Skip to content

Commit

Permalink
[FORK][FEATURE] DQ IP: performance enhansments
Browse files Browse the repository at this point in the history
- allocate aux accums regs on stack
- precompute grouped src sums
- optimize pointer arithmetic
- reduce aux vecs count requred for the microkernel
  • Loading branch information
dmitry-gorokhov committed Feb 3, 2025
1 parent 1789b1e commit b145489
Show file tree
Hide file tree
Showing 13 changed files with 667 additions and 201 deletions.
1 change: 1 addition & 0 deletions src/common/memory_tracking.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ enum {
key_decompression_zero_points,
key_src_quantized,
key_src_dequantized_scales,
key_src_grouped_sum,
// These two keys should always be the last ones,
// even though they are not in alphabetical order
key_nested,
Expand Down
22 changes: 18 additions & 4 deletions src/cpu/x64/brgemm/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ void brgemm_desc_t::cleanup_dst_md() {
void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
const brgemm_batch_element_t *batch, void *ptr_C, void *scratch,
const brgemm_dynamic_values_t *dynamic_values,
const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) {
const void *ptr_wei_scales, const void *ptr_wei_zero_points,
const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) {
brgemm_kernel_params_t brgemm_p;

brgemm_p.batch = batch;
Expand All @@ -105,6 +106,7 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
brgemm_p.ptr_wei_scales = ptr_wei_scales;
brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points;
brgemm_p.ptr_src_scales = ptr_src_scales;
brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum;
brgemm_p.ic = ic;

assert(brg_kernel);
Expand All @@ -116,7 +118,8 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
const void *addr_A, const void *addr_B,
const brgemm_batch_element_t *batch, void *ptr_C, void *scratch,
const brgemm_dynamic_values_t *dynamic_values,
const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) {
const void *ptr_wei_scales, const void *ptr_wei_zero_points,
const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) {
brgemm_kernel_params_t brgemm_p;

brgemm_p.batch = batch;
Expand All @@ -133,6 +136,7 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
brgemm_p.ptr_wei_scales = ptr_wei_scales;
brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points;
brgemm_p.ptr_src_scales = ptr_src_scales;
brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum;
brgemm_p.ic = ic;
if (dynamic_values) {
brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA;
Expand All @@ -148,7 +152,8 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
const brgemm_post_ops_data_t &post_ops_data, void *scratch,
const brgemm_dynamic_values_t *dynamic_values,
const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) {
const void *ptr_wei_scales, const void *ptr_wei_zero_points,
const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) {
brgemm_kernel_params_t brgemm_p;

brgemm_p.batch = batch;
Expand Down Expand Up @@ -178,6 +183,7 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
brgemm_p.ptr_wei_scales = ptr_wei_scales;
brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points;
brgemm_p.ptr_src_scales = ptr_src_scales;
brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum;
brgemm_p.ic = ic;
if (dynamic_values) {
brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA;
Expand All @@ -194,7 +200,8 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
const brgemm_post_ops_data_t &post_ops_data, void *scratch,
const brgemm_dynamic_values_t *dynamic_values,
const void *ptr_wei_scales, const void *ptr_wei_zero_points, const void *ptr_src_scales, size_t ic) {
const void *ptr_wei_scales, const void *ptr_wei_zero_points,
const void *ptr_src_scales, const void *ptr_src_grouped_sum, size_t ic) {
brgemm_kernel_params_t brgemm_p;

brgemm_p.batch = batch;
Expand Down Expand Up @@ -224,6 +231,7 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
brgemm_p.ptr_wei_scales = ptr_wei_scales;
brgemm_p.ptr_wei_zero_points = ptr_wei_zero_points;
brgemm_p.ptr_src_scales = ptr_src_scales;
brgemm_p.ptr_src_grouped_sum = ptr_src_grouped_sum;
brgemm_p.ic = ic;
if (dynamic_values) {
brgemm_p.dynamic_LDA = dynamic_values->dynamic_LDA;
Expand Down Expand Up @@ -318,6 +326,12 @@ status_t brgemm_desc_init(brgemm_desc_t *brg, cpu_isa_t isa,

CHECK(brgemm_blocking(brg));

brg->src_sum_group_size = wei_d.dims()[1];
if (brg->with_src_dyn_quant) {
brg->src_sum_group_size = brg->rd_block;
brg->src_grouped_sum_stride = div_up(wei_d.dims()[1], brg->src_sum_group_size);
}

// avx2_vnni_2 kernel with xf16 data type requires blocked weights.
if (brg->isa_impl == avx2_vnni_2 && brg->is_xf16()
&& brg->LDB % brg->ld_block > 0)
Expand Down
8 changes: 4 additions & 4 deletions src/cpu/x64/brgemm/brgemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ void DNNL_API brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
void *scratch = nullptr,
const brgemm_dynamic_values_t *dynamic_values = nullptr,
const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr,
const void *ptr_src_scales = nullptr, size_t ic = 0);
const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0);

/// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version)
///
Expand Down Expand Up @@ -205,7 +205,7 @@ void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
void *scratch = nullptr,
const brgemm_dynamic_values_t *dynamic_values = nullptr,
const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr,
const void *ptr_src_scales = nullptr, size_t ic = 0);
const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0);

/// Execute BRGEMM kernel (brgemm_addr version)
///
Expand Down Expand Up @@ -234,7 +234,7 @@ void DNNL_API brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel,
const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr,
const brgemm_dynamic_values_t *dynamic_values = nullptr,
const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr,
const void *ptr_src_scales = nullptr, size_t ic = 0);
const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0);

/// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version)
///
Expand Down Expand Up @@ -267,7 +267,7 @@ void DNNL_API brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel,
const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr,
const brgemm_dynamic_values_t *dynamic_values = nullptr,
const void *ptr_wei_scales = nullptr, const void *ptr_wei_zero_points = nullptr,
const void *ptr_src_scales = nullptr, size_t ic = 0);
const void *ptr_src_scales = nullptr, const void *ptr_src_grouped_sum = nullptr, size_t ic = 0);

/// AMX utilities: Creates a palette based on BRGEMM descriptor
///
Expand Down
3 changes: 3 additions & 0 deletions src/cpu/x64/brgemm/brgemm_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@ struct brgemm_desc_t {
bool with_src_dyn_quant = false;
int src_scales_group_size = 0;
int src_scales_stride = 0;
int src_sum_group_size = 0;
int src_grouped_sum_stride = 0;

bool is_row_major() const {
assert(layout != brgemm_layout_undef);
Expand Down Expand Up @@ -500,6 +502,7 @@ struct brgemm_kernel_params_t {
const void *ptr_wei_scales = nullptr;
const void *ptr_wei_zero_points = nullptr;
const void *ptr_src_scales = nullptr;
const void *ptr_src_grouped_sum = nullptr;
size_t ic;
dim_t dynamic_LDA = 0;
dim_t dynamic_LDB = 0;
Expand Down
19 changes: 11 additions & 8 deletions src/cpu/x64/brgemm/brgemm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,10 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) {
if (one_of(brg->dt_b, data_type::nf4) && brg->isa_impl == avx2) max_bcast_block -= 5;
if (one_of(brg->dt_b, data_type::f4_e2m1) && brg->isa_impl == avx2) max_bcast_block -= 2;
if (one_of(brg->dt_b, data_type::nf4, data_type::f4_e2m1) && brg->isa_impl != avx2) max_bcast_block -= 1;
if (brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride == 0) max_bcast_block -= 1;
if (brg->with_src_dyn_quant) max_bcast_block -= 2;
if (brg->with_src_dyn_quant && brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride != 0) max_bcast_block -= adj_ld_block2;
if (brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride == 0 && !brg->with_src_dyn_quant) max_bcast_block -= 1;
if (brg->with_src_dyn_quant) max_bcast_block -= 1;

max_bcast_block /= adj_ld_block2;
if (brg->with_src_dyn_quant) {
max_bcast_block /= 2;
}

return max_bcast_block;
}
Expand Down Expand Up @@ -301,15 +297,22 @@ status_t brgemm_blocking(brgemm_desc_t *brg) {
= (brg->is_f16 && brg->isa_impl == avx512_core_fp16)
? 1
: data_type_vnni_granularity(brg->dt_a);

int rd_unroll = one_of(brg->dt_b, data_type::nf4, data_type::u4, data_type::s4, data_type::f4_e2m1) ? 32 : 4;
if (brg->with_grouped_wei_decomp) {
if (brg->with_grouped_wei_decomp && !brg->with_src_dyn_quant) {
auto min_group_size = nstl::min(brg->wei_decomp_scales_group_size, brg->wei_decomp_zero_points_group_size);
min_group_size = nstl::min(min_group_size, brg->src_scales_group_size);
rd_unroll = nstl::min(rd_unroll, min_group_size / vnni_granularity);
rd_unroll = nstl::min(rd_unroll, min_group_size / vnni_granularity);
brg->rd_block = rd_unroll * vnni_granularity;
} else if (brg->with_src_dyn_quant) {
brg->rd_block = brg->src_scales_group_size;
auto min_group_size = nstl::min(brg->wei_decomp_scales_group_size, brg->wei_decomp_zero_points_group_size);
brg->rd_block = nstl::min(brg->rd_block, min_group_size);
} else {
brg->rd_block = rd_unroll * vnni_granularity;
}

brg->rd_block = rd_unroll * vnni_granularity;
brg->rdb = brg->reduce_dim / brg->rd_block;
brg->rdb_tail = brg->reduce_dim % brg->rd_block;

Expand Down
Loading

0 comments on commit b145489

Please sign in to comment.