-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support ClExpr
#176
Support ClExpr
#176
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,11 +22,13 @@ | |
SetBitsOp, | ||
CopyBitsOp, | ||
RangePredicateOp, | ||
ClExprOp, | ||
ClassicalExpBox, | ||
LogicExp, | ||
BitWiseOp, | ||
RegWiseOp, | ||
) | ||
from pytket._tket.circuit import ClExpr, ClOp, ClBitVar, ClRegVar | ||
|
||
|
||
ExtendedLogicExp = Union[LogicExp, Bit, BitRegister, int] | ||
|
@@ -56,6 +58,28 @@ def apply_classical_command( | |
# Check that the value is in the range | ||
bits_dict[res_bit] = val >= op.lower and val <= op.upper | ||
|
||
elif isinstance(op, ClExprOp): | ||
# Convert bit_posn to dictionary of `ClBitVar` index to its value | ||
bitvar_val = { | ||
var_id: int(bits_dict[args[bit_pos]]) | ||
for var_id, bit_pos in op.expr.bit_posn.items() | ||
} | ||
# Convert reg_posn to dictionary of `ClRegVar` index to its value | ||
regvar_val = { | ||
var_id: from_little_endian( | ||
[bits_dict[args[bit_pos]] for bit_pos in reg_pos_list] | ||
) | ||
for var_id, reg_pos_list in op.expr.reg_posn.items() | ||
} | ||
result = evaluate_clexpr(op.expr.expr, bitvar_val, regvar_val) | ||
|
||
# The result is an int in little-endian encoding. We update the | ||
# output register accordingly. | ||
for bit_pos in op.expr.output_posn: | ||
bits_dict[args[bit_pos]] = (result % 2) == 1 | ||
result = result >> 1 | ||
assert result == 0 # All bits consumed | ||
|
||
elif isinstance(op, ClassicalExpBox): | ||
the_exp = op.get_exp() | ||
result = evaluate_logic_exp(the_exp, bits_dict) | ||
|
@@ -74,6 +98,81 @@ def apply_classical_command( | |
raise NotImplementedError(f"Commands of type {op.type} are not supported.") | ||
|
||
|
||
def evaluate_clexpr( | ||
expr: ClExpr, bitvar_val: dict[int, int], regvar_val: dict[int, int] | ||
) -> int: | ||
"""Recursive evaluation of a ClExpr.""" | ||
|
||
# Evaluate arguments to operation | ||
args_val = [] | ||
for arg in expr.args: | ||
if isinstance(arg, int): | ||
value = arg | ||
elif isinstance(arg, ClBitVar): | ||
value = bitvar_val[arg.index] | ||
elif isinstance(arg, ClRegVar): | ||
value = regvar_val[arg.index] | ||
elif isinstance(arg, ClExpr): | ||
value = evaluate_clexpr(arg, bitvar_val, regvar_val) | ||
else: | ||
raise Exception(f"Unrecognised argument type of ClExpr: {type(arg)}.") | ||
|
||
args_val.append(value) | ||
|
||
# Apply the operation at the root of this ClExpr | ||
if expr.op in [ClOp.BitAnd, ClOp.RegAnd]: | ||
result = args_val[0] & args_val[1] | ||
elif expr.op in [ClOp.BitOr, ClOp.RegOr]: | ||
result = args_val[0] | args_val[1] | ||
elif expr.op in [ClOp.BitXor, ClOp.RegXor]: | ||
result = args_val[0] ^ args_val[1] | ||
elif expr.op in [ClOp.BitEq, ClOp.RegEq]: | ||
result = int(args_val[0] == args_val[1]) | ||
elif expr.op in [ClOp.BitNeq, ClOp.RegNeq]: | ||
result = int(args_val[0] != args_val[1]) | ||
elif expr.op == ClOp.RegGeq: | ||
result = int(args_val[0] >= args_val[1]) | ||
elif expr.op == ClOp.RegGt: | ||
result = int(args_val[0] > args_val[1]) | ||
elif expr.op == ClOp.RegLeq: | ||
result = int(args_val[0] <= args_val[1]) | ||
elif expr.op == ClOp.RegLt: | ||
result = int(args_val[0] < args_val[1]) | ||
elif expr.op == ClOp.BitNot: | ||
result = 1 - args_val[0] | ||
# elif expr.op == ClOp.RegNot: | ||
# result = int(args_val[0] == 0) | ||
elif expr.op in [ClOp.BitZero, ClOp.RegZero]: | ||
result = 0 | ||
elif expr.op in [ClOp.BitOne, ClOp.RegOne]: | ||
result = 1 | ||
# elif expr.op == ClOp.RegAdd: | ||
# result = args_val[0] + args_val[1] | ||
# elif expr.op == ClOp.RegSub: | ||
# result = args_val[0] - args_val[1] | ||
# elif expr.op == ClOp.RegMul: | ||
# result = args_val[0] * args_val[1] | ||
# elif expr.op == ClOp.RegPow: | ||
# result = int(args_val[0] ** args_val[1]) | ||
elif expr.op == ClOp.RegRsh: | ||
result = args_val[0] >> args_val[1] | ||
# elif expr.op == ClOp.RegNeg: | ||
# result = -args_val[0] | ||
else: | ||
# TODO: Currently not supporting ClOp's RegDiv since it does not return int, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it should return the integer quotient, i.e. floor(a/b) where a and b are unsigned integers. (If this doesn't fit in the result register, perhaps error is the kindest response.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But if we are leaving ADD etc unsupported for now I don't see a need to support DIV. |
||
# so I am unsure what the semantic is meant to be. | ||
# TODO: I don't now what to do with RegNot, since input | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The intended semantics of |
||
# is not guaranteed to be 0 or 1. | ||
# TODO: It is not clear what to do with overflow of ADD, etc. | ||
# so I have decided to not support them for now. | ||
raise NotImplementedError( | ||
f"Evaluation of {expr.op} not supported in ClExpr ", | ||
"by pytket-cutensornet.", | ||
) | ||
|
||
return result | ||
|
||
|
||
def evaluate_logic_exp(exp: ExtendedLogicExp, bits_dict: dict[Bit, bool]) -> int: | ||
"""Recursive evaluation of a LogicExp.""" | ||
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All of the new tests in this file are just copy-pastes of existing tests using Do you know of any tests from other repositories that could be used here? Any suggestions of where to get circuits to test on, for which the intended behaviour is known (and, hence, can be checked against)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not really. There's a test here for the quantinuum local emulator, but it uses |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
reg_eq, | ||
) | ||
from pytket.circuit.logic_exp import BitWiseOp, create_bit_logic_exp | ||
from pytket.circuit.clexpr import wired_clexpr_from_logic_exp | ||
|
||
from pytket.extensions.cutensornet.structured_state import ( | ||
CuTensorNetHandle, | ||
|
@@ -26,6 +27,36 @@ | |
# Further down, there are tests to check that the simulation works correctly. | ||
|
||
|
||
def test_circuit_with_clexpr_i() -> None: | ||
# test conditional handling | ||
|
||
circ = Circuit(3) | ||
a = circ.add_c_register("a", 5) | ||
b = circ.add_c_register("b", 5) | ||
c = circ.add_c_register("c", 5) | ||
d = circ.add_c_register("d", 5) | ||
circ.H(0) | ||
wexpr, args = wired_clexpr_from_logic_exp(a | b, c) # type: ignore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had to add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could write |
||
circ.add_clexpr(wexpr, args) | ||
wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore | ||
circ.add_clexpr(wexpr, args) | ||
wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore | ||
circ.add_clexpr(wexpr, args, condition=a[4]) | ||
circ.H(0) | ||
circ.Measure(Qubit(0), d[4]) | ||
circ.H(1) | ||
circ.Measure(Qubit(1), d[3]) | ||
circ.H(2) | ||
circ.Measure(Qubit(2), d[2]) | ||
|
||
with CuTensorNetHandle() as libhandle: | ||
cfg = Config() | ||
state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg) | ||
assert state.is_valid() | ||
assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) | ||
assert state.get_fidelity() == 1.0 | ||
|
||
|
||
def test_circuit_with_classicalexpbox_i() -> None: | ||
# test conditional handling | ||
|
||
|
@@ -53,6 +84,36 @@ def test_circuit_with_classicalexpbox_i() -> None: | |
assert state.get_fidelity() == 1.0 | ||
|
||
|
||
def test_circuit_with_clexpr_ii() -> None: | ||
# test conditional handling with else case | ||
|
||
circ = Circuit(3) | ||
a = circ.add_c_register("a", 5) | ||
b = circ.add_c_register("b", 5) | ||
c = circ.add_c_register("c", 5) | ||
d = circ.add_c_register("d", 5) | ||
circ.H(0) | ||
wexpr, args = wired_clexpr_from_logic_exp(a | b, c) # type: ignore | ||
circ.add_clexpr(wexpr, args) | ||
wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore | ||
circ.add_clexpr(wexpr, args) | ||
wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore | ||
circ.add_clexpr(wexpr, args, condition=if_not_bit(a[4])) | ||
circ.H(0) | ||
circ.Measure(Qubit(0), d[4]) | ||
circ.H(1) | ||
circ.Measure(Qubit(1), d[3]) | ||
circ.H(2) | ||
circ.Measure(Qubit(2), d[2]) | ||
|
||
with CuTensorNetHandle() as libhandle: | ||
cfg = Config() | ||
state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg) | ||
assert state.is_valid() | ||
assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) | ||
assert state.get_fidelity() == 1.0 | ||
|
||
|
||
def test_circuit_with_classicalexpbox_ii() -> None: | ||
# test conditional handling with else case | ||
|
||
|
@@ -82,6 +143,36 @@ def test_circuit_with_classicalexpbox_ii() -> None: | |
assert state.get_fidelity() == 1.0 | ||
|
||
|
||
@pytest.mark.skip(reason="Currently not supporting arithmetic operations in ClExpr") | ||
def test_circuit_with_clexpr_iii() -> None: | ||
# test complicated conditions and recursive classical op | ||
|
||
circ = Circuit(2) | ||
|
||
a = circ.add_c_register("a", 15) | ||
b = circ.add_c_register("b", 15) | ||
c = circ.add_c_register("c", 15) | ||
d = circ.add_c_register("d", 15) | ||
e = circ.add_c_register("e", 15) | ||
|
||
circ.H(0) | ||
bits = [Bit(i) for i in range(10)] | ||
big_exp = bits[4] | bits[5] ^ bits[6] | bits[7] & bits[8] | ||
circ.H(0, condition=big_exp) | ||
|
||
wexpr, args = wired_clexpr_from_logic_exp(a + b - d, c) # type: ignore | ||
circ.add_clexpr(wexpr, args) | ||
wexpr, args = wired_clexpr_from_logic_exp(a * b * d * c, e) # type: ignore | ||
circ.add_clexpr(wexpr, args) | ||
|
||
with CuTensorNetHandle() as libhandle: | ||
cfg = Config() | ||
state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg) | ||
assert state.is_valid() | ||
assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) | ||
assert state.get_fidelity() == 1.0 | ||
|
||
|
||
@pytest.mark.skip(reason="Currently not supporting arithmetic operations in LogicExp") | ||
def test_circuit_with_classicalexpbox_iii() -> None: | ||
# test complicated conditions and recursive classical op | ||
|
@@ -239,6 +330,32 @@ def test_pytket_qir_conditional_10() -> None: | |
assert state.get_fidelity() == 1.0 | ||
|
||
|
||
def test_pytket_qir_conditional_11() -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this name? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I grabbed these tests from the |
||
box_circ = Circuit(4) | ||
box_circ.X(0) | ||
box_circ.Y(1) | ||
box_circ.Z(2) | ||
box_circ.H(3) | ||
box_c = box_circ.add_c_register("c", 5) | ||
|
||
box_circ.H(0) | ||
|
||
wexpr, args = wired_clexpr_from_logic_exp(box_c | box_c, box_c) # type: ignore | ||
box_circ.add_clexpr(wexpr, args) | ||
|
||
cbox = CircBox(box_circ) | ||
d = Circuit(4, 5) | ||
a = d.add_c_register("a", 4) | ||
d.add_circbox(cbox, [0, 2, 1, 3, 0, 1, 2, 3, 4], condition=a[0]) | ||
|
||
with CuTensorNetHandle() as libhandle: | ||
cfg = Config() | ||
state = simulate(libhandle, d, SimulationAlgorithm.MPSxGate, cfg) | ||
assert state.is_valid() | ||
assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) | ||
assert state.get_fidelity() == 1.0 | ||
|
||
|
||
def test_circuit_with_conditional_gate_v() -> None: | ||
# test conditional with no register | ||
|
||
|
@@ -430,7 +547,84 @@ def test_repeat_until_success_i() -> None: | |
assert np.allclose(target_state, output_state) | ||
|
||
|
||
def test_repeat_until_success_ii() -> None: | ||
def test_repeat_until_success_ii_clexpr() -> None: | ||
# From Figure 1(c) of https://arxiv.org/pdf/1311.1074 | ||
|
||
attempts = 100 | ||
|
||
circ = Circuit() | ||
qin = circ.add_q_register("qin", 1) | ||
qaux = circ.add_q_register("aux", 2) | ||
flag = circ.add_c_register("flag", 3) | ||
circ.add_c_setbits([True, True], [flag[0], flag[1]]) # Set flag bits to 11 | ||
circ.H(qin[0]) # Use to convert gate to sqrt(1/5)*I + i*sqrt(4/5)*X (i.e. Z -> X) | ||
|
||
for _ in range(attempts): | ||
wexpr, args = wired_clexpr_from_logic_exp( | ||
flag[0] | flag[1], [flag[2]] # Success if both are zero | ||
) | ||
circ.add_clexpr(wexpr, args) | ||
|
||
circ.add_gate( | ||
OpType.Reset, [qaux[0]], condition_bits=[flag[2]], condition_value=1 | ||
) | ||
circ.add_gate( | ||
OpType.Reset, [qaux[1]], condition_bits=[flag[2]], condition_value=1 | ||
) | ||
circ.add_gate(OpType.H, [qaux[0]], condition_bits=[flag[2]], condition_value=1) | ||
circ.add_gate(OpType.H, [qaux[1]], condition_bits=[flag[2]], condition_value=1) | ||
|
||
circ.add_gate(OpType.T, [qin[0]], condition_bits=[flag[2]], condition_value=1) | ||
circ.add_gate(OpType.Z, [qin[0]], condition_bits=[flag[2]], condition_value=1) | ||
circ.add_gate( | ||
OpType.Tdg, [qaux[0]], condition_bits=[flag[2]], condition_value=1 | ||
) | ||
circ.add_gate( | ||
OpType.CX, [qaux[1], qaux[0]], condition_bits=[flag[2]], condition_value=1 | ||
) | ||
circ.add_gate(OpType.T, [qaux[0]], condition_bits=[flag[2]], condition_value=1) | ||
circ.add_gate( | ||
OpType.CX, [qin[0], qaux[1]], condition_bits=[flag[2]], condition_value=1 | ||
) | ||
circ.add_gate(OpType.T, [qaux[1]], condition_bits=[flag[2]], condition_value=1) | ||
|
||
circ.add_gate(OpType.H, [qaux[0]], condition_bits=[flag[2]], condition_value=1) | ||
circ.add_gate(OpType.H, [qaux[1]], condition_bits=[flag[2]], condition_value=1) | ||
circ.Measure(qaux[0], flag[0], condition_bits=[flag[2]], condition_value=1) | ||
circ.Measure(qaux[1], flag[1], condition_bits=[flag[2]], condition_value=1) | ||
|
||
# From chat with Silas and exploring the RUS as a block matrix, we have noticed | ||
# that the circuit is missing an X correction when this condition is satisfied | ||
Comment on lines
+597
to
+598
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this issue tracked somewhere? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This RUS circuit comes from a paper where they acknowledge that a "recovery operation" (i.e. correction) is required in some cases for RUS. As far as I can tell, they don't explicitly indicate what is the recovery operation required for this particular circuit (appearing in Fig 1c), but we figured it was an X correction. AFAIK this is not tracked anywhere, but is known by people with experience on RUS. |
||
wexpr, args = wired_clexpr_from_logic_exp(flag[0] ^ flag[1], [flag[2]]) | ||
circ.add_clexpr(wexpr, args) | ||
Comment on lines
+599
to
+600
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reading these tests makes me think we should add a method There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, could be a handy addition, but not strictly necessary. |
||
circ.add_gate(OpType.Z, [qin[0]], condition_bits=[flag[2]], condition_value=1) | ||
|
||
circ.H(qin[0]) # Use to convert gate to sqrt(1/5)*I + i*sqrt(4/5)*X (i.e. Z -> X) | ||
|
||
with CuTensorNetHandle() as libhandle: | ||
cfg = Config() | ||
|
||
state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg) | ||
assert state.is_valid() | ||
assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) | ||
assert state.get_fidelity() == 1.0 | ||
|
||
# All of the flag bits should have turned False | ||
assert all(not state.get_bits()[bit] for bit in flag) | ||
# The auxiliary qubits should be in state |0> | ||
prob = state.postselect({qaux[0]: 0, qaux[1]: 0}) | ||
assert np.isclose(prob, 1.0) | ||
|
||
target_state = [np.sqrt(1 / 5), np.sqrt(4 / 5) * 1j] | ||
output_state = state.get_statevector() | ||
# As indicated in the paper, the gate is implemented up to global phase | ||
global_phase = target_state[0] / output_state[0] | ||
assert np.isclose(abs(global_phase), 1.0) | ||
output_state *= global_phase | ||
assert np.allclose(target_state, output_state) | ||
|
||
|
||
def test_repeat_until_success_ii_classicalexpblox() -> None: | ||
# From Figure 1(c) of https://arxiv.org/pdf/1311.1074 | ||
|
||
attempts = 100 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The intended semantics of
RegOne
is that every bit is set to 1.