diff --git a/Cargo.toml b/Cargo.toml index 6b91978c..c00ef063 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ repository = "https://github.com/elftausend/custos" keywords = ["gpu", "autodiff", "arrays", "deep-learning", "fixed-size"] categories = ["science", "mathematics", "no-std", "external-ffi-bindings"] readme = "README.md" -rust-version = "1.79" +rust-version = "1.81" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] diff --git a/examples/lazy_and_fusing.rs b/examples/lazy_and_fusing.rs index 419859ee..159c2f6b 100644 --- a/examples/lazy_and_fusing.rs +++ b/examples/lazy_and_fusing.rs @@ -15,7 +15,7 @@ fn main() { device.unary_fusing(&device, None).unwrap(); // this executes all operations inside the lazy graph - unsafe { device.run().unwrap() }; + device.run().unwrap(); for (input, out) in buf.read().iter().zip(out2.replace().read()) { assert!((out - (input + 1.).sin()).abs() < 0.01); diff --git a/examples/modules_usage.rs b/examples/modules_usage.rs index 5309fcd4..83bc75fc 100644 --- a/examples/modules_usage.rs +++ b/examples/modules_usage.rs @@ -147,7 +147,7 @@ fn main() { let rhs = device.buffer([1, 2, 3, 4, 5]); let out = device.add(&lhs, &rhs).unwrap(); - unsafe { device.run().unwrap() }; // allocates memory and executes all operations inside the lazy graph + device.run().unwrap(); // allocates memory and executes all operations inside the lazy graph assert_eq!(out.replace().read(), [2, 4, 6, 8, 10]) } @@ -171,7 +171,7 @@ fn main() { device.unary_fusing(&device, None).unwrap(); // this executes all operations inside the lazy graph - unsafe { device.run().unwrap() }; + device.run().unwrap(); for (input, out) in buf.read().iter().zip(out2.replace().read()) { assert!((out - (input + 1.).sin()).abs() < 0.01); @@ -189,7 +189,7 @@ fn main() { let rhs = device.buffer([1, 2, 3, 4, 5]); let out = device.add(&lhs, &rhs).unwrap(); - unsafe { device.run().unwrap() }; + device.run().unwrap(); assert_eq!(out.replace().read(), vec![2, 4, 6, 8, 10]) } } diff --git a/src/buffer.rs b/src/buffer.rs index 2a58e02b..31638197 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -1,5 +1,4 @@ use core::{ - ffi::c_void, mem::ManuallyDrop, ops::{Deref, DerefMut}, }; @@ -11,9 +10,9 @@ use crate::cpu::{CPUPtr, CPU}; use crate::CPU; use crate::{ - flag::AllocFlag, shape::Shape, Alloc, Base, ClearBuf, CloneBuf, Device, - DevicelessAble, HasId, IsShapeIndep, OnDropBuffer, OnNewBuffer, PtrType, Read, ReplaceBuf, - ShallowCopy, Unit, WrappedData, WriteBuf, ZeroGrad, + flag::AllocFlag, shape::Shape, Alloc, Base, ClearBuf, CloneBuf, Device, DevicelessAble, HasId, + IsShapeIndep, OnDropBuffer, OnNewBuffer, PtrType, Read, ReplaceBuf, ShallowCopy, Unit, + WrappedData, WriteBuf, ZeroGrad, }; pub use self::num::Num; @@ -583,7 +582,7 @@ impl<'a, Mods: OnDropBuffer, T: Unit, S: Shape> Buffer<'a, T, CPU, S> { impl<'a, Mods: OnDropBuffer, T: Unit, S: Shape> Buffer<'a, T, crate::OpenCL, S> { /// Returns the OpenCL pointer of the `Buffer`. #[inline] - pub fn cl_ptr(&self) -> *mut c_void { + pub fn cl_ptr(&self) -> *mut core::ffi::c_void { assert!( !self.base().ptr.is_null(), "called cl_ptr() on an invalid OpenCL buffer" @@ -687,15 +686,14 @@ where T: Unit + Debug + Default + Clone + 'a, D: Read + Device + 'a, for<'b> >::Read<'b>: Debug, - D::Data: Debug, + D::Data: Debug, S: Shape, { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_struct("Buffer") - .field("ptr", self.data()); + f.debug_struct("Buffer").field("ptr", self.data()); writeln!(f, ",")?; - let data = self.read(); + let data = self.read(); writeln!(f, "data: {data:?}")?; writeln!(f, "len: {:?}", self.len())?; write!( diff --git a/src/buffer/num.rs b/src/buffer/num.rs index 556e052f..c7732b06 100644 --- a/src/buffer/num.rs +++ b/src/buffer/num.rs @@ -1,13 +1,11 @@ use core::{ convert::Infallible, - ffi::c_void, ops::{Deref, DerefMut}, - ptr::null_mut, }; use crate::{ - flag::AllocFlag, Alloc, Buffer, CloneBuf, Device, HasId, OnDropBuffer, PtrType, - ShallowCopy, Unit, WrappedData, + flag::AllocFlag, Alloc, Buffer, CloneBuf, Device, HasId, OnDropBuffer, PtrType, ShallowCopy, + Unit, WrappedData, }; #[derive(Debug, Default)] @@ -175,15 +173,37 @@ impl<'a, T: Unit> Buffer<'a, T, ()> { /// /// let x: Buffer = 7f32.into(); /// assert_eq!(**x, 7.); - /// assert_eq!(x.item(), 7.); + /// assert_eq!(x.item(), &7.); /// /// ``` #[inline] - pub fn item(&self) -> T + pub fn item(&self) -> &T where - T: Unit + Copy, + T: Unit, { - self.data.num + &self.data.num + } + + /// Used if the `Buffer` contains only a single value. + /// By derefencing this `Buffer`, you obtain this value as well (which is probably preferred). + /// + /// # Example + /// + /// ``` + /// use custos::Buffer; + /// + /// let mut x: Buffer = 7f32.into(); + /// assert_eq!(**x, 7.); + /// *x.item_mut() += 1.; + /// assert_eq!(*x.item_mut(), 8.); + /// + /// ``` + #[inline] + pub fn item_mut(&mut self) -> &mut T + where + T: Unit, + { + &mut self.data.num } } @@ -229,8 +249,4 @@ mod tests { <()>::new().unwrap(); } - - #[cfg(feature = "lazy")] - #[test] - fn test_num_device_lazy() {} } diff --git a/src/cache/owned_cache.rs b/src/cache/owned_cache.rs index ed09ffab..36eb6024 100644 --- a/src/cache/owned_cache.rs +++ b/src/cache/owned_cache.rs @@ -129,4 +129,42 @@ mod tests { } assert_eq!(cache.nodes.len(), 1); } + + #[cfg(feature = "cpu")] + #[test] + #[should_panic] + fn test_cache_with_diffrent_length_return() { + use crate::{Buffer, Cursor, Retriever, Base}; + + let dev = CPU::>::new(); + + for i in dev.range(10) { + if i == 4 { + // has assert inside, therefore, this line leads to a crash due tue mismatiching lengths + let buf: Buffer = dev.retrieve(5, ()).unwrap(); + assert_eq!(buf.len, 5); + } else { + let _x: Buffer = dev.retrieve(3, ()).unwrap(); + } + } + } + + #[cfg(feature = "cpu")] + #[test] + fn test_cache_with_cursor_range_overlap() { + use crate::{Buffer, Cursor, Retriever, Base}; + + let dev = CPU::>::new(); + + for _i in dev.range(10) { + let _x: Buffer = dev.retrieve(3, ()).unwrap(); + } + + assert_eq!(dev.cursor(), 1); + + for _i in dev.range(1..7) { + let _x: Buffer = dev.retrieve(4, ()).unwrap(); + } + assert_eq!(dev.cursor(), 2); + } } diff --git a/src/devices/cpu/cpu_device.rs b/src/devices/cpu/cpu_device.rs index 1d923740..a35fb2df 100644 --- a/src/devices/cpu/cpu_device.rs +++ b/src/devices/cpu/cpu_device.rs @@ -176,7 +176,7 @@ impl crate::LazyRun for CPU {} impl> crate::Run for CPU { #[inline] - unsafe fn run(&self) -> crate::Result<()> { + fn run(&self) -> crate::Result<()> { self.modules.run(self) } } diff --git a/src/devices/cpu/cpu_ptr.rs b/src/devices/cpu/cpu_ptr.rs index efed2306..e2a41801 100644 --- a/src/devices/cpu/cpu_ptr.rs +++ b/src/devices/cpu/cpu_ptr.rs @@ -99,11 +99,7 @@ impl CPUPtr { /// ``` #[inline] pub unsafe fn from_ptr(ptr: *mut T, len: usize, flag: AllocFlag) -> CPUPtr { - CPUPtr { - ptr, - len, - flag, - } + CPUPtr { ptr, len, flag } } pub fn from_vec(mut vec: Vec) -> CPUPtr { // CPUPtr only knows about the length, not the capacity -> deallocation happens with length, which may be less than the capacity @@ -246,7 +242,7 @@ impl ShallowCopy for CPUPtr { pub struct DeallocWithLayout { ptr: core::mem::ManuallyDrop>, - layout: Layout, + layout: Layout, } impl DeallocWithLayout { @@ -255,14 +251,18 @@ impl DeallocWithLayout { let (_, layout) = ptr.current_memory()?; let ptr = core::mem::ManuallyDrop::new(ptr); Some(Self { - ptr: core::mem::ManuallyDrop::new(CPUPtr { ptr: ptr.ptr as *mut u8, len: ptr.len, flag: ptr.flag }), - layout + ptr: core::mem::ManuallyDrop::new(CPUPtr { + ptr: ptr.ptr as *mut u8, + len: ptr.len, + flag: ptr.flag, + }), + layout, }) } - #[inline] + #[inline] pub fn layout(&self) -> &Layout { - &self.layout + &self.layout } } @@ -452,9 +452,7 @@ mod tests { #[test] fn test_dealloc_with_layout() { let data = CPUPtr::::new_initialized(10, crate::flag::AllocFlag::None); - let dealloc = unsafe { - DeallocWithLayout::new(data).unwrap() - }; + let dealloc = unsafe { DeallocWithLayout::new(data).unwrap() }; assert_eq!(dealloc.layout().size(), 40) } } diff --git a/src/devices/cuda/cuda_ptr.rs b/src/devices/cuda/cuda_ptr.rs index 8a5e4ddf..c5891183 100644 --- a/src/devices/cuda/cuda_ptr.rs +++ b/src/devices/cuda/cuda_ptr.rs @@ -1,8 +1,6 @@ -use core::{marker::PhantomData, ptr::null_mut}; - -use crate::{flag::AllocFlag, HasId, Id, PtrType, ShallowCopy}; - use super::api::{cu_read, cufree, cumalloc, CudaResult}; +use crate::{flag::AllocFlag, HasId, Id, PtrType, ShallowCopy}; +use core::marker::PhantomData; /// The pointer used for `CUDA` [`Buffer`](crate::Buffer)s #[derive(Debug, PartialEq, Eq)] diff --git a/src/devices/cuda/lazy.rs b/src/devices/cuda/lazy.rs index e61650be..623f0e15 100644 --- a/src/devices/cuda/lazy.rs +++ b/src/devices/cuda/lazy.rs @@ -50,7 +50,7 @@ impl crate::LazyRun for CUDA { impl> crate::Run for CUDA { #[inline] - unsafe fn run(&self) -> crate::Result<()> { + fn run(&self) -> crate::Result<()> { self.modules.run(self) } } @@ -74,7 +74,7 @@ impl crate::LazySetup for CUDA { #[cfg(test)] mod tests { use crate::{ - AddOperation, ApplyFunction, AsNoId, Base, Buffer, Combiner, Device, HasId, Lazy, Retrieve, + AddOperation, ApplyFunction, Base, Buffer, Combiner, Device, HasId, Lazy, Retrieve, Retriever, Run, CUDA, }; @@ -133,7 +133,7 @@ mod tests { assert_eq!(lhs.read(), vec![1, 2, 3, 4, 5, 6]); assert_eq!(rhs.read(), vec![1, 2, 3, 4, 5, 6]); - unsafe { device.run().unwrap() }; + device.run().unwrap(); assert_eq!(out.read(), vec![2, 4, 6, 8, 10, 12]); assert_eq!(rhs.read(), vec![3, 6, 9, 12, 15, 18]); @@ -187,7 +187,7 @@ mod tests { lhs.write(&[1, 2, 3, 4, 5, 6]); rhs.write(&[1, 2, 3, 4, 5, 6]); device.mem_transfer_stream.sync().unwrap(); - unsafe { device.run().unwrap() }; + device.run().unwrap(); } assert_eq!(out.read(), vec![2, 4, 6, 8, 10, 12]); @@ -294,7 +294,7 @@ mod tests { let out = cuda_ew(&device, &lhs, &rhs, ew_src("add", '+'), "add"); let out2 = cuda_ew(&device, &out, &rhs, ew_src("add", '+'), "add"); - let _ = unsafe { device.run() }; + let _ = device.run(); assert_eq!(out.replace().read(), [2, 4, 6, 8, 10, 12]); assert_eq!(out2.replace().read(), [3, 6, 9, 12, 15, 18]); @@ -310,6 +310,6 @@ mod tests { let out = device.apply_fn(&out, |x| x.cos()); let _out = device.apply_fn(&out, |x| x.ln()); - let _ = unsafe { device.run() }; + let _ = device.run(); } } diff --git a/src/devices/cuda/ops.rs b/src/devices/cuda/ops.rs index be0378f2..bf6431f8 100644 --- a/src/devices/cuda/ops.rs +++ b/src/devices/cuda/ops.rs @@ -4,9 +4,9 @@ use crate::{ bounds_to_range, cuda::api::{cu_read_async, CUstreamCaptureStatus}, op_hint::unary, - pass_down_add_operation, pass_down_exec_now, AddOperation, ApplyFunction, - Buffer, CDatatype, ClearBuf, CopySlice, OnDropBuffer, Read, Resolve, Retrieve, Retriever, - SetOpHint, Shape, ToCLSource, ToMarker, UnaryGrad, Unit, WriteBuf, ZeroGrad, CUDA, + pass_down_add_operation, pass_down_exec_now, AddOperation, ApplyFunction, Buffer, CDatatype, + ClearBuf, CopySlice, OnDropBuffer, Read, Resolve, Retrieve, Retriever, SetOpHint, Shape, + ToCLSource, ToMarker, UnaryGrad, Unit, WriteBuf, ZeroGrad, CUDA, }; use super::{ @@ -291,7 +291,7 @@ mod tests { assert_eq!(lhs_grad.read(), vec![1, 2, 3, 4, 5, 6]); - unsafe { device.run().unwrap() } + device.run().unwrap(); assert_eq!(lhs_grad.read(), vec![4, 6, 8, 10, 12, 14]); } diff --git a/src/devices/nnapi/nnapi_device.rs b/src/devices/nnapi/nnapi_device.rs index fe6d5dfe..7186e070 100644 --- a/src/devices/nnapi/nnapi_device.rs +++ b/src/devices/nnapi/nnapi_device.rs @@ -1,6 +1,7 @@ use crate::{ - cpu::{CPUPtr, DeallocWithLayout}, Alloc, AsOperandCode, Base, Buffer, Device, HasId, IsShapeIndep, Lazy, LazyRun, - LazySetup, Module, OnDropBuffer, PtrType, Retrieve, Retriever, Setup, Shape, Unit, WrappedData, + cpu::{CPUPtr, DeallocWithLayout}, + Alloc, AsOperandCode, Base, Buffer, Device, HasId, IsShapeIndep, Lazy, LazyRun, LazySetup, + Module, OnDropBuffer, PtrType, Retrieve, Retriever, Setup, Shape, Unit, WrappedData, }; use super::NnapiPtr; @@ -215,11 +216,7 @@ impl NnapiDevice { fn set_input_ptrs(&self, run: &mut Execution) -> crate::Result<()> { for (idx, (_id, input_ptr)) in self.input_ptrs.borrow().iter().enumerate() { unsafe { - run.set_input_raw( - idx as i32, - input_ptr.ptr.cast(), - input_ptr.layout().size() - ) + run.set_input_raw(idx as i32, input_ptr.ptr.cast(), input_ptr.layout().size()) }? } Ok(()) diff --git a/src/devices/opencl/cl_device.rs b/src/devices/opencl/cl_device.rs index 87dc0fd0..8a789b01 100644 --- a/src/devices/opencl/cl_device.rs +++ b/src/devices/opencl/cl_device.rs @@ -305,7 +305,7 @@ pass_down_use_gpu_or_cpu!(OpenCL); impl> crate::Run for OpenCL { #[inline] - unsafe fn run(&self) -> crate::Result<()> { + fn run(&self) -> crate::Result<()> { self.modules.run(self) } } diff --git a/src/devices/opencl/ops.rs b/src/devices/opencl/ops.rs index bb56d881..3e374159 100644 --- a/src/devices/opencl/ops.rs +++ b/src/devices/opencl/ops.rs @@ -413,7 +413,7 @@ mod test { let out = Buffer::from((&device, [1, 1, 1, 1, 1, 1])); device.add_unary_grad(&lhs, &mut lhs_grad, &out, |x| x.mul(2).add(1)); - unsafe { device.run().unwrap() }; + device.run().unwrap(); assert_eq!(lhs_grad.read(), [4, 7, 10, 13, 16, 19]); } diff --git a/src/devices/stack_array.rs b/src/devices/stack_array.rs index a563544a..01ef5e8a 100644 --- a/src/devices/stack_array.rs +++ b/src/devices/stack_array.rs @@ -1,7 +1,4 @@ -use core::{ - ops::{Deref, DerefMut}, - ptr::null_mut, -}; +use core::ops::{Deref, DerefMut}; use crate::{shape::Shape, HasId, HostPtr, PtrType, ShallowCopy}; diff --git a/src/devices/untyped/mod.rs b/src/devices/untyped/mod.rs index a0a75243..e3f730db 100644 --- a/src/devices/untyped/mod.rs +++ b/src/devices/untyped/mod.rs @@ -27,7 +27,7 @@ impl<'a, T: Unit, S: crate::Shape> Buffer<'a, T, Untyped, S> { // Safety: An Untyped device buffer is shape and data type independent! // type Base = UntypedData; <- missing // type Data = UntypedData; <| - // storage type is also matched + // storage type is also matched Some(unsafe { std::mem::transmute(self) }) } @@ -47,9 +47,7 @@ impl<'a, T: Unit, S: crate::Shape> Buffer<'a, T, Untyped, S> { { self.data.matches_storage_type::().ok()?; - Some(unsafe { - &*(self as *const Self as *const Buffer) - }) + Some(unsafe { &*(self as *const Self as *const Buffer) }) } #[inline] @@ -59,9 +57,7 @@ impl<'a, T: Unit, S: crate::Shape> Buffer<'a, T, Untyped, S> { NS: crate::Shape, { self.data.matches_storage_type::().ok()?; - Some(unsafe { - &mut *(self as *mut Self as *mut Buffer) - }) + Some(unsafe { &mut *(self as *mut Self as *mut Buffer) }) } #[inline] @@ -69,9 +65,7 @@ impl<'a, T: Unit, S: crate::Shape> Buffer<'a, T, Untyped, S> { // Safety: An Untyped device buffer is shape and data type independent! // type Base = UntypedData; <- missing // type Data = UntypedData; <| - unsafe { - &*(self as *const Self as *const Buffer<(), Untyped, ()>) - } + unsafe { &*(self as *const Self as *const Buffer<(), Untyped, ()>) } } #[inline] @@ -79,9 +73,7 @@ impl<'a, T: Unit, S: crate::Shape> Buffer<'a, T, Untyped, S> { // Safety: An Untyped device buffer is shape and data type independent! // type Base = UntypedData; <- missing // type Data = UntypedData; <| - unsafe { - &mut *(self as *mut Self as *mut Buffer<(), Untyped, ()>) - } + unsafe { &mut *(self as *mut Self as *mut Buffer<(), Untyped, ()>) } } #[inline] diff --git a/src/devices/vulkan/ops.rs b/src/devices/vulkan/ops.rs index 851598d2..87eb35e5 100644 --- a/src/devices/vulkan/ops.rs +++ b/src/devices/vulkan/ops.rs @@ -2,7 +2,7 @@ use core::fmt::Debug; use crate::{ cpu_stack_ops::clear_slice, pass_down_add_operation, pass_down_exec_now, prelude::Number, - AddOperation, ApplyFunction, AsNoId, BufAsNoId, Buffer, CDatatype, ClearBuf, OnDropBuffer, + AddOperation, ApplyFunction, Buffer, CDatatype, ClearBuf, OnDropBuffer, Read, Resolve, Retrieve, Retriever, Shape, ToCLSource, ToMarker, ToWgslSource, UnaryGrad, Unit, UseGpuOrCpu, Vulkan, WriteBuf, ZeroGrad, }; diff --git a/src/devices/wgsl/ops.rs b/src/devices/wgsl/ops.rs index e3eb68a6..c47ca0d9 100644 --- a/src/devices/wgsl/ops.rs +++ b/src/devices/wgsl/ops.rs @@ -1,5 +1,5 @@ use crate::{ - op_hint::unary, AddOperation, Alloc, ApplyFunction, AsNoId, OnDropBuffer, Read, Retrieve, + op_hint::unary, AddOperation, Alloc, ApplyFunction, OnDropBuffer, Read, Retrieve, Retriever, SetOpHint, Shape, ToMarker, Unit, }; diff --git a/src/error.rs b/src/error.rs index 2db7d4a2..0515c1a1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,38 +1,28 @@ +/// A type alias for Box #[cfg(feature = "std")] -mod std_err { - /// A type alias for Box - pub type Error = Box; - - /// A trait for downcasting errors. - pub trait ErrorKind { - /// Downcasts the error to the specified type. - fn kind(&self) -> Option<&E>; - } - - impl ErrorKind for Error { - fn kind(&self) -> Option<&E> { - self.downcast_ref::() - } - } +pub type Error = Box; +#[cfg(not(feature = "std"))] +pub type Error = DeviceError; - impl std::error::Error for crate::DeviceError {} +/// A trait for downcasting errors. +pub trait ErrorKind { + /// Downcasts the error to the specified type. + fn kind(&self) -> Option<&E>; } -#[cfg(feature = "std")] -pub use std_err::*; +impl ErrorKind for Error { + fn kind(&self) -> Option<&E> { + #[cfg(feature = "std")] + let err = self; -/// A type alias for `Result`. -#[cfg(feature = "std")] -pub type Result = core::result::Result; - -/// An error for no-std. -#[cfg(not(feature = "std"))] -#[derive(Debug)] -pub struct Error {} + #[cfg(not(feature = "std"))] + let err: &dyn core::error::Error = self; + err.downcast_ref::() + } +} /// A type alias for `Result`. -#[cfg(not(feature = "std"))] -pub type Result = core::result::Result; +pub type Result = core::result::Result; /// 'generic' device errors that can occur on any device. #[derive(Clone, Copy, Eq, Hash, Ord, PartialEq, PartialOrd)] @@ -65,6 +55,8 @@ pub enum DeviceError { ShapeLengthMismatch, } +impl core::error::Error for crate::DeviceError {} + impl DeviceError { /// Returns a string slice containing the error message. pub fn as_str(&self) -> &'static str { diff --git a/src/exec_on_cpu.rs b/src/exec_on_cpu.rs index a44dc0db..1bc37cc6 100644 --- a/src/exec_on_cpu.rs +++ b/src/exec_on_cpu.rs @@ -435,6 +435,7 @@ mod tests { } pub trait AddEw: crate::Device { + #[allow(dead_code)] fn add(&self, lhs: &crate::Buffer, rhs: &crate::Buffer) -> crate::Buffer; } diff --git a/src/features.rs b/src/features.rs index 43515d4c..41b7743a 100644 --- a/src/features.rs +++ b/src/features.rs @@ -120,11 +120,7 @@ impl RunModule for crate::Base {} pub trait Run { /// Executes a lazy graph. - /// - /// # Safety - /// The lifetime of captured references is ignored! - /// Specific style of writing operations should prevent UB altogether (at the cost of convenience). - unsafe fn run(&self) -> crate::Result<()>; + fn run(&self) -> crate::Result<()>; } pub trait HasModules { diff --git a/src/lib.rs b/src/lib.rs index c052acc4..adc24f05 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -56,9 +56,6 @@ //! A lot more usage examples can be found in the [tests] and [examples] folder. //! //! [tests]: https://github.com/elftausend/custos/tree/main/tests -use core::ffi::c_void; - -//pub use libs::*; pub use buffer::*; pub use devices::*; diff --git a/src/modules/autograd.rs b/src/modules/autograd.rs index 7632865f..faceadbc 100644 --- a/src/modules/autograd.rs +++ b/src/modules/autograd.rs @@ -404,7 +404,7 @@ mod tests { Ok(()) }); - out.backward(); + out.backward().unwrap(); assert_eq!(&***buf.grad(), [5.; 10]); } @@ -443,7 +443,7 @@ mod tests { Ok(()) }); - out.backward(); + out.backward().unwrap(); assert_eq!(lhs.try_grad().unwrap().as_slice(), [4, 5, 6, 7]); } @@ -463,7 +463,7 @@ mod tests { panic!("should not be called"); }); - out.backward(); + out.backward().unwrap(); assert!(lhs.try_grad().is_none()); @@ -475,7 +475,7 @@ mod tests { Ok(()) }); - out.backward(); + out.backward().unwrap(); assert_eq!(lhs.try_grad().unwrap().as_slice(), [4, 5, 6, 7]); } diff --git a/src/modules/autograd/tape.rs b/src/modules/autograd/tape.rs index 2566045d..5dea6261 100644 --- a/src/modules/autograd/tape.rs +++ b/src/modules/autograd/tape.rs @@ -78,7 +78,7 @@ impl<'t> Tape<'t> { buf: &Buffer<'a, T, D, S>, seed: &[T], buffers: Option<&mut Buffers>>, - ) -> crate::Result<()> + ) -> crate::Result<()> where T: Unit + 'static, D: Alloc + ZeroGrad + WriteBuf + GradActions + AddOperation + 'static, diff --git a/src/modules/autograd/wrapper.rs b/src/modules/autograd/wrapper.rs index 347420ff..dbb4d25d 100644 --- a/src/modules/autograd/wrapper.rs +++ b/src/modules/autograd/wrapper.rs @@ -1,6 +1,6 @@ use core::marker::PhantomData; -use crate::{flag::AllocFlag, Autograd, HasId, PtrType, ShallowCopy, Shape, WrappedData}; +use crate::{flag::AllocFlag, Autograd, HasId, PtrType, ShallowCopy, WrappedData}; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct ReqGradWrapper { diff --git a/src/modules/graph/opt_graph/optimize.rs b/src/modules/graph/opt_graph/optimize.rs index 74e38eb3..d74b7b3e 100644 --- a/src/modules/graph/opt_graph/optimize.rs +++ b/src/modules/graph/opt_graph/optimize.rs @@ -518,7 +518,7 @@ mod tests { let out: Buffer = device.retrieve::<2>(1000, (&mul, &mul_b)).unwrap(); device.optimize_mem_graph(device, None).unwrap(); - let _err = unsafe { device.run() }; + let _err = device.run(); assert_eq!(squared.replace().id(), mul.replace().id()); assert_eq!(squared.replace().id(), out.replace().id()); @@ -659,7 +659,7 @@ mod tests { let out: Buffer = device.retrieve::<2>(1000, (&mul, &mul_b)).unwrap(); device.optimize_mem_graph(&device, None).unwrap(); - let _err = unsafe { device.run() }; + let _err = device.run(); assert_eq!(squared.replace().id(), mul.replace().id()); assert_eq!(squared.replace().id(), out.replace().id()); diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index ce832b35..7e278b54 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -150,7 +150,7 @@ impl ExecNow for Lazy<'_, Mods, T> { impl Lazy<'_, Mods, T> { #[inline] - pub unsafe fn call_lazily(&self, device: &D) -> crate::Result<()> { + pub fn call_lazily(&self, device: &D) -> crate::Result<()> { self.graph .borrow_mut() .call_lazily(device, &mut self.buffers.borrow_mut())?; @@ -178,7 +178,7 @@ impl, D: LazyRun + Device + 'static> RunModule for Lazy #[inline] fn run(&self, device: &D) -> crate::Result<()> { self.alloc_later(device); - unsafe { self.call_lazily::(device)? }; + self.call_lazily::(device)?; device.run()?; self.modules.run(device) } @@ -529,12 +529,10 @@ mod tests { // assert_eq!(out.read(), &[0; 10]); -- should not work device.modules.alloc_later(&device); - unsafe { - device - .modules - .call_lazily::>>(&device) - .unwrap() - } + device + .modules + .call_lazily::>>(&device) + .unwrap(); // assert_eq!(out.read(), &[3; 10]); -- should work assert_eq!(out.replace().read(), &[3; 10]); drop(buf); @@ -583,7 +581,7 @@ mod tests { } if DeviceError::InvalidLazyBuf - != unsafe { *device.run().err().unwrap().downcast().unwrap() } + != *device.run().err().unwrap().downcast().unwrap() { panic!("") } @@ -599,7 +597,7 @@ mod tests { let out = device.apply_fn(&buf, |x| x.add(3)); // assert_eq!(out.read(), &[0; 10]); - unsafe { device.run().unwrap() }; + device.run().unwrap(); assert_eq!(out.replace().read(), &[3; 10]); } @@ -615,7 +613,7 @@ mod tests { device.modules.alloc_later(&device); assert_eq!(out.replace().read(), &[0; 10]); - unsafe { device.run().unwrap() }; + device.run().unwrap(); assert_eq!(out.replace().read(), &[3; 10]); } @@ -629,7 +627,7 @@ mod tests { let buf = Buffer::::new(&device, 10); let out = device.apply_fn(&buf, |x| x.add(3)); - unsafe { device.run().unwrap() } + device.run().unwrap(); assert_eq!(out.replace().read(), &[3; 10]); } #[test] @@ -650,7 +648,7 @@ mod tests { let out = device.add(&lhs, &rhs); // assert_eq!(out.read(), &[0; 10]); - unsafe { device.run().unwrap() }; + device.run().unwrap(); assert_eq!(lhs.replace().read(), &[3; 10]); assert_eq!(out.replace().read(), [4, 5, 6, 7, 8, 9, 10, 11, 12, 13]) @@ -708,7 +706,7 @@ mod tests { device.exec_now(&device, 1..).unwrap(); assert_eq!(out.replace().as_slice(), [2, 4, 6, 8]) } - unsafe { device.run().unwrap() }; + device.run().unwrap(); assert_eq!(out.replace().as_slice(), [0; 4]) } @@ -741,7 +739,7 @@ mod tests { device.exec_last_n(&device, 1).unwrap(); assert_eq!(out.replace().as_slice(), [2, 4, 6, 8]) } - unsafe { device.run().unwrap() }; + device.run().unwrap(); assert_eq!(out.replace().as_slice(), [0; 4]) } @@ -777,7 +775,7 @@ mod tests { .unwrap(); } - if unsafe { device.run() }.is_ok() { + if device.run().is_ok() { panic!() } } diff --git a/src/modules/lazy/lazy_graph.rs b/src/modules/lazy/lazy_graph.rs index 1d07b038..7718dce8 100644 --- a/src/modules/lazy/lazy_graph.rs +++ b/src/modules/lazy/lazy_graph.rs @@ -1,5 +1,6 @@ use crate::{ - bounds_to_range, modules::lazy::exec_iter::ExecIter, op_hint::OpHint, AnyOp, BoxedShallowCopy, Buffers, Device, Downcast, Id, OperationFn, Parents + bounds_to_range, modules::lazy::exec_iter::ExecIter, op_hint::OpHint, AnyOp, BoxedShallowCopy, + Buffers, Device, Downcast, Id, OperationFn, Parents, }; use core::ops::RangeBounds; use std::collections::HashSet; @@ -67,13 +68,13 @@ impl LazyGraph { self.operations.len() } - pub unsafe fn call_lazily( + pub fn call_lazily( &mut self, device: &D, buffers: &mut Buffers, ) -> crate::Result<()> { - for args in self.iter_with(device, buffers) { - args?; + for res in self.iter_with(device, buffers) { + res?; } Ok(()) } @@ -193,7 +194,7 @@ mod tests { println!("args: {args:?}"); Ok(()) }); - unsafe { graph.call_lazily(&device, &mut buffers).unwrap() }; + graph.call_lazily(&device, &mut buffers).unwrap(); }; // let x = DEVICE2.get().unwrap(); // println!("{:?}", x.modules.cache.borrow().nodes); @@ -231,7 +232,7 @@ mod tests { }; // todo!() - unsafe { graph.call_lazily(&device, &mut outs_unordered).unwrap() } + graph.call_lazily(&device, &mut outs_unordered).unwrap(); } #[test] @@ -257,7 +258,7 @@ mod tests { Ok(()) }); - unsafe { graph.call_lazily(&device, &mut outs_unordered).unwrap() } + graph.call_lazily(&device, &mut outs_unordered).unwrap(); } #[test] @@ -290,7 +291,7 @@ mod tests { }); } - unsafe { graph.call_lazily(&device, &mut outs_unordered).unwrap() } + graph.call_lazily(&device, &mut outs_unordered).unwrap(); } #[test] fn test_lazy_op_args_no_out_but_use() { @@ -317,7 +318,7 @@ mod tests { Ok(()) }); - unsafe { graph.call_lazily(&device, &mut outs_unordered).unwrap() } + graph.call_lazily(&device, &mut outs_unordered).unwrap(); } #[test] @@ -351,6 +352,6 @@ mod tests { Ok(()) }); - unsafe { graph.call_lazily(&device, &mut outs_unordered).unwrap() } + graph.call_lazily(&device, &mut outs_unordered).unwrap(); } } diff --git a/src/modules/lazy/wrapper.rs b/src/modules/lazy/wrapper.rs index 3f4f72f6..ebf2e313 100644 --- a/src/modules/lazy/wrapper.rs +++ b/src/modules/lazy/wrapper.rs @@ -3,7 +3,7 @@ use core::{ ops::{Deref, DerefMut}, }; -use crate::{flag::AllocFlag, HasId, HostPtr, Id, Lazy, PtrType, ShallowCopy, Shape, WrappedData}; +use crate::{flag::AllocFlag, HasId, HostPtr, Id, Lazy, PtrType, ShallowCopy, WrappedData}; #[derive(Debug, Default)] pub struct LazyWrapper { diff --git a/src/op_hint.rs b/src/op_hint.rs index c4c29e14..a8806824 100644 --- a/src/op_hint.rs +++ b/src/op_hint.rs @@ -130,7 +130,7 @@ mod tests { dev.optimize_mem_graph(&dev, None).unwrap(); dev.unary_fusing(&dev, None).unwrap(); - unsafe { dev.run().unwrap() }; + dev.run().unwrap(); for (buf, out) in buf.iter().zip(_out.replace().iter()) { assert!((*out - buf.sin().cos().ln()).abs() < 0.001); @@ -153,7 +153,7 @@ mod tests { dev.optimize_mem_graph(&dev, None).unwrap(); dev.unary_fusing(&dev, None).unwrap(); - unsafe { dev.run().unwrap() }; + dev.run().unwrap(); for (buf, out) in buf.read().iter().zip(_out.replace().read().iter()) { assert!((*out - buf.sin().cos().ln()).abs() < 0.001); @@ -176,7 +176,7 @@ mod tests { dev.optimize_mem_graph(&dev, None).unwrap(); dev.unary_fusing(&dev, None).unwrap(); - let _ = unsafe { dev.run() }; + let _ = dev.run(); for (buf, out) in buf.read().iter().zip(_out.replace().read().iter()) { assert!((*out - buf.sin().cos().ln()).abs() < 0.001); @@ -202,7 +202,7 @@ mod tests { dev.optimize_mem_graph(&dev, None).unwrap(); dev.unary_fusing(&dev, None).unwrap(); - unsafe { dev.run().unwrap() }; + dev.run().unwrap(); for (buf, out) in buf.iter().zip(_out.replace().iter()) { assert_eq!(*out, buf.sin().abs().ln()); @@ -293,11 +293,11 @@ mod tests { println!("unary fusing: {:?}", start.elapsed()); - unsafe { dev.run().unwrap() }; + dev.run().unwrap(); let start = Instant::now(); - unsafe { dev.run().unwrap() }; + dev.run().unwrap(); println!("perf automatic fusing: {:?}", start.elapsed()); diff --git a/src/unary.rs b/src/unary.rs index 087a62f4..c2682af6 100644 --- a/src/unary.rs +++ b/src/unary.rs @@ -292,7 +292,7 @@ mod tests { macro_rules! run_several_times { ($device:ident, $buf:ident, $out:ident) => { for i in 1..10 { - unsafe { $device.run() }.unwrap(); + $device.run().unwrap(); roughly_eq_slices( $out.replace().as_slice(), &[ diff --git a/tests/alloc.rs b/tests/alloc.rs index 9a008bbd..5f9f7d08 100644 --- a/tests/alloc.rs +++ b/tests/alloc.rs @@ -1,4 +1,3 @@ - #[cfg(feature = "wgpu")] use custos::prelude::*; #[cfg(feature = "wgpu")] diff --git a/tests/buffer.rs b/tests/buffer.rs index 4f0835cb..d7baa16c 100644 --- a/tests/buffer.rs +++ b/tests/buffer.rs @@ -1,6 +1,5 @@ use custos::prelude::*; - #[cfg(feature = "std")] pub fn read>(device: &D, buf: &Buffer) -> Vec where @@ -194,7 +193,6 @@ fn test_deviceless_buf() { Buffer::::deviceless(&device, 5) }; - for (idx, element) in buf.iter_mut().enumerate() { *element = idx as u8; }