Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multimodal #66

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ repos:
pyyaml,
requests,
rich,
pillow,
base64,
io,
transformers,

# dev dependencies
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ dependencies = [
"pyyaml>=6.0.0",
"requests",
"rich",
"pillow",
"base64",
"io",
"transformers",
]

Expand Down
1 change: 1 addition & 0 deletions src/guidellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# flake8: noqa

import os

import transformers # type: ignore

os.environ["TOKENIZERS_PARALLELISM"] = "false" # Silence warnings for tokenizers
Expand Down
26 changes: 23 additions & 3 deletions src/guidellm/backend/openai.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
import io
from typing import AsyncGenerator, Dict, List, Optional

from loguru import logger
Expand Down Expand Up @@ -103,11 +105,11 @@ async def make_request(

request_args.update(self._request_args)

messages = self._build_messages(request)

stream = await self._async_client.chat.completions.create(
model=self.model,
messages=[
{"role": "user", "content": request.prompt},
],
messages=messages,
stream=True,
**request_args,
)
Expand Down Expand Up @@ -167,3 +169,21 @@ def validate_connection(self):
except Exception as error:
logger.error("Failed to validate OpenAI connection: {}", error)
raise error

def _build_messages(self, request: TextGenerationRequest) -> Dict:
if request.number_images == 0:
messages = [{"role": "user", "content": request.prompt}]
else:
content = []
for image in request.images:
stream = io.BytesIO()
im_format = image.image.format or "PNG"
image.image.save(stream, format=im_format)
im_b64 = base64.b64encode(stream.getvalue()).decode("utf-8")
image_url = {"url": f"data:image/{im_format.lower()};base64,{im_b64}"}
content.append({"type": "image_url", "image_url": image_url})

content.append({"type": "text", "text": request.prompt})
messages = [{"role": "user", "content": content}]

return messages
1 change: 1 addition & 0 deletions src/guidellm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class EmulatedDataSettings(BaseModel):
"force_new_line_punctuation": True,
}
)
image_source: List[str] = "https://www.gutenberg.org/cache/epub/1342/pg1342-images.html"


class OpenAISettings(BaseModel):
Expand Down
15 changes: 14 additions & 1 deletion src/guidellm/core/request.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import uuid
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from pydantic import Field

from guidellm.core.serializable import Serializable
from guidellm.utils import ImageDescriptor


class TextGenerationRequest(Serializable):
Expand All @@ -16,6 +17,10 @@ class TextGenerationRequest(Serializable):
description="The unique identifier for the request.",
)
prompt: str = Field(description="The input prompt for the text generation.")
images: Optional[List[ImageDescriptor]] = Field(
default=None,
description="Input images.",
)
prompt_token_count: Optional[int] = Field(
default=None,
description="The number of tokens in the input prompt.",
Expand All @@ -29,6 +34,13 @@ class TextGenerationRequest(Serializable):
description="The parameters for the text generation request.",
)

@property
def number_images(self) -> int:
if self.images is None:
return 0
else:
return len(self.images)

def __str__(self) -> str:
prompt_short = (
self.prompt[:32] + "..."
Expand All @@ -41,4 +53,5 @@ def __str__(self) -> str:
f"prompt={prompt_short}, prompt_token_count={self.prompt_token_count}, "
f"output_token_count={self.output_token_count}, "
f"params={self.params})"
f"images={self.number_images}"
)
19 changes: 17 additions & 2 deletions src/guidellm/request/emulated.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from guidellm.config import settings
from guidellm.core.request import TextGenerationRequest
from guidellm.request.base import GenerationMode, RequestGenerator
from guidellm.utils import clean_text, filter_text, load_text, split_text
from guidellm.utils import clean_text, filter_text, load_images, load_text, split_text

__all__ = ["EmulatedConfig", "EmulatedRequestGenerator", "EndlessTokens"]

Expand All @@ -30,6 +30,7 @@ class EmulatedConfig:
generated_tokens_variance (Optional[int]): Variance for generated tokens.
generated_tokens_min (Optional[int]): Minimum number of generated tokens.
generated_tokens_max (Optional[int]): Maximum number of generated tokens.
images (Optional[int]): Number of input images.
"""

@staticmethod
Expand All @@ -47,7 +48,7 @@ def create_config(config: Optional[Union[str, Path, Dict]]) -> "EmulatedConfig":
"""
if not config:
logger.debug("Creating default configuration")
return EmulatedConfig(prompt_tokens=1024, generated_tokens=256)
return EmulatedConfig(prompt_tokens=1024, generated_tokens=256, images=0)

