mirror of
https://github.com/aljazceru/opencode.git
synced 2026-01-25 18:54:56 +01:00
feat: retry parts (#3369)
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
import { streamText, type ModelMessage, LoadAPIKeyError } from "ai"
|
import { streamText, type ModelMessage, LoadAPIKeyError, type StreamTextResult, type Tool as AITool } from "ai"
|
||||||
import { Session } from "."
|
import { Session } from "."
|
||||||
import { Identifier } from "../id/id"
|
import { Identifier } from "../id/id"
|
||||||
import { Instance } from "../project/instance"
|
import { Instance } from "../project/instance"
|
||||||
@@ -14,8 +14,8 @@ import { Flag } from "../flag/flag"
|
|||||||
import { Token } from "../util/token"
|
import { Token } from "../util/token"
|
||||||
import { Log } from "../util/log"
|
import { Log } from "../util/log"
|
||||||
import { SessionLock } from "./lock"
|
import { SessionLock } from "./lock"
|
||||||
import { NamedError } from "../util/error"
|
|
||||||
import { ProviderTransform } from "@/provider/transform"
|
import { ProviderTransform } from "@/provider/transform"
|
||||||
|
import { SessionRetry } from "./retry"
|
||||||
|
|
||||||
export namespace SessionCompaction {
|
export namespace SessionCompaction {
|
||||||
const log = Log.create({ service: "session.compaction" })
|
const log = Log.create({ service: "session.compaction" })
|
||||||
@@ -41,6 +41,7 @@ export namespace SessionCompaction {
|
|||||||
|
|
||||||
export const PRUNE_MINIMUM = 20_000
|
export const PRUNE_MINIMUM = 20_000
|
||||||
export const PRUNE_PROTECT = 40_000
|
export const PRUNE_PROTECT = 40_000
|
||||||
|
const MAX_RETRIES = 10
|
||||||
|
|
||||||
// goes backwards through parts until there are 40_000 tokens worth of tool
|
// goes backwards through parts until there are 40_000 tokens worth of tool
|
||||||
// calls. then erases output of previous tool calls. idea is to throw away old
|
// calls. then erases output of previous tool calls. idea is to throw away old
|
||||||
@@ -142,112 +143,173 @@ export namespace SessionCompaction {
|
|||||||
},
|
},
|
||||||
})) as MessageV2.TextPart
|
})) as MessageV2.TextPart
|
||||||
|
|
||||||
const stream = streamText({
|
const doStream = () =>
|
||||||
maxRetries: 10,
|
streamText({
|
||||||
model: model.language,
|
// set to 0, we handle loop
|
||||||
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, model.info.options),
|
maxRetries: 0,
|
||||||
abortSignal: signal,
|
model: model.language,
|
||||||
onError(error) {
|
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, model.info.options),
|
||||||
log.error("stream error", {
|
abortSignal: signal,
|
||||||
error,
|
onError(error) {
|
||||||
})
|
log.error("stream error", {
|
||||||
},
|
error,
|
||||||
messages: [
|
})
|
||||||
...system.map(
|
|
||||||
(x): ModelMessage => ({
|
|
||||||
role: "system",
|
|
||||||
content: x,
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
...MessageV2.toModelMessage(toSummarize),
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: "text",
|
|
||||||
text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
},
|
||||||
],
|
messages: [
|
||||||
})
|
...system.map(
|
||||||
|
(x): ModelMessage => ({
|
||||||
|
role: "system",
|
||||||
|
content: x,
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
...MessageV2.toModelMessage(toSummarize),
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: "text",
|
||||||
|
text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
})
|
||||||
|
|
||||||
try {
|
// TODO: reduce duplication between compaction.ts & prompt.ts
|
||||||
for await (const value of stream.fullStream) {
|
const process = async (
|
||||||
signal.throwIfAborted()
|
stream: StreamTextResult<Record<string, AITool>, never>,
|
||||||
switch (value.type) {
|
retries: { count: number; max: number },
|
||||||
case "text-delta":
|
) => {
|
||||||
part.text += value.text
|
let shouldRetry = false
|
||||||
if (value.providerMetadata) part.metadata = value.providerMetadata
|
try {
|
||||||
if (part.text) await Session.updatePart(part)
|
for await (const value of stream.fullStream) {
|
||||||
continue
|
signal.throwIfAborted()
|
||||||
case "text-end": {
|
switch (value.type) {
|
||||||
part.text = part.text.trimEnd()
|
case "text-delta":
|
||||||
part.time = {
|
part.text += value.text
|
||||||
start: Date.now(),
|
if (value.providerMetadata) part.metadata = value.providerMetadata
|
||||||
end: Date.now(),
|
if (part.text) await Session.updatePart(part)
|
||||||
|
continue
|
||||||
|
case "text-end": {
|
||||||
|
part.text = part.text.trimEnd()
|
||||||
|
part.time = {
|
||||||
|
start: Date.now(),
|
||||||
|
end: Date.now(),
|
||||||
|
}
|
||||||
|
if (value.providerMetadata) part.metadata = value.providerMetadata
|
||||||
|
await Session.updatePart(part)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
if (value.providerMetadata) part.metadata = value.providerMetadata
|
case "finish-step": {
|
||||||
await Session.updatePart(part)
|
const usage = Session.getUsage({
|
||||||
continue
|
model: model.info,
|
||||||
|
usage: value.usage,
|
||||||
|
metadata: value.providerMetadata,
|
||||||
|
})
|
||||||
|
msg.cost += usage.cost
|
||||||
|
msg.tokens = usage.tokens
|
||||||
|
await Session.updateMessage(msg)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
case "error":
|
||||||
|
throw value.error
|
||||||
|
default:
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
case "finish-step": {
|
}
|
||||||
const usage = Session.getUsage({
|
} catch (e) {
|
||||||
model: model.info,
|
log.error("compaction error", {
|
||||||
usage: value.usage,
|
error: e,
|
||||||
metadata: value.providerMetadata,
|
})
|
||||||
})
|
const error = MessageV2.fromError(e, { providerID: input.providerID })
|
||||||
msg.cost += usage.cost
|
if (retries.count < retries.max && MessageV2.APIError.isInstance(error) && error.data.isRetryable) {
|
||||||
msg.tokens = usage.tokens
|
shouldRetry = true
|
||||||
await Session.updateMessage(msg)
|
await Session.updatePart({
|
||||||
continue
|
id: Identifier.ascending("part"),
|
||||||
}
|
messageID: msg.id,
|
||||||
case "error":
|
sessionID: msg.sessionID,
|
||||||
throw value.error
|
type: "retry",
|
||||||
default:
|
attempt: retries.count + 1,
|
||||||
continue
|
time: {
|
||||||
|
created: Date.now(),
|
||||||
|
},
|
||||||
|
error,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
msg.error = error
|
||||||
|
Bus.publish(Session.Event.Error, {
|
||||||
|
sessionID: msg.sessionID,
|
||||||
|
error: msg.error,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (e) {
|
|
||||||
log.error("compaction error", {
|
const parts = await Session.getParts(msg.id)
|
||||||
error: e,
|
return {
|
||||||
})
|
info: msg,
|
||||||
switch (true) {
|
parts,
|
||||||
case e instanceof DOMException && e.name === "AbortError":
|
shouldRetry,
|
||||||
msg.error = new MessageV2.AbortedError(
|
}
|
||||||
{ message: e.message },
|
}
|
||||||
{
|
|
||||||
cause: e,
|
let stream = doStream()
|
||||||
},
|
let result = await process(stream, {
|
||||||
).toObject()
|
count: 0,
|
||||||
break
|
max: MAX_RETRIES,
|
||||||
case MessageV2.OutputLengthError.isInstance(e):
|
})
|
||||||
msg.error = e
|
if (result.shouldRetry) {
|
||||||
break
|
for (let retry = 1; retry < MAX_RETRIES; retry++) {
|
||||||
case LoadAPIKeyError.isInstance(e):
|
const lastRetryPart = result.parts.findLast((p) => p.type === "retry")
|
||||||
msg.error = new MessageV2.AuthError(
|
|
||||||
{
|
if (lastRetryPart) {
|
||||||
providerID: model.providerID,
|
const delayMs = SessionRetry.getRetryDelayInMs(lastRetryPart.error, retry)
|
||||||
message: e.message,
|
|
||||||
},
|
log.info("retrying with backoff", {
|
||||||
{ cause: e },
|
attempt: retry,
|
||||||
).toObject()
|
delayMs,
|
||||||
break
|
})
|
||||||
case e instanceof Error:
|
|
||||||
msg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
|
const stop = await SessionRetry.sleep(delayMs, signal)
|
||||||
break
|
.then(() => false)
|
||||||
default:
|
.catch((error) => {
|
||||||
msg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
|
if (error instanceof DOMException && error.name === "AbortError") {
|
||||||
|
const err = new MessageV2.AbortedError(
|
||||||
|
{ message: error.message },
|
||||||
|
{
|
||||||
|
cause: error,
|
||||||
|
},
|
||||||
|
).toObject()
|
||||||
|
result.info.error = err
|
||||||
|
Bus.publish(Session.Event.Error, {
|
||||||
|
sessionID: result.info.sessionID,
|
||||||
|
error: result.info.error,
|
||||||
|
})
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
})
|
||||||
|
|
||||||
|
if (stop) break
|
||||||
|
}
|
||||||
|
|
||||||
|
stream = doStream()
|
||||||
|
result = await process(stream, {
|
||||||
|
count: retry,
|
||||||
|
max: MAX_RETRIES,
|
||||||
|
})
|
||||||
|
if (!result.shouldRetry) {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Bus.publish(Session.Event.Error, {
|
|
||||||
sessionID: input.sessionID,
|
|
||||||
error: msg.error,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
msg.time.completed = Date.now()
|
msg.time.completed = Date.now()
|
||||||
|
|
||||||
if (!msg.error || MessageV2.AbortedError.isInstance(msg.error)) {
|
if (
|
||||||
|
!msg.error ||
|
||||||
|
(MessageV2.AbortedError.isInstance(msg.error) &&
|
||||||
|
result.parts.some((part) => part.type === "text" && part.text.length > 0))
|
||||||
|
) {
|
||||||
msg.summary = true
|
msg.summary = true
|
||||||
Bus.publish(Event.Compacted, {
|
Bus.publish(Event.Compacted, {
|
||||||
sessionID: input.sessionID,
|
sessionID: input.sessionID,
|
||||||
@@ -257,7 +319,7 @@ export namespace SessionCompaction {
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
info: msg,
|
info: msg,
|
||||||
parts: [part],
|
parts: result.parts,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import z from "zod/v4"
|
|||||||
import { Bus } from "../bus"
|
import { Bus } from "../bus"
|
||||||
import { NamedError } from "../util/error"
|
import { NamedError } from "../util/error"
|
||||||
import { Message } from "./message"
|
import { Message } from "./message"
|
||||||
import { convertToModelMessages, type ModelMessage, type UIMessage } from "ai"
|
import { APICallError, convertToModelMessages, LoadAPIKeyError, type ModelMessage, type UIMessage } from "ai"
|
||||||
import { Identifier } from "../id/id"
|
import { Identifier } from "../id/id"
|
||||||
import { LSP } from "../lsp"
|
import { LSP } from "../lsp"
|
||||||
import { Snapshot } from "@/snapshot"
|
import { Snapshot } from "@/snapshot"
|
||||||
@@ -18,6 +18,17 @@ export namespace MessageV2 {
|
|||||||
message: z.string(),
|
message: z.string(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
export const APIError = NamedError.create(
|
||||||
|
"APIError",
|
||||||
|
z.object({
|
||||||
|
message: z.string(),
|
||||||
|
statusCode: z.number().optional(),
|
||||||
|
isRetryable: z.boolean(),
|
||||||
|
responseHeaders: z.record(z.string(), z.string()).optional(),
|
||||||
|
responseBody: z.string().optional(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
export type APIError = z.infer<typeof APIError.Schema>
|
||||||
|
|
||||||
const PartBase = z.object({
|
const PartBase = z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
@@ -130,6 +141,18 @@ export namespace MessageV2 {
|
|||||||
})
|
})
|
||||||
export type AgentPart = z.infer<typeof AgentPart>
|
export type AgentPart = z.infer<typeof AgentPart>
|
||||||
|
|
||||||
|
export const RetryPart = PartBase.extend({
|
||||||
|
type: z.literal("retry"),
|
||||||
|
attempt: z.number(),
|
||||||
|
error: APIError.Schema,
|
||||||
|
time: z.object({
|
||||||
|
created: z.number(),
|
||||||
|
}),
|
||||||
|
}).meta({
|
||||||
|
ref: "RetryPart",
|
||||||
|
})
|
||||||
|
export type RetryPart = z.infer<typeof RetryPart>
|
||||||
|
|
||||||
export const StepStartPart = PartBase.extend({
|
export const StepStartPart = PartBase.extend({
|
||||||
type: z.literal("step-start"),
|
type: z.literal("step-start"),
|
||||||
snapshot: z.string().optional(),
|
snapshot: z.string().optional(),
|
||||||
@@ -265,6 +288,7 @@ export namespace MessageV2 {
|
|||||||
SnapshotPart,
|
SnapshotPart,
|
||||||
PatchPart,
|
PatchPart,
|
||||||
AgentPart,
|
AgentPart,
|
||||||
|
RetryPart,
|
||||||
])
|
])
|
||||||
.meta({
|
.meta({
|
||||||
ref: "Part",
|
ref: "Part",
|
||||||
@@ -283,6 +307,7 @@ export namespace MessageV2 {
|
|||||||
NamedError.Unknown.Schema,
|
NamedError.Unknown.Schema,
|
||||||
OutputLengthError.Schema,
|
OutputLengthError.Schema,
|
||||||
AbortedError.Schema,
|
AbortedError.Schema,
|
||||||
|
APIError.Schema,
|
||||||
])
|
])
|
||||||
.optional(),
|
.optional(),
|
||||||
system: z.string().array(),
|
system: z.string().array(),
|
||||||
@@ -610,4 +635,41 @@ export namespace MessageV2 {
|
|||||||
if (i === -1) return msgs.slice()
|
if (i === -1) return msgs.slice()
|
||||||
return msgs.slice(i)
|
return msgs.slice(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function fromError(e: unknown, ctx: { providerID: string }) {
|
||||||
|
switch (true) {
|
||||||
|
case e instanceof DOMException && e.name === "AbortError":
|
||||||
|
return new MessageV2.AbortedError(
|
||||||
|
{ message: e.message },
|
||||||
|
{
|
||||||
|
cause: e,
|
||||||
|
},
|
||||||
|
).toObject()
|
||||||
|
case MessageV2.OutputLengthError.isInstance(e):
|
||||||
|
return e
|
||||||
|
case LoadAPIKeyError.isInstance(e):
|
||||||
|
return new MessageV2.AuthError(
|
||||||
|
{
|
||||||
|
providerID: ctx.providerID,
|
||||||
|
message: e.message,
|
||||||
|
},
|
||||||
|
{ cause: e },
|
||||||
|
).toObject()
|
||||||
|
case APICallError.isInstance(e):
|
||||||
|
return new MessageV2.APIError(
|
||||||
|
{
|
||||||
|
message: e.message,
|
||||||
|
statusCode: e.statusCode,
|
||||||
|
isRetryable: e.isRetryable,
|
||||||
|
responseHeaders: e.responseHeaders,
|
||||||
|
responseBody: e.responseBody,
|
||||||
|
},
|
||||||
|
{ cause: e },
|
||||||
|
).toObject()
|
||||||
|
case e instanceof Error:
|
||||||
|
return new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
|
||||||
|
default:
|
||||||
|
return new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ import {
|
|||||||
tool,
|
tool,
|
||||||
wrapLanguageModel,
|
wrapLanguageModel,
|
||||||
type StreamTextResult,
|
type StreamTextResult,
|
||||||
LoadAPIKeyError,
|
|
||||||
stepCountIs,
|
stepCountIs,
|
||||||
jsonSchema,
|
jsonSchema,
|
||||||
} from "ai"
|
} from "ai"
|
||||||
@@ -28,6 +27,7 @@ import { Bus } from "../bus"
|
|||||||
import { ProviderTransform } from "../provider/transform"
|
import { ProviderTransform } from "../provider/transform"
|
||||||
import { SystemPrompt } from "./system"
|
import { SystemPrompt } from "./system"
|
||||||
import { Plugin } from "../plugin"
|
import { Plugin } from "../plugin"
|
||||||
|
import { SessionRetry } from "./retry"
|
||||||
|
|
||||||
import PROMPT_PLAN from "../session/prompt/plan.txt"
|
import PROMPT_PLAN from "../session/prompt/plan.txt"
|
||||||
import BUILD_SWITCH from "../session/prompt/build-switch.txt"
|
import BUILD_SWITCH from "../session/prompt/build-switch.txt"
|
||||||
@@ -44,7 +44,6 @@ import { TaskTool } from "../tool/task"
|
|||||||
import { FileTime } from "../file/time"
|
import { FileTime } from "../file/time"
|
||||||
import { Permission } from "../permission"
|
import { Permission } from "../permission"
|
||||||
import { Snapshot } from "../snapshot"
|
import { Snapshot } from "../snapshot"
|
||||||
import { NamedError } from "../util/error"
|
|
||||||
import { ulid } from "ulid"
|
import { ulid } from "ulid"
|
||||||
import { spawn } from "child_process"
|
import { spawn } from "child_process"
|
||||||
import { Command } from "../command"
|
import { Command } from "../command"
|
||||||
@@ -55,6 +54,7 @@ import { MessageSummary } from "./summary"
|
|||||||
export namespace SessionPrompt {
|
export namespace SessionPrompt {
|
||||||
const log = Log.create({ service: "session.prompt" })
|
const log = Log.create({ service: "session.prompt" })
|
||||||
export const OUTPUT_TOKEN_MAX = 32_000
|
export const OUTPUT_TOKEN_MAX = 32_000
|
||||||
|
const MAX_RETRIES = 10
|
||||||
|
|
||||||
export const Event = {
|
export const Event = {
|
||||||
Idle: Bus.event(
|
Idle: Bus.event(
|
||||||
@@ -240,93 +240,145 @@ export namespace SessionPrompt {
|
|||||||
await using _ = defer(async () => {
|
await using _ = defer(async () => {
|
||||||
await processor.end()
|
await processor.end()
|
||||||
})
|
})
|
||||||
const stream = streamText({
|
const doStream = () =>
|
||||||
onError(error) {
|
streamText({
|
||||||
log.error("stream error", {
|
onError(error) {
|
||||||
error,
|
log.error("stream error", {
|
||||||
})
|
error,
|
||||||
},
|
|
||||||
async experimental_repairToolCall(input) {
|
|
||||||
const lower = input.toolCall.toolName.toLowerCase()
|
|
||||||
if (lower !== input.toolCall.toolName && tools[lower]) {
|
|
||||||
log.info("repairing tool call", {
|
|
||||||
tool: input.toolCall.toolName,
|
|
||||||
repaired: lower,
|
|
||||||
})
|
})
|
||||||
|
},
|
||||||
|
async experimental_repairToolCall(input) {
|
||||||
|
const lower = input.toolCall.toolName.toLowerCase()
|
||||||
|
if (lower !== input.toolCall.toolName && tools[lower]) {
|
||||||
|
log.info("repairing tool call", {
|
||||||
|
tool: input.toolCall.toolName,
|
||||||
|
repaired: lower,
|
||||||
|
})
|
||||||
|
return {
|
||||||
|
...input.toolCall,
|
||||||
|
toolName: lower,
|
||||||
|
}
|
||||||
|
}
|
||||||
return {
|
return {
|
||||||
...input.toolCall,
|
...input.toolCall,
|
||||||
toolName: lower,
|
input: JSON.stringify({
|
||||||
|
tool: input.toolCall.toolName,
|
||||||
|
error: input.error.message,
|
||||||
|
}),
|
||||||
|
toolName: "invalid",
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
return {
|
headers:
|
||||||
...input.toolCall,
|
model.providerID === "opencode"
|
||||||
input: JSON.stringify({
|
? {
|
||||||
tool: input.toolCall.toolName,
|
"x-opencode-session": input.sessionID,
|
||||||
error: input.error.message,
|
"x-opencode-request": userMsg.info.id,
|
||||||
}),
|
|
||||||
toolName: "invalid",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
headers:
|
|
||||||
model.providerID === "opencode"
|
|
||||||
? {
|
|
||||||
"x-opencode-session": input.sessionID,
|
|
||||||
"x-opencode-request": userMsg.info.id,
|
|
||||||
}
|
|
||||||
: undefined,
|
|
||||||
maxRetries: 10,
|
|
||||||
activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
|
|
||||||
maxOutputTokens: ProviderTransform.maxOutputTokens(
|
|
||||||
model.providerID,
|
|
||||||
params.options,
|
|
||||||
model.info.limit.output,
|
|
||||||
OUTPUT_TOKEN_MAX,
|
|
||||||
),
|
|
||||||
abortSignal: abort.signal,
|
|
||||||
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options),
|
|
||||||
stopWhen: stepCountIs(1),
|
|
||||||
temperature: params.temperature,
|
|
||||||
topP: params.topP,
|
|
||||||
messages: [
|
|
||||||
...system.map(
|
|
||||||
(x): ModelMessage => ({
|
|
||||||
role: "system",
|
|
||||||
content: x,
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
...MessageV2.toModelMessage(
|
|
||||||
msgs.filter((m) => {
|
|
||||||
if (m.info.role !== "assistant" || m.info.error === undefined) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if (
|
|
||||||
MessageV2.AbortedError.isInstance(m.info.error) &&
|
|
||||||
m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
|
|
||||||
) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
tools: model.info.tool_call === false ? undefined : tools,
|
|
||||||
model: wrapLanguageModel({
|
|
||||||
model: model.language,
|
|
||||||
middleware: [
|
|
||||||
{
|
|
||||||
async transformParams(args) {
|
|
||||||
if (args.type === "stream") {
|
|
||||||
// @ts-expect-error
|
|
||||||
args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
|
|
||||||
}
|
}
|
||||||
return args.params
|
: undefined,
|
||||||
},
|
// set to 0, we handle loop
|
||||||
},
|
maxRetries: 0,
|
||||||
|
activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
|
||||||
|
maxOutputTokens: ProviderTransform.maxOutputTokens(
|
||||||
|
model.providerID,
|
||||||
|
params.options,
|
||||||
|
model.info.limit.output,
|
||||||
|
OUTPUT_TOKEN_MAX,
|
||||||
|
),
|
||||||
|
abortSignal: abort.signal,
|
||||||
|
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options),
|
||||||
|
stopWhen: stepCountIs(1),
|
||||||
|
temperature: params.temperature,
|
||||||
|
topP: params.topP,
|
||||||
|
messages: [
|
||||||
|
...system.map(
|
||||||
|
(x): ModelMessage => ({
|
||||||
|
role: "system",
|
||||||
|
content: x,
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
...MessageV2.toModelMessage(
|
||||||
|
msgs.filter((m) => {
|
||||||
|
if (m.info.role !== "assistant" || m.info.error === undefined) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if (
|
||||||
|
MessageV2.AbortedError.isInstance(m.info.error) &&
|
||||||
|
m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
|
||||||
|
) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
}),
|
tools: model.info.tool_call === false ? undefined : tools,
|
||||||
|
model: wrapLanguageModel({
|
||||||
|
model: model.language,
|
||||||
|
middleware: [
|
||||||
|
{
|
||||||
|
async transformParams(args) {
|
||||||
|
if (args.type === "stream") {
|
||||||
|
// @ts-expect-error
|
||||||
|
args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
|
||||||
|
}
|
||||||
|
return args.params
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
|
||||||
|
let stream = doStream()
|
||||||
|
let result = await processor.process(stream, {
|
||||||
|
count: 0,
|
||||||
|
max: MAX_RETRIES,
|
||||||
})
|
})
|
||||||
const result = await processor.process(stream)
|
if (result.shouldRetry) {
|
||||||
|
for (let retry = 1; retry < MAX_RETRIES; retry++) {
|
||||||
|
const lastRetryPart = result.parts.findLast((p) => p.type === "retry")
|
||||||
|
|
||||||
|
if (lastRetryPart) {
|
||||||
|
const delayMs = SessionRetry.getRetryDelayInMs(lastRetryPart.error, retry)
|
||||||
|
|
||||||
|
log.info("retrying with backoff", {
|
||||||
|
attempt: retry,
|
||||||
|
delayMs,
|
||||||
|
})
|
||||||
|
|
||||||
|
const stop = await SessionRetry.sleep(delayMs, abort.signal)
|
||||||
|
.then(() => false)
|
||||||
|
.catch((error) => {
|
||||||
|
if (error instanceof DOMException && error.name === "AbortError") {
|
||||||
|
const err = new MessageV2.AbortedError(
|
||||||
|
{ message: error.message },
|
||||||
|
{
|
||||||
|
cause: error,
|
||||||
|
},
|
||||||
|
).toObject()
|
||||||
|
result.info.error = err
|
||||||
|
Bus.publish(Session.Event.Error, {
|
||||||
|
sessionID: result.info.sessionID,
|
||||||
|
error: result.info.error,
|
||||||
|
})
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
})
|
||||||
|
|
||||||
|
if (stop) break
|
||||||
|
}
|
||||||
|
|
||||||
|
stream = doStream()
|
||||||
|
result = await processor.process(stream, {
|
||||||
|
count: retry,
|
||||||
|
max: MAX_RETRIES,
|
||||||
|
})
|
||||||
|
if (!result.shouldRetry) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
await processor.end()
|
await processor.end()
|
||||||
|
|
||||||
const queued = state().queued.get(input.sessionID) ?? []
|
const queued = state().queued.get(input.sessionID) ?? []
|
||||||
@@ -959,9 +1011,10 @@ export namespace SessionPrompt {
|
|||||||
partFromToolCall(toolCallID: string) {
|
partFromToolCall(toolCallID: string) {
|
||||||
return toolcalls[toolCallID]
|
return toolcalls[toolCallID]
|
||||||
},
|
},
|
||||||
async process(stream: StreamTextResult<Record<string, AITool>, never>) {
|
async process(stream: StreamTextResult<Record<string, AITool>, never>, retries: { count: number; max: number }) {
|
||||||
log.info("process")
|
log.info("process")
|
||||||
if (!assistantMsg) throw new Error("call next() first before processing")
|
if (!assistantMsg) throw new Error("call next() first before processing")
|
||||||
|
let shouldRetry = false
|
||||||
try {
|
try {
|
||||||
let currentText: MessageV2.TextPart | undefined
|
let currentText: MessageV2.TextPart | undefined
|
||||||
let reasoningMap: Record<string, MessageV2.ReasoningPart> = {}
|
let reasoningMap: Record<string, MessageV2.ReasoningPart> = {}
|
||||||
@@ -1314,37 +1367,27 @@ export namespace SessionPrompt {
|
|||||||
log.error("process", {
|
log.error("process", {
|
||||||
error: e,
|
error: e,
|
||||||
})
|
})
|
||||||
switch (true) {
|
const error = MessageV2.fromError(e, { providerID: input.providerID })
|
||||||
case e instanceof DOMException && e.name === "AbortError":
|
if (retries.count < retries.max && MessageV2.APIError.isInstance(error) && error.data.isRetryable) {
|
||||||
assistantMsg.error = new MessageV2.AbortedError(
|
shouldRetry = true
|
||||||
{ message: e.message },
|
await Session.updatePart({
|
||||||
{
|
id: Identifier.ascending("part"),
|
||||||
cause: e,
|
messageID: assistantMsg.id,
|
||||||
},
|
sessionID: assistantMsg.sessionID,
|
||||||
).toObject()
|
type: "retry",
|
||||||
break
|
attempt: retries.count + 1,
|
||||||
case MessageV2.OutputLengthError.isInstance(e):
|
time: {
|
||||||
assistantMsg.error = e
|
created: Date.now(),
|
||||||
break
|
},
|
||||||
case LoadAPIKeyError.isInstance(e):
|
error,
|
||||||
assistantMsg.error = new MessageV2.AuthError(
|
})
|
||||||
{
|
} else {
|
||||||
providerID: input.providerID,
|
assistantMsg.error = error
|
||||||
message: e.message,
|
Bus.publish(Session.Event.Error, {
|
||||||
},
|
sessionID: assistantMsg.sessionID,
|
||||||
{ cause: e },
|
error: assistantMsg.error,
|
||||||
).toObject()
|
})
|
||||||
break
|
|
||||||
case e instanceof Error:
|
|
||||||
assistantMsg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
|
|
||||||
break
|
|
||||||
default:
|
|
||||||
assistantMsg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
|
|
||||||
}
|
}
|
||||||
Bus.publish(Session.Event.Error, {
|
|
||||||
sessionID: assistantMsg.sessionID,
|
|
||||||
error: assistantMsg.error,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
const p = await Session.getParts(assistantMsg.id)
|
const p = await Session.getParts(assistantMsg.id)
|
||||||
for (const part of p) {
|
for (const part of p) {
|
||||||
@@ -1363,9 +1406,11 @@ export namespace SessionPrompt {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assistantMsg.time.completed = Date.now()
|
if (!shouldRetry) {
|
||||||
|
assistantMsg.time.completed = Date.now()
|
||||||
|
}
|
||||||
await Session.updateMessage(assistantMsg)
|
await Session.updateMessage(assistantMsg)
|
||||||
return { info: assistantMsg, parts: p, blocked }
|
return { info: assistantMsg, parts: p, blocked, shouldRetry }
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|||||||
57
packages/opencode/src/session/retry.ts
Normal file
57
packages/opencode/src/session/retry.ts
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import { MessageV2 } from "./message-v2"
|
||||||
|
|
||||||
|
export namespace SessionRetry {
|
||||||
|
export const RETRY_INITIAL_DELAY = 2000
|
||||||
|
export const RETRY_BACKOFF_FACTOR = 2
|
||||||
|
|
||||||
|
export async function sleep(ms: number, signal: AbortSignal): Promise<void> {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
const timeout = setTimeout(resolve, ms)
|
||||||
|
signal.addEventListener(
|
||||||
|
"abort",
|
||||||
|
() => {
|
||||||
|
clearTimeout(timeout)
|
||||||
|
reject(new DOMException("Aborted", "AbortError"))
|
||||||
|
},
|
||||||
|
{ once: true },
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getRetryDelayInMs(error: MessageV2.APIError, attempt: number): number {
|
||||||
|
const base = RETRY_INITIAL_DELAY * Math.pow(RETRY_BACKOFF_FACTOR, attempt - 1)
|
||||||
|
const headers = error.data.responseHeaders
|
||||||
|
if (!headers) return base
|
||||||
|
|
||||||
|
const retryAfterMs = headers["retry-after-ms"]
|
||||||
|
if (retryAfterMs) {
|
||||||
|
const parsed = Number.parseFloat(retryAfterMs)
|
||||||
|
const normalized = normalizeDelay({ base, candidate: parsed })
|
||||||
|
if (normalized != null) return normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
const retryAfter = headers["retry-after"]
|
||||||
|
if (!retryAfter) return base
|
||||||
|
|
||||||
|
const seconds = Number.parseFloat(retryAfter)
|
||||||
|
if (!Number.isNaN(seconds)) {
|
||||||
|
const normalized = normalizeDelay({ base, candidate: seconds * 1000 })
|
||||||
|
if (normalized != null) return normalized
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
const dateMs = Date.parse(retryAfter) - Date.now()
|
||||||
|
const normalized = normalizeDelay({ base, candidate: dateMs })
|
||||||
|
if (normalized != null) return normalized
|
||||||
|
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
function normalizeDelay(input: { base: number; candidate: number }): number | undefined {
|
||||||
|
if (Number.isNaN(input.candidate)) return undefined
|
||||||
|
if (input.candidate < 0) return undefined
|
||||||
|
if (input.candidate < 60_000) return input.candidate
|
||||||
|
if (input.candidate < input.base) return input.candidate
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
}
|
||||||
47
packages/opencode/test/session/retry.test.ts
Normal file
47
packages/opencode/test/session/retry.test.ts
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import { describe, expect, test } from "bun:test"
|
||||||
|
import { SessionRetry } from "../../src/session/retry"
|
||||||
|
import { MessageV2 } from "../../src/session/message-v2"
|
||||||
|
|
||||||
|
function apiError(headers?: Record<string, string>): MessageV2.APIError {
|
||||||
|
return new MessageV2.APIError({
|
||||||
|
message: "boom",
|
||||||
|
isRetryable: true,
|
||||||
|
responseHeaders: headers,
|
||||||
|
}).toObject() as MessageV2.APIError
|
||||||
|
}
|
||||||
|
|
||||||
|
describe("session.retry.getRetryDelayInMs", () => {
|
||||||
|
test("doubles delay on each attempt when headers missing", () => {
|
||||||
|
const error = apiError()
|
||||||
|
const delays = Array.from({ length: 7 }, (_, index) => SessionRetry.getRetryDelayInMs(error, index + 1))
|
||||||
|
expect(delays).toStrictEqual([2000, 4000, 8000, 16000, 32000, 64000, 128000])
|
||||||
|
})
|
||||||
|
|
||||||
|
test("prefers retry-after-ms when shorter than exponential", () => {
|
||||||
|
const error = apiError({ "retry-after-ms": "1500" })
|
||||||
|
expect(SessionRetry.getRetryDelayInMs(error, 4)).toBe(1500)
|
||||||
|
})
|
||||||
|
|
||||||
|
test("uses retry-after seconds when reasonable", () => {
|
||||||
|
const error = apiError({ "retry-after": "30" })
|
||||||
|
expect(SessionRetry.getRetryDelayInMs(error, 3)).toBe(30000)
|
||||||
|
})
|
||||||
|
|
||||||
|
test("falls back to exponential when server delay is long", () => {
|
||||||
|
const error = apiError({ "retry-after": "120" })
|
||||||
|
expect(SessionRetry.getRetryDelayInMs(error, 2)).toBe(4000)
|
||||||
|
})
|
||||||
|
|
||||||
|
test("accepts http-date retry-after values", () => {
|
||||||
|
const date = new Date(Date.now() + 20000).toUTCString()
|
||||||
|
const error = apiError({ "retry-after": date })
|
||||||
|
const delay = SessionRetry.getRetryDelayInMs(error, 1)
|
||||||
|
expect(delay).toBeGreaterThanOrEqual(19000)
|
||||||
|
expect(delay).toBeLessThanOrEqual(20000)
|
||||||
|
})
|
||||||
|
|
||||||
|
test("ignores invalid retry hints", () => {
|
||||||
|
const error = apiError({ "retry-after": "not-a-number" })
|
||||||
|
expect(SessionRetry.getRetryDelayInMs(error, 1)).toBe(2000)
|
||||||
|
})
|
||||||
|
})
|
||||||
Reference in New Issue
Block a user