Skip to content

Commit

Permalink
feat(api): impl Clone and Debug for Device and Stream (#40)
Browse files Browse the repository at this point in the history
* impl Clone for Device, Stream & StreamOrDevice

* deny missing_debug_impl
  • Loading branch information
minghuaw committed Jul 16, 2024
1 parent e43c9ff commit ecd7de1
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 5 deletions.
5 changes: 3 additions & 2 deletions examples/tutorial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ fn scalar_basics() {
assert_eq!(shape, vec![]);
}

#[allow(unused_variables)]
fn array_basics() {
// make a multidimensional array.
let x: Array = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
Expand All @@ -40,10 +41,10 @@ fn array_basics() {
let y = Array::ones::<f32>(&[2, 2]);

// Pointwise add x and y:
let mut z = x.add(&y);
let z = x.add(&y);

// Same thing:
z = &x + &y;
let mut z = &x + &y;

// mlx is lazy by default. At this point `z` only
// has a shape and a type but no actual data:
Expand Down
8 changes: 8 additions & 0 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,14 @@ pub struct Array {
pub(crate) c_array: mlx_array,
}

impl std::fmt::Debug for Array {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let description = crate::utils::mlx_describe(self.c_array as *mut c_void)
.unwrap_or_else(|| "Array".to_string());
write!(f, "{:?}", description)
}
}

impl std::fmt::Display for Array {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let description = crate::utils::mlx_describe(self.c_array as *mut c_void)
Expand Down
44 changes: 42 additions & 2 deletions src/device.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use mlx_sys::mlx_retain;

use crate::utils::mlx_describe;

///Type of device.
#[derive(num_enum::IntoPrimitive)]
#[derive(num_enum::IntoPrimitive, Debug, Clone, Copy)]
#[repr(u32)]
pub enum DeviceType {
Cpu = mlx_sys::mlx_device_type__MLX_CPU,
Gpu = mlx_sys::mlx_device_type__MLX_GPU,
}

/// Representation of a Device in MLX.
#[derive(Debug)]
pub struct Device {
pub(crate) c_device: mlx_sys::mlx_device,
}
Expand Down Expand Up @@ -42,6 +43,35 @@ impl Device {
}
}

/// The `Device` is a simple struct on the c++ side
///
/// ```cpp
/// struct Device {
/// enum class DeviceType {
/// cpu,
/// gpu,
/// };
///
/// // ... other methods
///
/// DeviceType type;
/// int index;
/// };
/// ```
///
/// There is no function that mutates the device, so we can implement `Clone` for it.
impl Clone for Device {
fn clone(&self) -> Self {
unsafe {
// Increment the reference count.
mlx_retain(self.c_device as *mut std::ffi::c_void);
Self {
c_device: self.c_device.clone(),
}
}
}
}

impl Drop for Device {
fn drop(&mut self) {
unsafe { mlx_sys::mlx_free(self.c_device as *mut std::ffi::c_void) };
Expand All @@ -55,6 +85,15 @@ impl Default for Device {
}
}

impl std::fmt::Debug for Device {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let description = mlx_describe(self.c_device as *mut std::os::raw::c_void);
let description = description.unwrap_or_else(|| "Device".to_string());

write!(f, "{}", description)
}
}

impl std::fmt::Display for Device {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let description = mlx_describe(self.c_device as *mut std::os::raw::c_void);
Expand All @@ -72,6 +111,7 @@ mod tests {
fn test_fmt() {
let device = Device::default();
let description = format!("{}", device);
println!("{:?}", device);
assert_eq!(description, "Device(gpu, 0)");
}
}
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#![deny(unused_unsafe)]
#![deny(unused_unsafe, missing_debug_implementations)]

mod array;
mod device;
Expand Down
39 changes: 39 additions & 0 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::utils::mlx_describe;
///
/// If omitted it will use the [default()], which will be [Device::gpu()] unless
/// set otherwise.
#[derive(Clone)]
pub struct StreamOrDevice {
stream: Stream,
}
Expand Down Expand Up @@ -54,6 +55,12 @@ impl Default for StreamOrDevice {
}
}

impl std::fmt::Debug for StreamOrDevice {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.stream)
}
}

impl std::fmt::Display for StreamOrDevice {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.stream)
Expand Down Expand Up @@ -90,6 +97,30 @@ impl Stream {
}
}

/// The `Stream` is a simple struct on the c++ side
///
/// ```cpp
/// struct Stream {
/// int index;
/// Device device;
///
/// // ... constructor
/// };
/// ```
///
/// There is no function that mutates the stream, so we can implement `Clone` for it.
impl Clone for Stream {
fn clone(&self) -> Self {
unsafe {
// Increment the reference count.
mlx_sys::mlx_retain(self.c_stream as *mut std::ffi::c_void);
Stream {
c_stream: self.c_stream,
}
}
}
}

impl Drop for Stream {
fn drop(&mut self) {
unsafe { mlx_sys::mlx_free(self.c_stream as *mut std::ffi::c_void) };
Expand All @@ -102,6 +133,14 @@ impl Default for Stream {
}
}

impl std::fmt::Debug for Stream {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let description = mlx_describe(self.c_stream as *mut std::os::raw::c_void);
let description = description.unwrap_or_else(|| "Stream".to_string());
write!(f, "{}", description)
}
}

impl std::fmt::Display for Stream {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let description = mlx_describe(self.c_stream as *mut std::os::raw::c_void);
Expand Down

0 comments on commit ecd7de1

Please sign in to comment.