diff --git a/src/cpp/src/cache_eviction.hpp b/src/cpp/src/cache_eviction.hpp index 3b22512b1f..afc280c16b 100644 --- a/src/cpp/src/cache_eviction.hpp +++ b/src/cpp/src/cache_eviction.hpp @@ -119,11 +119,11 @@ class CacheEvictionAlgorithm { class CacheRotationCalculator { public: - CacheRotationCalculator(size_t block_size, size_t max_context_length, size_t kv_head_size, double rope_theta = 10000.0f) : m_block_size(block_size) { + CacheRotationCalculator(size_t block_size, size_t max_context_length, size_t kv_head_size, double rope_theta = 10000.0f) : m_block_size(block_size), m_head_size(kv_head_size) { size_t max_position_angle_multiplier = max_context_length / 2 + 1; // adding +1 here and below for good measure in case of odd dividends size_t num_freqs = kv_head_size / 2 + 1; - m_rope_sin_lut.reserve(max_position_angle_multiplier); - m_rope_cos_lut.reserve(max_position_angle_multiplier); + m_rope_sin_lut.resize(max_position_angle_multiplier); + m_rope_cos_lut.resize(max_position_angle_multiplier); for (size_t i = 0; i < max_position_angle_multiplier; i++) { m_rope_sin_lut[i].reserve(num_freqs); @@ -138,45 +138,53 @@ class CacheRotationCalculator { }; using RotationCoefficientsPerToken = std::vector>; - std::pair get_rotation_multipliers(const std::set& evicted_block_logical_indices, size_t num_logical_blocks_before_eviction) { - std::pair retval; + struct BlockRotationData { + size_t logical_block_idx; + RotationCoefficientsPerToken sines; + RotationCoefficientsPerToken cosines; + }; + std::vector get_rotation_multipliers(const std::set& evicted_block_logical_indices, size_t num_logical_blocks_before_eviction) { + std::vector retval; if (evicted_block_logical_indices.empty()) { return retval; } + retval.reserve(num_logical_blocks_before_eviction - evicted_block_logical_indices.size()); + ptrdiff_t current_rotation_delta_in_positions = 0; std::vector logical_block_space(num_logical_blocks_before_eviction); std::iota(logical_block_space.begin(), logical_block_space.end(), 0); - std::vector rotation_deltas; - rotation_deltas.reserve(num_logical_blocks_before_eviction - evicted_block_logical_indices.size()); - for (size_t logical_block_idx : logical_block_space) { if (evicted_block_logical_indices.find(logical_block_idx) != evicted_block_logical_indices.end()) { current_rotation_delta_in_positions += 1; } else { if (current_rotation_delta_in_positions != 0) { - rotation_deltas.push_back(current_rotation_delta_in_positions); + BlockRotationData block_rotation_data; + block_rotation_data.logical_block_idx = logical_block_idx; + block_rotation_data.cosines.reserve(m_block_size / 2); + block_rotation_data.sines.reserve(m_block_size / 2); + for (size_t i = 0; i < m_block_size / 2; i++) { + block_rotation_data.cosines.push_back(m_rope_cos_lut[current_rotation_delta_in_positions]); + block_rotation_data.sines.push_back(m_rope_sin_lut[current_rotation_delta_in_positions]); + } + + retval.push_back(block_rotation_data); } } } - size_t num_tokens_to_rotate = rotation_deltas.size() * m_block_size; - retval.first.reserve(num_tokens_to_rotate); - retval.second.reserve(num_tokens_to_rotate); - for (ptrdiff_t delta : rotation_deltas) { - for (size_t i = 0; i < m_block_size; i++) { - retval.first.push_back(m_rope_cos_lut[delta]); - retval.second.push_back(m_rope_sin_lut[delta]); - } - } - return retval; } + size_t get_head_size() const { + return m_head_size; + } + private: size_t m_block_size; + size_t m_head_size; std::vector> m_rope_sin_lut; std::vector> m_rope_cos_lut; }; diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index 4896ba0e79..90fd32969a 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -76,11 +76,14 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::init( m_model_runner = std::make_shared(infer_request, updated_config, device_config.get_num_layers(), /* m_collect_attention_scores = */ true); m_rotation_coefficient_stores.reserve(device_config.get_num_layers()); - ov::Shape rotation_coefficient_store_shape{ device_config.get_head_size(), scheduler_config.block_size * scheduler_config.num_kv_blocks }; + ov::Shape rotation_coefficient_store_shape{ device_config.get_head_size() * (scheduler_config.block_size * scheduler_config.num_kv_blocks) }; for (size_t i = 0; i < device_config.get_num_layers(); i++) { - ov::Tensor store(device_config.get_cache_precision(), rotation_coefficient_store_shape); + ov::Tensor store(ov::element::f32, rotation_coefficient_store_shape); + std::memset(store.data(), 0, store.get_byte_size()); m_rotation_coefficient_stores.push_back(store); } + m_next_step_rotation_coefficients.resize(device_config.get_num_layers()); + m_next_step_rotated_block_logical_indices_per_sequence.resize(device_config.get_num_layers()); m_cache_rotation_calculator = std::make_shared(scheduler_config.block_size, // TODO (vshampor): LUT size equal to max cache size in tokens // is overkill - find a way to pass the max sequence length instead @@ -205,7 +208,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() { // evict unimportant blocks from KV cache, if requested if (sched_config.use_cache_eviction) { maybe_evict_cache_blocks(sched_config); - m_model_runner->set_cache_rotation_coefficients(m_next_step_rotation_coefficients); + m_model_runner->set_cache_rotation_data(m_next_step_rotation_coefficients, m_next_step_rotated_block_logical_indices_per_sequence); } #ifdef DEBUG_CACHE_STATE_DUMP @@ -389,12 +392,20 @@ float ContinuousBatchingPipeline::ContinuousBatchingImpl::_get_current_running_a void ContinuousBatchingPipeline::ContinuousBatchingImpl::maybe_evict_cache_blocks(const SchedulerConfig& sched_config) { std::unordered_map seq_group_to_num_blocks_evicted_map; auto sequence_attention_scores = m_model_runner->get_last_attention_scores(); + + OPENVINO_ASSERT(!sequence_attention_scores.empty()); + size_t num_decoder_layers = sequence_attention_scores.begin()->second.size(); + std::vector num_blocks_to_rotate_for_each_layer(num_decoder_layers, 0); + size_t head_size = m_cache_rotation_calculator->get_head_size(); + + m_next_step_rotation_coefficients.clear(); + m_next_step_rotated_block_logical_indices_per_sequence.clear(); + m_next_step_rotated_block_logical_indices_per_sequence.resize(num_decoder_layers); + for (auto& seq_id_and_attention_scores : sequence_attention_scores) { auto seq_id = seq_id_and_attention_scores.first; const auto& attention_scores_for_all_decoder_layers = seq_id_and_attention_scores.second; if (m_seq_group_id_to_cache_eviction_algo_map.find(seq_id) == m_seq_group_id_to_cache_eviction_algo_map.end()) { - auto num_decoder_layers = attention_scores_for_all_decoder_layers.size(); - m_seq_group_id_to_cache_eviction_algo_map[seq_id] = CacheEvictionAlgorithm(sched_config.cache_eviction_config, sched_config.block_size, num_decoder_layers); } auto& cache_eviction_algo = m_seq_group_id_to_cache_eviction_algo_map[seq_id]; @@ -403,31 +414,37 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::maybe_evict_cache_block auto logical_blocks_to_evict = cache_eviction_algo.evict_logical_blocks(); - for (size_t i = 0; i < logical_blocks_to_evict.size(); i++) { - size_t num_blocks_before_eviction = m_scheduler->get_block_tables(seq_id)[i].size(); + for (size_t layer_idx = 0; layer_idx < logical_blocks_to_evict.size(); layer_idx++) { + if (logical_blocks_to_evict[layer_idx].empty()) { + continue; + } + size_t num_blocks_before_eviction = m_scheduler->get_block_tables(seq_id)[layer_idx].size(); auto rotation_multipliers = - m_cache_rotation_calculator->get_rotation_multipliers(logical_blocks_to_evict[i], + m_cache_rotation_calculator->get_rotation_multipliers(logical_blocks_to_evict[layer_idx], num_blocks_before_eviction); - const auto& rotation_multipliers_cos = rotation_multipliers.first; - const auto& rotation_multipliers_sin = rotation_multipliers.second; - OPENVINO_ASSERT(rotation_multipliers_cos.size() == rotation_multipliers_sin.size()); - const size_t num_kv_heads = m_rotation_coefficient_stores[i].get_shape()[0]; - const size_t num_tokens = rotation_multipliers_cos.size() * 2; - - ov::Tensor rotation_multipliers_tensor(m_rotation_coefficient_stores[i], - ov::Coordinate{0, 0}, - ov::Coordinate{num_kv_heads, num_tokens}); - - // Fill the ROI tensor with rotation coefficient data - cos and sin coefficients are interleaved. - auto rotation_multipliers_tensor_data = rotation_multipliers_tensor.data(); - for (size_t head_idx = 0; head_idx < num_kv_heads; head_idx++) { - for (size_t pos_idx = 0; pos_idx < rotation_multipliers_cos.size(); pos_idx++) { - size_t head_offset = head_idx * num_tokens; - rotation_multipliers_tensor_data[head_offset + 2 * pos_idx] = rotation_multipliers_cos[head_idx][pos_idx]; - rotation_multipliers_tensor_data[head_offset + 2 * pos_idx + 1] = rotation_multipliers_sin[head_idx][pos_idx]; + for (size_t rotated_block_idx = 0; rotated_block_idx < rotation_multipliers.size(); rotated_block_idx++) { + const auto& block_rotation_data = rotation_multipliers[rotated_block_idx]; + const auto& rotation_multipliers_cos = block_rotation_data.cosines; + const auto& rotation_multipliers_sin = block_rotation_data.sines; + OPENVINO_ASSERT(rotation_multipliers_cos.size() == rotation_multipliers_sin.size()); + OPENVINO_ASSERT(rotation_multipliers_cos.size() * 2 == sched_config.block_size); + + m_next_step_rotated_block_logical_indices_per_sequence[layer_idx][seq_id].push_back(block_rotation_data.logical_block_idx); + + // Fill the store tensor with rotation coefficient data - cos and sin coefficients are interleaved + // NB: the order of seq_id in each per-sequence iteration of the `for (auto& seq_id_and_attention_scores ...` must be the same + // as the order of seq_ids in which the "rotated_block_indices.N" inputs are filled + size_t sequence_offset = num_blocks_to_rotate_for_each_layer[layer_idx] * sched_config.block_size * head_size; + auto rotation_multipliers_tensor_data = m_rotation_coefficient_stores[layer_idx].data() + sequence_offset; + for (size_t tok_idx = 0; tok_idx < rotation_multipliers_cos.size(); tok_idx++) { + size_t position_offset = head_size * tok_idx; + for (size_t embedding_pair_idx = 0; embedding_pair_idx < head_size / 2; embedding_pair_idx++) { + rotation_multipliers_tensor_data[position_offset + 2 * embedding_pair_idx] = rotation_multipliers_cos[tok_idx][embedding_pair_idx]; + rotation_multipliers_tensor_data[position_offset + 2 * embedding_pair_idx + 1] = rotation_multipliers_sin[tok_idx][embedding_pair_idx]; + } } + num_blocks_to_rotate_for_each_layer[layer_idx] += 1; } - m_next_step_rotation_coefficients[i] = rotation_multipliers_tensor; } m_scheduler->free_blocks_from_sequence(seq_id, logical_blocks_to_evict); @@ -444,6 +461,12 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::maybe_evict_cache_block } } + + // Select the previously filled rotation coefficients from the store tensor + for (size_t i = 0; i < num_decoder_layers; i++) { + m_next_step_rotation_coefficients.emplace_back(m_rotation_coefficient_stores[i], ov::Coordinate{0}, ov::Coordinate{num_blocks_to_rotate_for_each_layer[i] * head_size * sched_config.block_size}); + } + for (const auto& seq_group_ptr_and_num_blocks_evicted : seq_group_to_num_blocks_evicted_map) { // Assuming that the evicted blocks are always full (since they by design are only selected from intermediate-age blocks) auto seq_group_ptr = seq_group_ptr_and_num_blocks_evicted.first; diff --git a/src/cpp/src/continuous_batching_impl.hpp b/src/cpp/src/continuous_batching_impl.hpp index ed39ed905f..ea01843255 100644 --- a/src/cpp/src/continuous_batching_impl.hpp +++ b/src/cpp/src/continuous_batching_impl.hpp @@ -26,7 +26,7 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc static const size_t AVG_CACHE_USAGE_WINDOW_SIZE_IN_STEPS = 1000; std::deque m_previous_step_cache_usages; - + // flag to enable validation mode for sampler bool m_is_validation_mode_enabled = false; @@ -37,6 +37,9 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc // re-rotation coefficients to be sent to the proper model inputs at the *next* pipeline step. std::vector m_next_step_rotation_coefficients; + using SeqIdToRotatedLogicalBlocksMap = std::map>; + std::vector m_next_step_rotated_block_logical_indices_per_sequence; + std::shared_ptr m_cache_rotation_calculator; #ifdef DEBUG_CACHE_STATE_DUMP @@ -96,4 +99,4 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc const std::vector& sampling_params, const StreamerVariant& streamer) override; }; -} \ No newline at end of file +} diff --git a/src/cpp/src/model_runner.hpp b/src/cpp/src/model_runner.hpp index 9f2e3f8562..a4225e2b98 100644 --- a/src/cpp/src/model_runner.hpp +++ b/src/cpp/src/model_runner.hpp @@ -34,6 +34,8 @@ class ModelRunner { size_t m_num_decoder_layers; bool m_collect_attention_scores; std::vector m_cache_rotation_coefficients; + std::vector>> m_rotated_block_logical_indices_per_sequence_for_each_layer; + public: /** * Constructs the ModelRunner. @@ -67,9 +69,10 @@ class ModelRunner { return m_last_attention_scores; } - void set_cache_rotation_coefficients(const std::vector& cache_rotation_coefficients_for_each_layer) { + void set_cache_rotation_data(const std::vector& cache_rotation_coefficients_for_each_layer, const std::vector>>& rotated_logical_block_indices_per_sequence_for_each_layer) { // TODO (vshampor): avoid vector copy m_cache_rotation_coefficients = cache_rotation_coefficients_for_each_layer; + m_rotated_block_logical_indices_per_sequence_for_each_layer = rotated_logical_block_indices_per_sequence_for_each_layer; } /** @@ -209,10 +212,12 @@ class ModelRunner { void _fill_indices_from_block_tables(const std::vector& dst_tensor_names, const std::vector & sequence_groups, const Scheduler::Output& scheduler_output, - const std::vector& fill_n_last_vec) { - OPENVINO_ASSERT(fill_n_last_vec.size() == dst_tensor_names.size() || fill_n_last_vec.empty()); + const std::vector>>& seq_id_to_select_logical_idx_maps) { + OPENVINO_ASSERT(seq_id_to_select_logical_idx_maps.size() == dst_tensor_names.size() || seq_id_to_select_logical_idx_maps.empty()); + bool is_fill_all = seq_id_to_select_logical_idx_maps.empty(); size_t num_sequence_groups = scheduler_output.m_scheduled_sequence_groups_ids.size(); - size_t block_offset = 0; + std::vector block_offsets_per_layer(dst_tensor_names.size(), 0); + for (size_t i = 0; i < num_sequence_groups; ++i) { size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i]; SequenceGroup::CPtr sequence_group = sequence_groups[seq_group_id]; @@ -226,21 +231,31 @@ class ModelRunner { const auto & kv_blocks = scheduler_output.m_block_tables.at(sequence->get_id()); for (size_t layer_idx = 0; layer_idx < dst_tensor_names.size(); layer_idx++) { - size_t fill_n_last = num_blocks; - if (!fill_n_last_vec.empty()) { - fill_n_last = fill_n_last_vec[layer_idx]; - } - OPENVINO_ASSERT(num_blocks >= fill_n_last); - size_t starting_offset = num_blocks - fill_n_last; auto input_tensor = m_request.get_tensor(dst_tensor_names[layer_idx]); - auto block_indices_data = input_tensor.data() + block_offset; - for (size_t block_id = 0; block_id < fill_n_last; ++block_id) - // In case no cache eviction is requested, all per-layer block tables are expected to be identical - // at all times - block_indices_data[block_id] = kv_blocks[layer_idx][starting_offset + block_id]->get_index(); + auto block_indices_data = input_tensor.data() + block_offsets_per_layer[layer_idx]; + + if (is_fill_all) { + for (size_t block_id = 0; block_id < num_blocks; ++block_id) { + // In case no cache eviction is requested, all per-layer block tables are expected to be identical + // at all times + block_indices_data[block_id] = kv_blocks[layer_idx][block_id]->get_index(); + } + block_offsets_per_layer[layer_idx] += num_blocks; + } else { + auto seq_id_to_select_logical_idx_map = seq_id_to_select_logical_idx_maps[layer_idx]; + if (seq_id_to_select_logical_idx_map.find(seq_id) != seq_id_to_select_logical_idx_map.end()) { + continue; + } + auto select_logical_idxs = seq_id_to_select_logical_idx_map[seq_id]; + for (size_t block_id = 0; block_id < select_logical_idxs.size(); ++block_id) { + size_t logical_block_idx = select_logical_idxs[block_id]; + OPENVINO_ASSERT(logical_block_idx < num_blocks); + block_indices_data[block_id] = kv_blocks[layer_idx][logical_block_idx]->get_index(); + } + + block_offsets_per_layer[layer_idx] += select_logical_idxs.size(); + } } - - block_offset += num_blocks; } } } @@ -264,21 +279,24 @@ class ModelRunner { void _set_cache_rotation_coefficients(const std::vector & sequence_groups, const Scheduler::Output& scheduler_output) { for (size_t i = 0; i < m_num_decoder_layers; i++) { - auto tensor_name = std::string("cache_rotation_coefficients.") + std::to_string(i); + auto tensor_name = std::string("rotation_coefficients.") + std::to_string(i); m_request.set_tensor(tensor_name, m_cache_rotation_coefficients[i]); } std::vector rotation_indices_tensor_names(m_num_decoder_layers); - std::vector rotation_indices_sizes_in_blocks(m_num_decoder_layers); for (size_t i = 0; i < m_num_decoder_layers; i++) { auto tensor_name = std::string("rotated_block_indices.") + std::to_string(i); rotation_indices_tensor_names[i] = tensor_name; - size_t size_in_blocks = m_cache_rotation_coefficients[i].get_size() / m_scheduler_config.block_size; - m_request.get_tensor(tensor_name).set_shape({size_in_blocks}); - rotation_indices_sizes_in_blocks[i] = size_in_blocks; + size_t num_indices = 0; + for (const auto& entry : m_rotated_block_logical_indices_per_sequence_for_each_layer[i]) { + num_indices += entry.second.size(); + } + m_request.get_tensor(tensor_name).set_shape({num_indices}); } - _fill_indices_from_block_tables(rotation_indices_tensor_names, sequence_groups, scheduler_output, rotation_indices_sizes_in_blocks); + // NB: the order of per-sequence index filling in the function below must be the same + // as the order of `seq_id`s in which the "rotation_coefficients.N" inputs are filled + _fill_indices_from_block_tables(rotation_indices_tensor_names, sequence_groups, scheduler_output, m_rotated_block_logical_indices_per_sequence_for_each_layer); } void _collect_attention_scores(const std::vector & sequence_groups, const Scheduler::Output& scheduler_output) { diff --git a/src/cpp/src/utils/paged_attention_transformations.cpp b/src/cpp/src/utils/paged_attention_transformations.cpp index 53690f770c..43b0a89b17 100644 --- a/src/cpp/src/utils/paged_attention_transformations.cpp +++ b/src/cpp/src/utils/paged_attention_transformations.cpp @@ -32,7 +32,8 @@ void apply_paged_attention_transformations(std::shared_ptr model, boo bool use_block_indices_inputs = per_layer_cache_control; bool use_score_outputs = per_layer_cache_control; - ov::pass::SDPAToPagedAttention(use_block_indices_inputs, use_score_outputs).run_on_model(model); + bool allow_cache_rotation = per_layer_cache_control; + ov::pass::SDPAToPagedAttention(use_block_indices_inputs, use_score_outputs, allow_cache_rotation).run_on_model(model); } void set_kv_cache_type_and_shape(std::shared_ptr model, DeviceConfig& device_config) { @@ -80,4 +81,4 @@ void apply_paged_attention_transformations(std::shared_ptr model, Dev } // namespace utils } // namespace genai -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/tests/python_tests/test_cache_optimizations.py b/tests/python_tests/test_cache_optimizations.py index 49cb04ca1f..f8104c0028 100644 --- a/tests/python_tests/test_cache_optimizations.py +++ b/tests/python_tests/test_cache_optimizations.py @@ -22,7 +22,7 @@ def load_prompts_dataset(file_name : str) -> Dict[str, List[str]]: file_path = TESTS_ROOT / 'data' / file_name with open(file_path, 'r') as f: - return {"prompts": [s for s in f]} + return {"questions": [s for s in f]} def get_scheduler_config(num_kv_blocks: int) -> SchedulerConfig: scheduler_config = SchedulerConfig() @@ -118,7 +118,7 @@ def test_cache_optimized_generation_is_similar_to_unoptimized(converted_model, t data_dict = load_prompts_dataset(test_struct.prompt_file) - evaluator = whowhatbench.TextEvaluator(base_model=model_cb_noopt, tokenizer=tokenizer, test_data=data_dict, + evaluator = whowhatbench.Evaluator(base_model=model_cb_noopt, tokenizer=tokenizer, test_data=data_dict, generation_config=generation_config, generation_config_base=generation_config, max_new_tokens=test_struct.max_new_tokens, seqs_per_request=seqs_per_request)