Skip to content

Commit

Permalink
add workflow & precommit (#16)
Browse files Browse the repository at this point in the history
1. update dependencies: use torch>=2.0.0 & triton>=2.2.0(although torch release still requires triton 2.1.0)
2. add pre-commit & use ruff
  • Loading branch information
iclementine authored Jan 31, 2024
1 parent 14f5f1a commit 548382d
Show file tree
Hide file tree
Showing 10 changed files with 203 additions and 108 deletions.
52 changes: 52 additions & 0 deletions .github/workflows/code-check.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Python application

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

permissions:
contents: read

jobs:
build:
runs-on: self-hosted
steps:
- uses: actions/checkout@v4
with:
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}

# - name: Set up Python 3.10
# uses: actions/setup-python@v3
# with:
# python-version: "3.10"

# - name: Install dependencies
# run: |
# python -m pip install --upgrade pip
# pip install flake8 pytest
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
# pip install .

- name: Activate Virtualenv
run: |
source /home/flagattn_ci/.virtualenvs/release/bin/activate
echo PATH=$PATH >> $GITHUB_ENV
- name: Editable Install
run: |
pip install --no-dependencies -e .
# - name: Lint with flake8
# run: |
# # stop the build if there are Python syntax errors or undefined names
# flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest tests
41 changes: 41 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
files: '^src/.*'
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-toml
- id: check-ast
- id: check-added-large-files
- id: check-merge-conflict
- id: check-executables-have-shebangs
- id: check-shebang-scripts-are-executable
- id: detect-private-key
- id: debug-statements

# - repo: https://github.com/google/yapf
# rev: v0.40.2
# hooks:
# - id: yapf
# args: ["-p", "-i"]
# stages: [commit, push, manual]

# - repo: https://github.com/pylint-dev/pylint
# rev: v3.0.3
# hooks:
# - id: pylint


- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.14
hooks:
- id: ruff
args: ["--fix"]
stages: [commit, push, manual]
# - id: ruff-format
# stages: [commit, push, manual]



11 changes: 9 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@ classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
]

# Not specifing triton version here because torch has its own required triton version
# FlagAttention needs a recent version of triton (triton nightly or 2.2.0) to run.
dependencies = [
"torch>=2.1.0",
"torch>=2.0.0",
"triton"
]

[project.optional-dependencies]
Expand All @@ -38,7 +42,7 @@ where = ["src"]
include = ["flag_attn"]
namespaces = false

# helps for setting up pytest in pyprojects
# helps for setting up pytest in pyprojects
# https://docs.pytest.org/en/7.3.x/reference/customize.html#rootdir
# https://docs.pytest.org/en/7.3.x/reference/reference.html#confval-pythonpath
[tool.pytest.ini_options]
Expand All @@ -50,3 +54,6 @@ pythonpath = [
"tests/flag_attn",
]

[tool.ruff]
ignore = ["E741"]
line-length = 120
7 changes: 3 additions & 4 deletions src/flag_attn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
version_tuple = (0, 0, 0)


from flag_attn.piecewise import attention as piecewise_attention
from flag_attn.flash import attention as flash_attention

from flag_attn import testing
from flag_attn.piecewise import attention as piecewise_attention # noqa: F401
from flag_attn.flash import attention as flash_attention # noqa: F401

from flag_attn import testing # noqa: F401
62 changes: 31 additions & 31 deletions src/flag_attn/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_

if sm_scale is None:
sm_scale = 1. / math.sqrt(D)

# to work around https://github.com/openai/triton/issues/2441
device = torch.cuda.device_of(q)
with torch.cuda.device(device):
Expand All @@ -42,21 +42,21 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
B, H, M, N, P_SEQ,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D,
IS_CAUSAL=causal, LARGER_M=larger_m,
DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
num_warps=num_warps, num_stages=num_stages,
)

