Skip to content

Commit

Permalink
Impl ShallowCopy for LazyWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Dec 9, 2023
1 parent 9a81330 commit 83680a0
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 11 deletions.
4 changes: 3 additions & 1 deletion src/modules/autograd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use core::cell::UnsafeCell;
use crate::{
pass_down_add_operation, pass_down_exec_now_module, register_buf, unregister_buf, AddGradFn,
Alloc, Buffer, Device, HasId, Module, OnDropBuffer, OnNewBuffer, Parents, PtrConv, PtrType,
Retrieve, RunModule, Setup, Shape, TapeActions, WrappedData,
Retrieve, RunModule, Setup, Shape, TapeActions, WrappedData, ShallowCopy,
};

use super::{Cached, CachedModule};
Expand Down Expand Up @@ -50,6 +50,7 @@ impl<Mods> Autograd<Mods> {
where
T: 'static,
D: Device + PtrConv + 'static,
// D::Data<T, S>: ShallowCopy,
S: Shape,
{
let no_grads_pool = unsafe { &mut (*(self.grads.get())).no_grads_pool.cache };
Expand All @@ -67,6 +68,7 @@ impl<T, D, Mods> OnNewBuffer<T, D> for Autograd<Mods>
where
T: 'static,
D: Alloc<T> + PtrConv + 'static,
// D::Data<T, S>: ShallowCopy,
Mods: OnNewBuffer<T, D>,
{
#[inline]
Expand Down
4 changes: 2 additions & 2 deletions src/modules/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ mod tests {
#[cfg(feature = "cpu")]
#[test]
fn test_lazy_exec_with_range() {
use crate::{ExecNow, Run};
use crate::{ExecNow, Run, HostPtr};

let device = CPU::<Lazy<Base>>::new();
let mut out: Buffer<i32, _, ()> = device.retrieve(4, ());
Expand Down Expand Up @@ -357,7 +357,7 @@ mod tests {
#[cfg(feature = "cpu")]
#[test]
fn test_lazy_exec_last_n() {
use crate::{ExecNow, Run};
use crate::{ExecNow, Run, HostPtr};

let device = CPU::<Lazy<Base>>::new();
let mut out: Buffer<i32, _, ()> = device.retrieve(4, ());
Expand Down
41 changes: 34 additions & 7 deletions src/modules/lazy/wrapper.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
use core::{ops::{Deref, DerefMut}, marker::PhantomData};

use crate::{HasId, PtrType, WrappedData, Lazy};
use crate::{HasId, PtrType, WrappedData, Lazy, HostPtr, Id, ShallowCopy};

#[derive(Debug, Default)]
pub struct LazyWrapper<Data, T> {
data: Option<Data>,
id: Option<Id>,
_pd: PhantomData<T>
}


impl<Mods: WrappedData> WrappedData for Lazy<Mods> {
type Wrap<T, Base: HasId + PtrType> = LazyWrapper<Mods::Wrap<T, Base>, T>;
// type Wrap<T, Base: HasId + PtrType> = Mods::Wrap<T, Base>;

#[inline]
fn wrap_in_base<T, Base: HasId + PtrType>(&self, base: Base) -> Self::Wrap<T, Base> {
todo!()
// self.modules.wrap_in_base(base)
}
}

impl<Data: HasId, T> HasId for LazyWrapper<Data, T> {
#[inline]
fn id(&self) -> crate::Id {
self.data.as_ref().unwrap().id()
self.id.unwrap()
}
}

Expand Down Expand Up @@ -47,12 +61,25 @@ impl<Data: DerefMut<Target = [T]>, T> DerefMut for LazyWrapper<Data, T> {
}
}

impl<Mods: WrappedData> WrappedData for Lazy<Mods> {
// type Wrap<T, Base: HasId + PtrType> = LazyWrapper<Mods::Wrap<T, Base>, T>;
type Wrap<T, Base: HasId + PtrType> = Mods::Wrap<T, Base>;
impl<T, Data: HostPtr<T>> HostPtr<T> for LazyWrapper<Data, T> {
#[inline]
fn ptr(&self) -> *const T {
self.data.as_ref().unwrap().ptr()
}

#[inline]
fn wrap_in_base<T, Base: HasId + PtrType>(&self, base: Base) -> Self::Wrap<T, Base> {
self.modules.wrap_in_base(base)
fn ptr_mut(&mut self) -> *mut T {
self.data.as_mut().unwrap().ptr_mut()
}
}

impl<Data: ShallowCopy, T> ShallowCopy for LazyWrapper<Data, T> {
#[inline]
unsafe fn shallow(&self) -> Self {
LazyWrapper {
id: self.id,
data: self.data.as_ref().map(|data| data.shallow()),
_pd: PhantomData
}
}
}
5 changes: 4 additions & 1 deletion src/modules/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ mod fork;
pub use fork::*;

#[cfg(not(feature = "no-std"))]
use crate::{flag::AllocFlag, Buffer, Device, HasId, HashLocation, Id, PtrConv, Shape, UniqueId};
use crate::{flag::AllocFlag, Buffer, Device, HasId, HashLocation, Id, PtrConv, Shape, UniqueId, ShallowCopy};
#[cfg(not(feature = "no-std"))]
use core::{any::Any, hash::BuildHasher};

Expand All @@ -49,8 +49,11 @@ pub(crate) unsafe fn register_buf<T, D, S>(
) where
T: 'static,
D: Device + PtrConv + 'static,
// D::Data<T, S>: ShallowCopy,
S: Shape,
{

// buf.data
let wrapped_data = D::convert::<T, S, T, S>(&buf.data, AllocFlag::Wrapper);
let buf = Buffer {
data: wrapped_data,
Expand Down

0 comments on commit 83680a0

Please sign in to comment.