dont hard code the url

This commit is contained in:
Asuka Minato
2024-06-18 03:07:37 +09:00
parent c6c9dc316f
commit 5d4a04ac0e
2 changed files with 10 additions and 10 deletions

View File

@@ -433,7 +433,7 @@ mod test {
"models": {
"model1": {
"type": "gemini",
"completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key=",
"completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/",
"model": "gemini-1.5-flash-latest",
"auth_token_env_var_name": "GEMINI_API_KEY",
},

View File

@@ -70,6 +70,8 @@ impl Gemini {
.as_ref()
.context("must specify `gemini_endpoint` to use gemini")?
.to_owned()
+ self.config.model.as_ref()
+ ":generateContent?key="
+ token.as_ref(),
)
.header("Content-Type", "application/json")
@@ -110,11 +112,7 @@ impl Gemini {
anyhow::bail!("Unknown error while making request to Gemini: {:?}", res);
}
}
async fn do_chat_completion(
&self,
prompt: &Prompt,
params: Value,
) -> anyhow::Result<String> {
async fn do_chat_completion(&self, prompt: &Prompt, params: Value) -> anyhow::Result<String> {
let client = reqwest::Client::new();
let token = self.get_token()?;
let res: serde_json::Value = client
@@ -124,6 +122,8 @@ impl Gemini {
.as_ref()
.context("must specify `gemini_endpoint` to use gemini")?
.to_owned()
+ self.config.model.as_ref()
+ ":generateContent?key="
+ token.as_ref(),
)
.header("Content-Type", "application/json")
@@ -189,16 +189,16 @@ mod test {
#[tokio::test]
async fn gemini_completion_do_generate() -> anyhow::Result<()> {
let configuration: config::Gemini = from_value(json!({
"completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key=",
"completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/",
"model": "gemini-1.5-flash-latest",
"auth_token_env_var_name": "GEMINI_API_KEY",
}))?;
let anthropic = Gemini::new(configuration);
let gemini = Gemini::new(configuration);
let prompt = Prompt::default_fim();
let run_params = json!({
"max_tokens": 2
});
let response = anthropic.do_generate(&prompt, run_params).await?;
let response = gemini.do_generate(&prompt, run_params).await?;
assert!(!response.generated_text.is_empty());
dbg!(response.generated_text);
Ok(())
@@ -206,7 +206,7 @@ mod test {
#[tokio::test]
async fn gemini_chat_do_generate() -> anyhow::Result<()> {
let configuration: config::Gemini = serde_json::from_value(json!({
"chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=",
"chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/",
"model": "gemini-1.5-flash",
"auth_token_env_var_name": "GEMINI_API_KEY",
}))?;