Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/python/qubed/Qube.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,28 @@ def hash_node(node: Qube) -> int:

return hash_node(self)

def remove_branch(self, b: Qube) -> Qube:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a docstring would be helpful here I think, e.g.

Suggested change
def remove_branch(self, b: Qube) -> Qube:
def remove_branch(self, b: Qube) -> Qube:
"""
Navigates down self until it finds a key that matches the top key of b, then subtracts b from the subtree.
b can only have 1 child
"""

b_key = b.children[0].key
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we enforce that b only has 1 child?


new_children = []
for c in self.children:
if c.key == b_key:
update_c = type(self).make_root(children=(c,), update_depth=False)
new_c = set_operations.set_operation(
update_c,
b,
set_operations.SetOperation.DIFFERENCE,
type(self),
check_depth=False,
)
if len(new_c.children) != 0:
new_children.extend(new_c.children)
else:
c = c.remove_branch(b)
if len(c.children) != 0:
new_children.append(c)
return self.replace(children=tuple(sorted(new_children)))

def compress(self) -> Qube:
"""
This method is quite computationally heavy because of trees like this:
Expand Down Expand Up @@ -559,7 +581,7 @@ def compare_metadata(self, B: Qube) -> bool:
return False
for k in self.metadata.keys():
if k not in B.metadata:
print(f"'{k}' not in {B.metadata.keys() = }")
print(f"'{k}' not in {B.metadata.keys()=}")
return False
if not np.array_equal(self.metadata[k], B.metadata[k]):
print(f"self.metadata[{k}] != B.metadata.[{k}]")
Expand Down
21 changes: 13 additions & 8 deletions src/python/qubed/set_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def added_axis(size: int, metadata: dict[str, np.ndarray]) -> dict[str, np.ndarr

# @line_profiler.profile
def set_operation(
A: Qube, B: Qube, operation_type: SetOperation, node_type, depth=0
A: Qube, B: Qube, operation_type: SetOperation, node_type, depth=0, check_depth=True
) -> Qube | None:
if DEBUG:
print(f"{pad()}operation({operation_type.name}, depth={depth})")
Expand All @@ -334,7 +334,8 @@ def set_operation(
assert A.key == B.key
assert A.type == B.type
assert A.values == B.values
assert A.depth == B.depth
if check_depth:
assert A.depth == B.depth

new_children: list[Qube] = []

Expand All @@ -347,7 +348,9 @@ def set_operation(
# For every node group, perform the set operation
for A_nodes, B_nodes in nodes_by_key.values():
output = list(
_set_operation(A_nodes, B_nodes, operation_type, node_type, depth + 1)
_set_operation(
A_nodes, B_nodes, operation_type, node_type, depth + 1, check_depth
)
)
new_children.extend(output)

Expand Down Expand Up @@ -398,6 +401,7 @@ def _set_operation(
operation_type: SetOperation,
node_type,
depth: int,
check_depth,
) -> Iterable[Qube]:
"""
This operation get called from `operation` when we've found two nodes that match and now need
Expand Down Expand Up @@ -461,6 +465,7 @@ def make_new_node(source: Qube, values_indices: ValuesIndices):
operation_type,
node_type,
depth=depth + 1,
check_depth=check_depth,
)
if result is not None:
# If we're doing a difference or xor we might want to throw away the intersection
Expand All @@ -482,7 +487,7 @@ def make_new_node(source: Qube, values_indices: ValuesIndices):
continue
else:
raise ValueError(
f"Only one of set_ops_result.intersection_A and set_ops_result.intersection_B is None, I didn't think that could happen! {set_ops_result = }"
f"Only one of set_ops_result.intersection_A and set_ops_result.intersection_B is None, I didn't think that could happen! {set_ops_result=}"
)

if keep_only_A:
Expand Down Expand Up @@ -558,7 +563,7 @@ def merge_values(qubes: list[Qube]) -> Qube:
axis = example.depth

if DEBUG:
print(f"{pad()}merge_values --- {axis = }")
print(f"{pad()}merge_values --- {axis=}")
for i, qube in enumerate(qubes):
qube.display(f"{pad()}in_{i}")

Expand Down Expand Up @@ -694,7 +699,7 @@ def concat_metadata(
example = qubes[0]

if DEBUG:
print(f"concat_metadata --- {axis = }, qubes:")
print(f"concat_metadata --- {axis=}, qubes:")
for qube in qubes:
qube.display()

Expand Down Expand Up @@ -747,8 +752,8 @@ def shallow_concat_metadata(

if DEBUG:
print("shallow_concat_metadata")
print(f"{concatenation_axis = }")
print(f"{sorting_indices = }")
print(f"{concatenation_axis=}")
print(f"{sorting_indices=}")
for k, metadata_group in metadata_groups.items():
print(k, [m.shape for m in metadata_group])

Expand Down
73 changes: 73 additions & 0 deletions tests/test_remove_branch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from qubed import Qube


def test_remove_branch():
a = Qube.from_tree("""
root
β”œβ”€β”€ class=od, expver=0001/0002, param=1/2
└── class=rd
β”œβ”€β”€ expver=0001, param=1/2/3
└── expver=0002, param=1/2
""")

b = Qube.from_tree("""
root
β”œβ”€β”€ class=od, expver=0001/0002, param=1/2
""")

c = Qube.from_tree("""
root
└── class=rd
β”œβ”€β”€ expver=0001, param=1/2/3
└── expver=0002, param=1/2
""")

assert a.remove_branch(b) == c


def test_2():
a = Qube.from_tree("""
root
β”œβ”€β”€ class=od, expver=0001/0002, param=1/2
└── class=rd
β”œβ”€β”€ expver=0001, param=1/2/3
└── expver=0002, param=1/2
""")

b = Qube.from_tree("""
root
└── expver=0001/0002, param=1/2
""")

c = Qube.from_tree("""
root
└── class=rd
β”œβ”€β”€ expver=0001, param=3
""")

assert a.remove_branch(b) == c


def test_3():
a = Qube.from_tree("""
root
β”œβ”€β”€ class=od, expver=0001/0002, param=1/2
└── class=rd
β”œβ”€β”€ expver=0001, param=1/2/3
└── expver=0002, param=1/2
""")

b = Qube.from_tree("""
root
└── expver=0001, param=1/2
""")

c = Qube.from_tree("""
root
β”œβ”€β”€ class=od, expver=0002, param=1/2
└── class=rd
β”œβ”€β”€ expver=0001, param=3
└── expver=0002, param=1/2
""")

assert a.remove_branch(b) == c