-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[1xMI300X] GPT-2 XL 1.5B FP8 Training ~30% slower than H100 FP8 #72
Comments
Hi @OrenLeung this has been reported. Thank you. |
For reference numbers, GPT2-XL 1.5B BF16 Pytorch.compile nightly at the same batch size as the reprod script has:
H100 Transformer Engine FP8 gets 1.3x faster than H100 BF16 Even when building with
This means that AMD FP8 is slower by 20%!! |
Hi @OrenLeung we have more people on the issues your reported and we will drive the fixes. Thank you. |
Hi @OrenLeung, I tried to run the same script on our single GPU H100 machine but I only got around 100 TFLOPs: Thanks. |
@OrenLeung Hi, I tried to reproduce your issue. I got 166TFLOPS on MI300. The major difference is I am using 2.6.0.dev20241014 since I cannot find the 20241012 version. May you please provide the system configuration on your side? Thanks. |
hi @wangye805 , It seems like you are an H100 PCIe 350W card while I am on an H100 SXM 700W card 638.4 TFLOP/s/GPU was from H200 at batch size=38. I apologize for the error. I have updated the following script using batch size = 14. Feel free to use a larger batch size if that helps mi300x TFLOP/s/GPU
to change batch size just do H100/H200 DockerfileFROM nvcr.io/nvidia/pytorch:24.09-py3
RUN pip install uv
RUN uv pip install --system ipython pytest fire pydantic
WORKDIR /workspace/llm-train-bench/
CMD ["/usr/bin/bash"] Reprod Scriptimport contextlib
import torch
import torch.nn.functional as F
import torch.nn as nn
from pydantic.dataclasses import dataclass
@dataclass
class GPTConfig:
n_layers: int # L
n_heads: int # H
d_embd: int # E
max_seq_len: int = 1024
vocab_size: int = 50304 # V
arch_name: str = 'gpt'
@staticmethod
def estimate_flops_per_token(model, config):
# get param count
N = sum(p.numel() for p in model.parameters())
# print param count in B
print(f"Param count: {N/1e9}B")
head_dim = config['d_embd'] // config['n_heads']
flops_per_token = 6 * N + 12 * config['n_layers'] * config['n_heads'] * head_dim * config['max_seq_len']
return flops_per_token
def __post_init__(self):
assert self.d_embd % self.n_heads == 0, 'd_embd must be a multiple of n_heads.'
class GPT(nn.Module):
def __init__(self, vocab_size, max_seq_len, n_layers, d_embd, **kwargs):
super().__init__()
self.tok_embd = nn.Embedding(vocab_size, d_embd)
self.pos_embd = nn.Embedding(max_seq_len, d_embd)
# self.tsfmr_blks = nn.ModuleList(GPTBlock(d_embd, **kwargs) for _ in range(n_layers))
import transformer_engine.pytorch as te
self.tsfmr_blks = nn.ModuleList(te.TransformerLayer(
d_embd,
d_embd * 4,
kwargs['n_heads'],
layer_number=i+1,
# Optional, for speedups
fuse_qkv_params=True,
attn_input_format='bshd'
)
for i in range(n_layers)
)
self.out_norm = nn.LayerNorm(d_embd)
def forward(self, idx_BT):
pos_T = torch.arange(idx_BT.size(1), dtype=torch.int64, device=idx_BT.device)
x_BTE = self.tok_embd(idx_BT) + self.pos_embd(pos_T).unsqueeze(0)
for tsfmr_blk in self.tsfmr_blks:
x_BTE = tsfmr_blk(x_BTE)
x_BTE = self.out_norm(x_BTE)
logits_BTV = x_BTE @ self.tok_embd.weight.T # Weight tying
return logits_BTV
def train(
gpu_id: int = 0,
bsz: int = 14,
grad_acc_steps: int = 8,
):
torch.manual_seed(3985)
torch.cuda.set_device(gpu_id)
cfg_json = {
"n_layers": 48,
"n_heads": 25,
"d_embd": 1600,
"max_seq_len": 1024,
"vocab_size": 50304,
}
cfg_m = GPTConfig(**cfg_json)
model = GPT(**cfg_json).to(gpu_id)
optimizer = torch.optim.AdamW(model.parameters(), fused=True)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda t: 1.0)
flops_per_token = cfg_m.estimate_flops_per_token(model, cfg_json)
flops_per_iter = flops_per_token * (bsz * cfg_m.max_seq_len)
flops_promised = 2600e12
model.train()
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
fp8_format = Format.HYBRID
# Reasonable default setting
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
# Note: wrapped ctx in a function because the te.fp8_autocast object cannot be reused as a context for some reason.
@contextlib.contextmanager
def ctx():
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
yield
with ctx():
for step_idx in range(100):
input_BT = torch.randint(50304, [bsz, 1024], dtype=torch.int64).to('cuda:0')
label_BT = torch.randint(50304, [bsz, 1024], dtype=torch.int64).to('cuda:0')
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
logits_BTV = model(input_BT)
loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())
loss /= grad_acc_steps
loss.backward()
if (step_idx + 1) % grad_acc_steps == 0: # Assume n_steps % grad_acc_steps == 0
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
end.record()
torch.cuda.synchronize()
t = start.elapsed_time(end) / 1e3
flops_per_sec = flops_per_iter / t
mfu = flops_per_sec / flops_promised
print(f'{(flops_per_sec/1e12):.2f} TFLOP/s MFU={mfu:.2%}')
if __name__ == '__main__':
import fire
fire.Fire(train) |
Hi, it seems like we have similar results. I am getting 174 TFLOP/s/GPU on MI300X using the recommended hipblaslt backend. unfortunately still slower than bf16 mi300x The system setup is provided in this gh issue already. Is there any specific information that you like to know? |
I found a bug in use_fused_attention filtering and sent out PR#81 to fix this issue. With ck fused attn enabled, I can see our MI300X achieves 250 TFLOP/s with hipblaslt, compared to 300 TFLOP/s using H100, if we set the batch size to 8. We will keep working on other optimizations to match the performance |
hi @wangye805 Thanks for the fix! 250TFLOP/s is much better now. Though, it still about 100 TFLOP/s off from H100 at the same batch size (bsz=8). and about 2.5x slower than H200 at bsz=38 here is my preliminary single gpu H100 700W SXM results:
cc: @hliuca |
hi @wangye805 , Can we keep this issue open till mi300x single gpu is able to match h100? |
@OrenLeung I guess if @wangye805 increases bsz, the perf will increase too. |
currently when both are at batch size 8, MI300x get 250TFLOP/s and H100 gets 345.41 TFLOP/s, so about a 30% difference. But intuition is that this 30% difference will stay the same as we increase batch size for both |
Hi @OrenLeung I am not very sure about this. For many LLM I see, when we keep increasing workload, the gap will get smaller and smaller, and we can pass. Of course, that also depends on optimizations. We will keep working on this. |
@OrenLeung reopened this issue until we can match H100 |
With hipblaslt auto tunning, I can get 370 tflop/s with batch size 50 |
Thanks @wangye805, can you share me the exact env flags and what value should I set them to turn hipBlasLt auto tuning? |
@OrenLeung You can use the following envs to turn on hipblaslt tunning: TE_HIPBLASLT_TUNING_RUN_COUNT=20 TE_HIPBLASLT_TUNING_ALGO_COUNT=1000 |
@wangye805 at giant batch size, h100 can get 490 TFLOP/s, h200 can get 630 TFLOP/s. good improvement but seems like still 120 TFLOP/s difference |
thanks! will definitely try that. |
Problem Description
Hi AMD team,
When trying to do FP8 Training on MI300X, it is extremely slower due to extremely high cpu overhead taking up more than 81% of the time. As you can see from the profile, most of the time is spent in CPU & doing
hipFree
. On GPT-2 XL 1.5B, TFLOP/s is at 22 TFLOP/s. This is 10x slower than mi300x bf16.For Comparsion, On H100 GPT-2 XL 1.5B, FP8 makes it to be 1.3x faster than BF16 H100. Not slower.
The Reprod Script is attached Below & can be ran using
NVTE_FUSED_ATTN_CK=0 python3 ./train.py
cc: @hliuca
Steps to Reproduce
Versions
Install Instructions
Reprod GPT2 XL 1.5B Training
Operating System
Ubuntu
CPU
AMD CPU
GPU
AMD Instinct MI300X
ROCm Version
ROCm 6.2.0
The text was updated successfully, but these errors were encountered: