Skip to content

Commit

Permalink
Enable squared error loss in classifier_loss_and_stats.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 730457825
  • Loading branch information
timothyn617 authored and KfacJaxDev committed Feb 26, 2025
1 parent 4cc9d8f commit f2e8c26
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 31 deletions.
2 changes: 1 addition & 1 deletion examples/classifier_mnist/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def classifier_loss(
logits = convolutional_classifier().apply(params, batch["images"])

loss, stats = losses.classifier_loss_and_stats(
logits=logits,
predictions=logits,
labels_as_int=batch["labels"],
params=params,
l2_reg=l2_reg if is_training else 0.0,
Expand Down
129 changes: 101 additions & 28 deletions examples/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def softmax_cross_entropy(
log_z = special.logsumexp(logits, axis=-1)

if logits.shape == labels.shape:
# Labels are encoded as 1-hot vectors
# Labels are encoded as (possibly smoothed) 1-hot vectors
loss = -jnp.sum(logits * labels, axis=-1) + log_z

elif logits.ndim == labels.ndim + 1:
Expand Down Expand Up @@ -156,20 +156,28 @@ def squared_error(
targets: Array,
weight: float = 1.0,
register_loss: bool = True,
mask: Array | None = None,
extra_registration_kwargs: dict[str, Any] | None = None,
registration_module: types.ModuleType = kfac_jax,
) -> Array:
"""Squared error loss."""
extra_registration_kwargs = extra_registration_kwargs or {}

if prediction.shape != targets.shape:
raise ValueError("prediction and targets should have the same shape.")

if register_loss:
registration_module.register_squared_error_loss(
prediction, targets, weight, **extra_registration_kwargs)

return weight * jnp.sum(jnp.square(prediction - targets), axis=-1)
if prediction.shape != targets.shape:
raise ValueError("prediction and targets should have the same shape.")

squared_residuals = jnp.square(prediction - targets)

loss = jnp.sum(squared_residuals, axis=-1)

if mask:
loss = loss * mask

return weight * loss


def top_k_accuracy(
Expand Down Expand Up @@ -226,7 +234,7 @@ def add_label_smoothing(


def classifier_loss_and_stats(
logits: Array,
predictions: Array,
labels_as_int: Array,
params: Params,
l2_reg: Numeric,
Expand All @@ -240,26 +248,67 @@ def classifier_loss_and_stats(
normalization_mode: str = "batch_size_only",
extra_registration_kwargs: dict[str, Any] | None = None,
registration_module: types.ModuleType = kfac_jax,
loss_type: str = "cross_entropy",
) -> tuple[Array, dict[str, Array]]:
"""Softmax cross-entropy with regularizer and accuracy statistics."""

batch_size = logits.shape[0]

if labels_as_int.shape[0] != batch_size:
raise ValueError(f"Size of first dimension of logits ({batch_size}) "
f"(i.e. batch size) doesn't match that of labels "
f"({labels_as_int.shape[0]})")

if mask is not None and mask.shape[0] != batch_size:
raise ValueError(f"Size of first dimension of logits ({batch_size}) "
f"(i.e. batch size) doesn't match that of mask "
f"({mask.shape[0]})")
"""Classification loss with regularizer and accuracy statistics.
Args:
predictions: The output of the model. Logits for loss_type="cross_entropy",
or quantity to be compared to (possibly smoothed) one-hot targets for
loss_type="squared_error". Predictions will have shape (batch_size, ...,
num_classes).
labels_as_int: The labels to be used in the loss regarded as integers. Must
be of shape predictions.shape[:-1].
params: The parameters of the model.
l2_reg: The L2 regularization coefficient.
haiku_exclude_batch_norm: Whether to exclude batch norm parameters from the
L2 regularization (assumes models are Haiku models).
haiku_exclude_biases: Whether to exclude biases from the L2 regularization
(assumes models are Haiku models).
label_smoothing: The label smoothing coefficient.
top_k_stats: The top-k accuracies to compute.
average_loss: Whether to average the loss over the batch.
register_loss: Whether to register the loss.
mask: If not None, a binary mask of shape predictions.shape[:-1]. It's
nonzero values determine which predictions are used in the loss
computation.
normalization_mode: The normalization mode to use for the returned loss, one
of "batch_size_only", "all_dims", or "all_dims_nonmasked".
"batch_size_only" means the loss is normalized by the batch size.
"all_dims" means the loss is normalized by the product of the dimensions
of the predictions array, excluding the last dimension.
"all_dims_nonmasked" means the loss is normalized by the number of
nonzero entries of the mask if it is not None.
extra_registration_kwargs: Extra kwargs to pass to the registration
functions.
registration_module: The module containing the loss registration functions
that will be used. These are 'register_softmax_cross_entropy_loss' and
'register_squared_error_loss".
loss_type: The type of loss to use ("cross_entropy" or "squared_error").
Returns:
The regularized loss and a dictionary of statistics.
"""

batch_size = predictions.shape[0]

if labels_as_int.shape != predictions.shape[:-1]:
raise ValueError(
f"Shape mismatch: labels_as_int shape ({labels_as_int.shape}) "
f"not compatible with predictions shape {predictions.shape}"
)

if mask is not None and mask.shape != predictions.shape[:-1]:
raise ValueError(
f"Shape mismatch: mask shape ({mask.shape}) "
f"not compatible with predictions shape {predictions.shape}"
)

if normalization_mode == "batch_size_only":
weight = 1.0

elif normalization_mode == "all_dims":
weight = 1.0 / kfac_jax.utils.product(logits.shape[1:-1])
weight = 1.0 / kfac_jax.utils.product(predictions.shape[1:-1])

elif normalization_mode == "all_dims_nonmasked":
assert mask is not None
Expand All @@ -269,16 +318,40 @@ def classifier_loss_and_stats(
raise ValueError(f"Unrecognized value for normalization_mode: "
f"{normalization_mode}")

labels = add_label_smoothing(labels_as_int, label_smoothing, logits.shape[-1])
labels = add_label_smoothing(
labels_as_int, label_smoothing, predictions.shape[-1]
)

softmax_loss = softmax_cross_entropy(
logits, labels, weight=weight, register_loss=register_loss, mask=mask,
extra_registration_kwargs=extra_registration_kwargs,
registration_module=registration_module)
if loss_type == "cross_entropy":
raw_loss = softmax_cross_entropy(
predictions,
labels,
weight=weight,
register_loss=register_loss,
mask=mask,
extra_registration_kwargs=extra_registration_kwargs,
registration_module=registration_module,
)
elif loss_type == "squared_error":

if predictions.ndim == labels.ndim + 1:
labels = jax.nn.one_hot(labels, predictions.shape[-1])

raw_loss = squared_error(
predictions,
labels,
weight=weight,
register_loss=register_loss,
mask=mask,
extra_registration_kwargs=extra_registration_kwargs,
registration_module=registration_module,
)
else:
raise ValueError(f"Unknown loss type: {loss_type}")

averaged_raw_loss = jnp.sum(softmax_loss, axis=0) / batch_size
averaged_raw_loss = jnp.sum(raw_loss, axis=0) / batch_size

loss = averaged_raw_loss if average_loss else softmax_loss
loss = averaged_raw_loss if average_loss else raw_loss

l2_reg_val = l2_regularizer(
params, haiku_exclude_batch_norm, haiku_exclude_biases)
Expand All @@ -290,6 +363,6 @@ def classifier_loss_and_stats(
l2_reg_val=l2_reg_val,
)
for k in top_k_stats:
stats[f"top_{k}_accuracy"] = top_k_accuracy(logits, labels_as_int, k)
stats[f"top_{k}_accuracy"] = top_k_accuracy(predictions, labels_as_int, k)

return regularized_loss, stats
2 changes: 1 addition & 1 deletion examples/lrelunet101_imagenet/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def lrelunet_loss(
params, rng, batch["images"], is_training)

return losses.classifier_loss_and_stats(
logits=logits,
predictions=logits,
labels_as_int=batch["labels"],
params=params,
l2_reg=l2_reg if is_training else 0.0,
Expand Down
2 changes: 1 addition & 1 deletion examples/resnet50_imagenet/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def resnet50_loss(
).apply(params, state, batch["images"], is_training=is_training)

loss, stats = losses.classifier_loss_and_stats(
logits=logits,
predictions=logits,
labels_as_int=batch["labels"],
params=params,
l2_reg=l2_reg if is_training else 0.0,
Expand Down

0 comments on commit f2e8c26

Please sign in to comment.