Part data model (#950)

This commit is contained in:
Dax
2025-07-13 17:22:11 -04:00
committed by GitHub
parent 736396fc70
commit 90d6c4ab41
27 changed files with 1447 additions and 965 deletions

View File

@@ -33,6 +33,7 @@
"@openauthjs/openauth": "0.4.3",
"@standard-schema/spec": "1.0.0",
"ai": "catalog:",
"cli-markdown": "3.5.1",
"decimal.js": "10.5.0",
"diff": "8.0.2",
"env-paths": "3.0.0",

View File

@@ -9,6 +9,7 @@ import { Config } from "../../config/config"
import { bootstrap } from "../bootstrap"
import { MessageV2 } from "../../session/message-v2"
import { Mode } from "../../session/mode"
import { Identifier } from "../../id/id"
const TOOL: Record<string, [string, string]> = {
todowrite: ["Todo", UI.Style.TEXT_WARNING_BOLD],
@@ -83,14 +84,9 @@ export const RunCommand = cmd({
return
}
const isPiped = !process.stdout.isTTY
UI.empty()
UI.println(UI.logo())
UI.empty()
const displayMessage = message.length > 300 ? message.slice(0, 300) + "..." : message
UI.println(UI.Style.TEXT_NORMAL_BOLD + "> ", displayMessage)
UI.empty()
const cfg = await Config.get()
if (cfg.share === "auto" || Flag.OPENCODE_AUTO_SHARE || args.share) {
@@ -120,8 +116,10 @@ export const RunCommand = cmd({
)
}
let text = ""
Bus.subscribe(MessageV2.Event.PartUpdated, async (evt) => {
if (evt.properties.sessionID !== session.id) return
if (evt.properties.part.sessionID !== session.id) return
if (evt.properties.part.messageID === messageID) return
const part = evt.properties.part
if (part.type === "tool" && part.state.status === "completed") {
@@ -130,13 +128,15 @@ export const RunCommand = cmd({
}
if (part.type === "text") {
if (part.text.includes("\n")) {
text = part.text
if (part.time?.end) {
UI.empty()
UI.println(part.text)
UI.println(UI.markdown(text))
UI.empty()
text = ""
return
}
printEvent(UI.Style.TEXT_NORMAL_BOLD, "Text", part.text)
}
})
@@ -156,8 +156,10 @@ export const RunCommand = cmd({
const mode = args.mode ? await Mode.get(args.mode) : await Mode.list().then((x) => x[0])
const messageID = Identifier.ascending("message")
const result = await Session.chat({
sessionID: session.id,
messageID,
...(mode.model
? mode.model
: {
@@ -167,15 +169,19 @@ export const RunCommand = cmd({
mode: mode.name,
parts: [
{
id: Identifier.ascending("part"),
sessionID: session.id,
messageID: messageID,
type: "text",
text: message,
},
],
})
const isPiped = !process.stdout.isTTY
if (isPiped) {
const match = result.parts.findLast((x) => x.type === "text")
if (match) process.stdout.write(match.text)
if (match) process.stdout.write(UI.markdown(match.text))
if (errorMsg) process.stdout.write(errorMsg)
}
UI.empty()

View File

@@ -1,7 +1,4 @@
import { Storage } from "../../storage/storage"
import { MessageV2 } from "../../session/message-v2"
import { cmd } from "./cmd"
import { bootstrap } from "../bootstrap"
interface SessionStats {
totalSessions: number
@@ -27,87 +24,10 @@ interface SessionStats {
export const StatsCommand = cmd({
command: "stats",
handler: async () => {
await bootstrap({ cwd: process.cwd() }, async () => {
const stats: SessionStats = {
totalSessions: 0,
totalMessages: 0,
totalCost: 0,
totalTokens: {
input: 0,
output: 0,
reasoning: 0,
cache: {
read: 0,
write: 0,
},
},
toolUsage: {},
dateRange: {
earliest: Date.now(),
latest: 0,
},
days: 0,
costPerDay: 0,
}
const sessionMap = new Map<string, number>()
try {
for await (const messagePath of Storage.list("session/message")) {
try {
const message = await Storage.readJSON<MessageV2.Info>(messagePath)
if (!message.parts.find((part) => part.type === "step-finish")) continue
stats.totalMessages++
const sessionId = message.sessionID
sessionMap.set(sessionId, (sessionMap.get(sessionId) || 0) + 1)
if (message.time.created < stats.dateRange.earliest) {
stats.dateRange.earliest = message.time.created
}
if (message.time.created > stats.dateRange.latest) {
stats.dateRange.latest = message.time.created
}
if (message.role === "assistant") {
stats.totalCost += message.cost
stats.totalTokens.input += message.tokens.input
stats.totalTokens.output += message.tokens.output
stats.totalTokens.reasoning += message.tokens.reasoning
stats.totalTokens.cache.read += message.tokens.cache.read
stats.totalTokens.cache.write += message.tokens.cache.write
for (const part of message.parts) {
if (part.type === "tool") {
stats.toolUsage[part.tool] = (stats.toolUsage[part.tool] || 0) + 1
}
}
}
} catch (e) {
continue
}
}
} catch (e) {
console.error("Failed to read storage:", e)
return
}
stats.totalSessions = sessionMap.size
if (stats.dateRange.latest > 0) {
const daysDiff = (stats.dateRange.latest - stats.dateRange.earliest) / (1000 * 60 * 60 * 24)
stats.days = Math.max(1, Math.ceil(daysDiff))
stats.costPerDay = stats.totalCost / stats.days
}
displayStats(stats)
})
},
handler: async () => {},
})
function displayStats(stats: SessionStats) {
export function displayStats(stats: SessionStats) {
const width = 56
function renderRow(label: string, value: string): string {

View File

@@ -1,6 +1,8 @@
import { z } from "zod"
import { EOL } from "os"
import { NamedError } from "../util/error"
// @ts-ignore
import cliMarkdown from "cli-markdown"
export namespace UI {
const LOGO = [
@@ -76,4 +78,18 @@ export namespace UI {
export function error(message: string) {
println(Style.TEXT_DANGER_BOLD + "Error: " + Style.TEXT_NORMAL + message)
}
export function markdown(text: string): string {
const rendered = cliMarkdown(text, {
width: process.stdout.columns || 80,
firstHeading: false,
tab: 0,
}).trim()
// Remove leading space from each line
return rendered
.split("\n")
.map((line: string) => line.replace(/^ /, ""))
.join("\n")
}
}

View File

@@ -6,6 +6,7 @@ export namespace Identifier {
session: "ses",
message: "msg",
user: "usr",
part: "prt",
} as const
export function schema(prefix: keyof typeof prefixes) {

View File

@@ -269,6 +269,7 @@ export namespace Server {
zValidator(
"json",
z.object({
messageID: z.string(),
providerID: z.string(),
modelID: z.string(),
}),
@@ -405,7 +406,14 @@ export namespace Server {
description: "List of messages",
content: {
"application/json": {
schema: resolver(MessageV2.Info.array()),
schema: resolver(
z
.object({
info: MessageV2.Info,
parts: MessageV2.Part.array(),
})
.array(),
),
},
},
},
@@ -446,10 +454,11 @@ export namespace Server {
zValidator(
"json",
z.object({
messageID: z.string(),
providerID: z.string(),
modelID: z.string(),
mode: z.string(),
parts: MessageV2.UserPart.array(),
parts: z.union([MessageV2.FilePart, MessageV2.TextPart]).array(),
}),
),
async (c) => {

View File

@@ -12,6 +12,7 @@ import {
type ProviderMetadata,
type ModelMessage,
stepCountIs,
type StreamTextResult,
} from "ai"
import PROMPT_INITIALIZE from "../session/prompt/initialize.txt"
@@ -190,7 +191,10 @@ export namespace Session {
await Storage.writeJSON<ShareInfo>("session/share/" + id, share)
await Share.sync("session/info/" + id, session)
for (const msg of await messages(id)) {
await Share.sync("session/message/" + id + "/" + msg.id, msg)
await Share.sync("session/message/" + id + "/" + msg.info.id, msg.info)
for (const part of msg.parts) {
await Share.sync("session/part/" + id + "/" + msg.info.id + "/" + part.id, part)
}
}
return share
}
@@ -220,13 +224,19 @@ export namespace Session {
}
export async function messages(sessionID: string) {
const result = [] as MessageV2.Info[]
const result = [] as {
info: MessageV2.Info
parts: MessageV2.Part[]
}[]
const list = Storage.list("session/message/" + sessionID)
for await (const p of list) {
const read = await Storage.readJSON<MessageV2.Info>(p)
result.push(read)
result.push({
info: read,
parts: await parts(sessionID, read.id),
})
}
result.sort((a, b) => (a.id > b.id ? 1 : -1))
result.sort((a, b) => (a.info.id > b.info.id ? 1 : -1))
return result
}
@@ -234,6 +244,16 @@ export namespace Session {
return Storage.readJSON<MessageV2.Info>("session/message/" + sessionID + "/" + messageID)
}
export async function parts(sessionID: string, messageID: string) {
const result = [] as MessageV2.Part[]
for await (const item of Storage.list("session/part/" + sessionID + "/" + messageID)) {
const read = await Storage.readJSON<MessageV2.Part>(item)
result.push(read)
}
result.sort((a, b) => (a.id > b.id ? 1 : -1))
return result
}
export async function* list() {
for await (const item of Storage.list("session/info")) {
const sessionID = path.basename(item, ".json")
@@ -289,12 +309,21 @@ export namespace Session {
})
}
async function updatePart(part: MessageV2.Part) {
await Storage.writeJSON(["session", "part", part.sessionID, part.messageID, part.id].join("/"), part)
Bus.publish(MessageV2.Event.PartUpdated, {
part,
})
return part
}
export async function chat(input: {
sessionID: string
messageID: string
providerID: string
modelID: string
mode?: string
parts: MessageV2.UserPart[]
parts: (MessageV2.TextPart | MessageV2.FilePart)[]
}) {
const l = log.clone().tag("session", input.sessionID)
l.info("chatting")
@@ -306,16 +335,19 @@ export namespace Session {
if (session.revert) {
const trimmed = []
for (const msg of msgs) {
if (msg.id > session.revert.messageID || (msg.id === session.revert.messageID && session.revert.part === 0)) {
await Storage.remove("session/message/" + input.sessionID + "/" + msg.id)
if (
msg.info.id > session.revert.messageID ||
(msg.info.id === session.revert.messageID && session.revert.part === 0)
) {
await Storage.remove("session/message/" + input.sessionID + "/" + msg.info.id)
await Bus.publish(MessageV2.Event.Removed, {
sessionID: input.sessionID,
messageID: msg.id,
messageID: msg.info.id,
})
continue
}
if (msg.id === session.revert.messageID) {
if (msg.info.id === session.revert.messageID) {
if (session.revert.part === 0) break
msg.parts = msg.parts.slice(0, session.revert.part)
}
@@ -327,7 +359,7 @@ export namespace Session {
})
}
const previous = msgs.at(-1) as MessageV2.Assistant
const previous = msgs.filter((x) => x.info.role === "assistant").at(-1)?.info as MessageV2.Assistant
const outputLimit = Math.min(model.info.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
// auto summarize if too long
@@ -346,12 +378,21 @@ export namespace Session {
using abort = lock(input.sessionID)
const lastSummary = msgs.findLast((msg) => msg.role === "assistant" && msg.summary === true)
if (lastSummary) msgs = msgs.filter((msg) => msg.id >= lastSummary.id)
const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
if (lastSummary) msgs = msgs.filter((msg) => msg.info.id >= lastSummary.info.id)
const userMsg: MessageV2.Info = {
id: input.messageID,
role: "user",
sessionID: input.sessionID,
time: {
created: Date.now(),
},
}
const app = App.info()
input.parts = await Promise.all(
input.parts.map(async (part): Promise<MessageV2.UserPart[]> => {
const userParts = await Promise.all(
input.parts.map(async (part): Promise<MessageV2.Part[]> => {
if (part.type === "file") {
const url = new URL(part.url)
switch (url.protocol) {
@@ -406,11 +447,17 @@ export namespace Session {
})
return [
{
id: Identifier.ascending("part"),
messageID: userMsg.id,
sessionID: input.sessionID,
type: "text",
synthetic: true,
text: `Called the Read tool with the following input: ${JSON.stringify(args)}`,
},
{
id: Identifier.ascending("part"),
messageID: userMsg.id,
sessionID: input.sessionID,
type: "text",
synthetic: true,
text: result.output,
@@ -422,11 +469,17 @@ export namespace Session {
FileTime.read(input.sessionID, filePath)
return [
{
id: Identifier.ascending("part"),
messageID: userMsg.id,
sessionID: input.sessionID,
type: "text",
text: `Called the Read tool with the following input: {\"filePath\":\"${pathname}\"}`,
synthetic: true,
},
{
id: Identifier.ascending("part"),
messageID: userMsg.id,
sessionID: input.sessionID,
type: "file",
url: `data:${part.mime};base64,` + Buffer.from(await file.bytes()).toString("base64"),
mime: part.mime,
@@ -440,7 +493,10 @@ export namespace Session {
).then((x) => x.flat())
if (input.mode === "plan")
input.parts.push({
userParts.push({
id: Identifier.ascending("part"),
messageID: userMsg.id,
sessionID: input.sessionID,
type: "text",
text: PROMPT_PLAN,
synthetic: true,
@@ -459,13 +515,15 @@ export namespace Session {
),
...MessageV2.toModelMessage([
{
id: Identifier.ascending("message"),
role: "user",
sessionID: input.sessionID,
parts: input.parts,
time: {
created: Date.now(),
info: {
id: Identifier.ascending("message"),
role: "user",
sessionID: input.sessionID,
time: {
created: Date.now(),
},
},
parts: userParts,
},
]),
],
@@ -479,17 +537,11 @@ export namespace Session {
})
.catch(() => {})
}
const msg: MessageV2.Info = {
id: Identifier.ascending("message"),
role: "user",
sessionID: input.sessionID,
parts: input.parts,
time: {
created: Date.now(),
},
await updateMessage(userMsg)
for (const part of userParts) {
await updatePart(part)
}
await updateMessage(msg)
msgs.push(msg)
msgs.push({ info: userMsg, parts: userParts })
const mode = await Mode.get(input.mode ?? "build")
let system = mode.prompt ? [mode.prompt] : SystemPrompt.provider(input.providerID, input.modelID)
@@ -499,10 +551,9 @@ export namespace Session {
const [first, ...rest] = system
system = [first, rest.join("\n")]
const next: MessageV2.Info = {
const assistantMsg: MessageV2.Info = {
id: Identifier.ascending("message"),
role: "assistant",
parts: [],
system,
path: {
cwd: app.path.cwd,
@@ -522,7 +573,7 @@ export namespace Session {
},
sessionID: input.sessionID,
}
await updateMessage(next)
await updateMessage(assistantMsg)
const tools: Record<string, AITool> = {}
for (const item of await Provider.tools(input.providerID)) {
@@ -531,20 +582,29 @@ export namespace Session {
id: item.id as any,
description: item.description,
inputSchema: item.parameters as ZodSchema,
async execute(args, opts) {
async execute(args) {
const result = await item.execute(args, {
sessionID: input.sessionID,
abort: abort.signal,
messageID: next.id,
metadata: async (val) => {
const match = next.parts.find(
(p): p is MessageV2.ToolPart => p.type === "tool" && p.id === opts.toolCallId,
)
messageID: assistantMsg.id,
metadata: async () => {
/*
const match = toolCalls[opts.toolCallId]
if (match && match.state.status === "running") {
match.state.title = val.title
match.state.metadata = val.metadata
await updatePart({
...match,
state: {
title: val.title,
metadata: val.metadata,
status: "running",
input: args.input,
time: {
start: Date.now(),
},
},
})
}
await updateMessage(next)
*/
},
})
return result
@@ -582,10 +642,6 @@ export namespace Session {
tools[key] = item
}
let text: MessageV2.TextPart = {
type: "text",
text: "",
}
const result = streamText({
onError() {},
maxRetries: 10,
@@ -619,9 +675,20 @@ export namespace Session {
],
}),
})
return processStream(assistantMsg, model.info, result)
}
async function processStream(
assistantMsg: MessageV2.Assistant,
model: ModelsDev.Model,
stream: StreamTextResult<Record<string, AITool>, never>,
) {
try {
for await (const value of result.fullStream) {
l.info("part", {
let currentText: MessageV2.TextPart | undefined
const toolCalls: Record<string, MessageV2.ToolPart> = {}
for await (const value of stream.fullStream) {
log.info("part", {
type: value.type,
})
switch (value.type) {
@@ -629,88 +696,78 @@ export namespace Session {
break
case "tool-input-start":
next.parts.push({
const part = await updatePart({
id: Identifier.ascending("part"),
messageID: assistantMsg.id,
sessionID: assistantMsg.sessionID,
type: "tool",
tool: value.toolName,
id: value.id,
callID: value.id,
state: {
status: "pending",
},
})
Bus.publish(MessageV2.Event.PartUpdated, {
part: next.parts[next.parts.length - 1],
sessionID: next.sessionID,
messageID: next.id,
})
toolCalls[value.id] = part as MessageV2.ToolPart
break
case "tool-input-delta":
break
case "tool-call": {
const match = next.parts.find(
(p): p is MessageV2.ToolPart => p.type === "tool" && p.id === value.toolCallId,
)
const match = toolCalls[value.toolCallId]
if (match) {
match.state = {
status: "running",
input: value.input,
time: {
start: Date.now(),
const part = await updatePart({
...match,
state: {
status: "running",
input: value.input,
time: {
start: Date.now(),
},
},
}
Bus.publish(MessageV2.Event.PartUpdated, {
part: match,
sessionID: next.sessionID,
messageID: next.id,
})
toolCalls[value.toolCallId] = part as MessageV2.ToolPart
}
break
}
case "tool-result": {
const match = next.parts.find(
(p): p is MessageV2.ToolPart => p.type === "tool" && p.id === value.toolCallId,
)
const match = toolCalls[value.toolCallId]
if (match && match.state.status === "running") {
match.state = {
status: "completed",
input: value.input,
output: value.output.output,
metadata: value.output.metadata,
title: value.output.title,
time: {
start: match.state.time.start,
end: Date.now(),
await updatePart({
...match,
state: {
status: "completed",
input: value.input,
output: value.output.output,
metadata: value.output.metadata,
title: value.output.title,
time: {
start: match.state.time.start,
end: Date.now(),
},
},
}
Bus.publish(MessageV2.Event.PartUpdated, {
part: match,
sessionID: next.sessionID,
messageID: next.id,
})
delete toolCalls[value.toolCallId]
}
break
}
case "tool-error": {
const match = next.parts.find(
(p): p is MessageV2.ToolPart => p.type === "tool" && p.id === value.toolCallId,
)
const match = toolCalls[value.toolCallId]
if (match && match.state.status === "running") {
match.state = {
status: "error",
input: value.input,
error: (value.error as any).toString(),
time: {
start: match.state.time.start,
end: Date.now(),
await updatePart({
...match,
state: {
status: "error",
input: value.input,
error: (value.error as any).toString(),
time: {
start: match.state.time.start,
end: Date.now(),
},
},
}
Bus.publish(MessageV2.Event.PartUpdated, {
part: match,
sessionID: next.sessionID,
messageID: next.id,
})
delete toolCalls[value.toolCallId]
}
break
}
@@ -719,53 +776,71 @@ export namespace Session {
throw value.error
case "start-step":
next.parts.push({
await updatePart({
id: Identifier.ascending("part"),
messageID: assistantMsg.id,
sessionID: assistantMsg.sessionID,
type: "step-start",
})
break
case "finish-step":
const usage = getUsage(model.info, value.usage, value.providerMetadata)
next.cost += usage.cost
next.tokens = usage.tokens
next.parts.push({
const usage = getUsage(model, value.usage, value.providerMetadata)
assistantMsg.cost += usage.cost
assistantMsg.tokens = usage.tokens
await updatePart({
id: Identifier.ascending("part"),
messageID: assistantMsg.id,
sessionID: assistantMsg.sessionID,
type: "step-finish",
tokens: usage.tokens,
cost: usage.cost,
})
await updateMessage(assistantMsg)
break
case "text-start":
text = {
currentText = {
id: Identifier.ascending("part"),
messageID: assistantMsg.id,
sessionID: assistantMsg.sessionID,
type: "text",
text: "",
time: {
start: Date.now(),
},
}
break
case "text":
if (text.text === "") next.parts.push(text)
text.text += value.text
if (currentText) {
currentText.text += value.text
await updatePart(currentText)
}
break
case "text-end":
Bus.publish(MessageV2.Event.PartUpdated, {
part: text,
sessionID: next.sessionID,
messageID: next.id,
})
if (currentText && currentText.text) {
currentText.time = {
start: Date.now(),
end: Date.now(),
}
await updatePart(currentText)
}
currentText = undefined
break
case "finish":
next.time.completed = Date.now()
assistantMsg.time.completed = Date.now()
await updateMessage(assistantMsg)
break
default:
l.info("unhandled", {
log.info("unhandled", {
...value,
})
continue
}
await updateMessage(next)
}
} catch (e) {
log.error("", {
@@ -773,7 +848,7 @@ export namespace Session {
})
switch (true) {
case e instanceof DOMException && e.name === "AbortError":
next.error = new MessageV2.AbortedError(
assistantMsg.error = new MessageV2.AbortedError(
{ message: e.message },
{
cause: e,
@@ -781,44 +856,48 @@ export namespace Session {
).toObject()
break
case MessageV2.OutputLengthError.isInstance(e):
next.error = e
assistantMsg.error = e
break
case LoadAPIKeyError.isInstance(e):
next.error = new Provider.AuthError(
assistantMsg.error = new Provider.AuthError(
{
providerID: input.providerID,
providerID: model.id,
message: e.message,
},
{ cause: e },
).toObject()
break
case e instanceof Error:
next.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
assistantMsg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
break
default:
next.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
assistantMsg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
}
Bus.publish(Event.Error, {
sessionID: next.sessionID,
error: next.error,
sessionID: assistantMsg.sessionID,
error: assistantMsg.error,
})
}
for (const part of next.parts) {
const p = await parts(assistantMsg.sessionID, assistantMsg.id)
for (const part of p) {
if (part.type === "tool" && part.state.status !== "completed") {
part.state = {
status: "error",
error: "Tool execution aborted",
time: {
start: Date.now(),
end: Date.now(),
updatePart({
...part,
state: {
status: "error",
error: "Tool execution aborted",
time: {
start: Date.now(),
end: Date.now(),
},
input: {},
},
input: {},
}
})
}
}
next.time.completed = Date.now()
await updateMessage(next)
return next
assistantMsg.time.completed = Date.now()
await updateMessage(assistantMsg)
return { info: assistantMsg, parts: p }
}
export async function revert(_input: { sessionID: string; messageID: string; part: number }) {
@@ -867,8 +946,8 @@ export namespace Session {
export async function summarize(input: { sessionID: string; providerID: string; modelID: string }) {
using abort = lock(input.sessionID)
const msgs = await messages(input.sessionID)
const lastSummary = msgs.findLast((msg) => msg.role === "assistant" && msg.summary === true)?.id
const filtered = msgs.filter((msg) => !lastSummary || msg.id >= lastSummary)
const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
const filtered = msgs.filter((msg) => !lastSummary || msg.info.id >= lastSummary.info.id)
const model = await Provider.getModel(input.providerID, input.modelID)
const app = App.info()
const system = SystemPrompt.summarize(input.providerID)
@@ -876,7 +955,6 @@ export namespace Session {
const next: MessageV2.Info = {
id: Identifier.ascending("message"),
role: "assistant",
parts: [],
sessionID: input.sessionID,
system,
path: {
@@ -899,7 +977,6 @@ export namespace Session {
}
await updateMessage(next)
let text: MessageV2.TextPart | undefined
const result = streamText({
abortSignal: abort.signal,
model: model.language,
@@ -921,81 +998,9 @@ export namespace Session {
],
},
],
onStepFinish: async (step) => {
const usage = getUsage(model.info, step.usage, step.providerMetadata)
next.cost += usage.cost
next.tokens = usage.tokens
await updateMessage(next)
if (text) {
Bus.publish(MessageV2.Event.PartUpdated, {
part: text,
messageID: next.id,
sessionID: next.sessionID,
})
}
text = undefined
},
async onFinish(input) {
const usage = getUsage(model.info, input.usage, input.providerMetadata)
next.cost += usage.cost
next.tokens = usage.tokens
next.time.completed = Date.now()
await updateMessage(next)
},
})
try {
for await (const value of result.fullStream) {
switch (value.type) {
case "text":
if (!text) {
text = {
type: "text",
text: value.text,
}
next.parts.push(text)
} else text.text += value.text
await updateMessage(next)
break
}
}
} catch (e: any) {
log.error("summarize stream error", {
error: e,
})
switch (true) {
case e instanceof DOMException && e.name === "AbortError":
next.error = new MessageV2.AbortedError(
{ message: e.message },
{
cause: e,
},
).toObject()
break
case MessageV2.OutputLengthError.isInstance(e):
next.error = e
break
case LoadAPIKeyError.isInstance(e):
next.error = new Provider.AuthError(
{
providerID: input.providerID,
message: e.message,
},
{ cause: e },
).toObject()
break
case e instanceof Error:
next.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
break
default:
next.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e }).toObject()
}
Bus.publish(Event.Error, {
error: next.error,
})
}
next.time.completed = Date.now()
await updateMessage(next)
return processStream(next, model.info, result)
}
function lock(sessionID: string) {
@@ -1045,14 +1050,23 @@ export namespace Session {
}
}
export async function initialize(input: { sessionID: string; modelID: string; providerID: string }) {
export async function initialize(input: {
sessionID: string
modelID: string
providerID: string
messageID: string
}) {
const app = App.info()
await Session.chat({
sessionID: input.sessionID,
messageID: input.messageID,
providerID: input.providerID,
modelID: input.modelID,
parts: [
{
id: Identifier.ascending("part"),
sessionID: input.sessionID,
messageID: input.messageID,
type: "text",
text: PROMPT_INITIALIZE.replace("${path}", app.path.root),
},

View File

@@ -4,6 +4,7 @@ import { Provider } from "../provider/provider"
import { NamedError } from "../util/error"
import { Message } from "./message"
import { convertToModelMessages, type ModelMessage, type UIMessage } from "ai"
import { Identifier } from "../id/id"
export namespace MessageV2 {
export const OutputLengthError = NamedError.create("MessageOutputLengthError", z.object({}))
@@ -72,67 +73,69 @@ export namespace MessageV2 {
ref: "ToolState",
})
export const TextPart = z
.object({
type: z.literal("text"),
text: z.string(),
synthetic: z.boolean().optional(),
})
.openapi({
ref: "TextPart",
})
const PartBase = z.object({
id: z.string(),
sessionID: z.string(),
messageID: z.string(),
})
export const TextPart = PartBase.extend({
type: z.literal("text"),
text: z.string(),
synthetic: z.boolean().optional(),
time: z
.object({
start: z.number(),
end: z.number().optional(),
})
.optional(),
}).openapi({
ref: "TextPart",
})
export type TextPart = z.infer<typeof TextPart>
export const ToolPart = z
.object({
type: z.literal("tool"),
id: z.string(),
tool: z.string(),
state: ToolState,
})
.openapi({
ref: "ToolPart",
})
export const ToolPart = PartBase.extend({
type: z.literal("tool"),
callID: z.string(),
tool: z.string(),
state: ToolState,
}).openapi({
ref: "ToolPart",
})
export type ToolPart = z.infer<typeof ToolPart>
export const FilePart = z
.object({
type: z.literal("file"),
mime: z.string(),
filename: z.string().optional(),
url: z.string(),
})
.openapi({
ref: "FilePart",
})
export const FilePart = PartBase.extend({
type: z.literal("file"),
mime: z.string(),
filename: z.string().optional(),
url: z.string(),
}).openapi({
ref: "FilePart",
})
export type FilePart = z.infer<typeof FilePart>
export const StepStartPart = z
.object({
type: z.literal("step-start"),
})
.openapi({
ref: "StepStartPart",
})
export const StepStartPart = PartBase.extend({
type: z.literal("step-start"),
}).openapi({
ref: "StepStartPart",
})
export type StepStartPart = z.infer<typeof StepStartPart>
export const StepFinishPart = z
.object({
type: z.literal("step-finish"),
cost: z.number(),
tokens: z.object({
input: z.number(),
output: z.number(),
reasoning: z.number(),
cache: z.object({
read: z.number(),
write: z.number(),
}),
export const StepFinishPart = PartBase.extend({
type: z.literal("step-finish"),
cost: z.number(),
tokens: z.object({
input: z.number(),
output: z.number(),
reasoning: z.number(),
cache: z.object({
read: z.number(),
write: z.number(),
}),
})
.openapi({
ref: "StepFinishPart",
})
}),
}).openapi({
ref: "StepFinishPart",
})
export type StepFinishPart = z.infer<typeof StepFinishPart>
const Base = z.object({
@@ -140,14 +143,8 @@ export namespace MessageV2 {
sessionID: z.string(),
})
export const UserPart = z.discriminatedUnion("type", [TextPart, FilePart]).openapi({
ref: "UserMessagePart",
})
export type UserPart = z.infer<typeof UserPart>
export const User = Base.extend({
role: z.literal("user"),
parts: z.array(UserPart),
time: z.object({
created: z.number(),
}),
@@ -156,16 +153,15 @@ export namespace MessageV2 {
})
export type User = z.infer<typeof User>
export const AssistantPart = z
.discriminatedUnion("type", [TextPart, ToolPart, StepStartPart, StepFinishPart])
export const Part = z
.discriminatedUnion("type", [TextPart, FilePart, ToolPart, StepStartPart, StepFinishPart])
.openapi({
ref: "AssistantMessagePart",
ref: "Part",
})
export type AssistantPart = z.infer<typeof AssistantPart>
export type Part = z.infer<typeof Part>
export const Assistant = Base.extend({
role: z.literal("assistant"),
parts: z.array(AssistantPart),
time: z.object({
created: z.number(),
completed: z.number().optional(),
@@ -223,16 +219,14 @@ export namespace MessageV2 {
PartUpdated: Bus.event(
"message.part.updated",
z.object({
part: AssistantPart,
sessionID: z.string(),
messageID: z.string(),
part: Part,
}),
),
}
export function fromV1(v1: Message.Info) {
if (v1.role === "assistant") {
const result: Assistant = {
const info: Assistant = {
id: v1.id,
sessionID: v1.metadata.sessionID,
role: "assistant",
@@ -248,109 +242,135 @@ export namespace MessageV2 {
providerID: v1.metadata.assistant!.providerID,
system: v1.metadata.assistant!.system,
error: v1.metadata.error,
parts: v1.parts.flatMap((part): AssistantPart[] => {
if (part.type === "text") {
return [
{
type: "text",
text: part.text,
},
]
}
if (part.type === "step-start") {
return [
{
type: "step-start",
},
]
}
if (part.type === "tool-invocation") {
return [
{
type: "tool",
id: part.toolInvocation.toolCallId,
tool: part.toolInvocation.toolName,
state: (() => {
if (part.toolInvocation.state === "partial-call") {
return {
status: "pending",
}
}
const { title, time, ...metadata } = v1.metadata.tool[part.toolInvocation.toolCallId] ?? {}
if (part.toolInvocation.state === "call") {
return {
status: "running",
input: part.toolInvocation.args,
time: {
start: time?.start,
},
}
}
if (part.toolInvocation.state === "result") {
return {
status: "completed",
input: part.toolInvocation.args,
output: part.toolInvocation.result,
title,
time,
metadata,
}
}
throw new Error("unknown tool invocation state")
})(),
},
]
}
return []
}),
}
return result
const parts = v1.parts.flatMap((part): Part[] => {
const base = {
id: Identifier.ascending("part"),
messageID: v1.id,
sessionID: v1.metadata.sessionID,
}
if (part.type === "text") {
return [
{
...base,
type: "text",
text: part.text,
},
]
}
if (part.type === "step-start") {
return [
{
...base,
type: "step-start",
},
]
}
if (part.type === "tool-invocation") {
return [
{
...base,
type: "tool",
callID: part.toolInvocation.toolCallId,
tool: part.toolInvocation.toolName,
state: (() => {
if (part.toolInvocation.state === "partial-call") {
return {
status: "pending",
}
}
const { title, time, ...metadata } = v1.metadata.tool[part.toolInvocation.toolCallId] ?? {}
if (part.toolInvocation.state === "call") {
return {
status: "running",
input: part.toolInvocation.args,
time: {
start: time?.start,
},
}
}
if (part.toolInvocation.state === "result") {
return {
status: "completed",
input: part.toolInvocation.args,
output: part.toolInvocation.result,
title,
time,
metadata,
}
}
throw new Error("unknown tool invocation state")
})(),
},
]
}
return []
})
return {
info,
parts,
}
}
if (v1.role === "user") {
const result: User = {
const info: User = {
id: v1.id,
sessionID: v1.metadata.sessionID,
role: "user",
time: {
created: v1.metadata.time.created,
},
parts: v1.parts.flatMap((part): UserPart[] => {
if (part.type === "text") {
return [
{
type: "text",
text: part.text,
},
]
}
if (part.type === "file") {
return [
{
type: "file",
mime: part.mediaType,
filename: part.filename,
url: part.url,
},
]
}
return []
}),
}
return result
const parts = v1.parts.flatMap((part): Part[] => {
const base = {
id: Identifier.ascending("part"),
messageID: v1.id,
sessionID: v1.metadata.sessionID,
}
if (part.type === "text") {
return [
{
...base,
type: "text",
text: part.text,
},
]
}
if (part.type === "file") {
return [
{
...base,
type: "file",
mime: part.mediaType,
filename: part.filename,
url: part.url,
},
]
}
return []
})
return { info, parts }
}
throw new Error("unknown message type")
}
export function toModelMessage(input: Info[]): ModelMessage[] {
export function toModelMessage(
input: {
info: Info
parts: Part[]
}[],
): ModelMessage[] {
const result: UIMessage[] = []
for (const msg of input) {
if (msg.parts.length === 0) continue
if (msg.role === "user") {
if (msg.info.role === "user") {
result.push({
id: msg.id,
id: msg.info.id,
role: "user",
parts: msg.parts.flatMap((part): UIMessage["parts"] => {
if (part.type === "text")
@@ -374,9 +394,9 @@ export namespace MessageV2 {
})
}
if (msg.role === "assistant") {
if (msg.info.role === "assistant") {
result.push({
id: msg.id,
id: msg.info.id,
role: "assistant",
parts: msg.parts.flatMap((part): UIMessage["parts"] => {
if (part.type === "text")
@@ -398,7 +418,7 @@ export namespace MessageV2 {
{
type: ("tool-" + part.tool) as `tool-${string}`,
state: "output-available",
toolCallId: part.id,
toolCallId: part.callID,
input: part.state.input,
output: part.state.output,
},
@@ -408,7 +428,7 @@ export namespace MessageV2 {
{
type: ("tool-" + part.tool) as `tool-${string}`,
state: "output-error",
toolCallId: part.id,
toolCallId: part.callID,
input: part.state.input,
errorText: part.state.error,
},

View File

@@ -5,6 +5,7 @@ import path from "path"
import z from "zod"
import fs from "fs/promises"
import { MessageV2 } from "../session/message-v2"
import { Identifier } from "../id/id"
export namespace Storage {
const log = Log.create({ service: "storage" })
@@ -28,13 +29,49 @@ export namespace Storage {
log.info("migrating to v2 message", { file })
try {
const result = MessageV2.fromV1(content)
await Bun.write(file, JSON.stringify(result, null, 2))
await Bun.write(
file,
JSON.stringify(
{
...result.info,
parts: result.parts,
},
null,
2,
),
)
} catch (e) {
await fs.rename(file, file.replace("storage", "broken"))
}
}
} catch {}
},
async (dir: string) => {
const files = new Bun.Glob("session/message/*/*.json").scanSync({
cwd: dir,
absolute: true,
})
for (const file of files) {
try {
const { parts, ...info } = await Bun.file(file).json()
if (!parts) continue
for (const part of parts) {
const id = Identifier.ascending("part")
await Bun.write(
[dir, "session", "part", info.sessionID, info.id, id + ".json"].join("/"),
JSON.stringify({
...part,
id,
sessionID: info.sessionID,
messageID: info.id,
...(part.type === "tool" ? { callID: part.id } : {}),
}),
)
}
await Bun.write(file, JSON.stringify(info, null, 2))
} catch (e) {}
}
},
]
const state = App.state("storage", async () => {

View File

@@ -4,6 +4,7 @@ import { z } from "zod"
import { Session } from "../session"
import { Bus } from "../bus"
import { MessageV2 } from "../session/message-v2"
import { Identifier } from "../id/id"
export const TaskTool = Tool.define({
id: "task",
@@ -16,9 +17,10 @@ export const TaskTool = Tool.define({
const session = await Session.create(ctx.sessionID)
const msg = (await Session.getMessage(ctx.sessionID, ctx.messageID)) as MessageV2.Assistant
function summary(input: MessageV2.Info) {
const parts: Record<string, MessageV2.Part> = {}
function summary(input: MessageV2.Part[]) {
const result = []
for (const part of input.parts) {
for (const part of input) {
if (part.type === "tool" && part.state.status === "completed") {
result.push(part)
}
@@ -26,12 +28,13 @@ export const TaskTool = Tool.define({
return result
}
const unsub = Bus.subscribe(MessageV2.Event.Updated, async (evt) => {
if (evt.properties.info.sessionID !== session.id) return
const unsub = Bus.subscribe(MessageV2.Event.PartUpdated, async (evt) => {
if (evt.properties.part.sessionID !== session.id) return
parts[evt.properties.part.id] = evt.properties.part
ctx.metadata({
title: params.description,
metadata: {
summary: summary(evt.properties.info),
summary: Object.values(parts).sort((a, b) => a.id?.localeCompare(b.id)),
},
})
})
@@ -39,12 +42,17 @@ export const TaskTool = Tool.define({
ctx.abort.addEventListener("abort", () => {
Session.abort(session.id)
})
const messageID = Identifier.ascending("message")
const result = await Session.chat({
messageID,
sessionID: session.id,
modelID: msg.modelID,
providerID: msg.providerID,
parts: [
{
id: Identifier.ascending("part"),
messageID,
sessionID: session.id,
type: "text",
text: params.prompt,
},
@@ -54,7 +62,7 @@ export const TaskTool = Tool.define({
return {
title: params.description,
metadata: {
summary: summary(result),
summary: summary(result.parts),
},
output: result.parts.findLast((x) => x.type === "text")!.text,
}