-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New PRNG in 0.5.0 has degraded pseudorandom capability #26019
Comments
JAX also runs a suite of KS tests for all its distributions, and these succeed with both the old and the new PRNG. One thing to note: these kinds of tests are notoriously difficult to tune, particularly if you're running many of them. For example, if you have 100 parameterized KS tests, then statistically you'd expect expect 5 to fail with Is it possible that equinox is running into this expected failure? |
To be precise we run 12 such tests ( FWIW the failing tests we have are here: https://github.com/patrick-kidger/diffrax/actions/runs/12865365374/job/35924910896?pr=569 (The above also has some other stochastic tests failing, in which the numerical convergence order of some SDE solvers has veered from expected. These are just calibrated 'by eye' rather than in any real statistical sense, but perhaps still offer some evidence of degraded capability.) Now, notably the failures are all in the cases in which we do some additional and moreover, I have just tried rerunning the above tests with 10x and 100x the number of samples (passed in to If I had to make a hypothesis based on this, it's that |
I noticed some daily unit tests started failing with 0.5.0 on internal sampling code that depends on equinox (but not diffrax), although equinox doesn't really do anything with RNG (to my knowledge). My tests don't actually test statistical randomness, but it's surprising these all worked before and would now fail if the results are just as random (just a different random number), because I ran it with many different keys and didn't see a failure before (similar to "These are just calibrated 'by eye' rather than in any real statistical sense, but perhaps still offer some evidence of degraded capability."). I will try to isolate a MVC. |
For reference here btw Equinox doesn't do anything PRNG-related at all. Still just a fancy PyTree library 😄 |
My setup was messed up and that was erroneous! See below. |
Thanks for reporting this, @patrick-kidger and @lockwo ! |
Actually the finding from my previous comment was incorrect, since my env was messed up. If we change this line to So it does seem like something is up with splitting... |
Something I seem to notice is that splitting seems ok for n=2, but for n=3, 4 I see some issues (preliminary, but just something related to this hypothesis). In fact, in diffrax, if you just replace key, state_key = jax.random.split(key)
key, init_key_w = jax.random.split(key)
key, init_key_kk = jax.random.split(key)
key, init_key_hh = jax.random.split(key) it passes the test.
As someone who makes active use of functions such as equinox while loops and unvmaps and error checks, it seems more than "just" a pytree library 😃 |
Note that the PRNG sampler did change in 0.5.0, so if your tests depend on the values of particular random draws, this failure would be expected. See the Change log and the Update note for details. |
I think we might be tilting at windmills here. Look at these amazing (handcrated, no AI) shell scripts: $ seq 0 100 | xargs -P 32 -I \{\} bash -c 'JAX_THREEFRY_PARTITIONABLE=1 JAX_RANDOM_SEED_OFFSET={} pytest test/test_brownian.py::test_statistics[True-SpaceTimeLevyArea-VirtualBrownianTree]; echo $? >> /tmp/kidgerverse' &> /dev/null; echo (grep -c 1 /tmp/kidgerverse); rm /tmp/kidgerverse
28
$ seq 0 100 | xargs -P 32 -I \{\} bash -c 'JAX_THREEFRY_PARTITIONABLE=0 JAX_RANDOM_SEED_OFFSET={} pytest test/test_brownian.py::test_statistics[True-SpaceTimeLevyArea-VirtualBrownianTree]; echo $? >> /tmp/kidgerverse' &> /dev/null; echo (grep -c 1 /tmp/kidgerverse); rm /tmp/kidgerverse
28 The first line counts for how many seeds we see the test fail in jax==0.5.0. The latter effectively counts for how many seeds we see the test fail in jax==0.4.x. Coincidentally they generate the exact same number, but for In both cases this is a much higher failure rate than you might expect from the confidence level of 0.01. I would guess that's because as we increase the number of samples (e.g. to 60k in this case) we drive down the statistical error, but leave any numerical errors unchanged. (Numerical errors could look like: our empirical distribution isn't converging to the true distribution but only because we're using bf16 or f32 or something rather than real numbers.) For the p statistic to be uniformly distributed under the null, we probably need fewer samples. WDYT? Can you check these results to see if I messed up? |
0.5.0 pulls into the lead at 200 seed offsets: $ seq 0 200 | xargs -P 32 -I \{\} bash -c 'JAX_THREEFRY_PARTITIONABLE=0 JAX_RANDOM_SEED_OFFSET={} pytest test/test_brownian.py::test_statistics[True-SpaceTimeLevyArea-VirtualBrownianTree]; echo $? >> /tmp/kidgerverse' &> /dev/null; echo (grep -c 1 /tmp/kidgerverse); rm /tmp/kidgerverse
62
$ seq 0 200 | xargs -P 32 -I \{\} bash -c 'JAX_THREEFRY_PARTITIONABLE=1 JAX_RANDOM_SEED_OFFSET={} pytest test/test_brownian.py::test_statistics[True-SpaceTimeLevyArea-VirtualBrownianTree]; echo $? >> /tmp/kidgerverse' &> /dev/null; echo (grep -c 1 /tmp/kidgerverse); rm /tmp/kidgerverse
47 (This is just noise...) |
For readers wondering what $ JAX_RANDOM_SEED_OFFSET=0 python -c 'import jax.random as jr; print(jr.normal(jr.key(7)))'
0.45123515
$ JAX_RANDOM_SEED_OFFSET=3 python -c 'import jax.random as jr; print(jr.normal(jr.key(4)))'
0.45123515 This is meant as a means of checking a program's sensitivity to random seed choice, as we are doing here. |
Okay! I reduced the underlying Diffrax code to a MWE -- I think you're right, and that I can explain what happened. Whilst we do have several KS tests, with several of them now failing, and with them testing different codepaths, it just so happens that some of those codepaths split keys in similar ways! This means that we see the tests grouping together so that many of them will either mutually pass or mutually fail. We got unlucky (2 splits failed, a 1-in-100 event) rather than statistically-improbably-unlucky (4 tests failed, a 1-in-10000 event). I'm really sorry for taking up your time folks -- and I really appreciate everyone jumping in. |
Thanks for the follow-up @patrick-kidger ! I'd much rather have things surfaced eagerly and then realize they were false alarms than to instead have any missed detections. Randomness is so subtle that it demands extra scrutiny. Hopefully we've all learned never to use randomness. |
Description
Something we just spotted in patrick-kidger/diffrax#569 is that several of Diffrax's tests are now failing as of the 0.5.0 release. We run Kolmogorov-Smirnov tests to check the randomness of our Brownian motion samples, and these are now dramatically exceeding thresholds: with
p = 0.01
the null hypothesis (that our random samples are correctly normally distributed) is rejected, in favour of the alternate hypothesis (that our samples are from non-normal distributions).Setting
JAX_THREEFRY_PARTITIONABLE=False
as per #18480 fixes the issue (p=0.61
).As you might imagine a MWE is somewhat tricky to isolate, but I wanted to bring this up now to limit the blast radius of this fairly severe, fairly subtle, problem.
In the mean time this issue can be reproduced by running the following test (we have several tests failing, here is just one) from Diffrax (check out the source from here):
using versions:
System info (python version, jaxlib version, accelerator, etc.)
(Tagging @lockwo @andyElking as folks who are likely to need to know about this from the Diffrax side.)
The text was updated successfully, but these errors were encountered: