Skip to content

Commit

Permalink
Retain more precise types in MergeOptimizer
Browse files Browse the repository at this point in the history
This can avoid some infinite rewrite loops where a SpecifyShape is lifted, removed and then reintroduced at the bottom by the MergeOptimizer
  • Loading branch information
ricardoV94 committed Aug 3, 2023
1 parent 4cc13bc commit 0b558d8
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
24 changes: 16 additions & 8 deletions pytensor/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,14 +743,22 @@ def apply(self, fgraph):
):
continue

if len(pairs) == 1 and pairs[0][0].type != pairs[0][1].type:
res = pairs[0][0].type.convert_variable(pairs[0][1])

# Since the fgraph.replace only checks the convert_variable
# in one way, we change the order in the case that
# convert_variable will not be successful.
if not res:
pairs = [(pairs[0][1], pairs[0][0])]
# Keep the variable with the most specific static type from the pairs
# E.g the second in (TensorType(shape=(None,), TensorType(shape=(5,))
# Otherwise we could end up reverting type inference progress done elsewhere.
for pair_idx in range(len(pairs)):
old, new = pairs[pair_idx]
if old.type == new.type:
continue
# Check if type of new replacement is at least as specific as that of the old variable
if not old.type.is_super(new.type):
# Check the other way around
if new.type.is_super(old.type):
pairs[pair_idx] = (new, old)
else:
# Replacement requires some operation like specify_shape
new_repl = old.type.convert_variable(new)
pairs[pair_idx] = (old, new_repl)

try:
# If they're all `AtomicVariable`s, there's no need to call validate.
Expand Down
21 changes: 19 additions & 2 deletions tests/graph/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
pre_greedy_node_rewriter,
)
from pytensor.raise_op import assert_op
from pytensor.tensor.math import Dot, add, dot
from pytensor.tensor.math import Dot, add, dot, exp
from pytensor.tensor.rewriting.basic import constant_folding
from pytensor.tensor.subtensor import AdvancedSubtensor
from pytensor.tensor.type import matrix, values_eq_approx_always_true
from pytensor.tensor.type import matrix, values_eq_approx_always_true, vector
from pytensor.tensor.type_other import MakeSlice, SliceConstant, slicetype
from tests.graph.utils import (
MyOp,
Expand Down Expand Up @@ -441,6 +441,23 @@ def test_merge_noinput(self):
assert fg.outputs[0] is fg.outputs[1]
assert fg.outputs[0] is not fg.outputs[2]

@pytest.mark.parametrize("reverse", [False, True])
def test_merge_more_specific_types(self, reverse):
"""Check that we choose the most specific static type when merging variables."""

x1 = vector("x1", shape=(None,))
x2 = vector("x2", shape=(500,))

y1 = exp(x1)
y2 = exp(x2)

# Simulate case where we find that x2 is equivalent to x1
fg = FunctionGraph([x1, x2], [y2, y1] if reverse else [y1, y2], clone=False)
fg.replace(x1, x2)

MergeOptimizer().rewrite(fg)
assert fg.outputs == [y2, y2]


class TestEquilibrium:
def test_1(self):
Expand Down

0 comments on commit 0b558d8

Please sign in to comment.