diff --git a/python/rmm/_lib/device_buffer.pyx b/python/rmm/_lib/device_buffer.pyx index 613e0f2a3..5dc6d1b55 100644 --- a/python/rmm/_lib/device_buffer.pyx +++ b/python/rmm/_lib/device_buffer.pyx @@ -146,7 +146,9 @@ cdef class DeviceBuffer: } return intf - def copy(self): + def copy(self, *, + Stream stream=DEFAULT_STREAM, + DeviceMemoryResource mr=None): """Returns a copy of DeviceBuffer. Returns @@ -165,9 +167,9 @@ cdef class DeviceBuffer: >>> assert db is not db_copy >>> assert db.ptr != db_copy.ptr """ - ret = DeviceBuffer(ptr=self.ptr, size=self.size, stream=self.stream) - ret.mr = self.mr - return ret + return DeviceBuffer( + ptr=self.ptr, size=self.size, stream=stream, mr=mr + ) def __copy__(self): return self.copy()