diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index 328045ad4ca7f3..203cc6fea0e7e2 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -449,16 +449,10 @@ void Subgraph::initSupportedPrimitiveDescriptors() { config.inConfs.resize(inputShapes.size()); for (size_t i = 0; i < inputShapes.size(); i++) { const auto originalInputPrecision = getOriginalInputPrecisionAtPort(i); - auto precision = ((originalInputPrecision == ov::element::f32) && - context->getConfig().inferencePrecision == ov::element::bf16 && + const auto precision = ((originalInputPrecision == ov::element::f32) && + one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) && subgraph_attrs->snippet->has_domain_sensitive_ops()) ? - static_cast(ov::element::bf16) : - originalInputPrecision; - precision = ((originalInputPrecision == ov::element::f32) && - context->getConfig().inferencePrecision == ov::element::f16 && - subgraph_attrs->snippet->has_domain_sensitive_ops()) ? - static_cast(ov::element::f16) : - precision; + context->getConfig().inferencePrecision : originalInputPrecision; if (supportedPrecisions.count(precision) == 0) OPENVINO_THROW("Subgraph node with name `", getName(), "` doesn't support ", precision, " precision."); @@ -643,8 +637,7 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() { SNIPPETS_REGISTER_PASS_ABSOLUTE_COMMON(Place::PipelineStart, ConvertToSwishCPU); SNIPPETS_REGISTER_PASS_RELATIVE_COMMON(Place::After, ov::snippets::pass::Canonicalization, ov::snippets::pass::AnalyzeBroadcastableInputs, broadcastable_inputs); - if ((context->getConfig().inferencePrecision == ov::element::bf16 || context->getConfig().inferencePrecision == ov::element::f16) - && subgraph_attrs->snippet->has_domain_sensitive_ops()) { + if (one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) && subgraph_attrs->snippet->has_domain_sensitive_ops()) { // enforce BF16 precisions to supported operations // MatMul has to be decomposed to Brgemm operations before enforcement // Note, MatMul decomposition will be run later again for case if BF16 enforcement is not happened diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.cpp index b90b35f9359aa4..05a68f1538be4c 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.cpp @@ -125,5 +125,8 @@ std::set> EnforcePrecision::get_supported_precisi if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16) && ov::is_type(op)) { return {{element::bf16, element::bf16}}; } + if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16) && ov::is_type(op)) { + return {{element::f16, element::f16}}; + } return {}; } diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp index 93e4e3df4e856b..d3afe93b47d14a 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp @@ -124,30 +124,43 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16_4D, ::testing::Values(CPUTestUtils::empty_plugin_config)), MHA::getTestCaseName); -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFP16_4D, +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16, MHA, ::testing::Combine(::testing::ValuesIn(transposedShape_4D()), + ::testing::ValuesIn(precision_f32(4)), + ::testing::Values(ov::element::bf16), + ::testing::ValuesIn({false}), + ::testing::Values(MHA::default_thread_count), + ::testing::Values(7), + ::testing::Values(6), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)), + MHA::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFP16_4D, + MHA, + ::testing::Combine(::testing::ValuesIn(transposedShape_4D(false)), ::testing::ValuesIn(precision_fp16_if_supported(4)), - ::testing::Values(ov::element::f32), + ::testing::ValuesIn(mha_infer_precision_fp16_if_supported()), ::testing::ValuesIn({false, true}), ::testing::Values(MHA::default_thread_count), - ::testing::Values(1), // MHA + 5 Converts + 1 Transpose on output - ::testing::Values(1), // MHA + 5 Converts on inputs and output + ::testing::Values(2), + ::testing::Values(1), ::testing::Values(ov::test::utils::DEVICE_CPU), ::testing::Values(CPUTestUtils::empty_plugin_config)), MHA::getTestCaseName); -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16, +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16, MHA, - ::testing::Combine(::testing::ValuesIn(transposedShape_4D()), + ::testing::Combine(::testing::ValuesIn(transposedShape_4D(false)), ::testing::ValuesIn(precision_f32(4)), - ::testing::Values(ov::element::bf16), - ::testing::ValuesIn({false}), + ::testing::ValuesIn(mha_infer_precision_fp16_if_supported()), + ::testing::ValuesIn({false, true}), ::testing::Values(MHA::default_thread_count), - ::testing::Values(7), - ::testing::Values(6), + ::testing::Values(2), + ::testing::Values(1), ::testing::Values(ov::test::utils::DEVICE_CPU), - ::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)), + ::testing::Values(CPUTestUtils::cpu_f16_plugin_config)), MHA::getTestCaseName); } // namespace diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/utils.hpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/utils.hpp index 6815cdab671cea..e5b04a117d995e 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/utils.hpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/utils.hpp @@ -44,6 +44,13 @@ static inline std::vector> precision_fp16_if_supporte return prc; } +static inline std::vector mha_infer_precision_fp16_if_supported() { + std::vector prc; + if (is_fp16_supported_by_brgemm()) + prc.emplace_back(element::f16); + return prc; +} + static inline std::vector> quantized_precisions_if_supported() { std::vector> prc = {}; // In Snippets MatMul INT8 is supported only on VNNI/AMX platforms