diff --git a/models/llama3/api/datatypes.py b/models/llama3/api/datatypes.py index 05a067c1..6d1e1179 100644 --- a/models/llama3/api/datatypes.py +++ b/models/llama3/api/datatypes.py @@ -5,15 +5,17 @@ # top-level folder for each specific model found within the models/ directory at # the top-level of this source tree. +import base64 from enum import Enum + +from io import BytesIO from typing import Dict, List, Literal, Optional, Union -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator from typing_extensions import Annotated from ...datatypes import * # noqa -from io import BytesIO from ...schema_utils import json_schema_type @@ -123,6 +125,19 @@ class RawMediaItem(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) + @field_serializer("data") + def serialize_data(self, data: Optional[bytes], _info): + if data is None: + return None + return base64.b64encode(data).decode("utf-8") + + @field_validator("data", mode="before") + @classmethod + def validate_data(cls, v): + if isinstance(v, str): + return base64.b64decode(v) + return v + class RawTextItem(BaseModel): type: Literal["text"] = "text"