Skip to content

Commit

Permalink
[FORK][FIX] DQ IP: allocate aux accums via stack
Browse files Browse the repository at this point in the history
[FORK][FEATURE] InnerProduct primitive: squashed weight decompression
  • Loading branch information
dmitry-gorokhov committed Jan 24, 2025
1 parent 1789b1e commit d421730
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 29 deletions.
3 changes: 0 additions & 3 deletions src/cpu/x64/brgemm/brgemm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,6 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) {
if (brg->with_src_dyn_quant && brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride != 0) max_bcast_block -= adj_ld_block2;

max_bcast_block /= adj_ld_block2;
if (brg->with_src_dyn_quant) {
max_bcast_block /= 2;
}

return max_bcast_block;
}
Expand Down
46 changes: 25 additions & 21 deletions src/cpu/x64/brgemm/jit_brgemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -874,12 +874,13 @@ void jit_brgemm_kernel_t<Wmm>::ldb_regs_shift(int ld_block2, bool is_tail) {
mov(ptr[rsp + reg_aux_scales_offs_], reg_aux_scales);
}

if (brg.with_wei_decomp) {
if (brg.with_wei_decomp_scales && brg.wei_decomp_scales_stride != 0) {
mov(reg_aux_wei_scales, ptr[rsp + reg_aux_wei_scales_offs_]);
add(reg_aux_wei_scales, (is_tail) ? wei_scales_offset(1, true) : wei_scales_offset(ld_block2));
mov(ptr[rsp + reg_aux_wei_scales_offs_], reg_aux_wei_scales);
mov(ptr[rsp + reg_aux2_wei_scales_offs_], reg_aux_wei_scales);

}
if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0) {
mov(reg_aux_wei_zp, ptr[rsp + reg_aux_wei_zero_points_offs_]);
add(reg_aux_wei_zp, (is_tail) ? wei_zp_offset(1, true) : wei_zp_offset(ld_block2));
mov(ptr[rsp + reg_aux_wei_zero_points_offs_], reg_aux_wei_zp);
Expand Down Expand Up @@ -966,10 +967,6 @@ void jit_brgemm_kernel_t<Wmm>::copy_post_ops_stack_values_to_aux(
}

}
if (brg.with_grouped_wei_decomp) {
mov(reg_ic, ptr[rsp + reg_ic_offs_]);
mov(ptr[rsp + reg_aux_ic_offs_], reg_ic);
}
if (brg.with_src_dyn_quant) {
mov(reg_src_scales, ptr[rsp + reg_src_scales_offs_]);
mov(ptr[rsp + reg_aux_src_scales_offs_], reg_src_scales);
Expand Down Expand Up @@ -2298,11 +2295,6 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
if (brg.req_s8s8_compensation) uni_vpaddb(v1, v1, vmm_inp_shift());
};

auto vmm_accm_tmp = [&](int ld_block, int bd, int ld) {
int idx = max_effective_vregs - 1 - (brg.ld_block2 * brg.bd_block) - ld_block - (bd * ld_block + ld);
return Vmm(idx);
};

auto vmm_zero_point = [&](int ld) {
int idx = isa_num_vregs(brg.isa_impl) - 3 - ld;
return Vmm(idx);
Expand Down Expand Up @@ -2368,9 +2360,14 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,

mov(reg_local_wei_scales, ptr[rsp + reg_aux2_wei_scales_offs_]);

const int vec_size = vreg_traits<Vmm>::vlen;
auto accums_stack_space = bd_e * ld_block2 * vec_size;
sub(rsp, accums_stack_space);
for (int bd = bd_b; bd < bd_e; bd++) {
for (int ld = 0; ld < ld_block2; ld++) {
auto vmm_accm = vmm_accm_tmp(ld_block2, bd, ld);
auto vmm_accm = accm(ld_block2, bd, ld);
vmovups(ptr[rsp + (bd * ld_block2 + ld) * vec_size], vmm_accm);

uni_vxorps(vmm_accm, vmm_accm, vmm_accm);
}
}
Expand Down Expand Up @@ -2409,14 +2406,14 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
+ brg.LDB * brg.rd_block * brg.typesize_B]);
}
for (int ld = 0; ld < ld_block2; ld++) {
auto vmm = vmm_accm_tmp(ld_block2, bd, ld);
auto vmm = accm(ld_block2, bd, ld);
vpdpbusd(vmm, load(ld), bcst(), is_superset(brg.isa_impl, avx512_core) ? EvexEncoding : VexEncoding);
}
if (brg.with_wei_decomp_zero_points) {
uni_vpxor(bcst(), bcst(), vmm_neg_one);
uni_vpsubb(bcst(), bcst(), vmm_neg_one);
for (int ld = 0; ld < ld_block2; ld++) {
auto vmm = vmm_accm_tmp(ld_block2, bd, ld);
auto vmm = accm(ld_block2, bd, ld);
Vmm vmm_zp = brg.wei_decomp_zero_points_stride == 0 ? vmm_zero_point(0) : vmm_zero_point(ld);
vpdpbusd(vmm, vmm_zp, bcst(), is_superset(brg.isa_impl, avx512_core) ? EvexEncoding : VexEncoding);
}
Expand All @@ -2426,7 +2423,7 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,

auto reg_local_src_scales = reg_local_wei_zp;
auto vmm_src_scales = bcst();
mov(reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_]);
mov(reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_ + accums_stack_space]);

