Skip to content
154 changes: 154 additions & 0 deletions slither/analyses/data_flow/analyses/reentrancy/analysis/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from typing import Optional, Set, Union

from slither.analyses.data_flow.analyses.reentrancy.analysis.domain import (
DomainVariant,
ReentrancyDomain,
)
from slither.analyses.data_flow.analyses.reentrancy.core.state import State
from slither.analyses.data_flow.engine.analysis import Analysis
from slither.analyses.data_flow.engine.direction import Direction, Forward
from slither.analyses.data_flow.engine.domain import Domain
from slither.core.cfg.node import Node
from slither.core.declarations.function import Function
from slither.core.variables.state_variable import StateVariable
from slither.slithir.operations.event_call import EventCall
from slither.slithir.operations.high_level_call import HighLevelCall
from slither.slithir.operations.internal_call import InternalCall
from slither.slithir.operations.low_level_call import LowLevelCall
from slither.slithir.operations.operation import Operation
from slither.slithir.operations.send import Send
from slither.slithir.operations.transfer import Transfer


class ReentrancyAnalysis(Analysis):
def __init__(self):
self._direction = Forward()

def domain(self) -> Domain:
return ReentrancyDomain.bottom()

def direction(self) -> Direction:
return self._direction

def bottom_value(self) -> Domain:
return ReentrancyDomain.bottom()

def transfer_function(self, node: Node, domain: ReentrancyDomain, operation: Operation):
self.transfer_function_helper(node, domain, operation, private_functions_seen=set())

def transfer_function_helper(
self,
node: Node,
domain: ReentrancyDomain,
operation: Operation,
private_functions_seen: Optional[Set[Function]] = None,
):
if private_functions_seen is None:
private_functions_seen = set()

if domain.variant == DomainVariant.BOTTOM:
domain.variant = DomainVariant.STATE
domain.state = State()

self._analyze_operation_by_type(operation, domain, node, private_functions_seen)

def _analyze_operation_by_type(
self,
operation: Operation,
domain: ReentrancyDomain,
node: Node,
private_functions_seen: Set[Function],
):
if isinstance(operation, EventCall):
self._handle_event_call_operation(operation, domain)
elif isinstance(operation, InternalCall):
self._handle_internal_call_operation(operation, domain, private_functions_seen)
elif isinstance(operation, (HighLevelCall, LowLevelCall, Transfer, Send)):
self._handle_abi_call_contract_operation(operation, domain, node)

self._handle_storage(domain, node)
self._update_writes_after_calls(domain, node)

def _handle_storage(self, domain: ReentrancyDomain, node: Node):
# Track state reads
for var in node.state_variables_read:
if isinstance(var, StateVariable) and var.is_stored:
domain.state.add_read(var, node)
# Track state writes
for var in node.state_variables_written:
if isinstance(var, StateVariable) and var.is_stored:
domain.state.add_written(var, node)

def _update_writes_after_calls(self, domain: ReentrancyDomain, node: Node):
# Writes after any external call
if node in domain.state.calls:
for var_name, write_nodes in domain.state.written.items():
for wn in write_nodes:
domain.state.add_write_after_call(var_name, wn)
# Writes after ETH-sending calls
if node in domain.state.send_eth:
for var_name, write_nodes in domain.state.written.items():
for wn in write_nodes:
domain.state.add_write_after_call(var_name, wn)

def _handle_internal_call_operation(
self,
operation: InternalCall,
domain: ReentrancyDomain,
private_functions_seen: Set[Function],
):
function = operation.function
if not isinstance(function, Function) or function in private_functions_seen:
return

private_functions_seen.add(function)
for node in function.nodes:
for internal_operation in node.irs:
if isinstance(internal_operation, (HighLevelCall, LowLevelCall, Transfer, Send)):
continue
self.transfer_function_helper(
node,
domain,
internal_operation,
private_functions_seen,
)
# Mark cross-function reentrancy for written variables
for var_name in domain.state.written.keys():
domain.state.add_cross_function(var_name, function)

def _handle_abi_call_contract_operation(
self,
operation: Union[LowLevelCall, HighLevelCall, Send, Transfer],
domain: ReentrancyDomain,
node: Node,
):
# Track all external calls - avoid duplicates
if operation.node not in domain.state.calls.get(node, set()):
domain.state.add_call(node, operation.node)

# Track variables read prior to this call
for var_name in domain.state.reads.keys():
domain.state.add_reads_prior_calls(node, var_name)

# Track external calls that send ETH - avoid duplicates
if operation.can_send_eth:
if operation.node not in domain.state.send_eth.get(node, set()):
domain.state.add_send_eth(node, operation.node)

def _handle_event_call_operation(self, operation: EventCall, domain: ReentrancyDomain):
# Track events and propagate previous external calls
# Only propagate calls that haven't already been propagated to this event node
existing_calls = domain.state.calls.get(operation.node, set())

