From d14cf226b611ee2967fa233b52159fc8f4ae9aae Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 29 Nov 2024 17:04:18 +0100 Subject: [PATCH] Make do interventions shared variables by default --- pymc/model/transform/conditioning.py | 17 +++++++++++---- tests/model/transform/test_conditioning.py | 25 ++++++++++++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/pymc/model/transform/conditioning.py b/pymc/model/transform/conditioning.py index 23e0175503b..62d8e6e12c8 100644 --- a/pymc/model/transform/conditioning.py +++ b/pymc/model/transform/conditioning.py @@ -16,7 +16,9 @@ from collections.abc import Mapping, Sequence from typing import Any, Union -from pytensor.graph import ancestors +import pytensor + +from pytensor.graph import Constant, ancestors from pytensor.tensor import TensorVariable from pymc.logprob.transforms import Transform @@ -123,7 +125,9 @@ def observe( def do( model: Model, vars_to_interventions: Mapping[Union["str", TensorVariable], Any], - prune_vars=False, + *, + make_interventions_shared: bool = True, + prune_vars: bool = False, ) -> Model: """Replace model variables by intervention variables. @@ -137,6 +141,8 @@ def do( Dictionary that maps model variables (or names) to intervention expressions. Intervention expressions must have a shape and data type that is compatible with the original model variable. + make_interventions_shared: bool, defaults to True, + Whether to make constant interventions shared variables. prune_vars: bool, defaults to False Whether to prune model variables that are not connected to any observed variables, after the interventions. @@ -167,11 +173,14 @@ def do( """ do_mapping = {} - for var, obs in vars_to_interventions.items(): + for var, intervention in vars_to_interventions.items(): if isinstance(var, str): var = model[var] try: - do_mapping[var] = var.type.filter_variable(obs) + intervention = var.type.filter_variable(intervention) + if make_interventions_shared and isinstance(intervention, Constant): + intervention = pytensor.shared(intervention.data, name=var.name) + do_mapping[var] = intervention except TypeError as err: raise TypeError( "Incompatible replacement type. Make sure the shape and datatype of the interventions match the original variables" diff --git a/tests/model/transform/test_conditioning.py b/tests/model/transform/test_conditioning.py index 2aba88b99d1..8d87635ff4b 100644 --- a/tests/model/transform/test_conditioning.py +++ b/tests/model/transform/test_conditioning.py @@ -16,9 +16,12 @@ import pytest from pytensor import config +from pytensor.compile import SharedVariable +from pytensor.graph import Constant import pymc as pm +from pymc import set_data from pymc.distributions.transforms import logodds from pymc.model.transform.conditioning import ( change_value_transforms, @@ -253,6 +256,28 @@ def test_do_self_reference(): np.testing.assert_allclose(draw_x + 100, draw_do_x) +def test_do_make_intervenstions_shared(): + with pm.Model(coords={"obs": [1]}) as m: + x = pm.Normal("x", dims="obs") + y = pm.Normal("y", dims="obs") + + constant_m = do(m, {x: [0.5]}, make_interventions_shared=False) + constant_x = constant_m["x"] + assert isinstance(constant_x, Constant) + np.testing.assert_array_equal(constant_x.data, [0.5]) + + shared_m = do(m, {x: [0.5]}, make_interventions_shared=True) + shared_x = shared_m["x"] + assert isinstance(shared_x, SharedVariable) + np.testing.assert_array_equal(shared_x.get_value(borrow=True), [0.5]) + + with shared_m: + set_data({"x": [0.6, 0.9]}, coords={"obs": [2, 3]}) + pp_y = pm.sample_prior_predictive(draws=3).prior["y"] + assert pp_y.sizes == {"chain": 1, "draw": 3, "obs": 2} + assert pp_y.shape == (1, 3, 2) + + def test_change_value_transforms(): with pm.Model() as base_m: p = pm.Uniform("p", 0, 1, default_transform=None)