Skip to content

Commit

Permalink
[CrossEntropy] Fix triton cross_entropy_loss IMA for >=2B elements
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Oct 24, 2023
1 parent 02ac572 commit c79de85
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions flash_attn/ops/triton/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def cross_entropy_fwd_kernel(
):
row_idx = tl.program_id(0)
col_block_idx = tl.program_id(1)
logits_ptr = logits_ptr + row_idx * logits_row_stride
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
label_idx = tl.load(labels_ptr + row_idx)
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
Expand Down Expand Up @@ -107,8 +107,8 @@ def cross_entropy_bwd_kernel(
):
row_idx = tl.program_id(0)
col_block_idx = tl.program_id(1)
logits_ptr = logits_ptr + row_idx * logits_row_stride
dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
label_idx = tl.load(labels_ptr + row_idx)
if label_idx != ignored_index:
Expand Down

0 comments on commit c79de85

Please sign in to comment.