diff --git a/packages/opencode/src/server/server.ts b/packages/opencode/src/server/server.ts index 971fd4db..ded9f0e2 100644 --- a/packages/opencode/src/server/server.ts +++ b/packages/opencode/src/server/server.ts @@ -25,6 +25,7 @@ import { Global } from "../global" import { ProjectRoute } from "./project" import { ToolRegistry } from "../tool/registry" import { zodToJsonSchema } from "zod-to-json-schema" +import { SessionLock } from "../session/lock" import { SessionPrompt } from "../session/prompt" import { SessionCompaction } from "../session/compaction" import { SessionRevert } from "../session/revert" @@ -549,7 +550,7 @@ export namespace Server { }), ), async (c) => { - return c.json(SessionPrompt.abort(c.req.valid("param").id)) + return c.json(SessionLock.abort(c.req.valid("param").id)) }, ) .post( diff --git a/packages/opencode/src/session/lock.ts b/packages/opencode/src/session/lock.ts new file mode 100644 index 00000000..4b510dc9 --- /dev/null +++ b/packages/opencode/src/session/lock.ts @@ -0,0 +1,94 @@ +import z from "zod/v4" +import { Instance } from "../project/instance" +import { Log } from "../util/log" +import { NamedError } from "../util/error" + +export namespace SessionLock { + const log = Log.create({ service: "session.lock" }) + + export const LockedError = NamedError.create( + "SessionLockedError", + z.object({ + sessionID: z.string(), + message: z.string(), + }), + ) + + type LockState = { + controller: AbortController + created: number + } + + const state = Instance.state( + () => { + const locks = new Map() + return { + locks, + } + }, + async (current) => { + for (const [sessionID, lock] of current.locks) { + log.info("force abort", { sessionID }) + lock.controller.abort() + } + current.locks.clear() + }, + ) + + function get(sessionID: string) { + return state().locks.get(sessionID) + } + + function unset(input: { sessionID: string; controller: AbortController }) { + const lock = get(input.sessionID) + if (!lock) return false + if (lock.controller !== input.controller) return false + state().locks.delete(input.sessionID) + return true + } + + export function acquire(input: { sessionID: string }) { + const lock = get(input.sessionID) + if (lock) { + throw new LockedError({ sessionID: input.sessionID, message: `Session ${input.sessionID} is locked` }) + } + const controller = new AbortController() + state().locks.set(input.sessionID, { + controller, + created: Date.now(), + }) + log.info("locked", { sessionID: input.sessionID }) + return { + signal: controller.signal, + abort() { + controller.abort() + unset({ sessionID: input.sessionID, controller }) + }, + async [Symbol.dispose]() { + const removed = unset({ sessionID: input.sessionID, controller }) + if (removed) { + log.info("unlocked", { sessionID: input.sessionID }) + } + }, + } + } + + export function abort(sessionID: string) { + const lock = get(sessionID) + if (!lock) return false + log.info("abort", { sessionID }) + lock.controller.abort() + state().locks.delete(sessionID) + return true + } + + export function isLocked(sessionID: string) { + return get(sessionID) !== undefined + } + + export function assertUnlocked(sessionID: string) { + const lock = get(sessionID) + if (!lock) return + throw new LockedError({ sessionID, message: `Session ${sessionID} is locked` }) + } +} diff --git a/packages/opencode/src/session/prompt.ts b/packages/opencode/src/session/prompt.ts index 949eae6b..29940dda 100644 --- a/packages/opencode/src/session/prompt.ts +++ b/packages/opencode/src/session/prompt.ts @@ -22,6 +22,7 @@ import { jsonSchema, } from "ai" import { SessionCompaction } from "./compaction" +import { SessionLock } from "./lock" import { Instance } from "../project/instance" import { Bus } from "../bus" import { ProviderTransform } from "../provider/transform" @@ -65,7 +66,6 @@ export namespace SessionPrompt { const state = Instance.state( () => { - const pending = new Map() const queued = new Map< string, { @@ -75,14 +75,11 @@ export namespace SessionPrompt { >() return { - pending, queued, } }, - async (state) => { - for (const [_, controller] of state.pending) { - controller.abort() - } + async (current) => { + current.queued.clear() }, ) @@ -1179,30 +1176,20 @@ export namespace SessionPrompt { } function isBusy(sessionID: string) { - return state().pending.has(sessionID) - } - - export function abort(sessionID: string) { - const controller = state().pending.get(sessionID) - if (!controller) return false - log.info("aborting", { - sessionID, - }) - controller.abort() - state().pending.delete(sessionID) - return true + return SessionLock.isLocked(sessionID) } function lock(sessionID: string) { + const handle = SessionLock.acquire({ + sessionID, + }) log.info("locking", { sessionID }) - if (state().pending.has(sessionID)) throw new Error("TODO") - const controller = new AbortController() - state().pending.set(sessionID, controller) return { - signal: controller.signal, + signal: handle.signal, + abort: handle.abort, async [Symbol.dispose]() { + handle[Symbol.dispose]() log.info("unlocking", { sessionID }) - state().pending.delete(sessionID) const session = await Session.get(sessionID) if (session.parentID) return diff --git a/packages/opencode/src/session/revert.ts b/packages/opencode/src/session/revert.ts index 052e582f..0b0f4294 100644 --- a/packages/opencode/src/session/revert.ts +++ b/packages/opencode/src/session/revert.ts @@ -7,6 +7,7 @@ import { Log } from "../util/log" import { splitWhen } from "remeda" import { Storage } from "../storage/storage" import { Bus } from "../bus" +import { SessionLock } from "./lock" export namespace SessionRevert { const log = Log.create({ service: "session.revert" }) @@ -19,6 +20,11 @@ export namespace SessionRevert { export type RevertInput = z.infer export async function revert(input: RevertInput) { + SessionLock.assertUnlocked(input.sessionID) + using _ = SessionLock.acquire({ + sessionID: input.sessionID, + }) + const all = await Session.messages(input.sessionID) let lastUser: MessageV2.User | undefined const session = await Session.get(input.sessionID) @@ -64,6 +70,10 @@ export namespace SessionRevert { export async function unrevert(input: { sessionID: string }) { log.info("unreverting", input) + SessionLock.assertUnlocked(input.sessionID) + using _ = SessionLock.acquire({ + sessionID: input.sessionID, + }) const session = await Session.get(input.sessionID) if (!session.revert) return session if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot) diff --git a/packages/opencode/src/tool/task.ts b/packages/opencode/src/tool/task.ts index 302e0cce..95f650e0 100644 --- a/packages/opencode/src/tool/task.ts +++ b/packages/opencode/src/tool/task.ts @@ -6,6 +6,7 @@ import { Bus } from "../bus" import { MessageV2 } from "../session/message-v2" import { Identifier } from "../id/id" import { Agent } from "../agent/agent" +import { SessionLock } from "../session/lock" import { SessionPrompt } from "../session/prompt" export const TaskTool = Tool.define("task", async () => { @@ -53,7 +54,7 @@ export const TaskTool = Tool.define("task", async () => { } ctx.abort.addEventListener("abort", () => { - SessionPrompt.abort(session.id) + SessionLock.abort(session.id) }) const result = await SessionPrompt.prompt({ messageID, diff --git a/packages/opencode/src/util/error.ts b/packages/opencode/src/util/error.ts index f93c4d71..6e5414f4 100644 --- a/packages/opencode/src/util/error.ts +++ b/packages/opencode/src/util/error.ts @@ -1,7 +1,4 @@ import z from "zod/v4" -// import { Log } from "./log" - -// const log = Log.create() export abstract class NamedError extends Error { abstract schema(): z.core.$ZodType