Skip to content

Commit

Permalink
[Hardware][Gaudi] Support fused MOE
Browse files Browse the repository at this point in the history
Signed-off-by: zhenwei <[email protected]>
  • Loading branch information
zhenwei-intel committed Jan 23, 2025
1 parent 68c4421 commit 25c3c49
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 63 deletions.
2 changes: 1 addition & 1 deletion vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
is_prompt: bool
attn_bias: Optional[torch.Tensor]
seq_lens_tensor: Optional[torch.Tensor]

context_lens_tensor: Optional[torch.Tensor]

class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
"""
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/ops/hpu_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class HPUPagedAttentionMetadata:
block_indices: Optional[torch.Tensor]
block_offsets: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]

block_groups: Optional[torch.Tensor]

class HPUPagedAttention:

Expand Down
23 changes: 23 additions & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,26 @@ def forward_cpu(
num_expert_group,
)

def forward_hpu(self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
**kwargs,
):
assert not use_grouped_topk, 'use_grouped_topk must be False on HPU'
assert num_expert_group is None, ('num_expert_group is '
'not supported on HPU')
assert topk_group is None, 'topk_group is not supported on HPU'
if layer is not None:
return layer.hpu_fused_moe(x, layer.w13_weight,
layer.w2_weight, router_logits,
top_k)

def forward_tpu(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -281,6 +301,9 @@ def __init__(
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
if current_platform.is_hpu():
from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.num_experts)

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from vllm.platforms import current_platform
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
Expand Down Expand Up @@ -495,4 +495,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if current_platform.is_hpu():
torch.hpu.synchronize()
return loaded_params
3 changes: 3 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,9 @@ def reset(self):
"""
self._index = 0

@lru_cache(maxsize=None)
def is_fake_hpu() -> bool:
return os.environ.get('VLLM_USE_FAKE_HPU', '0') != '0'

@lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
Expand Down
Loading

0 comments on commit 25c3c49

Please sign in to comment.