Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New PRNG in 0.5.0 has degraded pseudorandom capability #26019

Closed
patrick-kidger opened this issue Jan 21, 2025 · 14 comments
Closed

New PRNG in 0.5.0 has degraded pseudorandom capability #26019

patrick-kidger opened this issue Jan 21, 2025 · 14 comments
Assignees
Labels
bug Something isn't working

Comments

@patrick-kidger
Copy link
Collaborator

patrick-kidger commented Jan 21, 2025

Description

Something we just spotted in patrick-kidger/diffrax#569 is that several of Diffrax's tests are now failing as of the 0.5.0 release. We run Kolmogorov-Smirnov tests to check the randomness of our Brownian motion samples, and these are now dramatically exceeding thresholds: with p = 0.01 the null hypothesis (that our random samples are correctly normally distributed) is rejected, in favour of the alternate hypothesis (that our samples are from non-normal distributions).

Setting JAX_THREEFRY_PARTITIONABLE=False as per #18480 fixes the issue (p=0.61).

As you might imagine a MWE is somewhat tricky to isolate, but I wanted to bring this up now to limit the blast radius of this fairly severe, fairly subtle, problem.

In the mean time this issue can be reproduced by running the following test (we have several tests failing, here is just one) from Diffrax (check out the source from here):

pytest test/test_brownian.py::test_statistics[True-SpaceTimeLevyArea-VirtualBrownianTree]

using versions:

beartype==0.19.0
diffrax==0.6.2
equinox==0.11.11
iniconfig==2.0.0
jax==0.5.0
jaxlib==0.5.0
jaxtyping==0.2.36
lineax==0.0.7
ml_dtypes==0.5.1
numpy==2.2.2
opt_einsum==3.4.0
optimistix==0.0.10
packaging==24.2
pluggy==1.5.0
pytest==8.3.4
scipy==1.15.1
setuptools==75.1.0
typeguard==2.13.3
typing_extensions==4.12.2
wheel==0.44.0

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.5.0
jaxlib: 0.5.0
numpy:  2.2.2
python: 3.12.8 | packaged by Anaconda, Inc. | (main, Dec 11 2024, 10:37:40) [Clang 14.0.6 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='Air.local', release='24.2.0', version='Darwin Kernel Version 24.2.0: Fri Dec  6 18:51:28 PST 2024; root:xnu-11215.61.5~2/RELEASE_ARM64_T8112', machine='arm64')

(Tagging @lockwo @andyElking as folks who are likely to need to know about this from the Diffrax side.)

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 21, 2025

JAX also runs a suite of KS tests for all its distributions, and these succeed with both the old and the new PRNG.

One thing to note: these kinds of tests are notoriously difficult to tune, particularly if you're running many of them. For example, if you have 100 parameterized KS tests, then statistically you'd expect expect 5 to fail with p<0.05 even if the RNG is sound.

Is it possible that equinox is running into this expected failure?

@patrick-kidger
Copy link
Collaborator Author

To be precise we run 12 such tests (test/test_brownian.py::test_statistics) with cutoff at p=0.1. Plausibly we got a little lucky before that we didn't have one failure! But now we have 2 failing at p=0.01 and 2 more failing at p=0.04, which I think is too many / too unlikely.

FWIW the failing tests we have are here: https://github.com/patrick-kidger/diffrax/actions/runs/12865365374/job/35924910896?pr=569

(The above also has some other stochastic tests failing, in which the numerical convergence order of some SDE solvers has veered from expected. These are just calibrated 'by eye' rather than in any real statistical sense, but perhaps still offer some evidence of degraded capability.)


Now, notably the failures are all in the cases in which we do some additional jax.random.split(..., 2)s:

https://github.com/patrick-kidger/diffrax/blob/134a40ad351e18eb8c2e6afa556356596123a69c/diffrax/_brownian/path.py#L161-L169

