Skip to content

Commit

Permalink
Add look up hashes
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Feb 5, 2024
1 parent 914de39 commit 31bd695
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 72 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true }
# min-cl = { version = "0.3.0", optional=true }

[features]
default = ["cpu", "graph", "cached", "autograd", "macro", "opencl"]
default = ["cpu", "graph", "macro", "cached"]
# default = ["no-std"]
# default = ["cpu", "lazy", "static-api", "graph", "autograd", "fork", "serde", "json"]

Expand Down
30 changes: 18 additions & 12 deletions src/cache/owned_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ use std::collections::HashMap;

use std::rc::Rc;

use crate::{flag::AllocFlag, Alloc, Device, HashLocation, FxHasher, ShallowCopy, Shape};
use crate::{flag::AllocFlag, Alloc, Device, FxHasher, HashLocation, NoHasher, ShallowCopy, Shape};

#[derive(Debug, Clone)]
pub struct Cache {
pub nodes:
HashMap<HashLocation<'static>, Rc<dyn core::any::Any>, BuildHasherDefault<FxHasher>>,
pub nodes: HashMap<u64, Rc<dyn core::any::Any>, BuildHasherDefault<NoHasher>>,
}

impl Default for Cache {
Expand All @@ -30,16 +29,22 @@ impl Cache {
/// Lifetime of data must be at least as long as the lifetime of the cache (usually the device).
#[track_caller]
#[inline]
pub unsafe fn get<T, S, D>(&mut self, device: &D, len: usize, callback: fn()) -> D::Base<T, S>
pub unsafe fn get<T, S, D>(
&mut self,
device: &D,
len: usize,
callback: fn(),
look_up_hash: u64,
) -> D::Base<T, S>
where
D: Alloc<T> + 'static,
D::Base<T, S>: ShallowCopy + 'static,
S: Shape,
{
let maybe_allocated = self.nodes.get(&Location::caller().into());
let maybe_allocated = self.nodes.get(&look_up_hash);
match maybe_allocated {
Some(data) => unsafe { data.downcast_ref::<D::Base<T, S>>().unwrap().shallow() },
None => self.add_node(device, len, callback),
None => self.add_node(device, len, callback, look_up_hash),
}
}

Expand All @@ -49,6 +54,7 @@ impl Cache {
device: &D,
len: usize,
callback: fn(),
look_up_hash: u64,
) -> <D as Device>::Base<T, S>
where
D: Alloc<T>,
Expand All @@ -58,7 +64,7 @@ impl Cache {
let data = device.alloc::<S>(len, AllocFlag::None);
let shallow_data = unsafe { data.shallow() };

self.nodes.insert(Location::caller().into(), Rc::new(data));
self.nodes.insert(look_up_hash, Rc::new(data));
callback();

shallow_data
Expand All @@ -77,12 +83,12 @@ mod tests {

assert_eq!(cache.nodes.len(), 0);

let out = cache.add_node::<f32, (), _>(&device, 10, || ());
let out = cache.add_node::<f32, (), _>(&device, 10, || (), 0);

assert_eq!(cache.nodes.len(), 1);
assert_eq!(out.len, 10);

let out1 = unsafe { cache.get::<f32, (), _>(&device, 10, || ()) };
let out1 = unsafe { cache.get::<f32, (), _>(&device, 10, || (), 1) };
assert_ne!(out.ptr, out1.ptr);
}

Expand All @@ -93,10 +99,10 @@ mod tests {

assert_eq!(cache.nodes.len(), 0);

let out1 = unsafe { cache.get::<f32, (), _>(&device, 10, || ()) };
let out1 = unsafe { cache.get::<f32, (), _>(&device, 10, || (), 0) };
assert_eq!(cache.nodes.len(), 1);

let out2 = unsafe { cache.get::<f32, (), _>(&device, 10, || ()) };
let out2 = unsafe { cache.get::<f32, (), _>(&device, 10, || (), 1) };

assert_ne!(out1.ptr, out2.ptr);
assert_eq!(cache.nodes.len(), 2);
Expand All @@ -110,7 +116,7 @@ mod tests {

let mut prev = None;
for _ in 0..1000 {
let out3 = unsafe { cache.get::<f32, (), _>(&device, 10, || ()) };
let out3 = unsafe { cache.get::<f32, (), _>(&device, 10, || (), 0) };
if prev.is_none() {
prev = Some(out3.ptr);
}
Expand Down
81 changes: 44 additions & 37 deletions src/devices/opencl/unified.rs
Original file line number Diff line number Diff line change
@@ -1,52 +1,55 @@
use core::{hash::BuildHasherDefault, panic::Location};
use core::hash::BuildHasherDefault;
use std::{collections::HashMap, ffi::c_void, rc::Rc};

#[cfg(not(feature = "realloc"))]
use crate::{AllocFlag, DeviceError};

use super::CLPtr;
use crate::{
Base, Buffer, CachedCPU, CachedModule, Device, HashLocation, FxHasher, OnDropBuffer,
OpenCL, Shape, UnifiedMemChain, CPU,
Base, Buffer, CachedCPU, CachedModule, Device, NoHasher, OnDropBuffer, OpenCL, Parents, Shape,
UnifiedMemChain, CPU,
};
use min_cl::api::{create_buffer, MemFlags};

impl<Mods: UnifiedMemChain<Self> + OnDropBuffer> UnifiedMemChain<Self> for OpenCL<Mods> {
#[inline]
fn construct_unified_buf_from_cpu_buf<'a, T: 'static, S: Shape>(
fn construct_unified_buf_from_cpu_buf<'a, T: 'static, S: Shape, const N: usize>(
&self,
device: &'a Self,
no_drop_buf: Buffer<'a, T, CachedCPU, S>,
parents: impl Parents<N>,
) -> crate::Result<Buffer<'a, T, Self, S>> {
self.modules
.construct_unified_buf_from_cpu_buf(device, no_drop_buf)
.construct_unified_buf_from_cpu_buf(device, no_drop_buf, parents)
}
}

impl<Mods, OclMods: OnDropBuffer, SimpleMods: OnDropBuffer> UnifiedMemChain<OpenCL<OclMods>>
for CachedModule<Mods, OpenCL<SimpleMods>>
{
#[inline]
fn construct_unified_buf_from_cpu_buf<'a, T: 'static, S: Shape>(
fn construct_unified_buf_from_cpu_buf<'a, T: 'static, S: Shape, const N: usize>(
&self,
device: &'a OpenCL<OclMods>,
no_drop_buf: Buffer<'a, T, CachedCPU, S>,
parents: impl Parents<N>,
) -> crate::Result<Buffer<'a, T, OpenCL<OclMods>, S>> {
construct_buffer(
device,
no_drop_buf,
&mut self.cache.borrow_mut().nodes,
Location::caller().into(),
parents.hash(),
)
}
}

impl<D: Device> UnifiedMemChain<D> for Base {
#[inline]
fn construct_unified_buf_from_cpu_buf<'a, T, S: Shape>(
fn construct_unified_buf_from_cpu_buf<'a, T, S: Shape, const N: usize>(
&self,
_device: &'a D,
_no_drop_buf: Buffer<'a, T, CachedCPU, S>,
_parents: impl Parents<N>,
) -> crate::Result<Buffer<'a, T, D, S>> {
Err(DeviceError::UnifiedConstructNotAvailable.into())
}
Expand All @@ -59,12 +62,8 @@ impl<D: Device> UnifiedMemChain<D> for Base {
pub unsafe fn to_cached_unified<OclMods, CpuMods, T, S>(
device: &OpenCL<OclMods>,
no_drop: Buffer<T, CPU<CpuMods>, S>,
cache: &mut HashMap<
HashLocation<'static>,
Rc<dyn core::any::Any>,
BuildHasherDefault<FxHasher>,
>,
location: HashLocation<'static>,
cache: &mut HashMap<u64, Rc<dyn core::any::Any>, BuildHasherDefault<NoHasher>>,
look_up_hash: u64,
) -> crate::Result<*mut c_void>
where
OclMods: OnDropBuffer,
Expand All @@ -81,7 +80,7 @@ where
)?;

let old_ptr = cache.insert(
location,
look_up_hash,
Rc::new(CLPtr {
ptr: cl_ptr,
host_ptr: no_drop.base().ptr as *mut T,
Expand All @@ -97,7 +96,6 @@ where
Ok(cl_ptr)
}

#[cfg(not(feature = "realloc"))]
/// Converts an 'only' CPU buffer into an OpenCL + CPU (unified memory) buffer.
///
/// # Example
Expand All @@ -123,12 +121,8 @@ where
pub fn construct_buffer<'a, OclMods: OnDropBuffer, CpuMods: OnDropBuffer, T: 'static, S: Shape>(
device: &'a OpenCL<OclMods>,
no_drop: Buffer<'a, T, CPU<CpuMods>, S>,
cache: &mut HashMap<
HashLocation<'static>,
Rc<dyn core::any::Any>,
BuildHasherDefault<FxHasher>,
>,
location: HashLocation<'static>,
cache: &mut HashMap<u64, Rc<dyn core::any::Any>, BuildHasherDefault<NoHasher>>,
look_up_hash: u64,
) -> crate::Result<Buffer<'a, T, OpenCL<OclMods>, S>> {
use crate::PtrType;

Expand All @@ -137,7 +131,7 @@ pub fn construct_buffer<'a, OclMods: OnDropBuffer, CpuMods: OnDropBuffer, T: 'st
}

// if buffer was already converted, return the cache entry.
if let Some(rawcl) = cache.get(&location) {
if let Some(rawcl) = cache.get(&look_up_hash) {
let rawcl = rawcl
.downcast_ref::<<OpenCL<OclMods> as Device>::Base<T, S>>()
.unwrap();
Expand All @@ -153,7 +147,7 @@ pub fn construct_buffer<'a, OclMods: OnDropBuffer, CpuMods: OnDropBuffer, T: 'st
});
}
let (host_ptr, len) = (no_drop.base().ptr, no_drop.len());
let ptr = unsafe { to_cached_unified(device, no_drop, cache, location)? };
let ptr = unsafe { to_cached_unified(device, no_drop, cache, look_up_hash)? };

let data = device.base_to_data::<T, S>(CLPtr {
ptr,
Expand All @@ -173,7 +167,7 @@ pub fn construct_buffer<'a, OclMods: OnDropBuffer, CpuMods: OnDropBuffer, T: 'st
mod tests {
use crate::{
opencl::{chosen_cl_idx, CLPtr},
AllocFlag, Base, Buffer, Cache, Cached, Device, DeviceError, HashLocation, HostPtr, OpenCL,
AllocFlag, Base, Buffer, Cache, Cached, Device, DeviceError, HostPtr, Id, OpenCL,
Retriever, UnifiedMemChain, CPU,
};

Expand All @@ -186,7 +180,11 @@ mod tests {
no_drop.write(&[1., 2.3, 0.76]);

let device = OpenCL::<Cached<Base>>::new(0)?;
let buf = device.construct_unified_buf_from_cpu_buf::<_, ()>(&device, no_drop)?;
let buf = device.construct_unified_buf_from_cpu_buf::<_, (), 1>(
&device,
no_drop,
Id { id: 0, len: 0 },
)?;
assert_eq!(buf.as_slice(), &[1., 2.3, 0.76]);
Ok(())
}
Expand All @@ -198,8 +196,11 @@ mod tests {
let mut no_drop = device.cpu.retrieve::<0>(3, ());
no_drop.write(&[1., 2.3, 0.76]);

let buf = device.construct_unified_buf_from_cpu_buf::<_, ()>(&device, no_drop)?;

let buf = device.construct_unified_buf_from_cpu_buf::<_, (), 1>(
&device,
no_drop,
Id { id: 0, len: 0 },
)?;
assert_eq!(buf.as_slice(), &[1., 2.3, 0.76]);
Ok(())
}
Expand All @@ -210,7 +211,11 @@ mod tests {
let no_drop = cpu.buffer([1, 2, 3]);

let device = OpenCL::<Cached<Base>>::new(0)?;
let buf = device.construct_unified_buf_from_cpu_buf(&device, no_drop);
let buf = device.construct_unified_buf_from_cpu_buf::<_, (), 1>(
&device,
no_drop,
Id { id: 0, len: 0 },
);
match buf
.expect_err("Missing error -> failure")
.downcast_ref::<DeviceError>()
Expand All @@ -228,7 +233,11 @@ mod tests {
no_drop.write(&[1., 2.3, 0.76]);

let device = OpenCL::<Base>::new(chosen_cl_idx())?;
let buf = device.construct_unified_buf_from_cpu_buf(&device, no_drop);
let buf = device.construct_unified_buf_from_cpu_buf::<_, (), 1>(
&device,
no_drop,
Id { id: 0, len: 0 },
);
match buf
.expect_err("Missing error -> failure")
.downcast_ref::<DeviceError>()
Expand All @@ -248,7 +257,7 @@ mod tests {
let device = OpenCL::<Base>::new(chosen_cl_idx())?;
let mut cache = Cache::new();

let buf = construct_buffer(&device, no_drop, &mut cache.nodes, HashLocation::here());
let buf = construct_buffer(&device, no_drop, &mut cache.nodes, 0);
match buf
.expect_err("Missing error -> failure")
.downcast_ref::<DeviceError>()
Expand All @@ -269,8 +278,7 @@ mod tests {
let mut cache = Cache::new();

let (host_ptr, len) = (no_drop.data.ptr, no_drop.len());
let cl_host_ptr =
unsafe { to_cached_unified(&device, no_drop, &mut cache.nodes, HashLocation::here())? };
let cl_host_ptr = unsafe { to_cached_unified(&device, no_drop, &mut cache.nodes, 0)? };

let buf: Buffer<f32, OpenCL> = Buffer {
data: CLPtr {
Expand All @@ -296,8 +304,7 @@ mod tests {
let device = OpenCL::<Base>::new(chosen_cl_idx())?;
let mut cache = Cache::new();

let buf: Buffer<_, _> =
construct_buffer(&device, no_drop, &mut cache.nodes, HashLocation::here())?;
let buf: Buffer<_, _> = construct_buffer(&device, no_drop, &mut cache.nodes, 0)?;

assert_eq!(buf.read(), vec![1., 2.3, 0.76]);
assert_eq!(buf.as_slice(), &[1., 2.3, 0.76]);
Expand All @@ -308,7 +315,7 @@ mod tests {
#[cfg(unified_cl)]
#[test]
fn test_cpu_to_unified_is_reusing_converted_buf() -> crate::Result<()> {
use crate::{Base, Cached, HashLocation, Retriever};
use crate::{Base, Cached, Retriever};
use std::time::Instant;

let cl_dev = OpenCL::<Cached<Base>>::new(0)?;
Expand All @@ -326,7 +333,7 @@ mod tests {
&cl_dev,
buf,
&mut cl_dev.modules.cache.borrow_mut().nodes,
HashLocation::here(),
0,
)?;
dur += start.elapsed().as_secs_f64();

Expand Down
4 changes: 2 additions & 2 deletions src/exec_on_cpu/cl_may_unified.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ where
});

// convert host ptr / CPU buffer into a host ptr + OpenCL ptr buffer
return device.construct_unified_buf_from_cpu_buf(device, no_drop);
return device.construct_unified_buf_from_cpu_buf(device, no_drop, x);
/*return unsafe {
construct_buffer(device, no_drop, /*buf.node.idx*/ ())
};*/
Expand Down Expand Up @@ -117,7 +117,7 @@ where
);

// convert host ptr / CPU buffer into a host ptr + OpenCL ptr buffer
return device.construct_unified_buf_from_cpu_buf(device, no_drop);
return device.construct_unified_buf_from_cpu_buf(device, no_drop, (lhs, rhs));
/*return unsafe {
construct_buffer(device, no_drop, /*buf.node.idx*/ ())
};*/
Expand Down
10 changes: 6 additions & 4 deletions src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,11 @@ pub type CachedCPU = CPU<CachedModule<Base, CPU>>;
#[cfg(feature = "cached")]
pub trait UnifiedMemChain<D: Device> {
#[track_caller]
fn construct_unified_buf_from_cpu_buf<'a, T: 'static, S: Shape>(
fn construct_unified_buf_from_cpu_buf<'a, T: 'static, S: Shape, const N: usize>(
&self,
device: &'a D,
no_drop_buf: Buffer<'a, T, CachedCPU, S>,
parents: impl Parents<N>,
) -> crate::Result<Buffer<'a, T, D, S>>;
}

Expand All @@ -339,13 +340,14 @@ macro_rules! pass_down_unified_mem_chain {
($($to_impl:ident),*) => {
$(
impl<Mods: $crate::UnifiedMemChain<D>, D: Device> $crate::UnifiedMemChain<D> for $to_impl<Mods> {
fn construct_unified_buf_from_cpu_buf<'a, T: 'static, S: Shape>(
fn construct_unified_buf_from_cpu_buf<'a, T: 'static, S: Shape, const N: usize>(
&self,
device: &'a D,
no_drop_buf: Buffer<'a, T, $crate::CachedCPU, S>
no_drop_buf: Buffer<'a, T, $crate::CachedCPU, S>,
parents: impl $crate::Parents<N>
) -> $crate::Result<Buffer<'a, T, D, S>>
{
self.modules.construct_unified_buf_from_cpu_buf(device, no_drop_buf)
self.modules.construct_unified_buf_from_cpu_buf(device, no_drop_buf, parents)
}
}

Expand Down
Loading

0 comments on commit 31bd695

Please sign in to comment.