Session management and prompt handling improvements (#2577)

Co-authored-by: GitHub Action <action@github.com>
This commit is contained in:
Dax
2025-09-13 05:46:14 -04:00
committed by GitHub
parent 535230dce4
commit 9bb25a9260
9 changed files with 1755 additions and 1682 deletions

View File

@@ -11,6 +11,7 @@ import { MessageV2 } from "../../session/message-v2"
import { Identifier } from "../../id/id" import { Identifier } from "../../id/id"
import { Agent } from "../../agent/agent" import { Agent } from "../../agent/agent"
import { Command } from "../../command" import { Command } from "../../command"
import { SessionPrompt } from "../../session/prompt"
const TOOL: Record<string, [string, string]> = { const TOOL: Record<string, [string, string]> = {
todowrite: ["Todo", UI.Style.TEXT_WARNING_BOLD], todowrite: ["Todo", UI.Style.TEXT_WARNING_BOLD],
@@ -185,7 +186,7 @@ export const RunCommand = cmd({
}) })
if (args.command) { if (args.command) {
await Session.command({ await SessionPrompt.command({
messageID: Identifier.ascending("message"), messageID: Identifier.ascending("message"),
sessionID: session.id, sessionID: session.id,
agent: agent.name, agent: agent.name,
@@ -197,7 +198,7 @@ export const RunCommand = cmd({
} }
const messageID = Identifier.ascending("message") const messageID = Identifier.ascending("message")
const result = await Session.prompt({ const result = await SessionPrompt.prompt({
sessionID: session.id, sessionID: session.id,
messageID, messageID,
model: { model: {

View File

@@ -25,6 +25,9 @@ import { Global } from "../global"
import { ProjectRoute } from "./project" import { ProjectRoute } from "./project"
import { ToolRegistry } from "../tool/registry" import { ToolRegistry } from "../tool/registry"
import { zodToJsonSchema } from "zod-to-json-schema" import { zodToJsonSchema } from "zod-to-json-schema"
import { SessionPrompt } from "../session/prompt"
import { SessionCompaction } from "../session/compaction"
import { SessionRevert } from "../session/revert"
const ERRORS = { const ERRORS = {
400: { 400: {
@@ -558,7 +561,7 @@ export namespace Server {
}), }),
), ),
async (c) => { async (c) => {
return c.json(Session.abort(c.req.valid("param").id)) return c.json(SessionPrompt.abort(c.req.valid("param").id))
}, },
) )
.post( .post(
@@ -651,7 +654,7 @@ export namespace Server {
async (c) => { async (c) => {
const id = c.req.valid("param").id const id = c.req.valid("param").id
const body = c.req.valid("json") const body = c.req.valid("json")
await Session.summarize({ ...body, sessionID: id }) await SessionCompaction.run({ ...body, sessionID: id })
return c.json(true) return c.json(true)
}, },
) )
@@ -665,14 +668,7 @@ export namespace Server {
description: "List of messages", description: "List of messages",
content: { content: {
"application/json": { "application/json": {
schema: resolver( schema: resolver(MessageV2.WithParts.array()),
z
.object({
info: MessageV2.Info,
parts: MessageV2.Part.array(),
})
.array(),
),
}, },
}, },
}, },
@@ -750,11 +746,11 @@ export namespace Server {
id: z.string().openapi({ description: "Session ID" }), id: z.string().openapi({ description: "Session ID" }),
}), }),
), ),
zValidator("json", Session.PromptInput.omit({ sessionID: true })), zValidator("json", SessionPrompt.PromptInput.omit({ sessionID: true })),
async (c) => { async (c) => {
const sessionID = c.req.valid("param").id const sessionID = c.req.valid("param").id
const body = c.req.valid("json") const body = c.req.valid("json")
const msg = await Session.prompt({ ...body, sessionID }) const msg = await SessionPrompt.prompt({ ...body, sessionID })
return c.json(msg) return c.json(msg)
}, },
) )
@@ -785,11 +781,11 @@ export namespace Server {
id: z.string().openapi({ description: "Session ID" }), id: z.string().openapi({ description: "Session ID" }),
}), }),
), ),
zValidator("json", Session.CommandInput.omit({ sessionID: true })), zValidator("json", SessionPrompt.CommandInput.omit({ sessionID: true })),
async (c) => { async (c) => {
const sessionID = c.req.valid("param").id const sessionID = c.req.valid("param").id
const body = c.req.valid("json") const body = c.req.valid("json")
const msg = await Session.command({ ...body, sessionID }) const msg = await SessionPrompt.command({ ...body, sessionID })
return c.json(msg) return c.json(msg)
}, },
) )
@@ -815,11 +811,11 @@ export namespace Server {
id: z.string().openapi({ description: "Session ID" }), id: z.string().openapi({ description: "Session ID" }),
}), }),
), ),
zValidator("json", Session.ShellInput.omit({ sessionID: true })), zValidator("json", SessionPrompt.ShellInput.omit({ sessionID: true })),
async (c) => { async (c) => {
const sessionID = c.req.valid("param").id const sessionID = c.req.valid("param").id
const body = c.req.valid("json") const body = c.req.valid("json")
const msg = await Session.shell({ ...body, sessionID }) const msg = await SessionPrompt.shell({ ...body, sessionID })
return c.json(msg) return c.json(msg)
}, },
) )
@@ -845,11 +841,11 @@ export namespace Server {
id: z.string(), id: z.string(),
}), }),
), ),
zValidator("json", Session.RevertInput.omit({ sessionID: true })), zValidator("json", SessionRevert.RevertInput.omit({ sessionID: true })),
async (c) => { async (c) => {
const id = c.req.valid("param").id const id = c.req.valid("param").id
log.info("revert", c.req.valid("json")) log.info("revert", c.req.valid("json"))
const session = await Session.revert({ sessionID: id, ...c.req.valid("json") }) const session = await SessionRevert.revert({ sessionID: id, ...c.req.valid("json") })
return c.json(session) return c.json(session)
}, },
) )
@@ -877,7 +873,7 @@ export namespace Server {
), ),
async (c) => { async (c) => {
const id = c.req.valid("param").id const id = c.req.valid("param").id
const session = await Session.unrevert({ sessionID: id }) const session = await SessionRevert.unrevert({ sessionID: id })
return c.json(session) return c.json(session)
}, },
) )

