mirror of
https://github.com/aljazceru/opencode.git
synced 2025-12-21 09:44:21 +01:00
make compact interruptable (#3251)
This commit is contained in:
@@ -59,9 +59,7 @@ const ERRORS = {
|
||||
description: "Not found",
|
||||
content: {
|
||||
"application/json": {
|
||||
schema: resolver(
|
||||
Storage.NotFoundError.Schema
|
||||
)
|
||||
schema: resolver(Storage.NotFoundError.Schema),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -87,12 +85,9 @@ export namespace Server {
|
||||
})
|
||||
if (err instanceof NamedError) {
|
||||
let status: ContentfulStatusCode
|
||||
if (err instanceof Storage.NotFoundError)
|
||||
status = 404
|
||||
else if (err instanceof Provider.ModelNotFoundError)
|
||||
status = 400
|
||||
else
|
||||
status = 500
|
||||
if (err instanceof Storage.NotFoundError) status = 404
|
||||
else if (err instanceof Provider.ModelNotFoundError) status = 400
|
||||
else status = 500
|
||||
return c.json(err.toObject(), { status })
|
||||
}
|
||||
const message = err instanceof Error && err.stack ? err.stack : err.toString()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { streamText, type ModelMessage } from "ai"
|
||||
import { streamText, type ModelMessage, LoadAPIKeyError } from "ai"
|
||||
import { Session } from "."
|
||||
import { Identifier } from "../id/id"
|
||||
import { Instance } from "../project/instance"
|
||||
@@ -13,6 +13,8 @@ import { SessionPrompt } from "./prompt"
|
||||
import { Flag } from "../flag/flag"
|
||||
import { Token } from "../util/token"
|
||||
import { Log } from "../util/log"
|
||||
import { SessionLock } from "./lock"
|
||||
import { NamedError } from "../util/error"
|
||||
|
||||
export namespace SessionCompaction {
|
||||
const log = Log.create({ service: "session.compaction" })
|
||||
@@ -82,7 +84,11 @@ export namespace SessionCompaction {
|
||||
}
|
||||
}
|
||||
|
||||
export async function run(input: { sessionID: string; providerID: string; modelID: string }) {
|
||||
export async function run(input: { sessionID: string; providerID: string; modelID: string; signal?: AbortSignal }) {
|
||||
if (!input.signal) SessionLock.assertUnlocked(input.sessionID)
|
||||
await using lock = input.signal === undefined ? SessionLock.acquire({ sessionID: input.sessionID }) : undefined
|
||||
const signal = input.signal ?? lock!.signal
|
||||
|
||||
await Session.update(input.sessionID, (draft) => {
|
||||
draft.time.compacting = Date.now()
|
||||
})
|
||||
@@ -122,6 +128,7 @@ export namespace SessionCompaction {
|
||||
created: Date.now(),
|
||||
},
|
||||
})) as MessageV2.Assistant
|
||||
|
||||
const part = (await Session.updatePart({
|
||||
type: "text",
|
||||
sessionID: input.sessionID,
|
||||
@@ -133,13 +140,18 @@ export namespace SessionCompaction {
|
||||
},
|
||||
})) as MessageV2.TextPart
|
||||
|
||||
let summaryText = ""
|
||||
const stream = streamText({
|
||||
maxRetries: 10,
|
||||
model: model.language,
|
||||
providerOptions: {
|
||||
[model.npm === "@ai-sdk/openai" ? "openai" : model.providerID]: model.info.options,
|
||||
},
|
||||
abortSignal: signal,
|
||||
onError(error) {
|
||||
log.error("stream error", {
|
||||
error,
|
||||
})
|
||||
},
|
||||
messages: [
|
||||
...system.map(
|
||||
(x): ModelMessage => ({
|
||||
@@ -160,38 +172,88 @@ export namespace SessionCompaction {
|
||||
],
|
||||
})
|
||||
|
||||
try {
|
||||
for await (const value of stream.fullStream) {
|
||||
signal.throwIfAborted()
|
||||
switch (value.type) {
|
||||
case "text-delta":
|
||||
summaryText += value.text
|
||||
await Session.updatePart({
|
||||
...part,
|
||||
text: summaryText,
|
||||
part.text += value.text
|
||||
if (value.providerMetadata) part.metadata = value.providerMetadata
|
||||
if (part.text) await Session.updatePart(part)
|
||||
continue
|
||||
case "text-end": {
|
||||
part.text = part.text.trimEnd()
|
||||
part.time = {
|
||||
start: Date.now(),
|
||||
end: Date.now(),
|
||||
}
|
||||
if (value.providerMetadata) part.metadata = value.providerMetadata
|
||||
await Session.updatePart(part)
|
||||
continue
|
||||
}
|
||||
case "finish-step": {
|
||||
const usage = Session.getUsage({
|
||||
model: model.info,
|
||||
usage: value.usage,
|
||||
metadata: value.providerMetadata,
|
||||
})
|
||||
break
|
||||
case "text-end":
|
||||
part.text = summaryText
|
||||
await Session.updatePart({
|
||||
...part,
|
||||
})
|
||||
break
|
||||
case "finish": {
|
||||
const usage = Session.getUsage({ model: model.info, usage: value.totalUsage, metadata: undefined })
|
||||
msg.cost += usage.cost
|
||||
msg.tokens = usage.tokens
|
||||
msg.summary = true
|
||||
msg.time.completed = Date.now()
|
||||
await Session.updateMessage(msg)
|
||||
part.time!.end = Date.now()
|
||||
await Session.updatePart(part)
|
||||
continue
|
||||
}
|
||||
case "error":
|
||||
throw value.error
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
log.error("compaction error", {
|
||||
error: e,
|
||||
})
|
||||
switch (true) {
|
||||
case e instanceof DOMException && e.name === "AbortError":
|
||||
msg.error = new MessageV2.AbortedError(
|
||||
{ message: e.message },
|
||||
{
|
||||
cause: e,
|
||||
},
|
||||
).toObject()
|
||||
break
|
||||
case MessageV2.OutputLengthError.isInstance(e):
|
||||
msg.error = e
|
||||
break
|
||||
case LoadAPIKeyError.isInstance(e):
|
||||
msg.error = new MessageV2.AuthError(
|
||||
{
|
||||
providerID: model.providerID,
|
||||
message: e.message,
|
||||
},
|
||||
{ cause: e },
|
||||
).toObject()
|
||||
break
|
||||
case e instanceof Error:
|
||||
msg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
|
||||
break
|
||||
default:
|
||||
msg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
|
||||
}
|
||||
}
|
||||
Bus.publish(Session.Event.Error, {
|
||||
sessionID: input.sessionID,
|
||||
error: msg.error,
|
||||
})
|
||||
}
|
||||
|
||||
msg.time.completed = Date.now()
|
||||
|
||||
if (!msg.error || MessageV2.AbortedError.isInstance(msg.error)) {
|
||||
msg.summary = true
|
||||
Bus.publish(Event.Compacted, {
|
||||
sessionID: input.sessionID,
|
||||
})
|
||||
}
|
||||
await Session.updateMessage(msg)
|
||||
|
||||
return {
|
||||
info: msg,
|
||||
|
||||
@@ -211,6 +211,7 @@ export namespace SessionPrompt {
|
||||
sessionID: input.sessionID,
|
||||
model: model.info,
|
||||
providerID: model.providerID,
|
||||
signal: abort.signal,
|
||||
}),
|
||||
(messages) => insertReminders({ messages, agent }),
|
||||
)
|
||||
@@ -339,7 +340,12 @@ export namespace SessionPrompt {
|
||||
}
|
||||
}
|
||||
|
||||
async function getMessages(input: { sessionID: string; model: ModelsDev.Model; providerID: string }) {
|
||||
async function getMessages(input: {
|
||||
sessionID: string
|
||||
model: ModelsDev.Model
|
||||
providerID: string
|
||||
signal: AbortSignal
|
||||
}) {
|
||||
let msgs = await Session.messages(input.sessionID).then(MessageV2.filterSummarized)
|
||||
const lastAssistant = msgs.findLast((msg) => msg.info.role === "assistant")
|
||||
if (
|
||||
@@ -353,6 +359,7 @@ export namespace SessionPrompt {
|
||||
sessionID: input.sessionID,
|
||||
providerID: input.providerID,
|
||||
modelID: input.model.id,
|
||||
signal: input.signal,
|
||||
})
|
||||
const resumeMsgID = Identifier.ascending("message")
|
||||
const resumeMsg = {
|
||||
|
||||
Reference in New Issue
Block a user