Skip to content
This repository has been archived by the owner on Mar 23, 2023. It is now read-only.

Commit

Permalink
[Bert] Add colotensor Example
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Jun 23, 2022
2 parents 3d902cb + 7e94e3e commit f71dd4f
Show file tree
Hide file tree
Showing 7 changed files with 277 additions and 0 deletions.
2 changes: 2 additions & 0 deletions language/bert/colotensor/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[WIP]

12 changes: 12 additions & 0 deletions language/bert/colotensor/configs/bert_base_tp1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SEQ_LENGTH = 512
BATCH_SIZE = 8
NUM_EPOCHS = 10
WARMUP_EPOCHS = 1

parallel = dict(
tensor=dict(mode="1d", size=4),
)

model = dict(
type="bert_base",
)
20 changes: 20 additions & 0 deletions language/bert/colotensor/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

from language.bert.colotensor.dataset.wikitext import build_data_from_wikitext
from colossalai.core import global_context as gpc

_datasets = {
"wikitext": build_data_from_wikitext,
}

def build_data(**args):
if hasattr(gpc.config, "dataset"):
assert (
gpc.config.dataset in _datasets.keys()
), f"Invalid dataset name. dataset should be in {_datasets.keys()} or use default wikitext"
builder = _datasets[gpc.config.dataset]
else:
builder = _datasets["wikitext"]
return builder(**args)


__all__ = ["build_data"]
73 changes: 73 additions & 0 deletions language/bert/colotensor/dataset/wikitext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import random
import torch
import numpy as np
import copy

from itertools import chain
from datasets import load_from_disk, set_progress_bar_enabled

from torch.utils.data import DataLoader, DistributedSampler
from torch.distributed import get_world_size

from transformers import BertTokenizer, default_data_collator
from colossalai.logging import get_dist_logger


def build_data_from_wikitext(dataset_path: str, tokenizer_path: str, seq_len: int = 512, batch_size: int = 8):
logger = get_dist_logger("build_data_from_wikitext")
logger.info("Building Wikitext-2 ...", ranks=[0])
world_size = get_world_size()

set_progress_bar_enabled(False)
dataset = load_from_disk(dataset_path)

tokenizer = BertTokenizer(vocab_file=tokenizer_path + "/vocab.txt")

def tokenize(examples):
seq_length = seq_len
examples = tokenizer(examples["text"])
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
if total_length >= seq_length:
total_length = (total_length // seq_length) * seq_length

result = {
k: [t[i : i + seq_len] for i in range(0, total_length, seq_length)]
for k, t in concatenated_examples.items()
}

return result

tokenized_dataset = dataset.map(
tokenize, batched=True, num_proc=16, load_from_cache_file=False, keep_in_memory=True, remove_columns="text"
)

def seed_worker():
worker_seed = 1024
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)

train_sampler = DistributedSampler(tokenized_dataset["train"], shuffle=True) if world_size > 1 else None
train_data = DataLoader(
tokenized_dataset["train"],
shuffle=(train_sampler is None),
sampler=train_sampler,
drop_last=True,
collate_fn=default_data_collator,
worker_init_fn=seed_worker,
batch_size=batch_size,
pin_memory=True,
)
test_sampler = DistributedSampler(tokenized_dataset["validation"], shuffle=False) if world_size > 1 else None
test_data = DataLoader(
tokenized_dataset["validation"],
sampler=test_sampler,
drop_last=True,
collate_fn=default_data_collator,
worker_init_fn=seed_worker,
batch_size=batch_size,
pin_memory=True,
)

return train_data, test_data
47 changes: 47 additions & 0 deletions language/bert/colotensor/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from language.bert.colotensor.model.hfmodel import ModelFromHF
from colossalai.core import global_context as gpc
from transformers import BertConfig, BertForMaskedLM

_bert_base = dict(
seq_length=512,
vocab_size=50304,
hidden_size=768,
num_heads=12,
depth=12,
ff_size=3072,
checkpoint=False,
evaluation='ppl',
)

_bert_large = dict(
seq_length=512,
vocab_size=50304,
hidden_size=1024,
num_heads=16,
depth=24,
ff_size=3072,
checkpoint=False,
evaluation='ppl',
)

