Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dtype mismatch in fused_linear_cross_entropy_forward #307

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

kostum123
Copy link

@kostum123 kostum123 commented Oct 12, 2024

Fixes #305

Fix dtype mismatch in fused_linear_cross_entropy_forward function.

  • Cast logits_chunk to the data type of _input_chunk before performing operations on it.

I tested this in Colab after the change and it solved the problem.

{
"epoch": 1.0,
"eval_loss": 1.885668396949768,
"eval_runtime": 0.1708,
"eval_samples_per_second": 5.856,
"eval_steps_per_second": 5.856,
"total_flos": 1766475165597696.0,
"train_loss": 1.9928909236309575,
"train_runtime": 115.5799,
"train_samples_per_second": 0.441,
"train_steps_per_second": 0.441
}

For more details, open the Copilot Workspace session.

Fixes linkedin#305

Fix dtype mismatch in fused_linear_cross_entropy_forward function.

* Cast `logits_chunk` to the data type of `_input_chunk` before performing operations on it.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/linkedin/Liger-Kernel/issues/305?shareId=XXXX-XXXX-XXXX-XXXX).
Copy link
Collaborator

@yundai424 yundai424 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logit chunk has to stay at fp32 while computing forward & backward w.r.t. CE loss to ensure numerical stability and consistency with HF model code (https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2/modeling_qwen2.py#L1187). I think the issue here is that after CE computation is done and logit has been casted back to torch autocast dtype (line 102 in https://github.com/linkedin/Liger-Kernel/blob/v0.3.0/src/liger_kernel/ops/fused_linear_cross_entropy.py#L122) there is somehow mismatch between the dtype of _input.dtype and the inferred autocast dtype 🤔 we might need a different solution here. Will think about it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RuntimeError due to dtype mismatch in fused_linear_cross_entropy_forward
3 participants