Skip to content

Commit

Permalink
Merge pull request #56 from nasa-nccs-hpda/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
cssprad1 authored Jul 12, 2024
2 parents 3eba08a + 5b1cba9 commit d3f2e2f
Show file tree
Hide file tree
Showing 10 changed files with 589 additions and 51 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
MODEL:
TYPE: swinv2
DECODER: unet
NAME: satvision_finetune_lc5class
DROP_PATH_RATE: 0.1
NUM_CLASSES: 5
SWINV2:
IN_CHANS: 7
EMBED_DIM: 128
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 4, 8, 16, 32 ]
WINDOW_SIZE: 14
PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
DATA:
IMG_SIZE: 224
DATASET: MODISLC5
MASK_PATCH_SIZE: 32
MASK_RATIO: 0.6
LOSS:
NAME: 'tversky'
MODE: 'multiclass'
ALPHA: 0.4
BETA: 0.6
TRAIN:
EPOCHS: 100
WARMUP_EPOCHS: 10
BASE_LR: 1e-4
WARMUP_LR: 5e-7
WEIGHT_DECAY: 0.01
LAYER_DECAY: 0.8
PRINT_FREQ: 100
SAVE_FREQ: 5
TAG: satvision_finetune_land_cover_5class_swinv2_satvision_192_window12__800ep
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
MODEL:
TYPE: swinv2
DECODER: unet
NAME: satvision_toa_finetune_lc9class
DROP_PATH_RATE: 0.1
NUM_CLASSES: 9
SWINV2:
IN_CHANS: 14
EMBED_DIM: 352
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 4, 8, 16, 32 ]
WINDOW_SIZE: 14
NORM_PERIOD: 6
DATA:
IMG_SIZE: 224
DATASET: MODISLC9
MASK_PATCH_SIZE: 8
MASK_RATIO: 0.6
LOSS:
NAME: 'tversky'
MODE: 'multiclass'
ALPHA: 0.4
BETA: 0.6
TRAIN:
EPOCHS: 100
WARMUP_EPOCHS: 10
BASE_LR: 1e-4
WARMUP_LR: 5e-7
WEIGHT_DECAY: 0.01
LAYER_DECAY: 0.8
PRINT_FREQ: 100
SAVE_FREQ: 5
TAG: satvision_toa_finetune_land_cover_9class_swinv2_satvision_224_window12__100ep
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash

#SBATCH -J finetune_satvision_lc5
#SBATCH -t 3-00:00:00
#SBATCH -G 4
#SBATCH -N 1


export PYTHONPATH=$PWD:../../../:../../../pytorch-caney
export NGPUS=8

torchrun --nproc_per_node $NGPUS \
../../../pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py \
--cfg finetune_satvision_base_landcover5class_192_window12_100ep.yaml \
--pretrained /explore/nobackup/people/cssprad1/projects/satnet/code/development/masked_image_modeling/development/models/simmim_satnet_pretrain_pretrain/simmim_pretrain__satnet_swinv2_base__img192_window12__800ep_v3_no_norm/ckpt_epoch_800.pth \
--dataset MODISLC9 \
--data-paths /explore/nobackup/projects/ilab/data/satvision/finetuning/h18v04/labels_9classes_224 \
--batch-size 4 \
--output /explore/nobackup/people/cssprad1/projects/satnet/code/development/cleanup/finetune/models \
--enable-amp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash

#SBATCH -J finetune_satvision_lc9
#SBATCH -t 3-00:00:00
#SBATCH -G 4
#SBATCH -N 1

export PYTHONPATH=$PWD:$PWD/pytorch-caney
export NGPUS=4

torchrun --nproc_per_node $NGPUS \
pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py \
--cfg $1 \
--pretrained $2 \
--dataset MODISLC9 \
--data-paths /explore/nobackup/projects/ilab/data/satvision/finetuning/h18v04/labels_5classes_224 \
--batch-size 4 \
--output . \
--enable-amp
2 changes: 1 addition & 1 deletion pytorch_caney/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
_C.TRAIN.AUTO_RESUME = True
# Gradient accumulation steps
# could be overwritten by command line argument
_C.TRAIN.ACCUMULATION_STEPS = 0
_C.TRAIN.ACCUMULATION_STEPS = 1
# Whether to use gradient checkpointing to save memory
# could be overwritten by command line argument
_C.TRAIN.USE_CHECKPOINT = False
Expand Down
5 changes: 2 additions & 3 deletions pytorch_caney/data/datasets/modis_lc_nine_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,8 @@ def __len__(self):
def __getitem__(self, idx, transpose=True):

