generated from BoMeyering/pytorch_project_template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
171 lines (139 loc) · 6.68 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# Main training script
# BoMeyering, 2024
import os
import torch
import yaml
import sys
import json
import logging.config
import datetime
from torch.utils.data import DataLoader
from argparse import ArgumentParser
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from src.utils.config import YamlConfigLoader, ArgsAttributes, setup_loggers
from src.utils.parameters import get_params
from src.models import create_smp_model
from src.optim import EMA, ConfigOptim
from src.datasets import LabeledDataset, UnlabeledDataset
from src.transforms import get_train_transforms, get_strong_transforms, get_weak_transforms, get_val_transforms
from src.dataloaders import DataLoaderBalancer
from src.trainer import FixMatchTrainer, SupervisedTrainer, TBLogger
from src.losses import get_loss_criterion
from metadata.dataset_maps import mapping
# Parse the command line argument configs
parser = ArgumentParser()
parser.add_argument('config', nargs='?', default='configs/train_config.yaml')
args = parser.parse_args()
def main(args):
# Load training configuration yaml
config_loader = YamlConfigLoader(args.config)
config = config_loader.load_config()
# Instantiate args namespace with config and set values
arg_setter = ArgsAttributes(args, config)
arg_setter.set_args_attr()
# Grab validated args Namespace
args = arg_setter.args
# Get class map
class_map = mapping['new_mapping']
# Set up Tensorboard
tb_writer = SummaryWriter(log_dir="/".join(('runs', args.run_name)))
tb_logger = TBLogger(tb_writer)
# Set up the loggers
setup_loggers(args)
logger = logging.getLogger()
# Create model specified in configs
model = create_smp_model(args)
model.to(args.device)
logger.info(f"Instantiated {args.model.model_name} with {args.model.encoder_name} backbone.")
if args.model.starting_weights:
state_dict = torch.load(args.model.starting_weights)['model_state_dict']
model.load_state_dict(state_dict=state_dict)
# Get model parameters and weight decay. Filter out bias and batch norm parameters if necessary
parameters = get_params(args, model)
if args.optimizer.filter_bias_and_bn:
logger.info(f"Applied decay rate to non bias and batch norm parameters.")
else:
logger.info(f"Applied decay rate to all parameters.")
# Get optimizer
opt_stuff = ConfigOptim(args, parameters)
optimizer = opt_stuff.get_optimizer()
scheduler = opt_stuff.get_scheduler()
logger.info(f"Initialized optimizer {args.optimizer.name}")
logger.info(f"Initialized scheduler {args.scheduler.name}")
# Load inverse wights and normalize
with open('metadata/class_pixel_counts.json', 'r') as f:
samples = json.load(f)
samples = torch.tensor([v for v in samples.values()]).to(args.device)
inv_weights = 1/samples
inv_weights = inv_weights / inv_weights.sum()
logger.info(f"Loading sample class pixel distribution {samples}")
logger.info(f"Loading class inverse weights {inv_weights}.")
# Set pixel class samples and inverse weights as args.loss attributes
setattr(args.loss, 'samples', samples)
setattr(args.loss, 'weights', inv_weights)
# Get loss criterion from args
loss_criterion = get_loss_criterion(args)
logger.info(f"Initialized loss criterion {args.loss.name}")
# # Set up EMA if configured
if args.optimizer.ema:
ema = EMA(model, args.optimizer.ema_decay)
logger.info(f'Applied exponential moving average of {args.optimizer.ema_decay} to model weights.')
# Build Datasets and Dataloaders
logger.info(f"Building datasets from {[v for _, v in vars(args.directories).items() if v.startswith('data')]}")
train_l_ds = LabeledDataset(root_dir=args.directories.train_l_dir, transforms=get_train_transforms(resize=args.model.resize))
val_ds = LabeledDataset(root_dir=args.directories.val_dir, transforms=get_val_transforms(resize=args.model.resize))
test_ds = LabeledDataset(root_dir=args.directories.test_dir, transforms=get_val_transforms(resize=args.model.resize))
if 'train_u_dir' in vars(args.directories).keys():
train_u_ds = UnlabeledDataset(root_dir=args.directories.train_u_dir, weak_transforms=get_weak_transforms(resize=args.model.resize), strong_transforms=get_strong_transforms(resize=args.model.resize))
dl_balancer = DataLoaderBalancer(train_l_ds, train_u_ds, batch_sizes=[args.model.lab_bs, args.model.unlab_bs], drop_last=False)
dataloaders, max_length = dl_balancer.balance_loaders()
logger.info(f"Training dataloaders balanced. Labeled DL BS: {args.model.lab_bs} Unlabaled DL BS: {args.model.unlab_bs}.")
logger.info(f"Max loader length for epoch iteration: {max_length}")
val_dataloader = DataLoader(val_ds, batch_size=args.model.lab_bs, shuffle=False, drop_last=False)
logger.info(f"Validation dataloader instantiated")
fixmatch_trainer = FixMatchTrainer(
name='Test_Trainer',
args=args,
model=model,
train_loaders = dataloaders,
train_length=max_length,
val_loader=val_dataloader,
optimizer=optimizer,
criterion=loss_criterion,
scheduler=scheduler,
tb_logger=tb_logger,
class_map=class_map
)
logger.info(f"Created FixMatchTrainer {fixmatch_trainer.trainer_id} for semi-supervised learning.")
logger.info("Training initiated")
fixmatch_trainer.train()
logger.info("Training complete")
tb_writer.flush()
tb_writer.close()
else:
train_dataloader = DataLoader(train_l_ds, batch_size=args.model.lab_bs, shuffle=True, drop_last=True)
val_dataloader = DataLoader(val_ds, batch_size=args.model.lab_bs, shuffle=False, drop_last=False)
test_dataloader = DataLoader(test_ds, batch_size=args.model.lab_bs, shuffle=False, drop_last=False)
logger.info(f"All dataloaders instantiated.")
supervised_trainer = SupervisedTrainer(
name='Supervised Trainer',
args=args,
model=model,
train_loader=train_dataloader,
val_loader=val_dataloader,
optimizer=optimizer,
criterion=loss_criterion,
scheduler=scheduler,
tb_logger=tb_logger,
class_map=class_map
)
logger.info(f"Created SupervisedTrainer {supervised_trainer.trainer_id} for fully supervised learning.")
logger.info("Training initiated")
supervised_trainer.train()
logger.info("Training complete")
tb_writer.flush()
tb_writer.close()
if __name__ == '__main__':
main(args)