mirror of
https://github.com/aljazceru/opencode.git
synced 2025-12-21 09:44:21 +01:00
core: improve session API reliability with proper input validation
This commit is contained in:
@@ -106,7 +106,7 @@ export const RunCommand = cmd({
|
||||
|
||||
if (args.session) return Session.get(args.session)
|
||||
|
||||
return Session.create()
|
||||
return Session.create({})
|
||||
})()
|
||||
|
||||
if (!session) {
|
||||
|
||||
@@ -31,7 +31,6 @@ import { SessionRevert } from "../session/revert"
|
||||
import { lazy } from "../util/lazy"
|
||||
import { Todo } from "../session/todo"
|
||||
import { InstanceBootstrap } from "../project/bootstrap"
|
||||
import { Identifier } from "@/id/id"
|
||||
|
||||
const ERRORS = {
|
||||
400: {
|
||||
@@ -308,7 +307,7 @@ export namespace Server {
|
||||
validator(
|
||||
"param",
|
||||
z.object({
|
||||
id: z.string(),
|
||||
id: Session.get.schema,
|
||||
}),
|
||||
),
|
||||
async (c) => {
|
||||
@@ -336,7 +335,7 @@ export namespace Server {
|
||||
validator(
|
||||
"param",
|
||||
z.object({
|
||||
id: z.string(),
|
||||
id: Session.children.schema,
|
||||
}),
|
||||
),
|
||||
async (c) => {
|
||||
@@ -390,18 +389,10 @@ export namespace Server {
|
||||
},
|
||||
},
|
||||
}),
|
||||
validator(
|
||||
"json",
|
||||
z
|
||||
.object({
|
||||
parentID: z.string().optional(),
|
||||
title: z.string().optional(),
|
||||
})
|
||||
.optional(),
|
||||
),
|
||||
validator("json", Session.create.schema.optional()),
|
||||
async (c) => {
|
||||
const body = c.req.valid("json") ?? {}
|
||||
const session = await Session.create(body.parentID, body.title)
|
||||
const session = await Session.create(body)
|
||||
return c.json(session)
|
||||
},
|
||||
)
|
||||
@@ -424,7 +415,7 @@ export namespace Server {
|
||||
validator(
|
||||
"param",
|
||||
z.object({
|
||||
id: z.string(),
|
||||
id: Session.remove.schema,
|
||||
}),
|
||||
),
|
||||
async (c) => {
|
||||
@@ -495,14 +486,7 @@ export namespace Server {
|
||||
id: z.string().meta({ description: "Session ID" }),
|
||||
}),
|
||||
),
|
||||
validator(
|
||||
"json",
|
||||
z.object({
|
||||
messageID: z.string(),
|
||||
providerID: z.string(),
|
||||
modelID: z.string(),
|
||||
}),
|
||||
),
|
||||
validator("json", Session.initialize.schema.omit({ sessionID: true })),
|
||||
async (c) => {
|
||||
const sessionID = c.req.valid("param").id
|
||||
const body = c.req.valid("json")
|
||||
@@ -529,7 +513,7 @@ export namespace Server {
|
||||
validator(
|
||||
"param",
|
||||
z.object({
|
||||
id: Identifier.schema("session").meta({ description: "Session ID" }),
|
||||
id: Session.fork.schema.shape.sessionID,
|
||||
}),
|
||||
),
|
||||
validator("json", Session.fork.schema.omit({ sessionID: true })),
|
||||
@@ -614,7 +598,7 @@ export namespace Server {
|
||||
validator(
|
||||
"param",
|
||||
z.object({
|
||||
id: z.string(),
|
||||
id: Session.unshare.schema,
|
||||
}),
|
||||
),
|
||||
async (c) => {
|
||||
@@ -717,7 +701,7 @@ export namespace Server {
|
||||
),
|
||||
async (c) => {
|
||||
const params = c.req.valid("param")
|
||||
const message = await Session.getMessage(params.id, params.messageID)
|
||||
const message = await Session.getMessage({ sessionID: params.id, messageID: params.messageID })
|
||||
return c.json(message)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -144,7 +144,7 @@ export namespace SessionCompaction {
|
||||
},
|
||||
],
|
||||
})
|
||||
const usage = Session.getUsage(model.info, generated.usage, generated.providerMetadata)
|
||||
const usage = Session.getUsage({ model: model.info, usage: generated.usage, metadata: generated.providerMetadata })
|
||||
msg.cost += usage.cost
|
||||
msg.tokens = usage.tokens
|
||||
msg.summary = true
|
||||
|
||||
@@ -93,13 +93,21 @@ export namespace Session {
|
||||
),
|
||||
}
|
||||
|
||||
export async function create(parentID?: string, title?: string) {
|
||||
return createNext({
|
||||
parentID,
|
||||
directory: Instance.directory,
|
||||
title,
|
||||
export const create = fn(
|
||||
z
|
||||
.object({
|
||||
parentID: Identifier.schema("session").optional(),
|
||||
title: z.string().optional(),
|
||||
})
|
||||
}
|
||||
.optional(),
|
||||
async (input) => {
|
||||
return createNext({
|
||||
parentID: input?.parentID,
|
||||
directory: Instance.directory,
|
||||
title: input?.title,
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
export const fork = fn(
|
||||
z.object({
|
||||
@@ -132,11 +140,11 @@ export namespace Session {
|
||||
},
|
||||
)
|
||||
|
||||
export async function touch(sessionID: string) {
|
||||
export const touch = fn(Identifier.schema("session"), async (sessionID) => {
|
||||
await update(sessionID, (draft) => {
|
||||
draft.time.updated = Date.now()
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
export async function createNext(input: { id?: string; title?: string; parentID?: string; directory: string }) {
|
||||
const result: Info = {
|
||||
@@ -170,16 +178,16 @@ export namespace Session {
|
||||
return result
|
||||
}
|
||||
|
||||
export async function get(id: string) {
|
||||
export const get = fn(Identifier.schema("session"), async (id) => {
|
||||
const read = await Storage.read<Info>(["session", Instance.project.id, id])
|
||||
return read as Info
|
||||
}
|
||||
})
|
||||
|
||||
export async function getShare(id: string) {
|
||||
export const getShare = fn(Identifier.schema("session"), async (id) => {
|
||||
return Storage.read<ShareInfo>(["share", id])
|
||||
}
|
||||
})
|
||||
|
||||
export async function share(id: string) {
|
||||
export const share = fn(Identifier.schema("session"), async (id) => {
|
||||
const cfg = await Config.get()
|
||||
if (cfg.share === "disabled") {
|
||||
throw new Error("Sharing is disabled in configuration")
|
||||
@@ -202,9 +210,9 @@ export namespace Session {
|
||||
}
|
||||
}
|
||||
return share
|
||||
}
|
||||
})
|
||||
|
||||
export async function unshare(id: string) {
|
||||
export const unshare = fn(Identifier.schema("session"), async (id) => {
|
||||
const share = await getShare(id)
|
||||
if (!share) return
|
||||
await Storage.remove(["share", id])
|
||||
@@ -212,7 +220,7 @@ export namespace Session {
|
||||
draft.share = undefined
|
||||
})
|
||||
await Share.remove(id, share.secret)
|
||||
}
|
||||
})
|
||||
|
||||
export async function update(id: string, editor: (session: Info) => void) {
|
||||
const project = Instance.project
|
||||
@@ -226,7 +234,7 @@ export namespace Session {
|
||||
return result
|
||||
}
|
||||
|
||||
export async function messages(sessionID: string) {
|
||||
export const messages = fn(Identifier.schema("session"), async (sessionID) => {
|
||||
const result = [] as MessageV2.WithParts[]
|
||||
for (const p of await Storage.list(["message", sessionID])) {
|
||||
const read = await Storage.read<MessageV2.Info>(p)
|
||||
@@ -237,16 +245,22 @@ export namespace Session {
|
||||
}
|
||||
result.sort((a, b) => (a.info.id > b.info.id ? 1 : -1))
|
||||
return result
|
||||
}
|
||||
})
|
||||
|
||||
export async function getMessage(sessionID: string, messageID: string) {
|
||||
export const getMessage = fn(
|
||||
z.object({
|
||||
sessionID: Identifier.schema("session"),
|
||||
messageID: Identifier.schema("message"),
|
||||
}),
|
||||
async (input) => {
|
||||
return {
|
||||
info: await Storage.read<MessageV2.Info>(["message", sessionID, messageID]),
|
||||
parts: await getParts(messageID),
|
||||
}
|
||||
info: await Storage.read<MessageV2.Info>(["message", input.sessionID, input.messageID]),
|
||||
parts: await getParts(input.messageID),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
export async function getParts(messageID: string) {
|
||||
export const getParts = fn(Identifier.schema("message"), async (messageID) => {
|
||||
const result = [] as MessageV2.Part[]
|
||||
for (const item of await Storage.list(["part", messageID])) {
|
||||
const read = await Storage.read<MessageV2.Part>(item)
|
||||
@@ -254,7 +268,7 @@ export namespace Session {
|
||||
}
|
||||
result.sort((a, b) => (a.id > b.id ? 1 : -1))
|
||||
return result
|
||||
}
|
||||
})
|
||||
|
||||
export async function* list() {
|
||||
const project = Instance.project
|
||||
@@ -263,7 +277,7 @@ export namespace Session {
|
||||
}
|
||||
}
|
||||
|
||||
export async function children(parentID: string) {
|
||||
export const children = fn(Identifier.schema("session"), async (parentID) => {
|
||||
const project = Instance.project
|
||||
const result = [] as Session.Info[]
|
||||
for (const item of await Storage.list(["session", project.id])) {
|
||||
@@ -272,9 +286,9 @@ export namespace Session {
|
||||
result.push(session)
|
||||
}
|
||||
return result
|
||||
}
|
||||
})
|
||||
|
||||
export async function remove(sessionID: string) {
|
||||
export const remove = fn(Identifier.schema("session"), async (sessionID) => {
|
||||
const project = Instance.project
|
||||
try {
|
||||
const session = await get(sessionID)
|
||||
@@ -295,56 +309,69 @@ export namespace Session {
|
||||
} catch (e) {
|
||||
log.error(e)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
export async function updateMessage(msg: MessageV2.Info) {
|
||||
export const updateMessage = fn(MessageV2.Info, async (msg) => {
|
||||
await Storage.write(["message", msg.sessionID, msg.id], msg)
|
||||
Bus.publish(MessageV2.Event.Updated, {
|
||||
info: msg,
|
||||
})
|
||||
return msg
|
||||
}
|
||||
|
||||
export async function removeMessage(sessionID: string, messageID: string) {
|
||||
await Storage.remove(["message", sessionID, messageID])
|
||||
Bus.publish(MessageV2.Event.Removed, {
|
||||
sessionID,
|
||||
messageID,
|
||||
})
|
||||
return messageID
|
||||
}
|
||||
|
||||
export async function updatePart(part: MessageV2.Part) {
|
||||
export const removeMessage = fn(
|
||||
z.object({
|
||||
sessionID: Identifier.schema("session"),
|
||||
messageID: Identifier.schema("message"),
|
||||
}),
|
||||
async (input) => {
|
||||
await Storage.remove(["message", input.sessionID, input.messageID])
|
||||
Bus.publish(MessageV2.Event.Removed, {
|
||||
sessionID: input.sessionID,
|
||||
messageID: input.messageID,
|
||||
})
|
||||
return input.messageID
|
||||
},
|
||||
)
|
||||
|
||||
export const updatePart = fn(MessageV2.Part, async (part) => {
|
||||
await Storage.write(["part", part.messageID, part.id], part)
|
||||
Bus.publish(MessageV2.Event.PartUpdated, {
|
||||
part,
|
||||
})
|
||||
return part
|
||||
}
|
||||
})
|
||||
|
||||
export function getUsage(model: ModelsDev.Model, usage: LanguageModelUsage, metadata?: ProviderMetadata) {
|
||||
export const getUsage = fn(
|
||||
z.object({
|
||||
model: z.custom<ModelsDev.Model>(),
|
||||
usage: z.custom<LanguageModelUsage>(),
|
||||
metadata: z.custom<ProviderMetadata>().optional(),
|
||||
}),
|
||||
(input) => {
|
||||
const tokens = {
|
||||
input: usage.inputTokens ?? 0,
|
||||
output: usage.outputTokens ?? 0,
|
||||
reasoning: usage?.reasoningTokens ?? 0,
|
||||
input: input.usage.inputTokens ?? 0,
|
||||
output: input.usage.outputTokens ?? 0,
|
||||
reasoning: input.usage?.reasoningTokens ?? 0,
|
||||
cache: {
|
||||
write: (metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
|
||||
write: (input.metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
|
||||
// @ts-expect-error
|
||||
metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
|
||||
input.metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
|
||||
0) as number,
|
||||
read: usage.cachedInputTokens ?? 0,
|
||||
read: input.usage.cachedInputTokens ?? 0,
|
||||
},
|
||||
}
|
||||
return {
|
||||
cost: new Decimal(0)
|
||||
.add(new Decimal(tokens.input).mul(model.cost?.input ?? 0).div(1_000_000))
|
||||
.add(new Decimal(tokens.output).mul(model.cost?.output ?? 0).div(1_000_000))
|
||||
.add(new Decimal(tokens.cache.read).mul(model.cost?.cache_read ?? 0).div(1_000_000))
|
||||
.add(new Decimal(tokens.cache.write).mul(model.cost?.cache_write ?? 0).div(1_000_000))
|
||||
.add(new Decimal(tokens.input).mul(input.model.cost?.input ?? 0).div(1_000_000))
|
||||
.add(new Decimal(tokens.output).mul(input.model.cost?.output ?? 0).div(1_000_000))
|
||||
.add(new Decimal(tokens.cache.read).mul(input.model.cost?.cache_read ?? 0).div(1_000_000))
|
||||
.add(new Decimal(tokens.cache.write).mul(input.model.cost?.cache_write ?? 0).div(1_000_000))
|
||||
.toNumber(),
|
||||
tokens,
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
export class BusyError extends Error {
|
||||
constructor(public readonly sessionID: string) {
|
||||
@@ -352,12 +379,14 @@ export namespace Session {
|
||||
}
|
||||
}
|
||||
|
||||
export async function initialize(input: {
|
||||
sessionID: string
|
||||
modelID: string
|
||||
providerID: string
|
||||
messageID: string
|
||||
}) {
|
||||
export const initialize = fn(
|
||||
z.object({
|
||||
sessionID: Identifier.schema("session"),
|
||||
modelID: z.string(),
|
||||
providerID: z.string(),
|
||||
messageID: Identifier.schema("message"),
|
||||
}),
|
||||
async (input) => {
|
||||
await SessionPrompt.prompt({
|
||||
sessionID: input.sessionID,
|
||||
messageID: input.messageID,
|
||||
@@ -374,5 +403,6 @@ export namespace Session {
|
||||
],
|
||||
})
|
||||
await Project.setInitialized(Instance.project.id)
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1031,7 +1031,11 @@ export namespace SessionPrompt {
|
||||
break
|
||||
|
||||
case "finish-step":
|
||||
const usage = Session.getUsage(input.model, value.usage, value.providerMetadata)
|
||||
const usage = Session.getUsage({
|
||||
model: input.model,
|
||||
usage: value.usage,
|
||||
metadata: value.providerMetadata,
|
||||
})
|
||||
assistantMsg.cost += usage.cost
|
||||
assistantMsg.tokens = usage.tokens
|
||||
await Session.updatePart({
|
||||
|
||||
@@ -26,8 +26,11 @@ export const TaskTool = Tool.define("task", async () => {
|
||||
async execute(params, ctx) {
|
||||
const agent = await Agent.get(params.subagent_type)
|
||||
if (!agent) throw new Error(`Unknown agent type: ${params.subagent_type} is not a valid agent type`)
|
||||
const session = await Session.create(ctx.sessionID, params.description + ` (@${agent.name} subagent)`)
|
||||
const msg = await Session.getMessage(ctx.sessionID, ctx.messageID)
|
||||
const session = await Session.create({
|
||||
parentID: ctx.sessionID,
|
||||
title: params.description + ` (@${agent.name} subagent)`,
|
||||
})
|
||||
const msg = await Session.getMessage({ sessionID: ctx.sessionID, messageID: ctx.messageID })
|
||||
if (msg.info.role !== "assistant") throw new Error("Not an assistant message")
|
||||
const messageID = Identifier.ascending("message")
|
||||
const parts: Record<string, MessageV2.ToolPart> = {}
|
||||
|
||||
Reference in New Issue
Block a user