forked from NVIDIA/modulus
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_mhd_vec_pot.py
339 lines (296 loc) · 12.6 KB
/
train_mhd_vec_pot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hydra
from omegaconf import DictConfig
from math import ceil
import torch
import plotly
import os
from torch.optim import lr_scheduler
from torch.nn.parallel import DistributedDataParallel
from omegaconf import OmegaConf
from modulus.models.fno import FNO
from modulus.distributed import DistributedManager
from modulus.launch.utils import load_checkpoint, save_checkpoint
from modulus.launch.logging import (
PythonLogger,
LaunchLogger,
)
from modulus.launch.logging.wandb import initialize_wandb
from modulus.sym.hydra import to_absolute_path
from losses import LossMHDVecPot, LossMHDVecPot_Modulus
from torch.optim import AdamW
from dataloaders import Dedalus2DDataset, MHDDataloaderVecPot
from utils.plot_utils import plot_predictions_mhd, plot_predictions_mhd_plotly
import wandb
dtype = torch.float
# dtype = torch.double
torch.set_default_dtype(dtype)
@hydra.main(
version_base="1.3", config_path="config", config_name="mhd_vec_pot_Re250.yaml"
)
def main(cfg: DictConfig) -> None:
"""Training for the 2D Darcy flow benchmark problem.
This training script demonstrates how to set up a data-driven model for a 2D Darcy flow
using Fourier Neural Operators (FNO) and acts as a benchmark for this type of operator.
Training data is generated in-situ via the Darcy2D data loader from Modulus. Darcy2D
continuously generates data previously unseen by the model, i.e. the model is trained
over a single epoch of a training set consisting of
(cfg.training.max_pseudo_epochs*cfg.training.pseudo_epoch_sample_size) unique samples.
Pseudo_epochs were introduced to leverage the LaunchLogger and its MLFlow integration.
"""
DistributedManager.initialize() # Only call this once in the entire script!
dist = DistributedManager() # call if required elsewhere
# initialize monitoring
log = PythonLogger(name="mhd_pino")
log.file_logging()
wandb_dir = cfg.wandb_params.wandb_dir
wandb_project = cfg.wandb_params.wandb_project
wandb_group = cfg.wandb_params.wandb_group
initialize_wandb(
project=wandb_project,
entity="fresleven",
mode="offline",
group=wandb_group,
config=dict(cfg),
results_dir=wandb_dir,
)
LaunchLogger.initialize(use_wandb=cfg.use_wandb) # Modulus launch logger
# Load config file parameters
model_params = cfg.model_params
dataset_params = cfg.dataset_params
train_loader_params = cfg.train_loader_params
val_loader_params = cfg.val_loader_params
test_loader_params = cfg.test_loader_params
loss_params = cfg.loss_params
optimizer_params = cfg.optimizer_params
train_params = cfg.train_params
wandb_params = cfg.wandb_params
load_ckpt = cfg.load_ckpt
output_dir = cfg.output_dir
use_wandb = cfg.use_wandb
output_dir = to_absolute_path(output_dir)
os.makedirs(output_dir, exist_ok=True)
data_dir = dataset_params.data_dir
ckpt_path = train_params.ckpt_path
wandb_dir = wandb_params.wandb_dir
wandb_project = wandb_params.wandb_group
wandb_group = wandb_params.wandb_project
# Construct dataloaders
dataset_train = Dedalus2DDataset(
dataset_params.data_dir,
output_names=dataset_params.output_names,
field_names=dataset_params.field_names,
num_train=dataset_params.num_train,
num_test=dataset_params.num_test,
use_train=True,
)
dataset_val = Dedalus2DDataset(
data_dir,
output_names=dataset_params.output_names,
field_names=dataset_params.field_names,
num_train=dataset_params.num_train,
num_test=dataset_params.num_test,
use_train=False,
)
mhd_dataloader_train = MHDDataloaderVecPot(
dataset_train,
sub_x=dataset_params.sub_x,
sub_t=dataset_params.sub_t,
ind_x=dataset_params.ind_x,
ind_t=dataset_params.ind_t,
)
mhd_dataloader_val = MHDDataloaderVecPot(
dataset_val,
sub_x=dataset_params.sub_x,
sub_t=dataset_params.sub_t,
ind_x=dataset_params.ind_x,
ind_t=dataset_params.ind_t,
)
dataloader_train, sampler_train = mhd_dataloader_train.create_dataloader(
batch_size=train_loader_params.batch_size,
shuffle=train_loader_params.shuffle,
num_workers=train_loader_params.num_workers,
pin_memory=train_loader_params.pin_memory,
distributed=dist.distributed,
)
dataloader_val, sampler_val = mhd_dataloader_val.create_dataloader(
batch_size=val_loader_params.batch_size,
shuffle=val_loader_params.shuffle,
num_workers=val_loader_params.num_workers,
pin_memory=val_loader_params.pin_memory,
distributed=dist.distributed,
)
# define FNO model
model = FNO(
in_channels=model_params.in_dim,
out_channels=model_params.out_dim,
decoder_layers=model_params.decoder_layers,
decoder_layer_size=model_params.fc_dim,
dimension=model_params.dimension,
latent_channels=model_params.layers,
num_fno_layers=model_params.num_fno_layers,
num_fno_modes=model_params.modes,
padding=[model_params.pad_z, model_params.pad_y, model_params.pad_x],
).to(dist.device)
# Set up DistributedDataParallel if using more than a single process.
# The `distributed` property of DistributedManager can be used to
# check this.
if dist.distributed:
ddps = torch.cuda.Stream()
with torch.cuda.stream(ddps):
model = DistributedDataParallel(
model,
device_ids=[dist.local_rank], # Set the device_id to be
# the local rank of this process on
# this node
output_device=dist.device,
broadcast_buffers=dist.broadcast_buffers,
find_unused_parameters=dist.find_unused_parameters,
)
torch.cuda.current_stream().wait_stream(ddps)
# Construct optimizer and scheduler
# optimizer = Adam(model.parameters(), betas=optimizer_params['betas'], lr=optimizer_params['lr'])
optimizer = AdamW(
model.parameters(),
betas=optimizer_params.betas,
lr=optimizer_params.lr,
weight_decay=0.1,
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=optimizer_params.milestones, gamma=optimizer_params.gamma
)
# Construct Loss class
if cfg.derivative == "modulus":
mhd_loss = LossMHDVecPot_Modulus(**loss_params)
elif cfg.derivative == "original":
mhd_loss = LossMHDVecPot(**loss_params)
# Load model from checkpoint (if exists)
loaded_epoch = 0
if load_ckpt:
loaded_epoch = load_checkpoint(
ckpt_path, model, optimizer, scheduler, device=dist.device
)
# Training Loop
epochs = train_params.epochs
ckpt_freq = train_params.ckpt_freq
names = dataset_params.fields
input_norm = torch.tensor(model_params.input_norm).to(dist.device)
output_norm = torch.tensor(model_params.output_norm).to(dist.device)
for epoch in range(max(1, loaded_epoch + 1), epochs + 1):
with LaunchLogger(
"train",
epoch=epoch,
num_mini_batch=len(dataloader_train),
epoch_alert_freq=1,
) as log:
if dist.distributed:
sampler_train.set_epoch(epoch)
# Train Loop
model.train()
for i, (inputs, outputs) in enumerate(dataloader_train):
inputs = inputs.type(torch.FloatTensor).to(dist.device)
outputs = outputs.type(torch.FloatTensor).to(dist.device)
# Zero Gradients
optimizer.zero_grad()
# Compute Predictions
pred = (
model((inputs / input_norm).permute(0, 4, 1, 2, 3)).permute(
0, 2, 3, 4, 1
)
* output_norm
)
# Compute Loss
loss, loss_dict = mhd_loss(pred, outputs, inputs, return_loss_dict=True)
# Compute Gradients for Back Propagation
loss.backward()
# Update Weights
optimizer.step()
log.log_minibatch(loss_dict)
log.log_epoch({"Learning Rate": optimizer.param_groups[0]["lr"]})
scheduler.step()
with LaunchLogger("valid", epoch=epoch) as log:
# Val loop
model.eval()
val_loss_dict = {}
plot_count = 0
plot_dict = {name: {} for name in names}
with torch.no_grad():
for i, (inputs, outputs) in enumerate(dataloader_val):
inputs = inputs.type(dtype).to(dist.device)
outputs = outputs.type(dtype).to(dist.device)
# Compute Predictions
pred = (
model((inputs / input_norm).permute(0, 4, 1, 2, 3)).permute(
0, 2, 3, 4, 1
)
* output_norm
)
# Compute Loss
loss, loss_dict = mhd_loss(
pred, outputs, inputs, return_loss_dict=True
)
log.log_minibatch(loss_dict)
# Get prediction plots to log for wandb
# Do for number of batches specified in the config file
if (i < wandb_params.wandb_num_plots) and (
epoch % wandb_params.wandb_plot_freq == 0
):
# Add all predictions in batch
for j, _ in enumerate(pred):
# Make plots for each field
for index, name in enumerate(names):
# Generate figure
if use_wandb:
figs = plot_predictions_mhd_plotly(
pred[j].cpu(),
outputs[j].cpu(),
inputs[j].cpu(),
index=index,
name=name,
)
# Add figure to plot dict
plot_dict[name] = {
f"{plot_type}-{plot_count}": wandb.Html(
plotly.io.to_html(fig)
)
for plot_type, fig in zip(
wandb_params.wandb_plot_types, figs
)
}
plot_count += 1
# Get prediction plots and save images locally
if (i < 2) and (epoch % wandb_params.wandb_plot_freq == 0):
# Add all predictions in batch
for j, _ in enumerate(pred):
# Generate figure
plot_predictions_mhd(
pred[j].cpu(),
outputs[j].cpu(),
inputs[j].cpu(),
names=names,
save_path=os.path.join(
output_dir,
"MHD_" + cfg.derivative + "_" + str(dist.rank),
),
save_suffix=i,
)
if use_wandb and epoch % wandb_params["wandb_plot_freq"] == 0:
wandb.log({"plots": plot_dict})
if epoch % ckpt_freq == 0 and dist.rank == 0:
save_checkpoint(ckpt_path, model, optimizer, scheduler, epoch=epoch)
if __name__ == "__main__":
main()