Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api): Arithmetic & Logical ops #39

Merged
merged 24 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ half = "2"
num-complex = "0.4"
num_enum = "0.7.2"
num-traits = "0.2.18"
strum = { version = "0.26", features = ["derive"] }
thiserror = "1.0.58"

[dev-dependencies]
pretty_assertions = "1.4.0"

[workspace]
members = ["mlx-macros", "mlx-sys", "mlx-macros"]
Expand Down
49 changes: 48 additions & 1 deletion mlx-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ extern crate proc_macro;
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::punctuated::Punctuated;
use syn::{parse_macro_input, parse_quote, FnArg, ItemFn, Pat};
use syn::{parse_macro_input, parse_quote, DeriveInput, FnArg, ItemFn, Pat};

#[proc_macro_attribute]
pub fn default_device(_attr: TokenStream, item: TokenStream) -> TokenStream {
Expand Down Expand Up @@ -51,3 +51,50 @@ pub fn default_device(_attr: TokenStream, item: TokenStream) -> TokenStream {

TokenStream::from(expanded)
}

#[proc_macro_derive(GenerateDtypeTestCases)]
pub fn generate_test_cases(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;

let tests = quote! {
/// MLX's rules for promoting two dtypes.
#[rustfmt::skip]
const TYPE_RULES: [[Dtype; 13]; 13] = [
// bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 bfloat16 complex64
[Dtype::Bool, Dtype::Uint8, Dtype::Uint16, Dtype::Uint32, Dtype::Uint64, Dtype::Int8, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // bool
[Dtype::Uint8, Dtype::Uint8, Dtype::Uint16, Dtype::Uint32, Dtype::Uint64, Dtype::Int16, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // uint8
[Dtype::Uint16, Dtype::Uint16, Dtype::Uint16, Dtype::Uint32, Dtype::Uint64, Dtype::Int32, Dtype::Int32, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // uint16
[Dtype::Uint32, Dtype::Uint32, Dtype::Uint32, Dtype::Uint32, Dtype::Uint64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // uint32
[Dtype::Uint64, Dtype::Uint64, Dtype::Uint64, Dtype::Uint64, Dtype::Uint64, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // uint64
[Dtype::Int8, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float32, Dtype::Int8, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // int8
[Dtype::Int16, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float32, Dtype::Int16, Dtype::Int16, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // int16
[Dtype::Int32, Dtype::Int32, Dtype::Int32, Dtype::Int64, Dtype::Float32, Dtype::Int32, Dtype::Int32, Dtype::Int32, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // int32
[Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Float32, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Int64, Dtype::Float16, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // int64
[Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float16, Dtype::Float32, Dtype::Float32, Dtype::Complex64], // float16
[Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Float32, Dtype::Complex64], // float32
[Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Bfloat16, Dtype::Float32, Dtype::Float32, Dtype::Bfloat16, Dtype::Complex64], // bfloat16
[Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64, Dtype::Complex64], // complex64
];

#[cfg(test)]
mod generated_tests {
use super::*;
use strum::IntoEnumIterator;
use pretty_assertions::assert_eq;

#[test]
fn test_all_combinations() {
for a in #name::iter() {
for b in #name::iter() {
let result = a.promote_with(b);
let expected = TYPE_RULES[a as usize][b as usize];
assert_eq!(result, expected, "{}", format!("Failed promotion test for {:?} and {:?}", a, b));
}
}
}
}
};

TokenStream::from(tests)
}
64 changes: 55 additions & 9 deletions src/array.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::ffi::c_void;
use std::ops::Add;
use std::ops::{Add, Div, Mul, Neg, Rem, Sub};

use half::{bf16, f16};
use mlx_sys::mlx_array;
use num_complex::Complex;
use num_traits::Pow;

use crate::{dtype::Dtype, error::AsSliceError, sealed::Sealed, StreamOrDevice};

Expand Down Expand Up @@ -137,6 +138,55 @@ pub struct Array {
pub(crate) c_array: mlx_array,
}

dcvz marked this conversation as resolved.
Show resolved Hide resolved
impl<'a> Add for &'a Array {
type Output = Array;
fn add(self, rhs: Self) -> Self::Output {
self.add_device(rhs, StreamOrDevice::default())
}
}

impl<'a> Sub for &'a Array {
type Output = Array;
fn sub(self, rhs: Self) -> Self::Output {
self.sub_device(rhs, StreamOrDevice::default())
}
}

impl<'a> Neg for &'a Array {
type Output = Array;
fn neg(self) -> Self::Output {
self.logical_not()
}
}

impl<'a> Mul for &'a Array {
type Output = Array;
fn mul(self, rhs: Self) -> Self::Output {
self.mul_device(rhs, StreamOrDevice::default())
}
}

impl<'a> Div for &'a Array {
type Output = Array;
fn div(self, rhs: Self) -> Self::Output {
self.div_device(rhs, StreamOrDevice::default())
}
}

impl<'a> Pow<&'a Array> for &'a Array {
type Output = Array;
fn pow(self, rhs: &'a Array) -> Self::Output {
self.pow_device(rhs, StreamOrDevice::default())
}
}

impl<'a> Rem for &'a Array {
type Output = Array;
fn rem(self, rhs: Self) -> Self::Output {
self.rem_device(rhs, StreamOrDevice::default())
}
}

impl std::fmt::Debug for Array {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let description = crate::utils::mlx_describe(self.c_array as *mut c_void)
Expand All @@ -153,13 +203,6 @@ impl std::fmt::Display for Array {
}
}

impl<'a> Add for &'a Array {
type Output = Array;
fn add(self, rhs: Self) -> Self::Output {
self.add_device(rhs, StreamOrDevice::default())
}
}

// TODO: Clone should probably NOT be implemented because the underlying pointer is atomically
// reference counted but not guarded by a mutex.

Expand Down Expand Up @@ -371,7 +414,10 @@ impl Array {
}

if self.dtype() != T::DTYPE {
return Err(AsSliceError::DtypeMismatch);
return Err(AsSliceError::DtypeMismatch {
expecting: T::DTYPE,
found: self.dtype(),
});
}

Ok(unsafe { self.as_slice_unchecked() })
Expand Down
122 changes: 121 additions & 1 deletion src/dtype.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
use mlx_macros::GenerateDtypeTestCases;
use strum::EnumIter;

/// Array element type
#[derive(
Debug, Clone, Copy, PartialEq, Eq, num_enum::IntoPrimitive, num_enum::TryFromPrimitive,
Debug,
Clone,
Copy,
PartialEq,
Eq,
num_enum::IntoPrimitive,
num_enum::TryFromPrimitive,
EnumIter,
GenerateDtypeTestCases,
)]
#[repr(u32)]
pub enum Dtype {
Expand All @@ -18,3 +29,112 @@ pub enum Dtype {
Bfloat16 = mlx_sys::mlx_array_dtype__MLX_BFLOAT16,
Complex64 = mlx_sys::mlx_array_dtype__MLX_COMPLEX64,
}

impl Dtype {
pub fn is_complex(&self) -> bool {
matches!(self, Dtype::Complex64)
}

pub fn is_floating(&self) -> bool {
matches!(self, Dtype::Float16 | Dtype::Float32 | Dtype::Bfloat16)
}

pub fn is_inexact(&self) -> bool {
matches!(
self,
Dtype::Float16 | Dtype::Float32 | Dtype::Complex64 | Dtype::Bfloat16
)
}

pub fn from_promoting_types(a: Dtype, b: Dtype) -> Self {
a.promote_with(b)
}
}

pub(crate) trait TypePromotion {
fn promote_with(self, other: Self) -> Self;
}

impl TypePromotion for Dtype {
fn promote_with(self, other: Self) -> Self {
use crate::dtype::Dtype::*;
match (self, other) {
// Boolean promotions
(Bool, Bool) => Bool,
(Bool, _) | (_, Bool) => {
if self == Bool {
other
} else {
self
}
}

// Uint8 promotions
(Uint8, Uint8) => Uint8,
(Uint8, Uint16) | (Uint16, Uint8) => Uint16,
(Uint8, Uint32) | (Uint32, Uint8) => Uint32,
(Uint8, Uint64) | (Uint64, Uint8) => Uint64,
(Uint8, Int8) | (Int8, Uint8) => Int16,
(Uint8, Int16) | (Int16, Uint8) => Int16,
(Uint8, Int32) | (Int32, Uint8) => Int32,
(Uint8, Int64) | (Int64, Uint8) => Int64,

// Uint16 promotions
(Uint16, Uint16) => Uint16,
(Uint16, Uint32) | (Uint32, Uint16) => Uint32,
(Uint16, Uint64) | (Uint64, Uint16) => Uint64,
(Uint16, Int8) | (Int8, Uint16) => Int32,
(Uint16, Int16) | (Int16, Uint16) => Int32,
(Uint16, Int32) | (Int32, Uint16) => Int32,
(Uint16, Int64) | (Int64, Uint16) => Int64,

// Uint32 promotions
(Uint32, Uint32) => Uint32,
(Uint32, Uint64) | (Uint64, Uint32) => Uint64,
(Uint32, Int8) | (Int8, Uint32) => Int64,
(Uint32, Int16) | (Int16, Uint32) => Int64,
(Uint32, Int32) | (Int32, Uint32) => Int64,
(Uint32, Int64) | (Int64, Uint32) => Int64,

// Uint64 promotions
(Uint64, Uint64) => Uint64,
(Uint64, Int8) | (Int8, Uint64) => Float32,
(Uint64, Int16) | (Int16, Uint64) => Float32,
(Uint64, Int32) | (Int32, Uint64) => Float32,
(Uint64, Int64) | (Int64, Uint64) => Float32,

// Int8 promotions
(Int8, Int8) => Int8,
(Int8, Int16) | (Int16, Int8) => Int16,
(Int8, Int32) | (Int32, Int8) => Int32,
(Int8, Int64) | (Int64, Int8) => Int64,

// Int16 promotions
(Int16, Int16) => Int16,
(Int16, Int32) | (Int32, Int16) => Int32,
(Int16, Int64) | (Int64, Int16) => Int64,

// Int32 promotions
(Int32, Int32) => Int32,
(Int32, Int64) | (Int64, Int32) => Int64,

// Int64 promotions
(Int64, Int64) => Int64,

// Float16 promotions
(Float16, Bfloat16) | (Bfloat16, Float16) => Float32,

// Complex type
(Complex64, _) | (_, Complex64) => Complex64,

// Float32 promotions
(Float32, _) | (_, Float32) => Float32,

// Float16 promotions
(Float16, _) | (_, Float16) => Float16,

// Bfloat16 promotions
(Bfloat16, _) | (_, Bfloat16) => Bfloat16,
}
}
}
35 changes: 32 additions & 3 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,41 @@
use crate::Dtype;
use thiserror::Error;

#[derive(Error, Debug)]
#[derive(Error, PartialEq, Debug)]
pub enum MlxError {
#[error("data store error: {0}")]
DataStore(#[from] DataStoreError),
#[error("operation error: {0}")]
Operation(#[from] OperationError),
#[error("as slice error: {0}")]
AsSlice(#[from] AsSliceError),
}

#[derive(Error, PartialEq, Debug)]
pub enum DataStoreError {
#[error("negative dimension: {0}")]
NegativeDimensions(String),

#[error("negative integer: {0}")]
NegativeInteger(String),

#[error("broadcast error")]
BroadcastError,
}

#[derive(Error, PartialEq, Debug)]
pub enum OperationError {
#[error("operation not supported: {0}")]
NotSupported(String),

#[error("wrong input: {0}")]
WrongInput(String),

#[error("wrong dimensions: {0}")]
WrongDimensions(String),

#[error("axis out of bounds: {0}")]
AxisOutOfBounds(String),
}

/// Error associated with `Array::try_as_slice()`
Expand All @@ -19,6 +48,6 @@ pub enum AsSliceError {
Null,

/// The output dtype does not match the data type of the array.
#[error("Desired output dtype does not match the data type of the array.")]
DtypeMismatch,
#[error("dtype mismatch: expected {expecting:?}, found {found:?}")]
DtypeMismatch { expecting: Dtype, found: Dtype },
}
Loading
Loading