Skip to content

Commit

Permalink
Allow for arbitrary examples containing DSPy.Images (#1801)
Browse files Browse the repository at this point in the history
* Fix dataset download

* WIP complex image types

* WIP fixing complex images

* Refactor chat_adapter somewhat working

* Ruff fixes

* remove print and update notebooks

* Tests failing on purpose - added None support and new str repr

* remove extra notebook

* Tests passing

* Clean comments

* Allow for proper image serialization

* ruff

* Remove assume text

* ruff

* remove excess prints

* fix test docstring

* Fix image repr

* tests

* Refactor tests to be more readable

* clean: ruff fix and add test

* fix: change test to use model_dump instead of model_dump_json

---------

Co-authored-by: Omar Khattab <[email protected]>
  • Loading branch information
isaacbmiller and okhat authored Feb 3, 2025
1 parent 0bf4c50 commit 9739f82
Show file tree
Hide file tree
Showing 10 changed files with 460 additions and 355 deletions.
138 changes: 40 additions & 98 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import textwrap
from collections.abc import Mapping
from itertools import chain
from typing import Any, Dict, List, Literal, NamedTuple, Union

from typing import Any, Dict, Literal, NamedTuple

import pydantic
from pydantic import TypeAdapter
Expand All @@ -17,6 +18,7 @@
from dspy.signatures.field import OutputField
from dspy.signatures.signature import Signature, SignatureMeta
from dspy.signatures.utils import get_dspy_field_type
from dspy.adapters.image_utils import try_expand_image_tags

field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]")

Expand Down Expand Up @@ -50,12 +52,12 @@ def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict

prepared_instructions = prepare_instructions(signature)
messages.append({"role": "system", "content": prepared_instructions})

for demo in demos:
messages.append(format_turn(signature, demo, role="user", incomplete=demo in incomplete_demos))
messages.append(format_turn(signature, demo, role="assistant", incomplete=demo in incomplete_demos))

messages.append(format_turn(signature, inputs, role="user"))
messages = try_expand_image_tags(messages)
return messages

def parse(self, signature, completion):
Expand Down Expand Up @@ -110,11 +112,10 @@ def format_fields(self, signature, values, role):
for field_name, field_info in signature.fields.items()
if field_name in values
}

return format_fields(fields_with_values)


def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=True) -> Union[str, List[dict]]:
def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
"""
Formats the values of the specified fields according to the field's DSPy type (input or output),
annotation (e.g. str, int, etc.), and the type of the value itself. Joins the formatted values
Expand All @@ -124,23 +125,14 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=
fields_with_values: A dictionary mapping information about a field to its corresponding
value.
Returns:
The joined formatted values of the fields, represented as a string or a list of dicts
The joined formatted values of the fields, represented as a string
"""
output = []
for field, field_value in fields_with_values.items():
formatted_field_value = format_field_value(field_info=field.info, value=field_value, assume_text=assume_text)
if assume_text:
output.append(f"[[ ## {field.name} ## ]]\n{formatted_field_value}")
else:
output.append({"type": "text", "text": f"[[ ## {field.name} ## ]]\n"})
if isinstance(formatted_field_value, dict) and formatted_field_value.get("type") == "image_url":
output.append(formatted_field_value)
else:
output[-1]["text"] += formatted_field_value["text"]
if assume_text:
return "\n\n".join(output).strip()
else:
return output
formatted_field_value = format_field_value(field_info=field.info, value=field_value)
output.append(f"[[ ## {field.name} ## ]]\n{formatted_field_value}")

return "\n\n".join(output).strip()


