feat: retry parts (#3369)

This commit is contained in:
Aiden Cline
2025-10-22 18:31:36 -05:00
committed by GitHub
parent 9def7cff2d
commit 7c7ebb0a9d
5 changed files with 487 additions and 214 deletions

View File

@@ -1,4 +1,4 @@
import { streamText, type ModelMessage, LoadAPIKeyError } from "ai"
import { streamText, type ModelMessage, LoadAPIKeyError, type StreamTextResult, type Tool as AITool } from "ai"
import { Session } from "."
import { Identifier } from "../id/id"
import { Instance } from "../project/instance"
@@ -14,8 +14,8 @@ import { Flag } from "../flag/flag"
import { Token } from "../util/token"
import { Log } from "../util/log"
import { SessionLock } from "./lock"
import { NamedError } from "../util/error"
import { ProviderTransform } from "@/provider/transform"
import { SessionRetry } from "./retry"
export namespace SessionCompaction {
const log = Log.create({ service: "session.compaction" })
@@ -41,6 +41,7 @@ export namespace SessionCompaction {
export const PRUNE_MINIMUM = 20_000
export const PRUNE_PROTECT = 40_000
const MAX_RETRIES = 10
// goes backwards through parts until there are 40_000 tokens worth of tool
// calls. then erases output of previous tool calls. idea is to throw away old
@@ -142,112 +143,173 @@ export namespace SessionCompaction {
},
})) as MessageV2.TextPart
const stream = streamText({
maxRetries: 10,
model: model.language,
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, model.info.options),
abortSignal: signal,
onError(error) {
log.error("stream error", {
error,
})
},
messages: [
...system.map(
(x): ModelMessage => ({
role: "system",
content: x,
}),
),
...MessageV2.toModelMessage(toSummarize),
{
role: "user",
content: [
{
type: "text",
text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
},
],
const doStream = () =>
streamText({
// set to 0, we handle loop
maxRetries: 0,
model: model.language,
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, model.info.options),
abortSignal: signal,
onError(error) {
log.error("stream error", {
error,
})
},
],
})
messages: [
...system.map(
(x): ModelMessage => ({
role: "system",
content: x,
}),
),
...MessageV2.toModelMessage(toSummarize),
{
role: "user",
content: [
{
type: "text",
text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
},
],
},
],
})
try {
for await (const value of stream.fullStream) {
signal.throwIfAborted()
switch (value.type) {
case "text-delta":
part.text += value.text
if (value.providerMetadata) part.metadata = value.providerMetadata
if (part.text) await Session.updatePart(part)
continue
case "text-end": {
part.text = part.text.trimEnd()
part.time = {
start: Date.now(),
end: Date.now(),
// TODO: reduce duplication between compaction.ts & prompt.ts
const process = async (
stream: StreamTextResult<Record<string, AITool>, never>,
retries: { count: number; max: number },
) => {
let shouldRetry = false
try {
for await (const value of stream.fullStream) {
signal.throwIfAborted()
switch (value.type) {
case "text-delta":
part.text += value.text
if (value.providerMetadata) part.metadata = value.providerMetadata
if (part.text) await Session.updatePart(part)
continue
case "text-end": {
part.text = part.text.trimEnd()
part.time = {
start: Date.now(),
end: Date.now(),
}
if (value.providerMetadata) part.metadata = value.providerMetadata
await Session.updatePart(part)
continue
}
if (value.providerMetadata) part.metadata = value.providerMetadata
await Session.updatePart(part)
continue
case "finish-step": {
const usage = Session.getUsage({
model: model.info,
usage: value.usage,
metadata: value.providerMetadata,
})
msg.cost += usage.cost
msg.tokens = usage.tokens
await Session.updateMessage(msg)
continue
}
case "error":
throw value.error
default:
continue
}
case "finish-step": {
const usage = Session.getUsage({
model: model.info,
usage: value.usage,
metadata: value.providerMetadata,
})
msg.cost += usage.cost
msg.tokens = usage.tokens
await Session.updateMessage(msg)
continue
}
case "error":
throw value.error
default:
continue
}
} catch (e) {
log.error("compaction error", {
error: e,
})
const error = MessageV2.fromError(e, { providerID: input.providerID })
if (retries.count < retries.max && MessageV2.APIError.isInstance(error) && error.data.isRetryable) {
shouldRetry = true
await Session.updatePart({
id: Identifier.ascending("part"),
messageID: msg.id,
sessionID: msg.sessionID,
type: "retry",
attempt: retries.count + 1,
time: {
created: Date.now(),
},
error,
})
} else {
msg.error = error
Bus.publish(Session.Event.Error, {
sessionID: msg.sessionID,
error: msg.error,
})
}
}
} catch (e) {
log.error("compaction error", {
error: e,
})
switch (true) {
case e instanceof DOMException && e.name === "AbortError":
msg.error = new MessageV2.AbortedError(
{ message: e.message },
{
cause: e,
},
).toObject()
break
case MessageV2.OutputLengthError.isInstance(e):
msg.error = e
break
case LoadAPIKeyError.isInstance(e):
msg.error = new MessageV2.AuthError(
{
providerID: model.providerID,
message: e.message,
},
{ cause: e },
).toObject()
break
case e instanceof Error:
msg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
break
default:
msg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
const parts = await Session.getParts(msg.id)
return {
info: msg,
parts,
shouldRetry,
}
}
let stream = doStream()
let result = await process(stream, {
count: 0,
max: MAX_RETRIES,
})
if (result.shouldRetry) {
for (let retry = 1; retry < MAX_RETRIES; retry++) {
const lastRetryPart = result.parts.findLast((p) => p.type === "retry")
if (lastRetryPart) {
const delayMs = SessionRetry.getRetryDelayInMs(lastRetryPart.error, retry)
log.info("retrying with backoff", {
attempt: retry,
delayMs,
})
const stop = await SessionRetry.sleep(delayMs, signal)
.then(() => false)
.catch((error) => {
if (error instanceof DOMException && error.name === "AbortError") {
const err = new MessageV2.AbortedError(
{ message: error.message },
{
cause: error,
},
).toObject()
result.info.error = err
Bus.publish(Session.Event.Error, {
sessionID: result.info.sessionID,
error: result.info.error,
})
return true
}
throw error
})
if (stop) break
}
stream = doStream()
result = await process(stream, {
count: retry,
max: MAX_RETRIES,
})
if (!result.shouldRetry) {
break
}
}
Bus.publish(Session.Event.Error, {
sessionID: input.sessionID,
error: msg.error,
})
}
msg.time.completed = Date.now()
if (!msg.error || MessageV2.AbortedError.isInstance(msg.error)) {
if (
!msg.error ||
(MessageV2.AbortedError.isInstance(msg.error) &&
result.parts.some((part) => part.type === "text" && part.text.length > 0))
) {
msg.summary = true
Bus.publish(Event.Compacted, {
sessionID: input.sessionID,
@@ -257,7 +319,7 @@ export namespace SessionCompaction {
return {
info: msg,
parts: [part],
parts: result.parts,
}
}
}

View File

@@ -2,7 +2,7 @@ import z from "zod/v4"
import { Bus } from "../bus"
import { NamedError } from "../util/error"
import { Message } from "./message"
import { convertToModelMessages, type ModelMessage, type UIMessage } from "ai"
import { APICallError, convertToModelMessages, LoadAPIKeyError, type ModelMessage, type UIMessage } from "ai"
import { Identifier } from "../id/id"
import { LSP } from "../lsp"
import { Snapshot } from "@/snapshot"
@@ -18,6 +18,17 @@ export namespace MessageV2 {
message: z.string(),
}),
)
export const APIError = NamedError.create(
"APIError",
z.object({
message: z.string(),
statusCode: z.number().optional(),
isRetryable: z.boolean(),
responseHeaders: z.record(z.string(), z.string()).optional(),
responseBody: z.string().optional(),
}),
)
export type APIError = z.infer<typeof APIError.Schema>
const PartBase = z.object({
id: z.string(),
@@ -130,6 +141,18 @@ export namespace MessageV2 {
})
export type AgentPart = z.infer<typeof AgentPart>
export const RetryPart = PartBase.extend({
type: z.literal("retry"),
attempt: z.number(),
error: APIError.Schema,
time: z.object({
created: z.number(),
}),
}).meta({
ref: "RetryPart",
})
export type RetryPart = z.infer<typeof RetryPart>
export const StepStartPart = PartBase.extend({
type: z.literal("step-start"),
snapshot: z.string().optional(),
@@ -265,6 +288,7 @@ export namespace MessageV2 {
SnapshotPart,
PatchPart,
AgentPart,
RetryPart,
])
.meta({
ref: "Part",
@@ -283,6 +307,7 @@ export namespace MessageV2 {
NamedError.Unknown.Schema,
OutputLengthError.Schema,
AbortedError.Schema,
APIError.Schema,
])
.optional(),
system: z.string().array(),
@@ -610,4 +635,41 @@ export namespace MessageV2 {
if (i === -1) return msgs.slice()
return msgs.slice(i)
}
export function fromError(e: unknown, ctx: { providerID: string }) {
switch (true) {
case e instanceof DOMException && e.name === "AbortError":
return new MessageV2.AbortedError(
{ message: e.message },
{
cause: e,
},
).toObject()
case MessageV2.OutputLengthError.isInstance(e):
return e
case LoadAPIKeyError.isInstance(e):
return new MessageV2.AuthError(
{
providerID: ctx.providerID,
message: e.message,
},
{ cause: e },
).toObject()
case APICallError.isInstance(e):
return new MessageV2.APIError(
{
message: e.message,
statusCode: e.statusCode,
isRetryable: e.isRetryable,
responseHeaders: e.responseHeaders,
responseBody: e.responseBody,
},
{ cause: e },
).toObject()
case e instanceof Error:
return new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
default:
return new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
}
}
}

View File

@@ -17,7 +17,6 @@ import {
tool,
wrapLanguageModel,
type StreamTextResult,
LoadAPIKeyError,
stepCountIs,
jsonSchema,
} from "ai"
@@ -28,6 +27,7 @@ import { Bus } from "../bus"
import { ProviderTransform } from "../provider/transform"
import { SystemPrompt } from "./system"
import { Plugin } from "../plugin"
import { SessionRetry } from "./retry"
import PROMPT_PLAN from "../session/prompt/plan.txt"
import BUILD_SWITCH from "../session/prompt/build-switch.txt"
@@ -44,7 +44,6 @@ import { TaskTool } from "../tool/task"
import { FileTime } from "../file/time"
import { Permission } from "../permission"
import { Snapshot } from "../snapshot"
import { NamedError } from "../util/error"
import { ulid } from "ulid"
import { spawn } from "child_process"
import { Command } from "../command"
@@ -55,6 +54,7 @@ import { MessageSummary } from "./summary"
export namespace SessionPrompt {
const log = Log.create({ service: "session.prompt" })
export const OUTPUT_TOKEN_MAX = 32_000
const MAX_RETRIES = 10
export const Event = {
Idle: Bus.event(
@@ -240,93 +240,145 @@ export namespace SessionPrompt {
await using _ = defer(async () => {
await processor.end()
})
const stream = streamText({
onError(error) {
log.error("stream error", {
error,
})
},
async experimental_repairToolCall(input) {
const lower = input.toolCall.toolName.toLowerCase()
if (lower !== input.toolCall.toolName && tools[lower]) {
log.info("repairing tool call", {
tool: input.toolCall.toolName,
repaired: lower,
const doStream = () =>
streamText({
onError(error) {
log.error("stream error", {
error,
})
},
async experimental_repairToolCall(input) {
const lower = input.toolCall.toolName.toLowerCase()
if (lower !== input.toolCall.toolName && tools[lower]) {
log.info("repairing tool call", {
tool: input.toolCall.toolName,
repaired: lower,
})
return {
...input.toolCall,
toolName: lower,
}
}
return {
...input.toolCall,
toolName: lower,
input: JSON.stringify({
tool: input.toolCall.toolName,
error: input.error.message,
}),
toolName: "invalid",
}
}
return {
...input.toolCall,
input: JSON.stringify({
tool: input.toolCall.toolName,
error: input.error.message,
}),
toolName: "invalid",
}
},
headers:
model.providerID === "opencode"
? {
"x-opencode-session": input.sessionID,
"x-opencode-request": userMsg.info.id,
}
: undefined,
maxRetries: 10,
activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
maxOutputTokens: ProviderTransform.maxOutputTokens(
model.providerID,
params.options,
model.info.limit.output,
OUTPUT_TOKEN_MAX,
),
abortSignal: abort.signal,
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options),
stopWhen: stepCountIs(1),
temperature: params.temperature,
topP: params.topP,
messages: [
...system.map(
(x): ModelMessage => ({
role: "system",
content: x,
}),
),
...MessageV2.toModelMessage(
msgs.filter((m) => {
if (m.info.role !== "assistant" || m.info.error === undefined) {
return true
}
if (
MessageV2.AbortedError.isInstance(m.info.error) &&
m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
) {
return true
}
return false
}),
),
],
tools: model.info.tool_call === false ? undefined : tools,
model: wrapLanguageModel({
model: model.language,
middleware: [
{
async transformParams(args) {
if (args.type === "stream") {
// @ts-expect-error
args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
},
headers:
model.providerID === "opencode"
? {
"x-opencode-session": input.sessionID,
"x-opencode-request": userMsg.info.id,
}
return args.params
},
},
: undefined,
// set to 0, we handle loop
maxRetries: 0,
activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
maxOutputTokens: ProviderTransform.maxOutputTokens(
model.providerID,
params.options,
model.info.limit.output,
OUTPUT_TOKEN_MAX,
),
abortSignal: abort.signal,
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options),
stopWhen: stepCountIs(1),
temperature: params.temperature,
topP: params.topP,
messages: [
...system.map(
(x): ModelMessage => ({
role: "system",
content: x,
}),
),
...MessageV2.toModelMessage(
msgs.filter((m) => {
if (m.info.role !== "assistant" || m.info.error === undefined) {
return true
}
if (
MessageV2.AbortedError.isInstance(m.info.error) &&
m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
) {
return true
}
return false
}),
),
],
}),
tools: model.info.tool_call === false ? undefined : tools,
model: wrapLanguageModel({
model: model.language,
middleware: [
{
async transformParams(args) {
if (args.type === "stream") {
// @ts-expect-error
args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
}
return args.params
},
},
],
}),
})
let stream = doStream()
let result = await processor.process(stream, {
count: 0,
max: MAX_RETRIES,
})
const result = await processor.process(stream)
if (result.shouldRetry) {
for (let retry = 1; retry < MAX_RETRIES; retry++) {
const lastRetryPart = result.parts.findLast((p) => p.type === "retry")
if (lastRetryPart) {
const delayMs = SessionRetry.getRetryDelayInMs(lastRetryPart.error, retry)
log.info("retrying with backoff", {
attempt: retry,
delayMs,
})
const stop = await SessionRetry.sleep(delayMs, abort.signal)
.then(() => false)
.catch((error) => {
if (error instanceof DOMException && error.name === "AbortError") {
const err = new MessageV2.AbortedError(
{ message: error.message },
{
cause: error,
},
).toObject()
result.info.error = err
Bus.publish(Session.Event.Error, {
sessionID: result.info.sessionID,
error: result.info.error,
})
return true
}
throw error
})
if (stop) break
}
stream = doStream()
result = await processor.process(stream, {
count: retry,
max: MAX_RETRIES,
})
if (!result.shouldRetry) {
break
}
}
}
await processor.end()
const queued = state().queued.get(input.sessionID) ?? []
@@ -959,9 +1011,10 @@ export namespace SessionPrompt {
partFromToolCall(toolCallID: string) {
return toolcalls[toolCallID]
},
async process(stream: StreamTextResult<Record<string, AITool>, never>) {
async process(stream: StreamTextResult<Record<string, AITool>, never>, retries: { count: number; max: number }) {
log.info("process")
if (!assistantMsg) throw new Error("call next() first before processing")
let shouldRetry = false
try {
let currentText: MessageV2.TextPart | undefined
let reasoningMap: Record<string, MessageV2.ReasoningPart> = {}
@@ -1314,37 +1367,27 @@ export namespace SessionPrompt {
log.error("process", {
error: e,
})
switch (true) {
case e instanceof DOMException && e.name === "AbortError":
assistantMsg.error = new MessageV2.AbortedError(
{ message: e.message },
{
cause: e,
},
).toObject()
break
case MessageV2.OutputLengthError.isInstance(e):
assistantMsg.error = e
break
case LoadAPIKeyError.isInstance(e):
assistantMsg.error = new MessageV2.AuthError(
{
providerID: input.providerID,
message: e.message,
},
{ cause: e },
).toObject()
break
case e instanceof Error:
assistantMsg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
break
default:
assistantMsg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
const error = MessageV2.fromError(e, { providerID: input.providerID })
if (retries.count < retries.max && MessageV2.APIError.isInstance(error) && error.data.isRetryable) {
shouldRetry = true
await Session.updatePart({
id: Identifier.ascending("part"),
messageID: assistantMsg.id,
sessionID: assistantMsg.sessionID,
type: "retry",
attempt: retries.count + 1,
time: {
created: Date.now(),
},
error,
})
} else {
assistantMsg.error = error
Bus.publish(Session.Event.Error, {
sessionID: assistantMsg.sessionID,
error: assistantMsg.error,
})
}
Bus.publish(Session.Event.Error, {
sessionID: assistantMsg.sessionID,
error: assistantMsg.error,
})
}
const p = await Session.getParts(assistantMsg.id)
for (const part of p) {
@@ -1363,9 +1406,11 @@ export namespace SessionPrompt {
})
}
}
assistantMsg.time.completed = Date.now()
if (!shouldRetry) {
assistantMsg.time.completed = Date.now()
}
await Session.updateMessage(assistantMsg)
return { info: assistantMsg, parts: p, blocked }
return { info: assistantMsg, parts: p, blocked, shouldRetry }
},
}
return result

View File

@@ -0,0 +1,57 @@
import { MessageV2 } from "./message-v2"
export namespace SessionRetry {
export const RETRY_INITIAL_DELAY = 2000
export const RETRY_BACKOFF_FACTOR = 2
export async function sleep(ms: number, signal: AbortSignal): Promise<void> {
return new Promise((resolve, reject) => {
const timeout = setTimeout(resolve, ms)
signal.addEventListener(
"abort",
() => {
clearTimeout(timeout)
reject(new DOMException("Aborted", "AbortError"))
},
{ once: true },
)
})
}
export function getRetryDelayInMs(error: MessageV2.APIError, attempt: number): number {
const base = RETRY_INITIAL_DELAY * Math.pow(RETRY_BACKOFF_FACTOR, attempt - 1)
const headers = error.data.responseHeaders
if (!headers) return base
const retryAfterMs = headers["retry-after-ms"]
if (retryAfterMs) {
const parsed = Number.parseFloat(retryAfterMs)
const normalized = normalizeDelay({ base, candidate: parsed })
if (normalized != null) return normalized
}
const retryAfter = headers["retry-after"]
if (!retryAfter) return base
const seconds = Number.parseFloat(retryAfter)
if (!Number.isNaN(seconds)) {
const normalized = normalizeDelay({ base, candidate: seconds * 1000 })
if (normalized != null) return normalized
return base
}
const dateMs = Date.parse(retryAfter) - Date.now()
const normalized = normalizeDelay({ base, candidate: dateMs })
if (normalized != null) return normalized
return base
}
function normalizeDelay(input: { base: number; candidate: number }): number | undefined {
if (Number.isNaN(input.candidate)) return undefined
if (input.candidate < 0) return undefined
if (input.candidate < 60_000) return input.candidate
if (input.candidate < input.base) return input.candidate
return undefined
}
}

View File

@@ -0,0 +1,47 @@
import { describe, expect, test } from "bun:test"
import { SessionRetry } from "../../src/session/retry"
import { MessageV2 } from "../../src/session/message-v2"
function apiError(headers?: Record<string, string>): MessageV2.APIError {
return new MessageV2.APIError({
message: "boom",
isRetryable: true,
responseHeaders: headers,
}).toObject() as MessageV2.APIError
}
describe("session.retry.getRetryDelayInMs", () => {
test("doubles delay on each attempt when headers missing", () => {
const error = apiError()
const delays = Array.from({ length: 7 }, (_, index) => SessionRetry.getRetryDelayInMs(error, index + 1))
expect(delays).toStrictEqual([2000, 4000, 8000, 16000, 32000, 64000, 128000])
})
test("prefers retry-after-ms when shorter than exponential", () => {
const error = apiError({ "retry-after-ms": "1500" })
expect(SessionRetry.getRetryDelayInMs(error, 4)).toBe(1500)
})
test("uses retry-after seconds when reasonable", () => {
const error = apiError({ "retry-after": "30" })
expect(SessionRetry.getRetryDelayInMs(error, 3)).toBe(30000)
})
test("falls back to exponential when server delay is long", () => {
const error = apiError({ "retry-after": "120" })
expect(SessionRetry.getRetryDelayInMs(error, 2)).toBe(4000)
})
test("accepts http-date retry-after values", () => {
const date = new Date(Date.now() + 20000).toUTCString()
const error = apiError({ "retry-after": date })
const delay = SessionRetry.getRetryDelayInMs(error, 1)
expect(delay).toBeGreaterThanOrEqual(19000)
expect(delay).toBeLessThanOrEqual(20000)
})
test("ignores invalid retry hints", () => {
const error = apiError({ "retry-after": "not-a-number" })
expect(SessionRetry.getRetryDelayInMs(error, 1)).toBe(2000)
})
})