Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[runtime] add BatchTranspose cuda kernel #309

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
//===----------------------------------------------------------------------===//

#include <cuda_fp16.h>
#include <stdio.h>

namespace brt {
namespace cuda {
namespace kernel {

constexpr int32_t kMaxGridDim = 65535;
template <typename T>
__global__ void transpose_naive_2d_kernel(const T *input, T *output, int m,
int n) {
Expand All @@ -40,12 +41,98 @@ void transpose_naive_2d(const T *input, T *output, int m, int n, dim3 grid,
transpose_naive_2d_kernel<T><<<grid, block, 0, stream>>>(input, output, m, n);
}

template <typename T, int32_t TileSizeX, int32_t TileSizeY, int32_t BlockSize>
__global__ void batch_transpose_kernel(const int32_t total_tile_num,
const int32_t tile_num_in_dim0,
const int32_t tile_num_in_dim1,
const int32_t tile_per_sample,
const int32_t row, const int32_t col,
void *__restrict__ inp_ptr,
void *__restrict__ out_ptr) {
__shared__ T tile_in_shmem[TileSizeX][TileSizeY];
for (int32_t i = blockIdx.x, step_tile = gridDim.x; i < total_tile_num;
i += step_tile) {
const int32_t batch_idx = i / tile_per_sample;
const int32_t remainder = i - batch_idx * tile_per_sample;
const int32_t dim0_idx = remainder / tile_num_in_dim1;
const int32_t dim1_idx = remainder - dim0_idx * tile_num_in_dim1;

T *inp_tile_gmem = reinterpret_cast<T *>(inp_ptr);
T *out_tile_gmem = reinterpret_cast<T *>(out_ptr);
inp_tile_gmem += batch_idx * row * col + dim0_idx * TileSizeX * col +
dim1_idx * TileSizeY;
out_tile_gmem += batch_idx * row * col + dim1_idx * TileSizeY * row +
dim0_idx * TileSizeX;

int32_t range_0 = dim0_idx < tile_num_in_dim0 - 1
? TileSizeX
: row - dim0_idx * TileSizeX;
int32_t range_1 = dim1_idx < tile_num_in_dim1 - 1
? TileSizeY
: col - dim1_idx * TileSizeY;
constexpr int32_t row_num_per_iter = BlockSize / TileSizeY;
constexpr int32_t col_num_per_iter = BlockSize / TileSizeX;

int32_t tile_row_idx = threadIdx.x / TileSizeY;
int32_t tile_col_idx = threadIdx.x - tile_row_idx * TileSizeY;
for (int32_t j = tile_row_idx; j < range_0; j += row_num_per_iter) {
if (tile_col_idx < range_1) {
tile_in_shmem[j][tile_col_idx ^ j] =
inp_tile_gmem[j * col + tile_col_idx];
}
}
__syncthreads();
tile_row_idx = threadIdx.x / TileSizeX;
tile_col_idx = threadIdx.x - tile_row_idx * TileSizeX;
for (int32_t j = tile_row_idx; j < range_1; j += col_num_per_iter) {
if (tile_col_idx < range_0) {
out_tile_gmem[j * row + tile_col_idx] =
tile_in_shmem[tile_col_idx][j ^ tile_col_idx];
}
}
__syncthreads();
}
}

template <typename T>
void batch_transpose(int32_t batch, int32_t row, int32_t col, const T *inp_ptr,
T *out_ptr, cudaStream_t stream) {
constexpr int32_t kTileSize = 32;

const int32_t tile_num_in_dim0 = (row - 1) / kTileSize + 1;
const int32_t tile_num_in_dim1 = (col - 1) / kTileSize + 1;
const int32_t tile_per_sample = tile_num_in_dim0 * tile_num_in_dim1;
const int32_t total_tile_num = batch * tile_per_sample;
dim3 grid(total_tile_num >= kMaxGridDim ? kMaxGridDim : total_tile_num);
if (row < 8 || col < 8) {
constexpr int32_t kBlockSize = 64;
dim3 block(kBlockSize);
batch_transpose_kernel<T, kTileSize, kTileSize, kBlockSize>
<<<grid, block, 0, stream>>>(
total_tile_num, tile_num_in_dim0, tile_num_in_dim1, tile_per_sample,
row, col, reinterpret_cast<void *>(const_cast<T *>(inp_ptr)),
reinterpret_cast<void *>(out_ptr));
} else {
constexpr int32_t kBlockSize = 256;
dim3 block(kBlockSize);
batch_transpose_kernel<T, kTileSize, kTileSize, kBlockSize>
<<<grid, block, 0, stream>>>(
total_tile_num, tile_num_in_dim0, tile_num_in_dim1, tile_per_sample,
row, col, reinterpret_cast<void *>(const_cast<T *>(inp_ptr)),
reinterpret_cast<void *>(out_ptr));
}
}

// instantiate
template void transpose_naive_2d<float>(const float *, float *, int, int, dim3,
dim3, cudaStream_t);
template void transpose_naive_2d<__half>(const __half *, __half *, int, int,
dim3, dim3, cudaStream_t);
template void batch_transpose<float>(int32_t, int32_t, int32_t, const float *,
float *, cudaStream_t);

template void batch_transpose<__half>(int32_t, int32_t, int32_t, const __half *,
__half *, cudaStream_t);
} // namespace kernel
} // namespace cuda
} // namespace brt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ template <typename T>
void transpose_naive_2d(const T *input, T *output, int m, int n, dim3 grid,
dim3 block, cudaStream_t stream);

template <typename T>
void batch_transpose(int32_t batch, int32_t row, int32_t col, const T *inp_ptr,
T *out_ptr, cudaStream_t stream);
} // namespace kernel
} // namespace cuda
} // namespace brt
Original file line number Diff line number Diff line change
Expand Up @@ -40,30 +40,35 @@ using namespace brt::ir;
namespace brt {
namespace cuda {

template <typename T> Transpose2D<T>::Transpose2D(const OpAccessor &accessor) {
template <typename T>
BatchTranspose<T>::BatchTranspose(const OpAccessor &accessor) {
auto shape_input = accessor.GetArgShape(0);
auto shape_output = accessor.GetArgShape(1);

BRT_ENFORCE(shape_input.size() == 2);
BRT_ENFORCE((shape_input.size() == 2 || shape_input.size() == 3));
BRT_ENFORCE(shape_output ==
transpose::DeduceOutputShape(
shape_input, accessor.GetAttrAsIntArray("permutation")));
input_shape = shape_input;
}

template <typename T>
void Transpose2D<T>::Execute(const T *input, T *output,
cudnnHandle_t /*handle*/, cudaStream_t stream) {
void BatchTranspose<T>::Execute(const T *input, T *output,
cudnnHandle_t /*handle*/, cudaStream_t stream) {
auto p = MakeCUDAGridAndBlock(input_shape[1], input_shape[0]);
kernel::transpose_naive_2d<T>(input, output, static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[1]), p.first,
p.second, stream);
int32_t batch = 1, m, n;
if (input_shape.size() == 2) {
m = input_shape[0], n = input_shape[1];
} else if (input_shape.size() == 3) {
batch = input_shape[0], m = input_shape[1], n = input_shape[2];
}
kernel::batch_transpose<T>(batch, m, n, input, output, stream);
BRT_CUDA_CHECK(cudaGetLastError());
}

// instantiate
template class Transpose2D<float>;
template class Transpose2D<__half>;
template class BatchTranspose<float>;
template class BatchTranspose<__half>;

template <typename T> Transpose4D<T>::Transpose4D(const OpAccessor &accessor) {
auto shape_input = accessor.GetArgShape(0);
Expand Down Expand Up @@ -134,8 +139,14 @@ template class Transpose4D<__half>;
template <typename T>
TransposeImpl<T>::TransposeImpl(const OpAccessor &accessor) {
auto shape_input = accessor.GetArgShape(0);
if (shape_input.size() == 2) {
this->impl = new Transpose2D<T>(accessor);
if (shape_input.size() == 2 || shape_input.size() == 3) {
auto permutation = accessor.GetAttrAsIntArray("permutation");
if (permutation[permutation.size() - 2] == permutation.size() - 1 &&
permutation[permutation.size() - 1] == permutation.size() - 2) {
this->impl = new BatchTranspose<T>(accessor);
} else {
BRT_THROW("unsupported transpose");
}
} else if (shape_input.size() == 4) {
this->impl = new Transpose4D<T>(accessor);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ template <typename T> class TransposeBase {
};

/**
* Transpose2D
* BatchTranspose
*/
template <typename T> class Transpose2D : public TransposeBase<T> {
template <typename T> class BatchTranspose : public TransposeBase<T> {
public:
explicit Transpose2D(const OpAccessor &accessor);
explicit BatchTranspose(const OpAccessor &accessor);

virtual void Execute(const T *input, T *output, cudnnHandle_t handle,
cudaStream_t stream) override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,38 @@ static void CheckTranspose2D(T *input, T *output,
free(h_output);
}

template <typename T>
static void CheckTranspose3D(T *input, T *output,
const std::vector<int64_t> &input_shape) {
T *h_input =
(T *)malloc(input_shape[0] * input_shape[1] * input_shape[2] * sizeof(T));
T *h_output =
(T *)malloc(input_shape[0] * input_shape[1] * input_shape[2] * sizeof(T));
cudaMemcpy(h_input, input,
input_shape[0] * input_shape[1] * input_shape[2] * sizeof(T),
cudaMemcpyDeviceToHost);
cudaMemcpy(h_output, output,
input_shape[0] * input_shape[1] * input_shape[2] * sizeof(T),
cudaMemcpyDeviceToHost);
cudaDeviceSynchronize();

int B = input_shape[0];
int m = input_shape[1];
int n = input_shape[2];
for (int64_t t = 0; t < B; ++t) {
for (int64_t i = 0; i < m; ++i) {
for (int64_t j = 0; j < n; ++j) {
int in_idx = t * m * n + i * n + j;
int out_idx = t * m * n + j * m + i;
EXPECT_EQ(h_output[out_idx], h_input[in_idx]);
}
}
}

free(h_input);
free(h_output);
}

template <typename T>
static void CheckTranspose4D(T *input, T *output,
const std::vector<int64_t> &input_shape,
Expand Down Expand Up @@ -142,6 +174,8 @@ static void TestTranspose(std::vector<int64_t> shape_input,

if (shape_input.size() == 2) {
CheckTranspose2D<T>(d_input, d_output, shape_input);
} else if (shape_input.size() == 3) {
CheckTranspose3D<T>(d_input, d_output, shape_input);
} else if (shape_input.size() == 4) {
CheckTranspose4D<T>(d_input, d_output, shape_input, perm);
} else {
Expand All @@ -150,17 +184,31 @@ static void TestTranspose(std::vector<int64_t> shape_input,
}

TEST(CUDAOpKerenlTest, TransposeOp) {
// 2D transpose
TestTranspose<float>({32, 64}, {64, 32}, {1, 0});
TestTranspose<float>({2, 1}, {1, 2}, {1, 0});
TestTranspose<float>({1007, 13}, {13, 1007}, {1, 0});
TestTranspose<float>({2007, 4339}, {4339, 2007}, {1, 0});
TestTranspose<float>({1000, 512}, {512, 1000}, {1, 0});
// 3D Batch transpose
TestTranspose<float>({13, 789, 1234}, {13, 1234, 789}, {0, 2, 1});
TestTranspose<float>({65536, 32, 50}, {65536, 50, 32}, {0, 2, 1});
TestTranspose<float>({65536, 2, 50}, {65536, 50, 2}, {0, 2, 1});
// NCHW 2 NHWC
TestTranspose<float>({10, 20, 30, 40}, {10, 30, 40, 20}, {0, 2, 3, 1});
// NHWC 2 NCHW
TestTranspose<float>({10, 20, 30, 40}, {10, 40, 20, 30}, {0, 3, 1, 2});
}

TEST(CUDAOpKerenlTest, TransposeOpFp16) {
// 2D transpose
TestTranspose<__half>({32, 64}, {64, 32}, {1, 0});
TestTranspose<__half>({2, 1}, {1, 2}, {1, 0});
TestTranspose<__half>({1007, 13}, {13, 1007}, {1, 0});
TestTranspose<__half>({2007, 4339}, {4339, 2007}, {1, 0});
TestTranspose<__half>({1000, 512}, {512, 1000}, {1, 0});
// 3D Batch transpose
TestTranspose<__half>({13, 789, 1234}, {13, 1234, 789}, {0, 2, 1});
// NCHW 2 NHWC
TestTranspose<__half>({10, 20, 30, 40}, {10, 30, 40, 20}, {0, 2, 3, 1});
// NHWC 2 NCHW
Expand Down
Loading