mirror of
https://github.com/aljazceru/opencode.git
synced 2025-12-23 10:44:21 +01:00
allow temperature to be configured per mode
This commit is contained in:
@@ -99,6 +99,7 @@ export namespace Config {
|
|||||||
export const Mode = z
|
export const Mode = z
|
||||||
.object({
|
.object({
|
||||||
model: z.string().optional(),
|
model: z.string().optional(),
|
||||||
|
temperature: z.number().optional(),
|
||||||
prompt: z.string().optional(),
|
prompt: z.string().optional(),
|
||||||
tools: z.record(z.string(), z.boolean()).optional(),
|
tools: z.record(z.string(), z.boolean()).optional(),
|
||||||
disable: z.boolean().optional(),
|
disable: z.boolean().optional(),
|
||||||
|
|||||||
@@ -5,22 +5,11 @@ import { mergeDeep, sortBy } from "remeda"
|
|||||||
import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
|
import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
|
||||||
import { Log } from "../util/log"
|
import { Log } from "../util/log"
|
||||||
import { BunProc } from "../bun"
|
import { BunProc } from "../bun"
|
||||||
import { BashTool } from "../tool/bash"
|
|
||||||
import { EditTool } from "../tool/edit"
|
|
||||||
import { WebFetchTool } from "../tool/webfetch"
|
|
||||||
import { GlobTool } from "../tool/glob"
|
|
||||||
import { GrepTool } from "../tool/grep"
|
|
||||||
import { ListTool } from "../tool/ls"
|
|
||||||
import { PatchTool } from "../tool/patch"
|
|
||||||
import { ReadTool } from "../tool/read"
|
|
||||||
import { WriteTool } from "../tool/write"
|
|
||||||
import { TodoReadTool, TodoWriteTool } from "../tool/todo"
|
|
||||||
import { AuthAnthropic } from "../auth/anthropic"
|
import { AuthAnthropic } from "../auth/anthropic"
|
||||||
import { AuthCopilot } from "../auth/copilot"
|
import { AuthCopilot } from "../auth/copilot"
|
||||||
import { ModelsDev } from "./models"
|
import { ModelsDev } from "./models"
|
||||||
import { NamedError } from "../util/error"
|
import { NamedError } from "../util/error"
|
||||||
import { Auth } from "../auth"
|
import { Auth } from "../auth"
|
||||||
import { TaskTool } from "../tool/task"
|
|
||||||
|
|
||||||
export namespace Provider {
|
export namespace Provider {
|
||||||
const log = Log.create({ service: "provider" })
|
const log = Log.create({ service: "provider" })
|
||||||
@@ -468,137 +457,6 @@ export namespace Provider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const TOOLS = [
|
|
||||||
BashTool,
|
|
||||||
EditTool,
|
|
||||||
WebFetchTool,
|
|
||||||
GlobTool,
|
|
||||||
GrepTool,
|
|
||||||
ListTool,
|
|
||||||
// LspDiagnosticTool,
|
|
||||||
// LspHoverTool,
|
|
||||||
PatchTool,
|
|
||||||
ReadTool,
|
|
||||||
// MultiEditTool,
|
|
||||||
WriteTool,
|
|
||||||
TodoWriteTool,
|
|
||||||
TodoReadTool,
|
|
||||||
TaskTool,
|
|
||||||
]
|
|
||||||
|
|
||||||
export async function tools(providerID: string) {
|
|
||||||
const result = await Promise.all(TOOLS.map((t) => t()))
|
|
||||||
switch (providerID) {
|
|
||||||
case "anthropic":
|
|
||||||
return result.filter((t) => t.id !== "patch")
|
|
||||||
case "openai":
|
|
||||||
return result.map((t) => ({
|
|
||||||
...t,
|
|
||||||
parameters: optionalToNullable(t.parameters),
|
|
||||||
}))
|
|
||||||
case "azure":
|
|
||||||
return result.map((t) => ({
|
|
||||||
...t,
|
|
||||||
parameters: optionalToNullable(t.parameters),
|
|
||||||
}))
|
|
||||||
case "google":
|
|
||||||
return result.map((t) => ({
|
|
||||||
...t,
|
|
||||||
parameters: sanitizeGeminiParameters(t.parameters),
|
|
||||||
}))
|
|
||||||
default:
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function sanitizeGeminiParameters(schema: z.ZodTypeAny, visited = new Set()): z.ZodTypeAny {
|
|
||||||
if (!schema || visited.has(schema)) {
|
|
||||||
return schema
|
|
||||||
}
|
|
||||||
visited.add(schema)
|
|
||||||
|
|
||||||
if (schema instanceof z.ZodDefault) {
|
|
||||||
const innerSchema = schema.removeDefault()
|
|
||||||
// Handle Gemini's incompatibility with `default` on `anyOf` (unions).
|
|
||||||
if (innerSchema instanceof z.ZodUnion) {
|
|
||||||
// The schema was `z.union(...).default(...)`, which is not allowed.
|
|
||||||
// We strip the default and return the sanitized union.
|
|
||||||
return sanitizeGeminiParameters(innerSchema, visited)
|
|
||||||
}
|
|
||||||
// Otherwise, the default is on a regular type, which is allowed.
|
|
||||||
// We recurse on the inner type and then re-apply the default.
|
|
||||||
return sanitizeGeminiParameters(innerSchema, visited).default(schema._def.defaultValue())
|
|
||||||
}
|
|
||||||
|
|
||||||
if (schema instanceof z.ZodOptional) {
|
|
||||||
return z.optional(sanitizeGeminiParameters(schema.unwrap(), visited))
|
|
||||||
}
|
|
||||||
|
|
||||||
if (schema instanceof z.ZodObject) {
|
|
||||||
const newShape: Record<string, z.ZodTypeAny> = {}
|
|
||||||
for (const [key, value] of Object.entries(schema.shape)) {
|
|
||||||
newShape[key] = sanitizeGeminiParameters(value as z.ZodTypeAny, visited)
|
|
||||||
}
|
|
||||||
return z.object(newShape)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (schema instanceof z.ZodArray) {
|
|
||||||
return z.array(sanitizeGeminiParameters(schema.element, visited))
|
|
||||||
}
|
|
||||||
|
|
||||||
if (schema instanceof z.ZodUnion) {
|
|
||||||
// This schema corresponds to `anyOf` in JSON Schema.
|
|
||||||
// We recursively sanitize each option in the union.
|
|
||||||
const sanitizedOptions = schema.options.map((option: z.ZodTypeAny) => sanitizeGeminiParameters(option, visited))
|
|
||||||
return z.union(sanitizedOptions as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
|
|
||||||
}
|
|
||||||
|
|
||||||
if (schema instanceof z.ZodString) {
|
|
||||||
const newSchema = z.string({ description: schema.description })
|
|
||||||
const safeChecks = ["min", "max", "length", "regex", "startsWith", "endsWith", "includes", "trim"]
|
|
||||||
// rome-ignore lint/suspicious/noExplicitAny: <explanation>
|
|
||||||
;(newSchema._def as any).checks = (schema._def as z.ZodStringDef).checks.filter((check) =>
|
|
||||||
safeChecks.includes(check.kind),
|
|
||||||
)
|
|
||||||
return newSchema
|
|
||||||
}
|
|
||||||
|
|
||||||
return schema
|
|
||||||
}
|
|
||||||
function optionalToNullable(schema: z.ZodTypeAny): z.ZodTypeAny {
|
|
||||||
if (schema instanceof z.ZodObject) {
|
|
||||||
const shape = schema.shape
|
|
||||||
const newShape: Record<string, z.ZodTypeAny> = {}
|
|
||||||
|
|
||||||
for (const [key, value] of Object.entries(shape)) {
|
|
||||||
const zodValue = value as z.ZodTypeAny
|
|
||||||
if (zodValue instanceof z.ZodOptional) {
|
|
||||||
newShape[key] = zodValue.unwrap().nullable()
|
|
||||||
} else {
|
|
||||||
newShape[key] = optionalToNullable(zodValue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return z.object(newShape)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (schema instanceof z.ZodArray) {
|
|
||||||
return z.array(optionalToNullable(schema.element))
|
|
||||||
}
|
|
||||||
|
|
||||||
if (schema instanceof z.ZodUnion) {
|
|
||||||
return z.union(
|
|
||||||
schema.options.map((option: z.ZodTypeAny) => optionalToNullable(option)) as [
|
|
||||||
z.ZodTypeAny,
|
|
||||||
z.ZodTypeAny,
|
|
||||||
...z.ZodTypeAny[],
|
|
||||||
],
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
return schema
|
|
||||||
}
|
|
||||||
|
|
||||||
export const ModelNotFoundError = NamedError.create(
|
export const ModelNotFoundError = NamedError.create(
|
||||||
"ProviderModelNotFoundError",
|
"ProviderModelNotFoundError",
|
||||||
z.object({
|
z.object({
|
||||||
|
|||||||
@@ -44,4 +44,9 @@ export namespace ProviderTransform {
|
|||||||
}
|
}
|
||||||
return msgs
|
return msgs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function temperature(_providerID: string, modelID: string) {
|
||||||
|
if (modelID.includes("qwen")) return 0.55
|
||||||
|
return 0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,7 +39,8 @@ import { MessageV2 } from "./message-v2"
|
|||||||
import { Mode } from "./mode"
|
import { Mode } from "./mode"
|
||||||
import { LSP } from "../lsp"
|
import { LSP } from "../lsp"
|
||||||
import { ReadTool } from "../tool/read"
|
import { ReadTool } from "../tool/read"
|
||||||
import { splitWhen } from "remeda"
|
import { mergeDeep, pipe, splitWhen } from "remeda"
|
||||||
|
import { ToolRegistry } from "../tool/registry"
|
||||||
|
|
||||||
export namespace Session {
|
export namespace Session {
|
||||||
const log = Log.create({ service: "session" })
|
const log = Log.create({ service: "session" })
|
||||||
@@ -430,7 +431,7 @@ export namespace Session {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
const args = { filePath, offset, limit }
|
const args = { filePath, offset, limit }
|
||||||
const result = await ReadTool().then((t) =>
|
const result = await ReadTool.init().then((t) =>
|
||||||
t.execute(args, {
|
t.execute(args, {
|
||||||
sessionID: input.sessionID,
|
sessionID: input.sessionID,
|
||||||
abort: new AbortController().signal,
|
abort: new AbortController().signal,
|
||||||
@@ -660,10 +661,13 @@ export namespace Session {
|
|||||||
|
|
||||||
const processor = createProcessor(assistantMsg, model.info)
|
const processor = createProcessor(assistantMsg, model.info)
|
||||||
|
|
||||||
for (const item of await Provider.tools(input.providerID)) {
|
const enabledTools = pipe(
|
||||||
if (mode.tools[item.id] === false) continue
|
mode.tools,
|
||||||
if (input.tools?.[item.id] === false) continue
|
mergeDeep(ToolRegistry.enabled(input.providerID, input.modelID)),
|
||||||
if (session.parentID && item.id === "task") continue
|
mergeDeep(input.tools ?? {}),
|
||||||
|
)
|
||||||
|
for (const item of await ToolRegistry.tools(input.providerID, input.modelID)) {
|
||||||
|
if (enabledTools[item.id] === false) continue
|
||||||
tools[item.id] = tool({
|
tools[item.id] = tool({
|
||||||
id: item.id as any,
|
id: item.id as any,
|
||||||
description: item.description,
|
description: item.description,
|
||||||
@@ -791,7 +795,9 @@ export namespace Session {
|
|||||||
),
|
),
|
||||||
...MessageV2.toModelMessage(msgs),
|
...MessageV2.toModelMessage(msgs),
|
||||||
],
|
],
|
||||||
temperature: model.info.temperature ? 0 : undefined,
|
temperature: model.info.temperature
|
||||||
|
? (mode.temperature ?? ProviderTransform.temperature(input.providerID, input.modelID))
|
||||||
|
: undefined,
|
||||||
tools: model.info.tool_call === false ? undefined : tools,
|
tools: model.info.tool_call === false ? undefined : tools,
|
||||||
model: wrapLanguageModel({
|
model: wrapLanguageModel({
|
||||||
model: model.language,
|
model: model.language,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ export namespace Mode {
|
|||||||
export const Info = z
|
export const Info = z
|
||||||
.object({
|
.object({
|
||||||
name: z.string(),
|
name: z.string(),
|
||||||
|
temperature: z.number().optional(),
|
||||||
model: z
|
model: z
|
||||||
.object({
|
.object({
|
||||||
modelID: z.string(),
|
modelID: z.string(),
|
||||||
@@ -50,6 +51,7 @@ export namespace Mode {
|
|||||||
item.name = key
|
item.name = key
|
||||||
if (value.model) item.model = Provider.parseModel(value.model)
|
if (value.model) item.model = Provider.parseModel(value.model)
|
||||||
if (value.prompt) item.prompt = value.prompt
|
if (value.prompt) item.prompt = value.prompt
|
||||||
|
if (value.temperature) item.temperature = value.temperature
|
||||||
if (value.tools)
|
if (value.tools)
|
||||||
item.tools = {
|
item.tools = {
|
||||||
...value.tools,
|
...value.tools,
|
||||||
|
|||||||
@@ -7,8 +7,7 @@ const MAX_OUTPUT_LENGTH = 30000
|
|||||||
const DEFAULT_TIMEOUT = 1 * 60 * 1000
|
const DEFAULT_TIMEOUT = 1 * 60 * 1000
|
||||||
const MAX_TIMEOUT = 10 * 60 * 1000
|
const MAX_TIMEOUT = 10 * 60 * 1000
|
||||||
|
|
||||||
export const BashTool = Tool.define({
|
export const BashTool = Tool.define("bash", {
|
||||||
id: "bash",
|
|
||||||
description: DESCRIPTION,
|
description: DESCRIPTION,
|
||||||
parameters: z.object({
|
parameters: z.object({
|
||||||
command: z.string().describe("The command to execute"),
|
command: z.string().describe("The command to execute"),
|
||||||
|
|||||||
@@ -14,8 +14,7 @@ import { File } from "../file"
|
|||||||
import { Bus } from "../bus"
|
import { Bus } from "../bus"
|
||||||
import { FileTime } from "../file/time"
|
import { FileTime } from "../file/time"
|
||||||
|
|
||||||
export const EditTool = Tool.define({
|
export const EditTool = Tool.define("edit", {
|
||||||
id: "edit",
|
|
||||||
description: DESCRIPTION,
|
description: DESCRIPTION,
|
||||||
parameters: z.object({
|
parameters: z.object({
|
||||||
filePath: z.string().describe("The absolute path to the file to modify"),
|
filePath: z.string().describe("The absolute path to the file to modify"),
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ import { App } from "../app/app"
|
|||||||
import DESCRIPTION from "./glob.txt"
|
import DESCRIPTION from "./glob.txt"
|
||||||
import { Ripgrep } from "../file/ripgrep"
|
import { Ripgrep } from "../file/ripgrep"
|
||||||
|
|
||||||
export const GlobTool = Tool.define({
|
export const GlobTool = Tool.define("glob", {
|
||||||
id: "glob",
|
|
||||||
description: DESCRIPTION,
|
description: DESCRIPTION,
|
||||||
parameters: z.object({
|
parameters: z.object({
|
||||||
pattern: z.string().describe("The glob pattern to match files against"),
|
pattern: z.string().describe("The glob pattern to match files against"),
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ import { Ripgrep } from "../file/ripgrep"
|
|||||||
|
|
||||||
import DESCRIPTION from "./grep.txt"
|
import DESCRIPTION from "./grep.txt"
|
||||||
|
|
||||||
export const GrepTool = Tool.define({
|
export const GrepTool = Tool.define("grep", {
|
||||||
id: "grep",
|
|
||||||
description: DESCRIPTION,
|
description: DESCRIPTION,
|
||||||
parameters: z.object({
|
parameters: z.object({
|
||||||
pattern: z.string().describe("The regex pattern to search for in file contents"),
|
pattern: z.string().describe("The regex pattern to search for in file contents"),
|
||||||
|
|||||||
@@ -33,8 +33,7 @@ export const IGNORE_PATTERNS = [
|
|||||||
|
|
||||||
const LIMIT = 100
|
const LIMIT = 100
|
||||||
|
|
||||||
export const ListTool = Tool.define({
|
export const ListTool = Tool.define("list", {
|
||||||
id: "list",
|
|
||||||
description: DESCRIPTION,
|
description: DESCRIPTION,
|
||||||
parameters: z.object({
|
parameters: z.object({
|
||||||
path: z.string().describe("The absolute path to the directory to list (must be absolute, not relative)").optional(),
|
path: z.string().describe("The absolute path to the directory to list (must be absolute, not relative)").optional(),
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ import { LSP } from "../lsp"
|
|||||||
import { App } from "../app/app"
|
import { App } from "../app/app"
|
||||||
import DESCRIPTION from "./lsp-diagnostics.txt"
|
import DESCRIPTION from "./lsp-diagnostics.txt"
|
||||||
|
|
||||||
export const LspDiagnosticTool = Tool.define({
|
export const LspDiagnosticTool = Tool.define("lsp_diagnostics", {
|
||||||
id: "lsp_diagnostics",
|
|
||||||
description: DESCRIPTION,
|
description: DESCRIPTION,
|
||||||
parameters: z.object({
|
parameters: z.object({
|
||||||
path: z.string().describe("The path to the file to get diagnostics."),
|
path: z.string().describe("The path to the file to get diagnostics."),
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ import { LSP } from "../lsp"
|
|||||||
import { App } from "../app/app"
|
import { App } from "../app/app"
|
||||||
import DESCRIPTION from "./lsp-hover.txt"
|
import DESCRIPTION from "./lsp-hover.txt"
|
||||||
|
|
||||||
export const LspHoverTool = Tool.define({
|
export const LspHoverTool = Tool.define("lsp_hover", {
|
||||||
id: "lsp_hover",
|
|
||||||
description: DESCRIPTION,
|
description: DESCRIPTION,
|
||||||
parameters: z.object({
|
parameters: z.object({
|
||||||
file: z.string().describe("The path to the file to get diagnostics."),
|
file: z.string().describe("The path to the file to get diagnostics."),
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ import DESCRIPTION from "./multiedit.txt"
|
|||||||
import path from "path"
|
import path from "path"
|
||||||
import { App } from "../app/app"
|
import { App } from "../app/app"
|
||||||
|
|
||||||
export const MultiEditTool = Tool.define({
|
export const MultiEditTool = Tool.define("multiedit", {
|
||||||
id: "multiedit",
|
|
||||||
description: DESCRIPTION,
|
description: DESCRIPTION,
|
||||||
parameters: z.object({
|
parameters: z.object({
|
||||||
filePath: z.string().describe("The absolute path to the file to modify"),
|
filePath: z.string().describe("The absolute path to the file to modify"),
|
||||||
@@ -22,7 +21,7 @@ export const MultiEditTool = Tool.define({
|
|||||||
.describe("Array of edit operations to perform sequentially on the file"),
|
.describe("Array of edit operations to perform sequentially on the file"),
|
||||||
}),
|
}),
|
||||||
async execute(params, ctx) {
|
async execute(params, ctx) {
|
||||||
const tool = await EditTool()
|
const tool = await EditTool.init()
|
||||||
const results = []
|
const results = []
|
||||||
for (const [, edit] of params.edits.entries()) {
|
for (const [, edit] of params.edits.entries()) {
|
||||||
const result = await tool.execute(
|
const result = await tool.execute(
|
||||||
|
|||||||
@@ -210,8 +210,7 @@ async function applyCommit(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export const PatchTool = Tool.define({
|
export const PatchTool = Tool.define("patch", {
|
||||||
id: "patch",
|
|
||||||
description: DESCRIPTION,
|
description: DESCRIPTION,
|
||||||
parameters: PatchParams,
|
parameters: PatchParams,
|
||||||
execute: async (params, ctx) => {
|
execute: async (params, ctx) => {
|
||||||
|
|||||||
@@ -10,8 +10,7 @@ import { App } from "../app/app"
|
|||||||
const DEFAULT_READ_LIMIT = 2000
|
const DEFAULT_READ_LIMIT = 2000
|
||||||
const MAX_LINE_LENGTH = 2000
|
const MAX_LINE_LENGTH = 2000
|
||||||
|
|
||||||
export const ReadTool = Tool.define({
|
export const ReadTool = Tool.define("read", {
|
||||||
id: "read",
|
|
||||||
description: DESCRIPTION,
|
description: DESCRIPTION,
|
||||||
parameters: z.object({
|
parameters: z.object({
|
||||||
filePath: z.string().describe("The path to the file to read"),
|
filePath: z.string().describe("The path to the file to read"),
|
||||||
|
|||||||
170
packages/opencode/src/tool/registry.ts
Normal file
170
packages/opencode/src/tool/registry.ts
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
import z from "zod"
|
||||||
|
import { BashTool } from "./bash"
|
||||||
|
import { EditTool } from "./edit"
|
||||||
|
import { GlobTool } from "./glob"
|
||||||
|
import { GrepTool } from "./grep"
|
||||||
|
import { ListTool } from "./ls"
|
||||||
|
import { PatchTool } from "./patch"
|
||||||
|
import { ReadTool } from "./read"
|
||||||
|
import { TaskTool } from "./task"
|
||||||
|
import { TodoWriteTool, TodoReadTool } from "./todo"
|
||||||
|
import { WebFetchTool } from "./webfetch"
|
||||||
|
import { WriteTool } from "./write"
|
||||||
|
|
||||||
|
export namespace ToolRegistry {
|
||||||
|
const ALL = [
|
||||||
|
BashTool,
|
||||||
|
EditTool,
|
||||||
|
WebFetchTool,
|
||||||
|
GlobTool,
|
||||||
|
GrepTool,
|
||||||
|
ListTool,
|
||||||
|
PatchTool,
|
||||||
|
ReadTool,
|
||||||
|
WriteTool,
|
||||||
|
TodoWriteTool,
|
||||||
|
TodoReadTool,
|
||||||
|
TaskTool,
|
||||||
|
]
|
||||||
|
|
||||||
|
export function ids() {
|
||||||
|
return ALL.map((t) => t.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function tools(providerID: string, _modelID: string) {
|
||||||
|
const result = await Promise.all(
|
||||||
|
ALL.map(async (t) => ({
|
||||||
|
id: t.id,
|
||||||
|
...(await t.init()),
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
|
||||||
|
if (providerID === "openai") {
|
||||||
|
return result.map((t) => ({
|
||||||
|
...t,
|
||||||
|
parameters: optionalToNullable(t.parameters),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (providerID === "azure") {
|
||||||
|
return result.map((t) => ({
|
||||||
|
...t,
|
||||||
|
parameters: optionalToNullable(t.parameters),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (providerID === "google") {
|
||||||
|
return result.map((t) => ({
|
||||||
|
...t,
|
||||||
|
parameters: sanitizeGeminiParameters(t.parameters),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
export function enabled(_providerID: string, modelID: string): Record<string, boolean> {
|
||||||
|
if (modelID.includes("claude")) {
|
||||||
|
return {
|
||||||
|
patch: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (modelID.includes("qwen")) {
|
||||||
|
return {
|
||||||
|
patch: false,
|
||||||
|
todowrite: false,
|
||||||
|
todoread: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {}
|
||||||
|
}
|
||||||
|
|
||||||
|
function sanitizeGeminiParameters(schema: z.ZodTypeAny, visited = new Set()): z.ZodTypeAny {
|
||||||
|
if (!schema || visited.has(schema)) {
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
visited.add(schema)
|
||||||
|
|
||||||
|
if (schema instanceof z.ZodDefault) {
|
||||||
|
const innerSchema = schema.removeDefault()
|
||||||
|
// Handle Gemini's incompatibility with `default` on `anyOf` (unions).
|
||||||
|
if (innerSchema instanceof z.ZodUnion) {
|
||||||
|
// The schema was `z.union(...).default(...)`, which is not allowed.
|
||||||
|
// We strip the default and return the sanitized union.
|
||||||
|
return sanitizeGeminiParameters(innerSchema, visited)
|
||||||
|
}
|
||||||
|
// Otherwise, the default is on a regular type, which is allowed.
|
||||||
|
// We recurse on the inner type and then re-apply the default.
|
||||||
|
return sanitizeGeminiParameters(innerSchema, visited).default(schema._def.defaultValue())
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schema instanceof z.ZodOptional) {
|
||||||
|
return z.optional(sanitizeGeminiParameters(schema.unwrap(), visited))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schema instanceof z.ZodObject) {
|
||||||
|
const newShape: Record<string, z.ZodTypeAny> = {}
|
||||||
|
for (const [key, value] of Object.entries(schema.shape)) {
|
||||||
|
newShape[key] = sanitizeGeminiParameters(value as z.ZodTypeAny, visited)
|
||||||
|
}
|
||||||
|
return z.object(newShape)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schema instanceof z.ZodArray) {
|
||||||
|
return z.array(sanitizeGeminiParameters(schema.element, visited))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schema instanceof z.ZodUnion) {
|
||||||
|
// This schema corresponds to `anyOf` in JSON Schema.
|
||||||
|
// We recursively sanitize each option in the union.
|
||||||
|
const sanitizedOptions = schema.options.map((option: z.ZodTypeAny) => sanitizeGeminiParameters(option, visited))
|
||||||
|
return z.union(sanitizedOptions as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schema instanceof z.ZodString) {
|
||||||
|
const newSchema = z.string({ description: schema.description })
|
||||||
|
const safeChecks = ["min", "max", "length", "regex", "startsWith", "endsWith", "includes", "trim"]
|
||||||
|
// rome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||||
|
;(newSchema._def as any).checks = (schema._def as z.ZodStringDef).checks.filter((check) =>
|
||||||
|
safeChecks.includes(check.kind),
|
||||||
|
)
|
||||||
|
return newSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
|
||||||
|
function optionalToNullable(schema: z.ZodTypeAny): z.ZodTypeAny {
|
||||||
|
if (schema instanceof z.ZodObject) {
|
||||||
|
const shape = schema.shape
|
||||||
|
const newShape: Record<string, z.ZodTypeAny> = {}
|
||||||
|
|
||||||
|
for (const [key, value] of Object.entries(shape)) {
|
||||||
|
const zodValue = value as z.ZodTypeAny
|
||||||
|
if (zodValue instanceof z.ZodOptional) {
|
||||||
|
newShape[key] = zodValue.unwrap().nullable()
|
||||||
|
} else {
|
||||||
|
newShape[key] = optionalToNullable(zodValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return z.object(newShape)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schema instanceof z.ZodArray) {
|
||||||
|
return z.array(optionalToNullable(schema.element))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schema instanceof z.ZodUnion) {
|
||||||
|
return z.union(
|
||||||
|
schema.options.map((option: z.ZodTypeAny) => optionalToNullable(option)) as [
|
||||||
|
z.ZodTypeAny,
|
||||||
|
z.ZodTypeAny,
|
||||||
|
...z.ZodTypeAny[],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,11 +7,10 @@ import { MessageV2 } from "../session/message-v2"
|
|||||||
import { Identifier } from "../id/id"
|
import { Identifier } from "../id/id"
|
||||||
import { Agent } from "../agent/agent"
|
import { Agent } from "../agent/agent"
|
||||||
|
|
||||||
export const TaskTool = Tool.define(async () => {
|
export const TaskTool = Tool.define("task", async () => {
|
||||||
const agents = await Agent.list()
|
const agents = await Agent.list()
|
||||||
const description = DESCRIPTION.replace("{agents}", agents.map((a) => `- ${a.name}: ${a.description}`).join("\n"))
|
const description = DESCRIPTION.replace("{agents}", agents.map((a) => `- ${a.name}: ${a.description}`).join("\n"))
|
||||||
return {
|
return {
|
||||||
id: "task",
|
|
||||||
description,
|
description,
|
||||||
parameters: z.object({
|
parameters: z.object({
|
||||||
description: z.string().describe("A short (3-5 words) description of the task"),
|
description: z.string().describe("A short (3-5 words) description of the task"),
|
||||||
@@ -53,7 +52,10 @@ export const TaskTool = Tool.define(async () => {
|
|||||||
providerID: model.providerID,
|
providerID: model.providerID,
|
||||||
mode: msg.mode,
|
mode: msg.mode,
|
||||||
system: agent.prompt,
|
system: agent.prompt,
|
||||||
tools: agent.tools,
|
tools: {
|
||||||
|
...agent.tools,
|
||||||
|
task: false,
|
||||||
|
},
|
||||||
parts: [
|
parts: [
|
||||||
{
|
{
|
||||||
id: Identifier.ascending("part"),
|
id: Identifier.ascending("part"),
|
||||||
|
|||||||
@@ -18,8 +18,7 @@ const state = App.state("todo-tool", () => {
|
|||||||
return todos
|
return todos
|
||||||
})
|
})
|
||||||
|
|
||||||
export const TodoWriteTool = Tool.define({
|
export const TodoWriteTool = Tool.define("todowrite", {
|
||||||
id: "todowrite",
|
|
||||||
description: DESCRIPTION_WRITE,
|
description: DESCRIPTION_WRITE,
|
||||||
parameters: z.object({
|
parameters: z.object({
|
||||||
todos: z.array(TodoInfo).describe("The updated todo list"),
|
todos: z.array(TodoInfo).describe("The updated todo list"),
|
||||||
@@ -37,8 +36,7 @@ export const TodoWriteTool = Tool.define({
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
export const TodoReadTool = Tool.define({
|
export const TodoReadTool = Tool.define("todoread", {
|
||||||
id: "todoread",
|
|
||||||
description: "Use this tool to read your todo list",
|
description: "Use this tool to read your todo list",
|
||||||
parameters: z.object({}),
|
parameters: z.object({}),
|
||||||
async execute(_params, opts) {
|
async execute(_params, opts) {
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ export namespace Tool {
|
|||||||
}
|
}
|
||||||
export interface Info<Parameters extends StandardSchemaV1 = StandardSchemaV1, M extends Metadata = Metadata> {
|
export interface Info<Parameters extends StandardSchemaV1 = StandardSchemaV1, M extends Metadata = Metadata> {
|
||||||
id: string
|
id: string
|
||||||
|
init: () => Promise<{
|
||||||
description: string
|
description: string
|
||||||
parameters: Parameters
|
parameters: Parameters
|
||||||
execute(
|
execute(
|
||||||
@@ -22,11 +23,19 @@ export namespace Tool {
|
|||||||
metadata: M
|
metadata: M
|
||||||
output: string
|
output: string
|
||||||
}>
|
}>
|
||||||
|
}>
|
||||||
}
|
}
|
||||||
|
|
||||||
export function define<Parameters extends StandardSchemaV1, Result extends Metadata>(
|
export function define<Parameters extends StandardSchemaV1, Result extends Metadata>(
|
||||||
input: Info<Parameters, Result> | (() => Promise<Info<Parameters, Result>>),
|
id: string,
|
||||||
): () => Promise<Info<Parameters, Result>> {
|
init: Info<Parameters, Result>["init"] | Awaited<ReturnType<Info<Parameters, Result>["init"]>>,
|
||||||
return input instanceof Function ? input : async () => input
|
): Info<Parameters, Result> {
|
||||||
|
return {
|
||||||
|
id,
|
||||||
|
init: async () => {
|
||||||
|
if (init instanceof Function) return init()
|
||||||
|
return init
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,8 +7,7 @@ const MAX_RESPONSE_SIZE = 5 * 1024 * 1024 // 5MB
|
|||||||
const DEFAULT_TIMEOUT = 30 * 1000 // 30 seconds
|
const DEFAULT_TIMEOUT = 30 * 1000 // 30 seconds
|
||||||
const MAX_TIMEOUT = 120 * 1000 // 2 minutes
|
const MAX_TIMEOUT = 120 * 1000 // 2 minutes
|
||||||
|
|
||||||
export const WebFetchTool = Tool.define({
|
export const WebFetchTool = Tool.define("webfetch", {
|
||||||
id: "webfetch",
|
|
||||||
description: DESCRIPTION,
|
description: DESCRIPTION,
|
||||||
parameters: z.object({
|
parameters: z.object({
|
||||||
url: z.string().describe("The URL to fetch content from"),
|
url: z.string().describe("The URL to fetch content from"),
|
||||||
|
|||||||
@@ -9,8 +9,7 @@ import { Bus } from "../bus"
|
|||||||
import { File } from "../file"
|
import { File } from "../file"
|
||||||
import { FileTime } from "../file/time"
|
import { FileTime } from "../file/time"
|
||||||
|
|
||||||
export const WriteTool = Tool.define({
|
export const WriteTool = Tool.define("write", {
|
||||||
id: "write",
|
|
||||||
description: DESCRIPTION,
|
description: DESCRIPTION,
|
||||||
parameters: z.object({
|
parameters: z.object({
|
||||||
filePath: z.string().describe("The absolute path to the file to write (must be absolute, not relative)"),
|
filePath: z.string().describe("The absolute path to the file to write (must be absolute, not relative)"),
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ const ctx = {
|
|||||||
abort: AbortSignal.any([]),
|
abort: AbortSignal.any([]),
|
||||||
metadata: () => {},
|
metadata: () => {},
|
||||||
}
|
}
|
||||||
const glob = await GlobTool()
|
const glob = await GlobTool.init()
|
||||||
const list = await ListTool()
|
const list = await ListTool.init()
|
||||||
|
|
||||||
describe("tool.glob", () => {
|
describe("tool.glob", () => {
|
||||||
test("truncate", async () => {
|
test("truncate", async () => {
|
||||||
|
|||||||
Reference in New Issue
Block a user