Skip to content

Commit

Permalink
Apply Alexandra comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Nov 15, 2024
1 parent a80aea6 commit c5c5af1
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 22 deletions.
15 changes: 4 additions & 11 deletions src/plugins/intel_cpu/src/nodes/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,16 +449,10 @@ void Subgraph::initSupportedPrimitiveDescriptors() {
config.inConfs.resize(inputShapes.size());
for (size_t i = 0; i < inputShapes.size(); i++) {
const auto originalInputPrecision = getOriginalInputPrecisionAtPort(i);
auto precision = ((originalInputPrecision == ov::element::f32) &&
context->getConfig().inferencePrecision == ov::element::bf16 &&
const auto precision = ((originalInputPrecision == ov::element::f32) &&
one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) &&
subgraph_attrs->snippet->has_domain_sensitive_ops()) ?
static_cast<ov::element::Type>(ov::element::bf16) :
originalInputPrecision;
precision = ((originalInputPrecision == ov::element::f32) &&
context->getConfig().inferencePrecision == ov::element::f16 &&
subgraph_attrs->snippet->has_domain_sensitive_ops()) ?
static_cast<ov::element::Type>(ov::element::f16) :
precision;
context->getConfig().inferencePrecision : originalInputPrecision;
if (supportedPrecisions.count(precision) == 0)
OPENVINO_THROW("Subgraph node with name `", getName(), "` doesn't support ", precision, " precision.");

Expand Down Expand Up @@ -643,8 +637,7 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() {
SNIPPETS_REGISTER_PASS_ABSOLUTE_COMMON(Place::PipelineStart, ConvertToSwishCPU);
SNIPPETS_REGISTER_PASS_RELATIVE_COMMON(Place::After, ov::snippets::pass::Canonicalization,
ov::snippets::pass::AnalyzeBroadcastableInputs, broadcastable_inputs);
if ((context->getConfig().inferencePrecision == ov::element::bf16 || context->getConfig().inferencePrecision == ov::element::f16)
&& subgraph_attrs->snippet->has_domain_sensitive_ops()) {
if (one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) && subgraph_attrs->snippet->has_domain_sensitive_ops()) {
// enforce BF16 precisions to supported operations
// MatMul has to be decomposed to Brgemm operations before enforcement
// Note, MatMul decomposition will be run later again for case if BF16 enforcement is not happened
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,8 @@ std::set<std::vector<ov::element::Type>> EnforcePrecision::get_supported_precisi
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16) && ov::is_type<snippets::op::Brgemm>(op)) {
return {{element::bf16, element::bf16}};
}
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16) && ov::is_type<snippets::op::Brgemm>(op)) {
return {{element::f16, element::f16}};
}
return {};
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,30 +124,43 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16_4D,
::testing::Values(CPUTestUtils::empty_plugin_config)),
MHA::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFP16_4D,
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16,
MHA,
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
::testing::ValuesIn(precision_f32(4)),
::testing::Values(ov::element::bf16),
::testing::ValuesIn({false}),
::testing::Values(MHA::default_thread_count),
::testing::Values(7),
::testing::Values(6),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)),
MHA::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFP16_4D,
MHA,
::testing::Combine(::testing::ValuesIn(transposedShape_4D(false)),
::testing::ValuesIn(precision_fp16_if_supported(4)),
::testing::Values(ov::element::f32),
::testing::ValuesIn(mha_infer_precision_fp16_if_supported()),
::testing::ValuesIn({false, true}),
::testing::Values(MHA::default_thread_count),
::testing::Values(1), // MHA + 5 Converts + 1 Transpose on output
::testing::Values(1), // MHA + 5 Converts on inputs and output
::testing::Values(2),
::testing::Values(1),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::empty_plugin_config)),
MHA::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16,
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16,
MHA,
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
::testing::Combine(::testing::ValuesIn(transposedShape_4D(false)),
::testing::ValuesIn(precision_f32(4)),
::testing::Values(ov::element::bf16),
::testing::ValuesIn({false}),
::testing::ValuesIn(mha_infer_precision_fp16_if_supported()),
::testing::ValuesIn({false, true}),
::testing::Values(MHA::default_thread_count),
::testing::Values(7),
::testing::Values(6),
::testing::Values(2),
::testing::Values(1),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)),
::testing::Values(CPUTestUtils::cpu_f16_plugin_config)),
MHA::getTestCaseName);

} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ static inline std::vector<std::vector<element::Type>> precision_fp16_if_supporte
return prc;
}

static inline std::vector<element::Type> mha_infer_precision_fp16_if_supported() {
std::vector<element::Type> prc;
if (is_fp16_supported_by_brgemm())
prc.emplace_back(element::f16);
return prc;
}

static inline std::vector<std::vector<element::Type>> quantized_precisions_if_supported() {
std::vector<std::vector<element::Type>> prc = {};
// In Snippets MatMul INT8 is supported only on VNNI/AMX platforms
Expand Down

0 comments on commit c5c5af1

Please sign in to comment.