mirror of
https://github.com/aljazceru/opencode.git
synced 2026-01-05 00:44:54 +01:00
441 lines
14 KiB
TypeScript
441 lines
14 KiB
TypeScript
import { z } from "zod"
|
|
import type { APIEvent } from "@solidjs/start/server"
|
|
import path from "node:path"
|
|
import { and, Database, eq, isNull, lt, or, sql } from "@opencode/console-core/drizzle/index.js"
|
|
import { KeyTable } from "@opencode/console-core/schema/key.sql.js"
|
|
import { BillingTable, UsageTable } from "@opencode/console-core/schema/billing.sql.js"
|
|
import { centsToMicroCents } from "@opencode/console-core/util/price.js"
|
|
import { Identifier } from "@opencode/console-core/identifier.js"
|
|
import { Resource } from "@opencode/console-resource"
|
|
import { Billing } from "../../../../core/src/billing"
|
|
import { Actor } from "@opencode/console-core/actor.js"
|
|
import { WorkspaceTable } from "@opencode/console-core/schema/workspace.sql.js"
|
|
|
|
export async function handler(
|
|
input: APIEvent,
|
|
opts: {
|
|
modifyBody?: (body: any) => any
|
|
setAuthHeader: (headers: Headers, apiKey: string) => void
|
|
parseApiKey: (headers: Headers) => string | undefined
|
|
onStreamPart: (chunk: string) => void
|
|
getStreamUsage: () => any
|
|
normalizeUsage: (body: any) => {
|
|
inputTokens: number
|
|
outputTokens: number
|
|
reasoningTokens?: number
|
|
cacheReadTokens?: number
|
|
cacheWrite5mTokens?: number
|
|
cacheWrite1hTokens?: number
|
|
}
|
|
},
|
|
) {
|
|
class AuthError extends Error {}
|
|
class CreditsError extends Error {}
|
|
class MonthlyLimitError extends Error {}
|
|
class ModelError extends Error {}
|
|
|
|
const ModelCostSchema = z.object({
|
|
input: z.number(),
|
|
output: z.number(),
|
|
cacheRead: z.number().optional(),
|
|
cacheWrite5m: z.number().optional(),
|
|
cacheWrite1h: z.number().optional(),
|
|
})
|
|
|
|
const ModelSchema = z.object({
|
|
cost: ModelCostSchema,
|
|
cost200K: ModelCostSchema.optional(),
|
|
allowAnonymous: z.boolean().optional(),
|
|
providers: z.array(
|
|
z.object({
|
|
id: z.string(),
|
|
api: z.string(),
|
|
apiKey: z.string(),
|
|
model: z.string(),
|
|
weight: z.number().optional(),
|
|
headerMappings: z.record(z.string(), z.string()).optional(),
|
|
disabled: z.boolean().optional(),
|
|
}),
|
|
),
|
|
})
|
|
|
|
type Model = z.infer<typeof ModelSchema>
|
|
|
|
const FREE_WORKSPACES = [
|
|
"wrk_01K46JDFR0E75SG2Q8K172KF3Y", // frank
|
|
]
|
|
|
|
const logger = {
|
|
metric: (values: Record<string, any>) => {
|
|
console.log(`_metric:${JSON.stringify(values)}`)
|
|
},
|
|
log: console.log,
|
|
debug: (message: string) => {
|
|
if (Resource.App.stage === "production") return
|
|
console.debug(message)
|
|
},
|
|
}
|
|
|
|
try {
|
|
const url = new URL(input.request.url)
|
|
const body = await input.request.json()
|
|
logger.debug(JSON.stringify(body))
|
|
logger.metric({
|
|
is_tream: !!body.stream,
|
|
session: input.request.headers.get("x-opencode-session"),
|
|
request: input.request.headers.get("x-opencode-request"),
|
|
})
|
|
const modelInfo = validateModel(body.model)
|
|
const providerInfo = selectProvider(modelInfo)
|
|
const authInfo = await authenticate(modelInfo)
|
|
validateBilling(modelInfo, authInfo)
|
|
logger.metric({ provider: providerInfo.id })
|
|
|
|
// Request to model provider
|
|
const startTimestamp = Date.now()
|
|
const res = await fetch(path.posix.join(providerInfo.api, url.pathname.replace(/^\/zen/, "") + url.search), {
|
|
method: "POST",
|
|
headers: (() => {
|
|
const headers = input.request.headers
|
|
headers.delete("host")
|
|
headers.delete("content-length")
|
|
opts.setAuthHeader(headers, providerInfo.apiKey)
|
|
Object.entries(providerInfo.headerMappings ?? {}).forEach(([k, v]) => {
|
|
headers.set(k, headers.get(v)!)
|
|
})
|
|
return headers
|
|
})(),
|
|
body: JSON.stringify({
|
|
...(opts.modifyBody?.(body) ?? body),
|
|
model: providerInfo.model,
|
|
}),
|
|
})
|
|
|
|
// Scrub response headers
|
|
const resHeaders = new Headers()
|
|
const keepHeaders = ["content-type", "cache-control"]
|
|
for (const [k, v] of res.headers.entries()) {
|
|
if (keepHeaders.includes(k.toLowerCase())) {
|
|
resHeaders.set(k, v)
|
|
}
|
|
}
|
|
|
|
// Handle non-streaming response
|
|
if (!body.stream) {
|
|
const json = await res.json()
|
|
const body = JSON.stringify(json)
|
|
logger.metric({ response_length: body.length })
|
|
logger.debug(body)
|
|
await trackUsage(authInfo, modelInfo, providerInfo.id, json.usage)
|
|
await reload(authInfo)
|
|
return new Response(body, {
|
|
status: res.status,
|
|
statusText: res.statusText,
|
|
headers: resHeaders,
|
|
})
|
|
}
|
|
|
|
// Handle streaming response
|
|
const stream = new ReadableStream({
|
|
start(c) {
|
|
const reader = res.body?.getReader()
|
|
const decoder = new TextDecoder()
|
|
let buffer = ""
|
|
let responseLength = 0
|
|
|
|
function pump(): Promise<void> {
|
|
return (
|
|
reader?.read().then(async ({ done, value }) => {
|
|
if (done) {
|
|
logger.metric({ response_length: responseLength })
|
|
const usage = opts.getStreamUsage()
|
|
if (usage) {
|
|
await trackUsage(authInfo, modelInfo, providerInfo.id, usage)
|
|
await reload(authInfo)
|
|
}
|
|
c.close()
|
|
return
|
|
}
|
|
|
|
if (responseLength === 0) {
|
|
logger.metric({ time_to_first_byte: Date.now() - startTimestamp })
|
|
}
|
|
responseLength += value.length
|
|
buffer += decoder.decode(value, { stream: true })
|
|
|
|
const parts = buffer.split("\n\n")
|
|
buffer = parts.pop() ?? ""
|
|
|
|
for (const part of parts) {
|
|
logger.debug(part)
|
|
opts.onStreamPart(part.trim())
|
|
}
|
|
|
|
c.enqueue(value)
|
|
|
|
return pump()
|
|
}) || Promise.resolve()
|
|
)
|
|
}
|
|
|
|
return pump()
|
|
},
|
|
})
|
|
|
|
return new Response(stream, {
|
|
status: res.status,
|
|
statusText: res.statusText,
|
|
headers: resHeaders,
|
|
})
|
|
} catch (error: any) {
|
|
logger.metric({
|
|
"error.type": error.constructor.name,
|
|
"error.message": error.message,
|
|
})
|
|
|
|
// Note: both top level "type" and "error.type" fields are used by the @ai-sdk/anthropic client to render the error message.
|
|
if (
|
|
error instanceof AuthError ||
|
|
error instanceof CreditsError ||
|
|
error instanceof MonthlyLimitError ||
|
|
error instanceof ModelError
|
|
)
|
|
return new Response(
|
|
JSON.stringify({
|
|
type: "error",
|
|
error: { type: error.constructor.name, message: error.message },
|
|
}),
|
|
{ status: 401 },
|
|
)
|
|
|
|
return new Response(
|
|
JSON.stringify({
|
|
type: "error",
|
|
error: {
|
|
type: "error",
|
|
message: error.message,
|
|
},
|
|
}),
|
|
{ status: 500 },
|
|
)
|
|
}
|
|
|
|
function validateModel(reqModel: string) {
|
|
const json = JSON.parse(Resource.ZEN_MODELS.value)
|
|
|
|
const allModels = z.record(z.string(), ModelSchema).parse(json)
|
|
|
|
if (!(reqModel in allModels)) {
|
|
throw new ModelError(`Model ${reqModel} not supported`)
|
|
}
|
|
const modelId = reqModel as keyof typeof allModels
|
|
const modelData = allModels[modelId]
|
|
|
|
logger.metric({ model: modelId })
|
|
|
|
return { id: modelId, ...modelData }
|
|
}
|
|
|
|
function selectProvider(model: Model) {
|
|
const providers = model.providers
|
|
.filter((provider) => !provider.disabled)
|
|
.flatMap((provider) => Array<typeof provider>(provider.weight ?? 1).fill(provider))
|
|
return providers[Math.floor(Math.random() * providers.length)]
|
|
}
|
|
|
|
async function authenticate(model: Model) {
|
|
const apiKey = opts.parseApiKey(input.request.headers)
|
|
if (!apiKey) {
|
|
if (model.allowAnonymous) return
|
|
throw new AuthError("Missing API key.")
|
|
}
|
|
|
|
const data = await Database.use((tx) =>
|
|
tx
|
|
.select({
|
|
apiKey: KeyTable.id,
|
|
workspaceID: KeyTable.workspaceID,
|
|
balance: BillingTable.balance,
|
|
paymentMethodID: BillingTable.paymentMethodID,
|
|
monthlyLimit: BillingTable.monthlyLimit,
|
|
monthlyUsage: BillingTable.monthlyUsage,
|
|
timeMonthlyUsageUpdated: BillingTable.timeMonthlyUsageUpdated,
|
|
})
|
|
.from(KeyTable)
|
|
.innerJoin(WorkspaceTable, eq(WorkspaceTable.id, KeyTable.workspaceID))
|
|
.innerJoin(BillingTable, eq(BillingTable.workspaceID, KeyTable.workspaceID))
|
|
.where(and(eq(KeyTable.key, apiKey), isNull(KeyTable.timeDeleted)))
|
|
.then((rows) => rows[0]),
|
|
)
|
|
|
|
if (!data) throw new AuthError("Invalid API key.")
|
|
logger.metric({
|
|
api_key: data.apiKey,
|
|
workspace: data.workspaceID,
|
|
})
|
|
|
|
const isFree = FREE_WORKSPACES.includes(data.workspaceID)
|
|
|
|
return {
|
|
apiKeyId: data.apiKey,
|
|
workspaceID: data.workspaceID,
|
|
billing: {
|
|
paymentMethodID: data.paymentMethodID,
|
|
balance: data.balance,
|
|
monthlyLimit: data.monthlyLimit,
|
|
monthlyUsage: data.monthlyUsage,
|
|
timeMonthlyUsageUpdated: data.timeMonthlyUsageUpdated,
|
|
},
|
|
isFree,
|
|
}
|
|
}
|
|
|
|
function validateBilling(model: Model, authInfo: Awaited<ReturnType<typeof authenticate>>) {
|
|
if (!authInfo || authInfo.isFree) return
|
|
if (model.allowAnonymous) return
|
|
|
|
const billing = authInfo.billing
|
|
if (!billing.paymentMethodID) throw new CreditsError("No payment method")
|
|
if (billing.balance <= 0) throw new CreditsError("Insufficient balance")
|
|
if (
|
|
billing.monthlyLimit &&
|
|
billing.monthlyUsage &&
|
|
billing.timeMonthlyUsageUpdated &&
|
|
billing.monthlyUsage >= centsToMicroCents(billing.monthlyLimit * 100)
|
|
) {
|
|
const now = new Date()
|
|
const currentYear = now.getUTCFullYear()
|
|
const currentMonth = now.getUTCMonth()
|
|
const dateYear = billing.timeMonthlyUsageUpdated.getUTCFullYear()
|
|
const dateMonth = billing.timeMonthlyUsageUpdated.getUTCMonth()
|
|
if (currentYear === dateYear && currentMonth === dateMonth)
|
|
throw new MonthlyLimitError(`You have reached your monthly spending limit of $${billing.monthlyLimit}.`)
|
|
}
|
|
}
|
|
|
|
async function trackUsage(
|
|
authInfo: Awaited<ReturnType<typeof authenticate>>,
|
|
modelInfo: ReturnType<typeof validateModel>,
|
|
providerId: string,
|
|
usage: any,
|
|
) {
|
|
const { inputTokens, outputTokens, reasoningTokens, cacheReadTokens, cacheWrite5mTokens, cacheWrite1hTokens } =
|
|
opts.normalizeUsage(usage)
|
|
|
|
const modelCost =
|
|
modelInfo.cost200K &&
|
|
usage.inputTokens + usage.cacheReadTokens + usage.cacheWrite5mTokens + usage.cacheWrite1hTokens > 200_000
|
|
? modelInfo.cost200K
|
|
: modelInfo.cost
|
|
|
|
const inputCost = modelCost.input * inputTokens * 100
|
|
const outputCost = modelCost.output * outputTokens * 100
|
|
const reasoningCost = (() => {
|
|
if (!reasoningTokens) return undefined
|
|
return modelCost.output * reasoningTokens * 100
|
|
})()
|
|
const cacheReadCost = (() => {
|
|
if (!cacheReadTokens) return undefined
|
|
if (!modelCost.cacheRead) return undefined
|
|
return modelCost.cacheRead * cacheReadTokens * 100
|
|
})()
|
|
const cacheWrite5mCost = (() => {
|
|
if (!cacheWrite5mTokens) return undefined
|
|
if (!modelCost.cacheWrite5m) return undefined
|
|
return modelCost.cacheWrite5m * cacheWrite5mTokens * 100
|
|
})()
|
|
const cacheWrite1hCost = (() => {
|
|
if (!cacheWrite1hTokens) return undefined
|
|
if (!modelCost.cacheWrite1h) return undefined
|
|
return modelCost.cacheWrite1h * cacheWrite1hTokens * 100
|
|
})()
|
|
const totalCostInCent =
|
|
inputCost +
|
|
outputCost +
|
|
(reasoningCost ?? 0) +
|
|
(cacheReadCost ?? 0) +
|
|
(cacheWrite5mCost ?? 0) +
|
|
(cacheWrite1hCost ?? 0)
|
|
|
|
logger.metric({
|
|
"tokens.input": inputTokens,
|
|
"tokens.output": outputTokens,
|
|
"tokens.reasoning": reasoningTokens,
|
|
"tokens.cache_read": cacheReadTokens,
|
|
"tokens.cache_write_5m": cacheWrite5mTokens,
|
|
"tokens.cache_write_1h": cacheWrite1hTokens,
|
|
"cost.input": Math.round(inputCost),
|
|
"cost.output": Math.round(outputCost),
|
|
"cost.reasoning": reasoningCost ? Math.round(reasoningCost) : undefined,
|
|
"cost.cache_read": cacheReadCost ? Math.round(cacheReadCost) : undefined,
|
|
"cost.cache_write_5m": cacheWrite5mCost ? Math.round(cacheWrite5mCost) : undefined,
|
|
"cost.cache_write_1h": cacheWrite1hCost ? Math.round(cacheWrite1hCost) : undefined,
|
|
"cost.total": Math.round(totalCostInCent),
|
|
})
|
|
|
|
if (!authInfo) return
|
|
|
|
const cost = authInfo.isFree ? 0 : centsToMicroCents(totalCostInCent)
|
|
await Database.transaction(async (tx) => {
|
|
await tx.insert(UsageTable).values({
|
|
workspaceID: authInfo.workspaceID,
|
|
id: Identifier.create("usage"),
|
|
model: modelInfo.id,
|
|
provider: providerId,
|
|
inputTokens,
|
|
outputTokens,
|
|
reasoningTokens,
|
|
cacheReadTokens,
|
|
cacheWrite5mTokens,
|
|
cacheWrite1hTokens,
|
|
cost,
|
|
})
|
|
await tx
|
|
.update(BillingTable)
|
|
.set({
|
|
balance: sql`${BillingTable.balance} - ${cost}`,
|
|
monthlyUsage: sql`
|
|
CASE
|
|
WHEN MONTH(${BillingTable.timeMonthlyUsageUpdated}) = MONTH(now()) AND YEAR(${BillingTable.timeMonthlyUsageUpdated}) = YEAR(now()) THEN ${BillingTable.monthlyUsage} + ${cost}
|
|
ELSE ${cost}
|
|
END
|
|
`,
|
|
timeMonthlyUsageUpdated: sql`now()`,
|
|
})
|
|
.where(eq(BillingTable.workspaceID, authInfo.workspaceID))
|
|
})
|
|
|
|
await Database.use((tx) =>
|
|
tx
|
|
.update(KeyTable)
|
|
.set({ timeUsed: sql`now()` })
|
|
.where(eq(KeyTable.id, authInfo.apiKeyId)),
|
|
)
|
|
}
|
|
|
|
async function reload(authInfo: Awaited<ReturnType<typeof authenticate>>) {
|
|
if (!authInfo) return
|
|
|
|
const lock = await Database.use((tx) =>
|
|
tx
|
|
.update(BillingTable)
|
|
.set({
|
|
timeReloadLockedTill: sql`now() + interval 1 minute`,
|
|
})
|
|
.where(
|
|
and(
|
|
eq(BillingTable.workspaceID, authInfo.workspaceID),
|
|
eq(BillingTable.reload, true),
|
|
lt(BillingTable.balance, centsToMicroCents(Billing.CHARGE_THRESHOLD)),
|
|
or(isNull(BillingTable.timeReloadLockedTill), lt(BillingTable.timeReloadLockedTill, sql`now()`)),
|
|
),
|
|
),
|
|
)
|
|
if (lock.rowsAffected === 0) return
|
|
|
|
await Actor.provide("system", { workspaceID: authInfo.workspaceID }, async () => {
|
|
await Billing.reload()
|
|
})
|
|
}
|
|
}
|