Skip to content

Commit

Permalink
Use Buffer::from_raw_host_device for cl cpu exec, fix automatic unifi…
Browse files Browse the repository at this point in the history
…ed memory
  • Loading branch information
elftausend committed Dec 17, 2023
1 parent 5ce6e98 commit d5eb979
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 39 deletions.
18 changes: 10 additions & 8 deletions src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ impl<'a, T, D: Device> Buffer<'a, T, D> {
}

#[cfg(feature = "cpu")]
impl<'a, Mods: OnDropBuffer, T, S: Shape> Buffer<'a, T, CPU<Mods>, S> {
impl<'a, T, S: Shape> Buffer<'a, T, CPU<Base>, S> {
/// Constructs a deviceless `Buffer` out of a host pointer and a length.
/// # Example
/// ```
Expand All @@ -484,15 +484,17 @@ impl<'a, Mods: OnDropBuffer, T, S: Shape> Buffer<'a, T, CPU<Mods>, S> {
/// The pointer must be valid.
/// The `Buffer` does not manage deallocation of the allocated memory.
#[inline]
pub unsafe fn from_raw_host(ptr: *mut T, len: usize) -> Buffer<'a, T, CPU<Mods>, S> {
let x = 2;
todo!()
// Buffer {
// data: self.wrap_in_base(CPUPtr::from_ptr(ptr, len, AllocFlag::Wrapper)),
// device: None,
// }
pub unsafe fn from_raw_host(ptr: *mut T, len: usize) -> Buffer<'a, T, CPU<Base>, S> {
Buffer {
data: CPUPtr::from_ptr(ptr, len, AllocFlag::Wrapper),
device: None,
}
}

}

#[cfg(feature = "cpu")]
impl<'a, Mods: OnDropBuffer, T, S: Shape> Buffer<'a, T, CPU<Mods>, S> {
/// Constructs a `Buffer` out of a host pointer and a length.
/// The provided device can be used to shorten operation calls.
///
Expand Down
3 changes: 2 additions & 1 deletion src/devices/opencl/cl_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
impl_buffer_hook_traits, impl_retriever, impl_wrapped_data, pass_down_grad_fn,
pass_down_optimize_mem_graph, pass_down_tape_actions, pass_down_use_gpu_or_cpu, Alloc, Base,
Buffer, Cached, CachedCPU, CloneBuf, Device, Module, OnDropBuffer, OnNewBuffer, Setup,
WrappedData, CPU,
WrappedData, CPU, pass_down_replace_buf,
};
use crate::{PtrConv, Shape};

Expand Down Expand Up @@ -309,6 +309,7 @@ impl<Mods> crate::LazyRun for OpenCL<Mods> {}

pass_down_tape_actions!(OpenCL);
pass_down_grad_fn!(OpenCL);
pass_down_replace_buf!(OpenCL);

#[cfg(test)]
mod tests {
Expand Down
21 changes: 14 additions & 7 deletions src/devices/opencl/unified.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl<D: Device> UnifiedMemChain<D> for Base {
/// This function is used in the `constuct_buffer()` function.
/// # Safety
/// The host pointer inside the no_drop `Buffer` must live as long as the resulting pointer.
pub unsafe fn to_cached_unified<OclMods: OnDropBuffer, CpuMods: OnDropBuffer, T, S: Shape>(
pub unsafe fn to_cached_unified<OclMods, CpuMods, T, S>(
device: &OpenCL<OclMods>,
no_drop: Buffer<T, CPU<CpuMods>, S>,
cache: &mut HashMap<
Expand All @@ -65,7 +65,13 @@ pub unsafe fn to_cached_unified<OclMods: OnDropBuffer, CpuMods: OnDropBuffer, T,
BuildHasherDefault<LocationHasher>,
>,
location: HashLocation<'static>,
) -> crate::Result<*mut c_void> {
) -> crate::Result<*mut c_void>
where
OclMods: OnDropBuffer,
CpuMods: OnDropBuffer,
T: 'static,
S: Shape,
{
// use the host pointer to create an OpenCL buffer
let cl_ptr = create_buffer(
device.ctx(),
Expand All @@ -78,7 +84,7 @@ pub unsafe fn to_cached_unified<OclMods: OnDropBuffer, CpuMods: OnDropBuffer, T,
location,
Rc::new(CLPtr {
ptr: cl_ptr,
host_ptr: no_drop.base().ptr as *mut u8,
host_ptr: no_drop.base().ptr as *mut T,
len: no_drop.len(),
flag: AllocFlag::None,
}),
Expand Down Expand Up @@ -132,7 +138,9 @@ pub fn construct_buffer<'a, OclMods: OnDropBuffer, CpuMods: OnDropBuffer, T: 'st

// if buffer was already converted, return the cache entry.
if let Some(rawcl) = cache.get(&location) {
let rawcl = rawcl.downcast_ref::<<OpenCL::<OclMods> as Device>::Base<T, S>>().unwrap();
let rawcl = rawcl
.downcast_ref::<<OpenCL<OclMods> as Device>::Base<T, S>>()
.unwrap();
let data = device.base_to_data::<T, S>(CLPtr {
ptr: rawcl.ptr,
host_ptr: rawcl.host_ptr as *mut T,
Expand Down Expand Up @@ -300,17 +308,16 @@ mod tests {
#[cfg(unified_cl)]
#[test]
fn test_cpu_to_unified_is_reusing_converted_buf() -> crate::Result<()> {
use std::time::Instant;

use crate::{Base, Cached, HashLocation, Retriever};
use std::time::Instant;

let cl_dev = OpenCL::<Cached<Base>>::new(0)?;
let device = CPU::<Cached<Base>>::new();

let mut dur = 0.;

for _ in 0..100 {
let mut buf: Buffer<i32, _> = device.retrieve::<0>(6, ());
let mut buf: Buffer<i32, _> = device.retrieve::<0>(6, ());

buf.copy_from_slice(&[1, 2, 3, 4, 5, 6]);

Expand Down
8 changes: 4 additions & 4 deletions src/exec_on_cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,9 @@ macro_rules! to_cpu {
/// The old `Buffer`s are shadowed.
#[macro_export]
macro_rules! to_raw_host {
($cpu_ty:ty, $($t:ident),*) => {
($cpu:expr, $($t:ident),*) => {
$(
let $t = &unsafe { $crate::Buffer::<_, $cpu_ty, ()>::from_raw_host($t.data.host_ptr, $t.len()) };
let $t = &unsafe { $crate::Buffer::<_, _, ()>::from_raw_host_device($cpu, $t.data.host_ptr, $t.len()) };
)*
};
}
Expand All @@ -230,10 +230,10 @@ macro_rules! to_raw_host {
/// New names for the `CPU` `Buffer`s are provided by the user.
#[macro_export]
macro_rules! to_raw_host_mut {
($cpu_ty:ty, $($t:ident, $cpu_name:ident),*) => {
($cpu:expr, $($t:ident, $cpu_name:ident),*) => {
$(
#[allow(unused_mut)]
let mut $cpu_name = &mut unsafe { $crate::Buffer::<_, $cpu_ty, ()>::from_raw_host($t.data.host_ptr, $t.len()) };
let mut $cpu_name = &mut unsafe { $crate::Buffer::<_, _, ()>::from_raw_host_device($cpu, $t.data.host_ptr, $t.len()) };
)*
};
}
Expand Down
24 changes: 10 additions & 14 deletions src/exec_on_cpu/cl_may_unified.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ where

// host ptr buffer
let no_drop = f(&device.cpu, &unsafe {
Buffer::from_raw_host(x.base().host_ptr, x.len())
Buffer::from_raw_host_device(&device.cpu, x.base().host_ptr, x.len())
});

// convert host ptr / CPU buffer into a host ptr + OpenCL ptr buffer
Expand Down Expand Up @@ -73,7 +73,7 @@ where
if device.unified_mem() {
return {
f(&cpu, &mut unsafe {
Buffer::from_raw_host(lhs.base().host_ptr, lhs.len())
Buffer::from_raw_host_device(&cpu, lhs.base().host_ptr, lhs.len())
});
Ok(())
};
Expand All @@ -87,12 +87,7 @@ where
/// This is way faster than [cpu_exec_binary], as new memory is not allocated.
///
/// `cpu_exec_binary_may_unified` can be used interchangeably with [cpu_exec_binary].
pub fn cpu_exec_binary_may_unified<
'a,
T,
F,
Mods: OnDropBuffer + UnifiedMemChain<OpenCL<Mods>> + Retrieve<OpenCL<Mods>, T> + 'static,
>(
pub fn cpu_exec_binary_may_unified<'a, T, F, Mods>(
device: &'a OpenCL<Mods>,
lhs: &Buffer<T, OpenCL<Mods>>,
rhs: &Buffer<T, OpenCL<Mods>>,
Expand All @@ -105,6 +100,7 @@ where
&Buffer<'_, T, CachedCPU>,
&Buffer<'_, T, CachedCPU>,
) -> Buffer<'b, T, CachedCPU>,
Mods: UnifiedMemChain<OpenCL<Mods>> + Retrieve<OpenCL<Mods>, T> + 'static,
{
// TODO: use compile time unified_cl flag -> get from custos?
#[cfg(not(feature = "realloc"))]
Expand All @@ -116,8 +112,8 @@ where
// host ptr buffer
let no_drop = f(
&device.cpu,
&unsafe { Buffer::from_raw_host(lhs.base().host_ptr, lhs.len()) },
&unsafe { Buffer::from_raw_host(rhs.base().host_ptr, rhs.len()) },
&unsafe { Buffer::from_raw_host_device(&device.cpu, lhs.base().host_ptr, lhs.len()) },
&unsafe { Buffer::from_raw_host_device(&device.cpu, rhs.base().host_ptr, rhs.len()) },
);

// convert host ptr / CPU buffer into a host ptr + OpenCL ptr buffer
Expand Down Expand Up @@ -163,8 +159,8 @@ where
return {
f(
&cpu,
&mut unsafe { Buffer::from_raw_host(lhs.base().host_ptr, lhs.len()) },
&unsafe { Buffer::from_raw_host(rhs.base().host_ptr, rhs.len()) },
&mut unsafe { Buffer::from_raw_host_device(&cpu, lhs.base().host_ptr, lhs.len()) },
&unsafe { Buffer::from_raw_host_device(&cpu, rhs.base().host_ptr, rhs.len()) },
);
Ok(())
};
Expand Down Expand Up @@ -242,8 +238,8 @@ macro_rules! cl_cpu_exec_unified_mut {
($device:ident, $($t:ident),* WRITE_TO<$($write_to:ident, $from:ident),*> $op:expr) => {{
// TODO: add to graph?: convert.node = device.graph().add(convert.len(), matrix.node.idx);
if $device.unified_mem() {
$crate::to_raw_host!($crate::CPU::<$crate::CachedModule<$crate::Base, $crate::CPU>>, $($t),*);
$crate::to_raw_host_mut!($crate::CPU::<$crate::CachedModule<$crate::Base, $crate::CPU>>, $($write_to, $from),*);
$crate::to_raw_host!(&$device.cpu, $($t),*);
$crate::to_raw_host_mut!(&$device.cpu, $($write_to, $from),*);
$op;

} else {
Expand Down
26 changes: 23 additions & 3 deletions src/modules/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type Buffers = HashMap<UniqueId, Box<dyn Any>, BuildHasherDefault<NoHasher>>;
#[derive(Default)]
pub struct Lazy<Mods> {
pub modules: Mods,
alloc_later: RefCell<Vec<(Id, fn(&mut Buffers, Id, &dyn Any))>>,
alloc_later: RefCell<Vec<(Id, fn(&mut Buffers, Id, &dyn Any))>>, // could use D generic instead of dyn Any (required LazyModule structure)
allocated: Cell<bool>,
buffers: RefCell<Buffers>,
graph: RefCell<LazyGraph>,
Expand Down Expand Up @@ -199,6 +199,9 @@ where
alloc_later.push((id, |buffers, id, device| {
let device = device.downcast_ref::<D>().unwrap();
// TODO: should be fixable - (lazy) -> either return error or fix
// creating buffers (with data) is not lazy - they are allocated instantly
// these are then added to `buffers` with their ID (which is the pointing address)
// new IDs start at 0. 1, 2, 3, ... till a collision with an address happens.
assert!(
!buffers.contains_key(&id.id),
"IDs collided! Maybe pointing address already occupied this ID."
Expand Down Expand Up @@ -368,6 +371,24 @@ mod tests {
assert_eq!(out.replace().read(), &[3; 10]);
}

#[cfg(feature = "cpu")]
#[test]
fn test_lazy_alloc_later() {
use crate::Run;

let device = CPU::<Lazy<Base>>::new();

let buf = Buffer::<i32, _>::new(&device, 10);
let out = device.apply_fn(&buf, |x| x.add(3));

device.modules.alloc_later(&device);
device.modules.allocated.set(true);
assert_eq!(out.replace().read(), &[0; 10]);
unsafe { device.run().unwrap() };
assert_eq!(out.replace().read(), &[3; 10]);

}

#[test]
#[cfg(feature = "opencl")]
fn test_lazy_apply_fn_with_run_cl() {
Expand All @@ -378,9 +399,8 @@ mod tests {
let buf = Buffer::<i32, _>::new(&device, 10);
let out = device.apply_fn(&buf, |x| x.add(3));

assert_eq!(out.read(), &[0; 10]);
unsafe { device.run().unwrap() }
assert_eq!(out.read(), &[3; 10]);
assert_eq!(out.replace().read(), &[3; 10]);
}
#[test]
#[cfg(feature = "cpu")]
Expand Down
4 changes: 2 additions & 2 deletions src/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ impl Shape for () {
// TODO: impl for net device
// this is used to
/// If the [`Shape`] does not matter for a specific device [`Buffer`](crate::Buffer), than this trait should be implemented.
pub trait IsShapeIndep: Device {}
pub unsafe trait IsShapeIndep: Device {}

#[cfg(not(feature = "no-std"))]
impl<D: PtrConv + Device> IsShapeIndep for D {}
unsafe impl<D: PtrConv + Device> IsShapeIndep for D {}

/// If the [`Shape`] is provides a fixed size, than this trait should be implemented.
/// Forgot how this is useful.
Expand Down

0 comments on commit d5eb979

Please sign in to comment.