This repository has been archived by the owner on Mar 23, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 103
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
148 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
|
||
import colossalai | ||
import psutil | ||
import torch | ||
import torch.nn as nn | ||
from colossalai.logging import disable_existing_loggers, get_dist_logger | ||
from colossalai.nn.optimizer import HybridAdam | ||
from transformers import GPT2Config, GPT2LMHeadModel | ||
from time import time | ||
from functools import partial | ||
from colossalai.tensor import ChunkManager | ||
from colossalai.gemini import GeminiManager | ||
from colossalai.utils.model.colo_init_context import ColoInitContext | ||
from colossalai.utils import get_current_device | ||
from colossalai.nn.parallel import ZeroDDP | ||
from colossalai.zero import ZeroOptimizer | ||
|
||
|
||
class GPTLMModel(nn.Module): | ||
def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257, checkpoint=False): | ||
super().__init__() | ||
self.checkpoint = checkpoint | ||
self.model = GPT2LMHeadModel(GPT2Config(n_embd=hidden_size, n_layer=num_layers, | ||
n_head=num_attention_heads, n_positions=max_seq_len, n_ctx=max_seq_len, vocab_size=vocab_size)) | ||
if checkpoint: | ||
self.model.gradient_checkpointing_enable() | ||
|
||
def forward(self, input_ids, attention_mask): | ||
# Only return lm_logits | ||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] | ||
|
||
|
||
class GPTLMLoss(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.loss_fn = nn.CrossEntropyLoss() | ||
|
||
def forward(self, logits, labels): | ||
shift_logits = logits[..., :-1, :].contiguous() | ||
shift_labels = labels[..., 1:].contiguous() | ||
# Flatten the tokens | ||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | ||
|
||
|
||
def get_data(batch_size, seq_len, vocab_size): | ||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) | ||
attention_mask = torch.ones_like(input_ids) | ||
return input_ids, attention_mask | ||
|
||
|
||
def gpt2_medium(checkpoint=False): | ||
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) | ||
|
||
|
||
def gpt2_xl(checkpoint=True): | ||
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint) | ||
|
||
|
||
def gpt2_10b(checkpoint=True): | ||
return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16, checkpoint=checkpoint) | ||
|
||
|
||
def get_cpu_mem(): | ||
return psutil.Process().memory_info().rss / 1024**2 | ||
|
||
|
||
def get_gpu_mem(): | ||
return torch.cuda.memory_allocated() / 1024**2 | ||
|
||
|
||
def get_mem_info(prefix=''): | ||
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' | ||
|
||
|
||
def get_tflops(model_numel, batch_size, seq_len, step_time): | ||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) | ||
|
||
|
||
def main(): | ||
BATCH_SIZE = 8 | ||
SEQ_LEN = 1024 | ||
VOCAB_SIZE = 50257 | ||
NUM_STEPS = 10 | ||
PLACEMENT_POLICY = 'cpu' | ||
disable_existing_loggers() | ||
colossalai.launch_from_torch(config={}) | ||
logger = get_dist_logger() | ||
|
||
logger.info(get_mem_info(), ranks=[0]) | ||
# build GPT model | ||
with ColoInitContext(device=get_current_device()): | ||
model = gpt2_medium(checkpoint=True) | ||
numel = sum([p.numel() for p in model.parameters()]) | ||
logger.info(f'Model numel: {numel}', ranks=[0]) | ||
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) | ||
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) | ||
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=True, | ||
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY)) | ||
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager) | ||
model = ZeroDDP(model, gemini_manager) | ||
logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) | ||
logger.info(chunk_manager, ranks=[0]) | ||
|
||
# build criterion | ||
criterion = GPTLMLoss() | ||
|
||
# optimizer | ||
optimizer = HybridAdam(model.parameters(), lr=1e-3) | ||
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5) | ||
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) | ||
|
||
model.train() | ||
for n in range(NUM_STEPS): | ||
# we just use randomly generated data here | ||
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE) | ||
optimizer.zero_grad() | ||
start = time() | ||
outputs = model(input_ids, attn_mask) | ||
loss = criterion(outputs, input_ids) | ||
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0]) | ||
optimizer.backward(loss) | ||
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0]) | ||
optimizer.step() | ||
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0]) | ||
step_time = time() - start | ||
logger.info( | ||
f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', ranks=[0]) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |