Skip to content

Commit

Permalink
[FORK][FIX] DQ IP: allocate aux accums via stack
Browse files Browse the repository at this point in the history
[FORK][FEATURE] InnerProduct primitive: squashed weight decompression
  • Loading branch information
dmitry-gorokhov committed Jan 21, 2025
1 parent c7ecd8f commit 55058f1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 16 deletions.
3 changes: 0 additions & 3 deletions src/cpu/x64/brgemm/brgemm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,6 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) {
if (brg->with_src_dyn_quant && brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride != 0) max_bcast_block -= adj_ld_block2;

max_bcast_block /= adj_ld_block2;
if (brg->with_src_dyn_quant) {
max_bcast_block /= 2;
}

return max_bcast_block;
}
Expand Down
28 changes: 15 additions & 13 deletions src/cpu/x64/brgemm/jit_brgemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2298,11 +2298,6 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
if (brg.req_s8s8_compensation) uni_vpaddb(v1, v1, vmm_inp_shift());
};

auto vmm_accm_tmp = [&](int ld_block, int bd, int ld) {
int idx = max_effective_vregs - 1 - (brg.ld_block2 * brg.bd_block) - ld_block - (bd * ld_block + ld);
return Vmm(idx);
};

auto vmm_zero_point = [&](int ld) {
int idx = isa_num_vregs(brg.isa_impl) - 3 - ld;
return Vmm(idx);
Expand Down Expand Up @@ -2368,9 +2363,14 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,

mov(reg_local_wei_scales, ptr[rsp + reg_aux2_wei_scales_offs_]);

const int vec_size = vreg_traits<Vmm>::vlen;
auto accums_stack_space = bd_e * ld_block2 * vec_size;
sub(rsp, accums_stack_space);
for (int bd = bd_b; bd < bd_e; bd++) {
for (int ld = 0; ld < ld_block2; ld++) {
auto vmm_accm = vmm_accm_tmp(ld_block2, bd, ld);
auto vmm_accm = accm(ld_block2, bd, ld);
vmovups(ptr[rsp + (bd * ld_block2 + ld) * vec_size], vmm_accm);

uni_vxorps(vmm_accm, vmm_accm, vmm_accm);
}
}
Expand Down Expand Up @@ -2409,14 +2409,14 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
+ brg.LDB * brg.rd_block * brg.typesize_B]);
}
for (int ld = 0; ld < ld_block2; ld++) {
auto vmm = vmm_accm_tmp(ld_block2, bd, ld);
auto vmm = accm(ld_block2, bd, ld);
vpdpbusd(vmm, load(ld), bcst(), is_superset(brg.isa_impl, avx512_core) ? EvexEncoding : VexEncoding);
}
if (brg.with_wei_decomp_zero_points) {
uni_vpxor(bcst(), bcst(), vmm_neg_one);
uni_vpsubb(bcst(), bcst(), vmm_neg_one);
for (int ld = 0; ld < ld_block2; ld++) {
auto vmm = vmm_accm_tmp(ld_block2, bd, ld);
auto vmm = accm(ld_block2, bd, ld);
Vmm vmm_zp = brg.wei_decomp_zero_points_stride == 0 ? vmm_zero_point(0) : vmm_zero_point(ld);
vpdpbusd(vmm, vmm_zp, bcst(), is_superset(brg.isa_impl, avx512_core) ? EvexEncoding : VexEncoding);
}
Expand All @@ -2426,7 +2426,7 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,

auto reg_local_src_scales = reg_local_wei_zp;
auto vmm_src_scales = bcst();
mov(reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_]);
mov(reg_local_src_scales, ptr[rsp + reg_aux2_src_scales_offs_ + accums_stack_space]);

for (int bd = bd_b; bd < bd_e; bd++) {
uni_vbroadcastss(vmm_src_scales, ptr[reg_local_src_scales + bd * brg.src_scales_stride * sizeof(float)]);
Expand All @@ -2438,15 +2438,17 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
}
}
for (int ld = 0; ld < ld_block2; ld++) {
auto vmm_accm_aux = vmm_accm_tmp(ld_block2, bd, ld);
auto vmm_accm = accm(ld_block2, bd, ld);

uni_vcvtdq2ps(vmm_accm_aux, vmm_accm_aux);
uni_vmulps(vmm_accm_aux, vmm_accm_aux, vmm_src_scales);
uni_vfmadd231ps(vmm_accm, vmm_accm_aux, load(ld));
uni_vcvtdq2ps(vmm_accm, vmm_accm);
uni_vmulps(vmm_accm, vmm_accm, vmm_src_scales);
uni_vmulps(load(ld), vmm_accm, load(ld));
uni_vmovups(vmm_accm, ptr[rsp + (bd * ld_block2 + ld) * vec_size]);
uni_vaddps(vmm_accm, vmm_accm, load(ld));
}
}

add(rsp, accums_stack_space);
mov(reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]);
mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]);

Expand Down

0 comments on commit 55058f1

Please sign in to comment.