From c13b93419844067fff4406a5b145c9b1e51c95fd Mon Sep 17 00:00:00 2001 From: Pradnya Khalate Date: Mon, 23 Sep 2024 11:39:34 -0700 Subject: [PATCH] * Unpack arguments based on the caller's execution context * Make tests stable by enforcing garbage collection as part of set up --- python/cudaq/handlers/photonics_kernel.py | 12 ++++- python/cudaq/kernel/kernel_decorator.py | 12 ++++- python/runtime/common/py_ExecutionContext.cpp | 4 ++ python/runtime/cudaq/algorithms/py_state.cpp | 14 +++--- .../tests/handlers/test_photonics_kernel.py | 44 +++++++++++++++++-- .../photonics/PhotonicsExecutionManager.cpp | 5 +-- 6 files changed, 75 insertions(+), 16 deletions(-) diff --git a/python/cudaq/handlers/photonics_kernel.py b/python/cudaq/handlers/photonics_kernel.py index fc676c33a9..5738234314 100644 --- a/python/cudaq/handlers/photonics_kernel.py +++ b/python/cudaq/handlers/photonics_kernel.py @@ -13,6 +13,8 @@ from ..mlir._mlir_libs._quakeDialects import cudaq_runtime +_TARGET_NAME = 'photonics' + # The qudit level must be explicitly defined globalQuditLevel = None @@ -33,7 +35,13 @@ class PyQudit: id: int def __del__(self): - cudaq_runtime.photonics.release_qudit(self.level, self.id) + try: + cudaq_runtime.photonics.release_qudit(self.level, self.id) + except Exception as e: + if _TARGET_NAME == cudaq_runtime.get_target().name: + raise e + else: + pass def _is_qudit_type(q: any) -> bool: @@ -194,7 +202,7 @@ class PhotonicsHandler(object): def __init__(self, function): - if 'photonics' != cudaq_runtime.get_target().name: + if _TARGET_NAME != cudaq_runtime.get_target().name: raise RuntimeError( "A photonics kernel can only be used with 'photonics' target.") diff --git a/python/cudaq/kernel/kernel_decorator.py b/python/cudaq/kernel/kernel_decorator.py index b43c54a6c8..2af09bb891 100644 --- a/python/cudaq/kernel/kernel_decorator.py +++ b/python/cudaq/kernel/kernel_decorator.py @@ -360,7 +360,17 @@ def __call__(self, *args): raise RuntimeError( "The 'photonics' target must be used with a valid function." ) - PhotonicsHandler(self.kernelFunction)(*args) + # NOTE: Since this handler does not support MLIR mode (yet), just + # invoke the kernel. If calling from a bound function, need to + # unpack the arguments, for example, see `pyGetStateLibraryMode` + try: + context_name = cudaq_runtime.getExecutionContextName() + except RuntimeError: + context_name = None + callable_args = args + if "extract-state" == context_name and len(args) == 1: + callable_args = args[0] + PhotonicsHandler(self.kernelFunction)(*callable_args) return # Prepare captured state storage for the run diff --git a/python/runtime/common/py_ExecutionContext.cpp b/python/runtime/common/py_ExecutionContext.cpp index baa1ca518c..b1fc025e1e 100644 --- a/python/runtime/common/py_ExecutionContext.cpp +++ b/python/runtime/common/py_ExecutionContext.cpp @@ -48,5 +48,9 @@ void bindExecutionContext(py::module &mod) { auto &platform = cudaq::get_platform(); return platform.supports_conditional_feedback(); }); + mod.def("getExecutionContextName", []() { + auto &self = cudaq::get_platform(); + return self.get_exec_ctx()->name; + }); } } // namespace cudaq diff --git a/python/runtime/cudaq/algorithms/py_state.cpp b/python/runtime/cudaq/algorithms/py_state.cpp index 62fe055af6..77a8e4a36d 100644 --- a/python/runtime/cudaq/algorithms/py_state.cpp +++ b/python/runtime/cudaq/algorithms/py_state.cpp @@ -138,15 +138,17 @@ state pyGetStateRemote(py::object kernel, py::args args) { } state pyGetStateLibraryMode(py::object kernel, py::args args) { - cudaq::info("Size of arguments = {}", args.size()); - - /// TODO: Pack / unpack arguments return details::extractState([&]() mutable { if (0 == args.size()) cudaq::invokeKernel(std::forward(kernel)); - else - cudaq::invokeKernel(std::forward(kernel), - std::forward(args)); + else { + std::vector argsData; + for (size_t i = 0; i < args.size(); i++) { + py::object arg = args[i]; + argsData.emplace_back(std::forward(arg)); + } + cudaq::invokeKernel(std::forward(kernel), argsData); + } }); } diff --git a/python/tests/handlers/test_photonics_kernel.py b/python/tests/handlers/test_photonics_kernel.py index 8c2d300324..771769b4ad 100644 --- a/python/tests/handlers/test_photonics_kernel.py +++ b/python/tests/handlers/test_photonics_kernel.py @@ -7,6 +7,10 @@ # ============================================================================ # import pytest + +import gc +from typing import List + import cudaq @@ -16,6 +20,8 @@ def do_something(): yield cudaq.reset_target() cudaq.__clearKernelRegistries() + # Make the tests stable by enforcing resource release + gc.collect() def test_qudit(): @@ -83,18 +89,48 @@ def kernel(): def test_kernel_with_args(): + """Test that `PhotonicsHandler` supports basic arguments. + The check here is that all the test kernels run successfully.""" @cudaq.kernel - def kernel(theta: float): + def kernel_1f(theta: float): q = qudit(4) plus(q) phase_shift(q, theta) mz(q) - result = cudaq.sample(kernel, 0.5) + result = cudaq.sample(kernel_1f, 0.5) + result.dump() + + state = cudaq.get_state(kernel_1f, 0.5) + state.dump() + + @cudaq.kernel + def kernel_2f(theta: float, phi: float): + quds = [qudit(3) for _ in range(2)] + plus(quds[0]) + phase_shift(quds[0], theta) + beam_splitter(quds[0], quds[1], phi) + mz(quds) + + result = cudaq.sample(kernel_2f, 0.7854, 0.3927) result.dump() - - state = cudaq.get_state(kernel, 0.5) + + state = cudaq.get_state(kernel_2f, 0.7854, 0.3927) + state.dump() + + @cudaq.kernel + def kernel_list(angles: List[float]): + quds = [qudit(2) for _ in range(3)] + plus(quds[0]) + phase_shift(quds[1], angles[0]) + phase_shift(quds[2], angles[1]) + mz(quds) + + result = cudaq.sample(kernel_list, [0.5236, 1.0472]) + result.dump() + + state = cudaq.get_state(kernel_list, [0.5236, 1.0472]) state.dump() diff --git a/runtime/cudaq/qis/managers/photonics/PhotonicsExecutionManager.cpp b/runtime/cudaq/qis/managers/photonics/PhotonicsExecutionManager.cpp index ef45f620c2..bfdcac1617 100644 --- a/runtime/cudaq/qis/managers/photonics/PhotonicsExecutionManager.cpp +++ b/runtime/cudaq/qis/managers/photonics/PhotonicsExecutionManager.cpp @@ -64,10 +64,10 @@ struct PhotonicsState : public cudaq::SimulationState { getPrecision()}; } - // /// @brief Return all tensors that represent this state + /// @brief Return all tensors that represent this state std::vector getTensors() const override { return {getTensor()}; } - // /// @brief Return the number of tensors that represent this state. + /// @brief Return the number of tensors that represent this state. std::size_t getNumTensors() const override { return 1; } std::complex @@ -77,7 +77,6 @@ struct PhotonicsState : public cudaq::SimulationState { throw std::runtime_error("[photonics] invalid tensor requested."); if (indices.size() != 1) throw std::runtime_error("[photonics] invalid element extraction."); - return state[indices[0]]; }