Skip to content

Commit

Permalink
Add stabilization rewrite for log_diff_exp
Browse files Browse the repository at this point in the history
Co-authored-by: Ricardo Vieira <[email protected]>

.rewrite
  • Loading branch information
Smit-create authored and ricardoV94 committed Apr 24, 2023
1 parent b1332b2 commit 9fd3af7
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
8 changes: 8 additions & 0 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3604,6 +3604,14 @@ def local_reciprocal_1_plus_exp(fgraph, node):
)
register_stabilize(log1pmexp_to_log1mexp, name="log1pmexp_to_log1mexp")

# log(exp(a) - exp(b)) -> a + log1mexp(b - a)
logdiffexp_to_log1mexpdiff = PatternNodeRewriter(
(log, (sub, (exp, "x"), (exp, "y"))),
(add, "x", (log1mexp, (sub, "y", "x"))),
allow_multiple_clients=True,
)
register_stabilize(logdiffexp_to_log1mexpdiff, name="logdiffexp_to_log1mexpdiff")


# log(sigmoid(x) / (1 - sigmoid(x))) -> x
# i.e logit(sigmoid(x)) -> x
Expand Down
39 changes: 39 additions & 0 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4136,3 +4136,42 @@ def test_log1mexp_stabilization():
f(np.array([-0.8, -0.6], dtype=config.floatX)),
np.log(1 - np.exp([-0.8, -0.6])),
)


def test_logdiffexp():
rng = np.random.default_rng(3559)
mode = Mode("py").including("stabilize").excluding("fusion")

x = fmatrix("x")
y = fmatrix("y")
f = function([x, y], log(exp(x) - exp(y)), mode=mode)

graph = f.maker.fgraph.toposort()
assert (
len(
[
node
for node in graph
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, (aes.Exp, aes.Log))
]
)
== 0
)
assert (
len(
[
node
for node in graph
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, aes.Log1mexp)
]
)
== 1
)

y_test = rng.normal(size=(3, 2)).astype("float32")
x_test = rng.normal(size=(3, 2)).astype("float32") + y_test.max()
np.testing.assert_almost_equal(
f(x_test, y_test), np.log(np.exp(x_test) - np.exp(y_test))
)

1 comment on commit 9fd3af7

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'Python Benchmark with pytest-benchmark'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 9fd3af7 Previous: 38dc6c9 Ratio
tests/link/numba/test_scan.py::test_scan_multiple_output 5692.316119036308 iter/sec (stddev: 0.000004073776657206289) 12016.570467614958 iter/sec (stddev: 0.000002392918507767701) 2.11

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.