for (int bd = bd_b; bd < bd_e; bd++) {
uni_vbroadcastss(vmm_src_scales, ptr[reg_local_src_scales + bd * brg.src_scales_stride * sizeof(float)]);
Expand All @@ -2438,15 +2435,17 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
}
}
for (int ld = 0; ld < ld_block2; ld++) {
auto vmm_accm_aux = vmm_accm_tmp(ld_block2, bd, ld);
auto vmm_accm = accm(ld_block2, bd, ld);

uni_vcvtdq2ps(vmm_accm_aux, vmm_accm_aux);
uni_vmulps(vmm_accm_aux, vmm_accm_aux, vmm_src_scales);
uni_vfmadd231ps(vmm_accm, vmm_accm_aux, load(ld));
uni_vcvtdq2ps(vmm_accm, vmm_accm);
uni_vmulps(vmm_accm, vmm_accm, vmm_src_scales);
uni_vmulps(load(ld), vmm_accm, load(ld));
uni_vmovups(vmm_accm, ptr[rsp + (bd * ld_block2 + ld) * vec_size]);
uni_vaddps(vmm_accm, vmm_accm, load(ld));
}
}

add(rsp, accums_stack_space);
mov(reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]);
mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]);

Expand Down Expand Up @@ -3014,6 +3013,11 @@ void jit_brgemm_kernel_t<Wmm>::ldb_loop(int bd_block2, bool is_bdb_tail,
copy_post_ops_stack_values_to_aux(is_reg_tail);

auto ld_loop_body = [&](int vpad) {
if (brg.with_grouped_wei_decomp) {
mov(reg_ic, ptr[rsp + reg_ic_offs_]);
mov(ptr[rsp + reg_aux_ic_offs_], reg_ic);
}

set_A_B_matrices();

int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block;
Expand All @@ -3033,8 +3037,8 @@ void jit_brgemm_kernel_t<Wmm>::ldb_loop(int bd_block2, bool is_bdb_tail,
mov(reg_rdb_loop, brg.rdb);
L_aligned(rdb_loop_label, 64);
{
if ((brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 ||
brg.wei_decomp_zero_points_stride != 0)) || brg.with_src_dyn_quant) {
if ((brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 || brg.wei_decomp_zero_points_stride != 0))
|| brg.with_src_dyn_quant) {
auto reg_local_ic = reg_aux_D;
auto reg_local_wei_params = reg_bdb_loop;
auto reg_local_ic_group = reg_ldb_loop;
Expand Down
5 changes: 0 additions & 5 deletions src/cpu/x64/jit_brgemm_inner_product_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1441,11 +1441,6 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa,
jbgp.wei_zero_points_ic_group_size = div_up(jbgp.ic, attr.zero_points_.get_dims(DNNL_ARG_WEIGHTS)[1]);
}

// todo: fix avx2 brgemm kernel behavior for non scalar zp
if (!is_superset(isa, avx512_core) && attr.zero_points_.get_dims(DNNL_ARG_WEIGHTS)[0] != 1) {
jbgp.with_src_dynamic_quant = false;
}

jbgp.wei_decomp_zero_points_dt = attr.zero_points_.get_data_type(DNNL_ARG_WEIGHTS);
if (!one_of(jbgp.wei_decomp_zero_points_dt, f32, u8))
return status::unimplemented;
Expand Down

0 comments on commit d421730

Please sign in to comment.