Skip to content

Commit

Permalink
Merge pull request #58 from nasa-nccs-hpda/hackathon_2024_cssprad1
Browse files Browse the repository at this point in the history
Hackathon 2024 cssprad1
  • Loading branch information
cssprad1 authored Aug 28, 2024
2 parents 265f252 + f0597d7 commit 089c81c
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 21 deletions.
8 changes: 8 additions & 0 deletions pytorch_caney/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@
# Gamma / Multi steps value, used in MultiStepLRScheduler
_C.TRAIN.LR_SCHEDULER.GAMMA = 0.1
_C.TRAIN.LR_SCHEDULER.MULTISTEPS = []
# OneCycle LR Scheduler max LR percentage
_C.TRAIN.LR_SCHEDULER.CYCLE_PERCENTAGE = 0.3

# Optimizer
_C.TRAIN.OPTIMIZER = CN()
Expand All @@ -133,6 +135,10 @@
# [SimMIM] Layer decay for fine-tuning
_C.TRAIN.LAYER_DECAY = 1.0

# Tensorboard settings
_C.TENSORBOARD = CN()
_C.TENSORBOARD.WRITER_DIR = '.'


# -----------------------------------------------------------------------------
# Testing settings
Expand Down Expand Up @@ -218,6 +224,8 @@ def _check_args(name):
config.EVAL_MODE = True
if _check_args('enable_amp'):
config.ENABLE_AMP = args.enable_amp
if _check_args('tensorboard_dir'):
config.TENSORBOARD.WRITER_DIR = args.tensorboard_dir

# output folder
config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
Expand Down
22 changes: 22 additions & 0 deletions pytorch_caney/optimizer/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,28 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import collections
import math

import torch
from torch.optim import Optimizer

from torch.utils.tensorboard import SummaryWriter


def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int):
"""Log a histogram of trust ratio scalars in across layers."""
results = collections.defaultdict(list)
for group in optimizer.param_groups:
for p in group['params']:
state = optimizer.state[p]
for i in ('weight_norm', 'adam_norm', 'trust_ratio'):
if i in state:
results[i].append(state[i])

for k, v in results.items():
event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count)


class Lamb(Optimizer):
"""Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB
Expand Down Expand Up @@ -185,6 +202,11 @@ def step(self, closure=None):
if group['trust_clip']:
# LAMBC trust clipping, upper bound fixed at one
trust_ratio = torch.minimum(trust_ratio, one_tensor)

state['weight_norm'] = w_norm
state['adam_norm'] = g_norm
state['trust_ratio'] = trust_ratio

update.mul_(trust_ratio)

p.add_(update, alpha=-group['lr'])
Expand Down
60 changes: 46 additions & 14 deletions pytorch_caney/pipelines/pretraining/mim_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('runs/vit_160_exp_1')

NUM_SAMPLES: int = 1962000


def parse_args():
"""
Parse command-line arguments
Expand All @@ -50,6 +48,13 @@ def parse_args():
required=True,
help="paths where dataset is stored")

parser.add_argument(
'--tensorboard-dir',
type=str,
required=True,
help='Dir path for tensorboard to write to.'
)

