diff --git a/requirements-hpu.txt b/requirements-hpu.txt index 6992238b5dd82..07f9c31117e49 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,7 +8,7 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@a69bb99 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@61334c5 neural-compressor @ git+https://github.com/intel/neural-compressor.git@b196432 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index a45036ba764fe..394c427fed330 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -115,7 +115,7 @@ def _generic_padding_fn(self, batch_size, max_seq_len) -> int: return batch_size * max_seq_len def _hpu_padding_fn(self, batch_size, max_seq_len): - from vllm.worker.hpu_model_runner import (HPUBucketingGlobalState, + from vllm_hpu_extension.bucketing import (HPUBucketingGlobalState, find_bucket) padded_bs = batch_size padded_seq = max_seq_len diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index b064da65db67d..7aa68d1e98abf 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -9,11 +9,9 @@ import gc import itertools import math -import operator import os import time from array import array -from dataclasses import dataclass, field from enum import IntEnum from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union) @@ -21,6 +19,7 @@ import habana_frameworks.torch as htorch import habana_frameworks.torch.internal.bridge_config as bc import torch +from vllm_hpu_extension.bucketing import HPUBucketingContext from vllm_hpu_extension.ops import LoraMask as LoraMask from vllm_hpu_extension.ops import batch2block, block2batch from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, @@ -68,262 +67,6 @@ LORA_WARMUP_RANK = 8 -class Singleton(type): - _instances: Dict[type, object] = {} - - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - cls._instances[cls] = super().__call__(*args, **kwargs) - return cls._instances[cls] - - -@dataclass -class HPUBucketingGlobalState(metaclass=Singleton): - prompt_bs_bucket_cfg: Tuple[int, int, int] = field(init=False) - decode_bs_bucket_cfg: Tuple[int, int, int] = field(init=False) - prompt_seq_bucket_cfg: Tuple[int, int, int] = field(init=False) - decode_block_bucket_cfg: Tuple[int, int, int] = field(init=False) - prompt_buckets: List[Tuple[int, int]] = field(init=False) - decode_buckets: List[Tuple[int, int]] = field(init=False) - - -class HPUBucketingContext(metaclass=Singleton): - global_state = HPUBucketingGlobalState() - - def __init__(self, max_num_seqs, max_num_prefill_seqs, block_size, - max_num_batched_tokens): - self.max_num_seqs = max_num_seqs - self.max_num_prefill_seqs = max_num_prefill_seqs - self.block_size = block_size - self.max_num_batched_tokens = max_num_batched_tokens - self._setup_buckets() - - def _setup_buckets(self) -> None: - align_bs = lambda x: min(self.max_num_seqs, x) - #FIXME: The default values should be max_model_len - max_prompt_seq = 1024 - max_decode_seq = 2048 - self.global_state.prompt_bs_bucket_cfg = read_bucket_settings( - 'prompt', - 'bs', - min=1, - step=align_bs(32), - max=self.max_num_prefill_seqs) - self.global_state.decode_bs_bucket_cfg = read_bucket_settings( - 'decode', 'bs', min=1, step=align_bs(32), max=self.max_num_seqs) - self.global_state.prompt_seq_bucket_cfg = \ - read_bucket_settings( - 'prompt', - 'seq', - min=self.block_size, - step=self.block_size, - max=max_prompt_seq) - self.global_state.decode_block_bucket_cfg = \ - read_bucket_settings( - 'decode', - 'block', - min=self.block_size, - step=self.block_size, - max=max(self.block_size, - self.max_num_seqs * max_decode_seq // self.block_size)) - - msg = ("Prompt bucket config (min, step, max_warmup) " - f"bs:{self.global_state.prompt_bs_bucket_cfg}, " - f"seq:{self.global_state.prompt_seq_bucket_cfg}") - logger.info(msg) - - msg = ("Decode bucket config (min, step, max_warmup) " - f"bs:{self.global_state.decode_bs_bucket_cfg}, " - f"block:{self.global_state.decode_block_bucket_cfg}") - logger.info(msg) - - def generate_prompt_buckets(self): - self.global_state.prompt_buckets, prompt_omitted_buckets = \ - generate_prompt_buckets( - self.global_state.prompt_bs_bucket_cfg, - self.global_state.prompt_seq_bucket_cfg, - self.max_num_batched_tokens) - - msg = (f"Generated {len(self.global_state.prompt_buckets)} " - f"prompt buckets [bs, seq]: \ - {list(sorted(self.global_state.prompt_buckets))}") - logger.info(msg) - - msg = (f"Omitted {len(prompt_omitted_buckets)} " - "prompt buckets due to exceeded token budget " - f"(max_num_batched_tokens={self.max_num_batched_tokens})") - logger.info(msg) - - msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" - logger.debug(msg) - - def generate_decode_buckets(self, max_blocks): - self.global_state.decode_buckets = generate_decode_buckets( - self.global_state.decode_bs_bucket_cfg, - self.global_state.decode_block_bucket_cfg, max_blocks) - logger.info("Generated %d decode buckets [bs, total_blocks]: %s", - len(self.global_state.decode_buckets), - list(sorted(self.global_state.decode_buckets))) - - def get_padded_prompt_batch_size(self, batch_size): - return find_bucket(batch_size, self.global_state.prompt_bs_bucket_cfg) - - def get_padded_decode_batch_size(self, batch_size): - return find_bucket(batch_size, self.global_state.decode_bs_bucket_cfg) - - def get_padded_prompt_seq_len(self, seq_len): - return find_bucket(seq_len, self.global_state.prompt_seq_bucket_cfg) - - def get_padded_decode_num_blocks(self, num_blocks): - return find_bucket(num_blocks, - self.global_state.decode_block_bucket_cfg) - - def get_padded_batch_size(self, batch_size, is_prompt): - if is_prompt: - return self.get_padded_prompt_batch_size(batch_size) - return self.get_padded_decode_batch_size(batch_size) - - def get_padded_seq_or_block(self, seq_or_block, is_prompt): - if is_prompt: - return self.get_padded_prompt_seq_len(seq_or_block) - return self.get_padded_decode_num_blocks(seq_or_block) - - @property - def prompt_buckets(self): - return self.global_state.prompt_buckets - - @property - def decode_buckets(self): - return self.global_state.decode_buckets - - -def read_bucket_settings(phase: str, dim: str, **defaults): - """Read bucketing configuration from env variables. - - phase is either 'prompt' or 'decode' - dim is either 'bs', 'seq' or 'block' - param is either 'min', 'step' or 'max' - example env variable: VLLM_DECODE_BS_BUCKET_STEP=128 - """ - params = ['min', 'step', 'max'] - env_vars = [f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper() for p in params] - default_values = [defaults[p] for p in params] - values = [ - int(os.environ.get(e, d)) for e, d in zip(env_vars, default_values) - ] - for e, v, d in zip(env_vars, values, default_values): - logger.info('%s=%s (default:%s)', e, v, d) - return values - - -def warmup_range(config: Tuple[int, int, int]): - """Generate a warmup range. - - Start from bmin and multiply by 2 until you reach bstep. - Then, increase the values in the range by the value of bstep until you - reach bmax. - - Example: - bmin = 2, bstep = 32, bmax = 64 - => ramp_up = (2, 4, 8, 16) - => stable = (32, 64) - => return ramp_up + stable => (2, 4, 8, 16, 32, 64) - """ - bmin, bstep, bmax = config - assert bmin <= bmax, ("Min. batch size cannot be greater than max. " - "batch size. If you want to skip warmup, " - "set VLLM_SKIP_WARMUP=true") - base = itertools.repeat(2) - ramp_up_acc = itertools.accumulate(base, func=operator.mul, initial=bmin) - ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax, \ - ramp_up_acc) - stable = range(bstep, bmax + 1, bstep) - buckets = list(ramp_up_tw) + list(stable) - return list(filter(lambda bucket: bucket >= bmin, buckets)) - - -def generate_prompt_buckets(bs_bucket_config, - seq_bucket_config, - max_num_batched_tokens=None): - buckets = list( - itertools.product(warmup_range(bs_bucket_config), - warmup_range(seq_bucket_config))) - if len(buckets) == 0: - msg = ("No buckets could be captured with following config " - f"(min, step, max_warmup): " - f"bs:{bs_bucket_config}, " - f"seq:{seq_bucket_config}") - raise ValueError(msg) - - filtered_buckets = buckets - if max_num_batched_tokens is not None: - # Remove buckets exceeding batch token budget - filtered_buckets = list( - filter( - lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens, - buckets)) - - if len(filtered_buckets) == 0: - # we can handle this if we ignore max_num_batched_tokens - min_bucket_bs, min_bucket_seq = min(buckets, - key=lambda b: (b[0] * b[1])) - min_reqd_budget = min_bucket_bs * min_bucket_seq - msg = ( - "The current bucketing configuration " - f"(min, step, max_warmup): " - f"bs:{bs_bucket_config}, " - f"seq:{seq_bucket_config} cannot be used with specified " - f"max_num_batched_tokens ({max_num_batched_tokens}), as the " - f"smallest bucket ({min_reqd_budget}) would exceed token " - "budget. Please increase max_num_batched_tokens or decrease " - "bucket minimum Ignoring max_num_batched_tokens at risk of " - "out-of-memory errors.") - logger.error(msg) - return list( - sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))), [] - - captured_buckets = list( - sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) - omitted_buckets = list( - sorted([x for x in buckets if x not in filtered_buckets])) - return captured_buckets, omitted_buckets - - -def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, - max_blocks): - buckets = [] - bs_buckets = warmup_range(bs_bucket_config) - block_buckets = warmup_range(blocks_bucket_config) - bmin, bstep, bmax = blocks_bucket_config - last_bucket = max_blocks - for bs in bs_buckets: - for blocks in block_buckets: - if blocks >= last_bucket: - buckets.append((bs, last_bucket)) - break - buckets.append((bs, blocks)) - return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) - - -def next_pow2(value: int, base: int): - res = base - while value > 1: - value = (value + 1) // 2 - res *= 2 - return res - - -def round_up(value: int, k: int): - return (value + k - 1) // k * k - - -def find_bucket(value: int, config: Tuple[int, int, int]): - bmin, bstep, _ = config - next_step = round_up(value, bstep) - next_pow = next_pow2(value, bmin) - return max(bmin, min(next_step, next_pow)) - - def subtuple(obj: object, typename: str, to_copy: List[str], @@ -355,17 +98,13 @@ def align_workers(value, op): def setup_profiler(): schedule = torch.profiler.schedule(wait=0, warmup=2, active=1, repeat=1) - DEVICE = 'hpu' - activities = [torch.profiler.ProfilerActivity.CPU] - activities.extend([torch.profiler.ProfilerActivity.HPU] if DEVICE == - 'hpu' else []) - #from habana_frameworks.torch.activity_profiler import DebugActivity - #debug_activities=[DebugActivity.BRIDGE_FUNCTION_CALLS] - + activities = [ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.HPU + ] profiler = torch.profiler.profile( schedule=schedule, activities=activities, - #debug_activities=debug_activities, on_trace_ready=torch.profiler.tensorboard_trace_handler('.', use_gzip=True), record_shapes=False, @@ -373,6 +112,10 @@ def setup_profiler(): return profiler +def round_up(value: int, k: int) -> int: + return (value + k - 1) // k * k + + def pad_list(input, k, v): input_len = len(input) target_len = round_up(input_len, k) @@ -1523,11 +1266,9 @@ def create_dummy_seq_group_metadata(self, def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - max_batch_size = self.bucketing_ctx.global_state.prompt_bs_bucket_cfg[ - -1] - max_seq_len = min( - self.bucketing_ctx.global_state.prompt_seq_bucket_cfg[-1], - self.max_num_batched_tokens // max_batch_size) + max_batch_size, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() + max_seq_len = min(max_seq_len, + self.max_num_batched_tokens // max_batch_size) self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, False, True)