View File

@@ -0,0 +1,120 @@
import { generateText, type ModelMessage } from "ai"
import { Session } from "."
import { Identifier } from "../id/id"
import { Instance } from "../project/instance"
import { Provider } from "../provider/provider"
import { defer } from "../util/defer"
import { MessageV2 } from "./message-v2"
import { SystemPrompt } from "./system"
import { Bus } from "../bus"
import z from "zod"
import type { ModelsDev } from "../provider/models"
import { SessionPrompt } from "./prompt"
export namespace SessionCompaction {
export const Event = {
Compacted: Bus.event(
"session.compacted",
z.object({
sessionID: z.string(),
}),
),
}
export function isOverflow(input: { tokens: MessageV2.Assistant["tokens"]; model: ModelsDev.Model }) {
const count = input.tokens.input + input.tokens.cache.read + input.tokens.output
const output = Math.min(input.model.limit.output, SessionPrompt.OUTPUT_TOKEN_MAX) || SessionPrompt.OUTPUT_TOKEN_MAX
const usable = input.model.limit.context - output
return count > usable / 2
}
export async function run(input: { sessionID: string; providerID: string; modelID: string }) {
await Session.update(input.sessionID, (draft) => {
draft.time.compacting = Date.now()
})
await using _ = defer(async () => {
await Session.update(input.sessionID, (draft) => {
draft.time.compacting = undefined
})
})
const toSummarize = await Session.messages(input.sessionID).then(MessageV2.filterSummarized)
const model = await Provider.getModel(input.providerID, input.modelID)
const system = [
...SystemPrompt.summarize(model.providerID),
...(await SystemPrompt.environment()),
...(await SystemPrompt.custom()),
]
const msg = (await Session.updateMessage({
id: Identifier.ascending("message"),
role: "assistant",
sessionID: input.sessionID,
system,
mode: "build",
path: {
cwd: Instance.directory,
root: Instance.worktree,
},
cost: 0,
tokens: {
output: 0,
input: 0,
reasoning: 0,
cache: { read: 0, write: 0 },
},
modelID: input.modelID,
providerID: model.providerID,
time: {
created: Date.now(),
},
})) as MessageV2.Assistant
const generated = await generateText({
maxRetries: 10,
model: model.language,
messages: [
...system.map(
(x): ModelMessage => ({
role: "system",
content: x,
}),
),
...MessageV2.toModelMessage(toSummarize),
{
role: "user",
content: [
{
type: "text",
text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
},
],
},
],
})
const usage = Session.getUsage(model.info, generated.usage, generated.providerMetadata)
msg.cost += usage.cost
msg.tokens = usage.tokens
msg.summary = true
msg.time.completed = Date.now()
await Session.updateMessage(msg)
const part = await Session.updatePart({
type: "text",
sessionID: input.sessionID,
messageID: msg.id,
id: Identifier.ascending("part"),
text: generated.text,
time: {
start: Date.now(),
end: Date.now(),
},
})
Bus.publish(Event.Compacted, {
sessionID: input.sessionID,
})
return {
info: msg,
parts: [part],
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -331,6 +331,12 @@ export namespace MessageV2 {
), ),
} }
export const WithParts = z.object({
info: Info,
parts: z.array(Part),
})
export type WithParts = z.infer<typeof WithParts>
export function fromV1(v1: Message.Info) { export function fromV1(v1: Message.Info) {
if (v1.role === "assistant") { if (v1.role === "assistant") {
const info: Assistant = { const info: Assistant = {
@@ -552,4 +558,10 @@ export namespace MessageV2 {
return convertToModelMessages(result) return convertToModelMessages(result)
} }
export function filterSummarized(msgs: { info: MessageV2.Info; parts: MessageV2.Part[] }[]) {
const i = msgs.findLastIndex((m) => m.info.role === "assistant" && !!m.info.summary)
if (i === -1) return msgs.slice()
return msgs.slice(i)
}
} }

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,105 @@
import z from "zod"
import { Identifier } from "../id/id"
import { Snapshot } from "../snapshot"
import { MessageV2 } from "./message-v2"
import { Session } from "."
import { Log } from "../util/log"
import { splitWhen } from "remeda"
import { Storage } from "../storage/storage"
import { Bus } from "../bus"
export namespace SessionRevert {
const log = Log.create({ service: "session.revert" })
export const RevertInput = z.object({
sessionID: Identifier.schema("session"),
messageID: Identifier.schema("message"),
partID: Identifier.schema("part").optional(),
})
export type RevertInput = z.infer<typeof RevertInput>
export async function revert(input: RevertInput) {
const all = await Session.messages(input.sessionID)
let lastUser: MessageV2.User | undefined
const session = await Session.get(input.sessionID)
let revert: Session.Info["revert"]
const patches: Snapshot.Patch[] = []
for (const msg of all) {
if (msg.info.role === "user") lastUser = msg.info
const remaining = []
for (const part of msg.parts) {
if (revert) {
if (part.type === "patch") {
patches.push(part)
}
continue
}
if (!revert) {
if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
// if no useful parts left in message, same as reverting whole message
const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
revert = {
messageID: !partID && lastUser ? lastUser.id : msg.info.id,
partID,
}
}
remaining.push(part)
}
}
}
if (revert) {
const session = await Session.get(input.sessionID)
revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track())
await Snapshot.revert(patches)
if (revert.snapshot) revert.diff = await Snapshot.diff(revert.snapshot)
return Session.update(input.sessionID, (draft) => {
draft.revert = revert
})
}
return session
}
export async function unrevert(input: { sessionID: string }) {
log.info("unreverting", input)
const session = await Session.get(input.sessionID)
if (!session.revert) return session
if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)
const next = await Session.update(input.sessionID, (draft) => {
draft.revert = undefined
})
return next
}
export async function cleanup(session: Session.Info) {
if (!session.revert) return
const sessionID = session.id
let msgs = await Session.messages(sessionID)
const messageID = session.revert.messageID
const [preserve, remove] = splitWhen(msgs, (x) => x.info.id === messageID)
msgs = preserve
for (const msg of remove) {
await Storage.remove(["message", sessionID, msg.info.id])
await Bus.publish(MessageV2.Event.Removed, { sessionID: sessionID, messageID: msg.info.id })
}
const last = preserve.at(-1)
if (session.revert.partID && last) {
const partID = session.revert.partID
const [preserveParts, removeParts] = splitWhen(last.parts, (x) => x.id === partID)
last.parts = preserveParts
for (const part of removeParts) {
await Storage.remove(["part", last.info.id, part.id])
await Bus.publish(MessageV2.Event.PartRemoved, {
sessionID: sessionID,
messageID: last.info.id,
partID: part.id,
})
}
}
await Session.update(sessionID, (draft) => {
draft.revert = undefined
})
}
}

