diff --git a/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp index 41826a9253e..2882dd52316 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2024 Intel Corporation +* Copyright 2021-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -151,7 +151,10 @@ struct jit_brgemm_matmul_copy_a_impl_t : public jit_brgemm_matmul_copy_a_t, template <> void jit_brgemm_matmul_copy_a_impl_t::load_vmm(int idx, int offset) { const auto addr = EVEX_compress_addr(reg_src, offset); - if (conf_->isa == avx512_core_fp16) { + if (conf_->isa == avx512_core_fp16 + && conf_->orig_wei_dt == data_type::f16) { + // See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt` + // is used. vcvtph2psx(get_vmm_copy(idx), addr); } else vmovdqu8(get_vmm_copy(idx), addr); @@ -186,8 +189,12 @@ void jit_brgemm_matmul_copy_a_impl_t::load_tail( } }; - const size_t dt_step - = conf_->is_bf32 || conf_->isa == avx512_core_fp16 ? 1 : typesize_; + // See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt` is used. + const size_t dt_step = conf_->is_bf32 + || (conf_->isa == avx512_core_fp16 + && conf_->orig_wei_dt == data_type::f16) + ? 1 + : typesize_; const size_t tail_mask_load = size_t(((size_t)1 << (dt_step * k_tail)) - 1); kmovx(kTail_load, tail_mask_load); const int k_tail_st = rnd_up(k_tail, vnni_granularity_); @@ -202,7 +209,10 @@ void jit_brgemm_matmul_copy_a_impl_t::load_tail( auto load_addr = EVEX_compress_addr(reg_src, offset * typesize_); if (conf_->is_bf32) vmovups(zmm_tail, load_addr); - else if (conf_->isa == avx512_core_fp16) + else if (conf_->isa == avx512_core_fp16 + && conf_->orig_wei_dt == data_type::f16) + // See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt` + // is used. vcvtph2psx(zmm_tail, load_addr); else vmovdqu8(zmm_tail, load_addr); @@ -223,7 +233,10 @@ void jit_brgemm_matmul_copy_a_impl_t::store_tail( Ymm ymm_downcvt_bf16 = Ymm(get_vmm_copy(0).getIdx()); vcvtneps2bf16(ymm_downcvt_bf16, get_vmm_copy(0)); vmovdqu16(tr_src_addr, ymm_downcvt_bf16 | kTail_store); - } else if (conf_->isa == avx512_core_fp16) { + } else if (conf_->isa == avx512_core_fp16 + && conf_->orig_wei_dt == data_type::f16) { + // See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt` + // is used. vmovups(tr_src_addr, get_vmm_copy(0) | kTail_store); } else vmovdqu8(tr_src_addr, get_vmm_copy(0) | kTail_store); @@ -943,13 +956,17 @@ void jit_brgemm_matmul_copy_a_transposed_impl_t::transpose_f32( const auto addr = is_dynamic_src_ld ? ptr[i % 2 == 0 ? reg_aux_src0 : reg_aux_src1] : EVEX_compress_addr(src, i * src_stride); - if (i < nrows) - if (conf_->isa == avx512_core_fp16) + if (i < nrows) { + if (conf_->isa == avx512_core_fp16 + && conf_->orig_wei_dt == data_type::f16) + // See the note in `create_brgemm_matmul_copy_b` why + // `orig_wei_dt` is used. vcvtph2psx(src_zmm(i) | kTail | T_z, addr); else vmovups(src_zmm(i) | kTail | T_z, addr); - else + } else { vpxord(src_zmm(i), src_zmm(i), src_zmm(i)); + } }; auto store = [this, dst](Zmm r, int i) { @@ -1075,7 +1092,11 @@ void jit_brgemm_matmul_copy_a_transposed_impl_t::transpose_f32( template void jit_brgemm_matmul_copy_a_transposed_impl_t::deploy_transpose( reg64_t dst, reg64_t src, int nrows, int ncolumns) { - if (is_f32 || conf_->isa == avx512_core_fp16) + if (is_f32 + || (conf_->isa == avx512_core_fp16 + && conf_->orig_wei_dt == data_type::f16)) + // See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt` + // is used. transpose_f32(dst, src, nrows, ncolumns); else transpose_bf16(dst, src, nrows, ncolumns); @@ -3714,7 +3735,12 @@ void jit_brgemm_matmul_copy_b_transposed_t::init_tail_mask( const int columns_tail, const bool use_int4_mask) { assert(IMPLICATION(use_int4_mask, is_src_int4_)); if (columns_tail > 0) { - const int dt_step = req_cvtps2bf16_ || conf_->isa == avx512_core_fp16 + + const int dt_step = req_cvtps2bf16_ + || (conf_->isa == avx512_core_fp16 + && conf_->orig_wei_dt == data_type::f16) + // See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt` + // is used. ? 1 : typesize_; const auto tail_mask = use_int4_mask @@ -3870,11 +3896,14 @@ void jit_brgemm_matmul_copy_b_transposed_t::copy_row_x_col( auto src_load = columns_tail > 0 ? src_reg | kTail | T_z : src_reg; const auto addr = EVEX_compress_addr(reg_src, i * src_stride_); - if (conf_->isa == avx512_core_fp16) + if (conf_->isa == avx512_core_fp16 + && conf_->orig_wei_dt == data_type::f16) { + // See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt` + // is used. vcvtph2psx(src_load, addr); - else + } else { vmovdqu8(src_load, addr); - + } L(load_done); }; @@ -4687,7 +4716,10 @@ status_t create_brgemm_matmul_copy_b( else CHECK(safe_ptr_assign(copy_ker, new jit_brgemm_matmul_copy_b_bf16_t(conf))); - } else if (is_f32 || conf->isa == avx512_core_fp16) { + } else if (is_f32 + || (conf->isa == avx512_core_fp16 + && conf->orig_wei_dt == data_type::f16)) { + // See the note above why `orig_wei_dt` is used. if (is_superset(conf->isa, avx512_core)) CHECK(safe_ptr_assign(copy_ker, new jit_brgemm_matmul_copy_b_f32_t(conf)));