Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ThunderFX fails with FP8 #1424

Open
mpatel31415 opened this issue Nov 12, 2024 · 0 comments
Open

ThunderFX fails with FP8 #1424

mpatel31415 opened this issue Nov 12, 2024 · 0 comments
Labels
dynamo+thunder for things that could be applicable to the dynamo+thunder frontend mixology Issues that the mixology team has surfaced TransformerEngine

Comments

@mpatel31415
Copy link
Contributor

mpatel31415 commented Nov 12, 2024

🐛 Bug

When training models: 'vicuna-7b-v1.5-16k', 'longchat-13b-16k', 'Mistral-7B-v0.2', 'falcon-180B', 'Llama-3-70B', 'CodeLlama-34b-hf' with FSDP and FP8 we get KeyError: 'scaling_fwd'. This might be also issue with Transformer Engine,, so I'm happy to move this issue to TE if needed.

Full traceback:

[rank7]: Traceback (most recent call last):
7: [rank7]: File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 974, in
7: [rank7]: CLI(benchmark_main)
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/jsonargparse/_cli.py", line 96, in CLI
7: [rank7]: return _run_component(components, init)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/jsonargparse/_cli.py", line 204, in _run_component
7: [rank7]: return component(**cfg)
7: [rank7]: ^^^^^^^^^^^^^^^^
7: [rank7]: File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 871, in benchmark_main
7: [rank7]: benchmark.train()
7: [rank7]: File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 765, in train
7: [rank7]: loss.backward()
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 624, in backward
7: [rank7]: torch.autograd.backward(
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/init.py", line 347, in backward
7: [rank7]: _engine_run_backward(
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
7: [rank7]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 307, in apply
7: [rank7]: return user_fn(self, *args)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 600, in wrapper
7: [rank7]: outputs = fn(ctx, *args)
7: [rank7]: ^^^^^^^^^^^^^^
7: [rank7]: File "/opt/pytorch/lightning-thunder/thunder/executors/torch_autograd.py", line 115, in backward
7: [rank7]: grads = ctx.compiled_backward([saved_tensors_list, ctx.saved_other], args)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
7: [rank7]: return func(*args, **kwargs)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
7: [rank7]: return func(*args, **kwargs)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
7: [rank7]: return func(*args, **kwargs)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "thunder.backward_fn_13", line 28, in backward_fn
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
7: [rank7]: return self.call_impl(*args, **kwargs)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in call_impl
7: [rank7]: return forward_call(*args, **kwargs)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/opt/pytorch/lightning-thunder/thunder/executors/transformer_engineex.py", line 205, in forward
7: [rank7]: weight_fp8, weight_t_fp8 = self.get_fp8_weight_version_compat(
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/opt/pytorch/lightning-thunder/thunder/executors/transformer_engineex.py", line 273, in get_fp8_weight_version_compat
7: [rank7]: weight_fp8 = self.get_fp8_workspace(
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/base.py", line 1086, in get_fp8_workspace
7: [rank7]: out.quantize
(tensor, noop_flag=skip_update_flag)
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/tensor/float8_tensor.py", line 642, in quantize

7: [rank7]: fp8_meta = dst._fp8_meta[fp8_meta_key]
7: [rank7]: ~~~~~~~~~~~~~^^^^^^^^^^^^^^
7: [rank7]: KeyError: 'scaling_fwd'

To Reproduce

Please use:
1 node(s), each with 8 GPUs.
Image "INTERNAL_IMAGE:pjnl_20241107"
Training script:
python /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py
--model_name Mistral-7B-v0.2
--distributed_mode fsdp
--shard_mode zero2
--compile dynamo_thunder
--checkpoint_activations True
--low_precision_mode fp8-delayed-te
--micro_batch_size 1

Environment

system.device_product_name DGXH100
system.gpu_driver_version 535.129.03
libraries.cuda 12.6.98.001
libraries.pip.lightning 2.4.0.dev20240728
libraries.pip.lightning-thunder 0.2.0.dev0
libraries.pip.lightning-utilities 0.11.8
libraries.pip.litgpt 0.4.11
libraries.pip.nvfuser 0.2.22+gitba4f7d4
libraries.pip.pytorch-lightning 2.4.0
libraries.pip.torch 2.6.0a0+gita9b4989
libraries.pip.torchao 0.6.1
libraries.pip.torchmetrics 1.5.1
libraries.pip.torchvision 0.19.0a0+d23a6e1

@mpatel31415 mpatel31415 changed the title Dynamo + Thunder fails with FP8 ThunderFX fails with FP8 Nov 12, 2024
@IvanYashchuk IvanYashchuk added TransformerEngine mixology Issues that the mixology team has surfaced labels Nov 12, 2024
@tfogal tfogal added the dynamo+thunder for things that could be applicable to the dynamo+thunder frontend label Nov 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamo+thunder for things that could be applicable to the dynamo+thunder frontend mixology Issues that the mixology team has surfaced TransformerEngine
Projects
None yet
Development

No branches or pull requests

3 participants