core: improve session API reliability with proper input validation

This commit is contained in:
Dax Raad
2025-10-06 19:37:30 -04:00
parent aee240150b
commit 10998d62b9
6 changed files with 139 additions and 118 deletions

View File

@@ -106,7 +106,7 @@ export const RunCommand = cmd({
if (args.session) return Session.get(args.session) if (args.session) return Session.get(args.session)
return Session.create() return Session.create({})
})() })()
if (!session) { if (!session) {

View File

@@ -31,7 +31,6 @@ import { SessionRevert } from "../session/revert"
import { lazy } from "../util/lazy" import { lazy } from "../util/lazy"
import { Todo } from "../session/todo" import { Todo } from "../session/todo"
import { InstanceBootstrap } from "../project/bootstrap" import { InstanceBootstrap } from "../project/bootstrap"
import { Identifier } from "@/id/id"
const ERRORS = { const ERRORS = {
400: { 400: {
@@ -308,7 +307,7 @@ export namespace Server {
validator( validator(
"param", "param",
z.object({ z.object({
id: z.string(), id: Session.get.schema,
}), }),
), ),
async (c) => { async (c) => {
@@ -336,7 +335,7 @@ export namespace Server {
validator( validator(
"param", "param",
z.object({ z.object({
id: z.string(), id: Session.children.schema,
}), }),
), ),
async (c) => { async (c) => {
@@ -390,18 +389,10 @@ export namespace Server {
}, },
}, },
}), }),
validator( validator("json", Session.create.schema.optional()),
"json",
z
.object({
parentID: z.string().optional(),
title: z.string().optional(),
})
.optional(),
),
async (c) => { async (c) => {
const body = c.req.valid("json") ?? {} const body = c.req.valid("json") ?? {}
const session = await Session.create(body.parentID, body.title) const session = await Session.create(body)
return c.json(session) return c.json(session)
}, },
) )
@@ -424,7 +415,7 @@ export namespace Server {
validator( validator(
"param", "param",
z.object({ z.object({
id: z.string(), id: Session.remove.schema,
}), }),
), ),
async (c) => { async (c) => {
@@ -495,14 +486,7 @@ export namespace Server {
id: z.string().meta({ description: "Session ID" }), id: z.string().meta({ description: "Session ID" }),
}), }),
), ),
validator( validator("json", Session.initialize.schema.omit({ sessionID: true })),
"json",
z.object({
messageID: z.string(),
providerID: z.string(),
modelID: z.string(),
}),
),
async (c) => { async (c) => {
const sessionID = c.req.valid("param").id const sessionID = c.req.valid("param").id
const body = c.req.valid("json") const body = c.req.valid("json")
@@ -529,7 +513,7 @@ export namespace Server {
validator( validator(
"param", "param",
z.object({ z.object({
id: Identifier.schema("session").meta({ description: "Session ID" }), id: Session.fork.schema.shape.sessionID,
}), }),
), ),
validator("json", Session.fork.schema.omit({ sessionID: true })), validator("json", Session.fork.schema.omit({ sessionID: true })),
@@ -614,7 +598,7 @@ export namespace Server {
validator( validator(
"param", "param",
z.object({ z.object({
id: z.string(), id: Session.unshare.schema,
}), }),
), ),
async (c) => { async (c) => {
@@ -717,7 +701,7 @@ export namespace Server {
), ),
async (c) => { async (c) => {
const params = c.req.valid("param") const params = c.req.valid("param")
const message = await Session.getMessage(params.id, params.messageID) const message = await Session.getMessage({ sessionID: params.id, messageID: params.messageID })
return c.json(message) return c.json(message)
}, },
) )

View File

@@ -144,7 +144,7 @@ export namespace SessionCompaction {
}, },
], ],
}) })
const usage = Session.getUsage(model.info, generated.usage, generated.providerMetadata) const usage = Session.getUsage({ model: model.info, usage: generated.usage, metadata: generated.providerMetadata })
msg.cost += usage.cost msg.cost += usage.cost
msg.tokens = usage.tokens msg.tokens = usage.tokens
msg.summary = true msg.summary = true

