Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf bwd hip #23

Open
wants to merge 9 commits into
base: performance
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ if(NOT FBGEMM_CPU_ONLY)
codegen/embedding_backward_dense_host.cpp
codegen/embedding_bounds_check_host.cpp
hip_kernel/split_tbe_fwd_hip.cpp
hip_kernel/split_tbe_bwd_hip.cpp
src/cumem_utils_host.cpp
src/layout_transform_ops_gpu.cpp
src/permute_pooled_embedding_ops_gpu.cpp
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/codegen/embedding_backward_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
# An optimization for ROCm
env.globals["items_per_warp"] = 128 if args.is_rocm is False else 256

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have use_rocm being exposed to jinja, we can probably avoid needing an extra "items_per_warp"

env.globals["dense"] = False
env.globals["is_rocm"] = args.is_rocm


def write(filename: str, s: str) -> None:
Expand Down
230 changes: 229 additions & 1 deletion fbgemm_gpu/codegen/embedding_backward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
{% set wdesc = "weighted" if weighted else "unweighted" %}
#include "fbgemm_gpu/embedding_backward_template_helpers.cuh"
#include "fbgemm_gpu/split_embeddings_utils.cuh"
#include "hip_kernel/split_tbe_common_hip.h"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use
#ifdef HIP_PLATFORM_HCC
#endif
Around new includes

#include "hip_kernel/split_tbe_bwd.hip.hpp"
#include <unistd.h>
#include <iostream>

#define SHFL_SYNC(val, srcLane) shfl_sync(val, srcLane, kThreadGroupSize, shfl_sync_mask)

Expand Down Expand Up @@ -938,7 +942,47 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
}
{% endif %}

static int init_hsaco = 0;
static hipModule_t hip_kernel_module;
static hipFunction_t hip_kernel_func;
{% if optimizer == "rowwise_adagrad" and not dense %}
std::set<int> D_emb_s {64, 128, 192, 256};
bool hip_opt_kernel_supported = (D_emb_s.find(max_D) != D_emb_s.end()) &&

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we do mixed dimension this check will go away.

(dev_weights.scalar_type() == at::ScalarType::Half || dev_weights.scalar_type() == at::ScalarType::Float);
{% else %}
bool hip_opt_kernel_supported = false; // TODO: figure out support range
{% endif %}

#if 0
{% if optimizer == "rowwise_adagrad" and not dense %}
if(hip_opt_kernel_supported && init_hsaco == 0){
int segment_split = 0; // warp per row
hipError_t hip_err = hipModuleLoad(&hip_kernel_module, "hip_kernel/split_tbe_bwd_hip_kernel.hsaco"); // hip kernel object
if (hip_err != hipSuccess) {
char cwd[PATH_MAX];
getcwd(cwd, sizeof(cwd));
printf("[hiperror](%d) line:%d, fail to call,(%s), cwd:%s", (int) hip_err, __LINE__, hipGetErrorString(hip_err), cwd);
exit(1);
}
std::string w_prec = dev_weights.scalar_type() == at::ScalarType::Half ? "fp16" : "fp32";
std::string g_prec = grad_output.scalar_type() == at::ScalarType::Half ? "fp16" : "fp32";
std::string hip_kernel_name = std::string("split_tbe_bwd_hip_kernel_") +"{{ optimizer }}" +"_w" + std::to_string(weight_decay_mode) +
"_s" + std::to_string(segment_split) + "_" + w_prec + "_" + g_prec + "_e" + std::to_string(max_D);
hip_err = hipModuleGetFunction(&hip_kernel_func, hip_kernel_module, hip_kernel_name.c_str());
printf("kernel function: %s, B:%d, T:%d(%d), wcnt:%ld, ocnt:%ld, mcnt:%ld\n",
hip_kernel_name.c_str(), B, T, hash_size_cumsum.size(0) - 1, dev_weights.numel(), grad_output.numel(), momentum1_dev.numel());
if (hip_err != hipSuccess) {
printf("[hiperror](%d) line:%d, fail to call,(%s), %s", (int) hip_err, __LINE__, hipGetErrorString(hip_err), hip_kernel_name.c_str());
exit(1);
}

init_hsaco = 1;
}
{% endif %}
#endif

{% if not dense %}