# Collect all calls to add before modifying the dictionary
calls_to_add = []
for calls_set in domain.state.calls.values():
for call_node in calls_set:
if call_node not in existing_calls:
calls_to_add.append(call_node)

# Add all collected calls
for call_node in calls_to_add:
domain.state.add_call(operation.node, call_node)

domain.state.add_event(operation, operation.node)
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from enum import Enum, auto
from typing import Optional

from slither.analyses.data_flow.analyses.reentrancy.core.state import State
from slither.analyses.data_flow.engine.domain import Domain


class DomainVariant(Enum):
BOTTOM = auto()
TOP = auto()
STATE = auto()


class ReentrancyDomain(Domain):
def __init__(self, variant: DomainVariant, state: Optional[State] = None):
self.variant = variant
self.state = state or State()

@classmethod
def bottom(cls) -> "ReentrancyDomain":

Check warning on line 20 in slither/analyses/data_flow/analyses/reentrancy/analysis/domain.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0221: Number of parameters was 1 in 'Domain.bottom' and is now 1 in overriding 'ReentrancyDomain.bottom' method (arguments-differ)
return cls(DomainVariant.BOTTOM)

@classmethod
def top(cls) -> "ReentrancyDomain":

Check warning on line 24 in slither/analyses/data_flow/analyses/reentrancy/analysis/domain.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0221: Number of parameters was 1 in 'Domain.top' and is now 1 in overriding 'ReentrancyDomain.top' method (arguments-differ)
return cls(DomainVariant.TOP)

@classmethod
def with_state(cls, info: State) -> "ReentrancyDomain":
return cls(DomainVariant.STATE, info)

def join(self, other: "ReentrancyDomain") -> bool:
if self.variant == DomainVariant.TOP or other.variant == DomainVariant.BOTTOM:
return False

if self.variant == DomainVariant.BOTTOM and other.variant == DomainVariant.STATE:
self.variant = DomainVariant.STATE
self.state = other.state.deep_copy()
self.state.written.clear()
self.state.events.clear()
self.state.writes_after_calls.clear()
self.state.cross_function.clear()
return True

if self.variant == DomainVariant.STATE and other.variant == DomainVariant.STATE:

Check warning on line 44 in slither/analyses/data_flow/analyses/reentrancy/analysis/domain.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

R1705: Unnecessary "else" after "return", remove the "else" and de-indent the code inside it (no-else-return)
if self.state == other.state:
return False

self.state.send_eth.update(other.state.send_eth)
self.state.calls.update(other.state.calls)
self.state.reads.update(other.state.reads)
self.state.reads_prior_calls.update(other.state.reads_prior_calls)
self.state.safe_send_eth.update(other.state.safe_send_eth)
self.state.writes_after_calls.update(other.state.writes_after_calls)
self.state.cross_function.update(other.state.cross_function)
return True

else:
self.variant = DomainVariant.TOP

return True
162 changes: 162 additions & 0 deletions slither/analyses/data_flow/analyses/reentrancy/core/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import copy
from collections import defaultdict
from typing import Dict, Set

from slither.core.cfg.node import Node
from slither.core.declarations.function import Function
from slither.core.variables.state_variable import StateVariable
from slither.slithir.operations.event_call import EventCall


class State:

Check warning on line 11 in slither/analyses/data_flow/analyses/reentrancy/core/state.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

R0902: Too many instance attributes (9/7) (too-many-instance-attributes)
def __init__(self):
self._send_eth: Dict[Node, Set[Node]] = defaultdict(set)
self._safe_send_eth: Dict[Node, Set[Node]] = defaultdict(set)
self._calls: Dict[Node, Set[Node]] = defaultdict(set)
self._reads: Dict[str, Set[Node]] = defaultdict(set)
self._reads_prior_calls: Dict[Node, Set[str]] = defaultdict(set)
self._events: Dict[EventCall, Set[Node]] = defaultdict(set)
self._written: Dict[str, Set[Node]] = defaultdict(set)
self.writes_after_calls: Dict[str, Set[Node]] = defaultdict(set)
self.cross_function: Dict[StateVariable, Set[Function]] = defaultdict(set)

# -------------------- Add methods --------------------
def add_call(self, node: Node, call_node: Node):
self._calls[node].add(call_node)

def add_send_eth(self, node: Node, call_node: Node):
self._send_eth[node].add(call_node)

def add_safe_send_eth(self, node: Node, call_node: Node):
self._safe_send_eth[node].add(call_node)

def add_written(self, var: StateVariable, node: Node):
# Ensure the canonical name exists and is not None
if var.canonical_name is not None:
# Ensure the key exists in the defaultdict
if var.canonical_name not in self._written:
self._written[var.canonical_name] = set()
self._written[var.canonical_name].add(node)

