diff --git a/iavl/cli.py b/iavl/cli.py index 1547b13..149f46a 100644 --- a/iavl/cli.py +++ b/iavl/cli.py @@ -9,10 +9,14 @@ from hexbytes import HexBytes from . import dbm -from .utils import (decode_fast_node, diff_iterators, encode_stdint, - fast_node_key, get_node, get_root_node, - iavl_latest_version, iter_fast_nodes, iter_iavl_tree, - load_commit_infos, root_key, store_prefix) +from .utils import (METADATA_KEY_PREFIX, NODE_KEY_PREFIX, ORPHAN_KEY_PREFIX, + ROOT_KEY_PREFIX, decode_fast_node, diff_iterators, + encode_stdint, fast_node_key, get_node, get_root_hash, + get_root_node, iavl_latest_version, + is_legacy, + iter_fast_nodes, iter_iavl_tree, legacy_root_key, + load_commit_infos, node_key_suffix, parse_node_key, + root_key, store_prefix) @click.group @@ -42,7 +46,7 @@ def root_hash(db, store: List[str], version: Optional[int]): for s in store: if version is None: version = iavl_latest_version(db, s) - bz = db.get(store_prefix(s) + root_key(version)) + bz = get_root_hash(db, s, version) print(f"{s}: {binascii.hexlify(bz or b'').decode()}") @@ -85,24 +89,27 @@ def root_versions(db, store: str, reverse: bool = False): """ iterate all root versions """ - begin = store_prefix(store) + b"r" - end = store_prefix(store) + b"s" # exclusive - + prefix = store_prefix(store) db = dbm.open(str(db), read_only=True) it = db.iterkeys() - if not reverse: - it.seek(begin) - for k in it: - if k >= end: + res = [] + legacy = is_legacy(db, store) + begin = prefix + ROOT_KEY_PREFIX if legacy else prefix + NODE_KEY_PREFIX + it.seek(begin) + for k in it: + if legacy: + if k[: len(prefix)] != prefix: break - print(int.from_bytes(k[len(begin) :], "big")) - else: - it = reversed(it) - it.seek_for_prev(end) - for k in it: - if k < begin: + res.append(int.from_bytes(k[len(begin) :], "big")) + else: + if k[: len(prefix)] != prefix: break - print(int.from_bytes(k[len(begin) :], "big")) + k = k[len(prefix) :] + version, nonce = parse_node_key(k) + res.append(f"{version}-{nonce}") + if reverse: + res = reversed(res) + print(*res, sep="\n") @cli.command() @@ -155,7 +162,7 @@ def metadata(db, store): raise click.UsageError("no store names are provided") db = dbm.open(str(db), read_only=True) for s in store: - bz = db.get(store_prefix(s) + b"m" + b"storage_version") + bz = db.get(store_prefix(s) + METADATA_KEY_PREFIX + b"storage_version") print(f"{s} storage version: {bz.decode()}") print(f"{s} latest version: {iavl_latest_version(db, s)}") @@ -211,7 +218,7 @@ def range_iavl(db, store, version, start, end, output_value): # find root node first if version is None: version = iavl_latest_version(db, store) - root_hash = db.get(store_prefix(store) + root_key(version)) + root_hash = get_root_hash(db, store, version) for k, v in iter_iavl_tree(db, store, root_hash, start, end): if output_value: print(f"{HexBytes(k).hex()} {HexBytes(v).hex()}") @@ -268,8 +275,11 @@ def diff_fastnode(db, store, start, end, output_value): it1 = iter_fast_nodes(db, store, start, end) # find root node first + legacy = is_legacy(db, store) version = iavl_latest_version(db, store) - root_hash = db.get(store_prefix(store) + root_key(version)) + prefix = store_prefix(store) if store is not None else b"" + suffix = legacy_root_key(version) if legacy else root_key(version) + root_hash = db.get(prefix + suffix) it2 = iter_iavl_tree(db, store, root_hash, start, end) for status, k, v in diff_iterators(it1, it2): @@ -323,7 +333,7 @@ def fast_rollback( ver = iavl_latest_version(db, info.name) print("delete orphan entries created since target version") - orphan_prefix = prefix + b"o" + target.to_bytes(8, "big") + orphan_prefix = prefix + ORPHAN_KEY_PREFIX + target.to_bytes(8, "big") it = db.iterkeys() it.seek(orphan_prefix) for k in it: @@ -365,12 +375,11 @@ def visualize(db, version, store=None, include_prev_version=False): if version is None: version = iavl_latest_version(db, store) - prefix = store_prefix(store) if store is not None else b"" - root_hash = db.get(prefix + root_key(version)) + root_hash = get_root_hash(db, store, version) root_hash2 = None if include_prev_version and version > 1: - root_hash2 = db.get(prefix + root_key(version - 1)) - g = visualize_iavl(db, prefix, root_hash, version, root_hash2=root_hash2) + root_hash2 = get_root_hash(db, store, version - 1) + g = visualize_iavl(db, root_hash, version, root_hash2=root_hash2, store=store) print(g.source) @@ -396,7 +405,13 @@ def visualize(db, version, store=None, include_prev_version=False): type=click.Path(exists=True), required=True, ) -def dump_changesets(db, start_version, end_version, store: Optional[str], out_dir: str): +def dump_changesets( + db, + start_version, + end_version, + store: Optional[str], + out_dir: str, +): """ extract changeset by comparing iavl versions and save in files with compatible format with file streamer. @@ -408,10 +423,16 @@ def dump_changesets(db, start_version, end_version, store: Optional[str], out_di db = dbm.open(str(db), read_only=True) prefix = store_prefix(store) if store is not None else b"" ndb = NodeDB(db, prefix=prefix) - for _, v, _, changeset in diff.iter_state_changes( - db, ndb, start_version=start_version, end_version=end_version, prefix=prefix + legacy = is_legacy(db, store) + for _, v, n, _, changeset in diff.iter_state_changes( + db, + ndb, + start_version=start_version, + end_version=end_version, + prefix=prefix, + legacy=legacy, ): - with (Path(out_dir) / f"block-{v}-data").open("wb") as fp: + with (Path(out_dir) / f"block-{v}-{n}-data").open("wb") as fp: diff.write_change_set(fp, changeset) @@ -448,13 +469,20 @@ def test_state_round_trip(db, store, start_version): db = dbm.open(str(db), read_only=True) prefix = store_prefix(store) if store is not None else b"" ndb = NodeDB(db, prefix=prefix) - for pversion, v, root, changeset in diff.iter_state_changes( - db, ndb, start_version=start_version, prefix=prefix + legacy = is_legacy(db, store) + for pversion, v, n, root, changeset in diff.iter_state_changes( + db, + ndb, + start_version=start_version, + prefix=prefix, + legacy=legacy, ): # re-apply changeset tree = Tree(ndb, pversion) diff.apply_change_set(tree, changeset) tmp = tree.save_version(dry_run=True) + if not legacy: + tmp = node_key_suffix(v, n) if (root or hashlib.sha256().digest()) == tmp: print(v, len(changeset), "ok") else: @@ -487,8 +515,9 @@ def visualize_pruning(db, store, version): db = dbm.open(str(db), read_only=True) prefix = store_prefix(store) if store is not None else b"" ndb = NodeDB(db, prefix=prefix) - predecessor = ndb.prev_version(version) or 0 - successor = ndb.next_version(version) + legacy = is_legacy(db, store) + predecessor = ndb.prev_version(version, legacy) or 0 + successor = ndb.next_version(version, legacy) root1 = ndb.get_root_hash(version) root2 = ndb.get_root_hash(successor) @@ -549,15 +578,15 @@ def scan_wal(wal: str): for cs in entry.changeset: print(f"store: {cs.name}") for pair in cs.changeset.pairs: - print( - f" key: {binascii.hexlify(pair.key).decode()} value: {binascii.hexlify(pair.value).decode()}" - ) + key = binascii.hexlify(pair.key).decode() + value = binascii.hexlify(pair.value).decode() + print(f" key: {key} value: {value}") for upgrade in entry.upgrades: print(f"upgrade: {upgrade.name}", end="") if upgrade.rename_from: print(f", from {upgrade.rename_from}", end="") if upgrade.delete: - print(f"deleted", end="") + print("deleted", end="") print("") diff --git a/iavl/diff.py b/iavl/diff.py index e4f09a2..8da89a3 100644 --- a/iavl/diff.py +++ b/iavl/diff.py @@ -10,7 +10,9 @@ from . import dbm from .iavl import NodeDB, PersistedNode, Tree -from .utils import GetNode, root_key, visit_iavl_nodes +from .utils import (NODE_KEY_PREFIX, ROOT_KEY_PREFIX, GetNode, legacy_root_key, + node_key_suffix, parse_node_key, root_key, + visit_iavl_nodes) class Op(IntEnum): @@ -188,22 +190,37 @@ def parse_change_set(data): def iter_state_changes( - db: dbm.DBM, ndb: NodeDB, start_version=0, end_version=None, prefix=b"" + db: dbm.DBM, + ndb: NodeDB, + start_version=0, + end_version=None, + prefix=b"", + legacy=False, ): from . import diff - pversion = ndb.prev_version(start_version) or 0 + pversion = ndb.prev_version(start_version, legacy) or 0 prev_root = ndb.get_root_hash(pversion) it = db.iteritems() - it.seek(prefix + root_key(start_version)) + key = legacy_root_key(start_version) if legacy else root_key(start_version) + it.seek(prefix + key) + key_prefix = ROOT_KEY_PREFIX if legacy else NODE_KEY_PREFIX + n = 1 for k, hash in it: - if not k.startswith(prefix + b"r"): + if not k.startswith(prefix + key_prefix): break - v = int.from_bytes(k[len(prefix) + 1 :], "big") + if legacy: + v = int.from_bytes(k[len(prefix) + 1 :], "big") + else: + k = k[len(prefix) :] + v, n = parse_node_key(k) + hash = node_key_suffix(v, n) if end_version is not None and v >= end_version: break - yield pversion, v, hash, diff.state_changes(ndb.get, pversion, prev_root, hash) + yield pversion, v, n, hash, diff.state_changes( + ndb.get, pversion, prev_root, hash, + ) pversion = v prev_root = hash diff --git a/iavl/iavl.py b/iavl/iavl.py index db10f3b..7121c45 100644 --- a/iavl/iavl.py +++ b/iavl/iavl.py @@ -9,8 +9,10 @@ import cprotobuf import rocksdb -from .utils import (GetNode, PersistedNode, encode_bytes, node_key, root_key, - visit_iavl_nodes) +from .utils import (NODE_KEY_PREFIX, ROOT_KEY_PREFIX, GetNode, PersistedNode, + decode_node, encode_bytes, legacy_node_key, + legacy_root_key, node_key, parse_node_key, root_key, + root_key_suffix, visit_iavl_nodes) NodeRef = Union[bytes, "Node"] @@ -35,12 +37,14 @@ def get(self, hash: bytes) -> Optional[PersistedNode]: try: return self.cache[hash] except KeyError: - bz = self.db.get(self.prefix + node_key(hash)) - if bz is None: - return - node = PersistedNode.decode(bz, hash) - self.cache[hash] = node - return node + for key_func in (node_key, legacy_node_key): + key = key_func(hash) + bz = self.db.get(self.prefix + key) + if bz: + legacy = key_func is legacy_node_key + node, _ = decode_node(bz, hash, legacy) + return node + return None def resolve_node(self, ref: NodeRef) -> Union["Node", PersistedNode, None]: if isinstance(ref, Node): @@ -48,28 +52,32 @@ def resolve_node(self, ref: NodeRef) -> Union["Node", PersistedNode, None]: elif ref is not None: return self.get(ref) - def batch_remove_node(self, hash: bytes): + def batch_remove_node(self, hash: bytes, legacy: bool = False): "remove node" if self.batch is None: self.batch = rocksdb.WriteBatch() - self.batch.delete(node_key(hash)) + key = legacy_node_key(hash) if legacy else node_key(hash) + self.batch.delete(key) self.cache.pop(hash, None) - def batch_remove_root_hash(self, version: int): + def batch_remove_root_hash(self, version: int, legacy: bool = False): if self.batch is None: self.batch = rocksdb.WriteBatch() - self.batch.delete(root_key(version)) + key = legacy_root_key(version) if legacy else root_key(version) + self.batch.delete(key) - def batch_set_node(self, hash: bytes, node: PersistedNode): + def batch_set_node(self, hash: bytes, node: PersistedNode, legacy: bool = False): if self.batch is None: self.batch = rocksdb.WriteBatch() self.cache[hash] = node - self.batch.put(node_key(hash), node.encode()) + key = legacy_node_key(hash) if legacy else node_key(hash) + self.batch.put(key, node.encode()) - def batch_set_root_hash(self, version: int, hash: bytes): + def batch_set_root_hash(self, version: int, hash: bytes, legacy: bool = False): if self.batch is None: self.batch = rocksdb.WriteBatch() - self.batch.put(root_key(version), hash) + key = legacy_root_key(version) if legacy else root_key(version) + self.batch.put(key, hash) def batch_commit(self): if self.batch is not None: @@ -77,7 +85,12 @@ def batch_commit(self): self.batch = None def get_root_hash(self, version: int) -> Optional[bytes]: - return self.db.get(self.prefix + root_key(version)) + bz = self.db.get(self.prefix + root_key(version)) + if not bz: + return self.db.get(self.prefix + legacy_root_key(version)) + key = root_key(version) + _, nonce = parse_node_key(key) + return root_key_suffix(nonce) def get_root_node(self, version: int) -> Optional[PersistedNode]: h = self.get_root_hash(version) @@ -90,12 +103,13 @@ def latest_version(self) -> Optional[int]: return iavl_latest_version(self.db, None) - def next_version(self, v: int) -> Optional[int]: + def next_version(self, v: int, legacy=False) -> Optional[int]: """ return the first version larger than v """ it = self.db.iterkeys() - target = self.prefix + root_key(v) + suffix = legacy_root_key(v) if legacy else root_key(v) + target = self.prefix + suffix it.seek(target) k = next(it, None) if k is None: @@ -104,31 +118,43 @@ def next_version(self, v: int) -> Optional[int]: k = next(it, None) if k is None: return - if not k.startswith(self.prefix + b"r"): + suffix = ROOT_KEY_PREFIX if legacy else NODE_KEY_PREFIX + if not k.startswith(self.prefix + suffix): return + if legacy: + return int.from_bytes(k[len(self.prefix) + 1 :], "big") + else: + k = k[len(self.prefix) :] + version, _ = parse_node_key(k) + return version - return int.from_bytes(k[len(self.prefix) + 1 :], "big") - - def prev_version(self, v: int) -> Optional[int]: + def prev_version(self, v: int, legacy=False) -> Optional[int]: """ return the closest version that's smaller than the target """ it = reversed(self.db.iterkeys()) - target = self.prefix + root_key(v) + suffix = legacy_root_key(v) if legacy else root_key(v) + target = self.prefix + suffix it.seek_for_prev(target) - key = next(it, None) - if key == target: - key = next(it, None) - if key is None or not key.startswith(self.prefix + b"r"): + k = next(it, None) + if k == target: + k = next(it, None) + suffix = ROOT_KEY_PREFIX if legacy else NODE_KEY_PREFIX + if k is None or not k.startswith(self.prefix + suffix): return - return int.from_bytes(key[len(self.prefix) + 1 :], "big") + if legacy: + return int.from_bytes(k[len(self.prefix) + 1 :], "big") + else: + k = k[len(self.prefix) :] + version, _ = parse_node_key(k) + return version - def delete_version(self, v: int) -> int: + def delete_version(self, v: int, legacy=False) -> int: """ return how many nodes deleted """ - predecessor = self.prev_version(v) or 0 - successor = self.next_version(v) + predecessor = self.prev_version(v, legacy) or 0 + successor = self.next_version(v, legacy) assert successor is not None, "can't delete latest version" counter = 0 @@ -140,9 +166,9 @@ def delete_version(self, v: int) -> int: self.get_root_hash(successor), ): counter += 1 - self.batch_remove_node(n.hash) + self.batch_remove_node(n.hash, legacy) - self.batch_remove_root_hash(v) + self.batch_remove_root_hash(v, legacy) self.batch_commit() return counter @@ -361,21 +387,21 @@ def remove(self, key: bytes) -> Optional[bytes]: self.root_node_ref = new return value - def save_version(self, dry_run=False): + def save_version(self, dry_run=False, legacy: bool = False): """ if dry_run=True, don't actually modify anything, just return the new root hash """ def save_node(hash: bytes, node: Node): if not dry_run: - self.ndb.batch_set_node(hash, node.persisted(hash)) + self.ndb.batch_set_node(hash, node.persisted(hash), legacy) if isinstance(self.root_node_ref, Node): self.root_node_ref = self.root_node_ref.save(save_node) root_hash = self.root_node_ref or hashlib.sha256().digest() if not dry_run: self.version += 1 - self.ndb.batch_set_root_hash(self.version, root_hash) + self.ndb.batch_set_root_hash(self.version, root_hash, legacy) self.ndb.batch_commit() return root_hash diff --git a/iavl/leveldb.py b/iavl/leveldb.py index e9d4558..d88f81f 100644 --- a/iavl/leveldb.py +++ b/iavl/leveldb.py @@ -55,5 +55,5 @@ def open(dir, read_only: bool = False): return LevelDB(plyvel.DB(str(dir))) -def WriteBatch(db): +def WriteBatch(db): # noqa: N802 return db.db.write_batch() diff --git a/iavl/utils.py b/iavl/utils.py index fa3179a..75b6f20 100644 --- a/iavl/utils.py +++ b/iavl/utils.py @@ -1,7 +1,8 @@ import hashlib import itertools +import struct from collections.abc import Iterator -from typing import Callable, List, NamedTuple, Optional, Tuple +from typing import Callable, List, Optional, Tuple import cprotobuf from hexbytes import HexBytes @@ -9,6 +10,12 @@ from .dbm import DBM EMPTY_HASH = hashlib.sha256().digest() +FAST_KEY_PREFIX = b"f" +METADATA_KEY_PREFIX = b"m" +NODE_KEY_PREFIX = b"s" +ORPHAN_KEY_PREFIX = b"o" +ROOT_KEY_PREFIX = b"r" +LEGACY_NODE_KEY_PREFIX = b"n" GetNode = Callable[bytes, Optional["PersistedNode"]] @@ -31,24 +38,34 @@ class StdInt(cprotobuf.ProtoEntity): value = cprotobuf.Field("uint64", 1) -class PersistedNode(NamedTuple): +class PersistedNode: """ immutable nodes that's loaded from and save to db """ - height: int # height of subtree - size: int # size of subtree - version: int # the version created at - key: bytes - - # only in leaf node - value: Optional[bytes] - - # only in branch nodes - left_node_ref: Optional[bytes] - right_node_ref: Optional[bytes] - - hash: bytes + def __init__( + self, + height: int, + size: int, + version: int, + key: bytes, + value: Optional[bytes], + left_node_ref: Optional[bytes], + right_node_ref: Optional[bytes], + hash: bytes, + ): + self.height = height # height of subtree + self.size = size # size of subtree + self.version = version # the version created at + self.key = key + + # only in leaf node + self.value = value + + # only in branch nodes + self.left_node_ref = left_node_ref + self.right_node_ref = right_node_ref + self.hash = hash def is_leaf(self): return self.height == 0 @@ -62,16 +79,10 @@ def right_node(self, ndb): return ndb.get(self.right_node_ref) def as_json(self): - d = self._asdict() - d["key"] = HexBytes(self.key).hex() - if self.value is not None: - d["value"] = HexBytes(self.value).hex() - if self.left_node_ref is not None: - d["left_node_ref"] = HexBytes(self.left_node_ref).hex() - if self.right_node_ref is not None: - d["right_node_ref"] = HexBytes(self.right_node_ref).hex() - if self.hash is not None: - d["hash"] = HexBytes(self.hash).hex() + d = self.__dict__.copy() + for key, value in d.items(): + if isinstance(value, bytes): + d[key] = HexBytes(value).hex() return d def encode(self): @@ -80,11 +91,6 @@ def encode(self): def calc_balance(self, ndb): return self.left_node(ndb).height - self.right_node(ndb).height - @staticmethod - def decode(bz: bytes, hash: bytes): - nd, _ = decode_node(bz, hash) - return nd - def incr_bytes(prefix: bytes) -> bytes: bz = list(prefix) @@ -125,16 +131,44 @@ def prefix_iteritems( return ((k.removeprefix(prefix), v) for k, v in it) +def root_key_suffix(v: int) -> bytes: + return struct.pack(">qI", v, 1) + + +def node_key_suffix(v, n: int) -> bytes: + return struct.pack(">qI", v, n) + + def root_key(v: int) -> bytes: - return b"r" + v.to_bytes(8, "big") + return NODE_KEY_PREFIX + root_key_suffix(v) + + +def legacy_root_key(v: int) -> bytes: + return ROOT_KEY_PREFIX + v.to_bytes(8, "big") def node_key(hash: bytes) -> bytes: - return b"n" + hash + return NODE_KEY_PREFIX + hash + + +def legacy_node_key(hash: bytes) -> bytes: + return LEGACY_NODE_KEY_PREFIX + hash + + +def parse_node_key(key: bytes) -> tuple[int, int]: + """ + parse version and nonce from given key + """ + if not key.startswith(NODE_KEY_PREFIX): + raise ValueError("Key must start with root prefix 's'") + key = key[len(NODE_KEY_PREFIX) :] + if len(key) != 12: + raise ValueError("Key must be 12 bytes after prefix") + return struct.unpack(">qI", key) def fast_node_key(key: bytes) -> bytes: - return b"f" + key + return FAST_KEY_PREFIX + key def store_prefix(s: str) -> bytes: @@ -154,13 +188,20 @@ def prev_version(db: DBM, store: str, v: int) -> Optional[int]: k = next(it, None) if k is None: # empty db - return + return None + if k >= target: - k = next(it) - if not k.startswith(prefix + b"r"): - return - # parse version from key - return int.from_bytes(k[len(prefix) + 1 :], "big") + k = next(it, None) + if k is None: + return None + + if k.startswith(prefix + ROOT_KEY_PREFIX): + # parse version from legacy key + return int.from_bytes(k[len(prefix) + 1 :], "big") + elif k.startswith(prefix + NODE_KEY_PREFIX): + k = k[len(prefix) :] + version, _ = parse_node_key(k) + return version def iavl_latest_version(db: DBM, store: str) -> Optional[int]: @@ -173,6 +214,10 @@ def decode_bytes(bz: bytes) -> (bytes, int): return bz[n : n + l], n + l +def decode_varint(bz: bytes) -> (int, int): + return cprotobuf.decode_primitive(bz, "sint64") + + def encode_bytes(bz: bytes) -> List[bytes]: return [ cprotobuf.encode_primitive("uint64", len(bz)), @@ -193,14 +238,18 @@ def encode_node(node: PersistedNode) -> bytes: return b"".join(chunks) -def decode_node(bz: bytes, hash: bytes) -> (PersistedNode, int): +def decode_node(bz: bytes, hash: bytes, legacy: bool) -> (PersistedNode, int): offset = 0 height, n = cprotobuf.decode_primitive(bz[offset:], "sint64") + if height < 0: + return None, offset offset += n size, n = cprotobuf.decode_primitive(bz[offset:], "sint64") offset += n - version, n = cprotobuf.decode_primitive(bz[offset:], "sint64") - offset += n + version = None + if legacy: + version, n = cprotobuf.decode_primitive(bz[offset:], "sint64") + offset += n key, n = decode_bytes(bz[offset:]) offset += n @@ -211,11 +260,31 @@ def decode_node(bz: bytes, hash: bytes) -> (PersistedNode, int): value, n = decode_bytes(bz[offset:]) offset += n else: - # container node, read children - left_hash, n = decode_bytes(bz[offset:]) - offset += n - right_hash, n = decode_bytes(bz[offset:]) - offset += n + if legacy: + # container node, read children + left_hash, n = decode_bytes(bz[offset:]) + offset += n + right_hash, n = decode_bytes(bz[offset:]) + offset += n + else: + # hash + _, n = decode_bytes(bz[offset:]) + offset += n + # mode + _, n = decode_varint(bz[offset:]) + offset += n + left_version = right_version = left_nonce = right_nonce = None + left_version, n = decode_varint(bz[offset:]) + offset += n + left_nonce, n = decode_varint(bz[offset:]) + offset += n + right_version, n = decode_varint(bz[offset:]) + offset += n + right_nonce, n = decode_varint(bz[offset:]) + left_hash = node_key_suffix(left_version, left_nonce) + right_hash = node_key_suffix(right_version, right_nonce) + if not legacy: + version, _ = parse_node_key(node_key(hash)) return ( PersistedNode( height=height, @@ -241,26 +310,47 @@ def decode_fast_node(bz: bytes) -> (int, bytes, int): def get_node( - db: DBM, hash: bytes, store: Optional[str] = None + db: DBM, + hash: bytes, + store: Optional[str] = None, ) -> Optional[PersistedNode]: - prefix = store_prefix(store) if store is not None else b"" - bz = db.get(prefix + node_key(hash)) - if not bz: - return - node, _ = decode_node(bz, hash) - return node + prefix = store_prefix(store) if store else b"" + for key_func in (node_key, legacy_node_key): + key = key_func(hash) + bz = db.get(prefix + key) + if bz: + legacy = key_func is legacy_node_key + node, _ = decode_node(bz, hash, legacy) + return node + return None def get_root_node( db: DBM, version: int, store: Optional[str] = None ) -> Optional[PersistedNode]: - prefix = store_prefix(store) if store is not None else b"" - hash = db.get(prefix + root_key(version)) - if not hash: - return + hash = get_root_hash(db, store, version) return get_node(db, hash, store) +def is_legacy(db: DBM, store: str) -> bool: + prefix = store_prefix(store) if store else b"" + version = 1 + for key_func in (root_key, legacy_root_key): + if db.get(prefix + key_func(version)) is not None: + return key_func is legacy_root_key + return False + + +def get_root_hash(db: DBM, store: str, version: int) -> Optional[bytes]: + prefix = store_prefix(store) if store else b"" + bz = db.get(prefix + root_key(version)) + if not bz: + return db.get(prefix + legacy_root_key(version)) + key = root_key(version) + _, nonce = parse_node_key(key) + return root_key_suffix(nonce) + + def iter_fast_nodes(db: DBM, store: str, start: Optional[bytes], end: Optional[bytes]): """ normal kv db iteration @@ -268,7 +358,7 @@ def iter_fast_nodes(db: DBM, store: str, start: Optional[bytes], end: Optional[b """ it = db.iteritems() - prefix = store_prefix(store) + b"f" + prefix = store_prefix(store) + FAST_KEY_PREFIX if start is None: start = prefix else: @@ -304,18 +394,15 @@ def iter_iavl_tree( # empty root node return - prefix = store_prefix(store) if store is not None else b"" - - def get_node(hash: bytes) -> PersistedNode: - n, _ = decode_node(db.get(prefix + node_key(hash)), hash) - return n + def get(hash: bytes) -> PersistedNode: + return get_node(db, hash, store) def prune_check(node: PersistedNode) -> (bool, bool): prune_left = start is not None and node.key <= start prune_right = end is not None and node.key >= end return prune_left, prune_right - for node in visit_iavl_nodes(get_node, prune_check, node_hash): + for node in visit_iavl_nodes(get, prune_check, node_hash): if node.is_leaf() and within_range(node.key, start, end): yield node.key, node.value @@ -340,7 +427,12 @@ def visit_iavl_nodes( yield hash_or_node continue + if hash_or_node is None: + continue + node = get_node(hash_or_node) + if node is None: + continue if not preorder: # postorder, visit later diff --git a/iavl/visualize.py b/iavl/visualize.py index 01719ec..da5af39 100644 --- a/iavl/visualize.py +++ b/iavl/visualize.py @@ -1,11 +1,11 @@ import binascii -from typing import List +from typing import List, Optional from graphviz import Digraph from hexbytes import HexBytes from .iavl import NodeDB -from .utils import PersistedNode, decode_node, node_key +from .utils import PersistedNode, get_node def label(node: PersistedNode): @@ -21,23 +21,23 @@ def label(node: PersistedNode): def visualize_iavl( - db, prefix: bytes, root_hash: bytes, version: int, root_hash2=None + db, + root_hash: bytes, + version: int, + root_hash2=None, + store: Optional[str] = None, ) -> Digraph: g = Digraph(comment="IAVL Tree") - def get_node(hash: bytes) -> PersistedNode: - n, _ = decode_node(db.get(prefix + node_key(hash)), hash) - return n - def vis_node(hash: bytes, n: PersistedNode): style = "solid" if n.version == version else "filled" - g.node(HexBytes(hash).hex(), label=label(node), style=style) + g.node(HexBytes(hash).hex(), label=label(n), style=style) if root_hash2 is not None: stack: List[bytes] = [root_hash2] while stack: hash = stack.pop() - node = get_node(hash) + node = get_node(db, hash, store) vis_node(hash, node) @@ -55,7 +55,7 @@ def vis_node(hash: bytes, n: PersistedNode): stack: List[bytes] = [root_hash] while stack: hash = stack.pop() - node = get_node(hash) + node = get_node(db, hash, store) # don't duplicate nodes in compare mode if root_hash2 is None or node.version == version: diff --git a/tests/test_diff.py b/tests/test_diff.py index 0140dd4..a60e8fd 100644 --- a/tests/test_diff.py +++ b/tests/test_diff.py @@ -1,4 +1,5 @@ import rocksdb + from iavl.diff import state_changes from iavl.iavl import NodeDB diff --git a/tests/test_iavl.py b/tests/test_iavl.py index 591b1f3..839b570 100644 --- a/tests/test_iavl.py +++ b/tests/test_iavl.py @@ -2,6 +2,7 @@ import rocksdb from hexbytes import HexBytes + from iavl.diff import Op, apply_change_set from iavl.iavl import NodeDB, Tree @@ -58,18 +59,18 @@ def setup_test_tree(kvdb: rocksdb.DB): db = NodeDB(kvdb) tree = Tree(db, 0) apply_change_set(tree, ChangeSets[0]) - tree.save_version() + tree.save_version(legacy=True) tree = Tree(db, 1) assert b"world" == tree.get(b"hello") apply_change_set(tree, ChangeSets[1]) - tree.save_version() + tree.save_version(legacy=True) tree = Tree(db, 2) assert b"world1" == tree.get(b"hello") assert b"world1" == tree.get(b"hello1") apply_change_set(tree, ChangeSets[2]) - tree.save_version() + tree.save_version(legacy=True) tree = Tree(db, 3) assert b"world1" == tree.get(b"hello3") @@ -78,17 +79,17 @@ def setup_test_tree(kvdb: rocksdb.DB): assert 2 == node.height apply_change_set(tree, ChangeSets[3]) - tree.save_version() + tree.save_version(legacy=True) # remove nothing assert tree.remove(b"not exists") is None apply_change_set(tree, ChangeSets[4]) - tree.save_version() + tree.save_version(legacy=True) assert not tree.get(b"hello") apply_change_set(tree, ChangeSets[5]) - tree.save_version() + tree.save_version(legacy=True) # test cache miss db = NodeDB(kvdb) @@ -98,7 +99,7 @@ def setup_test_tree(kvdb: rocksdb.DB): # remove most of the values apply_change_set(tree, ChangeSets[6]) - tree.save_version() + tree.save_version(legacy=True) def test_basic_ops(tmp_path): @@ -134,7 +135,7 @@ def test_empty_tree(tmp_path): tree = Tree(db, 0) assert tree.get("hello") is None assert tree.remove("hello") is None - tree.save_version() + tree.save_version(legacy=True) assert tree.version == 1 @@ -147,7 +148,7 @@ def test_new_key(tmp_path): tree = Tree(db, 0) for i in range(4): tree.set(f"key-{i}".encode(), b"1") - tree.save_version() + tree.save_version(legacy=True) # the smallest key in the right half of the tree assert tree.root_node().key == b"key-2" diff --git a/tests/test_prune.py b/tests/test_prune.py index 587cc59..6891555 100644 --- a/tests/test_prune.py +++ b/tests/test_prune.py @@ -1,4 +1,5 @@ import rocksdb + from iavl.iavl import NodeDB from iavl.utils import iter_iavl_tree @@ -16,7 +17,7 @@ def test_prune_tree(tmp_path): latest_version = db.latest_version() for i in range(1, latest_version): print("delete version", i) - assert EXPECT_OUTPUT[i + 1].orphaned == db.delete_version(i) + assert EXPECT_OUTPUT[i + 1].orphaned == db.delete_version(i, legacy=True) # check the integrity of the other versions for j in range(i + 1, latest_version): for _ in iter_iavl_tree(kvdb, None, db.get_root_hash(j), None, None): diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..187b51a --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,10 @@ +from iavl.utils import NODE_KEY_PREFIX, node_key_suffix, parse_node_key + + +def test_parse_key(): + v = 1234567890123456789 + n = 1 + suffix = node_key_suffix(v, n) + assert len(suffix) == 12 + assert suffix == b'\x11"\x10\xf4}\xe9\x81\x15\x00\x00\x00\x01' + assert parse_node_key(NODE_KEY_PREFIX + suffix) == (v, n)