-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More dask fixes #16
More dask fixes #16
Conversation
Add a new optional parameter `callback` to least_squares. A user-defined function that is called on each iteration, can stop optimization by returning True or raising StopIteration. Same signature and API as in eg. scipy.optimize.differential_evolution. Only implemented for trf and dogleg methods so far.
Fix linter errors
# Trick to get the array-api-compat namespace for dask | ||
# (otherwise the "naked" dask.array asarray does not respect | ||
# the input dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I fixed this in dask/dask#11288?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I skimmed that patch really quick - but at a quick glance it looks like that only fixes the issue when the input is a dask array.
When the input is a numpy array it looks like we go into the from_array
codepath which is not passed the dtype option?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If my patch didn't fix it then this should be a new bug report for Dask - we shouldn't have to work around that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll poke around more inside dask to double check if this is still an issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to be fixed in 2024.12.1 at least - gonna take this out.
Maybe I was on too old of a dask version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in dask/dask#11586
# use array-api-compat namespace for dask | ||
# since dask asarray never makes a copy | ||
# which makes xp_copy silently a no-op | ||
xp = array_namespace(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we adjust xp_copy
to special-case and copy for Dask instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are currently passing in xp into xp_copy
in these tests which is the problem.
i.e. xp_copy(arr, xp=<bad dask.array xp>)
If we don't pass in xp_copy, I've changed array_namespace
in our vendored array-api-compat to use the wrapped dask.array namespace by default (will upstream this later).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be easier to just do these copies NumPy side so we don't need to involve xp_copy
at all:
x = np.random.rand(100)
x2 = x.copy()
x = xp.asarray(x)
x2 = xp.asarray(x2)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also drop the xp argument to xp_copy
(which would be a smaller diff).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
happy with that!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, no - x2
should be an array from the unwrapped namespace before it is passed into func
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I've changed all of these back to using numpy, let me know if there's any more left.
Going to hold off on this until Looks like dasks is switching to always copying in |
…is slow due to a bug with std::priority_queue
…s.test_shortest_path
Remove trailing spaces, unused variables and imports, change reference of Fibonacci heap to priority queue
Add missing spaces between tests.
model the structure on scipy.ndimage: - add `_delegators.py` with *_signature functions - add _signal_api.py to collect "bare" imports from _private modules - add _support_alternative_backends.py to decorate "bare" functions - in __init__.py, import decorated names from _support_alternative_backends.py
stats is now fully passing as well. |
scipy/_lib/_util.py
Outdated
if is_dask_namespace(xp) or is_jax_namespace(xp): | ||
# TODO: verify for jax | ||
return xp.where(cond, f(arrays[0], arrays[1]), f2(arrays[0], arrays[1]) if not fillvalue else fillvalue) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cf. scipy#22070, cc @crusaderky.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think @mdhaber was also interested in changes to _lazywhere.
(Sorry for the slow ping!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks fine to me if you need something that works with Dask and JAX.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f2 can be None. IMHO _lazywhere should be moved to array-api-extras before such backend-specific hacks take place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I was try to hack around _lazywhere
since IIRC @mdhaber mentioned _lazywhere
didn't seem to be much of an optimization compared to just doing the regular where, and so might be removed.
Happy to wait for your jax/array-api-extra changes to land, though.
Thanks for the pointer about f2, though.
I'll patch in a later commit, but strange that it wasn't hit in tests.
x2 = xp_copy(x, xp=xp) | ||
x2 = xp.asarray(x_np.copy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you remind me why xp_copy
doesn't work for dask?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think xp_copy
does an asarray(..., copy=True)
to make a copy, but dask silently ignores the copy
flag in its asarray
implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, this may be for older Dask only? data-apis/array-api-compat#211 seemed to suggest that copies are now always made even when copy=None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a test-specific issue where xp is the naked dask module. The wrapper returned by array_namespace
has a functioning copy=True.
MAINT: improve overflow handling in factorial functions
TST: array types: enforce namespace in tests
…hinx` 0.17 (scipy#22161) * DOC, MAINT: Drop `convert_notebooks.py` * DOC: Use Markdown notebooks with JupyterLite * MAINT: Pin jupyterlite-sphinx to >=0.17.1 * DOC: Use the `NotebookLite` directive instead * DOC: Add a "Download" button for the notebooks [docs only]
@lucascolley Are you able to do one last rebase to pick up the lazywhere changes for jax on scipy main? |
[skip cirrus] [skip circle]
[skip cirrus] [skip circle]
[skip cirrus] [skip circle]
[skip cirrus] [skip circle]
Co-authored-by: Guido Imperiale <[email protected]>
oops, the rebase went wrong for some reason. Let me try to fix it. |
I'll open a new PR to SciPy |
Reference issue
What does this implement/fix?
Additional information