feat(models): enable Kimi k2 ⇄ Claude trajectory handoff (#1525)

This commit is contained in:
Ricardo Gonzalez
2025-08-01 20:05:06 -07:00
committed by GitHub
parent 6581741318
commit 8f45a0e227

View File

@@ -1,48 +1,76 @@
import type { ModelMessage } from "ai"
import { unique } from "remeda"
export namespace ProviderTransform {
export function message(msgs: ModelMessage[], providerID: string, modelID: string) {
if (providerID === "anthropic" || modelID.includes("anthropic") || modelID.includes("claude")) {
const system = msgs.filter((msg) => msg.role === "system").slice(0, 2)
const final = msgs.filter((msg) => msg.role !== "system").slice(-2)
const providerOptions = {
anthropic: {
cacheControl: { type: "ephemeral" },
},
openrouter: {
cache_control: { type: "ephemeral" },
},
bedrock: {
cachePoint: { type: "ephemeral" },
},
openaiCompatible: {
cache_control: { type: "ephemeral" },
},
export namespace ProviderTransform {
function normalizeToolCallIds(msgs: ModelMessage[]): ModelMessage[] {
return msgs.map((msg) => {
if ((msg.role === "assistant" || msg.role === "tool") && Array.isArray(msg.content)) {
msg.content = msg.content.map((part) => {
if ((part.type === "tool-call" || part.type === "tool-result") && "toolCallId" in part) {
return {
...part,
toolCallId: part.toolCallId.replace(/[^a-zA-Z0-9_-]/g, '_')
}
}
return part
})
}
return msg
})
}
function applyCaching(msgs: ModelMessage[], providerID: string): ModelMessage[] {
const system = msgs.filter((msg) => msg.role === "system").slice(0, 2)
const final = msgs.filter((msg) => msg.role !== "system").slice(-2)
const providerOptions = {
anthropic: {
cacheControl: { type: "ephemeral" },
},
openrouter: {
cache_control: { type: "ephemeral" },
},
bedrock: {
cachePoint: { type: "ephemeral" },
},
openaiCompatible: {
cache_control: { type: "ephemeral" },
},
}
for (const msg of unique([...system, ...final])) {
const shouldUseContentOptions =
providerID !== "anthropic" && Array.isArray(msg.content) && msg.content.length > 0
if (shouldUseContentOptions) {
const lastContent = msg.content[msg.content.length - 1]
if (lastContent && typeof lastContent === "object") {
lastContent.providerOptions = {
...lastContent.providerOptions,
...providerOptions,
}
continue
}
}
for (const msg of unique([...system, ...final])) {
const shouldUseContentOptions =
providerID !== "anthropic" && Array.isArray(msg.content) && msg.content.length > 0
if (shouldUseContentOptions) {
const lastContent = msg.content[msg.content.length - 1]
if (lastContent && typeof lastContent === "object") {
lastContent.providerOptions = {
...lastContent.providerOptions,
...providerOptions,
}
continue
}
}
msg.providerOptions = {
...msg.providerOptions,
...providerOptions,
}
msg.providerOptions = {
...msg.providerOptions,
...providerOptions,
}
}
return msgs
}
export function message(msgs: ModelMessage[], providerID: string, modelID: string) {
if (modelID.includes("claude")) {
msgs = normalizeToolCallIds(msgs)
}
if (providerID === "anthropic" || modelID.includes("anthropic") || modelID.includes("claude")) {
msgs = applyCaching(msgs, providerID)
}
return msgs
}
@@ -50,9 +78,4 @@ export namespace ProviderTransform {
if (modelID.toLowerCase().includes("qwen")) return 0.55
return 0
}
export function topP(_providerID: string, modelID: string) {
if (modelID.toLowerCase().includes("qwen")) return 1
return undefined
}
}