Skip to content
This repository has been archived by the owner on Mar 23, 2023. It is now read-only.

Commit

Permalink
[hotfix]fit to the refactored pipeline api (#136)
Browse files Browse the repository at this point in the history
* change to fit refactored schedule

* [hotfix]fit to the refactored pipeline api
  • Loading branch information
YuliangLiu0306 authored Jun 13, 2022
1 parent 8946b49 commit d1ce233
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 62 deletions.
2 changes: 1 addition & 1 deletion features/pipeline_parallel/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from colossalai.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader
from colossalai.context import ParallelMode
from colossalai.utils.model.pipelinable import PipelinableContext
from colossalai.pipeline.pipelinable import PipelinableContext

from titans.dataloader.cifar10 import build_cifar
from torchvision.models import resnet50
Expand Down
35 changes: 17 additions & 18 deletions image/mlpmixer/train_pipline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import torch
import torch.nn as nn
from colossalai.builder import build_pipeline_model
from colossalai.engine.schedule import (InterleavedPipelineSchedule,
PipelineSchedule)
from colossalai.engine.schedule import (InterleavedPipelineSchedule, PipelineSchedule)
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader
Expand All @@ -27,24 +26,21 @@
from colossalai.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, is_using_pp

from colossalai.utils.model.pipelinable import PipelinableContext

from colossalai.pipeline.pipelinable import PipelinableContext


class MlpBlock(nn.Module):

def __init__(self, hidden_dim, mlp_dim):
super(MlpBlock, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, mlp_dim),
nn.GELU(),
nn.Linear(mlp_dim, hidden_dim)
)
self.mlp = nn.Sequential(nn.Linear(hidden_dim, mlp_dim), nn.GELU(), nn.Linear(mlp_dim, hidden_dim))

def forward(self, x):
return self.mlp(x)


class MixerBlock(nn.Module):

def __init__(self, num_tokens, hidden_dim, tokens_mlp_dim, channels_mlp_dim):
super(MixerBlock, self).__init__()
self.ln_token = nn.LayerNorm(hidden_dim)
Expand Down Expand Up @@ -84,41 +80,46 @@ def forward(self, x):


class data_flatten(nn.Module):

def __init__(self):
super().__init__()

def forward(self, x):
return x.flatten(2).transpose(1, 2)


class data_mean(nn.Module):

def __init__(self):
super().__init__()

def forward(self, x):
return x.mean(dim=1)


def MlpMixer(num_classes, num_blocks, patch_size, hidden_dim, tokens_mlp_dim, channels_mlp_dim, image_size=224):
num_tokens = (image_size // patch_size) ** 2
num_tokens = (image_size // patch_size)**2
patch_emb = nn.Conv2d(3, hidden_dim, kernel_size=patch_size, stride=patch_size, bias=False)

mlp = nn.Sequential(
*[MixerBlock(num_tokens, hidden_dim, tokens_mlp_dim, channels_mlp_dim) for _ in range(num_blocks)])
ln = nn.LayerNorm(hidden_dim)
fc = nn.Linear(hidden_dim, num_classes)

return nn.Sequential(patch_emb,data_flatten(),mlp,ln, data_mean(), fc)
return nn.Sequential(patch_emb, data_flatten(), mlp, ln, data_mean(), fc)


# def data_flatten(x):
# return x.flatten(2).transpose(1, 2)
#
# def data_mean(x):
# return x.mean(dim=1)


# def mixer_s32(num_classes=10, image_size=32, patch_size=4,**kwargs):
# return MlpMixer(num_classes, 8, patch_size, 512, 256, 2048, image_size)

def mixer_s32(num_classes=1000, image_size=224, patch_size=32,**kwargs):

def mixer_s32(num_classes=1000, image_size=224, patch_size=32, **kwargs):
return MlpMixer(num_classes, 8, patch_size, 512, 256, 2048, image_size, **kwargs)


Expand Down Expand Up @@ -153,7 +154,6 @@ def build_cifar(batch_size):
patch_size = 4



def train():
# initialize distributed setting
parser = colossalai.get_default_parser()
Expand Down Expand Up @@ -185,13 +185,12 @@ def train():
if use_pipeline:
pipelinable = PipelinableContext()
with pipelinable:
model = mixer_s32(num_classes,image_size,patch_size)
model = mixer_s32(num_classes, image_size, patch_size)
pipelinable.to_layer_list()
pipelinable.load_policy("uniform")
model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
else:
model = mixer_s32(num_classes,image_size,patch_size)

model = mixer_s32(num_classes, image_size, patch_size)

# count number of parameters
total_numel = 0
Expand All @@ -204,7 +203,7 @@ def train():
logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")

# craete dataloaders

train_dataloader, test_dataloader = build_cifar(BATCH_SIZE)

# create loss function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.utils import is_using_pp, get_dataloader
from colossalai.utils.model.pipelinable import PipelinableContext
from colossalai.pipeline.pipelinable import PipelinableContext
from titans.model.vit.vit import _create_vit_model
from tqdm import tqdm

Expand Down Expand Up @@ -92,9 +92,6 @@ def main():

logger.info("Engine is built", ranks=[0])

# create schedule
schedule = None
tensor_shape = getattr(gpc.config, 'TENSOR_SHAPE', None)
if gpc.is_initialized(ParallelMode.PARALLEL_1D):
scatter_gather = True
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.utils import is_using_pp
from titans.dataloader.imagenet import build_dali_imagenet
from colossalai.utils.model.pipelinable import PipelinableContext
from colossalai.pipeline.pipelinable import PipelinableContext
from titans.model.vit.vit import _create_vit_model
from tqdm import tqdm

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from colossalai.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, is_using_pp
from titans.dataloader.imagenet import build_dali_imagenet
from colossalai.utils.model.pipelinable import PipelinableContext
from colossalai.pipeline.pipelinable import PipelinableContext
from titans.model.vit.vit import _create_vit_model


Expand Down
2 changes: 1 addition & 1 deletion language/DeepNet/train_deepnet_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from colossalai.trainer import hooks, Trainer
from colossalai import nn as col_nn
from colossalai.nn import LinearWarmupLR
from colossalai.utils.model.pipelinable import PipelinableContext
from colossalai.pipeline.pipelinable import PipelinableContext
import torch.nn as nn
from dataset.webtext import WebtextDataset
import contextlib
Expand Down
61 changes: 25 additions & 36 deletions language/gpt/train_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai import nn as col_nn
from colossalai.engine.schedule import (InterleavedPipelineSchedule,
PipelineSchedule)
from colossalai.engine.schedule import (InterleavedPipelineSchedule, PipelineSchedule)
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn import LinearWarmupLR
from colossalai.trainer import Trainer, hooks
from colossalai.utils import is_using_pp, colo_set_process_memory_fraction
from colossalai.utils.timer import MultiTimer
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.utils.model.pipelinable import PipelinableContext
from colossalai.pipeline.pipelinable import PipelinableContext
from titans.loss.lm_loss import GPTLMLoss

from dataset.webtext import WebtextDataset
Expand All @@ -37,10 +36,7 @@ def main():
if args.from_torch:
colossalai.launch_from_torch(config=args.config)
else:
colossalai.launch_from_slurm(config=args.config,
host=args.host,
port=29500,
seed=42)
colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42)

logger = get_dist_logger()

Expand All @@ -63,23 +59,24 @@ def main():
ctx = contextlib.nullcontext()
if use_zero3:
ctx = ZeroInitContext(target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True
)
shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True)
with ctx:
model = gpc.config.model.pop('type')(**gpc.config.model)
else:
pipelinable = PipelinableContext()
with pipelinable:
model = gpc.config.model.pop('type')(**gpc.config.model)

