Skip to content

Commit

Permalink
Add Drop impl for LazyWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Nov 18, 2024
1 parent 1072c95 commit c1aee7d
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 36 deletions.
13 changes: 0 additions & 13 deletions src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,19 +190,6 @@ impl<'a, T: Unit, D: Device, S: Shape> HasId for &mut Buffer<'a, T, D, S> {
}
}

impl<'a, T: Unit, D: Device, S: Shape> Drop for Buffer<'a, T, D, S> {
#[inline]
fn drop(&mut self) {
if self.data.flag() != AllocFlag::None {
return;
}

if let Some(device) = self.device {
device.on_drop_buffer(device, self)
}
}
}

impl<'a, T: Unit, D: Device + OnNewBuffer<'a, T, D, S>, S: Shape> Buffer<'a, T, D, S> {
/// Creates a new `Buffer` from a slice (&[T]).
#[inline]
Expand Down
2 changes: 1 addition & 1 deletion src/modules/autograd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ impl<'dev, Mods: OnDropBuffer> OnDropBuffer for Autograd<'dev, Mods> {
#[inline]
fn on_drop_buffer<T: Unit, D: Device, S: Shape>(&self, device: &D, buf: &Buffer<T, D, S>) {
unsafe { (*self.grads.get()).buf_requires_grad.remove(&*buf.id()) };
unregister_buf_copyable(unsafe { &mut (*self.grads.get()).no_grads_pool }, buf.id());
unregister_buf_copyable(unsafe { &mut (*self.grads.get()).no_grads_pool }, *buf.id());

// TODO
// FIXME if an alloc flag None buffer goes out of scope and it has used it's gradient buffer before,
Expand Down
5 changes: 3 additions & 2 deletions src/modules/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ impl<T2, Mods: OnDropBuffer> OnDropBuffer for Lazy<'_, Mods, T2> {
device: &D,
buf: &Buffer<T, D, S>,
) {
unregister_buf_copyable(&mut self.buffers.borrow_mut(), buf.id());
unregister_buf_copyable(&mut self.buffers.borrow_mut(), *buf.id());
self.modules.on_drop_buffer(device, buf)
}
}
Expand Down Expand Up @@ -350,7 +350,7 @@ where
let base = device
.alloc::<S>(id.len, crate::flag::AllocFlag::Lazy)
.unwrap();
let data = device.default_base_to_data(base);
let data = device.default_base_to_data_unbound(base);
let buffer = Buffer {
data,
device: Some(device),
Expand All @@ -365,6 +365,7 @@ where

Ok(LazyWrapper {
maybe_data: MaybeData::Id(id),
remove_id_cb: None,
_pd: core::marker::PhantomData,
})
}
Expand Down
66 changes: 47 additions & 19 deletions src/modules/lazy/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,47 @@ use core::{

use crate::{
flag::AllocFlag, Device, HasId, HostPtr, IsBasePtr, Lazy, PtrType, ShallowCopy, Shape, ToBase,
ToDim, Unit, WrappedData,
ToDim, UniqueId, Unit, WrappedData,
};

#[derive(Debug, Default)]
pub struct LazyWrapper<Data, T> {
use super::unregister_buf_copyable;

#[derive(Default)]
pub struct LazyWrapper<'a, Data: HasId, T> {
pub maybe_data: MaybeData<Data>,
pub _pd: PhantomData<T>,
pub remove_id_cb: Option<Box<dyn Fn(UniqueId) + 'a>>,
pub _pd: PhantomData<&'a T>,
}

impl<'a, Data: HasId, T> Drop for LazyWrapper<'a, Data, T> {
#[inline]
fn drop(&mut self) {
if let Some(remove_id_cb) = &self.remove_id_cb {
remove_id_cb(*self.id())
}
}
}

impl<'a, Data: std::fmt::Debug + HasId, T> std::fmt::Debug for LazyWrapper<'a, Data, T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("LazyWrapper")
.field("maybe_data", &self.maybe_data)
.field("remove_id_cb", &"callback()")
.field("_pd", &self._pd)
.finish()
}
}

impl<T2, Mods: WrappedData> WrappedData for Lazy<'_, Mods, T2> {
type Wrap<'a, T: Unit, Base: IsBasePtr> = LazyWrapper<Mods::Wrap<'a, T, Base>, T>;
type Wrap<'a, T: Unit, Base: IsBasePtr> = LazyWrapper<'a, Mods::Wrap<'a, T, Base>, T>;

#[inline]
fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&'a self, base: Base) -> Self::Wrap<'a, T, Base> {
LazyWrapper {
maybe_data: MaybeData::Data(self.modules.wrap_in_base(base)),
remove_id_cb: Some(Box::new(|id| {
unregister_buf_copyable(&mut self.buffers.borrow_mut(), id)
})),
_pd: PhantomData,
}
}
Expand All @@ -35,6 +60,7 @@ impl<T2, Mods: WrappedData> WrappedData for Lazy<'_, Mods, T2> {
) -> Self::Wrap<'a, T, Base> {
LazyWrapper {
maybe_data: MaybeData::Data(self.modules.wrap_in_base_unbound(base)),
remove_id_cb: None,
_pd: PhantomData,
}
}
Expand All @@ -54,7 +80,7 @@ impl<T2, Mods: WrappedData> WrappedData for Lazy<'_, Mods, T2> {
}
}