def add_read(self, var: StateVariable, node: Node):
# Ensure the canonical name exists and is not None
if var.canonical_name is not None:
# Ensure the key exists in the defaultdict
if var.canonical_name not in self._reads:
self._reads[var.canonical_name] = set()
self._reads[var.canonical_name].add(node)

def add_reads_prior_calls(self, node: Node, var_name: str):
self._reads_prior_calls[node].add(var_name)

def add_write_after_call(self, var_name: str, node: Node):
self.writes_after_calls[var_name].add(node)

def add_cross_function(self, var: StateVariable, function: Function):
self.cross_function[var].add(function)

def add_event(self, event: EventCall, node: Node):
self._events[event].add(node)

# -------------------- Properties --------------------
@property
def send_eth(self) -> Dict[Node, Set[Node]]:
return self._send_eth

@property
def safe_send_eth(self) -> Dict[Node, Set[Node]]:
return self._safe_send_eth

@property
def all_eth_calls(self) -> Dict[Node, Set[Node]]:
result = defaultdict(set)
for node, calls in self._send_eth.items():
result[node].update(calls)
for node, calls in self._safe_send_eth.items():
result[node].update(calls)
return result

@property
def calls(self) -> Dict[Node, Set[Node]]:
return self._calls

@property
def reads(self) -> Dict[str, Set[Node]]:
return self._reads

@property
def written(self) -> Dict[str, Set[Node]]:
return self._written

@property
def reads_prior_calls(self) -> Dict[Node, Set[str]]:
return self._reads_prior_calls

@property
def events(self) -> Dict[EventCall, Set[Node]]:
return self._events

# -------------------- Utilities --------------------
def __eq__(self, other):
if not isinstance(other, State):
return False
return (
self._send_eth == other._send_eth
and self._safe_send_eth == other._safe_send_eth
and self._calls == other._calls
and self._reads == other._reads
and self._reads_prior_calls == other._reads_prior_calls
and self._events == other._events
and self._written == other._written
and self.writes_after_calls == other.writes_after_calls
and self.cross_function == other.cross_function
)

def __hash__(self):
return hash(
(
frozenset(self._send_eth.items()),
frozenset(self._safe_send_eth.items()),
frozenset(self._calls.items()),
frozenset(self._reads.items()),
frozenset(self._reads_prior_calls.items()),
frozenset(self._events.items()),
frozenset(self._written.items()),
frozenset((k, frozenset(v)) for k, v in self.writes_after_calls.items()),
frozenset((k, frozenset(v)) for k, v in self.cross_function.items()),
)
)

def __str__(self):
return (
f"State(\n"
f" send_eth: {len(self._send_eth)} items,\n"
f" safe_send_eth: {len(self._safe_send_eth)} items,\n"
f" calls: {len(self._calls)} items,\n"
f" reads: {len(self._reads)} items,\n"
f" reads_prior_calls: {len(self._reads_prior_calls)} items,\n"
f" events: {len(self._events)} items,\n"
f" written: {len(self._written)} items,\n"
f" writes_after_calls: {len(self.writes_after_calls)} items,\n"
f" cross_function: {len(self.cross_function)} items,\n"
f")"
)

def deep_copy(self) -> "State":
new_state = State()
# Use shallow copy for Node objects to avoid circular reference issues

new_state._send_eth.update({k: v.copy() for k, v in self._send_eth.items()})

Check warning on line 149 in slither/analyses/data_flow/analyses/reentrancy/core/state.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0212: Access to a protected member _send_eth of a client class (protected-access)
new_state._safe_send_eth.update({k: v.copy() for k, v in self._safe_send_eth.items()})

Check warning on line 150 in slither/analyses/data_flow/analyses/reentrancy/core/state.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0212: Access to a protected member _safe_send_eth of a client class (protected-access)
new_state._calls.update({k: v.copy() for k, v in self._calls.items()})

Check warning on line 151 in slither/analyses/data_flow/analyses/reentrancy/core/state.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0212: Access to a protected member _calls of a client class (protected-access)
new_state._reads.update({k: v.copy() for k, v in self._reads.items()})

Check warning on line 152 in slither/analyses/data_flow/analyses/reentrancy/core/state.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0212: Access to a protected member _reads of a client class (protected-access)
new_state._reads_prior_calls.update(

Check warning on line 153 in slither/analyses/data_flow/analyses/reentrancy/core/state.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0212: Access to a protected member _reads_prior_calls of a client class (protected-access)
{k: v.copy() for k, v in self._reads_prior_calls.items()}
)
new_state._events.update({k: v.copy() for k, v in self._events.items()})

Check warning on line 156 in slither/analyses/data_flow/analyses/reentrancy/core/state.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0212: Access to a protected member _events of a client class (protected-access)
new_state._written.update({k: v.copy() for k, v in self._written.items()})
new_state.writes_after_calls.update(
{k: v.copy() for k, v in self.writes_after_calls.items()}
)
new_state.cross_function.update({k: v.copy() for k, v in self.cross_function.items()})
return new_state
Loading
Loading