Skip to content

Commit

Permalink
WIP fragments: schema plus reading but not yet writing, refs #617
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Nov 17, 2024
1 parent 7382301 commit 6c355c1
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 13 deletions.
39 changes: 39 additions & 0 deletions llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,23 @@
DEFAULT_TEMPLATE = "prompt: "


def resolve_fragments(fragments):
# These can be URLs or paths
resolved = []
for fragment in fragments:
if fragment.startswith("http://") or fragment.startswith("https://"):
response = httpx.get(fragment, follow_redirects=True)
response.raise_for_status()
resolved.append(response.text)
elif fragment == "-":
resolved.append(sys.stdin.read())
elif pathlib.Path(fragment).exists():
resolved.append(pathlib.Path(fragment).read_text())
else:
raise click.ClickException(f"Fragment {fragment} not found")
return resolved


class AttachmentType(click.ParamType):
name = "attachment"

Expand Down Expand Up @@ -174,6 +191,16 @@ def cli():
multiple=True,
help="key/value options for the model",
)
@click.option(
"fragments", "-f", "--fragment", multiple=True, help="Fragment to add to prompt"
)
@click.option(
"system_fragments",
"--sf",
"--system-fragment",
multiple=True,
help="Fragment to add to system prompt",
)
@click.option("-t", "--template", help="Template to use")
@click.option(
"-p",
Expand Down Expand Up @@ -209,6 +236,8 @@ def prompt(
attachments,
attachment_types,
options,
fragments,
system_fragments,
template,
param,
no_stream,
Expand Down Expand Up @@ -266,6 +295,7 @@ def read_prompt():
and sys.stdin.isatty()
and not attachments
and not attachment_types
and not fragments
):
# Hang waiting for input to stdin (unless --save)
prompt = sys.stdin.read()
Expand Down Expand Up @@ -377,6 +407,9 @@ def read_prompt():

prompt = read_prompt()

fragments = resolve_fragments(fragments)
system_fragments = resolve_fragments(system_fragments)

prompt_method = model.prompt
if conversation:
prompt_method = conversation.prompt
Expand All @@ -388,8 +421,10 @@ async def inner():
if should_stream:
async for chunk in prompt_method(
prompt,
fragments=fragments,
attachments=resolved_attachments,
system=system,
system_fragments=system_fragments,
**validated_options,
):
print(chunk, end="")
Expand All @@ -398,8 +433,10 @@ async def inner():
else:
response = prompt_method(
prompt,
fragments=fragments,
attachments=resolved_attachments,
system=system,
system_fragments=system_fragments,
**validated_options,
)
print(await response.text())
Expand All @@ -408,8 +445,10 @@ async def inner():
else:
response = prompt_method(
prompt,
fragments=fragments,
attachments=resolved_attachments,
system=system,
system_fragments=system_fragments,
**validated_options,
)
if should_stream:
Expand Down
39 changes: 39 additions & 0 deletions llm/migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,42 @@ def m012_attachments_tables(db):
),
pk=("response_id", "attachment_id"),
)


@migration
def m013_fragments_tables(db):
db["fragments"].create(
{
"id": int,
"hash": str,
"content": str,
"alias": str,
"datetime_utc": str,
"source": str,
},
pk="id",
)
db["prompt_fragments"].create(
{
"response_id": str,
"fragment_id": str,
"order": int,
},
foreign_keys=(
("response_id", "responses", "id"),
("fragment_id", "fragments", "id"),
),
pk=("response_id", "fragment_id"),
)
db["system_fragments"].create(
{
"response_id": str,
"fragment_id": str,
"order": int,
},
foreign_keys=(
("response_id", "responses", "id"),
("fragment_id", "fragments", "id"),
),
pk=("response_id", "fragment_id"),
)
98 changes: 85 additions & 13 deletions llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,12 @@ def from_row(cls, row):

@dataclass
class Prompt:
prompt: str
_prompt: str
model: "Model"
fragments: Optional[List[str]]
attachments: Optional[List[Attachment]]
system: Optional[str]
_system: Optional[str]
system_fragments: Optional[List[str]]
prompt_json: Optional[str]
options: "Options"

Expand All @@ -101,18 +103,55 @@ def __init__(
prompt,
model,
*,
fragments=None,
attachments=None,
system=None,
system_fragments=None,
prompt_json=None,
options=None,
):
self.prompt = prompt
self._prompt = prompt
self.model = model
self.attachments = list(attachments or [])
self.system = system
self.fragments = fragments or []
self._system = system
self.system_fragments = system_fragments or []
self.prompt_json = prompt_json
self.options = options or {}

@property
def prompt(self):
return "\n".join(self.fragments + [self._prompt])