parser.add_argument('--validation-path',
type=str,
required=True,
Expand Down Expand Up @@ -98,7 +103,8 @@ def train(config,
dataloader,
model_engine,
optimizer,
device):
device,
writer):
"""
Start pre-training a specific model and dataset.
Expand Down Expand Up @@ -130,7 +136,7 @@ def train(config,

execute_one_epoch(config, model_engine, dataloader,
optimizer, epoch, resuming_step,
target_dtype, device)
target_dtype, device, writer)

epoch_time = time.time() - start
logger.info(
Expand All @@ -144,8 +150,6 @@ def train(config,

logger.info('Training time {}'.format(total_time_str))

writer.close()


def execute_one_epoch(config,
model,
Expand All @@ -154,7 +158,8 @@ def execute_one_epoch(config,
epoch,
resuming_step,
target_dtype,
device):
device,
writer):
"""
Execute training iterations on a single epoch.
Expand All @@ -169,6 +174,10 @@ def execute_one_epoch(config,
"""
validationDataset = validation_setup(config)

# Setup lamb gradient logging
if config.TRAIN.OPTIMIZER.NAME == 'lamb':
from pytorch_caney.optimizer.lamb import log_lamb_rs

num_steps = max(1,
NUM_SAMPLES // (config.DATA.BATCH_SIZE * dist.get_world_size()))

Expand Down Expand Up @@ -210,11 +219,20 @@ def execute_one_epoch(config,

if idx % config.VALIDATION_FREQ == 0:
lr = optimizer.param_groups[0]['lr']
validate(model, validationDataset, lr, idx, epoch, target_dtype, device)
validate(model,
validationDataset,
lr,
idx,
epoch,
target_dtype,
device,
writer)

if idx % config.PRINT_FREQ == 0:
lr = optimizer.param_groups[0]['lr']
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
cached_memory = torch.cuda.memory_reserved() / (1024 * 1024) # in MB
max_memory = torch.cuda.max_memory_reserved() / (1024 * 1024) # in MB
etas = batch_time.avg * (num_steps - idx)
logger.info(
f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
Expand All @@ -223,8 +241,12 @@ def execute_one_epoch(config,
f'data_time {data_time.val:.4f} ({data_time.avg:.4f})\t'
f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
f'mem {memory_used:.0f}MB')
writer.add_scalar('training loss ', loss_meter.val, idx)
writer.add_scalar('memory usage ', memory_used, idx)
writer.add_scalar('training_loss ', loss_meter.val, idx)
writer.add_scalar('memory_usage ', memory_used, idx)
writer.add_scalar('cached_memory', cached_memory, idx)
writer.add_scalar('max_memory', max_memory, idx)
if config.TRAIN.OPTIMIZER.NAME == 'lamb':
log_lamb_rs(optimizer, writer, idx)
writer.flush()

if idx % config.SAVE_FREQ == 0 or idx == num_steps-1:
Expand All @@ -249,6 +271,11 @@ def main(config):

logger.info('In main')

tensorboardMainDir = config.TENSORBOARD.WRITER_DIR
tensorboardDir = f'{tensorboardMainDir}/{config.TAG}'
logger.info(f'Initializing tensorboard to {tensorboardDir}')
writer = SummaryWriter(tensorboardDir)

transform = SimmimTransform(config)

dataset = MODIS22MDataset(config,
Expand Down Expand Up @@ -308,7 +335,7 @@ def main(config):
total_steps = num_steps * config.TRAIN.EPOCHS
logger.info(f'Total steps for {config.TRAIN.EPOCHS} epochs: {total_steps}')

cycle_one_percentage = 0.3
cycle_one_percentage = config.TRAIN.LR_SCHEDULER.CYCLE_PERCENTAGE
cycle_stage_one = int(total_steps * cycle_one_percentage)
cycle_stage_two = (total_steps - cycle_stage_one) - 1

Expand Down Expand Up @@ -430,7 +457,10 @@ def main(config):
dataloader,
model_engine,
optimizer,
local_device)
local_device,
writer)

writer.close()


@torch.no_grad()
Expand All @@ -448,7 +478,7 @@ def validation_setup(config):


@torch.no_grad()
def validate(model, img_masks, lr, step, epoch, target_dtype, device):
def validate(model, img_masks, lr, step, epoch, target_dtype, device, writer):
start_time = time.time()

img, mask = img_masks
Expand All @@ -468,6 +498,8 @@ def validate(model, img_masks, lr, step, epoch, target_dtype, device):
f"lr {lr}\t"
f"val_loss {loss:.4f}\t"
f"time {validation_time:.4f}s")
writer.add_scalar('validation_loss', loss, step)
writer.flush()

del img, mask, loss

Expand Down Expand Up @@ -509,7 +541,7 @@ def setup_seeding(config):
setup_seeding(config)

config.defrost()
base_batch_size = 2048
base_batch_size = 512
config.TRAIN.BASE_LR = (config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size()) / base_batch_size
config.TRAIN.WARMUP_LR = (config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size()) / base_batch_size
config.TRAIN.MIN_LR = (config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size()) / base_batch_size
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
MODEL:
TYPE: swinv2
NAME: mim_satvision_pretrain-huge
DROP_PATH_RATE: 0.1
SWINV2:
IN_CHANS: 14
EMBED_DIM: 352
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 4, 8, 16, 32]
WINDOW_SIZE: 8
NORM_PERIOD: 6
DATA:
IMG_SIZE: 128
MASK_PATCH_SIZE: 8
MASK_RATIO: 0.6
TRAIN:
USE_CHECKPOINT: True
EPOCHS: 50
WARMUP_EPOCHS: 1
BASE_LR: 3e-4
MIN_LR: 2e-4
WARMUP_LR: 1e-4
WEIGHT_DECAY: 0.05
OPTIMIZER:
NAME: adamw
LR_SCHEDULER:
NAME: 'multistep'
GAMMA: 0.1
MULTISTEPS: [700,]
PRINT_FREQ: 10
SAVE_FREQ: 1000
VALIDATION_FREQ: 20
TAG: mim_pretrain_675m_2m_128_window8_onecycle_hackathon_adamw
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ TRAIN:
NAME: 'multistep'
GAMMA: 0.1
MULTISTEPS: [700,]
PRINT_FREQ: 100
PRINT_FREQ: 10
SAVE_FREQ: 1000
VALIDATION_FREQ: 200
TAG: mim_pretrain_675m_2m_128_window8_onecycle_hackathon
VALIDATION_FREQ: 20
TAG: mim_pretrain_675m_2m_128_window8_onecycle_hackathon_lamb
9 changes: 5 additions & 4 deletions runs/runners/frontier_svtoa_pretraining_runner.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ mkdir -p ${MIOPEN_USER_DB_PATH}
echo "copying torch_env to each node in the job"
conda_env_name='rocm-torch-test-full-0.1.0'

sbcast -pf $MEMBERWORK/geo160/${conda_env_name}.tar.gz /mnt/bb/${USER}/${conda_env_name}.tar.gz
echo $MEMBERWORK/geo160/${conda_env_name}.tar.gz
sbcast -pf /lustre/orion/geo160/proj-shared/envs/${conda_env_name}.tar.gz.hackathon /mnt/bb/${USER}/${conda_env_name}.tar.gz
echo /lustre/orion/geo160/proj-shared/envs/${conda_env_name}.tar.gz.hackathon
echo /mnt/bb/${USER}/${conda_env_name}.tar.gz
ls -l /mnt/bb/${USER}
ls -l $MEMBERWORK/geo160
ls -l /lustre/orion/geo160/proj-shared/envs

if [ ! "$?" == "0" ]; then
# CHECK EXIT CODE. When SBCAST fails, it may leave partial files on the compute nodes, and if you continue to launch srun,
Expand Down Expand Up @@ -89,13 +89,14 @@ echo $MASTER_PORT
nnodes=$SLURM_JOB_NUM_NODES
datapaths=/lustre/orion/geo160/proj-shared/data/satvision-toa/50m
validationpath=/lustre/orion/geo160/proj-shared/data/satvision-toa/validation/sv_toa_128_chip_validation_04_24.npy
tensorboard_dir=/lustre/orion/geo160/proj-shared/data/tensorboard/hackathon_2024
batchsize=256
nprocpernode=8

launcher="python -u -m torch.distributed.run --nnodes=${nnodes} --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node=${nprocpernode}"
echo $launcher

cmd=" pytorch-caney/pytorch_caney/pipelines/pretraining/mim_deepspeed.py --cfg $1 --dataset MODIS --data-paths ${datapaths} --output . --batch-size ${batchsize} --validation-path ${validationpath}"
cmd=" pytorch-caney/pytorch_caney/pipelines/pretraining/mim_deepspeed.py --cfg $1 --dataset MODIS --data-paths ${datapaths} --output . --batch-size ${batchsize} --validation-path ${validationpath} --tensorboard-dir ${tensorboard_dir}"
echo $cmd

srun -l -c56 --gpus-per-task=${nprocpernode} --gpu-bind=closest --jobid $SLURM_JOBID bash -c "$launcher --node_rank \$SLURM_PROCID $cmd"
Expand Down

0 comments on commit 089c81c

Please sign in to comment.