Skip to content

Commit

Permalink
Add check at LazyWrapper if Data is actually lazy
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Dec 9, 2023
1 parent 7eaa3c5 commit ddccb7d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
8 changes: 4 additions & 4 deletions src/modules/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ mod tests {
use core::ops::{Add, Deref};

use crate::{
cpu::CPUPtr, AddOperation, ApplyFunction, Base, Buffer, Combiner, Device, Retrieve,
Retriever, Shape, CPU,
AddOperation, ApplyFunction, Base, Buffer, Combiner, Device, HostPtr, Retrieve, Retriever,
Shape, CPU,
};

use super::Lazy;
Expand Down Expand Up @@ -215,15 +215,15 @@ mod tests {
T: Add<Output = T> + Copy + 'static,
D: Device + 'static,
D::Data<T, S>: Deref<Target = [T]>,
Mods::Wrap<T, CPUPtr<T>>: core::ops::DerefMut<Target = [T]>,
Self::Data<T, S>: HostPtr<T>,
S: Shape,
Mods: AddOperation + Retrieve<Self, T, S> + 'static,
{
#[inline]
fn add(&self, lhs: &Buffer<T, D, S>, rhs: &Buffer<T, D, S>) -> Buffer<T, Self, S> {
let mut out = self.retrieve(lhs.len(), ());
self.add_op((lhs, rhs, &mut out), |(lhs, rhs, out)| {
add_ew_slice(lhs, rhs, out);
add_ew_slice(lhs, rhs, out.as_mut_slice());
Ok(())
})
.unwrap();
Expand Down
12 changes: 9 additions & 3 deletions src/modules/lazy/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,21 @@ impl<Mods: WrappedData> WrappedData for Lazy<Mods> {

#[inline]
fn wrap_in_base<T, Base: HasId + PtrType>(&self, base: Base) -> Self::Wrap<T, Base> {
todo!()
// self.modules.wrap_in_base(base)
LazyWrapper {
data: Some(self.modules.wrap_in_base(base)),
id: None,
_pd: PhantomData,
}
}
}

impl<Data: HasId, T> HasId for LazyWrapper<Data, T> {
#[inline]
fn id(&self) -> crate::Id {
self.id.unwrap()
match self.id {
Some(id) => id,
None => self.data.as_ref().unwrap().id(),
}
}
}

Expand Down
7 changes: 5 additions & 2 deletions src/modules/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,11 @@ pub(crate) unsafe fn register_buf<T, D, S>(
D::Data<T, S>: ShallowCopy,
S: Shape,
{
// buf.data.
let wrapped_data = D::convert::<T, S, T, S>(&buf.data, AllocFlag::Wrapper);
// shallow copy sets flag to AllocFlag::Wrapper
let wrapped_data = buf.data.shallow();

// let wrapped_data = D::convert::<T, S, T, S>(&buf.data, AllocFlag::Wrapper);

let buf = Buffer {
data: wrapped_data,
device: buf.device,
Expand Down

0 comments on commit ddccb7d

Please sign in to comment.