Skip to content

Commit

Permalink
Add custom_deserialisation argument to BasePass.from_dict (#1647)
Browse files Browse the repository at this point in the history
* replace nlohmann json with custom deserialise

* bump

* add extra handling for serialisation and `SequencePass`

* add serialise/deserialise methods

* add dict argument

* Update test_json.cpp

* Update test_json.cpp

* Update passes.pyi

* add optional argument

* add test

* Update passes.cpp

* Update passes_serialisation_test.py

* fix sequence pass serialisation and stub generation

* bump

* bump

* Update tket/src/Predicates/CompilerPass.cpp

Co-authored-by: Alec Edgington <[email protected]>

* update changelog and docstring

* Update CompilerPass.cpp

* remove outdated code, update stubs

* Update PassGenerators.cpp

---------

Co-authored-by: Alec Edgington <[email protected]>
  • Loading branch information
sjdilkes and cqc-alec authored Nov 7, 2024
1 parent 5624c72 commit a9add4b
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 45 deletions.
31 changes: 26 additions & 5 deletions pytket/binders/passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,23 +265,35 @@ PYBIND11_MODULE(passes, m) {
.def(
"to_dict",
[](const BasePass &base_pass) {
return py::object(base_pass.get_config()).cast<py::dict>();
return py::cast(serialise(base_pass));
},
":return: A JSON serializable dictionary representation of the Pass.")
.def_static(
"from_dict",
[](const py::dict &base_pass_dict) {
return json(base_pass_dict).get<PassPtr>();
[](const py::dict &base_pass_dict,

std::map<std::string, std::function<Circuit(const Circuit &)>>
&custom_deserialisation) {
return deserialise(base_pass_dict, custom_deserialisation);
},
"Construct a new Pass instance from a JSON serializable dictionary "
"representation.")
"representation. `custom_deserialisation` is a map between "
"`CustomPass` "
"label attributes and a Circuit to Circuit function matching the "
"`CustomPass` `transform` argument. This allows the construction of "
"some `CustomPass` from JSON. `CustomPass` without a matching entry "
"in "
"`custom_deserialisation` will be rejected.",
py::arg("base_pass_dict"),
py::arg("custom_deserialisation") =
std::map<std::string, std::function<Circuit(const Circuit &)>>{})
.def(py::pickle(
[](py::object self) { // __getstate__
return py::make_tuple(self.attr("to_dict")());
},
[](const py::tuple &t) { // __setstate__
const json j = t[0].cast<json>();
return j.get<PassPtr>();
return deserialise(j);
}));
py::class_<SequencePass, std::shared_ptr<SequencePass>, BasePass>(
m, "SequencePass", "A sequence of compilation passes.")
Expand All @@ -296,9 +308,18 @@ PYBIND11_MODULE(passes, m) {
"\n:return: a pass that applies the sequence",
py::arg("pass_list"), py::arg("strict") = true)
.def("__str__", [](const BasePass &) { return "<tket::SequencePass>"; })
.def(
"to_dict",
[](const SequencePass &seq_pass) {
return py::cast(
serialise(std::make_shared<SequencePass>(seq_pass)));
},
":return: A JSON serializable dictionary representation of the "
"SequencePass.")
.def(
"get_sequence", &SequencePass::get_sequence,
":return: The underlying sequence of passes.");

py::class_<RepeatPass, std::shared_ptr<RepeatPass>, BasePass>(
m, "RepeatPass",
"Repeat a pass until its `apply()` method returns False, or if "
Expand Down
2 changes: 1 addition & 1 deletion pytket/conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def requirements(self):
self.requires("pybind11_json/0.2.14")
self.requires("symengine/0.12.0")
self.requires("tkassert/0.3.4@tket/stable")
self.requires("tket/1.3.38@tket/stable")
self.requires("tket/1.3.39@tket/stable")
self.requires("tklog/0.3.3@tket/stable")
self.requires("tkrng/0.3.3@tket/stable")
self.requires("tktokenswap/0.3.9@tket/stable")
Expand Down
2 changes: 2 additions & 0 deletions pytket/docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ Features:
in `ClExprOp`.
* Use `ClExprOp` instead of `ClassicalExpBox` when deconstructing complex
conditions.
* Add `custom_deserialisation` argument to `BasePass` and `SequencePass`
`from_dict` method to support construction of `CustomPass` from json.

Fixes:

Expand Down
10 changes: 7 additions & 3 deletions pytket/pytket/_tket/passes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ class BasePass:
def _pybind11_conduit_v1_(*args, **kwargs): # type: ignore
...
@staticmethod
def from_dict(arg0: dict) -> BasePass:
def from_dict(base_pass_dict: dict, custom_deserialisation: dict[str, typing.Callable[[pytket._tket.circuit.Circuit], pytket._tket.circuit.Circuit]] = {}) -> BasePass:
"""
Construct a new Pass instance from a JSON serializable dictionary representation.
Construct a new Pass instance from a JSON serializable dictionary representation. `custom_deserialisation` is a map between `CustomPass` label attributes and a Circuit to Circuit function matching the `CustomPass` `transform` argument. This allows the construction of some `CustomPass` from JSON. `CustomPass` without a matching entry in `custom_deserialisation` will be rejected.
"""
def __getstate__(self) -> tuple:
...
Expand Down Expand Up @@ -54,7 +54,7 @@ class BasePass:
:param after_apply: Invoked after a pass is applied. The CompilationUnit and a summary of the pass configuration are passed into the callback.
:return: True if pass modified the circuit, else False
"""
def to_dict(self) -> dict:
def to_dict(self) -> typing.Any:
"""
:return: A JSON serializable dictionary representation of the Pass.
"""
Expand Down Expand Up @@ -227,6 +227,10 @@ class SequencePass(BasePass):
"""
:return: The underlying sequence of passes.
"""
def to_dict(self) -> typing.Any:
"""
:return: A JSON serializable dictionary representation of the SequencePass.
"""
def AASRouting(arc: pytket._tket.architecture.Architecture, **kwargs: Any) -> BasePass:
"""
Construct a pass to relabel :py:class:`Circuit` Qubits to :py:class:`Device` Nodes, and then use architecture-aware synthesis to route the circuit. In the steps of the pass the circuit will be converted to CX, Rz, H gateset. The limited connectivity of the :py:class:`Architecture` is used for the routing. The direction of the edges is ignored. The placement used is GraphPlacement. This pass can take a few parameters for the routing, described below:
Expand Down
13 changes: 13 additions & 0 deletions pytket/tests/passes_serialisation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
DefaultMappingPass,
AASRouting,
SquashCustom,
CustomPass,
)
from pytket.mapping import (
LexiLabellingMethod,
Expand Down Expand Up @@ -771,3 +772,15 @@ def no_CX(circ: Circuit) -> bool:
rps.to_dict()["RepeatUntilSatisfiedPass"]["predicate"]["type"]
== "UserDefinedPredicate"
)


def test_custom_deserialisation() -> None:
def t(c: Circuit) -> Circuit:
return Circuit(2).CX(0, 1)

custom_pass_post = BasePass.from_dict(
CustomPass(t, label="test").to_dict(), {"test": t}
)
c: Circuit = Circuit(3).H(0).H(1).H(2)
custom_pass_post.apply(c)
assert c == Circuit(2).CX(0, 1)
2 changes: 1 addition & 1 deletion tket/conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class TketConan(ConanFile):
name = "tket"
version = "1.3.38"
version = "1.3.39"
package_type = "library"
license = "Apache 2"
homepage = "https://github.com/CQCL/tket"
Expand Down
10 changes: 8 additions & 2 deletions tket/include/tket/Predicates/CompilerPass.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ typedef std::pair<PredicatePtrMap, PostConditions> PassConditions;
typedef std::function<void(const CompilationUnit&, const nlohmann::json&)>
PassCallback;

JSON_DECL(PassPtr)

class IncompatibleCompilerPasses : public std::logic_error {
public:
explicit IncompatibleCompilerPasses(const std::type_index& typeid1)
Expand Down Expand Up @@ -301,6 +299,14 @@ class RepeatUntilSatisfiedPass : public BasePass {
PredicatePtr pred_;
};

nlohmann::json serialise(const BasePass& bp);
nlohmann::json serialise(const PassPtr& pp);
nlohmann::json serialise(const std::vector<PassPtr>& pp);

PassPtr deserialise(
const nlohmann::json& j,
const std::map<std::string, std::function<Circuit(const Circuit&)>>&
custom_deserialise = {});
// TODO: Repeat with a metric, repeat until a Predicate is satisfied...

} // namespace tket
43 changes: 34 additions & 9 deletions tket/src/Predicates/CompilerPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ std::string SequencePass::to_string() const {
nlohmann::json SequencePass::get_config() const {
nlohmann::json j;
j["pass_class"] = "SequencePass";
j["SequencePass"]["sequence"] = seq_;
j["SequencePass"]["sequence"] = serialise(seq_);
return j;
}

Expand All @@ -270,7 +270,7 @@ std::string RepeatPass::to_string() const {
nlohmann::json RepeatPass::get_config() const {
nlohmann::json j;
j["pass_class"] = "RepeatPass";
j["RepeatPass"]["body"] = pass_;
j["RepeatPass"]["body"] = serialise(pass_);
return j;
}

Expand Down Expand Up @@ -313,7 +313,7 @@ std::string RepeatWithMetricPass::to_string() const {
nlohmann::json RepeatWithMetricPass::get_config() const {
nlohmann::json j;
j["pass_class"] = "RepeatWithMetricPass";
j["RepeatWithMetricPass"]["body"] = pass_;
j["RepeatWithMetricPass"]["body"] = serialise(pass_);
j["RepeatWithMetricPass"]["metric"] =
"SERIALIZATION OF METRICS NOT YET IMPLEMENTED";
return j;
Expand Down Expand Up @@ -347,15 +347,27 @@ std::string RepeatUntilSatisfiedPass::to_string() const {
nlohmann::json RepeatUntilSatisfiedPass::get_config() const {
nlohmann::json j;
j["pass_class"] = "RepeatUntilSatisfiedPass";
j["RepeatUntilSatisfiedPass"]["body"] = pass_;
j["RepeatUntilSatisfiedPass"]["body"] = serialise(pass_);
j["RepeatUntilSatisfiedPass"]["predicate"] = pred_;
return j;
}

void to_json(nlohmann::json& j, const PassPtr& pp) { j = pp->get_config(); }
nlohmann::json serialise(const BasePass& bp) { return bp.get_config(); }
nlohmann::json serialise(const PassPtr& pp) { return pp->get_config(); }
nlohmann::json serialise(const std::vector<PassPtr>& pp) {
nlohmann::json j = nlohmann::json::array();
for (const auto& p : pp) {
j.push_back(serialise(p));
}
return j;
}

void from_json(const nlohmann::json& j, PassPtr& pp) {
PassPtr deserialise(
const nlohmann::json& j,
const std::map<std::string, std::function<Circuit(const Circuit&)>>&
custom_deserialise) {
std::string classname = j.at("pass_class").get<std::string>();
PassPtr pp;
if (classname == "StandardPass") {
const nlohmann::json& content = j.at("StandardPass");
std::string passname = content.at("name").get<std::string>();
Expand Down Expand Up @@ -576,29 +588,42 @@ void from_json(const nlohmann::json& j, PassPtr& pp) {
unsigned n = content.at("n").get<unsigned>();
bool only_zeros = content.at("only_zeros").get<bool>();
pp = RoundAngles(n, only_zeros);
} else if (passname == "CustomPass") {
std::string label = content.at("label").get<std::string>();
auto it = custom_deserialise.find(label);
if (it != custom_deserialise.end()) {
pp = CustomPass(it->second, label);
} else {
throw JsonError(
"Cannot deserialise CustomPass without passing a "
"custom_deserialisation map "
"with a key corresponding to the pass's label.");
}
} else {
throw JsonError("Cannot load StandardPass of unknown type");
}
} else if (classname == "SequencePass") {
const nlohmann::json& content = j.at("SequencePass");
std::vector<PassPtr> seq;
for (const auto& j_entry : content.at("sequence")) {
seq.push_back(j_entry.get<PassPtr>());
seq.push_back(deserialise(j_entry, custom_deserialise));
}
pp = std::make_shared<SequencePass>(seq);
} else if (classname == "RepeatPass") {
const nlohmann::json& content = j.at("RepeatPass");
pp = std::make_shared<RepeatPass>(content.at("body").get<PassPtr>());
pp = std::make_shared<RepeatPass>(
deserialise(content.at("body"), custom_deserialise));
} else if (classname == "RepeatWithMetricPass") {
throw PassNotSerializable(classname);
} else if (classname == "RepeatUntilSatisfiedPass") {
const nlohmann::json& content = j.at("RepeatUntilSatisfiedPass");
PassPtr body = content.at("body").get<PassPtr>();
PassPtr body = deserialise(content.at("body"), custom_deserialise);
PredicatePtr pred = content.at("predicate").get<PredicatePtr>();
pp = std::make_shared<RepeatUntilSatisfiedPass>(body, pred);
} else {
throw JsonError("Cannot load PassPtr of unknown type.");
}
return pp;
}

} // namespace tket
Loading

0 comments on commit a9add4b

Please sign in to comment.