diff --git a/src/buffer.rs b/src/buffer.rs index 2701295a..1fac9fda 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -461,7 +461,7 @@ impl<'a, T, D: Device> Buffer<'a, T, D> { } #[cfg(feature = "cpu")] -impl<'a, Mods: OnDropBuffer, T, S: Shape> Buffer<'a, T, CPU, S> { +impl<'a, T, S: Shape> Buffer<'a, T, CPU, S> { /// Constructs a deviceless `Buffer` out of a host pointer and a length. /// # Example /// ``` @@ -484,15 +484,17 @@ impl<'a, Mods: OnDropBuffer, T, S: Shape> Buffer<'a, T, CPU, S> { /// The pointer must be valid. /// The `Buffer` does not manage deallocation of the allocated memory. #[inline] - pub unsafe fn from_raw_host(ptr: *mut T, len: usize) -> Buffer<'a, T, CPU, S> { - let x = 2; - todo!() - // Buffer { - // data: self.wrap_in_base(CPUPtr::from_ptr(ptr, len, AllocFlag::Wrapper)), - // device: None, - // } + pub unsafe fn from_raw_host(ptr: *mut T, len: usize) -> Buffer<'a, T, CPU, S> { + Buffer { + data: CPUPtr::from_ptr(ptr, len, AllocFlag::Wrapper), + device: None, + } } +} + +#[cfg(feature = "cpu")] +impl<'a, Mods: OnDropBuffer, T, S: Shape> Buffer<'a, T, CPU, S> { /// Constructs a `Buffer` out of a host pointer and a length. /// The provided device can be used to shorten operation calls. /// diff --git a/src/devices/opencl/cl_device.rs b/src/devices/opencl/cl_device.rs index 9b889ab3..d7a44b41 100644 --- a/src/devices/opencl/cl_device.rs +++ b/src/devices/opencl/cl_device.rs @@ -8,7 +8,7 @@ use crate::{ impl_buffer_hook_traits, impl_retriever, impl_wrapped_data, pass_down_grad_fn, pass_down_optimize_mem_graph, pass_down_tape_actions, pass_down_use_gpu_or_cpu, Alloc, Base, Buffer, Cached, CachedCPU, CloneBuf, Device, Module, OnDropBuffer, OnNewBuffer, Setup, - WrappedData, CPU, + WrappedData, CPU, pass_down_replace_buf, }; use crate::{PtrConv, Shape}; @@ -309,6 +309,7 @@ impl crate::LazyRun for OpenCL {} pass_down_tape_actions!(OpenCL); pass_down_grad_fn!(OpenCL); +pass_down_replace_buf!(OpenCL); #[cfg(test)] mod tests { diff --git a/src/devices/opencl/unified.rs b/src/devices/opencl/unified.rs index ea5fe0a0..5f2f3f32 100644 --- a/src/devices/opencl/unified.rs +++ b/src/devices/opencl/unified.rs @@ -56,7 +56,7 @@ impl UnifiedMemChain for Base { /// This function is used in the `constuct_buffer()` function. /// # Safety /// The host pointer inside the no_drop `Buffer` must live as long as the resulting pointer. -pub unsafe fn to_cached_unified( +pub unsafe fn to_cached_unified( device: &OpenCL, no_drop: Buffer, S>, cache: &mut HashMap< @@ -65,7 +65,13 @@ pub unsafe fn to_cached_unified, >, location: HashLocation<'static>, -) -> crate::Result<*mut c_void> { +) -> crate::Result<*mut c_void> +where + OclMods: OnDropBuffer, + CpuMods: OnDropBuffer, + T: 'static, + S: Shape, +{ // use the host pointer to create an OpenCL buffer let cl_ptr = create_buffer( device.ctx(), @@ -78,7 +84,7 @@ pub unsafe fn to_cached_unified as Device>::Base>().unwrap(); + let rawcl = rawcl + .downcast_ref::< as Device>::Base>() + .unwrap(); let data = device.base_to_data::(CLPtr { ptr: rawcl.ptr, host_ptr: rawcl.host_ptr as *mut T, @@ -300,9 +308,8 @@ mod tests { #[cfg(unified_cl)] #[test] fn test_cpu_to_unified_is_reusing_converted_buf() -> crate::Result<()> { - use std::time::Instant; - use crate::{Base, Cached, HashLocation, Retriever}; + use std::time::Instant; let cl_dev = OpenCL::>::new(0)?; let device = CPU::>::new(); @@ -310,7 +317,7 @@ mod tests { let mut dur = 0.; for _ in 0..100 { - let mut buf: Buffer = device.retrieve::<0>(6, ()); + let mut buf: Buffer = device.retrieve::<0>(6, ()); buf.copy_from_slice(&[1, 2, 3, 4, 5, 6]); diff --git a/src/exec_on_cpu.rs b/src/exec_on_cpu.rs index 0d9d9f5b..cbed7f53 100644 --- a/src/exec_on_cpu.rs +++ b/src/exec_on_cpu.rs @@ -219,9 +219,9 @@ macro_rules! to_cpu { /// The old `Buffer`s are shadowed. #[macro_export] macro_rules! to_raw_host { - ($cpu_ty:ty, $($t:ident),*) => { + ($cpu:expr, $($t:ident),*) => { $( - let $t = &unsafe { $crate::Buffer::<_, $cpu_ty, ()>::from_raw_host($t.data.host_ptr, $t.len()) }; + let $t = &unsafe { $crate::Buffer::<_, _, ()>::from_raw_host_device($cpu, $t.data.host_ptr, $t.len()) }; )* }; } @@ -230,10 +230,10 @@ macro_rules! to_raw_host { /// New names for the `CPU` `Buffer`s are provided by the user. #[macro_export] macro_rules! to_raw_host_mut { - ($cpu_ty:ty, $($t:ident, $cpu_name:ident),*) => { + ($cpu:expr, $($t:ident, $cpu_name:ident),*) => { $( #[allow(unused_mut)] - let mut $cpu_name = &mut unsafe { $crate::Buffer::<_, $cpu_ty, ()>::from_raw_host($t.data.host_ptr, $t.len()) }; + let mut $cpu_name = &mut unsafe { $crate::Buffer::<_, _, ()>::from_raw_host_device($cpu, $t.data.host_ptr, $t.len()) }; )* }; } diff --git a/src/exec_on_cpu/cl_may_unified.rs b/src/exec_on_cpu/cl_may_unified.rs index 0921fec8..bf0f88da 100644 --- a/src/exec_on_cpu/cl_may_unified.rs +++ b/src/exec_on_cpu/cl_may_unified.rs @@ -31,7 +31,7 @@ where // host ptr buffer let no_drop = f(&device.cpu, &unsafe { - Buffer::from_raw_host(x.base().host_ptr, x.len()) + Buffer::from_raw_host_device(&device.cpu, x.base().host_ptr, x.len()) }); // convert host ptr / CPU buffer into a host ptr + OpenCL ptr buffer @@ -73,7 +73,7 @@ where if device.unified_mem() { return { f(&cpu, &mut unsafe { - Buffer::from_raw_host(lhs.base().host_ptr, lhs.len()) + Buffer::from_raw_host_device(&cpu, lhs.base().host_ptr, lhs.len()) }); Ok(()) }; @@ -87,12 +87,7 @@ where /// This is way faster than [cpu_exec_binary], as new memory is not allocated. /// /// `cpu_exec_binary_may_unified` can be used interchangeably with [cpu_exec_binary]. -pub fn cpu_exec_binary_may_unified< - 'a, - T, - F, - Mods: OnDropBuffer + UnifiedMemChain> + Retrieve, T> + 'static, ->( +pub fn cpu_exec_binary_may_unified<'a, T, F, Mods>( device: &'a OpenCL, lhs: &Buffer>, rhs: &Buffer>, @@ -105,6 +100,7 @@ where &Buffer<'_, T, CachedCPU>, &Buffer<'_, T, CachedCPU>, ) -> Buffer<'b, T, CachedCPU>, + Mods: UnifiedMemChain> + Retrieve, T> + 'static, { // TODO: use compile time unified_cl flag -> get from custos? #[cfg(not(feature = "realloc"))] @@ -116,8 +112,8 @@ where // host ptr buffer let no_drop = f( &device.cpu, - &unsafe { Buffer::from_raw_host(lhs.base().host_ptr, lhs.len()) }, - &unsafe { Buffer::from_raw_host(rhs.base().host_ptr, rhs.len()) }, + &unsafe { Buffer::from_raw_host_device(&device.cpu, lhs.base().host_ptr, lhs.len()) }, + &unsafe { Buffer::from_raw_host_device(&device.cpu, rhs.base().host_ptr, rhs.len()) }, ); // convert host ptr / CPU buffer into a host ptr + OpenCL ptr buffer @@ -163,8 +159,8 @@ where return { f( &cpu, - &mut unsafe { Buffer::from_raw_host(lhs.base().host_ptr, lhs.len()) }, - &unsafe { Buffer::from_raw_host(rhs.base().host_ptr, rhs.len()) }, + &mut unsafe { Buffer::from_raw_host_device(&cpu, lhs.base().host_ptr, lhs.len()) }, + &unsafe { Buffer::from_raw_host_device(&cpu, rhs.base().host_ptr, rhs.len()) }, ); Ok(()) }; @@ -242,8 +238,8 @@ macro_rules! cl_cpu_exec_unified_mut { ($device:ident, $($t:ident),* WRITE_TO<$($write_to:ident, $from:ident),*> $op:expr) => {{ // TODO: add to graph?: convert.node = device.graph().add(convert.len(), matrix.node.idx); if $device.unified_mem() { - $crate::to_raw_host!($crate::CPU::<$crate::CachedModule<$crate::Base, $crate::CPU>>, $($t),*); - $crate::to_raw_host_mut!($crate::CPU::<$crate::CachedModule<$crate::Base, $crate::CPU>>, $($write_to, $from),*); + $crate::to_raw_host!(&$device.cpu, $($t),*); + $crate::to_raw_host_mut!(&$device.cpu, $($write_to, $from),*); $op; } else { diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index e55d40ef..57927200 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -24,7 +24,7 @@ type Buffers = HashMap, BuildHasherDefault>; #[derive(Default)] pub struct Lazy { pub modules: Mods, - alloc_later: RefCell>, + alloc_later: RefCell>, // could use D generic instead of dyn Any (required LazyModule structure) allocated: Cell, buffers: RefCell, graph: RefCell, @@ -199,6 +199,9 @@ where alloc_later.push((id, |buffers, id, device| { let device = device.downcast_ref::().unwrap(); // TODO: should be fixable - (lazy) -> either return error or fix + // creating buffers (with data) is not lazy - they are allocated instantly + // these are then added to `buffers` with their ID (which is the pointing address) + // new IDs start at 0. 1, 2, 3, ... till a collision with an address happens. assert!( !buffers.contains_key(&id.id), "IDs collided! Maybe pointing address already occupied this ID." @@ -368,6 +371,24 @@ mod tests { assert_eq!(out.replace().read(), &[3; 10]); } + #[cfg(feature = "cpu")] + #[test] + fn test_lazy_alloc_later() { + use crate::Run; + + let device = CPU::>::new(); + + let buf = Buffer::::new(&device, 10); + let out = device.apply_fn(&buf, |x| x.add(3)); + + device.modules.alloc_later(&device); + device.modules.allocated.set(true); + assert_eq!(out.replace().read(), &[0; 10]); + unsafe { device.run().unwrap() }; + assert_eq!(out.replace().read(), &[3; 10]); + + } + #[test] #[cfg(feature = "opencl")] fn test_lazy_apply_fn_with_run_cl() { @@ -378,9 +399,8 @@ mod tests { let buf = Buffer::::new(&device, 10); let out = device.apply_fn(&buf, |x| x.add(3)); - assert_eq!(out.read(), &[0; 10]); unsafe { device.run().unwrap() } - assert_eq!(out.read(), &[3; 10]); + assert_eq!(out.replace().read(), &[3; 10]); } #[test] #[cfg(feature = "cpu")] diff --git a/src/shape.rs b/src/shape.rs index 986a6e1f..8087fd9b 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -39,10 +39,10 @@ impl Shape for () { // TODO: impl for net device // this is used to /// If the [`Shape`] does not matter for a specific device [`Buffer`](crate::Buffer), than this trait should be implemented. -pub trait IsShapeIndep: Device {} +pub unsafe trait IsShapeIndep: Device {} #[cfg(not(feature = "no-std"))] -impl IsShapeIndep for D {} +unsafe impl IsShapeIndep for D {} /// If the [`Shape`] is provides a fixed size, than this trait should be implemented. /// Forgot how this is useful.