fix(#243): claude on aws bedrock (#241)

Co-authored-by: Dax Raad <d@ironbay.co>
This commit is contained in:
Dmytro Yankovskyi
2025-06-20 20:57:33 +02:00
committed by GitHub
parent 2fd0e7dd6b
commit 91c4da5dbd
3 changed files with 103 additions and 25 deletions

View File

@@ -12,13 +12,14 @@
"./*": "./src/*.ts"
},
"devDependencies": {
"@ai-sdk/amazon-bedrock": "2.2.10",
"@ai-sdk/anthropic": "1.2.12",
"@tsconfig/bun": "1.0.7",
"@types/bun": "latest",
"@types/turndown": "5.0.5",
"@types/yargs": "17.0.33",
"typescript": "catalog:",
"zod-to-json-schema": "3.24.5",
"@ai-sdk/anthropic": "1.2.12"
"zod-to-json-schema": "3.24.5"
},
"dependencies": {
"@clack/prompts": "0.11.0",

View File

@@ -27,9 +27,13 @@ import { TaskTool } from "../tool/task"
export namespace Provider {
const log = Log.create({ service: "provider" })
type CustomLoader = (
provider: ModelsDev.Provider,
) => Promise<Record<string, any> | false>
type CustomLoader = (provider: ModelsDev.Provider) => Promise<
| {
getModel?: (sdk: any, modelID: string) => Promise<any>
options: Record<string, any>
}
| false
>
type Source = "env" | "config" | "custom" | "api"
@@ -44,30 +48,52 @@ export namespace Provider {
}
}
return {
apiKey: "",
async fetch(input: any, init: any) {
const access = await AuthAnthropic.access()
const headers = {
...init.headers,
authorization: `Bearer ${access}`,
"anthropic-beta": "oauth-2025-04-20",
}
delete headers["x-api-key"]
return fetch(input, {
...init,
headers,
})
options: {
apiKey: "",
async fetch(input: any, init: any) {
const access = await AuthAnthropic.access()
const headers = {
...init.headers,
authorization: `Bearer ${access}`,
"anthropic-beta": "oauth-2025-04-20",
}
delete headers["x-api-key"]
return fetch(input, {
...init,
headers,
})
},
},
}
},
openai: async () => {
return {
async getModel(sdk: any, modelID: string) {
return sdk.responses(modelID)
},
options: {},
}
},
"amazon-bedrock": async () => {
if (!process.env["AWS_PROFILE"]) return false
if (!process.env["AWS_PROFILE"]) false
const region = process.env["AWS_REGION"] ?? "us-east-1"
const { fromNodeProviderChain } = await import(
await BunProc.install("@aws-sdk/credential-providers")
)
return {
region: process.env["AWS_REGION"] ?? "us-east-1",
credentialProvider: fromNodeProviderChain(),
options: {
region,
credentialProvider: fromNodeProviderChain(),
},
async getModel(sdk: any, modelID: string) {
if (modelID.includes("claude")) {
const prefix = region.split("-")[0]
modelID = `${prefix}.${modelID}`
}
return sdk.languageModel(modelID)
},
}
},
}
@@ -80,6 +106,7 @@ export namespace Provider {
[providerID: string]: {
source: Source
info: ModelsDev.Provider
getModel?: (sdk: any, modelID: string) => Promise<any>
options: Record<string, any>
}
} = {}
@@ -95,6 +122,7 @@ export namespace Provider {
id: string,
options: Record<string, any>,
source: Source,
getModel?: (sdk: any, modelID: string) => Promise<any>,
) {
const provider = providers[id]
if (!provider) {
@@ -110,6 +138,7 @@ export namespace Provider {
}
provider.options = mergeDeep(provider.options, options)
provider.source = source
provider.getModel = getModel ?? provider.getModel
}
const configProviders = Object.entries(config.provider ?? {})
@@ -173,7 +202,8 @@ export namespace Provider {
for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) {
if (disabled.has(providerID)) continue
const result = await fn(database[providerID])
if (result) mergeProvider(providerID, result, "custom")
if (result)
mergeProvider(providerID, result.options, "custom", result.getModel)
}
// load config
@@ -236,9 +266,9 @@ export namespace Provider {
const sdk = await getSDK(provider.info)
try {
const language =
// @ts-expect-error
"responses" in sdk ? sdk.responses(modelID) : sdk.languageModel(modelID)
const language = provider.getModel
? await provider.getModel(sdk, modelID)
: sdk.languageModel(modelID)
log.info("found", { providerID, modelID })
s.models.set(key, {
info,