diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 475c88cf78d..cb6026ad78c 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -2,6 +2,12 @@

New features since last release

+* A compilation pass written with xDSL called `qml.compiler.python_compiler.transforms.ParitySynthPass` + has been added for the experimental xDSL Python compiler integration. This pass resynthesizes + subcircuits that form a phase polynomial (``CNOT`` and ``RZ`` gates), using ``ParitySynth`` by + [Vandaele et al.](https://arxiv.org/abs/2104.00934) + [(#8414)](https://github.com/PennyLaneAI/pennylane/pull/8414) + * Added a :meth:`~pennylane.devices.DeviceCapabilities.gate_set` method to :class:`~pennylane.devices.DeviceCapabilities` that produces a set of gate names to be used as the target gate set in decompositions. [(#8522)](https://github.com/PennyLaneAI/pennylane/pull/8522) @@ -26,7 +32,7 @@ additional templates. [(#8520)](https://github.com/PennyLaneAI/pennylane/pull/8520) [(#8515)](https://github.com/PennyLaneAI/pennylane/pull/8515) - + - :class:`~.QSVT` - :class:`~.AmplitudeEmbedding` @@ -136,13 +142,13 @@ * Access to the follow functions and classes from the ``pennylane.resources`` module are deprecated. Instead, these functions must be imported from the ``pennylane.estimator`` module. [(#8484)](https://github.com/PennyLaneAI/pennylane/pull/8484) - + - ``qml.estimator.estimate_shots`` in favor of ``qml.resources.estimate_shots`` - ``qml.estimator.estimate_error`` in favor of ``qml.resources.estimate_error`` - ``qml.estimator.FirstQuantization`` in favor of ``qml.resources.FirstQuantization`` - ``qml.estimator.DoubleFactorization`` in favor of ``qml.resources.DoubleFactorization`` -* ``argnum`` has been renamed ``argnums`` for ``qml.grad``, ``qml.jacobian``, ``qml.jvp`` and `qml.vjp``. +* ``argnum`` has been renamed ``argnums`` for ``qml.grad``, ``qml.jacobian``, ``qml.jvp`` and ``qml.vjp``. [(#8496)](https://github.com/PennyLaneAI/pennylane/pull/8496) [(#8481)](https://github.com/PennyLaneAI/pennylane/pull/8481) @@ -155,7 +161,7 @@ * Fix all NumPy 1.X `DeprecationWarnings` in our source code. [(#8497)](https://github.com/PennyLaneAI/pennylane/pull/8497) - + * Update versions for `pylint`, `isort` and `black` in `format.yml` [(#8506)](https://github.com/PennyLaneAI/pennylane/pull/8506) @@ -167,7 +173,7 @@ circuit. A clear error is now also raised when there are observables with overlapping wires. [(#8383)](https://github.com/PennyLaneAI/pennylane/pull/8383) -* Add an `outline_state_evolution_pass` pass to the MBQC xDSL transform, which moves all +* Add an `outline_state_evolution_pass` pass to the MBQC xDSL transform, which moves all quantum gate operations to a private callable. [(#8367)](https://github.com/PennyLaneAI/pennylane/pull/8367) @@ -183,17 +189,17 @@ [(#8486)](https://github.com/PennyLaneAI/pennylane/pull/8486) [(#8495)](https://github.com/PennyLaneAI/pennylane/pull/8495) -* The various private functions of the :class:`~pennylane.estimator.FirstQuantization` class have +* The various private functions of the :class:`~pennylane.estimator.FirstQuantization` class have been modified to avoid using `numpy.matrix` as this function is deprecated. [(#8523)](https://github.com/PennyLaneAI/pennylane/pull/8523) -* The `ftqc` module now includes dummy transforms for several Catalyst/MLIR passes (`to-ppr`, `commute-ppr`, `merge-ppr-ppm`, `pprm-to-mbqc` +* The `ftqc` module now includes dummy transforms for several Catalyst/MLIR passes (`to-ppr`, `commute-ppr`, `merge-ppr-ppm`, `pprm-to-mbqc` and `reduce-t-depth`), to allow them to be captured as primitives in PLxPR and mapped to the MLIR passes in Catalyst. This enables using the passes with the unified compiler and program capture. [(#8519)](https://github.com/PennyLaneAI/pennylane/pull/8519) -* The decompositions for several templates have been updated to use +* The decompositions for several templates have been updated to use :class:`~.ops.op_math.ChangeOpBasis`, which makes their decompositions more resource efficient - by eliminating unnecessary controlled operations. The templates include :class:`~.PhaseAdder`, + by eliminating unnecessary controlled operations. The templates include :class:`~.PhaseAdder`, :class:`~.TemporaryAND`, :class:`~.QSVT`, and :class:`~.SelectPauliRot`. [(#8490)](https://github.com/PennyLaneAI/pennylane/pull/8490) diff --git a/pennylane/compiler/python_compiler/compiler.py b/pennylane/compiler/python_compiler/compiler.py index e924f84d762..79d2d556757 100644 --- a/pennylane/compiler/python_compiler/compiler.py +++ b/pennylane/compiler/python_compiler/compiler.py @@ -55,6 +55,7 @@ def run( xmod = parser.parse_module() pipeline = PassPipeline((ApplyTransformSequence(callback=callback),)) pipeline.apply(ctx, xmod) + print(xmod) buffer = io.StringIO() Printer(stream=buffer, print_generic_format=True).print_op(xmod) with jaxContext() as jctx: diff --git a/pennylane/compiler/python_compiler/dialects/quantum.py b/pennylane/compiler/python_compiler/dialects/quantum.py index c092e1623cf..fae08aab6cf 100644 --- a/pennylane/compiler/python_compiler/dialects/quantum.py +++ b/pennylane/compiler/python_compiler/dialects/quantum.py @@ -492,10 +492,10 @@ class GlobalPhaseOp(IRDLOperation): name = "quantum.gphase" assembly_format = """ - `(` $params `)` - attr-dict - ( `ctrls` `(` $in_ctrl_qubits^ `)` )? - ( `ctrlvals` `(` $in_ctrl_values^ `)` )? + `(` $params `)` + attr-dict + ( `ctrls` `(` $in_ctrl_qubits^ `)` )? + ( `ctrlvals` `(` $in_ctrl_values^ `)` )? `:` type(results) """ diff --git a/pennylane/compiler/python_compiler/transforms/__init__.py b/pennylane/compiler/python_compiler/transforms/__init__.py index d3862b631ca..45240491787 100644 --- a/pennylane/compiler/python_compiler/transforms/__init__.py +++ b/pennylane/compiler/python_compiler/transforms/__init__.py @@ -35,6 +35,8 @@ MeasurementsFromSamplesPass, merge_rotations_pass, MergeRotationsPass, + parity_synth_pass, + ParitySynthPass, ) @@ -50,6 +52,8 @@ "MeasurementsFromSamplesPass", "merge_rotations_pass", "MergeRotationsPass", + "parity_synth_pass", + "ParitySynthPass", # MBQC "convert_to_mbqc_formalism_pass", "ConvertToMBQCFormalismPass", diff --git a/pennylane/compiler/python_compiler/transforms/quantum/__init__.py b/pennylane/compiler/python_compiler/transforms/quantum/__init__.py index b863466bbd9..ea028dc5bf8 100644 --- a/pennylane/compiler/python_compiler/transforms/quantum/__init__.py +++ b/pennylane/compiler/python_compiler/transforms/quantum/__init__.py @@ -23,6 +23,7 @@ measurements_from_samples_pass, ) from .merge_rotations import MergeRotationsPass, merge_rotations_pass +from .parity_synth import ParitySynthPass, parity_synth_pass __all__ = [ @@ -36,4 +37,6 @@ "MeasurementsFromSamplesPass", "merge_rotations_pass", "MergeRotationsPass", + "parity_synth_pass", + "ParitySynthPass", ] diff --git a/pennylane/compiler/python_compiler/transforms/quantum/parity_synth.py b/pennylane/compiler/python_compiler/transforms/quantum/parity_synth.py new file mode 100644 index 00000000000..ae8d842c34d --- /dev/null +++ b/pennylane/compiler/python_compiler/transforms/quantum/parity_synth.py @@ -0,0 +1,353 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This file contains the implementation of the ``ParitySynth`` compiler pass. Given that it +operates on the phase polynomial representation of subcircuits, the implementation splits into +an xDSL-agnostic synthesis functionality and an integration thereof into xDSL.""" + +from dataclasses import dataclass +from itertools import product + +import networkx as nx +import numpy as np +from xdsl import context, passes, pattern_rewriter +from xdsl.dialects import arith, builtin, func +from xdsl.ir import SSAValue +from xdsl.rewriter import InsertPoint + +from .....transforms.intermediate_reps.rowcol import _rowcol_parity_matrix +from ...dialects.quantum import CustomOp, QubitType +from ...pass_api import compiler_transform + +### xDSL-agnostic part + + +def _apply_dfs_po_circuit(tree, source, P, inv_synth_matrix=None): + dfs_po = list(nx.dfs_postorder_nodes(tree, source=source)) + sub_circuit = [] + if inv_synth_matrix is None: + for i, j in zip(dfs_po[:-1], dfs_po[1:]): + sub_circuit.append((i, j)) + P[i] += P[j] + else: + for i, j in zip(dfs_po[:-1], dfs_po[1:]): + sub_circuit.append((i, j)) + P[i] += P[j] + inv_synth_matrix[:, i] += inv_synth_matrix[:, j] + P %= 2 + return sub_circuit + + +def _loop_body_parity_network_synth( + P: np.ndarray, + inv_synth_matrix: np.ndarray, + circuit: list[int, list[tuple[int]]], +) -> tuple[np.ndarray, list]: + """Loop body function for ``_parity_network_synth``, the main subroutine of ``parity_synth``. + The loop body corresponds to synthesizing one parity in the parity table ``P``, and updating + all relevant data accordingly. It is the ``for``-loop body in Algorithm 1 + in https://arxiv.org/abs/2104.00934. + + Args: + P (np.ndarray): (Remaining) parity table for which to synthesize the parity network. + inv_synth_matrix (np.ndarray): Inverse of the parity _matrix_ implemented within + the parity network that has been synthesized so far. + circuit (list[int, list[tuple[int]]]): Circuit for the parity network that has been + synthesized so far. Each entry of the list consists of a _relative_ index into + the list of parities (or rotation angles) of the phase polynomial, a qubit + index onto which the rotation should be applied, and the subcircuit that should + be applied _before_ the rotation to achieve the respective parity. + + Returns: + tuple[np.ndarray, list]: Same as inputs, with updates applied; ``P`` has a column less + and has been transformed in addition. ``inv_synth_matrix`` has been transformed + according to the newly synthesized subcircuit implementing the next parity. The + ``circuit`` representation is grown by one entry, corresponding to that parity. + + """ + parity_idx = np.argmin(np.sum(P, axis=0)) # ┬ Line 3 + parity = P[:, parity_idx] # ╯ + graph_nodes = list(map(int, np.where(parity)[0])) # Line 5, vertices + if len(graph_nodes) == 1: + # The parity already has Hamming weight 1, so we don't need any modifications + # Just slice out the parity and append the parity/angle index as well as the qubit + # on which the parity has support + P = np.concatenate([P[:, :parity_idx], P[:, parity_idx + 1 :]], axis=1) # Line 4 + circuit.append((parity_idx, graph_nodes[0], [])) # Record parity index, qubit index, CNOTs + return P, inv_synth_matrix, circuit + + # Note that there is a bug in the algorithm as written in the paper: We first want to compute + # the edge weights for parity_graph (G_y) and _then_ slice out `parity` from `P`. + single_weights = np.sum(P, axis=1) # ╮ + parity_graph = nx.DiGraph() # │ + parity_graph.add_weighted_edges_from( # │ + [ # │ + (i, j, np.sum(np.mod(P[i] + P[j], 2)) - single_weights[j]) # ├ Line 5, edges + for i, j in product(graph_nodes, repeat=2) # │ + if i != j # │ + ] # │ + ) # ╯ + arbor = nx.minimum_spanning_arborescence(parity_graph) # Line 6 + + # Find the root of the tree + root = next(iter(node for node, degree in arbor.in_degree() if degree == 0)) + + P = np.concatenate([P[:, :parity_idx], P[:, parity_idx + 1 :]], axis=1) # Line 4 + # Lines 7-10, update P and inv_synth_matrix in place + sub_circuit = _apply_dfs_po_circuit(arbor, root, P, inv_synth_matrix) + circuit.append((parity_idx, root, sub_circuit)) # Record parity index, qubit index, CNOTs + return P, inv_synth_matrix, circuit + + +def _parity_network_synth(P: np.ndarray) -> list[int, list[tuple[int]]]: + """Main subroutine for the ``ParitySynth`` pass, mostly a ``for``-loop wrapper around + ``_loop_body_parity_network_synth``. It synthesizes the parity network, as described + in Algorithm 1 in https://arxiv.org/abs/2104.00934. + + Args: + P (np.ndarray): Parity table to be synthesized. + Shape should be ``(num_wires, num_parities)`` + + Returns: + tuple[list[int, list[tuple[int]]], np.ndarray]: Synthesized parity network, as a + circuit with structure as described in ``_loop_body_parity_network_synth``. Also, + inverse of the parity matrix implemented by the synthesized circuit. + + """ + if P.shape[-1] == 0: + # Nothing to do if there are not parities + return [], None + + circuit = [] # Line 1 in Alg. 1 + num_wires, num_parities = P.shape + # Initialize an inverse parity matrix that is updated with the CNOTs that are synthesized here. + inv_synth_mat = np.eye(num_wires, dtype=int) + # `num_parities` loop iterations because each loop body takes care of one parity, we just + # don't know which one. This makes the `for`-loop equivalent to line 2 in Alg. 1 + for _ in range(num_parities): + P, inv_synth_mat, circuit = _loop_body_parity_network_synth(P, inv_synth_mat, circuit) + + return circuit, inv_synth_mat % 2 + + +### end of xDSL-agnostic part + +valid_phase_polynomial_ops = {"CNOT", "RZ"} + + +def make_phase_polynomial( + ops: list[CustomOp], + init_wire_map: dict[QubitType, int], +) -> tuple[np.ndarray]: + r"""Compute the phase polynomial representation of a list of ``CustomOp``\ s. + This implementation is very similar to :func:`~.transforms.intermediate_reps.phase_polynomial` + but adjusted to work with xDSL objects.""" + wire_map = init_wire_map + + parity_matrix = np.eye(len(wire_map), dtype=int) + parity_table = [] + angles = [] + arith_ops = [] + for op in ops: + name = op.gate_name.data + if name == "CNOT": + control, target = wire_map.pop(op.in_qubits[0]), wire_map.pop(op.in_qubits[1]) + parity_matrix[target] += parity_matrix[control] + wire_map[op.out_qubits[0]] = control + wire_map[op.out_qubits[1]] = target + continue + + # RZ + angle = op.operands[0] + if getattr(op, "adjoint", False): + neg_op = arith.NegfOp(angle) + arith_ops.append(neg_op) + angle = neg_op.result + angles.append(angle) + wire = wire_map[op.in_qubits[0]] + parity_table.append(parity_matrix[wire].copy()) # append _current_ parity (hence the copy) + wire_map[op.out_qubits[0]] = wire + + return parity_matrix % 2, np.array(parity_table).T % 2, angles, arith_ops + + +class ParitySynthPattern(pattern_rewriter.RewritePattern): + """Rewrite pattern that applies ``ParitySynth`` to subcircuits that constitute + phase polynomials. + """ + + phase_polynomial_ops: list[CustomOp] + init_wire_map: [QubitType, int] + global_wire_map: [QubitType, int] + phase_polynomial_ops: set[QubitType] + num_phase_polynomial_qubits: int + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._reset_vars() + + def _reset_vars(self): + """Initialize/reset variables that are used in ``match_and_rewrite`` as well as + ``rewrite_phase_polynomial``.""" + self.phase_polynomial_ops = [] + self.init_wire_map = {} + self.phase_polynomial_qubits = set() + self.num_phase_polynomial_qubits = 0 + + def _record_phase_poly_op(self, op: CustomOp): + """Add a ``CustomOp`` to the phase polynomial ops, remove its input qubits + from ``self.phase_polynomial_qubits`` if present or add them to ``self.init_wire_map`` + if not, and insert its output qubits in ``self.phase_polynomial_qubits``.""" + for i, q in enumerate(op.in_qubits): + if q in self.phase_polynomial_qubits: + self.phase_polynomial_qubits.remove(q) + else: + self.init_wire_map[q] = self.num_phase_polynomial_qubits + self.num_phase_polynomial_qubits += 1 + self.phase_polynomial_qubits.add(op.out_qubits[i]) + self.phase_polynomial_ops.append(op) + + @pattern_rewriter.op_type_rewrite_pattern + def match_and_rewrite(self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter): + r"""Implementation of rewriting ``FuncOps`` that may contain phase poynomials + with ``ParitySynth``. + + Args: + funcOp (func.FuncOp): function containing the operations to rewrite. + rewriter (pattern_rewriter.PatternRewriter): Rewriter that executed operation erasure + and insertion. + + The logic of this implementation is centered around :attr:`~.rewrite_phase_polynomial`, + which is able to rewrite a collection of ``CustomOp``\ s that forms a phase polynomial + (see ``valid_phase_polynomial_ops`` for the supported types) into a new collection of + ``CustomOp``\ s that is equivalent. In addition to the operators, which are stored in + ``self.phase_polynomial_ops``, the ``rewrite_phase_polynomial`` subroutine requires + the initial mapping from input qubits to integer-valued wire positions, which is computed + in ``self.init_wire_map`` using temporary variables ``self.phase_polynomial_qubits`` + and ``self.num_phase_polynomial_qubits``. + + Iterating over all operations, the collected phase polynomial ops are rewritten as soon + as a non-phase-polynomial operation is encountered. Note that this makes the (size of the) + rewritten phase polynomials dependent on the order in which we walk over the operations. + """ + for op in funcOp.body.walk(): + if not isinstance(op, CustomOp): + # Non-quantum operation. Global phases are ignored as well. + continue + + gate_name = op.gate_name.data + if gate_name in valid_phase_polynomial_ops: + # Include op in phase polynomial ops and track its qubits + self._record_phase_poly_op(op) + continue + + # not a phase polynomial op, so we activate rewriting of the phase polynomial + self.rewrite_phase_polynomial(rewriter) + + # end of operations; rewrite terminal phase polynomial + self.rewrite_phase_polynomial(rewriter) + + @staticmethod + def _cnot(i: int, j: int, inv_wire_map: dict[int, QubitType]): + """Create a CNOT operator acting on the qubits that map to wires ``i`` and ``j`` + and update the wire map so that ``i`` and ``j`` point to the output qubits afterwards.""" + cnot_op = CustomOp( + in_qubits=[inv_wire_map[i], inv_wire_map[j]], + gate_name="CNOT", + params=tuple(), + ) + inv_wire_map[i] = cnot_op.out_qubits[0] + inv_wire_map[j] = cnot_op.out_qubits[1] + return cnot_op + + @staticmethod + def _rz(wire: int, angle: SSAValue[builtin.Float64Type], inv_wire_map: dict[int, QubitType]): + """Create a CNOT operator acting on the qubit that maps to ``wire`` + and update the wire map so that ``wire`` points to the output qubit afterwards.""" + rz_op = CustomOp(in_qubits=[inv_wire_map[wire]], gate_name="RZ", params=(angle,)) + inv_wire_map[wire] = rz_op.out_qubits[0] + return rz_op + + def rewrite_phase_polynomial(self, rewriter: pattern_rewriter.PatternRewriter): + """Rewrite a single region of a circuit that represents a phase polynomial.""" + if not self.phase_polynomial_ops: + # Nothing to do + return + + if len(self.phase_polynomial_ops) == 1: + # Phase polynomials of length 1 are left untouched. Reset internal state + self._reset_vars() + return + + insertion_point: InsertPoint = InsertPoint.after(self.phase_polynomial_ops[-1]) + + # Mapping from integer-valued wire positions to qubits, corresponding to state before + # phase polynomial + inv_wire_map: dict[int, QubitType] = {val: key for key, val in self.init_wire_map.items()} + + # Calculate the new circuit by going to phase polynomial IR and back, including synthesis + # of trailing CNOTs via rowcol + M, P, angles, arith_ops = make_phase_polynomial( + self.phase_polynomial_ops, self.init_wire_map + ) + + # Insert arithmetic operations produced within `make_phase_polynomial` + for op in arith_ops: + rewriter.insert_op(op, insertion_point) + + subcircuits, inv_network_parity_matrix = _parity_network_synth(P) + # `inv_network_parity_matrix` might be None if the parity table was empty + if inv_network_parity_matrix is not None: + M = (M @ inv_network_parity_matrix) % 2 + rowcol_circuit: list[tuple[int]] = _rowcol_parity_matrix(M) + + # Apply the parity network part of the new circuit + for idx, phase_wire, subcircuit in subcircuits: + for i, j in subcircuit: + rewriter.insert_op(self._cnot(i, j, inv_wire_map), insertion_point) + + rewriter.insert_op(self._rz(phase_wire, angles.pop(idx), inv_wire_map), insertion_point) + + # Apply the remaining parity matrix part of the new circuit + for i, j in rowcol_circuit: + rewriter.insert_op(self._cnot(i, j, inv_wire_map), insertion_point) + + # Replace the output qubits of the old phase polynomial operations by the output qubits of + # the new circuit + for old_qubit, int_wire in self.init_wire_map.items(): + rewriter.replace_all_uses_with(old_qubit, inv_wire_map[int_wire]) + + # Erase the old phase polynomial operations. + for op in self.phase_polynomial_ops[::-1]: + rewriter.erase_op(op) + + # Reset internal state + self._reset_vars() + + +@dataclass(frozen=True) +class ParitySynthPass(passes.ModulePass): + """Pass for applying ParitySynth to phase polynomials in a circuit.""" + + name = "xdsl-parity-synth" + + # pylint: disable=no-self-use + def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: + """Apply the ParitySynth pass.""" + applier = pattern_rewriter.GreedyRewritePatternApplier([ParitySynthPattern()]) + walker = pattern_rewriter.PatternRewriteWalker(applier, apply_recursively=False) + walker.rewrite_module(module) + + +parity_synth_pass = compiler_transform(ParitySynthPass) diff --git a/tests/python_compiler/transforms/quantum/test_xdsl_parity_synth.py b/tests/python_compiler/transforms/quantum/test_xdsl_parity_synth.py new file mode 100644 index 00000000000..3b72a7f24bb --- /dev/null +++ b/tests/python_compiler/transforms/quantum/test_xdsl_parity_synth.py @@ -0,0 +1,475 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit test module for the ParitySynth transform""" +from itertools import product + +import numpy as np +import pytest +from numpy.testing import assert_allclose, assert_equal + +pytestmark = pytest.mark.external + +pytest.importorskip("xdsl") +pytest.importorskip("catalyst") + +# pylint: disable=wrong-import-position +from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath + +import pennylane as qml +from pennylane.compiler.python_compiler.transforms import ParitySynthPass, parity_synth_pass +from pennylane.compiler.python_compiler.transforms.quantum.parity_synth import ( + _parity_network_synth, +) +from pennylane.transforms.intermediate_reps import phase_polynomial + + +def assert_binary_matrix(matrix: np.ndarray): + """Check that the input matrix is two-dimensional, ``np.int64``-dtyped and + only contains zeros and ones. + """ + if matrix.ndim != 2: + raise ValueError( + f"Expected the matrix to be two-dimensional, but got {matrix.ndim} dimensions." + ) + if matrix.dtype != np.int64: + raise ValueError( + f"Expected the data type of the matrix to be np.int64, but got {matrix.dtype}." + ) + if not set(matrix.flat).issubset({0, 1}): + raise ValueError( + f"Expected the entries of the matrix to be from {{0, 1}} but got {set(matrix.flat)}." + ) + + +class TestParityNetworkSynth: + """Tests for the synthesizing of a parity network with ``_parity_network_synth``.""" + + @staticmethod + def validate_circuit_entry(entry, exp_len=None): + """Validate that an object is a three-tuple consisting of an two integers and a list + of two-tuples with integers in them, like ``(1, 4, [(0, 2), (1, 0)])``. This constitutes + the format for circuit entries in the output of ``_parity_network_synth``.""" + assert isinstance(entry, tuple) and len(entry) == 3 + parity_idx, qubit_idx, cnot_circuit = entry + assert isinstance(parity_idx, np.int64) + assert isinstance(qubit_idx, int) + assert isinstance(cnot_circuit, list) + if exp_len is not None: + assert len(cnot_circuit) == exp_len + assert all(isinstance(_cnot, tuple) and len(_cnot) == 2 for _cnot in cnot_circuit) + + def test_empty_parity_table(self): + """Test that an empty parity table results in an empty circuit.""" + P = np.ones(shape=(10, 0), dtype=int) + circuit, inv_synth_matrix = _parity_network_synth(P) + assert not circuit + assert inv_synth_matrix is None + + @pytest.mark.parametrize("n, idx", [(1, 0), (2, 0), (2, 1), (3, 2), (10, 5)]) + def test_single_unit_vector_parity(self, n, idx): + """Test that a single unit vector-parity is synthesized into no CNOTs and a single RZ.""" + I = np.eye(n, dtype=int) + P = I[:, idx : idx + 1] + circuit, inv_synth_matrix = _parity_network_synth(P) + assert isinstance(circuit, list) and len(circuit) == 1 + self.validate_circuit_entry(circuit[0], exp_len=0) + assert_binary_matrix(inv_synth_matrix) + assert_equal(I, inv_synth_matrix) + + @pytest.mark.parametrize( + "n, ids", + [ + (1, [0]), + (2, [1, 0]), + (2, [0, 1]), + (3, [2, 0]), + (3, [0, 1]), + (5, [0, 4]), + (10, [5, 4, 9, 0, 2, 3]), + ], + ) + def test_multiple_unit_vector_parities(self, n, ids): + """Test that multiple unit vector-parities are synthesized into no CNOTs and a + series of RZ gates.""" + I = np.eye(n, dtype=int) + P = np.concatenate([I[:, idx : idx + 1] for idx in ids], axis=1) + circuit, inv_synth_matrix = _parity_network_synth(P) + assert isinstance(circuit, list) and len(circuit) == len(ids) + for entry in circuit: + self.validate_circuit_entry(entry, exp_len=0) + assert_binary_matrix(inv_synth_matrix) + assert_equal(I, inv_synth_matrix) + + @pytest.mark.parametrize( + "parity", + [ + [1, 0, 0, 1, 0, 1], + [1, 1, 1], + [0, 1, 1, 1, 0, 0, 0, 1, 1], + [1, 1], + [1, 0, 0, 0, 0, 0, 0, 1], + ], + ) + def test_single_non_unit_parity(self, parity): + """Test that a single non-unit vector-parity ``p`` is synthesized into + ``|p|-1`` CNOTs and a single RZ.""" + P = np.array([parity]).T + circuit, inv_synth_matrix = _parity_network_synth(P) + assert isinstance(circuit, list) and len(circuit) == 1 + self.validate_circuit_entry(circuit[0], exp_len=np.sum(P) - 1) + I = np.eye(len(parity), dtype=int) + assert I.shape == inv_synth_matrix.shape + assert set(inv_synth_matrix.flat).issubset({0, 1}) + assert not np.allclose(I, inv_synth_matrix) + + @pytest.mark.parametrize( + "parities, exp_lens", + [ + ([[1, 0, 0, 1, 0, 1], [0, 0, 1, 1, 0, 0], [0, 0, 1, 1, 0, 0]], (1, 0, 2)), + ([[1, 1, 1], [0, 1, 1], [1, 1, 1]], (1, 1, 0)), + ([[0, 1, 1, 1, 0, 0, 0, 1, 1], [0, 1, 1, 1, 0, 0, 0, 1, 1]], (4, 0)), + ([[1, 1], [1, 1], [0, 1], [1, 0], [1, 0], [0, 1]], (0, 0, 0, 0, 1, 0)), + ( + [[1, 0, 0, 0, 0, 0, 0, 1], [1, 0, 0, 0, 0, 0, 0, 1], [1, 0, 0, 0, 0, 0, 0, 1]], + (1, 0, 0), + ), + ], + ) + def test_with_repeated_parities(self, parities, exp_lens): + """Test that repeated (non-unit vector-)parities are synthesized in sequence and + require no CNOT between their RZ gates.""" + P = np.array(parities).T + circuit, inv_synth_matrix = _parity_network_synth(P) + assert isinstance(circuit, list) and len(circuit) == len(parities) + for entry, exp_len in zip(circuit, exp_lens, strict=True): + self.validate_circuit_entry(entry, exp_len=exp_len) + I = np.eye(len(parities[0]), dtype=int) + assert I.shape == inv_synth_matrix.shape + assert set(inv_synth_matrix.flat).issubset({0, 1}) + assert not np.allclose(I, inv_synth_matrix) + + @pytest.mark.parametrize("n, seed", [(2, 851), (3, 231), (4, 8241), (5, 214)]) + @pytest.mark.parametrize("num_parities", (1, 2, 3, 10, 20)) + def test_roundtrip(self, num_parities, n, seed): + """Test that the parity table of a randomly sampled CNOT+RZ circuit is synthesized + into a new CNOT+RZ circuit with the same parities, and that the inverse of the + parity matrix of the new circuit is reported correctly.""" + # pylint: disable=unbalanced-tuple-unpacking + + np.random.seed(seed) # todo: proper seeding + # Make all cnot ops + all_cnots = [qml.CNOT((i, j)) for i, j in product(range(n), repeat=2) if i != j] + # Sample random CNOTs (by index into above list) and rotation angles + cnots = [ + np.random.choice(len(all_cnots), size=n, replace=True) for _ in range(num_parities) + ] + thetas = np.random.random(num_parities) + # Make PL circuit + circuit = sum( + [ + [all_cnots[i] for i in sub_circuit] + [qml.RZ(x, j % n)] + for j, (sub_circuit, x) in enumerate(zip(cnots, thetas, strict=True)) + ], + start=[], + ) + # Compute IR + _, P, angles = phase_polynomial(qml.tape.QuantumScript(circuit), wire_order=range(n)) + + angles_ = list(angles) + # Synthesize parity network and compute new PL circuit from it + new_circuit, inv_parity_matrix = _parity_network_synth(P) + new_circuit = sum( + [ + [qml.CNOT(_cnot) for _cnot in sub_circuit] + + [qml.RZ(angles_.pop(angle_idx), qubit_idx)] + for angle_idx, qubit_idx, sub_circuit in new_circuit + ], + start=[], + ) + # Compute IR of new PL circuit + new_parity_matrix, new_P, new_angles = phase_polynomial( + qml.tape.QuantumScript(new_circuit), wire_order=range(n) + ) + # Compare phase parities and make sure that the inv_parity_matrix is valid + assert_allclose(new_P @ new_angles, P @ angles) + assert_binary_matrix(inv_parity_matrix) + assert_equal((new_parity_matrix @ inv_parity_matrix) % 2, np.eye(n, dtype=int)) + + +def translate_program_to_xdsl(program): + """Translate an almost-xDSL-program into an xDSL program by replacing some shorthand notations.""" + new_lines = [] + for line in program.split("\n"): + if "INIT_QUBIT" in line: + i = int(line.strip().split(" ")[0][1:]) + new_lines.extend( + [ + f'// CHECK: [[q{i}:%.+]] = "test.op"() : () -> !quantum.bit', + f'%{i} = "test.op"() : () -> !quantum.bit', + ] + ) + elif "_CNOT" in line: + bits = line.strip().split(" ") + new_bits = ( + bits[:3] + ['quantum.custom "CNOT"()'] + bits[4:] + [": !quantum.bit, !quantum.bit"] + ) + new_lines.append(" ".join(new_bits)) + else: + new_lines.append(line) + return "\n".join(new_lines) + + +class TestParitySynthPass: + """Unit tests for ParitySynthPass.""" + + pipeline = (ParitySynthPass(),) + + def test_no_phase_polynomial_ops(self, run_filecheck): + """Test that nothing changes when there are no phase polynomial gates.""" + program = """ + func.func @test_func(%arg0: f64) { + %0 = INIT_QUBIT + // CHECK: [[q1:%.+]] = quantum.custom "Hadamard"() [[q0]] : !quantum.bit + // CHECK: quantum.custom "RX"(%arg0) [[q1]] : !quantum.bit + %1 = quantum.custom "Hadamard"() %0 : !quantum.bit + %2 = quantum.custom "RX"(%arg0) %1 : !quantum.bit + return + } + """ + + run_filecheck(translate_program_to_xdsl(program), self.pipeline) + + def test_composable_cnots(self, run_filecheck): + """Test that two out of three CNOT gates are merged.""" + program = """ + func.func @test_func() { + %0 = INIT_QUBIT + %1 = INIT_QUBIT + // CHECK: quantum.custom "CNOT"() [[q0]], [[q1]] : !quantum.bit, !quantum.bit + %2, %3 = _CNOT %0, %1 + %4, %5 = _CNOT %2, %3 + %6, %7 = _CNOT %4, %5 + // CHECK-NOT: "quantum.custom" + return + } + """ + + run_filecheck(translate_program_to_xdsl(program), self.pipeline) + + def test_two_cnots_single_rotation_no_merge(self, run_filecheck): + """Test that a phase polynomial of two CNOTs separated by a rotation on the target + is maintained.""" + program = """ + func.func @test_func(%arg0: f64) { + %0 = INIT_QUBIT + %1 = INIT_QUBIT + // In the following check, q1 and q0 are exchanged. This is a symmetry + // of the test case and ParitySynth chooses to flip the CNOTs. + // CHECK: [[q2:%.+]], [[q3:%.+]] = quantum.custom "CNOT"() [[q1]], [[q0]] : !quantum.bit, !quantum.bit + %2, %3 = _CNOT %0, %1 + // CHECK: [[q4:%.+]] = quantum.custom "RZ"(%arg0) [[q3]] : !quantum.bit + %4 = quantum.custom "RZ"(%arg0) %3 : !quantum.bit + // CHECK: quantum.custom "CNOT"() [[q2]], [[q4]] : !quantum.bit, !quantum.bit + %5, %6 = _CNOT %2, %4 + // CHECK-NOT: "quantum.custom" + return + } + """ + + run_filecheck(translate_program_to_xdsl(program), self.pipeline) + + def test_two_cnots_single_rotation_with_merge(self, run_filecheck): + """Test that a phase polynomial of two CNOTs separated by a rotation on the control + is reduced.""" + program = """ + func.func @test_func(%arg0: f64) { + %0 = INIT_QUBIT + %1 = INIT_QUBIT + %2, %3 = _CNOT %0, %1 + // CHECK: quantum.custom "RZ"(%arg0) [[q0]] : !quantum.bit + %4 = quantum.custom "RZ"(%arg0) %2 : !quantum.bit + %5, %6 = _CNOT %4, %3 + // CHECK-NOT: "quantum.custom" + return + } + """ + + run_filecheck(translate_program_to_xdsl(program), self.pipeline) + + def test_two_phase_polynomials_first_merge(self, run_filecheck): + """Test that two phase polynomials separated by a non-phase-polynomial operation is + compiled correctly if the former polynomial can be reduced.""" + program = """ + func.func @test_func(%arg0: f64, %arg1: f64, %arg2: f64) { + %0 = INIT_QUBIT + %1 = INIT_QUBIT + %2 = INIT_QUBIT + + %3, %4 = _CNOT %0, %2 + // CHECK: [[q3:%.+]] = quantum.custom "RZ"(%arg2) [[q0]] : !quantum.bit + %5 = quantum.custom "RZ"(%arg2) %3 : !quantum.bit + %6, %7 = _CNOT %5, %4 + + // CHECK: [[q4:%.+]] = quantum.custom "RX"(%arg1) [[q3]] : !quantum.bit + %8 = quantum.custom "RX"(%arg1) %6 : !quantum.bit + + // CHECK: [[q5:%.+]], [[q6:%.+]] = quantum.custom "CNOT"() [[q1]], [[q4]] : !quantum.bit, !quantum.bit + %9, %10 = _CNOT %8, %1 + // CHECK: [[q7:%.+]] = quantum.custom "RZ"(%arg0) [[q6]] : !quantum.bit + %11 = quantum.custom "RZ"(%arg0) %10 : !quantum.bit + // CHECK: quantum.custom "CNOT"() [[q5]], [[q7]] : !quantum.bit, !quantum.bit + %12, %13 = _CNOT %9, %11 + // CHECK-NOT: "quantum.custom" + return + } + """ + run_filecheck(translate_program_to_xdsl(program), self.pipeline) + + def test_two_phase_polynomials_second_merge(self, run_filecheck): + """Test that two phase polynomials separated by a non-phase-polynomial operation is + compiled correctly if the latter polynomial can be reduced.""" + program = """ + func.func @test_func(%arg0: f64, %arg1: f64, %arg2: f64) { + %0 = INIT_QUBIT + %1 = INIT_QUBIT + %2 = INIT_QUBIT + + // CHECK: [[q3:%.+]], [[q4:%.+]] = quantum.custom "CNOT"() [[q1]], [[q0]] : !quantum.bit, !quantum.bit + %3, %4 = _CNOT %0, %1 + // CHECK: [[q5:%.+]] = quantum.custom "RZ"(%arg0) [[q4]] : !quantum.bit + %5 = quantum.custom "RZ"(%arg0) %4 : !quantum.bit + // CHECK: [[q6:%.+]], [[q7:%.+]] = quantum.custom "CNOT"() [[q3]], [[q5]] : !quantum.bit, !quantum.bit + %6, %7 = _CNOT %3, %5 + + // CHECK: [[q8:%.+]] = quantum.custom "RX"(%arg1) [[q6]] : !quantum.bit + %8 = quantum.custom "RX"(%arg1) %7 : !quantum.bit + + %9, %10 = _CNOT %8, %2 + // CHECK: quantum.custom "RZ"(%arg2) [[q8]] : !quantum.bit + %11 = quantum.custom "RZ"(%arg2) %9 : !quantum.bit + %12, %13 = _CNOT %11, %10 + // CHECK-NOT: "quantum.custom" + return + } + """ + run_filecheck(translate_program_to_xdsl(program), self.pipeline) + + def test_large_phase_polynomial(self, run_filecheck): + """Test that a larger phase polynomial block is handled without an error.""" + program = """ + func.func @test_func(%arg0: f64, %arg1: f64, %arg2: f64, %arg3: f64, %arg4: f64) { + %0 = INIT_QUBIT + %1 = INIT_QUBIT + %2 = INIT_QUBIT + %3 = INIT_QUBIT + + %4, %5 = _CNOT %0, %2 + %6, %7 = _CNOT %1, %4 + %8 = quantum.custom "RZ"(%arg0) %7 : !quantum.bit + %9, %10 = _CNOT %8, %6 + %11 = quantum.custom "RZ"(%arg1) %9 : !quantum.bit + %12, %13 = _CNOT %3, %10 + %14, %15 = _CNOT %12, %11 + %16 = quantum.custom "RZ"(%arg2) %14 : !quantum.bit + %17, %18 = _CNOT %5, %15 + %19, %20 = _CNOT %13, %18 + %21 = quantum.custom "RZ"(%arg3) %20 : !quantum.bit + %22, %23 = _CNOT %16, %17 + %24, %25 = _CNOT %22, %19 + %26, %27 = _CNOT %21, %23 + %28, %29 = _CNOT %25, %27 + %30, %31 = _CNOT %28, %26 + %32, %33 = _CNOT %29, %31 + %34, %35 = _CNOT %32, %30 + %36, %37 = _CNOT %34, %33 + %38 = quantum.custom "RZ"(%arg4) %35 : !quantum.bit + + // CHECK: [[q4:%.+]] = quantum.custom "RZ"(%arg2) [[q3]] : !quantum.bit + // CHECK: [[q5:%.+]], [[q6:%.+]] = quantum.custom "CNOT"() [[q1]], [[q0]] : !quantum.bit, !quantum.bit + // CHECK: [[q7:%.+]] = quantum.custom "RZ"(%arg0) [[q6]] : !quantum.bit + // CHECK: [[q8:%.+]] = quantum.custom "RZ"(%arg1) [[q7]] : !quantum.bit + // CHECK: [[q9:%.+]], [[q10:%.+]] = quantum.custom "CNOT"() [[q2]], [[q8]] : !quantum.bit, !quantum.bit + // CHECK: [[q11:%.+]] = quantum.custom "RZ"(%arg3) [[q10]] : !quantum.bit + // CHECK: [[q12:%.+]], [[q13:%.+]] = quantum.custom "CNOT"() [[q4]], [[q5]] : !quantum.bit, !quantum.bit + // CHECK: [[q14:%.+]] = quantum.custom "RZ"(%arg4) [[q13]] : !quantum.bit + + // CHECK: [[q15:%.+]], [[q16:%.+]] = quantum.custom "CNOT"() [[q12]], [[q9]] : !quantum.bit, !quantum.bit + // CHECK: [[q17:%.+]], [[q18:%.+]] = quantum.custom "CNOT"() [[q14]], [[q16]] : !quantum.bit, !quantum.bit + // CHECK: [[q19:%.+]], [[q20:%.+]] = quantum.custom "CNOT"() [[q17]], [[q11]] : !quantum.bit, !quantum.bit + // CHECK: [[q21:%.+]], [[q22:%.+]] = quantum.custom "CNOT"() [[q20]], [[q18]] : !quantum.bit, !quantum.bit + // CHECK: [[q23:%.+]], [[q24:%.+]] = quantum.custom "CNOT"() [[q22]], [[q21]] : !quantum.bit, !quantum.bit + // CHECK-NOT: "quantum.custom" + return + } + """ + run_filecheck(translate_program_to_xdsl(program), self.pipeline) + + def test_phase_polynomial_with_adjoint(self, run_filecheck): + """Test that adjoint is handled correctly.""" + program = """ + func.func @test_func(%arg0: f64) { + %0 = INIT_QUBIT + %1 = INIT_QUBIT + %2 = INIT_QUBIT + + // CHECK: [[new_angle:%.+]] = arith.negf %arg0 + // CHECK: [[q3:%.+]], [[q4:%.+]] = quantum.custom "CNOT"() [[q1]], [[q0]] : !quantum.bit, !quantum.bit + %3, %4 = _CNOT %0, %1 + // CHECK: [[q5:%.+]] = quantum.custom "RZ"([[new_angle]]) [[q4]] : !quantum.bit + %5 = quantum.custom "RZ"(%arg0) %4 adj : !quantum.bit + // CHECK: [[q6:%.+]], [[q7:%.+]] = quantum.custom "CNOT"() [[q3]], [[q5]] : !quantum.bit, !quantum.bit + %6, %7 = _CNOT %3, %5 + // CHECK-NOT: "quantum.custom" + return + } + """ + run_filecheck(translate_program_to_xdsl(program), self.pipeline) + + +# pylint: disable=too-few-public-methods +@pytest.mark.usefixtures("enable_disable_plxpr") +class TestParitySynthIntegration: + """Integration tests for the ParitySynthPass.""" + + def test_qjit(self, run_filecheck_qjit): + """Test that the ParitySynthPass works correctly with qjit.""" + dev = qml.device("lightning.qubit", wires=2) + + @qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()]) + @parity_synth_pass + @qml.qnode(dev) + def circuit(x: float, y: float, z: float): + # CHECK: [[phi:%.+]] = tensor.extract %arg0 + # CHECK: quantum.custom "CNOT"() + # CHECK: quantum.custom "RZ"([[phi]]) + # CHECK: quantum.custom "CNOT"() + # CHECK: [[omega:%.+]] = tensor.extract %arg1 + # CHECK: quantum.custom "RX"([[omega]]) + # CHECK: [[theta:%.+]] = tensor.extract %arg2 + # CHECK: quantum.custom "RZ"([[theta]]) + # CHECK-NOT: quantum.custom + qml.CNOT((0, 1)) + qml.RZ(x, 1) + qml.CNOT((0, 1)) + qml.RX(y, 1) + qml.CNOT((1, 0)) + qml.RZ(z, 1) + qml.CNOT((1, 0)) + return qml.state() + + run_filecheck_qjit(circuit) + + +if __name__ == "__main__": + pytest.main(["-x", __file__])