From e0a4d05ca0b4d1998aaf13ad92d05631cc420b90 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 3 May 2024 14:15:32 +0000 Subject: [PATCH] DeviceBuffer: accept memory resource when taking ownership In c_from_unique_ptr we should not just rely on get_current_device_resource, but rather allow the user to pass in the memory resource they _know_ was used to allocate the buffer we are taking ownership of. So that we are backwards-compatible we default, as before, to the current device resource. --- python/rmm/rmm/_lib/device_buffer.pxd | 3 ++- python/rmm/rmm/_lib/device_buffer.pyx | 14 ++++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/python/rmm/rmm/_lib/device_buffer.pxd b/python/rmm/rmm/_lib/device_buffer.pxd index b48df21e7..0603227fb 100644 --- a/python/rmm/rmm/_lib/device_buffer.pxd +++ b/python/rmm/rmm/_lib/device_buffer.pxd @@ -65,7 +65,8 @@ cdef class DeviceBuffer: @staticmethod cdef DeviceBuffer c_from_unique_ptr( unique_ptr[device_buffer] ptr, - Stream stream=* + Stream stream=*, + DeviceMemoryResource mr=*, ) @staticmethod diff --git a/python/rmm/rmm/_lib/device_buffer.pyx b/python/rmm/rmm/_lib/device_buffer.pyx index bbeaa614e..230271c39 100644 --- a/python/rmm/rmm/_lib/device_buffer.pyx +++ b/python/rmm/rmm/_lib/device_buffer.pyx @@ -33,6 +33,7 @@ from cuda.ccudart cimport ( ) from rmm._lib.memory_resource cimport ( + DeviceMemoryResource, device_memory_resource, get_current_device_resource, ) @@ -48,7 +49,8 @@ cdef class DeviceBuffer: def __cinit__(self, *, uintptr_t ptr=0, size_t size=0, - Stream stream=DEFAULT_STREAM): + Stream stream=DEFAULT_STREAM, + DeviceMemoryResource mr=None): """Construct a ``DeviceBuffer`` with optional size and data pointer Parameters @@ -65,6 +67,9 @@ cdef class DeviceBuffer: scope while the DeviceBuffer is in use. Destroying the underlying stream while the DeviceBuffer is in use will result in undefined behavior. + mr : optional + DeviceMemoryResource for the allocation, if not provided + defaults to the current device resource. Note ---- @@ -80,7 +85,7 @@ cdef class DeviceBuffer: cdef const void* c_ptr cdef device_memory_resource * mr_ptr # Save a reference to the MR and stream used for allocation - self.mr = get_current_device_resource() + self.mr = get_current_device_resource() if mr is None else mr self.stream = stream mr_ptr = self.mr.get_mr() @@ -162,13 +167,14 @@ cdef class DeviceBuffer: @staticmethod cdef DeviceBuffer c_from_unique_ptr( unique_ptr[device_buffer] ptr, - Stream stream=DEFAULT_STREAM + Stream stream=DEFAULT_STREAM, + DeviceMemoryResource mr=None, ): cdef DeviceBuffer buf = DeviceBuffer.__new__(DeviceBuffer) if stream.c_is_default(): stream.c_synchronize() buf.c_obj = move(ptr) - buf.mr = get_current_device_resource() + buf.mr = get_current_device_resource() if mr is None else mr buf.stream = stream return buf