From 03fe0234aac3a2dff9fabe256bf03b59b659a053 Mon Sep 17 00:00:00 2001 From: WingCode Date: Wed, 12 Jun 2024 18:49:03 +0200 Subject: [PATCH] Persist circuits and tasks in batch execute (#260) * feature: persist circuits in batch execute * feature: Add test for batch_execute persistance * fix: circuits attribute * documentation: add doc for printing circuits from batch_execute * fix: linting * Add tasks, circuits --------- Co-authored-by: Ryan Shaffer <3620100+rmshaffer@users.noreply.github.com> Co-authored-by: Tim (Yi-Ting) --- doc/devices/braket_remote.rst | 4 ++++ src/braket/pennylane_plugin/braket_device.py | 18 ++++++++++++++ test/unit_tests/test_braket_device.py | 25 ++++++++++++++++++++ 3 files changed, 47 insertions(+) diff --git a/doc/devices/braket_remote.rst b/doc/devices/braket_remote.rst index 1b8d70e2..b0bc7601 100644 --- a/doc/devices/braket_remote.rst +++ b/doc/devices/braket_remote.rst @@ -66,6 +66,10 @@ You can set a timeout by using the ``poll_timeout_seconds`` argument; the device will retry circuits that do not complete within the timeout. A timeout of 30 to 60 seconds is recommended for circuits with fewer than 25 qubits. +Each of the submitted circuit can be visualised using the attribute ``circuits`` on the device + +>> print(remote_device.circuits[0]) + Device options ~~~~~~~~~~~~~~ diff --git a/src/braket/pennylane_plugin/braket_device.py b/src/braket/pennylane_plugin/braket_device.py index c6d6c87f..ea49b9ea 100644 --- a/src/braket/pennylane_plugin/braket_device.py +++ b/src/braket/pennylane_plugin/braket_device.py @@ -139,7 +139,9 @@ def __init__( super().__init__(wires, shots=shots or None) self._device = device self._circuit = None + self._circuits = [] self._task = None + self._tasks = [] self._noise_model = noise_model self._parametrize_differentiable = parametrize_differentiable self._run_kwargs = run_kwargs @@ -153,7 +155,9 @@ def __init__( def reset(self): super().reset() self._circuit = None + self._circuits = [] self._task = None + self._tasks = [] @property def operations(self) -> frozenset[str]: @@ -173,11 +177,21 @@ def circuit(self) -> Circuit: """Circuit: The last circuit run on this device.""" return self._circuit + @property + def circuits(self) -> list[Circuit]: + """Circuit: The circuits run on this device.""" + return self._circuits + @property def task(self) -> QuantumTask: """QuantumTask: The task corresponding to the last run circuit.""" return self._task + @property + def tasks(self) -> list[QuantumTask]: + """The tasks corresponding to the circuits run on this device.""" + return self._tasks + def _pl_to_braket_circuit( self, circuit: QuantumTape, @@ -584,6 +598,8 @@ def __init__( self._max_parallel = max_parallel self._max_connections = max_connections self._max_retries = max_retries + self._circuits = [] + self._tasks = [] @property def use_grouping(self) -> bool: @@ -621,6 +637,7 @@ def batch_execute(self, circuits, **run_kwargs): **run_kwargs, ) ) + self._circuits = braket_circuits batch_shots = 0 if self.analytic else self.shots @@ -639,6 +656,7 @@ def batch_execute(self, circuits, **run_kwargs): ), **self._run_kwargs, ) + self._tasks = task_batch.tasks # Call results() to retrieve the Braket results in parallel. try: braket_results_batch = task_batch.results( diff --git a/test/unit_tests/test_braket_device.py b/test/unit_tests/test_braket_device.py index 7d818570..9c8c44c2 100644 --- a/test/unit_tests/test_braket_device.py +++ b/test/unit_tests/test_braket_device.py @@ -92,11 +92,15 @@ def test_reset(): """Tests that the members of the device are cleared on reset.""" dev = _aws_device(wires=2) dev._circuit = CIRCUIT + dev._circuits = [CIRCUIT, CIRCUIT] dev._task = TASK + dev._tasks = [TASK, TASK] dev.reset() assert dev.circuit is None + assert dev.circuits == [] assert dev.task is None + assert dev.tasks == [] def test_apply(): @@ -910,6 +914,25 @@ def test_batch_execute_non_parallel_tracker(mock_run): callback.assert_called_with(latest=latest, history=history, totals=totals) +@patch.object(AwsDevice, "run_batch") +def test_batch_execute_parallel_circuits_persistance(mock_run_batch): + mock_run_batch.return_value = TASK_BATCH + dev = _aws_device(wires=4, foo="bar", parallel=True) + assert dev.parallel is True + + with QuantumTape() as circuit: + qml.Hadamard(wires=0) + qml.CNOT(wires=[0, 1]) + qml.probs(wires=[0]) + qml.expval(qml.PauliX(1)) + qml.var(qml.PauliY(2)) + qml.sample(qml.PauliZ(3)) + + circuits = [circuit, circuit] + dev.batch_execute(circuits) + assert dev.circuits[1] + + @patch.object(AwsDevice, "run_batch") def test_batch_execute_parallel(mock_run_batch): """Test batch_execute(parallel=True) correctly calls batch execution methods in Braket SDK""" @@ -927,6 +950,8 @@ def test_batch_execute_parallel(mock_run_batch): circuits = [circuit, circuit] batch_results = dev.batch_execute(circuits) + + assert dev.tasks[0] for results in batch_results: assert np.allclose( results[0], RESULT.get_value_by_result_type(result_types.Probability(target=[0]))