diff --git a/llm/cli.py b/llm/cli.py index ad7aeb4f..6a6fb2cf 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -30,10 +30,10 @@ from .migrations import migrate from .plugins import pm +from .utils import mimetype_from_path, mimetype_from_string import base64 import httpx import pathlib -import puremagic import pydantic import readline from runpy import run_module @@ -58,9 +58,8 @@ def convert(self, value, param, ctx): if value == "-": content = sys.stdin.buffer.read() # Try to guess type - try: - mimetype = puremagic.from_string(content, mime=True) - except puremagic.PureError: + mimetype = mimetype_from_string(content) + if mimetype is None: raise click.BadParameter("Could not determine mimetype of stdin") return Attachment(type=mimetype, path=None, url=None, content=content) if "://" in value: @@ -78,7 +77,9 @@ def convert(self, value, param, ctx): self.fail(f"File {value} does not exist", param, ctx) path = path.resolve() # Try to guess type - mimetype = puremagic.from_file(str(path), mime=True) + mimetype = mimetype_from_path(str(path)) + if mimetype is None: + raise click.BadParameter(f"Could not determine mimetype of {value}") return Attachment(type=mimetype, path=str(path), url=None, content=None) diff --git a/llm/models.py b/llm/models.py index 838e25b1..485d9720 100644 --- a/llm/models.py +++ b/llm/models.py @@ -5,10 +5,10 @@ import hashlib import httpx from itertools import islice -import puremagic import re import time from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Union +from .utils import mimetype_from_path, mimetype_from_string from abc import ABC, abstractmethod import json from pydantic import BaseModel @@ -43,13 +43,13 @@ def resolve_type(self): return self.type # Derive it from path or url or content if self.path: - return puremagic.from_file(self.path, mime=True) + return mimetype_from_path(self.path) if self.url: response = httpx.head(self.url) response.raise_for_status() return response.headers.get("content-type") if self.content: - return puremagic.from_string(self.content, mime=True) + return mimetype_from_string(self.content) raise ValueError("Attachment has no type and no content to derive it from") def content_bytes(self): diff --git a/llm/utils.py b/llm/utils.py index 2ea9870a..d2618dd4 100644 --- a/llm/utils.py +++ b/llm/utils.py @@ -1,8 +1,29 @@ import click import httpx import json +import puremagic import textwrap -from typing import List, Dict +from typing import List, Dict, Optional + +MIME_TYPE_FIXES = { + "audio/wave": "audio/wav", +} + + +def mimetype_from_string(content) -> Optional[str]: + try: + type_ = puremagic.from_string(content, mime=True) + return MIME_TYPE_FIXES.get(type_, type_) + except puremagic.PureError: + return None + + +def mimetype_from_path(path) -> Optional[str]: + try: + type_ = puremagic.from_file(path, mime=True) + return MIME_TYPE_FIXES.get(type_, type_) + except puremagic.PureError: + return None def dicts_to_table_string( diff --git a/tests/conftest.py b/tests/conftest.py index bcdb8854..7d44b757 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -49,7 +49,7 @@ def env_setup(monkeypatch, user_path): class MockModel(llm.Model): model_id = "mock" - attachment_types = {"image/png"} + attachment_types = {"image/png", "audio/wav"} class Options(llm.Options): max_tokens: Optional[int] = Field( diff --git a/tests/test_attachments.py b/tests/test_attachments.py index 89a5b81a..e5417d47 100644 --- a/tests/test_attachments.py +++ b/tests/test_attachments.py @@ -1,6 +1,8 @@ from click.testing import CliRunner from unittest.mock import ANY import llm +from llm import cli +import pytest TINY_PNG = ( b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xa6\x00\x00\x01\x1a" @@ -12,20 +14,29 @@ b"\x82" ) +TINY_WAV = b"RIFF$\x00\x00\x00WAVEfmt \x10\x00\x00\x00\x01\x00\x01\x00D\xac\x00\x00" -def test_prompt_image(mock_model, logs_db): + +@pytest.mark.parametrize( + "attachment_type,attachment_content", + [ + ("image/png", TINY_PNG), + ("audio/wav", TINY_WAV), + ], +) +def test_prompt_attachment(mock_model, logs_db, attachment_type, attachment_content): runner = CliRunner() mock_model.enqueue(["two boxes"]) result = runner.invoke( - llm.cli.cli, - ["prompt", "-m", "mock", "describe image", "-a", "-"], - input=TINY_PNG, + cli.cli, + ["prompt", "-m", "mock", "describe file", "-a", "-"], + input=attachment_content, catch_exceptions=False, ) - assert result.exit_code == 0 + assert result.exit_code == 0, result.output assert result.output == "two boxes\n" assert mock_model.history[0][0].attachments[0] == llm.Attachment( - type="image/png", path=None, url=None, content=TINY_PNG, _id=ANY + type=attachment_type, path=None, url=None, content=attachment_content, _id=ANY ) # Check it was logged correctly @@ -33,15 +44,15 @@ def test_prompt_image(mock_model, logs_db): assert len(conversations) == 1 conversation = conversations[0] assert conversation["model"] == "mock" - assert conversation["name"] == "describe image" + assert conversation["name"] == "describe file" response = list(logs_db["responses"].rows)[0] attachment = list(logs_db["attachments"].rows)[0] assert attachment == { "id": ANY, - "type": "image/png", + "type": attachment_type, "path": None, "url": None, - "content": TINY_PNG, + "content": attachment_content, } prompt_attachment = list(logs_db["prompt_attachments"].rows)[0] assert prompt_attachment["attachment_id"] == attachment["id"]