Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add way to set backend fn random generators #7629

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Prev Previous commit
copy_function_with_new_rngs warns with JAXLinker
lucianopaz committed Jan 10, 2025
commit 7c597a8934ca2c5e375f09095dd24cfd97681d4b
10 changes: 10 additions & 0 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
@@ -38,6 +38,7 @@
)
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op
from pytensor.link.jax.linker import JAXLinker
from pytensor.scalar.basic import Cast
from pytensor.scan.op import Scan
from pytensor.tensor.basic import _as_tensor_variable
@@ -1208,6 +1209,15 @@ def copy_function_with_new_rngs(
fn_ = fn.f if isinstance(fn, PointFunc) else fn
shared_rngs = [var for var in fn_.get_shared() if isinstance(var.type, RandomGeneratorType)]
n_shared_rngs = len(shared_rngs)
if n_shared_rngs > 0 and isinstance(fn_.maker.linker, JAXLinker):
# Reseeding RVs in JAX backend requires a different logic, becuase the SharedVariables
# used internally are not the ones that `function.get_shared()` returns.
warnings.warn(
"At the moment, it is not possible to set the random generator's key for "
"JAX linked functions. This means that the draws yielded by the random "
"variables that are requested by 'Deterministic' will not be reproducible."
)
return fn
swap = {
old_shared_rng: shared(rng, borrow=True)
for old_shared_rng, rng in zip(shared_rngs, rng_gen.spawn(n_shared_rngs), strict=True)
27 changes: 22 additions & 5 deletions tests/sampling/test_mcmc.py
Original file line number Diff line number Diff line change
@@ -929,12 +929,29 @@ def trace_backend(request):
return trace


def test_random_deterministics(trace_backend):
@pytest.fixture(scope="function", params=["FAST_COMPILE", "NUMBA", "JAX"])
def pytensor_mode(request):
return request.param


def test_random_deterministics(trace_backend, pytensor_mode):
with pm.Model() as m:
x = pm.Bernoulli("x", p=0.5) * 0 # Force it to be zero
pm.Deterministic("y", x + pm.Normal.dist())

idata1 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
idata2 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)

assert idata1.posterior.equals(idata2.posterior)
if pytensor_mode == "JAX":
expected_warning = (
"At the moment, it is not possible to set the random generator's key for "
"JAX linked functions. This means that the draws yielded by the random "
"variables that are requested by 'Deterministic' will not be reproducible."
)
with pytest.warns(UserWarning, match=expected_warning):
with pytensor.config.change_flags(mode=pytensor_mode):
idata1 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
idata2 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
assert not idata1.posterior.equals(idata2.posterior)
else:
with pytensor.config.change_flags(mode=pytensor_mode):
idata1 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
idata2 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
assert idata1.posterior.equals(idata2.posterior)