@property
def system(self):
bits = [
bit.strip()
for bit in (self.system_fragments + [self._system or ""])
if bit.strip()
]
return "\n\n".join(bits)

@classmethod
def from_row(cls, db, row, model):
all_fragments = list(db.query(FRAGMENT_SQL, {"response_id": row["id"]}))
fragments = [
row["content"] for row in all_fragments if row["fragment_type"] == "prompt"
]
system_fragments = [
row["content"] for row in all_fragments if row["fragment_type"] == "system"
]
breakpoint()
return cls(
prompt=row["prompt"],
model=model,
fragments=fragments,
attachments=[],
system=row["system"],
system_fragments=system_fragments,
options=model.Options(**json.loads(row["options_json"])),
)


@dataclass
class _BaseConversation:
Expand All @@ -138,17 +177,21 @@ def prompt(
self,
prompt: Optional[str],
*,
fragments: Optional[List[str]] = None,
attachments: Optional[List[Attachment]] = None,
system: Optional[str] = None,
system_fragments: Optional[List[str]] = None,
stream: bool = True,
**options,
) -> "Response":
return Response(
Prompt(
prompt,
model=self.model,
fragments=fragments,
attachments=attachments,
system=system,
system_fragments=system_fragments,
options=self.model.Options(**options),
),
self.model,
Expand All @@ -163,17 +206,21 @@ def prompt(
self,
prompt: Optional[str],
*,
fragments: Optional[List[str]] = None,
attachments: Optional[List[Attachment]] = None,
system: Optional[str] = None,
system_fragments: Optional[List[str]] = None,
stream: bool = True,
**options,
) -> "AsyncResponse":
return AsyncResponse(
Prompt(
prompt,
model=self.model,
fragments=fragments,
attachments=attachments,
system=system,
system_fragments=system_fragments,
options=self.model.Options(**options),
),
self.model,
Expand All @@ -182,6 +229,26 @@ def prompt(
)


FRAGMENT_SQL = """
select
'prompt' as fragment_type,
f.content,
pf."order" as ord
from prompt_fragments pf
join fragments f on pf.fragment_id = f.id
where pf.response_id = :response_id
union all
select
'system' as fragment_type,
f.content,
sf."order" as ord
from system_fragments sf
join fragments f on sf.fragment_id = f.id
where sf.response_id = :response_id
order by fragment_type desc, ord asc;
"""


class _BaseResponse:
"""Base response class shared between sync and async responses"""

Expand Down Expand Up @@ -217,13 +284,7 @@ def from_row(cls, db, row):

response = cls(
model=model,
prompt=Prompt(
prompt=row["prompt"],
model=model,
attachments=[],
system=row["system"],
options=model.Options(**json.loads(row["options_json"])),
),
prompt=Prompt.from_row(db, row, model),
stream=False,
)
response.id = row["id"]
Expand All @@ -233,8 +294,8 @@ def from_row(cls, db, row):
response._chunks = [row["response"]]
# Attachments
response.attachments = [
Attachment.from_row(arow)
for arow in db.query(
Attachment.from_row(attachment_row)
for attachment_row in db.query(
"""
select attachments.* from attachments
join prompt_attachments on attachments.id = prompt_attachments.attachment_id
Expand Down Expand Up @@ -328,6 +389,9 @@ def datetime_utc(self) -> str:
self._force()
return self._start_utcnow.isoformat() if self._start_utcnow else ""

def text_or_raise(self) -> str:
return self.text()

def __iter__(self) -> Iterator[str]:
self._start = time.monotonic()
self._start_utcnow = datetime.datetime.utcnow()
Expand Down Expand Up @@ -541,17 +605,21 @@ def prompt(
self,
prompt: str,
*,
fragments: Optional[List[str]] = None,
attachments: Optional[List[Attachment]] = None,
system: Optional[str] = None,
system_fragments: Optional[List[str]] = None,
stream: bool = True,
**options,
) -> Response:
self._validate_attachments(attachments)
return Response(
Prompt(
prompt,
fragments=fragments,
attachments=attachments,
system=system,
system_fragments=system_fragments,
model=self,
options=self.Options(**options),
),
Expand All @@ -578,17 +646,21 @@ def prompt(
self,
prompt: str,
*,
fragments: Optional[List[str]] = None,
attachments: Optional[List[Attachment]] = None,
system: Optional[str] = None,
system_fragments: Optional[List[str]] = None,
stream: bool = True,
**options,
) -> AsyncResponse:
self._validate_attachments(attachments)
return AsyncResponse(
Prompt(
prompt,
fragments=fragments,
attachments=attachments,
system=system,
system_fragments=system_fragments,
model=self,
options=self.Options(**options),
),
Expand Down

0 comments on commit 6c355c1

Please sign in to comment.