diff --git a/docs/changelog.rst b/docs/changelog.rst index db8f665a..fee08d40 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -3,6 +3,12 @@ Changelog ~~~~~~~~~ +Unreleased +---------- + +* Updated pytket version requirement to 1.34. +* Now supporting ``ClExpr`` operations (the new version of tket's ``ClassicalExpBox``). + 0.10.0 (October 2024) --------------------- diff --git a/pytket/extensions/cutensornet/structured_state/classical.py b/pytket/extensions/cutensornet/structured_state/classical.py index d49ac4d9..474529ab 100644 --- a/pytket/extensions/cutensornet/structured_state/classical.py +++ b/pytket/extensions/cutensornet/structured_state/classical.py @@ -22,11 +22,13 @@ SetBitsOp, CopyBitsOp, RangePredicateOp, + ClExprOp, ClassicalExpBox, LogicExp, BitWiseOp, RegWiseOp, ) +from pytket._tket.circuit import ClExpr, ClOp, ClBitVar, ClRegVar ExtendedLogicExp = Union[LogicExp, Bit, BitRegister, int] @@ -56,6 +58,40 @@ def apply_classical_command( # Check that the value is in the range bits_dict[res_bit] = val >= op.lower and val <= op.upper + elif isinstance(op, ClExprOp): + # Convert bit_posn to dictionary of `ClBitVar` index to its value + bitvar_val = { + var_id: int(bits_dict[args[bit_pos]]) + for var_id, bit_pos in op.expr.bit_posn.items() + } + # Convert reg_posn to dictionary of `ClRegVar` index to its value + regvar_val = { + var_id: from_little_endian( + [bits_dict[args[bit_pos]] for bit_pos in reg_pos_list] + ) + for var_id, reg_pos_list in op.expr.reg_posn.items() + } + # Identify number of bits on each register + regvar_size = { + var_id: len(reg_pos_list) + for var_id, reg_pos_list in op.expr.reg_posn.items() + } + # Identify number of bits in output register + output_size = len(op.expr.output_posn) + result = evaluate_clexpr( + op.expr.expr, bitvar_val, regvar_val, regvar_size, output_size + ) + + # The result is an int in little-endian encoding. We update the + # output register accordingly. + for bit_pos in op.expr.output_posn: + bits_dict[args[bit_pos]] = (result % 2) == 1 + result = result >> 1 + # If there has been overflow in the operations, error out. + # This can be detected if `result != 0` + if result != 0: + raise ValueError("Evaluation of the ClExpr resulted in overflow.") + elif isinstance(op, ClassicalExpBox): the_exp = op.get_exp() result = evaluate_logic_exp(the_exp, bits_dict) @@ -74,6 +110,95 @@ def apply_classical_command( raise NotImplementedError(f"Commands of type {op.type} are not supported.") +def evaluate_clexpr( + expr: ClExpr, + bitvar_val: dict[int, int], + regvar_val: dict[int, int], + regvar_size: dict[int, int], + output_size: int, +) -> int: + """Recursive evaluation of a ClExpr.""" + + # Evaluate arguments to operation + args_val = [] + for arg in expr.args: + if isinstance(arg, int): + value = arg + elif isinstance(arg, ClBitVar): + value = bitvar_val[arg.index] + elif isinstance(arg, ClRegVar): + value = regvar_val[arg.index] + elif isinstance(arg, ClExpr): + value = evaluate_clexpr( + arg, bitvar_val, regvar_val, regvar_size, output_size + ) + else: + raise Exception(f"Unrecognised argument type of ClExpr: {type(arg)}.") + + args_val.append(value) + + # Apply the operation at the root of this ClExpr + if expr.op in [ClOp.BitAnd, ClOp.RegAnd]: + result = args_val[0] & args_val[1] + elif expr.op in [ClOp.BitOr, ClOp.RegOr]: + result = args_val[0] | args_val[1] + elif expr.op in [ClOp.BitXor, ClOp.RegXor]: + result = args_val[0] ^ args_val[1] + elif expr.op in [ClOp.BitEq, ClOp.RegEq]: + result = int(args_val[0] == args_val[1]) + elif expr.op in [ClOp.BitNeq, ClOp.RegNeq]: + result = int(args_val[0] != args_val[1]) + elif expr.op == ClOp.RegGeq: + result = int(args_val[0] >= args_val[1]) + elif expr.op == ClOp.RegGt: + result = int(args_val[0] > args_val[1]) + elif expr.op == ClOp.RegLeq: + result = int(args_val[0] <= args_val[1]) + elif expr.op == ClOp.RegLt: + result = int(args_val[0] < args_val[1]) + elif expr.op == ClOp.BitNot: + result = 1 - args_val[0] + elif expr.op == ClOp.RegNot: # Bit-wise NOT (flip all bits) + n_bits = regvar_size[expr.args[0].index] # type: ignore + result = (2**n_bits - 1) ^ args_val[0] # XOR with all 1s bitstring + elif expr.op in [ClOp.BitZero, ClOp.RegZero]: + result = 0 + elif expr.op == ClOp.BitOne: + result = 1 + elif expr.op == ClOp.RegOne: # All 1s bitstring + n_bits = output_size + result = 2**n_bits - 1 + elif expr.op == ClOp.RegAdd: + result = args_val[0] + args_val[1] + elif expr.op == ClOp.RegSub: + if args_val[0] < args_val[1]: + raise NotImplementedError( + "Currently not supporting ClOp.RegSub where the outcome is negative." + ) + result = args_val[0] - args_val[1] + elif expr.op == ClOp.RegMul: + result = args_val[0] * args_val[1] + elif expr.op == ClOp.RegDiv: # floor(a / b) + result = args_val[0] // args_val[1] + elif expr.op == ClOp.RegPow: + result = int(args_val[0] ** args_val[1]) + elif expr.op == ClOp.RegLsh: + result = args_val[0] << args_val[1] + elif expr.op == ClOp.RegRsh: + result = args_val[0] >> args_val[1] + # elif expr.op == ClOp.RegNeg: + # result = -args_val[0] + else: + # TODO: Not supporting RegNeg because I do not know if we have agreed how to + # specify signed ints. + raise NotImplementedError( + f"Evaluation of {expr.op} not supported in ClExpr ", + "by pytket-cutensornet.", + ) + + return result + + def evaluate_logic_exp(exp: ExtendedLogicExp, bits_dict: dict[Bit, bool]) -> int: """Recursive evaluation of a LogicExp.""" @@ -132,4 +257,5 @@ def evaluate_logic_exp(exp: ExtendedLogicExp, bits_dict: dict[Bit, bool]) -> int def from_little_endian(bitstring: list[bool]) -> int: """Obtain the integer from the little-endian encoded bitstring (i.e. bitstring [False, True] is interpreted as the integer 2).""" + # TODO: Assumes unisigned integer. What are the specs for signed integers? return sum(1 << i for i, b in enumerate(bitstring) if b) diff --git a/setup.py b/setup.py index 23e033b2..0f4a3927 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ license="Apache 2", packages=find_namespace_packages(include=["pytket.*"]), include_package_data=True, - install_requires=["pytket >= 1.33.0", "networkx >= 2.8.8"], + install_requires=["pytket >= 1.34.0", "networkx >= 2.8.8"], classifiers=[ "Environment :: Console", "Programming Language :: Python :: 3.10", diff --git a/tests/test_structured_state_conditionals.py b/tests/test_structured_state_conditionals.py index c5b6a116..3fdaa177 100644 --- a/tests/test_structured_state_conditionals.py +++ b/tests/test_structured_state_conditionals.py @@ -10,8 +10,12 @@ Bit, if_not_bit, reg_eq, + WiredClExpr, + ClExpr, + ClOp, ) from pytket.circuit.logic_exp import BitWiseOp, create_bit_logic_exp +from pytket.circuit.clexpr import wired_clexpr_from_logic_exp from pytket.extensions.cutensornet.structured_state import ( CuTensorNetHandle, @@ -26,6 +30,36 @@ # Further down, there are tests to check that the simulation works correctly. +def test_circuit_with_clexpr_i() -> None: + # test conditional handling + + circ = Circuit(3) + a = circ.add_c_register("a", 5) + b = circ.add_c_register("b", 5) + c = circ.add_c_register("c", 5) + d = circ.add_c_register("d", 5) + circ.H(0) + wexpr, args = wired_clexpr_from_logic_exp(a | b, c.to_list()) + circ.add_clexpr(wexpr, args) + wexpr, args = wired_clexpr_from_logic_exp(c | b, d.to_list()) + circ.add_clexpr(wexpr, args) + wexpr, args = wired_clexpr_from_logic_exp(c | b, d.to_list()) + circ.add_clexpr(wexpr, args, condition=a[4]) + circ.H(0) + circ.Measure(Qubit(0), d[4]) + circ.H(1) + circ.Measure(Qubit(1), d[3]) + circ.H(2) + circ.Measure(Qubit(2), d[2]) + + with CuTensorNetHandle() as libhandle: + cfg = Config() + state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg) + assert state.is_valid() + assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) + assert state.get_fidelity() == 1.0 + + def test_circuit_with_classicalexpbox_i() -> None: # test conditional handling @@ -35,9 +69,39 @@ def test_circuit_with_classicalexpbox_i() -> None: c = circ.add_c_register("c", 5) d = circ.add_c_register("d", 5) circ.H(0) - circ.add_classicalexpbox_register(a | b, c) # type: ignore - circ.add_classicalexpbox_register(c | b, d) # type: ignore - circ.add_classicalexpbox_register(c | b, d, condition=a[4]) # type: ignore + circ.add_classicalexpbox_register(a | b, c.to_list()) + circ.add_classicalexpbox_register(c | b, d.to_list()) + circ.add_classicalexpbox_register(c | b, d.to_list(), condition=a[4]) + circ.H(0) + circ.Measure(Qubit(0), d[4]) + circ.H(1) + circ.Measure(Qubit(1), d[3]) + circ.H(2) + circ.Measure(Qubit(2), d[2]) + + with CuTensorNetHandle() as libhandle: + cfg = Config() + state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg) + assert state.is_valid() + assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) + assert state.get_fidelity() == 1.0 + + +def test_circuit_with_clexpr_ii() -> None: + # test conditional handling with else case + + circ = Circuit(3) + a = circ.add_c_register("a", 5) + b = circ.add_c_register("b", 5) + c = circ.add_c_register("c", 5) + d = circ.add_c_register("d", 5) + circ.H(0) + wexpr, args = wired_clexpr_from_logic_exp(a | b, c.to_list()) + circ.add_clexpr(wexpr, args) + wexpr, args = wired_clexpr_from_logic_exp(c | b, d.to_list()) + circ.add_clexpr(wexpr, args) + wexpr, args = wired_clexpr_from_logic_exp(c | b, d.to_list()) + circ.add_clexpr(wexpr, args, condition=if_not_bit(a[4])) circ.H(0) circ.Measure(Qubit(0), d[4]) circ.H(1) @@ -62,11 +126,9 @@ def test_circuit_with_classicalexpbox_ii() -> None: c = circ.add_c_register("c", 5) d = circ.add_c_register("d", 5) circ.H(0) - circ.add_classicalexpbox_register(a | b, c) # type: ignore - circ.add_classicalexpbox_register(c | b, d) # type: ignore - circ.add_classicalexpbox_register( - c | b, d, condition=if_not_bit(a[4]) # type: ignore - ) + circ.add_classicalexpbox_register(a | b, c.to_list()) + circ.add_classicalexpbox_register(c | b, d.to_list()) + circ.add_classicalexpbox_register(c | b, d.to_list(), condition=if_not_bit(a[4])) circ.H(0) circ.Measure(Qubit(0), d[4]) circ.H(1) @@ -82,6 +144,36 @@ def test_circuit_with_classicalexpbox_ii() -> None: assert state.get_fidelity() == 1.0 +@pytest.mark.skip(reason="Currently not supporting arithmetic operations in ClExpr") +def test_circuit_with_clexpr_iii() -> None: + # test complicated conditions and recursive classical op + + circ = Circuit(2) + + a = circ.add_c_register("a", 15) + b = circ.add_c_register("b", 15) + c = circ.add_c_register("c", 15) + d = circ.add_c_register("d", 15) + e = circ.add_c_register("e", 15) + + circ.H(0) + bits = [Bit(i) for i in range(10)] + big_exp = bits[4] | bits[5] ^ bits[6] | bits[7] & bits[8] + circ.H(0, condition=big_exp) + + wexpr, args = wired_clexpr_from_logic_exp(a + b - d, c.to_list()) + circ.add_clexpr(wexpr, args) + wexpr, args = wired_clexpr_from_logic_exp(a * b * d * c, e.to_list()) + circ.add_clexpr(wexpr, args) + + with CuTensorNetHandle() as libhandle: + cfg = Config() + state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg) + assert state.is_valid() + assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) + assert state.get_fidelity() == 1.0 + + @pytest.mark.skip(reason="Currently not supporting arithmetic operations in LogicExp") def test_circuit_with_classicalexpbox_iii() -> None: # test complicated conditions and recursive classical op @@ -99,8 +191,8 @@ def test_circuit_with_classicalexpbox_iii() -> None: big_exp = bits[4] | bits[5] ^ bits[6] | bits[7] & bits[8] circ.H(0, condition=big_exp) - circ.add_classicalexpbox_register(a + b - d, c) # type: ignore - circ.add_classicalexpbox_register(a * b * d * c, e) # type: ignore + circ.add_classicalexpbox_register(a + b - d, c.to_list()) + circ.add_classicalexpbox_register(a * b * d * c, e.to_list()) with CuTensorNetHandle() as libhandle: cfg = Config() @@ -177,7 +269,7 @@ def test_circuit_with_conditional_gate_iv() -> None: assert state.get_fidelity() == 1.0 -def test_pytket_qir_conditional_8() -> None: +def test_pytket_basic_conditional_i() -> None: c = Circuit(4) c.H(0) c.H(1) @@ -196,7 +288,7 @@ def test_pytket_qir_conditional_8() -> None: assert state.get_fidelity() == 1.0 -def test_pytket_qir_conditional_9() -> None: +def test_pytket_basic_conditional_ii() -> None: c = Circuit(4) c.X(0) c.Y(1) @@ -215,7 +307,31 @@ def test_pytket_qir_conditional_9() -> None: assert state.get_fidelity() == 1.0 -def test_pytket_qir_conditional_10() -> None: +def test_pytket_basic_conditional_iii_classicalexpbox() -> None: + box_circ = Circuit(4) + box_circ.X(0) + box_circ.Y(1) + box_circ.Z(2) + box_circ.H(3) + box_c = box_circ.add_c_register("c", 5) + + box_circ.H(0) + box_circ.add_classicalexpbox_register(box_c | box_c, box_c.to_list()) + + cbox = CircBox(box_circ) + d = Circuit(4, 5) + a = d.add_c_register("a", 4) + d.add_circbox(cbox, [0, 2, 1, 3, 0, 1, 2, 3, 4], condition=a[0]) + + with CuTensorNetHandle() as libhandle: + cfg = Config() + state = simulate(libhandle, d, SimulationAlgorithm.MPSxGate, cfg) + assert state.is_valid() + assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) + assert state.get_fidelity() == 1.0 + + +def test_pytket_basic_conditional_iii_clexpr() -> None: box_circ = Circuit(4) box_circ.X(0) box_circ.Y(1) @@ -224,7 +340,9 @@ def test_pytket_qir_conditional_10() -> None: box_c = box_circ.add_c_register("c", 5) box_circ.H(0) - box_circ.add_classicalexpbox_register(box_c | box_c, box_c) # type: ignore + + wexpr, args = wired_clexpr_from_logic_exp(box_c | box_c, box_c.to_list()) + box_circ.add_clexpr(wexpr, args) cbox = CircBox(box_circ) d = Circuit(4, 5) @@ -430,7 +548,84 @@ def test_repeat_until_success_i() -> None: assert np.allclose(target_state, output_state) -def test_repeat_until_success_ii() -> None: +def test_repeat_until_success_ii_clexpr() -> None: + # From Figure 1(c) of https://arxiv.org/pdf/1311.1074 + + attempts = 100 + + circ = Circuit() + qin = circ.add_q_register("qin", 1) + qaux = circ.add_q_register("aux", 2) + flag = circ.add_c_register("flag", 3) + circ.add_c_setbits([True, True], [flag[0], flag[1]]) # Set flag bits to 11 + circ.H(qin[0]) # Use to convert gate to sqrt(1/5)*I + i*sqrt(4/5)*X (i.e. Z -> X) + + for _ in range(attempts): + wexpr, args = wired_clexpr_from_logic_exp( + flag[0] | flag[1], [flag[2]] # Success if both are zero + ) + circ.add_clexpr(wexpr, args) + + circ.add_gate( + OpType.Reset, [qaux[0]], condition_bits=[flag[2]], condition_value=1 + ) + circ.add_gate( + OpType.Reset, [qaux[1]], condition_bits=[flag[2]], condition_value=1 + ) + circ.add_gate(OpType.H, [qaux[0]], condition_bits=[flag[2]], condition_value=1) + circ.add_gate(OpType.H, [qaux[1]], condition_bits=[flag[2]], condition_value=1) + + circ.add_gate(OpType.T, [qin[0]], condition_bits=[flag[2]], condition_value=1) + circ.add_gate(OpType.Z, [qin[0]], condition_bits=[flag[2]], condition_value=1) + circ.add_gate( + OpType.Tdg, [qaux[0]], condition_bits=[flag[2]], condition_value=1 + ) + circ.add_gate( + OpType.CX, [qaux[1], qaux[0]], condition_bits=[flag[2]], condition_value=1 + ) + circ.add_gate(OpType.T, [qaux[0]], condition_bits=[flag[2]], condition_value=1) + circ.add_gate( + OpType.CX, [qin[0], qaux[1]], condition_bits=[flag[2]], condition_value=1 + ) + circ.add_gate(OpType.T, [qaux[1]], condition_bits=[flag[2]], condition_value=1) + + circ.add_gate(OpType.H, [qaux[0]], condition_bits=[flag[2]], condition_value=1) + circ.add_gate(OpType.H, [qaux[1]], condition_bits=[flag[2]], condition_value=1) + circ.Measure(qaux[0], flag[0], condition_bits=[flag[2]], condition_value=1) + circ.Measure(qaux[1], flag[1], condition_bits=[flag[2]], condition_value=1) + + # From chat with Silas and exploring the RUS as a block matrix, we have noticed + # that the circuit is missing an X correction when this condition is satisfied + wexpr, args = wired_clexpr_from_logic_exp(flag[0] ^ flag[1], [flag[2]]) + circ.add_clexpr(wexpr, args) + circ.add_gate(OpType.Z, [qin[0]], condition_bits=[flag[2]], condition_value=1) + + circ.H(qin[0]) # Use to convert gate to sqrt(1/5)*I + i*sqrt(4/5)*X (i.e. Z -> X) + + with CuTensorNetHandle() as libhandle: + cfg = Config() + + state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg) + assert state.is_valid() + assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) + assert state.get_fidelity() == 1.0 + + # All of the flag bits should have turned False + assert all(not state.get_bits()[bit] for bit in flag) + # The auxiliary qubits should be in state |0> + prob = state.postselect({qaux[0]: 0, qaux[1]: 0}) + assert np.isclose(prob, 1.0) + + target_state = [np.sqrt(1 / 5), np.sqrt(4 / 5) * 1j] + output_state = state.get_statevector() + # As indicated in the paper, the gate is implemented up to global phase + global_phase = target_state[0] / output_state[0] + assert np.isclose(abs(global_phase), 1.0) + output_state *= global_phase + assert np.allclose(target_state, output_state) + + +def test_repeat_until_success_ii_classicalexpblox() -> None: # From Figure 1(c) of https://arxiv.org/pdf/1311.1074 attempts = 100 @@ -503,3 +698,41 @@ def test_repeat_until_success_ii() -> None: assert np.isclose(abs(global_phase), 1.0) output_state *= global_phase assert np.allclose(target_state, output_state) + + +def test_clexpr_on_regs() -> None: + """Non-exhaustive test on some ClOp on registers.""" + circ = Circuit(2) + a = circ.add_c_register("a", 5) + b = circ.add_c_register("b", 5) + c = circ.add_c_register("c", 5) + d = circ.add_c_register("d", 5) + e = circ.add_c_register("e", 5) + + w_expr_regone = WiredClExpr(ClExpr(ClOp.RegOne, []), output_posn=list(range(5))) + circ.add_clexpr(w_expr_regone, a.to_list()) # a = 0b11111 = 31 + circ.add_c_setbits([True, True, False, False, False], b.to_list()) # b = 3 + circ.add_c_setbits([False, True, False, True, False], c.to_list()) # c = 10 + circ.add_clexpr(*wired_clexpr_from_logic_exp(b | c, d.to_list())) # d = 11 + circ.add_clexpr(*wired_clexpr_from_logic_exp(a - d, e.to_list())) # e = 20 + + with CuTensorNetHandle() as libhandle: + cfg = Config() + + state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg) + assert state.is_valid() + assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) + assert state.get_fidelity() == 1.0 + + # Check the bits + bits_dict = state.get_bits() + a_bitstring = list(bits_dict[bit] for bit in a) + assert all(a_bitstring) # a = 0b11111 + b_bitstring = list(bits_dict[bit] for bit in b) + assert b_bitstring == [True, True, False, False, False] # b = 0b11000 + c_bitstring = list(bits_dict[bit] for bit in c) + assert c_bitstring == [False, True, False, True, False] # c = 0b01010 + d_bitstring = list(bits_dict[bit] for bit in d) + assert d_bitstring == [True, True, False, True, False] # d = 0b11010 + e_bitstring = list(bits_dict[bit] for bit in e) + assert e_bitstring == [False, False, True, False, True] # e = 0b00101