Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

towards audio #1048

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions packages/core/src/chat.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
// cspell: disable
import { MarkdownTrace, TraceOptions } from "./trace"
import { PromptImage, PromptPrediction, renderPromptNode } from "./promptdom"
import {
PromptAudio,
PromptImage,
PromptPrediction,
renderPromptNode,
} from "./promptdom"
import { host, runtimeHost } from "./host"
import { GenerationOptions } from "./generation"
import { dispose } from "./dispose"
Expand Down Expand Up @@ -46,6 +51,7 @@ import { parseModelIdentifier, traceLanguageModelConnection } from "./models"
import {
ChatCompletionAssistantMessageParam,
ChatCompletionContentPartImage,
ChatCompletionContentPartInputAudio,
ChatCompletionMessageParam,
ChatCompletionResponse,
ChatCompletionsOptions,
Expand Down Expand Up @@ -95,9 +101,11 @@ import { deleteUndefinedValues } from "./cleaners"

export function toChatCompletionUserMessage(
expanded: string,
images?: PromptImage[]
images?: PromptImage[],
audios?: PromptAudio[]
): ChatCompletionUserMessageParam {
const imgs = images?.filter(({ url }) => url) || []
const auds = audios?.filter(({ data }) => data) || []
if (imgs.length)
return <ChatCompletionUserMessageParam>{
role: "user",
Expand All @@ -108,13 +116,23 @@ export function toChatCompletionUserMessage(
},
...imgs.map(
({ url, detail }) =>
<ChatCompletionContentPartImage>{
({
type: "image_url",
image_url: {
url,
detail,
},
}
}) satisfies ChatCompletionContentPartImage
),
...auds.map(
({ data, format }) =>
({
type: "input_audio",
input_audio: {
data,
format,
},
}) satisfies ChatCompletionContentPartInputAudio
),
],
}
Expand All @@ -135,9 +153,11 @@ export type ChatCompletionHandler = (
export type ListModelsFunction = (
cfg: LanguageModelConfiguration,
options: TraceOptions & CancellationOptions
) => Promise<ResponseStatus & {
models?: LanguageModelInfo[]
}>
) => Promise<
ResponseStatus & {
models?: LanguageModelInfo[]
}
>

export type PullModelFunction = (
cfg: LanguageModelConfiguration,
Expand Down
3 changes: 3 additions & 0 deletions packages/core/src/chattypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ export type ChatCompletionUserMessageParam =
export type ChatCompletionContentPartImage =
OpenAI.Chat.Completions.ChatCompletionContentPartImage

export type ChatCompletionContentPartInputAudio =
OpenAI.Chat.Completions.ChatCompletionContentPartInputAudio

// Parameters for creating embeddings
export type EmbeddingCreateParams = OpenAI.Embeddings.EmbeddingCreateParams

Expand Down
12 changes: 10 additions & 2 deletions packages/core/src/expander.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
} from "./constants"
import {
finalizeMessages,
PromptAudio,
PromptImage,
PromptPrediction,
renderPromptNode,
Expand Down Expand Up @@ -50,6 +51,7 @@ export async function callExpander(
let logs = ""
let messages: ChatCompletionMessageParam[] = []
let images: PromptImage[] = []
let audios: PromptAudio[] = []
let schemas: Record<string, JSONSchema> = {}
let functions: ToolCallback[] = []
let fileMerges: FileMergeHandler[] = []
Expand Down Expand Up @@ -82,6 +84,7 @@ export async function callExpander(
const {
messages: msgs,
images: imgs,
audios: auds,
errors,
schemas: schs,
functions: fns,
Expand All @@ -98,6 +101,7 @@ export async function callExpander(
})
messages = msgs
images = imgs
audios = auds
schemas = schs
functions = fns
fileMerges = fms
Expand Down Expand Up @@ -136,6 +140,7 @@ export async function callExpander(
statusText,
messages,
images,
audios,
schemas,
functions: Object.freeze(functions),
fileMerges,
Expand Down Expand Up @@ -247,6 +252,7 @@ export async function expandTemplate(

const { status, statusText, messages } = prompt
const images = prompt.images.slice(0)
const audios = prompt.audios.slice(0)
const schemas = structuredClone(prompt.schemas)
const tools = prompt.functions.slice(0)
const fileMerges = prompt.fileMerges.slice(0)
Expand Down Expand Up @@ -279,8 +285,8 @@ export async function expandTemplate(
}
}

if (prompt.images?.length)
messages.push(toChatCompletionUserMessage("", prompt.images))
if (images?.length || audios?.length)
messages.push(toChatCompletionUserMessage("", images, audios))
if (prompt.aici) messages.push(prompt.aici)

const addSystemMessage = (content: string) => {
Expand Down Expand Up @@ -314,6 +320,7 @@ export async function expandTemplate(
const sysr = await callExpander(prj, system, env, trace, options)

if (sysr.images) images.push(...sysr.images)
if (sysr.audios) audios.push(...sysr.audios)
if (sysr.schemas) Object.assign(schemas, sysr.schemas)
if (sysr.functions) tools.push(...sysr.functions)
if (sysr.fileMerges) fileMerges.push(...sysr.fileMerges)
Expand Down Expand Up @@ -394,6 +401,7 @@ ${schemaTs}
cache,
messages,
images,
audios,
schemas,
tools,
status: <GenerationStatus>status,
Expand Down
48 changes: 48 additions & 0 deletions packages/core/src/promptdom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export interface PromptNode extends ContextExpansionOptions {
type?:
| "text"
| "image"
| "audio"
| "schema"
| "tool"
| "fileMerge"
Expand Down Expand Up @@ -150,6 +151,18 @@ export interface PromptImageNode extends PromptNode {
resolved?: PromptImage // Resolved image information
}

export interface PromptAudio {
filename?: string
data: string
format: "mp3" | "wav"
}

export interface PromptAudioNode extends PromptNode {
type: "audio"
value: Awaitable<PromptAudio> // Image information
resolved?: PromptAudio // Resolved image information
}

// Interface for a schema node.
export interface PromptSchemaNode extends PromptNode {
type: "schema"
Expand Down Expand Up @@ -418,6 +431,15 @@ export function createImageNode(
return { type: "image", value, ...(options || {}) }
}

// Function to create an image node.
export function createAudioNode(
value: Awaitable<PromptAudio>,
options?: ContextExpansionOptions
): PromptAudioNode {
assert(value !== undefined)
return { type: "audio", value, ...(options || {}) }
}

// Function to create a schema node.
export function createSchemaNode(
name: string,
Expand Down Expand Up @@ -556,6 +578,7 @@ export interface PromptNodeVisitor {
def?: (node: PromptDefNode) => Awaitable<void> // Definition node visitor
defData?: (node: PromptDefDataNode) => Awaitable<void> // Definition data node visitor
image?: (node: PromptImageNode) => Awaitable<void> // Image node visitor
audio?: (node: PromptAudioNode) => Awaitable<void> // Audio node visitor
schema?: (node: PromptSchemaNode) => Awaitable<void> // Schema node visitor
tool?: (node: PromptToolNode) => Awaitable<void> // Function node visitor
fileMerge?: (node: PromptFileMergeNode) => Awaitable<void> // File merge node visitor
Expand Down Expand Up @@ -585,6 +608,9 @@ export async function visitNode(node: PromptNode, visitor: PromptNodeVisitor) {
case "image":
await visitor.image?.(node as PromptImageNode)
break
case "audio":
await visitor.audio?.(node as PromptAudioNode)
break
case "schema":
await visitor.schema?.(node as PromptSchemaNode)
break
Expand Down Expand Up @@ -632,6 +658,7 @@ export async function visitNode(node: PromptNode, visitor: PromptNodeVisitor) {
// Interface for representing a rendered prompt node.
export interface PromptNodeRender {
images: PromptImage[] // Images included in the prompt
audios: PromptAudio[]
errors: unknown[] // Errors encountered during rendering
schemas: Record<string, JSONSchema> // Schemas included in the prompt
functions: ToolCallback[] // Functions included in the prompt
Expand Down Expand Up @@ -847,6 +874,15 @@ async function resolvePromptNode(
n.error = e
}
},
audio: async (n) => {
try {
const v = await n.value
n.resolved = v
n.preview = n.resolved ? `<audio />` : undefined
} catch (e) {
n.error = e
}
},
})
return { errors: err }
}
Expand Down Expand Up @@ -1186,6 +1222,7 @@ export async function renderPromptNode(
) => appendAssistantMessage(messages, content, options)

const images: PromptImage[] = []
const audios: PromptAudio[] = []
const errors: unknown[] = []
const schemas: Record<string, JSONSchema> = {}
const tools: ToolCallback[] = []
Expand Down Expand Up @@ -1248,6 +1285,16 @@ export async function renderPromptNode(
}
}
},
audio: async (n) => {
const value = n.resolved
if (value?.data) {
audios.push(value)
if (trace) {
trace.startDetails(`🎤 audio ${value.filename || ""}`)
trace.endDetails()
}
}
},
schema: (n) => {
const { name: schemaName, value: schema, options } = n
if (schemas[schemaName])
Expand Down Expand Up @@ -1333,6 +1380,7 @@ ${trimNewlines(schemaText)}

const res = Object.freeze<PromptNodeRender>({
images,
audios,
schemas,
functions: tools,
fileMerges,
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/runpromptcontext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,8 @@ export function createChatGenerationContext(
)
if (sysr.images?.length)
throw new NotSupportedError("images")
if (sysr.audios?.length)
throw new NotSupportedError("audios")
if (sysr.schemas) Object.assign(schemas, sysr.schemas)
if (sysr.functions) tools.push(...sysr.functions)
if (sysr.fileMerges?.length)
Expand Down
5 changes: 5 additions & 0 deletions packages/vscode/src/lmaccess.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ function messagesToChatMessages(messages: ChatCompletionMessageParam[]) {
m.content.some((c) => c.type === "image_url")
)
throw new Error("Vision model not supported")
if (
Array.isArray(m.content) &&
m.content.some((c) => c.type === "input_audio")
)
throw new Error("Ayudiuo model not supported")
return vscode.LanguageModelChatMessage.User(
renderMessageContent(m),
"genaiscript"
Expand Down
Loading