Refactor agent loop (#4412)

This commit is contained in:
Dax
2025-11-17 10:57:18 -05:00
committed by GitHub
parent 9fd43ec616
commit a1214fff2e
22 changed files with 1297 additions and 1324 deletions

1
.gitignore vendored
View File

@@ -13,3 +13,4 @@ dist
.turbo
**/.serena
.serena/
refs

View File

@@ -1,5 +1,6 @@
---
description: Git commit and push
subtask: true
---
commit and push

View File

@@ -1,4 +0,0 @@
{
"$schema": "https://opencode.ai/config.json",
"plugin": ["opencode-openai-codex-auth"]
}

11
.opencode/opencode.jsonc Normal file
View File

@@ -0,0 +1,11 @@
{
"$schema": "https://opencode.ai/config.json",
"plugin": ["opencode-openai-codex-auth"],
"provider": {
"opencode": {
"options": {
// "baseURL": "http://localhost:8080"
},
},
},
}

0
a.out Normal file
View File

View File

@@ -2,6 +2,7 @@ import { EOL } from "os"
import { File } from "../../../file"
import { bootstrap } from "../../bootstrap"
import { cmd } from "../cmd"
import { Ripgrep } from "@/file/ripgrep"
const FileSearchCommand = cmd({
command: "search <query>",
@@ -62,6 +63,20 @@ const FileListCommand = cmd({
},
})
const FileTreeCommand = cmd({
command: "tree [dir]",
builder: (yargs) =>
yargs.positional("dir", {
type: "string",
description: "Directory to tree",
default: process.cwd(),
}),
async handler(args) {
const files = await Ripgrep.tree({ cwd: args.dir, limit: 200 })
console.log(files)
},
})
export const FileCommand = cmd({
command: "file",
builder: (yargs) =>
@@ -70,6 +85,7 @@ export const FileCommand = cmd({
.command(FileStatusCommand)
.command(FileListCommand)
.command(FileSearchCommand)
.command(FileTreeCommand)
.demandCommand(),
async handler() {},
})

View File

@@ -11,6 +11,7 @@ import type {
LspStatus,
McpStatus,
FormatterStatus,
SessionStatus,
} from "@opencode-ai/sdk"
import { createStore, produce, reconcile } from "solid-js/store"
import { useSDK } from "@tui/context/sdk"
@@ -33,6 +34,9 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({
}
config: Config
session: Session[]
session_status: {
[sessionID: string]: SessionStatus
}
session_diff: {
[sessionID: string]: Snapshot.FileDiff[]
}
@@ -58,6 +62,7 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({
command: [],
provider: [],
session: [],
session_status: {},
session_diff: {},
todo: {},
message: {},
@@ -140,6 +145,12 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({
}),
)
break
case "session.status": {
setStore("session_status", event.properties.sessionID, event.properties.status)
break
}
case "message.updated": {
const messages = store.message[event.properties.info.sessionID]
if (!messages) {
@@ -240,6 +251,7 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({
sdk.client.lsp.status().then((x) => setStore("lsp", x.data!)),
sdk.client.mcp.status().then((x) => setStore("mcp", x.data!)),
sdk.client.formatter.status().then((x) => setStore("formatter", x.data!)),
sdk.client.session.status().then((x) => setStore("session_status", x.data!)),
]).then(() => {
setStore("status", "complete")
})

View File

@@ -20,7 +20,6 @@ import { useTheme } from "@tui/context/theme"
import {
BoxRenderable,
ScrollBoxRenderable,
TextAttributes,
addDefaultParsers,
MacOSScrollAccel,
type ScrollAcceleration,
@@ -65,7 +64,6 @@ import { Editor } from "../../util/editor"
import { Global } from "@/global"
import fs from "fs/promises"
import stripAnsi from "strip-ansi"
import { LSP } from "@/lsp/index.ts"
addDefaultParsers(parsers.parsers)
@@ -101,7 +99,12 @@ export function Session() {
const permissions = createMemo(() => sync.data.permission[route.sessionID] ?? [])
const pending = createMemo(() => {
return messages().findLast((x) => x.role === "assistant" && !x.time?.completed)?.id
return messages().findLast((x) => x.role === "assistant" && !x.time.completed)?.id
})
const lastUserMessage = createMemo(() => {
const p = pending()
return messages().findLast((x) => x.role === "user" && (!p || x.id < p)) as UserMessage
})
const dimensions = useTerminalDimensions()
@@ -801,7 +804,7 @@ export function Session() {
</Match>
<Match when={message.role === "assistant"}>
<AssistantMessage
last={index() === messages().length - 1}
last={pending() === message.id}
message={message as AssistantMessage}
parts={sync.data.part[message.id] ?? []}
/>
@@ -856,64 +859,84 @@ function UserMessage(props: {
const queued = createMemo(() => props.pending && props.message.id > props.pending)
const color = createMemo(() => (queued() ? theme.accent : theme.secondary))
const compaction = createMemo(() => props.parts.find((x) => x.type === "compaction"))
return (
<Show when={text()}>
<box
id={props.message.id}
onMouseOver={() => {
setHover(true)
}}
onMouseOut={() => {
setHover(false)
}}
onMouseUp={props.onMouseUp}
border={["left"]}
paddingTop={1}
paddingBottom={1}
paddingLeft={2}
marginTop={props.index === 0 ? 0 : 1}
backgroundColor={hover() ? theme.backgroundElement : theme.backgroundPanel}
customBorderChars={SplitBorder.customBorderChars}
borderColor={color()}
flexShrink={0}
>
<text fg={theme.text}>{text()?.text}</text>
<Show when={files().length}>
<box flexDirection="row" paddingBottom={1} paddingTop={1} gap={1} flexWrap="wrap">
<For each={files()}>
{(file) => {
const bg = createMemo(() => {
if (file.mime.startsWith("image/")) return theme.accent
if (file.mime === "application/pdf") return theme.primary
return theme.secondary
})
return (
<text fg={theme.text}>
<span style={{ bg: bg(), fg: theme.background }}> {MIME_BADGE[file.mime] ?? file.mime} </span>
<span style={{ bg: theme.backgroundElement, fg: theme.textMuted }}> {file.filename} </span>
</text>
)
}}
</For>
</box>
</Show>
<text fg={theme.text}>
{sync.data.config.username ?? "You"}{" "}
<Show
when={queued()}
fallback={<span style={{ fg: theme.textMuted }}>({Locale.time(props.message.time.created)})</span>}
>
<span style={{ bg: theme.accent, fg: theme.backgroundPanel, bold: true }}> QUEUED </span>
<>
<Show when={text()}>
<box
id={props.message.id}
onMouseOver={() => {
setHover(true)
}}
onMouseOut={() => {
setHover(false)
}}
onMouseUp={props.onMouseUp}
border={["left"]}
paddingTop={1}
paddingBottom={1}
paddingLeft={2}
marginTop={props.index === 0 ? 0 : 1}
backgroundColor={hover() ? theme.backgroundElement : theme.backgroundPanel}
customBorderChars={SplitBorder.customBorderChars}
borderColor={color()}
flexShrink={0}
>
<text fg={theme.text}>{text()?.text}</text>
<Show when={files().length}>
<box flexDirection="row" paddingBottom={1} paddingTop={1} gap={1} flexWrap="wrap">
<For each={files()}>
{(file) => {
const bg = createMemo(() => {
if (file.mime.startsWith("image/")) return theme.accent
if (file.mime === "application/pdf") return theme.primary
return theme.secondary
})
return (
<text fg={theme.text}>
<span style={{ bg: bg(), fg: theme.background }}> {MIME_BADGE[file.mime] ?? file.mime} </span>
<span style={{ bg: theme.backgroundElement, fg: theme.textMuted }}> {file.filename} </span>
</text>
)
}}
</For>
</box>
</Show>
</text>
</box>
</Show>
<text fg={theme.text}>
{sync.data.config.username ?? "You"}{" "}
<Show
when={queued()}
fallback={<span style={{ fg: theme.textMuted }}>({Locale.time(props.message.time.created)})</span>}
>
<span style={{ bg: theme.accent, fg: theme.backgroundPanel, bold: true }}> QUEUED </span>
</Show>
</text>
</box>
</Show>
<Show when={compaction()}>
<box
marginTop={1}
border={["top"]}
title=" Compaction "
titleAlignment="center"
borderColor={theme.borderActive}
/>
</Show>
</>
)
}
function AssistantMessage(props: { message: AssistantMessage; parts: Part[]; last: boolean }) {
const local = useLocal()
const { theme } = useTheme()
const sync = useSync()
const status = createMemo(
() =>
sync.data.session_status[props.message.sessionID] ?? {
type: "idle",
},
)
return (
<>
<For each={props.parts}>
@@ -945,23 +968,15 @@ function AssistantMessage(props: { message: AssistantMessage; parts: Part[]; las
<text fg={theme.textMuted}>{props.message.error?.data.message}</text>
</box>
</Show>
<Show
when={
!props.message.time.completed ||
(props.last && props.parts.some((item) => item.type === "step-finish" && item.reason === "tool-calls"))
}
>
<box
paddingLeft={2}
marginTop={1}
flexDirection="row"
gap={1}
border={["left"]}
customBorderChars={SplitBorder.customBorderChars}
borderColor={theme.backgroundElement}
>
<Show when={props.last && status().type !== "idle"}>
<box paddingLeft={3} flexDirection="row" gap={1} marginTop={1}>
<text fg={local.agent.color(props.message.mode)}>{Locale.titlecase(props.message.mode)}</text>
<Shimmer text={`${props.message.modelID}`} color={theme.text} />
<Shimmer text={props.message.modelID} color={theme.text} />
<Show when={status().type === "retry"}>
<text fg={theme.error}>
{(status() as any).message} [attempt #{(status() as any).attempt}]
</text>
</Show>
</box>
</Show>
<Show

View File

@@ -8,8 +8,10 @@ import { lazy } from "../util/lazy"
import { $ } from "bun"
import { ZipReader, BlobReader, BlobWriter } from "@zip.js/zip.js"
import { Log } from "@/util/log"
export namespace Ripgrep {
const log = Log.create({ service: "ripgrep" })
const Stats = z.object({
elapsed: z.object({
secs: z.number(),
@@ -254,6 +256,7 @@ export namespace Ripgrep {
}
export async function tree(input: { cwd: string; limit?: number }) {
log.info("tree", input)
const files = await Array.fromAsync(Ripgrep.files({ cwd: input.cwd }))
interface Node {
path: string[]

View File

@@ -27,7 +27,6 @@ import { Global } from "../global"
import { ProjectRoute } from "./project"
import { ToolRegistry } from "../tool/registry"
import { zodToJsonSchema } from "zod-to-json-schema"
import { SessionLock } from "../session/lock"
import { SessionPrompt } from "../session/prompt"
import { SessionCompaction } from "../session/compaction"
import { SessionRevert } from "../session/revert"
@@ -41,6 +40,7 @@ import { TuiEvent } from "@/cli/cmd/tui/event"
import { Snapshot } from "@/snapshot"
import { SessionSummary } from "@/session/summary"
import { GlobalBus } from "@/bus/global"
import { SessionStatus } from "@/session/status"
const ERRORS = {
400: {
@@ -367,6 +367,28 @@ export namespace Server {
return c.json(sessions)
},
)
.get(
"/session/status",
describeRoute({
description: "Get session status",
operationId: "session.status",
responses: {
200: {
description: "Get session status",
content: {
"application/json": {
schema: resolver(z.record(z.string(), SessionStatus.Info)),
},
},
},
...errors(400),
},
}),
async (c) => {
const result = SessionStatus.list()
return c.json(result)
},
)
.get(
"/session/:id",
describeRoute({
@@ -637,7 +659,8 @@ export namespace Server {
}),
),
async (c) => {
return c.json(SessionLock.abort(c.req.valid("param").id))
SessionPrompt.cancel(c.req.valid("param").id)
return c.json(true)
},
)
.post(
@@ -771,7 +794,14 @@ export namespace Server {
async (c) => {
const id = c.req.valid("param").id
const body = c.req.valid("json")
await SessionCompaction.run({ ...body, sessionID: id })
await SessionCompaction.create({
sessionID: id,
model: {
providerID: body.providerID,
modelID: body.modelID,
},
})
await SessionPrompt.loop(id)
return c.json(true)
},
)

View File

@@ -1,9 +1,8 @@
import { streamText, type ModelMessage, type StreamTextResult, type Tool as AITool } from "ai"
import { streamText, type ModelMessage } from "ai"
import { Session } from "."
import { Identifier } from "../id/id"
import { Instance } from "../project/instance"
import { Provider } from "../provider/provider"
import { defer } from "../util/defer"
import { MessageV2 } from "./message-v2"
import { SystemPrompt } from "./system"
import { Bus } from "../bus"
@@ -13,10 +12,9 @@ import { SessionPrompt } from "./prompt"
import { Flag } from "../flag/flag"
import { Token } from "../util/token"
import { Log } from "../util/log"
import { SessionLock } from "./lock"
import { ProviderTransform } from "@/provider/transform"
import { SessionRetry } from "./retry"
import { Config } from "@/config/config"
import { SessionProcessor } from "./processor"
import { fn } from "@/util/fn"
export namespace SessionCompaction {
const log = Log.create({ service: "session.compaction" })
@@ -42,7 +40,6 @@ export namespace SessionCompaction {
export const PRUNE_MINIMUM = 20_000
export const PRUNE_PROTECT = 40_000
const MAX_RETRIES = 10
// goes backwards through parts until there are 40_000 tokens worth of tool
// calls. then erases output of previous tool calls. idea is to throw away old
@@ -87,38 +84,29 @@ export namespace SessionCompaction {
}
}
export async function run(input: { sessionID: string; providerID: string; modelID: string; signal?: AbortSignal }) {
if (!input.signal) SessionLock.assertUnlocked(input.sessionID)
await using lock = input.signal === undefined ? SessionLock.acquire({ sessionID: input.sessionID }) : undefined
const signal = input.signal ?? lock!.signal
await Session.update(input.sessionID, (draft) => {
draft.time.compacting = Date.now()
})
await using _ = defer(async () => {
await Session.update(input.sessionID, (draft) => {
draft.time.compacting = undefined
})
})
const toSummarize = await MessageV2.filterCompacted(MessageV2.stream(input.sessionID))
const model = await Provider.getModel(input.providerID, input.modelID)
const system = [
...SystemPrompt.summarize(model.providerID),
...(await SystemPrompt.environment()),
...(await SystemPrompt.custom()),
]
export async function process(input: {
parentID: string
messages: MessageV2.WithParts[]
sessionID: string
model: {
providerID: string
modelID: string
}
abort: AbortSignal
}) {
const model = await Provider.getModel(input.model.providerID, input.model.modelID)
const system = [...SystemPrompt.summarize(model.providerID)]
const msg = (await Session.updateMessage({
id: Identifier.ascending("message"),
role: "assistant",
parentID: toSummarize.findLast((m) => m.info.role === "user")?.info.id!,
parentID: input.parentID,
sessionID: input.sessionID,
mode: "build",
summary: true,
path: {
cwd: Instance.directory,
root: Instance.worktree,
},
summary: true,
cost: 0,
tokens: {
output: 0,
@@ -126,37 +114,27 @@ export namespace SessionCompaction {
reasoning: 0,
cache: { read: 0, write: 0 },
},
modelID: input.modelID,
modelID: input.model.modelID,
providerID: model.providerID,
time: {
created: Date.now(),
},
})) as MessageV2.Assistant
const part = (await Session.updatePart({
type: "text",
const processor = SessionProcessor.create({
assistantMessage: msg,
sessionID: input.sessionID,
messageID: msg.id,
id: Identifier.ascending("part"),
text: "",
time: {
start: Date.now(),
},
})) as MessageV2.TextPart
const doStream = () =>
providerID: input.model.providerID,
model: model.info,
abort: input.abort,
})
const result = await processor.process(() =>
streamText({
// set to 0, we handle loop
maxRetries: 0,
model: model.language,
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, model.info.options),
headers: model.info.headers,
abortSignal: signal,
onError(error) {
log.error("stream error", {
error,
})
},
abortSignal: input.abort,
tools: model.info.tool_call ? {} : undefined,
messages: [
...system.map(
@@ -165,7 +143,7 @@ export namespace SessionCompaction {
content: x,
}),
),
...MessageV2.toModelMessage(toSummarize),
...MessageV2.toModelMessage(input.messages),
{
role: "user",
content: [
@@ -176,168 +154,60 @@ export namespace SessionCompaction {
],
},
],
})
// TODO: reduce duplication between compaction.ts & prompt.ts
const process = async (
stream: StreamTextResult<Record<string, AITool>, never>,
retries: { count: number; max: number },
) => {
let shouldRetry = false
try {
for await (const value of stream.fullStream) {
signal.throwIfAborted()
switch (value.type) {
case "text-delta":
part.text += value.text
if (value.providerMetadata) part.metadata = value.providerMetadata
if (part.text)
await Session.updatePart({
part,
delta: value.text,
})
continue
case "text-end": {
part.text = part.text.trimEnd()
part.time = {
start: Date.now(),
end: Date.now(),
}
if (value.providerMetadata) part.metadata = value.providerMetadata
await Session.updatePart(part)
continue
}
case "finish-step": {
const usage = Session.getUsage({
model: model.info,
usage: value.usage,
metadata: value.providerMetadata,
})
msg.cost += usage.cost
msg.tokens = usage.tokens
await Session.updateMessage(msg)
continue
}
case "error":
throw value.error
default:
continue
}
}
} catch (e) {
log.error("compaction error", {
error: e,
})
const error = MessageV2.fromError(e, { providerID: input.providerID })
if (retries.count < retries.max && MessageV2.APIError.isInstance(error) && error.data.isRetryable) {
shouldRetry = true
await Session.updatePart({
id: Identifier.ascending("part"),
messageID: msg.id,
sessionID: msg.sessionID,
type: "retry",
attempt: retries.count + 1,
time: {
created: Date.now(),
},
error,
})
} else {
msg.error = error
Bus.publish(Session.Event.Error, {
sessionID: msg.sessionID,
error: msg.error,
})
}
}
const parts = await MessageV2.parts(msg.id)
return {
info: msg,
parts,
shouldRetry,
}
}
let stream = doStream()
const cfg = await Config.get()
const maxRetries = cfg.experimental?.chatMaxRetries ?? MAX_RETRIES
let result = await process(stream, {
count: 0,
max: maxRetries,
})
if (result.shouldRetry) {
const start = Date.now()
for (let retry = 1; retry < maxRetries; retry++) {
const lastRetryPart = result.parts.findLast((p): p is MessageV2.RetryPart => p.type === "retry")
if (lastRetryPart) {
const delayMs = SessionRetry.getBoundedDelay({
error: lastRetryPart.error,
attempt: retry,
startTime: start,
})
if (!delayMs) {
break
}
log.info("retrying with backoff", {
attempt: retry,
delayMs,
elapsed: Date.now() - start,
})
const stop = await SessionRetry.sleep(delayMs, signal)
.then(() => false)
.catch((error) => {
if (error instanceof DOMException && error.name === "AbortError") {
const err = new MessageV2.AbortedError(
{ message: error.message },
{
cause: error,
},
).toObject()
result.info.error = err
Bus.publish(Session.Event.Error, {
sessionID: result.info.sessionID,
error: result.info.error,
})
return true
}
throw error
})
if (stop) break
}
stream = doStream()
result = await process(stream, {
count: retry,
max: maxRetries,
})
if (!result.shouldRetry) {
break
}
}
}
msg.time.completed = Date.now()
if (
!msg.error ||
(MessageV2.AbortedError.isInstance(msg.error) &&
result.parts.some((part): part is MessageV2.TextPart => part.type === "text" && part.text.length > 0))
) {
msg.summary = true
Bus.publish(Event.Compacted, {
}),
)
if (result === "continue") {
const continueMsg = await Session.updateMessage({
id: Identifier.ascending("message"),
role: "user",
sessionID: input.sessionID,
time: {
created: Date.now(),
},
agent: "build",
model: input.model,
})
await Session.updatePart({
id: Identifier.ascending("part"),
messageID: continueMsg.id,
sessionID: input.sessionID,
type: "text",
synthetic: true,
text: "Continue if you have next steps",
time: {
start: Date.now(),
end: Date.now(),
},
})
}
await Session.updateMessage(msg)
return {
info: msg,
parts: result.parts,
}
return "continue"
}
export const create = fn(
z.object({
sessionID: Identifier.schema("session"),
model: z.object({
providerID: z.string(),
modelID: z.string(),
}),
}),
async (input) => {
const msg = await Session.updateMessage({
id: Identifier.ascending("message"),
role: "user",
model: input.model,
sessionID: input.sessionID,
agent: "build",
time: {
created: Date.now(),
},
})
await Session.updatePart({
id: Identifier.ascending("part"),
messageID: msg.id,
sessionID: msg.sessionID,
type: "compaction",
})
},
)
}

View File

@@ -1,7 +1,6 @@
import { Decimal } from "decimal.js"
import z from "zod"
import { type LanguageModelUsage, type ProviderMetadata } from "ai"
import { Bus } from "../bus"
import { Config } from "../config/config"
import { Flag } from "../flag/flag"

View File

@@ -1,97 +0,0 @@
import z from "zod"
import { Instance } from "../project/instance"
import { Log } from "../util/log"
import { NamedError } from "../util/error"
export namespace SessionLock {
const log = Log.create({ service: "session.lock" })
export const LockedError = NamedError.create(
"SessionLockedError",
z.object({
sessionID: z.string(),
message: z.string(),
}),
)
type LockState = {
controller: AbortController
created: number
}
const state = Instance.state(
() => {
const locks = new Map<string, LockState>()
return {
locks,
}
},
async (current) => {
for (const [sessionID, lock] of current.locks) {
log.info("force abort", { sessionID })
lock.controller.abort()
}
current.locks.clear()
},
)
function get(sessionID: string) {
return state().locks.get(sessionID)
}
function unset(input: { sessionID: string; controller: AbortController }) {
const lock = get(input.sessionID)
if (!lock) return false
if (lock.controller !== input.controller) return false
state().locks.delete(input.sessionID)
return true
}
export function acquire(input: { sessionID: string }) {
const lock = get(input.sessionID)
if (lock) {
throw new LockedError({
sessionID: input.sessionID,
message: `Session ${input.sessionID} is locked`,
})
}
const controller = new AbortController()
state().locks.set(input.sessionID, {
controller,
created: Date.now(),
})
log.info("locked", { sessionID: input.sessionID })
return {
signal: controller.signal,
abort() {
controller.abort()
unset({ sessionID: input.sessionID, controller })
},
async [Symbol.dispose]() {
const removed = unset({ sessionID: input.sessionID, controller })
if (removed) {
log.info("unlocked", { sessionID: input.sessionID })
}
},
}
}
export function abort(sessionID: string) {
const lock = get(sessionID)
if (!lock) return false
log.info("abort", { sessionID })
lock.controller.abort()
state().locks.delete(sessionID)
return true
}
export function isLocked(sessionID: string) {
return get(sessionID) !== undefined
}
export function assertUnlocked(sessionID: string) {
const lock = get(sessionID)
if (!lock) return
throw new LockedError({ sessionID, message: `Session ${sessionID} is locked` })
}
}

View File

@@ -142,6 +142,21 @@ export namespace MessageV2 {
})
export type AgentPart = z.infer<typeof AgentPart>
export const CompactionPart = PartBase.extend({
type: z.literal("compaction"),
}).meta({
ref: "CompactionPart",
})
export type CompactionPart = z.infer<typeof CompactionPart>
export const SubtaskPart = PartBase.extend({
type: z.literal("subtask"),
prompt: z.string(),
description: z.string(),
agent: z.string(),
})
export type SubtaskPart = z.infer<typeof SubtaskPart>
export const RetryPart = PartBase.extend({
type: z.literal("retry"),
attempt: z.number(),
@@ -277,6 +292,13 @@ export namespace MessageV2 {
diffs: Snapshot.FileDiff.array(),
})
.optional(),
agent: z.string(),
model: z.object({
providerID: z.string(),
modelID: z.string(),
}),
system: z.string().optional(),
tools: z.record(z.string(), z.boolean()).optional(),
}).meta({
ref: "UserMessage",
})
@@ -285,6 +307,7 @@ export namespace MessageV2 {
export const Part = z
.discriminatedUnion("type", [
TextPart,
SubtaskPart,
ReasoningPart,
FilePart,
ToolPart,
@@ -294,6 +317,7 @@ export namespace MessageV2 {
PatchPart,
AgentPart,
RetryPart,
CompactionPart,
])
.meta({
ref: "Part",
@@ -334,6 +358,7 @@ export namespace MessageV2 {
write: z.number(),
}),
}),
finish: z.string().optional(),
}).meta({
ref: "AssistantMessage",
})
@@ -482,6 +507,11 @@ export namespace MessageV2 {
time: {
created: v1.metadata.time.created,
},
agent: "build",
model: {
providerID: "opencode",
modelID: "opencode",
},
}
const parts = v1.parts.flatMap((part): Part[] => {
const base = {
@@ -529,107 +559,107 @@ export namespace MessageV2 {
if (msg.parts.length === 0) continue
if (msg.info.role === "user") {
result.push({
const userMessage: UIMessage = {
id: msg.info.id,
role: "user",
parts: msg.parts.flatMap((part): UIMessage["parts"] => {
if (part.type === "text")
return [
{
type: "text",
text: part.text,
},
]
// text/plain and directory files are converted into text parts, ignore them
if (part.type === "file" && part.mime !== "text/plain" && part.mime !== "application/x-directory")
return [
{
type: "file",
url: part.url,
mediaType: part.mime,
filename: part.filename,
},
]
return []
}),
})
parts: [],
}
result.push(userMessage)
for (const part of msg.parts) {
if (part.type === "text")
userMessage.parts.push({
type: "text",
text: part.text,
})
// text/plain and directory files are converted into text parts, ignore them
if (part.type === "file" && part.mime !== "text/plain" && part.mime !== "application/x-directory")
userMessage.parts.push({
type: "file",
url: part.url,
mediaType: part.mime,
filename: part.filename,
})
if (part.type === "compaction") {
userMessage.parts.push({
type: "text",
text: "What did we do so far?",
})
}
if (part.type === "subtask") {
userMessage.parts.push({
type: "text",
text: "The following tool was executed by the user",
})
}
}
}
if (msg.info.role === "assistant") {
result.push({
const assistantMessage: UIMessage = {
id: msg.info.id,
role: "assistant",
parts: msg.parts.flatMap((part): UIMessage["parts"] => {
if (part.type === "text")
return [
{
type: "text",
text: part.text,
providerMetadata: part.metadata,
},
]
if (part.type === "step-start")
return [
{
type: "step-start",
},
]
if (part.type === "tool") {
if (part.state.status === "completed") {
if (part.state.attachments?.length) {
result.push({
id: Identifier.ascending("message"),
role: "user",
parts: [
{
type: "text",
text: `Tool ${part.tool} returned an attachment:`,
},
...part.state.attachments.map((attachment) => ({
type: "file" as const,
url: attachment.url,
mediaType: attachment.mime,
filename: attachment.filename,
})),
],
})
}
return [
{
type: ("tool-" + part.tool) as `tool-${string}`,
state: "output-available",
toolCallId: part.callID,
input: part.state.input,
output: part.state.time.compacted ? "[Old tool result content cleared]" : part.state.output,
callProviderMetadata: part.metadata,
},
]
parts: [],
}
result.push(assistantMessage)
for (const part of msg.parts) {
if (part.type === "text")
assistantMessage.parts.push({
type: "text",
text: part.text,
providerMetadata: part.metadata,
})
if (part.type === "step-start")
assistantMessage.parts.push({
type: "step-start",
})
if (part.type === "tool") {
if (part.state.status === "completed") {
if (part.state.attachments?.length) {
result.push({
id: Identifier.ascending("message"),
role: "user",
parts: [
{
type: "text",
text: `Tool ${part.tool} returned an attachment:`,
},
...part.state.attachments.map((attachment) => ({
type: "file" as const,
url: attachment.url,
mediaType: attachment.mime,
filename: attachment.filename,
})),
],
})
}
if (part.state.status === "error")
return [
{
type: ("tool-" + part.tool) as `tool-${string}`,
state: "output-error",
toolCallId: part.callID,
input: part.state.input,
errorText: part.state.error,
callProviderMetadata: part.metadata,
},
]
assistantMessage.parts.push({
type: ("tool-" + part.tool) as `tool-${string}`,
state: "output-available",
toolCallId: part.callID,
input: part.state.input,
output: part.state.time.compacted ? "[Old tool result content cleared]" : part.state.output,
callProviderMetadata: part.metadata,
})
}
if (part.type === "reasoning") {
return [
{
type: "reasoning",
text: part.text,
providerMetadata: part.metadata,
},
]
}
return []
}),
})
if (part.state.status === "error")
assistantMessage.parts.push({
type: ("tool-" + part.tool) as `tool-${string}`,
state: "output-error",
toolCallId: part.callID,
input: part.state.input,
errorText: part.state.error,
callProviderMetadata: part.metadata,
})
}
if (part.type === "reasoning") {
assistantMessage.parts.push({
type: "reasoning",
text: part.text,
providerMetadata: part.metadata,
})
}
}
}
}
@@ -671,9 +701,16 @@ export namespace MessageV2 {
export async function filterCompacted(stream: AsyncIterable<MessageV2.WithParts>) {
const result = [] as MessageV2.WithParts[]
const completed = new Set<string>()
for await (const msg of stream) {
result.push(msg)
if (msg.info.role === "assistant" && msg.info.summary === true) break
if (
msg.info.role === "user" &&
completed.has(msg.info.id) &&
msg.parts.some((part) => part.type === "compaction")
)
break
if (msg.info.role === "assistant" && msg.info.summary && msg.info.finish) completed.add(msg.info.parentID)
}
result.reverse()
return result

View File

@@ -0,0 +1,372 @@
import type { ModelsDev } from "@/provider/models"
import { MessageV2 } from "./message-v2"
import { type StreamTextResult, type Tool as AITool, APICallError } from "ai"
import { Log } from "@/util/log"
import { Identifier } from "@/id/id"
import { Session } from "."
import { Agent } from "@/agent/agent"
import { Permission } from "@/permission"
import { Snapshot } from "@/snapshot"
import { SessionSummary } from "./summary"
import { Bus } from "@/bus"
import { SessionRetry } from "./retry"
import { SessionStatus } from "./status"
export namespace SessionProcessor {
const DOOM_LOOP_THRESHOLD = 3
const log = Log.create({ service: "session.processor" })
export type Info = Awaited<ReturnType<typeof create>>
export type Result = Awaited<ReturnType<Info["process"]>>
export function create(input: {
assistantMessage: MessageV2.Assistant
sessionID: string
providerID: string
model: ModelsDev.Model
abort: AbortSignal
}) {
const toolcalls: Record<string, MessageV2.ToolPart> = {}
let snapshot: string | undefined
let blocked = false
let attempt = 0
const result = {
get message() {
return input.assistantMessage
},
partFromToolCall(toolCallID: string) {
return toolcalls[toolCallID]
},
async process(fn: () => StreamTextResult<Record<string, AITool>, never>) {
log.info("process")
while (true) {
try {
let currentText: MessageV2.TextPart | undefined
let reasoningMap: Record<string, MessageV2.ReasoningPart> = {}
const stream = fn()
for await (const value of stream.fullStream) {
input.abort.throwIfAborted()
switch (value.type) {
case "start":
SessionStatus.set(input.sessionID, { type: "busy" })
break
case "reasoning-start":
if (value.id in reasoningMap) {
continue
}
reasoningMap[value.id] = {
id: Identifier.ascending("part"),
messageID: input.assistantMessage.id,
sessionID: input.assistantMessage.sessionID,
type: "reasoning",
text: "",
time: {
start: Date.now(),
},
metadata: value.providerMetadata,
}
break
case "reasoning-delta":
if (value.id in reasoningMap) {
const part = reasoningMap[value.id]
part.text += value.text
if (value.providerMetadata) part.metadata = value.providerMetadata
if (part.text) await Session.updatePart({ part, delta: value.text })
}
break
case "reasoning-end":
if (value.id in reasoningMap) {
const part = reasoningMap[value.id]
part.text = part.text.trimEnd()
part.time = {
...part.time,
end: Date.now(),
}
if (value.providerMetadata) part.metadata = value.providerMetadata
await Session.updatePart(part)
delete reasoningMap[value.id]
}
break
case "tool-input-start":
const part = await Session.updatePart({
id: toolcalls[value.id]?.id ?? Identifier.ascending("part"),
messageID: input.assistantMessage.id,
sessionID: input.assistantMessage.sessionID,
type: "tool",
tool: value.toolName,
callID: value.id,
state: {
status: "pending",
input: {},
raw: "",
},
})
toolcalls[value.id] = part as MessageV2.ToolPart
break
case "tool-input-delta":
break
case "tool-input-end":
break
case "tool-call": {
const match = toolcalls[value.toolCallId]
if (match) {
const part = await Session.updatePart({
...match,
tool: value.toolName,
state: {
status: "running",
input: value.input,
time: {
start: Date.now(),
},
},
metadata: value.providerMetadata,
})
toolcalls[value.toolCallId] = part as MessageV2.ToolPart
const parts = await MessageV2.parts(input.assistantMessage.id)
const lastThree = parts.slice(-DOOM_LOOP_THRESHOLD)
if (
lastThree.length === DOOM_LOOP_THRESHOLD &&
lastThree.every(
(p) =>
p.type === "tool" &&
p.tool === value.toolName &&
p.state.status !== "pending" &&
JSON.stringify(p.state.input) === JSON.stringify(value.input),
)
) {
const permission = await Agent.get(input.assistantMessage.mode).then((x) => x.permission)
if (permission.doom_loop === "ask") {
await Permission.ask({
type: "doom_loop",
pattern: value.toolName,
sessionID: input.assistantMessage.sessionID,
messageID: input.assistantMessage.id,
callID: value.toolCallId,
title: `Possible doom loop: "${value.toolName}" called ${DOOM_LOOP_THRESHOLD} times with identical arguments`,
metadata: {
tool: value.toolName,
input: value.input,
},
})
}
}
}
break
}
case "tool-result": {
const match = toolcalls[value.toolCallId]
if (match && match.state.status === "running") {
await Session.updatePart({
...match,
state: {
status: "completed",
input: value.input,
output: value.output.output,
metadata: value.output.metadata,
title: value.output.title,
time: {
start: match.state.time.start,
end: Date.now(),
},
attachments: value.output.attachments,
},
})
delete toolcalls[value.toolCallId]
}
break
}
case "tool-error": {
const match = toolcalls[value.toolCallId]
if (match && match.state.status === "running") {
await Session.updatePart({
...match,
state: {
status: "error",
input: value.input,
error: (value.error as any).toString(),
metadata: value.error instanceof Permission.RejectedError ? value.error.metadata : undefined,
time: {
start: match.state.time.start,
end: Date.now(),
},
},
})
if (value.error instanceof Permission.RejectedError) {
blocked = true
}
delete toolcalls[value.toolCallId]
}
break
}
case "error":
throw value.error
case "start-step":
snapshot = await Snapshot.track()
await Session.updatePart({
id: Identifier.ascending("part"),
messageID: input.assistantMessage.id,
sessionID: input.sessionID,
snapshot,
type: "step-start",
})
break
case "finish-step":
const usage = Session.getUsage({
model: input.model,
usage: value.usage,
metadata: value.providerMetadata,
})
input.assistantMessage.finish = value.finishReason
input.assistantMessage.cost += usage.cost
input.assistantMessage.tokens = usage.tokens
await Session.updatePart({
id: Identifier.ascending("part"),
reason: value.finishReason,
snapshot: await Snapshot.track(),
messageID: input.assistantMessage.id,
sessionID: input.assistantMessage.sessionID,
type: "step-finish",
tokens: usage.tokens,
cost: usage.cost,
})
await Session.updateMessage(input.assistantMessage)
if (snapshot) {
const patch = await Snapshot.patch(snapshot)
if (patch.files.length) {
await Session.updatePart({
id: Identifier.ascending("part"),
messageID: input.assistantMessage.id,
sessionID: input.sessionID,
type: "patch",
hash: patch.hash,
files: patch.files,
})
}
snapshot = undefined
}
SessionSummary.summarize({
sessionID: input.sessionID,
messageID: input.assistantMessage.parentID,
})
break
case "text-start":
currentText = {
id: Identifier.ascending("part"),
messageID: input.assistantMessage.id,
sessionID: input.assistantMessage.sessionID,
type: "text",
text: "",
time: {
start: Date.now(),
},
metadata: value.providerMetadata,
}
break
case "text-delta":
if (currentText) {
currentText.text += value.text
if (value.providerMetadata) currentText.metadata = value.providerMetadata
if (currentText.text)
await Session.updatePart({
part: currentText,
delta: value.text,
})
}
break
case "text-end":
if (currentText) {
currentText.text = currentText.text.trimEnd()
currentText.time = {
start: Date.now(),
end: Date.now(),
}
if (value.providerMetadata) currentText.metadata = value.providerMetadata
await Session.updatePart(currentText)
}
currentText = undefined
break
case "finish":
input.assistantMessage.time.completed = Date.now()
await Session.updateMessage(input.assistantMessage)
break
default:
log.info("unhandled", {
...value,
})
continue
}
}
} catch (e) {
log.error("process", {
error: e,
})
const error = MessageV2.fromError(e, { providerID: input.providerID })
if (error?.name === "APIError" && error.data.isRetryable) {
attempt++
const delay = SessionRetry.getRetryDelayInMs(error, attempt)
if (delay) {
SessionStatus.set(input.sessionID, {
type: "retry",
attempt,
message: error.data.message,
})
await SessionRetry.sleep(delay, input.abort).catch(() => {})
continue
}
}
input.assistantMessage.error = error
Bus.publish(Session.Event.Error, {
sessionID: input.assistantMessage.sessionID,
error: input.assistantMessage.error,
})
}
const p = await MessageV2.parts(input.assistantMessage.id)
for (const part of p) {
if (part.type === "tool" && part.state.status !== "completed" && part.state.status !== "error") {
await Session.updatePart({
...part,
state: {
...part.state,
status: "error",
error: "Tool execution aborted",
time: {
start: Date.now(),
end: Date.now(),
},
},
})
}
}
input.assistantMessage.time.completed = Date.now()
await Session.updateMessage(input.assistantMessage)
if (blocked) return "stop"
if (input.assistantMessage.error) return "stop"
return "continue"
}
},
}
return result
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,7 @@ import { Log } from "../util/log"
import { splitWhen } from "remeda"
import { Storage } from "../storage/storage"
import { Bus } from "../bus"
import { SessionLock } from "./lock"
import { SessionPrompt } from "./prompt"
export namespace SessionRevert {
const log = Log.create({ service: "session.revert" })
@@ -20,11 +20,7 @@ export namespace SessionRevert {
export type RevertInput = z.infer<typeof RevertInput>
export async function revert(input: RevertInput) {
SessionLock.assertUnlocked(input.sessionID)
using _ = SessionLock.acquire({
sessionID: input.sessionID,
})
SessionPrompt.assertNotBusy(input.sessionID)
const all = await Session.messages({ sessionID: input.sessionID })
let lastUser: MessageV2.User | undefined
const session = await Session.get(input.sessionID)
@@ -70,10 +66,7 @@ export namespace SessionRevert {
export async function unrevert(input: { sessionID: string }) {
log.info("unreverting", input)
SessionLock.assertUnlocked(input.sessionID)
using _ = SessionLock.acquire({
sessionID: input.sessionID,
})
SessionPrompt.assertNotBusy(input.sessionID)
const session = await Session.get(input.sessionID)
if (!session.revert) return session
if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)

View File

@@ -0,0 +1,63 @@
import { Bus } from "@/bus"
import { Instance } from "@/project/instance"
import z from "zod"
export namespace SessionStatus {
export const Info = z
.union([
z.object({
type: z.literal("idle"),
}),
z.object({
type: z.literal("retry"),
attempt: z.number(),
message: z.string(),
}),
z.object({
type: z.literal("busy"),
}),
])
.meta({
ref: "SessionStatus",
})
export type Info = z.infer<typeof Info>
export const Event = {
Status: Bus.event(
"session.status",
z.object({
sessionID: z.string(),
status: Info,
}),
),
}
const state = Instance.state(() => {
const data: Record<string, Info> = {}
return data
})
export function get(sessionID: string) {
return (
state()[sessionID] ?? {
type: "idle",
}
)
}
export function list() {
return Object.values(state())
}
export function set(sessionID: string, status: Info) {
Bus.publish(Event.Status, {
sessionID,
status,
})
if (status.type === "idle") {
delete state()[sessionID]
return
}
state()[sessionID] = status
}
}

View File

@@ -43,7 +43,7 @@ export namespace SystemPrompt {
` Platform: ${process.platform}`,
` Today's date: ${new Date().toDateString()}`,
`</env>`,
`<project>`,
`<files>`,
` ${
project.vcs === "git"
? await Ripgrep.tree({
@@ -52,7 +52,7 @@ export namespace SystemPrompt {
})
: ""
}`,
`</project>`,
`</files>`,
].join("\n"),
]
}

View File

@@ -6,8 +6,8 @@ import { Bus } from "../bus"
import { MessageV2 } from "../session/message-v2"
import { Identifier } from "../id/id"
import { Agent } from "../agent/agent"
import { SessionLock } from "../session/lock"
import { SessionPrompt } from "../session/prompt"
import { defer } from "@/util/defer"
export const TaskTool = Tool.define("task", async () => {
const agents = await Agent.list().then((x) => x.filter((a) => a.mode !== "primary"))
@@ -62,9 +62,11 @@ export const TaskTool = Tool.define("task", async () => {
providerID: msg.info.providerID,
}
ctx.abort.addEventListener("abort", () => {
SessionLock.abort(session.id)
})
function cancel() {
SessionPrompt.cancel(session.id)
}
ctx.abort.addEventListener("abort", cancel)
using _ = defer(() => ctx.abort.removeEventListener("abort", cancel))
const promptParts = await SessionPrompt.resolvePromptParts(params.prompt)
const result = await SessionPrompt.prompt({
messageID,

View File

@@ -26,6 +26,9 @@ import type {
SessionCreateData,
SessionCreateResponses,
SessionCreateErrors,
SessionStatusData,
SessionStatusResponses,
SessionStatusErrors,
SessionDeleteData,
SessionDeleteResponses,
SessionDeleteErrors,
@@ -306,6 +309,16 @@ class Session extends _HeyApiClient {
})
}
/**
* Get session status
*/
public status<ThrowOnError extends boolean = false>(options?: Options<SessionStatusData, ThrowOnError>) {
return (options?.client ?? this._client).get<SessionStatusResponses, SessionStatusErrors, ThrowOnError>({
url: "/session/status",
...options,
})
}
/**
* Delete a session and all its data
*/

View File

@@ -42,6 +42,15 @@ export type UserMessage = {
body?: string
diffs: Array<FileDiff>
}
agent: string
model: {
providerID: string
modelID: string
}
system?: string
tools?: {
[key: string]: boolean
}
}
export type ProviderAuthError = {
@@ -114,6 +123,7 @@ export type AssistantMessage = {
write: number
}
}
finish?: string
}
export type Message = UserMessage | AssistantMessage
@@ -348,6 +358,13 @@ export type RetryPart = {
}
}
export type CompactionPart = {
id: string
sessionID: string
messageID: string
type: "compaction"
}
export type Part =
| TextPart
| ReasoningPart
@@ -359,6 +376,7 @@ export type Part =
| PatchPart
| AgentPart
| RetryPart
| CompactionPart
export type EventMessagePartUpdated = {
type: "message.part.updated"
@@ -377,13 +395,6 @@ export type EventMessagePartRemoved = {
}
}
export type EventSessionCompacted = {
type: "session.compacted"
properties: {
sessionID: string
}
}
export type Permission = {
id: string
type: string
@@ -414,6 +425,13 @@ export type EventPermissionReplied = {
}
}
export type EventSessionCompacted = {
type: "session.compacted"
properties: {
sessionID: string
}
}
export type EventFileEdited = {
type: "file.edited"
properties: {
@@ -458,6 +476,27 @@ export type EventCommandExecuted = {
}
}
export type SessionStatus =
| {
type: "idle"
}
| {
type: "retry"
attempt: number
message: string
}
| {
type: "busy"
}
export type EventSessionStatus = {
type: "session.status"
properties: {
sessionID: string
status: SessionStatus
}
}
export type EventSessionIdle = {
type: "session.idle"
properties: {
@@ -598,12 +637,13 @@ export type Event =
| EventMessageRemoved
| EventMessagePartUpdated
| EventMessagePartRemoved
| EventSessionCompacted
| EventPermissionUpdated
| EventPermissionReplied
| EventSessionCompacted
| EventFileEdited
| EventTodoUpdated
| EventCommandExecuted
| EventSessionStatus
| EventSessionIdle
| EventSessionCreated
| EventSessionUpdated
@@ -1613,6 +1653,35 @@ export type SessionCreateResponses = {
export type SessionCreateResponse = SessionCreateResponses[keyof SessionCreateResponses]
export type SessionStatusData = {
body?: never
path?: never
query?: {
directory?: string
}
url: "/session/status"
}
export type SessionStatusErrors = {
/**
* Bad request
*/
400: BadRequestError
}
export type SessionStatusError = SessionStatusErrors[keyof SessionStatusErrors]
export type SessionStatusResponses = {
/**
* Get session status
*/
200: {
[key: string]: SessionStatus
}
}
export type SessionStatusResponse = SessionStatusResponses[keyof SessionStatusResponses]
export type SessionDeleteData = {
body?: never
path: {