-
Notifications
You must be signed in to change notification settings - Fork 912
/
rkhs.py
95 lines (65 loc) · 2.37 KB
/
rkhs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
Inferring a function from a reproducing kernel Hilbert space (RKHS) by taking
gradients of eval with respect to the function-valued argument
"""
import autograd.numpy as np
import autograd.numpy.random as npr
from autograd import grad
from autograd.extend import Box, VSpace, defvjp, primitive
from autograd.util import func
class RKHSFun:
def __init__(self, kernel, alphas={}):
self.alphas = alphas
self.kernel = kernel
self.vs = RKHSFunVSpace(self)
@primitive
def __call__(self, x):
return sum([a * self.kernel(x, x_repr) for x_repr, a in self.alphas.items()], 0.0)
def __add__(self, f):
return self.vs.add(self, f)
def __mul__(self, a):
return self.vs.scalar_mul(self, a)
# TODO: add vjp of __call__ wrt x (and show it in action)
defvjp(func(RKHSFun.__call__), lambda ans, f, x: lambda g: RKHSFun(f.kernel, {x: 1}) * g)
class RKHSFunBox(Box, RKHSFun):
@property
def kernel(self):
return self._value.kernel
RKHSFunBox.register(RKHSFun)
class RKHSFunVSpace(VSpace):
def __init__(self, value):
self.kernel = value.kernel
def zeros(self):
return RKHSFun(self.kernel)
def randn(self):
# These arbitrary vectors are not analogous to randn in any meaningful way
N = npr.randint(1, 3)
return RKHSFun(self.kernel, dict(zip(npr.randn(N), npr.randn(N))))
def _add(self, f, g):
assert f.kernel is g.kernel
return RKHSFun(f.kernel, add_dicts(f.alphas, g.alphas))
def _scalar_mul(self, f, a):
return RKHSFun(f.kernel, {x: a * a_cur for x, a_cur in f.alphas.items()})
def _inner_prod(self, f, g):
assert f.kernel is g.kernel
return sum(
[a1 * a2 * f.kernel(x1, x2) for x1, a1 in f.alphas.items() for x2, a2 in g.alphas.items()], 0.0
)
RKHSFunVSpace.register(RKHSFun)
def add_dicts(d1, d2):
d = {}
for k, v in d1.items() + d2.items():
d[k] = d[k] + v if k in d else v
return d
if __name__ == "__main__":
def sq_exp_kernel(x1, x2):
return np.exp(-((x1 - x2) ** 2))
xs = range(5)
ys = [1, 2, 3, 2, 1]
def logprob(f, xs, ys):
return -sum((f(x) - y) ** 2 for x, y in zip(xs, ys))
f = RKHSFun(sq_exp_kernel)
for i in range(100):
f = f + grad(logprob)(f, xs, ys) * 0.01
for x, y in zip(xs, ys):
print(f"{x}\t{y}\t{f(x)}")