Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Nov 11, 2024
1 parent 9d67675 commit d90e212
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ inline static void rotate_kv_cache_chunk_avx512(CT* current_x_values_ptr, CT* cu

#if defined(HAVE_AVX2)
template<class CT>
inline static void rotate_kv_cache_chunk_avx2(CT* current_x_values_ptr, CT* current_y_values_ptr, float* current_rotation_coeffts_cos_ptr, float* current_rotation_coeffts_sin_ptr, bool is_underutilizing, size_t num_vectorized_elements_per_iteration) {
inline static void rotate_kv_cache_chunk_avx2(CT* current_x_values_ptr, CT* current_y_values_ptr, float* current_rotation_coeffts_cos_ptr, float* current_rotation_coeffts_sin_ptr, size_t num_vectorized_elements_per_iteration, size_t is_underutilizing) {
using namespace ov::Extensions::Cpu::XARCH;

auto result_x = _mm256_setzero_ps();
Expand All @@ -70,19 +70,19 @@ inline static void rotate_kv_cache_chunk_avx2(CT* current_x_values_ptr, CT* curr
auto cache_values_y = _mm256_undefined_ps();

if (!is_underutilizing) {
coeffts_cos = mm256_uni_loadu_tail_ps(current_rotation_coeffts_cos_ptr, num_vectorized_elements_per_iteration);
coeffts_sin = mm256_uni_loadu_tail_ps(current_rotation_coeffts_sin_ptr, num_vectorized_elements_per_iteration);

cache_values_x = mm256_uni_loadu_tail_ps(current_x_values_ptr, num_vectorized_elements_per_iteration);
cache_values_y = mm256_uni_loadu_tail_ps(current_y_values_ptr, num_vectorized_elements_per_iteration);
}
else {
coeffts_cos = mm256_uni_loadu_ps(current_rotation_coeffts_cos_ptr);
coeffts_sin = mm256_uni_loadu_ps(current_rotation_coeffts_sin_ptr);

cache_values_x = mm256_uni_loadu_ps(current_x_values_ptr);
cache_values_y = mm256_uni_loadu_ps(current_y_values_ptr);
}
else {
coeffts_cos = mm256_uni_loadu_tail_ps(current_rotation_coeffts_cos_ptr, num_vectorized_elements_per_iteration);
coeffts_sin = mm256_uni_loadu_tail_ps(current_rotation_coeffts_sin_ptr, num_vectorized_elements_per_iteration);

cache_values_x = mm256_uni_loadu_tail_ps(current_x_values_ptr, num_vectorized_elements_per_iteration);
cache_values_y = mm256_uni_loadu_tail_ps(current_y_values_ptr, num_vectorized_elements_per_iteration);
}

result_x = _mm256_fmadd_ps(cache_values_x, coeffts_cos, result_x);
result_x = _mm256_fnmadd_ps(cache_values_y, coeffts_sin, result_x); // negative multiply-add
Expand Down Expand Up @@ -115,28 +115,30 @@ inline static void rotate_kv_cache_block_hw(CT* cache_block_ptr, float* block_ro
constexpr size_t vec_len_in_f32_elts = ov::Extensions::Cpu::XARCH::vec_len_f32_avx2;
#endif // defined(HAVE_AVX512F)

size_t num_processed_elements_per_iteration = vec_len_in_f32_elts * 2; // implementations act on pairs of cache values at once using separate registers, each elt is expanded to f32 on load
size_t num_iterations = embedding_size / num_processed_elements_per_iteration; // implementations act on pairs of cache values at once using separate registers, each elt is expanded to f32 on load
size_t num_processed_elements_per_iteration = vec_len_in_f32_elts; // implementations act on pairs of cache values at once using separate registers, each elt is expanded to f32 on load
size_t num_iterations = embedding_size / num_processed_elements_per_iteration;
if (embedding_size >= num_processed_elements_per_iteration) {
OPENVINO_ASSERT(!(num_processed_elements_per_iteration % vec_len_in_f32_elts));
}
else {
is_underutilizing = true;
OPENVINO_ASSERT(!(embedding_size % 2));
num_processed_elements_per_iteration = embedding_size;
num_processed_elements_per_iteration = embedding_size / 2;
num_iterations = 1;
}
constexpr size_t vec_size_in_elts = vec_len_in_bytes / sizeof(CT);

CT* current_cache_element_ptr = cache_block_ptr;

for (size_t head_idx = 0; head_idx < num_heads; head_idx++, current_cache_element_ptr += block_size * embedding_size) {
for (size_t head_idx = 0; head_idx < num_heads; head_idx++) {
// the rotation coefficients are taken to be the same for all heads
float* current_rotation_coeffts_ptr = block_rotation_coefficients_ptr;

for (size_t tok_idx = 0; tok_idx <= block_size; tok_idx++, current_cache_element_ptr += embedding_size) {
for (size_t tok_idx = 0; tok_idx < block_size; tok_idx++,
current_cache_element_ptr += embedding_size,
current_rotation_coeffts_ptr += embedding_size) {
CT* current_x_values_ptr = current_cache_element_ptr;
CT* current_y_values_ptr = current_cache_element_ptr + embedding_size / 2;

float* current_rotation_coeffts_cos_ptr = current_rotation_coeffts_ptr;
float* current_rotation_coeffts_sin_ptr = current_rotation_coeffts_ptr + embedding_size / 2;

Expand All @@ -146,9 +148,9 @@ inline static void rotate_kv_cache_block_hw(CT* cache_block_ptr, float* block_ro
current_rotation_coeffts_cos_ptr += vec_size_in_elts,
current_rotation_coeffts_sin_ptr += vec_size_in_elts) {
#if defined(HAVE_AVX512F)
rotate_kv_cache_chunk_avx512(current_x_values_ptr, current_y_values_ptr, current_rotation_coeffts_cos_ptr, current_rotation_coeffts_sin_ptr, is_underutilizing, num_processed_elements_per_iteration);
rotate_kv_cache_chunk_avx512(current_x_values_ptr, current_y_values_ptr, current_rotation_coeffts_cos_ptr, current_rotation_coeffts_sin_ptr, num_processed_elements_per_iteration, is_underutilizing);
#else // HAVE_AVX2
rotate_kv_cache_chunk_avx2(current_x_values_ptr, current_y_values_ptr, current_rotation_coeffts_cos_ptr, current_rotation_coeffts_sin_ptr, is_underutilizing, num_processed_elements_per_iteration);
rotate_kv_cache_chunk_avx2(current_x_values_ptr, current_y_values_ptr, current_rotation_coeffts_cos_ptr, current_rotation_coeffts_sin_ptr, num_processed_elements_per_iteration, is_underutilizing);
#endif // defined(HAVE_AVX512F)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@
#include <vector>
#include <cassert>

// TODO (vshampor): remove this
#define HAVE_AVX2
#define HAVE_AVX512F

#include "openvino/core/type/bfloat16.hpp"
#include "openvino/core/type/float16.hpp"

Expand Down
16 changes: 5 additions & 11 deletions src/plugins/intel_cpu/tests/unit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,11 @@ if (ENABLE_SNIPPETS_LIBXSMM_TPP)
target_include_directories(${TARGET_NAME} SYSTEM PRIVATE $<TARGET_PROPERTY:xsmm,INCLUDE_DIRECTORIES>)
endif()

if(ENABLE_AVX2)
ov_avx2_optimization_flags(avx2_flags)
message("VSHAMPOR: passing AVX flags ${avx2_flags}")
target_compile_options(${TARGET_NAME} PRIVATE "${avx2_flags}")
endif()

if(ENABLE_AVX512F)
ov_avx512_optimization_flags(avx512_flags)
message("VSHAMPOR: passing AVX flags ${avx512_flags}")
target_compile_options(${TARGET_NAME} PRIVATE "${avx512_flags}")
endif()
ov_avx2_optimization_flags(avx2_flags)
ov_avx512_optimization_flags(avx512_flags)
message("VSHAMPOR: passing AVX flags ${avx2_flags};${avx512_flags}")
target_compile_options(${TARGET_NAME} PRIVATE "${avx2_flags};${avx512_flags}")
target_compile_definitions(${TARGET_NAME} PRIVATE HAVE_AVX2 HAVE_AVX512F)

# LTO
set_target_properties(${TARGET_NAME} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_RELEASE ${ENABLE_LTO})
Expand Down
50 changes: 34 additions & 16 deletions src/plugins/intel_cpu/tests/unit/paged_attn_cache_rotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
#include <string>
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <gtest/internal/gtest-param-util.h>
#include <memory>

// TODO (vshampor): remove this, find a way to forward the compile flags to the unit tests
#define HAVE_AVX2
#define HAVE_AVX512F

// the includes in the block below are necessary in order for the common.hpp header to be
// instantiated correctly
#include <cstring>
#if defined(HAVE_AVX2) || defined(HAVE_AVX512F)
# include <immintrin.h>
#endif
#include "kernels/scaled_attn/common.hpp"

#include "utils/plain_tensor.hpp"
#include "nodes/kernels/scaled_attn/cache_rotation.hpp"

Expand Down Expand Up @@ -175,26 +179,27 @@ enum class TargetInstructionSet {
AVX512
};

using CacheRotationHWKernelTest = ::testing::TestWithParam<TargetInstructionSet>;
using CacheRotationHWKernelTest = ::testing::TestWithParam<std::tuple<TargetInstructionSet, size_t>>;

MATCHER_P3(IsNFirstValuesNear, ref_container, abs_err, n, "") {
if (ref_container.size() < n || arg.size() < n) return false;
if (ref_container.size() != arg.size()) return false;

bool is_ok = true;
for (size_t i = 0; i < ref_container.size(); i++)
for (size_t i = 0; i < n; i++)
{
if (!::testing::ExplainMatchResult(::testing::FloatNear(float(arg[i]), abs_err), float(ref_container[i]), result_listener))
{
*result_listener << " for element at idx " << i;
*result_listener << " for element at idx " << i << '\n';
is_ok = false;
}
}
return is_ok;
}

TEST_P(CacheRotationHWKernelTest, HWChunkRotationGivesReferenceResults) {
auto instruction_set = GetParam();
auto instruction_set = std::get<0>(GetParam());
auto num_elements_to_process = std::get<1>(GetParam());

constexpr size_t MAX_CHUNK_SIZE_IN_ELEMENTS = 16;
using MemChunk = std::array<float, MAX_CHUNK_SIZE_IN_ELEMENTS>;
Expand All @@ -219,30 +224,43 @@ TEST_P(CacheRotationHWKernelTest, HWChunkRotationGivesReferenceResults) {
MemChunk ref_chunk_y = { -1.10423816, -1.15289358, -0.184335, -0.44671148, -0.44598258, -0.19360973, 0.26258603, 1.72300577,
0.24143039, -0.19057521, 0.46558381, -0.55538896, 0.80444446, -0.93508112, -0.32987781, 1.49928198 };

// unprocessed elements should remain untouched
std::copy(chunk_x.begin() + num_elements_to_process, chunk_x.end(), ref_chunk_x.begin() + num_elements_to_process);
std::copy(chunk_y.begin() + num_elements_to_process, chunk_y.end(), ref_chunk_y.begin() + num_elements_to_process);

size_t vec_len_in_elts = 0;
switch(instruction_set) {
using namespace ov::Extensions::Cpu::XARCH;
case TargetInstructionSet::AVX2:
vec_len_in_elts = ov::Extensions::Cpu::XARCH::vec_len_f32_avx2;
rotate_kv_cache_chunk_avx2(chunk_x.data(), chunk_y.data(), chunk_cos.data(), vec_len_in_elts, /* is_underutilizing = */ false);
rotate_kv_cache_chunk_avx2(chunk_x.data(), chunk_y.data(), chunk_cos.data(), chunk_sin.data(), num_elements_to_process, /* is_underutilizing = */ num_elements_to_process < vec_len_f32_avx2);
break;
case TargetInstructionSet::AVX512:
vec_len_in_elts = ov::Extensions::Cpu::XARCH::vec_len_f32_avx512;
rotate_kv_cache_chunk_avx512(chunk_x.data(), chunk_y.data(), chunk_cos.data(), chunk_sin.data(), vec_len_in_elts, /* is_underutilizing = */false);
rotate_kv_cache_chunk_avx512(chunk_x.data(), chunk_y.data(), chunk_cos.data(), chunk_sin.data(), num_elements_to_process, /* is_underutilizing = */ num_elements_to_process < vec_len_f32_avx512);
break;
default:
FAIL() << "unknown target instruction set";
}

EXPECT_THAT(chunk_x, IsNFirstValuesNear(ref_chunk_x, 1e-6, vec_len_in_elts));
EXPECT_THAT(chunk_y, IsNFirstValuesNear(ref_chunk_y, 1e-6, vec_len_in_elts));
EXPECT_THAT(chunk_x, IsNFirstValuesNear(ref_chunk_x, 1e-6, num_elements_to_process));
EXPECT_THAT(chunk_y, IsNFirstValuesNear(ref_chunk_y, 1e-6, num_elements_to_process));

EXPECT_EQ(chunk_cos, ref_chunk_cos);
EXPECT_EQ(chunk_sin, ref_chunk_sin);

}

INSTANTIATE_TEST_SUITE_P(VariousInstructionSets, CacheRotationHWKernelTest, ::testing::Values{TargetInstructionSet::AVX2, TargetInstructionSet::AVX512});
auto TEST_STRUCT_TO_NAME_FN = [](const testing::TestParamInfo<CacheRotationHWKernelTest::ParamType>& info) {
size_t num_elts = std::get<1>(info.param);
switch(std::get<0>(info.param)) {
case TargetInstructionSet::AVX2:
return std::string("avx2-") + std::to_string(num_elts);
case TargetInstructionSet::AVX512:
return std::string("avx512-") + std::to_string(num_elts);
}
return std::string("unknown");
};

INSTANTIATE_TEST_SUITE_P(AVX2, CacheRotationHWKernelTest, ::testing::Combine(::testing::Values(TargetInstructionSet::AVX2), ::testing::Range(size_t(0), ov::Extensions::Cpu::XARCH::vec_len_f32_avx2 + 1)), TEST_STRUCT_TO_NAME_FN);
INSTANTIATE_TEST_SUITE_P(AVX512, CacheRotationHWKernelTest, ::testing::Combine(::testing::Values(TargetInstructionSet::AVX512), ::testing::Range(size_t(0), ov::Extensions::Cpu::XARCH::vec_len_f32_avx512 + 1)), TEST_STRUCT_TO_NAME_FN);

TYPED_TEST(CacheRotationKernelTest, HWBlockRotationGivesReferenceResults) {
auto raw_cache_mem_ptr = this->cache_mem_ptr.get();
Expand Down

0 comments on commit d90e212

Please sign in to comment.