def parse_value(value, annotation):
Expand Down Expand Up @@ -180,92 +172,43 @@ def format_turn(signature, values, role, incomplete=False):
A chat message that can be appended to a chat thread. The message contains two string fields:
``role`` ("user" or "assistant") and ``content`` (the message text).
"""
fields_to_collapse = []
content = []

if role == "user":
fields = signature.input_fields
if incomplete:
fields_to_collapse.append(
{
"type": "text",
"text": "This is an example of the task, though some input or output fields are not supplied.",
}
)
message_prefix = "This is an example of the task, though some input or output fields are not supplied." if incomplete else ""
else:
fields = signature.output_fields
# Add the built-in field indicating that the chat turn has been completed
fields[BuiltInCompletedOutputFieldInfo.name] = BuiltInCompletedOutputFieldInfo.info
# Add the completed field for the assistant turn
fields = {**signature.output_fields, BuiltInCompletedOutputFieldInfo.name: BuiltInCompletedOutputFieldInfo.info}
values = {**values, BuiltInCompletedOutputFieldInfo.name: ""}
field_names = fields.keys()
if not incomplete:
if not set(values).issuperset(set(field_names)):
raise ValueError(f"Expected {field_names} but got {values.keys()}")
message_prefix = ""

fields_to_collapse.extend(
format_fields(
fields_with_values={
FieldInfoWithName(name=field_name, info=field_info): values.get(
field_name, "Not supplied for this particular example."
)
for field_name, field_info in fields.items()
},
assume_text=False,
)
)

if role == "user":
output_fields = list(signature.output_fields.keys())
if not incomplete and not set(values).issuperset(fields.keys()):
raise ValueError(f"Expected {fields.keys()} but got {values.keys()}")

def type_info(v):
return (
f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})"
if v.annotation is not str
else ""
)
messages = []
if message_prefix:
messages.append(message_prefix)

if output_fields:
fields_to_collapse.append(
{
"type": "text",
"text": "Respond with the corresponding output fields, starting with the field "
+ ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items())
+ ", and then ending with the marker for `[[ ## completed ## ]]`.",
}
)

# flatmap the list if any items are lists otherwise keep the item
flattened_list = list(
chain.from_iterable(item if isinstance(item, list) else [item] for item in fields_to_collapse)
field_messages = format_fields(
{FieldInfoWithName(name=k, info=v): values.get(k, "Not supplied for this particular example.")
for k, v in fields.items()},
)

if all(message.get("type", None) == "text" for message in flattened_list):
content = "\n\n".join(message.get("text") for message in flattened_list)
return {"role": role, "content": content}

# Collapse all consecutive text messages into a single message.
collapsed_messages = []
for item in flattened_list:
# First item is always added
if not collapsed_messages:
collapsed_messages.append(item)
continue

# If the current item is image, add to collapsed_messages
if item.get("type") == "image_url":
if collapsed_messages[-1].get("type") == "text":
collapsed_messages[-1]["text"] += "\n"
collapsed_messages.append(item)
# If the previous item is text and current item is text, append to the previous item
elif collapsed_messages[-1].get("type") == "text":
collapsed_messages[-1]["text"] += "\n\n" + item["text"]
# If the previous item is not text(aka image), add the current item as a new item
else:
item["text"] = "\n\n" + item["text"]
collapsed_messages.append(item)

return {"role": role, "content": collapsed_messages}

messages.append(field_messages)

# Add output field instructions for user messages
if role == "user" and signature.output_fields:
type_info = lambda v: f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" if v.annotation is not str else ""
field_instructions = "Respond with the corresponding output fields, starting with the field " + \
", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items()) + \
", and then ending with the marker for `[[ ## completed ## ]]`."
messages.append(field_instructions)
joined_messages = "\n\n".join(msg for msg in messages)
return {"role": role, "content": joined_messages}

def flatten_messages(messages):
"""Flatten nested message lists."""
return list(chain.from_iterable(
item if isinstance(item, list) else [item] for item in messages
))

def enumerate_fields(fields: dict) -> str:
parts = []
Expand Down Expand Up @@ -328,12 +271,11 @@ def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]):
FieldInfoWithName(name=field_name, info=field_info): field_metadata(field_name, field_info)
for field_name, field_info in fields.items()
},
assume_text=True,
)

parts.append(format_signature_fields_for_instructions(signature.input_fields))
parts.append(format_signature_fields_for_instructions(signature.output_fields))
parts.append(format_fields({BuiltInCompletedOutputFieldInfo: ""}, assume_text=True))
parts.append(format_fields({BuiltInCompletedOutputFieldInfo: ""}))
instructions = textwrap.dedent(signature.instructions)
objective = ("\n" + " " * 8).join([""] + instructions.splitlines())
parts.append(f"In adhering to this structure, your objective is: {objective}")
Expand Down
91 changes: 77 additions & 14 deletions dspy/adapters/image_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import base64
import io
import os
from typing import Union
from typing import Any, Dict, List, Union
from urllib.parse import urlparse
import re

import pydantic
import requests
Expand All @@ -17,13 +18,20 @@

class Image(pydantic.BaseModel):
url: str


model_config = {
'frozen': True,
'str_strip_whitespace': True,
'validate_assignment': True,
'extra': 'forbid',
}

@pydantic.model_validator(mode="before")
@classmethod
def validate_input(cls, values):
# Allow the model to accept either a URL string or a dictionary with a single 'url' key
if isinstance(values, str):
# if a string, assume its the URL directly and wrap it in a dict
# if a string, assume it's the URL directly and wrap it in a dict
return {"url": values}
elif isinstance(values, dict) and set(values.keys()) == {"url"}:
# if it's a dict, ensure it has only the 'url' key
Expand All @@ -44,14 +52,21 @@ def from_file(cls, file_path: str):

@classmethod
def from_PIL(cls, pil_image):
import PIL
return cls(url=encode_image(pil_image))

return cls(url=encode_image(PIL.Image.open(pil_image)))
@pydantic.model_serializer()
def serialize_model(self):
return "<DSPY_IMAGE_START>" + self.url + "<DSPY_IMAGE_END>"

def __repr__(self):
len_base64 = len(self.url.split("base64,")[1])
return f"Image(url = {self.url.split('base64,')[0]}base64,<IMAGE_BASE_64_ENCODED({str(len_base64)})>)"
def __str__(self):
return self.serialize_model()

def __repr__(self):
if "base64" in self.url:
len_base64 = len(self.url.split("base64,")[1])
image_type = self.url.split(";")[0].split("/")[-1]
return f"Image(url=data:image/{image_type};base64,<IMAGE_BASE_64_ENCODED({str(len_base64)})>)"
return f"Image(url='{self.url}')"

def is_url(string: str) -> bool:
"""Check if a string is a valid URL."""
Expand Down Expand Up @@ -95,6 +110,7 @@ def encode_image(image: Union[str, bytes, "PILImage.Image", dict], download_imag
return image
else:
# Unsupported string format
print(f"Unsupported image string: {image}")
raise ValueError(f"Unsupported image string: {image}")
elif PIL_AVAILABLE and isinstance(image, PILImage.Image):
# PIL Image
Expand All @@ -103,11 +119,12 @@ def encode_image(image: Union[str, bytes, "PILImage.Image", dict], download_imag
# Raw bytes
if not PIL_AVAILABLE:
raise ImportError("Pillow is required to process image bytes.")
img = Image.open(io.BytesIO(image))
img = PILImage.open(io.BytesIO(image))
return _encode_pil_image(img)
elif isinstance(image, Image):
return image.url
else:
print(f"Unsupported image type: {type(image)}")
raise ValueError(f"Unsupported image type: {type(image)}")


Expand All @@ -133,8 +150,7 @@ def _encode_image_from_url(image_url: str) -> str:
encoded_image = base64.b64encode(response.content).decode("utf-8")
return f"data:image/{file_extension};base64,{encoded_image}"


def _encode_pil_image(image: "Image.Image") -> str:
def _encode_pil_image(image: 'PILImage') -> str:
"""Encode a PIL Image object to a base64 data URI."""
buffered = io.BytesIO()
file_extension = (image.format or "PNG").lower()
Expand All @@ -151,9 +167,7 @@ def _get_file_extension(path_or_url: str) -> str:

def is_image(obj) -> bool:
"""Check if the object is an image or a valid image reference."""
if PIL_AVAILABLE and isinstance(obj, Image.Image):
return True
if isinstance(obj, (bytes, bytearray)):
if PIL_AVAILABLE and isinstance(obj, PILImage.Image):
return True
if isinstance(obj, str):
if obj.startswith("data:image/"):
Expand All @@ -163,3 +177,52 @@ def is_image(obj) -> bool:
elif is_url(obj):
return True
return False

def try_expand_image_tags(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Try to expand image tags in the messages."""
for message in messages:
# NOTE: Assumption that content is a string
if "content" in message and "<DSPY_IMAGE_START>" in message["content"]:
message["content"] = expand_image_tags(message["content"])
return messages

