Skip to content

Commit

Permalink
feat(api): Arithmetic & Logical ops (#39)
Browse files Browse the repository at this point in the history
Co-authored-by: minghuaw <[email protected]>
  • Loading branch information
dcvz and minghuaw authored Apr 22, 2024
1 parent 5b7ceb5 commit 270a2e5
Show file tree
Hide file tree
Showing 12 changed files with 3,643 additions and 87 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ half = "2"
num-complex = "0.4"
num_enum = "0.7.2"
num-traits = "0.2.18"
strum = { version = "0.26", features = ["derive"] }
thiserror = "1.0.58"

[dev-dependencies]
pretty_assertions = "1.4.0"

[workspace]
members = ["mlx-macros", "mlx-sys", "mlx-macros"]
Expand Down
49 changes: 48 additions & 1 deletion mlx-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ extern crate proc_macro;
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::punctuated::Punctuated;
use syn::{parse_macro_input, parse_quote, FnArg, ItemFn, Pat};
use syn::{parse_macro_input, parse_quote, DeriveInput, FnArg, ItemFn, Pat};

#[proc_macro_attribute]
pub fn default_device(_attr: TokenStream, item: TokenStream) -> TokenStream {
Expand Down Expand Up @@ -51,3 +51,50 @@ pub fn default_device(_attr: TokenStream, item: TokenStream) -> TokenStream {

TokenStream::from(expanded)
}

#[proc_macro_derive(GenerateDtypeTestCases)]
pub fn generate_test_cases(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;

let tests = quote! {
/// MLX's rules for promoting two dtypes.
#[rustfmt::skip]
const TYPE_RULES: [[Dtype; 13]; 13] = [
// bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 bfloat16 complex64
[Dtype::Bool, Dtype::Uint8, Dtype::Uint16, Dtype::Uint32, Dtype::Uint64, Dtype::Int8, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // bool
[Dtype::Uint8, Dtype::Uint8, Dtype::Uint16, Dtype::Uint32, Dtype::Uint64, Dtype::Int16, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // uint8
[Dtype::Uint16, Dtype::Uint16, Dtype::Uint16, Dtype::Uint32, Dtype::Uint64, Dtype::Int32, Dtype::Int32, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // uint16
[Dtype::Uint32, Dtype::Uint32, Dtype::Uint32, Dtype::Uint32, Dtype::Uint64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // uint32
[Dtype::Uint64, Dtype::Uint64, Dtype::Uint64, Dtype::Uint64, Dtype::Uint64, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // uint64
[Dtype::Int8, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float32, Dtype::Int8, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // int8
[Dtype::Int16, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float32, Dtype::Int16, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // int16
[Dtype::Int32, Dtype::Int32, Dtype::Int32, Dtype::Int64, Dtype::Float32, Dtype::Int32, Dtype::Int32, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // int32
[Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Float32, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // int64
[Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float32, Dtype::Float32, Dtype::Complex64], // float16
[Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Complex64], // float32
[Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Float32, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // bfloat16
[Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64], // complex64
];

#[cfg(test)]
mod generated_tests {
use super::*;
use strum::IntoEnumIterator;
use pretty_assertions::assert_eq;

#[test]
fn test_all_combinations() {
for a in #name::iter() {
for b in #name::iter() {
let result = a.promote_with(b);
let expected = TYPE_RULES[a as usize][b as usize];
assert_eq!(result, expected, "{}", format!("Failed promotion test for {:?} and {:?}", a, b));
}
}
}
}
};

TokenStream::from(tests)
}
64 changes: 55 additions & 9 deletions src/array.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::ffi::c_void;
use std::ops::Add;
use std::ops::{Add, Div, Mul, Neg, Rem, Sub};

use half::{bf16, f16};
use mlx_sys::mlx_array;
use num_complex::Complex;
use num_traits::Pow;

use crate::{dtype::Dtype, error::AsSliceError, sealed::Sealed, StreamOrDevice};

Expand Down Expand Up @@ -137,6 +138,55 @@ pub struct Array {
pub(crate) c_array: mlx_array,
}

impl<'a> Add for &'a Array {
type Output = Array;
fn add(self, rhs: Self) -> Self::Output {
self.add_device(rhs, StreamOrDevice::default())
}
}

impl<'a> Sub for &'a Array {
type Output = Array;
fn sub(self, rhs: Self) -> Self::Output {
self.sub_device(rhs, StreamOrDevice::default())
}
}

impl<'a> Neg for &'a Array {
type Output = Array;
fn neg(self) -> Self::Output {
self.logical_not()
}
}

impl<'a> Mul for &'a Array {
type Output = Array;
fn mul(self, rhs: Self) -> Self::Output {
self.mul_device(rhs, StreamOrDevice::default())
}
}

impl<'a> Div for &'a Array {
type Output = Array;
fn div(self, rhs: Self) -> Self::Output {
self.div_device(rhs, StreamOrDevice::default())
}
}

impl<'a> Pow<&'a Array> for &'a Array {
type Output = Array;
fn pow(self, rhs: &'a Array) -> Self::Output {
self.pow_device(rhs, StreamOrDevice::default())
}
}

impl<'a> Rem for &'a Array {
type Output = Array;
fn rem(self, rhs: Self) -> Self::Output {
self.rem_device(rhs, StreamOrDevice::default())
}
}

impl std::fmt::Debug for Array {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let description = crate::utils::mlx_describe(self.c_array as *mut c_void)
Expand All @@ -153,13 +203,6 @@ impl std::fmt::Display for Array {
}
}

impl<'a> Add for &'a Array {
type Output = Array;
fn add(self, rhs: Self) -> Self::Output {
self.add_device(rhs, StreamOrDevice::default())
}
}

// TODO: Clone should probably NOT be implemented because the underlying pointer is atomically
// reference counted but not guarded by a mutex.

Expand Down Expand Up @@ -371,7 +414,10 @@ impl Array {
}

if self.dtype() != T::DTYPE {
return Err(AsSliceError::DtypeMismatch);
return Err(AsSliceError::DtypeMismatch {
expecting: T::DTYPE,
found: self.dtype(),
});
}

Ok(unsafe { self.as_slice_unchecked() })
Expand Down
122 changes: 121 additions & 1 deletion src/dtype.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
use mlx_macros::GenerateDtypeTestCases;
use strum::EnumIter;

/// Array element type
#[derive(
Debug, Clone, Copy, PartialEq, Eq, num_enum::IntoPrimitive, num_enum::TryFromPrimitive,
Debug,
Clone,
Copy,
PartialEq,
Eq,
num_enum::IntoPrimitive,
num_enum::TryFromPrimitive,
EnumIter,
GenerateDtypeTestCases,
)]
#[repr(u32)]
pub enum Dtype {
Expand All @@ -18,3 +29,112 @@ pub enum Dtype {
Bfloat16 = mlx_sys::mlx_array_dtype__MLX_BFLOAT16,
Complex64 = mlx_sys::mlx_array_dtype__MLX_COMPLEX64,
}

impl Dtype {
pub fn is_complex(&self) -> bool {
matches!(self, Dtype::Complex64)
}

pub fn is_floating(&self) -> bool {
matches!(self, Dtype::Float16 | Dtype::Float32 | Dtype::Bfloat16)
}

pub fn is_inexact(&self) -> bool {
matches!(
self,
Dtype::Float16 | Dtype::Float32 | Dtype::Complex64 | Dtype::Bfloat16
)
}

pub fn from_promoting_types(a: Dtype, b: Dtype) -> Self {
a.promote_with(b)
}
}

pub(crate) trait TypePromotion {
fn promote_with(self, other: Self) -> Self;
}

impl TypePromotion for Dtype {
fn promote_with(self, other: Self) -> Self {
use crate::dtype::Dtype::*;
match (self, other) {
// Boolean promotions
(Bool, Bool) => Bool,
(Bool, _) | (_, Bool) => {
if self == Bool {
other
} else {
self
}
}

// Uint8 promotions
(Uint8, Uint8) => Uint8,
(Uint8, Uint16) | (Uint16, Uint8) => Uint16,
(Uint8, Uint32) | (Uint32, Uint8) => Uint32,
(Uint8, Uint64) | (Uint64, Uint8) => Uint64,
(Uint8, Int8) | (Int8, Uint8) => Int16,
(Uint8, Int16) | (Int16, Uint8) => Int16,
(Uint8, Int32) | (Int32, Uint8) => Int32,
(Uint8, Int64) | (Int64, Uint8) => Int64,

// Uint16 promotions
(Uint16, Uint16) => Uint16,
(Uint16, Uint32) | (Uint32, Uint16) => Uint32,
(Uint16, Uint64) | (Uint64, Uint16) => Uint64,
(Uint16, Int8) | (Int8, Uint16) => Int32,
(Uint16, Int16) | (Int16, Uint16) => Int32,
(Uint16, Int32) | (Int32, Uint16) => Int32,
(Uint16, Int64) | (Int64, Uint16) => Int64,

// Uint32 promotions
(Uint32, Uint32) => Uint32,
(Uint32, Uint64) | (Uint64, Uint32) => Uint64,
(Uint32, Int8) | (Int8, Uint32) => Int64,
(Uint32, Int16) | (Int16, Uint32) => Int64,
(Uint32, Int32) | (Int32, Uint32) => Int64,
(Uint32, Int64) | (Int64, Uint32) => Int64,

// Uint64 promotions
(Uint64, Uint64) => Uint64,
(Uint64, Int8) | (Int8, Uint64) => Float32,
(Uint64, Int16) | (Int16, Uint64) => Float32,
(Uint64, Int32) | (Int32, Uint64) => Float32,
(Uint64, Int64) | (Int64, Uint64) => Float32,

// Int8 promotions
(Int8, Int8) => Int8,
(Int8, Int16) | (Int16, Int8) => Int16,
(Int8, Int32) | (Int32, Int8) => Int32,
(Int8, Int64) | (Int64, Int8) => Int64,

// Int16 promotions
(Int16, Int16) => Int16,
(Int16, Int32) | (Int32, Int16) => Int32,
(Int16, Int64) | (Int64, Int16) => Int64,

// Int32 promotions
(Int32, Int32) => Int32,
(Int32, Int64) | (Int64, Int32) => Int64,

// Int64 promotions
(Int64, Int64) => Int64,

// Float16 promotions
(Float16, Bfloat16) | (Bfloat16, Float16) => Float32,

// Complex type
(Complex64, _) | (_, Complex64) => Complex64,

// Float32 promotions
(Float32, _) | (_, Float32) => Float32,

// Float16 promotions
(Float16, _) | (_, Float16) => Float16,

// Bfloat16 promotions
(Bfloat16, _) | (_, Bfloat16) => Bfloat16,
}
}
}
35 changes: 32 additions & 3 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,41 @@
use crate::Dtype;
use thiserror::Error;

#[derive(Error, Debug)]
#[derive(Error, PartialEq, Debug)]
pub enum MlxError {
#[error("data store error: {0}")]
DataStore(#[from] DataStoreError),
#[error("operation error: {0}")]
Operation(#[from] OperationError),
#[error("as slice error: {0}")]
AsSlice(#[from] AsSliceError),
}

#[derive(Error, PartialEq, Debug)]
pub enum DataStoreError {
#[error("negative dimension: {0}")]
NegativeDimensions(String),

#[error("negative integer: {0}")]
NegativeInteger(String),

#[error("broadcast error")]
BroadcastError,
}

#[derive(Error, PartialEq, Debug)]
pub enum OperationError {
#[error("operation not supported: {0}")]
NotSupported(String),

#[error("wrong input: {0}")]
WrongInput(String),

#[error("wrong dimensions: {0}")]
WrongDimensions(String),

#[error("axis out of bounds: {0}")]
AxisOutOfBounds(String),
}

/// Error associated with `Array::try_as_slice()`
Expand All @@ -19,6 +48,6 @@ pub enum AsSliceError {
Null,

/// The output dtype does not match the data type of the array.
#[error("Desired output dtype does not match the data type of the array.")]
DtypeMismatch,
#[error("dtype mismatch: expected {expecting:?}, found {found:?}")]
DtypeMismatch { expecting: Dtype, found: Dtype },
}
Loading

0 comments on commit 270a2e5

Please sign in to comment.