Skip to content

Commit

Permalink
AIT softmax fix (#982)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #982

For S387149

Reviewed By: aakhundov

Differential Revision: D52530631

fbshipit-source-id: ef61868f13f46240cc2714773e8579cf34b19a0c
  • Loading branch information
zoranzhao authored and facebook-github-bot committed Jan 6, 2024
1 parent c1747f6 commit 7bccbc9
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
13 changes: 7 additions & 6 deletions python/aitemplate/backend/cuda/softmax/softmax.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,10 @@ __global__ void softmax_small_k(Arguments<T> args, size_t M) {

PRAGMA_UNROLL
for (size_t i = 0; i < m; i++) {
T max = std::numeric_limits<T>::lowest();
T max = input_tile[i * K];
// find max
PRAGMA_UNROLL
for (size_t j = 0; j < K; j++) {
for (size_t j = 1; j < K; j++) {
max = fast_max(input_tile[i * K + j], max);
}
// get sum
Expand Down Expand Up @@ -287,10 +287,10 @@ __global__ void softmax_small_k(Arguments<T> args, size_t M) {
input_tile[j] = input[i * K + j];
}

T max = std::numeric_limits<T>::lowest();
T max = input_tile[0];
// find max
PRAGMA_UNROLL
for (size_t j = 0; j < K; j++) {
for (size_t j = 1; j < K; j++) {
max = fast_max(input_tile[j], max);
}
// get sum
Expand Down Expand Up @@ -865,13 +865,14 @@ __global__ void softmax_general(const T* input, T* output, size_t outer_size) {
inner_index < InnerSize;
inner_index += blockDim.y * gridDim.y) {
const uint32_t data_offset = outer_offset + inner_index;
T local_max = std::numeric_limits<T>::lowest();
T local_max = input[data_offset + threadIdx.x * dim_stride];
// First we reduce locally on a per-thread basis. We reduce #InnerThreads
// consecutive rows of the tensor at once, so we read the #input values in
// contiguous chunks of size #InnerThreads. For small values of InnerSize,
// we have InnerThreads == InnerSize, and so we will read in one big
// contiguous range.
for (uint32_t d = threadIdx.x; d < DimSize; d += blockDim.x) {
for (uint32_t d = threadIdx.x + blockDim.x; d < DimSize;
d += blockDim.x) {
const T value = input[data_offset + d * dim_stride];
local_max = fast_max(local_max, value);
}
Expand Down
14 changes: 14 additions & 0 deletions tests/unittest/ops/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,20 @@ def _test_softmax(
(6, 8, 3, 3),
2,
),
(
"zero_batch_size",
"float16",
(0,),
(3, 3),
0,
),
(
"empty_tensor",
"float16",
(2,),
(0, 3),
1,
),
],
TestEnv.CUDA_SM80: [
("dim_1_bf16", "bfloat16", (1, 2), (6,), 1),
Expand Down

0 comments on commit 7bccbc9

Please sign in to comment.