Skip to content

Commit

Permalink
[rls-v3.6 backport] cpu: x64: matmul: copy_routines: associate fp16 s…
Browse files Browse the repository at this point in the history
…upport with dt, not only isa (oneapi-src#2332)
  • Loading branch information
amakarev authored and liubo-intel committed Jan 26, 2025
1 parent b08b241 commit 7a6ee1c
Showing 1 changed file with 47 additions and 15 deletions.
62 changes: 47 additions & 15 deletions src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2024 Intel Corporation
* Copyright 2021-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -151,7 +151,10 @@ struct jit_brgemm_matmul_copy_a_impl_t : public jit_brgemm_matmul_copy_a_t,
template <>
void jit_brgemm_matmul_copy_a_impl_t<Zmm>::load_vmm(int idx, int offset) {
const auto addr = EVEX_compress_addr(reg_src, offset);
if (conf_->isa == avx512_core_fp16) {
if (conf_->isa == avx512_core_fp16
&& conf_->orig_wei_dt == data_type::f16) {
// See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt`
// is used.
vcvtph2psx(get_vmm_copy(idx), addr);
} else
vmovdqu8(get_vmm_copy(idx), addr);
Expand Down Expand Up @@ -186,8 +189,12 @@ void jit_brgemm_matmul_copy_a_impl_t<Zmm>::load_tail(
}
};

const size_t dt_step
= conf_->is_bf32 || conf_->isa == avx512_core_fp16 ? 1 : typesize_;
// See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt` is used.
const size_t dt_step = conf_->is_bf32
|| (conf_->isa == avx512_core_fp16
&& conf_->orig_wei_dt == data_type::f16)
? 1
: typesize_;
const size_t tail_mask_load = size_t(((size_t)1 << (dt_step * k_tail)) - 1);
kmovx(kTail_load, tail_mask_load);
const int k_tail_st = rnd_up(k_tail, vnni_granularity_);
Expand All @@ -202,7 +209,10 @@ void jit_brgemm_matmul_copy_a_impl_t<Zmm>::load_tail(
auto load_addr = EVEX_compress_addr(reg_src, offset * typesize_);
if (conf_->is_bf32)
vmovups(zmm_tail, load_addr);
else if (conf_->isa == avx512_core_fp16)
else if (conf_->isa == avx512_core_fp16
&& conf_->orig_wei_dt == data_type::f16)
// See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt`
// is used.
vcvtph2psx(zmm_tail, load_addr);
else
vmovdqu8(zmm_tail, load_addr);
Expand All @@ -223,7 +233,10 @@ void jit_brgemm_matmul_copy_a_impl_t<Zmm>::store_tail(
Ymm ymm_downcvt_bf16 = Ymm(get_vmm_copy(0).getIdx());
vcvtneps2bf16(ymm_downcvt_bf16, get_vmm_copy(0));
vmovdqu16(tr_src_addr, ymm_downcvt_bf16 | kTail_store);
} else if (conf_->isa == avx512_core_fp16) {
} else if (conf_->isa == avx512_core_fp16
&& conf_->orig_wei_dt == data_type::f16) {
// See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt`
// is used.
vmovups(tr_src_addr, get_vmm_copy(0) | kTail_store);
} else
vmovdqu8(tr_src_addr, get_vmm_copy(0) | kTail_store);
Expand Down Expand Up @@ -943,13 +956,17 @@ void jit_brgemm_matmul_copy_a_transposed_impl_t<Xbyak::Zmm>::transpose_f32(
const auto addr = is_dynamic_src_ld
? ptr[i % 2 == 0 ? reg_aux_src0 : reg_aux_src1]
: EVEX_compress_addr(src, i * src_stride);
if (i < nrows)
if (conf_->isa == avx512_core_fp16)
if (i < nrows) {
if (conf_->isa == avx512_core_fp16
&& conf_->orig_wei_dt == data_type::f16)
// See the note in `create_brgemm_matmul_copy_b` why
// `orig_wei_dt` is used.
vcvtph2psx(src_zmm(i) | kTail | T_z, addr);
else
vmovups(src_zmm(i) | kTail | T_z, addr);
else
} else {
vpxord(src_zmm(i), src_zmm(i), src_zmm(i));
}
};

auto store = [this, dst](Zmm r, int i) {
Expand Down Expand Up @@ -1075,7 +1092,11 @@ void jit_brgemm_matmul_copy_a_transposed_impl_t<Xbyak::Zmm>::transpose_f32(
template <typename Vmm>
void jit_brgemm_matmul_copy_a_transposed_impl_t<Vmm>::deploy_transpose(
reg64_t dst, reg64_t src, int nrows, int ncolumns) {
if (is_f32 || conf_->isa == avx512_core_fp16)
if (is_f32
|| (conf_->isa == avx512_core_fp16
&& conf_->orig_wei_dt == data_type::f16))
// See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt`
// is used.
transpose_f32(dst, src, nrows, ncolumns);
else
transpose_bf16(dst, src, nrows, ncolumns);
Expand Down Expand Up @@ -3714,7 +3735,12 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::init_tail_mask(
const int columns_tail, const bool use_int4_mask) {
assert(IMPLICATION(use_int4_mask, is_src_int4_));
if (columns_tail > 0) {
const int dt_step = req_cvtps2bf16_ || conf_->isa == avx512_core_fp16

const int dt_step = req_cvtps2bf16_
|| (conf_->isa == avx512_core_fp16
&& conf_->orig_wei_dt == data_type::f16)
// See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt`
// is used.
? 1
: typesize_;
const auto tail_mask = use_int4_mask
Expand Down Expand Up @@ -3870,11 +3896,14 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::copy_row_x_col(

auto src_load = columns_tail > 0 ? src_reg | kTail | T_z : src_reg;
const auto addr = EVEX_compress_addr(reg_src, i * src_stride_);
if (conf_->isa == avx512_core_fp16)
if (conf_->isa == avx512_core_fp16
&& conf_->orig_wei_dt == data_type::f16) {
// See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt`
// is used.
vcvtph2psx(src_load, addr);
else
} else {
vmovdqu8(src_load, addr);

}
L(load_done);
};

Expand Down Expand Up @@ -4687,7 +4716,10 @@ status_t create_brgemm_matmul_copy_b(
else
CHECK(safe_ptr_assign(copy_ker,
new jit_brgemm_matmul_copy_b_bf16_t<Ymm>(conf)));
} else if (is_f32 || conf->isa == avx512_core_fp16) {
} else if (is_f32
|| (conf->isa == avx512_core_fp16
&& conf->orig_wei_dt == data_type::f16)) {
// See the note above why `orig_wei_dt` is used.
if (is_superset(conf->isa, avx512_core))
CHECK(safe_ptr_assign(copy_ker,
new jit_brgemm_matmul_copy_b_f32_t<Zmm>(conf)));
Expand Down

0 comments on commit 7a6ee1c

Please sign in to comment.