-
Notifications
You must be signed in to change notification settings - Fork 912
/
gmm.py
87 lines (65 loc) · 2.89 KB
/
gmm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""Implements a Gaussian mixture model, in which parameters are fit using
gradient descent. This example runs on 2-dimensional data, but the model
works on arbitrarily-high dimension."""
import matplotlib.pyplot as plt
from data import make_pinwheel
from scipy.optimize import minimize
import autograd.numpy as np
import autograd.numpy.random as npr
import autograd.scipy.stats.multivariate_normal as mvn
from autograd import grad, hessian_vector_product
from autograd.misc.flatten import flatten_func
from autograd.scipy.special import logsumexp
def init_gmm_params(num_components, D, scale, rs=npr.RandomState(0)):
return {
"log proportions": rs.randn(num_components) * scale,
"means": rs.randn(num_components, D) * scale,
"lower triangles": np.zeros((num_components, D, D)) + np.eye(D),
}
def log_normalize(x):
return x - logsumexp(x)
def unpack_gmm_params(params):
normalized_log_proportions = log_normalize(params["log proportions"])
return normalized_log_proportions, params["means"], params["lower triangles"]
def gmm_log_likelihood(params, data):
cluster_lls = []
for log_proportion, mean, cov_sqrt in zip(*unpack_gmm_params(params)):
cov = np.dot(cov_sqrt.T, cov_sqrt)
cluster_lls.append(log_proportion + mvn.logpdf(data, mean, cov))
return np.sum(logsumexp(np.vstack(cluster_lls), axis=0))
def plot_ellipse(ax, mean, cov_sqrt, alpha, num_points=100):
angles = np.linspace(0, 2 * np.pi, num_points)
circle_pts = np.vstack([np.cos(angles), np.sin(angles)]).T * 2.0
cur_pts = mean + np.dot(circle_pts, cov_sqrt)
ax.plot(cur_pts[:, 0], cur_pts[:, 1], "-", alpha=alpha)
def plot_gaussian_mixture(params, ax):
for log_proportion, mean, cov_sqrt in zip(*unpack_gmm_params(params)):
alpha = np.minimum(1.0, np.exp(log_proportion) * 10)
plot_ellipse(ax, mean, cov_sqrt, alpha)
if __name__ == "__main__":
init_params = init_gmm_params(num_components=10, D=2, scale=0.1)
data = make_pinwheel(radial_std=0.3, tangential_std=0.05, num_classes=3, num_per_class=100, rate=0.4)
def objective(params):
return -gmm_log_likelihood(params, data)
flattened_obj, unflatten, flattened_init_params = flatten_func(objective, init_params)
fig = plt.figure(figsize=(12, 8), facecolor="white")
ax = fig.add_subplot(111, frameon=False)
plt.show(block=False)
def callback(flattened_params):
params = unflatten(flattened_params)
print(f"Log likelihood {-objective(params)}")
ax.cla()
ax.plot(data[:, 0], data[:, 1], "k.")
ax.set_xticks([])
ax.set_yticks([])
plot_gaussian_mixture(params, ax)
plt.draw()
plt.pause(1.0 / 60.0)
minimize(
flattened_obj,
flattened_init_params,
jac=grad(flattened_obj),
hessp=hessian_vector_product(flattened_obj),
method="Newton-CG",
callback=callback,
)