Skip to content

Commit

Permalink
support reasoning effort flag (#1087)
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan authored Feb 2, 2025
1 parent c2b51bf commit 9d937f2
Show file tree
Hide file tree
Showing 18 changed files with 118 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/playwright.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ jobs:
- name: download ollama docker
run: yarn ollama:start
- name: run browse-text
run: yarn run:script browse-text --out ./temp/browse-text --model ollama:phi3.5
run: yarn run:script browse-text --out ./temp/browse-text --model ollama:smollm2:135m
3 changes: 3 additions & 0 deletions docs/src/content/docs/reference/cli/commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Options:
-sm, --small-model <string> 'small' alias model
-vm, --vision-model <string> 'vision' alias model
-ma, --model-alias <nameid...> model alias as name=modelid
-re, --reasoning-effort <string> Reasoning effort for o* models (choices: "high", "medium", "low")
-lp, --logprobs enable reporting token probabilities
-tlp, --top-logprobs <number> number of top logprobs (1 to 5)
-ef, --excluded-files <string...> excluded files
Expand Down Expand Up @@ -100,6 +101,8 @@ Options:
-sm, --small-model <string> 'small' alias model
-vm, --vision-model <string> 'vision' alias model
-ma, --model-alias <nameid...> model alias as name=modelid
-re, --reasoning-effort <string> Reasoning effort for o* models (choices:
"high", "medium", "low")
--models <models...> models to test where mode is the key
value pair list of m (model), s (small
model), t (temperature), p (top-p)
Expand Down
38 changes: 38 additions & 0 deletions docs/src/content/docs/reference/scripts/o-models.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
---
title: o1, o3 models
description: Specific information about OpenAI reasoning models.
sidebar:
order: 100
---

The OpenAI reasoning models, the `o1, o3` models, are models that are optimized for reasoning tasks.

```js
script({
model: "openai:o1",
})
```

- You can experiement with these models on Github Models as well but the context window is quite small.

```js
script({
model: "github:o3-mini",
})
```

## Reasoning effort

The reasoning effort parameter can be set to `low`, `medium`, or `high`.

```js 'reasoningEffort: "high"'
script({
model: "openai:o3-mini"
reasoningEffort: "high"
})
```

## Limitations

- `o1-preview`, `o1-mini` do not support streaming
- `o1` models do not support tool calling so GenAIScript uses [fallback tools](/genaiscript/reference/scripts/tools).
6 changes: 6 additions & 0 deletions packages/cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -528,5 +528,11 @@ export async function cli() {
"-ma, --model-alias <nameid...>",
"model alias as name=modelid"
)
.addOption(
new Option(
"-re, --reasoning-effort <string>",
"Reasoning effort for o* models"
).choices(["high", "medium", "low"])
)
}
}
2 changes: 2 additions & 0 deletions packages/cli/src/nodehost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ export class NodeHost implements RuntimeHost {
(c as any).model = value.model
if (!isNaN(value.temperature))
(c as any).temperature = value.temperature
if (value.reasoningEffort)
(c as any).reasoningEffort = value.reasoningEffort
}

