From c1aee7dad2dd3a23f94d235545d11ad9c37e15e0 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Mon, 18 Nov 2024 22:23:54 +0100 Subject: [PATCH] Add Drop impl for LazyWrapper --- src/buffer.rs | 13 -------- src/modules/autograd.rs | 2 +- src/modules/lazy.rs | 5 +-- src/modules/lazy/wrapper.rs | 66 ++++++++++++++++++++++++++----------- src/modules/mod.rs | 2 +- 5 files changed, 52 insertions(+), 36 deletions(-) diff --git a/src/buffer.rs b/src/buffer.rs index 08235944..1001cdc1 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -190,19 +190,6 @@ impl<'a, T: Unit, D: Device, S: Shape> HasId for &mut Buffer<'a, T, D, S> { } } -impl<'a, T: Unit, D: Device, S: Shape> Drop for Buffer<'a, T, D, S> { - #[inline] - fn drop(&mut self) { - if self.data.flag() != AllocFlag::None { - return; - } - - if let Some(device) = self.device { - device.on_drop_buffer(device, self) - } - } -} - impl<'a, T: Unit, D: Device + OnNewBuffer<'a, T, D, S>, S: Shape> Buffer<'a, T, D, S> { /// Creates a new `Buffer` from a slice (&[T]). #[inline] diff --git a/src/modules/autograd.rs b/src/modules/autograd.rs index 3f50bc7c..39714b22 100644 --- a/src/modules/autograd.rs +++ b/src/modules/autograd.rs @@ -108,7 +108,7 @@ impl<'dev, Mods: OnDropBuffer> OnDropBuffer for Autograd<'dev, Mods> { #[inline] fn on_drop_buffer(&self, device: &D, buf: &Buffer) { unsafe { (*self.grads.get()).buf_requires_grad.remove(&*buf.id()) }; - unregister_buf_copyable(unsafe { &mut (*self.grads.get()).no_grads_pool }, buf.id()); + unregister_buf_copyable(unsafe { &mut (*self.grads.get()).no_grads_pool }, *buf.id()); // TODO // FIXME if an alloc flag None buffer goes out of scope and it has used it's gradient buffer before, diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index de0317f0..0c0f1382 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -192,7 +192,7 @@ impl OnDropBuffer for Lazy<'_, Mods, T2> { device: &D, buf: &Buffer, ) { - unregister_buf_copyable(&mut self.buffers.borrow_mut(), buf.id()); + unregister_buf_copyable(&mut self.buffers.borrow_mut(), *buf.id()); self.modules.on_drop_buffer(device, buf) } } @@ -350,7 +350,7 @@ where let base = device .alloc::(id.len, crate::flag::AllocFlag::Lazy) .unwrap(); - let data = device.default_base_to_data(base); + let data = device.default_base_to_data_unbound(base); let buffer = Buffer { data, device: Some(device), @@ -365,6 +365,7 @@ where Ok(LazyWrapper { maybe_data: MaybeData::Id(id), + remove_id_cb: None, _pd: core::marker::PhantomData, }) } diff --git a/src/modules/lazy/wrapper.rs b/src/modules/lazy/wrapper.rs index f1fd1bb3..6e646504 100644 --- a/src/modules/lazy/wrapper.rs +++ b/src/modules/lazy/wrapper.rs @@ -8,22 +8,47 @@ use core::{ use crate::{ flag::AllocFlag, Device, HasId, HostPtr, IsBasePtr, Lazy, PtrType, ShallowCopy, Shape, ToBase, - ToDim, Unit, WrappedData, + ToDim, UniqueId, Unit, WrappedData, }; -#[derive(Debug, Default)] -pub struct LazyWrapper { +use super::unregister_buf_copyable; + +#[derive(Default)] +pub struct LazyWrapper<'a, Data: HasId, T> { pub maybe_data: MaybeData, - pub _pd: PhantomData, + pub remove_id_cb: Option>, + pub _pd: PhantomData<&'a T>, +} + +impl<'a, Data: HasId, T> Drop for LazyWrapper<'a, Data, T> { + #[inline] + fn drop(&mut self) { + if let Some(remove_id_cb) = &self.remove_id_cb { + remove_id_cb(*self.id()) + } + } +} + +impl<'a, Data: std::fmt::Debug + HasId, T> std::fmt::Debug for LazyWrapper<'a, Data, T> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("LazyWrapper") + .field("maybe_data", &self.maybe_data) + .field("remove_id_cb", &"callback()") + .field("_pd", &self._pd) + .finish() + } } impl WrappedData for Lazy<'_, Mods, T2> { - type Wrap<'a, T: Unit, Base: IsBasePtr> = LazyWrapper, T>; + type Wrap<'a, T: Unit, Base: IsBasePtr> = LazyWrapper<'a, Mods::Wrap<'a, T, Base>, T>; #[inline] fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&'a self, base: Base) -> Self::Wrap<'a, T, Base> { LazyWrapper { maybe_data: MaybeData::Data(self.modules.wrap_in_base(base)), + remove_id_cb: Some(Box::new(|id| { + unregister_buf_copyable(&mut self.buffers.borrow_mut(), id) + })), _pd: PhantomData, } } @@ -35,6 +60,7 @@ impl WrappedData for Lazy<'_, Mods, T2> { ) -> Self::Wrap<'a, T, Base> { LazyWrapper { maybe_data: MaybeData::Data(self.modules.wrap_in_base_unbound(base)), + remove_id_cb: None, _pd: PhantomData, } } @@ -54,7 +80,7 @@ impl WrappedData for Lazy<'_, Mods, T2> { } } -impl HasId for LazyWrapper { +impl<'a, Data: HasId, T> HasId for LazyWrapper<'a, Data, T> { #[inline] fn id(&self) -> crate::Id { match self.maybe_data { @@ -65,7 +91,7 @@ impl HasId for LazyWrapper { } } -impl PtrType for LazyWrapper { +impl<'a, Data: PtrType + HasId, T: Unit> PtrType for LazyWrapper<'a, Data, T> { #[inline] fn size(&self) -> usize { match self.maybe_data { @@ -92,7 +118,7 @@ impl PtrType for LazyWrapper { const MISSING_DATA: &str = "This lazy buffer does not contain any data. Try with a buffer.replace() call."; -impl, T> Deref for LazyWrapper { +impl<'a, Data: HasId + Deref, T> Deref for LazyWrapper<'a, Data, T> { type Target = Data; #[inline] @@ -101,14 +127,14 @@ impl, T> Deref for LazyWrapper { } } -impl, T> DerefMut for LazyWrapper { +impl<'a, Data: HasId + DerefMut, T> DerefMut for LazyWrapper<'a, Data, T> { #[inline] fn deref_mut(&mut self) -> &mut Self::Target { self.maybe_data.data_mut().expect(MISSING_DATA) } } -impl> HostPtr for LazyWrapper { +impl<'a, T: Unit, Data: HasId + HostPtr> HostPtr for LazyWrapper<'a, Data, T> { #[inline] fn ptr(&self) -> *const T { self.maybe_data.data().unwrap().ptr() @@ -120,7 +146,7 @@ impl> HostPtr for LazyWrapper { } } -impl ShallowCopy for LazyWrapper { +impl<'a, Data: HasId + ShallowCopy, T> ShallowCopy for LazyWrapper<'a, Data, T> { #[inline] unsafe fn shallow(&self) -> Self { LazyWrapper { @@ -129,25 +155,27 @@ impl ShallowCopy for LazyWrapper { MaybeData::Id(id) => MaybeData::Id(*id), MaybeData::None => unimplemented!(), }, + remove_id_cb: None, _pd: PhantomData, } } } -impl, T1, D: Device> ToBase - for LazyWrapper +impl<'a, T: Unit, S: Shape, Data: HasId + ToBase, T1, D: Device> ToBase + for LazyWrapper<'a, Data, T1> { #[inline] fn to_base(self) -> ::Base { - match self.maybe_data { - MaybeData::Data(data) => data.to_base(), - MaybeData::Id(_id) => unimplemented!("Cannot convert id wrapper to base"), - MaybeData::None => unimplemented!("Cannot convert nothin to base"), - } + todo!() + // match self.maybe_data { + // MaybeData::Data(data) => data.to_base(), + // MaybeData::Id(_id) => unimplemented!("Cannot convert id wrapper to base"), + // MaybeData::None => unimplemented!("Cannot convert nothin to base"), + // } } } -impl ToDim for LazyWrapper { +impl<'a, T, Data: HasId> ToDim for LazyWrapper<'a, Data, T> { type Out = Self; #[inline] diff --git a/src/modules/mod.rs b/src/modules/mod.rs index c6ad3dc9..f4424c0a 100644 --- a/src/modules/mod.rs +++ b/src/modules/mod.rs @@ -145,7 +145,7 @@ pub(crate) unsafe fn register_buf_copyable<'a, T, D, S>( #[allow(unused)] pub(crate) fn unregister_buf_copyable( cache: &mut HashMap, impl BuildHasher>, - id: Id, + id: UniqueId, ) { cache.remove(&id); }