mirror of
https://github.com/aljazceru/opencode.git
synced 2025-12-21 09:44:21 +01:00
Session management and prompt handling improvements (#2577)
Co-authored-by: GitHub Action <action@github.com>
This commit is contained in:
@@ -11,6 +11,7 @@ import { MessageV2 } from "../../session/message-v2"
|
||||
import { Identifier } from "../../id/id"
|
||||
import { Agent } from "../../agent/agent"
|
||||
import { Command } from "../../command"
|
||||
import { SessionPrompt } from "../../session/prompt"
|
||||
|
||||
const TOOL: Record<string, [string, string]> = {
|
||||
todowrite: ["Todo", UI.Style.TEXT_WARNING_BOLD],
|
||||
@@ -185,7 +186,7 @@ export const RunCommand = cmd({
|
||||
})
|
||||
|
||||
if (args.command) {
|
||||
await Session.command({
|
||||
await SessionPrompt.command({
|
||||
messageID: Identifier.ascending("message"),
|
||||
sessionID: session.id,
|
||||
agent: agent.name,
|
||||
@@ -197,7 +198,7 @@ export const RunCommand = cmd({
|
||||
}
|
||||
|
||||
const messageID = Identifier.ascending("message")
|
||||
const result = await Session.prompt({
|
||||
const result = await SessionPrompt.prompt({
|
||||
sessionID: session.id,
|
||||
messageID,
|
||||
model: {
|
||||
|
||||
@@ -25,6 +25,9 @@ import { Global } from "../global"
|
||||
import { ProjectRoute } from "./project"
|
||||
import { ToolRegistry } from "../tool/registry"
|
||||
import { zodToJsonSchema } from "zod-to-json-schema"
|
||||
import { SessionPrompt } from "../session/prompt"
|
||||
import { SessionCompaction } from "../session/compaction"
|
||||
import { SessionRevert } from "../session/revert"
|
||||
|
||||
const ERRORS = {
|
||||
400: {
|
||||
@@ -558,7 +561,7 @@ export namespace Server {
|
||||
}),
|
||||
),
|
||||
async (c) => {
|
||||
return c.json(Session.abort(c.req.valid("param").id))
|
||||
return c.json(SessionPrompt.abort(c.req.valid("param").id))
|
||||
},
|
||||
)
|
||||
.post(
|
||||
@@ -651,7 +654,7 @@ export namespace Server {
|
||||
async (c) => {
|
||||
const id = c.req.valid("param").id
|
||||
const body = c.req.valid("json")
|
||||
await Session.summarize({ ...body, sessionID: id })
|
||||
await SessionCompaction.run({ ...body, sessionID: id })
|
||||
return c.json(true)
|
||||
},
|
||||
)
|
||||
@@ -665,14 +668,7 @@ export namespace Server {
|
||||
description: "List of messages",
|
||||
content: {
|
||||
"application/json": {
|
||||
schema: resolver(
|
||||
z
|
||||
.object({
|
||||
info: MessageV2.Info,
|
||||
parts: MessageV2.Part.array(),
|
||||
})
|
||||
.array(),
|
||||
),
|
||||
schema: resolver(MessageV2.WithParts.array()),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -750,11 +746,11 @@ export namespace Server {
|
||||
id: z.string().openapi({ description: "Session ID" }),
|
||||
}),
|
||||
),
|
||||
zValidator("json", Session.PromptInput.omit({ sessionID: true })),
|
||||
zValidator("json", SessionPrompt.PromptInput.omit({ sessionID: true })),
|
||||
async (c) => {
|
||||
const sessionID = c.req.valid("param").id
|
||||
const body = c.req.valid("json")
|
||||
const msg = await Session.prompt({ ...body, sessionID })
|
||||
const msg = await SessionPrompt.prompt({ ...body, sessionID })
|
||||
return c.json(msg)
|
||||
},
|
||||
)
|
||||
@@ -785,11 +781,11 @@ export namespace Server {
|
||||
id: z.string().openapi({ description: "Session ID" }),
|
||||
}),
|
||||
),
|
||||
zValidator("json", Session.CommandInput.omit({ sessionID: true })),
|
||||
zValidator("json", SessionPrompt.CommandInput.omit({ sessionID: true })),
|
||||
async (c) => {
|
||||
const sessionID = c.req.valid("param").id
|
||||
const body = c.req.valid("json")
|
||||
const msg = await Session.command({ ...body, sessionID })
|
||||
const msg = await SessionPrompt.command({ ...body, sessionID })
|
||||
return c.json(msg)
|
||||
},
|
||||
)
|
||||
@@ -815,11 +811,11 @@ export namespace Server {
|
||||
id: z.string().openapi({ description: "Session ID" }),
|
||||
}),
|
||||
),
|
||||
zValidator("json", Session.ShellInput.omit({ sessionID: true })),
|
||||
zValidator("json", SessionPrompt.ShellInput.omit({ sessionID: true })),
|
||||
async (c) => {
|
||||
const sessionID = c.req.valid("param").id
|
||||
const body = c.req.valid("json")
|
||||
const msg = await Session.shell({ ...body, sessionID })
|
||||
const msg = await SessionPrompt.shell({ ...body, sessionID })
|
||||
return c.json(msg)
|
||||
},
|
||||
)
|
||||
@@ -845,11 +841,11 @@ export namespace Server {
|
||||
id: z.string(),
|
||||
}),
|
||||
),
|
||||
zValidator("json", Session.RevertInput.omit({ sessionID: true })),
|
||||
zValidator("json", SessionRevert.RevertInput.omit({ sessionID: true })),
|
||||
async (c) => {
|
||||
const id = c.req.valid("param").id
|
||||
log.info("revert", c.req.valid("json"))
|
||||
const session = await Session.revert({ sessionID: id, ...c.req.valid("json") })
|
||||
const session = await SessionRevert.revert({ sessionID: id, ...c.req.valid("json") })
|
||||
return c.json(session)
|
||||
},
|
||||
)
|
||||
@@ -877,7 +873,7 @@ export namespace Server {
|
||||
),
|
||||
async (c) => {
|
||||
const id = c.req.valid("param").id
|
||||
const session = await Session.unrevert({ sessionID: id })
|
||||
const session = await SessionRevert.unrevert({ sessionID: id })
|
||||
return c.json(session)
|
||||
},
|
||||
)
|
||||
|
||||
120
packages/opencode/src/session/compaction.ts
Normal file
120
packages/opencode/src/session/compaction.ts
Normal file
@@ -0,0 +1,120 @@
|
||||
import { generateText, type ModelMessage } from "ai"
|
||||
import { Session } from "."
|
||||
import { Identifier } from "../id/id"
|
||||
import { Instance } from "../project/instance"
|
||||
import { Provider } from "../provider/provider"
|
||||
import { defer } from "../util/defer"
|
||||
import { MessageV2 } from "./message-v2"
|
||||
import { SystemPrompt } from "./system"
|
||||
import { Bus } from "../bus"
|
||||
import z from "zod"
|
||||
import type { ModelsDev } from "../provider/models"
|
||||
import { SessionPrompt } from "./prompt"
|
||||
|
||||
export namespace SessionCompaction {
|
||||
export const Event = {
|
||||
Compacted: Bus.event(
|
||||
"session.compacted",
|
||||
z.object({
|
||||
sessionID: z.string(),
|
||||
}),
|
||||
),
|
||||
}
|
||||
|
||||
export function isOverflow(input: { tokens: MessageV2.Assistant["tokens"]; model: ModelsDev.Model }) {
|
||||
const count = input.tokens.input + input.tokens.cache.read + input.tokens.output
|
||||
const output = Math.min(input.model.limit.output, SessionPrompt.OUTPUT_TOKEN_MAX) || SessionPrompt.OUTPUT_TOKEN_MAX
|
||||
const usable = input.model.limit.context - output
|
||||
return count > usable / 2
|
||||
}
|
||||
|
||||
export async function run(input: { sessionID: string; providerID: string; modelID: string }) {
|
||||
await Session.update(input.sessionID, (draft) => {
|
||||
draft.time.compacting = Date.now()
|
||||
})
|
||||
await using _ = defer(async () => {
|
||||
await Session.update(input.sessionID, (draft) => {
|
||||
draft.time.compacting = undefined
|
||||
})
|
||||
})
|
||||
const toSummarize = await Session.messages(input.sessionID).then(MessageV2.filterSummarized)
|
||||
const model = await Provider.getModel(input.providerID, input.modelID)
|
||||
const system = [
|
||||
...SystemPrompt.summarize(model.providerID),
|
||||
...(await SystemPrompt.environment()),
|
||||
...(await SystemPrompt.custom()),
|
||||
]
|
||||
|
||||
const msg = (await Session.updateMessage({
|
||||
id: Identifier.ascending("message"),
|
||||
role: "assistant",
|
||||
sessionID: input.sessionID,
|
||||
system,
|
||||
mode: "build",
|
||||
path: {
|
||||
cwd: Instance.directory,
|
||||
root: Instance.worktree,
|
||||
},
|
||||
cost: 0,
|
||||
tokens: {
|
||||
output: 0,
|
||||
input: 0,
|
||||
reasoning: 0,
|
||||
cache: { read: 0, write: 0 },
|
||||
},
|
||||
modelID: input.modelID,
|
||||
providerID: model.providerID,
|
||||
time: {
|
||||
created: Date.now(),
|
||||
},
|
||||
})) as MessageV2.Assistant
|
||||
const generated = await generateText({
|
||||
maxRetries: 10,
|
||||
model: model.language,
|
||||
messages: [
|
||||
...system.map(
|
||||
(x): ModelMessage => ({
|
||||
role: "system",
|
||||
content: x,
|
||||
}),
|
||||
),
|
||||
...MessageV2.toModelMessage(toSummarize),
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
})
|
||||
const usage = Session.getUsage(model.info, generated.usage, generated.providerMetadata)
|
||||
msg.cost += usage.cost
|
||||
msg.tokens = usage.tokens
|
||||
msg.summary = true
|
||||
msg.time.completed = Date.now()
|
||||
await Session.updateMessage(msg)
|
||||
const part = await Session.updatePart({
|
||||
type: "text",
|
||||
sessionID: input.sessionID,
|
||||
messageID: msg.id,
|
||||
id: Identifier.ascending("part"),
|
||||
text: generated.text,
|
||||
time: {
|
||||
start: Date.now(),
|
||||
end: Date.now(),
|
||||
},
|
||||
})
|
||||
|
||||
Bus.publish(Event.Compacted, {
|
||||
sessionID: input.sessionID,
|
||||
})
|
||||
|
||||
return {
|
||||
info: msg,
|
||||
parts: [part],
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -331,6 +331,12 @@ export namespace MessageV2 {
|
||||
),
|
||||
}
|
||||
|
||||
export const WithParts = z.object({
|
||||
info: Info,
|
||||
parts: z.array(Part),
|
||||
})
|
||||
export type WithParts = z.infer<typeof WithParts>
|
||||
|
||||
export function fromV1(v1: Message.Info) {
|
||||
if (v1.role === "assistant") {
|
||||
const info: Assistant = {
|
||||
@@ -552,4 +558,10 @@ export namespace MessageV2 {
|
||||
|
||||
return convertToModelMessages(result)
|
||||
}
|
||||
|
||||
export function filterSummarized(msgs: { info: MessageV2.Info; parts: MessageV2.Part[] }[]) {
|
||||
const i = msgs.findLastIndex((m) => m.info.role === "assistant" && !!m.info.summary)
|
||||
if (i === -1) return msgs.slice()
|
||||
return msgs.slice(i)
|
||||
}
|
||||
}
|
||||
|
||||
1470
packages/opencode/src/session/prompt.ts
Normal file
1470
packages/opencode/src/session/prompt.ts
Normal file
File diff suppressed because it is too large
Load Diff
105
packages/opencode/src/session/revert.ts
Normal file
105
packages/opencode/src/session/revert.ts
Normal file
@@ -0,0 +1,105 @@
|
||||
import z from "zod"
|
||||
import { Identifier } from "../id/id"
|
||||
import { Snapshot } from "../snapshot"
|
||||
import { MessageV2 } from "./message-v2"
|
||||
import { Session } from "."
|
||||
import { Log } from "../util/log"
|
||||
import { splitWhen } from "remeda"
|
||||
import { Storage } from "../storage/storage"
|
||||
import { Bus } from "../bus"
|
||||
|
||||
export namespace SessionRevert {
|
||||
const log = Log.create({ service: "session.revert" })
|
||||
|
||||
export const RevertInput = z.object({
|
||||
sessionID: Identifier.schema("session"),
|
||||
messageID: Identifier.schema("message"),
|
||||
partID: Identifier.schema("part").optional(),
|
||||
})
|
||||
export type RevertInput = z.infer<typeof RevertInput>
|
||||
|
||||
export async function revert(input: RevertInput) {
|
||||
const all = await Session.messages(input.sessionID)
|
||||
let lastUser: MessageV2.User | undefined
|
||||
const session = await Session.get(input.sessionID)
|
||||
|
||||
let revert: Session.Info["revert"]
|
||||
const patches: Snapshot.Patch[] = []
|
||||
for (const msg of all) {
|
||||
if (msg.info.role === "user") lastUser = msg.info
|
||||
const remaining = []
|
||||
for (const part of msg.parts) {
|
||||
if (revert) {
|
||||
if (part.type === "patch") {
|
||||
patches.push(part)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if (!revert) {
|
||||
if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
|
||||
// if no useful parts left in message, same as reverting whole message
|
||||
const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
|
||||
revert = {
|
||||
messageID: !partID && lastUser ? lastUser.id : msg.info.id,
|
||||
partID,
|
||||
}
|
||||
}
|
||||
remaining.push(part)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (revert) {
|
||||
const session = await Session.get(input.sessionID)
|
||||
revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track())
|
||||
await Snapshot.revert(patches)
|
||||
if (revert.snapshot) revert.diff = await Snapshot.diff(revert.snapshot)
|
||||
return Session.update(input.sessionID, (draft) => {
|
||||
draft.revert = revert
|
||||
})
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
export async function unrevert(input: { sessionID: string }) {
|
||||
log.info("unreverting", input)
|
||||
const session = await Session.get(input.sessionID)
|
||||
if (!session.revert) return session
|
||||
if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)
|
||||
const next = await Session.update(input.sessionID, (draft) => {
|
||||
draft.revert = undefined
|
||||
})
|
||||
return next
|
||||
}
|
||||
|
||||
export async function cleanup(session: Session.Info) {
|
||||
if (!session.revert) return
|
||||
const sessionID = session.id
|
||||
let msgs = await Session.messages(sessionID)
|
||||
const messageID = session.revert.messageID
|
||||
const [preserve, remove] = splitWhen(msgs, (x) => x.info.id === messageID)
|
||||
msgs = preserve
|
||||
for (const msg of remove) {
|
||||
await Storage.remove(["message", sessionID, msg.info.id])
|
||||
await Bus.publish(MessageV2.Event.Removed, { sessionID: sessionID, messageID: msg.info.id })
|
||||
}
|
||||
const last = preserve.at(-1)
|
||||
if (session.revert.partID && last) {
|
||||
const partID = session.revert.partID
|
||||
const [preserveParts, removeParts] = splitWhen(last.parts, (x) => x.id === partID)
|
||||
last.parts = preserveParts
|
||||
for (const part of removeParts) {
|
||||
await Storage.remove(["part", last.info.id, part.id])
|
||||
await Bus.publish(MessageV2.Event.PartRemoved, {
|
||||
sessionID: sessionID,
|
||||
messageID: last.info.id,
|
||||
partID: part.id,
|
||||
})
|
||||
}
|
||||
}
|
||||
await Session.update(sessionID, (draft) => {
|
||||
draft.revert = undefined
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import { Bus } from "../bus"
|
||||
import { MessageV2 } from "../session/message-v2"
|
||||
import { Identifier } from "../id/id"
|
||||
import { Agent } from "../agent/agent"
|
||||
import { SessionPrompt } from "../session/prompt"
|
||||
|
||||
export const TaskTool = Tool.define("task", async () => {
|
||||
const agents = await Agent.list().then((x) => x.filter((a) => a.mode !== "primary"))
|
||||
@@ -49,9 +50,9 @@ export const TaskTool = Tool.define("task", async () => {
|
||||
}
|
||||
|
||||
ctx.abort.addEventListener("abort", () => {
|
||||
Session.abort(session.id)
|
||||
SessionPrompt.abort(session.id)
|
||||
})
|
||||
const result = await Session.prompt({
|
||||
const result = await SessionPrompt.prompt({
|
||||
messageID,
|
||||
sessionID: session.id,
|
||||
model: {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { describe, expect, test } from "bun:test"
|
||||
import { Session } from "../../src/session/index"
|
||||
import { SessionPrompt } from "../../src/session/prompt"
|
||||
|
||||
describe("fileRegex", () => {
|
||||
const template = `This is a @valid/path/to/a/file and it should also match at
|
||||
@@ -23,7 +23,7 @@ as well as @~/home-files and @~/paths/under/home.txt.
|
||||
|
||||
If the reference is \`@quoted/in/backticks\` then it shouldn't match at all.`
|
||||
|
||||
const matches = Array.from(template.matchAll(Session.fileRegex))
|
||||
const matches = Array.from(template.matchAll(SessionPrompt.fileRegex))
|
||||
|
||||
test("should extract exactly 12 file references", () => {
|
||||
expect(matches.length).toBe(12)
|
||||
@@ -79,13 +79,13 @@ If the reference is \`@quoted/in/backticks\` then it shouldn't match at all.`
|
||||
|
||||
test("should not match when preceded by backtick", () => {
|
||||
const backtickTest = "This `@should/not/match` should be ignored"
|
||||
const backtickMatches = Array.from(backtickTest.matchAll(Session.fileRegex))
|
||||
const backtickMatches = Array.from(backtickTest.matchAll(SessionPrompt.fileRegex))
|
||||
expect(backtickMatches.length).toBe(0)
|
||||
})
|
||||
|
||||
test("should not match email addresses", () => {
|
||||
const emailTest = "Contact user@example.com for help"
|
||||
const emailMatches = Array.from(emailTest.matchAll(Session.fileRegex))
|
||||
const emailMatches = Array.from(emailTest.matchAll(SessionPrompt.fileRegex))
|
||||
expect(emailMatches.length).toBe(0)
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user