Skip to content

Commit

Permalink
Support Vllm sync model backend;
Browse files Browse the repository at this point in the history
  • Loading branch information
yhyu13 committed Dec 9, 2023
1 parent 705f04a commit 095ce50
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 6 deletions.
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ Optionally, you can use the following command-line flags:

| Flag | Description |
|--------------------------------------------|-------------|
| `--loader LOADER` | Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlama_HF, ExLlamav2_HF, AutoGPTQ, AutoAWQ, GPTQ-for-LLaMa, ExLlama, ExLlamav2, ctransformers. |
| `--loader LOADER` | Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlama_HF, ExLlamav2_HF, AutoGPTQ, AutoAWQ, GPTQ-for-LLaMa, ExLlama, ExLlamav2, ctransformers, vllm. |

#### Accelerate/transformers

Expand Down Expand Up @@ -385,6 +385,15 @@ Optionally, you can use the following command-line flags:
| `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". |
| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. |

#### VLLM

| Flag | Description |
|------------------|-------------|
| `--max-model-len MAX_MODEL_LEN` | Model context length. If unspecified, will be automatically derived from the model config.|
| `--dtype “float16”` | Data type for model weights and activations.“auto” will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.|

Refer https://docs.vllm.ai/en/latest/models/engine_args.html for more details. All arguments can simply be passed to textgen webui.

#### RoPE (for llama.cpp, ExLlama, ExLlamaV2, and transformers)

| Flag | Description |
Expand Down
2 changes: 1 addition & 1 deletion extensions/openai/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class GenerationOptions(BaseModel):
preset: str | None = Field(default=None, description="The name of a file under text-generation-webui/presets (without the .yaml extension). The sampling parameters that get overwritten by this option are the keys in the default_preset() function in modules/presets.py.")
min_p: float = 0
top_k: int = 0
top_k: int = 1
repetition_penalty: float = 1
repetition_penalty_range: int = 1024
typical_p: float = 1
Expand Down
6 changes: 6 additions & 0 deletions modules/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@
'trust_remote_code',
'no_use_fast',
'no_flash_attn',
],
'Vllm': [
# Use Vllm's default settings
]
})

Expand Down Expand Up @@ -495,6 +498,9 @@
'skip_special_tokens',
'auto_max_new_tokens',
},
'Vllm': {
# Use Vllm's default settings
},
}

loaders_model_types = {
Expand Down
6 changes: 6 additions & 0 deletions modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def load_model(model_name, loader=None):
'ctransformers': ctransformers_loader,
'AutoAWQ': AutoAWQ_loader,
'QuIP#': QuipSharp_loader,
'Vllm': Vllm_loader,
}

metadata = get_model_metadata(model_name)
Expand Down Expand Up @@ -421,6 +422,11 @@ def RWKV_loader(model_name):
tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir))
return model, tokenizer

def Vllm_loader(model_name):
from modules.vllm import VllmModel

return VllmModel.from_pretrained(model_name)


def get_max_memory_dict():
max_memory = {}
Expand Down
5 changes: 4 additions & 1 deletion modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@
parser.add_argument('--llama_cpp_seed', type=int, default=0, help='DEPRECATED')
parser.add_argument('--use_fast', action='store_true', help='DEPRECATED')

args = parser.parse_args()
args, unknown = parser.parse_known_args()
logger.warning(f'Textgen-webui has been provided with unknown arguments: {unknown}')
args_defaults = parser.parse_args([])
provided_arguments = []
for arg in sys.argv[1:]:
Expand Down Expand Up @@ -243,6 +244,8 @@ def fix_loader_name(name):
return 'AutoAWQ'
elif name in ['quip#', 'quip-sharp', 'quipsharp', 'quip_sharp']:
return 'QuIP#'
elif name in ['vllm', 'Vllm', 'VLLM']:
return 'Vllm'


def add_extension(name, last=False):
Expand Down
6 changes: 3 additions & 3 deletions modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
yield ''
return

if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', 'CtransformersModel']:
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', 'CtransformersModel', 'VllmModel']:
generate_func = generate_reply_custom
else:
generate_func = generate_reply_HF
Expand Down Expand Up @@ -114,7 +114,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
if shared.tokenizer is None:
raise ValueError('No tokenizer is loaded')

if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'CtransformersModel', 'Exllamav2Model']:
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'CtransformersModel', 'Exllamav2Model', 'VllmModel']:
input_ids = shared.tokenizer.encode(str(prompt))
if shared.model.__class__.__name__ not in ['Exllamav2Model']:
input_ids = np.array(input_ids).reshape(1, len(input_ids))
Expand All @@ -129,7 +129,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
if truncation_length is not None:
input_ids = input_ids[:, -truncation_length:]

if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', 'CtransformersModel'] or shared.args.cpu:
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', 'CtransformersModel', 'VllmModel'] or shared.args.cpu:
return input_ids
elif shared.args.deepspeed:
return input_ids.to(device=local_rank)
Expand Down
141 changes: 141 additions & 0 deletions modules/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@

import contextlib
from pathlib import Path
import argparse
import threading

from typing import List, Tuple
from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput
from vllm.utils import random_uuid

from modules import shared
from modules.logging_colors import logger

# Lock vllm to prevent multiple threads from using it at the same time
class LockContextManager:
def __init__(self, lock):
self.lock = lock

def __enter__(self):
self.lock.acquire()

def __exit__(self, exc_type, exc_value, exc_traceback):
self.lock.release()

class VllmModel:
__VLLM_DEBUG__ = False
def __init__(self):
self.inference_lock = threading.Lock()
pass

@classmethod
def from_pretrained(self, path_to_model):

# Parse the arguments, but ignore textgen arguments, only parse vllm arguments
vllm_parser = argparse.ArgumentParser(
description='VllmModel underlyingly uses the Vllm LLMEngine class directly, we will use Vllm argparser to parse the arguments instead')
vllm_parser = EngineArgs.add_cli_args(vllm_parser)
vllm_args, unknown = vllm_parser.parse_known_args()

path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
assert path_to_model.exists(), f'Model {path_to_model} does not exist'

vllm_args.model = str(path_to_model.absolute())

logger.info(f'Parsed vllm_args : {vllm_args}')
engine_args = EngineArgs.from_cli_args(vllm_args)
engine = LLMEngine.from_engine_args(engine_args)

result = self()
result.engine = engine
result.tokenizer = engine.tokenizer

logger.info(f'Loaded model into \n{result.engine}, \n{result.tokenizer}')

return result, result.tokenizer


def generate_with_streaming(self, prompt, state):

# Get sampling settings from textgen
settings = SamplingParams()
for key, value in state.items():
if hasattr(settings, key) and value is not None:
setattr(settings, key, value)
if shared.args.verbose and self.__VLLM_DEBUG__:
logger.debug(f'Setting {key} to {value}')

# use Vllm to Verify the settings
try:
settings._verify_args()
except ValueError as e:
settings = SamplingParams()
logger.warning(f'Vllm Error verifying settings, useing default sampler settings instead {settings}: {e}')

# Get prompt token prompt_token_ids
prompt_token_ids = self.tokenizer.encode(prompt)
# Get max new tokens
if state['auto_max_new_tokens']:
max_new_tokens = state['truncation_length'] - len(prompt_token_ids)
else:
max_new_tokens = state['max_new_tokens']
if max_new_tokens < 0:
logger.warning(f'Max new tokens {max_new_tokens} < 0, setting to 0')
max_new_tokens = 0
settings.max_tokens = max_new_tokens

if shared.args.verbose and self.__VLLM_DEBUG__:
logger.debug(f'Generating with streaming, max_tokens={settings.max_tokens}')
logger.debug(f'Prompt token ids {prompt_token_ids}')
logger.debug(f'Prompt token ids length {len(prompt_token_ids)}')
logger.debug(f'settings {settings}')

# Can only handle 1 sample per generation
assert settings.n == 1, f'Only 1 sample per generation is supported, got {settings.n}'

request_id = f"{random_uuid()}"
with LockContextManager(self.inference_lock):
self.engine.add_request(request_id=request_id,
prompt=prompt,
sampling_params=settings,
prompt_token_ids=prompt_token_ids)

while True:
# Abort generation if we are stopping everything
if shared.stop_everything:
with LockContextManager(self.inference_lock):
self.engine.abort(request_id)
if shared.args.verbose and self.__VLLM_DEBUG__:
logger.debug(f'Aborted generation')
break

target_request_output = None
with LockContextManager(self.inference_lock):
request_outputs: List[RequestOutput] = self.engine.step()

for request_output in request_outputs:
if request_output.request_id != request_id:
logger.warning(f'Request id mismatch, expected {request_id}, got {request_output.request_id}')
continue
# Can only handle 1 sample per generation
assert len(request_output.outputs) == 1, f'Only 1 sample per generation is supported, got {len(request_output.outputs)}'
target_request_output = request_output

output = target_request_output.outputs[0]
decoded_text = output.text
# if shared.args.verbose and self.__VLLM_DEBUG__:
# logger.debug(f'{decoded_text}')
yield decoded_text

if target_request_output.finished:
if shared.args.verbose and self.__VLLM_DEBUG__:
logger.debug(f'Finished generation')
break


def generate(self, prompt, state):
output = ''
for output in self.generate_with_streaming(prompt, state):
pass

return output

0 comments on commit 095ce50

Please sign in to comment.