mirror of
https://github.com/aljazceru/opencode.git
synced 2025-12-21 09:44:21 +01:00
make compact interruptable (#3251)
This commit is contained in:
@@ -59,9 +59,7 @@ const ERRORS = {
|
|||||||
description: "Not found",
|
description: "Not found",
|
||||||
content: {
|
content: {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
schema: resolver(
|
schema: resolver(Storage.NotFoundError.Schema),
|
||||||
Storage.NotFoundError.Schema
|
|
||||||
)
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -87,12 +85,9 @@ export namespace Server {
|
|||||||
})
|
})
|
||||||
if (err instanceof NamedError) {
|
if (err instanceof NamedError) {
|
||||||
let status: ContentfulStatusCode
|
let status: ContentfulStatusCode
|
||||||
if (err instanceof Storage.NotFoundError)
|
if (err instanceof Storage.NotFoundError) status = 404
|
||||||
status = 404
|
else if (err instanceof Provider.ModelNotFoundError) status = 400
|
||||||
else if (err instanceof Provider.ModelNotFoundError)
|
else status = 500
|
||||||
status = 400
|
|
||||||
else
|
|
||||||
status = 500
|
|
||||||
return c.json(err.toObject(), { status })
|
return c.json(err.toObject(), { status })
|
||||||
}
|
}
|
||||||
const message = err instanceof Error && err.stack ? err.stack : err.toString()
|
const message = err instanceof Error && err.stack ? err.stack : err.toString()
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { streamText, type ModelMessage } from "ai"
|
import { streamText, type ModelMessage, LoadAPIKeyError } from "ai"
|
||||||
import { Session } from "."
|
import { Session } from "."
|
||||||
import { Identifier } from "../id/id"
|
import { Identifier } from "../id/id"
|
||||||
import { Instance } from "../project/instance"
|
import { Instance } from "../project/instance"
|
||||||
@@ -13,6 +13,8 @@ import { SessionPrompt } from "./prompt"
|
|||||||
import { Flag } from "../flag/flag"
|
import { Flag } from "../flag/flag"
|
||||||
import { Token } from "../util/token"
|
import { Token } from "../util/token"
|
||||||
import { Log } from "../util/log"
|
import { Log } from "../util/log"
|
||||||
|
import { SessionLock } from "./lock"
|
||||||
|
import { NamedError } from "../util/error"
|
||||||
|
|
||||||
export namespace SessionCompaction {
|
export namespace SessionCompaction {
|
||||||
const log = Log.create({ service: "session.compaction" })
|
const log = Log.create({ service: "session.compaction" })
|
||||||
@@ -82,7 +84,11 @@ export namespace SessionCompaction {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function run(input: { sessionID: string; providerID: string; modelID: string }) {
|
export async function run(input: { sessionID: string; providerID: string; modelID: string; signal?: AbortSignal }) {
|
||||||
|
if (!input.signal) SessionLock.assertUnlocked(input.sessionID)
|
||||||
|
await using lock = input.signal === undefined ? SessionLock.acquire({ sessionID: input.sessionID }) : undefined
|
||||||
|
const signal = input.signal ?? lock!.signal
|
||||||
|
|
||||||
await Session.update(input.sessionID, (draft) => {
|
await Session.update(input.sessionID, (draft) => {
|
||||||
draft.time.compacting = Date.now()
|
draft.time.compacting = Date.now()
|
||||||
})
|
})
|
||||||
@@ -122,6 +128,7 @@ export namespace SessionCompaction {
|
|||||||
created: Date.now(),
|
created: Date.now(),
|
||||||
},
|
},
|
||||||
})) as MessageV2.Assistant
|
})) as MessageV2.Assistant
|
||||||
|
|
||||||
const part = (await Session.updatePart({
|
const part = (await Session.updatePart({
|
||||||
type: "text",
|
type: "text",
|
||||||
sessionID: input.sessionID,
|
sessionID: input.sessionID,
|
||||||
@@ -133,13 +140,18 @@ export namespace SessionCompaction {
|
|||||||
},
|
},
|
||||||
})) as MessageV2.TextPart
|
})) as MessageV2.TextPart
|
||||||
|
|
||||||
let summaryText = ""
|
|
||||||
const stream = streamText({
|
const stream = streamText({
|
||||||
maxRetries: 10,
|
maxRetries: 10,
|
||||||
model: model.language,
|
model: model.language,
|
||||||
providerOptions: {
|
providerOptions: {
|
||||||
[model.npm === "@ai-sdk/openai" ? "openai" : model.providerID]: model.info.options,
|
[model.npm === "@ai-sdk/openai" ? "openai" : model.providerID]: model.info.options,
|
||||||
},
|
},
|
||||||
|
abortSignal: signal,
|
||||||
|
onError(error) {
|
||||||
|
log.error("stream error", {
|
||||||
|
error,
|
||||||
|
})
|
||||||
|
},
|
||||||
messages: [
|
messages: [
|
||||||
...system.map(
|
...system.map(
|
||||||
(x): ModelMessage => ({
|
(x): ModelMessage => ({
|
||||||
@@ -160,38 +172,88 @@ export namespace SessionCompaction {
|
|||||||
],
|
],
|
||||||
})
|
})
|
||||||
|
|
||||||
for await (const value of stream.fullStream) {
|
try {
|
||||||
switch (value.type) {
|
for await (const value of stream.fullStream) {
|
||||||
case "text-delta":
|
signal.throwIfAborted()
|
||||||
summaryText += value.text
|
switch (value.type) {
|
||||||
await Session.updatePart({
|
case "text-delta":
|
||||||
...part,
|
part.text += value.text
|
||||||
text: summaryText,
|
if (value.providerMetadata) part.metadata = value.providerMetadata
|
||||||
})
|
if (part.text) await Session.updatePart(part)
|
||||||
break
|
continue
|
||||||
case "text-end":
|
case "text-end": {
|
||||||
part.text = summaryText
|
part.text = part.text.trimEnd()
|
||||||
await Session.updatePart({
|
part.time = {
|
||||||
...part,
|
start: Date.now(),
|
||||||
})
|
end: Date.now(),
|
||||||
break
|
}
|
||||||
case "finish": {
|
if (value.providerMetadata) part.metadata = value.providerMetadata
|
||||||
const usage = Session.getUsage({ model: model.info, usage: value.totalUsage, metadata: undefined })
|
await Session.updatePart(part)
|
||||||
msg.cost += usage.cost
|
continue
|
||||||
msg.tokens = usage.tokens
|
}
|
||||||
msg.summary = true
|
case "finish-step": {
|
||||||
msg.time.completed = Date.now()
|
const usage = Session.getUsage({
|
||||||
await Session.updateMessage(msg)
|
model: model.info,
|
||||||
part.time!.end = Date.now()
|
usage: value.usage,
|
||||||
await Session.updatePart(part)
|
metadata: value.providerMetadata,
|
||||||
break
|
})
|
||||||
|
msg.cost += usage.cost
|
||||||
|
msg.tokens = usage.tokens
|
||||||
|
await Session.updateMessage(msg)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
case "error":
|
||||||
|
throw value.error
|
||||||
|
default:
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} catch (e) {
|
||||||
|
log.error("compaction error", {
|
||||||
|
error: e,
|
||||||
|
})
|
||||||
|
switch (true) {
|
||||||
|
case e instanceof DOMException && e.name === "AbortError":
|
||||||
|
msg.error = new MessageV2.AbortedError(
|
||||||
|
{ message: e.message },
|
||||||
|
{
|
||||||
|
cause: e,
|
||||||
|
},
|
||||||
|
).toObject()
|
||||||
|
break
|
||||||
|
case MessageV2.OutputLengthError.isInstance(e):
|
||||||
|
msg.error = e
|
||||||
|
break
|
||||||
|
case LoadAPIKeyError.isInstance(e):
|
||||||
|
msg.error = new MessageV2.AuthError(
|
||||||
|
{
|
||||||
|
providerID: model.providerID,
|
||||||
|
message: e.message,
|
||||||
|
},
|
||||||
|
{ cause: e },
|
||||||
|
).toObject()
|
||||||
|
break
|
||||||
|
case e instanceof Error:
|
||||||
|
msg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
msg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
|
||||||
|
}
|
||||||
|
Bus.publish(Session.Event.Error, {
|
||||||
|
sessionID: input.sessionID,
|
||||||
|
error: msg.error,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
Bus.publish(Event.Compacted, {
|
msg.time.completed = Date.now()
|
||||||
sessionID: input.sessionID,
|
|
||||||
})
|
if (!msg.error || MessageV2.AbortedError.isInstance(msg.error)) {
|
||||||
|
msg.summary = true
|
||||||
|
Bus.publish(Event.Compacted, {
|
||||||
|
sessionID: input.sessionID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
await Session.updateMessage(msg)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
info: msg,
|
info: msg,
|
||||||
|
|||||||
@@ -211,6 +211,7 @@ export namespace SessionPrompt {
|
|||||||
sessionID: input.sessionID,
|
sessionID: input.sessionID,
|
||||||
model: model.info,
|
model: model.info,
|
||||||
providerID: model.providerID,
|
providerID: model.providerID,
|
||||||
|
signal: abort.signal,
|
||||||
}),
|
}),
|
||||||
(messages) => insertReminders({ messages, agent }),
|
(messages) => insertReminders({ messages, agent }),
|
||||||
)
|
)
|
||||||
@@ -339,7 +340,12 @@ export namespace SessionPrompt {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async function getMessages(input: { sessionID: string; model: ModelsDev.Model; providerID: string }) {
|
async function getMessages(input: {
|
||||||
|
sessionID: string
|
||||||
|
model: ModelsDev.Model
|
||||||
|
providerID: string
|
||||||
|
signal: AbortSignal
|
||||||
|
}) {
|
||||||
let msgs = await Session.messages(input.sessionID).then(MessageV2.filterSummarized)
|
let msgs = await Session.messages(input.sessionID).then(MessageV2.filterSummarized)
|
||||||
const lastAssistant = msgs.findLast((msg) => msg.info.role === "assistant")
|
const lastAssistant = msgs.findLast((msg) => msg.info.role === "assistant")
|
||||||
if (
|
if (
|
||||||
@@ -353,6 +359,7 @@ export namespace SessionPrompt {
|
|||||||
sessionID: input.sessionID,
|
sessionID: input.sessionID,
|
||||||
providerID: input.providerID,
|
providerID: input.providerID,
|
||||||
modelID: input.model.id,
|
modelID: input.model.id,
|
||||||
|
signal: input.signal,
|
||||||
})
|
})
|
||||||
const resumeMsgID = Identifier.ascending("message")
|
const resumeMsgID = Identifier.ascending("message")
|
||||||
const resumeMsg = {
|
const resumeMsg = {
|
||||||
|
|||||||
Reference in New Issue
Block a user