Skip to content

Commit

Permalink
Use memory resource in DeviceBuffer construtor
Browse files Browse the repository at this point in the history
Update the Python constructor to take and handle a
`DeviceMemoryResource` argument. Also pass this through to
`device_buffer` constructors.
  • Loading branch information
jakirkham committed Oct 18, 2022
1 parent fc8e3ed commit 295156f
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions python/rmm/_lib/device_buffer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ cdef class DeviceBuffer:
def __cinit__(self, *,
uintptr_t ptr=0,
size_t size=0,
Stream stream=DEFAULT_STREAM):
Stream stream=DEFAULT_STREAM,
DeviceMemoryResource mr=None):
"""Construct a ``DeviceBuffer`` with optional size and data pointer
Parameters
Expand All @@ -64,6 +65,9 @@ cdef class DeviceBuffer:
scope while the DeviceBuffer is in use. Destroying the
underlying stream while the DeviceBuffer is in use will
result in undefined behavior.
mr : optional
Memory resource to use to allocate memory for the underlying
``device_buffer``.
Note
----
Expand All @@ -77,22 +81,31 @@ cdef class DeviceBuffer:
>>> db = rmm.DeviceBuffer(size=5)
"""
cdef const void* c_ptr
cdef device_memory_resource* c_mr

# Use default memory resource if none is specified.
# Also get C++ representation to call constructor below.
if mr is None:
mr = get_current_device_resource()
c_mr = mr.get_mr()

with nogil:
c_ptr = <const void*>ptr

if size == 0:
self.c_obj.reset(new device_buffer())
self.c_obj.reset(new device_buffer(c_mr))
elif c_ptr == NULL:
self.c_obj.reset(new device_buffer(size, stream.view()))
self.c_obj.reset(new device_buffer(size, stream.view(), c_mr))
else:
self.c_obj.reset(new device_buffer(c_ptr, size, stream.view()))
self.c_obj.reset(
new device_buffer(c_ptr, size, stream.view(), c_mr)
)

if stream.c_is_default():
stream.c_synchronize()

# Save a reference to the MR and stream used for allocation
self.mr = get_current_device_resource()
self.mr = mr
self.stream = stream

def __len__(self):
Expand Down

0 comments on commit 295156f

Please sign in to comment.