diff --git a/python/cudaq/handlers/photonics_kernel.py b/python/cudaq/handlers/photonics_kernel.py index 5738234314..a0d1d65907 100644 --- a/python/cudaq/handlers/photonics_kernel.py +++ b/python/cudaq/handlers/photonics_kernel.py @@ -13,11 +13,6 @@ from ..mlir._mlir_libs._quakeDialects import cudaq_runtime -_TARGET_NAME = 'photonics' - -# The qudit level must be explicitly defined -globalQuditLevel = None - @dataclass class PyQudit: @@ -34,14 +29,39 @@ class PyQudit: level: int id: int - def __del__(self): - try: - cudaq_runtime.photonics.release_qudit(self.level, self.id) - except Exception as e: - if _TARGET_NAME == cudaq_runtime.get_target().name: - raise e - else: - pass + +class QuditManager(object): + """ + A class to explicitly manage resource allocation for qudits within a + `PhotonicsKernel`. + """ + qudit_level = None + allocated_ids = [] + + @classmethod + def reset(cls): + cls.qudit_level = None + cls.allocated_ids = [] + + @classmethod + def allocate(cls, level: int): + if cls.qudit_level is None: + cls.qudit_level = level + elif level != cls.qudit_level: + raise RuntimeError( + "The qudits must be of same level within a kernel.") + id = cudaq_runtime.photonics.allocate_qudit(cls.qudit_level) + cls.allocated_ids.append(id) + return PyQudit(cls.qudit_level, id) + + def __enter__(cls): + cls.reset() + + def __exit__(cls, exc_type, exc_val, exc_tb): + while cls.allocated_ids: + cudaq_runtime.photonics.release_qudit(cls.allocated_ids.pop(), + cls.qudit_level) + cls.reset() def _is_qudit_type(q: any) -> bool: @@ -71,7 +91,7 @@ def _check_args(q: any): RuntimeError: If the qudit level is not set. Exception: If input argument is not instance of `PyQudit` class. """ - if globalQuditLevel is None: + if QuditManager.qudit_level is None: raise RuntimeError( "Qudit level not set. Define a qudit (`qudit(level=N)`) or list of qudits." ) @@ -97,15 +117,7 @@ def qudit(level: int) -> PyQudit: RuntimeError: If a qudit of level different than one already defined in the kernel is requested. """ - global globalQuditLevel - - if globalQuditLevel is None: - globalQuditLevel = level - elif level != globalQuditLevel: - raise RuntimeError("The qudits must be of same level within a kernel.") - - id = cudaq_runtime.photonics.allocate_qudit(globalQuditLevel) - return PyQudit(globalQuditLevel, id) + return QuditManager.allocate(level) def plus(qudit: PyQudit): @@ -202,13 +214,11 @@ class PhotonicsHandler(object): def __init__(self, function): - if _TARGET_NAME != cudaq_runtime.get_target().name: + if 'photonics' != cudaq_runtime.get_target().name: raise RuntimeError( "A photonics kernel can only be used with 'photonics' target.") - global globalQuditLevel - globalQuditLevel = None - + QuditManager.reset() self.kernelFunction = function self.kernelFunction.__globals__['qudit'] = qudit @@ -218,4 +228,5 @@ def __init__(self, function): self.kernelFunction.__globals__['mz'] = mz def __call__(self, *args): - return self.kernelFunction(*args) + with QuditManager(): + return self.kernelFunction(*args) diff --git a/python/tests/handlers/test_photonics_kernel.py b/python/tests/handlers/test_photonics_kernel.py index 771769b4ad..dd2fa762f2 100644 --- a/python/tests/handlers/test_photonics_kernel.py +++ b/python/tests/handlers/test_photonics_kernel.py @@ -8,7 +8,6 @@ import pytest -import gc from typing import List import cudaq @@ -20,8 +19,6 @@ def do_something(): yield cudaq.reset_target() cudaq.__clearKernelRegistries() - # Make the tests stable by enforcing resource release - gc.collect() def test_qudit():