and moreover, I have just tried rerunning the above tests with 10x and 100x the number of samples (passed in to jax.random.split(key, num_samples), and these bring values up to around p=0.3 and p=0.7 respectively.

If I had to make a hypothesis based on this, it's that jax.random.split(key, n) no longer produces statistically random streams for low n. (Which I would note is probably the common case, not often that one does jax.random.split(key, 6000000) 😄.) I don't know if that's plausible or not, just venturing a hypothesis.

@lockwo
Copy link
Contributor

lockwo commented Jan 21, 2025

I noticed some daily unit tests started failing with 0.5.0 on internal sampling code that depends on equinox (but not diffrax), although equinox doesn't really do anything with RNG (to my knowledge). My tests don't actually test statistical randomness, but it's surprising these all worked before and would now fail if the results are just as random (just a different random number), because I ran it with many different keys and didn't see a failure before (similar to "These are just calibrated 'by eye' rather than in any real statistical sense, but perhaps still offer some evidence of degraded capability."). I will try to isolate a MVC.

@patrick-kidger
Copy link
Collaborator Author

that depends on equinox (but not diffrax), although equinox doesn't really do anything fancy with RNG (to my knowledge)

For reference here btw Equinox doesn't do anything PRNG-related at all. Still just a fancy PyTree library 😄

@mattjj
Copy link
Collaborator

mattjj commented Jan 21, 2025

If I had to make a hypothesis based on this, it's that jax.random.split(key, n) no longer produces statistically random streams for low n. (Which I would note is probably the common case, not often that one does jax.random.split(key, 6000000) 😄.) I don't know if that's plausible or not, just venturing a hypothesis.

I put a with jax.config.update('jax_threefry_partitionable', False) around those split lines (and I also tried putting it on jax.random.split) and the test still failed, with the same numerical value in the assertion. (I even manually edited this line to call _threefry_split_original just to make sure.) So I don't think it's splitting specifically...

My setup was messed up and that was erroneous! See below.

@mattjj
Copy link
Collaborator

mattjj commented Jan 21, 2025

Thanks for reporting this, @patrick-kidger and @lockwo !

@mattjj
Copy link
Collaborator

mattjj commented Jan 21, 2025

Actually the finding from my previous comment was incorrect, since my env was messed up. If we change this line to return _threefry_split_original(key, shape), the test passes.

So it does seem like something is up with splitting...

@lockwo
Copy link
Contributor

lockwo commented Jan 21, 2025

If I had to make a hypothesis based on this, it's that jax.random.split(key, n) no longer produces statistically random streams for low n. (Which I would note is probably the common case, not often that one does jax.random.split(key, 6000000) 😄.) I don't know if that's plausible or not, just venturing a hypothesis.

Something I seem to notice is that splitting seems ok for n=2, but for n=3, 4 I see some issues (preliminary, but just something related to this hypothesis).

In fact, in diffrax, if you just replace state_key, init_key_w, init_key_hh, init_key_kk = jr.split(key, 4) with a bunch of 2 splits

key, state_key = jax.random.split(key)
key, init_key_w = jax.random.split(key)
key, init_key_kk = jax.random.split(key)
key, init_key_hh = jax.random.split(key)

it passes the test.

Still just a fancy PyTree library

As someone who makes active use of functions such as equinox while loops and unvmaps and error checks, it seems more than "just" a pytree library 😃

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 21, 2025

I noticed some daily unit tests started failing with 0.5.0 on internal sampling code

Note that the PRNG sampler did change in 0.5.0, so if your tests depend on the values of particular random draws, this failure would be expected. See the Change log and the Update note for details.

@mattjj
Copy link
Collaborator

mattjj commented Jan 22, 2025

I think we might be tilting at windmills here.

Look at these amazing (handcrated, no AI) shell scripts:

$ seq 0 100 | xargs -P 32 -I \{\} bash -c 'JAX_THREEFRY_PARTITIONABLE=1 JAX_RANDOM_SEED_OFFSET={} pytest test/test_brownian.py::test_statistics[True-SpaceTimeLevyArea-VirtualBrownianTree]; echo $? >> /tmp/kidgerverse' &> /dev/null; echo (grep -c 1 /tmp/kidgerverse); rm /tmp/kidgerverse
28
$ seq 0 100 | xargs -P 32 -I \{\} bash -c 'JAX_THREEFRY_PARTITIONABLE=0 JAX_RANDOM_SEED_OFFSET={} pytest test/test_brownian.py::test_statistics[True-SpaceTimeLevyArea-VirtualBrownianTree]; echo $? >> /tmp/kidgerverse' &> /dev/null; echo (grep -c 1 /tmp/kidgerverse); rm /tmp/kidgerverse
28

The first line counts for how many seeds we see the test fail in jax==0.5.0. The latter effectively counts for how many seeds we see the test fail in jax==0.4.x. Coincidentally they generate the exact same number, but for seq 0 10 they do not.

In both cases this is a much higher failure rate than you might expect from the confidence level of 0.01. I would guess that's because as we increase the number of samples (e.g. to 60k in this case) we drive down the statistical error, but leave any numerical errors unchanged. (Numerical errors could look like: our empirical distribution isn't converging to the true distribution but only because we're using bf16 or f32 or something rather than real numbers.) For the p statistic to be uniformly distributed under the null, we probably need fewer samples.

WDYT? Can you check these results to see if I messed up?

@mattjj
Copy link
Collaborator

mattjj commented Jan 22, 2025

0.5.0 pulls into the lead at 200 seed offsets:

$ seq 0 200 | xargs -P 32 -I \{\} bash -c 'JAX_THREEFRY_PARTITIONABLE=0 JAX_RANDOM_SEED_OFFSET={} pytest test/test_brownian.py::test_statistics[True-SpaceTimeLevyArea-VirtualBrownianTree]; echo $? >> /tmp/kidgerverse' &> /dev/null; echo (grep -c 1 /tmp/kidgerverse); rm /tmp/kidgerverse
62
$ seq 0 200 | xargs -P 32 -I \{\} bash -c 'JAX_THREEFRY_PARTITIONABLE=1 JAX_RANDOM_SEED_OFFSET={} pytest test/test_brownian.py::test_statistics[True-SpaceTimeLevyArea-VirtualBrownianTree]; echo $? >> /tmp/kidgerverse' &> /dev/null; echo (grep -c 1 /tmp/kidgerverse); rm /tmp/kidgerverse
47

(This is just noise...)

@froystig
Copy link
Member

For readers wondering what jax_random_seed_offset does: it adds the given integer offset to every seed argument supplied to jax.random.key or jax.random.PRNGKey. For example:

$ JAX_RANDOM_SEED_OFFSET=0 python -c 'import jax.random as jr; print(jr.normal(jr.key(7)))'
0.45123515
$ JAX_RANDOM_SEED_OFFSET=3 python -c 'import jax.random as jr; print(jr.normal(jr.key(4)))'
0.45123515

This is meant as a means of checking a program's sensitivity to random seed choice, as we are doing here.

@patrick-kidger
Copy link
Collaborator Author

Okay! I reduced the underlying Diffrax code to a MWE -- I think you're right, and that I can explain what happened.

Whilst we do have several KS tests, with several of them now failing, and with them testing different codepaths, it just so happens that some of those codepaths split keys in similar ways! This means that we see the tests grouping together so that many of them will either mutually pass or mutually fail.

We got unlucky (2 splits failed, a 1-in-100 event) rather than statistically-improbably-unlucky (4 tests failed, a 1-in-10000 event).

I'm really sorry for taking up your time folks -- and I really appreciate everyone jumping in.

@mattjj
Copy link
Collaborator

mattjj commented Jan 22, 2025

Thanks for the follow-up @patrick-kidger ! I'd much rather have things surfaced eagerly and then realize they were false alarms than to instead have any missed detections. Randomness is so subtle that it demands extra scrutiny. Hopefully we've all learned never to use randomness.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants