From a1214fff2eaa71b3346f53580e6e94376ac9050d Mon Sep 17 00:00:00 2001 From: Dax Date: Mon, 17 Nov 2025 10:57:18 -0500 Subject: [PATCH] Refactor agent loop (#4412) --- .gitignore | 1 + .opencode/command/commit.md | 1 + .opencode/opencode.json | 4 - .opencode/opencode.jsonc | 11 + a.out | 0 packages/opencode/src/cli/cmd/debug/file.ts | 16 + .../opencode/src/cli/cmd/tui/context/sync.tsx | 12 + .../src/cli/cmd/tui/routes/session/index.tsx | 155 ++- packages/opencode/src/file/ripgrep.ts | 3 + packages/opencode/src/server/server.ts | 36 +- packages/opencode/src/session/compaction.ts | 288 ++-- packages/opencode/src/session/index.ts | 1 - packages/opencode/src/session/lock.ts | 97 -- packages/opencode/src/session/message-v2.ts | 223 +-- packages/opencode/src/session/processor.ts | 372 +++++ packages/opencode/src/session/prompt.ts | 1213 ++++++----------- packages/opencode/src/session/revert.ts | 13 +- packages/opencode/src/session/status.ts | 63 + packages/opencode/src/session/system.ts | 4 +- packages/opencode/src/tool/task.ts | 10 +- packages/sdk/js/src/gen/sdk.gen.ts | 13 + packages/sdk/js/src/gen/types.gen.ts | 85 +- 22 files changed, 1297 insertions(+), 1324 deletions(-) delete mode 100644 .opencode/opencode.json create mode 100644 .opencode/opencode.jsonc create mode 100644 a.out delete mode 100644 packages/opencode/src/session/lock.ts create mode 100644 packages/opencode/src/session/processor.ts create mode 100644 packages/opencode/src/session/status.ts diff --git a/.gitignore b/.gitignore index f69a7079..d1d98396 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ dist .turbo **/.serena .serena/ +refs diff --git a/.opencode/command/commit.md b/.opencode/command/commit.md index 2e3d759b..9626f172 100644 --- a/.opencode/command/commit.md +++ b/.opencode/command/commit.md @@ -1,5 +1,6 @@ --- description: Git commit and push +subtask: true --- commit and push diff --git a/.opencode/opencode.json b/.opencode/opencode.json deleted file mode 100644 index 7da874d3..00000000 --- a/.opencode/opencode.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "$schema": "https://opencode.ai/config.json", - "plugin": ["opencode-openai-codex-auth"] -} diff --git a/.opencode/opencode.jsonc b/.opencode/opencode.jsonc new file mode 100644 index 00000000..02278ce3 --- /dev/null +++ b/.opencode/opencode.jsonc @@ -0,0 +1,11 @@ +{ + "$schema": "https://opencode.ai/config.json", + "plugin": ["opencode-openai-codex-auth"], + "provider": { + "opencode": { + "options": { + // "baseURL": "http://localhost:8080" + }, + }, + }, +} diff --git a/a.out b/a.out new file mode 100644 index 00000000..e69de29b diff --git a/packages/opencode/src/cli/cmd/debug/file.ts b/packages/opencode/src/cli/cmd/debug/file.ts index 3d1e707d..51196614 100644 --- a/packages/opencode/src/cli/cmd/debug/file.ts +++ b/packages/opencode/src/cli/cmd/debug/file.ts @@ -2,6 +2,7 @@ import { EOL } from "os" import { File } from "../../../file" import { bootstrap } from "../../bootstrap" import { cmd } from "../cmd" +import { Ripgrep } from "@/file/ripgrep" const FileSearchCommand = cmd({ command: "search ", @@ -62,6 +63,20 @@ const FileListCommand = cmd({ }, }) +const FileTreeCommand = cmd({ + command: "tree [dir]", + builder: (yargs) => + yargs.positional("dir", { + type: "string", + description: "Directory to tree", + default: process.cwd(), + }), + async handler(args) { + const files = await Ripgrep.tree({ cwd: args.dir, limit: 200 }) + console.log(files) + }, +}) + export const FileCommand = cmd({ command: "file", builder: (yargs) => @@ -70,6 +85,7 @@ export const FileCommand = cmd({ .command(FileStatusCommand) .command(FileListCommand) .command(FileSearchCommand) + .command(FileTreeCommand) .demandCommand(), async handler() {}, }) diff --git a/packages/opencode/src/cli/cmd/tui/context/sync.tsx b/packages/opencode/src/cli/cmd/tui/context/sync.tsx index eb8c3dfa..2c994a4a 100644 --- a/packages/opencode/src/cli/cmd/tui/context/sync.tsx +++ b/packages/opencode/src/cli/cmd/tui/context/sync.tsx @@ -11,6 +11,7 @@ import type { LspStatus, McpStatus, FormatterStatus, + SessionStatus, } from "@opencode-ai/sdk" import { createStore, produce, reconcile } from "solid-js/store" import { useSDK } from "@tui/context/sdk" @@ -33,6 +34,9 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({ } config: Config session: Session[] + session_status: { + [sessionID: string]: SessionStatus + } session_diff: { [sessionID: string]: Snapshot.FileDiff[] } @@ -58,6 +62,7 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({ command: [], provider: [], session: [], + session_status: {}, session_diff: {}, todo: {}, message: {}, @@ -140,6 +145,12 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({ }), ) break + + case "session.status": { + setStore("session_status", event.properties.sessionID, event.properties.status) + break + } + case "message.updated": { const messages = store.message[event.properties.info.sessionID] if (!messages) { @@ -240,6 +251,7 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({ sdk.client.lsp.status().then((x) => setStore("lsp", x.data!)), sdk.client.mcp.status().then((x) => setStore("mcp", x.data!)), sdk.client.formatter.status().then((x) => setStore("formatter", x.data!)), + sdk.client.session.status().then((x) => setStore("session_status", x.data!)), ]).then(() => { setStore("status", "complete") }) diff --git a/packages/opencode/src/cli/cmd/tui/routes/session/index.tsx b/packages/opencode/src/cli/cmd/tui/routes/session/index.tsx index de1dd727..192123e0 100644 --- a/packages/opencode/src/cli/cmd/tui/routes/session/index.tsx +++ b/packages/opencode/src/cli/cmd/tui/routes/session/index.tsx @@ -20,7 +20,6 @@ import { useTheme } from "@tui/context/theme" import { BoxRenderable, ScrollBoxRenderable, - TextAttributes, addDefaultParsers, MacOSScrollAccel, type ScrollAcceleration, @@ -65,7 +64,6 @@ import { Editor } from "../../util/editor" import { Global } from "@/global" import fs from "fs/promises" import stripAnsi from "strip-ansi" -import { LSP } from "@/lsp/index.ts" addDefaultParsers(parsers.parsers) @@ -101,7 +99,12 @@ export function Session() { const permissions = createMemo(() => sync.data.permission[route.sessionID] ?? []) const pending = createMemo(() => { - return messages().findLast((x) => x.role === "assistant" && !x.time?.completed)?.id + return messages().findLast((x) => x.role === "assistant" && !x.time.completed)?.id + }) + + const lastUserMessage = createMemo(() => { + const p = pending() + return messages().findLast((x) => x.role === "user" && (!p || x.id < p)) as UserMessage }) const dimensions = useTerminalDimensions() @@ -801,7 +804,7 @@ export function Session() { @@ -856,64 +859,84 @@ function UserMessage(props: { const queued = createMemo(() => props.pending && props.message.id > props.pending) const color = createMemo(() => (queued() ? theme.accent : theme.secondary)) + const compaction = createMemo(() => props.parts.find((x) => x.type === "compaction")) + return ( - - { - setHover(true) - }} - onMouseOut={() => { - setHover(false) - }} - onMouseUp={props.onMouseUp} - border={["left"]} - paddingTop={1} - paddingBottom={1} - paddingLeft={2} - marginTop={props.index === 0 ? 0 : 1} - backgroundColor={hover() ? theme.backgroundElement : theme.backgroundPanel} - customBorderChars={SplitBorder.customBorderChars} - borderColor={color()} - flexShrink={0} - > - {text()?.text} - - - - {(file) => { - const bg = createMemo(() => { - if (file.mime.startsWith("image/")) return theme.accent - if (file.mime === "application/pdf") return theme.primary - return theme.secondary - }) - return ( - - {MIME_BADGE[file.mime] ?? file.mime} - {file.filename} - - ) - }} - - - - - {sync.data.config.username ?? "You"}{" "} - ({Locale.time(props.message.time.created)})} - > - QUEUED + <> + + { + setHover(true) + }} + onMouseOut={() => { + setHover(false) + }} + onMouseUp={props.onMouseUp} + border={["left"]} + paddingTop={1} + paddingBottom={1} + paddingLeft={2} + marginTop={props.index === 0 ? 0 : 1} + backgroundColor={hover() ? theme.backgroundElement : theme.backgroundPanel} + customBorderChars={SplitBorder.customBorderChars} + borderColor={color()} + flexShrink={0} + > + {text()?.text} + + + + {(file) => { + const bg = createMemo(() => { + if (file.mime.startsWith("image/")) return theme.accent + if (file.mime === "application/pdf") return theme.primary + return theme.secondary + }) + return ( + + {MIME_BADGE[file.mime] ?? file.mime} + {file.filename} + + ) + }} + + - - - + + {sync.data.config.username ?? "You"}{" "} + ({Locale.time(props.message.time.created)})} + > + QUEUED + + + + + + + + ) } function AssistantMessage(props: { message: AssistantMessage; parts: Part[]; last: boolean }) { const local = useLocal() const { theme } = useTheme() + const sync = useSync() + const status = createMemo( + () => + sync.data.session_status[props.message.sessionID] ?? { + type: "idle", + }, + ) return ( <> @@ -945,23 +968,15 @@ function AssistantMessage(props: { message: AssistantMessage; parts: Part[]; las {props.message.error?.data.message} - item.type === "step-finish" && item.reason === "tool-calls")) - } - > - + + {Locale.titlecase(props.message.mode)} - + + + + {(status() as any).message} [attempt #{(status() as any).attempt}] + + { + const result = SessionStatus.list() + return c.json(result) + }, + ) .get( "/session/:id", describeRoute({ @@ -637,7 +659,8 @@ export namespace Server { }), ), async (c) => { - return c.json(SessionLock.abort(c.req.valid("param").id)) + SessionPrompt.cancel(c.req.valid("param").id) + return c.json(true) }, ) .post( @@ -771,7 +794,14 @@ export namespace Server { async (c) => { const id = c.req.valid("param").id const body = c.req.valid("json") - await SessionCompaction.run({ ...body, sessionID: id }) + await SessionCompaction.create({ + sessionID: id, + model: { + providerID: body.providerID, + modelID: body.modelID, + }, + }) + await SessionPrompt.loop(id) return c.json(true) }, ) diff --git a/packages/opencode/src/session/compaction.ts b/packages/opencode/src/session/compaction.ts index ff988845..0bb949ba 100644 --- a/packages/opencode/src/session/compaction.ts +++ b/packages/opencode/src/session/compaction.ts @@ -1,9 +1,8 @@ -import { streamText, type ModelMessage, type StreamTextResult, type Tool as AITool } from "ai" +import { streamText, type ModelMessage } from "ai" import { Session } from "." import { Identifier } from "../id/id" import { Instance } from "../project/instance" import { Provider } from "../provider/provider" -import { defer } from "../util/defer" import { MessageV2 } from "./message-v2" import { SystemPrompt } from "./system" import { Bus } from "../bus" @@ -13,10 +12,9 @@ import { SessionPrompt } from "./prompt" import { Flag } from "../flag/flag" import { Token } from "../util/token" import { Log } from "../util/log" -import { SessionLock } from "./lock" import { ProviderTransform } from "@/provider/transform" -import { SessionRetry } from "./retry" -import { Config } from "@/config/config" +import { SessionProcessor } from "./processor" +import { fn } from "@/util/fn" export namespace SessionCompaction { const log = Log.create({ service: "session.compaction" }) @@ -42,7 +40,6 @@ export namespace SessionCompaction { export const PRUNE_MINIMUM = 20_000 export const PRUNE_PROTECT = 40_000 - const MAX_RETRIES = 10 // goes backwards through parts until there are 40_000 tokens worth of tool // calls. then erases output of previous tool calls. idea is to throw away old @@ -87,38 +84,29 @@ export namespace SessionCompaction { } } - export async function run(input: { sessionID: string; providerID: string; modelID: string; signal?: AbortSignal }) { - if (!input.signal) SessionLock.assertUnlocked(input.sessionID) - await using lock = input.signal === undefined ? SessionLock.acquire({ sessionID: input.sessionID }) : undefined - const signal = input.signal ?? lock!.signal - - await Session.update(input.sessionID, (draft) => { - draft.time.compacting = Date.now() - }) - await using _ = defer(async () => { - await Session.update(input.sessionID, (draft) => { - draft.time.compacting = undefined - }) - }) - const toSummarize = await MessageV2.filterCompacted(MessageV2.stream(input.sessionID)) - const model = await Provider.getModel(input.providerID, input.modelID) - const system = [ - ...SystemPrompt.summarize(model.providerID), - ...(await SystemPrompt.environment()), - ...(await SystemPrompt.custom()), - ] - + export async function process(input: { + parentID: string + messages: MessageV2.WithParts[] + sessionID: string + model: { + providerID: string + modelID: string + } + abort: AbortSignal + }) { + const model = await Provider.getModel(input.model.providerID, input.model.modelID) + const system = [...SystemPrompt.summarize(model.providerID)] const msg = (await Session.updateMessage({ id: Identifier.ascending("message"), role: "assistant", - parentID: toSummarize.findLast((m) => m.info.role === "user")?.info.id!, + parentID: input.parentID, sessionID: input.sessionID, mode: "build", + summary: true, path: { cwd: Instance.directory, root: Instance.worktree, }, - summary: true, cost: 0, tokens: { output: 0, @@ -126,37 +114,27 @@ export namespace SessionCompaction { reasoning: 0, cache: { read: 0, write: 0 }, }, - modelID: input.modelID, + modelID: input.model.modelID, providerID: model.providerID, time: { created: Date.now(), }, })) as MessageV2.Assistant - - const part = (await Session.updatePart({ - type: "text", + const processor = SessionProcessor.create({ + assistantMessage: msg, sessionID: input.sessionID, - messageID: msg.id, - id: Identifier.ascending("part"), - text: "", - time: { - start: Date.now(), - }, - })) as MessageV2.TextPart - - const doStream = () => + providerID: input.model.providerID, + model: model.info, + abort: input.abort, + }) + const result = await processor.process(() => streamText({ // set to 0, we handle loop maxRetries: 0, model: model.language, providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, model.info.options), headers: model.info.headers, - abortSignal: signal, - onError(error) { - log.error("stream error", { - error, - }) - }, + abortSignal: input.abort, tools: model.info.tool_call ? {} : undefined, messages: [ ...system.map( @@ -165,7 +143,7 @@ export namespace SessionCompaction { content: x, }), ), - ...MessageV2.toModelMessage(toSummarize), + ...MessageV2.toModelMessage(input.messages), { role: "user", content: [ @@ -176,168 +154,60 @@ export namespace SessionCompaction { ], }, ], - }) - - // TODO: reduce duplication between compaction.ts & prompt.ts - const process = async ( - stream: StreamTextResult, never>, - retries: { count: number; max: number }, - ) => { - let shouldRetry = false - try { - for await (const value of stream.fullStream) { - signal.throwIfAborted() - switch (value.type) { - case "text-delta": - part.text += value.text - if (value.providerMetadata) part.metadata = value.providerMetadata - if (part.text) - await Session.updatePart({ - part, - delta: value.text, - }) - continue - case "text-end": { - part.text = part.text.trimEnd() - part.time = { - start: Date.now(), - end: Date.now(), - } - if (value.providerMetadata) part.metadata = value.providerMetadata - await Session.updatePart(part) - continue - } - case "finish-step": { - const usage = Session.getUsage({ - model: model.info, - usage: value.usage, - metadata: value.providerMetadata, - }) - msg.cost += usage.cost - msg.tokens = usage.tokens - await Session.updateMessage(msg) - continue - } - case "error": - throw value.error - default: - continue - } - } - } catch (e) { - log.error("compaction error", { - error: e, - }) - const error = MessageV2.fromError(e, { providerID: input.providerID }) - if (retries.count < retries.max && MessageV2.APIError.isInstance(error) && error.data.isRetryable) { - shouldRetry = true - await Session.updatePart({ - id: Identifier.ascending("part"), - messageID: msg.id, - sessionID: msg.sessionID, - type: "retry", - attempt: retries.count + 1, - time: { - created: Date.now(), - }, - error, - }) - } else { - msg.error = error - Bus.publish(Session.Event.Error, { - sessionID: msg.sessionID, - error: msg.error, - }) - } - } - - const parts = await MessageV2.parts(msg.id) - return { - info: msg, - parts, - shouldRetry, - } - } - - let stream = doStream() - const cfg = await Config.get() - const maxRetries = cfg.experimental?.chatMaxRetries ?? MAX_RETRIES - let result = await process(stream, { - count: 0, - max: maxRetries, - }) - if (result.shouldRetry) { - const start = Date.now() - for (let retry = 1; retry < maxRetries; retry++) { - const lastRetryPart = result.parts.findLast((p): p is MessageV2.RetryPart => p.type === "retry") - - if (lastRetryPart) { - const delayMs = SessionRetry.getBoundedDelay({ - error: lastRetryPart.error, - attempt: retry, - startTime: start, - }) - if (!delayMs) { - break - } - - log.info("retrying with backoff", { - attempt: retry, - delayMs, - elapsed: Date.now() - start, - }) - - const stop = await SessionRetry.sleep(delayMs, signal) - .then(() => false) - .catch((error) => { - if (error instanceof DOMException && error.name === "AbortError") { - const err = new MessageV2.AbortedError( - { message: error.message }, - { - cause: error, - }, - ).toObject() - result.info.error = err - Bus.publish(Session.Event.Error, { - sessionID: result.info.sessionID, - error: result.info.error, - }) - return true - } - throw error - }) - - if (stop) break - } - - stream = doStream() - result = await process(stream, { - count: retry, - max: maxRetries, - }) - if (!result.shouldRetry) { - break - } - } - } - - msg.time.completed = Date.now() - - if ( - !msg.error || - (MessageV2.AbortedError.isInstance(msg.error) && - result.parts.some((part): part is MessageV2.TextPart => part.type === "text" && part.text.length > 0)) - ) { - msg.summary = true - Bus.publish(Event.Compacted, { + }), + ) + if (result === "continue") { + const continueMsg = await Session.updateMessage({ + id: Identifier.ascending("message"), + role: "user", sessionID: input.sessionID, + time: { + created: Date.now(), + }, + agent: "build", + model: input.model, + }) + await Session.updatePart({ + id: Identifier.ascending("part"), + messageID: continueMsg.id, + sessionID: input.sessionID, + type: "text", + synthetic: true, + text: "Continue if you have next steps", + time: { + start: Date.now(), + end: Date.now(), + }, }) } - await Session.updateMessage(msg) - - return { - info: msg, - parts: result.parts, - } + return "continue" } + + export const create = fn( + z.object({ + sessionID: Identifier.schema("session"), + model: z.object({ + providerID: z.string(), + modelID: z.string(), + }), + }), + async (input) => { + const msg = await Session.updateMessage({ + id: Identifier.ascending("message"), + role: "user", + model: input.model, + sessionID: input.sessionID, + agent: "build", + time: { + created: Date.now(), + }, + }) + await Session.updatePart({ + id: Identifier.ascending("part"), + messageID: msg.id, + sessionID: msg.sessionID, + type: "compaction", + }) + }, + ) } diff --git a/packages/opencode/src/session/index.ts b/packages/opencode/src/session/index.ts index 395014ff..a9ab8ea9 100644 --- a/packages/opencode/src/session/index.ts +++ b/packages/opencode/src/session/index.ts @@ -1,7 +1,6 @@ import { Decimal } from "decimal.js" import z from "zod" import { type LanguageModelUsage, type ProviderMetadata } from "ai" - import { Bus } from "../bus" import { Config } from "../config/config" import { Flag } from "../flag/flag" diff --git a/packages/opencode/src/session/lock.ts b/packages/opencode/src/session/lock.ts deleted file mode 100644 index 22eb8187..00000000 --- a/packages/opencode/src/session/lock.ts +++ /dev/null @@ -1,97 +0,0 @@ -import z from "zod" -import { Instance } from "../project/instance" -import { Log } from "../util/log" -import { NamedError } from "../util/error" - -export namespace SessionLock { - const log = Log.create({ service: "session.lock" }) - - export const LockedError = NamedError.create( - "SessionLockedError", - z.object({ - sessionID: z.string(), - message: z.string(), - }), - ) - - type LockState = { - controller: AbortController - created: number - } - - const state = Instance.state( - () => { - const locks = new Map() - return { - locks, - } - }, - async (current) => { - for (const [sessionID, lock] of current.locks) { - log.info("force abort", { sessionID }) - lock.controller.abort() - } - current.locks.clear() - }, - ) - - function get(sessionID: string) { - return state().locks.get(sessionID) - } - - function unset(input: { sessionID: string; controller: AbortController }) { - const lock = get(input.sessionID) - if (!lock) return false - if (lock.controller !== input.controller) return false - state().locks.delete(input.sessionID) - return true - } - - export function acquire(input: { sessionID: string }) { - const lock = get(input.sessionID) - if (lock) { - throw new LockedError({ - sessionID: input.sessionID, - message: `Session ${input.sessionID} is locked`, - }) - } - const controller = new AbortController() - state().locks.set(input.sessionID, { - controller, - created: Date.now(), - }) - log.info("locked", { sessionID: input.sessionID }) - return { - signal: controller.signal, - abort() { - controller.abort() - unset({ sessionID: input.sessionID, controller }) - }, - async [Symbol.dispose]() { - const removed = unset({ sessionID: input.sessionID, controller }) - if (removed) { - log.info("unlocked", { sessionID: input.sessionID }) - } - }, - } - } - - export function abort(sessionID: string) { - const lock = get(sessionID) - if (!lock) return false - log.info("abort", { sessionID }) - lock.controller.abort() - state().locks.delete(sessionID) - return true - } - - export function isLocked(sessionID: string) { - return get(sessionID) !== undefined - } - - export function assertUnlocked(sessionID: string) { - const lock = get(sessionID) - if (!lock) return - throw new LockedError({ sessionID, message: `Session ${sessionID} is locked` }) - } -} diff --git a/packages/opencode/src/session/message-v2.ts b/packages/opencode/src/session/message-v2.ts index 5435b9e2..69087356 100644 --- a/packages/opencode/src/session/message-v2.ts +++ b/packages/opencode/src/session/message-v2.ts @@ -142,6 +142,21 @@ export namespace MessageV2 { }) export type AgentPart = z.infer + export const CompactionPart = PartBase.extend({ + type: z.literal("compaction"), + }).meta({ + ref: "CompactionPart", + }) + export type CompactionPart = z.infer + + export const SubtaskPart = PartBase.extend({ + type: z.literal("subtask"), + prompt: z.string(), + description: z.string(), + agent: z.string(), + }) + export type SubtaskPart = z.infer + export const RetryPart = PartBase.extend({ type: z.literal("retry"), attempt: z.number(), @@ -277,6 +292,13 @@ export namespace MessageV2 { diffs: Snapshot.FileDiff.array(), }) .optional(), + agent: z.string(), + model: z.object({ + providerID: z.string(), + modelID: z.string(), + }), + system: z.string().optional(), + tools: z.record(z.string(), z.boolean()).optional(), }).meta({ ref: "UserMessage", }) @@ -285,6 +307,7 @@ export namespace MessageV2 { export const Part = z .discriminatedUnion("type", [ TextPart, + SubtaskPart, ReasoningPart, FilePart, ToolPart, @@ -294,6 +317,7 @@ export namespace MessageV2 { PatchPart, AgentPart, RetryPart, + CompactionPart, ]) .meta({ ref: "Part", @@ -334,6 +358,7 @@ export namespace MessageV2 { write: z.number(), }), }), + finish: z.string().optional(), }).meta({ ref: "AssistantMessage", }) @@ -482,6 +507,11 @@ export namespace MessageV2 { time: { created: v1.metadata.time.created, }, + agent: "build", + model: { + providerID: "opencode", + modelID: "opencode", + }, } const parts = v1.parts.flatMap((part): Part[] => { const base = { @@ -529,107 +559,107 @@ export namespace MessageV2 { if (msg.parts.length === 0) continue if (msg.info.role === "user") { - result.push({ + const userMessage: UIMessage = { id: msg.info.id, role: "user", - parts: msg.parts.flatMap((part): UIMessage["parts"] => { - if (part.type === "text") - return [ - { - type: "text", - text: part.text, - }, - ] - // text/plain and directory files are converted into text parts, ignore them - if (part.type === "file" && part.mime !== "text/plain" && part.mime !== "application/x-directory") - return [ - { - type: "file", - url: part.url, - mediaType: part.mime, - filename: part.filename, - }, - ] - return [] - }), - }) + parts: [], + } + result.push(userMessage) + for (const part of msg.parts) { + if (part.type === "text") + userMessage.parts.push({ + type: "text", + text: part.text, + }) + // text/plain and directory files are converted into text parts, ignore them + if (part.type === "file" && part.mime !== "text/plain" && part.mime !== "application/x-directory") + userMessage.parts.push({ + type: "file", + url: part.url, + mediaType: part.mime, + filename: part.filename, + }) + + if (part.type === "compaction") { + userMessage.parts.push({ + type: "text", + text: "What did we do so far?", + }) + } + if (part.type === "subtask") { + userMessage.parts.push({ + type: "text", + text: "The following tool was executed by the user", + }) + } + } } if (msg.info.role === "assistant") { - result.push({ + const assistantMessage: UIMessage = { id: msg.info.id, role: "assistant", - parts: msg.parts.flatMap((part): UIMessage["parts"] => { - if (part.type === "text") - return [ - { - type: "text", - text: part.text, - providerMetadata: part.metadata, - }, - ] - if (part.type === "step-start") - return [ - { - type: "step-start", - }, - ] - if (part.type === "tool") { - if (part.state.status === "completed") { - if (part.state.attachments?.length) { - result.push({ - id: Identifier.ascending("message"), - role: "user", - parts: [ - { - type: "text", - text: `Tool ${part.tool} returned an attachment:`, - }, - ...part.state.attachments.map((attachment) => ({ - type: "file" as const, - url: attachment.url, - mediaType: attachment.mime, - filename: attachment.filename, - })), - ], - }) - } - return [ - { - type: ("tool-" + part.tool) as `tool-${string}`, - state: "output-available", - toolCallId: part.callID, - input: part.state.input, - output: part.state.time.compacted ? "[Old tool result content cleared]" : part.state.output, - callProviderMetadata: part.metadata, - }, - ] + parts: [], + } + result.push(assistantMessage) + for (const part of msg.parts) { + if (part.type === "text") + assistantMessage.parts.push({ + type: "text", + text: part.text, + providerMetadata: part.metadata, + }) + if (part.type === "step-start") + assistantMessage.parts.push({ + type: "step-start", + }) + if (part.type === "tool") { + if (part.state.status === "completed") { + if (part.state.attachments?.length) { + result.push({ + id: Identifier.ascending("message"), + role: "user", + parts: [ + { + type: "text", + text: `Tool ${part.tool} returned an attachment:`, + }, + ...part.state.attachments.map((attachment) => ({ + type: "file" as const, + url: attachment.url, + mediaType: attachment.mime, + filename: attachment.filename, + })), + ], + }) } - if (part.state.status === "error") - return [ - { - type: ("tool-" + part.tool) as `tool-${string}`, - state: "output-error", - toolCallId: part.callID, - input: part.state.input, - errorText: part.state.error, - callProviderMetadata: part.metadata, - }, - ] + assistantMessage.parts.push({ + type: ("tool-" + part.tool) as `tool-${string}`, + state: "output-available", + toolCallId: part.callID, + input: part.state.input, + output: part.state.time.compacted ? "[Old tool result content cleared]" : part.state.output, + callProviderMetadata: part.metadata, + }) } - if (part.type === "reasoning") { - return [ - { - type: "reasoning", - text: part.text, - providerMetadata: part.metadata, - }, - ] - } - - return [] - }), - }) + if (part.state.status === "error") + assistantMessage.parts.push({ + type: ("tool-" + part.tool) as `tool-${string}`, + state: "output-error", + toolCallId: part.callID, + input: part.state.input, + errorText: part.state.error, + callProviderMetadata: part.metadata, + }) + } + if (part.type === "reasoning") { + assistantMessage.parts.push({ + type: "reasoning", + text: part.text, + providerMetadata: part.metadata, + }) + } + } } } @@ -671,9 +701,16 @@ export namespace MessageV2 { export async function filterCompacted(stream: AsyncIterable) { const result = [] as MessageV2.WithParts[] + const completed = new Set() for await (const msg of stream) { result.push(msg) - if (msg.info.role === "assistant" && msg.info.summary === true) break + if ( + msg.info.role === "user" && + completed.has(msg.info.id) && + msg.parts.some((part) => part.type === "compaction") + ) + break + if (msg.info.role === "assistant" && msg.info.summary && msg.info.finish) completed.add(msg.info.parentID) } result.reverse() return result diff --git a/packages/opencode/src/session/processor.ts b/packages/opencode/src/session/processor.ts new file mode 100644 index 00000000..de96c5ee --- /dev/null +++ b/packages/opencode/src/session/processor.ts @@ -0,0 +1,372 @@ +import type { ModelsDev } from "@/provider/models" +import { MessageV2 } from "./message-v2" +import { type StreamTextResult, type Tool as AITool, APICallError } from "ai" +import { Log } from "@/util/log" +import { Identifier } from "@/id/id" +import { Session } from "." +import { Agent } from "@/agent/agent" +import { Permission } from "@/permission" +import { Snapshot } from "@/snapshot" +import { SessionSummary } from "./summary" +import { Bus } from "@/bus" +import { SessionRetry } from "./retry" +import { SessionStatus } from "./status" + +export namespace SessionProcessor { + const DOOM_LOOP_THRESHOLD = 3 + const log = Log.create({ service: "session.processor" }) + + export type Info = Awaited> + export type Result = Awaited> + + export function create(input: { + assistantMessage: MessageV2.Assistant + sessionID: string + providerID: string + model: ModelsDev.Model + abort: AbortSignal + }) { + const toolcalls: Record = {} + let snapshot: string | undefined + let blocked = false + let attempt = 0 + + const result = { + get message() { + return input.assistantMessage + }, + partFromToolCall(toolCallID: string) { + return toolcalls[toolCallID] + }, + async process(fn: () => StreamTextResult, never>) { + log.info("process") + while (true) { + try { + let currentText: MessageV2.TextPart | undefined + let reasoningMap: Record = {} + const stream = fn() + + for await (const value of stream.fullStream) { + input.abort.throwIfAborted() + switch (value.type) { + case "start": + SessionStatus.set(input.sessionID, { type: "busy" }) + break + + case "reasoning-start": + if (value.id in reasoningMap) { + continue + } + reasoningMap[value.id] = { + id: Identifier.ascending("part"), + messageID: input.assistantMessage.id, + sessionID: input.assistantMessage.sessionID, + type: "reasoning", + text: "", + time: { + start: Date.now(), + }, + metadata: value.providerMetadata, + } + break + + case "reasoning-delta": + if (value.id in reasoningMap) { + const part = reasoningMap[value.id] + part.text += value.text + if (value.providerMetadata) part.metadata = value.providerMetadata + if (part.text) await Session.updatePart({ part, delta: value.text }) + } + break + + case "reasoning-end": + if (value.id in reasoningMap) { + const part = reasoningMap[value.id] + part.text = part.text.trimEnd() + + part.time = { + ...part.time, + end: Date.now(), + } + if (value.providerMetadata) part.metadata = value.providerMetadata + await Session.updatePart(part) + delete reasoningMap[value.id] + } + break + + case "tool-input-start": + const part = await Session.updatePart({ + id: toolcalls[value.id]?.id ?? Identifier.ascending("part"), + messageID: input.assistantMessage.id, + sessionID: input.assistantMessage.sessionID, + type: "tool", + tool: value.toolName, + callID: value.id, + state: { + status: "pending", + input: {}, + raw: "", + }, + }) + toolcalls[value.id] = part as MessageV2.ToolPart + break + + case "tool-input-delta": + break + + case "tool-input-end": + break + + case "tool-call": { + const match = toolcalls[value.toolCallId] + if (match) { + const part = await Session.updatePart({ + ...match, + tool: value.toolName, + state: { + status: "running", + input: value.input, + time: { + start: Date.now(), + }, + }, + metadata: value.providerMetadata, + }) + toolcalls[value.toolCallId] = part as MessageV2.ToolPart + + const parts = await MessageV2.parts(input.assistantMessage.id) + const lastThree = parts.slice(-DOOM_LOOP_THRESHOLD) + if ( + lastThree.length === DOOM_LOOP_THRESHOLD && + lastThree.every( + (p) => + p.type === "tool" && + p.tool === value.toolName && + p.state.status !== "pending" && + JSON.stringify(p.state.input) === JSON.stringify(value.input), + ) + ) { + const permission = await Agent.get(input.assistantMessage.mode).then((x) => x.permission) + if (permission.doom_loop === "ask") { + await Permission.ask({ + type: "doom_loop", + pattern: value.toolName, + sessionID: input.assistantMessage.sessionID, + messageID: input.assistantMessage.id, + callID: value.toolCallId, + title: `Possible doom loop: "${value.toolName}" called ${DOOM_LOOP_THRESHOLD} times with identical arguments`, + metadata: { + tool: value.toolName, + input: value.input, + }, + }) + } + } + } + break + } + case "tool-result": { + const match = toolcalls[value.toolCallId] + if (match && match.state.status === "running") { + await Session.updatePart({ + ...match, + state: { + status: "completed", + input: value.input, + output: value.output.output, + metadata: value.output.metadata, + title: value.output.title, + time: { + start: match.state.time.start, + end: Date.now(), + }, + attachments: value.output.attachments, + }, + }) + + delete toolcalls[value.toolCallId] + } + break + } + + case "tool-error": { + const match = toolcalls[value.toolCallId] + if (match && match.state.status === "running") { + await Session.updatePart({ + ...match, + state: { + status: "error", + input: value.input, + error: (value.error as any).toString(), + metadata: value.error instanceof Permission.RejectedError ? value.error.metadata : undefined, + time: { + start: match.state.time.start, + end: Date.now(), + }, + }, + }) + + if (value.error instanceof Permission.RejectedError) { + blocked = true + } + delete toolcalls[value.toolCallId] + } + break + } + case "error": + throw value.error + + case "start-step": + snapshot = await Snapshot.track() + await Session.updatePart({ + id: Identifier.ascending("part"), + messageID: input.assistantMessage.id, + sessionID: input.sessionID, + snapshot, + type: "step-start", + }) + break + + case "finish-step": + const usage = Session.getUsage({ + model: input.model, + usage: value.usage, + metadata: value.providerMetadata, + }) + input.assistantMessage.finish = value.finishReason + input.assistantMessage.cost += usage.cost + input.assistantMessage.tokens = usage.tokens + await Session.updatePart({ + id: Identifier.ascending("part"), + reason: value.finishReason, + snapshot: await Snapshot.track(), + messageID: input.assistantMessage.id, + sessionID: input.assistantMessage.sessionID, + type: "step-finish", + tokens: usage.tokens, + cost: usage.cost, + }) + await Session.updateMessage(input.assistantMessage) + if (snapshot) { + const patch = await Snapshot.patch(snapshot) + if (patch.files.length) { + await Session.updatePart({ + id: Identifier.ascending("part"), + messageID: input.assistantMessage.id, + sessionID: input.sessionID, + type: "patch", + hash: patch.hash, + files: patch.files, + }) + } + snapshot = undefined + } + SessionSummary.summarize({ + sessionID: input.sessionID, + messageID: input.assistantMessage.parentID, + }) + break + + case "text-start": + currentText = { + id: Identifier.ascending("part"), + messageID: input.assistantMessage.id, + sessionID: input.assistantMessage.sessionID, + type: "text", + text: "", + time: { + start: Date.now(), + }, + metadata: value.providerMetadata, + } + break + + case "text-delta": + if (currentText) { + currentText.text += value.text + if (value.providerMetadata) currentText.metadata = value.providerMetadata + if (currentText.text) + await Session.updatePart({ + part: currentText, + delta: value.text, + }) + } + break + + case "text-end": + if (currentText) { + currentText.text = currentText.text.trimEnd() + currentText.time = { + start: Date.now(), + end: Date.now(), + } + if (value.providerMetadata) currentText.metadata = value.providerMetadata + await Session.updatePart(currentText) + } + currentText = undefined + break + + case "finish": + input.assistantMessage.time.completed = Date.now() + await Session.updateMessage(input.assistantMessage) + break + + default: + log.info("unhandled", { + ...value, + }) + continue + } + } + } catch (e) { + log.error("process", { + error: e, + }) + const error = MessageV2.fromError(e, { providerID: input.providerID }) + if (error?.name === "APIError" && error.data.isRetryable) { + attempt++ + const delay = SessionRetry.getRetryDelayInMs(error, attempt) + if (delay) { + SessionStatus.set(input.sessionID, { + type: "retry", + attempt, + message: error.data.message, + }) + await SessionRetry.sleep(delay, input.abort).catch(() => {}) + continue + } + } + input.assistantMessage.error = error + Bus.publish(Session.Event.Error, { + sessionID: input.assistantMessage.sessionID, + error: input.assistantMessage.error, + }) + } + const p = await MessageV2.parts(input.assistantMessage.id) + for (const part of p) { + if (part.type === "tool" && part.state.status !== "completed" && part.state.status !== "error") { + await Session.updatePart({ + ...part, + state: { + ...part.state, + status: "error", + error: "Tool execution aborted", + time: { + start: Date.now(), + end: Date.now(), + }, + }, + }) + } + } + input.assistantMessage.time.completed = Date.now() + await Session.updateMessage(input.assistantMessage) + if (blocked) return "stop" + if (input.assistantMessage.error) return "stop" + return "continue" + } + }, + } + return result + } +} diff --git a/packages/opencode/src/session/prompt.ts b/packages/opencode/src/session/prompt.ts index 8e27d714..b8f3d112 100644 --- a/packages/opencode/src/session/prompt.ts +++ b/packages/opencode/src/session/prompt.ts @@ -16,22 +16,18 @@ import { type Tool as AITool, tool, wrapLanguageModel, - type StreamTextResult, stepCountIs, jsonSchema, } from "ai" import { SessionCompaction } from "./compaction" -import { SessionLock } from "./lock" import { Instance } from "../project/instance" import { Bus } from "../bus" import { ProviderTransform } from "../provider/transform" import { SystemPrompt } from "./system" import { Plugin } from "../plugin" -import { SessionRetry } from "./retry" import PROMPT_PLAN from "../session/prompt/plan.txt" import BUILD_SWITCH from "../session/prompt/build-switch.txt" -import { ModelsDev } from "../provider/models" import { defer } from "../util/defer" import { mergeDeep, pipe } from "remeda" import { ToolRegistry } from "../tool/registry" @@ -40,24 +36,22 @@ import { MCP } from "../mcp" import { LSP } from "../lsp" import { ReadTool } from "../tool/read" import { ListTool } from "../tool/ls" -import { TaskTool } from "../tool/task" import { FileTime } from "../file/time" -import { Permission } from "../permission" -import { Snapshot } from "../snapshot" import { ulid } from "ulid" import { spawn } from "child_process" import { Command } from "../command" import { $, fileURLToPath } from "bun" import { ConfigMarkdown } from "../config/markdown" import { SessionSummary } from "./summary" -import { Config } from "@/config/config" import { NamedError } from "@/util/error" +import { fn } from "@/util/fn" +import { SessionProcessor } from "./processor" +import { TaskTool } from "@/tool/task" +import { SessionStatus } from "./status" export namespace SessionPrompt { const log = Log.create({ service: "session.prompt" }) export const OUTPUT_TOKEN_MAX = 32_000 - const MAX_RETRIES = 10 - const DOOM_LOOP_THRESHOLD = 3 export const Event = { Idle: Bus.event( @@ -70,32 +64,30 @@ export namespace SessionPrompt { const state = Instance.state( () => { - const queued = new Map< + const data: Record< string, { - messageID: string - callback: (input: MessageV2.WithParts) => void - }[] - >() - const pending = new Set>() - - const track = (promise: Promise) => { - pending.add(promise) - promise.finally(() => pending.delete(promise)) - } - - return { - queued, - pending, - track, - } + abort: AbortController + callbacks: { + resolve(input: MessageV2.WithParts): void + reject(): void + }[] + } + > = {} + return data }, async (current) => { - current.queued.clear() - await Promise.allSettled([...current.pending]) + for (const item of Object.values(current)) { + item.abort.abort() + } }, ) + export function assertNotBusy(sessionID: string) { + const match = state()[sessionID] + if (match) throw new Session.BusyError(sessionID) + } + export const PromptInput = z.object({ sessionID: Identifier.schema("session"), messageID: Identifier.schema("message").optional(), @@ -141,6 +133,16 @@ export namespace SessionPrompt { .meta({ ref: "AgentPartInput", }), + MessageV2.SubtaskPart.omit({ + messageID: true, + sessionID: true, + }) + .partial({ + id: true, + }) + .meta({ + ref: "SubtaskPartInput", + }), ]), ), }) @@ -193,118 +195,310 @@ export namespace SessionPrompt { ) return parts } - export async function prompt(input: PromptInput): Promise { - const l = log.clone().tag("session", input.sessionID) - l.info("prompt") + export const prompt = fn(PromptInput, async (input) => { const session = await Session.get(input.sessionID) await SessionRevert.cleanup(session) - const userMsg = await createUserMessage(input) + await createUserMessage(input) await Session.touch(input.sessionID) - // Early return for context-only messages (no AI inference) - if (input.noReply) { - return userMsg - } + return loop(input.sessionID) + }) - if (isBusy(input.sessionID)) { - return new Promise((resolve) => { - const queue = state().queued.get(input.sessionID) ?? [] - queue.push({ - messageID: userMsg.info.id, - callback: resolve, - }) - state().queued.set(input.sessionID, queue) + function start(sessionID: string) { + const s = state() + if (s[sessionID]) return + const controller = new AbortController() + s[sessionID] = { + abort: controller, + callbacks: [], + } + return controller.signal + } + + export function cancel(sessionID: string) { + log.info("cancel", { sessionID }) + const s = state() + const match = s[sessionID] + if (!match) return + match.abort.abort() + for (const item of match.callbacks) { + item.reject() + } + delete s[sessionID] + SessionStatus.set(sessionID, { type: "idle" }) + return + } + + export const loop = fn(Identifier.schema("session"), async (sessionID) => { + const abort = start(sessionID) + if (!abort) { + return new Promise((resolve, reject) => { + const callbacks = state()[sessionID].callbacks + callbacks.push({ resolve, reject }) }) } - const agent = await Agent.get(input.agent ?? "build") - const model = await resolveModel({ - agent, - model: input.model, - }).then((x) => Provider.getModel(x.providerID, x.modelID)) - using abort = lock(input.sessionID) - - const system = await resolveSystemPrompt({ - providerID: model.providerID, - modelID: model.info.id, - agent, - system: input.system, - }) - - const processor = await createProcessor({ - sessionID: input.sessionID, - model: model.info, - providerID: model.providerID, - agent: agent.name, - system, - abort: abort.signal, - }) - - const tools = await resolveTools({ - agent, - sessionID: input.sessionID, - modelID: model.modelID, - providerID: model.providerID, - tools: input.tools, - processor, - }) - - const params = await Plugin.trigger( - "chat.params", - { - sessionID: input.sessionID, - agent: agent.name, - model: model.info, - provider: await Provider.getProvider(model.providerID), - message: userMsg, - }, - { - temperature: model.info.temperature - ? (agent.temperature ?? ProviderTransform.temperature(model.providerID, model.modelID)) - : undefined, - topP: agent.topP ?? ProviderTransform.topP(model.providerID, model.modelID), - options: { - ...ProviderTransform.options(model.providerID, model.modelID, model.npm ?? "", input.sessionID), - ...model.info.options, - ...agent.options, - }, - }, - ) + using _ = defer(() => cancel(sessionID)) let step = 0 while (true) { - const msgs: MessageV2.WithParts[] = pipe( - await getMessages({ - sessionID: input.sessionID, - model: model.info, - providerID: model.providerID, - signal: abort.signal, - }), - (messages) => insertReminders({ messages, agent }), - ) + log.info("loop", { step, sessionID }) + if (abort.aborted) break + let msgs = await MessageV2.filterCompacted(MessageV2.stream(sessionID)) + + let lastUser: MessageV2.User | undefined + let lastAssistant: MessageV2.Assistant | undefined + let lastFinished: MessageV2.Assistant | undefined + let tasks: (MessageV2.CompactionPart | MessageV2.SubtaskPart)[] = [] + for (let i = msgs.length - 1; i >= 0; i--) { + const msg = msgs[i] + if (!lastUser && msg.info.role === "user") lastUser = msg.info as MessageV2.User + if (!lastAssistant && msg.info.role === "assistant") lastAssistant = msg.info as MessageV2.Assistant + if (!lastFinished && msg.info.role === "assistant" && msg.info.finish) + lastFinished = msg.info as MessageV2.Assistant + if (lastUser && lastFinished) break + const task = msg.parts.filter((part) => part.type === "compaction" || part.type === "subtask") + if (task && !lastFinished) { + tasks.push(...task) + } + } + + if (!lastUser) throw new Error("No user message found in stream. This should never happen.") + if (lastAssistant?.finish && lastAssistant.finish !== "tool-calls" && lastUser.id < lastAssistant.id) { + log.info("exiting loop", { sessionID }) + break + } + step++ - await processor.next(msgs.findLast((m) => m.info.role === "user")?.info.id!) - if (step === 1) { - state().track( - ensureTitle({ - session, - history: msgs, - message: userMsg, + if (step === 1) + ensureTitle({ + session: await Session.get(sessionID), + modelID: lastUser.model.modelID, + providerID: lastUser.model.providerID, + message: msgs.find((m) => m.info.role === "user")!, + history: msgs, + }) + + const model = await Provider.getModel(lastUser.model.providerID, lastUser.model.modelID) + const task = tasks.pop() + + // pending subtask + // TODO: centralize "invoke tool" logic + if (task?.type === "subtask") { + const taskTool = await TaskTool.init() + const assistantMessage = (await Session.updateMessage({ + id: Identifier.ascending("message"), + role: "assistant", + parentID: lastUser.id, + sessionID, + mode: task.agent, + path: { + cwd: Instance.directory, + root: Instance.worktree, + }, + cost: 0, + tokens: { + input: 0, + output: 0, + reasoning: 0, + cache: { read: 0, write: 0 }, + }, + modelID: model.modelID, + providerID: model.providerID, + time: { + created: Date.now(), + }, + })) as MessageV2.Assistant + let part = (await Session.updatePart({ + id: Identifier.ascending("part"), + messageID: assistantMessage.id, + sessionID: assistantMessage.sessionID, + type: "tool", + callID: ulid(), + tool: TaskTool.id, + state: { + status: "running", + input: { + prompt: task.prompt, + description: task.description, + subagent_type: task.agent, + }, + time: { + start: Date.now(), + }, + }, + })) as MessageV2.ToolPart + const result = await taskTool + .execute( + { + prompt: task.prompt, + description: task.description, + subagent_type: task.agent, + }, + { + agent: task.agent, + messageID: assistantMessage.id, + sessionID: sessionID, + abort, + async metadata(input) { + await Session.updatePart({ + ...part, + type: "tool", + state: { + ...part.state, + ...input, + }, + } satisfies MessageV2.ToolPart) + }, + }, + ) + .catch(() => {}) + assistantMessage.finish = "tool-calls" + assistantMessage.time.completed = Date.now() + await Session.updateMessage(assistantMessage) + if (result && part.state.status === "running") { + await Session.updatePart({ + ...part, + state: { + status: "completed", + input: part.state.input, + title: result.title, + metadata: result.metadata, + output: result.output, + attachments: result.attachments, + time: { + ...part.state.time, + end: Date.now(), + }, + }, + } satisfies MessageV2.ToolPart) + } + if (!result) { + await Session.updatePart({ + ...part, + state: { + status: "error", + error: "Tool execution failed", + time: { + start: part.state.status === "running" ? part.state.time.start : Date.now(), + end: Date.now(), + }, + metadata: part.metadata, + input: part.state.input, + }, + } satisfies MessageV2.ToolPart) + } + continue + } + + // pending compaction + if (task?.type === "compaction") { + await SessionCompaction.process({ + messages: msgs, + parentID: lastUser.id, + abort, + model: { providerID: model.providerID, - modelID: model.info.id, - }), - ) + modelID: model.modelID, + }, + sessionID, + }) + continue + } + + // context overflow, needs compaction + if ( + lastFinished && + lastFinished.summary !== true && + SessionCompaction.isOverflow({ tokens: lastFinished.tokens, model: model.info }) + ) { + await SessionCompaction.create({ + sessionID, + model: lastUser.model, + }) + continue + } + + // normal processing + const agent = await Agent.get(lastUser.agent) + msgs = insertReminders({ + messages: msgs, + agent, + }) + const processor = SessionProcessor.create({ + assistantMessage: (await Session.updateMessage({ + id: Identifier.ascending("message"), + parentID: lastUser.id, + role: "assistant", + mode: agent.name, + path: { + cwd: Instance.directory, + root: Instance.worktree, + }, + cost: 0, + tokens: { + input: 0, + output: 0, + reasoning: 0, + cache: { read: 0, write: 0 }, + }, + modelID: model.modelID, + providerID: model.providerID, + time: { + created: Date.now(), + }, + sessionID, + })) as MessageV2.Assistant, + sessionID: sessionID, + model: model.info, + providerID: model.providerID, + abort, + }) + const system = await resolveSystemPrompt({ + providerID: model.providerID, + modelID: model.info.id, + agent, + system: lastUser.system, + }) + const tools = await resolveTools({ + agent, + sessionID, + model: lastUser.model, + tools: lastUser.tools, + processor, + }) + const params = await Plugin.trigger( + "chat.params", + { + sessionID: sessionID, + agent: lastUser.agent, + model: model.info, + provider: await Provider.getProvider(model.providerID), + message: lastUser, + }, + { + temperature: model.info.temperature + ? (agent.temperature ?? ProviderTransform.temperature(model.providerID, model.modelID)) + : undefined, + topP: agent.topP ?? ProviderTransform.topP(model.providerID, model.modelID), + options: { + ...ProviderTransform.options(model.providerID, model.modelID, model.npm ?? "", sessionID), + ...model.info.options, + ...agent.options, + }, + }, + ) + + if (step === 1) { SessionSummary.summarize({ - sessionID: input.sessionID, - messageID: userMsg.info.id, + sessionID: sessionID, + messageID: lastUser.id, }) } - await using _ = defer(async () => { - await processor.end() - }) - const doStream = () => + + const result = await processor.process(() => streamText({ onError(error) { log.error("stream error", { @@ -335,8 +529,8 @@ export namespace SessionPrompt { headers: { ...(model.providerID === "opencode" ? { - "x-opencode-session": input.sessionID, - "x-opencode-request": userMsg.info.id, + "x-opencode-session": sessionID, + "x-opencode-request": lastUser.id, } : undefined), ...model.info.headers, @@ -345,12 +539,12 @@ export namespace SessionPrompt { maxRetries: 0, activeTools: Object.keys(tools).filter((x) => x !== "invalid"), maxOutputTokens: ProviderTransform.maxOutputTokens( - model.npm ?? "", + model.providerID, params.options, model.info.limit.output, OUTPUT_TOKEN_MAX, ), - abortSignal: abort.signal, + abortSignal: abort, providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options), stopWhen: stepCountIs(1), temperature: params.temperature, @@ -393,142 +587,22 @@ export namespace SessionPrompt { }, ], }), - }) - - let stream = doStream() - const cfg = await Config.get() - const maxRetries = cfg.experimental?.chatMaxRetries ?? MAX_RETRIES - let result = await processor.process(stream, { - count: 0, - max: maxRetries, - }) - if (result.shouldRetry) { - const start = Date.now() - for (let retry = 1; retry < maxRetries; retry++) { - const lastRetryPart = result.parts.findLast((p): p is MessageV2.RetryPart => p.type === "retry") - - if (lastRetryPart) { - const delayMs = SessionRetry.getBoundedDelay({ - error: lastRetryPart.error, - attempt: retry, - startTime: start, - }) - if (!delayMs) { - break - } - - log.info("retrying with backoff", { - attempt: retry, - delayMs, - elapsed: Date.now() - start, - }) - - const stop = await SessionRetry.sleep(delayMs, abort.signal) - .then(() => false) - .catch((error) => { - let err = error - if (error instanceof DOMException && error.name === "AbortError") { - err = new MessageV2.AbortedError( - { message: error.message }, - { - cause: error, - }, - ).toObject() - } - result.info.error = err - Bus.publish(Session.Event.Error, { - sessionID: result.info.sessionID, - error: result.info.error, - }) - return true - }) - - if (stop) break - } - - stream = doStream() - result = await processor.process(stream, { - count: retry, - max: maxRetries, - }) - if (!result.shouldRetry) { - break - } - } - } - await processor.end() - - const queued = state().queued.get(input.sessionID) ?? [] - - if (!result.blocked && !result.info.error) { - if ((await stream.finishReason) === "tool-calls") { - continue - } - - const unprocessed = queued.filter((x) => x.messageID > result.info.id) - if (unprocessed.length) { - continue - } - } - for (const item of queued) { - item.callback(result) - } - state().queued.delete(input.sessionID) - SessionCompaction.prune(input) - return result - } - } - - async function getMessages(input: { - sessionID: string - model: ModelsDev.Model - providerID: string - signal: AbortSignal - }) { - let msgs = await MessageV2.filterCompacted(MessageV2.stream(input.sessionID)) - const lastAssistant = msgs.findLast((msg) => msg.info.role === "assistant") - if ( - lastAssistant?.info.role === "assistant" && - SessionCompaction.isOverflow({ - tokens: lastAssistant.info.tokens, - model: input.model, - }) - ) { - const summaryMsg = await SessionCompaction.run({ - sessionID: input.sessionID, - providerID: input.providerID, - modelID: input.model.id, - signal: input.signal, - }) - const resumeMsgID = Identifier.ascending("message") - const resumeMsg = { - info: await Session.updateMessage({ - id: resumeMsgID, - role: "user", - sessionID: input.sessionID, - time: { - created: Date.now(), - }, }), - parts: [ - await Session.updatePart({ - type: "text", - sessionID: input.sessionID, - messageID: resumeMsgID, - id: Identifier.ascending("part"), - text: "Use the above summary generated from your last session to resume from where you left off.", - time: { - start: Date.now(), - end: Date.now(), - }, - synthetic: true, - }), - ], - } - msgs = [summaryMsg, resumeMsg] + ) + if (result === "stop") break + continue } - return msgs - } + SessionCompaction.prune({ sessionID }) + for await (const item of MessageV2.stream(sessionID)) { + if (item.info.role === "user") continue + const queued = state()[sessionID]?.callbacks ?? [] + for (const q of queued) { + q.resolve(item) + } + return item + } + throw new Error("Impossible") + }) async function resolveModel(input: { model: PromptInput["model"]; agent: Agent.Info }) { if (input.model) { @@ -564,21 +638,27 @@ export namespace SessionPrompt { async function resolveTools(input: { agent: Agent.Info + model: { + providerID: string + modelID: string + } sessionID: string - modelID: string - providerID: string tools?: Record - processor: Processor + processor: SessionProcessor.Info }) { const tools: Record = {} const enabledTools = pipe( input.agent.tools, - mergeDeep(await ToolRegistry.enabled(input.providerID, input.modelID, input.agent)), + mergeDeep(await ToolRegistry.enabled(input.model.providerID, input.model.modelID, input.agent)), mergeDeep(input.tools ?? {}), ) - for (const item of await ToolRegistry.tools(input.providerID, input.modelID)) { + for (const item of await ToolRegistry.tools(input.model.providerID, input.model.modelID)) { if (Wildcard.all(item.id, enabledTools) === false) continue - const schema = ProviderTransform.schema(input.providerID, input.modelID, z.toJSONSchema(item.parameters)) + const schema = ProviderTransform.schema( + input.model.providerID, + input.model.modelID, + z.toJSONSchema(item.parameters), + ) tools[item.id] = tool({ id: item.id as any, description: item.description, @@ -600,10 +680,7 @@ export namespace SessionPrompt { abort: options.abortSignal!, messageID: input.processor.message.id, callID: options.toolCallId, - extra: { - modelID: input.modelID, - providerID: input.providerID, - }, + extra: input.model, agent: input.agent.name, metadata: async (val) => { const match = input.processor.partFromToolCall(options.toolCallId) @@ -710,6 +787,7 @@ export namespace SessionPrompt { } async function createUserMessage(input: PromptInput) { + const agent = await Agent.get(input.agent ?? "build") const info: MessageV2.Info = { id: input.messageID ?? Identifier.ascending("message"), role: "user", @@ -717,6 +795,13 @@ export namespace SessionPrompt { time: { created: Date.now(), }, + tools: input.tools, + system: input.system, + agent: agent.name, + model: await resolveModel({ + model: input.model, + agent, + }), } const parts = await Promise.all( @@ -1007,428 +1092,6 @@ export namespace SessionPrompt { return input.messages } - export type Processor = Awaited> - async function createProcessor(input: { - sessionID: string - providerID: string - model: ModelsDev.Model - system: string[] - agent: string - abort: AbortSignal - }) { - const toolcalls: Record = {} - let snapshot: string | undefined - let blocked = false - - async function createMessage(parentID: string) { - const msg: MessageV2.Info = { - id: Identifier.ascending("message"), - parentID, - role: "assistant", - mode: input.agent, - path: { - cwd: Instance.directory, - root: Instance.worktree, - }, - cost: 0, - tokens: { - input: 0, - output: 0, - reasoning: 0, - cache: { read: 0, write: 0 }, - }, - modelID: input.model.id, - providerID: input.providerID, - time: { - created: Date.now(), - }, - sessionID: input.sessionID, - } - await Session.updateMessage(msg) - return msg - } - - let assistantMsg: MessageV2.Assistant | undefined - - const result = { - async end() { - if (assistantMsg) { - assistantMsg.time.completed = Date.now() - await Session.updateMessage(assistantMsg) - assistantMsg = undefined - } - }, - async next(parentID: string) { - if (assistantMsg) { - throw new Error("end previous assistant message first") - } - assistantMsg = await createMessage(parentID) - return assistantMsg - }, - get message() { - if (!assistantMsg) throw new Error("call next() first before accessing message") - return assistantMsg - }, - partFromToolCall(toolCallID: string) { - return toolcalls[toolCallID] - }, - async process(stream: StreamTextResult, never>, retries: { count: number; max: number }) { - log.info("process") - if (!assistantMsg) throw new Error("call next() first before processing") - let shouldRetry = false - try { - let currentText: MessageV2.TextPart | undefined - let reasoningMap: Record = {} - - for await (const value of stream.fullStream) { - input.abort.throwIfAborted() - switch (value.type) { - case "start": - break - - case "reasoning-start": - if (value.id in reasoningMap) { - continue - } - reasoningMap[value.id] = { - id: Identifier.ascending("part"), - messageID: assistantMsg.id, - sessionID: assistantMsg.sessionID, - type: "reasoning", - text: "", - time: { - start: Date.now(), - }, - metadata: value.providerMetadata, - } - break - - case "reasoning-delta": - if (value.id in reasoningMap) { - const part = reasoningMap[value.id] - part.text += value.text - if (value.providerMetadata) part.metadata = value.providerMetadata - if (part.text) await Session.updatePart({ part, delta: value.text }) - } - break - - case "reasoning-end": - if (value.id in reasoningMap) { - const part = reasoningMap[value.id] - part.text = part.text.trimEnd() - - part.time = { - ...part.time, - end: Date.now(), - } - if (value.providerMetadata) part.metadata = value.providerMetadata - await Session.updatePart(part) - delete reasoningMap[value.id] - } - break - - case "tool-input-start": - const part = await Session.updatePart({ - id: toolcalls[value.id]?.id ?? Identifier.ascending("part"), - messageID: assistantMsg.id, - sessionID: assistantMsg.sessionID, - type: "tool", - tool: value.toolName, - callID: value.id, - state: { - status: "pending", - input: {}, - raw: "", - }, - }) - toolcalls[value.id] = part as MessageV2.ToolPart - break - - case "tool-input-delta": - break - - case "tool-input-end": - break - - case "tool-call": { - const match = toolcalls[value.toolCallId] - if (match) { - const part = await Session.updatePart({ - ...match, - tool: value.toolName, - state: { - status: "running", - input: value.input, - time: { - start: Date.now(), - }, - }, - metadata: value.providerMetadata, - }) - toolcalls[value.toolCallId] = part as MessageV2.ToolPart - - const parts = await MessageV2.parts(assistantMsg.id) - const lastThree = parts.slice(-DOOM_LOOP_THRESHOLD) - if ( - lastThree.length === DOOM_LOOP_THRESHOLD && - lastThree.every( - (p) => - p.type === "tool" && - p.tool === value.toolName && - p.state.status !== "pending" && - JSON.stringify(p.state.input) === JSON.stringify(value.input), - ) - ) { - const permission = await Agent.get(input.agent).then((x) => x.permission) - if (permission.doom_loop === "ask") { - await Permission.ask({ - type: "doom_loop", - pattern: value.toolName, - sessionID: assistantMsg.sessionID, - messageID: assistantMsg.id, - callID: value.toolCallId, - title: `Possible doom loop: "${value.toolName}" called ${DOOM_LOOP_THRESHOLD} times with identical arguments`, - metadata: { - tool: value.toolName, - input: value.input, - }, - }) - } - } - } - break - } - case "tool-result": { - const match = toolcalls[value.toolCallId] - if (match && match.state.status === "running") { - await Session.updatePart({ - ...match, - state: { - status: "completed", - input: value.input, - output: value.output.output, - metadata: value.output.metadata, - title: value.output.title, - time: { - start: match.state.time.start, - end: Date.now(), - }, - attachments: value.output.attachments, - }, - }) - - delete toolcalls[value.toolCallId] - } - break - } - - case "tool-error": { - const match = toolcalls[value.toolCallId] - if (match && match.state.status === "running") { - await Session.updatePart({ - ...match, - state: { - status: "error", - input: value.input, - error: (value.error as any).toString(), - metadata: value.error instanceof Permission.RejectedError ? value.error.metadata : undefined, - time: { - start: match.state.time.start, - end: Date.now(), - }, - }, - }) - - if (value.error instanceof Permission.RejectedError) { - blocked = true - } - delete toolcalls[value.toolCallId] - } - break - } - case "error": - throw value.error - - case "start-step": - snapshot = await Snapshot.track() - await Session.updatePart({ - id: Identifier.ascending("part"), - messageID: assistantMsg.id, - sessionID: assistantMsg.sessionID, - snapshot, - type: "step-start", - }) - break - - case "finish-step": - const usage = Session.getUsage({ - model: input.model, - usage: value.usage, - metadata: value.providerMetadata, - }) - assistantMsg.cost += usage.cost - assistantMsg.tokens = usage.tokens - await Session.updatePart({ - id: Identifier.ascending("part"), - reason: value.finishReason, - snapshot: await Snapshot.track(), - messageID: assistantMsg.id, - sessionID: assistantMsg.sessionID, - type: "step-finish", - tokens: usage.tokens, - cost: usage.cost, - }) - await Session.updateMessage(assistantMsg) - if (snapshot) { - const patch = await Snapshot.patch(snapshot) - if (patch.files.length) { - await Session.updatePart({ - id: Identifier.ascending("part"), - messageID: assistantMsg.id, - sessionID: assistantMsg.sessionID, - type: "patch", - hash: patch.hash, - files: patch.files, - }) - } - snapshot = undefined - } - SessionSummary.summarize({ - sessionID: input.sessionID, - messageID: assistantMsg.parentID, - }) - break - - case "text-start": - currentText = { - id: Identifier.ascending("part"), - messageID: assistantMsg.id, - sessionID: assistantMsg.sessionID, - type: "text", - text: "", - time: { - start: Date.now(), - }, - metadata: value.providerMetadata, - } - break - - case "text-delta": - if (currentText) { - currentText.text += value.text - if (value.providerMetadata) currentText.metadata = value.providerMetadata - if (currentText.text) - await Session.updatePart({ - part: currentText, - delta: value.text, - }) - } - break - - case "text-end": - if (currentText) { - currentText.text = currentText.text.trimEnd() - currentText.time = { - start: Date.now(), - end: Date.now(), - } - if (value.providerMetadata) currentText.metadata = value.providerMetadata - await Session.updatePart(currentText) - } - currentText = undefined - break - - case "finish": - assistantMsg.time.completed = Date.now() - await Session.updateMessage(assistantMsg) - break - - default: - log.info("unhandled", { - ...value, - }) - continue - } - } - } catch (e) { - log.error("process", { - error: e, - }) - const error = MessageV2.fromError(e, { providerID: input.providerID }) - if (retries.count < retries.max && MessageV2.APIError.isInstance(error) && error.data.isRetryable) { - shouldRetry = true - await Session.updatePart({ - id: Identifier.ascending("part"), - messageID: assistantMsg.id, - sessionID: assistantMsg.sessionID, - type: "retry", - attempt: retries.count + 1, - time: { - created: Date.now(), - }, - error, - }) - } else { - assistantMsg.error = error - Bus.publish(Session.Event.Error, { - sessionID: assistantMsg.sessionID, - error: assistantMsg.error, - }) - } - } - const p = await MessageV2.parts(assistantMsg.id) - for (const part of p) { - if (part.type === "tool" && part.state.status !== "completed" && part.state.status !== "error") { - await Session.updatePart({ - ...part, - state: { - ...part.state, - status: "error", - error: "Tool execution aborted", - time: { - start: Date.now(), - end: Date.now(), - }, - }, - }) - } - } - if (!shouldRetry) { - assistantMsg.time.completed = Date.now() - } - await Session.updateMessage(assistantMsg) - return { info: assistantMsg, parts: p, blocked, shouldRetry } - }, - } - return result - } - - function isBusy(sessionID: string) { - return SessionLock.isLocked(sessionID) - } - - function lock(sessionID: string) { - const handle = SessionLock.acquire({ - sessionID, - }) - log.info("locking", { sessionID }) - return { - signal: handle.signal, - abort: handle.abort, - async [Symbol.dispose]() { - handle[Symbol.dispose]() - log.info("unlocking", { sessionID }) - - const session = await Session.get(sessionID) - if (session.parentID) return - - Bus.publish(Event.Idle, { - sessionID, - }) - }, - } - } - export const ShellInput = z.object({ sessionID: Identifier.schema("session"), agent: z.string(), @@ -1436,11 +1099,12 @@ export namespace SessionPrompt { }) export type ShellInput = z.infer export async function shell(input: ShellInput) { - using abort = lock(input.sessionID) const session = await Session.get(input.sessionID) if (session.revert) { SessionRevert.cleanup(session) } + const agent = await Agent.get(input.agent) + const model = await resolveModel({ agent, model: undefined }) const userMsg: MessageV2.User = { id: Identifier.ascending("message"), sessionID: input.sessionID, @@ -1448,6 +1112,11 @@ export namespace SessionPrompt { created: Date.now(), }, role: "user", + agent: input.agent, + model: { + providerID: model.providerID, + modelID: model.modelID, + }, } await Session.updateMessage(userMsg) const userPart: MessageV2.Part = { @@ -1480,8 +1149,8 @@ export namespace SessionPrompt { reasoning: 0, cache: { read: 0, write: 0 }, }, - modelID: "", - providerID: "", + modelID: model.modelID, + providerID: model.providerID, } await Session.updateMessage(msg) const part: MessageV2.Part = { @@ -1544,7 +1213,6 @@ export namespace SessionPrompt { const proc = spawn(shell, args, { cwd: Instance.directory, - signal: abort.signal, detached: true, stdio: ["ignore", "pipe", "pipe"], env: { @@ -1553,11 +1221,6 @@ export namespace SessionPrompt { }, }) - abort.signal.addEventListener("abort", () => { - if (!proc.pid) return - process.kill(-proc.pid) - }) - let output = "" proc.stdout?.on("data", (chunk) => { @@ -1669,8 +1332,6 @@ export namespace SessionPrompt { } template = template.trim() - const parts = await resolvePromptParts(template) - const model = await (async () => { if (command.model) { return Provider.parseModel(command.model) @@ -1686,128 +1347,28 @@ export namespace SessionPrompt { } return await Provider.defaultModel() })() - const agent = await Agent.get(agentName) - let result: MessageV2.WithParts - if ((agent.mode === "subagent" && command.subtask !== false) || command.subtask === true) { - using abort = lock(input.sessionID) + const parts = + (agent.mode === "subagent" && command.subtask !== false) || command.subtask === true + ? [ + { + type: "subtask" as const, + agent: agent.name, + description: command.description ?? "", + // TODO: how can we make task tool accept a more complex input? + prompt: await resolvePromptParts(template).then((x) => x.find((y) => y.type === "text")?.text ?? ""), + }, + ] + : await resolvePromptParts(template) - const userMsg: MessageV2.User = { - id: Identifier.ascending("message"), - sessionID: input.sessionID, - time: { - created: Date.now(), - }, - role: "user", - } - await Session.updateMessage(userMsg) - const userPart: MessageV2.Part = { - type: "text", - id: Identifier.ascending("part"), - messageID: userMsg.id, - sessionID: input.sessionID, - text: "The following tool was executed by the user", - synthetic: true, - } - await Session.updatePart(userPart) - - const assistantMsg: MessageV2.Assistant = { - id: Identifier.ascending("message"), - sessionID: input.sessionID, - parentID: userMsg.id, - mode: agentName, - cost: 0, - path: { - cwd: Instance.directory, - root: Instance.worktree, - }, - time: { - created: Date.now(), - }, - role: "assistant", - tokens: { - input: 0, - output: 0, - reasoning: 0, - cache: { read: 0, write: 0 }, - }, - modelID: model.modelID, - providerID: model.providerID, - } - await Session.updateMessage(assistantMsg) - - const args = { - description: "Consulting " + agent.name, - subagent_type: agent.name, - prompt: template, - } - const toolPart: MessageV2.ToolPart = { - type: "tool", - id: Identifier.ascending("part"), - messageID: assistantMsg.id, - sessionID: input.sessionID, - tool: "task", - callID: ulid(), - state: { - status: "running", - time: { - start: Date.now(), - }, - input: { - description: args.description, - subagent_type: args.subagent_type, - // truncate prompt to preserve context - prompt: args.prompt.length > 100 ? args.prompt.substring(0, 97) + "..." : args.prompt, - }, - }, - } - await Session.updatePart(toolPart) - - const taskResult = await TaskTool.init().then((t) => - t.execute(args, { - sessionID: input.sessionID, - abort: abort.signal, - agent: agent.name, - messageID: assistantMsg.id, - extra: {}, - metadata: async (metadata) => { - if (toolPart.state.status === "running") { - toolPart.state.metadata = metadata.metadata - toolPart.state.title = metadata.title - await Session.updatePart(toolPart) - } - }, - }), - ) - - assistantMsg.time.completed = Date.now() - await Session.updateMessage(assistantMsg) - if (toolPart.state.status === "running") { - toolPart.state = { - status: "completed", - time: { - ...toolPart.state.time, - end: Date.now(), - }, - input: toolPart.state.input, - title: "", - metadata: taskResult.metadata, - output: taskResult.output, - } - await Session.updatePart(toolPart) - } - - result = { info: assistantMsg, parts: [toolPart] } - } else { - result = await prompt({ - sessionID: input.sessionID, - messageID: input.messageID, - model, - agent: agentName, - parts, - }) - } + const result = (await prompt({ + sessionID: input.sessionID, + messageID: input.messageID, + model, + agent: agentName, + parts, + })) as MessageV2.WithParts Bus.publish(Command.Event.Executed, { name: input.command, @@ -1819,6 +1380,7 @@ export namespace SessionPrompt { return result } + // TODO: wire this back up async function ensureTitle(input: { session: Session.Info message: MessageV2.WithParts @@ -1871,6 +1433,11 @@ export namespace SessionPrompt { time: { created: Date.now(), }, + agent: input.message.info.role === "user" ? input.message.info.agent : "build", + model: { + providerID: input.providerID, + modelID: input.modelID, + }, }, parts: input.message.parts, }, diff --git a/packages/opencode/src/session/revert.ts b/packages/opencode/src/session/revert.ts index dbf81edc..35c7b9a6 100644 --- a/packages/opencode/src/session/revert.ts +++ b/packages/opencode/src/session/revert.ts @@ -7,7 +7,7 @@ import { Log } from "../util/log" import { splitWhen } from "remeda" import { Storage } from "../storage/storage" import { Bus } from "../bus" -import { SessionLock } from "./lock" +import { SessionPrompt } from "./prompt" export namespace SessionRevert { const log = Log.create({ service: "session.revert" }) @@ -20,11 +20,7 @@ export namespace SessionRevert { export type RevertInput = z.infer export async function revert(input: RevertInput) { - SessionLock.assertUnlocked(input.sessionID) - using _ = SessionLock.acquire({ - sessionID: input.sessionID, - }) - + SessionPrompt.assertNotBusy(input.sessionID) const all = await Session.messages({ sessionID: input.sessionID }) let lastUser: MessageV2.User | undefined const session = await Session.get(input.sessionID) @@ -70,10 +66,7 @@ export namespace SessionRevert { export async function unrevert(input: { sessionID: string }) { log.info("unreverting", input) - SessionLock.assertUnlocked(input.sessionID) - using _ = SessionLock.acquire({ - sessionID: input.sessionID, - }) + SessionPrompt.assertNotBusy(input.sessionID) const session = await Session.get(input.sessionID) if (!session.revert) return session if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot) diff --git a/packages/opencode/src/session/status.ts b/packages/opencode/src/session/status.ts new file mode 100644 index 00000000..ecac222f --- /dev/null +++ b/packages/opencode/src/session/status.ts @@ -0,0 +1,63 @@ +import { Bus } from "@/bus" +import { Instance } from "@/project/instance" +import z from "zod" + +export namespace SessionStatus { + export const Info = z + .union([ + z.object({ + type: z.literal("idle"), + }), + z.object({ + type: z.literal("retry"), + attempt: z.number(), + message: z.string(), + }), + z.object({ + type: z.literal("busy"), + }), + ]) + .meta({ + ref: "SessionStatus", + }) + export type Info = z.infer + + export const Event = { + Status: Bus.event( + "session.status", + z.object({ + sessionID: z.string(), + status: Info, + }), + ), + } + + const state = Instance.state(() => { + const data: Record = {} + return data + }) + + export function get(sessionID: string) { + return ( + state()[sessionID] ?? { + type: "idle", + } + ) + } + + export function list() { + return Object.values(state()) + } + + export function set(sessionID: string, status: Info) { + Bus.publish(Event.Status, { + sessionID, + status, + }) + if (status.type === "idle") { + delete state()[sessionID] + return + } + state()[sessionID] = status + } +} diff --git a/packages/opencode/src/session/system.ts b/packages/opencode/src/session/system.ts index 7d44bbda..aaccccc4 100644 --- a/packages/opencode/src/session/system.ts +++ b/packages/opencode/src/session/system.ts @@ -43,7 +43,7 @@ export namespace SystemPrompt { ` Platform: ${process.platform}`, ` Today's date: ${new Date().toDateString()}`, ``, - ``, + ``, ` ${ project.vcs === "git" ? await Ripgrep.tree({ @@ -52,7 +52,7 @@ export namespace SystemPrompt { }) : "" }`, - ``, + ``, ].join("\n"), ] } diff --git a/packages/opencode/src/tool/task.ts b/packages/opencode/src/tool/task.ts index a5369d33..8f27f570 100644 --- a/packages/opencode/src/tool/task.ts +++ b/packages/opencode/src/tool/task.ts @@ -6,8 +6,8 @@ import { Bus } from "../bus" import { MessageV2 } from "../session/message-v2" import { Identifier } from "../id/id" import { Agent } from "../agent/agent" -import { SessionLock } from "../session/lock" import { SessionPrompt } from "../session/prompt" +import { defer } from "@/util/defer" export const TaskTool = Tool.define("task", async () => { const agents = await Agent.list().then((x) => x.filter((a) => a.mode !== "primary")) @@ -62,9 +62,11 @@ export const TaskTool = Tool.define("task", async () => { providerID: msg.info.providerID, } - ctx.abort.addEventListener("abort", () => { - SessionLock.abort(session.id) - }) + function cancel() { + SessionPrompt.cancel(session.id) + } + ctx.abort.addEventListener("abort", cancel) + using _ = defer(() => ctx.abort.removeEventListener("abort", cancel)) const promptParts = await SessionPrompt.resolvePromptParts(params.prompt) const result = await SessionPrompt.prompt({ messageID, diff --git a/packages/sdk/js/src/gen/sdk.gen.ts b/packages/sdk/js/src/gen/sdk.gen.ts index e1c62204..04dc29cc 100644 --- a/packages/sdk/js/src/gen/sdk.gen.ts +++ b/packages/sdk/js/src/gen/sdk.gen.ts @@ -26,6 +26,9 @@ import type { SessionCreateData, SessionCreateResponses, SessionCreateErrors, + SessionStatusData, + SessionStatusResponses, + SessionStatusErrors, SessionDeleteData, SessionDeleteResponses, SessionDeleteErrors, @@ -306,6 +309,16 @@ class Session extends _HeyApiClient { }) } + /** + * Get session status + */ + public status(options?: Options) { + return (options?.client ?? this._client).get({ + url: "/session/status", + ...options, + }) + } + /** * Delete a session and all its data */ diff --git a/packages/sdk/js/src/gen/types.gen.ts b/packages/sdk/js/src/gen/types.gen.ts index ea43490a..2309f8b7 100644 --- a/packages/sdk/js/src/gen/types.gen.ts +++ b/packages/sdk/js/src/gen/types.gen.ts @@ -42,6 +42,15 @@ export type UserMessage = { body?: string diffs: Array } + agent: string + model: { + providerID: string + modelID: string + } + system?: string + tools?: { + [key: string]: boolean + } } export type ProviderAuthError = { @@ -114,6 +123,7 @@ export type AssistantMessage = { write: number } } + finish?: string } export type Message = UserMessage | AssistantMessage @@ -348,6 +358,13 @@ export type RetryPart = { } } +export type CompactionPart = { + id: string + sessionID: string + messageID: string + type: "compaction" +} + export type Part = | TextPart | ReasoningPart @@ -359,6 +376,7 @@ export type Part = | PatchPart | AgentPart | RetryPart + | CompactionPart export type EventMessagePartUpdated = { type: "message.part.updated" @@ -377,13 +395,6 @@ export type EventMessagePartRemoved = { } } -export type EventSessionCompacted = { - type: "session.compacted" - properties: { - sessionID: string - } -} - export type Permission = { id: string type: string @@ -414,6 +425,13 @@ export type EventPermissionReplied = { } } +export type EventSessionCompacted = { + type: "session.compacted" + properties: { + sessionID: string + } +} + export type EventFileEdited = { type: "file.edited" properties: { @@ -458,6 +476,27 @@ export type EventCommandExecuted = { } } +export type SessionStatus = + | { + type: "idle" + } + | { + type: "retry" + attempt: number + message: string + } + | { + type: "busy" + } + +export type EventSessionStatus = { + type: "session.status" + properties: { + sessionID: string + status: SessionStatus + } +} + export type EventSessionIdle = { type: "session.idle" properties: { @@ -598,12 +637,13 @@ export type Event = | EventMessageRemoved | EventMessagePartUpdated | EventMessagePartRemoved - | EventSessionCompacted | EventPermissionUpdated | EventPermissionReplied + | EventSessionCompacted | EventFileEdited | EventTodoUpdated | EventCommandExecuted + | EventSessionStatus | EventSessionIdle | EventSessionCreated | EventSessionUpdated @@ -1613,6 +1653,35 @@ export type SessionCreateResponses = { export type SessionCreateResponse = SessionCreateResponses[keyof SessionCreateResponses] +export type SessionStatusData = { + body?: never + path?: never + query?: { + directory?: string + } + url: "/session/status" +} + +export type SessionStatusErrors = { + /** + * Bad request + */ + 400: BadRequestError +} + +export type SessionStatusError = SessionStatusErrors[keyof SessionStatusErrors] + +export type SessionStatusResponses = { + /** + * Get session status + */ + 200: { + [key: string]: SessionStatus + } +} + +export type SessionStatusResponse = SessionStatusResponses[keyof SessionStatusResponses] + export type SessionDeleteData = { body?: never path: {