Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api): add fft #37

Merged
merged 31 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
e0455b3
adding fft mod
minghuaw Apr 18, 2024
5efba3a
output dtype of fft is complex64
minghuaw Apr 20, 2024
2a98380
Merge branch 'dev' into api/fft
minghuaw Apr 20, 2024
cd3fd30
add device input to default_device macro
minghuaw Apr 20, 2024
502add8
use macro to generate non device fft unchecked
minghuaw Apr 20, 2024
c2f41cb
Merge branch 'dev' into api/fft
minghuaw Apr 20, 2024
4cab6a7
test try fft
minghuaw Apr 20, 2024
02f6c0d
use unit test as example
minghuaw Apr 21, 2024
c3cc17e
cargo fmt
minghuaw Apr 21, 2024
facf23d
added fft_device
minghuaw Apr 21, 2024
34b3c80
follow original mlx indexing behavior
minghuaw Apr 21, 2024
3e00b2a
cargo fmt
minghuaw Apr 21, 2024
d0aa5d5
Merge branch 'dev' into api/fft
minghuaw Apr 22, 2024
16f1e88
added fft2
minghuaw Apr 22, 2024
5ef1d45
Merge branch 'dev' into api/fft
minghuaw Apr 23, 2024
892d55e
impl fftn
minghuaw Apr 23, 2024
fc776b6
check for negative output size in one dim fft
minghuaw Apr 23, 2024
551ebbe
init impl of ifftn, rfftn, irfftn
minghuaw Apr 23, 2024
ae17488
added docs
minghuaw Apr 23, 2024
71dc6ee
re-organize fft mod and unit tests
minghuaw Apr 23, 2024
d17d159
Merge branch 'dev' into api/fft
minghuaw Apr 23, 2024
96b2208
use *fftn in *fft2 & added missing docs for irfft*
minghuaw Apr 24, 2024
2952da1
Update src/fft/fftn.rs
minghuaw Apr 24, 2024
518d9b4
Update src/fft/fftn.rs
minghuaw Apr 24, 2024
5ab2134
Update src/fft/fftn.rs
minghuaw Apr 24, 2024
0f15e19
Update src/fft/fftn.rs
minghuaw Apr 24, 2024
82fdb96
Update src/error.rs
minghuaw Apr 24, 2024
ad1ac1c
Update src/fft/fftn.rs
minghuaw Apr 24, 2024
8c4ab42
Update src/fft/fftn.rs
minghuaw Apr 24, 2024
3df96d3
moved helper fn to fft/utils.rs
minghuaw Apr 24, 2024
6ba9597
Merge branch 'api/fft' of https://github.com/oxideai/mlx-rs into api/fft
minghuaw Apr 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ num-traits = "0.2.18"
paste = "1.0.14"
strum = { version = "0.26", features = ["derive"] }
thiserror = "1.0.58"
smallvec = "1"

[dev-dependencies]
pretty_assertions = "1.4.0"
Expand Down
1 change: 1 addition & 0 deletions mlx-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ proc-macro = true
[dependencies]
syn = { version = "2.0.60", features = ["full"] }
quote = "1.0"
darling = "0.20"
50 changes: 47 additions & 3 deletions mlx-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,47 @@
extern crate proc_macro;
use darling::FromMeta;
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::punctuated::Punctuated;
use syn::{parse_macro_input, parse_quote, DeriveInput, FnArg, ItemFn, Pat};

#[derive(Debug, FromMeta)]
enum DeviceType {
Cpu,
Gpu,
}

#[derive(Debug)]
struct DefaultDeviceInput {
device: DeviceType,
}

impl FromMeta for DefaultDeviceInput {
fn from_meta(meta: &syn::Meta) -> darling::Result<Self> {
let syn::Meta::NameValue(meta_name_value) = meta else {
return Err(darling::Error::unsupported_format(
"expected a name-value attribute",
));
};

let ident = meta_name_value.path.get_ident().unwrap();
assert_eq!(ident, "device", "expected `device`");

let device = DeviceType::from_expr(&meta_name_value.value)?;

Ok(DefaultDeviceInput { device })
}
}

#[proc_macro_attribute]
pub fn default_device(_attr: TokenStream, item: TokenStream) -> TokenStream {
pub fn default_device(attr: TokenStream, item: TokenStream) -> TokenStream {
let input = if !attr.is_empty() {
let meta = syn::parse_macro_input!(attr as syn::Meta);
Some(DefaultDeviceInput::from_meta(&meta).unwrap())
} else {
None
};

let mut input_fn = parse_macro_input!(item as ItemFn);
let original_fn = input_fn.clone();

Expand Down Expand Up @@ -37,8 +73,16 @@ pub fn default_device(_attr: TokenStream, item: TokenStream) -> TokenStream {
input_fn.sig.inputs = Punctuated::from_iter(filtered_inputs);

// Prepend default stream initialization
let default_stream_stmt = parse_quote! {
let stream = StreamOrDevice::default();
let default_stream_stmt = match input.map(|input| input.device) {
Some(DeviceType::Cpu) => parse_quote! {
let stream = StreamOrDevice::cpu();
},
Some(DeviceType::Gpu) => parse_quote! {
let stream = StreamOrDevice::gpu();
},
None => parse_quote! {
let stream = StreamOrDevice::default();
},
};
input_fn.block.stmts.insert(0, default_stream_stmt);

Expand Down
3 changes: 2 additions & 1 deletion src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ impl Device {

/// Set the default device.
///
/// Example:
/// # Example:
///
/// ```rust
/// use mlx::{Device, DeviceType};
/// Device::set_default(&Device::new(DeviceType::Cpu, 1));
Expand Down
18 changes: 18 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,21 @@ pub enum AsSliceError {
#[error("dtype mismatch: expected {expecting:?}, found {found:?}")]
DtypeMismatch { expecting: Dtype, found: Dtype },
}

#[derive(Error, Debug, PartialEq)]
pub enum FftError {
#[error("fftn requires at least one dimension")]
ScalarArray,

#[error("Invalid axis received for array with {ndim} dimensions")]
InvalidAxis { ndim: usize },

#[error("Shape and axes/axis have different sizes")]
minghuaw marked this conversation as resolved.
Show resolved Hide resolved
IncompatibleShapeAndAxes { shape_size: usize, axes_size: usize },

#[error("Duplicate axis received: {axis}")]
DuplicateAxis { axis: i32 },

#[error("Invalid output size requested")]
InvalidOutputSize,
}
Loading
Loading