View File

@@ -6,6 +6,7 @@ import { Bus } from "../bus"
import { MessageV2 } from "../session/message-v2" import { MessageV2 } from "../session/message-v2"
import { Identifier } from "../id/id" import { Identifier } from "../id/id"
import { Agent } from "../agent/agent" import { Agent } from "../agent/agent"
import { SessionPrompt } from "../session/prompt"
export const TaskTool = Tool.define("task", async () => { export const TaskTool = Tool.define("task", async () => {
const agents = await Agent.list().then((x) => x.filter((a) => a.mode !== "primary")) const agents = await Agent.list().then((x) => x.filter((a) => a.mode !== "primary"))
@@ -49,9 +50,9 @@ export const TaskTool = Tool.define("task", async () => {
} }
ctx.abort.addEventListener("abort", () => { ctx.abort.addEventListener("abort", () => {
Session.abort(session.id) SessionPrompt.abort(session.id)
}) })
const result = await Session.prompt({ const result = await SessionPrompt.prompt({
messageID, messageID,
sessionID: session.id, sessionID: session.id,
model: { model: {

View File

@@ -1,5 +1,5 @@
import { describe, expect, test } from "bun:test" import { describe, expect, test } from "bun:test"
import { Session } from "../../src/session/index" import { SessionPrompt } from "../../src/session/prompt"
describe("fileRegex", () => { describe("fileRegex", () => {
const template = `This is a @valid/path/to/a/file and it should also match at const template = `This is a @valid/path/to/a/file and it should also match at
@@ -23,7 +23,7 @@ as well as @~/home-files and @~/paths/under/home.txt.
If the reference is \`@quoted/in/backticks\` then it shouldn't match at all.` If the reference is \`@quoted/in/backticks\` then it shouldn't match at all.`
const matches = Array.from(template.matchAll(Session.fileRegex)) const matches = Array.from(template.matchAll(SessionPrompt.fileRegex))
test("should extract exactly 12 file references", () => { test("should extract exactly 12 file references", () => {
expect(matches.length).toBe(12) expect(matches.length).toBe(12)
@@ -79,13 +79,13 @@ If the reference is \`@quoted/in/backticks\` then it shouldn't match at all.`
test("should not match when preceded by backtick", () => { test("should not match when preceded by backtick", () => {
const backtickTest = "This `@should/not/match` should be ignored" const backtickTest = "This `@should/not/match` should be ignored"
const backtickMatches = Array.from(backtickTest.matchAll(Session.fileRegex)) const backtickMatches = Array.from(backtickTest.matchAll(SessionPrompt.fileRegex))
expect(backtickMatches.length).toBe(0) expect(backtickMatches.length).toBe(0)
}) })
test("should not match email addresses", () => { test("should not match email addresses", () => {
const emailTest = "Contact user@example.com for help" const emailTest = "Contact user@example.com for help"
const emailMatches = Array.from(emailTest.matchAll(Session.fileRegex)) const emailMatches = Array.from(emailTest.matchAll(SessionPrompt.fileRegex))
expect(emailMatches.length).toBe(0) expect(emailMatches.length).toBe(0)
}) })
}) })