mirror of
https://github.com/aljazceru/opencode.git
synced 2025-12-24 03:04:21 +01:00
550 lines
18 KiB
TypeScript
550 lines
18 KiB
TypeScript
import type { APIEvent } from "@solidjs/start/server"
|
|
import { and, Database, eq, isNull, lt, or, sql } from "@opencode-ai/console-core/drizzle/index.js"
|
|
import { KeyTable } from "@opencode-ai/console-core/schema/key.sql.js"
|
|
import { BillingTable, UsageTable } from "@opencode-ai/console-core/schema/billing.sql.js"
|
|
import { centsToMicroCents } from "@opencode-ai/console-core/util/price.js"
|
|
import { Identifier } from "@opencode-ai/console-core/identifier.js"
|
|
import { Billing } from "@opencode-ai/console-core/billing.js"
|
|
import { Actor } from "@opencode-ai/console-core/actor.js"
|
|
import { WorkspaceTable } from "@opencode-ai/console-core/schema/workspace.sql.js"
|
|
import { ZenData } from "@opencode-ai/console-core/model.js"
|
|
import { UserTable } from "@opencode-ai/console-core/schema/user.sql.js"
|
|
import { ModelTable } from "@opencode-ai/console-core/schema/model.sql.js"
|
|
import { ProviderTable } from "@opencode-ai/console-core/schema/provider.sql.js"
|
|
import { logger } from "./logger"
|
|
import { AuthError, CreditsError, MonthlyLimitError, UserLimitError, ModelError } from "./error"
|
|
import {
|
|
createBodyConverter,
|
|
createStreamPartConverter,
|
|
createResponseConverter,
|
|
} from "./provider/provider"
|
|
import { anthropicHelper } from "./provider/anthropic"
|
|
import { openaiHelper } from "./provider/openai"
|
|
import { oaCompatHelper } from "./provider/openai-compatible"
|
|
|
|
type ZenData = Awaited<ReturnType<typeof ZenData.list>>
|
|
type Model = ZenData["models"][string]
|
|
|
|
export async function handler(
|
|
input: APIEvent,
|
|
opts: {
|
|
format: ZenData.Format
|
|
parseApiKey: (headers: Headers) => string | undefined
|
|
},
|
|
) {
|
|
const FREE_WORKSPACES = [
|
|
"wrk_01K46JDFR0E75SG2Q8K172KF3Y", // frank
|
|
"wrk_01K6W1A3VE0KMNVSCQT43BG2SX", // opencode bench
|
|
]
|
|
|
|
try {
|
|
const body = await input.request.json()
|
|
logger.metric({
|
|
is_tream: !!body.stream,
|
|
session: input.request.headers.get("x-opencode-session"),
|
|
request: input.request.headers.get("x-opencode-request"),
|
|
})
|
|
const zenData = ZenData.list()
|
|
const modelInfo = validateModel(zenData, body.model)
|
|
const providerInfo = selectProvider(
|
|
zenData,
|
|
modelInfo,
|
|
input.request.headers.get("x-real-ip") ?? "",
|
|
)
|
|
const authInfo = await authenticate(modelInfo, providerInfo)
|
|
validateBilling(modelInfo, authInfo)
|
|
validateModelSettings(authInfo)
|
|
updateProviderKey(authInfo, providerInfo)
|
|
logger.metric({ provider: providerInfo.id })
|
|
|
|
// Request to model provider
|
|
const startTimestamp = Date.now()
|
|
const reqUrl = providerInfo.modifyUrl(providerInfo.api)
|
|
const reqBody = JSON.stringify(
|
|
providerInfo.modifyBody({
|
|
...createBodyConverter(opts.format, providerInfo.format)(body),
|
|
model: providerInfo.model,
|
|
}),
|
|
)
|
|
logger.debug("REQUEST URL: " + reqUrl)
|
|
logger.debug("REQUEST: " + reqBody)
|
|
const res = await fetch(reqUrl, {
|
|
method: "POST",
|
|
headers: (() => {
|
|
const headers = input.request.headers
|
|
headers.delete("host")
|
|
headers.delete("content-length")
|
|
providerInfo.modifyHeaders(headers, body, providerInfo.apiKey)
|
|
Object.entries(providerInfo.headerMappings ?? {}).forEach(([k, v]) => {
|
|
headers.set(k, headers.get(v)!)
|
|
})
|
|
return headers
|
|
})(),
|
|
body: reqBody,
|
|
})
|
|
|
|
// Scrub response headers
|
|
const resHeaders = new Headers()
|
|
const keepHeaders = ["content-type", "cache-control"]
|
|
for (const [k, v] of res.headers.entries()) {
|
|
if (keepHeaders.includes(k.toLowerCase())) {
|
|
resHeaders.set(k, v)
|
|
}
|
|
}
|
|
logger.debug("STATUS: " + res.status + " " + res.statusText)
|
|
if (res.status === 400 || res.status === 503) {
|
|
logger.debug("RESPONSE: " + (await res.text()))
|
|
}
|
|
|
|
// Handle non-streaming response
|
|
if (!body.stream) {
|
|
const responseConverter = createResponseConverter(providerInfo.format, opts.format)
|
|
const json = await res.json()
|
|
const body = JSON.stringify(responseConverter(json))
|
|
logger.metric({ response_length: body.length })
|
|
logger.debug("RESPONSE: " + body)
|
|
await trackUsage(authInfo, modelInfo, providerInfo, json.usage)
|
|
await reload(authInfo)
|
|
return new Response(body, {
|
|
status: res.status,
|
|
statusText: res.statusText,
|
|
headers: resHeaders,
|
|
})
|
|
}
|
|
|
|
// Handle streaming response
|
|
const streamConverter = createStreamPartConverter(providerInfo.format, opts.format)
|
|
const usageParser = providerInfo.createUsageParser()
|
|
const stream = new ReadableStream({
|
|
start(c) {
|
|
const reader = res.body?.getReader()
|
|
const decoder = new TextDecoder()
|
|
const encoder = new TextEncoder()
|
|
let buffer = ""
|
|
let responseLength = 0
|
|
|
|
function pump(): Promise<void> {
|
|
return (
|
|
reader?.read().then(async ({ done, value }) => {
|
|
if (done) {
|
|
logger.metric({
|
|
response_length: responseLength,
|
|
"timestamp.last_byte": Date.now(),
|
|
})
|
|
const usage = usageParser.retrieve()
|
|
if (usage) {
|
|
await trackUsage(authInfo, modelInfo, providerInfo, usage)
|
|
await reload(authInfo)
|
|
}
|
|
c.close()
|
|
return
|
|
}
|
|
|
|
if (responseLength === 0) {
|
|
const now = Date.now()
|
|
logger.metric({
|
|
time_to_first_byte: now - startTimestamp,
|
|
"timestamp.first_byte": now,
|
|
})
|
|
}
|
|
responseLength += value.length
|
|
buffer += decoder.decode(value, { stream: true })
|
|
|
|
const parts = buffer.split("\n\n")
|
|
buffer = parts.pop() ?? ""
|
|
|
|
for (let part of parts) {
|
|
logger.debug("PART: " + part)
|
|
|
|
part = part.trim()
|
|
usageParser.parse(part)
|
|
|
|
if (providerInfo.format !== opts.format) {
|
|
part = streamConverter(part)
|
|
c.enqueue(encoder.encode(part + "\n\n"))
|
|
}
|
|
}
|
|
|
|
if (providerInfo.format === opts.format) {
|
|
c.enqueue(value)
|
|
}
|
|
|
|
return pump()
|
|
}) || Promise.resolve()
|
|
)
|
|
}
|
|
|
|
return pump()
|
|
},
|
|
})
|
|
|
|
return new Response(stream, {
|
|
status: res.status,
|
|
statusText: res.statusText,
|
|
headers: resHeaders,
|
|
})
|
|
} catch (error: any) {
|
|
logger.metric({
|
|
"error.type": error.constructor.name,
|
|
"error.message": error.message,
|
|
})
|
|
|
|
// Note: both top level "type" and "error.type" fields are used by the @ai-sdk/anthropic client to render the error message.
|
|
if (
|
|
error instanceof AuthError ||
|
|
error instanceof CreditsError ||
|
|
error instanceof MonthlyLimitError ||
|
|
error instanceof UserLimitError ||
|
|
error instanceof ModelError
|
|
)
|
|
return new Response(
|
|
JSON.stringify({
|
|
type: "error",
|
|
error: { type: error.constructor.name, message: error.message },
|
|
}),
|
|
{ status: 401 },
|
|
)
|
|
|
|
return new Response(
|
|
JSON.stringify({
|
|
type: "error",
|
|
error: {
|
|
type: "error",
|
|
message: error.message,
|
|
},
|
|
}),
|
|
{ status: 500 },
|
|
)
|
|
}
|
|
|
|
function validateModel(zenData: ZenData, reqModel: string) {
|
|
if (!(reqModel in zenData.models)) {
|
|
throw new ModelError(`Model ${reqModel} not supported`)
|
|
}
|
|
const modelId = reqModel as keyof typeof zenData.models
|
|
const modelData = zenData.models[modelId]
|
|
|
|
logger.metric({ model: modelId })
|
|
|
|
return { id: modelId, ...modelData }
|
|
}
|
|
|
|
function selectProvider(
|
|
zenData: ZenData,
|
|
model: Awaited<ReturnType<typeof validateModel>>,
|
|
ip: string,
|
|
) {
|
|
const providers = model.providers
|
|
.filter((provider) => !provider.disabled)
|
|
.flatMap((provider) => Array<typeof provider>(provider.weight ?? 1).fill(provider))
|
|
|
|
// Use the last 2 characters of IP address to select a provider
|
|
const lastChars = ip.slice(-2)
|
|
const index = parseInt(lastChars, 16) % providers.length
|
|
const provider = providers[index || 0]
|
|
|
|
if (!(provider.id in zenData.providers)) {
|
|
throw new ModelError(`Provider ${provider.id} not supported`)
|
|
}
|
|
|
|
const format = zenData.providers[provider.id].format
|
|
|
|
return {
|
|
...provider,
|
|
...zenData.providers[provider.id],
|
|
...(format === "anthropic"
|
|
? anthropicHelper
|
|
: format === "openai"
|
|
? openaiHelper
|
|
: oaCompatHelper),
|
|
}
|
|
}
|
|
|
|
async function authenticate(
|
|
model: Awaited<ReturnType<typeof validateModel>>,
|
|
providerInfo: Awaited<ReturnType<typeof selectProvider>>,
|
|
) {
|
|
const apiKey = opts.parseApiKey(input.request.headers)
|
|
if (!apiKey) {
|
|
if (model.allowAnonymous) return
|
|
throw new AuthError("Missing API key.")
|
|
}
|
|
|
|
const data = await Database.use((tx) =>
|
|
tx
|
|
.select({
|
|
apiKey: KeyTable.id,
|
|
workspaceID: KeyTable.workspaceID,
|
|
billing: {
|
|
balance: BillingTable.balance,
|
|
paymentMethodID: BillingTable.paymentMethodID,
|
|
monthlyLimit: BillingTable.monthlyLimit,
|
|
monthlyUsage: BillingTable.monthlyUsage,
|
|
timeMonthlyUsageUpdated: BillingTable.timeMonthlyUsageUpdated,
|
|
},
|
|
user: {
|
|
id: UserTable.id,
|
|
monthlyLimit: UserTable.monthlyLimit,
|
|
monthlyUsage: UserTable.monthlyUsage,
|
|
timeMonthlyUsageUpdated: UserTable.timeMonthlyUsageUpdated,
|
|
},
|
|
provider: {
|
|
credentials: ProviderTable.credentials,
|
|
},
|
|
timeDisabled: ModelTable.timeCreated,
|
|
})
|
|
.from(KeyTable)
|
|
.innerJoin(WorkspaceTable, eq(WorkspaceTable.id, KeyTable.workspaceID))
|
|
.innerJoin(BillingTable, eq(BillingTable.workspaceID, KeyTable.workspaceID))
|
|
.innerJoin(
|
|
UserTable,
|
|
and(eq(UserTable.workspaceID, KeyTable.workspaceID), eq(UserTable.id, KeyTable.userID)),
|
|
)
|
|
.leftJoin(
|
|
ModelTable,
|
|
and(eq(ModelTable.workspaceID, KeyTable.workspaceID), eq(ModelTable.model, model.id)),
|
|
)
|
|
.leftJoin(
|
|
ProviderTable,
|
|
and(
|
|
eq(ProviderTable.workspaceID, KeyTable.workspaceID),
|
|
eq(ProviderTable.provider, providerInfo.id),
|
|
),
|
|
)
|
|
.where(and(eq(KeyTable.key, apiKey), isNull(KeyTable.timeDeleted)))
|
|
.then((rows) => rows[0]),
|
|
)
|
|
|
|
if (!data) throw new AuthError("Invalid API key.")
|
|
logger.metric({
|
|
api_key: data.apiKey,
|
|
workspace: data.workspaceID,
|
|
})
|
|
|
|
return {
|
|
apiKeyId: data.apiKey,
|
|
workspaceID: data.workspaceID,
|
|
billing: data.billing,
|
|
user: data.user,
|
|
provider: data.provider,
|
|
isFree: FREE_WORKSPACES.includes(data.workspaceID),
|
|
isDisabled: !!data.timeDisabled,
|
|
}
|
|
}
|
|
|
|
function validateBilling(model: Model, authInfo: Awaited<ReturnType<typeof authenticate>>) {
|
|
if (!authInfo) return
|
|
if (authInfo.provider?.credentials) return
|
|
if (authInfo.isFree) return
|
|
if (model.allowAnonymous) return
|
|
|
|
const billing = authInfo.billing
|
|
if (!billing.paymentMethodID)
|
|
throw new CreditsError(
|
|
`No payment method. Add a payment method here: https://opencode.ai/workspace/${authInfo.workspaceID}/billing`,
|
|
)
|
|
if (billing.balance <= 0)
|
|
throw new CreditsError(
|
|
`Insufficient balance. Manage your billing here: https://opencode.ai/workspace/${authInfo.workspaceID}/billing`,
|
|
)
|
|
|
|
const now = new Date()
|
|
const currentYear = now.getUTCFullYear()
|
|
const currentMonth = now.getUTCMonth()
|
|
if (
|
|
billing.monthlyLimit &&
|
|
billing.monthlyUsage &&
|
|
billing.timeMonthlyUsageUpdated &&
|
|
billing.monthlyUsage >= centsToMicroCents(billing.monthlyLimit * 100)
|
|
) {
|
|
const dateYear = billing.timeMonthlyUsageUpdated.getUTCFullYear()
|
|
const dateMonth = billing.timeMonthlyUsageUpdated.getUTCMonth()
|
|
if (currentYear === dateYear && currentMonth === dateMonth)
|
|
throw new MonthlyLimitError(
|
|
`Your workspace has reached its monthly spending limit of $${billing.monthlyLimit}. Manage your limits here: https://opencode.ai/workspace/${authInfo.workspaceID}/billing`,
|
|
)
|
|
}
|
|
|
|
if (
|
|
authInfo.user.monthlyLimit &&
|
|
authInfo.user.monthlyUsage &&
|
|
authInfo.user.timeMonthlyUsageUpdated &&
|
|
authInfo.user.monthlyUsage >= centsToMicroCents(authInfo.user.monthlyLimit * 100)
|
|
) {
|
|
const dateYear = authInfo.user.timeMonthlyUsageUpdated.getUTCFullYear()
|
|
const dateMonth = authInfo.user.timeMonthlyUsageUpdated.getUTCMonth()
|
|
if (currentYear === dateYear && currentMonth === dateMonth)
|
|
throw new UserLimitError(
|
|
`You have reached your monthly spending limit of $${authInfo.user.monthlyLimit}. Manage your limits here: https://opencode.ai/workspace/${authInfo.workspaceID}/members`,
|
|
)
|
|
}
|
|
}
|
|
|
|
function validateModelSettings(authInfo: Awaited<ReturnType<typeof authenticate>>) {
|
|
if (!authInfo) return
|
|
if (authInfo.isDisabled) throw new ModelError("Model is disabled")
|
|
}
|
|
|
|
function updateProviderKey(
|
|
authInfo: Awaited<ReturnType<typeof authenticate>>,
|
|
providerInfo: Awaited<ReturnType<typeof selectProvider>>,
|
|
) {
|
|
if (!authInfo) return
|
|
if (!authInfo.provider?.credentials) return
|
|
providerInfo.apiKey = authInfo.provider.credentials
|
|
}
|
|
|
|
async function trackUsage(
|
|
authInfo: Awaited<ReturnType<typeof authenticate>>,
|
|
modelInfo: ReturnType<typeof validateModel>,
|
|
providerInfo: Awaited<ReturnType<typeof selectProvider>>,
|
|
usage: any,
|
|
) {
|
|
const {
|
|
inputTokens,
|
|
outputTokens,
|
|
reasoningTokens,
|
|
cacheReadTokens,
|
|
cacheWrite5mTokens,
|
|
cacheWrite1hTokens,
|
|
} = providerInfo.normalizeUsage(usage)
|
|
|
|
const modelCost =
|
|
modelInfo.cost200K &&
|
|
inputTokens + (cacheReadTokens ?? 0) + (cacheWrite5mTokens ?? 0) + (cacheWrite1hTokens ?? 0) >
|
|
200_000
|
|
? modelInfo.cost200K
|
|
: modelInfo.cost
|
|
|
|
const inputCost = modelCost.input * inputTokens * 100
|
|
const outputCost = modelCost.output * outputTokens * 100
|
|
const reasoningCost = (() => {
|
|
if (!reasoningTokens) return undefined
|
|
return modelCost.output * reasoningTokens * 100
|
|
})()
|
|
const cacheReadCost = (() => {
|
|
if (!cacheReadTokens) return undefined
|
|
if (!modelCost.cacheRead) return undefined
|
|
return modelCost.cacheRead * cacheReadTokens * 100
|
|
})()
|
|
const cacheWrite5mCost = (() => {
|
|
if (!cacheWrite5mTokens) return undefined
|
|
if (!modelCost.cacheWrite5m) return undefined
|
|
return modelCost.cacheWrite5m * cacheWrite5mTokens * 100
|
|
})()
|
|
const cacheWrite1hCost = (() => {
|
|
if (!cacheWrite1hTokens) return undefined
|
|
if (!modelCost.cacheWrite1h) return undefined
|
|
return modelCost.cacheWrite1h * cacheWrite1hTokens * 100
|
|
})()
|
|
const totalCostInCent =
|
|
inputCost +
|
|
outputCost +
|
|
(reasoningCost ?? 0) +
|
|
(cacheReadCost ?? 0) +
|
|
(cacheWrite5mCost ?? 0) +
|
|
(cacheWrite1hCost ?? 0)
|
|
|
|
logger.metric({
|
|
"tokens.input": inputTokens,
|
|
"tokens.output": outputTokens,
|
|
"tokens.reasoning": reasoningTokens,
|
|
"tokens.cache_read": cacheReadTokens,
|
|
"tokens.cache_write_5m": cacheWrite5mTokens,
|
|
"tokens.cache_write_1h": cacheWrite1hTokens,
|
|
"cost.input": Math.round(inputCost),
|
|
"cost.output": Math.round(outputCost),
|
|
"cost.reasoning": reasoningCost ? Math.round(reasoningCost) : undefined,
|
|
"cost.cache_read": cacheReadCost ? Math.round(cacheReadCost) : undefined,
|
|
"cost.cache_write_5m": cacheWrite5mCost ? Math.round(cacheWrite5mCost) : undefined,
|
|
"cost.cache_write_1h": cacheWrite1hCost ? Math.round(cacheWrite1hCost) : undefined,
|
|
"cost.total": Math.round(totalCostInCent),
|
|
})
|
|
|
|
if (!authInfo) return
|
|
|
|
const cost =
|
|
authInfo.isFree || authInfo.provider?.credentials ? 0 : centsToMicroCents(totalCostInCent)
|
|
await Database.transaction(async (tx) => {
|
|
await tx.insert(UsageTable).values({
|
|
workspaceID: authInfo.workspaceID,
|
|
id: Identifier.create("usage"),
|
|
model: modelInfo.id,
|
|
provider: providerInfo.id,
|
|
inputTokens,
|
|
outputTokens,
|
|
reasoningTokens,
|
|
cacheReadTokens,
|
|
cacheWrite5mTokens,
|
|
cacheWrite1hTokens,
|
|
cost,
|
|
keyID: authInfo.apiKeyId,
|
|
})
|
|
await tx
|
|
.update(BillingTable)
|
|
.set({
|
|
balance: sql`${BillingTable.balance} - ${cost}`,
|
|
monthlyUsage: sql`
|
|
CASE
|
|
WHEN MONTH(${BillingTable.timeMonthlyUsageUpdated}) = MONTH(now()) AND YEAR(${BillingTable.timeMonthlyUsageUpdated}) = YEAR(now()) THEN ${BillingTable.monthlyUsage} + ${cost}
|
|
ELSE ${cost}
|
|
END
|
|
`,
|
|
timeMonthlyUsageUpdated: sql`now()`,
|
|
})
|
|
.where(eq(BillingTable.workspaceID, authInfo.workspaceID))
|
|
await tx
|
|
.update(UserTable)
|
|
.set({
|
|
monthlyUsage: sql`
|
|
CASE
|
|
WHEN MONTH(${UserTable.timeMonthlyUsageUpdated}) = MONTH(now()) AND YEAR(${UserTable.timeMonthlyUsageUpdated}) = YEAR(now()) THEN ${UserTable.monthlyUsage} + ${cost}
|
|
ELSE ${cost}
|
|
END
|
|
`,
|
|
timeMonthlyUsageUpdated: sql`now()`,
|
|
})
|
|
.where(
|
|
and(eq(UserTable.workspaceID, authInfo.workspaceID), eq(UserTable.id, authInfo.user.id)),
|
|
)
|
|
})
|
|
|
|
await Database.use((tx) =>
|
|
tx
|
|
.update(KeyTable)
|
|
.set({ timeUsed: sql`now()` })
|
|
.where(eq(KeyTable.id, authInfo.apiKeyId)),
|
|
)
|
|
}
|
|
|
|
async function reload(authInfo: Awaited<ReturnType<typeof authenticate>>) {
|
|
if (!authInfo) return
|
|
if (authInfo.isFree) return
|
|
if (authInfo.provider?.credentials) return
|
|
|
|
const lock = await Database.use((tx) =>
|
|
tx
|
|
.update(BillingTable)
|
|
.set({
|
|
timeReloadLockedTill: sql`now() + interval 1 minute`,
|
|
})
|
|
.where(
|
|
and(
|
|
eq(BillingTable.workspaceID, authInfo.workspaceID),
|
|
eq(BillingTable.reload, true),
|
|
lt(BillingTable.balance, centsToMicroCents(Billing.CHARGE_THRESHOLD)),
|
|
or(
|
|
isNull(BillingTable.timeReloadLockedTill),
|
|
lt(BillingTable.timeReloadLockedTill, sql`now()`),
|
|
),
|
|
),
|
|
),
|
|
)
|
|
if (lock.rowsAffected === 0) return
|
|
|
|
await Actor.provide("system", { workspaceID: authInfo.workspaceID }, async () => {
|
|
await Billing.reload()
|
|
})
|
|
}
|
|
}
|