Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][Gaudi][Feature] Enable Dynamic MoE for Mixtral #12303

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
is_prompt: bool
attn_bias: Optional[torch.Tensor]
seq_lens_tensor: Optional[torch.Tensor]
context_lens_tensor: Optional[torch.Tensor]


class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/ops/hpu_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class HPUPagedAttentionMetadata:
block_indices: Optional[torch.Tensor]
block_offsets: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]


class HPUPagedAttention:
Expand Down
23 changes: 23 additions & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,26 @@ def forward_cpu(
num_expert_group,
)

def forward_hpu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
**kwargs,
):
assert not use_grouped_topk, "use_grouped_topk must be False on HPU"
assert num_expert_group is None, ('num_expert_group is '
'not supported on HPU')
assert topk_group is None, "topk_group is not supported on HPU"
if layer is not None:
return layer.hpu_fused_moe(x, layer.w13_weight, layer.w2_weight,
router_logits, top_k)

def forward_tpu(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -281,6 +301,9 @@ def __init__(
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
if current_platform.is_hpu():
from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.num_experts)

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP
Expand Down Expand Up @@ -495,4 +496,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if current_platform.is_hpu():
torch.hpu.synchronize()
return loaded_params
5 changes: 5 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,11 @@ def reset(self):
self._index = 0


@lru_cache(maxsize=None)
def is_fake_hpu() -> bool:
return os.environ.get('VLLM_USE_FAKE_HPU', '0') != '0'


@lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
Expand Down
187 changes: 129 additions & 58 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch
import torch.nn as nn
from vllm_hpu_extension.ops import LoraMask as LoraMask
from vllm_hpu_extension.ops import batch2block, block2batch
from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
HabanaMemoryProfiler, format_bytes)

Expand All @@ -42,7 +43,7 @@
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
from vllm.utils import (bind_kv_cache, is_pin_memory_available,
from vllm.utils import (bind_kv_cache, is_fake_hpu, is_pin_memory_available,
make_tensor_with_pad)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
Expand Down Expand Up @@ -258,10 +259,19 @@ def setup_profiler():
return profiler


def pad_list(list, k, v):
target_len = round_up(len(list), k)
padding = target_len - len(list)
return list + [v] * padding
def pad_list(input, k, v):
input_len = len(input)
target_len = round_up(input_len, k)
padding = target_len - input_len
return input + [v] * padding


def gather_list(input, indices, v):
return [input[i] if i is not None else v for i in indices]


def flatten(in_list):
return list(itertools.chain(*in_list))


def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt):
Expand Down Expand Up @@ -332,23 +342,60 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype):
mask = mask >= metadata.block_usage.unsqueeze(-1)
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
mask, -math.inf))
block_mapping = torch.nn.functional.one_hot(metadata.block_mapping,
num_classes=batch_size)
if not is_fake_hpu() and htorch.utils.internal.is_lazy():
block_mapping = torch.nn.functional.one_hot(metadata.block_groups,
num_classes=batch_size)
else:
# Unfortunately one_hot on CPU/torch.compile mode/eager mode
# doesn't handle out of bounds classes so we need to convert
# all negative values to 0 (block_mapping) or bs (block_groups)
block_groups = metadata.block_groups.to(torch.long)
block_mapping = torch.nn.functional.relu(block_groups)
block_mapping = torch.nn.functional.one_hot(block_mapping,
num_classes=batch_size)
oob_values = block_groups.lt(0)
block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0)
block_groups.masked_fill_(oob_values, batch_size)
metadata = metadata._replace(block_groups=block_groups)
block_mapping = block_mapping.to(dtype)
metadata = metadata._replace(block_mapping=block_mapping,
attn_bias=attn_bias)
return metadata

def _set_block_scales(self, metadata, device):
block_mapping = metadata.block_mapping
ones = torch.ones((block_mapping.size(0), ),
device=device,
dtype=block_mapping.dtype)
sums = batch2block(block2batch(ones, block_mapping), block_mapping)
block_scales = torch.reciprocal(torch.maximum(ones, sums))
metadata = metadata._replace(block_scales=block_scales)
return metadata

def _set_indices_and_offsets(self, metadata, block_size, is_prompt):
slot_mapping = metadata.slot_mapping.flatten()
indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
if is_prompt:
indices = indices.unflatten(0, (-1, block_size))[:, 0]
offsets = None
else:
offsets = torch.fmod(slot_mapping, block_size)
metadata = metadata._replace(block_offsets=offsets,
block_indices=indices)
return metadata

def _update_metadata(self, attn_metadata, batch_size, seq_len, device,
dtype):
if attn_metadata.is_prompt:
meta = attn_metadata
attn_metadata = self._set_attn_bias(meta, batch_size, seq_len,
device, dtype)
attn_metadata = self._set_attn_bias(attn_metadata, batch_size,
seq_len, device, dtype)
else:
meta = attn_metadata
attn_metadata = self._set_block_mapping(meta, batch_size, device,
dtype)
attn_metadata = self._set_block_mapping(attn_metadata, batch_size,
device, dtype)
attn_metadata = self._set_block_scales(attn_metadata, device)
attn_metadata = self._set_indices_and_offsets(attn_metadata,
self.block_size,
attn_metadata.is_prompt)
return attn_metadata

def forward(self, *args, **kwargs):
Expand Down Expand Up @@ -584,6 +631,8 @@ def __init__(
self.bucketing_global_state = HPUBucketingGlobalState()
self._setup_buckets()
self._set_gc_threshold()
self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA',
'true').lower() == 'true'

def _set_gc_threshold(self) -> None:
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
Expand Down Expand Up @@ -888,6 +937,11 @@ def _prepare_prompt(

block_indices, block_offsets = precompute_indices_and_offsets(
self.block_size, slot_mapping, True)
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.long,
device='cpu')
context_lens_tensor = context_lens_tensor.to(self.device,
non_blocking=True)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
block_list=None,
Expand All @@ -896,8 +950,10 @@ def _prepare_prompt(
block_indices=block_indices,
block_offsets=block_offsets,
block_scales=None,
block_groups=None,
attn_bias=None,
seq_lens_tensor=seq_lens_tensor,
context_lens_tensor=context_lens_tensor,
num_prefills=real_num_seqs,
num_prefill_tokens=sum_query_len,
num_decode_tokens=0,
Expand All @@ -922,6 +978,7 @@ def _prepare_prompt(
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
output=None,
) -> PrepareDecodeMetadata:
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
Expand Down Expand Up @@ -952,8 +1009,9 @@ def _prepare_decode(

for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
if output is None:
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])

seq_len = seq_data.get_len()
position = seq_len - 1
Expand All @@ -964,6 +1022,9 @@ def _prepare_decode(
seq_lens.append(seq_len)

block_table = seq_group_metadata.block_tables[seq_id]
num_fully_occupied_blocks = position // self.block_size
block_table = block_table[:num_fully_occupied_blocks + 1]

if len(block_table) == 0:
block_number = _PAD_BLOCK_ID
else:
Expand All @@ -983,76 +1044,85 @@ def _prepare_decode(
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)

input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
if output is None:
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
else:
real_batch_size = len(seq_group_metadata_list)
input_tokens = output[:real_batch_size]

input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)

num_decode_tokens = sum(seq_lens)

blocks_used = [len(bt) for bt in block_tables if bt]
block_list = []
block_scales = []
for i, bt in enumerate(block_tables):
block_list.extend(bt)
blocks_in_group = len(bt)
if blocks_in_group > 0:
scale = 1.0 / blocks_in_group
block_scales.extend([scale] * blocks_in_group)

block_mapping_nested: List[List[int]] = [
[i] * b_u for i, b_u in enumerate(blocks_used)
last_block_usage = [
slot[0] % self.block_size + 1 for slot in slot_mapping
]
block_mapping: List[int] = list(
itertools.chain.from_iterable(block_mapping_nested))
block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)]
block_usage = [[self.block_size] * (len(bt) - 1) + [lbu]
for bt, lbu in zip(block_tables, last_block_usage)
if bt]

block_list = flatten(block_tables)
block_groups = flatten(block_groups)
block_usage = flatten(block_usage)

assert len(block_list) == len(block_groups)
assert len(block_list) == len(block_usage)

padding_fn = None
if self.use_contiguous_pa:
block_bucket_size = find_bucket(
max(block_list) + 1,
self.bucketing_global_state.decode_block_bucket_cfg)
indices: List[Any]
indices = [None] * block_bucket_size
for i, bid in enumerate(block_list):
indices[bid] = i
padding_fn = lambda tensor, pad_value: gather_list(
tensor, indices, pad_value)
else:
block_bucket_size = find_bucket(
len(block_list),
self.bucketing_global_state.decode_block_bucket_cfg)
padding_fn = lambda tensor, pad_value: pad_list(
tensor, block_bucket_size, pad_value)

last_block = [
sl % self.block_size + 1 for sl in itertools.chain(*slot_mapping)
]
block_usage = [[self.block_size] * (b_u - 1) + [lb]
for b_u, lb in zip(blocks_used, last_block)]
block_usage = list(itertools.chain(*block_usage))

block_bucket_size = find_bucket(
len(block_list),
self.bucketing_global_state.decode_block_bucket_cfg)
block_list = pad_list(block_list, block_bucket_size, _PAD_BLOCK_ID)
block_mapping = pad_list(block_mapping, block_bucket_size, -1)
block_usage = pad_list(block_usage, block_bucket_size, 1)
block_scales = pad_list(block_scales, block_bucket_size, 0.0)
block_list = padding_fn(block_list, _PAD_BLOCK_ID)
block_groups = padding_fn(block_groups, -1)
block_usage = padding_fn(block_usage, 1)

block_list = torch.tensor(block_list,
dtype=torch.int,
device=self.device)
block_mapping = torch.tensor(block_mapping,
dtype=torch.long,
device=self.device)
block_groups = torch.tensor(block_groups,
dtype=torch.int,
device=self.device)
block_usage = torch.tensor(block_usage,
dtype=self.model_config.dtype,
device=self.device)

slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)

block_indices, block_offsets = precompute_indices_and_offsets(
self.block_size, slot_mapping, False)
block_scales = torch.tensor(block_scales,
dtype=self.model_config.dtype,
device=self.device)

attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
block_list=block_list,
block_mapping=block_mapping,
block_mapping=None,
block_usage=block_usage,
block_indices=block_indices,
block_offsets=block_offsets,
block_scales=block_scales,
block_scales=None,
block_groups=block_groups,
attn_bias=None,
seq_lens_tensor=None,
context_lens_tensor=None,
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=num_decode_tokens,
Expand Down Expand Up @@ -1260,9 +1330,10 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
# input_hash(123) != input_hash(321)
# input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping',
'block_usage', 'slot_mapping', 'is_prompt', 'block_indices',
'block_offsets', 'block_scales'
'attn_bias', 'seq_lens_tensor', 'context_lens_tensor',
'block_list', 'block_mapping', 'block_usage', 'slot_mapping',
'is_prompt', 'block_indices', 'block_offsets', 'block_scales',
'block_groups'
])
return attention_metadata

Expand Down
Loading