forked from wesselb/stheno
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathreadme_example8_gp-rnn.py
120 lines (94 loc) · 3.3 KB
/
readme_example8_gp-rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from varz.spec import parametrised, Positive
from varz.tensorflow import Vars, minimise_adam
from wbml.net import rnn as rnn_constructor
from wbml.plot import tweak
from stheno.tensorflow import B, Measure, GP, Delta, EQ
# Increase regularisation because we are dealing with `tf.float32`s.
B.epsilon = 1e-6
# Construct points which to predict at.
x = B.linspace(tf.float32, 0, 1, 100)[:, None]
inds_obs = B.range(0, int(0.75 * len(x))) # Train on the first 75% only.
x_obs = B.take(x, inds_obs)
# Construct function and observations.
# Draw random modulation functions.
a_true = GP(1e-2 * EQ().stretch(0.1))(x).sample()
b_true = GP(1e-2 * EQ().stretch(0.1))(x).sample()
# Construct the true, underlying function.
f_true = (1 + a_true) * B.sin(2 * np.pi * 7 * x) + b_true
# Add noise.
y_true = f_true + 0.1 * B.randn(*f_true.shape)
# Normalise and split.
f_true = (f_true - B.mean(y_true)) / B.std(y_true)
y_true = (y_true - B.mean(y_true)) / B.std(y_true)
y_obs = B.take(y_true, inds_obs)
@parametrised
def model(
vs, a_scale: Positive = 0.1, b_scale: Positive = 0.1, noise: Positive = 0.01
):
prior = Measure()
# Construct an RNN.
f_rnn = rnn_constructor(
output_size=1, widths=(10,), nonlinearity=B.tanh, final_dense=True
)
# Set the weights for the RNN.
num_weights = f_rnn.num_weights(input_size=1)
weights = Vars(tf.float32, source=vs.get(shape=(num_weights,), name="rnn"))
f_rnn.initialise(input_size=1, vs=weights)
# Construct GPs that modulate the RNN.
a = GP(1e-2 * EQ().stretch(a_scale), measure=prior)
b = GP(1e-2 * EQ().stretch(b_scale), measure=prior)
e = GP(noise * Delta(), measure=prior)
# GP-RNN model:
f_gp_rnn = (1 + a) * (lambda x: f_rnn(x)) + b
y_gp_rnn = f_gp_rnn + e
return f_rnn, f_gp_rnn, y_gp_rnn, a, b
def objective_rnn(vs):
f_rnn, _, _, _, _ = model(vs)
return B.mean((f_rnn(x_obs) - y_obs) ** 2)
def objective_gp_rnn(vs):
_, _, y_gp_rnn, _, _ = model(vs)
evidence = y_gp_rnn(x_obs).logpdf(y_obs)
return -evidence
# Pretrain the RNN.
vs = Vars(tf.float32)
minimise_adam(
tf.function(objective_rnn, autograph=False), vs, rate=1e-2, iters=1000, trace=True
)
# Jointly train the RNN and GPs.
minimise_adam(
tf.function(objective_gp_rnn, autograph=False),
vs,
rate=1e-3,
iters=1000,
trace=True,
)
_, f_gp_rnn, y_gp_rnn, a, b = model(vs)
# Condition.
post = f_gp_rnn.measure | (y_gp_rnn(x_obs), y_obs)
# Predict and plot results.
plt.figure(figsize=(10, 6))
plt.subplot(2, 1, 1)
plt.title("$(1 + a)\\cdot {}$RNN${} + b$")
plt.plot(x, f_true, label="True", style="test")
plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
mean, lower, upper = post(f_gp_rnn(x)).marginals()
plt.plot(x, mean, label="Prediction", style="pred")
plt.fill_between(x, lower, upper, style="pred")
tweak()
plt.subplot(2, 2, 3)
plt.title("$a$")
mean, lower, upper = post(a(x)).marginals()
plt.plot(x, mean, label="Prediction", style="pred")
plt.fill_between(x, lower, upper, style="pred")
tweak()
plt.subplot(2, 2, 4)
plt.title("$b$")
mean, lower, upper = post(b(x)).marginals()
plt.plot(x, mean, label="Prediction", style="pred")
plt.fill_between(x, lower, upper, style="pred")
tweak()
plt.savefig(f"readme_example8_gp-rnn.png")
plt.show()