Skip to content

Commit

Permalink
Add LU decomposition Op
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Feb 11, 2025
1 parent 2460f2d commit 86c5539
Show file tree
Hide file tree
Showing 2 changed files with 246 additions and 3 deletions.
187 changes: 184 additions & 3 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import logging
import typing
import warnings
from collections.abc import Sequence
from functools import reduce
from typing import Literal, cast

import numpy as np
import scipy.linalg
import scipy

import pytensor
import pytensor.tensor as pt
from pytensor import Variable
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor import TensorLike, as_tensor_variable
Expand All @@ -25,8 +28,6 @@


class Cholesky(Op):
# TODO: LAPACK wrapper with in-place behavior, for solve also

__props__ = ("lower", "check_finite", "on_error", "overwrite_a")
gufunc_signature = "(m,m)->(m,m)"

Expand Down Expand Up @@ -396,6 +397,186 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None):
)(A, b)


class LU(Op):
"""Decompose a matrix into lower and upper triangular matrices."""

__props__ = ("permute_l", "overwrite_a", "check_finite", "p_indices")

def __init__(
self, *, permute_l=False, overwrite_a=False, check_finite=True, p_indices=False
):
self.permute_l = permute_l
self.check_finite = check_finite
self.p_indices = p_indices
self.overwrite_a = overwrite_a

if self.permute_l:
# permute_l overrides p_indices in the scipy function. We can copy that behavior
self.gufunc_signature = "(m,m)->(m,m),(m,m)"
elif self.p_indices:
self.gufunc_signature = "(m,m)->(m),(m,m),(m,m)"
else:
self.gufunc_signature = "(m,m)->(m,m),(m,m),(m,m)"

if self.overwrite_a:
self.destroy_map = {0: [0]}

def infer_shape(self, fgraph, node, shapes):
n = shapes[0][0]
if self.permute_l:
return [(n, n), (n, n)]
elif self.p_indices:
return [(n,), (n, n), (n, n)]
else:
return [(n, n), (n, n), (n, n)]

def make_node(self, x):
x = as_tensor_variable(x)
if x.type.ndim != 2:
raise TypeError(
f"LU only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input"
)

real_dtype = "f" if np.dtype(x.type.dtype).char in "fF" else "d"
p_dtype = "int32" if self.p_indices else np.dtype(real_dtype)

L = tensor(shape=x.type.shape, dtype=real_dtype)
U = tensor(shape=x.type.shape, dtype=real_dtype)

if self.permute_l:
# In this case, L is actually P @ L
return Apply(self, inputs=[x], outputs=[L, U])
elif self.p_indices:
p = tensor(shape=(x.type.shape[0],), dtype=p_dtype)
return Apply(self, inputs=[x], outputs=[p, L, U])
else:
P = tensor(shape=x.type.shape, dtype=p_dtype)
return Apply(self, inputs=[x], outputs=[P, L, U])

def perform(self, node, inputs, outputs):
[A] = inputs

out = scipy.linalg.lu(
A,
permute_l=self.permute_l,
overwrite_a=self.overwrite_a,
check_finite=self.check_finite,
p_indices=self.p_indices,
)

outputs[0][0] = out[0]
outputs[1][0] = out[1]

if not self.permute_l:
# In all cases except permute_l, there are three returns
outputs[2][0] = out[2]

def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
if 0 in allowed_inplace_inputs:
new_props = self._props_dict() # type: ignore
new_props["overwrite_a"] = True
return type(self)(**new_props)
else:
return self

