diff --git a/src/leapfrogai_api/backend/rag/leapfrogai_embeddings.py b/src/leapfrogai_api/backend/rag/leapfrogai_embeddings.py index 8274a5861..f48ab98b3 100644 --- a/src/leapfrogai_api/backend/rag/leapfrogai_embeddings.py +++ b/src/leapfrogai_api/backend/rag/leapfrogai_embeddings.py @@ -2,7 +2,7 @@ import os import leapfrogai_sdk as lfai -from leapfrogai_api.utils import get_model_config +from leapfrogai_api.utils.__init__ import config as global_config from leapfrogai_api.backend.grpc_client import create_embeddings import logging @@ -61,9 +61,9 @@ async def _get_model( Raises: ValueError: If the embeddings model is not found. """ - - if not (model := get_model_config().get_model_backend(model=model_name)): - logger.error(f"Embeddings model {model_name} not found.") + config = await global_config.create() + if not (model := config.get_model_backend(model=model_name)): + logging.error(f"Embeddings model {model_name} not found.") raise ValueError("Embeddings model not found.") return model diff --git a/src/leapfrogai_api/backend/types.py b/src/leapfrogai_api/backend/types.py index 59011003c..efaab824a 100644 --- a/src/leapfrogai_api/backend/types.py +++ b/src/leapfrogai_api/backend/types.py @@ -4,7 +4,8 @@ import datetime from enum import Enum -from typing import Literal +from typing import Any, Literal +import warnings from fastapi import UploadFile, Form, File from openai.types import FileObject @@ -18,7 +19,7 @@ ) from openai.types.beta.threads.text_content_block_param import TextContentBlockParam from openai.types.beta.vector_store import ExpiresAfter -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator, ValidationInfo, ConfigDict ########## # DEFAULTS @@ -27,7 +28,7 @@ DEFAULT_MAX_COMPLETION_TOKENS = 4096 DEFAULT_MAX_PROMPT_TOKENS = 4096 - +DEFAULT_DIMENSION: int = 2000 ########## # GENERIC @@ -49,11 +50,160 @@ class Usage(BaseModel): ) +########## +# WARNINGS +########## + + +class LeapfrogAIWarning(UserWarning): + """Base warning class for LeapfrogAI.""" + + +class DimensionWarning(LeapfrogAIWarning): + """Warning for dimension-related issues.""" + + def __init__( + self, + dimension: int, + dimension_default: int = DEFAULT_DIMENSION, + ): + super().__init__( + f"Dimension {dimension} exceeds the recommended maximum of {dimension_default}." + ) + + +class CapabilityWarning(LeapfrogAIWarning): + """Warning for capability-related issues.""" + + ########## # MODELS ########## +class Modality(str, Enum): + """Defines the input/output modality of a model (image, text, speech, or null).""" + + IMAGE = "image" + TEXT = "text" + SPEECH = "speech" + + +class Capability(str, Enum): + """Specifies the functional capabilities of a model (chat, embeddings, speech-to-text, text-to-speech, or null).""" + + CHAT = "chat" + EMBEDDINGS = "embeddings" + STT = "stt" + TTS = "tts" + + +class Precision(str, Enum): + """Indicates the numerical precision used in the model's parameters (float16, float32, bfloat16, int8, int4, or null).""" + + FLOAT16 = "float16" + FLOAT32 = "float32" + BFLOAT16 = "bfloat16" + INT8 = "int8" + INT4 = "int4" + + +class Format(str, Enum): + """Describes the storage or quantization format of the model (None, GPTQ, GGUF, SqueezeLLM, AWQ, or null).""" + + AWQ = "AWQ" + GGUF = "GGUF" + GPTQ = "GPTQ" + SQUEEZELLM = "SqueezeLLM" + + +class ModelMetadataResponse(BaseModel): + """Metadata for the model, including type, dimensions (for embeddings), and precision.""" + + model_config = ConfigDict(use_enum_values=True) + + capabilities: list[Capability] | None = Field( + default=None, # TODO: should we define this as an empty string if it's None? + description="Model capabilities (e.g., 'embeddings', 'chat', 'tts', 'stt')", + ) + dimensions: int | None = Field( + default=None, + description="Embedding dimensions (for embeddings models)", + ) + format: Format | None = Field( + default=None, + description="Model format (e.g., None, 'GPTQ', 'GGUF', 'SqueezeLLM', 'AWQ')", + ) + modalities: list[Modality] | None = Field( + default=None, # TODO: should we define this as an empty string if it's None? + description="The modalities of the model (e.g., 'image', 'text', 'speech')", + ) + + precision: Precision | None = Field( + default=None, + description="Model precision (e.g., 'float16', 'float32', 'bfloat16', 'int8', 'int4')", + ) + type: Literal["embeddings", "llm"] | None = Field( + default=None, + description="The type of the model e.g. ('embeddings' or 'llm')", + ) + + @field_validator("dimensions") + @classmethod + def check_dimensions( + cls, + v: int | None, + info: ValidationInfo, + ) -> int | None: + """ + Validates the 'dimensions' field of a model's metadata. + + Args: + v: The dimension value to be validated. + info: The validation information. + + Returns: + The validated dimension value. + + Raises: + CapabilityError: If the 'dimensions' field is not set for models with 'embeddings' capability or vice versa. + """ + if v is not None and v > 2000: + warnings.warn(DimensionWarning(dimension=v)) + return v + + @field_validator("capabilities") + @classmethod + def validate_capabilities( + cls, + v: list[Capability] | None, + values: dict[str, Any], + ) -> list[Capability] | None: + """ + Validates the 'capabilities' field of a model's metadata, ensuring that 'dimensions' + is correctly set when 'embeddings' is in capabilities. + """ + # TODO: Actually error here when 'embeddings' is not in capabilities, once this is actually implemented + + # Check if dimensions is set, warn if 'embeddings' is not in capabilities + if (dimensions_set := (values.get("dimensions", None) is not None)) and not ( + embeddings_set := (v is not None and "embeddings" in v) + ): + warnings.warn( + CapabilityWarning( + "'dimensions' should only be set for models with 'embeddings' capability" + ) + ) + # Check if dimensions is not set, warn if 'embeddings' is in capabilities + elif not dimensions_set and embeddings_set: + warnings.warn( + CapabilityWarning( + "'dimensions' must be set for models with 'embeddings' capability" + ) + ) + return v + + class ModelResponseModel(BaseModel): """Represents a single model in the response.""" @@ -75,6 +225,10 @@ class ModelResponseModel(BaseModel): default="leapfrogai", description="The organization that owns the model. Always 'leapfrogai' for LeapfrogAI models.", ) + metadata: ModelMetadataResponse | None = Field( + default=None, + description="Metadata for the model, including type, dimensions (for embeddings), and precision.", + ) class ModelResponse(BaseModel): @@ -100,7 +254,7 @@ class CompletionRequest(BaseModel): model: str = Field( ..., description="The ID of the model to use for completion.", - example="llama-cpp-python", + examples=["llama-cpp-python"], ) prompt: str | list[int] = Field( ..., @@ -132,7 +286,9 @@ class CompletionChoice(BaseModel): description="Log probabilities for the generated tokens. Only returned if requested.", ) finish_reason: str = Field( - "", description="The reason why the completion finished.", example="length" + "", + description="The reason why the completion finished.", + examples=["length"], ) @@ -579,12 +735,12 @@ class CreateVectorStoreRequest(BaseModel): file_ids: list[str] | None = Field( default=[], description="List of file IDs to be included in the vector store.", - example=["file-abc123", "file-def456"], + examples=["file-abc123", "file-def456"], ) name: str | None = Field( default=None, description="Optional name for the vector store.", - example="My Vector Store", + examples=["My Vector Store"], ) expires_after: ExpiresAfter | None = Field( default=None, @@ -594,7 +750,7 @@ class CreateVectorStoreRequest(BaseModel): metadata: dict | None = Field( default=None, description="Optional metadata for the vector store.", - example={"project": "AI Research", "version": "1.0"}, + examples=[{"project": "AI Research", "version": "1.0"}], ) def add_days_to_timestamp(self, timestamp: int, days: int) -> int: diff --git a/src/leapfrogai_api/main.py b/src/leapfrogai_api/main.py index c3c806dbd..31b9ce8f6 100644 --- a/src/leapfrogai_api/main.py +++ b/src/leapfrogai_api/main.py @@ -1,11 +1,13 @@ """Main FastAPI application for the LeapfrogAI API.""" -import asyncio import logging import os from contextlib import asynccontextmanager +import os + +from typing import AsyncContextManager, Callable -from fastapi import FastAPI +from fastapi import FastAPI, APIRouter from fastapi.exception_handlers import request_validation_exception_handler from fastapi.exceptions import RequestValidationError @@ -27,7 +29,35 @@ threads, vector_stores, ) -from leapfrogai_api.utils import get_model_config +from leapfrogai_api.utils.config import Config + +# TODO: Add in `if __name__ == "__main__":` block to allow uvicorn to be invoked here instead. + +logger = logging.getLogger(__name__) +handler = logging.StreamHandler() +handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) +logger.addHandler(handler) + +API_ROUTERS: list[APIRouter] = [ + base_router, + auth.router, + models.router, + completions.router, + chat.router, + audio.router, + embeddings.router, + assistants.router, + files.router, + vector_stores.router, + runs.router, + messages.router, + runs_steps.router, + lfai_vector_stores.router, + lfai_models.router, + # This should be at the bottom to prevent it preempting more specific runs endpoints + # https://fastapi.tiangolo.com/tutorial/path-params/#order-matters + threads.router, +] logging.basicConfig( level=os.getenv("LFAI_LOG_LEVEL", logging.INFO), @@ -36,43 +66,84 @@ logger = logging.getLogger(__name__) -# handle startup & shutdown tasks -@asynccontextmanager -async def lifespan(app: FastAPI): - """Handle startup and shutdown tasks for the FastAPI app.""" - # startup - logger.info("Starting to watch for configs with this being an info") - asyncio.create_task(get_model_config().watch_and_load_configs()) - yield - # shutdown - logger.info("Clearing model configs") - asyncio.create_task(get_model_config().clear_all_models()) +def get_lifespan( + testing: bool | None = None, +) -> Callable[[FastAPI], AsyncContextManager]: + """ + Returns a lifespan function based on the testing environment. + NOTE: We will never call use function directly, its just to make the app more testable + + Args: + testing (bool | None): A boolean indicating whether the testing environment is active. Defaults to None. + + Returns: + Callable[[FastAPI], AsyncContextManager]: A lifespan function that handles the application's lifecycle. + """ + # Convenience function to get the lifespan function + # + + lifespan_name = "DEVELOPMENT" if testing else "TESTING" + + @asynccontextmanager + async def _lifespan(app: FastAPI): + logger.info(f"Entering {lifespan_name} lifespan") + config = await Config.create(testing=testing) + app.state.config = config + await config.start_watching(testing=testing) + logger.info(f"Yielding control to FastAPI in {lifespan_name} mode") + yield + logger.info(f"Shutting down {lifespan_name} lifespan") -app = FastAPI(lifespan=lifespan) + await app.state.config.cleanup() + logger.info("Lifespan shutdown complete") + logger.info(f"Cleanup complete in {lifespan_name} lifespan mode") + + return _lifespan + + +def create_app( + testing: bool | None = None, + lifespan: Callable[[FastAPI], None] | None = None, + **kwargs, +) -> FastAPI: + """ + Creates a FastAPI application instance. + + Args: + testing (bool | None): A boolean indicating whether the application is in testing mode. + If None, the value will be determined from the LFAI_TESTING environment variable. + lifespan (Callable[[FastAPI], None] | None): A callable that defines the lifespan of the application. + If None, the lifespan will be determined based on the testing mode. + **kwargs: Additional keyword arguments to pass to the FastAPI application constructor. + + Returns: + FastAPI: The created FastAPI application instance. + """ + + # Set the lifespan based off of the testing mode and if it was provided + lifespan = lifespan if callable(lifespan) else get_lifespan(testing=testing) + testing = testing or os.environ.get("LFAI_TESTING", "false").lower() == "true" + + if "debug" not in kwargs: + kwargs["debug"] = testing + + app = FastAPI(lifespan=lifespan, **kwargs) + for router in API_ROUTERS: + app.include_router(router) + + return app -@app.exception_handler(RequestValidationError) async def validation_exception_handler(request, exc): logger.error(f"The client sent invalid data!: {exc}") return await request_validation_exception_handler(request, exc) -app.include_router(base_router) -app.include_router(auth.router) -app.include_router(models.router) -app.include_router(completions.router) -app.include_router(chat.router) -app.include_router(audio.router) -app.include_router(embeddings.router) -app.include_router(assistants.router) -app.include_router(files.router) -app.include_router(vector_stores.router) -app.include_router(runs.router) -app.include_router(messages.router) -app.include_router(runs_steps.router) -app.include_router(lfai_vector_stores.router) -app.include_router(lfai_models.router) -# This should be at the bottom to prevent it preempting more specific runs endpoints -# https://fastapi.tiangolo.com/tutorial/path-params/#order-matters -app.include_router(threads.router) +app = create_app( + testing=False, + lifespan=get_lifespan(testing=False), + exception_handlers={ + RequestValidationError: validation_exception_handler, + }, +) diff --git a/src/leapfrogai_api/routers/leapfrogai/models.py b/src/leapfrogai_api/routers/leapfrogai/models.py index 27b750a2a..778107ddd 100644 --- a/src/leapfrogai_api/routers/leapfrogai/models.py +++ b/src/leapfrogai_api/routers/leapfrogai/models.py @@ -1,10 +1,17 @@ -from fastapi import APIRouter -from leapfrogai_api.utils import get_model_config +from typing import Any + +from fastapi import APIRouter, Request +import logging + router = APIRouter(prefix="/leapfrogai/v1/models", tags=["leapfrogai/models"]) @router.get("") -async def models(): +async def models( + request: Request, +) -> dict[str, dict[str, Any]]: """List all the models.""" - return get_model_config() + config = request.app.state.config + logging.debug(f"CONFIG IN models.py: {config}") + return config.to_dict() diff --git a/src/leapfrogai_api/routers/openai/audio.py b/src/leapfrogai_api/routers/openai/audio.py index aae238cb5..7faf9a10a 100644 --- a/src/leapfrogai_api/routers/openai/audio.py +++ b/src/leapfrogai_api/routers/openai/audio.py @@ -11,7 +11,8 @@ CreateTranslationRequest, ) from leapfrogai_api.routers.supabase_session import Session -from leapfrogai_api.utils import get_model_config +from leapfrogai_api.utils.__init__ import config as global_config + from leapfrogai_api.utils.config import Config import leapfrogai_sdk as lfai @@ -21,7 +22,7 @@ @router.post("/transcriptions") async def transcribe( session: Session, # pylint: disable=unused-argument # required for authorizing endpoint - model_config: Annotated[Config, Depends(get_model_config)], + model_config: Annotated[Config, Depends(global_config.create)], req: CreateTranscriptionRequest = Depends(CreateTranscriptionRequest.as_form), ) -> CreateTranscriptionResponse: """Create a transcription from the given audio file.""" @@ -50,7 +51,7 @@ async def transcribe( @router.post("/translations") async def translate( session: Session, - model_config: Annotated[Config, Depends(get_model_config)], + model_config: Annotated[Config, Depends(global_config.create)], req: CreateTranslationRequest = Depends(CreateTranslationRequest.as_form), ) -> CreateTranscriptionResponse: """Create a translation to english from the given audio file.""" diff --git a/src/leapfrogai_api/routers/openai/chat.py b/src/leapfrogai_api/routers/openai/chat.py index 23d09beb4..2110090cd 100644 --- a/src/leapfrogai_api/routers/openai/chat.py +++ b/src/leapfrogai_api/routers/openai/chat.py @@ -11,7 +11,7 @@ from leapfrogai_api.backend.helpers import grpc_chat_role from leapfrogai_api.backend.types import ChatCompletionRequest from leapfrogai_api.routers.supabase_session import Session -from leapfrogai_api.utils import get_model_config +from leapfrogai_api.utils.__init__ import config as config_global from leapfrogai_api.utils.config import Config from leapfrogai_sdk.chat.chat_pb2 import ( ChatCompletionResponse as ProtobufChatCompletionResponse, @@ -23,7 +23,7 @@ @router.post("/completions") async def chat_complete( req: ChatCompletionRequest, - model_config: Annotated[Config, Depends(get_model_config)], + model_config: Annotated[Config, Depends(config_global.create)], session: Session, # pylint: disable=unused-argument # required for authorizing endpoint ): """Complete a chat conversation with the given model.""" @@ -32,7 +32,7 @@ async def chat_complete( if model is None: raise HTTPException( status_code=405, - detail=f"Model {req.model} not found. Currently supported models are {list(model_config.models.keys())}", + detail=f"Model {req.model} not found. Currently supported models are {model_config}", ) chat_items: list[lfai.ChatItem] = [] @@ -61,7 +61,7 @@ async def chat_complete( async def chat_complete_stream_raw( req: ChatCompletionRequest, - model_config: Annotated[Config, Depends(get_model_config)], + model_config: Annotated[Config, Depends(config_global.create)], ) -> AsyncGenerator[ProtobufChatCompletionResponse, Any]: """Complete a prompt with the given model.""" # Get the model backend configuration @@ -69,7 +69,7 @@ async def chat_complete_stream_raw( if model is None: raise HTTPException( status_code=405, - detail=f"Model {req.model} not found. Currently supported models are {list(model_config.models.keys())}", + detail=f"Model {req.model} not found. Currently supported models are {model_config}", ) chat_items: list[lfai.ChatItem] = [] diff --git a/src/leapfrogai_api/routers/openai/completions.py b/src/leapfrogai_api/routers/openai/completions.py index 0c0e4fd76..98aa57574 100644 --- a/src/leapfrogai_api/routers/openai/completions.py +++ b/src/leapfrogai_api/routers/openai/completions.py @@ -10,7 +10,7 @@ CompletionRequest, ) from leapfrogai_api.routers.supabase_session import Session -from leapfrogai_api.utils import get_model_config +from leapfrogai_api.utils.__init__ import config as config_global from leapfrogai_api.utils.config import Config import leapfrogai_sdk as lfai @@ -21,7 +21,7 @@ async def complete( session: Session, # pylint: disable=unused-argument # required for authorizing endpoint req: CompletionRequest, - model_config: Annotated[Config, Depends(get_model_config)], + model_config: Annotated[Config, Depends(config_global.create)], ): """Complete a prompt with the given model.""" # Get the model backend configuration @@ -29,7 +29,7 @@ async def complete( if model is None: raise HTTPException( status_code=405, - detail=f"Model {req.model} not found. Currently supported models are {list(model_config.models.keys())}", + detail=f"Model {req.model} not found. Currently supported models are {model_config}", ) request = lfai.CompletionRequest( diff --git a/src/leapfrogai_api/routers/openai/embeddings.py b/src/leapfrogai_api/routers/openai/embeddings.py index fd7a8c7fc..b890c6113 100644 --- a/src/leapfrogai_api/routers/openai/embeddings.py +++ b/src/leapfrogai_api/routers/openai/embeddings.py @@ -7,9 +7,12 @@ from leapfrogai_api.backend.grpc_client import create_embeddings from leapfrogai_api.backend.types import CreateEmbeddingRequest, CreateEmbeddingResponse from leapfrogai_api.routers.supabase_session import Session -from leapfrogai_api.utils import get_model_config +from leapfrogai_api.utils.__init__ import config as global_config from leapfrogai_api.utils.config import Config +import logging + +logger = logging.getLogger(__name__) router = APIRouter(prefix="/openai/v1/embeddings", tags=["openai/embeddings"]) @@ -17,10 +20,10 @@ async def embeddings( session: Session, # pylint: disable=unused-argument # required for authorizing endpoint req: CreateEmbeddingRequest, - model_config: Annotated[Config, Depends(get_model_config)], + model_config: Annotated[Config, Depends(global_config.create)], ) -> CreateEmbeddingResponse: """Create embeddings from the given input.""" - model = model_config.get_model_backend(req.model) + model = model_config.get_model_backend(model=req.model) if model is None: raise HTTPException( status_code=status.HTTP_405_METHOD_NOT_ALLOWED, diff --git a/src/leapfrogai_api/routers/openai/models.py b/src/leapfrogai_api/routers/openai/models.py index c71167d05..a121fd48d 100644 --- a/src/leapfrogai_api/routers/openai/models.py +++ b/src/leapfrogai_api/routers/openai/models.py @@ -4,10 +4,16 @@ from leapfrogai_api.backend.types import ( ModelResponse, ModelResponseModel, + ModelMetadataResponse, ) +from typing import TYPE_CHECKING from leapfrogai_api.routers.supabase_session import Session -from leapfrogai_api.utils import get_model_config -from leapfrogai_api.utils.config import Config + +import logging + +logger = logging.getLogger(__file__) +if TYPE_CHECKING: + from leapfrogai_api.utils.config import Config router = APIRouter(prefix="/openai/v1/models", tags=["openai/models"]) @@ -18,8 +24,14 @@ async def models( ) -> ModelResponse: """List all available models.""" res = ModelResponse(data=[]) - model_config: Config = get_model_config() - for model in model_config.models: - m = ModelResponseModel(id=model) + # shared config object from the app + model_config: "Config" = session.app.state.config + + for model_name, model_data in model_config.models.items(): + meta = ModelMetadataResponse(**dict(model_data.metadata)) + m = ModelResponseModel( + id=model_name, + metadata=meta, + ) res.data.append(m) return res diff --git a/src/leapfrogai_api/routers/openai/requests/run_create_params_request.py b/src/leapfrogai_api/routers/openai/requests/run_create_params_request.py index 2ea23555b..03007af97 100644 --- a/src/leapfrogai_api/routers/openai/requests/run_create_params_request.py +++ b/src/leapfrogai_api/routers/openai/requests/run_create_params_request.py @@ -55,7 +55,7 @@ class RunCreateParamsRequest(RunCreateParamsRequestBase): stream: bool | None = Field( default=None, description="If set to true, the response will be streamed as it's generated.", - example=False, + examples=[False], ) async def create_additional_messages(self, session: Session, thread_id: str): diff --git a/src/leapfrogai_api/routers/openai/requests/run_create_params_request_base.py b/src/leapfrogai_api/routers/openai/requests/run_create_params_request_base.py index ba9d440bb..f43cdef67 100644 --- a/src/leapfrogai_api/routers/openai/requests/run_create_params_request_base.py +++ b/src/leapfrogai_api/routers/openai/requests/run_create_params_request_base.py @@ -64,7 +64,7 @@ CreateMessageRequest, ) from leapfrogai_api.routers.supabase_session import Session -from leapfrogai_api.utils import get_model_config +from leapfrogai_api.utils.__init__ import config as global_config from leapfrogai_sdk.chat.chat_pb2 import ( ChatCompletionResponse as ProtobufChatCompletionResponse, ) @@ -326,6 +326,7 @@ async def generate_message_for_thread( chat_messages, file_ids = await self.create_chat_messages( session, thread, additional_instructions, tool_resources ) + config = await global_config.create() # Generate a new message and add it to the thread creation request chat_response: ChatCompletionResponse = await chat_complete( @@ -339,7 +340,7 @@ async def generate_message_for_thread( stop=None, max_tokens=self.max_completion_tokens, ), - model_config=get_model_config(), + model_config=config, session=session, ) @@ -405,7 +406,7 @@ async def stream_generate_message_for_thread( stop=None, max_tokens=self.max_completion_tokens, ), - model_config=get_model_config(), + model_config=await global_config.create(), ) ) diff --git a/src/leapfrogai_api/utils/config.py b/src/leapfrogai_api/utils/config.py index 60edfb3e8..49f93d781 100644 --- a/src/leapfrogai_api/utils/config.py +++ b/src/leapfrogai_api/utils/config.py @@ -1,149 +1,523 @@ -import fnmatch -import glob +# src/leapfrogai_api/utils/config.py +from __future__ import annotations +import asyncio import logging import os -from typing import List +import functools +import traceback +from typing import Any, Callable, ClassVar, Generator, Literal, Self +from anyio import Event +# from pathlib import Path +from anyio import Path + +import anyio import toml import yaml from watchfiles import Change, awatch +from dataclasses import dataclass, asdict +from leapfrogai_api.backend.types import Capability, Precision, Modality, Format +logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +DEFAULT_CONFIG_FILE: str = "config.yaml" + + +def async_locked(method: Callable) -> Callable: + @functools.wraps(method) + async def wrapper( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + # Acquire the lock if it's not already held + if getattr(self, "_lock", None) is None: + self._lock = anyio.Lock() + try: + # Wrap the method with the lock. + # NOTE: The lock must be released by the task that acquired it + async with self._lock: + return await method(self, *args, **kwargs) + except asyncio.CancelledError: + raise + finally: + # Release the lock regardless + if getattr(self, "_lock", None) is not None: + self._lock = None + + return wrapper + + +@dataclass +class ModelMetadata: + """ + Initializes a ModelMetadata object with the specified model type, dimensions, and precision. + + Parameters: + capabilities (list[Capability], optional): The capabilities of the model e.g. ('embeddings' or 'chat'). + dimensions (Optional[int], optional): Embedding dimensions (for embeddings models). Defaults to None. + precision (str, optional): Model precision (e.g., 'float16', 'float32'). Defaults to 'float32'. + type (Literal["embeddings", "llm"], optional): The type of the model e.g ('embeddings' or 'llm'). + """ + + capabilities: list[Capability] | None = None + dimensions: int | None = None + format: Format | None = None + modalities: list[Modality] | None = None + precision: Precision | None = None + type: Literal["embeddings", "llm"] | None = None + + def has_values(self) -> bool: + """ + Returns True if any of the attributes 'type', 'dimensions', or 'precision' of the object are not None, + and False otherwise. + + :return: bool + """ + # returns true if any public attribute is not None + return any(value is not None for value in asdict(self).values()) + + def __iter__(self) -> Generator[tuple[str, Any], None, None]: + """Make it iterable / possible to use `dict(class_instance)""" + yield from asdict(self).items() + + +@dataclass class Model: + """Represents a model in the LeapFrogAI API.""" + name: str backend: str + metadata: ModelMetadata | None = None + + def __iter__(self) -> Generator[tuple[str, Any], None, None]: + """Make it iterable / possible to use `dict(class_instance)""" + yield from asdict(self).items() + + +class ConfigFile: + def __init__(self, path: Path): + self.path = path + self.filename = str(path.name) + self.models: dict[str, Model] = {} + self._loaded = False + # https://anyio.readthedocs.io/en/stable/synchronization.html#locks + self._lock: anyio.Lock | None = None + # methods used to parse models from a file + self.parsers: dict[str, Callable] = { + ".toml": toml.loads, + ".yaml": yaml.safe_load, + ".yml": yaml.safe_load, + } + + def parse_models(self, loaded_artifact: dict[str, Any]) -> Self: + # Step 1: Clear models to avoid duplicates for the given config file + self.models.clear() + + # Step 2: Make sure that the config file contains models + models_to_load = loaded_artifact.get("models", []) + if not models_to_load: + logger.error(f"Failed to load and parse config from {self.path}") + return self + + # Step 3: Load and parse models + for m in models_to_load: + model_name = m["name"] + model_config = Model( + name=model_name, + backend=m["backend"], + metadata=ModelMetadata(**m["metadata"]) if m.get("metadata") else None, + ) + self.models[model_name] = model_config + logger.debug(f"Added {model_name} to model config") + logger.debug(f"Successfully loaded and parsed config from {self.path}") + self._loaded = True + return self - def __init__(self, name: str, backend: str, capabilities: List[str] | None = None): - self.name = name - self.backend = backend + async def _load_from_file(self, path: Path) -> dict[str, Any]: + """ + Asynchronously loads the content of a file from the given path and returns it as a dictionary. + + Args: + path (Path): The path to the file to be loaded. + + Returns: + dict[str, Any]: A dictionary containing the content of the file. If the file type is not supported or an error occurs during loading, an empty dictionary is returned. + """ + + try: + async with await path.open("r") as contents: + # If a known file type is found, use the corresponding parser + if (parser := self.parsers.get(path.suffix)) is not None: + logger.debug(f"Loading config file: {path}") + loaded_artifact = parser(await contents.read()) + logger.debug(f"Loaded artifact content: {loaded_artifact}") + return loaded_artifact + + # Else, return an empty dict if the file type is not supported + logger.error(f"Unsupported file type: {path}") + return {} + + except Exception as e: + logger.error(f"Error loading config file {path}: {e}") + return {} # Return an empty dict if there's an error + + @async_locked + async def load_config_file(self) -> None: + logger.debug(f"Loading config file: {self.path}") + try: + if not (loaded_artifact := await self._load_from_file(path=self.path)): + return + self.parse_models(loaded_artifact) + except Exception as e: + logger.error(f"Error loading config file {self.path}: {e}") + + async def aload(self) -> None: + # We make a new lock for each config file to avoid race conditions + async with anyio.Lock(): + if not await self.path.exists(): + logger.error(f"Config file does not exist: {self.path}") + return + await self.load_config_file() + return self + + @async_locked + async def aunload(self) -> None: + self.models.clear() + logger.debug(f"Unloaded config file: {self.path}") + self._loaded = False + + def __await__(self): + # Load the config file on await + return self.aload().__await__() + + def __str__(self) -> str: + return f"Path: {self.path}, Models: {self.models}" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(path={self.path}, models={self.models})" class Config: - models: dict[str, Model] = {} - config_sources: dict[str, list] = {} - - def __init__( - self, models: dict[str, Model] = {}, config_sources: dict[str, list] = {} - ): - self.models = models - self.config_sources = config_sources - - def __str__(self): - return f"Models: {self.models}" - - async def watch_and_load_configs(self, directory=".", filename="config.yaml"): - # Get the config directory and filename from the environment variables if provided - env_directory = os.environ.get("LFAI_CONFIG_PATH", directory) - if env_directory is not None and env_directory != "": - directory = env_directory - env_filename = os.environ.get("LFAI_CONFIG_FILENAME", filename) - if env_filename is not None and env_filename != "": - filename = env_filename - - # Process all the configs that were already in the directory - self.load_all_configs(directory, filename) - - # Watch the directory for changes until the end of time - while True: - async for changes in awatch(directory, recursive=False, step=50): - # get two unique lists of files that have been (updated files and deleted files) - # (awatch can return duplicates depending on the type of updates that happen) - logger.info("Config changes detected: {}".format(changes)) - unique_new_files = set() - unique_deleted_files = set() - for change in changes: - if change[0] == Change.deleted: - unique_deleted_files.add(os.path.basename(change[1])) - else: - unique_new_files.add(os.path.basename(change[1])) - - # filter the files to those that match the filename or glob pattern - filtered_new_matches = fnmatch.filter(unique_new_files, filename) - filtered_deleted_matches = fnmatch.filter( - unique_deleted_files, filename - ) - - # load all the updated config files - for match in filtered_new_matches: - self.load_config_file(directory, match) - - # remove deleted models - for match in filtered_deleted_matches: - self.remove_model_by_config(match) - - async def clear_all_models(self): - # reset the model config on shutdown (so old model configs don't get cached) - self.models = {} - self.config_sources = {} - logger.info("All models have been removed") - - def load_config_file(self, directory: str, config_file: str): - logger.info("Loading config file: {}/{}".format(directory, config_file)) - - # load the config file into the config object - config_path = os.path.join(directory, config_file) - with open(config_path) as c: - # Load the file into a python object - loaded_artifact = {} - if config_path.endswith(".toml"): - loaded_artifact = toml.load(c) - elif config_path.endswith(".yaml"): - loaded_artifact = yaml.safe_load(c) + """Configuration class for the Leapfrog AI API. + + This class is used to dynamically load and manage the configuration files for the Leapfrog AI API. + """ + + _instance: ClassVar["Config | None"] = None + _watch_task: ClassVar[asyncio.Task | None] = None + # https://anyio.readthedocs.io/en/latest/synchronization.html#events + _stop_event: ClassVar[asyncio.Event] = Event() + _testing: ClassVar[bool] = False + + def __new__(cls): + """This method is used to ensure that only one instance of the Config class is created.""" + if cls._instance is None: + cls._instance = super(Config, cls).__new__(cls) + return cls._instance + + def __init__(self) -> None: + """Only one instance of the Config class can be created.""" + if not hasattr(self, "initialized"): + self.config_files: dict[str, ConfigFile] = {} + self.models: dict[str, Model] = {} + self._config_dir = None + self._config_filename = None + self._lock = asyncio.Lock() + self.initialized = True + + @classmethod + async def create(cls, testing: bool = False) -> "Config": + """ + Creates a new instance of the Config class. + + Args: + testing (bool): A flag indicating if the Config instance is being created for testing purposes. + Defaults to False. + + Returns: + Config: The created instance of the Config class. + """ + logger.debug("Entering Config.create") + if cls._instance is None: + cls._instance = cls() + cls._testing = testing + await cls._instance.initialize() + return cls._instance + + async def initialize(self, testing: bool = False) -> Self: + logger.debug("Initializing Config") + try: + self._initialize_from_env() + await self.load_all_configs() + await self.start_watching(testing) + + if not hasattr(self, "_initialized") or self._initialized is None: + self._initialized = True + logger.debug("Config initialized successfully") else: - # TODO: Return an error ??? - logger.error(f"Unsupported file type: {config_path}") - return + logger.debug("Existing Config found, skipping initialization") - # parse the object into our config - self.parse_models(loaded_artifact, config_file) - logger.info("loaded artifact at {}".format(config_path)) + except Exception as e: + logger.error(f"Error during Config initialization: {e}") + raise + return self - return + def _initialize_from_env(self): + """Updates the config directory and filename from environment variables. - def load_all_configs(self, directory="", filename="config.yaml"): - logger.info( - "Loading all configs in {} that match the name '{}'".format( - directory, filename - ) + NOTE: At present this gets called as part of `initialize`. As such, it means that + Any iterations of `watch_for_changes` will use the environment variables, even if they are updated + AFTER the class has been instantiated. This behavior is intentional, but only for testing purposes. + """ + logger.debug("Initializing from environment") + self._config_dir = os.environ.get("LFAI_CONFIG_PATH", ".") + self._config_filename = os.environ.get("LFAI_CONFIG_FILENAME", "*config.yaml") + + async def _load_config_file(self, path: Path) -> Self: + config_file = ConfigFile(path=path) + await config_file + self.config_files[path.name] = config_file + self.models.update(config_file.models) + return self + + async def load_all_configs(self) -> None: + logger.debug( + f"Loading all configs in {self._config_dir} matching {self._config_filename}" + ) + try: + path = Path(self._config_dir) + config_files = path.glob(self._config_filename) + async for config_path in config_files: + logger.debug(f"Loading config file: {config_path}") + await self._load_config_file(path=config_path) + logger.debug(f"Loaded configs: {list(self.config_files.keys())}") + logger.debug(f"Current models: {list(self.models.keys())}") + except Exception as e: + logger.error(f"Error loading configs: {e}") + raise e + + async def start_watching(self, testing: bool = False): + if self._watch_task is not None: + if self._watch_task.done(): + self._watch_task = None + else: + logger.warning("Watch task is already running") + return self._watch_task + + logger.debug("Starting config watcher") + self._watch_task = asyncio.create_task( + self._watch_wrapper(testing=testing), + name="Config Watcher Wrapper Worker", + ) + logger.debug("Started watching for config changes") + return self._watch_task + + @classmethod + async def stop_watching(cls) -> None: + """ + Stops the configuration watcher. + + This method stops the configuration watcher by setting the stop event and + waiting for the watch task to finish. If the watch task does not finish within + the specified timeout, a warning is logged. If the watch task is cancelled or + an error occurs while stopping the watch task, an error is logged and the + exception is re-raised. + + This method is idempotent. If the watcher is already stopped, it is ignored. + + Raises: + Exception: If an error occurs while stopping the watch task. + """ + if cls._watch_task is None: + logger.warning("No watch task is running") + return + + logger.info("Stopping config watcher") + cls._stop_event.set() + + try: + await asyncio.wait_for(cls._watch_task, timeout=5.0) + except asyncio.TimeoutError: + logger.warning("Timeout while waiting for watch task to finish") + except asyncio.CancelledError: + logger.debug("Watch task was cancelled") + except Exception as e: + logger.error(f"Error while stopping watch task: {e}") + raise e + finally: + if cls._watch_task and not cls._watch_task.done(): + logger.debug("Force cancelling the watch task") + cls._watch_task.cancel() + try: + await cls._watch_task + except asyncio.CancelledError: + pass + + cls._watch_task = None + cls._stop_event.clear() + logger.info("Stopped watching for config changes") + + async def _watch_wrapper(self, testing: bool) -> Self: + """ + Executes the watch task for configuration changes. + + This function is responsible for running the watch task for configuration changes. It calls the `watch_for_changes` method with the provided `testing` flag. If the task is cancelled, it logs a message. If an exception occurs during the execution, it logs an error message and re-raises the exception. Finally, it ensures that the `cleanup` method is called after the task ends. + + Args: + testing (bool): A flag indicating if the watch task is running in a testing environment. + + Returns: + self: The instance of the class. + + Raises: + Exception: If an error occurs during the watch task execution. + """ + try: + await self.watch_for_changes(testing=testing) + except asyncio.CancelledError: + logger.info("Watch task was cancelled") + except Exception: + logger.error(f"Error in watch task: {traceback.format_exc()}") + raise + finally: + logger.debug("Cleaning up watch task") + await self.cleanup() # Ensure cleanup is called after the watch ends + return self + + async def watch_for_changes(self, testing: bool = False) -> None: + """ + Watches for changes in the configuration directory. + + Args: + testing (bool): A flag indicating if the watch task is running in a testing environment. Defaults to False. + + Returns: + None + """ + if self._watch_task is None: + logger.warning("No watch task is running") + return + + logger.debug("Watching for changes") + + try: + async for changes in awatch( + self._config_dir, + recursive=False, + step=50, # Normal interval + stop_event=self.__class__._stop_event, + debug=False, + ): + logger.debug(f"Detected changes: {changes}") + await self.initialize() + await self._handle_config_changes(changes) + except asyncio.CancelledError as e: + logger.warning("Watch for changes task was cancelled") + raise e + finally: + logger.debug("Finished watch_for_changes") + + async def _handle_config_changes( + self, + changes: list[tuple[Change, str]], + ) -> None: + """ + Handles the detected changes in the configuration files. + + Args: + changes (list[tuple[Change, str]]): A list of tuples representing the changes detected. Each tuple contains the type of change (Change.added, Change.modified, or Change.deleted) and the file path. + + Returns: + None + + This function iterates over the list of changes and performs the necessary actions based on the type of change. If a change is detected in a configuration file, it updates or removes the corresponding config file and model from the internal state. If a change is detected in a non-configuration file, it is ignored. + + After handling the changes, the function logs the updated state of the config files and models. + + Note: This function assumes that the configuration files and models are stored in the `config_files` and `models` attributes of the object, respectively. + """ + logger.info(f"Detected changes: {changes}") + for change_type, file_path in changes: + path = Path(file_path) + if path.match(self._config_filename): + if change_type in (Change.added, Change.modified): + logger.info(f"Adding or updating config file: {path}") + await self._load_config_file(path=path) + elif change_type == Change.deleted: + logger.info(f"Removing config file from handler: {path}") + if config_file := self.config_files.pop(path.name, None): + for model_name in config_file.models: + self.models.pop(model_name, None) + else: + logger.debug(f"Ignoring change to non-config file: {path}") + + logger.debug( + f"Updated state - Config files: {list(self.config_files.keys())}, Models: {list(self.models.keys())}" ) - if not os.path.exists(directory): - logger.error("The config directory ({}) does not exist".format(directory)) - return "THE CONFIG DIRECTORY DOES NOT EXIST" + @classmethod + async def cleanup(cls): + """ + Cleans up the Config instance by stopping the watch task and clearing all models. - # Get all config files and load them into the config object - config_files = glob.glob(os.path.join(directory, filename)) - for config_path in config_files: - dir_path, file_path = os.path.split(config_path) - self.load_config_file(directory=dir_path, config_file=file_path) + This method is a class method and is used to clean up the Config instance when it is no longer needed. - return + Args: + None + + Returns: + None + """ + await cls.stop_watching() + if cls._instance: + await cls._instance.clear_all_models() + cls._instance = None + logger.debug("Config instance cleanup complete") def get_model_backend(self, model: str) -> Model | None: - if model in self.models: - return self.models[model] - else: - return None - - def parse_models(self, loaded_artifact, config_file): - for m in loaded_artifact["models"]: - model_config = Model(name=m["name"], backend=m["backend"]) - - self.models[m["name"]] = model_config - try: - self.config_sources[config_file].append(m["name"]) - except KeyError: - self.config_sources[config_file] = [m["name"]] - logger.info("added {} to model config".format(m["name"])) - - def remove_model_by_config(self, config_file): - for model_name in self.config_sources[config_file]: - self.models.pop(model_name) - logger.info("removed {} from model config".format(model_name)) - - # clear config once all corresponding models are deleted - self.config_sources.pop(config_file) + """Get the backend for a model.""" + return self.models.get(model) + + async def clear_all_models(self) -> None: + """Clear all models.""" + logger.debug("Clearing all models") + self.models.clear() + for config_file in list(self.config_files.values()): + logger.debug(f"Removing config file: {config_file.filename}") + await config_file.aunload() + self.config_files.clear() + logger.debug("All models have been removed") + + def to_dict(self) -> dict[str, Any]: + """The method used to serialize the Config instance to a dictionary.""" + models_dict = { + # TODO: Make this dynamically generated / structured + name: { + "name": name, + "backend": model.backend, + "metadata": model.metadata.dict() if model.metadata else None, + } + for name, model in self.models.items() + } + + config_sources = { + config_file.filename: list(config_file.models.keys()) + for config_file in self.config_files.values() + } + + return { + "config_sources": config_sources, + "models": models_dict, + } + + def __del__(self): + if self.__class__._watch_task: + self.__class__._watch_task.cancel() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.to_dict()})" diff --git a/src/leapfrogai_ui/supabase/seed.sql b/src/leapfrogai_ui/supabase/seed.sql index 77c42ed92..c6787c15e 100644 --- a/src/leapfrogai_ui/supabase/seed.sql +++ b/src/leapfrogai_ui/supabase/seed.sql @@ -65,4 +65,3 @@ INSERT INTO from auth.users ); - diff --git a/tests/data/test.txt b/tests/data/test.txt deleted file mode 100644 index 0a9012568..000000000 --- a/tests/data/test.txt +++ /dev/null @@ -1 +0,0 @@ -Testing \ No newline at end of file diff --git a/tests/pytest/leapfrogai_api/fixtures/repeater-test-config.yaml b/tests/pytest/leapfrogai_api/fixtures/repeater-test-config.yaml index c866f1bce..abda621b3 100644 --- a/tests/pytest/leapfrogai_api/fixtures/repeater-test-config.yaml +++ b/tests/pytest/leapfrogai_api/fixtures/repeater-test-config.yaml @@ -1,3 +1,3 @@ models: -- name: repeater - backend: localhost:50051 + - name: repeater + backend: 0.0.0.0:50051 diff --git a/tests/pytest/leapfrogai_api/test_api.py b/tests/pytest/leapfrogai_api/test_api.py index a80df6b6c..7687568b9 100644 --- a/tests/pytest/leapfrogai_api/test_api.py +++ b/tests/pytest/leapfrogai_api/test_api.py @@ -1,20 +1,33 @@ +import asyncio +from dataclasses import dataclass, field import json +import logging + import os import shutil -import time -from typing import Optional + +from typing import Any, TYPE_CHECKING + +from leapfrogai_api.utils.config import Config import pytest +import pytest_asyncio + from fastapi.applications import BaseHTTPMiddleware from fastapi.security import HTTPBearer from fastapi.testclient import TestClient from starlette.middleware.base import _CachedRequest from supabase import ClientOptions import leapfrogai_api.backend.types as lfai_types -from leapfrogai_api.main import app + +from leapfrogai_api.main import create_app from leapfrogai_api.routers.supabase_session import init_supabase_client +if TYPE_CHECKING: + from fastapi import FastAPI + security = HTTPBearer() +logger = logging.getLogger(__name__) # Set environment variables that the TestClient will use LFAI_CONFIG_FILENAME = os.environ["LFAI_CONFIG_FILENAME"] = "repeater-test-config.yaml" @@ -22,33 +35,56 @@ os.path.dirname(__file__), "fixtures" ) LFAI_CONFIG_FILEPATH = os.path.join(LFAI_CONFIG_PATH, LFAI_CONFIG_FILENAME) +# Set reusable variables for test runs +NO_MODEL_METADATA: dict[str, Any] = dict( + models=dict( + repeater=dict( + backend="0.0.0.0:50051", + name="repeater", + metadata=None, + ) + ) +) +EMPTY_CONFIG: dict[str, Any] = dict( + config_sources={}, + models={}, +) +REQUEST_URI = "/leapfrogai/v1/models" + +# Set pytest markers / fixtures etc. +SKIP_IF_NO_REPEATER_ENV_VAR = pytest.mark.skipif( + os.environ.get("LFAI_RUN_REPEATER_TESTS") != "true", + reason="LFAI_RUN_REPEATER_TESTS envvar was not set to true", +) + + +@pytest.fixture(autouse=True) +def anyio_backend(): + """This is necessary to prevent `watchfiles` from keeping an open thread with anyio""" + return "asyncio" ######################### ######################### -class AsyncClient: - """Supabase client class.""" +async def mock_init_supabase_client(): + """Returns a mocked supabase client""" - def __init__( - self, - supabase_url: str, - supabase_key: str, - access_token: Optional[str] = None, - options: ClientOptions = ClientOptions(), - ): - self.supabase_url = supabase_url - self.supabase_key = supabase_key - self.access_token = access_token - self.options = options + @dataclass + class AsyncClient: + """Supabase client class.""" + supabase_url: str = "" + supabase_key: str = "" + access_token: str | None = None + options: ClientOptions = field(default_factory=ClientOptions) -async def mock_init_supabase_client() -> AsyncClient: - return AsyncClient("", "", "", ClientOptions()) + return AsyncClient() async def pack_dummy_bearer_token(request: _CachedRequest, call_next): + """Creates a callable that adds a dummy bearer token to the request header""" request.headers._list.append( ( "authorization".encode(), @@ -59,58 +95,128 @@ async def pack_dummy_bearer_token(request: _CachedRequest, call_next): @pytest.fixture -def dummy_auth_middleware(): +def auth_client(): + """Creates a client with dummy auth middleware configured""" + app = create_app(testing=True) app.dependency_overrides[init_supabase_client] = mock_init_supabase_client app.user_middleware.clear() app.middleware_stack = None app.add_middleware(BaseHTTPMiddleware, dispatch=pack_dummy_bearer_token) app.middleware_stack = app.build_middleware_stack() + with TestClient(app) as client: + yield client + + +@pytest_asyncio.fixture(scope="function") +async def test_app_factory(monkeypatch): + """Factory fixture for creating an app and config.""" + + # NOTE: This primarily existst to make it easy to override env vars / lifespan + async def _create_app( + config_path: str | None = None, + config_filename: str | None = None, + ) -> tuple["FastAPI", "Config"]: + if config_path is None: + config_path = os.environ.get("LFAI_CONFIG_PATH", LFAI_CONFIG_PATH) + if config_filename is None: + config_filename = os.environ.get( + "LFAI_CONFIG_FILENAME", LFAI_CONFIG_FILENAME + ) + monkeypatch.setenv("LFAI_CONFIG_PATH", config_path) + monkeypatch.setenv("LFAI_CONFIG_FILENAME", config_filename) + config = await Config.create(testing=True) + app = create_app(testing=True, lifespan=None) + return app, config + + try: + yield _create_app + finally: + pass + + +@pytest.mark.anyio +async def test_config_load(test_app_factory): + """Test that the config is loaded correctly.""" + config_path = LFAI_CONFIG_PATH + model_name = "repeater" + + app, _ = await test_app_factory( + config_path=config_path, + config_filename=LFAI_CONFIG_FILENAME, + ) + with TestClient(app=app) as client: + response = client.get(REQUEST_URI) -def test_config_load(): - """Test that the config is loaded correctly.""" - with TestClient(app) as client: - response = client.get("/leapfrogai/v1/models") + result = response.json() + assert response.status_code == 200, response.json() - assert response.status_code == 200 - assert response.json() == { - "config_sources": {"repeater-test-config.yaml": ["repeater"]}, - "models": {"repeater": {"backend": "localhost:50051", "name": "repeater"}}, - } + expected_response: dict[str, dict[str, Any]] = dict( + config_sources={LFAI_CONFIG_FILENAME: [model_name]}, + **NO_MODEL_METADATA, + ) + assert ( + expected_response == result + ), f"Assertions failed due to {expected_response} != {result}" -def test_config_delete(tmp_path): +@pytest.mark.anyio +async def test_config_delete(test_app_factory, tmp_path): """Test that the config is deleted correctly.""" - # move repeater-test-config.yaml to temp dir so that we can remove it at a later step + + # Step 1: Copy the config file to the temporary directory tmp_config_filepath = shutil.copyfile( - LFAI_CONFIG_FILEPATH, os.path.join(tmp_path, LFAI_CONFIG_FILENAME) + LFAI_CONFIG_FILEPATH, + tmp_path / LFAI_CONFIG_FILENAME, + ) + app, _ = await test_app_factory( + config_path=str(tmp_path), + config_filename=LFAI_CONFIG_FILENAME, ) - os.environ["LFAI_CONFIG_PATH"] = str(tmp_path) + model_name = "repeater" + expected_response = { + "config_sources": {LFAI_CONFIG_FILENAME: [model_name]}, + "models": { + model_name: { + "backend": "0.0.0.0:50051", + "name": model_name, + "metadata": None, + } + }, + } - with TestClient(app) as client: - # ensure the API loads the temp config - response = client.get("/leapfrogai/v1/models") - assert response.status_code == 200 - - assert response.json() == { - "config_sources": {"repeater-test-config.yaml": ["repeater"]}, - "models": {"repeater": {"backend": "localhost:50051", "name": "repeater"}}, - } - # delete source config from temp dir + with TestClient(app=app) as client: + # Step 2: Ensure the API loads the temp config + response = client.get(REQUEST_URI) + result = response.json() + assert response.status_code == 200, response.json() + assert ( + expected_response == result + ), f"Assertions failed due to {expected_response} != {result}" + + # Step 3: Delete the source config file from temp dir os.remove(tmp_config_filepath) + logger.debug(f"Deleted config file: {tmp_config_filepath}") + + # Step 4: Await a context switch to allow the API to detect the change. + await asyncio.sleep(0.1) - # wait for the api to be able to detect the change - time.sleep(0.5) - # assert response is now empty - response = client.get("/leapfrogai/v1/models") - assert response.status_code == 200 - assert response.json() == {"config_sources": {}, "models": {}} + # Step 5: Make another request that should have no models loaded + response = client.get(REQUEST_URI) + logger.debug(f"Received response after deletion: {response}") + assert response.status_code == 200, response.json() - os.environ["LFAI_CONFIG_PATH"] = os.path.join(os.path.dirname(__file__), "fixtures") + assert response.json() == dict( + config_sources={}, + models={}, + ) -def test_routes(): +@pytest.mark.anyio +async def test_routes(test_app_factory): """Test that the expected routes are present.""" + app, _ = await test_app_factory() + expected_routes = { "/docs": ["GET", "HEAD"], "/healthz": ["GET"], @@ -168,165 +274,150 @@ def test_routes(): ["DELETE"], ), ] - - actual_routes = app.routes - for route in actual_routes: - if hasattr(route, "path") and route.path in expected_routes: - assert route.methods == set(expected_routes[route.path]) - del expected_routes[route.path] - - for route, name, methods in openai_routes: - found = False - for actual_route in actual_routes: - if ( - hasattr(actual_route, "path") - and actual_route.path == route - and actual_route.name == name - ): - assert actual_route.methods == set(methods) - found = True - break - assert found, f"Missing route: {route}, {name}, {methods}" - - assert len(expected_routes) == 0 + # test the expected routes + for path, methods in expected_routes.items(): + route = next( # iterate through the routes to find the route with the expected path, if any + (r for r in app.routes if getattr(r, "path", None) == path), + None, + ) + assert route is not None, f"Route {path} not found." + assert route.methods == set(methods), f"Methods for {path} do not match." + + for path, name, methods in openai_routes: + route = next( + ( + r + for r in app.routes + if getattr(r, "path", None) == path and r.name == name + ), + None, + ) + assert route is not None, f"Route {path} with name {name} not found." + assert route.methods == set( + methods + ), f"Methods for {path} with name {name} do not match." def test_healthz(): """Test the healthz endpoint.""" + + app = create_app(testing=True) with TestClient(app) as client: response = client.get("/healthz") - assert response.status_code == 200 - assert response.json() == {"status": "ok"} + assert response.status_code == 200, response.json() + assert response.json() == {"status": "ok"} -@pytest.mark.skipif( - os.environ.get("LFAI_RUN_REPEATER_TESTS") != "true", - reason="LFAI_RUN_REPEATER_TESTS envvar was not set to true", -) -def test_embedding(dummy_auth_middleware): + +@SKIP_IF_NO_REPEATER_ENV_VAR +def test_embedding(auth_client): """Test the embedding endpoint.""" - expected_embedding = [0.0 for _ in range(10)] - with TestClient(app) as client: - # Send request to client - embedding_request = lfai_types.CreateEmbeddingRequest( - model="repeater", - input="This is the embedding input text.", - ) - response = client.post( - "/openai/v1/embeddings", json=embedding_request.model_dump() - ) - assert response.status_code == 200 + # Send request to client + embedding_request = lfai_types.CreateEmbeddingRequest( + model="repeater", + input="This is the embedding input text.", + ) + response = auth_client.post( + "/openai/v1/embeddings", + json=embedding_request.model_dump(), + ) + response_obj = response.json() + assert response.status_code == 200, response_obj - # parse through the response - response_obj = response.json() - assert "data" in response_obj - assert len(response_obj.get("data")) == 1 + # parse through the response + assert (data := response_obj.get("data")) is not None, response_obj + assert len(data) == 1 - # validate the expected response - data_obj = response_obj.get("data")[0] - assert "embedding" in data_obj - assert data_obj.get("embedding") == expected_embedding + # validate the expected response + data_obj = data[0] # type: ignore + assert "embedding" in data_obj + assert data_obj.get("embedding") == ([0.0] * 10) # list of 10 floats -@pytest.mark.skipif( - os.environ.get("LFAI_RUN_REPEATER_TESTS") != "true", - reason="LFAI_RUN_REPEATER_TESTS envvar was not set to true", -) -def test_chat_completion(dummy_auth_middleware): +@SKIP_IF_NO_REPEATER_ENV_VAR +def test_chat_completion(auth_client): """Test the chat completion endpoint.""" - with TestClient(app) as client: - input_content = "this is the chat completion input." - chat_completion_request = lfai_types.ChatCompletionRequest( - model="repeater", - messages=[lfai_types.ChatMessage(role="user", content=input_content)], - ) - response = client.post( - "/openai/v1/chat/completions", json=chat_completion_request.model_dump() - ) - assert response.status_code == 200 - - assert response - - # parse through the chat completion response - response_obj = response.json() - assert "choices" in response_obj - - # parse the choices from the response - response_choices = response_obj.get("choices") - assert len(response_choices) == 1 - assert "message" in response_choices[0] - assert "content" in response_choices[0].get("message") - - # parse finish reason - assert "finish_reason" in response_choices[0] - assert "stop" == response_choices[0].get("finish_reason") - - # parse usage data - response_usage = response_obj.get("usage") - prompt_tokens = response_usage.get("prompt_tokens") - completion_tokens = response_usage.get("completion_tokens") - total_tokens = response_usage.get("total_tokens") - assert prompt_tokens == len(input_content) - assert completion_tokens == len(input_content) - assert total_tokens == len(input_content) * 2 - - # validate that the repeater repeated - assert response_choices[0].get("message").get("content") == input_content - - -@pytest.mark.skipif( - os.environ.get("LFAI_RUN_REPEATER_TESTS") != "true", - reason="LFAI_RUN_REPEATER_TESTS envvar was not set to true", -) -def test_stream_chat_completion(dummy_auth_middleware): + input_content = "this is the chat completion input." + chat_completion_request = lfai_types.ChatCompletionRequest( + model="repeater", + messages=[lfai_types.ChatMessage(role="user", content=input_content)], + ) + response = auth_client.post( + "/openai/v1/chat/completions", + json=chat_completion_request.model_dump(), + ) + response_obj = response.json() + assert response.status_code == 200, response_obj + + # parse through the chat completion response + response_choices: list[dict[str, Any]] = response_obj.get("choices") + assert response_choices is not None, response_obj + assert len(response_choices) == 1 + first_choice = response_choices[0] + + assert (response_message := first_choice.get("message")) is not None + assert (response_content := response_message.get("content")) is not None + assert first_choice.get("finish_reason") == "stop", first_choice + + # parse usage data + response_usage = response_obj.get("usage") + prompt_tokens = response_usage.get("prompt_tokens") + completion_tokens = response_usage.get("completion_tokens") + total_tokens = response_usage.get("total_tokens") + assert prompt_tokens == len(input_content) + assert completion_tokens == len(input_content) + assert total_tokens == len(input_content) * 2 + + # validate that the repeater repeated + assert response_content == input_content + + +@SKIP_IF_NO_REPEATER_ENV_VAR +def test_stream_chat_completion(auth_client): """Test the stream chat completion endpoint.""" - with TestClient(app) as client: - input_content = "this is the stream chat completion input." + input_content = "this is the stream chat completion input." - chat_completion_request = lfai_types.ChatCompletionRequest( - model="repeater", - messages=[lfai_types.ChatMessage(role="user", content=input_content)], - stream=True, - ) - - response = client.post( - "/openai/v1/chat/completions", json=chat_completion_request.model_dump() - ) - assert response.status_code == 200 - assert ( - response.headers.get("content-type") == "text/event-stream; charset=utf-8" - ) + chat_completion_request = lfai_types.ChatCompletionRequest( + model="repeater", + messages=[lfai_types.ChatMessage(role="user", content=input_content)], + stream=True, + ) - # parse through the streamed response - iter_length = 0 - iter_lines = response.iter_lines() - for line in iter_lines: - # skip the empty, and non-data lines - if ": " in line: - strings = line.split(": ", 1) - - # Process all the data responses that is not the sig_stop signal - if strings[0] == "data" and strings[1] != "[DONE]": - stream_response = json.loads(strings[1]) - assert "choices" in stream_response - choices = stream_response.get("choices") - assert len(choices) == 1 - assert "delta" in choices[0] - assert "content" in choices[0].get("delta") - assert choices[0].get("delta").get("content") == input_content - iter_length += 1 - # parse finish reason - assert "finish_reason" in choices[0] - assert "stop" == choices[0].get("finish_reason") - # parse usage data - response_usage = stream_response.get("usage") - prompt_tokens = response_usage.get("prompt_tokens") - completion_tokens = response_usage.get("completion_tokens") - total_tokens = response_usage.get("total_tokens") - assert prompt_tokens == len(input_content) - assert completion_tokens == len(input_content) - assert total_tokens == len(input_content) * 2 - - # The repeater only response with 5 messages - assert iter_length == 5 + response = auth_client.post( + "/openai/v1/chat/completions", + json=chat_completion_request.model_dump(), + ) + assert response.status_code == 200, response.json() + + assert response.headers.get("content-type") == "text/event-stream; charset=utf-8" + + # parse through the streamed response + iter_length = 0 + iter_lines = response.iter_lines() + for line in iter_lines: + # skip the empty, and non-data lines + if ": " in line: + # parse through the streamed response + key, content = line.split(": ", 1) + # Process all the data responses that is not the sig_stop signal + if key == "data" and content != "[DONE]": + # Check the content of the response + stream_response = json.loads(content) + assert (choices := stream_response.get("choices")) is not None + assert len(choices) == 1 + first_choice = choices[0] + response_usage = stream_response.get("usage") + + # Check the content of the "first choice" + assert first_choice.get("delta", {}).get("content") == input_content + iter_length += 1 + + # parse finish reason + assert "stop" == first_choice.get("finish_reason", None), first_choice + assert response_usage.get("prompt_tokens", 0) == len(input_content) + assert response_usage.get("completion_tokens", 0) == len(input_content) + assert response_usage.get("total_tokens", 0) == (len(input_content) * 2) + + # The repeater only response with 5 messages + assert iter_length == 5 diff --git a/tests/unit/leapfrogai_api/utils/conftest.py b/tests/unit/leapfrogai_api/utils/conftest.py new file mode 100644 index 000000000..5c6ee54fd --- /dev/null +++ b/tests/unit/leapfrogai_api/utils/conftest.py @@ -0,0 +1,64 @@ +import os +import pytest +import pytest_asyncio +from anyio import Path +from typing import AsyncGenerator, TypeAlias, Callable +from leapfrogai_api.utils.config import Config + +TOML_CONFIG_FILE: str = "test_config.toml" +YAML_CONFIG_FILE: str = "test_config.yaml" +INVALID_CONFIG_FILE: str = "invalid_config.fake" +NON_EXISTENT_DIR: str = "/path/to/non/existent/directory" + +ConfigMaker: TypeAlias = Callable[[str, str], Config] + + +@pytest_asyncio.fixture(scope="function") +async def config_files() -> dict[str, Path]: + """Fixture to return the path to the tests directory.""" + # NOTE: this is an async `anyio.Path` object that is _mostly_ comptatible with `pathlib.Path` + test_dir = await Path(str(__file__)).resolve() + dir_name = Path( + os.path.dirname(test_dir) + ) # dir the file lives in e.g `/tests/unit/leapfrogai_api/utils` + return { + YAML_CONFIG_FILE: await (dir_name / YAML_CONFIG_FILE).resolve(), + TOML_CONFIG_FILE: await (dir_name / TOML_CONFIG_FILE).resolve(), + INVALID_CONFIG_FILE: await (dir_name / INVALID_CONFIG_FILE).resolve(), + } + + +@pytest_asyncio.fixture +async def config_factory(monkeypatch) -> AsyncGenerator[ConfigMaker, None]: + """Used to instantiate a Config object while overriding the initial env vars that dictate the folders / files to use""" + + async def _create_config( + config_path: str | None = None, + config_filename: str | None = None, + ) -> Config: + # Check for either arguments or env vars + config_path = config_path or os.environ.get("LFAI_CONFIG_PATH") + config_filename = config_filename or os.environ.get("LFAI_CONFIG_FILENAME") + + # if either are set, patch the env vars + if config_path is not None: + monkeypatch.setenv("LFAI_CONFIG_PATH", config_path) + if config_filename is not None: + monkeypatch.setenv("LFAI_CONFIG_FILENAME", config_filename) + # Instatiate the config now that we have patched our environment variables + config = await Config.create(testing=True) + return config + + yield _create_config + + +@pytest.fixture(autouse=True) +def anyio_backend(): + """This is necessary to prevent `watchfiles` from keeping an open thread with anyio""" + return "asyncio" + + +@pytest_asyncio.fixture +def parent_dir(): + """Return the parent directory of the current file.""" + return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) diff --git a/tests/unit/leapfrogai_api/utils/invalid_config.fake b/tests/unit/leapfrogai_api/utils/invalid_config.fake new file mode 100644 index 000000000..4b05e443c --- /dev/null +++ b/tests/unit/leapfrogai_api/utils/invalid_config.fake @@ -0,0 +1 @@ +# This file exists as a way to test validating that it'll skip "invalid" extensions diff --git a/tests/unit/leapfrogai_api/utils/test_config.py b/tests/unit/leapfrogai_api/utils/test_config.py new file mode 100644 index 000000000..31b552ede --- /dev/null +++ b/tests/unit/leapfrogai_api/utils/test_config.py @@ -0,0 +1,373 @@ +from typing import AsyncGenerator, TypeAlias, Callable +import pytest +import asyncio + +from anyio import Path +import pytest_asyncio +from unittest.mock import patch, AsyncMock, MagicMock +import toml +from watchfiles import Change + +from leapfrogai_api.utils.config import ( + Config, + ConfigFile, + Model, +) + +# This is just to make the IDE happy about what the factory does +ConfigFileMaker: TypeAlias = Callable[[str], ConfigFile] + + +TOML_CONFIG_FILE: str = "test_config.toml" +YAML_CONFIG_FILE: str = "test_config.yaml" +INVALID_CONFIG_FILE: str = "invalid_config.fake" +NON_EXISTENT_DIR: str = "/path/to/non/existent/directory" + + +@pytest.fixture +def config_file(): + return ConfigFile(Path(YAML_CONFIG_FILE)) + + +@pytest_asyncio.fixture +async def config_file_factory(monkeypatch) -> AsyncGenerator[ConfigFile, None]: + async def _create_config_file(path: str, mock_load: bool = False) -> ConfigFile: + config_file = ConfigFile(path=Path(path)) + + if mock_load: + # Mock the aload method + config_file.aload = AsyncMock() + + # Optionally mock the behavior to simulate file loading + config_file.aload.return_value = None # or whatever behavior you expect + config_file.models = { + "test_model": Model(name="test_model", backend="test_backend") + } + + return config_file + + yield _create_config_file + + +@pytest.mark.anyio +async def test_config_singleton(): + """Test that Config is a singleton""" + config1 = await Config.create() + config2 = await Config.create() + assert config1 is config2 + + +@pytest.mark.anyio +async def test_initialize_from_env(config_factory): + """Test that this will run from the env vars we specify and won't error just cause it's a fake path""" + config_path = "/test/path" + config_filename = "test*.yaml" + + # Set env variables to existing values if not provided, and create the config object + config = await config_factory( + config_path=config_path, + config_filename=config_filename, + ) + assert config._config_dir == "/test/path" + assert config._config_filename == "test*.yaml" + await config.cleanup() + + +@pytest.mark.anyio +async def test_config_load_config_file(config_factory, config_file_factory): + """Test that this will run from the env vars we specify and won't error just cause it's a fake path""" + config_filename = "test_config.yaml" + config: Config = await config_factory( + config_path="/test/path", + config_filename="test_config.yaml", + ) + # Create a real ConfigFile instance with a mocked __await__ method + config_file: ConfigFile = await config_file_factory( + path=config_filename, mock_load=True + ) + + with patch( + "leapfrogai_api.utils.config.ConfigFile", # Patching the class constructor to return our instance + return_value=config_file, + ): + await config._load_config_file(path=Path(config_filename)) + + # Verify that the config file was loaded into the config object + test_config = config.config_files.get(config_filename, None) + assert test_config is not None, f"{config_filename} not loaded: {test_config}" + + # Verify that the model was loaded into the config's models + test_model = config.models.get("test_model", None) + assert test_model is not None, f"test_model not loaded: {test_model}" + assert test_model.name == "test_model" + assert test_model.backend == "test_backend" + + +@pytest.mark.anyio +async def test_load_all_configs(config_factory, parent_dir): + config = await config_factory() + + mock_glob = MagicMock() + mock_glob.__aiter__.return_value = iter( + [ + await (Path(parent_dir) / "test-config.yaml").resolve(), + await (Path(parent_dir) / "test-config.toml").resolve(), + ] + ) + + with patch("leapfrogai_api.utils.config.Path.glob", return_value=mock_glob): + await config.load_all_configs() + + assert mock_glob.__aiter__.called + + +@pytest.mark.anyio +async def test_watch_for_changes(config_factory, parent_dir): + config = await config_factory(config_path=parent_dir, config_filename="*.yaml") + + # Create an async generator for mocking awatch + awatch_response = [(Change.added, "test-config.yaml")] + + async def mock_awatch_generator(): + yield awatch_response + # Add a small delay to allow other coroutines to run + await asyncio.sleep(0.1) + + with ( + patch( + "leapfrogai_api.utils.config.awatch", + return_value=mock_awatch_generator(), + ), + patch.object(config, "initialize", new_callable=AsyncMock) as mock_initialize, + patch.object( + config, "_handle_config_changes", new_callable=AsyncMock + ) as mock_handle_config_changes, + ): + # Start watching in a separate task + watch_task = asyncio.create_task(config.start_watching()) + + # Wait a short time to allow the watch task to start and process the mock changes + await asyncio.sleep(0.2) + + # Stop the watching + await config.stop_watching() + + # Wait for the watch task to complete + await watch_task + + mock_initialize.assert_called_once() + mock_handle_config_changes.assert_called_once_with(awatch_response) + + +@pytest.mark.anyio +async def test_handle_config_changes(config_factory): + """Test that we can detect and handle config changes""" + config = await config_factory() + mock_load_config = AsyncMock() + changes = [ + (Change.added, "new_config.yaml"), + (Change.modified, "existing_config.yaml"), + (Change.deleted, "old_config.yaml"), + (Change.added, "not_a_config.txt"), + ] + + config.config_files = {"old_config.yaml": AsyncMock()} + + with patch.object(config, "_load_config_file", mock_load_config): + await config._handle_config_changes(changes) + + assert mock_load_config.call_count == 2 # for added and modified + assert "old_config.yaml" not in config.config_files + + +@pytest.mark.anyio +async def test_get_model_backend(config_factory): + """Test that we can access a model's backend""" + config = await config_factory() + + config.models = {"test_model": Model(name="test_model", backend="test_backend")} + assert config.get_model_backend("test_model").backend == "test_backend" + assert config.get_model_backend("non_existent_model") is None + + +@pytest.mark.anyio +async def test_clear_all_models(config_factory, parent_dir): + config = await config_factory(config_path=parent_dir) + + mock_config_file = AsyncMock(spec=ConfigFile) + mock_config_file.filename = "test_config.yaml" # Mock the filename attribute + config.config_files = {"test_config.yaml": mock_config_file} + config.models = {"test_model": Model(name="test_model", backend="test_backend")} + + await config.clear_all_models() + + assert len(config.models) == 0 + assert len(config.config_files) == 0 + mock_config_file.aunload.assert_called_once() + + +@pytest.mark.anyio +async def test_to_dict(config_factory): + """Test the dict conversion is correct""" + config = await config_factory() + config.models = { + "model1": Model(name="model1", backend="backend1"), + "model2": Model(name="model2", backend="backend2"), + } + config.config_files = { + "config1.yaml": MagicMock( + spec=ConfigFile, filename="config1.yaml", models={"model1": None} + ), + "config2.yaml": MagicMock( + spec=ConfigFile, filename="config2.yaml", models={"model2": None} + ), + } + + result = config.to_dict() + + assert "config_sources" in result + assert "models" in result + assert len(result["models"]) == 2 + assert result["config_sources"]["config1.yaml"] == ["model1"] + assert result["config_sources"]["config2.yaml"] == ["model2"] + + +@pytest.mark.asyncio +async def test_parse_models(config_files): + """Test that the models are parsed correctly""" + config_path: Path = config_files.get(YAML_CONFIG_FILE, None) + assert config_path is not None, f"Could not find config file: {YAML_CONFIG_FILE}" + config = ConfigFile( + path=Path(config_path), + ) + await config.aload() + test_data = { + "models": [ + {"name": "model1", "backend": "backend1"}, + {"name": "model2", "backend": "backend2"}, + ] + } + + config.parse_models(test_data) + + assert len(config.models) == 2 + # Model 1 tests + model1 = config.models["model1"] + assert model1.name == "model1" + assert model1.backend == "backend1" + + # Model 2 tests + model2 = config.models["model2"] + + assert model2.name == "model2" + assert model2.backend == "backend2" + + assert config._loaded is True + + +@pytest.mark.asyncio +async def test_load_from_file_yaml(config_files): + config_path: Path = config_files.get(YAML_CONFIG_FILE, None) + assert config_path is not None, f"Could not find config file: {YAML_CONFIG_FILE}" + config_path = await config_path.resolve() + config = ConfigFile(path=config_path) + result = await config._load_from_file(config_path) + + # See the above mentioned YAML_CONFIG_FILE for the expected data + expected_data = { + "models": [ + { + "name": "model1", + "backend": "backend1", + "type": "type1", + "dimensions": 768, + "precision": 32, + "capabilities": ["embeddings"], + }, + {"name": "model2", "backend": "backend2"}, + ] + } + + assert result == expected_data + + +@pytest.mark.asyncio +async def test_load_from_file_toml(config_files): + config_path: Path = config_files.get(TOML_CONFIG_FILE, None) + assert config_path is not None, f"Could not find config file: {TOML_CONFIG_FILE}" + config_path = await config_path.resolve() + + config = ConfigFile(path=config_path) + result = await config._load_from_file(config_path) + + # Load the TOML file directly for comparison + with open(config_path, "r") as f: + expected_data = toml.load(f) + + assert result == expected_data + + +@pytest.mark.asyncio +async def test_load_from_file_unsupported(): + config = ConfigFile(Path("test_config.txt")) + + with patch("pathlib.Path.open") as mock_open: + mock_file = MagicMock() + mock_file.read.return_value = "dummy_content" + mock_open.return_value.__aenter__.return_value = mock_file + + result = await config._load_from_file(Path("test_config.txt")) + + assert result == {} + + +@pytest.mark.asyncio +async def test_load_config_file(config_file): + test_data = dict( + models=[dict(name="test_model", backend="test_backend")], + ) + + with ( + patch.object(config_file, "_load_from_file") as mock_load, + patch.object(config_file, "parse_models") as mock_parse, + ): + mock_load.return_value = test_data + + await config_file.load_config_file() + + mock_load.assert_called_once_with(path=config_file.path) + mock_parse.assert_called_once_with(test_data) + + +@pytest.mark.asyncio +async def test_aload(config_file): + with ( + patch("pathlib.Path.exists") as mock_exists, + patch.object(config_file, "load_config_file") as mock_load, + ): + mock_exists.return_value = True + + await config_file.aload() + + mock_exists.assert_called_once() + mock_load.assert_called_once() + + +@pytest.mark.asyncio +async def test_aunload(config_file): + # config = ConfigFile(Path(YAML_CONFIG_FILE)) + config_file.models = {"model1": Model(name="model1", backend="backend1")} + config_file._loaded = True + + await config_file.aunload() + + assert len(config_file.models) == 0 + assert config_file._loaded is False, f"config._loaded = {config_file._loaded}" + + +def test_str_representation(config_file): + assert str(config_file) == f"Path: {config_file.path}, Models: {{}}" + + +def test_repr_representation(config_file): + assert repr(config_file) == f"ConfigFile(path={config_file.path}, models={{}})" diff --git a/tests/unit/leapfrogai_api/utils/test_config.toml b/tests/unit/leapfrogai_api/utils/test_config.toml new file mode 100644 index 000000000..60d669f87 --- /dev/null +++ b/tests/unit/leapfrogai_api/utils/test_config.toml @@ -0,0 +1,11 @@ +[[models]] +name = "model1" +backend = "backend1" +type = "type1" +dimensions = 768 +precision = 32 +capabilities = ["embeddings"] + +[[models]] +name = "model2" +backend = "backend2" diff --git a/tests/unit/leapfrogai_api/utils/test_config.yaml b/tests/unit/leapfrogai_api/utils/test_config.yaml new file mode 100644 index 000000000..b146d99a4 --- /dev/null +++ b/tests/unit/leapfrogai_api/utils/test_config.yaml @@ -0,0 +1,11 @@ +# This file is meant to serve as a sample config file. +# `test_config.toml` is supposed to be a TOML file with the same contents. +models: + - name: "model1" + backend: "backend1" + type: "type1" + dimensions: 768 + precision: 32 + capabilities: ["embeddings"] + - name: model2 + backend: backend2