mirror of
https://github.com/aljazceru/opencode.git
synced 2026-01-10 03:14:55 +01:00
wip: more snapshot stuff
This commit is contained in:
@@ -5,76 +5,30 @@ import { cmd } from "../cmd"
|
||||
|
||||
export const SnapshotCommand = cmd({
|
||||
command: "snapshot",
|
||||
builder: (yargs) =>
|
||||
yargs.command(CreateCommand).command(RestoreCommand).command(DiffCommand).command(RevertCommand).demandCommand(),
|
||||
builder: (yargs) => yargs.command(TrackCommand).command(PatchCommand).demandCommand(),
|
||||
async handler() {},
|
||||
})
|
||||
|
||||
const CreateCommand = cmd({
|
||||
command: "create",
|
||||
const TrackCommand = cmd({
|
||||
command: "track",
|
||||
async handler() {
|
||||
await bootstrap({ cwd: process.cwd() }, async () => {
|
||||
const result = await Snapshot.create()
|
||||
console.log(result)
|
||||
console.log(await Snapshot.track())
|
||||
})
|
||||
},
|
||||
})
|
||||
|
||||
const RestoreCommand = cmd({
|
||||
command: "restore <commit>",
|
||||
const PatchCommand = cmd({
|
||||
command: "patch <hash>",
|
||||
builder: (yargs) =>
|
||||
yargs.positional("commit", {
|
||||
yargs.positional("hash", {
|
||||
type: "string",
|
||||
description: "commit",
|
||||
description: "hash",
|
||||
demandOption: true,
|
||||
}),
|
||||
async handler(args) {
|
||||
await bootstrap({ cwd: process.cwd() }, async () => {
|
||||
await Snapshot.restore(args.commit)
|
||||
console.log("restored")
|
||||
})
|
||||
},
|
||||
})
|
||||
|
||||
export const DiffCommand = cmd({
|
||||
command: "diff <commit>",
|
||||
describe: "diff",
|
||||
builder: (yargs) =>
|
||||
yargs.positional("commit", {
|
||||
type: "string",
|
||||
description: "commit",
|
||||
demandOption: true,
|
||||
}),
|
||||
async handler(args) {
|
||||
await bootstrap({ cwd: process.cwd() }, async () => {
|
||||
const diff = await Snapshot.diff(args.commit)
|
||||
console.log(diff)
|
||||
})
|
||||
},
|
||||
})
|
||||
|
||||
export const RevertCommand = cmd({
|
||||
command: "revert <sessionID> <messageID>",
|
||||
describe: "revert",
|
||||
builder: (yargs) =>
|
||||
yargs
|
||||
.positional("sessionID", {
|
||||
type: "string",
|
||||
description: "sessionID",
|
||||
demandOption: true,
|
||||
})
|
||||
.positional("messageID", {
|
||||
type: "string",
|
||||
description: "messageID",
|
||||
demandOption: true,
|
||||
}),
|
||||
async handler(args) {
|
||||
await bootstrap({ cwd: process.cwd() }, async () => {
|
||||
const session = await Session.revert({
|
||||
sessionID: args.sessionID,
|
||||
messageID: args.messageID,
|
||||
})
|
||||
console.log(session?.revert)
|
||||
console.log(await Snapshot.patch(args.hash))
|
||||
})
|
||||
},
|
||||
})
|
||||
|
||||
@@ -661,6 +661,7 @@ export namespace Session {
|
||||
description: item.description,
|
||||
inputSchema: item.parameters as ZodSchema,
|
||||
async execute(args, options) {
|
||||
await processor.track(options.toolCallId)
|
||||
const result = await item.execute(args, {
|
||||
sessionID: input.sessionID,
|
||||
abort: abort.signal,
|
||||
@@ -699,6 +700,7 @@ export namespace Session {
|
||||
const execute = item.execute
|
||||
if (!execute) continue
|
||||
item.execute = async (args, opts) => {
|
||||
await processor.track(opts.toolCallId)
|
||||
const result = await execute(args, opts)
|
||||
const output = result.content
|
||||
.filter((x: any) => x.type === "text")
|
||||
@@ -814,7 +816,12 @@ export namespace Session {
|
||||
|
||||
function createProcessor(assistantMsg: MessageV2.Assistant, model: ModelsDev.Model) {
|
||||
const toolCalls: Record<string, MessageV2.ToolPart> = {}
|
||||
const snapshots: Record<string, string> = {}
|
||||
return {
|
||||
async track(toolCallID: string) {
|
||||
const hash = await Snapshot.track()
|
||||
if (hash) snapshots[toolCallID] = hash
|
||||
},
|
||||
partFromToolCall(toolCallID: string) {
|
||||
return toolCalls[toolCallID]
|
||||
},
|
||||
@@ -828,15 +835,6 @@ export namespace Session {
|
||||
})
|
||||
switch (value.type) {
|
||||
case "start":
|
||||
const snapshot = await Snapshot.create()
|
||||
if (snapshot)
|
||||
await updatePart({
|
||||
id: Identifier.ascending("part"),
|
||||
messageID: assistantMsg.id,
|
||||
sessionID: assistantMsg.sessionID,
|
||||
type: "snapshot",
|
||||
snapshot,
|
||||
})
|
||||
break
|
||||
|
||||
case "tool-input-start":
|
||||
@@ -857,6 +855,9 @@ export namespace Session {
|
||||
case "tool-input-delta":
|
||||
break
|
||||
|
||||
case "tool-input-end":
|
||||
break
|
||||
|
||||
case "tool-call": {
|
||||
const match = toolCalls[value.toolCallId]
|
||||
if (match) {
|
||||
@@ -892,15 +893,20 @@ export namespace Session {
|
||||
},
|
||||
})
|
||||
delete toolCalls[value.toolCallId]
|
||||
const snapshot = await Snapshot.create()
|
||||
if (snapshot)
|
||||
await updatePart({
|
||||
id: Identifier.ascending("part"),
|
||||
messageID: assistantMsg.id,
|
||||
sessionID: assistantMsg.sessionID,
|
||||
type: "snapshot",
|
||||
snapshot,
|
||||
})
|
||||
const snapshot = snapshots[value.toolCallId]
|
||||
if (snapshot) {
|
||||
const patch = await Snapshot.patch(snapshot)
|
||||
if (patch.files.length) {
|
||||
await updatePart({
|
||||
id: Identifier.ascending("part"),
|
||||
messageID: assistantMsg.id,
|
||||
sessionID: assistantMsg.sessionID,
|
||||
type: "patch",
|
||||
hash: patch.hash,
|
||||
files: patch.files,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
@@ -921,15 +927,18 @@ export namespace Session {
|
||||
},
|
||||
})
|
||||
delete toolCalls[value.toolCallId]
|
||||
const snapshot = await Snapshot.create()
|
||||
if (snapshot)
|
||||
const snapshot = snapshots[value.toolCallId]
|
||||
if (snapshot) {
|
||||
const patch = await Snapshot.patch(snapshot)
|
||||
await updatePart({
|
||||
id: Identifier.ascending("part"),
|
||||
messageID: assistantMsg.id,
|
||||
sessionID: assistantMsg.sessionID,
|
||||
type: "snapshot",
|
||||
snapshot,
|
||||
type: "patch",
|
||||
hash: patch.hash,
|
||||
files: patch.files,
|
||||
})
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
@@ -1073,33 +1082,45 @@ export namespace Session {
|
||||
|
||||
export async function revert(input: RevertInput) {
|
||||
const all = await messages(input.sessionID)
|
||||
const session = await get(input.sessionID)
|
||||
let lastUser: MessageV2.User | undefined
|
||||
let lastSnapshot: MessageV2.SnapshotPart | undefined
|
||||
const session = await get(input.sessionID)
|
||||
|
||||
let revert: Info["revert"]
|
||||
const patches: Snapshot.Patch[] = []
|
||||
for (const msg of all) {
|
||||
if (msg.info.role === "user") lastUser = msg.info
|
||||
const remaining = []
|
||||
for (const part of msg.parts) {
|
||||
if (part.type === "snapshot") lastSnapshot = part
|
||||
if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
|
||||
// if no useful parts left in message, same as reverting whole message
|
||||
const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
|
||||
const snapshot = session.revert?.snapshot ?? (await Snapshot.create())
|
||||
log.info("revert snapshot", { snapshot })
|
||||
if (lastSnapshot) await Snapshot.restore(lastSnapshot.snapshot)
|
||||
const next = await update(input.sessionID, (draft) => {
|
||||
draft.revert = {
|
||||
// if not part id jump to the last user message
|
||||
if (revert) {
|
||||
if (part.type === "patch") {
|
||||
patches.push(part)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if (!revert) {
|
||||
if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
|
||||
// if no useful parts left in message, same as reverting whole message
|
||||
const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
|
||||
revert = {
|
||||
messageID: !partID && lastUser ? lastUser.id : msg.info.id,
|
||||
partID,
|
||||
snapshot,
|
||||
}
|
||||
})
|
||||
return next
|
||||
}
|
||||
remaining.push(part)
|
||||
}
|
||||
remaining.push(part)
|
||||
}
|
||||
}
|
||||
|
||||
if (revert) {
|
||||
const session = await get(input.sessionID)
|
||||
revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track())
|
||||
await Snapshot.revert(patches)
|
||||
return update(input.sessionID, (draft) => {
|
||||
draft.revert = revert
|
||||
})
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
export async function unrevert(input: { sessionID: string }) {
|
||||
|
||||
@@ -94,6 +94,15 @@ export namespace MessageV2 {
|
||||
})
|
||||
export type SnapshotPart = z.infer<typeof SnapshotPart>
|
||||
|
||||
export const PatchPart = PartBase.extend({
|
||||
type: z.literal("patch"),
|
||||
hash: z.string(),
|
||||
files: z.string().array(),
|
||||
}).openapi({
|
||||
ref: "PatchPart",
|
||||
})
|
||||
export type PatchPart = z.infer<typeof PatchPart>
|
||||
|
||||
export const TextPart = PartBase.extend({
|
||||
type: z.literal("text"),
|
||||
text: z.string(),
|
||||
@@ -203,7 +212,7 @@ export namespace MessageV2 {
|
||||
export type User = z.infer<typeof User>
|
||||
|
||||
export const Part = z
|
||||
.discriminatedUnion("type", [TextPart, FilePart, ToolPart, StepStartPart, StepFinishPart, SnapshotPart])
|
||||
.discriminatedUnion("type", [TextPart, FilePart, ToolPart, StepStartPart, StepFinishPart, SnapshotPart, PatchPart])
|
||||
.openapi({
|
||||
ref: "Part",
|
||||
})
|
||||
|
||||
@@ -6,6 +6,7 @@ import { Ripgrep } from "../file/ripgrep"
|
||||
import { Log } from "../util/log"
|
||||
import { Global } from "../global"
|
||||
import { Installation } from "../installation"
|
||||
import { z } from "zod"
|
||||
|
||||
export namespace Snapshot {
|
||||
const log = Log.create({ service: "snapshot" })
|
||||
@@ -24,21 +25,9 @@ export namespace Snapshot {
|
||||
})
|
||||
}
|
||||
|
||||
export async function create() {
|
||||
log.info("creating snapshot")
|
||||
export async function track() {
|
||||
const app = App.info()
|
||||
|
||||
// not a git repo, check if too big to snapshot
|
||||
if (!app.git || !Installation.isDev()) {
|
||||
return
|
||||
const files = await Ripgrep.files({
|
||||
cwd: app.path.cwd,
|
||||
limit: 1000,
|
||||
})
|
||||
log.info("found files", { count: files.length })
|
||||
if (files.length >= 1000) return
|
||||
}
|
||||
|
||||
if (!app.git) return
|
||||
const git = gitdir()
|
||||
if (await fs.mkdir(git, { recursive: true })) {
|
||||
await $`git init`
|
||||
@@ -51,33 +40,52 @@ export namespace Snapshot {
|
||||
.nothrow()
|
||||
log.info("initialized")
|
||||
}
|
||||
|
||||
await $`git --git-dir ${git} add .`.quiet().cwd(app.path.cwd).nothrow()
|
||||
log.info("added files")
|
||||
const hash = await $`git --git-dir ${git} write-tree`.quiet().cwd(app.path.cwd).text()
|
||||
return hash.trim()
|
||||
}
|
||||
|
||||
const result =
|
||||
await $`git --git-dir ${git} commit --allow-empty -m "snapshot" --no-gpg-sign --author="opencode <mail@opencode.ai>"`
|
||||
.quiet()
|
||||
.cwd(app.path.cwd)
|
||||
.nothrow()
|
||||
export const Patch = z.object({
|
||||
hash: z.string(),
|
||||
files: z.string().array(),
|
||||
})
|
||||
export type Patch = z.infer<typeof Patch>
|
||||
|
||||
const match = result.stdout.toString().match(/\[.+ ([a-f0-9]+)\]/)
|
||||
if (!match) return
|
||||
return match![1]
|
||||
export async function patch(hash: string): Promise<Patch> {
|
||||
const app = App.info()
|
||||
const git = gitdir()
|
||||
const files = await $`git --git-dir ${git} diff --name-only ${hash} -- .`.cwd(app.path.cwd).text()
|
||||
return {
|
||||
hash,
|
||||
files: files
|
||||
.trim()
|
||||
.split("\n")
|
||||
.map((x) => x.trim())
|
||||
.filter(Boolean)
|
||||
.map((x) => path.join(app.path.cwd, x)),
|
||||
}
|
||||
}
|
||||
|
||||
export async function restore(snapshot: string) {
|
||||
log.info("restore", { commit: snapshot })
|
||||
const app = App.info()
|
||||
const git = gitdir()
|
||||
await $`git --git-dir=${git} reset --hard ${snapshot}`.quiet().cwd(app.path.root)
|
||||
await $`git --git-dir=${git} read-tree ${snapshot} && git --git-dir=${git} checkout-index -a -f`
|
||||
.quiet()
|
||||
.cwd(app.path.root)
|
||||
}
|
||||
|
||||
export async function diff(commit: string) {
|
||||
export async function revert(patches: Patch[]) {
|
||||
const files = new Set<string>()
|
||||
const git = gitdir()
|
||||
const result = await $`git --git-dir=${git} diff -R ${commit}`.quiet().cwd(App.info().path.root)
|
||||
const text = result.stdout.toString("utf8")
|
||||
return text
|
||||
for (const item of patches) {
|
||||
for (const file of item.files) {
|
||||
if (files.has(file)) continue
|
||||
log.info("reverting", { file, hash: item.hash })
|
||||
await $`git --git-dir=${git} checkout ${item.hash} -- ${file}`.quiet().cwd(App.info().path.root)
|
||||
files.add(file)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function gitdir() {
|
||||
|
||||
Reference in New Issue
Block a user