forked from pytorch/FBGEMM
-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Perf bwd hip #23
Open
carlushuang
wants to merge
9
commits into
performance
Choose a base branch
from
perf-bwd-hip
base: performance
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Perf bwd hip #23
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
6f27e85
add init bwd kernel(not ready)
carlushuang b546e59
compiler OK now
carlushuang 369f076
fix bug in bwd
carlushuang 18fda96
modify 2 UTs
carlushuang 920ef65
Merge remote-tracking branch 'origin/performance' into perf-bwd-hip
carlushuang 70bccb7
build inside a single so
carlushuang 1a94a2d
Add hpp counterpart for bwd pass
dllehr-amd 502af32
add weighted to backward pipelined hip
liligwu 984032c
clean up comments
liligwu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,10 @@ | |
{% set wdesc = "weighted" if weighted else "unweighted" %} | ||
#include "fbgemm_gpu/embedding_backward_template_helpers.cuh" | ||
#include "fbgemm_gpu/split_embeddings_utils.cuh" | ||
#include "hip_kernel/split_tbe_common_hip.h" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
#include "hip_kernel/split_tbe_bwd.hip.hpp" | ||
#include <unistd.h> | ||
#include <iostream> | ||
|
||
#define SHFL_SYNC(val, srcLane) shfl_sync(val, srcLane, kThreadGroupSize, shfl_sync_mask) | ||
|
||
|
@@ -938,7 +942,47 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ | |
} | ||
{% endif %} | ||
|
||
static int init_hsaco = 0; | ||
static hipModule_t hip_kernel_module; | ||
static hipFunction_t hip_kernel_func; | ||
{% if optimizer == "rowwise_adagrad" and not dense %} | ||
std::set<int> D_emb_s {64, 128, 192, 256}; | ||
bool hip_opt_kernel_supported = (D_emb_s.find(max_D) != D_emb_s.end()) && | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When we do mixed dimension this check will go away. |
||
(dev_weights.scalar_type() == at::ScalarType::Half || dev_weights.scalar_type() == at::ScalarType::Float); | ||
{% else %} | ||
bool hip_opt_kernel_supported = false; // TODO: figure out support range | ||
{% endif %} | ||
|
||
#if 0 | ||
{% if optimizer == "rowwise_adagrad" and not dense %} | ||
if(hip_opt_kernel_supported && init_hsaco == 0){ | ||
int segment_split = 0; // warp per row | ||
hipError_t hip_err = hipModuleLoad(&hip_kernel_module, "hip_kernel/split_tbe_bwd_hip_kernel.hsaco"); // hip kernel object | ||
if (hip_err != hipSuccess) { | ||
char cwd[PATH_MAX]; | ||
getcwd(cwd, sizeof(cwd)); | ||
printf("[hiperror](%d) line:%d, fail to call,(%s), cwd:%s", (int) hip_err, __LINE__, hipGetErrorString(hip_err), cwd); | ||
exit(1); | ||
} | ||
std::string w_prec = dev_weights.scalar_type() == at::ScalarType::Half ? "fp16" : "fp32"; | ||
std::string g_prec = grad_output.scalar_type() == at::ScalarType::Half ? "fp16" : "fp32"; | ||
std::string hip_kernel_name = std::string("split_tbe_bwd_hip_kernel_") +"{{ optimizer }}" +"_w" + std::to_string(weight_decay_mode) + | ||
"_s" + std::to_string(segment_split) + "_" + w_prec + "_" + g_prec + "_e" + std::to_string(max_D); | ||
hip_err = hipModuleGetFunction(&hip_kernel_func, hip_kernel_module, hip_kernel_name.c_str()); | ||
printf("kernel function: %s, B:%d, T:%d(%d), wcnt:%ld, ocnt:%ld, mcnt:%ld\n", | ||
hip_kernel_name.c_str(), B, T, hash_size_cumsum.size(0) - 1, dev_weights.numel(), grad_output.numel(), momentum1_dev.numel()); | ||
if (hip_err != hipSuccess) { | ||
printf("[hiperror](%d) line:%d, fail to call,(%s), %s", (int) hip_err, __LINE__, hipGetErrorString(hip_err), hip_kernel_name.c_str()); | ||
exit(1); | ||
} | ||
|
||
init_hsaco = 1; | ||
} | ||
{% endif %} | ||
#endif | ||
|
||
{% if not dense %} | ||
|
||
DISPATCH_EMB_GRAD_CACHE_TYPES( | ||
dev_weights.scalar_type(), | ||
grad_output.scalar_type(), | ||
|
@@ -1039,7 +1083,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ | |
// kMaxElemPerThread is # of elements handled by thread if we use a full warp for a row | ||
// We consider kMaxElemPerThread 1 and 2, and then a multiple of 4. | ||
{% for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %} | ||
{% if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %} | ||
{% if kMaxElemPerThread in ([1, 2, 3] if is_rocm else [1, 2]) or kMaxElemPerThread % 4 == 0 %} | ||
if (max_D <= {{ items_per_warp // 4 * kMaxElemPerThread }}) { | ||
// hipcc can't use max in constexpr | ||
constexpr int kMaxVecsPerThread = {{ kMaxElemPerThread }} / 4 >= 1 ? {{ kMaxElemPerThread }} / 4 : 1; | ||
|
@@ -1238,6 +1282,190 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ | |
} | ||
|
||
C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||
{% if optimizer == "rowwise_adagrad" and not dense and (items_per_warp // 4 * kMaxElemPerThread) in [64, 128, 192, 256] %} | ||
if(hip_opt_kernel_supported){ | ||
struct { | ||
const void* p_output_grad; | ||
void* p_emb_table; | ||
const void* p_hash_size_cumsum; | ||
const void* p_sorted_linear_indices_run; | ||
const void* p_sorted_linear_indices_cumulative_run_lengths; | ||
const void* p_sorted_linear_indices_num_runs; | ||
const void* p_long_run_ids; | ||
const void* p_num_long_run_ids; | ||
const void* p_sorted_infos; | ||
magic_div_u32_t batch_mdiv; | ||
uint32_t max_segment_length_per_warp; | ||
{% if weighted %} | ||
float *indice_weights_sorted; | ||
{% endif %} | ||
uint32_t emb_dim; | ||
uint32_t batch; | ||
uint32_t num_rows; | ||
uint32_t num_tables; | ||
rowwise_adagrad_kernel_arg_t opt_karg; | ||
} karg; | ||
size_t arg_size = sizeof(karg); | ||
void* kconf[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &karg, HIP_LAUNCH_PARAM_BUFFER_SIZE, | ||
&arg_size, HIP_LAUNCH_PARAM_END}; | ||
|
||
karg.p_output_grad = grad_output_accessor.data(); | ||
karg.p_emb_table = dev_weights.packed_accessor64<emb_t, 1, at::RestrictPtrTraits>().data(); | ||
karg.p_hash_size_cumsum = hash_size_cumsum.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>().data(); | ||
karg.p_sorted_linear_indices_run = sorted_linear_indices_run.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>().data(); | ||
karg.p_sorted_linear_indices_cumulative_run_lengths = sorted_linear_indices_cumulative_run_lengths.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>().data(); | ||
karg.p_sorted_linear_indices_num_runs = sorted_linear_indices_num_runs.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>().data(); | ||
karg.p_long_run_ids = long_run_ids.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>().data(); | ||
karg.p_num_long_run_ids = num_long_run_ids.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>().data(); | ||
karg.p_sorted_infos = infos_sorted.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>().data(); | ||
karg.batch_mdiv = magic_div_u32_gen(B); | ||
karg.max_segment_length_per_warp = max_segment_length_per_warp; | ||
{% if weighted %} | ||
karg.indice_weights_sorted = indice_weights_sorted.packed_accessor32<float, 1, at::RestrictPtrTraits>().data(); | ||
{% endif %} | ||
karg.emb_dim = max_D; | ||
karg.batch = B; | ||
karg.num_rows = dev_weights.numel() / T / max_D; | ||
karg.num_tables = T; | ||
|
||
{% if optimizer == "rowwise_adagrad" and not dense %} | ||
rowwise_adagrad_kernel_arg_t opt_karg; | ||
opt_karg.p_momentum = momentum1_dev.packed_accessor64<at::acc_type<cache_t, true>, 1, at::RestrictPtrTraits>().data(); | ||
opt_karg.eps = eps; | ||
opt_karg.learning_rate = learning_rate; | ||
opt_karg.weight_decay = weight_decay; | ||
karg.opt_karg = opt_karg; | ||
{% endif %} | ||
|
||
constexpr int segments_per_workgroup = 4; | ||
int32_t grids[3] = {div_round_up(sorted_linear_indices_run.numel(), segments_per_workgroup), 1, 1}; | ||
int32_t blocks[3] = {256, 1, 1}; | ||
|
||
{% for weight_decay_mode_current in [0, 1, 2] %} | ||
if(weight_decay_mode == {{ weight_decay_mode_current }}){ | ||
if(dev_weights.scalar_type() == at::ScalarType::Half && grad_output.scalar_type() == at::ScalarType::Half){ | ||
hipLaunchKernelGGL(split_tbe_bwd_{{wdesc}}_hip_kernel_{{ optimizer }}_w{{ weight_decay_mode_current }}_s0_fp16_fp16_e{{ items_per_warp // 4 * kMaxElemPerThread }}, | ||
dim3(grids[0], grids[1], grids[2]), | ||
dim3(blocks[0], blocks[1], blocks[2]), | ||
0, 0, | ||
(const half*)karg.p_output_grad , | ||
(half*)karg.p_emb_table, | ||
(const int64_t*)karg.p_hash_size_cumsum, | ||
(const int64_t*)karg.p_sorted_linear_indices_run, | ||
(const int32_t* )karg.p_sorted_linear_indices_cumulative_run_lengths, | ||
(const int32_t*)karg.p_sorted_linear_indices_num_runs, | ||
(const int32_t*)karg.p_long_run_ids , | ||
(const int32_t*)karg.p_num_long_run_ids, | ||
(const int32_t*)karg.p_sorted_infos , | ||
karg.batch_mdiv, | ||
karg.max_segment_length_per_warp, | ||
{% if weighted %} | ||
karg.indice_weights_sorted, | ||
{% endif %} | ||
karg.emb_dim , | ||
karg.batch , | ||
karg.num_rows, | ||
karg.num_tables , | ||
{% if optimizer == "rowwise_adagrad" and not dense %} | ||
karg.opt_karg | ||
{% endif %} | ||
); | ||
}else if (!(dev_weights.scalar_type() == at::ScalarType::Half) && grad_output.scalar_type() == at::ScalarType::Half) | ||
{ | ||
hipLaunchKernelGGL(split_tbe_bwd_{{wdesc}}_hip_kernel_{{ optimizer }}_w{{ weight_decay_mode_current }}_s0_fp32_fp16_e{{ items_per_warp // 4 * kMaxElemPerThread }}, | ||
dim3(grids[0], grids[1], grids[2]), | ||
dim3(blocks[0], blocks[1], blocks[2]), | ||
0, 0, | ||
(const half*)karg.p_output_grad , | ||
(float*)karg.p_emb_table, | ||
(const int64_t*)karg.p_hash_size_cumsum, | ||
(const int64_t*)karg.p_sorted_linear_indices_run, | ||
(const int32_t* )karg.p_sorted_linear_indices_cumulative_run_lengths, | ||
(const int32_t*)karg.p_sorted_linear_indices_num_runs, | ||
(const int32_t*)karg.p_long_run_ids , | ||
(const int32_t*)karg.p_num_long_run_ids, | ||
(const int32_t*)karg.p_sorted_infos , | ||
karg.batch_mdiv, | ||
karg.max_segment_length_per_warp, | ||
{% if weighted %} | ||
karg.indice_weights_sorted, | ||
{% endif %} | ||
karg.emb_dim , | ||
karg.batch , | ||
karg.num_rows, | ||
karg.num_tables , | ||
{% if optimizer == "rowwise_adagrad" and not dense %} | ||
karg.opt_karg | ||
{% endif %} | ||
); | ||
|
||
} | ||
else if (dev_weights.scalar_type() == at::ScalarType::Half && !(grad_output.scalar_type() == at::ScalarType::Half)) | ||
{ | ||
hipLaunchKernelGGL(split_tbe_bwd_{{wdesc}}_hip_kernel_{{ optimizer }}_w{{ weight_decay_mode_current }}_s0_fp16_fp32_e{{ items_per_warp // 4 * kMaxElemPerThread }}, | ||
dim3(grids[0], grids[1], grids[2]), | ||
dim3(blocks[0], blocks[1], blocks[2]), | ||
0, 0, | ||
(const float*)karg.p_output_grad , | ||
(half*)karg.p_emb_table, | ||
(const int64_t*)karg.p_hash_size_cumsum, | ||
(const int64_t*)karg.p_sorted_linear_indices_run, | ||
(const int32_t* )karg.p_sorted_linear_indices_cumulative_run_lengths, | ||
(const int32_t*)karg.p_sorted_linear_indices_num_runs, | ||
(const int32_t*)karg.p_long_run_ids , | ||
(const int32_t*)karg.p_num_long_run_ids, | ||
(const int32_t*)karg.p_sorted_infos , | ||
karg.batch_mdiv, | ||
karg.max_segment_length_per_warp, | ||
{% if weighted %} | ||
karg.indice_weights_sorted, | ||
{% endif %} | ||
karg.emb_dim , | ||
karg.batch , | ||
karg.num_rows, | ||
karg.num_tables , | ||
{% if optimizer == "rowwise_adagrad" and not dense %} | ||
karg.opt_karg | ||
{% endif %} | ||
); | ||
} | ||
else{ | ||
hipLaunchKernelGGL(split_tbe_bwd_{{wdesc}}_hip_kernel_{{ optimizer }}_w{{ weight_decay_mode_current }}_s0_fp32_fp32_e{{ items_per_warp // 4 * kMaxElemPerThread }}, | ||
dim3(grids[0], grids[1], grids[2]), | ||
dim3(blocks[0], blocks[1], blocks[2]), | ||
0, 0, | ||
(const float*)karg.p_output_grad , | ||
(float*)karg.p_emb_table, | ||
(const int64_t*)karg.p_hash_size_cumsum, | ||
(const int64_t*)karg.p_sorted_linear_indices_run, | ||
(const int32_t* )karg.p_sorted_linear_indices_cumulative_run_lengths, | ||
(const int32_t*)karg.p_sorted_linear_indices_num_runs, | ||
(const int32_t*)karg.p_long_run_ids , | ||
(const int32_t*)karg.p_num_long_run_ids, | ||
(const int32_t*)karg.p_sorted_infos , | ||
karg.batch_mdiv, | ||
karg.max_segment_length_per_warp, | ||
{% if weighted %} | ||
karg.indice_weights_sorted, | ||
{% endif %} | ||
karg.emb_dim , | ||
karg.batch , | ||
karg.num_rows, | ||
karg.num_tables , | ||
{% if optimizer == "rowwise_adagrad" and not dense %} | ||
karg.opt_karg | ||
{% endif %} | ||
); | ||
} | ||
} | ||
{% endfor %} | ||
|
||
// hipModuleLaunchKernel(hip_kernel_func, | ||
// grids[0], grids[1], grids[2], | ||
// blocks[0], blocks[1], blocks[2], 0, 0, NULL, (void **) &kconf); | ||
|
||
}else | ||
{% endif %} | ||
split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_warp_per_row_1< | ||
{% if not dense %} | ||
emb_t, | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have use_rocm being exposed to jinja, we can probably avoid needing an extra "items_per_warp"