make compact interruptable (#3251)

This commit is contained in:
Aiden Cline
2025-10-18 11:49:29 -05:00
committed by GitHub
parent 8da8c9e78c
commit 1f869bccc1
3 changed files with 105 additions and 41 deletions

View File

@@ -59,9 +59,7 @@ const ERRORS = {
description: "Not found", description: "Not found",
content: { content: {
"application/json": { "application/json": {
schema: resolver( schema: resolver(Storage.NotFoundError.Schema),
Storage.NotFoundError.Schema
)
}, },
}, },
}, },
@@ -87,12 +85,9 @@ export namespace Server {
}) })
if (err instanceof NamedError) { if (err instanceof NamedError) {
let status: ContentfulStatusCode let status: ContentfulStatusCode
if (err instanceof Storage.NotFoundError) if (err instanceof Storage.NotFoundError) status = 404
status = 404 else if (err instanceof Provider.ModelNotFoundError) status = 400
else if (err instanceof Provider.ModelNotFoundError) else status = 500
status = 400
else
status = 500
return c.json(err.toObject(), { status }) return c.json(err.toObject(), { status })
} }
const message = err instanceof Error && err.stack ? err.stack : err.toString() const message = err instanceof Error && err.stack ? err.stack : err.toString()

View File

@@ -1,4 +1,4 @@
import { streamText, type ModelMessage } from "ai" import { streamText, type ModelMessage, LoadAPIKeyError } from "ai"
import { Session } from "." import { Session } from "."
import { Identifier } from "../id/id" import { Identifier } from "../id/id"
import { Instance } from "../project/instance" import { Instance } from "../project/instance"
@@ -13,6 +13,8 @@ import { SessionPrompt } from "./prompt"
import { Flag } from "../flag/flag" import { Flag } from "../flag/flag"
import { Token } from "../util/token" import { Token } from "../util/token"
import { Log } from "../util/log" import { Log } from "../util/log"
import { SessionLock } from "./lock"
import { NamedError } from "../util/error"
export namespace SessionCompaction { export namespace SessionCompaction {
const log = Log.create({ service: "session.compaction" }) const log = Log.create({ service: "session.compaction" })
@@ -82,7 +84,11 @@ export namespace SessionCompaction {
} }
} }
export async function run(input: { sessionID: string; providerID: string; modelID: string }) { export async function run(input: { sessionID: string; providerID: string; modelID: string; signal?: AbortSignal }) {
if (!input.signal) SessionLock.assertUnlocked(input.sessionID)
await using lock = input.signal === undefined ? SessionLock.acquire({ sessionID: input.sessionID }) : undefined
const signal = input.signal ?? lock!.signal
await Session.update(input.sessionID, (draft) => { await Session.update(input.sessionID, (draft) => {
draft.time.compacting = Date.now() draft.time.compacting = Date.now()
}) })
@@ -122,6 +128,7 @@ export namespace SessionCompaction {
created: Date.now(), created: Date.now(),
}, },
})) as MessageV2.Assistant })) as MessageV2.Assistant
const part = (await Session.updatePart({ const part = (await Session.updatePart({
type: "text", type: "text",
sessionID: input.sessionID, sessionID: input.sessionID,
@@ -133,13 +140,18 @@ export namespace SessionCompaction {
}, },
})) as MessageV2.TextPart })) as MessageV2.TextPart
let summaryText = ""
const stream = streamText({ const stream = streamText({
maxRetries: 10, maxRetries: 10,
model: model.language, model: model.language,
providerOptions: { providerOptions: {
[model.npm === "@ai-sdk/openai" ? "openai" : model.providerID]: model.info.options, [model.npm === "@ai-sdk/openai" ? "openai" : model.providerID]: model.info.options,
}, },
abortSignal: signal,
onError(error) {
log.error("stream error", {
error,
})
},
messages: [ messages: [
...system.map( ...system.map(
(x): ModelMessage => ({ (x): ModelMessage => ({
@@ -160,38 +172,88 @@ export namespace SessionCompaction {
], ],
}) })
try {
for await (const value of stream.fullStream) { for await (const value of stream.fullStream) {
signal.throwIfAborted()
switch (value.type) { switch (value.type) {
case "text-delta": case "text-delta":
summaryText += value.text part.text += value.text
await Session.updatePart({ if (value.providerMetadata) part.metadata = value.providerMetadata
...part, if (part.text) await Session.updatePart(part)
text: summaryText, continue
case "text-end": {
part.text = part.text.trimEnd()
part.time = {
start: Date.now(),
end: Date.now(),
}
if (value.providerMetadata) part.metadata = value.providerMetadata
await Session.updatePart(part)
continue
}
case "finish-step": {
const usage = Session.getUsage({
model: model.info,
usage: value.usage,
metadata: value.providerMetadata,
}) })
break
case "text-end":
part.text = summaryText
await Session.updatePart({
...part,
})
break
case "finish": {
const usage = Session.getUsage({ model: model.info, usage: value.totalUsage, metadata: undefined })
msg.cost += usage.cost msg.cost += usage.cost
msg.tokens = usage.tokens msg.tokens = usage.tokens
msg.summary = true
msg.time.completed = Date.now()
await Session.updateMessage(msg) await Session.updateMessage(msg)
part.time!.end = Date.now() continue
await Session.updatePart(part) }
case "error":
throw value.error
default:
continue
}
}
} catch (e) {
log.error("compaction error", {
error: e,
})
switch (true) {
case e instanceof DOMException && e.name === "AbortError":
msg.error = new MessageV2.AbortedError(
{ message: e.message },
{
cause: e,
},
).toObject()
break break
case MessageV2.OutputLengthError.isInstance(e):
msg.error = e
break
case LoadAPIKeyError.isInstance(e):
msg.error = new MessageV2.AuthError(
{
providerID: model.providerID,
message: e.message,
},
{ cause: e },
).toObject()
break
case e instanceof Error:
msg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
break
default:
msg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
} }
} Bus.publish(Session.Event.Error, {
sessionID: input.sessionID,
error: msg.error,
})
} }
msg.time.completed = Date.now()
if (!msg.error || MessageV2.AbortedError.isInstance(msg.error)) {
msg.summary = true
Bus.publish(Event.Compacted, { Bus.publish(Event.Compacted, {
sessionID: input.sessionID, sessionID: input.sessionID,
}) })
}
await Session.updateMessage(msg)
return { return {
info: msg, info: msg,

View File

@@ -211,6 +211,7 @@ export namespace SessionPrompt {
sessionID: input.sessionID, sessionID: input.sessionID,
model: model.info, model: model.info,
providerID: model.providerID, providerID: model.providerID,
signal: abort.signal,
}), }),
(messages) => insertReminders({ messages, agent }), (messages) => insertReminders({ messages, agent }),
) )
@@ -339,7 +340,12 @@ export namespace SessionPrompt {
} }
} }
async function getMessages(input: { sessionID: string; model: ModelsDev.Model; providerID: string }) { async function getMessages(input: {
sessionID: string
model: ModelsDev.Model
providerID: string
signal: AbortSignal
}) {
let msgs = await Session.messages(input.sessionID).then(MessageV2.filterSummarized) let msgs = await Session.messages(input.sessionID).then(MessageV2.filterSummarized)
const lastAssistant = msgs.findLast((msg) => msg.info.role === "assistant") const lastAssistant = msgs.findLast((msg) => msg.info.role === "assistant")
if ( if (
@@ -353,6 +359,7 @@ export namespace SessionPrompt {
sessionID: input.sessionID, sessionID: input.sessionID,
providerID: input.providerID, providerID: input.providerID,
modelID: input.model.id, modelID: input.model.id,
signal: input.signal,
}) })
const resumeMsgID = Identifier.ascending("message") const resumeMsgID = Identifier.ascending("message")
const resumeMsg = { const resumeMsg = {