Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf bwd hip #23

Open
wants to merge 9 commits into
base: performance
Choose a base branch
from
Open

Perf bwd hip #23

wants to merge 9 commits into from

Conversation

carlushuang
Copy link

@carlushuang carlushuang commented Sep 28, 2022

python test/split_table_batched_embeddings_test.py SplitTableBatchedEmbeddingsTest.test_backward_adagrad_fp32_pmSUM
python test/split_table_batched_embeddings_test.py SplitTableBatchedEmbeddingsTest.test_backward_optimizers_adagrad

Can pass above 2 UTs ( by some modification)

  • support fp16/fp32 data type for emb_t and grad_t combination
  • support D=64, 128, 192, 256
  • support exact for now.
  • support rowwise-adagrad for now. And different optimizer can be support by different template lambda functor
  • support all 3 weight_decay_mode in rowwise-adagrad
  • support emb table duplication (like in test test_backward_adagrad_fp32_pmSUM)

carlushuang and others added 3 commits October 16, 2022 06:49
split_tbe_bwd.hip.cpp needed a hpp counterpart to
invoke the necessary macros to build all templates for
split_tbe_bwd_hip_kernel_
@@ -8,6 +8,10 @@
{% set wdesc = "weighted" if weighted else "unweighted" %}
#include "fbgemm_gpu/embedding_backward_template_helpers.cuh"
#include "fbgemm_gpu/split_embeddings_utils.cuh"
#include "hip_kernel/split_tbe_common_hip.h"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use
#ifdef HIP_PLATFORM_HCC
#endif
Around new includes

@@ -45,6 +45,7 @@
# An optimization for ROCm
env.globals["items_per_warp"] = 128 if args.is_rocm is False else 256

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have use_rocm being exposed to jinja, we can probably avoid needing an extra "items_per_warp"

static hipFunction_t hip_kernel_func;
{% if optimizer == "rowwise_adagrad" and not dense %}
std::set<int> D_emb_s {64, 128, 192, 256};
bool hip_opt_kernel_supported = (D_emb_s.find(max_D) != D_emb_s.end()) &&

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we do mixed dimension this check will go away.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants