diff --git a/python/aitemplate/backend/cuda/softmax/softmax.cuh b/python/aitemplate/backend/cuda/softmax/softmax.cuh index 8b10e520b..640425815 100644 --- a/python/aitemplate/backend/cuda/softmax/softmax.cuh +++ b/python/aitemplate/backend/cuda/softmax/softmax.cuh @@ -240,10 +240,10 @@ __global__ void softmax_small_k(Arguments args, size_t M) { PRAGMA_UNROLL for (size_t i = 0; i < m; i++) { - T max = std::numeric_limits::lowest(); + T max = input_tile[i * K]; // find max PRAGMA_UNROLL - for (size_t j = 0; j < K; j++) { + for (size_t j = 1; j < K; j++) { max = fast_max(input_tile[i * K + j], max); } // get sum @@ -287,10 +287,10 @@ __global__ void softmax_small_k(Arguments args, size_t M) { input_tile[j] = input[i * K + j]; } - T max = std::numeric_limits::lowest(); + T max = input_tile[0]; // find max PRAGMA_UNROLL - for (size_t j = 0; j < K; j++) { + for (size_t j = 1; j < K; j++) { max = fast_max(input_tile[j], max); } // get sum @@ -865,13 +865,14 @@ __global__ void softmax_general(const T* input, T* output, size_t outer_size) { inner_index < InnerSize; inner_index += blockDim.y * gridDim.y) { const uint32_t data_offset = outer_offset + inner_index; - T local_max = std::numeric_limits::lowest(); + T local_max = input[data_offset + threadIdx.x * dim_stride]; // First we reduce locally on a per-thread basis. We reduce #InnerThreads // consecutive rows of the tensor at once, so we read the #input values in // contiguous chunks of size #InnerThreads. For small values of InnerSize, // we have InnerThreads == InnerSize, and so we will read in one big // contiguous range. - for (uint32_t d = threadIdx.x; d < DimSize; d += blockDim.x) { + for (uint32_t d = threadIdx.x + blockDim.x; d < DimSize; + d += blockDim.x) { const T value = input[data_offset + d * dim_stride]; local_max = fast_max(local_max, value); } diff --git a/tests/unittest/ops/test_softmax.py b/tests/unittest/ops/test_softmax.py index e477e5e12..c4fd47eda 100644 --- a/tests/unittest/ops/test_softmax.py +++ b/tests/unittest/ops/test_softmax.py @@ -130,6 +130,20 @@ def _test_softmax( (6, 8, 3, 3), 2, ), + ( + "zero_batch_size", + "float16", + (0,), + (3, 3), + 0, + ), + ( + "empty_tensor", + "float16", + (2,), + (0, 3), + 1, + ), ], TestEnv.CUDA_SM80: [ ("dim_1_bf16", "bfloat16", (1, 2), (6,), 1),