diff --git a/src/buffer/num.rs b/src/buffer/num.rs index 6813ce9f..cf2d1bde 100644 --- a/src/buffer/num.rs +++ b/src/buffer/num.rs @@ -27,6 +27,9 @@ impl PtrType for Num { fn flag(&self) -> crate::flag::AllocFlag { crate::flag::AllocFlag::Num } + + #[inline] + unsafe fn set_flag(&mut self, _flag: AllocFlag) {} } impl CommonPtrs for Num { diff --git a/src/device_traits.rs b/src/device_traits.rs index 9fca3e89..0ed98a77 100644 --- a/src/device_traits.rs +++ b/src/device_traits.rs @@ -1,7 +1,5 @@ // TODO: move to devices folder ig -use core::ops::Deref; - use crate::{flag::AllocFlag, prelude::Device, Buffer, HasId, Parents, PtrType, Shape, StackArray}; pub trait Alloc: Device + Sized { diff --git a/src/devices/cpu/cpu_device.rs b/src/devices/cpu/cpu_device.rs index d89f7780..5e3936ac 100644 --- a/src/devices/cpu/cpu_device.rs +++ b/src/devices/cpu/cpu_device.rs @@ -8,7 +8,7 @@ use crate::{ cpu::CPUPtr, flag::AllocFlag, impl_buffer_hook_traits, impl_retriever, pass_down_grad_fn, pass_down_optimize_mem_graph, pass_down_tape_actions, Alloc, Base, Buffer, CloneBuf, Device, DevicelessAble, HasModules, Module, OnDropBuffer, OnNewBuffer, PtrConv, Setup, Shape, - WrappedData, + WrappedData, PtrType, }; pub trait IsCPU {} @@ -163,6 +163,7 @@ impl PtrConv> for CP data: &Mods::Wrap>, flag: AllocFlag, ) -> OtherMods::Wrap> { + // data.flag() todo!() // CPUPtr { // ptr: data.ptr as *mut Conv, diff --git a/src/devices/cpu/cpu_ptr.rs b/src/devices/cpu/cpu_ptr.rs index 79c9ba8e..44391674 100644 --- a/src/devices/cpu/cpu_ptr.rs +++ b/src/devices/cpu/cpu_ptr.rs @@ -207,6 +207,11 @@ impl PtrType for CPUPtr { fn flag(&self) -> AllocFlag { self.flag } + + #[inline] + unsafe fn set_flag(&mut self, flag: AllocFlag) { + self.flag = flag + } } impl CommonPtrs for CPUPtr { diff --git a/src/devices/cpu/ops.rs b/src/devices/cpu/ops.rs index 5ff7afbf..7f93df85 100644 --- a/src/devices/cpu/ops.rs +++ b/src/devices/cpu/ops.rs @@ -11,8 +11,6 @@ use crate::{ pass_down_add_operation!(CPU); pass_down_exec_now!(CPU); -pub fn slice(x: &[T]) {} - impl ApplyFunction for CPU where Mods: Retrieve + AddOperation + 'static, diff --git a/src/devices/stack_array.rs b/src/devices/stack_array.rs index a247a784..89f668db 100644 --- a/src/devices/stack_array.rs +++ b/src/devices/stack_array.rs @@ -120,6 +120,9 @@ impl PtrType for StackArray { fn flag(&self) -> crate::flag::AllocFlag { crate::flag::AllocFlag::None } + + #[inline] + unsafe fn set_flag(&mut self, _flag: crate::flag::AllocFlag) {} } impl CommonPtrs for StackArray { diff --git a/src/lib.rs b/src/lib.rs index 69432146..202b4f81 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -144,6 +144,7 @@ pub trait PtrType { fn size(&self) -> usize; /// Returns the [`AllocFlag`]. fn flag(&self) -> AllocFlag; + unsafe fn set_flag(&mut self, flag: AllocFlag); } pub trait HostPtr: PtrType { diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index 2308027e..36300ec8 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -60,6 +60,11 @@ impl PtrType for LazyWrapper { fn flag(&self) -> crate::flag::AllocFlag { self.data.flag() } + + #[inline] + unsafe fn set_flag(&mut self, flag: crate::flag::AllocFlag) { + self.data.set_flag(flag) + } } impl Deref for LazyWrapper { @@ -145,7 +150,7 @@ impl> Setup for Lazy { } } -impl, D: LazyRun + PtrConv + 'static> RunModule for Lazy { +impl, D: LazyRun + Device + 'static> RunModule for Lazy { #[inline] fn run(&self, device: &D) -> crate::Result<()> { unsafe { self.call_lazily::()? };