core: improve session API reliability with proper input validation

This commit is contained in:
Dax Raad
2025-10-06 19:37:30 -04:00
parent aee240150b
commit 10998d62b9
6 changed files with 139 additions and 118 deletions

View File

@@ -106,7 +106,7 @@ export const RunCommand = cmd({
if (args.session) return Session.get(args.session)
return Session.create()
return Session.create({})
})()
if (!session) {

View File

@@ -31,7 +31,6 @@ import { SessionRevert } from "../session/revert"
import { lazy } from "../util/lazy"
import { Todo } from "../session/todo"
import { InstanceBootstrap } from "../project/bootstrap"
import { Identifier } from "@/id/id"
const ERRORS = {
400: {
@@ -308,7 +307,7 @@ export namespace Server {
validator(
"param",
z.object({
id: z.string(),
id: Session.get.schema,
}),
),
async (c) => {
@@ -336,7 +335,7 @@ export namespace Server {
validator(
"param",
z.object({
id: z.string(),
id: Session.children.schema,
}),
),
async (c) => {
@@ -390,18 +389,10 @@ export namespace Server {
},
},
}),
validator(
"json",
z
.object({
parentID: z.string().optional(),
title: z.string().optional(),
})
.optional(),
),
validator("json", Session.create.schema.optional()),
async (c) => {
const body = c.req.valid("json") ?? {}
const session = await Session.create(body.parentID, body.title)
const session = await Session.create(body)
return c.json(session)
},
)
@@ -424,7 +415,7 @@ export namespace Server {
validator(
"param",
z.object({
id: z.string(),
id: Session.remove.schema,
}),
),
async (c) => {
@@ -495,14 +486,7 @@ export namespace Server {
id: z.string().meta({ description: "Session ID" }),
}),
),
validator(
"json",
z.object({
messageID: z.string(),
providerID: z.string(),
modelID: z.string(),
}),
),
validator("json", Session.initialize.schema.omit({ sessionID: true })),
async (c) => {
const sessionID = c.req.valid("param").id
const body = c.req.valid("json")
@@ -529,7 +513,7 @@ export namespace Server {
validator(
"param",
z.object({
id: Identifier.schema("session").meta({ description: "Session ID" }),
id: Session.fork.schema.shape.sessionID,
}),
),
validator("json", Session.fork.schema.omit({ sessionID: true })),
@@ -614,7 +598,7 @@ export namespace Server {
validator(
"param",
z.object({
id: z.string(),
id: Session.unshare.schema,
}),
),
async (c) => {
@@ -717,7 +701,7 @@ export namespace Server {
),
async (c) => {
const params = c.req.valid("param")
const message = await Session.getMessage(params.id, params.messageID)
const message = await Session.getMessage({ sessionID: params.id, messageID: params.messageID })
return c.json(message)
},
)

View File

@@ -144,7 +144,7 @@ export namespace SessionCompaction {
},
],
})
const usage = Session.getUsage(model.info, generated.usage, generated.providerMetadata)
const usage = Session.getUsage({ model: model.info, usage: generated.usage, metadata: generated.providerMetadata })
msg.cost += usage.cost
msg.tokens = usage.tokens
msg.summary = true

View File

@@ -93,13 +93,21 @@ export namespace Session {
),
}
export async function create(parentID?: string, title?: string) {
return createNext({
parentID,
directory: Instance.directory,
title,
export const create = fn(
z
.object({
parentID: Identifier.schema("session").optional(),
title: z.string().optional(),
})
}
.optional(),
async (input) => {
return createNext({
parentID: input?.parentID,
directory: Instance.directory,
title: input?.title,
})
},
)
export const fork = fn(
z.object({
@@ -132,11 +140,11 @@ export namespace Session {
},
)
export async function touch(sessionID: string) {
export const touch = fn(Identifier.schema("session"), async (sessionID) => {
await update(sessionID, (draft) => {
draft.time.updated = Date.now()
})
}
})
export async function createNext(input: { id?: string; title?: string; parentID?: string; directory: string }) {
const result: Info = {
@@ -170,16 +178,16 @@ export namespace Session {
return result
}
export async function get(id: string) {
export const get = fn(Identifier.schema("session"), async (id) => {
const read = await Storage.read<Info>(["session", Instance.project.id, id])
return read as Info
}
})
export async function getShare(id: string) {
export const getShare = fn(Identifier.schema("session"), async (id) => {
return Storage.read<ShareInfo>(["share", id])
}
})
export async function share(id: string) {
export const share = fn(Identifier.schema("session"), async (id) => {
const cfg = await Config.get()
if (cfg.share === "disabled") {
throw new Error("Sharing is disabled in configuration")
@@ -202,9 +210,9 @@ export namespace Session {
}
}
return share
}
})
export async function unshare(id: string) {
export const unshare = fn(Identifier.schema("session"), async (id) => {
const share = await getShare(id)
if (!share) return
await Storage.remove(["share", id])
@@ -212,7 +220,7 @@ export namespace Session {
draft.share = undefined
})
await Share.remove(id, share.secret)
}
})
export async function update(id: string, editor: (session: Info) => void) {
const project = Instance.project
@@ -226,7 +234,7 @@ export namespace Session {
return result
}
export async function messages(sessionID: string) {
export const messages = fn(Identifier.schema("session"), async (sessionID) => {
const result = [] as MessageV2.WithParts[]
for (const p of await Storage.list(["message", sessionID])) {
const read = await Storage.read<MessageV2.Info>(p)
@@ -237,16 +245,22 @@ export namespace Session {
}
result.sort((a, b) => (a.info.id > b.info.id ? 1 : -1))
return result
}
})
export async function getMessage(sessionID: string, messageID: string) {
export const getMessage = fn(
z.object({
sessionID: Identifier.schema("session"),
messageID: Identifier.schema("message"),
}),
async (input) => {
return {
info: await Storage.read<MessageV2.Info>(["message", sessionID, messageID]),
parts: await getParts(messageID),
}
info: await Storage.read<MessageV2.Info>(["message", input.sessionID, input.messageID]),
parts: await getParts(input.messageID),
}
},
)
export async function getParts(messageID: string) {
export const getParts = fn(Identifier.schema("message"), async (messageID) => {
const result = [] as MessageV2.Part[]
for (const item of await Storage.list(["part", messageID])) {
const read = await Storage.read<MessageV2.Part>(item)
@@ -254,7 +268,7 @@ export namespace Session {
}
result.sort((a, b) => (a.id > b.id ? 1 : -1))
return result
}
})
export async function* list() {
const project = Instance.project
@@ -263,7 +277,7 @@ export namespace Session {
}
}
export async function children(parentID: string) {
export const children = fn(Identifier.schema("session"), async (parentID) => {
const project = Instance.project
const result = [] as Session.Info[]
for (const item of await Storage.list(["session", project.id])) {
@@ -272,9 +286,9 @@ export namespace Session {
result.push(session)
}
return result
}
})
export async function remove(sessionID: string) {
export const remove = fn(Identifier.schema("session"), async (sessionID) => {
const project = Instance.project
try {
const session = await get(sessionID)
@@ -295,56 +309,69 @@ export namespace Session {
} catch (e) {
log.error(e)
}
}
})
export async function updateMessage(msg: MessageV2.Info) {
export const updateMessage = fn(MessageV2.Info, async (msg) => {
await Storage.write(["message", msg.sessionID, msg.id], msg)
Bus.publish(MessageV2.Event.Updated, {
info: msg,
})
return msg
}
export async function removeMessage(sessionID: string, messageID: string) {
await Storage.remove(["message", sessionID, messageID])
Bus.publish(MessageV2.Event.Removed, {
sessionID,
messageID,
})
return messageID
}
export async function updatePart(part: MessageV2.Part) {
export const removeMessage = fn(
z.object({
sessionID: Identifier.schema("session"),
messageID: Identifier.schema("message"),
}),
async (input) => {
await Storage.remove(["message", input.sessionID, input.messageID])
Bus.publish(MessageV2.Event.Removed, {
sessionID: input.sessionID,
messageID: input.messageID,
})
return input.messageID
},
)
export const updatePart = fn(MessageV2.Part, async (part) => {
await Storage.write(["part", part.messageID, part.id], part)
Bus.publish(MessageV2.Event.PartUpdated, {
part,
})
return part
}
})
export function getUsage(model: ModelsDev.Model, usage: LanguageModelUsage, metadata?: ProviderMetadata) {
export const getUsage = fn(
z.object({
model: z.custom<ModelsDev.Model>(),
usage: z.custom<LanguageModelUsage>(),
metadata: z.custom<ProviderMetadata>().optional(),
}),
(input) => {
const tokens = {
input: usage.inputTokens ?? 0,
output: usage.outputTokens ?? 0,
reasoning: usage?.reasoningTokens ?? 0,
input: input.usage.inputTokens ?? 0,
output: input.usage.outputTokens ?? 0,
reasoning: input.usage?.reasoningTokens ?? 0,
cache: {
write: (metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
write: (input.metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
// @ts-expect-error
metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
input.metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
0) as number,
read: usage.cachedInputTokens ?? 0,
read: input.usage.cachedInputTokens ?? 0,
},
}
return {
cost: new Decimal(0)
.add(new Decimal(tokens.input).mul(model.cost?.input ?? 0).div(1_000_000))
.add(new Decimal(tokens.output).mul(model.cost?.output ?? 0).div(1_000_000))
.add(new Decimal(tokens.cache.read).mul(model.cost?.cache_read ?? 0).div(1_000_000))
.add(new Decimal(tokens.cache.write).mul(model.cost?.cache_write ?? 0).div(1_000_000))
.add(new Decimal(tokens.input).mul(input.model.cost?.input ?? 0).div(1_000_000))
.add(new Decimal(tokens.output).mul(input.model.cost?.output ?? 0).div(1_000_000))
.add(new Decimal(tokens.cache.read).mul(input.model.cost?.cache_read ?? 0).div(1_000_000))
.add(new Decimal(tokens.cache.write).mul(input.model.cost?.cache_write ?? 0).div(1_000_000))
.toNumber(),
tokens,
}
}
},
)
export class BusyError extends Error {
constructor(public readonly sessionID: string) {
@@ -352,12 +379,14 @@ export namespace Session {
}
}
export async function initialize(input: {
sessionID: string
modelID: string
providerID: string
messageID: string
}) {
export const initialize = fn(
z.object({
sessionID: Identifier.schema("session"),
modelID: z.string(),
providerID: z.string(),
messageID: Identifier.schema("message"),
}),
async (input) => {
await SessionPrompt.prompt({
sessionID: input.sessionID,
messageID: input.messageID,
@@ -374,5 +403,6 @@ export namespace Session {
],
})
await Project.setInitialized(Instance.project.id)
}
},
)
}

View File

@@ -1031,7 +1031,11 @@ export namespace SessionPrompt {
break
case "finish-step":
const usage = Session.getUsage(input.model, value.usage, value.providerMetadata)
const usage = Session.getUsage({
model: input.model,
usage: value.usage,
metadata: value.providerMetadata,
})
assistantMsg.cost += usage.cost
assistantMsg.tokens = usage.tokens
await Session.updatePart({

View File

@@ -26,8 +26,11 @@ export const TaskTool = Tool.define("task", async () => {
async execute(params, ctx) {
const agent = await Agent.get(params.subagent_type)
if (!agent) throw new Error(`Unknown agent type: ${params.subagent_type} is not a valid agent type`)
const session = await Session.create(ctx.sessionID, params.description + ` (@${agent.name} subagent)`)
const msg = await Session.getMessage(ctx.sessionID, ctx.messageID)
const session = await Session.create({
parentID: ctx.sessionID,
title: params.description + ` (@${agent.name} subagent)`,
})
const msg = await Session.getMessage({ sessionID: ctx.sessionID, messageID: ctx.messageID })
if (msg.info.role !== "assistant") throw new Error("Not an assistant message")
const messageID = Identifier.ascending("message")
const parts: Record<string, MessageV2.ToolPart> = {}