Skip to content

Commit

Permalink
change: Clean up GateConnectivityCriterion instantiation logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
ltnln committed Jul 18, 2024
1 parent 0601c63 commit c5b9fc3
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,47 +17,43 @@ def __init__(
directed=True,
):
super().__init__()
if isinstance(gate_connectivity_graph, DiGraph):
self._gate_connectivity_graph = gate_connectivity_graph
if not directed:
for u, v in self._gate_connectivity_graph.edges:
back_edge = (v, u)
if back_edge not in self._gate_connectivity_graph.edges:
supported_gates = self._gate_connectivity_graph[u][v]["supported_gates"]
self._gate_connectivity_graph.add_edge(
*back_edge, supported_gates=supported_gates
)
else:
# check that the supported gate sets are identical
if (
self._gate_connectivity_graph[u][v]["supported_gates"]
!= self._gate_connectivity_graph[v][u]["supported_gates"]
):
raise ValueError(
f"Connectivity Graph marked as undirected\
but edges ({u}, {v}) and ({v}, {u}) have different supported\
gate sets."
)

elif isinstance(gate_connectivity_graph, dict):
if isinstance(gate_connectivity_graph, dict):
self._gate_connectivity_graph = DiGraph()
for edge, supported_gates in gate_connectivity_graph.items():
self._gate_connectivity_graph.add_edge(
edge[0], edge[1], supported_gates=supported_gates
)
if not directed:
back_edge = (edge[1], edge[0])
if back_edge not in gate_connectivity_graph:
self._gate_connectivity_graph.add_edge(
edge[1], edge[0], supported_gates=supported_gates
)
for (u, v), supported_gates in gate_connectivity_graph.items():
self._gate_connectivity_graph.add_edge(u, v, supported_gates=supported_gates)
elif isinstance(gate_connectivity_graph, DiGraph):
self._gate_connectivity_graph = gate_connectivity_graph
else:
raise TypeError(
"Gate_connectivity_graph must either be a dictionary of edges mapped to \
supported gates lists, or a DiGraph with supported gates \
provided as edge attributes."
)

if not directed:
"""
Add reverse edges and check that any supplied reverse edges have
identical supported gate sets to their corresponding forwards edge.
"""
for u, v in self._gate_connectivity_graph.edges:
back_edge = (v, u)
if back_edge not in self._gate_connectivity_graph.edges:
supported_gates = self._gate_connectivity_graph[u][v]["supported_gates"]
self._gate_connectivity_graph.add_edge(
*back_edge, supported_gates=supported_gates
)
else:
# check that the supported gate sets are identical
if (
self._gate_connectivity_graph[u][v]["supported_gates"]
!= self._gate_connectivity_graph[v][u]["supported_gates"]
):
raise ValueError(
f"Connectivity Graph marked as undirected\
but edges ({u}, {v}) and ({v}, {u}) have different supported\
gate sets."
)

def validate(self, circuit: Circuit) -> None:
"""
Verifies that any multiqubit gates used within a verbatim box are supported
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_undirected_graph_construction_from_dict():
"""
dict_representation = {
(0, 1): ["CNot", "CZ"],
(1, 0): ["CZ", "XX"],
(1, 0): ["CNot", "CZ"],
(1, 2): ["Swap", "CNot", "YY"],
(0, 2): ["XX", "XY", "CNot", "CZ"],
(2, 5): ["XX", "XY", "CNot", "CZ"],
Expand All @@ -167,7 +167,7 @@ def test_undirected_graph_construction_from_dict():
(1, 2, {"supported_gates": ["Swap", "CNot", "YY"]}),
(0, 2, {"supported_gates": ["XX", "XY", "CNot", "CZ"]}),
(2, 5, {"supported_gates": ["XX", "XY", "CNot", "CZ"]}),
(1, 0, {"supported_gates": ["CZ", "XX"]}),
(1, 0, {"supported_gates": ["CNot", "CZ"]}),
(2, 1, {"supported_gates": ["Swap", "CNot", "YY"]}),
(2, 0, {"supported_gates": ["XX", "XY", "CNot", "CZ"]}),
(5, 2, {"supported_gates": ["XX", "XY", "CNot", "CZ"]}),
Expand Down Expand Up @@ -289,3 +289,24 @@ def test_validate_instruction_method(gate_name, controls, targets, is_valid, bas
else:
with pytest.raises(ValueError):
gcc.validate_instruction_connectivity(gate_name, controls, targets)


@pytest.mark.parametrize(
"graph",
[
(
nx.from_dict_of_dicts(
{
0: {1: {"supported_gates": ["cnot", "cz"]}},
1: {0: {"supported_gates": ["cz", "cnot", "xx"]}},
},
create_using=nx.DiGraph(),
)
),
({(0, 1): ["cnot", "cz"], (1, 0): ["cz", "cnot", "xx"]}),
({(0, 1): ["xx", "yy"], (1, 0): ["yy", "xx"]}),
],
)
def test_invalid_undirected_graph(graph):
with pytest.raises(ValueError):
GateConnectivityCriterion(graph, directed=False)

0 comments on commit c5b9fc3

Please sign in to comment.