From 5c29257e962c4bf9eb1e5105bdbf711642b762fc Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Fri, 4 Aug 2023 15:37:36 -0700 Subject: [PATCH 01/15] refactoring configs --- mii/config.py | 253 ++++++++++++++++++++--------- mii/deployment.py | 227 ++++++++------------------ mii/models/score/generate.py | 25 +-- mii/models/score/score_template.py | 25 +-- 4 files changed, 254 insertions(+), 276 deletions(-) diff --git a/mii/config.py b/mii/config.py index 2714cb40..f2a41f3e 100644 --- a/mii/config.py +++ b/mii/config.py @@ -5,78 +5,59 @@ import torch from typing import Union, List from enum import Enum -from pydantic import BaseModel, validator, root_validator +from pydantic import validator, root_validator +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from deepspeed.runtime.config import DTypeEnum from deepspeed.launcher.runner import DLTS_HOSTFILE -class DtypeEnum(Enum): - # The torch dtype must always be the first value (so we return torch.dtype) - fp16 = torch.float16, "torch.float16", "fp16", "float16", "half" - bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16" - fp32 = torch.float32, "torch.float32", "fp32", "float32", "float" - int8 = torch.int8, "torch.int8", "int8" - - # Copied from https://stackoverflow.com/a/43210118 - # Allows us to use multiple values for each Enum index and returns first - # listed value when Enum is called - def __new__(cls, *values): - obj = object.__new__(cls) - # first value is canonical value - obj._value_ = values[0] - for other_value in values[1:]: - cls._value2member_map_[other_value] = obj - obj._all_values = values - return obj - - def __repr__(self): - return "<%s.%s: %s>" % ( - self.__class__.__name__, - self._name_, - ", ".join([repr(v) for v in self._all_values]), - ) +class DeploymentType(Enum): + LOCAL = "local" + AML = "aml" + NON_PERSISTENT = "non-persistent" + + +class TaskType(Enum): + TEXT_GENERATION = "text-generation" + TEXT_CLASSIFICATION = "text-classification" + QUESTION_ANSWERING = "question-answering" + FILL_MASK = "fill-mask" + TOKEN_CLASSIFICATION = "token-classification" + CONVERSATIONAL = "conversational" + TEXT2IMG = "text-to-image" -class MIIConfig(BaseModel): +class ReplicaConfig(BaseModel): + hostname: str = "" + tensor_parallel_ports: List[int] = [] + torch_dist_port: int = None + gpu_indices: List[int] = [] + + +class DeploymentConfig(DeepSpeedConfigModel): + deployment_name: str + model: str + task: TaskType tensor_parallel: int = 1 - port_number: int = 50050 dtype: DtypeEnum = torch.float32 meta_tensor: bool = False load_with_sys_mem: bool = False enable_cuda_graph: bool = False - checkpoint_dict: Union[dict, None] = None - deploy_rank: Union[int, List[int]] = -1 + checkpoint_dict: Optional[Dict[str, Any]] = None + deploy_rank: Optional[List[int]] = None torch_dist_port: int = 29500 - hf_auth_token: str = None replace_with_kernel_inject: bool = True profile_model_time: bool = False skip_model_check: bool = False max_tokens: int = 1024 - enable_restful_api: bool = False - restful_api_port: int = 51080 - replica_num: int = 1 - hostfile: str = DLTS_HOSTFILE trust_remote_code: bool = False - - @validator("deploy_rank") - def deploy_valid(cls, field_value, values): - if "tensor_parallel" not in values: - raise ValueError( - "'tensor_parallel' must be defined in the pydantic model before 'deploy_rank'" - ) - - # if deploy rank is not given, default to align with TP value - if field_value == -1: - field_value = list(range(values["tensor_parallel"])) - - # ensure deploy rank type is always list for easier consumption later - if not isinstance(field_value, list): - field_value = [field_value] - - # number of ranks provided must be equal to TP size, DP is handled outside MII currently - assert values["tensor_parallel"] == len(field_value), \ - f"{len(field_value)} rank(s) provided in 'deploy_rank' does not align with tensor_parallel size of {values['tensor_parallel']}" - return field_value + enable_deepspeed: bool = True + enable_zero: bool = False + ds_config: Dict[str, Any] = {} + model_path: Optional[str] = "" + replica_num: int = 1 + replica_configs: List[ReplicaConfig] = [] @validator('checkpoint_dict') def checkpoint_dict_valid(cls, value): @@ -91,6 +72,49 @@ def checkpoint_dict_valid(cls, value): raise ValueError(f"Missing key={k} in checkpoint_dict") return value + @validator("deploy_rank", pre=True) + def deploy_rank_to_list(cls, field_value, values): + if not isinstance(field_value, list): + field_value = [field_value] + return field_value + + @root_validator + def deploy_rank_valid(cls, values): + tensor_parallel = values.get("tensor_parallel") + deploy_rank = values.get("deploy_rank") + + # if deploy rank is not given, default to align with TP value + if deploy_rank is None: + deploy_rank = list(range(tensor_parallel)) + + # number of ranks provided must be equal to TP size, DP is handled outside MII currently + assert tensor_parallel == len(deploy_rank), \ + f"{len(deploy_rank)} rank(s) provided in 'deploy_rank' does not align with tensor_parallel size of {tensor_parallel}" + + values["deploy_rank"] = deploy_rank + return values + + @root_validator + def set_model_path(cls, values): + if values.get("model_path") is None: + deployment_type = values.get("deployment_type") + if deployment_type == DeploymentType.LOCAL: + model_path = MII_MODEL_PATH_DEFAULT + if deployment_tpye == DeploymentType.AML: + model_path = "model" + values["model_path"] = model_path + return values + + @root_validator + def validate_model_and_task(cls, values): + task = values.get("task") + model = values.get("model") + if not values.get("skip_model_check"): + mii.utils.check_if_task_and_model_is_valid(task, model) + if values.get("ds_optimize"): + mii.utils.check_if_task_and_model_is_supported(task, model) + return values + @root_validator def meta_tensor_or_sys_mem(cls, values): if values.get("meta_tensor") and values.get("load_with_sys_mem"): @@ -99,29 +123,108 @@ def meta_tensor_or_sys_mem(cls, values): ) return values - class Config: - validate_all = True - validate_assignment = True - use_enum_values = True - extra = 'forbid' - json_encoders = {torch.dtype: lambda x: str(x)} + @root_validator + def zero_dtype_valid(cls, values): + if values.get("enable_zero"): + if values.get("ds_config").get("fp16", {}).get("enabled", False): + assert ( + values.get("dtype") == DtypeEnum.float16 + ), "ZeRO FP16 enabled, `dtype` must be set to `torch.float16`" + else: + assert ( + values.get("dtype") == DtypeEnum.float32 + ), "ZeRO FP16 disabled, `dtype` must be set to `torch.float32`" + return values + + @root_validator + def deepspeed_or_zero(cls, values): + assert not ( + values.get("enable_deepspeed") and values.get("enable_zero") + ), "DeepSpeed and ZeRO cannot both be enabled, select only one" + return values + +class MIIConfig(DeepSpeedConfigModel): + hf_auth_token: str = None + port_number: int = 50050 + enable_restful_api: bool = False + restful_api_port: int = 51080 + hostfile: str = DLTS_HOSTFILE + deployment_type: DeploymentType = DeploymentType.LOCAL + version: int = 1 -class ReplicaConfig(BaseModel): - hostname: str = "" - tensor_parallel_ports: List[int] = [] - torch_dist_port: int = None - gpu_indices: List[int] = [] + @root_validator + def AML_name_valid(cls, fields): + if fields.get("deployment_type") == DeploymentType.AML: + allowed_chars = set(string.ascii_lowercase + string.ascii_uppercaes + + string.digits + "-") + assert ( + set(fields.get("deployment_config").deployment_name) <= allowed_chars + ), "AML deployment names can only contain a-z, A-Z, 0-9, and '-'." + return fields - class Config: - validate_all = True - validate_assignment = True + @root_validator + def generate_replica_configs(cls, values): + replica_configs = values.get("deployment_config").replica_configs + num_replicas = values.get("deployment_config").num_replicas + if replica_configs: + assert len(replica_confgs) == num_replicas + return values + + hostfile = values.get("hostfile") + port_number = values.get("port_number") + torch_dist_port = values.get("deployment_config").torch_dist_port + tensor_parallel = values.get("deployment_config").tensor_parallel + replica_num = values.get("deployment_config").replica_num + replica_pool = _allocate_processes(hostfile, tensor_parallel, replica_num) + replica_configs = [] + for i, (hostname, gpu_indices) in enumerate(replica_pool): + # Reserver port for a LB proxy when replication is enabled + port_offset = 1 + base_port = port_number + i * tensor_parallel + port_offset + tensor_parallel_ports = list(range(base_port, base_port + tensor_parallel)) + replica_torch_dist_port = torch_dist_port + i + replica_configs.append( + ReplicaConfig(hostname=hostname, + tensor_parallel_ports=tensor_parallel_ports, + torch_dist_port=replica_torch_dist_port, + gpu_indices=gpu_indices)) + + values.get("deployment_config").replica_configs = replica_configs + return values -class LoadBalancerConfig(BaseModel): - port: int = None - replica_configs: List[ReplicaConfig] = [] +def _allocate_processes(hostfile_path, tensor_parallel, num_replicas): + resource_pool = fetch_hostfile(hostfile_path) + assert resource_pool is not None and len( + resource_pool) > 0, f'No hosts found in {hostfile_path}' + + replica_pool = [] + allocated_num = 0 + for host, slots in resource_pool.items(): + available_on_host = slots + while available_on_host >= tensor_parallel: + if allocated_num >= num_replicas: + break + if slots < tensor_parallel: + raise ValueError( + f'Host {host} has {slots} slot(s), but {tensor_parallel} slot(s) are required' + ) + + allocated_num_on_host = slots - available_on_host + replica_pool.append( + (host, + [ + i for i in range(allocated_num_on_host, + allocated_num_on_host + tensor_parallel) + ])) + allocated_num += 1 + + available_on_host -= tensor_parallel + + if allocated_num < num_replicas: + raise ValueError( + f'No sufficient GPUs for {num_replicas} replica(s), only {allocated_num} replica(s) can be deployed' + ) - class Config: - validate_all = True - validate_assignment = True + return replica_pool diff --git a/mii/deployment.py b/mii/deployment.py index 3cadd994..449b3801 100644 --- a/mii/deployment.py +++ b/mii/deployment.py @@ -16,80 +16,56 @@ from .config import ReplicaConfig, LoadBalancerConfig -def deploy(task, - model, - deployment_name, - deployment_type=DeploymentType.LOCAL, - model_path=None, - enable_deepspeed=True, - enable_zero=False, - ds_config=None, - mii_config={}, - version=1): - """Deploy a task using specified model. For usage examples see: - - mii/examples/local/text-generation-example.py - - - Arguments: - task: Name of the machine learning task to be deployed.Currently MII supports the following list of tasks - ``['text-generation', 'text-classification', 'question-answering', 'fill-mask', 'token-classification', 'conversational', 'text-to-image']`` - - model: Name of a supported model for the task. Models in MII are sourced from multiple open-source projects - such as Huggingface Transformer, FairSeq, EluetherAI etc. For the list of supported models for each task, please - see here [TODO]. - - deployment_name: Name of the deployment. Used as an identifier for posting queries for ``LOCAL`` deployment. - - deployment_type: One of the ``enum mii.DeploymentTypes: [LOCAL]``. - *``LOCAL`` uses a grpc server to create a local deployment, and query the model must be done by creating a query handle using - `mii.mii_query_handle` and posting queries using ``mii_request_handle.query`` API, - - model_path: Optional: In LOCAL deployments this is the local path where model checkpoints are available. In AML deployments this - is an optional relative path with AZURE_MODEL_DIR for the deployment. - - enable_deepspeed: Optional: Defaults to True. Use this flag to enable or disable DeepSpeed-Inference optimizations - - enable_zero: Optional: Defaults to False. Use this flag to enable or disable DeepSpeed-ZeRO inference - - ds_config: Optional: Defaults to None. Use this to specify the DeepSpeed configuration when enabling DeepSpeed-ZeRO inference - - force_register_model: Optional: Defaults to False. For AML deployments, set it to True if you want to re-register your model - with the same ``aml_model_tags`` using checkpoints from ``model_path``. - - mii_config: Optional: Dictionary specifying optimization and deployment configurations that should override defaults in ``mii.config.MIIConfig``. - mii_config is future looking to support extensions in optimization strategies supported by DeepSpeed Inference as we extend mii. - As of now, it can be used to set tensor-slicing degree using 'tensor_parallel' and port number for deployment using 'port_number'. - - version: Optional: Version to be set for AML deployment, useful if you want to deploy the same model with different settings. - Returns: - If deployment_type is `LOCAL`, returns just the name of the deployment that can be used to create a query handle using `mii.mii_query_handle(deployment_name)` - - """ - - # parse and validate mii config - mii_config = mii.config.MIIConfig(**mii_config) - if enable_zero: - if ds_config.get("fp16", {}).get("enabled", False): - assert (mii_config.dtype == torch.half), "MII Config Error: MII dtype and ZeRO dtype must match" - else: - assert (mii_config.dtype == torch.float), "MII Config Error: MII dtype and ZeRO dtype must match" - assert not (enable_deepspeed and enable_zero), "MII Config Error: DeepSpeed and ZeRO cannot both be enabled, select only one" - - # aml only allows certain characters for deployment names - if deployment_type == DeploymentType.AML: - allowed_chars = set(string.ascii_lowercase + string.ascii_uppercase + - string.digits + '-') - assert set(deployment_name) <= allowed_chars, "AML deployment names can only contain a-z, A-Z, 0-9, and '-'" - - task = mii.utils.get_task(task) - - if not mii_config.skip_model_check: - mii.utils.check_if_task_and_model_is_valid(task, model) - if enable_deepspeed: - mii.utils.check_if_task_and_model_is_supported(task, model) - - if enable_deepspeed: +def support_legacy_api(task, + model, + deployment_type=DeploymentType.LOCAL, + model_path="", + enable_deepspeed=True, + enable_zero=False, + ds_config=None, + mii_config=None, + version=1): + deployment_tag = deployment_name + + if ds_config is None: + ds_config = {} + if mii_config is None: + mii_config = {} + + deployment_config = { + "deployment_name": deployment_name, + "task": task, + "model": model, + "model_path": model_path, + "ds_optimize": enable_deepspeed, + "ds_zero": enable_zero, + "ds_config": ds_config, + } + for key, val in mii_config.items(): + if not hasattr(MIIConfig, key): + deployment_config[key] = val + deployments = [deployment_config] + + mii_config = {k: v for k, v in mii_config.items() if hasattr(MIIConfig, k)} + mii_config["version"] = version + mii_config["deployment_type"] = deployment_type + + return deployment_tag, deployments, mii_config + + +def deploy(deployment_name, deployment_config=None, mii_config=None, *args, **kwargs): + if mii_config is None: + mii_config = {} + + if args or kwargs: + assert not deployment_config, "We do not support mixture of legacy and new API options, use latest API." + kwargs["mii_config"] = mii_config + deployment_config, mii_config = support_legacy_api(*args, **kwargs) + + deployment_config["deployment_name"] = deployment_name + mii_config = mii.config.MIIConfig(**mii_config, deployment_config=deployment_config) + + if mii_config.deployment_config.enable_deepspeed: logger.info( f"************* MII is using DeepSpeed Optimizations to accelerate your model *************" ) @@ -98,54 +74,24 @@ def deploy(task, f"************* DeepSpeed Optimizations not enabled. Please use enable_deepspeed to get better performance *************" ) - # In local deployments use default path if no model path set - if model_path is None and deployment_type == DeploymentType.LOCAL: - model_path = MII_MODEL_PATH_DEFAULT - elif model_path is None and deployment_type == DeploymentType.AML: - model_path = "model" - - # add fields for replica deployment - replica_pool = _allocate_processes(mii_config.hostfile, - mii_config.tensor_parallel, - mii_config.replica_num) - replica_configs = [] - for i, (hostname, gpu_indices) in enumerate(replica_pool): - # Reserver port for a LB proxy when replication is enabled - port_offset = 1 - base_port = mii_config.port_number + i * mii_config.tensor_parallel + port_offset - tensor_parallel_ports = list( - range(base_port, - base_port + mii_config.tensor_parallel)) - torch_dist_port = mii_config.torch_dist_port + i - replica_configs.append( - ReplicaConfig(hostname=hostname, - tensor_parallel_ports=tensor_parallel_ports, - torch_dist_port=torch_dist_port, - gpu_indices=gpu_indices)) - lb_config = LoadBalancerConfig(port=mii_config.port_number, - replica_configs=replica_configs) - if deployment_type != DeploymentType.NON_PERSISTENT: - create_score_file(deployment_name=deployment_name, - deployment_type=deployment_type, - task=task, - model_name=model, - ds_optimize=enable_deepspeed, - ds_zero=enable_zero, - ds_config=ds_config, - mii_config=mii_config, - model_path=model_path, - lb_config=lb_config) + create_score_file(mii_config) if deployment_type == DeploymentType.AML: - _deploy_aml(deployment_name=deployment_name, model_name=model, version=version) + _deploy_aml(mii_config) elif deployment_type == DeploymentType.LOCAL: - return _deploy_local(deployment_name, model_path=model_path) + return _deploy_local(mii_config) elif deployment_type == DeploymentType.NON_PERSISTENT: - assert int(os.getenv('WORLD_SIZE', '1')) == mii_config.tensor_parallel, "World Size does not equal number of tensors. When using non-persistent deployment type, please launch with `deepspeed --num_gpus `" + assert int(os.getenv('WORLD_SIZE', '1')) == mii_config.deployment_config.tensor_parallel, "World Size does not equal number of tensors. When using non-persistent deployment type, please launch with `deepspeed --num_gpus `" + deployment_name = mii_config.deployment_config.deployment_name + model = mii_config.deployment_config.model + task = mii_config.deployment_config.task + model_path = mii_config.deployment_config.model_path + enable_deepspeed = mii_config.deployment_config.enable_deepspeed + enable_zero = mii_config.deployment_config.enable_zero provider = MODEL_PROVIDER_MAP[get_provider_name(model, task)] mii.non_persistent_models[deployment_name] = (load_models( - get_task_name(task), + task, model, model_path, enable_deepspeed, @@ -157,53 +103,18 @@ def deploy(task, raise Exception(f"Unknown deployment type: {deployment_type}") -def _deploy_local(deployment_name, model_path): - mii.utils.import_score_file(deployment_name).init() +def _deploy_local(mii_config): + mii.utils.import_score_file(mii_config.deployment_config.deployment_name).init() -def _deploy_aml(deployment_name, model_name, version): +def _deploy_aml(mii_config): acr_name = mii.aml_related.utils.get_acr_name() - mii.aml_related.utils.generate_aml_scripts(acr_name=acr_name, - deployment_name=deployment_name, - model_name=model_name, - version=version) + mii.aml_related.utils.generate_aml_scripts( + acr_name=acr_name, + deployment_name=mii_config.deployment_config.deployment_name, + model_name=mii_config.deployment_config.model, + version=mii_config.version) print( f"AML deployment assets at {mii.aml_related.utils.aml_output_path(deployment_name)}" ) print("Please run 'deploy.sh' to bring your deployment online") - - -def _allocate_processes(hostfile_path, tensor_parallel, num_replicas): - resource_pool = fetch_hostfile(hostfile_path) - assert resource_pool is not None and len( - resource_pool) > 0, f'No hosts found in {hostfile_path}' - - replica_pool = [] - allocated_num = 0 - for host, slots in resource_pool.items(): - available_on_host = slots - while available_on_host >= tensor_parallel: - if allocated_num >= num_replicas: - break - if slots < tensor_parallel: - raise ValueError( - f'Host {host} has {slots} slot(s), but {tensor_parallel} slot(s) are required' - ) - - allocated_num_on_host = slots - available_on_host - replica_pool.append( - (host, - [ - i for i in range(allocated_num_on_host, - allocated_num_on_host + tensor_parallel) - ])) - allocated_num += 1 - - available_on_host -= tensor_parallel - - if allocated_num < num_replicas: - raise ValueError( - f'No sufficient GPUs for {num_replicas} replica(s), only {allocated_num} replica(s) can be deployed' - ) - - return replica_pool diff --git a/mii/models/score/generate.py b/mii/models/score/generate.py index 1184d70e..f347202c 100644 --- a/mii/models/score/generate.py +++ b/mii/models/score/generate.py @@ -9,29 +9,7 @@ from mii.constants import DeploymentType -def create_score_file(deployment_name, - deployment_type, - task, - model_name, - ds_optimize, - ds_zero, - ds_config, - mii_config, - model_path, - lb_config): - config_dict = {} - config_dict[mii.constants.DEPLOYMENT_NAME_KEY] = deployment_name - config_dict[mii.constants.TASK_NAME_KEY] = mii.utils.get_task_name(task) - config_dict[mii.constants.MODEL_NAME_KEY] = model_name - config_dict[mii.constants.ENABLE_DEEPSPEED_KEY] = ds_optimize - config_dict[mii.constants.MII_CONFIGS_KEY] = mii_config.dict() - config_dict[mii.constants.ENABLE_DEEPSPEED_ZERO_KEY] = ds_zero - config_dict[mii.constants.DEEPSPEED_CONFIG_KEY] = ds_config - config_dict[mii.constants.MODEL_PATH_KEY] = model_path - - if lb_config is not None: - config_dict[mii.constants.LOAD_BALANCER_CONFIG_KEY] = lb_config - +def create_score_file(mii_config): if len(mii.__path__) > 1: logger.warning( f"Detected mii path as multiple sources: {mii.__path__}, might cause unknown behavior" @@ -43,6 +21,7 @@ def create_score_file(deployment_name, score_src = fd.read() # update score file w. global config dict + config_dict = mii_config.dict() source_with_config = f"{score_src}\n" source_with_config += f"configs = {pprint.pformat(config_dict, indent=4)}" diff --git a/mii/models/score/score_template.py b/mii/models/score/score_template.py index 04e47fae..0f59652e 100644 --- a/mii/models/score/score_template.py +++ b/mii/models/score/score_template.py @@ -8,32 +8,17 @@ import json import torch import mii -from mii.config import LoadBalancerConfig, ReplicaConfig +from mii.config import MIIConfig import time model = None def init(): - model_path = mii.utils.full_model_path(configs[mii.constants.MODEL_PATH_KEY]) - - deployment_name = configs[mii.constants.DEPLOYMENT_NAME_KEY] - model_name = configs[mii.constants.MODEL_NAME_KEY] - task_name = configs[mii.constants.TASK_NAME_KEY] - - assert model_name is not None, "The model name should be set before calling init" - assert task_name is not None, "The task name should be set before calling init" - - mii.MIIServer(deployment_name, - task_name, - model_name, - model_path, - ds_optimize=configs[mii.constants.ENABLE_DEEPSPEED_KEY], - ds_zero=configs[mii.constants.ENABLE_DEEPSPEED_ZERO_KEY], - ds_config=configs[mii.constants.DEEPSPEED_CONFIG_KEY], - mii_configs=configs[mii.constants.MII_CONFIGS_KEY], - lb_config=configs.get(mii.constants.LOAD_BALANCER_CONFIG_KEY, - None)) + mii_configs = MIIConfig(**mii_configs) + #model_path = mii.utils.full_model_path(configs[mii.constants.MODEL_PATH_KEY]) + + mii.MIIServer(mii_config) global model model = None From 206fc2c7a5805dbb8f9b85adb31043ac4f82eff1 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Mon, 7 Aug 2023 17:31:46 -0700 Subject: [PATCH 02/15] further refactoring --- mii/__init__.py | 4 +- mii/aml_related/templates.py | 30 +- mii/aml_related/utils.py | 99 ++--- mii/client.py | 89 ++-- mii/config.py | 157 ++++--- mii/constants.py | 124 ++---- mii/deployment.py | 81 ++-- mii/grpc_related/modelresponse_server.py | 74 ++-- mii/grpc_related/proto/modelresponse_pb2.py | 13 +- .../proto/modelresponse_pb2_grpc.py | 382 +++++++++--------- mii/grpc_related/restful_gateway.py | 16 +- mii/launch/multi_gpu_server.py | 113 +++--- mii/method_table.py | 131 +++--- mii/models/load_models.py | 89 ++-- mii/models/providers/diffusers.py | 13 +- mii/models/providers/huggingface.py | 96 +++-- mii/models/score/generate.py | 19 +- mii/models/score/score_template.py | 23 +- mii/models/utils.py | 14 +- mii/server.py | 306 ++++---------- mii/terminate.py | 4 +- mii/utils.py | 143 ++----- 22 files changed, 902 insertions(+), 1118 deletions(-) diff --git a/mii/__init__.py b/mii/__init__.py index ab409d4c..e0b4fe21 100644 --- a/mii/__init__.py +++ b/mii/__init__.py @@ -7,10 +7,10 @@ from .client import MIIClient, mii_query_handle from .deployment import deploy from .terminate import terminate -from .constants import DeploymentType, Tasks +from .constants import DeploymentType, TaskType from .aml_related.utils import aml_output_path -from .config import MIIConfig, LoadBalancerConfig +from .config import MIIConfig from .grpc_related.proto import modelresponse_pb2_grpc __version__ = "0.0.0" diff --git a/mii/aml_related/templates.py b/mii/aml_related/templates.py index 5e67255b..5ad97a31 100644 --- a/mii/aml_related/templates.py +++ b/mii/aml_related/templates.py @@ -2,8 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team -deployment = \ -"""$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineDeployment.schema.json +deployment = """$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineDeployment.schema.json name: endpoint_name: model: @@ -36,14 +35,12 @@ instance_count: 1 """ -endpoint = \ -"""$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineEndpoint.schema.json +endpoint = """$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineEndpoint.schema.json name: auth_mode: key """ -environment = \ -"""$schema: https://azuremlschemas.azureedge.net/latest/environment.schema.json +environment = """$schema: https://azuremlschemas.azureedge.net/latest/environment.schema.json name: version: image: .azurecr.io/: @@ -59,8 +56,7 @@ port: 5001 """ -model_download = \ -"""import os +model_download = """import os import glob import shutil @@ -97,8 +93,7 @@ shutil.rmtree(tmp_download_path) """ -deploy = \ -"""set -e +deploy = """set -e python3 model_download.py az acr build -r --build-arg no-cache=True -t ":" build az ml environment create -f environment.yml @@ -106,8 +101,7 @@ az ml online-deployment create -n "" -f deployment.yml """ -dockerfile = \ -"""FROM nvidia/cuda:11.3.1-devel-ubuntu18.04 +dockerfile = """FROM nvidia/cuda:11.3.1-devel-ubuntu18.04 ENV AML_APP_ROOT=/var/azureml-model/code \ BUILD_DIR=/tmp/build \ @@ -176,8 +170,7 @@ CMD sudo service nginx start && cd $AZUREML_MODEL_DIR/code && azmlinfsrv --model_dir $AZUREML_MODEL_DIR --entry_script $AZUREML_MODEL_DIR/code/score.py --port 31311 """ -gunicorn = \ -"""upstream gunicorn { +gunicorn = """upstream gunicorn { server 127.0.0.1:31311; } @@ -202,8 +195,7 @@ } """ -gunicorn_run = \ -"""#!/bin/bash +gunicorn_run = """#!/bin/bash SCRIPT_PATH=$(dirname $(realpath -s "$0")) @@ -377,8 +369,7 @@ fi """ -gunicorn_finish = \ -"""#!/bin/bash +gunicorn_finish = """#!/bin/bash exit_code="$1" # The exit code from gunicorn signal="$2" # The signal which caused gunicorn to exit (or 0) @@ -389,8 +380,7 @@ killall -SIGHUP runsvdir """ -requirements = \ -"""torch==1.12.0 +requirements = """torch==1.12.0 grpcio grpcio-tools pydantic diff --git a/mii/aml_related/utils.py b/mii/aml_related/utils.py index f6e1520b..c2065010 100644 --- a/mii/aml_related/utils.py +++ b/mii/aml_related/utils.py @@ -11,14 +11,10 @@ def get_acr_name(): try: acr_name = subprocess.check_output( - ["az", - "ml", - "workspace", - "show", - "--query", - "container_registry"], - text=True) - return acr_name.strip().replace('"', '').rsplit('/', 1)[-1] + ["az", "ml", "workspace", "show", "--query", "container_registry"], + text=True, + ) + return acr_name.strip().replace('"', "").rsplit("/", 1)[-1] except subprocess.CalledProcessError as e: print("\n", "-" * 30, "\n") print("Unable to obtain ACR name from Azure-CLI. Please verify that you:") @@ -80,60 +76,49 @@ def generate_aml_scripts(acr_name, deployment_name, model_name, version): } # Docker files - write_out_script(os.path.join(output_dir, - "build", - "Dockerfile"), - fill_template(mii.aml_related.templates.dockerfile, - replace_dict)) - write_out_script(os.path.join(output_dir, - "build", - "gunicorn_app"), - fill_template(mii.aml_related.templates.gunicorn, - replace_dict)) - write_out_script(os.path.join(output_dir, - "build", - "runit", - "gunicorn", - "run"), - fill_template(mii.aml_related.templates.gunicorn_run, - replace_dict)) write_out_script( - os.path.join(output_dir, - "build", - "runit", - "gunicorn", - "finish"), - fill_template(mii.aml_related.templates.gunicorn_finish, - replace_dict)) - write_out_script(os.path.join(output_dir, - "build", - "requirements.txt"), - fill_template(mii.aml_related.templates.requirements, - replace_dict)) + os.path.join(output_dir, "build", "Dockerfile"), + fill_template(mii.aml_related.templates.dockerfile, replace_dict), + ) + write_out_script( + os.path.join(output_dir, "build", "gunicorn_app"), + fill_template(mii.aml_related.templates.gunicorn, replace_dict), + ) + write_out_script( + os.path.join(output_dir, "build", "runit", "gunicorn", "run"), + fill_template(mii.aml_related.templates.gunicorn_run, replace_dict), + ) + write_out_script( + os.path.join(output_dir, "build", "runit", "gunicorn", "finish"), + fill_template(mii.aml_related.templates.gunicorn_finish, replace_dict), + ) + write_out_script( + os.path.join(output_dir, "build", "requirements.txt"), + fill_template(mii.aml_related.templates.requirements, replace_dict), + ) # Model download script write_out_script( - os.path.join(output_dir, - "model_download.py"), - fill_template(mii.aml_related.templates.model_download, - replace_dict)) + os.path.join(output_dir, "model_download.py"), + fill_template(mii.aml_related.templates.model_download, replace_dict), + ) # Deployment script - write_out_script(os.path.join(output_dir, - "deploy.sh"), - fill_template(mii.aml_related.templates.deploy, - replace_dict)) + write_out_script( + os.path.join(output_dir, "deploy.sh"), + fill_template(mii.aml_related.templates.deploy, replace_dict), + ) # Yaml configs - write_out_yaml(os.path.join(output_dir, - "deployment.yml"), - fill_template(mii.aml_related.templates.deployment, - replace_dict)) - write_out_yaml(os.path.join(output_dir, - "endpoint.yml"), - fill_template(mii.aml_related.templates.endpoint, - replace_dict)) - write_out_yaml(os.path.join(output_dir, - "environment.yml"), - fill_template(mii.aml_related.templates.environment, - replace_dict)) + write_out_yaml( + os.path.join(output_dir, "deployment.yml"), + fill_template(mii.aml_related.templates.deployment, replace_dict), + ) + write_out_yaml( + os.path.join(output_dir, "endpoint.yml"), + fill_template(mii.aml_related.templates.endpoint, replace_dict), + ) + write_out_yaml( + os.path.join(output_dir, "environment.yml"), + fill_template(mii.aml_related.templates.environment, replace_dict), + ) diff --git a/mii/client.py b/mii/client.py index 535b55c8..b776df38 100644 --- a/mii/client.py +++ b/mii/client.py @@ -6,9 +6,8 @@ import grpc import requests import mii -from mii.utils import get_task from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc -from mii.constants import GRPC_MAX_MSG_SIZE, Tasks +from mii.constants import GRPC_MAX_MSG_SIZE, TaskType from mii.method_table import GRPC_METHOD_TABLE @@ -39,27 +38,30 @@ def mii_query_handle(deployment_name): inference_pipeline, task = mii.non_persistent_models[deployment_name] return MIINonPersistentClient(task, deployment_name) - task_name, mii_configs = _get_deployment_info(deployment_name) - return MIIClient(task_name, "localhost", mii_configs.port_number) + task, mii_configs = _get_deployment_info(deployment_name) + return MIIClient(task, "localhost", mii_configs.port_number) def create_channel(host, port): - return grpc.aio.insecure_channel(f'{host}:{port}', - options=[('grpc.max_send_message_length', - GRPC_MAX_MSG_SIZE), - ('grpc.max_receive_message_length', - GRPC_MAX_MSG_SIZE)]) + return grpc.aio.insecure_channel( + f"{host}:{port}", + options=[ + ("grpc.max_send_message_length", GRPC_MAX_MSG_SIZE), + ("grpc.max_receive_message_length", GRPC_MAX_MSG_SIZE), + ], + ) -class MIIClient(): +class MIIClient: """ Client to send queries to a single endpoint. """ - def __init__(self, task_name, host, port): + + def __init__(self, task, host, port): self.asyncio_loop = asyncio.get_event_loop() channel = create_channel(host, port) self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel) - self.task = get_task(task_name) + self.task = task async def _request_async_response(self, request_dict, **query_kwargs): if self.task not in GRPC_METHOD_TABLE: @@ -72,42 +74,51 @@ async def _request_async_response(self, request_dict, **query_kwargs): def query(self, request_dict, **query_kwargs): return self.asyncio_loop.run_until_complete( - self._request_async_response(request_dict, - **query_kwargs)) + self._request_async_response(request_dict, **query_kwargs) + ) async def terminate_async(self): await self.stub.Terminate( - modelresponse_pb2.google_dot_protobuf_dot_empty__pb2.Empty()) + modelresponse_pb2.google_dot_protobuf_dot_empty__pb2.Empty() + ) def terminate(self): self.asyncio_loop.run_until_complete(self.terminate_async()) async def create_session_async(self, session_id): return await self.stub.CreateSession( - modelresponse_pb2.SessionID(session_id=session_id)) + modelresponse_pb2.SessionID(session_id=session_id) + ) def create_session(self, session_id): - assert self.task == Tasks.TEXT_GENERATION, f"Session creation only available for task '{Tasks.TEXT_GENERATION}'." + assert ( + self.task == TaskType.TEXT_GENERATION + ), f"Session creation only available for task '{TaskType.TEXT_GENERATION}'." return self.asyncio_loop.run_until_complete( - self.create_session_async(session_id)) + self.create_session_async(session_id) + ) async def destroy_session_async(self, session_id): - await self.stub.DestroySession(modelresponse_pb2.SessionID(session_id=session_id) - ) + await self.stub.DestroySession( + modelresponse_pb2.SessionID(session_id=session_id) + ) def destroy_session(self, session_id): - assert self.task == Tasks.TEXT_GENERATION, f"Session deletion only available for task '{Tasks.TEXT_GENERATION}'." + assert ( + self.task == TaskType.TEXT_GENERATION + ), f"Session deletion only available for task '{TaskType.TEXT_GENERATION}'." self.asyncio_loop.run_until_complete(self.destroy_session_async(session_id)) -class MIITensorParallelClient(): +class MIITensorParallelClient: """ Client to send queries to multiple endpoints in parallel. This is used to call multiple servers deployed for tensor parallelism. """ - def __init__(self, task_name, host, ports): - self.task = get_task(task_name) - self.clients = [MIIClient(task_name, host, port) for port in ports] + + def __init__(self, task, host, ports): + self.task = task + self.clients = [MIIClient(task, host, port) for port in ports] self.asyncio_loop = asyncio.get_event_loop() # runs task in parallel and return the result from the first task @@ -116,8 +127,9 @@ async def _query_in_tensor_parallel(self, request_string, query_kwargs): for client in self.clients: responses.append( self.asyncio_loop.create_task( - client._request_async_response(request_string, - **query_kwargs))) + client._request_async_response(request_string, **query_kwargs) + ) + ) await responses[0] return responses[0] @@ -136,8 +148,8 @@ def query(self, request_dict, **query_kwargs): response: Response of the model """ response = self.asyncio_loop.run_until_complete( - self._query_in_tensor_parallel(request_dict, - query_kwargs)) + self._query_in_tensor_parallel(request_dict, query_kwargs) + ) ret = response.result() return ret @@ -155,30 +167,33 @@ def destroy_session(self, session_id): client.destroy_session(session_id) -class MIINonPersistentClient(): +class MIINonPersistentClient: def __init__(self, task, deployment_name): self.task = task self.deployment_name = deployment_name def query(self, request_dict, **query_kwargs): - assert self.deployment_name in mii.non_persistent_models, f"deployment: {self.deployment_name} not found" + assert ( + self.deployment_name in mii.non_persistent_models + ), f"deployment: {self.deployment_name} not found" task_methods = GRPC_METHOD_TABLE[self.task] inference_pipeline = mii.non_persistent_models[self.deployment_name][0] - if self.task == Tasks.QUESTION_ANSWERING: - if 'question' not in request_dict or 'context' not in request_dict: + if self.task == TaskType.QUESTION_ANSWERING: + if "question" not in request_dict or "context" not in request_dict: raise Exception( - "Question Answering Task requires 'question' and 'context' keys") + "Question Answering Task requires 'question' and 'context' keys" + ) args = (request_dict["question"], request_dict["context"]) kwargs = query_kwargs - elif self.task == Tasks.CONVERSATIONAL: + elif self.task == TaskType.CONVERSATIONAL: conv = task_methods.create_conversation(request_dict, **query_kwargs) - args = (conv, ) + args = (conv,) kwargs = {} else: - args = (request_dict['query'], ) + args = (request_dict["query"],) kwargs = query_kwargs return task_methods.run_inference(inference_pipeline, args, query_kwargs) diff --git a/mii/config.py b/mii/config.py index f2a41f3e..28a21c8f 100644 --- a/mii/config.py +++ b/mii/config.py @@ -3,32 +3,19 @@ # DeepSpeed Team import torch -from typing import Union, List +import os +from typing import Union, List, Optional, Dict, Any from enum import Enum from pydantic import validator, root_validator +import mii +from mii.constants import DeploymentType, TaskType, MII_MODEL_PATH_DEFAULT from deepspeed.runtime.config_utils import DeepSpeedConfigModel -from deepspeed.runtime.config import DTypeEnum -from deepspeed.launcher.runner import DLTS_HOSTFILE +from deepspeed.inference.config import DtypeEnum +from deepspeed.launcher.runner import DLTS_HOSTFILE, fetch_hostfile -class DeploymentType(Enum): - LOCAL = "local" - AML = "aml" - NON_PERSISTENT = "non-persistent" - - -class TaskType(Enum): - TEXT_GENERATION = "text-generation" - TEXT_CLASSIFICATION = "text-classification" - QUESTION_ANSWERING = "question-answering" - FILL_MASK = "fill-mask" - TOKEN_CLASSIFICATION = "token-classification" - CONVERSATIONAL = "conversational" - TEXT2IMG = "text-to-image" - - -class ReplicaConfig(BaseModel): +class ReplicaConfig(DeepSpeedConfigModel): hostname: str = "" tensor_parallel_ports: List[int] = [] torch_dist_port: int = None @@ -55,29 +42,50 @@ class DeploymentConfig(DeepSpeedConfigModel): enable_deepspeed: bool = True enable_zero: bool = False ds_config: Dict[str, Any] = {} - model_path: Optional[str] = "" + model_path: str = "" replica_num: int = 1 replica_configs: List[ReplicaConfig] = [] - @validator('checkpoint_dict') - def checkpoint_dict_valid(cls, value): - if value is None: - return value - if value.get('base_dir', ''): + @validator("checkpoint_dict") + def checkpoint_dict_valid(cls, field_value, values): + if field_value is None: + return field_value + if field_value.get("base_dir", ""): raise ValueError( "please unset 'base_dir' it will be set w.r.t. the deployment 'model_path'" ) - for k in ['checkpoints', 'parallelization', 'version', 'type']: - if not value.get(k, ''): + for k in ["checkpoints", "parallelization", "version", "type"]: + if not field_value.get(k, ""): raise ValueError(f"Missing key={k} in checkpoint_dict") - return value + return vfield_alue @validator("deploy_rank", pre=True) def deploy_rank_to_list(cls, field_value, values): - if not isinstance(field_value, list): + if field_value and not isinstance(field_value, list): field_value = [field_value] return field_value + @root_validator + def zero_or_meta(cls, values): + if values.get("enable_zero"): + assert not values.get( + "meta_tensor" + ), "ZeRO-Inference does not support meta tensors." + return values + + @root_validator + def bloom_model_valid(cls, values): + if "bigscience/bloom" in values.get("model"): + dtype = values.get("dtype") + assert dtype in [ + DtypeEnum.int8, + DtypeEnum.fp16, + ], "Bloom models only support fp16/int8." + assert not values.get( + "enable_cuda_graph" + ), "Bloom models do not support CUDA Graph." + return values + @root_validator def deploy_rank_valid(cls, values): tensor_parallel = values.get("tensor_parallel") @@ -88,21 +96,32 @@ def deploy_rank_valid(cls, values): deploy_rank = list(range(tensor_parallel)) # number of ranks provided must be equal to TP size, DP is handled outside MII currently - assert tensor_parallel == len(deploy_rank), \ - f"{len(deploy_rank)} rank(s) provided in 'deploy_rank' does not align with tensor_parallel size of {tensor_parallel}" + assert tensor_parallel == len( + deploy_rank + ), f"{len(deploy_rank)} rank(s) provided in 'deploy_rank' does not align with tensor_parallel size of {tensor_parallel}" values["deploy_rank"] = deploy_rank return values @root_validator def set_model_path(cls, values): - if values.get("model_path") is None: - deployment_type = values.get("deployment_type") - if deployment_type == DeploymentType.LOCAL: - model_path = MII_MODEL_PATH_DEFAULT - if deployment_tpye == DeploymentType.AML: + model_path = values.get("model_path") + if not model_path: + if values.get("deployment_type") == DeploymentType.AML: model_path = "model" - values["model_path"] = model_path + else: + model_path = MII_MODEL_PATH_DEFAULT + aml_model_dir = os.environ.get("AZUREML_MODEL_DIR", None) + if aml_model_dir: + assert os.path.isabs( + aml_model_dir + ), "AZUREML_MODEL_DIR={aml_model_dir} must be an absolute path." + assert not os.path.isabs( + model_path + ), f"model_path={model_path} must be relative to append w/ AML path." + model_path = os.path.join(aml_model_dir, model_path) + + values["model_path"] = model_path return values @root_validator @@ -145,6 +164,7 @@ def deepspeed_or_zero(cls, values): class MIIConfig(DeepSpeedConfigModel): + deployment_config: DeploymentConfig = {} hf_auth_token: str = None port_number: int = 50050 enable_restful_api: bool = False @@ -154,21 +174,22 @@ class MIIConfig(DeepSpeedConfigModel): version: int = 1 @root_validator - def AML_name_valid(cls, fields): - if fields.get("deployment_type") == DeploymentType.AML: - allowed_chars = set(string.ascii_lowercase + string.ascii_uppercaes + - string.digits + "-") + def AML_name_valid(cls, values): + if values.get("deployment_type") == DeploymentType.AML: + allowed_chars = set( + string.ascii_lowercase + string.ascii_uppercaes + string.digits + "-" + ) assert ( - set(fields.get("deployment_config").deployment_name) <= allowed_chars + set(values.get("deployment_config").deployment_name) <= allowed_chars ), "AML deployment names can only contain a-z, A-Z, 0-9, and '-'." - return fields + return values @root_validator def generate_replica_configs(cls, values): replica_configs = values.get("deployment_config").replica_configs - num_replicas = values.get("deployment_config").num_replicas + replica_num = values.get("deployment_config").replica_num if replica_configs: - assert len(replica_confgs) == num_replicas + assert len(replica_configs) == replica_num return values hostfile = values.get("hostfile") @@ -183,48 +204,58 @@ def generate_replica_configs(cls, values): port_offset = 1 base_port = port_number + i * tensor_parallel + port_offset tensor_parallel_ports = list(range(base_port, base_port + tensor_parallel)) - replica_torch_dist_port = torch_dist_port + i + replica_torch_dist_port = torch_dist_port + (100 * i) replica_configs.append( - ReplicaConfig(hostname=hostname, - tensor_parallel_ports=tensor_parallel_ports, - torch_dist_port=replica_torch_dist_port, - gpu_indices=gpu_indices)) + ReplicaConfig( + hostname=hostname, + tensor_parallel_ports=tensor_parallel_ports, + torch_dist_port=replica_torch_dist_port, + gpu_indices=gpu_indices, + ) + ) values.get("deployment_config").replica_configs = replica_configs return values -def _allocate_processes(hostfile_path, tensor_parallel, num_replicas): +def _allocate_processes(hostfile_path, tensor_parallel, replica_num): resource_pool = fetch_hostfile(hostfile_path) - assert resource_pool is not None and len( - resource_pool) > 0, f'No hosts found in {hostfile_path}' + assert ( + resource_pool is not None and len(resource_pool) > 0 + ), f"No hosts found in {hostfile_path}" replica_pool = [] allocated_num = 0 for host, slots in resource_pool.items(): available_on_host = slots while available_on_host >= tensor_parallel: - if allocated_num >= num_replicas: + if allocated_num >= replica_num: break if slots < tensor_parallel: raise ValueError( - f'Host {host} has {slots} slot(s), but {tensor_parallel} slot(s) are required' + f"Host {host} has {slots} slot(s), but {tensor_parallel} slot(s) are required" ) allocated_num_on_host = slots - available_on_host replica_pool.append( - (host, - [ - i for i in range(allocated_num_on_host, - allocated_num_on_host + tensor_parallel) - ])) + ( + host, + [ + i + for i in range( + allocated_num_on_host, + allocated_num_on_host + tensor_parallel, + ) + ], + ) + ) allocated_num += 1 available_on_host -= tensor_parallel - if allocated_num < num_replicas: + if allocated_num < replica_num: raise ValueError( - f'No sufficient GPUs for {num_replicas} replica(s), only {allocated_num} replica(s) can be deployed' + f"No sufficient GPUs for {replica_num} replica(s), only {allocated_num} replica(s) can be deployed" ) return replica_pool diff --git a/mii/constants.py b/mii/constants.py index ba4cfa2f..55487701 100644 --- a/mii/constants.py +++ b/mii/constants.py @@ -2,104 +2,60 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team -import enum +from enum import Enum -#TODO naming.. -class DeploymentType(enum.Enum): - LOCAL = 1 - AML = 2 - NON_PERSISTENT = 3 +class DeploymentType(str, Enum): + LOCAL = "local" + AML = "aml" + NON_PERSISTENT = "non-persistent" -MII_CONFIGS_KEY = 'mii_configs' +class TaskType(str, Enum): + TEXT_GENERATION = "text-generation" + TEXT_CLASSIFICATION = "text-classification" + QUESTION_ANSWERING = "question-answering" + FILL_MASK = "fill-mask" + TOKEN_CLASSIFICATION = "token-classification" + CONVERSATIONAL = "conversational" + TEXT2IMG = "text-to-image" -class Tasks(enum.Enum): - TEXT_GENERATION = 1 - TEXT_CLASSIFICATION = 2 - QUESTION_ANSWERING = 3 - FILL_MASK = 4 - TOKEN_CLASSIFICATION = 5 - CONVERSATIONAL = 6 - TEXT2IMG = 7 +class ModelProvider(str, Enum): + HUGGING_FACE = "hugging-face" + ELEUTHER_AI = "eleuther-ai" + DIFFUSERS = "diffusers" -TEXT_GENERATION_NAME = 'text-generation' -TEXT_CLASSIFICATION_NAME = 'text-classification' -QUESTION_ANSWERING_NAME = 'question-answering' -FILL_MASK_NAME = 'fill-mask' -TOKEN_CLASSIFICATION_NAME = 'token-classification' -CONVERSATIONAL_NAME = 'conversational' -TEXT2IMG_NAME = "text-to-image" - - -class ModelProvider(enum.Enum): - HUGGING_FACE = 1 - ELEUTHER_AI = 2 - DIFFUSERS = 3 - - -MODEL_PROVIDER_NAME_HF = "hugging-face" -MODEL_PROVIDER_NAME_EA = "eleuther-ai" -MODEL_PROVIDER_NAME_DIFFUSERS = "diffusers" - -MODEL_PROVIDER_MAP = { - MODEL_PROVIDER_NAME_HF: ModelProvider.HUGGING_FACE, - MODEL_PROVIDER_NAME_EA: ModelProvider.ELEUTHER_AI, - MODEL_PROVIDER_NAME_DIFFUSERS: ModelProvider.DIFFUSERS -} - SUPPORTED_MODEL_TYPES = { - 'roberta': ModelProvider.HUGGING_FACE, - 'xlm-roberta': ModelProvider.HUGGING_FACE, - 'gpt2': ModelProvider.HUGGING_FACE, - 'bert': ModelProvider.HUGGING_FACE, - 'gpt_neo': ModelProvider.HUGGING_FACE, - 'gptj': ModelProvider.HUGGING_FACE, - 'opt': ModelProvider.HUGGING_FACE, - 'bloom': ModelProvider.HUGGING_FACE, - 'gpt-neox': ModelProvider.ELEUTHER_AI, - 'stable-diffusion': ModelProvider.DIFFUSERS, - 'llama': ModelProvider.HUGGING_FACE + "roberta": ModelProvider.HUGGING_FACE, + "xlm-roberta": ModelProvider.HUGGING_FACE, + "gpt2": ModelProvider.HUGGING_FACE, + "bert": ModelProvider.HUGGING_FACE, + "gpt_neo": ModelProvider.HUGGING_FACE, + "gptj": ModelProvider.HUGGING_FACE, + "opt": ModelProvider.HUGGING_FACE, + "bloom": ModelProvider.HUGGING_FACE, + "gpt-neox": ModelProvider.ELEUTHER_AI, + "stable-diffusion": ModelProvider.DIFFUSERS, + "llama": ModelProvider.HUGGING_FACE, } -SUPPORTED_TASKS = [ - TEXT_GENERATION_NAME, - TEXT_CLASSIFICATION_NAME, - QUESTION_ANSWERING_NAME, - FILL_MASK_NAME, - TOKEN_CLASSIFICATION_NAME, - CONVERSATIONAL_NAME, - TEXT2IMG_NAME -] - REQUIRED_KEYS_PER_TASK = { - TEXT_GENERATION_NAME: ["query"], - TEXT_CLASSIFICATION_NAME: ["query"], - QUESTION_ANSWERING_NAME: ["context", - "question"], - FILL_MASK_NAME: ["query"], - TOKEN_CLASSIFICATION_NAME: ["query"], - CONVERSATIONAL_NAME: - ['text', - 'conversation_id', - 'past_user_inputs', - 'generated_responses'], - TEXT2IMG_NAME: ["query"] + TaskType.TEXT_GENERATION: ["query"], + TaskType.TEXT_CLASSIFICATION: ["query"], + TaskType.QUESTION_ANSWERING: ["context", "question"], + TaskType.FILL_MASK: ["query"], + TaskType.TOKEN_CLASSIFICATION: ["query"], + TaskType.CONVERSATIONAL: [ + "text", + "conversation_id", + "past_user_inputs", + "generated_responses", + ], + TaskType.TEXT2IMG: ["query"], } -MODEL_NAME_KEY = 'model_name' -TASK_NAME_KEY = 'task_name' -DEPLOYMENT_NAME_KEY = 'deployment_name' -MODEL_PATH_KEY = 'model_path' -LOAD_BALANCER_CONFIG_KEY = 'load_balancer_config' - -ENABLE_DEEPSPEED_KEY = 'ds_optimize' -ENABLE_DEEPSPEED_ZERO_KEY = 'ds_zero' -DEEPSPEED_CONFIG_KEY = 'ds_config' -CHECKPOINT_KEY = "checkpoint" - MII_CACHE_PATH = "MII_CACHE_PATH" MII_CACHE_PATH_DEFAULT = "/tmp/mii_cache" diff --git a/mii/deployment.py b/mii/deployment.py index 449b3801..d4e2bdc4 100644 --- a/mii/deployment.py +++ b/mii/deployment.py @@ -7,50 +7,46 @@ import os import mii -from deepspeed.launcher.runner import fetch_hostfile -from .constants import DeploymentType, MII_MODEL_PATH_DEFAULT, MODEL_PROVIDER_MAP -from .utils import logger, get_task_name, get_provider_name +from .utils import logger from .models.score import create_score_file from .models import load_models -from .config import ReplicaConfig, LoadBalancerConfig - - -def support_legacy_api(task, - model, - deployment_type=DeploymentType.LOCAL, - model_path="", - enable_deepspeed=True, - enable_zero=False, - ds_config=None, - mii_config=None, - version=1): - deployment_tag = deployment_name - +from .config import MIIConfig, DeploymentType + + +def support_legacy_api( + task, + model, + deployment_type=DeploymentType.LOCAL, + model_path="", + enable_deepspeed=True, + enable_zero=False, + ds_config=None, + mii_config=None, + version=1, +): if ds_config is None: ds_config = {} if mii_config is None: mii_config = {} deployment_config = { - "deployment_name": deployment_name, "task": task, "model": model, "model_path": model_path, - "ds_optimize": enable_deepspeed, - "ds_zero": enable_zero, + "enable_deepspeed": enable_deepspeed, + "enable_zero": enable_zero, "ds_config": ds_config, } for key, val in mii_config.items(): if not hasattr(MIIConfig, key): deployment_config[key] = val - deployments = [deployment_config] mii_config = {k: v for k, v in mii_config.items() if hasattr(MIIConfig, k)} mii_config["version"] = version mii_config["deployment_type"] = deployment_type - return deployment_tag, deployments, mii_config + return deployment_config, mii_config def deploy(deployment_name, deployment_config=None, mii_config=None, *args, **kwargs): @@ -58,12 +54,15 @@ def deploy(deployment_name, deployment_config=None, mii_config=None, *args, **kw mii_config = {} if args or kwargs: - assert not deployment_config, "We do not support mixture of legacy and new API options, use latest API." + assert ( + not deployment_config + ), "We do not support mixture of legacy and new API options, use latest API." kwargs["mii_config"] = mii_config deployment_config, mii_config = support_legacy_api(*args, **kwargs) deployment_config["deployment_name"] = deployment_name - mii_config = mii.config.MIIConfig(**mii_config, deployment_config=deployment_config) + mii_config["deployment_config"] = deployment_config + mii_config = mii.config.MIIConfig(**mii_config) if mii_config.deployment_config.enable_deepspeed: logger.info( @@ -74,33 +73,24 @@ def deploy(deployment_name, deployment_config=None, mii_config=None, *args, **kw f"************* DeepSpeed Optimizations not enabled. Please use enable_deepspeed to get better performance *************" ) - if deployment_type != DeploymentType.NON_PERSISTENT: + if mii_config.deployment_type != DeploymentType.NON_PERSISTENT: create_score_file(mii_config) - if deployment_type == DeploymentType.AML: + if mii_config.deployment_type == DeploymentType.AML: _deploy_aml(mii_config) - elif deployment_type == DeploymentType.LOCAL: + elif mii_config.deployment_type == DeploymentType.LOCAL: return _deploy_local(mii_config) - elif deployment_type == DeploymentType.NON_PERSISTENT: - assert int(os.getenv('WORLD_SIZE', '1')) == mii_config.deployment_config.tensor_parallel, "World Size does not equal number of tensors. When using non-persistent deployment type, please launch with `deepspeed --num_gpus `" + elif mii_config.deployment_type == DeploymentType.NON_PERSISTENT: + assert ( + int(os.getenv("WORLD_SIZE", "1")) + == mii_config.deployment_config.tensor_parallel + ), "World Size does not equal number of tensors. When using non-persistent deployment type, please launch with `deepspeed --num_gpus `" deployment_name = mii_config.deployment_config.deployment_name - model = mii_config.deployment_config.model task = mii_config.deployment_config.task - model_path = mii_config.deployment_config.model_path - enable_deepspeed = mii_config.deployment_config.enable_deepspeed - enable_zero = mii_config.deployment_config.enable_zero - provider = MODEL_PROVIDER_MAP[get_provider_name(model, task)] - mii.non_persistent_models[deployment_name] = (load_models( + mii.non_persistent_models[deployment_name] = ( + load_models(deployment_config), task, - model, - model_path, - enable_deepspeed, - enable_zero, - provider, - mii_config), - task) - else: - raise Exception(f"Unknown deployment type: {deployment_type}") + ) def _deploy_local(mii_config): @@ -113,7 +103,8 @@ def _deploy_aml(mii_config): acr_name=acr_name, deployment_name=mii_config.deployment_config.deployment_name, model_name=mii_config.deployment_config.model, - version=mii_config.version) + version=mii_config.version, + ) print( f"AML deployment assets at {mii.aml_related.utils.aml_output_path(deployment_name)}" ) diff --git a/mii/grpc_related/modelresponse_server.py b/mii/grpc_related/modelresponse_server.py index 4a0a5d00..f6eab456 100644 --- a/mii/grpc_related/modelresponse_server.py +++ b/mii/grpc_related/modelresponse_server.py @@ -13,7 +13,15 @@ import threading import time -from mii.constants import GRPC_MAX_MSG_SIZE, CREATE_SESSION_METHOD, DESTROY_SESSION_METHOD, TERMINATE_METHOD, LB_MAX_WORKER_THREADS, SERVER_SHUTDOWN_TIMEOUT, Tasks +from mii.constants import ( + GRPC_MAX_MSG_SIZE, + CREATE_SESSION_METHOD, + DESTROY_SESSION_METHOD, + TERMINATE_METHOD, + LB_MAX_WORKER_THREADS, + SERVER_SHUTDOWN_TIMEOUT, + Tasks, +) from mii.method_table import GRPC_METHOD_TABLE from mii.client import create_channel from mii.utils import get_task, unpack_proto_query_kwargs @@ -23,6 +31,7 @@ class ServiceBase(modelresponse_pb2_grpc.ModelResponseServicer): """ Base class to provide common features of an inference server """ + def __init__(self): self._stop_event = threading.Event() @@ -38,6 +47,7 @@ class ModelResponse(ServiceBase): """ Implementation class of an MII inference server """ + def __init__(self, inference_pipeline): super().__init__() self.inference_pipeline = inference_pipeline @@ -87,10 +97,11 @@ def _run_inference(self, method_name, request_proto): response = task_methods.run_inference(self.inference_pipeline, args, kwargs) end = time.time() - model_time = self._get_model_time(self.inference_pipeline.model, - sum_times=True) if hasattr( - self.inference_pipeline, - "model") else -1 + model_time = ( + self._get_model_time(self.inference_pipeline.model, sum_times=True) + if hasattr(self.inference_pipeline, "model") + else -1 + ) return task_methods.pack_response_to_proto(response, end - start, model_time) @@ -138,6 +149,7 @@ class ParallelStubInvoker: This class aims to call gRPC methods without conversions between proto and python object. TensorParallelClient can be used for invocation with the conversions. """ + def __init__(self, host, ports): # Assumption: target services are all on the same host self.stubs = [] @@ -158,23 +170,21 @@ async def _invoke_async(self, method_name, proto_request): def invoke(self, method_name, proto_request): # This is needed because gRPC calls from interceptor are launched from return asyncio.run_coroutine_threadsafe( - self._invoke_async(method_name, - proto_request), - self.asyncio_loop).result() + self._invoke_async(method_name, proto_request), self.asyncio_loop + ).result() class LoadBalancingInterceptor(grpc.ServerInterceptor): - def __init__(self, task_name, replica_configs): + def __init__(self, deployment_config): super().__init__() self.asyncio_loop = asyncio.get_event_loop() self.stubs = [ - ParallelStubInvoker(replica.hostname, - replica.tensor_parallel_ports) - for replica in replica_configs + ParallelStubInvoker(replica.hostname, replica.tensor_parallel_ports) + for replica in deployment_config.replica_configs ] self.counter = AtomicCounter() - self.task = get_task(task_name) + self.task = deployment_config.task self.replica_sessions = {} # Start the asyncio loop in a separate thread @@ -182,7 +192,7 @@ def run_asyncio_loop(loop): asyncio.set_event_loop(loop) loop.run_forever() - threading.Thread(target=run_asyncio_loop, args=(self.asyncio_loop, )).start() + threading.Thread(target=run_asyncio_loop, args=(self.asyncio_loop,)).start() def choose_stub(self, call_count): return self.stubs[call_count % len(self.stubs)] @@ -196,8 +206,9 @@ def invoke_intercept_method(request_proto, context): if method_name == TERMINATE_METHOD: for stub in self.stubs: - stub.invoke(TERMINATE_METHOD, - google_dot_protobuf_dot_empty__pb2.Empty()) + stub.invoke( + TERMINATE_METHOD, google_dot_protobuf_dot_empty__pb2.Empty() + ) self.asyncio_loop.call_soon_threadsafe(self.asyncio_loop.stop) return next_handler.unary_unary(request_proto, context) @@ -207,7 +218,8 @@ def invoke_intercept_method(request_proto, context): if method_name == CREATE_SESSION_METHOD: if request_proto.session_id in self.sessions: raise ValueError( - f"session {request_proto.session_id} already exists") + f"session {request_proto.session_id} already exists" + ) self.replica_sessions[request_proto.session_id] = replica_index self.stubs[replica_index].invoke(CREATE_SESSION_METHOD, request_proto) return google_dot_protobuf_dot_empty__pb2.Empty() @@ -230,19 +242,22 @@ def invoke_intercept_method(request_proto, context): return grpc.unary_unary_rpc_method_handler( invoke_intercept_method, request_deserializer=next_handler.request_deserializer, - response_serializer=next_handler.response_serializer) + response_serializer=next_handler.response_serializer, + ) def _do_serve(service_impl, port, interceptors=[]): stop_event = service_impl.get_stop_event() - server = grpc.server(futures.ThreadPoolExecutor(max_workers=LB_MAX_WORKER_THREADS), - interceptors=interceptors, - options=[('grpc.max_send_message_length', - GRPC_MAX_MSG_SIZE), - ('grpc.max_receive_message_length', - GRPC_MAX_MSG_SIZE)]) + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=LB_MAX_WORKER_THREADS), + interceptors=interceptors, + options=[ + ("grpc.max_send_message_length", GRPC_MAX_MSG_SIZE), + ("grpc.max_receive_message_length", GRPC_MAX_MSG_SIZE), + ], + ) modelresponse_pb2_grpc.add_ModelResponseServicer_to_server(service_impl, server) - server.add_insecure_port(f'[::]:{port}') + server.add_insecure_port(f"[::]:{port}") print(f"About to start server") server.start() print(f"Started") @@ -254,14 +269,11 @@ def serve_inference(inference_pipeline, port): _do_serve(ModelResponse(inference_pipeline), port) -def serve_load_balancing(task_name, lb_config): - _do_serve(ServiceBase(), - lb_config.port, - [LoadBalancingInterceptor(task_name, - lb_config.replica_configs)]) +def serve_load_balancing(deployment_config, lb_port): + _do_serve(ServiceBase(), lb_port, [LoadBalancingInterceptor(deployment_config)]) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() print(sys.argv[1]) serve_inference(None, sys.argv[1]) diff --git a/mii/grpc_related/proto/modelresponse_pb2.py b/mii/grpc_related/proto/modelresponse_pb2.py index 76b1f994..677ecd0d 100644 --- a/mii/grpc_related/proto/modelresponse_pb2.py +++ b/mii/grpc_related/proto/modelresponse_pb2.py @@ -10,6 +10,7 @@ from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -17,22 +18,22 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xbb\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xb9\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"\xb9\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xa1\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x1c\n\x0f\x63onversation_id\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_conversation_id\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\x03\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x32\xd4\x06\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x43\n\rCreateSession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x44\n\x0e\x44\x65stroySession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12V\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1f.modelresponse.MultiStringReply\"\x00\x12]\n\x13\x43lassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply\"\x00\x12W\n\rFillMaskReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12\x62\n\x18TokenClassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12]\n\x13\x43onversationalReply\x12\".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply\"\x00\x12N\n\x0cTxt2ImgReply\x12!.modelresponse.MultiStringRequest\x1a\x19.modelresponse.ImageReply\"\x00\x62\x06proto3' + b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t"\xbb\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01"\xb9\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02"\xb9\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01"\xa1\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x1c\n\x0f\x63onversation_id\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_conversation_id"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\x03\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x32\xd4\x06\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty"\x00\x12\x43\n\rCreateSession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty"\x00\x12\x44\n\x0e\x44\x65stroySession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty"\x00\x12V\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1f.modelresponse.MultiStringReply"\x00\x12]\n\x13\x43lassificationReply\x12".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply"\x00\x12W\n\rFillMaskReply\x12".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply"\x00\x12\x62\n\x18TokenClassificationReply\x12".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply"\x00\x12]\n\x13\x43onversationalReply\x12".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply"\x00\x12N\n\x0cTxt2ImgReply\x12!.modelresponse.MultiStringRequest\x1a\x19.modelresponse.ImageReply"\x00\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'modelresponse_pb2', globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "modelresponse_pb2", globals()) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._options = None - _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' + _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b"8\001" _MULTISTRINGREQUEST_QUERYKWARGSENTRY._options = None - _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' + _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b"8\001" _QAREQUEST_QUERYKWARGSENTRY._options = None - _QAREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' + _QAREQUEST_QUERYKWARGSENTRY._serialized_options = b"8\001" _CONVERSATIONREQUEST_QUERYKWARGSENTRY._options = None - _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' + _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_options = b"8\001" _VALUE._serialized_start = 67 _VALUE._serialized_end = 162 _SESSIONID._serialized_start = 164 diff --git a/mii/grpc_related/proto/modelresponse_pb2_grpc.py b/mii/grpc_related/proto/modelresponse_pb2_grpc.py index 95cfa825..af1b1898 100644 --- a/mii/grpc_related/proto/modelresponse_pb2_grpc.py +++ b/mii/grpc_related/proto/modelresponse_pb2_grpc.py @@ -13,6 +13,7 @@ class ModelResponseStub(object): """Missing associated documentation comment in .proto file.""" + def __init__(self, channel): """Constructor. @@ -20,53 +21,52 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Terminate = channel.unary_unary( - '/modelresponse.ModelResponse/Terminate', - request_serializer=google_dot_protobuf_dot_empty__pb2.Empty. - SerializeToString, + "/modelresponse.ModelResponse/Terminate", + request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, ) self.CreateSession = channel.unary_unary( - '/modelresponse.ModelResponse/CreateSession', + "/modelresponse.ModelResponse/CreateSession", request_serializer=modelresponse__pb2.SessionID.SerializeToString, response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, ) self.DestroySession = channel.unary_unary( - '/modelresponse.ModelResponse/DestroySession', + "/modelresponse.ModelResponse/DestroySession", request_serializer=modelresponse__pb2.SessionID.SerializeToString, response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, ) self.GeneratorReply = channel.unary_unary( - '/modelresponse.ModelResponse/GeneratorReply', + "/modelresponse.ModelResponse/GeneratorReply", request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, response_deserializer=modelresponse__pb2.MultiStringReply.FromString, ) self.ClassificationReply = channel.unary_unary( - '/modelresponse.ModelResponse/ClassificationReply', + "/modelresponse.ModelResponse/ClassificationReply", request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, response_deserializer=modelresponse__pb2.SingleStringReply.FromString, ) self.QuestionAndAnswerReply = channel.unary_unary( - '/modelresponse.ModelResponse/QuestionAndAnswerReply', + "/modelresponse.ModelResponse/QuestionAndAnswerReply", request_serializer=modelresponse__pb2.QARequest.SerializeToString, response_deserializer=modelresponse__pb2.SingleStringReply.FromString, ) self.FillMaskReply = channel.unary_unary( - '/modelresponse.ModelResponse/FillMaskReply', + "/modelresponse.ModelResponse/FillMaskReply", request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, response_deserializer=modelresponse__pb2.SingleStringReply.FromString, ) self.TokenClassificationReply = channel.unary_unary( - '/modelresponse.ModelResponse/TokenClassificationReply', + "/modelresponse.ModelResponse/TokenClassificationReply", request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, response_deserializer=modelresponse__pb2.SingleStringReply.FromString, ) self.ConversationalReply = channel.unary_unary( - '/modelresponse.ModelResponse/ConversationalReply', + "/modelresponse.ModelResponse/ConversationalReply", request_serializer=modelresponse__pb2.ConversationRequest.SerializeToString, response_deserializer=modelresponse__pb2.ConversationReply.FromString, ) self.Txt2ImgReply = channel.unary_unary( - '/modelresponse.ModelResponse/Txt2ImgReply', + "/modelresponse.ModelResponse/Txt2ImgReply", request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, response_deserializer=modelresponse__pb2.ImageReply.FromString, ) @@ -74,156 +74,148 @@ def __init__(self, channel): class ModelResponseServicer(object): """Missing associated documentation comment in .proto file.""" + def Terminate(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def CreateSession(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def DestroySession(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def GeneratorReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def ClassificationReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def QuestionAndAnswerReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def FillMaskReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def TokenClassificationReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def ConversationalReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def Txt2ImgReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def add_ModelResponseServicer_to_server(servicer, server): rpc_method_handlers = { - 'Terminate': - grpc.unary_unary_rpc_method_handler( + "Terminate": grpc.unary_unary_rpc_method_handler( servicer.Terminate, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. - SerializeToString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, ), - 'CreateSession': - grpc.unary_unary_rpc_method_handler( + "CreateSession": grpc.unary_unary_rpc_method_handler( servicer.CreateSession, request_deserializer=modelresponse__pb2.SessionID.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. - SerializeToString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, ), - 'DestroySession': - grpc.unary_unary_rpc_method_handler( + "DestroySession": grpc.unary_unary_rpc_method_handler( servicer.DestroySession, request_deserializer=modelresponse__pb2.SessionID.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. - SerializeToString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, ), - 'GeneratorReply': - grpc.unary_unary_rpc_method_handler( + "GeneratorReply": grpc.unary_unary_rpc_method_handler( servicer.GeneratorReply, request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, response_serializer=modelresponse__pb2.MultiStringReply.SerializeToString, ), - 'ClassificationReply': - grpc.unary_unary_rpc_method_handler( + "ClassificationReply": grpc.unary_unary_rpc_method_handler( servicer.ClassificationReply, request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, ), - 'QuestionAndAnswerReply': - grpc.unary_unary_rpc_method_handler( + "QuestionAndAnswerReply": grpc.unary_unary_rpc_method_handler( servicer.QuestionAndAnswerReply, request_deserializer=modelresponse__pb2.QARequest.FromString, response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, ), - 'FillMaskReply': - grpc.unary_unary_rpc_method_handler( + "FillMaskReply": grpc.unary_unary_rpc_method_handler( servicer.FillMaskReply, request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, ), - 'TokenClassificationReply': - grpc.unary_unary_rpc_method_handler( + "TokenClassificationReply": grpc.unary_unary_rpc_method_handler( servicer.TokenClassificationReply, request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, ), - 'ConversationalReply': - grpc.unary_unary_rpc_method_handler( + "ConversationalReply": grpc.unary_unary_rpc_method_handler( servicer.ConversationalReply, request_deserializer=modelresponse__pb2.ConversationRequest.FromString, response_serializer=modelresponse__pb2.ConversationReply.SerializeToString, ), - 'Txt2ImgReply': - grpc.unary_unary_rpc_method_handler( + "Txt2ImgReply": grpc.unary_unary_rpc_method_handler( servicer.Txt2ImgReply, request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, response_serializer=modelresponse__pb2.ImageReply.SerializeToString, ), } - generic_handler = grpc.method_handlers_generic_handler('modelresponse.ModelResponse', - rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler, )) + generic_handler = grpc.method_handlers_generic_handler( + "modelresponse.ModelResponse", rpc_method_handlers + ) + server.add_generic_rpc_handlers((generic_handler,)) # This class is part of an EXPERIMENTAL API. class ModelResponse(object): """Missing associated documentation comment in .proto file.""" + @staticmethod - def Terminate(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + def Terminate( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): return grpc.experimental.unary_unary( request, target, - '/modelresponse.ModelResponse/Terminate', + "/modelresponse.ModelResponse/Terminate", google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, google_dot_protobuf_dot_empty__pb2.Empty.FromString, options, @@ -233,23 +225,26 @@ def Terminate(request, compression, wait_for_ready, timeout, - metadata) + metadata, + ) @staticmethod - def CreateSession(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + def CreateSession( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): return grpc.experimental.unary_unary( request, target, - '/modelresponse.ModelResponse/CreateSession', + "/modelresponse.ModelResponse/CreateSession", modelresponse__pb2.SessionID.SerializeToString, google_dot_protobuf_dot_empty__pb2.Empty.FromString, options, @@ -259,23 +254,26 @@ def CreateSession(request, compression, wait_for_ready, timeout, - metadata) + metadata, + ) @staticmethod - def DestroySession(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + def DestroySession( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): return grpc.experimental.unary_unary( request, target, - '/modelresponse.ModelResponse/DestroySession', + "/modelresponse.ModelResponse/DestroySession", modelresponse__pb2.SessionID.SerializeToString, google_dot_protobuf_dot_empty__pb2.Empty.FromString, options, @@ -285,23 +283,26 @@ def DestroySession(request, compression, wait_for_ready, timeout, - metadata) + metadata, + ) @staticmethod - def GeneratorReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + def GeneratorReply( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): return grpc.experimental.unary_unary( request, target, - '/modelresponse.ModelResponse/GeneratorReply', + "/modelresponse.ModelResponse/GeneratorReply", modelresponse__pb2.MultiStringRequest.SerializeToString, modelresponse__pb2.MultiStringReply.FromString, options, @@ -311,23 +312,26 @@ def GeneratorReply(request, compression, wait_for_ready, timeout, - metadata) + metadata, + ) @staticmethod - def ClassificationReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + def ClassificationReply( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): return grpc.experimental.unary_unary( request, target, - '/modelresponse.ModelResponse/ClassificationReply', + "/modelresponse.ModelResponse/ClassificationReply", modelresponse__pb2.SingleStringRequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, options, @@ -337,23 +341,26 @@ def ClassificationReply(request, compression, wait_for_ready, timeout, - metadata) + metadata, + ) @staticmethod - def QuestionAndAnswerReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + def QuestionAndAnswerReply( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): return grpc.experimental.unary_unary( request, target, - '/modelresponse.ModelResponse/QuestionAndAnswerReply', + "/modelresponse.ModelResponse/QuestionAndAnswerReply", modelresponse__pb2.QARequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, options, @@ -363,23 +370,26 @@ def QuestionAndAnswerReply(request, compression, wait_for_ready, timeout, - metadata) + metadata, + ) @staticmethod - def FillMaskReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + def FillMaskReply( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): return grpc.experimental.unary_unary( request, target, - '/modelresponse.ModelResponse/FillMaskReply', + "/modelresponse.ModelResponse/FillMaskReply", modelresponse__pb2.SingleStringRequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, options, @@ -389,23 +399,26 @@ def FillMaskReply(request, compression, wait_for_ready, timeout, - metadata) + metadata, + ) @staticmethod - def TokenClassificationReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + def TokenClassificationReply( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): return grpc.experimental.unary_unary( request, target, - '/modelresponse.ModelResponse/TokenClassificationReply', + "/modelresponse.ModelResponse/TokenClassificationReply", modelresponse__pb2.SingleStringRequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, options, @@ -415,23 +428,26 @@ def TokenClassificationReply(request, compression, wait_for_ready, timeout, - metadata) + metadata, + ) @staticmethod - def ConversationalReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + def ConversationalReply( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): return grpc.experimental.unary_unary( request, target, - '/modelresponse.ModelResponse/ConversationalReply', + "/modelresponse.ModelResponse/ConversationalReply", modelresponse__pb2.ConversationRequest.SerializeToString, modelresponse__pb2.ConversationReply.FromString, options, @@ -441,23 +457,26 @@ def ConversationalReply(request, compression, wait_for_ready, timeout, - metadata) + metadata, + ) @staticmethod - def Txt2ImgReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + def Txt2ImgReply( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): return grpc.experimental.unary_unary( request, target, - '/modelresponse.ModelResponse/Txt2ImgReply', + "/modelresponse.ModelResponse/Txt2ImgReply", modelresponse__pb2.MultiStringRequest.SerializeToString, modelresponse__pb2.ImageReply.FromString, options, @@ -467,4 +486,5 @@ def Txt2ImgReply(request, compression, wait_for_ready, timeout, - metadata) + metadata, + ) diff --git a/mii/grpc_related/restful_gateway.py b/mii/grpc_related/restful_gateway.py index e8cfa934..f3eeaa46 100644 --- a/mii/grpc_related/restful_gateway.py +++ b/mii/grpc_related/restful_gateway.py @@ -17,9 +17,9 @@ def shutdown(thread): thread.server.shutdown() -def createRestfulGatewayApp(deployment_name, task, mii_config, server_thread): +def createRestfulGatewayApp(deployment_config, lb_port, server_thread): # client must be thread-safe - client = mii.MIIClient(task, "localhost", mii_config.port_number) + client = mii.MIIClient(task, "localhost", lb_port) class RestfulGatewayService(Resource): def __init__(self): @@ -33,26 +33,26 @@ def post(self): app = Flask("RestfulGateway") - @app.route("/terminate", methods=['GET']) + @app.route("/terminate", methods=["GET"]) def terminate(): # Need to shutdown *after* completing the request - threading.Thread(target=shutdown, args=(server_thread, )).start() + threading.Thread(target=shutdown, args=(server_thread,)).start() return "Shutting down RESTful API gateway server" api = Api(app) - path = '/{}/{}'.format(RESTFUL_API_PATH, deployment_name) + path = "/{}/{}".format(RESTFUL_API_PATH, deployment_config.deployment_name) api.add_resource(RestfulGatewayService, path) return app class RestfulGatewayThread(threading.Thread): - def __init__(self, deployment_name, task, mii_config): + def __init__(self, deployment_config, lb_port, rest_port): threading.Thread.__init__(self) self.mii_config = mii_config - app = createRestfulGatewayApp(deployment_name, task, mii_config, self) - self.server = make_server('127.0.0.1', mii_config.restful_api_port, app) + app = createRestfulGatewayApp(deployment_config, lb_port, self) + self.server = make_server("127.0.0.1", rest_port, app) self.ctx = app.app_context() self.ctx.push() diff --git a/mii/launch/multi_gpu_server.py b/mii/launch/multi_gpu_server.py index 27878725..708ec80a 100644 --- a/mii/launch/multi_gpu_server.py +++ b/mii/launch/multi_gpu_server.py @@ -8,95 +8,88 @@ import base64 import json -from mii import MIIConfig, LoadBalancerConfig +from mii import DeploymentConfig from mii.models.load_models import load_models from mii.grpc_related.modelresponse_server import serve_inference, serve_load_balancing from mii.grpc_related.restful_gateway import RestfulGatewayThread -def decode_config_from_str(config_str): +def b64_encoded_config(config_str): # str -> bytes b64_bytes = config_str.encode() # decode b64 bytes -> json bytes config_bytes = base64.urlsafe_b64decode(b64_bytes) # convert json bytes -> str -> dict - return json.loads(config_bytes.decode()) + config_dict = json.loads(config_bytes.decode()) + # return mii.DeploymentConfig object + return DeploymentConfig(**config_dict) def main(): parser = argparse.ArgumentParser() - parser.add_argument("-n", "--deployment-name", type=str, help="deployment name") - parser.add_argument("-t", "--task-name", type=str, help="task name") - parser.add_argument("-m", "--model", type=str, help="model name") - parser.add_argument("-d", "--model-path", type=str, help="path to model") - parser.add_argument('-b', '--provider', type=str, help="model provider") - parser.add_argument("-o", - "--ds-optimize", - action='store_true', - help="Enable DeepSpeed") - parser.add_argument("-z", - "--ds-zero", - action='store_true', - help="Enable DeepSpeed ZeRO") - parser.add_argument("--ds-config", type=str, help="path to DeepSpeed ZeRO config") - parser.add_argument( - "-p", - "--port", + "--deployment-config", + type=b64_encoded_config, + help="base64 encoded deployment config", + ) + parser.add_argument( + "--server-port", type=int, - help="base server port, each rank will have unique port based on this value") - parser.add_argument("-c", "--config", type=str, help="base64 encoded mii config") - parser.add_argument("--load-balancer", - type=str, - default=None, - help="base64 encoded load balancer config") - parser.add_argument("-r", - "--restful-gateway", - action='store_true', - help="launch restful api gateway") - + default=0, + help="Port to user for DeepSpeed inference server.", + ) + parser.add_argument( + "--load-balancer", action="store_true", help="Launch load balancer process." + ) + parser.add_argument( + "--load-balancer-port", + type=int, + default=0, + help="Port to use for load balancer.", + ) + parser.add_argument( + "--restful-gateway", + action="store_true", + help="Launches restful gateway process.", + ) + parser.add_argument( + "--restful-gateway-port", + type=int, + default=0, + help="Port to use for restful gateway.", + ) args = parser.parse_args() - - # de-serialize config object - config_dict = decode_config_from_str(args.config) - # convert dict -> mii config - mii_config = MIIConfig(**config_dict) + assert not ( + args.load_balancer and args.restful_gateway + ), "Select only load-balancer OR restful-gateway." if args.restful_gateway: - print(f"Starting RESTful API gateway on port: {mii_config.restful_api_port}") - gateway_thread = RestfulGatewayThread(args.deployment_name, - args.task_name, - mii_config) + assert args.restful_gateway_port, "--restful-gateway-port must be provided." + print(f"Starting RESTful API gateway on port: {args.restful_gateway_port}") + gateway_thread = RestfulGatewayThread( + deployment_config, + lb_port=args.load_balancer_port, + rest_port=args.restful_gateway_port, + ) stop_event = gateway_thread.get_stop_event() gateway_thread.start() stop_event.wait() - elif args.load_balancer is None: - provider = mii.constants.MODEL_PROVIDER_MAP.get(args.provider, None) - assert provider is not None, f"Unknown model provider: {args.provider}" + elif args.load_balancer: + assert args.load_balancer_port, "--load-balancer-port must be provided." + print(f"Starting load balancer on port: {lb_config.port}") + serve_load_balancing(deployment_config, args.load_balancer_port) - assert args.port is not None, "port is required for inference server" - local_rank = int(os.getenv('LOCAL_RANK', '0')) - port = args.port + local_rank + else: + assert args.server_port, "--server-port must be provided." + local_rank = int(os.getenv("LOCAL_RANK", "0")) + port = args.server_port + local_rank - inference_pipeline = load_models(task_name=args.task_name, - model_name=args.model, - model_path=args.model_path, - ds_optimize=args.ds_optimize, - ds_zero=args.ds_zero, - ds_config_path=args.ds_config, - provider=provider, - mii_config=mii_config) + inference_pipeline = load_models(deployment_config) print(f"Starting server on port: {port}") serve_inference(inference_pipeline, port) - else: - lb_config_dict = decode_config_from_str(args.load_balancer) - lb_config = LoadBalancerConfig(**lb_config_dict) - - print(f"Starting load balancer on port: {lb_config.port}") - serve_load_balancing(args.task_name, lb_config) if __name__ == "__main__": diff --git a/mii/method_table.py b/mii/method_table.py index c412f446..a5457f24 100644 --- a/mii/method_table.py +++ b/mii/method_table.py @@ -4,7 +4,7 @@ # DeepSpeed Team from abc import ABC, abstractmethod from transformers import Conversation -from mii.constants import Tasks +from mii.constants import TaskType from mii.grpc_related.proto import modelresponse_pb2 from mii.utils import kwarg_dict_to_proto, unpack_proto_query_kwargs from mii.models.utils import ImageResponse @@ -12,31 +12,33 @@ def single_string_request_to_proto(self, request_dict, **query_kwargs): return modelresponse_pb2.SingleStringRequest( - request=request_dict['query'], - query_kwargs=kwarg_dict_to_proto(query_kwargs)) + request=request_dict["query"], query_kwargs=kwarg_dict_to_proto(query_kwargs) + ) def single_string_response_to_proto(self, response, time_taken, model_time_taken): - return modelresponse_pb2.SingleStringReply(response=f"{response}", - time_taken=time_taken, - model_time_taken=model_time_taken) + return modelresponse_pb2.SingleStringReply( + response=f"{response}", time_taken=time_taken, model_time_taken=model_time_taken + ) def multi_string_request_to_proto(self, request_dict, **query_kwargs): return modelresponse_pb2.MultiStringRequest( - request=request_dict['query'] if isinstance(request_dict['query'], - list) else [request_dict['query']], - query_kwargs=kwarg_dict_to_proto(query_kwargs)) + request=request_dict["query"] + if isinstance(request_dict["query"], list) + else [request_dict["query"]], + query_kwargs=kwarg_dict_to_proto(query_kwargs), + ) def proto_request_to_single_input(self, request): - args = (request.request, ) + args = (request.request,) kwargs = unpack_proto_query_kwargs(request.query_kwargs) return args, kwargs def proto_request_to_list(self, request): - args = ([r for r in request.request], ) + args = ([r for r in request.request],) kwargs = unpack_proto_query_kwargs(request.query_kwargs) return args, kwargs @@ -91,7 +93,7 @@ def preprocess_session(self, session_id, args): if len(args[0]) != 1: raise ValueError(f"You can pass only one prompt with a session_id") - args = ([self.session_context[session_id] + args[0][0]], ) + args = ([self.session_context[session_id] + args[0][0]],) return args def run_inference(self, inference_pipeline, args, kwargs): @@ -108,18 +110,20 @@ def run_inference(self, inference_pipeline, args, kwargs): def postprocess_session(self, session_id, args, response): generated_text = response[0][0]["generated_text"] self.session_context[session_id] = generated_text - response[0][0]["generated_text"] = generated_text[len(args[0][0]):] + response[0][0]["generated_text"] = generated_text[len(args[0][0]) :] return response def pack_response_to_proto(self, response, time_taken, model_time_taken): text_responses = [] for response in response: - text = response[0]['generated_text'] + text = response[0]["generated_text"] text_responses.append(text) - return modelresponse_pb2.MultiStringReply(response=text_responses, - time_taken=time_taken, - model_time_taken=model_time_taken) + return modelresponse_pb2.MultiStringReply( + response=text_responses, + time_taken=time_taken, + model_time_taken=model_time_taken, + ) class TextClassificationMethods(TaskMethods): @@ -141,9 +145,10 @@ def method(self): def pack_request_to_proto(self, request_dict, **query_kwargs): return modelresponse_pb2.QARequest( - question=request_dict['question'], - context=request_dict['context'], - query_kwargs=kwarg_dict_to_proto(query_kwargs)) + question=request_dict["question"], + context=request_dict["context"], + query_kwargs=kwarg_dict_to_proto(query_kwargs), + ) def unpack_request_from_proto(self, request): kwargs = unpack_proto_query_kwargs(request.query_kwargs) @@ -180,24 +185,31 @@ def method(self): def create_conversation(self, request, **kwargs): if isinstance(request, dict): - assert 'text' in request and 'past_user_inputs' in request and 'generated_responses' in request, "Conversation requires 'text', 'past_user_inputs', and 'generated_responses' keys" - text = request['text'] - conversation_id = request[ - 'conversation_id'] if 'conversation_id' in request else None - past_user_inputs = request['past_user_inputs'] - generated_responses = request['generated_responses'] + assert ( + "text" in request + and "past_user_inputs" in request + and "generated_responses" in request + ), "Conversation requires 'text', 'past_user_inputs', and 'generated_responses' keys" + text = request["text"] + conversation_id = ( + request["conversation_id"] if "conversation_id" in request else None + ) + past_user_inputs = request["past_user_inputs"] + generated_responses = request["generated_responses"] else: - text = getattr(request, 'text') - conversation_id = getattr(request, 'conversation_id') - past_user_inputs = getattr(request, 'past_user_inputs') - generated_responses = getattr(request, 'generated_responses') - - conv = Conversation(text=text, - conversation_id=conversation_id, - past_user_inputs=past_user_inputs, - generated_responses=generated_responses, - **kwargs) + text = getattr(request, "text") + conversation_id = getattr(request, "conversation_id") + past_user_inputs = getattr(request, "past_user_inputs") + generated_responses = getattr(request, "generated_responses") + + conv = Conversation( + text=text, + conversation_id=conversation_id, + past_user_inputs=past_user_inputs, + generated_responses=generated_responses, + **kwargs, + ) return conv def pack_response_to_proto(self, conv, time_taken, model_time_taken): @@ -206,23 +218,26 @@ def pack_response_to_proto(self, conv, time_taken, model_time_taken): past_user_inputs=conv.past_user_inputs, generated_responses=conv.generated_responses, time_taken=time_taken, - model_time_taken=model_time_taken) + model_time_taken=model_time_taken, + ) def unpack_request_from_proto(self, request): kwargs = unpack_proto_query_kwargs(request.query_kwargs) conv = self.create_conversation(request, **kwargs) - args = (conv, ) + args = (conv,) kwargs = {} return args, kwargs def pack_request_to_proto(self, request_dict, **query_kwargs): return modelresponse_pb2.ConversationRequest( - text=request_dict['text'], - conversation_id=request_dict['conversation_id'] - if 'conversation_id' in request_dict else None, - past_user_inputs=request_dict['past_user_inputs'], - generated_responses=request_dict['generated_responses'], - query_kwargs=kwarg_dict_to_proto(query_kwargs)) + text=request_dict["text"], + conversation_id=request_dict["conversation_id"] + if "conversation_id" in request_dict + else None, + past_user_inputs=request_dict["past_user_inputs"], + generated_responses=request_dict["generated_responses"], + query_kwargs=kwarg_dict_to_proto(query_kwargs), + ) class Text2ImgMethods(TaskMethods): @@ -245,23 +260,25 @@ def pack_response_to_proto(self, response, time_taken, model_time_taken): img_mode = response.images[0].mode img_size_w, img_size_h = response.images[0].size - return modelresponse_pb2.ImageReply(images=images_bytes, - nsfw_content_detected=nsfw_content_detected, - mode=img_mode, - size_w=img_size_w, - size_h=img_size_h, - time_taken=time_taken) + return modelresponse_pb2.ImageReply( + images=images_bytes, + nsfw_content_detected=nsfw_content_detected, + mode=img_mode, + size_w=img_size_w, + size_h=img_size_h, + time_taken=time_taken, + ) def unpack_response_from_proto(self, response): return ImageResponse(response) GRPC_METHOD_TABLE = { - Tasks.TEXT_GENERATION: TextGenerationMethods(), - Tasks.TEXT_CLASSIFICATION: TextClassificationMethods(), - Tasks.QUESTION_ANSWERING: QuestionAnsweringMethods(), - Tasks.FILL_MASK: FillMaskMethods(), - Tasks.TOKEN_CLASSIFICATION: TokenClassificationMethods(), - Tasks.CONVERSATIONAL: ConversationalMethods(), - Tasks.TEXT2IMG: Text2ImgMethods(), + TaskType.TEXT_GENERATION: TextGenerationMethods(), + TaskType.TEXT_CLASSIFICATION: TextClassificationMethods(), + TaskType.QUESTION_ANSWERING: QuestionAnsweringMethods(), + TaskType.FILL_MASK: FillMaskMethods(), + TaskType.TOKEN_CLASSIFICATION: TokenClassificationMethods(), + TaskType.CONVERSATIONAL: ConversationalMethods(), + TaskType.TEXT2IMG: Text2ImgMethods(), } diff --git a/mii/models/load_models.py b/mii/models/load_models.py index a7d5e861..98b10df3 100644 --- a/mii/models/load_models.py +++ b/mii/models/load_models.py @@ -10,49 +10,45 @@ import deepspeed from deepspeed.runtime.config import DeepSpeedConfig from deepspeed.runtime.zero.config import ZeroStageEnum +from mii.utils import get_provider -def load_models(task_name, - model_name, - model_path, - ds_optimize, - ds_zero, - provider, - mii_config, - ds_config_path=None): - global generator - local_rank = int(os.getenv('LOCAL_RANK', '0')) - world_size = int(os.getenv('WORLD_SIZE', '1')) +def load_models(deployment_config): + local_rank = int(os.getenv("LOCAL_RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) inf_config = { - "tensor_parallel": { - "tp_size": world_size, - "mpu": None - }, - "dtype": mii_config.dtype, + "tensor_parallel": {"tp_size": deployment_config.tensor_parallel, "mpu": None}, + "dtype": deployment_config.dtype, "replace_method": "auto", - "enable_cuda_graph": mii_config.enable_cuda_graph, + "enable_cuda_graph": deployment_config.enable_cuda_graph, "checkpoint": None, "config": None, "training_mp_size": 1, - "replace_with_kernel_inject": mii_config.replace_with_kernel_inject, - "max_tokens": mii_config.max_tokens + "replace_with_kernel_inject": deployment_config.replace_with_kernel_inject, + "max_tokens": deployment_config.max_tokens, } + provider = get_provider(deployment_config.model, deployment_config.task) if provider == mii.constants.ModelProvider.HUGGING_FACE: from mii.models.providers.huggingface import hf_provider - if "bigscience/bloom" in model_name: - assert mii_config.dtype == torch.half or mii_config.dtype == torch.int8, "Bloom models only support fp16/int8" - assert mii_config.enable_cuda_graph == False, "Bloom models do no support Cuda Graphs" - inference_pipeline = hf_provider(model_path, model_name, task_name, mii_config) - if mii_config.meta_tensor: + + inference_pipeline = hf_provider( + model_path, model_name, task_name, deployment_config + ) + if deployment_config.meta_tensor: inf_config["checkpoint"] = inference_pipeline.checkpoint_dict - if mii_config.dtype == torch.int8: + if deployment_config.dtype == torch.int8: # Support for older DeepSpeed versions - if "enable_qkv_quantization" in inspect.signature( - deepspeed.init_inference).parameters: + if ( + "enable_qkv_quantization" + in inspect.signature(deepspeed.init_inference).parameters + ): inf_config["enable_qkv_quantization"] = True elif provider == mii.constants.ModelProvider.ELEUTHER_AI: + assert False, "Eleuther AI support is currently disabled." + # TODO: Re-enable EleutherAI model support + """ from mii.models.providers.eleutherai import eleutherai_provider assert mii_config.dtype == torch.half, "gpt-neox only support fp16" assert mii_config.enable_cuda_graph == False, "Provider EleutherAI not supported with Cuda Graphs" @@ -64,44 +60,45 @@ def load_models(task_name, mii_config) inf_config["training_mp_size"] = 2 inf_config["config"] = inference_pipeline.neox_args + """ elif provider == mii.constants.ModelProvider.DIFFUSERS: from mii.models.providers.diffusers import diffusers_provider - inference_pipeline = diffusers_provider(model_path, - model_name, - task_name, - mii_config) - inf_config["replace_with_kernel_inject"] = False #not supported yet + + inference_pipeline = diffusers_provider( + model_path, model_name, task_name, deployment_config + ) inf_config["enable_cuda_graph"] = True else: raise ValueError(f"Unknown model provider {provider}") + """ print( f"> --------- MII Settings: ds_optimize={ds_optimize}, replace_with_kernel_inject={mii_config.replace_with_kernel_inject}, enable_cuda_graph={mii_config.enable_cuda_graph} " ) - if ds_optimize: - engine = deepspeed.init_inference(getattr(inference_pipeline, - "model", - inference_pipeline), - config=inf_config) - if mii_config.profile_model_time: + """ + if deployment_config.ds_optimize: + engine = deepspeed.init_inference( + getattr(inference_pipeline, "model", inference_pipeline), config=inf_config + ) + if deployment_config.profile_model_time: engine.profile_model_time() if hasattr(inference_pipeline, "model"): inference_pipeline.model = engine elif ds_zero: - assert not mii_config.meta_tensor, "ZeRO-Inference does not support meta tensors" - ds_config = DeepSpeedConfig(ds_config_path) - #TODO: don't read ds-config from disk, we should pass this around as a dict instead - ds_config_dict = json.load(open(ds_config_path, 'r')) - assert ds_config.zero_optimization_stage == ZeroStageEnum.weights, "DeepSpeed ZeRO inference is only supported for ZeRO-3" + ds_config = DeepSpeedConfig(deployment_config.ds_config) + assert ( + ds_config.zero_optimization_stage == ZeroStageEnum.weights + ), "DeepSpeed ZeRO inference is only supported for ZeRO-3" # initialise Deepspeed ZeRO and store only the engine object - ds_engine = deepspeed.initialize(model=inference_pipeline.model, - config_params=ds_config_dict)[0] + ds_engine = deepspeed.initialize( + model=inference_pipeline.model, config_params=ds_config + )[0] ds_engine.module.eval() # inference inference_pipeline.model = ds_engine.module - if mii_config.load_with_sys_mem: + if deployment_config.load_with_sys_mem: inference_pipeline.device = torch.device(f"cuda:{local_rank}") return inference_pipeline diff --git a/mii/models/providers/diffusers.py b/mii/models/providers/diffusers.py index 82768db0..65713fbf 100644 --- a/mii/models/providers/diffusers.py +++ b/mii/models/providers/diffusers.py @@ -6,18 +6,19 @@ import torch -def diffusers_provider(model_path, model_name, task_name, mii_config): +def diffusers_provider(model_path, model_name, task_name, deployment_config): from diffusers import DiffusionPipeline - local_rank = int(os.getenv('LOCAL_RANK', '0')) + + local_rank = int(os.getenv("LOCAL_RANK", "0")) kwargs = {} - if mii_config.dtype == torch.half: + if deployment_config.dtype == torch.half: kwargs["torch_dtype"] = torch.float16 kwargs["revision"] = "fp16" - pipeline = DiffusionPipeline.from_pretrained(model_name, - use_auth_token=mii_config.hf_auth_token, - **kwargs) + pipeline = DiffusionPipeline.from_pretrained( + model_name, use_auth_token=deployment_config.hf_auth_token, **kwargs + ) pipeline = pipeline.to(f"cuda:{local_rank}") pipeline.set_progress_bar_config(disable=True) return pipeline diff --git a/mii/models/providers/huggingface.py b/mii/models/providers/huggingface.py index 27f456aa..6b3cf64c 100644 --- a/mii/models/providers/huggingface.py +++ b/mii/models/providers/huggingface.py @@ -18,9 +18,11 @@ try: from transformers.utils import cached_path, hf_bucket_url + USE_NEW_HF_CACHE = False except ImportError: from huggingface_hub import snapshot_download + USE_NEW_HF_CACHE = True @@ -28,6 +30,7 @@ class MetaTensorPipeline(object): """ Class for loading HuggingFace models using meta tensors """ + def __init__(self, model, tokenizer, checkpoint_dict): self.model = model self.tokenizer = tokenizer @@ -41,9 +44,9 @@ def __call__(self, inputs, **kwargs): # expand proto list into py-list inputs = [i for i in inputs] - tokens = self.tokenizer.batch_encode_plus(inputs, - return_tensors="pt", - padding=True) + tokens = self.tokenizer.batch_encode_plus( + inputs, return_tensors="pt", padding=True + ) for t in tokens: if torch.is_tensor(tokens[t]): tokens[t] = tokens[t].to(device) @@ -54,7 +57,7 @@ def __call__(self, inputs, **kwargs): # construct output to align w. HF pipeline output_dicts = [] for output in outputs: - output_dicts.append([{'generated_text': output}]) + output_dicts.append([{"generated_text": output}]) return output_dicts @@ -63,7 +66,7 @@ def get_device(load_with_sys_mem=False): if load_with_sys_mem: device = torch.device("cpu") else: - local_rank = int(os.getenv('LOCAL_RANK', '0')) + local_rank = int(os.getenv("LOCAL_RANK", "0")) device = torch.device(f"cuda:{local_rank}") return device @@ -72,7 +75,7 @@ def _attempt_load(load_fn, model_name, cache_path, kwargs={}): try: value = load_fn(model_name, **kwargs) except OSError: - print(f'Attempted load but failed, retrying using cache_dir={cache_path}') + print(f"Attempted load but failed, retrying using cache_dir={cache_path}") value = load_fn(model_name, cache_dir=cache_path, **kwargs) return value @@ -84,9 +87,9 @@ def get_checkpoint_files(pretrained_model_name_or_path): local_files_only = False filename = WEIGHTS_NAME - archive_file = hf_bucket_url(pretrained_model_name_or_path, - filename=filename, - revision=revision) + archive_file = hf_bucket_url( + pretrained_model_name_or_path, filename=filename, revision=revision + ) try: resolved_archive_file = cached_path( @@ -117,25 +120,27 @@ def get_checkpoint_files(pretrained_model_name_or_path): pretrained_model_name_or_path, resolved_archive_file, cache_dir=cache_dir, - revision=revision + revision=revision, ) return resolved_archive_file -def create_checkpoint_dict(model_name, model_path, mii_config): +def create_checkpoint_dict(model_name, model_path, deployment_config): if USE_NEW_HF_CACHE: - model_path = snapshot_download(model_name, - cache_dir=model_path, - allow_patterns=[ - "*.bin", - "*.json", - "*.pt", - ], - revision=None) - if mii_config.checkpoint_dict: - mii_config.checkpoint_dict['base_dir'] = model_path - return mii_config.checkpoint_dict + model_path = snapshot_download( + model_name, + cache_dir=model_path, + allow_patterns=[ + "*.bin", + "*.json", + "*.pt", + ], + revision=None, + ) + if deployment_config.checkpoint_dict: + deployment_config.checkpoint_dict["base_dir"] = model_path + return deployment_config.checkpoint_dict elif os.path.isfile(os.path.join(model_path, "ds_inference_config.json")): with open(os.path.join(model_path, "ds_inference_config.json")) as f: data = json.load(f) @@ -144,8 +149,9 @@ def create_checkpoint_dict(model_name, model_path, mii_config): else: if USE_NEW_HF_CACHE: checkpoint_files = [ - str(entry).split('/')[-1] - for entry in Path(model_path).rglob("*.[bp][it][n]") if entry.is_file() + str(entry).split("/")[-1] + for entry in Path(model_path).rglob("*.[bp][it][n]") + if entry.is_file() ] else: checkpoint_files = get_checkpoint_files(model_name) @@ -153,47 +159,51 @@ def create_checkpoint_dict(model_name, model_path, mii_config): "type": "DS_MODEL", "checkpoints": checkpoint_files, "version": 1.0, - "base_dir": model_path + "base_dir": model_path, } return data -def load_with_meta_tensor(model_path, model_name, task_name, mii_config): - deepspeed.init_distributed('nccl') +def load_with_meta_tensor(model_path, model_name, task_name, deployment_config): + deepspeed.init_distributed("nccl") cache_path = mii_cache_path() - tokenizer = _attempt_load(AutoTokenizer.from_pretrained, - model_name, - cache_path, - kwargs={"padding_side": 'left'}) + tokenizer = _attempt_load( + AutoTokenizer.from_pretrained, + model_name, + cache_path, + kwargs={"padding_side": "left"}, + ) tokenizer.pad_token = tokenizer.eos_token config = _attempt_load(AutoConfig.from_pretrained, model_name, cache_path) - with OnDevice(dtype=torch.float16, device='meta', enabled=True): + with OnDevice(dtype=torch.float16, device="meta", enabled=True): model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16) model = model.eval() - checkpoint_dict = create_checkpoint_dict(model_name, model_path, mii_config) + checkpoint_dict = create_checkpoint_dict(model_name, model_path, deployment_config) torch.distributed.barrier() - inference_pipeline = MetaTensorPipeline(model=model, - tokenizer=tokenizer, - checkpoint_dict=checkpoint_dict) + inference_pipeline = MetaTensorPipeline( + model=model, tokenizer=tokenizer, checkpoint_dict=checkpoint_dict + ) return inference_pipeline -def hf_provider(model_path, model_name, task_name, mii_config): - if mii_config.meta_tensor: - return load_with_meta_tensor(model_path, model_name, task_name, mii_config) +def hf_provider(model_path, model_name, task_name, deployment_config): + if deployment_config.meta_tensor: + return load_with_meta_tensor( + model_path, model_name, task_name, deployment_config + ) else: - device = get_device(load_with_sys_mem=mii_config.load_with_sys_mem) + device = get_device(load_with_sys_mem=deployment_config.load_with_sys_mem) inference_pipeline = pipeline( task_name, model=model_name, device=device, framework="pt", - use_auth_token=mii_config.hf_auth_token, - torch_dtype=mii_config.dtype, - trust_remote_code=mii_config.trust_remote_code, + use_auth_token=deployment_config.hf_auth_token, + torch_dtype=deployment_config.dtype, + trust_remote_code=deployment_config.trust_remote_code, ) return inference_pipeline diff --git a/mii/models/score/generate.py b/mii/models/score/generate.py index f347202c..ccc35695 100644 --- a/mii/models/score/generate.py +++ b/mii/models/score/generate.py @@ -15,27 +15,30 @@ def create_score_file(mii_config): f"Detected mii path as multiple sources: {mii.__path__}, might cause unknown behavior" ) - with open(os.path.join(mii.__path__[0], - "models/score/score_template.py"), - "r") as fd: + with open( + os.path.join(mii.__path__[0], "models/score/score_template.py"), "r" + ) as fd: score_src = fd.read() # update score file w. global config dict config_dict = mii_config.dict() source_with_config = f"{score_src}\n" - source_with_config += f"configs = {pprint.pformat(config_dict, indent=4)}" + source_with_config += f"mii_config = {pprint.pformat(config_dict, indent=4)}" - with open(generated_score_path(deployment_name, deployment_type), "w") as fd: + with open(generated_score_path(mii_config), "w") as fd: fd.write(source_with_config) fd.write("\n") -def generated_score_path(deployment_name, deployment_type): +def generated_score_path(mii_config): + deployment_type = mii_config.deployment_type + deployment_name = mii_config.deployment_config.deployment_name if deployment_type == DeploymentType.LOCAL: score_path = os.path.join(mii.utils.mii_cache_path(), deployment_name) elif deployment_type == DeploymentType.AML: - score_path = os.path.join(mii.aml_related.utils.aml_output_path(deployment_name), - "code") + score_path = os.path.join( + mii.aml_related.utils.aml_output_path(deployment_name), "code" + ) if not os.path.isdir(score_path): os.makedirs(score_path) return os.path.join(score_path, "score.py") diff --git a/mii/models/score/score_template.py b/mii/models/score/score_template.py index 0f59652e..353a60c6 100644 --- a/mii/models/score/score_template.py +++ b/mii/models/score/score_template.py @@ -6,18 +6,17 @@ # flake8: noqa import os import json +import time import torch + import mii -from mii.config import MIIConfig -import time model = None def init(): - mii_configs = MIIConfig(**mii_configs) - #model_path = mii.utils.full_model_path(configs[mii.constants.MODEL_PATH_KEY]) - + global mii_config + mii_config = mii.MIIConfig(**mii_config) mii.MIIServer(mii_config) global model @@ -25,25 +24,25 @@ def init(): # In AML deployments both the GRPC client and server are used in the same process if mii.utils.is_aml(): - model = mii.MIIClient(task_name, - mii_configs=configs[mii.constants.MII_CONFIGS_KEY]) + model = mii.MIIClient(mii_config=mii_config) def run(request): - global model - assert model is not None, "grpc client has not been setup when this model was created" + global mii_config, model + assert ( + model is not None + ), "grpc client has not been setup when this model was created" request_dict = json.loads(request) - query_dict = mii.utils.extract_query_dict(configs[mii.constants.TASK_NAME_KEY], - request_dict) + query_dict = mii.utils.extract_query_dict(mii_config.task, request_dict) response = model.query(query_dict, **request_dict) time_taken = response.time_taken if not isinstance(response.response, str): response = [r for r in response.response] - return json.dumps({'responses': response, 'time': time_taken}) + return json.dumps({"responses": response, "time": time_taken}) ### Auto-generated config will be appended below at run-time diff --git a/mii/models/utils.py b/mii/models/utils.py index 8baa7333..d44b2871 100644 --- a/mii/models/utils.py +++ b/mii/models/utils.py @@ -7,23 +7,24 @@ def supported_models_from_huggingface(): - return ['gpt2', "deepset/roberta-large-squad2"] + return ["gpt2", "deepset/roberta-large-squad2"] -'''TODO make this more robust. If the pipeline has already been imported then -this might not work since the cache is set by the first import''' +"""TODO make this more robust. If the pipeline has already been imported then +this might not work since the cache is set by the first import""" def _download_hf_model_to_path(task, model_name, model_path): os.environ["TRANSFORMERS_CACHE"] = model_path from transformers import pipeline + inference_pipeline = pipeline(task, model=model_name) -'''generic method that will allow downloading all models that we support. +"""generic method that will allow downloading all models that we support. Currently only supports HF models, but will be extended to support model checkpoints -from other sources''' +from other sources""" def download_model_and_get_path(task, model_name): @@ -40,7 +41,7 @@ def download_model_and_get_path(task, model_name): return model_path -class ImageResponse(): +class ImageResponse: def __init__(self, response): self._response = response self.nsfw_content_detected = response.nsfw_content_detected @@ -50,6 +51,7 @@ def __init__(self, response): def images(self): if self._deserialized_images is None: from PIL import Image + images = [] for idx, img_bytes in enumerate(self._response.images): size = (self._response.size_w, self._response.size_h) diff --git a/mii/server.py b/mii/server.py index 0825e060..d33743ae 100644 --- a/mii/server.py +++ b/mii/server.py @@ -14,7 +14,7 @@ from collections import defaultdict import mii -from mii.utils import get_num_gpus, logger, get_provider_name +from mii.utils import get_num_gpus, logger def config_to_b64_str(config): @@ -26,43 +26,36 @@ def config_to_b64_str(config): return b64_config_bytes.decode() -class MIIServer(): - '''Initialize the model, setup the server for the model under model_path''' - def __init__(self, - deployment_name, - task_name, - model_name, - model_path, - ds_optimize=True, - ds_zero=False, - ds_config=None, - mii_configs={}, - lb_config=None): +class MIIServer: + """Initialize the model, setup the server for the model under model_path""" - mii_configs = mii.config.MIIConfig(**mii_configs) + def __init__(self, mii_config): - self.task = mii.utils.get_task(task_name) - - self.num_gpus = get_num_gpus(mii_configs) + self.task = mii_config.deployment_config.task + self.num_gpus = get_num_gpus(mii_config) assert self.num_gpus > 0, "GPU count must be greater than 0" self.port_number = mii_configs.port_number + """ if mii_configs.hostfile is None: hostfile = tempfile.NamedTemporaryFile(delete=False) num_gpu = torch.cuda.device_count() with open(hostfile, "w") as f: f.write(f"localhost slots={num_gpu}") - mii.configs.hostfile = hostfile + mii_config.hostfile = hostfile + """ - processes = self._initialize_service(deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - lb_config) + processes = self._initialize_service( + deployment_name, + model_name, + model_path, + ds_optimize, + ds_zero, + ds_config, + mii_configs, + lb_config, + ) self._wait_until_server_is_live(processes, lb_config.replica_configs) def _wait_until_server_is_live(self, processes, deployment): @@ -70,20 +63,23 @@ def _wait_until_server_is_live(self, processes, deployment): sockets_open = False while not sockets_open: sockets_open = all( - self._is_socket_open(repl_config.hostname, - port) - for port in repl_config.tensor_parallel_ports) + self._is_socket_open(repl_config.hostname, port) + for port in repl_config.tensor_parallel_ports + ) process_alive = self._is_server_process_alive(process) if not process_alive: raise RuntimeError( - "server crashed for some reason, unable to proceed") + "server crashed for some reason, unable to proceed" + ) time.sleep(4) logger.info("waiting for server to start...") logger.info( - f"server has started on ports {repl_config.tensor_parallel_ports}") + f"server has started on ports {repl_config.tensor_parallel_ports}" + ) def _is_socket_open(self, host, port): import socket + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) result = sock.connect_ex((host, port)) sock.close() @@ -102,242 +98,82 @@ def _is_server_process_alive(self, process): is_alive = False return is_alive - def _build_server_args(self, - deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - port): - # serialize mii config - b64_config_str = config_to_b64_str(mii_configs) - - server_args_str = f"--deployment-name {deployment_name} --task-name {mii.utils.get_task_name(self.task)} --model {model_name} --model-path {model_path} --port {port}" - server_args_str += " --ds-optimize" if ds_optimize else "" - - # XXX: fetch model provider based on model name in a more general way - provider = get_provider_name(model_name, self.task) - server_args_str += f" --provider {provider}" - - server_args_str += f" --config {b64_config_str}" - server_args_str += " --ds-zero" if ds_zero else "" - if ds_zero and ds_config is not None: - if isinstance(ds_config, dict): - - def create_config_from_dict(tmpdir, config_dict): - if not os.path.exists(tmpdir): - os.makedirs(tmpdir) - config_path = os.path.join(tmpdir, 'temp_config.json') - with open(config_path, 'w') as fd: - json.dump(config_dict, fd) - return config_path - - model_dir = Path(model_path).parent.resolve() - ds_config_path = create_config_from_dict(model_dir, ds_config) - elif isinstance(ds_config, str): - ds_config_path = ds_config - else: - raise ValueError( - f"Expected a string path to an existing deepspeed config, or a dictionary. Received: {ds_config}" - ) - server_args_str += f" --ds-config {ds_config_path}" - printable_config = f"task-name {mii.utils.get_task_name(self.task)} model {model_name} model-path {model_path} port {self.port_number} provider {provider}" - logger.info(f"MII using multi-gpu deepspeed launcher:\n" + - self.print_helper(printable_config)) - return server_args_str - - def print_helper(self, args): - # convert to list - args = args.split(" ") - # convert to dict - dct = {args[i]: args[i + 1] for i in range(0, len(args), 2)} - printable_string = "" - printable_string += " " + "-" * 60 + "\n" - for k, v in dct.items(): - dots = "." * (29 - len(k)) - printable_string += f" {k} {dots} {v} \n" - printable_string += " " + "-" * 60 - return printable_string - - def _launch_load_balancer(self, - deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - lb_config): - - # serialize mii config - b64_config_str = config_to_b64_str(lb_config) - - return self._launch_server_process( - deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - mii_configs.port_number, - "load balancer", - ex_server_args=[f"--load-balancer {b64_config_str}"]) - - def _launch_restful_gateway(self, - deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - port): - return self._launch_server_process(deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - port, - "restful api gateway", - ex_server_args=["--restful-gateway"]) - - def _launch_server_process(self, - deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - port, - msg_server_type, - ds_launch_str=None, - ex_server_args=[]): + def _launch_server_process( + self, deployment_config, msg_server_type, ds_launch_str="", server_args=[] + ): launch_str = f"{sys.executable} -m mii.launch.multi_gpu_server" - server_args_str = self._build_server_args(deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - port) - server_args_str += f" " + \ - " ".join(ex_server_args) if ex_server_args else "" - - if ds_launch_str is None: - cmd = f'{launch_str} {server_args_str}'.split(" ") - else: - cmd = f'{ds_launch_str} {launch_str} {server_args_str}'.split(" ") + b64_config_str = config_to_b64_str(deployment_config) + server_args.append(f"--deployment-config {b64_config_str}") + server_args_str = " ".join(server_args) + cmd = f"{ds_launch_str} {launch_str} {server_args_str}".split(" ") mii_env = os.environ.copy() - mii_env["TRANSFORMERS_CACHE"] = model_path + mii_env["TRANSFORMERS_CACHE"] = deployment_config.model_path logger.info(f"{msg_server_type} server launch: {cmd}") return subprocess.Popen(cmd, env=mii_env) - def _launch_deepspeed(self, - deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - hostfile, - host, - port, - master_port, - deploy_ranks): + def _generate_ds_launch_str(self, replica_config, hostfile): # use different hostfiles for replica instances # pass /dev/null when no replica is used worker_str = f"-H {hostfile} " # pin deepspeed launch to specific gpu id(s) - included_gpus = f"{host}:{','.join(map(str, deploy_ranks))}" + included_gpus = f"{host}:{','.join(map(str, replica_config.gpu_indices))}" worker_str += f"-i {included_gpus} " # adjust torch dist port depending on rank, otherwise multi-replica deployments will conflict # assign different ports to replicas because they could be on the same host - worker_str += f"--master_port {master_port}" + worker_str += f"--master_port {replica_config.torch_dist_port}" ds_launch_str = f"deepspeed {worker_str} --no_local_rank --no_python" - return self._launch_server_process(deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - port, - "MII server", - ds_launch_str=ds_launch_str) - - def _initialize_service(self, - deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - lb_config): + return ds_launch_str + def _initialize_service(self, mii_config): processes = [] + server_args = [ + f"--load-balancer-port {mii_config.port_number}", + f"--restful-gateway-port {mii_config.restful_api_port}", + ] host_gpus = defaultdict(list) - for repl_config in lb_config.replica_configs: + for repl_config in mii_config.deployment_config.replica_configs: host_gpus[repl_config.hostname].extend(repl_config.gpu_indices) # Start replica instances - for i, repl_config in enumerate(lb_config.replica_configs): + for repl_config in mii_config.deployment_config.replica_configs: hostfile = tempfile.NamedTemporaryFile(delete=False) hostfile.write( - f'{repl_config.hostname} slots={max(host_gpus[repl_config.hostname])+1}\n' - .encode()) + f"{repl_config.hostname} slots={max(host_gpus[repl_config.hostname])+1}\n".encode() + ) + ds_launch_str = self._generate_ds_launch_str(replica_config, hostfile.name) processes.append( - self._launch_deepspeed( - deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - hostfile.name, - repl_config.hostname, - repl_config.tensor_parallel_ports[0], - mii_configs.torch_dist_port + (100 * i) + repl_config.gpu_indices[0], - repl_config.gpu_indices)) - + self._launch_server_process( + deployment_config, + "MII server", + ds_launch_str=ds_launch_str, + server_args=server_args + + [f"--server-port {replica_config.tensor_parallel_port[0]}"], + ) + ) # start load balancer here. # we don't use deepspeed launcher for the load balancer because it does not need a GPU. # The deepspeed launcher determines the number of processes to launch based on GPUs available on the host or CUDA_VISIBLE_DEVICES, # and it is expected to assign one GPU to one process. processes.append( - self._launch_load_balancer(deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - lb_config)) - - if mii_configs.enable_restful_api: - # start rest api server + self._launch_server_process( + deployment_config, + "load balancer", + server_args=server_args + ["--load-balancer"], + ) + ) + + if mii_config.enable_restful_api: processes.append( - self._launch_restful_gateway(deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - mii_configs.port_number)) + self._launch_server_process( + deployment_config, + "restful api gateway", + server_args=server_args + ["--restful-gateway"], + ) + ) return processes diff --git a/mii/terminate.py b/mii/terminate.py index 167c5a5a..a0a6b99e 100644 --- a/mii/terminate.py +++ b/mii/terminate.py @@ -10,11 +10,11 @@ def terminate(deployment_name): mii.utils.logger.info(f"Terminating server for {deployment_name}") generator = mii.mii_query_handle(deployment_name) - if (deployment_name in mii.non_persistent_models): + if deployment_name in mii.non_persistent_models: generator.terminate() return try: - generator.query({'query': ''}) + generator.query({"query": ""}) except grpc.aio._call.AioRpcError as error: if error._code == grpc.StatusCode.UNAVAILABLE: mii.utils.logger.warn(f"Server for {deployment_name} not found") diff --git a/mii/utils.py b/mii/utils.py index 674c64a3..14a4b74a 100644 --- a/mii/utils.py +++ b/mii/utils.py @@ -10,78 +10,26 @@ import mii from huggingface_hub import HfApi -from mii.constants import (CONVERSATIONAL_NAME, - FILL_MASK_NAME, - MII_CACHE_PATH, - MII_CACHE_PATH_DEFAULT, - TEXT_GENERATION_NAME, - TEXT_CLASSIFICATION_NAME, - QUESTION_ANSWERING_NAME, - TOKEN_CLASSIFICATION_NAME, - SUPPORTED_MODEL_TYPES, - ModelProvider, - REQUIRED_KEYS_PER_TASK, - TEXT2IMG_NAME) +from mii.constants import ( + MII_CACHE_PATH, + MII_CACHE_PATH_DEFAULT, + ModelProvider, + SUPPORTED_MODEL_TYPES, + REQUIRED_KEYS_PER_TASK, +) -from mii.constants import Tasks - - -def get_task_name(task): - if task == Tasks.QUESTION_ANSWERING: - return QUESTION_ANSWERING_NAME - - if task == Tasks.TEXT_GENERATION: - return TEXT_GENERATION_NAME - - if task == Tasks.TEXT_CLASSIFICATION: - return TEXT_CLASSIFICATION_NAME - - if task == Tasks.FILL_MASK: - return FILL_MASK_NAME - - if task == Tasks.TOKEN_CLASSIFICATION: - return TOKEN_CLASSIFICATION_NAME - - if task == Tasks.CONVERSATIONAL: - return CONVERSATIONAL_NAME - - if task == Tasks.TEXT2IMG: - return TEXT2IMG_NAME - - raise ValueError(f"Unknown Task {task}") - - -def get_task(task_name): - if task_name == QUESTION_ANSWERING_NAME: - return Tasks.QUESTION_ANSWERING - - if task_name == TEXT_GENERATION_NAME: - return Tasks.TEXT_GENERATION - - if task_name == TEXT_CLASSIFICATION_NAME: - return Tasks.TEXT_CLASSIFICATION - - if task_name == FILL_MASK_NAME: - return Tasks.FILL_MASK - - if task_name == TOKEN_CLASSIFICATION_NAME: - return Tasks.TOKEN_CLASSIFICATION - - if task_name == CONVERSATIONAL_NAME: - return Tasks.CONVERSATIONAL - - if task_name == TEXT2IMG_NAME: - return Tasks.TEXT2IMG - - assert False, f"Unknown Task {task_name}" +from mii.config import TaskType def _get_hf_models_by_type(model_type, task=None): api = HfApi() models = api.list_models(filter=model_type) - models = ([m.modelId for m in models] - if task is None else [m.modelId for m in models if m.pipeline_tag == task]) - if task == TEXT_GENERATION_NAME: + models = ( + [m.modelId for m in models] + if task is None + else [m.modelId for m in models if m.pipeline_tag == task] + ) + if task == TaskType.TEXT_GENERATION: # TODO: this is a temp solution to get around some HF models not having the correct tags models.append("microsoft/bloom-deepspeed-inference-fp16") models.append("microsoft/bloom-deepspeed-inference-int8") @@ -92,16 +40,15 @@ def _get_hf_models_by_type(model_type, task=None): # TODO read this from a file containing list of files supported for each task def _get_supported_models_name(task): supported_models = [] - task_name = get_task_name(task) for model_type, provider in SUPPORTED_MODEL_TYPES.items(): if provider == ModelProvider.HUGGING_FACE: - models = _get_hf_models_by_type(model_type, task_name) + models = _get_hf_models_by_type(model_type, task) elif provider == ModelProvider.ELEUTHER_AI: - if task_name == TEXT_GENERATION_NAME: + if task == TaskType.TEXT_GENERATION: models = [model_type] elif provider == ModelProvider.DIFFUSERS: - models = _get_hf_models_by_type(model_type, task_name) + models = _get_hf_models_by_type(model_type, task) supported_models.extend(models) if not supported_models: raise ValueError(f"Task {task} not supported") @@ -115,27 +62,8 @@ def check_if_task_and_model_is_supported(task, model_name): def check_if_task_and_model_is_valid(task, model_name): - task_name = get_task_name(task) - valid_task_models = _get_hf_models_by_type(None, task_name) - assert ( - model_name in valid_task_models - ), f"{task_name} only supports {valid_task_models}" - - -def full_model_path(model_path): - aml_model_dir = os.environ.get('AZUREML_MODEL_DIR', None) - if aml_model_dir: - # (potentially) append relative model_path w. aml path - assert os.path.isabs(aml_model_dir), f"AZUREML_MODEL_DIR={aml_model_dir} must be an absolute path" - if model_path: - assert not os.path.isabs(model_path), f"model_path={model_path} must be relative to append w. AML path" - return os.path.join(aml_model_dir, model_path) - else: - return aml_model_dir - elif model_path: - return model_path - else: - return mii.constants.MII_MODEL_PATH_DEFAULT + valid_task_models = _get_hf_models_by_type(None, task) + assert model_name in valid_task_models, f"{task} only supports {valid_task_models}" def is_aml(): @@ -151,10 +79,8 @@ def mii_cache_path(): def import_score_file(deployment_name): spec = importlib.util.spec_from_file_location( - "score", - os.path.join(mii_cache_path(), - deployment_name, - "score.py")) + "score", os.path.join(mii_cache_path(), deployment_name, "score.py") + ) score = importlib.util.module_from_spec(spec) spec.loader.exec_module(score) return score @@ -179,10 +105,7 @@ def get_proto_value(value): def unpack_proto_query_kwargs(query_kwargs): query_kwargs = { - k: getattr(v, - v.WhichOneof("oneof_values")) - for k, - v in query_kwargs.items() + k: getattr(v, v.WhichOneof("oneof_values")) for k, v in query_kwargs.items() } return query_kwargs @@ -198,21 +121,22 @@ def extract_query_dict(task, request_dict): return query_dict -def get_num_gpus(mii_configs): - num_gpus = mii_configs.tensor_parallel +def get_num_gpus(mii_config): + num_gpus = mii_config.deployment_config.tensor_parallel - assert torch.cuda.device_count( - ) >= num_gpus, f"Available GPU count: {torch.cuda.device_count()} does not meet the required gpu count: {num_gpus}" + assert ( + torch.cuda.device_count() >= num_gpus + ), f"Available GPU count: {torch.cuda.device_count()} does not meet the required gpu count: {num_gpus}" return num_gpus -def get_provider_name(model_name, task): +def get_provider(model_name, task): if model_name == "gpt-neox": - provider = mii.constants.MODEL_PROVIDER_NAME_EA - elif task == mii.Tasks.TEXT2IMG: - provider = mii.constants.MODEL_PROVIDER_NAME_DIFFUSERS + provider = ModelProvider.ELEUTHER_AI + elif task == TaskType.TEXT2IMG: + provider = ModelProvider.DIFFUSERS else: - provider = mii.constants.MODEL_PROVIDER_NAME_HF + provider = ModelProvider.HUGGING_FACE return provider @@ -241,7 +165,8 @@ def create_logger(name=None, level=logging.INFO): formatter = logging.Formatter( "[%(asctime)s] [%(levelname)s] " - "[%(filename)s:%(lineno)d:%(funcName)s] %(message)s") + "[%(filename)s:%(lineno)d:%(funcName)s] %(message)s" + ) logger_ = logging.getLogger(name) logger_.setLevel(level) From 926795de02e95d02c93091d0950eba9c4c88be84 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 8 Aug 2023 10:37:24 -0700 Subject: [PATCH 03/15] fix errors, get things running --- mii/__init__.py | 2 +- mii/aml_related/templates.py | 30 ++++--- mii/aml_related/utils.py | 99 ++++++++++++++---------- mii/client.py | 17 ++-- mii/config.py | 31 +++++--- mii/grpc_related/modelresponse_server.py | 4 +- mii/launch/multi_gpu_server.py | 8 +- mii/models/load_models.py | 12 +-- mii/models/providers/diffusers.py | 4 +- mii/models/providers/huggingface.py | 26 +++---- mii/server.py | 29 +++---- 11 files changed, 138 insertions(+), 124 deletions(-) diff --git a/mii/__init__.py b/mii/__init__.py index e0b4fe21..94bbdccf 100644 --- a/mii/__init__.py +++ b/mii/__init__.py @@ -10,7 +10,7 @@ from .constants import DeploymentType, TaskType from .aml_related.utils import aml_output_path -from .config import MIIConfig +from .config import MIIConfig, DeploymentConfig from .grpc_related.proto import modelresponse_pb2_grpc __version__ = "0.0.0" diff --git a/mii/aml_related/templates.py b/mii/aml_related/templates.py index 5ad97a31..5e67255b 100644 --- a/mii/aml_related/templates.py +++ b/mii/aml_related/templates.py @@ -2,7 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team -deployment = """$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineDeployment.schema.json +deployment = \ +"""$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineDeployment.schema.json name: endpoint_name: model: @@ -35,12 +36,14 @@ instance_count: 1 """ -endpoint = """$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineEndpoint.schema.json +endpoint = \ +"""$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineEndpoint.schema.json name: auth_mode: key """ -environment = """$schema: https://azuremlschemas.azureedge.net/latest/environment.schema.json +environment = \ +"""$schema: https://azuremlschemas.azureedge.net/latest/environment.schema.json name: version: image: .azurecr.io/: @@ -56,7 +59,8 @@ port: 5001 """ -model_download = """import os +model_download = \ +"""import os import glob import shutil @@ -93,7 +97,8 @@ shutil.rmtree(tmp_download_path) """ -deploy = """set -e +deploy = \ +"""set -e python3 model_download.py az acr build -r --build-arg no-cache=True -t ":" build az ml environment create -f environment.yml @@ -101,7 +106,8 @@ az ml online-deployment create -n "" -f deployment.yml """ -dockerfile = """FROM nvidia/cuda:11.3.1-devel-ubuntu18.04 +dockerfile = \ +"""FROM nvidia/cuda:11.3.1-devel-ubuntu18.04 ENV AML_APP_ROOT=/var/azureml-model/code \ BUILD_DIR=/tmp/build \ @@ -170,7 +176,8 @@ CMD sudo service nginx start && cd $AZUREML_MODEL_DIR/code && azmlinfsrv --model_dir $AZUREML_MODEL_DIR --entry_script $AZUREML_MODEL_DIR/code/score.py --port 31311 """ -gunicorn = """upstream gunicorn { +gunicorn = \ +"""upstream gunicorn { server 127.0.0.1:31311; } @@ -195,7 +202,8 @@ } """ -gunicorn_run = """#!/bin/bash +gunicorn_run = \ +"""#!/bin/bash SCRIPT_PATH=$(dirname $(realpath -s "$0")) @@ -369,7 +377,8 @@ fi """ -gunicorn_finish = """#!/bin/bash +gunicorn_finish = \ +"""#!/bin/bash exit_code="$1" # The exit code from gunicorn signal="$2" # The signal which caused gunicorn to exit (or 0) @@ -380,7 +389,8 @@ killall -SIGHUP runsvdir """ -requirements = """torch==1.12.0 +requirements = \ +"""torch==1.12.0 grpcio grpcio-tools pydantic diff --git a/mii/aml_related/utils.py b/mii/aml_related/utils.py index c2065010..f6e1520b 100644 --- a/mii/aml_related/utils.py +++ b/mii/aml_related/utils.py @@ -11,10 +11,14 @@ def get_acr_name(): try: acr_name = subprocess.check_output( - ["az", "ml", "workspace", "show", "--query", "container_registry"], - text=True, - ) - return acr_name.strip().replace('"', "").rsplit("/", 1)[-1] + ["az", + "ml", + "workspace", + "show", + "--query", + "container_registry"], + text=True) + return acr_name.strip().replace('"', '').rsplit('/', 1)[-1] except subprocess.CalledProcessError as e: print("\n", "-" * 30, "\n") print("Unable to obtain ACR name from Azure-CLI. Please verify that you:") @@ -76,49 +80,60 @@ def generate_aml_scripts(acr_name, deployment_name, model_name, version): } # Docker files + write_out_script(os.path.join(output_dir, + "build", + "Dockerfile"), + fill_template(mii.aml_related.templates.dockerfile, + replace_dict)) + write_out_script(os.path.join(output_dir, + "build", + "gunicorn_app"), + fill_template(mii.aml_related.templates.gunicorn, + replace_dict)) + write_out_script(os.path.join(output_dir, + "build", + "runit", + "gunicorn", + "run"), + fill_template(mii.aml_related.templates.gunicorn_run, + replace_dict)) write_out_script( - os.path.join(output_dir, "build", "Dockerfile"), - fill_template(mii.aml_related.templates.dockerfile, replace_dict), - ) - write_out_script( - os.path.join(output_dir, "build", "gunicorn_app"), - fill_template(mii.aml_related.templates.gunicorn, replace_dict), - ) - write_out_script( - os.path.join(output_dir, "build", "runit", "gunicorn", "run"), - fill_template(mii.aml_related.templates.gunicorn_run, replace_dict), - ) - write_out_script( - os.path.join(output_dir, "build", "runit", "gunicorn", "finish"), - fill_template(mii.aml_related.templates.gunicorn_finish, replace_dict), - ) - write_out_script( - os.path.join(output_dir, "build", "requirements.txt"), - fill_template(mii.aml_related.templates.requirements, replace_dict), - ) + os.path.join(output_dir, + "build", + "runit", + "gunicorn", + "finish"), + fill_template(mii.aml_related.templates.gunicorn_finish, + replace_dict)) + write_out_script(os.path.join(output_dir, + "build", + "requirements.txt"), + fill_template(mii.aml_related.templates.requirements, + replace_dict)) # Model download script write_out_script( - os.path.join(output_dir, "model_download.py"), - fill_template(mii.aml_related.templates.model_download, replace_dict), - ) + os.path.join(output_dir, + "model_download.py"), + fill_template(mii.aml_related.templates.model_download, + replace_dict)) # Deployment script - write_out_script( - os.path.join(output_dir, "deploy.sh"), - fill_template(mii.aml_related.templates.deploy, replace_dict), - ) + write_out_script(os.path.join(output_dir, + "deploy.sh"), + fill_template(mii.aml_related.templates.deploy, + replace_dict)) # Yaml configs - write_out_yaml( - os.path.join(output_dir, "deployment.yml"), - fill_template(mii.aml_related.templates.deployment, replace_dict), - ) - write_out_yaml( - os.path.join(output_dir, "endpoint.yml"), - fill_template(mii.aml_related.templates.endpoint, replace_dict), - ) - write_out_yaml( - os.path.join(output_dir, "environment.yml"), - fill_template(mii.aml_related.templates.environment, replace_dict), - ) + write_out_yaml(os.path.join(output_dir, + "deployment.yml"), + fill_template(mii.aml_related.templates.deployment, + replace_dict)) + write_out_yaml(os.path.join(output_dir, + "endpoint.yml"), + fill_template(mii.aml_related.templates.endpoint, + replace_dict)) + write_out_yaml(os.path.join(output_dir, + "environment.yml"), + fill_template(mii.aml_related.templates.environment, + replace_dict)) diff --git a/mii/client.py b/mii/client.py index b776df38..df9bab48 100644 --- a/mii/client.py +++ b/mii/client.py @@ -9,16 +9,7 @@ from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc from mii.constants import GRPC_MAX_MSG_SIZE, TaskType from mii.method_table import GRPC_METHOD_TABLE - - -def _get_deployment_info(deployment_name): - configs = mii.utils.import_score_file(deployment_name).configs - task = configs[mii.constants.TASK_NAME_KEY] - mii_configs_dict = configs[mii.constants.MII_CONFIGS_KEY] - mii_configs = mii.config.MIIConfig(**mii_configs_dict) - - assert task is not None, "The task name should be set before calling init" - return task, mii_configs +from mii.config import MIIConfig def mii_query_handle(deployment_name): @@ -38,8 +29,10 @@ def mii_query_handle(deployment_name): inference_pipeline, task = mii.non_persistent_models[deployment_name] return MIINonPersistentClient(task, deployment_name) - task, mii_configs = _get_deployment_info(deployment_name) - return MIIClient(task, "localhost", mii_configs.port_number) + mii_config = mii.utils.import_score_file(deployment_name).mii_config + # TODO: Avoid model checking when we laod the config + mii_config = MIIConfig(**mii_config) + return MIIClient(mii_config.deployment_config.task, "localhost", mii_config.port_number) def create_channel(host, port): diff --git a/mii/config.py b/mii/config.py index 28a21c8f..e4392b9c 100644 --- a/mii/config.py +++ b/mii/config.py @@ -27,10 +27,11 @@ class DeploymentConfig(DeepSpeedConfigModel): model: str task: TaskType tensor_parallel: int = 1 - dtype: DtypeEnum = torch.float32 + dtype: DtypeEnum = DtypeEnum.fp32 meta_tensor: bool = False load_with_sys_mem: bool = False enable_cuda_graph: bool = False + hf_auth_token: str = "" checkpoint_dict: Optional[Dict[str, Any]] = None deploy_rank: Optional[List[int]] = None torch_dist_port: int = 29500 @@ -46,6 +47,9 @@ class DeploymentConfig(DeepSpeedConfigModel): replica_num: int = 1 replica_configs: List[ReplicaConfig] = [] + class Config: + json_encoders = {torch.dtype: lambda x: str(x)} + @validator("checkpoint_dict") def checkpoint_dict_valid(cls, field_value, values): if field_value is None: @@ -76,14 +80,14 @@ def zero_or_meta(cls, values): @root_validator def bloom_model_valid(cls, values): if "bigscience/bloom" in values.get("model"): - dtype = values.get("dtype") - assert dtype in [ - DtypeEnum.int8, - DtypeEnum.fp16, + # TODO: SHould be albe to use DtypeEnum here + assert values.get("dtype") in [ + torch.int8, + torch.float16, ], "Bloom models only support fp16/int8." - assert not values.get( + assert (not values.get( "enable_cuda_graph" - ), "Bloom models do not support CUDA Graph." + )), "Bloom models do not support CUDA Graph." return values @root_validator @@ -130,7 +134,7 @@ def validate_model_and_task(cls, values): model = values.get("model") if not values.get("skip_model_check"): mii.utils.check_if_task_and_model_is_valid(task, model) - if values.get("ds_optimize"): + if values.get("enable_deepspeed"): mii.utils.check_if_task_and_model_is_supported(task, model) return values @@ -165,7 +169,7 @@ def deepspeed_or_zero(cls, values): class MIIConfig(DeepSpeedConfigModel): deployment_config: DeploymentConfig = {} - hf_auth_token: str = None + hf_auth_token: str = "" port_number: int = 50050 enable_restful_api: bool = False restful_api_port: int = 51080 @@ -173,6 +177,15 @@ class MIIConfig(DeepSpeedConfigModel): deployment_type: DeploymentType = DeploymentType.LOCAL version: int = 1 + @root_validator + def propagate_hf_auth(cls, values): + # This validator is for when we support multiple models in a deployment + hf_auth_token = values.get("hf_auth_token") + deployment_config = values.get("deployment_config") + if not deployment_config.hf_auth_token: + deployment_config.hf_auth_token = hf_auth_token + return values + @root_validator def AML_name_valid(cls, values): if values.get("deployment_type") == DeploymentType.AML: diff --git a/mii/grpc_related/modelresponse_server.py b/mii/grpc_related/modelresponse_server.py index f6eab456..30055286 100644 --- a/mii/grpc_related/modelresponse_server.py +++ b/mii/grpc_related/modelresponse_server.py @@ -20,11 +20,11 @@ TERMINATE_METHOD, LB_MAX_WORKER_THREADS, SERVER_SHUTDOWN_TIMEOUT, - Tasks, + TaskType, ) from mii.method_table import GRPC_METHOD_TABLE from mii.client import create_channel -from mii.utils import get_task, unpack_proto_query_kwargs +from mii.utils import unpack_proto_query_kwargs class ServiceBase(modelresponse_pb2_grpc.ModelResponseServicer): diff --git a/mii/launch/multi_gpu_server.py b/mii/launch/multi_gpu_server.py index 708ec80a..5e7988aa 100644 --- a/mii/launch/multi_gpu_server.py +++ b/mii/launch/multi_gpu_server.py @@ -68,7 +68,7 @@ def main(): assert args.restful_gateway_port, "--restful-gateway-port must be provided." print(f"Starting RESTful API gateway on port: {args.restful_gateway_port}") gateway_thread = RestfulGatewayThread( - deployment_config, + args.deployment_config, lb_port=args.load_balancer_port, rest_port=args.restful_gateway_port, ) @@ -78,15 +78,15 @@ def main(): elif args.load_balancer: assert args.load_balancer_port, "--load-balancer-port must be provided." - print(f"Starting load balancer on port: {lb_config.port}") - serve_load_balancing(deployment_config, args.load_balancer_port) + print(f"Starting load balancer on port: {args.load_balancer_port}") + serve_load_balancing(args.deployment_config, args.load_balancer_port) else: assert args.server_port, "--server-port must be provided." local_rank = int(os.getenv("LOCAL_RANK", "0")) port = args.server_port + local_rank - inference_pipeline = load_models(deployment_config) + inference_pipeline = load_models(args.deployment_config) print(f"Starting server on port: {port}") serve_inference(inference_pipeline, port) diff --git a/mii/models/load_models.py b/mii/models/load_models.py index 98b10df3..09011c31 100644 --- a/mii/models/load_models.py +++ b/mii/models/load_models.py @@ -33,9 +33,7 @@ def load_models(deployment_config): if provider == mii.constants.ModelProvider.HUGGING_FACE: from mii.models.providers.huggingface import hf_provider - inference_pipeline = hf_provider( - model_path, model_name, task_name, deployment_config - ) + inference_pipeline = hf_provider(deployment_config) if deployment_config.meta_tensor: inf_config["checkpoint"] = inference_pipeline.checkpoint_dict if deployment_config.dtype == torch.int8: @@ -64,9 +62,7 @@ def load_models(deployment_config): elif provider == mii.constants.ModelProvider.DIFFUSERS: from mii.models.providers.diffusers import diffusers_provider - inference_pipeline = diffusers_provider( - model_path, model_name, task_name, deployment_config - ) + inference_pipeline = diffusers_provider(deployment_config) inf_config["enable_cuda_graph"] = True else: raise ValueError(f"Unknown model provider {provider}") @@ -76,7 +72,7 @@ def load_models(deployment_config): f"> --------- MII Settings: ds_optimize={ds_optimize}, replace_with_kernel_inject={mii_config.replace_with_kernel_inject}, enable_cuda_graph={mii_config.enable_cuda_graph} " ) """ - if deployment_config.ds_optimize: + if deployment_config.enable_deepspeed: engine = deepspeed.init_inference( getattr(inference_pipeline, "model", inference_pipeline), config=inf_config ) @@ -85,7 +81,7 @@ def load_models(deployment_config): if hasattr(inference_pipeline, "model"): inference_pipeline.model = engine - elif ds_zero: + elif deployment_config.enable_zero: ds_config = DeepSpeedConfig(deployment_config.ds_config) assert ( ds_config.zero_optimization_stage == ZeroStageEnum.weights diff --git a/mii/models/providers/diffusers.py b/mii/models/providers/diffusers.py index 65713fbf..3109223e 100644 --- a/mii/models/providers/diffusers.py +++ b/mii/models/providers/diffusers.py @@ -6,7 +6,7 @@ import torch -def diffusers_provider(model_path, model_name, task_name, deployment_config): +def diffusers_provider(deployment_config): from diffusers import DiffusionPipeline local_rank = int(os.getenv("LOCAL_RANK", "0")) @@ -17,7 +17,7 @@ def diffusers_provider(model_path, model_name, task_name, deployment_config): kwargs["revision"] = "fp16" pipeline = DiffusionPipeline.from_pretrained( - model_name, use_auth_token=deployment_config.hf_auth_token, **kwargs + deployment_config.model, use_auth_token=deployment_config.hf_auth_token, **kwargs ) pipeline = pipeline.to(f"cuda:{local_rank}") pipeline.set_progress_bar_config(disable=True) diff --git a/mii/models/providers/huggingface.py b/mii/models/providers/huggingface.py index 6b3cf64c..60862b8e 100644 --- a/mii/models/providers/huggingface.py +++ b/mii/models/providers/huggingface.py @@ -126,7 +126,7 @@ def get_checkpoint_files(pretrained_model_name_or_path): return resolved_archive_file -def create_checkpoint_dict(model_name, model_path, deployment_config): +def create_checkpoint_dict(model_name, model_path, checkpoint_dict): if USE_NEW_HF_CACHE: model_path = snapshot_download( model_name, @@ -138,9 +138,9 @@ def create_checkpoint_dict(model_name, model_path, deployment_config): ], revision=None, ) - if deployment_config.checkpoint_dict: - deployment_config.checkpoint_dict["base_dir"] = model_path - return deployment_config.checkpoint_dict + if checkpoint_dict: + checkpoint_dict["base_dir"] = model_path + return checkpoint_dict elif os.path.isfile(os.path.join(model_path, "ds_inference_config.json")): with open(os.path.join(model_path, "ds_inference_config.json")) as f: data = json.load(f) @@ -164,25 +164,25 @@ def create_checkpoint_dict(model_name, model_path, deployment_config): return data -def load_with_meta_tensor(model_path, model_name, task_name, deployment_config): +def load_with_meta_tensor(deployment_config): deepspeed.init_distributed("nccl") cache_path = mii_cache_path() tokenizer = _attempt_load( AutoTokenizer.from_pretrained, - model_name, + deployment_config.model, cache_path, kwargs={"padding_side": "left"}, ) tokenizer.pad_token = tokenizer.eos_token - config = _attempt_load(AutoConfig.from_pretrained, model_name, cache_path) + config = _attempt_load(AutoConfig.from_pretrained, deployment_config.model, cache_path) with OnDevice(dtype=torch.float16, device="meta", enabled=True): model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16) model = model.eval() - checkpoint_dict = create_checkpoint_dict(model_name, model_path, deployment_config) + checkpoint_dict = create_checkpoint_dict(deployment_config.model, deployment_config.model_path, deployment_config.checkpoint_dict) torch.distributed.barrier() inference_pipeline = MetaTensorPipeline( model=model, tokenizer=tokenizer, checkpoint_dict=checkpoint_dict @@ -190,16 +190,14 @@ def load_with_meta_tensor(model_path, model_name, task_name, deployment_config): return inference_pipeline -def hf_provider(model_path, model_name, task_name, deployment_config): +def hf_provider(deployment_config): if deployment_config.meta_tensor: - return load_with_meta_tensor( - model_path, model_name, task_name, deployment_config - ) + return load_with_meta_tensor(deployment_config) else: device = get_device(load_with_sys_mem=deployment_config.load_with_sys_mem) inference_pipeline = pipeline( - task_name, - model=model_name, + deployment_config.task, + model=deployment_config.model, device=device, framework="pt", use_auth_token=deployment_config.hf_auth_token, diff --git a/mii/server.py b/mii/server.py index d33743ae..f2f2a42f 100644 --- a/mii/server.py +++ b/mii/server.py @@ -35,8 +35,6 @@ def __init__(self, mii_config): self.num_gpus = get_num_gpus(mii_config) assert self.num_gpus > 0, "GPU count must be greater than 0" - self.port_number = mii_configs.port_number - """ if mii_configs.hostfile is None: hostfile = tempfile.NamedTemporaryFile(delete=False) @@ -46,17 +44,8 @@ def __init__(self, mii_config): mii_config.hostfile = hostfile """ - processes = self._initialize_service( - deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - lb_config, - ) - self._wait_until_server_is_live(processes, lb_config.replica_configs) + processes = self._initialize_service(mii_config) + self._wait_until_server_is_live(processes, mii_config.deployment_config.replica_configs) def _wait_until_server_is_live(self, processes, deployment): for process, repl_config in zip(processes, deployment): @@ -105,7 +94,7 @@ def _launch_server_process( b64_config_str = config_to_b64_str(deployment_config) server_args.append(f"--deployment-config {b64_config_str}") server_args_str = " ".join(server_args) - cmd = f"{ds_launch_str} {launch_str} {server_args_str}".split(" ") + cmd = f"{ds_launch_str} {launch_str} {server_args_str}".strip().split(" ") mii_env = os.environ.copy() mii_env["TRANSFORMERS_CACHE"] = deployment_config.model_path @@ -117,7 +106,7 @@ def _generate_ds_launch_str(self, replica_config, hostfile): # pass /dev/null when no replica is used worker_str = f"-H {hostfile} " # pin deepspeed launch to specific gpu id(s) - included_gpus = f"{host}:{','.join(map(str, replica_config.gpu_indices))}" + included_gpus = f"{replica_config.hostname}:{','.join(map(str, replica_config.gpu_indices))}" worker_str += f"-i {included_gpus} " # adjust torch dist port depending on rank, otherwise multi-replica deployments will conflict @@ -145,14 +134,14 @@ def _initialize_service(self, mii_config): hostfile.write( f"{repl_config.hostname} slots={max(host_gpus[repl_config.hostname])+1}\n".encode() ) - ds_launch_str = self._generate_ds_launch_str(replica_config, hostfile.name) + ds_launch_str = self._generate_ds_launch_str(repl_config, hostfile.name) processes.append( self._launch_server_process( - deployment_config, + mii_config.deployment_config, "MII server", ds_launch_str=ds_launch_str, server_args=server_args - + [f"--server-port {replica_config.tensor_parallel_port[0]}"], + + [f"--server-port {repl_config.tensor_parallel_ports[0]}"], ) ) # start load balancer here. @@ -161,7 +150,7 @@ def _initialize_service(self, mii_config): # and it is expected to assign one GPU to one process. processes.append( self._launch_server_process( - deployment_config, + mii_config.deployment_config, "load balancer", server_args=server_args + ["--load-balancer"], ) @@ -170,7 +159,7 @@ def _initialize_service(self, mii_config): if mii_config.enable_restful_api: processes.append( self._launch_server_process( - deployment_config, + mii_config.deployment_config, "restful api gateway", server_args=server_args + ["--restful-gateway"], ) From 0b04ed23b8bfe405f0f8dafb10f7f81bd2affec2 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 8 Aug 2023 10:44:19 -0700 Subject: [PATCH 04/15] formatting and flake fixes --- mii/client.py | 62 +++++++++---------- mii/config.py | 35 +++++------ mii/constants.py | 3 +- mii/deployment.py | 5 +- mii/grpc_related/modelresponse_server.py | 40 ++++++------ .../proto/modelresponse_pb2_grpc.py | 52 +++++++++------- mii/grpc_related/restful_gateway.py | 5 +- mii/launch/multi_gpu_server.py | 10 ++- mii/method_table.py | 33 +++++----- mii/models/load_models.py | 25 ++++---- mii/models/providers/diffusers.py | 5 +- mii/models/providers/huggingface.py | 30 ++++----- mii/models/score/generate.py | 11 ++-- mii/server.py | 46 ++++++-------- mii/utils.py | 21 ++++--- 15 files changed, 186 insertions(+), 197 deletions(-) diff --git a/mii/client.py b/mii/client.py index df9bab48..8b1b9235 100644 --- a/mii/client.py +++ b/mii/client.py @@ -12,6 +12,12 @@ from mii.config import MIIConfig +def _get_mii_config(deployment_name): + mii_config = mii.utils.import_score_file(deployment_name).mii_config + # TODO: Avoid model checking when we load the config + return MIIConfig(**mii_config) + + def mii_query_handle(deployment_name): """Get a query handle for a local deployment: @@ -29,18 +35,20 @@ def mii_query_handle(deployment_name): inference_pipeline, task = mii.non_persistent_models[deployment_name] return MIINonPersistentClient(task, deployment_name) - mii_config = mii.utils.import_score_file(deployment_name).mii_config - # TODO: Avoid model checking when we laod the config - mii_config = MIIConfig(**mii_config) - return MIIClient(mii_config.deployment_config.task, "localhost", mii_config.port_number) + mii_config = _get_mii_config(deployment_name) + return MIIClient(mii_config.deployment_config.task, + "localhost", + mii_config.port_number) def create_channel(host, port): return grpc.aio.insecure_channel( f"{host}:{port}", options=[ - ("grpc.max_send_message_length", GRPC_MAX_MSG_SIZE), - ("grpc.max_receive_message_length", GRPC_MAX_MSG_SIZE), + ("grpc.max_send_message_length", + GRPC_MAX_MSG_SIZE), + ("grpc.max_receive_message_length", + GRPC_MAX_MSG_SIZE), ], ) @@ -49,7 +57,6 @@ class MIIClient: """ Client to send queries to a single endpoint. """ - def __init__(self, task, host, port): self.asyncio_loop = asyncio.get_event_loop() channel = create_channel(host, port) @@ -67,34 +74,30 @@ async def _request_async_response(self, request_dict, **query_kwargs): def query(self, request_dict, **query_kwargs): return self.asyncio_loop.run_until_complete( - self._request_async_response(request_dict, **query_kwargs) - ) + self._request_async_response(request_dict, + **query_kwargs)) async def terminate_async(self): await self.stub.Terminate( - modelresponse_pb2.google_dot_protobuf_dot_empty__pb2.Empty() - ) + modelresponse_pb2.google_dot_protobuf_dot_empty__pb2.Empty()) def terminate(self): self.asyncio_loop.run_until_complete(self.terminate_async()) async def create_session_async(self, session_id): return await self.stub.CreateSession( - modelresponse_pb2.SessionID(session_id=session_id) - ) + modelresponse_pb2.SessionID(session_id=session_id)) def create_session(self, session_id): assert ( self.task == TaskType.TEXT_GENERATION ), f"Session creation only available for task '{TaskType.TEXT_GENERATION}'." return self.asyncio_loop.run_until_complete( - self.create_session_async(session_id) - ) + self.create_session_async(session_id)) async def destroy_session_async(self, session_id): - await self.stub.DestroySession( - modelresponse_pb2.SessionID(session_id=session_id) - ) + await self.stub.DestroySession(modelresponse_pb2.SessionID(session_id=session_id) + ) def destroy_session(self, session_id): assert ( @@ -108,7 +111,6 @@ class MIITensorParallelClient: Client to send queries to multiple endpoints in parallel. This is used to call multiple servers deployed for tensor parallelism. """ - def __init__(self, task, host, ports): self.task = task self.clients = [MIIClient(task, host, port) for port in ports] @@ -120,9 +122,8 @@ async def _query_in_tensor_parallel(self, request_string, query_kwargs): for client in self.clients: responses.append( self.asyncio_loop.create_task( - client._request_async_response(request_string, **query_kwargs) - ) - ) + client._request_async_response(request_string, + **query_kwargs))) await responses[0] return responses[0] @@ -141,8 +142,8 @@ def query(self, request_dict, **query_kwargs): response: Response of the model """ response = self.asyncio_loop.run_until_complete( - self._query_in_tensor_parallel(request_dict, query_kwargs) - ) + self._query_in_tensor_parallel(request_dict, + query_kwargs)) ret = response.result() return ret @@ -175,18 +176,17 @@ def query(self, request_dict, **query_kwargs): if self.task == TaskType.QUESTION_ANSWERING: if "question" not in request_dict or "context" not in request_dict: raise Exception( - "Question Answering Task requires 'question' and 'context' keys" - ) + "Question Answering Task requires 'question' and 'context' keys") args = (request_dict["question"], request_dict["context"]) kwargs = query_kwargs elif self.task == TaskType.CONVERSATIONAL: conv = task_methods.create_conversation(request_dict, **query_kwargs) - args = (conv,) + args = (conv, ) kwargs = {} else: - args = (request_dict["query"],) + args = (request_dict["query"], ) kwargs = query_kwargs return task_methods.run_inference(inference_pipeline, args, query_kwargs) @@ -197,6 +197,6 @@ def terminate(self): def terminate_restful_gateway(deployment_name): - _, mii_configs = _get_deployment_info(deployment_name) - if mii_configs.enable_restful_api: - requests.get(f"http://localhost:{mii_configs.restful_api_port}/terminate") + mii_config = _get_mii_config(deployment_name) + if mii_config.enable_restful_api: + requests.get(f"http://localhost:{mii_config.restful_api_port}/terminate") diff --git a/mii/config.py b/mii/config.py index e4392b9c..12295305 100644 --- a/mii/config.py +++ b/mii/config.py @@ -4,8 +4,8 @@ # DeepSpeed Team import torch import os -from typing import Union, List, Optional, Dict, Any -from enum import Enum +import string +from typing import List, Optional, Dict, Any from pydantic import validator, root_validator import mii from mii.constants import DeploymentType, TaskType, MII_MODEL_PATH_DEFAULT @@ -61,7 +61,7 @@ def checkpoint_dict_valid(cls, field_value, values): for k in ["checkpoints", "parallelization", "version", "type"]: if not field_value.get(k, ""): raise ValueError(f"Missing key={k} in checkpoint_dict") - return vfield_alue + return field_value @validator("deploy_rank", pre=True) def deploy_rank_to_list(cls, field_value, values): @@ -189,9 +189,8 @@ def propagate_hf_auth(cls, values): @root_validator def AML_name_valid(cls, values): if values.get("deployment_type") == DeploymentType.AML: - allowed_chars = set( - string.ascii_lowercase + string.ascii_uppercaes + string.digits + "-" - ) + allowed_chars = set(string.ascii_lowercase + string.ascii_uppercaes + + string.digits + "-") assert ( set(values.get("deployment_config").deployment_name) <= allowed_chars ), "AML deployment names can only contain a-z, A-Z, 0-9, and '-'." @@ -224,8 +223,7 @@ def generate_replica_configs(cls, values): tensor_parallel_ports=tensor_parallel_ports, torch_dist_port=replica_torch_dist_port, gpu_indices=gpu_indices, - ) - ) + )) values.get("deployment_config").replica_configs = replica_configs return values @@ -250,18 +248,15 @@ def _allocate_processes(hostfile_path, tensor_parallel, replica_num): ) allocated_num_on_host = slots - available_on_host - replica_pool.append( - ( - host, - [ - i - for i in range( - allocated_num_on_host, - allocated_num_on_host + tensor_parallel, - ) - ], - ) - ) + replica_pool.append(( + host, + [ + i for i in range( + allocated_num_on_host, + allocated_num_on_host + tensor_parallel, + ) + ], + )) allocated_num += 1 available_on_host -= tensor_parallel diff --git a/mii/constants.py b/mii/constants.py index 55487701..6fc80854 100644 --- a/mii/constants.py +++ b/mii/constants.py @@ -44,7 +44,8 @@ class ModelProvider(str, Enum): REQUIRED_KEYS_PER_TASK = { TaskType.TEXT_GENERATION: ["query"], TaskType.TEXT_CLASSIFICATION: ["query"], - TaskType.QUESTION_ANSWERING: ["context", "question"], + TaskType.QUESTION_ANSWERING: ["context", + "question"], TaskType.FILL_MASK: ["query"], TaskType.TOKEN_CLASSIFICATION: ["query"], TaskType.CONVERSATIONAL: [ diff --git a/mii/deployment.py b/mii/deployment.py index d4e2bdc4..3bec6c09 100644 --- a/mii/deployment.py +++ b/mii/deployment.py @@ -2,12 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team -import torch -import string import os import mii - from .utils import logger from .models.score import create_score_file from .models import load_models @@ -106,6 +103,6 @@ def _deploy_aml(mii_config): version=mii_config.version, ) print( - f"AML deployment assets at {mii.aml_related.utils.aml_output_path(deployment_name)}" + f"AML deployment assets at {mii.aml_related.utils.aml_output_path(mii_config.deployment_config.deployment_name)}" ) print("Please run 'deploy.sh' to bring your deployment online") diff --git a/mii/grpc_related/modelresponse_server.py b/mii/grpc_related/modelresponse_server.py index 30055286..adec8e03 100644 --- a/mii/grpc_related/modelresponse_server.py +++ b/mii/grpc_related/modelresponse_server.py @@ -31,7 +31,6 @@ class ServiceBase(modelresponse_pb2_grpc.ModelResponseServicer): """ Base class to provide common features of an inference server """ - def __init__(self): self._stop_event = threading.Event() @@ -47,7 +46,6 @@ class ModelResponse(ServiceBase): """ Implementation class of an MII inference server """ - def __init__(self, inference_pipeline): super().__init__() self.inference_pipeline = inference_pipeline @@ -72,12 +70,12 @@ def _get_model_time(self, model, sum_times=False): return model_time def CreateSession(self, request, context): - task_methods = GRPC_METHOD_TABLE[Tasks.TEXT_GENERATION] + task_methods = GRPC_METHOD_TABLE[TaskType.TEXT_GENERATION] task_methods.create_session(request.session_id) return google_dot_protobuf_dot_empty__pb2.Empty() def DestroySession(self, request, context): - task_methods = GRPC_METHOD_TABLE[Tasks.TEXT_GENERATION] + task_methods = GRPC_METHOD_TABLE[TaskType.TEXT_GENERATION] task_methods.destroy_session(request.session_id) return google_dot_protobuf_dot_empty__pb2.Empty() @@ -97,11 +95,10 @@ def _run_inference(self, method_name, request_proto): response = task_methods.run_inference(self.inference_pipeline, args, kwargs) end = time.time() - model_time = ( - self._get_model_time(self.inference_pipeline.model, sum_times=True) - if hasattr(self.inference_pipeline, "model") - else -1 - ) + model_time = (self._get_model_time(self.inference_pipeline.model, + sum_times=True) if hasattr( + self.inference_pipeline, + "model") else -1) return task_methods.pack_response_to_proto(response, end - start, model_time) @@ -149,7 +146,6 @@ class ParallelStubInvoker: This class aims to call gRPC methods without conversions between proto and python object. TensorParallelClient can be used for invocation with the conversions. """ - def __init__(self, host, ports): # Assumption: target services are all on the same host self.stubs = [] @@ -170,8 +166,9 @@ async def _invoke_async(self, method_name, proto_request): def invoke(self, method_name, proto_request): # This is needed because gRPC calls from interceptor are launched from return asyncio.run_coroutine_threadsafe( - self._invoke_async(method_name, proto_request), self.asyncio_loop - ).result() + self._invoke_async(method_name, + proto_request), + self.asyncio_loop).result() class LoadBalancingInterceptor(grpc.ServerInterceptor): @@ -180,7 +177,8 @@ def __init__(self, deployment_config): self.asyncio_loop = asyncio.get_event_loop() self.stubs = [ - ParallelStubInvoker(replica.hostname, replica.tensor_parallel_ports) + ParallelStubInvoker(replica.hostname, + replica.tensor_parallel_ports) for replica in deployment_config.replica_configs ] self.counter = AtomicCounter() @@ -192,7 +190,7 @@ def run_asyncio_loop(loop): asyncio.set_event_loop(loop) loop.run_forever() - threading.Thread(target=run_asyncio_loop, args=(self.asyncio_loop,)).start() + threading.Thread(target=run_asyncio_loop, args=(self.asyncio_loop, )).start() def choose_stub(self, call_count): return self.stubs[call_count % len(self.stubs)] @@ -206,9 +204,8 @@ def invoke_intercept_method(request_proto, context): if method_name == TERMINATE_METHOD: for stub in self.stubs: - stub.invoke( - TERMINATE_METHOD, google_dot_protobuf_dot_empty__pb2.Empty() - ) + stub.invoke(TERMINATE_METHOD, + google_dot_protobuf_dot_empty__pb2.Empty()) self.asyncio_loop.call_soon_threadsafe(self.asyncio_loop.stop) return next_handler.unary_unary(request_proto, context) @@ -218,8 +215,7 @@ def invoke_intercept_method(request_proto, context): if method_name == CREATE_SESSION_METHOD: if request_proto.session_id in self.sessions: raise ValueError( - f"session {request_proto.session_id} already exists" - ) + f"session {request_proto.session_id} already exists") self.replica_sessions[request_proto.session_id] = replica_index self.stubs[replica_index].invoke(CREATE_SESSION_METHOD, request_proto) return google_dot_protobuf_dot_empty__pb2.Empty() @@ -252,8 +248,10 @@ def _do_serve(service_impl, port, interceptors=[]): futures.ThreadPoolExecutor(max_workers=LB_MAX_WORKER_THREADS), interceptors=interceptors, options=[ - ("grpc.max_send_message_length", GRPC_MAX_MSG_SIZE), - ("grpc.max_receive_message_length", GRPC_MAX_MSG_SIZE), + ("grpc.max_send_message_length", + GRPC_MAX_MSG_SIZE), + ("grpc.max_receive_message_length", + GRPC_MAX_MSG_SIZE), ], ) modelresponse_pb2_grpc.add_ModelResponseServicer_to_server(service_impl, server) diff --git a/mii/grpc_related/proto/modelresponse_pb2_grpc.py b/mii/grpc_related/proto/modelresponse_pb2_grpc.py index af1b1898..0de21c3e 100644 --- a/mii/grpc_related/proto/modelresponse_pb2_grpc.py +++ b/mii/grpc_related/proto/modelresponse_pb2_grpc.py @@ -13,7 +13,6 @@ class ModelResponseStub(object): """Missing associated documentation comment in .proto file.""" - def __init__(self, channel): """Constructor. @@ -22,7 +21,8 @@ def __init__(self, channel): """ self.Terminate = channel.unary_unary( "/modelresponse.ModelResponse/Terminate", - request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + request_serializer=google_dot_protobuf_dot_empty__pb2.Empty. + SerializeToString, response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, ) self.CreateSession = channel.unary_unary( @@ -74,7 +74,6 @@ def __init__(self, channel): class ModelResponseServicer(object): """Missing associated documentation comment in .proto file.""" - def Terminate(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -138,67 +137,78 @@ def Txt2ImgReply(self, request, context): def add_ModelResponseServicer_to_server(servicer, server): rpc_method_handlers = { - "Terminate": grpc.unary_unary_rpc_method_handler( + "Terminate": + grpc.unary_unary_rpc_method_handler( servicer.Terminate, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. + SerializeToString, ), - "CreateSession": grpc.unary_unary_rpc_method_handler( + "CreateSession": + grpc.unary_unary_rpc_method_handler( servicer.CreateSession, request_deserializer=modelresponse__pb2.SessionID.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. + SerializeToString, ), - "DestroySession": grpc.unary_unary_rpc_method_handler( + "DestroySession": + grpc.unary_unary_rpc_method_handler( servicer.DestroySession, request_deserializer=modelresponse__pb2.SessionID.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. + SerializeToString, ), - "GeneratorReply": grpc.unary_unary_rpc_method_handler( + "GeneratorReply": + grpc.unary_unary_rpc_method_handler( servicer.GeneratorReply, request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, response_serializer=modelresponse__pb2.MultiStringReply.SerializeToString, ), - "ClassificationReply": grpc.unary_unary_rpc_method_handler( + "ClassificationReply": + grpc.unary_unary_rpc_method_handler( servicer.ClassificationReply, request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, ), - "QuestionAndAnswerReply": grpc.unary_unary_rpc_method_handler( + "QuestionAndAnswerReply": + grpc.unary_unary_rpc_method_handler( servicer.QuestionAndAnswerReply, request_deserializer=modelresponse__pb2.QARequest.FromString, response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, ), - "FillMaskReply": grpc.unary_unary_rpc_method_handler( + "FillMaskReply": + grpc.unary_unary_rpc_method_handler( servicer.FillMaskReply, request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, ), - "TokenClassificationReply": grpc.unary_unary_rpc_method_handler( + "TokenClassificationReply": + grpc.unary_unary_rpc_method_handler( servicer.TokenClassificationReply, request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, ), - "ConversationalReply": grpc.unary_unary_rpc_method_handler( + "ConversationalReply": + grpc.unary_unary_rpc_method_handler( servicer.ConversationalReply, request_deserializer=modelresponse__pb2.ConversationRequest.FromString, response_serializer=modelresponse__pb2.ConversationReply.SerializeToString, ), - "Txt2ImgReply": grpc.unary_unary_rpc_method_handler( + "Txt2ImgReply": + grpc.unary_unary_rpc_method_handler( servicer.Txt2ImgReply, request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, response_serializer=modelresponse__pb2.ImageReply.SerializeToString, ), } - generic_handler = grpc.method_handlers_generic_handler( - "modelresponse.ModelResponse", rpc_method_handlers - ) - server.add_generic_rpc_handlers((generic_handler,)) + generic_handler = grpc.method_handlers_generic_handler("modelresponse.ModelResponse", + rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler, )) # This class is part of an EXPERIMENTAL API. class ModelResponse(object): """Missing associated documentation comment in .proto file.""" - @staticmethod def Terminate( request, diff --git a/mii/grpc_related/restful_gateway.py b/mii/grpc_related/restful_gateway.py index f3eeaa46..68b17c73 100644 --- a/mii/grpc_related/restful_gateway.py +++ b/mii/grpc_related/restful_gateway.py @@ -19,7 +19,7 @@ def shutdown(thread): def createRestfulGatewayApp(deployment_config, lb_port, server_thread): # client must be thread-safe - client = mii.MIIClient(task, "localhost", lb_port) + client = mii.MIIClient(deployment_config.task, "localhost", lb_port) class RestfulGatewayService(Resource): def __init__(self): @@ -36,7 +36,7 @@ def post(self): @app.route("/terminate", methods=["GET"]) def terminate(): # Need to shutdown *after* completing the request - threading.Thread(target=shutdown, args=(server_thread,)).start() + threading.Thread(target=shutdown, args=(server_thread, )).start() return "Shutting down RESTful API gateway server" api = Api(app) @@ -49,7 +49,6 @@ def terminate(): class RestfulGatewayThread(threading.Thread): def __init__(self, deployment_config, lb_port, rest_port): threading.Thread.__init__(self) - self.mii_config = mii_config app = createRestfulGatewayApp(deployment_config, lb_port, self) self.server = make_server("127.0.0.1", rest_port, app) diff --git a/mii/launch/multi_gpu_server.py b/mii/launch/multi_gpu_server.py index 5e7988aa..4831593a 100644 --- a/mii/launch/multi_gpu_server.py +++ b/mii/launch/multi_gpu_server.py @@ -4,12 +4,10 @@ # DeepSpeed Team import os import argparse -import mii import base64 import json -from mii import DeploymentConfig - +from mii.config import DeploymentConfig from mii.models.load_models import load_models from mii.grpc_related.modelresponse_server import serve_inference, serve_load_balancing from mii.grpc_related.restful_gateway import RestfulGatewayThread @@ -39,9 +37,9 @@ def main(): default=0, help="Port to user for DeepSpeed inference server.", ) - parser.add_argument( - "--load-balancer", action="store_true", help="Launch load balancer process." - ) + parser.add_argument("--load-balancer", + action="store_true", + help="Launch load balancer process.") parser.add_argument( "--load-balancer-port", type=int, diff --git a/mii/method_table.py b/mii/method_table.py index a5457f24..d1eb0b9d 100644 --- a/mii/method_table.py +++ b/mii/method_table.py @@ -12,33 +12,32 @@ def single_string_request_to_proto(self, request_dict, **query_kwargs): return modelresponse_pb2.SingleStringRequest( - request=request_dict["query"], query_kwargs=kwarg_dict_to_proto(query_kwargs) - ) + request=request_dict["query"], + query_kwargs=kwarg_dict_to_proto(query_kwargs)) def single_string_response_to_proto(self, response, time_taken, model_time_taken): - return modelresponse_pb2.SingleStringReply( - response=f"{response}", time_taken=time_taken, model_time_taken=model_time_taken - ) + return modelresponse_pb2.SingleStringReply(response=f"{response}", + time_taken=time_taken, + model_time_taken=model_time_taken) def multi_string_request_to_proto(self, request_dict, **query_kwargs): return modelresponse_pb2.MultiStringRequest( - request=request_dict["query"] - if isinstance(request_dict["query"], list) - else [request_dict["query"]], + request=request_dict["query"] if isinstance(request_dict["query"], + list) else [request_dict["query"]], query_kwargs=kwarg_dict_to_proto(query_kwargs), ) def proto_request_to_single_input(self, request): - args = (request.request,) + args = (request.request, ) kwargs = unpack_proto_query_kwargs(request.query_kwargs) return args, kwargs def proto_request_to_list(self, request): - args = ([r for r in request.request],) + args = ([r for r in request.request], ) kwargs = unpack_proto_query_kwargs(request.query_kwargs) return args, kwargs @@ -93,7 +92,7 @@ def preprocess_session(self, session_id, args): if len(args[0]) != 1: raise ValueError(f"You can pass only one prompt with a session_id") - args = ([self.session_context[session_id] + args[0][0]],) + args = ([self.session_context[session_id] + args[0][0]], ) return args def run_inference(self, inference_pipeline, args, kwargs): @@ -110,7 +109,7 @@ def run_inference(self, inference_pipeline, args, kwargs): def postprocess_session(self, session_id, args, response): generated_text = response[0][0]["generated_text"] self.session_context[session_id] = generated_text - response[0][0]["generated_text"] = generated_text[len(args[0][0]) :] + response[0][0]["generated_text"] = generated_text[len(args[0][0]):] return response def pack_response_to_proto(self, response, time_taken, model_time_taken): @@ -191,9 +190,8 @@ def create_conversation(self, request, **kwargs): and "generated_responses" in request ), "Conversation requires 'text', 'past_user_inputs', and 'generated_responses' keys" text = request["text"] - conversation_id = ( - request["conversation_id"] if "conversation_id" in request else None - ) + conversation_id = (request["conversation_id"] + if "conversation_id" in request else None) past_user_inputs = request["past_user_inputs"] generated_responses = request["generated_responses"] @@ -224,7 +222,7 @@ def pack_response_to_proto(self, conv, time_taken, model_time_taken): def unpack_request_from_proto(self, request): kwargs = unpack_proto_query_kwargs(request.query_kwargs) conv = self.create_conversation(request, **kwargs) - args = (conv,) + args = (conv, ) kwargs = {} return args, kwargs @@ -232,8 +230,7 @@ def pack_request_to_proto(self, request_dict, **query_kwargs): return modelresponse_pb2.ConversationRequest( text=request_dict["text"], conversation_id=request_dict["conversation_id"] - if "conversation_id" in request_dict - else None, + if "conversation_id" in request_dict else None, past_user_inputs=request_dict["past_user_inputs"], generated_responses=request_dict["generated_responses"], query_kwargs=kwarg_dict_to_proto(query_kwargs), diff --git a/mii/models/load_models.py b/mii/models/load_models.py index 09011c31..6da0fa9a 100644 --- a/mii/models/load_models.py +++ b/mii/models/load_models.py @@ -4,7 +4,6 @@ # DeepSpeed Team import os import mii -import json import torch import inspect import deepspeed @@ -18,7 +17,10 @@ def load_models(deployment_config): world_size = int(os.getenv("WORLD_SIZE", "1")) inf_config = { - "tensor_parallel": {"tp_size": deployment_config.tensor_parallel, "mpu": None}, + "tensor_parallel": { + "tp_size": deployment_config.tensor_parallel, + "mpu": None + }, "dtype": deployment_config.dtype, "replace_method": "auto", "enable_cuda_graph": deployment_config.enable_cuda_graph, @@ -38,10 +40,8 @@ def load_models(deployment_config): inf_config["checkpoint"] = inference_pipeline.checkpoint_dict if deployment_config.dtype == torch.int8: # Support for older DeepSpeed versions - if ( - "enable_qkv_quantization" - in inspect.signature(deepspeed.init_inference).parameters - ): + if ("enable_qkv_quantization" + in inspect.signature(deepspeed.init_inference).parameters): inf_config["enable_qkv_quantization"] = True elif provider == mii.constants.ModelProvider.ELEUTHER_AI: assert False, "Eleuther AI support is currently disabled." @@ -66,16 +66,16 @@ def load_models(deployment_config): inf_config["enable_cuda_graph"] = True else: raise ValueError(f"Unknown model provider {provider}") - """ print( f"> --------- MII Settings: ds_optimize={ds_optimize}, replace_with_kernel_inject={mii_config.replace_with_kernel_inject}, enable_cuda_graph={mii_config.enable_cuda_graph} " ) """ if deployment_config.enable_deepspeed: - engine = deepspeed.init_inference( - getattr(inference_pipeline, "model", inference_pipeline), config=inf_config - ) + engine = deepspeed.init_inference(getattr(inference_pipeline, + "model", + inference_pipeline), + config=inf_config) if deployment_config.profile_model_time: engine.profile_model_time() if hasattr(inference_pipeline, "model"): @@ -88,9 +88,8 @@ def load_models(deployment_config): ), "DeepSpeed ZeRO inference is only supported for ZeRO-3" # initialise Deepspeed ZeRO and store only the engine object - ds_engine = deepspeed.initialize( - model=inference_pipeline.model, config_params=ds_config - )[0] + ds_engine = deepspeed.initialize(model=inference_pipeline.model, + config_params=ds_config)[0] ds_engine.module.eval() # inference inference_pipeline.model = ds_engine.module diff --git a/mii/models/providers/diffusers.py b/mii/models/providers/diffusers.py index 3109223e..517559fe 100644 --- a/mii/models/providers/diffusers.py +++ b/mii/models/providers/diffusers.py @@ -17,8 +17,9 @@ def diffusers_provider(deployment_config): kwargs["revision"] = "fp16" pipeline = DiffusionPipeline.from_pretrained( - deployment_config.model, use_auth_token=deployment_config.hf_auth_token, **kwargs - ) + deployment_config.model, + use_auth_token=deployment_config.hf_auth_token, + **kwargs) pipeline = pipeline.to(f"cuda:{local_rank}") pipeline.set_progress_bar_config(disable=True) return pipeline diff --git a/mii/models/providers/huggingface.py b/mii/models/providers/huggingface.py index 60862b8e..52b55730 100644 --- a/mii/models/providers/huggingface.py +++ b/mii/models/providers/huggingface.py @@ -30,7 +30,6 @@ class MetaTensorPipeline(object): """ Class for loading HuggingFace models using meta tensors """ - def __init__(self, model, tokenizer, checkpoint_dict): self.model = model self.tokenizer = tokenizer @@ -44,9 +43,9 @@ def __call__(self, inputs, **kwargs): # expand proto list into py-list inputs = [i for i in inputs] - tokens = self.tokenizer.batch_encode_plus( - inputs, return_tensors="pt", padding=True - ) + tokens = self.tokenizer.batch_encode_plus(inputs, + return_tensors="pt", + padding=True) for t in tokens: if torch.is_tensor(tokens[t]): tokens[t] = tokens[t].to(device) @@ -87,9 +86,9 @@ def get_checkpoint_files(pretrained_model_name_or_path): local_files_only = False filename = WEIGHTS_NAME - archive_file = hf_bucket_url( - pretrained_model_name_or_path, filename=filename, revision=revision - ) + archive_file = hf_bucket_url(pretrained_model_name_or_path, + filename=filename, + revision=revision) try: resolved_archive_file = cached_path( @@ -150,8 +149,7 @@ def create_checkpoint_dict(model_name, model_path, checkpoint_dict): if USE_NEW_HF_CACHE: checkpoint_files = [ str(entry).split("/")[-1] - for entry in Path(model_path).rglob("*.[bp][it][n]") - if entry.is_file() + for entry in Path(model_path).rglob("*.[bp][it][n]") if entry.is_file() ] else: checkpoint_files = get_checkpoint_files(model_name) @@ -177,16 +175,20 @@ def load_with_meta_tensor(deployment_config): ) tokenizer.pad_token = tokenizer.eos_token - config = _attempt_load(AutoConfig.from_pretrained, deployment_config.model, cache_path) + config = _attempt_load(AutoConfig.from_pretrained, + deployment_config.model, + cache_path) with OnDevice(dtype=torch.float16, device="meta", enabled=True): model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16) model = model.eval() - checkpoint_dict = create_checkpoint_dict(deployment_config.model, deployment_config.model_path, deployment_config.checkpoint_dict) + checkpoint_dict = create_checkpoint_dict(deployment_config.model, + deployment_config.model_path, + deployment_config.checkpoint_dict) torch.distributed.barrier() - inference_pipeline = MetaTensorPipeline( - model=model, tokenizer=tokenizer, checkpoint_dict=checkpoint_dict - ) + inference_pipeline = MetaTensorPipeline(model=model, + tokenizer=tokenizer, + checkpoint_dict=checkpoint_dict) return inference_pipeline diff --git a/mii/models/score/generate.py b/mii/models/score/generate.py index ccc35695..b351fab6 100644 --- a/mii/models/score/generate.py +++ b/mii/models/score/generate.py @@ -15,9 +15,9 @@ def create_score_file(mii_config): f"Detected mii path as multiple sources: {mii.__path__}, might cause unknown behavior" ) - with open( - os.path.join(mii.__path__[0], "models/score/score_template.py"), "r" - ) as fd: + with open(os.path.join(mii.__path__[0], + "models/score/score_template.py"), + "r") as fd: score_src = fd.read() # update score file w. global config dict @@ -36,9 +36,8 @@ def generated_score_path(mii_config): if deployment_type == DeploymentType.LOCAL: score_path = os.path.join(mii.utils.mii_cache_path(), deployment_name) elif deployment_type == DeploymentType.AML: - score_path = os.path.join( - mii.aml_related.utils.aml_output_path(deployment_name), "code" - ) + score_path = os.path.join(mii.aml_related.utils.aml_output_path(deployment_name), + "code") if not os.path.isdir(score_path): os.makedirs(score_path) return os.path.join(score_path, "score.py") diff --git a/mii/server.py b/mii/server.py index f2f2a42f..04002e2e 100644 --- a/mii/server.py +++ b/mii/server.py @@ -3,17 +3,13 @@ # DeepSpeed Team import base64 -import json import os import subprocess import sys import tempfile import time -import torch -from pathlib import Path from collections import defaultdict -import mii from mii.utils import get_num_gpus, logger @@ -28,13 +24,11 @@ def config_to_b64_str(config): class MIIServer: """Initialize the model, setup the server for the model under model_path""" - def __init__(self, mii_config): self.task = mii_config.deployment_config.task self.num_gpus = get_num_gpus(mii_config) assert self.num_gpus > 0, "GPU count must be greater than 0" - """ if mii_configs.hostfile is None: hostfile = tempfile.NamedTemporaryFile(delete=False) @@ -45,26 +39,25 @@ def __init__(self, mii_config): """ processes = self._initialize_service(mii_config) - self._wait_until_server_is_live(processes, mii_config.deployment_config.replica_configs) + self._wait_until_server_is_live(processes, + mii_config.deployment_config.replica_configs) def _wait_until_server_is_live(self, processes, deployment): for process, repl_config in zip(processes, deployment): sockets_open = False while not sockets_open: sockets_open = all( - self._is_socket_open(repl_config.hostname, port) - for port in repl_config.tensor_parallel_ports - ) + self._is_socket_open(repl_config.hostname, + port) + for port in repl_config.tensor_parallel_ports) process_alive = self._is_server_process_alive(process) if not process_alive: raise RuntimeError( - "server crashed for some reason, unable to proceed" - ) + "server crashed for some reason, unable to proceed") time.sleep(4) logger.info("waiting for server to start...") logger.info( - f"server has started on ports {repl_config.tensor_parallel_ports}" - ) + f"server has started on ports {repl_config.tensor_parallel_ports}") def _is_socket_open(self, host, port): import socket @@ -87,9 +80,11 @@ def _is_server_process_alive(self, process): is_alive = False return is_alive - def _launch_server_process( - self, deployment_config, msg_server_type, ds_launch_str="", server_args=[] - ): + def _launch_server_process(self, + deployment_config, + msg_server_type, + ds_launch_str="", + server_args=[]): launch_str = f"{sys.executable} -m mii.launch.multi_gpu_server" b64_config_str = config_to_b64_str(deployment_config) server_args.append(f"--deployment-config {b64_config_str}") @@ -132,18 +127,17 @@ def _initialize_service(self, mii_config): for repl_config in mii_config.deployment_config.replica_configs: hostfile = tempfile.NamedTemporaryFile(delete=False) hostfile.write( - f"{repl_config.hostname} slots={max(host_gpus[repl_config.hostname])+1}\n".encode() - ) + f"{repl_config.hostname} slots={max(host_gpus[repl_config.hostname])+1}\n" + .encode()) ds_launch_str = self._generate_ds_launch_str(repl_config, hostfile.name) processes.append( self._launch_server_process( mii_config.deployment_config, "MII server", ds_launch_str=ds_launch_str, - server_args=server_args - + [f"--server-port {repl_config.tensor_parallel_ports[0]}"], - ) - ) + server_args=server_args + + [f"--server-port {repl_config.tensor_parallel_ports[0]}"], + )) # start load balancer here. # we don't use deepspeed launcher for the load balancer because it does not need a GPU. # The deepspeed launcher determines the number of processes to launch based on GPUs available on the host or CUDA_VISIBLE_DEVICES, @@ -153,8 +147,7 @@ def _initialize_service(self, mii_config): mii_config.deployment_config, "load balancer", server_args=server_args + ["--load-balancer"], - ) - ) + )) if mii_config.enable_restful_api: processes.append( @@ -162,7 +155,6 @@ def _initialize_service(self, mii_config): mii_config.deployment_config, "restful api gateway", server_args=server_args + ["--restful-gateway"], - ) - ) + )) return processes diff --git a/mii/utils.py b/mii/utils.py index 14a4b74a..72e22ea2 100644 --- a/mii/utils.py +++ b/mii/utils.py @@ -24,11 +24,8 @@ def _get_hf_models_by_type(model_type, task=None): api = HfApi() models = api.list_models(filter=model_type) - models = ( - [m.modelId for m in models] - if task is None - else [m.modelId for m in models if m.pipeline_tag == task] - ) + models = ([m.modelId for m in models] + if task is None else [m.modelId for m in models if m.pipeline_tag == task]) if task == TaskType.TEXT_GENERATION: # TODO: this is a temp solution to get around some HF models not having the correct tags models.append("microsoft/bloom-deepspeed-inference-fp16") @@ -79,8 +76,10 @@ def mii_cache_path(): def import_score_file(deployment_name): spec = importlib.util.spec_from_file_location( - "score", os.path.join(mii_cache_path(), deployment_name, "score.py") - ) + "score", + os.path.join(mii_cache_path(), + deployment_name, + "score.py")) score = importlib.util.module_from_spec(spec) spec.loader.exec_module(score) return score @@ -105,7 +104,10 @@ def get_proto_value(value): def unpack_proto_query_kwargs(query_kwargs): query_kwargs = { - k: getattr(v, v.WhichOneof("oneof_values")) for k, v in query_kwargs.items() + k: getattr(v, + v.WhichOneof("oneof_values")) + for k, + v in query_kwargs.items() } return query_kwargs @@ -165,8 +167,7 @@ def create_logger(name=None, level=logging.INFO): formatter = logging.Formatter( "[%(asctime)s] [%(levelname)s] " - "[%(filename)s:%(lineno)d:%(funcName)s] %(message)s" - ) + "[%(filename)s:%(lineno)d:%(funcName)s] %(message)s") logger_ = logging.getLogger(name) logger_.setLevel(level) From e5aea2098a43a67ea36141f1ecf722b1246f0863 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 8 Aug 2023 10:48:48 -0700 Subject: [PATCH 05/15] remove formatting changes --- mii/grpc_related/proto/modelresponse_pb2.py | 13 +- .../proto/modelresponse_pb2_grpc.py | 352 ++++++++---------- 2 files changed, 167 insertions(+), 198 deletions(-) diff --git a/mii/grpc_related/proto/modelresponse_pb2.py b/mii/grpc_related/proto/modelresponse_pb2.py index 677ecd0d..76b1f994 100644 --- a/mii/grpc_related/proto/modelresponse_pb2.py +++ b/mii/grpc_related/proto/modelresponse_pb2.py @@ -10,7 +10,6 @@ from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database - # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -18,22 +17,22 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t"\xbb\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01"\xb9\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02"\xb9\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01"\xa1\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x1c\n\x0f\x63onversation_id\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_conversation_id"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\x03\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x32\xd4\x06\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty"\x00\x12\x43\n\rCreateSession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty"\x00\x12\x44\n\x0e\x44\x65stroySession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty"\x00\x12V\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1f.modelresponse.MultiStringReply"\x00\x12]\n\x13\x43lassificationReply\x12".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply"\x00\x12W\n\rFillMaskReply\x12".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply"\x00\x12\x62\n\x18TokenClassificationReply\x12".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply"\x00\x12]\n\x13\x43onversationalReply\x12".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply"\x00\x12N\n\x0cTxt2ImgReply\x12!.modelresponse.MultiStringRequest\x1a\x19.modelresponse.ImageReply"\x00\x62\x06proto3' + b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xbb\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xb9\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"\xb9\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xa1\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x1c\n\x0f\x63onversation_id\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_conversation_id\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\x03\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x32\xd4\x06\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x43\n\rCreateSession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x44\n\x0e\x44\x65stroySession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12V\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1f.modelresponse.MultiStringReply\"\x00\x12]\n\x13\x43lassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply\"\x00\x12W\n\rFillMaskReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12\x62\n\x18TokenClassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12]\n\x13\x43onversationalReply\x12\".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply\"\x00\x12N\n\x0cTxt2ImgReply\x12!.modelresponse.MultiStringRequest\x1a\x19.modelresponse.ImageReply\"\x00\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "modelresponse_pb2", globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'modelresponse_pb2', globals()) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._options = None - _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b"8\001" + _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' _MULTISTRINGREQUEST_QUERYKWARGSENTRY._options = None - _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b"8\001" + _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' _QAREQUEST_QUERYKWARGSENTRY._options = None - _QAREQUEST_QUERYKWARGSENTRY._serialized_options = b"8\001" + _QAREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' _CONVERSATIONREQUEST_QUERYKWARGSENTRY._options = None - _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_options = b"8\001" + _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' _VALUE._serialized_start = 67 _VALUE._serialized_end = 162 _SESSIONID._serialized_start = 164 diff --git a/mii/grpc_related/proto/modelresponse_pb2_grpc.py b/mii/grpc_related/proto/modelresponse_pb2_grpc.py index 0de21c3e..95cfa825 100644 --- a/mii/grpc_related/proto/modelresponse_pb2_grpc.py +++ b/mii/grpc_related/proto/modelresponse_pb2_grpc.py @@ -20,53 +20,53 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Terminate = channel.unary_unary( - "/modelresponse.ModelResponse/Terminate", + '/modelresponse.ModelResponse/Terminate', request_serializer=google_dot_protobuf_dot_empty__pb2.Empty. SerializeToString, response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, ) self.CreateSession = channel.unary_unary( - "/modelresponse.ModelResponse/CreateSession", + '/modelresponse.ModelResponse/CreateSession', request_serializer=modelresponse__pb2.SessionID.SerializeToString, response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, ) self.DestroySession = channel.unary_unary( - "/modelresponse.ModelResponse/DestroySession", + '/modelresponse.ModelResponse/DestroySession', request_serializer=modelresponse__pb2.SessionID.SerializeToString, response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, ) self.GeneratorReply = channel.unary_unary( - "/modelresponse.ModelResponse/GeneratorReply", + '/modelresponse.ModelResponse/GeneratorReply', request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, response_deserializer=modelresponse__pb2.MultiStringReply.FromString, ) self.ClassificationReply = channel.unary_unary( - "/modelresponse.ModelResponse/ClassificationReply", + '/modelresponse.ModelResponse/ClassificationReply', request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, response_deserializer=modelresponse__pb2.SingleStringReply.FromString, ) self.QuestionAndAnswerReply = channel.unary_unary( - "/modelresponse.ModelResponse/QuestionAndAnswerReply", + '/modelresponse.ModelResponse/QuestionAndAnswerReply', request_serializer=modelresponse__pb2.QARequest.SerializeToString, response_deserializer=modelresponse__pb2.SingleStringReply.FromString, ) self.FillMaskReply = channel.unary_unary( - "/modelresponse.ModelResponse/FillMaskReply", + '/modelresponse.ModelResponse/FillMaskReply', request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, response_deserializer=modelresponse__pb2.SingleStringReply.FromString, ) self.TokenClassificationReply = channel.unary_unary( - "/modelresponse.ModelResponse/TokenClassificationReply", + '/modelresponse.ModelResponse/TokenClassificationReply', request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, response_deserializer=modelresponse__pb2.SingleStringReply.FromString, ) self.ConversationalReply = channel.unary_unary( - "/modelresponse.ModelResponse/ConversationalReply", + '/modelresponse.ModelResponse/ConversationalReply', request_serializer=modelresponse__pb2.ConversationRequest.SerializeToString, response_deserializer=modelresponse__pb2.ConversationReply.FromString, ) self.Txt2ImgReply = channel.unary_unary( - "/modelresponse.ModelResponse/Txt2ImgReply", + '/modelresponse.ModelResponse/Txt2ImgReply', request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, response_deserializer=modelresponse__pb2.ImageReply.FromString, ) @@ -77,131 +77,131 @@ class ModelResponseServicer(object): def Terminate(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def CreateSession(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def DestroySession(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def GeneratorReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def ClassificationReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def QuestionAndAnswerReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def FillMaskReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def TokenClassificationReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def ConversationalReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def Txt2ImgReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def add_ModelResponseServicer_to_server(servicer, server): rpc_method_handlers = { - "Terminate": + 'Terminate': grpc.unary_unary_rpc_method_handler( servicer.Terminate, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. SerializeToString, ), - "CreateSession": + 'CreateSession': grpc.unary_unary_rpc_method_handler( servicer.CreateSession, request_deserializer=modelresponse__pb2.SessionID.FromString, response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. SerializeToString, ), - "DestroySession": + 'DestroySession': grpc.unary_unary_rpc_method_handler( servicer.DestroySession, request_deserializer=modelresponse__pb2.SessionID.FromString, response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. SerializeToString, ), - "GeneratorReply": + 'GeneratorReply': grpc.unary_unary_rpc_method_handler( servicer.GeneratorReply, request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, response_serializer=modelresponse__pb2.MultiStringReply.SerializeToString, ), - "ClassificationReply": + 'ClassificationReply': grpc.unary_unary_rpc_method_handler( servicer.ClassificationReply, request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, ), - "QuestionAndAnswerReply": + 'QuestionAndAnswerReply': grpc.unary_unary_rpc_method_handler( servicer.QuestionAndAnswerReply, request_deserializer=modelresponse__pb2.QARequest.FromString, response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, ), - "FillMaskReply": + 'FillMaskReply': grpc.unary_unary_rpc_method_handler( servicer.FillMaskReply, request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, ), - "TokenClassificationReply": + 'TokenClassificationReply': grpc.unary_unary_rpc_method_handler( servicer.TokenClassificationReply, request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, ), - "ConversationalReply": + 'ConversationalReply': grpc.unary_unary_rpc_method_handler( servicer.ConversationalReply, request_deserializer=modelresponse__pb2.ConversationRequest.FromString, response_serializer=modelresponse__pb2.ConversationReply.SerializeToString, ), - "Txt2ImgReply": + 'Txt2ImgReply': grpc.unary_unary_rpc_method_handler( servicer.Txt2ImgReply, request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, response_serializer=modelresponse__pb2.ImageReply.SerializeToString, ), } - generic_handler = grpc.method_handlers_generic_handler("modelresponse.ModelResponse", + generic_handler = grpc.method_handlers_generic_handler('modelresponse.ModelResponse', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler, )) @@ -210,22 +210,20 @@ def add_ModelResponseServicer_to_server(servicer, server): class ModelResponse(object): """Missing associated documentation comment in .proto file.""" @staticmethod - def Terminate( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def Terminate(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modelresponse.ModelResponse/Terminate", + '/modelresponse.ModelResponse/Terminate', google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, google_dot_protobuf_dot_empty__pb2.Empty.FromString, options, @@ -235,26 +233,23 @@ def Terminate( compression, wait_for_ready, timeout, - metadata, - ) + metadata) @staticmethod - def CreateSession( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def CreateSession(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modelresponse.ModelResponse/CreateSession", + '/modelresponse.ModelResponse/CreateSession', modelresponse__pb2.SessionID.SerializeToString, google_dot_protobuf_dot_empty__pb2.Empty.FromString, options, @@ -264,26 +259,23 @@ def CreateSession( compression, wait_for_ready, timeout, - metadata, - ) + metadata) @staticmethod - def DestroySession( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def DestroySession(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modelresponse.ModelResponse/DestroySession", + '/modelresponse.ModelResponse/DestroySession', modelresponse__pb2.SessionID.SerializeToString, google_dot_protobuf_dot_empty__pb2.Empty.FromString, options, @@ -293,26 +285,23 @@ def DestroySession( compression, wait_for_ready, timeout, - metadata, - ) + metadata) @staticmethod - def GeneratorReply( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def GeneratorReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modelresponse.ModelResponse/GeneratorReply", + '/modelresponse.ModelResponse/GeneratorReply', modelresponse__pb2.MultiStringRequest.SerializeToString, modelresponse__pb2.MultiStringReply.FromString, options, @@ -322,26 +311,23 @@ def GeneratorReply( compression, wait_for_ready, timeout, - metadata, - ) + metadata) @staticmethod - def ClassificationReply( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def ClassificationReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modelresponse.ModelResponse/ClassificationReply", + '/modelresponse.ModelResponse/ClassificationReply', modelresponse__pb2.SingleStringRequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, options, @@ -351,26 +337,23 @@ def ClassificationReply( compression, wait_for_ready, timeout, - metadata, - ) + metadata) @staticmethod - def QuestionAndAnswerReply( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def QuestionAndAnswerReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modelresponse.ModelResponse/QuestionAndAnswerReply", + '/modelresponse.ModelResponse/QuestionAndAnswerReply', modelresponse__pb2.QARequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, options, @@ -380,26 +363,23 @@ def QuestionAndAnswerReply( compression, wait_for_ready, timeout, - metadata, - ) + metadata) @staticmethod - def FillMaskReply( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def FillMaskReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modelresponse.ModelResponse/FillMaskReply", + '/modelresponse.ModelResponse/FillMaskReply', modelresponse__pb2.SingleStringRequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, options, @@ -409,26 +389,23 @@ def FillMaskReply( compression, wait_for_ready, timeout, - metadata, - ) + metadata) @staticmethod - def TokenClassificationReply( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def TokenClassificationReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modelresponse.ModelResponse/TokenClassificationReply", + '/modelresponse.ModelResponse/TokenClassificationReply', modelresponse__pb2.SingleStringRequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, options, @@ -438,26 +415,23 @@ def TokenClassificationReply( compression, wait_for_ready, timeout, - metadata, - ) + metadata) @staticmethod - def ConversationalReply( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def ConversationalReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modelresponse.ModelResponse/ConversationalReply", + '/modelresponse.ModelResponse/ConversationalReply', modelresponse__pb2.ConversationRequest.SerializeToString, modelresponse__pb2.ConversationReply.FromString, options, @@ -467,26 +441,23 @@ def ConversationalReply( compression, wait_for_ready, timeout, - metadata, - ) + metadata) @staticmethod - def Txt2ImgReply( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def Txt2ImgReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/modelresponse.ModelResponse/Txt2ImgReply", + '/modelresponse.ModelResponse/Txt2ImgReply', modelresponse__pb2.MultiStringRequest.SerializeToString, modelresponse__pb2.ImageReply.FromString, options, @@ -496,5 +467,4 @@ def Txt2ImgReply( compression, wait_for_ready, timeout, - metadata, - ) + metadata) From b0f3f650290e3d9d34e0d4b10141ba5f5d6bb67c Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 8 Aug 2023 14:33:28 -0700 Subject: [PATCH 06/15] get tests working again --- mii/config.py | 43 +++++++------ mii/deployment.py | 6 +- tests/conftest.py | 83 ++++++++++++------------- tests/test_config.py | 46 +++++--------- tests/test_deployment_options.py | 10 +-- tests/test_local_deployment.py | 4 +- tests/test_non_persistent_deployment.py | 2 +- 7 files changed, 96 insertions(+), 98 deletions(-) diff --git a/mii/config.py b/mii/config.py index 12295305..0d2180a9 100644 --- a/mii/config.py +++ b/mii/config.py @@ -23,29 +23,34 @@ class ReplicaConfig(DeepSpeedConfigModel): class DeploymentConfig(DeepSpeedConfigModel): + # Deployment configs deployment_name: str - model: str - task: TaskType - tensor_parallel: int = 1 - dtype: DtypeEnum = DtypeEnum.fp32 - meta_tensor: bool = False load_with_sys_mem: bool = False - enable_cuda_graph: bool = False + meta_tensor: bool = False hf_auth_token: str = "" - checkpoint_dict: Optional[Dict[str, Any]] = None deploy_rank: Optional[List[int]] = None torch_dist_port: int = 29500 - replace_with_kernel_inject: bool = True + replica_num: int = 1 + replica_configs: List[ReplicaConfig] = [] profile_model_time: bool = False skip_model_check: bool = False - max_tokens: int = 1024 trust_remote_code: bool = False + + # Model configs + model: str + task: TaskType + dtype: DtypeEnum = DtypeEnum.fp32 + model_path: str = "" + checkpoint_dict: Optional[Dict[str, Any]] = None + max_tokens: int = 1024 + + # Performance configs enable_deepspeed: bool = True enable_zero: bool = False ds_config: Dict[str, Any] = {} - model_path: str = "" - replica_num: int = 1 - replica_configs: List[ReplicaConfig] = [] + tensor_parallel: int = 1 + enable_cuda_graph: bool = False + replace_with_kernel_inject: bool = True class Config: json_encoders = {torch.dtype: lambda x: str(x)} @@ -136,6 +141,8 @@ def validate_model_and_task(cls, values): mii.utils.check_if_task_and_model_is_valid(task, model) if values.get("enable_deepspeed"): mii.utils.check_if_task_and_model_is_supported(task, model) + # Skip any future checks + values["skip_model_check"] = True return values @root_validator @@ -168,16 +175,16 @@ def deepspeed_or_zero(cls, values): class MIIConfig(DeepSpeedConfigModel): - deployment_config: DeploymentConfig = {} + deployment_config: DeploymentConfig + deployment_type: DeploymentType = DeploymentType.LOCAL hf_auth_token: str = "" port_number: int = 50050 enable_restful_api: bool = False restful_api_port: int = 51080 hostfile: str = DLTS_HOSTFILE - deployment_type: DeploymentType = DeploymentType.LOCAL version: int = 1 - @root_validator + @root_validator(skip_on_failure=True) def propagate_hf_auth(cls, values): # This validator is for when we support multiple models in a deployment hf_auth_token = values.get("hf_auth_token") @@ -186,7 +193,7 @@ def propagate_hf_auth(cls, values): deployment_config.hf_auth_token = hf_auth_token return values - @root_validator + @root_validator(skip_on_failure=True) def AML_name_valid(cls, values): if values.get("deployment_type") == DeploymentType.AML: allowed_chars = set(string.ascii_lowercase + string.ascii_uppercaes + @@ -196,7 +203,7 @@ def AML_name_valid(cls, values): ), "AML deployment names can only contain a-z, A-Z, 0-9, and '-'." return values - @root_validator + @root_validator(skip_on_failure=True) def generate_replica_configs(cls, values): replica_configs = values.get("deployment_config").replica_configs replica_num = values.get("deployment_config").replica_num @@ -263,7 +270,7 @@ def _allocate_processes(hostfile_path, tensor_parallel, replica_num): if allocated_num < replica_num: raise ValueError( - f"No sufficient GPUs for {replica_num} replica(s), only {allocated_num} replica(s) can be deployed" + f"Not sufficient GPUs for {replica_num} replica(s), only {allocated_num} replica(s) can be deployed" ) return replica_pool diff --git a/mii/deployment.py b/mii/deployment.py index 3bec6c09..51de2189 100644 --- a/mii/deployment.py +++ b/mii/deployment.py @@ -46,7 +46,11 @@ def support_legacy_api( return deployment_config, mii_config -def deploy(deployment_name, deployment_config=None, mii_config=None, *args, **kwargs): +def deploy(deployment_name: str, + deployment_config: dict, + mii_config: dict = None, + *args, + **kwargs): if mii_config is None: mii_config = {} diff --git a/tests/conftest.py b/tests/conftest.py index cb812069..4cfa7dd9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,11 +9,6 @@ from types import SimpleNamespace -# Add pytest.skip here for configs that we do not want to test -def validate_config(config): - pass - - @pytest.fixture(scope="function", params=['fp16']) def dtype(request): return request.param @@ -54,30 +49,6 @@ def restful_api_port(request): return request.param -@pytest.fixture(scope="function") -def mii_config( - tmpdir: str, - dtype: str, - tensor_parallel: int, - port_number: int, - meta_tensor: bool, - load_with_sys_mem: bool, - replica_num: int, - enable_restful_api: bool, - restful_api_port: int, -): - return { - 'dtype': dtype, - 'tensor_parallel': tensor_parallel, - 'port_number': port_number, - 'meta_tensor': meta_tensor, - 'load_with_sys_mem': load_with_sys_mem, - 'replica_num': replica_num, - 'enable_restful_api': enable_restful_api, - 'restful_api_port': restful_api_port, - } - - @pytest.fixture(scope="function", params=["text-generation"]) def task_name(request): return request.param @@ -88,6 +59,11 @@ def model_name(request): return request.param +@pytest.fixture(scope="function") +def deployment_name(model_name): + return model_name + "-deployment" + + @pytest.fixture(scope="function", params=[mii.DeploymentType.LOCAL]) def deployment_type(request): return request.param @@ -111,23 +87,42 @@ def ds_config(request): @pytest.fixture(scope="function") def deployment_config(task_name: str, model_name: str, - deployment_type: str, - mii_config: dict, + dtype: str, + tensor_parallel: int, + meta_tensor: bool, + load_with_sys_mem: bool, + replica_num: int, enable_deepspeed: bool, enable_zero: bool, ds_config: dict): config = SimpleNamespace(task=task_name, model=model_name, - deployment_type=deployment_type, - deployment_name=model_name + "-deployment", + dtype=dtype, + tensor_parallel=tensor_parallel, model_path=os.getenv("TRANSFORMERS_CACHE", - None), - mii_config=mii_config, + ""), + meta_tensor=meta_tensor, + replica_num=replica_num, enable_deepspeed=enable_deepspeed, enable_zero=enable_zero, ds_config=ds_config) - validate_config(config) - return config + return config.__dict__ + + +@pytest.fixture(scope="function") +def mii_config( + deployment_type: str, + port_number: int, + enable_restful_api: bool, + restful_api_port: int, +): + config = SimpleNamespace( + deployment_type=deployment_type, + port_number=port_number, + enable_restful_api=enable_restful_api, + restful_api_port=restful_api_port, + ) + return config.__dict__ @pytest.fixture(scope="function", params=[None]) @@ -136,15 +131,19 @@ def expected_failure(request): @pytest.fixture(scope="function") -def deployment(deployment_config, expected_failure): +def deployment(deployment_name, mii_config, deployment_config, expected_failure): if expected_failure is not None: with pytest.raises(expected_failure) as excinfo: - mii.deploy(**deployment_config.__dict__) + mii.deploy(deployment_name=deployment_name, + mii_config=mii_config, + deployment_config=deployment_config) yield excinfo else: - mii.deploy(**deployment_config.__dict__) - yield deployment_config - mii.terminate(deployment_config.deployment_name) + mii.deploy(deployment_name=deployment_name, + mii_config=mii_config, + deployment_config=deployment_config) + yield deployment_name + mii.terminate(deployment_name) @pytest.fixture(scope="function", params=[{"query": "DeepSpeed is the greatest"}]) diff --git a/tests/test_config.py b/tests/test_config.py index 2d8de70b..73e68eeb 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -9,33 +9,21 @@ import mii -def test_base_config(): - config = {'port_number': 12345, 'tensor_parallel': 4} - mii_config = mii.config.MIIConfig(**config) - - assert mii_config.port_number == config['port_number'] - assert mii_config.tensor_parallel == config['tensor_parallel'] - - -@pytest.mark.parametrize("config", - [ - { - 'port_number': 'fail', - 'tensor_parallel': 'fail' - }, - { - 'port_number': 'fail', - 'tensor_parallel': 4 - }, - { - 'port_number': 12345, - 'tensor_parallel': 'fail' - }, - { - 'port_fail': 12345, - 'tensor_parallel': 4 - }, - ]) -def test_base_config_literalfail(config): +@pytest.mark.parametrize("port_number", [12345]) +@pytest.mark.parametrize("tensor_parallel", [4]) +def test_base_configs(deployment_name, mii_config, deployment_config): + deployment_config["deployment_name"] = deployment_name + mii_config["deployment_config"] = deployment_config + mii_config = mii.config.MIIConfig(**mii_config) + + assert mii_config.port_number == 12345 + assert mii_config.deployment_config.tensor_parallel == 4 + + +@pytest.mark.parametrize("port_number", ["fail"]) +@pytest.mark.parametrize("tensor_parallel", [3.5]) +def test_base_configs_literalfail(deployment_name, mii_config, deployment_config): with pytest.raises(pydantic.ValidationError): - mii_config = mii.config.MIIConfig(**config) + deployment_config["deployment_name"] = deployment_name + mii_config["deployment_config"] = deployment_config + mii_config = mii.config.MIIConfig(**mii_config) diff --git a/tests/test_deployment_options.py b/tests/test_deployment_options.py index a84318d7..2c80cc26 100644 --- a/tests/test_deployment_options.py +++ b/tests/test_deployment_options.py @@ -13,14 +13,14 @@ @pytest.mark.parametrize("meta_tensor", [True]) @pytest.mark.parametrize("tensor_parallel", [2]) def test_meta_tensor(deployment, query): - generator = mii.mii_query_handle(deployment.deployment_name) + generator = mii.mii_query_handle(deployment) result = generator.query(query) assert result @pytest.mark.parametrize("enable_restful_api", [True]) def test_restful_api(deployment, query, restful_api_port): - generator = mii.mii_query_handle(deployment.deployment_name) + generator = mii.mii_query_handle(deployment) for _ in range(2): result = generator.query(query) @@ -36,14 +36,14 @@ def test_restful_api(deployment, query, restful_api_port): @pytest.mark.parametrize("load_with_sys_mem", [True]) def test_load_to_sys_mem(deployment, query): - generator = mii.mii_query_handle(deployment.deployment_name) + generator = mii.mii_query_handle(deployment) result = generator.query(query) assert result @pytest.mark.parametrize("replica_num", [2]) def test_replicas(deployment, query, replica_num): - generator = mii.mii_query_handle(deployment.deployment_name) + generator = mii.mii_query_handle(deployment) # Replicas are given queries in round-robin, so test each model is responding for _ in range(replica_num): result = generator.query(query) @@ -72,7 +72,7 @@ def test_replicas(deployment, query, replica_num): }, ]) def test_zero_config(deployment, query): - generator = mii.mii_query_handle(deployment.deployment_name) + generator = mii.mii_query_handle(deployment) result = generator.query(query) assert result diff --git a/tests/test_local_deployment.py b/tests/test_local_deployment.py index 7e0ab7b4..73cadc50 100644 --- a/tests/test_local_deployment.py +++ b/tests/test_local_deployment.py @@ -64,7 +64,7 @@ ], ) def test_single_GPU(deployment, query): - generator = mii.mii_query_handle(deployment.deployment_name) + generator = mii.mii_query_handle(deployment) result = generator.query(query) assert result @@ -83,6 +83,6 @@ def test_single_GPU(deployment, query): ], ) def test_multi_GPU(deployment, query): - generator = mii.mii_query_handle(deployment.deployment_name) + generator = mii.mii_query_handle(deployment) result = generator.query(query) assert result diff --git a/tests/test_non_persistent_deployment.py b/tests/test_non_persistent_deployment.py index 2f555b64..92cb454a 100644 --- a/tests/test_non_persistent_deployment.py +++ b/tests/test_non_persistent_deployment.py @@ -66,6 +66,6 @@ ], ) def test_single_GPU(deployment, query): - generator = mii.mii_query_handle(deployment.deployment_name) + generator = mii.mii_query_handle(deployment) result = generator.query(query) assert result From 7e206efea7691de02fd7a60679faf30f3abd2094 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 8 Aug 2023 16:01:39 -0700 Subject: [PATCH 07/15] fix non-persistent deployment error --- mii/client.py | 1 - mii/deployment.py | 37 +++++++++++++++++++++---------------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/mii/client.py b/mii/client.py index 8b1b9235..1ffb656b 100644 --- a/mii/client.py +++ b/mii/client.py @@ -14,7 +14,6 @@ def _get_mii_config(deployment_name): mii_config = mii.utils.import_score_file(deployment_name).mii_config - # TODO: Avoid model checking when we load the config return MIIConfig(**mii_config) diff --git a/mii/deployment.py b/mii/deployment.py index 51de2189..4caea601 100644 --- a/mii/deployment.py +++ b/mii/deployment.py @@ -46,11 +46,13 @@ def support_legacy_api( return deployment_config, mii_config -def deploy(deployment_name: str, - deployment_config: dict, - mii_config: dict = None, - *args, - **kwargs): +def deploy( + deployment_name: str, + deployment_config: dict, + mii_config: dict = None, + *args, + **kwargs, +): if mii_config is None: mii_config = {} @@ -80,18 +82,9 @@ def deploy(deployment_name: str, if mii_config.deployment_type == DeploymentType.AML: _deploy_aml(mii_config) elif mii_config.deployment_type == DeploymentType.LOCAL: - return _deploy_local(mii_config) + _deploy_local(mii_config) elif mii_config.deployment_type == DeploymentType.NON_PERSISTENT: - assert ( - int(os.getenv("WORLD_SIZE", "1")) - == mii_config.deployment_config.tensor_parallel - ), "World Size does not equal number of tensors. When using non-persistent deployment type, please launch with `deepspeed --num_gpus `" - deployment_name = mii_config.deployment_config.deployment_name - task = mii_config.deployment_config.task - mii.non_persistent_models[deployment_name] = ( - load_models(deployment_config), - task, - ) + _deploy_nonpersistent(mii_config) def _deploy_local(mii_config): @@ -110,3 +103,15 @@ def _deploy_aml(mii_config): f"AML deployment assets at {mii.aml_related.utils.aml_output_path(mii_config.deployment_config.deployment_name)}" ) print("Please run 'deploy.sh' to bring your deployment online") + + +def _deploy_nonpersistent(mii_config): + assert ( + int(os.getenv("WORLD_SIZE", "1")) + == mii_config.deployment_config.tensor_parallel + ), "World Size does not equal number of tensors. When using non-persistent deployment type, please launch with `deepspeed --num_gpus `" + deployment_name = mii_config.deployment_config.deployment_name + mii.non_persistent_models[deployment_name] = ( + load_models(mii_config.deployment_config), + mii_config.deployment_config.task, + ) From 018215efd4342e367466d3f773da038f410431ac Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 8 Aug 2023 16:31:01 -0700 Subject: [PATCH 08/15] fix restful test --- tests/test_deployment_options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_deployment_options.py b/tests/test_deployment_options.py index 2c80cc26..9f204eaa 100644 --- a/tests/test_deployment_options.py +++ b/tests/test_deployment_options.py @@ -24,7 +24,7 @@ def test_restful_api(deployment, query, restful_api_port): for _ in range(2): result = generator.query(query) - url = f'http://localhost:{restful_api_port}/mii/{deployment.deployment_name}' + url = f'http://localhost:{restful_api_port}/mii/{deployment}' params = {"request": query} json_params = json.dumps(params) result = requests.post(url, From c77c78d9cb8b60922652830b0f74bc7f93a77d11 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 8 Aug 2023 17:32:10 -0700 Subject: [PATCH 09/15] fix zero config test failures --- mii/config.py | 5 +- mii/models/load_models.py | 2 +- tests/conftest.py | 64 ++++++++------- tests/test_deployment_options.py | 102 +++++++++++------------- tests/test_local_deployment.py | 16 ++-- tests/test_non_persistent_deployment.py | 14 ++-- 6 files changed, 105 insertions(+), 98 deletions(-) diff --git a/mii/config.py b/mii/config.py index 0d2180a9..dbf0ac35 100644 --- a/mii/config.py +++ b/mii/config.py @@ -157,12 +157,13 @@ def meta_tensor_or_sys_mem(cls, values): def zero_dtype_valid(cls, values): if values.get("enable_zero"): if values.get("ds_config").get("fp16", {}).get("enabled", False): + # TODO: We should be able to use DtypeEnum instead of torch.float assert ( - values.get("dtype") == DtypeEnum.float16 + values.get("dtype") == torch.float16 ), "ZeRO FP16 enabled, `dtype` must be set to `torch.float16`" else: assert ( - values.get("dtype") == DtypeEnum.float32 + values.get("dtype") == torch.float32 ), "ZeRO FP16 disabled, `dtype` must be set to `torch.float32`" return values diff --git a/mii/models/load_models.py b/mii/models/load_models.py index 6da0fa9a..1e28d9aa 100644 --- a/mii/models/load_models.py +++ b/mii/models/load_models.py @@ -89,7 +89,7 @@ def load_models(deployment_config): # initialise Deepspeed ZeRO and store only the engine object ds_engine = deepspeed.initialize(model=inference_pipeline.model, - config_params=ds_config)[0] + config=deployment_config.ds_config)[0] ds_engine.module.eval() # inference inference_pipeline.model = ds_engine.module diff --git a/tests/conftest.py b/tests/conftest.py index 4cfa7dd9..9dac6e69 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ from types import SimpleNamespace -@pytest.fixture(scope="function", params=['fp16']) +@pytest.fixture(scope="function", params=["fp16"]) def dtype(request): return request.param @@ -85,27 +85,31 @@ def ds_config(request): @pytest.fixture(scope="function") -def deployment_config(task_name: str, - model_name: str, - dtype: str, - tensor_parallel: int, - meta_tensor: bool, - load_with_sys_mem: bool, - replica_num: int, - enable_deepspeed: bool, - enable_zero: bool, - ds_config: dict): - config = SimpleNamespace(task=task_name, - model=model_name, - dtype=dtype, - tensor_parallel=tensor_parallel, - model_path=os.getenv("TRANSFORMERS_CACHE", - ""), - meta_tensor=meta_tensor, - replica_num=replica_num, - enable_deepspeed=enable_deepspeed, - enable_zero=enable_zero, - ds_config=ds_config) +def deployment_config( + task_name: str, + model_name: str, + dtype: str, + tensor_parallel: int, + meta_tensor: bool, + load_with_sys_mem: bool, + replica_num: int, + enable_deepspeed: bool, + enable_zero: bool, + ds_config: dict, +): + config = SimpleNamespace( + task=task_name, + model=model_name, + dtype=dtype, + tensor_parallel=tensor_parallel, + model_path=os.getenv("TRANSFORMERS_CACHE", + ""), + meta_tensor=meta_tensor, + replica_num=replica_num, + enable_deepspeed=enable_deepspeed, + enable_zero=enable_zero, + ds_config=ds_config, + ) return config.__dict__ @@ -134,14 +138,18 @@ def expected_failure(request): def deployment(deployment_name, mii_config, deployment_config, expected_failure): if expected_failure is not None: with pytest.raises(expected_failure) as excinfo: - mii.deploy(deployment_name=deployment_name, - mii_config=mii_config, - deployment_config=deployment_config) + mii.deploy( + deployment_name=deployment_name, + mii_config=mii_config, + deployment_config=deployment_config, + ) yield excinfo else: - mii.deploy(deployment_name=deployment_name, - mii_config=mii_config, - deployment_config=deployment_config) + mii.deploy( + deployment_name=deployment_name, + mii_config=mii_config, + deployment_config=deployment_config, + ) yield deployment_name mii.terminate(deployment_name) diff --git a/tests/test_deployment_options.py b/tests/test_deployment_options.py index 9f204eaa..5b1dcf22 100644 --- a/tests/test_deployment_options.py +++ b/tests/test_deployment_options.py @@ -6,6 +6,7 @@ import pytest import json import requests +import pydantic import mii @@ -24,7 +25,7 @@ def test_restful_api(deployment, query, restful_api_port): for _ in range(2): result = generator.query(query) - url = f'http://localhost:{restful_api_port}/mii/{deployment}' + url = f"http://localhost:{restful_api_port}/mii/{deployment}" params = {"request": query} json_params = json.dumps(params) result = requests.post(url, @@ -53,24 +54,26 @@ def test_replicas(deployment, query, replica_num): @pytest.mark.deepspeed @pytest.mark.parametrize("enable_deepspeed", [False]) @pytest.mark.parametrize("enable_zero", [True]) -@pytest.mark.parametrize("ds_config", - [ - { - "fp16": { - "enabled": True - }, - "bf16": { - "enabled": False - }, - "zero_optimization": { - "stage": 3, - "offload_param": { - "device": "cpu", - }, - }, - "train_micro_batch_size_per_gpu": 1, - }, - ]) +@pytest.mark.parametrize( + "ds_config", + [ + { + "fp16": { + "enabled": True + }, + "bf16": { + "enabled": False + }, + "zero_optimization": { + "stage": 3, + "offload_param": { + "device": "cpu", + }, + }, + "train_micro_batch_size_per_gpu": 1, + }, + ], +) def test_zero_config(deployment, query): generator = mii.mii_query_handle(deployment) result = generator.query(query) @@ -78,44 +81,35 @@ def test_zero_config(deployment, query): @pytest.mark.deepspeed -@pytest.mark.parametrize("expected_failure", [AssertionError]) -@pytest.mark.parametrize("enable_deepspeed, enable_zero, dtype", - [(True, - True, - 'fp32'), - (False, - True, - 'fp16')]) -@pytest.mark.parametrize("ds_config", - [ - { - "fp16": { - "enabled": False - }, - "bf16": { - "enabled": False - }, - "zero_optimization": { - "stage": 3, - "offload_param": { - "device": "cpu", - }, - }, - "train_micro_batch_size_per_gpu": 1, - }, - ]) +@pytest.mark.parametrize("expected_failure", [pydantic.ValidationError]) +@pytest.mark.parametrize( + "enable_deepspeed, enable_zero, dtype", + [(True, + True, + "fp32"), + (False, + True, + "fp16")], +) @pytest.mark.parametrize( - "task_name, model_name, query", + "ds_config", [ - ( - "text-generation", - "distilgpt2", - { - "query": "DeepSpeed is the greatest" + { + "fp16": { + "enabled": False + }, + "bf16": { + "enabled": False + }, + "zero_optimization": { + "stage": 3, + "offload_param": { + "device": "cpu", + }, }, - ), + "train_micro_batch_size_per_gpu": 1, + }, ], ) def test_zero_config_fail(deployment, query): - print(deployment) - assert "MII Config Error" in str(deployment.value) + assert "assertion_error" in str(deployment.value) diff --git a/tests/test_local_deployment.py b/tests/test_local_deployment.py index 73cadc50..ec77613d 100644 --- a/tests/test_local_deployment.py +++ b/tests/test_local_deployment.py @@ -46,14 +46,16 @@ "bigscience/bloom-560m", { "query": ["DeepSpeed is the greatest", - 'Seattle is'] + "Seattle is"] + }, + ), + ( + "token-classification", + "Jean-Baptiste/roberta-large-ner-english", + { + "query": "My name is jean-baptiste and I live in montreal." }, ), - ("token-classification", - "Jean-Baptiste/roberta-large-ner-english", - { - "query": "My name is jean-baptiste and I live in montreal." - }), ( "text-classification", "roberta-large-mnli", @@ -77,7 +79,7 @@ def test_single_GPU(deployment, query): "bigscience/bloom-560m", { "query": ["DeepSpeed is the greatest", - 'Seattle is'] + "Seattle is"] }, ), ], diff --git a/tests/test_non_persistent_deployment.py b/tests/test_non_persistent_deployment.py index 92cb454a..71234201 100644 --- a/tests/test_non_persistent_deployment.py +++ b/tests/test_non_persistent_deployment.py @@ -48,14 +48,16 @@ "bigscience/bloom-560m", { "query": ["DeepSpeed is the greatest", - 'Seattle is'] + "Seattle is"] + }, + ), + ( + "token-classification", + "Jean-Baptiste/roberta-large-ner-english", + { + "query": "My name is jean-baptiste and I live in montreal." }, ), - ("token-classification", - "Jean-Baptiste/roberta-large-ner-english", - { - "query": "My name is jean-baptiste and I live in montreal." - }), ( "text-classification", "roberta-large-mnli", From 3f50926efb4313bceb3d95be93599d9fb2150178 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 8 Aug 2023 18:45:35 -0700 Subject: [PATCH 10/15] attempt to fix hf_auth error --- mii/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mii/config.py b/mii/config.py index dbf0ac35..f6de7016 100644 --- a/mii/config.py +++ b/mii/config.py @@ -27,7 +27,7 @@ class DeploymentConfig(DeepSpeedConfigModel): deployment_name: str load_with_sys_mem: bool = False meta_tensor: bool = False - hf_auth_token: str = "" + hf_auth_token: Optional[str] = None deploy_rank: Optional[List[int]] = None torch_dist_port: int = 29500 replica_num: int = 1 @@ -178,7 +178,7 @@ def deepspeed_or_zero(cls, values): class MIIConfig(DeepSpeedConfigModel): deployment_config: DeploymentConfig deployment_type: DeploymentType = DeploymentType.LOCAL - hf_auth_token: str = "" + hf_auth_token: Optional[str] = None port_number: int = 50050 enable_restful_api: bool = False restful_api_port: int = 51080 From 4f3b63e6d463f68e1b2a53e14249def4865083d4 Mon Sep 17 00:00:00 2001 From: Tosin Segun Date: Wed, 9 Aug 2023 18:11:10 +0000 Subject: [PATCH 11/15] Allocating multiple processes --- mii/config.py | 90 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 58 insertions(+), 32 deletions(-) diff --git a/mii/config.py b/mii/config.py index f6de7016..898421f9 100644 --- a/mii/config.py +++ b/mii/config.py @@ -43,6 +43,7 @@ class DeploymentConfig(DeepSpeedConfigModel): model_path: str = "" checkpoint_dict: Optional[Dict[str, Any]] = None max_tokens: int = 1024 + GPU_index_map: dict = None # Performance configs enable_deepspeed: bool = True @@ -174,9 +175,16 @@ def deepspeed_or_zero(cls, values): ), "DeepSpeed and ZeRO cannot both be enabled, select only one" return values + @root_validator + def index_map_valid(cls, values): + if values.get("GPU_index_map"): + for host in gpu_index_map: + assert host in resource_pool, f"Host: {host} was not found" + assert resource_pool[host] >= tensor_parallel, f"Host {host} has {resource_pool[host]} slot(s), but {tensor_parallel} slot(s) are required" + return values class MIIConfig(DeepSpeedConfigModel): - deployment_config: DeploymentConfig + deployment_config: List[DeploymentConfig] deployment_type: DeploymentType = DeploymentType.LOCAL hf_auth_token: Optional[str] = None port_number: int = 50050 @@ -184,14 +192,15 @@ class MIIConfig(DeepSpeedConfigModel): restful_api_port: int = 51080 hostfile: str = DLTS_HOSTFILE version: int = 1 - + port_map: dict = {} @root_validator(skip_on_failure=True) def propagate_hf_auth(cls, values): # This validator is for when we support multiple models in a deployment hf_auth_token = values.get("hf_auth_token") - deployment_config = values.get("deployment_config") - if not deployment_config.hf_auth_token: - deployment_config.hf_auth_token = hf_auth_token + deployment_config_list = values.get("deployment_config") + for deployment_config in deployment_config_list: + if not deployment_config.hf_auth_token: + deployment_config.hf_auth_token = hf_auth_token return values @root_validator(skip_on_failure=True) @@ -206,44 +215,61 @@ def AML_name_valid(cls, values): @root_validator(skip_on_failure=True) def generate_replica_configs(cls, values): - replica_configs = values.get("deployment_config").replica_configs - replica_num = values.get("deployment_config").replica_num - if replica_configs: - assert len(replica_configs) == replica_num - return values - + port_map = values.get("port_map") hostfile = values.get("hostfile") port_number = values.get("port_number") - torch_dist_port = values.get("deployment_config").torch_dist_port - tensor_parallel = values.get("deployment_config").tensor_parallel - replica_num = values.get("deployment_config").replica_num - replica_pool = _allocate_processes(hostfile, tensor_parallel, replica_num) - replica_configs = [] - for i, (hostname, gpu_indices) in enumerate(replica_pool): - # Reserver port for a LB proxy when replication is enabled - port_offset = 1 - base_port = port_number + i * tensor_parallel + port_offset - tensor_parallel_ports = list(range(base_port, base_port + tensor_parallel)) - replica_torch_dist_port = torch_dist_port + (100 * i) - replica_configs.append( - ReplicaConfig( - hostname=hostname, - tensor_parallel_ports=tensor_parallel_ports, - torch_dist_port=replica_torch_dist_port, - gpu_indices=gpu_indices, - )) - - values.get("deployment_config").replica_configs = replica_configs + port_offset = 1 + for deployment_config in values.get("deployment_config"): + + replica_configs = deployment_config.replica_configs + replica_num = deployment_config.replica_num + if replica_configs: + assert len(replica_configs) == replica_num + return values + + torch_dist_port = deployment_config.torch_dist_port + tensor_parallel = deployment_config.tensor_parallel + replica_num = deployment_config.replica_num + GPU_index_map = deployment_config.GPU_index_map + replica_pool, GPU_index_map = _allocate_processes(hostfile, tensor_parallel, replica_num, GPU_index_map) + deployment_config.GPU_index_map = GPU_index_map + replica_configs = [] + for i, (hostname, gpu_indices) in enumerate(replica_pool): + # Reserver port for a LB proxy when replication is enabled + if hostname not in port_map: + port_map[hostname] = set() + base_port = port_number + i * tensor_parallel + port_offset + if base_port in port_map[hostname]: + base_port = max(port_map[hostname]) + 1 + tensor_parallel_ports = list(range(base_port, base_port + tensor_parallel)) + for i in range(base_port, base_port + tensor_parallel): + port_map[hostname].add(i) + replica_torch_dist_port = torch_dist_port + (100 * i) + replica_configs.append( + ReplicaConfig( + hostname=hostname, + tensor_parallel_ports=tensor_parallel_ports, + torch_dist_port=replica_torch_dist_port, + gpu_indices=gpu_indices, + )) + + deployment_config.replica_configs = replica_configs return values -def _allocate_processes(hostfile_path, tensor_parallel, replica_num): +def _allocate_processes(hostfile_path, tensor_parallel, replica_num, GPU_index_map=None): resource_pool = fetch_hostfile(hostfile_path) assert ( resource_pool is not None and len(resource_pool) > 0 ), f"No hosts found in {hostfile_path}" replica_pool = [] + + if gpu_index_map is not None: + for host in gpu_index_map: + replica_pool.append((host, gpu_index_map[host])) + return replica_pool + allocated_num = 0 for host, slots in resource_pool.items(): available_on_host = slots From 9487437a50c8846576334b5e76cb782af69ac516 Mon Sep 17 00:00:00 2001 From: Tosin Segun Date: Wed, 9 Aug 2023 18:45:13 +0000 Subject: [PATCH 12/15] Support for multiple deployments in deploy() API --- mii/config.py | 9 +++++---- mii/deployment.py | 36 ++++++++++++++++++++++-------------- mii/models/score/generate.py | 2 +- mii/server.py | 11 ++++++++--- mii/utils.py | 4 ++-- 5 files changed, 38 insertions(+), 24 deletions(-) diff --git a/mii/config.py b/mii/config.py index 898421f9..73d64a6f 100644 --- a/mii/config.py +++ b/mii/config.py @@ -184,8 +184,9 @@ def index_map_valid(cls, values): return values class MIIConfig(DeepSpeedConfigModel): - deployment_config: List[DeploymentConfig] + deployment_configs: List[DeploymentConfig] deployment_type: DeploymentType = DeploymentType.LOCAL + deployment_tag: str = None hf_auth_token: Optional[str] = None port_number: int = 50050 enable_restful_api: bool = False @@ -197,7 +198,7 @@ class MIIConfig(DeepSpeedConfigModel): def propagate_hf_auth(cls, values): # This validator is for when we support multiple models in a deployment hf_auth_token = values.get("hf_auth_token") - deployment_config_list = values.get("deployment_config") + deployment_config_list = values.get("deployment_configs") for deployment_config in deployment_config_list: if not deployment_config.hf_auth_token: deployment_config.hf_auth_token = hf_auth_token @@ -209,7 +210,7 @@ def AML_name_valid(cls, values): allowed_chars = set(string.ascii_lowercase + string.ascii_uppercaes + string.digits + "-") assert ( - set(values.get("deployment_config").deployment_name) <= allowed_chars + set(values.get("deployment_configs").deployment_name) <= allowed_chars ), "AML deployment names can only contain a-z, A-Z, 0-9, and '-'." return values @@ -219,7 +220,7 @@ def generate_replica_configs(cls, values): hostfile = values.get("hostfile") port_number = values.get("port_number") port_offset = 1 - for deployment_config in values.get("deployment_config"): + for deployment_config in values.get("deployment_configs"): replica_configs = deployment_config.replica_configs replica_num = deployment_config.replica_num diff --git a/mii/deployment.py b/mii/deployment.py index 4caea601..2ca4fad6 100644 --- a/mii/deployment.py +++ b/mii/deployment.py @@ -4,11 +4,11 @@ # DeepSpeed Team import os import mii - +from typing import List from .utils import logger from .models.score import create_score_file from .models import load_models -from .config import MIIConfig, DeploymentType +from .config import MIIConfig, DeploymentType, DeploymentConfig def support_legacy_api( @@ -48,7 +48,7 @@ def support_legacy_api( def deploy( deployment_name: str, - deployment_config: dict, + deployment_configs: List[dict], mii_config: dict = None, *args, **kwargs, @@ -63,18 +63,26 @@ def deploy( kwargs["mii_config"] = mii_config deployment_config, mii_config = support_legacy_api(*args, **kwargs) - deployment_config["deployment_name"] = deployment_name - mii_config["deployment_config"] = deployment_config + deployment_config["deployment_name"] = deployment_name + mii_config["deployment_tag"] = deployment_name + mii_config["deployment_configs"] = [DeploymentConfig(**deployment_config)] + else: + for deployment_config in deployment_configs: + deployment_config = DeploymentConfig(**deployment_config) + mii_config["deployment_configs"] = deployment_configs + mii_config["deployment_tag"] = kwargs["deployment_tag"] + mii_config = mii.config.MIIConfig(**mii_config) - if mii_config.deployment_config.enable_deepspeed: - logger.info( - f"************* MII is using DeepSpeed Optimizations to accelerate your model *************" - ) - else: - logger.info( - f"************* DeepSpeed Optimizations not enabled. Please use enable_deepspeed to get better performance *************" - ) + for deployment_config in mii_config.deployment_configs: + if deployment_config.enable_deepspeed: + logger.info( + f"************* MII is using DeepSpeed Optimizations to accelerate your model *************" + ) + else: + logger.info( + f"************* DeepSpeed Optimizations not enabled. Please use enable_deepspeed to get better performance *************" + ) if mii_config.deployment_type != DeploymentType.NON_PERSISTENT: create_score_file(mii_config) @@ -88,7 +96,7 @@ def deploy( def _deploy_local(mii_config): - mii.utils.import_score_file(mii_config.deployment_config.deployment_name).init() + mii.utils.import_score_file(mii_config.deployment_tag).init() def _deploy_aml(mii_config): diff --git a/mii/models/score/generate.py b/mii/models/score/generate.py index b351fab6..0839a793 100644 --- a/mii/models/score/generate.py +++ b/mii/models/score/generate.py @@ -32,7 +32,7 @@ def create_score_file(mii_config): def generated_score_path(mii_config): deployment_type = mii_config.deployment_type - deployment_name = mii_config.deployment_config.deployment_name + deployment_name = mii_config.deployment_tag if deployment_type == DeploymentType.LOCAL: score_path = os.path.join(mii.utils.mii_cache_path(), deployment_name) elif deployment_type == DeploymentType.AML: diff --git a/mii/server.py b/mii/server.py index 04002e2e..3dc9e028 100644 --- a/mii/server.py +++ b/mii/server.py @@ -26,9 +26,14 @@ class MIIServer: """Initialize the model, setup the server for the model under model_path""" def __init__(self, mii_config): - self.task = mii_config.deployment_config.task - self.num_gpus = get_num_gpus(mii_config) - assert self.num_gpus > 0, "GPU count must be greater than 0" + #self.task = mii_config.deployment_config.task + + + #self.num_gpus = get_num_gpus(mii_config) + #assert self.num_gpus > 0, "GPU count must be greater than 0" + + for deployment_config in mii_config.deployment_configs: + assert get_num_gpus(deployment_config) > 0, f"GPU count for {deployment.deployment_name} must be greater than 0" """ if mii_configs.hostfile is None: hostfile = tempfile.NamedTemporaryFile(delete=False) diff --git a/mii/utils.py b/mii/utils.py index 72e22ea2..9a93b3c1 100644 --- a/mii/utils.py +++ b/mii/utils.py @@ -123,8 +123,8 @@ def extract_query_dict(task, request_dict): return query_dict -def get_num_gpus(mii_config): - num_gpus = mii_config.deployment_config.tensor_parallel +def get_num_gpus(deployment_config): + num_gpus = deployment_config.tensor_parallel assert ( torch.cuda.device_count() >= num_gpus From 511fd16f7fcbe33cda2dcaf12dd77235028cccd8 Mon Sep 17 00:00:00 2001 From: Tosin Segun Date: Fri, 11 Aug 2023 17:29:53 +0000 Subject: [PATCH 13/15] Load balancer support for multiple models --- mii/client.py | 38 +- mii/deployment.py | 4 +- mii/grpc_related/modelresponse_server.py | 69 ++- mii/grpc_related/proto/modelresponse.proto | 6 +- mii/grpc_related/proto/modelresponse_pb2.py | 97 ++-- .../proto/modelresponse_pb2_grpc.py | 520 +++++++----------- mii/launch/multi_gpu_server.py | 18 +- mii/server.py | 36 +- 8 files changed, 367 insertions(+), 421 deletions(-) diff --git a/mii/client.py b/mii/client.py index 1ffb656b..b262cece 100644 --- a/mii/client.py +++ b/mii/client.py @@ -56,24 +56,44 @@ class MIIClient: """ Client to send queries to a single endpoint. """ - def __init__(self, task, host, port): + def __init__(self, mii_config, host, port): self.asyncio_loop = asyncio.get_event_loop() channel = create_channel(host, port) self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel) - self.task = task - - async def _request_async_response(self, request_dict, **query_kwargs): - if self.task not in GRPC_METHOD_TABLE: - raise ValueError(f"unknown task: {self.task}") - - task_methods = GRPC_METHOD_TABLE[self.task] + self.mii_config = mii_config + + def _get_deployment_task(self, deployment_name=None): + task = None + if deployment_name is None: #mii.terminate() or single model + if deployment_name is None: + assert len(self.deployments) == 1, "Must pass deployment_name to query when using multiple deployments" + deployment = self.mii_config.deployment_configs[0] + deployment_name = getattr(deployment, deployment_name) + task = getattr(deployment, task) + else: + if deployment_name in self.deployments: + deployment = self.mii_config.deployment_configs[deployment_name] + task = getattr(deployment, task) + else: + assert False, f"{deployment_name} not found in list of deployments" + return deployment_name, task + + async def _request_async_response(self, request_dict, task, **query_kwargs): + if task not in GRPC_METHOD_TABLE: + raise ValueError(f"unknown task: {task}") + + task_methods = GRPC_METHOD_TABLE[task] proto_request = task_methods.pack_request_to_proto(request_dict, **query_kwargs) - proto_response = await getattr(self.stub, task_methods.method)(proto_request) + proto_response = await getattr(self.mr_stub, task_methods.method)(proto_request) return task_methods.unpack_response_from_proto(proto_response) def query(self, request_dict, **query_kwargs): + deployment_name = request_dict.get(mii.constants.DEPLOYMENT_NAME_KEY) + deployment_name, task = self._get_deployment_task(deployment_name) + request_dict['deployment_name'] = deployment_name return self.asyncio_loop.run_until_complete( self._request_async_response(request_dict, + task, **query_kwargs)) async def terminate_async(self): diff --git a/mii/deployment.py b/mii/deployment.py index 2ca4fad6..d76d1f35 100644 --- a/mii/deployment.py +++ b/mii/deployment.py @@ -65,11 +65,11 @@ def deploy( deployment_config["deployment_name"] = deployment_name mii_config["deployment_tag"] = deployment_name - mii_config["deployment_configs"] = [DeploymentConfig(**deployment_config)] + mii_config["deployment_configs"] = {deployment_name: DeploymentConfig(**deployment_config)} else: for deployment_config in deployment_configs: deployment_config = DeploymentConfig(**deployment_config) - mii_config["deployment_configs"] = deployment_configs + mii_config["deployment_configs"] = {dep.deployment_name: dep for dep in deployment_configs} mii_config["deployment_tag"] = kwargs["deployment_tag"] mii_config = mii.config.MIIConfig(**mii_config) diff --git a/mii/grpc_related/modelresponse_server.py b/mii/grpc_related/modelresponse_server.py index adec8e03..a7a812c8 100644 --- a/mii/grpc_related/modelresponse_server.py +++ b/mii/grpc_related/modelresponse_server.py @@ -171,19 +171,29 @@ def invoke(self, method_name, proto_request): self.asyncio_loop).result() + + class LoadBalancingInterceptor(grpc.ServerInterceptor): - def __init__(self, deployment_config): + def __init__(self, mii_config): super().__init__() self.asyncio_loop = asyncio.get_event_loop() - self.stubs = [ - ParallelStubInvoker(replica.hostname, - replica.tensor_parallel_ports) - for replica in deployment_config.replica_configs - ] - self.counter = AtomicCounter() - self.task = deployment_config.task - self.replica_sessions = {} + self.stubs = {} + self.counter = {} + self.replica_configs = replica_configs + self.tasks = {} + for deployment in mii_config.deployment_configs: + self.stubs[deployment.deployment_name] = [] + self.counter[deployment.deployment_name] = AtomicCounter() + self.tasks[deployment.deployment_name] = repl.task + + for deployment in mii_config.deployment_configs: + deployment_name = deployment.deployment_name + for repl in deployment.replica_configs: + self.stubs[deployment_name].append( + ParallelStubInvoker(repl.hostname, + repl.tensor_parallel_ports, + self.asyncio_loop)) # Start the asyncio loop in a separate thread def run_asyncio_loop(loop): @@ -201,46 +211,51 @@ def intercept_service(self, continuation, handler_call_details): def invoke_intercept_method(request_proto, context): method_name = _get_grpc_method_name(handler_call_details.method) - if method_name == TERMINATE_METHOD: - for stub in self.stubs: - stub.invoke(TERMINATE_METHOD, - google_dot_protobuf_dot_empty__pb2.Empty()) + for deployment_name in self.stubs: + for stub in self.stubs[deployment_name]: + stub.invoke(TERMINATE_METHOD, + google_dot_protobuf_dot_empty__pb2.Empty()) self.asyncio_loop.call_soon_threadsafe(self.asyncio_loop.stop) return next_handler.unary_unary(request_proto, context) - call_count = self.counter.get_and_increment() - replica_index = call_count % len(self.stubs) - if method_name == CREATE_SESSION_METHOD: if request_proto.session_id in self.sessions: raise ValueError( f"session {request_proto.session_id} already exists") self.replica_sessions[request_proto.session_id] = replica_index - self.stubs[replica_index].invoke(CREATE_SESSION_METHOD, request_proto) + self.stubs[deployment_name][replica_index].invoke( + CREATE_SESSION_METHOD, + request_proto) return google_dot_protobuf_dot_empty__pb2.Empty() if method_name == DESTROY_SESSION_METHOD: replica_index = self.replica_sessions.pop(request_proto.session_id) - self.stubs[replica_index].invoke(DESTROY_SESSION_METHOD, request_proto) + self.stubs[deployment_name][replica_index].invoke( + DESTROY_SESSION_METHOD, + request_proto) return google_dot_protobuf_dot_empty__pb2.Empty() - kwargs = unpack_proto_query_kwargs(request_proto.query_kwargs) - if "session_id" in kwargs: - session_id = kwargs["session_id"] + if "session_id" in request_proto.query_kwargs: + session_id = request_proto.query_kwargs["session_id"] if session_id not in self.replica_sessions: raise ValueError(f"session not found") replica_index = self.replica_sessions[session_id] - ret = self.stubs[replica_index].invoke(method_name, request_proto) + deployment_name = getattr(request_proto, 'deployment_name') + assert deployment_name in self.stubs, f"Deployment: {deployment_name} not found" + call_count = self.counter[deployment_name].get_and_increment() + replica_index = call_count % len(self.stubs[deployment_name]) + + ret = self.stubs[deployment_name][replica_index].invoke( + method_name, + request_proto) return ret return grpc.unary_unary_rpc_method_handler( invoke_intercept_method, request_deserializer=next_handler.request_deserializer, - response_serializer=next_handler.response_serializer, - ) - + response_serializer=next_handler.response_serializer) def _do_serve(service_impl, port, interceptors=[]): stop_event = service_impl.get_stop_event() @@ -267,8 +282,8 @@ def serve_inference(inference_pipeline, port): _do_serve(ModelResponse(inference_pipeline), port) -def serve_load_balancing(deployment_config, lb_port): - _do_serve(ServiceBase(), lb_port, [LoadBalancingInterceptor(deployment_config)]) +def serve_load_balancing(mii_config, lb_port): + _do_serve(ServiceBase(), lb_port, [LoadBalancingInterceptor(mii_config)]) if __name__ == "__main__": diff --git a/mii/grpc_related/proto/modelresponse.proto b/mii/grpc_related/proto/modelresponse.proto index a0698899..f3e5ff29 100644 --- a/mii/grpc_related/proto/modelresponse.proto +++ b/mii/grpc_related/proto/modelresponse.proto @@ -52,18 +52,20 @@ message SessionID { message SingleStringRequest { string request = 1; map query_kwargs = 2; + optional string deployment_name = 3; } message MultiStringRequest { repeated string request = 1; map query_kwargs = 2; + optional string deployment_name = 3; } message SingleStringReply { string response = 1; float time_taken = 2; float model_time_taken = 3; -} +} message MultiStringReply { repeated string response = 1; @@ -75,6 +77,7 @@ message QARequest { string question = 1; string context = 2; map query_kwargs = 3; + optional string deployment_name = 4; } message ConversationRequest { @@ -83,6 +86,7 @@ message ConversationRequest { repeated string past_user_inputs = 3; repeated string generated_responses = 4; map query_kwargs = 5; + optional string deployment_name = 6; } message ConversationReply { diff --git a/mii/grpc_related/proto/modelresponse_pb2.py b/mii/grpc_related/proto/modelresponse_pb2.py index 76b1f994..c9f12538 100644 --- a/mii/grpc_related/proto/modelresponse_pb2.py +++ b/mii/grpc_related/proto/modelresponse_pb2.py @@ -1,66 +1,63 @@ -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team - +# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: modelresponse.proto """Generated protocol buffer code.""" -from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() + from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xbb\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xb9\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"\xb9\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xa1\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x1c\n\x0f\x63onversation_id\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_conversation_id\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\x03\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x32\xd4\x06\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x43\n\rCreateSession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x44\n\x0e\x44\x65stroySession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12V\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1f.modelresponse.MultiStringReply\"\x00\x12]\n\x13\x43lassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply\"\x00\x12W\n\rFillMaskReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12\x62\n\x18TokenClassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12]\n\x13\x43onversationalReply\x12\".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply\"\x00\x12N\n\x0cTxt2ImgReply\x12!.modelresponse.MultiStringRequest\x1a\x19.modelresponse.ImageReply\"\x00\x62\x06proto3' -) -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'modelresponse_pb2', globals()) +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xed\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x12\x1c\n\x0f\x64\x65ployment_name\x18\x03 \x01(\tH\x00\x88\x01\x01\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_deployment_name\"\xeb\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x12\x1c\n\x0f\x64\x65ployment_name\x18\x03 \x01(\tH\x00\x88\x01\x01\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_deployment_name\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"\xeb\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x12\x1c\n\x0f\x64\x65ployment_name\x18\x04 \x01(\tH\x00\x88\x01\x01\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_deployment_name\"\xd3\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x1c\n\x0f\x63onversation_id\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x12\x1c\n\x0f\x64\x65ployment_name\x18\x06 \x01(\tH\x01\x88\x01\x01\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_conversation_idB\x12\n\x10_deployment_name\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\x03\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x32\xd4\x06\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x43\n\rCreateSession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x44\n\x0e\x44\x65stroySession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12V\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1f.modelresponse.MultiStringReply\"\x00\x12]\n\x13\x43lassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply\"\x00\x12W\n\rFillMaskReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12\x62\n\x18TokenClassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12]\n\x13\x43onversationalReply\x12\".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply\"\x00\x12N\n\x0cTxt2ImgReply\x12!.modelresponse.MultiStringRequest\x1a\x19.modelresponse.ImageReply\"\x00\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'modelresponse_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._options = None - _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _MULTISTRINGREQUEST_QUERYKWARGSENTRY._options = None - _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _QAREQUEST_QUERYKWARGSENTRY._options = None - _QAREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _CONVERSATIONREQUEST_QUERYKWARGSENTRY._options = None - _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _VALUE._serialized_start = 67 - _VALUE._serialized_end = 162 - _SESSIONID._serialized_start = 164 - _SESSIONID._serialized_end = 195 - _SINGLESTRINGREQUEST._serialized_start = 198 - _SINGLESTRINGREQUEST._serialized_end = 385 - _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_start = 313 - _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_end = 385 - _MULTISTRINGREQUEST._serialized_start = 388 - _MULTISTRINGREQUEST._serialized_end = 573 - _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_start = 313 - _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_end = 385 - _SINGLESTRINGREPLY._serialized_start = 575 - _SINGLESTRINGREPLY._serialized_end = 658 - _MULTISTRINGREPLY._serialized_start = 660 - _MULTISTRINGREPLY._serialized_end = 742 - _QAREQUEST._serialized_start = 745 - _QAREQUEST._serialized_end = 930 - _QAREQUEST_QUERYKWARGSENTRY._serialized_start = 313 - _QAREQUEST_QUERYKWARGSENTRY._serialized_end = 385 - _CONVERSATIONREQUEST._serialized_start = 933 - _CONVERSATIONREQUEST._serialized_end = 1222 - _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_start = 313 - _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_end = 385 - _CONVERSATIONREPLY._serialized_start = 1225 - _CONVERSATIONREPLY._serialized_end = 1370 - _IMAGEREPLY._serialized_start = 1372 - _IMAGEREPLY._serialized_end = 1497 - _MODELRESPONSE._serialized_start = 1500 - _MODELRESPONSE._serialized_end = 2352 + DESCRIPTOR._options = None + _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._options = None + _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' + _MULTISTRINGREQUEST_QUERYKWARGSENTRY._options = None + _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' + _QAREQUEST_QUERYKWARGSENTRY._options = None + _QAREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' + _CONVERSATIONREQUEST_QUERYKWARGSENTRY._options = None + _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' + _globals['_VALUE']._serialized_start=67 + _globals['_VALUE']._serialized_end=162 + _globals['_SESSIONID']._serialized_start=164 + _globals['_SESSIONID']._serialized_end=195 + _globals['_SINGLESTRINGREQUEST']._serialized_start=198 + _globals['_SINGLESTRINGREQUEST']._serialized_end=435 + _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start=343 + _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end=415 + _globals['_MULTISTRINGREQUEST']._serialized_start=438 + _globals['_MULTISTRINGREQUEST']._serialized_end=673 + _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start=343 + _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end=415 + _globals['_SINGLESTRINGREPLY']._serialized_start=675 + _globals['_SINGLESTRINGREPLY']._serialized_end=758 + _globals['_MULTISTRINGREPLY']._serialized_start=760 + _globals['_MULTISTRINGREPLY']._serialized_end=842 + _globals['_QAREQUEST']._serialized_start=845 + _globals['_QAREQUEST']._serialized_end=1080 + _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_start=343 + _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_end=415 + _globals['_CONVERSATIONREQUEST']._serialized_start=1083 + _globals['_CONVERSATIONREQUEST']._serialized_end=1422 + _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_start=343 + _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_end=415 + _globals['_CONVERSATIONREPLY']._serialized_start=1425 + _globals['_CONVERSATIONREPLY']._serialized_end=1570 + _globals['_IMAGEREPLY']._serialized_start=1572 + _globals['_IMAGEREPLY']._serialized_end=1697 + _globals['_MODELRESPONSE']._serialized_start=1700 + _globals['_MODELRESPONSE']._serialized_end=2552 # @@protoc_insertion_point(module_scope) diff --git a/mii/grpc_related/proto/modelresponse_pb2_grpc.py b/mii/grpc_related/proto/modelresponse_pb2_grpc.py index 95cfa825..683e4962 100644 --- a/mii/grpc_related/proto/modelresponse_pb2_grpc.py +++ b/mii/grpc_related/proto/modelresponse_pb2_grpc.py @@ -1,8 +1,3 @@ -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team - # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc @@ -13,6 +8,7 @@ class ModelResponseStub(object): """Missing associated documentation comment in .proto file.""" + def __init__(self, channel): """Constructor. @@ -20,60 +16,60 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Terminate = channel.unary_unary( - '/modelresponse.ModelResponse/Terminate', - request_serializer=google_dot_protobuf_dot_empty__pb2.Empty. - SerializeToString, - response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - ) + '/modelresponse.ModelResponse/Terminate', + request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + ) self.CreateSession = channel.unary_unary( - '/modelresponse.ModelResponse/CreateSession', - request_serializer=modelresponse__pb2.SessionID.SerializeToString, - response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - ) + '/modelresponse.ModelResponse/CreateSession', + request_serializer=modelresponse__pb2.SessionID.SerializeToString, + response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + ) self.DestroySession = channel.unary_unary( - '/modelresponse.ModelResponse/DestroySession', - request_serializer=modelresponse__pb2.SessionID.SerializeToString, - response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - ) + '/modelresponse.ModelResponse/DestroySession', + request_serializer=modelresponse__pb2.SessionID.SerializeToString, + response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + ) self.GeneratorReply = channel.unary_unary( - '/modelresponse.ModelResponse/GeneratorReply', - request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.MultiStringReply.FromString, - ) + '/modelresponse.ModelResponse/GeneratorReply', + request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.MultiStringReply.FromString, + ) self.ClassificationReply = channel.unary_unary( - '/modelresponse.ModelResponse/ClassificationReply', - request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) + '/modelresponse.ModelResponse/ClassificationReply', + request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.SingleStringReply.FromString, + ) self.QuestionAndAnswerReply = channel.unary_unary( - '/modelresponse.ModelResponse/QuestionAndAnswerReply', - request_serializer=modelresponse__pb2.QARequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) + '/modelresponse.ModelResponse/QuestionAndAnswerReply', + request_serializer=modelresponse__pb2.QARequest.SerializeToString, + response_deserializer=modelresponse__pb2.SingleStringReply.FromString, + ) self.FillMaskReply = channel.unary_unary( - '/modelresponse.ModelResponse/FillMaskReply', - request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) + '/modelresponse.ModelResponse/FillMaskReply', + request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.SingleStringReply.FromString, + ) self.TokenClassificationReply = channel.unary_unary( - '/modelresponse.ModelResponse/TokenClassificationReply', - request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) + '/modelresponse.ModelResponse/TokenClassificationReply', + request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.SingleStringReply.FromString, + ) self.ConversationalReply = channel.unary_unary( - '/modelresponse.ModelResponse/ConversationalReply', - request_serializer=modelresponse__pb2.ConversationRequest.SerializeToString, - response_deserializer=modelresponse__pb2.ConversationReply.FromString, - ) + '/modelresponse.ModelResponse/ConversationalReply', + request_serializer=modelresponse__pb2.ConversationRequest.SerializeToString, + response_deserializer=modelresponse__pb2.ConversationReply.FromString, + ) self.Txt2ImgReply = channel.unary_unary( - '/modelresponse.ModelResponse/Txt2ImgReply', - request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.ImageReply.FromString, - ) + '/modelresponse.ModelResponse/Txt2ImgReply', + request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.ImageReply.FromString, + ) class ModelResponseServicer(object): """Missing associated documentation comment in .proto file.""" + def Terminate(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -137,334 +133,232 @@ def Txt2ImgReply(self, request, context): def add_ModelResponseServicer_to_server(servicer, server): rpc_method_handlers = { - 'Terminate': - grpc.unary_unary_rpc_method_handler( - servicer.Terminate, - request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. - SerializeToString, - ), - 'CreateSession': - grpc.unary_unary_rpc_method_handler( - servicer.CreateSession, - request_deserializer=modelresponse__pb2.SessionID.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. - SerializeToString, - ), - 'DestroySession': - grpc.unary_unary_rpc_method_handler( - servicer.DestroySession, - request_deserializer=modelresponse__pb2.SessionID.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. - SerializeToString, - ), - 'GeneratorReply': - grpc.unary_unary_rpc_method_handler( - servicer.GeneratorReply, - request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, - response_serializer=modelresponse__pb2.MultiStringReply.SerializeToString, - ), - 'ClassificationReply': - grpc.unary_unary_rpc_method_handler( - servicer.ClassificationReply, - request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'QuestionAndAnswerReply': - grpc.unary_unary_rpc_method_handler( - servicer.QuestionAndAnswerReply, - request_deserializer=modelresponse__pb2.QARequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'FillMaskReply': - grpc.unary_unary_rpc_method_handler( - servicer.FillMaskReply, - request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'TokenClassificationReply': - grpc.unary_unary_rpc_method_handler( - servicer.TokenClassificationReply, - request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'ConversationalReply': - grpc.unary_unary_rpc_method_handler( - servicer.ConversationalReply, - request_deserializer=modelresponse__pb2.ConversationRequest.FromString, - response_serializer=modelresponse__pb2.ConversationReply.SerializeToString, - ), - 'Txt2ImgReply': - grpc.unary_unary_rpc_method_handler( - servicer.Txt2ImgReply, - request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, - response_serializer=modelresponse__pb2.ImageReply.SerializeToString, - ), + 'Terminate': grpc.unary_unary_rpc_method_handler( + servicer.Terminate, + request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + ), + 'CreateSession': grpc.unary_unary_rpc_method_handler( + servicer.CreateSession, + request_deserializer=modelresponse__pb2.SessionID.FromString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + ), + 'DestroySession': grpc.unary_unary_rpc_method_handler( + servicer.DestroySession, + request_deserializer=modelresponse__pb2.SessionID.FromString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + ), + 'GeneratorReply': grpc.unary_unary_rpc_method_handler( + servicer.GeneratorReply, + request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, + response_serializer=modelresponse__pb2.MultiStringReply.SerializeToString, + ), + 'ClassificationReply': grpc.unary_unary_rpc_method_handler( + servicer.ClassificationReply, + request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, + response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, + ), + 'QuestionAndAnswerReply': grpc.unary_unary_rpc_method_handler( + servicer.QuestionAndAnswerReply, + request_deserializer=modelresponse__pb2.QARequest.FromString, + response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, + ), + 'FillMaskReply': grpc.unary_unary_rpc_method_handler( + servicer.FillMaskReply, + request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, + response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, + ), + 'TokenClassificationReply': grpc.unary_unary_rpc_method_handler( + servicer.TokenClassificationReply, + request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, + response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, + ), + 'ConversationalReply': grpc.unary_unary_rpc_method_handler( + servicer.ConversationalReply, + request_deserializer=modelresponse__pb2.ConversationRequest.FromString, + response_serializer=modelresponse__pb2.ConversationReply.SerializeToString, + ), + 'Txt2ImgReply': grpc.unary_unary_rpc_method_handler( + servicer.Txt2ImgReply, + request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, + response_serializer=modelresponse__pb2.ImageReply.SerializeToString, + ), } - generic_handler = grpc.method_handlers_generic_handler('modelresponse.ModelResponse', - rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler, )) + generic_handler = grpc.method_handlers_generic_handler( + 'modelresponse.ModelResponse', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) -# This class is part of an EXPERIMENTAL API. + # This class is part of an EXPERIMENTAL API. class ModelResponse(object): """Missing associated documentation comment in .proto file.""" + @staticmethod def Terminate(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, target, - '/modelresponse.ModelResponse/Terminate', + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/Terminate', google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, google_dot_protobuf_dot_empty__pb2.Empty.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def CreateSession(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, target, - '/modelresponse.ModelResponse/CreateSession', + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/CreateSession', modelresponse__pb2.SessionID.SerializeToString, google_dot_protobuf_dot_empty__pb2.Empty.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def DestroySession(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, target, - '/modelresponse.ModelResponse/DestroySession', + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/DestroySession', modelresponse__pb2.SessionID.SerializeToString, google_dot_protobuf_dot_empty__pb2.Empty.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def GeneratorReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, target, - '/modelresponse.ModelResponse/GeneratorReply', + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/GeneratorReply', modelresponse__pb2.MultiStringRequest.SerializeToString, modelresponse__pb2.MultiStringReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def ClassificationReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, target, - '/modelresponse.ModelResponse/ClassificationReply', + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/ClassificationReply', modelresponse__pb2.SingleStringRequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def QuestionAndAnswerReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, target, - '/modelresponse.ModelResponse/QuestionAndAnswerReply', + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/QuestionAndAnswerReply', modelresponse__pb2.QARequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def FillMaskReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, target, - '/modelresponse.ModelResponse/FillMaskReply', + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/FillMaskReply', modelresponse__pb2.SingleStringRequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def TokenClassificationReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, target, - '/modelresponse.ModelResponse/TokenClassificationReply', + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/TokenClassificationReply', modelresponse__pb2.SingleStringRequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def ConversationalReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, target, - '/modelresponse.ModelResponse/ConversationalReply', + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/ConversationalReply', modelresponse__pb2.ConversationRequest.SerializeToString, modelresponse__pb2.ConversationReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def Txt2ImgReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, target, - '/modelresponse.ModelResponse/Txt2ImgReply', + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/Txt2ImgReply', modelresponse__pb2.MultiStringRequest.SerializeToString, modelresponse__pb2.ImageReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/mii/launch/multi_gpu_server.py b/mii/launch/multi_gpu_server.py index 4831593a..195094fe 100644 --- a/mii/launch/multi_gpu_server.py +++ b/mii/launch/multi_gpu_server.py @@ -7,7 +7,7 @@ import base64 import json -from mii.config import DeploymentConfig +from mii.config import DeploymentConfig, MIIConfig from mii.models.load_models import load_models from mii.grpc_related.modelresponse_server import serve_inference, serve_load_balancing from mii.grpc_related.restful_gateway import RestfulGatewayThread @@ -23,6 +23,15 @@ def b64_encoded_config(config_str): # return mii.DeploymentConfig object return DeploymentConfig(**config_dict) +def b64_encoded_config(config_str): + # str -> bytes + b64_bytes = config_str.encode() + # decode b64 bytes -> json bytes + config_bytes = base64.urlsafe_b64decode(b64_bytes) + # convert json bytes -> str -> dict + config_dict = json.loads(config_bytes.decode()) + # return mii.MIIConfig object + return MIIConfig(**config_dict) def main(): parser = argparse.ArgumentParser() @@ -31,6 +40,11 @@ def main(): type=b64_encoded_config, help="base64 encoded deployment config", ) + parser.add_argument( + "--mii-config", + type=b64_encoded_config, + help="base64 encoded mii config", + ) parser.add_argument( "--server-port", type=int, @@ -77,7 +91,7 @@ def main(): elif args.load_balancer: assert args.load_balancer_port, "--load-balancer-port must be provided." print(f"Starting load balancer on port: {args.load_balancer_port}") - serve_load_balancing(args.deployment_config, args.load_balancer_port) + serve_load_balancing(args.mii_config, args.load_balancer_port) else: assert args.server_port, "--server-port must be provided." diff --git a/mii/server.py b/mii/server.py index 3dc9e028..24e0b59c 100644 --- a/mii/server.py +++ b/mii/server.py @@ -125,31 +125,33 @@ def _initialize_service(self, mii_config): ] host_gpus = defaultdict(list) - for repl_config in mii_config.deployment_config.replica_configs: - host_gpus[repl_config.hostname].extend(repl_config.gpu_indices) + for deployment in mii_config.deployment_configs: + for repl_config in deployment.replica_configs: + host_gpus[repl_config.hostname].extend(repl_config.gpu_indices) # Start replica instances - for repl_config in mii_config.deployment_config.replica_configs: - hostfile = tempfile.NamedTemporaryFile(delete=False) - hostfile.write( - f"{repl_config.hostname} slots={max(host_gpus[repl_config.hostname])+1}\n" - .encode()) - ds_launch_str = self._generate_ds_launch_str(repl_config, hostfile.name) - processes.append( - self._launch_server_process( - mii_config.deployment_config, - "MII server", - ds_launch_str=ds_launch_str, - server_args=server_args + - [f"--server-port {repl_config.tensor_parallel_ports[0]}"], - )) + for deployment_config in mii_config.deployment_configs: + for repl_config in mii_config.deployment_config.replica_configs: + hostfile = tempfile.NamedTemporaryFile(delete=False) + hostfile.write( + f"{repl_config.hostname} slots={max(host_gpus[repl_config.hostname])+1}\n" + .encode()) + ds_launch_str = self._generate_ds_launch_str(repl_config, hostfile.name) + processes.append( + self._launch_server_process( + deployment_config, + "MII server", + ds_launch_str=ds_launch_str, + server_args=server_args + + [f"--server-port {repl_config.tensor_parallel_ports[0]}"], + )) # start load balancer here. # we don't use deepspeed launcher for the load balancer because it does not need a GPU. # The deepspeed launcher determines the number of processes to launch based on GPUs available on the host or CUDA_VISIBLE_DEVICES, # and it is expected to assign one GPU to one process. processes.append( self._launch_server_process( - mii_config.deployment_config, + mii_config, "load balancer", server_args=server_args + ["--load-balancer"], )) From 04f7989f051ae3c1e38998db6599d7a7027f80ff Mon Sep 17 00:00:00 2001 From: Tosin Segun Date: Fri, 11 Aug 2023 23:46:29 +0000 Subject: [PATCH 14/15] Refacoring and some formatting --- examples/multi_model/deploy.py | 51 ++ examples/multi_model/query.py | 50 ++ examples/multi_model/shutdown.py | 7 + mii/client.py | 10 +- mii/config.py | 38 +- mii/deployment.py | 26 +- mii/grpc_related/modelresponse_server.py | 7 +- mii/grpc_related/proto/modelresponse.proto | 2 +- mii/grpc_related/proto/modelresponse_pb2.py | 85 ++- .../proto/modelresponse_pb2_grpc.py | 515 +++++++++++------- mii/launch/multi_gpu_server.py | 2 + mii/models/score/score_template.py | 1 + mii/server.py | 15 +- 13 files changed, 520 insertions(+), 289 deletions(-) create mode 100644 examples/multi_model/deploy.py create mode 100644 examples/multi_model/query.py create mode 100644 examples/multi_model/shutdown.py diff --git a/examples/multi_model/deploy.py b/examples/multi_model/deploy.py new file mode 100644 index 00000000..1e8b7aed --- /dev/null +++ b/examples/multi_model/deploy.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +import mii + +gpu_index_map1 = {'master': [0]} +gpu_index_map2 = {'master': [1]} +gpu_index_map3 = {'master': [0, 1]} + +deployments = [] + +mii_configs1 = {"tensor_parallel": 2, "dtype": "fp16"} +mii_configs2 = {"tensor_parallel": 1} + +name = "bigscience/bloom-560m" +deployments.append({ + 'task': 'text-generation', + 'model': name, + 'deployment_name': name + "_deployment", + 'GPU_index_map': gpu_index_map3, + 'tensor_parallel': 2, + 'dtype': "fp16" +}) + +# gpt2 +name = "microsoft/DialogRPT-human-vs-rand" +deployments.append({ + 'task': 'text-classification', + 'model': name, + 'deployment_name': name + "_deployment", + 'GPU_index_map': gpu_index_map2 +}) + +name = "microsoft/DialoGPT-large" +deployments.append({ + 'task': 'conversational', + 'model': name, + 'deployment_name': name + "_deployment", + 'GPU_index_map': gpu_index_map1, +}) + +name = "deepset/roberta-large-squad2" +deployments.append({ + 'task': "question-answering", + 'model': name, + 'deployment_name': name + "-qa-deployment", + 'GPU_index_map': gpu_index_map2 +}) + +mii.deploy(deployment_tag="multi_models", deployment_configs=deployments[:2]) diff --git a/examples/multi_model/query.py b/examples/multi_model/query.py new file mode 100644 index 00000000..f506830f --- /dev/null +++ b/examples/multi_model/query.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import mii + +results = [] +generator = mii.mii_query_handle("multi_models") +result = generator.query( + { + "query": ["DeepSpeed is", + "Seattle is"], + "deployment_name": "bigscience/bloom-560m_deployment" + }, + do_sample=True, + max_new_tokens=30, +) +results.append(result) +print(result) + +result = generator.query({ + 'query': + "DeepSpeed is the greatest", + "deployment_name": + "microsoft/DialogRPT-human-vs-rand_deployment" +}) +results.append(result) +print(result) + +result = generator.query({ + 'text': "DeepSpeed is the greatest", + 'conversation_id': 3, + 'past_user_inputs': [], + 'generated_responses': [], + "deployment_name": "microsoft/DialoGPT-large_deployment" +}) +results.append(result) +print(result) + +result = generator.query({ + 'question': + "What is the greatest?", + 'context': + "DeepSpeed is the greatest", + "deployment_name": + "deepset/roberta-large-squad2" + "-qa-deployment" +}) +results.append(result) +print(result) diff --git a/examples/multi_model/shutdown.py b/examples/multi_model/shutdown.py new file mode 100644 index 00000000..6b718a4d --- /dev/null +++ b/examples/multi_model/shutdown.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +import mii + +mii.terminate("multi_models") diff --git a/mii/client.py b/mii/client.py index b262cece..fbc166e0 100644 --- a/mii/client.py +++ b/mii/client.py @@ -35,9 +35,7 @@ def mii_query_handle(deployment_name): return MIINonPersistentClient(task, deployment_name) mii_config = _get_mii_config(deployment_name) - return MIIClient(mii_config.deployment_config.task, - "localhost", - mii_config.port_number) + return MIIClient(mii_config, "localhost", mii_config.port_number) def create_channel(host, port): @@ -61,13 +59,13 @@ def __init__(self, mii_config, host, port): channel = create_channel(host, port) self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel) self.mii_config = mii_config - - def _get_deployment_task(self, deployment_name=None): + + def _get_deployment_task(self, deployment_name=None): task = None if deployment_name is None: #mii.terminate() or single model if deployment_name is None: assert len(self.deployments) == 1, "Must pass deployment_name to query when using multiple deployments" - deployment = self.mii_config.deployment_configs[0] + deployment = self.mii_config.deployment_configs[0] deployment_name = getattr(deployment, deployment_name) task = getattr(deployment, task) else: diff --git a/mii/config.py b/mii/config.py index 73d64a6f..1a581e3e 100644 --- a/mii/config.py +++ b/mii/config.py @@ -6,7 +6,7 @@ import os import string from typing import List, Optional, Dict, Any -from pydantic import validator, root_validator +from pydantic import validator, root_validator, BaseModel import mii from mii.constants import DeploymentType, TaskType, MII_MODEL_PATH_DEFAULT @@ -175,16 +175,20 @@ def deepspeed_or_zero(cls, values): ), "DeepSpeed and ZeRO cannot both be enabled, select only one" return values + +""" @root_validator def index_map_valid(cls, values): if values.get("GPU_index_map"): - for host in gpu_index_map: + for host in values.get("GPU_index_map"): assert host in resource_pool, f"Host: {host} was not found" assert resource_pool[host] >= tensor_parallel, f"Host {host} has {resource_pool[host]} slot(s), but {tensor_parallel} slot(s) are required" return values +""" + class MIIConfig(DeepSpeedConfigModel): - deployment_configs: List[DeploymentConfig] + deployment_configs: dict[str, DeploymentConfig] = None deployment_type: DeploymentType = DeploymentType.LOCAL deployment_tag: str = None hf_auth_token: Optional[str] = None @@ -194,12 +198,14 @@ class MIIConfig(DeepSpeedConfigModel): hostfile: str = DLTS_HOSTFILE version: int = 1 port_map: dict = {} + @root_validator(skip_on_failure=True) def propagate_hf_auth(cls, values): # This validator is for when we support multiple models in a deployment hf_auth_token = values.get("hf_auth_token") deployment_config_list = values.get("deployment_configs") - for deployment_config in deployment_config_list: + print(deployment_config_list) + for deployment_config in deployment_config_list.values(): if not deployment_config.hf_auth_token: deployment_config.hf_auth_token = hf_auth_token return values @@ -214,13 +220,13 @@ def AML_name_valid(cls, values): ), "AML deployment names can only contain a-z, A-Z, 0-9, and '-'." return values - @root_validator(skip_on_failure=True) + @root_validator() def generate_replica_configs(cls, values): port_map = values.get("port_map") hostfile = values.get("hostfile") port_number = values.get("port_number") port_offset = 1 - for deployment_config in values.get("deployment_configs"): + for deployment_config in values.get("deployment_configs").values(): replica_configs = deployment_config.replica_configs replica_num = deployment_config.replica_num @@ -235,6 +241,7 @@ def generate_replica_configs(cls, values): replica_pool, GPU_index_map = _allocate_processes(hostfile, tensor_parallel, replica_num, GPU_index_map) deployment_config.GPU_index_map = GPU_index_map replica_configs = [] + print(replica_pool) for i, (hostname, gpu_indices) in enumerate(replica_pool): # Reserver port for a LB proxy when replication is enabled if hostname not in port_map: @@ -242,7 +249,9 @@ def generate_replica_configs(cls, values): base_port = port_number + i * tensor_parallel + port_offset if base_port in port_map[hostname]: base_port = max(port_map[hostname]) + 1 - tensor_parallel_ports = list(range(base_port, base_port + tensor_parallel)) + tensor_parallel_ports = list( + range(base_port, + base_port + tensor_parallel)) for i in range(base_port, base_port + tensor_parallel): port_map[hostname].add(i) replica_torch_dist_port = torch_dist_port + (100 * i) @@ -266,11 +275,14 @@ def _allocate_processes(hostfile_path, tensor_parallel, replica_num, GPU_index_m replica_pool = [] - if gpu_index_map is not None: - for host in gpu_index_map: - replica_pool.append((host, gpu_index_map[host])) - return replica_pool - + if GPU_index_map is not None: + for host in GPU_index_map: + assert host in resource_pool, f"Host: {host} was not found" + assert resource_pool[host] >= tensor_parallel, f"Host {host} has {resource_pool[host]} slot(s), but {tensor_parallel} slot(s) are required" + for host in GPU_index_map: + replica_pool.append((host, GPU_index_map[host])) + return replica_pool, GPU_index_map + allocated_num = 0 for host, slots in resource_pool.items(): available_on_host = slots @@ -301,4 +313,4 @@ def _allocate_processes(hostfile_path, tensor_parallel, replica_num, GPU_index_m f"Not sufficient GPUs for {replica_num} replica(s), only {allocated_num} replica(s) can be deployed" ) - return replica_pool + return replica_pool, GPU_index_map diff --git a/mii/deployment.py b/mii/deployment.py index d76d1f35..115f6c55 100644 --- a/mii/deployment.py +++ b/mii/deployment.py @@ -47,9 +47,11 @@ def support_legacy_api( def deploy( - deployment_name: str, - deployment_configs: List[dict], + deployment_name: str = None, + deployment_config: dict = None, mii_config: dict = None, + deployment_configs: list[dict] = None, + deployment_tag: str = None, *args, **kwargs, ): @@ -60,21 +62,29 @@ def deploy( assert ( not deployment_config ), "We do not support mixture of legacy and new API options, use latest API." + assert deployment_name, "deployment_name required for singular deployment" kwargs["mii_config"] = mii_config deployment_config, mii_config = support_legacy_api(*args, **kwargs) deployment_config["deployment_name"] = deployment_name mii_config["deployment_tag"] = deployment_name - mii_config["deployment_configs"] = {deployment_name: DeploymentConfig(**deployment_config)} + mii_config["deployment_configs"] = { + deployment_name: DeploymentConfig(**deployment_config) + } else: + assert all((deployment_tag, deployment_configs)), "To deploy multiple models you must use deployment_tag and deployment_configs" + deployment_dict = {} for deployment_config in deployment_configs: - deployment_config = DeploymentConfig(**deployment_config) - mii_config["deployment_configs"] = {dep.deployment_name: dep for dep in deployment_configs} - mii_config["deployment_tag"] = kwargs["deployment_tag"] - + deployment_dict[deployment_config.get('deployment_name')] = DeploymentConfig( + **deployment_config) + #print(deployment_dict) + mii_config["deployment_configs"] = deployment_dict + mii_config["deployment_tag"] = deployment_tag + + print(mii_config.keys()) mii_config = mii.config.MIIConfig(**mii_config) - for deployment_config in mii_config.deployment_configs: + for deployment_config in mii_config.deployment_configs.values(): if deployment_config.enable_deepspeed: logger.info( f"************* MII is using DeepSpeed Optimizations to accelerate your model *************" diff --git a/mii/grpc_related/modelresponse_server.py b/mii/grpc_related/modelresponse_server.py index a7a812c8..c5815b26 100644 --- a/mii/grpc_related/modelresponse_server.py +++ b/mii/grpc_related/modelresponse_server.py @@ -171,10 +171,8 @@ def invoke(self, method_name, proto_request): self.asyncio_loop).result() - - class LoadBalancingInterceptor(grpc.ServerInterceptor): - def __init__(self, mii_config): + def __init__(self, mii_config): super().__init__() self.asyncio_loop = asyncio.get_event_loop() @@ -186,7 +184,7 @@ def __init__(self, mii_config): self.stubs[deployment.deployment_name] = [] self.counter[deployment.deployment_name] = AtomicCounter() self.tasks[deployment.deployment_name] = repl.task - + for deployment in mii_config.deployment_configs: deployment_name = deployment.deployment_name for repl in deployment.replica_configs: @@ -257,6 +255,7 @@ def invoke_intercept_method(request_proto, context): request_deserializer=next_handler.request_deserializer, response_serializer=next_handler.response_serializer) + def _do_serve(service_impl, port, interceptors=[]): stop_event = service_impl.get_stop_event() server = grpc.server( diff --git a/mii/grpc_related/proto/modelresponse.proto b/mii/grpc_related/proto/modelresponse.proto index f3e5ff29..74ca4913 100644 --- a/mii/grpc_related/proto/modelresponse.proto +++ b/mii/grpc_related/proto/modelresponse.proto @@ -65,7 +65,7 @@ message SingleStringReply { string response = 1; float time_taken = 2; float model_time_taken = 3; -} +} message MultiStringReply { repeated string response = 1; diff --git a/mii/grpc_related/proto/modelresponse_pb2.py b/mii/grpc_related/proto/modelresponse_pb2.py index c9f12538..0a219bf3 100644 --- a/mii/grpc_related/proto/modelresponse_pb2.py +++ b/mii/grpc_related/proto/modelresponse_pb2.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: modelresponse.proto """Generated protocol buffer code.""" @@ -10,54 +9,54 @@ _sym_db = _symbol_database.Default() - from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xed\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x12\x1c\n\x0f\x64\x65ployment_name\x18\x03 \x01(\tH\x00\x88\x01\x01\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_deployment_name\"\xeb\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x12\x1c\n\x0f\x64\x65ployment_name\x18\x03 \x01(\tH\x00\x88\x01\x01\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_deployment_name\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"\xeb\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x12\x1c\n\x0f\x64\x65ployment_name\x18\x04 \x01(\tH\x00\x88\x01\x01\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_deployment_name\"\xd3\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x1c\n\x0f\x63onversation_id\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x12\x1c\n\x0f\x64\x65ployment_name\x18\x06 \x01(\tH\x01\x88\x01\x01\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_conversation_idB\x12\n\x10_deployment_name\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\x03\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x32\xd4\x06\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x43\n\rCreateSession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x44\n\x0e\x44\x65stroySession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12V\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1f.modelresponse.MultiStringReply\"\x00\x12]\n\x13\x43lassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply\"\x00\x12W\n\rFillMaskReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12\x62\n\x18TokenClassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12]\n\x13\x43onversationalReply\x12\".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply\"\x00\x12N\n\x0cTxt2ImgReply\x12!.modelresponse.MultiStringRequest\x1a\x19.modelresponse.ImageReply\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xed\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x12\x1c\n\x0f\x64\x65ployment_name\x18\x03 \x01(\tH\x00\x88\x01\x01\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_deployment_name\"\xeb\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x12\x1c\n\x0f\x64\x65ployment_name\x18\x03 \x01(\tH\x00\x88\x01\x01\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_deployment_name\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"\xeb\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x12\x1c\n\x0f\x64\x65ployment_name\x18\x04 \x01(\tH\x00\x88\x01\x01\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_deployment_name\"\xd3\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x1c\n\x0f\x63onversation_id\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x12\x1c\n\x0f\x64\x65ployment_name\x18\x06 \x01(\tH\x01\x88\x01\x01\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_conversation_idB\x12\n\x10_deployment_name\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\x03\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x32\xd4\x06\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x43\n\rCreateSession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x44\n\x0e\x44\x65stroySession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12V\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1f.modelresponse.MultiStringReply\"\x00\x12]\n\x13\x43lassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply\"\x00\x12W\n\rFillMaskReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12\x62\n\x18TokenClassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12]\n\x13\x43onversationalReply\x12\".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply\"\x00\x12N\n\x0cTxt2ImgReply\x12!.modelresponse.MultiStringRequest\x1a\x19.modelresponse.ImageReply\"\x00\x62\x06proto3' +) _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'modelresponse_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._options = None - _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _MULTISTRINGREQUEST_QUERYKWARGSENTRY._options = None - _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _QAREQUEST_QUERYKWARGSENTRY._options = None - _QAREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _CONVERSATIONREQUEST_QUERYKWARGSENTRY._options = None - _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _globals['_VALUE']._serialized_start=67 - _globals['_VALUE']._serialized_end=162 - _globals['_SESSIONID']._serialized_start=164 - _globals['_SESSIONID']._serialized_end=195 - _globals['_SINGLESTRINGREQUEST']._serialized_start=198 - _globals['_SINGLESTRINGREQUEST']._serialized_end=435 - _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start=343 - _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end=415 - _globals['_MULTISTRINGREQUEST']._serialized_start=438 - _globals['_MULTISTRINGREQUEST']._serialized_end=673 - _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start=343 - _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end=415 - _globals['_SINGLESTRINGREPLY']._serialized_start=675 - _globals['_SINGLESTRINGREPLY']._serialized_end=758 - _globals['_MULTISTRINGREPLY']._serialized_start=760 - _globals['_MULTISTRINGREPLY']._serialized_end=842 - _globals['_QAREQUEST']._serialized_start=845 - _globals['_QAREQUEST']._serialized_end=1080 - _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_start=343 - _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_end=415 - _globals['_CONVERSATIONREQUEST']._serialized_start=1083 - _globals['_CONVERSATIONREQUEST']._serialized_end=1422 - _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_start=343 - _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_end=415 - _globals['_CONVERSATIONREPLY']._serialized_start=1425 - _globals['_CONVERSATIONREPLY']._serialized_end=1570 - _globals['_IMAGEREPLY']._serialized_start=1572 - _globals['_IMAGEREPLY']._serialized_end=1697 - _globals['_MODELRESPONSE']._serialized_start=1700 - _globals['_MODELRESPONSE']._serialized_end=2552 + DESCRIPTOR._options = None + _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._options = None + _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' + _MULTISTRINGREQUEST_QUERYKWARGSENTRY._options = None + _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' + _QAREQUEST_QUERYKWARGSENTRY._options = None + _QAREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' + _CONVERSATIONREQUEST_QUERYKWARGSENTRY._options = None + _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' + _globals['_VALUE']._serialized_start = 67 + _globals['_VALUE']._serialized_end = 162 + _globals['_SESSIONID']._serialized_start = 164 + _globals['_SESSIONID']._serialized_end = 195 + _globals['_SINGLESTRINGREQUEST']._serialized_start = 198 + _globals['_SINGLESTRINGREQUEST']._serialized_end = 435 + _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 343 + _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 415 + _globals['_MULTISTRINGREQUEST']._serialized_start = 438 + _globals['_MULTISTRINGREQUEST']._serialized_end = 673 + _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 343 + _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 415 + _globals['_SINGLESTRINGREPLY']._serialized_start = 675 + _globals['_SINGLESTRINGREPLY']._serialized_end = 758 + _globals['_MULTISTRINGREPLY']._serialized_start = 760 + _globals['_MULTISTRINGREPLY']._serialized_end = 842 + _globals['_QAREQUEST']._serialized_start = 845 + _globals['_QAREQUEST']._serialized_end = 1080 + _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_start = 343 + _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_end = 415 + _globals['_CONVERSATIONREQUEST']._serialized_start = 1083 + _globals['_CONVERSATIONREQUEST']._serialized_end = 1422 + _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_start = 343 + _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_end = 415 + _globals['_CONVERSATIONREPLY']._serialized_start = 1425 + _globals['_CONVERSATIONREPLY']._serialized_end = 1570 + _globals['_IMAGEREPLY']._serialized_start = 1572 + _globals['_IMAGEREPLY']._serialized_end = 1697 + _globals['_MODELRESPONSE']._serialized_start = 1700 + _globals['_MODELRESPONSE']._serialized_end = 2552 # @@protoc_insertion_point(module_scope) diff --git a/mii/grpc_related/proto/modelresponse_pb2_grpc.py b/mii/grpc_related/proto/modelresponse_pb2_grpc.py index 683e4962..f6515b0a 100644 --- a/mii/grpc_related/proto/modelresponse_pb2_grpc.py +++ b/mii/grpc_related/proto/modelresponse_pb2_grpc.py @@ -8,7 +8,6 @@ class ModelResponseStub(object): """Missing associated documentation comment in .proto file.""" - def __init__(self, channel): """Constructor. @@ -16,60 +15,60 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Terminate = channel.unary_unary( - '/modelresponse.ModelResponse/Terminate', - request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - ) + '/modelresponse.ModelResponse/Terminate', + request_serializer=google_dot_protobuf_dot_empty__pb2.Empty. + SerializeToString, + response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + ) self.CreateSession = channel.unary_unary( - '/modelresponse.ModelResponse/CreateSession', - request_serializer=modelresponse__pb2.SessionID.SerializeToString, - response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - ) + '/modelresponse.ModelResponse/CreateSession', + request_serializer=modelresponse__pb2.SessionID.SerializeToString, + response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + ) self.DestroySession = channel.unary_unary( - '/modelresponse.ModelResponse/DestroySession', - request_serializer=modelresponse__pb2.SessionID.SerializeToString, - response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - ) + '/modelresponse.ModelResponse/DestroySession', + request_serializer=modelresponse__pb2.SessionID.SerializeToString, + response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + ) self.GeneratorReply = channel.unary_unary( - '/modelresponse.ModelResponse/GeneratorReply', - request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.MultiStringReply.FromString, - ) + '/modelresponse.ModelResponse/GeneratorReply', + request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.MultiStringReply.FromString, + ) self.ClassificationReply = channel.unary_unary( - '/modelresponse.ModelResponse/ClassificationReply', - request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) + '/modelresponse.ModelResponse/ClassificationReply', + request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.SingleStringReply.FromString, + ) self.QuestionAndAnswerReply = channel.unary_unary( - '/modelresponse.ModelResponse/QuestionAndAnswerReply', - request_serializer=modelresponse__pb2.QARequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) + '/modelresponse.ModelResponse/QuestionAndAnswerReply', + request_serializer=modelresponse__pb2.QARequest.SerializeToString, + response_deserializer=modelresponse__pb2.SingleStringReply.FromString, + ) self.FillMaskReply = channel.unary_unary( - '/modelresponse.ModelResponse/FillMaskReply', - request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) + '/modelresponse.ModelResponse/FillMaskReply', + request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.SingleStringReply.FromString, + ) self.TokenClassificationReply = channel.unary_unary( - '/modelresponse.ModelResponse/TokenClassificationReply', - request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) + '/modelresponse.ModelResponse/TokenClassificationReply', + request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.SingleStringReply.FromString, + ) self.ConversationalReply = channel.unary_unary( - '/modelresponse.ModelResponse/ConversationalReply', - request_serializer=modelresponse__pb2.ConversationRequest.SerializeToString, - response_deserializer=modelresponse__pb2.ConversationReply.FromString, - ) + '/modelresponse.ModelResponse/ConversationalReply', + request_serializer=modelresponse__pb2.ConversationRequest.SerializeToString, + response_deserializer=modelresponse__pb2.ConversationReply.FromString, + ) self.Txt2ImgReply = channel.unary_unary( - '/modelresponse.ModelResponse/Txt2ImgReply', - request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.ImageReply.FromString, - ) + '/modelresponse.ModelResponse/Txt2ImgReply', + request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.ImageReply.FromString, + ) class ModelResponseServicer(object): """Missing associated documentation comment in .proto file.""" - def Terminate(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -133,232 +132,334 @@ def Txt2ImgReply(self, request, context): def add_ModelResponseServicer_to_server(servicer, server): rpc_method_handlers = { - 'Terminate': grpc.unary_unary_rpc_method_handler( - servicer.Terminate, - request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - ), - 'CreateSession': grpc.unary_unary_rpc_method_handler( - servicer.CreateSession, - request_deserializer=modelresponse__pb2.SessionID.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - ), - 'DestroySession': grpc.unary_unary_rpc_method_handler( - servicer.DestroySession, - request_deserializer=modelresponse__pb2.SessionID.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - ), - 'GeneratorReply': grpc.unary_unary_rpc_method_handler( - servicer.GeneratorReply, - request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, - response_serializer=modelresponse__pb2.MultiStringReply.SerializeToString, - ), - 'ClassificationReply': grpc.unary_unary_rpc_method_handler( - servicer.ClassificationReply, - request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'QuestionAndAnswerReply': grpc.unary_unary_rpc_method_handler( - servicer.QuestionAndAnswerReply, - request_deserializer=modelresponse__pb2.QARequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'FillMaskReply': grpc.unary_unary_rpc_method_handler( - servicer.FillMaskReply, - request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'TokenClassificationReply': grpc.unary_unary_rpc_method_handler( - servicer.TokenClassificationReply, - request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'ConversationalReply': grpc.unary_unary_rpc_method_handler( - servicer.ConversationalReply, - request_deserializer=modelresponse__pb2.ConversationRequest.FromString, - response_serializer=modelresponse__pb2.ConversationReply.SerializeToString, - ), - 'Txt2ImgReply': grpc.unary_unary_rpc_method_handler( - servicer.Txt2ImgReply, - request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, - response_serializer=modelresponse__pb2.ImageReply.SerializeToString, - ), + 'Terminate': + grpc.unary_unary_rpc_method_handler( + servicer.Terminate, + request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. + SerializeToString, + ), + 'CreateSession': + grpc.unary_unary_rpc_method_handler( + servicer.CreateSession, + request_deserializer=modelresponse__pb2.SessionID.FromString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. + SerializeToString, + ), + 'DestroySession': + grpc.unary_unary_rpc_method_handler( + servicer.DestroySession, + request_deserializer=modelresponse__pb2.SessionID.FromString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. + SerializeToString, + ), + 'GeneratorReply': + grpc.unary_unary_rpc_method_handler( + servicer.GeneratorReply, + request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, + response_serializer=modelresponse__pb2.MultiStringReply.SerializeToString, + ), + 'ClassificationReply': + grpc.unary_unary_rpc_method_handler( + servicer.ClassificationReply, + request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, + response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, + ), + 'QuestionAndAnswerReply': + grpc.unary_unary_rpc_method_handler( + servicer.QuestionAndAnswerReply, + request_deserializer=modelresponse__pb2.QARequest.FromString, + response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, + ), + 'FillMaskReply': + grpc.unary_unary_rpc_method_handler( + servicer.FillMaskReply, + request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, + response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, + ), + 'TokenClassificationReply': + grpc.unary_unary_rpc_method_handler( + servicer.TokenClassificationReply, + request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, + response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, + ), + 'ConversationalReply': + grpc.unary_unary_rpc_method_handler( + servicer.ConversationalReply, + request_deserializer=modelresponse__pb2.ConversationRequest.FromString, + response_serializer=modelresponse__pb2.ConversationReply.SerializeToString, + ), + 'Txt2ImgReply': + grpc.unary_unary_rpc_method_handler( + servicer.Txt2ImgReply, + request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, + response_serializer=modelresponse__pb2.ImageReply.SerializeToString, + ), } - generic_handler = grpc.method_handlers_generic_handler( - 'modelresponse.ModelResponse', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) + generic_handler = grpc.method_handlers_generic_handler('modelresponse.ModelResponse', + rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler, )) - # This class is part of an EXPERIMENTAL API. +# This class is part of an EXPERIMENTAL API. class ModelResponse(object): """Missing associated documentation comment in .proto file.""" - @staticmethod def Terminate(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/Terminate', + '/modelresponse.ModelResponse/Terminate', google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, google_dot_protobuf_dot_empty__pb2.Empty.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) @staticmethod def CreateSession(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/CreateSession', + '/modelresponse.ModelResponse/CreateSession', modelresponse__pb2.SessionID.SerializeToString, google_dot_protobuf_dot_empty__pb2.Empty.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) @staticmethod def DestroySession(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/DestroySession', + '/modelresponse.ModelResponse/DestroySession', modelresponse__pb2.SessionID.SerializeToString, google_dot_protobuf_dot_empty__pb2.Empty.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) @staticmethod def GeneratorReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/GeneratorReply', + '/modelresponse.ModelResponse/GeneratorReply', modelresponse__pb2.MultiStringRequest.SerializeToString, modelresponse__pb2.MultiStringReply.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) @staticmethod def ClassificationReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/ClassificationReply', + '/modelresponse.ModelResponse/ClassificationReply', modelresponse__pb2.SingleStringRequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) @staticmethod def QuestionAndAnswerReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/QuestionAndAnswerReply', + '/modelresponse.ModelResponse/QuestionAndAnswerReply', modelresponse__pb2.QARequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) @staticmethod def FillMaskReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/FillMaskReply', + '/modelresponse.ModelResponse/FillMaskReply', modelresponse__pb2.SingleStringRequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) @staticmethod def TokenClassificationReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/TokenClassificationReply', + '/modelresponse.ModelResponse/TokenClassificationReply', modelresponse__pb2.SingleStringRequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) @staticmethod def ConversationalReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/ConversationalReply', + '/modelresponse.ModelResponse/ConversationalReply', modelresponse__pb2.ConversationRequest.SerializeToString, modelresponse__pb2.ConversationReply.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) @staticmethod def Txt2ImgReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/Txt2ImgReply', + '/modelresponse.ModelResponse/Txt2ImgReply', modelresponse__pb2.MultiStringRequest.SerializeToString, modelresponse__pb2.ImageReply.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) diff --git a/mii/launch/multi_gpu_server.py b/mii/launch/multi_gpu_server.py index 195094fe..8aa080b2 100644 --- a/mii/launch/multi_gpu_server.py +++ b/mii/launch/multi_gpu_server.py @@ -23,6 +23,7 @@ def b64_encoded_config(config_str): # return mii.DeploymentConfig object return DeploymentConfig(**config_dict) + def b64_encoded_config(config_str): # str -> bytes b64_bytes = config_str.encode() @@ -33,6 +34,7 @@ def b64_encoded_config(config_str): # return mii.MIIConfig object return MIIConfig(**config_dict) + def main(): parser = argparse.ArgumentParser() parser.add_argument( diff --git a/mii/models/score/score_template.py b/mii/models/score/score_template.py index 353a60c6..8000c34c 100644 --- a/mii/models/score/score_template.py +++ b/mii/models/score/score_template.py @@ -10,6 +10,7 @@ import torch import mii +from mii.config import DeploymentConfig model = None diff --git a/mii/server.py b/mii/server.py index 24e0b59c..742f2488 100644 --- a/mii/server.py +++ b/mii/server.py @@ -27,12 +27,11 @@ class MIIServer: def __init__(self, mii_config): #self.task = mii_config.deployment_config.task - #self.num_gpus = get_num_gpus(mii_config) #assert self.num_gpus > 0, "GPU count must be greater than 0" - - for deployment_config in mii_config.deployment_configs: + + for deployment_config in mii_config.deployment_configs.values(): assert get_num_gpus(deployment_config) > 0, f"GPU count for {deployment.deployment_name} must be greater than 0" """ if mii_configs.hostfile is None: @@ -91,8 +90,10 @@ def _launch_server_process(self, ds_launch_str="", server_args=[]): launch_str = f"{sys.executable} -m mii.launch.multi_gpu_server" + print(deployment_config) b64_config_str = config_to_b64_str(deployment_config) server_args.append(f"--deployment-config {b64_config_str}") + server_args_str = " ".join(server_args) cmd = f"{ds_launch_str} {launch_str} {server_args_str}".strip().split(" ") @@ -125,13 +126,13 @@ def _initialize_service(self, mii_config): ] host_gpus = defaultdict(list) - for deployment in mii_config.deployment_configs: + for deployment in mii_config.deployment_configs.values(): for repl_config in deployment.replica_configs: host_gpus[repl_config.hostname].extend(repl_config.gpu_indices) # Start replica instances - for deployment_config in mii_config.deployment_configs: - for repl_config in mii_config.deployment_config.replica_configs: + for deployment_config in mii_config.deployment_configs.values(): + for repl_config in deployment_config.replica_configs: hostfile = tempfile.NamedTemporaryFile(delete=False) hostfile.write( f"{repl_config.hostname} slots={max(host_gpus[repl_config.hostname])+1}\n" @@ -159,7 +160,7 @@ def _initialize_service(self, mii_config): if mii_config.enable_restful_api: processes.append( self._launch_server_process( - mii_config.deployment_config, + next(iter(mii_config.deployment_configs.values())), "restful api gateway", server_args=server_args + ["--restful-gateway"], )) From de998bf52f4cf11fad9ac613643f0a9ab1a054b4 Mon Sep 17 00:00:00 2001 From: Tosin Segun Date: Sat, 12 Aug 2023 00:07:25 +0000 Subject: [PATCH 15/15] Fixing deserialization --- mii/launch/multi_gpu_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mii/launch/multi_gpu_server.py b/mii/launch/multi_gpu_server.py index 8aa080b2..34b46845 100644 --- a/mii/launch/multi_gpu_server.py +++ b/mii/launch/multi_gpu_server.py @@ -24,7 +24,7 @@ def b64_encoded_config(config_str): return DeploymentConfig(**config_dict) -def b64_encoded_config(config_str): +def b64_encoded_config_MII(config_str): #TODO: Remove Duplicated Funciton # str -> bytes b64_bytes = config_str.encode() # decode b64 bytes -> json bytes @@ -44,7 +44,7 @@ def main(): ) parser.add_argument( "--mii-config", - type=b64_encoded_config, + type=b64_encoded_config_MII, help="base64 encoded mii config", ) parser.add_argument(