Skip to content
This repository has been archived by the owner on Mar 23, 2023. It is now read-only.

Commit

Permalink
[zero] update zero example
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Jun 23, 2022
2 parents b4cfda7 + 860f555 commit a7eae54
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 4 deletions.
21 changes: 17 additions & 4 deletions features/zero/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2
from transformers import GPT2Config, GPT2LMHeadModel
from time import time
from functools import partial


class GPTLMModel(nn.Module):
Expand Down Expand Up @@ -68,6 +70,10 @@ def get_mem_info(prefix=''):
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'


def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)


def main():
BATCH_SIZE = 8
SEQ_LEN = 1024
Expand All @@ -80,8 +86,11 @@ def main():
logger.info(get_mem_info(), ranks=[0])
# build GPT model
shard_strategy = TensorShardStrategy()
with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, shard_param=True):
with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, shard_param=True) as ctx:
model = gpt2_medium(checkpoint=True)
numel = ctx.model_numel_tensor.item()
logger.info(f'Model numel: {numel}', ranks=[0])
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
# Set tensor_placement_policy='cpu', which will offload params, grads and os
model = ShardedModelV2(model, shard_strategy, tensor_placement_policy='cpu', reuse_fp16_shard=True)
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
Expand All @@ -99,13 +108,17 @@ def main():
# we just use randomly generated data here
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
optimizer.zero_grad()
start = time()
outputs = model(input_ids, attn_mask)
loss = criterion(outputs, input_ids)
logger.info(get_mem_info(prefix=f'Forward [{n+1}/{NUM_STEPS}] '), ranks=[0])
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0])
optimizer.backward(loss)
logger.info(get_mem_info(prefix=f'Backward [{n+1}/{NUM_STEPS}] '), ranks=[0])
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0])
optimizer.step()
logger.info(get_mem_info(prefix=f'Optimizer step [{n+1}/{NUM_STEPS}] '), ranks=[0])
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
step_time = time() - start
logger.info(
f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', ranks=[0])


if __name__ == '__main__':
Expand Down
131 changes: 131 additions & 0 deletions features/zero/train_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@

import colossalai
import psutil
import torch
import torch.nn as nn
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from transformers import GPT2Config, GPT2LMHeadModel
from time import time
from functools import partial
from colossalai.tensor import ChunkManager
from colossalai.gemini import GeminiManager
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.utils import get_current_device
from colossalai.nn.parallel import ZeroDDP
from colossalai.zero import ZeroOptimizer


class GPTLMModel(nn.Module):
def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257, checkpoint=False):
super().__init__()
self.checkpoint = checkpoint
self.model = GPT2LMHeadModel(GPT2Config(n_embd=hidden_size, n_layer=num_layers,
n_head=num_attention_heads, n_positions=max_seq_len, n_ctx=max_seq_len, vocab_size=vocab_size))
if checkpoint:
self.model.gradient_checkpointing_enable()

def forward(self, input_ids, attention_mask):
# Only return lm_logits
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]


class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()

def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))


def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask


def gpt2_medium(checkpoint=False):
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)


def gpt2_xl(checkpoint=True):
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint)


def gpt2_10b(checkpoint=True):
return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16, checkpoint=checkpoint)


def get_cpu_mem():
return psutil.Process().memory_info().rss / 1024**2


def get_gpu_mem():
return torch.cuda.memory_allocated() / 1024**2


def get_mem_info(prefix=''):
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'


def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)


def main():
BATCH_SIZE = 8
SEQ_LEN = 1024
VOCAB_SIZE = 50257
NUM_STEPS = 10
PLACEMENT_POLICY = 'cpu'
disable_existing_loggers()
colossalai.launch_from_torch(config={})
logger = get_dist_logger()

logger.info(get_mem_info(), ranks=[0])
# build GPT model
with ColoInitContext(device=get_current_device()):
model = gpt2_medium(checkpoint=True)
numel = sum([p.numel() for p in model.parameters()])
logger.info(f'Model numel: {numel}', ranks=[0])
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY))
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
model = ZeroDDP(model, gemini_manager)
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
logger.info(chunk_manager, ranks=[0])

# build criterion
criterion = GPTLMLoss()

# optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3)
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])

model.train()
for n in range(NUM_STEPS):
# we just use randomly generated data here
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
optimizer.zero_grad()
start = time()
outputs = model(input_ids, attn_mask)
loss = criterion(outputs, input_ids)
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0])
optimizer.backward(loss)
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0])
optimizer.step()
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
step_time = time() - start
logger.info(
f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', ranks=[0])


if __name__ == '__main__':
main()

0 comments on commit a7eae54

Please sign in to comment.