def mask_function(attention_mask=None):
if attention_mask is not None:
batch_size = gpc.config.BATCH_SIZE//gpc.config.NUM_MICRO_BATCHES
batch_size = gpc.config.BATCH_SIZE // gpc.config.NUM_MICRO_BATCHES
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = col_nn.partition_batch(attention_mask)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = (1.0 - attention_mask) * -10000.0
return attention_mask

# GPT2_small exec_seq
# (lyl)TODO: The exec_seq for gpt3 will be added here and to_layer_list should be more friendly to use.
exec_seq = ['embed', mask_function, 'blocks.0', 'blocks.1', 'blocks.2', 'blocks.3', 'blocks.4', 'blocks.5', (mask_function, "front"), \
Expand All @@ -89,11 +86,11 @@ def mask_function(attention_mask=None):
# (lyl)TODO: Zero context and pipelinable context should be integrated into one context.
if use_zero3:
ctx = ZeroInitContext(target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True
)
shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True)
with ctx:
model = pipelinable.partition(num_chunks, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
model = pipelinable.partition(num_chunks, gpc.pipeline_parallel_size,
gpc.get_local_rank(ParallelMode.PIPELINE))

if use_zero3:
numel = ctx.model_numel_tensor.item()
Expand All @@ -110,11 +107,9 @@ def mask_function(attention_mask=None):
criterion = GPTLMLoss()

logger.info('Build optimizer', ranks=[0])
optimizer = gpc.config.optimizer.pop('type')(
model.parameters(), **gpc.config.optimizer)
optimizer = gpc.config.optimizer.pop('type')(model.parameters(), **gpc.config.optimizer)

lr_scheduler = LinearWarmupLR(
optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=5)
lr_scheduler = LinearWarmupLR(optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=5)

engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model,
optimizer,
Expand All @@ -127,33 +122,27 @@ def mask_function(attention_mask=None):

timier = MultiTimer()

trainer = Trainer(
engine=engine,
logger=logger,
timer=timier
)
trainer = Trainer(engine=engine, logger=logger, timer=timier)

hook_list = [
hooks.LossHook(),
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True),
hooks.LogMetricByEpochHook(logger),
hooks.ThroughputHook(ignored_steps=10, tflop_per_step=tflop),
hooks.LogMetricByStepHook(),
# hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]),
# hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]),
hooks.LogMemoryByEpochHook(logger),
# hooks.LogTimingByEpochHook(timer, logger),
# hooks.SaveCheckpointHook(checkpoint_dir='./ckpt')
# hooks.LogTimingByEpochHook(timer, logger),
# hooks.SaveCheckpointHook(checkpoint_dir='./ckpt')
]

trainer.fit(
train_dataloader=train_dataloader,
epochs=gpc.config.NUM_EPOCHS,
test_interval=1,
hooks=hook_list,
display_progress=True,
return_output_label=False,
max_steps=5
)
trainer.fit(train_dataloader=train_dataloader,
epochs=gpc.config.NUM_EPOCHS,
test_interval=1,
hooks=hook_list,
display_progress=True,
return_output_label=False,
max_steps=5)


if __name__ == '__main__':
Expand Down

0 comments on commit d1ce233

Please sign in to comment.