Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HF Qwen 2 with Thunder returns a slightly different loss function output #1407

Open
IvanYashchuk opened this issue Nov 7, 2024 · 1 comment
Assignees
Labels
high priority huggingface For supporting HF models

Comments

@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Nov 7, 2024

🐛 Bug

We need to determine whether Thunder has real accuracy problems computing HF's Qwen 2 model.

The test added in #1406 might fail because the loss computed by the Thunder-generated function is slightly different from HF's implementation. Here's the snippet to reproduce the problem:

import torch
from thunder.dynamo import ThunderCompiler
from transformers import Qwen2Config, Qwen2ForCausalLM
torch.manual_seed(0)

# https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json
configuration = Qwen2Config(
    # Qwen2.5-7B-Instruct uses Grouped-Query Attention, while the default
    # config uses Multi-Head Attention
    num_attention_heads=28,
    num_key_value_heads=4,
    # Scaled down for testing
    hidden_size=56,
    vocab_size=2,
    max_position_embeddings=32,
)
configuration.num_hidden_layers = 1
with torch.device("cuda"):
    model = Qwen2ForCausalLM(configuration).to(torch.bfloat16)

# thunder.jit doesn't work with Qwen2, so we use torch.compile
# https://github.com/Lightning-AI/lightning-thunder/issues/1405
backend = ThunderCompiler()
compiled_model = torch.compile(model, backend=backend, fullgraph=True)

input_ids = torch.randint(0, configuration.vocab_size, (1, configuration.max_position_embeddings), device="cuda")
# input_ids = torch.ones_like(input_ids) * 0
ref_output = model(input_ids=input_ids, labels=input_ids)
ref_loss = ref_output.loss

compiled_output = compiled_model(input_ids=input_ids, labels=input_ids)
compiled_loss = compiled_output.loss
torch.testing.assert_close(compiled_loss, ref_loss)
AssertionError: Scalars are not close!

Expected 0.7005462646484375 but got 0.7004587650299072.
Absolute difference: 8.749961853027344e-05 (up to 1e-05 allowed)
Relative difference: 0.00012490198427392128 (up to 1.3e-06 allowed)

Thunder may return a different result because upcasting and downcasting to bf16 are different. However, we need to know that Thunder's is indeed more accurate by comparing the distance to the fp64 result and the tolerances in the test may need to be tweaked.

cc @apaz-cli

@IvanYashchuk IvanYashchuk added the huggingface For supporting HF models label Nov 7, 2024
@IvanYashchuk
Copy link
Collaborator Author

The accuracy problem could be related to #1338.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority huggingface For supporting HF models
Projects
None yet
Development

No branches or pull requests

3 participants