Skip to content

Commit

Permalink
Apply to torch model
Browse files Browse the repository at this point in the history
  • Loading branch information
yelite committed Feb 23, 2024
1 parent 0b574aa commit cd4f1c8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
12 changes: 10 additions & 2 deletions serve/mlc_serve/model/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def profile_and_init_cache(
hf_config,
num_shards,
max_num_batched_tokens,
max_num_seq,
gpu_memory_utilization,
):
num_kv_heads = hf_config.num_key_value_heads // num_shards
num_hidden_layers = hf_config.num_hidden_layers
Expand All @@ -177,7 +179,9 @@ def profile_and_init_cache(

if max_num_batched_tokens > 0:
LOG.info("Running memory profiling.")
seq_lens = [1] * max_num_batched_tokens
seq_len = max_num_batched_tokens // max_num_seq
seq_lens = [seq_len] * max_num_seq
seq_lens[-1] += max_num_batched_tokens % max_num_seq
used_memory_bytes = profile_memory_usage(
pt_model, seq_lens, num_hidden_layers, hf_config.vocab_size
)
Expand All @@ -187,7 +191,7 @@ def profile_and_init_cache(
hf_config.num_hidden_layers,
num_kv_heads,
head_size,
gpu_memory_utilization=0.9,
gpu_memory_utilization,
)
else:
num_blocks = 500
Expand Down Expand Up @@ -424,6 +428,8 @@ def exposed_init_model(
hf_config,
num_shards,
engine_config.max_num_batched_tokens,
engine_config.max_num_seq,
engine_config.gpu_memory_utilization,
)

return num_blocks
Expand Down Expand Up @@ -594,6 +600,8 @@ def __init__(
hf_config,
1,
engine_config.max_num_batched_tokens,
engine_config.max_num_seq,
engine_config.gpu_memory_utilization,
)
self.model_rpc = None

Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/model/tvm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def init_tvm_model(
if engine_config.max_num_batched_tokens > 0:
LOG.info("Running memory profiling.")
try:
max_num_seq = engine_config.max_num_seq
max_num_seq = engine_config.max_num_seq
max_num_batched_tokens = engine_config.max_num_batched_tokens
seq_len = max_num_batched_tokens // max_num_seq
seq_lens = [seq_len] * max_num_seq
Expand Down

0 comments on commit cd4f1c8

Please sign in to comment.