def expand_image_tags(text: str) -> Union[str, List[Dict[str, Any]]]:
"""Expand image tags in the text. If there are any image tags,
turn it from a content string into a content list of texts and image urls.
Args:
text: The text content that may contain image tags
Returns:
Either the original string if no image tags, or a list of content dicts
with text and image_url entries
"""
image_tag_regex = r'"?<DSPY_IMAGE_START>(.*?)<DSPY_IMAGE_END>"?'

# If no image tags, return original text
if not re.search(image_tag_regex, text):
return text

final_list = []
remaining_text = text

while remaining_text:
match = re.search(image_tag_regex, remaining_text)
if not match:
if remaining_text.strip():
final_list.append({"type": "text", "text": remaining_text.strip()})
break

# Get text before the image tag
prefix = remaining_text[:match.start()].strip()
if prefix:
final_list.append({"type": "text", "text": prefix})

# Add the image
image_url = match.group(1)
final_list.append({"type": "image_url", "image_url": {"url": image_url}})

# Update remaining text
remaining_text = remaining_text[match.end():].strip()

return final_list
3 changes: 2 additions & 1 deletion dspy/adapters/json_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,11 @@ def _format_field_value(field_info: FieldInfo, value: Any) -> str:
Returns:
The formatted value of the field, represented as a string.
"""
# TODO: Wasnt this easy to fix?
if field_info.annotation is Image:
raise NotImplementedError("Images are not yet supported in JSON mode.")

return format_field_value(field_info=field_info, value=value, assume_text=True)
return format_field_value(field_info=field_info, value=value)


def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
Expand Down
Loading

0 comments on commit 9739f82

Please sign in to comment.