Training seems to pause every N steps #13375
-
I am doing feature extraction using an efficientnet_b0 model. The training process works fine but it seems to pause every once in a while. I verified this using Right now I am training with 4 Tesla T4, but I verified the same issue with a single GPU (T4 and V100). I noticed the training pausing at epoch 48, 96, 144,... So it pauses every 48 steps. I thought that the pause were caused by logging so in my Originally, I thought it was a PyTorch "issue". So I opened a post here https://discuss.pytorch.org/t/gpu-usage-is-not-constant-during-training/154718 . However I am wondering whether this could be caused by torch lightning. Thank you |
Beta Was this translation helpful? Give feedback.
Replies: 8 comments 3 replies
-
@mfoglio It sounds like it's because of logging as you mentioned. To make sure if the issue sits around the logger, have you checked if you see the same behaviour by disabling it completely
How are you checking that it happens at these exact steps? From the progress bar? |
Beta Was this translation helpful? Give feedback.
-
Hi @akihironitta , thank you for your help. Yes, I am talking about steps, not epochs. And I am checking using the progress bar. The sleep is also confirmed by EDIT: while at the beginning the code seems to be stuck at number of steps that are multiples of 48, I also noticed the progress bar getting stuck at step 965 which is obviously not a multiple of 48. I disabled both logging and checkpoint but I still see the same issue: code # Trainer
trainer = Trainer(
accelerator='auto',
gpus=-1,
default_root_dir=checkpoint_path,
# devices=1, # automatically inferred
# num_processes=os.cpu_count(),
strategy=DDPStrategy(find_unused_parameters=True),
precision=16 if torch.cuda.is_available() else 32,
# max_epochs=1000,
# max_steps=1000,
logger=False,
log_every_n_steps=500,
val_check_interval=1.0, # must be a float otherwise it's interpreted as the number of batches
# num_sanity_val_steps=0,
# callbacks=[
# feature_freeze_unfreeze,
# checkpoint_epoch_model,
# checkpoint_best_model,
# RichProgressBar(),
# LearningRateMonitor(logging_interval='step'),
# ],
# profiler='simple',
max_epochs=-1
)
# Train
trainer.fit(
model=model,
datamodule=vehicles_datamodule,
ckpt_path=sorted(
glob.glob(os.path.join(checkpoint_epoch_path, '*.ckpt')) # todo: edit path
)[-1] if resume_from_checkpoint else None
) It's hard for me to share my entire code but I can share the model and the data module. from collections import defaultdict
from typing import Dict, List
import seaborn as sn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchmetrics
import torchvision
from vehicles.dataset.attributes import Attribute, AttributeValue
from vehicles.dataset.splits import DatasetSplit
from vehicles.model.modules import Flatten
from utils.plots import convert_matplotlib_plot_to_torch_tensor
class MultiOutputClassifier(pl.LightningModule):
def __init__(self, attributes: List[Attribute],
samples_per_attribute_value: Dict[Attribute, Dict[AttributeValue, int]],
batch_size: int = 64,
freeze_features=False,
verbose=False,
):
super(MultiOutputClassifier, self).__init__()
self.attributes = attributes
self.verbose = verbose
# FEATURES
# Resnet
# resnet = torchvision.models.resnet50(pretrained=True)
# torch.nn.Sequential(*(list(resnet.children())[:-1]))
# EfficientNet
efficientnet_b0 = torchvision.models.efficientnet_b0(pretrained=True)
self.features = torch.nn.Sequential(*(list(efficientnet_b0.children())[:-1]))
# SqueezeNet
# self.features = torchvision.models.squeezenet1_1(pretrained=True).features # output is [n, 512, 13, 13]
if freeze_features is True:
for param in self.features.parameters():
param.requires_grad = False
# MIDDLE LAYERS
self.channels_fc = 128
self.middle_layers = nn.Sequential(
# nn.Conv2d(512, 512, 3),
# nn.Conv2d(512, 512, 3),
# nn.Dropout2d(p=0.1),
# Squeezenet
# nn.AvgPool2d(13),
# Flatten(512),
# EfficientDet
Flatten(1280),
nn.Dropout(0.25),
nn.Linear(1280, self.channels_fc), # 1280 for efficient det
nn.Linear(self.channels_fc, self.channels_fc, bias=False),
# nn.Linear(self.channels_fc, self.channels_fc, bias=False),
)
# OUTPUTS
self.output_layers = nn.ModuleDict({
f'output_{attribute.name}': nn.Sequential(
nn.Linear(self.channels_fc, len(attribute.get_values()), bias=False),
) for attribute in self.attributes
})
# INITIALIZE WEIGHTS
self._initialize_weights()
# BATCH SIZE
self.batch_size = batch_size
# LOSSES
self.losses = self._get_losses(self.attributes, samples_per_attribute_value)
# METRICS
self.accuracy_metrics = nn.ModuleDict({
dataset_split: nn.ModuleDict({
attribute.name: torchmetrics.Accuracy(
ignore_index=Attribute.ATTRIBUTE_VALUE_NONE.index,
average='macro',
num_classes=len(attribute.get_values(include_none=False))
)
for attribute in self.attributes
}) for dataset_split in [DatasetSplit.TRAIN, DatasetSplit.VALIDATION, DatasetSplit.TEST]
})
self.precision_metrics = nn.ModuleDict({
dataset_split: nn.ModuleDict({
attribute.name: torchmetrics.Precision(
ignore_index=Attribute.ATTRIBUTE_VALUE_NONE.index,
average='macro',
num_classes=len(attribute.get_values(include_none=False))
)
for attribute in self.attributes
}) for dataset_split in [DatasetSplit.TRAIN, DatasetSplit.VALIDATION, DatasetSplit.TEST]
})
self.recall_metrics = nn.ModuleDict({
dataset_split: nn.ModuleDict({
attribute.name: torchmetrics.Recall(
ignore_index=Attribute.ATTRIBUTE_VALUE_NONE.index,
average='macro',
num_classes=len(attribute.get_values(include_none=False))
)
for attribute in self.attributes
}) for dataset_split in [DatasetSplit.TRAIN, DatasetSplit.VALIDATION, DatasetSplit.TEST]
})
self.f1_metrics = nn.ModuleDict({
dataset_split: nn.ModuleDict({
attribute.name: torchmetrics.F1Score(
ignore_index=Attribute.ATTRIBUTE_VALUE_NONE.index,
average='macro',
num_classes=len(attribute.get_values(include_none=False))
)
for attribute in self.attributes
}) for dataset_split in [DatasetSplit.TRAIN, DatasetSplit.VALIDATION, DatasetSplit.TEST]
})
self.confusion_matrix_metrics = nn.ModuleDict({
dataset_split: nn.ModuleDict({
attribute.name: torchmetrics.ConfusionMatrix(
num_classes=len(attribute.get_values(include_none=False)),
ignore_index=Attribute.ATTRIBUTE_VALUE_NONE.index,
normalize='true',
nan_strategy='ignore'
)
for attribute in self.attributes
}) for dataset_split in [DatasetSplit.TRAIN, DatasetSplit.VALIDATION, DatasetSplit.TEST]
})
def forward(self, x):
# Features
x = self.features(x)
if self.verbose:
print(f'Features {x.shape}')
# Middle layers
x = self.middle_layers(x)
if self.verbose:
print(f'Middle layers {x.shape}')
# Outputs
outputs = {attribute: output_layer(x) for attribute, output_layer in self.output_layers.items()}
# outputs = self.output_namedtuple(**outputs)
if self.verbose:
print(f'Output layers' + str({attribute: output.shape for attribute, output in outputs.items()}))
return outputs
def configure_optimizers(self):
def lr_lambda(epoch: int):
if epoch < 1:
return 1
# elif epoch < 2:
# return 2
elif epoch < 25:
return 1
else:
return 0.97 ** (epoch / 2.4) # official efficientnet decaying from paper
# Rescaling learning rate of official paper to the used batch size
# learning_rate = 0.1 * 0.256 / 4096 * self.batch_size * torch.cuda.device_count()
# Optimizer used by the paper
# TODO: check that RMSprop is properly initialized as explained in the paper
# optimizer = torch.optim.RMSprop(
# self.parameters(), lr=learning_rate, alpha=0.9, momentum=0.9, weight_decay=1e-5
# )
optimizer = torch.optim.AdamW(self.parameters(), lr=0.1)
# optimizer_features = torch.optim.AdamW(self.features.parameters(), lr=0.0001)
# optimizer_middle_layers = torch.optim.AdamW(self.middle_layers.parameters(), lr=0.01)
# optimizer_outputs = torch.optim.AdamW(self.output_layers.parameters(), lr=0.01)
# lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=0.25, epochs=100, steps_per_epoch=4,
pct_start=0.1, anneal_strategy='cos', cycle_momentum=True,
base_momentum=0.85
)
return [optimizer], [lr_scheduler]
@staticmethod
def _get_losses(attributes: List[Attribute],
samples_per_attribute_value: Dict[Attribute, Dict[AttributeValue, int]]):
# Compute weights
weights_attributes = dict()
for attribute in attributes:
weights_attributes[attribute]: Dict[AttributeValue, float] = dict()
for attribute_value in attribute.get_values():
n_samples = samples_per_attribute_value[attribute][attribute_value]
if attribute_value is not Attribute.ATTRIBUTE_VALUE_NONE:
weights_attributes[attribute][attribute_value] = 1 / n_samples if n_samples != 0 else 1
# Convert to tensors
weights_attributes = {
attribute: torch.tensor(list(weights_attribute.values()))
for attribute, weights_attribute in weights_attributes.items()
}
losses = nn.ModuleDict({
attribute.name: torch.nn.CrossEntropyLoss(
weight=weights_attribute,
ignore_index=Attribute.ATTRIBUTE_VALUE_NONE.index
)
for attribute, weights_attribute in weights_attributes.items()
})
return losses
def training_step(self, batch, batch_id):
inputs, labels = batch
y_pred_attributes = self.forward(inputs)
log_metrics = self._get_log_metrics(
dataset_split=DatasetSplit.TRAIN,
y_pred_attributes=y_pred_attributes,
y_target_attributes=labels,
)
# self.log_dict(log_metrics, rank_zero_only=True)
losses = {
attribute.name: log_metrics[f'{attribute.name} Loss'][DatasetSplit.TRAIN]
for attribute in self.attributes
}
sum_losses = sum([loss for loss in losses.values() if not loss.isnan()])
return {'loss': sum_losses, 'losses': losses, 'log': log_metrics}
def validation_step(self, batch, batch_id):
inputs, labels = batch
y_pred_attributes = self.forward(inputs)
log_metrics = self._get_log_metrics(
dataset_split=DatasetSplit.VALIDATION,
y_pred_attributes=y_pred_attributes,
y_target_attributes=labels,
)
# self.log_dict(log_metrics, sync_dist=True, rank_zero_only=True)
losses = {
attribute.name: log_metrics[f'{attribute.name} Loss'][DatasetSplit.VALIDATION]
for attribute in self.attributes
}
sum_losses = sum([loss for loss in losses.values() if not loss.isnan()])
return {'loss': sum_losses, 'log': log_metrics}
def test_step(self, batch, batch_id):
inputs, labels = batch
y_pred_attributes = self.forward(inputs)
log_metrics = self._get_log_metrics(
dataset_split=DatasetSplit.TEST,
y_pred_attributes=y_pred_attributes,
y_target_attributes=labels,
)
# self.log_dict(log_metrics, sync_dist=True, rank_zero_only=True)
losses = {
attribute.name: log_metrics[f'{attribute.name} Loss'][DatasetSplit.TEST]
for attribute in self.attributes
}
sum_losses = sum([loss for loss in losses.values() if not loss.isnan()])
return {'loss': sum_losses, 'log': log_metrics}
def training_epoch_end(self, outputs):
self._log_confusion_matrix(dataset_split=DatasetSplit.TRAIN)
def validation_epoch_end(self, outputs):
self._log_confusion_matrix(dataset_split=DatasetSplit.VALIDATION)
# Metric to select best model
avg_f1 = torch.mean(torch.tensor(
[f1_metric.compute() for f1_metric in self.f1_metrics[DatasetSplit.VALIDATION].values()]
))
# self.log('val_avg_f1', avg_f1, sync_dist=True, rank_zero_only=True)
def test_epoch_end(self, outputs):
self._log_confusion_matrix(dataset_split=DatasetSplit.TEST)
def _get_log_metrics(self, dataset_split: 'DatasetSplit', y_pred_attributes, y_target_attributes) \
-> Dict[str, Dict[str, float]]:
metrics: Dict[str, Dict[str, float]] = defaultdict(dict)
for attribute, label in zip(self.attributes, y_target_attributes.T):
metrics[f'{attribute.name} Loss'][dataset_split] = self.losses[attribute.name](
y_pred_attributes[f'output_{attribute.name}'], label
)
if not all(label == Attribute.ATTRIBUTE_VALUE_NONE.index): # TODO: is it ok to skip all these samples?
y_pred_attribute = y_pred_attributes[f'output_{attribute.name}']
# Filter out ignored index ( # TODO: torchmetrics seems to have a bug)
# TODO: is it correct?
filtered_label = label[label != Attribute.ATTRIBUTE_VALUE_NONE.index]
filtered_y_pred_attribute = y_pred_attribute[label != Attribute.ATTRIBUTE_VALUE_NONE.index]
# Compute metrics
self.accuracy_metrics[dataset_split][attribute.name](filtered_y_pred_attribute, filtered_label)
self.precision_metrics[dataset_split][attribute.name](filtered_y_pred_attribute, filtered_label)
self.recall_metrics[dataset_split][attribute.name](filtered_y_pred_attribute, filtered_label)
self.f1_metrics[dataset_split][attribute.name](filtered_y_pred_attribute, filtered_label)
self.confusion_matrix_metrics[dataset_split][attribute.name](filtered_y_pred_attribute, filtered_label)
metrics[f'{attribute.name} Accuracy'][dataset_split] = self.accuracy_metrics[dataset_split][attribute.name]
metrics[f'{attribute.name} F-1'][dataset_split] = self.f1_metrics[dataset_split][attribute.name]
metrics[f'{attribute.name} Precision'][dataset_split] = self.precision_metrics[dataset_split][attribute.name]
metrics[f'{attribute.name} Recall'][dataset_split] = self.recall_metrics[dataset_split][attribute.name]
return metrics
def _log_confusion_matrix(self, dataset_split: 'DatasetSplit'):
# Confusion matrix
for attribute in self.attributes:
# Compute confusion matrix
conf_mat = self.confusion_matrix_metrics[dataset_split][attribute.name].compute().detach().cpu().numpy().astype(np.float32)
# Convert confusion matrix to pandas dataframe
attribute_labels = [attribute_value.label for attribute_value in attribute.get_values()]
df_cm = pd.DataFrame(conf_mat, index=attribute_labels, columns=attribute_labels)
# Generate plot
figure = plt.figure(figsize=(25, 25))
# sn.set(font_scale=1.2)
sn.heatmap(df_cm, annot=True, annot_kws={'size': 7}, square=True, fmt='.2%')
# Send plot to tensorboard
image = convert_matplotlib_plot_to_torch_tensor(figure)
# self.logger.experiment.add_image(
# f'{dataset_split} {attribute.name} Confusion Matrix',
# image,
# global_step=self.current_epoch
# )
def _initialize_weights(self):
def init_weights(m):
if type(m) in [nn.Linear, nn.Conv2d]:
torch.nn.init.kaiming_uniform_(m.weight)
self.middle_layers.apply(init_weights)
for output_layer in self.output_layers.values():
output_layer.apply(init_weights)
@torch.jit.export
def _format_single_output(self, labels: List[str], values: torch.Tensor):
indexes = torch.argsort(values, descending=True)
values: List[float] = values.tolist()
sorted_labels = [labels[i] for i in indexes]
sorted_values = [values[i] for i in indexes]
return {label: value for label, value in zip(sorted_labels, sorted_values)}
def freeze_feature_layers(self):
for param in self.features.parameters():
param.requires_grad = False
def unfreeze_feature_layers(self):
for param in self.features.parameters():
param.requires_grad = True
def freeze_middle_layers(self):
for param in self.middle_layers.parameters():
param.requires_grad = False
def unfreeze_middle_layers(self):
for param in self.middle_layers.parameters():
param.requires_grad = True
def to_onnx(self, **kwargs):
# ONNX settings
dynamic_axes = {
'images': {0: 'batch'},
**{
str(output): {0: 'batch'}
for output in self.output_layers.keys()
}
}
torch_onnx_export_kwargs = dict(
input_names=['images'],
output_names=[str(output) for output in self.output_layers.keys()],
dynamic_axes=dynamic_axes,
)
kwargs = {
**kwargs,
**torch_onnx_export_kwargs
}
super().to_onnx(**kwargs)
def export_deepstream_label_file(self, label_file_path: str):
# Create label file
text = ''
for attribute in self.attributes:
text += ';'.join([
f'{attribute.name}_{attribute_value.label}'
for attribute_value in attribute.get_values(include_none=False)
])
text += '\n'
with open(label_file_path, 'w') as fp:
fp.write(text)
class FeatureExtractorFreezeUnfreeze(pl.callbacks.BaseFinetuning):
def __init__(self, unfreeze_at_epoch: int):
super().__init__()
self._unfreeze_at_epoch = unfreeze_at_epoch
def freeze_before_training(self, pl_module):
# freeze any module you want
# Here, we are freezing `features`
self.freeze(pl_module.features)
def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx):
# When `current_epoch` is self._unfreeze_at_epoch, features layers will start training.
if current_epoch == self._unfreeze_at_epoch:
self.unfreeze_and_add_param_group(
modules=pl_module.features,
optimizer=optimizer,
train_bn=True,
) Data module: import os
from typing import Any, Dict, List, Optional
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningDataModule
from vehicles.dataset.attributes import Attribute
class VehiclesDataModule(LightningDataModule):
def __init__(self, train_dataset: Dataset, val_dataset: Dataset, test_dataset: Dataset,
attributes: List[Attribute], batch_size: int, num_workers: int):
super().__init__()
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.test_dataset = test_dataset
self.attributes = attributes
self.batch_size = batch_size
self.num_workers = num_workers
def setup(self, stage: Optional[str] = None):
pass
def _get_data_loaders_kwargs(self):
return dict(
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=True,
# drop_last=True
prefetch_factor=2
)
def train_dataloader(self):
return DataLoader(self.train_dataset, **self._get_data_loaders_kwargs())
def val_dataloader(self):
return DataLoader(self.val_dataset, **self._get_data_loaders_kwargs())
def test_dataloader(self):
return DataLoader(self.test_dataset, **self._get_data_loaders_kwargs())
def predict_dataloader(self):
return DataLoader(self.test_dataset, **self._get_data_loaders_kwargs())
def teardown(self, stage: Optional[str] = None):
# Used to clean-up when the run is finished
pass
def state_dict(self) -> Dict[str, Any]:
return {
'train_dataset': self.train_dataset.state_dict(),
'val_dataset': self.val_dataset.state_dict(),
'test_dataset': self.test_dataset.state_dict(),
}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.train_dataset.load_state_dict(state_dict['train_dataset'])
self.val_dataset.load_state_dict(state_dict['val_dataset'])
self.test_dataset.load_state_dict(state_dict['test_dataset']) |
Beta Was this translation helpful? Give feedback.
-
As a side question, the following checkpoint will be saved only at the end of every epoch, not every time there is an improvement in the metric, right? Because during the first epochs there is going to be an improvement almost every step.
Again, just to be clear, in my latest test the checkpointing system was disabled. |
Beta Was this translation helpful? Give feedback.
-
I analyzed the problem a little bit more. I noticed that I have 48 CPUs and 48 workers. That makes the training process pausing every 48 steps. If use 12 workers, the pause happens every 12 steps. EDIT: I think I am facing this issue pytorch/pytorch#13246 (comment) even though I am not entirely sure. My memory consumption is of about 100-150 gb right after the training starts. I tried to used a numpy array to store the huge list of integers containing the IDs of the record in the dataset. However, this didn't reduce the RAM usage. |
Beta Was this translation helpful? Give feedback.
-
Hi @mfoglio, did you find a way to solve this problem? I am having the same problem. |
Beta Was this translation helpful? Give feedback.
-
I am having the same exact issue. I tried to solve it with pytorch/pytorch#13246 (comment) but without success. For me the problem is not that it uses large amount of RAM, but that it really slows down every N steps (where N is num_workers). Things I tried (as suggested in the aforementioned issue)
|
Beta Was this translation helpful? Give feedback.
-
I met the same issue and it turned out that my |
Beta Was this translation helpful? Give feedback.
-
Exactly the same issue! Still haven't found any viable solution to this. Currently in search. |
Beta Was this translation helpful? Give feedback.
I analyzed the problem a little bit more. I noticed that I have 48 CPUs and 48 workers. That makes the training process pausing every 48 steps. If use 12 workers, the pause happens every 12 steps.
I'd like to increase the number of workers but the RAM usage is crazy high. With 48 workers I am almost using all the 180Gb of RAM available. Is this normal for simply loading images of a few Kbytes?
Any suggestion on how to speed this up?
EDIT: I think I am facing this issue pytorch/pytorch#13246 (comment) even though I am not entirely sure. My memory consumption is of about 100-150 gb right after the training starts. I tried to used a numpy array to store the huge list of integers containing t…