Skip to content

Commit

Permalink
feat(leapfrogai_api/tests): fix race condition in test, make config a…
Browse files Browse the repository at this point in the history
… singleton, stop event stuff
  • Loading branch information
jamestexas committed Aug 20, 2024
1 parent 2c19064 commit af85681
Show file tree
Hide file tree
Showing 14 changed files with 1,520 additions and 860 deletions.
162 changes: 148 additions & 14 deletions src/leapfrogai_api/backend/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import datetime
from enum import Enum
from typing import Literal
from typing import Any, Literal
import warnings

from fastapi import UploadFile, Form, File
from openai.types import FileObject
Expand All @@ -18,7 +19,7 @@
)
from openai.types.beta.threads.text_content_block_param import TextContentBlockParam
from openai.types.beta.vector_store import ExpiresAfter
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator, ValidationInfo, ConfigDict

##########
# DEFAULTS
Expand All @@ -27,6 +28,7 @@

DEFAULT_MAX_COMPLETION_TOKENS = 4096
DEFAULT_MAX_PROMPT_TOKENS = 4096
DEFAULT_DIMENSION: int = 2000

##########
# GENERIC
Expand All @@ -48,28 +50,158 @@ class Usage(BaseModel):
)


##########
# WARNINGS
##########


class LeapfrogAIWarning(UserWarning):
"""Base warning class for LeapfrogAI."""


class DimensionWarning(LeapfrogAIWarning):
"""Warning for dimension-related issues."""

def __init__(
self,
dimension: int,
dimension_default: int = DEFAULT_DIMENSION,
):
super().__init__(
f"Dimension {dimension} exceeds the recommended maximum of {dimension_default}."
)


class CapabilityWarning(LeapfrogAIWarning):
"""Warning for capability-related issues."""


##########
# MODELS
##########


class Modality(str, Enum):
"""Defines the input/output modality of a model (image, text, speech, or null)."""

IMAGE = "image"
TEXT = "text"
SPEECH = "speech"


class Capability(str, Enum):
"""Specifies the functional capabilities of a model (chat, embeddings, speech-to-text, text-to-speech, or null)."""

CHAT = "chat"
EMBEDDINGS = "embeddings"
STT = "stt"
TTS = "tts"


class Precision(str, Enum):
"""Indicates the numerical precision used in the model's parameters (float16, float32, bfloat16, int8, int4, or null)."""

FLOAT16 = "float16"
FLOAT32 = "float32"
BFLOAT16 = "bfloat16"
INT8 = "int8"
INT4 = "int4"


class Format(str, Enum):
"""Describes the storage or quantization format of the model (None, GPTQ, GGUF, SqueezeLLM, AWQ, or null)."""

AWQ = "AWQ"
GGUF = "GGUF"
GPTQ = "GPTQ"
SQUEEZELLM = "SqueezeLLM"


class ModelMetadataResponse(BaseModel):
type: Literal["embeddings", "llm"] | None = Field(
default=None,
description="The type of the model e.g. ('embeddings' or 'llm')",
"""Metadata for the model, including type, dimensions (for embeddings), and precision."""

model_config = ConfigDict(use_enum_values=True)

capabilities: list[Capability] | None = Field(
default=None, # TODO: should we define this as an empty string if it's None?
description="Model capabilities (e.g., 'embeddings', 'chat', 'tts', 'stt')",
)
dimensions: int | None = Field(
default=None,
description="Embedding dimensions (for embeddings models)",
)
precision: str | None = Field(
format: Format | None = Field(
default=None,
description="Model precision (e.g., 'float16', 'float32')",
description="Model format (e.g., None, 'GPTQ', 'GGUF', 'SqueezeLLM', 'AWQ')",
)
capabilities: str | None = Field(
modalities: list[Modality] | None = Field(
default=None, # TODO: should we define this as an empty string if it's None?
description="The modalities of the model (e.g., 'image', 'text', 'speech')",
)

precision: Precision | None = Field(
default=None,
description="Model capabilities (e.g., 'embeddings' or 'chat')",
description="Model precision (e.g., 'float16', 'float32', 'bfloat16', 'int8', 'int4')",
)
type: Literal["embeddings", "llm"] | None = Field(
default=None,
description="The type of the model e.g. ('embeddings' or 'llm')",
)

@field_validator("dimensions")
@classmethod
def check_dimensions(
cls,
v: int | None,
info: ValidationInfo,
) -> int | None:
"""
Validates the 'dimensions' field of a model's metadata.
Args:
v: The dimension value to be validated.
info: The validation information.
Returns:
The validated dimension value.
Raises:
CapabilityError: If the 'dimensions' field is not set for models with 'embeddings' capability or vice versa.
"""
if v is not None and v > 2000:
warnings.warn(DimensionWarning(dimension=v))
return v

@field_validator("capabilities")
@classmethod
def validate_capabilities(
cls,
v: list[Capability] | None,
values: dict[str, Any],
) -> list[Capability] | None:
"""
Validates the 'capabilities' field of a model's metadata, ensuring that 'dimensions'
is correctly set when 'embeddings' is in capabilities.
"""
# TODO: Actually error here when 'embeddings' is not in capabilities, once this is actually implemented

# Check if dimensions is set, warn if 'embeddings' is not in capabilities
if (dimensions_set := (values.get("dimensions", None) is not None)) and not (
embeddings_set := (v is not None and "embeddings" in v)
):
warnings.warn(
CapabilityWarning(
"'dimensions' should only be set for models with 'embeddings' capability"
)
)
# Check if dimensions is not set, warn if 'embeddings' is in capabilities
elif not dimensions_set and embeddings_set:
warnings.warn(
CapabilityWarning(
"'dimensions' must be set for models with 'embeddings' capability"
)
)
return v


class ModelResponseModel(BaseModel):
Expand Down Expand Up @@ -122,7 +254,7 @@ class CompletionRequest(BaseModel):
model: str = Field(
...,
description="The ID of the model to use for completion.",
example="llama-cpp-python",
examples=["llama-cpp-python"],
)
prompt: str | list[int] = Field(
...,
Expand Down Expand Up @@ -154,7 +286,9 @@ class CompletionChoice(BaseModel):
description="Log probabilities for the generated tokens. Only returned if requested.",
)
finish_reason: str = Field(
"", description="The reason why the completion finished.", example="length"
"",
description="The reason why the completion finished.",
examples=["length"],
)


Expand Down Expand Up @@ -601,12 +735,12 @@ class CreateVectorStoreRequest(BaseModel):
file_ids: list[str] | None = Field(
default=[],
description="List of file IDs to be included in the vector store.",
example=["file-abc123", "file-def456"],
examples=["file-abc123", "file-def456"],
)
name: str | None = Field(
default=None,
description="Optional name for the vector store.",
example="My Vector Store",
examples=["My Vector Store"],
)
expires_after: ExpiresAfter | None = Field(
default=None,
Expand All @@ -616,7 +750,7 @@ class CreateVectorStoreRequest(BaseModel):
metadata: dict | None = Field(
default=None,
description="Optional metadata for the vector store.",
example={"project": "AI Research", "version": "1.0"},
examples=[{"project": "AI Research", "version": "1.0"}],
)

def add_days_to_timestamp(self, timestamp: int, days: int) -> int:
Expand Down
Loading

0 comments on commit af85681

Please sign in to comment.