Skip to content

Commit

Permalink
Fix bug in local_div_switch_sink rewrite
Browse files Browse the repository at this point in the history
Introduced in 4f7d709
  • Loading branch information
ricardoV94 committed Nov 11, 2024
1 parent c2e88c6 commit fdbf3aa
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
5 changes: 4 additions & 1 deletion pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,10 @@ def local_div_switch_sink(fgraph, node):
# will point to the new division op.
copy_stack_trace(node.outputs, fdiv)

fct = switch(switch_cond, zero_switch_input, fdiv)
if branch == 0:
fct = switch(switch_cond, zero_switch_input, fdiv)
else:
fct = switch(switch_cond, fdiv, zero_switch_input)

# Tell debug_mode than the output is correct, even if nan disappear
fct.tag.values_eq_approx = values_eq_approx_remove_nan
Expand Down
23 changes: 22 additions & 1 deletion tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2163,7 +2163,7 @@ def test_local_mul_div_switch_sink_cast(self, op, rewrite):
# The zero branch upcasts the output, so we can't ignore its dtype
zero_branch = constant(np.array(0, dtype="float64"), name="zero_branch")
other_branch = scalar("other_branch", dtype="float32")
outer_var = scalar("mul_var", dtype="bool")
outer_var = scalar("outer_var", dtype="bool")

out = op(switch(cond, zero_branch, other_branch), outer_var)
fgraph = FunctionGraph(outputs=[out], clone=False)
Expand All @@ -2173,6 +2173,27 @@ def test_local_mul_div_switch_sink_cast(self, op, rewrite):
expected_out = switch(cond, zero_branch, op(other_branch, outer_var))
assert equal_computations([new_out], [expected_out])

@pytest.mark.parametrize(
"op, rewrite", [(mul, local_mul_switch_sink), (true_div, local_div_switch_sink)]
)
def test_local_mul_div_switch_sink_branch_order(self, op, rewrite):
cond = scalar("cond", dtype="bool")
zero_branch = constant(np.array(0.0, dtype="float64"), "zero_branch")
other_branch = scalar("other_branch", dtype="float64")
outer_var = scalar("outer_var", dtype="float64")

left = op(switch(cond, zero_branch, other_branch), outer_var)
right = op(switch(cond, other_branch, zero_branch), outer_var)
fgraph = FunctionGraph(outputs=[left, right], clone=False)
[new_left] = rewrite.transform(fgraph, left.owner)
[new_right] = rewrite.transform(fgraph, right.owner)

expected_left = switch(cond, zero_branch, op(other_branch, outer_var))
expected_right = switch(cond, op(other_branch, outer_var), zero_branch)
assert equal_computations(
[new_left, new_right], [expected_left, expected_right]
)


@pytest.mark.skipif(
config.cxx == "",
Expand Down

0 comments on commit fdbf3aa

Please sign in to comment.