Skip to content

Commit

Permalink
Rework T5 RPE test
Browse files Browse the repository at this point in the history
Refactor and add default torch implementation against which we allclose.
Set sizes to known good values that pass the checks; it is easy to fall off the cliff with various size combinations.

Additionally, with the following, one can remove the inplace hack.
```
pip install -r pytorch-rocm-requirements.txt  -e .
```
  • Loading branch information
nicolasvasilache committed Feb 3, 2025
1 parent 48aa3ec commit c75e45c
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 72 deletions.
16 changes: 4 additions & 12 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,18 +663,10 @@ def invoke_vmfb(
if not (run or run_bench):
return

# TODO: the following crashes with:
#
# File "/home/nico/dev/iree-turbine/iree/turbine/kernel/wave/utils.py",
# line 550, in get_device_uuid
# uuid = str(torch.cuda.get_device_properties(device).uuid)
# AttributeError: 'torch._C._CudaDeviceProperties' object has no attribute
# 'uuid'. Hack it out for now.
#
# if inplace:
# # Select device as the GPU, where input tensors are coming from.
# device_uuid = get_device_uuid(kernel_inputs + kernel_outputs)
# device = f"{device}://GPU-{device_uuid}"
if inplace:
# Select device as the GPU, where input tensors are coming from.
device_uuid = get_device_uuid(kernel_inputs + kernel_outputs)
device = f"{device}://GPU-{device_uuid}"
rt_config = rt.Config(device)
device = rt_config.device
vm_instance = rt_config.vm_instance
Expand Down
193 changes: 133 additions & 60 deletions playground/vanilla_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,66 +4,34 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import math
import torch
from torch.profiler import profile, record_function, ProfilerActivity
from torch.nn import functional as F
from torch.testing import assert_close
from typing import Any, Callable

from iree.turbine.kernel.gen import TestLaunchContext
from iree.turbine.kernel.wave.constraints import MMAType
from iree.turbine.kernel.wave.templates.attention_common import AttentionShape
from iree.turbine.kernel.wave.templates.vanilla_attention import (
get_vanilla_attention_kernel as get_vanilla_attention_kernel_reference
get_vanilla_attention_kernel as get_vanilla_attention_kernel_reference)
from iree.turbine.kernel.wave.utils import (
device_randn,
device_zeros,
get_default_run_config,
to_default_device,
)
from iree.turbine.kernel.wave.utils import (device_randn, device_zeros,
get_default_run_config)
from vanilla_attention_template import get_vanilla_attention_kernel

torch.manual_seed(0)
torch.set_printoptions(linewidth=10000000)

# num_query_heads, num_kv_heads, head_size, head_size_kv
shape = AttentionShape(1, 128, 8, 64)
shape.query_seq_len = 64
shape.kv_seq_len = 64

# T5 RPE parameter
max_context_length = 24

base_attention, hyperparams, dynamic_symbols, dynamic_symbols_map = \
get_vanilla_attention_kernel(
shape,
mfma_variant=[MMAType.F32_16x16x16_F16,
MMAType.F32_16x16x16_F16],
dynamic_dims=False,
max_context_length = max_context_length + 2)

base_attention_reference, _, _, _ = \
get_vanilla_attention_kernel_reference(
shape,
mfma_variant=[MMAType.F32_16x16x16_F16,
MMAType.F32_16x16x16_F16],
dynamic_dims=False)


vB = shape.num_query_heads
vM = int(shape.query_seq_len)
vN = shape.head_size_kv
vK1 = shape.head_size
vK2 = int(shape.kv_seq_len)
q = device_randn(vB, vM, vK1, dtype=torch.float16)
k = device_randn(vB, vK2, vK1, dtype=torch.float16)
v = device_randn(vB, vN, vK2, dtype=torch.float16)
output = device_zeros(vB, vM, vN, dtype=torch.float32)
output_reference = device_zeros(vB, vM, vN, dtype=torch.float32)
output_reference_2 = device_zeros(vB, vM, vN, dtype=torch.float32)

# Applied pre-softmax on the MMA'ed result so f32.
# Provision more room for clipping and adding 0 at the boundaries.
rpe = device_randn(max_context_length + 2, dtype=torch.float32)
rpe[0] = 0
rpe[max_context_length + 1] = 0

torch.set_printoptions(
linewidth=1000000,
threshold=1000000,
precision=3,
)

### TKW Harness
def run(fun: Callable, hparams, *args) -> Any:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
Expand All @@ -83,23 +51,128 @@ def run(fun: Callable, hparams, *args) -> Any:
prof.key_averages(group_by_input_shape=True).table(
sort_by="self_cuda_time_total", row_limit=10))

#################################################################################
# INIT VALS
#################################################################################
# num_query_heads, num_kv_heads, head_size, head_size_kv
shape = AttentionShape(128, 128, 128, 128)
shape.query_seq_len = 128
shape.kv_seq_len = 128

assert shape.num_query_heads == shape.num_kv_heads, \
"expected query and kv to have the same number of heads!"

q_shape = (shape.num_query_heads, shape.query_seq_len, shape.head_size)
k_shape = (shape.num_kv_heads, shape.kv_seq_len, shape.head_size)
v_shape = (shape.num_kv_heads, shape.kv_seq_len, shape.head_size_kv)
o_shape = (shape.num_query_heads, shape.query_seq_len, shape.head_size_kv)

