forked from thu-ml/zhusuan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bayesian_nn.py
168 lines (144 loc) · 6.35 KB
/
bayesian_nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import os
import time
import tensorflow as tf
from six.moves import range, zip
import numpy as np
import zhusuan as zs
from examples import conf
from examples.utils import dataset
@zs.reuse('model')
def bayesianNN(observed, x, n_x, layer_sizes, n_particles):
with zs.BayesianNet(observed=observed) as model:
ws = []
for i, (n_in, n_out) in enumerate(zip(layer_sizes[:-1],
layer_sizes[1:])):
w_mu = tf.zeros([1, n_out, n_in + 1])
ws.append(zs.Normal('w' + str(i), w_mu, std=1.,
n_samples=n_particles, group_ndims=2))
# forward
ly_x = tf.expand_dims(
tf.tile(tf.expand_dims(x, 0), [n_particles, 1, 1]), 3)
for i in range(len(ws)):
w = tf.tile(ws[i], [1, tf.shape(x)[0], 1, 1])
ly_x = tf.concat(
[ly_x, tf.ones([n_particles, tf.shape(x)[0], 1, 1])], 2)
ly_x = tf.matmul(w, ly_x) / tf.sqrt(tf.cast(tf.shape(ly_x)[2],
tf.float32))
if i < len(ws) - 1:
ly_x = tf.nn.relu(ly_x)
y_mean = tf.squeeze(ly_x, [2, 3])
y_logstd = tf.get_variable('y_logstd', shape=[],
initializer=tf.constant_initializer(0.))
y = zs.Normal('y', y_mean, logstd=y_logstd)
return model, y_mean
def mean_field_variational(layer_sizes, n_particles):
with zs.BayesianNet() as variational:
ws = []
for i, (n_in, n_out) in enumerate(zip(layer_sizes[:-1],
layer_sizes[1:])):
w_mean = tf.get_variable(
'w_mean_' + str(i), shape=[1, n_out, n_in + 1],
initializer=tf.constant_initializer(0.))
w_logstd = tf.get_variable(
'w_logstd_' + str(i), shape=[1, n_out, n_in + 1],
initializer=tf.constant_initializer(0.))
ws.append(
zs.Normal('w' + str(i), w_mean, logstd=w_logstd,
n_samples=n_particles, group_ndims=2))
return variational
if __name__ == '__main__':
tf.set_random_seed(1237)
np.random.seed(1234)
# Load UCI Boston housing data
data_path = os.path.join(conf.data_dir, 'housing.data')
x_train, y_train, x_valid, y_valid, x_test, y_test = \
dataset.load_uci_boston_housing(data_path)
x_train = np.vstack([x_train, x_valid])
y_train = np.hstack([y_train, y_valid])
N, n_x = x_train.shape
# Standardize data
x_train, x_test, _, _ = dataset.standardize(x_train, x_test)
y_train, y_test, mean_y_train, std_y_train = dataset.standardize(
y_train, y_test)
# Define model parameters
n_hiddens = [50]
# Define training/evaluation parameters
lb_samples = 10
ll_samples = 5000
epochs = 500
batch_size = 10
iters = int(np.floor(x_train.shape[0] / float(batch_size)))
test_freq = 10
learning_rate = 0.01
anneal_lr_freq = 100
anneal_lr_rate = 0.75
# Build the computation graph
n_particles = tf.placeholder(tf.int32, shape=[], name='n_particles')
x = tf.placeholder(tf.float32, shape=[None, n_x])
y = tf.placeholder(tf.float32, shape=[None])
y_obs = tf.tile(tf.expand_dims(y, 0), [n_particles, 1])
layer_sizes = [n_x] + n_hiddens + [1]
w_names = ['w' + str(i) for i in range(len(layer_sizes) - 1)]
def log_joint(observed):
model, _ = bayesianNN(observed, x, n_x, layer_sizes, n_particles)
log_pws = model.local_log_prob(w_names)
log_py_xw = model.local_log_prob('y')
return tf.add_n(log_pws) + log_py_xw * N
variational = mean_field_variational(layer_sizes, n_particles)
qw_outputs = variational.query(w_names, outputs=True, local_log_prob=True)
latent = dict(zip(w_names, qw_outputs))
lower_bound = zs.variational.elbo(
log_joint, observed={'y': y_obs}, latent=latent, axis=0)
cost = tf.reduce_mean(lower_bound.sgvb())
lower_bound = tf.reduce_mean(lower_bound)
learning_rate_ph = tf.placeholder(tf.float32, shape=[])
optimizer = tf.train.AdamOptimizer(learning_rate_ph)
infer_op = optimizer.minimize(cost)
# prediction: rmse & log likelihood
observed = dict((w_name, latent[w_name][0]) for w_name in w_names)
observed.update({'y': y_obs})
model, y_mean = bayesianNN(observed, x, n_x, layer_sizes, n_particles)
y_pred = tf.reduce_mean(y_mean, 0)
rmse = tf.sqrt(tf.reduce_mean((y_pred - y) ** 2)) * std_y_train
log_py_xw = model.local_log_prob('y')
log_likelihood = tf.reduce_mean(zs.log_mean_exp(log_py_xw, 0)) - \
tf.log(std_y_train)
params = tf.trainable_variables()
for i in params:
print(i.name, i.get_shape())
# Run the inference
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(1, epochs + 1):
time_epoch = -time.time()
if epoch % anneal_lr_freq == 0:
learning_rate *= anneal_lr_rate
lbs = []
for t in range(iters):
x_batch = x_train[t * batch_size:(t + 1) * batch_size]
y_batch = y_train[t * batch_size:(t + 1) * batch_size]
_, lb = sess.run(
[infer_op, lower_bound],
feed_dict={n_particles: lb_samples,
learning_rate_ph: learning_rate,
x: x_batch, y: y_batch})
lbs.append(lb)
time_epoch += time.time()
print('Epoch {} ({:.1f}s): Lower bound = {}'.format(
epoch, time_epoch, np.mean(lbs)))
if epoch % test_freq == 0:
time_test = -time.time()
test_lb, test_rmse, test_ll = sess.run(
[lower_bound, rmse, log_likelihood],
feed_dict={n_particles: ll_samples,
x: x_test, y: y_test})
time_test += time.time()
print('>>> TEST ({:.1f}s)'.format(time_test))
print('>> Test lower bound = {}'.format(test_lb))
print('>> Test rmse = {}'.format(test_rmse))
print('>> Test log_likelihood = {}'.format(test_ll))