Skip to content

Commit

Permalink
Fragments are now persisted, added basic CLI commands
Browse files Browse the repository at this point in the history
Refs #617
  • Loading branch information
simonw committed Nov 18, 2024
1 parent c07ba03 commit 84f8363
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 17 deletions.
107 changes: 101 additions & 6 deletions llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@

from .migrations import migrate
from .plugins import pm, load_plugins
from .utils import mimetype_from_path, mimetype_from_string
from .utils import (
FragmentString,
ensure_fragment,
mimetype_from_path,
mimetype_from_string,
)
import base64
import httpx
import pathlib
Expand All @@ -44,7 +49,7 @@
from sqlite_utils.utils import rows_from_file, Format
import sys
import textwrap
from typing import cast, Optional, Iterable, Union, Tuple
from typing import cast, Optional, Iterable, List, Union, Tuple
import warnings
import yaml

Expand All @@ -53,18 +58,22 @@
DEFAULT_TEMPLATE = "prompt: "


def resolve_fragments(fragments):
def resolve_fragments(fragments: Iterable[str]) -> List[Tuple[str, str]]:
"""
Resolve fragments into a list of (content, source) tuples
"""
# These can be URLs or paths
resolved = []
for fragment in fragments:
if fragment.startswith("http://") or fragment.startswith("https://"):
response = httpx.get(fragment, follow_redirects=True)
response.raise_for_status()
resolved.append(response.text)
resolved.append(FragmentString(response.text, fragment))
elif fragment == "-":
resolved.append(sys.stdin.read())
resolved.append(FragmentString(sys.stdin.read(), "-"))
elif pathlib.Path(fragment).exists():
resolved.append(pathlib.Path(fragment).read_text())
path = pathlib.Path(fragment)
resolved.append(FragmentString(path.read_text(), str(path.resolve())))
else:
raise click.ClickException(f"Fragment {fragment} not found")
return resolved
Expand Down Expand Up @@ -1226,6 +1235,92 @@ def aliases_path():
click.echo(user_dir() / "aliases.json")


@cli.group(
cls=DefaultGroup,
default="list",
default_if_no_args=True,
)
def fragments():
"Manage fragments"


@fragments.command(name="list")
@click.option("json_", "--json", is_flag=True, help="Output as JSON")
def fragments_list(json_):
"List current fragments"
db = sqlite_utils.Database(logs_db_path())
migrate(db)
sql = """
select
fragments.id,
fragments.hash,
fragments.content,
fragments.datetime_utc,
fragments.source,
json_group_array(fragment_aliases.alias) filter (
where
fragment_aliases.alias is not null
) as aliases
from
fragments
left join
fragment_aliases on fragment_aliases.fragment_id = fragments.id
group by
fragments.id, fragments.hash, fragments.content, fragments.datetime_utc, fragments.source;
"""
results = list(db.query(sql))
for result in results:
result["aliases"] = json.loads(result["aliases"])
click.echo(json.dumps(results, indent=4))


@fragments.command(name="set")
@click.argument("alias")
@click.argument("fragment")
def fragments_set(alias, fragment):
"""
Set an alias for a fragment
Accepts an alias and a file path, URL or '-' for stdin
Example usage:
\b
llm fragments set docs ./docs.md
"""
resolved = resolve_fragments([fragment])[0]
db = sqlite_utils.Database(logs_db_path())
migrate(db)
alias_sql = """
insert into fragment_aliases (alias, fragment_id)
values (:alias, :fragment_id)
on conflict(alias) do update set
fragment_id = excluded.fragment_id;
"""
with db.conn:
fragment_id = ensure_fragment(db, resolved)
db.conn.execute(alias_sql, {"alias": alias, "fragment_id": fragment_id})


@fragments.command(name="remove")
@click.argument("alias")
def fragments_remove(alias):
"""
Remove a fragment alias
Example usage:
\b
llm fragments remove docs
"""
db = sqlite_utils.Database(logs_db_path())
migrate(db)
with db.conn:
db.conn.execute(
"delete from fragment_aliases where alias = :alias", {"alias": alias}
)


