Skip to content

Commit

Permalink
properly resolve fallbackt ools
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan committed Feb 2, 2025
1 parent 04ce680 commit ce217f7
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 23 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/genai-pr-review.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ env:
# Configure default GenAIScript models
# using Ollama's models
GENAISCRIPT_DEFAULT_MODEL: ollama:qwen2.5-coder:7b
GENAISCRIPT_DEFAULT_REASONING_MODEL: ollama:deepseek-r1:1.5b
GENAISCRIPT_DEFAULT_SMALL_MODEL: ollama:qwen2.5-coder:1.5b
GENAISCRIPT_DEFAULT_VISION_MODEL: ollama:llama3.2-vision:11b
jobs:
Expand All @@ -40,11 +41,11 @@ jobs:
run: git fetch origin && git pull origin main:main
- name: genaiscript pr-describe
continue-on-error: true
run: node packages/cli/built/genaiscript.cjs run pr-describe --out ./temp/genai/pr-describe -prd --out-trace $GITHUB_STEP_SUMMARY
run: node packages/cli/built/genaiscript.cjs run pr-describe --out ./temp/genai/pr-describe -prd -m reasoning --out-trace $GITHUB_STEP_SUMMARY
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: genaiscript pr-review
run: node packages/cli/built/genaiscript.cjs run pr-review --out ./temp/genai/pr-review -prc --out-trace $GITHUB_STEP_SUMMARY
run: node packages/cli/built/genaiscript.cjs run pr-review --out ./temp/genai/pr-review -prc -m reasoning --out-trace $GITHUB_STEP_SUMMARY
continue-on-error: true
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Expand Down
2 changes: 1 addition & 1 deletion packages/cli/src/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ export async function runScriptInternal(
const removeOut = options.removeOut
const cancellationToken = options.cancellationToken
const jsSource = options.jsSource
const fallbackTools = !!options.fallbackTools
const fallbackTools = options.fallbackTools
const logprobs = options.logprobs
const topLogprobs = normalizeInt(options.topLogprobs)
const fenceFormat = options.fenceFormat
Expand Down
4 changes: 2 additions & 2 deletions packages/core/src/expander.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import { addToolDefinitionsMessage, appendSystemMessage } from "./chat"
import { importPrompt } from "./importprompt"
import { parseModelIdentifier } from "./models"
import { runtimeHost } from "./host"
import { resolveSystems } from "./systems"
import { addFallbackToolSystems, resolveSystems } from "./systems"
import { GenerationOptions } from "./generation"
import {
AICIRequest,
Expand Down Expand Up @@ -356,7 +356,7 @@ export async function expandTemplate(
trace.endDetails()
}

if (systems.includes("system.tool_calls")) {
if (addFallbackToolSystems(systems, tools, template, options)) {
addToolDefinitionsMessage(messages, tools)
options.fallbackTools = true
}
Expand Down
25 changes: 15 additions & 10 deletions packages/core/src/runpromptcontext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ import {
SPEECH_MODEL_ID,
} from "./constants"
import { renderAICI } from "./aici"
import { resolveSystems, resolveTools } from "./systems"
import { addFallbackToolSystems, resolveSystems, resolveTools } from "./systems"
import { callExpander } from "./expander"
import {
errorMessage,
Expand Down Expand Up @@ -992,19 +992,24 @@ export function createChatGenerationContext(
} finally {
runTrace.endDetails()
}
if (systemScripts.includes("system.tool_calls")) {

if (
addFallbackToolSystems(
systemScripts,
tools,
runOptions,
genOptions
)
) {
addToolDefinitionsMessage(messages, tools)
genOptions.fallbackTools = true
}

finalizeMessages(
messages,
{
...(runOptions || {}),
fileOutputs,
trace,
}
)
finalizeMessages(messages, {
...(runOptions || {}),
fileOutputs,
trace,
})
const { completer } = await resolveLanguageModel(
configuration.provider
)
Expand Down
20 changes: 15 additions & 5 deletions packages/core/src/systems.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,27 @@ export function resolveSystems(
.filter((s) => !!s)
.filter((s) => !excludedSystem.includes(s))

const fallbackTools =
isToolsSupported(options?.model) === false || options?.fallbackTools
if (fallbackTools && (tools.length || resolvedTools?.length))
systems.push("system.tool_calls")

// Return a unique list of non-empty systems
// Filters out duplicates and empty entries using unique utility
const res = uniq(systems)
return res
}

export function addFallbackToolSystems(
systems: string[],
tools: ToolCallback[],
options?: ModelOptions,
genOptions?: GenerationOptions
) {
if (!tools?.length || systems.includes("system.tool_calls")) return false

const fallbackTools =
isToolsSupported(options?.model || genOptions?.model) === false ||
genOptions?.fallbackTools
if (fallbackTools) systems.push("system.tool_calls")
return fallbackTools
}

/**
* Helper function to resolve tools in the project and return their system IDs.
* Finds systems in the project associated with a specific tool.
Expand Down
8 changes: 5 additions & 3 deletions packages/core/src/tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ import { parseModelIdentifier } from "./models"

export function isToolsSupported(modelId: string): boolean | undefined {
if (!modelId) return undefined
const { provider, model } = parseModelIdentifier(modelId)
const { provider, family } = parseModelIdentifier(modelId)

const info = MODEL_PROVIDERS.find(({ id }) => provider === id)
if (info?.tools === false) return false

if (/^o1-(mini|preview)/.test(model)) return false
if (/^o1-(mini|preview)/.test(family)) return false

const oai = {
"o1-preview": false,
Expand All @@ -39,6 +39,7 @@ export function isToolsSupported(modelId: string): boolean | undefined {
["llama2"]: false,
["codellama"]: false,
["phi"]: false,
["deepseek-r1"]: false,
},
[MODEL_PROVIDER_OPENAI]: oai,
[MODEL_PROVIDER_AZURE_OPENAI]: oai,
Expand All @@ -48,5 +49,6 @@ export function isToolsSupported(modelId: string): boolean | undefined {
},
}

return data[provider]?.[model]
const res = data[provider]?.[family]
return res
}

0 comments on commit ce217f7

Please sign in to comment.