From 548382d704db9310e7613a542bb81efc96ef3890 Mon Sep 17 00:00:00 2001 From: Clement Chan Date: Wed, 31 Jan 2024 14:47:40 +0800 Subject: [PATCH] add workflow & precommit (#16) 1. update dependencies: use torch>=2.0.0 & triton>=2.2.0(although torch release still requires triton 2.1.0) 2. add pre-commit & use ruff --- .github/workflows/code-check.yml | 52 ++++++++++++++++ .pre-commit-config.yaml | 41 +++++++++++++ pyproject.toml | 11 +++- src/flag_attn/__init__.py | 7 +-- src/flag_attn/flash.py | 62 +++++++++---------- src/flag_attn/piecewise.py | 96 +++++++++++++++--------------- src/flag_attn/testing/__init__.py | 4 +- src/flag_attn/testing/flash.py | 7 +-- src/flag_attn/testing/piecewise.py | 8 +-- src/flag_attn/total.py | 23 ++++--- 10 files changed, 203 insertions(+), 108 deletions(-) create mode 100644 .github/workflows/code-check.yml create mode 100644 .pre-commit-config.yaml diff --git a/.github/workflows/code-check.yml b/.github/workflows/code-check.yml new file mode 100644 index 0000000..0c553fa --- /dev/null +++ b/.github/workflows/code-check.yml @@ -0,0 +1,52 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Python application + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +permissions: + contents: read + +jobs: + build: + runs-on: self-hosted + steps: + - uses: actions/checkout@v4 + with: + ssh-key: ${{ secrets.SSH_PRIVATE_KEY }} + + # - name: Set up Python 3.10 + # uses: actions/setup-python@v3 + # with: + # python-version: "3.10" + + # - name: Install dependencies + # run: | + # python -m pip install --upgrade pip + # pip install flake8 pytest + # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + # pip install . + + - name: Activate Virtualenv + run: | + source /home/flagattn_ci/.virtualenvs/release/bin/activate + echo PATH=$PATH >> $GITHUB_ENV + + - name: Editable Install + run: | + pip install --no-dependencies -e . + + # - name: Lint with flake8 + # run: | + # # stop the build if there are Python syntax errors or undefined names + # flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + # flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + pytest tests \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..2ebef37 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,41 @@ +files: '^src/.*' +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-toml + - id: check-ast + - id: check-added-large-files + - id: check-merge-conflict + - id: check-executables-have-shebangs + - id: check-shebang-scripts-are-executable + - id: detect-private-key + - id: debug-statements + + # - repo: https://github.com/google/yapf + # rev: v0.40.2 + # hooks: + # - id: yapf + # args: ["-p", "-i"] + # stages: [commit, push, manual] + + # - repo: https://github.com/pylint-dev/pylint + # rev: v3.0.3 + # hooks: + # - id: pylint + + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.14 + hooks: + - id: ruff + args: ["--fix"] + stages: [commit, push, manual] + # - id: ruff-format + # stages: [commit, push, manual] + + + diff --git a/pyproject.toml b/pyproject.toml index 1c24bd1..cd80acc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,8 +13,12 @@ classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", ] + +# Not specifing triton version here because torch has its own required triton version +# FlagAttention needs a recent version of triton (triton nightly or 2.2.0) to run. dependencies = [ - "torch>=2.1.0", + "torch>=2.0.0", + "triton" ] [project.optional-dependencies] @@ -38,7 +42,7 @@ where = ["src"] include = ["flag_attn"] namespaces = false -# helps for setting up pytest in pyprojects +# helps for setting up pytest in pyprojects # https://docs.pytest.org/en/7.3.x/reference/customize.html#rootdir # https://docs.pytest.org/en/7.3.x/reference/reference.html#confval-pythonpath [tool.pytest.ini_options] @@ -50,3 +54,6 @@ pythonpath = [ "tests/flag_attn", ] +[tool.ruff] +ignore = ["E741"] +line-length = 120 diff --git a/src/flag_attn/__init__.py b/src/flag_attn/__init__.py index 2f1883d..7556803 100644 --- a/src/flag_attn/__init__.py +++ b/src/flag_attn/__init__.py @@ -6,8 +6,7 @@ version_tuple = (0, 0, 0) -from flag_attn.piecewise import attention as piecewise_attention -from flag_attn.flash import attention as flash_attention - -from flag_attn import testing +from flag_attn.piecewise import attention as piecewise_attention # noqa: F401 +from flag_attn.flash import attention as flash_attention # noqa: F401 +from flag_attn import testing # noqa: F401 diff --git a/src/flag_attn/flash.py b/src/flag_attn/flash.py index dcd12c9..20958b7 100644 --- a/src/flag_attn/flash.py +++ b/src/flag_attn/flash.py @@ -21,7 +21,7 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_ if sm_scale is None: sm_scale = 1. / math.sqrt(D) - + # to work around https://github.com/openai/triton/issues/2441 device = torch.cuda.device_of(q) with torch.cuda.device(device): @@ -42,9 +42,9 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_ v.stride(0), v.stride(1), v.stride(2), v.stride(3), o.stride(0), o.stride(1), o.stride(2), o.stride(3), B, H, M, N, P_SEQ, - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D, IS_CAUSAL=causal, LARGER_M=larger_m, - DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, + DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, num_warps=num_warps, num_stages=num_stages, ) @@ -52,11 +52,11 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_ tot_attn = torch.empty((B, H, N), device=q.device, dtype=torch.float32) grid = (triton.cdiv(N, BLOCK_N), H, B) _total_attention_kernel[grid]( - q, k, L, tot_attn, sm_scale, + q, k, L, tot_attn, sm_scale, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), B, H, M, N, P_SEQ, - BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, + BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, CAUSAL=causal, DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, num_stages=num_stages, num_warps=num_warps, @@ -91,7 +91,7 @@ def backward(ctx, do, *ignored): if sm_scale is None: sm_scale = 1. / math.sqrt(D) - + # to work around https://github.com/openai/triton/issues/2441 device = torch.cuda.device_of(q) with torch.cuda.device(device): @@ -118,7 +118,7 @@ def backward(ctx, do, *ignored): dv = torch.empty_like(v) grid = (triton.cdiv(N, BLOCK_N), H, B) _bwd_kv_kernel[grid]( - q, k, v, sm_scale, do, + q, k, v, sm_scale, do, dk, dv, L, delta, q.stride(0), q.stride(1), q.stride(2), q.stride(3), @@ -136,7 +136,7 @@ def backward(ctx, do, *ignored): dq = torch.zeros_like(q) grid = (triton.cdiv(M, BLOCK_M), H, B) _bwd_q_kernel[grid]( - q, k, v, sm_scale, do, + q, k, v, sm_scale, do, dq, L, delta, q.stride(0), q.stride(1), q.stride(2), q.stride(3), @@ -145,7 +145,7 @@ def backward(ctx, do, *ignored): do.stride(0), do.stride(1), do.stride(2), do.stride(3), dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3), B, H, M, N, P_SEQ, - BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, + BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, CAUSAL=causal, LARGER_M=larger_m, DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, num_stages=num_stages, num_warps = num_warps, @@ -154,7 +154,7 @@ def backward(ctx, do, *ignored): return dq, dk, dv, None, None, None, None -def attention(q, k, v, causal=False, sm_scale=None, +def attention(q, k, v, causal=False, sm_scale=None, return_log_normalizer=False, return_total_attention=False, ): """ @@ -183,7 +183,7 @@ def attention(q, k, v, causal=False, sm_scale=None, # --------------------------- Forward --------------------------- # NOTE: this function can be overwritten at runtime to use your custom config def get_fwd_config(B, H, M, N, D, causal): - if torch.cuda.get_device_capability() == (8, 0): + if torch.cuda.get_device_capability() == (8, 0): if not causal: if D <= 64: BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4 @@ -225,7 +225,7 @@ def _fwd_kernel( stride_vz, stride_vh, stride_vn, stride_vk, stride_oz, stride_oh, stride_om, stride_ok, Z, H, M, N, P_SEQ, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, IS_CAUSAL: tl.constexpr, LARGER_M: tl.constexpr, DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, ): @@ -253,7 +253,7 @@ def _fwd_kernel( offs_n_base = tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_DMODEL) - # initialize pointers to value-like data + # initialize pointers to value-like data q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL) o_ptrs = O + (offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok) # (BLOCK_M, BLOCK_DMODEL) l_ptrs = L + offs_m @@ -269,7 +269,7 @@ def _fwd_kernel( else: mask_m = offs_m < M q = tl.load(q_ptrs, mask=mask_m[:, None], cache_modifier=".cg") - + #Dot I trick: to place q in registers, it saves shared memory if BLOCK_DMODEL < 128: I = tl.where(offs_k[:, None] == offs_k, @@ -285,10 +285,10 @@ def _fwd_kernel( # NOTE: Loop-Bound-For-N # The indices in m-dimension that this block may access is in `[start_m * BLOCK_M, (start_m + 1) * BLOCK_M)`. # According to the rule of causal masking, then max index in n-dimension that this block may access - # is `P_SEQ + (start_m + 1) * BLOCK_M`. - # However, the upper bound of index in n-dimension should never exceed the sequence length of k/v(`P_SEQ + N_CTX`). + # is `P_SEQ + (start_m + 1) * BLOCK_M`. + # However, the upper bound of index in n-dimension should never exceed the sequence length of k/v(`P_SEQ + N_CTX`). # `P_SEQ + (start_m + 1) * BLOCK_M` may be larger than `N`. - # At this case, there would be illegal memory access when loading k & v tiles + # At this case, there would be illegal memory access when loading k & v tiles # if mask_n is not applied for loading(only when `DIVISIBLE_N`` is true). # See also https://github.com/FlagOpen/FlagAttention/pull/8 if IS_CAUSAL: @@ -305,7 +305,7 @@ def _fwd_kernel( for start_n in range(0, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) offs_n = start_n + offs_n_base - + # -- load k, v -- if DIVISIBLE_N: k = tl.load(k_ptrs, cache_modifier=".cg") @@ -318,7 +318,7 @@ def _fwd_kernel( # -- compute qk --- s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) s += tl.dot(q, k) - + if not DIVISIBLE_N: s = tl.where(mask_n[None, :], s, float("-inf")) if IS_CAUSAL: @@ -349,7 +349,7 @@ def _fwd_kernel( else: acc = acc * (1.0 / l_i[:, None]) l = m_i * sm_scale + tl.log(l_i) # log(normalizer) - + if DIVISIBLE_M: tl.store(l_ptrs, l, cache_modifier=".cg") tl.store(o_ptrs, acc.to(input_dtype), cache_modifier=".cg") @@ -406,7 +406,7 @@ def _bwd_preprocess( Delta += off_z * stride_dz + off_h * stride_dh # compute (Out * Dout).sum() for vector interpretation - off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) off_n = tl.arange(0, D_HEAD) # load @@ -420,7 +420,7 @@ def _bwd_preprocess( mask_m = off_m < M o = tl.load(o_ptrs, mask=mask_m[:, None]).to(tl.float32) do = tl.load(do_ptrs, mask=mask_m[:, None]).to(tl.float32) - + # compute delta = tl.sum(o * do, axis=1) # write-back @@ -480,8 +480,8 @@ def _bwd_kv_kernel( offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_m_base = tl.arange(0, BLOCK_M) offs_k = tl.arange(0, BLOCK_DMODEL) - - # initialize pointers to value-like data + + # initialize pointers to value-like data q_ptrs = Q + (offs_m_init[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL) k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL) v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL) @@ -502,7 +502,7 @@ def _bwd_kv_kernel( # initialize dk amd dv dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) - + # loop over a col for start_m in range(lo, M, BLOCK_M): start_m = tl.multiple_of(start_m, BLOCK_M) @@ -599,7 +599,7 @@ def _bwd_q_kernel( start_m = tl.program_id(0) off_h = tl.program_id(1) off_z = tl.program_id(2) - + # scale sm_scale by log_2(e) and use # 2^x instead of exp in the loop because CSE and LICM # don't work as expected with `exp` in the loop @@ -622,7 +622,7 @@ def _bwd_q_kernel( offs_n_init = offs_n_base offs_k = tl.arange(0, BLOCK_DMODEL) - # initialize pointers to value-like data + # initialize pointers to value-like data q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL) k_ptrs = K + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL) v_ptrs = V + (offs_n_init[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL) @@ -647,7 +647,7 @@ def _bwd_q_kernel( delta = tl.load(d_ptrs, mask=mask_m) l = tl.load(l_ptrs, mask=mask_m) - # initialize dq + # initialize dq dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # loop over k, v and update accumulator @@ -662,7 +662,7 @@ def _bwd_q_kernel( # loop over a row for start_n in range(0, hi, BLOCK_N): offs_n = start_n + offs_n_base - + # load k1, k2, v on chip if DIVISIBLE_N: v = tl.load(v_ptrs) @@ -675,7 +675,7 @@ def _bwd_q_kernel( # recompute p = softmax(qk * sm_scale, dim=-1) if not DIVISIBLE_N: - valid_mask = mask_n # & mask_m[:, None] + valid_mask = mask_n # & mask_m[:, None] if CAUSAL: causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) # (BLOCK_M, BLOCK_N) s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) @@ -713,7 +713,7 @@ def _bwd_q_kernel( # increment pointers k_ptrs += BLOCK_N * stride_kn v_ptrs += BLOCK_N * stride_vn - + dq *= sm_scale if DIVISIBLE_M: tl.store(dq_ptrs, dq.to(input_dtype)) diff --git a/src/flag_attn/piecewise.py b/src/flag_attn/piecewise.py index f6c8fbe..8ad01fc 100644 --- a/src/flag_attn/piecewise.py +++ b/src/flag_attn/piecewise.py @@ -2,13 +2,13 @@ Piecewise Attention ==================== -This is a extension to Flash Attention v2 algorithm from Tri Dao -(https://tridao.me/publications/flash2/flash2.pdf) that performs piecewise computation -of attention scores(The scores to which softmax is applied). This design originates from -the need to make better predictions when the predicted sequence is longer than sequences +This is a extension to Flash Attention v2 algorithm from Tri Dao +(https://tridao.me/publications/flash2/flash2.pdf) that performs piecewise computation +of attention scores(The scores to which softmax is applied). This design originates from +the need to make better predictions when the predicted sequence is longer than sequences in the training set. -It takes as input two q's and two k's as inputs. The attention score is the dot product +It takes as input two q's and two k's as inputs. The attention score is the dot product of (q1, k1) or (q2, k2) depending on whether the distance between q & k exceeds a threshold. The code is adapted from triton's [tutorial](https://github.com/openai/triton/blob/5162871c6cae01a8508a309cf21a8e6b68a4c091/python/tutorials/06-fused-attention.py). @@ -37,7 +37,7 @@ def forward(ctx, q1, k1, q2, k2, v, w, causal, sm_scale): if sm_scale is None: sm_scale = 1. / math.sqrt(D) - + # to work around https://github.com/openai/triton/issues/2441 device = torch.cuda.device_of(q1) with torch.cuda.device(device): @@ -50,7 +50,7 @@ def forward(ctx, q1, k1, q2, k2, v, w, causal, sm_scale): grid = (triton.cdiv(M, BLOCK_M), H, B) o = torch.empty_like(q1) L = torch.empty((B, H, M), device=q1.device, dtype=torch.float32) - + _fwd_kernel[grid]( q1, k1, q2, k2, v, sm_scale, L, @@ -116,7 +116,7 @@ def backward(ctx, do): dv = torch.empty_like(v) grid = (triton.cdiv(N, BLOCK_N), H, B) _bwd_kv_kernel[grid]( - q1, k1, q2, k2, v, sm_scale, do, + q1, k1, q2, k2, v, sm_scale, do, dk1,dk2, dv, L, delta, q1.stride(0), q1.stride(1), q1.stride(2), q1.stride(3), @@ -128,7 +128,7 @@ def backward(ctx, do): dk1.stride(0), dk1.stride(1), dk1.stride(2), dk1.stride(3), dk2.stride(0), dk2.stride(1), dk2.stride(2), dk2.stride(3), dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3), - B, H, M, N, P_SEQ, + B, H, M, N, P_SEQ, w=w, BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, @@ -137,12 +137,12 @@ def backward(ctx, do): num_stages=num_stages, num_warps=num_warps, ) - + dq1 = torch.zeros_like(q1) dq2 = torch.zeros_like(q2) grid = (triton.cdiv(M, BLOCK_M), H, B) _bwd_q_kernel[grid]( - q1, k1, q2, k2, v, sm_scale, do, + q1, k1, q2, k2, v, sm_scale, do, dq1, dq2, L, delta, q1.stride(0), q1.stride(1), q1.stride(2), q1.stride(3), @@ -153,7 +153,7 @@ def backward(ctx, do): do.stride(0), do.stride(1), do.stride(2), do.stride(3), dq1.stride(0), dq1.stride(1), dq1.stride(2), dq1.stride(3), dq2.stride(0), dq2.stride(1), dq2.stride(2), dq2.stride(3), - B, H, M, N, P_SEQ, + B, H, M, N, P_SEQ, w=w, BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, @@ -170,9 +170,9 @@ def attention(q1, k1, q2, k2, v, dist_threshold, causal=False, sm_scale=None): """ PiecewiseAttention - Piecewise deviates from standard scaled dot product attention in that takes - as inputs two q's and two k's as inputs. The attention score is dot product - of (q1, k1) or (q2, k2) depending on whether the distance between q & k + Piecewise deviates from standard scaled dot product attention in that takes + as inputs two q's and two k's as inputs. The attention score is dot product + of (q1, k1) or (q2, k2) depending on whether the distance between q & k exceeds a threshold. Arguments: @@ -193,7 +193,7 @@ def attention(q1, k1, q2, k2, v, dist_threshold, causal=False, sm_scale=None): # --------------------------- Forward --------------------------- def get_fwd_config(B, H, M, N, D, causal): # A100 - if torch.cuda.get_device_capability() == (8, 0): + if torch.cuda.get_device_capability() == (8, 0): if not causal: if D <= 64: BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 3, 4 @@ -243,7 +243,7 @@ def _fwd_kernel( start_m = tl.program_id(0) off_h = tl.program_id(1) off_z = tl.program_id(2) - + # scale sm_scale by log_2(e) and use # 2^x instead of exp in the loop because CSE and LICM # don't work as expected with `exp` in the loop @@ -264,7 +264,7 @@ def _fwd_kernel( offs_n_init = offs_n_base offs_k = tl.arange(0, BLOCK_DMODEL) - # initialize pointers to v alue-like data + # initialize pointers to v alue-like data q1_ptrs = Q1 + (offs_m[:, None] * stride_q1m + offs_k[None, :] * stride_q1k) # (BLOCK_M, BLOCK_DMODEL) q2_ptrs = Q2 + (offs_m[:, None] * stride_q2m + offs_k[None, :] * stride_q2k) # (BLOCK_M, BLOCK_DMODEL) k1_ptrs = K1 + (offs_n_init[:, None] * stride_k1n + offs_k[None, :] * stride_k1k) # (BLOCK_N, BLOCK_DMODEL) @@ -289,8 +289,8 @@ def _fwd_kernel( # Dot I trick: it converts q1, q2 into mma layout and saves shared memory # better way to generate a eye matrix. avoid casting from bool - I = tl.where(offs_k[:, None] == offs_k, - tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=input_dtype), + I = tl.where(offs_k[:, None] == offs_k, + tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=input_dtype), tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=input_dtype)) q1 = tl.dot(q1, I).to(input_dtype) q2 = tl.dot(q2, I).to(input_dtype) @@ -303,7 +303,7 @@ def _fwd_kernel( hi = tl.maximum(0, hi) else: hi = N - + for start_n in range(0, hi, BLOCK_N): # -- offsets & masking -- start_n = tl.multiple_of(start_n, BLOCK_N) @@ -325,8 +325,8 @@ def _fwd_kernel( s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # TODO: more careful masking - s += tl.where(piecewise_mask, - tl.dot(q2, tl.trans(k2)), + s += tl.where(piecewise_mask, + tl.dot(q2, tl.trans(k2)), tl.dot(q1, tl.trans(k1))) if not DIVISIBLE_N: s = tl.where(mask_n, s, float("-inf")) @@ -372,7 +372,7 @@ def _fwd_kernel( # --------------------------- Backward --------------------------- def get_bwd_config(B, H, M, N, D, causal): # A100 - if torch.cuda.get_device_capability() == (8, 0): + if torch.cuda.get_device_capability() == (8, 0): if not causal: if D <= 64: BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 @@ -383,7 +383,7 @@ def get_bwd_config(B, H, M, N, D, causal): BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 3, 4 else: BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 64, 2, 4 - + # BLOCK_M = 64 if D<=64 else 128 # BLOCK_N = 64 # num_stages = 1 if D<=64 else (2 if not causal else 1) @@ -422,7 +422,7 @@ def _bwd_preprocess( Delta += off_z * stride_dz + off_h * stride_dh # compute (Out * Dout).sum() for vector interpretation - off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) off_n = tl.arange(0, D_HEAD) # load @@ -436,7 +436,7 @@ def _bwd_preprocess( mask_m = off_m < M o = tl.load(o_ptrs, mask=mask_m[:, None]).to(tl.float32) do = tl.load(do_ptrs, mask=mask_m[:, None]).to(tl.float32) - + # compute delta = tl.sum(o * do, axis=1) # write-back @@ -462,7 +462,7 @@ def _bwd_kv_kernel( stride_dk1z, stride_dk1h, stride_dk1n, stride_dk1k, stride_dk2z, stride_dk2h, stride_dk2n, stride_dk2k, stride_dvz, stride_dvh, stride_dvn, stride_dvk, - Z, H, M, N, P_SEQ, + Z, H, M, N, P_SEQ, w: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, @@ -503,9 +503,9 @@ def _bwd_kv_kernel( offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_m_base = tl.arange(0, BLOCK_M) offs_k = tl.arange(0, BLOCK_DMODEL) - - # initialize pointers to value-like data + + # initialize pointers to value-like data q1_ptrs = Q1 + (offs_m_init[:, None] * stride_q1m + offs_k[None, :] * stride_q1k) # (BLOCK_M, BLOCK_DMODEL) q2_ptrs = Q2 + (offs_m_init[:, None] * stride_q2m + offs_k[None, :] * stride_q2k) # (BLOCK_M, BLOCK_DMODEL) k1_ptrs = K1 + (offs_k[:, None] * stride_k1k + offs_n[None, :] * stride_k1n) # (BLOCK_DMODEL, BLOCK_N) @@ -532,12 +532,12 @@ def _bwd_kv_kernel( dk1 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) dk2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) - + # loop over a column for start_m in range(lo, M, BLOCK_M): start_m = tl.multiple_of(start_m, BLOCK_M) offs_m = start_m + offs_m_base - + # load q1, k1, q2, k2, v, do on-chip if DIVISIBLE_M: q1 = tl.load(q1_ptrs) @@ -556,10 +556,10 @@ def _bwd_kv_kernel( # recompute p = softmax(qk, dim=-1).T piecewise_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :] + w) # (BLOCK_M, BLOCK_N) s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - s += tl.where(piecewise_mask, - tl.dot(q2, k2), + s += tl.where(piecewise_mask, + tl.dot(q2, k2), tl.dot(q1, k1)) - + # NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd) # So masking on s is not needed. # if CAUSAL: @@ -576,7 +576,7 @@ def _bwd_kv_kernel( if CAUSAL: causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) # (BLOCK_M, BLOCK_N) p = tl.where(causal_mask, p, 0.0) - + # compute dv = dot(p, do) # do = tl.load(do_ptrs, mask=mask_m[:, None]) # (BLOCK_M, BLOCK_DMODEL) @@ -592,16 +592,16 @@ def _bwd_kv_kernel( # else: # dp = tl.where(valid_mask, dp, 0.0) - # compute ds = p * (dp - delta[:, None]) + # compute ds = p * (dp - delta[:, None]) # move scale out to dk at last ds = p * (dp - delta[:, None]) # (BLOCK_M, BLOCK_N) - + # mask ds To ensure no small values if not DIVISIBLE_M: ds = tl.where(valid_mask, ds, 0.0) if CAUSAL: ds = tl.where(causal_mask, ds, 0.0) - + ds2 = tl.where(piecewise_mask, ds, 0.0).to(input_dtype) ds1 = tl.where(piecewise_mask, 0.0, ds).to(input_dtype) @@ -641,7 +641,7 @@ def _bwd_q_kernel( stride_doz, stride_doh, stride_dom, stride_dok, stride_dq1z, stride_dq1h, stride_dq1m, stride_dq1k, stride_dq2z, stride_dq2h, stride_dq2m, stride_dq2k, - Z, H, M, N, P_SEQ, + Z, H, M, N, P_SEQ, w: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, @@ -653,7 +653,7 @@ def _bwd_q_kernel( start_m = tl.program_id(0) off_h = tl.program_id(1) off_z = tl.program_id(2) - + # scale sm_scale by log_2(e) and use # 2^x instead of exp in the loop because CSE and LICM # don't work as expected with `exp` in the loop @@ -679,7 +679,7 @@ def _bwd_q_kernel( offs_n_init = offs_n_base offs_k = tl.arange(0, BLOCK_DMODEL) - # initialize pointers to value-like data + # initialize pointers to value-like data q1_ptrs = Q1 + (offs_m[:, None] * stride_q1m + offs_k[None, :] * stride_q1k) # (BLOCK_M, BLOCK_DMODEL) q2_ptrs = Q2 + (offs_m[:, None] * stride_q2m + offs_k[None, :] * stride_q2k) # (BLOCK_M, BLOCK_DMODEL) k1_ptrs = K1 + (offs_n_init[:, None] * stride_k1n + offs_k[None, :] * stride_k1k) # (BLOCK_N, BLOCK_DMODEL) @@ -709,7 +709,7 @@ def _bwd_q_kernel( delta = tl.load(d_ptrs, mask=mask_m) l = tl.load(l_ptrs, mask=mask_m) - # initialize dq + # initialize dq dq1 = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) dq2 = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) @@ -725,7 +725,7 @@ def _bwd_q_kernel( # loop over a row for start_n in range(0, hi, BLOCK_N): offs_n = start_n + offs_n_base - + # load k1, k2, v on chip if DIVISIBLE_N: v = tl.load(v_ptrs) @@ -737,11 +737,11 @@ def _bwd_q_kernel( k1 = tl.load(k1_ptrs, mask=mask_n[:, None]) k2 = tl.load(k2_ptrs, mask=mask_n[:, None]) - # recompute p = softmax(qk * sm_scale, dim=-1) + # recompute p = softmax(qk * sm_scale, dim=-1) piecewise_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :] + w) # (BLOCK_M, BLOCK_N) s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - s += tl.where(piecewise_mask, - tl.dot(q2, tl.trans(k2)), + s += tl.where(piecewise_mask, + tl.dot(q2, tl.trans(k2)), tl.dot(q1, tl.trans(k1))) # NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd) # So masking on s is not needed. @@ -781,7 +781,7 @@ def _bwd_q_kernel( k1_ptrs += BLOCK_N * stride_k1n k2_ptrs += BLOCK_N * stride_k2n v_ptrs += BLOCK_N * stride_vn - + dq1 *= sm_scale dq2 *= sm_scale if DIVISIBLE_M: diff --git a/src/flag_attn/testing/__init__.py b/src/flag_attn/testing/__init__.py index 873d70b..780acb2 100644 --- a/src/flag_attn/testing/__init__.py +++ b/src/flag_attn/testing/__init__.py @@ -1,2 +1,2 @@ -from flag_attn.testing.flash import attention as flash_attention -from flag_attn.testing.piecewise import attention as piecewise_attention \ No newline at end of file +from flag_attn.testing.flash import attention as flash_attention # noqa: F401 +from flag_attn.testing.piecewise import attention as piecewise_attention # noqa: F401 diff --git a/src/flag_attn/testing/flash.py b/src/flag_attn/testing/flash.py index 7e1557e..7f272ed 100644 --- a/src/flag_attn/testing/flash.py +++ b/src/flag_attn/testing/flash.py @@ -1,9 +1,6 @@ import math import torch -import math -import torch - def attention(q, k, v, @@ -47,8 +44,8 @@ def attention(q, has_extra_return = return_log_normalizer or return_total_attention if has_extra_return: - outs = (attn_output, - log_normalizer if return_log_normalizer else None, + outs = (attn_output, + log_normalizer if return_log_normalizer else None, tot_attn if return_total_attention else None) return outs else: diff --git a/src/flag_attn/testing/piecewise.py b/src/flag_attn/testing/piecewise.py index 8e8f56f..4047773 100644 --- a/src/flag_attn/testing/piecewise.py +++ b/src/flag_attn/testing/piecewise.py @@ -17,12 +17,12 @@ def attention(q1, k1, q2, k2, v, dist_threshold, causal, sm_scale=None, upcast=F ms = torch.arange(q_seq_len, device=device).unsqueeze(-1) ns = torch.arange(kv_seq_len, device=device) - + S1 = torch.matmul(q1, k1.transpose(2, 3)) S2 = torch.matmul(q2, k2.transpose(2, 3)) long_distance = ((ms + p_seq - ns) >= dist_threshold) S = torch.where(long_distance, S2, S1) * sm_scale - + if causal: S = torch.where(ms + p_seq >= ns, S, torch.finfo(S.dtype).min) @@ -45,12 +45,12 @@ def attention_grad(q1, k1, q2, k2, v, w, causal, sm_scale, o, do, upcast=False): ms = torch.arange(q_seq_len, device=device).unsqueeze(-1) ns = torch.arange(kv_seq_len, device=device) - + S1 = torch.matmul(q1, k1.transpose(2, 3)) S2 = torch.matmul(q2, k2.transpose(2, 3)) long_distance = ((ms + p_seq - ns) >= w) S = torch.where(long_distance, S2, S1) * sm_scale - + if causal: S = torch.where((ms + p_seq) >= ns, S, torch.finfo(S.dtype).min) diff --git a/src/flag_attn/total.py b/src/flag_attn/total.py index 70a6198..318be34 100644 --- a/src/flag_attn/total.py +++ b/src/flag_attn/total.py @@ -18,7 +18,7 @@ def total_attention(q, k, l, causal=False, sm_scale=None): if sm_scale is None: sm_scale = 1. / math.sqrt(D) - + # to work around https://github.com/openai/triton/issues/2441 device = torch.cuda.device_of(q) with torch.cuda.device(device): @@ -31,11 +31,11 @@ def total_attention(q, k, l, causal=False, sm_scale=None): grid = (triton.cdiv(N, BLOCK_N), H, B) tot_attn = torch.empty((B, H, N), dtype=torch.float32, device=q.device) _total_attention_kernel[grid]( - q, k, l, tot_attn, sm_scale, + q, k, l, tot_attn, sm_scale, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), B, H, M, N, P_SEQ, - BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, + BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, CAUSAL=causal, DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, num_stages=num_stages, num_warps=num_warps, @@ -45,15 +45,14 @@ def total_attention(q, k, l, causal=False, sm_scale=None): @triton.jit def _total_attention_kernel( - Q, K, L, TA, sm_scale, + Q, K, L, TA, sm_scale, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, Z, H, M, N, P_SEQ, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + CAUSAL: tl.constexpr, DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, ): - input_dtype = Q.dtype.element_ty # -- grid id -- start_n = tl.program_id(0) off_h = tl.program_id(1) @@ -77,8 +76,8 @@ def _total_attention_kernel( offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_m_base = tl.arange(0, BLOCK_M) offs_k = tl.arange(0, BLOCK_DMODEL) - - # initialize pointers to value-like data + + # initialize pointers to value-like data q_ptrs = Q + (offs_m_init[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL) k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL) ta_ptrs = TA + offs_n # (BLOCK_N, ) @@ -92,7 +91,7 @@ def _total_attention_kernel( # initialize total attention tot_attn = tl.zeros([BLOCK_N], dtype=tl.float32) - + # loop over a col for start_m in range(lo, M, BLOCK_M): start_m = tl.multiple_of(start_m, BLOCK_M) @@ -126,11 +125,11 @@ def _total_attention_kernel( p = tl.where(valid_mask, p, 0.0) if CAUSAL: p = tl.where(causal_mask, p, 0.0) - + tot_attn += tl.sum(p, 0) # increment pointers q_ptrs += BLOCK_M * stride_qm - + if DIVISIBLE_N: tl.store(ta_ptrs, tot_attn) # (BLOCK_N,)