Skip to content

Commit

Permalink
cleanup and format
Browse files Browse the repository at this point in the history
  • Loading branch information
libinta committed Jan 31, 2025
1 parent 3178d90 commit c0da1e0
Showing 1 changed file with 93 additions and 66 deletions.
159 changes: 93 additions & 66 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.model_loader.utils import get_architecture_class_name
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.sampling_metadata import SequenceGroupToSample
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs, MultiModalPlaceholderMap,
Expand Down Expand Up @@ -76,6 +76,7 @@

LORA_WARMUP_RANK = 8


def subtuple(obj: object,
typename: str,
to_copy: List[str],
Expand Down Expand Up @@ -263,7 +264,8 @@ def _compile_region(self, model, name, module):

def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
dtype):
if (attn_metadata is None or (self.prefill_use_fusedsdpa and self.is_causal)
if (attn_metadata is None
or (self.prefill_use_fusedsdpa and self.is_causal)
or not attn_metadata.is_prompt):
return attn_metadata

Expand All @@ -290,19 +292,22 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
query_lens_t.unsqueeze(-1)).view(
batch_size, 1, 1, seq_len))
if self.is_causal:
attn_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len),
device=device,
dtype=torch.bool),diagonal=1)
attn_mask = torch.triu(torch.ones(
(batch_size, 1, seq_len, seq_len),
device=device,
dtype=torch.bool),
diagonal=1)
else:
attn_mask = torch.zeros((batch_size, 1, seq_len, seq_len),
device=device,
dtype=torch.bool)
device=device,
dtype=torch.bool)
if hasattr(self.model.model, "_pooler"):
len_mask_v = len_mask.view(batch_size, 1, seq_len, 1)
mask = attn_mask.logical_or(len_mask).logical_or(len_mask_v)
off_value = -3E38 #small number, avoid nan and overflow
off_value = -3E38 #small number, avoid nan and overflow
else:
mask = attn_mask.logical_or(len_mask)#no need for len_mask_v as decode overwrites it
mask = attn_mask.logical_or(
len_mask) #no need for len_mask_v as decode overwrites it
off_value = -math.inf

mask = torch.concat((past_mask, mask), dim=-1)
Expand Down Expand Up @@ -418,7 +423,7 @@ def forward(self, *args, **kwargs):
kwargs['attn_metadata'] = self._update_metadata(
kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1),
input_ids.device, self.dtype)
if 'lora_mask' in kwargs.keys():
if 'lora_mask' in kwargs:
LoraMask.setLoraMask(kwargs.pop('lora_mask'))
if self.layer_names is not None:
self._prepare_cos_sin(kwargs['positions'])
Expand All @@ -428,8 +433,8 @@ def forward(self, *args, **kwargs):
hidden_states = self.model(*args, **kwargs)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
if selected_token_indices is not None:
hidden_states = hidden_states.index_select(0,
selected_token_indices)
hidden_states = hidden_states.index_select(
0, selected_token_indices)
return hidden_states

def compute_logits(self, *args, **kwargs):
Expand Down Expand Up @@ -716,14 +721,11 @@ def _set_gc_threshold(self) -> None:

self.skip_warmup = os.environ.get('VLLM_SKIP_WARMUP',
'false').lower() == 'true'

def is_causal(self) -> bool:
non_causal = ["Bert", "Roberta", "Bart"]
model_name = get_architecture_class_name(self.model_config)
if any([m in model_name for m in non_causal]):
return False
else:
return True
return any([m in model_name for m in non_causal])

def load_model(self) -> None:
model_arch_causal = self.is_causal()
Expand Down Expand Up @@ -870,6 +872,7 @@ def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode):
phase = 'prompt' if is_prompt else 'decode'
logger.warning("Configuration: (%s, %s, %s) was not warmed-up!",
phase, batch_size, seq_len)

def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
Expand Down Expand Up @@ -1019,8 +1022,8 @@ def _prepare_prompt(
lora_index_mapping += [lora_id] * max_prompt_len
lora_prompt_mapping.extend(
[lora_id] *
(max_prompt_len
if seq_group_metadata.sampling_params and seq_group_metadata.sampling_params.prompt_logprobs else 1))
(max_prompt_len if seq_group_metadata.sampling_params and
seq_group_metadata.sampling_params.prompt_logprobs else 1))

if any(context_lens):
assert not self.scheduler_config.chunked_prefill_enabled
Expand Down Expand Up @@ -1455,10 +1458,9 @@ def prepare_input_tensors(
) = self._prepare_decode(decode_reqs)

if not hasattr(self.model.model, "_pooler"):
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
seq_lens, query_lens,
self.device,
self.pin_memory)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.pin_memory)

if not self.scheduler_config.chunked_prefill_enabled:
assert (len(prefill_reqs) and len(decode_reqs)) == 0
Expand Down Expand Up @@ -1492,12 +1494,14 @@ def prepare_input_tensors(

if not hasattr(self.model.model, "_pooler"):
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
if seq_group_metadata.sampling_params and seq_group_metadata.sampling_params.prompt_logprobs is not None \
and seq_group_metadata.is_prompt:
if seq_group_metadata.sampling_params \
and seq_group_metadata.sampling_params.prompt_logprobs \
is not None and seq_group_metadata.is_prompt:
paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i])

paddings = torch.tensor(
paddings_prompt_logprobs if paddings_prompt_logprobs else paddings,
paddings_prompt_logprobs
if paddings_prompt_logprobs else paddings,
dtype=sampling_metadata.selected_token_indices.dtype,
device=sampling_metadata.selected_token_indices.device)
sampling_metadata.selected_token_indices.add_(paddings)
Expand All @@ -1522,19 +1526,33 @@ def prepare_input_tensors(
batch_type = BatchType.DECODE

metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
"selected_token_indices": sampling_metadata.selected_token_indices if sampling_metadata else None,
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
"multi_modal_kwargs": multi_modal_kwargs,
"num_prefill_tokens": num_prefill_tokens,
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
"num_prefills": num_prefills,
"batch_type": batch_type,
"seq_lens": seq_lens,
"query_lens": query_lens
"input_tokens":
input_tokens,
"input_positions":
input_positions,
"selected_token_indices":
sampling_metadata.selected_token_indices
if sampling_metadata else None,
"lora_requests":
lora_requests,
"lora_mapping":
lora_mapping,
"multi_modal_kwargs":
multi_modal_kwargs,
"num_prefill_tokens":
num_prefill_tokens,
"num_decode_tokens":
num_decode_tokens,
"slot_mapping":
slot_mapping,
"num_prefills":
num_prefills,
"batch_type":
batch_type,
"seq_lens":
seq_lens,
"query_lens":
query_lens
}
if prefill_attn_metadata is not None:
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
Expand Down Expand Up @@ -1910,10 +1928,10 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
self.warmup_all_buckets(self.bucketing_ctx.prompt_buckets, True,
kv_caches)
if not hasattr(self.model.model, "_pooler"):
self.warmup_all_buckets(self.bucketing_ctx.decode_buckets, False,
kv_caches)
self.warmup_all_buckets(self.bucketing_ctx.decode_buckets,
False, kv_caches)

if not self.enforce_eager and htorch.utils.internal.is_lazy() and not hasattr(self.model.model, "_pooler"):
if not self.enforce_eager and htorch.utils.internal.is_lazy():
assert self.mem_margin is not None, \
("HabanaWorker.determine_num_available_blocks needs "
"to be called before warming up the model.")
Expand All @@ -1925,8 +1943,11 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.3'))
prompt_available_memory = (prompt_graph_mem_ratio *
graph_free_mem)
decode_available_memory = (graph_free_mem -
prompt_available_memory)
if hasattr(self.model.model, "_pooler"):
decode_available_memory = 0
else:
decode_available_memory = (graph_free_mem -
prompt_available_memory)
msg = (
f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
Expand All @@ -1949,42 +1970,49 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
decode_strategy, self.bucketing_ctx.decode_buckets,
False, kv_caches, decode_available_memory)

# Not all prompt buckets were captured, but all decode buckets
# were captured and we have some free graph-allocated space
# left. Let's try to use it for capturing more prompt buckets.
# Not all prompt buckets were captured, but all decode
# buckets were captured and we have some free
# graph-allocated space left. Let's try to use it for
# capturing more prompt buckets.
if (mem_post_decode + mem_post_prompt < graph_free_mem
and not prompt_captured_all and decode_captured_all):
and not prompt_captured_all
and decode_captured_all):
mem_post_prompt, _, prompt_captured_all = (
self.warmup_graphs(
prompt_strategy, self.bucketing_ctx.prompt_buckets,
True, kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_prompt, prompt_batch_seq))
# Not all decode buckets were captured, but all prompt buckets
# were captured and we have some free graph-allocated space
# left. Let's try to use it for capturing more decode buckets.
prompt_strategy,
self.bucketing_ctx.prompt_buckets, True,
kv_caches, graph_free_mem - mem_post_prompt -
mem_post_decode, mem_post_prompt,
prompt_batch_seq))
# Not all decode buckets were captured, but all prompt
# buckets were captured and we have some free
# graph-allocated space left. Let's try to use it for
# capturing more decode buckets.
if mem_post_decode + mem_post_prompt < graph_free_mem \
and not decode_captured_all \
and prompt_captured_all:
mem_post_decode, _, _ = self.warmup_graphs(
decode_strategy, self.bucketing_ctx.decode_buckets,
False, kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_decode, decode_batch_seq)
decode_strategy,
self.bucketing_ctx.decode_buckets, False,
kv_caches, graph_free_mem - mem_post_prompt -
mem_post_decode, mem_post_decode,
decode_batch_seq)
else:
if mem_post_prompt < graph_free_mem and not prompt_captured_all:
mem_post_prompt, _, prompt_captured_all = (
if mem_post_prompt < graph_free_mem and \
not prompt_captured_all:
mem_post_prompt, _, prompt_captured_all = (
self.warmup_graphs(
prompt_strategy, self.bucketing_ctx.prompt_buckets,
True, kv_caches,
graph_free_mem - mem_post_prompt,
prompt_strategy,
self.bucketing_ctx.prompt_buckets, True,
kv_caches, graph_free_mem - mem_post_prompt,
mem_post_prompt, prompt_batch_seq))

self.log_graph_warmup_summary(
self.bucketing_ctx.prompt_buckets, True, mem_post_prompt)
if not hasattr(self.model.model, "_pooler"):
self.log_graph_warmup_summary(
self.bucketing_ctx.decode_buckets, False, mem_post_decode)
self.bucketing_ctx.decode_buckets, False,
mem_post_decode)

end_time = time.perf_counter()
end_mem = HabanaMemoryProfiler.current_device_memory_usage()
Expand Down Expand Up @@ -2160,7 +2188,6 @@ def prepare_model_input(
is_prompt=is_prompt,
virtual_engine=virtual_engine)


def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],
is_prompt: bool):
'''
Expand Down

0 comments on commit c0da1e0

Please sign in to comment.