Skip to content
This repository has been archived by the owner on Mar 23, 2023. It is now read-only.

Commit

Permalink
update example
Browse files Browse the repository at this point in the history
  • Loading branch information
ver217 committed Jun 22, 2022
1 parent c49644c commit 860f555
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions features/zero/train_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from transformers import GPT2Config, GPT2LMHeadModel
from time import time
from functools import partial
from colossalai.tensor.chunk import ChunkManager
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor import ChunkManager
from colossalai.gemini import GeminiManager
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.utils import get_current_device
from colossalai.nn.parallel import ColoDDPV2
from colossalai.nn.parallel import ZeroDDP
from colossalai.zero import ZeroOptimizer


Expand Down Expand Up @@ -81,6 +81,7 @@ def main():
SEQ_LEN = 1024
VOCAB_SIZE = 50257
NUM_STEPS = 10
PLACEMENT_POLICY = 'cpu'
disable_existing_loggers()
colossalai.launch_from_torch(config={})
logger = get_dist_logger()
Expand All @@ -94,9 +95,9 @@ def main():
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=True,
init_device=GeminiManager.get_default_device('cpu'))
gemini_manager = GeminiManager('cpu', chunk_manager)
model = ColoDDPV2(model, gemini_manager)
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY))
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
model = ZeroDDP(model, gemini_manager)
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
logger.info(chunk_manager, ranks=[0])

Expand Down

0 comments on commit 860f555

Please sign in to comment.