diff --git a/examples/SGMCMC.md b/examples/SGMCMC.md index 4ed1fa1b2..77bd930cd 100644 --- a/examples/SGMCMC.md +++ b/examples/SGMCMC.md @@ -104,14 +104,6 @@ class NN(nn.Module): model = NN() -@jax.jit -def predict_fn(params, X): - """Returns the probability for the image represented by X - to be in each category given the MLP's weights vakues. - """ - return model.apply(params, X) - - def logprior_fn(params): """Compute the value of the log-prior density function.""" leaves, _ = jax.tree_util.tree_flatten(params) @@ -122,7 +114,7 @@ def logprior_fn(params): def loglikelihood_fn(params, data): """Categorical log-likelihood""" X, y = data - return jnp.sum(y * predict_fn(params, X)) + return jnp.sum(y * model.apply(params, X)) @jax.jit @@ -132,7 +124,7 @@ def compute_accuracy(params, X, y): To make predictions we take the number that corresponds to the highest probability value. """ target_class = jnp.argmax(y, axis=1) - predicted_class = jnp.argmax(predict_fn(params, X), axis=1) + predicted_class = jnp.argmax(model.apply(params, X), axis=1) return jnp.mean(predicted_class == target_class) ``` @@ -142,7 +134,7 @@ Now we need to get initial values for the parameters, and we simply sample from +++ -We now sample from the model's posteriors. We discard the first 1000 samples until the sampler has reached the typical set, and then take 2000 samples. We record the model's accuracy with the current values every 100 steps. +We now sample from the model's posteriors using SGLD. We discard the first 1000 samples until the sampler has reached the typical set, and then take 2000 samples. We record the model's accuracy with the current values every 100 steps. ```{code-cell} ipython3 %%time @@ -155,7 +147,6 @@ from blackjax.sgmcmc.gradients import grad_estimator data_size = len(y_train) batch_size = 512 -layer_sizes = [784, 100, 10] step_size = 5e-5 num_warmup = (data_size // batch_size) * 20 @@ -165,12 +156,13 @@ num_samples = 1000 rng_key = jax.random.PRNGKey(1) batches = batch_data(rng_key, (X_train, y_train), batch_size, data_size) +# Set the initial state +init_positions = jax.jit(model.init)(rng_key, jnp.ones(X_train.shape[-1])) + # Build the SGLD kernel with a constant learning rate grad_fn = grad_estimator(logprior_fn, loglikelihood_fn, data_size) sgld = blackjax.sgld(grad_fn, lambda _: step_size) -# Set the initial state -init_positions = jax.jit(model.init)(rng_key, jnp.ones(X_train.shape[-1])) state = sgld.init(init_positions, next(batches)) # Sample from the posterior @@ -208,10 +200,55 @@ plt.title("Sample from 3-layer MLP posterior (MNIST dataset) with SgLD") plt.plot(); ``` +### Sampling with SGHMC + +We can also use SGHMC to samples from this model + +```{code-cell} ipython3 +# Build the SGHMC kernel with a constant learning rate +step_size = 9e-6 +grad_fn = grad_estimator(logprior_fn, loglikelihood_fn, data_size) +sghmc = blackjax.sghmc(grad_fn, lambda _: step_size) + +# Batch the data +state = sghmc.init(init_positions, next(batches)) + +# Sample from the posterior +sghmc_accuracies = [] +samples = [] +steps = [] +for step in progress_bar(range(num_samples + num_warmup)): + _, rng_key = jax.random.split(rng_key) + batch = next(batches) + state = jax.jit(sghmc.step)(rng_key, state, batch) + if step % 100 == 0: + sghmc_accuracy = compute_accuracy(state.position, X_test, y_test) + sghmc_accuracies.append(sghmc_accuracy) + steps.append(step) + if step > num_warmup: + samples.append(state.position) +``` + +```{code-cell} ipython3 +fig = plt.figure(figsize=(12, 8)) +ax = fig.add_subplot(111) +ld_plot, = ax.plot(steps, accuracies) +hmc_plot, = ax.plot(steps, sghmc_accuracies) +ax.set_xlabel("Number of sampling steps") +ax.set_ylabel("Prediction accuracy") +ax.set_xlim([0, num_warmup + num_samples]) +ax.set_ylim([0, 1]) +ax.set_yticks([0.1, 0.3, 0.5, 0.7, 0.9]) +plt.title("Sample from 3-layer MLP posterior (MNIST dataset)") +ax.legend((ld_plot, hmc_plot), ('SGLD', 'SGHMC'), loc='lower right', shadow=True) +plt.plot(); +``` + ```{code-cell} ipython3 :tags: [hide-input] -print(f"The average accuracy in the sampling phase is {np.mean(accuracies[10:]):.2f}") +print(f"The average accuracy for SGLD in the sampling phase is {100 * np.mean(accuracies[10:]):.2f}%") +print(f"The average accuracy for SGHMC in the sampling phase is {100 * np.mean(sghmc_accuracies[10:]):.2f}%") ``` Which is not a bad accuracy at all for such a simple model! Remember though that we draw samples from the posterior distribution of the digit probabilities; we can thus use this information to filter out examples for which the model is "unsure" of its prediction. @@ -220,7 +257,7 @@ Here we will say that the model is unsure of its prediction for a given image if ```{code-cell} ipython3 predicted_class = jnp.exp( - jnp.stack([jax.vmap(predict_fn, in_axes=(None, 0))(s, X_test) for s in samples]) + jnp.stack([jax.vmap(model.apply, in_axes=(None, 0))(s, X_test) for s in samples]) ) ``` @@ -257,6 +294,6 @@ avg_accuracy = np.mean( [compute_accuracy(s, X_test[certain_mask], y_test[certain_mask]) for s in samples] ) print( - f"The average accuracy removing the samples for which the model is uncertain is {avg_accuracy:.3f}" + f"The average accuracy removing the samples for which the model is uncertain is {100*avg_accuracy:.3f}%" ) ```