From 84f8363581c7672966d53fde8fc6eb606313c97f Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 17 Nov 2024 21:18:04 -0800 Subject: [PATCH] Fragments are now persisted, added basic CLI commands Refs #617 --- llm/cli.py | 107 +++++++++++++++++++++++++++++++++++++++++++--- llm/migrations.py | 15 +++++-- llm/models.py | 33 +++++++++++--- llm/utils.py | 34 +++++++++++++++ 4 files changed, 172 insertions(+), 17 deletions(-) diff --git a/llm/cli.py b/llm/cli.py index d1261d2e..1e766c7d 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -32,7 +32,12 @@ from .migrations import migrate from .plugins import pm, load_plugins -from .utils import mimetype_from_path, mimetype_from_string +from .utils import ( + FragmentString, + ensure_fragment, + mimetype_from_path, + mimetype_from_string, +) import base64 import httpx import pathlib @@ -44,7 +49,7 @@ from sqlite_utils.utils import rows_from_file, Format import sys import textwrap -from typing import cast, Optional, Iterable, Union, Tuple +from typing import cast, Optional, Iterable, List, Union, Tuple import warnings import yaml @@ -53,18 +58,22 @@ DEFAULT_TEMPLATE = "prompt: " -def resolve_fragments(fragments): +def resolve_fragments(fragments: Iterable[str]) -> List[Tuple[str, str]]: + """ + Resolve fragments into a list of (content, source) tuples + """ # These can be URLs or paths resolved = [] for fragment in fragments: if fragment.startswith("http://") or fragment.startswith("https://"): response = httpx.get(fragment, follow_redirects=True) response.raise_for_status() - resolved.append(response.text) + resolved.append(FragmentString(response.text, fragment)) elif fragment == "-": - resolved.append(sys.stdin.read()) + resolved.append(FragmentString(sys.stdin.read(), "-")) elif pathlib.Path(fragment).exists(): - resolved.append(pathlib.Path(fragment).read_text()) + path = pathlib.Path(fragment) + resolved.append(FragmentString(path.read_text(), str(path.resolve()))) else: raise click.ClickException(f"Fragment {fragment} not found") return resolved @@ -1226,6 +1235,92 @@ def aliases_path(): click.echo(user_dir() / "aliases.json") +@cli.group( + cls=DefaultGroup, + default="list", + default_if_no_args=True, +) +def fragments(): + "Manage fragments" + + +@fragments.command(name="list") +@click.option("json_", "--json", is_flag=True, help="Output as JSON") +def fragments_list(json_): + "List current fragments" + db = sqlite_utils.Database(logs_db_path()) + migrate(db) + sql = """ + select + fragments.id, + fragments.hash, + fragments.content, + fragments.datetime_utc, + fragments.source, + json_group_array(fragment_aliases.alias) filter ( + where + fragment_aliases.alias is not null + ) as aliases + from + fragments + left join + fragment_aliases on fragment_aliases.fragment_id = fragments.id + group by + fragments.id, fragments.hash, fragments.content, fragments.datetime_utc, fragments.source; + """ + results = list(db.query(sql)) + for result in results: + result["aliases"] = json.loads(result["aliases"]) + click.echo(json.dumps(results, indent=4)) + + +@fragments.command(name="set") +@click.argument("alias") +@click.argument("fragment") +def fragments_set(alias, fragment): + """ + Set an alias for a fragment + + Accepts an alias and a file path, URL or '-' for stdin + + Example usage: + + \b + llm fragments set docs ./docs.md + """ + resolved = resolve_fragments([fragment])[0] + db = sqlite_utils.Database(logs_db_path()) + migrate(db) + alias_sql = """ + insert into fragment_aliases (alias, fragment_id) + values (:alias, :fragment_id) + on conflict(alias) do update set + fragment_id = excluded.fragment_id; + """ + with db.conn: + fragment_id = ensure_fragment(db, resolved) + db.conn.execute(alias_sql, {"alias": alias, "fragment_id": fragment_id}) + + +@fragments.command(name="remove") +@click.argument("alias") +def fragments_remove(alias): + """ + Remove a fragment alias + + Example usage: + + \b + llm fragments remove docs + """ + db = sqlite_utils.Database(logs_db_path()) + migrate(db) + with db.conn: + db.conn.execute( + "delete from fragment_aliases where alias = :alias", {"alias": alias} + ) + + @cli.command(name="plugins") @click.option("--all", help="Include built-in default plugins", is_flag=True) def plugins_list(all): diff --git a/llm/migrations.py b/llm/migrations.py index e0e7c36b..eb607422 100644 --- a/llm/migrations.py +++ b/llm/migrations.py @@ -236,17 +236,24 @@ def m013_fragments_tables(db): "id": int, "hash": str, "content": str, - "alias": str, "datetime_utc": str, "source": str, }, pk="id", ) - db["fragments"].create_index(["alias"], unique=True) + db["fragments"].create_index(["hash"], unique=True) + db["fragment_aliases"].create( + { + "alias": str, + "fragment_id": int, + }, + foreign_keys=(("fragment_id", "fragments", "id"),), + pk="alias", + ) db["prompt_fragments"].create( { "response_id": str, - "fragment_id": str, + "fragment_id": int, "order": int, }, foreign_keys=( @@ -258,7 +265,7 @@ def m013_fragments_tables(db): db["system_fragments"].create( { "response_id": str, - "fragment_id": str, + "fragment_id": int, "order": int, }, foreign_keys=( diff --git a/llm/models.py b/llm/models.py index f61d526c..328de4d1 100644 --- a/llm/models.py +++ b/llm/models.py @@ -18,7 +18,7 @@ Set, Union, ) -from .utils import mimetype_from_path, mimetype_from_string +from .utils import ensure_fragment, mimetype_from_path, mimetype_from_string from abc import ABC, abstractmethod import json from pydantic import BaseModel @@ -231,18 +231,18 @@ def prompt( FRAGMENT_SQL = """ select 'prompt' as fragment_type, - f.content, + fragments.content, pf."order" as ord from prompt_fragments pf -join fragments f on pf.fragment_id = f.id +join fragments on pf.fragment_id = fragments.id where pf.response_id = :response_id union all select 'system' as fragment_type, - f.content, + fragments.content, sf."order" as ord from system_fragments sf -join fragments f on sf.fragment_id = f.id +join fragments on sf.fragment_id = fragments.id where sf.response_id = :response_id order by fragment_type desc, ord asc; """ @@ -324,8 +324,8 @@ def log_to_db(self, db): response = { "id": response_id, "model": self.model.model_id, - "prompt": self.prompt.prompt, - "system": self.prompt.system, + "prompt": self.prompt._prompt, + "system": self.prompt._system, "prompt_json": self._prompt_json, "options_json": { key: value @@ -339,6 +339,25 @@ def log_to_db(self, db): "datetime_utc": self.datetime_utc(), } db["responses"].insert(response) + # Persist any fragments + for i, fragment in enumerate(self.prompt.fragments): + fragment_id = ensure_fragment(db, fragment) + db["prompt_fragments"].insert( + { + "response_id": response_id, + "fragment_id": fragment_id, + "order": i, + }, + ) + for i, fragment in enumerate(self.prompt.system_fragments): + fragment_id = ensure_fragment(db, fragment) + db["system_fragments"].insert( + { + "response_id": response_id, + "fragment_id": fragment_id, + "order": i, + }, + ) # Persist any attachments - loop through with index for index, attachment in enumerate(self.prompt.attachments): attachment_id = attachment.id() diff --git a/llm/utils.py b/llm/utils.py index d2618dd4..94db09d7 100644 --- a/llm/utils.py +++ b/llm/utils.py @@ -1,4 +1,5 @@ import click +import hashlib import httpx import json import puremagic @@ -10,6 +11,22 @@ } +class FragmentString(str): + def __new__(cls, content, source): + # We need to use __new__ since str is immutable + instance = super().__new__(cls, content) + return instance + + def __init__(self, content, source): + self.source = source + + def __str__(self): + return super().__str__() + + def __repr__(self): + return super().__repr__() + + def mimetype_from_string(content) -> Optional[str]: try: type_ = puremagic.from_string(content, mime=True) @@ -127,3 +144,20 @@ def logging_client() -> httpx.Client: transport=_LogTransport(httpx.HTTPTransport()), event_hooks={"request": [_no_accept_encoding], "response": [_log_response]}, ) + + +def ensure_fragment(db, content): + sql = """ + insert into fragments (hash, content, datetime_utc, source) + values (:hash, :content, datetime('now'), :source) + on conflict(hash) do nothing + """ + hash = hashlib.sha256(content.encode("utf-8")).hexdigest() + source = None + if isinstance(content, FragmentString): + source = content.source + with db.conn: + db.execute(sql, {"hash": hash, "content": content, "source": source}) + return list( + db.query("select id from fragments where hash = :hash", {"hash": hash}) + )[0]["id"]