-
Notifications
You must be signed in to change notification settings - Fork 216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sync layer norm #271
Draft
thomasw21
wants to merge
38
commits into
thomas/test_different_layer_norm
Choose a base branch
from
thomas/fix_layer_norm
base: thomas/test_different_layer_norm
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Sync layer norm #271
Changes from 11 commits
Commits
Show all changes
38 commits
Select commit
Hold shift + click to select a range
07ccb3d
Better
thomasw21 391ed48
Force synchronize the layer norms parameters across all TP
thomasw21 98d0e7c
import mpu
stas00 279a77e
use the bf16 branch for testing
stas00 87a9dba
`torch.testing.assert_equal` didn't make it (#273)
stas00 dbb5914
Merge remote-tracking branch 'origin/main' into thomas/fix_layer_norm
stas00 70f91f8
bf16 comms requite pt-1.11
stas00 835a3e5
already part of the function
stas00 37795a9
reproduce the crashing on resume
stas00 3ec65f7
run just the test we want for now
stas00 8271d41
all_reduce is an in_place operation
thomasw21 b418b47
Make a test that TP reshaping works
thomasw21 4b7207b
Woops
thomasw21 3bc5824
Woops
thomasw21 05c99db
Woops
thomasw21 55e10c6
Woops
thomasw21 2ab8a3a
Woops
thomasw21 d357839
Woops
thomasw21 5fb231c
Woops
thomasw21 cc7ff45
Woops
thomasw21 7cdb1be
Woops
thomasw21 4574ec9
Fix load issue
thomasw21 04e89d1
Woops
thomasw21 e943100
Fix checkpoint path
thomasw21 09cead3
Test that force sync will allow TP changes
thomasw21 77abee6
Nit
thomasw21 64a62c8
Now that we have a force sync mechanism, let's try to reproduce
thomasw21 0b7afcc
Compare model_states_rank
thomasw21 ce01733
test
thomasw21 89ab0b7
Row column bias should be synchronized as well
thomasw21 42997b2
New list of matching embeddings
thomasw21 e0ef168
Figure out why state differs
thomasw21 1fc4fe8
Test for final weight
thomasw21 7ebbed1
Test that torch_rng_state
thomasw21 2c49216
Fix non matching torch_rng_state for tp_rank=0
thomasw21 007ecb4
Update test
thomasw21 c3844b5
I'm surprised one can apply inplace operation here
thomasw21 189f054
Test out the loss from the fp32 weights and optimizer states
thomasw21 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -682,6 +682,8 @@ def test_layer_norm_consistent(self, variation): | |
execute_subprocess_async(cmd, env=self.get_env()) | ||
|
||
checkpoints = ["global_step10", "global_step20"] | ||
|
||
# Check transformer layer norm | ||
keys_to_compare = ["input_layernorm.weight", "input_layernorm.bias", "post_attention_layernorm.weight", "post_attention_layernorm.bias"] | ||
files_to_compare = [[f"layer_{layer_id:02d}-model_{tp:02d}-model_states.pt" for tp in range(num_gpus)] for layer_id in [3,4]] | ||
for checkpoint in checkpoints: | ||
|
@@ -691,8 +693,9 @@ def test_layer_norm_consistent(self, variation): | |
weights = [torch.load(os.path.join(checkpoint_path, file))[key] for file in files] | ||
ref = weights[0] | ||
for weight in weights[1:]: | ||
torch_assert_equal(ref, weight, rtol=0.0, atol=0.0, check_device=False) | ||
torch_assert_equal(ref, weight, check_device=False) | ||
|
||
# Check embed layer norm | ||
keys_to_compare = ["word_embeddings.norm.weight"] | ||
files_to_compare = [[f"layer_{layer_id:02d}-model_{tp:02d}-model_states.pt" for tp in range(num_gpus)] for layer_id in [1]] | ||
for checkpoint in checkpoints: | ||
|
@@ -702,4 +705,15 @@ def test_layer_norm_consistent(self, variation): | |
weights = [torch.load(os.path.join(checkpoint_path, file))[key] for file in files] | ||
ref = weights[0] | ||
for weight in weights[1:]: | ||
torch_assert_equal(ref, weight, rtol=0.0, atol=0.0, check_device=False) | ||
torch_assert_equal(ref, weight, check_device=False) | ||
|
||
# 2. test training from checkpoint: resume | ||
# now do it again, this time resuming from the checkpoint | ||
with CaptureStdout() as cs: | ||
execute_subprocess_async(cmd, env=self.get_env()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so it crashes on resume:
|
||
|
||
# test checkpoint loading | ||
self.assertIn(f"successfully loaded checkpoint from {output_dir}/checkpoints", cs.out) | ||
|
||
# test reports | ||
self.assertIn("consumed samples", cs.out) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stas00
Essentially the reduce is an in-place operator, which means at each forward pass,
self.weight
was updated with the sum of all the weights of all tp_ranks. We could try thinking of a better fix by doing a average reduce, but I'm scared back propagation doesn't play well with this in place logic.New test fails with:
This is more expected since the previous run should have consumed all the tokens. Going to update #272 and restart the training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we extend:
Megatron-DeepSpeed/megatron/mpu/mappings.py
Lines 22 to 30 in 87a9dba
to support an optional
ReduceOp.AVG
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is tricky. The reason why is this means that we need to implement custom backward function (since you compute the average, the gradient needs to be divided by the tp world size).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I don't think we save much compute by supporting that.