# load image
img = np.load(self.img_list[idx])

img = np.clip(img, 0, 1.0)
img = np.random.rand(224, 224, 14)
# img = np.load(self.img_list[idx])

# load mask
mask = np.load(self.mask_list[idx])
Expand Down
2 changes: 1 addition & 1 deletion pytorch_caney/models/unet_swin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class unet_swin(nn.Module):
decoder.
"""

FEATURE_CHANNELS: Tuple[int] = (3, 256, 512, 1024, 1024)
FEATURE_CHANNELS: Tuple[int] = (3, 704, 1408, 2816, 2816)
DECODE_CHANNELS: Tuple[int] = (512, 256, 128, 64)
IN_CHANNELS: int = 64
N_BLOCKS: int = 4
Expand Down
64 changes: 23 additions & 41 deletions pytorch_caney/pipelines/finetuning/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def train(config,
criterion: loss function to use for fine-tuning
"""

loss = validate(config, model, dataloader_val, criterion)

logger.info(f'Model validation loss: {loss:.3f}%')

logger.info("Start fine-tuning")

start_time = time.time()
Expand Down Expand Up @@ -204,45 +208,29 @@ def execute_one_epoch(config,
samples = samples.cuda(non_blocking=True)
targets = targets.cuda(non_blocking=True)

samples = samples.to(torch.bfloat16)

with amp.autocast(enabled=config.ENABLE_AMP):
logits = model(samples)

if config.TRAIN.ACCUMULATION_STEPS > 1:
loss = criterion(logits, targets)
loss = loss / config.TRAIN.ACCUMULATION_STEPS
scaler.scale(loss).backward()
if config.TRAIN.CLIP_GRAD:
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
config.TRAIN.CLIP_GRAD)
else:
grad_norm = get_grad_norm(model.parameters())
if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
scaler.step(optimizer)
optimizer.zero_grad()
scaler.update()
lr_scheduler.step_update(epoch * num_steps + idx)
else:
loss = criterion(logits, targets)
loss = criterion(logits, targets)
loss = loss / config.TRAIN.ACCUMULATION_STEPS

scaler.scale(loss).backward()

grad_norm = get_grad_norm(model.parameters())

if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
optimizer.zero_grad()
scaler.scale(loss).backward()
if config.TRAIN.CLIP_GRAD:
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
config.TRAIN.CLIP_GRAD)
else:
grad_norm = get_grad_norm(model.parameters())
scaler.step(optimizer)
scaler.update()
lr_scheduler.step_update(epoch * num_steps + idx)
lr_scheduler.step_update((epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS)

loss_scale_value = scaler.state_dict()["scale"]

torch.cuda.synchronize()

loss_meter.update(loss.item(), targets.size(0))
norm_meter.update(grad_norm)
loss_scale_meter.update(scaler.get_scale())
loss_scale_meter.update(loss_scale_value)
batch_time.update(time.time() - end)
end = time.time()

Expand Down Expand Up @@ -298,11 +286,14 @@ def validate(config, model, dataloader, criterion):

target = target.cuda(non_blocking=True)

images = images.to(torch.bfloat16)

# compute output
output = model(images)
with amp.autocast(enabled=config.ENABLE_AMP):
output = model(images)

# measure accuracy and record loss
loss = criterion(output, target.long())
loss = criterion(output, target)

loss = reduce_tensor(loss)

Expand Down Expand Up @@ -429,15 +420,6 @@ def setup_seeding(config):

cudnn.benchmark = True

linear_scaled_lr, linear_scaled_min_lr, linear_scaled_warmup_lr = \
setup_scaled_lr(config)

config.defrost()
config.TRAIN.BASE_LR = linear_scaled_lr
config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
config.TRAIN.MIN_LR = linear_scaled_min_lr
config.freeze()

os.makedirs(config.OUTPUT, exist_ok=True)
logger = create_logger(output_dir=config.OUTPUT,
dist_rank=dist.get_rank(),
Expand Down
Loading

0 comments on commit d3f2e2f

Please sign in to comment.