diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 14c28bbe95..8d5396ac08 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1,14 +1,17 @@ import logging import typing import warnings +from collections.abc import Sequence from functools import reduce from typing import Literal, cast import numpy as np -import scipy.linalg +import scipy import pytensor import pytensor.tensor as pt +from pytensor import Variable +from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.tensor import TensorLike, as_tensor_variable @@ -25,8 +28,6 @@ class Cholesky(Op): - # TODO: LAPACK wrapper with in-place behavior, for solve also - __props__ = ("lower", "check_finite", "on_error", "overwrite_a") gufunc_signature = "(m,m)->(m,m)" @@ -396,6 +397,186 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None): )(A, b) +class LU(Op): + """Decompose a matrix into lower and upper triangular matrices.""" + + __props__ = ("permute_l", "overwrite_a", "check_finite", "p_indices") + + def __init__( + self, *, permute_l=False, overwrite_a=False, check_finite=True, p_indices=False + ): + self.permute_l = permute_l + self.check_finite = check_finite + self.p_indices = p_indices + self.overwrite_a = overwrite_a + + if self.permute_l: + # permute_l overrides p_indices in the scipy function. We can copy that behavior + self.gufunc_signature = "(m,m)->(m,m),(m,m)" + elif self.p_indices: + self.gufunc_signature = "(m,m)->(m),(m,m),(m,m)" + else: + self.gufunc_signature = "(m,m)->(m,m),(m,m),(m,m)" + + if self.overwrite_a: + self.destroy_map = {0: [0]} + + def infer_shape(self, fgraph, node, shapes): + n = shapes[0][0] + if self.permute_l: + return [(n, n), (n, n)] + elif self.p_indices: + return [(n,), (n, n), (n, n)] + else: + return [(n, n), (n, n), (n, n)] + + def make_node(self, x): + x = as_tensor_variable(x) + if x.type.ndim != 2: + raise TypeError( + f"LU only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input" + ) + + real_dtype = "f" if np.dtype(x.type.dtype).char in "fF" else "d" + p_dtype = "int32" if self.p_indices else np.dtype(real_dtype) + + L = tensor(shape=x.type.shape, dtype=real_dtype) + U = tensor(shape=x.type.shape, dtype=real_dtype) + + if self.permute_l: + # In this case, L is actually P @ L + return Apply(self, inputs=[x], outputs=[L, U]) + elif self.p_indices: + p = tensor(shape=(x.type.shape[0],), dtype=p_dtype) + return Apply(self, inputs=[x], outputs=[p, L, U]) + else: + P = tensor(shape=x.type.shape, dtype=p_dtype) + return Apply(self, inputs=[x], outputs=[P, L, U]) + + def perform(self, node, inputs, outputs): + [A] = inputs + + out = scipy.linalg.lu( + A, + permute_l=self.permute_l, + overwrite_a=self.overwrite_a, + check_finite=self.check_finite, + p_indices=self.p_indices, + ) + + outputs[0][0] = out[0] + outputs[1][0] = out[1] + + if not self.permute_l: + # In all cases except permute_l, there are three returns + outputs[2][0] = out[2] + + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": + if 0 in allowed_inplace_inputs: + new_props = self._props_dict() # type: ignore + new_props["overwrite_a"] = True + return type(self)(**new_props) + else: + return self + + def L_op( + self, + inputs: Sequence[Variable], + outputs: Sequence[Variable], + output_grads: Sequence[Variable], + ) -> list[Variable]: + r""" + Derivation is due to Differentiation of Matrix Functionals Using Triangular Factorization + F. R. De Hoog, R.S. Anderssen, M. A. Lukas + """ + [A] = inputs + A = cast(TensorVariable, A) + + if self.permute_l: + PL_bar, U_bar = output_grads + + # TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient + P, L, U = lu( # type: ignore + A, permute_l=False, check_finite=self.check_finite, p_indices=False + ) + + # Permutation matrix is orthogonal + L_bar = ( + P.T @ PL_bar + if not isinstance(PL_bar.type, DisconnectedType) + else pt.zeros_like(A) + ) + + elif self.p_indices: + p, L, U = outputs + + # TODO: rewrite to p_indices = False for graphs where we need to compute the gradient + P = pt.eye(A.shape[0])[p] + _, L_bar, U_bar = output_grads + else: + P, L, U = outputs + _, L_bar, U_bar = output_grads + + L_bar = ( + L_bar if not isinstance(L_bar.type, DisconnectedType) else pt.zeros_like(A) + ) + U_bar = ( + U_bar if not isinstance(U_bar.type, DisconnectedType) else pt.zeros_like(A) + ) + + x1 = ptb.tril(L.T @ L_bar, k=-1) + x2 = ptb.triu(U_bar @ U.T) + + L_inv_x = solve_triangular(L.T, x1 + x2, lower=False, unit_diagonal=True) + A_bar = P @ solve_triangular(U, L_inv_x.T, lower=False).T + + return [A_bar] + + +def lu( + a: TensorLike, permute_l=False, check_finite=True, p_indices=False +) -> ( + tuple[TensorVariable, TensorVariable, TensorVariable] + | tuple[TensorVariable, TensorVariable] +): + """ + Factorize a matrix as the product of a unit lower triangular matrix and an upper triangular matrix: + + ... math:: + + A = P L U + + Where P is a permutation matrix, L is lower triangular with unit diagonal elements, and U is upper triangular. + + Parameters + ---------- + a: TensorLike + Matrix to be factorized + permute_l: bool + If True, L is a product of permutation and unit lower triangular matrices. Only two values, PL and U, will + be returned in this case, and PL will not be lower triangular. + check_finite: bool + Whether to check that the input matrix contains only finite numbers. + p_indices: bool + If True, return integer matrix indices for the permutation matrix. Otherwise, return the permutation matrix + itself. + + Returns + ------- + P: TensorVariable + Permutation matrix, or array of integer indices for permutation matrix. Not returned if permute_l is True. + L: TensorVariable + Lower triangular matrix, or product of permutation and unit lower triangular matrices if permute_l is True. + U: TensorVariable + Upper triangular matrix + """ + return cast( + tuple[TensorVariable, TensorVariable, TensorVariable] + | tuple[TensorVariable, TensorVariable], + LU(permute_l=permute_l, check_finite=check_finite, p_indices=p_indices)(a), + ) + + class SolveTriangular(SolveBase): """Solve a system of linear equations.""" diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 67607891fc..3517e8703e 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -21,6 +21,7 @@ cholesky, eigvalsh, expm, + lu, solve, solve_continuous_lyapunov, solve_discrete_are, @@ -437,6 +438,67 @@ def test_solve_dtype(self): assert x.dtype == x_result.dtype, (A_dtype, b_dtype) +@pytest.mark.parametrize("permute_l", [True, False], ids=["permute_l", "no_permute_l"]) +@pytest.mark.parametrize("p_indices", [True, False], ids=["p_indices", "no_p_indices"]) +@pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"]) +def test_lu_decomposition(permute_l, p_indices, complex): + dtype = config.floatX if not complex else f"complex{int(config.floatX[-2:]) * 2}" + A = tensor("A", shape=(None, None), dtype=dtype) + out = lu(A, permute_l=permute_l, p_indices=p_indices) + + f = pytensor.function([A], out) + + rng = np.random.default_rng(utt.fetch_seed()) + x = rng.normal(size=(5, 5)).astype(config.floatX) + if complex: + x = x + 1j * rng.normal(size=(5, 5)).astype(config.floatX) + + out = f(x) + + if permute_l: + PL, U = out + x_rebuilt = PL @ U + elif p_indices: + p, L, U = out + P = np.eye(5)[p] + x_rebuilt = P @ L @ U + else: + P, L, U = out + x_rebuilt = P @ L @ U + + np.testing.assert_allclose(x, x_rebuilt) + scipy_out = scipy.linalg.lu(x, permute_l=permute_l, p_indices=p_indices) + + for a, b in zip(out, scipy_out, strict=True): + np.testing.assert_allclose(a, b) + + +@pytest.mark.parametrize("grad_case", [0, 1, 2], ids=["U_only", "L_only", "U_and_L"]) +@pytest.mark.parametrize("permute_l", [True, False]) +@pytest.mark.parametrize("p_indices", [True, False]) +def test_lu_grad(grad_case, permute_l, p_indices): + rng = np.random.default_rng(utt.fetch_seed()) + A_value = rng.normal(size=(5, 5)) + + def f_pt(A): + out = lu(A, permute_l=permute_l, p_indices=p_indices) + + if permute_l: + L, U = out + else: + _, L, U = out + + match grad_case: + case 0: + return U.sum() + case 1: + return L.sum() + case 2: + return U.sum() + L.sum() + + utt.verify_grad(f_pt, [A_value], rng=rng) + + def test_cho_solve(): rng = np.random.default_rng(utt.fetch_seed()) A = matrix()