Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: Persist circuits and tasks in batch execute #270

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/devices/braket_remote.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ You can set a timeout by using the ``poll_timeout_seconds`` argument;
the device will retry circuits that do not complete within the timeout.
A timeout of 30 to 60 seconds is recommended for circuits with fewer than 25 qubits.

Each of the submitted circuit can be visualised using the attribute ``circuits`` on the device

>> print(remote_device.circuits[0])

Device options
~~~~~~~~~~~~~~

Expand Down
18 changes: 18 additions & 0 deletions src/braket/pennylane_plugin/braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ def __init__(
super().__init__(wires, shots=shots or None)
self._device = device
self._circuit = None
self._circuits = []
self._task = None
self._tasks = []
self._noise_model = noise_model
self._parametrize_differentiable = parametrize_differentiable
self._run_kwargs = run_kwargs
Expand All @@ -153,7 +155,9 @@ def __init__(
def reset(self):
super().reset()
self._circuit = None
self._circuits = []
self._task = None
self._tasks = []

@property
def operations(self) -> frozenset[str]:
Expand All @@ -173,11 +177,21 @@ def circuit(self) -> Circuit:
"""Circuit: The last circuit run on this device."""
return self._circuit

@property
def circuits(self) -> list[Circuit]:
"""Circuit: The circuits run on this device."""
return self._circuits

@property
def task(self) -> QuantumTask:
"""QuantumTask: The task corresponding to the last run circuit."""
return self._task

@property
def tasks(self) -> list[QuantumTask]:
"""The tasks corresponding to the circuits run on this device."""
return self._tasks

def _pl_to_braket_circuit(
self,
circuit: QuantumTape,
Expand Down Expand Up @@ -584,6 +598,8 @@ def __init__(
self._max_parallel = max_parallel
self._max_connections = max_connections
self._max_retries = max_retries
self._circuits = []
self._tasks = []

@property
def use_grouping(self) -> bool:
Expand Down Expand Up @@ -621,6 +637,7 @@ def batch_execute(self, circuits, **run_kwargs):
**run_kwargs,
)
)
self._circuits = braket_circuits

batch_shots = 0 if self.analytic else self.shots

Expand All @@ -639,6 +656,7 @@ def batch_execute(self, circuits, **run_kwargs):
),
**self._run_kwargs,
)
self._tasks = task_batch.tasks
# Call results() to retrieve the Braket results in parallel.
try:
braket_results_batch = task_batch.results(
Expand Down
25 changes: 25 additions & 0 deletions test/unit_tests/test_braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,15 @@ def test_reset():
"""Tests that the members of the device are cleared on reset."""
dev = _aws_device(wires=2)
dev._circuit = CIRCUIT
dev._circuits = [CIRCUIT, CIRCUIT]
dev._task = TASK
dev._tasks = [TASK, TASK]

dev.reset()
assert dev.circuit is None
assert dev.circuits == []
assert dev.task is None
assert dev.tasks == []


def test_apply():
Expand Down Expand Up @@ -910,6 +914,25 @@ def test_batch_execute_non_parallel_tracker(mock_run):
callback.assert_called_with(latest=latest, history=history, totals=totals)


@patch.object(AwsDevice, "run_batch")
def test_batch_execute_parallel_circuits_persistance(mock_run_batch):
mock_run_batch.return_value = TASK_BATCH
dev = _aws_device(wires=4, foo="bar", parallel=True)
assert dev.parallel is True

with QuantumTape() as circuit:
qml.Hadamard(wires=0)
qml.CNOT(wires=[0, 1])
qml.probs(wires=[0])
qml.expval(qml.PauliX(1))
qml.var(qml.PauliY(2))
qml.sample(qml.PauliZ(3))

circuits = [circuit, circuit]
dev.batch_execute(circuits)
assert dev.circuits[1]


@patch.object(AwsDevice, "run_batch")
def test_batch_execute_parallel(mock_run_batch):
"""Test batch_execute(parallel=True) correctly calls batch execution methods in Braket SDK"""
Expand All @@ -927,6 +950,8 @@ def test_batch_execute_parallel(mock_run_batch):

circuits = [circuit, circuit]
batch_results = dev.batch_execute(circuits)

assert dev.tasks[0]
for results in batch_results:
assert np.allclose(
results[0], RESULT.get_value_by_result_type(result_types.Probability(target=[0]))
Expand Down
Loading