async pullModel(
Expand Down
2 changes: 2 additions & 0 deletions packages/cli/src/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ export async function runScriptInternal(
const outData = options.outData
const label = options.label
const temperature = normalizeFloat(options.temperature)
const reasoningEffort = options.reasoningEffort
const topP = normalizeFloat(options.topP)
const seed = normalizeFloat(options.seed)
const maxTokens = normalizeInt(options.maxTokens)
Expand Down Expand Up @@ -374,6 +375,7 @@ export async function runScriptInternal(
label,
cache,
temperature,
reasoningEffort,
topP,
seed,
cancellationToken,
Expand Down
2 changes: 2 additions & 0 deletions packages/cli/src/test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import { filterScripts } from "../../core/src/ast"
import { link } from "../../core/src/mkmd"
import { applyModelOptions } from "./modelalias"
import { normalizeFloat, normalizeInt } from "../../core/src/cleaners"
import { ChatCompletionReasoningEffort } from "../../core/src/chattypes"

/**
* Parses model specifications from a string and returns a ModelOptions object.
Expand All @@ -67,6 +68,7 @@ function parseModelSpec(m: string): ModelOptions & ModelAliasesOptions {
visionModel: values["v"],
temperature: normalizeFloat(values["t"]),
topP: normalizeFloat(values["p"]),
reasoningEffort: values["r"] as ChatCompletionReasoningEffort,
}
else return { model: m }
}
Expand Down
7 changes: 6 additions & 1 deletion packages/core/src/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,9 @@ export function mergeGenerationOptions(
temperature:
runOptions?.temperature ??
runtimeHost.modelAliases.large.temperature,
reasoningEffort:
runOptions?.reasoningEffort ??
runtimeHost.modelAliases.large.reasoningEffort,
embeddingsModel:
runOptions?.embeddingsModel ??
options?.embeddingsModel ??
Expand Down Expand Up @@ -942,6 +945,7 @@ export async function executeChatSession(
trace,
model,
temperature,
reasoningEffort,
topP,
maxTokens,
seed,
Expand Down Expand Up @@ -1018,7 +1022,8 @@ export async function executeChatSession(
)
req = {
model,
temperature: temperature,
temperature,
reasoning_effort: reasoningEffort,
top_p: topP,
max_tokens: maxTokens,
logit_bias,
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/chattypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ export interface AICIRequest extends ChatCompletionMessageParamCacheControl {
functionName: string // Name of the function being requested
}

export type ChatCompletionReasoningEffort = OpenAI.ChatCompletionReasoningEffort

// Aliases for OpenAI chat completion types
export type ChatCompletionUsage = OpenAI.Completions.CompletionUsage
export type ChatCompletionUsageCompletionTokensDetails =
Expand Down
3 changes: 2 additions & 1 deletion packages/core/src/clihelp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export function generateCliArguments(
options: GenerationOptions,
command: "run" | "batch"
) {
const { model, temperature, topP, seed, cliInfo } = options
const { model, temperature, reasoningEffort, topP, seed, cliInfo } = options
const { files = [] } = cliInfo || {}

const cli = [
Expand All @@ -26,6 +26,7 @@ export function generateCliArguments(
if (!isNaN(temperature)) cli.push(`--temperature`, temperature + "")
if (!isNaN(topP)) cli.push(`--top-p`, topP + "")
if (!isNaN(seed)) cli.push("--seed", seed + "")
if (reasoningEffort) cli.push("--reasoning-effort", reasoningEffort)

return cli.join(" ")
}
Expand Down
7 changes: 7 additions & 0 deletions packages/core/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,13 @@ export const MODEL_PROVIDER_JAN = "jan"
export const MODEL_PROVIDER_DEEPSEEK = "deepseek"
export const MODEL_WHISPERASR_PROVIDER = "whisperasr"

export const MODEL_PROVIDER_OPENAI_HOSTS = Object.freeze([
MODEL_PROVIDER_OPENAI,
MODEL_PROVIDER_GITHUB,
MODEL_PROVIDER_AZURE_OPENAI,
MODEL_PROVIDER_AZURE_SERVERLESS_OPENAI
])

export const TRACE_FILE_PREVIEW_MAX_LENGTH = 240

export const OPENROUTER_API_CHAT_URL =
Expand Down
13 changes: 12 additions & 1 deletion packages/core/src/expander.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ import { parseModelIdentifier } from "./models"
import { runtimeHost } from "./host"
import { resolveSystems } from "./systems"
import { GenerationOptions } from "./generation"
import { AICIRequest, ChatCompletionMessageParam } from "./chattypes"
import {
AICIRequest,
ChatCompletionMessageParam,
ChatCompletionReasoningEffort,
} from "./chattypes"
import { GenerationStatus, Project } from "./server/messages"
import { dispose } from "./dispose"
import { normalizeFloat, normalizeInt } from "./cleaners"
Expand Down Expand Up @@ -193,6 +197,11 @@ export async function expandTemplate(
normalizeFloat(env.vars["temperature"]) ??
template.temperature ??
runtimeHost.modelAliases.large.temperature
const reasoningEffort: ChatCompletionReasoningEffort =
options.reasoningEffort ??
env.vars["reasoning_effort"] ??
template.reasoningEffort ??
runtimeHost.modelAliases.large.reasoningEffort
const topP =
options.topP ?? normalizeFloat(env.vars["top_p"]) ?? template.topP
const maxTokens =
Expand Down Expand Up @@ -236,6 +245,7 @@ export async function expandTemplate(
seed,
topP,
temperature,
reasoningEffort,
lineNumbers,
fenceFormat,
})
Expand Down Expand Up @@ -369,6 +379,7 @@ export async function expandTemplate(
statusText: statusText,
model,
temperature,
reasoningEffort,
topP,
maxTokens,
maxToolCalls,
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/host.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ export interface AzureTokenResolver {
}

export type ModelConfiguration = Readonly<
Pick<ModelOptions, "model" | "temperature"> & {
Pick<ModelOptions, "model" | "temperature" | "reasoningEffort"> & {
source: "cli" | "env" | "script" | "config" | "default"
candidates?: string[]
}
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ export function traceLanguageModelConnection(
const {
model,
temperature,
reasoningEffort,
topP,
maxTokens,
seed,
Expand All @@ -74,6 +75,7 @@ export function traceLanguageModelConnection(
trace.itemValue(`source`, source)
trace.itemValue(`provider`, provider)
trace.itemValue(`temperature`, temperature)
trace.itemValue(`reasoningEffort`, reasoningEffort)
trace.itemValue(`topP`, topP)
trace.itemValue(`maxTokens`, maxTokens)
trace.itemValue(`base`, base)
Expand Down
38 changes: 24 additions & 14 deletions packages/core/src/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { host } from "./host"
import {
AZURE_AI_INFERENCE_VERSION,
AZURE_OPENAI_API_VERSION,
MODEL_PROVIDER_OPENAI_HOSTS,
MODEL_PROVIDERS,
OPENROUTER_API_CHAT_URL,
OPENROUTER_SITE_NAME_HEADER,
Expand Down Expand Up @@ -34,6 +35,7 @@ import {
ChatCompletionChoice,
CreateChatCompletionRequest,
ChatCompletionTokenLogprob,
ChatCompletionReasoningEffort,
} from "./chattypes"
import { resolveTokenEncoder } from "./encoders"
import { CancellationOptions } from "./cancellation"
Expand Down Expand Up @@ -90,7 +92,7 @@ export const OpenAIChatCompletion: ChatCompletionHandler = async (
inner,
} = options
const { headers = {}, ...rest } = requestOptions || {}
const { model } = parseModelIdentifier(req.model)
const { provider, model } = parseModelIdentifier(req.model)
const { encode: encoder } = await resolveTokenEncoder(model)

const postReq = structuredClone({
Expand All @@ -108,21 +110,29 @@ export const OpenAIChatCompletion: ChatCompletionHandler = async (
delete postReq.stream_options
}

if (/^o1/i.test(model)) {
const preview = /^o1-(preview|mini)/i.test(model)
delete postReq.temperature
delete postReq.stream
delete postReq.stream_options
for (const msg of postReq.messages) {
if (msg.role === "system") {
;(msg as any).role = preview ? "user" : "developer"
if (MODEL_PROVIDER_OPENAI_HOSTS.includes(provider)) {
if (/^o(1|3)/.test(model)) {
delete postReq.temperature
if (postReq.max_tokens) {
postReq.max_completion_tokens = postReq.max_tokens
delete postReq.max_tokens
}
}
} else if (/^o3/i.test(model)) {
delete postReq.temperature
for (const msg of postReq.messages) {
if (msg.role === "system") {
;(msg as any).role = "developer"

if (/^o1/.test(model)) {
const preview = /^o1-(preview|mini)/i.test(model)
delete postReq.stream
delete postReq.stream_options
for (const msg of postReq.messages) {
if (msg.role === "system") {
;(msg as any).role = preview ? "user" : "developer"
}
}
} else if (/^o3/i.test(model)) {
for (const msg of postReq.messages) {
if (msg.role === "system") {
;(msg as any).role = "developer"
}
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/promptrunner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ export async function runTemplate(
status,
statusText,
temperature,
reasoningEffort,
topP,
maxTokens,
seed,
Expand Down Expand Up @@ -249,6 +250,7 @@ export async function runTemplate(
responseSchema,
model,
temperature,
reasoningEffort,
maxTokens,
topP,
seed,
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/server/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ export interface PromptScriptRunOptions {
outData: string
label: string
temperature: string | number
reasoningEffort: "high" | "low" | "medium"
topP: string | number
seed: string | number
maxTokens: string | number
Expand Down
5 changes: 5 additions & 0 deletions packages/core/src/types/prompt_template.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,11 @@ interface ModelOptions extends ModelConnectionOptions, ModelTemplateOptions {
*/
temperature?: number

/**
* Some reasoning model support a reasoning effort parameter.
*/
reasoningEffort?: "high" | "medium" | "low"

/**
* A list of keywords that should be found in the output.
*/
Expand Down

0 comments on commit 9d937f2

Please sign in to comment.