Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Support clipping the gradient norm #642

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 48 additions & 4 deletions classy_vision/tasks/classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import copy
import enum
import itertools
import json
import logging
import math
Expand Down Expand Up @@ -157,6 +158,7 @@ def __init__(self):
)
self.amp_args = None
self.mixup_transform = None
self.grad_norm_clip = None
self.perf_log = []
self.last_batch = None
self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED
Expand Down Expand Up @@ -412,6 +414,24 @@ def set_optimizer_schedulers(self, schedulers):
self.optimizer_schedulers = schedulers
return self

def set_grad_norm_clip(
self,
grad_norm_clip: Optional[float],
) -> "ClassificationTask":
"""Enable / disable clipping the gradient norm

Args:
grad_norm_clip: The value to clip the gradient by, set to None to disable
"""
if grad_norm_clip is None:
logging.info(f"Disabled gradient norm clipping: {grad_norm_clip}")
else:
logging.info(
f"Enabled gradient norm clipping with threshold: {grad_norm_clip}"
)
self.grad_norm_clip = grad_norm_clip
return self

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
"""Instantiates a ClassificationTask from a configuration.
Expand Down Expand Up @@ -489,6 +509,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
.set_distributed_options(**distributed_options)
.set_hooks(hooks)
.set_bn_weight_decay(config.get("bn_weight_decay", False))
.set_grad_norm_clip(config.get("grad_norm_clip"))
)

if not test_only:
Expand Down Expand Up @@ -719,10 +740,7 @@ def init_distributed_data_parallel_model(self):
broadcast_buffers=broadcast_buffers,
find_unused_parameters=self.find_unused_parameters,
)
if (
isinstance(self.base_loss, ClassyLoss)
and self.base_loss.has_learned_parameters()
):
if self._loss_has_learnable_params():
logging.info("Initializing distributed loss")
self.distributed_loss = init_distributed_data_parallel_model(
self.base_loss,
Expand Down Expand Up @@ -919,6 +937,9 @@ def train_step(self):
else:
self.optimizer.backward(local_loss)

if self.grad_norm_clip is not None:
self._clip_grad_norm()

self.check_inf_nan(loss)

self.optimizer.step(where=self.where)
Expand Down Expand Up @@ -992,6 +1013,22 @@ def create_data_iterators(self):
del self.data_iterator
self.data_iterator = iter(self.dataloader)

def _clip_grad_norm(self):
"""Clip the gradient norms based on self.grad_norm_clip"""
model_params = (
self.base_model.parameters()
if self.amp_args is None
else apex.amp.master_params(self.optimizer.optimizer)
)
loss_params = (
self.base_loss.parameters()
if self._loss_has_learnable_params()
else iter(())
)
nn.utils.clip_grad_norm_(
itertools.chain(model_params, loss_params), self.grad_norm_clip
)

def _set_model_train_mode(self):
"""Set train mode for model"""
phase = self.phases[self.phase_idx]
Expand All @@ -1014,6 +1051,13 @@ def _broadcast_buffers(self):
for buffer in buffers:
broadcast(buffer, 0, group=self.distributed_model.process_group)

def _loss_has_learnable_params(self):
"""Returns True if the loss has any learnable parameters"""
return (
isinstance(self.base_loss, ClassyLoss)
and self.base_loss.has_learned_parameters()
)

# TODO: Functions below should be better abstracted into the dataloader
# abstraction
def get_batchsize_per_replica(self):
Expand Down
22 changes: 22 additions & 0 deletions test/tasks_classification_task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import copy
import shutil
import tempfile
import itertools
import unittest
from test.generic.config_utils import get_fast_test_task_config, get_test_task_config
from test.generic.utils import (
Expand Down Expand Up @@ -284,3 +285,24 @@ def test_get_classy_state_on_loss(self):
task = build_task(config)
task.prepare()
self.assertIn("alpha", task.get_classy_state()["loss"])

def test_grad_norm_clip(self):
config = get_fast_test_task_config()
config["loss"] = {"name": "test_stateful_loss", "in_plane": 256}
config["grad_norm_clip"] = grad_norm_clip = 1
task = build_task(config)
task.prepare()

# set fake gradients with norm > grad_norm_clip
for param in itertools.chain(
task.base_model.parameters(), task.base_loss.parameters()
):
param.grad = 1.1 + torch.rand(param.shape)
self.assertGreater(param.grad.norm(), grad_norm_clip)

task._clip_grad_norm()

for param in itertools.chain(
task.base_model.parameters(), task.base_loss.parameters()
):
self.assertLessEqual(param.grad.norm(), grad_norm_clip)