diff --git a/pyzx/circuit/gates.py b/pyzx/circuit/gates.py index e2e7ef89..bef81f8b 100644 --- a/pyzx/circuit/gates.py +++ b/pyzx/circuit/gates.py @@ -592,7 +592,7 @@ def to_graph(self, g, q_mapper, _c_mapper): g.add_edge((t,c)) for r_ in range(top, bot + 1): q_mapper.set_next_row(r_, r+1) - g.scalar.add_power(1) + g.mult_scalar_by_sqrt2_power(1) def to_emoji(self,strings: List[List[str]]) -> None: c,t = self.control, self.target @@ -634,7 +634,7 @@ def to_graph(self, g, q_mapper, _c_mapper): g.add_edge((t,c), EdgeType.HADAMARD) q_mapper.set_next_row(self.target, r+1) q_mapper.set_next_row(self.control, r+1) - g.scalar.add_power(1) + g.mult_scalar_by_sqrt2_power(1) def to_emoji(self,strings: List[List[str]]) -> None: c,t = self.control, self.target @@ -656,7 +656,7 @@ def to_graph(self, g, q_mapper, _c_mapper): g.add_edge((t,c), EdgeType.HADAMARD) q_mapper.set_next_row(self.target, r+1) q_mapper.set_next_row(self.control, r+1) - g.scalar.add_power(1) + g.mult_scalar_by_sqrt2_power(1) def to_basic_gates(self): return [HAD(self.control), CNOT(self.control,self.target), HAD(self.control)] @@ -912,7 +912,7 @@ def to_graph(self, g, q_mapper, _c_mapper): #rs[self.control] = r+2 q_mapper.shift_all_rows(4) - g.scalar.add_power(1) + g.mult_scalar_by_sqrt2_power(1) # c1 = self.graph_add_node(g, q_mapper, VertexType.Z, self.control,r) # t1 = self.graph_add_node(g, q_mapper, VertexType.Z, self.target,r) diff --git a/pyzx/graph/base.py b/pyzx/graph/base.py index 2f49e911..de8017fd 100644 --- a/pyzx/graph/base.py +++ b/pyzx/graph/base.py @@ -287,6 +287,35 @@ def multigraph(self) -> bool: # Backends may wish to override these methods to implement them more efficiently + + # Helper functions for mutating the scalar + # + # AK: I suggest new code *only* uses these functions to modify the scalar and we deprecate mutating + # the scalar directly (which doesn't work correctly with the Rust backend). I also picked slightly + # clearer names for these options, as the "add_..." method names are misleading for these operations, + # which all entail scalar multiplication. + + def mult_scalar_by_phase(self, phase: FractionLike) -> None: + """Multiplies the scalar by a phase factor.""" + self.scalar.add_phase(phase) + + def mult_scalar_by_one_plus_phase(self, phase: FractionLike) -> None: + """Multiplies the scalar by a phase factor.""" + self.scalar.add_node(phase) + + def mult_scalar_by_sqrt2_power(self, power: int) -> None: + """Multiplies the scalar by sqrt(2) raised to the given power.""" + self.scalar.add_power(power) + + def mult_scalar_by_scalar(self, scalar: Scalar) -> None: + """Multiplies scalar with the given scalar""" + self.scalar.mult_with_scalar(scalar) + + def mult_scalar_by_spider_pair(self, phase1: FractionLike, phase2: FractionLike) -> None: + """Multiplies scalar with a 'spider pair', i.e. a pair of phased Z-spiders connected by an H edge""" + self.scalar.add_spider_pair(phase1, phase2) + + # These methods return mappings from vertices to various pieces of data. If the backend # stores these e.g. as Python dicts, just return the relevant dicts. def phases(self) -> Mapping[VT, FractionLike]: diff --git a/pyzx/graph/scalar.py b/pyzx/graph/scalar.py index ca51837f..bb78ace8 100644 --- a/pyzx/graph/scalar.py +++ b/pyzx/graph/scalar.py @@ -245,7 +245,7 @@ def mult_with_scalar(self, other: 'Scalar') -> None: if other.is_zero: self.is_zero = True if other.is_unknown: self.is_unknown = True - def add_spider_pair(self, p1: FractionLike,p2: FractionLike) -> None: + def add_spider_pair(self, p1: FractionLike, p2: FractionLike) -> None: """Add the scalar corresponding to a connected pair of spiders (p1)-H-(p2).""" # These if statements look quite arbitrary, but they are just calculations of the scalar # of a pair of connected single wire spiders of opposite colors. diff --git a/pyzx/rules.py b/pyzx/rules.py index aa3343e2..6f165db1 100644 --- a/pyzx/rules.py +++ b/pyzx/rules.py @@ -652,18 +652,18 @@ def pivot(g: BaseGraph[VT,ET], matches: List[MatchPivotType[VT]]) -> RewriteOutp [(s,t) for s in n[1] for t in n[2]] + [(s,t) for s in n[0] for t in n[2]]) k0, k1, k2 = len(n[0]), len(n[1]), len(n[2]) - g.scalar.add_power(k0*k2 + k1*k2 + k0*k1) + g.mult_scalar_by_sqrt2_power(k0*k2 + k1*k2 + k0*k1) for v in n[2]: if not g.is_ground(v): g.add_to_phase(v, 1) - if g.phase(m[0][0]) and g.phase(m[0][1]): g.scalar.add_phase(Fraction(1)) + if g.phase(m[0][0]) and g.phase(m[0][1]): g.mult_scalar_by_phase(Fraction(1)) if not m[1][0] and not m[1][1]: - g.scalar.add_power(-(k0+k1+2*k2-1)) + g.mult_scalar_by_sqrt2_power(-(k0+k1+2*k2-1)) elif not m[1][0]: - g.scalar.add_power(-(k1+k2)) - else: g.scalar.add_power(-(k0+k2)) + g.mult_scalar_by_sqrt2_power(-(k1+k2)) + else: g.mult_scalar_by_sqrt2_power(-(k0+k2)) for i in 0, 1: # if m[i] has a phase, it will get copied on to the neighbors of m[1-i]: @@ -762,10 +762,10 @@ def lcomp(g: BaseGraph[VT,ET], matches: List[MatchLcompType[VT]]) -> RewriteOutp a = g.phase(m[0]) rem.append(m[0]) assert isinstance(a,Fraction) # For mypy - if a.numerator == 1: g.scalar.add_phase(Fraction(1,4)) - else: g.scalar.add_phase(Fraction(7,4)) + if a.numerator == 1: g.mult_scalar_by_phase(Fraction(1,4)) + else: g.mult_scalar_by_phase(Fraction(7,4)) n = len(m[1]) - g.scalar.add_power((n-2)*(n-1)//2) + g.mult_scalar_by_sqrt2_power((n-2)*(n-1)//2) for i in range(n): if not g.is_ground(m[1][i]): g.add_to_phase(m[1][i], -a) @@ -892,7 +892,7 @@ def hopf(g: BaseGraph[VT,ET], matches: List[MatchHopfType[VT]]) -> RewriteOutput n = g.num_edges(v,w,et) parity = n % 2 rem_edges.extend([g.edge(v, w, et)]*(n - parity)) - g.scalar.add_power(-(n-parity)) + g.mult_scalar_by_sqrt2_power(-(n-parity)) return (etab, [], rem_edges, False) @@ -934,7 +934,7 @@ def remove_self_loops(g: BaseGraph[VT,ET], matches: List[MatchSelfLoopType[VT]]) for v,ns,nh in matches: rem_edges.extend([g.edge(v, v, EdgeType.SIMPLE)]*ns) rem_edges.extend([g.edge(v, v, EdgeType.HADAMARD)]*nh) - g.scalar.add_power(-nh) + g.mult_scalar_by_sqrt2_power(-nh) if nh % 2 == 1: # A Hadamard self-loop gives a phase of pi g.add_to_phase(v, Fraction(1,1)) @@ -979,16 +979,16 @@ def match_phase_gadgets(g: BaseGraph[VT,ET],vertexf:Optional[Callable[[VT],bool] n = gad[0] v = gadgets[n] if phases[n] != 0: # If the phase of the axel vertex is pi, we change the phase of the gadget - g.scalar.add_phase(phases[v]) + g.mult_scalar_by_phase(phases[v]) g.phase_negate(v) m.append((v,n,-phases[v],[],[])) else: totphase = sum((1 if phases[n]==0 else -1)*phases[gadgets[n]] for n in gad)%2 for n in gad: if phases[n] != 0: - g.scalar.add_phase(phases[gadgets[n]]) + g.mult_scalar_by_phase(phases[gadgets[n]]) g.phase_negate(gadgets[n]) - g.scalar.add_power(-((len(par)-1)*(len(gad)-1))) + g.mult_scalar_by_sqrt2_power(-((len(par)-1)*(len(gad)-1))) n = gad.pop() v = gadgets[n] m.append((v,n,totphase, gad, [gadgets[n] for n in gad])) @@ -1064,20 +1064,20 @@ def apply_supplementarity( rem.append(w) alpha = g.phase(v) beta = g.phase(w) - g.scalar.add_power(-2*len(neigh)) + g.mult_scalar_by_sqrt2_power(-2*len(neigh)) if t == 1: # v and w are not connected - g.scalar.add_node(2*alpha+1) + g.mult_scalar_by_one_plus_phase(2*alpha+1) #if (alpha-beta)%2 == 1: # Standard supplementarity if (alpha+beta)%2 == 1: # Need negation on beta - g.scalar.add_phase(-alpha + 1) + g.mult_scalar_by_phase(-alpha + 1) for n in neigh: g.add_to_phase(n,1) elif t == 2: # they are connected - g.scalar.add_power(-1) - g.scalar.add_node(2*alpha) + g.mult_scalar_by_sqrt2_power(-1) + g.mult_scalar_by_one_plus_phase(2*alpha) #if (alpha-beta)%2 == 1: # Standard supplementarity if (alpha+beta)%2 == 0: # Need negation - g.scalar.add_phase(-alpha) + g.mult_scalar_by_phase(-alpha) for n in neigh: g.add_to_phase(n,1) else: raise Exception("Shouldn't happen") @@ -1118,8 +1118,8 @@ def apply_copy(g: BaseGraph[VT,ET], matches: List[MatchCopyType[VT]]) -> Rewrite for v,w,a,alpha, neigh in matches: rem.append(v) rem.append(w) - g.scalar.add_power(-len(neigh)+1) - if a: g.scalar.add_phase(alpha) + g.mult_scalar_by_sqrt2_power(-len(neigh)+1) + if a: g.mult_scalar_by_phase(alpha) for n in neigh: if types[n] == VertexType.BOUNDARY: r = g.row(n) - 1 if n in outputs else g.row(n)+1