-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add missing at functions #193
base: main
Are you sure you want to change the base?
Conversation
- fixes ml-explore#188 - also fixes bug with array[idx] += 3 that would _not_ assign back to array
import Cmlx | ||
import Foundation | ||
|
||
public struct ArrayAt { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- needs documentation
} | ||
} | ||
|
||
public struct ArrayAtIndices { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- needs documentation
let indexOperations: [MLXArrayIndexOperation] | ||
let stream: StreamOrDevice | ||
|
||
public func add(_ values: ScalarOrArray) -> MLXArray { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the pattern on the python side:
mx::array mlx_subtract_item(
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
return scatter_add(src, indices, -updates, axes);
} else {
return src - updates;
}
}
This is the same but using the mlx-c API.
@@ -62,7 +62,7 @@ extension MLXArray { | |||
/// ### See Also | |||
/// - <doc:arithmetic> | |||
/// - ``add(_:_:stream:)`` | |||
public static func += (lhs: MLXArray, rhs: MLXArray) { | |||
public static func += (lhs: inout MLXArray, rhs: MLXArray) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Found in unit tests:
// this references each index twice
let idx = MLXArray([0, 1, 0, 1])
// similar to above -- we can only observe one assignment, so we just get a +1
// note: there was a bug in the += operator where the lhs was not inout and
// this was producing [0, 0]
let a1 = MLXArray([0, 0])
a1[idx] += 1
assertEqual(a1, MLXArray([1, 1]))
Without the inout
it seems that:
array[idx] += 2
is roughly equivalent to:
tmp = array[idx]
tmp += 2
let a = MLXArray([1, 2, 3]) | ||
let b = MLXArray(converting: [-5.0, 37.5, 4]) | ||
var a = MLXArray([1, 2, 3]) | ||
var b = MLXArray(converting: [-5.0, 37.5, 4]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because of the inout
change this needs to be var
now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
} | ||
|
||
public func divide(_ values: ScalarOrArray) -> MLXArray { | ||
multiply(1 / values.asMLXArray(dtype: array.dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Total nitpick but I would use reciprocal
here to avoid the extra array and broadcast op (although likely negligible).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- use reciprocal
at()
function #188