From 6f27e85ed8bf4aad1870f5ec92a051d586c7043e Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 18 Sep 2022 09:49:10 -0400 Subject: [PATCH 1/8] add init bwd kernel(not ready) --- fbgemm_gpu/CMakeLists.txt | 6 + fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp | 346 ++++++++++++++++ fbgemm_gpu/hip_kernel/split_tbe_common_hip.h | 408 +++++++++++++++++++ fbgemm_gpu/hip_kernel/split_tbe_fwd_hip.cpp | 181 +------- 4 files changed, 761 insertions(+), 180 deletions(-) create mode 100644 fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp create mode 100644 fbgemm_gpu/hip_kernel/split_tbe_common_hip.h diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 811c4bcf8..9a080fc5b 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -188,6 +188,12 @@ message(STATUS "${PYTHON_EXECUTABLE}" "${CMAKE_CODEGEN_DIR}/embedding_backward_c ${CMAKE_CURRENT_SOURCE_DIR}/hip_kernel/split_tbe_fwd_hip_kernel.hsp -o \ ${CMAKE_CURRENT_SOURCE_DIR}/hip_kernel/split_tbe_fwd_hip_kernel.hsaco") + execute_process( + COMMAND sh -c "${ROCM_PATH}/hip/bin/hipcc -x hip --cuda-device-only -save-temps -c -O3 \ + ${CMAKE_CURRENT_SOURCE_DIR}/hip_kernel/split_tbe_bwd_hip.cpp -o \ + ${CMAKE_CURRENT_SOURCE_DIR}/hip_kernel/split_tbe_bwd_hip_kernel.hsp -o \ + ${CMAKE_CURRENT_SOURCE_DIR}/hip_kernel/split_tbe_bwd_hip_kernel.hsaco") + else() add_custom_command( OUTPUT ${gen_cpu_source_files} ${gen_gpu_source_files} diff --git a/fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp b/fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp new file mode 100644 index 000000000..cfeb9cc63 --- /dev/null +++ b/fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp @@ -0,0 +1,346 @@ +/******************************************************************************* + * Copyright (c) 2016 - 2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + ******************************************************************************/ + +#include +#include +#include "split_tbe_common_hip.h" + +template +struct rowwise_adagrad_optimizer_t +{ + __device__ rowwise_adagrad_optimizer_t(const rowwise_adagrad_kernel_arg_t& karg_) + : karg(karg_) + { + } + + // template + // __device__ static void precompute(float * acc){ + // // compute per row square sum + // } + template + __device__ void update(cache_t* acc, emb_t* weight, uint32_t row_index) + { + if constexpr(segment_split == 0) + { + cache_t momentum = karg.p_momentum[row_index]; // should be s_load + // compute per row square sum + cache_t local_sum_squre = .0f; + if constexpr(weigiht_decay_mode == 1) + { +#pragma unroll + for(auto i = 0; i < thread_length; i++) + { + cache_t w = static_cast(weight[i]); + cache_t a = acc[i] + w * karg.weight_decay; + local_sum_squre += a * a; + } + } + else + { +#pragma unroll + for(auto i = 0; i < thread_length; i++) + { + cache_t a = acc[i]; + local_sum_squre += a * a; + } + } + + cache_t avg_square = + wave_reduce, cache_t, AMDGCN_WAVE_SIZE>(local_sum_squre) / + embedding_dim; + + cache_t multiplier; + cache_t correction; + + cache_t momentum_new = momentum + avg_square; + + multiplier = karg.learning_rate / (sqrtf(momentum_new) + karg.eps); + + if constexpr(weigiht_decay_mode == 1) + { + correction = 1.0 - multiplier * karg.weight_decay; + } + else if constexpr(weigiht_decay_mode == 2) + { + correction = 1.0 - karg.learning_rate * karg.weight_decay; + } + else + { + correction = 1.0; + } + +// update new weight value +#pragma unroll + for(auto i = 0; i < thread_length; i++) + { + cache_t w = static_cast(weight[i]); + cache_t a = acc[i]; + w = correction * w - multiplier * a; + weight[i] = static_cast(w); + } + + karg.p_momentum[row_index] = momentum_new; + } + } + + rowwise_adagrad_kernel_arg_t karg; +}; + +template // 0-warp per row, 1-cta per row, 2-atomic(needed?) +__device__ void split_tbe_backward_unweighted_hip_kernel( + const grad_t* p_output_grad, + emb_t* p_emb_table, + const int64_t* p_sorted_linear_indices_run, + const int64_t* p_sorted_linear_indices_cumulative_run_lengths, + const int32_t* p_sorted_linear_indices_num_runs, + const int32_t* p_long_run_ids, + const int64_t* p_num_long_run_ids, + const int32_t* p_sorted_infos, + magic_div_u32_t batch_mdiv, + uint32_t max_segment_length_per_warp, + uint32_t emb_dim, + uint32_t batch, + uint32_t num_rows, + uint32_t num_tables, + optimizer_karg_t opt_karg) +{ + constexpr uint32_t dword_per_row = (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; + constexpr uint32_t waves_per_block = block_size / AMDGCN_WAVE_SIZE; + constexpr uint32_t length_mask = ~(bag_unroll - 1); + const uint32_t wave_id = __builtin_amdgcn_readfirstlane(threadIdx.x / AMDGCN_WAVE_SIZE); + const uint32_t lane_id = threadIdx.x % AMDGCN_WAVE_SIZE; + const uint32_t run_id = wave_id + blockIdx.x * waves_per_block; + + if(run_id >= p_sorted_linear_indices_num_runs[0]) + return; + + const int64_t linear_index = p_sorted_linear_indices_run[run_id]; + const int64_t emb_idx = linear_index - blockIdx.y; + const int32_t segment_start = p_sorted_linear_indices_cumulative_run_lengths[run_id]; + const int32_t segment_end = p_sorted_linear_indices_cumulative_run_lengths[run_id + 1]; + + p_output_grad += blockIdx.y * emb_dim; + + uint64_t emb_table_stride = static_cast(num_rows) * emb_dim; + p_emb_table += blockIdx.y * emb_table_stride; + opt_karg.p_momentum += blockIdx.y * num_rows; + + const int32_t segment_length = segment_end - segment_start; + + if(segment_length >= max_segment_length_per_warp) + return; + + const int32_t segment_length_mod = segment_length & length_mask; + + cache_t grad_acc[dword_per_row]; + int32_t infos[bag_unroll]; + grad_t grad_data[dword_per_row * bag_prefetch]; + emb_t emb_data[dword_per_row]; + + int itr = 0; + if(segment_length_mod == 0) + goto L_tail_grad_acc; + +#pragma unroll + for(int i = 0; i < bag_unroll; i++) + { + infos[i] = p_sorted_infos[i]; + } + + itr += bag_unroll; + p_sorted_infos += bag_unroll; + + uint32_t row_index; + uint32_t table_index__unused; + + // LOOP + for(; itr < segment_length_mod; itr += bag_unroll) + { + magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index__unused, row_index); + load_row_per_warp::run( + &grad_data[0], row_index * num_tables, p_output_grad, lane_id); + + magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index__unused, row_index); + load_row_per_warp::run( + &grad_data[dword_per_row], row_index * num_tables, p_output_grad, lane_id); + +#pragma unroll + for(int j = 2; j < bag_unroll; j += 2) + { + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index__unused, row_index); + load_row_per_warp::run( + &grad_data[0], row_index * num_tables, p_output_grad, lane_id); + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id); + magic_div_u32_run_with_mod( + batch_mdiv, infos[j + 1], batch, table_index__unused, row_index); + load_row_per_warp::run( + &grad_data[dword_per_row], row_index * num_tables, p_output_grad, lane_id); + } + +#pragma unroll + for(int i = 0; i < bag_unroll; i++) + { + infos[i] = p_sorted_infos[i]; + } + p_sorted_infos += bag_unroll; + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id); + } + + // LAST + magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index__unused, row_index); + load_row_per_warp::run( + &grad_data[0], row_index * num_tables, p_output_grad, lane_id); + + magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index__unused, row_index); + load_row_per_warp::run( + &grad_data[dword_per_row], row_index * num_tables, p_output_grad, lane_id); + +#pragma unroll + for(int j = 2; j < bag_unroll; j += 2) + { + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index__unused, row_index); + load_row_per_warp::run( + &grad_data[0], row_index * num_tables, p_output_grad, lane_id); + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id); + magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index__unused, row_index); + load_row_per_warp::run( + &grad_data[dword_per_row], row_index * num_tables, p_output_grad, lane_id); + } + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id); + +L_tail_grad_acc: + if(segment_length & (bag_unroll - 1)) + { + // last, load one by one + do + { + infos[0] = p_sorted_infos[0]; + p_sorted_infos++; + + magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index__unused, row_index); + load_row_per_warp::run( + &grad_data[0], row_index * num_tables, p_output_grad, lane_id); + accumulate_row_per_warp::run( + &grad_data[0], &grad_data[0], lane_id); + + itr++; + } while(itr < segment_length); + } + + // load the old emb weight data + load_row_per_warp::run( + &emb_data[0], emb_idx, p_emb_table, lane_id); + optimizer_t optimizer(opt_karg); + optimizer.template update(grad_acc, emb_data, row_index); + + // store updated weight to grad + store_row_per_warp::run(&emb_data[0], p_emb_table, lane_id); +} + +#define SPLIT_TBE_BWD_WARP_PER_ROW_KERNEL(optimizer, \ + weight_decay_mode, \ + segment_split, \ + emb_prec, \ + emb_type, \ + embedding_dim, \ + bag_prefetch, \ + bag_unroll) \ + extern "C" __global__ void \ + split_tbe_bwd_hip_kernel_warp_per_row_##optimizer##_w##weight_decay_mode##_s##segment_split##_##emb_prec##_e##embedding_dim( \ + const float* p_output_grad, \ + emb_type* p_emb_table, \ + const int64_t* p_sorted_linear_indices_run, \ + const int64_t* p_sorted_linear_indices_cumulative_run_lengths, \ + const int32_t* p_sorted_linear_indices_num_runs, \ + const int32_t* p_long_run_ids, \ + const int64_t* p_num_long_run_ids, \ + const int32_t* p_sorted_infos, \ + magic_div_u32_t batch_mdiv, \ + uint32_t max_segment_length_per_warp, \ + uint32_t emb_dim, \ + uint32_t batch, \ + uint32_t num_rows, \ + uint32_t num_tables, \ + optimizer##_kernel_arg_t opt_karg) \ + { \ + split_tbe_backward_unweighted_hip_kernel< \ + optimizer##_optimizer_t, \ + optimizer##_kernel_arg_t, \ + emb_type, \ + float, \ + float, \ + BLOCK_SIZE, \ + embedding_dim, \ + bag_prefetch, \ + bag_unroll, \ + segment_split>(p_output_grad, \ + p_emb_table, \ + p_sorted_linear_indices_run, \ + p_sorted_linear_indices_cumulative_run_lengths, \ + p_sorted_linear_indices_num_runs, \ + p_long_run_ids, \ + p_num_long_run_ids, \ + p_sorted_infos, \ + batch_mdiv, \ + max_segment_length_per_warp, \ + emb_dim, \ + batch, \ + num_rows, \ + num_tables, \ + opt_karg); \ + } + +SPLIT_TBE_BWD_WARP_PER_ROW_KERNEL(rowwise_adagrad, 1, 0, fp32, float, 64, 2, 8) +SPLIT_TBE_BWD_WARP_PER_ROW_KERNEL(rowwise_adagrad, 1, 0, fp32, float, 128, 2, 8) +SPLIT_TBE_BWD_WARP_PER_ROW_KERNEL(rowwise_adagrad, 1, 0, fp32, float, 192, 2, 8) +SPLIT_TBE_BWD_WARP_PER_ROW_KERNEL(rowwise_adagrad, 1, 0, fp32, float, 256, 2, 8) + +SPLIT_TBE_BWD_WARP_PER_ROW_KERNEL(rowwise_adagrad, 1, 0, fp16, half, 64, 2, 8) +SPLIT_TBE_BWD_WARP_PER_ROW_KERNEL(rowwise_adagrad, 1, 0, fp16, half, 128, 2, 8) +SPLIT_TBE_BWD_WARP_PER_ROW_KERNEL(rowwise_adagrad, 1, 0, fp16, half, 192, 2, 8) +SPLIT_TBE_BWD_WARP_PER_ROW_KERNEL(rowwise_adagrad, 1, 0, fp16, half, 256, 2, 8) diff --git a/fbgemm_gpu/hip_kernel/split_tbe_common_hip.h b/fbgemm_gpu/hip_kernel/split_tbe_common_hip.h new file mode 100644 index 000000000..6efe1534e --- /dev/null +++ b/fbgemm_gpu/hip_kernel/split_tbe_common_hip.h @@ -0,0 +1,408 @@ +/******************************************************************************* + * Copyright (c) 2016 - 2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + ******************************************************************************/ +#pragma once +#include +#include +#include + +/******************************************************************************/ +typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); +typedef float floatx2_t __attribute__((ext_vector_type(2))); +#define AMDGCN_BUFFER_RES_3 0x00027000 +#define AMDGCN_WAVE_SIZE 64 + +template +union amdgcn_buffer_resource +{ + // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions + int32x4_t content; + struct + { + T* address; + int32_t range; + int32_t config; + }; +}; + +template +__device__ int32x4_t amdgcn_make_buffer_resource(const T* addr) +{ + amdgcn_buffer_resource buffer_resource; + buffer_resource.address = const_cast(addr); + buffer_resource.range = 0xffffffff; + buffer_resource.config = AMDGCN_BUFFER_RES_3; // for gfx9 + + return buffer_resource.content; +} + +// buffer load fp32 +__device__ half +llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc, + int32_t voffset, + int32_t soffset, + int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); + +__device__ float +llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, + int32_t voffset, + int32_t soffset, + int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); + +__device__ half2 +llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc, + int32_t voffset, + int32_t soffset, + int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); + +__device__ void +llvm_amdgcn_raw_buffer_store_fp32(float vdata, + int32x4_t rsrc, + int32_t voffset, + int32_t soffset, + int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); + +__device__ void +llvm_amdgcn_raw_buffer_store_fp32x2(floatx2_t vdata, + int32x4_t rsrc, + int32_t voffset, + int32_t soffset, + int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); + +/******************************************************************************/ + +#define THREADS_PER_ROW 64 +#define BLOCK_SIZE 256 + +template +struct load_row_per_warp +{ + static __device__ void + run(emb_t* emb_data, index_t row_index, const emb_t* p_emb_table, int lane_id) + { + } +}; + +template +struct load_row_per_warp +{ + static constexpr int dword_per_row = (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) + { + int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * embedding_dim); +#pragma unroll + for(int i = 0; i < dword_per_row; i++) + { + emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); + } + } +}; + +template +struct load_row_per_warp +{ + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) + { + int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 64); + emb_data[0] = llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half), 0, 0); + } +}; + +template +struct load_row_per_warp +{ + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) + { + int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 128); + *reinterpret_cast(emb_data) = + llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2), 0, 0); + } +}; + +template +struct load_row_per_warp +{ + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) + { + int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + *reinterpret_cast(emb_data) = + llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2), 0, 0); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp16(emb_res, (lane_id + 128) * sizeof(half), 0, 0); + } +}; + +template +struct load_row_per_warp +{ + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) + { + int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 256); + *reinterpret_cast(&emb_data[0]) = + llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2), 0, 0); + *reinterpret_cast(&emb_data[2]) = + llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, (lane_id + 64) * sizeof(half2), 0, 0); + } +}; + +template +struct load_row_per_warp +{ + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) + { + int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 512); + *reinterpret_cast(&emb_data[0]) = + llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2), 0, 0); + *reinterpret_cast(&emb_data[2]) = + llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, (lane_id + 64) * sizeof(half2), 0, 0); + *reinterpret_cast(&emb_data[4]) = + llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, (lane_id + 64 * 2) * sizeof(half2), 0, 0); + *reinterpret_cast(&emb_data[6]) = + llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, (lane_id + 64 * 3) * sizeof(half2), 0, 0); + } +}; + +template +struct accumulate_row_per_warp +{ + static constexpr int dword_per_row = (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; + static __device__ void run(output_t* acc, emb_t* emb_data, int lane_id) + { +#pragma unroll + for(int i = 0; i < dword_per_row; i++) + { + acc[i] += static_cast(emb_data[i]); + } + } +}; + +template +struct store_row_per_warp +{ + static constexpr int dword_per_row = (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; + static __device__ void run(output_t* acc, output_t* p_output, int lane_id) + { +#pragma unroll + for(int i = 0; i < dword_per_row; i++) + { + p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; + } + } +}; + +template <> +struct store_row_per_warp +{ + static __device__ void run(float* acc, float* p_output, int lane_id) + { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32x2( + *reinterpret_cast(acc), out_res, lane_id * sizeof(floatx2_t), 0, 0); + } +}; + +template <> +struct store_row_per_warp +{ + static __device__ void run(float* acc, float* p_output, int lane_id) + { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32x2( + *reinterpret_cast(acc), out_res, lane_id * sizeof(floatx2_t), 0, 0); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float), 0, 0); + } +}; + +template <> +struct store_row_per_warp +{ + static __device__ void run(float* acc, float* p_output, int lane_id) + { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32x2( + *reinterpret_cast(acc), out_res, lane_id * sizeof(floatx2_t), 0, 0); + llvm_amdgcn_raw_buffer_store_fp32x2(*reinterpret_cast(&acc[2]), + out_res, + (lane_id + 64) * sizeof(floatx2_t), + 0, + 0); + } +}; + +template +__device__ to_t bit_cast(const from_t& v) +{ + // TODO: how to deal with sizeof(to_t) larger than sizeof(from_t) + static_assert(sizeof(to_t) == sizeof(from_t), ""); + return __builtin_bit_cast(to_t, v); +} + +template +struct reduce_op_sum_t +{ + __device__ data_t operator()(const data_t& a, const data_t& b) { return a + b; } +}; + +template +__device__ inline data_t wave_reduce(const data_t& thread_data) +{ + // wave_size must be power of 2 + constexpr int row_mask = 0xf; + constexpr int bank_mask = 0xf; + constexpr bool bound_ctrl = false; + + reduce_op_t reduce_op; + data_t result = thread_data; + + if constexpr(wave_size > 1) + { + result = reduce_op( + result, + bit_cast(__builtin_amdgcn_mov_dpp(bit_cast(result), + 0xb1, + row_mask, + bank_mask, + bound_ctrl))); // quad_perm:[1,0,3,2] + } + if constexpr(wave_size > 2) + { + result = reduce_op( + result, + bit_cast(__builtin_amdgcn_mov_dpp(bit_cast(result), + 0x4e, + row_mask, + bank_mask, + bound_ctrl))); // quad_perm:[2,3,0,1] + } + if constexpr(wave_size > 4) + { + result = + reduce_op(result, + bit_cast(__builtin_amdgcn_mov_dpp(bit_cast(result), + 0x114, + row_mask, + bank_mask, + bound_ctrl))); // row_shr:4 + } + if constexpr(wave_size > 8) + { + result = + reduce_op(result, + bit_cast(__builtin_amdgcn_mov_dpp(bit_cast(result), + 0x118, + row_mask, + bank_mask, + bound_ctrl))); // row_shr:8 + } +#if (__gfx1010__ || __gfx1011__ || __gfx1012__ || __gfx1030__ || __gfx1031__) + // TODO: current compiler seems fail on this branch + // if constexpr(wave_size > 16) + // { + // result = + // reduce_op(result, + // bit_cast(__builtin_amdgcn_mov_dpp(bit_cast(result), + // 0x1e0, + // row_mask, + // bank_mask, + // bound_ctrl))); // row_bcast:15 + // } +#else + if constexpr(wave_size > 16) + { + result = + reduce_op(result, + bit_cast(__builtin_amdgcn_mov_dpp(bit_cast(result), + 0x142, + row_mask, + bank_mask, + bound_ctrl))); // row_bcast:15 + } + if constexpr(wave_size > 32) + { + result = + reduce_op(result, + bit_cast(__builtin_amdgcn_mov_dpp(bit_cast(result), + 0x143, + row_mask, + bank_mask, + bound_ctrl))); // row_bcast:31 + } +#endif + // now the reduced value is in the last lane of wave + return bit_cast( + __builtin_amdgcn_readlane(bit_cast(result), wave_size - 1)); +} + +template +struct rowwise_adagrad_kernel_arg_t +{ + cache_t* p_momentum; + float eps; + float learning_rate; + float weight_decay; + // int64_t weight_decay_mode; +}; + +typedef struct +{ + uint32_t magic; + uint32_t shift; // actually 8 bit is enough +} magic_div_u32_t; + +static inline magic_div_u32_t magic_div_u32_gen(uint32_t d) +{ + assert(d >= 1 && d <= INT32_MAX); + uint8_t shift; + for(shift = 0; shift < 32; shift++) + if((1U << shift) >= d) + break; + + uint64_t one = 1; + uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1; + assert(magic <= 0xffffffffUL); + + magic_div_u32_t result; + result.magic = magic; + result.shift = shift; + return result; +} + +// numer / denom = quotient, reminder +__device__ inline uint32_t magic_div_u32_run(const magic_div_u32_t& mdiv, const uint32_t& n) +{ + uint32_t tmp = __umulhi(n, mdiv.magic); + return (tmp + n) >> mdiv.shift; +} + +__device__ inline void magic_div_u32_run_with_mod( + const magic_div_u32_t& mdiv, const uint32_t& n, const uint32_t d, uint32_t& quo, uint32_t& rem) +{ + quo = magic_div_u32_run(mdiv, n); + rem = n - quo * d; +} diff --git a/fbgemm_gpu/hip_kernel/split_tbe_fwd_hip.cpp b/fbgemm_gpu/hip_kernel/split_tbe_fwd_hip.cpp index bd816ebb0..90bfe8aa6 100644 --- a/fbgemm_gpu/hip_kernel/split_tbe_fwd_hip.cpp +++ b/fbgemm_gpu/hip_kernel/split_tbe_fwd_hip.cpp @@ -23,186 +23,7 @@ #include #include - -typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); -typedef float floatx2_t __attribute__((ext_vector_type(2))); -#define AMDGCN_BUFFER_RES_3 0x00027000 -#define AMDGCN_WAVE_SIZE 64 - -template -union amdgcn_buffer_resource{ - // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions - int32x4_t content; - struct { - T * address; - int32_t range; - int32_t config; - }; -}; - -template -__device__ int32x4_t amdgcn_make_buffer_resource(const T* addr) -{ - amdgcn_buffer_resource buffer_resource; - buffer_resource.address = const_cast(addr); - buffer_resource.range = 0xffffffff; - buffer_resource.config = AMDGCN_BUFFER_RES_3; // for gfx9 - - return buffer_resource.content; -} - -// buffer load fp32 -__device__ half -llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc, - int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); - -__device__ float -llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, - int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); - -__device__ half2 -llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc, - int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); - -__device__ void -llvm_amdgcn_raw_buffer_store_fp32(float vdata, - int32x4_t rsrc, - int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); - -__device__ void -llvm_amdgcn_raw_buffer_store_fp32x2(floatx2_t vdata, - int32x4_t rsrc, - int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); - -/******************************************************************************/ - -#define THREADS_PER_ROW 64 -#define BLOCK_SIZE 256 - -template -struct load_row_per_warp { - static __device__ void run(emb_t * emb_data, index_t row_index, const emb_t * p_emb_table, int lane_id) {} -}; - -template -struct load_row_per_warp { - static constexpr int dword_per_row = (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - static __device__ void run(float * emb_data, index_t row_index, const float * p_emb_table, int lane_id) { - int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * embedding_dim); - #pragma unroll - for(int i = 0; i < dword_per_row; i++) - { - emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); - } - } -}; - -template -struct load_row_per_warp { - static __device__ void run(half * emb_data, index_t row_index, const half * p_emb_table, int lane_id) { - int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 64); - emb_data[0] = llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half), 0, 0); - } -}; - -template -struct load_row_per_warp { - static __device__ void run(half * emb_data, index_t row_index, const half * p_emb_table, int lane_id) { - int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 128); - *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2), 0, 0); - } -}; - -template -struct load_row_per_warp { - static __device__ void run(half * emb_data, index_t row_index, const half * p_emb_table, int lane_id) { - int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 192); - *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2), 0, 0); - emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16(emb_res, (lane_id + 128) * sizeof(half), 0, 0); - } -}; - -template -struct load_row_per_warp { - static __device__ void run(half * emb_data, index_t row_index, const half * p_emb_table, int lane_id) { - int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 256); - *reinterpret_cast(&emb_data[0]) = llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[2]) = llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, (lane_id + 64 )* sizeof(half2), 0, 0); - } -}; - -template -struct load_row_per_warp { - static __device__ void run(half * emb_data, index_t row_index, const half * p_emb_table, int lane_id) { - int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 512); - *reinterpret_cast(&emb_data[0]) = llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[2]) = llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, (lane_id + 64 )* sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[4]) = llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, (lane_id + 64 * 2 )* sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[6]) = llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, (lane_id + 64 * 3 )* sizeof(half2), 0, 0); - } -}; - -template -struct accumulate_row_per_warp { - static constexpr int dword_per_row = (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - static __device__ void run(output_t * acc, emb_t * emb_data, int lane_id) { - #pragma unroll - for(int i = 0; i < dword_per_row; i++){ - acc[i] += static_cast(emb_data[i]); - } - } -}; - -template -struct store_row_per_warp { - static constexpr int dword_per_row = (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - static __device__ void run(output_t * acc, output_t * p_output, int lane_id) - { - #pragma unroll - for(int i = 0; i < dword_per_row; i++){ - p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; - } - } -}; - -template<> -struct store_row_per_warp { - static __device__ void run(float * acc, float * p_output, int lane_id) - { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(floatx2_t), 0, 0); - } -}; - -template<> -struct store_row_per_warp { - static __device__ void run(float * acc, float * p_output, int lane_id) - { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(floatx2_t), 0, 0); - llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128 )* sizeof(float), 0, 0); - } -}; - -template<> -struct store_row_per_warp { - static __device__ void run(float * acc, float * p_output, int lane_id) - { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(floatx2_t), 0, 0); - llvm_amdgcn_raw_buffer_store_fp32x2(*reinterpret_cast(&acc[2]), out_res, (lane_id + 64) * sizeof(floatx2_t), 0, 0); - } -}; +#include "split_tbe_common_hip.h" template < typename emb_t, From b546e59e081555bd469602f97298fe0ad0707e72 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Mon, 19 Sep 2022 09:00:51 -0400 Subject: [PATCH 2/8] compiler OK now --- .../embedding_backward_split_template.cu | 80 +++++++++++++++++++ fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp | 44 +++++----- fbgemm_gpu/hip_kernel/split_tbe_common_hip.h | 3 +- 3 files changed, 104 insertions(+), 23 deletions(-) diff --git a/fbgemm_gpu/codegen/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_template.cu index afc391da9..657a09e28 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_template.cu @@ -8,6 +8,8 @@ {% set wdesc = "weighted" if weighted else "unweighted" %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/split_embeddings_utils.cuh" +#include "hip_kernel/split_tbe_common_hip.h" +#include #define SHFL_SYNC(val, srcLane) shfl_sync(val, srcLane, kThreadGroupSize, shfl_sync_mask) @@ -926,7 +928,39 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ } {% endif %} + static int init_hsaco = 0; + static hipModule_t hip_kernel_module; + static hipFunction_t hip_kernel_func; + bool hip_opt_kernel_supported = "{{ optimizer }}" == "rowwise_adagrad"; // TODO: figure out support range + // uint32_t grids[3] = {(B + bags_per_workgroup - 1) / bags_per_workgroup, (uint32_t)T, 1}; + // uint32_t blocks[3] = {256, 1, 1}; + + if(hip_opt_kernel_supported && init_hsaco == 0){ + int segment_split = 0; // warp per row + int weight_decay_mode = 1; // TODO: how to check this? + hipError_t hip_err = hipModuleLoad(&hip_kernel_module, "hip_kernel/split_tbe_bwd_hip_kernel.hsaco"); // hip kernel object + if (hip_err != hipSuccess) { + char cwd[PATH_MAX]; + getcwd(cwd, sizeof(cwd)); + printf("[hiperror](%d) line:%d, fail to call,(%s), cwd:%s", (int) hip_err, __LINE__, hipGetErrorString(hip_err), cwd); + exit(1); + } + std::string prec = dev_weights.scalar_type() == at::ScalarType::Half ? "fp16" : "fp32"; + std::string hip_kernel_name = std::string("split_tbe_bwd_hip_kernel_") +"{{ optimizer }}" +"_w" + std::to_string(weight_decay_mode) + + "_s" + std::to_string(segment_split) + "_" + prec + "_e" + std::to_string(max_D); + hip_err = hipModuleGetFunction(&hip_kernel_func, hip_kernel_module, hip_kernel_name.c_str()); + printf("kernel function: %s, B:%d, T:%d\n", + hip_kernel_name.c_str(), B, T); + if (hip_err != hipSuccess) { + printf("[hiperror](%d) line:%d, fail to call,(%s)", (int) hip_err, __LINE__, hipGetErrorString(hip_err)); + exit(1); + } + + init_hsaco = 1; + } + {% if not dense %} + DISPATCH_EMB_GRAD_CACHE_TYPES( dev_weights.scalar_type(), grad_output.scalar_type(), @@ -1200,6 +1234,52 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ used_shared_bytes); // V100: 64 KB; A100: 96 KB. #endif C10_CUDA_KERNEL_LAUNCH_CHECK(); + if(hip_opt_kernel_supported){ + struct { + const void* p_output_grad; + void* p_emb_table; + const void* p_sorted_linear_indices_run; + const void* p_sorted_linear_indices_cumulative_run_lengths; + const void* p_sorted_linear_indices_num_runs; + const void* p_long_run_ids; + const void* p_num_long_run_ids; + const void* p_sorted_infos; + magic_div_u32_t batch_mdiv; + uint32_t max_segment_length_per_warp; + uint32_t emb_dim; + uint32_t batch; + uint32_t num_rows; + uint32_t num_tables; + rowwise_adagrad_kernel_arg_t opt_karg; + } karg; + size_t arg_size = sizeof(karg); + void* kconf[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &karg, HIP_LAUNCH_PARAM_BUFFER_SIZE, + &arg_size, HIP_LAUNCH_PARAM_END}; + + karg.p_output_grad = grad_output_accessor.data(); + karg.p_emb_table = dev_weights.packed_accessor64().data(); + karg.p_sorted_linear_indices_run = sorted_linear_indices_run.packed_accessor32().data(); + karg.p_sorted_linear_indices_cumulative_run_lengths = sorted_linear_indices_cumulative_run_lengths.packed_accessor32().data(); + karg.p_sorted_linear_indices_num_runs = sorted_linear_indices_num_runs.packed_accessor32().data(); + karg.p_long_run_ids = long_run_ids.packed_accessor32().data(); + karg.p_num_long_run_ids = num_long_run_ids.packed_accessor32().data(); + karg.p_sorted_infos = infos_sorted.packed_accessor32().data(); + karg.batch_mdiv = magic_div_u32_gen(B); + karg.max_segment_length_per_warp = max_segment_length_per_warp; + karg.emb_dim = max_D; + karg.batch = B; + karg.num_rows = dev_weights.numel() / T / max_D; + karg.num_tables = T; + + constexpr int segments_per_workgroup = 4; + int32_t grids[3] = {div_round_up(sorted_linear_indices_run.numel(), segments_per_workgroup), (int32_t)T, 1}; + int32_t blocks[3] = {256, 1, 1}; + + hipModuleLaunchKernel(hip_kernel_func, + grids[0], grids[1], grids[2], + blocks[0], blocks[1], blocks[2], 0, 0, NULL, (void **) &kconf); + + }else split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_warp_per_row_1< {% if not dense %} emb_t, diff --git a/fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp b/fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp index cfeb9cc63..481d5ad03 100644 --- a/fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp +++ b/fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp @@ -25,10 +25,10 @@ #include #include "split_tbe_common_hip.h" -template +template struct rowwise_adagrad_optimizer_t { - __device__ rowwise_adagrad_optimizer_t(const rowwise_adagrad_kernel_arg_t& karg_) + __device__ rowwise_adagrad_optimizer_t(const rowwise_adagrad_kernel_arg_t& karg_) : karg(karg_) { } @@ -42,10 +42,11 @@ struct rowwise_adagrad_optimizer_t { if constexpr(segment_split == 0) { - cache_t momentum = karg.p_momentum[row_index]; // should be s_load + cache_t * p_momentum = reinterpret_cast(karg.p_momentum); + cache_t momentum = p_momentum[row_index]; // should be s_load // compute per row square sum cache_t local_sum_squre = .0f; - if constexpr(weigiht_decay_mode == 1) + if constexpr(weight_decay_mode == 1) { #pragma unroll for(auto i = 0; i < thread_length; i++) @@ -76,11 +77,11 @@ struct rowwise_adagrad_optimizer_t multiplier = karg.learning_rate / (sqrtf(momentum_new) + karg.eps); - if constexpr(weigiht_decay_mode == 1) + if constexpr(weight_decay_mode == 1) { correction = 1.0 - multiplier * karg.weight_decay; } - else if constexpr(weigiht_decay_mode == 2) + else if constexpr(weight_decay_mode == 2) { correction = 1.0 - karg.learning_rate * karg.weight_decay; } @@ -99,11 +100,11 @@ struct rowwise_adagrad_optimizer_t weight[i] = static_cast(w); } - karg.p_momentum[row_index] = momentum_new; + p_momentum[row_index] = momentum_new; } } - rowwise_adagrad_kernel_arg_t karg; + rowwise_adagrad_kernel_arg_t karg; }; template (num_rows) * emb_dim; p_emb_table += blockIdx.y * emb_table_stride; - opt_karg.p_momentum += blockIdx.y * num_rows; + opt_karg.p_momentum = reinterpret_cast(reinterpret_cast(opt_karg.p_momentum) + blockIdx.y * num_rows); const int32_t segment_length = segment_end - segment_start; @@ -282,7 +283,7 @@ __device__ void split_tbe_backward_unweighted_hip_kernel( store_row_per_warp::run(&emb_data[0], p_emb_table, lane_id); } -#define SPLIT_TBE_BWD_WARP_PER_ROW_KERNEL(optimizer, \ +#define SPLIT_TBE_BWD_KERNEL(optimizer, \ weight_decay_mode, \ segment_split, \ emb_prec, \ @@ -291,7 +292,7 @@ __device__ void split_tbe_backward_unweighted_hip_kernel( bag_prefetch, \ bag_unroll) \ extern "C" __global__ void \ - split_tbe_bwd_hip_kernel_warp_per_row_##optimizer##_w##weight_decay_mode##_s##segment_split##_##emb_prec##_e##embedding_dim( \ + split_tbe_bwd_hip_kernel_##optimizer##_w##weight_decay_mode##_s##segment_split##_##emb_prec##_e##embedding_dim( \ const float* p_output_grad, \ emb_type* p_emb_table, \ const int64_t* p_sorted_linear_indices_run, \ @@ -306,11 +307,11 @@ __device__ void split_tbe_backward_unweighted_hip_kernel( uint32_t batch, \ uint32_t num_rows, \ uint32_t num_tables, \ - optimizer##_kernel_arg_t opt_karg) \ + optimizer##_kernel_arg_t opt_karg) \ { \ split_tbe_backward_unweighted_hip_kernel< \ optimizer##_optimizer_t, \ - optimizer##_kernel_arg_t, \ + optimizer##_kernel_arg_t, \ emb_type, \ float, \ float, \ @@ -335,12 +336,13 @@ __device__ void split_tbe_backward_unweighted_hip_kernel( opt_karg); \ } -SPLIT_TBE_BWD_WARP_PER_ROW_KERNEL(rowwise_adagrad, 1, 0, fp32, float, 64, 2, 8) -SPLIT_TBE_BWD_WARP_PER_ROW_KERNEL(rowwise_adagrad, 1, 0, fp32, float, 128, 2, 8) -SPLIT_TBE_BWD_WARP_PER_ROW_KERNEL(rowwise_adagrad, 1, 0, fp32, float, 192, 2, 8) -SPLIT_TBE_BWD_WARP_PER_ROW_KERNEL(rowwise_adagrad, 1, 0, fp32, float, 256, 2, 8) +// warp per row +SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 1, 0, fp32, float, 64, 2, 8) +SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 1, 0, fp32, float, 128, 2, 8) +SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 1, 0, fp32, float, 192, 2, 8) +SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 1, 0, fp32, float, 256, 2, 8) -SPLIT_TBE_BWD_WARP_PER_ROW_KERNEL(rowwise_adagrad, 1, 0, fp16, half, 64, 2, 8) -SPLIT_TBE_BWD_WARP_PER_ROW_KERNEL(rowwise_adagrad, 1, 0, fp16, half, 128, 2, 8) -SPLIT_TBE_BWD_WARP_PER_ROW_KERNEL(rowwise_adagrad, 1, 0, fp16, half, 192, 2, 8) -SPLIT_TBE_BWD_WARP_PER_ROW_KERNEL(rowwise_adagrad, 1, 0, fp16, half, 256, 2, 8) +SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 1, 0, fp16, half, 64, 2, 8) +SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 1, 0, fp16, half, 128, 2, 8) +SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 1, 0, fp16, half, 192, 2, 8) +SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 1, 0, fp16, half, 256, 2, 8) diff --git a/fbgemm_gpu/hip_kernel/split_tbe_common_hip.h b/fbgemm_gpu/hip_kernel/split_tbe_common_hip.h index 6efe1534e..e7d8ef0f9 100644 --- a/fbgemm_gpu/hip_kernel/split_tbe_common_hip.h +++ b/fbgemm_gpu/hip_kernel/split_tbe_common_hip.h @@ -359,10 +359,9 @@ __device__ inline data_t wave_reduce(const data_t& thread_data) __builtin_amdgcn_readlane(bit_cast(result), wave_size - 1)); } -template struct rowwise_adagrad_kernel_arg_t { - cache_t* p_momentum; + void* p_momentum; float eps; float learning_rate; float weight_decay; From 369f076b2c266aa9c4b2a3a435094ca6b93e31ef Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 25 Sep 2022 10:58:50 -0400 Subject: [PATCH 3/8] fix bug in bwd --- .../embedding_backward_split_template.cu | 38 +- fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp | 740 ++++++++++-------- 2 files changed, 420 insertions(+), 358 deletions(-) diff --git a/fbgemm_gpu/codegen/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_template.cu index 657a09e28..44a7b28eb 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_template.cu @@ -10,6 +10,7 @@ #include "fbgemm_gpu/split_embeddings_utils.cuh" #include "hip_kernel/split_tbe_common_hip.h" #include +#include #define SHFL_SYNC(val, srcLane) shfl_sync(val, srcLane, kThreadGroupSize, shfl_sync_mask) @@ -931,13 +932,17 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ static int init_hsaco = 0; static hipModule_t hip_kernel_module; static hipFunction_t hip_kernel_func; - bool hip_opt_kernel_supported = "{{ optimizer }}" == "rowwise_adagrad"; // TODO: figure out support range - // uint32_t grids[3] = {(B + bags_per_workgroup - 1) / bags_per_workgroup, (uint32_t)T, 1}; - // uint32_t blocks[3] = {256, 1, 1}; +{% if optimizer == "rowwise_adagrad" and not dense %} + std::set D_emb_s {64, 128, 192, 256}; + bool hip_opt_kernel_supported = (D_emb_s.find(max_D) != D_emb_s.end()) && + (dev_weights.scalar_type() == at::ScalarType::Half || dev_weights.scalar_type() == at::ScalarType::Float); +{% else %} + bool hip_opt_kernel_supported = false; // TODO: figure out support range +{% endif %} +{% if optimizer == "rowwise_adagrad" and not dense %} if(hip_opt_kernel_supported && init_hsaco == 0){ int segment_split = 0; // warp per row - int weight_decay_mode = 1; // TODO: how to check this? hipError_t hip_err = hipModuleLoad(&hip_kernel_module, "hip_kernel/split_tbe_bwd_hip_kernel.hsaco"); // hip kernel object if (hip_err != hipSuccess) { char cwd[PATH_MAX]; @@ -945,19 +950,21 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ printf("[hiperror](%d) line:%d, fail to call,(%s), cwd:%s", (int) hip_err, __LINE__, hipGetErrorString(hip_err), cwd); exit(1); } - std::string prec = dev_weights.scalar_type() == at::ScalarType::Half ? "fp16" : "fp32"; + std::string w_prec = dev_weights.scalar_type() == at::ScalarType::Half ? "fp16" : "fp32"; + std::string g_prec = grad_output.scalar_type() == at::ScalarType::Half ? "fp16" : "fp32"; std::string hip_kernel_name = std::string("split_tbe_bwd_hip_kernel_") +"{{ optimizer }}" +"_w" + std::to_string(weight_decay_mode) + - "_s" + std::to_string(segment_split) + "_" + prec + "_e" + std::to_string(max_D); + "_s" + std::to_string(segment_split) + "_" + w_prec + "_" + g_prec + "_e" + std::to_string(max_D); hip_err = hipModuleGetFunction(&hip_kernel_func, hip_kernel_module, hip_kernel_name.c_str()); - printf("kernel function: %s, B:%d, T:%d\n", - hip_kernel_name.c_str(), B, T); + printf("kernel function: %s, B:%d, T:%d(%d), wcnt:%ld, ocnt:%ld, mcnt:%ld\n", + hip_kernel_name.c_str(), B, T, hash_size_cumsum.size(0) - 1, dev_weights.numel(), grad_output.numel(), momentum1_dev.numel()); if (hip_err != hipSuccess) { - printf("[hiperror](%d) line:%d, fail to call,(%s)", (int) hip_err, __LINE__, hipGetErrorString(hip_err)); + printf("[hiperror](%d) line:%d, fail to call,(%s), %s", (int) hip_err, __LINE__, hipGetErrorString(hip_err), hip_kernel_name.c_str()); exit(1); } init_hsaco = 1; } +{% endif %} {% if not dense %} @@ -1238,6 +1245,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ struct { const void* p_output_grad; void* p_emb_table; + const void* p_hash_size_cumsum; const void* p_sorted_linear_indices_run; const void* p_sorted_linear_indices_cumulative_run_lengths; const void* p_sorted_linear_indices_num_runs; @@ -1258,6 +1266,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ karg.p_output_grad = grad_output_accessor.data(); karg.p_emb_table = dev_weights.packed_accessor64().data(); + karg.p_hash_size_cumsum = hash_size_cumsum.packed_accessor32().data(); karg.p_sorted_linear_indices_run = sorted_linear_indices_run.packed_accessor32().data(); karg.p_sorted_linear_indices_cumulative_run_lengths = sorted_linear_indices_cumulative_run_lengths.packed_accessor32().data(); karg.p_sorted_linear_indices_num_runs = sorted_linear_indices_num_runs.packed_accessor32().data(); @@ -1271,8 +1280,17 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ karg.num_rows = dev_weights.numel() / T / max_D; karg.num_tables = T; + {% if optimizer == "rowwise_adagrad" and not dense %} + rowwise_adagrad_kernel_arg_t opt_karg; + opt_karg.p_momentum = momentum1_dev.packed_accessor64, 1, at::RestrictPtrTraits>().data(); + opt_karg.eps = eps; + opt_karg.learning_rate = learning_rate; + opt_karg.weight_decay = weight_decay; + karg.opt_karg = opt_karg; + {% endif %} + constexpr int segments_per_workgroup = 4; - int32_t grids[3] = {div_round_up(sorted_linear_indices_run.numel(), segments_per_workgroup), (int32_t)T, 1}; + int32_t grids[3] = {div_round_up(sorted_linear_indices_run.numel(), segments_per_workgroup), 1, 1}; int32_t blocks[3] = {256, 1, 1}; hipModuleLaunchKernel(hip_kernel_func, diff --git a/fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp b/fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp index 481d5ad03..45f25e9f3 100644 --- a/fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp +++ b/fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp @@ -1,348 +1,392 @@ -/******************************************************************************* - * Copyright (c) 2016 - 2022 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - * - ******************************************************************************/ - -#include -#include -#include "split_tbe_common_hip.h" - -template -struct rowwise_adagrad_optimizer_t -{ - __device__ rowwise_adagrad_optimizer_t(const rowwise_adagrad_kernel_arg_t& karg_) - : karg(karg_) - { - } - - // template - // __device__ static void precompute(float * acc){ - // // compute per row square sum - // } - template - __device__ void update(cache_t* acc, emb_t* weight, uint32_t row_index) - { - if constexpr(segment_split == 0) - { - cache_t * p_momentum = reinterpret_cast(karg.p_momentum); - cache_t momentum = p_momentum[row_index]; // should be s_load - // compute per row square sum - cache_t local_sum_squre = .0f; - if constexpr(weight_decay_mode == 1) - { -#pragma unroll - for(auto i = 0; i < thread_length; i++) - { - cache_t w = static_cast(weight[i]); - cache_t a = acc[i] + w * karg.weight_decay; - local_sum_squre += a * a; - } - } - else - { -#pragma unroll - for(auto i = 0; i < thread_length; i++) - { - cache_t a = acc[i]; - local_sum_squre += a * a; - } - } - - cache_t avg_square = - wave_reduce, cache_t, AMDGCN_WAVE_SIZE>(local_sum_squre) / - embedding_dim; - - cache_t multiplier; - cache_t correction; - - cache_t momentum_new = momentum + avg_square; - - multiplier = karg.learning_rate / (sqrtf(momentum_new) + karg.eps); - - if constexpr(weight_decay_mode == 1) - { - correction = 1.0 - multiplier * karg.weight_decay; - } - else if constexpr(weight_decay_mode == 2) - { - correction = 1.0 - karg.learning_rate * karg.weight_decay; - } - else - { - correction = 1.0; - } - -// update new weight value -#pragma unroll - for(auto i = 0; i < thread_length; i++) - { - cache_t w = static_cast(weight[i]); - cache_t a = acc[i]; - w = correction * w - multiplier * a; - weight[i] = static_cast(w); - } - - p_momentum[row_index] = momentum_new; - } - } - - rowwise_adagrad_kernel_arg_t karg; -}; - -template // 0-warp per row, 1-cta per row, 2-atomic(needed?) -__device__ void split_tbe_backward_unweighted_hip_kernel( - const grad_t* p_output_grad, - emb_t* p_emb_table, - const int64_t* p_sorted_linear_indices_run, - const int64_t* p_sorted_linear_indices_cumulative_run_lengths, - const int32_t* p_sorted_linear_indices_num_runs, - const int32_t* p_long_run_ids, - const int64_t* p_num_long_run_ids, - const int32_t* p_sorted_infos, - magic_div_u32_t batch_mdiv, - uint32_t max_segment_length_per_warp, - uint32_t emb_dim, - uint32_t batch, - uint32_t num_rows, - uint32_t num_tables, - optimizer_karg_t opt_karg) -{ - constexpr uint32_t dword_per_row = (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - constexpr uint32_t waves_per_block = block_size / AMDGCN_WAVE_SIZE; - constexpr uint32_t length_mask = ~(bag_unroll - 1); - const uint32_t wave_id = __builtin_amdgcn_readfirstlane(threadIdx.x / AMDGCN_WAVE_SIZE); - const uint32_t lane_id = threadIdx.x % AMDGCN_WAVE_SIZE; - const uint32_t run_id = wave_id + blockIdx.x * waves_per_block; - - if(run_id >= p_sorted_linear_indices_num_runs[0]) - return; - - const int64_t linear_index = p_sorted_linear_indices_run[run_id]; - const int64_t emb_idx = linear_index - blockIdx.y; - const int32_t segment_start = p_sorted_linear_indices_cumulative_run_lengths[run_id]; - const int32_t segment_end = p_sorted_linear_indices_cumulative_run_lengths[run_id + 1]; - - p_output_grad += blockIdx.y * emb_dim; - - uint64_t emb_table_stride = static_cast(num_rows) * emb_dim; - p_emb_table += blockIdx.y * emb_table_stride; - opt_karg.p_momentum = reinterpret_cast(reinterpret_cast(opt_karg.p_momentum) + blockIdx.y * num_rows); - - const int32_t segment_length = segment_end - segment_start; - - if(segment_length >= max_segment_length_per_warp) - return; - - const int32_t segment_length_mod = segment_length & length_mask; - - cache_t grad_acc[dword_per_row]; - int32_t infos[bag_unroll]; - grad_t grad_data[dword_per_row * bag_prefetch]; - emb_t emb_data[dword_per_row]; - - int itr = 0; - if(segment_length_mod == 0) - goto L_tail_grad_acc; - -#pragma unroll - for(int i = 0; i < bag_unroll; i++) - { - infos[i] = p_sorted_infos[i]; - } - - itr += bag_unroll; - p_sorted_infos += bag_unroll; - - uint32_t row_index; - uint32_t table_index__unused; - - // LOOP - for(; itr < segment_length_mod; itr += bag_unroll) - { - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index__unused, row_index); - load_row_per_warp::run( - &grad_data[0], row_index * num_tables, p_output_grad, lane_id); - - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index__unused, row_index); - load_row_per_warp::run( - &grad_data[dword_per_row], row_index * num_tables, p_output_grad, lane_id); - -#pragma unroll - for(int j = 2; j < bag_unroll; j += 2) - { - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[0], lane_id); - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index__unused, row_index); - load_row_per_warp::run( - &grad_data[0], row_index * num_tables, p_output_grad, lane_id); - - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[dword_per_row], lane_id); - magic_div_u32_run_with_mod( - batch_mdiv, infos[j + 1], batch, table_index__unused, row_index); - load_row_per_warp::run( - &grad_data[dword_per_row], row_index * num_tables, p_output_grad, lane_id); - } - -#pragma unroll - for(int i = 0; i < bag_unroll; i++) - { - infos[i] = p_sorted_infos[i]; - } - p_sorted_infos += bag_unroll; - - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[0], lane_id); - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[dword_per_row], lane_id); - } - - // LAST - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index__unused, row_index); - load_row_per_warp::run( - &grad_data[0], row_index * num_tables, p_output_grad, lane_id); - - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index__unused, row_index); - load_row_per_warp::run( - &grad_data[dword_per_row], row_index * num_tables, p_output_grad, lane_id); - -#pragma unroll - for(int j = 2; j < bag_unroll; j += 2) - { - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[0], lane_id); - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index__unused, row_index); - load_row_per_warp::run( - &grad_data[0], row_index * num_tables, p_output_grad, lane_id); - - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[dword_per_row], lane_id); - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index__unused, row_index); - load_row_per_warp::run( - &grad_data[dword_per_row], row_index * num_tables, p_output_grad, lane_id); - } - - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[0], lane_id); - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[dword_per_row], lane_id); - -L_tail_grad_acc: - if(segment_length & (bag_unroll - 1)) - { - // last, load one by one - do - { - infos[0] = p_sorted_infos[0]; - p_sorted_infos++; - - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index__unused, row_index); - load_row_per_warp::run( - &grad_data[0], row_index * num_tables, p_output_grad, lane_id); - accumulate_row_per_warp::run( - &grad_data[0], &grad_data[0], lane_id); - - itr++; - } while(itr < segment_length); - } - - // load the old emb weight data - load_row_per_warp::run( - &emb_data[0], emb_idx, p_emb_table, lane_id); - optimizer_t optimizer(opt_karg); - optimizer.template update(grad_acc, emb_data, row_index); - - // store updated weight to grad - store_row_per_warp::run(&emb_data[0], p_emb_table, lane_id); -} - -#define SPLIT_TBE_BWD_KERNEL(optimizer, \ - weight_decay_mode, \ - segment_split, \ - emb_prec, \ - emb_type, \ - embedding_dim, \ - bag_prefetch, \ - bag_unroll) \ - extern "C" __global__ void \ - split_tbe_bwd_hip_kernel_##optimizer##_w##weight_decay_mode##_s##segment_split##_##emb_prec##_e##embedding_dim( \ - const float* p_output_grad, \ - emb_type* p_emb_table, \ - const int64_t* p_sorted_linear_indices_run, \ - const int64_t* p_sorted_linear_indices_cumulative_run_lengths, \ - const int32_t* p_sorted_linear_indices_num_runs, \ - const int32_t* p_long_run_ids, \ - const int64_t* p_num_long_run_ids, \ - const int32_t* p_sorted_infos, \ - magic_div_u32_t batch_mdiv, \ - uint32_t max_segment_length_per_warp, \ - uint32_t emb_dim, \ - uint32_t batch, \ - uint32_t num_rows, \ - uint32_t num_tables, \ - optimizer##_kernel_arg_t opt_karg) \ - { \ - split_tbe_backward_unweighted_hip_kernel< \ - optimizer##_optimizer_t, \ - optimizer##_kernel_arg_t, \ - emb_type, \ - float, \ - float, \ - BLOCK_SIZE, \ - embedding_dim, \ - bag_prefetch, \ - bag_unroll, \ - segment_split>(p_output_grad, \ - p_emb_table, \ - p_sorted_linear_indices_run, \ - p_sorted_linear_indices_cumulative_run_lengths, \ - p_sorted_linear_indices_num_runs, \ - p_long_run_ids, \ - p_num_long_run_ids, \ - p_sorted_infos, \ - batch_mdiv, \ - max_segment_length_per_warp, \ - emb_dim, \ - batch, \ - num_rows, \ - num_tables, \ - opt_karg); \ - } - -// warp per row -SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 1, 0, fp32, float, 64, 2, 8) -SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 1, 0, fp32, float, 128, 2, 8) -SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 1, 0, fp32, float, 192, 2, 8) -SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 1, 0, fp32, float, 256, 2, 8) - -SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 1, 0, fp16, half, 64, 2, 8) -SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 1, 0, fp16, half, 128, 2, 8) -SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 1, 0, fp16, half, 192, 2, 8) -SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 1, 0, fp16, half, 256, 2, 8) +/******************************************************************************* + * Copyright (c) 2016 - 2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + ******************************************************************************/ + +#include +#include +#include "split_tbe_common_hip.h" + +template +struct rowwise_adagrad_optimizer_t +{ + __device__ rowwise_adagrad_optimizer_t(const rowwise_adagrad_kernel_arg_t& karg_) + : karg(karg_) + { + } + + // template + // __device__ static void precompute(float * acc){ + // // compute per row square sum + // } + template + __device__ void update(cache_t* acc, emb_t* weight, uint32_t row_index) + { + if constexpr(segment_split == 0) + { + cache_t * p_momentum = reinterpret_cast(karg.p_momentum); + cache_t momentum = p_momentum[row_index]; // should be s_load + // compute per row square sum + cache_t local_sum_squre = .0f; + if constexpr(weight_decay_mode == 1) + { +#pragma unroll + for(auto i = 0; i < thread_length; i++) + { + cache_t w = static_cast(weight[i]); + cache_t a = acc[i] + w * karg.weight_decay; + local_sum_squre += a * a; + } + } + else + { +#pragma unroll + for(auto i = 0; i < thread_length; i++) + { + cache_t a = acc[i]; + local_sum_squre += a * a; + } + } + + cache_t avg_square = + wave_reduce, cache_t, AMDGCN_WAVE_SIZE>(local_sum_squre) / + embedding_dim; + + cache_t momentum_new = momentum + avg_square; + + cache_t multiplier = karg.learning_rate / (sqrtf(momentum_new) + karg.eps); + cache_t correction; + + if constexpr(weight_decay_mode == 1) + { + correction = 1.0 - multiplier * karg.weight_decay; + } + else if constexpr(weight_decay_mode == 2) + { + correction = 1.0 - karg.learning_rate * karg.weight_decay; + } + else + { + correction = 1.0; + } + +// update new weight value +#pragma unroll + for(auto i = 0; i < thread_length; i++) + { + cache_t w = static_cast(weight[i]); + cache_t a = acc[i]; + w = correction * w - multiplier * a; + weight[i] = static_cast(w); + } + + // printf("momentum_new:%f, avg_square:%f, row_index:%d, momentum:%f\n", momentum_new, avg_square, row_index, momentum); + // printf("momentum_new:%f", momentum_new); + + p_momentum[row_index] = momentum_new; + } + } + + rowwise_adagrad_kernel_arg_t karg; +}; + +template // 0-warp per row, 1-cta per row, 2-atomic(needed?) +__device__ void split_tbe_backward_unweighted_hip_kernel( + const grad_t* p_output_grad, + emb_t* p_emb_table, + const int64_t* p_hash_size_cumsum, + const int64_t* p_sorted_linear_indices_run, + const int32_t* p_sorted_linear_indices_cumulative_run_lengths, + const int32_t* p_sorted_linear_indices_num_runs, + const int32_t* p_long_run_ids, + const int32_t* p_num_long_run_ids, + const int32_t* p_sorted_infos, + magic_div_u32_t batch_mdiv, + uint32_t max_segment_length_per_warp, + uint32_t emb_dim, + uint32_t batch, + uint32_t num_rows, + uint32_t num_tables, + optimizer_karg_t opt_karg) +{ + constexpr uint32_t dword_per_row = (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; + constexpr uint32_t waves_per_block = block_size / AMDGCN_WAVE_SIZE; + constexpr uint32_t length_mask = ~(segment_unroll - 1); + const uint32_t wave_id = __builtin_amdgcn_readfirstlane(threadIdx.x / AMDGCN_WAVE_SIZE); + const uint32_t lane_id = threadIdx.x % AMDGCN_WAVE_SIZE; + const uint32_t run_id = wave_id + blockIdx.x * waves_per_block; + + // printf("wave_id:%d, run_id:%d(%d), batch:%d(%d, %d)\n", + // wave_id, run_id, p_sorted_linear_indices_num_runs[0], batch, batch_mdiv.magic, batch_mdiv.shift); + + if(run_id >= p_sorted_linear_indices_num_runs[0]) + return; + + const int64_t linear_index = p_sorted_linear_indices_run[run_id]; + + const int32_t segment_start = p_sorted_linear_indices_cumulative_run_lengths[run_id]; + const int32_t segment_end = p_sorted_linear_indices_cumulative_run_lengths[run_id + 1]; + + int32_t info_0 = p_sorted_infos[segment_start]; + uint32_t t_0 = magic_div_u32_run(batch_mdiv, info_0); + int64_t hash_size = p_hash_size_cumsum[t_0]; + + const int64_t emb_idx = linear_index - hash_size; + + // printf("[%d] segment_start:%d, info_0:%d, t_0:%d, num_rows:%d, emb_dim:%d, linear_index:%ld\n", wave_id, segment_start, info_0, t_0, num_rows, emb_dim, linear_index); + + // p_output_grad += t_0 * emb_dim; + + p_emb_table += hash_size * emb_dim; + opt_karg.p_momentum = reinterpret_cast(reinterpret_cast(opt_karg.p_momentum) + hash_size); + + const int32_t segment_length = segment_end - segment_start; + + if(segment_length >= max_segment_length_per_warp) + return; + + // printf("[%d] segment_length:%d\n", wave_id, segment_length); + + const int32_t segment_length_mod = segment_length & length_mask; + + cache_t grad_acc[dword_per_row]; + int32_t infos[segment_unroll]; + grad_t grad_data[dword_per_row * segment_prefetch]; + emb_t emb_data[dword_per_row]; + + #pragma unroll + for(int i=0; i < dword_per_row; i++) + { + grad_acc[i] = .0f; + } + + int itr = 0; + if(segment_length_mod == 0) + goto L_tail_grad_acc; + +#pragma unroll + for(int i = 0; i < segment_unroll; i++) + { + infos[i] = p_sorted_infos[segment_start + i]; + } + + itr += segment_unroll; + p_sorted_infos += segment_unroll; + + uint32_t bag_index; + uint32_t table_index; + + // LOOP + for(; itr < segment_length_mod; itr += segment_unroll) + { + magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); + load_row_per_warp::run( + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + +#pragma unroll + for(int j = 2; j < segment_unroll; j += 2) + { + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id); + magic_div_u32_run_with_mod( + batch_mdiv, infos[j + 1], batch, table_index, bag_index); + load_row_per_warp::run( + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + } + +#pragma unroll + for(int i = 0; i < segment_unroll; i++) + { + infos[i] = p_sorted_infos[segment_start + i]; + } + p_sorted_infos += segment_unroll; + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id); + } + + // LAST + magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); + load_row_per_warp::run( + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + +#pragma unroll + for(int j = 2; j < segment_unroll; j += 2) + { + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id); + magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); + load_row_per_warp::run( + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + } + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id); + +L_tail_grad_acc: + if(segment_length & (segment_unroll - 1)) + { + // last, load one by one + do + { + infos[0] = p_sorted_infos[segment_start]; + p_sorted_infos++; + + magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + + itr++; + } while(itr < segment_length); + } + + // printf("[%d] segment_length:%d ==<< %f, emb_idx:%ld\n", wave_id, segment_length, grad_acc[0], emb_idx); + + // load the old emb weight data + load_row_per_warp::run( + &emb_data[0], emb_idx, p_emb_table, lane_id); + optimizer_t optimizer(opt_karg); + optimizer.template update(grad_acc, emb_data, emb_idx); + + // store updated weight to grad ?? + store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id); +} + +#define __SPLIT_TBE_BWD_KERNEL(optimizer, \ + weight_decay_mode, \ + segment_split, \ + emb_prec, \ + emb_type, \ + grad_prec, \ + grad_type, \ + embedding_dim, \ + segment_prefetch, \ + segment_unroll) \ + extern "C" __global__ void \ + split_tbe_bwd_hip_kernel_##optimizer##_w##weight_decay_mode##_s##segment_split##_##emb_prec##_##grad_prec##_e##embedding_dim( \ + const grad_type* p_output_grad, \ + emb_type* p_emb_table, \ + const int64_t* p_hash_size_cumsum, \ + const int64_t* p_sorted_linear_indices_run, \ + const int32_t* p_sorted_linear_indices_cumulative_run_lengths, \ + const int32_t* p_sorted_linear_indices_num_runs, \ + const int32_t* p_long_run_ids, \ + const int32_t* p_num_long_run_ids, \ + const int32_t* p_sorted_infos, \ + magic_div_u32_t batch_mdiv, \ + uint32_t max_segment_length_per_warp, \ + uint32_t emb_dim, \ + uint32_t batch, \ + uint32_t num_rows, \ + uint32_t num_tables, \ + optimizer##_kernel_arg_t opt_karg) \ + { \ + split_tbe_backward_unweighted_hip_kernel< \ + optimizer##_optimizer_t, \ + optimizer##_kernel_arg_t, \ + emb_type, \ + float, \ + grad_type, \ + BLOCK_SIZE, \ + embedding_dim, \ + segment_prefetch, \ + segment_unroll, \ + segment_split>(p_output_grad, \ + p_emb_table, \ + p_hash_size_cumsum, \ + p_sorted_linear_indices_run, \ + p_sorted_linear_indices_cumulative_run_lengths, \ + p_sorted_linear_indices_num_runs, \ + p_long_run_ids, \ + p_num_long_run_ids, \ + p_sorted_infos, \ + batch_mdiv, \ + max_segment_length_per_warp, \ + emb_dim, \ + batch, \ + num_rows, \ + num_tables, \ + opt_karg); \ + } + +#define SPLIT_TBE_BWD_KERNEL_ALL_WDM(optimizer, \ + segment_split, \ + emb_prec, \ + emb_type, \ + grad_prec, \ + grad_type, \ + embedding_dim, \ + segment_prefetch, \ + segment_unroll) \ + __SPLIT_TBE_BWD_KERNEL(optimizer, 0, segment_split, emb_prec, emb_type, grad_prec, grad_type, embedding_dim, segment_prefetch, segment_unroll) \ + __SPLIT_TBE_BWD_KERNEL(optimizer, 1, segment_split, emb_prec, emb_type, grad_prec, grad_type, embedding_dim, segment_prefetch, segment_unroll) \ + __SPLIT_TBE_BWD_KERNEL(optimizer, 2, segment_split, emb_prec, emb_type, grad_prec, grad_type, embedding_dim, segment_prefetch, segment_unroll) + + +#define SPLIT_TBE_BWD_KERNEL(optimizer, \ + segment_split, \ + embedding_dim) \ + SPLIT_TBE_BWD_KERNEL_ALL_WDM(optimizer, segment_split, fp32, float, fp32, float, embedding_dim, 2, 8) \ + SPLIT_TBE_BWD_KERNEL_ALL_WDM(optimizer, segment_split, fp32, float, fp16, half, embedding_dim, 2, 8) \ + SPLIT_TBE_BWD_KERNEL_ALL_WDM(optimizer, segment_split, fp16, half, fp32, float, embedding_dim, 2, 8) \ + SPLIT_TBE_BWD_KERNEL_ALL_WDM(optimizer, segment_split, fp16, half, fp16, half, embedding_dim, 2, 8) + +// warp per row +SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 0, 64) +SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 0, 128) +SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 0, 192) +SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 0, 256) + From 18fda9630fb87bf95fd8c5a39bc4d0753294a4f9 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Wed, 28 Sep 2022 07:38:14 -0400 Subject: [PATCH 4/8] modify 2 UTs --- .../split_table_batched_embeddings_test.py | 56 +++++++++---------- 1 file changed, 26 insertions(+), 30 deletions(-) diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index 6ee2b23f3..9dd9913b5 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -1794,7 +1794,7 @@ def execute_backward_adagrad_( # noqa C901 np.random.choice( [ split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE, - split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED, + # split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED, ] ) for _ in range(T) @@ -2216,27 +2216,23 @@ def test_backward_adagrad_fp16_pmNONE( # noqa C901 @given( T=st.integers(min_value=1, max_value=5), - D=st.integers(min_value=2, max_value=128), + D=st.just(16), # support 16, 32, 48, 64 B=st.integers(min_value=1, max_value=128), log_E=st.integers(min_value=3, max_value=5), L=st.integers(min_value=0, max_value=20), D_gradcheck=st.integers(min_value=1, max_value=2), - weights_precision=st.just(SparseType.FP32), - stochastic_rounding=st.booleans(), - weighted=st.booleans(), - row_wise=st.booleans(), - mixed=st.booleans(), - use_cache=st.booleans(), + weights_precision=st.just(SparseType.FP32), # support fp16/fp32 + stochastic_rounding=st.just(False), + weighted=st.just(False), + row_wise=st.just(True), + mixed=st.just(False), + use_cache=st.just(False), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans() - if (gpu_available and not TEST_WITH_ROCM) - else st.just(False) - if (gpu_available and TEST_WITH_ROCM) - else st.just(True), - exact=st.booleans(), - output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), + use_cpu=st.just(False), + exact=st.just(True), + output_dtype=st.just(SparseType.FP32), # support fp16/fp32 ) @settings( verbosity=Verbosity.verbose, @@ -2600,7 +2596,7 @@ def execute_backward_optimizers_( # noqa C901 np.random.choice( [ split_table_batched_embeddings_ops.EmbeddingLocation.DEVICE, - split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED, + # split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED, ] ) for _ in range(T) @@ -2975,36 +2971,36 @@ def test_backward_optimizers_adam( # noqa C901 @given( T=st.integers(min_value=1, max_value=5), - D=st.integers(min_value=2, max_value=256), + D=st.just(48), # 16, 32, 48, 64 B=st.integers(min_value=1, max_value=128), log_E=st.integers(min_value=3, max_value=5), L=st.integers(min_value=0, max_value=20), - weighted=st.booleans(), - mixed=st.booleans(), + weighted=st.just(False), + mixed=st.just(False), optimizer=st.sampled_from( [ - OptimType.EXACT_ADAGRAD, + # currently only support exact rowwise adagrad + #OptimType.EXACT_ADAGRAD, OptimType.EXACT_ROWWISE_ADAGRAD, - OptimType.EXACT_ROWWISE_WEIGHTED_ADAGRAD, + #OptimType.EXACT_ROWWISE_WEIGHTED_ADAGRAD, ] ), long_segments=st.booleans(), pooling_mode=st.sampled_from( [ + # can only uncomment this when fwd kernel support any pooling mode split_table_batched_embeddings_ops.PoolingMode.SUM, - split_table_batched_embeddings_ops.PoolingMode.MEAN, - split_table_batched_embeddings_ops.PoolingMode.NONE, + #split_table_batched_embeddings_ops.PoolingMode.MEAN, + #split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans() - if (gpu_available and not TEST_WITH_ROCM) - else st.just(False) - if (gpu_available and TEST_WITH_ROCM) - else st.just(True), + use_cpu=st.just(False), weight_decay_mode=st.sampled_from( [ - WeightDecayMode.L2, - WeightDecayMode.DECOUPLE, + # can change this within 3 modes + WeightDecayMode.NONE, + #WeightDecayMode.L2, + #WeightDecayMode.DECOUPLE, ] ), ) From 70bccb75cc68c549ced80c60061c0b68a7737a7f Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 16 Oct 2022 10:26:04 -0400 Subject: [PATCH 5/8] build inside a single so --- .../embedding_backward_code_generator.py | 1 + .../embedding_backward_split_template.cu | 120 +++++++++++++++++- 2 files changed, 117 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/codegen/embedding_backward_code_generator.py b/fbgemm_gpu/codegen/embedding_backward_code_generator.py index 9d6735890..04321ff23 100644 --- a/fbgemm_gpu/codegen/embedding_backward_code_generator.py +++ b/fbgemm_gpu/codegen/embedding_backward_code_generator.py @@ -45,6 +45,7 @@ # An optimization for ROCm env.globals["items_per_warp"] = 128 if args.is_rocm is False else 256 env.globals["dense"] = False +env.globals["is_rocm"] = args.is_rocm def write(filename: str, s: str) -> None: diff --git a/fbgemm_gpu/codegen/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_template.cu index 6bd4cdff2..1a718dcaf 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_template.cu @@ -9,6 +9,7 @@ #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/split_embeddings_utils.cuh" #include "hip_kernel/split_tbe_common_hip.h" +#include "hip_kernel/split_tbe_bwd.hip.hpp" #include #include @@ -952,6 +953,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ bool hip_opt_kernel_supported = false; // TODO: figure out support range {% endif %} +#if 0 {% if optimizer == "rowwise_adagrad" and not dense %} if(hip_opt_kernel_supported && init_hsaco == 0){ int segment_split = 0; // warp per row @@ -977,6 +979,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ init_hsaco = 1; } {% endif %} +#endif {% if not dense %} @@ -1080,7 +1083,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ // kMaxElemPerThread is # of elements handled by thread if we use a full warp for a row // We consider kMaxElemPerThread 1 and 2, and then a multiple of 4. {% for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %} - {% if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %} + {% if kMaxElemPerThread in ([1, 2, 3] if is_rocm else [1, 2]) or kMaxElemPerThread % 4 == 0 %} if (max_D <= {{ items_per_warp // 4 * kMaxElemPerThread }}) { // hipcc can't use max in constexpr constexpr int kMaxVecsPerThread = {{ kMaxElemPerThread }} / 4 >= 1 ? {{ kMaxElemPerThread }} / 4 : 1; @@ -1279,6 +1282,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ } C10_CUDA_KERNEL_LAUNCH_CHECK(); + {% if optimizer == "rowwise_adagrad" and not dense and (items_per_warp // 4 * kMaxElemPerThread) in [64, 128, 192, 256] %} if(hip_opt_kernel_supported){ struct { const void* p_output_grad; @@ -1331,11 +1335,119 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ int32_t grids[3] = {div_round_up(sorted_linear_indices_run.numel(), segments_per_workgroup), 1, 1}; int32_t blocks[3] = {256, 1, 1}; - hipModuleLaunchKernel(hip_kernel_func, - grids[0], grids[1], grids[2], - blocks[0], blocks[1], blocks[2], 0, 0, NULL, (void **) &kconf); + {% for weight_decay_mode_current in [0, 1, 2] %} + if(weight_decay_mode == {{ weight_decay_mode_current }}){ + if(dev_weights.scalar_type() == at::ScalarType::Half && grad_output.scalar_type() == at::ScalarType::Half){ + hipLaunchKernelGGL(split_tbe_bwd_hip_kernel_{{ optimizer }}_w{{ weight_decay_mode_current }}_s0_fp16_fp16_e{{ items_per_warp // 4 * kMaxElemPerThread }}, + dim3(grids[0], grids[1], grids[2]), + dim3(blocks[0], blocks[1], blocks[2]), + 0, 0, + (const half*)karg.p_output_grad , + (half*)karg.p_emb_table, + (const int64_t*)karg.p_hash_size_cumsum, + (const int64_t*)karg.p_sorted_linear_indices_run, + (const int32_t* )karg.p_sorted_linear_indices_cumulative_run_lengths, + (const int32_t*)karg.p_sorted_linear_indices_num_runs, + (const int32_t*)karg.p_long_run_ids , + (const int32_t*)karg.p_num_long_run_ids, + (const int32_t*)karg.p_sorted_infos , + karg.batch_mdiv, + karg.max_segment_length_per_warp, + karg.emb_dim , + karg.batch , + karg.num_rows, + karg.num_tables , + {% if optimizer == "rowwise_adagrad" and not dense %} + karg.opt_karg + {% endif %} + ); + }else if (!(dev_weights.scalar_type() == at::ScalarType::Half) && grad_output.scalar_type() == at::ScalarType::Half) + { + hipLaunchKernelGGL(split_tbe_bwd_hip_kernel_{{ optimizer }}_w{{ weight_decay_mode_current }}_s0_fp32_fp16_e{{ items_per_warp // 4 * kMaxElemPerThread }}, + dim3(grids[0], grids[1], grids[2]), + dim3(blocks[0], blocks[1], blocks[2]), + 0, 0, + (const half*)karg.p_output_grad , + (float*)karg.p_emb_table, + (const int64_t*)karg.p_hash_size_cumsum, + (const int64_t*)karg.p_sorted_linear_indices_run, + (const int32_t* )karg.p_sorted_linear_indices_cumulative_run_lengths, + (const int32_t*)karg.p_sorted_linear_indices_num_runs, + (const int32_t*)karg.p_long_run_ids , + (const int32_t*)karg.p_num_long_run_ids, + (const int32_t*)karg.p_sorted_infos , + karg.batch_mdiv, + karg.max_segment_length_per_warp, + karg.emb_dim , + karg.batch , + karg.num_rows, + karg.num_tables , + {% if optimizer == "rowwise_adagrad" and not dense %} + karg.opt_karg + {% endif %} + ); + + } + else if (dev_weights.scalar_type() == at::ScalarType::Half && !(grad_output.scalar_type() == at::ScalarType::Half)) + { + hipLaunchKernelGGL(split_tbe_bwd_hip_kernel_{{ optimizer }}_w{{ weight_decay_mode_current }}_s0_fp16_fp32_e{{ items_per_warp // 4 * kMaxElemPerThread }}, + dim3(grids[0], grids[1], grids[2]), + dim3(blocks[0], blocks[1], blocks[2]), + 0, 0, + (const float*)karg.p_output_grad , + (half*)karg.p_emb_table, + (const int64_t*)karg.p_hash_size_cumsum, + (const int64_t*)karg.p_sorted_linear_indices_run, + (const int32_t* )karg.p_sorted_linear_indices_cumulative_run_lengths, + (const int32_t*)karg.p_sorted_linear_indices_num_runs, + (const int32_t*)karg.p_long_run_ids , + (const int32_t*)karg.p_num_long_run_ids, + (const int32_t*)karg.p_sorted_infos , + karg.batch_mdiv, + karg.max_segment_length_per_warp, + karg.emb_dim , + karg.batch , + karg.num_rows, + karg.num_tables , + {% if optimizer == "rowwise_adagrad" and not dense %} + karg.opt_karg + {% endif %} + ); + } + else{ + hipLaunchKernelGGL(split_tbe_bwd_hip_kernel_{{ optimizer }}_w{{ weight_decay_mode_current }}_s0_fp32_fp32_e{{ items_per_warp // 4 * kMaxElemPerThread }}, + dim3(grids[0], grids[1], grids[2]), + dim3(blocks[0], blocks[1], blocks[2]), + 0, 0, + (const float*)karg.p_output_grad , + (float*)karg.p_emb_table, + (const int64_t*)karg.p_hash_size_cumsum, + (const int64_t*)karg.p_sorted_linear_indices_run, + (const int32_t* )karg.p_sorted_linear_indices_cumulative_run_lengths, + (const int32_t*)karg.p_sorted_linear_indices_num_runs, + (const int32_t*)karg.p_long_run_ids , + (const int32_t*)karg.p_num_long_run_ids, + (const int32_t*)karg.p_sorted_infos , + karg.batch_mdiv, + karg.max_segment_length_per_warp, + karg.emb_dim , + karg.batch , + karg.num_rows, + karg.num_tables , + {% if optimizer == "rowwise_adagrad" and not dense %} + karg.opt_karg + {% endif %} + ); + } + } + {% endfor %} + + // hipModuleLaunchKernel(hip_kernel_func, + // grids[0], grids[1], grids[2], + // blocks[0], blocks[1], blocks[2], 0, 0, NULL, (void **) &kconf); }else + {% endif %} split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_warp_per_row_1< {% if not dense %} emb_t, From 1a94a2dafca2782f3b9bd8f03f2b22e20f52fbf3 Mon Sep 17 00:00:00 2001 From: Douglas Lehr Date: Wed, 19 Oct 2022 11:55:51 -0400 Subject: [PATCH 6/8] Add hpp counterpart for bwd pass split_tbe_bwd.hip.cpp needed a hpp counterpart to invoke the necessary macros to build all templates for split_tbe_bwd_hip_kernel_ --- fbgemm_gpu/hip_kernel/split_tbe_bwd.hip.hpp | 82 +++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 fbgemm_gpu/hip_kernel/split_tbe_bwd.hip.hpp diff --git a/fbgemm_gpu/hip_kernel/split_tbe_bwd.hip.hpp b/fbgemm_gpu/hip_kernel/split_tbe_bwd.hip.hpp new file mode 100644 index 000000000..42a1849e4 --- /dev/null +++ b/fbgemm_gpu/hip_kernel/split_tbe_bwd.hip.hpp @@ -0,0 +1,82 @@ +/******************************************************************************* + * Copyright (c) 2016 - 2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + ******************************************************************************/ +#pragma once + +#include +#include +#define __SPLIT_TBE_BWD_KERNEL(optimizer, \ + weight_decay_mode, \ + segment_split, \ + emb_prec, \ + emb_type, \ + grad_prec, \ + grad_type, \ + embedding_dim, \ + segment_prefetch, \ + segment_unroll) \ + extern "C" __global__ void \ + split_tbe_bwd_hip_kernel_##optimizer##_w##weight_decay_mode##_s##segment_split##_##emb_prec##_##grad_prec##_e##embedding_dim( \ + const grad_type* p_output_grad, \ + emb_type* p_emb_table, \ + const int64_t* p_hash_size_cumsum, \ + const int64_t* p_sorted_linear_indices_run, \ + const int32_t* p_sorted_linear_indices_cumulative_run_lengths, \ + const int32_t* p_sorted_linear_indices_num_runs, \ + const int32_t* p_long_run_ids, \ + const int32_t* p_num_long_run_ids, \ + const int32_t* p_sorted_infos, \ + magic_div_u32_t batch_mdiv, \ + uint32_t max_segment_length_per_warp, \ + uint32_t emb_dim, \ + uint32_t batch, \ + uint32_t num_rows, \ + uint32_t num_tables, \ + optimizer##_kernel_arg_t opt_karg); + +#define SPLIT_TBE_BWD_KERNEL_ALL_WDM(optimizer, \ + segment_split, \ + emb_prec, \ + emb_type, \ + grad_prec, \ + grad_type, \ + embedding_dim, \ + segment_prefetch, \ + segment_unroll) \ + __SPLIT_TBE_BWD_KERNEL(optimizer, 0, segment_split, emb_prec, emb_type, grad_prec, grad_type, embedding_dim, segment_prefetch, segment_unroll) \ + __SPLIT_TBE_BWD_KERNEL(optimizer, 1, segment_split, emb_prec, emb_type, grad_prec, grad_type, embedding_dim, segment_prefetch, segment_unroll) \ + __SPLIT_TBE_BWD_KERNEL(optimizer, 2, segment_split, emb_prec, emb_type, grad_prec, grad_type, embedding_dim, segment_prefetch, segment_unroll) + + +#define SPLIT_TBE_BWD_KERNEL(optimizer, \ + segment_split, \ + embedding_dim) \ + SPLIT_TBE_BWD_KERNEL_ALL_WDM(optimizer, segment_split, fp32, float, fp32, float, embedding_dim, 2, 8) \ + SPLIT_TBE_BWD_KERNEL_ALL_WDM(optimizer, segment_split, fp32, float, fp16, half, embedding_dim, 2, 8) \ + SPLIT_TBE_BWD_KERNEL_ALL_WDM(optimizer, segment_split, fp16, half, fp32, float, embedding_dim, 2, 8) \ + SPLIT_TBE_BWD_KERNEL_ALL_WDM(optimizer, segment_split, fp16, half, fp16, half, embedding_dim, 2, 8) + +// warp per row +SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 0, 64); +SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 0, 128); +SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 0, 192); +SPLIT_TBE_BWD_KERNEL(rowwise_adagrad, 0, 256); \ No newline at end of file From 502af32cdafb4df88ea434bad8c26668ad37b7a0 Mon Sep 17 00:00:00 2001 From: Li Li Date: Mon, 21 Nov 2022 20:48:39 +0000 Subject: [PATCH 7/8] add weighted to backward pipelined hip --- .../embedding_backward_split_template.cu | 26 +- fbgemm_gpu/hip_kernel/split_tbe_bwd.hip.hpp | 22 +- fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp | 289 +++++++++++++----- fbgemm_gpu/hip_kernel/split_tbe_common_hip.h | 22 +- .../split_table_batched_embeddings_test.py | 4 +- 5 files changed, 273 insertions(+), 90 deletions(-) diff --git a/fbgemm_gpu/codegen/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_template.cu index 1a718dcaf..81e6d90b0 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_template.cu @@ -1296,6 +1296,9 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ const void* p_sorted_infos; magic_div_u32_t batch_mdiv; uint32_t max_segment_length_per_warp; + {% if weighted %} + float *indice_weights_sorted; + {% endif %} uint32_t emb_dim; uint32_t batch; uint32_t num_rows; @@ -1317,6 +1320,9 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ karg.p_sorted_infos = infos_sorted.packed_accessor32().data(); karg.batch_mdiv = magic_div_u32_gen(B); karg.max_segment_length_per_warp = max_segment_length_per_warp; + {% if weighted %} + karg.indice_weights_sorted = indice_weights_sorted.packed_accessor32().data(); + {% endif %} karg.emb_dim = max_D; karg.batch = B; karg.num_rows = dev_weights.numel() / T / max_D; @@ -1338,7 +1344,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {% for weight_decay_mode_current in [0, 1, 2] %} if(weight_decay_mode == {{ weight_decay_mode_current }}){ if(dev_weights.scalar_type() == at::ScalarType::Half && grad_output.scalar_type() == at::ScalarType::Half){ - hipLaunchKernelGGL(split_tbe_bwd_hip_kernel_{{ optimizer }}_w{{ weight_decay_mode_current }}_s0_fp16_fp16_e{{ items_per_warp // 4 * kMaxElemPerThread }}, + hipLaunchKernelGGL(split_tbe_bwd_{{wdesc}}_hip_kernel_{{ optimizer }}_w{{ weight_decay_mode_current }}_s0_fp16_fp16_e{{ items_per_warp // 4 * kMaxElemPerThread }}, dim3(grids[0], grids[1], grids[2]), dim3(blocks[0], blocks[1], blocks[2]), 0, 0, @@ -1353,6 +1359,9 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ (const int32_t*)karg.p_sorted_infos , karg.batch_mdiv, karg.max_segment_length_per_warp, + {% if weighted %} + karg.indice_weights_sorted, + {% endif %} karg.emb_dim , karg.batch , karg.num_rows, @@ -1363,7 +1372,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ ); }else if (!(dev_weights.scalar_type() == at::ScalarType::Half) && grad_output.scalar_type() == at::ScalarType::Half) { - hipLaunchKernelGGL(split_tbe_bwd_hip_kernel_{{ optimizer }}_w{{ weight_decay_mode_current }}_s0_fp32_fp16_e{{ items_per_warp // 4 * kMaxElemPerThread }}, + hipLaunchKernelGGL(split_tbe_bwd_{{wdesc}}_hip_kernel_{{ optimizer }}_w{{ weight_decay_mode_current }}_s0_fp32_fp16_e{{ items_per_warp // 4 * kMaxElemPerThread }}, dim3(grids[0], grids[1], grids[2]), dim3(blocks[0], blocks[1], blocks[2]), 0, 0, @@ -1378,6 +1387,9 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ (const int32_t*)karg.p_sorted_infos , karg.batch_mdiv, karg.max_segment_length_per_warp, + {% if weighted %} + karg.indice_weights_sorted, + {% endif %} karg.emb_dim , karg.batch , karg.num_rows, @@ -1390,7 +1402,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ } else if (dev_weights.scalar_type() == at::ScalarType::Half && !(grad_output.scalar_type() == at::ScalarType::Half)) { - hipLaunchKernelGGL(split_tbe_bwd_hip_kernel_{{ optimizer }}_w{{ weight_decay_mode_current }}_s0_fp16_fp32_e{{ items_per_warp // 4 * kMaxElemPerThread }}, + hipLaunchKernelGGL(split_tbe_bwd_{{wdesc}}_hip_kernel_{{ optimizer }}_w{{ weight_decay_mode_current }}_s0_fp16_fp32_e{{ items_per_warp // 4 * kMaxElemPerThread }}, dim3(grids[0], grids[1], grids[2]), dim3(blocks[0], blocks[1], blocks[2]), 0, 0, @@ -1405,6 +1417,9 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ (const int32_t*)karg.p_sorted_infos , karg.batch_mdiv, karg.max_segment_length_per_warp, + {% if weighted %} + karg.indice_weights_sorted, + {% endif %} karg.emb_dim , karg.batch , karg.num_rows, @@ -1415,7 +1430,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ ); } else{ - hipLaunchKernelGGL(split_tbe_bwd_hip_kernel_{{ optimizer }}_w{{ weight_decay_mode_current }}_s0_fp32_fp32_e{{ items_per_warp // 4 * kMaxElemPerThread }}, + hipLaunchKernelGGL(split_tbe_bwd_{{wdesc}}_hip_kernel_{{ optimizer }}_w{{ weight_decay_mode_current }}_s0_fp32_fp32_e{{ items_per_warp // 4 * kMaxElemPerThread }}, dim3(grids[0], grids[1], grids[2]), dim3(blocks[0], blocks[1], blocks[2]), 0, 0, @@ -1430,6 +1445,9 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ (const int32_t*)karg.p_sorted_infos , karg.batch_mdiv, karg.max_segment_length_per_warp, + {% if weighted %} + karg.indice_weights_sorted, + {% endif %} karg.emb_dim , karg.batch , karg.num_rows, diff --git a/fbgemm_gpu/hip_kernel/split_tbe_bwd.hip.hpp b/fbgemm_gpu/hip_kernel/split_tbe_bwd.hip.hpp index 42a1849e4..a69ff64af 100644 --- a/fbgemm_gpu/hip_kernel/split_tbe_bwd.hip.hpp +++ b/fbgemm_gpu/hip_kernel/split_tbe_bwd.hip.hpp @@ -35,7 +35,7 @@ segment_prefetch, \ segment_unroll) \ extern "C" __global__ void \ - split_tbe_bwd_hip_kernel_##optimizer##_w##weight_decay_mode##_s##segment_split##_##emb_prec##_##grad_prec##_e##embedding_dim( \ + split_tbe_bwd_unweighted_hip_kernel_##optimizer##_w##weight_decay_mode##_s##segment_split##_##emb_prec##_##grad_prec##_e##embedding_dim( \ const grad_type* p_output_grad, \ emb_type* p_emb_table, \ const int64_t* p_hash_size_cumsum, \ @@ -51,6 +51,26 @@ uint32_t batch, \ uint32_t num_rows, \ uint32_t num_tables, \ + optimizer##_kernel_arg_t opt_karg); \ + \ + extern "C" __global__ void \ + split_tbe_bwd_weighted_hip_kernel_##optimizer##_w##weight_decay_mode##_s##segment_split##_##emb_prec##_##grad_prec##_e##embedding_dim( \ + const grad_type* p_output_grad, \ + emb_type* p_emb_table, \ + const int64_t* p_hash_size_cumsum, \ + const int64_t* p_sorted_linear_indices_run, \ + const int32_t* p_sorted_linear_indices_cumulative_run_lengths, \ + const int32_t* p_sorted_linear_indices_num_runs, \ + const int32_t* p_long_run_ids, \ + const int32_t* p_num_long_run_ids, \ + const int32_t* p_sorted_infos, \ + magic_div_u32_t batch_mdiv, \ + uint32_t max_segment_length_per_warp, \ + const float * p_indice_weights, \ + uint32_t emb_dim, \ + uint32_t batch, \ + uint32_t num_rows, \ + uint32_t num_tables, \ optimizer##_kernel_arg_t opt_karg); #define SPLIT_TBE_BWD_KERNEL_ALL_WDM(optimizer, \ diff --git a/fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp b/fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp index 45f25e9f3..be025ac35 100644 --- a/fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp +++ b/fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp @@ -115,10 +115,11 @@ template // 0-warp per row, 1-cta per row, 2-atomic(needed?) -__device__ void split_tbe_backward_unweighted_hip_kernel( + int32_t segment_prefetch, // 2 + int32_t segment_unroll, // 8 + int32_t segment_split, // 0-warp per row, 1-cta per row, 2-atomic(needed?) + bool weighted> +__device__ void split_tbe_backward_hip_kernel( const grad_t* p_output_grad, emb_t* p_emb_table, const int64_t* p_hash_size_cumsum, @@ -134,11 +135,12 @@ __device__ void split_tbe_backward_unweighted_hip_kernel( uint32_t batch, uint32_t num_rows, uint32_t num_tables, - optimizer_karg_t opt_karg) + optimizer_karg_t opt_karg, + const float * p_sorted_indice_weights = nullptr) { - constexpr uint32_t dword_per_row = (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - constexpr uint32_t waves_per_block = block_size / AMDGCN_WAVE_SIZE; - constexpr uint32_t length_mask = ~(segment_unroll - 1); + constexpr uint32_t dword_per_row = (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; // number of columns each thread will process + constexpr uint32_t waves_per_block = block_size / AMDGCN_WAVE_SIZE; // 256 / 64 + constexpr uint32_t length_mask = ~(segment_unroll - 1); // ~(8-1) = 0b000 const uint32_t wave_id = __builtin_amdgcn_readfirstlane(threadIdx.x / AMDGCN_WAVE_SIZE); const uint32_t lane_id = threadIdx.x % AMDGCN_WAVE_SIZE; const uint32_t run_id = wave_id + blockIdx.x * waves_per_block; @@ -146,7 +148,7 @@ __device__ void split_tbe_backward_unweighted_hip_kernel( // printf("wave_id:%d, run_id:%d(%d), batch:%d(%d, %d)\n", // wave_id, run_id, p_sorted_linear_indices_num_runs[0], batch, batch_mdiv.magic, batch_mdiv.shift); - if(run_id >= p_sorted_linear_indices_num_runs[0]) + if(run_id >= p_sorted_linear_indices_num_runs[0]) // number of segment return; const int64_t linear_index = p_sorted_linear_indices_run[run_id]; @@ -154,17 +156,17 @@ __device__ void split_tbe_backward_unweighted_hip_kernel( const int32_t segment_start = p_sorted_linear_indices_cumulative_run_lengths[run_id]; const int32_t segment_end = p_sorted_linear_indices_cumulative_run_lengths[run_id + 1]; - int32_t info_0 = p_sorted_infos[segment_start]; - uint32_t t_0 = magic_div_u32_run(batch_mdiv, info_0); - int64_t hash_size = p_hash_size_cumsum[t_0]; + int32_t info_0 = p_sorted_infos[segment_start]; // start of a segment in linear index + uint32_t t_0 = magic_div_u32_run(batch_mdiv, info_0); // determine which table info_0 stays + int64_t hash_size = p_hash_size_cumsum[t_0]; // p_hash_size_cumsum: the first element offset of a table in rows - const int64_t emb_idx = linear_index - hash_size; + const int64_t emb_idx = linear_index - hash_size; // the location of a row in a table // printf("[%d] segment_start:%d, info_0:%d, t_0:%d, num_rows:%d, emb_dim:%d, linear_index:%ld\n", wave_id, segment_start, info_0, t_0, num_rows, emb_dim, linear_index); // p_output_grad += t_0 * emb_dim; - p_emb_table += hash_size * emb_dim; + p_emb_table += hash_size * emb_dim; // start of the current talbe (p_emb_table is a pointer) opt_karg.p_momentum = reinterpret_cast(reinterpret_cast(opt_karg.p_momentum) + hash_size); const int32_t segment_length = segment_end - segment_start; @@ -174,12 +176,13 @@ __device__ void split_tbe_backward_unweighted_hip_kernel( // printf("[%d] segment_length:%d\n", wave_id, segment_length); - const int32_t segment_length_mod = segment_length & length_mask; + const int32_t segment_length_mod = segment_length & length_mask; // segment_length_mod is a multiplication of 8 cache_t grad_acc[dword_per_row]; int32_t infos[segment_unroll]; grad_t grad_data[dword_per_row * segment_prefetch]; emb_t emb_data[dword_per_row]; + float indice_weights[segment_unroll]; #pragma unroll for(int i=0; i < dword_per_row; i++) @@ -191,15 +194,27 @@ __device__ void split_tbe_backward_unweighted_hip_kernel( if(segment_length_mod == 0) goto L_tail_grad_acc; -#pragma unroll +if constexpr (!weighted) { + #pragma unroll + for(int i = 0; i < segment_unroll; i++) + { + infos[i] = p_sorted_infos[segment_start + i]; + } +} else { for(int i = 0; i < segment_unroll; i++) { infos[i] = p_sorted_infos[segment_start + i]; + indice_weights[i] = p_sorted_indice_weights[segment_start + i]; } +} itr += segment_unroll; p_sorted_infos += segment_unroll; +if constexpr (weighted) { + p_sorted_indice_weights += segment_unroll; +} + uint32_t bag_index; uint32_t table_index; @@ -213,35 +228,69 @@ __device__ void split_tbe_backward_unweighted_hip_kernel( magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + if constexpr (!weighted){ + #pragma unroll + for(int j = 2; j < segment_unroll; j += 2) + { + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id); + magic_div_u32_run_with_mod( + batch_mdiv, infos[j + 1], batch, table_index, bag_index); + load_row_per_warp::run( + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + } -#pragma unroll - for(int j = 2; j < segment_unroll; j += 2) - { - accumulate_row_per_warp::run( + accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - - accumulate_row_per_warp::run( + accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - magic_div_u32_run_with_mod( - batch_mdiv, infos[j + 1], batch, table_index, bag_index); - load_row_per_warp::run( - &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - } -#pragma unroll - for(int i = 0; i < segment_unroll; i++) - { - infos[i] = p_sorted_infos[segment_start + i]; - } - p_sorted_infos += segment_unroll; + #pragma unroll + for(int i = 0; i < segment_unroll; i++) + { + infos[i] = p_sorted_infos[segment_start + i]; + } + p_sorted_infos += segment_unroll; - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[0], lane_id); - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[dword_per_row], lane_id); + + } else { + #pragma unroll + for(int j = 2; j < segment_unroll; j += 2) + { + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); + magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); + magic_div_u32_run_with_mod( + batch_mdiv, infos[j + 1], batch, table_index, bag_index); + load_row_per_warp::run( + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + } + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id, indice_weights[segment_unroll-2]); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[segment_unroll-1]); + + #pragma unroll + for(int i = 0; i < segment_unroll; i++) + { + infos[i] = p_sorted_infos[segment_start + i]; + indice_weights[i] = p_sorted_indice_weights[segment_start + i]; + } + p_sorted_infos += segment_unroll; + p_sorted_indice_weights += segment_unroll; + } } // LAST @@ -253,44 +302,85 @@ __device__ void split_tbe_backward_unweighted_hip_kernel( load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); -#pragma unroll - for(int j = 2; j < segment_unroll; j += 2) - { - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[0], lane_id); - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + if constexpr (!weighted) { + #pragma unroll + for(int j = 2; j < segment_unroll; j += 2) + { + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id); + magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); + load_row_per_warp::run( + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + } - accumulate_row_per_warp::run( + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - load_row_per_warp::run( - &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - } + } else { + #pragma unroll + for(int j = 2; j < segment_unroll; j += 2) + { + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); + magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); + magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); + load_row_per_warp::run( + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + } - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[0], lane_id); - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[dword_per_row], lane_id); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id, indice_weights[segment_unroll-2]); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[segment_unroll-1]); + } L_tail_grad_acc: if(segment_length & (segment_unroll - 1)) { - // last, load one by one - do - { - infos[0] = p_sorted_infos[segment_start]; - p_sorted_infos++; - - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[0], lane_id); - - itr++; - } while(itr < segment_length); + if constexpr (!weighted){ + // last, load one by one + do + { + infos[0] = p_sorted_infos[segment_start]; + p_sorted_infos++; + + magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + + itr++; + } while(itr < segment_length); + } else { + do + { + infos[0] = p_sorted_infos[segment_start]; + indice_weights[0] = p_sorted_indice_weights[segment_start]; + p_sorted_infos++; + p_sorted_indice_weights++; + + magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id, indice_weights[0]); + + itr++; + } while(itr < segment_length); + } } // printf("[%d] segment_length:%d ==<< %f, emb_idx:%ld\n", wave_id, segment_length, grad_acc[0], emb_idx); @@ -316,7 +406,7 @@ __device__ void split_tbe_backward_unweighted_hip_kernel( segment_prefetch, \ segment_unroll) \ extern "C" __global__ void \ - split_tbe_bwd_hip_kernel_##optimizer##_w##weight_decay_mode##_s##segment_split##_##emb_prec##_##grad_prec##_e##embedding_dim( \ + split_tbe_bwd_unweighted_hip_kernel_##optimizer##_w##weight_decay_mode##_s##segment_split##_##emb_prec##_##grad_prec##_e##embedding_dim( \ const grad_type* p_output_grad, \ emb_type* p_emb_table, \ const int64_t* p_hash_size_cumsum, \ @@ -334,7 +424,7 @@ __device__ void split_tbe_backward_unweighted_hip_kernel( uint32_t num_tables, \ optimizer##_kernel_arg_t opt_karg) \ { \ - split_tbe_backward_unweighted_hip_kernel< \ + split_tbe_backward_hip_kernel< \ optimizer##_optimizer_t, \ optimizer##_kernel_arg_t, \ emb_type, \ @@ -344,7 +434,8 @@ __device__ void split_tbe_backward_unweighted_hip_kernel( embedding_dim, \ segment_prefetch, \ segment_unroll, \ - segment_split>(p_output_grad, \ + segment_split, \ + false>(p_output_grad, \ p_emb_table, \ p_hash_size_cumsum, \ p_sorted_linear_indices_run, \ @@ -360,6 +451,56 @@ __device__ void split_tbe_backward_unweighted_hip_kernel( num_rows, \ num_tables, \ opt_karg); \ + } \ + \ + extern "C" __global__ void \ + split_tbe_bwd_weighted_hip_kernel_##optimizer##_w##weight_decay_mode##_s##segment_split##_##emb_prec##_##grad_prec##_e##embedding_dim( \ + const grad_type* p_output_grad, \ + emb_type* p_emb_table, \ + const int64_t* p_hash_size_cumsum, \ + const int64_t* p_sorted_linear_indices_run, \ + const int32_t* p_sorted_linear_indices_cumulative_run_lengths, \ + const int32_t* p_sorted_linear_indices_num_runs, \ + const int32_t* p_long_run_ids, \ + const int32_t* p_num_long_run_ids, \ + const int32_t* p_sorted_infos, \ + magic_div_u32_t batch_mdiv, \ + uint32_t max_segment_length_per_warp, \ + const float * p_indice_weights, \ + uint32_t emb_dim, \ + uint32_t batch, \ + uint32_t num_rows, \ + uint32_t num_tables, \ + optimizer##_kernel_arg_t opt_karg) \ + { \ + split_tbe_backward_hip_kernel< \ + optimizer##_optimizer_t, \ + optimizer##_kernel_arg_t, \ + emb_type, \ + float, \ + grad_type, \ + BLOCK_SIZE, \ + embedding_dim, \ + segment_prefetch, \ + segment_unroll, \ + segment_split, \ + true>(p_output_grad, \ + p_emb_table, \ + p_hash_size_cumsum, \ + p_sorted_linear_indices_run, \ + p_sorted_linear_indices_cumulative_run_lengths, \ + p_sorted_linear_indices_num_runs, \ + p_long_run_ids, \ + p_num_long_run_ids, \ + p_sorted_infos, \ + batch_mdiv, \ + max_segment_length_per_warp, \ + emb_dim, \ + batch, \ + num_rows, \ + num_tables, \ + opt_karg, \ + p_indice_weights); \ } #define SPLIT_TBE_BWD_KERNEL_ALL_WDM(optimizer, \ diff --git a/fbgemm_gpu/hip_kernel/split_tbe_common_hip.h b/fbgemm_gpu/hip_kernel/split_tbe_common_hip.h index e7d8ef0f9..9d690d43e 100644 --- a/fbgemm_gpu/hip_kernel/split_tbe_common_hip.h +++ b/fbgemm_gpu/hip_kernel/split_tbe_common_hip.h @@ -188,16 +188,20 @@ struct load_row_per_warp } }; -template -struct accumulate_row_per_warp -{ +template +struct accumulate_row_per_warp { static constexpr int dword_per_row = (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - static __device__ void run(output_t* acc, emb_t* emb_data, int lane_id) - { -#pragma unroll - for(int i = 0; i < dword_per_row; i++) - { - acc[i] += static_cast(emb_data[i]); + static __device__ void run(output_t * acc, emb_t * emb_data, int lane_id, float row_weight = 1.0) { + if constexpr (!weighted) { + #pragma unroll + for(int i = 0; i < dword_per_row; i++){ + acc[i] += static_cast(emb_data[i]); + } + } else { + #pragma unroll + for(int i = 0; i < dword_per_row; i++){ + acc[i] += static_cast((float)emb_data[i] * row_weight); + } } } }; diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index 2acfda7de..f2add207d 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -2961,11 +2961,11 @@ def test_backward_optimizers_adam( # noqa C901 @given( T=st.integers(min_value=1, max_value=5), - D=st.just(48), # 16, 32, 48, 64 + D=st.sampled_from([16, 32, 48, 64]), # 16, 32, 48, 64 B=st.integers(min_value=1, max_value=128), log_E=st.integers(min_value=3, max_value=5), L=st.integers(min_value=0, max_value=20), - weighted=st.just(False), + weighted=st.sampled_from([True, False]), mixed=st.just(False), optimizer=st.sampled_from( [ From 984032c28fcc66b8c0fc19cb36271bc20ca7933f Mon Sep 17 00:00:00 2001 From: Li Li Date: Mon, 21 Nov 2022 20:51:28 +0000 Subject: [PATCH 8/8] clean up comments --- fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp b/fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp index be025ac35..bf93854fc 100644 --- a/fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp +++ b/fbgemm_gpu/hip_kernel/split_tbe_bwd_hip.cpp @@ -138,9 +138,9 @@ __device__ void split_tbe_backward_hip_kernel( optimizer_karg_t opt_karg, const float * p_sorted_indice_weights = nullptr) { - constexpr uint32_t dword_per_row = (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; // number of columns each thread will process - constexpr uint32_t waves_per_block = block_size / AMDGCN_WAVE_SIZE; // 256 / 64 - constexpr uint32_t length_mask = ~(segment_unroll - 1); // ~(8-1) = 0b000 + constexpr uint32_t dword_per_row = (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; + constexpr uint32_t waves_per_block = block_size / AMDGCN_WAVE_SIZE; + constexpr uint32_t length_mask = ~(segment_unroll - 1); const uint32_t wave_id = __builtin_amdgcn_readfirstlane(threadIdx.x / AMDGCN_WAVE_SIZE); const uint32_t lane_id = threadIdx.x % AMDGCN_WAVE_SIZE; const uint32_t run_id = wave_id + blockIdx.x * waves_per_block; @@ -148,7 +148,7 @@ __device__ void split_tbe_backward_hip_kernel( // printf("wave_id:%d, run_id:%d(%d), batch:%d(%d, %d)\n", // wave_id, run_id, p_sorted_linear_indices_num_runs[0], batch, batch_mdiv.magic, batch_mdiv.shift); - if(run_id >= p_sorted_linear_indices_num_runs[0]) // number of segment + if(run_id >= p_sorted_linear_indices_num_runs[0]) return; const int64_t linear_index = p_sorted_linear_indices_run[run_id]; @@ -156,17 +156,17 @@ __device__ void split_tbe_backward_hip_kernel( const int32_t segment_start = p_sorted_linear_indices_cumulative_run_lengths[run_id]; const int32_t segment_end = p_sorted_linear_indices_cumulative_run_lengths[run_id + 1]; - int32_t info_0 = p_sorted_infos[segment_start]; // start of a segment in linear index - uint32_t t_0 = magic_div_u32_run(batch_mdiv, info_0); // determine which table info_0 stays - int64_t hash_size = p_hash_size_cumsum[t_0]; // p_hash_size_cumsum: the first element offset of a table in rows + int32_t info_0 = p_sorted_infos[segment_start]; + uint32_t t_0 = magic_div_u32_run(batch_mdiv, info_0); + int64_t hash_size = p_hash_size_cumsum[t_0]; - const int64_t emb_idx = linear_index - hash_size; // the location of a row in a table + const int64_t emb_idx = linear_index - hash_size; // printf("[%d] segment_start:%d, info_0:%d, t_0:%d, num_rows:%d, emb_dim:%d, linear_index:%ld\n", wave_id, segment_start, info_0, t_0, num_rows, emb_dim, linear_index); // p_output_grad += t_0 * emb_dim; - p_emb_table += hash_size * emb_dim; // start of the current talbe (p_emb_table is a pointer) + p_emb_table += hash_size * emb_dim; opt_karg.p_momentum = reinterpret_cast(reinterpret_cast(opt_karg.p_momentum) + hash_size); const int32_t segment_length = segment_end - segment_start; @@ -176,7 +176,7 @@ __device__ void split_tbe_backward_hip_kernel( // printf("[%d] segment_length:%d\n", wave_id, segment_length); - const int32_t segment_length_mod = segment_length & length_mask; // segment_length_mod is a multiplication of 8 + const int32_t segment_length_mod = segment_length & length_mask; cache_t grad_acc[dword_per_row]; int32_t infos[segment_unroll];