-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP 8xMI300X] Llama3 8B FP8 is 21% slower than BF16 & OOMs on the same batch size #79
Comments
@OrenLeung This issue was due to that our dev branch does not have all the recent optimizations on DDP and FSDP from NVTE yet. We have a PR in review that would be merged soon that could resolve this issue (#66). 8xMI300X FSDP TE FP8 (batch size 2): 442TFLOPs |
Thanks @wenchenvincent for looking into this. This is quite competitive to h100 on perf per TCO basis sinec mi300x TCO is 78% of an h100. But unfortunately it is not competitive to H200, any other PRs you have in the pipeline that would help? here is my results for this llama3 8B full model:
Full Response in the llama3 70B proxy gh issue #78 (comment) cc: @hliuca |
Thank you again @OrenLeung We like your data, and they will be great reference for our future optimization goals. We will see if we can pass H200 using MI300x :-) |
hi @hliuca , I am glad we were able to provide an optimization goal. Please note that all of our H100 & H200 that we shared are preliminary and will probably improve too as I do tuning on them. Also please note that we are benchmarking & evaluating AMD/Nvidia on other real world transformer models and real world GEMM training shapes that we have not shared with Nvidia or AMD to ensure that these patches to pytorch, te, hipblaslt, etc made are generalizable. |
I now get an preliminary number of 464 TFLOP/s/GPU (batch = 4) after #66 got merged to main on our internal codebase for this model. After 32 Warmup: Mean TFLOP/s: 464.33 Mean MFU: 17.79% @wenchenvincent & your team, very impressive work to be able to boost perf by 25% in less than 7 days ! it seems like it is competitive to H100 on perf per TCO but still not on pure performance. |
@OrenLeung We have the optimized cast transpose Triton kernel merged in. And with that, I got the following improvement: 8xMI300X FSDP TE FP8 (batch size 4): 475.76 TFLOP/s -> 523.27 TFLOP/s |
Problem Description
Llama3 8B FP8 OOMs at the same batch size as BF16. I need to decrease the batch size to
2
for it to not OOM. At batch size 2, TE FP8 is 21% slower than torch compile BF16 nightly.I have verified that on H100, TE FP8 is able to fit the same batch size as BF16 and results in an 11% increase in perf for this model.
preliminary Results
Commands:
python ./train_fsdp_llama_8b.py --bsz=2
python ./train_fsdp_llama_8b.py --bsz=4
torch.OutOfMemoryError: HIP out of memory. Tried to allocate 4.00 GiB. GPU 5 has a total capacity of 191.98 GiB of which 2.55 GiB is free. Of the allocated memory 181.21 GiB is allocated by PyTorch, and 3.61 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_HIP_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
cc: @hliuca
Operating System
Ubuntu
CPU
AMD CPU
GPU
MI300X
ROCm Version
ROCm 6.2.0
ROCm Component
No response
Steps to Reproduce
Versions
Docker Image
FROM rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0 RUN apt install nano RUN pip install uv RUN uv pip install --system ipython pytest fire pydantic pybind11 RUN pip3 uninstall -y torch RUN pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2 WORKDIR /workspace/llm-train-bench/ CMD ["/usr/bin/bash"]
TE install Instructions (done inside docker container)
Reprod
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response
The text was updated successfully, but these errors were encountered: