From 2bb76726040fa5fca26194fe880d2ecb8ea122f9 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Sun, 17 Nov 2024 22:50:02 +0100 Subject: [PATCH] Add ToBase, ToDim, remove unsafe from retrieve, .. --- Cargo.toml | 2 +- README.md | 10 ++-- examples/custom_module.rs | 19 +++--- examples/modules_usage.rs | 4 +- src/buffer.rs | 45 ++++++-------- src/cache/borrow_cache.rs | 17 +++--- src/cache/locking/guard.rs | 23 ++++++-- src/cache/owned_cache/fast_cache2.rs | 8 ++- src/devices/cpu/cpu_ptr.rs | 34 +++++++---- src/devices/stack_array.rs | 11 +--- src/features.rs | 9 ++- src/lib.rs | 19 +++--- src/modules/autograd.rs | 87 ++++++++++++++++++++-------- src/modules/autograd/gradients.rs | 3 +- src/modules/autograd/wrapper.rs | 50 ++++++++-------- src/modules/base.rs | 4 +- src/modules/cached.rs | 25 +++++--- src/modules/lazy.rs | 36 +++++++----- src/modules/lazy/wrapper.rs | 50 ++++++++++------ src/modules/mod.rs | 46 ++++----------- src/shape.rs | 40 ++----------- src/unary.rs | 6 +- 22 files changed, 292 insertions(+), 256 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 08373c4f..1a947d98 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true } [features] # default = ["cpu", "blas", "static-api", "macro", "cached", "autograd", "stack", "opencl", "fork", "graph", "untyped"] -default = ["cpu", "cached"] +default = ["cpu", "cached", "autograd"] # default = ["no-std"] # default = ["opencl"] # default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "nnapi"] diff --git a/README.md b/README.md index 76899b0b..1ff1898e 100644 --- a/README.md +++ b/README.md @@ -93,19 +93,19 @@ This operation is only affected by the `Cached` module (and partially `Autograd` use custos::prelude::*; use std::ops::{Deref, Mul}; -pub trait MulBuf: Sized + Device { - fn mul(&self, lhs: &Buffer, rhs: &Buffer) -> Buffer; +pub trait MulBuf<'a, T: Unit, S: Shape = (), D: Device = Self>: Sized + Device { + fn mul(&'a self, lhs: &Buffer, rhs: &Buffer) -> Buffer<'a, T, Self, S>; } -impl MulBuf for CPU +impl<'a, Mods, T, S, D> MulBuf<'a, T, S, D> for CPU where - Mods: Retrieve, + Mods: Retrieve<'a, Self, T, S>, T: Unit + Mul + Copy + 'static, S: Shape, D: Device, D::Base: Deref, { - fn mul(&self, lhs: &Buffer, rhs: &Buffer) -> Buffer { + fn mul(&'a self, lhs: &Buffer, rhs: &Buffer) -> Buffer<'a, T, Self, S> { let mut out = self.retrieve(lhs.len(), (lhs, rhs)).unwrap(); // unwrap or return error (update trait) for ((lhs, rhs), out) in lhs.iter().zip(rhs.iter()).zip(&mut out) { diff --git a/examples/custom_module.rs b/examples/custom_module.rs index 72bd8a4c..e0bd30a4 100644 --- a/examples/custom_module.rs +++ b/examples/custom_module.rs @@ -1,5 +1,6 @@ use custos::{ - Alloc, Base, Device, HasId, IsBasePtr, Module, OnDropBuffer, Parents, PtrType, Retrieve, Setup, Shape, Unit, WrappedData, CPU + Alloc, Base, Device, HasId, IsBasePtr, Module, OnDropBuffer, Parents, PtrType, Retrieve, Setup, + Shape, Unit, WrappedData, CPU, }; pub struct CustomModule { @@ -43,7 +44,9 @@ impl WrappedData for CustomModule { } #[inline] - fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>(wrap: &'b Self::Wrap<'a, T, Base>) -> &'b Base { + fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>( + wrap: &'b Self::Wrap<'a, T, Base>, + ) -> &'b Base { Mods::wrapped_as_base(wrap) } @@ -88,19 +91,19 @@ where self.mods.retrieve_entry(device, len, parents) } - fn on_retrieve_finish(&self, + fn on_retrieve_finish( + &self, len: usize, parents: impl Parents, - retrieved_buf: &custos::prelude::Buffer - ) - where + retrieved_buf: &custos::prelude::Buffer, + ) where D: Alloc, { // inject custom behaviour in this body self.mods.on_retrieve_finish(len, parents, retrieved_buf) } - + unsafe fn retrieve( &self, device: &D, @@ -109,7 +112,7 @@ where ) -> custos::Result::Base>> where S: Shape, - D: Alloc + D: Alloc, { self.mods.retrieve(device, len, parents) } diff --git a/examples/modules_usage.rs b/examples/modules_usage.rs index e667bba6..5d6dcc2b 100644 --- a/examples/modules_usage.rs +++ b/examples/modules_usage.rs @@ -1,8 +1,8 @@ use std::ops::{Add, AddAssign, Deref, DerefMut, Mul}; use custos::{ - AddGradFn, AddOperation, Alloc, Buffer, Device, MayGradActions, - Retrieve, Retriever, Shape, Unit, ZeroGrad, CPU, + AddGradFn, AddOperation, Alloc, Buffer, Device, MayGradActions, Retrieve, Retriever, Shape, + Unit, ZeroGrad, CPU, }; pub trait ElementWise<'a, T: Unit, D: Device, S: Shape>: Device { diff --git a/src/buffer.rs b/src/buffer.rs index e352f912..4a9fd39e 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -11,8 +11,8 @@ use crate::CPU; use crate::{ flag::AllocFlag, shape::Shape, Alloc, Base, ClearBuf, CloneBuf, Device, DevicelessAble, HasId, - IsShapeIndep, OnDropBuffer, OnNewBuffer, PtrType, Read, ReplaceBuf, ShallowCopy, Unit, - WrappedCopy, WrappedData, WriteBuf, ZeroGrad, + IsShapeIndep, OnDropBuffer, OnNewBuffer, PtrType, Read, ReplaceBuf, ShallowCopy, ToDim, Unit, + WrappedData, WriteBuf, ZeroGrad, }; pub use self::num::Num; @@ -136,7 +136,7 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> { // #[inline] // fn id(&self) -> super::Id { // self.data.id() -// } +// }i // } impl<'a, T: Unit, D: Device, S: Shape> HasId for Buffer<'a, T, D, S> { @@ -231,7 +231,7 @@ impl<'a, T: Unit, D: Device + OnNewBuffer<'a, T, D, S>, S: Shape> Buffer<'a, T, /// Creates a new `Buffer` from an nd-array. /// The dimension is defined by the [`Shape`]. #[inline] - pub fn from_array(device: &'a D, array: S::ARR) -> Buffer + pub fn from_array(device: &'a D, array: S::ARR) -> Buffer<'a, T, D, S> where T: Clone, D: Alloc, @@ -271,21 +271,24 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> { } #[inline] - pub fn to_deviceless<'b>(self) -> Buffer<'b, T, D, S> + pub fn to_deviceless<'b>(mut self) -> Buffer<'b, T, D, S> where D::Data<'b, T, S>: Default, + D::Base: ShallowCopy, { if let Some(device) = self.device { if self.data.flag() != AllocFlag::None { device.on_drop_buffer(device, &self) } } - todo!() - // let mut val = ManuallyDrop::new(self); - // let data = core::mem::take(&mut val.data); + unsafe { self.set_flag(AllocFlag::Wrapper) }; + let mut base = unsafe { self.base().shallow() }; + unsafe { base.set_flag(AllocFlag::None) }; + + let data: ::Data<'b, T, S> = self.device().base_to_data::(base); - // Buffer { data, device: None } + Buffer { data, device: None } } /// Returns the device of the `Buffer`. @@ -460,7 +463,6 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> { } } -// TODO better solution for the to_dims stack problem? impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> { /// Converts a non stack allocated `Buffer` with shape `S` to a `Buffer` with shape `O`. /// # Example @@ -474,24 +476,15 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> { /// /// ``` #[inline] - pub fn to_dims(self) -> Buffer<'a, T, D, O> + pub fn to_dims(mut self) -> Buffer<'a, T, D, O> where - // D: crate::ToDim, - D::Data<'a, T, S>: WrappedCopy>, - D::Base: ShallowCopy, + D::Data<'a, T, S>: Default + ToDim>, { - let base = unsafe { (*self).shallow() }; - let data = self.data.wrapped_copy(base); - let buf = ManuallyDrop::new(self); - - // let mut data = buf.device().to_dim(data); - // unsafe { data.set_flag(AllocFlag::None) }; - todo!() - - // Buffer { - // data, - // device: buf.device, - // } + let data = std::mem::take(&mut self.data).to_dim(); + Buffer { + data, + device: self.device, + } } } diff --git a/src/cache/borrow_cache.rs b/src/cache/borrow_cache.rs index ad3c13e2..3aec8c3a 100644 --- a/src/cache/borrow_cache.rs +++ b/src/cache/borrow_cache.rs @@ -122,22 +122,21 @@ impl BorrowCache { } #[inline] - pub fn get_buf( - &self, + pub fn get_buf<'a, 'b, T, D, S>( + &'a self, _device: &D, id: Id, - ) -> Result<&Buffer<'_, T, D, S>, CachingError> + ) -> Result<&'a Buffer<'b, T, D, S>, CachingError> where T: Unit + 'static, D: Device + 'static, S: Shape, { - todo!() - // self.cache - // .get(&id) - // .ok_or(CachingError::InvalidId)? - // .downcast_ref() - // .ok_or(CachingError::InvalidTypeInfo) + let out = self.cache.get(&id).ok_or(CachingError::InvalidId)?; + if !out.is::>() { + return Err(CachingError::InvalidTypeInfo); + } + Ok(unsafe { out.downcast_ref_unchecked::>() }) } #[inline] diff --git a/src/cache/locking/guard.rs b/src/cache/locking/guard.rs index 049d224f..76328ca8 100644 --- a/src/cache/locking/guard.rs +++ b/src/cache/locking/guard.rs @@ -1,9 +1,6 @@ -use core::{ - mem::ManuallyDrop, - ops::{Deref, DerefMut}, -}; +use core::ops::{Deref, DerefMut}; -use crate::{CowMutCell, HasId, HostPtr, PtrType, ShallowCopy}; +use crate::{CowMutCell, HasId, HostPtr, PtrType, ShallowCopy, ToDim}; #[derive(Debug)] pub struct Guard<'a, T> { @@ -24,7 +21,7 @@ impl<'a, T> Guard<'a, T> { Guard { data: f(data) } } - #[inline] + #[inline] pub fn make_static(self) -> Option> { match self.data { CowMutCell::Borrowed(_) => None, @@ -91,3 +88,17 @@ impl<'a, T, P: PtrType + HostPtr> HostPtr for Guard<'a, P> { self.data.get_mut().unwrap().ptr_mut() } } + +impl<'a, P> ToDim for Guard<'a, P> { + type Out = Self; + + #[inline] + fn to_dim(self) -> Self::Out { + self + } + + #[inline] + fn as_dim(&self) -> &Self::Out { + self + } +} diff --git a/src/cache/owned_cache/fast_cache2.rs b/src/cache/owned_cache/fast_cache2.rs index 0b5a5a9c..585c7bad 100644 --- a/src/cache/owned_cache/fast_cache2.rs +++ b/src/cache/owned_cache/fast_cache2.rs @@ -1,5 +1,9 @@ -use core::{any::Any, cell::{Ref, RefMut}, hash::BuildHasherDefault}; use crate::{LockedMap, NoHasher, State, UniqueId}; +use core::{ + any::Any, + cell::{Ref, RefMut}, + hash::BuildHasherDefault, +}; use super::Cache; @@ -18,7 +22,7 @@ impl Cache> for FastCache2 { fn insert(&self, id: UniqueId, _len: usize, data: Box) { self.nodes.insert(id, data); } - + #[inline] fn get(&self, id: UniqueId, _len: usize) -> State>> { self.nodes.get(&id) diff --git a/src/devices/cpu/cpu_ptr.rs b/src/devices/cpu/cpu_ptr.rs index 1d546907..dcb8cffd 100644 --- a/src/devices/cpu/cpu_ptr.rs +++ b/src/devices/cpu/cpu_ptr.rs @@ -7,7 +7,9 @@ use core::{ use std::alloc::handle_alloc_error; -use crate::{flag::AllocFlag, HasId, HostPtr, Id, PtrType, ShallowCopy, Unit, WrappedCopy}; +use crate::{ + flag::AllocFlag, Device, HasId, HostPtr, Id, PtrType, ShallowCopy, Shape, ToBase, ToDim, Unit, +}; /// The pointer used for `CPU` [`Buffer`](crate::Buffer)s #[derive(Debug)] @@ -229,15 +231,6 @@ impl PtrType for CPUPtr { } } -impl WrappedCopy for CPUPtr { - type Base = Self; - - #[inline] - fn wrapped_copy(&self, to_wrap: Self::Base) -> Self { - to_wrap - } -} - impl ShallowCopy for CPUPtr { #[inline] unsafe fn shallow(&self) -> Self { @@ -303,6 +296,27 @@ impl Drop for DeallocWithLayout { } } +impl ToDim for CPUPtr { + type Out = Self; + + #[inline] + fn to_dim(self) -> Self::Out { + self + } + + #[inline] + fn as_dim(&self) -> &Self::Out { + self + } +} + +impl = CPUPtr>, S: Shape> ToBase for CPUPtr { + #[inline] + fn to_base(self) -> D::Base { + self + } +} + #[cfg(feature = "serde")] pub mod serde { use core::{fmt, marker::PhantomData}; diff --git a/src/devices/stack_array.rs b/src/devices/stack_array.rs index 72274aa7..f16ddabf 100644 --- a/src/devices/stack_array.rs +++ b/src/devices/stack_array.rs @@ -1,6 +1,6 @@ use core::ops::{Deref, DerefMut}; -use crate::{shape::Shape, HasId, HostPtr, PtrType, ShallowCopy, Unit, WrappedCopy}; +use crate::{shape::Shape, HasId, HostPtr, PtrType, ShallowCopy, Unit}; /// A possibly multi-dimensional array allocated on the stack. /// It uses `S:`[`Shape`] to get the type of the array. @@ -137,15 +137,6 @@ impl HostPtr for StackArray { } } -impl WrappedCopy for StackArray { - type Base = Self; - - #[inline] - fn wrapped_copy(&self, to_wrap: Self::Base) -> Self { - to_wrap - } -} - impl ShallowCopy for StackArray where S::ARR: Copy, diff --git a/src/features.rs b/src/features.rs index 7213cd8c..83eec0f7 100644 --- a/src/features.rs +++ b/src/features.rs @@ -27,7 +27,7 @@ pub trait Feature: OnDropBuffer {} pub trait Retrieve<'a, D, T: Unit, S: Shape = ()>: OnDropBuffer { // "generator" #[track_caller] - unsafe fn retrieve_entry( + fn retrieve_entry( &'a self, device: &D, len: usize, @@ -38,7 +38,7 @@ pub trait Retrieve<'a, D, T: Unit, S: Shape = ()>: OnDropBuffer { D: Alloc; #[track_caller] - unsafe fn retrieve( + fn retrieve( &self, device: &D, len: usize, @@ -684,6 +684,11 @@ pub trait CachedBuffers { ) -> Option>>> { None } + + #[inline] + fn is_supplied_from_below_module(&self) -> bool { + false + } } #[macro_export] diff --git a/src/lib.rs b/src/lib.rs index c0a88966..363003c0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,19 +29,19 @@ //! use custos::prelude::*; //! use std::ops::{Deref, Mul}; //! -//! pub trait MulBuf: Sized + Device { -//! fn mul(&self, lhs: &Buffer, rhs: &Buffer) -> Buffer; +//! pub trait MulBuf<'a, T: Unit, S: Shape = (), D: Device = Self>: Sized + Device { +//! fn mul(&'a self, lhs: &Buffer, rhs: &Buffer) -> Buffer<'a, T, Self, S>; //! } //! -//! impl MulBuf for CPU +//! impl<'a, Mods, T, S, D> MulBuf<'a, T, S, D> for CPU //! where -//! Mods: Retrieve, +//! Mods: Retrieve<'a, Self, T, S>, //! T: Unit + Mul + Copy + 'static, //! S: Shape, //! D: Device, //! D::Base: Deref, //! { -//! fn mul(&self, lhs: &Buffer, rhs: &Buffer) -> Buffer { +//! fn mul(&'a self, lhs: &Buffer, rhs: &Buffer) -> Buffer<'a, T, Self, S> { //! let mut out = self.retrieve(lhs.len(), (lhs, rhs)).unwrap(); // unwrap or return error (update trait) //! //! for ((lhs, rhs), out) in lhs.iter().zip(rhs.iter()).zip(&mut out) { @@ -185,9 +185,8 @@ pub trait Unit: 'static {} // useful for Sync and Send or 'static impl Unit for T {} -pub trait WrappedCopy { - type Base; - fn wrapped_copy(&self, to_wrap: Self::Base) -> Self; +pub trait ToBase { + fn to_base(self) -> D::Base; } /// Used to shallow-copy a pointer. Use is discouraged. @@ -281,8 +280,8 @@ pub mod tests_helper { use crate::{Buffer, Device, Number, Shape, Unit}; - pub trait AddEw: Device { - fn add(&self, lhs: &Buffer, rhs: &Buffer) -> Buffer; + pub trait AddEw<'a, T: Unit, D: Device, S: Shape>: Device { + fn add(&'a self, lhs: &Buffer, rhs: &Buffer) -> Buffer<'a, T, Self, S>; } pub fn add_ew_slice + Copy>(lhs: &[T], rhs: &[T], out: &mut [T]) { diff --git a/src/modules/autograd.rs b/src/modules/autograd.rs index e8ffd7a7..96b550db 100644 --- a/src/modules/autograd.rs +++ b/src/modules/autograd.rs @@ -13,9 +13,9 @@ use core::{ use crate::{ impl_remove_layer, pass_down_add_operation, pass_down_cached_buffers, pass_down_cursor, pass_down_exec_now_module, pass_down_replace_buf_module, register_buf_copyable, - unregister_buf_copyable, AddGradFn, AddLayer, Alloc, Buffer, Device, GradActions, HasId, - HasModules, IsShapeIndep, Module, OnDropBuffer, OnNewBuffer, Parents, Retrieve, RunModule, - Setup, ShallowCopy, Shape, TapeActions, Unit, + unregister_buf_copyable, AddGradFn, AddLayer, Alloc, Buffer, CachedBuffers, Device, + GradActions, HasId, HasModules, IsShapeIndep, Module, OnDropBuffer, OnNewBuffer, Parents, + Retrieve, RunModule, Setup, ShallowCopy, Shape, TapeActions, Unit, WrappedData, }; use self::wrapper::ReqGradWrapper; @@ -55,9 +55,14 @@ impl<'dev, Mods> Autograd<'dev, Mods> { where T: Unit + 'static, D: Device + IsShapeIndep + 'static, - D::Data<'a, T, S>: ShallowCopy, + D::Data<'static, T, S>: ShallowCopy, + D::Base: ShallowCopy, S: Shape, + Mods: CachedBuffers, { + if self.modules.is_supplied_from_below_module() { + return; + } let no_grads_pool = unsafe { &mut (*self.grads.get()).no_grads_pool }; if no_grads_pool.get(&buf.id()).is_some() { @@ -72,8 +77,9 @@ impl<'dev, T, D, Mods, S: Shape> OnNewBuffer<'dev, T, D, S> for Autograd<'_, Mod where T: Unit + 'static, D: Alloc + IsShapeIndep + 'static, - D::Data<'dev, T, S>: ShallowCopy, - Mods: OnNewBuffer<'dev, T, D, S>, + D::Data<'static, T, S>: ShallowCopy, + D::Base: ShallowCopy, + Mods: OnNewBuffer<'dev, T, D, S> + CachedBuffers, { #[inline] unsafe fn on_new_buffer(&self, device: &'dev D, new_buf: &Buffer<'dev, T, D, S>) { @@ -121,24 +127,24 @@ impl<'dev, Mods: Setup, NewDev> Setup for Autograd<'dev, Mods> { } } -impl<'dev, 'a, T, Mods: Retrieve<'a, D, T, S>, D, S: Shape> Retrieve<'a, D, T, S> for Autograd<'dev, Mods> +impl<'dev, Mods> Autograd<'dev, Mods> where - T: Unit + 'static, - D: IsShapeIndep + Device + 'static, - D::Data<'a, T, S>: ShallowCopy, + Mods: crate::WrappedData, { - #[inline] - unsafe fn retrieve( + fn retrieve_inner<'a, D, T, S, const NUM_PARENTS: usize>( &self, - device: &D, - len: usize, + _device: &D, + _len: usize, parents: &impl Parents, - ) -> crate::Result>> + retrieve_cb: impl Fn() -> crate::Result>>, + ) -> crate::Result<::Wrap<'a, T, D::Base>> where - D: Alloc, + D: Device, + T: Unit, + S: Shape, { let requires_grad = parents.requires_grads().iter().any(|&x| x); - let data = self.modules.retrieve(device, len, parents)?; + let data = retrieve_cb()?; unsafe { (*self.grads.get()) .buf_requires_grad @@ -151,7 +157,32 @@ where _pd: core::marker::PhantomData, }) } - +} + +impl<'dev, 'a, T, Mods: Retrieve<'a, D, T, S>, D, S: Shape> Retrieve<'a, D, T, S> + for Autograd<'dev, Mods> +where + T: Unit + 'static, + D: IsShapeIndep + Device + 'static, + D::Data<'static, T, S>: ShallowCopy, + D::Base: ShallowCopy, + Mods: CachedBuffers, +{ + #[inline] + fn retrieve( + &self, + device: &D, + len: usize, + parents: &impl Parents, + ) -> crate::Result>> + where + D: Alloc, + { + self.retrieve_inner(device, len, parents, || { + self.modules.retrieve(device, len, parents) + }) + } + #[inline] fn on_retrieve_finish( &self, @@ -165,8 +196,9 @@ where self.modules.on_retrieve_finish(len, parents, retrieved_buf) } - - unsafe fn retrieve_entry( + + #[inline] + fn retrieve_entry( &'a self, device: &D, len: usize, @@ -174,9 +206,11 @@ where ) -> crate::Result::Base>> where S: Shape, - D: Alloc + D: Alloc, { - todo!() + self.retrieve_inner(device, len, parents, || { + self.modules.retrieve_entry(device, len, parents) + }) } } @@ -302,7 +336,7 @@ impl<'a, Mods> HasModules for Autograd<'a, Mods> { mod tests { use crate::{ AddGradFn, Autograd, Base, BoxedShallowCopy, Buffer, Cached, Combiner, Cursor, Device, - HasId, Lazy, Retriever, Shape, UnaryGrad, Unit, CPU, + Downcast, HasId, Lazy, Retriever, Shape, UnaryGrad, Unit, CPU, }; #[inline] @@ -310,8 +344,11 @@ mod tests { buf_any: &'b Box, _device: &'a D, ) -> Option<&'b Buffer<'a, T, D, S>> { - todo!() - // buf_any.as_any().downcast_ref::>() + let any = buf_any.as_any(); + if !any.is::>() { + return None; + } + Some(unsafe { Downcast::downcast_ref_unchecked::>(any) }) } #[test] diff --git a/src/modules/autograd/gradients.rs b/src/modules/autograd/gradients.rs index a73fd926..c0d3bb19 100644 --- a/src/modules/autograd/gradients.rs +++ b/src/modules/autograd/gradients.rs @@ -135,7 +135,7 @@ impl Gradients { self.get_ref(buf.device(), buf.id()) } - #[inline] + /*#[inline] pub fn get_buf_from_no_grad_pool<'a, T, S, D>(&self, id: Id) -> &Buffer<'a, T, D, S> where T: Unit + 'static, @@ -151,6 +151,7 @@ impl Gradients { .ok_or(CachingError::InvalidTypeInfo) .expect(INVALID_ID) } + */ } #[cfg(test)] diff --git a/src/modules/autograd/wrapper.rs b/src/modules/autograd/wrapper.rs index 34da1a5c..2fb058a4 100644 --- a/src/modules/autograd/wrapper.rs +++ b/src/modules/autograd/wrapper.rs @@ -1,7 +1,8 @@ use core::marker::PhantomData; use crate::{ - flag::AllocFlag, Autograd, HasId, IsBasePtr, PtrType, ShallowCopy, Unit, WrappedCopy, WrappedData + flag::AllocFlag, Autograd, Device, HasId, IsBasePtr, PtrType, ShallowCopy, Shape, ToBase, + ToDim, Unit, WrappedData, }; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] @@ -12,14 +13,10 @@ pub struct ReqGradWrapper { } impl<'dev, Mods: WrappedData> WrappedData for Autograd<'dev, Mods> { - type Wrap<'a, T: Unit, Base: IsBasePtr> = - ReqGradWrapper, T>; + type Wrap<'a, T: Unit, Base: IsBasePtr> = ReqGradWrapper, T>; #[inline] - fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>( - &self, - base: Base, - ) -> Self::Wrap<'a, T, Base> { + fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> { ReqGradWrapper { // by default: true -> if lazy layer is (accidentally) put before autograd, all gradients will be computed instead of none.. subject to change requires_grad: true, @@ -77,22 +74,6 @@ impl PtrType for ReqGradWrapper { } } -impl WrappedCopy for ReqGradWrapper -where - Data: WrappedCopy, -{ - type Base = T; - - #[inline] - fn wrapped_copy(&self, to_wrap: Self::Base) -> Self { - Self { - requires_grad: self.requires_grad, - data: self.data.wrapped_copy(to_wrap), - _pd: PhantomData, - } - } -} - impl ShallowCopy for ReqGradWrapper where Data: ShallowCopy, @@ -105,3 +86,26 @@ where } } } + +impl, T1, D: Device> ToBase + for ReqGradWrapper +{ + #[inline] + fn to_base(self) -> ::Base { + self.data.to_base() + } +} + +impl ToDim for ReqGradWrapper { + type Out = Self; + + #[inline] + fn to_dim(self) -> Self::Out { + self + } + + #[inline] + fn as_dim(&self) -> &Self::Out { + self + } +} diff --git a/src/modules/base.rs b/src/modules/base.rs index fbe8d52a..f7e3915c 100644 --- a/src/modules/base.rs +++ b/src/modules/base.rs @@ -87,7 +87,7 @@ impl OnDropBuffer for Base {} impl<'a, D, T: Unit, S: Shape> Retrieve<'a, D, T, S> for Base { #[inline] - unsafe fn retrieve_entry( + fn retrieve_entry( &'a self, device: &D, len: usize, @@ -100,7 +100,7 @@ impl<'a, D, T: Unit, S: Shape> Retrieve<'a, D, T, S> for Base { } #[inline] - unsafe fn retrieve( + fn retrieve( &self, device: &D, len: usize, diff --git a/src/modules/cached.rs b/src/modules/cached.rs index 254a369f..f1f9112b 100644 --- a/src/modules/cached.rs +++ b/src/modules/cached.rs @@ -1,9 +1,14 @@ use core::{ - any::Any, cell::{Cell, RefMut}, marker::PhantomData + any::Any, + cell::{Cell, RefMut}, + marker::PhantomData, }; use crate::{ - AddGradFn, AddLayer, AddOperation, Alloc, Buffer, Cache, CachedBuffers, CowMut, Cursor, Device, ExecNow, FastCache2, Guard, HasId, HasModules, IsBasePtr, IsShapeIndep, LockInfo, Module, OnDropBuffer, OnNewBuffer, Parents, PtrType, RemoveLayer, ReplaceBuf, Retrieve, RunModule, SetOpHint, Setup, ShallowCopy, Shape, State, UniqueId, Unit, WrappedData + AddGradFn, AddLayer, AddOperation, Alloc, Buffer, Cache, CachedBuffers, CowMut, Cursor, Device, + ExecNow, FastCache2, Guard, HasId, HasModules, IsBasePtr, IsShapeIndep, LockInfo, Module, + OnDropBuffer, OnNewBuffer, Parents, PtrType, RemoveLayer, ReplaceBuf, Retrieve, RunModule, + SetOpHint, Setup, ShallowCopy, Shape, State, UniqueId, Unit, WrappedData, }; #[cfg(feature = "graph")] @@ -163,10 +168,11 @@ where S: Shape, { let entry = self.cache.get_mut(id, len)?; - let entry = RefMut::map(entry, |x| { + let mut entry = RefMut::map(entry, |x| { x.downcast_mut::>>() .unwrap() }); + unsafe { entry.set_flag(crate::flag::AllocFlag::BorrowedCache) }; Ok(Guard::new(CowMut::BorrowedMut(entry))) } } @@ -183,7 +189,7 @@ where CacheType: Cache>, { #[inline] - unsafe fn retrieve_entry( + fn retrieve_entry( &'a self, device: &D, len: usize, @@ -197,13 +203,16 @@ where Ok(out) => { unsafe { device.bump_cursor() }; Ok(out) - }, + } Err(state) => match state { // return err LockInfo::Locked => panic!("Locked!!"), LockInfo::None => { - self.cache - .insert(id, len, Box::new(self.modules.retrieve(device, len, _parents)?)); + self.cache.insert( + id, + len, + Box::new(self.modules.retrieve(device, len, _parents)?), + ); unsafe { device.bump_cursor() }; Ok(self.get_mut::(id, len).unwrap()) @@ -232,7 +241,7 @@ where self.modules.on_retrieve_finish(len, parents, retrieved_buf) } - unsafe fn retrieve( + fn retrieve( &self, _device: &D, _len: usize, diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index 86409b58..a77f08c4 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -201,7 +201,8 @@ impl<'a, T, D, Mods, S, T2> OnNewBuffer<'a, T, D, S> for Lazy<'_, Mods, T2> where T: Unit + 'static, D: Device + IsShapeIndep + 'static, - D::Data<'a, T, S>: ShallowCopy, + D::Data<'static, T, S>: ShallowCopy, + D::Base: ShallowCopy, Mods: OnNewBuffer<'a, T, D, S>, S: Shape, { @@ -307,11 +308,11 @@ where T: Unit + 'static, Mods: Retrieve<'a, D, T, S>, D: IsShapeIndep + 'static, - D::Data<'a, T, S>: ShallowCopy, + D::Data<'static, T, S>: ShallowCopy, S: Shape, { #[inline] - unsafe fn retrieve( + fn retrieve( &self, _device: &D, len: usize, @@ -369,17 +370,21 @@ where } #[inline] - fn on_retrieve_finish(&self, retrieved_buf: &Buffer) - where + fn on_retrieve_finish( + &self, + len: usize, + parents: impl Parents, + retrieved_buf: &Buffer, + ) where D: Alloc, { // unsafe { register_buf(&mut self.buffers.borrow_mut(), retrieved_buf) }; // pass down - self.modules.on_retrieve_finish(retrieved_buf) + self.modules.on_retrieve_finish(len, parents, retrieved_buf) } - - unsafe fn retrieve_entry( + + fn retrieve_entry( &'a self, device: &D, len: usize, @@ -387,9 +392,9 @@ where ) -> crate::Result::Base>> where S: Shape, - D: Alloc + D: Alloc, { - todo!() + self.retrieve(device, len, parents) } } @@ -492,6 +497,11 @@ impl CachedBuffers for Lazy<'_, Mods, T> { ) -> Option>>> { Some(self.buffers.borrow_mut()) } + + #[inline] + fn is_supplied_from_below_module(&self) -> bool { + true + } } impl HasModules for Lazy<'_, Mods> { @@ -552,16 +562,16 @@ mod tests { } #[cfg(feature = "cpu")] - impl AddEw for CPU + impl<'a, T, D, S, Mods> AddEw<'a, T, D, S> for CPU where T: Unit + Add + Copy + 'static, D: Device + 'static, D::Base: Deref, S: Shape, - Mods: AddOperation + Retrieve + 'static, + Mods: AddOperation + Retrieve<'a, Self, T, S> + 'static, { #[inline] - fn add(&self, lhs: &Buffer, rhs: &Buffer) -> Buffer { + fn add(&'a self, lhs: &Buffer, rhs: &Buffer) -> Buffer<'a, T, Self, S> { let mut out = self.retrieve(lhs.len(), ()).unwrap(); self.add_op((lhs, rhs, &mut out), |(lhs, rhs, out)| { add_ew_slice(lhs, rhs, out.as_mut_slice()); diff --git a/src/modules/lazy/wrapper.rs b/src/modules/lazy/wrapper.rs index b3ee80b4..5bbd81f5 100644 --- a/src/modules/lazy/wrapper.rs +++ b/src/modules/lazy/wrapper.rs @@ -7,7 +7,8 @@ use core::{ }; use crate::{ - flag::AllocFlag, HasId, HostPtr, IsBasePtr, Lazy, PtrType, ShallowCopy, Unit, WrappedCopy, WrappedData + flag::AllocFlag, Device, HasId, HostPtr, IsBasePtr, Lazy, PtrType, ShallowCopy, Shape, ToBase, + ToDim, Unit, WrappedData, }; #[derive(Debug, Default)] @@ -28,7 +29,9 @@ impl WrappedData for Lazy<'_, Mods, T2> { } #[inline] - fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>(wrap: &'b Self::Wrap<'a, T, Base>) -> &'b Base { + fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>( + wrap: &'b Self::Wrap<'a, T, Base>, + ) -> &'b Base { Mods::wrapped_as_base(wrap.maybe_data.data().expect(MISSING_DATA)) } @@ -106,16 +109,12 @@ impl> HostPtr for LazyWrapper { } } -impl WrappedCopy for LazyWrapper -where - Data: WrappedCopy, -{ - type Base = T; - - fn wrapped_copy(&self, to_wrap: Self::Base) -> Self { +impl ShallowCopy for LazyWrapper { + #[inline] + unsafe fn shallow(&self) -> Self { LazyWrapper { maybe_data: match &self.maybe_data { - MaybeData::Data(data) => MaybeData::Data(data.wrapped_copy(to_wrap)), + MaybeData::Data(data) => MaybeData::Data(data.shallow()), MaybeData::Id(id) => MaybeData::Id(*id), MaybeData::None => unimplemented!(), }, @@ -124,16 +123,29 @@ where } } -impl ShallowCopy for LazyWrapper { +impl, T1, D: Device> ToBase + for LazyWrapper +{ #[inline] - unsafe fn shallow(&self) -> Self { - LazyWrapper { - maybe_data: match &self.maybe_data { - MaybeData::Data(data) => MaybeData::Data(data.shallow()), - MaybeData::Id(id) => MaybeData::Id(*id), - MaybeData::None => unimplemented!(), - }, - _pd: PhantomData, + fn to_base(self) -> ::Base { + match self.maybe_data { + MaybeData::Data(data) => data.to_base(), + MaybeData::Id(_id) => unimplemented!("Cannot convert id wrapper to base"), + MaybeData::None => unimplemented!("Cannot convert nothin to base"), } } } + +impl ToDim for LazyWrapper { + type Out = Self; + + #[inline] + fn to_dim(self) -> Self::Out { + self + } + + #[inline] + fn as_dim(&self) -> &Self::Out { + self + } +} diff --git a/src/modules/mod.rs b/src/modules/mod.rs index 8cfb6303..48682995 100644 --- a/src/modules/mod.rs +++ b/src/modules/mod.rs @@ -94,18 +94,15 @@ pub(crate) unsafe fn register_buf_any<'a, T, D, S>( T: crate::Unit + 'static, D: Device + crate::IsShapeIndep + 'static, D::Data<'a, T, S>: ShallowCopy, + D::Base: ShallowCopy, S: Shape, { // shallow copy sets flag to AllocFlag::Wrapper + let wrapped_data = unsafe { buf.base().shallow() }; + let data: ::Data<'static, T, S> = buf.device().base_to_data::(wrapped_data); - let wrapped_data = unsafe { buf.data.shallow() }; - - let buf: Buffer = Buffer { - data: wrapped_data, - device: None, - }; - todo!() - // cache.insert(*buf.id(), Box::new(buf)); + let buf: Buffer<'static, T, D, S> = Buffer { data, device: None }; + cache.insert(*buf.id(), Box::new(buf)); } #[cfg(feature = "std")] @@ -123,43 +120,20 @@ pub(crate) fn unregister_buf_any( #[allow(unused)] pub(crate) unsafe fn register_buf_copyable<'a, T, D, S>( cache: &mut HashMap, impl BuildHasher>, - buf: &Buffer<'a, T, D, S>, -) where - T: crate::Unit + 'static, - D: Device + crate::IsShapeIndep + 'static, - D::Data<'a, T, S>: ShallowCopy, - D::Base: ShallowCopy, - S: Shape, -{ - // shallow copy sets flag to AllocFlag::Wrapper - let wrapped_data = unsafe { buf.data.shallow() }; - - let buf: Buffer = Buffer { - data: wrapped_data, - device: None, - }; - todo!() - // cache.insert(*buf.id(), Box::new(buf)); -} - -pub(crate) unsafe fn register_buf_copyable2<'a, T, D, S>( - cache: &mut HashMap, impl BuildHasher>, - buf: &Buffer<'a, T, D, S>, + buf: &Buffer, ) where T: crate::Unit + 'static, D: Device + crate::IsShapeIndep + 'static, D::Base: ShallowCopy, + D::Data<'static, T, S>: ShallowCopy, S: Shape, { // shallow copy sets flag to AllocFlag::Wrapper let wrapped_data = unsafe { buf.base().shallow() }; + let data: ::Data<'static, T, S> = buf.device().base_to_data::(wrapped_data); - // let buf: Buffer = Buffer { - // data: wrapped_data, - // device: None, - // }; - todo!() - // cache.insert(*buf.id(), Box::new(buf)); + let buf: Buffer<'static, T, D, S> = Buffer { data, device: None }; + cache.insert(*buf.id(), Box::new(buf)); } #[cfg(feature = "std")] diff --git a/src/shape.rs b/src/shape.rs index 34d7ad05..b4bb0bc0 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -1,4 +1,4 @@ -use crate::{Device, ShallowCopy, Unit}; +use crate::Device; /// Determines the shape of a [`Buffer`](crate::Buffer). /// `Shape` is used to get the size and ND-Array for a stack allocated `Buffer`. @@ -118,40 +118,10 @@ impl Shape for Dim3 { } } -// TODO: do not use device -/// Converts a pointer to a different [`Shape`]. -// pub trait ToDim: crate::Device { -// /// Converts a pointer to a different [`Shape`]. -// /// This is only possible for [`Buffer`](crate::Buffer)s that are not allocated on the stack. -// fn to_dim(&self, ptr: Self::Data) -> Self::Data; -// } - -// #[cfg(feature = "std")] -// impl ToDim for D -// where -// T: Unit, -// D::Data: crate::PtrType + ShallowCopy, -// D: IsShapeIndep + Device, -// I: Shape, -// O: Shape, -// { -// #[inline] -// fn to_dim(&self, ptr: Self::Data) -> D::Data { -// // resources are now mananged by the destructed raw pointer (prevents double free). -// // could set alloc flag as well -// let ptr = core::mem::ManuallyDrop::new(ptr); - -// let shape_changed = unsafe { &*(&*ptr as *const D::Data as *const D::Data) }; -// unsafe { shape_changed.shallow() } -// } -// } - -#[cfg(feature = "stack")] -impl ToDim for crate::Stack { - #[inline] - fn to_dim(&self, ptr: Self::Data) -> Self::Data { - ptr - } +pub trait ToDim { + type Out; + fn to_dim(self) -> Self::Out; + fn as_dim(&self) -> &Self::Out; } #[cfg(test)] diff --git a/src/unary.rs b/src/unary.rs index 861bf3e1..5ac16a8c 100644 --- a/src/unary.rs +++ b/src/unary.rs @@ -108,7 +108,7 @@ where buf: &Buffer<'a, T, D, S>, forward_fn: impl Fn(Resolve) -> FO + Copy + 'static, grad_fn: fn(Resolve) -> GO, - ) -> Buffer + ) -> Buffer<'a, T, Self, S> where FO: TwoWay, GO: Eval + MayToCLSource + 'static, @@ -169,14 +169,14 @@ mod tests { #[cfg(feature = "autograd")] fn test_unary_autograd<'a, 'b, D>(device: &'a D) where - D::Data: crate::ShallowCopy, + D::Data<'a, f32, ()>: crate::ShallowCopy, D: 'static + crate::WriteBuf + crate::Read + crate::GradActions + crate::TapeActions<'b> + crate::HasAutograd - + crate::UnaryElementWiseMayGrad + + crate::UnaryElementWiseMayGrad<'a, f32, D, ()> + crate::Alloc + crate::CachedBuffers + crate::AddOperation