diff --git a/packages/opencode/src/session/index.ts b/packages/opencode/src/session/index.ts index e5dbffac..71e894f8 100644 --- a/packages/opencode/src/session/index.ts +++ b/packages/opencode/src/session/index.ts @@ -34,6 +34,7 @@ import type { ModelsDev } from "../provider/models" import { Installation } from "../installation" import { Config } from "../config/config" import { ProviderTransform } from "../provider/transform" +import { Snapshot } from "../snapshot" export namespace Session { const log = Log.create({ service: "session" }) @@ -53,6 +54,13 @@ export namespace Session { created: z.number(), updated: z.number(), }), + revert: z + .object({ + messageID: z.string(), + part: z.number(), + snapshot: z.string().optional(), + }) + .optional(), }) .openapi({ ref: "Session", @@ -285,6 +293,37 @@ export namespace Session { l.info("chatting") const model = await Provider.getModel(input.providerID, input.modelID) let msgs = await messages(input.sessionID) + const session = await get(input.sessionID) + + if (session.revert) { + const trimmed = [] + for (const msg of msgs) { + if ( + msg.id > session.revert.messageID || + (msg.id === session.revert.messageID && session.revert.part === 0) + ) { + await Storage.remove( + "session/message/" + input.sessionID + "/" + msg.id, + ) + await Bus.publish(Message.Event.Removed, { + sessionID: input.sessionID, + messageID: msg.id, + }) + continue + } + + if (msg.id === session.revert.messageID) { + if (session.revert.part === 0) break + msg.parts = msg.parts.slice(0, session.revert.part) + } + trimmed.push(msg) + } + msgs = trimmed + await update(input.sessionID, (draft) => { + draft.revert = undefined + }) + } + const previous = msgs.at(-1) // auto summarize if too long @@ -319,7 +358,6 @@ export namespace Session { if (lastSummary) msgs = msgs.filter((msg) => msg.id >= lastSummary.id) const app = App.info() - const session = await get(input.sessionID) if (msgs.length === 0 && !session.parentID) { generateText({ maxTokens: input.providerID === "google" ? 1024 : 20, @@ -349,6 +387,7 @@ export namespace Session { }) .catch(() => {}) } + const snapshot = await Snapshot.create(input.sessionID) const msg: Message.Info = { role: "user", id: Identifier.ascending("message"), @@ -359,6 +398,7 @@ export namespace Session { }, sessionID: input.sessionID, tool: {}, + snapshot, }, } await updateMessage(msg) @@ -373,6 +413,7 @@ export namespace Session { role: "assistant", parts: [], metadata: { + snapshot, assistant: { system, path: { @@ -424,6 +465,7 @@ export namespace Session { }) next.metadata!.tool![opts.toolCallId] = { ...result.metadata, + snapshot: await Snapshot.create(input.sessionID), time: { start, end: Date.now(), @@ -436,6 +478,7 @@ export namespace Session { error: true, message: e.toString(), title: e.toString(), + snapshot: await Snapshot.create(input.sessionID), time: { start, end: Date.now(), @@ -457,6 +500,7 @@ export namespace Session { const result = await execute(args, opts) next.metadata!.tool![opts.toolCallId] = { ...result.metadata, + snapshot: await Snapshot.create(input.sessionID), time: { start, end: Date.now(), @@ -471,6 +515,7 @@ export namespace Session { next.metadata!.tool![opts.toolCallId] = { error: true, message: e.toString(), + snapshot: await Snapshot.create(input.sessionID), title: "mcp", time: { start, @@ -735,6 +780,51 @@ export namespace Session { return next } + export async function revert(input: { + sessionID: string + messageID: string + part: number + }) { + const message = await getMessage(input.sessionID, input.messageID) + if (!message) return + const part = message.parts[input.part] + if (!part) return + const session = await get(input.sessionID) + const snapshot = + session.revert?.snapshot ?? (await Snapshot.create(input.sessionID)) + const old = (() => { + if (message.role === "assistant") { + const lastTool = message.parts.findLast( + (part, index) => + part.type === "tool-invocation" && index < input.part, + ) + if (lastTool && lastTool.type === "tool-invocation") + return message.metadata.tool[lastTool.toolInvocation.toolCallId] + .snapshot + } + return message.metadata.snapshot + })() + if (old) await Snapshot.restore(input.sessionID, old) + await update(input.sessionID, (draft) => { + draft.revert = { + messageID: input.messageID, + part: input.part, + snapshot, + } + }) + } + + export async function unrevert(sessionID: string) { + const session = await get(sessionID) + if (!session) return + if (!session.revert) return + if (session.revert.snapshot) + await Snapshot.restore(sessionID, session.revert.snapshot) + update(sessionID, (draft) => { + draft.revert = undefined + }) + } + export async function summarize(input: { sessionID: string providerID: string diff --git a/packages/opencode/src/session/message.ts b/packages/opencode/src/session/message.ts index b2171fa4..2d319e87 100644 --- a/packages/opencode/src/session/message.ts +++ b/packages/opencode/src/session/message.ts @@ -159,6 +159,7 @@ export namespace Message { z .object({ title: z.string(), + snapshot: z.string().optional(), time: z.object({ start: z.number(), end: z.number(), @@ -188,11 +189,7 @@ export namespace Message { }), }) .optional(), - user: z - .object({ - snapshot: z.string().optional(), - }) - .optional(), + snapshot: z.string().optional(), }) .openapi({ ref: "MessageMetadata" }), }) @@ -208,6 +205,13 @@ export namespace Message { info: Info, }), ), + Removed: Bus.event( + "message.removed", + z.object({ + sessionID: z.string(), + messageID: z.string(), + }), + ), PartUpdated: Bus.event( "message.part.updated", z.object({