Skip to content

Commit

Permalink
Deconstruct classical conditions in terms of ClExprOp instead of `C…
Browse files Browse the repository at this point in the history
…lassicalExpBox` (#1657)
  • Loading branch information
cqc-alec authored Nov 6, 2024
1 parent 629c975 commit 71d76a4
Show file tree
Hide file tree
Showing 11 changed files with 35 additions and 196 deletions.
4 changes: 2 additions & 2 deletions pytket/binders/circuit/clexpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace tket {

static std::string qasm_bit_repr(
const ClExprTerm &term, const std::map<int, Bit> &input_bits) {
if (const int *n = std::get_if<int>(&term)) {
if (const uint64_t *n = std::get_if<uint64_t>(&term)) {
switch (*n) {
case 0:
return "0";
Expand All @@ -56,7 +56,7 @@ static std::string qasm_bit_repr(

static std::string qasm_reg_repr(
const ClExprTerm &term, const std::map<int, BitRegister> &input_regs) {
if (const int *n = std::get_if<int>(&term)) {
if (const uint64_t *n = std::get_if<uint64_t>(&term)) {
std::stringstream ss;
ss << *n;
return ss.str();
Expand Down
2 changes: 1 addition & 1 deletion pytket/conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def requirements(self):
self.requires("pybind11_json/0.2.14")
self.requires("symengine/0.12.0")
self.requires("tkassert/0.3.4@tket/stable")
self.requires("tket/1.3.37@tket/stable")
self.requires("tket/1.3.38@tket/stable")
self.requires("tklog/0.3.3@tket/stable")
self.requires("tkrng/0.3.3@tket/stable")
self.requires("tktokenswap/0.3.9@tket/stable")
Expand Down
2 changes: 2 additions & 0 deletions pytket/docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ Features:

* Add `clexpr.check_register_alignments()` method to check register alignments
in `ClExprOp`.
* Use `ClExprOp` instead of `ClassicalExpBox` when deconstructing complex
conditions.

Fixes:

Expand Down
17 changes: 8 additions & 9 deletions pytket/pytket/circuit/add_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@
from typing import Tuple, Union

from pytket.circuit import Bit, Circuit, BitRegister
from pytket._tket.unit_id import (
_TEMP_REG_SIZE,
_TEMP_BIT_NAME,
_TEMP_BIT_REG_BASE,
)
from pytket._tket.unit_id import _TEMP_BIT_NAME, _TEMP_BIT_REG_BASE
from pytket.circuit.clexpr import wired_clexpr_from_logic_exp
from pytket.circuit.logic_exp import (
BitLogicExp,
Constant,
Expand Down Expand Up @@ -79,7 +76,8 @@ def _add_condition(
circ.add_bit(condition_bit)

if isinstance(pred_exp, BitLogicExp):
circ.add_classicalexpbox_bit(pred_exp, [condition_bit])
wexpr, args = wired_clexpr_from_logic_exp(pred_exp, [condition_bit])
circ.add_clexpr(wexpr, args)
return condition_bit, bool(pred_val)

assert isinstance(pred_exp, (RegLogicExp, BitRegister))
Expand All @@ -99,10 +97,11 @@ def _add_condition(
int(r_name.split("_")[-1]) for r_name in existing_reg_names
)
next_index = max(existing_reg_indices, default=-1) + 1
temp_reg = BitRegister(f"{_TEMP_BIT_REG_BASE}_{next_index}", _TEMP_REG_SIZE)
temp_reg = BitRegister(f"{_TEMP_BIT_REG_BASE}_{next_index}", min_reg_size)
circ.add_c_register(temp_reg)
target_bits = temp_reg.to_list()[:min_reg_size]
circ.add_classicalexpbox_register(pred_exp, target_bits)
target_bits = temp_reg.to_list()
wexpr, args = wired_clexpr_from_logic_exp(pred_exp, target_bits)
circ.add_clexpr(wexpr, args)
elif isinstance(pred_exp, BitRegister):
target_bits = pred_exp.to_list()

Expand Down
7 changes: 7 additions & 0 deletions pytket/pytket/circuit/decompose_classical.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,5 +358,12 @@ def _decompose_expressions(circ: Circuit) -> Tuple[Circuit, bool]:
# add_gate doesn't work for metaops
newcirc.add_barrier(args)
else:
for arg in args:
if (
isinstance(arg, Bit)
and arg.reg_name != "_w" # workaround: this shouldn't be type Bit
and arg not in newcirc.bits
):
newcirc.add_bit(arg)
newcirc.add_gate(op, args, **kwargs)
return newcirc, modified
174 changes: 1 addition & 173 deletions pytket/tests/classical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ def test_regpredicate(condition: PredicateExp) -> None:
circ.add_bit(inp, reject_dups=False)

circ.X(qb, condition=condition)
assert circ.n_gates_of_type(OpType.ClassicalExpBox) == 1
assert circ.n_gates_of_type(OpType.ClassicalExpBox) == 0
newcirc = circ.copy()
DecomposeClassicalExp().apply(newcirc)

Expand Down Expand Up @@ -1115,178 +1115,6 @@ def check_serialization_roundtrip(circ: Circuit) -> None:
assert circ_from_dict.to_dict() == circ_dict


def test_decomposition_known() -> None:
bits = [Bit(i) for i in range(10)]
registers = [BitRegister(c, 3) for c in "abdefghijk"]

qreg = QubitRegister("q_", 10)
circ = Circuit()
conditioned_circ = Circuit()
decomposed_circ = Circuit()

for c in (circ, conditioned_circ, decomposed_circ):
for b in bits:
c.add_bit(b)
for br in registers:
for b in br.to_list():
c.add_bit(b, reject_dups=False)
c.add_q_register(qreg.name, qreg.size)

circ.H(qreg[0], condition=bits[0])
circ.X(qreg[0], condition=if_bit(bits[1]))
circ.S(qreg[0])
circ.T(qreg[1], condition=if_not_bit(bits[2]))
circ.Z(qreg[0], condition=(bits[2] & bits[3]))
circ.Z(qreg[1], condition=if_not_bit(bits[3] & bits[4]))
big_exp = bits[4] | bits[5] ^ bits[6] | bits[7] & bits[8]
# ^ no need for parantheses as python operator precedence
# will enforce correct precedence in LogicExp
circ.CX(qreg[0], qreg[1])
circ.CX(qreg[1], qreg[2], condition=big_exp)

circ.add_barrier(qreg.to_list())

circ.H(qreg[2], condition=reg_eq(registers[0], 3))
circ.X(qreg[3], condition=reg_lt(registers[1], 6))
circ.Y(qreg[4], condition=reg_neq(registers[2], 5))
circ.Z(qreg[5], condition=reg_gt(registers[3], 3))
circ.S(qreg[6], condition=reg_leq(registers[4], 6))
circ.T(qreg[7], condition=reg_geq(registers[5], 3))
big_reg_exp = registers[4] & registers[3] | registers[6] ^ registers[7]
circ.CX(qreg[3], qreg[4], condition=reg_eq(big_reg_exp, 3))

circ.add_classicalexpbox_bit(
bits[4] | bits[5] & bits[3], [bits[0]], condition=bits[1]
)
check_serialization_roundtrip(circ)

temp_bits = BitRegister(_TEMP_BIT_NAME, 64)

def temp_reg(i: int) -> BitRegister:
return BitRegister(f"{_TEMP_BIT_REG_BASE}_{i}", 64)

for b in (temp_bits[i] for i in range(0, 10)):
conditioned_circ.add_bit(b)

for t_r in (temp_reg(i) for i in range(0, 1)):
conditioned_circ.add_c_register(t_r.name, t_r.size)

# relies on existing interface for adding conditionals
# may need a more low level interface for that if we decide to get rid of it
conditioned_circ.H(qreg[0], condition_bits=[bits[0]], condition_value=1)
conditioned_circ.X(qreg[0], condition_bits=[bits[1]], condition_value=1)
conditioned_circ.S(qreg[0])
conditioned_circ.T(qreg[1], condition_bits=[bits[2]], condition_value=0)

conditioned_circ.add_classicalexpbox_bit((bits[2] & bits[3]), [temp_bits[0]])
conditioned_circ.Z(qreg[0], condition_bits=[temp_bits[0]], condition_value=1)
conditioned_circ.add_classicalexpbox_bit((bits[3] & bits[4]), [temp_bits[1]])
conditioned_circ.Z(qreg[1], condition_bits=[temp_bits[1]], condition_value=0)
conditioned_circ.CX(qreg[0], qreg[1])
conditioned_circ.add_classicalexpbox_bit(big_exp, [temp_bits[2]])
conditioned_circ.CX(
qreg[1], qreg[2], condition_bits=[temp_bits[2]], condition_value=1
)

conditioned_circ.add_barrier(qreg.to_list())

registers_lists = [reg.to_list() for reg in registers]

conditioned_circ.add_c_range_predicate(3, 3, registers_lists[0], temp_bits[3])
conditioned_circ.H(qreg[2], condition_bits=[temp_bits[3]], condition_value=1)
conditioned_circ.add_c_range_predicate(0, 5, registers_lists[1], temp_bits[4])
conditioned_circ.X(qreg[3], condition_bits=[temp_bits[4]], condition_value=1)
conditioned_circ.add_c_range_predicate(5, 5, registers_lists[2], temp_bits[5])
conditioned_circ.Y(qreg[4], condition_bits=[temp_bits[5]], condition_value=0)
conditioned_circ.add_c_range_predicate(
4, 18446744073709551615, registers_lists[3], temp_bits[6]
)
conditioned_circ.Z(qreg[5], condition_bits=[temp_bits[6]], condition_value=1)
conditioned_circ.add_c_range_predicate(0, 6, registers_lists[4], temp_bits[7])
conditioned_circ.S(qreg[6], condition_bits=[temp_bits[7]], condition_value=1)
conditioned_circ.add_c_range_predicate(
3, 18446744073709551615, registers_lists[5], temp_bits[8]
)
conditioned_circ.T(qreg[7], condition_bits=[temp_bits[8]], condition_value=1)

temp_reg_bits = [temp_reg(0)[i] for i in range(3)]
conditioned_circ.add_classicalexpbox_register(big_reg_exp, temp_reg_bits)
conditioned_circ.add_c_range_predicate(3, 3, temp_reg_bits, temp_bits[9])
conditioned_circ.CX(
qreg[3], qreg[4], condition_bits=[temp_bits[9]], condition_value=1
)
conditioned_circ.add_classicalexpbox_bit(
bits[4] | bits[5] & bits[3], [bits[0]], condition=bits[1]
)

assert compare_commands_box(circ, conditioned_circ)

for b in (temp_bits[i] for i in range(0, 12)):
decomposed_circ.add_bit(b)

decomposed_circ.add_c_register(BitRegister(f"{_TEMP_BIT_REG_BASE}_0", 3))
decomposed_circ.add_c_register(BitRegister(f"{_TEMP_BIT_REG_BASE}_1", 64))
decomposed_circ.add_c_register(BitRegister(f"{_TEMP_BIT_REG_BASE}_2", 64))

decomposed_circ.H(qreg[0], condition_bits=[bits[0]], condition_value=1)
decomposed_circ.X(qreg[0], condition_bits=[bits[1]], condition_value=1)
decomposed_circ.S(qreg[0])
decomposed_circ.T(qreg[1], condition_bits=[bits[2]], condition_value=0)
decomposed_circ.add_c_and(bits[2], bits[3], temp_bits[0])
decomposed_circ.Z(qreg[0], condition_bits=[temp_bits[0]], condition_value=1)
decomposed_circ.add_c_and(bits[3], bits[4], temp_bits[1])
decomposed_circ.Z(qreg[1], condition_bits=[temp_bits[1]], condition_value=0)
decomposed_circ.CX(qreg[0], qreg[1])
decomposed_circ.add_c_range_predicate(3, 3, registers_lists[0], temp_bits[3])
decomposed_circ.add_c_range_predicate(0, 5, registers_lists[1], temp_bits[4])
decomposed_circ.add_c_range_predicate(5, 5, registers_lists[2], temp_bits[5])
decomposed_circ.add_c_range_predicate(
4, 18446744073709551615, registers_lists[3], temp_bits[6]
)
decomposed_circ.add_c_range_predicate(0, 6, registers_lists[4], temp_bits[7])
decomposed_circ.add_c_range_predicate(
3, 18446744073709551615, registers_lists[5], temp_bits[8]
)

decomposed_circ.add_c_xor(bits[5], bits[6], temp_bits[10])
decomposed_circ.add_c_and(bits[7], bits[8], temp_bits[11])
decomposed_circ.add_c_or(bits[4], temp_bits[10], temp_bits[10])
decomposed_circ.add_c_or(temp_bits[10], temp_bits[11], temp_bits[2])
decomposed_circ.CX(
qreg[1], qreg[2], condition_bits=[temp_bits[2]], condition_value=1
)

decomposed_circ.add_barrier(qreg.to_list())

decomposed_circ.H(qreg[2], condition_bits=[temp_bits[3]], condition_value=1)
decomposed_circ.X(qreg[3], condition_bits=[temp_bits[4]], condition_value=1)
decomposed_circ.Y(qreg[4], condition_bits=[temp_bits[5]], condition_value=0)
decomposed_circ.Z(qreg[5], condition_bits=[temp_bits[6]], condition_value=1)
decomposed_circ.S(qreg[6], condition_bits=[temp_bits[7]], condition_value=1)
decomposed_circ.T(qreg[7], condition_bits=[temp_bits[8]], condition_value=1)

decomposed_circ.add_c_and_to_registers(registers[4], registers[3], temp_reg(1))
decomposed_circ.add_c_xor_to_registers(registers[6], registers[7], temp_reg(2))
decomposed_circ.add_c_or_to_registers(
temp_reg(1), BitRegister(temp_reg(2).name, 3), temp_reg(0)
)
decomposed_circ.add_c_range_predicate(3, 3, temp_reg(0).to_list()[:3], temp_bits[9])
decomposed_circ.CX(
qreg[3], qreg[4], condition_bits=[temp_bits[9]], condition_value=1
)
decomposed_circ.add_c_and(
bits[5], bits[3], temp_bits[10], condition_bits=[bits[1]], condition_value=1
)
decomposed_circ.add_c_or(
bits[4], temp_bits[10], bits[0], condition_bits=[bits[1]], condition_value=1
)
check_serialization_roundtrip(decomposed_circ)
circ_copy = circ.copy()

DecomposeClassicalExp().apply(circ_copy)
assert circ_copy == decomposed_circ


def test_conditional() -> None:
c = Circuit(1, 2)
c.H(0, condition_bits=[0, 1], condition_value=3)
Expand Down
2 changes: 1 addition & 1 deletion pytket/tests/compilation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def test_resize_scratch_registers() -> None:
reg_a = circ.add_c_register("a", 1)
reg_b = circ.add_c_register("b", 1)
circ.X(0, condition=reg_eq(reg_a ^ reg_b, 1))
assert circ.get_c_register(f"{_TEMP_BIT_REG_BASE}_0").size == 64
assert circ.get_c_register(f"{_TEMP_BIT_REG_BASE}_0").size == 1
c_compiled = circ.copy()
scratch_reg_resize_pass(10).apply(c_compiled)
assert circ == c_compiled
2 changes: 1 addition & 1 deletion tket/conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class TketConan(ConanFile):
name = "tket"
version = "1.3.37"
version = "1.3.38"
package_type = "library"
license = "Apache 2"
homepage = "https://github.com/CQCL/tket"
Expand Down
3 changes: 2 additions & 1 deletion tket/include/tket/Ops/ClExpr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
* @brief Classical expressions involving bits and registers
*/

#include <cstdint>
#include <map>
#include <nlohmann/detail/macro_scope.hpp>
#include <ostream>
Expand Down Expand Up @@ -124,7 +125,7 @@ void from_json(const nlohmann::json& j, ClExprVar& var);
/**
* A term in a classical expression (either a constant or a variable)
*/
typedef std::variant<int, ClExprVar> ClExprTerm;
typedef std::variant<uint64_t, ClExprVar> ClExprTerm;

std::ostream& operator<<(std::ostream& os, const ClExprTerm& term);

Expand Down
7 changes: 4 additions & 3 deletions tket/src/Ops/ClExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "tket/Ops/ClExpr.hpp"

#include <algorithm>
#include <cstdint>
#include <set>
#include <stdexcept>
#include <string>
Expand Down Expand Up @@ -131,7 +132,7 @@ void from_json(const nlohmann::json& j, ClExprVar& var) {
}

std::ostream& operator<<(std::ostream& os, const ClExprTerm& term) {
if (const int* n = std::get_if<int>(&term)) {
if (const uint64_t* n = std::get_if<uint64_t>(&term)) {
return os << *n;
} else {
ClExprVar var = std::get<ClExprVar>(term);
Expand All @@ -141,7 +142,7 @@ std::ostream& operator<<(std::ostream& os, const ClExprTerm& term) {

void to_json(nlohmann::json& j, const ClExprTerm& term) {
nlohmann::json inner_j;
if (const int* n = std::get_if<int>(&term)) {
if (const uint64_t* n = std::get_if<uint64_t>(&term)) {
j["type"] = "int";
inner_j = *n;
} else {
Expand All @@ -155,7 +156,7 @@ void to_json(nlohmann::json& j, const ClExprTerm& term) {
void from_json(const nlohmann::json& j, ClExprTerm& term) {
const std::string termtype = j.at("type").get<std::string>();
if (termtype == "int") {
term = j.at("term").get<int>();
term = j.at("term").get<uint64_t>();
} else {
TKET_ASSERT(termtype == "var");
term = j.at("term").get<ClExprVar>();
Expand Down
11 changes: 6 additions & 5 deletions tket/test/src/test_ClExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include <catch2/catch_test_macros.hpp>
#include <cstdint>
#include <memory>
#include <nlohmann/json_fwd.hpp>
#include <sstream>
Expand Down Expand Up @@ -127,7 +128,7 @@ SCENARIO("Serialization and stringification") {
REQUIRE(var_reg1 == var_reg);
}
GIVEN("ClExprTerm") {
ClExprTerm term_int = 7;
ClExprTerm term_int = uint64_t{7};
ClExprTerm term_var = ClRegVar{5};
std::stringstream ss;
ss << term_int << ", " << term_var;
Expand All @@ -140,14 +141,14 @@ SCENARIO("Serialization and stringification") {
REQUIRE(term_var1 == term_var);
}
GIVEN("Vector of ClExprArg (1)") {
std::vector<ClExprArg> args{ClRegVar{2}, int{3}};
std::vector<ClExprArg> args{ClRegVar{2}, uint64_t{3}};
nlohmann::json j = args;
std::vector<ClExprArg> args1 = j.get<std::vector<ClExprArg>>();
REQUIRE(args == args1);
}
GIVEN("ClExpr (1)") {
// r0 + 7
ClExpr expr(ClOp::RegAdd, {ClRegVar{0}, int{7}});
ClExpr expr(ClOp::RegAdd, {ClRegVar{0}, uint64_t{7}});
std::stringstream ss;
ss << expr;
REQUIRE(ss.str() == "add(r0, 7)");
Expand All @@ -156,7 +157,7 @@ SCENARIO("Serialization and stringification") {
REQUIRE(expr1 == expr);
}
GIVEN("Vector of ClExprArg (2)") {
ClExpr expr(ClOp::RegAdd, {ClRegVar{0}, int{8}});
ClExpr expr(ClOp::RegAdd, {ClRegVar{0}, uint64_t{8}});
std::vector<ClExprArg> args{expr};
nlohmann::json j = args;
std::vector<ClExprArg> args1 = j.get<std::vector<ClExprArg>>();
Expand All @@ -165,7 +166,7 @@ SCENARIO("Serialization and stringification") {
GIVEN("ClExpr (2)") {
// (r0 + r1) / (r2 * 3)
ClExpr numer(ClOp::RegAdd, {ClRegVar{0}, ClRegVar{1}});
ClExpr denom(ClOp::RegMul, {ClRegVar{2}, int{3}});
ClExpr denom(ClOp::RegMul, {ClRegVar{2}, uint64_t{3}});
ClExpr expr(ClOp::RegDiv, {numer, denom});
std::stringstream ss;
ss << expr;
Expand Down

0 comments on commit 71d76a4

Please sign in to comment.