diff --git a/src/guidellm/backend/__init__.py b/src/guidellm/backend/__init__.py index cc5c740..f9d5541 100644 --- a/src/guidellm/backend/__init__.py +++ b/src/guidellm/backend/__init__.py @@ -1,9 +1,9 @@ -from .base import Backend, BackendTypes, GenerativeResponse +from .base import Backend, BackendType, GenerativeResponse from .openai import OpenAIBackend __all__ = [ "Backend", - "BackendTypes", + "BackendType", "GenerativeResponse", "OpenAIBackend", ] diff --git a/src/guidellm/backend/base.py b/src/guidellm/backend/base.py index 22aab80..b09ce37 100644 --- a/src/guidellm/backend/base.py +++ b/src/guidellm/backend/base.py @@ -2,17 +2,17 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from typing import Iterator, List, Optional, Type, Union +from typing import Iterator, List, Optional, Type from loguru import logger from guidellm.core.request import TextGenerationRequest from guidellm.core.result import TextGenerationResult -__all__ = ["Backend", "BackendTypes", "GenerativeResponse"] +__all__ = ["Backend", "BackendType", "GenerativeResponse"] -class BackendTypes(Enum): +class BackendType(str, Enum): TEST = "test" OPENAI_SERVER = "openai_server" @@ -39,12 +39,12 @@ class Backend(ABC): _registry = {} @staticmethod - def register_backend(backend_type: BackendTypes): + def register_backend(backend_type: BackendType): """ A decorator to register a backend class in the backend registry. :param backend_type: The type of backend to register. - :type backend_type: BackendTypes + :type backend_type: BackendType """ def inner_wrapper(wrapped_class: Type["Backend"]): @@ -54,21 +54,23 @@ def inner_wrapper(wrapped_class: Type["Backend"]): return inner_wrapper @staticmethod - def create_backend(backend_type: Union[str, BackendTypes], **kwargs) -> "Backend": + def create_backend(backend_type: BackendType, **kwargs) -> "Backend": """ Factory method to create a backend based on the backend type. :param backend_type: The type of backend to create. - :type backend_type: BackendTypes + :type backend_type: BackendType :param kwargs: Additional arguments for backend initialization. :type kwargs: dict :return: An instance of a subclass of Backend. :rtype: Backend """ logger.info(f"Creating backend of type {backend_type}") - if backend_type not in Backend._registry: + + if backend_type not in Backend._registry.keys(): logger.error(f"Unsupported backend type: {backend_type}") raise ValueError(f"Unsupported backend type: {backend_type}") + return Backend._registry[backend_type](**kwargs) def submit(self, request: TextGenerationRequest) -> TextGenerationResult: diff --git a/src/guidellm/backend/openai.py b/src/guidellm/backend/openai.py index ce9f6c2..d2656be 100644 --- a/src/guidellm/backend/openai.py +++ b/src/guidellm/backend/openai.py @@ -4,13 +4,13 @@ from loguru import logger from transformers import AutoTokenizer -from guidellm.backend import Backend, BackendTypes, GenerativeResponse +from guidellm.backend import Backend, BackendType, GenerativeResponse from guidellm.core.request import TextGenerationRequest __all__ = ["OpenAIBackend"] -@Backend.register_backend(BackendTypes.OPENAI_SERVER) +@Backend.register_backend(BackendType.OPENAI_SERVER) class OpenAIBackend(Backend): """ An OpenAI backend implementation for the generative AI result.