if return_total_attention:
tot_attn = torch.empty((B, H, N), device=q.device, dtype=torch.float32)
grid = (triton.cdiv(N, BLOCK_N), H, B)
_total_attention_kernel[grid](
q, k, L, tot_attn, sm_scale,
q, k, L, tot_attn, sm_scale,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
B, H, M, N, P_SEQ,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N,
CAUSAL=causal,
DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
num_stages=num_stages, num_warps=num_warps,
Expand Down Expand Up @@ -91,7 +91,7 @@ def backward(ctx, do, *ignored):

if sm_scale is None:
sm_scale = 1. / math.sqrt(D)

# to work around https://github.com/openai/triton/issues/2441
device = torch.cuda.device_of(q)
with torch.cuda.device(device):
Expand All @@ -118,7 +118,7 @@ def backward(ctx, do, *ignored):
dv = torch.empty_like(v)
grid = (triton.cdiv(N, BLOCK_N), H, B)
_bwd_kv_kernel[grid](
q, k, v, sm_scale, do,
q, k, v, sm_scale, do,
dk, dv,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
Expand All @@ -136,7 +136,7 @@ def backward(ctx, do, *ignored):
dq = torch.zeros_like(q)
grid = (triton.cdiv(M, BLOCK_M), H, B)
_bwd_q_kernel[grid](
q, k, v, sm_scale, do,
q, k, v, sm_scale, do,
dq,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
Expand All @@ -145,7 +145,7 @@ def backward(ctx, do, *ignored):
do.stride(0), do.stride(1), do.stride(2), do.stride(3),
dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3),
B, H, M, N, P_SEQ,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N,
CAUSAL=causal, LARGER_M=larger_m,
DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
num_stages=num_stages, num_warps = num_warps,
Expand All @@ -154,7 +154,7 @@ def backward(ctx, do, *ignored):
return dq, dk, dv, None, None, None, None


def attention(q, k, v, causal=False, sm_scale=None,
def attention(q, k, v, causal=False, sm_scale=None,
return_log_normalizer=False, return_total_attention=False,
):
"""
Expand Down Expand Up @@ -183,7 +183,7 @@ def attention(q, k, v, causal=False, sm_scale=None,
# --------------------------- Forward ---------------------------
# NOTE: this function can be overwritten at runtime to use your custom config
def get_fwd_config(B, H, M, N, D, causal):
if torch.cuda.get_device_capability() == (8, 0):
if torch.cuda.get_device_capability() == (8, 0):
if not causal:
if D <= 64:
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4
Expand Down Expand Up @@ -225,7 +225,7 @@ def _fwd_kernel(
stride_vz, stride_vh, stride_vn, stride_vk,
stride_oz, stride_oh, stride_om, stride_ok,
Z, H, M, N, P_SEQ,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr, LARGER_M: tl.constexpr,
DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
):
Expand Down Expand Up @@ -253,7 +253,7 @@ def _fwd_kernel(
offs_n_base = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL)

# initialize pointers to value-like data
# initialize pointers to value-like data
q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL)
o_ptrs = O + (offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok) # (BLOCK_M, BLOCK_DMODEL)
l_ptrs = L + offs_m
Expand All @@ -269,7 +269,7 @@ def _fwd_kernel(
else:
mask_m = offs_m < M
q = tl.load(q_ptrs, mask=mask_m[:, None], cache_modifier=".cg")

#Dot I trick: to place q in registers, it saves shared memory
if BLOCK_DMODEL < 128:
I = tl.where(offs_k[:, None] == offs_k,
Expand All @@ -285,10 +285,10 @@ def _fwd_kernel(
# NOTE: Loop-Bound-For-N
# The indices in m-dimension that this block may access is in `[start_m * BLOCK_M, (start_m + 1) * BLOCK_M)`.
# According to the rule of causal masking, then max index in n-dimension that this block may access
# is `P_SEQ + (start_m + 1) * BLOCK_M`.
# However, the upper bound of index in n-dimension should never exceed the sequence length of k/v(`P_SEQ + N_CTX`).
# is `P_SEQ + (start_m + 1) * BLOCK_M`.
# However, the upper bound of index in n-dimension should never exceed the sequence length of k/v(`P_SEQ + N_CTX`).
# `P_SEQ + (start_m + 1) * BLOCK_M` may be larger than `N`.
# At this case, there would be illegal memory access when loading k & v tiles
# At this case, there would be illegal memory access when loading k & v tiles
# if mask_n is not applied for loading(only when `DIVISIBLE_N`` is true).
# See also https://github.com/FlagOpen/FlagAttention/pull/8
if IS_CAUSAL:
Expand All @@ -305,7 +305,7 @@ def _fwd_kernel(
for start_n in range(0, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
offs_n = start_n + offs_n_base

# -- load k, v --
if DIVISIBLE_N:
k = tl.load(k_ptrs, cache_modifier=".cg")
Expand All @@ -318,7 +318,7 @@ def _fwd_kernel(
# -- compute qk ---
s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
s += tl.dot(q, k)

if not DIVISIBLE_N:
s = tl.where(mask_n[None, :], s, float("-inf"))
if IS_CAUSAL:
Expand Down Expand Up @@ -349,7 +349,7 @@ def _fwd_kernel(
else:
acc = acc * (1.0 / l_i[:, None])
l = m_i * sm_scale + tl.log(l_i) # log(normalizer)

if DIVISIBLE_M:
tl.store(l_ptrs, l, cache_modifier=".cg")
tl.store(o_ptrs, acc.to(input_dtype), cache_modifier=".cg")
Expand Down Expand Up @@ -406,7 +406,7 @@ def _bwd_preprocess(
Delta += off_z * stride_dz + off_h * stride_dh

# compute (Out * Dout).sum() for vector interpretation
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)

# load
Expand All @@ -420,7 +420,7 @@ def _bwd_preprocess(
mask_m = off_m < M
o = tl.load(o_ptrs, mask=mask_m[:, None]).to(tl.float32)
do = tl.load(do_ptrs, mask=mask_m[:, None]).to(tl.float32)

# compute
delta = tl.sum(o * do, axis=1)
# write-back
Expand Down Expand Up @@ -480,8 +480,8 @@ def _bwd_kv_kernel(
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_m_base = tl.arange(0, BLOCK_M)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data

# initialize pointers to value-like data
q_ptrs = Q + (offs_m_init[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL)
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL)
Expand All @@ -502,7 +502,7 @@ def _bwd_kv_kernel(
# initialize dk amd dv
dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)

# loop over a col
for start_m in range(lo, M, BLOCK_M):
start_m = tl.multiple_of(start_m, BLOCK_M)
Expand Down Expand Up @@ -599,7 +599,7 @@ def _bwd_q_kernel(
start_m = tl.program_id(0)
off_h = tl.program_id(1)
off_z = tl.program_id(2)

# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
Expand All @@ -622,7 +622,7 @@ def _bwd_q_kernel(
offs_n_init = offs_n_base
offs_k = tl.arange(0, BLOCK_DMODEL)

# initialize pointers to value-like data
# initialize pointers to value-like data
q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL)
k_ptrs = K + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL)
v_ptrs = V + (offs_n_init[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL)
Expand All @@ -647,7 +647,7 @@ def _bwd_q_kernel(
delta = tl.load(d_ptrs, mask=mask_m)
l = tl.load(l_ptrs, mask=mask_m)

# initialize dq
# initialize dq
dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)

# loop over k, v and update accumulator
Expand All @@ -662,7 +662,7 @@ def _bwd_q_kernel(
# loop over a row
for start_n in range(0, hi, BLOCK_N):
offs_n = start_n + offs_n_base

# load k1, k2, v on chip
if DIVISIBLE_N:
v = tl.load(v_ptrs)
Expand All @@ -675,7 +675,7 @@ def _bwd_q_kernel(

# recompute p = softmax(qk * sm_scale, dim=-1)
if not DIVISIBLE_N:
valid_mask = mask_n # & mask_m[:, None]
valid_mask = mask_n # & mask_m[:, None]
if CAUSAL:
causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) # (BLOCK_M, BLOCK_N)
s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
Expand Down Expand Up @@ -713,7 +713,7 @@ def _bwd_q_kernel(
# increment pointers
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vn

dq *= sm_scale
if DIVISIBLE_M:
tl.store(dq_ptrs, dq.to(input_dtype))
Expand Down
Loading

0 comments on commit 548382d

Please sign in to comment.