Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] sns f16_mha_on_avx512_core_amx_f16_target #27514

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/common/snippets/src/op/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ ov::element::Type Brgemm::get_output_type(const ov::element::Type& in_type0, con
const bool is_f32 = utils::everyone_is(element::f32, in_type0, in_type1);
const bool is_int8 = utils::one_of(in_type0, element::i8, element::u8) && in_type1 == element::i8;
const bool is_bf16 = utils::everyone_is(element::bf16, in_type0, in_type1);
if (is_f32 || is_bf16) {
const bool is_f16 = utils::everyone_is(element::f16, in_type0, in_type1);
if (is_f32 || is_bf16 || is_f16) {
return element::f32;
} else if (is_int8) {
return element::i32;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class jit_brgemm_copy_b_emitter : public jit_emitter {

size_t get_inputs_num() const override {return 1;}
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr) {
return {{element::i8}, {element::bf16}, {element::f32}};
return {{element::i8}, {element::bf16}, {element::f16}, {element::f32}};
}

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ std::set<std::vector<element::Type>> jit_brgemm_emitter::get_supported_precision
} else if (brgemm->get_type() == BRGEMM_TYPE::WITH_AMX) {
return {{element::i8, element::i8, element::u8},
{element::u8, element::i8, element::u8},
{element::bf16, element::bf16, element::u8}};
{element::bf16, element::bf16, element::u8},
{element::f16, element::f16, element::u8}};
}
OV_CPU_JIT_EMITTER_THROW("got BrgemmCPU node with unsupported type");
}
Expand Down
12 changes: 9 additions & 3 deletions src/plugins/intel_cpu/src/nodes/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,11 +449,16 @@ void Subgraph::initSupportedPrimitiveDescriptors() {
config.inConfs.resize(inputShapes.size());
for (size_t i = 0; i < inputShapes.size(); i++) {
const auto originalInputPrecision = getOriginalInputPrecisionAtPort(i);
const auto precision = ((originalInputPrecision == ov::element::f32) &&
auto precision = ((originalInputPrecision == ov::element::f32) &&
context->getConfig().inferencePrecision == ov::element::bf16 &&
subgraph_attrs->snippet->has_domain_sensitive_ops()) ?
static_cast<ov::element::Type>(ov::element::bf16) :
originalInputPrecision;
precision = ((originalInputPrecision == ov::element::f32) &&
context->getConfig().inferencePrecision == ov::element::f16 &&
subgraph_attrs->snippet->has_domain_sensitive_ops()) ?
static_cast<ov::element::Type>(ov::element::f16) :
precision;
a-sidorova marked this conversation as resolved.
Show resolved Hide resolved
if (supportedPrecisions.count(precision) == 0)
OPENVINO_THROW("Subgraph node with name `", getName(), "` doesn't support ", precision, " precision.");

Expand Down Expand Up @@ -638,13 +643,14 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() {
SNIPPETS_REGISTER_PASS_ABSOLUTE_COMMON(Place::PipelineStart, ConvertToSwishCPU);
SNIPPETS_REGISTER_PASS_RELATIVE_COMMON(Place::After, ov::snippets::pass::Canonicalization,
ov::snippets::pass::AnalyzeBroadcastableInputs, broadcastable_inputs);
if (context->getConfig().inferencePrecision == ov::element::bf16 && subgraph_attrs->snippet->has_domain_sensitive_ops()) {
if ((context->getConfig().inferencePrecision == ov::element::bf16 || context->getConfig().inferencePrecision == ov::element::f16)
a-sidorova marked this conversation as resolved.
Show resolved Hide resolved
&& subgraph_attrs->snippet->has_domain_sensitive_ops()) {
// enforce BF16 precisions to supported operations
// MatMul has to be decomposed to Brgemm operations before enforcement
// Note, MatMul decomposition will be run later again for case if BF16 enforcement is not happened
SNIPPETS_REGISTER_PASS_ABSOLUTE_X86_64(Place::PipelineStart, ov::snippets::pass::MatMulToBrgemm);
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, ov::snippets::pass::MatMulToBrgemm,
pass::EnforcePrecision, element::f32, element::bf16);
pass::EnforcePrecision, element::f32, context->getConfig().inferencePrecision);
a-sidorova marked this conversation as resolved.
Show resolved Hide resolved
}
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before, ov::snippets::pass::PropagatePrecision,
ov::intel_cpu::pass::BrgemmToBrgemmCPU);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ void BrgemmCopyB::validate_and_infer_types() {
}

void BrgemmCopyB::validate_element_type(const ov::element::Type& element_type) {
OPENVINO_ASSERT(one_of(element_type, element::f32, element::bf16, element::i8),
OPENVINO_ASSERT(one_of(element_type, element::f32, element::bf16, element::f16, element::i8),
"BrgemmCopyB doesn't support element type" + element_type.get_type_name());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx) {

// Note: AMX might be not used even if it's supported by the hardware, check the BrgemmToBrgemmCPU pass for details
if (is_with_amx) {
SUPPORT_ONE(avx512_core_amx, "Unsupported hardware configuration: amx is supported only on avx512 platforms")
if (dt_in0 == ov::element::f16)
SUPPORT_ONE(avx512_core_amx_fp16, "Unsupported hardware configuration: amx is supported only on avx512 platforms")
else
SUPPORT_ONE(avx512_core_amx, "Unsupported hardware configuration: amx is supported only on avx512 platforms")
} else if (dt_in0 == ov::element::bf16) {
SUPPORT_ONE(avx512_core_bf16, "Unsupported hardware configuration: bf16 is supported only on avx512 platforms")
} else if (one_of(dt_in0, ov::element::u8, ov::element::i8)) {
Expand All @@ -46,13 +49,19 @@ BRGEMM_TYPE get_brgemm_type(const ov::element::Type& element_type_a, const Dimen
return transpose_b ? BRGEMM_TYPE::REPACKING_ONLY : BRGEMM_TYPE::STAND_ALONE;

OPENVINO_ASSERT(element_type_a != element::bf16 || mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16),
"BF16 precision is not supported on this hardware");
"BrgemmCPU BF16 precision is not supported on non avx512_core_bf16 system");
OPENVINO_ASSERT(element_type_a != element::f16 || mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16),
"BrgemmCPU FP16 precision is not supported on non avx512_core_amx_fp16 system");

const auto brgemmVNNIFactor = 4 / element_type_a.size();
if (one_of(element_type_a, element::u8, element::i8, element::bf16) &&
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx) &&
K_dim.is_static() && K_dim.get_length() % brgemmVNNIFactor == 0)
return BRGEMM_TYPE::WITH_AMX;
if (element_type_a == ov::element::f16 &&
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16) &&
K_dim.is_static() && K_dim.get_length() % brgemmVNNIFactor == 0)
return BRGEMM_TYPE::WITH_AMX;
// Note: this condition reproduces logic from the OneDNN Brgemm implementation. This is needed to align with the
// backend requirements. More details in onednn/src/cpu/x64/brgemm/brgemm_utils.cpp
if (element_type_a == ov::element::i8)
Expand All @@ -79,6 +88,7 @@ size_t compute_inner_n_block(const ov::element::Type& precision) {
switch (precision) {
case element::i8: return 64;
case element::bf16: return 32;
case element::f16: return 32;
case element::f32: return 16;
default: OPENVINO_THROW("BrgemmCopyB doesn't support precision ", precision);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ namespace brgemm_utils {

enum class BRGEMM_TYPE {
STAND_ALONE, // No extra requirements, used for f32|f32
WITH_AMX, // i8|i8 or bf16|bf16 on AMX system - needs BrgemmCopyB and scratchpad
WITH_AMX, // i8|i8 or bf16|bf16 on AMX system or fp16|fp16 on AMX_FP16 system - needs BrgemmCopyB and scratchpad
WITH_COMPENSATIONS, // i8|i8 (non-AMX system) - needs BrgemmCopyB for data repacking and compensations
REPACKING_ONLY // u8|i8 or bf16|bf16 (non-AMX system) - needs BrgemmCopyB on second input for data repacking
REPACKING_ONLY, // u8|i8 or bf16|bf16 (non-AMX system) - needs BrgemmCopyB on second input for data repacking
};

dnnl::impl::cpu::x64::cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace pass {
* \ Buffer (with repacked data) Buffer (with compensations)
* \ | /
* BrgemmCPU
* - f32|f32 with transpose_b, u8|i8, i8|i8 or bf16|bf16 on AMX system:
* - f32|f32 with transpose_b, u8|i8, i8|i8 or bf16|bf16 on AMX system or fp16|fp16 on AMX_FP16 system:
* \ BrgemmCopyB
* \ Buffer (with repacked data) Buffer (with new memory)
* \ | /
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ void Transformations::MainSnippets(void) {
// However there may be Convert [f32->bf16] before Result since:
// - bf16 Brgemm has f32 output;
// - CPU Node Subgraph requires bf16 on output when inference precision is bf16.
// To avoid sitations when Transpose is not alone node between MatMul and Result,
// To avoid situations when Transpose is not alone node between MatMul and Result,
// Plugin disables Transpose tokenization on output
bool mha_token_enable_transpose_on_output = one_of(config.inferencePrecision, element::f32, element::undefined);
size_t concurrency = config.streamExecutorConfig.get_threads_per_stream();
Expand Down Expand Up @@ -959,6 +959,7 @@ void Transformations::MainSnippets(void) {

ov::pass::Manager snippetsManager("CPU:Snippets");
snippetsManager.set_per_pass_validation(false);
// if callback needed for better perf, enable SnippetsMarkSkipped, and disable TokenizeFCSnippets.
if (!ignoreCallback) {
#if defined(OPENVINO_ARCH_ARM64)
CPU_REGISTER_PASS_ARM(snippetsManager, SnippetsMarkSkipped);
Expand All @@ -969,17 +970,17 @@ void Transformations::MainSnippets(void) {
}
CPU_REGISTER_PASS_COMMON(snippetsManager, snippets::pass::SnippetsTokenization, tokenization_config);

// - MHA has BRGEMM that is supported only on AVX512 platforms
// - CPU Plugin Subgraph supports only f32, bf16 (and quantized) BRGEMM
// [122494] Need to add support of f16
// - CPU Plugin Subgraph supports f32, bf16, quantized and fp16(on avx_512_core_amx_fp16 target) BRGEMM
const bool isMHASupported =
#if defined(OPENVINO_ARCH_ARM64)
false;
#else
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) &&
one_of(config.inferencePrecision, ov::element::f32, element::undefined)) ||
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) &&
one_of(config.inferencePrecision, ov::element::bf16, ov::element::f32, element::undefined));
one_of(config.inferencePrecision, ov::element::bf16, ov::element::f32, element::undefined)) ||
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16) &&
one_of(config.inferencePrecision, ov::element::f16));
#endif
if (!isMHASupported) {
CPU_DISABLE_PASS_COMMON(snippetsManager, snippets::pass::TokenizeMHASnippets);
Expand All @@ -995,12 +996,11 @@ void Transformations::MainSnippets(void) {
const auto in_type1 = matmul->get_input_element_type(1);
const auto is_fp32 = (in_type0 == ov::element::f32 && in_type1 == ov::element::f32 &&
one_of(config.inferencePrecision, element::f32, element::undefined));
const auto is_fp16 = (in_type0 == ov::element::f16 || in_type1 == ov::element::f16);
const auto is_fp16 = (in_type0 == ov::element::f16 || in_type1 == ov::element::f16) ||
(in_type0 == element::f32 && in_type1 == ov::element::f32 && config.inferencePrecision == ov::element::f16);
const auto is_bf16 = (in_type0 == ov::element::bf16 && in_type1 == ov::element::bf16) ||
((in_type0 == element::f32 && in_type1 == ov::element::f32 && config.inferencePrecision == ov::element::bf16));
const auto is_int8 = in_type0 == ov::element::i8;
if (is_fp16)
return false;
if (is_fp32)
return true;
// Only FP32 dynamic MHA is supported
Expand All @@ -1010,19 +1010,22 @@ void Transformations::MainSnippets(void) {
// The current solution with ExtractExplicitMatMulTranspose pass is slower for non-f32 cases than using of brgemm_copy_b kernel
if (matmul->get_transpose_a() || matmul->get_transpose_b())
return false;
// [150842] The execution of Brgemm INT8/BF16 on AMX platforms depends on the value of "K % VNNIFactor".
// [150842] The execution of Brgemm INT8/BF16/FP16 on AMX platforms depends on the value of "K % VNNIFactor".
// For more details, please teake a look at the ticket 150842
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx)) {
const auto& b_shape = matmul->get_input_partial_shape(1);
const auto K = matmul->get_transpose_b() ? *b_shape.rbegin() : *++b_shape.rbegin();
if (is_bf16) return K.is_static() && (K.get_length() % 2 == 0);
const size_t brgemm_vnni_factor_for_real16 = 2; // 4/2(size in term of byte for bf16/fp16)
if (is_bf16 || is_fp16) return K.is_static() && (K.get_length() % brgemm_vnni_factor_for_real16 == 0);
if (is_int8) return K.is_static();
}
if (is_int8)
return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_vnni) ||
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2_vnni);
if (is_bf16)
return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16);
if (is_fp16)
return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16);
return true;
};
auto is_unsupported_parallel_work_amount = [&](const std::shared_ptr<const ov::Node>& n, const ov::PartialShape& shape) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class MHATest : public testing::WithParamInterface<MHATuple>, virtual public Sub
for (size_t i = 0; i < funcInputs.size(); ++i) {
const auto& funcInput = funcInputs[i];
ov::Tensor tensor;
if (funcInput.get_element_type() == ov::element::bf16) {
if (funcInput.get_element_type() == ov::element::bf16 || funcInput.get_element_type() == ov::element::f16) {
ov::test::utils::InputGenerateData in_data;
in_data.start_from = -1;
in_data.range = 2;
Expand Down Expand Up @@ -232,6 +232,9 @@ class MHATest : public testing::WithParamInterface<MHATuple>, virtual public Sub
configuration.insert({ov::hint::inference_precision(ov::element::bf16)});
}

if (inputPrecisions[0] == ElementType::f16)
configuration.insert({ov::hint::inference_precision(ov::element::f16)});

// Snippets MHA tokenization has limitations to avoid performance degradations. These limitations depend on
// target machine. Just for testing, we disable these limitations to allow Snippets to tokenize pattern on all
// machines for validation.
Expand All @@ -253,6 +256,9 @@ TEST_P(MHATest, CompareWithRefs) {
if (inputPrecisions[0] == ElementType::bf16 && !ov::with_cpu_x86_bfloat16())
GTEST_SKIP();

if (inputPrecisions[0] == ElementType::f16 && !ov::with_cpu_x86_avx512_core_amx_fp16())
GTEST_SKIP();

if (!ov::with_cpu_x86_avx512_core())
GTEST_SKIP();

Expand Down Expand Up @@ -308,6 +314,20 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(ov::test::utils::DEVICE_CPU)),
MHATest::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(
smoke_MHA_FP16,
MHATest,
::testing::Combine(
::testing::ValuesIn(static_shapes_to_test_representation(inputShapes)),
::testing::Values(
std::vector<ElementType>{ElementType::f16, ElementType::f16, ElementType::f16, ElementType::f16}),
::testing::ValuesIn(matMulIn0Precisions),
::testing::ValuesIn(patternTypes),
::testing::Values(ExpectedNodes{{"Subgraph", 1},
{"Transpose", 1}}), // Plugin disables tokenization of Transpose on output
::testing::Values(ov::test::utils::DEVICE_CPU)),
MHATest::getTestCaseName);

} // namespace

static std::shared_ptr<ov::Model> initMHAQuantSubgraph0(std::vector<ov::PartialShape>& inputDynamicShapes,
Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you tell me please if you launched these added tests on GNR or using SDE?

Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,19 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16_4D,
::testing::Values(CPUTestUtils::empty_plugin_config)),
MHA::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFP16_4D,
a-sidorova marked this conversation as resolved.
Show resolved Hide resolved
MHA,
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
::testing::ValuesIn(precision_fp16_if_supported(4)),
::testing::Values(ov::element::f32),
::testing::ValuesIn({false, true}),
::testing::Values(MHA::default_thread_count),
::testing::Values(1), // MHA + 5 Converts + 1 Transpose on output
::testing::Values(1), // MHA + 5 Converts on inputs and output
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::empty_plugin_config)),
MHA::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16,
MHA,
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ static inline bool is_bf16_supported_by_brgemm() {
return ov::with_cpu_x86_bfloat16() || ov::with_cpu_x86_avx512_core_amx_bf16();
}

static inline bool is_fp16_supported_by_brgemm() {
return ov::with_cpu_x86_avx512_core_amx_fp16();
}

static inline bool is_i8_supported_by_brgemm() {
return ov::with_cpu_x86_avx512_core_vnni() || ov::with_cpu_x86_avx512_core_amx_int8();
}
Expand All @@ -33,6 +37,13 @@ static inline std::vector<std::vector<element::Type>> precision_bf16_if_supporte
return prc;
}

static inline std::vector<std::vector<element::Type>> precision_fp16_if_supported(size_t count) {
std::vector<std::vector<element::Type>> prc;
if (is_fp16_supported_by_brgemm())
prc.emplace_back(std::vector<element::Type>(count, element::f16));
return prc;
}

static inline std::vector<std::vector<element::Type>> quantized_precisions_if_supported() {
std::vector<std::vector<element::Type>> prc = {};
// In Snippets MatMul INT8 is supported only on VNNI/AMX platforms
Expand Down
Loading