Skip to content

Commit

Permalink
remove threads argument
Browse files Browse the repository at this point in the history
  • Loading branch information
sjdilkes committed Nov 18, 2024
1 parent 3d66238 commit 45d0773
Show file tree
Hide file tree
Showing 11 changed files with 52 additions and 79 deletions.
4 changes: 1 addition & 3 deletions pytket/binders/passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -965,14 +965,12 @@ PYBIND11_MODULE(passes, m) {
"\n:param only_reduce: Only returns modified circuit if it has "
"fewer two-qubit gates."
"\n:param trials: Sets maximum number of found solutions."
"\n:param threads: Sets maximum number of threads used when finding "
"solutions in parallel."
"\n:return: a pass to perform the simplification",
py::arg("discount_rate") = 0.7, py::arg("depth_weight") = 0.3,
py::arg("max_lookahead") = 500, py::arg("max_tqe_candidates") = 500,
py::arg("seed") = 0, py::arg("allow_zzphase") = false,
py::arg("thread_timeout") = 100, py::arg("only_reduce") = false,
py::arg("trials") = 1, py::arg("threads") = 1);
py::arg("trials") = 1);
m.def(
"PauliSquash", &PauliSquash,
"Applies :py:meth:`PauliSimp` followed by "
Expand Down
5 changes: 1 addition & 4 deletions pytket/binders/transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,14 +443,11 @@ PYBIND11_MODULE(transform, m) {
"\n:param thread_timeout: Sets maximum out of time spent finding a "
"single solution in one thread."
"\n:param trials: Sets maximum number of found solutions."
"\n:param threads: Sets maximum number of threads used when finding "
"solutions in parallel."
"\n:return: a pass to perform the simplification",
py::arg("discount_rate") = 0.7, py::arg("depth_weight") = 0.3,
py::arg("max_tqe_candidates") = 500, py::arg("max_lookahead") = 500,
py::arg("seed") = 0, py::arg("allow_zzphase") = false,
py::arg("thread_timeout") = 100, py::arg("trials") = 1,
py::arg("threads") = 1)
py::arg("thread_timeout") = 100, py::arg("trials") = 1)
.def_static(
"ZZPhaseToRz", &Transforms::ZZPhase_to_Rz,
"Fixes all ZZPhase gate angles to [-1, 1) half turns.")
Expand Down
2 changes: 1 addition & 1 deletion pytket/docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Features:
conditions.
* Add `custom_deserialisation` argument to `BasePass` and `SequencePass`
`from_dict` method to support construction of `CustomPass` from json.
* Add `thread_timeout`, `only_reduce`, `threads` and `trials` arguments
* Add `thread_timeout`, `only_reduce`, and `trials` arguments
to `GreedyPauliSimp`.
* Add option to not relabel `ClassicalExpBox` when calling `rename_units`
and `flatten_registers`
Expand Down
1 change: 0 additions & 1 deletion pytket/tests/passes_serialisation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ def nonparam_predicate_dict(name: str) -> Dict[str, Any]:
"thread_timeout": 5000,
"only_reduce": False,
"trials": 1,
"threads": 1,
}
),
# lists must be sorted by OpType value
Expand Down
2 changes: 1 addition & 1 deletion pytket/tests/predicates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,7 @@ def test_greedy_pauli_synth() -> None:
).SWAP(regb[1], rega[0])
d = circ.copy()
assert not GreedyPauliSimp(0.5, 0.5, thread_timeout=0, only_reduce=True).apply(d)
assert GreedyPauliSimp(0.5, 0.5, thread_timeout=10, trials=5, threads=3).apply(d)
assert GreedyPauliSimp(0.5, 0.5, thread_timeout=10, trials=5).apply(d)

