Skip to content
/ mlx-rs Public

Unofficial Rust bindings to Apple's mlx framework

License

Apache-2.0, MIT licenses found

Licenses found

Apache-2.0
LICENSE-APACHE
MIT
LICENSE-MIT
Notifications You must be signed in to change notification settings

oxideai/mlx-rs

Repository files navigation

mlx-rs

Rust bindings for Apple's mlx machine learning library.

Discord Current Crates.io Version Documentation Test Status Blaze Rust Version license

⚠️ Project is in active development - contributors welcome!



Blaze supports this project by providing ultra-fast Apple Silicon macOS Github Action Runners. Apply the discount code AI25 at checkout to enjoy 25% off your first year.

Documentation

Due to known limitation of docsrs, we are hosting the documentation on github pages here.

Features

MLX is an array framework for machine learning on Apple Silicon. mlx-rs provides Rust bindings for MLX, allowing you to use MLX in your Rust projects.

Some key features of MLX and mlx-rs include:

  • Performance: MLX is optimized for Apple Silicon, providing fast performance for machine learning tasks.
  • Lazy Evaluation: MLX uses lazy evaluation to optimize performance and memory usage. Arrays are only materialized when needed.
  • Dynamic Graphs: Computation graphs in MLX are constructed dynamically, allowing for flexible and efficient computation. Changing the shapes of function arguments does not require recompilation.
  • Mutli-Device Support: MLX supports running computations on any of the supported devices (for now the CPU and GPU).
  • Unified memory: MLX provides a unified memory model, meaning arrays live in the same memory space, regardless of the device they are computed on. Operations can be performed on arrays on different devices without copying data between them.

mlx-rs is designed to be a safe and idiomatic Rust interface to MLX, providing a seamless experience for Rust developers.

Examples

The examples directory contains sample projects demonstrating different uses cases of our library.

  • mnist: Train a basic neural network on the MNIST digit dataset
  • mistral: Text generation using the pre-trained Mistral model

Installation

Add this to your Cargo.toml:

[dependencies]
mlx-rs = "0.21.0"

Feature Flags

  • metal - enables metal (GPU) usage in MLX
  • accelerate - enables using the accelerate framework in MLX

Important Notes on Automatic Differentiation

When using automatic differentiation in mlx-rs, there's an important difference in how closures work compared to Python's MLX. In Python, variables are implicitly captured and properly traced in the compute graph. However, in Rust, we need to be more explicit about which arrays should be traced.

❌ This approach may cause segfaults:

// Don't do this
let x = random::normal::<f32>(&[num_examples, num_features], None, None, None)?;
let y = x.matmul(&w_star)? + eps;

let loss_fn = |w: &Array| -> Result<Array, Exception> {
    let y_pred = x.matmul(w)?;  // x and y are captured from outer scope
    let loss = Array::from_float(0.5) * ops::mean(&ops::square(&(y_pred - &y))?, None, None)?;
    Ok(loss)
};

let grad_fn = transforms::grad(loss_fn, &[0]);

✅ Instead, pass all required arrays as inputs to ensure proper tracing:

let loss_fn = |inputs: &[Array]| -> Result<Array, Exception> {
    let w = &inputs[0];
    let x = &inputs[1];
    let y = &inputs[2];

    let y_pred = x.matmul(w)?;
    let loss = Array::from_float(0.5) * ops::mean(&ops::square(y_pred - y)?, None, None)?;
    Ok(loss)
};
let argnums = &[0];  // Specify which argument to differentiate with respect to

// Pass all required arrays in the inputs slice
let mut inputs = vec![w, x, y];
let grad = transforms::grad(loss_fn, argnums)(&inputs)?;

When using gradients in training loops, remember to update the appropriate array in your inputs:

let mut inputs = vec![w, x, y];

for _ in 0..num_iterations {
    let grad = transforms::grad(loss_fn, argnums)(&inputs)?;
    inputs[0] = &inputs[0] - Array::from_float(learning_rate) * grad;  // Update the weight array
    inputs[0].eval()?;
}

We are actively working on improving this API to make it more ergonomic and closer to Python's behavior. For now, explicitly passing all required arrays as shown above is the recommended approach.

Versioning

For simplicity, the main crate mls-rs follows MLX’s versioning, allowing you to easily see which MLX version you’re using under the hood. The mlx-sys crate follows the versioning of mlx-c, as that is the version from which the API is generated.

Community

If you are excited about the project or want to contribute, don't hesitate to join our Discord! We try to be as welcoming as possible to everybody from any background. We're still building this out, but you can ask your questions there!

Status

mlx-rs is currently in active development and can be used to run MLX models in Rust.

MSRV

The minimum supported Rust version is 1.81.0.

The MSRV is the minimum Rust version that can be used to compile each crate.

License

mlx-rs is distributed under the terms of both the MIT license and the Apache License (Version 2.0). See LICENSE-APACHE and LICENSE-MIT for details. Opening a pull request is assumed to signal agreement with these licensing terms.

About

Unofficial Rust bindings to Apple's mlx framework

Resources

License

Apache-2.0, MIT licenses found

Licenses found

Apache-2.0
LICENSE-APACHE
MIT
LICENSE-MIT

Stars

Watchers

Forks

Packages

No packages published

Contributors 3

  •  
  •  
  •  

Languages