Skip to content

Commit

Permalink
defi jagged flash attention benchmark
Browse files Browse the repository at this point in the history
Summary: added a benchmark for defi_jagged_flahs_attention kernel

Reviewed By: xuzhao9

Differential Revision: D55110675

fbshipit-source-id: 5e763f7717e754566f85b2d42c0d197afdac1ea2
  • Loading branch information
chenyang78 authored and facebook-github-bot committed Apr 4, 2024
1 parent ef317e5 commit abf184d
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy
from enum import Enum
import argparse
import random
import triton
import torch
import gc
Expand Down Expand Up @@ -305,6 +306,50 @@ def enable_fp16(self):
self.example_inputs = input_cast(tensor_cond, tensor_action, self.example_inputs)


# a function copied from https://fburl.com/code/hdypvhjw, which generate offsets
# for jagged tensors with the given load_factor
def generate_offsets(
self,
batch_size: int,
max_seq_len: int,
load_factor: float,
offsets_dtype: torch.dtype,
) -> torch.Tensor:
total_length = int(batch_size * max_seq_len * load_factor)
avg_length = total_length // batch_size
std = avg_length // 3 # rather arbitrary, but likely reasonable
lengths = [random.gauss(avg_length, std) for _ in range(batch_size)]
lengths = [int(min(max_seq_len, max(L, 0))) for L in lengths]

if load_factor == 1.0:
lengths = [max_seq_len] * batch_size

diff = sum(lengths) - total_length
idx_and_lengths = list(enumerate(lengths))
random.shuffle(idx_and_lengths)

for i, length in idx_and_lengths:
if diff == 0:
break
elif diff > 0:
delta = min(length, diff)
lengths[i] -= delta
diff -= delta
else:
delta = min(max_seq_len - length, -diff)
lengths[i] += delta
diff += delta

offsets = [0]
for length in lengths:
offsets.append(offsets[-1] + length)

return torch.tensor(
offsets,
dtype=offsets_dtype,
)


def enable_channels_last(self):
tensor_cond = lambda x: x.dim() == 4
tensor_action = lambda x: x.to(memory_format=torch.channels_last)
Expand Down

0 comments on commit abf184d

Please sign in to comment.