Skip to content

Commit

Permalink
[WIP] Used per-parameter FSDP
Browse files Browse the repository at this point in the history
  • Loading branch information
awgu committed Feb 27, 2024
1 parent 5dec536 commit 56f1ae0
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 51 deletions.
2 changes: 1 addition & 1 deletion run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ NGPU=${NGPU:-"8"}
LOG_RANK=${LOG_RANK:-0}


CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/test_llama7b.toml"}

torchrun --nproc_per_node=${NGPU} --rdzv_endpoint="localhost:5972" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
Expand Down
50 changes: 17 additions & 33 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging

import torch
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import (
distribute_module,
distribute_tensor,
Expand All @@ -19,13 +20,6 @@
checkpoint_wrapper as ptd_checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.fsdp import (
BackwardPrefetch,
FullyShardedDataParallel as FSDP,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import enable_wrap, wrap
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
Expand Down Expand Up @@ -153,32 +147,22 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
fsdp_config = {
"mixed_precision": MixedPrecision(
param_dtype=torch.bfloat16,
# TODO: see whether we should expose a option to user
reduce_dtype=torch.float32,
),
"sharding_strategy": ShardingStrategy.FULL_SHARD,
"backward_prefetch": BackwardPrefetch.BACKWARD_PRE,
# When torch.compile is active, it requires us to set use_orig_params=True
"use_orig_params": True,
"device_mesh": dp_mesh,
}

with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
for layer_id, transformer_block in enumerate(model.layers):
# apply AC to each layer
# before wrapping with FSDP, we need to make sure the layer is on GPU
transformer_block = transformer_block.cuda()
transformer_block = checkpoint_wrapper(transformer_block, job_config)

# Wraps each layer with FSDP
model.layers[layer_id] = wrap(transformer_block)

# wrap the rest layers with FSDP
model = wrap(model.cuda())

mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32
)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer_id, transformer_block in enumerate(model.layers):
transformer_block = checkpoint_wrapper(transformer_block, job_config)
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = layer_id < len(model.layers) - 1
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward
)
model.layers[layer_id] = transformer_block
model = fully_shard(model, **fsdp_config)
rank0_log("Applied FSDP to the model...")

# redundant if FSDP is enabled, but ensure the model is on device regardless of which parallelisms were used
Expand Down
34 changes: 17 additions & 17 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# torch imports
import torch
import torch.nn.functional as F
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

from torchtrain.checkpoint import CheckpointManager, IntervalType
Expand Down Expand Up @@ -56,9 +55,9 @@ def build_optimizer(model, job_config: JobConfig):
name = job_config.optimizer.name
lr = job_config.optimizer.lr
if name == "Adam":
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, foreach=True)
elif name == "AdamW":
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, foreach=True)
else:
raise NotImplementedError(f"optimizer {name} not added")

Expand All @@ -67,13 +66,14 @@ def build_optimizer(model, job_config: JobConfig):

def build_grad_scaler(model):
# apply gradient scaling if mixed precision training is enabled with fp16 param dtype
if model.mixed_precision.param_dtype == torch.float16:
enable_grad_scaling = True
rank0_log("Enabling gradient scaling for mixed precision training.")
else:
enable_grad_scaling = False
rank0_log("Gradient scaling not enabled.")

# TODO: We do not expose the mixed precision attribute. This is low
# priority since we do not use fp16.
# if model.mixed_precision.param_dtype == torch.float16:
# enable_grad_scaling = True
# rank0_log("Enabling gradient scaling for mixed precision training.")
# else:
enable_grad_scaling = False
rank0_log("Gradient scaling not enabled.")
return ShardedGradScaler(enabled=enable_grad_scaling)


Expand Down Expand Up @@ -130,9 +130,6 @@ def main(job_config: JobConfig):
model, world_mesh, parallel_dims, job_config
)

# to use FSDP-customized gradient scaler and gradient clipping solutions
assert isinstance(model, FSDP)

# build optimizer after apply parallelisms to the model
optimizer = build_optimizer(model, job_config)
scheduler = get_lr_scheduler(optimizer, job_config)
Expand All @@ -144,9 +141,7 @@ def main(job_config: JobConfig):
# torch.compile model for improved performance
if job_config.training.compile:
rank0_log(f"Compiling model {model_name} with torch.compile...")
model = torch.compile(
model,
)
model = torch.compile(model)

train_state = TrainState()

Expand Down Expand Up @@ -208,7 +203,12 @@ def main(job_config: JobConfig):

# clip gradients (after unscaling gradients of the optimizer's params)
scaler.unscale_(optimizer)
model.clip_grad_norm_(job_config.training.max_norm)
# TODO: Disable `clip_grad_norm_()` until it is supported:
# https://github.com/pytorch/pytorch/pull/120238
# torch.nn.utils.clip_grad_norm_(
# model.parameters(), job_config.training.max_norm
# )
# model.clip_grad_norm_(job_config.training.max_norm)

# optimizer step
# If gradients don't contain infs/NaNs, optimizer.step() is then called;
Expand Down
39 changes: 39 additions & 0 deletions train_configs/test_llama7b.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# TorchTrain Config.toml
[job]
dump_folder = "./outputs"

[profiling]
run_profiler = true
save_traces_folder = "profiling/traces"
# profiling frequency - example: 10 means every 10th iter will be profiled
profile_every_x_iter = 10

[metrics]
enable_tensorboard = true
save_tb_folder = "tb"
log_freq = 10

[model]
name = "llama"
flavor = "7B"
tokenizer_path = "./torchtrain/datasets/tokenizer/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 8e-4


[training]
batch_size = 8
seq_len = 2048
warmup_pct = 0.20 # lr scheduler warm up
max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_degree = -1
sequence_parallel_degree = 1
pipeline_parallel_degree = 1
compile = false
checkpoint_interval = 3600
checkpoint_interval_type = "steps"
checkpoint_folder = ""
dataset = "alpaca"

0 comments on commit 56f1ae0

Please sign in to comment.