impl<Data: HasId, T> HasId for LazyWrapper<Data, T> {
impl<'a, Data: HasId, T> HasId for LazyWrapper<'a, Data, T> {
#[inline]
fn id(&self) -> crate::Id {
match self.maybe_data {
Expand All @@ -65,7 +91,7 @@ impl<Data: HasId, T> HasId for LazyWrapper<Data, T> {
}
}

impl<Data: PtrType, T: Unit> PtrType for LazyWrapper<Data, T> {
impl<'a, Data: PtrType + HasId, T: Unit> PtrType for LazyWrapper<'a, Data, T> {
#[inline]
fn size(&self) -> usize {
match self.maybe_data {
Expand All @@ -92,7 +118,7 @@ impl<Data: PtrType, T: Unit> PtrType for LazyWrapper<Data, T> {
const MISSING_DATA: &str =
"This lazy buffer does not contain any data. Try with a buffer.replace() call.";

impl<Data: Deref<Target = [T]>, T> Deref for LazyWrapper<Data, T> {
impl<'a, Data: HasId + Deref<Target = [T]>, T> Deref for LazyWrapper<'a, Data, T> {
type Target = Data;

#[inline]
Expand All @@ -101,14 +127,14 @@ impl<Data: Deref<Target = [T]>, T> Deref for LazyWrapper<Data, T> {
}
}

impl<Data: DerefMut<Target = [T]>, T> DerefMut for LazyWrapper<Data, T> {
impl<'a, Data: HasId + DerefMut<Target = [T]>, T> DerefMut for LazyWrapper<'a, Data, T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
self.maybe_data.data_mut().expect(MISSING_DATA)
}
}

impl<T: Unit, Data: HostPtr<T>> HostPtr<T> for LazyWrapper<Data, T> {
impl<'a, T: Unit, Data: HasId + HostPtr<T>> HostPtr<T> for LazyWrapper<'a, Data, T> {
#[inline]
fn ptr(&self) -> *const T {
self.maybe_data.data().unwrap().ptr()
Expand All @@ -120,7 +146,7 @@ impl<T: Unit, Data: HostPtr<T>> HostPtr<T> for LazyWrapper<Data, T> {
}
}

impl<Data: ShallowCopy, T> ShallowCopy for LazyWrapper<Data, T> {
impl<'a, Data: HasId + ShallowCopy, T> ShallowCopy for LazyWrapper<'a, Data, T> {
#[inline]
unsafe fn shallow(&self) -> Self {
LazyWrapper {
Expand All @@ -129,25 +155,27 @@ impl<Data: ShallowCopy, T> ShallowCopy for LazyWrapper<Data, T> {
MaybeData::Id(id) => MaybeData::Id(*id),
MaybeData::None => unimplemented!(),
},
remove_id_cb: None,
_pd: PhantomData,
}
}
}

impl<T: Unit, S: Shape, Data: ToBase<T, D, S>, T1, D: Device> ToBase<T, D, S>
for LazyWrapper<Data, T1>
impl<'a, T: Unit, S: Shape, Data: HasId + ToBase<T, D, S>, T1, D: Device> ToBase<T, D, S>
for LazyWrapper<'a, Data, T1>
{
#[inline]
fn to_base(self) -> <D as Device>::Base<T, S> {
match self.maybe_data {
MaybeData::Data(data) => data.to_base(),
MaybeData::Id(_id) => unimplemented!("Cannot convert id wrapper to base"),
MaybeData::None => unimplemented!("Cannot convert nothin to base"),
}
todo!()
// match self.maybe_data {
// MaybeData::Data(data) => data.to_base(),
// MaybeData::Id(_id) => unimplemented!("Cannot convert id wrapper to base"),
// MaybeData::None => unimplemented!("Cannot convert nothin to base"),
// }
}
}

impl<T, Data> ToDim for LazyWrapper<Data, T> {
impl<'a, T, Data: HasId> ToDim for LazyWrapper<'a, Data, T> {
type Out = Self;

#[inline]
Expand Down
2 changes: 1 addition & 1 deletion src/modules/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ pub(crate) unsafe fn register_buf_copyable<'a, T, D, S>(
#[allow(unused)]
pub(crate) fn unregister_buf_copyable(
cache: &mut HashMap<UniqueId, Box<dyn crate::BoxedShallowCopy>, impl BuildHasher>,
id: Id,
id: UniqueId,
) {
cache.remove(&id);
}

0 comments on commit c1aee7d

Please sign in to comment.