diff --git a/packages/opencode/src/server/server.ts b/packages/opencode/src/server/server.ts index c33e6deb..bea61a65 100644 --- a/packages/opencode/src/server/server.ts +++ b/packages/opencode/src/server/server.ts @@ -59,9 +59,7 @@ const ERRORS = { description: "Not found", content: { "application/json": { - schema: resolver( - Storage.NotFoundError.Schema - ) + schema: resolver(Storage.NotFoundError.Schema), }, }, }, @@ -87,12 +85,9 @@ export namespace Server { }) if (err instanceof NamedError) { let status: ContentfulStatusCode - if (err instanceof Storage.NotFoundError) - status = 404 - else if (err instanceof Provider.ModelNotFoundError) - status = 400 - else - status = 500 + if (err instanceof Storage.NotFoundError) status = 404 + else if (err instanceof Provider.ModelNotFoundError) status = 400 + else status = 500 return c.json(err.toObject(), { status }) } const message = err instanceof Error && err.stack ? err.stack : err.toString() diff --git a/packages/opencode/src/session/compaction.ts b/packages/opencode/src/session/compaction.ts index f9c5f363..6ef56a7e 100644 --- a/packages/opencode/src/session/compaction.ts +++ b/packages/opencode/src/session/compaction.ts @@ -1,4 +1,4 @@ -import { streamText, type ModelMessage } from "ai" +import { streamText, type ModelMessage, LoadAPIKeyError } from "ai" import { Session } from "." import { Identifier } from "../id/id" import { Instance } from "../project/instance" @@ -13,6 +13,8 @@ import { SessionPrompt } from "./prompt" import { Flag } from "../flag/flag" import { Token } from "../util/token" import { Log } from "../util/log" +import { SessionLock } from "./lock" +import { NamedError } from "../util/error" export namespace SessionCompaction { const log = Log.create({ service: "session.compaction" }) @@ -82,7 +84,11 @@ export namespace SessionCompaction { } } - export async function run(input: { sessionID: string; providerID: string; modelID: string }) { + export async function run(input: { sessionID: string; providerID: string; modelID: string; signal?: AbortSignal }) { + if (!input.signal) SessionLock.assertUnlocked(input.sessionID) + await using lock = input.signal === undefined ? SessionLock.acquire({ sessionID: input.sessionID }) : undefined + const signal = input.signal ?? lock!.signal + await Session.update(input.sessionID, (draft) => { draft.time.compacting = Date.now() }) @@ -122,6 +128,7 @@ export namespace SessionCompaction { created: Date.now(), }, })) as MessageV2.Assistant + const part = (await Session.updatePart({ type: "text", sessionID: input.sessionID, @@ -133,13 +140,18 @@ export namespace SessionCompaction { }, })) as MessageV2.TextPart - let summaryText = "" const stream = streamText({ maxRetries: 10, model: model.language, providerOptions: { [model.npm === "@ai-sdk/openai" ? "openai" : model.providerID]: model.info.options, }, + abortSignal: signal, + onError(error) { + log.error("stream error", { + error, + }) + }, messages: [ ...system.map( (x): ModelMessage => ({ @@ -160,38 +172,88 @@ export namespace SessionCompaction { ], }) - for await (const value of stream.fullStream) { - switch (value.type) { - case "text-delta": - summaryText += value.text - await Session.updatePart({ - ...part, - text: summaryText, - }) - break - case "text-end": - part.text = summaryText - await Session.updatePart({ - ...part, - }) - break - case "finish": { - const usage = Session.getUsage({ model: model.info, usage: value.totalUsage, metadata: undefined }) - msg.cost += usage.cost - msg.tokens = usage.tokens - msg.summary = true - msg.time.completed = Date.now() - await Session.updateMessage(msg) - part.time!.end = Date.now() - await Session.updatePart(part) - break + try { + for await (const value of stream.fullStream) { + signal.throwIfAborted() + switch (value.type) { + case "text-delta": + part.text += value.text + if (value.providerMetadata) part.metadata = value.providerMetadata + if (part.text) await Session.updatePart(part) + continue + case "text-end": { + part.text = part.text.trimEnd() + part.time = { + start: Date.now(), + end: Date.now(), + } + if (value.providerMetadata) part.metadata = value.providerMetadata + await Session.updatePart(part) + continue + } + case "finish-step": { + const usage = Session.getUsage({ + model: model.info, + usage: value.usage, + metadata: value.providerMetadata, + }) + msg.cost += usage.cost + msg.tokens = usage.tokens + await Session.updateMessage(msg) + continue + } + case "error": + throw value.error + default: + continue } } + } catch (e) { + log.error("compaction error", { + error: e, + }) + switch (true) { + case e instanceof DOMException && e.name === "AbortError": + msg.error = new MessageV2.AbortedError( + { message: e.message }, + { + cause: e, + }, + ).toObject() + break + case MessageV2.OutputLengthError.isInstance(e): + msg.error = e + break + case LoadAPIKeyError.isInstance(e): + msg.error = new MessageV2.AuthError( + { + providerID: model.providerID, + message: e.message, + }, + { cause: e }, + ).toObject() + break + case e instanceof Error: + msg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject() + break + default: + msg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e }) + } + Bus.publish(Session.Event.Error, { + sessionID: input.sessionID, + error: msg.error, + }) } - Bus.publish(Event.Compacted, { - sessionID: input.sessionID, - }) + msg.time.completed = Date.now() + + if (!msg.error || MessageV2.AbortedError.isInstance(msg.error)) { + msg.summary = true + Bus.publish(Event.Compacted, { + sessionID: input.sessionID, + }) + } + await Session.updateMessage(msg) return { info: msg, diff --git a/packages/opencode/src/session/prompt.ts b/packages/opencode/src/session/prompt.ts index 29940dda..cced1271 100644 --- a/packages/opencode/src/session/prompt.ts +++ b/packages/opencode/src/session/prompt.ts @@ -211,6 +211,7 @@ export namespace SessionPrompt { sessionID: input.sessionID, model: model.info, providerID: model.providerID, + signal: abort.signal, }), (messages) => insertReminders({ messages, agent }), ) @@ -339,7 +340,12 @@ export namespace SessionPrompt { } } - async function getMessages(input: { sessionID: string; model: ModelsDev.Model; providerID: string }) { + async function getMessages(input: { + sessionID: string + model: ModelsDev.Model + providerID: string + signal: AbortSignal + }) { let msgs = await Session.messages(input.sessionID).then(MessageV2.filterSummarized) const lastAssistant = msgs.findLast((msg) => msg.info.role === "assistant") if ( @@ -353,6 +359,7 @@ export namespace SessionPrompt { sessionID: input.sessionID, providerID: input.providerID, modelID: input.model.id, + signal: input.signal, }) const resumeMsgID = Identifier.ascending("message") const resumeMsg = {