diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index e343df39e..a2786a402 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -186,7 +186,9 @@ jobs: pip install pytest-timeout - name: Run tests - run: pytest --timeout=600 tests + # run: pytest --timeout=600 tests + # run just the test we want for now + run: pytest --timeout=600 tests/test_training.py::MegDSTestTraining::test_layer_norm_consistent_0_bf16 stop-runner: name: Stop self-hosted EC2 runner diff --git a/megatron/arguments.py b/megatron/arguments.py index 2be64b77d..194a518ba 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -738,6 +738,7 @@ def _add_distributed_args(parser): group.add_argument('--use-cpu-initialization', action='store_true', default=None, help='If set, affine parallel weights ' 'initialization uses CPU' ) + group.add_argument('--force-sync-layer-norm-parameters', action="store_true") return parser diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index d9a30f468..3fe5dafeb 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -169,6 +169,69 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): # Trim off the filename and mp_rank_* directory. for _ in range(3): checkpoint_name = os.path.dirname(checkpoint_name) + + # Debug + layer_norms_params_end_with = [ + "word_embeddings.norm.weight", "word_embeddings.norm.bias", + "input_layernorm.weight", "input_layernorm.bias", + "post_attention_layernorm.weight", "post_attention_layernorm.bias", + "self_attention.dense.bias", "mlp.dense_4h_to_h.bias", + ] + for n,p in model[0].named_parameters(): + # Here is how you can access fp32 version of the bf16 param and fp32 optim states + # + # Note that there is an all_reduce called on all dp ranks when `get_full_hp_param` is called - + # so it's not free + # + # a. fp32 param + for end in layer_norms_params_end_with: + if n.endswith(end): + fp32_param = p.get_full_hp_param() + + fp32_params_acculumator = [ + torch.zeros_like(fp32_param) + for _ in range(mpu.get_tensor_model_parallel_world_size()) + ] + torch.distributed.gather( + fp32_param, + fp32_params_acculumator, + dst=0, + group=mpu.get_tensor_model_parallel_group() + ) + if mpu.get_tensor_model_parallel_rank() == 0: + square = torch.tensor([ + [ + torch.max(torch.abs(c1 - c2)) + for c2 in fp32_params_acculumator + ] for c1 in fp32_params_acculumator + ]) + print(f"Parameter name = {n}") + print(square) + + # b. fp32 optim states + for key in ['exp_avg', 'exp_avg_sq']: + full_optim_state = p.get_full_hp_param(optim_state_key=key) + + full_optim_state_acculumator = [ + torch.zeros_like(fp32_param) + for _ in range(mpu.get_tensor_model_parallel_world_size()) + ] + torch.distributed.gather( + full_optim_state, + full_optim_state_acculumator, + dst=0, + group=mpu.get_tensor_model_parallel_group() + ) + if mpu.get_tensor_model_parallel_rank() == 0: + square = torch.tensor([ + [ + torch.max(torch.abs(c1 - c2)) + for c2 in full_optim_state_acculumator + ] for c1 in full_optim_state_acculumator + ]) + print(f"Optimizer state: parameter name = {n}, key = {key}") + print(square) + model[0].save_checkpoint(checkpoint_name, client_state=state_dict) # Wait so everyone is done (necessary) diff --git a/megatron/data/data_samplers.py b/megatron/data/data_samplers.py index 1cbeac312..b933ff34e 100644 --- a/megatron/data/data_samplers.py +++ b/megatron/data/data_samplers.py @@ -52,6 +52,7 @@ def build_pretraining_data_loader(dataset, consumed_samples): return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=args.num_workers, + generator=torch.Generator().manual_seed(args.seed), pin_memory=True) class MegatronPretrainingSampler: diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 8430f528c..cd10e5b77 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -19,10 +19,11 @@ import numbers import torch -from megatron import mpu +from megatron import mpu, get_args from torch.nn.parameter import Parameter from torch.nn import init import importlib +from megatron import mpu global fused_mix_prec_layer_norm_cuda fused_mix_prec_layer_norm_cuda = None @@ -63,6 +64,7 @@ class MixedFusedLayerNorm(torch.nn.Module): def __init__(self, normalized_shape, eps=1e-5): super(MixedFusedLayerNorm, self).__init__() + args = get_args() global fused_mix_prec_layer_norm_cuda fused_mix_prec_layer_norm_cuda = importlib.import_module( @@ -75,6 +77,7 @@ def __init__(self, normalized_shape, eps=1e-5): self.weight = Parameter(torch.Tensor(*normalized_shape)) self.bias = Parameter(torch.Tensor(*normalized_shape)) self.reset_parameters() + self.force_sync_layer_norm_parameters = args.force_sync_layer_norm_parameters def reset_parameters(self): @@ -84,19 +87,17 @@ def reset_parameters(self): def forward(self, input): - weights = [torch.empty_like(self.weight) for tp in range(mpu.get_tensor_model_parallel_world_size())] - torch.distributed.all_gather(weights, self.weight, group=mpu.get_tensor_model_parallel_group()) - biases = [torch.empty_like(self.bias) for tp in range(mpu.get_tensor_model_parallel_world_size())] - torch.distributed.all_gather(biases, self.bias, group=mpu.get_tensor_model_parallel_group()) - if any(torch.any(weight != self.weight) for weight in weights): - if mpu.get_tensor_model_parallel_rank() == 0: - print("Weight sync failed") - print(weights) - if any(torch.any(bias != self.bias) for bias in biases): - if mpu.get_tensor_model_parallel_rank() == 0: - print("Bias sync failed") - print(biases) + if self.force_sync_layer_norm_parameters: + tp_world_size = mpu.get_tensor_model_parallel_world_size() + # TODO: hack in order to synchronize all layer norms despite them being unsynched + weight = torch.clone(self.weight) + bias = torch.clone(self.bias) + weight = mpu.reduce_from_tensor_model_parallel_region(weight) / tp_world_size + bias = mpu.reduce_from_tensor_model_parallel_region(bias) / tp_world_size + else: + weight = self.weight + bias = self.bias return FusedLayerNormAffineFunction.apply( - input, self.weight, self.bias, self.normalized_shape,self.eps) + input, weight, bias, self.normalized_shape,self.eps) diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index 4d94156ac..e649a1259 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -241,7 +241,7 @@ def forward(self, input_): self.sparse) # Mask the output embedding. if self.tensor_model_parallel_size > 1: - output_parallel[input_mask, :] = 0.0 + output_parallel = output_parallel.masked_fill(input_mask[..., None], 0.0) # Reduce across all the model parallel GPUs. output = reduce_from_tensor_model_parallel_region(output_parallel) diff --git a/megatron/testing_utils.py b/megatron/testing_utils.py index 2143b610b..9521cb361 100644 --- a/megatron/testing_utils.py +++ b/megatron/testing_utils.py @@ -232,9 +232,9 @@ def get_gpu_count(): return 0 def torch_assert_equal(actual, expected, **kwargs): - # assert_equal was added around pt-1.9, it does better checks - e.g will check dimensions match - if hasattr(torch.testing, "assert_equal"): - return torch.testing.assert_equal(actual, expected, **kwargs) + # assert_close was added around pt-1.9, it does better checks - e.g will check dimensions match + if hasattr(torch.testing, "assert_close"): + return torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0, **kwargs) else: return torch.allclose(actual, expected, rtol=0.0, atol=0.0) @@ -886,4 +886,4 @@ def flatten_arguments(args): Example: {"arg1": "value1", "arg2": "value2"} -> ["IGNORED", "arg1", "value1", "arg2", "value2"] """ - return ["IGNORED"] + [item for key_value in args.items() for item in key_value if item != ""] \ No newline at end of file + return ["IGNORED"] + [item for key_value in args.items() for item in key_value if item != ""] diff --git a/requirements.txt b/requirements.txt index da76b5e44..f0ec53a7d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,9 +6,10 @@ pybind11 regex six tensorboard -torch>=1.7 +torch>=1.11 transformers -DeepSpeed @ git+https://github.com/microsoft/DeepSpeed.git +# for now using this branch for bf16 work +DeepSpeed @ git+https://github.com/microsoft/DeepSpeed.git@olruwase/bf16-updates # versions from HF transformers black==21.4b0 isort>=5.5.4 diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index ed383e17a..5e821cf37 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -1,3 +1,4 @@ +import os import unittest from random import randint from unittest.mock import patch @@ -8,8 +9,11 @@ import numpy as np import pytest +from parameterized import parameterized + from megatron import initialize_megatron, get_args, get_tokenizer, global_vars -from megatron.testing_utils import TestCasePlus, mockenv_context, flatten_arguments, require_deepspeed, require_torch_multi_gpu +from megatron.testing_utils import TestCasePlus, mockenv_context, flatten_arguments, require_deepspeed, \ + require_torch_multi_gpu, torch_assert_equal, CaptureStdout, execute_subprocess_async from megatron.training import setup_model_and_optimizer from megatron.mpu.mappings import gather_from_tensor_model_parallel_region from pretrain_gpt import model_provider as gpt_model_provider, get_batch_pipe as get_gpt_batch_pipe @@ -44,7 +48,6 @@ def get_default_args(self): "--merge-file": f"{data_dir}/gpt2-tiny-merges.txt", "--vocab-file": f"{data_dir}/gpt2-tiny-vocab.json", "--data-impl": "mmap", - "--split": "949,50,1", "--distributed-backend": "nccl", "--weight-decay": "1e-2", "--clip-grad": "1.0", @@ -57,14 +60,14 @@ def get_default_args(self): # OUTPUT_ARGS "--log-interval": "10", - "--save-interval": "500", - "--eval-interval": "100", - "--eval-iters": "10", + "--save-interval": "10", + "--eval-interval": "10", + "--eval-iters": "5", "--checkpoint-activations": "", #ds args "--deepspeed": "", - "--deepspeed_config":f"{self.test_file_dir_str}/ds_config.json", + "--deepspeed_config": f"{self.test_file_dir_str}/ds_config.json", "--zero-stage": "1", "--deepspeed-activation-checkpointing": "" # DATA_ARGS @@ -115,8 +118,6 @@ def create_model_inputs(tokens): tokenizer = get_tokenizer() - model, _, _ = setup_model_and_optimizer(gpt_model_provider) - model = model[0] if load is not None: # Hack (same as in eval_harness/evaluate.py) # Loading pipelined models in deepspeed with different TP than it was originally trained on fails @@ -127,6 +128,10 @@ def create_model_inputs(tokens): # Deepspeed does however manage to load the model if we just turn off this sanity check. deepspeed.runtime.state_dict_factory.MegatronSDLoader.sanity_check = lambda self, ckpt_file_name: None + model, _, _ = setup_model_and_optimizer(gpt_model_provider) + model = model[0] + + if load is not None: zero_enabled = model._config.zero_enabled model._config.zero_enabled = False _, _ = model.load_checkpoint(load, load_optimizer_states=False, load_lr_scheduler_states=False, load_module_only=True) @@ -190,7 +195,7 @@ def test_alibi_tp(self): output2, tokens = result[0] logging.getLogger().critical(output-output2) - self.assertTrue(np.allclose(output,output2, atol=5e-3, rtol=0), "Different results when running with TP=1 and TP=2") + self.assertTrue(np.allclose(output, output2, atol=5e-3, rtol=0), "Different results when running with TP=1 and TP=2") @@ -293,5 +298,167 @@ def test_tokenizer_raise_error_make_vocab_size_divisible_by(self): self.assertEqual(str(exc_info.value), "5121 is not divisible by 128") + @parameterized.expand(["bf16", "fp16"]) + def test_layer_norm_consistent(self, variation): + src_dir = self.src_dir + output_dir = self.get_auto_remove_tmp_dir() + tp_size = 2 + pp_size = 1 + num_gpus = tp_size * pp_size # dp = 1 + seq_len = 128 + data_dir = f"{self.data_dir}/gpt2" + + command_args = self.get_default_args() + command_args["--pad-vocab-size-to"] = "5120" # This is equal to 128 * 40 which is above the len of gp2-tiny vocabulary + command_args["--position-embedding-type"] = "alibi" + command_args["--embed-layernorm"] = "" + command_args["--tensor-model-parallel-size"] = f"{tp_size}" + command_args["--pipeline-model-parallel-size"] = f"{pp_size}" + command_args["--save"] = f"{output_dir}/checkpoints" + command_args["--load"] = f"{output_dir}/checkpoints" + command_args["--data-path"] = f"{data_dir}/meg-gpt2-openwebtext_text_document" + command_args["--train-samples"] = "200" + command_args["--rampup-batch-size"] = "4 4 200" + command_args["--seq-length"] = f"{seq_len}" + command_args["--exit-interval"] = "20" + del command_args["--train-iters"] + del command_args["--lr-decay-iters"] + command_args["--tensorboard-dir"] = f"{output_dir}/tensorboard" + command_args["--lr"] = "1e-1" + + if variation == "bf16": + command_args["--bf16"] = "" + del command_args["--fp16"] + command_args["--deepspeed_config"] = f"{self.test_file_dir_str}/ds_config_bf16.json" + command_args["--zero-stage"] = "0" + elif variation == "fp16": + command_args["--fp16"] = "" + command_args["--deepspeed_config"] = f"{self.test_file_dir_str}/ds_config.json" + command_args["--zero-stage"] = "1" + + # args, ds_args, num_gpus = self.get_variation_config("base", output_dir, n_samples=200) + + script = [f"{src_dir}/pretrain_gpt.py"] + launcher = f"deepspeed --num_nodes 1 --num_gpus {num_gpus}".split() + cmd = launcher + script + [elt for elts in [f"{key} {value}".split() for key, value in command_args.items()] for elt in elts] + # keep for quick debug + # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die + + with CaptureStdout() as cs: + execute_subprocess_async(cmd, env=self.get_env()) + + # 1. test that the layer norm weights and biases are synchronized + checkpoints = ["global_step10", "global_step20"] + + # Check transformer layer norm + keys_to_compare = [ + "input_layernorm.weight", + "input_layernorm.bias", + "post_attention_layernorm.weight", + "post_attention_layernorm.bias", + "self_attention.dense.bias", + "mlp.dense_4h_to_h.bias" + ] + files_to_compare = [[f"layer_{layer_id:02d}-model_{tp:02d}-model_states.pt" for tp in range(tp_size)] for + layer_id in [3, 4]] + for checkpoint in checkpoints: + checkpoint_path = os.path.join(output_dir, "checkpoints", checkpoint) + for key in keys_to_compare: + for files in files_to_compare: + weights = [torch.load(os.path.join(checkpoint_path, file))[key] for file in files] + ref = weights[0] + for weight in weights[1:]: + torch_assert_equal(ref, weight, check_device=False) + + # Check embed layer norm + keys_to_compare = [ + "word_embeddings.norm.weight", + "word_embeddings.norm.bias" + ] + files_to_compare = [[f"layer_{layer_id:02d}-model_{tp:02d}-model_states.pt" for tp in range(tp_size)] for + layer_id in [1]] + for checkpoint in checkpoints: + checkpoint_path = os.path.join(output_dir, "checkpoints", checkpoint) + for key in keys_to_compare: + for files in files_to_compare: + weights = [torch.load(os.path.join(checkpoint_path, file))[key] for file in files] + ref = weights[0] + for weight in weights[1:]: + torch_assert_equal(ref, weight, check_device=False) + + # Final layer norm + keys_to_compare = [ + "weight", + "bias" + ] + files_to_compare = [[f"layer_{layer_id:02d}-model_{tp:02d}-model_states.pt" for tp in range(tp_size)] + for + layer_id in [6]] + for checkpoint in checkpoints: + checkpoint_path = os.path.join(output_dir, "checkpoints", checkpoint) + for key in keys_to_compare: + for files in files_to_compare: + weights = [torch.load(os.path.join(checkpoint_path, file))[key] for file in files] + ref = weights[0] + for weight in weights[1:]: + torch_assert_equal(ref, weight, check_device=False) + + keys_to_compare = ["torch_rng_state"] + files_to_compare = [[f"mp_rank_{tp + pp*tp_size:02d}_model_states.pt" for tp in range(tp_size)] for pp in range(pp_size)] + for checkpoint in checkpoints: + checkpoint_path = os.path.join(output_dir, "checkpoints", checkpoint) + for key in keys_to_compare: + for files in files_to_compare: + weights = [torch.load(os.path.join(checkpoint_path, file))[key] for file in files] + ref = weights[0] + for weight in weights[1:]: + assert (ref == weight).all(), f"key: {key} ref: {ref}, weight: {weight}" + + + # 2. test training from checkpoint: resume + command_args["--exit-interval"] = "30" + cmd = launcher + script + [elt for elts in [f"{key} {value}".split() for key, value in command_args.items()] for elt in elts] + + # now do it again, this time resuming from the checkpoint + with CaptureStdout() as cs: + execute_subprocess_async(cmd, env=self.get_env()) + + # test checkpoint loading + self.assertIn(f"successfully loaded checkpoint from {output_dir}/checkpoints", cs.out) + + # test reports + self.assertIn("consumed samples", cs.out) + + # 3. test that inference with changes TP works. + mp.set_start_method('spawn', force=True) + del command_args["--rampup-batch-size"] + command_args["--tensor-model-parallel-size"] = "1" + del command_args["--load"] + del command_args["--save"] + command_args["--force-sync-layer-norm-parameters"] = "" + + checkpoints_path = os.path.join(output_dir, "checkpoints") + pool = Pool(1) + result = pool.map(MegDSTestTP.infer_model, [((0, 1, command_args, None, None, checkpoints_path))]) + pool.close() + pool.join() + + output, tokens = result[0] + logging.getLogger().info("First done!") + + command_args["--tensor-model-parallel-size"] = "2" + + pool = Pool(2) + result = pool.map(MegDSTestTP.infer_model, + [((0, 2, command_args, tokens, None, checkpoints_path)), ((1, 2, command_args, tokens, None, checkpoints_path))]) + pool.close() + pool.join() + + output2, tokens = result[0] + + logging.getLogger().critical(output - output2) + self.assertTrue(np.allclose(output, output2, atol=0, rtol=0), + "Different results when running with TP=1 and TP=2") + if __name__ == '__main__': unittest.main() diff --git a/tests/test_training.py b/tests/test_training.py index fb72e59c6..bf31bf904 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -594,112 +594,3 @@ def test_skip_train_iteration(self): train_iterations = range(1,10) for i in train_iterations: self.assertTrue(f"iteration {i:8d}/" in cs.out) - - @parameterized.expand(["bf16", "fp16"]) - def test_layer_norm_consistent(self, variation): - src_dir = self.src_dir - output_dir = self.get_auto_remove_tmp_dir() - num_gpus = 2 - seq_len = 128 - data_dir = f"{self.data_dir}/gpt2" - args = f""" - --tensor-model-parallel-size {2} - --pipeline-model-parallel-size {1} - --distributed-backend nccl - - --log-interval 1 - --save-interval 10 - --eval-interval 10 - --eval-iters 5 - --checkpoint-activations - --partition-activations - --exit-interval {20} - - --merge-file {data_dir}/gpt2-tiny-merges.txt - --vocab-file {data_dir}/gpt2-tiny-vocab.json - --save {output_dir}/checkpoints - --load {output_dir}/checkpoints - --data-path {data_dir}/meg-gpt2-openwebtext_text_document - --tensorboard-dir {output_dir}/tensorboard - --tensorboard-queue-size 5 - --log-timers-to-tensorboard - --log-batch-size-to-tensorboard - --log-validation-ppl-to-tensorboard - - --num-layers 2 - --hidden-size 64 - --num-attention-heads 2 - --seq-length {seq_len} - --max-position-embeddings 1024 - --micro-batch-size 2 - --global-batch-size 16 - - --optimizer adam - --adam-beta1 0.9 - --adam-beta2 0.95 - --adam-eps 1e-8 - --lr 1e-1 - --clip-grad 1.0 - --weight-decay 1e-1 - --embed-layernorm - - --log-level debug - --log-level-replica info - - --rampup-batch-size 2 2 200 - --train-samples 200 - - --position-embedding-type alibi - """.split() - - ds_args = f""" - --deepspeed - --deepspeed-activation-checkpointing - """.split() - - if variation == "bf16": - args.append("--bf16") - ds_args += [ - "--zero-stage", "0", - "--deepspeed_config", f"{self.test_file_dir_str}/ds_config_bf16.json" - ] - elif variation == "fp16": - args.append("--fp16") - ds_args += [ - "--zero-stage", "1", - "--deepspeed_config", f"{self.test_file_dir_str}/ds_config.json" - ] - - # args, ds_args, num_gpus = self.get_variation_config("base", output_dir, n_samples=200) - - script = [f"{src_dir}/pretrain_gpt.py"] - launcher = get_launcher(num_gpus) - cmd = launcher + script + args + ds_args - # keep for quick debug - # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die - - with CaptureStdout() as cs: - execute_subprocess_async(cmd, env=self.get_env()) - - checkpoints = ["global_step10", "global_step20"] - keys_to_compare = ["input_layernorm.weight", "input_layernorm.bias", "post_attention_layernorm.weight", "post_attention_layernorm.bias"] - files_to_compare = [[f"layer_{layer_id:02d}-model_{tp:02d}-model_states.pt" for tp in range(num_gpus)] for layer_id in [3,4]] - for checkpoint in checkpoints: - checkpoint_path = os.path.join(output_dir, "checkpoints", checkpoint) - for key in keys_to_compare: - for files in files_to_compare: - weights = [torch.load(os.path.join(checkpoint_path, file))[key] for file in files] - ref = weights[0] - for weight in weights[1:]: - torch_assert_equal(ref, weight, rtol=0.0, atol=0.0, check_device=False) - - keys_to_compare = ["word_embeddings.norm.weight"] - files_to_compare = [[f"layer_{layer_id:02d}-model_{tp:02d}-model_states.pt" for tp in range(num_gpus)] for layer_id in [1]] - for checkpoint in checkpoints: - checkpoint_path = os.path.join(output_dir, "checkpoints", checkpoint) - for key in keys_to_compare: - for files in files_to_compare: - weights = [torch.load(os.path.join(checkpoint_path, file))[key] for file in files] - ref = weights[0] - for weight in weights[1:]: - torch_assert_equal(ref, weight, rtol=0.0, atol=0.0, check_device=False) \ No newline at end of file