diff --git a/packages/opencode/src/provider/transform.ts b/packages/opencode/src/provider/transform.ts index 9e095f5b..9a955f5c 100644 --- a/packages/opencode/src/provider/transform.ts +++ b/packages/opencode/src/provider/transform.ts @@ -3,21 +3,77 @@ import { unique } from "remeda" import type { JSONSchema } from "zod/v4/core" export namespace ProviderTransform { - function normalizeToolCallIds(msgs: ModelMessage[]): ModelMessage[] { - return msgs.map((msg) => { - if ((msg.role === "assistant" || msg.role === "tool") && Array.isArray(msg.content)) { - msg.content = msg.content.map((part) => { - if ((part.type === "tool-call" || part.type === "tool-result") && "toolCallId" in part) { - return { - ...part, - toolCallId: part.toolCallId.replace(/[^a-zA-Z0-9_-]/g, "_"), + function normalizeMessages( + msgs: ModelMessage[], + providerID: string, + modelID: string, + ): ModelMessage[] { + if (modelID.includes("claude")) { + return msgs.map((msg) => { + if ((msg.role === "assistant" || msg.role === "tool") && Array.isArray(msg.content)) { + msg.content = msg.content.map((part) => { + if ( + (part.type === "tool-call" || part.type === "tool-result") && + "toolCallId" in part + ) { + return { + ...part, + toolCallId: part.toolCallId.replace(/[^a-zA-Z0-9_-]/g, "_"), + } } - } - return part - }) + return part + }) + } + return msg + }) + } + if (providerID === "mistral" || modelID.toLowerCase().includes("mistral")) { + const result: ModelMessage[] = [] + for (let i = 0; i < msgs.length; i++) { + const msg = msgs[i] + const prevMsg = msgs[i - 1] + const nextMsg = msgs[i + 1] + + if ((msg.role === "assistant" || msg.role === "tool") && Array.isArray(msg.content)) { + msg.content = msg.content.map((part) => { + if ( + (part.type === "tool-call" || part.type === "tool-result") && + "toolCallId" in part + ) { + // Mistral requires alphanumeric tool call IDs with exactly 9 characters + const normalizedId = part.toolCallId + .replace(/[^a-zA-Z0-9]/g, "") // Remove non-alphanumeric characters + .substring(0, 9) // Take first 9 characters + .padEnd(9, "0") // Pad with zeros if less than 9 characters + + return { + ...part, + toolCallId: normalizedId, + } + } + return part + }) + } + + result.push(msg) + + // Fix message sequence: tool messages cannot be followed by user messages + if (msg.role === "tool" && nextMsg?.role === "user") { + result.push({ + role: "assistant", + content: [ + { + type: "text", + text: "Done.", + }, + ], + }) + } } - return msg - }) + return result + } + + return msgs } function applyCaching(msgs: ModelMessage[], providerID: string): ModelMessage[] { @@ -64,9 +120,7 @@ export namespace ProviderTransform { } export function message(msgs: ModelMessage[], providerID: string, modelID: string) { - if (modelID.includes("claude")) { - msgs = normalizeToolCallIds(msgs) - } + msgs = normalizeMessages(msgs, providerID, modelID) if (providerID === "anthropic" || modelID.includes("anthropic") || modelID.includes("claude")) { msgs = applyCaching(msgs, providerID) }