Skip to content

Commit

Permalink
[CPU] Enable bf16 acdb layout for transpose (openvinotoolkit#21030)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangYiIntel authored Nov 14, 2023
1 parent bb3ed2d commit 7cb3bf5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/nodes/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ void Transpose::initSupportedPrimitiveDescriptors() {
supportedPrimitiveDescriptorsBuilder(config, transposeParams);
}
#endif // OPENVINO_ARCH_X86_64
if (prec == Precision::FP32 || prec == Precision::FP16 || prec == Precision::I8 || prec == Precision::U8) {
if (prec == Precision::FP32 || prec == Precision::FP16 || prec == Precision::I8 || prec == Precision::U8 || prec == Precision::BF16) {
config.inConfs[0].setMemDesc(creatorsMap.at(LayoutType::nspc)->createSharedDesc(prec, inputDataShape));
config.outConfs[0].setMemDesc(creatorsMap.at(LayoutType::nspc)->createSharedDesc(prec, outputDataShape));
supportedPrimitiveDescriptorsBuilder(config, transposeParams);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const auto cpuParams_nCdhw16c = CPUSpecificParams {{nCdhw16c}, {}, {}, {}};

const auto cpuParams_nChw8c = CPUSpecificParams {{nChw8c}, {}, {}, {}};
const auto cpuParams_nCdhw8c = CPUSpecificParams {{nCdhw8c}, {}, {}, {}};
const auto cpuParams_nspc = CPUSpecificParams {{acdb}, {}, {}, {}};

const std::vector<InferenceEngine::Precision> netPrecisions = {
Precision::I8,
Expand Down Expand Up @@ -64,7 +65,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_dynamicShapes4D_Transpose, TransposeLayerCPUTest,
::testing::Values(Precision::BF16),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(additional_config),
::testing::Values(CPUSpecificParams{})),
::testing::ValuesIn({CPUSpecificParams{}, cpuParams_nspc})),
TransposeLayerCPUTest::getTestCaseName);

const std::vector<InputShape> staticInputShapes5DC16 = {InputShape{
Expand Down

0 comments on commit 7cb3bf5

Please sign in to comment.