diff --git a/src/modules/autograd.rs b/src/modules/autograd.rs index ff30350e..6cc09cdc 100644 --- a/src/modules/autograd.rs +++ b/src/modules/autograd.rs @@ -9,7 +9,7 @@ use core::cell::UnsafeCell; use crate::{ pass_down_add_operation, pass_down_exec_now_module, register_buf, unregister_buf, AddGradFn, Alloc, Buffer, Device, HasId, Module, OnDropBuffer, OnNewBuffer, Parents, PtrConv, PtrType, - Retrieve, RunModule, Setup, Shape, TapeActions, WrappedData, + Retrieve, RunModule, Setup, Shape, TapeActions, WrappedData, ShallowCopy, }; use super::{Cached, CachedModule}; @@ -50,6 +50,7 @@ impl Autograd { where T: 'static, D: Device + PtrConv + 'static, + // D::Data: ShallowCopy, S: Shape, { let no_grads_pool = unsafe { &mut (*(self.grads.get())).no_grads_pool.cache }; @@ -67,6 +68,7 @@ impl OnNewBuffer for Autograd where T: 'static, D: Alloc + PtrConv + 'static, + // D::Data: ShallowCopy, Mods: OnNewBuffer, { #[inline] diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index b68989ff..3ed72a7c 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -324,7 +324,7 @@ mod tests { #[cfg(feature = "cpu")] #[test] fn test_lazy_exec_with_range() { - use crate::{ExecNow, Run}; + use crate::{ExecNow, Run, HostPtr}; let device = CPU::>::new(); let mut out: Buffer = device.retrieve(4, ()); @@ -357,7 +357,7 @@ mod tests { #[cfg(feature = "cpu")] #[test] fn test_lazy_exec_last_n() { - use crate::{ExecNow, Run}; + use crate::{ExecNow, Run, HostPtr}; let device = CPU::>::new(); let mut out: Buffer = device.retrieve(4, ()); diff --git a/src/modules/lazy/wrapper.rs b/src/modules/lazy/wrapper.rs index 5b9ecae3..f2309390 100644 --- a/src/modules/lazy/wrapper.rs +++ b/src/modules/lazy/wrapper.rs @@ -1,16 +1,30 @@ use core::{ops::{Deref, DerefMut}, marker::PhantomData}; -use crate::{HasId, PtrType, WrappedData, Lazy}; +use crate::{HasId, PtrType, WrappedData, Lazy, HostPtr, Id, ShallowCopy}; +#[derive(Debug, Default)] pub struct LazyWrapper { data: Option, + id: Option, _pd: PhantomData } + +impl WrappedData for Lazy { + type Wrap = LazyWrapper, T>; + // type Wrap = Mods::Wrap; + + #[inline] + fn wrap_in_base(&self, base: Base) -> Self::Wrap { + todo!() + // self.modules.wrap_in_base(base) + } +} + impl HasId for LazyWrapper { #[inline] fn id(&self) -> crate::Id { - self.data.as_ref().unwrap().id() + self.id.unwrap() } } @@ -47,12 +61,25 @@ impl, T> DerefMut for LazyWrapper { } } -impl WrappedData for Lazy { - // type Wrap = LazyWrapper, T>; - type Wrap = Mods::Wrap; +impl> HostPtr for LazyWrapper { + #[inline] + fn ptr(&self) -> *const T { + self.data.as_ref().unwrap().ptr() + } #[inline] - fn wrap_in_base(&self, base: Base) -> Self::Wrap { - self.modules.wrap_in_base(base) + fn ptr_mut(&mut self) -> *mut T { + self.data.as_mut().unwrap().ptr_mut() + } +} + +impl ShallowCopy for LazyWrapper { + #[inline] + unsafe fn shallow(&self) -> Self { + LazyWrapper { + id: self.id, + data: self.data.as_ref().map(|data| data.shallow()), + _pd: PhantomData + } } } diff --git a/src/modules/mod.rs b/src/modules/mod.rs index 947e10b2..9cffa63a 100644 --- a/src/modules/mod.rs +++ b/src/modules/mod.rs @@ -28,7 +28,7 @@ mod fork; pub use fork::*; #[cfg(not(feature = "no-std"))] -use crate::{flag::AllocFlag, Buffer, Device, HasId, HashLocation, Id, PtrConv, Shape, UniqueId}; +use crate::{flag::AllocFlag, Buffer, Device, HasId, HashLocation, Id, PtrConv, Shape, UniqueId, ShallowCopy}; #[cfg(not(feature = "no-std"))] use core::{any::Any, hash::BuildHasher}; @@ -49,8 +49,11 @@ pub(crate) unsafe fn register_buf( ) where T: 'static, D: Device + PtrConv + 'static, + // D::Data: ShallowCopy, S: Shape, { + + // buf.data let wrapped_data = D::convert::(&buf.data, AllocFlag::Wrapper); let buf = Buffer { data: wrapped_data,