Skip to content

Commit

Permalink
added get_autocast_dtype("mps") for attention bias
Browse files Browse the repository at this point in the history
  • Loading branch information
aman-17 committed Feb 5, 2025
1 parent bc8ee29 commit cb30dc7
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.
elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
target_dtype = torch.get_autocast_cpu_dtype()
elif bias.device.type == "mps":
target_dtype = torch.float32
target_dtype = torch.get_autocast_dtype("mps")
if bias.dtype != target_dtype:
bias = bias.to(target_dtype)
ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
Expand Down

0 comments on commit cb30dc7

Please sign in to comment.