-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlinalg.py
120 lines (88 loc) · 3.02 KB
/
linalg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
r"""Linear algebra helpers"""
from __future__ import annotations
import jax
import jax.numpy as jnp
from jax import Array
from typing import *
def transpose(A: Callable[[Array], Array], x: Array) -> Callable[[Array], Array]:
r"""Returns the transpose of a linear operation."""
y, vjp = jax.vjp(A, x)
def At(y):
return next(iter(vjp(y)))
return At
class DPLR(NamedTuple):
r"""Diagonal plus low-rank (DPLR) matrix."""
D: Array
U: Array = None
V: Array = None
def __add__(self, C: Array) -> DPLR: # C is scalar or diagonal
return DPLR(self.D + C, self.U, self.V)
def __radd__(self, C: Array) -> DPLR:
return DPLR(C + self.D, self.U, self.V)
def __sub__(self, C: Array) -> DPLR:
return DPLR(self.D - C, self.U, self.V)
def __mul__(self, C: Array) -> DPLR:
D = self.D * C
if self.U is None:
U, V = None, None
else:
U, V = self.U, self.V * C[..., None, :]
return DPLR(D, U, V)
def __rmul__(self, C: Array) -> DPLR:
D = C * self.D
if self.U is None:
U, V = None, None
else:
U, V = C[..., None] * self.U, self.V
return DPLR(D, U, V)
def __matmul__(self, x: Array) -> Array:
if self.U is None:
return self.D * x
else:
return self.D * x + jnp.einsum('...ij,...jk,...k', self.U, self.V, x)
@property
def rank(self) -> int:
if self.U is None:
return 0
else:
return self.U.shape[-1]
@property
def W(self) -> Array: # capacitance
return jnp.eye(self.rank) + jnp.einsum('...ik,...k,...kj', self.V, 1 / self.D, self.U)
@property
def inv(self) -> DPLR:
D = 1 / self.D
if self.U is None:
U, V = None, None
else:
U = -D[..., None] * self.U
V = jnp.linalg.solve(self.W, self.V) * D[..., None, :]
return DPLR(D, U, V)
def solve(self, x: Array) -> Array:
D = 1 / self.D
if self.U is None:
return D * x
else:
return D * x - D * jnp.squeeze(
self.U @ jnp.linalg.solve(self.W, self.V @ jnp.expand_dims(D * x, axis=-1)),
axis=-1,
)
def diag(self) -> Array:
if self.U is None:
return self.D
else:
return self.D + jnp.einsum('...ij,...ji->...i', self.U, self.V)
def norm(self) -> Array:
if self.U is None:
return jnp.sum(self.D**2, axis=-1)
else:
return (
jnp.sum(self.D**2, axis=-1)
+ 2 * jnp.einsum('...i,...ij,...ji', self.D, self.U, self.V)
+ jnp.sum((self.V @ self.U) ** 2, axis=(-1, -2))
)
def slogdet(self) -> Tuple[Array, Array]:
sign, logabsdet = jnp.linalg.slogdet(self.W)
sign = sign * jnp.prod(jnp.sign(self.D), axis=-1)
logabsdet = logabsdet + jnp.sum(jnp.log(jnp.abs(self.D)), axis=-1)
return sign, logabsdet