def L_op(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
r"""
Derivation is due to Differentiation of Matrix Functionals Using Triangular Factorization
F. R. De Hoog, R.S. Anderssen, M. A. Lukas
"""
[A] = inputs
A = cast(TensorVariable, A)

if self.permute_l:
PL_bar, U_bar = output_grads

# TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient
P, L, U = lu( # type: ignore
A, permute_l=False, check_finite=self.check_finite, p_indices=False
)

# Permutation matrix is orthogonal
L_bar = (
P.T @ PL_bar
if not isinstance(PL_bar.type, DisconnectedType)
else pt.zeros_like(A)
)

elif self.p_indices:
p, L, U = outputs

# TODO: rewrite to p_indices = False for graphs where we need to compute the gradient
P = pt.eye(A.shape[0])[p]
_, L_bar, U_bar = output_grads
else:
P, L, U = outputs
_, L_bar, U_bar = output_grads

L_bar = (
L_bar if not isinstance(L_bar.type, DisconnectedType) else pt.zeros_like(A)
)
U_bar = (
U_bar if not isinstance(U_bar.type, DisconnectedType) else pt.zeros_like(A)
)

x1 = ptb.tril(L.T @ L_bar, k=-1)
x2 = ptb.triu(U_bar @ U.T)

L_inv_x = solve_triangular(L.T, x1 + x2, lower=False, unit_diagonal=True)
A_bar = P @ solve_triangular(U, L_inv_x.T, lower=False).T

return [A_bar]


def lu(
a: TensorLike, permute_l=False, check_finite=True, p_indices=False
) -> (
tuple[TensorVariable, TensorVariable, TensorVariable]
| tuple[TensorVariable, TensorVariable]
):
"""
Factorize a matrix as the product of a unit lower triangular matrix and an upper triangular matrix:
... math::
A = P L U
Where P is a permutation matrix, L is lower triangular with unit diagonal elements, and U is upper triangular.
Parameters
----------
a: TensorLike
Matrix to be factorized
permute_l: bool
If True, L is a product of permutation and unit lower triangular matrices. Only two values, PL and U, will
be returned in this case, and PL will not be lower triangular.
check_finite: bool
Whether to check that the input matrix contains only finite numbers.
p_indices: bool
If True, return integer matrix indices for the permutation matrix. Otherwise, return the permutation matrix
itself.
Returns
-------
P: TensorVariable
Permutation matrix, or array of integer indices for permutation matrix. Not returned if permute_l is True.
L: TensorVariable
Lower triangular matrix, or product of permutation and unit lower triangular matrices if permute_l is True.
U: TensorVariable
Upper triangular matrix
"""
return cast(
tuple[TensorVariable, TensorVariable, TensorVariable]
| tuple[TensorVariable, TensorVariable],
LU(permute_l=permute_l, check_finite=check_finite, p_indices=p_indices)(a),
)


class SolveTriangular(SolveBase):
"""Solve a system of linear equations."""

Expand Down
62 changes: 62 additions & 0 deletions tests/tensor/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
cholesky,
eigvalsh,
expm,
lu,
solve,
solve_continuous_lyapunov,
solve_discrete_are,
Expand Down Expand Up @@ -437,6 +438,67 @@ def test_solve_dtype(self):
assert x.dtype == x_result.dtype, (A_dtype, b_dtype)


@pytest.mark.parametrize("permute_l", [True, False], ids=["permute_l", "no_permute_l"])
@pytest.mark.parametrize("p_indices", [True, False], ids=["p_indices", "no_p_indices"])
@pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"])
def test_lu_decomposition(permute_l, p_indices, complex):
dtype = config.floatX if not complex else f"complex{int(config.floatX[-2:]) * 2}"
A = tensor("A", shape=(None, None), dtype=dtype)
out = lu(A, permute_l=permute_l, p_indices=p_indices)

f = pytensor.function([A], out)

rng = np.random.default_rng(utt.fetch_seed())
x = rng.normal(size=(5, 5)).astype(config.floatX)
if complex:
x = x + 1j * rng.normal(size=(5, 5)).astype(config.floatX)

out = f(x)

if permute_l:
PL, U = out
x_rebuilt = PL @ U
elif p_indices:
p, L, U = out
P = np.eye(5)[p]
x_rebuilt = P @ L @ U
else:
P, L, U = out
x_rebuilt = P @ L @ U

np.testing.assert_allclose(x, x_rebuilt)
scipy_out = scipy.linalg.lu(x, permute_l=permute_l, p_indices=p_indices)

for a, b in zip(out, scipy_out, strict=True):
np.testing.assert_allclose(a, b)


@pytest.mark.parametrize("grad_case", [0, 1, 2], ids=["U_only", "L_only", "U_and_L"])
@pytest.mark.parametrize("permute_l", [True, False])
@pytest.mark.parametrize("p_indices", [True, False])
def test_lu_grad(grad_case, permute_l, p_indices):
rng = np.random.default_rng(utt.fetch_seed())
A_value = rng.normal(size=(5, 5))

def f_pt(A):
out = lu(A, permute_l=permute_l, p_indices=p_indices)

if permute_l:
L, U = out
else:
_, L, U = out

match grad_case:
case 0:
return U.sum()
case 1:
return L.sum()
case 2:
return U.sum() + L.sum()

utt.verify_grad(f_pt, [A_value], rng=rng)


def test_cho_solve():
rng = np.random.default_rng(utt.fetch_seed())
A = matrix()
Expand Down

0 comments on commit 86c5539

Please sign in to comment.