Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions safe_transaction_service/history/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,6 +1322,42 @@ def safes_pending_to_be_processed(self) -> QuerySet[ChecksumAddress]:
self.not_processed().values_list("internal_tx___from", flat=True).distinct()
)

def safes_pending_to_be_processed_iterator(
self, batch_size: int = 1000, start_after_id: int = 0
) -> Iterator[str]:
"""
Generator that yields batches of `_from` addresses from unprocessed InternalTxDecoded entries.

:param batch_size: Number of rows to fetch per page
:param start_after_id: Start after this internal_tx_id
:return: Iterator over lists of `_from` addresses
"""
last_seen_id = start_after_id

while True:
# Get next batch of internal_tx_ids from InternalTxDecoded
decoded_ids = list(
InternalTxDecoded.objects.filter(
processed=False, internal_tx_id__gt=last_seen_id
)
.order_by("internal_tx_id")
.values_list("internal_tx_id", flat=True)[:batch_size]
)

if not decoded_ids:
break # No more results

# Fetch corresponding _from addresse
for _from in (
InternalTx.objects.filter(id__in=decoded_ids)
.values_list("_from", flat=True)
.iterator(chunk_size=100)
):
yield _from

# Update last_seen_id for keyset pagination
last_seen_id = decoded_ids[-1]


class InternalTxDecoded(models.Model):
objects = InternalTxDecodedManager.from_queryset(InternalTxDecodedQuerySet)()
Expand Down
21 changes: 14 additions & 7 deletions safe_transaction_service/history/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,21 +263,28 @@ def process_decoded_internal_txs_task(self) -> Optional[int]:
"Start process decoded internal txs for every Safe in a different task"
)
count = 0
redis = get_redis()
redis_key = f"safes_being_processed:{process_decoded_internal_txs_task.request.id}"
for (
safe_to_process
) in (
InternalTxDecoded.objects.safes_pending_to_be_processed().iterator()
):
process_decoded_internal_txs_for_safe_task.delay(
safe_to_process, reindex_master_copies=True
)
count += 1
) in InternalTxDecoded.objects.safes_pending_to_be_processed_iterator():
if not redis.sismember(redis_key, safe_to_process):
logger.debug(f"Sending to process {safe_to_process}")
process_decoded_internal_txs_for_safe_task.delay(
safe_to_process, reindex_master_copies=True
)
redis.sadd(redis_key, safe_to_process)
if count == 0: # Configure TTL on processing start
redis.expire(redis_key, LOCK_TIMEOUT)
count += 1

(
logger.info("%d Safes to process", count)
if count
else logger.info("No Safes to process")
)
logger.info("Clean redis key: %s", redis_key)
redis.unlink(redis_key)
return count


Expand Down
51 changes: 51 additions & 0 deletions safe_transaction_service/history/tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,54 @@ def test_delete_expired_delegates_task(self):
delegator=safe_contract_delegate_expected_to_be_deleted.delegator,
).exists()
)

@patch(
"safe_transaction_service.history.tasks.process_decoded_internal_txs_for_safe_task.delay"
)
def test_process_decoded_internal_txs_task_with_batch(self, mock_process_safe_task):
"""
Test that batch processing of decoded internal txs works correctly:
"""
# Create 3 different Safes
safe_addresses = [Account.create().address for i in range(3)]

internal_decoded_txs = [
InternalTxDecodedFactory(internal_tx___from=address)
for address in safe_addresses
]

internal_decoded_txs.append(
[
InternalTxDecodedFactory(
internal_tx___from=safe_addresses[0], # Duplicate for first address
),
InternalTxDecodedFactory(
internal_tx___from=safe_addresses[0], # Duplicate for first address
),
]
)

with self.settings(
PROCESSING_ALL_SAFES_TOGETHER=False,
ETH_INTERNAL_TX_DECODED_PROCESS_BATCH=2, # Smaller than number of Safes (3)
):
with self.assertLogs(logger=task_logger) as cm:
result = process_decoded_internal_txs_task.delay().result
# Should process 3 unique Safes
self.assertEqual(result, 3)
self.assertIn(
"Start process decoded internal txs for every Safe in a different task",
cm.output[0],
)
self.assertIn("3 Safes to process", cm.output[1])

# Verify that process_decoded_internal_txs_for_safe_task was called exactly 3 times
self.assertEqual(mock_process_safe_task.call_count, 3)

# Verify each Safe was called exactly once
called_addresses = [
call[0][0] for call in mock_process_safe_task.call_args_list
]
self.assertEqual(called_addresses.count(safe_addresses[0]), 1)
self.assertEqual(called_addresses.count(safe_addresses[1]), 1)
self.assertEqual(called_addresses.count(safe_addresses[2]), 1)