wip: session revert/unrevert

This commit is contained in:
Dax Raad
2025-07-02 13:00:46 -04:00
parent b89d4a16fd
commit 35d6273fb3
2 changed files with 100 additions and 6 deletions

View File

@@ -34,6 +34,7 @@ import type { ModelsDev } from "../provider/models"
import { Installation } from "../installation"
import { Config } from "../config/config"
import { ProviderTransform } from "../provider/transform"
import { Snapshot } from "../snapshot"
export namespace Session {
const log = Log.create({ service: "session" })
@@ -53,6 +54,13 @@ export namespace Session {
created: z.number(),
updated: z.number(),
}),
revert: z
.object({
messageID: z.string(),
part: z.number(),
snapshot: z.string().optional(),
})
.optional(),
})
.openapi({
ref: "Session",
@@ -285,6 +293,37 @@ export namespace Session {
l.info("chatting")
const model = await Provider.getModel(input.providerID, input.modelID)
let msgs = await messages(input.sessionID)
const session = await get(input.sessionID)
if (session.revert) {
const trimmed = []
for (const msg of msgs) {
if (
msg.id > session.revert.messageID ||
(msg.id === session.revert.messageID && session.revert.part === 0)
) {
await Storage.remove(
"session/message/" + input.sessionID + "/" + msg.id,
)
await Bus.publish(Message.Event.Removed, {
sessionID: input.sessionID,
messageID: msg.id,
})
continue
}
if (msg.id === session.revert.messageID) {
if (session.revert.part === 0) break
msg.parts = msg.parts.slice(0, session.revert.part)
}
trimmed.push(msg)
}
msgs = trimmed
await update(input.sessionID, (draft) => {
draft.revert = undefined
})
}
const previous = msgs.at(-1)
// auto summarize if too long
@@ -319,7 +358,6 @@ export namespace Session {
if (lastSummary) msgs = msgs.filter((msg) => msg.id >= lastSummary.id)
const app = App.info()
const session = await get(input.sessionID)
if (msgs.length === 0 && !session.parentID) {
generateText({
maxTokens: input.providerID === "google" ? 1024 : 20,
@@ -349,6 +387,7 @@ export namespace Session {
})
.catch(() => {})
}
const snapshot = await Snapshot.create(input.sessionID)
const msg: Message.Info = {
role: "user",
id: Identifier.ascending("message"),
@@ -359,6 +398,7 @@ export namespace Session {
},
sessionID: input.sessionID,
tool: {},
snapshot,
},
}
await updateMessage(msg)
@@ -373,6 +413,7 @@ export namespace Session {
role: "assistant",
parts: [],
metadata: {
snapshot,
assistant: {
system,
path: {
@@ -424,6 +465,7 @@ export namespace Session {
})
next.metadata!.tool![opts.toolCallId] = {
...result.metadata,
snapshot: await Snapshot.create(input.sessionID),
time: {
start,
end: Date.now(),
@@ -436,6 +478,7 @@ export namespace Session {
error: true,
message: e.toString(),
title: e.toString(),
snapshot: await Snapshot.create(input.sessionID),
time: {
start,
end: Date.now(),
@@ -457,6 +500,7 @@ export namespace Session {
const result = await execute(args, opts)
next.metadata!.tool![opts.toolCallId] = {
...result.metadata,
snapshot: await Snapshot.create(input.sessionID),
time: {
start,
end: Date.now(),
@@ -471,6 +515,7 @@ export namespace Session {
next.metadata!.tool![opts.toolCallId] = {
error: true,
message: e.toString(),
snapshot: await Snapshot.create(input.sessionID),
title: "mcp",
time: {
start,
@@ -735,6 +780,51 @@ export namespace Session {
return next
}
export async function revert(input: {
sessionID: string
messageID: string
part: number
}) {
const message = await getMessage(input.sessionID, input.messageID)
if (!message) return
const part = message.parts[input.part]
if (!part) return
const session = await get(input.sessionID)
const snapshot =
session.revert?.snapshot ?? (await Snapshot.create(input.sessionID))
const old = (() => {
if (message.role === "assistant") {
const lastTool = message.parts.findLast(
(part, index) =>
part.type === "tool-invocation" && index < input.part,
)
if (lastTool && lastTool.type === "tool-invocation")
return message.metadata.tool[lastTool.toolInvocation.toolCallId]
.snapshot
}
return message.metadata.snapshot
})()
if (old) await Snapshot.restore(input.sessionID, old)
await update(input.sessionID, (draft) => {
draft.revert = {
messageID: input.messageID,
part: input.part,
snapshot,
}
})
}
export async function unrevert(sessionID: string) {
const session = await get(sessionID)
if (!session) return
if (!session.revert) return
if (session.revert.snapshot)
await Snapshot.restore(sessionID, session.revert.snapshot)
update(sessionID, (draft) => {
draft.revert = undefined
})
}
export async function summarize(input: {
sessionID: string
providerID: string

View File

@@ -159,6 +159,7 @@ export namespace Message {
z
.object({
title: z.string(),
snapshot: z.string().optional(),
time: z.object({
start: z.number(),
end: z.number(),
@@ -188,11 +189,7 @@ export namespace Message {
}),
})
.optional(),
user: z
.object({
snapshot: z.string().optional(),
})
.optional(),
snapshot: z.string().optional(),
})
.openapi({ ref: "MessageMetadata" }),
})
@@ -208,6 +205,13 @@ export namespace Message {
info: Info,
}),
),
Removed: Bus.event(
"message.removed",
z.object({
sessionID: z.string(),
messageID: z.string(),
}),
),
PartUpdated: Bus.event(
"message.part.updated",
z.object({