fix compaction issues

This commit is contained in:
Dax Raad
2025-09-12 05:59:53 -04:00
parent 3c502861a7
commit 983e3b2ee3
2 changed files with 185 additions and 115 deletions

View File

@@ -53,6 +53,7 @@ import { defer } from "../util/defer"
import { Command } from "../command"
import { $ } from "bun"
import { ListTool } from "../tool/ls"
import { Token } from "../util/token"
export namespace Session {
const log = Log.create({ service: "session" })
@@ -83,6 +84,12 @@ export namespace Session {
.optional(),
title: z.string(),
version: z.string(),
compaction: z
.object({
full: z.string().optional(),
micro: z.string().optional(),
})
.optional(),
time: z.object({
created: z.number(),
updated: z.number(),
@@ -361,6 +368,7 @@ export namespace Session {
Bus.publish(MessageV2.Event.Updated, {
info: msg,
})
return msg
}
async function updatePart(part: MessageV2.Part) {
@@ -717,14 +725,34 @@ export namespace Session {
}
return Provider.defaultModel()
})().then((x) => Provider.getModel(x.providerID, x.modelID))
let msgs = await messages(input.sessionID)
const lastSummary = Math.max(
0,
msgs.findLastIndex((msg) => msg.info.role === "assistant" && msg.info.summary === true),
)
msgs = msgs.slice(lastSummary)
const lastAssistant = msgs.findLast((msg) => msg.info.role === "assistant")
if (
lastAssistant?.info.role === "assistant" &&
needsCompaction({
tokens: lastAssistant.info.tokens,
model: model.info,
})
) {
const msg = await summarize({
sessionID: input.sessionID,
providerID: model.providerID,
modelID: model.info.id,
})
msgs = [msg]
}
const outputLimit = Math.min(model.info.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
using abort = lock(input.sessionID)
const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
if (lastSummary) msgs = msgs.filter((msg) => msg.info.id >= lastSummary.info.id)
const numRealUserMsgs = msgs.filter(
(m) => m.info.role === "user" && !m.parts.every((p) => "synthetic" in p && p.synthetic),
).length
@@ -819,39 +847,21 @@ export namespace Session {
const [first, ...rest] = system
system = [first, rest.join("\n")]
const assistantMsg: MessageV2.Info = {
id: Identifier.ascending("message"),
role: "assistant",
system,
mode: inputAgent,
path: {
cwd: Instance.directory,
root: Instance.worktree,
},
cost: 0,
tokens: {
input: 0,
output: 0,
reasoning: 0,
cache: { read: 0, write: 0 },
},
modelID: model.modelID,
providerID: model.providerID,
time: {
created: Date.now(),
},
const processor = await createProcessor({
sessionID: input.sessionID,
}
await updateMessage(assistantMsg)
model: model.info,
providerID: model.providerID,
agent: inputAgent,
system,
})
await using _ = defer(async () => {
if (assistantMsg.time.completed) return
await Storage.remove(["session", "message", input.sessionID, assistantMsg.id])
await Bus.publish(MessageV2.Event.Removed, { sessionID: input.sessionID, messageID: assistantMsg.id })
if (processor.message.time.completed) return
await Storage.remove(["session", "message", input.sessionID, processor.message.id])
await Bus.publish(MessageV2.Event.Removed, { sessionID: input.sessionID, messageID: processor.message.id })
})
const tools: Record<string, AITool> = {}
const processor = createProcessor(assistantMsg, model.info)
const enabledTools = pipe(
agent.tools,
mergeDeep(await ToolRegistry.enabled(model.providerID, model.modelID, agent)),
@@ -878,7 +888,7 @@ export namespace Session {
const result = await item.execute(args, {
sessionID: input.sessionID,
abort: options.abortSignal!,
messageID: assistantMsg.id,
messageID: processor.message.id,
callID: options.toolCallId,
agent: agent.name,
metadata: async (val) => {
@@ -982,6 +992,8 @@ export namespace Session {
},
},
)
let pointer = 0
const stream = streamText({
onError(e) {
log.error("streamText error", {
@@ -989,39 +1001,32 @@ export namespace Session {
})
},
async prepareStep({ messages, steps }) {
// Auto compact if too long
const tokens = (() => {
if (steps.length) {
const previous = steps.at(-1)
if (previous) return getUsage(model.info, previous.usage, previous.providerMetadata).tokens
}
const msg = msgs.findLast((x) => x.info.role === "assistant")?.info as MessageV2.Assistant
if (msg && msg.tokens) {
return msg.tokens
}
})()
if (tokens) {
log.info("compact check", tokens)
const count = tokens.input + tokens.cache.read + tokens.cache.write + tokens.output
if (model.info.limit.context && count > Math.max((model.info.limit.context - outputLimit) * 0.9, 0)) {
log.info("compacting in prepareStep")
const summarized = await summarize({
sessionID: input.sessionID,
providerID: model.providerID,
modelID: model.info.id,
})
const msgs = await Session.messages(input.sessionID).then((x) =>
x.filter((x) => x.info.id >= summarized.id),
)
return {
messages: MessageV2.toModelMessage(msgs),
}
}
log.info("search", {
length: messages.length,
})
const step = steps.at(-1)
if (
step &&
needsCompaction({
tokens: getUsage(model.info, step.usage, step.providerMetadata).tokens,
model: model.info,
})
) {
await processor.end()
const msg = await Session.summarize({
sessionID: input.sessionID,
providerID: model.providerID,
modelID: model.info.id,
})
await processor.next()
pointer = messages.length - 1
messages.push(...MessageV2.toModelMessage([msg]))
}
// Add queued messages to the stream
const queue = (state().queued.get(input.sessionID) ?? []).filter((x) => !x.processed)
if (queue.length) {
await processor.end()
for (const item of queue) {
if (item.processed) continue
messages.push(
@@ -1034,35 +1039,10 @@ export namespace Session {
)
item.processed = true
}
assistantMsg.time.completed = Date.now()
await updateMessage(assistantMsg)
Object.assign(assistantMsg, {
id: Identifier.ascending("message"),
role: "assistant",
system,
path: {
cwd: Instance.directory,
root: Instance.worktree,
},
cost: 0,
tokens: {
input: 0,
output: 0,
reasoning: 0,
cache: { read: 0, write: 0 },
},
modelID: model.modelID,
providerID: model.providerID,
mode: inputAgent,
time: {
created: Date.now(),
},
sessionID: input.sessionID,
})
await updateMessage(assistantMsg)
await processor.next()
}
return {
messages,
messages: messages.slice(pointer),
}
},
async experimental_repairToolCall(input) {
@@ -1421,11 +1401,60 @@ export namespace Session {
})
}
function createProcessor(assistantMsg: MessageV2.Assistant, model: ModelsDev.Model) {
async function createProcessor(input: {
sessionID: string
providerID: string
model: ModelsDev.Model
system: string[]
agent: string
}) {
const toolcalls: Record<string, MessageV2.ToolPart> = {}
let snapshot: string | undefined
let shouldStop = false
return {
async function createMessage() {
const msg: MessageV2.Info = {
id: Identifier.ascending("message"),
role: "assistant",
system: input.system,
mode: input.agent,
path: {
cwd: Instance.directory,
root: Instance.worktree,
},
cost: 0,
tokens: {
input: 0,
output: 0,
reasoning: 0,
cache: { read: 0, write: 0 },
},
modelID: input.model.id,
providerID: input.providerID,
time: {
created: Date.now(),
},
sessionID: input.sessionID,
}
await updateMessage(msg)
return msg
}
let assistantMsg = await createMessage()
const result = {
async end() {
if (assistantMsg) {
assistantMsg.time.completed = Date.now()
await updateMessage(assistantMsg)
}
},
async next() {
assistantMsg = await createMessage()
},
get message() {
return assistantMsg
},
partFromToolCall(toolCallID: string) {
return toolcalls[toolCallID]
},
@@ -1581,7 +1610,7 @@ export namespace Session {
break
case "finish-step":
const usage = getUsage(model, value.usage, value.providerMetadata)
const usage = getUsage(input.model, value.usage, value.providerMetadata)
assistantMsg.cost += usage.cost
assistantMsg.tokens = usage.tokens
await updatePart({
@@ -1672,7 +1701,7 @@ export namespace Session {
case LoadAPIKeyError.isInstance(e):
assistantMsg.error = new MessageV2.AuthError(
{
providerID: model.id,
providerID: input.providerID,
message: e.message,
},
{ cause: e },
@@ -1711,6 +1740,7 @@ export namespace Session {
return { info: assistantMsg, parts: p }
},
}
return result
}
export const RevertInput = z.object({
@@ -1789,9 +1819,8 @@ export namespace Session {
0,
msgs.findLastIndex((msg) => msg.info.role === "assistant" && msg.info.summary === true),
)
const split = start + Math.floor((msgs.length - start) / 2)
log.info("summarizing", { start, split })
const toSummarize = msgs.slice(start, split)
log.info("summarizing", { start })
const toSummarize = msgs.slice(start)
const model = await Provider.getModel(input.providerID, input.modelID)
const system = [
...SystemPrompt.summarize(model.providerID),
@@ -1799,6 +1828,29 @@ export namespace Session {
...(await SystemPrompt.custom()),
]
const msg = (await updateMessage({
id: Identifier.ascending("message"),
role: "assistant",
sessionID: input.sessionID,
system,
mode: "build",
path: {
cwd: Instance.directory,
root: Instance.worktree,
},
cost: 0,
tokens: {
output: 0,
input: 0,
reasoning: 0,
cache: { read: 0, write: 0 },
},
modelID: input.modelID,
providerID: model.providerID,
time: {
created: Date.now(),
},
})) as MessageV2.Assistant
const generated = await generateText({
maxRetries: 10,
model: model.language,
@@ -1822,28 +1874,12 @@ export namespace Session {
],
})
const usage = getUsage(model.info, generated.usage, generated.providerMetadata)
const msg: MessageV2.Info = {
id: Identifier.create("message", false, toSummarize.at(-1)!.info.time.created + 1),
role: "assistant",
sessionID: input.sessionID,
system,
mode: "build",
path: {
cwd: Instance.directory,
root: Instance.worktree,
},
summary: true,
cost: usage.cost,
tokens: usage.tokens,
modelID: input.modelID,
providerID: model.providerID,
time: {
created: Date.now(),
completed: Date.now(),
},
}
msg.cost += usage.cost
msg.tokens = usage.tokens
msg.summary = true
msg.time.completed = Date.now()
await updateMessage(msg)
await updatePart({
const part = await updatePart({
type: "text",
sessionID: input.sessionID,
messageID: msg.id,
@@ -1859,7 +1895,34 @@ export namespace Session {
sessionID: input.sessionID,
})
return msg
return {
info: msg,
parts: [part],
}
}
function needsCompaction(input: { tokens: MessageV2.Assistant["tokens"]; model: ModelsDev.Model }) {
const count = input.tokens.input + input.tokens.cache.read + input.tokens.output
const output = Math.min(input.model.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
const usable = input.model.limit.context - output
return count > usable / 2
}
export async function microcompact(input: { sessionID: string }) {
const msgs = await messages(input.sessionID)
let sum = 0
for (let msgIndex = msgs.length - 1; msgIndex >= 0; msgIndex--) {
const msg = msgs[msgIndex]
for (let partIndex = msg.parts.length - 1; partIndex >= 0; partIndex--) {
const part = msg.parts[partIndex]
if (part.type === "tool")
if (part.state.status === "completed") {
sum += Token.estimate(part.state.output)
if (sum > 40_000) {
}
}
}
}
}
function isLocked(sessionID: string) {

View File

@@ -0,0 +1,7 @@
export namespace Token {
const CHARS_PER_TOKEN = 4
export function estimate(input: string) {
return Math.max(0, Math.round((input || "").length / CHARS_PER_TOKEN))
}
}