From e0455b3f1c69a7fdf1101bab59752666961616f8 Mon Sep 17 00:00:00 2001 From: minghuaw Date: Thu, 18 Apr 2024 06:08:57 -0700 Subject: [PATCH 01/25] adding fft mod --- src/array.rs | 2 +- src/fft.rs | 5 +++++ src/lib.rs | 1 + 3 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 src/fft.rs diff --git a/src/array.rs b/src/array.rs index 3a4af69d1..5f1525fdd 100644 --- a/src/array.rs +++ b/src/array.rs @@ -134,7 +134,7 @@ impl ArrayElement for complex64 { } pub struct Array { - c_array: mlx_array, + pub(crate) c_array: mlx_array, } impl std::fmt::Debug for Array { diff --git a/src/fft.rs b/src/fft.rs new file mode 100644 index 000000000..dcd9fe8d1 --- /dev/null +++ b/src/fft.rs @@ -0,0 +1,5 @@ +use crate::{array::Array, stream::StreamOrDevice}; + +// pub fn fft(array: Array, n: i32, axis: i32, s: StreamOrDevice) -> Array { + +// } \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 1471f885a..48ab0397e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ pub mod array; pub mod device; pub mod dtype; +pub mod fft; pub mod stream; mod utils; From 5efba3aa61a3228fb097f325a43fef2d5e68e299 Mon Sep 17 00:00:00 2001 From: minghuaw Date: Fri, 19 Apr 2024 22:33:49 -0700 Subject: [PATCH 02/25] output dtype of fft is complex64 --- src/device.rs | 3 ++- src/fft.rs | 58 +++++++++++++++++++++++++++++++++++++++++++++++++-- src/stream.rs | 5 +++-- 3 files changed, 61 insertions(+), 5 deletions(-) diff --git a/src/device.rs b/src/device.rs index f884ce8b8..2c4f18b9e 100644 --- a/src/device.rs +++ b/src/device.rs @@ -30,7 +30,8 @@ impl Device { /// Set the default device. /// - /// Example: + /// # Example: + /// /// ```rust /// use mlx::device::{Device, DeviceType}; /// Device::set_default(&Device::new(DeviceType::Cpu, 1)); diff --git a/src/fft.rs b/src/fft.rs index dcd9fe8d1..249f0a5ac 100644 --- a/src/fft.rs +++ b/src/fft.rs @@ -1,5 +1,59 @@ use crate::{array::Array, stream::StreamOrDevice}; -// pub fn fft(array: Array, n: i32, axis: i32, s: StreamOrDevice) -> Array { +pub fn fft(array: Array, n: i32, axis: i32) -> Array { + // FFT is not yet implemented on gpu + let s = StreamOrDevice::cpu(); + unsafe { + let c_array = mlx_sys::mlx_fft_fft(array.c_array, n, axis, s.stream.c_stream); + Array::from_ptr(c_array) + } +} -// } \ No newline at end of file +pub fn fft_device(array: Array, n: i32, axis: i32, s: StreamOrDevice) -> Array { + unsafe { + let c_array = mlx_sys::mlx_fft_fft(array.c_array, n, axis, s.stream.c_stream); + Array::from_ptr(c_array) + } +} + +#[cfg(test)] +mod tests { + use crate::{array::complex64, fft::fft_device, stream::StreamOrDevice}; + + use super::fft; + + #[test] + fn test_fft() { + let array = crate::array::Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); + let mut result = fft(array, 4, 0); + result.eval(); + + assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); + + let expected = &[ + complex64::new(10.0, 0.0), + complex64::new(-2.0, 2.0), + complex64::new(-2.0, 0.0), + complex64::new(-2.0, -2.0), + ]; + assert_eq!(result.as_slice::(), Some(&expected[..])); + } + + #[test] + fn test_fft_device() { + let array = crate::array::Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); + let s = StreamOrDevice::cpu(); + let mut result = fft_device(array, 4, 0, s); + result.eval(); + + assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); + + let expected = &[ + complex64::new(10.0, 0.0), + complex64::new(-2.0, 2.0), + complex64::new(-2.0, 0.0), + complex64::new(-2.0, -2.0), + ]; + assert_eq!(result.as_slice::(), Some(&expected[..])); + } +} \ No newline at end of file diff --git a/src/stream.rs b/src/stream.rs index 79e032ca2..f3abaf20b 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -8,7 +8,7 @@ use crate::utils::mlx_describe; /// If omitted it will use the [default()], which will be [Device::gpu()] unless /// set otherwise. pub struct StreamOrDevice { - stream: Stream, + pub(crate) stream: Stream, } impl StreamOrDevice { @@ -59,7 +59,7 @@ impl std::fmt::Display for StreamOrDevice { /// /// Typically, this is used via the `stream:` parameter on a method with a [StreamOrDevice]: pub struct Stream { - c_stream: mlx_sys::mlx_stream, + pub(crate) c_stream: mlx_sys::mlx_stream, } impl Stream { @@ -79,6 +79,7 @@ impl Stream { Stream { c_stream } } + // TODO: document how this is different from `Default::default()` pub fn default_stream(device: &Device) -> Stream { let default_stream = unsafe { mlx_sys::mlx_default_stream(device.c_device) }; Stream::new_with_mlx_mlx_stream(default_stream) From cd3fd30306ff0987df9b3d014edee3dbd9046d4c Mon Sep 17 00:00:00 2001 From: minghuaw Date: Sat, 20 Apr 2024 09:46:57 -0700 Subject: [PATCH 03/25] add device input to default_device macro --- mlx-macros/Cargo.toml | 1 + mlx-macros/src/lib.rs | 50 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/mlx-macros/Cargo.toml b/mlx-macros/Cargo.toml index ab8c6388b..13d6e67db 100644 --- a/mlx-macros/Cargo.toml +++ b/mlx-macros/Cargo.toml @@ -17,3 +17,4 @@ proc-macro = true [dependencies] syn = { version = "2.0.60", features = ["full"] } quote = "1.0" +darling = "0.20" diff --git a/mlx-macros/src/lib.rs b/mlx-macros/src/lib.rs index 3fe8f311c..70139de95 100644 --- a/mlx-macros/src/lib.rs +++ b/mlx-macros/src/lib.rs @@ -1,11 +1,47 @@ extern crate proc_macro; +use darling::FromMeta; use proc_macro::TokenStream; use quote::{format_ident, quote}; use syn::punctuated::Punctuated; use syn::{parse_macro_input, parse_quote, FnArg, ItemFn, Pat}; +#[derive(Debug, FromMeta)] +enum DeviceType { + Cpu, + Gpu, +} + +#[derive(Debug)] +struct DefaultDeviceInput { + device: DeviceType, +} + +impl FromMeta for DefaultDeviceInput { + fn from_meta(meta: &syn::Meta) -> darling::Result { + let syn::Meta::NameValue(meta_name_value) = meta else { + return Err(darling::Error::unsupported_format( + "expected a name-value attribute", + )); + }; + + let ident = meta_name_value.path.get_ident().unwrap(); + assert_eq!(ident, "device", "expected `device`"); + + let device = DeviceType::from_expr(&meta_name_value.value)?; + + Ok(DefaultDeviceInput { device }) + } +} + #[proc_macro_attribute] -pub fn default_device(_attr: TokenStream, item: TokenStream) -> TokenStream { +pub fn default_device(attr: TokenStream, item: TokenStream) -> TokenStream { + let input = if !attr.is_empty() { + let meta = syn::parse_macro_input!(attr as syn::Meta); + Some(DefaultDeviceInput::from_meta(&meta).unwrap()) + } else { + None + }; + let mut input_fn = parse_macro_input!(item as ItemFn); let original_fn = input_fn.clone(); @@ -37,8 +73,16 @@ pub fn default_device(_attr: TokenStream, item: TokenStream) -> TokenStream { input_fn.sig.inputs = Punctuated::from_iter(filtered_inputs); // Prepend default stream initialization - let default_stream_stmt = parse_quote! { - let stream = StreamOrDevice::default(); + let default_stream_stmt = match input.map(|input| input.device) { + Some(DeviceType::Cpu) => parse_quote! { + let stream = StreamOrDevice::cpu(); + }, + Some(DeviceType::Gpu) => parse_quote! { + let stream = StreamOrDevice::gpu(); + }, + None => parse_quote! { + let stream = StreamOrDevice::default(); + }, }; input_fn.block.stmts.insert(0, default_stream_stmt); From 502add8e9974836063771da14182bfd1ba0a8e6f Mon Sep 17 00:00:00 2001 From: minghuaw Date: Sat, 20 Apr 2024 10:24:22 -0700 Subject: [PATCH 04/25] use macro to generate non device fft unchecked --- src/fft.rs | 58 +++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 40 insertions(+), 18 deletions(-) diff --git a/src/fft.rs b/src/fft.rs index 249f0a5ac..825bdce09 100644 --- a/src/fft.rs +++ b/src/fft.rs @@ -1,31 +1,53 @@ -use crate::{array::Array, stream::StreamOrDevice}; +use mlx_macros::default_device; -pub fn fft(array: Array, n: i32, axis: i32) -> Array { - // FFT is not yet implemented on gpu - let s = StreamOrDevice::cpu(); - unsafe { - let c_array = mlx_sys::mlx_fft_fft(array.c_array, n, axis, s.stream.c_stream); - Array::from_ptr(c_array) - } -} +use crate::{array::Array, stream::StreamOrDevice}; -pub fn fft_device(array: Array, n: i32, axis: i32, s: StreamOrDevice) -> Array { +/// One dimensional discrete Fourier Transform. +/// +/// # Params +/// +/// - a: The input array. +/// - n: Size of the transformed axis. The corresponding axis in the input is truncated or padded +/// with zeros to match `n`. The default value is `a.shape[axis]`. +/// - axis: Axis along which to perform the FFT. The default is -1. +#[default_device(device = "cpu")] // fft is not implemented on GPU yet +pub unsafe fn fft_device_unchecked( + a: Array, + n: impl Into>, + axis: impl Into>, + stream: StreamOrDevice, +) -> Array { + let axis = axis.into().unwrap_or(-1); + let n = n.into().unwrap_or_else(|| { + if axis.is_negative() { + // TODO: replace with unchecked_add when it's stable + let index = (a.ndim() as i32) + .checked_add(axis) + .unwrap() + // index may still be negative + .max(0) as usize; + a.shape()[index] + } else { + // # Safety: positive i32 is always smaller than usize::MAX + a.shape()[axis as usize] + } + }); unsafe { - let c_array = mlx_sys::mlx_fft_fft(array.c_array, n, axis, s.stream.c_stream); + let c_array = mlx_sys::mlx_fft_fft(a.c_array, n, axis, stream.stream.c_stream); Array::from_ptr(c_array) } } #[cfg(test)] mod tests { - use crate::{array::complex64, fft::fft_device, stream::StreamOrDevice}; + use crate::{array::complex64, stream::StreamOrDevice}; - use super::fft; + use super::*; #[test] - fn test_fft() { + fn test_fft_unchecked() { let array = crate::array::Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); - let mut result = fft(array, 4, 0); + let mut result = unsafe { fft_unchecked(array, 4, 0) }; result.eval(); assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); @@ -40,10 +62,10 @@ mod tests { } #[test] - fn test_fft_device() { + fn test_fft_device_unchecked() { let array = crate::array::Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); let s = StreamOrDevice::cpu(); - let mut result = fft_device(array, 4, 0, s); + let mut result = unsafe { fft_device_unchecked(array, 4, 0, s) }; result.eval(); assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); @@ -56,4 +78,4 @@ mod tests { ]; assert_eq!(result.as_slice::(), Some(&expected[..])); } -} \ No newline at end of file +} From 4cab6a7d83f469767f227596720b9724ee3af599 Mon Sep 17 00:00:00 2001 From: minghuaw Date: Sat, 20 Apr 2024 11:15:18 -0700 Subject: [PATCH 05/25] test try fft --- src/error.rs | 9 +++++ src/fft.rs | 108 +++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 113 insertions(+), 4 deletions(-) diff --git a/src/error.rs b/src/error.rs index ff35f7781..8fc7d288f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -22,3 +22,12 @@ pub enum AsSliceError { #[error("Desired output dtype does not match the data type of the array.")] DtypeMismatch, } + +#[derive(Error, Debug)] +pub enum FftError { + #[error("fftn requires at least one dimension")] + ScalarArray, + + #[error("Invalid axis received for array with {0} dimensions")] + InvalidAxis(usize), +} \ No newline at end of file diff --git a/src/fft.rs b/src/fft.rs index 1cd906235..e33f8f714 100644 --- a/src/fft.rs +++ b/src/fft.rs @@ -1,6 +1,6 @@ use mlx_macros::default_device; -use crate::{array::Array, stream::StreamOrDevice}; +use crate::{array::Array, error::FftError, stream::StreamOrDevice}; /// One dimensional discrete Fourier Transform. /// @@ -12,7 +12,7 @@ use crate::{array::Array, stream::StreamOrDevice}; /// - axis: Axis along which to perform the FFT. The default is -1. #[default_device(device = "cpu")] // fft is not implemented on GPU yet pub unsafe fn fft_device_unchecked( - a: Array, + a: &Array, n: impl Into>, axis: impl Into>, stream: StreamOrDevice, @@ -38,6 +38,36 @@ pub unsafe fn fft_device_unchecked( } } +#[default_device(device = "cpu")] // fft is not implemented on GPU yet +pub fn try_fft_device( + a: &Array, + n: impl Into>, + axis: impl Into>, + stream: StreamOrDevice, +) -> Result { + if a.ndim() < 1 { + return Err(FftError::ScalarArray); + } + + let axis = axis.into().unwrap_or(-1); + let (n, axis) = if axis.is_negative() { + if axis.abs() as usize > a.ndim() { + return Err(FftError::InvalidAxis(a.ndim())); + } + let index = a.ndim() - axis.abs() as usize; + let n = n.into().unwrap_or(a.shape()[index]); + (n, axis) + } else { + if axis as usize >= a.ndim() { + return Err(FftError::InvalidAxis(a.ndim())); + } + let n = n.into().unwrap_or(a.shape()[axis as usize]); + (n, axis) + }; + + Ok(unsafe { fft_device_unchecked(a, Some(n), Some(axis), stream) }) +} + #[cfg(test)] mod tests { use crate::{array::complex64, stream::StreamOrDevice}; @@ -47,7 +77,7 @@ mod tests { #[test] fn test_fft_unchecked() { let array = crate::array::Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); - let mut result = unsafe { fft_unchecked(array, 4, 0) }; + let mut result = unsafe { fft_unchecked(&array, 4, 0) }; result.eval(); assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); @@ -59,13 +89,17 @@ mod tests { complex64::new(-2.0, -2.0), ]; assert_eq!(result.as_slice::(), &expected[..]); + + // test that previous array is not modified and valid + let data: &[f32] = array.as_slice(); + assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); } #[test] fn test_fft_device_unchecked() { let array = crate::array::Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); let s = StreamOrDevice::cpu(); - let mut result = unsafe { fft_device_unchecked(array, 4, 0, s) }; + let mut result = unsafe { fft_device_unchecked(&array, 4, 0, s) }; result.eval(); assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); @@ -77,5 +111,71 @@ mod tests { complex64::new(-2.0, -2.0), ]; assert_eq!(result.as_slice::(), &expected[..]); + + // test that previous array is not modified and valid + let data: &[f32] = array.as_slice(); + assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); + } + + #[test] + fn test_try_fft() { + let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); + + // Error case + let scalar_array = Array::from_float(1.0); + let result = try_fft(&scalar_array, 0, 0); + assert!(result.is_err()); + + let result = try_fft(&array, 4, 2); + assert!(result.is_err()); + + // Success case + let mut result = try_fft(&array, 4, 0).unwrap(); + result.eval(); + + assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); + + let expected = &[ + complex64::new(10.0, 0.0), + complex64::new(-2.0, 2.0), + complex64::new(-2.0, 0.0), + complex64::new(-2.0, -2.0), + ]; + assert_eq!(result.as_slice::(), &expected[..]); + + // test that previous array is not modified and valid + let data: &[f32] = array.as_slice(); + assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); + } + + #[test] + fn test_try_fft_device() { + let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); + + // Error case + let scalar_array = Array::from_float(1.0); + let result = try_fft_device(&scalar_array, 0, 0, StreamOrDevice::cpu()); + assert!(result.is_err()); + + let result = try_fft_device(&array, 4, 2, StreamOrDevice::cpu()); + assert!(result.is_err()); + + // Success case + let mut result = try_fft_device(&array, 4, 0, StreamOrDevice::cpu()).unwrap(); + result.eval(); + + assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); + + let expected = &[ + complex64::new(10.0, 0.0), + complex64::new(-2.0, 2.0), + complex64::new(-2.0, 0.0), + complex64::new(-2.0, -2.0), + ]; + assert_eq!(result.as_slice::(), &expected[..]); + + // test that previous array is not modified and valid + let data: &[f32] = array.as_slice(); + assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); } } From 02f6c0d822654bb4f5f5de006892cdad215126f2 Mon Sep 17 00:00:00 2001 From: minghuaw Date: Sat, 20 Apr 2024 21:10:34 -0700 Subject: [PATCH 06/25] use unit test as example --- src/fft.rs | 162 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 112 insertions(+), 50 deletions(-) diff --git a/src/fft.rs b/src/fft.rs index e33f8f714..b1df03ef2 100644 --- a/src/fft.rs +++ b/src/fft.rs @@ -10,6 +10,31 @@ use crate::{array::Array, error::FftError, stream::StreamOrDevice}; /// - n: Size of the transformed axis. The corresponding axis in the input is truncated or padded /// with zeros to match `n`. The default value is `a.shape[axis]`. /// - axis: Axis along which to perform the FFT. The default is -1. +/// +/// # Example +/// +/// ```rust +/// use mlx::{Dtype, Array, StreamOrDevice, complex64, fft::*}; +/// +/// let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); +/// let s = StreamOrDevice::cpu(); +/// let mut result = unsafe { fft_device_unchecked(&array, 4, 0, s) }; +/// result.eval(); +/// +/// assert_eq!(result.dtype(), Dtype::Complex64); +/// +/// let expected = &[ +/// complex64::new(10.0, 0.0), +/// complex64::new(-2.0, 2.0), +/// complex64::new(-2.0, 0.0), +/// complex64::new(-2.0, -2.0), +/// ]; +/// assert_eq!(result.as_slice::(), &expected[..]); +/// +/// // test that previous array is not modified and valid +/// let data: &[f32] = array.as_slice(); +/// assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); +/// ``` #[default_device(device = "cpu")] // fft is not implemented on GPU yet pub unsafe fn fft_device_unchecked( a: &Array, @@ -38,6 +63,39 @@ pub unsafe fn fft_device_unchecked( } } +/// # Example +/// +/// ```rust +/// use mlx::{Dtype, Array, StreamOrDevice, complex64, fft::*}; +/// +/// let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); +/// +/// // Error case +/// let scalar_array = Array::from_float(1.0); +/// let result = try_fft_device(&scalar_array, 0, 0, StreamOrDevice::cpu()); +/// assert!(result.is_err()); +/// +/// let result = try_fft_device(&array, 4, 2, StreamOrDevice::cpu()); +/// assert!(result.is_err()); +/// +/// // Success case +/// let mut result = try_fft_device(&array, 4, 0, StreamOrDevice::cpu()).unwrap(); +/// result.eval(); +/// +/// assert_eq!(result.dtype(), Dtype::Complex64); +/// +/// let expected = &[ +/// complex64::new(10.0, 0.0), +/// complex64::new(-2.0, 2.0), +/// complex64::new(-2.0, 0.0), +/// complex64::new(-2.0, -2.0), +/// ]; +/// assert_eq!(result.as_slice::(), &expected[..]); +/// +/// // test that previous array is not modified and valid +/// let data: &[f32] = array.as_slice(); +/// assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); +/// ``` #[default_device(device = "cpu")] // fft is not implemented on GPU yet pub fn try_fft_device( a: &Array, @@ -70,17 +128,15 @@ pub fn try_fft_device( #[cfg(test)] mod tests { - use crate::{array::complex64, stream::StreamOrDevice}; - - use super::*; - #[test] fn test_fft_unchecked() { - let array = crate::array::Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); + use crate::{Dtype, Array, complex64, fft::*}; + + let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); let mut result = unsafe { fft_unchecked(&array, 4, 0) }; result.eval(); - assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); + assert_eq!(result.dtype(), Dtype::Complex64); let expected = &[ complex64::new(10.0, 0.0), @@ -90,35 +146,39 @@ mod tests { ]; assert_eq!(result.as_slice::(), &expected[..]); - // test that previous array is not modified and valid + // The original array is not modified and valid let data: &[f32] = array.as_slice(); assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); } - #[test] - fn test_fft_device_unchecked() { - let array = crate::array::Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); - let s = StreamOrDevice::cpu(); - let mut result = unsafe { fft_device_unchecked(&array, 4, 0, s) }; - result.eval(); + // #[test] + // fn test_fft_device_unchecked() { + // use crate::{Array, StreamOrDevice, complex64, fft::*}; - assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); + // let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); + // let s = StreamOrDevice::cpu(); + // let mut result = unsafe { fft_device_unchecked(&array, 4, 0, s) }; + // result.eval(); - let expected = &[ - complex64::new(10.0, 0.0), - complex64::new(-2.0, 2.0), - complex64::new(-2.0, 0.0), - complex64::new(-2.0, -2.0), - ]; - assert_eq!(result.as_slice::(), &expected[..]); + // assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); - // test that previous array is not modified and valid - let data: &[f32] = array.as_slice(); - assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); - } + // let expected = &[ + // complex64::new(10.0, 0.0), + // complex64::new(-2.0, 2.0), + // complex64::new(-2.0, 0.0), + // complex64::new(-2.0, -2.0), + // ]; + // assert_eq!(result.as_slice::(), &expected[..]); + + // // test that previous array is not modified and valid + // let data: &[f32] = array.as_slice(); + // assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); + // } #[test] fn test_try_fft() { + use crate::{Dtype, Array, complex64, fft::*}; + let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); // Error case @@ -133,7 +193,7 @@ mod tests { let mut result = try_fft(&array, 4, 0).unwrap(); result.eval(); - assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); + assert_eq!(result.dtype(), Dtype::Complex64); let expected = &[ complex64::new(10.0, 0.0), @@ -148,34 +208,36 @@ mod tests { assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); } - #[test] - fn test_try_fft_device() { - let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); + // #[test] + // fn test_try_fft_device() { + // use crate::{Array, StreamOrDevice, complex64, fft::*}; - // Error case - let scalar_array = Array::from_float(1.0); - let result = try_fft_device(&scalar_array, 0, 0, StreamOrDevice::cpu()); - assert!(result.is_err()); + // let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); - let result = try_fft_device(&array, 4, 2, StreamOrDevice::cpu()); - assert!(result.is_err()); + // // Error case + // let scalar_array = Array::from_float(1.0); + // let result = try_fft_device(&scalar_array, 0, 0, StreamOrDevice::cpu()); + // assert!(result.is_err()); - // Success case - let mut result = try_fft_device(&array, 4, 0, StreamOrDevice::cpu()).unwrap(); - result.eval(); + // let result = try_fft_device(&array, 4, 2, StreamOrDevice::cpu()); + // assert!(result.is_err()); - assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); + // // Success case + // let mut result = try_fft_device(&array, 4, 0, StreamOrDevice::cpu()).unwrap(); + // result.eval(); - let expected = &[ - complex64::new(10.0, 0.0), - complex64::new(-2.0, 2.0), - complex64::new(-2.0, 0.0), - complex64::new(-2.0, -2.0), - ]; - assert_eq!(result.as_slice::(), &expected[..]); + // assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); - // test that previous array is not modified and valid - let data: &[f32] = array.as_slice(); - assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); - } + // let expected = &[ + // complex64::new(10.0, 0.0), + // complex64::new(-2.0, 2.0), + // complex64::new(-2.0, 0.0), + // complex64::new(-2.0, -2.0), + // ]; + // assert_eq!(result.as_slice::(), &expected[..]); + + // // test that previous array is not modified and valid + // let data: &[f32] = array.as_slice(); + // assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); + // } } From c3cc17e4b5d051adaeb7cd699213f47c822a5c7f Mon Sep 17 00:00:00 2001 From: minghuaw Date: Sat, 20 Apr 2024 21:19:30 -0700 Subject: [PATCH 07/25] cargo fmt --- src/device.rs | 2 +- src/error.rs | 2 +- src/fft.rs | 34 +++++++++++++++++----------------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/device.rs b/src/device.rs index 4518558fb..8a4451db7 100644 --- a/src/device.rs +++ b/src/device.rs @@ -31,7 +31,7 @@ impl Device { /// Set the default device. /// /// # Example: - /// + /// /// ```rust /// use mlx::{Device, DeviceType}; /// Device::set_default(&Device::new(DeviceType::Cpu, 1)); diff --git a/src/error.rs b/src/error.rs index 8fc7d288f..2a80f794f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -30,4 +30,4 @@ pub enum FftError { #[error("Invalid axis received for array with {0} dimensions")] InvalidAxis(usize), -} \ No newline at end of file +} diff --git a/src/fft.rs b/src/fft.rs index b1df03ef2..7fc6452db 100644 --- a/src/fft.rs +++ b/src/fft.rs @@ -10,19 +10,19 @@ use crate::{array::Array, error::FftError, stream::StreamOrDevice}; /// - n: Size of the transformed axis. The corresponding axis in the input is truncated or padded /// with zeros to match `n`. The default value is `a.shape[axis]`. /// - axis: Axis along which to perform the FFT. The default is -1. -/// +/// /// # Example -/// +/// /// ```rust /// use mlx::{Dtype, Array, StreamOrDevice, complex64, fft::*}; -/// +/// /// let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); /// let s = StreamOrDevice::cpu(); /// let mut result = unsafe { fft_device_unchecked(&array, 4, 0, s) }; /// result.eval(); -/// +/// /// assert_eq!(result.dtype(), Dtype::Complex64); -/// +/// /// let expected = &[ /// complex64::new(10.0, 0.0), /// complex64::new(-2.0, 2.0), @@ -30,7 +30,7 @@ use crate::{array::Array, error::FftError, stream::StreamOrDevice}; /// complex64::new(-2.0, -2.0), /// ]; /// assert_eq!(result.as_slice::(), &expected[..]); -/// +/// /// // test that previous array is not modified and valid /// let data: &[f32] = array.as_slice(); /// assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); @@ -64,26 +64,26 @@ pub unsafe fn fft_device_unchecked( } /// # Example -/// +/// /// ```rust /// use mlx::{Dtype, Array, StreamOrDevice, complex64, fft::*}; -/// +/// /// let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); -/// +/// /// // Error case /// let scalar_array = Array::from_float(1.0); /// let result = try_fft_device(&scalar_array, 0, 0, StreamOrDevice::cpu()); /// assert!(result.is_err()); -/// +/// /// let result = try_fft_device(&array, 4, 2, StreamOrDevice::cpu()); /// assert!(result.is_err()); -/// +/// /// // Success case /// let mut result = try_fft_device(&array, 4, 0, StreamOrDevice::cpu()).unwrap(); /// result.eval(); -/// +/// /// assert_eq!(result.dtype(), Dtype::Complex64); -/// +/// /// let expected = &[ /// complex64::new(10.0, 0.0), /// complex64::new(-2.0, 2.0), @@ -91,7 +91,7 @@ pub unsafe fn fft_device_unchecked( /// complex64::new(-2.0, -2.0), /// ]; /// assert_eq!(result.as_slice::(), &expected[..]); -/// +/// /// // test that previous array is not modified and valid /// let data: &[f32] = array.as_slice(); /// assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); @@ -130,7 +130,7 @@ pub fn try_fft_device( mod tests { #[test] fn test_fft_unchecked() { - use crate::{Dtype, Array, complex64, fft::*}; + use crate::{complex64, fft::*, Array, Dtype}; let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); let mut result = unsafe { fft_unchecked(&array, 4, 0) }; @@ -177,10 +177,10 @@ mod tests { #[test] fn test_try_fft() { - use crate::{Dtype, Array, complex64, fft::*}; + use crate::{complex64, fft::*, Array, Dtype}; let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); - + // Error case let scalar_array = Array::from_float(1.0); let result = try_fft(&scalar_array, 0, 0); From facf23dbc5d9962547e8cb91d73ee4b99ba84941 Mon Sep 17 00:00:00 2001 From: minghuaw Date: Sun, 21 Apr 2024 00:50:55 -0700 Subject: [PATCH 08/25] added fft_device --- src/fft.rs | 157 ++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 113 insertions(+), 44 deletions(-) diff --git a/src/fft.rs b/src/fft.rs index 7fc6452db..3f011657b 100644 --- a/src/fft.rs +++ b/src/fft.rs @@ -63,6 +63,8 @@ pub unsafe fn fft_device_unchecked( } } +/// One dimensional discrete Fourier Transform. +/// /// # Example /// /// ```rust @@ -126,6 +128,23 @@ pub fn try_fft_device( Ok(unsafe { fft_device_unchecked(a, Some(n), Some(axis), stream) }) } +/// One dimensional discrete Fourier Transform. +/// +/// # Panic +/// +/// Panics if the input array is a scalar or if the axis is invalid. +/// +/// See [`try_fft_device`] for more details. +#[default_device(device = "cpu")] // fft is not implemented on GPU yet +pub fn fft_device( + a: &Array, + n: impl Into>, + axis: impl Into>, + stream: StreamOrDevice, +) -> Array { + try_fft_device(a, n, axis, stream).unwrap() +} + #[cfg(test)] mod tests { #[test] @@ -151,29 +170,29 @@ mod tests { assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); } - // #[test] - // fn test_fft_device_unchecked() { - // use crate::{Array, StreamOrDevice, complex64, fft::*}; + #[test] + fn test_fft_device_unchecked() { + use crate::{Array, StreamOrDevice, complex64, fft::*}; - // let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); - // let s = StreamOrDevice::cpu(); - // let mut result = unsafe { fft_device_unchecked(&array, 4, 0, s) }; - // result.eval(); + let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); + let s = StreamOrDevice::cpu(); + let mut result = unsafe { fft_device_unchecked(&array, 4, 0, s) }; + result.eval(); - // assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); + assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); - // let expected = &[ - // complex64::new(10.0, 0.0), - // complex64::new(-2.0, 2.0), - // complex64::new(-2.0, 0.0), - // complex64::new(-2.0, -2.0), - // ]; - // assert_eq!(result.as_slice::(), &expected[..]); + let expected = &[ + complex64::new(10.0, 0.0), + complex64::new(-2.0, 2.0), + complex64::new(-2.0, 0.0), + complex64::new(-2.0, -2.0), + ]; + assert_eq!(result.as_slice::(), &expected[..]); - // // test that previous array is not modified and valid - // let data: &[f32] = array.as_slice(); - // assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); - // } + // test that previous array is not modified and valid + let data: &[f32] = array.as_slice(); + assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); + } #[test] fn test_try_fft() { @@ -208,36 +227,86 @@ mod tests { assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); } - // #[test] - // fn test_try_fft_device() { - // use crate::{Array, StreamOrDevice, complex64, fft::*}; + #[test] + fn test_try_fft_device() { + use crate::{Array, StreamOrDevice, complex64, fft::*}; - // let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); + let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); - // // Error case - // let scalar_array = Array::from_float(1.0); - // let result = try_fft_device(&scalar_array, 0, 0, StreamOrDevice::cpu()); - // assert!(result.is_err()); + // Error case + let scalar_array = Array::from_float(1.0); + let result = try_fft_device(&scalar_array, 0, 0, StreamOrDevice::cpu()); + assert!(result.is_err()); + + let result = try_fft_device(&array, 4, 2, StreamOrDevice::cpu()); + assert!(result.is_err()); - // let result = try_fft_device(&array, 4, 2, StreamOrDevice::cpu()); - // assert!(result.is_err()); + // Success case + let mut result = try_fft_device(&array, 4, 0, StreamOrDevice::cpu()).unwrap(); + result.eval(); - // // Success case - // let mut result = try_fft_device(&array, 4, 0, StreamOrDevice::cpu()).unwrap(); - // result.eval(); + assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); - // assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); + let expected = &[ + complex64::new(10.0, 0.0), + complex64::new(-2.0, 2.0), + complex64::new(-2.0, 0.0), + complex64::new(-2.0, -2.0), + ]; + assert_eq!(result.as_slice::(), &expected[..]); - // let expected = &[ - // complex64::new(10.0, 0.0), - // complex64::new(-2.0, 2.0), - // complex64::new(-2.0, 0.0), - // complex64::new(-2.0, -2.0), - // ]; - // assert_eq!(result.as_slice::(), &expected[..]); + // test that previous array is not modified and valid + let data: &[f32] = array.as_slice(); + assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); + } - // // test that previous array is not modified and valid - // let data: &[f32] = array.as_slice(); - // assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); - // } + #[test] + fn test_fft() { + use crate::{complex64, fft::*, Array, Dtype}; + + let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); + + // Success case + let mut result = fft(&array, 4, 0); + result.eval(); + + assert_eq!(result.dtype(), Dtype::Complex64); + + let expected = &[ + complex64::new(10.0, 0.0), + complex64::new(-2.0, 2.0), + complex64::new(-2.0, 0.0), + complex64::new(-2.0, -2.0), + ]; + assert_eq!(result.as_slice::(), &expected[..]); + + // test that previous array is not modified and valid + let data: &[f32] = array.as_slice(); + assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); + } + + #[test] + fn test_fft_device() { + use crate::{Array, Dtype, StreamOrDevice, complex64, fft::*}; + + let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); + + // Success case + let mut result = fft_device(&array, 4, 0, StreamOrDevice::cpu()); + result.eval(); + + assert_eq!(result.dtype(), Dtype::Complex64); + + let expected = &[ + complex64::new(10.0, 0.0), + complex64::new(-2.0, 2.0), + complex64::new(-2.0, 0.0), + complex64::new(-2.0, -2.0), + ]; + assert_eq!(result.as_slice::(), &expected[..]); + + // test that previous array is not modified and valid + let data: &[f32] = array.as_slice(); + assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); + } } From 34b3c8017727d838dfa85c9be7f5b5aaf9161523 Mon Sep 17 00:00:00 2001 From: minghuaw Date: Sun, 21 Apr 2024 06:14:25 -0700 Subject: [PATCH 09/25] follow original mlx indexing behavior --- src/fft.rs | 60 ++++++++++++++++++++++++++++----------------------- src/stream.rs | 9 ++++---- src/utils.rs | 33 ++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+), 32 deletions(-) diff --git a/src/fft.rs b/src/fft.rs index 3f011657b..78c2dd582 100644 --- a/src/fft.rs +++ b/src/fft.rs @@ -1,6 +1,6 @@ use mlx_macros::default_device; -use crate::{array::Array, error::FftError, stream::StreamOrDevice}; +use crate::{array::Array, error::FftError, stream::StreamOrDevice, utils::{resolve_index, resolve_index_unchecked}}; /// One dimensional discrete Fourier Transform. /// @@ -44,18 +44,8 @@ pub unsafe fn fft_device_unchecked( ) -> Array { let axis = axis.into().unwrap_or(-1); let n = n.into().unwrap_or_else(|| { - if axis.is_negative() { - // TODO: replace with unchecked_add when it's stable - let index = (a.ndim() as i32) - .checked_add(axis) - .unwrap() - // index may still be negative - .max(0) as usize; - a.shape()[index] - } else { - // # Safety: positive i32 is always smaller than usize::MAX - a.shape()[axis as usize] - } + let axis_index = resolve_index_unchecked(axis, a.ndim()); + a.shape()[axis_index] }); unsafe { let c_array = mlx_sys::mlx_fft_fft(a.c_array, n, axis, stream.stream.c_stream); @@ -110,20 +100,8 @@ pub fn try_fft_device( } let axis = axis.into().unwrap_or(-1); - let (n, axis) = if axis.is_negative() { - if axis.abs() as usize > a.ndim() { - return Err(FftError::InvalidAxis(a.ndim())); - } - let index = a.ndim() - axis.abs() as usize; - let n = n.into().unwrap_or(a.shape()[index]); - (n, axis) - } else { - if axis as usize >= a.ndim() { - return Err(FftError::InvalidAxis(a.ndim())); - } - let n = n.into().unwrap_or(a.shape()[axis as usize]); - (n, axis) - }; + let axis_index = resolve_index(axis, a.ndim()); + let n = n.into().unwrap_or(a.shape()[axis_index]); Ok(unsafe { fft_device_unchecked(a, Some(n), Some(axis), stream) }) } @@ -145,6 +123,34 @@ pub fn fft_device( try_fft_device(a, n, axis, stream).unwrap() } +#[default_device(device = "cpu")] // fft is not implemented on GPU yet +pub fn fft2_device_unchecked( + a: &Array, + n: &[i32], + axes: &[i32], + stream: StreamOrDevice, +) -> Array { + let num_n = n.len(); + let num_axes = axes.len(); + + let n_ptr = n.as_ptr(); + let axes_ptr = axes.as_ptr(); + + unsafe { + let c_array = mlx_sys::mlx_fft_fft2( + a.c_array, + n_ptr, + num_n, + axes_ptr, + num_axes, + stream.as_ptr() + ); + + Array::from_ptr(c_array) + } +} + +// TODO: test out of bound indexing #[cfg(test)] mod tests { #[test] diff --git a/src/stream.rs b/src/stream.rs index 1941bbff7..a13748748 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -18,21 +18,21 @@ impl StreamOrDevice { pub fn new_with_device(device: &Device) -> StreamOrDevice { StreamOrDevice { - stream: Stream::default_stream(device), + stream: Stream::default_stream_on_device(device), } } /// The `[Stream::default_stream()] on the [Device::cpu()] pub fn cpu() -> StreamOrDevice { StreamOrDevice { - stream: Stream::default_stream(&Device::cpu()), + stream: Stream::default_stream_on_device(&Device::cpu()), } } /// The `[Stream::default_stream()] on the [Device::gpu()] pub fn gpu() -> StreamOrDevice { StreamOrDevice { - stream: Stream::default_stream(&Device::gpu()), + stream: Stream::default_stream_on_device(&Device::gpu()), } } @@ -84,8 +84,7 @@ impl Stream { Stream { c_stream } } - // TODO: document how this is different from `Default::default()` - pub fn default_stream(device: &Device) -> Stream { + pub fn default_stream_on_device(device: &Device) -> Stream { let default_stream = unsafe { mlx_sys::mlx_default_stream(device.c_device) }; Stream::new_with_mlx_mlx_stream(default_stream) } diff --git a/src/utils.rs b/src/utils.rs index 1503971e7..6af55eb17 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -17,3 +17,36 @@ pub(crate) fn mlx_describe(ptr: *mut ::std::os::raw::c_void) -> Option { description } + + +pub(crate) fn resolve_index_unchecked(index: i32, len: usize) -> usize { + if index.is_negative() { + (len as i32 + index) as usize + } else { + index as usize + } +} + +/// `mlx` differs from `numpy` in the way it handles out of bounds indices. It's behavior is more +/// like `jax`. See the related issue here https://github.com/ml-explore/mlx/issues/206. +/// +/// The issue says it would use the last element if the index is out of bounds. But testing with +/// python seems more like undefined behavior. Here we will use the last element if the index is +/// is out of bounds. +pub(crate) fn resolve_index(index: i32, len: usize) -> usize { + let abs_index = index.abs() as usize; + + if index.is_negative() { + if abs_index <= len { + len - abs_index + } else { + len - 1 + } + } else { + if abs_index < len { + abs_index + } else { + len - 1 + } + } +} \ No newline at end of file From 3e00b2a746273f10805f41f6bb10a86e176a4f51 Mon Sep 17 00:00:00 2001 From: minghuaw Date: Sun, 21 Apr 2024 06:14:45 -0700 Subject: [PATCH 10/25] cargo fmt --- src/fft.rs | 36 +++++++++++++++--------------------- src/utils.rs | 5 ++--- 2 files changed, 17 insertions(+), 24 deletions(-) diff --git a/src/fft.rs b/src/fft.rs index 78c2dd582..3570fb81b 100644 --- a/src/fft.rs +++ b/src/fft.rs @@ -1,6 +1,11 @@ use mlx_macros::default_device; -use crate::{array::Array, error::FftError, stream::StreamOrDevice, utils::{resolve_index, resolve_index_unchecked}}; +use crate::{ + array::Array, + error::FftError, + stream::StreamOrDevice, + utils::{resolve_index, resolve_index_unchecked}, +}; /// One dimensional discrete Fourier Transform. /// @@ -107,11 +112,11 @@ pub fn try_fft_device( } /// One dimensional discrete Fourier Transform. -/// +/// /// # Panic -/// +/// /// Panics if the input array is a scalar or if the axis is invalid. -/// +/// /// See [`try_fft_device`] for more details. #[default_device(device = "cpu")] // fft is not implemented on GPU yet pub fn fft_device( @@ -124,12 +129,7 @@ pub fn fft_device( } #[default_device(device = "cpu")] // fft is not implemented on GPU yet -pub fn fft2_device_unchecked( - a: &Array, - n: &[i32], - axes: &[i32], - stream: StreamOrDevice, -) -> Array { +pub fn fft2_device_unchecked(a: &Array, n: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { let num_n = n.len(); let num_axes = axes.len(); @@ -137,14 +137,8 @@ pub fn fft2_device_unchecked( let axes_ptr = axes.as_ptr(); unsafe { - let c_array = mlx_sys::mlx_fft_fft2( - a.c_array, - n_ptr, - num_n, - axes_ptr, - num_axes, - stream.as_ptr() - ); + let c_array = + mlx_sys::mlx_fft_fft2(a.c_array, n_ptr, num_n, axes_ptr, num_axes, stream.as_ptr()); Array::from_ptr(c_array) } @@ -178,7 +172,7 @@ mod tests { #[test] fn test_fft_device_unchecked() { - use crate::{Array, StreamOrDevice, complex64, fft::*}; + use crate::{complex64, fft::*, Array, StreamOrDevice}; let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); let s = StreamOrDevice::cpu(); @@ -235,7 +229,7 @@ mod tests { #[test] fn test_try_fft_device() { - use crate::{Array, StreamOrDevice, complex64, fft::*}; + use crate::{complex64, fft::*, Array, StreamOrDevice}; let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); @@ -293,7 +287,7 @@ mod tests { #[test] fn test_fft_device() { - use crate::{Array, Dtype, StreamOrDevice, complex64, fft::*}; + use crate::{complex64, fft::*, Array, Dtype, StreamOrDevice}; let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); diff --git a/src/utils.rs b/src/utils.rs index 6af55eb17..aa6187408 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -18,7 +18,6 @@ pub(crate) fn mlx_describe(ptr: *mut ::std::os::raw::c_void) -> Option { description } - pub(crate) fn resolve_index_unchecked(index: i32, len: usize) -> usize { if index.is_negative() { (len as i32 + index) as usize @@ -29,7 +28,7 @@ pub(crate) fn resolve_index_unchecked(index: i32, len: usize) -> usize { /// `mlx` differs from `numpy` in the way it handles out of bounds indices. It's behavior is more /// like `jax`. See the related issue here https://github.com/ml-explore/mlx/issues/206. -/// +/// /// The issue says it would use the last element if the index is out of bounds. But testing with /// python seems more like undefined behavior. Here we will use the last element if the index is /// is out of bounds. @@ -49,4 +48,4 @@ pub(crate) fn resolve_index(index: i32, len: usize) -> usize { len - 1 } } -} \ No newline at end of file +} From 16f1e88b989c3136ba7c28232c3bf8b0426ddca7 Mon Sep 17 00:00:00 2001 From: minghuaw Date: Sun, 21 Apr 2024 22:22:15 -0700 Subject: [PATCH 11/25] added fft2 --- Cargo.toml | 1 + src/array.rs | 6 ++ src/error.rs | 14 ++- src/fft.rs | 282 +++++++++++++++++++++++++++++++++++++++++++++++++-- src/utils.rs | 16 +-- 5 files changed, 295 insertions(+), 24 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bae839cf0..705ee83e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ num-complex = "0.4" num_enum = "0.7.2" num-traits = "0.2.18" thiserror = "1.0.58" +smallvec = "1" [dev-dependencies] diff --git a/src/array.rs b/src/array.rs index 2cca9f3b6..edf11873b 100644 --- a/src/array.rs +++ b/src/array.rs @@ -133,6 +133,12 @@ impl ArrayElement for complex64 { } } +// TODO: `mlx` differs from `numpy` in the way it handles out of bounds indices. It's behavior is more +// like `jax`. See the related issue here https://github.com/ml-explore/mlx/issues/206. +// +// The issue says it would use the last element if the index is out of bounds. But testing with +// python seems more like undefined behavior. Here we will use the last element if the index is +// is out of bounds. pub struct Array { pub(crate) c_array: mlx_array, } diff --git a/src/error.rs b/src/error.rs index 2a80f794f..75bd466ad 100644 --- a/src/error.rs +++ b/src/error.rs @@ -23,11 +23,17 @@ pub enum AsSliceError { DtypeMismatch, } -#[derive(Error, Debug)] -pub enum FftError { +#[derive(Error, Debug, PartialEq)] +pub enum FftnError { #[error("fftn requires at least one dimension")] ScalarArray, - #[error("Invalid axis received for array with {0} dimensions")] - InvalidAxis(usize), + #[error("Invalid axis received for array with {ndim} dimensions")] + InvalidAxis { ndim: usize }, + + #[error("Shape and axis have different sizes")] + ShapeAxisMismatch, + + #[error("Duplcated axis received: {axis}")] + DuplicateAxis { axis: i32 }, } diff --git a/src/fft.rs b/src/fft.rs index 3570fb81b..1a20aaabb 100644 --- a/src/fft.rs +++ b/src/fft.rs @@ -1,8 +1,11 @@ +use std::collections::HashSet; + use mlx_macros::default_device; +use smallvec::SmallVec; use crate::{ array::Array, - error::FftError, + error::FftnError, stream::StreamOrDevice, utils::{resolve_index, resolve_index_unchecked}, }; @@ -11,10 +14,10 @@ use crate::{ /// /// # Params /// -/// - a: The input array. -/// - n: Size of the transformed axis. The corresponding axis in the input is truncated or padded +/// - `a`: The input array. +/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded /// with zeros to match `n`. The default value is `a.shape[axis]`. -/// - axis: Axis along which to perform the FFT. The default is -1. +/// - `axis`: Axis along which to perform the FFT. The default is -1. /// /// # Example /// @@ -60,6 +63,13 @@ pub unsafe fn fft_device_unchecked( /// One dimensional discrete Fourier Transform. /// +/// # Params +/// +/// - `a`: The input array. +/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded +/// with zeros to match `n`. The default value is `a.shape[axis]`. +/// - `axis`: Axis along which to perform the FFT. The default is -1. +/// /// # Example /// /// ```rust @@ -99,13 +109,14 @@ pub fn try_fft_device( n: impl Into>, axis: impl Into>, stream: StreamOrDevice, -) -> Result { +) -> Result { if a.ndim() < 1 { - return Err(FftError::ScalarArray); + return Err(FftnError::ScalarArray); } let axis = axis.into().unwrap_or(-1); - let axis_index = resolve_index(axis, a.ndim()); + let axis_index = + resolve_index(axis, a.ndim()).ok_or_else(|| FftnError::InvalidAxis { ndim: a.ndim() })?; let n = n.into().unwrap_or(a.shape()[axis_index]); Ok(unsafe { fft_device_unchecked(a, Some(n), Some(axis), stream) }) @@ -113,6 +124,13 @@ pub fn try_fft_device( /// One dimensional discrete Fourier Transform. /// +/// # Params +/// +/// - `a`: The input array. +/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded +/// with zeros to match `n`. The default value is `a.shape[axis]`. +/// - `axis`: Axis along which to perform the FFT. The default is -1. +/// /// # Panic /// /// Panics if the input array is a scalar or if the axis is invalid. @@ -128,8 +146,13 @@ pub fn fft_device( try_fft_device(a, n, axis, stream).unwrap() } -#[default_device(device = "cpu")] // fft is not implemented on GPU yet -pub fn fft2_device_unchecked(a: &Array, n: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { +#[inline] +fn fft2_device_unchecked_inner( + a: &Array, + n: &[i32], + axes: &[i32], + stream: StreamOrDevice, +) -> Array { let num_n = n.len(); let num_axes = axes.len(); @@ -144,6 +167,154 @@ pub fn fft2_device_unchecked(a: &Array, n: &[i32], axes: &[i32], stream: StreamO } } +/// Two dimensional discrete Fourier Transform. +/// +/// # Param +/// +/// - `a`: The input array. +/// - `n`: Size of the transformed axes. The corresponding axes in the input are truncated or padded +/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. +/// +/// # Example +/// +/// ```rust +/// use mlx::{Dtype, Array, StreamOrDevice, complex64, fft::*}; +/// +/// let array = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); +/// let n = [2, 2]; +/// let axes = [-2, -1]; +/// let mut result = fft2_device_unchecked(&array, &n[..], &axes[..], StreamOrDevice::cpu()); +/// result.eval(); +/// +/// assert_eq!(result.dtype(), Dtype::Complex64); +/// +/// let expected = &[ +/// complex64::new(4.0, 0.0), +/// complex64::new(0.0, 0.0), +/// complex64::new(0.0, 0.0), +/// complex64::new(0.0, 0.0), +/// ]; +/// assert_eq!(result.as_slice::(), &expected[..]); +/// +/// // test that previous array is not modified and valid +/// let data: &[f32] = array.as_slice(); +/// assert_eq!(data, &[1.0, 1.0, 1.0, 1.0]); +/// ``` +#[default_device(device = "cpu")] // fft is not implemented on GPU yet +pub fn fft2_device_unchecked<'a>( + a: &'a Array, + n: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + let axes = axes.into().unwrap_or(&[-2, -1]); + let mut valid_n = SmallVec::<[i32; 2]>::new(); + match n.into() { + Some(n) => valid_n.extend_from_slice(&n), + None => { + for axis in axes { + let axis_index = resolve_index_unchecked(*axis, a.ndim()); + valid_n.push(a.shape()[axis_index]); + } + } + } + + fft2_device_unchecked_inner(a, &valid_n, axes, stream) +} + +/// Two dimensional discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `n`: Size of the transformed axes. The corresponding axes in the input are truncated or padded +/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. +/// +/// # Example +/// +/// ```rust +/// use mlx::{Dtype, Array, StreamOrDevice, complex64, fft::*}; +/// +/// let array = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); +/// let mut result = try_fft2_device(&array, None, None, StreamOrDevice::cpu()).unwrap(); +/// result.eval(); +/// assert_eq!(result.dtype(), Dtype::Complex64); +/// let expected = &[ +/// complex64::new(4.0, 0.0), +/// complex64::new(0.0, 0.0), +/// complex64::new(0.0, 0.0), +/// complex64::new(0.0, 0.0), +/// ]; +/// assert_eq!(result.as_slice::(), &expected[..]); +/// ``` +#[default_device(device = "cpu")] // fft is not implemented on GPU yet +pub fn try_fft2_device<'a>( + a: &'a Array, + n: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Result { + if a.ndim() < 1 { + return Err(FftnError::ScalarArray); + } + + // Check for duplicate axes + let axes = axes.into().unwrap_or(&[-2, -1]); + let mut unique_axes = HashSet::new(); + for axis in axes { + if !unique_axes.insert(axis) { + return Err(FftnError::DuplicateAxis { axis: *axis }); + } + } + + // valid shape + let mut valid_n = SmallVec::<[i32; 2]>::new(); + match n.into() { + Some(n) => { + if n.len() > a.ndim() { + return Err(FftnError::InvalidAxis { ndim: a.ndim() }); + } + valid_n.extend_from_slice(n); + } + None => { + for axis in axes { + let axis_index = resolve_index(*axis, a.ndim()) + .ok_or_else(|| FftnError::InvalidAxis { ndim: a.ndim() })?; + valid_n.push(a.shape()[axis_index]); + } + } + } + + // Check if shape and axes have the same size + if valid_n.len() != axes.len() { + return Err(FftnError::ShapeAxisMismatch); + } + + Ok(fft2_device_unchecked_inner(a, &valid_n, axes, stream)) +} + +/// Two dimensional discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `n`: Size of the transformed axes. The corresponding axes in the input are truncated or padded +/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. +/// +/// See [`try_fft2_device`] for more details. +#[default_device(device = "cpu")] // fft is not implemented on GPU yet +pub fn fft2_device<'a>( + a: &'a Array, + n: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + try_fft2_device(a, n, axes, stream).unwrap() +} + // TODO: test out of bound indexing #[cfg(test)] mod tests { @@ -309,4 +480,97 @@ mod tests { let data: &[f32] = array.as_slice(); assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); } + + #[test] + fn test_fft2_device_unchecked() { + use crate::{complex64, fft::*, Array, Dtype}; + + let array = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); + let n = [2, 2]; + let axes = [-2, -1]; + let mut result = fft2_device_unchecked(&array, &n[..], &axes[..], StreamOrDevice::cpu()); + result.eval(); + + assert_eq!(result.dtype(), Dtype::Complex64); + + let expected = &[ + complex64::new(4.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + ]; + assert_eq!(result.as_slice::(), &expected[..]); + + // test that previous array is not modified and valid + let data: &[f32] = array.as_slice(); + assert_eq!(data, &[1.0, 1.0, 1.0, 1.0]); + } + + #[test] + fn test_try_fft2_device() { + use crate::{complex64, error::FftnError, fft::*, Array, StreamOrDevice}; + + let array = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); + + // Error case + let scalar_array = Array::from_float(1.0); + let result = try_fft2_device(&scalar_array, None, None, StreamOrDevice::cpu()); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), FftnError::ScalarArray); + + let result = try_fft2_device(&array, &[2, 2, 2][..], None, StreamOrDevice::cpu()); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), FftnError::InvalidAxis { ndim: 2 }); + + let result = try_fft2_device(&array, &[2, 2][..], &[-1][..], StreamOrDevice::cpu()); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), FftnError::ShapeAxisMismatch); + + let result = try_fft2_device(&array, None, &[-2, -2][..], StreamOrDevice::cpu()); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), FftnError::DuplicateAxis { axis: -2 }); + + // Success case + let mut result = try_fft2_device(&array, None, None, StreamOrDevice::cpu()).unwrap(); + result.eval(); + + assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); + + let expected = &[ + complex64::new(4.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + ]; + assert_eq!(result.as_slice::(), &expected[..]); + + // test that previous array is not modified and valid + let data: &[f32] = array.as_slice(); + assert_eq!(data, &[1.0, 1.0, 1.0, 1.0]); + } + + #[test] + fn test_fft2_device() { + use crate::{complex64, fft::*, Array, Dtype}; + + let array = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); + let n = [2, 2]; + let axes = [-2, -1]; + let mut result = fft2_device(&array, Some(&n[..]), Some(&axes[..]), StreamOrDevice::cpu()); + result.eval(); + + assert_eq!(result.dtype(), Dtype::Complex64); + + let expected = &[ + complex64::new(4.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + ]; + assert_eq!(result.as_slice::(), &expected[..]); + + // test that previous array is not modified and valid + let data: &[f32] = array.as_slice(); + assert_eq!(data, &[1.0, 1.0, 1.0, 1.0]); + } } diff --git a/src/utils.rs b/src/utils.rs index aa6187408..a08901d27 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -26,26 +26,20 @@ pub(crate) fn resolve_index_unchecked(index: i32, len: usize) -> usize { } } -/// `mlx` differs from `numpy` in the way it handles out of bounds indices. It's behavior is more -/// like `jax`. See the related issue here https://github.com/ml-explore/mlx/issues/206. -/// -/// The issue says it would use the last element if the index is out of bounds. But testing with -/// python seems more like undefined behavior. Here we will use the last element if the index is -/// is out of bounds. -pub(crate) fn resolve_index(index: i32, len: usize) -> usize { +pub(crate) fn resolve_index(index: i32, len: usize) -> Option { let abs_index = index.abs() as usize; if index.is_negative() { if abs_index <= len { - len - abs_index + Some(len - abs_index) } else { - len - 1 + None } } else { if abs_index < len { - abs_index + Some(abs_index) } else { - len - 1 + None } } } From 892d55e180f4903a6e3a76e8489d3b04c2a9b82d Mon Sep 17 00:00:00 2001 From: minghuaw Date: Tue, 23 Apr 2024 01:26:52 -0700 Subject: [PATCH 12/25] impl fftn --- src/error.rs | 7 +- src/{fft.rs => fft/fftn.rs} | 468 ++++++++++++++++++++++-------------- src/fft/ifftn.rs | 1 + src/fft/irfftn.rs | 1 + src/fft/mod.rs | 97 ++++++++ src/fft/rfftn.rs | 1 + src/utils.rs | 11 + 7 files changed, 402 insertions(+), 184 deletions(-) rename src/{fft.rs => fft/fftn.rs} (57%) create mode 100644 src/fft/ifftn.rs create mode 100644 src/fft/irfftn.rs create mode 100644 src/fft/mod.rs create mode 100644 src/fft/rfftn.rs diff --git a/src/error.rs b/src/error.rs index 387fc12db..d8e23f77c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -60,9 +60,12 @@ pub enum FftnError { #[error("Invalid axis received for array with {ndim} dimensions")] InvalidAxis { ndim: usize }, - #[error("Shape and axis have different sizes")] - ShapeAxisMismatch, + #[error("Shape and axes/axis have different sizes")] + IncompatibleShapeAndAxes { shape_size: usize, axes_size: usize }, #[error("Duplcated axis received: {axis}")] DuplicateAxis { axis: i32 }, + + #[error("Invalid output size requested")] + InvalidOutputSize, } diff --git a/src/fft.rs b/src/fft/fftn.rs similarity index 57% rename from src/fft.rs rename to src/fft/fftn.rs index 1a20aaabb..1e0196cec 100644 --- a/src/fft.rs +++ b/src/fft/fftn.rs @@ -1,13 +1,8 @@ -use std::collections::HashSet; - use mlx_macros::default_device; use smallvec::SmallVec; use crate::{ - array::Array, - error::FftnError, - stream::StreamOrDevice, - utils::{resolve_index, resolve_index_unchecked}, + array::Array, error::FftnError, stream::StreamOrDevice, utils::resolve_index_unchecked, }; /// One dimensional discrete Fourier Transform. @@ -76,16 +71,6 @@ pub unsafe fn fft_device_unchecked( /// use mlx::{Dtype, Array, StreamOrDevice, complex64, fft::*}; /// /// let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); -/// -/// // Error case -/// let scalar_array = Array::from_float(1.0); -/// let result = try_fft_device(&scalar_array, 0, 0, StreamOrDevice::cpu()); -/// assert!(result.is_err()); -/// -/// let result = try_fft_device(&array, 4, 2, StreamOrDevice::cpu()); -/// assert!(result.is_err()); -/// -/// // Success case /// let mut result = try_fft_device(&array, 4, 0, StreamOrDevice::cpu()).unwrap(); /// result.eval(); /// @@ -98,10 +83,6 @@ pub unsafe fn fft_device_unchecked( /// complex64::new(-2.0, -2.0), /// ]; /// assert_eq!(result.as_slice::(), &expected[..]); -/// -/// // test that previous array is not modified and valid -/// let data: &[f32] = array.as_slice(); -/// assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); /// ``` #[default_device(device = "cpu")] // fft is not implemented on GPU yet pub fn try_fft_device( @@ -110,15 +91,7 @@ pub fn try_fft_device( axis: impl Into>, stream: StreamOrDevice, ) -> Result { - if a.ndim() < 1 { - return Err(FftnError::ScalarArray); - } - - let axis = axis.into().unwrap_or(-1); - let axis_index = - resolve_index(axis, a.ndim()).ok_or_else(|| FftnError::InvalidAxis { ndim: a.ndim() })?; - let n = n.into().unwrap_or(a.shape()[axis_index]); - + let (n, axis) = super::try_resolve_size_and_axis(a, n, axis)?; Ok(unsafe { fft_device_unchecked(a, Some(n), Some(axis), stream) }) } @@ -146,15 +119,9 @@ pub fn fft_device( try_fft_device(a, n, axis, stream).unwrap() } -#[inline] -fn fft2_device_unchecked_inner( - a: &Array, - n: &[i32], - axes: &[i32], - stream: StreamOrDevice, -) -> Array { - let num_n = n.len(); +fn fft2_device_inner(a: &Array, n: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { let num_axes = axes.len(); + let num_n = n.len(); let n_ptr = n.as_ptr(); let axes_ptr = axes.as_ptr(); @@ -162,7 +129,6 @@ fn fft2_device_unchecked_inner( unsafe { let c_array = mlx_sys::mlx_fft_fft2(a.c_array, n_ptr, num_n, axes_ptr, num_axes, stream.as_ptr()); - Array::from_ptr(c_array) } } @@ -172,7 +138,7 @@ fn fft2_device_unchecked_inner( /// # Param /// /// - `a`: The input array. -/// - `n`: Size of the transformed axes. The corresponding axes in the input are truncated or padded +/// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded /// with zeros to match `n`. The default value is the sizes of `a` along `axes`. /// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. /// @@ -182,9 +148,9 @@ fn fft2_device_unchecked_inner( /// use mlx::{Dtype, Array, StreamOrDevice, complex64, fft::*}; /// /// let array = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); -/// let n = [2, 2]; -/// let axes = [-2, -1]; -/// let mut result = fft2_device_unchecked(&array, &n[..], &axes[..], StreamOrDevice::cpu()); +/// let mut result = unsafe { +/// fft2_device_unchecked(&array, &[2, 2][..], &[-2,-1][..], StreamOrDevice::cpu()) +/// }; /// result.eval(); /// /// assert_eq!(result.dtype(), Dtype::Complex64); @@ -196,22 +162,18 @@ fn fft2_device_unchecked_inner( /// complex64::new(0.0, 0.0), /// ]; /// assert_eq!(result.as_slice::(), &expected[..]); -/// -/// // test that previous array is not modified and valid -/// let data: &[f32] = array.as_slice(); -/// assert_eq!(data, &[1.0, 1.0, 1.0, 1.0]); /// ``` #[default_device(device = "cpu")] // fft is not implemented on GPU yet -pub fn fft2_device_unchecked<'a>( +pub unsafe fn fft2_device_unchecked<'a>( a: &'a Array, - n: impl Into>, + s: impl Into>, axes: impl Into>, stream: StreamOrDevice, ) -> Array { let axes = axes.into().unwrap_or(&[-2, -1]); let mut valid_n = SmallVec::<[i32; 2]>::new(); - match n.into() { - Some(n) => valid_n.extend_from_slice(&n), + match s.into() { + Some(s) => valid_n.extend_from_slice(&s), None => { for axis in axes { let axis_index = resolve_index_unchecked(*axis, a.ndim()); @@ -220,7 +182,7 @@ pub fn fft2_device_unchecked<'a>( } } - fft2_device_unchecked_inner(a, &valid_n, axes, stream) + fft2_device_inner(a, &valid_n, axes, stream) } /// Two dimensional discrete Fourier Transform. @@ -228,7 +190,7 @@ pub fn fft2_device_unchecked<'a>( /// # Params /// /// - `a`: The input array. -/// - `n`: Size of the transformed axes. The corresponding axes in the input are truncated or padded +/// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded /// with zeros to match `n`. The default value is the sizes of `a` along `axes`. /// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. /// @@ -252,67 +214,193 @@ pub fn fft2_device_unchecked<'a>( #[default_device(device = "cpu")] // fft is not implemented on GPU yet pub fn try_fft2_device<'a>( a: &'a Array, - n: impl Into>, + s: impl Into>, axes: impl Into>, stream: StreamOrDevice, ) -> Result { - if a.ndim() < 1 { - return Err(FftnError::ScalarArray); - } + let valid_axes = axes.into().unwrap_or(&[-2, -1]); + let (valid_s, valid_axes) = super::try_resolve_sizes_and_axes(a, s, valid_axes)?; + Ok(fft2_device_inner(a, &valid_s, &valid_axes, stream)) +} - // Check for duplicate axes - let axes = axes.into().unwrap_or(&[-2, -1]); - let mut unique_axes = HashSet::new(); - for axis in axes { - if !unique_axes.insert(axis) { - return Err(FftnError::DuplicateAxis { axis: *axis }); - } +/// Two dimensional discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded +/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. +/// +/// # Panic +/// +/// - if the input array is a scalar array +/// - if the shape and axes have different sizes +/// - if more axes are provided than the array has +/// - if the output sizes are invalid (<= 0) +/// - if the axes are not unique +/// +/// See [`try_fft2_device`] for more details. +#[default_device(device = "cpu")] // fft is not implemented on GPU yet +pub fn fft2_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + try_fft2_device(a, s, axes, stream).unwrap() +} + +#[inline] +fn fftn_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { + let num_s = s.len(); + let num_axes = axes.len(); + + let s_ptr = s.as_ptr(); + let axes_ptr = axes.as_ptr(); + + unsafe { + let c_array = + mlx_sys::mlx_fft_fftn(a.c_array, s_ptr, num_s, axes_ptr, num_axes, stream.as_ptr()); + + Array::from_ptr(c_array) } +} - // valid shape - let mut valid_n = SmallVec::<[i32; 2]>::new(); - match n.into() { - Some(n) => { - if n.len() > a.ndim() { - return Err(FftnError::InvalidAxis { ndim: a.ndim() }); - } - valid_n.extend_from_slice(n); +/// N-dimensional discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or +/// padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes` +/// if not specified. +/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is +/// over the last `len(s)` axes are or all axes if `s` is also None. +/// +/// # Example +/// +/// ```rust +/// use mlx::{Dtype, Array, StreamOrDevice, complex64, fft::*}; +/// +/// let array = Array::ones::(&[3, 3, 3]); +/// +/// let mut result = unsafe { fftn_device_unchecked(&array, None, None, StreamOrDevice::cpu()) }; +/// result.eval(); +/// +/// assert_eq!(result.dtype(), Dtype::Complex64); +/// +/// let mut expected = vec![complex64::new(0.0, 0.0); 27]; +/// expected[0] = complex64::new(27.0, 0.0); +/// +/// assert_eq!(result.as_slice::(), &expected[..]); +/// ``` +#[default_device(device = "cpu")] // fft is not implemented on GPU yet +pub unsafe fn fftn_device_unchecked<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + let (valid_s, valid_axes) = match (s.into(), axes.into()) { + (Some(s), Some(axes)) => { + let valid_s = SmallVec::<[i32; 4]>::from_slice(s); + let valid_axes = SmallVec::<[i32; 4]>::from_slice(axes); + (valid_s, valid_axes) } - None => { - for axis in axes { - let axis_index = resolve_index(*axis, a.ndim()) - .ok_or_else(|| FftnError::InvalidAxis { ndim: a.ndim() })?; - valid_n.push(a.shape()[axis_index]); - } + (Some(s), None) => { + let valid_s = SmallVec::<[i32; 4]>::from_slice(s); + let valid_axes = (-(valid_s.len() as i32)..0).collect(); + (valid_s, valid_axes) } - } + (None, Some(axes)) => { + let valid_s = axes + .iter() + .map(|&axis| { + let axis_index = resolve_index_unchecked(axis, a.ndim()); + a.shape()[axis_index] + }) + .collect(); + let valid_axes = SmallVec::<[i32; 4]>::from_slice(axes); + (valid_s, valid_axes) + } + (None, None) => { + let valid_s: SmallVec<[i32; 4]> = (0..a.ndim()).map(|axis| a.shape()[axis]).collect(); + let valid_axes = (-(valid_s.len() as i32)..0).collect(); + (valid_s, valid_axes) + } + }; - // Check if shape and axes have the same size - if valid_n.len() != axes.len() { - return Err(FftnError::ShapeAxisMismatch); - } + fftn_device_inner(a, &valid_s, &valid_axes, stream) +} - Ok(fft2_device_unchecked_inner(a, &valid_n, axes, stream)) +/// N-dimensional discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or +/// padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes` +/// if not specified. +/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is +/// over the last `len(s)` axes are or all axes if `s` is also `None`. +/// +/// # Example +/// +/// ```rust +/// use mlx::{Dtype, Array, StreamOrDevice, complex64, fft::*}; +/// +/// let array = Array::ones::(&[3, 3, 3]); +/// +/// let mut result = try_fftn(&array, None, None).unwrap(); +/// result.eval(); +/// +/// assert_eq!(result.dtype(), Dtype::Complex64); +/// +/// let mut expected = vec![complex64::new(0.0, 0.0); 27]; +/// expected[0] = complex64::new(27.0, 0.0); +/// +/// assert_eq!(result.as_slice::(), &expected[..]); +/// ``` +#[default_device(device = "cpu")] // fft is not implemented on GPU yet +pub fn try_fftn_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Result { + let (valid_s, valid_axes) = super::try_resolve_sizes_and_axes(a, s, axes)?; + Ok(fftn_device_inner(a, &valid_s, &valid_axes, stream)) } -/// Two dimensional discrete Fourier Transform. +/// N-dimensional discrete Fourier Transform. /// /// # Params /// /// - `a`: The input array. -/// - `n`: Size of the transformed axes. The corresponding axes in the input are truncated or padded -/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. -/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. +/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or +/// padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes` +/// if not specified. +/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is +/// over the last `len(s)` axes are or all axes if `s` is also `None`. /// -/// See [`try_fft2_device`] for more details. +/// # Panic +/// +/// - if the input array is a scalar array +/// - if the axes are not unique +/// - if the shape and axes have different sizes +/// - if the output sizes are invalid (<= 0) +/// - if more axes are provided than the array has +/// +/// See [`try_fftn_device`] for more details. #[default_device(device = "cpu")] // fft is not implemented on GPU yet -pub fn fft2_device<'a>( +pub fn fftn_device<'a>( a: &'a Array, - n: impl Into>, + s: impl Into>, axes: impl Into>, stream: StreamOrDevice, ) -> Array { - try_fft2_device(a, n, axes, stream).unwrap() + try_fftn_device(a, s, axes, stream).unwrap() } // TODO: test out of bound indexing @@ -341,30 +429,6 @@ mod tests { assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); } - #[test] - fn test_fft_device_unchecked() { - use crate::{complex64, fft::*, Array, StreamOrDevice}; - - let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); - let s = StreamOrDevice::cpu(); - let mut result = unsafe { fft_device_unchecked(&array, 4, 0, s) }; - result.eval(); - - assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); - - let expected = &[ - complex64::new(10.0, 0.0), - complex64::new(-2.0, 2.0), - complex64::new(-2.0, 0.0), - complex64::new(-2.0, -2.0), - ]; - assert_eq!(result.as_slice::(), &expected[..]); - - // test that previous array is not modified and valid - let data: &[f32] = array.as_slice(); - assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); - } - #[test] fn test_try_fft() { use crate::{complex64, fft::*, Array, Dtype}; @@ -399,24 +463,16 @@ mod tests { } #[test] - fn test_try_fft_device() { - use crate::{complex64, fft::*, Array, StreamOrDevice}; + fn test_fft() { + use crate::{complex64, fft::*, Array, Dtype}; let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); - // Error case - let scalar_array = Array::from_float(1.0); - let result = try_fft_device(&scalar_array, 0, 0, StreamOrDevice::cpu()); - assert!(result.is_err()); - - let result = try_fft_device(&array, 4, 2, StreamOrDevice::cpu()); - assert!(result.is_err()); - // Success case - let mut result = try_fft_device(&array, 4, 0, StreamOrDevice::cpu()).unwrap(); + let mut result = fft(&array, 4, 0); result.eval(); - assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); + assert_eq!(result.dtype(), Dtype::Complex64); let expected = &[ complex64::new(10.0, 0.0), @@ -432,63 +488,86 @@ mod tests { } #[test] - fn test_fft() { + fn test_fft2_unchecked() { use crate::{complex64, fft::*, Array, Dtype}; - let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); - - // Success case - let mut result = fft(&array, 4, 0); + let array = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); + let n = [2, 2]; + let axes = [-2, -1]; + let mut result = unsafe { fft2_unchecked(&array, &n[..], &axes[..]) }; result.eval(); assert_eq!(result.dtype(), Dtype::Complex64); let expected = &[ - complex64::new(10.0, 0.0), - complex64::new(-2.0, 2.0), - complex64::new(-2.0, 0.0), - complex64::new(-2.0, -2.0), + complex64::new(4.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), ]; assert_eq!(result.as_slice::(), &expected[..]); // test that previous array is not modified and valid let data: &[f32] = array.as_slice(); - assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); + assert_eq!(data, &[1.0, 1.0, 1.0, 1.0]); } #[test] - fn test_fft_device() { - use crate::{complex64, fft::*, Array, Dtype, StreamOrDevice}; + fn test_try_fft2() { + use crate::{complex64, error::FftnError, fft::*, Array}; - let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); + let array = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); + + // Error case + let scalar_array = Array::from_float(1.0); + let result = try_fft2(&scalar_array, None, None); + assert_eq!(result.unwrap_err(), FftnError::ScalarArray); + + let result = try_fft2(&array, &[2, 2, 2][..], &[0, 1, 2][..]); + assert_eq!(result.unwrap_err(), FftnError::InvalidAxis { ndim: 2 }); + + let result = try_fft2(&array, &[2, 2][..], &[-1][..]); + assert_eq!( + result.unwrap_err(), + FftnError::IncompatibleShapeAndAxes { + shape_size: 2, + axes_size: 1, + } + ); + + let result = try_fft2(&array, None, &[-2, -2][..]); + assert_eq!(result.unwrap_err(), FftnError::DuplicateAxis { axis: -2 }); + + let result = try_fft2(&array, &[-2, 2][..], None); + assert_eq!(result.unwrap_err(), FftnError::InvalidOutputSize); // Success case - let mut result = fft_device(&array, 4, 0, StreamOrDevice::cpu()); + let mut result = try_fft2(&array, None, None).unwrap(); result.eval(); - assert_eq!(result.dtype(), Dtype::Complex64); + assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); let expected = &[ - complex64::new(10.0, 0.0), - complex64::new(-2.0, 2.0), - complex64::new(-2.0, 0.0), - complex64::new(-2.0, -2.0), + complex64::new(4.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), ]; assert_eq!(result.as_slice::(), &expected[..]); // test that previous array is not modified and valid let data: &[f32] = array.as_slice(); - assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); + assert_eq!(data, &[1.0, 1.0, 1.0, 1.0]); } #[test] - fn test_fft2_device_unchecked() { + fn test_fft2() { use crate::{complex64, fft::*, Array, Dtype}; let array = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); let n = [2, 2]; let axes = [-2, -1]; - let mut result = fft2_device_unchecked(&array, &n[..], &axes[..], StreamOrDevice::cpu()); + let mut result = fft2(&array, Some(&n[..]), Some(&axes[..])); result.eval(); assert_eq!(result.dtype(), Dtype::Complex64); @@ -507,70 +586,95 @@ mod tests { } #[test] - fn test_try_fft2_device() { - use crate::{complex64, error::FftnError, fft::*, Array, StreamOrDevice}; + fn test_fftn_unchecked() { + use crate::{complex64, fft::*, Array, Dtype}; - let array = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); + let array = Array::ones::(&[3, 3]); + let mut result = unsafe { fftn_unchecked(&array, None, None) }; + result.eval(); + + assert_eq!(result.dtype(), Dtype::Complex64); + + let expected = &[ + complex64::new(9.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + ]; + assert_eq!(result.as_slice::(), &expected[..]); + + // test that previous array is not modified and valid + let data: &[f32] = array.as_slice(); + assert_eq!(data, &[1.0; 9]); + } + + #[test] + fn test_try_fftn() { + use crate::{complex64, error::FftnError, fft::*, Array}; + + let array = Array::ones::(&[3, 3, 3]); // Error case let scalar_array = Array::from_float(1.0); - let result = try_fft2_device(&scalar_array, None, None, StreamOrDevice::cpu()); - assert!(result.is_err()); + let result = try_fftn(&scalar_array, None, None); assert_eq!(result.unwrap_err(), FftnError::ScalarArray); - let result = try_fft2_device(&array, &[2, 2, 2][..], None, StreamOrDevice::cpu()); - assert!(result.is_err()); - assert_eq!(result.unwrap_err(), FftnError::InvalidAxis { ndim: 2 }); + let result = try_fftn(&array, &[3, 3, 3, 3][..], &[0, 1, 2, 3][..]); + assert_eq!(result.unwrap_err(), FftnError::InvalidAxis { ndim: 3 }); - let result = try_fft2_device(&array, &[2, 2][..], &[-1][..], StreamOrDevice::cpu()); - assert!(result.is_err()); - assert_eq!(result.unwrap_err(), FftnError::ShapeAxisMismatch); + let result = try_fftn(&array, &[3, 3, 3][..], &[-1][..]); + assert_eq!( + result.unwrap_err(), + FftnError::IncompatibleShapeAndAxes { + shape_size: 3, + axes_size: 1, + } + ); - let result = try_fft2_device(&array, None, &[-2, -2][..], StreamOrDevice::cpu()); - assert!(result.is_err()); + let result = try_fftn(&array, None, &[-2, -2][..]); assert_eq!(result.unwrap_err(), FftnError::DuplicateAxis { axis: -2 }); + let result = try_fftn(&array, &[-2, 2][..], None); + assert_eq!(result.unwrap_err(), FftnError::InvalidOutputSize); + // Success case - let mut result = try_fft2_device(&array, None, None, StreamOrDevice::cpu()).unwrap(); + let mut result = try_fftn(&array, None, None).unwrap(); result.eval(); assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); - let expected = &[ - complex64::new(4.0, 0.0), - complex64::new(0.0, 0.0), - complex64::new(0.0, 0.0), - complex64::new(0.0, 0.0), - ]; + let mut expected = vec![complex64::new(0.0, 0.0); 27]; + expected[0] = complex64::new(27.0, 0.0); + assert_eq!(result.as_slice::(), &expected[..]); // test that previous array is not modified and valid let data: &[f32] = array.as_slice(); - assert_eq!(data, &[1.0, 1.0, 1.0, 1.0]); + assert_eq!(data, &[1.0; 27]); } #[test] - fn test_fft2_device() { + fn test_fftn() { use crate::{complex64, fft::*, Array, Dtype}; - let array = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); - let n = [2, 2]; - let axes = [-2, -1]; - let mut result = fft2_device(&array, Some(&n[..]), Some(&axes[..]), StreamOrDevice::cpu()); + let array = Array::ones::(&[3, 3, 3]); + let mut result = fftn(&array, None, None); result.eval(); assert_eq!(result.dtype(), Dtype::Complex64); - let expected = &[ - complex64::new(4.0, 0.0), - complex64::new(0.0, 0.0), - complex64::new(0.0, 0.0), - complex64::new(0.0, 0.0), - ]; + let mut expected = vec![complex64::new(0.0, 0.0); 27]; + expected[0] = complex64::new(27.0, 0.0); + assert_eq!(result.as_slice::(), &expected[..]); // test that previous array is not modified and valid let data: &[f32] = array.as_slice(); - assert_eq!(data, &[1.0, 1.0, 1.0, 1.0]); + assert_eq!(data, &[1.0; 27]); } } diff --git a/src/fft/ifftn.rs b/src/fft/ifftn.rs new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/src/fft/ifftn.rs @@ -0,0 +1 @@ + diff --git a/src/fft/irfftn.rs b/src/fft/irfftn.rs new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/src/fft/irfftn.rs @@ -0,0 +1 @@ + diff --git a/src/fft/mod.rs b/src/fft/mod.rs new file mode 100644 index 000000000..2f4560688 --- /dev/null +++ b/src/fft/mod.rs @@ -0,0 +1,97 @@ +mod fftn; +mod ifftn; +mod irfftn; +mod rfftn; + +use smallvec::SmallVec; + +use crate::{ + error::FftnError, + utils::{all_unique, resolve_index}, + Array, +}; + +pub use self::{fftn::*, ifftn::*, irfftn::*, rfftn::*}; + +#[inline] +fn try_resolve_size_and_axis( + a: &Array, + n: impl Into>, + axis: impl Into>, +) -> Result<(i32, i32), FftnError> { + if a.ndim() < 1 { + return Err(FftnError::ScalarArray); + } + + let axis = axis.into().unwrap_or(-1); + let axis_index = + resolve_index(axis, a.ndim()).ok_or_else(|| FftnError::InvalidAxis { ndim: a.ndim() })?; + let n = n.into().unwrap_or(a.shape()[axis_index]); + + Ok((n, axis)) +} + +// It's probably rare to perform fft on more than 4 axes +// TODO: check if this is a good default value +#[inline] +fn try_resolve_sizes_and_axes<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, +) -> Result<(SmallVec<[i32; 4]>, SmallVec<[i32; 4]>), FftnError> { + if a.ndim() < 1 { + return Err(FftnError::ScalarArray); + } + + let (valid_s, valid_axes) = match (s.into(), axes.into()) { + (Some(s), Some(axes)) => { + let valid_s = SmallVec::<[i32; 4]>::from_slice(s); + let valid_axes = SmallVec::<[i32; 4]>::from_slice(axes); + (valid_s, valid_axes) + } + (Some(s), None) => { + let valid_s = SmallVec::<[i32; 4]>::from_slice(s); + let valid_axes = (-(valid_s.len() as i32)..0).collect(); + (valid_s, valid_axes) + } + (None, Some(axes)) => { + // SmallVec somehow doesn't implement FromIterator with result + let mut valid_s = SmallVec::<[i32; 4]>::new(); + for &axis in axes { + let axis_index = resolve_index(axis, a.ndim()) + .ok_or_else(|| FftnError::InvalidAxis { ndim: a.ndim() })?; + valid_s.push(a.shape()[axis_index]); + } + let valid_axes = SmallVec::<[i32; 4]>::from_slice(axes); + (valid_s, valid_axes) + } + (None, None) => { + let valid_s: SmallVec<[i32; 4]> = (0..a.ndim()).map(|axis| a.shape()[axis]).collect(); + let valid_axes = (-(valid_s.len() as i32)..0).collect(); + (valid_s, valid_axes) + } + }; + + // Check duplicate axes + all_unique(&valid_axes).map_err(|axis| FftnError::DuplicateAxis { axis })?; + + // Check if shape and axes have the same size + if valid_s.len() != valid_axes.len() { + return Err(FftnError::IncompatibleShapeAndAxes { + shape_size: valid_s.len(), + axes_size: valid_axes.len(), + }); + } + + // Check if more axes are provided than the array has + if valid_s.len() > a.ndim() { + return Err(FftnError::InvalidAxis { ndim: a.ndim() }); + } + + // Check if output sizes are valid + if valid_s.iter().any(|val| *val <= 0) { + return Err(FftnError::InvalidOutputSize); + } + + Ok((valid_s, valid_axes)) +} diff --git a/src/fft/rfftn.rs b/src/fft/rfftn.rs new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/src/fft/rfftn.rs @@ -0,0 +1 @@ + diff --git a/src/utils.rs b/src/utils.rs index adcc43885..0b688587b 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -46,6 +46,17 @@ pub(crate) fn resolve_index(index: i32, len: usize) -> Option { } } +pub(crate) fn all_unique(arr: &[i32]) -> Result<(), i32> { + let mut unique = std::collections::HashSet::new(); + for &x in arr { + if !unique.insert(x) { + return Err(x); + } + } + + Ok(()) +} + /// Helper method to check if two arrays are broadcastable. /// /// Uses the same broadcasting rules as numpy. From fc776b6de82da5a258be0c3d28e6a263760578df Mon Sep 17 00:00:00 2001 From: minghuaw Date: Tue, 23 Apr 2024 01:48:51 -0700 Subject: [PATCH 13/25] check for negative output size in one dim fft --- src/fft/mod.rs | 96 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/src/fft/mod.rs b/src/fft/mod.rs index 2f4560688..230666bfc 100644 --- a/src/fft/mod.rs +++ b/src/fft/mod.rs @@ -28,6 +28,10 @@ fn try_resolve_size_and_axis( resolve_index(axis, a.ndim()).ok_or_else(|| FftnError::InvalidAxis { ndim: a.ndim() })?; let n = n.into().unwrap_or(a.shape()[axis_index]); + if n <= 0 { + return Err(FftnError::InvalidOutputSize); + } + Ok((n, axis)) } @@ -95,3 +99,95 @@ fn try_resolve_sizes_and_axes<'a>( Ok((valid_s, valid_axes)) } + +#[cfg(test)] +mod try_resolve_size_and_axis_tests { + use crate::Array; + + use super::{try_resolve_size_and_axis, FftnError}; + + #[test] + fn scalar_array_returns_error() { + // Returns an error if the array is a scalar + let a = Array::from_float(1.0); + let result = try_resolve_size_and_axis(&a, 0, 0); + assert_eq!(result, Err(FftnError::ScalarArray)); + } + + #[test] + fn out_of_bound_axis_returns_error() { + // Returns an error if the axis is invalid (out of bounds) + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let result = try_resolve_size_and_axis(&a, 0, 1); + assert_eq!(result, Err(FftnError::InvalidAxis { ndim: 1 })); + } + + #[test] + fn negative_output_size_returns_error() { + // Returns an error if the output size is negative + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let result = try_resolve_size_and_axis(&a, -1, 0); + assert_eq!(result, Err(FftnError::InvalidOutputSize)); + } + + #[test] + fn valid_input_returns_sizes_and_axis() { + // Returns the output size and axis if the input is valid + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let result = try_resolve_size_and_axis(&a, 4, 0); + assert_eq!(result, Ok((4, 0))); + } +} + +#[cfg(test)] +mod try_resolve_sizes_and_axes_tests { + use crate::Array; + + use super::{try_resolve_sizes_and_axes, FftnError}; + + #[test] + fn scalar_array_returns_error() { + // Returns an error if the array is a scalar + let a = Array::from_float(1.0); + let result = try_resolve_sizes_and_axes(&a, None, None); + assert_eq!(result, Err(FftnError::ScalarArray)); + } + + #[test] + fn out_of_bound_axis_returns_error() { + // Returns an error if the axis is invalid (out of bounds) + let a = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); + let result = try_resolve_sizes_and_axes(&a, &[2, 2, 2][..], &[0, 1, 2][..]); + assert_eq!(result, Err(FftnError::InvalidAxis { ndim: 2 })); + } + + #[test] + fn different_num_sizes_and_num_axes_returns_error() { + // Returns an error if the number of sizes and axes are different + let a = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); + let result = try_resolve_sizes_and_axes(&a, &[2, 2, 2][..], &[0, 1][..]); + assert_eq!( + result, + Err(FftnError::IncompatibleShapeAndAxes { + shape_size: 3, + axes_size: 2 + }) + ); + } + + #[test] + fn duplicate_axes_returns_error() { + // Returns an error if there are duplicate axes + let a = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); + let result = try_resolve_sizes_and_axes(&a, &[2, 2][..], &[0, 0][..]); + assert_eq!(result, Err(FftnError::DuplicateAxis { axis: 0 })); + } + + #[test] + fn negative_output_size_returns_error() { + // Returns an error if the output size is negative + let a = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); + let result = try_resolve_sizes_and_axes(&a, &[-2, 2][..], None); + assert_eq!(result, Err(FftnError::InvalidOutputSize)); + } +} From 551ebbe7a320d00afa608fd05afd1a6a0eeadb63 Mon Sep 17 00:00:00 2001 From: minghuaw Date: Tue, 23 Apr 2024 02:14:52 -0700 Subject: [PATCH 14/25] init impl of ifftn, rfftn, irfftn --- src/error.rs | 2 +- src/fft/fftn.rs | 97 +++++++++------------------------ src/fft/ifftn.rs | 136 ++++++++++++++++++++++++++++++++++++++++++++++ src/fft/irfftn.rs | 136 ++++++++++++++++++++++++++++++++++++++++++++++ src/fft/mod.rs | 96 ++++++++++++++++++++++++-------- src/fft/rfftn.rs | 136 ++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 509 insertions(+), 94 deletions(-) diff --git a/src/error.rs b/src/error.rs index d8e23f77c..1757e39dd 100644 --- a/src/error.rs +++ b/src/error.rs @@ -53,7 +53,7 @@ pub enum AsSliceError { } #[derive(Error, Debug, PartialEq)] -pub enum FftnError { +pub enum FftError { #[error("fftn requires at least one dimension")] ScalarArray, diff --git a/src/fft/fftn.rs b/src/fft/fftn.rs index 1e0196cec..686d41792 100644 --- a/src/fft/fftn.rs +++ b/src/fft/fftn.rs @@ -1,9 +1,8 @@ use mlx_macros::default_device; -use smallvec::SmallVec; -use crate::{ - array::Array, error::FftnError, stream::StreamOrDevice, utils::resolve_index_unchecked, -}; +use crate::{array::Array, error::FftError, stream::StreamOrDevice}; + +use super::resolve_size_and_axis_unchecked; /// One dimensional discrete Fourier Transform. /// @@ -45,11 +44,7 @@ pub unsafe fn fft_device_unchecked( axis: impl Into>, stream: StreamOrDevice, ) -> Array { - let axis = axis.into().unwrap_or(-1); - let n = n.into().unwrap_or_else(|| { - let axis_index = resolve_index_unchecked(axis, a.ndim()); - a.shape()[axis_index] - }); + let (n, axis) = resolve_size_and_axis_unchecked(a, n, axis); unsafe { let c_array = mlx_sys::mlx_fft_fft(a.c_array, n, axis, stream.stream.c_stream); Array::from_ptr(c_array) @@ -90,9 +85,9 @@ pub fn try_fft_device( n: impl Into>, axis: impl Into>, stream: StreamOrDevice, -) -> Result { +) -> Result { let (n, axis) = super::try_resolve_size_and_axis(a, n, axis)?; - Ok(unsafe { fft_device_unchecked(a, Some(n), Some(axis), stream) }) + unsafe { Ok(fft_device_unchecked(a, Some(n), Some(axis), stream)) } } /// One dimensional discrete Fourier Transform. @@ -119,16 +114,16 @@ pub fn fft_device( try_fft_device(a, n, axis, stream).unwrap() } -fn fft2_device_inner(a: &Array, n: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { +fn fft2_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { + let num_s = s.len(); let num_axes = axes.len(); - let num_n = n.len(); - let n_ptr = n.as_ptr(); + let s_ptr = s.as_ptr(); let axes_ptr = axes.as_ptr(); unsafe { let c_array = - mlx_sys::mlx_fft_fft2(a.c_array, n_ptr, num_n, axes_ptr, num_axes, stream.as_ptr()); + mlx_sys::mlx_fft_fft2(a.c_array, s_ptr, num_s, axes_ptr, num_axes, stream.as_ptr()); Array::from_ptr(c_array) } } @@ -171,18 +166,8 @@ pub unsafe fn fft2_device_unchecked<'a>( stream: StreamOrDevice, ) -> Array { let axes = axes.into().unwrap_or(&[-2, -1]); - let mut valid_n = SmallVec::<[i32; 2]>::new(); - match s.into() { - Some(s) => valid_n.extend_from_slice(&s), - None => { - for axis in axes { - let axis_index = resolve_index_unchecked(*axis, a.ndim()); - valid_n.push(a.shape()[axis_index]); - } - } - } - - fft2_device_inner(a, &valid_n, axes, stream) + let (valid_s, valid_axes) = super::resolve_sizes_and_axes_unchecked(a, s, axes); + fft2_device_inner(a, &valid_s, &valid_axes, stream) } /// Two dimensional discrete Fourier Transform. @@ -217,7 +202,7 @@ pub fn try_fft2_device<'a>( s: impl Into>, axes: impl Into>, stream: StreamOrDevice, -) -> Result { +) -> Result { let valid_axes = axes.into().unwrap_or(&[-2, -1]); let (valid_s, valid_axes) = super::try_resolve_sizes_and_axes(a, s, valid_axes)?; Ok(fft2_device_inner(a, &valid_s, &valid_axes, stream)) @@ -302,35 +287,7 @@ pub unsafe fn fftn_device_unchecked<'a>( axes: impl Into>, stream: StreamOrDevice, ) -> Array { - let (valid_s, valid_axes) = match (s.into(), axes.into()) { - (Some(s), Some(axes)) => { - let valid_s = SmallVec::<[i32; 4]>::from_slice(s); - let valid_axes = SmallVec::<[i32; 4]>::from_slice(axes); - (valid_s, valid_axes) - } - (Some(s), None) => { - let valid_s = SmallVec::<[i32; 4]>::from_slice(s); - let valid_axes = (-(valid_s.len() as i32)..0).collect(); - (valid_s, valid_axes) - } - (None, Some(axes)) => { - let valid_s = axes - .iter() - .map(|&axis| { - let axis_index = resolve_index_unchecked(axis, a.ndim()); - a.shape()[axis_index] - }) - .collect(); - let valid_axes = SmallVec::<[i32; 4]>::from_slice(axes); - (valid_s, valid_axes) - } - (None, None) => { - let valid_s: SmallVec<[i32; 4]> = (0..a.ndim()).map(|axis| a.shape()[axis]).collect(); - let valid_axes = (-(valid_s.len() as i32)..0).collect(); - (valid_s, valid_axes) - } - }; - + let (valid_s, valid_axes) = super::resolve_sizes_and_axes_unchecked(a, s, axes); fftn_device_inner(a, &valid_s, &valid_axes, stream) } @@ -368,7 +325,7 @@ pub fn try_fftn_device<'a>( s: impl Into>, axes: impl Into>, stream: StreamOrDevice, -) -> Result { +) -> Result { let (valid_s, valid_axes) = super::try_resolve_sizes_and_axes(a, s, axes)?; Ok(fftn_device_inner(a, &valid_s, &valid_axes, stream)) } @@ -514,32 +471,32 @@ mod tests { #[test] fn test_try_fft2() { - use crate::{complex64, error::FftnError, fft::*, Array}; + use crate::{complex64, error::FftError, fft::*, Array}; let array = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); // Error case let scalar_array = Array::from_float(1.0); let result = try_fft2(&scalar_array, None, None); - assert_eq!(result.unwrap_err(), FftnError::ScalarArray); + assert_eq!(result.unwrap_err(), FftError::ScalarArray); let result = try_fft2(&array, &[2, 2, 2][..], &[0, 1, 2][..]); - assert_eq!(result.unwrap_err(), FftnError::InvalidAxis { ndim: 2 }); + assert_eq!(result.unwrap_err(), FftError::InvalidAxis { ndim: 2 }); let result = try_fft2(&array, &[2, 2][..], &[-1][..]); assert_eq!( result.unwrap_err(), - FftnError::IncompatibleShapeAndAxes { + FftError::IncompatibleShapeAndAxes { shape_size: 2, axes_size: 1, } ); let result = try_fft2(&array, None, &[-2, -2][..]); - assert_eq!(result.unwrap_err(), FftnError::DuplicateAxis { axis: -2 }); + assert_eq!(result.unwrap_err(), FftError::DuplicateAxis { axis: -2 }); let result = try_fft2(&array, &[-2, 2][..], None); - assert_eq!(result.unwrap_err(), FftnError::InvalidOutputSize); + assert_eq!(result.unwrap_err(), FftError::InvalidOutputSize); // Success case let mut result = try_fft2(&array, None, None).unwrap(); @@ -615,32 +572,32 @@ mod tests { #[test] fn test_try_fftn() { - use crate::{complex64, error::FftnError, fft::*, Array}; + use crate::{complex64, error::FftError, fft::*, Array}; let array = Array::ones::(&[3, 3, 3]); // Error case let scalar_array = Array::from_float(1.0); let result = try_fftn(&scalar_array, None, None); - assert_eq!(result.unwrap_err(), FftnError::ScalarArray); + assert_eq!(result.unwrap_err(), FftError::ScalarArray); let result = try_fftn(&array, &[3, 3, 3, 3][..], &[0, 1, 2, 3][..]); - assert_eq!(result.unwrap_err(), FftnError::InvalidAxis { ndim: 3 }); + assert_eq!(result.unwrap_err(), FftError::InvalidAxis { ndim: 3 }); let result = try_fftn(&array, &[3, 3, 3][..], &[-1][..]); assert_eq!( result.unwrap_err(), - FftnError::IncompatibleShapeAndAxes { + FftError::IncompatibleShapeAndAxes { shape_size: 3, axes_size: 1, } ); let result = try_fftn(&array, None, &[-2, -2][..]); - assert_eq!(result.unwrap_err(), FftnError::DuplicateAxis { axis: -2 }); + assert_eq!(result.unwrap_err(), FftError::DuplicateAxis { axis: -2 }); let result = try_fftn(&array, &[-2, 2][..], None); - assert_eq!(result.unwrap_err(), FftnError::InvalidOutputSize); + assert_eq!(result.unwrap_err(), FftError::InvalidOutputSize); // Success case let mut result = try_fftn(&array, None, None).unwrap(); diff --git a/src/fft/ifftn.rs b/src/fft/ifftn.rs index 8b1378917..a533ab116 100644 --- a/src/fft/ifftn.rs +++ b/src/fft/ifftn.rs @@ -1 +1,137 @@ +use mlx_macros::default_device; +use crate::{error::FftError, Array, StreamOrDevice}; + +use super::{ + resolve_size_and_axis_unchecked, resolve_sizes_and_axes_unchecked, try_resolve_size_and_axis, + try_resolve_sizes_and_axes, +}; + +#[default_device(device = "cpu")] +pub unsafe fn ifft_device_unchecked( + a: &Array, + n: impl Into>, + axis: impl Into>, + stream: StreamOrDevice, +) -> Array { + let (n, axis) = resolve_size_and_axis_unchecked(a, n, axis); + unsafe { + let c_array = mlx_sys::mlx_fft_ifft(a.c_array, n, axis, stream.stream.c_stream); + Array::from_ptr(c_array) + } +} + +#[default_device(device = "cpu")] +pub fn try_ifft_device( + a: &Array, + n: impl Into>, + axis: impl Into>, + stream: StreamOrDevice, +) -> Result { + let (n, axis) = try_resolve_size_and_axis(a, n, axis)?; + unsafe { Ok(ifft_device_unchecked(a, n, axis, stream)) } +} + +#[default_device(device = "cpu")] +pub fn ifft_device( + a: &Array, + n: impl Into>, + axis: impl Into>, + stream: StreamOrDevice, +) -> Array { + try_ifft_device(a, n, axis, stream).unwrap() +} + +fn ifft2_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { + let num_s = s.len(); + let num_axes = axes.len(); + + let s_ptr = s.as_ptr(); + let axes_ptr = axes.as_ptr(); + + unsafe { + let c_array = + mlx_sys::mlx_fft_ifft2(a.c_array, s_ptr, num_s, axes_ptr, num_axes, stream.as_ptr()); + Array::from_ptr(c_array) + } +} + +#[default_device(device = "cpu")] +pub unsafe fn ifft2_device_unchecked<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + let axes = axes.into().unwrap_or(&[-2, -1]); + let (valid_s, valid_axes) = resolve_sizes_and_axes_unchecked(a, s, axes); + ifft2_device_inner(a, &valid_s, &valid_axes, stream) +} + +#[default_device(device = "cpu")] +pub fn try_ifft2_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Result { + let axes = axes.into().unwrap_or(&[-2, -1]); + let (valid_s, valid_axes) = try_resolve_sizes_and_axes(a, s, axes)?; + Ok(ifft2_device_inner(a, &valid_s, &valid_axes, stream)) +} + +#[default_device(device = "cpu")] +pub fn ifft2_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + try_ifft2_device(a, s, axes, stream).unwrap() +} + +fn ifftn_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { + let num_s = s.len(); + let num_axes = axes.len(); + + let s_ptr = s.as_ptr(); + let axes_ptr = axes.as_ptr(); + + unsafe { + let c_array = + mlx_sys::mlx_fft_ifftn(a.c_array, s_ptr, num_s, axes_ptr, num_axes, stream.as_ptr()); + Array::from_ptr(c_array) + } +} + +#[default_device(device = "cpu")] +pub unsafe fn ifftn_device_unchecked<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + let (valid_s, valid_axes) = resolve_sizes_and_axes_unchecked(a, s, axes); + ifftn_device_inner(a, &valid_s, &valid_axes, stream) +} + +#[default_device(device = "cpu")] +pub fn try_ifftn_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Result { + let (valid_s, valid_axes) = try_resolve_sizes_and_axes(a, s, axes)?; + Ok(ifftn_device_inner(a, &valid_s, &valid_axes, stream)) +} + +#[default_device(device = "cpu")] +pub fn ifftn_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + try_ifftn_device(a, s, axes, stream).unwrap() +} diff --git a/src/fft/irfftn.rs b/src/fft/irfftn.rs index 8b1378917..3d50b3c9a 100644 --- a/src/fft/irfftn.rs +++ b/src/fft/irfftn.rs @@ -1 +1,137 @@ +use mlx_macros::default_device; +use crate::{error::FftError, Array, StreamOrDevice}; + +use super::{ + resolve_size_and_axis_unchecked, resolve_sizes_and_axes_unchecked, try_resolve_size_and_axis, + try_resolve_sizes_and_axes, +}; + +#[default_device(device = "cpu")] +pub unsafe fn irfft_device_unchecked( + a: &Array, + n: impl Into>, + axis: impl Into>, + stream: StreamOrDevice, +) -> Array { + let (n, axis) = resolve_size_and_axis_unchecked(a, n, axis); + unsafe { + let c_array = mlx_sys::mlx_fft_irfft(a.c_array, n, axis, stream.stream.c_stream); + Array::from_ptr(c_array) + } +} + +#[default_device(device = "cpu")] +pub fn try_irfft_device( + a: &Array, + n: impl Into>, + axis: impl Into>, + stream: StreamOrDevice, +) -> Result { + let (n, axis) = try_resolve_size_and_axis(a, n, axis)?; + unsafe { Ok(irfft_device_unchecked(a, n, axis, stream)) } +} + +#[default_device(device = "cpu")] +pub fn irfft_device( + a: &Array, + n: impl Into>, + axis: impl Into>, + stream: StreamOrDevice, +) -> Array { + try_irfft_device(a, n, axis, stream).unwrap() +} + +fn irfft2_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { + let num_s = s.len(); + let num_axes = axes.len(); + + let s_ptr = s.as_ptr(); + let axes_ptr = axes.as_ptr(); + + unsafe { + let c_array = + mlx_sys::mlx_fft_irfft2(a.c_array, s_ptr, num_s, axes_ptr, num_axes, stream.as_ptr()); + Array::from_ptr(c_array) + } +} + +#[default_device(device = "cpu")] +pub unsafe fn irfft2_device_unchecked<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + let axes = axes.into().unwrap_or(&[-2, -1]); + let (valid_s, valid_axes) = resolve_sizes_and_axes_unchecked(a, s, axes); + irfft2_device_inner(a, &valid_s, &valid_axes, stream) +} + +#[default_device(device = "cpu")] +pub fn try_irfft2_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Result { + let axes = axes.into().unwrap_or(&[-2, -1]); + let (valid_s, valid_axes) = try_resolve_sizes_and_axes(a, s, axes)?; + Ok(irfft2_device_inner(a, &valid_s, &valid_axes, stream)) +} + +#[default_device(device = "cpu")] +pub fn irfft2_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + try_irfft2_device(a, s, axes, stream).unwrap() +} + +fn irfftn_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { + let num_s = s.len(); + let num_axes = axes.len(); + + let s_ptr = s.as_ptr(); + let axes_ptr = axes.as_ptr(); + + unsafe { + let c_array = + mlx_sys::mlx_fft_irfftn(a.c_array, s_ptr, num_s, axes_ptr, num_axes, stream.as_ptr()); + Array::from_ptr(c_array) + } +} + +#[default_device(device = "cpu")] +pub unsafe fn irfftn_device_unchecked<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + let (valid_s, valid_axes) = resolve_sizes_and_axes_unchecked(a, s, axes); + irfftn_device_inner(a, &valid_s, &valid_axes, stream) +} + +#[default_device(device = "cpu")] +pub fn try_irfftn_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Result { + let (valid_s, valid_axes) = try_resolve_sizes_and_axes(a, s, axes)?; + Ok(irfftn_device_inner(a, &valid_s, &valid_axes, stream)) +} + +#[default_device(device = "cpu")] +pub fn irfftn_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + try_irfftn_device(a, s, axes, stream).unwrap() +} diff --git a/src/fft/mod.rs b/src/fft/mod.rs index 230666bfc..5f690f978 100644 --- a/src/fft/mod.rs +++ b/src/fft/mod.rs @@ -6,35 +6,85 @@ mod rfftn; use smallvec::SmallVec; use crate::{ - error::FftnError, - utils::{all_unique, resolve_index}, + error::FftError, + utils::{all_unique, resolve_index, resolve_index_unchecked}, Array, }; pub use self::{fftn::*, ifftn::*, irfftn::*, rfftn::*}; +#[inline] +fn resolve_size_and_axis_unchecked( + a: &Array, + n: impl Into>, + axis: impl Into>, +) -> (i32, i32) { + let axis = axis.into().unwrap_or(-1); + let n = n.into().unwrap_or_else(|| { + let axis_index = resolve_index_unchecked(axis, a.ndim()); + a.shape()[axis_index] + }); + (n, axis) +} + #[inline] fn try_resolve_size_and_axis( a: &Array, n: impl Into>, axis: impl Into>, -) -> Result<(i32, i32), FftnError> { +) -> Result<(i32, i32), FftError> { if a.ndim() < 1 { - return Err(FftnError::ScalarArray); + return Err(FftError::ScalarArray); } let axis = axis.into().unwrap_or(-1); let axis_index = - resolve_index(axis, a.ndim()).ok_or_else(|| FftnError::InvalidAxis { ndim: a.ndim() })?; + resolve_index(axis, a.ndim()).ok_or_else(|| FftError::InvalidAxis { ndim: a.ndim() })?; let n = n.into().unwrap_or(a.shape()[axis_index]); if n <= 0 { - return Err(FftnError::InvalidOutputSize); + return Err(FftError::InvalidOutputSize); } Ok((n, axis)) } +#[inline] +fn resolve_sizes_and_axes_unchecked<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, +) -> (SmallVec<[i32; 4]>, SmallVec<[i32; 4]>) { + match (s.into(), axes.into()) { + (Some(s), Some(axes)) => { + let valid_s = SmallVec::<[i32; 4]>::from_slice(s); + let valid_axes = SmallVec::<[i32; 4]>::from_slice(axes); + (valid_s, valid_axes) + } + (Some(s), None) => { + let valid_s = SmallVec::<[i32; 4]>::from_slice(s); + let valid_axes = (-(valid_s.len() as i32)..0).collect(); + (valid_s, valid_axes) + } + (None, Some(axes)) => { + let valid_s = axes + .iter() + .map(|&axis| { + let axis_index = resolve_index_unchecked(axis, a.ndim()); + a.shape()[axis_index] + }) + .collect(); + let valid_axes = SmallVec::<[i32; 4]>::from_slice(axes); + (valid_s, valid_axes) + } + (None, None) => { + let valid_s: SmallVec<[i32; 4]> = (0..a.ndim()).map(|axis| a.shape()[axis]).collect(); + let valid_axes = (-(valid_s.len() as i32)..0).collect(); + (valid_s, valid_axes) + } + } +} + // It's probably rare to perform fft on more than 4 axes // TODO: check if this is a good default value #[inline] @@ -42,9 +92,9 @@ fn try_resolve_sizes_and_axes<'a>( a: &'a Array, s: impl Into>, axes: impl Into>, -) -> Result<(SmallVec<[i32; 4]>, SmallVec<[i32; 4]>), FftnError> { +) -> Result<(SmallVec<[i32; 4]>, SmallVec<[i32; 4]>), FftError> { if a.ndim() < 1 { - return Err(FftnError::ScalarArray); + return Err(FftError::ScalarArray); } let (valid_s, valid_axes) = match (s.into(), axes.into()) { @@ -63,7 +113,7 @@ fn try_resolve_sizes_and_axes<'a>( let mut valid_s = SmallVec::<[i32; 4]>::new(); for &axis in axes { let axis_index = resolve_index(axis, a.ndim()) - .ok_or_else(|| FftnError::InvalidAxis { ndim: a.ndim() })?; + .ok_or_else(|| FftError::InvalidAxis { ndim: a.ndim() })?; valid_s.push(a.shape()[axis_index]); } let valid_axes = SmallVec::<[i32; 4]>::from_slice(axes); @@ -77,11 +127,11 @@ fn try_resolve_sizes_and_axes<'a>( }; // Check duplicate axes - all_unique(&valid_axes).map_err(|axis| FftnError::DuplicateAxis { axis })?; + all_unique(&valid_axes).map_err(|axis| FftError::DuplicateAxis { axis })?; // Check if shape and axes have the same size if valid_s.len() != valid_axes.len() { - return Err(FftnError::IncompatibleShapeAndAxes { + return Err(FftError::IncompatibleShapeAndAxes { shape_size: valid_s.len(), axes_size: valid_axes.len(), }); @@ -89,12 +139,12 @@ fn try_resolve_sizes_and_axes<'a>( // Check if more axes are provided than the array has if valid_s.len() > a.ndim() { - return Err(FftnError::InvalidAxis { ndim: a.ndim() }); + return Err(FftError::InvalidAxis { ndim: a.ndim() }); } // Check if output sizes are valid if valid_s.iter().any(|val| *val <= 0) { - return Err(FftnError::InvalidOutputSize); + return Err(FftError::InvalidOutputSize); } Ok((valid_s, valid_axes)) @@ -104,14 +154,14 @@ fn try_resolve_sizes_and_axes<'a>( mod try_resolve_size_and_axis_tests { use crate::Array; - use super::{try_resolve_size_and_axis, FftnError}; + use super::{try_resolve_size_and_axis, FftError}; #[test] fn scalar_array_returns_error() { // Returns an error if the array is a scalar let a = Array::from_float(1.0); let result = try_resolve_size_and_axis(&a, 0, 0); - assert_eq!(result, Err(FftnError::ScalarArray)); + assert_eq!(result, Err(FftError::ScalarArray)); } #[test] @@ -119,7 +169,7 @@ mod try_resolve_size_and_axis_tests { // Returns an error if the axis is invalid (out of bounds) let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); let result = try_resolve_size_and_axis(&a, 0, 1); - assert_eq!(result, Err(FftnError::InvalidAxis { ndim: 1 })); + assert_eq!(result, Err(FftError::InvalidAxis { ndim: 1 })); } #[test] @@ -127,7 +177,7 @@ mod try_resolve_size_and_axis_tests { // Returns an error if the output size is negative let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); let result = try_resolve_size_and_axis(&a, -1, 0); - assert_eq!(result, Err(FftnError::InvalidOutputSize)); + assert_eq!(result, Err(FftError::InvalidOutputSize)); } #[test] @@ -143,14 +193,14 @@ mod try_resolve_size_and_axis_tests { mod try_resolve_sizes_and_axes_tests { use crate::Array; - use super::{try_resolve_sizes_and_axes, FftnError}; + use super::{try_resolve_sizes_and_axes, FftError}; #[test] fn scalar_array_returns_error() { // Returns an error if the array is a scalar let a = Array::from_float(1.0); let result = try_resolve_sizes_and_axes(&a, None, None); - assert_eq!(result, Err(FftnError::ScalarArray)); + assert_eq!(result, Err(FftError::ScalarArray)); } #[test] @@ -158,7 +208,7 @@ mod try_resolve_sizes_and_axes_tests { // Returns an error if the axis is invalid (out of bounds) let a = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); let result = try_resolve_sizes_and_axes(&a, &[2, 2, 2][..], &[0, 1, 2][..]); - assert_eq!(result, Err(FftnError::InvalidAxis { ndim: 2 })); + assert_eq!(result, Err(FftError::InvalidAxis { ndim: 2 })); } #[test] @@ -168,7 +218,7 @@ mod try_resolve_sizes_and_axes_tests { let result = try_resolve_sizes_and_axes(&a, &[2, 2, 2][..], &[0, 1][..]); assert_eq!( result, - Err(FftnError::IncompatibleShapeAndAxes { + Err(FftError::IncompatibleShapeAndAxes { shape_size: 3, axes_size: 2 }) @@ -180,7 +230,7 @@ mod try_resolve_sizes_and_axes_tests { // Returns an error if there are duplicate axes let a = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); let result = try_resolve_sizes_and_axes(&a, &[2, 2][..], &[0, 0][..]); - assert_eq!(result, Err(FftnError::DuplicateAxis { axis: 0 })); + assert_eq!(result, Err(FftError::DuplicateAxis { axis: 0 })); } #[test] @@ -188,6 +238,6 @@ mod try_resolve_sizes_and_axes_tests { // Returns an error if the output size is negative let a = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); let result = try_resolve_sizes_and_axes(&a, &[-2, 2][..], None); - assert_eq!(result, Err(FftnError::InvalidOutputSize)); + assert_eq!(result, Err(FftError::InvalidOutputSize)); } } diff --git a/src/fft/rfftn.rs b/src/fft/rfftn.rs index 8b1378917..3a96ad98f 100644 --- a/src/fft/rfftn.rs +++ b/src/fft/rfftn.rs @@ -1 +1,137 @@ +use mlx_macros::default_device; +use crate::{error::FftError, Array, StreamOrDevice}; + +use super::{ + resolve_size_and_axis_unchecked, resolve_sizes_and_axes_unchecked, try_resolve_size_and_axis, + try_resolve_sizes_and_axes, +}; + +#[default_device(device = "cpu")] +pub unsafe fn rfft_device_unchecked( + a: &Array, + n: impl Into>, + axis: impl Into>, + stream: StreamOrDevice, +) -> Array { + let (n, axis) = resolve_size_and_axis_unchecked(a, n, axis); + unsafe { + let c_array = mlx_sys::mlx_fft_rfft(a.c_array, n, axis, stream.stream.c_stream); + Array::from_ptr(c_array) + } +} + +#[default_device(device = "cpu")] +pub fn try_rfft_device( + a: &Array, + n: impl Into>, + axis: impl Into>, + stream: StreamOrDevice, +) -> Result { + let (n, axis) = try_resolve_size_and_axis(a, n, axis)?; + unsafe { Ok(rfft_device_unchecked(a, n, axis, stream)) } +} + +#[default_device(device = "cpu")] +pub fn rfft_device( + a: &Array, + n: impl Into>, + axis: impl Into>, + stream: StreamOrDevice, +) -> Array { + try_rfft_device(a, n, axis, stream).unwrap() +} + +fn rfft2_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { + let num_s = s.len(); + let num_axes = axes.len(); + + let s_ptr = s.as_ptr(); + let axes_ptr = axes.as_ptr(); + + unsafe { + let c_array = + mlx_sys::mlx_fft_rfft2(a.c_array, s_ptr, num_s, axes_ptr, num_axes, stream.as_ptr()); + Array::from_ptr(c_array) + } +} + +#[default_device(device = "cpu")] +pub unsafe fn rfft2_device_unchecked<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + let axes = axes.into().unwrap_or(&[-2, -1]); + let (valid_s, valid_axes) = resolve_sizes_and_axes_unchecked(a, s, axes); + rfft2_device_inner(a, &valid_s, &valid_axes, stream) +} + +#[default_device(device = "cpu")] +pub fn try_rfft2_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Result { + let axes = axes.into().unwrap_or(&[-2, -1]); + let (valid_s, valid_axes) = try_resolve_sizes_and_axes(a, s, axes)?; + Ok(rfft2_device_inner(a, &valid_s, &valid_axes, stream)) +} + +#[default_device(device = "cpu")] +pub fn rfft2_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + try_rfft2_device(a, s, axes, stream).unwrap() +} + +fn rfftn_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { + let num_s = s.len(); + let num_axes = axes.len(); + + let s_ptr = s.as_ptr(); + let axes_ptr = axes.as_ptr(); + + unsafe { + let c_array = + mlx_sys::mlx_fft_rfftn(a.c_array, s_ptr, num_s, axes_ptr, num_axes, stream.as_ptr()); + Array::from_ptr(c_array) + } +} + +#[default_device(device = "cpu")] +pub unsafe fn rfftn_device_unchecked<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + let (valid_s, valid_axes) = resolve_sizes_and_axes_unchecked(a, s, axes); + rfftn_device_inner(a, &valid_s, &valid_axes, stream) +} + +#[default_device(device = "cpu")] +pub fn try_rfftn_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Result { + let (valid_s, valid_axes) = try_resolve_sizes_and_axes(a, s, axes)?; + Ok(rfftn_device_inner(a, &valid_s, &valid_axes, stream)) +} + +#[default_device(device = "cpu")] +pub fn rfftn_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + try_rfftn_device(a, s, axes, stream).unwrap() +} From ae17488c85d2330e7043a3f310d18a97be29a1d0 Mon Sep 17 00:00:00 2001 From: minghuaw Date: Tue, 23 Apr 2024 11:23:25 -0700 Subject: [PATCH 15/25] added docs --- src/fft/fftn.rs | 155 ++++--------------------------------------- src/fft/ifftn.rs | 94 ++++++++++++++++++++++++++ src/fft/mod.rs | 169 +++++++++++++++++++++++++++++++++++++++++++++++ src/fft/rfftn.rs | 120 +++++++++++++++++++++++++++++++++ 4 files changed, 397 insertions(+), 141 deletions(-) diff --git a/src/fft/fftn.rs b/src/fft/fftn.rs index 686d41792..1544cfab1 100644 --- a/src/fft/fftn.rs +++ b/src/fft/fftn.rs @@ -12,32 +12,7 @@ use super::resolve_size_and_axis_unchecked; /// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded /// with zeros to match `n`. The default value is `a.shape[axis]`. /// - `axis`: Axis along which to perform the FFT. The default is -1. -/// -/// # Example -/// -/// ```rust -/// use mlx::{Dtype, Array, StreamOrDevice, complex64, fft::*}; -/// -/// let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); -/// let s = StreamOrDevice::cpu(); -/// let mut result = unsafe { fft_device_unchecked(&array, 4, 0, s) }; -/// result.eval(); -/// -/// assert_eq!(result.dtype(), Dtype::Complex64); -/// -/// let expected = &[ -/// complex64::new(10.0, 0.0), -/// complex64::new(-2.0, 2.0), -/// complex64::new(-2.0, 0.0), -/// complex64::new(-2.0, -2.0), -/// ]; -/// assert_eq!(result.as_slice::(), &expected[..]); -/// -/// // test that previous array is not modified and valid -/// let data: &[f32] = array.as_slice(); -/// assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); -/// ``` -#[default_device(device = "cpu")] // fft is not implemented on GPU yet +#[default_device(device = "cpu")] pub unsafe fn fft_device_unchecked( a: &Array, n: impl Into>, @@ -59,27 +34,7 @@ pub unsafe fn fft_device_unchecked( /// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded /// with zeros to match `n`. The default value is `a.shape[axis]`. /// - `axis`: Axis along which to perform the FFT. The default is -1. -/// -/// # Example -/// -/// ```rust -/// use mlx::{Dtype, Array, StreamOrDevice, complex64, fft::*}; -/// -/// let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); -/// let mut result = try_fft_device(&array, 4, 0, StreamOrDevice::cpu()).unwrap(); -/// result.eval(); -/// -/// assert_eq!(result.dtype(), Dtype::Complex64); -/// -/// let expected = &[ -/// complex64::new(10.0, 0.0), -/// complex64::new(-2.0, 2.0), -/// complex64::new(-2.0, 0.0), -/// complex64::new(-2.0, -2.0), -/// ]; -/// assert_eq!(result.as_slice::(), &expected[..]); -/// ``` -#[default_device(device = "cpu")] // fft is not implemented on GPU yet +#[default_device(device = "cpu")] pub fn try_fft_device( a: &Array, n: impl Into>, @@ -104,7 +59,7 @@ pub fn try_fft_device( /// Panics if the input array is a scalar or if the axis is invalid. /// /// See [`try_fft_device`] for more details. -#[default_device(device = "cpu")] // fft is not implemented on GPU yet +#[default_device(device = "cpu")] pub fn fft_device( a: &Array, n: impl Into>, @@ -136,29 +91,7 @@ fn fft2_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) /// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded /// with zeros to match `n`. The default value is the sizes of `a` along `axes`. /// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. -/// -/// # Example -/// -/// ```rust -/// use mlx::{Dtype, Array, StreamOrDevice, complex64, fft::*}; -/// -/// let array = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); -/// let mut result = unsafe { -/// fft2_device_unchecked(&array, &[2, 2][..], &[-2,-1][..], StreamOrDevice::cpu()) -/// }; -/// result.eval(); -/// -/// assert_eq!(result.dtype(), Dtype::Complex64); -/// -/// let expected = &[ -/// complex64::new(4.0, 0.0), -/// complex64::new(0.0, 0.0), -/// complex64::new(0.0, 0.0), -/// complex64::new(0.0, 0.0), -/// ]; -/// assert_eq!(result.as_slice::(), &expected[..]); -/// ``` -#[default_device(device = "cpu")] // fft is not implemented on GPU yet +#[default_device(device = "cpu")] pub unsafe fn fft2_device_unchecked<'a>( a: &'a Array, s: impl Into>, @@ -178,25 +111,7 @@ pub unsafe fn fft2_device_unchecked<'a>( /// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded /// with zeros to match `n`. The default value is the sizes of `a` along `axes`. /// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. -/// -/// # Example -/// -/// ```rust -/// use mlx::{Dtype, Array, StreamOrDevice, complex64, fft::*}; -/// -/// let array = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); -/// let mut result = try_fft2_device(&array, None, None, StreamOrDevice::cpu()).unwrap(); -/// result.eval(); -/// assert_eq!(result.dtype(), Dtype::Complex64); -/// let expected = &[ -/// complex64::new(4.0, 0.0), -/// complex64::new(0.0, 0.0), -/// complex64::new(0.0, 0.0), -/// complex64::new(0.0, 0.0), -/// ]; -/// assert_eq!(result.as_slice::(), &expected[..]); -/// ``` -#[default_device(device = "cpu")] // fft is not implemented on GPU yet +#[default_device(device = "cpu")] pub fn try_fft2_device<'a>( a: &'a Array, s: impl Into>, @@ -219,14 +134,8 @@ pub fn try_fft2_device<'a>( /// /// # Panic /// -/// - if the input array is a scalar array -/// - if the shape and axes have different sizes -/// - if more axes are provided than the array has -/// - if the output sizes are invalid (<= 0) -/// - if the axes are not unique -/// -/// See [`try_fft2_device`] for more details. -#[default_device(device = "cpu")] // fft is not implemented on GPU yet +/// Panics if the input arguments are invalid. See [`try_fft2_device`] for more details. +#[default_device(device = "cpu")] pub fn fft2_device<'a>( a: &'a Array, s: impl Into>, @@ -252,7 +161,7 @@ fn fftn_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) } } -/// N-dimensional discrete Fourier Transform. +/// n-dimensional discrete Fourier Transform. /// /// # Params /// @@ -261,26 +170,8 @@ fn fftn_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) /// padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes` /// if not specified. /// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is -/// over the last `len(s)` axes are or all axes if `s` is also None. -/// -/// # Example -/// -/// ```rust -/// use mlx::{Dtype, Array, StreamOrDevice, complex64, fft::*}; -/// -/// let array = Array::ones::(&[3, 3, 3]); -/// -/// let mut result = unsafe { fftn_device_unchecked(&array, None, None, StreamOrDevice::cpu()) }; -/// result.eval(); -/// -/// assert_eq!(result.dtype(), Dtype::Complex64); -/// -/// let mut expected = vec![complex64::new(0.0, 0.0); 27]; -/// expected[0] = complex64::new(27.0, 0.0); -/// -/// assert_eq!(result.as_slice::(), &expected[..]); -/// ``` -#[default_device(device = "cpu")] // fft is not implemented on GPU yet +/// over the last `len(s)` axes are or all axes if `s` is also `None`. +#[default_device(device = "cpu")] pub unsafe fn fftn_device_unchecked<'a>( a: &'a Array, s: impl Into>, @@ -291,7 +182,7 @@ pub unsafe fn fftn_device_unchecked<'a>( fftn_device_inner(a, &valid_s, &valid_axes, stream) } -/// N-dimensional discrete Fourier Transform. +/// n-dimensional discrete Fourier Transform. /// /// # Params /// @@ -301,25 +192,7 @@ pub unsafe fn fftn_device_unchecked<'a>( /// if not specified. /// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is /// over the last `len(s)` axes are or all axes if `s` is also `None`. -/// -/// # Example -/// -/// ```rust -/// use mlx::{Dtype, Array, StreamOrDevice, complex64, fft::*}; -/// -/// let array = Array::ones::(&[3, 3, 3]); -/// -/// let mut result = try_fftn(&array, None, None).unwrap(); -/// result.eval(); -/// -/// assert_eq!(result.dtype(), Dtype::Complex64); -/// -/// let mut expected = vec![complex64::new(0.0, 0.0); 27]; -/// expected[0] = complex64::new(27.0, 0.0); -/// -/// assert_eq!(result.as_slice::(), &expected[..]); -/// ``` -#[default_device(device = "cpu")] // fft is not implemented on GPU yet +#[default_device(device = "cpu")] pub fn try_fftn_device<'a>( a: &'a Array, s: impl Into>, @@ -330,7 +203,7 @@ pub fn try_fftn_device<'a>( Ok(fftn_device_inner(a, &valid_s, &valid_axes, stream)) } -/// N-dimensional discrete Fourier Transform. +/// n-dimensional discrete Fourier Transform. /// /// # Params /// @@ -350,7 +223,7 @@ pub fn try_fftn_device<'a>( /// - if more axes are provided than the array has /// /// See [`try_fftn_device`] for more details. -#[default_device(device = "cpu")] // fft is not implemented on GPU yet +#[default_device(device = "cpu")] pub fn fftn_device<'a>( a: &'a Array, s: impl Into>, diff --git a/src/fft/ifftn.rs b/src/fft/ifftn.rs index a533ab116..1ee2265e4 100644 --- a/src/fft/ifftn.rs +++ b/src/fft/ifftn.rs @@ -7,6 +7,14 @@ use super::{ try_resolve_sizes_and_axes, }; +/// One dimensional inverse discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: Input array. +/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded +/// with zeros to match `n`. The default value is `a.shape[axis]` if not specified. +/// - `axis`: Axis along which to perform the FFT. The default is `-1` if not specified. #[default_device(device = "cpu")] pub unsafe fn ifft_device_unchecked( a: &Array, @@ -21,6 +29,14 @@ pub unsafe fn ifft_device_unchecked( } } +/// One dimensional inverse discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: Input array. +/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded +/// with zeros to match `n`. The default value is `a.shape[axis]` if not specified. +/// - `axis`: Axis along which to perform the FFT. The default is `-1` if not specified. #[default_device(device = "cpu")] pub fn try_ifft_device( a: &Array, @@ -32,6 +48,14 @@ pub fn try_ifft_device( unsafe { Ok(ifft_device_unchecked(a, n, axis, stream)) } } +/// One dimensional inverse discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: Input array. +/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded +/// with zeros to match `n`. The default value is `a.shape[axis]` if not specified. +/// - `axis`: Axis along which to perform the FFT. The default is `-1` if not specified. #[default_device(device = "cpu")] pub fn ifft_device( a: &Array, @@ -42,6 +66,14 @@ pub fn ifft_device( try_ifft_device(a, n, axis, stream).unwrap() } +/// Two dimensional inverse discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded +/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. fn ifft2_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { let num_s = s.len(); let num_axes = axes.len(); @@ -56,6 +88,14 @@ fn ifft2_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice } } +/// Two dimensional inverse discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded +/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. #[default_device(device = "cpu")] pub unsafe fn ifft2_device_unchecked<'a>( a: &'a Array, @@ -68,6 +108,14 @@ pub unsafe fn ifft2_device_unchecked<'a>( ifft2_device_inner(a, &valid_s, &valid_axes, stream) } +/// Two dimensional inverse discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded +/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. #[default_device(device = "cpu")] pub fn try_ifft2_device<'a>( a: &'a Array, @@ -80,6 +128,18 @@ pub fn try_ifft2_device<'a>( Ok(ifft2_device_inner(a, &valid_s, &valid_axes, stream)) } +/// Two dimensional inverse discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded +/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. +/// +/// # Panic +/// +/// Panics if the input arguments are invalid. See [try_ifft2_device] for more details. #[default_device(device = "cpu")] pub fn ifft2_device<'a>( a: &'a Array, @@ -104,6 +164,16 @@ fn ifftn_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice } } +/// n-dimensional inverse discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or +/// padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes` +/// if not specified. +/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is +/// over the last `len(s)` axes are or all axes if `s` is also `None`. #[default_device(device = "cpu")] pub unsafe fn ifftn_device_unchecked<'a>( a: &'a Array, @@ -115,6 +185,16 @@ pub unsafe fn ifftn_device_unchecked<'a>( ifftn_device_inner(a, &valid_s, &valid_axes, stream) } +/// n-dimensional inverse discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or +/// padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes` +/// if not specified. +/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is +/// over the last `len(s)` axes are or all axes if `s` is also `None`. #[default_device(device = "cpu")] pub fn try_ifftn_device<'a>( a: &'a Array, @@ -126,6 +206,20 @@ pub fn try_ifftn_device<'a>( Ok(ifftn_device_inner(a, &valid_s, &valid_axes, stream)) } +/// n-dimensional inverse discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or +/// padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes` +/// if not specified. +/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is +/// over the last `len(s)` axes are or all axes if `s` is also `None`. +/// +/// # Panic +/// +/// Panics if the input arguments are invalid. See [try_ifftn_device] for more details. #[default_device(device = "cpu")] pub fn ifftn_device<'a>( a: &'a Array, diff --git a/src/fft/mod.rs b/src/fft/mod.rs index 5f690f978..02b484e72 100644 --- a/src/fft/mod.rs +++ b/src/fft/mod.rs @@ -1,3 +1,172 @@ +//! Fast Fourier Transform (FFT) and its inverse (IFFT) for one, two, and `N` dimensions. +//! +//! Like all other functions in `mlx-rs`, three variants are provided for each FFT function, plus +//! each variant has a version that uses the default `StreamOrDevice` or takes a user-specified +//! `StreamOrDevice`. +//! +//! The difference are explained below using `fftn` as an example: +//! +//! 1. `fftn_unchecked`/`fftn_device_unchecked`: This function is simply a wrapper around the C API +//! and does not perform any checks on the input. It may panic or get an fatal error that cannot +//! be caught by the rust runtime if the input is invalid. +//! 2. `try_fftn`/`try_fftn_device`: This function performs checks on the input and returns a +//! `Result` instead of panicking. +//! 3. `fftn`/`fftn_device`: This function is a wrapper around `try_fftn` and unwraps the result. It +//! panics if the input is invalid. +//! +//! The functions that contains `device` in their name are meant to be used with a user-specified +//! `StreamOrDevice`. If you don't care about the stream, you can use the functions without `device` +//! in their names. Please note that GPU device support is not yet implemented. +//! +//! # Examples +//! +//! ## One dimension +//! +//! ```rust +//! use mlx::{Dtype, Array, StreamOrDevice, complex64, fft::*}; +//! +//! let src = [1.0f32, 2.0, 3.0, 4.0]; +//! let array = Array::from_slice(&src[..], &[4]); +//! +//! let mut fft_result = fft(&array, 4, 0); +//! fft_result.eval(); +//! assert_eq!(fft_result.dtype(), Dtype::Complex64); +//! +//! let expected = &[ +//! complex64::new(10.0, 0.0), +//! complex64::new(-2.0, 2.0), +//! complex64::new(-2.0, 0.0), +//! complex64::new(-2.0, -2.0), +//! ]; +//! assert_eq!(fft_result.as_slice::(), &expected[..]); +//! +//! let mut ifft_result = ifft(&fft_result, 4, 0); +//! ifft_result.eval(); +//! assert_eq!(ifft_result.dtype(), Dtype::Complex64); +//! +//! let expected = &[ +//! complex64::new(1.0, 0.0), +//! complex64::new(2.0, 0.0), +//! complex64::new(3.0, 0.0), +//! complex64::new(4.0, 0.0), +//! ]; +//! assert_eq!(ifft_result.as_slice::(), &expected[..]); +//! +//! let mut rfft_result = rfft(&array, 4, 0); +//! rfft_result.eval(); +//! assert_eq!(rfft_result.dtype(), Dtype::Complex64); +//! +//! let expected = &[ +//! complex64::new(10.0, 0.0), +//! complex64::new(-2.0, 2.0), +//! complex64::new(-2.0, 0.0), +//! ]; +//! assert_eq!(rfft_result.as_slice::(), &expected[..]); +//! +//! let mut irfft_result = irfft(&rfft_result, 4, 0); +//! irfft_result.eval(); +//! assert_eq!(irfft_result.dtype(), Dtype::Float32); +//! assert_eq!(irfft_result.as_slice::(), &src[..]); +//! +//! // The original array is not modified +//! let data: &[f32] = array.as_slice(); +//! assert_eq!(data, &src[..]); +//! ``` +//! +//! ## Two dimensions +//! +//! ```rust +//! use mlx::{Dtype, Array, StreamOrDevice, complex64, fft::*}; +//! +//! let src = [1.0f32, 1.0, 1.0, 1.0]; +//! let array = Array::from_slice(&src[..], &[2, 2]); +//! +//! let mut fft2_result = fft2(&array, None, None); +//! fft2_result.eval(); +//! assert_eq!(fft2_result.dtype(), Dtype::Complex64); +//! let expected = &[ +//! complex64::new(4.0, 0.0), +//! complex64::new(0.0, 0.0), +//! complex64::new(0.0, 0.0), +//! complex64::new(0.0, 0.0), +//! ]; +//! assert_eq!(fft2_result.as_slice::(), &expected[..]); +//! +//! let mut ifft2_result = ifft2(&fft2_result, None, None); +//! ifft2_result.eval(); +//! assert_eq!(ifft2_result.dtype(), Dtype::Complex64); +//! +//! let expected = &[ +//! complex64::new(1.0, 0.0), +//! complex64::new(1.0, 0.0), +//! complex64::new(1.0, 0.0), +//! complex64::new(1.0, 0.0), +//! ]; +//! assert_eq!(ifft2_result.as_slice::(), &expected[..]); +//! +//! let mut rfft2_result = rfft2(&array, None, None); +//! rfft2_result.eval(); +//! assert_eq!(rfft2_result.dtype(), Dtype::Complex64); +//! +//! let expected = &[ +//! complex64::new(4.0, 0.0), +//! complex64::new(0.0, 0.0), +//! complex64::new(0.0, 0.0), +//! complex64::new(0.0, 0.0), +//! ]; +//! assert_eq!(rfft2_result.as_slice::(), &expected[..]); +//! +//! let mut irfft2_result = irfft2(&rfft2_result, None, None); +//! irfft2_result.eval(); +//! assert_eq!(irfft2_result.dtype(), Dtype::Float32); +//! assert_eq!(irfft2_result.as_slice::(), &src[..]); +//! +//! // The original array is not modified +//! let data: &[f32] = array.as_slice(); +//! assert_eq!(data, &[1.0, 1.0, 1.0, 1.0]); +//! ``` +//! +//! ## `N` dimensions +//! +//! ```rust +//! use mlx::{Dtype, Array, StreamOrDevice, complex64, fft::*}; +//! +//! let array = Array::ones::(&[2, 2, 2]); +//! let mut fftn_result = fftn(&array, None, None); +//! fftn_result.eval(); +//! assert_eq!(fftn_result.dtype(), Dtype::Complex64); +//! +//! let mut expected = [complex64::new(0.0, 0.0); 8]; +//! expected[0] = complex64::new(8.0, 0.0); +//! assert_eq!(fftn_result.as_slice::(), &expected[..]); +//! +//! let mut ifftn_result = ifftn(&fftn_result, None, None); +//! ifftn_result.eval(); +//! assert_eq!(ifftn_result.dtype(), Dtype::Complex64); +//! +//! let expected = [complex64::new(1.0, 0.0); 8]; +//! assert_eq!(ifftn_result.as_slice::(), &expected[..]); +//! +//! let mut rfftn_result = rfftn(&array, None, None); +//! rfftn_result.eval(); +//! assert_eq!(rfftn_result.dtype(), Dtype::Complex64); +//! +//! let mut expected = [complex64::new(0.0, 0.0); 8]; +//! expected[0] = complex64::new(8.0, 0.0); +//! assert_eq!(rfftn_result.as_slice::(), &expected[..]); +//! +//! let mut irfftn_result = irfftn(&rfftn_result, None, None); +//! irfftn_result.eval(); +//! assert_eq!(irfftn_result.dtype(), Dtype::Float32); +//! +//! let expected = [1.0; 8]; +//! assert_eq!(irfftn_result.as_slice::(), &expected[..]); +//! +//! // The original array is not modified +//! let data: &[f32] = array.as_slice(); +//! assert_eq!(data, &[1.0; 8]); +//! ``` + mod fftn; mod ifftn; mod irfftn; diff --git a/src/fft/rfftn.rs b/src/fft/rfftn.rs index 3a96ad98f..6bc56482e 100644 --- a/src/fft/rfftn.rs +++ b/src/fft/rfftn.rs @@ -7,6 +7,17 @@ use super::{ try_resolve_sizes_and_axes, }; +/// One dimensional discrete Fourier Transform on a real input. +/// +/// The output has the same shape as the input except along `axis` in which case it has size `n // 2 +/// + 1`. +/// +/// # Params +/// +/// - `a`: The input array. If the array is complex it will be silently cast to a real type. +/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded +/// with zeros to match `n`. The default value is `a.shape[axis]` if not specified. +/// - `axis`: Axis along which to perform the FFT. The default is `-1` if not specified. #[default_device(device = "cpu")] pub unsafe fn rfft_device_unchecked( a: &Array, @@ -21,6 +32,17 @@ pub unsafe fn rfft_device_unchecked( } } +/// One dimensional discrete Fourier Transform on a real input. +/// +/// The output has the same shape as the input except along `axis` in which case it has size `n // 2 +/// + 1`. +/// +/// # Params +/// +/// - `a`: The input array. If the array is complex it will be silently cast to a real type. +/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded +/// with zeros to match `n`. The default value is `a.shape[axis]` if not specified. +/// - `axis`: Axis along which to perform the FFT. The default is `-1` if not specified. #[default_device(device = "cpu")] pub fn try_rfft_device( a: &Array, @@ -32,6 +54,21 @@ pub fn try_rfft_device( unsafe { Ok(rfft_device_unchecked(a, n, axis, stream)) } } +/// One dimensional discrete Fourier Transform on a real input. +/// +/// The output has the same shape as the input except along `axis` in which case it has size `n // 2 +/// + 1`. +/// +/// # Params +/// +/// - `a`: The input array. If the array is complex it will be silently cast to a real type. +/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded +/// with zeros to match `n`. The default value is `a.shape[axis]` if not specified. +/// - `axis`: Axis along which to perform the FFT. The default is `-1` if not specified. +/// +/// # Panic +/// +/// Panics if the input arguments are invalid. See [try_rfft_device()] for more information. #[default_device(device = "cpu")] pub fn rfft_device( a: &Array, @@ -56,6 +93,18 @@ fn rfft2_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice } } +/// Two dimensional real discrete Fourier Transform. +/// +/// The output has the same shape as the input except along the dimensions in `axes` in which case +/// it has sizes from `s`. The last axis in `axes` is treated as the real axis and will have size +/// `s[-1] // 2 + 1`. +/// +/// # Params +/// +/// - `a`: The input array. If the array is complex it will be silently cast to a real type. +/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or +/// padded with zeros to match `s`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. #[default_device(device = "cpu")] pub unsafe fn rfft2_device_unchecked<'a>( a: &'a Array, @@ -68,6 +117,18 @@ pub unsafe fn rfft2_device_unchecked<'a>( rfft2_device_inner(a, &valid_s, &valid_axes, stream) } +/// Two dimensional real discrete Fourier Transform. +/// +/// The output has the same shape as the input except along the dimensions in `axes` in which case +/// it has sizes from `s`. The last axis in `axes` is treated as the real axis and will have size +/// `s[-1] // 2 + 1`. +/// +/// # Params +/// +/// - `a`: The input array. If the array is complex it will be silently cast to a real type. +/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or +/// padded with zeros to match `s`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. #[default_device(device = "cpu")] pub fn try_rfft2_device<'a>( a: &'a Array, @@ -80,6 +141,22 @@ pub fn try_rfft2_device<'a>( Ok(rfft2_device_inner(a, &valid_s, &valid_axes, stream)) } +/// Two dimensional real discrete Fourier Transform. +/// +/// The output has the same shape as the input except along the dimensions in `axes` in which case +/// it has sizes from `s`. The last axis in `axes` is treated as the real axis and will have size +/// `s[-1] // 2 + 1`. +/// +/// # Params +/// +/// - `a`: The input array. If the array is complex it will be silently cast to a real type. +/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or +/// padded with zeros to match `s`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. +/// +/// # Panic +/// +/// Panics if the input arguments are invalid. See [try_rfft2_device()] for more information. #[default_device(device = "cpu")] pub fn rfft2_device<'a>( a: &'a Array, @@ -104,6 +181,19 @@ fn rfftn_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice } } +/// n-dimensional real discrete Fourier Transform. +/// +/// The output has the same shape as the input except along the dimensions in `axes` in which case +/// it has sizes from `s`. The last axis in `axes` is treated as the real axis and will have size +/// `s[-1] // 2 + 1`. +/// +/// # Params +/// +/// - `a`: The input array. If the array is complex it will be silently cast to a real type. +/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or +/// padded with zeros to match `s`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is over +/// the last `len(s)` axes or all axes if `s` is also `None`. #[default_device(device = "cpu")] pub unsafe fn rfftn_device_unchecked<'a>( a: &'a Array, @@ -115,6 +205,19 @@ pub unsafe fn rfftn_device_unchecked<'a>( rfftn_device_inner(a, &valid_s, &valid_axes, stream) } +/// n-dimensional real discrete Fourier Transform. +/// +/// The output has the same shape as the input except along the dimensions in `axes` in which case +/// it has sizes from `s`. The last axis in `axes` is treated as the real axis and will have size +/// `s[-1] // 2 + 1`. +/// +/// # Params +/// +/// - `a`: The input array. If the array is complex it will be silently cast to a real type. +/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or +/// padded with zeros to match `s`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is over +/// the last `len(s)` axes or all axes if `s` is also `None`. #[default_device(device = "cpu")] pub fn try_rfftn_device<'a>( a: &'a Array, @@ -126,6 +229,23 @@ pub fn try_rfftn_device<'a>( Ok(rfftn_device_inner(a, &valid_s, &valid_axes, stream)) } +/// n-dimensional real discrete Fourier Transform. +/// +/// The output has the same shape as the input except along the dimensions in `axes` in which case +/// it has sizes from `s`. The last axis in `axes` is treated as the real axis and will have size +/// `s[-1] // 2 + 1`. +/// +/// # Params +/// +/// - `a`: The input array. If the array is complex it will be silently cast to a real type. +/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or +/// padded with zeros to match `s`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is over +/// the last `len(s)` axes or all axes if `s` is also `None`. +/// +/// # Panic +/// +/// Panics if the input arguments are invalid. See [try_rfftn_device] for more information. #[default_device(device = "cpu")] pub fn rfftn_device<'a>( a: &'a Array, From 71dc6ee04a7f47509c902b572ef4f2dd5c4190c8 Mon Sep 17 00:00:00 2001 From: minghuaw Date: Tue, 23 Apr 2024 12:31:21 -0700 Subject: [PATCH 16/25] re-organize fft mod and unit tests --- src/fft/fftn.rs | 496 ++++++++++++++++++++++++++-------------------- src/fft/ifftn.rs | 231 --------------------- src/fft/irfftn.rs | 137 ------------- src/fft/mod.rs | 4 +- src/fft/rfftn.rs | 208 +++++++++++++++++++ 5 files changed, 489 insertions(+), 587 deletions(-) delete mode 100644 src/fft/ifftn.rs delete mode 100644 src/fft/irfftn.rs diff --git a/src/fft/fftn.rs b/src/fft/fftn.rs index 1544cfab1..8ca9580af 100644 --- a/src/fft/fftn.rs +++ b/src/fft/fftn.rs @@ -2,7 +2,10 @@ use mlx_macros::default_device; use crate::{array::Array, error::FftError, stream::StreamOrDevice}; -use super::resolve_size_and_axis_unchecked; +use super::{ + resolve_size_and_axis_unchecked, resolve_sizes_and_axes_unchecked, try_resolve_size_and_axis, + try_resolve_sizes_and_axes, +}; /// One dimensional discrete Fourier Transform. /// @@ -41,7 +44,7 @@ pub fn try_fft_device( axis: impl Into>, stream: StreamOrDevice, ) -> Result { - let (n, axis) = super::try_resolve_size_and_axis(a, n, axis)?; + let (n, axis) = try_resolve_size_and_axis(a, n, axis)?; unsafe { Ok(fft_device_unchecked(a, Some(n), Some(axis), stream)) } } @@ -99,7 +102,7 @@ pub unsafe fn fft2_device_unchecked<'a>( stream: StreamOrDevice, ) -> Array { let axes = axes.into().unwrap_or(&[-2, -1]); - let (valid_s, valid_axes) = super::resolve_sizes_and_axes_unchecked(a, s, axes); + let (valid_s, valid_axes) = resolve_sizes_and_axes_unchecked(a, s, axes); fft2_device_inner(a, &valid_s, &valid_axes, stream) } @@ -119,7 +122,7 @@ pub fn try_fft2_device<'a>( stream: StreamOrDevice, ) -> Result { let valid_axes = axes.into().unwrap_or(&[-2, -1]); - let (valid_s, valid_axes) = super::try_resolve_sizes_and_axes(a, s, valid_axes)?; + let (valid_s, valid_axes) = try_resolve_sizes_and_axes(a, s, valid_axes)?; Ok(fft2_device_inner(a, &valid_s, &valid_axes, stream)) } @@ -178,7 +181,7 @@ pub unsafe fn fftn_device_unchecked<'a>( axes: impl Into>, stream: StreamOrDevice, ) -> Array { - let (valid_s, valid_axes) = super::resolve_sizes_and_axes_unchecked(a, s, axes); + let (valid_s, valid_axes) = resolve_sizes_and_axes_unchecked(a, s, axes); fftn_device_inner(a, &valid_s, &valid_axes, stream) } @@ -199,7 +202,7 @@ pub fn try_fftn_device<'a>( axes: impl Into>, stream: StreamOrDevice, ) -> Result { - let (valid_s, valid_axes) = super::try_resolve_sizes_and_axes(a, s, axes)?; + let (valid_s, valid_axes) = try_resolve_sizes_and_axes(a, s, axes)?; Ok(fftn_device_inner(a, &valid_s, &valid_axes, stream)) } @@ -233,201 +236,309 @@ pub fn fftn_device<'a>( try_fftn_device(a, s, axes, stream).unwrap() } -// TODO: test out of bound indexing -#[cfg(test)] -mod tests { - #[test] - fn test_fft_unchecked() { - use crate::{complex64, fft::*, Array, Dtype}; - - let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); - let mut result = unsafe { fft_unchecked(&array, 4, 0) }; - result.eval(); +/// One dimensional inverse discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: Input array. +/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded +/// with zeros to match `n`. The default value is `a.shape[axis]` if not specified. +/// - `axis`: Axis along which to perform the FFT. The default is `-1` if not specified. +#[default_device(device = "cpu")] +pub unsafe fn ifft_device_unchecked( + a: &Array, + n: impl Into>, + axis: impl Into>, + stream: StreamOrDevice, +) -> Array { + let (n, axis) = resolve_size_and_axis_unchecked(a, n, axis); + unsafe { + let c_array = mlx_sys::mlx_fft_ifft(a.c_array, n, axis, stream.stream.c_stream); + Array::from_ptr(c_array) + } +} - assert_eq!(result.dtype(), Dtype::Complex64); +/// One dimensional inverse discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: Input array. +/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded +/// with zeros to match `n`. The default value is `a.shape[axis]` if not specified. +/// - `axis`: Axis along which to perform the FFT. The default is `-1` if not specified. +#[default_device(device = "cpu")] +pub fn try_ifft_device( + a: &Array, + n: impl Into>, + axis: impl Into>, + stream: StreamOrDevice, +) -> Result { + let (n, axis) = try_resolve_size_and_axis(a, n, axis)?; + unsafe { Ok(ifft_device_unchecked(a, n, axis, stream)) } +} - let expected = &[ - complex64::new(10.0, 0.0), - complex64::new(-2.0, 2.0), - complex64::new(-2.0, 0.0), - complex64::new(-2.0, -2.0), - ]; - assert_eq!(result.as_slice::(), &expected[..]); +/// One dimensional inverse discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: Input array. +/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded +/// with zeros to match `n`. The default value is `a.shape[axis]` if not specified. +/// - `axis`: Axis along which to perform the FFT. The default is `-1` if not specified. +#[default_device(device = "cpu")] +pub fn ifft_device( + a: &Array, + n: impl Into>, + axis: impl Into>, + stream: StreamOrDevice, +) -> Array { + try_ifft_device(a, n, axis, stream).unwrap() +} - // The original array is not modified and valid - let data: &[f32] = array.as_slice(); - assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); - } +/// Two dimensional inverse discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded +/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. +fn ifft2_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { + let num_s = s.len(); + let num_axes = axes.len(); - #[test] - fn test_try_fft() { - use crate::{complex64, fft::*, Array, Dtype}; + let s_ptr = s.as_ptr(); + let axes_ptr = axes.as_ptr(); - let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); + unsafe { + let c_array = + mlx_sys::mlx_fft_ifft2(a.c_array, s_ptr, num_s, axes_ptr, num_axes, stream.as_ptr()); + Array::from_ptr(c_array) + } +} - // Error case - let scalar_array = Array::from_float(1.0); - let result = try_fft(&scalar_array, 0, 0); - assert!(result.is_err()); +/// Two dimensional inverse discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded +/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. +#[default_device(device = "cpu")] +pub unsafe fn ifft2_device_unchecked<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + let axes = axes.into().unwrap_or(&[-2, -1]); + let (valid_s, valid_axes) = resolve_sizes_and_axes_unchecked(a, s, axes); + ifft2_device_inner(a, &valid_s, &valid_axes, stream) +} - let result = try_fft(&array, 4, 2); - assert!(result.is_err()); +/// Two dimensional inverse discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded +/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. +#[default_device(device = "cpu")] +pub fn try_ifft2_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Result { + let axes = axes.into().unwrap_or(&[-2, -1]); + let (valid_s, valid_axes) = try_resolve_sizes_and_axes(a, s, axes)?; + Ok(ifft2_device_inner(a, &valid_s, &valid_axes, stream)) +} - // Success case - let mut result = try_fft(&array, 4, 0).unwrap(); - result.eval(); +/// Two dimensional inverse discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded +/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. +/// +/// # Panic +/// +/// Panics if the input arguments are invalid. See [try_ifft2_device] for more details. +#[default_device(device = "cpu")] +pub fn ifft2_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + try_ifft2_device(a, s, axes, stream).unwrap() +} - assert_eq!(result.dtype(), Dtype::Complex64); +fn ifftn_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { + let num_s = s.len(); + let num_axes = axes.len(); - let expected = &[ - complex64::new(10.0, 0.0), - complex64::new(-2.0, 2.0), - complex64::new(-2.0, 0.0), - complex64::new(-2.0, -2.0), - ]; - assert_eq!(result.as_slice::(), &expected[..]); + let s_ptr = s.as_ptr(); + let axes_ptr = axes.as_ptr(); - // test that previous array is not modified and valid - let data: &[f32] = array.as_slice(); - assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); + unsafe { + let c_array = + mlx_sys::mlx_fft_ifftn(a.c_array, s_ptr, num_s, axes_ptr, num_axes, stream.as_ptr()); + Array::from_ptr(c_array) } +} - #[test] - fn test_fft() { - use crate::{complex64, fft::*, Array, Dtype}; +/// n-dimensional inverse discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or +/// padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes` +/// if not specified. +/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is +/// over the last `len(s)` axes are or all axes if `s` is also `None`. +#[default_device(device = "cpu")] +pub unsafe fn ifftn_device_unchecked<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + let (valid_s, valid_axes) = resolve_sizes_and_axes_unchecked(a, s, axes); + ifftn_device_inner(a, &valid_s, &valid_axes, stream) +} - let array = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]); +/// n-dimensional inverse discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or +/// padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes` +/// if not specified. +/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is +/// over the last `len(s)` axes are or all axes if `s` is also `None`. +#[default_device(device = "cpu")] +pub fn try_ifftn_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Result { + let (valid_s, valid_axes) = try_resolve_sizes_and_axes(a, s, axes)?; + Ok(ifftn_device_inner(a, &valid_s, &valid_axes, stream)) +} - // Success case - let mut result = fft(&array, 4, 0); - result.eval(); +/// n-dimensional inverse discrete Fourier Transform. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or +/// padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes` +/// if not specified. +/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is +/// over the last `len(s)` axes are or all axes if `s` is also `None`. +/// +/// # Panic +/// +/// Panics if the input arguments are invalid. See [try_ifftn_device] for more details. +#[default_device(device = "cpu")] +pub fn ifftn_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + try_ifftn_device(a, s, axes, stream).unwrap() +} - assert_eq!(result.dtype(), Dtype::Complex64); +#[cfg(test)] +mod tests { + use crate::{complex64, fft::*, Array, Dtype}; - let expected = &[ + #[test] + fn test_fft() { + const FFT_DATA: &[f32] = &[1.0, 2.0, 3.0, 4.0]; + const FFT_SHAPE: &[i32] = &[4]; + const FFT_EXPECTED: &[complex64; 4] = &[ complex64::new(10.0, 0.0), complex64::new(-2.0, 2.0), complex64::new(-2.0, 0.0), complex64::new(-2.0, -2.0), ]; - assert_eq!(result.as_slice::(), &expected[..]); - - // test that previous array is not modified and valid - let data: &[f32] = array.as_slice(); - assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]); - } - #[test] - fn test_fft2_unchecked() { - use crate::{complex64, fft::*, Array, Dtype}; + let array = Array::from_slice(FFT_DATA, FFT_SHAPE); + let mut fft = fft(&array, None, None); + fft.eval(); - let array = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); - let n = [2, 2]; - let axes = [-2, -1]; - let mut result = unsafe { fft2_unchecked(&array, &n[..], &axes[..]) }; - result.eval(); + assert_eq!(fft.dtype(), Dtype::Complex64); + assert_eq!(fft.as_slice::(), FFT_EXPECTED); - assert_eq!(result.dtype(), Dtype::Complex64); + let mut ifft = ifft(&fft, None, None); + ifft.eval(); - let expected = &[ - complex64::new(4.0, 0.0), - complex64::new(0.0, 0.0), - complex64::new(0.0, 0.0), - complex64::new(0.0, 0.0), - ]; - assert_eq!(result.as_slice::(), &expected[..]); + assert_eq!(ifft.dtype(), Dtype::Complex64); + assert_eq!( + ifft.as_slice::(), + FFT_DATA + .iter() + .map(|&x| complex64::new(x, 0.0)) + .collect::>() + ); - // test that previous array is not modified and valid + // The original array is not modified and valid let data: &[f32] = array.as_slice(); - assert_eq!(data, &[1.0, 1.0, 1.0, 1.0]); + assert_eq!(data, FFT_DATA); } #[test] - fn test_try_fft2() { - use crate::{complex64, error::FftError, fft::*, Array}; - - let array = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); - - // Error case - let scalar_array = Array::from_float(1.0); - let result = try_fft2(&scalar_array, None, None); - assert_eq!(result.unwrap_err(), FftError::ScalarArray); - - let result = try_fft2(&array, &[2, 2, 2][..], &[0, 1, 2][..]); - assert_eq!(result.unwrap_err(), FftError::InvalidAxis { ndim: 2 }); - - let result = try_fft2(&array, &[2, 2][..], &[-1][..]); - assert_eq!( - result.unwrap_err(), - FftError::IncompatibleShapeAndAxes { - shape_size: 2, - axes_size: 1, - } - ); - - let result = try_fft2(&array, None, &[-2, -2][..]); - assert_eq!(result.unwrap_err(), FftError::DuplicateAxis { axis: -2 }); - - let result = try_fft2(&array, &[-2, 2][..], None); - assert_eq!(result.unwrap_err(), FftError::InvalidOutputSize); - - // Success case - let mut result = try_fft2(&array, None, None).unwrap(); - result.eval(); - - assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); - - let expected = &[ + fn test_fft2() { + const FFT2_DATA: &[f32] = &[1.0, 1.0, 1.0, 1.0]; + const FFT2_SHAPE: &[i32] = &[2, 2]; + const FFT2_EXPECTED: &[complex64; 4] = &[ complex64::new(4.0, 0.0), complex64::new(0.0, 0.0), complex64::new(0.0, 0.0), complex64::new(0.0, 0.0), ]; - assert_eq!(result.as_slice::(), &expected[..]); - - // test that previous array is not modified and valid - let data: &[f32] = array.as_slice(); - assert_eq!(data, &[1.0, 1.0, 1.0, 1.0]); - } - #[test] - fn test_fft2() { - use crate::{complex64, fft::*, Array, Dtype}; + let array = Array::from_slice(FFT2_DATA, FFT2_SHAPE); + let mut fft2 = fft2(&array, None, None); + fft2.eval(); - let array = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); - let n = [2, 2]; - let axes = [-2, -1]; - let mut result = fft2(&array, Some(&n[..]), Some(&axes[..])); - result.eval(); + assert_eq!(fft2.dtype(), Dtype::Complex64); + assert_eq!(fft2.as_slice::(), FFT2_EXPECTED); - assert_eq!(result.dtype(), Dtype::Complex64); + let mut ifft2 = ifft2(&fft2, None, None); + ifft2.eval(); - let expected = &[ - complex64::new(4.0, 0.0), - complex64::new(0.0, 0.0), - complex64::new(0.0, 0.0), - complex64::new(0.0, 0.0), - ]; - assert_eq!(result.as_slice::(), &expected[..]); + assert_eq!(ifft2.dtype(), Dtype::Complex64); + assert_eq!( + ifft2.as_slice::(), + FFT2_DATA + .iter() + .map(|&x| complex64::new(x, 0.0)) + .collect::>() + ); // test that previous array is not modified and valid let data: &[f32] = array.as_slice(); - assert_eq!(data, &[1.0, 1.0, 1.0, 1.0]); + assert_eq!(data, FFT2_DATA); } #[test] - fn test_fftn_unchecked() { - use crate::{complex64, fft::*, Array, Dtype}; - - let array = Array::ones::(&[3, 3]); - let mut result = unsafe { fftn_unchecked(&array, None, None) }; - result.eval(); - - assert_eq!(result.dtype(), Dtype::Complex64); - - let expected = &[ - complex64::new(9.0, 0.0), - complex64::new(0.0, 0.0), + fn test_fftn() { + const FFTN_DATA: &[f32] = &[1.0; 8]; + const FFTN_SHAPE: &[i32] = &[2, 2, 2]; + const FFTN_EXPECTED: &[complex64; 8] = &[ + complex64::new(8.0, 0.0), complex64::new(0.0, 0.0), complex64::new(0.0, 0.0), complex64::new(0.0, 0.0), @@ -436,75 +547,28 @@ mod tests { complex64::new(0.0, 0.0), complex64::new(0.0, 0.0), ]; - assert_eq!(result.as_slice::(), &expected[..]); - // test that previous array is not modified and valid - let data: &[f32] = array.as_slice(); - assert_eq!(data, &[1.0; 9]); - } - - #[test] - fn test_try_fftn() { - use crate::{complex64, error::FftError, fft::*, Array}; - - let array = Array::ones::(&[3, 3, 3]); + let array = Array::from_slice(FFTN_DATA, FFTN_SHAPE); + let mut fftn = fftn(&array, None, None); + fftn.eval(); - // Error case - let scalar_array = Array::from_float(1.0); - let result = try_fftn(&scalar_array, None, None); - assert_eq!(result.unwrap_err(), FftError::ScalarArray); + assert_eq!(fftn.dtype(), Dtype::Complex64); + assert_eq!(fftn.as_slice::(), FFTN_EXPECTED); - let result = try_fftn(&array, &[3, 3, 3, 3][..], &[0, 1, 2, 3][..]); - assert_eq!(result.unwrap_err(), FftError::InvalidAxis { ndim: 3 }); + let mut ifftn = ifftn(&fftn, None, None); + ifftn.eval(); - let result = try_fftn(&array, &[3, 3, 3][..], &[-1][..]); + assert_eq!(ifftn.dtype(), Dtype::Complex64); assert_eq!( - result.unwrap_err(), - FftError::IncompatibleShapeAndAxes { - shape_size: 3, - axes_size: 1, - } + ifftn.as_slice::(), + FFTN_DATA + .iter() + .map(|&x| complex64::new(x, 0.0)) + .collect::>() ); - let result = try_fftn(&array, None, &[-2, -2][..]); - assert_eq!(result.unwrap_err(), FftError::DuplicateAxis { axis: -2 }); - - let result = try_fftn(&array, &[-2, 2][..], None); - assert_eq!(result.unwrap_err(), FftError::InvalidOutputSize); - - // Success case - let mut result = try_fftn(&array, None, None).unwrap(); - result.eval(); - - assert_eq!(result.dtype(), crate::dtype::Dtype::Complex64); - - let mut expected = vec![complex64::new(0.0, 0.0); 27]; - expected[0] = complex64::new(27.0, 0.0); - - assert_eq!(result.as_slice::(), &expected[..]); - - // test that previous array is not modified and valid - let data: &[f32] = array.as_slice(); - assert_eq!(data, &[1.0; 27]); - } - - #[test] - fn test_fftn() { - use crate::{complex64, fft::*, Array, Dtype}; - - let array = Array::ones::(&[3, 3, 3]); - let mut result = fftn(&array, None, None); - result.eval(); - - assert_eq!(result.dtype(), Dtype::Complex64); - - let mut expected = vec![complex64::new(0.0, 0.0); 27]; - expected[0] = complex64::new(27.0, 0.0); - - assert_eq!(result.as_slice::(), &expected[..]); - // test that previous array is not modified and valid let data: &[f32] = array.as_slice(); - assert_eq!(data, &[1.0; 27]); + assert_eq!(data, FFTN_DATA); } } diff --git a/src/fft/ifftn.rs b/src/fft/ifftn.rs deleted file mode 100644 index 1ee2265e4..000000000 --- a/src/fft/ifftn.rs +++ /dev/null @@ -1,231 +0,0 @@ -use mlx_macros::default_device; - -use crate::{error::FftError, Array, StreamOrDevice}; - -use super::{ - resolve_size_and_axis_unchecked, resolve_sizes_and_axes_unchecked, try_resolve_size_and_axis, - try_resolve_sizes_and_axes, -}; - -/// One dimensional inverse discrete Fourier Transform. -/// -/// # Params -/// -/// - `a`: Input array. -/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded -/// with zeros to match `n`. The default value is `a.shape[axis]` if not specified. -/// - `axis`: Axis along which to perform the FFT. The default is `-1` if not specified. -#[default_device(device = "cpu")] -pub unsafe fn ifft_device_unchecked( - a: &Array, - n: impl Into>, - axis: impl Into>, - stream: StreamOrDevice, -) -> Array { - let (n, axis) = resolve_size_and_axis_unchecked(a, n, axis); - unsafe { - let c_array = mlx_sys::mlx_fft_ifft(a.c_array, n, axis, stream.stream.c_stream); - Array::from_ptr(c_array) - } -} - -/// One dimensional inverse discrete Fourier Transform. -/// -/// # Params -/// -/// - `a`: Input array. -/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded -/// with zeros to match `n`. The default value is `a.shape[axis]` if not specified. -/// - `axis`: Axis along which to perform the FFT. The default is `-1` if not specified. -#[default_device(device = "cpu")] -pub fn try_ifft_device( - a: &Array, - n: impl Into>, - axis: impl Into>, - stream: StreamOrDevice, -) -> Result { - let (n, axis) = try_resolve_size_and_axis(a, n, axis)?; - unsafe { Ok(ifft_device_unchecked(a, n, axis, stream)) } -} - -/// One dimensional inverse discrete Fourier Transform. -/// -/// # Params -/// -/// - `a`: Input array. -/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded -/// with zeros to match `n`. The default value is `a.shape[axis]` if not specified. -/// - `axis`: Axis along which to perform the FFT. The default is `-1` if not specified. -#[default_device(device = "cpu")] -pub fn ifft_device( - a: &Array, - n: impl Into>, - axis: impl Into>, - stream: StreamOrDevice, -) -> Array { - try_ifft_device(a, n, axis, stream).unwrap() -} - -/// Two dimensional inverse discrete Fourier Transform. -/// -/// # Params -/// -/// - `a`: The input array. -/// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded -/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. -/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. -fn ifft2_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { - let num_s = s.len(); - let num_axes = axes.len(); - - let s_ptr = s.as_ptr(); - let axes_ptr = axes.as_ptr(); - - unsafe { - let c_array = - mlx_sys::mlx_fft_ifft2(a.c_array, s_ptr, num_s, axes_ptr, num_axes, stream.as_ptr()); - Array::from_ptr(c_array) - } -} - -/// Two dimensional inverse discrete Fourier Transform. -/// -/// # Params -/// -/// - `a`: The input array. -/// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded -/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. -/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. -#[default_device(device = "cpu")] -pub unsafe fn ifft2_device_unchecked<'a>( - a: &'a Array, - s: impl Into>, - axes: impl Into>, - stream: StreamOrDevice, -) -> Array { - let axes = axes.into().unwrap_or(&[-2, -1]); - let (valid_s, valid_axes) = resolve_sizes_and_axes_unchecked(a, s, axes); - ifft2_device_inner(a, &valid_s, &valid_axes, stream) -} - -/// Two dimensional inverse discrete Fourier Transform. -/// -/// # Params -/// -/// - `a`: The input array. -/// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded -/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. -/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. -#[default_device(device = "cpu")] -pub fn try_ifft2_device<'a>( - a: &'a Array, - s: impl Into>, - axes: impl Into>, - stream: StreamOrDevice, -) -> Result { - let axes = axes.into().unwrap_or(&[-2, -1]); - let (valid_s, valid_axes) = try_resolve_sizes_and_axes(a, s, axes)?; - Ok(ifft2_device_inner(a, &valid_s, &valid_axes, stream)) -} - -/// Two dimensional inverse discrete Fourier Transform. -/// -/// # Params -/// -/// - `a`: The input array. -/// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded -/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. -/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. -/// -/// # Panic -/// -/// Panics if the input arguments are invalid. See [try_ifft2_device] for more details. -#[default_device(device = "cpu")] -pub fn ifft2_device<'a>( - a: &'a Array, - s: impl Into>, - axes: impl Into>, - stream: StreamOrDevice, -) -> Array { - try_ifft2_device(a, s, axes, stream).unwrap() -} - -fn ifftn_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { - let num_s = s.len(); - let num_axes = axes.len(); - - let s_ptr = s.as_ptr(); - let axes_ptr = axes.as_ptr(); - - unsafe { - let c_array = - mlx_sys::mlx_fft_ifftn(a.c_array, s_ptr, num_s, axes_ptr, num_axes, stream.as_ptr()); - Array::from_ptr(c_array) - } -} - -/// n-dimensional inverse discrete Fourier Transform. -/// -/// # Params -/// -/// - `a`: The input array. -/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or -/// padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes` -/// if not specified. -/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is -/// over the last `len(s)` axes are or all axes if `s` is also `None`. -#[default_device(device = "cpu")] -pub unsafe fn ifftn_device_unchecked<'a>( - a: &'a Array, - s: impl Into>, - axes: impl Into>, - stream: StreamOrDevice, -) -> Array { - let (valid_s, valid_axes) = resolve_sizes_and_axes_unchecked(a, s, axes); - ifftn_device_inner(a, &valid_s, &valid_axes, stream) -} - -/// n-dimensional inverse discrete Fourier Transform. -/// -/// # Params -/// -/// - `a`: The input array. -/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or -/// padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes` -/// if not specified. -/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is -/// over the last `len(s)` axes are or all axes if `s` is also `None`. -#[default_device(device = "cpu")] -pub fn try_ifftn_device<'a>( - a: &'a Array, - s: impl Into>, - axes: impl Into>, - stream: StreamOrDevice, -) -> Result { - let (valid_s, valid_axes) = try_resolve_sizes_and_axes(a, s, axes)?; - Ok(ifftn_device_inner(a, &valid_s, &valid_axes, stream)) -} - -/// n-dimensional inverse discrete Fourier Transform. -/// -/// # Params -/// -/// - `a`: The input array. -/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or -/// padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes` -/// if not specified. -/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is -/// over the last `len(s)` axes are or all axes if `s` is also `None`. -/// -/// # Panic -/// -/// Panics if the input arguments are invalid. See [try_ifftn_device] for more details. -#[default_device(device = "cpu")] -pub fn ifftn_device<'a>( - a: &'a Array, - s: impl Into>, - axes: impl Into>, - stream: StreamOrDevice, -) -> Array { - try_ifftn_device(a, s, axes, stream).unwrap() -} diff --git a/src/fft/irfftn.rs b/src/fft/irfftn.rs deleted file mode 100644 index 3d50b3c9a..000000000 --- a/src/fft/irfftn.rs +++ /dev/null @@ -1,137 +0,0 @@ -use mlx_macros::default_device; - -use crate::{error::FftError, Array, StreamOrDevice}; - -use super::{ - resolve_size_and_axis_unchecked, resolve_sizes_and_axes_unchecked, try_resolve_size_and_axis, - try_resolve_sizes_and_axes, -}; - -#[default_device(device = "cpu")] -pub unsafe fn irfft_device_unchecked( - a: &Array, - n: impl Into>, - axis: impl Into>, - stream: StreamOrDevice, -) -> Array { - let (n, axis) = resolve_size_and_axis_unchecked(a, n, axis); - unsafe { - let c_array = mlx_sys::mlx_fft_irfft(a.c_array, n, axis, stream.stream.c_stream); - Array::from_ptr(c_array) - } -} - -#[default_device(device = "cpu")] -pub fn try_irfft_device( - a: &Array, - n: impl Into>, - axis: impl Into>, - stream: StreamOrDevice, -) -> Result { - let (n, axis) = try_resolve_size_and_axis(a, n, axis)?; - unsafe { Ok(irfft_device_unchecked(a, n, axis, stream)) } -} - -#[default_device(device = "cpu")] -pub fn irfft_device( - a: &Array, - n: impl Into>, - axis: impl Into>, - stream: StreamOrDevice, -) -> Array { - try_irfft_device(a, n, axis, stream).unwrap() -} - -fn irfft2_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { - let num_s = s.len(); - let num_axes = axes.len(); - - let s_ptr = s.as_ptr(); - let axes_ptr = axes.as_ptr(); - - unsafe { - let c_array = - mlx_sys::mlx_fft_irfft2(a.c_array, s_ptr, num_s, axes_ptr, num_axes, stream.as_ptr()); - Array::from_ptr(c_array) - } -} - -#[default_device(device = "cpu")] -pub unsafe fn irfft2_device_unchecked<'a>( - a: &'a Array, - s: impl Into>, - axes: impl Into>, - stream: StreamOrDevice, -) -> Array { - let axes = axes.into().unwrap_or(&[-2, -1]); - let (valid_s, valid_axes) = resolve_sizes_and_axes_unchecked(a, s, axes); - irfft2_device_inner(a, &valid_s, &valid_axes, stream) -} - -#[default_device(device = "cpu")] -pub fn try_irfft2_device<'a>( - a: &'a Array, - s: impl Into>, - axes: impl Into>, - stream: StreamOrDevice, -) -> Result { - let axes = axes.into().unwrap_or(&[-2, -1]); - let (valid_s, valid_axes) = try_resolve_sizes_and_axes(a, s, axes)?; - Ok(irfft2_device_inner(a, &valid_s, &valid_axes, stream)) -} - -#[default_device(device = "cpu")] -pub fn irfft2_device<'a>( - a: &'a Array, - s: impl Into>, - axes: impl Into>, - stream: StreamOrDevice, -) -> Array { - try_irfft2_device(a, s, axes, stream).unwrap() -} - -fn irfftn_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { - let num_s = s.len(); - let num_axes = axes.len(); - - let s_ptr = s.as_ptr(); - let axes_ptr = axes.as_ptr(); - - unsafe { - let c_array = - mlx_sys::mlx_fft_irfftn(a.c_array, s_ptr, num_s, axes_ptr, num_axes, stream.as_ptr()); - Array::from_ptr(c_array) - } -} - -#[default_device(device = "cpu")] -pub unsafe fn irfftn_device_unchecked<'a>( - a: &'a Array, - s: impl Into>, - axes: impl Into>, - stream: StreamOrDevice, -) -> Array { - let (valid_s, valid_axes) = resolve_sizes_and_axes_unchecked(a, s, axes); - irfftn_device_inner(a, &valid_s, &valid_axes, stream) -} - -#[default_device(device = "cpu")] -pub fn try_irfftn_device<'a>( - a: &'a Array, - s: impl Into>, - axes: impl Into>, - stream: StreamOrDevice, -) -> Result { - let (valid_s, valid_axes) = try_resolve_sizes_and_axes(a, s, axes)?; - Ok(irfftn_device_inner(a, &valid_s, &valid_axes, stream)) -} - -#[default_device(device = "cpu")] -pub fn irfftn_device<'a>( - a: &'a Array, - s: impl Into>, - axes: impl Into>, - stream: StreamOrDevice, -) -> Array { - try_irfftn_device(a, s, axes, stream).unwrap() -} diff --git a/src/fft/mod.rs b/src/fft/mod.rs index 02b484e72..41d73e0ca 100644 --- a/src/fft/mod.rs +++ b/src/fft/mod.rs @@ -168,8 +168,6 @@ //! ``` mod fftn; -mod ifftn; -mod irfftn; mod rfftn; use smallvec::SmallVec; @@ -180,7 +178,7 @@ use crate::{ Array, }; -pub use self::{fftn::*, ifftn::*, irfftn::*, rfftn::*}; +pub use self::{fftn::*, rfftn::*}; #[inline] fn resolve_size_and_axis_unchecked( diff --git a/src/fft/rfftn.rs b/src/fft/rfftn.rs index 6bc56482e..bebf24218 100644 --- a/src/fft/rfftn.rs +++ b/src/fft/rfftn.rs @@ -255,3 +255,211 @@ pub fn rfftn_device<'a>( ) -> Array { try_rfftn_device(a, s, axes, stream).unwrap() } + +#[default_device(device = "cpu")] +pub unsafe fn irfft_device_unchecked( + a: &Array, + n: impl Into>, + axis: impl Into>, + stream: StreamOrDevice, +) -> Array { + let (n, axis) = resolve_size_and_axis_unchecked(a, n, axis); + unsafe { + let c_array = mlx_sys::mlx_fft_irfft(a.c_array, n, axis, stream.stream.c_stream); + Array::from_ptr(c_array) + } +} + +#[default_device(device = "cpu")] +pub fn try_irfft_device( + a: &Array, + n: impl Into>, + axis: impl Into>, + stream: StreamOrDevice, +) -> Result { + let (n, axis) = try_resolve_size_and_axis(a, n, axis)?; + unsafe { Ok(irfft_device_unchecked(a, n, axis, stream)) } +} + +#[default_device(device = "cpu")] +pub fn irfft_device( + a: &Array, + n: impl Into>, + axis: impl Into>, + stream: StreamOrDevice, +) -> Array { + try_irfft_device(a, n, axis, stream).unwrap() +} + +fn irfft2_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { + let num_s = s.len(); + let num_axes = axes.len(); + + let s_ptr = s.as_ptr(); + let axes_ptr = axes.as_ptr(); + + unsafe { + let c_array = + mlx_sys::mlx_fft_irfft2(a.c_array, s_ptr, num_s, axes_ptr, num_axes, stream.as_ptr()); + Array::from_ptr(c_array) + } +} + +#[default_device(device = "cpu")] +pub unsafe fn irfft2_device_unchecked<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + let axes = axes.into().unwrap_or(&[-2, -1]); + let (valid_s, valid_axes) = resolve_sizes_and_axes_unchecked(a, s, axes); + irfft2_device_inner(a, &valid_s, &valid_axes, stream) +} + +#[default_device(device = "cpu")] +pub fn try_irfft2_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Result { + let axes = axes.into().unwrap_or(&[-2, -1]); + let (valid_s, valid_axes) = try_resolve_sizes_and_axes(a, s, axes)?; + Ok(irfft2_device_inner(a, &valid_s, &valid_axes, stream)) +} + +#[default_device(device = "cpu")] +pub fn irfft2_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + try_irfft2_device(a, s, axes, stream).unwrap() +} + +fn irfftn_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { + let num_s = s.len(); + let num_axes = axes.len(); + + let s_ptr = s.as_ptr(); + let axes_ptr = axes.as_ptr(); + + unsafe { + let c_array = + mlx_sys::mlx_fft_irfftn(a.c_array, s_ptr, num_s, axes_ptr, num_axes, stream.as_ptr()); + Array::from_ptr(c_array) + } +} + +#[default_device(device = "cpu")] +pub unsafe fn irfftn_device_unchecked<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + let (valid_s, valid_axes) = resolve_sizes_and_axes_unchecked(a, s, axes); + irfftn_device_inner(a, &valid_s, &valid_axes, stream) +} + +#[default_device(device = "cpu")] +pub fn try_irfftn_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Result { + let (valid_s, valid_axes) = try_resolve_sizes_and_axes(a, s, axes)?; + Ok(irfftn_device_inner(a, &valid_s, &valid_axes, stream)) +} + +#[default_device(device = "cpu")] +pub fn irfftn_device<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, + stream: StreamOrDevice, +) -> Array { + try_irfftn_device(a, s, axes, stream).unwrap() +} + +#[cfg(test)] +mod tests { + use crate::{complex64, fft::*, Array, Dtype}; + + #[test] + fn test_rfft() { + const RFFT_DATA: &[f32] = &[1.0, 2.0, 3.0, 4.0]; + const RFFT_SHAPE: &[i32] = &[4]; + const RFFT_N: i32 = 4; + const RFFT_AXIS: i32 = -1; + const RFFT_EXPECTED: &[complex64] = &[ + complex64::new(10.0, 0.0), + complex64::new(-2.0, 2.0), + complex64::new(-2.0, 0.0), + ]; + + let a = Array::from_slice(RFFT_DATA, RFFT_SHAPE); + let mut rfft = rfft(&a, RFFT_N, RFFT_AXIS); + rfft.eval(); + assert_eq!(rfft.dtype(), Dtype::Complex64); + assert_eq!(rfft.as_slice::(), RFFT_EXPECTED); + + let mut irfft = irfft(&rfft, RFFT_N, RFFT_AXIS); + irfft.eval(); + assert_eq!(irfft.dtype(), Dtype::Float32); + assert_eq!(irfft.as_slice::(), RFFT_DATA); + } + + #[test] + fn test_rfft2() { + const RFFT2_DATA: &[f32] = &[1.0; 4]; + const RFFT2_SHAPE: &[i32] = &[2, 2]; + const RFFT2_EXPECTED: &[complex64] = &[ + complex64::new(4.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + ]; + + let a = Array::from_slice(RFFT2_DATA, RFFT2_SHAPE); + let mut rfft2 = rfft2(&a, None, None); + rfft2.eval(); + assert_eq!(rfft2.dtype(), Dtype::Complex64); + assert_eq!(rfft2.as_slice::(), RFFT2_EXPECTED); + + let mut irfft2 = irfft2(&rfft2, None, None); + irfft2.eval(); + assert_eq!(irfft2.dtype(), Dtype::Float32); + assert_eq!(irfft2.as_slice::(), RFFT2_DATA); + } + + #[test] + fn test_rfftn() { + const RFFTN_DATA: &[f32] = &[1.0; 8]; + const RFFTN_SHAPE: &[i32] = &[2, 2, 2]; + const RFFTN_EXPECTED: &[complex64] = &[ + complex64::new(8.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + complex64::new(0.0, 0.0), + ]; + + let a = Array::from_slice(RFFTN_DATA, RFFTN_SHAPE); + let mut rfftn = rfftn(&a, None, None); + rfftn.eval(); + assert_eq!(rfftn.dtype(), Dtype::Complex64); + assert_eq!(rfftn.as_slice::(), RFFTN_EXPECTED); + + let mut irfftn = irfftn(&rfftn, None, None); + irfftn.eval(); + assert_eq!(irfftn.dtype(), Dtype::Float32); + assert_eq!(irfftn.as_slice::(), RFFTN_DATA); + } +} From 96b2208c0f50b9465ffd0600a2f9355996755e8c Mon Sep 17 00:00:00 2001 From: minghuaw Date: Wed, 24 Apr 2024 00:43:28 -0700 Subject: [PATCH 17/25] use *fftn in *fft2 & added missing docs for irfft* --- src/fft/fftn.rs | 50 ++------------- src/fft/rfftn.rs | 163 +++++++++++++++++++++++++++++++++++------------ 2 files changed, 126 insertions(+), 87 deletions(-) diff --git a/src/fft/fftn.rs b/src/fft/fftn.rs index 8ca9580af..595e31348 100644 --- a/src/fft/fftn.rs +++ b/src/fft/fftn.rs @@ -72,20 +72,6 @@ pub fn fft_device( try_fft_device(a, n, axis, stream).unwrap() } -fn fft2_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { - let num_s = s.len(); - let num_axes = axes.len(); - - let s_ptr = s.as_ptr(); - let axes_ptr = axes.as_ptr(); - - unsafe { - let c_array = - mlx_sys::mlx_fft_fft2(a.c_array, s_ptr, num_s, axes_ptr, num_axes, stream.as_ptr()); - Array::from_ptr(c_array) - } -} - /// Two dimensional discrete Fourier Transform. /// /// # Param @@ -102,8 +88,7 @@ pub unsafe fn fft2_device_unchecked<'a>( stream: StreamOrDevice, ) -> Array { let axes = axes.into().unwrap_or(&[-2, -1]); - let (valid_s, valid_axes) = resolve_sizes_and_axes_unchecked(a, s, axes); - fft2_device_inner(a, &valid_s, &valid_axes, stream) + fftn_device_unchecked(a, s, axes, stream) } /// Two dimensional discrete Fourier Transform. @@ -121,9 +106,8 @@ pub fn try_fft2_device<'a>( axes: impl Into>, stream: StreamOrDevice, ) -> Result { - let valid_axes = axes.into().unwrap_or(&[-2, -1]); - let (valid_s, valid_axes) = try_resolve_sizes_and_axes(a, s, valid_axes)?; - Ok(fft2_device_inner(a, &valid_s, &valid_axes, stream)) + let axes = axes.into().unwrap_or(&[-2, -1]); + try_fftn_device(a, s, axes, stream) } /// Two dimensional discrete Fourier Transform. @@ -295,28 +279,6 @@ pub fn ifft_device( try_ifft_device(a, n, axis, stream).unwrap() } -/// Two dimensional inverse discrete Fourier Transform. -/// -/// # Params -/// -/// - `a`: The input array. -/// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded -/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. -/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. -fn ifft2_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { - let num_s = s.len(); - let num_axes = axes.len(); - - let s_ptr = s.as_ptr(); - let axes_ptr = axes.as_ptr(); - - unsafe { - let c_array = - mlx_sys::mlx_fft_ifft2(a.c_array, s_ptr, num_s, axes_ptr, num_axes, stream.as_ptr()); - Array::from_ptr(c_array) - } -} - /// Two dimensional inverse discrete Fourier Transform. /// /// # Params @@ -333,8 +295,7 @@ pub unsafe fn ifft2_device_unchecked<'a>( stream: StreamOrDevice, ) -> Array { let axes = axes.into().unwrap_or(&[-2, -1]); - let (valid_s, valid_axes) = resolve_sizes_and_axes_unchecked(a, s, axes); - ifft2_device_inner(a, &valid_s, &valid_axes, stream) + ifftn_device_unchecked(a, s, axes, stream) } /// Two dimensional inverse discrete Fourier Transform. @@ -353,8 +314,7 @@ pub fn try_ifft2_device<'a>( stream: StreamOrDevice, ) -> Result { let axes = axes.into().unwrap_or(&[-2, -1]); - let (valid_s, valid_axes) = try_resolve_sizes_and_axes(a, s, axes)?; - Ok(ifft2_device_inner(a, &valid_s, &valid_axes, stream)) + try_ifftn_device(a, s, axes, stream) } /// Two dimensional inverse discrete Fourier Transform. diff --git a/src/fft/rfftn.rs b/src/fft/rfftn.rs index bebf24218..053ad10f4 100644 --- a/src/fft/rfftn.rs +++ b/src/fft/rfftn.rs @@ -79,25 +79,11 @@ pub fn rfft_device( try_rfft_device(a, n, axis, stream).unwrap() } -fn rfft2_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { - let num_s = s.len(); - let num_axes = axes.len(); - - let s_ptr = s.as_ptr(); - let axes_ptr = axes.as_ptr(); - - unsafe { - let c_array = - mlx_sys::mlx_fft_rfft2(a.c_array, s_ptr, num_s, axes_ptr, num_axes, stream.as_ptr()); - Array::from_ptr(c_array) - } -} - /// Two dimensional real discrete Fourier Transform. /// /// The output has the same shape as the input except along the dimensions in `axes` in which case /// it has sizes from `s`. The last axis in `axes` is treated as the real axis and will have size -/// `s[-1] // 2 + 1`. +/// `s[s.len()-1] // 2 + 1`. /// /// # Params /// @@ -113,15 +99,14 @@ pub unsafe fn rfft2_device_unchecked<'a>( stream: StreamOrDevice, ) -> Array { let axes = axes.into().unwrap_or(&[-2, -1]); - let (valid_s, valid_axes) = resolve_sizes_and_axes_unchecked(a, s, axes); - rfft2_device_inner(a, &valid_s, &valid_axes, stream) + rfftn_device_unchecked(a, s, axes, stream) } /// Two dimensional real discrete Fourier Transform. /// /// The output has the same shape as the input except along the dimensions in `axes` in which case /// it has sizes from `s`. The last axis in `axes` is treated as the real axis and will have size -/// `s[-1] // 2 + 1`. +/// `s[s.len()-1] // 2 + 1`. /// /// # Params /// @@ -137,15 +122,14 @@ pub fn try_rfft2_device<'a>( stream: StreamOrDevice, ) -> Result { let axes = axes.into().unwrap_or(&[-2, -1]); - let (valid_s, valid_axes) = try_resolve_sizes_and_axes(a, s, axes)?; - Ok(rfft2_device_inner(a, &valid_s, &valid_axes, stream)) + try_rfftn_device(a, s, axes, stream) } /// Two dimensional real discrete Fourier Transform. /// /// The output has the same shape as the input except along the dimensions in `axes` in which case /// it has sizes from `s`. The last axis in `axes` is treated as the real axis and will have size -/// `s[-1] // 2 + 1`. +/// `s[s.len()-1] // 2 + 1`. /// /// # Params /// @@ -185,7 +169,7 @@ fn rfftn_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice /// /// The output has the same shape as the input except along the dimensions in `axes` in which case /// it has sizes from `s`. The last axis in `axes` is treated as the real axis and will have size -/// `s[-1] // 2 + 1`. +/// `s[s.len()-1] // 2 + 1`. /// /// # Params /// @@ -209,7 +193,7 @@ pub unsafe fn rfftn_device_unchecked<'a>( /// /// The output has the same shape as the input except along the dimensions in `axes` in which case /// it has sizes from `s`. The last axis in `axes` is treated as the real axis and will have size -/// `s[-1] // 2 + 1`. +/// `s[s.len()-1] // 2 + 1`. /// /// # Params /// @@ -233,7 +217,7 @@ pub fn try_rfftn_device<'a>( /// /// The output has the same shape as the input except along the dimensions in `axes` in which case /// it has sizes from `s`. The last axis in `axes` is treated as the real axis and will have size -/// `s[-1] // 2 + 1`. +/// `s[s.len()-1] // 2 + 1`. /// /// # Params /// @@ -256,6 +240,16 @@ pub fn rfftn_device<'a>( try_rfftn_device(a, s, axes, stream).unwrap() } +/// The inverse of [`rfft()`]. +/// +/// The output has the same shape as the input except along axis in which case it has size n. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded +/// with zeros to match `n // 2 + 1`. The default value is `a.shape[axis] // 2 + 1`. +/// - `axis`: Axis along which to perform the FFT. The default is `-1`. #[default_device(device = "cpu")] pub unsafe fn irfft_device_unchecked( a: &Array, @@ -270,6 +264,16 @@ pub unsafe fn irfft_device_unchecked( } } +/// The inverse of [`rfft()`]. +/// +/// The output has the same shape as the input except along axis in which case it has size n. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded +/// with zeros to match `n // 2 + 1`. The default value is `a.shape[axis] // 2 + 1`. +/// - `axis`: Axis along which to perform the FFT. The default is `-1`. #[default_device(device = "cpu")] pub fn try_irfft_device( a: &Array, @@ -281,6 +285,16 @@ pub fn try_irfft_device( unsafe { Ok(irfft_device_unchecked(a, n, axis, stream)) } } +/// The inverse of [`rfft()`]. +/// +/// The output has the same shape as the input except along axis in which case it has size n. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `n`: Size of the transformed axis. The corresponding axis in the input is truncated or padded +/// with zeros to match `n // 2 + 1`. The default value is `a.shape[axis] // 2 + 1`. +/// - `axis`: Axis along which to perform the FFT. The default is `-1`. #[default_device(device = "cpu")] pub fn irfft_device( a: &Array, @@ -291,20 +305,19 @@ pub fn irfft_device( try_irfft_device(a, n, axis, stream).unwrap() } -fn irfft2_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevice) -> Array { - let num_s = s.len(); - let num_axes = axes.len(); - - let s_ptr = s.as_ptr(); - let axes_ptr = axes.as_ptr(); - - unsafe { - let c_array = - mlx_sys::mlx_fft_irfft2(a.c_array, s_ptr, num_s, axes_ptr, num_axes, stream.as_ptr()); - Array::from_ptr(c_array) - } -} - +/// The inverse of [`rfft2()`]. +/// +/// Note the input is generally complex. The dimensions of the input specified in `axes` are padded +/// or truncated to match the sizes from `s`. The last axis in `axes` is treated as the real axis +/// and will have size `s[s.len()-1] // 2 + 1`. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or +/// padded with zeros to match the sizes in `s` except for the last axis which has size +/// `s[s.len()-1] // 2 + 1`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. #[default_device(device = "cpu")] pub unsafe fn irfft2_device_unchecked<'a>( a: &'a Array, @@ -313,10 +326,22 @@ pub unsafe fn irfft2_device_unchecked<'a>( stream: StreamOrDevice, ) -> Array { let axes = axes.into().unwrap_or(&[-2, -1]); - let (valid_s, valid_axes) = resolve_sizes_and_axes_unchecked(a, s, axes); - irfft2_device_inner(a, &valid_s, &valid_axes, stream) + irfftn_device_unchecked(a, s, axes, stream) } +/// The inverse of [`rfft2()`]. +/// +/// Note the input is generally complex. The dimensions of the input specified in `axes` are padded +/// or truncated to match the sizes from `s`. The last axis in `axes` is treated as the real axis +/// and will have size `s[s.len()-1] // 2 + 1`. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or +/// padded with zeros to match the sizes in `s` except for the last axis which has size +/// `s[s.len()-1] // 2 + 1`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. #[default_device(device = "cpu")] pub fn try_irfft2_device<'a>( a: &'a Array, @@ -325,10 +350,22 @@ pub fn try_irfft2_device<'a>( stream: StreamOrDevice, ) -> Result { let axes = axes.into().unwrap_or(&[-2, -1]); - let (valid_s, valid_axes) = try_resolve_sizes_and_axes(a, s, axes)?; - Ok(irfft2_device_inner(a, &valid_s, &valid_axes, stream)) + try_irfftn_device(a, s, axes, stream) } +/// The inverse of [`rfft2()`]. +/// +/// Note the input is generally complex. The dimensions of the input specified in `axes` are padded +/// or truncated to match the sizes from `s`. The last axis in `axes` is treated as the real axis +/// and will have size `s[s.len()-1] // 2 + 1`. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or +/// padded with zeros to match the sizes in `s` except for the last axis which has size +/// `s[s.len()-1] // 2 + 1`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. #[default_device(device = "cpu")] pub fn irfft2_device<'a>( a: &'a Array, @@ -353,6 +390,20 @@ fn irfftn_device_inner(a: &Array, s: &[i32], axes: &[i32], stream: StreamOrDevic } } +/// The inverse of [`rfftn()`]. +/// +/// Note the input is generally complex. The dimensions of the input specified in `axes` are padded +/// or truncated to match the sizes from `s`. The last axis in `axes` is treated as the real axis +/// and will have size `s[s.len()-1] // 2 + 1`. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or +/// padded with zeros to match the sizes in `s` except for the last axis which has size +/// `s[s.len()-1] // 2 + 1`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is +/// over the last `len(s)` axes or all axes if `s` is also `None`. #[default_device(device = "cpu")] pub unsafe fn irfftn_device_unchecked<'a>( a: &'a Array, @@ -364,6 +415,20 @@ pub unsafe fn irfftn_device_unchecked<'a>( irfftn_device_inner(a, &valid_s, &valid_axes, stream) } +/// The inverse of [`rfftn()`]. +/// +/// Note the input is generally complex. The dimensions of the input specified in `axes` are padded +/// or truncated to match the sizes from `s`. The last axis in `axes` is treated as the real axis +/// and will have size `s[s.len()-1] // 2 + 1`. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or +/// padded with zeros to match the sizes in `s` except for the last axis which has size +/// `s[s.len()-1] // 2 + 1`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is +/// over the last `len(s)` axes or all axes if `s` is also `None`. #[default_device(device = "cpu")] pub fn try_irfftn_device<'a>( a: &'a Array, @@ -375,6 +440,20 @@ pub fn try_irfftn_device<'a>( Ok(irfftn_device_inner(a, &valid_s, &valid_axes, stream)) } +/// The inverse of [`rfftn()`]. +/// +/// Note the input is generally complex. The dimensions of the input specified in `axes` are padded +/// or truncated to match the sizes from `s`. The last axis in `axes` is treated as the real axis +/// and will have size `s[s.len()-1] // 2 + 1`. +/// +/// # Params +/// +/// - `a`: The input array. +/// - `s`: Sizes of the transformed axes. The corresponding axes in the input are truncated or +/// padded with zeros to match the sizes in `s` except for the last axis which has size +/// `s[s.len()-1] // 2 + 1`. The default value is the sizes of `a` along `axes`. +/// - `axes`: Axes along which to perform the FFT. The default is `None` in which case the FFT is +/// over the last `len(s)` axes or all axes if `s` is also `None`. #[default_device(device = "cpu")] pub fn irfftn_device<'a>( a: &'a Array, From 2952da101ed0a248012eba67063e37e394f68f66 Mon Sep 17 00:00:00 2001 From: minghuaw Date: Wed, 24 Apr 2024 02:44:24 -0700 Subject: [PATCH 18/25] Update src/fft/fftn.rs Co-authored-by: David Chavez --- src/fft/fftn.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fft/fftn.rs b/src/fft/fftn.rs index 595e31348..a9b9b9a6d 100644 --- a/src/fft/fftn.rs +++ b/src/fft/fftn.rs @@ -285,7 +285,7 @@ pub fn ifft_device( /// /// - `a`: The input array. /// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded -/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. +/// with zeros to match `s`. The default value is the sizes of `a` along `axes`. /// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. #[default_device(device = "cpu")] pub unsafe fn ifft2_device_unchecked<'a>( From 518d9b46f34612852ad0faa657bb48729b49a964 Mon Sep 17 00:00:00 2001 From: minghuaw Date: Wed, 24 Apr 2024 02:45:01 -0700 Subject: [PATCH 19/25] Update src/fft/fftn.rs Co-authored-by: David Chavez --- src/fft/fftn.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fft/fftn.rs b/src/fft/fftn.rs index a9b9b9a6d..9f7a8a053 100644 --- a/src/fft/fftn.rs +++ b/src/fft/fftn.rs @@ -78,7 +78,7 @@ pub fn fft_device( /// /// - `a`: The input array. /// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded -/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. +/// with zeros to match `s`. The default value is the sizes of `a` along `axes`. /// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. #[default_device(device = "cpu")] pub unsafe fn fft2_device_unchecked<'a>( From 5ab2134d45d510e6fd87148f138618a7ad19d1e9 Mon Sep 17 00:00:00 2001 From: minghuaw Date: Wed, 24 Apr 2024 02:45:28 -0700 Subject: [PATCH 20/25] Update src/fft/fftn.rs Co-authored-by: David Chavez --- src/fft/fftn.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fft/fftn.rs b/src/fft/fftn.rs index 9f7a8a053..071a4afe0 100644 --- a/src/fft/fftn.rs +++ b/src/fft/fftn.rs @@ -304,7 +304,7 @@ pub unsafe fn ifft2_device_unchecked<'a>( /// /// - `a`: The input array. /// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded -/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. +/// with zeros to match `s`. The default value is the sizes of `a` along `axes`. /// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. #[default_device(device = "cpu")] pub fn try_ifft2_device<'a>( From 0f15e19714ad3c94dc0b1fbaa15d8f591909b0ce Mon Sep 17 00:00:00 2001 From: minghuaw Date: Wed, 24 Apr 2024 02:46:23 -0700 Subject: [PATCH 21/25] Update src/fft/fftn.rs Co-authored-by: David Chavez --- src/fft/fftn.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fft/fftn.rs b/src/fft/fftn.rs index 071a4afe0..d0a08e77e 100644 --- a/src/fft/fftn.rs +++ b/src/fft/fftn.rs @@ -97,7 +97,7 @@ pub unsafe fn fft2_device_unchecked<'a>( /// /// - `a`: The input array. /// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded -/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. +/// with zeros to match `s`. The default value is the sizes of `a` along `axes`. /// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. #[default_device(device = "cpu")] pub fn try_fft2_device<'a>( From 82fdb9612add0e492c32accc7d84d4e93db3e52b Mon Sep 17 00:00:00 2001 From: minghuaw Date: Wed, 24 Apr 2024 02:46:44 -0700 Subject: [PATCH 22/25] Update src/error.rs Co-authored-by: David Chavez --- src/error.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/error.rs b/src/error.rs index 1757e39dd..cb9b96752 100644 --- a/src/error.rs +++ b/src/error.rs @@ -63,7 +63,7 @@ pub enum FftError { #[error("Shape and axes/axis have different sizes")] IncompatibleShapeAndAxes { shape_size: usize, axes_size: usize }, - #[error("Duplcated axis received: {axis}")] + #[error("Duplicate axis received: {axis}")] DuplicateAxis { axis: i32 }, #[error("Invalid output size requested")] From ad1ac1cea7c3f593700736747818dcf7ec3767df Mon Sep 17 00:00:00 2001 From: minghuaw Date: Wed, 24 Apr 2024 02:47:10 -0700 Subject: [PATCH 23/25] Update src/fft/fftn.rs Co-authored-by: David Chavez --- src/fft/fftn.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fft/fftn.rs b/src/fft/fftn.rs index d0a08e77e..d2a51a96c 100644 --- a/src/fft/fftn.rs +++ b/src/fft/fftn.rs @@ -116,7 +116,7 @@ pub fn try_fft2_device<'a>( /// /// - `a`: The input array. /// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded -/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. +/// with zeros to match `s`. The default value is the sizes of `a` along `axes`. /// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. /// /// # Panic From 8c4ab42db6907940ef47f751a7f4be9c5de300e9 Mon Sep 17 00:00:00 2001 From: minghuaw Date: Wed, 24 Apr 2024 02:47:36 -0700 Subject: [PATCH 24/25] Update src/fft/fftn.rs Co-authored-by: David Chavez --- src/fft/fftn.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fft/fftn.rs b/src/fft/fftn.rs index d2a51a96c..d9ebba6db 100644 --- a/src/fft/fftn.rs +++ b/src/fft/fftn.rs @@ -323,7 +323,7 @@ pub fn try_ifft2_device<'a>( /// /// - `a`: The input array. /// - `s`: Size of the transformed axes. The corresponding axes in the input are truncated or padded -/// with zeros to match `n`. The default value is the sizes of `a` along `axes`. +/// with zeros to match `s`. The default value is the sizes of `a` along `axes`. /// - `axes`: Axes along which to perform the FFT. The default is `[-2, -1]`. /// /// # Panic From 3df96d301255b3da7a607dc4ca9003bd38549bc6 Mon Sep 17 00:00:00 2001 From: minghuaw Date: Wed, 24 Apr 2024 03:00:25 -0700 Subject: [PATCH 25/25] moved helper fn to fft/utils.rs --- src/fft/fftn.rs | 2 +- src/fft/mod.rs | 238 +---------------------------------------------- src/fft/rfftn.rs | 2 +- src/fft/utils.rs | 236 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 239 insertions(+), 239 deletions(-) create mode 100644 src/fft/utils.rs diff --git a/src/fft/fftn.rs b/src/fft/fftn.rs index 595e31348..c753d695f 100644 --- a/src/fft/fftn.rs +++ b/src/fft/fftn.rs @@ -2,7 +2,7 @@ use mlx_macros::default_device; use crate::{array::Array, error::FftError, stream::StreamOrDevice}; -use super::{ +use super::utils::{ resolve_size_and_axis_unchecked, resolve_sizes_and_axes_unchecked, try_resolve_size_and_axis, try_resolve_sizes_and_axes, }; diff --git a/src/fft/mod.rs b/src/fft/mod.rs index 41d73e0ca..f3f867e3c 100644 --- a/src/fft/mod.rs +++ b/src/fft/mod.rs @@ -169,242 +169,6 @@ mod fftn; mod rfftn; - -use smallvec::SmallVec; - -use crate::{ - error::FftError, - utils::{all_unique, resolve_index, resolve_index_unchecked}, - Array, -}; +mod utils; pub use self::{fftn::*, rfftn::*}; - -#[inline] -fn resolve_size_and_axis_unchecked( - a: &Array, - n: impl Into>, - axis: impl Into>, -) -> (i32, i32) { - let axis = axis.into().unwrap_or(-1); - let n = n.into().unwrap_or_else(|| { - let axis_index = resolve_index_unchecked(axis, a.ndim()); - a.shape()[axis_index] - }); - (n, axis) -} - -#[inline] -fn try_resolve_size_and_axis( - a: &Array, - n: impl Into>, - axis: impl Into>, -) -> Result<(i32, i32), FftError> { - if a.ndim() < 1 { - return Err(FftError::ScalarArray); - } - - let axis = axis.into().unwrap_or(-1); - let axis_index = - resolve_index(axis, a.ndim()).ok_or_else(|| FftError::InvalidAxis { ndim: a.ndim() })?; - let n = n.into().unwrap_or(a.shape()[axis_index]); - - if n <= 0 { - return Err(FftError::InvalidOutputSize); - } - - Ok((n, axis)) -} - -#[inline] -fn resolve_sizes_and_axes_unchecked<'a>( - a: &'a Array, - s: impl Into>, - axes: impl Into>, -) -> (SmallVec<[i32; 4]>, SmallVec<[i32; 4]>) { - match (s.into(), axes.into()) { - (Some(s), Some(axes)) => { - let valid_s = SmallVec::<[i32; 4]>::from_slice(s); - let valid_axes = SmallVec::<[i32; 4]>::from_slice(axes); - (valid_s, valid_axes) - } - (Some(s), None) => { - let valid_s = SmallVec::<[i32; 4]>::from_slice(s); - let valid_axes = (-(valid_s.len() as i32)..0).collect(); - (valid_s, valid_axes) - } - (None, Some(axes)) => { - let valid_s = axes - .iter() - .map(|&axis| { - let axis_index = resolve_index_unchecked(axis, a.ndim()); - a.shape()[axis_index] - }) - .collect(); - let valid_axes = SmallVec::<[i32; 4]>::from_slice(axes); - (valid_s, valid_axes) - } - (None, None) => { - let valid_s: SmallVec<[i32; 4]> = (0..a.ndim()).map(|axis| a.shape()[axis]).collect(); - let valid_axes = (-(valid_s.len() as i32)..0).collect(); - (valid_s, valid_axes) - } - } -} - -// It's probably rare to perform fft on more than 4 axes -// TODO: check if this is a good default value -#[inline] -fn try_resolve_sizes_and_axes<'a>( - a: &'a Array, - s: impl Into>, - axes: impl Into>, -) -> Result<(SmallVec<[i32; 4]>, SmallVec<[i32; 4]>), FftError> { - if a.ndim() < 1 { - return Err(FftError::ScalarArray); - } - - let (valid_s, valid_axes) = match (s.into(), axes.into()) { - (Some(s), Some(axes)) => { - let valid_s = SmallVec::<[i32; 4]>::from_slice(s); - let valid_axes = SmallVec::<[i32; 4]>::from_slice(axes); - (valid_s, valid_axes) - } - (Some(s), None) => { - let valid_s = SmallVec::<[i32; 4]>::from_slice(s); - let valid_axes = (-(valid_s.len() as i32)..0).collect(); - (valid_s, valid_axes) - } - (None, Some(axes)) => { - // SmallVec somehow doesn't implement FromIterator with result - let mut valid_s = SmallVec::<[i32; 4]>::new(); - for &axis in axes { - let axis_index = resolve_index(axis, a.ndim()) - .ok_or_else(|| FftError::InvalidAxis { ndim: a.ndim() })?; - valid_s.push(a.shape()[axis_index]); - } - let valid_axes = SmallVec::<[i32; 4]>::from_slice(axes); - (valid_s, valid_axes) - } - (None, None) => { - let valid_s: SmallVec<[i32; 4]> = (0..a.ndim()).map(|axis| a.shape()[axis]).collect(); - let valid_axes = (-(valid_s.len() as i32)..0).collect(); - (valid_s, valid_axes) - } - }; - - // Check duplicate axes - all_unique(&valid_axes).map_err(|axis| FftError::DuplicateAxis { axis })?; - - // Check if shape and axes have the same size - if valid_s.len() != valid_axes.len() { - return Err(FftError::IncompatibleShapeAndAxes { - shape_size: valid_s.len(), - axes_size: valid_axes.len(), - }); - } - - // Check if more axes are provided than the array has - if valid_s.len() > a.ndim() { - return Err(FftError::InvalidAxis { ndim: a.ndim() }); - } - - // Check if output sizes are valid - if valid_s.iter().any(|val| *val <= 0) { - return Err(FftError::InvalidOutputSize); - } - - Ok((valid_s, valid_axes)) -} - -#[cfg(test)] -mod try_resolve_size_and_axis_tests { - use crate::Array; - - use super::{try_resolve_size_and_axis, FftError}; - - #[test] - fn scalar_array_returns_error() { - // Returns an error if the array is a scalar - let a = Array::from_float(1.0); - let result = try_resolve_size_and_axis(&a, 0, 0); - assert_eq!(result, Err(FftError::ScalarArray)); - } - - #[test] - fn out_of_bound_axis_returns_error() { - // Returns an error if the axis is invalid (out of bounds) - let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); - let result = try_resolve_size_and_axis(&a, 0, 1); - assert_eq!(result, Err(FftError::InvalidAxis { ndim: 1 })); - } - - #[test] - fn negative_output_size_returns_error() { - // Returns an error if the output size is negative - let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); - let result = try_resolve_size_and_axis(&a, -1, 0); - assert_eq!(result, Err(FftError::InvalidOutputSize)); - } - - #[test] - fn valid_input_returns_sizes_and_axis() { - // Returns the output size and axis if the input is valid - let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); - let result = try_resolve_size_and_axis(&a, 4, 0); - assert_eq!(result, Ok((4, 0))); - } -} - -#[cfg(test)] -mod try_resolve_sizes_and_axes_tests { - use crate::Array; - - use super::{try_resolve_sizes_and_axes, FftError}; - - #[test] - fn scalar_array_returns_error() { - // Returns an error if the array is a scalar - let a = Array::from_float(1.0); - let result = try_resolve_sizes_and_axes(&a, None, None); - assert_eq!(result, Err(FftError::ScalarArray)); - } - - #[test] - fn out_of_bound_axis_returns_error() { - // Returns an error if the axis is invalid (out of bounds) - let a = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); - let result = try_resolve_sizes_and_axes(&a, &[2, 2, 2][..], &[0, 1, 2][..]); - assert_eq!(result, Err(FftError::InvalidAxis { ndim: 2 })); - } - - #[test] - fn different_num_sizes_and_num_axes_returns_error() { - // Returns an error if the number of sizes and axes are different - let a = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); - let result = try_resolve_sizes_and_axes(&a, &[2, 2, 2][..], &[0, 1][..]); - assert_eq!( - result, - Err(FftError::IncompatibleShapeAndAxes { - shape_size: 3, - axes_size: 2 - }) - ); - } - - #[test] - fn duplicate_axes_returns_error() { - // Returns an error if there are duplicate axes - let a = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); - let result = try_resolve_sizes_and_axes(&a, &[2, 2][..], &[0, 0][..]); - assert_eq!(result, Err(FftError::DuplicateAxis { axis: 0 })); - } - - #[test] - fn negative_output_size_returns_error() { - // Returns an error if the output size is negative - let a = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); - let result = try_resolve_sizes_and_axes(&a, &[-2, 2][..], None); - assert_eq!(result, Err(FftError::InvalidOutputSize)); - } -} diff --git a/src/fft/rfftn.rs b/src/fft/rfftn.rs index 053ad10f4..c7b94c4a4 100644 --- a/src/fft/rfftn.rs +++ b/src/fft/rfftn.rs @@ -2,7 +2,7 @@ use mlx_macros::default_device; use crate::{error::FftError, Array, StreamOrDevice}; -use super::{ +use super::utils::{ resolve_size_and_axis_unchecked, resolve_sizes_and_axes_unchecked, try_resolve_size_and_axis, try_resolve_sizes_and_axes, }; diff --git a/src/fft/utils.rs b/src/fft/utils.rs new file mode 100644 index 000000000..0085020e7 --- /dev/null +++ b/src/fft/utils.rs @@ -0,0 +1,236 @@ +use smallvec::SmallVec; + +use crate::{ + error::FftError, + utils::{all_unique, resolve_index, resolve_index_unchecked}, + Array, +}; + +#[inline] +pub(super) fn resolve_size_and_axis_unchecked( + a: &Array, + n: impl Into>, + axis: impl Into>, +) -> (i32, i32) { + let axis = axis.into().unwrap_or(-1); + let n = n.into().unwrap_or_else(|| { + let axis_index = resolve_index_unchecked(axis, a.ndim()); + a.shape()[axis_index] + }); + (n, axis) +} + +#[inline] +pub(super) fn try_resolve_size_and_axis( + a: &Array, + n: impl Into>, + axis: impl Into>, +) -> Result<(i32, i32), FftError> { + if a.ndim() < 1 { + return Err(FftError::ScalarArray); + } + + let axis = axis.into().unwrap_or(-1); + let axis_index = + resolve_index(axis, a.ndim()).ok_or_else(|| FftError::InvalidAxis { ndim: a.ndim() })?; + let n = n.into().unwrap_or(a.shape()[axis_index]); + + if n <= 0 { + return Err(FftError::InvalidOutputSize); + } + + Ok((n, axis)) +} + +#[inline] +pub(super) fn resolve_sizes_and_axes_unchecked<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, +) -> (SmallVec<[i32; 4]>, SmallVec<[i32; 4]>) { + match (s.into(), axes.into()) { + (Some(s), Some(axes)) => { + let valid_s = SmallVec::<[i32; 4]>::from_slice(s); + let valid_axes = SmallVec::<[i32; 4]>::from_slice(axes); + (valid_s, valid_axes) + } + (Some(s), None) => { + let valid_s = SmallVec::<[i32; 4]>::from_slice(s); + let valid_axes = (-(valid_s.len() as i32)..0).collect(); + (valid_s, valid_axes) + } + (None, Some(axes)) => { + let valid_s = axes + .iter() + .map(|&axis| { + let axis_index = resolve_index_unchecked(axis, a.ndim()); + a.shape()[axis_index] + }) + .collect(); + let valid_axes = SmallVec::<[i32; 4]>::from_slice(axes); + (valid_s, valid_axes) + } + (None, None) => { + let valid_s: SmallVec<[i32; 4]> = (0..a.ndim()).map(|axis| a.shape()[axis]).collect(); + let valid_axes = (-(valid_s.len() as i32)..0).collect(); + (valid_s, valid_axes) + } + } +} + +// It's probably rare to perform fft on more than 4 axes +// TODO: check if this is a good default value +#[inline] +pub(super) fn try_resolve_sizes_and_axes<'a>( + a: &'a Array, + s: impl Into>, + axes: impl Into>, +) -> Result<(SmallVec<[i32; 4]>, SmallVec<[i32; 4]>), FftError> { + if a.ndim() < 1 { + return Err(FftError::ScalarArray); + } + + let (valid_s, valid_axes) = match (s.into(), axes.into()) { + (Some(s), Some(axes)) => { + let valid_s = SmallVec::<[i32; 4]>::from_slice(s); + let valid_axes = SmallVec::<[i32; 4]>::from_slice(axes); + (valid_s, valid_axes) + } + (Some(s), None) => { + let valid_s = SmallVec::<[i32; 4]>::from_slice(s); + let valid_axes = (-(valid_s.len() as i32)..0).collect(); + (valid_s, valid_axes) + } + (None, Some(axes)) => { + // SmallVec somehow doesn't implement FromIterator with result + let mut valid_s = SmallVec::<[i32; 4]>::new(); + for &axis in axes { + let axis_index = resolve_index(axis, a.ndim()) + .ok_or_else(|| FftError::InvalidAxis { ndim: a.ndim() })?; + valid_s.push(a.shape()[axis_index]); + } + let valid_axes = SmallVec::<[i32; 4]>::from_slice(axes); + (valid_s, valid_axes) + } + (None, None) => { + let valid_s: SmallVec<[i32; 4]> = (0..a.ndim()).map(|axis| a.shape()[axis]).collect(); + let valid_axes = (-(valid_s.len() as i32)..0).collect(); + (valid_s, valid_axes) + } + }; + + // Check duplicate axes + all_unique(&valid_axes).map_err(|axis| FftError::DuplicateAxis { axis })?; + + // Check if shape and axes have the same size + if valid_s.len() != valid_axes.len() { + return Err(FftError::IncompatibleShapeAndAxes { + shape_size: valid_s.len(), + axes_size: valid_axes.len(), + }); + } + + // Check if more axes are provided than the array has + if valid_s.len() > a.ndim() { + return Err(FftError::InvalidAxis { ndim: a.ndim() }); + } + + // Check if output sizes are valid + if valid_s.iter().any(|val| *val <= 0) { + return Err(FftError::InvalidOutputSize); + } + + Ok((valid_s, valid_axes)) +} + +#[cfg(test)] +mod try_resolve_size_and_axis_tests { + use crate::Array; + + use super::{try_resolve_size_and_axis, FftError}; + + #[test] + fn scalar_array_returns_error() { + // Returns an error if the array is a scalar + let a = Array::from_float(1.0); + let result = try_resolve_size_and_axis(&a, 0, 0); + assert_eq!(result, Err(FftError::ScalarArray)); + } + + #[test] + fn out_of_bound_axis_returns_error() { + // Returns an error if the axis is invalid (out of bounds) + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let result = try_resolve_size_and_axis(&a, 0, 1); + assert_eq!(result, Err(FftError::InvalidAxis { ndim: 1 })); + } + + #[test] + fn negative_output_size_returns_error() { + // Returns an error if the output size is negative + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let result = try_resolve_size_and_axis(&a, -1, 0); + assert_eq!(result, Err(FftError::InvalidOutputSize)); + } + + #[test] + fn valid_input_returns_sizes_and_axis() { + // Returns the output size and axis if the input is valid + let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]); + let result = try_resolve_size_and_axis(&a, 4, 0); + assert_eq!(result, Ok((4, 0))); + } +} + +#[cfg(test)] +mod try_resolve_sizes_and_axes_tests { + use crate::Array; + + use super::{try_resolve_sizes_and_axes, FftError}; + + #[test] + fn scalar_array_returns_error() { + // Returns an error if the array is a scalar + let a = Array::from_float(1.0); + let result = try_resolve_sizes_and_axes(&a, None, None); + assert_eq!(result, Err(FftError::ScalarArray)); + } + + #[test] + fn out_of_bound_axis_returns_error() { + // Returns an error if the axis is invalid (out of bounds) + let a = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); + let result = try_resolve_sizes_and_axes(&a, &[2, 2, 2][..], &[0, 1, 2][..]); + assert_eq!(result, Err(FftError::InvalidAxis { ndim: 2 })); + } + + #[test] + fn different_num_sizes_and_num_axes_returns_error() { + // Returns an error if the number of sizes and axes are different + let a = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); + let result = try_resolve_sizes_and_axes(&a, &[2, 2, 2][..], &[0, 1][..]); + assert_eq!( + result, + Err(FftError::IncompatibleShapeAndAxes { + shape_size: 3, + axes_size: 2 + }) + ); + } + + #[test] + fn duplicate_axes_returns_error() { + // Returns an error if there are duplicate axes + let a = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); + let result = try_resolve_sizes_and_axes(&a, &[2, 2][..], &[0, 0][..]); + assert_eq!(result, Err(FftError::DuplicateAxis { axis: 0 })); + } + + #[test] + fn negative_output_size_returns_error() { + // Returns an error if the output size is negative + let a = Array::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[2, 2]); + let result = try_resolve_sizes_and_axes(&a, &[-2, 2][..], None); + assert_eq!(result, Err(FftError::InvalidOutputSize)); + } +}