-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Test for read-only arrays #205
Conversation
3c2f31d
to
6884a34
Compare
This seems related to the suggestions at #146. Ping @mdhaber @lucascolley to check if these helpers will be useful in SciPy. |
x-ref data-apis/array-api#609, cc @jakevdp @rgommers |
Yes, it looks like it! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, looks useful! The diffs will be dauntingly large I suppose, but hopefully trivial enough!
array_api_compat/common/_helpers.py
Outdated
if is_jax_array(x): | ||
return x.at | ||
if is_numpy_array(x) and not x.flags.writeable: | ||
x = x.copy() | ||
return _InPlaceAt(x, idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if I understand correctly, this implementation works for:
- JAX
- NumPy
- Mutable arrays (arrays for which the in-place updates work)
It doesn't work for immutable arrays coming from other libraries, for the reasons discussed in data-apis/array-api#845.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It works for all arrays explicitly listed in array-api-compat, including numpy subclasses.
It won't work for immutable arrays from other, unknown libraries - in fact, they won't be even recognized as immutable. Without a standardized way for is_immutable_array
to detect such a use case (__array_writeable__
interface?) I would not know how to implement it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would attempting to mutate the array as a fallback work, or could that lead to undesired side-effects?
# Something along these lines, but more robust
a = x[0]
try:
x[0] = 0
x[0] = a
mutable = True
except:
mutable = False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that reimplementing at[]
without the ability to update and without specific knowledge of the library at hand would be extraordinarily inefficient.
Consider at(x, 0).add(1)
. Implementing it on top of where
would be extremely expensive; an alternative would be to break the array down and rebuild it with concatenate
and stack
, which would be very complicated and probably fragile.
The snippet of code you just wrote would perform very poorly in many cases. Consider:
- a library like dask: the
x[0] = a
line adds extra labour into the graph, which is very likely nontrivial - any library where the memory transparently moves from device to host and vice versa:
a = x[0]
would likely cause either a page or the whole array to do just that.
A somewhat slightly more reasonable alternative would be to
- blindly try the update
- if the update fails, try if by any chance the library has a
at[]
method with exactly the same API as JAX
At the moment we're only doing (1).
My personal opinion would be to explicitly cater for these libraries if and when they crop up.
Is this whole discussion hypothetical, or do we know of a specific read-only library other than JAX?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just realised that sparse does not support updates, so it is going to face the same issues in scipy as JAX. Has there been any discussion about it?
Do you mean the diffs inside scipy? Yes they will. Not much that can be done about it I'm afraid. I gave this some more thought and reworked the PR.
In scipy and similar libraries, we should replace all instances of x[idx] += y with x = at(x, idx).add(y, copy=None) (read below for the masked use case). Offline comment by @rgommers :
There is discussion elsewhere that something should be pushed into the standard.
Masked JAX arraysI've deliberately omitted special-casing for masked JAX arrays, unlike in @rgommers 's POC here https://github.com/scipy/scipy/compare/main...rgommers:scipy:array-types-inplace-ops?expand=1. IMHO, the problem with the POC code e.g. if is_jax(xp):
if hasattr(idx, 'dtype') and xp.isdtype(idx.dtype, 'bool'):
x = xp.where(idx, x + y, x)
else:
x = x.at[idx].add(y)
else:
x[idx] += val is that it produces incorrect results when y has shape other than This strongly feels to me less of a generic problem with read-only arrays (which array-api-compat and scipy should cater for) and more of a JAX-specific quirk (needing to know all shapes in advance).
Until that happens, scipy code will work with JAX, but it will break when jitted. CC @jakevdp |
I think I disagree. Everything in array-api-extra should work with any standard-compatible library, whereas this implementation will not work with immutable arrays from libraries not covered by array-api-compat (there doesn't seem to be any way it could). Meanwhile, none of the helper functions in array-api-compat (https://data-apis.org/array-api-compat/helper-functions.html) are in the standard1 - that's where this PR is targeting. I suppose the worry is that this is effectively arguing for the standardisation of Footnotes
|
I think there is room to make function behaviour vary on https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.info.capabilities.html. But I don't think we should add this to array-api-extra without some sort of array-level flag for mutability in the standard (data-apis/array-api#845). |
I think it's a fuzzy zone whether this belongs in compat or in extra. The scope for array-api-compat is spelled out here https://data-apis.org/array-api-compat/#scope. On the one hand, the
(by the way, we should add a reference to array-api-extra to that scope section) |
IMO the bar is finding agreement on a sensible API in data-apis/array-api#609 (and here). But we don't need to have any confidence that existing array libraries will adopt it, since the implementations can live in array-api-compat if need be. |
Note that I renamed it to |
ca33a66
to
c8f6613
Compare
I think this is ready for a second round of consultation. sparse
JAX apply() and ufuncsI can't seem to make JAX's apply() work: >>> import jax.numpy as jnp
>>> import numpy as np
>>> a = jnp.array([1,2,3])
>>> a.at[:2].apply(np.negative)
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int32[] This works without JIT... >>> a.at[:2].set(np.negative(a[:2]))
Array([-1, -2, 3], dtype=int32) ... but it crashes with >>> np.negative(a)
array([-1, -2, -3], dtype=int32) Given
I'm inclined to completely remove apply() from the API. PyTorch ufuncsApplying a ufunc to a torch array raises a warning, which makes the tests for min, max, and apply fail. ```python
>>> import torch
>>> import numpy
>>> a = torch.asarray([1,2,3])
>>> numpy.negative(a)
DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)
tensor([-1, -2, -3]) |
Looking a little closer at this, my personal opinion is that it would make more sense at the moment to put It also looks like this is trying to use NumPy functions uniformly, regardless of whether the underlying library is NumPy. That's going to work for specific libraries that implement the NumPy interop (
|
Just to be clear, someone else (probably @rgommers) should make the actual decision on whether or not this belongs here or elsewhere. |
I've replaced np.minimum with xp.minimum - is this what you were referring to? |
Yes, that was what I meant. And also the apply function which you already mentioned you want to remove. |
In JAX, you can apply JAX functions, not NumPy functions: >>> import jax.numpy as jnp
>>> a = jnp.array([1,2,3])
>>> a.at[:2].apply(jnp.negative)
Array([-1, -2, 3], dtype=int32) That said, I don't think Same with the other snippet: it will work under JIT if you use |
This is ready for final review and merge. |
https://numpy.org/neps/nep-0047-array-api-standard.html. | ||
This is a small wrapper around NumPy, CuPy, JAX, sparse and others that is | ||
compatible with the Array API standard https://data-apis.org/array-api/latest/. | ||
See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(unrelated to this PR, but probably no need to link to that NEP now)
8a43d2a
to
6caf96a
Compare
@lucascolley this is ready to be merged |
Sounds good, but I don't have merge permissions on this repo. |
Would be helpful to summarize the current scope. So ATM, this PR only adds a single public function, in the vein of data-apis/array-api#845, correct? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, seems ready to merge.
Then it'd be helpful to spell the plan w.r.t. the spec: if the spec gets something in
capabilities
, will it obviate the need for this function, or will this function wrap it or....?
I wouldn't worry about that here. The chance doesn't seem too high that the standard gets this (soon at least), and if it does then wrapping it if needed is straightforward. I'd merge this as is.
Ready to merge |
Thanks @crusaderky, all. |
New public function
is_writeable_array
, which introduce transparent support for read-only backends such as JAX (but more may be added in the future).[EDIT] this PR also implemented an
at
function, mocking the syntax of JAX's omonymous method.This function will be proposed in array-api-extra instead.