Skip to content

feat: improve UTxO selection #283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 173 additions & 29 deletions cardano_clusterlib/txtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,21 @@ def _organize_utxos_by_id(
return db


def _organize_utxos_by_coin_and_id(
tx_list: list[structs.UTXOData],
) -> dict[str, dict[str, int]]:
"""Organize UTxOs by coin and ID (hash#ix)."""
db: dict[str, dict[str, int]] = {}
for r in tx_list:
utxo_id = f"{r.utxo_hash}#{r.utxo_ix}"
db_rec = db.get(r.coin)
if db_rec is None:
db[r.coin] = {utxo_id: r.amount}
continue
db_rec[utxo_id] = r.amount
return db


def _get_usable_utxos(
address_utxos: list[structs.UTXOData], coins: set[str]
) -> list[structs.UTXOData]:
Expand All @@ -67,31 +82,151 @@ def _get_usable_utxos(
return txins


def _collect_utxos_amount(
utxos: list[structs.UTXOData], amount: int, min_change_value: int
) -> list[structs.UTXOData]:
"""Collect UTxOs so their total combined amount >= `amount`."""
collected_utxos: list[structs.UTXOData] = []
collected_amount = 0
# `_min_change_value` applies only to ADA
amount_plus_change = (
amount + min_change_value if utxos and utxos[0].coin == consts.DEFAULT_COIN else amount
)
for utxo in utxos:
def _pick_coins_from_already_selected_utxos(
coin_txins: dict[str, int],
already_selected_utxos: set[str],
target_amount: int,
target_with_change: int,
) -> tuple[set[str], int, bool]:
"""Pick coins from already selected UTxOs if they have the desired coin.

Args:
coin_txins (dict): A dictionary of coin UTxOs.
already_selected_utxos (set): A set of already selected UTxOs (for different coins).
target_amount (int): The desired amount.
target_with_change (int): The desired amount with minimal change.

Returns:
tuple: A tuple with selected UTxO IDs, accumulated amount and a bool indicating if the
desired amount was met.
"""
picked_utxos: set[str] = set()
accumulated_amount = 0

# See if the coin exists in UTxOs that were already selected
for utxo_id in already_selected_utxos:
utxo_amount = coin_txins.get(utxo_id)
if utxo_amount is None:
continue
accumulated_amount += utxo_amount

# If we were able to collect exact amount, no change is needed
if collected_amount == amount:
if accumulated_amount == target_amount:
break
# Make sure the change is higher than `_min_change_value`
if collected_amount >= amount_plus_change:
if accumulated_amount >= target_with_change:
break
collected_utxos.append(utxo)
collected_amount += utxo.amount
else:
return picked_utxos, accumulated_amount, False

return picked_utxos, accumulated_amount, True


def _pick_utxos_with_defragmentation(
utxos: list[tuple[str, int]],
target_amount: int,
target_with_change: int,
accumulated_amount: int,
) -> tuple[set[str], int, bool]:
"""Pick UTxOs to meet or exceed the target amount while prioritizing defragmentation.

Args:
utxos (list of tuple): A list of tuples (utxo_id, coin_amount).
target_amount (int): The desired amount.
target_with_change (int): The desired amount with minimal change.
accumulated_amount (int): The accumulated amount.

Returns:
tuple: A tuple with selected UTxO IDs, accumulated amount and a bool indicating if the
desired amount was met.
"""
# Sort UTxOs by amount in ascending order
sorted_utxos = sorted(enumerate(utxos), key=lambda x: x[1][1]) # Keep original indices
selected_indices = set()
picked_utxos = set()

# Step 1: Select up to 10 smallest UTxOs
for i, (utxo_id, coin_amount) in sorted_utxos[:10]:
picked_utxos.add(utxo_id)
selected_indices.add(i)
accumulated_amount += coin_amount

# If we were able to collect exact amount, no change is needed
if accumulated_amount == target_amount:
return picked_utxos, accumulated_amount, True
# Make sure the change is higher than `_min_change_value`
if accumulated_amount >= target_with_change:
return picked_utxos, accumulated_amount, True

# Step 2: If target is not met, select UTxO closest to remaining amount
while accumulated_amount < target_with_change:
# If we were able to collect exact amount, no change is needed
if accumulated_amount == target_amount:
return picked_utxos, accumulated_amount, True

# We target exact amount, but if we are already over it, we need at least additional
# `_min_change_value` for change.
if accumulated_amount > target_amount:
remaining_amount = target_with_change - accumulated_amount
else:
remaining_amount = target_amount - accumulated_amount

# Find the index of the UTxO closest to the remaining amount
closest_index = min(
(i for i, _ in sorted_utxos if i not in selected_indices),
key=lambda i: abs(utxos[i][1] - remaining_amount),
default=None,
)

# If all UTxOs have been considered, the target was not met
if closest_index is None:
return picked_utxos, accumulated_amount, False

# Select the closest UTxO
utxo_id, coin_amount = utxos[closest_index]
picked_utxos.add(utxo_id)
selected_indices.add(closest_index)
accumulated_amount += coin_amount

return picked_utxos, accumulated_amount, True


def _select_utxos_per_coin(
coin_txins: dict[str, int],
coin: str,
target_amount: int,
target_with_change: int,
already_selected_utxos: set[str],
) -> set[str]:
"""Select UTxOs for a given coin so their total combined amount >= `amount`."""
selected_utxos, accumulated_amount, target_met = _pick_coins_from_already_selected_utxos(
coin_txins=coin_txins,
already_selected_utxos=already_selected_utxos,
target_amount=target_amount,
target_with_change=target_with_change,
)

return collected_utxos
# Pick more UTxOs if the amount is not satisfied yet
if not target_met:
ids_and_amounts = [(i, a) for i, a in coin_txins.items() if i not in already_selected_utxos]
more_utxos, _, target_met = _pick_utxos_with_defragmentation(
utxos=ids_and_amounts,
target_amount=target_amount,
target_with_change=target_with_change,
accumulated_amount=accumulated_amount,
)
selected_utxos.update(more_utxos)

if not target_met:
LOGGER.warning(
f"Could not meet target amount {target_amount} for coin '{coin}' with the given UTxOs."
)

return selected_utxos


def _select_utxos(
txins_db: dict[str, list[structs.UTXOData]],
txins_by_coin_and_id: dict[str, dict[str, int]],
txouts_passed_db: dict[str, list[structs.TxOut]],
txouts_mint_db: dict[str, list[structs.TxOut]],
fee: int,
Expand All @@ -107,8 +242,8 @@ def _select_utxos(
utxo_ids: set[str] = set()

# Iterate over coins both in txins and txouts
for coin in set(txins_db).union(txouts_passed_db).union(txouts_mint_db):
coin_txins = txins_db.get(coin) or []
for coin in set(txins_by_coin_and_id).union(txouts_passed_db).union(txouts_mint_db):
coin_txins = txins_by_coin_and_id.get(coin) or {}
coin_txouts = txouts_passed_db.get(coin) or []

total_output_amount = functools.reduce(lambda x, y: x + y.amount, coin_txouts, 0)
Expand All @@ -117,26 +252,35 @@ def _select_utxos(
# The value "-1" means all available funds
max_index = [idx for idx, val in enumerate(coin_txouts) if val.amount == -1]
if max_index:
utxo_ids.update(f"{rec.utxo_hash}#{rec.utxo_ix}" for rec in coin_txins)
utxo_ids.update(r for r in coin_txins)
continue

tx_fee = max(1, fee)
funds_needed = total_output_amount + tx_fee + deposit + treasury_donation
total_withdrawals_amount = functools.reduce(lambda x, y: x + y.amount, withdrawals, 0)
# Fee needs an input, even if withdrawal would cover all needed funds
input_funds_needed = max(funds_needed - total_withdrawals_amount, tx_fee)
# `_min_change_value` applies only to ADA
target_with_change = input_funds_needed + min_change_value
else:
coin_txouts_minted = txouts_mint_db.get(coin) or []
total_minted_amount = functools.reduce(lambda x, y: x + y.amount, coin_txouts_minted, 0)
# In case of token burning, `total_minted_amount` might be negative.
# Try to collect enough funds to satisfy both token burning and token
# transfers, even though there might be an overlap.
input_funds_needed = total_output_amount - total_minted_amount

filtered_coin_utxos = _collect_utxos_amount(
utxos=coin_txins, amount=input_funds_needed, min_change_value=min_change_value
)
utxo_ids.update(f"{rec.utxo_hash}#{rec.utxo_ix}" for rec in filtered_coin_utxos)
target_with_change = input_funds_needed

if input_funds_needed:
utxo_ids.update(
_select_utxos_per_coin(
coin_txins=txins_by_coin_and_id.get(coin) or {},
coin=coin,
target_amount=input_funds_needed,
target_with_change=target_with_change,
already_selected_utxos=utxo_ids,
)
)

return utxo_ids

Expand Down Expand Up @@ -536,11 +680,11 @@ def _get_tx_ins_outs(
msg = "No input UTxO."
raise exceptions.CLIError(msg)

txins_db_all: dict[str, list[structs.UTXOData]] = _organize_tx_ins_outs_by_coin(txins_all)
txins_by_coin_and_id = _organize_utxos_by_coin_and_id(txins_all)

# All output coins, except those minted by this transaction, need to be present in
# transaction inputs
if not set(outcoins_passed).difference(txouts_mint_db).issubset(txins_db_all):
if not set(outcoins_passed).difference(txouts_mint_db).issubset(txins_by_coin_and_id):
msg = "Not all output coins are present in input UTxOs."
raise exceptions.CLIError(msg)

Expand All @@ -555,11 +699,11 @@ def _get_tx_ins_outs(
if txins:
# Don't touch txins that were passed to the function
txins_filtered = txins_all
txins_db_filtered = txins_db_all
txins_db_filtered = _organize_tx_ins_outs_by_coin(txins_all)
else:
# Select only UTxOs that are needed to satisfy all outputs, deposits and fee
selected_utxo_ids = _select_utxos(
txins_db=txins_db_all,
txins_by_coin_and_id=txins_by_coin_and_id,
txouts_passed_db=txouts_passed_db,
txouts_mint_db=txouts_mint_db,
fee=fee,
Expand Down
Loading