assert np.allclose(circ.get_unitary(), d.get_unitary())
assert d.name == "test"
Expand Down
9 changes: 2 additions & 7 deletions schemas/compiler_pass_v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -385,10 +385,6 @@
"trials": {
"type": "number",
"definition": "parameter controlling the number of random solutions found when calling \"GreedyPauliSimp\""
},
"threads": {
"type": "number",
"definition": "parameter controlling the maximum number of threads used in parallel when calling \"GreedyPauliSimp\""
}
},
"required": [
Expand Down Expand Up @@ -923,10 +919,9 @@
"allow_zzphase",
"thread_timeout",
"only_reduce",
"trials",
"threads"
"trials"
],
"maxProperties": 11
"maxProperties": 10
}
},
{
Expand Down
2 changes: 1 addition & 1 deletion tket/include/tket/Predicates/PassGenerators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ PassPtr gen_greedy_pauli_simp(
unsigned max_lookahead = 500, unsigned max_tqe_candidates = 500,
unsigned seed = 0, bool allow_zzphase = false,
unsigned thread_timeout = 100, bool only_reduce = false,
unsigned threads = 1, unsigned trials = 1);
unsigned trials = 1);

/**
* Generate a pass to simplify the circuit where it acts on known basis states.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ Transform greedy_pauli_optimisation(
double discount_rate = 0.7, double depth_weight = 0.3,
unsigned max_lookahead = 500, unsigned max_tqe_candidates = 500,
unsigned seed = 0, bool allow_zzphase = false,
unsigned thread_timeout = 100, unsigned trials = 1, unsigned threads = 1);
unsigned thread_timeout = 100, unsigned trials = 1);

} // namespace Transforms

Expand Down
3 changes: 1 addition & 2 deletions tket/src/Predicates/CompilerPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,11 +524,10 @@ PassPtr deserialise(
bool allow_zzphase = content.at("allow_zzphase").get<bool>();
unsigned timeout = content.at("thread_timeout").get<unsigned>();
bool only_reduce = content.at("only_reduce").get<bool>();
unsigned threads = content.at("threads").get<unsigned>();
unsigned trials = content.at("trials").get<unsigned>();
pp = gen_greedy_pauli_simp(
discount_rate, depth_weight, max_lookahead, max_tqe_candidates, seed,
allow_zzphase, timeout, only_reduce, threads, trials);
allow_zzphase, timeout, only_reduce, trials);

} else if (passname == "PauliSimp") {
// SEQUENCE PASS - DESERIALIZABLE ONLY
Expand Down
55 changes: 26 additions & 29 deletions tket/src/Predicates/PassGenerators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1018,35 +1018,33 @@ PassPtr gen_synthesise_pauli_graph(
PassPtr gen_greedy_pauli_simp(
double discount_rate, double depth_weight, unsigned max_lookahead,
unsigned max_tqe_candidates, unsigned seed, bool allow_zzphase,
unsigned thread_timeout, bool only_reduce, unsigned threads,
unsigned trials) {
Transform t =
Transform([discount_rate, depth_weight, max_lookahead, max_tqe_candidates,
seed, allow_zzphase, thread_timeout, only_reduce, threads,
trials](Circuit& circ) {
Transform gpo = Transforms::greedy_pauli_optimisation(
discount_rate, depth_weight, max_lookahead, max_tqe_candidates,
seed, allow_zzphase, thread_timeout, threads, trials);
if (only_reduce) {
Circuit gpo_circ = circ;
// comparison will be inaccurate if circuit has PauliExpBox
gpo_circ.decompose_boxes_recursively();
unsigned original_n_2qb_gates = gpo_circ.count_n_qubit_gates(2);
unsigned original_n_gates = gpo_circ.n_gates();
unsigned original_depth = gpo_circ.depth();
if (gpo.apply(gpo_circ) &&
unsigned thread_timeout, bool only_reduce, unsigned trials) {
Transform t = Transform([discount_rate, depth_weight, max_lookahead,
max_tqe_candidates, seed, allow_zzphase,
thread_timeout, only_reduce, trials](Circuit& circ) {
Transform gpo = Transforms::greedy_pauli_optimisation(
discount_rate, depth_weight, max_lookahead, max_tqe_candidates, seed,
allow_zzphase, thread_timeout, trials);
if (only_reduce) {
Circuit gpo_circ = circ;
// comparison will be inaccurate if circuit has PauliExpBox
gpo_circ.decompose_boxes_recursively();
unsigned original_n_2qb_gates = gpo_circ.count_n_qubit_gates(2);
unsigned original_n_gates = gpo_circ.n_gates();
unsigned original_depth = gpo_circ.depth();
if (gpo.apply(gpo_circ) &&
std::make_tuple(
gpo_circ.count_n_qubit_gates(2), gpo_circ.n_gates(),
gpo_circ.depth()) <
std::make_tuple(
gpo_circ.count_n_qubit_gates(2), gpo_circ.n_gates(),
gpo_circ.depth()) <
std::make_tuple(
original_n_2qb_gates, original_n_gates, original_depth)) {
circ = gpo_circ;
return true;
}
return false;
}
return gpo.apply(circ);
});
original_n_2qb_gates, original_n_gates, original_depth)) {
circ = gpo_circ;
return true;
}
return false;
}
return gpo.apply(circ);
});
OpTypeSet ins = {
OpType::Z,
OpType::X,
Expand Down Expand Up @@ -1097,7 +1095,6 @@ PassPtr gen_greedy_pauli_simp(
j["allow_zzphase"] = allow_zzphase;
j["thread_timeout"] = thread_timeout;
j["only_reduce"] = only_reduce;
j["threads"] = threads;
j["trials"] = trials;

return std::make_shared<StandardPass>(precons, t, postcon, j);
Expand Down
46 changes: 17 additions & 29 deletions tket/src/Transformations/GreedyPauliOptimisation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -827,51 +827,39 @@ Circuit greedy_pauli_graph_synthesis(
Transform greedy_pauli_optimisation(
double discount_rate, double depth_weight, unsigned max_lookahead,
unsigned max_tqe_candidates, unsigned seed, bool allow_zzphase,
unsigned thread_timeout, unsigned trials, unsigned threads) {
unsigned thread_timeout, unsigned trials) {
return Transform([discount_rate, depth_weight, max_lookahead,
max_tqe_candidates, seed, allow_zzphase, thread_timeout,
trials, threads](Circuit& circ) {
trials](Circuit& circ) {
std::mt19937 seed_gen(seed);
std::queue<
std::pair<std::future<Circuit>, std::shared_ptr<std::atomic<bool>>>>
all_threads;
std::vector<Circuit> circuits;
unsigned max_threads =
std::min(threads, std::thread::hardware_concurrency());
unsigned threads_started = 0;

while (threads_started < trials || !all_threads.empty()) {
// Start new jobs if we haven't reached the max threads or trials
if (threads_started < trials && all_threads.size() < max_threads) {
std::shared_ptr<std::atomic<bool>> stop_flag =
std::make_shared<std::atomic<bool>>(false);
// Circuit copy(circ);
std::future<Circuit> future = std::async(
std::launch::async,
[&, stop_flag]() { // Capture `stop_flag` explicitly in the lambda
return GreedyPauliSimp::greedy_pauli_graph_synthesis_flag(
circ, stop_flag, discount_rate, depth_weight, max_lookahead,
max_tqe_candidates, seed_gen(), allow_zzphase);
});
all_threads.emplace(std::move(future), stop_flag);
threads_started++;
// continue to come straight back to this if statement, meaning we
// maximise parallel threads
continue;
}
while (threads_started < trials) {
std::shared_ptr<std::atomic<bool>> stop_flag =
std::make_shared<std::atomic<bool>>(false);
// Circuit copy(circ);
std::future<Circuit> future = std::async(
std::launch::async,
[&, stop_flag]() { // Capture `stop_flag` explicitly in the lambda
return GreedyPauliSimp::greedy_pauli_graph_synthesis_flag(
circ, stop_flag, discount_rate, depth_weight, max_lookahead,
max_tqe_candidates, seed_gen(), allow_zzphase);
});
threads_started++;

// Check the oldest thread for completion
auto& [thread, stop_flag] = all_threads.front();
if (thread.wait_for(std::chrono::seconds(thread_timeout)) ==
if (future.wait_for(std::chrono::seconds(thread_timeout)) ==
std::future_status::ready) {
Circuit c = thread.get();
Circuit c = future.get();
c.decompose_boxes_recursively();
circuits.push_back(c);
all_threads.pop();
} else {
// If the thread is not ready, move it to the back of the queue
*stop_flag = true;
all_threads.pop();
break;
}
}

Expand Down

0 comments on commit 45d0773

Please sign in to comment.