diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2a085bb..6bcf150 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,6 +27,9 @@ repos: pyyaml, requests, rich, + pillow, + base64, + io, transformers, # dev dependencies diff --git a/pyproject.toml b/pyproject.toml index 6ab2c6e..b83abfd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,9 @@ dependencies = [ "pyyaml>=6.0.0", "requests", "rich", + "pillow", + "base64", + "io", "transformers", ] diff --git a/src/guidellm/__init__.py b/src/guidellm/__init__.py index e562018..b10b445 100644 --- a/src/guidellm/__init__.py +++ b/src/guidellm/__init__.py @@ -6,6 +6,7 @@ # flake8: noqa import os + import transformers # type: ignore os.environ["TOKENIZERS_PARALLELISM"] = "false" # Silence warnings for tokenizers diff --git a/src/guidellm/backend/openai.py b/src/guidellm/backend/openai.py index 8c83f91..90d2791 100644 --- a/src/guidellm/backend/openai.py +++ b/src/guidellm/backend/openai.py @@ -1,3 +1,5 @@ +import base64 +import io from typing import AsyncGenerator, Dict, List, Optional from loguru import logger @@ -103,11 +105,11 @@ async def make_request( request_args.update(self._request_args) + messages = self._build_messages(request) + stream = await self._async_client.chat.completions.create( model=self.model, - messages=[ - {"role": "user", "content": request.prompt}, - ], + messages=messages, stream=True, **request_args, ) @@ -167,3 +169,21 @@ def validate_connection(self): except Exception as error: logger.error("Failed to validate OpenAI connection: {}", error) raise error + + def _build_messages(self, request: TextGenerationRequest) -> Dict: + if request.number_images == 0: + messages = [{"role": "user", "content": request.prompt}] + else: + content = [] + for image in request.images: + stream = io.BytesIO() + im_format = image.image.format or "PNG" + image.image.save(stream, format=im_format) + im_b64 = base64.b64encode(stream.getvalue()).decode("utf-8") + image_url = {"url": f"data:image/{im_format.lower()};base64,{im_b64}"} + content.append({"type": "image_url", "image_url": image_url}) + + content.append({"type": "text", "text": request.prompt}) + messages = [{"role": "user", "content": content}] + + return messages diff --git a/src/guidellm/config.py b/src/guidellm/config.py index c3d950e..df750ea 100644 --- a/src/guidellm/config.py +++ b/src/guidellm/config.py @@ -90,6 +90,7 @@ class EmulatedDataSettings(BaseModel): "force_new_line_punctuation": True, } ) + image_source: List[str] = "https://www.gutenberg.org/cache/epub/1342/pg1342-images.html" class OpenAISettings(BaseModel): diff --git a/src/guidellm/core/request.py b/src/guidellm/core/request.py index 4f7315c..a1ff199 100644 --- a/src/guidellm/core/request.py +++ b/src/guidellm/core/request.py @@ -1,9 +1,10 @@ import uuid -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from pydantic import Field from guidellm.core.serializable import Serializable +from guidellm.utils import ImageDescriptor class TextGenerationRequest(Serializable): @@ -16,6 +17,10 @@ class TextGenerationRequest(Serializable): description="The unique identifier for the request.", ) prompt: str = Field(description="The input prompt for the text generation.") + images: Optional[List[ImageDescriptor]] = Field( + default=None, + description="Input images.", + ) prompt_token_count: Optional[int] = Field( default=None, description="The number of tokens in the input prompt.", @@ -29,6 +34,13 @@ class TextGenerationRequest(Serializable): description="The parameters for the text generation request.", ) + @property + def number_images(self) -> int: + if self.images is None: + return 0 + else: + return len(self.images) + def __str__(self) -> str: prompt_short = ( self.prompt[:32] + "..." @@ -41,4 +53,5 @@ def __str__(self) -> str: f"prompt={prompt_short}, prompt_token_count={self.prompt_token_count}, " f"output_token_count={self.output_token_count}, " f"params={self.params})" + f"images={self.number_images}" ) diff --git a/src/guidellm/request/emulated.py b/src/guidellm/request/emulated.py index 7d481cb..f15387e 100644 --- a/src/guidellm/request/emulated.py +++ b/src/guidellm/request/emulated.py @@ -11,7 +11,7 @@ from guidellm.config import settings from guidellm.core.request import TextGenerationRequest from guidellm.request.base import GenerationMode, RequestGenerator -from guidellm.utils import clean_text, filter_text, load_text, split_text +from guidellm.utils import clean_text, filter_text, load_images, load_text, split_text __all__ = ["EmulatedConfig", "EmulatedRequestGenerator", "EndlessTokens"] @@ -30,6 +30,7 @@ class EmulatedConfig: generated_tokens_variance (Optional[int]): Variance for generated tokens. generated_tokens_min (Optional[int]): Minimum number of generated tokens. generated_tokens_max (Optional[int]): Maximum number of generated tokens. + images (Optional[int]): Number of input images. """ @staticmethod @@ -47,7 +48,7 @@ def create_config(config: Optional[Union[str, Path, Dict]]) -> "EmulatedConfig": """ if not config: logger.debug("Creating default configuration") - return EmulatedConfig(prompt_tokens=1024, generated_tokens=256) + return EmulatedConfig(prompt_tokens=1024, generated_tokens=256, images=0) if isinstance(config, dict): logger.debug("Loading configuration from dict: {}", config) @@ -105,6 +106,8 @@ def create_config(config: Optional[Union[str, Path, Dict]]) -> "EmulatedConfig": generated_tokens_min: Optional[int] = None generated_tokens_max: Optional[int] = None + images: int = 0 + @property def prompt_tokens_range(self) -> Tuple[int, int]: """ @@ -327,6 +330,8 @@ def __init__( settings.emulated_data.filter_start, settings.emulated_data.filter_end, ) + if self._config.images > 0: + self._images = load_images(settings.emulated_data.image_source) self._rng = np.random.default_rng(random_seed) # NOTE: Must be after all the parameters since the queue population @@ -355,6 +360,7 @@ def create_item(self) -> TextGenerationRequest: logger.debug("Creating new text generation request") target_prompt_token_count = self._config.sample_prompt_tokens(self._rng) prompt = self.sample_prompt(target_prompt_token_count) + images = self.sample_images() prompt_token_count = len(self.tokenizer.tokenize(prompt)) output_token_count = self._config.sample_output_tokens(self._rng) logger.debug("Generated prompt: {}", prompt) @@ -363,6 +369,7 @@ def create_item(self) -> TextGenerationRequest: prompt=prompt, prompt_token_count=prompt_token_count, output_token_count=output_token_count, + images=images, ) def sample_prompt(self, tokens: int) -> str: @@ -395,3 +402,11 @@ def sample_prompt(self, tokens: int) -> str: right = mid return self._tokens.create_text(start_line_index, left) + + + def sample_images(self): + image_indices = self._rng.choice( + len(self._images), size=self._config.images, replace=False, + ) + + return [self._images[i] for i in image_indices] diff --git a/src/guidellm/utils/__init__.py b/src/guidellm/utils/__init__.py index 2fdd8ca..eb4931b 100644 --- a/src/guidellm/utils/__init__.py +++ b/src/guidellm/utils/__init__.py @@ -1,3 +1,4 @@ +from .images import ImageDescriptor, load_images from .injector import create_report, inject_data from .progress import BenchmarkReportProgress from .text import ( @@ -37,4 +38,6 @@ "resolve_transformers_dataset_split", "split_lines_by_punctuation", "split_text", + "ImageDescriptor", + "load_images", ] diff --git a/src/guidellm/utils/images.py b/src/guidellm/utils/images.py new file mode 100644 index 0000000..5d73bc0 --- /dev/null +++ b/src/guidellm/utils/images.py @@ -0,0 +1,68 @@ +from io import BytesIO +from typing import List, Optional +from urllib.parse import urljoin + +import requests +from bs4 import BeautifulSoup +from loguru import logger +from PIL import Image +from pydantic import ConfigDict, Field + +from guidellm.config import settings +from guidellm.core.serializable import Serializable + +__all__ = ["load_images", "ImageDescriptor"] + +class ImageDescriptor(Serializable): + """ + A class to represent image data in serializable format. + """ + model_config = ConfigDict(arbitrary_types_allowed=True) + + url: Optional[str] = Field(description="url address for image.") + image: Image.Image = Field(description="PIL image", exclude=True) + filename: Optional[int] = Field( + default=None, + description="Image filename.", + ) + + +def load_images(data: str) -> List[ImageDescriptor]: + """ + Load an HTML file from a path or URL + + :param data: the path or URL to load the HTML file from + :type data: Union[str, Path] + :return: Descriptor containing image url and the data in PIL.Image.Image format + :rtype: ImageDescriptor + """ + + images = [] + if not data: + return None + if isinstance(data, str) and data.startswith("http"): + response = requests.get(data, timeout=settings.request_timeout) + response.raise_for_status() + + soup = BeautifulSoup(response.text, "html.parser") + for img_tag in soup.find_all("img"): + img_url = img_tag.get("src") + + if img_url: + # Handle relative URLs + img_url = urljoin(data, img_url) + + # Download the image + logger.debug("Loading image: {}", img_url) + img_response = requests.get(img_url) + img_response.raise_for_status() + + # Load image into Pillow + images.append( + ImageDescriptor( + url=img_url, + image=Image.open(BytesIO(img_response.content)), + ) + ) + + return images