Skip to content

Commit

Permalink
removed the cub::Sum to fix compile issue
Browse files Browse the repository at this point in the history
Signed-off-by: Vincent Huang <[email protected]>
  • Loading branch information
ttyio authored and rajeevsrao committed Sep 22, 2023
1 parent c0f5044 commit 78245b0
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 18 deletions.
10 changes: 4 additions & 6 deletions plugin/common/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ __device__ inline void layerNorm(
__shared__ R mu; // mean
__shared__ R rsigma; // 1 / std.dev.

const auto sumKV = BlockReduce(temp_storage).Reduce(threadData, cub::Sum());
const auto sumKV = BlockReduce(temp_storage).Reduce(threadData, [](auto const& lhs, auto const& rhs){return lhs + rhs;});

if (threadIdx.x == 0)
{
Expand Down Expand Up @@ -286,7 +286,7 @@ __device__ inline void layerNormSmall(
__shared__ T mu; // mean
__shared__ T rsigma; // 1 / std.dev.

const auto sumKV = BlockReduce(temp_storage).Reduce(threadData, cub::Sum());
const auto sumKV = BlockReduce(temp_storage).Reduce(threadData, [](auto const& lhs, auto const& rhs){return lhs + rhs;});

if (threadIdx.x == 0)
{
Expand Down Expand Up @@ -318,7 +318,6 @@ __device__ inline void scaledSoftmaxSmall(
const int32_t offset = (blockIdx.y * gridDim.x + blockIdx.x) * ld;

const float w(rsqrtHeadSize);
cub::Sum sum;
float threadData(-FLT_MAX);

const int32_t idx = offset + threadIdx.x;
Expand All @@ -343,7 +342,7 @@ __device__ inline void scaledSoftmaxSmall(
threadData = 0;
}

const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, [](auto const& lhs, auto const& rhs){return lhs + rhs;});

if (threadIdx.x == 0)
{
Expand Down Expand Up @@ -371,7 +370,6 @@ __device__ inline void scaledSoftmax(
const int32_t offset = (blockIdx.y * gridDim.x + blockIdx.x) * ld;

const float w(rsqrtHeadSize);
cub::Sum sum;
float threadData(-FLT_MAX);

if (lastValid >= blockDim.x)
Expand Down Expand Up @@ -399,7 +397,7 @@ __device__ inline void scaledSoftmax(
threadData += exp((static_cast<float>(input[idx]) - fMax) * w);
}

const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, [](auto const& lhs, auto const& rhs){return lhs + rhs;});

if (threadIdx.x == 0)
{
Expand Down
3 changes: 1 addition & 2 deletions plugin/embLayerNormPlugin/embLayerNormKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ __global__ void embLayerNormKernel(int ld, int32_t const* inputIds, int32_t cons
int32_t const tokSize, T* output)
{

cub::Sum pairSum;
// 1. lookup word and token of the block
// blockIdx.x = position in the sequence
// blockIdx.y = batch
Expand Down Expand Up @@ -225,7 +224,7 @@ __global__ void embLayerNormKernel(int ld, int32_t const* inputIds, int32_t cons

output[outOffset + it] = val;
T const rldval = rld * val;
threadData = pairSum(threadData, kvp<T>(rldval, rldval * val));
threadData = threadData + kvp<T>(rldval, rldval * val);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ __global__ void embLayerNormKernelHFace(int32_t ld, int32_t const* inputIds, int
// this code currently assumes the input shape is SxB, row-major => seqPos = s * B + b
// instead we want BxS, row-major => seqPos = b * S + s

cub::Sum pairSum;
// 1. lookup word and token of the block
// blockIdx.x = position in the sequence
// blockIdx.y = batch
Expand Down Expand Up @@ -95,7 +94,7 @@ __global__ void embLayerNormKernelHFace(int32_t ld, int32_t const* inputIds, int

output[outOffset + it] = val;
T const rldval = rld * val;
threadData = pairSum(threadData, kvp<T>(rldval, rldval * val));
threadData = threadData + kvp<T>(rldval, rldval * val);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ __global__ void embLayerNormKernelMTron(int32_t ld, int32_t const* inputIds, int
// this code currently assumes the input shape is SxB, row-major => seqPos = s * B + b
// instead we want BxS, row-major => seqPos = b * S + s

cub::Sum pairSum;
// 1. lookup word and token of the block
// blockIdx.x = position in the sequence
// blockIdx.y = batch
Expand Down Expand Up @@ -96,7 +95,7 @@ __global__ void embLayerNormKernelMTron(int32_t ld, int32_t const* inputIds, int
output[outOffset + it] = val;
skip[outOffset + it] = val;
T const rldval = rld * val;
threadData = pairSum(threadData, kvp<T>(rldval, rldval * val));
threadData = threadData + kvp<T>(rldval, rldval * val);
}
}

Expand Down
10 changes: 4 additions & 6 deletions plugin/skipLayerNormPlugin/skipLayerNormKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ __global__ void skiplnDQQ(int32_t const ld, int8_t const* input, int8_t const* s
__shared__ __half mu; // mean
__shared__ __half rsigma; // 1 / std.dev.

const __half2 sum2 = BlockReduce(tempStorage).Reduce(loc, cub::Sum());
const __half2 sum2 = BlockReduce(tempStorage).Reduce(loc, [](auto const& lhs, auto const& rhs){return lhs + rhs;});

if (threadIdx.x == 0)
{
Expand Down Expand Up @@ -139,7 +139,7 @@ __global__ void skipln_vec(
__shared__ T mu; // mean
__shared__ T rsigma; // 1 / std.dev.

auto const sumKV = BlockReduce(tempStorage).Reduce(kvp<T>(local, local2), cub::Sum());
auto const sumKV = BlockReduce(tempStorage).Reduce(kvp<T>(local, local2), [](auto const& lhs, auto const& rhs){return lhs + rhs;});

if (threadIdx.x == 0)
{
Expand All @@ -166,7 +166,6 @@ __global__ void skipLayerNormKernelSmall(
const T rld = T(1) / T(ld);
int32_t const offset = blockIdx.x * ld;

cub::Sum pairSum;
// reduce x and x^2
kvp<T> threadData(0, 0);
int32_t const idx = offset + threadIdx.x;
Expand All @@ -182,7 +181,7 @@ __global__ void skipLayerNormKernelSmall(
}

const T rldval = rld * val;
threadData = pairSum(threadData, kvp<T>(rldval, rldval * val));
threadData = threadData + kvp<T>(rldval, rldval * val);
}

layerNormSmall<T, T, TPB>(val, threadData, ld, idx, beta, gamma, output);
Expand All @@ -195,7 +194,6 @@ __global__ void skipLayerNormKernel(
const T rld = T(1) / T(ld);
int32_t const offset = blockIdx.x * ld;

cub::Sum pairSum;
// reduce x and x^2
kvp<T> threadData(0, 0);

Expand All @@ -209,7 +207,7 @@ __global__ void skipLayerNormKernel(
val += T(bias[i]);
}
const T rldval = rld * val;
threadData = pairSum(threadData, kvp<T>(rldval, rldval * val));
threadData = threadData + kvp<T>(rldval, rldval * val);
output[idx] = val;
}

Expand Down

0 comments on commit 78245b0

Please sign in to comment.