diff --git a/Cargo.toml b/Cargo.toml index bae839cf0..1cdbbd8ba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,9 +18,11 @@ half = "2" num-complex = "0.4" num_enum = "0.7.2" num-traits = "0.2.18" +strum = { version = "0.26", features = ["derive"] } thiserror = "1.0.58" [dev-dependencies] +pretty_assertions = "1.4.0" [workspace] members = ["mlx-macros", "mlx-sys", "mlx-macros"] diff --git a/mlx-macros/src/lib.rs b/mlx-macros/src/lib.rs index 3fe8f311c..958ed9190 100644 --- a/mlx-macros/src/lib.rs +++ b/mlx-macros/src/lib.rs @@ -2,7 +2,7 @@ extern crate proc_macro; use proc_macro::TokenStream; use quote::{format_ident, quote}; use syn::punctuated::Punctuated; -use syn::{parse_macro_input, parse_quote, FnArg, ItemFn, Pat}; +use syn::{parse_macro_input, parse_quote, DeriveInput, FnArg, ItemFn, Pat}; #[proc_macro_attribute] pub fn default_device(_attr: TokenStream, item: TokenStream) -> TokenStream { @@ -51,3 +51,50 @@ pub fn default_device(_attr: TokenStream, item: TokenStream) -> TokenStream { TokenStream::from(expanded) } + +#[proc_macro_derive(GenerateDtypeTestCases)] +pub fn generate_test_cases(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let name = &input.ident; + + let tests = quote! { + /// MLX's rules for promoting two dtypes. + #[rustfmt::skip] + const TYPE_RULES: [[Dtype; 13]; 13] = [ + // bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 bfloat16 complex64 + [Dtype::Bool, Dtype::Uint8, Dtype::Uint16, Dtype::Uint32, Dtype::Uint64, Dtype::Int8, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // bool + [Dtype::Uint8, Dtype::Uint8, Dtype::Uint16, Dtype::Uint32, Dtype::Uint64, Dtype::Int16, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // uint8 + [Dtype::Uint16, Dtype::Uint16, Dtype::Uint16, Dtype::Uint32, Dtype::Uint64, Dtype::Int32, Dtype::Int32, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // uint16 + [Dtype::Uint32, Dtype::Uint32, Dtype::Uint32, Dtype::Uint32, Dtype::Uint64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // uint32 + [Dtype::Uint64, Dtype::Uint64, Dtype::Uint64, Dtype::Uint64, Dtype::Uint64, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // uint64 + [Dtype::Int8, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float32, Dtype::Int8, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // int8 + [Dtype::Int16, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float32, Dtype::Int16, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // int16 + [Dtype::Int32, Dtype::Int32, Dtype::Int32, Dtype::Int64, Dtype::Float32, Dtype::Int32, Dtype::Int32, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // int32 + [Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Float32, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // int64 + [Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float32, Dtype::Float32, Dtype::Complex64], // float16 + [Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Complex64], // float32 + [Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Float32, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // bfloat16 + [Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64], // complex64 + ]; + + #[cfg(test)] + mod generated_tests { + use super::*; + use strum::IntoEnumIterator; + use pretty_assertions::assert_eq; + + #[test] + fn test_all_combinations() { + for a in #name::iter() { + for b in #name::iter() { + let result = a.promote_with(b); + let expected = TYPE_RULES[a as usize][b as usize]; + assert_eq!(result, expected, "{}", format!("Failed promotion test for {:?} and {:?}", a, b)); + } + } + } + } + }; + + TokenStream::from(tests) +} diff --git a/src/array.rs b/src/array.rs index 2cca9f3b6..276d28815 100644 --- a/src/array.rs +++ b/src/array.rs @@ -1,9 +1,10 @@ use std::ffi::c_void; -use std::ops::Add; +use std::ops::{Add, Div, Mul, Neg, Rem, Sub}; use half::{bf16, f16}; use mlx_sys::mlx_array; use num_complex::Complex; +use num_traits::Pow; use crate::{dtype::Dtype, error::AsSliceError, sealed::Sealed, StreamOrDevice}; @@ -137,6 +138,55 @@ pub struct Array { pub(crate) c_array: mlx_array, } +impl<'a> Add for &'a Array { + type Output = Array; + fn add(self, rhs: Self) -> Self::Output { + self.add_device(rhs, StreamOrDevice::default()) + } +} + +impl<'a> Sub for &'a Array { + type Output = Array; + fn sub(self, rhs: Self) -> Self::Output { + self.sub_device(rhs, StreamOrDevice::default()) + } +} + +impl<'a> Neg for &'a Array { + type Output = Array; + fn neg(self) -> Self::Output { + self.logical_not() + } +} + +impl<'a> Mul for &'a Array { + type Output = Array; + fn mul(self, rhs: Self) -> Self::Output { + self.mul_device(rhs, StreamOrDevice::default()) + } +} + +impl<'a> Div for &'a Array { + type Output = Array; + fn div(self, rhs: Self) -> Self::Output { + self.div_device(rhs, StreamOrDevice::default()) + } +} + +impl<'a> Pow<&'a Array> for &'a Array { + type Output = Array; + fn pow(self, rhs: &'a Array) -> Self::Output { + self.pow_device(rhs, StreamOrDevice::default()) + } +} + +impl<'a> Rem for &'a Array { + type Output = Array; + fn rem(self, rhs: Self) -> Self::Output { + self.rem_device(rhs, StreamOrDevice::default()) + } +} + impl std::fmt::Debug for Array { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let description = crate::utils::mlx_describe(self.c_array as *mut c_void) @@ -153,13 +203,6 @@ impl std::fmt::Display for Array { } } -impl<'a> Add for &'a Array { - type Output = Array; - fn add(self, rhs: Self) -> Self::Output { - self.add_device(rhs, StreamOrDevice::default()) - } -} - // TODO: Clone should probably NOT be implemented because the underlying pointer is atomically // reference counted but not guarded by a mutex. @@ -371,7 +414,10 @@ impl Array { } if self.dtype() != T::DTYPE { - return Err(AsSliceError::DtypeMismatch); + return Err(AsSliceError::DtypeMismatch { + expecting: T::DTYPE, + found: self.dtype(), + }); } Ok(unsafe { self.as_slice_unchecked() }) diff --git a/src/dtype.rs b/src/dtype.rs index 99f630299..5e64e73c6 100644 --- a/src/dtype.rs +++ b/src/dtype.rs @@ -1,6 +1,17 @@ +use mlx_macros::GenerateDtypeTestCases; +use strum::EnumIter; + /// Array element type #[derive( - Debug, Clone, Copy, PartialEq, Eq, num_enum::IntoPrimitive, num_enum::TryFromPrimitive, + Debug, + Clone, + Copy, + PartialEq, + Eq, + num_enum::IntoPrimitive, + num_enum::TryFromPrimitive, + EnumIter, + GenerateDtypeTestCases, )] #[repr(u32)] pub enum Dtype { @@ -18,3 +29,112 @@ pub enum Dtype { Bfloat16 = mlx_sys::mlx_array_dtype__MLX_BFLOAT16, Complex64 = mlx_sys::mlx_array_dtype__MLX_COMPLEX64, } + +impl Dtype { + pub fn is_complex(&self) -> bool { + matches!(self, Dtype::Complex64) + } + + pub fn is_floating(&self) -> bool { + matches!(self, Dtype::Float16 | Dtype::Float32 | Dtype::Bfloat16) + } + + pub fn is_inexact(&self) -> bool { + matches!( + self, + Dtype::Float16 | Dtype::Float32 | Dtype::Complex64 | Dtype::Bfloat16 + ) + } + + pub fn from_promoting_types(a: Dtype, b: Dtype) -> Self { + a.promote_with(b) + } +} + +pub(crate) trait TypePromotion { + fn promote_with(self, other: Self) -> Self; +} + +impl TypePromotion for Dtype { + fn promote_with(self, other: Self) -> Self { + use crate::dtype::Dtype::*; + match (self, other) { + // Boolean promotions + (Bool, Bool) => Bool, + (Bool, _) | (_, Bool) => { + if self == Bool { + other + } else { + self + } + } + + // Uint8 promotions + (Uint8, Uint8) => Uint8, + (Uint8, Uint16) | (Uint16, Uint8) => Uint16, + (Uint8, Uint32) | (Uint32, Uint8) => Uint32, + (Uint8, Uint64) | (Uint64, Uint8) => Uint64, + (Uint8, Int8) | (Int8, Uint8) => Int16, + (Uint8, Int16) | (Int16, Uint8) => Int16, + (Uint8, Int32) | (Int32, Uint8) => Int32, + (Uint8, Int64) | (Int64, Uint8) => Int64, + + // Uint16 promotions + (Uint16, Uint16) => Uint16, + (Uint16, Uint32) | (Uint32, Uint16) => Uint32, + (Uint16, Uint64) | (Uint64, Uint16) => Uint64, + (Uint16, Int8) | (Int8, Uint16) => Int32, + (Uint16, Int16) | (Int16, Uint16) => Int32, + (Uint16, Int32) | (Int32, Uint16) => Int32, + (Uint16, Int64) | (Int64, Uint16) => Int64, + + // Uint32 promotions + (Uint32, Uint32) => Uint32, + (Uint32, Uint64) | (Uint64, Uint32) => Uint64, + (Uint32, Int8) | (Int8, Uint32) => Int64, + (Uint32, Int16) | (Int16, Uint32) => Int64, + (Uint32, Int32) | (Int32, Uint32) => Int64, + (Uint32, Int64) | (Int64, Uint32) => Int64, + + // Uint64 promotions + (Uint64, Uint64) => Uint64, + (Uint64, Int8) | (Int8, Uint64) => Float32, + (Uint64, Int16) | (Int16, Uint64) => Float32, + (Uint64, Int32) | (Int32, Uint64) => Float32, + (Uint64, Int64) | (Int64, Uint64) => Float32, + + // Int8 promotions + (Int8, Int8) => Int8, + (Int8, Int16) | (Int16, Int8) => Int16, + (Int8, Int32) | (Int32, Int8) => Int32, + (Int8, Int64) | (Int64, Int8) => Int64, + + // Int16 promotions + (Int16, Int16) => Int16, + (Int16, Int32) | (Int32, Int16) => Int32, + (Int16, Int64) | (Int64, Int16) => Int64, + + // Int32 promotions + (Int32, Int32) => Int32, + (Int32, Int64) | (Int64, Int32) => Int64, + + // Int64 promotions + (Int64, Int64) => Int64, + + // Float16 promotions + (Float16, Bfloat16) | (Bfloat16, Float16) => Float32, + + // Complex type + (Complex64, _) | (_, Complex64) => Complex64, + + // Float32 promotions + (Float32, _) | (_, Float32) => Float32, + + // Float16 promotions + (Float16, _) | (_, Float16) => Float16, + + // Bfloat16 promotions + (Bfloat16, _) | (_, Bfloat16) => Bfloat16, + } + } +} diff --git a/src/error.rs b/src/error.rs index ff35f7781..c94dec1f0 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,12 +1,41 @@ +use crate::Dtype; use thiserror::Error; -#[derive(Error, Debug)] +#[derive(Error, PartialEq, Debug)] +pub enum MlxError { + #[error("data store error: {0}")] + DataStore(#[from] DataStoreError), + #[error("operation error: {0}")] + Operation(#[from] OperationError), + #[error("as slice error: {0}")] + AsSlice(#[from] AsSliceError), +} + +#[derive(Error, PartialEq, Debug)] pub enum DataStoreError { #[error("negative dimension: {0}")] NegativeDimensions(String), #[error("negative integer: {0}")] NegativeInteger(String), + + #[error("broadcast error")] + BroadcastError, +} + +#[derive(Error, PartialEq, Debug)] +pub enum OperationError { + #[error("operation not supported: {0}")] + NotSupported(String), + + #[error("wrong input: {0}")] + WrongInput(String), + + #[error("wrong dimensions: {0}")] + WrongDimensions(String), + + #[error("axis out of bounds: {0}")] + AxisOutOfBounds(String), } /// Error associated with `Array::try_as_slice()` @@ -19,6 +48,6 @@ pub enum AsSliceError { Null, /// The output dtype does not match the data type of the array. - #[error("Desired output dtype does not match the data type of the array.")] - DtypeMismatch, + #[error("dtype mismatch: expected {expecting:?}, found {found:?}")] + DtypeMismatch { expecting: Dtype, found: Dtype }, } diff --git a/src/ops/arithmetic.rs b/src/ops/arithmetic.rs new file mode 100644 index 000000000..f93cebdd5 --- /dev/null +++ b/src/ops/arithmetic.rs @@ -0,0 +1,1701 @@ +use crate::array::Array; +use crate::error::{DataStoreError, MlxError, OperationError}; +use crate::stream::StreamOrDevice; +use crate::utils::is_broadcastable; +use crate::Dtype; +use mlx_macros::default_device; + +impl Array { + /// Element-wise absolute value. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let array = Array::from_slice(&[1i32, 2, -3, -4, -5], &[5]); + /// let mut result = array.abs(); + /// + /// result.eval(); + /// let data: &[i32] = result.as_slice(); + /// // data == [1, 2, 3, 4, 5] + /// ``` + /// + /// # Params + /// + /// - stream: stream or device to evaluate on + #[default_device] + pub fn abs_device(&self, stream: StreamOrDevice) -> Array { + unsafe { Array::from_ptr(mlx_sys::mlx_abs(self.c_array, stream.as_ptr())) } + } + + /// Element-wise addition. + /// + /// Add two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); + /// let mut c = a.add_device(&b, Default::default()); + /// + /// c.eval(); + /// let c_data: &[f32] = c.as_slice(); + /// // c_data == [5.0, 7.0, 9.0] + /// ``` + /// + /// # Params + /// + /// - other: array to add + /// - stream: stream or device to evaluate on + pub fn add_device(&self, other: &Array, stream: StreamOrDevice) -> Array { + self.try_add_device(other, stream).unwrap() + } + + /// Element-wise addition without checking broadcastability. + /// + /// Add two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); + /// let mut c = a.add_device(&b, Default::default()); + /// + /// c.eval(); + /// let c_data: &[f32] = c.as_slice(); + /// // c_data == [5.0, 7.0, 9.0] + /// ``` + /// + /// # Params + /// + /// - other: array to add + /// - stream: stream or device to evaluate on + /// + /// # Safety + /// + /// This function is unsafe because it does not check that the arrays have the same shape. + #[default_device] + pub unsafe fn add_device_unchecked(&self, other: &Array, stream: StreamOrDevice) -> Array { + unsafe { + Array::from_ptr(mlx_sys::mlx_add( + self.c_array, + other.c_array, + stream.as_ptr(), + )) + } + } + + /// Element-wise addition returning an error if arrays are not broadcastable. + /// + /// Add two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); + /// let mut c = a.add_device(&b, Default::default()); + /// + /// c.eval(); + /// let c_data: &[f32] = c.as_slice(); + /// // c_data == [5.0, 7.0, 9.0] + /// ``` + /// + /// # Params + /// + /// - other: array to add + /// - stream: stream or device to evaluate on + #[default_device] + pub fn try_add_device( + &self, + other: &Array, + stream: StreamOrDevice, + ) -> Result { + if !is_broadcastable(self.shape(), other.shape()) { + return Err(DataStoreError::BroadcastError); + } + + Ok(unsafe { self.add_device_unchecked(other, stream) }) + } + + /// Element-wise subtraction. + /// + /// Subtract two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); + /// let mut c = a.sub_device(&b, Default::default()); + /// + /// c.eval(); + /// let c_data: &[f32] = c.as_slice(); + /// // c_data == [-3.0, -3.0, -3.0] + /// ``` + /// + /// # Params + /// + /// - other: array to subtract + /// - stream: stream or device to evaluate on + pub fn sub_device(&self, other: &Array, stream: StreamOrDevice) -> Array { + self.try_sub_device(other, stream).unwrap() + } + + /// Element-wise subtraction without checking broadcastability. + /// + /// Subtract two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); + /// let mut c = a.sub_device(&b, Default::default()); + /// + /// c.eval(); + /// let c_data: &[f32] = c.as_slice(); + /// // c_data == [-3.0, -3.0, -3.0] + /// ``` + /// + /// # Params + /// + /// - other: array to subtract + /// - stream: stream or device to evaluate on + /// + /// # Safety + /// + /// This function is unsafe because it does not check that the arrays have the same shape. + #[default_device] + pub unsafe fn sub_device_unchecked(&self, other: &Array, stream: StreamOrDevice) -> Array { + unsafe { + Array::from_ptr(mlx_sys::mlx_subtract( + self.c_array, + other.c_array, + stream.as_ptr(), + )) + } + } + + /// Element-wise subtraction returning an error if arrays are not broadcastable. + /// + /// Subtract two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); + /// let mut c = a.sub_device(&b, Default::default()); + /// + /// c.eval(); + /// let c_data: &[f32] = c.as_slice(); + /// // c_data == [-3.0, -3.0, -3.0] + /// ``` + /// + /// # Params + /// + /// - other: array to subtract + /// - stream: stream or device to evaluate on + #[default_device] + pub fn try_sub_device( + &self, + other: &Array, + stream: StreamOrDevice, + ) -> Result { + if !is_broadcastable(self.shape(), other.shape()) { + return Err(DataStoreError::BroadcastError); + } + + Ok(unsafe { self.sub_device_unchecked(other, stream) }) + } + + /// Unary element-wise negation. + /// + /// Negate the values in the array. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let mut b = a.neg(); + /// + /// b.eval(); + /// let b_data: &[f32] = b.as_slice(); + /// // b_data == [-1.0, -2.0, -3.0] + /// ``` + /// + /// # Params + /// + /// - stream: stream or device to evaluate on + #[default_device] + pub fn neg_device(&self, stream: StreamOrDevice) -> Array { + self.try_neg_device(stream).unwrap() + } + + /// Unary element-wise negation without validating the array type. + /// + /// Negate the values in the array. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let mut b = unsafe { a.neg_unchecked() }; + /// + /// b.eval(); + /// let b_data: &[f32] = b.as_slice(); + /// // b_data == [-1.0, -2.0, -3.0] + /// ``` + /// + /// # Params + /// + /// - stream: stream or device to evaluate on + /// + /// # Safety + /// + /// This function is unsafe because it does not check that the array is not a boolean array. + #[default_device] + pub unsafe fn neg_device_unchecked(&self, stream: StreamOrDevice) -> Array { + unsafe { Array::from_ptr(mlx_sys::mlx_negative(self.c_array, stream.as_ptr())) } + } + + /// Unary element-wise negation. + /// + /// Negate the values in the array. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let mut b = a.try_neg().unwrap(); + /// + /// b.eval(); + /// let b_data: &[f32] = b.as_slice(); + /// // b_data == [-1.0, -2.0, -3.0] + /// ``` + /// + /// # Params + /// + /// - stream: stream or device to evaluate on + /// + /// # Errors + /// + /// Returns an error if the array is of type bool. + #[default_device] + pub fn try_neg_device(&self, stream: StreamOrDevice) -> Result { + if self.dtype() == Dtype::Bool { + return Err(OperationError::NotSupported( + "Negation not supported for bool, use logical_not() instead".to_string(), + )); + } + + Ok(unsafe { self.neg_device_unchecked(stream) }) + } + + /// Unary element-wise logical not. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a: Array = false.into(); + /// let mut b = a.logical_not_device(Default::default()); + /// + /// b.eval(); + /// let b_data: &[bool] = b.as_slice(); + /// // b_data == [true] + /// ``` + /// + /// # Params + /// + /// - stream: stream or device to evaluate on + #[default_device] + pub fn logical_not_device(&self, stream: StreamOrDevice) -> Array { + unsafe { Array::from_ptr(mlx_sys::mlx_logical_not(self.c_array, stream.as_ptr())) } + } + + /// Element-wise multiplication. + /// + /// Multiply two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); + /// let mut c = a.mul_device(&b, Default::default()); + /// + /// c.eval(); + /// let c_data: &[f32] = c.as_slice(); + /// // c_data == [4.0, 10.0, 18.0] + /// ``` + pub fn mul_device(&self, other: &Array, stream: StreamOrDevice) -> Array { + self.try_mul_device(other, stream).unwrap() + } + + /// Element-wise multiplication without checking broadcastability. + /// + /// Multiply two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); + /// let mut c = a.mul_device(&b, Default::default()); + /// + /// c.eval(); + /// let c_data: &[f32] = c.as_slice(); + /// // c_data == [4.0, 10.0, 18.0] + /// ``` + /// + /// # Safety + /// + /// This function is unsafe because it does not check that the arrays are broadcastable. + #[default_device] + pub unsafe fn mul_device_unchecked(&self, other: &Array, stream: StreamOrDevice) -> Array { + unsafe { + Array::from_ptr(mlx_sys::mlx_multiply( + self.c_array, + other.c_array, + stream.as_ptr(), + )) + } + } + + /// Element-wise multiplication returning an error if arrays are not broadcastable. + /// + /// Multiply two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); + /// let mut c = a.mul_device(&b, Default::default()); + /// + /// c.eval(); + /// let c_data: &[f32] = c.as_slice(); + /// // c_data == [4.0, 10.0, 18.0] + /// ``` + #[default_device] + pub fn try_mul_device( + &self, + other: &Array, + stream: StreamOrDevice, + ) -> Result { + if !is_broadcastable(self.shape(), other.shape()) { + return Err(DataStoreError::BroadcastError); + } + + Ok(unsafe { self.mul_device_unchecked(other, stream) }) + } + + /// Element-wise division. + /// + /// Divide two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); + /// let mut c = a.div_device(&b, Default::default()); + /// + /// c.eval(); + /// let c_data: &[f32] = c.as_slice(); + /// // c_data == [0.25, 0.4, 0.5] + /// ``` + /// + /// # Params + /// + /// - other: array to divide + /// - stream: stream or device to evaluate on + pub fn div_device(&self, other: &Array, stream: StreamOrDevice) -> Array { + self.try_div_device(other, stream).unwrap() + } + + /// Element-wise division without checking broadcastability. + /// + /// Divide two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); + /// let mut c = a.div_device(&b, Default::default()); + /// + /// c.eval(); + /// let c_data: &[f32] = c.as_slice(); + /// // c_data == [0.25, 0.4, 0.5] + /// ``` + /// + /// # Params + /// + /// - other: array to divide + /// - stream: stream or device to evaluate on + /// + /// # Safety + /// + /// This function is unsafe because it does not check that the arrays are broadcastable. + #[default_device] + pub unsafe fn div_device_unchecked(&self, other: &Array, stream: StreamOrDevice) -> Array { + unsafe { + Array::from_ptr(mlx_sys::mlx_divide( + self.c_array, + other.c_array, + stream.as_ptr(), + )) + } + } + + /// Element-wise division returning an error if arrays are not broadcastable. + /// + /// Divide two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); + /// let mut c = a.div_device(&b, Default::default()); + /// + /// c.eval(); + /// let c_data: &[f32] = c.as_slice(); + /// // c_data == [0.25, 0.4, 0.5] + /// ``` + /// + /// # Params + /// + /// - other: array to divide + /// - stream: stream or device to evaluate on + #[default_device] + pub fn try_div_device( + &self, + other: &Array, + stream: StreamOrDevice, + ) -> Result { + if !is_broadcastable(self.shape(), other.shape()) { + return Err(DataStoreError::BroadcastError); + } + + Ok(unsafe { self.div_device_unchecked(other, stream) }) + } + + /// Element-wise power operation. + /// + /// Raise the elements of the array to the power of the elements of another array. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let b = Array::from_slice(&[2.0, 3.0, 4.0], &[3]); + /// let mut c = a.pow_device(&b, Default::default()); + /// + /// c.eval(); + /// let c_data: &[f32] = c.as_slice(); + /// // c_data == [1.0, 8.0, 81.0] + /// ``` + /// + /// # Params + /// + /// - other: array to raise to the power of + /// - stream: stream or device to evaluate on + pub fn pow_device(&self, other: &Array, stream: StreamOrDevice) -> Array { + self.try_pow_device(other, stream).unwrap() + } + + /// Element-wise power operation without checking broadcastability if arrays are different shapes. + /// + /// Raise the elements of the array to the power of the elements of another array. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let b = Array::from_slice(&[2.0, 3.0, 4.0], &[3]); + /// let mut c = a.pow_device(&b, Default::default()); + /// + /// c.eval(); + /// let c_data: &[f32] = c.as_slice(); + /// // c_data == [1.0, 8.0, 81.0] + /// ``` + /// + /// # Params + /// + /// - other: array to raise to the power of + /// - stream: stream or device to evaluate on + /// + /// # Safety + /// + /// This function is unsafe because it does not check that the arrays are broadcastable if they have different shapes. + #[default_device] + pub unsafe fn pow_device_unchecked(&self, other: &Array, stream: StreamOrDevice) -> Array { + unsafe { + Array::from_ptr(mlx_sys::mlx_power( + self.c_array, + other.c_array, + stream.as_ptr(), + )) + } + } + + /// Element-wise power operation returning an error if arrays are not broadcastable if they have different shapes. + /// + /// Raise the elements of the array to the power of the elements of another array. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let b = Array::from_slice(&[2.0, 3.0, 4.0], &[3]); + /// let mut c = a.pow_device(&b, Default::default()); + /// + /// c.eval(); + /// let c_data: &[f32] = c.as_slice(); + /// // c_data == [1.0, 8.0, 81.0] + /// ``` + /// + /// # Params + /// + /// - other: array to raise to the power of + /// - stream: stream or device to evaluate on + #[default_device] + pub fn try_pow_device( + &self, + other: &Array, + stream: StreamOrDevice, + ) -> Result { + if self.shape() != other.shape() { + if !is_broadcastable(self.shape(), other.shape()) { + return Err(DataStoreError::BroadcastError); + } + } + + Ok(unsafe { self.pow_device_unchecked(other, stream) }) + } + + /// Element-wise remainder of division. + /// + /// Computes the remainder of dividing `lhs` with `rhs` with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]); + /// let b = Array::from_slice(&[3.0, 4.0, 5.0], &[3]); + /// let mut c = a.rem_device(&b, Default::default()); + /// + /// c.eval(); + /// let c_data: &[f32] = c.as_slice(); + /// // c_data == [1.0, 3.0, 2.0] + /// ``` + /// + /// # Params + /// + /// - other: array to divide + /// - stream: stream or device to evaluate on + pub fn rem_device(&self, other: &Array, stream: StreamOrDevice) -> Array { + self.try_rem_device(other, stream).unwrap() + } + + /// Element-wise remainder of division without checking broadcastability. + /// + /// Computes the remainder of dividing `lhs` with `rhs` with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]); + /// let b = Array::from_slice(&[3.0, 4.0, 5.0], &[3]); + /// let mut c = a.rem_device(&b, Default::default()); + /// + /// c.eval(); + /// let c_data: &[f32] = c.as_slice(); + /// // c_data == [1.0, 3.0, 2.0] + /// ``` + /// + /// # Params + /// + /// - other: array to divide + /// - stream: stream or device to evaluate on + /// + /// # Safety + /// + /// This function is unsafe because it does not check that the arrays are broadcastable. + #[default_device] + pub unsafe fn rem_device_unchecked(&self, other: &Array, stream: StreamOrDevice) -> Array { + unsafe { + Array::from_ptr(mlx_sys::mlx_remainder( + self.c_array, + other.c_array, + stream.as_ptr(), + )) + } + } + + /// Element-wise remainder of division returning an error if arrays are not broadcastable. + /// + /// Computes the remainder of dividing `lhs` with `rhs` with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]); + /// let b = Array::from_slice(&[3.0, 4.0, 5.0], &[3]); + /// let mut c = a.rem_device(&b, Default::default()); + /// + /// c.eval(); + /// let c_data: &[f32] = c.as_slice(); + /// // c_data == [1.0, 3.0, 2.0] + /// ``` + /// + /// # Params + /// + /// - other: array to divide + /// - stream: stream or device to evaluate on + #[default_device] + pub fn try_rem_device( + &self, + other: &Array, + stream: StreamOrDevice, + ) -> Result { + if !is_broadcastable(self.shape(), other.shape()) { + return Err(DataStoreError::BroadcastError); + } + + Ok(unsafe { self.rem_device_unchecked(other, stream) }) + } + + /// Element-wise square root + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 4.0, 9.0], &[3]); + /// let mut b = a.sqrt_device(Default::default()); + /// + /// b.eval(); + /// let b_data: &[f32] = b.as_slice(); + /// // b_data == [1.0, 2.0, 3.0] + /// ``` + #[default_device] + pub fn sqrt_device(&self, stream: StreamOrDevice) -> Array { + unsafe { Array::from_ptr(mlx_sys::mlx_sqrt(self.c_array, stream.as_ptr())) } + } + + /// Element-wise cosine + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]); + /// let mut b = a.cos_device(Default::default()); + /// + /// b.eval(); + /// let b_data: &[f32] = b.as_slice(); + /// // b_data == [1.0, 0.54030234, -0.41614687] + /// ``` + #[default_device] + pub fn cos_device(&self, stream: StreamOrDevice) -> Array { + unsafe { Array::from_ptr(mlx_sys::mlx_cos(self.c_array, stream.as_ptr())) } + } + + /// Element-wise exponential. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// + /// let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]); + /// let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]); + /// let mut b = a.exp_device(Default::default()); + /// + /// b.eval(); + /// let b_data: &[f32] = b.as_slice(); + /// // b_data == [1.0, 2.7182817, 7.389056] + /// ``` + #[default_device] + pub fn exp_device(&self, stream: StreamOrDevice) -> Array { + unsafe { Array::from_ptr(mlx_sys::mlx_exp(self.c_array, stream.as_ptr())) } + } + + /// Element-wise floor. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[0.1, 1.9, 2.5], &[3]); + /// let mut b = a.floor_device(Default::default()); + /// + /// b.eval(); + /// let b_data: &[f32] = b.as_slice(); + /// // b_data == [0.0, 1.0, 2.0] + /// ``` + #[default_device] + pub fn floor_device(&self, stream: StreamOrDevice) -> Array { + self.try_floor_device(stream).unwrap() + } + + /// Element-wise floor without checking the array type. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[0.1, 1.9, 2.5], &[3]); + /// let mut b = a.floor_device(Default::default()); + /// + /// b.eval(); + /// let b_data: &[f32] = b.as_slice(); + /// // b_data == [0.0, 1.0, 2.0] + /// ``` + /// + /// # Safety + /// + /// This function is unsafe because it does not check that the array is not of type complex64. + #[default_device] + pub unsafe fn floor_device_unchecked(&self, stream: StreamOrDevice) -> Array { + unsafe { Array::from_ptr(mlx_sys::mlx_floor(self.c_array, stream.as_ptr())) } + } + + /// Element-wise floor returning an error if the array is of type complex64. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[0.1, 1.9, 2.5], &[3]); + /// let mut b = a.floor_device(Default::default()); + /// + /// b.eval(); + /// let b_data: &[f32] = b.as_slice(); + /// // b_data == [0.0, 1.0, 2.0] + /// ``` + #[default_device] + pub fn try_floor_device(&self, stream: StreamOrDevice) -> Result { + if self.dtype() == Dtype::Complex64 { + return Err(OperationError::NotSupported( + "Floor not supported for complex64".to_string(), + )); + } + + Ok(unsafe { self.floor_device_unchecked(stream) }) + } + + /// Element-wise integer division. + /// + /// Divide two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// If either array is a floating point type then it is equivalent to calling [floor()] after `/`. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); + /// let mut c = a.floor_divide_device(&b, Default::default()); + /// + /// c.eval(); + /// let c_data: &[f32] = c.as_slice(); + /// // c_data == [0.25, 0.4, 0.5] + /// ``` + /// + /// # Params + /// + /// - other: array to divide + /// - stream: stream or device to evaluate on + #[default_device] + pub fn floor_divide_device(&self, other: &Array, stream: StreamOrDevice) -> Array { + self.try_floor_divide_device(other, stream).unwrap() + } + + /// Element-wise integer division without checking the array type or for broadcastability. + /// + /// Divide two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// If either array is a floating point type then it is equivalent to calling [floor()] after `/`. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); + /// let mut c = a.floor_divide_device(&b, Default::default()); + /// + /// c.eval(); + /// let c_data: &[f32] = c.as_slice(); + /// // c_data == [0.25, 0.4, 0.5] + /// ``` + /// + /// # Params + /// + /// - other: array to divide + /// - stream: stream or device to evaluate on + /// + /// # Safety + /// + /// This function is unsafe because it does not check array types or that the arrays are broadcastable. + #[default_device] + pub unsafe fn floor_divide_device_unchecked( + &self, + other: &Array, + stream: StreamOrDevice, + ) -> Array { + unsafe { + Array::from_ptr(mlx_sys::mlx_floor_divide( + self.c_array, + other.c_array, + stream.as_ptr(), + )) + } + } + + /// Element-wise integer division returning an error if arrays are not broadcastable. + /// + /// Divide two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// If either array is a floating point type then it is equivalent to calling [floor()] after `/`. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); + /// let mut c = a.floor_divide_device(&b, Default::default()); + /// + /// c.eval(); + /// let c_data: &[f32] = c.as_slice(); + /// // c_data == [0.25, 0.4, 0.5] + /// ``` + /// + /// # Params + /// + /// - other: array to divide + /// - stream: stream or device to evaluate on + #[default_device] + pub fn try_floor_divide_device( + &self, + other: &Array, + stream: StreamOrDevice, + ) -> Result { + if self.dtype() == Dtype::Complex64 { + return Err(OperationError::NotSupported( + "Floor is not supported for complex64".to_string(), + ) + .into()); + } + + if !is_broadcastable(self.shape(), other.shape()) { + return Err(DataStoreError::BroadcastError.into()); + } + + Ok(unsafe { self.floor_divide_device_unchecked(other, stream) }) + } + + /// Element-wise natural logarithm. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let mut b = a.log_device(Default::default()); + /// + /// b.eval(); + /// let b_data: &[f32] = b.as_slice(); + /// // b_data == [0.0, 0.6931472, 1.0986123] + /// ``` + #[default_device] + pub fn log_device(&self, stream: StreamOrDevice) -> Array { + unsafe { Array::from_ptr(mlx_sys::mlx_log(self.c_array, stream.as_ptr())) } + } + + /// Element-wise base-2 logarithm. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 4.0, 8.0], &[4]); + /// let mut b = a.log2_device(Default::default()); + /// + /// b.eval(); + /// let b_data: &[f32] = b.as_slice(); + /// // b_data == [0.0, 1.0, 2.0, 3.0] + /// ``` + #[default_device] + pub fn log2_device(&self, stream: StreamOrDevice) -> Array { + unsafe { Array::from_ptr(mlx_sys::mlx_log2(self.c_array, stream.as_ptr())) } + } + + /// Element-wise base-10 logarithm. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 10.0, 100.0], &[3]); + /// let mut b = a.log10_device(Default::default()); + /// + /// b.eval(); + /// let b_data: &[f32] = b.as_slice(); + /// // b_data == [0.0, 1.0, 2.0] + /// ``` + #[default_device] + pub fn log10_device(&self, stream: StreamOrDevice) -> Array { + unsafe { Array::from_ptr(mlx_sys::mlx_log10(self.c_array, stream.as_ptr())) } + } + + /// Element-wise natural log of one plus the array. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + /// let mut b = a.log1p_device(Default::default()); + /// + /// b.eval(); + /// let b_data: &[f32] = b.as_slice(); + /// // b_data == [0.6931472, 1.0986123, 1.3862944] + /// ``` + #[default_device] + pub fn log1p_device(&self, stream: StreamOrDevice) -> Array { + unsafe { Array::from_ptr(mlx_sys::mlx_log1p(self.c_array, stream.as_ptr())) } + } + + /// Matrix multiplication. + /// + /// Perform the (possibly batched) matrix multiplication of two arrays. This function supports + /// broadcasting for arrays with more than two dimensions. + /// + /// - If the first array is 1-D then a 1 is prepended to its shape to make it + /// a matrix. Similarly, if the second array is 1-D then a 1 is appended to its + /// shape to make it a matrix. In either case the singleton dimension is removed + /// from the result. + /// - A batched matrix multiplication is performed if the arrays have more than + /// 2 dimensions. The matrix dimensions for the matrix product are the last + /// two dimensions of each input. + /// - All but the last two dimensions of each input are broadcast with one another using + /// standard [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]); + /// let b = Array::from_slice(&[-5.0, 37.5, 4., 7., 1., 0.], &[2, 3]); + /// + /// // produces a [2, 3] result + /// let mut c = a.matmul_device(&b, Default::default()); + /// ``` + /// + /// # Params + /// + /// - other: array to multiply + /// - stream: stream or device to evaluate on + #[default_device] + pub fn matmul_device(&self, other: &Array, stream: StreamOrDevice) -> Array { + self.try_matmul_device(other, stream).unwrap() + } + + /// Matrix multiplication without validating inputs. + /// + /// Perform the (possibly batched) matrix multiplication of two arrays. This function supports + /// broadcasting for arrays with more than two dimensions. + /// + /// - If the first array is 1-D then a 1 is prepended to its shape to make it + /// a matrix. Similarly, if the second array is 1-D then a 1 is appended to its + /// shape to make it a matrix. In either case the singleton dimension is removed + /// from the result. + /// - A batched matrix multiplication is performed if the arrays have more than + /// 2 dimensions. The matrix dimensions for the matrix product are the last + /// two dimensions of each input. + /// - All but the last two dimensions of each input are broadcast with one another using + /// standard [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]); + /// let b = Array::from_slice(&[-5.0, 37.5, 4., 7., 1., 0.], &[2, 3]); + /// + /// // produces a [2, 3] result + /// let mut c = a.matmul_device(&b, Default::default()); + /// ``` + /// + /// # Params + /// + /// - other: array to multiply + /// - stream: stream or device to evaluate on + /// + /// # Safety + /// + /// This function is unsafe because it does not check that the inputs are valid for matrix multiplication. + #[default_device] + pub unsafe fn matmul_device_unchecked(&self, other: &Array, stream: StreamOrDevice) -> Array { + unsafe { + Array::from_ptr(mlx_sys::mlx_matmul( + self.c_array, + other.c_array, + stream.as_ptr(), + )) + } + } + + /// Matrix multiplication returning an error if inputs are not valid. + /// + /// Perform the (possibly batched) matrix multiplication of two arrays. This function supports + /// broadcasting for arrays with more than two dimensions. + /// + /// - If the first array is 1-D then a 1 is prepended to its shape to make it + /// a matrix. Similarly, if the second array is 1-D then a 1 is appended to its + /// shape to make it a matrix. In either case the singleton dimension is removed + /// from the result. + /// - A batched matrix multiplication is performed if the arrays have more than + /// 2 dimensions. The matrix dimensions for the matrix product are the last + /// two dimensions of each input. + /// - All but the last two dimensions of each input are broadcast with one another using + /// standard [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]); + /// let b = Array::from_slice(&[-5.0, 37.5, 4., 7., 1., 0.], &[2, 3]); + /// + /// // produces a [2, 3] result + /// let mut c = a.matmul_device(&b, Default::default()); + /// ``` + /// + /// # Params + /// + /// - other: array to multiply + /// - stream: stream or device to evaluate on + #[default_device] + pub fn try_matmul_device( + &self, + other: &Array, + stream: StreamOrDevice, + ) -> Result { + if self.ndim() == 0 || other.ndim() == 0 { + return Err(OperationError::WrongInput( + "Got 0 dimension input. Inputs must have at least one dimension.".to_string(), + ) + .into()); + } + + // get last dimension of first input and second to last dimension of second input + let a_last_dim: i32 = if self.ndim() == 1 { + let new_shape = [1, self.size() as i32]; + new_shape[new_shape.len() - 1] + } else { + self.shape()[self.shape().len() - 1] + }; + + let b_semi_last_dim = if other.ndim() == 1 { + let new_shape = [other.size() as i32, 1]; + new_shape[new_shape.len() - 2] + } else { + other.shape()[other.shape().len() - 2] + }; + + if a_last_dim != b_semi_last_dim { + return Err(OperationError::WrongDimensions( + format!("Last dimension of first input with shape {:?} must match second to last dimension of second input with shape {:?}", + self.shape(), other.shape()) + ) + .into()); + } + + let result_type = Dtype::from_promoting_types(self.dtype(), other.dtype()); + + if !result_type.is_floating() { + return Err(OperationError::WrongInput( + format!("Only real floating point types are supported but {:?} and {:?} where provided, which is not a real floating point type", + self.dtype(), other.dtype()) + ) + .into()); + } + + Ok(unsafe { self.matmul_device_unchecked(other, stream) }) + } + + /// Element-wise reciprocal. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]); + /// let mut b = a.reciprocal_device(Default::default()); + /// + /// b.eval(); + /// let b_data: &[f32] = b.as_slice(); + /// // b_data == [1.0, 0.5, 0.25] + /// ``` + #[default_device] + pub fn reciprocal_device(&self, stream: StreamOrDevice) -> Array { + unsafe { Array::from_ptr(mlx_sys::mlx_reciprocal(self.c_array, stream.as_ptr())) } + } + + /// Round to the given number of decimals. + /// + /// # Params + /// + /// - decimals: number of decimals to round to - default is 0 if not provided + /// - stream: stream or device to evaluate on + #[default_device] + pub fn round_device(&self, decimals: impl Into>, stream: StreamOrDevice) -> Array { + unsafe { + Array::from_ptr(mlx_sys::mlx_round( + self.c_array, + decimals.into().unwrap_or(0), + stream.as_ptr(), + )) + } + } + + /// Element-wise reciprocal and square root. + #[default_device] + pub fn rsqrt_device(&self, stream: StreamOrDevice) -> Array { + unsafe { Array::from_ptr(mlx_sys::mlx_rsqrt(self.c_array, stream.as_ptr())) } + } + + /// Element-wise sine. + #[default_device] + pub fn sin_device(&self, stream: StreamOrDevice) -> Array { + unsafe { Array::from_ptr(mlx_sys::mlx_sin(self.c_array, stream.as_ptr())) } + } + + /// Element-wise square. + #[default_device] + pub fn square_device(&self, stream: StreamOrDevice) -> Array { + unsafe { Array::from_ptr(mlx_sys::mlx_square(self.c_array, stream.as_ptr())) } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::complex64; + use num_traits::Pow; + use pretty_assertions::assert_eq; + + #[test] + fn test_abs() { + let data = [1i32, 2, -3, -4, -5]; + let array = Array::from_slice(&data, &[5]); + let mut result = array.abs(); + + result.eval(); + let data: &[i32] = result.as_slice(); + assert_eq!(data, [1, 2, 3, 4, 5]); + + // test that previous array is not modified and valid + let data: &[i32] = array.as_slice(); + assert_eq!(data, [1, 2, -3, -4, -5]); + } + + #[test] + fn test_add() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); + + let mut c = &a + &b; + c.eval(); + + let c_data: &[f32] = c.as_slice(); + assert_eq!(c_data, &[5.0, 7.0, 9.0]); + + // check a and b are not modified + let a_data: &[f32] = a.as_slice(); + assert_eq!(a_data, &[1.0, 2.0, 3.0]); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[4.0, 5.0, 6.0]); + } + + #[test] + fn test_add_invalid_broadcast() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let b = Array::from_slice(&[4.0, 5.0], &[2]); + let c = a.try_add(&b); + assert!(c.is_err()); + } + + #[test] + fn test_sub() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); + + let mut c = &a - &b; + c.eval(); + + let c_data: &[f32] = c.as_slice(); + assert_eq!(c_data, &[-3.0, -3.0, -3.0]); + + // check a and b are not modified + let a_data: &[f32] = a.as_slice(); + assert_eq!(a_data, &[1.0, 2.0, 3.0]); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[4.0, 5.0, 6.0]); + } + + #[test] + fn test_sub_invalid_broadcast() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let b = Array::from_slice(&[4.0, 5.0], &[2]); + let c = a.try_sub(&b); + assert!(c.is_err()); + } + + #[test] + fn test_neg() { + let a = Array::from_slice::(&[1.0, 2.0, 3.0], &[3]); + let mut b = a.neg(); + b.eval(); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[-1.0, -2.0, -3.0]); + + // check a is not modified + let a_data: &[f32] = a.as_slice(); + assert_eq!(a_data, &[1.0, 2.0, 3.0]); + } + + #[test] + fn test_neg_bool() { + let a = Array::from_slice(&[true, false, true], &[3]); + let b = a.try_neg(); + assert!(b.is_err()); + } + + #[test] + fn test_logical_not() { + let a: Array = false.into(); + let mut b = a.logical_not(); + + b.eval(); + let b_data: &[bool] = b.as_slice(); + assert_eq!(b_data, [true]); + } + + #[test] + fn test_mul() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); + + let mut c = &a * &b; + c.eval(); + + let c_data: &[f32] = c.as_slice(); + assert_eq!(c_data, &[4.0, 10.0, 18.0]); + + // check a and b are not modified + let a_data: &[f32] = a.as_slice(); + assert_eq!(a_data, &[1.0, 2.0, 3.0]); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[4.0, 5.0, 6.0]); + } + + #[test] + fn test_mul_invalid_broadcast() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let b = Array::from_slice(&[4.0, 5.0], &[2]); + let c = a.try_mul(&b); + assert!(c.is_err()); + } + + #[test] + fn test_div() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); + + let mut c = &a / &b; + c.eval(); + + let c_data: &[f32] = c.as_slice(); + assert_eq!(c_data, &[0.25, 0.4, 0.5]); + + // check a and b are not modified + let a_data: &[f32] = a.as_slice(); + assert_eq!(a_data, &[1.0, 2.0, 3.0]); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[4.0, 5.0, 6.0]); + } + + #[test] + fn test_div_invalid_broadcast() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let b = Array::from_slice(&[4.0, 5.0], &[2]); + let c = a.try_div(&b); + assert!(c.is_err()); + } + + #[test] + fn test_pow() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let b = Array::from_slice(&[2.0, 3.0, 4.0], &[3]); + + let mut c = a.pow(&b); + c.eval(); + + let c_data: &[f32] = c.as_slice(); + assert_eq!(c_data, &[1.0, 8.0, 81.0]); + + // check a and b are not modified + let a_data: &[f32] = a.as_slice(); + assert_eq!(a_data, &[1.0, 2.0, 3.0]); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[2.0, 3.0, 4.0]); + } + + #[test] + fn test_pow_invalid_broadcast() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let b = Array::from_slice(&[2.0, 3.0], &[2]); + let c = a.try_pow(&b); + assert!(c.is_err()); + } + + #[test] + fn test_rem() { + let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]); + let b = Array::from_slice(&[3.0, 4.0, 5.0], &[3]); + + let mut c = &a % &b; + c.eval(); + + let c_data: &[f32] = c.as_slice(); + assert_eq!(c_data, &[1.0, 3.0, 2.0]); + + // check a and b are not modified + let a_data: &[f32] = a.as_slice(); + assert_eq!(a_data, &[10.0, 11.0, 12.0]); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[3.0, 4.0, 5.0]); + } + + #[test] + fn test_rem_invalid_broadcast() { + let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]); + let b = Array::from_slice(&[3.0, 4.0], &[2]); + let c = a.try_rem(&b); + assert!(c.is_err()); + } + + #[test] + fn test_sqrt() { + let a = Array::from_slice(&[1.0, 4.0, 9.0], &[3]); + let mut b = a.sqrt(); + b.eval(); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[1.0, 2.0, 3.0]); + + // check a is not modified + let a_data: &[f32] = a.as_slice(); + assert_eq!(a_data, &[1.0, 4.0, 9.0]); + } + + #[test] + fn test_cos() { + let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]); + let mut b = a.cos(); + b.eval(); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[1.0, 0.54030234, -0.41614687]); + + // check a is not modified + let a_data: &[f32] = a.as_slice(); + assert_eq!(a_data, &[0.0, 1.0, 2.0]); + } + + #[test] + fn test_exp() { + let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]); + let mut b = a.exp(); + b.eval(); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[1.0, 2.7182817, 7.389056]); + + // check a is not modified + let a_data: &[f32] = a.as_slice(); + assert_eq!(a_data, &[0.0, 1.0, 2.0]); + } + + #[test] + fn test_floor() { + let a = Array::from_slice(&[0.1, 1.9, 2.5], &[3]); + let mut b = a.floor(); + b.eval(); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[0.0, 1.0, 2.0]); + + // check a is not modified + let a_data: &[f32] = a.as_slice(); + assert_eq!(a_data, &[0.1, 1.9, 2.5]); + } + + #[test] + fn test_floor_complex64() { + let val = complex64::new(1.0, 2.0); + let a = Array::from_complex(val); + let b = a.try_floor_device(Default::default()); + assert!(b.is_err()); + } + + #[test] + fn test_floor_divide() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); + + let mut c = a.floor_divide(&b); + c.eval(); + + let c_data: &[f32] = c.as_slice(); + assert_eq!(c_data, &[0.0, 0.0, 0.0]); + + // check a and b are not modified + let a_data: &[f32] = a.as_slice(); + assert_eq!(a_data, &[1.0, 2.0, 3.0]); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[4.0, 5.0, 6.0]); + } + + #[test] + fn test_floor_divide_complex64() { + let val = complex64::new(1.0, 2.0); + let a = Array::from_complex(val); + let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); + let c = a.try_floor_divide_device(&b, Default::default()); + assert!(c.is_err()); + } + + #[test] + fn test_floor_divide_invalid_broadcast() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let b = Array::from_slice(&[4.0, 5.0], &[2]); + let c = a.try_floor_divide_device(&b, Default::default()); + assert!(c.is_err()); + } + + #[test] + fn test_log() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let mut b = a.log(); + b.eval(); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[0.0, 0.6931472, 1.0986123]); + + // check a is not modified + let a_data: &[f32] = a.as_slice(); + assert_eq!(a_data, &[1.0, 2.0, 3.0]); + } + + #[test] + fn test_log2() { + let a = Array::from_slice(&[1.0, 2.0, 4.0, 8.0], &[4]); + let mut b = a.log2(); + b.eval(); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[0.0, 1.0, 2.0, 3.0]); + + // check a is not modified + let a_data: &[f32] = a.as_slice(); + assert_eq!(a_data, &[1.0, 2.0, 4.0, 8.0]); + } + + #[test] + fn test_log10() { + let a = Array::from_slice(&[1.0, 10.0, 100.0], &[3]); + let mut b = a.log10(); + b.eval(); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[0.0, 1.0, 2.0]); + + // check a is not modified + let a_data: &[f32] = a.as_slice(); + assert_eq!(a_data, &[1.0, 10.0, 100.0]); + } + + #[test] + fn test_log1p() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let mut b = a.log1p(); + b.eval(); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[0.6931472, 1.0986123, 1.3862944]); + + // check a is not modified + let a_data: &[f32] = a.as_slice(); + assert_eq!(a_data, &[1.0, 2.0, 3.0]); + } + + #[test] + fn test_matmul() { + let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]); + let b = Array::from_slice(&[-5.0, 37.5, 4., 7., 1., 0.], &[2, 3]); + + let mut c = a.matmul(&b); + c.eval(); + + assert_eq!(c.shape(), &[2, 3]); + let c_data: &[f32] = c.as_slice(); + assert_eq!(c_data, &[9.0, 39.5, 4.0, 13.0, 116.5, 12.0]); + + // check a and b are not modified + let a_data: &[i32] = a.as_slice(); + assert_eq!(a_data, &[1, 2, 3, 4]); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[-5.0, 37.5, 4., 7., 1., 0.]); + } + + #[test] + fn test_matmul_ndim_zero() { + let a: Array = 1.0.into(); + let b = Array::from_slice::(&[1], &[1]); + let c = a.try_matmul(&b); + assert!(c.is_err()); + } + + #[test] + fn test_matmul_ndim_one() { + let a = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4]); + let b = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4]); + let c = a.try_matmul(&b); + assert!(c.is_ok()); + } + + #[test] + fn test_matmul_dim_mismatch() { + let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]); + let b = Array::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], &[2, 5]); + let c = a.try_matmul(&b); + assert!(c.is_err()); + } + + #[test] + fn test_matmul_non_float_output_type() { + let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]); + let b = Array::from_slice(&[5, 37, 4, 7, 1, 0], &[2, 3]); + + let c = a.try_matmul(&b); + assert!(c.is_err()); + } + + #[test] + fn test_reciprocal() { + let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]); + let mut b = a.reciprocal(); + b.eval(); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[1.0, 0.5, 0.25]); + + // check a is not modified + let a_data: &[f32] = a.as_slice(); + assert_eq!(a_data, &[1.0, 2.0, 4.0]); + } + + #[test] + fn test_round() { + let a = Array::from_slice(&[1.1, 2.9, 3.5], &[3]); + let mut b = a.round(None); + b.eval(); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[1.0, 3.0, 4.0]); + + // check a is not modified + let a_data: &[f32] = a.as_slice(); + assert_eq!(a_data, &[1.1, 2.9, 3.5]); + } + + #[test] + fn test_rsqrt() { + let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]); + let mut b = a.rsqrt(); + b.eval(); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[1.0, 0.70710677, 0.5]); + + // check a is not modified + let a_data: &[f32] = a.as_slice(); + assert_eq!(a_data, &[1.0, 2.0, 4.0]); + } + + #[test] + fn test_sin() { + let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]); + let mut b = a.sin(); + b.eval(); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[0.0, 0.841471, 0.9092974]); + + // check a is not modified + let a_data: &[f32] = a.as_slice(); + assert_eq!(a_data, &[0.0, 1.0, 2.0]); + } + + #[test] + fn test_square() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let mut b = a.square(); + b.eval(); + + let b_data: &[f32] = b.as_slice(); + assert_eq!(b_data, &[1.0, 4.0, 9.0]); + + // check a is not modified + let a_data: &[f32] = a.as_slice(); + assert_eq!(a_data, &[1.0, 2.0, 3.0]); + } +} diff --git a/src/ops/array.rs b/src/ops/array.rs deleted file mode 100644 index 80919e46c..000000000 --- a/src/ops/array.rs +++ /dev/null @@ -1,69 +0,0 @@ -use crate::array::Array; -use crate::stream::StreamOrDevice; -use mlx_macros::default_device; - -impl Array { - /// Element-wise absolute value. - /// - /// # Params - /// - /// - stream: stream or device to evaluate on - #[default_device] - pub fn abs_device(&self, stream: StreamOrDevice) -> Array { - let ctx = stream.as_ptr(); - - unsafe { Array::from_ptr(mlx_sys::mlx_abs(self.c_array, ctx)) } - } - - /// Element-wise addition. - /// - /// Add two arrays with . - pub fn add_device(&self, other: &Array, stream: StreamOrDevice) -> Array { - unsafe { - Array::from_ptr(mlx_sys::mlx_add( - self.c_array, - other.c_array, - stream.as_ptr(), - )) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_abs() { - let data = [1i32, 2, -3, -4, -5]; - let array = Array::from_slice(&data, &[5]); - let mut result = array.abs(); - - result.eval(); - let data: &[i32] = result.as_slice(); - assert_eq!(data, [1, 2, 3, 4, 5]); - - // test that previous array is not modified and valid - let data: &[i32] = array.as_slice(); - assert_eq!(data, [1, 2, -3, -4, -5]); - } - - #[test] - fn test_add_device() { - let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); - let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]); - - let mut c = &a + &b; - c.eval(); - - let c_data: &[f32] = c.as_slice(); - assert_eq!(c_data, &[5.0, 7.0, 9.0]); - - // check a and b are not modified - let a_data: &[f32] = a.as_slice(); - assert_eq!(a_data, &[1.0, 2.0, 3.0]); - - let b_data: &[f32] = b.as_slice(); - assert_eq!(b_data, &[4.0, 5.0, 6.0]); - } -} diff --git a/src/ops/factory.rs b/src/ops/factory.rs index afafe2402..779f5af6d 100644 --- a/src/ops/factory.rs +++ b/src/ops/factory.rs @@ -263,7 +263,7 @@ impl Array { /// Construct an array with the given value. /// /// Constructs an array of size `shape` filled with `values`. If `values` - /// is an [Array] it must be to the given `shape`. + /// is an [Array] it must be [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting) to the given `shape`. /// /// # Example /// @@ -290,7 +290,7 @@ impl Array { /// Construct an array with the given value without validating shape. /// /// Constructs an array of size `shape` filled with `values`. If `values` - /// is an [Array] it must be to the given `shape`. + /// is an [Array] it must be [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting) to the given `shape`. /// /// # Example /// @@ -329,7 +329,7 @@ impl Array { /// Construct an array with the given value returning an error if shape is invalid. /// /// Constructs an array of size `shape` filled with `values`. If `values` - /// is an [Array] it must be to the given `shape`. + /// is an [Array] it must be [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting) to the given `shape`. /// /// # Example /// diff --git a/src/ops/logical.rs b/src/ops/logical.rs new file mode 100644 index 000000000..c44751af8 --- /dev/null +++ b/src/ops/logical.rs @@ -0,0 +1,1343 @@ +use crate::array::Array; +use crate::error::DataStoreError; +use crate::stream::StreamOrDevice; +use crate::utils::is_broadcastable; +use mlx_macros::default_device; + +impl Array { + /// Element-wise equality. + /// + /// Equality comparison on two arrays with + /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1, 2, 3], &[3]); + /// let b = Array::from_slice(&[1, 2, 3], &[3]); + /// let mut c = a.eq(&b); + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [true, true, true] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + #[default_device] + pub fn eq_device(&self, other: &Array, stream: StreamOrDevice) -> Array { + self.try_eq_device(other, stream).unwrap() + } + + /// Element-wise equality without broadcasting checks. + /// + /// Equality comparison on two arrays with + /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1, 2, 3], &[3]); + /// let b = Array::from_slice(&[1, 2, 3], &[3]); + /// let mut c = unsafe { a.eq_unchecked(&b) }; + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [true, true, true] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + /// + /// # Safety + /// + /// This function is unsafe because it does not check if the arrays are broadcastable. + #[default_device] + pub unsafe fn eq_device_unchecked(&self, other: &Array, stream: StreamOrDevice) -> Array { + unsafe { + Array::from_ptr(mlx_sys::mlx_equal( + self.c_array, + other.c_array, + stream.as_ptr(), + )) + } + } + + /// Element-wise equality returning an error if the arrays are not broadcastable. + /// + /// Equality comparison on two arrays with + /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1, 2, 3], &[3]); + /// let b = Array::from_slice(&[1, 2, 3], &[3]); + /// let mut c = a.try_eq(&b).unwrap(); + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [true, true, true] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + #[default_device] + pub fn try_eq_device( + &self, + other: &Array, + stream: StreamOrDevice, + ) -> Result { + if !is_broadcastable(self.shape(), other.shape()) { + return Err(DataStoreError::BroadcastError); + } + + Ok(unsafe { self.eq_device_unchecked(other, stream) }) + } + + /// Element-wise less than or equal. + /// + /// Less than or equal on two arrays with + /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1, 2, 3], &[3]); + /// let b = Array::from_slice(&[1, 2, 3], &[3]); + /// let mut c = a.le(&b); + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [true, true, true] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + #[default_device] + pub fn le_device(&self, other: &Array, stream: StreamOrDevice) -> Array { + self.try_le_device(other, stream).unwrap() + } + + /// Element-wise less than or equal without broadcasting checks. + /// + /// Less than or equal on two arrays with + /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1, 2, 3], &[3]); + /// let b = Array::from_slice(&[1, 2, 3], &[3]); + /// let mut c = unsafe { a.le_unchecked(&b) }; + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [true, true, true] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + /// + /// # Safety + /// + /// This function is unsafe because it does not check if the arrays are broadcastable. + #[default_device] + pub unsafe fn le_device_unchecked(&self, other: &Array, stream: StreamOrDevice) -> Array { + unsafe { + Array::from_ptr(mlx_sys::mlx_less_equal( + self.c_array, + other.c_array, + stream.as_ptr(), + )) + } + } + + /// Element-wise less than or equal returning an error if the arrays are not broadcastable. + /// + /// Less than or equal on two arrays with + /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1, 2, 3], &[3]); + /// let b = Array::from_slice(&[1, 2, 3], &[3]); + /// let mut c = a.try_le(&b).unwrap(); + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [true, true, true] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + #[default_device] + pub fn try_le_device( + &self, + other: &Array, + stream: StreamOrDevice, + ) -> Result { + if !is_broadcastable(self.shape(), other.shape()) { + return Err(DataStoreError::BroadcastError); + } + + Ok(unsafe { self.le_device_unchecked(other, stream) }) + } + + /// Element-wise greater than or equal. + /// + /// Greater than or equal on two arrays with + /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1, 2, 3], &[3]); + /// let b = Array::from_slice(&[1, 2, 3], &[3]); + /// let mut c = a.ge(&b); + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [true, true, true] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + #[default_device] + pub fn ge_device(&self, other: &Array, stream: StreamOrDevice) -> Array { + self.try_ge_device(other, stream).unwrap() + } + + /// Element-wise greater than or equal without broadcasting checks. + /// + /// Greater than or equal on two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1, 2, 3], &[3]); + /// let b = Array::from_slice(&[1, 2, 3], &[3]); + /// let mut c = unsafe { a.ge_unchecked(&b) }; + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [true, true, true] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + /// + /// # Safety + /// + /// This function is unsafe because it does not check if the arrays are broadcastable. + #[default_device] + pub unsafe fn ge_device_unchecked(&self, other: &Array, stream: StreamOrDevice) -> Array { + unsafe { + Array::from_ptr(mlx_sys::mlx_greater_equal( + self.c_array, + other.c_array, + stream.as_ptr(), + )) + } + } + + /// Element-wise greater than or equal returning an error if the arrays are not broadcastable. + /// + /// Greater than or equal on two arrays with + /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1, 2, 3], &[3]); + /// let b = Array::from_slice(&[1, 2, 3], &[3]); + /// let mut c = a.try_ge(&b).unwrap(); + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [true, true, true] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + #[default_device] + pub fn try_ge_device( + &self, + other: &Array, + stream: StreamOrDevice, + ) -> Result { + if !is_broadcastable(self.shape(), other.shape()) { + return Err(DataStoreError::BroadcastError); + } + + Ok(unsafe { self.ge_device_unchecked(other, stream) }) + } + + /// Element-wise not equal. + /// + /// Not equal on two arrays with + /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1, 2, 3], &[3]); + /// let b = Array::from_slice(&[1, 2, 3], &[3]); + /// let mut c = a.ne(&b); + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [false, false, false] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + #[default_device] + pub fn ne_device(&self, other: &Array, stream: StreamOrDevice) -> Array { + self.try_ne_device(other, stream).unwrap() + } + + /// Element-wise not equal without broadcasting checks. + /// + /// Not equal on two arrays with + /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1, 2, 3], &[3]); + /// let b = Array::from_slice(&[1, 2, 3], &[3]); + /// let mut c = unsafe { a.ne_unchecked(&b) }; + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [false, false, false] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + /// + /// # Safety + /// + /// This function is unsafe because it does not check if the arrays are broadcastable. + #[default_device] + pub unsafe fn ne_device_unchecked(&self, other: &Array, stream: StreamOrDevice) -> Array { + unsafe { + Array::from_ptr(mlx_sys::mlx_not_equal( + self.c_array, + other.c_array, + stream.as_ptr(), + )) + } + } + + /// Element-wise not equal returning an error if the arrays are not broadcastable. + /// + /// Not equal on two arrays with + /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1, 2, 3], &[3]); + /// let b = Array::from_slice(&[1, 2, 3], &[3]); + /// let mut c = a.try_ne(&b).unwrap(); + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [false, false, false] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + #[default_device] + pub fn try_ne_device( + &self, + other: &Array, + stream: StreamOrDevice, + ) -> Result { + if !is_broadcastable(self.shape(), other.shape()) { + return Err(DataStoreError::BroadcastError); + } + + Ok(unsafe { self.ne_device_unchecked(other, stream) }) + } + + /// Element-wise less than. + /// + /// Less than on two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1, 2, 3], &[3]); + /// let b = Array::from_slice(&[1, 2, 3], &[3]); + /// let mut c = a.lt(&b); + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [false, false, false] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + #[default_device] + pub fn lt_device(&self, other: &Array, stream: StreamOrDevice) -> Array { + self.try_lt_device(other, stream).unwrap() + } + + /// Element-wise less than without broadcasting checks. + /// + /// Less than on two arrays with + /// [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1, 2, 3], &[3]); + /// let b = Array::from_slice(&[1, 2, 3], &[3]); + /// let mut c = unsafe { a.lt_unchecked(&b) }; + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [false, false, false] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + /// + /// # Safety + /// + /// This function is unsafe because it does not check if the arrays are broadcastable. + #[default_device] + pub unsafe fn lt_device_unchecked(&self, other: &Array, stream: StreamOrDevice) -> Array { + unsafe { + Array::from_ptr(mlx_sys::mlx_less( + self.c_array, + other.c_array, + stream.as_ptr(), + )) + } + } + + /// Element-wise less than returning an error if the arrays are not broadcastable. + /// + /// Less than on two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1, 2, 3], &[3]); + /// let b = Array::from_slice(&[1, 2, 3], &[3]); + /// let mut c = a.try_lt(&b).unwrap(); + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [false, false, false] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + #[default_device] + pub fn try_lt_device( + &self, + other: &Array, + stream: StreamOrDevice, + ) -> Result { + if !is_broadcastable(self.shape(), other.shape()) { + return Err(DataStoreError::BroadcastError); + } + + Ok(unsafe { self.lt_device_unchecked(other, stream) }) + } + + /// Element-wise greater than. + /// + /// Greater than on two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1, 2, 3], &[3]); + /// let b = Array::from_slice(&[1, 2, 3], &[3]); + /// let mut c = a.gt(&b); + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [false, false, false] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + #[default_device] + pub fn gt_device(&self, other: &Array, stream: StreamOrDevice) -> Array { + self.try_gt_device(other, stream).unwrap() + } + + /// Element-wise greater than without broadcasting checks. + /// + /// Greater than on two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1, 2, 3], &[3]); + /// let b = Array::from_slice(&[1, 2, 3], &[3]); + /// let mut c = unsafe { a.gt_unchecked(&b) }; + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [false, false, false] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + /// + /// # Safety + /// + /// This function is unsafe because it does not check if the arrays are broadcastable. + #[default_device] + pub unsafe fn gt_device_unchecked(&self, other: &Array, stream: StreamOrDevice) -> Array { + unsafe { + Array::from_ptr(mlx_sys::mlx_greater( + self.c_array, + other.c_array, + stream.as_ptr(), + )) + } + } + + /// Element-wise greater than returning an error if the arrays are not broadcastable. + /// + /// Greater than on two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[1, 2, 3], &[3]); + /// let b = Array::from_slice(&[1, 2, 3], &[3]); + /// let mut c = a.try_gt(&b).unwrap(); + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [false, false, false] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + #[default_device] + pub fn try_gt_device( + &self, + other: &Array, + stream: StreamOrDevice, + ) -> Result { + if !is_broadcastable(self.shape(), other.shape()) { + return Err(DataStoreError::BroadcastError); + } + + Ok(unsafe { self.gt_device_unchecked(other, stream) }) + } + + /// Element-wise logical and. + /// + /// Logical and on two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[true, false, true], &[3]); + /// let b = Array::from_slice(&[true, true, false], &[3]); + /// let mut c = a.logical_and(&b); + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [true, false, false] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + #[default_device] + pub fn logical_and_device(&self, other: &Array, stream: StreamOrDevice) -> Array { + self.try_logical_and_device(other, stream).unwrap() + } + + /// Element-wise logical and without broadcasting checks. + /// + /// Logical and on two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[true, false, true], &[3]); + /// let b = Array::from_slice(&[true, true, false], &[3]); + /// let mut c = unsafe { a.logical_and_unchecked(&b) }; + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [true, false, false] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + /// + /// # Safety + /// + /// This function is unsafe because it does not check if the arrays are broadcastable. + #[default_device] + pub unsafe fn logical_and_device_unchecked( + &self, + other: &Array, + stream: StreamOrDevice, + ) -> Array { + unsafe { + Array::from_ptr(mlx_sys::mlx_logical_and( + self.c_array, + other.c_array, + stream.as_ptr(), + )) + } + } + + /// Element-wise logical and returning an error if the arrays are not broadcastable. + /// + /// Logical and on two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[true, false, true], &[3]); + /// let b = Array::from_slice(&[true, true, false], &[3]); + /// let mut c = a.try_logical_and(&b).unwrap(); + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [true, false, false] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + #[default_device] + pub fn try_logical_and_device( + &self, + other: &Array, + stream: StreamOrDevice, + ) -> Result { + if !is_broadcastable(self.shape(), other.shape()) { + return Err(DataStoreError::BroadcastError); + } + + Ok(unsafe { self.logical_and_device_unchecked(other, stream) }) + } + + /// Element-wise logical or. + /// + /// Logical or on two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[true, false, true], &[3]); + /// let b = Array::from_slice(&[true, true, false], &[3]); + /// let mut c = a.logical_or(&b); + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [true, true, true] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + #[default_device] + pub fn logical_or_device(&self, other: &Array, stream: StreamOrDevice) -> Array { + self.try_logical_or_device(other, stream).unwrap() + } + + /// Element-wise logical or without broadcasting checks. + /// + /// Logical or on two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[true, false, true], &[3]); + /// let b = Array::from_slice(&[true, true, false], &[3]); + /// let mut c = unsafe { a.logical_or_unchecked(&b) }; + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [true, true, true] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + /// + /// # Safety + /// + /// This function is unsafe because it does not check if the arrays are broadcastable. + #[default_device] + pub unsafe fn logical_or_device_unchecked( + &self, + other: &Array, + stream: StreamOrDevice, + ) -> Array { + unsafe { + Array::from_ptr(mlx_sys::mlx_logical_or( + self.c_array, + other.c_array, + stream.as_ptr(), + )) + } + } + + /// Element-wise logical or returning an error if the arrays are not broadcastable. + /// + /// Logical or on two arrays with [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[true, false, true], &[3]); + /// let b = Array::from_slice(&[true, true, false], &[3]); + /// let mut c = a.try_logical_or(&b).unwrap(); + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [true, true, true] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - stream: stream or device to evaluate on + #[default_device] + pub fn try_logical_or_device( + &self, + other: &Array, + stream: StreamOrDevice, + ) -> Result { + if !is_broadcastable(self.shape(), other.shape()) { + return Err(DataStoreError::BroadcastError); + } + + Ok(unsafe { self.logical_or_device_unchecked(other, stream) }) + } + + /// Approximate comparison of two arrays. + /// + /// The arrays are considered equal if: + /// + /// ```text + /// all(abs(a - b) <= (atol + rtol * abs(b))) + /// ``` + /// + /// # Example + /// + /// ```rust + /// use num_traits::Pow; + /// use mlx::Array; + /// let a = Array::from_slice(&[0., 1., 2., 3.], &[4]).sqrt(); + /// let b = Array::from_slice(&[0., 1., 2., 3.], &[4]).pow(&(0.5.into())); + /// let mut c = a.all_close(&b, None, None, None); + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [true] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - rtol: relative tolerance = defaults to 1e-5 when None + /// - atol: absolute tolerance - defaults to 1e-8 when None + /// - equal_nan: whether to consider NaNs equal -- default is false when None + /// - stream: stream or device to evaluate on + #[default_device] + pub fn all_close_device( + &self, + other: &Array, + rtol: impl Into>, + atol: impl Into>, + equal_nan: impl Into>, + stream: StreamOrDevice, + ) -> Array { + self.try_all_close_device(other, rtol, atol, equal_nan, stream) + .unwrap() + } + + /// Approximate comparison of two arrays without validating inputs. + /// + /// The arrays are considered equal if: + /// + /// ```text + /// all(abs(a - b) <= (atol + rtol * abs(b))) + /// ``` + /// + /// # Example + /// + /// ```rust + /// use num_traits::Pow; + /// use mlx::Array; + /// let a = Array::from_slice(&[0., 1., 2., 3.], &[4]).sqrt(); + /// let b = Array::from_slice(&[0., 1., 2., 3.], &[4]).pow(&(0.5.into())); + /// let mut c = unsafe { a.all_close_unchecked(&b, None, None, None) }; + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [true] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - rtol: relative tolerance = defaults to 1e-5 when None + /// - atol: absolute tolerance - defaults to 1e-8 when None + /// - equal_nan: whether to consider NaNs equal -- default is false when None + /// - stream: stream or device to evaluate on + /// + /// # Safety + /// + /// This function is unsafe because it does not validate inputs. + #[default_device] + pub unsafe fn all_close_device_unchecked( + &self, + other: &Array, + rtol: impl Into>, + atol: impl Into>, + equal_nan: impl Into>, + stream: StreamOrDevice, + ) -> Array { + unsafe { + Array::from_ptr(mlx_sys::mlx_allclose( + self.c_array, + other.c_array, + rtol.into().unwrap_or(1e-5), + atol.into().unwrap_or(1e-8), + equal_nan.into().unwrap_or(false), + stream.as_ptr(), + )) + } + } + + /// Approximate comparison of two arrays returning an error if the inputs aren't valid. + /// + /// The arrays are considered equal if: + /// + /// ```text + /// all(abs(a - b) <= (atol + rtol * abs(b))) + /// ``` + /// + /// # Example + /// + /// ```rust + /// use num_traits::Pow; + /// use mlx::Array; + /// let a = Array::from_slice(&[0., 1., 2., 3.], &[4]).sqrt(); + /// let b = Array::from_slice(&[0., 1., 2., 3.], &[4]).pow(&(0.5.into())); + /// let mut c = a.try_all_close(&b, None, None, None).unwrap(); + /// + /// c.eval(); + /// let c_data: &[bool] = c.as_slice(); + /// // c_data == [true] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - rtol: relative tolerance = defaults to 1e-5 when None + /// - atol: absolute tolerance - defaults to 1e-8 when None + /// - equal_nan: whether to consider NaNs equal -- default is false when None + /// - stream: stream or device to evaluate on + #[default_device] + pub fn try_all_close_device( + &self, + other: &Array, + rtol: impl Into>, + atol: impl Into>, + equal_nan: impl Into>, + stream: StreamOrDevice, + ) -> Result { + let is_close = self.try_is_close_device(other, rtol, atol, equal_nan, stream.clone()); + is_close + .map(|is_close| is_close.all_device(None, None, stream)) + .map_err(|error| error) + } + + /// Returns a boolean array where two arrays are element-wise equal within a tolerance. + /// + /// Infinite values are considered equal if they have the same sign, NaN values are not equal unless + /// `equalNAN` is `true`. + /// + /// Two values are considered close if: + /// + /// ```text + /// abs(a - b) <= (atol + rtol * abs(b)) + /// ``` + /// + /// Unlike [self.array_eq] this function supports [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + #[default_device] + pub fn is_close_device( + &self, + other: &Array, + rtol: impl Into>, + atol: impl Into>, + equal_nan: impl Into>, + stream: StreamOrDevice, + ) -> Array { + self.try_is_close_device(other, rtol, atol, equal_nan, stream) + .unwrap() + } + + /// Returns a boolean array where two arrays are element-wise equal within a tolerance without broadcasting checks. + /// + /// Infinite values are considered equal if they have the same sign, NaN values are not equal unless + /// `equalNAN` is `true`. + /// + /// Two values are considered close if: + /// + /// ```text + /// abs(a - b) <= (atol + rtol * abs(b)) + /// ``` + /// + /// Unlike [self.array_eq] this function supports [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + /// + /// # Safety + /// + /// This function is unsafe because it does not check if the arrays are broadcastable. + #[default_device] + pub unsafe fn is_close_device_unchecked( + &self, + other: &Array, + rtol: impl Into>, + atol: impl Into>, + equal_nan: impl Into>, + stream: StreamOrDevice, + ) -> Array { + unsafe { + Array::from_ptr(mlx_sys::mlx_isclose( + self.c_array, + other.c_array, + rtol.into().unwrap_or(1e-5), + atol.into().unwrap_or(1e-8), + equal_nan.into().unwrap_or(false), + stream.as_ptr(), + )) + } + } + + /// Returns a boolean array where two arrays are element-wise equal within a tolerance returning an error if the arrays are not broadcastable. + /// + /// Infinite values are considered equal if they have the same sign, NaN values are not equal unless + /// `equalNAN` is `true`. + /// + /// Two values are considered close if: + /// + /// ```text + /// abs(a - b) <= (atol + rtol * abs(b)) + /// ``` + /// + /// Unlike [self.array_eq] this function supports [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting). + #[default_device] + pub fn try_is_close_device( + &self, + other: &Array, + rtol: impl Into>, + atol: impl Into>, + equal_nan: impl Into>, + stream: StreamOrDevice, + ) -> Result { + // represents atol and rtol being broadcasted to operate on other + if !is_broadcastable(&[], other.shape()) { + return Err(DataStoreError::BroadcastError); + } + + if !is_broadcastable(self.shape(), other.shape()) { + return Err(DataStoreError::BroadcastError); + } + + Ok(unsafe { self.is_close_device_unchecked(other, rtol, atol, equal_nan, stream) }) + } + + /// Array equality check. + /// + /// Compare two arrays for equality. Returns `True` if and only if the arrays + /// have the same shape and their values are equal. The arrays need not have + /// the same type to be considered equal. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[0, 1, 2, 3], &[4]); + /// let b = Array::from_slice(&[0., 1., 2., 3.], &[4]); + /// + /// let c = a.array_eq(&b, None); + /// // c == [true] + /// ``` + /// + /// # Params + /// + /// - other: array to compare + /// - equal_nan: whether to consider NaNs equal -- default is false when None + /// - stream: stream or device to evaluate on + #[default_device] + pub fn array_eq_device( + &self, + other: &Array, + equal_nan: impl Into>, + stream: StreamOrDevice, + ) -> Array { + unsafe { + Array::from_ptr(mlx_sys::mlx_array_equal( + self.c_array, + other.c_array, + equal_nan.into().unwrap_or(false), + stream.as_ptr(), + )) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use num_traits::Pow; + + #[test] + fn test_eq() { + let a = Array::from_slice(&[1, 2, 3], &[3]); + let b = Array::from_slice(&[1, 2, 3], &[3]); + let mut c = a.eq(&b); + + c.eval(); + let c_data: &[bool] = c.as_slice(); + assert_eq!(c_data, [true, true, true]); + + // check a and b are not modified + let a_data: &[i32] = a.as_slice(); + assert_eq!(a_data, [1, 2, 3]); + + let b_data: &[i32] = b.as_slice(); + assert_eq!(b_data, [1, 2, 3]); + } + + #[test] + fn test_eq_invalid_broadcast() { + let a = Array::from_slice(&[1, 2, 3], &[3]); + let b = Array::from_slice(&[1, 2, 3, 4], &[4]); + let c = a.try_eq(&b); + assert!(c.is_err()); + } + + #[test] + fn test_le() { + let a = Array::from_slice(&[1, 2, 3], &[3]); + let b = Array::from_slice(&[1, 2, 3], &[3]); + let mut c = a.le(&b); + + c.eval(); + let c_data: &[bool] = c.as_slice(); + assert_eq!(c_data, [true, true, true]); + + // check a and b are not modified + let a_data: &[i32] = a.as_slice(); + assert_eq!(a_data, [1, 2, 3]); + + let b_data: &[i32] = b.as_slice(); + assert_eq!(b_data, [1, 2, 3]); + } + + #[test] + fn test_le_invalid_broadcast() { + let a = Array::from_slice(&[1, 2, 3], &[3]); + let b = Array::from_slice(&[1, 2, 3, 4], &[4]); + let c = a.try_le(&b); + assert!(c.is_err()); + } + + #[test] + fn test_ge() { + let a = Array::from_slice(&[1, 2, 3], &[3]); + let b = Array::from_slice(&[1, 2, 3], &[3]); + let mut c = a.ge(&b); + + c.eval(); + let c_data: &[bool] = c.as_slice(); + assert_eq!(c_data, [true, true, true]); + + // check a and b are not modified + let a_data: &[i32] = a.as_slice(); + assert_eq!(a_data, [1, 2, 3]); + + let b_data: &[i32] = b.as_slice(); + assert_eq!(b_data, [1, 2, 3]); + } + + #[test] + fn test_ge_invalid_broadcast() { + let a = Array::from_slice(&[1, 2, 3], &[3]); + let b = Array::from_slice(&[1, 2, 3, 4], &[4]); + let c = a.try_ge(&b); + assert!(c.is_err()); + } + + #[test] + fn test_ne() { + let a = Array::from_slice(&[1, 2, 3], &[3]); + let b = Array::from_slice(&[1, 2, 3], &[3]); + let mut c = a.ne(&b); + + c.eval(); + let c_data: &[bool] = c.as_slice(); + assert_eq!(c_data, [false, false, false]); + + // check a and b are not modified + let a_data: &[i32] = a.as_slice(); + assert_eq!(a_data, [1, 2, 3]); + + let b_data: &[i32] = b.as_slice(); + assert_eq!(b_data, [1, 2, 3]); + } + + #[test] + fn test_ne_invalid_broadcast() { + let a = Array::from_slice(&[1, 2, 3], &[3]); + let b = Array::from_slice(&[1, 2, 3, 4], &[4]); + let c = a.try_ne(&b); + assert!(c.is_err()); + } + + #[test] + fn test_lt() { + let a = Array::from_slice(&[1, 0, 3], &[3]); + let b = Array::from_slice(&[1, 2, 3], &[3]); + let mut c = a.lt(&b); + + c.eval(); + let c_data: &[bool] = c.as_slice(); + assert_eq!(c_data, [false, true, false]); + + // check a and b are not modified + let a_data: &[i32] = a.as_slice(); + assert_eq!(a_data, [1, 0, 3]); + + let b_data: &[i32] = b.as_slice(); + assert_eq!(b_data, [1, 2, 3]); + } + + #[test] + fn test_lt_invalid_broadcast() { + let a = Array::from_slice(&[1, 2, 3], &[3]); + let b = Array::from_slice(&[1, 2, 3, 4], &[4]); + let c = a.try_lt(&b); + assert!(c.is_err()); + } + + #[test] + fn test_gt() { + let a = Array::from_slice(&[1, 4, 3], &[3]); + let b = Array::from_slice(&[1, 2, 3], &[3]); + let mut c = a.gt(&b); + + c.eval(); + let c_data: &[bool] = c.as_slice(); + assert_eq!(c_data, [false, true, false]); + + // check a and b are not modified + let a_data: &[i32] = a.as_slice(); + assert_eq!(a_data, [1, 4, 3]); + + let b_data: &[i32] = b.as_slice(); + assert_eq!(b_data, [1, 2, 3]); + } + + #[test] + fn test_gt_invalid_broadcast() { + let a = Array::from_slice(&[1, 2, 3], &[3]); + let b = Array::from_slice(&[1, 2, 3, 4], &[4]); + let c = a.try_gt(&b); + assert!(c.is_err()); + } + + #[test] + fn test_logical_and() { + let a = Array::from_slice(&[true, false, true], &[3]); + let b = Array::from_slice(&[true, true, false], &[3]); + let mut c = a.logical_and(&b); + + c.eval(); + let c_data: &[bool] = c.as_slice(); + assert_eq!(c_data, [true, false, false]); + + // check a and b are not modified + let a_data: &[bool] = a.as_slice(); + assert_eq!(a_data, [true, false, true]); + + let b_data: &[bool] = b.as_slice(); + assert_eq!(b_data, [true, true, false]); + } + + #[test] + fn test_logical_and_invalid_broadcast() { + let a = Array::from_slice(&[true, false, true], &[3]); + let b = Array::from_slice(&[true, true, false, true], &[4]); + let c = a.try_logical_and(&b); + assert!(c.is_err()); + } + + #[test] + fn test_logical_or() { + let a = Array::from_slice(&[true, false, true], &[3]); + let b = Array::from_slice(&[true, true, false], &[3]); + let mut c = a.logical_or(&b); + + c.eval(); + let c_data: &[bool] = c.as_slice(); + assert_eq!(c_data, [true, true, true]); + + // check a and b are not modified + let a_data: &[bool] = a.as_slice(); + assert_eq!(a_data, [true, false, true]); + + let b_data: &[bool] = b.as_slice(); + assert_eq!(b_data, [true, true, false]); + } + + #[test] + fn test_logical_or_invalid_broadcast() { + let a = Array::from_slice(&[true, false, true], &[3]); + let b = Array::from_slice(&[true, true, false, true], &[4]); + let c = a.try_logical_or(&b); + assert!(c.is_err()); + } + + #[test] + fn test_all_close() { + let a = Array::from_slice(&[0., 1., 2., 3.], &[4]).sqrt(); + let b = Array::from_slice(&[0., 1., 2., 3.], &[4]).pow(&(0.5.into())); + let mut c = a.all_close(&b, 1e-5, None, None); + + c.eval(); + let c_data: &[bool] = c.as_slice(); + assert_eq!(c_data, [true]); + } + + #[test] + fn test_all_close_invalid_broadcast() { + let a = Array::from_slice(&[0., 1., 2., 3.], &[4]); + let b = Array::from_slice(&[0., 1., 2., 3., 4.], &[5]); + let c = a.try_all_close(&b, 1e-5, None, None); + assert!(c.is_err()); + } + + #[test] + fn test_is_close_false() { + let a = Array::from_slice(&[1., 2., 3.], &[3]); + let b = Array::from_slice(&[1.1, 2.2, 3.3], &[3]); + let mut c = a.is_close(&b, None, None, false); + + c.eval(); + let c_data: &[bool] = c.as_slice(); + assert_eq!(c_data, [false, false, false]); + } + + #[test] + fn test_is_close_true() { + let a = Array::from_slice(&[1., 2., 3.], &[3]); + let b = Array::from_slice(&[1.1, 2.2, 3.3], &[3]); + let mut c = a.is_close(&b, 0.1, 0.2, true); + + c.eval(); + let c_data: &[bool] = c.as_slice(); + assert_eq!(c_data, [true, true, true]); + } + + #[test] + fn test_is_close_invalid_broadcast() { + let a = Array::from_slice(&[1., 2., 3.], &[3]); + let b = Array::from_slice(&[1.1, 2.2, 3.3, 4.4], &[4]); + let c = a.try_is_close(&b, None, None, false); + assert!(c.is_err()); + } + + #[test] + fn test_array_eq() { + let a = Array::from_slice(&[0, 1, 2, 3], &[4]); + let b = Array::from_slice(&[0., 1., 2., 3.], &[4]); + let mut c = a.array_eq(&b, None); + + c.eval(); + let c_data: &[bool] = c.as_slice(); + assert_eq!(c_data, [true]); + } +} diff --git a/src/ops/mod.rs b/src/ops/mod.rs index bcb09a349..c0ac10a5d 100644 --- a/src/ops/mod.rs +++ b/src/ops/mod.rs @@ -1,2 +1,4 @@ -mod array; +mod arithmetic; mod factory; +mod logical; +mod reduction; diff --git a/src/ops/reduction.rs b/src/ops/reduction.rs new file mode 100644 index 000000000..8b9beb009 --- /dev/null +++ b/src/ops/reduction.rs @@ -0,0 +1,177 @@ +use crate::array::Array; +use crate::error::OperationError; +use crate::stream::StreamOrDevice; +use mlx_macros::default_device; + +impl Array { + /// An `and` reduction over the given axes. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]); + /// let mut b = a.all(&[0][..], None); + /// + /// b.eval(); + /// let results: &[bool] = b.as_slice(); + /// // results == [false, true, true, true] + /// ``` + /// + /// # Params + /// + /// - axes: The axes to reduce over + /// - keep_dims: Whether to keep the reduced dimensions -- defaults to false if not provided + /// - stream: The stream to execute the operation on + #[default_device] + pub fn all_device<'a>( + &'a self, + axes: impl Into>, + keep_dims: impl Into>, + stream: StreamOrDevice, + ) -> Array { + self.try_all_device(axes, keep_dims, stream).unwrap() + } + + /// An `and` reduction over the given axes without validating axes are valid for the array. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]); + /// let mut b = unsafe { a.all_unchecked(&[0][..], None) }; + /// + /// b.eval(); + /// let results: &[bool] = b.as_slice(); + /// // results == [false, true, true, true] + /// ``` + /// + /// # Params + /// + /// - axes: The axes to reduce over + /// - keep_dims: Whether to keep the reduced dimensions -- defaults to false if not provided + /// - stream: The stream to execute the operation on + /// + /// # Safety + /// + /// This function is unsafe because it does not validate that the axes are valid for the array. + #[default_device] + pub unsafe fn all_device_unchecked<'a>( + &'a self, + axes: impl Into>, + keep_dims: impl Into>, + stream: StreamOrDevice, + ) -> Array { + let axes = match axes.into() { + Some(axes) => axes.to_vec(), + None => { + let axes: Vec = (0..self.ndim() as i32).collect(); + axes + } + }; + + Array::from_ptr(mlx_sys::mlx_all_axes( + self.c_array, + axes.as_ptr(), + axes.len(), + keep_dims.into().unwrap_or(false), + stream.as_ptr(), + )) + } + + /// An `and` reduction over the given axes returning an error if the axes are invalid. + /// + /// # Example + /// + /// ```rust + /// use mlx::Array; + /// let a = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]); + /// let mut b = a.try_all(&[0][..], None).unwrap(); + /// + /// b.eval(); + /// let results: &[bool] = b.as_slice(); + /// // results == [false, true, true, true] + /// ``` + /// + /// # Params + /// + /// - axes: The axes to reduce over + /// - keep_dims: Whether to keep the reduced dimensions -- defaults to false if not provided + /// - stream: The stream to execute the operation on + #[default_device] + pub fn try_all_device<'a>( + &'a self, + axes: impl Into>, + keep_dims: impl Into>, + stream: StreamOrDevice, + ) -> Result { + let axes = match axes.into() { + Some(axes) => axes.to_vec(), + None => { + let axes: Vec = (0..self.ndim() as i32).collect(); + axes + } + }; + + let ndim = self.shape().len() as i32; + let mut axes_set = std::collections::HashSet::new(); + for axis in axes.clone() { + let ax = if axis < 0 { axis + ndim } else { axis }; + if ax < 0 || ax >= ndim { + return Err(OperationError::AxisOutOfBounds(format!( + "Invalid axis {} for array with {} dimensions", + axis, ndim + ))); + } + + axes_set.insert(ax); + } + + if axes_set.len() != axes.len() { + return Err(OperationError::WrongInput(format!( + "Duplicate axes in {:?}", + axes + ))); + } + + Ok(unsafe { + Array::from_ptr(mlx_sys::mlx_all_axes( + self.c_array, + axes.as_ptr(), + axes.len(), + keep_dims.into().unwrap_or(false), + stream.as_ptr(), + )) + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_all_axes() { + let array = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]); + let mut all = array.all(&[0][..], None); + + all.eval(); + let results: &[bool] = all.as_slice(); + assert_eq!(results, &[false, true, true, true]); + } + + #[test] + fn test_all_axes_out_of_bounds() { + let array = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[12]); + let result = array.try_all(&[1][..], None); + assert!(result.is_err()); + } + + #[test] + fn test_all_axes_duplicate_axes() { + let array = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]); + let result = array.try_all(&[0, 0][..], None); + assert!(result.is_err()); + } +} diff --git a/src/utils.rs b/src/utils.rs index 1503971e7..2d29cb0a6 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,3 +1,5 @@ +use crate::Array; + /// Helper method to get a string representation of an mlx object. pub(crate) fn mlx_describe(ptr: *mut ::std::os::raw::c_void) -> Option { let mlx_description = unsafe { mlx_sys::mlx_tostring(ptr) }; @@ -17,3 +19,159 @@ pub(crate) fn mlx_describe(ptr: *mut ::std::os::raw::c_void) -> Option { description } + +/// Helper method to check if two arrays are broadcastable. +/// +/// Uses the same broadcasting rules as numpy. +/// https://numpy.org/doc/1.20/user/theory.broadcasting.html +/// +/// "The size of the trailing axes for both arrays in an operation must +/// either be the same size or one of them must be one." +pub(crate) fn is_broadcastable(a: &[i32], b: &[i32]) -> bool { + a.iter() + .rev() + .zip(b.iter().rev()) + .all(|(a, b)| *a == 1 || *b == 1 || a == b) +} + +impl Array { + /// Helper method to check if an array can be reshaped to a given shape. + pub fn can_reshape_to(&self, shape: &[i32]) -> bool { + if self.shape() == shape { + return true; + } + + let mut size = 1; + let mut infer_idx: isize = -1; + for i in 0..shape.len() { + if shape[i] == -1 { + if infer_idx >= 0 { + return false; + } + + infer_idx = i as isize; + } else { + size *= shape[i]; + } + } + + if size > 0 { + let quotient = self.size() / size as usize; + if infer_idx >= 0 { + size *= quotient as i32; + } + } else if infer_idx >= 0 { + return false; + } + + // validate the reshaping is valid + if self.size() != size as usize { + return false; + } + + return true; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_broadcastable() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let b = Array::from_slice(&[2.0, 2.0, 2.0], &[3]); + assert!(is_broadcastable(a.shape(), b.shape())); + + let a = Array::from_slice( + &[ + 0.0, 0.0, 0.0, 10.0, 10.0, 10.0, 20.0, 20.0, 20.0, 30.0, 30.0, 30.0, + ], + &[4, 3], + ); + let b = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + assert!(is_broadcastable(a.shape(), b.shape())); + + let a = Array::from_slice( + &[ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + ], + &[2, 2, 4], + ); + let b = Array::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]); + assert!(is_broadcastable(a.shape(), b.shape())); + } + + #[test] + fn test_is_broadcastable_scalar() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let b: Array = 2.0.into(); + assert!(is_broadcastable(a.shape(), b.shape())); + } + + #[test] + fn test_is_broadcastable_empty() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + assert!(is_broadcastable(&[], a.shape())); + } + + #[test] + fn test_not_broadcastable() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let b = Array::from_slice(&[2.0, 2.0, 2.0, 2.0], &[4]); + assert!(!is_broadcastable(a.shape(), b.shape())); + + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let b = Array::from_slice(&[2.0, 2.0], &[1, 2]); + assert!(!is_broadcastable(a.shape(), b.shape())); + } + + #[test] + fn test_can_reshape_to() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + assert!(a.can_reshape_to(&[3])); + assert!(a.can_reshape_to(&[1, 3])); + assert!(a.can_reshape_to(&[3, 1])); + assert!(a.can_reshape_to(&[1, 1, 3])); + assert!(a.can_reshape_to(&[1, 3, 1])); + assert!(a.can_reshape_to(&[3, 1, 1])); + assert!(a.can_reshape_to(&[1, 1, 1, 3])); + assert!(a.can_reshape_to(&[1, 1, 3, 1])); + assert!(a.can_reshape_to(&[1, 3, 1, 1])); + assert!(a.can_reshape_to(&[3, 1, 1, 1])); + assert!(a.can_reshape_to(&[1, 1, 1, 1, 3])); + assert!(a.can_reshape_to(&[1, 1, 1, 3, 1])); + assert!(a.can_reshape_to(&[1, 1, 3, 1, 1])); + assert!(a.can_reshape_to(&[1, 3, 1, 1, 1])); + assert!(a.can_reshape_to(&[3, 1, 1, 1, 1])); + assert!(a.can_reshape_to(&[1, 1, 1, 1, 1, 3])); + assert!(a.can_reshape_to(&[1, 1, 1, 1, 3, 1])); + assert!(a.can_reshape_to(&[1, 1, 1, 3, 1, 1])); + assert!(a.can_reshape_to(&[1, 1, 3, 1, 1, 1])); + assert!(a.can_reshape_to(&[1, 3, 1, 1, 1, 1])); + assert!(a.can_reshape_to(&[3, 1, 1, 1, 1, 1])); + } + + #[test] + fn test_reshape_negative_dim() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + assert!(a.can_reshape_to(&[1, -1])); + assert!(a.can_reshape_to(&[-1, 1])); + assert!(a.can_reshape_to(&[-1])); + assert!(a.can_reshape_to(&[1, -1, 1])); + assert!(a.can_reshape_to(&[-1, 1, 1])); + + assert!(!a.can_reshape_to(&[1, -2])); + } + + #[test] + fn test_cannot_reshape_to() { + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + assert!(!a.can_reshape_to(&[2])); + assert!(!a.can_reshape_to(&[2, 2])); + assert!(!a.can_reshape_to(&[2, 2, 2])); + assert!(!a.can_reshape_to(&[2, 2, 2, 2])); + assert!(!a.can_reshape_to(&[2, 2, 2, 2, 2])); + assert!(!a.can_reshape_to(&[2, 2, 2, 2, 2, 2])); + } +}