Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable squared error loss in classifier_loss_and_stats. #315

Merged
merged 1 commit into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/classifier_mnist/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def classifier_loss(
logits = convolutional_classifier().apply(params, batch["images"])

loss, stats = losses.classifier_loss_and_stats(
logits=logits,
predictions=logits,
labels_as_int=batch["labels"],
params=params,
l2_reg=l2_reg if is_training else 0.0,
Expand Down
129 changes: 101 additions & 28 deletions examples/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def softmax_cross_entropy(
log_z = special.logsumexp(logits, axis=-1)

if logits.shape == labels.shape:
# Labels are encoded as 1-hot vectors
# Labels are encoded as (possibly smoothed) 1-hot vectors
loss = -jnp.sum(logits * labels, axis=-1) + log_z

elif logits.ndim == labels.ndim + 1:
Expand Down Expand Up @@ -156,20 +156,28 @@ def squared_error(
targets: Array,
weight: float = 1.0,
register_loss: bool = True,
mask: Array | None = None,
extra_registration_kwargs: dict[str, Any] | None = None,
registration_module: types.ModuleType = kfac_jax,
) -> Array:
"""Squared error loss."""
extra_registration_kwargs = extra_registration_kwargs or {}

if prediction.shape != targets.shape:
raise ValueError("prediction and targets should have the same shape.")

if register_loss:
registration_module.register_squared_error_loss(
prediction, targets, weight, **extra_registration_kwargs)

return weight * jnp.sum(jnp.square(prediction - targets), axis=-1)
if prediction.shape != targets.shape:
raise ValueError("prediction and targets should have the same shape.")

squared_residuals = jnp.square(prediction - targets)

loss = jnp.sum(squared_residuals, axis=-1)

if mask:
loss = loss * mask

return weight * loss


def top_k_accuracy(
Expand Down Expand Up @@ -226,7 +234,7 @@ def add_label_smoothing(


def classifier_loss_and_stats(
logits: Array,
predictions: Array,
labels_as_int: Array,
params: Params,
l2_reg: Numeric,
Expand All @@ -240,26 +248,67 @@ def classifier_loss_and_stats(
normalization_mode: str = "batch_size_only",
extra_registration_kwargs: dict[str, Any] | None = None,
registration_module: types.ModuleType = kfac_jax,
loss_type: str = "cross_entropy",
) -> tuple[Array, dict[str, Array]]:
"""Softmax cross-entropy with regularizer and accuracy statistics."""

batch_size = logits.shape[0]

if labels_as_int.shape[0] != batch_size:
raise ValueError(f"Size of first dimension of logits ({batch_size}) "
f"(i.e. batch size) doesn't match that of labels "
f"({labels_as_int.shape[0]})")

if mask is not None and mask.shape[0] != batch_size:
raise ValueError(f"Size of first dimension of logits ({batch_size}) "
f"(i.e. batch size) doesn't match that of mask "
f"({mask.shape[0]})")
"""Classification loss with regularizer and accuracy statistics.

Args:
predictions: The output of the model. Logits for loss_type="cross_entropy",
or quantity to be compared to (possibly smoothed) one-hot targets for
loss_type="squared_error". Predictions will have shape (batch_size, ...,
num_classes).
labels_as_int: The labels to be used in the loss regarded as integers. Must
be of shape predictions.shape[:-1].
params: The parameters of the model.
l2_reg: The L2 regularization coefficient.
haiku_exclude_batch_norm: Whether to exclude batch norm parameters from the
L2 regularization (assumes models are Haiku models).
haiku_exclude_biases: Whether to exclude biases from the L2 regularization
(assumes models are Haiku models).
label_smoothing: The label smoothing coefficient.
top_k_stats: The top-k accuracies to compute.
average_loss: Whether to average the loss over the batch.
register_loss: Whether to register the loss.
mask: If not None, a binary mask of shape predictions.shape[:-1]. It's
nonzero values determine which predictions are used in the loss
computation.
normalization_mode: The normalization mode to use for the returned loss, one
of "batch_size_only", "all_dims", or "all_dims_nonmasked".
"batch_size_only" means the loss is normalized by the batch size.
"all_dims" means the loss is normalized by the product of the dimensions
of the predictions array, excluding the last dimension.
"all_dims_nonmasked" means the loss is normalized by the number of
nonzero entries of the mask if it is not None.
extra_registration_kwargs: Extra kwargs to pass to the registration
functions.
registration_module: The module containing the loss registration functions
that will be used. These are 'register_softmax_cross_entropy_loss' and
'register_squared_error_loss".
loss_type: The type of loss to use ("cross_entropy" or "squared_error").

Returns:
The regularized loss and a dictionary of statistics.
"""

batch_size = predictions.shape[0]

if labels_as_int.shape != predictions.shape[:-1]:
raise ValueError(
f"Shape mismatch: labels_as_int shape ({labels_as_int.shape}) "
f"not compatible with predictions shape {predictions.shape}"
)

if mask is not None and mask.shape != predictions.shape[:-1]:
raise ValueError(
f"Shape mismatch: mask shape ({mask.shape}) "
f"not compatible with predictions shape {predictions.shape}"
)

if normalization_mode == "batch_size_only":
weight = 1.0

elif normalization_mode == "all_dims":
weight = 1.0 / kfac_jax.utils.product(logits.shape[1:-1])
weight = 1.0 / kfac_jax.utils.product(predictions.shape[1:-1])

elif normalization_mode == "all_dims_nonmasked":
assert mask is not None
Expand All @@ -269,16 +318,40 @@ def classifier_loss_and_stats(
raise ValueError(f"Unrecognized value for normalization_mode: "
f"{normalization_mode}")

labels = add_label_smoothing(labels_as_int, label_smoothing, logits.shape[-1])
labels = add_label_smoothing(
labels_as_int, label_smoothing, predictions.shape[-1]
)

softmax_loss = softmax_cross_entropy(
logits, labels, weight=weight, register_loss=register_loss, mask=mask,
extra_registration_kwargs=extra_registration_kwargs,
registration_module=registration_module)
if loss_type == "cross_entropy":
raw_loss = softmax_cross_entropy(
predictions,
labels,
weight=weight,
register_loss=register_loss,
mask=mask,
extra_registration_kwargs=extra_registration_kwargs,
registration_module=registration_module,
)
elif loss_type == "squared_error":

if predictions.ndim == labels.ndim + 1:
labels = jax.nn.one_hot(labels, predictions.shape[-1])

raw_loss = squared_error(
predictions,
labels,
weight=weight,
register_loss=register_loss,
mask=mask,
extra_registration_kwargs=extra_registration_kwargs,
registration_module=registration_module,
)
else:
raise ValueError(f"Unknown loss type: {loss_type}")

averaged_raw_loss = jnp.sum(softmax_loss, axis=0) / batch_size
averaged_raw_loss = jnp.sum(raw_loss, axis=0) / batch_size

loss = averaged_raw_loss if average_loss else softmax_loss
loss = averaged_raw_loss if average_loss else raw_loss

l2_reg_val = l2_regularizer(
params, haiku_exclude_batch_norm, haiku_exclude_biases)
Expand All @@ -290,6 +363,6 @@ def classifier_loss_and_stats(
l2_reg_val=l2_reg_val,
)
for k in top_k_stats:
stats[f"top_{k}_accuracy"] = top_k_accuracy(logits, labels_as_int, k)
stats[f"top_{k}_accuracy"] = top_k_accuracy(predictions, labels_as_int, k)

return regularized_loss, stats
2 changes: 1 addition & 1 deletion examples/lrelunet101_imagenet/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def lrelunet_loss(
params, rng, batch["images"], is_training)

return losses.classifier_loss_and_stats(
logits=logits,
predictions=logits,
labels_as_int=batch["labels"],
params=params,
l2_reg=l2_reg if is_training else 0.0,
Expand Down
2 changes: 1 addition & 1 deletion examples/resnet50_imagenet/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def resnet50_loss(
).apply(params, state, batch["images"], is_training=is_training)

loss, stats = losses.classifier_loss_and_stats(
logits=logits,
predictions=logits,
labels_as_int=batch["labels"],
params=params,
l2_reg=l2_reg if is_training else 0.0,
Expand Down