Skip to content

Commit

Permalink
Add an optional CompilationTargetGateset postprocessor to contract th…
Browse files Browse the repository at this point in the history
…e circuit (quantumlib#6433)
  • Loading branch information
NoureldinYosri authored Feb 3, 2024
1 parent 039361d commit 0b3f936
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 19 deletions.
56 changes: 39 additions & 17 deletions cirq/transformers/optimize_for_target_gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Transformers to rewrite a circuit using gates from a given target gateset."""

from typing import Optional, Callable, Hashable, Sequence, TYPE_CHECKING
from typing import Optional, Callable, Hashable, Sequence, TYPE_CHECKING, Union

from cirq import circuits
from cirq.protocols import decompose_protocol as dp
Expand Down Expand Up @@ -102,19 +102,29 @@ def optimize_for_target_gateset(
context: Optional['cirq.TransformerContext'] = None,
gateset: Optional['cirq.CompilationTargetGateset'] = None,
ignore_failures: bool = True,
max_num_passes: Union[int, None] = 1,
) -> 'cirq.Circuit':
"""Transforms the given circuit into an equivalent circuit using gates accepted by `gateset`.
Repeat max_num_passes times or when `max_num_passes=None` until no further changes can be done
1. Run all `gateset.preprocess_transformers`
2. Convert operations using built-in cirq decompose + `gateset.decompose_to_target_gateset`.
3. Run all `gateset.postprocess_transformers`
Note:
The optimizer is a heuristic and may not produce optimal results even with
max_num_passes=None. The preprocessors and postprocessors of the gate set
as well as their order yield different results.
Args:
circuit: Input circuit to transform. It will not be modified.
context: `cirq.TransformerContext` storing common configurable options for transformers.
gateset: Target gateset, which should be an instance of `cirq.CompilationTargetGateset`.
ignore_failures: If set, operations that fail to convert are left unchanged. If not set,
conversion failures raise a ValueError.
max_num_passes: The maximum number of passes to do. A value of `None` means to keep
iterating until no more changes happen to the number of moments or operations.
Returns:
An equivalent circuit containing gates accepted by `gateset`.
Expand All @@ -126,20 +136,32 @@ def optimize_for_target_gateset(
return _decompose_operations_to_target_gateset(
circuit, context=context, ignore_failures=ignore_failures
)

for transformer in gateset.preprocess_transformers:
circuit = transformer(circuit, context=context)

circuit = _decompose_operations_to_target_gateset(
circuit,
context=context,
gateset=gateset,
decomposer=gateset.decompose_to_target_gateset,
ignore_failures=ignore_failures,
tags_to_decompose=(gateset._intermediate_result_tag,),
)

for transformer in gateset.postprocess_transformers:
circuit = transformer(circuit, context=context)

if isinstance(max_num_passes, int):
_outerloop = lambda: range(max_num_passes)
else:

def _outerloop():
while True:
yield 0

initial_num_moments, initial_num_ops = len(circuit), sum(1 for _ in circuit.all_operations())
for _ in _outerloop():
for transformer in gateset.preprocess_transformers:
circuit = transformer(circuit, context=context)
circuit = _decompose_operations_to_target_gateset(
circuit,
context=context,
gateset=gateset,
decomposer=gateset.decompose_to_target_gateset,
ignore_failures=ignore_failures,
tags_to_decompose=(gateset._intermediate_result_tag,),
)
for transformer in gateset.postprocess_transformers:
circuit = transformer(circuit, context=context)

num_moments, num_ops = len(circuit), sum(1 for _ in circuit.all_operations())
if (num_moments, num_ops) == (initial_num_moments, initial_num_ops):
# Stop early. No further optimizations can be done.
break
initial_num_moments, initial_num_ops = num_moments, num_ops
return circuit.unfreeze(copy=False)
150 changes: 150 additions & 0 deletions cirq/transformers/optimize_for_target_gateset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import cirq
from cirq.protocols.decompose_protocol import DecomposeResult
from cirq.transformers.optimize_for_target_gateset import _decompose_operations_to_target_gateset
Expand Down Expand Up @@ -243,3 +245,151 @@ def test_optimize_for_target_gateset_deep():
1: ───#2───────────────────────────────────────────────────────────────────────────
''',
)


@pytest.mark.parametrize('max_num_passes', [2, None])
def test_optimize_for_target_gateset_multiple_passes(max_num_passes: Union[int, None]):
gateset = cirq.CZTargetGateset()

input_circuit = cirq.Circuit(
[
cirq.Moment(
cirq.X(cirq.LineQubit(1)),
cirq.X(cirq.LineQubit(2)),
cirq.X(cirq.LineQubit(3)),
cirq.X(cirq.LineQubit(6)),
),
cirq.Moment(
cirq.H(cirq.LineQubit(0)),
cirq.H(cirq.LineQubit(1)),
cirq.H(cirq.LineQubit(2)),
cirq.H(cirq.LineQubit(3)),
cirq.H(cirq.LineQubit(4)),
cirq.H(cirq.LineQubit(5)),
cirq.H(cirq.LineQubit(6)),
),
cirq.Moment(
cirq.H(cirq.LineQubit(1)), cirq.H(cirq.LineQubit(3)), cirq.H(cirq.LineQubit(5))
),
cirq.Moment(
cirq.CZ(cirq.LineQubit(0), cirq.LineQubit(1)),
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(3)),
cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(5)),
),
cirq.Moment(
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(1)),
cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(3)),
cirq.CZ(cirq.LineQubit(6), cirq.LineQubit(5)),
),
]
)
desired_circuit = cirq.Circuit.from_moments(
cirq.Moment(
cirq.PhasedXZGate(axis_phase_exponent=0.5, x_exponent=-0.5, z_exponent=1.0).on(
cirq.LineQubit(4)
)
),
cirq.Moment(cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(5))),
cirq.Moment(
cirq.PhasedXZGate(axis_phase_exponent=-1.0, x_exponent=1, z_exponent=0).on(
cirq.LineQubit(1)
),
cirq.PhasedXZGate(axis_phase_exponent=0.5, x_exponent=-0.5, z_exponent=1.0).on(
cirq.LineQubit(0)
),
cirq.PhasedXZGate(axis_phase_exponent=-1.0, x_exponent=1, z_exponent=0).on(
cirq.LineQubit(3)
),
cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=0.0).on(
cirq.LineQubit(2)
),
),
cirq.Moment(
cirq.CZ(cirq.LineQubit(0), cirq.LineQubit(1)),
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(3)),
),
cirq.Moment(
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(1)),
cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(3)),
),
cirq.Moment(
cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=0.0).on(
cirq.LineQubit(6)
)
),
cirq.Moment(cirq.CZ(cirq.LineQubit(6), cirq.LineQubit(5))),
)
got = cirq.optimize_for_target_gateset(
input_circuit, gateset=gateset, max_num_passes=max_num_passes
)
cirq.testing.assert_same_circuits(got, desired_circuit)


@pytest.mark.parametrize('max_num_passes', [2, None])
def test_optimize_for_target_gateset_multiple_passes_dont_preserve_moment_structure(
max_num_passes: Union[int, None]
):
gateset = cirq.CZTargetGateset(preserve_moment_structure=False)

input_circuit = cirq.Circuit(
[
cirq.Moment(
cirq.X(cirq.LineQubit(1)),
cirq.X(cirq.LineQubit(2)),
cirq.X(cirq.LineQubit(3)),
cirq.X(cirq.LineQubit(6)),
),
cirq.Moment(
cirq.H(cirq.LineQubit(0)),
cirq.H(cirq.LineQubit(1)),
cirq.H(cirq.LineQubit(2)),
cirq.H(cirq.LineQubit(3)),
cirq.H(cirq.LineQubit(4)),
cirq.H(cirq.LineQubit(5)),
cirq.H(cirq.LineQubit(6)),
),
cirq.Moment(
cirq.H(cirq.LineQubit(1)), cirq.H(cirq.LineQubit(3)), cirq.H(cirq.LineQubit(5))
),
cirq.Moment(
cirq.CZ(cirq.LineQubit(0), cirq.LineQubit(1)),
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(3)),
cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(5)),
),
cirq.Moment(
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(1)),
cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(3)),
cirq.CZ(cirq.LineQubit(6), cirq.LineQubit(5)),
),
]
)
desired_circuit = cirq.Circuit(
cirq.PhasedXZGate(axis_phase_exponent=0.5, x_exponent=-0.5, z_exponent=1.0).on(
cirq.LineQubit(4)
),
cirq.PhasedXZGate(axis_phase_exponent=-1.0, x_exponent=1, z_exponent=0).on(
cirq.LineQubit(1)
),
cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=0.0).on(
cirq.LineQubit(2)
),
cirq.PhasedXZGate(axis_phase_exponent=0.5, x_exponent=-0.5, z_exponent=1.0).on(
cirq.LineQubit(0)
),
cirq.PhasedXZGate(axis_phase_exponent=-1.0, x_exponent=1, z_exponent=0).on(
cirq.LineQubit(3)
),
cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=0.0).on(
cirq.LineQubit(6)
),
cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(5)),
cirq.CZ(cirq.LineQubit(0), cirq.LineQubit(1)),
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(3)),
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(1)),
cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(3)),
cirq.CZ(cirq.LineQubit(6), cirq.LineQubit(5)),
)
got = cirq.optimize_for_target_gateset(
input_circuit, gateset=gateset, max_num_passes=max_num_passes
)
cirq.testing.assert_same_circuits(got, desired_circuit)
28 changes: 26 additions & 2 deletions cirq/transformers/target_gatesets/compilation_target_gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Base class for creating custom target gatesets which can be used for compilation."""

from typing import Optional, List, Hashable, TYPE_CHECKING
from typing import Optional, List, Hashable, TYPE_CHECKING, Union, Type
import abc

from cirq import circuits, ops, protocols, transformers
Expand Down Expand Up @@ -80,6 +80,27 @@ class CompilationTargetGateset(ops.Gateset, metaclass=abc.ABCMeta):
which can transform any given circuit to contain gates accepted by this gateset.
"""

def __init__(
self,
*gates: Union[Type['cirq.Gate'], 'cirq.Gate', 'cirq.GateFamily'],
name: Optional[str] = None,
unroll_circuit_op: bool = True,
preserve_moment_structure: bool = True,
):
"""Initializes CompilationTargetGateset.
Args:
*gates: A list of `cirq.Gate` subclasses / `cirq.Gate` instances /
`cirq.GateFamily` instances to initialize the Gateset.
name: (Optional) Name for the Gateset. Useful for description.
unroll_circuit_op: If True, `cirq.CircuitOperation` is recursively
validated by validating the underlying `cirq.Circuit`.
preserve_moment_structure: Whether to preserve the moment structure of the
circuit during compilation or not.
"""
super().__init__(*gates, name=name, unroll_circuit_op=unroll_circuit_op)
self._preserve_moment_structure = preserve_moment_structure

@property
@abc.abstractmethod
def num_qubits(self) -> int:
Expand Down Expand Up @@ -140,11 +161,14 @@ def preprocess_transformers(self) -> List['cirq.TRANSFORMER']:
@property
def postprocess_transformers(self) -> List['cirq.TRANSFORMER']:
"""List of transformers which should be run after decomposing individual operations."""
return [
processors: List['cirq.TRANSFORMER'] = [
merge_single_qubit_gates.merge_single_qubit_moments_to_phxz,
transformers.drop_negligible_operations,
transformers.drop_empty_moments,
]
if not self._preserve_moment_structure:
processors.append(transformers.stratified_circuit)
return processors


class TwoQubitCompilationTargetGateset(CompilationTargetGateset):
Expand Down
4 changes: 4 additions & 0 deletions cirq/transformers/target_gatesets/cz_gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
atol: float = 1e-8,
allow_partial_czs: bool = False,
additional_gates: Sequence[Union[Type['cirq.Gate'], 'cirq.Gate', 'cirq.GateFamily']] = (),
preserve_moment_structure: bool = True,
) -> None:
"""Initializes CZTargetGateset
Expand All @@ -57,6 +58,8 @@ def __init__(
`cirq.CZ`, are part of this gateset.
additional_gates: Sequence of additional gates / gate families which should also
be "accepted" by this gateset. This is empty by default.
preserve_moment_structure: Whether to preserve the moment structure of the
circuit during compilation or not.
"""
super().__init__(
ops.CZPowGate if allow_partial_czs else ops.CZ,
Expand All @@ -65,6 +68,7 @@ def __init__(
ops.GlobalPhaseGate,
*additional_gates,
name='CZPowTargetGateset' if allow_partial_czs else 'CZTargetGateset',
preserve_moment_structure=preserve_moment_structure,
)
self.additional_gates = tuple(
g if isinstance(g, ops.GateFamily) else ops.GateFamily(gate=g) for g in additional_gates
Expand Down

0 comments on commit 0b3f936

Please sign in to comment.