q = device_randn(q_shape, dtype=torch.float16)
k = device_randn(k_shape, dtype=torch.float16)
v = device_randn(v_shape, dtype=torch.float16)
tkw_attention = device_zeros(o_shape, dtype=torch.float32)
tkw_attention_with_rpe = device_zeros(o_shape, dtype=torch.float32)

log2e = 1.44269504089
dk_sqrt = math.sqrt(1.0 / q.shape[-1])

#################################################################################
# T5 RPE INIT VALS
#################################################################################
# T5 RPE parameter
max_context_length = 33

# Applied pre-softmax on the MMA'ed result so f32.
# Provision more room for clipping and adding 0 at the boundaries.
rpe = device_zeros(1000 + max_context_length + 2, dtype=torch.float32)
rpe = rpe[:max_context_length + 2].view(max_context_length + 2)
rpe.copy_(device_randn(max_context_length + 2, dtype=torch.float32))
rpe[0] = 0
rpe[max_context_length + 1] = 0

def t5_rpe_masked_cond(rpe, max_context_length: int, sequence_length: int,
dtype):
positions = to_default_device(torch.arange(sequence_length))
pos_diff = positions.unsqueeze(1) - positions.unsqueeze(0)
mask = to_default_device((pos_diff >= 0)
& (pos_diff <= max_context_length))
rpe_cond = device_zeros(sequence_length, sequence_length, dtype=dtype)
rpe_cond[mask] = rpe[pos_diff[mask]]
return rpe_cond

# rpe_cond is used by torch only
rpe_cond = t5_rpe_masked_cond(rpe,
max_context_length=max_context_length,
sequence_length=shape.kv_seq_len,
dtype=tkw_attention_with_rpe.dtype)

#################################################################################
# TORCH ATTENTION and ATTENTION + RPE
#################################################################################
torch_attention_ref = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None
)

a = torch.matmul(q, k.transpose(-1, -2)) * dk_sqrt
torch_attention = torch.matmul(torch.softmax(a, dim=-1), v)

assert_close(torch_attention, torch_attention_ref, atol=2e-3, rtol=2e-3)

a += rpe_cond.unsqueeze(0)
torch_attention_with_rpe = torch.matmul(F.softmax(a, dim=-1), v)

torch_rpe_delta = torch_attention_with_rpe - torch_attention

#################################################################################
# TKW BASE ATTENTION
#################################################################################
### Reference version
# base_attention_reference, hyperparams, dynamic_symbols, dynamic_symbols_map = \
# get_vanilla_attention_kernel_reference(
# shape,
# mfma_variant=[MMAType.F32_16x16x16_F16,
# MMAType.F32_16x16x16_F16],
# dynamic_dims=False)


# def attention_reference(tq, tk, tv, toutput):
# base_attention_reference(tq, tk, tv, toutput)

# run(attention_reference, hyperparams, q * dk_sqrt * log2e, k,
# v.permute([0, 2, 1]), tkw_attention)

# print(torch_attention)
# print(tkw_attention)

# assert_close(torch_attention.to(dtype=tkw_attention.dtype),
# tkw_attention,
# atol=2e-3,
# rtol=2e-3)

### RPE version
rpe_attention, hyperparams, dynamic_symbols, dynamic_symbols_map = \
get_vanilla_attention_kernel(
shape,
mfma_variant=[MMAType.F32_16x16x16_F16,
MMAType.F32_16x16x16_F16],
dynamic_dims=False,
max_context_length = max_context_length + 2)


def attention_with_rpe(tq, tk, tv, trpe, toutput):
# Print IR if needed.
# print(base_attention(q, k, v, rpe, output).module_op)
base_attention(tq, tk, tv, trpe, toutput)

mb = rpe_attention(tq, tk, tv, trpe, toutput)
print(mb.module_op)

def attention_reference(tq, tk, tv, toutput):
base_attention_reference(tq, tk, tv, toutput)
run(attention_with_rpe, hyperparams, q * dk_sqrt * log2e, k,
v.permute([0, 2, 1]), rpe, tkw_attention_with_rpe)

tkw_rpe_delta = tkw_attention_with_rpe - tkw_attention

run(attention_with_rpe, hyperparams, q, k, v, rpe, output)
run(attention_reference, hyperparams, q, k, v, output_reference)
run(attention_reference, hyperparams, q, k, v, output_reference_2)
print(torch_rpe_delta)
print(tkw_rpe_delta)

print(f"\n\nreference:\n{output_reference.cpu()[0]}")
print(f"RPE:\n{rpe.cpu()}")
print(f"ATTENTION RPE:\n{output.cpu()[0]}")
print(f"delta:\n{(output - output_reference).cpu()[0]}")
print(f"truth sanity check should be zero:\n{(output_reference - output_reference_2).cpu()[0]}")
assert_close(torch_rpe_delta.to(dtype=tkw_rpe_delta.dtype),
tkw_rpe_delta,
atol=2e-3,
rtol=2e-3)

0 comments on commit c75e45c

Please sign in to comment.