@cli.command(name="plugins")
@click.option("--all", help="Include built-in default plugins", is_flag=True)
def plugins_list(all):
Expand Down
15 changes: 11 additions & 4 deletions llm/migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,17 +236,24 @@ def m013_fragments_tables(db):
"id": int,
"hash": str,
"content": str,
"alias": str,
"datetime_utc": str,
"source": str,
},
pk="id",
)
db["fragments"].create_index(["alias"], unique=True)
db["fragments"].create_index(["hash"], unique=True)
db["fragment_aliases"].create(
{
"alias": str,
"fragment_id": int,
},
foreign_keys=(("fragment_id", "fragments", "id"),),
pk="alias",
)
db["prompt_fragments"].create(
{
"response_id": str,
"fragment_id": str,
"fragment_id": int,
"order": int,
},
foreign_keys=(
Expand All @@ -258,7 +265,7 @@ def m013_fragments_tables(db):
db["system_fragments"].create(
{
"response_id": str,
"fragment_id": str,
"fragment_id": int,
"order": int,
},
foreign_keys=(
Expand Down
33 changes: 26 additions & 7 deletions llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Set,
Union,
)
from .utils import mimetype_from_path, mimetype_from_string
from .utils import ensure_fragment, mimetype_from_path, mimetype_from_string
from abc import ABC, abstractmethod
import json
from pydantic import BaseModel
Expand Down Expand Up @@ -231,18 +231,18 @@ def prompt(
FRAGMENT_SQL = """
select
'prompt' as fragment_type,
f.content,
fragments.content,
pf."order" as ord
from prompt_fragments pf
join fragments f on pf.fragment_id = f.id
join fragments on pf.fragment_id = fragments.id
where pf.response_id = :response_id
union all
select
'system' as fragment_type,
f.content,
fragments.content,
sf."order" as ord
from system_fragments sf
join fragments f on sf.fragment_id = f.id
join fragments on sf.fragment_id = fragments.id
where sf.response_id = :response_id
order by fragment_type desc, ord asc;
"""
Expand Down Expand Up @@ -324,8 +324,8 @@ def log_to_db(self, db):
response = {
"id": response_id,
"model": self.model.model_id,
"prompt": self.prompt.prompt,
"system": self.prompt.system,
"prompt": self.prompt._prompt,
"system": self.prompt._system,
"prompt_json": self._prompt_json,
"options_json": {
key: value
Expand All @@ -339,6 +339,25 @@ def log_to_db(self, db):
"datetime_utc": self.datetime_utc(),
}
db["responses"].insert(response)
# Persist any fragments
for i, fragment in enumerate(self.prompt.fragments):
fragment_id = ensure_fragment(db, fragment)
db["prompt_fragments"].insert(
{
"response_id": response_id,
"fragment_id": fragment_id,
"order": i,
},
)
for i, fragment in enumerate(self.prompt.system_fragments):
fragment_id = ensure_fragment(db, fragment)
db["system_fragments"].insert(
{
"response_id": response_id,
"fragment_id": fragment_id,
"order": i,
},
)
# Persist any attachments - loop through with index
for index, attachment in enumerate(self.prompt.attachments):
attachment_id = attachment.id()
Expand Down
34 changes: 34 additions & 0 deletions llm/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import click
import hashlib
import httpx
import json
import puremagic
Expand All @@ -10,6 +11,22 @@
}


class FragmentString(str):
def __new__(cls, content, source):
# We need to use __new__ since str is immutable
instance = super().__new__(cls, content)
return instance

def __init__(self, content, source):
self.source = source

def __str__(self):
return super().__str__()

def __repr__(self):
return super().__repr__()


def mimetype_from_string(content) -> Optional[str]:
try:
type_ = puremagic.from_string(content, mime=True)
Expand Down Expand Up @@ -127,3 +144,20 @@ def logging_client() -> httpx.Client:
transport=_LogTransport(httpx.HTTPTransport()),
event_hooks={"request": [_no_accept_encoding], "response": [_log_response]},
)


def ensure_fragment(db, content):
sql = """
insert into fragments (hash, content, datetime_utc, source)
values (:hash, :content, datetime('now'), :source)
on conflict(hash) do nothing
"""
hash = hashlib.sha256(content.encode("utf-8")).hexdigest()
source = None
if isinstance(content, FragmentString):
source = content.source
with db.conn:
db.execute(sql, {"hash": hash, "content": content, "source": source})
return list(
db.query("select id from fragments where hash = :hash", {"hash": hash})
)[0]["id"]

0 comments on commit 84f8363

Please sign in to comment.