Skip to content

Commit

Permalink
Add SGHMC
Browse files Browse the repository at this point in the history
  • Loading branch information
zaxtax authored and rlouf committed Sep 30, 2022
1 parent e7b882d commit bf15ed8
Showing 1 changed file with 54 additions and 17 deletions.
71 changes: 54 additions & 17 deletions examples/SGMCMC.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,6 @@ class NN(nn.Module):
model = NN()
@jax.jit
def predict_fn(params, X):
"""Returns the probability for the image represented by X
to be in each category given the MLP's weights vakues.
"""
return model.apply(params, X)
def logprior_fn(params):
"""Compute the value of the log-prior density function."""
leaves, _ = jax.tree_util.tree_flatten(params)
Expand All @@ -122,7 +114,7 @@ def logprior_fn(params):
def loglikelihood_fn(params, data):
"""Categorical log-likelihood"""
X, y = data
return jnp.sum(y * predict_fn(params, X))
return jnp.sum(y * model.apply(params, X))
@jax.jit
Expand All @@ -132,7 +124,7 @@ def compute_accuracy(params, X, y):
To make predictions we take the number that corresponds to the highest probability value.
"""
target_class = jnp.argmax(y, axis=1)
predicted_class = jnp.argmax(predict_fn(params, X), axis=1)
predicted_class = jnp.argmax(model.apply(params, X), axis=1)
return jnp.mean(predicted_class == target_class)
```

Expand All @@ -142,7 +134,7 @@ Now we need to get initial values for the parameters, and we simply sample from

+++

We now sample from the model's posteriors. We discard the first 1000 samples until the sampler has reached the typical set, and then take 2000 samples. We record the model's accuracy with the current values every 100 steps.
We now sample from the model's posteriors using SGLD. We discard the first 1000 samples until the sampler has reached the typical set, and then take 2000 samples. We record the model's accuracy with the current values every 100 steps.

```{code-cell} ipython3
%%time
Expand All @@ -155,7 +147,6 @@ from blackjax.sgmcmc.gradients import grad_estimator
data_size = len(y_train)
batch_size = 512
layer_sizes = [784, 100, 10]
step_size = 5e-5
num_warmup = (data_size // batch_size) * 20
Expand All @@ -165,12 +156,13 @@ num_samples = 1000
rng_key = jax.random.PRNGKey(1)
batches = batch_data(rng_key, (X_train, y_train), batch_size, data_size)
# Set the initial state
init_positions = jax.jit(model.init)(rng_key, jnp.ones(X_train.shape[-1]))
# Build the SGLD kernel with a constant learning rate
grad_fn = grad_estimator(logprior_fn, loglikelihood_fn, data_size)
sgld = blackjax.sgld(grad_fn, lambda _: step_size)
# Set the initial state
init_positions = jax.jit(model.init)(rng_key, jnp.ones(X_train.shape[-1]))
state = sgld.init(init_positions, next(batches))
# Sample from the posterior
Expand Down Expand Up @@ -208,10 +200,55 @@ plt.title("Sample from 3-layer MLP posterior (MNIST dataset) with SgLD")
plt.plot();
```

### Sampling with SGHMC

We can also use SGHMC to samples from this model

```{code-cell} ipython3
# Build the SGHMC kernel with a constant learning rate
step_size = 9e-6
grad_fn = grad_estimator(logprior_fn, loglikelihood_fn, data_size)
sghmc = blackjax.sghmc(grad_fn, lambda _: step_size)
# Batch the data
state = sghmc.init(init_positions, next(batches))
# Sample from the posterior
sghmc_accuracies = []
samples = []
steps = []
for step in progress_bar(range(num_samples + num_warmup)):
_, rng_key = jax.random.split(rng_key)
batch = next(batches)
state = jax.jit(sghmc.step)(rng_key, state, batch)
if step % 100 == 0:
sghmc_accuracy = compute_accuracy(state.position, X_test, y_test)
sghmc_accuracies.append(sghmc_accuracy)
steps.append(step)
if step > num_warmup:
samples.append(state.position)
```

```{code-cell} ipython3
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111)
ld_plot, = ax.plot(steps, accuracies)
hmc_plot, = ax.plot(steps, sghmc_accuracies)
ax.set_xlabel("Number of sampling steps")
ax.set_ylabel("Prediction accuracy")
ax.set_xlim([0, num_warmup + num_samples])
ax.set_ylim([0, 1])
ax.set_yticks([0.1, 0.3, 0.5, 0.7, 0.9])
plt.title("Sample from 3-layer MLP posterior (MNIST dataset)")
ax.legend((ld_plot, hmc_plot), ('SGLD', 'SGHMC'), loc='lower right', shadow=True)
plt.plot();
```

```{code-cell} ipython3
:tags: [hide-input]
print(f"The average accuracy in the sampling phase is {np.mean(accuracies[10:]):.2f}")
print(f"The average accuracy for SGLD in the sampling phase is {100 * np.mean(accuracies[10:]):.2f}%")
print(f"The average accuracy for SGHMC in the sampling phase is {100 * np.mean(sghmc_accuracies[10:]):.2f}%")
```

Which is not a bad accuracy at all for such a simple model! Remember though that we draw samples from the posterior distribution of the digit probabilities; we can thus use this information to filter out examples for which the model is "unsure" of its prediction.
Expand All @@ -220,7 +257,7 @@ Here we will say that the model is unsure of its prediction for a given image if

```{code-cell} ipython3
predicted_class = jnp.exp(
jnp.stack([jax.vmap(predict_fn, in_axes=(None, 0))(s, X_test) for s in samples])
jnp.stack([jax.vmap(model.apply, in_axes=(None, 0))(s, X_test) for s in samples])
)
```

Expand Down Expand Up @@ -257,6 +294,6 @@ avg_accuracy = np.mean(
[compute_accuracy(s, X_test[certain_mask], y_test[certain_mask]) for s in samples]
)
print(
f"The average accuracy removing the samples for which the model is uncertain is {avg_accuracy:.3f}"
f"The average accuracy removing the samples for which the model is uncertain is {100*avg_accuracy:.3f}%"
)
```

0 comments on commit bf15ed8

Please sign in to comment.