Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add missing at functions #193

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

davidkoski
Copy link
Collaborator

- fixes ml-explore#188
- also fixes bug with array[idx] += 3 that would _not_ assign back to array
@davidkoski davidkoski requested a review from awni February 13, 2025 21:07
import Cmlx
import Foundation

public struct ArrayAt {
Copy link
Collaborator Author

@davidkoski davidkoski Feb 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • needs documentation

}
}

public struct ArrayAtIndices {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • needs documentation

let indexOperations: [MLXArrayIndexOperation]
let stream: StreamOrDevice

public func add(_ values: ScalarOrArray) -> MLXArray {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the pattern on the python side:

mx::array mlx_subtract_item(
    const mx::array& src,
    const nb::object& obj,
    const ScalarOrArray& v) {
  auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
  if (indices.size() > 0) {
    return scatter_add(src, indices, -updates, axes);
  } else {
    return src - updates;
  }
}

This is the same but using the mlx-c API.

@@ -62,7 +62,7 @@ extension MLXArray {
/// ### See Also
/// - <doc:arithmetic>
/// - ``add(_:_:stream:)``
public static func += (lhs: MLXArray, rhs: MLXArray) {
public static func += (lhs: inout MLXArray, rhs: MLXArray) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found in unit tests:

        // this references each index twice
        let idx = MLXArray([0, 1, 0, 1])

        // similar to above -- we can only observe one assignment, so we just get a +1
        // note: there was a bug in the += operator where the lhs was not inout and
        // this was producing [0, 0]
        let a1 = MLXArray([0, 0])
        a1[idx] += 1
        assertEqual(a1, MLXArray([1, 1]))

Without the inout it seems that:

array[idx] += 2

is roughly equivalent to:

tmp = array[idx]
tmp += 2

let a = MLXArray([1, 2, 3])
let b = MLXArray(converting: [-5.0, 37.5, 4])
var a = MLXArray([1, 2, 3])
var b = MLXArray(converting: [-5.0, 37.5, 4])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of the inout change this needs to be var now

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

}

public func divide(_ values: ScalarOrArray) -> MLXArray {
multiply(1 / values.asMLXArray(dtype: array.dtype))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Total nitpick but I would use reciprocal here to avoid the extra array and broadcast op (although likely negligible).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • use reciprocal

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

add missing at() function
2 participants