Skip to content
This repository has been archived by the owner on Mar 23, 2023. It is now read-only.

Commit

Permalink
[hotfix]fit to the refactored pipelinable api (#137)
Browse files Browse the repository at this point in the history
* change to fit refactored schedule

* [hotfix]fit to the refactored pipelinable api

* polish
  • Loading branch information
YuliangLiu0306 authored Jun 13, 2022
1 parent d1ce233 commit 1371a5b
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 11 deletions.
9 changes: 1 addition & 8 deletions image/mlpmixer/train_pipline.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,19 +175,12 @@ def train():

use_pipeline = is_using_pp()

# pipelinable = PipelinableContext()
# with pipelinable:
# model = mixer_s32(num_classes,image_size,patch_size)
# pipelinable.to_layer_list()
# pipelinable.load_policy("uniform")
# model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))

if use_pipeline:
pipelinable = PipelinableContext()
with pipelinable:
model = mixer_s32(num_classes, image_size, patch_size)
pipelinable.to_layer_list()
pipelinable.load_policy("uniform")
pipelinable.policy = "uniform"
model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
else:
model = mixer_s32(num_classes, image_size, patch_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def main():
with pipelinable:
model = _create_vit_model(**model_kwargs)
pipelinable.to_layer_list()
pipelinable.load_policy("uniform")
pipelinable.policy = "uniform"
model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
else:
model = _create_vit_model(**model_kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def train_imagenet():
with pipelinable:
model = _create_vit_model(**model_kwargs)
pipelinable.to_layer_list()
pipelinable.load_policy("uniform")
pipelinable.policy = "uniform"
model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
else:
model = _create_vit_model(**model_kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def train_imagenet():
with pipelinable:
model = _create_vit_model(**model_kwargs)
pipelinable.to_layer_list()
pipelinable.load_policy("uniform")
pipelinable.policy = "uniform"
model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
else:
model = _create_vit_model(**model_kwargs)
Expand Down

0 comments on commit 1371a5b

Please sign in to comment.