From a6f8bcb53c794198316607023ccd31a38146acf5 Mon Sep 17 00:00:00 2001 From: jakirkham Date: Tue, 18 Oct 2022 01:28:12 -0700 Subject: [PATCH] Add memory resource argument through `to_device` While the `to_device` function already included a memory resource, it didn't use it. Plus other functions calling `to_device` did not use the argument. The change here makes sure `to_device` passes this to the `DeviceBuffer` constructor. Also it makes sure other functions have a default argument, which they set if one is not specified. --- python/rmm/_lib/device_buffer.pyx | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/python/rmm/_lib/device_buffer.pyx b/python/rmm/_lib/device_buffer.pyx index 5dc6d1b55..59a51066d 100644 --- a/python/rmm/_lib/device_buffer.pyx +++ b/python/rmm/_lib/device_buffer.pyx @@ -182,15 +182,17 @@ cdef class DeviceBuffer: @staticmethod cdef DeviceBuffer c_to_device(const unsigned char[::1] b, - Stream stream=DEFAULT_STREAM): + Stream stream=DEFAULT_STREAM, + DeviceMemoryResource mr=None): """Calls ``to_device`` function on arguments provided""" - return to_device(b, stream) + return to_device(b, stream, mr) @staticmethod def to_device(const unsigned char[::1] b, - Stream stream=DEFAULT_STREAM): + Stream stream=DEFAULT_STREAM, + DeviceMemoryResource mr=None): """Calls ``to_device`` function on arguments provided.""" - return to_device(b, stream) + return to_device(b, stream, mr) cpdef copy_to_host(self, ary=None, Stream stream=DEFAULT_STREAM): """Copy from a ``DeviceBuffer`` to a buffer on host. @@ -356,7 +358,8 @@ cdef class DeviceBuffer: @cython.boundscheck(False) cpdef DeviceBuffer to_device(const unsigned char[::1] b, - Stream stream=DEFAULT_STREAM): + Stream stream=DEFAULT_STREAM, + DeviceMemoryResource mr=None): """Return a new ``DeviceBuffer`` with a copy of the data. Parameters @@ -384,7 +387,7 @@ cpdef DeviceBuffer to_device(const unsigned char[::1] b, cdef uintptr_t p = &b[0] cdef size_t s = len(b) - return DeviceBuffer(ptr=p, size=s, stream=stream) + return DeviceBuffer(ptr=p, size=s, stream=stream, mr=mr) @cython.boundscheck(False)