make compact interruptable (#3251)

This commit is contained in:
Aiden Cline
2025-10-18 11:49:29 -05:00
committed by GitHub
parent 8da8c9e78c
commit 1f869bccc1
3 changed files with 105 additions and 41 deletions

View File

@@ -59,9 +59,7 @@ const ERRORS = {
description: "Not found",
content: {
"application/json": {
schema: resolver(
Storage.NotFoundError.Schema
)
schema: resolver(Storage.NotFoundError.Schema),
},
},
},
@@ -87,12 +85,9 @@ export namespace Server {
})
if (err instanceof NamedError) {
let status: ContentfulStatusCode
if (err instanceof Storage.NotFoundError)
status = 404
else if (err instanceof Provider.ModelNotFoundError)
status = 400
else
status = 500
if (err instanceof Storage.NotFoundError) status = 404
else if (err instanceof Provider.ModelNotFoundError) status = 400
else status = 500
return c.json(err.toObject(), { status })
}
const message = err instanceof Error && err.stack ? err.stack : err.toString()

View File

@@ -1,4 +1,4 @@
import { streamText, type ModelMessage } from "ai"
import { streamText, type ModelMessage, LoadAPIKeyError } from "ai"
import { Session } from "."
import { Identifier } from "../id/id"
import { Instance } from "../project/instance"
@@ -13,6 +13,8 @@ import { SessionPrompt } from "./prompt"
import { Flag } from "../flag/flag"
import { Token } from "../util/token"
import { Log } from "../util/log"
import { SessionLock } from "./lock"
import { NamedError } from "../util/error"
export namespace SessionCompaction {
const log = Log.create({ service: "session.compaction" })
@@ -82,7 +84,11 @@ export namespace SessionCompaction {
}
}
export async function run(input: { sessionID: string; providerID: string; modelID: string }) {
export async function run(input: { sessionID: string; providerID: string; modelID: string; signal?: AbortSignal }) {
if (!input.signal) SessionLock.assertUnlocked(input.sessionID)
await using lock = input.signal === undefined ? SessionLock.acquire({ sessionID: input.sessionID }) : undefined
const signal = input.signal ?? lock!.signal
await Session.update(input.sessionID, (draft) => {
draft.time.compacting = Date.now()
})
@@ -122,6 +128,7 @@ export namespace SessionCompaction {
created: Date.now(),
},
})) as MessageV2.Assistant
const part = (await Session.updatePart({
type: "text",
sessionID: input.sessionID,
@@ -133,13 +140,18 @@ export namespace SessionCompaction {
},
})) as MessageV2.TextPart
let summaryText = ""
const stream = streamText({
maxRetries: 10,
model: model.language,
providerOptions: {
[model.npm === "@ai-sdk/openai" ? "openai" : model.providerID]: model.info.options,
},
abortSignal: signal,
onError(error) {
log.error("stream error", {
error,
})
},
messages: [
...system.map(
(x): ModelMessage => ({
@@ -160,38 +172,88 @@ export namespace SessionCompaction {
],
})
for await (const value of stream.fullStream) {
switch (value.type) {
case "text-delta":
summaryText += value.text
await Session.updatePart({
...part,
text: summaryText,
})
break
case "text-end":
part.text = summaryText
await Session.updatePart({
...part,
})
break
case "finish": {
const usage = Session.getUsage({ model: model.info, usage: value.totalUsage, metadata: undefined })
msg.cost += usage.cost
msg.tokens = usage.tokens
msg.summary = true
msg.time.completed = Date.now()
await Session.updateMessage(msg)
part.time!.end = Date.now()
await Session.updatePart(part)
break
try {
for await (const value of stream.fullStream) {
signal.throwIfAborted()
switch (value.type) {
case "text-delta":
part.text += value.text
if (value.providerMetadata) part.metadata = value.providerMetadata
if (part.text) await Session.updatePart(part)
continue
case "text-end": {
part.text = part.text.trimEnd()
part.time = {
start: Date.now(),
end: Date.now(),
}
if (value.providerMetadata) part.metadata = value.providerMetadata
await Session.updatePart(part)
continue
}
case "finish-step": {
const usage = Session.getUsage({
model: model.info,
usage: value.usage,
metadata: value.providerMetadata,
})
msg.cost += usage.cost
msg.tokens = usage.tokens
await Session.updateMessage(msg)
continue
}
case "error":
throw value.error
default:
continue
}
}
} catch (e) {
log.error("compaction error", {
error: e,
})
switch (true) {
case e instanceof DOMException && e.name === "AbortError":
msg.error = new MessageV2.AbortedError(
{ message: e.message },
{
cause: e,
},
).toObject()
break
case MessageV2.OutputLengthError.isInstance(e):
msg.error = e
break
case LoadAPIKeyError.isInstance(e):
msg.error = new MessageV2.AuthError(
{
providerID: model.providerID,
message: e.message,
},
{ cause: e },
).toObject()
break
case e instanceof Error:
msg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
break
default:
msg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
}
Bus.publish(Session.Event.Error, {
sessionID: input.sessionID,
error: msg.error,
})
}
Bus.publish(Event.Compacted, {
sessionID: input.sessionID,
})
msg.time.completed = Date.now()
if (!msg.error || MessageV2.AbortedError.isInstance(msg.error)) {
msg.summary = true
Bus.publish(Event.Compacted, {
sessionID: input.sessionID,
})
}
await Session.updateMessage(msg)
return {
info: msg,

View File

@@ -211,6 +211,7 @@ export namespace SessionPrompt {
sessionID: input.sessionID,
model: model.info,
providerID: model.providerID,
signal: abort.signal,
}),
(messages) => insertReminders({ messages, agent }),
)
@@ -339,7 +340,12 @@ export namespace SessionPrompt {
}
}
async function getMessages(input: { sessionID: string; model: ModelsDev.Model; providerID: string }) {
async function getMessages(input: {
sessionID: string
model: ModelsDev.Model
providerID: string
signal: AbortSignal
}) {
let msgs = await Session.messages(input.sessionID).then(MessageV2.filterSummarized)
const lastAssistant = msgs.findLast((msg) => msg.info.role === "assistant")
if (
@@ -353,6 +359,7 @@ export namespace SessionPrompt {
sessionID: input.sessionID,
providerID: input.providerID,
modelID: input.model.id,
signal: input.signal,
})
const resumeMsgID = Identifier.ascending("message")
const resumeMsg = {