_bert_configurations = dict(
bert=_bert_base,
bert_base=_bert_base,
bert_large=_bert_large
)

def build_model():
model_cfg = _bert_configurations[gpc.config.model.type]
bert_cfg = BertConfig(vocab_size=model_cfg['vocab_size'],
hidden_size=model_cfg['hidden_size'],
num_hidden_layers=model_cfg['depth'],
num_attention_heads=model_cfg['num_heads'],
intermediate_size=model_cfg['ff_size'],
max_position_embeddings=model_cfg['seq_length'],
use_cache=not gpc.config.model.get('checkpoint', False))

model = ModelFromHF(bert_cfg, BertForMaskedLM)

return model

__all__ = ["build_model"]
17 changes: 17 additions & 0 deletions language/bert/colotensor/model/hfmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from colossalai.core import global_context as gpc
import torch

class ModelFromHF(torch.nn.Module):
def __init__(self, config, model_cls):
super().__init__()
self.module = model_cls(config)
if gpc.config.model.get('checkpoint'):
self.module.apply(self.set_checkpointing)

def set_checkpointing(self, module):
if hasattr(module, 'gradient_checkpointing'):
module.gradient_checkpointing = True

def forward(self, *args, **kwargs):
output = self.module(*args, **kwargs)
return output.logits
106 changes: 106 additions & 0 deletions language/bert/colotensor/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os

import colossalai
import colossalai.utils as utils
import torch
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn import LinearWarmupLR
from colossalai.trainer import Trainer, hooks
from colossalai.utils import colo_set_process_memory_fraction, get_current_device, MultiTimer
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.nn._ops import *
from colossalai.nn.parallel.layers import init_colo_module
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction

from language.bert.colotensor.dataset import build_data
from language.bert.colotensor.model import build_model

def calc_local_model_size(model: torch.nn.Module):
numel_per_device = 0
for p in model.parameters():
numel_per_device += p.numel()
return numel_per_device


def main():
parser = colossalai.get_default_parser()
parser.add_argument('--from_torch', default=True, action='store_true')
args = parser.parse_args()
disable_existing_loggers()
colossalai.launch_from_torch(config=args.config)

logger = get_dist_logger()

logger.info('Build data loader', ranks=[0])
train_dataloader, test_dataloader = build_data(
dataset_path=os.environ["DATA"],
tokenizer_path=os.environ["TOKENIZER"],
seq_len=gpc.config.SEQ_LENGTH,
batch_size=gpc.config.BATCH_SIZE,
)

logger.info('Build model', ranks=[0])
use_zero = hasattr(gpc.config, 'zero')

# TODO(jzy) Add ZERO
if use_zero:
raise NotImplemented
else:
with ColoInitContext(device=get_current_device()):
model = build_model()

parallel_action = ParallelAction(ComputePattern.TP1D)
init_colo_module(model, parallel_action, recursive=True, mode='col')

if use_zero:
raise NotImplemented
else:
numel = calc_local_model_size(model)

tflop = numel * gpc.config.BATCH_SIZE * gpc.config.SEQ_LENGTH \
* gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4)

criterion = nn.CrossEntropyLoss()

logger.info('Build optimizer', ranks=[0])
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(model.parameters(), lr=0.001, weight_decay=1e-2)

lr_scheduler = LinearWarmupLR(optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=2)

engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model,
optimizer,
criterion,
train_dataloader=train_dataloader,
lr_scheduler=lr_scheduler)
global_batch_size = gpc.config.BATCH_SIZE * \
gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1)
logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0])

timier = MultiTimer()

trainer = Trainer(engine=engine, logger=logger, timer=timier)

hook_list = [
hooks.LossHook(),
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True),
hooks.LogMetricByEpochHook(logger),
hooks.ThroughputHook(ignored_steps=10, tflop_per_step=tflop),
hooks.LogMetricByStepHook(),
hooks.LogMemoryByEpochHook(logger),
]

trainer.fit(train_dataloader=train_dataloader,
epochs=gpc.config.NUM_EPOCHS,
test_interval=1,
hooks=hook_list,
display_progress=True,
return_output_label=False,
max_steps=5)


if __name__ == '__main__':
main()

0 comments on commit f71dd4f

Please sign in to comment.