Skip to content

Commit

Permalink
Add shape requirement to Device::Base
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Dec 16, 2023
1 parent 223c0ef commit 14c2f21
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ min-cl = { path="../min-cl", optional=true }

[features]
# default = ["cpu", "autograd", "macro"]
default = ["cpu", "static-api", "lazy", "autograd", "graph", "fork",]
default = ["cpu", "static-api", "lazy", "autograd", "graph", "fork", ]

cpu = []
opencl = ["dep:min-cl", "cpu", "cached"]
Expand Down
2 changes: 1 addition & 1 deletion src/buffer/num.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl<T> From<T> for Num<T> {

impl Device for () {
type Data<T, S: crate::Shape> = Self::Base<T, S>;
type Base<T, S> = Num<T>;
type Base<T, S: crate::Shape> = Num<T>;

type Error = Infallible;

Expand Down
2 changes: 1 addition & 1 deletion src/devices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub mod cpu_stack_ops;
use crate::{Buffer, HasId, OnDropBuffer, PtrType, Shape};

pub trait Device: OnDropBuffer + Sized {
type Base<T, S>: HasId + PtrType;
type Base<T, S: Shape>: HasId + PtrType;
type Data<T, S: Shape>: HasId + PtrType;

type Error;
Expand Down
2 changes: 1 addition & 1 deletion src/devices/cpu/cpu_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl<Mods> IsCPU for CPU<Mods> {}

impl<Mods: OnDropBuffer> Device for CPU<Mods> {
type Error = Infallible;
type Base<T, S> = CPUPtr<T>;
type Base<T, S: Shape> = CPUPtr<T>;
type Data<T, S: Shape> = Self::Wrap<T, Self::Base<T, S>>;
// type WrappedData<T, S: Shape> = ;

Expand Down
43 changes: 33 additions & 10 deletions src/devices/opencl/cl_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::flag::AllocFlag;
use crate::{
impl_buffer_hook_traits, impl_retriever, pass_down_grad_fn, pass_down_optimize_mem_graph,
pass_down_tape_actions, pass_down_use_gpu_or_cpu, Alloc, Base, Buffer, Cached, CachedCPU,
CloneBuf, Device, Module, OnDropBuffer, OnNewBuffer, Setup, CPU,
CloneBuf, Device, Module, OnDropBuffer, OnNewBuffer, Setup, CPU, WrappedData, impl_wrapped_data,
};
use crate::{PtrConv, Shape};

Expand Down Expand Up @@ -53,6 +53,8 @@ pub type CL = OpenCL;
}
}*/

impl_wrapped_data!(OpenCL);

impl<Mods> Deref for OpenCL<Mods> {
type Target = CLDevice;

Expand Down Expand Up @@ -152,13 +154,39 @@ impl<Mods> OpenCL<Mods> {
}

impl<Mods: OnDropBuffer> Device for OpenCL<Mods> {
type Data<U, S: Shape> = CLPtr<U>;
type Data<T, S: Shape> = Self::Wrap<T, Self::Base<T, S>>;
type Base<U, S: Shape> = CLPtr<U>;
type Error = ();

fn new() -> Result<Self, Self::Error> {
todo!()
// OpenCL::<Base>::new(chosen_cl_idx())
}
#[inline(always)]
fn base_to_data<T, S: Shape>(&self, base: Self::Base<T, S>) -> Self::Data<T, S> {
self.wrap_in_base(base)
}

#[inline(always)]
fn wrap_to_data<T, S: Shape>(&self, wrap: Self::Wrap<T, Self::Base<T, S>>) -> Self::Data<T, S> {
wrap
}

#[inline(always)]
fn data_as_wrap<'a, T, S: Shape>(
data: &'a Self::Data<T, S>,
) -> &'a Self::Wrap<T, Self::Base<T, S>> {
data
}

#[inline(always)]
fn data_as_wrap_mut<'a, T, S: Shape>(
data: &'a mut Self::Data<T, S>,
) -> &'a mut Self::Wrap<T, Self::Base<T, S>> {
data
}


}

impl<Mods: OnDropBuffer, OtherMods: OnDropBuffer> PtrConv<OpenCL<OtherMods>> for OpenCL<Mods> {
Expand All @@ -171,12 +199,7 @@ impl<Mods: OnDropBuffer, OtherMods: OnDropBuffer> PtrConv<OpenCL<OtherMods>> for
IS: Shape,
OS: Shape,
{
CLPtr {
ptr: ptr.ptr,
host_ptr: ptr.host_ptr.cast(),
len: ptr.len,
flag,
}
todo!()
}
}

Expand Down Expand Up @@ -246,10 +269,10 @@ impl<Mods: OnDropBuffer, T> Alloc<T> for OpenCL<Mods> {
}
}

impl<'a, T, Mods: OnDropBuffer + OnNewBuffer<T, Self>> CloneBuf<'a, T> for OpenCL<Mods> {
impl<'a, T, Mods: OnDropBuffer + OnNewBuffer<T, Self, ()>> CloneBuf<'a, T> for OpenCL<Mods> {
fn clone_buf(&'a self, buf: &Buffer<'a, T, Self>) -> Buffer<'a, T, Self> {
let cloned = Buffer::new(self, buf.len());
enqueue_full_copy_buffer::<T>(self.queue(), buf.data.ptr, cloned.data.ptr, buf.len())
enqueue_full_copy_buffer::<T>(self.queue(), buf.base().ptr, cloned.base().ptr, buf.len())
.unwrap();
cloned
}
Expand Down

0 comments on commit 14c2f21

Please sign in to comment.