allow temperature to be configured per mode

This commit is contained in:
Dax Raad
2025-07-25 13:29:29 -04:00
parent 827469c725
commit e97613ef9f
22 changed files with 234 additions and 195 deletions

View File

@@ -99,6 +99,7 @@ export namespace Config {
export const Mode = z
.object({
model: z.string().optional(),
temperature: z.number().optional(),
prompt: z.string().optional(),
tools: z.record(z.string(), z.boolean()).optional(),
disable: z.boolean().optional(),

View File

@@ -5,22 +5,11 @@ import { mergeDeep, sortBy } from "remeda"
import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
import { Log } from "../util/log"
import { BunProc } from "../bun"
import { BashTool } from "../tool/bash"
import { EditTool } from "../tool/edit"
import { WebFetchTool } from "../tool/webfetch"
import { GlobTool } from "../tool/glob"
import { GrepTool } from "../tool/grep"
import { ListTool } from "../tool/ls"
import { PatchTool } from "../tool/patch"
import { ReadTool } from "../tool/read"
import { WriteTool } from "../tool/write"
import { TodoReadTool, TodoWriteTool } from "../tool/todo"
import { AuthAnthropic } from "../auth/anthropic"
import { AuthCopilot } from "../auth/copilot"
import { ModelsDev } from "./models"
import { NamedError } from "../util/error"
import { Auth } from "../auth"
import { TaskTool } from "../tool/task"
export namespace Provider {
const log = Log.create({ service: "provider" })
@@ -468,137 +457,6 @@ export namespace Provider {
}
}
const TOOLS = [
BashTool,
EditTool,
WebFetchTool,
GlobTool,
GrepTool,
ListTool,
// LspDiagnosticTool,
// LspHoverTool,
PatchTool,
ReadTool,
// MultiEditTool,
WriteTool,
TodoWriteTool,
TodoReadTool,
TaskTool,
]
export async function tools(providerID: string) {
const result = await Promise.all(TOOLS.map((t) => t()))
switch (providerID) {
case "anthropic":
return result.filter((t) => t.id !== "patch")
case "openai":
return result.map((t) => ({
...t,
parameters: optionalToNullable(t.parameters),
}))
case "azure":
return result.map((t) => ({
...t,
parameters: optionalToNullable(t.parameters),
}))
case "google":
return result.map((t) => ({
...t,
parameters: sanitizeGeminiParameters(t.parameters),
}))
default:
return result
}
}
function sanitizeGeminiParameters(schema: z.ZodTypeAny, visited = new Set()): z.ZodTypeAny {
if (!schema || visited.has(schema)) {
return schema
}
visited.add(schema)
if (schema instanceof z.ZodDefault) {
const innerSchema = schema.removeDefault()
// Handle Gemini's incompatibility with `default` on `anyOf` (unions).
if (innerSchema instanceof z.ZodUnion) {
// The schema was `z.union(...).default(...)`, which is not allowed.
// We strip the default and return the sanitized union.
return sanitizeGeminiParameters(innerSchema, visited)
}
// Otherwise, the default is on a regular type, which is allowed.
// We recurse on the inner type and then re-apply the default.
return sanitizeGeminiParameters(innerSchema, visited).default(schema._def.defaultValue())
}
if (schema instanceof z.ZodOptional) {
return z.optional(sanitizeGeminiParameters(schema.unwrap(), visited))
}
if (schema instanceof z.ZodObject) {
const newShape: Record<string, z.ZodTypeAny> = {}
for (const [key, value] of Object.entries(schema.shape)) {
newShape[key] = sanitizeGeminiParameters(value as z.ZodTypeAny, visited)
}
return z.object(newShape)
}
if (schema instanceof z.ZodArray) {
return z.array(sanitizeGeminiParameters(schema.element, visited))
}
if (schema instanceof z.ZodUnion) {
// This schema corresponds to `anyOf` in JSON Schema.
// We recursively sanitize each option in the union.
const sanitizedOptions = schema.options.map((option: z.ZodTypeAny) => sanitizeGeminiParameters(option, visited))
return z.union(sanitizedOptions as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
}
if (schema instanceof z.ZodString) {
const newSchema = z.string({ description: schema.description })
const safeChecks = ["min", "max", "length", "regex", "startsWith", "endsWith", "includes", "trim"]
// rome-ignore lint/suspicious/noExplicitAny: <explanation>
;(newSchema._def as any).checks = (schema._def as z.ZodStringDef).checks.filter((check) =>
safeChecks.includes(check.kind),
)
return newSchema
}
return schema
}
function optionalToNullable(schema: z.ZodTypeAny): z.ZodTypeAny {
if (schema instanceof z.ZodObject) {
const shape = schema.shape
const newShape: Record<string, z.ZodTypeAny> = {}
for (const [key, value] of Object.entries(shape)) {
const zodValue = value as z.ZodTypeAny
if (zodValue instanceof z.ZodOptional) {
newShape[key] = zodValue.unwrap().nullable()
} else {
newShape[key] = optionalToNullable(zodValue)
}
}
return z.object(newShape)
}
if (schema instanceof z.ZodArray) {
return z.array(optionalToNullable(schema.element))
}
if (schema instanceof z.ZodUnion) {
return z.union(
schema.options.map((option: z.ZodTypeAny) => optionalToNullable(option)) as [
z.ZodTypeAny,
z.ZodTypeAny,
...z.ZodTypeAny[],
],
)
}
return schema
}
export const ModelNotFoundError = NamedError.create(
"ProviderModelNotFoundError",
z.object({

View File

@@ -44,4 +44,9 @@ export namespace ProviderTransform {
}
return msgs
}
export function temperature(_providerID: string, modelID: string) {
if (modelID.includes("qwen")) return 0.55
return 0
}
}

View File

@@ -39,7 +39,8 @@ import { MessageV2 } from "./message-v2"
import { Mode } from "./mode"
import { LSP } from "../lsp"
import { ReadTool } from "../tool/read"
import { splitWhen } from "remeda"
import { mergeDeep, pipe, splitWhen } from "remeda"
import { ToolRegistry } from "../tool/registry"
export namespace Session {
const log = Log.create({ service: "session" })
@@ -430,7 +431,7 @@ export namespace Session {
}
}
const args = { filePath, offset, limit }
const result = await ReadTool().then((t) =>
const result = await ReadTool.init().then((t) =>
t.execute(args, {
sessionID: input.sessionID,
abort: new AbortController().signal,
@@ -660,10 +661,13 @@ export namespace Session {
const processor = createProcessor(assistantMsg, model.info)
for (const item of await Provider.tools(input.providerID)) {
if (mode.tools[item.id] === false) continue
if (input.tools?.[item.id] === false) continue
if (session.parentID && item.id === "task") continue
const enabledTools = pipe(
mode.tools,
mergeDeep(ToolRegistry.enabled(input.providerID, input.modelID)),
mergeDeep(input.tools ?? {}),
)
for (const item of await ToolRegistry.tools(input.providerID, input.modelID)) {
if (enabledTools[item.id] === false) continue
tools[item.id] = tool({
id: item.id as any,
description: item.description,
@@ -791,7 +795,9 @@ export namespace Session {
),
...MessageV2.toModelMessage(msgs),
],
temperature: model.info.temperature ? 0 : undefined,
temperature: model.info.temperature
? (mode.temperature ?? ProviderTransform.temperature(input.providerID, input.modelID))
: undefined,
tools: model.info.tool_call === false ? undefined : tools,
model: wrapLanguageModel({
model: model.language,

View File

@@ -7,6 +7,7 @@ export namespace Mode {
export const Info = z
.object({
name: z.string(),
temperature: z.number().optional(),
model: z
.object({
modelID: z.string(),
@@ -50,6 +51,7 @@ export namespace Mode {
item.name = key
if (value.model) item.model = Provider.parseModel(value.model)
if (value.prompt) item.prompt = value.prompt
if (value.temperature) item.temperature = value.temperature
if (value.tools)
item.tools = {
...value.tools,

View File

@@ -7,8 +7,7 @@ const MAX_OUTPUT_LENGTH = 30000
const DEFAULT_TIMEOUT = 1 * 60 * 1000
const MAX_TIMEOUT = 10 * 60 * 1000
export const BashTool = Tool.define({
id: "bash",
export const BashTool = Tool.define("bash", {
description: DESCRIPTION,
parameters: z.object({
command: z.string().describe("The command to execute"),

View File

@@ -14,8 +14,7 @@ import { File } from "../file"
import { Bus } from "../bus"
import { FileTime } from "../file/time"
export const EditTool = Tool.define({
id: "edit",
export const EditTool = Tool.define("edit", {
description: DESCRIPTION,
parameters: z.object({
filePath: z.string().describe("The absolute path to the file to modify"),

View File

@@ -5,8 +5,7 @@ import { App } from "../app/app"
import DESCRIPTION from "./glob.txt"
import { Ripgrep } from "../file/ripgrep"
export const GlobTool = Tool.define({
id: "glob",
export const GlobTool = Tool.define("glob", {
description: DESCRIPTION,
parameters: z.object({
pattern: z.string().describe("The glob pattern to match files against"),

View File

@@ -5,8 +5,7 @@ import { Ripgrep } from "../file/ripgrep"
import DESCRIPTION from "./grep.txt"
export const GrepTool = Tool.define({
id: "grep",
export const GrepTool = Tool.define("grep", {
description: DESCRIPTION,
parameters: z.object({
pattern: z.string().describe("The regex pattern to search for in file contents"),

View File

@@ -33,8 +33,7 @@ export const IGNORE_PATTERNS = [
const LIMIT = 100
export const ListTool = Tool.define({
id: "list",
export const ListTool = Tool.define("list", {
description: DESCRIPTION,
parameters: z.object({
path: z.string().describe("The absolute path to the directory to list (must be absolute, not relative)").optional(),

View File

@@ -5,8 +5,7 @@ import { LSP } from "../lsp"
import { App } from "../app/app"
import DESCRIPTION from "./lsp-diagnostics.txt"
export const LspDiagnosticTool = Tool.define({
id: "lsp_diagnostics",
export const LspDiagnosticTool = Tool.define("lsp_diagnostics", {
description: DESCRIPTION,
parameters: z.object({
path: z.string().describe("The path to the file to get diagnostics."),

View File

@@ -5,8 +5,7 @@ import { LSP } from "../lsp"
import { App } from "../app/app"
import DESCRIPTION from "./lsp-hover.txt"
export const LspHoverTool = Tool.define({
id: "lsp_hover",
export const LspHoverTool = Tool.define("lsp_hover", {
description: DESCRIPTION,
parameters: z.object({
file: z.string().describe("The path to the file to get diagnostics."),

View File

@@ -5,8 +5,7 @@ import DESCRIPTION from "./multiedit.txt"
import path from "path"
import { App } from "../app/app"
export const MultiEditTool = Tool.define({
id: "multiedit",
export const MultiEditTool = Tool.define("multiedit", {
description: DESCRIPTION,
parameters: z.object({
filePath: z.string().describe("The absolute path to the file to modify"),
@@ -22,7 +21,7 @@ export const MultiEditTool = Tool.define({
.describe("Array of edit operations to perform sequentially on the file"),
}),
async execute(params, ctx) {
const tool = await EditTool()
const tool = await EditTool.init()
const results = []
for (const [, edit] of params.edits.entries()) {
const result = await tool.execute(

View File

@@ -210,8 +210,7 @@ async function applyCommit(
}
}
export const PatchTool = Tool.define({
id: "patch",
export const PatchTool = Tool.define("patch", {
description: DESCRIPTION,
parameters: PatchParams,
execute: async (params, ctx) => {

View File

@@ -10,8 +10,7 @@ import { App } from "../app/app"
const DEFAULT_READ_LIMIT = 2000
const MAX_LINE_LENGTH = 2000
export const ReadTool = Tool.define({
id: "read",
export const ReadTool = Tool.define("read", {
description: DESCRIPTION,
parameters: z.object({
filePath: z.string().describe("The path to the file to read"),

View File

@@ -0,0 +1,170 @@
import z from "zod"
import { BashTool } from "./bash"
import { EditTool } from "./edit"
import { GlobTool } from "./glob"
import { GrepTool } from "./grep"
import { ListTool } from "./ls"
import { PatchTool } from "./patch"
import { ReadTool } from "./read"
import { TaskTool } from "./task"
import { TodoWriteTool, TodoReadTool } from "./todo"
import { WebFetchTool } from "./webfetch"
import { WriteTool } from "./write"
export namespace ToolRegistry {
const ALL = [
BashTool,
EditTool,
WebFetchTool,
GlobTool,
GrepTool,
ListTool,
PatchTool,
ReadTool,
WriteTool,
TodoWriteTool,
TodoReadTool,
TaskTool,
]
export function ids() {
return ALL.map((t) => t.id)
}
export async function tools(providerID: string, _modelID: string) {
const result = await Promise.all(
ALL.map(async (t) => ({
id: t.id,
...(await t.init()),
})),
)
if (providerID === "openai") {
return result.map((t) => ({
...t,
parameters: optionalToNullable(t.parameters),
}))
}
if (providerID === "azure") {
return result.map((t) => ({
...t,
parameters: optionalToNullable(t.parameters),
}))
}
if (providerID === "google") {
return result.map((t) => ({
...t,
parameters: sanitizeGeminiParameters(t.parameters),
}))
}
return result
}
export function enabled(_providerID: string, modelID: string): Record<string, boolean> {
if (modelID.includes("claude")) {
return {
patch: false,
}
}
if (modelID.includes("qwen")) {
return {
patch: false,
todowrite: false,
todoread: false,
}
}
return {}
}
function sanitizeGeminiParameters(schema: z.ZodTypeAny, visited = new Set()): z.ZodTypeAny {
if (!schema || visited.has(schema)) {
return schema
}
visited.add(schema)
if (schema instanceof z.ZodDefault) {
const innerSchema = schema.removeDefault()
// Handle Gemini's incompatibility with `default` on `anyOf` (unions).
if (innerSchema instanceof z.ZodUnion) {
// The schema was `z.union(...).default(...)`, which is not allowed.
// We strip the default and return the sanitized union.
return sanitizeGeminiParameters(innerSchema, visited)
}
// Otherwise, the default is on a regular type, which is allowed.
// We recurse on the inner type and then re-apply the default.
return sanitizeGeminiParameters(innerSchema, visited).default(schema._def.defaultValue())
}
if (schema instanceof z.ZodOptional) {
return z.optional(sanitizeGeminiParameters(schema.unwrap(), visited))
}
if (schema instanceof z.ZodObject) {
const newShape: Record<string, z.ZodTypeAny> = {}
for (const [key, value] of Object.entries(schema.shape)) {
newShape[key] = sanitizeGeminiParameters(value as z.ZodTypeAny, visited)
}
return z.object(newShape)
}
if (schema instanceof z.ZodArray) {
return z.array(sanitizeGeminiParameters(schema.element, visited))
}
if (schema instanceof z.ZodUnion) {
// This schema corresponds to `anyOf` in JSON Schema.
// We recursively sanitize each option in the union.
const sanitizedOptions = schema.options.map((option: z.ZodTypeAny) => sanitizeGeminiParameters(option, visited))
return z.union(sanitizedOptions as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
}
if (schema instanceof z.ZodString) {
const newSchema = z.string({ description: schema.description })
const safeChecks = ["min", "max", "length", "regex", "startsWith", "endsWith", "includes", "trim"]
// rome-ignore lint/suspicious/noExplicitAny: <explanation>
;(newSchema._def as any).checks = (schema._def as z.ZodStringDef).checks.filter((check) =>
safeChecks.includes(check.kind),
)
return newSchema
}
return schema
}
function optionalToNullable(schema: z.ZodTypeAny): z.ZodTypeAny {
if (schema instanceof z.ZodObject) {
const shape = schema.shape
const newShape: Record<string, z.ZodTypeAny> = {}
for (const [key, value] of Object.entries(shape)) {
const zodValue = value as z.ZodTypeAny
if (zodValue instanceof z.ZodOptional) {
newShape[key] = zodValue.unwrap().nullable()
} else {
newShape[key] = optionalToNullable(zodValue)
}
}
return z.object(newShape)
}
if (schema instanceof z.ZodArray) {
return z.array(optionalToNullable(schema.element))
}
if (schema instanceof z.ZodUnion) {
return z.union(
schema.options.map((option: z.ZodTypeAny) => optionalToNullable(option)) as [
z.ZodTypeAny,
z.ZodTypeAny,
...z.ZodTypeAny[],
],
)
}
return schema
}
}

View File

@@ -7,11 +7,10 @@ import { MessageV2 } from "../session/message-v2"
import { Identifier } from "../id/id"
import { Agent } from "../agent/agent"
export const TaskTool = Tool.define(async () => {
export const TaskTool = Tool.define("task", async () => {
const agents = await Agent.list()
const description = DESCRIPTION.replace("{agents}", agents.map((a) => `- ${a.name}: ${a.description}`).join("\n"))
return {
id: "task",
description,
parameters: z.object({
description: z.string().describe("A short (3-5 words) description of the task"),
@@ -53,7 +52,10 @@ export const TaskTool = Tool.define(async () => {
providerID: model.providerID,
mode: msg.mode,
system: agent.prompt,
tools: agent.tools,
tools: {
...agent.tools,
task: false,
},
parts: [
{
id: Identifier.ascending("part"),

View File

@@ -18,8 +18,7 @@ const state = App.state("todo-tool", () => {
return todos
})
export const TodoWriteTool = Tool.define({
id: "todowrite",
export const TodoWriteTool = Tool.define("todowrite", {
description: DESCRIPTION_WRITE,
parameters: z.object({
todos: z.array(TodoInfo).describe("The updated todo list"),
@@ -37,8 +36,7 @@ export const TodoWriteTool = Tool.define({
},
})
export const TodoReadTool = Tool.define({
id: "todoread",
export const TodoReadTool = Tool.define("todoread", {
description: "Use this tool to read your todo list",
parameters: z.object({}),
async execute(_params, opts) {

View File

@@ -12,21 +12,30 @@ export namespace Tool {
}
export interface Info<Parameters extends StandardSchemaV1 = StandardSchemaV1, M extends Metadata = Metadata> {
id: string
description: string
parameters: Parameters
execute(
args: StandardSchemaV1.InferOutput<Parameters>,
ctx: Context,
): Promise<{
title: string
metadata: M
output: string
init: () => Promise<{
description: string
parameters: Parameters
execute(
args: StandardSchemaV1.InferOutput<Parameters>,
ctx: Context,
): Promise<{
title: string
metadata: M
output: string
}>
}>
}
export function define<Parameters extends StandardSchemaV1, Result extends Metadata>(
input: Info<Parameters, Result> | (() => Promise<Info<Parameters, Result>>),
): () => Promise<Info<Parameters, Result>> {
return input instanceof Function ? input : async () => input
id: string,
init: Info<Parameters, Result>["init"] | Awaited<ReturnType<Info<Parameters, Result>["init"]>>,
): Info<Parameters, Result> {
return {
id,
init: async () => {
if (init instanceof Function) return init()
return init
},
}
}
}

View File

@@ -7,8 +7,7 @@ const MAX_RESPONSE_SIZE = 5 * 1024 * 1024 // 5MB
const DEFAULT_TIMEOUT = 30 * 1000 // 30 seconds
const MAX_TIMEOUT = 120 * 1000 // 2 minutes
export const WebFetchTool = Tool.define({
id: "webfetch",
export const WebFetchTool = Tool.define("webfetch", {
description: DESCRIPTION,
parameters: z.object({
url: z.string().describe("The URL to fetch content from"),

View File

@@ -9,8 +9,7 @@ import { Bus } from "../bus"
import { File } from "../file"
import { FileTime } from "../file/time"
export const WriteTool = Tool.define({
id: "write",
export const WriteTool = Tool.define("write", {
description: DESCRIPTION,
parameters: z.object({
filePath: z.string().describe("The absolute path to the file to write (must be absolute, not relative)"),

View File

@@ -9,8 +9,8 @@ const ctx = {
abort: AbortSignal.any([]),
metadata: () => {},
}
const glob = await GlobTool()
const list = await ListTool()
const glob = await GlobTool.init()
const list = await ListTool.init()
describe("tool.glob", () => {
test("truncate", async () => {