From 43371f5bad504e931fca9365db35734085167f3d Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 16 Jun 2022 16:31:19 +0800 Subject: [PATCH 1/4] update zero example --- features/zero/train.py | 20 ++++-- features/zero/train_v2.py | 130 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 4 deletions(-) create mode 100644 features/zero/train_v2.py diff --git a/features/zero/train.py b/features/zero/train.py index ca208f7..60adadc 100644 --- a/features/zero/train.py +++ b/features/zero/train.py @@ -10,6 +10,8 @@ from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2 from transformers import GPT2Config, GPT2LMHeadModel +from time import time +from functools import partial class GPTLMModel(nn.Module): @@ -68,6 +70,10 @@ def get_mem_info(prefix=''): return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) + + def main(): BATCH_SIZE = 8 SEQ_LEN = 1024 @@ -80,8 +86,11 @@ def main(): logger.info(get_mem_info(), ranks=[0]) # build GPT model shard_strategy = TensorShardStrategy() - with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, shard_param=True): + with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, shard_param=True) as ctx: model = gpt2_medium(checkpoint=True) + numel = ctx.model_numel_tensor.item() + logger.info(f'Model numel: {numel}') + get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) # Set tensor_placement_policy='cpu', which will offload params, grads and os model = ShardedModelV2(model, shard_strategy, tensor_placement_policy='cpu', reuse_fp16_shard=True) logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) @@ -99,13 +108,16 @@ def main(): # we just use randomly generated data here input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE) optimizer.zero_grad() + start = time() outputs = model(input_ids, attn_mask) loss = criterion(outputs, input_ids) - logger.info(get_mem_info(prefix=f'Forward [{n+1}/{NUM_STEPS}] '), ranks=[0]) + logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0]) optimizer.backward(loss) - logger.info(get_mem_info(prefix=f'Backward [{n+1}/{NUM_STEPS}] '), ranks=[0]) + logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0]) optimizer.step() - logger.info(get_mem_info(prefix=f'Optimizer step [{n+1}/{NUM_STEPS}] '), ranks=[0]) + logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0]) + step_time = time() - start + logger.info(f'[{n+1}/{NUM_STEPS}] Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', ranks=[0]) if __name__ == '__main__': diff --git a/features/zero/train_v2.py b/features/zero/train_v2.py new file mode 100644 index 0000000..fb17cf7 --- /dev/null +++ b/features/zero/train_v2.py @@ -0,0 +1,130 @@ + +import colossalai +import psutil +import torch +import torch.nn as nn +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from transformers import GPT2Config, GPT2LMHeadModel +from time import time +from functools import partial +from colossalai.tensor.chunk import ChunkManager +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.utils import get_current_device +from colossalai.nn.parallel import ColoDDPV2 +from colossalai.zero import ZeroOptimizer + + +class GPTLMModel(nn.Module): + def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257, checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.model = GPT2LMHeadModel(GPT2Config(n_embd=hidden_size, n_layer=num_layers, + n_head=num_attention_heads, n_positions=max_seq_len, n_ctx=max_seq_len, vocab_size=vocab_size)) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + + +class GPTLMLoss(nn.Module): + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +def gpt2_medium(checkpoint=False): + return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_xl(checkpoint=True): + return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint) + + +def gpt2_10b(checkpoint=True): + return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16, checkpoint=checkpoint) + + +def get_cpu_mem(): + return psutil.Process().memory_info().rss / 1024**2 + + +def get_gpu_mem(): + return torch.cuda.memory_allocated() / 1024**2 + + +def get_mem_info(prefix=''): + return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) + + +def main(): + BATCH_SIZE = 8 + SEQ_LEN = 1024 + VOCAB_SIZE = 50257 + NUM_STEPS = 10 + disable_existing_loggers() + colossalai.launch_from_torch(config={}) + logger = get_dist_logger() + + logger.info(get_mem_info(), ranks=[0]) + # build GPT model + with ColoInitContext(device=get_current_device()): + model = gpt2_medium(checkpoint=True) + numel = sum([p.numel() for p in model.parameters()]) + logger.info(f'Model numel: {numel}') + get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) + + chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) + chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=True, + init_device=GeminiManager.get_default_device('cpu')) + gemini_manager = GeminiManager('cpu', chunk_manager) + model = ColoDDPV2(model, gemini_manager) + logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) + logger.info(chunk_manager) + + # build criterion + criterion = GPTLMLoss() + + # optimizer + optimizer = HybridAdam(model.parameters(), lr=1e-3) + optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5) + logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) + + model.train() + for n in range(NUM_STEPS): + # we just use randomly generated data here + input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE) + optimizer.zero_grad() + start = time() + outputs = model(input_ids, attn_mask) + loss = criterion(outputs, input_ids) + logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0]) + optimizer.backward(loss) + logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0]) + optimizer.step() + logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0]) + step_time = time() - start + logger.info(f'[{n+1}/{NUM_STEPS}] Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', ranks=[0]) + + +if __name__ == '__main__': + main() From 2b05c0de0a7b023e85a21439c9ed33dd4bb17b18 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 16 Jun 2022 17:30:07 +0800 Subject: [PATCH 2/4] polish code --- features/zero/train.py | 2 +- features/zero/train_v2.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/features/zero/train.py b/features/zero/train.py index 60adadc..e16a4c5 100644 --- a/features/zero/train.py +++ b/features/zero/train.py @@ -89,7 +89,7 @@ def main(): with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, shard_param=True) as ctx: model = gpt2_medium(checkpoint=True) numel = ctx.model_numel_tensor.item() - logger.info(f'Model numel: {numel}') + logger.info(f'Model numel: {numel}', ranks=[0]) get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) # Set tensor_placement_policy='cpu', which will offload params, grads and os model = ShardedModelV2(model, shard_strategy, tensor_placement_policy='cpu', reuse_fp16_shard=True) diff --git a/features/zero/train_v2.py b/features/zero/train_v2.py index fb17cf7..249f2fc 100644 --- a/features/zero/train_v2.py +++ b/features/zero/train_v2.py @@ -90,16 +90,15 @@ def main(): with ColoInitContext(device=get_current_device()): model = gpt2_medium(checkpoint=True) numel = sum([p.numel() for p in model.parameters()]) - logger.info(f'Model numel: {numel}') + logger.info(f'Model numel: {numel}', ranks=[0]) get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) - chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=True, init_device=GeminiManager.get_default_device('cpu')) gemini_manager = GeminiManager('cpu', chunk_manager) model = ColoDDPV2(model, gemini_manager) logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) - logger.info(chunk_manager) + logger.info(chunk_manager, ranks=[0]) # build criterion criterion = GPTLMLoss() From c49644cfb2980d0804c7e72260e82ce944d5a68a Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 17 Jun 2022 12:22:57 +0800 Subject: [PATCH 3/4] polish code --- features/zero/train.py | 3 ++- features/zero/train_v2.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/features/zero/train.py b/features/zero/train.py index e16a4c5..f652d28 100644 --- a/features/zero/train.py +++ b/features/zero/train.py @@ -117,7 +117,8 @@ def main(): optimizer.step() logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0]) step_time = time() - start - logger.info(f'[{n+1}/{NUM_STEPS}] Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', ranks=[0]) + logger.info( + f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', ranks=[0]) if __name__ == '__main__': diff --git a/features/zero/train_v2.py b/features/zero/train_v2.py index 249f2fc..e07169d 100644 --- a/features/zero/train_v2.py +++ b/features/zero/train_v2.py @@ -122,7 +122,8 @@ def main(): optimizer.step() logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0]) step_time = time() - start - logger.info(f'[{n+1}/{NUM_STEPS}] Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', ranks=[0]) + logger.info( + f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', ranks=[0]) if __name__ == '__main__': From 860f5559a6e9f42cdc190ef06f91377ca4d5c2b8 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 22 Jun 2022 17:26:57 +0800 Subject: [PATCH 4/4] update example --- features/zero/train_v2.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/features/zero/train_v2.py b/features/zero/train_v2.py index e07169d..c5654f7 100644 --- a/features/zero/train_v2.py +++ b/features/zero/train_v2.py @@ -8,11 +8,11 @@ from transformers import GPT2Config, GPT2LMHeadModel from time import time from functools import partial -from colossalai.tensor.chunk import ChunkManager -from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.tensor import ChunkManager +from colossalai.gemini import GeminiManager from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils import get_current_device -from colossalai.nn.parallel import ColoDDPV2 +from colossalai.nn.parallel import ZeroDDP from colossalai.zero import ZeroOptimizer @@ -81,6 +81,7 @@ def main(): SEQ_LEN = 1024 VOCAB_SIZE = 50257 NUM_STEPS = 10 + PLACEMENT_POLICY = 'cpu' disable_existing_loggers() colossalai.launch_from_torch(config={}) logger = get_dist_logger() @@ -94,9 +95,9 @@ def main(): get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=True, - init_device=GeminiManager.get_default_device('cpu')) - gemini_manager = GeminiManager('cpu', chunk_manager) - model = ColoDDPV2(model, gemini_manager) + init_device=GeminiManager.get_default_device(PLACEMENT_POLICY)) + gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager) + model = ZeroDDP(model, gemini_manager) logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) logger.info(chunk_manager, ranks=[0])