Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: pymc-devs/pymc
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 312effd38c98a8206702395ca305ce0f8668cf94
Choose a base ref
..
head repository: pymc-devs/pymc
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: 88308f26c252ee7bc964a6e8f6d141eb87c9c510
Choose a head ref
Showing with 11 additions and 9 deletions.
  1. +2 −2 pymc/backends/base.py
  2. +2 −2 pymc/backends/mcbackend.py
  3. +2 −2 pymc/backends/zarr.py
  4. +5 −3 pymc/pytensorf.py
4 changes: 2 additions & 2 deletions pymc/backends/base.py
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@

from pymc.backends.report import SamplerReport
from pymc.model import modelcontext
from pymc.pytensorf import compile, set_function_rngs
from pymc.pytensorf import compile, copy_function_with_new_rngs
from pymc.util import get_var_name

logger = logging.getLogger(__name__)
@@ -179,7 +179,7 @@ def __init__(
)
fn.trust_input = True
if rng is not None:
fn = set_function_rngs(fn=fn, rng=rng)
fn = copy_function_with_new_rngs(fn=fn, rng=rng)

# Get variable shapes. Most backends will need this
# information.
4 changes: 2 additions & 2 deletions pymc/backends/mcbackend.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@

from pymc.backends.base import IBaseTrace
from pymc.model import Model
from pymc.pytensorf import PointFunc, set_function_rngs
from pymc.pytensorf import PointFunc, copy_function_with_new_rngs
from pymc.step_methods.compound import (
BlockedStep,
CompoundStep,
@@ -116,7 +116,7 @@ def __init__(
self._chain = chain
self._point_fn = point_fn
if rng is not None:
self._point_fn = set_function_rngs(self._point_fn, rng)
self._point_fn = copy_function_with_new_rngs(self._point_fn, rng)
self._statsbj = stats_bijection
super().__init__()

4 changes: 2 additions & 2 deletions pymc/backends/zarr.py
Original file line number Diff line number Diff line change
@@ -35,7 +35,7 @@
from pymc.backends.base import BaseTrace
from pymc.blocking import StatDtype, StatShape
from pymc.model.core import Model, modelcontext
from pymc.pytensorf import set_function_rngs
from pymc.pytensorf import copy_function_with_new_rngs
from pymc.step_methods.compound import (
BlockedStep,
CompoundStep,
@@ -548,7 +548,7 @@ def init_trace(
test_point=test_point,
stats_bijection=StatsBijection(step.stats_dtypes),
draws_per_chunk=self.draws_per_chunk,
fn=set_function_rngs(self.fn, rng_),
fn=copy_function_with_new_rngs(self.fn, rng_),
)
for rng_ in get_random_generator(rng).spawn(chains)
]
8 changes: 5 additions & 3 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
@@ -1167,18 +1167,20 @@ def normalize_rng_param(rng: None | Variable) -> Variable:


@overload
def set_function_rngs(
def copy_function_with_new_rngs(
fn: PointFunc, rng: np.random.Generator | RandomGeneratorState
) -> PointFunc: ...


@overload
def set_function_rngs(
def copy_function_with_new_rngs(
fn: Function, rng: np.random.Generator | RandomGeneratorState
) -> Function: ...


def set_function_rngs(fn: Function, rng: np.random.Generator | RandomGeneratorState) -> Function:
def copy_function_with_new_rngs(
fn: Function, rng: np.random.Generator | RandomGeneratorState
) -> Function:
"""Copy a compiled pytensor function and replace the random Generators with spawns.
Parameters