This repository has been archived by the owner on Mar 23, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 103
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
277 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[WIP] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
SEQ_LENGTH = 512 | ||
BATCH_SIZE = 8 | ||
NUM_EPOCHS = 10 | ||
WARMUP_EPOCHS = 1 | ||
|
||
parallel = dict( | ||
tensor=dict(mode="1d", size=4), | ||
) | ||
|
||
model = dict( | ||
type="bert_base", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
|
||
from language.bert.colotensor.dataset.wikitext import build_data_from_wikitext | ||
from colossalai.core import global_context as gpc | ||
|
||
_datasets = { | ||
"wikitext": build_data_from_wikitext, | ||
} | ||
|
||
def build_data(**args): | ||
if hasattr(gpc.config, "dataset"): | ||
assert ( | ||
gpc.config.dataset in _datasets.keys() | ||
), f"Invalid dataset name. dataset should be in {_datasets.keys()} or use default wikitext" | ||
builder = _datasets[gpc.config.dataset] | ||
else: | ||
builder = _datasets["wikitext"] | ||
return builder(**args) | ||
|
||
|
||
__all__ = ["build_data"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import random | ||
import torch | ||
import numpy as np | ||
import copy | ||
|
||
from itertools import chain | ||
from datasets import load_from_disk, set_progress_bar_enabled | ||
|
||
from torch.utils.data import DataLoader, DistributedSampler | ||
from torch.distributed import get_world_size | ||
|
||
from transformers import BertTokenizer, default_data_collator | ||
from colossalai.logging import get_dist_logger | ||
|
||
|
||
def build_data_from_wikitext(dataset_path: str, tokenizer_path: str, seq_len: int = 512, batch_size: int = 8): | ||
logger = get_dist_logger("build_data_from_wikitext") | ||
logger.info("Building Wikitext-2 ...", ranks=[0]) | ||
world_size = get_world_size() | ||
|
||
set_progress_bar_enabled(False) | ||
dataset = load_from_disk(dataset_path) | ||
|
||
tokenizer = BertTokenizer(vocab_file=tokenizer_path + "/vocab.txt") | ||
|
||
def tokenize(examples): | ||
seq_length = seq_len | ||
examples = tokenizer(examples["text"]) | ||
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} | ||
total_length = len(concatenated_examples[list(examples.keys())[0]]) | ||
if total_length >= seq_length: | ||
total_length = (total_length // seq_length) * seq_length | ||
|
||
result = { | ||
k: [t[i : i + seq_len] for i in range(0, total_length, seq_length)] | ||
for k, t in concatenated_examples.items() | ||
} | ||
|
||
return result | ||
|
||
tokenized_dataset = dataset.map( | ||
tokenize, batched=True, num_proc=16, load_from_cache_file=False, keep_in_memory=True, remove_columns="text" | ||
) | ||
|
||
def seed_worker(): | ||
worker_seed = 1024 | ||
np.random.seed(worker_seed) | ||
torch.manual_seed(worker_seed) | ||
random.seed(worker_seed) | ||
|
||
train_sampler = DistributedSampler(tokenized_dataset["train"], shuffle=True) if world_size > 1 else None | ||
train_data = DataLoader( | ||
tokenized_dataset["train"], | ||
shuffle=(train_sampler is None), | ||
sampler=train_sampler, | ||
drop_last=True, | ||
collate_fn=default_data_collator, | ||
worker_init_fn=seed_worker, | ||
batch_size=batch_size, | ||
pin_memory=True, | ||
) | ||
test_sampler = DistributedSampler(tokenized_dataset["validation"], shuffle=False) if world_size > 1 else None | ||
test_data = DataLoader( | ||
tokenized_dataset["validation"], | ||
sampler=test_sampler, | ||
drop_last=True, | ||
collate_fn=default_data_collator, | ||
worker_init_fn=seed_worker, | ||
batch_size=batch_size, | ||
pin_memory=True, | ||
) | ||
|
||
return train_data, test_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from language.bert.colotensor.model.hfmodel import ModelFromHF | ||
from colossalai.core import global_context as gpc | ||
from transformers import BertConfig, BertForMaskedLM | ||
|
||
_bert_base = dict( | ||
seq_length=512, | ||
vocab_size=50304, | ||
hidden_size=768, | ||
num_heads=12, | ||
depth=12, | ||
ff_size=3072, | ||
checkpoint=False, | ||
evaluation='ppl', | ||
) | ||
|
||
_bert_large = dict( | ||
seq_length=512, | ||
vocab_size=50304, | ||
hidden_size=1024, | ||
num_heads=16, | ||
depth=24, | ||
ff_size=3072, | ||
checkpoint=False, | ||
evaluation='ppl', | ||
) | ||
|
||
_bert_configurations = dict( | ||
bert=_bert_base, | ||
bert_base=_bert_base, | ||
bert_large=_bert_large | ||
) | ||
|
||
def build_model(): | ||
model_cfg = _bert_configurations[gpc.config.model.type] | ||
bert_cfg = BertConfig(vocab_size=model_cfg['vocab_size'], | ||
hidden_size=model_cfg['hidden_size'], | ||
num_hidden_layers=model_cfg['depth'], | ||
num_attention_heads=model_cfg['num_heads'], | ||
intermediate_size=model_cfg['ff_size'], | ||
max_position_embeddings=model_cfg['seq_length'], | ||
use_cache=not gpc.config.model.get('checkpoint', False)) | ||
|
||
model = ModelFromHF(bert_cfg, BertForMaskedLM) | ||
|
||
return model | ||
|
||
__all__ = ["build_model"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from colossalai.core import global_context as gpc | ||
import torch | ||
|
||
class ModelFromHF(torch.nn.Module): | ||
def __init__(self, config, model_cls): | ||
super().__init__() | ||
self.module = model_cls(config) | ||
if gpc.config.model.get('checkpoint'): | ||
self.module.apply(self.set_checkpointing) | ||
|
||
def set_checkpointing(self, module): | ||
if hasattr(module, 'gradient_checkpointing'): | ||
module.gradient_checkpointing = True | ||
|
||
def forward(self, *args, **kwargs): | ||
output = self.module(*args, **kwargs) | ||
return output.logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import os | ||
|
||
import colossalai | ||
import colossalai.utils as utils | ||
import torch | ||
import torch.nn as nn | ||
from colossalai.context.parallel_mode import ParallelMode | ||
from colossalai.core import global_context as gpc | ||
from colossalai.logging import disable_existing_loggers, get_dist_logger | ||
from colossalai.nn import LinearWarmupLR | ||
from colossalai.trainer import Trainer, hooks | ||
from colossalai.utils import colo_set_process_memory_fraction, get_current_device, MultiTimer | ||
from colossalai.utils.model.colo_init_context import ColoInitContext | ||
from colossalai.nn._ops import * | ||
from colossalai.nn.parallel.layers import init_colo_module | ||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction | ||
|
||
from language.bert.colotensor.dataset import build_data | ||
from language.bert.colotensor.model import build_model | ||
|
||
def calc_local_model_size(model: torch.nn.Module): | ||
numel_per_device = 0 | ||
for p in model.parameters(): | ||
numel_per_device += p.numel() | ||
return numel_per_device | ||
|
||
|
||
def main(): | ||
parser = colossalai.get_default_parser() | ||
parser.add_argument('--from_torch', default=True, action='store_true') | ||
args = parser.parse_args() | ||
disable_existing_loggers() | ||
colossalai.launch_from_torch(config=args.config) | ||
|
||
logger = get_dist_logger() | ||
|
||
logger.info('Build data loader', ranks=[0]) | ||
train_dataloader, test_dataloader = build_data( | ||
dataset_path=os.environ["DATA"], | ||
tokenizer_path=os.environ["TOKENIZER"], | ||
seq_len=gpc.config.SEQ_LENGTH, | ||
batch_size=gpc.config.BATCH_SIZE, | ||
) | ||
|
||
logger.info('Build model', ranks=[0]) | ||
use_zero = hasattr(gpc.config, 'zero') | ||
|
||
# TODO(jzy) Add ZERO | ||
if use_zero: | ||
raise NotImplemented | ||
else: | ||
with ColoInitContext(device=get_current_device()): | ||
model = build_model() | ||
|
||
parallel_action = ParallelAction(ComputePattern.TP1D) | ||
init_colo_module(model, parallel_action, recursive=True, mode='col') | ||
|
||
if use_zero: | ||
raise NotImplemented | ||
else: | ||
numel = calc_local_model_size(model) | ||
|
||
tflop = numel * gpc.config.BATCH_SIZE * gpc.config.SEQ_LENGTH \ | ||
* gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4) | ||
|
||
criterion = nn.CrossEntropyLoss() | ||
|
||
logger.info('Build optimizer', ranks=[0]) | ||
optimizer_class = torch.optim.AdamW | ||
optimizer = optimizer_class(model.parameters(), lr=0.001, weight_decay=1e-2) | ||
|
||
lr_scheduler = LinearWarmupLR(optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=2) | ||
|
||
engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model, | ||
optimizer, | ||
criterion, | ||
train_dataloader=train_dataloader, | ||
lr_scheduler=lr_scheduler) | ||
global_batch_size = gpc.config.BATCH_SIZE * \ | ||
gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1) | ||
logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0]) | ||
|
||
timier = MultiTimer() | ||
|
||
trainer = Trainer(engine=engine, logger=logger, timer=timier) | ||
|
||
hook_list = [ | ||
hooks.LossHook(), | ||
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True), | ||
hooks.LogMetricByEpochHook(logger), | ||
hooks.ThroughputHook(ignored_steps=10, tflop_per_step=tflop), | ||
hooks.LogMetricByStepHook(), | ||
hooks.LogMemoryByEpochHook(logger), | ||
] | ||
|
||
trainer.fit(train_dataloader=train_dataloader, | ||
epochs=gpc.config.NUM_EPOCHS, | ||
test_interval=1, | ||
hooks=hook_list, | ||
display_progress=True, | ||
return_output_label=False, | ||
max_steps=5) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |