From 54c8e143b89bd787d15ea20ad14db96801ade50e Mon Sep 17 00:00:00 2001 From: Dmytro Parfeniuk Date: Wed, 26 Jun 2024 11:41:34 +0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20`create=5Fbackend`=20invalid=20c?= =?UTF-8?q?ondition=20is=20fixed?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * `create_backend` is no longer fail due to inproper condition * rename: `BackendTypes` -> `BackendType` --- src/guidellm/backend/__init__.py | 4 ++-- src/guidellm/backend/base.py | 18 ++++++++++-------- src/guidellm/backend/openai.py | 4 ++-- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/guidellm/backend/__init__.py b/src/guidellm/backend/__init__.py index cc5c740..f9d5541 100644 --- a/src/guidellm/backend/__init__.py +++ b/src/guidellm/backend/__init__.py @@ -1,9 +1,9 @@ -from .base import Backend, BackendTypes, GenerativeResponse +from .base import Backend, BackendType, GenerativeResponse from .openai import OpenAIBackend __all__ = [ "Backend", - "BackendTypes", + "BackendType", "GenerativeResponse", "OpenAIBackend", ] diff --git a/src/guidellm/backend/base.py b/src/guidellm/backend/base.py index 22aab80..b09ce37 100644 --- a/src/guidellm/backend/base.py +++ b/src/guidellm/backend/base.py @@ -2,17 +2,17 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from typing import Iterator, List, Optional, Type, Union +from typing import Iterator, List, Optional, Type from loguru import logger from guidellm.core.request import TextGenerationRequest from guidellm.core.result import TextGenerationResult -__all__ = ["Backend", "BackendTypes", "GenerativeResponse"] +__all__ = ["Backend", "BackendType", "GenerativeResponse"] -class BackendTypes(Enum): +class BackendType(str, Enum): TEST = "test" OPENAI_SERVER = "openai_server" @@ -39,12 +39,12 @@ class Backend(ABC): _registry = {} @staticmethod - def register_backend(backend_type: BackendTypes): + def register_backend(backend_type: BackendType): """ A decorator to register a backend class in the backend registry. :param backend_type: The type of backend to register. - :type backend_type: BackendTypes + :type backend_type: BackendType """ def inner_wrapper(wrapped_class: Type["Backend"]): @@ -54,21 +54,23 @@ def inner_wrapper(wrapped_class: Type["Backend"]): return inner_wrapper @staticmethod - def create_backend(backend_type: Union[str, BackendTypes], **kwargs) -> "Backend": + def create_backend(backend_type: BackendType, **kwargs) -> "Backend": """ Factory method to create a backend based on the backend type. :param backend_type: The type of backend to create. - :type backend_type: BackendTypes + :type backend_type: BackendType :param kwargs: Additional arguments for backend initialization. :type kwargs: dict :return: An instance of a subclass of Backend. :rtype: Backend """ logger.info(f"Creating backend of type {backend_type}") - if backend_type not in Backend._registry: + + if backend_type not in Backend._registry.keys(): logger.error(f"Unsupported backend type: {backend_type}") raise ValueError(f"Unsupported backend type: {backend_type}") + return Backend._registry[backend_type](**kwargs) def submit(self, request: TextGenerationRequest) -> TextGenerationResult: diff --git a/src/guidellm/backend/openai.py b/src/guidellm/backend/openai.py index ce9f6c2..d2656be 100644 --- a/src/guidellm/backend/openai.py +++ b/src/guidellm/backend/openai.py @@ -4,13 +4,13 @@ from loguru import logger from transformers import AutoTokenizer -from guidellm.backend import Backend, BackendTypes, GenerativeResponse +from guidellm.backend import Backend, BackendType, GenerativeResponse from guidellm.core.request import TextGenerationRequest __all__ = ["OpenAIBackend"] -@Backend.register_backend(BackendTypes.OPENAI_SERVER) +@Backend.register_backend(BackendType.OPENAI_SERVER) class OpenAIBackend(Backend): """ An OpenAI backend implementation for the generative AI result.