From 1dd80e48f1e7cc0d345cee06f6dbeb84a539b4be Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Tue, 16 Jul 2024 23:41:50 +0200 Subject: [PATCH 01/29] Add add_op2 to AddOperation --- examples/custom_device.rs | 2 +- src/devices/nnapi/nnapi_device.rs | 4 ++-- src/devices/wgsl/wgsl_device.rs | 10 +++++++++- src/features.rs | 13 +++++++++++++ src/modules/autograd/tape.rs | 4 ++-- src/modules/base.rs | 16 ++++++++++++---- src/modules/cached.rs | 19 +++++++++++++++---- src/modules/lazy.rs | 23 ++++++++++++++++++++--- src/modules/lazy/lazy_graph.rs | 5 +++++ 9 files changed, 79 insertions(+), 17 deletions(-) diff --git a/examples/custom_device.rs b/examples/custom_device.rs index 975c4323..c0c2518e 100644 --- a/examples/custom_device.rs +++ b/examples/custom_device.rs @@ -10,7 +10,7 @@ use custos::{ cpu::CPUPtr, flag::AllocFlag, impl_device_traits, AddGradFn, AddOperation, Alloc, Base, BorrowCacheLT, Buffer, Cached, CachedModule, Device, DeviceError, DevicelessAble, HasId, Id, LazyGraph2, Module, OnDropBuffer, OnNewBuffer, PtrType, Retrieve, Retriever, Setup, Shape, - TapeActions, Tape, Unit, WrappedData, CPU, + Tape, TapeActions, Unit, WrappedData, CPU, }; pub trait Str { diff --git a/src/devices/nnapi/nnapi_device.rs b/src/devices/nnapi/nnapi_device.rs index 03650aa7..0b5bae10 100644 --- a/src/devices/nnapi/nnapi_device.rs +++ b/src/devices/nnapi/nnapi_device.rs @@ -64,8 +64,8 @@ impl Device for NnapiDevice { unsafe impl IsShapeIndep for NnapiDevice {} -impl<'a, U, T: Unit, D: Device, S: Shape, Mods: crate::OnNewBuffer<'a, T, D, S>> crate::OnNewBuffer<'a, T, D, S> - for NnapiDevice +impl<'a, U, T: Unit, D: Device, S: Shape, Mods: crate::OnNewBuffer<'a, T, D, S>> + crate::OnNewBuffer<'a, T, D, S> for NnapiDevice { #[inline] fn on_new_buffer(&self, device: &'a D, new_buf: &Buffer<'a, T, D, S>) { diff --git a/src/devices/wgsl/wgsl_device.rs b/src/devices/wgsl/wgsl_device.rs index eea32c4a..a5d83f93 100644 --- a/src/devices/wgsl/wgsl_device.rs +++ b/src/devices/wgsl/wgsl_device.rs @@ -182,6 +182,14 @@ impl, T: Unit, Mods: Retrieve, S: Shape> Retrie } impl AddOperation for Wgsl { + fn add_op2 + crate::AnyOp, const N: usize>( + &self, + args: Args, + op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, + ) { + self.modules.add_op2(args, op) + } + #[inline] fn add_op + crate::UpdateArgs, const N: usize>( &self, @@ -200,7 +208,7 @@ impl AddOperation for Wgsl { fn set_lazy_enabled(&self, enabled: bool) { self.modules.set_lazy_enabled(enabled) } - + #[inline] fn is_lazy_enabled(&self) -> bool { self.modules.is_lazy_enabled() diff --git a/src/features.rs b/src/features.rs index 9ece5760..2d4020c6 100644 --- a/src/features.rs +++ b/src/features.rs @@ -367,6 +367,12 @@ macro_rules! pass_down_replace_buf_module { } pub trait AddOperation { + fn add_op2 + AnyOp, const N: usize>( + &self, + args: Args, + op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, + ); + fn add_op + UpdateArgs, const N: usize>( &self, args: Args, @@ -433,6 +439,13 @@ macro_rules! pass_down_add_operation { ) -> $crate::Result<()> { self.modules.add_op(args, operation) } + fn add_op2 + $crate::AnyOp, const N: usize>( + &self, + args: Args, + op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, + ) { + self.modules.add_op2(args, op) + } #[inline] fn ops_count(&self) -> usize { diff --git a/src/modules/autograd/tape.rs b/src/modules/autograd/tape.rs index 73e74ca3..89433444 100644 --- a/src/modules/autograd/tape.rs +++ b/src/modules/autograd/tape.rs @@ -1,6 +1,6 @@ use crate::{ - AddOperation, Alloc, AnyOp, BoxedShallowCopy, Buffer, Buffers, GradActions, - LazyGraph2, Parents, Shape, Unit, WriteBuf, ZeroGrad, + AddOperation, Alloc, AnyOp, BoxedShallowCopy, Buffer, Buffers, GradActions, LazyGraph2, + Parents, Shape, Unit, WriteBuf, ZeroGrad, }; use super::Gradients; diff --git a/src/modules/base.rs b/src/modules/base.rs index d7b9c472..08dfd7cf 100644 --- a/src/modules/base.rs +++ b/src/modules/base.rs @@ -36,9 +36,12 @@ impl<'a, D: Device + 'a> Module<'a, D> for Base { } impl AddOperation for Base { - #[inline] - fn ops_count(&self) -> usize { - 0 + fn add_op2 + crate::AnyOp, const N: usize>( + &self, + args: Args, + op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, + ) { + todo!() } #[inline] @@ -51,8 +54,13 @@ impl AddOperation for Base { } #[inline] - fn set_lazy_enabled(&self, _enabled: bool) {} + fn ops_count(&self) -> usize { + 0 + } + #[inline] + fn set_lazy_enabled(&self, _enabled: bool) {} + #[inline] fn is_lazy_enabled(&self) -> bool { false diff --git a/src/modules/cached.rs b/src/modules/cached.rs index afcda432..417af9b7 100644 --- a/src/modules/cached.rs +++ b/src/modules/cached.rs @@ -70,9 +70,15 @@ impl, D: Device, NewDev> Setup for CachedModule AddOperation for CachedModule { - #[inline] - fn ops_count(&self) -> usize { - self.modules.ops_count() + fn add_op2 + crate::AnyOp, const N: usize>( + &self, + args: Args, + op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, + ) { + let op = crate::LazyGraph2::>::convert_to_operation(args, op); + + // op(Args::replication_fn(ids, op)) + todo!() } fn add_op, const N: usize>( @@ -83,11 +89,16 @@ impl AddOperation for CachedModule { operation(&mut args) } + #[inline] + fn ops_count(&self) -> usize { + self.modules.ops_count() + } + #[inline] fn set_lazy_enabled(&self, enabled: bool) { self.modules.set_lazy_enabled(enabled) } - + #[inline] fn is_lazy_enabled(&self) -> bool { self.modules.is_lazy_enabled() diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index 20d16762..38e50b77 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -44,6 +44,7 @@ pub struct Lazy { // TODO: remove this, fix id and address collision - then just use `buffers` for duplicate calls allocated_ids: RefCell, pub graph: RefCell, T>>, + pub graph2: RefCell, T>>, cursor: Cell, enabled: Cell, pd: PhantomData, @@ -80,6 +81,7 @@ impl<'a, T, Mods: Module<'a, D>, D: LazySetup + Device + 'a> Module<'a, D> for L buffers: Default::default(), replaced_buffers: Default::default(), graph: Default::default(), + graph2: Default::default(), alloc_later: Default::default(), allocated_ids: Default::default(), cursor: Default::default(), @@ -90,9 +92,18 @@ impl<'a, T, Mods: Module<'a, D>, D: LazySetup + Device + 'a> Module<'a, D> for L } impl AddOperation for Lazy { - #[inline] - fn ops_count(&self) -> usize { - self.graph.borrow().operations.len() + fn add_op2 + AnyOp, const N: usize>( + &self, + args: Args, + op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, + ) { + if self.enabled.get() { + self.graph2.try_borrow_mut() + .expect("already borrowed: BorrowMutError - is the inner operation trying to add an operation as well?") + .add_operation(args, op); + } else { + self.modules.add_op2(args, op); + } } #[inline] @@ -111,6 +122,11 @@ impl AddOperation for Lazy { Ok(()) } + #[inline] + fn ops_count(&self) -> usize { + self.graph.borrow().operations.len() + } + #[inline] fn set_lazy_enabled(&self, enabled: bool) { self.enabled.set(enabled); @@ -292,6 +308,7 @@ impl AddLayer for Lazy<(), T> { buffers: Default::default(), replaced_buffers: Default::default(), graph: Default::default(), + graph2: Default::default(), alloc_later: Default::default(), allocated_ids: Default::default(), cursor: Default::default(), diff --git a/src/modules/lazy/lazy_graph.rs b/src/modules/lazy/lazy_graph.rs index 009ff1fc..ced760a7 100644 --- a/src/modules/lazy/lazy_graph.rs +++ b/src/modules/lazy/lazy_graph.rs @@ -65,6 +65,11 @@ impl<'a, B: Downcast, T> LazyGraph2<'a, B, T> { self.operations.clear(); } + #[inline] + pub fn ops_count(&self) -> usize { + self.operations.len() + } + pub fn call_lazily(&mut self, buffers: &mut Buffers) -> crate::Result<()> { for args in self.iter_with(buffers) { args?; From 0d54d19a52979352ec59f1619a439541e08c30a0 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Tue, 23 Jul 2024 23:01:29 +0200 Subject: [PATCH 02/29] Add another lifetime param to Replicate --- Cargo.toml | 2 +- src/any_op.rs | 58 ++++++++++++++++++++++++---------- src/devices/cpu/ops.rs | 13 +++++--- src/features.rs | 12 +++---- src/modules/base.rs | 8 ++--- src/modules/cached.rs | 8 ++--- src/modules/lazy.rs | 11 ++++--- src/modules/lazy/lazy_graph.rs | 8 ++--- src/parents.rs | 26 +++++++++++++++ 9 files changed, 101 insertions(+), 45 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a361f77e..8ff96681 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true } [features] # default = ["cpu", "opencl", "cuda", "blas", "static-api", "stack", "macro", "nnapi", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph", "serde"] # default = ["no-std"] -default = ["no-std"] +default = ["cpu", "opencl", "autograd"] # default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "cuda", "vulkan"] diff --git a/src/any_op.rs b/src/any_op.rs index ca302606..990cb238 100644 --- a/src/any_op.rs +++ b/src/any_op.rs @@ -13,50 +13,74 @@ pub trait AnyOp: Sized { ) -> Box Fn(&'i mut Buffers) -> crate::Result<()>>; } -pub trait AnyOp2<'dev>: Sized { - type Replicated<'a>; +pub trait AnyOp2<'own, 'dev>: Sized { + type Replicated<'a, 'b> where 'b: 'a; #[cfg(feature = "std")] fn replication_fn( ids: Vec, - op: impl for<'a> Fn(Self::Replicated<'a>) -> crate::Result<()> + 'static, + op: impl for<'a, 'b> Fn(Self::Replicated<'a, 'b>) -> crate::Result<()> + 'static, ) -> Box Fn(&'i mut Buffers) -> crate::Result<()>>; - fn replication(self) -> Self::Replicated<'dev>; + fn replication(self) -> Self::Replicated<'own, 'dev>; } -pub trait Replicate2<'dev> { - type Replication<'r>; +pub trait Replicate2<'own, 'dev> { + type Replication<'r, 'd> where 'd: 'r; type Downcast<'r>: 'r; - fn replicate(self) -> Self::Replication<'dev>; + fn replicate(self) -> Self::Replication<'own, 'dev>; #[cfg(feature = "std")] unsafe fn replicate_borrowed<'r, B: Downcast>( id: &Id, buffers: &'r mut Buffers, - ) -> Option>; + ) -> Option>; } -impl<'a, T: 'static, D: Device + 'static, S: crate::Shape> Replicate2<'a> - for &'a crate::Buffer<'a, T, D, S> +impl<'own, 'dev, T: 'static, D: Device + 'static, S: crate::Shape> Replicate2<'own, 'dev> + for &'own crate::Buffer<'dev, T, D, S> { - type Replication<'r> = &'r Buffer<'r, T, D, S>; + type Replication<'r, 'd> = &'r Buffer<'r, T, D, S> where 'd: 'r; type Downcast<'r> = Buffer<'r, T, D, S>; #[cfg(feature = "std")] unsafe fn replicate_borrowed<'r, B: Downcast>( id: &Id, buffers: &'r mut Buffers, - ) -> Option> { + ) -> Option> { buffers.get(id)?.downcast_ref::>() } - fn replicate(self) -> Self::Replication<'a> { + fn replicate(self) -> Self::Replication<'own, 'dev> { self } } +impl<'own, 'dev, T: 'static, D: Device + 'static, S: crate::Shape> Replicate2<'own, 'dev> + for &'own mut crate::Buffer<'dev, T, D, S> +{ + type Replication<'r, 'd> = &'r mut Self::Downcast<'d> where 'd: 'r; + type Downcast<'r> = Buffer<'r, T, D, S>; + + #[cfg(feature = "std")] + unsafe fn replicate_borrowed<'r, B: Downcast>( + id: &Id, + buffers: &'r mut Buffers, + ) -> Option> { + let replication = buffers.get_mut(id)?; + if !replication.is::>() { + return None; + } + Some(unsafe { replication.downcast_mut_unchecked::>() }) + } + + fn replicate(self) -> Self::Replication<'own, 'dev> { + self + } +} + + pub trait Replicate { type Replication<'r>; type Downcast<'r>: 'r; @@ -119,13 +143,13 @@ impl AnyOp for R { type Replicated<'a> = R::Replication<'a>; } -impl<'dev, R: crate::HasId + Replicate2<'dev>> AnyOp2<'dev> for R { - type Replicated<'a> = R::Replication<'a>; +impl<'own, 'dev, R: crate::HasId + Replicate2<'own, 'dev>> AnyOp2<'own, 'dev> for R { + type Replicated<'a, 'b> = R::Replication<'a, 'b> where 'b: 'a; #[cfg(feature = "std")] fn replication_fn( ids: Vec, - op: impl for<'a> Fn(Self::Replicated<'a>) -> crate::Result<()> + 'static, + op: impl for<'a> Fn(Self::Replicated<'a, 'a>) -> crate::Result<()> + 'static, ) -> Box) -> crate::Result<()>> { use crate::DeviceError; @@ -136,7 +160,7 @@ impl<'dev, R: crate::HasId + Replicate2<'dev>> AnyOp2<'dev> for R { }) } - fn replication(self) -> Self::Replicated<'dev> { + fn replication(self) -> Self::Replicated<'own, 'dev> { self.replicate() } } diff --git a/src/devices/cpu/ops.rs b/src/devices/cpu/ops.rs index 77174e54..65c5ad06 100644 --- a/src/devices/cpu/ops.rs +++ b/src/devices/cpu/ops.rs @@ -30,11 +30,16 @@ where { let mut out = self.retrieve(buf.len(), buf).unwrap(); - self.add_op((&mut out, buf, f.no_id()), move |(out, buf, f)| { - apply_fn_slice(buf, out, **f); + self.add_op2((&out), move |(out)| { + Ok(()) - }) - .unwrap(); + }); + + // self.add_op((&mut out, buf, f.no_id()), move |(out, buf, f)| { + // apply_fn_slice(buf, out, **f); + // Ok(()) + // }) + // .unwrap(); self.set_op_hint(unary(f)); diff --git a/src/features.rs b/src/features.rs index 2d4020c6..e42e372f 100644 --- a/src/features.rs +++ b/src/features.rs @@ -367,11 +367,11 @@ macro_rules! pass_down_replace_buf_module { } pub trait AddOperation { - fn add_op2 + AnyOp, const N: usize>( + fn add_op2<'own, 'dev: 'own, Args: Parents + crate::AnyOp2<'own, 'dev>, const N: usize>( &self, args: Args, - op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, - ); + op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'b>) -> crate::Result<()> + 'static, + ) -> crate::Result<()>; fn add_op + UpdateArgs, const N: usize>( &self, @@ -439,11 +439,11 @@ macro_rules! pass_down_add_operation { ) -> $crate::Result<()> { self.modules.add_op(args, operation) } - fn add_op2 + $crate::AnyOp, const N: usize>( + fn add_op2<'own, 'd: 'own, Args: $crate::Parents + $crate::AnyOp2<'own, 'd>, const N: usize>( &self, args: Args, - op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, - ) { + op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'b>) -> crate::Result<()> + 'static, + ) -> $crate::Result<()> { self.modules.add_op2(args, op) } diff --git a/src/modules/base.rs b/src/modules/base.rs index 08dfd7cf..eac1ec0c 100644 --- a/src/modules/base.rs +++ b/src/modules/base.rs @@ -36,12 +36,12 @@ impl<'a, D: Device + 'a> Module<'a, D> for Base { } impl AddOperation for Base { - fn add_op2 + crate::AnyOp, const N: usize>( + fn add_op2<'own, 'dev: 'own, Args: Parents + crate::AnyOp2<'own, 'dev> , const N: usize>( &self, args: Args, - op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, - ) { - todo!() + op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'b>) -> crate::Result<()> + 'static, + ) -> crate::Result<()> { + op(args.replication()) } #[inline] diff --git a/src/modules/cached.rs b/src/modules/cached.rs index 417af9b7..704031ef 100644 --- a/src/modules/cached.rs +++ b/src/modules/cached.rs @@ -70,12 +70,12 @@ impl, D: Device, NewDev> Setup for CachedModule AddOperation for CachedModule { - fn add_op2 + crate::AnyOp, const N: usize>( + fn add_op2<'own, 'dev, Args: Parents + crate::AnyOp2<'own, 'dev>, const N: usize>( &self, args: Args, - op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, - ) { - let op = crate::LazyGraph2::>::convert_to_operation(args, op); + op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'b>) -> crate::Result<()> + 'static, + ) -> crate::Result<()> { + // let op = crate::LazyGraph2::>::convert_to_operation(args, op); // op(Args::replication_fn(ids, op)) todo!() diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index 38e50b77..194f32f1 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -92,17 +92,18 @@ impl<'a, T, Mods: Module<'a, D>, D: LazySetup + Device + 'a> Module<'a, D> for L } impl AddOperation for Lazy { - fn add_op2 + AnyOp, const N: usize>( + fn add_op2<'own, 'dev: 'own, Args: Parents + crate::AnyOp2<'own, 'dev>, const N: usize>( &self, args: Args, - op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, - ) { + op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'b>) -> crate::Result<()> + 'static, + ) -> crate::Result<()> { if self.enabled.get() { self.graph2.try_borrow_mut() .expect("already borrowed: BorrowMutError - is the inner operation trying to add an operation as well?") - .add_operation(args, op); + .add_operation2(args, op); + Ok(()) } else { - self.modules.add_op2(args, op); + self.modules.add_op2(args, op) } } diff --git a/src/modules/lazy/lazy_graph.rs b/src/modules/lazy/lazy_graph.rs index dd58afd4..7afa97fa 100644 --- a/src/modules/lazy/lazy_graph.rs +++ b/src/modules/lazy/lazy_graph.rs @@ -90,9 +90,9 @@ impl<'a, B: Downcast, T> LazyGraph2<'a, B, T> { Ok(()) } - pub fn convert_to_operation2<'dev, Args: Parents + crate::AnyOp2<'dev>, const N: usize>( + pub fn convert_to_operation2<'own, 'dev, Args: Parents + crate::AnyOp2<'own, 'dev>, const N: usize>( args: Args, - op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, + op: impl for<'r, 'b> Fn(Args::Replicated<'r, 'b>) -> crate::Result<()> + 'static, ) -> Operation2<'a, B, T> { const { assert!(N > 0, "Size of parents must be greater than 0") }; @@ -175,10 +175,10 @@ impl<'a, B: Downcast, T> LazyGraph2<'a, B, T> { self.operations.push(operation) } - pub fn add_operation2<'dev, Args: Parents + crate::AnyOp2<'dev>, const N: usize>( + pub fn add_operation2<'own, 'dev: 'own, Args: Parents + crate::AnyOp2<'own, 'dev>, const N: usize>( &mut self, args: Args, - op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, + op: impl for<'r, 'b> Fn(Args::Replicated<'r, 'b>) -> crate::Result<()> + 'static, ) { let operation = Self::convert_to_operation2(args, op); self.operations.push(operation) diff --git a/src/parents.rs b/src/parents.rs index cb0b8714..70d44a25 100644 --- a/src/parents.rs +++ b/src/parents.rs @@ -93,6 +93,32 @@ macro_rules! impl_parents { } } + // impl<'dev, $($to_impl: $crate::Replicate2<'dev> + $crate::HasId, )+> $crate::AnyOp2<'dev> for ($($to_impl,)+) { + // type Replicated<'a> = ($($to_impl::Replication<'a>,)+); + + // #[cfg(feature = "std")] + // fn replication_fn( + // ids: Vec<$crate::Id>, + // op: impl for<'a> Fn(Self::Replicated<'a>) -> $crate::Result<()> + 'static, + // ) -> Box) -> $crate::Result<()>> { + // Box::new(move |buffers| { + // let mut ids = ids.iter(); + + // op(($( + // unsafe { + // $to_impl::replicate_borrowed(ids.next().unwrap(), &mut *(buffers as *mut _)).ok_or(crate::DeviceError::InvalidLazyBuf)? + // } + // ,)+)) + // }) + // } + // #[inline] + // fn replication(self) -> Self::Replicated<'dev> { + // #[allow(non_snake_case)] + // let ($($to_impl,)+) = self; + // ($($to_impl.replicate(),)+) + // } + // } + impl<$($to_impl: $crate::Replicate + $crate::HasId, )+> $crate::AnyOp for ($($to_impl,)+) { type Replicated<'a> = ($($to_impl::Replication<'a>,)+); From a985577299335ab00a02dc226a2e4b9d1642b8f7 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Wed, 24 Jul 2024 01:29:33 +0200 Subject: [PATCH 03/29] Use transmute for extending lifetime in replication --- src/any_op.rs | 55 +++++++++++++++++++++------------- src/features.rs | 8 ++--- src/modules/base.rs | 6 ++-- src/modules/cached.rs | 2 +- src/modules/lazy.rs | 4 +-- src/modules/lazy/lazy_graph.rs | 22 ++++++++++---- 6 files changed, 60 insertions(+), 37 deletions(-) diff --git a/src/any_op.rs b/src/any_op.rs index 990cb238..f8613a72 100644 --- a/src/any_op.rs +++ b/src/any_op.rs @@ -14,22 +14,26 @@ pub trait AnyOp: Sized { } pub trait AnyOp2<'own, 'dev>: Sized { - type Replicated<'a, 'b> where 'b: 'a; + type Replicated<'a, 'b> + where + 'b: 'a; #[cfg(feature = "std")] fn replication_fn( ids: Vec, - op: impl for<'a, 'b> Fn(Self::Replicated<'a, 'b>) -> crate::Result<()> + 'static, + op: impl for<'a, 'b> Fn(Self::Replicated<'a, 'a>) -> crate::Result<()> + 'static, ) -> Box Fn(&'i mut Buffers) -> crate::Result<()>>; - fn replication(self) -> Self::Replicated<'own, 'dev>; + unsafe fn replication<'iown, 'idev>(self) -> Self::Replicated<'iown, 'idev>; } -pub trait Replicate2<'own, 'dev> { - type Replication<'r, 'd> where 'd: 'r; +pub trait Replicate2<'uown, 'udev> { + type Replication<'r, 'd> + where + 'd: 'r; type Downcast<'r>: 'r; - fn replicate(self) -> Self::Replication<'own, 'dev>; + unsafe fn replicate<'own, 'dev>(self) -> Self::Replication<'own, 'dev>; #[cfg(feature = "std")] unsafe fn replicate_borrowed<'r, B: Downcast>( @@ -38,8 +42,8 @@ pub trait Replicate2<'own, 'dev> { ) -> Option>; } -impl<'own, 'dev, T: 'static, D: Device + 'static, S: crate::Shape> Replicate2<'own, 'dev> - for &'own crate::Buffer<'dev, T, D, S> +impl<'uown, 'a, T: 'static, D: Device + 'static, S: crate::Shape> Replicate2<'uown, 'a> + for &'uown crate::Buffer<'a, T, D, S> { type Replication<'r, 'd> = &'r Buffer<'r, T, D, S> where 'd: 'r; type Downcast<'r> = Buffer<'r, T, D, S>; @@ -51,14 +55,18 @@ impl<'own, 'dev, T: 'static, D: Device + 'static, S: crate::Shape> Replicate2<'o ) -> Option> { buffers.get(id)?.downcast_ref::>() } - - fn replicate(self) -> Self::Replication<'own, 'dev> { - self + + unsafe fn replicate<'own, 'dev: 'own>(self) -> Self::Replication<'own, 'dev> { + // TODO: this should work without this trick -> move 'own, 'dev up to the trait when something like for<'a: 'b, ...> starts to work + // look at commit "0d54d19a52979352ec59f1619a439541e08c30a0" - it was implemented like this there + // but than something like this happens: https://github.com/rust-lang/rust/issues/100013 + // most of the "double lifetime stuff" is still implemented at the moment + unsafe { core::mem::transmute::>(self) } } } -impl<'own, 'dev, T: 'static, D: Device + 'static, S: crate::Shape> Replicate2<'own, 'dev> - for &'own mut crate::Buffer<'dev, T, D, S> +impl<'uown, 'udev, T: 'static, D: Device + 'static, S: crate::Shape> Replicate2<'uown, 'udev> + for &'uown mut crate::Buffer<'udev, T, D, S> { type Replication<'r, 'd> = &'r mut Self::Downcast<'d> where 'd: 'r; type Downcast<'r> = Buffer<'r, T, D, S>; @@ -74,13 +82,16 @@ impl<'own, 'dev, T: 'static, D: Device + 'static, S: crate::Shape> Replicate2<'o } Some(unsafe { replication.downcast_mut_unchecked::>() }) } - - fn replicate(self) -> Self::Replication<'own, 'dev> { - self + + unsafe fn replicate<'own, 'dev>(self) -> Self::Replication<'own, 'dev> { + // TODO: this should work without this trick -> move 'own, 'dev up to the trait when something like for<'a: 'b, ...> starts to work + // https://github.com/rust-lang/rust/issues/100013 + // look at commit "0d54d19a52979352ec59f1619a439541e08c30a0" - it was implemented like this there + // most of the "double lifetime stuff" is still implemented at the moment + unsafe { core::mem::transmute::>(self) } } } - pub trait Replicate { type Replication<'r>; type Downcast<'r>: 'r; @@ -155,12 +166,14 @@ impl<'own, 'dev, R: crate::HasId + Replicate2<'own, 'dev>> AnyOp2<'own, 'dev> fo let id = ids[0]; Box::new(move |buffers| { - let r1 = unsafe { R::replicate_borrowed(&id, buffers) }.ok_or(DeviceError::InvalidLazyBuf)?; + let r1 = unsafe { R::replicate_borrowed(&id, buffers) } + .ok_or(DeviceError::InvalidLazyBuf)?; op(r1) }) } - - fn replication(self) -> Self::Replicated<'own, 'dev> { - self.replicate() + + #[inline] + unsafe fn replication<'iown, 'idev: 'iown>(self) -> Self::Replicated<'iown, 'idev> { + unsafe { self.replicate() } } } diff --git a/src/features.rs b/src/features.rs index e42e372f..57ca86c6 100644 --- a/src/features.rs +++ b/src/features.rs @@ -367,10 +367,10 @@ macro_rules! pass_down_replace_buf_module { } pub trait AddOperation { - fn add_op2<'own, 'dev: 'own, Args: Parents + crate::AnyOp2<'own, 'dev>, const N: usize>( + fn add_op2<'own, 'dev, Args: Parents + crate::AnyOp2<'own, 'dev>, const N: usize>( &self, args: Args, - op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'b>) -> crate::Result<()> + 'static, + op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'a>) -> crate::Result<()> + 'static, ) -> crate::Result<()>; fn add_op + UpdateArgs, const N: usize>( @@ -439,10 +439,10 @@ macro_rules! pass_down_add_operation { ) -> $crate::Result<()> { self.modules.add_op(args, operation) } - fn add_op2<'own, 'd: 'own, Args: $crate::Parents + $crate::AnyOp2<'own, 'd>, const N: usize>( + fn add_op2<'own, 'd, Args: $crate::Parents + $crate::AnyOp2<'own, 'd>, const N: usize>( &self, args: Args, - op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'b>) -> crate::Result<()> + 'static, + op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'a>) -> crate::Result<()> + 'static, ) -> $crate::Result<()> { self.modules.add_op2(args, op) } diff --git a/src/modules/base.rs b/src/modules/base.rs index eac1ec0c..3f26fcac 100644 --- a/src/modules/base.rs +++ b/src/modules/base.rs @@ -36,12 +36,12 @@ impl<'a, D: Device + 'a> Module<'a, D> for Base { } impl AddOperation for Base { - fn add_op2<'own, 'dev: 'own, Args: Parents + crate::AnyOp2<'own, 'dev> , const N: usize>( + fn add_op2<'own, 'dev, Args: Parents + crate::AnyOp2<'own, 'dev> , const N: usize>( &self, args: Args, - op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'b>) -> crate::Result<()> + 'static, + op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'a>) -> crate::Result<()> + 'static, ) -> crate::Result<()> { - op(args.replication()) + op(unsafe { args.replication() }) } #[inline] diff --git a/src/modules/cached.rs b/src/modules/cached.rs index 704031ef..6fac785f 100644 --- a/src/modules/cached.rs +++ b/src/modules/cached.rs @@ -73,7 +73,7 @@ impl AddOperation for CachedModule { fn add_op2<'own, 'dev, Args: Parents + crate::AnyOp2<'own, 'dev>, const N: usize>( &self, args: Args, - op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'b>) -> crate::Result<()> + 'static, + op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'a>) -> crate::Result<()> + 'static, ) -> crate::Result<()> { // let op = crate::LazyGraph2::>::convert_to_operation(args, op); diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index 194f32f1..e284bb58 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -92,10 +92,10 @@ impl<'a, T, Mods: Module<'a, D>, D: LazySetup + Device + 'a> Module<'a, D> for L } impl AddOperation for Lazy { - fn add_op2<'own, 'dev: 'own, Args: Parents + crate::AnyOp2<'own, 'dev>, const N: usize>( + fn add_op2<'own, 'dev, Args: Parents + crate::AnyOp2<'own, 'dev>, const N: usize>( &self, args: Args, - op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'b>) -> crate::Result<()> + 'static, + op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'a>) -> crate::Result<()> + 'static, ) -> crate::Result<()> { if self.enabled.get() { self.graph2.try_borrow_mut() diff --git a/src/modules/lazy/lazy_graph.rs b/src/modules/lazy/lazy_graph.rs index 7afa97fa..6032247f 100644 --- a/src/modules/lazy/lazy_graph.rs +++ b/src/modules/lazy/lazy_graph.rs @@ -89,10 +89,15 @@ impl<'a, B: Downcast, T> LazyGraph2<'a, B, T> { } Ok(()) } - - pub fn convert_to_operation2<'own, 'dev, Args: Parents + crate::AnyOp2<'own, 'dev>, const N: usize>( + + pub fn convert_to_operation2< + 'own, + 'dev, + Args: Parents + crate::AnyOp2<'own, 'dev>, + const N: usize, + >( args: Args, - op: impl for<'r, 'b> Fn(Args::Replicated<'r, 'b>) -> crate::Result<()> + 'static, + op: impl for<'r, 'b> Fn(Args::Replicated<'r, 'r>) -> crate::Result<()> + 'static, ) -> Operation2<'a, B, T> { const { assert!(N > 0, "Size of parents must be greater than 0") }; @@ -174,11 +179,16 @@ impl<'a, B: Downcast, T> LazyGraph2<'a, B, T> { let operation = Self::convert_to_operation(args, op); self.operations.push(operation) } - - pub fn add_operation2<'own, 'dev: 'own, Args: Parents + crate::AnyOp2<'own, 'dev>, const N: usize>( + + pub fn add_operation2< + 'own, + 'dev, + Args: Parents + crate::AnyOp2<'own, 'dev>, + const N: usize, + >( &mut self, args: Args, - op: impl for<'r, 'b> Fn(Args::Replicated<'r, 'b>) -> crate::Result<()> + 'static, + op: impl for<'r, 'b> Fn(Args::Replicated<'r, 'r>) -> crate::Result<()> + 'static, ) { let operation = Self::convert_to_operation2(args, op); self.operations.push(operation) From 79a856c97a84d006c1f860824d91bafa7ecb2a34 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Wed, 24 Jul 2024 01:52:32 +0200 Subject: [PATCH 04/29] Add owned replicate to AnyOp, Add 'a to ApplyFn --- src/any_op.rs | 40 +++++++++++++++++++++++++--- src/devices/cpu/ops.rs | 12 ++++----- src/modules/cached.rs | 5 +--- src/parents.rs | 59 +++++++++++++++++++++++------------------- src/unary.rs | 18 ++++++------- 5 files changed, 85 insertions(+), 49 deletions(-) diff --git a/src/any_op.rs b/src/any_op.rs index f8613a72..c8202c2c 100644 --- a/src/any_op.rs +++ b/src/any_op.rs @@ -11,6 +11,8 @@ pub trait AnyOp: Sized { ids: Vec, op: impl for<'a> Fn(Self::Replicated<'a>) -> crate::Result<()> + 'static, ) -> Box Fn(&'i mut Buffers) -> crate::Result<()>>; + + unsafe fn replication<'a>(self) -> Self::Replicated<'a>; } pub trait AnyOp2<'own, 'dev>: Sized { @@ -61,6 +63,7 @@ impl<'uown, 'a, T: 'static, D: Device + 'static, S: crate::Shape> Replicate2<'uo // look at commit "0d54d19a52979352ec59f1619a439541e08c30a0" - it was implemented like this there // but than something like this happens: https://github.com/rust-lang/rust/issues/100013 // most of the "double lifetime stuff" is still implemented at the moment + // commit a985577299335ab00a02dc226a2e4b9d1642b8f7 introduced this line unsafe { core::mem::transmute::>(self) } } } @@ -88,6 +91,7 @@ impl<'uown, 'udev, T: 'static, D: Device + 'static, S: crate::Shape> Replicate2< // https://github.com/rust-lang/rust/issues/100013 // look at commit "0d54d19a52979352ec59f1619a439541e08c30a0" - it was implemented like this there // most of the "double lifetime stuff" is still implemented at the moment + // commit a985577299335ab00a02dc226a2e4b9d1642b8f7 introduced this line unsafe { core::mem::transmute::>(self) } } } @@ -97,10 +101,12 @@ pub trait Replicate { type Downcast<'r>: 'r; #[cfg(feature = "std")] - unsafe fn replicate<'r, B: Downcast>( + unsafe fn replicate_borrowed<'r, B: Downcast>( id: &Id, buffers: &'r mut Buffers, ) -> Option>; + + unsafe fn replicate<'a>(self) -> Self::Replication<'a>; } impl<'a, T: 'static, D: Device + 'static, S: crate::Shape> Replicate @@ -110,12 +116,23 @@ impl<'a, T: 'static, D: Device + 'static, S: crate::Shape> Replicate type Downcast<'r> = Buffer<'r, T, D, S>; #[cfg(feature = "std")] - unsafe fn replicate<'r, B: Downcast>( + #[inline] + unsafe fn replicate_borrowed<'r, B: Downcast>( id: &Id, buffers: &'r mut Buffers, ) -> Option> { buffers.get(id)?.downcast_ref::>() } + + #[inline] + unsafe fn replicate<'r>(self) -> Self::Replication<'r> { + // TODO: this should work without this trick -> move 'own, 'dev up to the trait when something like for<'a: 'b, ...> starts to work + // https://github.com/rust-lang/rust/issues/100013 + // look at commit "0d54d19a52979352ec59f1619a439541e08c30a0" - it was implemented like this there + // most of the "double lifetime stuff" is still implemented at the moment + // commit a985577299335ab00a02dc226a2e4b9d1642b8f7 introduced this line + unsafe { core::mem::transmute::>(self) } + } } impl<'a, T: 'static, D: Device + 'static, S: crate::Shape> Replicate @@ -125,7 +142,7 @@ impl<'a, T: 'static, D: Device + 'static, S: crate::Shape> Replicate type Downcast<'r> = Buffer<'r, T, D, S>; #[cfg(feature = "std")] - unsafe fn replicate<'r, B: Downcast>( + unsafe fn replicate_borrowed<'r, B: Downcast>( id: &Id, buffers: &'r mut Buffers, ) -> Option> { @@ -135,6 +152,16 @@ impl<'a, T: 'static, D: Device + 'static, S: crate::Shape> Replicate } Some(unsafe { replication.downcast_mut_unchecked::>() }) } + + #[inline] + unsafe fn replicate<'r>(self) -> Self::Replication<'r> { + // TODO: this should work without this trick -> move 'own, 'dev up to the trait when something like for<'a: 'b, ...> starts to work + // https://github.com/rust-lang/rust/issues/100013 + // look at commit "0d54d19a52979352ec59f1619a439541e08c30a0" - it was implemented like this there + // most of the "double lifetime stuff" is still implemented at the moment + // commit a985577299335ab00a02dc226a2e4b9d1642b8f7 introduced this line + unsafe { core::mem::transmute::>(self) } + } } impl AnyOp for R { @@ -147,11 +174,16 @@ impl AnyOp for R { let id = ids[0]; Box::new(move |buffers| { - let r1 = unsafe { R::replicate(&id, buffers) }.ok_or(DeviceError::InvalidLazyBuf)?; + let r1 = unsafe { R::replicate_borrowed(&id, buffers) }.ok_or(DeviceError::InvalidLazyBuf)?; op(r1) }) } type Replicated<'a> = R::Replication<'a>; + + #[inline] + unsafe fn replication<'a>(self) -> Self::Replicated<'a> { + self.replicate() + } } impl<'own, 'dev, R: crate::HasId + Replicate2<'own, 'dev>> AnyOp2<'own, 'dev> for R { diff --git a/src/devices/cpu/ops.rs b/src/devices/cpu/ops.rs index 65c5ad06..1350c11d 100644 --- a/src/devices/cpu/ops.rs +++ b/src/devices/cpu/ops.rs @@ -20,9 +20,9 @@ where D::Base: Deref, S: Shape, { - fn apply_fn( - &self, - buf: &Buffer, + fn apply_fn<'a, F>( + &'a self, + buf: &Buffer<'a, T, D, S>, f: impl Fn(Resolve) -> F + Copy + 'static, ) -> Buffer where @@ -30,10 +30,10 @@ where { let mut out = self.retrieve(buf.len(), buf).unwrap(); - self.add_op2((&out), move |(out)| { - + self.add_op2((&mut out, buf), move |(out, buf)| { + apply_fn_slice(buf, out, f); Ok(()) - }); + }).unwrap(); // self.add_op((&mut out, buf, f.no_id()), move |(out, buf, f)| { // apply_fn_slice(buf, out, **f); diff --git a/src/modules/cached.rs b/src/modules/cached.rs index 6fac785f..ff37bb17 100644 --- a/src/modules/cached.rs +++ b/src/modules/cached.rs @@ -75,10 +75,7 @@ impl AddOperation for CachedModule { args: Args, op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'a>) -> crate::Result<()> + 'static, ) -> crate::Result<()> { - // let op = crate::LazyGraph2::>::convert_to_operation(args, op); - - // op(Args::replication_fn(ids, op)) - todo!() + self.modules.add_op2(args, op) } fn add_op, const N: usize>( diff --git a/src/parents.rs b/src/parents.rs index 70d44a25..43a3b88e 100644 --- a/src/parents.rs +++ b/src/parents.rs @@ -93,31 +93,32 @@ macro_rules! impl_parents { } } - // impl<'dev, $($to_impl: $crate::Replicate2<'dev> + $crate::HasId, )+> $crate::AnyOp2<'dev> for ($($to_impl,)+) { - // type Replicated<'a> = ($($to_impl::Replication<'a>,)+); - - // #[cfg(feature = "std")] - // fn replication_fn( - // ids: Vec<$crate::Id>, - // op: impl for<'a> Fn(Self::Replicated<'a>) -> $crate::Result<()> + 'static, - // ) -> Box) -> $crate::Result<()>> { - // Box::new(move |buffers| { - // let mut ids = ids.iter(); - - // op(($( - // unsafe { - // $to_impl::replicate_borrowed(ids.next().unwrap(), &mut *(buffers as *mut _)).ok_or(crate::DeviceError::InvalidLazyBuf)? - // } - // ,)+)) - // }) - // } - // #[inline] - // fn replication(self) -> Self::Replicated<'dev> { - // #[allow(non_snake_case)] - // let ($($to_impl,)+) = self; - // ($($to_impl.replicate(),)+) - // } - // } + impl<'own, 'dev, $($to_impl: $crate::Replicate2<'own, 'dev> + $crate::HasId, )+> $crate::AnyOp2<'own, 'dev> for ($($to_impl,)+) { + type Replicated<'a, 'b> = ($($to_impl::Replication<'a, 'b>,)+) where 'b: 'a; + + #[cfg(feature = "std")] + fn replication_fn( + ids: Vec<$crate::Id>, + op: impl for<'a, 'b> Fn(Self::Replicated<'a, 'a>) -> $crate::Result<()> + 'static, + ) -> Box) -> $crate::Result<()>> { + Box::new(move |buffers| { + let mut ids = ids.iter(); + + op(($( + unsafe { + $to_impl::replicate_borrowed(ids.next().unwrap(), &mut *(buffers as *mut _)).ok_or(crate::DeviceError::InvalidLazyBuf)? + } + ,)+)) + }) + } + + #[inline] + unsafe fn replication<'iown, 'idev>(self) -> Self::Replicated<'iown, 'idev> { + #[allow(non_snake_case)] + let ($($to_impl,)+) = self; + ($($to_impl.replicate(),)+) + } + } impl<$($to_impl: $crate::Replicate + $crate::HasId, )+> $crate::AnyOp for ($($to_impl,)+) { type Replicated<'a> = ($($to_impl::Replication<'a>,)+); @@ -132,11 +133,17 @@ macro_rules! impl_parents { op(($( unsafe { - $to_impl::replicate(ids.next().unwrap(), &mut *(buffers as *mut _)).ok_or(crate::DeviceError::InvalidLazyBuf)? + $to_impl::replicate_borrowed(ids.next().unwrap(), &mut *(buffers as *mut _)).ok_or(crate::DeviceError::InvalidLazyBuf)? } ,)+)) }) } + #[inline] + unsafe fn replication<'a>(self) -> Self::Replicated<'a> { + #[allow(non_snake_case)] + let ($($to_impl,)+) = self; + ($($to_impl.replicate(),)+) + } } }; } diff --git a/src/unary.rs b/src/unary.rs index 3d824dc0..0d94ef9f 100644 --- a/src/unary.rs +++ b/src/unary.rs @@ -17,10 +17,10 @@ pub trait ApplyFunction: Device { /// let out = device.apply_fn(&a, |x| x.mul(2.)); /// assert_eq!(&**out, &[2., 4., 6., 6., 4., 2.,]); /// ``` - fn apply_fn( - &self, + fn apply_fn<'a, F>( + &'a self, // buf: &D::Data, - buf: &Buffer, + buf: &Buffer<'a, T, D, S>, f: impl Fn(Resolve) -> F + Copy + 'static, ) -> Buffer where @@ -83,9 +83,9 @@ pub trait UnaryElementWiseMayGrad: Device { /// out.backward(); /// assert_eq!(buf.grad().as_slice(), &[2.; 6]); /// ``` - fn unary_ew( - &self, - buf: &Buffer, + fn unary_ew<'a, FO, GO>( + &'a self, + buf: &Buffer<'a, T, D, S>, forward_fn: impl Fn(Resolve) -> FO + Copy + 'static, grad_fn: fn(Resolve) -> GO, ) -> Buffer @@ -103,9 +103,9 @@ where S: Shape, { #[inline(always)] - fn unary_ew( - &self, - buf: &Buffer, + fn unary_ew<'a, FO, GO>( + &'a self, + buf: &Buffer<'a, T, D, S>, forward_fn: impl Fn(Resolve) -> FO + Copy + 'static, grad_fn: fn(Resolve) -> GO, ) -> Buffer From 801c415fc6218e3aeba193a22b013828529d2dbe Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Wed, 24 Jul 2024 02:08:14 +0200 Subject: [PATCH 05/29] Update add_op --- examples/custom_device.rs | 24 ++++++++++++++++-------- examples/modules_usage.rs | 2 -- src/devices/cpu/ops.rs | 16 ++++++++-------- src/devices/opencl/ops.rs | 16 ++++++++-------- src/features.rs | 21 +++++++++++++-------- src/modules/base.rs | 15 ++++++++------- src/modules/cached.rs | 17 +++++++++-------- src/modules/lazy.rs | 26 ++++++++++++-------------- src/unary.rs | 6 +++--- 9 files changed, 77 insertions(+), 66 deletions(-) diff --git a/examples/custom_device.rs b/examples/custom_device.rs index c0c2518e..df6d16fe 100644 --- a/examples/custom_device.rs +++ b/examples/custom_device.rs @@ -209,14 +209,6 @@ impl<'a, T, S: Shape, D: Device, Mods: OnDropBuffer> Retrieve for Autog } impl<'a, Mods: OnDropBuffer> AddOperation for Autograd<'a, Mods> { - fn add_op + custos::UpdateArgs, const N: usize>( - &self, - args: Args, - operation: fn(&mut Args) -> custos::Result<()>, - ) -> custos::Result<()> { - todo!() - } - fn ops_count(&self) -> usize { todo!() } @@ -228,6 +220,22 @@ impl<'a, Mods: OnDropBuffer> AddOperation for Autograd<'a, Mods> { fn is_lazy_enabled(&self) -> bool { todo!() } + + fn add_op2<'own, 'dev, Args: custos::Parents + custos::AnyOp2<'own, 'dev>, const N: usize>( + &self, + args: Args, + op: impl for<'g, 'b> Fn(Args::Replicated<'g, 'g>) -> custos::Result<()> + 'static, + ) -> custos::Result<()> { + todo!() + } + + fn add_op + custos::AnyOp, const N: usize>( + &self, + args: Args, + op: impl for<'b> Fn(Args::Replicated<'b>) -> custos::Result<()> + 'static, + ) -> custos::Result<()> { + todo!() + } } impl<'a, Mods> Autograd<'a, Mods> { diff --git a/examples/modules_usage.rs b/examples/modules_usage.rs index 9a5de985..a658cb56 100644 --- a/examples/modules_usage.rs +++ b/examples/modules_usage.rs @@ -104,8 +104,6 @@ where self.add_op((lhs, rhs, &mut out), |(lhs, rhs, out)| { let dev = lhs.device(); - let out = &mut **out; - #[cfg(unified_cl)] { let cpu_out = unsafe { &mut *(out as *mut Buffer<_, OpenCL, _>) }; diff --git a/src/devices/cpu/ops.rs b/src/devices/cpu/ops.rs index 1350c11d..cb5530a8 100644 --- a/src/devices/cpu/ops.rs +++ b/src/devices/cpu/ops.rs @@ -20,9 +20,9 @@ where D::Base: Deref, S: Shape, { - fn apply_fn<'a, F>( - &'a self, - buf: &Buffer<'a, T, D, S>, + fn apply_fn( + &self, + buf: &Buffer, f: impl Fn(Resolve) -> F + Copy + 'static, ) -> Buffer where @@ -30,7 +30,7 @@ where { let mut out = self.retrieve(buf.len(), buf).unwrap(); - self.add_op2((&mut out, buf), move |(out, buf)| { + self.add_op((&mut out, buf), move |(out, buf)| { apply_fn_slice(buf, out, f); Ok(()) }).unwrap(); @@ -71,10 +71,10 @@ where ) where F: Eval + MayToCLSource, { - self.add_op::<_, 4>( - (lhs, lhs_grad.buf_no_id(), out, lhs_grad_fn.no_id()), - |(lhs, lhs_grad, out, lhs_grad_fn)| { - crate::cpu_stack_ops::add_unary_grad(lhs, out, lhs_grad, **lhs_grad_fn); + self.add_op::<_, 3>( + (lhs, lhs_grad, out), + move |(lhs, lhs_grad, out)| { + crate::cpu_stack_ops::add_unary_grad(lhs, out, lhs_grad, lhs_grad_fn); Ok(()) }, ) diff --git a/src/devices/opencl/ops.rs b/src/devices/opencl/ops.rs index d339625f..728a952c 100644 --- a/src/devices/opencl/ops.rs +++ b/src/devices/opencl/ops.rs @@ -233,18 +233,18 @@ where { let mut out = self.retrieve(buf.len(), buf).unwrap(); - self.add_op((&mut out, buf, f.no_id()), |(out, buf, f)| { + self.add_op((&mut out, buf), move |(out, buf)| { let dev = buf.device(); // let out: &mut Buffer<'_, T, OpenCL, S> = out.as_mut().unwrap(); - let out = &mut **out; + let out = &mut *out; #[cfg(unified_cl)] { let cpu_out = unsafe { &mut *(out as *mut Buffer<_, OpenCL, _>) }; dev.use_cpu_or_gpu( (file!(), line!(), column!()).into(), &[buf.len()], - || crate::devices::cpu_stack_ops::apply_fn_slice(buf, cpu_out, **f), - || try_cl_apply_fn_mut(dev, buf, out, **f).unwrap(), + || crate::devices::cpu_stack_ops::apply_fn_slice(buf, cpu_out, f), + || try_cl_apply_fn_mut(dev, buf, out, f).unwrap(), ); Ok(()) } @@ -311,10 +311,10 @@ where ) where F: ToCLSource, { - self.add_op::<_, 4>( - (lhs, lhs_grad.buf_no_id(), out, lhs_grad_fn.no_id()), - move |(lhs, lhs_grad, out, lhs_grad_fn)| { - try_cl_add_unary_grad(lhs.device(), lhs, **lhs_grad, out, **lhs_grad_fn) + self.add_op( + (lhs, lhs_grad, out), + move |(lhs, lhs_grad, out)| { + try_cl_add_unary_grad(lhs.device(), lhs, lhs_grad, out, lhs_grad_fn) }, ) .unwrap(); diff --git a/src/features.rs b/src/features.rs index 57ca86c6..f3b6082a 100644 --- a/src/features.rs +++ b/src/features.rs @@ -165,7 +165,8 @@ pub trait AddGradFn { ) where Self: AddOperation, { - self.add_op(args.clone(), forward_fn).unwrap(); + todo!(); + // self.add_op(args.clone(), forward_fn).unwrap(); self.add_grad_fn(args, grad_fn) } @@ -367,17 +368,19 @@ macro_rules! pass_down_replace_buf_module { } pub trait AddOperation { - fn add_op2<'own, 'dev, Args: Parents + crate::AnyOp2<'own, 'dev>, const N: usize>( + fn add_op + AnyOp, const N: usize>( &self, args: Args, - op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'a>) -> crate::Result<()> + 'static, + op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, ) -> crate::Result<()>; - fn add_op + UpdateArgs, const N: usize>( + + fn add_op2<'own, 'dev, Args: Parents + crate::AnyOp2<'own, 'dev>, const N: usize>( &self, args: Args, - operation: fn(&mut Args) -> crate::Result<()>, + op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'a>) -> crate::Result<()> + 'static, ) -> crate::Result<()>; + fn ops_count(&self) -> usize; fn set_lazy_enabled(&self, enabled: bool); #[inline] @@ -432,13 +435,15 @@ macro_rules! pass_down_add_operation { impl<'dev, Mods: $crate::AddOperation> $crate::AddOperation for $device<$($generics),*> { #[inline] - fn add_op + $crate::UpdateArgs, const N: usize>( + fn add_op + $crate::AnyOp, const N: usize>( &self, args: Args, - operation: fn(&mut Args) -> $crate::Result<()>, + op: impl for<'a> Fn(Args::Replicated<'a>) -> crate::Result<()> + 'static, ) -> $crate::Result<()> { - self.modules.add_op(args, operation) + self.modules.add_op(args, op) } + + #[inline] fn add_op2<'own, 'd, Args: $crate::Parents + $crate::AnyOp2<'own, 'd>, const N: usize>( &self, args: Args, diff --git a/src/modules/base.rs b/src/modules/base.rs index 3f26fcac..a1805481 100644 --- a/src/modules/base.rs +++ b/src/modules/base.rs @@ -36,28 +36,29 @@ impl<'a, D: Device + 'a> Module<'a, D> for Base { } impl AddOperation for Base { - fn add_op2<'own, 'dev, Args: Parents + crate::AnyOp2<'own, 'dev> , const N: usize>( + #[inline] + fn add_op + crate::AnyOp, const N: usize>( &self, args: Args, - op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'a>) -> crate::Result<()> + 'static, + op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, ) -> crate::Result<()> { op(unsafe { args.replication() }) } #[inline] - fn add_op, const N: usize>( + fn add_op2<'own, 'dev, Args: Parents + crate::AnyOp2<'own, 'dev> , const N: usize>( &self, - mut args: Args, - operation: fn(&mut Args) -> crate::Result<()>, + args: Args, + op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'a>) -> crate::Result<()> + 'static, ) -> crate::Result<()> { - operation(&mut args) + op(unsafe { args.replication() }) } #[inline] fn ops_count(&self) -> usize { 0 } - + #[inline] fn set_lazy_enabled(&self, _enabled: bool) {} diff --git a/src/modules/cached.rs b/src/modules/cached.rs index ff37bb17..f822b024 100644 --- a/src/modules/cached.rs +++ b/src/modules/cached.rs @@ -70,27 +70,28 @@ impl, D: Device, NewDev> Setup for CachedModule AddOperation for CachedModule { - fn add_op2<'own, 'dev, Args: Parents + crate::AnyOp2<'own, 'dev>, const N: usize>( + #[inline] + fn add_op + crate::AnyOp, const N: usize>( &self, args: Args, - op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'a>) -> crate::Result<()> + 'static, + op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, ) -> crate::Result<()> { - self.modules.add_op2(args, op) + self.modules.add_op(args, op) } - fn add_op, const N: usize>( + fn add_op2<'own, 'dev, Args: Parents + crate::AnyOp2<'own, 'dev>, const N: usize>( &self, - mut args: Args, - operation: fn(&mut Args) -> crate::Result<()>, + args: Args, + op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'a>) -> crate::Result<()> + 'static, ) -> crate::Result<()> { - operation(&mut args) + self.modules.add_op2(args, op) } #[inline] fn ops_count(&self) -> usize { self.modules.ops_count() } - + #[inline] fn set_lazy_enabled(&self, enabled: bool) { self.modules.set_lazy_enabled(enabled) diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index e284bb58..bdfb24bd 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -92,35 +92,33 @@ impl<'a, T, Mods: Module<'a, D>, D: LazySetup + Device + 'a> Module<'a, D> for L } impl AddOperation for Lazy { - fn add_op2<'own, 'dev, Args: Parents + crate::AnyOp2<'own, 'dev>, const N: usize>( + fn add_op + crate::AnyOp, const N: usize>( &self, args: Args, - op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'a>) -> crate::Result<()> + 'static, + op: impl for<'a> Fn(Args::Replicated<'a>) -> crate::Result<()> + 'static, ) -> crate::Result<()> { if self.enabled.get() { self.graph2.try_borrow_mut() .expect("already borrowed: BorrowMutError - is the inner operation trying to add an operation as well?") - .add_operation2(args, op); + .add_operation(args, op); Ok(()) } else { - self.modules.add_op2(args, op) + self.modules.add_op(args, op) } } - - #[inline] - fn add_op + UpdateArgs, const N: usize>( + fn add_op2<'own, 'dev, Args: Parents + crate::AnyOp2<'own, 'dev>, const N: usize>( &self, args: Args, - operation: fn(&mut Args) -> crate::Result<()>, + op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'a>) -> crate::Result<()> + 'static, ) -> crate::Result<()> { if self.enabled.get() { - self.graph.try_borrow_mut() + self.graph2.try_borrow_mut() .expect("already borrowed: BorrowMutError - is the inner operation trying to add an operation as well?") - .add_operation(args, operation); + .add_operation2(args, op); + Ok(()) } else { - return self.modules.add_op(args, operation); + self.modules.add_op2(args, op) } - Ok(()) } #[inline] @@ -783,8 +781,8 @@ mod tests { let vec = vec![1, 2, 3]; device .add_op( - (&mut out, a.no_id(), &b, vec.no_id()), - |(out, a, b, _vec)| { + (&mut out, &b), + move |(out, b)| { for ((lhs, rhs), out) in a.iter().zip(b.iter()).zip(out.iter_mut()) { *out = lhs + rhs; } diff --git a/src/unary.rs b/src/unary.rs index 0d94ef9f..f0fd2f1e 100644 --- a/src/unary.rs +++ b/src/unary.rs @@ -17,10 +17,10 @@ pub trait ApplyFunction: Device { /// let out = device.apply_fn(&a, |x| x.mul(2.)); /// assert_eq!(&**out, &[2., 4., 6., 6., 4., 2.,]); /// ``` - fn apply_fn<'a, F>( - &'a self, + fn apply_fn( + &self, // buf: &D::Data, - buf: &Buffer<'a, T, D, S>, + buf: &Buffer, f: impl Fn(Resolve) -> F + Copy + 'static, ) -> Buffer where From cc115c6d0eea66ddbbcd0ae85c56128addd367b4 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Wed, 24 Jul 2024 02:15:37 +0200 Subject: [PATCH 06/29] Remove usages of old LazyGraph --- examples/custom_device.rs | 2 +- src/modules/lazy.rs | 14 ++++++-------- src/modules/lazy/lazy_graph.rs | 10 +++++----- src/op_hint.rs | 6 +++--- 4 files changed, 15 insertions(+), 17 deletions(-) diff --git a/examples/custom_device.rs b/examples/custom_device.rs index df6d16fe..4561b304 100644 --- a/examples/custom_device.rs +++ b/examples/custom_device.rs @@ -525,7 +525,7 @@ fn main() { // graph.add_operation((&lhs, &rhs), |(lhs, rhs)| Ok(())); let graph: &mut LazyGraph2 = &mut unsafe { device.modules.tape_mut() }.unwrap().lazy_graph; - graph.call_lazily(&mut buffers).unwrap(); + unsafe { graph.call_lazily(&mut buffers).unwrap() }; // // unsafe { register_buf_copyable(&mut buffers, &lhs) }; // unsafe { register_buf_copyable(&mut buffers, &rhs) }; // let tape: &mut LazyGraph2 = &mut unsafe { &mut *device.modules.tape.get()}.lazy_graph; diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index bdfb24bd..08953a2a 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -43,7 +43,7 @@ pub struct Lazy { // This ensures to only allocate a buffer once, without having to remove the ID/address collision check // TODO: remove this, fix id and address collision - then just use `buffers` for duplicate calls allocated_ids: RefCell, - pub graph: RefCell, T>>, + // pub graph: RefCell, T>>, pub graph2: RefCell, T>>, cursor: Cell, enabled: Cell, @@ -80,7 +80,6 @@ impl<'a, T, Mods: Module<'a, D>, D: LazySetup + Device + 'a> Module<'a, D> for L modules: Mods::new(), buffers: Default::default(), replaced_buffers: Default::default(), - graph: Default::default(), graph2: Default::default(), alloc_later: Default::default(), allocated_ids: Default::default(), @@ -123,7 +122,7 @@ impl AddOperation for Lazy { #[inline] fn ops_count(&self) -> usize { - self.graph.borrow().operations.len() + self.graph2.borrow().ops_count() } #[inline] @@ -140,7 +139,7 @@ impl AddOperation for Lazy { impl SetOpHint for Lazy { #[inline] fn set_op_hint(&self, op_hint: OpHint) { - if let Some(op) = self.graph.borrow_mut().operations.last_mut() { + if let Some(op) = self.graph2.borrow_mut().operations.last_mut() { op.op_hint = op_hint; } } @@ -155,7 +154,7 @@ impl ExecNow for Lazy { ) -> crate::Result<()> { self.alloc_later(device); unsafe { - self.graph + self.graph2 .borrow_mut() .call_range::(range_bounds, &mut self.buffers.borrow_mut())?; } @@ -166,9 +165,9 @@ impl ExecNow for Lazy { impl Lazy { #[inline] pub unsafe fn call_lazily(&self) -> crate::Result<()> { - self.graph + self.graph2 .borrow_mut() - .call_lazily::(&mut self.buffers.borrow_mut())?; + .call_lazily(&mut self.buffers.borrow_mut())?; Ok(()) } @@ -306,7 +305,6 @@ impl AddLayer for Lazy<(), T> { modules: inner_mods, buffers: Default::default(), replaced_buffers: Default::default(), - graph: Default::default(), graph2: Default::default(), alloc_later: Default::default(), allocated_ids: Default::default(), diff --git a/src/modules/lazy/lazy_graph.rs b/src/modules/lazy/lazy_graph.rs index 6032247f..baed6503 100644 --- a/src/modules/lazy/lazy_graph.rs +++ b/src/modules/lazy/lazy_graph.rs @@ -35,7 +35,7 @@ pub struct LazyGraph, T = ()> { } pub struct LazyGraph2<'a, B = Box, T = ()> { - operations: Vec>, + pub(crate) operations: Vec>, } impl<'a, B, T> Default for LazyGraph2<'a, B, T> { @@ -70,16 +70,16 @@ impl<'a, B: Downcast, T> LazyGraph2<'a, B, T> { self.operations.len() } - pub fn call_lazily(&mut self, buffers: &mut Buffers) -> crate::Result<()> { + pub unsafe fn call_lazily(&mut self, buffers: &mut Buffers) -> crate::Result<()> { for args in self.iter_with(buffers) { args?; } Ok(()) } - pub fn call_range( + pub unsafe fn call_range( &mut self, - _device: &'a D, + // _device: &'a D, bounds: impl RangeBounds, buffers: &mut Buffers, ) -> crate::Result<()> { @@ -352,7 +352,7 @@ mod tests { println!("args: {args:?}"); Ok(()) }); - graph.call_lazily(&mut buffers).unwrap(); + unsafe { graph.call_lazily(&mut buffers).unwrap() }; }; // let x = DEVICE2.get().unwrap(); // println!("{:?}", x.modules.cache.borrow().nodes); diff --git a/src/op_hint.rs b/src/op_hint.rs index c4c29e14..2dad3c0b 100644 --- a/src/op_hint.rs +++ b/src/op_hint.rs @@ -56,7 +56,7 @@ mod tests { marker: "x", }; - let ops = &dev.modules.graph.borrow().operations; + let ops = &dev.modules.graph2.borrow().operations; let op_hint = &ops[0].op_hint; if let OpHint::Unary(op) = op_hint { let src = op(resolve).to_cl_source(); @@ -98,7 +98,7 @@ mod tests { let mut out = buf.clone(); for out in out.iter_mut() { - for op in &dev.modules.graph.borrow().operations { + for op in &dev.modules.graph2.borrow().operations { let resolve = Resolve { val: *out, marker: "x", @@ -234,7 +234,7 @@ mod tests { let start = Instant::now(); for out in out.iter_mut() { - for op in &dev.modules.graph.borrow().operations { + for op in &dev.modules.graph2.borrow().operations { let resolve = Resolve { val: *out, marker: "x", From 0d81aa72de3441e6dc1857c62b8dd7c9143d8e4f Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Wed, 24 Jul 2024 02:24:42 +0200 Subject: [PATCH 07/29] Remove old LazyGraph --- examples/custom_device.rs | 6 +- src/modules/autograd/tape.rs | 4 +- src/modules/lazy.rs | 20 +++-- src/modules/lazy/lazy_graph.rs | 136 ++++----------------------------- src/op_hint.rs | 6 +- 5 files changed, 33 insertions(+), 139 deletions(-) diff --git a/examples/custom_device.rs b/examples/custom_device.rs index 4561b304..63f739fd 100644 --- a/examples/custom_device.rs +++ b/examples/custom_device.rs @@ -9,7 +9,7 @@ use std::{ use custos::{ cpu::CPUPtr, flag::AllocFlag, impl_device_traits, AddGradFn, AddOperation, Alloc, Base, BorrowCacheLT, Buffer, Cached, CachedModule, Device, DeviceError, DevicelessAble, HasId, Id, - LazyGraph2, Module, OnDropBuffer, OnNewBuffer, PtrType, Retrieve, Retriever, Setup, Shape, + LazyGraph, Module, OnDropBuffer, OnNewBuffer, PtrType, Retrieve, Retriever, Setup, Shape, Tape, TapeActions, Unit, WrappedData, CPU, }; @@ -20,7 +20,7 @@ pub trait Str { #[derive(Default)] pub struct CPU2<'a, Mods: 'a = Base> { pub modules: Mods, - pub graph: LazyGraph2<'a>, + pub graph: LazyGraph<'a>, pd: PhantomData<&'a ()>, } @@ -524,7 +524,7 @@ fn main() { rhs.backward(); // graph.add_operation((&lhs, &rhs), |(lhs, rhs)| Ok(())); - let graph: &mut LazyGraph2 = &mut unsafe { device.modules.tape_mut() }.unwrap().lazy_graph; + let graph: &mut LazyGraph = &mut unsafe { device.modules.tape_mut() }.unwrap().lazy_graph; unsafe { graph.call_lazily(&mut buffers).unwrap() }; // // unsafe { register_buf_copyable(&mut buffers, &lhs) }; // unsafe { register_buf_copyable(&mut buffers, &rhs) }; diff --git a/src/modules/autograd/tape.rs b/src/modules/autograd/tape.rs index 89433444..10db28dc 100644 --- a/src/modules/autograd/tape.rs +++ b/src/modules/autograd/tape.rs @@ -1,5 +1,5 @@ use crate::{ - AddOperation, Alloc, AnyOp, BoxedShallowCopy, Buffer, Buffers, GradActions, LazyGraph2, + AddOperation, Alloc, AnyOp, BoxedShallowCopy, Buffer, Buffers, GradActions, LazyGraph, Parents, Shape, Unit, WriteBuf, ZeroGrad, }; @@ -10,7 +10,7 @@ pub type GradFn = Box; /// Stores the grad functions and gradient cache. #[derive(Default)] pub struct Tape<'a> { - pub lazy_graph: LazyGraph2<'a, Box>, + pub lazy_graph: LazyGraph<'a, Box>, } impl<'t> Tape<'t> { diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index 08953a2a..69b36886 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -26,7 +26,6 @@ use core::{ }; use std::collections::HashSet; -pub use self::lazy_graph::LazyGraph; use self::wrapper::LazyWrapper; pub use lazy_graph::*; @@ -43,8 +42,7 @@ pub struct Lazy { // This ensures to only allocate a buffer once, without having to remove the ID/address collision check // TODO: remove this, fix id and address collision - then just use `buffers` for duplicate calls allocated_ids: RefCell, - // pub graph: RefCell, T>>, - pub graph2: RefCell, T>>, + pub graph: RefCell, T>>, cursor: Cell, enabled: Cell, pd: PhantomData, @@ -80,7 +78,7 @@ impl<'a, T, Mods: Module<'a, D>, D: LazySetup + Device + 'a> Module<'a, D> for L modules: Mods::new(), buffers: Default::default(), replaced_buffers: Default::default(), - graph2: Default::default(), + graph: Default::default(), alloc_later: Default::default(), allocated_ids: Default::default(), cursor: Default::default(), @@ -97,7 +95,7 @@ impl AddOperation for Lazy { op: impl for<'a> Fn(Args::Replicated<'a>) -> crate::Result<()> + 'static, ) -> crate::Result<()> { if self.enabled.get() { - self.graph2.try_borrow_mut() + self.graph.try_borrow_mut() .expect("already borrowed: BorrowMutError - is the inner operation trying to add an operation as well?") .add_operation(args, op); Ok(()) @@ -111,7 +109,7 @@ impl AddOperation for Lazy { op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'a>) -> crate::Result<()> + 'static, ) -> crate::Result<()> { if self.enabled.get() { - self.graph2.try_borrow_mut() + self.graph.try_borrow_mut() .expect("already borrowed: BorrowMutError - is the inner operation trying to add an operation as well?") .add_operation2(args, op); Ok(()) @@ -122,7 +120,7 @@ impl AddOperation for Lazy { #[inline] fn ops_count(&self) -> usize { - self.graph2.borrow().ops_count() + self.graph.borrow().ops_count() } #[inline] @@ -139,7 +137,7 @@ impl AddOperation for Lazy { impl SetOpHint for Lazy { #[inline] fn set_op_hint(&self, op_hint: OpHint) { - if let Some(op) = self.graph2.borrow_mut().operations.last_mut() { + if let Some(op) = self.graph.borrow_mut().operations.last_mut() { op.op_hint = op_hint; } } @@ -154,7 +152,7 @@ impl ExecNow for Lazy { ) -> crate::Result<()> { self.alloc_later(device); unsafe { - self.graph2 + self.graph .borrow_mut() .call_range::(range_bounds, &mut self.buffers.borrow_mut())?; } @@ -165,7 +163,7 @@ impl ExecNow for Lazy { impl Lazy { #[inline] pub unsafe fn call_lazily(&self) -> crate::Result<()> { - self.graph2 + self.graph .borrow_mut() .call_lazily(&mut self.buffers.borrow_mut())?; Ok(()) @@ -305,7 +303,7 @@ impl AddLayer for Lazy<(), T> { modules: inner_mods, buffers: Default::default(), replaced_buffers: Default::default(), - graph2: Default::default(), + graph: Default::default(), alloc_later: Default::default(), allocated_ids: Default::default(), cursor: Default::default(), diff --git a/src/modules/lazy/lazy_graph.rs b/src/modules/lazy/lazy_graph.rs index baed6503..364b3edb 100644 --- a/src/modules/lazy/lazy_graph.rs +++ b/src/modules/lazy/lazy_graph.rs @@ -30,15 +30,12 @@ impl Operation { } } -pub struct LazyGraph, T = ()> { - pub operations: Vec>, -} -pub struct LazyGraph2<'a, B = Box, T = ()> { +pub struct LazyGraph<'a, B = Box, T = ()> { pub(crate) operations: Vec>, } -impl<'a, B, T> Default for LazyGraph2<'a, B, T> { +impl<'a, B, T> Default for LazyGraph<'a, B, T> { #[inline] fn default() -> Self { Self { @@ -47,7 +44,7 @@ impl<'a, B, T> Default for LazyGraph2<'a, B, T> { } } -impl<'a, B: Downcast, T> LazyGraph2<'a, B, T> { +impl<'a, B: Downcast, T> LazyGraph<'a, B, T> { #[inline] pub fn iter_with<'b>( &'b mut self, @@ -195,89 +192,12 @@ impl<'a, B: Downcast, T> LazyGraph2<'a, B, T> { } } -impl Default for LazyGraph { - #[inline] - fn default() -> Self { - Self { - operations: Vec::new(), - } - } -} - -impl LazyGraph { - #[inline] - pub fn iter_with<'a>(&'a mut self, buffers: &'a mut Buffers) -> ExecIter { - ExecIter { - operations: self.operations.iter_mut(), - buffers, - } - } - - #[inline] - pub fn clear(&mut self) { - self.operations.clear(); - } - - pub unsafe fn convert_to_operation + UpdateArgs, const N: usize>( - args: Args, - op: fn(&mut Args) -> crate::Result<()>, - ) -> Operation { - // store ids and test if buffers are still in cache - let arg_ids = args - .maybe_ids() - .into_iter() - .map(|id| id.map(|id| *id)) - .collect(); - - let args: Box> = Box::new(args); - - Operation { - arg_ids, - op: transmute(op), - args: transmute(args), - op_hint: OpHint::None, - } - } - - pub fn add_operation + UpdateArgs, const N: usize>( - &mut self, - args: Args, - op: fn(&mut Args) -> crate::Result<()>, - ) { - let operation = unsafe { Self::convert_to_operation(args, op) }; - self.operations.push(operation) - } - - pub unsafe fn call_lazily( - &mut self, - outs_unordered: &mut Buffers, - ) -> crate::Result<()> { - for args in self.iter_with(outs_unordered) { - args?; - } - Ok(()) - } - - pub unsafe fn call_range( - &mut self, - bounds: impl RangeBounds, - outs_unordered: &mut Buffers, - ) -> crate::Result<()> { - let range = bounds_to_range(bounds, self.operations.len()); - for mut op in self.operations.drain(range) { - exec_op(&mut op.args, &op.op, &op.arg_ids, outs_unordered)?; - } - Ok(()) - } -} - #[cfg(feature = "cpu")] #[cfg(test)] mod tests { - use super::LazyGraph; use crate::{ register_buf_any, register_buf_copyable, AnyBuffer, AsNoId, Base, BoxedShallowCopy, Buffer, - CloneBuf, Device, HasId, LazyGraph2, Retriever, Shape, UniqueId, CPU, + CloneBuf, Device, HasId, LazyGraph, Retriever, Shape, UniqueId, CPU, }; use core::cell::Cell; use std::collections::HashMap; @@ -328,7 +248,7 @@ mod tests { let device = CPU::>::new(); let mut buffers = HashMap::default(); - let mut graph: LazyGraph2> = LazyGraph2::default(); + let mut graph: LazyGraph> = LazyGraph::default(); let lhs = device.buffer([1f32, 2., 3., 4., 5.]); let rhs = device.buffer([1f32, 2., 6., 4., 5.]); @@ -380,7 +300,7 @@ mod tests { // outs_unordered.insert(out.id(), ) graph.add_operation::<_, 3>((&out, &lhs, &rhs), |args| { - let (_out, lhs, rhs) = *args; + let (_out, lhs, rhs) = args; assert_eq!(lhs.as_slice(), &[1f32, 2., 3., 4., 5.,]); assert_eq!(rhs.as_slice(), &[1f32, 2., 6., 4., 5.,]); Ok(()) @@ -390,7 +310,7 @@ mod tests { }; // todo!() - unsafe { graph.call_lazily::(&mut outs_unordered).unwrap() } + unsafe { graph.call_lazily(&mut outs_unordered).unwrap() } } #[test] @@ -410,13 +330,13 @@ mod tests { // outs_unordered.insert(out.id(), ) graph.add_operation::<_, 3>((&out, &lhs, &rhs), |args| { - let (_out, lhs, rhs) = *args; + let (_out, lhs, rhs) = args; assert_eq!(lhs.as_slice(), &[1f32, 2., 3., 4., 5.,]); assert_eq!(rhs.as_slice(), &[1f32, 2., 6., 4., 5.,]); Ok(()) }); - unsafe { graph.call_lazily::(&mut outs_unordered).unwrap() } + unsafe { graph.call_lazily(&mut outs_unordered).unwrap() } } #[test] @@ -449,7 +369,7 @@ mod tests { }); } - unsafe { graph.call_lazily::(&mut outs_unordered).unwrap() } + unsafe { graph.call_lazily(&mut outs_unordered).unwrap() } } #[test] fn test_lazy_op_args_no_out_but_use() { @@ -469,14 +389,14 @@ mod tests { // outs_unordered.insert(out.id(), ) graph.add_operation::<_, 2>((&lhs, &rhs), |args| { - let (lhs, rhs) = *args; + let (lhs, rhs) = args; assert_eq!(lhs.as_slice(), &[1f32, 2., 3., 4., 5.,]); assert_eq!(rhs.as_slice(), &[1f32, 2., 6., 4., 5.,]); Ok(()) }); - unsafe { graph.call_lazily::(&mut outs_unordered).unwrap() } + unsafe { graph.call_lazily(&mut outs_unordered).unwrap() } } #[test] @@ -499,9 +419,9 @@ mod tests { // outs_unordered.insert(out.id(), ) - graph.add_operation::<_, 4>( - (&mut out, &lhs, &rhs, ew_fn.no_id()), - |(_out, lhs, rhs, ew_fn)| { + graph.add_operation::<_, 3>( + (&mut out, &lhs, &rhs), + move |(_out, lhs, rhs)| { assert_eq!(lhs.as_slice(), &[1f32, 2., 3., 4., 5.,]); assert_eq!(rhs.as_slice(), &[1f32, 2., 6., 4., 5.,]); @@ -513,30 +433,6 @@ mod tests { }, ); - unsafe { graph.call_lazily::(&mut outs_unordered).unwrap() } - } - - #[test] - fn test_lazy_graph_exec_with_vecs() { - let mut graph = LazyGraph::>::default(); - - { - let vec = vec![1, 2, 3, 4]; - graph.add_operation::<_, 1>(vec.no_id(), |vec| { - assert_eq!(vec.as_slice(), &[1, 2, 3, 4]); - Ok(()) - }); - } - unsafe { graph.call_lazily::(&mut HashMap::default()) }.unwrap(); - } - - #[test] - fn test_args_ref_updating() { - let x = 5; - let y = 3.; - let mut args = (&x, 10, &y); - - let replace_x = &x; - args.0 = replace_x; + unsafe { graph.call_lazily(&mut outs_unordered).unwrap() } } } diff --git a/src/op_hint.rs b/src/op_hint.rs index 2dad3c0b..c4c29e14 100644 --- a/src/op_hint.rs +++ b/src/op_hint.rs @@ -56,7 +56,7 @@ mod tests { marker: "x", }; - let ops = &dev.modules.graph2.borrow().operations; + let ops = &dev.modules.graph.borrow().operations; let op_hint = &ops[0].op_hint; if let OpHint::Unary(op) = op_hint { let src = op(resolve).to_cl_source(); @@ -98,7 +98,7 @@ mod tests { let mut out = buf.clone(); for out in out.iter_mut() { - for op in &dev.modules.graph2.borrow().operations { + for op in &dev.modules.graph.borrow().operations { let resolve = Resolve { val: *out, marker: "x", @@ -234,7 +234,7 @@ mod tests { let start = Instant::now(); for out in out.iter_mut() { - for op in &dev.modules.graph2.borrow().operations { + for op in &dev.modules.graph.borrow().operations { let resolve = Resolve { val: *out, marker: "x", From c1d3191ba3975990762a8bac316e04322c21ddfe Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Wed, 24 Jul 2024 02:32:09 +0200 Subject: [PATCH 08/29] Remove UpdateArgs --- src/boxed_shallow_copy.rs | 29 +-------------- src/devices/opencl/ops.rs | 2 +- src/features.rs | 9 ++--- src/id.rs | 55 +-------------------------- src/lib.rs | 2 - src/modules/lazy.rs | 4 +- src/modules/lazy/exec_iter.rs | 49 +++--------------------- src/modules/lazy/lazy_graph.rs | 33 +++-------------- src/parents.rs | 28 +------------- src/unary.rs | 2 +- src/update_args.rs | 68 ---------------------------------- 11 files changed, 22 insertions(+), 259 deletions(-) delete mode 100644 src/update_args.rs diff --git a/src/boxed_shallow_copy.rs b/src/boxed_shallow_copy.rs index 88f4ff80..6340ac69 100644 --- a/src/boxed_shallow_copy.rs +++ b/src/boxed_shallow_copy.rs @@ -1,5 +1,4 @@ -use crate::{AnyBuffer, AsAny, Downcast, ShallowCopy}; -use core::any::Any; +use crate::{AnyBuffer, Downcast, ShallowCopy}; pub trait BoxedShallowCopy: AnyBuffer { fn shallow_copy(&self) -> Box; @@ -77,29 +76,3 @@ impl Downcast for Box { (**self).is::() } } - -impl AsAny for Box { - #[inline] - fn as_any(&self) -> *const () { - let data = &**self; - data as *const _ as *const () - } - - #[inline] - fn as_any_mut(&mut self) -> *mut () { - let data = &mut **self; - data as *mut _ as *mut () - } -} - -impl AsAny for Box { - #[inline] - fn as_any(&self) -> *const () { - (&**self) as *const _ as *const () - } - - #[inline] - fn as_any_mut(&mut self) -> *mut () { - (&mut **self) as *mut _ as *mut () - } -} diff --git a/src/devices/opencl/ops.rs b/src/devices/opencl/ops.rs index 728a952c..8d87e2f5 100644 --- a/src/devices/opencl/ops.rs +++ b/src/devices/opencl/ops.rs @@ -7,7 +7,7 @@ use min_cl::{ use crate::{ bounds_to_range, cpu_stack_ops::clear_slice, location, op_hint::unary, pass_down_add_operation, - pass_down_exec_now, prelude::Number, AddOperation, ApplyFunction, AsNoId, BufAsNoId, Buffer, + pass_down_exec_now, prelude::Number, AddOperation, ApplyFunction, Buffer, CDatatype, ClearBuf, CopySlice, OnDropBuffer, OpenCL, Read, Resolve, Retrieve, Retriever, SetOpHint, Shape, ToCLSource, ToMarker, TwoWay, UnaryGrad, Unit, UseGpuOrCpu, WriteBuf, ZeroGrad, diff --git a/src/features.rs b/src/features.rs index f3b6082a..2c98c403 100644 --- a/src/features.rs +++ b/src/features.rs @@ -7,7 +7,7 @@ use core::{cell::RefMut, fmt::Debug, ops::RangeBounds}; use crate::{ op_hint::OpHint, range::{AsRange, CursorRange}, - AnyOp, HasId, Parents, Shape, UniqueId, Unit, UpdateArgs, ZeroGrad, CPU, + AnyOp, HasId, Parents, Shape, UniqueId, Unit, ZeroGrad, CPU, }; #[cfg(feature = "cached")] @@ -157,16 +157,15 @@ pub trait AddGradFn { op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, ); - fn add_grad_and_forward_fn + UpdateArgs + AnyOp + Clone, const N: usize>( + fn add_grad_and_forward_fn + AnyOp + Clone, const N: usize>( &self, args: Args, - forward_fn: fn(&mut Args) -> crate::Result<()>, + forward_fn: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, grad_fn: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, ) where Self: AddOperation, { - todo!(); - // self.add_op(args.clone(), forward_fn).unwrap(); + self.add_op(args.clone(), forward_fn).unwrap(); self.add_grad_fn(args, grad_fn) } diff --git a/src/id.rs b/src/id.rs index a0fc939f..f1b45907 100644 --- a/src/id.rs +++ b/src/id.rs @@ -1,6 +1,6 @@ use core::ops::{Deref, DerefMut}; -use crate::{Buffer, Device, DeviceError, Shape, UniqueId, Unit, UpdateArg}; +use crate::{Buffer, Device, Shape, Unit}; pub trait HasId { const HAS_NO_ID: bool = false; @@ -105,59 +105,6 @@ impl>> AsNoId for T { } } -impl UpdateArg for NoId { - #[inline] - #[cfg(feature = "std")] - fn update_arg( - _to_update: &mut Self, - _id: Option, - _buffers: &mut crate::Buffers, - ) -> crate::Result<()> { - Ok(()) - } -} - -impl<'a, T: Unit + 'static, D: Device + 'static, S: Shape + 'static> UpdateArg - for &Buffer<'a, T, D, S> -{ - #[cfg(feature = "std")] - fn update_arg( - to_update: &mut Self, - id: Option, - buffers: &mut crate::Buffers, - ) -> crate::Result<()> { - // use crate::ShallowCopyable; - - let buf = buffers - .get(&id.unwrap()) - .ok_or(DeviceError::InvalidLazyBuf)?; - // let any = buf.as_any(); - // let _to_update = buf.as_any().downcast_ref::>().unwrap(); - // todo!(); - *to_update = unsafe { &*(buf.as_any() as *const Buffer) }; - // *self = buffers.get(&self.id()).unwrap().downcast_ref().unwrap(); - Ok(()) - } -} - -impl<'a, T: Unit + 'static, D: Device + 'static, S: Shape + 'static> UpdateArg - for &mut Buffer<'a, T, D, S> -{ - #[cfg(feature = "std")] - fn update_arg( - to_update: &mut Self, - id: Option, - buffers: &mut crate::Buffers, - ) -> crate::Result<()> { - let buf = buffers - .get_mut(&id.unwrap()) - .ok_or(DeviceError::InvalidLazyBuf)?; - *to_update = unsafe { &mut *(buf.as_any_mut() as *mut Buffer) }; - Ok(()) - // *self = buffers.get(&self.id()).unwrap().downcast_ref().unwrap(); - } -} - pub trait BufAsNoId: Sized { fn buf_no_id(self) -> NoId; } diff --git a/src/lib.rs b/src/lib.rs index 1f3e1096..066e3d60 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -116,7 +116,6 @@ mod range; mod shape; mod two_way_ops; mod unary; -mod update_args; mod wrapper; pub use any_op::*; @@ -130,7 +129,6 @@ pub use number::*; pub use parents::*; pub use ptr_conv::*; pub use range::*; -pub use update_args::*; pub use wrapper::*; #[cfg(not(feature = "cpu"))] diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index 69b36886..406ddd94 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -11,7 +11,7 @@ use crate::{ op_hint::OpHint, register_buf_copyable, unregister_buf_copyable, AddLayer, AddOperation, Alloc, AnyOp, BoxedShallowCopy, Buffer, CachedBuffers, Cursor, Device, ExecNow, HasId, HasModules, Id, IsShapeIndep, Module, NoHasher, OnDropBuffer, OnNewBuffer, Parents, ReplaceBuf, Retrieve, - RunModule, SetOpHint, Setup, ShallowCopy, Shape, UniqueId, Unit, UpdateArgs, UseGpuOrCpu, + RunModule, SetOpHint, Setup, ShallowCopy, Shape, UniqueId, Unit, UseGpuOrCpu, }; #[cfg(feature = "graph")] @@ -757,7 +757,7 @@ mod tests { // #[ignore = "causes UB"] #[test] fn test_lazy_exec_ub_testing() { - use crate::{AsNoId, Run}; + use crate::Run; let device = CPU::>::new(); diff --git a/src/modules/lazy/exec_iter.rs b/src/modules/lazy/exec_iter.rs index 02267e6b..750709e2 100644 --- a/src/modules/lazy/exec_iter.rs +++ b/src/modules/lazy/exec_iter.rs @@ -1,13 +1,11 @@ -use crate::{Buffers, Operation2, UniqueId, UpdateArgsDynable}; +use crate::{Buffers, Operation2}; -use super::lazy_graph::Operation; - -pub struct ExecIter2<'a, 'b, B, T> { +pub struct ExecIter<'a, 'b, B, T> { pub(super) operations: std::slice::Iter<'b, Operation2<'a, B, T>>, pub(super) buffers: &'b mut Buffers, } -impl<'a, 'b, B, T> Iterator for ExecIter2<'a, 'b, B, T> { +impl<'a, 'b, B, T> Iterator for ExecIter<'a, 'b, B, T> { type Item = crate::Result<()>; fn next(&mut self) -> Option { @@ -16,51 +14,14 @@ impl<'a, 'b, B, T> Iterator for ExecIter2<'a, 'b, B, T> { } } -impl<'a, 'b, B, T> DoubleEndedIterator for ExecIter2<'a, 'b, B, T> { +impl<'a, 'b, B, T> DoubleEndedIterator for ExecIter<'a, 'b, B, T> { fn next_back(&mut self) -> Option { let op = self.operations.next_back()?; Some((op.op)(self.buffers)) } } -impl<'a, 'b, B, T> ExactSizeIterator for ExecIter2<'a, 'b, B, T> { - fn len(&self) -> usize { - self.operations.len() - } -} -pub struct ExecIter<'a, B, T> { - pub(super) operations: std::slice::IterMut<'a, Operation>, - pub(super) buffers: &'a mut Buffers, -} - -pub fn exec_op( - args: &mut Box>, - op: &fn(*mut ()) -> crate::Result<()>, - ids_to_check: &[Option], - buffers: &mut Buffers, -) -> crate::Result<()> { - args.update_args_dynable(ids_to_check, buffers)?; - let args = core::ptr::addr_of_mut!(**args) as *mut (); - op(args) -} - -impl<'a, B, T> Iterator for ExecIter<'a, B, T> { - type Item = crate::Result<()>; - - fn next(&mut self) -> Option { - let op = self.operations.next()?; - Some(exec_op(&mut op.args, &op.op, &op.arg_ids, self.buffers)) - } -} - -impl<'a, B, T> DoubleEndedIterator for ExecIter<'a, B, T> { - fn next_back(&mut self) -> Option { - let op = self.operations.next_back()?; - Some(exec_op(&mut op.args, &op.op, &op.arg_ids, self.buffers)) - } -} - -impl<'a, B, T> ExactSizeIterator for ExecIter<'a, B, T> { +impl<'a, 'b, B, T> ExactSizeIterator for ExecIter<'a, 'b, B, T> { fn len(&self) -> usize { self.operations.len() } diff --git a/src/modules/lazy/lazy_graph.rs b/src/modules/lazy/lazy_graph.rs index 364b3edb..02efd936 100644 --- a/src/modules/lazy/lazy_graph.rs +++ b/src/modules/lazy/lazy_graph.rs @@ -1,36 +1,15 @@ use crate::{ - bounds_to_range, modules::lazy::exec_iter::ExecIter2, op_hint::OpHint, AnyOp, AsAny, - BoxedShallowCopy, Buffers, Device, Downcast, Parents, UniqueId, UpdateArgs, UpdateArgsDynable, + bounds_to_range, modules::lazy::exec_iter::ExecIter, op_hint::OpHint, AnyOp, + BoxedShallowCopy, Buffers, Device, Downcast, Parents, }; -use core::{mem::transmute, ops::RangeBounds}; +use core::ops::RangeBounds; use std::collections::HashSet; -use super::exec_iter::{exec_op, ExecIter}; - pub struct Operation2<'a, B, T> { pub op: Box) -> crate::Result<()> + 'a>, pub op_hint: OpHint, } -pub struct Operation { - pub op_hint: OpHint, - pub arg_ids: Vec>, - pub op: fn(*mut ()) -> crate::Result<()>, - pub args: Box>, -} - -impl Operation { - pub fn no_op() -> Self { - Self { - op_hint: OpHint::None, - arg_ids: vec![None], - op: |_: *mut ()| Ok(()), - args: Box::new(()), - } - } -} - - pub struct LazyGraph<'a, B = Box, T = ()> { pub(crate) operations: Vec>, } @@ -50,8 +29,8 @@ impl<'a, B: Downcast, T> LazyGraph<'a, B, T> { &'b mut self, // device: &'a D, buffers: &'b mut Buffers, - ) -> ExecIter2<'a, 'b, B, T> { - ExecIter2 { + ) -> ExecIter<'a, 'b, B, T> { + ExecIter { operations: self.operations.iter(), buffers, } @@ -196,7 +175,7 @@ impl<'a, B: Downcast, T> LazyGraph<'a, B, T> { #[cfg(test)] mod tests { use crate::{ - register_buf_any, register_buf_copyable, AnyBuffer, AsNoId, Base, BoxedShallowCopy, Buffer, + register_buf_any, register_buf_copyable, AnyBuffer, Base, Buffer, CloneBuf, Device, HasId, LazyGraph, Retriever, Shape, UniqueId, CPU, }; use core::cell::Cell; diff --git a/src/parents.rs b/src/parents.rs index 43a3b88e..b3a1a5d6 100644 --- a/src/parents.rs +++ b/src/parents.rs @@ -1,4 +1,4 @@ -use crate::{HasId, Id, UpdateArg}; +use crate::{HasId, Id}; pub trait Parents: AllParents { fn ids(&self) -> [Id; N]; @@ -22,17 +22,6 @@ impl Parents<0> for () { impl AllParents for () {} -impl UpdateArg for () { - #[cfg(feature = "std")] - fn update_arg( - _to_update: &mut Self, - _id: Option, - _buffers: &mut crate::Buffers, - ) -> crate::Result<()> { - Ok(()) - } -} - impl Parents<1> for T { #[inline] fn ids(&self) -> [Id; 1] { @@ -78,21 +67,6 @@ macro_rules! impl_parents { } impl<$($to_impl: $crate::HasId, )+> AllParents for ($($to_impl,)+) {} - impl<$($to_impl: $crate::UpdateArg + $crate::HasId, )+> $crate::UpdateArgs for ($($to_impl,)+) { - #[cfg(feature = "std")] - fn update_args(&mut self, - ids: &[Option<$crate::UniqueId>], - buffers: &mut $crate::Buffers) - -> crate::Result<()> - { - let mut ids = ids.iter(); - #[allow(non_snake_case)] - let ($($to_impl,)+) = self; - $($to_impl::update_arg($to_impl, *ids.next().unwrap(), buffers)?;)* - Ok(()) - } - } - impl<'own, 'dev, $($to_impl: $crate::Replicate2<'own, 'dev> + $crate::HasId, )+> $crate::AnyOp2<'own, 'dev> for ($($to_impl,)+) { type Replicated<'a, 'b> = ($($to_impl::Replication<'a, 'b>,)+) where 'b: 'a; diff --git a/src/unary.rs b/src/unary.rs index f0fd2f1e..45f87e86 100644 --- a/src/unary.rs +++ b/src/unary.rs @@ -1,5 +1,5 @@ use crate::{ - AddGradFn, AddOperation, Alloc, AsNoId, Buffer, Device, Eval, HasId, MayGradActions, + AddGradFn, AddOperation, Alloc, Buffer, Device, Eval, HasId, MayGradActions, MayToCLSource, Resolve, Shape, TwoWay, Unit, ZeroGrad, }; diff --git a/src/update_args.rs b/src/update_args.rs deleted file mode 100644 index 2848f9da..00000000 --- a/src/update_args.rs +++ /dev/null @@ -1,68 +0,0 @@ -use crate::UniqueId; - -#[cfg(feature = "std")] -use crate::Buffers; - -/// A dummy trait for no-std context. [`UpdateArgs`] requires standard library code. -#[cfg(not(feature = "std"))] -pub trait UpdateArgs {} - -#[cfg(feature = "std")] -pub trait UpdateArgs { - #[cfg(feature = "std")] - fn update_args( - &mut self, - ids: &[Option], - buffers: &mut Buffers, - ) -> crate::Result<()>; -} - -/// A dummy trait for no-std context. [`UpdateArg`] requires standard library code. -#[cfg(not(feature = "std"))] -pub trait UpdateArg {} - -#[cfg(feature = "std")] -pub trait UpdateArg { - fn update_arg( - to_update: &mut Self, - id: Option, - buffers: &mut Buffers, - ) -> crate::Result<()>; -} - -#[cfg(feature = "std")] -impl UpdateArgs for T { - fn update_args( - &mut self, - ids: &[Option], - buffers: &mut crate::Buffers, - ) -> crate::Result<()> { - T::update_arg(self, ids[0], buffers) - } -} - -#[cfg(feature = "std")] -pub trait UpdateArgsDynable { - fn update_args_dynable( - &mut self, - ids: &[Option], - buffers: &mut Buffers, - ) -> crate::Result<()>; -} - -#[cfg(feature = "std")] -impl UpdateArgsDynable for A { - #[inline] - fn update_args_dynable( - &mut self, - ids: &[Option], - buffers: &mut Buffers, - ) -> crate::Result<()> { - self.update_args(ids, buffers) - } -} - -pub trait AsAny { - fn as_any(&self) -> *const (); - fn as_any_mut(&mut self) -> *mut (); -} From 53f21cbf168531d99f0c2374add9e6e3cb20b408 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Wed, 24 Jul 2024 02:32:49 +0200 Subject: [PATCH 09/29] Remove custom_device.rs --- examples/custom_device.rs | 605 -------------------------------------- 1 file changed, 605 deletions(-) delete mode 100644 examples/custom_device.rs diff --git a/examples/custom_device.rs b/examples/custom_device.rs deleted file mode 100644 index 63f739fd..00000000 --- a/examples/custom_device.rs +++ /dev/null @@ -1,605 +0,0 @@ -use std::{ - cell::{Cell, RefCell, UnsafeCell}, - collections::HashMap, - convert::Infallible, - marker::PhantomData, - ops::AddAssign, -}; - -use custos::{ - cpu::CPUPtr, flag::AllocFlag, impl_device_traits, AddGradFn, AddOperation, Alloc, Base, - BorrowCacheLT, Buffer, Cached, CachedModule, Device, DeviceError, DevicelessAble, HasId, Id, - LazyGraph, Module, OnDropBuffer, OnNewBuffer, PtrType, Retrieve, Retriever, Setup, Shape, - Tape, TapeActions, Unit, WrappedData, CPU, -}; - -pub trait Str { - fn str(&self) -> &String; -} - -#[derive(Default)] -pub struct CPU2<'a, Mods: 'a = Base> { - pub modules: Mods, - pub graph: LazyGraph<'a>, - pd: PhantomData<&'a ()>, -} - -impl<'dev, T, D, S, Mods> crate::OnNewBuffer<'dev, T, D, S> for CPU2<'_, Mods> -where - Self: 'dev, - T: crate::Unit, - D: Device, - S: Shape, - Mods: crate::OnNewBuffer<'dev, T, D, S>, -{ - #[inline] - fn on_new_buffer(&self, device: &'dev D, new_buf: &Buffer<'dev, T, D, S>) { - self.modules.on_new_buffer(device, new_buf) - } -} - -impl<'dev, Mods: crate::OnDropBuffer> crate::OnDropBuffer for CPU2<'dev, Mods> { - #[inline] - fn on_drop_buffer( - &self, - device: &D, - buf: &Buffer, - ) { - self.modules.on_drop_buffer(device, buf) - } -} -impl<'dev, Mods: crate::WrappedData> crate::WrappedData for CPU2<'dev, Mods> { - type Wrap = Mods::Wrap; - - #[inline] - fn wrap_in_base( - &self, - base: Base, - ) -> Self::Wrap { - self.modules.wrap_in_base(base) - } - - #[inline] - fn wrapped_as_base( - wrap: &Self::Wrap, - ) -> &Base { - Mods::wrapped_as_base(wrap) - } - - #[inline] - fn wrapped_as_base_mut( - wrap: &mut Self::Wrap, - ) -> &mut Base { - Mods::wrapped_as_base_mut(wrap) - } -} - -impl<'a, Mods: OnDropBuffer> Device for CPU2<'a, Mods> { - type Error = Infallible; - type Base = CPUPtr; - type Data = Self::Wrap>; - // type WrappedData = ; - - fn new() -> Result { - todo!() - } - - #[inline(always)] - fn base_to_data(&self, base: Self::Base) -> Self::Data { - self.wrap_in_base(base) - } - - #[inline(always)] - fn wrap_to_data( - &self, - wrap: Self::Wrap>, - ) -> Self::Data { - wrap - } - - #[inline(always)] - fn data_as_wrap( - data: &Self::Data, - ) -> &Self::Wrap> { - data - } - - #[inline(always)] - fn data_as_wrap_mut( - data: &mut Self::Data, - ) -> &mut Self::Wrap> { - data - } - - // #[inline] - // fn wrap(&self) {} -} - -impl<'a, SimpleMods> CPU2<'a, SimpleMods> { - #[inline] - pub fn new() -> CPU2<'a, SimpleMods::Module> - where - SimpleMods: Module<'a, CPU2<'a>, Module = NewMods>, - // NewMods: Setup>, - { - let mut cpu = CPU2 { - modules: SimpleMods::new(), - graph: Default::default(), - pd: PhantomData, - }; - // NewMods::setup(&mut cpu).unwrap(); - cpu - } -} - -impl Alloc for CPU2<'_, Mods> { - fn alloc(&self, mut len: usize, flag: AllocFlag) -> custos::Result> { - if len == 0 { - return Err(DeviceError::ZeroLengthBuffer.into()); - } - - if S::LEN > len { - len = S::LEN - } - - Ok(CPUPtr::new_initialized(len, flag)) - } - - fn alloc_from_slice(&self, data: &[T]) -> custos::Result> - where - S: Shape, - T: Clone, - { - if data.is_empty() { - return Err(DeviceError::ZeroLengthBuffer.into()); - } - if !(S::LEN == data.len() || S::LEN == 0) { - return Err(DeviceError::ShapeLengthMismatch.into()); - } - - let cpu_ptr = unsafe { CPUPtr::new(data.len(), AllocFlag::None) }; - let slice = unsafe { std::slice::from_raw_parts_mut(cpu_ptr.ptr, data.len()) }; - slice.clone_from_slice(data); - - Ok(cpu_ptr) - } -} - -pub trait New { - fn new1<'a, NewMods>() -> CPU - where - Self: 'a, - SimpleMods: Module<'a, CPU, Module = NewMods>; -} - -impl New for CPU { - #[inline] - fn new1<'a, NewMods>() -> CPU - where - Self: 'a, - SimpleMods: Module<'a, CPU, Module = NewMods>, - { - CPU { - modules: SimpleMods::new(), - } - } -} - -#[derive(Default)] -pub struct Autograd<'a, Mods> { - _cache: UnsafeCell>, - tape: UnsafeCell>, - val: Cell>, - _modules: Mods, -} - -impl<'a, T, S: Shape, D: Device, Mods: OnDropBuffer> Retrieve for Autograd<'a, Mods> { - unsafe fn retrieve( - &self, - device: &D, - len: usize, - parents: impl custos::Parents, - ) -> custos::Result::Base>> - where - S: Shape, - D: Device + Alloc, - { - todo!() - } -} - -impl<'a, Mods: OnDropBuffer> AddOperation for Autograd<'a, Mods> { - fn ops_count(&self) -> usize { - todo!() - } - - fn set_lazy_enabled(&self, enabled: bool) { - todo!() - } - - fn is_lazy_enabled(&self) -> bool { - todo!() - } - - fn add_op2<'own, 'dev, Args: custos::Parents + custos::AnyOp2<'own, 'dev>, const N: usize>( - &self, - args: Args, - op: impl for<'g, 'b> Fn(Args::Replicated<'g, 'g>) -> custos::Result<()> + 'static, - ) -> custos::Result<()> { - todo!() - } - - fn add_op + custos::AnyOp, const N: usize>( - &self, - args: Args, - op: impl for<'b> Fn(Args::Replicated<'b>) -> custos::Result<()> + 'static, - ) -> custos::Result<()> { - todo!() - } -} - -impl<'a, Mods> Autograd<'a, Mods> { - pub fn add_buf(&'a self, device: &'a CPU) { - // unsafe { (*self._cache.get()).add_buf(device) }; - // binding.get_buf_mut(device); - // self.val.set(Some(&device.val)); - } -} -pub trait GradActions<'dev, D: Device> { - fn get_grad(&self, device: &'dev D, for_buf_id: Id) -> &Buffer<'dev, T, D, S> - where - D: Alloc, - T: 'static, - S: Shape; - - #[allow(clippy::mut_from_ref)] - unsafe fn get_grad_mut<'b, T, S>(&'b self, for_buf_id: Id) -> &'b mut Buffer<'dev, T, D, S> - where - T: 'static, - S: Shape; - - fn grad(&self, device: &'dev D, for_buf: &Buffer<'_, T, D, S>) -> &Buffer<'dev, T, D, S> - where - T: 'static, - D: Alloc, - S: Shape, - { - self.get_grad(device, for_buf.id()) - } - - fn grad_mut<'b, T, S>( - &'b self, - for_buf: &'b Buffer<'_, T, D, S>, - ) -> &'b mut Buffer<'dev, T, D, S> - where - T: 'static, - S: Shape, - { - todo!() - // self.get_grad_mut(for_buf.id()) - } -} - -impl<'dev, Mods, D: Device + 'static> GradActions<'dev, D> for Autograd<'dev, Mods> { - fn get_grad<'a, T, S>(&'a self, device: &'dev D, for_buf_id: Id) -> &'a Buffer<'dev, T, D, S> - where - D: Alloc, - T: 'static, - S: Shape, - { - let mut new_buf = false; - unsafe { (*self._cache.get()).add_buf_once::(device, for_buf_id, &mut new_buf) }; - unsafe { (*self._cache.get()).get_buf(for_buf_id) }.unwrap() - } - - unsafe fn get_grad_mut<'b, T, S>(&'b self, for_buf_id: Id) -> &'b mut Buffer<'dev, T, D, S> - where - T: 'static, - S: Shape, - { - unsafe { (*self._cache.get()).get_buf_mut(for_buf_id) }.unwrap() - } -} - -impl<'dev, Mods: OnDropBuffer + GradActions<'dev, Self>> GradActions<'dev, Self> for CPU -where - Self: 'dev, -{ - fn get_grad<'a, T, S>( - &'a self, - device: &'dev Self, - for_buf_id: Id, - ) -> &'a Buffer<'dev, T, Self, S> - where - T: 'static, - S: Shape, - { - self.modules.get_grad(device, for_buf_id) - } - - unsafe fn get_grad_mut<'b, T, S>(&'b self, for_buf_id: Id) -> &'b mut Buffer<'dev, T, Self, S> - where - T: 'static, - S: Shape, - { - self.modules.get_grad_mut(for_buf_id) - } -} - -pub struct Test<'a> { - pd: PhantomData<&'a ()>, -} - -impl<'a, D: 'a, Mods: Module<'a, D>> Module<'a, D> for Autograd<'a, Mods> { - type Module = Autograd<'a, Mods::Module>; - - fn new() -> Self::Module { - Autograd { - _cache: Default::default(), - tape: Default::default(), - _modules: Mods::new(), - val: Default::default(), - } - } -} - -impl<'a, 'b, T, D, S, Mods: OnNewBuffer<'a, T, D, S>> OnNewBuffer<'b, T, D, S> - for Autograd<'a, Mods> -where - D: Device, - S: Shape, -{ -} - -impl<'a, Mods: OnDropBuffer> OnDropBuffer for Autograd<'a, Mods> { - #[inline] - fn on_drop_buffer( - &self, - _device: &D, - _buf: &custos::prelude::Buffer, - ) { - self._modules.on_drop_buffer(_device, _buf) - } -} - -impl<'a, Mods: WrappedData> WrappedData for Autograd<'a, Mods> { - type Wrap = Mods::Wrap; - - #[inline] - fn wrap_in_base(&self, base: Base) -> Self::Wrap { - self._modules.wrap_in_base(base) - } - - #[inline] - fn wrapped_as_base(wrap: &Self::Wrap) -> &Base { - Mods::wrapped_as_base(wrap) - } - - #[inline] - fn wrapped_as_base_mut(wrap: &mut Self::Wrap) -> &mut Base { - Mods::wrapped_as_base_mut(wrap) - } -} - -pub trait Grad<'dev, T, D: Device, S: Shape> { - fn grad1(&self) -> &Buffer<'dev, T, D, S>; - fn grad_mut1(&mut self) -> &mut Buffer<'dev, T, D, S>; -} - -impl<'dev, T, D, S> Grad<'dev, T, D, S> for Buffer<'dev, T, D, S> -where - T: 'static, - D: Device + 'static + GradActions<'dev, D> + Alloc, - S: Shape, -{ - fn grad1(&self) -> &Buffer<'dev, T, D, S> { - self.device().get_grad(self.device(), self.id()) - } - - fn grad_mut1(&mut self) -> &mut Buffer<'dev, T, D, S> { - unsafe { self.device().get_grad_mut(self.id()) } - } -} - -pub trait AddBuf<'dev, T: Unit, S: Shape = (), D: Device = Self>: Sized + Device { - fn add(&self, lhs: &mut Buffer, rhs: &mut Buffer) -> Buffer; - fn test<'a>(&self, lhs: &'a Buffer<'_, T, D, S>) -> &'a Buffer<'dev, T, Self, S>; -} - -impl<'dev, T, S, Mods> AddBuf<'dev, T, S, Self> for CPU -where - T: Unit + Copy + AddAssign + 'static, - S: Shape, - Mods: 'static - + for<'d> GradActions<'d, Self> - + OnDropBuffer - + AddOperation - + Retrieve, -{ - fn add( - &self, - lhs: &mut Buffer, - rhs: &mut Buffer, - ) -> Buffer { - let out = self.retrieve(lhs.len, (&*lhs, &*rhs)).unwrap(); - - // lazy fn not grad fn -> wurscht - self.add_op((lhs, rhs /*&out*/), |(lhs, rhs /*out*/)| { - lhs.grad_mut1(); - rhs.grad1(); - // add_ew_grad_slice(lhs.grad_mut1(), out.grad1()); - // add_ew_grad_slice(rhs.grad_mut1(), out.grad1()); - Ok(()) - }) - .unwrap(); - - out - } - - fn test<'a>(&self, lhs: &'a Buffer<'_, T, Self, S>) -> &'a Buffer<'dev, T, Self, S> { - todo!() - // lhs.grad1() - } -} - -fn add_ew_grad_slice(grad_acc: &mut [T], out_grad: &[T]) { - for (grad, out_grad) in grad_acc.iter_mut().zip(out_grad) { - *grad += *out_grad; - } -} - -pub trait OnNewBuffer2<'dev, T: Unit, D: Device + 'dev, S: Shape = ()> { - fn on_new_buffer2(new_buf: &Buffer<'dev, T, D, S>) {} - fn on_new_buffer3(&self, x: &Buffer<'dev, T, D, S>) {} - // fn on_new_buffer2(&self, /*_device: &'dev D,*/ _new_buf: &'_ Buffer<'dev, T, D, S>) {} -} - -impl<'dev, Mods: OnNewBuffer2<'dev, T, Self, S> + OnDropBuffer + 'dev, T, S: Shape> - OnNewBuffer2<'dev, T, Self, S> for CPU -{ - // fn on_new_buffer2(&self, _device: &'dev D, _new_buf: &Buffer<'dev, T, D, S>) { - // self.modules.on_new_buffer2(_device, _new_buf) - // } -} - -impl<'dev, Mods: OnNewBuffer2<'dev, T, D, S>, T, D: Device + 'dev, S: Shape> - OnNewBuffer2<'dev, T, D, S> for Autograd<'dev, Mods> -{ - // fn on_new_buffer2(&self, _device: &'dev D, _new_buf: &Buffer<'dev, T, D, S>) { - // self._modules.on_new_buffer2(_device, _new_buf) - // } -} - -impl<'dev, Mods: OnNewBuffer2<'dev, T, D, S>, T, D: Device + 'dev, S: Shape, SD: Device> - OnNewBuffer2<'dev, T, D, S> for CachedModule -{ - // fn on_new_buffer2(&self, _device: &'dev D, _new_buf: &Buffer<'dev, T, D, S>) { - // self.modules.on_new_buffer2(_device, _new_buf) - // } -} - -impl<'dev, T, D: Device + 'dev, S: Shape> OnNewBuffer2<'dev, T, D, S> for Base {} - -fn x<'a>(device: &'a CPU2<'a, Autograd<'a, Base>>) { - let lhs = device.buffer([1f32, 2., 3., 4., 5.]); - let rhs = device.buffer([1f32, 2., 6., 4., 5.]); - // let mut buffers = HashMap::default(); - // // unsafe { register_buf_copyable(&mut buffers, &lhs) }; - // // unsafe { register_buf_copyable(&mut buffers, &rhs) }; - // let tape: &'a mut LazyGraph2<'a> = &mut unsafe { &mut *device.modules.tape.get()}.lazy_graph; - // tape.add_operation((&lhs, &rhs), |(lhs, rhs)| { - // // lhs.grad(); - // Ok(()) - // }); - // tape.call_lazily(&mut buffers).unwrap(); -} -impl<'a, T, S: Shape, Mods: OnDropBuffer> DevicelessAble<'a, T, S> for CPU2<'_, Mods> {} -fn main() { - // let x = Box::new(Typ::default()); - // Box::into_raw(x); - // - { - // x(&device); - let mut device = CPU::>::new(); - - let lhs = device.buffer([1f32, 2., 3., 4., 5.]); - - // let lhs = Buffer::::deviceless(&device, 10); - let rhs = device.buffer([1f32, 2., 6., 4., 5.]); - // let rhs = Buffer::::deviceless(&device, 10); - let mut buffers = HashMap::default(); - - // let graph = unsafe { &mut *device.graph.get() }; - // let mut graph = &mut device.graph; - - device.add_grad_fn((&lhs, &rhs), |(lhs, rhs)| { - unsafe { lhs.grad_mut() }; - Ok(()) - }); - - unsafe { device.modules.tape_mut() } - .unwrap() - .backward(&mut buffers, false); - // unsafe { device.modules.tape_mut() }.unwrap().backward_seeded_with_buffers(&lhs, &[1., 1., 1., 1., 1.], &mut buffers); - rhs.backward(); - - // graph.add_operation((&lhs, &rhs), |(lhs, rhs)| Ok(())); - let graph: &mut LazyGraph = &mut unsafe { device.modules.tape_mut() }.unwrap().lazy_graph; - unsafe { graph.call_lazily(&mut buffers).unwrap() }; - // // unsafe { register_buf_copyable(&mut buffers, &lhs) }; - // unsafe { register_buf_copyable(&mut buffers, &rhs) }; - // let tape: &mut LazyGraph2 = &mut unsafe { &mut *device.modules.tape.get()}.lazy_graph; - // tape.add_operation((&lhs, &rhs), |(lhs, rhs)| { - // lhs.grad(); - // Ok(()) - // }); - // tape.call_lazily(&device, &mut buffers).unwrap(); - - let dev = CPU::>>::new1(); - - let data = /*dev.wrap_in_base(*/dev.alloc::<()>(10, custos::flag::AllocFlag::None).unwrap(); - let buffer: Buffer = Buffer { data, device: None }; - // CPU::>::on_new_buffer2(&buffer); - // dev.on_new_buffer3(&buffer); - // OnNewBuffer2::<_ ,_>::on_new_buffer3(&dev, &buffer) - // dev.on_new_buffer2(&buffer) - dev.on_new_buffer(&dev, &buffer); - - // let mut out = dev.buffer([1, 2, 3]); - // let mut out1 = dev.buffer([1, 2, 3]); - - // let mut out = dev.add(&mut out, &mut out1); - // dev.add(&mut out, &mut out1); - // dev.test(&out); - - // // dev.get_grad::(out.id()); - // { - // let z = out.grad_mut1(); - // let x = out1.grad_mut1(); - // assert_eq!(z.len(), x.len()); - // out.grad1(); - // } - - // let x = dev.grad_mut(&out); - // let z = dev.grad_mut(&out); - // assert_eq!(z.len(), x.len()); - // unsafe { dev.get_grad_mut::(out.id()) }; - // unsafe { dev.get_grad_mut::(out.id()) }; - } - - // return; - - // let out = Buffer::new(&dev, 10); - // - // out.grad(); - - let mods = Autograd::::default(); - { - let dev = CPU::>::new1(); - let mut cache = BorrowCacheLT::default(); - cache.add_buf::(&dev, Id { id: 0, len: 10 }); - // dev.modules.add_buf(&dev); - // let out = dev.modules._cache._cache.get(&3).unwrap(); - // mods.add_buunsafe { f(&dev); - // mods.add_buf(&dev); - { - // cache.add_buf(&dev); - } - { - // cache.add_buf(&dev); - } - // cache.get_buf_mut(&dev); - let out = cache - .get_buf::>, ()>(Id { id: 0, len: 10 }) - .unwrap(); - let out1 = cache - .get_buf_mut::>, ()>(Id { id: 0, len: 10 }) - .unwrap(); - // assert_eq!(out.len(), out1.len()); - out1; - } - // let out = unsafe { cache.get_buf::>, ()>(Id { id: 0, len: 10 }) }; - let dev = CPU::>::new1(); - // cache.add_buf(&dev); - // mods.val; -} From 00a0b457e16a0d1f955138490a215a8e1e2a71a3 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Wed, 24 Jul 2024 02:34:05 +0200 Subject: [PATCH 10/29] Remove custom_device.rs example link in Cargo.toml --- Cargo.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8ff96681..695eec69 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -98,10 +98,6 @@ serde_test = "1" name = "cuda_usage" required-features = ["cuda"] -[[example]] -name = "custom_device" -required-features = ["cpu", "autograd"] - [[example]] name = "cpu_usage" required-features = ["cpu"] From 9843e15cb3a1e0ae338b1e75f1347c2bececdcf0 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Wed, 24 Jul 2024 02:35:35 +0200 Subject: [PATCH 11/29] Bump rust-version to 1.79 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 695eec69..aaae07e8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ repository = "https://github.com/elftausend/custos" keywords = ["gpu", "autodiff", "arrays", "deep-learning", "fixed-size"] categories = ["science", "mathematics", "no-std", "external-ffi-bindings"] readme = "README.md" -rust-version = "1.70" +rust-version = "1.79" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] From a42bb4bc0e09a450935860144fdf4bd57f23ec58 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Wed, 24 Jul 2024 02:40:28 +0200 Subject: [PATCH 12/29] Remove add_op2, add_operation2, convert_to_operation2 --- src/any_op.rs | 176 +++++++++++++++++---------------- src/features.rs | 16 --- src/modules/base.rs | 9 -- src/modules/cached.rs | 8 -- src/modules/lazy.rs | 14 --- src/modules/lazy/exec_iter.rs | 4 +- src/modules/lazy/lazy_graph.rs | 66 +------------ src/parents.rs | 27 ----- 8 files changed, 96 insertions(+), 224 deletions(-) diff --git a/src/any_op.rs b/src/any_op.rs index c8202c2c..0cd183a8 100644 --- a/src/any_op.rs +++ b/src/any_op.rs @@ -3,6 +3,19 @@ use crate::{Buffer, Device, Id}; #[cfg(feature = "std")] use crate::{Buffers, Downcast}; +pub trait Replicate { + type Replication<'r>; + type Downcast<'r>: 'r; + + #[cfg(feature = "std")] + unsafe fn replicate_borrowed<'r, B: Downcast>( + id: &Id, + buffers: &'r mut Buffers, + ) -> Option>; + + unsafe fn replicate<'a>(self) -> Self::Replication<'a>; +} + pub trait AnyOp: Sized { type Replicated<'a>; @@ -15,70 +28,43 @@ pub trait AnyOp: Sized { unsafe fn replication<'a>(self) -> Self::Replicated<'a>; } -pub trait AnyOp2<'own, 'dev>: Sized { - type Replicated<'a, 'b> - where - 'b: 'a; - - #[cfg(feature = "std")] - fn replication_fn( - ids: Vec, - op: impl for<'a, 'b> Fn(Self::Replicated<'a, 'a>) -> crate::Result<()> + 'static, - ) -> Box Fn(&'i mut Buffers) -> crate::Result<()>>; - - unsafe fn replication<'iown, 'idev>(self) -> Self::Replicated<'iown, 'idev>; -} - -pub trait Replicate2<'uown, 'udev> { - type Replication<'r, 'd> - where - 'd: 'r; - type Downcast<'r>: 'r; - - unsafe fn replicate<'own, 'dev>(self) -> Self::Replication<'own, 'dev>; - - #[cfg(feature = "std")] - unsafe fn replicate_borrowed<'r, B: Downcast>( - id: &Id, - buffers: &'r mut Buffers, - ) -> Option>; -} - -impl<'uown, 'a, T: 'static, D: Device + 'static, S: crate::Shape> Replicate2<'uown, 'a> - for &'uown crate::Buffer<'a, T, D, S> +impl<'a, T: 'static, D: Device + 'static, S: crate::Shape> Replicate + for &crate::Buffer<'a, T, D, S> { - type Replication<'r, 'd> = &'r Buffer<'r, T, D, S> where 'd: 'r; + type Replication<'r> = &'r Buffer<'r, T, D, S>; type Downcast<'r> = Buffer<'r, T, D, S>; #[cfg(feature = "std")] + #[inline] unsafe fn replicate_borrowed<'r, B: Downcast>( id: &Id, buffers: &'r mut Buffers, - ) -> Option> { + ) -> Option> { buffers.get(id)?.downcast_ref::>() } - - unsafe fn replicate<'own, 'dev: 'own>(self) -> Self::Replication<'own, 'dev> { + + #[inline] + unsafe fn replicate<'r>(self) -> Self::Replication<'r> { // TODO: this should work without this trick -> move 'own, 'dev up to the trait when something like for<'a: 'b, ...> starts to work + // https://github.com/rust-lang/rust/issues/100013 // look at commit "0d54d19a52979352ec59f1619a439541e08c30a0" - it was implemented like this there - // but than something like this happens: https://github.com/rust-lang/rust/issues/100013 // most of the "double lifetime stuff" is still implemented at the moment // commit a985577299335ab00a02dc226a2e4b9d1642b8f7 introduced this line - unsafe { core::mem::transmute::>(self) } + unsafe { core::mem::transmute::>(self) } } } -impl<'uown, 'udev, T: 'static, D: Device + 'static, S: crate::Shape> Replicate2<'uown, 'udev> - for &'uown mut crate::Buffer<'udev, T, D, S> +impl<'a, T: 'static, D: Device + 'static, S: crate::Shape> Replicate + for &mut crate::Buffer<'a, T, D, S> { - type Replication<'r, 'd> = &'r mut Self::Downcast<'d> where 'd: 'r; + type Replication<'r> = &'r mut Self::Downcast<'r>; type Downcast<'r> = Buffer<'r, T, D, S>; #[cfg(feature = "std")] unsafe fn replicate_borrowed<'r, B: Downcast>( id: &Id, buffers: &'r mut Buffers, - ) -> Option> { + ) -> Option> { let replication = buffers.get_mut(id)?; if !replication.is::>() { return None; @@ -86,66 +72,104 @@ impl<'uown, 'udev, T: 'static, D: Device + 'static, S: crate::Shape> Replicate2< Some(unsafe { replication.downcast_mut_unchecked::>() }) } - unsafe fn replicate<'own, 'dev>(self) -> Self::Replication<'own, 'dev> { + #[inline] + unsafe fn replicate<'r>(self) -> Self::Replication<'r> { // TODO: this should work without this trick -> move 'own, 'dev up to the trait when something like for<'a: 'b, ...> starts to work // https://github.com/rust-lang/rust/issues/100013 // look at commit "0d54d19a52979352ec59f1619a439541e08c30a0" - it was implemented like this there // most of the "double lifetime stuff" is still implemented at the moment // commit a985577299335ab00a02dc226a2e4b9d1642b8f7 introduced this line - unsafe { core::mem::transmute::>(self) } + unsafe { core::mem::transmute::>(self) } } } -pub trait Replicate { - type Replication<'r>; +impl AnyOp for R { + #[cfg(feature = "std")] + fn replication_fn( + ids: Vec, + op: impl for<'a> Fn(Self::Replicated<'a>) -> crate::Result<()> + 'static, + ) -> Box) -> crate::Result<()>> { + use crate::DeviceError; + + let id = ids[0]; + Box::new(move |buffers| { + let r1 = unsafe { R::replicate_borrowed(&id, buffers) }.ok_or(DeviceError::InvalidLazyBuf)?; + op(r1) + }) + } + type Replicated<'a> = R::Replication<'a>; + + #[inline] + unsafe fn replication<'a>(self) -> Self::Replicated<'a> { + self.replicate() + } +} + +/* +pub trait AnyOp2<'own, 'dev>: Sized { + type Replicated<'a, 'b> + where + 'b: 'a; + + #[cfg(feature = "std")] + fn replication_fn( + ids: Vec, + op: impl for<'a, 'b> Fn(Self::Replicated<'a, 'a>) -> crate::Result<()> + 'static, + ) -> Box Fn(&'i mut Buffers) -> crate::Result<()>>; + + unsafe fn replication<'iown, 'idev>(self) -> Self::Replicated<'iown, 'idev>; +} + +pub trait Replicate2<'uown, 'udev> { + type Replication<'r, 'd> + where + 'd: 'r; type Downcast<'r>: 'r; + unsafe fn replicate<'own, 'dev>(self) -> Self::Replication<'own, 'dev>; + #[cfg(feature = "std")] unsafe fn replicate_borrowed<'r, B: Downcast>( id: &Id, buffers: &'r mut Buffers, - ) -> Option>; - - unsafe fn replicate<'a>(self) -> Self::Replication<'a>; + ) -> Option>; } -impl<'a, T: 'static, D: Device + 'static, S: crate::Shape> Replicate - for &crate::Buffer<'a, T, D, S> +impl<'uown, 'a, T: 'static, D: Device + 'static, S: crate::Shape> Replicate2<'uown, 'a> + for &'uown crate::Buffer<'a, T, D, S> { - type Replication<'r> = &'r Buffer<'r, T, D, S>; + type Replication<'r, 'd> = &'r Buffer<'r, T, D, S> where 'd: 'r; type Downcast<'r> = Buffer<'r, T, D, S>; #[cfg(feature = "std")] - #[inline] unsafe fn replicate_borrowed<'r, B: Downcast>( id: &Id, buffers: &'r mut Buffers, - ) -> Option> { + ) -> Option> { buffers.get(id)?.downcast_ref::>() } - - #[inline] - unsafe fn replicate<'r>(self) -> Self::Replication<'r> { + + unsafe fn replicate<'own, 'dev: 'own>(self) -> Self::Replication<'own, 'dev> { // TODO: this should work without this trick -> move 'own, 'dev up to the trait when something like for<'a: 'b, ...> starts to work - // https://github.com/rust-lang/rust/issues/100013 // look at commit "0d54d19a52979352ec59f1619a439541e08c30a0" - it was implemented like this there + // but than something like this happens: https://github.com/rust-lang/rust/issues/100013 // most of the "double lifetime stuff" is still implemented at the moment // commit a985577299335ab00a02dc226a2e4b9d1642b8f7 introduced this line - unsafe { core::mem::transmute::>(self) } + unsafe { core::mem::transmute::>(self) } } } -impl<'a, T: 'static, D: Device + 'static, S: crate::Shape> Replicate - for &mut crate::Buffer<'a, T, D, S> +impl<'uown, 'udev, T: 'static, D: Device + 'static, S: crate::Shape> Replicate2<'uown, 'udev> + for &'uown mut crate::Buffer<'udev, T, D, S> { - type Replication<'r> = &'r mut Self::Downcast<'r>; + type Replication<'r, 'd> = &'r mut Self::Downcast<'d> where 'd: 'r; type Downcast<'r> = Buffer<'r, T, D, S>; #[cfg(feature = "std")] unsafe fn replicate_borrowed<'r, B: Downcast>( id: &Id, buffers: &'r mut Buffers, - ) -> Option> { + ) -> Option> { let replication = buffers.get_mut(id)?; if !replication.is::>() { return None; @@ -153,38 +177,16 @@ impl<'a, T: 'static, D: Device + 'static, S: crate::Shape> Replicate Some(unsafe { replication.downcast_mut_unchecked::>() }) } - #[inline] - unsafe fn replicate<'r>(self) -> Self::Replication<'r> { + unsafe fn replicate<'own, 'dev>(self) -> Self::Replication<'own, 'dev> { // TODO: this should work without this trick -> move 'own, 'dev up to the trait when something like for<'a: 'b, ...> starts to work // https://github.com/rust-lang/rust/issues/100013 // look at commit "0d54d19a52979352ec59f1619a439541e08c30a0" - it was implemented like this there // most of the "double lifetime stuff" is still implemented at the moment // commit a985577299335ab00a02dc226a2e4b9d1642b8f7 introduced this line - unsafe { core::mem::transmute::>(self) } + unsafe { core::mem::transmute::>(self) } } } -impl AnyOp for R { - #[cfg(feature = "std")] - fn replication_fn( - ids: Vec, - op: impl for<'a> Fn(Self::Replicated<'a>) -> crate::Result<()> + 'static, - ) -> Box) -> crate::Result<()>> { - use crate::DeviceError; - - let id = ids[0]; - Box::new(move |buffers| { - let r1 = unsafe { R::replicate_borrowed(&id, buffers) }.ok_or(DeviceError::InvalidLazyBuf)?; - op(r1) - }) - } - type Replicated<'a> = R::Replication<'a>; - - #[inline] - unsafe fn replication<'a>(self) -> Self::Replicated<'a> { - self.replicate() - } -} impl<'own, 'dev, R: crate::HasId + Replicate2<'own, 'dev>> AnyOp2<'own, 'dev> for R { type Replicated<'a, 'b> = R::Replication<'a, 'b> where 'b: 'a; @@ -208,4 +210,4 @@ impl<'own, 'dev, R: crate::HasId + Replicate2<'own, 'dev>> AnyOp2<'own, 'dev> fo unsafe fn replication<'iown, 'idev: 'iown>(self) -> Self::Replicated<'iown, 'idev> { unsafe { self.replicate() } } -} +}*/ diff --git a/src/features.rs b/src/features.rs index 2c98c403..31cacb31 100644 --- a/src/features.rs +++ b/src/features.rs @@ -373,13 +373,6 @@ pub trait AddOperation { op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, ) -> crate::Result<()>; - - fn add_op2<'own, 'dev, Args: Parents + crate::AnyOp2<'own, 'dev>, const N: usize>( - &self, - args: Args, - op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'a>) -> crate::Result<()> + 'static, - ) -> crate::Result<()>; - fn ops_count(&self) -> usize; fn set_lazy_enabled(&self, enabled: bool); #[inline] @@ -442,15 +435,6 @@ macro_rules! pass_down_add_operation { self.modules.add_op(args, op) } - #[inline] - fn add_op2<'own, 'd, Args: $crate::Parents + $crate::AnyOp2<'own, 'd>, const N: usize>( - &self, - args: Args, - op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'a>) -> crate::Result<()> + 'static, - ) -> $crate::Result<()> { - self.modules.add_op2(args, op) - } - #[inline] fn ops_count(&self) -> usize { self.modules.ops_count() diff --git a/src/modules/base.rs b/src/modules/base.rs index a1805481..74e137f2 100644 --- a/src/modules/base.rs +++ b/src/modules/base.rs @@ -45,15 +45,6 @@ impl AddOperation for Base { op(unsafe { args.replication() }) } - #[inline] - fn add_op2<'own, 'dev, Args: Parents + crate::AnyOp2<'own, 'dev> , const N: usize>( - &self, - args: Args, - op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'a>) -> crate::Result<()> + 'static, - ) -> crate::Result<()> { - op(unsafe { args.replication() }) - } - #[inline] fn ops_count(&self) -> usize { 0 diff --git a/src/modules/cached.rs b/src/modules/cached.rs index f822b024..257f50c8 100644 --- a/src/modules/cached.rs +++ b/src/modules/cached.rs @@ -79,14 +79,6 @@ impl AddOperation for CachedModule { self.modules.add_op(args, op) } - fn add_op2<'own, 'dev, Args: Parents + crate::AnyOp2<'own, 'dev>, const N: usize>( - &self, - args: Args, - op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'a>) -> crate::Result<()> + 'static, - ) -> crate::Result<()> { - self.modules.add_op2(args, op) - } - #[inline] fn ops_count(&self) -> usize { self.modules.ops_count() diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index 406ddd94..e70db71a 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -103,20 +103,6 @@ impl AddOperation for Lazy { self.modules.add_op(args, op) } } - fn add_op2<'own, 'dev, Args: Parents + crate::AnyOp2<'own, 'dev>, const N: usize>( - &self, - args: Args, - op: impl for<'a, 'b> Fn(Args::Replicated<'a, 'a>) -> crate::Result<()> + 'static, - ) -> crate::Result<()> { - if self.enabled.get() { - self.graph.try_borrow_mut() - .expect("already borrowed: BorrowMutError - is the inner operation trying to add an operation as well?") - .add_operation2(args, op); - Ok(()) - } else { - self.modules.add_op2(args, op) - } - } #[inline] fn ops_count(&self) -> usize { diff --git a/src/modules/lazy/exec_iter.rs b/src/modules/lazy/exec_iter.rs index 750709e2..4bc87463 100644 --- a/src/modules/lazy/exec_iter.rs +++ b/src/modules/lazy/exec_iter.rs @@ -1,7 +1,7 @@ -use crate::{Buffers, Operation2}; +use crate::{Buffers, Operation}; pub struct ExecIter<'a, 'b, B, T> { - pub(super) operations: std::slice::Iter<'b, Operation2<'a, B, T>>, + pub(super) operations: std::slice::Iter<'b, Operation<'a, B, T>>, pub(super) buffers: &'b mut Buffers, } diff --git a/src/modules/lazy/lazy_graph.rs b/src/modules/lazy/lazy_graph.rs index 02efd936..9fac1bf4 100644 --- a/src/modules/lazy/lazy_graph.rs +++ b/src/modules/lazy/lazy_graph.rs @@ -5,13 +5,13 @@ use crate::{ use core::ops::RangeBounds; use std::collections::HashSet; -pub struct Operation2<'a, B, T> { +pub struct Operation<'a, B, T> { pub op: Box) -> crate::Result<()> + 'a>, pub op_hint: OpHint, } pub struct LazyGraph<'a, B = Box, T = ()> { - pub(crate) operations: Vec>, + pub(crate) operations: Vec>, } impl<'a, B, T> Default for LazyGraph<'a, B, T> { @@ -66,53 +66,10 @@ impl<'a, B: Downcast, T> LazyGraph<'a, B, T> { Ok(()) } - pub fn convert_to_operation2< - 'own, - 'dev, - Args: Parents + crate::AnyOp2<'own, 'dev>, - const N: usize, - >( - args: Args, - op: impl for<'r, 'b> Fn(Args::Replicated<'r, 'r>) -> crate::Result<()> + 'static, - ) -> Operation2<'a, B, T> { - const { assert!(N > 0, "Size of parents must be greater than 0") }; - - let mut seen_ids = HashSet::new(); - - // store ids and test if buffers are still in cache - let arg_ids = args - .maybe_ids() - .into_iter() - .flat_map(|id| { - // return error / none - let id = id.expect("every parent must have an id"); - if seen_ids.contains(&id.id) { - panic!("each parent (id) must be unique") - } - seen_ids.insert(id.id); - - Some(id) - }) - //.flat_map(|id| id.map(|id| *id)) - .collect::>(); - - if arg_ids.len() != N { - panic!() - } - - let op: Box) -> crate::Result<()>> = - Args::replication_fn::(arg_ids, op); - - Operation2 { - op, - op_hint: OpHint::None, - } - } - pub fn convert_to_operation + AnyOp, const N: usize>( args: Args, op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, - ) -> Operation2<'a, B, T> { + ) -> Operation<'a, B, T> { const { assert!(N > 0, "Size of parents must be greater than 0") }; let mut seen_ids = HashSet::new(); @@ -141,7 +98,7 @@ impl<'a, B: Downcast, T> LazyGraph<'a, B, T> { let op: Box) -> crate::Result<()>> = Args::replication_fn::(arg_ids, op); - Operation2 { + Operation { op, op_hint: OpHint::None, } @@ -155,20 +112,7 @@ impl<'a, B: Downcast, T> LazyGraph<'a, B, T> { let operation = Self::convert_to_operation(args, op); self.operations.push(operation) } - - pub fn add_operation2< - 'own, - 'dev, - Args: Parents + crate::AnyOp2<'own, 'dev>, - const N: usize, - >( - &mut self, - args: Args, - op: impl for<'r, 'b> Fn(Args::Replicated<'r, 'r>) -> crate::Result<()> + 'static, - ) { - let operation = Self::convert_to_operation2(args, op); - self.operations.push(operation) - } + } #[cfg(feature = "cpu")] diff --git a/src/parents.rs b/src/parents.rs index b3a1a5d6..ac8283e0 100644 --- a/src/parents.rs +++ b/src/parents.rs @@ -67,33 +67,6 @@ macro_rules! impl_parents { } impl<$($to_impl: $crate::HasId, )+> AllParents for ($($to_impl,)+) {} - impl<'own, 'dev, $($to_impl: $crate::Replicate2<'own, 'dev> + $crate::HasId, )+> $crate::AnyOp2<'own, 'dev> for ($($to_impl,)+) { - type Replicated<'a, 'b> = ($($to_impl::Replication<'a, 'b>,)+) where 'b: 'a; - - #[cfg(feature = "std")] - fn replication_fn( - ids: Vec<$crate::Id>, - op: impl for<'a, 'b> Fn(Self::Replicated<'a, 'a>) -> $crate::Result<()> + 'static, - ) -> Box) -> $crate::Result<()>> { - Box::new(move |buffers| { - let mut ids = ids.iter(); - - op(($( - unsafe { - $to_impl::replicate_borrowed(ids.next().unwrap(), &mut *(buffers as *mut _)).ok_or(crate::DeviceError::InvalidLazyBuf)? - } - ,)+)) - }) - } - - #[inline] - unsafe fn replication<'iown, 'idev>(self) -> Self::Replicated<'iown, 'idev> { - #[allow(non_snake_case)] - let ($($to_impl,)+) = self; - ($($to_impl.replicate(),)+) - } - } - impl<$($to_impl: $crate::Replicate + $crate::HasId, )+> $crate::AnyOp for ($($to_impl,)+) { type Replicated<'a> = ($($to_impl::Replication<'a>,)+); From b75b5dca1c44a625ab910172068434a27bec40d9 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Wed, 24 Jul 2024 02:44:20 +0200 Subject: [PATCH 13/29] Add const { assert .. } to StackArray --- Cargo.toml | 2 +- src/any_op.rs | 15 ++++++++------- src/devices/cpu/ops.rs | 14 ++++++-------- src/devices/opencl/ops.rs | 16 ++++++---------- src/devices/stack_array.rs | 22 ++++++++++++---------- src/devices/wgsl/wgsl_device.rs | 2 +- src/modules/autograd/tape.rs | 4 ++-- src/modules/base.rs | 4 ++-- src/modules/cached.rs | 4 ++-- src/modules/lazy.rs | 15 ++++++--------- src/modules/lazy/lazy_graph.rs | 28 ++++++++++++---------------- src/unary.rs | 4 ++-- 12 files changed, 60 insertions(+), 70 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index aaae07e8..7ddece96 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true } [features] # default = ["cpu", "opencl", "cuda", "blas", "static-api", "stack", "macro", "nnapi", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph", "serde"] # default = ["no-std"] -default = ["cpu", "opencl", "autograd"] +default = ["cpu", "opencl", "stack", "autograd"] # default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "cuda", "vulkan"] diff --git a/src/any_op.rs b/src/any_op.rs index 0cd183a8..fa77f1e3 100644 --- a/src/any_op.rs +++ b/src/any_op.rs @@ -24,7 +24,7 @@ pub trait AnyOp: Sized { ids: Vec, op: impl for<'a> Fn(Self::Replicated<'a>) -> crate::Result<()> + 'static, ) -> Box Fn(&'i mut Buffers) -> crate::Result<()>>; - + unsafe fn replication<'a>(self) -> Self::Replicated<'a>; } @@ -42,14 +42,14 @@ impl<'a, T: 'static, D: Device + 'static, S: crate::Shape> Replicate ) -> Option> { buffers.get(id)?.downcast_ref::>() } - + #[inline] unsafe fn replicate<'r>(self) -> Self::Replication<'r> { // TODO: this should work without this trick -> move 'own, 'dev up to the trait when something like for<'a: 'b, ...> starts to work // https://github.com/rust-lang/rust/issues/100013 // look at commit "0d54d19a52979352ec59f1619a439541e08c30a0" - it was implemented like this there // most of the "double lifetime stuff" is still implemented at the moment - // commit a985577299335ab00a02dc226a2e4b9d1642b8f7 introduced this line + // commit a985577299335ab00a02dc226a2e4b9d1642b8f7 introduced this line unsafe { core::mem::transmute::>(self) } } } @@ -78,7 +78,7 @@ impl<'a, T: 'static, D: Device + 'static, S: crate::Shape> Replicate // https://github.com/rust-lang/rust/issues/100013 // look at commit "0d54d19a52979352ec59f1619a439541e08c30a0" - it was implemented like this there // most of the "double lifetime stuff" is still implemented at the moment - // commit a985577299335ab00a02dc226a2e4b9d1642b8f7 introduced this line + // commit a985577299335ab00a02dc226a2e4b9d1642b8f7 introduced this line unsafe { core::mem::transmute::>(self) } } } @@ -93,7 +93,8 @@ impl AnyOp for R { let id = ids[0]; Box::new(move |buffers| { - let r1 = unsafe { R::replicate_borrowed(&id, buffers) }.ok_or(DeviceError::InvalidLazyBuf)?; + let r1 = unsafe { R::replicate_borrowed(&id, buffers) } + .ok_or(DeviceError::InvalidLazyBuf)?; op(r1) }) } @@ -154,7 +155,7 @@ impl<'uown, 'a, T: 'static, D: Device + 'static, S: crate::Shape> Replicate2<'uo // look at commit "0d54d19a52979352ec59f1619a439541e08c30a0" - it was implemented like this there // but than something like this happens: https://github.com/rust-lang/rust/issues/100013 // most of the "double lifetime stuff" is still implemented at the moment - // commit a985577299335ab00a02dc226a2e4b9d1642b8f7 introduced this line + // commit a985577299335ab00a02dc226a2e4b9d1642b8f7 introduced this line unsafe { core::mem::transmute::>(self) } } } @@ -182,7 +183,7 @@ impl<'uown, 'udev, T: 'static, D: Device + 'static, S: crate::Shape> Replicate2< // https://github.com/rust-lang/rust/issues/100013 // look at commit "0d54d19a52979352ec59f1619a439541e08c30a0" - it was implemented like this there // most of the "double lifetime stuff" is still implemented at the moment - // commit a985577299335ab00a02dc226a2e4b9d1642b8f7 introduced this line + // commit a985577299335ab00a02dc226a2e4b9d1642b8f7 introduced this line unsafe { core::mem::transmute::>(self) } } } diff --git a/src/devices/cpu/ops.rs b/src/devices/cpu/ops.rs index cb5530a8..c098a7ea 100644 --- a/src/devices/cpu/ops.rs +++ b/src/devices/cpu/ops.rs @@ -33,7 +33,8 @@ where self.add_op((&mut out, buf), move |(out, buf)| { apply_fn_slice(buf, out, f); Ok(()) - }).unwrap(); + }) + .unwrap(); // self.add_op((&mut out, buf, f.no_id()), move |(out, buf, f)| { // apply_fn_slice(buf, out, **f); @@ -71,13 +72,10 @@ where ) where F: Eval + MayToCLSource, { - self.add_op::<_, 3>( - (lhs, lhs_grad, out), - move |(lhs, lhs_grad, out)| { - crate::cpu_stack_ops::add_unary_grad(lhs, out, lhs_grad, lhs_grad_fn); - Ok(()) - }, - ) + self.add_op::<_, 3>((lhs, lhs_grad, out), move |(lhs, lhs_grad, out)| { + crate::cpu_stack_ops::add_unary_grad(lhs, out, lhs_grad, lhs_grad_fn); + Ok(()) + }) .unwrap(); } } diff --git a/src/devices/opencl/ops.rs b/src/devices/opencl/ops.rs index 8d87e2f5..6a415625 100644 --- a/src/devices/opencl/ops.rs +++ b/src/devices/opencl/ops.rs @@ -7,10 +7,9 @@ use min_cl::{ use crate::{ bounds_to_range, cpu_stack_ops::clear_slice, location, op_hint::unary, pass_down_add_operation, - pass_down_exec_now, prelude::Number, AddOperation, ApplyFunction, Buffer, - CDatatype, ClearBuf, CopySlice, OnDropBuffer, OpenCL, Read, Resolve, Retrieve, Retriever, - SetOpHint, Shape, ToCLSource, ToMarker, TwoWay, UnaryGrad, Unit, UseGpuOrCpu, WriteBuf, - ZeroGrad, + pass_down_exec_now, prelude::Number, AddOperation, ApplyFunction, Buffer, CDatatype, ClearBuf, + CopySlice, OnDropBuffer, OpenCL, Read, Resolve, Retrieve, Retriever, SetOpHint, Shape, + ToCLSource, ToMarker, TwoWay, UnaryGrad, Unit, UseGpuOrCpu, WriteBuf, ZeroGrad, }; use super::{enqueue_kernel, CLPtr}; @@ -311,12 +310,9 @@ where ) where F: ToCLSource, { - self.add_op( - (lhs, lhs_grad, out), - move |(lhs, lhs_grad, out)| { - try_cl_add_unary_grad(lhs.device(), lhs, lhs_grad, out, lhs_grad_fn) - }, - ) + self.add_op((lhs, lhs_grad, out), move |(lhs, lhs_grad, out)| { + try_cl_add_unary_grad(lhs.device(), lhs, lhs_grad, out, lhs_grad_fn) + }) .unwrap(); } } diff --git a/src/devices/stack_array.rs b/src/devices/stack_array.rs index 11eeea8e..39188b80 100644 --- a/src/devices/stack_array.rs +++ b/src/devices/stack_array.rs @@ -16,12 +16,12 @@ impl StackArray { /// Creates a new `StackArray`. #[inline] pub fn new() -> Self { - // TODO: one day... use const expressions - // rust 1.79 - assert!( - S::LEN > 0, - "The size (N) of a stack allocated buffer must be greater than 0." - ); + const { + assert!( + S::LEN > 0, + "The size (N) of a stack allocated buffer must be greater than 0." + ) + }; StackArray { array: S::new() } } @@ -58,10 +58,12 @@ impl Default for StackArray { impl StackArray { /// Creates a new `StackArray` from a possibly multi-dimensional array. pub fn from_array(array: S::ARR) -> Self { - assert!( - S::LEN > 0, - "The size (N) of a stack allocated buffer must be greater than 0." - ); + const { + assert!( + S::LEN > 0, + "The size (N) of a stack allocated buffer must be greater than 0." + ) + }; StackArray { array } } diff --git a/src/devices/wgsl/wgsl_device.rs b/src/devices/wgsl/wgsl_device.rs index a5d83f93..8a77c4bd 100644 --- a/src/devices/wgsl/wgsl_device.rs +++ b/src/devices/wgsl/wgsl_device.rs @@ -208,7 +208,7 @@ impl AddOperation for Wgsl { fn set_lazy_enabled(&self, enabled: bool) { self.modules.set_lazy_enabled(enabled) } - + #[inline] fn is_lazy_enabled(&self) -> bool { self.modules.is_lazy_enabled() diff --git a/src/modules/autograd/tape.rs b/src/modules/autograd/tape.rs index 10db28dc..61838bea 100644 --- a/src/modules/autograd/tape.rs +++ b/src/modules/autograd/tape.rs @@ -1,6 +1,6 @@ use crate::{ - AddOperation, Alloc, AnyOp, BoxedShallowCopy, Buffer, Buffers, GradActions, LazyGraph, - Parents, Shape, Unit, WriteBuf, ZeroGrad, + AddOperation, Alloc, AnyOp, BoxedShallowCopy, Buffer, Buffers, GradActions, LazyGraph, Parents, + Shape, Unit, WriteBuf, ZeroGrad, }; use super::Gradients; diff --git a/src/modules/base.rs b/src/modules/base.rs index 74e137f2..2df392c2 100644 --- a/src/modules/base.rs +++ b/src/modules/base.rs @@ -49,10 +49,10 @@ impl AddOperation for Base { fn ops_count(&self) -> usize { 0 } - + #[inline] fn set_lazy_enabled(&self, _enabled: bool) {} - + #[inline] fn is_lazy_enabled(&self) -> bool { false diff --git a/src/modules/cached.rs b/src/modules/cached.rs index 257f50c8..e026d900 100644 --- a/src/modules/cached.rs +++ b/src/modules/cached.rs @@ -83,12 +83,12 @@ impl AddOperation for CachedModule { fn ops_count(&self) -> usize { self.modules.ops_count() } - + #[inline] fn set_lazy_enabled(&self, enabled: bool) { self.modules.set_lazy_enabled(enabled) } - + #[inline] fn is_lazy_enabled(&self) -> bool { self.modules.is_lazy_enabled() diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index e70db71a..1487aec8 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -762,15 +762,12 @@ mod tests { let b = Buffer::::from_slice(&device, &[1, 2, 3, 4]); let vec = vec![1, 2, 3]; device - .add_op( - (&mut out, &b), - move |(out, b)| { - for ((lhs, rhs), out) in a.iter().zip(b.iter()).zip(out.iter_mut()) { - *out = lhs + rhs; - } - Ok(()) - }, - ) + .add_op((&mut out, &b), move |(out, b)| { + for ((lhs, rhs), out) in a.iter().zip(b.iter()).zip(out.iter_mut()) { + *out = lhs + rhs; + } + Ok(()) + }) .unwrap(); } diff --git a/src/modules/lazy/lazy_graph.rs b/src/modules/lazy/lazy_graph.rs index 9fac1bf4..eccb18bb 100644 --- a/src/modules/lazy/lazy_graph.rs +++ b/src/modules/lazy/lazy_graph.rs @@ -1,6 +1,6 @@ use crate::{ - bounds_to_range, modules::lazy::exec_iter::ExecIter, op_hint::OpHint, AnyOp, - BoxedShallowCopy, Buffers, Device, Downcast, Parents, + bounds_to_range, modules::lazy::exec_iter::ExecIter, op_hint::OpHint, AnyOp, BoxedShallowCopy, + Buffers, Device, Downcast, Parents, }; use core::ops::RangeBounds; use std::collections::HashSet; @@ -112,15 +112,14 @@ impl<'a, B: Downcast, T> LazyGraph<'a, B, T> { let operation = Self::convert_to_operation(args, op); self.operations.push(operation) } - } #[cfg(feature = "cpu")] #[cfg(test)] mod tests { use crate::{ - register_buf_any, register_buf_copyable, AnyBuffer, Base, Buffer, - CloneBuf, Device, HasId, LazyGraph, Retriever, Shape, UniqueId, CPU, + register_buf_any, register_buf_copyable, AnyBuffer, Base, Buffer, CloneBuf, Device, HasId, + LazyGraph, Retriever, Shape, UniqueId, CPU, }; use core::cell::Cell; use std::collections::HashMap; @@ -342,19 +341,16 @@ mod tests { // outs_unordered.insert(out.id(), ) - graph.add_operation::<_, 3>( - (&mut out, &lhs, &rhs), - move |(_out, lhs, rhs)| { - assert_eq!(lhs.as_slice(), &[1f32, 2., 3., 4., 5.,]); - assert_eq!(rhs.as_slice(), &[1f32, 2., 6., 4., 5.,]); + graph.add_operation::<_, 3>((&mut out, &lhs, &rhs), move |(_out, lhs, rhs)| { + assert_eq!(lhs.as_slice(), &[1f32, 2., 3., 4., 5.,]); + assert_eq!(rhs.as_slice(), &[1f32, 2., 6., 4., 5.,]); - for (out, lhs) in _out.iter_mut().zip(lhs.iter()) { - *out = ew_fn(*lhs); - } + for (out, lhs) in _out.iter_mut().zip(lhs.iter()) { + *out = ew_fn(*lhs); + } - Ok(()) - }, - ); + Ok(()) + }); unsafe { graph.call_lazily(&mut outs_unordered).unwrap() } } diff --git a/src/unary.rs b/src/unary.rs index 45f87e86..2dc6acee 100644 --- a/src/unary.rs +++ b/src/unary.rs @@ -1,6 +1,6 @@ use crate::{ - AddGradFn, AddOperation, Alloc, Buffer, Device, Eval, HasId, MayGradActions, - MayToCLSource, Resolve, Shape, TwoWay, Unit, ZeroGrad, + AddGradFn, AddOperation, Alloc, Buffer, Device, Eval, HasId, MayGradActions, MayToCLSource, + Resolve, Shape, TwoWay, Unit, ZeroGrad, }; /// Applies a function to a buffer and returns a new buffer. From 11e442b9a8d28c70bdd0791b792df53de2f88054 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Wed, 24 Jul 2024 02:47:28 +0200 Subject: [PATCH 14/29] Fix compile time error on stack array length 0 --- Cargo.toml | 2 +- src/devices/stack_array.rs | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7ddece96..b544a83e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true } [features] # default = ["cpu", "opencl", "cuda", "blas", "static-api", "stack", "macro", "nnapi", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph", "serde"] # default = ["no-std"] -default = ["cpu", "opencl", "stack", "autograd"] +default = ["cpu"] # default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "cuda", "vulkan"] diff --git a/src/devices/stack_array.rs b/src/devices/stack_array.rs index 39188b80..3aa40101 100644 --- a/src/devices/stack_array.rs +++ b/src/devices/stack_array.rs @@ -164,11 +164,12 @@ where #[cfg(test)] mod test { - use crate::StackArray; - - #[test] - #[should_panic] - fn test_stack_array_zero_len() { - StackArray::<(), f32>::new(); - } + // use crate::StackArray; + + // compile time error instead! + // #[test] + // #[should_panic] + // fn test_stack_array_zero_len() { + // StackArray::<(), f32>::new(); + // } } From 310c6d6623cad76c36e39ea26a62b8937aac3d19 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Wed, 24 Jul 2024 02:49:08 +0200 Subject: [PATCH 15/29] Adapt cuda add_ops --- Cargo.toml | 2 +- src/devices/cuda/ops.rs | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b544a83e..824fbadc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true } [features] # default = ["cpu", "opencl", "cuda", "blas", "static-api", "stack", "macro", "nnapi", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph", "serde"] # default = ["no-std"] -default = ["cpu"] +default = ["cuda", "opencl"] # default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "cuda", "vulkan"] diff --git a/src/devices/cuda/ops.rs b/src/devices/cuda/ops.rs index 05b72024..6428bbc4 100644 --- a/src/devices/cuda/ops.rs +++ b/src/devices/cuda/ops.rs @@ -134,8 +134,8 @@ where F: crate::TwoWay, { let mut out = self.retrieve(buf.len(), buf).unwrap(); - self.add_op((&mut out, buf, f.no_id()), |(out, buf, f)| { - try_cu_apply_fn_mut(buf.device(), buf, out, &**f) + self.add_op((&mut out, buf), move |(out, buf)| { + try_cu_apply_fn_mut(buf.device(), buf, out, &f) }) .unwrap(); self.set_op_hint(unary(f)); @@ -195,9 +195,9 @@ where F: ToCLSource, { self.add_op( - (lhs, lhs_grad.buf_no_id(), out, lhs_grad_fn.no_id()), - move |(lhs, lhs_grad, out, lhs_grad_fn)| { - try_cu_add_unary_grad(lhs.device(), lhs, lhs_grad, out, **lhs_grad_fn) + (lhs, lhs_grad, out), + move |(lhs, lhs_grad, out)| { + try_cu_add_unary_grad(lhs.device(), lhs, lhs_grad, out, lhs_grad_fn) }, ) .unwrap(); From 9af29b26348475f8e2a5d1322cb6d68de3598ee2 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Wed, 24 Jul 2024 02:51:44 +0200 Subject: [PATCH 16/29] Fix vulkan add op --- Cargo.toml | 2 +- src/devices/vulkan/ops.rs | 8 ++++---- src/devices/wgsl/ops.rs | 2 +- src/devices/wgsl/wgsl_device.rs | 15 ++++----------- 4 files changed, 10 insertions(+), 17 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 824fbadc..83ba8321 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true } [features] # default = ["cpu", "opencl", "cuda", "blas", "static-api", "stack", "macro", "nnapi", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph", "serde"] # default = ["no-std"] -default = ["cuda", "opencl"] +default = ["cuda", "vulkan"] # default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "cuda", "vulkan"] diff --git a/src/devices/vulkan/ops.rs b/src/devices/vulkan/ops.rs index 457fae6f..fb2d87b3 100644 --- a/src/devices/vulkan/ops.rs +++ b/src/devices/vulkan/ops.rs @@ -188,10 +188,10 @@ where ) where F: ToCLSource, { - self.add_op::<_, 4>( - (lhs, lhs_grad.buf_no_id(), out, lhs_grad_fn.no_id()), - move |(lhs, lhs_grad, out, lhs_grad_fn)| { - try_vk_add_unary_grad(lhs.device(), lhs, lhs_grad, out, **lhs_grad_fn) + self.add_op( + (lhs, lhs_grad, out), + move |(lhs, lhs_grad, out)| { + try_vk_add_unary_grad(lhs.device(), lhs, lhs_grad, out, lhs_grad_fn) }, ) .unwrap(); diff --git a/src/devices/wgsl/ops.rs b/src/devices/wgsl/ops.rs index 9de9e72d..655364a9 100644 --- a/src/devices/wgsl/ops.rs +++ b/src/devices/wgsl/ops.rs @@ -47,7 +47,7 @@ where { let mut out = self.retrieve(buf.len(), buf).unwrap(); - self.add_op((&mut out, buf, f.no_id()), move |(out, buf, f)| { + self.add_op((&mut out, buf), move |(out, buf)| { let src = format!( " @group(0) diff --git a/src/devices/wgsl/wgsl_device.rs b/src/devices/wgsl/wgsl_device.rs index 8a77c4bd..a9adb1a4 100644 --- a/src/devices/wgsl/wgsl_device.rs +++ b/src/devices/wgsl/wgsl_device.rs @@ -182,21 +182,12 @@ impl, T: Unit, Mods: Retrieve, S: Shape> Retrie } impl AddOperation for Wgsl { - fn add_op2 + crate::AnyOp, const N: usize>( + fn add_op + crate::AnyOp, const N: usize>( &self, args: Args, op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, - ) { - self.modules.add_op2(args, op) - } - - #[inline] - fn add_op + crate::UpdateArgs, const N: usize>( - &self, - args: Args, - operation: fn(&mut Args) -> crate::Result<()>, ) -> crate::Result<()> { - self.modules.add_op(args, operation) + self.modules.add_op(args, op) } #[inline] @@ -217,10 +208,12 @@ impl AddOperation for Wgsl { #[cfg(test)] mod tests { + #[cfg(feature = "vulkan")] use crate::{Device, Vulkan}; use super::Wgsl; + #[cfg(feature = "vulkan")] #[test] fn test_wgsl_wrapper() { let dev = Wgsl::::new(0).unwrap(); From 15effb72850c013b75a929f67ad49a4aef21bd57 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Wed, 24 Jul 2024 02:55:03 +0200 Subject: [PATCH 17/29] Remove deref inside unified cl cond try_cl_apply_fn_mut --- Cargo.toml | 2 +- src/devices/opencl/ops.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 83ba8321..37cb1784 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true } [features] # default = ["cpu", "opencl", "cuda", "blas", "static-api", "stack", "macro", "nnapi", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph", "serde"] # default = ["no-std"] -default = ["cuda", "vulkan"] +default = ["opencl"] # default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "cuda", "vulkan"] diff --git a/src/devices/opencl/ops.rs b/src/devices/opencl/ops.rs index 6a415625..bb56d881 100644 --- a/src/devices/opencl/ops.rs +++ b/src/devices/opencl/ops.rs @@ -249,7 +249,7 @@ where } #[cfg(not(unified_cl))] { - try_cl_apply_fn_mut(dev, buf, out, **f)?; + try_cl_apply_fn_mut(dev, buf, out, f)?; Ok(()) } }) From 578bcb2ed446064a7848df756d2bcd42f28a3829 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Wed, 24 Jul 2024 02:58:26 +0200 Subject: [PATCH 18/29] Adapt cuda add_op test in lazy.rs --- Cargo.toml | 2 +- src/devices/cuda/lazy.rs | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 37cb1784..50754791 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true } [features] # default = ["cpu", "opencl", "cuda", "blas", "static-api", "stack", "macro", "nnapi", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph", "serde"] # default = ["no-std"] -default = ["opencl"] +default = ["cuda", "lazy"] # default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "cuda", "vulkan"] diff --git a/src/devices/cuda/lazy.rs b/src/devices/cuda/lazy.rs index bfc8d462..e8b5b115 100644 --- a/src/devices/cuda/lazy.rs +++ b/src/devices/cuda/lazy.rs @@ -265,14 +265,14 @@ mod tests { device .add_op( - (lhs, rhs, &mut out, src.no_id(), fn_name.no_id()), - |(lhs, rhs, out, src, fn_name)| { + (lhs, rhs, &mut out), + move |(lhs, rhs, out)| { let device = lhs.device(); device.launch_kernel1d( lhs.len(), - &**src, + &src, fn_name, - &[lhs, rhs, *out, &lhs.len()], + &[lhs, rhs, out, &lhs.len()], ) }, ) From df9226807bd9d04ea5dddff887313c8f292a1a77 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Wed, 24 Jul 2024 14:58:19 +0200 Subject: [PATCH 19/29] Add Operation::no_op --- Cargo.toml | 2 +- src/modules/lazy/lazy_graph.rs | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 50754791..6b7cd602 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true } [features] # default = ["cpu", "opencl", "cuda", "blas", "static-api", "stack", "macro", "nnapi", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph", "serde"] # default = ["no-std"] -default = ["cuda", "lazy"] +default = ["vulkan", "graph", "autograd", "cpu"] # default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "cuda", "vulkan"] diff --git a/src/modules/lazy/lazy_graph.rs b/src/modules/lazy/lazy_graph.rs index eccb18bb..0e1a2d9e 100644 --- a/src/modules/lazy/lazy_graph.rs +++ b/src/modules/lazy/lazy_graph.rs @@ -10,6 +10,15 @@ pub struct Operation<'a, B, T> { pub op_hint: OpHint, } +impl<'a, B, T> Operation<'a, B, T> { + pub fn no_op() -> Self { + Self { + op: Box::new(|_buffers| Ok(())), + op_hint: OpHint::None + } + } +} + pub struct LazyGraph<'a, B = Box, T = ()> { pub(crate) operations: Vec>, } From ebadc32e189503611b383fcb6760c96d2554dc7c Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Wed, 24 Jul 2024 16:50:49 +0200 Subject: [PATCH 20/29] Remove lifetime parameter from LazyGraph --- src/modules/lazy/exec_iter.rs | 10 +++++----- src/modules/lazy/lazy_graph.rs | 23 ++++++++++++----------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/modules/lazy/exec_iter.rs b/src/modules/lazy/exec_iter.rs index 4bc87463..bc4ab16b 100644 --- a/src/modules/lazy/exec_iter.rs +++ b/src/modules/lazy/exec_iter.rs @@ -1,11 +1,11 @@ use crate::{Buffers, Operation}; -pub struct ExecIter<'a, 'b, B, T> { - pub(super) operations: std::slice::Iter<'b, Operation<'a, B, T>>, +pub struct ExecIter<'b, B, T> { + pub(super) operations: std::slice::Iter<'b, Operation>, pub(super) buffers: &'b mut Buffers, } -impl<'a, 'b, B, T> Iterator for ExecIter<'a, 'b, B, T> { +impl<'b, B, T> Iterator for ExecIter<'b, B, T> { type Item = crate::Result<()>; fn next(&mut self) -> Option { @@ -14,14 +14,14 @@ impl<'a, 'b, B, T> Iterator for ExecIter<'a, 'b, B, T> { } } -impl<'a, 'b, B, T> DoubleEndedIterator for ExecIter<'a, 'b, B, T> { +impl<'b, B, T> DoubleEndedIterator for ExecIter<'b, B, T> { fn next_back(&mut self) -> Option { let op = self.operations.next_back()?; Some((op.op)(self.buffers)) } } -impl<'a, 'b, B, T> ExactSizeIterator for ExecIter<'a, 'b, B, T> { +impl<'b, B, T> ExactSizeIterator for ExecIter<'b, B, T> { fn len(&self) -> usize { self.operations.len() } diff --git a/src/modules/lazy/lazy_graph.rs b/src/modules/lazy/lazy_graph.rs index 0e1a2d9e..5d50b70e 100644 --- a/src/modules/lazy/lazy_graph.rs +++ b/src/modules/lazy/lazy_graph.rs @@ -2,28 +2,29 @@ use crate::{ bounds_to_range, modules::lazy::exec_iter::ExecIter, op_hint::OpHint, AnyOp, BoxedShallowCopy, Buffers, Device, Downcast, Parents, }; -use core::ops::RangeBounds; +use core::{marker::PhantomData, ops::RangeBounds}; use std::collections::HashSet; -pub struct Operation<'a, B, T> { - pub op: Box) -> crate::Result<()> + 'a>, +pub struct Operation { + pub op: Box) -> crate::Result<()> + 'static>, pub op_hint: OpHint, + // pub pd: PhantomData<&'a ()>, } -impl<'a, B, T> Operation<'a, B, T> { +impl Operation { pub fn no_op() -> Self { Self { op: Box::new(|_buffers| Ok(())), - op_hint: OpHint::None + op_hint: OpHint::None, } } } -pub struct LazyGraph<'a, B = Box, T = ()> { - pub(crate) operations: Vec>, +pub struct LazyGraph, T = ()> { + pub(crate) operations: Vec>, } -impl<'a, B, T> Default for LazyGraph<'a, B, T> { +impl Default for LazyGraph { #[inline] fn default() -> Self { Self { @@ -32,13 +33,13 @@ impl<'a, B, T> Default for LazyGraph<'a, B, T> { } } -impl<'a, B: Downcast, T> LazyGraph<'a, B, T> { +impl LazyGraph { #[inline] pub fn iter_with<'b>( &'b mut self, // device: &'a D, buffers: &'b mut Buffers, - ) -> ExecIter<'a, 'b, B, T> { + ) -> ExecIter<'b, B, T> { ExecIter { operations: self.operations.iter(), buffers, @@ -78,7 +79,7 @@ impl<'a, B: Downcast, T> LazyGraph<'a, B, T> { pub fn convert_to_operation + AnyOp, const N: usize>( args: Args, op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static, - ) -> Operation<'a, B, T> { + ) -> Operation { const { assert!(N > 0, "Size of parents must be greater than 0") }; let mut seen_ids = HashSet::new(); From c1b8e882511ae8d75e10ecbe1e13dca1ac064978 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Thu, 25 Jul 2024 01:05:44 +0200 Subject: [PATCH 21/29] Remove lifetimes from lazygraph usages --- Cargo.toml | 2 +- src/devices/fusing.rs | 4 ++-- src/modules/autograd/tape.rs | 5 ++++- src/modules/lazy.rs | 2 +- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6b7cd602..4fc70487 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true } [features] # default = ["cpu", "opencl", "cuda", "blas", "static-api", "stack", "macro", "nnapi", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph", "serde"] # default = ["no-std"] -default = ["vulkan", "graph", "autograd", "cpu"] +default = ["vulkan", "graph", "lazy", "autograd", "cpu"] # default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "cuda", "vulkan"] diff --git a/src/devices/fusing.rs b/src/devices/fusing.rs index 23b9ec87..d64e64bf 100644 --- a/src/devices/fusing.rs +++ b/src/devices/fusing.rs @@ -72,12 +72,12 @@ pub trait UnaryFusing: IsShapeIndep { .collect::>(); let out = unsafe { - &mut *(buffers.get_mut(&last_arg_ids[0]).unwrap().as_any_mut() + &mut *(buffers.get_mut(&last_arg_ids[0]).unwrap() as *mut _ as *mut Buffer) }; let buf = unsafe { - &*(buffers.get(&first_arg_ids[1]).unwrap().as_any() as *const Buffer) + &*(buffers.get(&first_arg_ids[1]).unwrap() as *const _ as *const Buffer) }; let op = self.unary_fuse_op::(); diff --git a/src/modules/autograd/tape.rs b/src/modules/autograd/tape.rs index 61838bea..13be712a 100644 --- a/src/modules/autograd/tape.rs +++ b/src/modules/autograd/tape.rs @@ -1,3 +1,5 @@ +use core::marker::PhantomData; + use crate::{ AddOperation, Alloc, AnyOp, BoxedShallowCopy, Buffer, Buffers, GradActions, LazyGraph, Parents, Shape, Unit, WriteBuf, ZeroGrad, @@ -10,7 +12,8 @@ pub type GradFn = Box; /// Stores the grad functions and gradient cache. #[derive(Default)] pub struct Tape<'a> { - pub lazy_graph: LazyGraph<'a, Box>, + pub lazy_graph: LazyGraph>, + pd: PhantomData<&'a ()>, } impl<'t> Tape<'t> { diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index 1487aec8..57a54c7c 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -42,7 +42,7 @@ pub struct Lazy { // This ensures to only allocate a buffer once, without having to remove the ID/address collision check // TODO: remove this, fix id and address collision - then just use `buffers` for duplicate calls allocated_ids: RefCell, - pub graph: RefCell, T>>, + pub graph: RefCell, T>>, cursor: Cell, enabled: Cell, pd: PhantomData, From c4c20efed22849583846e5a7844a6c4ce2a4634e Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Thu, 25 Jul 2024 01:58:59 +0200 Subject: [PATCH 22/29] Add ops_to_fuse arg to unary_fuse_op, fix --- Cargo.toml | 2 +- src/devices/cpu/cpu_device.rs | 16 +++++------ src/devices/cuda/fusing.rs | 26 +++++++++--------- src/devices/cuda/lazy.rs | 16 +++-------- src/devices/cuda/ops.rs | 9 +++---- src/devices/fusing.rs | 49 +++++++++++++--------------------- src/devices/opencl/fusing.rs | 26 +++++++++--------- src/devices/vulkan/ops.rs | 9 +++---- src/modules/lazy/lazy_graph.rs | 9 ++++--- 9 files changed, 70 insertions(+), 92 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4fc70487..aefd3e16 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true } [features] # default = ["cpu", "opencl", "cuda", "blas", "static-api", "stack", "macro", "nnapi", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph", "serde"] # default = ["no-std"] -default = ["vulkan", "graph", "lazy", "autograd", "cpu"] +default = ["vulkan", "graph", "lazy", "autograd", "cpu", "opencl", "cuda"] # default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "cuda", "vulkan"] diff --git a/src/devices/cpu/cpu_device.rs b/src/devices/cpu/cpu_device.rs index 0d92fbdf..1d923740 100644 --- a/src/devices/cpu/cpu_device.rs +++ b/src/devices/cpu/cpu_device.rs @@ -204,17 +204,13 @@ impl UnaryFusing for CPU { #[inline] fn unary_fuse_op( &self, - ) -> fn( - &mut ( - &mut Buffer<'_, T, Self, ()>, - &Buffer<'_, T, Self, ()>, - crate::NoId) -> Box>>>>, - ), - ) -> crate::Result<()> { - |(out, buf, ops)| { + ops_to_fuse: Vec) -> Box>>>, + ) -> Box, &Buffer<'_, T, Self, ()>)) -> crate::Result<()>> + { + Box::new(move |(out, buf)| { for (out, buf) in out.iter_mut().zip(buf.iter()) { let mut current_val = *buf; - for op in ops.iter() { + for op in ops_to_fuse.iter() { let resolve = crate::Resolve { val: current_val, marker: "x", @@ -224,7 +220,7 @@ impl UnaryFusing for CPU { *out = current_val; } Ok(()) - } + }) } } diff --git a/src/devices/cuda/fusing.rs b/src/devices/cuda/fusing.rs index b7ef7b70..c03c4397 100644 --- a/src/devices/cuda/fusing.rs +++ b/src/devices/cuda/fusing.rs @@ -6,20 +6,22 @@ impl UnaryFusing for CUDA { #[inline] fn unary_fuse_op( &self, - ) -> fn( - &mut ( - &mut crate::Buffer<'_, T, Self, ()>, - &crate::Buffer<'_, T, Self, ()>, - crate::NoId) -> Box>>>>, - ), - ) -> crate::Result<()> { + ops_to_fuse: Vec) -> Box>>>, + ) -> Box< + dyn Fn( + ( + &mut crate::Buffer<'_, T, Self, ()>, + &crate::Buffer<'_, T, Self, ()>, + ), + ) -> crate::Result<()>, + > { use crate::operations_to_fused_src; - |(out, buf, ops)| { - if ops.is_empty() { + Box::new(move |(out, buf)| { + if ops_to_fuse.is_empty() { return Ok(()); } - let fused_operations = operations_to_fused_src(&ops); + let fused_operations = operations_to_fused_src(&ops_to_fuse); let src = format!( r#"extern "C" __global__ void applyFn({datatype}* lhs, {datatype}* out, int numElements) @@ -41,8 +43,8 @@ impl UnaryFusing for CUDA { [(buf.len() as u32 / 32 + 1) * 32, 1, 1], [32, 1, 1], 0, - &[buf, *out, &buf.len()], + &[buf, out, &buf.len()], ) - } + }) } } diff --git a/src/devices/cuda/lazy.rs b/src/devices/cuda/lazy.rs index e8b5b115..e61650be 100644 --- a/src/devices/cuda/lazy.rs +++ b/src/devices/cuda/lazy.rs @@ -264,18 +264,10 @@ mod tests { let mut out = device.retrieve(lhs.len(), (lhs.id(), rhs.id())).unwrap(); device - .add_op( - (lhs, rhs, &mut out), - move |(lhs, rhs, out)| { - let device = lhs.device(); - device.launch_kernel1d( - lhs.len(), - &src, - fn_name, - &[lhs, rhs, out, &lhs.len()], - ) - }, - ) + .add_op((lhs, rhs, &mut out), move |(lhs, rhs, out)| { + let device = lhs.device(); + device.launch_kernel1d(lhs.len(), &src, fn_name, &[lhs, rhs, out, &lhs.len()]) + }) .unwrap(); out diff --git a/src/devices/cuda/ops.rs b/src/devices/cuda/ops.rs index 6428bbc4..c9455f9a 100644 --- a/src/devices/cuda/ops.rs +++ b/src/devices/cuda/ops.rs @@ -194,12 +194,9 @@ where ) where F: ToCLSource, { - self.add_op( - (lhs, lhs_grad, out), - move |(lhs, lhs_grad, out)| { - try_cu_add_unary_grad(lhs.device(), lhs, lhs_grad, out, lhs_grad_fn) - }, - ) + self.add_op((lhs, lhs_grad, out), move |(lhs, lhs_grad, out)| { + try_cu_add_unary_grad(lhs.device(), lhs, lhs_grad, out, lhs_grad_fn) + }) .unwrap(); } } diff --git a/src/devices/fusing.rs b/src/devices/fusing.rs index d64e64bf..5dce70ee 100644 --- a/src/devices/fusing.rs +++ b/src/devices/fusing.rs @@ -23,21 +23,23 @@ pub trait UnaryFusing: IsShapeIndep { #[cfg(feature = "graph")] fn unary_fuse_op( &self, - ) -> fn( - &mut ( - &mut crate::Buffer<'_, T, Self, ()>, - &crate::Buffer<'_, T, Self, ()>, - crate::NoId) -> Box>>>>, - ), - ) -> crate::Result<()>; + ops_to_fuse: Vec) -> Box>>>, + ) -> Box< + dyn Fn( + ( + &mut crate::Buffer<'_, T, Self, ()>, + &crate::Buffer<'_, T, Self, ()>, + ), + ) -> crate::Result<()>, + >; #[cfg(feature = "lazy")] #[cfg(feature = "graph")] /// # Safety /// Does not check if specific retrieved buffers contain data of type `T`. - unsafe fn fuse_unary_ops( - &self, - lazy_graph: &crate::LazyGraph, T>, + unsafe fn fuse_unary_ops<'a, T: crate::CDatatype + crate::Numeric>( + &'a self, + lazy_graph: &'a crate::LazyGraph, T>, ops: ( Vec) -> Box>>>, Vec, @@ -47,45 +49,32 @@ pub trait UnaryFusing: IsShapeIndep { where Self: 'static, { - use crate::{AsAny, AsNoId, Buffer}; + use crate::Buffer; let (ops, affected_op_idxs) = ops; let to_insert_idx: usize = affected_op_idxs[0]; let first_op = &lazy_graph.operations[to_insert_idx]; - let first_arg_ids = first_op - .arg_ids - .iter() - .flatten() - .copied() - .collect::>(); + let first_arg_ids = &first_op.arg_ids; let last_op = &lazy_graph.operations[*affected_op_idxs.last().unwrap()]; // use last op in the unary fuse chain as the output buffer - let last_arg_ids = last_op - .arg_ids - .iter() - .flatten() - .copied() - .collect::>(); - + let last_arg_ids = &last_op.arg_ids; let out = unsafe { - &mut *(buffers.get_mut(&last_arg_ids[0]).unwrap() as *mut _ - as *mut Buffer) + &mut *(buffers.get_mut(&last_arg_ids[0]).unwrap() as *mut _ as *mut Buffer) }; let buf = unsafe { &*(buffers.get(&first_arg_ids[1]).unwrap() as *const _ as *const Buffer) }; - let op = self.unary_fuse_op::(); - let mut operation = - unsafe { crate::LazyGraph::convert_to_operation((out, buf, ops.no_id()), op) }; + let op = self.unary_fuse_op::(ops); + let mut operation = crate::LazyGraph::convert_to_operation((out, buf), op); // using the buffers out of the 'buffers' hashmaps results in using allocated buffers that are not in the 'buffers' hashmap // if the lazy graph is executed, it updates the references to the corresponding buffers -> new ids would not be found -> invalid lazy buffer panic - operation.arg_ids = vec![Some(last_arg_ids[0]), Some(first_arg_ids[1]), None]; + operation.arg_ids = vec![last_arg_ids[0], first_arg_ids[1]]; operation.op_hint = crate::op_hint::OpHint::UnaryFused; (to_insert_idx, operation) diff --git a/src/devices/opencl/fusing.rs b/src/devices/opencl/fusing.rs index bee9492d..32e2d598 100644 --- a/src/devices/opencl/fusing.rs +++ b/src/devices/opencl/fusing.rs @@ -6,21 +6,23 @@ impl UnaryFusing for OpenCL { #[inline] fn unary_fuse_op( &self, - ) -> fn( - &mut ( - &mut crate::Buffer<'_, T, Self, ()>, - &crate::Buffer<'_, T, Self, ()>, - crate::NoId) -> Box>>>>, - ), - ) -> crate::Result<()> { + ops_to_fuse: Vec) -> Box>>>, + ) -> Box< + dyn Fn( + ( + &mut crate::Buffer<'_, T, Self, ()>, + &crate::Buffer<'_, T, Self, ()>, + ), + ) -> crate::Result<()>, + > { use crate::operations_to_fused_src; - |(out, buf, ops)| { - if ops.is_empty() { + Box::new(move |(out, buf)| { + if ops_to_fuse.is_empty() { return Ok(()); } - let fused_operations = operations_to_fused_src(ops); + let fused_operations = operations_to_fused_src(&ops_to_fuse); let src = format!( " @@ -41,8 +43,8 @@ impl UnaryFusing for OpenCL { &src, [(buf.len() / 32 + 1) * 32, 0, 0], Some([32, 0, 0]), - &[buf, *out, &buf.len()], + &[buf, out, &buf.len()], ) - } + }) } } diff --git a/src/devices/vulkan/ops.rs b/src/devices/vulkan/ops.rs index fb2d87b3..851598d2 100644 --- a/src/devices/vulkan/ops.rs +++ b/src/devices/vulkan/ops.rs @@ -188,12 +188,9 @@ where ) where F: ToCLSource, { - self.add_op( - (lhs, lhs_grad, out), - move |(lhs, lhs_grad, out)| { - try_vk_add_unary_grad(lhs.device(), lhs, lhs_grad, out, lhs_grad_fn) - }, - ) + self.add_op((lhs, lhs_grad, out), move |(lhs, lhs_grad, out)| { + try_vk_add_unary_grad(lhs.device(), lhs, lhs_grad, out, lhs_grad_fn) + }) .unwrap(); } } diff --git a/src/modules/lazy/lazy_graph.rs b/src/modules/lazy/lazy_graph.rs index 5d50b70e..0115478c 100644 --- a/src/modules/lazy/lazy_graph.rs +++ b/src/modules/lazy/lazy_graph.rs @@ -1,11 +1,12 @@ use crate::{ bounds_to_range, modules::lazy::exec_iter::ExecIter, op_hint::OpHint, AnyOp, BoxedShallowCopy, - Buffers, Device, Downcast, Parents, + Buffers, Device, Downcast, Id, Parents, }; -use core::{marker::PhantomData, ops::RangeBounds}; +use core::ops::RangeBounds; use std::collections::HashSet; pub struct Operation { + pub arg_ids: Vec, pub op: Box) -> crate::Result<()> + 'static>, pub op_hint: OpHint, // pub pd: PhantomData<&'a ()>, @@ -15,6 +16,7 @@ impl Operation { pub fn no_op() -> Self { Self { op: Box::new(|_buffers| Ok(())), + arg_ids: vec![], op_hint: OpHint::None, } } @@ -106,9 +108,10 @@ impl LazyGraph { } let op: Box) -> crate::Result<()>> = - Args::replication_fn::(arg_ids, op); + Args::replication_fn::(arg_ids.clone(), op); Operation { + arg_ids, op, op_hint: OpHint::None, } From fe19c9e173bbe85407ca7a21388773fdd6aae9d1 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Thu, 25 Jul 2024 16:10:21 +0200 Subject: [PATCH 23/29] Add device param to LazyGraph operation, update device in replicated buffer --- Cargo.toml | 2 +- src/any_op.rs | 32 +++++++++++++++++++++++------ src/modules/autograd/tape.rs | 11 +++++----- src/modules/lazy.rs | 21 ++++++++++++------- src/modules/lazy/exec_iter.rs | 19 ++++++++--------- src/modules/lazy/lazy_graph.rs | 37 +++++++++++++++++++--------------- src/parents.rs | 6 +++--- 7 files changed, 81 insertions(+), 47 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index aefd3e16..fac63475 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true } [features] # default = ["cpu", "opencl", "cuda", "blas", "static-api", "stack", "macro", "nnapi", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph", "serde"] # default = ["no-std"] -default = ["vulkan", "graph", "lazy", "autograd", "cpu", "opencl", "cuda"] +default = ["vulkan", "lazy", "autograd", "cpu", "opencl", "cuda"] # default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "cuda", "vulkan"] diff --git a/src/any_op.rs b/src/any_op.rs index fa77f1e3..2793cf3f 100644 --- a/src/any_op.rs +++ b/src/any_op.rs @@ -11,6 +11,21 @@ pub trait Replicate { unsafe fn replicate_borrowed<'r, B: Downcast>( id: &Id, buffers: &'r mut Buffers, + device: Option<&'r dyn core::any::Any>, + ) -> Option>; + + unsafe fn replicate<'a>(self) -> Self::Replication<'a>; +} + +pub trait Replicate2 { + type Replication<'r>; + type Downcast<'r>: 'r; + + #[cfg(feature = "std")] + unsafe fn replicate_borrowed<'r, B: Downcast>( + id: &Id, + buffers: &'r mut Buffers, + device: Option<&'r D>, ) -> Option>; unsafe fn replicate<'a>(self) -> Self::Replication<'a>; @@ -23,7 +38,7 @@ pub trait AnyOp: Sized { fn replication_fn( ids: Vec, op: impl for<'a> Fn(Self::Replicated<'a>) -> crate::Result<()> + 'static, - ) -> Box Fn(&'i mut Buffers) -> crate::Result<()>>; + ) -> Box Fn(&'i mut Buffers, &dyn core::any::Any) -> crate::Result<()>>; unsafe fn replication<'a>(self) -> Self::Replicated<'a>; } @@ -39,8 +54,10 @@ impl<'a, T: 'static, D: Device + 'static, S: crate::Shape> Replicate unsafe fn replicate_borrowed<'r, B: Downcast>( id: &Id, buffers: &'r mut Buffers, + device: Option<&'r dyn core::any::Any>, ) -> Option> { - buffers.get(id)?.downcast_ref::>() + <&mut Buffer as Replicate>::replicate_borrowed(id, buffers, device) + .map(|buf| &*buf) } #[inline] @@ -64,12 +81,15 @@ impl<'a, T: 'static, D: Device + 'static, S: crate::Shape> Replicate unsafe fn replicate_borrowed<'r, B: Downcast>( id: &Id, buffers: &'r mut Buffers, + device: Option<&'r dyn core::any::Any>, ) -> Option> { let replication = buffers.get_mut(id)?; if !replication.is::>() { return None; } - Some(unsafe { replication.downcast_mut_unchecked::>() }) + let buf = unsafe { replication.downcast_mut_unchecked::>() }; + buf.device = device.map(|dev| dev.downcast_ref::().unwrap()); + Some(buf) } #[inline] @@ -88,12 +108,12 @@ impl AnyOp for R { fn replication_fn( ids: Vec, op: impl for<'a> Fn(Self::Replicated<'a>) -> crate::Result<()> + 'static, - ) -> Box) -> crate::Result<()>> { + ) -> Box, &dyn core::any::Any) -> crate::Result<()>> { use crate::DeviceError; let id = ids[0]; - Box::new(move |buffers| { - let r1 = unsafe { R::replicate_borrowed(&id, buffers) } + Box::new(move |buffers, dev| { + let r1 = unsafe { R::replicate_borrowed(&id, buffers, Some(dev)) } .ok_or(DeviceError::InvalidLazyBuf)?; op(r1) }) diff --git a/src/modules/autograd/tape.rs b/src/modules/autograd/tape.rs index 13be712a..ce5bf93a 100644 --- a/src/modules/autograd/tape.rs +++ b/src/modules/autograd/tape.rs @@ -1,8 +1,8 @@ use core::marker::PhantomData; use crate::{ - AddOperation, Alloc, AnyOp, BoxedShallowCopy, Buffer, Buffers, GradActions, LazyGraph, Parents, - Shape, Unit, WriteBuf, ZeroGrad, + AddOperation, Alloc, AnyOp, BoxedShallowCopy, Buffer, Buffers, Device, GradActions, LazyGraph, + Parents, Shape, Unit, WriteBuf, ZeroGrad, }; use super::Gradients; @@ -27,12 +27,13 @@ impl<'t> Tape<'t> { } /// Calls all gradient functions in reverse order. - pub fn backward( + pub fn backward( &mut self, + device: &D, buffers: &mut Buffers>, lazy_enabled: bool, ) { - for val in self.lazy_graph.iter_with(buffers).rev() { + for val in self.lazy_graph.iter_with(device, buffers).rev() { val.unwrap(); } if !lazy_enabled { @@ -65,7 +66,7 @@ impl<'t> Tape<'t> { let is_lazy_enabled = buf.device().is_lazy_enabled(); buf.device() - .eagerly(|| self.backward(buffers, is_lazy_enabled)); + .eagerly(|| self.backward(buf.device(), buffers, is_lazy_enabled)); } pub fn backward_seeded_maybe_with_buffers<'a, T, D, S: Shape>( diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index 57a54c7c..08383817 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -138,9 +138,11 @@ impl ExecNow for Lazy { ) -> crate::Result<()> { self.alloc_later(device); unsafe { - self.graph - .borrow_mut() - .call_range::(range_bounds, &mut self.buffers.borrow_mut())?; + self.graph.borrow_mut().call_range::( + device, + range_bounds, + &mut self.buffers.borrow_mut(), + )?; } Ok(()) } @@ -148,10 +150,10 @@ impl ExecNow for Lazy { impl Lazy { #[inline] - pub unsafe fn call_lazily(&self) -> crate::Result<()> { + pub unsafe fn call_lazily(&self, device: &D) -> crate::Result<()> { self.graph .borrow_mut() - .call_lazily(&mut self.buffers.borrow_mut())?; + .call_lazily(device, &mut self.buffers.borrow_mut())?; Ok(()) } @@ -176,7 +178,7 @@ impl, D: LazyRun + Device + 'static> RunModule for Lazy #[inline] fn run(&self, device: &D) -> crate::Result<()> { self.alloc_later(device); - unsafe { self.call_lazily::()? }; + unsafe { self.call_lazily::(device)? }; device.run()?; self.modules.run(device) } @@ -527,7 +529,12 @@ mod tests { // assert_eq!(out.read(), &[0; 10]); -- should not work device.modules.alloc_later(&device); - unsafe { device.modules.call_lazily::>>().unwrap() } + unsafe { + device + .modules + .call_lazily::>>(&device) + .unwrap() + } // assert_eq!(out.read(), &[3; 10]); -- should work assert_eq!(out.replace().read(), &[3; 10]); drop(buf); diff --git a/src/modules/lazy/exec_iter.rs b/src/modules/lazy/exec_iter.rs index bc4ab16b..68e88def 100644 --- a/src/modules/lazy/exec_iter.rs +++ b/src/modules/lazy/exec_iter.rs @@ -1,27 +1,28 @@ -use crate::{Buffers, Operation}; +use crate::{Buffers, Device, Operation}; -pub struct ExecIter<'b, B, T> { - pub(super) operations: std::slice::Iter<'b, Operation>, - pub(super) buffers: &'b mut Buffers, +pub struct ExecIter<'_6, B, T, D> { + pub(super) operations: std::slice::Iter<'_6, Operation>, + pub(super) buffers: &'_6 mut Buffers, + pub(super) device: &'_6 D, } -impl<'b, B, T> Iterator for ExecIter<'b, B, T> { +impl<'b, B, T, D: Device + 'static> Iterator for ExecIter<'b, B, T, D> { type Item = crate::Result<()>; fn next(&mut self) -> Option { let op = self.operations.next()?; - Some((op.op)(self.buffers)) + Some((op.op)(self.buffers, self.device)) } } -impl<'b, B, T> DoubleEndedIterator for ExecIter<'b, B, T> { +impl<'b, B, T, D: Device + 'static> DoubleEndedIterator for ExecIter<'b, B, T, D> { fn next_back(&mut self) -> Option { let op = self.operations.next_back()?; - Some((op.op)(self.buffers)) + Some((op.op)(self.buffers, self.device)) } } -impl<'b, B, T> ExactSizeIterator for ExecIter<'b, B, T> { +impl<'b, B, T, D: Device + 'static> ExactSizeIterator for ExecIter<'b, B, T, D> { fn len(&self) -> usize { self.operations.len() } diff --git a/src/modules/lazy/lazy_graph.rs b/src/modules/lazy/lazy_graph.rs index 0115478c..d8080487 100644 --- a/src/modules/lazy/lazy_graph.rs +++ b/src/modules/lazy/lazy_graph.rs @@ -7,7 +7,7 @@ use std::collections::HashSet; pub struct Operation { pub arg_ids: Vec, - pub op: Box) -> crate::Result<()> + 'static>, + pub op: Box, &dyn core::any::Any) -> crate::Result<()> + 'static>, pub op_hint: OpHint, // pub pd: PhantomData<&'a ()>, } @@ -15,7 +15,7 @@ pub struct Operation { impl Operation { pub fn no_op() -> Self { Self { - op: Box::new(|_buffers| Ok(())), + op: Box::new(|_buffers, _dev| Ok(())), arg_ids: vec![], op_hint: OpHint::None, } @@ -37,14 +37,15 @@ impl Default for LazyGraph { impl LazyGraph { #[inline] - pub fn iter_with<'b>( + pub fn iter_with<'b, D: Device>( &'b mut self, - // device: &'a D, + device: &'b D, buffers: &'b mut Buffers, - ) -> ExecIter<'b, B, T> { + ) -> ExecIter<'b, B, T, D> { ExecIter { operations: self.operations.iter(), buffers, + device, } } @@ -58,8 +59,12 @@ impl LazyGraph { self.operations.len() } - pub unsafe fn call_lazily(&mut self, buffers: &mut Buffers) -> crate::Result<()> { - for args in self.iter_with(buffers) { + pub unsafe fn call_lazily( + &mut self, + device: &D, + buffers: &mut Buffers, + ) -> crate::Result<()> { + for args in self.iter_with(device, buffers) { args?; } Ok(()) @@ -67,13 +72,13 @@ impl LazyGraph { pub unsafe fn call_range( &mut self, - // _device: &'a D, + device: &D, bounds: impl RangeBounds, buffers: &mut Buffers, ) -> crate::Result<()> { let range = bounds_to_range(bounds, self.operations.len()); for op in self.operations.drain(range) { - (op.op)(buffers)?; + (op.op)(buffers, device)?; } Ok(()) } @@ -107,7 +112,7 @@ impl LazyGraph { panic!() } - let op: Box) -> crate::Result<()>> = + let op: Box, &dyn core::any::Any) -> crate::Result<()>> = Args::replication_fn::(arg_ids.clone(), op); Operation { @@ -207,7 +212,7 @@ mod tests { println!("args: {args:?}"); Ok(()) }); - unsafe { graph.call_lazily(&mut buffers).unwrap() }; + unsafe { graph.call_lazily(&device, &mut buffers).unwrap() }; }; // let x = DEVICE2.get().unwrap(); // println!("{:?}", x.modules.cache.borrow().nodes); @@ -245,7 +250,7 @@ mod tests { }; // todo!() - unsafe { graph.call_lazily(&mut outs_unordered).unwrap() } + unsafe { graph.call_lazily(&device, &mut outs_unordered).unwrap() } } #[test] @@ -271,7 +276,7 @@ mod tests { Ok(()) }); - unsafe { graph.call_lazily(&mut outs_unordered).unwrap() } + unsafe { graph.call_lazily(&device, &mut outs_unordered).unwrap() } } #[test] @@ -304,7 +309,7 @@ mod tests { }); } - unsafe { graph.call_lazily(&mut outs_unordered).unwrap() } + unsafe { graph.call_lazily(&device, &mut outs_unordered).unwrap() } } #[test] fn test_lazy_op_args_no_out_but_use() { @@ -331,7 +336,7 @@ mod tests { Ok(()) }); - unsafe { graph.call_lazily(&mut outs_unordered).unwrap() } + unsafe { graph.call_lazily(&device, &mut outs_unordered).unwrap() } } #[test] @@ -365,6 +370,6 @@ mod tests { Ok(()) }); - unsafe { graph.call_lazily(&mut outs_unordered).unwrap() } + unsafe { graph.call_lazily(&device, &mut outs_unordered).unwrap() } } } diff --git a/src/parents.rs b/src/parents.rs index ac8283e0..83a2b249 100644 --- a/src/parents.rs +++ b/src/parents.rs @@ -74,13 +74,13 @@ macro_rules! impl_parents { fn replication_fn( ids: Vec<$crate::Id>, op: impl for<'a> Fn(Self::Replicated<'a>) -> $crate::Result<()> + 'static, - ) -> Box) -> $crate::Result<()>> { - Box::new(move |buffers| { + ) -> Box, &dyn core::any::Any) -> $crate::Result<()>> { + Box::new(move |buffers, dev| { let mut ids = ids.iter(); op(($( unsafe { - $to_impl::replicate_borrowed(ids.next().unwrap(), &mut *(buffers as *mut _)).ok_or(crate::DeviceError::InvalidLazyBuf)? + $to_impl::replicate_borrowed(ids.next().unwrap(), &mut *(buffers as *mut _), Some(dev)).ok_or(crate::DeviceError::InvalidLazyBuf)? } ,)+)) }) From d94c367b2ef57e95502e4662706eb69c479ed6f7 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Thu, 25 Jul 2024 19:05:12 +0200 Subject: [PATCH 24/29] Remove device from buffer when registering --- src/hooks.rs | 2 +- src/modules/lazy.rs | 19 +++++++++++++++---- src/modules/mod.rs | 14 ++++++-------- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/hooks.rs b/src/hooks.rs index 70e92abc..44c5f618 100644 --- a/src/hooks.rs +++ b/src/hooks.rs @@ -8,5 +8,5 @@ pub trait OnDropBuffer: WrappedData { pub trait OnNewBuffer<'dev, T: Unit, D: Device, S: Shape = ()> { #[track_caller] - fn on_new_buffer(&self, _device: &'dev D, _new_buf: &Buffer<'dev, T, D, S>) {} + fn on_new_buffer<'s>(&'s self, _device: &'dev D, _new_buf: &'s Buffer<'dev, T, D, S>) {} } diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index 08383817..2078df05 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -205,7 +205,7 @@ where S: Shape, { #[inline] - fn on_new_buffer(&self, device: &'a D, new_buf: &Buffer<'a, T, D, S>) { + fn on_new_buffer<'s>(&'s self, device: &'a D, new_buf: &'s Buffer<'a, T, D, S>) { unsafe { register_buf_copyable(&mut self.buffers.borrow_mut(), new_buf) }; self.modules.on_new_buffer(device, new_buf) } @@ -495,13 +495,24 @@ mod tests { use core::ops::{Add, Deref}; use crate::{ - tests_helper::{add_ew_slice, AddEw}, - AddOperation, ApplyFunction, Base, Buffer, Combiner, Device, Retrieve, Retriever, Shape, - Unit, CPU, + tests_helper::{add_ew_slice, AddEw}, AddOperation, ApplyFunction, Base, Buffer, Combiner, Device, OnDropBuffer, OnNewBuffer, Retrieve, Retriever, Shape, Unit, CPU }; use super::Lazy; + #[test] + fn test_lazy_on_new_buffer() { + let lazy = Lazy::::default(); + { + let device = CPU::::new(); + let buf = device.buffer([1, 2, 3]); + lazy.on_new_buffer(&device, &buf); + } + for value in lazy.buffers.borrow().values() { + + } + } + #[test] #[cfg(feature = "cpu")] fn test_lazy_retrieve() { diff --git a/src/modules/mod.rs b/src/modules/mod.rs index b568b595..68d84128 100644 --- a/src/modules/mod.rs +++ b/src/modules/mod.rs @@ -93,13 +93,12 @@ pub(crate) unsafe fn register_buf_any( { // shallow copy sets flag to AllocFlag::Wrapper - let wrapped_data = buf.data.shallow(); + let wrapped_data = unsafe { buf.data.shallow() }; - let buf = Buffer { + let buf: Buffer = Buffer { data: wrapped_data, - device: buf.device, + device: None, }; - let buf: Buffer<'static, T, D, S> = core::mem::transmute(buf); cache.insert(*buf.id(), Box::new(buf)); } @@ -126,13 +125,12 @@ pub(crate) unsafe fn register_buf_copyable( S: Shape, { // shallow copy sets flag to AllocFlag::Wrapper - let wrapped_data = buf.data.shallow(); + let wrapped_data = unsafe { buf.data.shallow() }; - let buf = Buffer { + let buf: Buffer = Buffer { data: wrapped_data, - device: buf.device, + device: None, }; - let buf: Buffer<'static, T, D, S> = core::mem::transmute(buf); cache.insert(*buf.id(), Box::new(buf)); } From a3b7a3fe15c82784fa3739d2176f303551e4f76c Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Thu, 25 Jul 2024 19:12:49 +0200 Subject: [PATCH 25/29] Add lifetime to Lazy --- src/features.rs | 2 +- src/modules/lazy.rs | 58 ++++++++++++++++++------------------- src/modules/lazy/wrapper.rs | 2 +- 3 files changed, 31 insertions(+), 31 deletions(-) diff --git a/src/features.rs b/src/features.rs index 31cacb31..29923deb 100644 --- a/src/features.rs +++ b/src/features.rs @@ -539,7 +539,7 @@ use crate::Lazy; #[cfg(feature = "lazy")] #[cfg(feature = "cached")] -pass_down_unified_mem_chain!(Lazy); +pass_down_unified_mem_chain!(Lazy, 'dev, Mods); #[cfg(feature = "autograd")] pass_down_unified_mem_chain!(Autograd, 'dev, Mods); diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index 2078df05..5c32d0e4 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -33,7 +33,7 @@ type Buffers = crate::Buffers>; type AllocatedIds = HashSet>; #[derive(Default)] -pub struct Lazy { +pub struct Lazy<'a, Mods, T = f32> { pub modules: Mods, alloc_later: RefCell>, // could use D generic instead of dyn Any (required LazyModule structure) pub buffers: RefCell, @@ -45,10 +45,10 @@ pub struct Lazy { pub graph: RefCell, T>>, cursor: Cell, enabled: Cell, - pd: PhantomData, + pd: PhantomData<&'a T>, } -impl Debug for Lazy { +impl Debug for Lazy<'_, Mods, T> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("Lazy").field("mods", &self.modules).finish() } @@ -68,8 +68,8 @@ pub trait LazyRun { } } -impl<'a, T, Mods: Module<'a, D>, D: LazySetup + Device + 'a> Module<'a, D> for Lazy { - type Module = Lazy; +impl<'a, T, Mods: Module<'a, D>, D: LazySetup + Device + 'a> Module<'a, D> for Lazy<'a, Mods, T> { + type Module = Lazy<'a, Mods::Module, T>; // type Data = LazyWrapper>; #[inline] @@ -88,7 +88,7 @@ impl<'a, T, Mods: Module<'a, D>, D: LazySetup + Device + 'a> Module<'a, D> for L } } -impl AddOperation for Lazy { +impl AddOperation for Lazy<'_, Mods, T> { fn add_op + crate::AnyOp, const N: usize>( &self, args: Args, @@ -120,7 +120,7 @@ impl AddOperation for Lazy { } } -impl SetOpHint for Lazy { +impl SetOpHint for Lazy<'_, Mods, T> { #[inline] fn set_op_hint(&self, op_hint: OpHint) { if let Some(op) = self.graph.borrow_mut().operations.last_mut() { @@ -129,7 +129,7 @@ impl SetOpHint for Lazy { } } -impl ExecNow for Lazy { +impl ExecNow for Lazy<'_, Mods, T> { #[inline] fn exec_now( &self, @@ -148,7 +148,7 @@ impl ExecNow for Lazy { } } -impl Lazy { +impl Lazy<'_, Mods, T> { #[inline] pub unsafe fn call_lazily(&self, device: &D) -> crate::Result<()> { self.graph @@ -166,7 +166,7 @@ impl Lazy { } } -impl> Setup for Lazy { +impl> Setup for Lazy<'_, Mods, T> { #[inline] fn setup(device: &mut D) -> crate::Result<()> { device.lazy_setup()?; @@ -174,7 +174,7 @@ impl> Setup for Lazy { } } -impl, D: LazyRun + Device + 'static> RunModule for Lazy { +impl, D: LazyRun + Device + 'static> RunModule for Lazy<'_, Mods, T> { #[inline] fn run(&self, device: &D) -> crate::Result<()> { self.alloc_later(device); @@ -184,7 +184,7 @@ impl, D: LazyRun + Device + 'static> RunModule for Lazy } } -impl OnDropBuffer for Lazy { +impl OnDropBuffer for Lazy<'_, Mods, T2> { #[inline] fn on_drop_buffer( &self, @@ -196,7 +196,7 @@ impl OnDropBuffer for Lazy { } } -impl<'a, T, D, Mods, S, T2> OnNewBuffer<'a, T, D, S> for Lazy +impl<'a, T, D, Mods, S, T2> OnNewBuffer<'a, T, D, S> for Lazy<'_, Mods, T2> where T: Unit + 'static, D: Device + IsShapeIndep + 'static, @@ -213,10 +213,10 @@ where // pass_down_tape_actions!(Lazy); #[cfg(feature = "autograd")] -impl crate::HasAutograd for Lazy {} +impl crate::HasAutograd for Lazy<'_, Mods, T> {} #[cfg(feature = "autograd")] -impl crate::GradActions for Lazy { +impl crate::GradActions for Lazy<'_, Mods, U> { unsafe fn grad< 'a, T: 'static, @@ -254,7 +254,7 @@ impl crate::GradActions for Lazy { } } -impl crate::AddGradFn for Lazy { +impl crate::AddGradFn for Lazy<'_, Mods, T> { #[inline] fn add_grad_fn + AnyOp, const N: usize>( &self, @@ -276,14 +276,14 @@ impl crate::AddGradFn for Lazy { } // pass_down_grad_fn!(Lazy); // impl_remove_layer!(Lazy); -impl crate::RemoveLayer for Lazy { +impl crate::RemoveLayer for Lazy<'_, Mods, T> { #[inline] fn inner_mods(self) -> Mods { self.modules } } -impl AddLayer for Lazy<(), T> { - type Wrapped = crate::Lazy; +impl<'a, T, NewMods, SD> AddLayer for Lazy<'a, (), T> { + type Wrapped = crate::Lazy<'a, NewMods, T>; #[inline] fn wrap_layer(inner_mods: NewMods) -> Self::Wrapped { @@ -301,7 +301,7 @@ impl AddLayer for Lazy<(), T> { } } -impl Retrieve for Lazy +impl Retrieve for Lazy<'_, Mods, T2> where T: Unit + 'static, Mods: Retrieve, @@ -380,7 +380,7 @@ where } } -impl Cursor for Lazy { +impl Cursor for Lazy<'_, Mods, T> { #[inline] fn cursor(&self) -> usize { self.cursor.get() @@ -393,7 +393,7 @@ impl Cursor for Lazy { } impl ReplaceBuf - for Lazy + for Lazy<'_, Mods, T2> { #[inline] fn replace_buf<'a, 'b, 'c>( @@ -420,7 +420,7 @@ impl R } } -impl UseGpuOrCpu for Lazy { +impl UseGpuOrCpu for Lazy<'_, Mods, T> { fn use_cpu_or_gpu( &self, location: crate::HashLocation<'static>, @@ -472,7 +472,7 @@ impl crate::Optimize for Lazy CachedBuffers for Lazy { +impl CachedBuffers for Lazy<'_, Mods, T> { #[inline] unsafe fn buffers_mut( &self, @@ -481,7 +481,7 @@ impl CachedBuffers for Lazy { } } -impl HasModules for Lazy { +impl HasModules for Lazy<'_, Mods> { type Mods = Mods; #[inline] @@ -495,7 +495,9 @@ mod tests { use core::ops::{Add, Deref}; use crate::{ - tests_helper::{add_ew_slice, AddEw}, AddOperation, ApplyFunction, Base, Buffer, Combiner, Device, OnDropBuffer, OnNewBuffer, Retrieve, Retriever, Shape, Unit, CPU + tests_helper::{add_ew_slice, AddEw}, + AddOperation, ApplyFunction, Base, Buffer, Combiner, Device, OnDropBuffer, OnNewBuffer, + Retrieve, Retriever, Shape, Unit, CPU, }; use super::Lazy; @@ -508,9 +510,7 @@ mod tests { let buf = device.buffer([1, 2, 3]); lazy.on_new_buffer(&device, &buf); } - for value in lazy.buffers.borrow().values() { - - } + for value in lazy.buffers.borrow().values() {} } #[test] diff --git a/src/modules/lazy/wrapper.rs b/src/modules/lazy/wrapper.rs index 790352e0..a6e8c552 100644 --- a/src/modules/lazy/wrapper.rs +++ b/src/modules/lazy/wrapper.rs @@ -12,7 +12,7 @@ pub struct LazyWrapper { pub _pd: PhantomData, } -impl WrappedData for Lazy { +impl WrappedData for Lazy<'_, Mods, T2> { type Wrap = LazyWrapper, T>; #[inline] From 626bc1af67068fbb21c7aef73a29d32fcbe18150 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Thu, 25 Jul 2024 22:23:17 +0200 Subject: [PATCH 26/29] Add OnNewBuffer2 --- src/hooks.rs | 5 +++++ src/modules/autograd.rs | 8 ++++++-- src/modules/lazy.rs | 33 +++++++++++++++++++++++++++------ 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/src/hooks.rs b/src/hooks.rs index 44c5f618..5c47bbcd 100644 --- a/src/hooks.rs +++ b/src/hooks.rs @@ -10,3 +10,8 @@ pub trait OnNewBuffer<'dev, T: Unit, D: Device, S: Shape = ()> { #[track_caller] fn on_new_buffer<'s>(&'s self, _device: &'dev D, _new_buf: &'s Buffer<'dev, T, D, S>) {} } + +pub trait OnNewBuffer2<'s, T: Unit, D: Device, S: Shape = ()> { + #[track_caller] + fn on_new_buffer(&'s self, _device: &D, _new_buf: &'s Buffer) {} +} diff --git a/src/modules/autograd.rs b/src/modules/autograd.rs index c4c96819..f6636698 100644 --- a/src/modules/autograd.rs +++ b/src/modules/autograd.rs @@ -5,7 +5,7 @@ mod wrapper; pub use gradients::*; pub use tape::*; -use core::cell::{Cell, UnsafeCell}; +use core::{cell::{Cell, UnsafeCell}, marker::PhantomData}; use crate::{ impl_remove_layer, pass_down_add_operation, pass_down_cached_buffers, pass_down_cursor, @@ -28,6 +28,7 @@ pub struct Autograd<'dev, Mods> { pub(crate) no_grads_pool: core::cell::RefCell>, pub(crate) tape: UnsafeCell>, pub enabled: Cell, + pd: PhantomData>, } impl<'a, Mods: Module<'a, D>, D: Device + 'a> Module<'a, D> for Autograd<'a, Mods> { @@ -43,6 +44,7 @@ impl<'a, Mods: Module<'a, D>, D: Device + 'a> Module<'a, D> for Autograd<'a, Mod no_grads_pool: Default::default(), tape: Default::default(), enabled: Cell::new(true), + pd: PhantomData } } } @@ -255,6 +257,7 @@ impl<'a, NewMods, SD> AddLayer for Autograd<'a, ()> { tape: Default::default(), enabled: Cell::new(true), no_grads_pool: Default::default(), + pd: PhantomData } } } @@ -285,9 +288,10 @@ mod tests { #[cfg(feature = "opencl")] #[test] fn test_autograd_lt() { + + let ag = Autograd::::default(); { let device = crate::OpenCL::based(0).unwrap(); - let ag = Autograd::::default(); let out = unsafe { // ag.gradients_mut() // .unwrap() diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index 5c32d0e4..fbbd5b5a 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -45,7 +45,7 @@ pub struct Lazy<'a, Mods, T = f32> { pub graph: RefCell, T>>, cursor: Cell, enabled: Cell, - pd: PhantomData<&'a T>, + pd: PhantomData>, } impl Debug for Lazy<'_, Mods, T> { @@ -83,7 +83,7 @@ impl<'a, T, Mods: Module<'a, D>, D: LazySetup + Device + 'a> Module<'a, D> for L allocated_ids: Default::default(), cursor: Default::default(), enabled: Cell::new(true), - pd: PhantomData, + pd: Default::default(), } } } @@ -211,6 +211,25 @@ where } } +impl<'a, T, D: Device, S: Shape> crate::OnNewBuffer2<'a, T, D, S> for crate::Base { + fn on_new_buffer(&'a self, _device: &D, _new_buf: &'a Buffer) {} +} + +impl<'a, T, D, Mods, S, T2> crate::OnNewBuffer2<'a, T, D, S> for Lazy<'a, Mods, T2> +where + T: Unit + 'static, + D: Device + IsShapeIndep + 'static, + D::Data: ShallowCopy, + Mods: crate::OnNewBuffer2<'a, T, D, S>, + S: Shape, +{ + #[inline] + fn on_new_buffer(&'a self, device: &D, new_buf: &'a Buffer) { + unsafe { register_buf_copyable(&mut self.buffers.borrow_mut(), new_buf) }; + self.modules.on_new_buffer(device, new_buf) + } +} + // pass_down_tape_actions!(Lazy); #[cfg(feature = "autograd")] impl crate::HasAutograd for Lazy<'_, Mods, T> {} @@ -296,7 +315,7 @@ impl<'a, T, NewMods, SD> AddLayer for Lazy<'a, (), T> { allocated_ids: Default::default(), cursor: Default::default(), enabled: Cell::new(true), - pd: PhantomData, + pd: Default::default(), } } } @@ -496,7 +515,7 @@ mod tests { use crate::{ tests_helper::{add_ew_slice, AddEw}, - AddOperation, ApplyFunction, Base, Buffer, Combiner, Device, OnDropBuffer, OnNewBuffer, + AddOperation, ApplyFunction, Base, Buffer, Combiner, Device, OnDropBuffer, OnNewBuffer2, Retrieve, Retriever, Shape, Unit, CPU, }; @@ -505,12 +524,14 @@ mod tests { #[test] fn test_lazy_on_new_buffer() { let lazy = Lazy::::default(); - { + // { let device = CPU::::new(); let buf = device.buffer([1, 2, 3]); lazy.on_new_buffer(&device, &buf); + // } + for value in lazy.buffers.borrow().values() { + println!("value"); } - for value in lazy.buffers.borrow().values() {} } #[test] From 35a81d91ab353a7dfe0ce2f793f88278982dae54 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Sun, 28 Jul 2024 02:47:34 +0200 Subject: [PATCH 27/29] Add unsafe to on_new_buffer --- Cargo.toml | 4 ++-- src/buffer.rs | 4 ++-- src/devices.rs | 4 ++-- src/devices/wgsl/wgsl_device.rs | 2 +- src/hooks.rs | 7 +----- src/modules/autograd.rs | 2 +- src/modules/cached.rs | 2 +- src/modules/fork.rs | 2 +- src/modules/graph.rs | 2 +- src/modules/lazy.rs | 40 +++----------------------------- src/modules/lazy/lazy_graph.rs | 25 +------------------- src/modules/lazy/optimization.rs | 2 +- 12 files changed, 17 insertions(+), 79 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fac63475..6aac9718 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,8 +53,8 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true } [features] # default = ["cpu", "opencl", "cuda", "blas", "static-api", "stack", "macro", "nnapi", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph", "serde"] # default = ["no-std"] -default = ["vulkan", "lazy", "autograd", "cpu", "opencl", "cuda"] -# default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "cuda", "vulkan"] +# default = [ "cpu", "lazy", "autograd", ] +default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "cuda", "vulkan"] std = [] diff --git a/src/buffer.rs b/src/buffer.rs index f310f49f..ea1ed1e2 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -89,7 +89,7 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> { }; // mind: on_new_buffer must be called for user buffers! - device.on_new_buffer(device, &buf); + unsafe { device.on_new_buffer(device, &buf) }; buf } @@ -111,7 +111,7 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> { } let mut buf = self; buf.set_requires_grad(require_grad); - buf.device().on_new_buffer(buf.device(), &buf); + unsafe { buf.device().on_new_buffer(buf.device(), &buf) }; buf } diff --git a/src/devices.rs b/src/devices.rs index 70278ace..9dfaf805 100644 --- a/src/devices.rs +++ b/src/devices.rs @@ -113,8 +113,8 @@ macro_rules! impl_buffer_hook_traits { Self: 'dev, { #[inline] - fn on_new_buffer(&self, device: &'dev D, new_buf: &Buffer<'dev, T, D, S>) { - self.modules.on_new_buffer(device, new_buf) + unsafe fn on_new_buffer(&self, device: &'dev D, new_buf: &Buffer<'dev, T, D, S>) { + unsafe { self.modules.on_new_buffer(device, new_buf) } } } diff --git a/src/devices/wgsl/wgsl_device.rs b/src/devices/wgsl/wgsl_device.rs index a9adb1a4..0e583e25 100644 --- a/src/devices/wgsl/wgsl_device.rs +++ b/src/devices/wgsl/wgsl_device.rs @@ -105,7 +105,7 @@ impl<'dev, D: Device, Mods: OnNewBuffer<'dev, T, D1, S>, T: Unit, D1: Device, S: OnNewBuffer<'dev, T, D1, S> for Wgsl { #[inline] - fn on_new_buffer(&self, device: &'dev D1, new_buf: &crate::Buffer<'dev, T, D1, S>) { + unsafe fn on_new_buffer(&self, device: &'dev D1, new_buf: &crate::Buffer<'dev, T, D1, S>) { self.modules.on_new_buffer(device, new_buf) } } diff --git a/src/hooks.rs b/src/hooks.rs index 5c47bbcd..c630a44c 100644 --- a/src/hooks.rs +++ b/src/hooks.rs @@ -8,10 +8,5 @@ pub trait OnDropBuffer: WrappedData { pub trait OnNewBuffer<'dev, T: Unit, D: Device, S: Shape = ()> { #[track_caller] - fn on_new_buffer<'s>(&'s self, _device: &'dev D, _new_buf: &'s Buffer<'dev, T, D, S>) {} -} - -pub trait OnNewBuffer2<'s, T: Unit, D: Device, S: Shape = ()> { - #[track_caller] - fn on_new_buffer(&'s self, _device: &D, _new_buf: &'s Buffer) {} + unsafe fn on_new_buffer<'s>(&'s self, _device: &'dev D, _new_buf: &'s Buffer<'dev, T, D, S>) {} } diff --git a/src/modules/autograd.rs b/src/modules/autograd.rs index f6636698..45e3f775 100644 --- a/src/modules/autograd.rs +++ b/src/modules/autograd.rs @@ -76,7 +76,7 @@ where Mods: OnNewBuffer<'dev, T, D, S>, { #[inline] - fn on_new_buffer(&self, device: &'dev D, new_buf: &Buffer<'dev, T, D, S>) { + unsafe fn on_new_buffer(&self, device: &'dev D, new_buf: &Buffer<'dev, T, D, S>) { // let mut no_grads = self.no_grads_pool.borrow_mut(); // let wrapped_data = unsafe { new_buf.data.shallow() }; diff --git a/src/modules/cached.rs b/src/modules/cached.rs index e026d900..517ec9f0 100644 --- a/src/modules/cached.rs +++ b/src/modules/cached.rs @@ -123,7 +123,7 @@ where S: Shape, { #[inline] - fn on_new_buffer(&self, device: &'a D, new_buf: &Buffer<'a, T, D, S>) { + unsafe fn on_new_buffer(&self, device: &'a D, new_buf: &Buffer<'a, T, D, S>) { self.modules.on_new_buffer(device, new_buf) } } diff --git a/src/modules/fork.rs b/src/modules/fork.rs index 47b98ce4..23ec589f 100644 --- a/src/modules/fork.rs +++ b/src/modules/fork.rs @@ -99,7 +99,7 @@ impl<'a, Mods: OnNewBuffer<'a, T, D, S>, T: Unit, D: Device, S: Shape> OnNewBuff for Fork { #[inline] - fn on_new_buffer(&self, device: &'a D, new_buf: &crate::Buffer<'a, T, D, S>) { + unsafe fn on_new_buffer(&self, device: &'a D, new_buf: &crate::Buffer<'a, T, D, S>) { self.modules.on_new_buffer(device, new_buf) } } diff --git a/src/modules/graph.rs b/src/modules/graph.rs index 37760d6d..56f5a2c4 100644 --- a/src/modules/graph.rs +++ b/src/modules/graph.rs @@ -101,7 +101,7 @@ impl Optimize for Graph { impl<'a, Mods: OnNewBuffer<'a, T, D, S>, T: Unit, D: Device, S: Shape> OnNewBuffer<'a, T, D, S> for Graph { - fn on_new_buffer(&self, _device: &'a D, new_buf: &crate::Buffer<'a, T, D, S>) { + unsafe fn on_new_buffer(&self, _device: &'a D, new_buf: &crate::Buffer<'a, T, D, S>) { let mut graph_trans = self.graph_trans.borrow_mut(); let next_idx = graph_trans.next_idx; diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index fbbd5b5a..5836f242 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -205,26 +205,7 @@ where S: Shape, { #[inline] - fn on_new_buffer<'s>(&'s self, device: &'a D, new_buf: &'s Buffer<'a, T, D, S>) { - unsafe { register_buf_copyable(&mut self.buffers.borrow_mut(), new_buf) }; - self.modules.on_new_buffer(device, new_buf) - } -} - -impl<'a, T, D: Device, S: Shape> crate::OnNewBuffer2<'a, T, D, S> for crate::Base { - fn on_new_buffer(&'a self, _device: &D, _new_buf: &'a Buffer) {} -} - -impl<'a, T, D, Mods, S, T2> crate::OnNewBuffer2<'a, T, D, S> for Lazy<'a, Mods, T2> -where - T: Unit + 'static, - D: Device + IsShapeIndep + 'static, - D::Data: ShallowCopy, - Mods: crate::OnNewBuffer2<'a, T, D, S>, - S: Shape, -{ - #[inline] - fn on_new_buffer(&'a self, device: &D, new_buf: &'a Buffer) { + unsafe fn on_new_buffer<'s>(&'s self, device: &'a D, new_buf: &'s Buffer<'a, T, D, S>) { unsafe { register_buf_copyable(&mut self.buffers.borrow_mut(), new_buf) }; self.modules.on_new_buffer(device, new_buf) } @@ -463,7 +444,7 @@ impl UseGpuOrCpu for Lazy<'_, Mods, T> { } #[cfg(feature = "graph")] -impl crate::Optimize for Lazy { +impl crate::Optimize for Lazy<'_, Mods, T> { #[inline] fn optimize_mem_graph( &self, @@ -514,26 +495,11 @@ mod tests { use core::ops::{Add, Deref}; use crate::{ - tests_helper::{add_ew_slice, AddEw}, - AddOperation, ApplyFunction, Base, Buffer, Combiner, Device, OnDropBuffer, OnNewBuffer2, - Retrieve, Retriever, Shape, Unit, CPU, + tests_helper::{add_ew_slice, AddEw}, AddOperation, ApplyFunction, Base, Buffer, Combiner, Device, Retrieve, Retriever, Shape, Unit, CPU }; use super::Lazy; - #[test] - fn test_lazy_on_new_buffer() { - let lazy = Lazy::::default(); - // { - let device = CPU::::new(); - let buf = device.buffer([1, 2, 3]); - lazy.on_new_buffer(&device, &buf); - // } - for value in lazy.buffers.borrow().values() { - println!("value"); - } - } - #[test] #[cfg(feature = "cpu")] fn test_lazy_retrieve() { diff --git a/src/modules/lazy/lazy_graph.rs b/src/modules/lazy/lazy_graph.rs index d8080487..b4ae65ab 100644 --- a/src/modules/lazy/lazy_graph.rs +++ b/src/modules/lazy/lazy_graph.rs @@ -139,37 +139,14 @@ mod tests { register_buf_any, register_buf_copyable, AnyBuffer, Base, Buffer, CloneBuf, Device, HasId, LazyGraph, Retriever, Shape, UniqueId, CPU, }; - use core::cell::Cell; use std::collections::HashMap; - - pub(crate) fn register_buf_an_bufsy<'a, T, D, S>( - cache: &mut HashMap, impl core::hash::BuildHasher>, - buf: &Buffer<'a, T, D, S>, - ) where - T: crate::Unit + 'static, - D: Device + crate::IsShapeIndep + 'static + CloneBuf<'a, T, S>, - D::Data: crate::ShallowCopy, - S: Shape, - { - // shallow copy sets flag to AllocFlag::Wrapper - - // let wrapped_data = unsafe { buf.data.shallow() }; - - // let buf = Buffer { - // data: wrapped_data, - // device: buf.device, - // }; - let buf2 = buf.device().clone_buf(&buf); - cache.insert(*buf.id(), Box::new(buf2)); - } - #[cfg(feature = "autograd")] #[test] fn test_autograd_lazy_op() { use crate::TapeActions; // static mut DEVICE: Option<&'static CPU>> = None; thread_local! { - static DEVICE2: Cell>>> = Cell::new(None); + static DEVICE2: std::cell::Cell>>> = std::cell::Cell::new(None); }; // static DEVICES: std::sync::Mutex>>> = Default::default(); { diff --git a/src/modules/lazy/optimization.rs b/src/modules/lazy/optimization.rs index 8021b342..46d19a0d 100644 --- a/src/modules/lazy/optimization.rs +++ b/src/modules/lazy/optimization.rs @@ -1,6 +1,6 @@ use crate::{op_hint::OpHint, DeviceError, Lazy, Operation}; -impl Lazy { +impl Lazy<'_, Mods, T> { pub(crate) fn alloc_later_optimized( &self, device: &D, From 71c83207ae39f1172b4235fd3a722e192b673474 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Sun, 28 Jul 2024 17:55:57 +0200 Subject: [PATCH 28/29] Fix fusing -> move arg ids to callback params in replication fn --- Cargo.toml | 2 +- src/any_op.rs | 10 ++++------ src/devices/fusing.rs | 15 ++++++++++++--- src/modules/autograd.rs | 10 ++++++---- src/modules/lazy.rs | 4 +++- src/modules/lazy/exec_iter.rs | 4 ++-- src/modules/lazy/lazy_graph.rs | 15 ++++++++++----- src/op_hint.rs | 3 ++- src/parents.rs | 9 +++++---- 9 files changed, 45 insertions(+), 27 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6aac9718..625dfe02 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true } [features] # default = ["cpu", "opencl", "cuda", "blas", "static-api", "stack", "macro", "nnapi", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph", "serde"] # default = ["no-std"] -# default = [ "cpu", "lazy", "autograd", ] +# default = [ "cpu", "lazy", "autograd", "graph"] default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "cuda", "vulkan"] diff --git a/src/any_op.rs b/src/any_op.rs index 2793cf3f..31e95bf7 100644 --- a/src/any_op.rs +++ b/src/any_op.rs @@ -36,9 +36,8 @@ pub trait AnyOp: Sized { #[cfg(feature = "std")] fn replication_fn( - ids: Vec, op: impl for<'a> Fn(Self::Replicated<'a>) -> crate::Result<()> + 'static, - ) -> Box Fn(&'i mut Buffers, &dyn core::any::Any) -> crate::Result<()>>; + ) -> Box Fn(&[crate::Id], &'i mut Buffers, &dyn core::any::Any) -> crate::Result<()>>; unsafe fn replication<'a>(self) -> Self::Replicated<'a>; } @@ -106,13 +105,12 @@ impl<'a, T: 'static, D: Device + 'static, S: crate::Shape> Replicate impl AnyOp for R { #[cfg(feature = "std")] fn replication_fn( - ids: Vec, op: impl for<'a> Fn(Self::Replicated<'a>) -> crate::Result<()> + 'static, - ) -> Box, &dyn core::any::Any) -> crate::Result<()>> { + ) -> Box, &dyn core::any::Any) -> crate::Result<()>> { use crate::DeviceError; - let id = ids[0]; - Box::new(move |buffers, dev| { + Box::new(move |ids, buffers, dev| { + let id = ids[0]; let r1 = unsafe { R::replicate_borrowed(&id, buffers, Some(dev)) } .ok_or(DeviceError::InvalidLazyBuf)?; op(r1) diff --git a/src/devices/fusing.rs b/src/devices/fusing.rs index 5dce70ee..1d77d136 100644 --- a/src/devices/fusing.rs +++ b/src/devices/fusing.rs @@ -49,7 +49,7 @@ pub trait UnaryFusing: IsShapeIndep { where Self: 'static, { - use crate::Buffer; + use crate::{Buffer, Buffers, Downcast, HasId}; let (ops, affected_op_idxs) = ops; let to_insert_idx: usize = affected_op_idxs[0]; @@ -62,12 +62,21 @@ pub trait UnaryFusing: IsShapeIndep { // use last op in the unary fuse chain as the output buffer let last_arg_ids = &last_op.arg_ids; + + assert_ne!(*last_arg_ids[0], *first_arg_ids[1]); + let out = unsafe { - &mut *(buffers.get_mut(&last_arg_ids[0]).unwrap() as *mut _ as *mut Buffer) + (&mut *(buffers as *mut Buffers>)) + .get_mut(&last_arg_ids[0]) + .unwrap() + .downcast_mut_unchecked::>() }; let buf = unsafe { - &*(buffers.get(&first_arg_ids[1]).unwrap() as *const _ as *const Buffer) + buffers + .get(&first_arg_ids[1]) + .unwrap() + .downcast_ref_unchecked::>() }; let op = self.unary_fuse_op::(ops); diff --git a/src/modules/autograd.rs b/src/modules/autograd.rs index 45e3f775..c1afde4e 100644 --- a/src/modules/autograd.rs +++ b/src/modules/autograd.rs @@ -5,7 +5,10 @@ mod wrapper; pub use gradients::*; pub use tape::*; -use core::{cell::{Cell, UnsafeCell}, marker::PhantomData}; +use core::{ + cell::{Cell, UnsafeCell}, + marker::PhantomData, +}; use crate::{ impl_remove_layer, pass_down_add_operation, pass_down_cached_buffers, pass_down_cursor, @@ -44,7 +47,7 @@ impl<'a, Mods: Module<'a, D>, D: Device + 'a> Module<'a, D> for Autograd<'a, Mod no_grads_pool: Default::default(), tape: Default::default(), enabled: Cell::new(true), - pd: PhantomData + pd: PhantomData, } } } @@ -257,7 +260,7 @@ impl<'a, NewMods, SD> AddLayer for Autograd<'a, ()> { tape: Default::default(), enabled: Cell::new(true), no_grads_pool: Default::default(), - pd: PhantomData + pd: PhantomData, } } } @@ -288,7 +291,6 @@ mod tests { #[cfg(feature = "opencl")] #[test] fn test_autograd_lt() { - let ag = Autograd::::default(); { let device = crate::OpenCL::based(0).unwrap(); diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index 5836f242..827534e5 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -495,7 +495,9 @@ mod tests { use core::ops::{Add, Deref}; use crate::{ - tests_helper::{add_ew_slice, AddEw}, AddOperation, ApplyFunction, Base, Buffer, Combiner, Device, Retrieve, Retriever, Shape, Unit, CPU + tests_helper::{add_ew_slice, AddEw}, + AddOperation, ApplyFunction, Base, Buffer, Combiner, Device, Retrieve, Retriever, Shape, + Unit, CPU, }; use super::Lazy; diff --git a/src/modules/lazy/exec_iter.rs b/src/modules/lazy/exec_iter.rs index 68e88def..02c10506 100644 --- a/src/modules/lazy/exec_iter.rs +++ b/src/modules/lazy/exec_iter.rs @@ -11,14 +11,14 @@ impl<'b, B, T, D: Device + 'static> Iterator for ExecIter<'b, B, T, D> { fn next(&mut self) -> Option { let op = self.operations.next()?; - Some((op.op)(self.buffers, self.device)) + Some(op.call(self.buffers, self.device)) } } impl<'b, B, T, D: Device + 'static> DoubleEndedIterator for ExecIter<'b, B, T, D> { fn next_back(&mut self) -> Option { let op = self.operations.next_back()?; - Some((op.op)(self.buffers, self.device)) + Some(op.call(self.buffers, self.device)) } } diff --git a/src/modules/lazy/lazy_graph.rs b/src/modules/lazy/lazy_graph.rs index b4ae65ab..7b50f7e4 100644 --- a/src/modules/lazy/lazy_graph.rs +++ b/src/modules/lazy/lazy_graph.rs @@ -7,7 +7,7 @@ use std::collections::HashSet; pub struct Operation { pub arg_ids: Vec, - pub op: Box, &dyn core::any::Any) -> crate::Result<()> + 'static>, + pub op: Box, &dyn core::any::Any) -> crate::Result<()> + 'static>, pub op_hint: OpHint, // pub pd: PhantomData<&'a ()>, } @@ -15,11 +15,16 @@ pub struct Operation { impl Operation { pub fn no_op() -> Self { Self { - op: Box::new(|_buffers, _dev| Ok(())), + op: Box::new(|_ids, _buffers, _dev| Ok(())), arg_ids: vec![], op_hint: OpHint::None, } } + + #[inline] + pub fn call(&self, buffers: &mut Buffers, device: &D) -> crate::Result<()> { + (self.op)(&self.arg_ids, buffers, device) + } } pub struct LazyGraph, T = ()> { @@ -78,7 +83,7 @@ impl LazyGraph { ) -> crate::Result<()> { let range = bounds_to_range(bounds, self.operations.len()); for op in self.operations.drain(range) { - (op.op)(buffers, device)?; + op.call(buffers, device)?; } Ok(()) } @@ -112,8 +117,8 @@ impl LazyGraph { panic!() } - let op: Box, &dyn core::any::Any) -> crate::Result<()>> = - Args::replication_fn::(arg_ids.clone(), op); + let op: Box, &dyn core::any::Any) -> crate::Result<()>> = + Args::replication_fn::(op); Operation { arg_ids, diff --git a/src/op_hint.rs b/src/op_hint.rs index c4c29e14..91b8cb0a 100644 --- a/src/op_hint.rs +++ b/src/op_hint.rs @@ -188,7 +188,7 @@ mod tests { #[cfg(feature = "graph")] #[test] fn test_op_hint_unary_chain_fuse_graph_complex() { - use crate::{ApplyFunction, Base, Combiner, Device, Graph, Lazy, Optimize, Run, CPU}; + use crate::{ApplyFunction, Base, Combiner, Device, Graph, HasId, Lazy, Optimize, Run, CPU}; let dev = CPU::>>::new(); @@ -200,6 +200,7 @@ mod tests { let out1 = dev.apply_fn(&out1, |x| x.abs()); let _out = dev.apply_fn(&out1, |x| x.ln()); + dev.optimize_mem_graph(&dev, None).unwrap(); dev.unary_fusing(&dev, None).unwrap(); unsafe { dev.run().unwrap() }; diff --git a/src/parents.rs b/src/parents.rs index 83a2b249..26d74076 100644 --- a/src/parents.rs +++ b/src/parents.rs @@ -72,15 +72,16 @@ macro_rules! impl_parents { #[cfg(feature = "std")] fn replication_fn( - ids: Vec<$crate::Id>, op: impl for<'a> Fn(Self::Replicated<'a>) -> $crate::Result<()> + 'static, - ) -> Box, &dyn core::any::Any) -> $crate::Result<()>> { - Box::new(move |buffers, dev| { + ) -> Box, &dyn core::any::Any) -> $crate::Result<()>> { + Box::new(move |ids, buffers, dev| { let mut ids = ids.iter(); op(($( unsafe { - $to_impl::replicate_borrowed(ids.next().unwrap(), &mut *(buffers as *mut _), Some(dev)).ok_or(crate::DeviceError::InvalidLazyBuf)? + $to_impl::replicate_borrowed( + ids.next().unwrap(), &mut *(buffers as *mut _), Some(dev) + ).ok_or(crate::DeviceError::InvalidLazyBuf)? } ,)+)) }) From a695bc3cc5c2de1e5df1b41de4e6b17c38553082 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Sun, 28 Jul 2024 18:01:07 +0200 Subject: [PATCH 29/29] Fix nnapi --- Cargo.toml | 2 +- src/devices/nnapi/nnapi_device.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 625dfe02..d8025ff9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,7 +54,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true } # default = ["cpu", "opencl", "cuda", "blas", "static-api", "stack", "macro", "nnapi", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph", "serde"] # default = ["no-std"] # default = [ "cpu", "lazy", "autograd", "graph"] -default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "cuda", "vulkan"] +default = ["nnapi", "untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "cuda", "vulkan"] std = [] diff --git a/src/devices/nnapi/nnapi_device.rs b/src/devices/nnapi/nnapi_device.rs index 0b5bae10..2de16a0a 100644 --- a/src/devices/nnapi/nnapi_device.rs +++ b/src/devices/nnapi/nnapi_device.rs @@ -68,7 +68,7 @@ impl<'a, U, T: Unit, D: Device, S: Shape, Mods: crate::OnNewBuffer<'a, T, D, S>> crate::OnNewBuffer<'a, T, D, S> for NnapiDevice { #[inline] - fn on_new_buffer(&self, device: &'a D, new_buf: &Buffer<'a, T, D, S>) { + unsafe fn on_new_buffer(&self, device: &'a D, new_buf: &Buffer<'a, T, D, S>) { self.modules.on_new_buffer(device, new_buf) } } @@ -250,7 +250,7 @@ impl NnapiDevice { impl LazySetup for NnapiDevice {} -impl Default for NnapiDevice> { +impl<'a, T: 'a> Default for NnapiDevice> { #[inline] fn default() -> Self { Self::new().unwrap()