Skip to content

Commit

Permalink
Add memory resource argument through to_device
Browse files Browse the repository at this point in the history
While the `to_device` function already included a memory resource, it
didn't use it. Plus other functions calling `to_device` did not use the
argument. The change here makes sure `to_device` passes this to the
`DeviceBuffer` constructor. Also it makes sure other functions have a
default argument, which they set if one is not specified.
  • Loading branch information
jakirkham committed Oct 18, 2022
1 parent 1532bff commit a6f8bcb
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions python/rmm/_lib/device_buffer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,17 @@ cdef class DeviceBuffer:

@staticmethod
cdef DeviceBuffer c_to_device(const unsigned char[::1] b,
Stream stream=DEFAULT_STREAM):
Stream stream=DEFAULT_STREAM,
DeviceMemoryResource mr=None):
"""Calls ``to_device`` function on arguments provided"""
return to_device(b, stream)
return to_device(b, stream, mr)

@staticmethod
def to_device(const unsigned char[::1] b,
Stream stream=DEFAULT_STREAM):
Stream stream=DEFAULT_STREAM,
DeviceMemoryResource mr=None):
"""Calls ``to_device`` function on arguments provided."""
return to_device(b, stream)
return to_device(b, stream, mr)

cpdef copy_to_host(self, ary=None, Stream stream=DEFAULT_STREAM):
"""Copy from a ``DeviceBuffer`` to a buffer on host.
Expand Down Expand Up @@ -356,7 +358,8 @@ cdef class DeviceBuffer:

@cython.boundscheck(False)
cpdef DeviceBuffer to_device(const unsigned char[::1] b,
Stream stream=DEFAULT_STREAM):
Stream stream=DEFAULT_STREAM,
DeviceMemoryResource mr=None):
"""Return a new ``DeviceBuffer`` with a copy of the data.
Parameters
Expand Down Expand Up @@ -384,7 +387,7 @@ cpdef DeviceBuffer to_device(const unsigned char[::1] b,

cdef uintptr_t p = <uintptr_t>&b[0]
cdef size_t s = len(b)
return DeviceBuffer(ptr=p, size=s, stream=stream)
return DeviceBuffer(ptr=p, size=s, stream=stream, mr=mr)


@cython.boundscheck(False)
Expand Down

0 comments on commit a6f8bcb

Please sign in to comment.