diff --git a/docs/HOWTO.rst b/docs/HOWTO.rst index 1b8d523dd..4c834701f 100644 --- a/docs/HOWTO.rst +++ b/docs/HOWTO.rst @@ -28,17 +28,20 @@ functions. For example, `dash_hash`_ is required for DASH. Scrypt coins require a Python interpreter compiled and/or linked with OpenSSL 1.1.0 or higher. -You **must** be running a non-pruning bitcoin daemon with:: +You **must** be running a non-pruning bitcoin daemon. +It is also recommended to have:: txindex=1 -set in its configuration file. If you have an existing installation -of bitcoind and have not previously set this you will need to reindex -the blockchain with:: +set in its configuration file, for better performance when serving many sessions. +If you have an existing installation of bitcoind and have not previously set this +but you do now, you will need to reindex the blockchain with:: bitcoind -reindex which can take some time. +If you intend to use a bitcoind without txindex enabled, you need to configure the +`DAEMON_HAS_TXINDEX` :ref:`environment variable `. While not a requirement for running ElectrumX, it is intended to be run with supervisor software such as Daniel Bernstein's diff --git a/docs/environment.rst b/docs/environment.rst index d16559a64..a59fde9a5 100644 --- a/docs/environment.rst +++ b/docs/environment.rst @@ -281,6 +281,15 @@ These environment variables are optional: version string. For example to drop versions from 1.0 to 1.2 use the regex ``1\.[0-2]\.\d+``. +.. envvar:: DAEMON_HAS_TXINDEX + + Set this environment variable to empty if the connected bitcoind + does not have txindex enabled. Defaults to True (has txindex). + We rely on bitcoind to lookup arbitrary txs, regardless of this setting. + Without txindex, tx lookup works as in `bitcoind PR #10275`_. + Note that performance when serving many sessions is better with txindex + (but initial sync should be unaffected). + Resource Usage Limits ===================== @@ -504,3 +513,4 @@ your available physical RAM: .. _lib/coins.py: https://github.com/spesmilo/electrumx/blob/master/src/electrumx/lib/coins.py .. _uvloop: https://pypi.python.org/pypi/uvloop +.. _bitcoind PR #10275: https://github.com/bitcoin/bitcoin/pull/10275 diff --git a/electrumx_compact_history b/electrumx_compact_history deleted file mode 100755 index 8104be127..000000000 --- a/electrumx_compact_history +++ /dev/null @@ -1,11 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 ; mode: python -*- -import os -import sys - - -if __name__ == '__main__': - src_dir = os.path.join(os.path.dirname(__file__), "src") - sys.path.insert(0, src_dir) - from electrumx.cli.electrumx_compact_history import main - sys.exit(main()) diff --git a/pyproject.toml b/pyproject.toml index 5a91cf79b..606bb6a2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,6 @@ Repository = "https://github.com/spesmilo/electrumx" [project.scripts] electrumx_server = "electrumx.cli.electrumx_server:main" electrumx_rpc = "electrumx.cli.electrumx_rpc:main" -electrumx_compact_history = "electrumx.cli.electrumx_compact_history:main" [tool.setuptools.dynamic] version = { attr = 'electrumx.__version__' } diff --git a/src/electrumx/cli/electrumx_compact_history.py b/src/electrumx/cli/electrumx_compact_history.py deleted file mode 100644 index 66cc81312..000000000 --- a/src/electrumx/cli/electrumx_compact_history.py +++ /dev/null @@ -1,83 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright (c) 2017, Neil Booth -# -# All rights reserved. -# -# See the file "LICENCE" for information about the copyright -# and warranty status of this software. - -'''Script to compact the history database. This should save space and -will reset the flush counter to a low number, avoiding overflow when -the flush count reaches 65,536. - -This needs to lock the database so ElectrumX must not be running - -shut it down cleanly first. - -It is recommended you run this script with the same environment as -ElectrumX. However, it is intended to be runnable with just -DB_DIRECTORY and COIN set (COIN defaults as for ElectrumX). - -If you use daemon tools, you might run this script like so: - - envdir /path/to/the/environment/directory ./compact_history.py - -Depending on your hardware, this script may take up to 6 hours to -complete; it logs progress regularly. - -Compaction can be interrupted and restarted harmlessly and will pick -up where it left off. However, if you restart ElectrumX without -running the compaction to completion, it will not benefit and -subsequent compactions will restart from the beginning. -''' - -import asyncio -import logging -import sys -import traceback -from os import environ - -from electrumx import Env -from electrumx.server.db import DB - - -async def compact_history(): - if sys.version_info < (3, 10): - raise RuntimeError('Python >= 3.10 is required to run ElectrumX') - - environ['DAEMON_URL'] = '' # Avoid Env erroring out - env = Env() - db = DB(env) - await db.open_for_compacting() - - assert not db.first_sync - history = db.history - # Continue where we left off, if interrupted - if history.comp_cursor == -1: - history.comp_cursor = 0 - - history.comp_flush_count = max(history.comp_flush_count, 1) - limit = 8 * 1000 * 1000 - - while history.comp_cursor != -1: - history._compact_history(limit) - - # When completed also update the UTXO flush count - db.set_flush_count(history.flush_count) - - -def main(): - logging.basicConfig(level=logging.INFO) - logging.info('Starting history compaction...') - loop = asyncio.get_event_loop() - try: - loop.run_until_complete(compact_history()) - except Exception: - traceback.print_exc() - logging.critical('History compaction terminated abnormally') - else: - logging.info('History compaction complete') - - -if __name__ == '__main__': - main() diff --git a/src/electrumx/lib/coins.py b/src/electrumx/lib/coins.py index baa32d84f..e05562f6a 100644 --- a/src/electrumx/lib/coins.py +++ b/src/electrumx/lib/coins.py @@ -234,7 +234,7 @@ def privkey_WIF(cls, privkey_bytes, compressed): return cls.ENCODE_CHECK(payload) @classmethod - def header_hash(cls, header): + def header_hash(cls, header: bytes) -> bytes: '''Given a header return hash''' return double_sha256(header) diff --git a/src/electrumx/lib/tx.py b/src/electrumx/lib/tx.py index e1fe87a18..40dfe71ae 100644 --- a/src/electrumx/lib/tx.py +++ b/src/electrumx/lib/tx.py @@ -113,6 +113,13 @@ def serialize(self): )) +@dataclass(kw_only=True, slots=True) +class TXOSpendStatus: + prev_height: Optional[int] # block height TXO is mined at. None if the outpoint never existed + spender_txhash: bytes = None + spender_height: int = None + + class Deserializer: '''Deserializes blocks into transactions. diff --git a/src/electrumx/lib/util.py b/src/electrumx/lib/util.py index 11236cd0e..8d8ef8b94 100644 --- a/src/electrumx/lib/util.py +++ b/src/electrumx/lib/util.py @@ -35,7 +35,7 @@ import sys from collections.abc import Container, Mapping from struct import Struct -from typing import Set, Any +from typing import Set, Any, Optional import aiorpcx @@ -166,7 +166,7 @@ def chunks(items, size): yield items[i: i + size] -def resolve_limit(limit): +def resolve_limit(limit: Optional[int]) -> int: if limit is None or limit < 0: return -1 assert isinstance(limit, int) @@ -333,6 +333,7 @@ def is_hex_str(text: Any) -> bool: struct_le_Q = Struct('H') struct_be_I = Struct('>I') +struct_be_Q = Struct('>Q') structB = Struct('B') unpack_le_int32_from = struct_le_i.unpack_from @@ -346,6 +347,7 @@ def is_hex_str(text: Any) -> bool: unpack_le_uint32 = struct_le_I.unpack unpack_le_uint64 = struct_le_Q.unpack unpack_be_uint32 = struct_be_I.unpack +unpack_be_uint64 = struct_be_Q.unpack pack_le_int32 = struct_le_i.pack pack_le_int64 = struct_le_q.pack @@ -354,6 +356,7 @@ def is_hex_str(text: Any) -> bool: pack_le_uint64 = struct_le_Q.pack pack_be_uint16 = struct_be_H.pack pack_be_uint32 = struct_be_I.pack +pack_be_uint64 = struct_be_Q.pack pack_byte = structB.pack hex_to_bytes = bytes.fromhex diff --git a/src/electrumx/server/block_processor.py b/src/electrumx/server/block_processor.py index 23ae54378..7e4b7c37d 100644 --- a/src/electrumx/server/block_processor.py +++ b/src/electrumx/server/block_processor.py @@ -11,7 +11,7 @@ import asyncio import time -from typing import Sequence, Tuple, List, Callable, Optional, TYPE_CHECKING, Type +from typing import Sequence, Tuple, List, Callable, Optional, TYPE_CHECKING, Type, Set from aiorpcx import run_in_thread, CancelledError @@ -23,8 +23,8 @@ chunks, class_logger, pack_le_uint32, pack_le_uint64, unpack_le_uint64, OldTaskGroup ) from electrumx.lib.tx import Tx -from electrumx.server.db import FlushData, COMP_TXID_LEN, DB -from electrumx.server.history import TXNUM_LEN +from electrumx.server.db import FlushData, DB +from electrumx.server.history import TXNUM_LEN, TXOUTIDX_LEN, TXOUTIDX_PADDING, pack_txnum if TYPE_CHECKING: from electrumx.lib.coins import Coin, Block @@ -186,7 +186,8 @@ def __init__(self, env: 'Env', db: DB, daemon: Daemon, notifications: 'Notificat # Meta self.next_cache_check = 0 - self.touched = set() + self.touched_hashxs = set() # type: Set[bytes] + self.touched_outpoints = set() # type: Set[Tuple[bytes, int]] self.reorg_count = 0 self.height = -1 self.tip = None # type: Optional[bytes] @@ -196,7 +197,10 @@ def __init__(self, env: 'Env', db: DB, daemon: Daemon, notifications: 'Notificat # Caches of unflushed items. self.headers = [] - self.tx_hashes = [] + self.tx_hashes = [] # type: List[bytes] + self.wtxids = [] # type: List[bytes] + self.undo_tx_hashes = [] # type: List[bytes] + self.undo_historical_spends = [] # type: List[bytes] self.undo_infos = [] # type: List[Tuple[Sequence[bytes], int]] # UTXO cache @@ -244,8 +248,13 @@ async def check_and_advance_blocks(self, raw_blocks: Sequence[bytes]) -> None: self.logger.info(f'processed {len(blocks):,d} block{s} size {blocks_size:.2f} MB ' f'in {time.monotonic() - start:.1f}s') if self._caught_up_event.is_set(): - await self.notifications.on_block(self.touched, self.height) - self.touched = set() + await self.notifications.on_block( + touched_hashxs=self.touched_hashxs, + touched_outpoints=self.touched_outpoints, + height=self.height, + ) + self.touched_hashxs = set() + self.touched_outpoints = set() elif hprevs[0] != chain[0]: await self.reorg_chain() else: @@ -279,10 +288,10 @@ async def get_raw_blocks(last_height, hex_hashes) -> Sequence[bytes]: return await self.daemon.raw_blocks(hex_hashes) def flush_backup(): - # self.touched can include other addresses which is + # self.touched_hashxs can include other addresses which is # harmless, but remove None. - self.touched.discard(None) - self.db.flush_backup(self.flush_data(), self.touched) + self.touched_hashxs.discard(None) + self.db.flush_backup(self.flush_data(), self.touched_hashxs) _start, last, hashes = await self.reorg_hashes(count) # Reverse and convert to hex strings. @@ -357,9 +366,19 @@ def estimate_txs_remaining(self): def flush_data(self): '''The data for a flush. The lock must be taken.''' assert self.state_lock.locked() - return FlushData(self.height, self.tx_count, self.headers, - self.tx_hashes, self.undo_infos, self.utxo_cache, - self.db_deletes, self.tip) + return FlushData( + height=self.height, + tx_count=self.tx_count, + headers=self.headers, + block_tx_hashes=self.tx_hashes, + block_wtxids=self.wtxids, + undo_block_tx_hashes=self.undo_tx_hashes, + undo_historical_spends=self.undo_historical_spends, + undo_infos=self.undo_infos, + adds=self.utxo_cache, + deletes=self.db_deletes, + tip=self.tip, + ) async def flush(self, flush_utxos): def flush(): @@ -378,7 +397,7 @@ async def _maybe_flush(self): await self.flush(flush_arg) self.next_cache_check = time.monotonic() + 30 - def check_cache_size(self): + def check_cache_size(self) -> Optional[bool]: '''Flush a cache if it gets too big.''' # Good average estimates based on traversal of subobjects and # requesting size from Python (see deep_getsizeof). @@ -433,8 +452,6 @@ def advance_txs( txs: Sequence[Tx], is_unspendable: Callable[[bytes], bool], ) -> Sequence[bytes]: - self.tx_hashes.append(b''.join(tx.txid for tx in txs)) - # Use local vars for speed in the loops undo_info = [] tx_num = self.tx_count @@ -442,17 +459,23 @@ def advance_txs( put_utxo = self.utxo_cache.__setitem__ spend_utxo = self.spend_utxo undo_info_append = undo_info.append - update_touched = self.touched.update + update_touched_hashxs = self.touched_hashxs.update + add_touched_outpoint = self.touched_outpoints.add hashXs_by_tx = [] append_hashXs = hashXs_by_tx.append + txhash_to_txnum_map = {} + put_txhash_to_txnum_map = txhash_to_txnum_map.__setitem__ + txo_to_spender_map = {} + put_txo_to_spender_map = txo_to_spender_map.__setitem__ to_le_uint32 = pack_le_uint32 to_le_uint64 = pack_le_uint64 + _pack_txnum = pack_txnum for tx in txs: tx_hash = tx.txid hashXs = [] append_hashX = hashXs.append - tx_numb = to_le_uint64(tx_num)[:TXNUM_LEN] + tx_numb = _pack_txnum(tx_num) # Spend the inputs for txin in tx.inputs: @@ -461,6 +484,9 @@ def advance_txs( cache_value = spend_utxo(txin.prev_hash, txin.prev_idx) undo_info_append(cache_value) append_hashX(cache_value[:HASHX_LEN]) + prevout_tuple = (txin.prev_hash, txin.prev_idx) + put_txo_to_spender_map(prevout_tuple, tx_hash) + add_touched_outpoint(prevout_tuple) # Add the new UTXOs for idx, txout in enumerate(tx.outputs): @@ -471,14 +497,23 @@ def advance_txs( # Get the hashX hashX = script_hashX(txout.pk_script) append_hashX(hashX) - put_utxo(tx_hash + to_le_uint32(idx), + put_utxo(tx_hash + to_le_uint32(idx)[:TXOUTIDX_LEN], hashX + tx_numb + to_le_uint64(txout.value)) + add_touched_outpoint((tx_hash, idx)) append_hashXs(hashXs) - update_touched(hashXs) + update_touched_hashxs(hashXs) + put_txhash_to_txnum_map(tx_hash, tx_num) tx_num += 1 - self.db.history.add_unflushed(hashXs_by_tx, self.tx_count) + self.tx_hashes.append(b''.join(tx.txid for tx in txs)) + self.wtxids.append(b''.join(tx.wtxid for tx in txs)) + self.db.history.add_unflushed( + hashXs_by_tx=hashXs_by_tx, + first_tx_num=self.tx_count, + txhash_to_txnum_map=txhash_to_txnum_map, + txo_to_spender_map=txo_to_spender_map, + ) self.tx_count = tx_num self.db.tx_counts.append(tx_num) @@ -530,7 +565,9 @@ def backup_txs( # Use local vars for speed in the loops put_utxo = self.utxo_cache.__setitem__ spend_utxo = self.spend_utxo - touched = self.touched + add_touched_hashx = self.touched_hashxs.add + add_touched_outpoint = self.touched_outpoints.add + undo_hist_spend = self.undo_historical_spends.append undo_entry_len = HASHX_LEN + TXNUM_LEN + 8 for tx in reversed(txs): @@ -544,7 +581,8 @@ def backup_txs( # Get the hashX cache_value = spend_utxo(tx_hash, idx) hashX = cache_value[:HASHX_LEN] - touched.add(hashX) + add_touched_hashx(hashX) + add_touched_outpoint((tx_hash, idx)) # Restore the inputs for txin in reversed(tx.inputs): @@ -552,9 +590,14 @@ def backup_txs( continue n -= undo_entry_len undo_item = undo_info[n:n + undo_entry_len] - put_utxo(txin.prev_hash + pack_le_uint32(txin.prev_idx), undo_item) + prevout = txin.prev_hash + pack_le_uint32(txin.prev_idx)[:TXOUTIDX_LEN] + put_utxo(prevout, undo_item) hashX = undo_item[:HASHX_LEN] - touched.add(hashX) + add_touched_hashx(hashX) + add_touched_outpoint((txin.prev_hash, txin.prev_idx)) + undo_hist_spend(prevout) + + self.undo_tx_hashes.append(b''.join(tx.txid for tx in txs)) assert n == 0 self.tx_count -= len(txs) @@ -599,21 +642,14 @@ def backup_txs( To this end we maintain two "tables", one for each point above: - 1. Key: b'u' + address_hashX + tx_idx + tx_num + 1. Key: b'u' + address_hashX + tx_num + txout_idx Value: the UTXO value as a 64-bit unsigned integer - 2. Key: b'h' + compressed_tx_hash + tx_idx + tx_num + 2. Key: b'h' + tx_num + txout_idx Value: hashX - - The compressed tx hash is just the first few bytes of the hash of - the tx in which the UTXO was created. As this is not unique there - will be potential collisions so tx_num is also in the key. When - looking up a UTXO the prefix space of the compressed hash needs to - be searched and resolved if necessary with the tx_num. The - collision rate is low (<0.1%). ''' - def spend_utxo(self, tx_hash: bytes, tx_idx: int) -> bytes: + def spend_utxo(self, tx_hash: bytes, txout_idx: int) -> bytes: '''Spend a UTXO and return (hashX + tx_num + value_sats). If the UTXO is not in the cache it must be on disk. We store @@ -621,42 +657,36 @@ def spend_utxo(self, tx_hash: bytes, tx_idx: int) -> bytes: corruption. ''' # Fast track is it being in the cache - idx_packed = pack_le_uint32(tx_idx) + idx_packed = pack_le_uint32(txout_idx)[:TXOUTIDX_LEN] cache_value = self.utxo_cache.pop(tx_hash + idx_packed, None) if cache_value: return cache_value # Spend it from the DB. - txnum_padding = bytes(8-TXNUM_LEN) + tx_num = self.db.fs_txnum_for_txhash(tx_hash) + if tx_num is None: + raise ChainError(f'UTXO {hash_to_hex_str(tx_hash)} / {txout_idx:,d} has ' + f'no corresponding tx_num in DB') + tx_numb = pack_txnum(tx_num) - # Key: b'h' + compressed_tx_hash + tx_idx + tx_num + # Key: b'h' + tx_num + txout_idx # Value: hashX - prefix = b'h' + tx_hash[:COMP_TXID_LEN] + idx_packed - candidates = {db_key: hashX for db_key, hashX - in self.db.utxo_db.iterator(prefix=prefix)} - - for hdb_key, hashX in candidates.items(): - tx_num_packed = hdb_key[-TXNUM_LEN:] - - if len(candidates) > 1: - tx_num, = unpack_le_uint64(tx_num_packed + txnum_padding) - hash, _height = self.db.fs_tx_hash(tx_num) - if hash != tx_hash: - assert hash is not None # Should always be found - continue - - # Key: b'u' + address_hashX + tx_idx + tx_num - # Value: the UTXO value as a 64-bit unsigned integer - udb_key = b'u' + hashX + hdb_key[-4-TXNUM_LEN:] - utxo_value_packed = self.db.utxo_db.get(udb_key) - if utxo_value_packed: - # Remove both entries for this UTXO - self.db_deletes.append(hdb_key) - self.db_deletes.append(udb_key) - return hashX + tx_num_packed + utxo_value_packed - - raise ChainError(f'UTXO {hash_to_hex_str(tx_hash)} / {tx_idx:,d} not ' - f'found in "h" table') + hdb_key = b'h' + tx_numb + idx_packed + hashX = self.db.utxo_db.get(hdb_key) + if hashX is None: + raise ChainError(f'UTXO {hash_to_hex_str(tx_hash)} / {txout_idx:,d} not ' + f'found in "h" table') + # Key: b'u' + address_hashX + tx_num + txout_idx + # Value: the UTXO value as a 64-bit unsigned integer + udb_key = b'u' + hashX + tx_numb + idx_packed + utxo_value_packed = self.db.utxo_db.get(udb_key) + if utxo_value_packed is None: + raise ChainError(f'UTXO {hash_to_hex_str(tx_hash)} / {txout_idx:,d} not ' + f'found in "u" table') + # Remove both entries for this UTXO + self.db_deletes.append(hdb_key) + self.db_deletes.append(udb_key) + return hashX + tx_numb + utxo_value_packed async def _process_prefetched_blocks(self): '''Loop forever processing blocks as they arrive.''' @@ -747,7 +777,7 @@ def advance_txs(self, txs, is_unspendable): tx_num = self.tx_count - len(txs) script_name_hashX = self.coin.name_hashX_from_script - update_touched = self.touched.update + update_touched_hashxs = self.touched_hashxs.update hashXs_by_tx = [] append_hashXs = hashXs_by_tx.append @@ -763,10 +793,15 @@ def advance_txs(self, txs, is_unspendable): append_hashX(hashX) append_hashXs(hashXs) - update_touched(hashXs) + update_touched_hashxs(hashXs) tx_num += 1 - self.db.history.add_unflushed(hashXs_by_tx, self.tx_count - len(txs)) + self.db.history.add_unflushed( + hashXs_by_tx=hashXs_by_tx, + first_tx_num=self.tx_count - len(txs), + txhash_to_txnum_map={}, + txo_to_spender_map={}, + ) return result @@ -774,8 +809,6 @@ def advance_txs(self, txs, is_unspendable): class LTORBlockProcessor(BlockProcessor): def advance_txs(self, txs, is_unspendable): - self.tx_hashes.append(b''.join(tx.txid for tx in txs)) - # Use local vars for speed in the loops undo_info = [] tx_num = self.tx_count @@ -783,9 +816,15 @@ def advance_txs(self, txs, is_unspendable): put_utxo = self.utxo_cache.__setitem__ spend_utxo = self.spend_utxo undo_info_append = undo_info.append - update_touched = self.touched.update + update_touched_hashxs = self.touched_hashxs.update + add_touched_outpoint = self.touched_outpoints.add + txhash_to_txnum_map = {} + put_txhash_to_txnum_map = txhash_to_txnum_map.__setitem__ + txo_to_spender_map = {} + put_txo_to_spender_map = txo_to_spender_map.__setitem__ to_le_uint32 = pack_le_uint32 to_le_uint64 = pack_le_uint64 + _pack_txnum = pack_txnum hashXs_by_tx = [set() for _ in txs] @@ -793,7 +832,7 @@ def advance_txs(self, txs, is_unspendable): for tx, hashXs in zip(txs, hashXs_by_tx): tx_hash = tx.txid add_hashXs = hashXs.add - tx_numb = to_le_uint64(tx_num)[:TXNUM_LEN] + tx_numb = _pack_txnum(tx_num) for idx, txout in enumerate(tx.outputs): # Ignore unspendable outputs @@ -803,8 +842,10 @@ def advance_txs(self, txs, is_unspendable): # Get the hashX hashX = script_hashX(txout.pk_script) add_hashXs(hashX) - put_utxo(tx_hash + to_le_uint32(idx), + put_utxo(tx_hash + to_le_uint32(idx)[:TXOUTIDX_LEN], hashX + tx_numb + to_le_uint64(txout.value)) + add_touched_outpoint((tx_hash, idx)) + put_txhash_to_txnum_map(tx_hash, tx_num) tx_num += 1 # Spend the inputs @@ -817,12 +858,22 @@ def advance_txs(self, txs, is_unspendable): cache_value = spend_utxo(txin.prev_hash, txin.prev_idx) undo_info_append(cache_value) add_hashXs(cache_value[:HASHX_LEN]) + prevout_tuple = (txin.prev_hash, txin.prev_idx) + put_txo_to_spender_map(prevout_tuple, tx_hash) + add_touched_outpoint(prevout_tuple) # Update touched set for notifications for hashXs in hashXs_by_tx: - update_touched(hashXs) + update_touched_hashxs(hashXs) - self.db.history.add_unflushed(hashXs_by_tx, self.tx_count) + self.tx_hashes.append(b''.join(tx.txid for tx in txs)) + self.wtxids.append(b''.join(tx.wtxid for tx in txs)) + self.db.history.add_unflushed( + hashXs_by_tx=hashXs_by_tx, + first_tx_num=self.tx_count, + txhash_to_txnum_map=txhash_to_txnum_map, + txo_to_spender_map=txo_to_spender_map, + ) self.tx_count = tx_num self.db.tx_counts.append(tx_num) @@ -839,7 +890,8 @@ def backup_txs(self, txs, is_unspendable): # Use local vars for speed in the loops put_utxo = self.utxo_cache.__setitem__ spend_utxo = self.spend_utxo - add_touched = self.touched.add + add_touched_hashx = self.touched_hashxs.add + add_touched_outpoint = self.touched_outpoints.add undo_entry_len = HASHX_LEN + TXNUM_LEN + 8 # Restore coins that had been spent @@ -850,8 +902,10 @@ def backup_txs(self, txs, is_unspendable): if txin.is_generation(): continue undo_item = undo_info[n:n + undo_entry_len] - put_utxo(txin.prev_hash + pack_le_uint32(txin.prev_idx), undo_item) - add_touched(undo_item[:HASHX_LEN]) + prevout = txin.prev_hash + pack_le_uint32(txin.prev_idx)[:TXOUTIDX_LEN] + put_utxo(prevout, undo_item) + add_touched_hashx(undo_item[:HASHX_LEN]) + add_touched_outpoint((txin.prev_hash, txin.prev_idx)) n += undo_entry_len assert n == len(undo_info) @@ -868,6 +922,8 @@ def backup_txs(self, txs, is_unspendable): # Get the hashX cache_value = spend_utxo(tx_hash, idx) hashX = cache_value[:HASHX_LEN] - add_touched(hashX) + add_touched_hashx(hashX) + add_touched_outpoint((tx_hash, idx)) + self.undo_tx_hashes.append(b''.join(tx.txid for tx in txs)) self.tx_count -= len(txs) diff --git a/src/electrumx/server/controller.py b/src/electrumx/server/controller.py index 60b8046b9..36408edc0 100644 --- a/src/electrumx/server/controller.py +++ b/src/electrumx/server/controller.py @@ -6,6 +6,7 @@ # and warranty status of this software. from asyncio import Event +from typing import Set, Dict, Tuple from aiorpcx import _version as aiorpcx_version @@ -31,43 +32,83 @@ class Notifications: # notifications appropriately. def __init__(self): - self._touched_mp = {} - self._touched_bp = {} + self._touched_hashxs_mp = {} # type: Dict[int, Set[bytes]] + self._touched_hashxs_bp = {} # type: Dict[int, Set[bytes]] + self._touched_outpoints_mp = {} # type: Dict[int, Set[Tuple[bytes, int]]] + self._touched_outpoints_bp = {} # type: Dict[int, Set[Tuple[bytes, int]]] self._highest_block = -1 async def _maybe_notify(self): - tmp, tbp = self._touched_mp, self._touched_bp - common = set(tmp).intersection(tbp) - if common: - height = max(common) - elif tmp and max(tmp) == self._highest_block: + th_mp, th_bp = self._touched_hashxs_mp, self._touched_hashxs_bp + # figure out block height + common_heights = set(th_mp).intersection(th_bp) + if common_heights: + height = max(common_heights) + elif th_mp and max(th_mp) == self._highest_block: height = self._highest_block else: # Either we are processing a block and waiting for it to # come in, or we have not yet had a mempool update for the # new block height return - touched = tmp.pop(height) - for old in [h for h in tmp if h <= height]: - del tmp[old] - for old in [h for h in tbp if h <= height]: - touched.update(tbp.pop(old)) - await self.notify(height, touched) - - async def notify(self, height, touched): + # hashXs + touched_hashxs = th_mp.pop(height) + for old in [h for h in th_mp if h <= height]: + del th_mp[old] + for old in [h for h in th_bp if h <= height]: + touched_hashxs.update(th_bp.pop(old)) + # outpoints + to_mp, to_bp = self._touched_outpoints_mp, self._touched_outpoints_bp + touched_outpoints = to_mp.pop(height) + for old in [h for h in to_mp if h <= height]: + del to_mp[old] + for old in [h for h in to_bp if h <= height]: + touched_outpoints.update(to_bp.pop(old)) + + await self.notify( + height=height, + touched_hashxs=touched_hashxs, + touched_outpoints=touched_outpoints, + ) + + async def notify( + self, + *, + touched_hashxs: Set[bytes], + touched_outpoints: Set[Tuple[bytes, int]], + height: int, + ): pass - async def start(self, height, notify_func): + async def start(self, height: int, notify_func): self._highest_block = height self.notify = notify_func - await self.notify(height, set()) - - async def on_mempool(self, touched, height): - self._touched_mp[height] = touched + await self.notify( + height=height, + touched_hashxs=set(), + touched_outpoints=set(), + ) + + async def on_mempool( + self, + *, + touched_hashxs: Set[bytes], + touched_outpoints: Set[Tuple[bytes, int]], + height: int, + ): + self._touched_hashxs_mp[height] = touched_hashxs + self._touched_outpoints_mp[height] = touched_outpoints await self._maybe_notify() - async def on_block(self, touched, height): - self._touched_bp[height] = touched + async def on_block( + self, + *, + touched_hashxs: Set[bytes], + touched_outpoints: Set[Tuple[bytes, int]], + height: int, + ): + self._touched_hashxs_bp[height] = touched_hashxs + self._touched_outpoints_bp[height] = touched_outpoints self._highest_block = height await self._maybe_notify() diff --git a/src/electrumx/server/daemon.py b/src/electrumx/server/daemon.py index c76ae91ef..453240628 100644 --- a/src/electrumx/server/daemon.py +++ b/src/electrumx/server/daemon.py @@ -332,11 +332,16 @@ async def mempool_info(self) -> dict[str, float]: 'incrementalrelayfee': mempool_info['incrementalrelayfee'], } - async def getrawtransaction(self, hex_hash, verbose=False): + async def getrawtransaction(self, hex_hash, verbose=False, blockhash=None): '''Return the serialized raw transaction with the given hash.''' # Cast to int because some coin daemons are old and require it - return await self._send_single('getrawtransaction', - (hex_hash, int(verbose))) + verbose = int(verbose) + if blockhash is None: + return await self._send_single('getrawtransaction', (hex_hash, verbose)) + else: + # given a blockhash, modern bitcoind can lookup the tx even without txindex: + # https://github.com/bitcoin/bitcoin/pull/10275 + return await self._send_single('getrawtransaction', (hex_hash, verbose, blockhash)) async def getrawtransactions(self, hex_hashes, replace_errs=True): '''Return the serialized raw transactions with the given hashes. diff --git a/src/electrumx/server/db.py b/src/electrumx/server/db.py index b3d93c37e..4e79504a5 100644 --- a/src/electrumx/server/db.py +++ b/src/electrumx/server/db.py @@ -16,7 +16,8 @@ from bisect import bisect_right from dataclasses import dataclass from glob import glob -from typing import Dict, List, Sequence, Tuple, Optional, TYPE_CHECKING +from typing import Dict, List, Sequence, Tuple, Optional, TYPE_CHECKING, Union +from functools import partial import attr from aiorpcx import run_in_thread, sleep @@ -28,8 +29,11 @@ formatted_time, pack_be_uint16, pack_be_uint32, pack_le_uint64, pack_le_uint32, unpack_le_uint32, unpack_be_uint32, unpack_le_uint64 ) +from electrumx.lib.tx import TXOSpendStatus from electrumx.server.storage import db_class, Storage -from electrumx.server.history import History, TXNUM_LEN +from electrumx.server.history import ( + History, TXNUM_LEN, TXOUTIDX_LEN, TXOUTIDX_PADDING, pack_txnum, unpack_txnum, +) if TYPE_CHECKING: from electrumx.server.env import Env @@ -50,6 +54,9 @@ class FlushData: tx_count = attr.ib() headers = attr.ib() block_tx_hashes = attr.ib() # type: List[bytes] + undo_block_tx_hashes = attr.ib() # type: List[bytes] + block_wtxids = attr.ib() # type: List[bytes] + undo_historical_spends = attr.ib() # type: List[bytes] # The following are flushed to the UTXO DB if undo_infos is not None undo_infos = attr.ib() # type: List[Tuple[Sequence[bytes], int]] adds = attr.ib() # type: Dict[bytes, bytes] # txid+out_idx -> hashX+tx_num+value_sats @@ -57,9 +64,6 @@ class FlushData: tip = attr.ib() -COMP_TXID_LEN = 4 - - class DB: '''Simple wrapper of the backend database for querying. @@ -67,7 +71,7 @@ class DB: it was shutdown uncleanly. ''' - DB_VERSIONS = (6, 7, 8) + DB_VERSIONS = (9, ) utxo_db: Optional['Storage'] @@ -93,11 +97,11 @@ def __init__(self, env: 'Env'): self.db_class = db_class(self.env.db_engine) self.history = History() - # Key: b'u' + address_hashX + txout_idx + tx_num + # Key: b'u' + address_hashX + tx_num + txout_idx # Value: the UTXO value as a 64-bit unsigned integer (in satoshis) # "at address, at outpoint, there is a UTXO of value v" # --- - # Key: b'h' + compressed_tx_hash + txout_idx + tx_num + # Key: b'h' + tx_num + txout_idx # Value: hashX # "some outpoint created a UTXO at address" # --- @@ -106,13 +110,11 @@ def __init__(self, env: 'Env'): # "undo data: list of UTXOs spent at block height" self.utxo_db = None - self.utxo_flush_count = 0 self.fs_height = -1 self.fs_tx_count = 0 self.db_height = -1 self.db_tx_count = 0 self.db_tip = None # type: Optional[bytes] - self.tx_counts = None self.last_flush = time.time() self.last_flush_tx_count = 0 self.wall_time = 0 @@ -128,9 +130,12 @@ def __init__(self, env: 'Env'): # on-disk: raw block headers in chain order self.headers_file = util.LogicalFile('meta/headers', 2, 16000000) # on-disk: cumulative number of txs at the end of height N + self.tx_counts = None # type: Optional[array] self.tx_counts_file = util.LogicalFile('meta/txcounts', 2, 2000000) # on-disk: 32 byte txids in chain order, allows (tx_num -> txid) map self.hashes_file = util.LogicalFile('meta/hashes', 4, 16000000) + # on-disk: 32 byte wtxids in chain order, allows (tx_num -> wtxid) map + self.wtxids_file = util.LogicalFile('meta/wtxids', 4, 16000000) if not self.coin.STATIC_BLOCK_HEADERS: self.headers_offsets_file = util.LogicalFile( 'meta/headers_offsets', 2, 16000000) @@ -149,7 +154,7 @@ async def _read_tx_counts(self): else: assert self.db_tx_count == 0 - async def _open_dbs(self, for_sync: bool, compacting: bool): + async def _open_dbs(self, *, for_sync: bool): assert self.utxo_db is None # First UTXO DB @@ -168,17 +173,16 @@ async def _open_dbs(self, for_sync: bool, compacting: bool): self.read_utxo_state() # Then history DB - self.utxo_flush_count = self.history.open_db(self.db_class, for_sync, - self.utxo_flush_count, - compacting) + self.history.open_db( + db_class=self.db_class, + for_sync=for_sync, + utxo_db_tx_count=self.db_tx_count, + ) self.clear_excess_undo_info() # Read TX counts (requires meta directory) await self._read_tx_counts() - async def open_for_compacting(self): - await self._open_dbs(True, True) - async def open_for_sync(self): '''Open the databases to sync to the daemon. @@ -186,7 +190,7 @@ async def open_for_sync(self): synchronization. When serving clients we want the open files for serving network connections. ''' - await self._open_dbs(True, False) + await self._open_dbs(for_sync=True) async def open_for_serving(self): '''Open the databases for serving. If they are already open they are @@ -197,7 +201,7 @@ async def open_for_serving(self): self.utxo_db.close() self.history.close_db() self.utxo_db = None - await self._open_dbs(False, False) + await self._open_dbs(for_sync=False) # Header merkle cache @@ -213,19 +217,22 @@ async def header_branch_and_root(self, length, height): return await self.header_mc.branch_and_root(length, height) # Flushing - def assert_flushed(self, flush_data): + def assert_flushed(self, flush_data: FlushData): '''Asserts state is fully flushed.''' assert flush_data.tx_count == self.fs_tx_count == self.db_tx_count assert flush_data.height == self.fs_height == self.db_height assert flush_data.tip == self.db_tip assert not flush_data.headers assert not flush_data.block_tx_hashes + assert not flush_data.undo_block_tx_hashes + assert not flush_data.block_wtxids + assert not flush_data.undo_historical_spends assert not flush_data.adds assert not flush_data.deletes assert not flush_data.undo_infos self.history.assert_flushed() - def flush_dbs(self, flush_data, flush_utxos, estimate_txs_remaining): + def flush_dbs(self, flush_data: FlushData, flush_utxos, estimate_txs_remaining): '''Flush out cached state. History is always flushed; UTXOs are flushed if flush_utxos.''' if flush_data.height == self.db_height: @@ -253,7 +260,7 @@ def flush_dbs(self, flush_data, flush_utxos, estimate_txs_remaining): self.flush_state(self.utxo_db) elapsed = self.last_flush - start_time - self.logger.info(f'flush #{self.history.flush_count:,d} took ' + self.logger.info(f'flush took ' f'{elapsed:.1f}s. Height {flush_data.height:,d} ' f'txs: {flush_data.tx_count:,d} ({tx_delta:+,d})') @@ -268,7 +275,7 @@ def flush_dbs(self, flush_data, flush_utxos, estimate_txs_remaining): self.logger.info(f'sync time: {formatted_time(self.wall_time)} ' f'ETA: {formatted_time(eta)}') - def flush_fs(self, flush_data): + def flush_fs(self, flush_data: FlushData): '''Write headers, tx counts and block tx hashes to the filesystem. The first height to write is self.fs_height + 1. The FS @@ -278,15 +285,21 @@ def flush_fs(self, flush_data): prior_tx_count = (self.tx_counts[self.fs_height] if self.fs_height >= 0 else 0) assert len(flush_data.block_tx_hashes) == len(flush_data.headers) + assert len(flush_data.block_wtxids) == len(flush_data.headers) assert flush_data.height == self.fs_height + len(flush_data.headers) assert flush_data.tx_count == (self.tx_counts[-1] if self.tx_counts else 0) assert len(self.tx_counts) == flush_data.height + 1 + hashes = b''.join(flush_data.block_tx_hashes) flush_data.block_tx_hashes.clear() assert len(hashes) % 32 == 0 assert len(hashes) // 32 == flush_data.tx_count - prior_tx_count + wtxids = b''.join(flush_data.block_wtxids) + flush_data.block_wtxids.clear() + assert len(wtxids) == len(hashes) + # Write the headers, tx counts, and tx hashes start_time = time.monotonic() height_start = self.fs_height + 1 @@ -300,6 +313,8 @@ def flush_fs(self, flush_data): self.tx_counts[height_start:].tobytes()) offset = prior_tx_count * 32 self.hashes_file.write(offset, hashes) + offset = prior_tx_count * 32 + self.wtxids_file.write(offset, wtxids) self.fs_height = flush_data.height self.fs_tx_count = flush_data.tx_count @@ -331,11 +346,11 @@ def flush_utxo_db(self, batch, flush_data: FlushData): for key, value in flush_data.adds.items(): # key: txid+out_idx, value: hashX+tx_num+value_sats hashX = value[:HASHX_LEN] - txout_idx = key[-4:] + txout_idx = key[-TXOUTIDX_LEN:] tx_num = value[HASHX_LEN: HASHX_LEN+TXNUM_LEN] value_sats = value[-8:] - suffix = txout_idx + tx_num - batch_put(b'h' + key[:COMP_TXID_LEN] + suffix, hashX) + suffix = tx_num + txout_idx + batch_put(b'h' + suffix, hashX) batch_put(b'u' + hashX + suffix, value_sats) flush_data.adds.clear() @@ -352,7 +367,6 @@ def flush_utxo_db(self, batch, flush_data: FlushData): f'{spend_count:,d} spends in ' f'{elapsed:.1f}s, committing...') - self.utxo_flush_count = self.history.flush_count self.db_height = flush_data.height self.db_tx_count = flush_data.tx_count self.db_tip = flush_data.tip @@ -365,25 +379,39 @@ def flush_state(self, batch): self.last_flush_tx_count = self.fs_tx_count self.write_utxo_state(batch) - def flush_backup(self, flush_data, touched): + def flush_backup(self, flush_data: FlushData, touched_hashxs): '''Like flush_dbs() but when backing up. All UTXOs are flushed.''' assert not flush_data.headers assert not flush_data.block_tx_hashes + assert not flush_data.block_wtxids assert flush_data.height < self.db_height self.history.assert_flushed() + assert len(flush_data.undo_block_tx_hashes) == self.db_height - flush_data.height start_time = time.time() tx_delta = flush_data.tx_count - self.last_flush_tx_count + tx_hashes = [] + for block in flush_data.undo_block_tx_hashes: + tx_hashes += [*util.chunks(block, 32)] + flush_data.undo_block_tx_hashes.clear() + assert len(tx_hashes) == -tx_delta + self.backup_fs(flush_data.height, flush_data.tx_count) - self.history.backup(touched, flush_data.tx_count) + self.history.backup( + hashXs=touched_hashxs, + tx_count=flush_data.tx_count, + tx_hashes=tx_hashes, + spends=flush_data.undo_historical_spends, + ) + flush_data.undo_historical_spends.clear() with self.utxo_db.write_batch() as batch: self.flush_utxo_db(batch, flush_data) # Flush state last as it reads the wall time. self.flush_state(batch) elapsed = self.last_flush - start_time - self.logger.info(f'backup flush #{self.history.flush_count:,d} took ' + self.logger.info(f'backup flush took ' f'{elapsed:.1f}s. Height {flush_data.height:,d} ' f'txs: {flush_data.tx_count:,d} ({tx_delta:+,d})') @@ -416,14 +444,14 @@ def backup_fs(self, height, tx_count): # Truncate header_mc: header count is 1 more than the height. self.header_mc.truncate(height + 1) - async def raw_header(self, height): + async def raw_header(self, height: int) -> bytes: '''Return the binary header at the given height.''' header, n = await self.read_headers(height, 1) if n != 1: raise IndexError(f'height {height:,d} out of range') return header - async def read_headers(self, start_height, count): + async def read_headers(self, start_height: int, count: int) -> Tuple[bytes, int]: '''Requires start_height >= 0, count >= 0. Reads as many headers as are available starting at start_height up to count. This would be zero if start_height is beyond self.db_height, for @@ -447,18 +475,19 @@ def read_headers(): return await run_in_thread(read_headers) - def fs_tx_hash(self, tx_num): + def fs_tx_hash(self, tx_num: int, *, wtxid: bool = False) -> Tuple[Optional[bytes], int]: '''Return a pair (tx_hash, tx_height) for the given tx number. If the tx_height is not on disk, returns (None, tx_height).''' + file = self.wtxids_file if wtxid else self.hashes_file tx_height = bisect_right(self.tx_counts, tx_num) if tx_height > self.db_height: tx_hash = None else: - tx_hash = self.hashes_file.read(tx_num * 32, 32) + tx_hash = file.read(tx_num * 32, 32) return tx_hash, tx_height - def fs_tx_hashes_at_blockheight(self, block_height): + def fs_tx_hashes_at_blockheight(self, block_height, *, wtxid: bool = False) -> Sequence[bytes]: '''Return a list of tx_hashes at given block height, in the same order as in the block. ''' @@ -470,12 +499,16 @@ def fs_tx_hashes_at_blockheight(self, block_height): else: first_tx_num = 0 num_txs_in_block = self.tx_counts[block_height] - first_tx_num - tx_hashes = self.hashes_file.read(first_tx_num * 32, num_txs_in_block * 32) + file = self.wtxids_file if wtxid else self.hashes_file + tx_hashes = file.read(first_tx_num * 32, num_txs_in_block * 32) assert num_txs_in_block == len(tx_hashes) // 32 return [tx_hashes[idx * 32: (idx+1) * 32] for idx in range(num_txs_in_block)] - async def tx_hashes_at_blockheight(self, block_height): - return await run_in_thread(self.fs_tx_hashes_at_blockheight, block_height) + async def tx_hashes_at_blockheight( + self, block_height, *, wtxid: bool = False, + ) -> Sequence[bytes]: + func = partial(self.fs_tx_hashes_at_blockheight, block_height, wtxid=wtxid) + return await run_in_thread(func) async def fs_block_hashes(self, height, count): headers_concat, headers_count = await self.read_headers(height, count) @@ -511,6 +544,63 @@ def read_history(): f'not found (reorg?), retrying...') await sleep(0.25) + def fs_txnum_for_txhash(self, tx_hash: bytes) -> Optional[int]: + return self.history.get_txnum_for_txhash(tx_hash) + + async def txnum_for_txhash(self, tx_hash: bytes) -> Optional[int]: + return await run_in_thread(self.fs_txnum_for_txhash, tx_hash) + + async def get_blockheight_and_txpos_for_txhash( + self, tx_hash: bytes, + ) -> Tuple[Optional[int], Optional[int]]: + '''Returns (block_height, tx_pos) for a confirmed tx_hash.''' + tx_num = await self.txnum_for_txhash(tx_hash) + if tx_num is None: + return None, None + return self.get_blockheight_and_txpos_for_txnum(tx_num) + + def get_blockheight_and_txpos_for_txnum( + self, tx_num: int, + ) -> Tuple[Optional[int], Optional[int]]: + '''Returns (block_height, tx_pos) for a tx_num.''' + height = bisect_right(self.tx_counts, tx_num) + if height > self.db_height: + return None, None + assert height > 0 + tx_pos = tx_num - self.tx_counts[height - 1] + return height, tx_pos + + def fs_spender_for_txo(self, prev_txhash: bytes, txout_idx: int) -> 'TXOSpendStatus': + '''For an outpoint, returns its spend-status (considering only the DB, + not the mempool). + ''' + prev_txnum = self.fs_txnum_for_txhash(prev_txhash) + if prev_txnum is None: # outpoint never existed (in chain) + return TXOSpendStatus(prev_height=None) + prev_height = self.get_blockheight_and_txpos_for_txnum(prev_txnum)[0] + hashx, _ = self._get_hashX_for_utxo(prev_txhash, txout_idx) + if hashx: # outpoint exists and is unspent + return TXOSpendStatus(prev_height=prev_height) + # by now we know prev_txhash was mined, and + # txout_idx was either spent or is out-of-bounds + spender_txnum = self.history.get_spender_txnum_for_txo(prev_txnum, txout_idx) + if spender_txnum is None: + # txout_idx was out-of-bounds + return TXOSpendStatus(prev_height=None) + # by now we know the outpoint exists and it was spent. + spender_txhash, spender_height = self.fs_tx_hash(spender_txnum) + if spender_txhash is None: + # not sure if this can happen. maybe through a race? + return TXOSpendStatus(prev_height=prev_height) + return TXOSpendStatus( + prev_height=prev_height, + spender_txhash=spender_txhash, + spender_height=spender_height, + ) + + async def spender_for_txo(self, prev_txhash: bytes, txout_idx: int) -> 'TXOSpendStatus': + return await run_in_thread(self.fs_spender_for_txo, prev_txhash, txout_idx) + # -- Undo information def min_undo_height(self, max_height): @@ -588,13 +678,12 @@ def clear_excess_undo_info(self): # -- UTXO database def read_utxo_state(self): - state = self.utxo_db.get(b'state') + state = self.utxo_db.get(b'\0state') if not state: self.db_height = -1 self.db_tx_count = 0 self.db_tip = b'\0' * 32 self.db_version = max(self.DB_VERSIONS) - self.utxo_flush_count = 0 self.wall_time = 0 self.first_sync = True else: @@ -616,7 +705,6 @@ def read_utxo_state(self): self.db_height = state['height'] self.db_tx_count = state['tx_count'] self.db_tip = state['tip'] - self.utxo_flush_count = state['utxo_flush_count'] self.wall_time = state['wall_time'] self.first_sync = state['first_sync'] @@ -627,7 +715,7 @@ def read_utxo_state(self): # Upgrade DB if self.db_version != max(self.DB_VERSIONS): - self.upgrade_db() + pass # call future upgrade logic here # Log some stats self.logger.info(f'UTXO DB version: {self.db_version:d}') @@ -643,90 +731,6 @@ def read_utxo_state(self): f'sync time so far: {util.formatted_time(self.wall_time)}' ) - def upgrade_db(self): - self.logger.info(f'UTXO DB version: {self.db_version}') - self.logger.info('Upgrading your DB; this can take some time...') - - def upgrade_u_prefix(prefix): - count = 0 - with self.utxo_db.write_batch() as batch: - batch_delete = batch.delete - batch_put = batch.put - # Key: b'u' + address_hashX + tx_idx + tx_num - for db_key, db_value in self.utxo_db.iterator(prefix=prefix): - if len(db_key) == 21: - return - break - if self.db_version == 6: - for db_key, db_value in self.utxo_db.iterator(prefix=prefix): - count += 1 - batch_delete(db_key) - batch_put(db_key[:14] + b'\0\0' + db_key[14:] + b'\0', db_value) - else: - for db_key, db_value in self.utxo_db.iterator(prefix=prefix): - count += 1 - batch_delete(db_key) - batch_put(db_key + b'\0', db_value) - return count - - last = time.monotonic() - count = 0 - for cursor in range(65536): - prefix = b'u' + pack_be_uint16(cursor) - count += upgrade_u_prefix(prefix) - now = time.monotonic() - if now > last + 10: - last = now - self.logger.info(f'DB 1 of 3: {count:,d} entries updated, ' - f'{cursor * 100 / 65536:.1f}% complete') - self.logger.info('DB 1 of 3 upgraded successfully') - - def upgrade_h_prefix(prefix): - count = 0 - with self.utxo_db.write_batch() as batch: - batch_delete = batch.delete - batch_put = batch.put - # Key: b'h' + compressed_tx_hash + tx_idx + tx_num - for db_key, db_value in self.utxo_db.iterator(prefix=prefix): - if len(db_key) == 14: - return - break - if self.db_version == 6: - for db_key, db_value in self.utxo_db.iterator(prefix=prefix): - count += 1 - batch_delete(db_key) - batch_put(db_key[:7] + b'\0\0' + db_key[7:] + b'\0', db_value) - else: - for db_key, db_value in self.utxo_db.iterator(prefix=prefix): - count += 1 - batch_delete(db_key) - batch_put(db_key + b'\0', db_value) - return count - - last = time.monotonic() - count = 0 - for cursor in range(65536): - prefix = b'h' + pack_be_uint16(cursor) - count += upgrade_h_prefix(prefix) - now = time.monotonic() - if now > last + 10: - last = now - self.logger.info(f'DB 2 of 3: {count:,d} entries updated, ' - f'{cursor * 100 / 65536:.1f}% complete') - - # Upgrade tx_counts file - size = (self.db_height + 1) * 8 - tx_counts = self.tx_counts_file.read(0, size) - if len(tx_counts) == (self.db_height + 1) * 4: - tx_counts = array('I', tx_counts) - tx_counts = array('Q', tx_counts) - self.tx_counts_file.write(0, tx_counts.tobytes()) - - self.db_version = max(self.DB_VERSIONS) - with self.utxo_db.write_batch() as batch: - self.write_utxo_state(batch) - self.logger.info('DB 2 of 3 upgraded successfully') - def write_utxo_state(self, batch): '''Write (UTXO) state to the batch.''' state = { @@ -734,30 +738,24 @@ def write_utxo_state(self, batch): 'height': self.db_height, 'tx_count': self.db_tx_count, 'tip': self.db_tip, - 'utxo_flush_count': self.utxo_flush_count, 'wall_time': self.wall_time, 'first_sync': self.first_sync, 'db_version': self.db_version, } - batch.put(b'state', repr(state).encode()) - - def set_flush_count(self, count): - self.utxo_flush_count = count - with self.utxo_db.write_batch() as batch: - self.write_utxo_state(batch) + batch.put(b'\0state', repr(state).encode()) async def all_utxos(self, hashX): '''Return all UTXOs for an address sorted in no particular order.''' def read_utxos(): utxos = [] utxos_append = utxos.append - txnum_padding = bytes(8-TXNUM_LEN) - # Key: b'u' + address_hashX + txout_idx + tx_num + # Key: b'u' + address_hashX + tx_num + txout_idx # Value: the UTXO value as a 64-bit unsigned integer prefix = b'u' + hashX for db_key, db_value in self.utxo_db.iterator(prefix=prefix): - txout_idx, = unpack_le_uint32(db_key[-TXNUM_LEN-4:-TXNUM_LEN]) - tx_num, = unpack_le_uint64(db_key[-TXNUM_LEN:] + txnum_padding) + txout_idx, = unpack_le_uint32(db_key[-TXOUTIDX_LEN:] + TXOUTIDX_PADDING) + tx_numb = db_key[-TXOUTIDX_LEN-TXNUM_LEN:-TXOUTIDX_LEN] + tx_num = unpack_txnum(tx_numb) value, = unpack_le_uint64(db_value) tx_hash, height = self.fs_tx_hash(tx_num) utxos_append(UTXO(tx_num, txout_idx, tx_hash, height, value)) @@ -771,6 +769,23 @@ def read_utxos(): f'found (reorg?), retrying...') await sleep(0.25) + def _get_hashX_for_utxo( + self, tx_hash: bytes, txout_idx: int, + ) -> Tuple[Optional[bytes], Optional[bytes]]: + idx_packed = pack_le_uint32(txout_idx)[:TXOUTIDX_LEN] + tx_num = self.fs_txnum_for_txhash(tx_hash) + if tx_num is None: + return None, None + tx_numb = pack_txnum(tx_num) + + # Key: b'h' + tx_num + txout_idx + # Value: hashX + db_key = b'h' + tx_numb + idx_packed + hashX = self.utxo_db.get(db_key) + if hashX is None: + return None, None + return hashX, tx_numb + idx_packed + async def lookup_utxos(self, prevouts): '''For each prevout, lookup it up in the DB and return a (hashX, value) pair or None if not found. @@ -781,22 +796,7 @@ def lookup_hashXs(): '''Return (hashX, suffix) pairs, or None if not found, for each prevout. ''' - def lookup_hashX(tx_hash, tx_idx): - idx_packed = pack_le_uint32(tx_idx) - txnum_padding = bytes(8-TXNUM_LEN) - - # Key: b'h' + compressed_tx_hash + tx_idx + tx_num - # Value: hashX - prefix = b'h' + tx_hash[:COMP_TXID_LEN] + idx_packed - - # Find which entry, if any, the TX_HASH matches. - for db_key, hashX in self.utxo_db.iterator(prefix=prefix): - tx_num_packed = db_key[-TXNUM_LEN:] - tx_num, = unpack_le_uint64(tx_num_packed + txnum_padding) - hash, _height = self.fs_tx_hash(tx_num) - if hash == tx_hash: - return hashX, idx_packed + tx_num_packed - return None, None + lookup_hashX = self._get_hashX_for_utxo return [lookup_hashX(*prevout) for prevout in prevouts] def lookup_utxos(hashX_pairs): @@ -806,7 +806,7 @@ def lookup_utxo(hashX, suffix): # of us and has mempool txs spending outputs from # that new block return None - # Key: b'u' + address_hashX + tx_idx + tx_num + # Key: b'u' + address_hashX + tx_num + txout_idx # Value: the UTXO value as a 64-bit unsigned integer key = b'u' + hashX + suffix db_value = self.utxo_db.get(key) diff --git a/src/electrumx/server/env.py b/src/electrumx/server/env.py index 36f5afc00..066e7ac38 100644 --- a/src/electrumx/server/env.py +++ b/src/electrumx/server/env.py @@ -77,6 +77,7 @@ def __init__(self, coin=None): self.reorg_limit = self.integer('REORG_LIMIT', self.coin.REORG_LIMIT) self.daemon_poll_interval_blocks_msec = self.integer('DAEMON_POLL_INTERVAL_BLOCKS', 5000) self.daemon_poll_interval_mempool_msec = self.integer('DAEMON_POLL_INTERVAL_MEMPOOL', 5000) + self.daemon_has_txindex = self.boolean('DAEMON_HAS_TXINDEX', True) # Server limits to help prevent DoS diff --git a/src/electrumx/server/history.py b/src/electrumx/server/history.py index dc7dd8d54..eb5eec0c5 100644 --- a/src/electrumx/server/history.py +++ b/src/electrumx/server/history.py @@ -9,62 +9,74 @@ '''History by script hash (address).''' import ast -import bisect import time -from array import array from collections import defaultdict -from typing import TYPE_CHECKING, Type, Optional +from typing import TYPE_CHECKING, Type, Optional, Dict, Sequence, Tuple +import itertools import electrumx.lib.util as util from electrumx.lib.hash import HASHX_LEN, hash_to_hex_str -from electrumx.lib.util import (pack_be_uint16, pack_le_uint64, - unpack_be_uint16_from, unpack_le_uint64) +from electrumx.lib.util import (pack_le_uint64, unpack_le_uint64, + pack_le_uint32, unpack_le_uint32, + pack_be_uint64, unpack_be_uint64) if TYPE_CHECKING: from electrumx.server.storage import Storage TXNUM_LEN = 5 -FLUSHID_LEN = 2 +TXNUM_PADDING = bytes(8 - TXNUM_LEN) +TXOUTIDX_LEN = 3 +TXOUTIDX_PADDING = bytes(4 - TXOUTIDX_LEN) + + +def unpack_txnum(tx_numb: bytes) -> int: + return unpack_be_uint64(TXNUM_PADDING + tx_numb)[0] + + +def pack_txnum(tx_num: int) -> bytes: + return pack_be_uint64(tx_num)[-TXNUM_LEN:] class History: - DB_VERSIONS = (0, 1) + DB_VERSIONS = (3, ) db: Optional['Storage'] def __init__(self): self.logger = util.class_logger(__name__, self.__class__.__name__) - # For history compaction - self.max_hist_row_entries = 12500 - self.unflushed = defaultdict(bytearray) - self.unflushed_count = 0 - self.flush_count = 0 - self.comp_flush_count = -1 - self.comp_cursor = -1 + self.hist_db_tx_count = 0 + self.hist_db_tx_count_next = 0 # after next flush, next value for self.hist_db_tx_count self.db_version = max(self.DB_VERSIONS) self.upgrade_cursor = -1 - # Key: address_hashX + flush_id - # Value: sorted "list" of tx_nums in history of hashX + self._unflushed_hashxs = defaultdict(bytearray) # type: Dict[bytes, bytearray] + self._unflushed_hashxs_count = 0 + self._unflushed_txhash_to_txnum_map = {} # type: Dict[bytes, int] + self._unflushed_txo_to_spender = {} # type: Dict[bytes, int] # (tx_num+txout_idx)->tx_num + + # Key: b'H' + address_hashX + tx_num + # Value: + # --- + # Key: b't' + tx_hash + # Value: tx_num + # --- + # Key: b's' + tx_num + txout_idx + # Value: tx_num + # "which tx spent this TXO?" -- note that UTXOs are not stored. self.db = None def open_db( self, + *, db_class: Type['Storage'], for_sync: bool, - utxo_flush_count: int, - compacting: bool, - ): + utxo_db_tx_count: int, + ) -> None: self.db = db_class('hist', for_sync) self.read_state() - self.clear_excess(utxo_flush_count) - # An incomplete compaction needs to be cancelled otherwise - # restarting it will corrupt the history - if not compacting: - self._cancel_compaction() - return self.flush_count + self.clear_excess(utxo_db_tx_count) def close_db(self): if self.db: @@ -72,22 +84,15 @@ def close_db(self): self.db = None def read_state(self): - state = self.db.get(b'state\0\0') + state = self.db.get(b'\0state') if state: state = ast.literal_eval(state.decode()) if not isinstance(state, dict): raise RuntimeError('failed reading state from history DB') - self.flush_count = state['flush_count'] - self.comp_flush_count = state.get('comp_flush_count', -1) - self.comp_cursor = state.get('comp_cursor', -1) self.db_version = state.get('db_version', 0) self.upgrade_cursor = state.get('upgrade_cursor', -1) - else: - self.flush_count = 0 - self.comp_flush_count = -1 - self.comp_cursor = -1 - self.db_version = max(self.DB_VERSIONS) - self.upgrade_cursor = -1 + self.hist_db_tx_count = state.get('hist_db_tx_count', 0) + self.hist_db_tx_count_next = self.hist_db_tx_count if self.db_version not in self.DB_VERSIONS: msg = (f'your history DB version is {self.db_version} but ' @@ -95,30 +100,53 @@ def read_state(self): self.logger.error(msg) raise RuntimeError(msg) if self.db_version != max(self.DB_VERSIONS): - self.upgrade_db() + pass # call future upgrade logic here self.logger.info(f'history DB version: {self.db_version}') - self.logger.info(f'flush count: {self.flush_count:,d}') - def clear_excess(self, utxo_flush_count): - # < might happen at end of compaction as both DBs cannot be - # updated atomically - if self.flush_count <= utxo_flush_count: + def clear_excess(self, utxo_db_tx_count: int) -> None: + # self.hist_db_tx_count != utxo_db_tx_count might happen as + # both DBs cannot be updated atomically + # FIXME when advancing blocks, hist_db is flushed first, so its count can be higher; + # but when backing up (e.g. reorg), hist_db is flushed first as well, + # so its count can be lower?! + # Shouldn't we flush utxo_db first when backing up? + if self.hist_db_tx_count <= utxo_db_tx_count: + assert self.hist_db_tx_count == utxo_db_tx_count return self.logger.info('DB shut down uncleanly. Scanning for ' 'excess history flushes...') - keys = [] - for key, _hist in self.db.iterator(prefix=b''): - flush_id, = unpack_be_uint16_from(key[-FLUSHID_LEN:]) - if flush_id > utxo_flush_count: - keys.append(key) - - self.logger.info(f'deleting {len(keys):,d} history entries') - - self.flush_count = utxo_flush_count + hkeys = [] + for db_key, db_val in self.db.iterator(prefix=b'H'): + tx_numb = db_key[-TXNUM_LEN:] + tx_num = unpack_txnum(tx_numb) + if tx_num >= utxo_db_tx_count: + hkeys.append(db_key) + + tkeys = [] + for db_key, db_val in self.db.iterator(prefix=b't'): + tx_numb = db_val + tx_num = unpack_txnum(tx_numb) + if tx_num >= utxo_db_tx_count: + tkeys.append(db_key) + + skeys = [] + for db_key, db_val in self.db.iterator(prefix=b's'): + tx_numb1 = db_key[1:1+TXNUM_LEN] + tx_numb2 = db_val + tx_num1 = unpack_txnum(tx_numb1) + tx_num2 = unpack_txnum(tx_numb2) + if max(tx_num1, tx_num2) >= utxo_db_tx_count: + skeys.append(db_key) + + self.logger.info(f'deleting {len(hkeys):,d} addr entries,' + f' {len(tkeys):,d} txs, and {len(skeys):,d} spends') + + self.hist_db_tx_count = utxo_db_tx_count + self.hist_db_tx_count_next = self.hist_db_tx_count with self.db.write_batch() as batch: - for key in keys: + for key in itertools.chain(hkeys, tkeys, skeys): batch.delete(key) self.write_state(batch) @@ -127,278 +155,165 @@ def clear_excess(self, utxo_flush_count): def write_state(self, batch): '''Write state to the history DB.''' state = { - 'flush_count': self.flush_count, - 'comp_flush_count': self.comp_flush_count, - 'comp_cursor': self.comp_cursor, + 'hist_db_tx_count': self.hist_db_tx_count, 'db_version': self.db_version, 'upgrade_cursor': self.upgrade_cursor, } - # History entries are not prefixed; the suffix \0\0 ensures we - # look similar to other entries and aren't interfered with - batch.put(b'state\0\0', repr(state).encode()) + batch.put(b'\0state', repr(state).encode()) - def add_unflushed(self, hashXs_by_tx, first_tx_num): - unflushed = self.unflushed + def add_unflushed( + self, + *, + hashXs_by_tx: Sequence[Sequence[bytes]], + first_tx_num: int, + txhash_to_txnum_map: Dict[bytes, int], + txo_to_spender_map: Dict[Tuple[bytes, int], bytes], # (tx_hash, txout_idx) -> tx_hash + ): + unflushed_hashxs = self._unflushed_hashxs count = 0 + tx_num = None for tx_num, hashXs in enumerate(hashXs_by_tx, start=first_tx_num): - tx_numb = pack_le_uint64(tx_num)[:TXNUM_LEN] + tx_numb = pack_txnum(tx_num) hashXs = set(hashXs) for hashX in hashXs: - unflushed[hashX] += tx_numb + unflushed_hashxs[hashX] += tx_numb count += len(hashXs) - self.unflushed_count += count + self._unflushed_hashxs_count += count + if tx_num is not None: + assert self.hist_db_tx_count_next + len(hashXs_by_tx) == tx_num + 1 + self.hist_db_tx_count_next = tx_num + 1 + + self._unflushed_txhash_to_txnum_map.update(txhash_to_txnum_map) + + unflushed_spenders = self._unflushed_txo_to_spender + get_txnum_for_txhash = self.get_txnum_for_txhash + for (prev_hash, prev_idx), spender_hash in txo_to_spender_map.items(): + prev_txnum = get_txnum_for_txhash(prev_hash) + assert prev_txnum is not None + spender_txnum = get_txnum_for_txhash(spender_hash) + assert spender_txnum is not None + prev_idx_packed = pack_le_uint32(prev_idx)[:TXOUTIDX_LEN] + prev_txnumb = pack_txnum(prev_txnum) + unflushed_spenders[prev_txnumb+prev_idx_packed] = spender_txnum def unflushed_memsize(self): - return len(self.unflushed) * 180 + self.unflushed_count * TXNUM_LEN + hashXs = len(self._unflushed_hashxs) * 180 + self._unflushed_hashxs_count * TXNUM_LEN + txs = 232 + 93 * len(self._unflushed_txhash_to_txnum_map) + spenders = 102 + 113 * len(self._unflushed_txo_to_spender) + return hashXs + txs + spenders def assert_flushed(self): - assert not self.unflushed + assert not self._unflushed_hashxs + assert not self._unflushed_txhash_to_txnum_map + assert not self._unflushed_txo_to_spender def flush(self): start_time = time.monotonic() - self.flush_count += 1 - flush_id = pack_be_uint16(self.flush_count) - unflushed = self.unflushed + unflushed_hashxs = self._unflushed_hashxs + chunks = util.chunks with self.db.write_batch() as batch: - for hashX in sorted(unflushed): - key = hashX + flush_id - batch.put(key, bytes(unflushed[hashX])) + for hashX in sorted(unflushed_hashxs): + for tx_num in sorted(chunks(unflushed_hashxs[hashX], TXNUM_LEN)): + db_key = b'H' + hashX + tx_num + batch.put(db_key, b'') + for tx_hash, tx_num in sorted(self._unflushed_txhash_to_txnum_map.items()): + db_key = b't' + tx_hash + tx_numb = pack_txnum(tx_num) + batch.put(db_key, tx_numb) + for prevout, spender_txnum in sorted(self._unflushed_txo_to_spender.items()): + db_key = b's' + prevout + db_val = pack_txnum(spender_txnum) + batch.put(db_key, db_val) + self.hist_db_tx_count = self.hist_db_tx_count_next self.write_state(batch) - count = len(unflushed) - unflushed.clear() - self.unflushed_count = 0 + addr_count = len(unflushed_hashxs) + tx_count = len(self._unflushed_txhash_to_txnum_map) + spend_count = len(self._unflushed_txo_to_spender) + unflushed_hashxs.clear() + self._unflushed_hashxs_count = 0 + self._unflushed_txhash_to_txnum_map.clear() + self._unflushed_txo_to_spender.clear() if self.db.for_sync: elapsed = time.monotonic() - start_time - self.logger.info(f'flushed history in {elapsed:.1f}s ' - f'for {count:,d} addrs') - - def backup(self, hashXs, tx_count): - # Not certain this is needed, but it doesn't hurt - self.flush_count += 1 - nremoves = 0 - bisect_left = bisect.bisect_left - chunks = util.chunks + self.logger.info(f'flushed history in {elapsed:.1f}s, for: ' + f'{addr_count:,d} addrs, {tx_count:,d} txs, {spend_count:,d} spends') - txnum_padding = bytes(8-TXNUM_LEN) + def backup(self, *, hashXs, tx_count, tx_hashes: Sequence[bytes], spends: Sequence[bytes]): + self.assert_flushed() + get_txnum_for_txhash = self.get_txnum_for_txhash + nremoves_addr = 0 with self.db.write_batch() as batch: for hashX in sorted(hashXs): deletes = [] - puts = {} - for key, hist in self.db.iterator(prefix=hashX, reverse=True): - a = array( - 'Q', - b''.join(item + txnum_padding for item in chunks(hist, TXNUM_LEN)) - ) - # Remove all history entries >= tx_count - idx = bisect_left(a, tx_count) - nremoves += len(a) - idx - if idx > 0: - puts[key] = hist[:TXNUM_LEN * idx] + prefix = b'H' + hashX + for db_key, db_val in self.db.iterator(prefix=prefix, reverse=True): + tx_numb = db_key[-TXNUM_LEN:] + tx_num = unpack_txnum(tx_numb) + if tx_num >= tx_count: + nremoves_addr += 1 + deletes.append(db_key) + else: + # note: we can break now, due to 'reverse=True' and txnums being big endian break - deletes.append(key) - for key in deletes: batch.delete(key) - for key, value in puts.items(): - batch.put(key, value) + for spend in spends: + prev_hash = spend[:32] + prev_idx = spend[32:] + assert len(prev_idx) == TXOUTIDX_LEN + prev_txnum = get_txnum_for_txhash(prev_hash) + assert prev_txnum is not None + prev_txnumb = pack_txnum(prev_txnum) + db_key = b's' + prev_txnumb + prev_idx + batch.delete(db_key) + for tx_hash in sorted(tx_hashes): + db_key = b't' + tx_hash + batch.delete(db_key) + self.hist_db_tx_count = tx_count + self.hist_db_tx_count_next = self.hist_db_tx_count self.write_state(batch) - self.logger.info(f'backing up removed {nremoves:,d} history entries') + self.logger.info(f'backing up history, removed {nremoves_addr:,d} addrs, ' + f'{len(tx_hashes):,d} txs, and {len(spends):,d} spends') - def get_txnums(self, hashX, limit=1000): + def get_txnums(self, hashX: bytes, limit: Optional[int] = 1000): '''Generator that returns an unpruned, sorted list of tx_nums in the history of a hashX. Includes both spending and receiving transactions. By default yields at most 1000 entries. Set limit to None to get them all. ''' limit = util.resolve_limit(limit) - chunks = util.chunks - txnum_padding = bytes(8-TXNUM_LEN) - for _key, hist in self.db.iterator(prefix=hashX): - for tx_numb in chunks(hist, TXNUM_LEN): - if limit == 0: - return - tx_num, = unpack_le_uint64(tx_numb + txnum_padding) - yield tx_num - limit -= 1 - - # - # History compaction - # - - # comp_cursor is a cursor into compaction progress. - # -1: no compaction in progress - # 0-65535: Compaction in progress; all prefixes < comp_cursor have - # been compacted, and later ones have not. - # 65536: compaction complete in-memory but not flushed - # - # comp_flush_count applies during compaction, and is a flush count - # for history with prefix < comp_cursor. flush_count applies - # to still uncompacted history. It is -1 when no compaction is - # taking place. Key suffixes up to and including comp_flush_count - # are used, so a parallel history flush must first increment this - # - # When compaction is complete and the final flush takes place, - # flush_count is reset to comp_flush_count, and comp_flush_count to -1 - - def _flush_compaction(self, cursor, write_items, keys_to_delete): - '''Flush a single compaction pass as a batch.''' - # Update compaction state - if cursor == 65536: - self.flush_count = self.comp_flush_count - self.comp_cursor = -1 - self.comp_flush_count = -1 - else: - self.comp_cursor = cursor - - # History DB. Flush compacted history and updated state - with self.db.write_batch() as batch: - # Important: delete first! The keyspace may overlap. - for key in keys_to_delete: - batch.delete(key) - for key, value in write_items: - batch.put(key, value) - self.write_state(batch) - - def _compact_hashX(self, hashX, hist_map, hist_list, - write_items, keys_to_delete): - '''Compres history for a hashX. hist_list is an ordered list of - the histories to be compressed.''' - # History entries (tx numbers) are TXNUM_LEN bytes each. Distribute - # over rows of up to 50KB in size. A fixed row size means - # future compactions will not need to update the first N - 1 - # rows. - max_row_size = self.max_hist_row_entries * TXNUM_LEN - full_hist = b''.join(hist_list) - nrows = (len(full_hist) + max_row_size - 1) // max_row_size - if nrows > 4: - self.logger.info( - f'hashX {hash_to_hex_str(hashX)} is large: ' - f'{len(full_hist) // TXNUM_LEN:,d} entries across {nrows:,d} rows' - ) - - # Find what history needs to be written, and what keys need to - # be deleted. Start by assuming all keys are to be deleted, - # and then remove those that are the same on-disk as when - # compacted. - write_size = 0 - keys_to_delete.update(hist_map) - for n, chunk in enumerate(util.chunks(full_hist, max_row_size)): - key = hashX + pack_be_uint16(n) - if hist_map.get(key) == chunk: - keys_to_delete.remove(key) - else: - write_items.append((key, chunk)) - write_size += len(chunk) - - assert n + 1 == nrows - self.comp_flush_count = max(self.comp_flush_count, n) - - return write_size - - def _compact_prefix(self, prefix, write_items, keys_to_delete): - '''Compact all history entries for hashXs beginning with the - given prefix. Update keys_to_delete and write.''' - prior_hashX = None - hist_map = {} - hist_list = [] - - key_len = HASHX_LEN + FLUSHID_LEN - write_size = 0 - for key, hist in self.db.iterator(prefix=prefix): - # Ignore non-history entries - if len(key) != key_len: - continue - hashX = key[:-FLUSHID_LEN] - if hashX != prior_hashX and prior_hashX: - write_size += self._compact_hashX(prior_hashX, hist_map, - hist_list, write_items, - keys_to_delete) - hist_map.clear() - hist_list.clear() - prior_hashX = hashX - hist_map[key] = hist - hist_list.append(hist) - - if prior_hashX: - write_size += self._compact_hashX(prior_hashX, hist_map, hist_list, - write_items, keys_to_delete) - return write_size - - def _compact_history(self, limit): - '''Inner loop of history compaction. Loops until limit bytes have - been processed. + prefix = b'H' + hashX + for db_key, db_val in self.db.iterator(prefix=prefix): + tx_numb = db_key[-TXNUM_LEN:] + if limit == 0: + return + tx_num = unpack_txnum(tx_numb) + yield tx_num + limit -= 1 + + def get_txnum_for_txhash(self, tx_hash: bytes) -> Optional[int]: + tx_num = self._unflushed_txhash_to_txnum_map.get(tx_hash) + if tx_num is None: + db_key = b't' + tx_hash + tx_numb = self.db.get(db_key) + if tx_numb: + tx_num = unpack_txnum(tx_numb) + return tx_num + + def get_spender_txnum_for_txo(self, prev_txnum: int, txout_idx: int) -> Optional[int]: + '''For an outpoint, returns the tx_num that spent it. + If the outpoint is unspent, or even if it never existed (!), returns None. ''' - keys_to_delete = set() - write_items = [] # A list of (key, value) pairs - write_size = 0 - - # Loop over 2-byte prefixes - cursor = self.comp_cursor - while write_size < limit and cursor < 65536: - prefix = pack_be_uint16(cursor) - write_size += self._compact_prefix(prefix, write_items, - keys_to_delete) - cursor += 1 - - max_rows = self.comp_flush_count + 1 - self._flush_compaction(cursor, write_items, keys_to_delete) - - self.logger.info( - f'history compaction: wrote {len(write_items):,d} rows ' - f'({write_size / 1000000:.1f} MB), removed ' - f'{len(keys_to_delete):,d} rows, largest: {max_rows:,d}, ' - f'{100 * cursor / 65536:.1f}% complete' - ) - return write_size - - def _cancel_compaction(self): - if self.comp_cursor != -1: - self.logger.warning('cancelling in-progress history compaction') - self.comp_flush_count = -1 - self.comp_cursor = -1 - - # - # DB upgrade - # - - def upgrade_db(self): - self.logger.info(f'history DB version: {self.db_version}') - self.logger.info('Upgrading your history DB; this can take some time...') - - def upgrade_cursor(cursor): - count = 0 - prefix = pack_be_uint16(cursor) - key_len = HASHX_LEN + 2 - chunks = util.chunks - with self.db.write_batch() as batch: - batch_put = batch.put - for key, hist in self.db.iterator(prefix=prefix): - # Ignore non-history entries - if len(key) != key_len: - continue - count += 1 - hist = b''.join(item + b'\0' for item in chunks(hist, 4)) - batch_put(key, hist) - self.upgrade_cursor = cursor - self.write_state(batch) - return count - - last = time.monotonic() - count = 0 - - for cursor in range(self.upgrade_cursor + 1, 65536): - count += upgrade_cursor(cursor) - now = time.monotonic() - if now > last + 10: - last = now - self.logger.info(f'DB 3 of 3: {count:,d} entries updated, ' - f'{cursor * 100 / 65536:.1f}% complete') - - self.db_version = max(self.DB_VERSIONS) - self.upgrade_cursor = -1 - with self.db.write_batch() as batch: - self.write_state(batch) - self.logger.info('DB 3 of 3 upgraded successfully') + prev_idx_packed = pack_le_uint32(txout_idx)[:TXOUTIDX_LEN] + prev_txnumb = pack_txnum(prev_txnum) + prevout = prev_txnumb + prev_idx_packed + spender_txnum = self._unflushed_txhash_to_txnum_map.get(prevout) + if spender_txnum is None: + db_key = b's' + prevout + spender_txnumb = self.db.get(db_key) + if spender_txnumb: + spender_txnum = unpack_txnum(spender_txnumb) + return spender_txnum diff --git a/src/electrumx/server/mempool.py b/src/electrumx/server/mempool.py index 533233c1e..65e1dc20f 100644 --- a/src/electrumx/server/mempool.py +++ b/src/electrumx/server/mempool.py @@ -21,6 +21,7 @@ from electrumx.lib.hash import hash_to_hex_str, hex_str_to_hash from electrumx.lib.tx import SkipTxDeserialize from electrumx.lib.util import class_logger, chunks, OldTaskGroup +from electrumx.lib.tx import TXOSpendStatus from electrumx.server.db import UTXO if TYPE_CHECKING: @@ -53,17 +54,17 @@ class MemPoolAPI(ABC): and used by it to query DB and blockchain state.''' @abstractmethod - async def height(self): + async def height(self) -> int: '''Query bitcoind for its height.''' @abstractmethod - def cached_height(self): + def cached_height(self) -> Optional[int]: '''Return the height of bitcoind the last time it was queried, for any reason, without actually querying it. ''' @abstractmethod - def db_height(self): + def db_height(self) -> int: '''Return the height flushed to the on-disk DB.''' @abstractmethod @@ -80,17 +81,25 @@ async def raw_transactions(self, hex_hashes): @abstractmethod async def lookup_utxos(self, prevouts): - '''Return a list of (hashX, value) pairs each prevout if unspent, - otherwise return None if spent or not found. + '''Return a list of (hashX, value) pairs, one for each prevout if unspent, + otherwise return None if spent or not found (for the given prevout). - prevouts - an iterable of (hash, index) pairs + prevouts - an iterable of (tx_hash, txout_idx) pairs ''' @abstractmethod - async def on_mempool(self, touched, height): - '''Called each time the mempool is synchronized. touched is a set of - hashXs touched since the previous call. height is the - daemon's height at the time the mempool was obtained.''' + async def on_mempool( + self, + *, + touched_hashxs: Set[bytes], + touched_outpoints: Set[Tuple[bytes, int]], + height: int, + ): + '''Called each time the mempool is synchronized. touched_hashxs and + touched_outpoints are sets of hashXs and tx outpoints touched since + the previous call. height is the daemon's height at the time the + mempool was obtained. + ''' class MemPool: @@ -119,8 +128,9 @@ def __init__( self.coin = coin self.api = api self.logger = class_logger(__name__, self.__class__.__name__) - self.txs = {} # type: Dict[bytes, MemPoolTx] - self.hashXs = defaultdict(set) # None can be a key + self.txs = {} # type: Dict[bytes, MemPoolTx] # txid->tx + self.hashXs = defaultdict(set) # type: Dict[Optional[bytes], Set[bytes]] # hashX->txids + self.txo_to_spender = {} # type: Dict[Tuple[bytes, int], bytes] # prevout->txid self.cached_compact_histogram = [] self.refresh_secs = refresh_secs self.log_status_secs = log_status_secs @@ -137,8 +147,9 @@ async def _logging(self, synchronized_event): self.logger.info(f'synced in {elapsed:.2f}s') while True: mempool_size = sum(tx.size for tx in self.txs.values()) / 1_000_000 - self.logger.info(f'{len(self.txs):,d} txs {mempool_size:.2f} MB ' - f'touching {len(self.hashXs):,d} addresses') + self.logger.info(f'{len(self.txs):,d} txs {mempool_size:.2f} MB, ' + f'touching {len(self.hashXs):,d} addresses. ' + f'{len(self.txo_to_spender):,d} spends.') await sleep(self.log_status_secs) await synchronized_event.wait() @@ -205,7 +216,15 @@ def _compress_histogram( prev_fee_rate = fee_rate return compact - def _accept_transactions(self, tx_map: Dict[bytes, MemPoolTx], utxo_map, touched): + def _accept_transactions( + self, + *, + tx_map: Dict[bytes, MemPoolTx], # txid->tx + utxo_map: Dict[Tuple[bytes, int], Tuple[bytes, int]], # prevout->(hashX,value_in_sats) + touched_hashxs: Set[bytes], # set of hashXs + touched_outpoints: Set[Tuple[bytes, int]], # set of outpoints + ) -> Tuple[Dict[bytes, MemPoolTx], + Dict[Tuple[bytes, int], Tuple[bytes, int]]]: '''Accept transactions in tx_map to the mempool if all their inputs can be found in the existing mempool or a utxo_map from the DB. @@ -214,11 +233,12 @@ def _accept_transactions(self, tx_map: Dict[bytes, MemPoolTx], utxo_map, touched ''' hashXs = self.hashXs txs = self.txs + txo_to_spender = self.txo_to_spender deferred = {} unspent = set(utxo_map) # Try to find all prevouts so we can accept the TX - for hash, tx in tx_map.items(): + for tx_hash, tx in tx_map.items(): in_pairs = [] try: for prevout in tx.prevouts: @@ -229,7 +249,7 @@ def _accept_transactions(self, tx_map: Dict[bytes, MemPoolTx], utxo_map, touched utxo = txs[prev_hash].out_pairs[prev_index] in_pairs.append(utxo) except KeyError: - deferred[hash] = tx + deferred[tx_hash] = tx continue # Spend the prevouts @@ -241,19 +261,25 @@ def _accept_transactions(self, tx_map: Dict[bytes, MemPoolTx], utxo_map, touched # because some in_parts would be missing tx.fee = max(0, (sum(v for _, v in tx.in_pairs) - sum(v for _, v in tx.out_pairs))) - txs[hash] = tx + txs[tx_hash] = tx for hashX, _value in itertools.chain(tx.in_pairs, tx.out_pairs): - touched.add(hashX) - hashXs[hashX].add(hash) + touched_hashxs.add(hashX) + hashXs[hashX].add(tx_hash) + for prevout in tx.prevouts: + txo_to_spender[prevout] = tx_hash + touched_outpoints.add(prevout) + for out_idx, out_pair in enumerate(tx.out_pairs): + touched_outpoints.add((tx_hash, out_idx)) return deferred, {prevout: utxo_map[prevout] for prevout in unspent} async def _refresh_hashes(self, synchronized_event): '''Refresh our view of the daemon's mempool.''' - # Touched accumulates between calls to on_mempool and each + # touched_* accumulates between calls to on_mempool and each # call transfers ownership - touched = set() + touched_hashxs = set() + touched_outpoints = set() while True: height = self.api.cached_height() hex_hashes = await self.api.mempool_hashes() @@ -262,7 +288,12 @@ async def _refresh_hashes(self, synchronized_event): hashes = {hex_str_to_hash(hh) for hh in hex_hashes} try: async with self.lock: - await self._process_mempool(hashes, touched, height) + await self._process_mempool( + all_hashes=hashes, + touched_hashxs=touched_hashxs, + touched_outpoints=touched_outpoints, + mempool_height=height, + ) except DBSyncError: # The UTXO DB is not at the same height as the # mempool; wait and try again @@ -270,14 +301,27 @@ async def _refresh_hashes(self, synchronized_event): else: synchronized_event.set() synchronized_event.clear() - await self.api.on_mempool(touched, height) - touched = set() + await self.api.on_mempool( + touched_hashxs=touched_hashxs, + touched_outpoints=touched_outpoints, + height=height, + ) + touched_hashxs = set() + touched_outpoints = set() await sleep(self.refresh_secs) - async def _process_mempool(self, all_hashes: Set[bytes], touched, mempool_height): + async def _process_mempool( + self, + *, + all_hashes: Set[bytes], # set of txids + touched_hashxs: Set[bytes], # set of hashXs + touched_outpoints: Set[Tuple[bytes, int]], # set of outpoints + mempool_height: int, + ) -> None: # Re-sync with the new set of hashes txs = self.txs hashXs = self.hashXs + txo_to_spender = self.txo_to_spender if mempool_height != self.api.db_height(): raise DBSyncError @@ -285,20 +329,32 @@ async def _process_mempool(self, all_hashes: Set[bytes], touched, mempool_height # First handle txs that have disappeared for tx_hash in (set(txs) - all_hashes): tx = txs.pop(tx_hash) + # hashXs tx_hashXs = {hashX for hashX, value in tx.in_pairs} tx_hashXs.update(hashX for hashX, value in tx.out_pairs) for hashX in tx_hashXs: hashXs[hashX].remove(tx_hash) if not hashXs[hashX]: del hashXs[hashX] - touched |= tx_hashXs + touched_hashxs |= tx_hashXs + # outpoints + for prevout in tx.prevouts: + del txo_to_spender[prevout] + touched_outpoints.add(prevout) + for out_idx, out_pair in enumerate(tx.out_pairs): + touched_outpoints.add((tx_hash, out_idx)) # Process new transactions new_hashes = list(all_hashes.difference(txs)) if new_hashes: group = OldTaskGroup() for hashes in chunks(new_hashes, 200): - coro = self._fetch_and_accept(hashes, all_hashes, touched) + coro = self._fetch_and_accept( + hashes=hashes, + all_hashes=all_hashes, + touched_hashxs=touched_hashxs, + touched_outpoints=touched_outpoints, + ) await group.spawn(coro) if mempool_height != self.api.db_height(): raise DBSyncError @@ -314,14 +370,23 @@ async def _process_mempool(self, all_hashes: Set[bytes], touched, mempool_height # FIXME: this is not particularly efficient while tx_map and len(tx_map) != prior_count: prior_count = len(tx_map) - tx_map, utxo_map = self._accept_transactions(tx_map, utxo_map, - touched) + tx_map, utxo_map = self._accept_transactions( + tx_map=tx_map, + utxo_map=utxo_map, + touched_hashxs=touched_hashxs, + touched_outpoints=touched_outpoints, + ) if tx_map: self.logger.error(f'{len(tx_map)} txs dropped') - return touched - - async def _fetch_and_accept(self, hashes: Sequence[bytes], all_hashes: Set[bytes], touched): + async def _fetch_and_accept( + self, + *, + hashes: Set[bytes], # set of txids + all_hashes: Set[bytes], # set of txids + touched_hashxs: Set[bytes], # set of hashXs + touched_outpoints: Set[Tuple[bytes, int]], # set of outpoints + ): '''Fetch a list of mempool transactions.''' hex_hashes_iter = (hash_to_hex_str(hash) for hash in hashes) raw_txs = await self.api.raw_transactions(hex_hashes_iter) @@ -372,7 +437,12 @@ def deserialize_txs() -> Dict[bytes, MemPoolTx]: utxos = await self.api.lookup_utxos(prevouts) utxo_map = {prevout: utxo for prevout, utxo in zip(prevouts, utxos)} - return self._accept_transactions(tx_map, utxo_map, touched) + return self._accept_transactions( + tx_map=tx_map, + utxo_map=utxo_map, + touched_hashxs=touched_hashxs, + touched_outpoints=touched_outpoints, + ) # # External interface @@ -441,3 +511,36 @@ async def unordered_UTXOs(self, hashX): if hX == hashX: utxos.append(UTXO(-1, pos, tx_hash, 0, value)) return utxos + + async def spender_for_txo(self, prev_txhash: bytes, txout_idx: int) -> 'TXOSpendStatus': + '''For an outpoint, returns its spend-status. + This only considers the mempool, not the DB/blockchain, so e.g. mined + txs are not distinguished from txs that never existed. + ''' + # look up funding tx + prev_tx = self.txs.get(prev_txhash, None) + if prev_tx is None: + # funding tx already mined or never existed + prev_height = None + else: + if len(prev_tx.out_pairs) <= txout_idx: + # output idx out of bounds...? + return TXOSpendStatus(prev_height=None) + prev_has_ui = any(hash in self.txs for hash, idx in prev_tx.prevouts) + prev_height = -prev_has_ui + prevout = (prev_txhash, txout_idx) + # look up spending tx + spender_txhash = self.txo_to_spender.get(prevout, None) + spender_tx = self.txs.get(spender_txhash, None) + if spender_tx is None: + self.logger.warning(f"spender_tx {hash_to_hex_str(spender_txhash)} not in" + f"mempool, but txo_to_spender referenced it as spender " + f"of {hash_to_hex_str(prev_txhash)}:{txout_idx} ?!") + return TXOSpendStatus(prev_height=prev_height) + spender_has_ui = any(hash in self.txs for hash, idx in spender_tx.prevouts) + spender_height = -spender_has_ui + return TXOSpendStatus( + prev_height=prev_height, + spender_txhash=spender_txhash, + spender_height=spender_height, + ) diff --git a/src/electrumx/server/session.py b/src/electrumx/server/session.py index d9c8a010e..2ea7ea809 100644 --- a/src/electrumx/server/session.py +++ b/src/electrumx/server/session.py @@ -18,13 +18,16 @@ from collections import defaultdict from functools import partial from ipaddress import IPv4Address, IPv6Address, IPv4Network, IPv6Network -from typing import Iterable, Optional, TYPE_CHECKING, Sequence, Union, Any +from typing import Iterable, Optional, TYPE_CHECKING, Sequence, Union, Any, Tuple, Set, Dict, Mapping +from typing import Callable import attr from aiorpcx import (Event, JSONRPCAutoDetect, JSONRPCConnection, ReplyAndDisconnect, Request, RPCError, RPCSession, Service, handler_invocation, serve_rs, serve_ws, sleep, - NewlineFramer, TaskTimeout, timeout_after, run_in_thread) + NewlineFramer, TaskTimeout, timeout_after, run_in_thread, + Notification) +from aiorpcx.jsonrpc import SingleRequest import electrumx import electrumx.lib.util as util @@ -50,7 +53,7 @@ DAEMON_ERROR = 2 -def scripthash_to_hashX(scripthash): +def scripthash_to_hashX(scripthash: str) -> bytes: try: bin_hash = hex_str_to_hash(scripthash) if len(bin_hash) == 32: @@ -60,6 +63,13 @@ def scripthash_to_hashX(scripthash): raise RPCError(BAD_REQUEST, f'{scripthash} is not a valid script hash') +def spk_to_scripthash(spk: str) -> str: + """Converts scriptPubKey to scripthash.""" + assert_hex_str(spk) + h = sha256(bytes.fromhex(spk)) + return h[::-1].hex() + + def non_negative_integer(value): '''Return param value it is or can be converted to a non-negative integer, otherwise raise an RPCError.''' @@ -108,7 +118,7 @@ def assert_list_or_tuple(value: Any) -> None: class SessionGroup: name = attr.ib() weight = attr.ib() - sessions = attr.ib() + sessions = attr.ib() # type: Set[ElectrumX] retained_cost = attr.ib() def session_cost(self): @@ -151,8 +161,8 @@ def __init__( self.shutdown_event = shutdown_event self.logger = util.class_logger(__name__, self.__class__.__name__) self.servers = {} # service->server - self.sessions = {} # session->iterable of its SessionGroups - self.session_groups = {} # group name->SessionGroup instance + self.sessions = {} # type: Dict[ElectrumX, Iterable[SessionGroup]] + self.session_groups = {} # type: Dict[str, SessionGroup] self.txs_sent = 0 # Would use monotonic time, but aiorpcx sessions use Unix time: self.start_time = time.time() @@ -160,8 +170,10 @@ def __init__( self._reorg_count = 0 self._history_cache = LRUCache(maxsize=1000) self._txids_cache = LRUCache(maxsize=1000) + self._wtxids_cache = LRUCache(maxsize=1000) # Really a MerkleCache cache self._merkle_txid_cache = LRUCache(maxsize=1000) + self._merkle_wtxid_cache = LRUCache(maxsize=1000) self.estimatefee_cache = LRUCache(maxsize=1000) self.notified_height = None self.hsub_results = None @@ -299,7 +311,9 @@ async def _handle_chain_reorgs(self): self._reorg_count += 1 # not: history_cache is cleared in _notify_sessions self._txids_cache.clear() + self._wtxids_cache.clear() self._merkle_txid_cache.clear() + self._merkle_wtxid_cache.clear() async def _recalc_concurrency(self): '''Periodically recalculate session concurrency.''' @@ -324,7 +338,7 @@ async def _recalc_concurrency(self): # cost_decay_per_sec. for session in self.sessions: # Subs have an on-going cost so decay more slowly with more subs - session.cost_decay_per_sec = hard_limit / (10000 + 5 * session.sub_count()) + session.cost_decay_per_sec = hard_limit / (10000 + 5 * session.sub_count_total()) session.recalc_concurrency() def _get_info(self): @@ -337,23 +351,27 @@ def cache_fmt(cache: LRUCache): 'daemon': self.daemon.logged_url(), 'daemon height': self.daemon.cached_height(), 'db height': self.db.db_height, - 'db_flush_count': self.db.history.flush_count, 'groups': len(self.session_groups), 'history cache': cache_fmt(self._history_cache), 'merkle txid cache': cache_fmt(self._merkle_txid_cache), + 'merkle wtxid cache': cache_fmt(self._merkle_wtxid_cache), 'pid': os.getpid(), 'peers': self.peer_mgr.info(), 'request counts': self._method_counts, 'request total': sum(self._method_counts.values()), 'sessions': { 'count': len(sessions), - 'count with subs': sum(len(getattr(s, 'hashX_subs', ())) > 0 for s in sessions), + 'count with subs_sh': sum(s.sub_count_scripthashes() > 0 for s in sessions), + 'count with subs_txo': sum(s.sub_count_txoutpoints() > 0 for s in sessions), + 'count with subs_any': sum(s.sub_count_total() > 0 for s in sessions), 'errors': sum(s.errors for s in sessions), 'logged': len([s for s in sessions if s.log_me]), 'pending requests': sum(s.unanswered_request_count() for s in sessions), - 'subs': sum(s.sub_count() for s in sessions), + 'subs_sh': sum(s.sub_count_scripthashes() for s in sessions), + 'subs_txo': sum(s.sub_count_txoutpoints() for s in sessions), }, 'txids cache': cache_fmt(self._txids_cache), + 'wtxids cache': cache_fmt(self._wtxids_cache), 'txs sent': self.txs_sent, 'uptime': util.formatted_time(time.time() - self.start_time), 'version': electrumx.version, @@ -372,7 +390,7 @@ def _session_data(self, for_log): session.extra_cost(), session.unanswered_request_count(), session.txs_sent, - session.sub_count(), + session.sub_count_total(), session.recv_count, session.recv_size, session.send_count, session.send_size, now - session.start_time) @@ -389,7 +407,7 @@ def _group_data(self): group.retained_cost, sum(s.unanswered_request_count() for s in sessions), sum(s.txs_sent for s in sessions), - sum(s.sub_count() for s in sessions), + sum(s.sub_count_total() for s in sessions), sum(s.recv_count for s in sessions), sum(s.recv_size for s in sessions), sum(s.send_count for s in sessions), @@ -714,22 +732,39 @@ def extra_cost(self, session): return 0 return sum((group.cost() - session.cost) * group.weight for group in groups) - async def _merkle_branch(self, height, tx_hashes, tx_pos): + async def getrawtransaction(self, tx_hash: bytes, *, verbose: bool = False) -> str: + tx_hash_hex = hash_to_hex_str(tx_hash) + blockhash = None + if not self.env.daemon_has_txindex: + height, tx_pos = await self.db.get_blockheight_and_txpos_for_txhash(tx_hash) + if height is not None: + block_header = await self.db.raw_header(height) + blockhash = self.env.coin.header_hash(block_header).hex() + + return await self.daemon_request('getrawtransaction', tx_hash_hex, verbose, blockhash) + + async def _merkle_branch( + self, height: int, tx_hashes: Sequence[bytes], tx_pos: int, wtxid: bool = False, + ) -> Tuple[Sequence[str], float]: + mccache = self._merkle_wtxid_cache if wtxid else self._merkle_txid_cache tx_hash_count = len(tx_hashes) cost = tx_hash_count + if wtxid: + tx_hashes = list(tx_hashes) + tx_hashes[0] = bytes(32) # The wtxid of coinbase tx is assumed to be 0x0000....0000 if tx_hash_count >= 200: - self._merkle_txid_cache.num_lookups += 1 - merkle_cache = self._merkle_txid_cache.get(height) + mccache.num_lookups += 1 + merkle_cache = mccache.get(height) if merkle_cache: - self._merkle_txid_cache.num_hits += 1 + mccache.num_hits += 1 cost = 10 * math.sqrt(tx_hash_count) else: async def tx_hashes_func(start, count): return tx_hashes[start: start + count] merkle_cache = MerkleCache(self.db.merkle, tx_hashes_func) - self._merkle_txid_cache[height] = merkle_cache + mccache[height] = merkle_cache await merkle_cache.initialize(len(tx_hashes)) branch, _root = await merkle_cache.branch_and_root(tx_hash_count, tx_pos) else: @@ -738,16 +773,26 @@ async def tx_hashes_func(start, count): branch = [hash_to_hex_str(hash) for hash in branch] return branch, cost / 2500 - async def merkle_branch_for_tx_hash(self, height, tx_hash): - '''Return a triple (branch, tx_pos, cost).''' - tx_hashes, tx_hashes_cost = await self.tx_hashes_at_blockheight(height) - try: - tx_pos = tx_hashes.index(tx_hash) - except ValueError: + async def merkle_branch_for_tx_hash( + self, *, tx_hash: bytes, witness: bool, + ) -> Tuple[int, Optional[bytes], Sequence[str], int, bytes, float]: + '''Returns (height, wtxid, branch, tx_pos, block_header, cost).''' + cost = 0.1 + height, tx_pos = await self.db.get_blockheight_and_txpos_for_txhash(tx_hash) + if height is None: raise RPCError(BAD_REQUEST, - f'tx {hash_to_hex_str(tx_hash)} not in block at height {height:,d}') - branch, merkle_cost = await self._merkle_branch(height, tx_hashes, tx_pos) - return branch, tx_pos, tx_hashes_cost + merkle_cost + f'tx {hash_to_hex_str(tx_hash)} not in any block') + assert tx_pos is not None + block_header = await self.raw_header(height) + tx_hashes, tx_hashes_cost = await self.tx_hashes_at_blockheight(height, wtxid=witness) + wtxid = tx_hashes[tx_pos] if witness else None + if block_header != await self.raw_header(height): + # there was a reorg while processing the request... TODO maybe retry? + raise RPCError(BAD_REQUEST, + f'tx {hash_to_hex_str(tx_hash)} was reorged while processing request') + branch, merkle_cost = await self._merkle_branch(height, tx_hashes, tx_pos, wtxid=witness) + cost += tx_hashes_cost + merkle_cost + return height, wtxid, branch, tx_pos, block_header, cost async def merkle_branch_for_tx_pos(self, height, tx_pos): '''Return a triple (branch, tx_hash_hex, cost).''' @@ -760,29 +805,32 @@ async def merkle_branch_for_tx_pos(self, height, tx_pos): branch, merkle_cost = await self._merkle_branch(height, tx_hashes, tx_pos) return branch, hash_to_hex_str(tx_hash), tx_hashes_cost + merkle_cost - async def tx_hashes_at_blockheight(self, height): + async def tx_hashes_at_blockheight( + self, height, *, wtxid: bool = False, + ) -> Tuple[Sequence[bytes], float]: '''Returns a pair (tx_hashes, cost). tx_hashes is an ordered list of binary hashes, cost is an estimated cost of getting the hashes; cheaper if in-cache. Raises RPCError. ''' - self._txids_cache.num_lookups += 1 - tx_hashes = self._txids_cache.get(height) + cache = self._wtxids_cache if wtxid else self._txids_cache + cache.num_lookups += 1 + tx_hashes = cache.get(height) # type: Sequence[bytes] if tx_hashes: - self._txids_cache.num_hits += 1 + cache.num_hits += 1 return tx_hashes, 0.1 # Ensure the tx_hashes are fresh before placing in the cache while True: reorg_count = self._reorg_count try: - tx_hashes = await self.db.tx_hashes_at_blockheight(height) + tx_hashes = await self.db.tx_hashes_at_blockheight(height, wtxid=wtxid) except self.db.DBError as e: raise RPCError(BAD_REQUEST, f'db error: {e!r}') if reorg_count == self._reorg_count: break - self._txids_cache[height] = tx_hashes + cache[height] = tx_hashes return tx_hashes, 0.25 + len(tx_hashes) * 0.0001 @@ -797,7 +845,7 @@ async def daemon_request(self, method, *args): except DaemonError as e: raise RPCError(DAEMON_ERROR, f'daemon error: {e!r}') from None - async def raw_header(self, height): + async def raw_header(self, height: int) -> bytes: '''Return the binary header at the given height.''' try: return await self.db.raw_header(height) @@ -838,21 +886,32 @@ async def limited_history(self, hashX): raise result return result, cost - async def _notify_sessions(self, height, touched): + async def _notify_sessions( + self, + *, + touched_hashxs: Set[bytes], + touched_outpoints: Set[Tuple[bytes, int]], + height: int, + ): '''Notify sessions about height changes and touched addresses.''' height_changed = height != self.notified_height if height_changed: await self._refresh_hsub_results(height) # Invalidate our history cache for touched hashXs cache = self._history_cache - for hashX in set(cache).intersection(touched): + for hashX in set(cache).intersection(touched_hashxs): del cache[hashX] for session in self.sessions: if self._task_group.joined: # this can happen during shutdown self.logger.warning(f"task group already terminated. not notifying sessions.") return - await self._task_group.spawn(session.notify, touched, height_changed) + coro = session.notify( + touched_hashxs=touched_hashxs, + touched_outpoints=touched_outpoints, + height_changed=height_changed, + ) + await self._task_group.spawn(coro) def _ip_addr_group_name(self, session) -> Optional[str]: host = session.remote_address().host @@ -939,6 +998,8 @@ class SessionBase(RPCSessionWithTaskGroup): MAX_CHUNK_SIZE = 2016 session_counter = itertools.count() log_new = False + request_handlers: Dict[str, Callable] + notification_handlers: Dict[str, Callable] def __init__( self, @@ -980,7 +1041,13 @@ def __init__( self.session_mgr.add_session(self) self.recalc_concurrency() # must be called after session_mgr.add_session - async def notify(self, touched, height_changed): + async def notify( + self, + *, + touched_hashxs: Set[bytes], + touched_outpoints: Set[Tuple[bytes, int]], + height_changed: bool, + ): pass def default_framer(self): @@ -1016,17 +1083,22 @@ async def connection_lost(self): msg = 'disconnected' + msg self.logger.info(msg) - def sub_count(self): + def sub_count_scripthashes(self): return 0 - async def handle_request(self, request): - '''Handle an incoming request. ElectrumX doesn't receive - notifications from client sessions. - ''' + def sub_count_txoutpoints(self): + return 0 + + def sub_count_total(self): + return self.sub_count_scripthashes() + self.sub_count_txoutpoints() + + async def handle_request(self, request: SingleRequest): + '''Handle an incoming request.''' + handler = None if isinstance(request, Request): handler = self.request_handlers.get(request.method) - else: - handler = None + elif isinstance(request, Notification): + handler = self.notification_handlers.get(request.method) method = 'invalid method' if handler is None else request.method # Version negotiation must happen before any other messages. @@ -1073,9 +1145,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.subscribe_headers = False self.connection.max_response_size = self.env.max_send - self.hashX_subs = {} - self.sv_seen = False - self.mempool_statuses = {} + self.hashX_subs = {} # type: Dict[bytes, bytes] # hashX -> scripthash + self.txoutpoint_subs = set() # type: Set[Tuple[bytes, int]] + self.mempool_hashX_statuses = {} # type: Dict[bytes, str] + self.mempool_txoutpoint_statuses = {} # type: Dict[Tuple[bytes, int], Mapping[str, Any]] self.set_request_handlers(self.PROTOCOL_MIN) self.is_peer = False self.cost = 5.0 # Connection cost @@ -1104,6 +1177,7 @@ def server_features(cls, env): 'genesis_hash': env.coin.GENESIS_HASH, 'hash_function': 'sha256', 'services': [str(service) for service in env.report_services], + 'method_flavours': {}, } async def server_features_async(self): @@ -1128,46 +1202,68 @@ def on_disconnect_due_to_excessive_session_cost(self): group_names = [group.name for group in groups] self.logger.info(f"closing session over res usage. ip: {ip_addr}. groups: {group_names}") - def sub_count(self): + def sub_count_scripthashes(self): return len(self.hashX_subs) + def sub_count_txoutpoints(self): + return len(self.txoutpoint_subs) + def unsubscribe_hashX(self, hashX): - self.mempool_statuses.pop(hashX, None) + self.mempool_hashX_statuses.pop(hashX, None) return self.hashX_subs.pop(hashX, None) - async def notify(self, touched, height_changed): + async def notify( + self, + *, + touched_hashxs: Set[bytes], + touched_outpoints: Set[Tuple[bytes, int]], + height_changed: bool, + ): '''Wrap _notify_inner; websockets raises exceptions for unclear reasons.''' try: async with timeout_after(30): - await self._notify_inner(touched, height_changed) + await self._notify_inner( + touched_hashxs=touched_hashxs, + touched_outpoints=touched_outpoints, + height_changed=height_changed, + ) except TaskTimeout: self.logger.warning('timeout notifying client, closing...') await self.close(force_after=1.0) except Exception: self.logger.exception('unexpected exception notifying client') - async def _notify_inner(self, touched, height_changed): + async def _notify_inner( + self, + *, + touched_hashxs: Set[bytes], + touched_outpoints: Set[Tuple[bytes, int]], + height_changed: bool, + ): '''Notify the client about changes to touched addresses (from mempool updates or new blocks) and height. ''' + # block headers if height_changed and self.subscribe_headers: args = (await self.subscribe_headers_result(), ) await self.send_notification('blockchain.headers.subscribe', args) - touched = touched.intersection(self.hashX_subs) - if touched or (height_changed and self.mempool_statuses): + # hashXs + num_hashx_notifs_sent = 0 + touched_hashxs = touched_hashxs.intersection(self.hashX_subs) + if touched_hashxs or (height_changed and self.mempool_hashX_statuses): changed = {} - for hashX in touched: + for hashX in touched_hashxs: alias = self.hashX_subs.get(hashX) if alias: status = await self.subscription_address_status(hashX) changed[alias] = status # Check mempool hashXs - the status is a function of the confirmed state of - # other transactions. - mempool_statuses = self.mempool_statuses.copy() - for hashX, old_status in mempool_statuses.items(): + # other transactions. (this is to detect if height changed from -1 to 0) + mempool_hashX_statuses = self.mempool_hashX_statuses.copy() + for hashX, old_status in mempool_hashX_statuses.items(): alias = self.hashX_subs.get(hashX) if alias: status = await self.subscription_address_status(hashX) @@ -1177,10 +1273,36 @@ async def _notify_inner(self, touched, height_changed): method = 'blockchain.scripthash.subscribe' for alias, status in changed.items(): await self.send_notification(method, (alias, status)) - - if changed: - es = '' if len(changed) == 1 else 'es' - self.logger.info(f'notified of {len(changed):,d} address{es}') + num_hashx_notifs_sent = len(changed) + + # tx outpoints + num_txo_notifs_sent = 0 + touched_outpoints = touched_outpoints.intersection(self.txoutpoint_subs) + if touched_outpoints or (height_changed and self.mempool_txoutpoint_statuses): + method = 'blockchain.outpoint.subscribe' + txo_to_status = {} + for prevout in touched_outpoints: + txo_to_status[prevout] = await self.txoutpoint_status(*prevout) + + # Check mempool TXOs - the status is a function of the confirmed state of + # other transactions. (this is to detect if height changed from -1 to 0) + mempool_txoutpoint_statuses = self.mempool_txoutpoint_statuses.copy() + for prevout, old_status in mempool_txoutpoint_statuses.items(): + status = await self.txoutpoint_status(*prevout) + if status != old_status: + txo_to_status[prevout] = status + + for tx_hash, txout_idx in touched_outpoints: + spend_status = txo_to_status[(tx_hash, txout_idx)] + tx_hash_hex = hash_to_hex_str(tx_hash) + await self.send_notification(method, ((tx_hash_hex, txout_idx), spend_status)) + num_txo_notifs_sent = len(touched_outpoints) + + if num_hashx_notifs_sent + num_txo_notifs_sent > 0: + es1 = '' if num_hashx_notifs_sent == 1 else 'es' + s2 = '' if num_txo_notifs_sent == 1 else 's' + self.logger.info(f'notified of {num_hashx_notifs_sent:,d} address{es1} and ' + f'{num_txo_notifs_sent:,d} outpoint{s2}') async def subscribe_headers_result(self): '''The result of a header subscription or notification.''' @@ -1203,7 +1325,7 @@ async def peers_subscribe(self): self.bump_cost(1.0) return self.peer_mgr.on_peers_subscribe(self.is_tor()) - async def address_status(self, hashX): + async def address_status(self, hashX: bytes) -> Optional[str]: '''Returns an address status. Status is a hex string, but must be None if there is no history. @@ -1229,9 +1351,9 @@ async def address_status(self, hashX): status = None if mempool: - self.mempool_statuses[hashX] = status + self.mempool_hashX_statuses[hashX] = status else: - self.mempool_statuses.pop(hashX, None) + self.mempool_hashX_statuses.pop(hashX, None) return status @@ -1244,6 +1366,40 @@ async def subscription_address_status(self, hashX): self.unsubscribe_hashX(hashX) return None + async def txoutpoint_status(self, prev_txhash: bytes, txout_idx: int) -> Dict[str, Any]: + self.bump_cost(0.2) + spend_status = await self.db.spender_for_txo(prev_txhash, txout_idx) + if spend_status.spender_height is not None: + # TXO was created, was mined, was spent, and spend was mined. + assert spend_status.prev_height > 0 + assert spend_status.spender_height > 0 + assert spend_status.spender_txhash is not None + else: + mp_spend_status = await self.mempool.spender_for_txo(prev_txhash, txout_idx) + if mp_spend_status.prev_height is not None: + spend_status.prev_height = mp_spend_status.prev_height + if mp_spend_status.spender_height is not None: + spend_status.spender_height = mp_spend_status.spender_height + if mp_spend_status.spender_txhash is not None: + spend_status.spender_txhash = mp_spend_status.spender_txhash + # convert to json dict the client expects + status = {} + if spend_status.prev_height is not None: + status['height'] = spend_status.prev_height + if spend_status.spender_txhash is not None: + assert spend_status.spender_height is not None + status['spender_txhash'] = hash_to_hex_str(spend_status.spender_txhash) + status['spender_height'] = spend_status.spender_height + + prevout = (prev_txhash, txout_idx) + if ((spend_status.prev_height is not None and spend_status.prev_height <= 0) + or (spend_status.spender_height is not None and spend_status.spender_height <= 0)): + self.mempool_txoutpoint_statuses[prevout] = status + else: + self.mempool_txoutpoint_statuses.pop(prevout, None) + + return status + async def hashX_listunspent(self, hashX): '''Return the list of UTXOs of a script hash, including mempool effects.''' @@ -1323,6 +1479,55 @@ async def scripthash_unsubscribe(self, scripthash): hashX = scripthash_to_hashX(scripthash) return self.unsubscribe_hashX(hashX) is not None + def scriptpubkey_get_balance(self, spk: str): + scripthash = spk_to_scripthash(spk) + return self.scripthash_get_balance(scripthash) + + def scriptpubkey_get_history(self, spk: str): + scripthash = spk_to_scripthash(spk) + return self.scripthash_get_history(scripthash) + + def scriptpubkey_get_mempool(self, spk: str): + scripthash = spk_to_scripthash(spk) + return self.scripthash_get_mempool(scripthash) + + def scriptpubkey_listunspent(self, spk: str): + scripthash = spk_to_scripthash(spk) + return self.scripthash_listunspent(scripthash) + + def scriptpubkey_subscribe(self, spk: str): + scripthash = spk_to_scripthash(spk) + return self.scripthash_subscribe(scripthash) + + def scriptpubkey_unsubscribe(self, spk: str): + scripthash = spk_to_scripthash(spk) + return self.scripthash_unsubscribe(scripthash) + + async def txoutpoint_subscribe(self, tx_hash, txout_idx, spk_hint=None): + '''Subscribe to an outpoint. + + spk_hint: scriptPubKey corresponding to the outpoint. Might be used by + other servers, but we don't need and hence ignore it. + ''' + tx_hash = assert_tx_hash(tx_hash) + txout_idx = non_negative_integer(txout_idx) + if spk_hint is not None: + assert_hex_str(spk_hint) + spend_status = await self.txoutpoint_status(tx_hash, txout_idx) + self.txoutpoint_subs.add((tx_hash, txout_idx)) + return spend_status + + async def txoutpoint_unsubscribe(self, tx_hash, txout_idx): + '''Unsubscribe from an outpoint.''' + tx_hash = assert_tx_hash(tx_hash) + txout_idx = non_negative_integer(txout_idx) + self.bump_cost(0.1) + prevout = (tx_hash, txout_idx) + was_subscribed = prevout in self.txoutpoint_subs + self.txoutpoint_subs.discard(prevout) + self.mempool_txoutpoint_statuses.pop(prevout, None) + return was_subscribed + async def _merkle_proof(self, cp_height, height): max_height = self.db.db_height if not height <= cp_height <= max_height: @@ -1521,12 +1726,25 @@ async def estimatefee(self, number, mode=None): cache[(number, mode)] = (blockhash, feerate, lock) return feerate - async def ping(self): + async def ping(self, pong_len=0, data=""): '''Serves as a connection keep-alive mechanism and for the client to - confirm the server is still responding. + confirm the server is still responding. It can also be used to obfuscate + traffic patterns. ''' self.bump_cost(0.1) - return None + if self.protocol_tuple < (1, 7): + return None + assert_hex_str(data) + pong_len = non_negative_integer(pong_len) + if pong_len > self.env.max_send: + raise RPCError(BAD_REQUEST, f'pong_len value too high') + pong_data = pong_len * "0" + return {"data": pong_data} + + async def on_ping_notification(self, data=""): + self.bump_cost(0.1) # note: the bw cost for receiving 'data' has already been incurred + assert_hex_str(data) + # nothing to do async def server_version( self, @@ -1644,20 +1862,20 @@ async def package_broadcast(self, tx_package: Sequence[str], verbose: bool = Fal response['errors'] = errors return response - async def transaction_get(self, tx_hash, verbose=False): + async def transaction_get(self, tx_hash: str, verbose=False): '''Return the serialized raw transaction given its hash tx_hash: the transaction hash as a hexadecimal string verbose: passed on to the daemon ''' - assert_tx_hash(tx_hash) + tx_hash = assert_tx_hash(tx_hash) if verbose not in (True, False): raise RPCError(BAD_REQUEST, '"verbose" must be a boolean') self.bump_cost(1.0) - return await self.daemon_request('getrawtransaction', tx_hash, verbose) + return await self.session_mgr.getrawtransaction(tx_hash, verbose=verbose) - async def transaction_merkle(self, tx_hash, height): + async def transaction_merkle(self, tx_hash, height=None): '''Return the merkle branch to a confirmed transaction given its hash and height. @@ -1665,13 +1883,75 @@ async def transaction_merkle(self, tx_hash, height): height: the height of the block it is in ''' tx_hash = assert_tx_hash(tx_hash) - height = non_negative_integer(height) + if height is not None: + height = non_negative_integer(height) # unused + + (height, wtxid, branch, tx_pos, block_header, cost) = ( + await self.session_mgr.merkle_branch_for_tx_hash( + tx_hash=tx_hash, witness=False)) + self.bump_cost(cost) + blockhash = self.coin.header_hash(block_header).hex() + + assert height is not None + return { + "block_height": height, + "block_hash": blockhash, + "merkle": branch, + "pos": tx_pos, + } + + async def transaction_merkle_witness(self, tx_hash: str, height=None, cb=False): + '''Return the witness merkle branch to a confirmed transaction given its hash + and height. - branch, tx_pos, cost = await self.session_mgr.merkle_branch_for_tx_hash( - height, tx_hash) + tx_hash: the transaction hash as a hexadecimal string + height: the height of the block it is in + cb: whether to include `cb_tx` and `cb_proof` in response + ''' + tx_hash = assert_tx_hash(tx_hash) + if height is not None: + height = non_negative_integer(height) # unused + + (height, wtxid, wbranch, tx_pos, block_header, cost) = ( + await self.session_mgr.merkle_branch_for_tx_hash( + tx_hash=tx_hash, witness=True)) self.bump_cost(cost) + blockhash = self.coin.header_hash(block_header).hex() + assert isinstance(wtxid, bytes) and len(wtxid) == 32 + wtxid_hex = hash_to_hex_str(wtxid) + assert height is not None + + ret = { + "block_height": height, + "block_hash": blockhash, + "wmerkle": wbranch, + "pos": tx_pos, + "wtxid": wtxid_hex, + } + if not cb: + return ret - return {"block_height": height, "merkle": branch, "pos": tx_pos} + # add coinbase proof + tx_hashes, cost = await self.session_mgr.tx_hashes_at_blockheight(height) + self.bump_cost(cost) + cb_txid = tx_hashes[0] + cb_tx = await self.session_mgr.getrawtransaction(cb_txid) + + (cb_height, cb_wtxid, cb_branch, cb_tx_pos, cb_block_header, cost) = ( + await self.session_mgr.merkle_branch_for_tx_hash( + tx_hash=cb_txid, witness=False)) + self.bump_cost(cost) + + if block_header != cb_block_header: + # there was a reorg while processing the request... TODO maybe retry? + raise RPCError(BAD_REQUEST, + f'tx {hash_to_hex_str(tx_hash)} was reorged while processing request') + + ret.update({ + "cb_tx": cb_tx, + "cb_proof": cb_branch, + }) + return ret async def transaction_id_from_pos(self, height, tx_pos, merkle=False): '''Return the txid and optionally a merkle proof, given @@ -1709,11 +1989,6 @@ def set_request_handlers(self, ptuple): 'blockchain.block.headers': self.block_headers, 'blockchain.estimatefee': self.estimatefee, 'blockchain.headers.subscribe': self.headers_subscribe, - 'blockchain.scripthash.get_balance': self.scripthash_get_balance, - 'blockchain.scripthash.get_history': self.scripthash_get_history, - 'blockchain.scripthash.get_mempool': self.scripthash_get_mempool, - 'blockchain.scripthash.listunspent': self.scripthash_listunspent, - 'blockchain.scripthash.subscribe': self.scripthash_subscribe, 'blockchain.transaction.broadcast': self.transaction_broadcast, 'blockchain.transaction.get': self.transaction_get, 'blockchain.transaction.get_merkle': self.transaction_merkle, @@ -1727,8 +2002,16 @@ def set_request_handlers(self, ptuple): 'server.ping': self.ping, 'server.version': self.server_version, } + notif_handlers = {} - if ptuple >= (1, 4, 2): + if ptuple < (1, 7): + handlers['blockchain.scripthash.get_balance'] = self.scripthash_get_balance + handlers['blockchain.scripthash.get_history'] = self.scripthash_get_history + handlers['blockchain.scripthash.get_mempool'] = self.scripthash_get_mempool + handlers['blockchain.scripthash.listunspent'] = self.scripthash_listunspent + handlers['blockchain.scripthash.subscribe'] = self.scripthash_subscribe + + if (1, 4, 2) <= ptuple < (1, 7): handlers['blockchain.scripthash.unsubscribe'] = self.scripthash_unsubscribe if ptuple >= (1, 6): @@ -1737,7 +2020,21 @@ def set_request_handlers(self, ptuple): else: handlers['blockchain.relayfee'] = self.relayfee # removed in 1.6 + # experimental: + if ptuple >= (1, 7): + handlers['blockchain.outpoint.subscribe'] = self.txoutpoint_subscribe + handlers['blockchain.outpoint.unsubscribe'] = self.txoutpoint_unsubscribe + handlers['blockchain.scriptpubkey.get_balance'] = self.scriptpubkey_get_balance + handlers['blockchain.scriptpubkey.get_history'] = self.scriptpubkey_get_history + handlers['blockchain.scriptpubkey.get_mempool'] = self.scriptpubkey_get_mempool + handlers['blockchain.scriptpubkey.listunspent'] = self.scriptpubkey_listunspent + handlers['blockchain.scriptpubkey.subscribe'] = self.scriptpubkey_subscribe + handlers['blockchain.scriptpubkey.unsubscribe'] = self.scriptpubkey_unsubscribe + handlers['blockchain.transaction.get_merkle_witness'] = self.transaction_merkle_witness + notif_handlers['server.ping'] = self.on_ping_notification + self.request_handlers = handlers + self.notification_handlers = notif_handlers class LocalRPC(SessionBase): @@ -1751,6 +2048,8 @@ def __init__(self, *args, **kwargs): self.sv_negotiated.set() self.client = 'RPC' self.connection.max_response_size = 0 + # note: self.request_handlers are set on the class, in SessionManager.__init__ + self.notification_handlers = {} def protocol_version_string(self): return 'RPC' @@ -1781,9 +2080,19 @@ def set_request_handlers(self, ptuple): 'protx.info': self.protx_info, }) - async def _notify_inner(self, touched, height_changed): + async def _notify_inner( + self, + *, + touched_hashxs, + touched_outpoints, + height_changed, + ): '''Notify the client about changes in masternode list.''' - await super()._notify_inner(touched, height_changed) + await super()._notify_inner( + touched_hashxs=touched_hashxs, + touched_outpoints=touched_outpoints, + height_changed=height_changed, + ) for mn in self.mns.copy(): status = await self.daemon_request('masternode_list', ('status', mn)) diff --git a/src/electrumx/server/storage.py b/src/electrumx/server/storage.py index a45c27002..377cf44d9 100644 --- a/src/electrumx/server/storage.py +++ b/src/electrumx/server/storage.py @@ -65,6 +65,13 @@ def iterator(self, prefix=b'', reverse=False): If `prefix` is set, only keys starting with `prefix` will be included. If `reverse` is True the items are returned in reverse order. + + The iterator supports .seek(key), which moves it just left of `key`. + - if forward-iterating + - if `key` is present, it will be the next item + - if `key` is not present, the next item will be the smallest still greater than `key` + - if reverse-iterating + - the next item will be the largest still smaller than `key` ''' raise NotImplementedError @@ -85,10 +92,32 @@ def open(self, name, create): self.close = self.db.close self.get = self.db.get self.put = self.db.put - self.iterator = self.db.iterator self.write_batch = partial(self.db.write_batch, transaction=True, sync=True) + def iterator(self, prefix=b'', reverse=False): + return LevelDBIterator(db=self.db, prefix=prefix, reverse=reverse) + + +class LevelDBIterator: + '''An iterator for LevelDB.''' + + def __init__(self, *, db, prefix, reverse): + self.prefix = prefix + self.iterator = db.iterator(prefix=prefix, reverse=reverse) + + def __iter__(self): + return self + + def __next__(self): + k, v = next(self.iterator) + if not k.startswith(self.prefix): + raise StopIteration + return k, v + + def seek(self, key: bytes) -> None: + self.iterator.seek(key) + class RocksDB(Storage): '''RocksDB database engine.''' @@ -119,7 +148,7 @@ def write_batch(self): return RocksDBWriteBatch(self.db) def iterator(self, prefix=b'', reverse=False): - return RocksDBIterator(self.db, prefix, reverse) + return RocksDBIterator(db=self.db, prefix=prefix, reverse=reverse) class RocksDBWriteBatch: @@ -140,8 +169,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): class RocksDBIterator: '''An iterator for RocksDB.''' - def __init__(self, db, prefix, reverse): + def __init__(self, *, db, prefix, reverse): self.prefix = prefix + self._is_reverse = reverse if reverse: self.iterator = reversed(db.iteritems()) nxt_prefix = util.increment_byte_string(prefix) @@ -165,3 +195,11 @@ def __next__(self): if not k.startswith(self.prefix): raise StopIteration return k, v + + def seek(self, key: bytes) -> None: + self.iterator.seek(key) + if self._is_reverse: + try: + next(self) + except StopIteration: + self.iterator.seek_to_last() diff --git a/tests/server/test_compaction.py b/tests/server/test_compaction.py deleted file mode 100644 index ad6c96a43..000000000 --- a/tests/server/test_compaction.py +++ /dev/null @@ -1,133 +0,0 @@ -'''Test of compaction code in server/history.py''' -import array -import random -from os import environ, urandom - -import pytest - -from electrumx.lib.hash import HASHX_LEN -from electrumx.lib.util import pack_be_uint16, pack_le_uint64 -from electrumx.server.env import Env -from electrumx.server.db import DB - - -def create_histories(history, hashX_count=100): - '''Creates a bunch of random transaction histories, and write them - to disk in a series of small flushes.''' - hashXs = [urandom(HASHX_LEN) for n in range(hashX_count)] - mk_array = lambda : array.array('Q') - histories = {hashX : mk_array() for hashX in hashXs} - unflushed = history.unflushed - tx_num = 0 - while hashXs: - tx_numb = pack_le_uint64(tx_num)[:5] - hash_indexes = set(random.randrange(len(hashXs)) - for n in range(1 + random.randrange(4))) - for index in hash_indexes: - histories[hashXs[index]].append(tx_num) - unflushed[hashXs[index]].extend(tx_numb) - - tx_num += 1 - # Occasionally flush and drop a random hashX if non-empty - if random.random() < 0.1: - history.flush() - index = random.randrange(0, len(hashXs)) - if histories[hashXs[index]]: - del hashXs[index] - - return histories - - -def check_hashX_compaction(history): - history.max_hist_row_entries = 40 - row_size = history.max_hist_row_entries * 5 - full_hist = b''.join(pack_le_uint64(tx_num)[:5] for tx_num in range(100)) - hashX = urandom(HASHX_LEN) - pairs = ((1, 20), (26, 50), (56, 30)) - - cum = 0 - hist_list = [] - hist_map = {} - for flush_count, count in pairs: - key = hashX + pack_be_uint16(flush_count) - hist = full_hist[cum * 5: (cum+count) * 5] - hist_map[key] = hist - hist_list.append(hist) - cum += count - - write_items = [] - keys_to_delete = set() - write_size = history._compact_hashX(hashX, hist_map, hist_list, - write_items, keys_to_delete) - # Check results for sanity - assert write_size == len(full_hist) - assert len(write_items) == 3 - assert len(keys_to_delete) == 3 - assert len(hist_map) == len(pairs) - for n, item in enumerate(write_items): - assert item == (hashX + pack_be_uint16(n), - full_hist[n * row_size: (n + 1) * row_size]) - for flush_count, count in pairs: - assert hashX + pack_be_uint16(flush_count) in keys_to_delete - - # Check re-compaction is null - hist_map = {key: value for key, value in write_items} - hist_list = [value for key, value in write_items] - write_items.clear() - keys_to_delete.clear() - write_size = history._compact_hashX(hashX, hist_map, hist_list, - write_items, keys_to_delete) - assert write_size == 0 - assert len(write_items) == 0 - assert len(keys_to_delete) == 0 - assert len(hist_map) == len(pairs) - - # Check re-compaction adding a single tx writes the one row - hist_list[-1] += array.array('I', [100]).tobytes() - write_size = history._compact_hashX(hashX, hist_map, hist_list, - write_items, keys_to_delete) - assert write_size == len(hist_list[-1]) - assert write_items == [(hashX + pack_be_uint16(2), hist_list[-1])] - assert len(keys_to_delete) == 1 - assert write_items[0][0] in keys_to_delete - assert len(hist_map) == len(pairs) - - -def check_written(history, histories): - for hashX, hist in histories.items(): - db_hist = array.array('I', history.get_txnums(hashX, limit=None)) - assert hist == db_hist - - -def compact_history(history): - '''Synchronously compact the DB history.''' - history.comp_cursor = 0 - - history.comp_flush_count = max(history.comp_flush_count, 1) - limit = 5 * 1000 - - write_size = 0 - while history.comp_cursor != -1: - write_size += history._compact_history(limit) - assert write_size != 0 - - -@pytest.mark.asyncio -async def test_compaction(tmpdir): - db_dir = str(tmpdir) - print(f'Temp dir: {db_dir}') - environ.clear() - environ['DB_DIRECTORY'] = db_dir - environ['DAEMON_URL'] = '' - environ['COIN'] = 'BitcoinSV' - db = DB(Env()) - await db.open_for_serving() - history = db.history - - # Test abstract compaction - check_hashX_compaction(history) - # Now test in with random data - histories = create_histories(history) - check_written(history, histories) - compact_history(history) - check_written(history, histories) diff --git a/tests/server/test_mempool.py b/tests/server/test_mempool.py index 6cbbc81b3..700c0d135 100644 --- a/tests/server/test_mempool.py +++ b/tests/server/test_mempool.py @@ -222,34 +222,20 @@ def cached_height(self): return self._cached_height async def mempool_hashes(self): - '''Query bitcoind for the hashes of all transactions in its - mempool, returned as a list.''' await sleep(0) return [hash_to_hex_str(hash) for hash in self.txs] async def raw_transactions(self, hex_hashes): - '''Query bitcoind for the serialized raw transactions with the given - hashes. Missing transactions are returned as None. - - hex_hashes is an iterable of hexadecimal hash strings.''' await sleep(0) hashes = [hex_str_to_hash(hex_hash) for hex_hash in hex_hashes] return [self.raw_txs.get(hash) for hash in hashes] async def lookup_utxos(self, prevouts): - '''Return a list of (hashX, value) pairs each prevout if unspent, - otherwise return None if spent or not found. - - prevouts - an iterable of (hash, index) pairs - ''' await sleep(0) return [self.db_utxos.get(prevout) for prevout in prevouts] - async def on_mempool(self, touched, height): - '''Called each time the mempool is synchronized. touched is a set of - hashXs touched since the previous call. height is the - daemon's height at the time the mempool was obtained.''' - self.on_mempool_calls.append((touched, height)) + async def on_mempool(self, *, touched_hashxs, touched_outpoints, height): + self.on_mempool_calls.append((touched_hashxs, height)) await sleep(0) diff --git a/tests/server/test_notifications.py b/tests/server/test_notifications.py index c8c55b311..ee14f237b 100644 --- a/tests/server/test_notifications.py +++ b/tests/server/test_notifications.py @@ -7,15 +7,15 @@ async def test_simple_mempool(): n = Notifications() notified = [] - async def notify(height, touched): - notified.append((height, touched)) + async def notify(*, touched_hashxs, touched_outpoints, height): + notified.append((height, touched_hashxs)) await n.start(5, notify) - mtouched = {'a', 'b'} - btouched = {'b', 'c'} - await n.on_mempool(mtouched, 6) + mtouched = {b'a', b'b'} + btouched = {b'b', b'c'} + await n.on_mempool(touched_hashxs=mtouched, height=6, touched_outpoints=set()) assert notified == [(5, set())] - await n.on_block(btouched, 6) + await n.on_block(touched_hashxs=btouched, height=6, touched_outpoints=set()) assert notified == [(5, set()), (6, set.union(mtouched, btouched))] @@ -23,23 +23,23 @@ async def notify(height, touched): async def test_enter_mempool_quick_blocks_2(): n = Notifications() notified = [] - async def notify(height, touched): - notified.append((height, touched)) + async def notify(*, touched_hashxs, touched_outpoints, height): + notified.append((height, touched_hashxs)) await n.start(5, notify) # Suppose a gets in block 6 and blocks 7,8 found right after and # the block processor processes them together. - await n.on_mempool({'a'}, 5) - assert notified == [(5, set()), (5, {'a'})] + await n.on_mempool(touched_hashxs={b'a'}, height=5, touched_outpoints=set()) + assert notified == [(5, set()), (5, {b'a'})] # Mempool refreshes with daemon on block 6 - await n.on_mempool({'a'}, 6) - assert notified == [(5, set()), (5, {'a'})] + await n.on_mempool(touched_hashxs={b'a'}, height=6, touched_outpoints=set()) + assert notified == [(5, set()), (5, {b'a'})] # Blocks 6, 7 processed together - await n.on_block({'a', 'b'}, 7) - assert notified == [(5, set()), (5, {'a'})] + await n.on_block(touched_hashxs={b'a', b'b'}, height=7, touched_outpoints=set()) + assert notified == [(5, set()), (5, {b'a'})] # Then block 8 processed - await n.on_block({'c'}, 8) - assert notified == [(5, set()), (5, {'a'})] + await n.on_block(touched_hashxs={b'c'}, height=8, touched_outpoints=set()) + assert notified == [(5, set()), (5, {b'a'})] # Now mempool refreshes - await n.on_mempool(set(), 8) - assert notified == [(5, set()), (5, {'a'}), (8, {'a', 'b', 'c'})] + await n.on_mempool(touched_hashxs=set(), height=8, touched_outpoints=set()) + assert notified == [(5, set()), (5, {b'a'}), (8, {b'a', b'b', b'c'})] diff --git a/tests/server/test_storage.py b/tests/server/test_storage.py index 3cfb6e172..b47b963f5 100644 --- a/tests/server/test_storage.py +++ b/tests/server/test_storage.py @@ -66,6 +66,76 @@ def test_iterator_reverse(db): ] +def test_iterator_seek(db): + db.put(b"first-key1", b"val") + db.put(b"first-key2", b"val") + db.put(b"first-key3", b"val") + db.put(b"key-1", b"value-1") + db.put(b"key-5", b"value-5") + db.put(b"key-3", b"value-3") + db.put(b"key-8", b"value-8") + db.put(b"key-2", b"value-2") + db.put(b"key-4", b"value-4") + db.put(b"last-key1", b"val") + db.put(b"last-key2", b"val") + db.put(b"last-key3", b"val") + # forward-iterate, key present, no prefix + it = db.iterator() + it.seek(b"key-4") + assert list(it) == [(b"key-4", b"value-4"), (b"key-5", b"value-5"), (b"key-8", b"value-8"), + (b"last-key1", b"val"), (b"last-key2", b"val"), (b"last-key3", b"val")] + # forward-iterate, key present + it = db.iterator(prefix=b"key-") + it.seek(b"key-4") + assert list(it) == [(b"key-4", b"value-4"), (b"key-5", b"value-5"), + (b"key-8", b"value-8")] + # forward-iterate, key missing + it = db.iterator(prefix=b"key-") + it.seek(b"key-6") + assert list(it) == [(b"key-8", b"value-8")] + # forward-iterate, after last prefix + it = db.iterator(prefix=b"key-") + it.seek(b"key-9") + assert list(it) == [] + # forward-iterate, after last, no prefix + it = db.iterator() + it.seek(b"z") + assert list(it) == [] + # forward-iterate, no such prefix + it = db.iterator(prefix=b"key---") + it.seek(b"key---5") + assert list(it) == [] + # forward-iterate, seek outside prefix + it = db.iterator(prefix=b"key-") + it.seek(b"last-key2") + assert list(it) == [] + # reverse-iterate, key present + it = db.iterator(prefix=b"key-", reverse=True) + it.seek(b"key-4") + assert list(it) == [(b"key-3", b"value-3"), (b"key-2", b"value-2"), (b"key-1", b"value-1")] + # reverse-iterate, key missing + it = db.iterator(prefix=b"key-", reverse=True) + it.seek(b"key-7") + assert list(it) == [(b"key-5", b"value-5"), (b"key-4", b"value-4"), (b"key-3", b"value-3"), + (b"key-2", b"value-2"), (b"key-1", b"value-1")] + # reverse-iterate, before first prefix + it = db.iterator(prefix=b"key-", reverse=True) + it.seek(b"key-0") + assert list(it) == [] + # reverse-iterate, before first, no prefix + it = db.iterator(reverse=True) + it.seek(b"a") + assert list(it) == [] + # reverse-iterate, no such prefix + it = db.iterator(prefix=b"key---", reverse=True) + it.seek(b"key---5") + assert list(it) == [] + # reverse-iterate, seek outside prefix + it = db.iterator(prefix=b"key-", reverse=True) + it.seek(b"first-key2") + assert list(it) == [] + + def test_close(db): db.put(b"a", b"b") db.close()