fix compaction issues

This commit is contained in:
Dax Raad
2025-09-12 05:59:53 -04:00
parent 3c502861a7
commit 983e3b2ee3
2 changed files with 185 additions and 115 deletions

View File

@@ -53,6 +53,7 @@ import { defer } from "../util/defer"
import { Command } from "../command" import { Command } from "../command"
import { $ } from "bun" import { $ } from "bun"
import { ListTool } from "../tool/ls" import { ListTool } from "../tool/ls"
import { Token } from "../util/token"
export namespace Session { export namespace Session {
const log = Log.create({ service: "session" }) const log = Log.create({ service: "session" })
@@ -83,6 +84,12 @@ export namespace Session {
.optional(), .optional(),
title: z.string(), title: z.string(),
version: z.string(), version: z.string(),
compaction: z
.object({
full: z.string().optional(),
micro: z.string().optional(),
})
.optional(),
time: z.object({ time: z.object({
created: z.number(), created: z.number(),
updated: z.number(), updated: z.number(),
@@ -361,6 +368,7 @@ export namespace Session {
Bus.publish(MessageV2.Event.Updated, { Bus.publish(MessageV2.Event.Updated, {
info: msg, info: msg,
}) })
return msg
} }
async function updatePart(part: MessageV2.Part) { async function updatePart(part: MessageV2.Part) {
@@ -717,14 +725,34 @@ export namespace Session {
} }
return Provider.defaultModel() return Provider.defaultModel()
})().then((x) => Provider.getModel(x.providerID, x.modelID)) })().then((x) => Provider.getModel(x.providerID, x.modelID))
let msgs = await messages(input.sessionID) let msgs = await messages(input.sessionID)
const lastSummary = Math.max(
0,
msgs.findLastIndex((msg) => msg.info.role === "assistant" && msg.info.summary === true),
)
msgs = msgs.slice(lastSummary)
const lastAssistant = msgs.findLast((msg) => msg.info.role === "assistant")
if (
lastAssistant?.info.role === "assistant" &&
needsCompaction({
tokens: lastAssistant.info.tokens,
model: model.info,
})
) {
const msg = await summarize({
sessionID: input.sessionID,
providerID: model.providerID,
modelID: model.info.id,
})
msgs = [msg]
}
const outputLimit = Math.min(model.info.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX const outputLimit = Math.min(model.info.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
using abort = lock(input.sessionID) using abort = lock(input.sessionID)
const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
if (lastSummary) msgs = msgs.filter((msg) => msg.info.id >= lastSummary.info.id)
const numRealUserMsgs = msgs.filter( const numRealUserMsgs = msgs.filter(
(m) => m.info.role === "user" && !m.parts.every((p) => "synthetic" in p && p.synthetic), (m) => m.info.role === "user" && !m.parts.every((p) => "synthetic" in p && p.synthetic),
).length ).length
@@ -819,39 +847,21 @@ export namespace Session {
const [first, ...rest] = system const [first, ...rest] = system
system = [first, rest.join("\n")] system = [first, rest.join("\n")]
const assistantMsg: MessageV2.Info = { const processor = await createProcessor({
id: Identifier.ascending("message"),
role: "assistant",
system,
mode: inputAgent,
path: {
cwd: Instance.directory,
root: Instance.worktree,
},
cost: 0,
tokens: {
input: 0,
output: 0,
reasoning: 0,
cache: { read: 0, write: 0 },
},
modelID: model.modelID,
providerID: model.providerID,
time: {
created: Date.now(),
},
sessionID: input.sessionID, sessionID: input.sessionID,
} model: model.info,
await updateMessage(assistantMsg) providerID: model.providerID,
agent: inputAgent,
system,
})
await using _ = defer(async () => { await using _ = defer(async () => {
if (assistantMsg.time.completed) return if (processor.message.time.completed) return
await Storage.remove(["session", "message", input.sessionID, assistantMsg.id]) await Storage.remove(["session", "message", input.sessionID, processor.message.id])
await Bus.publish(MessageV2.Event.Removed, { sessionID: input.sessionID, messageID: assistantMsg.id }) await Bus.publish(MessageV2.Event.Removed, { sessionID: input.sessionID, messageID: processor.message.id })
}) })
const tools: Record<string, AITool> = {} const tools: Record<string, AITool> = {}
const processor = createProcessor(assistantMsg, model.info)
const enabledTools = pipe( const enabledTools = pipe(
agent.tools, agent.tools,
mergeDeep(await ToolRegistry.enabled(model.providerID, model.modelID, agent)), mergeDeep(await ToolRegistry.enabled(model.providerID, model.modelID, agent)),
@@ -878,7 +888,7 @@ export namespace Session {
const result = await item.execute(args, { const result = await item.execute(args, {
sessionID: input.sessionID, sessionID: input.sessionID,
abort: options.abortSignal!, abort: options.abortSignal!,
messageID: assistantMsg.id, messageID: processor.message.id,
callID: options.toolCallId, callID: options.toolCallId,
agent: agent.name, agent: agent.name,
metadata: async (val) => { metadata: async (val) => {
@@ -982,6 +992,8 @@ export namespace Session {
}, },
}, },
) )
let pointer = 0
const stream = streamText({ const stream = streamText({
onError(e) { onError(e) {
log.error("streamText error", { log.error("streamText error", {
@@ -989,39 +1001,32 @@ export namespace Session {
}) })
}, },
async prepareStep({ messages, steps }) { async prepareStep({ messages, steps }) {
// Auto compact if too long log.info("search", {
const tokens = (() => { length: messages.length,
if (steps.length) { })
const previous = steps.at(-1) const step = steps.at(-1)
if (previous) return getUsage(model.info, previous.usage, previous.providerMetadata).tokens if (
} step &&
const msg = msgs.findLast((x) => x.info.role === "assistant")?.info as MessageV2.Assistant needsCompaction({
if (msg && msg.tokens) { tokens: getUsage(model.info, step.usage, step.providerMetadata).tokens,
return msg.tokens model: model.info,
} })
})() ) {
if (tokens) { await processor.end()
log.info("compact check", tokens) const msg = await Session.summarize({
const count = tokens.input + tokens.cache.read + tokens.cache.write + tokens.output sessionID: input.sessionID,
if (model.info.limit.context && count > Math.max((model.info.limit.context - outputLimit) * 0.9, 0)) { providerID: model.providerID,
log.info("compacting in prepareStep") modelID: model.info.id,
const summarized = await summarize({ })
sessionID: input.sessionID, await processor.next()
providerID: model.providerID, pointer = messages.length - 1
modelID: model.info.id, messages.push(...MessageV2.toModelMessage([msg]))
})
const msgs = await Session.messages(input.sessionID).then((x) =>
x.filter((x) => x.info.id >= summarized.id),
)
return {
messages: MessageV2.toModelMessage(msgs),
}
}
} }
// Add queued messages to the stream // Add queued messages to the stream
const queue = (state().queued.get(input.sessionID) ?? []).filter((x) => !x.processed) const queue = (state().queued.get(input.sessionID) ?? []).filter((x) => !x.processed)
if (queue.length) { if (queue.length) {
await processor.end()
for (const item of queue) { for (const item of queue) {
if (item.processed) continue if (item.processed) continue
messages.push( messages.push(
@@ -1034,35 +1039,10 @@ export namespace Session {
) )
item.processed = true item.processed = true
} }
assistantMsg.time.completed = Date.now() await processor.next()
await updateMessage(assistantMsg)
Object.assign(assistantMsg, {
id: Identifier.ascending("message"),
role: "assistant",
system,
path: {
cwd: Instance.directory,
root: Instance.worktree,
},
cost: 0,
tokens: {
input: 0,
output: 0,
reasoning: 0,
cache: { read: 0, write: 0 },
},
modelID: model.modelID,
providerID: model.providerID,
mode: inputAgent,
time: {
created: Date.now(),
},
sessionID: input.sessionID,
})
await updateMessage(assistantMsg)
} }
return { return {
messages, messages: messages.slice(pointer),
} }
}, },
async experimental_repairToolCall(input) { async experimental_repairToolCall(input) {
@@ -1421,11 +1401,60 @@ export namespace Session {
}) })
} }
function createProcessor(assistantMsg: MessageV2.Assistant, model: ModelsDev.Model) { async function createProcessor(input: {
sessionID: string
providerID: string
model: ModelsDev.Model
system: string[]
agent: string
}) {
const toolcalls: Record<string, MessageV2.ToolPart> = {} const toolcalls: Record<string, MessageV2.ToolPart> = {}
let snapshot: string | undefined let snapshot: string | undefined
let shouldStop = false let shouldStop = false
return {
async function createMessage() {
const msg: MessageV2.Info = {
id: Identifier.ascending("message"),
role: "assistant",
system: input.system,
mode: input.agent,
path: {
cwd: Instance.directory,
root: Instance.worktree,
},
cost: 0,
tokens: {
input: 0,
output: 0,
reasoning: 0,
cache: { read: 0, write: 0 },
},
modelID: input.model.id,
providerID: input.providerID,
time: {
created: Date.now(),
},
sessionID: input.sessionID,
}
await updateMessage(msg)
return msg
}
let assistantMsg = await createMessage()
const result = {
async end() {
if (assistantMsg) {
assistantMsg.time.completed = Date.now()
await updateMessage(assistantMsg)
}
},
async next() {
assistantMsg = await createMessage()
},
get message() {
return assistantMsg
},
partFromToolCall(toolCallID: string) { partFromToolCall(toolCallID: string) {
return toolcalls[toolCallID] return toolcalls[toolCallID]
}, },
@@ -1581,7 +1610,7 @@ export namespace Session {
break break
case "finish-step": case "finish-step":
const usage = getUsage(model, value.usage, value.providerMetadata) const usage = getUsage(input.model, value.usage, value.providerMetadata)
assistantMsg.cost += usage.cost assistantMsg.cost += usage.cost
assistantMsg.tokens = usage.tokens assistantMsg.tokens = usage.tokens
await updatePart({ await updatePart({
@@ -1672,7 +1701,7 @@ export namespace Session {
case LoadAPIKeyError.isInstance(e): case LoadAPIKeyError.isInstance(e):
assistantMsg.error = new MessageV2.AuthError( assistantMsg.error = new MessageV2.AuthError(
{ {
providerID: model.id, providerID: input.providerID,
message: e.message, message: e.message,
}, },
{ cause: e }, { cause: e },
@@ -1711,6 +1740,7 @@ export namespace Session {
return { info: assistantMsg, parts: p } return { info: assistantMsg, parts: p }
}, },
} }
return result
} }
export const RevertInput = z.object({ export const RevertInput = z.object({
@@ -1789,9 +1819,8 @@ export namespace Session {
0, 0,
msgs.findLastIndex((msg) => msg.info.role === "assistant" && msg.info.summary === true), msgs.findLastIndex((msg) => msg.info.role === "assistant" && msg.info.summary === true),
) )
const split = start + Math.floor((msgs.length - start) / 2) log.info("summarizing", { start })
log.info("summarizing", { start, split }) const toSummarize = msgs.slice(start)
const toSummarize = msgs.slice(start, split)
const model = await Provider.getModel(input.providerID, input.modelID) const model = await Provider.getModel(input.providerID, input.modelID)
const system = [ const system = [
...SystemPrompt.summarize(model.providerID), ...SystemPrompt.summarize(model.providerID),
@@ -1799,6 +1828,29 @@ export namespace Session {
...(await SystemPrompt.custom()), ...(await SystemPrompt.custom()),
] ]
const msg = (await updateMessage({
id: Identifier.ascending("message"),
role: "assistant",
sessionID: input.sessionID,
system,
mode: "build",
path: {
cwd: Instance.directory,
root: Instance.worktree,
},
cost: 0,
tokens: {
output: 0,
input: 0,
reasoning: 0,
cache: { read: 0, write: 0 },
},
modelID: input.modelID,
providerID: model.providerID,
time: {
created: Date.now(),
},
})) as MessageV2.Assistant
const generated = await generateText({ const generated = await generateText({
maxRetries: 10, maxRetries: 10,
model: model.language, model: model.language,
@@ -1822,28 +1874,12 @@ export namespace Session {
], ],
}) })
const usage = getUsage(model.info, generated.usage, generated.providerMetadata) const usage = getUsage(model.info, generated.usage, generated.providerMetadata)
const msg: MessageV2.Info = { msg.cost += usage.cost
id: Identifier.create("message", false, toSummarize.at(-1)!.info.time.created + 1), msg.tokens = usage.tokens
role: "assistant", msg.summary = true
sessionID: input.sessionID, msg.time.completed = Date.now()
system,
mode: "build",
path: {
cwd: Instance.directory,
root: Instance.worktree,
},
summary: true,
cost: usage.cost,
tokens: usage.tokens,
modelID: input.modelID,
providerID: model.providerID,
time: {
created: Date.now(),
completed: Date.now(),
},
}
await updateMessage(msg) await updateMessage(msg)
await updatePart({ const part = await updatePart({
type: "text", type: "text",
sessionID: input.sessionID, sessionID: input.sessionID,
messageID: msg.id, messageID: msg.id,
@@ -1859,7 +1895,34 @@ export namespace Session {
sessionID: input.sessionID, sessionID: input.sessionID,
}) })
return msg return {
info: msg,
parts: [part],
}
}
function needsCompaction(input: { tokens: MessageV2.Assistant["tokens"]; model: ModelsDev.Model }) {
const count = input.tokens.input + input.tokens.cache.read + input.tokens.output
const output = Math.min(input.model.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
const usable = input.model.limit.context - output
return count > usable / 2
}
export async function microcompact(input: { sessionID: string }) {
const msgs = await messages(input.sessionID)
let sum = 0
for (let msgIndex = msgs.length - 1; msgIndex >= 0; msgIndex--) {
const msg = msgs[msgIndex]
for (let partIndex = msg.parts.length - 1; partIndex >= 0; partIndex--) {
const part = msg.parts[partIndex]
if (part.type === "tool")
if (part.state.status === "completed") {
sum += Token.estimate(part.state.output)
if (sum > 40_000) {
}
}
}
}
} }
function isLocked(sessionID: string) { function isLocked(sessionID: string) {

View File

@@ -0,0 +1,7 @@
export namespace Token {
const CHARS_PER_TOKEN = 4
export function estimate(input: string) {
return Math.max(0, Math.round((input || "").length / CHARS_PER_TOKEN))
}
}