DISPATCH_EMB_GRAD_CACHE_TYPES(
dev_weights.scalar_type(),
grad_output.scalar_type(),
Expand Down Expand Up @@ -1039,7 +1083,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
// kMaxElemPerThread is # of elements handled by thread if we use a full warp for a row
// We consider kMaxElemPerThread 1 and 2, and then a multiple of 4.
{% for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %}
{% if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %}
{% if kMaxElemPerThread in ([1, 2, 3] if is_rocm else [1, 2]) or kMaxElemPerThread % 4 == 0 %}
if (max_D <= {{ items_per_warp // 4 * kMaxElemPerThread }}) {
// hipcc can't use max in constexpr
constexpr int kMaxVecsPerThread = {{ kMaxElemPerThread }} / 4 >= 1 ? {{ kMaxElemPerThread }} / 4 : 1;
Expand Down Expand Up @@ -1238,6 +1282,190 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
}

C10_CUDA_KERNEL_LAUNCH_CHECK();
{% if optimizer == "rowwise_adagrad" and not dense and (items_per_warp // 4 * kMaxElemPerThread) in [64, 128, 192, 256] %}
if(hip_opt_kernel_supported){
struct {
const void* p_output_grad;
void* p_emb_table;
const void* p_hash_size_cumsum;
const void* p_sorted_linear_indices_run;
const void* p_sorted_linear_indices_cumulative_run_lengths;
const void* p_sorted_linear_indices_num_runs;
const void* p_long_run_ids;
const void* p_num_long_run_ids;
const void* p_sorted_infos;
magic_div_u32_t batch_mdiv;
uint32_t max_segment_length_per_warp;
{% if weighted %}
float *indice_weights_sorted;
{% endif %}
uint32_t emb_dim;
uint32_t batch;
uint32_t num_rows;
uint32_t num_tables;
rowwise_adagrad_kernel_arg_t opt_karg;
} karg;
size_t arg_size = sizeof(karg);
void* kconf[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &karg, HIP_LAUNCH_PARAM_BUFFER_SIZE,
&arg_size, HIP_LAUNCH_PARAM_END};

karg.p_output_grad = grad_output_accessor.data();
karg.p_emb_table = dev_weights.packed_accessor64<emb_t, 1, at::RestrictPtrTraits>().data();
karg.p_hash_size_cumsum = hash_size_cumsum.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>().data();
karg.p_sorted_linear_indices_run = sorted_linear_indices_run.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>().data();
karg.p_sorted_linear_indices_cumulative_run_lengths = sorted_linear_indices_cumulative_run_lengths.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>().data();
karg.p_sorted_linear_indices_num_runs = sorted_linear_indices_num_runs.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>().data();
karg.p_long_run_ids = long_run_ids.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>().data();
karg.p_num_long_run_ids = num_long_run_ids.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>().data();
karg.p_sorted_infos = infos_sorted.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>().data();
karg.batch_mdiv = magic_div_u32_gen(B);
karg.max_segment_length_per_warp = max_segment_length_per_warp;
{% if weighted %}
karg.indice_weights_sorted = indice_weights_sorted.packed_accessor32<float, 1, at::RestrictPtrTraits>().data();
{% endif %}
karg.emb_dim = max_D;
karg.batch = B;
karg.num_rows = dev_weights.numel() / T / max_D;
karg.num_tables = T;

{% if optimizer == "rowwise_adagrad" and not dense %}
rowwise_adagrad_kernel_arg_t opt_karg;
opt_karg.p_momentum = momentum1_dev.packed_accessor64<at::acc_type<cache_t, true>, 1, at::RestrictPtrTraits>().data();
opt_karg.eps = eps;
opt_karg.learning_rate = learning_rate;
opt_karg.weight_decay = weight_decay;
karg.opt_karg = opt_karg;
{% endif %}

constexpr int segments_per_workgroup = 4;
int32_t grids[3] = {div_round_up(sorted_linear_indices_run.numel(), segments_per_workgroup), 1, 1};
int32_t blocks[3] = {256, 1, 1};

{% for weight_decay_mode_current in [0, 1, 2] %}
if(weight_decay_mode == {{ weight_decay_mode_current }}){
if(dev_weights.scalar_type() == at::ScalarType::Half && grad_output.scalar_type() == at::ScalarType::Half){
hipLaunchKernelGGL(split_tbe_bwd_{{wdesc}}_hip_kernel_{{ optimizer }}_w{{ weight_decay_mode_current }}_s0_fp16_fp16_e{{ items_per_warp // 4 * kMaxElemPerThread }},
dim3(grids[0], grids[1], grids[2]),
dim3(blocks[0], blocks[1], blocks[2]),
0, 0,
(const half*)karg.p_output_grad ,
(half*)karg.p_emb_table,
(const int64_t*)karg.p_hash_size_cumsum,
(const int64_t*)karg.p_sorted_linear_indices_run,
(const int32_t* )karg.p_sorted_linear_indices_cumulative_run_lengths,
(const int32_t*)karg.p_sorted_linear_indices_num_runs,
(const int32_t*)karg.p_long_run_ids ,
(const int32_t*)karg.p_num_long_run_ids,
(const int32_t*)karg.p_sorted_infos ,
karg.batch_mdiv,
karg.max_segment_length_per_warp,
{% if weighted %}
karg.indice_weights_sorted,
{% endif %}
karg.emb_dim ,
karg.batch ,
karg.num_rows,
karg.num_tables ,
{% if optimizer == "rowwise_adagrad" and not dense %}
karg.opt_karg
{% endif %}
);
}else if (!(dev_weights.scalar_type() == at::ScalarType::Half) && grad_output.scalar_type() == at::ScalarType::Half)
{
hipLaunchKernelGGL(split_tbe_bwd_{{wdesc}}_hip_kernel_{{ optimizer }}_w{{ weight_decay_mode_current }}_s0_fp32_fp16_e{{ items_per_warp // 4 * kMaxElemPerThread }},
dim3(grids[0], grids[1], grids[2]),
dim3(blocks[0], blocks[1], blocks[2]),
0, 0,
(const half*)karg.p_output_grad ,
(float*)karg.p_emb_table,
(const int64_t*)karg.p_hash_size_cumsum,
(const int64_t*)karg.p_sorted_linear_indices_run,
(const int32_t* )karg.p_sorted_linear_indices_cumulative_run_lengths,
(const int32_t*)karg.p_sorted_linear_indices_num_runs,
(const int32_t*)karg.p_long_run_ids ,
(const int32_t*)karg.p_num_long_run_ids,
(const int32_t*)karg.p_sorted_infos ,
karg.batch_mdiv,
karg.max_segment_length_per_warp,
{% if weighted %}
karg.indice_weights_sorted,
{% endif %}
karg.emb_dim ,
karg.batch ,
karg.num_rows,
karg.num_tables ,
{% if optimizer == "rowwise_adagrad" and not dense %}
karg.opt_karg
{% endif %}
);

}
else if (dev_weights.scalar_type() == at::ScalarType::Half && !(grad_output.scalar_type() == at::ScalarType::Half))
{
hipLaunchKernelGGL(split_tbe_bwd_{{wdesc}}_hip_kernel_{{ optimizer }}_w{{ weight_decay_mode_current }}_s0_fp16_fp32_e{{ items_per_warp // 4 * kMaxElemPerThread }},
dim3(grids[0], grids[1], grids[2]),
dim3(blocks[0], blocks[1], blocks[2]),
0, 0,
(const float*)karg.p_output_grad ,
(half*)karg.p_emb_table,
(const int64_t*)karg.p_hash_size_cumsum,
(const int64_t*)karg.p_sorted_linear_indices_run,
(const int32_t* )karg.p_sorted_linear_indices_cumulative_run_lengths,
(const int32_t*)karg.p_sorted_linear_indices_num_runs,
(const int32_t*)karg.p_long_run_ids ,
(const int32_t*)karg.p_num_long_run_ids,
(const int32_t*)karg.p_sorted_infos ,
karg.batch_mdiv,
karg.max_segment_length_per_warp,
{% if weighted %}
karg.indice_weights_sorted,
{% endif %}
karg.emb_dim ,
karg.batch ,
karg.num_rows,
karg.num_tables ,
{% if optimizer == "rowwise_adagrad" and not dense %}
karg.opt_karg
{% endif %}
);
}
else{
hipLaunchKernelGGL(split_tbe_bwd_{{wdesc}}_hip_kernel_{{ optimizer }}_w{{ weight_decay_mode_current }}_s0_fp32_fp32_e{{ items_per_warp // 4 * kMaxElemPerThread }},
dim3(grids[0], grids[1], grids[2]),
dim3(blocks[0], blocks[1], blocks[2]),
0, 0,
(const float*)karg.p_output_grad ,
(float*)karg.p_emb_table,
(const int64_t*)karg.p_hash_size_cumsum,
(const int64_t*)karg.p_sorted_linear_indices_run,
(const int32_t* )karg.p_sorted_linear_indices_cumulative_run_lengths,
(const int32_t*)karg.p_sorted_linear_indices_num_runs,
(const int32_t*)karg.p_long_run_ids ,
(const int32_t*)karg.p_num_long_run_ids,
(const int32_t*)karg.p_sorted_infos ,
karg.batch_mdiv,
karg.max_segment_length_per_warp,
{% if weighted %}
karg.indice_weights_sorted,
{% endif %}
karg.emb_dim ,
karg.batch ,
karg.num_rows,
karg.num_tables ,
{% if optimizer == "rowwise_adagrad" and not dense %}
karg.opt_karg
{% endif %}
);
}
}
{% endfor %}

// hipModuleLaunchKernel(hip_kernel_func,
// grids[0], grids[1], grids[2],
// blocks[0], blocks[1], blocks[2], 0, 0, NULL, (void **) &kconf);

}else
{% endif %}
split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_warp_per_row_1<
{% if not dense %}
emb_t,
Expand Down
Loading