View File

@@ -93,13 +93,21 @@ export namespace Session {
), ),
} }
export async function create(parentID?: string, title?: string) { export const create = fn(
return createNext({ z
parentID, .object({
directory: Instance.directory, parentID: Identifier.schema("session").optional(),
title, title: z.string().optional(),
}) })
} .optional(),
async (input) => {
return createNext({
parentID: input?.parentID,
directory: Instance.directory,
title: input?.title,
})
},
)
export const fork = fn( export const fork = fn(
z.object({ z.object({
@@ -132,11 +140,11 @@ export namespace Session {
}, },
) )
export async function touch(sessionID: string) { export const touch = fn(Identifier.schema("session"), async (sessionID) => {
await update(sessionID, (draft) => { await update(sessionID, (draft) => {
draft.time.updated = Date.now() draft.time.updated = Date.now()
}) })
} })
export async function createNext(input: { id?: string; title?: string; parentID?: string; directory: string }) { export async function createNext(input: { id?: string; title?: string; parentID?: string; directory: string }) {
const result: Info = { const result: Info = {
@@ -170,16 +178,16 @@ export namespace Session {
return result return result
} }
export async function get(id: string) { export const get = fn(Identifier.schema("session"), async (id) => {
const read = await Storage.read<Info>(["session", Instance.project.id, id]) const read = await Storage.read<Info>(["session", Instance.project.id, id])
return read as Info return read as Info
} })
export async function getShare(id: string) { export const getShare = fn(Identifier.schema("session"), async (id) => {
return Storage.read<ShareInfo>(["share", id]) return Storage.read<ShareInfo>(["share", id])
} })
export async function share(id: string) { export const share = fn(Identifier.schema("session"), async (id) => {
const cfg = await Config.get() const cfg = await Config.get()
if (cfg.share === "disabled") { if (cfg.share === "disabled") {
throw new Error("Sharing is disabled in configuration") throw new Error("Sharing is disabled in configuration")
@@ -202,9 +210,9 @@ export namespace Session {
} }
} }
return share return share
} })
export async function unshare(id: string) { export const unshare = fn(Identifier.schema("session"), async (id) => {
const share = await getShare(id) const share = await getShare(id)
if (!share) return if (!share) return
await Storage.remove(["share", id]) await Storage.remove(["share", id])
@@ -212,7 +220,7 @@ export namespace Session {
draft.share = undefined draft.share = undefined
}) })
await Share.remove(id, share.secret) await Share.remove(id, share.secret)
} })
export async function update(id: string, editor: (session: Info) => void) { export async function update(id: string, editor: (session: Info) => void) {
const project = Instance.project const project = Instance.project
@@ -226,7 +234,7 @@ export namespace Session {
return result return result
} }
export async function messages(sessionID: string) { export const messages = fn(Identifier.schema("session"), async (sessionID) => {
const result = [] as MessageV2.WithParts[] const result = [] as MessageV2.WithParts[]
for (const p of await Storage.list(["message", sessionID])) { for (const p of await Storage.list(["message", sessionID])) {
const read = await Storage.read<MessageV2.Info>(p) const read = await Storage.read<MessageV2.Info>(p)
@@ -237,16 +245,22 @@ export namespace Session {
} }
result.sort((a, b) => (a.info.id > b.info.id ? 1 : -1)) result.sort((a, b) => (a.info.id > b.info.id ? 1 : -1))
return result return result
} })
export async function getMessage(sessionID: string, messageID: string) { export const getMessage = fn(
return { z.object({
info: await Storage.read<MessageV2.Info>(["message", sessionID, messageID]), sessionID: Identifier.schema("session"),
parts: await getParts(messageID), messageID: Identifier.schema("message"),
} }),
} async (input) => {
return {
info: await Storage.read<MessageV2.Info>(["message", input.sessionID, input.messageID]),
parts: await getParts(input.messageID),
}
},
)
export async function getParts(messageID: string) { export const getParts = fn(Identifier.schema("message"), async (messageID) => {
const result = [] as MessageV2.Part[] const result = [] as MessageV2.Part[]
for (const item of await Storage.list(["part", messageID])) { for (const item of await Storage.list(["part", messageID])) {
const read = await Storage.read<MessageV2.Part>(item) const read = await Storage.read<MessageV2.Part>(item)
@@ -254,7 +268,7 @@ export namespace Session {
} }
result.sort((a, b) => (a.id > b.id ? 1 : -1)) result.sort((a, b) => (a.id > b.id ? 1 : -1))
return result return result
} })
export async function* list() { export async function* list() {
const project = Instance.project const project = Instance.project
@@ -263,7 +277,7 @@ export namespace Session {
} }
} }
export async function children(parentID: string) { export const children = fn(Identifier.schema("session"), async (parentID) => {
const project = Instance.project const project = Instance.project
const result = [] as Session.Info[] const result = [] as Session.Info[]
for (const item of await Storage.list(["session", project.id])) { for (const item of await Storage.list(["session", project.id])) {
@@ -272,9 +286,9 @@ export namespace Session {
result.push(session) result.push(session)
} }
return result return result
} })
export async function remove(sessionID: string) { export const remove = fn(Identifier.schema("session"), async (sessionID) => {
const project = Instance.project const project = Instance.project
try { try {
const session = await get(sessionID) const session = await get(sessionID)
@@ -295,56 +309,69 @@ export namespace Session {
} catch (e) { } catch (e) {
log.error(e) log.error(e)
} }
} })
export async function updateMessage(msg: MessageV2.Info) { export const updateMessage = fn(MessageV2.Info, async (msg) => {
await Storage.write(["message", msg.sessionID, msg.id], msg) await Storage.write(["message", msg.sessionID, msg.id], msg)
Bus.publish(MessageV2.Event.Updated, { Bus.publish(MessageV2.Event.Updated, {
info: msg, info: msg,
}) })
return msg return msg
} })
export async function removeMessage(sessionID: string, messageID: string) { export const removeMessage = fn(
await Storage.remove(["message", sessionID, messageID]) z.object({
Bus.publish(MessageV2.Event.Removed, { sessionID: Identifier.schema("session"),
sessionID, messageID: Identifier.schema("message"),
messageID, }),
}) async (input) => {
return messageID await Storage.remove(["message", input.sessionID, input.messageID])
} Bus.publish(MessageV2.Event.Removed, {
sessionID: input.sessionID,
messageID: input.messageID,
})
return input.messageID
},
)
export async function updatePart(part: MessageV2.Part) { export const updatePart = fn(MessageV2.Part, async (part) => {
await Storage.write(["part", part.messageID, part.id], part) await Storage.write(["part", part.messageID, part.id], part)
Bus.publish(MessageV2.Event.PartUpdated, { Bus.publish(MessageV2.Event.PartUpdated, {
part, part,
}) })
return part return part
} })
export function getUsage(model: ModelsDev.Model, usage: LanguageModelUsage, metadata?: ProviderMetadata) { export const getUsage = fn(
const tokens = { z.object({
input: usage.inputTokens ?? 0, model: z.custom<ModelsDev.Model>(),
output: usage.outputTokens ?? 0, usage: z.custom<LanguageModelUsage>(),
reasoning: usage?.reasoningTokens ?? 0, metadata: z.custom<ProviderMetadata>().optional(),
cache: { }),
write: (metadata?.["anthropic"]?.["cacheCreationInputTokens"] ?? (input) => {
// @ts-expect-error const tokens = {
metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ?? input: input.usage.inputTokens ?? 0,
0) as number, output: input.usage.outputTokens ?? 0,
read: usage.cachedInputTokens ?? 0, reasoning: input.usage?.reasoningTokens ?? 0,
}, cache: {
} write: (input.metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
return { // @ts-expect-error
cost: new Decimal(0) input.metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
.add(new Decimal(tokens.input).mul(model.cost?.input ?? 0).div(1_000_000)) 0) as number,
.add(new Decimal(tokens.output).mul(model.cost?.output ?? 0).div(1_000_000)) read: input.usage.cachedInputTokens ?? 0,
.add(new Decimal(tokens.cache.read).mul(model.cost?.cache_read ?? 0).div(1_000_000)) },
.add(new Decimal(tokens.cache.write).mul(model.cost?.cache_write ?? 0).div(1_000_000)) }
.toNumber(), return {
tokens, cost: new Decimal(0)
} .add(new Decimal(tokens.input).mul(input.model.cost?.input ?? 0).div(1_000_000))
} .add(new Decimal(tokens.output).mul(input.model.cost?.output ?? 0).div(1_000_000))
.add(new Decimal(tokens.cache.read).mul(input.model.cost?.cache_read ?? 0).div(1_000_000))
.add(new Decimal(tokens.cache.write).mul(input.model.cost?.cache_write ?? 0).div(1_000_000))
.toNumber(),
tokens,
}
},
)
export class BusyError extends Error { export class BusyError extends Error {
constructor(public readonly sessionID: string) { constructor(public readonly sessionID: string) {
@@ -352,27 +379,30 @@ export namespace Session {
} }
} }
export async function initialize(input: { export const initialize = fn(
sessionID: string z.object({
modelID: string sessionID: Identifier.schema("session"),
providerID: string modelID: z.string(),
messageID: string providerID: z.string(),
}) { messageID: Identifier.schema("message"),
await SessionPrompt.prompt({ }),
sessionID: input.sessionID, async (input) => {
messageID: input.messageID, await SessionPrompt.prompt({
model: { sessionID: input.sessionID,
providerID: input.providerID, messageID: input.messageID,
modelID: input.modelID, model: {
}, providerID: input.providerID,
parts: [ modelID: input.modelID,
{
id: Identifier.ascending("part"),
type: "text",
text: PROMPT_INITIALIZE.replace("${path}", Instance.worktree),
}, },
], parts: [
}) {
await Project.setInitialized(Instance.project.id) id: Identifier.ascending("part"),
} type: "text",
text: PROMPT_INITIALIZE.replace("${path}", Instance.worktree),
},
],
})
await Project.setInitialized(Instance.project.id)
},
)
} }

View File

@@ -1031,7 +1031,11 @@ export namespace SessionPrompt {
break break
case "finish-step": case "finish-step":
const usage = Session.getUsage(input.model, value.usage, value.providerMetadata) const usage = Session.getUsage({
model: input.model,
usage: value.usage,
metadata: value.providerMetadata,
})
assistantMsg.cost += usage.cost assistantMsg.cost += usage.cost
assistantMsg.tokens = usage.tokens assistantMsg.tokens = usage.tokens
await Session.updatePart({ await Session.updatePart({

View File

@@ -26,8 +26,11 @@ export const TaskTool = Tool.define("task", async () => {
async execute(params, ctx) { async execute(params, ctx) {
const agent = await Agent.get(params.subagent_type) const agent = await Agent.get(params.subagent_type)
if (!agent) throw new Error(`Unknown agent type: ${params.subagent_type} is not a valid agent type`) if (!agent) throw new Error(`Unknown agent type: ${params.subagent_type} is not a valid agent type`)
const session = await Session.create(ctx.sessionID, params.description + ` (@${agent.name} subagent)`) const session = await Session.create({
const msg = await Session.getMessage(ctx.sessionID, ctx.messageID) parentID: ctx.sessionID,
title: params.description + ` (@${agent.name} subagent)`,
})
const msg = await Session.getMessage({ sessionID: ctx.sessionID, messageID: ctx.messageID })
if (msg.info.role !== "assistant") throw new Error("Not an assistant message") if (msg.info.role !== "assistant") throw new Error("Not an assistant message")
const messageID = Identifier.ascending("message") const messageID = Identifier.ascending("message")
const parts: Record<string, MessageV2.ToolPart> = {} const parts: Record<string, MessageV2.ToolPart> = {}