if isinstance(config, dict):
logger.debug("Loading configuration from dict: {}", config)
Expand Down Expand Up @@ -105,6 +106,8 @@ def create_config(config: Optional[Union[str, Path, Dict]]) -> "EmulatedConfig":
generated_tokens_min: Optional[int] = None
generated_tokens_max: Optional[int] = None

images: int = 0

@property
def prompt_tokens_range(self) -> Tuple[int, int]:
"""
Expand Down Expand Up @@ -327,6 +330,8 @@ def __init__(
settings.emulated_data.filter_start,
settings.emulated_data.filter_end,
)
if self._config.images > 0:
self._images = load_images(settings.emulated_data.image_source)
self._rng = np.random.default_rng(random_seed)

# NOTE: Must be after all the parameters since the queue population
Expand Down Expand Up @@ -355,6 +360,7 @@ def create_item(self) -> TextGenerationRequest:
logger.debug("Creating new text generation request")
target_prompt_token_count = self._config.sample_prompt_tokens(self._rng)
prompt = self.sample_prompt(target_prompt_token_count)
images = self.sample_images()
prompt_token_count = len(self.tokenizer.tokenize(prompt))
output_token_count = self._config.sample_output_tokens(self._rng)
logger.debug("Generated prompt: {}", prompt)
Expand All @@ -363,6 +369,7 @@ def create_item(self) -> TextGenerationRequest:
prompt=prompt,
prompt_token_count=prompt_token_count,
output_token_count=output_token_count,
images=images,
)

def sample_prompt(self, tokens: int) -> str:
Expand Down Expand Up @@ -395,3 +402,11 @@ def sample_prompt(self, tokens: int) -> str:
right = mid

return self._tokens.create_text(start_line_index, left)


def sample_images(self):
image_indices = self._rng.choice(
len(self._images), size=self._config.images, replace=False,
)

return [self._images[i] for i in image_indices]
3 changes: 3 additions & 0 deletions src/guidellm/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .images import ImageDescriptor, load_images
from .injector import create_report, inject_data
from .progress import BenchmarkReportProgress
from .text import (
Expand Down Expand Up @@ -37,4 +38,6 @@
"resolve_transformers_dataset_split",
"split_lines_by_punctuation",
"split_text",
"ImageDescriptor",
"load_images",
]
68 changes: 68 additions & 0 deletions src/guidellm/utils/images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from io import BytesIO
from typing import List, Optional
from urllib.parse import urljoin

import requests
from bs4 import BeautifulSoup
from loguru import logger
from PIL import Image
from pydantic import ConfigDict, Field

from guidellm.config import settings
from guidellm.core.serializable import Serializable

__all__ = ["load_images", "ImageDescriptor"]

class ImageDescriptor(Serializable):
"""
A class to represent image data in serializable format.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)

url: Optional[str] = Field(description="url address for image.")
image: Image.Image = Field(description="PIL image", exclude=True)
filename: Optional[int] = Field(
default=None,
description="Image filename.",
)


def load_images(data: str) -> List[ImageDescriptor]:
"""
Load an HTML file from a path or URL

:param data: the path or URL to load the HTML file from
:type data: Union[str, Path]
:return: Descriptor containing image url and the data in PIL.Image.Image format
:rtype: ImageDescriptor
"""

images = []
if not data:
return None
if isinstance(data, str) and data.startswith("http"):
response = requests.get(data, timeout=settings.request_timeout)
response.raise_for_status()

soup = BeautifulSoup(response.text, "html.parser")
for img_tag in soup.find_all("img"):
img_url = img_tag.get("src")

if img_url:
# Handle relative URLs
img_url = urljoin(data, img_url)

# Download the image
logger.debug("Loading image: {}", img_url)
img_response = requests.get(img_url)
img_response.raise_for_status()

# Load image into Pillow
images.append(
ImageDescriptor(
url=img_url,
image=Image.open(BytesIO(img_response.content)),
)
)

return images
Loading