basic undo feature (#1268)

Co-authored-by: adamdotdevin <2363879+adamdottv@users.noreply.github.com>
Co-authored-by: Jay V <air@live.ca>
Co-authored-by: Aiden Cline <63023139+rekram1-node@users.noreply.github.com>
Co-authored-by: Andrew Joslin <andrew@ajoslin.com>
Co-authored-by: GitHub Action <action@github.com>
Co-authored-by: Tobias Walle <9933601+tobias-walle@users.noreply.github.com>
This commit is contained in:
Dax
2025-07-23 20:30:46 -04:00
committed by GitHub
parent 507c975e92
commit 96866e52ce
26 changed files with 768 additions and 127 deletions

View File

@@ -3,6 +3,7 @@ import { ConfigHooks } from "../config/hooks"
import { Format } from "../format"
import { LSP } from "../lsp"
import { Share } from "../share/share"
import { Snapshot } from "../snapshot"
export async function bootstrap<T>(input: App.Input, cb: (app: App.Info) => Promise<T>) {
return App.provide(input, async (app) => {
@@ -10,6 +11,7 @@ export async function bootstrap<T>(input: App.Input, cb: (app: App.Info) => Prom
Format.init()
ConfigHooks.init()
LSP.init()
Snapshot.init()
return cb(app)
})

View File

@@ -1,10 +1,12 @@
import { Session } from "../../../session"
import { Snapshot } from "../../../snapshot"
import { bootstrap } from "../../bootstrap"
import { cmd } from "../cmd"
export const SnapshotCommand = cmd({
command: "snapshot",
builder: (yargs) => yargs.command(CreateCommand).command(RestoreCommand).command(DiffCommand).demandCommand(),
builder: (yargs) =>
yargs.command(CreateCommand).command(RestoreCommand).command(DiffCommand).command(RevertCommand).demandCommand(),
async handler() {},
})
@@ -12,7 +14,7 @@ const CreateCommand = cmd({
command: "create",
async handler() {
await bootstrap({ cwd: process.cwd() }, async () => {
const result = await Snapshot.create("test")
const result = await Snapshot.create()
console.log(result)
})
},
@@ -28,7 +30,7 @@ const RestoreCommand = cmd({
}),
async handler(args) {
await bootstrap({ cwd: process.cwd() }, async () => {
await Snapshot.restore("test", args.commit)
await Snapshot.restore(args.commit)
console.log("restored")
})
},
@@ -45,8 +47,34 @@ export const DiffCommand = cmd({
}),
async handler(args) {
await bootstrap({ cwd: process.cwd() }, async () => {
const diff = await Snapshot.diff("test", args.commit)
const diff = await Snapshot.diff(args.commit)
console.log(diff)
})
},
})
export const RevertCommand = cmd({
command: "revert <sessionID> <messageID>",
describe: "revert",
builder: (yargs) =>
yargs
.positional("sessionID", {
type: "string",
description: "sessionID",
demandOption: true,
})
.positional("messageID", {
type: "string",
description: "messageID",
demandOption: true,
}),
async handler(args) {
await bootstrap({ cwd: process.cwd() }, async () => {
const session = await Session.revert({
sessionID: args.sessionID,
messageID: args.messageID,
})
console.log(session?.revert)
})
},
})

View File

@@ -26,6 +26,9 @@ export namespace Config {
if (result.autoshare === true && !result.share) {
result.share = "auto"
}
if (result.keybinds?.messages_revert && !result.keybinds.messages_undo) {
result.keybinds.messages_undo = result.keybinds.messages_revert
}
if (!result.username) {
const os = await import("os")
@@ -89,7 +92,7 @@ export namespace Config {
session_new: z.string().optional().default("<leader>n").describe("Create a new session"),
session_list: z.string().optional().default("<leader>l").describe("List all sessions"),
session_share: z.string().optional().default("<leader>s").describe("Share current session"),
session_unshare: z.string().optional().default("<leader>u").describe("Unshare current session"),
session_unshare: z.string().optional().default("none").describe("Unshare current session"),
session_interrupt: z.string().optional().default("esc").describe("Interrupt current session"),
session_compact: z.string().optional().default("<leader>c").describe("Compact the session"),
tool_details: z.string().optional().default("<leader>d").describe("Toggle tool details"),
@@ -118,7 +121,9 @@ export namespace Config {
messages_last: z.string().optional().default("ctrl+alt+g").describe("Navigate to last message"),
messages_layout_toggle: z.string().optional().default("<leader>p").describe("Toggle layout"),
messages_copy: z.string().optional().default("<leader>y").describe("Copy message"),
messages_revert: z.string().optional().default("<leader>r").describe("Revert message"),
messages_revert: z.string().optional().default("none").describe("@deprecated use messages_undo. Revert message"),
messages_undo: z.string().optional().default("<leader>u").describe("Undo message"),
messages_redo: z.string().optional().default("<leader>r").describe("Redo message"),
app_exit: z.string().optional().default("ctrl+c,<leader>q").describe("Exit the application"),
})
.strict()

View File

@@ -58,15 +58,20 @@ export namespace Server {
})
})
.use(async (c, next) => {
log.info("request", {
method: c.req.method,
path: c.req.path,
})
const skipLogging = c.req.path === "/log"
if (!skipLogging) {
log.info("request", {
method: c.req.method,
path: c.req.path,
})
}
const start = Date.now()
await next()
log.info("response", {
duration: Date.now() - start,
})
if (!skipLogging) {
log.info("response", {
duration: Date.now() - start,
})
}
})
.get(
"/doc",
@@ -461,6 +466,61 @@ export namespace Server {
return c.json(msg)
},
)
.post(
"/session/:id/revert",
describeRoute({
description: "Revert a message",
responses: {
200: {
description: "Updated session",
content: {
"application/json": {
schema: resolver(Session.Info),
},
},
},
},
}),
zValidator(
"param",
z.object({
id: z.string(),
}),
),
zValidator("json", Session.RevertInput.omit({ sessionID: true })),
async (c) => {
const id = c.req.valid("param").id
const session = await Session.revert({ sessionID: id, ...c.req.valid("json") })
return c.json(session)
},
)
.post(
"/session/:id/unrevert",
describeRoute({
description: "Restore all reverted messages",
responses: {
200: {
description: "Updated session",
content: {
"application/json": {
schema: resolver(Session.Info),
},
},
},
},
}),
zValidator(
"param",
z.object({
id: z.string(),
}),
),
async (c) => {
const id = c.req.valid("param").id
const session = await Session.unrevert({ sessionID: id })
return c.json(session)
},
)
.get(
"/config/providers",
describeRoute({

View File

@@ -40,6 +40,7 @@ import { MessageV2 } from "./message-v2"
import { Mode } from "./mode"
import { LSP } from "../lsp"
import { ReadTool } from "../tool/read"
import { splitWhen } from "remeda"
export namespace Session {
const log = Log.create({ service: "session" })
@@ -64,7 +65,7 @@ export namespace Session {
revert: z
.object({
messageID: z.string(),
part: z.number(),
partID: z.string().optional(),
snapshot: z.string().optional(),
})
.optional(),
@@ -246,7 +247,7 @@ export namespace Session {
const read = await Storage.readJSON<MessageV2.Info>(p)
result.push({
info: read,
parts: await parts(sessionID, read.id),
parts: await getParts(sessionID, read.id),
})
}
result.sort((a, b) => (a.info.id > b.info.id ? 1 : -1))
@@ -257,7 +258,7 @@ export namespace Session {
return Storage.readJSON<MessageV2.Info>("session/message/" + sessionID + "/" + messageID)
}
export async function parts(sessionID: string, messageID: string) {
export async function getParts(sessionID: string, messageID: string) {
const result = [] as MessageV2.Part[]
for (const item of await Storage.list("session/part/" + sessionID + "/" + messageID)) {
const read = await Storage.readJSON<MessageV2.Part>(item)
@@ -531,30 +532,26 @@ export namespace Session {
const session = await get(input.sessionID)
if (session.revert) {
const trimmed = []
for (const msg of msgs) {
if (
msg.info.id > session.revert.messageID ||
(msg.info.id === session.revert.messageID && session.revert.part === 0)
) {
await Storage.remove("session/message/" + input.sessionID + "/" + msg.info.id)
await Bus.publish(MessageV2.Event.Removed, {
sessionID: input.sessionID,
messageID: msg.info.id,
})
continue
}
if (msg.info.id === session.revert.messageID) {
if (session.revert.part === 0) break
msg.parts = msg.parts.slice(0, session.revert.part)
}
trimmed.push(msg)
const messageID = session.revert.messageID
const [preserve, remove] = splitWhen(msgs, (x) => x.info.id === messageID)
msgs = preserve
for (const msg of remove) {
await Storage.remove(`session/message/${input.sessionID}/${msg.info.id}`)
await Bus.publish(MessageV2.Event.Removed, { sessionID: input.sessionID, messageID: msg.info.id })
}
const last = preserve.at(-1)
if (session.revert.partID && last) {
const partID = session.revert.partID
const [preserveParts, removeParts] = splitWhen(last.parts, (x) => x.id === partID)
last.parts = preserveParts
for (const part of removeParts) {
await Storage.remove(`session/part/${input.sessionID}/${last.info.id}/${part.id}`)
await Bus.publish(MessageV2.Event.PartRemoved, {
messageID: last.info.id,
partID: part.id,
})
}
}
msgs = trimmed
await update(input.sessionID, (draft) => {
draft.revert = undefined
})
}
const previous = msgs.filter((x) => x.info.role === "assistant").at(-1)?.info as MessageV2.Assistant
@@ -831,7 +828,7 @@ export namespace Session {
})
switch (value.type) {
case "start":
const snapshot = await Snapshot.create(assistantMsg.sessionID)
const snapshot = await Snapshot.create()
if (snapshot)
await updatePart({
id: Identifier.ascending("part"),
@@ -895,7 +892,7 @@ export namespace Session {
},
})
delete toolCalls[value.toolCallId]
const snapshot = await Snapshot.create(assistantMsg.sessionID)
const snapshot = await Snapshot.create()
if (snapshot)
await updatePart({
id: Identifier.ascending("part"),
@@ -924,7 +921,7 @@ export namespace Session {
},
})
delete toolCalls[value.toolCallId]
const snapshot = await Snapshot.create(assistantMsg.sessionID)
const snapshot = await Snapshot.create()
if (snapshot)
await updatePart({
id: Identifier.ascending("part"),
@@ -1043,7 +1040,7 @@ export namespace Session {
error: assistantMsg.error,
})
}
const p = await parts(assistantMsg.sessionID, assistantMsg.id)
const p = await getParts(assistantMsg.sessionID, assistantMsg.id)
for (const part of p) {
if (part.type === "tool" && part.state.status !== "completed") {
updatePart({
@@ -1067,47 +1064,53 @@ export namespace Session {
}
}
export async function revert(_input: { sessionID: string; messageID: string; part: number }) {
// TODO
/*
const message = await getMessage(input.sessionID, input.messageID)
if (!message) return
const part = message.parts[input.part]
if (!part) return
export const RevertInput = z.object({
sessionID: Identifier.schema("session"),
messageID: Identifier.schema("message"),
partID: Identifier.schema("part").optional(),
})
export type RevertInput = z.infer<typeof RevertInput>
export async function revert(input: RevertInput) {
const all = await messages(input.sessionID)
const session = await get(input.sessionID)
const snapshot =
session.revert?.snapshot ?? (await Snapshot.create(input.sessionID))
const old = (() => {
if (message.role === "assistant") {
const lastTool = message.parts.findLast(
(part, index) =>
part.type === "tool-invocation" && index < input.part,
)
if (lastTool && lastTool.type === "tool-invocation")
return message.metadata.tool[lastTool.toolInvocation.toolCallId]
.snapshot
let lastUser: MessageV2.User | undefined
let lastSnapshot: MessageV2.SnapshotPart | undefined
for (const msg of all) {
if (msg.info.role === "user") lastUser = msg.info
const remaining = []
for (const part of msg.parts) {
if (part.type === "snapshot") lastSnapshot = part
if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
// if no useful parts left in message, same as reverting whole message
const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
const snapshot = session.revert?.snapshot ?? (await Snapshot.create(true))
log.info("revert snapshot", { snapshot })
if (lastSnapshot) await Snapshot.restore(lastSnapshot.snapshot)
const next = await update(input.sessionID, (draft) => {
draft.revert = {
// if not part id jump to the last user message
messageID: !partID && lastUser ? lastUser.id : msg.info.id,
partID,
snapshot,
}
})
return next
}
remaining.push(part)
}
return message.metadata.snapshot
})()
if (old) await Snapshot.restore(input.sessionID, old)
await update(input.sessionID, (draft) => {
draft.revert = {
messageID: input.messageID,
part: input.part,
snapshot,
}
})
*/
}
}
export async function unrevert(sessionID: string) {
const session = await get(sessionID)
if (!session) return
if (!session.revert) return
if (session.revert.snapshot) await Snapshot.restore(sessionID, session.revert.snapshot)
update(sessionID, (draft) => {
export async function unrevert(input: { sessionID: string }) {
log.info("unreverting", input)
const session = await get(input.sessionID)
if (!session.revert) return session
if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)
const next = await update(input.sessionID, (draft) => {
draft.revert = undefined
})
return next
}
export async function summarize(input: { sessionID: string; providerID: string; modelID: string }) {

View File

@@ -272,6 +272,13 @@ export namespace MessageV2 {
part: Part,
}),
),
PartRemoved: Bus.event(
"message.part.removed",
z.object({
messageID: z.string(),
partID: z.string(),
}),
),
}
export function fromV1(v1: Message.Info) {

View File

@@ -4,11 +4,26 @@ import path from "path"
import fs from "fs/promises"
import { Ripgrep } from "../file/ripgrep"
import { Log } from "../util/log"
import { Global } from "../global"
export namespace Snapshot {
const log = Log.create({ service: "snapshot" })
export async function create(sessionID: string) {
export function init() {
Array.fromAsync(
new Bun.Glob("**/snapshot").scan({
absolute: true,
onlyFiles: false,
cwd: Global.Path.data,
}),
).then((files) => {
for (const file of files) {
fs.rmdir(file, { recursive: true })
}
})
}
export async function create(force?: boolean) {
log.info("creating snapshot")
const app = App.info()
@@ -23,7 +38,7 @@ export namespace Snapshot {
if (files.length >= 1000) return
}
const git = gitdir(sessionID)
const git = gitdir()
if (await fs.mkdir(git, { recursive: true })) {
await $`git init`
.env({
@@ -40,7 +55,7 @@ export namespace Snapshot {
log.info("added files")
const result =
await $`git --git-dir ${git} commit -m "snapshot" --no-gpg-sign --author="opencode <mail@opencode.ai>"`
await $`git --git-dir ${git} commit ${force ? "--allow-empty" : ""} -m "snapshot" --no-gpg-sign --author="opencode <mail@opencode.ai>"`
.quiet()
.cwd(app.path.cwd)
.nothrow()
@@ -50,21 +65,22 @@ export namespace Snapshot {
return match![1]
}
export async function restore(sessionID: string, snapshot: string) {
export async function restore(snapshot: string) {
log.info("restore", { commit: snapshot })
const app = App.info()
const git = gitdir(sessionID)
await $`git --git-dir=${git} checkout ${snapshot} --force`.quiet().cwd(app.path.root)
const git = gitdir()
await $`git --git-dir=${git} reset --hard ${snapshot}`.quiet().cwd(app.path.root)
}
export async function diff(sessionID: string, commit: string) {
const git = gitdir(sessionID)
export async function diff(commit: string) {
const git = gitdir()
const result = await $`git --git-dir=${git} diff -R ${commit}`.quiet().cwd(App.info().path.root)
return result.stdout.toString("utf8")
const text = result.stdout.toString("utf8")
return text
}
function gitdir(sessionID: string) {
function gitdir() {
const app = App.info()
return path.join(app.path.data, "snapshot", sessionID)
return path.join(app.path.data, "snapshots")
}
}