Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for seed checkpoint creation for meta-init flow #172

Merged
merged 15 commits into from
May 2, 2024

Conversation

wconstab
Copy link
Contributor

@wconstab wconstab commented Mar 27, 2024

Stack from ghstack (oldest at bottom):

Adds new command ./create_seed_checkpoint.sh which largely
reuses code inside train.py to create the model and then save its
initial state as a step-0 checkpoint for use with meta-initialization
loading flow.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Mar 27, 2024
ghstack-source-id: c0f3c9c4605f933e8a21c61f79774cb9e5a22f85
Pull Request resolved: #172
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 27, 2024
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Mar 27, 2024
ghstack-source-id: da50fcefcb67cdf7c5dae14f2a407e6adb17824c
Pull Request resolved: #172
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
wconstab added 3 commits April 5, 2024 10:25
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path)

# build model (using meta init)
model_cls = model_name_to_cls[model_name]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering could we do some benchmarks about the time needed to start training with the seed checkpoint? I am a bit worrying that if that takes very long then this approach might not be our desired solution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, its a good point. what configuration should i benchmark? i have only been running tiny models.

but any 'real' training should expect to save and load from checkpoints periodically due to faults. so i think we need to have fast enough ckpt load that we can live with during training.

probably a bigger deal just for UX for iterative development, but for small models it is not a noticeable amount of time.

the only workaround i can think of is to write a custom 'initializer' function for our model that we can call safely on a post-PP-split model chunk. this is complex/ugly, but should be fast, but only helps the first launch of training not the ckpt resume.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry just saw this! I think ideally benchmarking a 13B/70B model loading seed checkpoints.

The real training loading checkpoints when restarting actually make sense to me, so I guess we can live with the seed checkpoint if that's the best UX.

For the workaround, iiuc pipeline splitting would only need to check: 1. first embedding layer exist 2. last projection layer exist. For the transformerblock module list it does not need to check anything as it's a for loop anyway?

wconstab added 4 commits April 5, 2024 14:21
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@wconstab wconstab mentioned this pull request Apr 29, 2024
[ghstack-poisoned]
@wconstab wconstab mentioned this pull request May 1, 2024
[ghstack-poisoned]
@wconstab wconstab requested a review from wanchaol May 2, 2024 05:10
Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!
Thanks for doing the foundational work!

Comment on lines 31 to 33
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --job.config_file ${CONFIG_FILE} $seed_checkpoint $overrides
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If NGPU is always 1 for creating seed checkpoint, shall we just launch the script by:
python train.py
?
Not sure if it would work directly though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't work bc the script still expects things like WORLD_SIZE to be set.

I considered just hardcoding the envs inside the launcher, but, why not just keep it simple

Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, have some questions and suggestions inlined


if job_config.checkpoint.create_seed_checkpoint:
assert (
world_size == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some questions about the meta init workflow, assuming we are planning to train 70b model with PP. But initially we want to create a seed checkpoint so that we can use it later for meta init load. Do it mean that we need to init the 70b model on CPU first? If we init a debug model and save its seed checkpoint, I guess that won't be reusable in the later 70b model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the seed checkpoint must match the model we're training.

I suppose i have not even tested the seed creation on cpu vs gpu- probably we need to expose more options or smarts to determine which device to use here.

Do it mean that we need to init the 70b model on CPU first?

yes, that's right. if this isn't OK, i think we can try to hand-write some initializer functions that can work with the PP traced function, but maybe that can come later?

Note: even if we say "ditch pipeline tracer", it wont fix the initializer problems. We'd still need to customize the model's init_weights functions so that they can work given a 'model chunk' instead of a whole model. And then we'd also have to tolerate some RNG divergence from the non-PP case, or, implement RNG seed passing and coordination.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that can come later. My main worry is that for larger models like 70B/400B, this is going to be quite a initialization bottleneck, that we need to do:

  1. create a 70B/400B model on CPU and save it to disk, probably take a couple of minutes
  2. trace the 70B/400B model, take another tensor of couple of minutes
  3. load the checkpoint, take another couple of minutes (this is fine I think as training anyways always want to save/load checkpoints)

Maybe we should brainstorming more on how PP could support meta init in a more sound way, i.e. maybe a PipelineStage.init_weights that wait for its RNG (if not the first stage), init the current meta model's weights, and transmit the RNGs to next stage

Copy link
Contributor Author

@wconstab wconstab May 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. trace the 70B/400B model, take another tensor of couple of minutes

do you think the tracing time will be long? I actually assumed that the tracing time would not be an issue, it should be more related to the number of operators than the parameter size.

Maybe we should brainstorming more on how PP could support meta init in a more sound way,

I am kind of ambivalent about this. Yes, i like the idea, even proposed it at one point. But on the other hand I see quite a few more important issues to solve first. And this idea is quite complicated, so it should be justified. So this seems like something P1 to me.


torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --job.config_file ${CONFIG_FILE} $seed_checkpoint $overrides
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given that seed checkpoint requires no parallelisms, we should just provide the overrides here (i.e. training.dp_degree=1) to disable all parallelisms given that we already specify NGPU=1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea thats a good point.

i didn't want to hardcode those kind of things inside train.py, but, in this script it seems exactly right to do this. i'll change it.

# All rights reserved.

# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some comments in the beginning of the file to explain what this script is used for and how to run this?

[ghstack-poisoned]
@wconstab wconstab merged commit a57d458 into gh/wconstab/3/base May 2, 2024
4 checks passed
wconstab added a commit that referenced this pull request May 2, 2024
Adds new command ./create_seed_checkpoint.sh which largely
reuses code inside train.py to create the model and then save its
initial state as a step-0 checkpoint for use with meta-initialization
loading flow.

ghstack-source-id: 3e1aa9eab847c1f1341f22772ca8ae3688883454
Pull Request resolved: #172
@wconstab wconstab deleted the gh/wconstab/3/head branch May 2, 2024 17:43
tianyu-l pushed a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
Adds new command ./create_seed_checkpoint.sh which largely
reuses code inside train.py to create the model and then save its
initial state as a step-0 checkpoint for use with meta-initialization
loading flow.

ghstack-source-id: 3e1aa9eab847c1f1341f22772ca8ae3688883454
Pull Request resolved: pytorch#172
tianyu-l pushed a commit that referenced this pull request Aug 16, 2024
ghstack-source-id: eb584b26c23535d7d6db4e44c0074c2b4adf1515
Pull Request resolved: #172
philippguevorguian pushed a commit to YerevaNN/YNNtitan that referenced this pull request Aug 17, 2024
Adds new command ./create_seed_checkpoint.sh which largely
reuses code inside train.py to create the model and then save its
initial state as a step-0 checkpoint for use with meta-initialization
loading flow.

ghstack-source-id: 3e1aa9eab847c1f1341f22772ca8ae3688883454
Pull Request resolved: pytorch#172
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants