diff --git a/Cargo.toml b/Cargo.toml index d4042b53..1d8038bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,7 +49,7 @@ min-cl = { path="../min-cl", optional=true } [features] # default = ["cpu", "autograd", "macro"] -default = ["cpu", "static-api", "lazy", "autograd", "graph", "fork",] +default = ["cpu", "static-api", "lazy", "autograd", "graph", "fork", ] cpu = [] opencl = ["dep:min-cl", "cpu", "cached"] diff --git a/src/buffer/num.rs b/src/buffer/num.rs index c4446344..666a8833 100644 --- a/src/buffer/num.rs +++ b/src/buffer/num.rs @@ -59,7 +59,7 @@ impl From for Num { impl Device for () { type Data = Self::Base; - type Base = Num; + type Base = Num; type Error = Infallible; diff --git a/src/devices.rs b/src/devices.rs index d57c096c..fe3e8d4c 100644 --- a/src/devices.rs +++ b/src/devices.rs @@ -41,7 +41,7 @@ pub mod cpu_stack_ops; use crate::{Buffer, HasId, OnDropBuffer, PtrType, Shape}; pub trait Device: OnDropBuffer + Sized { - type Base: HasId + PtrType; + type Base: HasId + PtrType; type Data: HasId + PtrType; type Error; diff --git a/src/devices/cpu/cpu_device.rs b/src/devices/cpu/cpu_device.rs index 61603d32..35820d0b 100644 --- a/src/devices/cpu/cpu_device.rs +++ b/src/devices/cpu/cpu_device.rs @@ -36,7 +36,7 @@ impl IsCPU for CPU {} impl Device for CPU { type Error = Infallible; - type Base = CPUPtr; + type Base = CPUPtr; type Data = Self::Wrap>; // type WrappedData = ; diff --git a/src/devices/opencl/cl_device.rs b/src/devices/opencl/cl_device.rs index a3335e71..cc7cbe10 100644 --- a/src/devices/opencl/cl_device.rs +++ b/src/devices/opencl/cl_device.rs @@ -7,7 +7,7 @@ use crate::flag::AllocFlag; use crate::{ impl_buffer_hook_traits, impl_retriever, pass_down_grad_fn, pass_down_optimize_mem_graph, pass_down_tape_actions, pass_down_use_gpu_or_cpu, Alloc, Base, Buffer, Cached, CachedCPU, - CloneBuf, Device, Module, OnDropBuffer, OnNewBuffer, Setup, CPU, + CloneBuf, Device, Module, OnDropBuffer, OnNewBuffer, Setup, CPU, WrappedData, impl_wrapped_data, }; use crate::{PtrConv, Shape}; @@ -53,6 +53,8 @@ pub type CL = OpenCL; } }*/ +impl_wrapped_data!(OpenCL); + impl Deref for OpenCL { type Target = CLDevice; @@ -152,13 +154,39 @@ impl OpenCL { } impl Device for OpenCL { - type Data = CLPtr; + type Data = Self::Wrap>; + type Base = CLPtr; type Error = (); fn new() -> Result { todo!() // OpenCL::::new(chosen_cl_idx()) } + #[inline(always)] + fn base_to_data(&self, base: Self::Base) -> Self::Data { + self.wrap_in_base(base) + } + + #[inline(always)] + fn wrap_to_data(&self, wrap: Self::Wrap>) -> Self::Data { + wrap + } + + #[inline(always)] + fn data_as_wrap<'a, T, S: Shape>( + data: &'a Self::Data, + ) -> &'a Self::Wrap> { + data + } + + #[inline(always)] + fn data_as_wrap_mut<'a, T, S: Shape>( + data: &'a mut Self::Data, + ) -> &'a mut Self::Wrap> { + data + } + + } impl PtrConv> for OpenCL { @@ -171,12 +199,7 @@ impl PtrConv> for IS: Shape, OS: Shape, { - CLPtr { - ptr: ptr.ptr, - host_ptr: ptr.host_ptr.cast(), - len: ptr.len, - flag, - } + todo!() } } @@ -246,10 +269,10 @@ impl Alloc for OpenCL { } } -impl<'a, T, Mods: OnDropBuffer + OnNewBuffer> CloneBuf<'a, T> for OpenCL { +impl<'a, T, Mods: OnDropBuffer + OnNewBuffer> CloneBuf<'a, T> for OpenCL { fn clone_buf(&'a self, buf: &Buffer<'a, T, Self>) -> Buffer<'a, T, Self> { let cloned = Buffer::new(self, buf.len()); - enqueue_full_copy_buffer::(self.queue(), buf.data.ptr, cloned.data.ptr, buf.len()) + enqueue_full_copy_buffer::(self.queue(), buf.base().ptr, cloned.base().ptr, buf.len()) .unwrap(); cloned }