From 6c355c1e534f27a8a3b4d2930fe6e5a99c34c29a Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 17 Nov 2024 12:17:16 -0800 Subject: [PATCH] WIP fragments: schema plus reading but not yet writing, refs #617 --- llm/cli.py | 39 +++++++++++++++++++ llm/migrations.py | 39 +++++++++++++++++++ llm/models.py | 98 ++++++++++++++++++++++++++++++++++++++++------- 3 files changed, 163 insertions(+), 13 deletions(-) diff --git a/llm/cli.py b/llm/cli.py index 5a9f20b4..d1261d2e 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -53,6 +53,23 @@ DEFAULT_TEMPLATE = "prompt: " +def resolve_fragments(fragments): + # These can be URLs or paths + resolved = [] + for fragment in fragments: + if fragment.startswith("http://") or fragment.startswith("https://"): + response = httpx.get(fragment, follow_redirects=True) + response.raise_for_status() + resolved.append(response.text) + elif fragment == "-": + resolved.append(sys.stdin.read()) + elif pathlib.Path(fragment).exists(): + resolved.append(pathlib.Path(fragment).read_text()) + else: + raise click.ClickException(f"Fragment {fragment} not found") + return resolved + + class AttachmentType(click.ParamType): name = "attachment" @@ -174,6 +191,16 @@ def cli(): multiple=True, help="key/value options for the model", ) +@click.option( + "fragments", "-f", "--fragment", multiple=True, help="Fragment to add to prompt" +) +@click.option( + "system_fragments", + "--sf", + "--system-fragment", + multiple=True, + help="Fragment to add to system prompt", +) @click.option("-t", "--template", help="Template to use") @click.option( "-p", @@ -209,6 +236,8 @@ def prompt( attachments, attachment_types, options, + fragments, + system_fragments, template, param, no_stream, @@ -266,6 +295,7 @@ def read_prompt(): and sys.stdin.isatty() and not attachments and not attachment_types + and not fragments ): # Hang waiting for input to stdin (unless --save) prompt = sys.stdin.read() @@ -377,6 +407,9 @@ def read_prompt(): prompt = read_prompt() + fragments = resolve_fragments(fragments) + system_fragments = resolve_fragments(system_fragments) + prompt_method = model.prompt if conversation: prompt_method = conversation.prompt @@ -388,8 +421,10 @@ async def inner(): if should_stream: async for chunk in prompt_method( prompt, + fragments=fragments, attachments=resolved_attachments, system=system, + system_fragments=system_fragments, **validated_options, ): print(chunk, end="") @@ -398,8 +433,10 @@ async def inner(): else: response = prompt_method( prompt, + fragments=fragments, attachments=resolved_attachments, system=system, + system_fragments=system_fragments, **validated_options, ) print(await response.text()) @@ -408,8 +445,10 @@ async def inner(): else: response = prompt_method( prompt, + fragments=fragments, attachments=resolved_attachments, system=system, + system_fragments=system_fragments, **validated_options, ) if should_stream: diff --git a/llm/migrations.py b/llm/migrations.py index 91da6429..9f0cf987 100644 --- a/llm/migrations.py +++ b/llm/migrations.py @@ -227,3 +227,42 @@ def m012_attachments_tables(db): ), pk=("response_id", "attachment_id"), ) + + +@migration +def m013_fragments_tables(db): + db["fragments"].create( + { + "id": int, + "hash": str, + "content": str, + "alias": str, + "datetime_utc": str, + "source": str, + }, + pk="id", + ) + db["prompt_fragments"].create( + { + "response_id": str, + "fragment_id": str, + "order": int, + }, + foreign_keys=( + ("response_id", "responses", "id"), + ("fragment_id", "fragments", "id"), + ), + pk=("response_id", "fragment_id"), + ) + db["system_fragments"].create( + { + "response_id": str, + "fragment_id": str, + "order": int, + }, + foreign_keys=( + ("response_id", "responses", "id"), + ("fragment_id", "fragments", "id"), + ), + pk=("response_id", "fragment_id"), + ) diff --git a/llm/models.py b/llm/models.py index f5c8fd3b..36da2dae 100644 --- a/llm/models.py +++ b/llm/models.py @@ -89,10 +89,12 @@ def from_row(cls, row): @dataclass class Prompt: - prompt: str + _prompt: str model: "Model" + fragments: Optional[List[str]] attachments: Optional[List[Attachment]] - system: Optional[str] + _system: Optional[str] + system_fragments: Optional[List[str]] prompt_json: Optional[str] options: "Options" @@ -101,18 +103,55 @@ def __init__( prompt, model, *, + fragments=None, attachments=None, system=None, + system_fragments=None, prompt_json=None, options=None, ): - self.prompt = prompt + self._prompt = prompt self.model = model self.attachments = list(attachments or []) - self.system = system + self.fragments = fragments or [] + self._system = system + self.system_fragments = system_fragments or [] self.prompt_json = prompt_json self.options = options or {} + @property + def prompt(self): + return "\n".join(self.fragments + [self._prompt]) + + @property + def system(self): + bits = [ + bit.strip() + for bit in (self.system_fragments + [self._system or ""]) + if bit.strip() + ] + return "\n\n".join(bits) + + @classmethod + def from_row(cls, db, row, model): + all_fragments = list(db.query(FRAGMENT_SQL, {"response_id": row["id"]})) + fragments = [ + row["content"] for row in all_fragments if row["fragment_type"] == "prompt" + ] + system_fragments = [ + row["content"] for row in all_fragments if row["fragment_type"] == "system" + ] + breakpoint() + return cls( + prompt=row["prompt"], + model=model, + fragments=fragments, + attachments=[], + system=row["system"], + system_fragments=system_fragments, + options=model.Options(**json.loads(row["options_json"])), + ) + @dataclass class _BaseConversation: @@ -138,8 +177,10 @@ def prompt( self, prompt: Optional[str], *, + fragments: Optional[List[str]] = None, attachments: Optional[List[Attachment]] = None, system: Optional[str] = None, + system_fragments: Optional[List[str]] = None, stream: bool = True, **options, ) -> "Response": @@ -147,8 +188,10 @@ def prompt( Prompt( prompt, model=self.model, + fragments=fragments, attachments=attachments, system=system, + system_fragments=system_fragments, options=self.model.Options(**options), ), self.model, @@ -163,8 +206,10 @@ def prompt( self, prompt: Optional[str], *, + fragments: Optional[List[str]] = None, attachments: Optional[List[Attachment]] = None, system: Optional[str] = None, + system_fragments: Optional[List[str]] = None, stream: bool = True, **options, ) -> "AsyncResponse": @@ -172,8 +217,10 @@ def prompt( Prompt( prompt, model=self.model, + fragments=fragments, attachments=attachments, system=system, + system_fragments=system_fragments, options=self.model.Options(**options), ), self.model, @@ -182,6 +229,26 @@ def prompt( ) +FRAGMENT_SQL = """ +select + 'prompt' as fragment_type, + f.content, + pf."order" as ord +from prompt_fragments pf +join fragments f on pf.fragment_id = f.id +where pf.response_id = :response_id +union all +select + 'system' as fragment_type, + f.content, + sf."order" as ord +from system_fragments sf +join fragments f on sf.fragment_id = f.id +where sf.response_id = :response_id +order by fragment_type desc, ord asc; +""" + + class _BaseResponse: """Base response class shared between sync and async responses""" @@ -217,13 +284,7 @@ def from_row(cls, db, row): response = cls( model=model, - prompt=Prompt( - prompt=row["prompt"], - model=model, - attachments=[], - system=row["system"], - options=model.Options(**json.loads(row["options_json"])), - ), + prompt=Prompt.from_row(db, row, model), stream=False, ) response.id = row["id"] @@ -233,8 +294,8 @@ def from_row(cls, db, row): response._chunks = [row["response"]] # Attachments response.attachments = [ - Attachment.from_row(arow) - for arow in db.query( + Attachment.from_row(attachment_row) + for attachment_row in db.query( """ select attachments.* from attachments join prompt_attachments on attachments.id = prompt_attachments.attachment_id @@ -328,6 +389,9 @@ def datetime_utc(self) -> str: self._force() return self._start_utcnow.isoformat() if self._start_utcnow else "" + def text_or_raise(self) -> str: + return self.text() + def __iter__(self) -> Iterator[str]: self._start = time.monotonic() self._start_utcnow = datetime.datetime.utcnow() @@ -541,8 +605,10 @@ def prompt( self, prompt: str, *, + fragments: Optional[List[str]] = None, attachments: Optional[List[Attachment]] = None, system: Optional[str] = None, + system_fragments: Optional[List[str]] = None, stream: bool = True, **options, ) -> Response: @@ -550,8 +616,10 @@ def prompt( return Response( Prompt( prompt, + fragments=fragments, attachments=attachments, system=system, + system_fragments=system_fragments, model=self, options=self.Options(**options), ), @@ -578,8 +646,10 @@ def prompt( self, prompt: str, *, + fragments: Optional[List[str]] = None, attachments: Optional[List[Attachment]] = None, system: Optional[str] = None, + system_fragments: Optional[List[str]] = None, stream: bool = True, **options, ) -> AsyncResponse: @@ -587,8 +657,10 @@ def prompt( return AsyncResponse( Prompt( prompt, + fragments=fragments, attachments=attachments, system=system, + system_fragments=system_fragments, model=self, options=self.Options(**options), ),