mirror of
https://github.com/aljazceru/opencode.git
synced 2025-12-21 17:54:23 +01:00
tweak: consolidate session lock logic (#3185)
This commit is contained in:
@@ -25,6 +25,7 @@ import { Global } from "../global"
|
|||||||
import { ProjectRoute } from "./project"
|
import { ProjectRoute } from "./project"
|
||||||
import { ToolRegistry } from "../tool/registry"
|
import { ToolRegistry } from "../tool/registry"
|
||||||
import { zodToJsonSchema } from "zod-to-json-schema"
|
import { zodToJsonSchema } from "zod-to-json-schema"
|
||||||
|
import { SessionLock } from "../session/lock"
|
||||||
import { SessionPrompt } from "../session/prompt"
|
import { SessionPrompt } from "../session/prompt"
|
||||||
import { SessionCompaction } from "../session/compaction"
|
import { SessionCompaction } from "../session/compaction"
|
||||||
import { SessionRevert } from "../session/revert"
|
import { SessionRevert } from "../session/revert"
|
||||||
@@ -549,7 +550,7 @@ export namespace Server {
|
|||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
async (c) => {
|
async (c) => {
|
||||||
return c.json(SessionPrompt.abort(c.req.valid("param").id))
|
return c.json(SessionLock.abort(c.req.valid("param").id))
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.post(
|
.post(
|
||||||
|
|||||||
94
packages/opencode/src/session/lock.ts
Normal file
94
packages/opencode/src/session/lock.ts
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
import z from "zod/v4"
|
||||||
|
import { Instance } from "../project/instance"
|
||||||
|
import { Log } from "../util/log"
|
||||||
|
import { NamedError } from "../util/error"
|
||||||
|
|
||||||
|
export namespace SessionLock {
|
||||||
|
const log = Log.create({ service: "session.lock" })
|
||||||
|
|
||||||
|
export const LockedError = NamedError.create(
|
||||||
|
"SessionLockedError",
|
||||||
|
z.object({
|
||||||
|
sessionID: z.string(),
|
||||||
|
message: z.string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
type LockState = {
|
||||||
|
controller: AbortController
|
||||||
|
created: number
|
||||||
|
}
|
||||||
|
|
||||||
|
const state = Instance.state(
|
||||||
|
() => {
|
||||||
|
const locks = new Map<string, LockState>()
|
||||||
|
return {
|
||||||
|
locks,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
async (current) => {
|
||||||
|
for (const [sessionID, lock] of current.locks) {
|
||||||
|
log.info("force abort", { sessionID })
|
||||||
|
lock.controller.abort()
|
||||||
|
}
|
||||||
|
current.locks.clear()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
function get(sessionID: string) {
|
||||||
|
return state().locks.get(sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
function unset(input: { sessionID: string; controller: AbortController }) {
|
||||||
|
const lock = get(input.sessionID)
|
||||||
|
if (!lock) return false
|
||||||
|
if (lock.controller !== input.controller) return false
|
||||||
|
state().locks.delete(input.sessionID)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
export function acquire(input: { sessionID: string }) {
|
||||||
|
const lock = get(input.sessionID)
|
||||||
|
if (lock) {
|
||||||
|
throw new LockedError({ sessionID: input.sessionID, message: `Session ${input.sessionID} is locked` })
|
||||||
|
}
|
||||||
|
const controller = new AbortController()
|
||||||
|
state().locks.set(input.sessionID, {
|
||||||
|
controller,
|
||||||
|
created: Date.now(),
|
||||||
|
})
|
||||||
|
log.info("locked", { sessionID: input.sessionID })
|
||||||
|
return {
|
||||||
|
signal: controller.signal,
|
||||||
|
abort() {
|
||||||
|
controller.abort()
|
||||||
|
unset({ sessionID: input.sessionID, controller })
|
||||||
|
},
|
||||||
|
async [Symbol.dispose]() {
|
||||||
|
const removed = unset({ sessionID: input.sessionID, controller })
|
||||||
|
if (removed) {
|
||||||
|
log.info("unlocked", { sessionID: input.sessionID })
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function abort(sessionID: string) {
|
||||||
|
const lock = get(sessionID)
|
||||||
|
if (!lock) return false
|
||||||
|
log.info("abort", { sessionID })
|
||||||
|
lock.controller.abort()
|
||||||
|
state().locks.delete(sessionID)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
export function isLocked(sessionID: string) {
|
||||||
|
return get(sessionID) !== undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
export function assertUnlocked(sessionID: string) {
|
||||||
|
const lock = get(sessionID)
|
||||||
|
if (!lock) return
|
||||||
|
throw new LockedError({ sessionID, message: `Session ${sessionID} is locked` })
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -22,6 +22,7 @@ import {
|
|||||||
jsonSchema,
|
jsonSchema,
|
||||||
} from "ai"
|
} from "ai"
|
||||||
import { SessionCompaction } from "./compaction"
|
import { SessionCompaction } from "./compaction"
|
||||||
|
import { SessionLock } from "./lock"
|
||||||
import { Instance } from "../project/instance"
|
import { Instance } from "../project/instance"
|
||||||
import { Bus } from "../bus"
|
import { Bus } from "../bus"
|
||||||
import { ProviderTransform } from "../provider/transform"
|
import { ProviderTransform } from "../provider/transform"
|
||||||
@@ -65,7 +66,6 @@ export namespace SessionPrompt {
|
|||||||
|
|
||||||
const state = Instance.state(
|
const state = Instance.state(
|
||||||
() => {
|
() => {
|
||||||
const pending = new Map<string, AbortController>()
|
|
||||||
const queued = new Map<
|
const queued = new Map<
|
||||||
string,
|
string,
|
||||||
{
|
{
|
||||||
@@ -75,14 +75,11 @@ export namespace SessionPrompt {
|
|||||||
>()
|
>()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
pending,
|
|
||||||
queued,
|
queued,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
async (state) => {
|
async (current) => {
|
||||||
for (const [_, controller] of state.pending) {
|
current.queued.clear()
|
||||||
controller.abort()
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1179,30 +1176,20 @@ export namespace SessionPrompt {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function isBusy(sessionID: string) {
|
function isBusy(sessionID: string) {
|
||||||
return state().pending.has(sessionID)
|
return SessionLock.isLocked(sessionID)
|
||||||
}
|
|
||||||
|
|
||||||
export function abort(sessionID: string) {
|
|
||||||
const controller = state().pending.get(sessionID)
|
|
||||||
if (!controller) return false
|
|
||||||
log.info("aborting", {
|
|
||||||
sessionID,
|
|
||||||
})
|
|
||||||
controller.abort()
|
|
||||||
state().pending.delete(sessionID)
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function lock(sessionID: string) {
|
function lock(sessionID: string) {
|
||||||
|
const handle = SessionLock.acquire({
|
||||||
|
sessionID,
|
||||||
|
})
|
||||||
log.info("locking", { sessionID })
|
log.info("locking", { sessionID })
|
||||||
if (state().pending.has(sessionID)) throw new Error("TODO")
|
|
||||||
const controller = new AbortController()
|
|
||||||
state().pending.set(sessionID, controller)
|
|
||||||
return {
|
return {
|
||||||
signal: controller.signal,
|
signal: handle.signal,
|
||||||
|
abort: handle.abort,
|
||||||
async [Symbol.dispose]() {
|
async [Symbol.dispose]() {
|
||||||
|
handle[Symbol.dispose]()
|
||||||
log.info("unlocking", { sessionID })
|
log.info("unlocking", { sessionID })
|
||||||
state().pending.delete(sessionID)
|
|
||||||
|
|
||||||
const session = await Session.get(sessionID)
|
const session = await Session.get(sessionID)
|
||||||
if (session.parentID) return
|
if (session.parentID) return
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import { Log } from "../util/log"
|
|||||||
import { splitWhen } from "remeda"
|
import { splitWhen } from "remeda"
|
||||||
import { Storage } from "../storage/storage"
|
import { Storage } from "../storage/storage"
|
||||||
import { Bus } from "../bus"
|
import { Bus } from "../bus"
|
||||||
|
import { SessionLock } from "./lock"
|
||||||
|
|
||||||
export namespace SessionRevert {
|
export namespace SessionRevert {
|
||||||
const log = Log.create({ service: "session.revert" })
|
const log = Log.create({ service: "session.revert" })
|
||||||
@@ -19,6 +20,11 @@ export namespace SessionRevert {
|
|||||||
export type RevertInput = z.infer<typeof RevertInput>
|
export type RevertInput = z.infer<typeof RevertInput>
|
||||||
|
|
||||||
export async function revert(input: RevertInput) {
|
export async function revert(input: RevertInput) {
|
||||||
|
SessionLock.assertUnlocked(input.sessionID)
|
||||||
|
using _ = SessionLock.acquire({
|
||||||
|
sessionID: input.sessionID,
|
||||||
|
})
|
||||||
|
|
||||||
const all = await Session.messages(input.sessionID)
|
const all = await Session.messages(input.sessionID)
|
||||||
let lastUser: MessageV2.User | undefined
|
let lastUser: MessageV2.User | undefined
|
||||||
const session = await Session.get(input.sessionID)
|
const session = await Session.get(input.sessionID)
|
||||||
@@ -64,6 +70,10 @@ export namespace SessionRevert {
|
|||||||
|
|
||||||
export async function unrevert(input: { sessionID: string }) {
|
export async function unrevert(input: { sessionID: string }) {
|
||||||
log.info("unreverting", input)
|
log.info("unreverting", input)
|
||||||
|
SessionLock.assertUnlocked(input.sessionID)
|
||||||
|
using _ = SessionLock.acquire({
|
||||||
|
sessionID: input.sessionID,
|
||||||
|
})
|
||||||
const session = await Session.get(input.sessionID)
|
const session = await Session.get(input.sessionID)
|
||||||
if (!session.revert) return session
|
if (!session.revert) return session
|
||||||
if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)
|
if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import { Bus } from "../bus"
|
|||||||
import { MessageV2 } from "../session/message-v2"
|
import { MessageV2 } from "../session/message-v2"
|
||||||
import { Identifier } from "../id/id"
|
import { Identifier } from "../id/id"
|
||||||
import { Agent } from "../agent/agent"
|
import { Agent } from "../agent/agent"
|
||||||
|
import { SessionLock } from "../session/lock"
|
||||||
import { SessionPrompt } from "../session/prompt"
|
import { SessionPrompt } from "../session/prompt"
|
||||||
|
|
||||||
export const TaskTool = Tool.define("task", async () => {
|
export const TaskTool = Tool.define("task", async () => {
|
||||||
@@ -53,7 +54,7 @@ export const TaskTool = Tool.define("task", async () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx.abort.addEventListener("abort", () => {
|
ctx.abort.addEventListener("abort", () => {
|
||||||
SessionPrompt.abort(session.id)
|
SessionLock.abort(session.id)
|
||||||
})
|
})
|
||||||
const result = await SessionPrompt.prompt({
|
const result = await SessionPrompt.prompt({
|
||||||
messageID,
|
messageID,
|
||||||
|
|||||||
@@ -1,7 +1,4 @@
|
|||||||
import z from "zod/v4"
|
import z from "zod/v4"
|
||||||
// import { Log } from "./log"
|
|
||||||
|
|
||||||
// const log = Log.create()
|
|
||||||
|
|
||||||
export abstract class NamedError extends Error {
|
export abstract class NamedError extends Error {
|
||||||
abstract schema(): z.core.$ZodType
|
abstract schema(): z.core.$ZodType
|
||||||
|
|||||||
Reference in New Issue
Block a user