From ec106ffa5389c45846dd1d23e3af9c164991dbd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 26 Sep 2022 21:58:02 +0200 Subject: [PATCH] Add the transformation between the inverse gamma and the exponential --- aemcmc/transforms.py | 51 ++++++++++++++++++++++++++++++++++++++++ tests/test_transforms.py | 34 ++++++++++++++++++++++++++- 2 files changed, 84 insertions(+), 1 deletion(-) diff --git a/aemcmc/transforms.py b/aemcmc/transforms.py index 6cc9e82..27430d4 100644 --- a/aemcmc/transforms.py +++ b/aemcmc/transforms.py @@ -70,3 +70,54 @@ def location_scale_transform(in_expr, out_expr): eq(out_expr, noncentered_et), location_scale_family(distribution_lv), ) + + +def invgamma_exponential(invgamma_expr, invexponential_expr): + r"""Produce a goal that represents the relation between the inverse gamma distribution + and the inverse of an exponential distribution. + + .. math:: + + \begin{equation*} + \frac{ + X \sim \operatorname{Gamma^{-1}}\left(1, c\right) + }{ + Y = 1 / X, \quad + Y \sim \operatorname{Exp}\left(c\right) + } + \end{equation*} + + TODO: This is a particular case of a more general relation between the inverse gamma + and the gamma distribution (of which the exponential distribution is a special case). + We should implement this more general relation, and the special case separately in the + future. + + Parameters + ---------- + invgamma_expr + An expression that represents a random variable with an inverse gamma + distribution with a shape parameter equal to 1. + invexponential_expr + An expression that represents the inverse of a random variable with an + exponential distribution. + + """ + c_lv = var() + rng_lv, size_lv, dtype_lv = var(), var(), var() + + invgamma_et = etuple( + etuplize(at.random.invgamma), rng_lv, size_lv, dtype_lv, at.as_tensor(1.0), c_lv + ) + + exponential_et = etuple( + etuplize(at.random.exponential), + c_lv, + rng=rng_lv, + size=size_lv, + dtype=dtype_lv, + ) + invexponential_et = etuple(at.true_div, at.as_tensor(1.0), exponential_et) + + return lall( + eq(invgamma_expr, invgamma_et), eq(invexponential_expr, invexponential_et) + ) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 14b5a27..40d5d20 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -3,7 +3,7 @@ from aesara.graph.fg import FunctionGraph from aesara.graph.kanren import KanrenRelationSub -from aemcmc.transforms import location_scale_transform +from aemcmc.transforms import invgamma_exponential, location_scale_transform def test_normal_scale_loc_transform_lift(): @@ -45,3 +45,35 @@ def test_normal_scale_loc_transform_sink(): )[0] assert isinstance(res.owner.op, type(at.random.normal)) + + +def test_invgamma_to_exp(): + + srng = at.random.RandomStream(0) + c_at = at.scalar() + X_rv = srng.invgamma(1.0, c_at) + + fgraph = FunctionGraph(outputs=[X_rv], clone=False) + res = KanrenRelationSub(invgamma_exponential).transform( + fgraph, fgraph.outputs[0].owner + )[0] + + assert isinstance(res.owner.op, type(at.true_div)) + assert isinstance(res.owner.inputs[1].owner.op, type(at.random.exponential)) + + +@pytest.mark.xfail( + reason="Op.__call__ does not dispatch to Op.make_node for some RandomVariable and etuple evaluation returns an error" +) +def test_invgamma_from_exp(): + + srng = at.random.RandomStream(0) + c_at = at.scalar() + X_rv = 1.0 / srng.exponential(c_at) + + fgraph = FunctionGraph(outputs=[X_rv], clone=False) + res = KanrenRelationSub(lambda x, y: invgamma_exponential(y, x)).transform( + fgraph, fgraph.outputs[0].owner + )[0] + + assert isinstance(res.owner.op, type(at.random.inversegamma))