-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update jax
config import
#199
base: main
Are you sure you want to change the base?
Conversation
Hi @alberthli , thanks for the PR! To understand a bit better, it is my understand jax deprecated this import recently. Did you check by any chance if this works for older jax version as well? Currently, we require (P.S. don't worry update the failing tests, the failures are unrelated to this PR, I think, and a regression in the CI) |
Hi @gomezzz, I haven't checked whether it works for earlier versions. This commit from 6 months ago seems relevant if you want to control versioning, though. |
Hi @alberthli ! Sorry for the long delay. Yes, it looks good and works mostly with jax>=0.4.17 (having trouble to test with earlier version due to issues with jax). Maybe we should bump the recommended version of jax though. Changes for that would be:
Would you mind updating it, @alberthli in this PR? Otherwise I can do it. I have been trying to run the tests locally on CPU too but I have a problem with jax now.
I think the middle one just has a too aggressive threshold but the other two seem to be changes in the way jax behaves? Errors are thrown here
and
So either we just change the type in the test or might have to make a change in the set_precision.py. If this is too much, @alberthli , we can also move that to a dedicated issue since it is not directly related to your changes, if you prefer? Thanks and sorry again for the delays! |
Hi @gomezzz, my bandwidth is very limited in the next couple of weeks, so I probably won't be able to write this PR. I think separating the testing issues into a separate issue is a good idea, though perhaps that fix should be merged with this one together. |
Hi, |
Hi @HGangloff , Sounds good, please go ahead, thanks! In case you have an idea why the datatypes in the tests changed (which should have been set via the here modified |
This PR changes the import statement for
jax.config
, which resolves the error