diff --git a/packages/console/app/src/routes/workspace/[id]/model-section.tsx b/packages/console/app/src/routes/workspace/[id]/model-section.tsx index 0a608603..4dea0b98 100644 --- a/packages/console/app/src/routes/workspace/[id]/model-section.tsx +++ b/packages/console/app/src/routes/workspace/[id]/model-section.tsx @@ -2,7 +2,7 @@ import { Model } from "@opencode-ai/console-core/model.js" import { query, action, useParams, createAsync, json } from "@solidjs/router" import { createMemo, For, Show } from "solid-js" import { withActor } from "~/context/auth.withActor" -import { ZenModel } from "@opencode-ai/console-core/model.js" +import { ZenData } from "@opencode-ai/console-core/model.js" import styles from "./model-section.module.css" import { querySessionInfo } from "../common" import { IconAlibaba, IconAnthropic, IconMoonshotAI, IconOpenAI, IconStealth, IconXai, IconZai } from "~/component/icon" @@ -21,7 +21,7 @@ const getModelsInfo = query(async (workspaceID: string) => { "use server" return withActor(async () => { return { - all: Object.entries(ZenModel.list()) + all: Object.entries(ZenData.list().models) .filter(([id, _model]) => !["claude-3-5-haiku"].includes(id)) .filter(([id, _model]) => !id.startsWith("an-")) .sort(([_idA, modelA], [_idB, modelB]) => modelA.name.localeCompare(modelB.name)) diff --git a/packages/console/app/src/routes/zen/handler.ts b/packages/console/app/src/routes/zen/handler.ts index 1f93d971..67b03ab0 100644 --- a/packages/console/app/src/routes/zen/handler.ts +++ b/packages/console/app/src/routes/zen/handler.ts @@ -10,7 +10,7 @@ import { Resource } from "@opencode-ai/console-resource" import { Billing } from "../../../../core/src/billing" import { Actor } from "@opencode-ai/console-core/actor.js" import { WorkspaceTable } from "@opencode-ai/console-core/schema/workspace.sql.js" -import { ZenModel } from "@opencode-ai/console-core/model.js" +import { ZenData } from "@opencode-ai/console-core/model.js" import { UserTable } from "@opencode-ai/console-core/schema/user.sql.js" import { ModelTable } from "@opencode-ai/console-core/schema/model.sql.js" import { ProviderTable } from "@opencode-ai/console-core/schema/provider.sql.js" @@ -39,7 +39,8 @@ export async function handler( class UserLimitError extends Error {} class ModelError extends Error {} - type Model = z.infer + type ZenData = Awaited> + type Model = ZenData["models"][string] const FREE_WORKSPACES = [ "wrk_01K46JDFR0E75SG2Q8K172KF3Y", // frank @@ -66,8 +67,9 @@ export async function handler( session: input.request.headers.get("x-opencode-session"), request: input.request.headers.get("x-opencode-request"), }) - const modelInfo = validateModel(body.model) - const providerInfo = selectProvider(modelInfo) + const zenData = ZenData.list() + const modelInfo = validateModel(zenData, body.model) + const providerInfo = selectProvider(zenData, modelInfo) const authInfo = await authenticate(modelInfo, providerInfo) validateBilling(modelInfo, authInfo) validateModelSettings(authInfo) @@ -211,27 +213,29 @@ export async function handler( ) } - function validateModel(reqModel: string) { - const json = JSON.parse(Resource.ZEN_MODELS.value) - - const allModels = ZenModel.ModelsSchema.parse(json) - - if (!(reqModel in allModels)) { + function validateModel(zenData: ZenData, reqModel: string) { + if (!(reqModel in zenData.models)) { throw new ModelError(`Model ${reqModel} not supported`) } - const modelId = reqModel as keyof typeof allModels - const modelData = allModels[modelId] + const modelId = reqModel as keyof typeof zenData.models + const modelData = zenData.models[modelId] logger.metric({ model: modelId }) return { id: modelId, ...modelData } } - function selectProvider(model: Awaited>) { + function selectProvider(zenData: ZenData, model: Awaited>) { const providers = model.providers .filter((provider) => !provider.disabled) .flatMap((provider) => Array(provider.weight ?? 1).fill(provider)) - return providers[Math.floor(Math.random() * providers.length)] + const provider = providers[Math.floor(Math.random() * providers.length)] + + if (!(provider.id in zenData.providers)) { + throw new ModelError(`Provider ${provider.id} not supported`) + } + + return { ...provider, ...zenData.providers[provider.id] } } async function authenticate( diff --git a/packages/console/core/script/promote-models.ts b/packages/console/core/script/promote-models.ts index 1a5cf2fd..67c2b6f3 100755 --- a/packages/console/core/script/promote-models.ts +++ b/packages/console/core/script/promote-models.ts @@ -2,7 +2,7 @@ import { $ } from "bun" import path from "path" -import { ZenModel } from "../src/model" +import { ZenData } from "../src/model" const stage = process.argv[2] if (!stage) throw new Error("Stage is required") @@ -18,7 +18,7 @@ const value = ret if (!value) throw new Error("ZEN_MODELS not found") // validate value -ZenModel.ModelsSchema.parse(JSON.parse(value)) +ZenData.validate(JSON.parse(value)) // update the secret await $`bun sst secret set ZEN_MODELS ${value} --stage ${stage}` diff --git a/packages/console/core/script/update-models.ts b/packages/console/core/script/update-models.ts index 7740fdcf..939af616 100755 --- a/packages/console/core/script/update-models.ts +++ b/packages/console/core/script/update-models.ts @@ -3,7 +3,7 @@ import { $ } from "bun" import path from "path" import os from "os" -import { ZenModel } from "../src/model" +import { ZenData } from "../src/model" const root = path.resolve(process.cwd(), "..", "..", "..") const models = await $`bun sst secret list`.cwd(root).text() @@ -26,7 +26,7 @@ console.log("tempFile", tempFile.name) // open temp file in vim and read the file on close await $`vim ${tempFile.name}` const newValue = JSON.parse(await tempFile.text()) -ZenModel.ModelsSchema.parse(newValue) +ZenData.validate(newValue) // update the secret await $`bun sst secret set ZEN_MODELS ${JSON.stringify(newValue)}` diff --git a/packages/console/core/src/model.ts b/packages/console/core/src/model.ts index 018df5b5..300f92ed 100644 --- a/packages/console/core/src/model.ts +++ b/packages/console/core/src/model.ts @@ -7,7 +7,7 @@ import { fn } from "./util/fn" import { Actor } from "./actor" import { Resource } from "@opencode-ai/console-resource" -export namespace ZenModel { +export namespace ZenData { const ModelCostSchema = z.object({ input: z.number(), output: z.number(), @@ -16,7 +16,7 @@ export namespace ZenModel { cacheWrite1h: z.number().optional(), }) - export const ModelSchema = z.object({ + const ModelSchema = z.object({ name: z.string(), cost: ModelCostSchema, cost200K: ModelCostSchema.optional(), @@ -24,19 +24,32 @@ export namespace ZenModel { providers: z.array( z.object({ id: z.string(), - api: z.string(), - apiKey: z.string(), model: z.string(), weight: z.number().optional(), - headerMappings: z.record(z.string(), z.string()).optional(), disabled: z.boolean().optional(), }), ), }) - export const ModelsSchema = z.record(z.string(), ModelSchema) + const ProviderSchema = z.object({ + api: z.string(), + apiKey: z.string(), + headerMappings: z.record(z.string(), z.string()).optional(), + }) - export const list = fn(z.void(), () => ModelsSchema.parse(JSON.parse(Resource.ZEN_MODELS.value))) + const ModelsSchema = z.object({ + models: z.record(z.string(), ModelSchema), + providers: z.record(z.string(), ProviderSchema), + }) + + export const validate = fn(ModelsSchema, (input) => { + return input + }) + + export const list = fn(z.void(), () => { + const json = JSON.parse(Resource.ZEN_MODELS.value) + return ModelsSchema.parse(json) + }) } export namespace Model {