From d90e21241ca111b38f0521f92b79acb766644b89 Mon Sep 17 00:00:00 2001 From: Vasily Shamporov Date: Mon, 11 Nov 2024 18:42:33 +0100 Subject: [PATCH] Fix tests --- .../kernels/scaled_attn/cache_rotation.hpp | 34 +++++++------ .../src/nodes/kernels/scaled_attn/common.hpp | 4 -- .../intel_cpu/tests/unit/CMakeLists.txt | 16 ++---- .../tests/unit/paged_attn_cache_rotation.cpp | 50 +++++++++++++------ 4 files changed, 57 insertions(+), 47 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/cache_rotation.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/cache_rotation.hpp index f3a6a92f1fcf44..ac4f645bb3ed57 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/cache_rotation.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/cache_rotation.hpp @@ -57,7 +57,7 @@ inline static void rotate_kv_cache_chunk_avx512(CT* current_x_values_ptr, CT* cu #if defined(HAVE_AVX2) template -inline static void rotate_kv_cache_chunk_avx2(CT* current_x_values_ptr, CT* current_y_values_ptr, float* current_rotation_coeffts_cos_ptr, float* current_rotation_coeffts_sin_ptr, bool is_underutilizing, size_t num_vectorized_elements_per_iteration) { +inline static void rotate_kv_cache_chunk_avx2(CT* current_x_values_ptr, CT* current_y_values_ptr, float* current_rotation_coeffts_cos_ptr, float* current_rotation_coeffts_sin_ptr, size_t num_vectorized_elements_per_iteration, size_t is_underutilizing) { using namespace ov::Extensions::Cpu::XARCH; auto result_x = _mm256_setzero_ps(); @@ -70,19 +70,19 @@ inline static void rotate_kv_cache_chunk_avx2(CT* current_x_values_ptr, CT* curr auto cache_values_y = _mm256_undefined_ps(); if (!is_underutilizing) { - coeffts_cos = mm256_uni_loadu_tail_ps(current_rotation_coeffts_cos_ptr, num_vectorized_elements_per_iteration); - coeffts_sin = mm256_uni_loadu_tail_ps(current_rotation_coeffts_sin_ptr, num_vectorized_elements_per_iteration); - - cache_values_x = mm256_uni_loadu_tail_ps(current_x_values_ptr, num_vectorized_elements_per_iteration); - cache_values_y = mm256_uni_loadu_tail_ps(current_y_values_ptr, num_vectorized_elements_per_iteration); - } - else { coeffts_cos = mm256_uni_loadu_ps(current_rotation_coeffts_cos_ptr); coeffts_sin = mm256_uni_loadu_ps(current_rotation_coeffts_sin_ptr); cache_values_x = mm256_uni_loadu_ps(current_x_values_ptr); cache_values_y = mm256_uni_loadu_ps(current_y_values_ptr); } + else { + coeffts_cos = mm256_uni_loadu_tail_ps(current_rotation_coeffts_cos_ptr, num_vectorized_elements_per_iteration); + coeffts_sin = mm256_uni_loadu_tail_ps(current_rotation_coeffts_sin_ptr, num_vectorized_elements_per_iteration); + + cache_values_x = mm256_uni_loadu_tail_ps(current_x_values_ptr, num_vectorized_elements_per_iteration); + cache_values_y = mm256_uni_loadu_tail_ps(current_y_values_ptr, num_vectorized_elements_per_iteration); + } result_x = _mm256_fmadd_ps(cache_values_x, coeffts_cos, result_x); result_x = _mm256_fnmadd_ps(cache_values_y, coeffts_sin, result_x); // negative multiply-add @@ -115,28 +115,30 @@ inline static void rotate_kv_cache_block_hw(CT* cache_block_ptr, float* block_ro constexpr size_t vec_len_in_f32_elts = ov::Extensions::Cpu::XARCH::vec_len_f32_avx2; #endif // defined(HAVE_AVX512F) - size_t num_processed_elements_per_iteration = vec_len_in_f32_elts * 2; // implementations act on pairs of cache values at once using separate registers, each elt is expanded to f32 on load - size_t num_iterations = embedding_size / num_processed_elements_per_iteration; // implementations act on pairs of cache values at once using separate registers, each elt is expanded to f32 on load + size_t num_processed_elements_per_iteration = vec_len_in_f32_elts; // implementations act on pairs of cache values at once using separate registers, each elt is expanded to f32 on load + size_t num_iterations = embedding_size / num_processed_elements_per_iteration; if (embedding_size >= num_processed_elements_per_iteration) { OPENVINO_ASSERT(!(num_processed_elements_per_iteration % vec_len_in_f32_elts)); } else { is_underutilizing = true; OPENVINO_ASSERT(!(embedding_size % 2)); - num_processed_elements_per_iteration = embedding_size; + num_processed_elements_per_iteration = embedding_size / 2; num_iterations = 1; } constexpr size_t vec_size_in_elts = vec_len_in_bytes / sizeof(CT); CT* current_cache_element_ptr = cache_block_ptr; - for (size_t head_idx = 0; head_idx < num_heads; head_idx++, current_cache_element_ptr += block_size * embedding_size) { + for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { // the rotation coefficients are taken to be the same for all heads float* current_rotation_coeffts_ptr = block_rotation_coefficients_ptr; - - for (size_t tok_idx = 0; tok_idx <= block_size; tok_idx++, current_cache_element_ptr += embedding_size) { + for (size_t tok_idx = 0; tok_idx < block_size; tok_idx++, + current_cache_element_ptr += embedding_size, + current_rotation_coeffts_ptr += embedding_size) { CT* current_x_values_ptr = current_cache_element_ptr; CT* current_y_values_ptr = current_cache_element_ptr + embedding_size / 2; + float* current_rotation_coeffts_cos_ptr = current_rotation_coeffts_ptr; float* current_rotation_coeffts_sin_ptr = current_rotation_coeffts_ptr + embedding_size / 2; @@ -146,9 +148,9 @@ inline static void rotate_kv_cache_block_hw(CT* cache_block_ptr, float* block_ro current_rotation_coeffts_cos_ptr += vec_size_in_elts, current_rotation_coeffts_sin_ptr += vec_size_in_elts) { #if defined(HAVE_AVX512F) - rotate_kv_cache_chunk_avx512(current_x_values_ptr, current_y_values_ptr, current_rotation_coeffts_cos_ptr, current_rotation_coeffts_sin_ptr, is_underutilizing, num_processed_elements_per_iteration); + rotate_kv_cache_chunk_avx512(current_x_values_ptr, current_y_values_ptr, current_rotation_coeffts_cos_ptr, current_rotation_coeffts_sin_ptr, num_processed_elements_per_iteration, is_underutilizing); #else // HAVE_AVX2 - rotate_kv_cache_chunk_avx2(current_x_values_ptr, current_y_values_ptr, current_rotation_coeffts_cos_ptr, current_rotation_coeffts_sin_ptr, is_underutilizing, num_processed_elements_per_iteration); + rotate_kv_cache_chunk_avx2(current_x_values_ptr, current_y_values_ptr, current_rotation_coeffts_cos_ptr, current_rotation_coeffts_sin_ptr, num_processed_elements_per_iteration, is_underutilizing); #endif // defined(HAVE_AVX512F) } } diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp index 7a2770e47377d9..c307b338d1454f 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp @@ -9,10 +9,6 @@ #include #include -// TODO (vshampor): remove this -#define HAVE_AVX2 -#define HAVE_AVX512F - #include "openvino/core/type/bfloat16.hpp" #include "openvino/core/type/float16.hpp" diff --git a/src/plugins/intel_cpu/tests/unit/CMakeLists.txt b/src/plugins/intel_cpu/tests/unit/CMakeLists.txt index fa66a3e4730d88..2f56545e6d802f 100644 --- a/src/plugins/intel_cpu/tests/unit/CMakeLists.txt +++ b/src/plugins/intel_cpu/tests/unit/CMakeLists.txt @@ -78,17 +78,11 @@ if (ENABLE_SNIPPETS_LIBXSMM_TPP) target_include_directories(${TARGET_NAME} SYSTEM PRIVATE $) endif() -if(ENABLE_AVX2) - ov_avx2_optimization_flags(avx2_flags) - message("VSHAMPOR: passing AVX flags ${avx2_flags}") - target_compile_options(${TARGET_NAME} PRIVATE "${avx2_flags}") -endif() - -if(ENABLE_AVX512F) - ov_avx512_optimization_flags(avx512_flags) - message("VSHAMPOR: passing AVX flags ${avx512_flags}") - target_compile_options(${TARGET_NAME} PRIVATE "${avx512_flags}") -endif() +ov_avx2_optimization_flags(avx2_flags) +ov_avx512_optimization_flags(avx512_flags) +message("VSHAMPOR: passing AVX flags ${avx2_flags};${avx512_flags}") +target_compile_options(${TARGET_NAME} PRIVATE "${avx2_flags};${avx512_flags}") +target_compile_definitions(${TARGET_NAME} PRIVATE HAVE_AVX2 HAVE_AVX512F) # LTO set_target_properties(${TARGET_NAME} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_RELEASE ${ENABLE_LTO}) diff --git a/src/plugins/intel_cpu/tests/unit/paged_attn_cache_rotation.cpp b/src/plugins/intel_cpu/tests/unit/paged_attn_cache_rotation.cpp index e86badc60911b4..2bf77d651e6ef6 100644 --- a/src/plugins/intel_cpu/tests/unit/paged_attn_cache_rotation.cpp +++ b/src/plugins/intel_cpu/tests/unit/paged_attn_cache_rotation.cpp @@ -6,13 +6,17 @@ #include #include #include +#include #include -// TODO (vshampor): remove this, find a way to forward the compile flags to the unit tests -#define HAVE_AVX2 -#define HAVE_AVX512F - +// the includes in the block below are necessary in order for the common.hpp header to be +// instantiated correctly +#include +#if defined(HAVE_AVX2) || defined(HAVE_AVX512F) +# include +#endif #include "kernels/scaled_attn/common.hpp" + #include "utils/plain_tensor.hpp" #include "nodes/kernels/scaled_attn/cache_rotation.hpp" @@ -175,18 +179,18 @@ enum class TargetInstructionSet { AVX512 }; -using CacheRotationHWKernelTest = ::testing::TestWithParam; +using CacheRotationHWKernelTest = ::testing::TestWithParam>; MATCHER_P3(IsNFirstValuesNear, ref_container, abs_err, n, "") { if (ref_container.size() < n || arg.size() < n) return false; if (ref_container.size() != arg.size()) return false; bool is_ok = true; - for (size_t i = 0; i < ref_container.size(); i++) + for (size_t i = 0; i < n; i++) { if (!::testing::ExplainMatchResult(::testing::FloatNear(float(arg[i]), abs_err), float(ref_container[i]), result_listener)) { - *result_listener << " for element at idx " << i; + *result_listener << " for element at idx " << i << '\n'; is_ok = false; } } @@ -194,7 +198,8 @@ MATCHER_P3(IsNFirstValuesNear, ref_container, abs_err, n, "") { } TEST_P(CacheRotationHWKernelTest, HWChunkRotationGivesReferenceResults) { - auto instruction_set = GetParam(); + auto instruction_set = std::get<0>(GetParam()); + auto num_elements_to_process = std::get<1>(GetParam()); constexpr size_t MAX_CHUNK_SIZE_IN_ELEMENTS = 16; using MemChunk = std::array; @@ -219,30 +224,43 @@ TEST_P(CacheRotationHWKernelTest, HWChunkRotationGivesReferenceResults) { MemChunk ref_chunk_y = { -1.10423816, -1.15289358, -0.184335, -0.44671148, -0.44598258, -0.19360973, 0.26258603, 1.72300577, 0.24143039, -0.19057521, 0.46558381, -0.55538896, 0.80444446, -0.93508112, -0.32987781, 1.49928198 }; + // unprocessed elements should remain untouched + std::copy(chunk_x.begin() + num_elements_to_process, chunk_x.end(), ref_chunk_x.begin() + num_elements_to_process); + std::copy(chunk_y.begin() + num_elements_to_process, chunk_y.end(), ref_chunk_y.begin() + num_elements_to_process); - size_t vec_len_in_elts = 0; switch(instruction_set) { + using namespace ov::Extensions::Cpu::XARCH; case TargetInstructionSet::AVX2: - vec_len_in_elts = ov::Extensions::Cpu::XARCH::vec_len_f32_avx2; - rotate_kv_cache_chunk_avx2(chunk_x.data(), chunk_y.data(), chunk_cos.data(), vec_len_in_elts, /* is_underutilizing = */ false); + rotate_kv_cache_chunk_avx2(chunk_x.data(), chunk_y.data(), chunk_cos.data(), chunk_sin.data(), num_elements_to_process, /* is_underutilizing = */ num_elements_to_process < vec_len_f32_avx2); break; case TargetInstructionSet::AVX512: - vec_len_in_elts = ov::Extensions::Cpu::XARCH::vec_len_f32_avx512; - rotate_kv_cache_chunk_avx512(chunk_x.data(), chunk_y.data(), chunk_cos.data(), chunk_sin.data(), vec_len_in_elts, /* is_underutilizing = */false); + rotate_kv_cache_chunk_avx512(chunk_x.data(), chunk_y.data(), chunk_cos.data(), chunk_sin.data(), num_elements_to_process, /* is_underutilizing = */ num_elements_to_process < vec_len_f32_avx512); break; default: FAIL() << "unknown target instruction set"; } - EXPECT_THAT(chunk_x, IsNFirstValuesNear(ref_chunk_x, 1e-6, vec_len_in_elts)); - EXPECT_THAT(chunk_y, IsNFirstValuesNear(ref_chunk_y, 1e-6, vec_len_in_elts)); + EXPECT_THAT(chunk_x, IsNFirstValuesNear(ref_chunk_x, 1e-6, num_elements_to_process)); + EXPECT_THAT(chunk_y, IsNFirstValuesNear(ref_chunk_y, 1e-6, num_elements_to_process)); EXPECT_EQ(chunk_cos, ref_chunk_cos); EXPECT_EQ(chunk_sin, ref_chunk_sin); } -INSTANTIATE_TEST_SUITE_P(VariousInstructionSets, CacheRotationHWKernelTest, ::testing::Values{TargetInstructionSet::AVX2, TargetInstructionSet::AVX512}); +auto TEST_STRUCT_TO_NAME_FN = [](const testing::TestParamInfo& info) { + size_t num_elts = std::get<1>(info.param); + switch(std::get<0>(info.param)) { + case TargetInstructionSet::AVX2: + return std::string("avx2-") + std::to_string(num_elts); + case TargetInstructionSet::AVX512: + return std::string("avx512-") + std::to_string(num_elts); + } + return std::string("unknown"); +}; + +INSTANTIATE_TEST_SUITE_P(AVX2, CacheRotationHWKernelTest, ::testing::Combine(::testing::Values(TargetInstructionSet::AVX2), ::testing::Range(size_t(0), ov::Extensions::Cpu::XARCH::vec_len_f32_avx2 + 1)), TEST_STRUCT_TO_NAME_FN); +INSTANTIATE_TEST_SUITE_P(AVX512, CacheRotationHWKernelTest, ::testing::Combine(::testing::Values(TargetInstructionSet::AVX512), ::testing::Range(size_t(0), ov::Extensions::Cpu::XARCH::vec_len_f32_avx512 + 1)), TEST_STRUCT_TO_NAME_FN); TYPED_TEST(CacheRotationKernelTest, HWBlockRotationGivesReferenceResults) { auto raw_cache_mem_ptr = this->cache_mem_ptr.get();