Skip to content

Commit

Permalink
Special case treat audio/wave as audio/wav, closes #603
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Nov 8, 2024
1 parent febbc04 commit 5d1d723
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 19 deletions.
11 changes: 6 additions & 5 deletions llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@

from .migrations import migrate
from .plugins import pm
from .utils import mimetype_from_path, mimetype_from_string
import base64
import httpx
import pathlib
import puremagic
import pydantic
import readline
from runpy import run_module
Expand All @@ -58,9 +58,8 @@ def convert(self, value, param, ctx):
if value == "-":
content = sys.stdin.buffer.read()
# Try to guess type
try:
mimetype = puremagic.from_string(content, mime=True)
except puremagic.PureError:
mimetype = mimetype_from_string(content)
if mimetype is None:
raise click.BadParameter("Could not determine mimetype of stdin")
return Attachment(type=mimetype, path=None, url=None, content=content)
if "://" in value:
Expand All @@ -78,7 +77,9 @@ def convert(self, value, param, ctx):
self.fail(f"File {value} does not exist", param, ctx)
path = path.resolve()
# Try to guess type
mimetype = puremagic.from_file(str(path), mime=True)
mimetype = mimetype_from_path(str(path))
if mimetype is None:
raise click.BadParameter(f"Could not determine mimetype of {value}")
return Attachment(type=mimetype, path=str(path), url=None, content=None)


Expand Down
6 changes: 3 additions & 3 deletions llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import hashlib
import httpx
from itertools import islice
import puremagic
import re
import time
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Union
from .utils import mimetype_from_path, mimetype_from_string
from abc import ABC, abstractmethod
import json
from pydantic import BaseModel
Expand Down Expand Up @@ -43,13 +43,13 @@ def resolve_type(self):
return self.type
# Derive it from path or url or content
if self.path:
return puremagic.from_file(self.path, mime=True)
return mimetype_from_path(self.path)
if self.url:
response = httpx.head(self.url)
response.raise_for_status()
return response.headers.get("content-type")
if self.content:
return puremagic.from_string(self.content, mime=True)
return mimetype_from_string(self.content)
raise ValueError("Attachment has no type and no content to derive it from")

def content_bytes(self):
Expand Down
23 changes: 22 additions & 1 deletion llm/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,29 @@
import click
import httpx
import json
import puremagic
import textwrap
from typing import List, Dict
from typing import List, Dict, Optional

MIME_TYPE_FIXES = {
"audio/wave": "audio/wav",
}


def mimetype_from_string(content) -> Optional[str]:
try:
type_ = puremagic.from_string(content, mime=True)
return MIME_TYPE_FIXES.get(type_, type_)
except puremagic.PureError:
return None


def mimetype_from_path(path) -> Optional[str]:
try:
type_ = puremagic.from_file(path, mime=True)
return MIME_TYPE_FIXES.get(type_, type_)
except puremagic.PureError:
return None


def dicts_to_table_string(
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def env_setup(monkeypatch, user_path):

class MockModel(llm.Model):
model_id = "mock"
attachment_types = {"image/png"}
attachment_types = {"image/png", "audio/wav"}

class Options(llm.Options):
max_tokens: Optional[int] = Field(
Expand Down
29 changes: 20 additions & 9 deletions tests/test_attachments.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from click.testing import CliRunner
from unittest.mock import ANY
import llm
from llm import cli
import pytest

TINY_PNG = (
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xa6\x00\x00\x01\x1a"
Expand All @@ -12,36 +14,45 @@
b"\x82"
)

TINY_WAV = b"RIFF$\x00\x00\x00WAVEfmt \x10\x00\x00\x00\x01\x00\x01\x00D\xac\x00\x00"

def test_prompt_image(mock_model, logs_db):

@pytest.mark.parametrize(
"attachment_type,attachment_content",
[
("image/png", TINY_PNG),
("audio/wav", TINY_WAV),
],
)
def test_prompt_attachment(mock_model, logs_db, attachment_type, attachment_content):
runner = CliRunner()
mock_model.enqueue(["two boxes"])
result = runner.invoke(
llm.cli.cli,
["prompt", "-m", "mock", "describe image", "-a", "-"],
input=TINY_PNG,
cli.cli,
["prompt", "-m", "mock", "describe file", "-a", "-"],
input=attachment_content,
catch_exceptions=False,
)
assert result.exit_code == 0
assert result.exit_code == 0, result.output
assert result.output == "two boxes\n"
assert mock_model.history[0][0].attachments[0] == llm.Attachment(
type="image/png", path=None, url=None, content=TINY_PNG, _id=ANY
type=attachment_type, path=None, url=None, content=attachment_content, _id=ANY
)

# Check it was logged correctly
conversations = list(logs_db["conversations"].rows)
assert len(conversations) == 1
conversation = conversations[0]
assert conversation["model"] == "mock"
assert conversation["name"] == "describe image"
assert conversation["name"] == "describe file"
response = list(logs_db["responses"].rows)[0]
attachment = list(logs_db["attachments"].rows)[0]
assert attachment == {
"id": ANY,
"type": "image/png",
"type": attachment_type,
"path": None,
"url": None,
"content": TINY_PNG,
"content": attachment_content,
}
prompt_attachment = list(logs_db["prompt_attachments"].rows)[0]
assert prompt_attachment["attachment_id"] == attachment["id"]
Expand Down

0 comments on commit 5d1d723

Please sign in to comment.