mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-18 04:54:34 +01:00
fix: improve url construction in provider (#985)
This commit is contained in:
@@ -58,11 +58,15 @@ impl AnthropicProvider {
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
let url = format!("{}/v1/messages", self.host.trim_end_matches('/'));
|
||||
let base_url = url::Url::parse(&self.host)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
let url = base_url.join("v1/messages").map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
||||
})?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.post(url)
|
||||
.header("x-api-key", &self.api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.json(&payload)
|
||||
|
||||
@@ -14,6 +14,7 @@ use crate::config::ConfigError;
|
||||
use crate::message::Message;
|
||||
use crate::model::ModelConfig;
|
||||
use mcp_core::tool::Tool;
|
||||
use url::Url;
|
||||
|
||||
const DEFAULT_CLIENT_ID: &str = "databricks-cli";
|
||||
const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020";
|
||||
@@ -137,16 +138,17 @@ impl DatabricksProvider {
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
let url = format!(
|
||||
"{}/serving-endpoints/{}/invocations",
|
||||
self.host.trim_end_matches('/'),
|
||||
self.model.model_name
|
||||
);
|
||||
let base_url = Url::parse(&self.host)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
let path = format!("serving-endpoints/{}/invocations", self.model.model_name);
|
||||
let url = base_url.join(&path).map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
||||
})?;
|
||||
|
||||
let auth_header = self.ensure_auth_header().await?;
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.post(url)
|
||||
.header("Authorization", auth_header)
|
||||
.json(&payload)
|
||||
.send()
|
||||
|
||||
@@ -10,6 +10,7 @@ use mcp_core::tool::Tool;
|
||||
use reqwest::{Client, StatusCode};
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
use url::Url;
|
||||
|
||||
pub const GOOGLE_API_HOST: &str = "https://generativelanguage.googleapis.com";
|
||||
pub const GOOGLE_DEFAULT_MODEL: &str = "gemini-2.0-flash-exp";
|
||||
@@ -61,16 +62,21 @@ impl GoogleProvider {
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
let url = format!(
|
||||
"{}/v1beta/models/{}:generateContent?key={}",
|
||||
self.host.trim_end_matches('/'),
|
||||
self.model.model_name,
|
||||
self.api_key
|
||||
);
|
||||
let base_url = Url::parse(&self.host)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
|
||||
let url = base_url
|
||||
.join(&format!(
|
||||
"v1beta/models/{}:generateContent?key={}",
|
||||
self.model.model_name, self.api_key
|
||||
))
|
||||
.map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
||||
})?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.post(url)
|
||||
.header("CONTENT_TYPE", "application/json")
|
||||
.json(&payload)
|
||||
.send()
|
||||
|
||||
@@ -10,6 +10,7 @@ use mcp_core::Tool;
|
||||
use reqwest::{Client, StatusCode};
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
use url::Url;
|
||||
|
||||
pub const GROQ_API_HOST: &str = "https://api.groq.com";
|
||||
pub const GROQ_DEFAULT_MODEL: &str = "llama-3.3-70b-versatile";
|
||||
@@ -54,14 +55,15 @@ impl GroqProvider {
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> anyhow::Result<Value, ProviderError> {
|
||||
let url = format!(
|
||||
"{}/openai/v1/chat/completions",
|
||||
self.host.trim_end_matches('/')
|
||||
);
|
||||
let base_url = Url::parse(&self.host)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
let url = base_url.join("openai/v1/chat/completions").map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
||||
})?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.post(url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&payload)
|
||||
.send()
|
||||
|
||||
@@ -8,11 +8,11 @@ use serde_json::Value;
|
||||
use sha2::Digest;
|
||||
use std::{collections::HashMap, fs, net::SocketAddr, path::PathBuf, sync::Arc};
|
||||
use tokio::sync::{oneshot, Mutex as TokioMutex};
|
||||
use url::Url;
|
||||
|
||||
lazy_static! {
|
||||
static ref OAUTH_MUTEX: TokioMutex<()> = TokioMutex::new(());
|
||||
}
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct OidcEndpoints {
|
||||
@@ -76,16 +76,18 @@ impl TokenCache {
|
||||
}
|
||||
|
||||
async fn get_workspace_endpoints(host: &str) -> Result<OidcEndpoints> {
|
||||
let host = host.trim_end_matches('/');
|
||||
let oidc_url = format!("{}/oidc/.well-known/oauth-authorization-server", host);
|
||||
let base_url = Url::parse(host).expect("Invalid host URL");
|
||||
let oidc_url = base_url
|
||||
.join("oidc/.well-known/oauth-authorization-server")
|
||||
.expect("Invalid OIDC URL");
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client.get(&oidc_url).send().await?;
|
||||
let resp = client.get(oidc_url.clone()).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Failed to get OIDC configuration from {}",
|
||||
oidc_url
|
||||
oidc_url.to_string()
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
@@ -10,8 +10,10 @@ use mcp_core::tool::Tool;
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
use url::Url;
|
||||
|
||||
pub const OLLAMA_HOST: &str = "http://localhost:11434";
|
||||
pub const OLLAMA_HOST: &str = "localhost";
|
||||
pub const OLLAMA_DEFAULT_PORT: u16 = 11434;
|
||||
pub const OLLAMA_DEFAULT_MODEL: &str = "qwen2.5";
|
||||
// Ollama can run many models, we only provide the default
|
||||
pub const OLLAMA_KNOWN_MODELS: &[&str] = &[OLLAMA_DEFAULT_MODEL];
|
||||
@@ -51,9 +53,28 @@ impl OllamaProvider {
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
let url = format!("{}/v1/chat/completions", self.host.trim_end_matches('/'));
|
||||
// OLLAMA_HOST is sometimes just the 'host' or 'host:port' without a scheme
|
||||
let base = if self.host.starts_with("http://") || self.host.starts_with("https://") {
|
||||
self.host.clone()
|
||||
} else {
|
||||
format!("http://{}", self.host)
|
||||
};
|
||||
|
||||
let response = self.client.post(&url).json(&payload).send().await?;
|
||||
let mut base_url = Url::parse(&base)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
|
||||
// Set the default port if missing
|
||||
if base_url.port().is_none() {
|
||||
base_url.set_port(Some(OLLAMA_DEFAULT_PORT)).map_err(|_| {
|
||||
ProviderError::RequestFailed("Failed to set default port".to_string())
|
||||
})?;
|
||||
}
|
||||
|
||||
let url = base_url.join("v1/chat/completions").map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
||||
})?;
|
||||
|
||||
let response = self.client.post(url).json(&payload).send().await?;
|
||||
|
||||
handle_response_openai_compat(response).await
|
||||
}
|
||||
@@ -99,7 +120,6 @@ impl Provider for OllamaProvider {
|
||||
tools,
|
||||
&super::utils::ImageFormat::OpenAi,
|
||||
)?;
|
||||
|
||||
let response = self.post(payload.clone()).await?;
|
||||
|
||||
// Parse response
|
||||
|
||||
@@ -59,11 +59,15 @@ impl OpenAiProvider {
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
let url = format!("{}/v1/chat/completions", self.host.trim_end_matches('/'));
|
||||
let base_url = url::Url::parse(&self.host)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
let url = base_url.join("v1/chat/completions").map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
||||
})?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.post(url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&payload)
|
||||
.send()
|
||||
|
||||
@@ -11,6 +11,7 @@ use crate::message::Message;
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
|
||||
use mcp_core::tool::Tool;
|
||||
use url::Url;
|
||||
|
||||
pub const OPENROUTER_DEFAULT_MODEL: &str = "anthropic/claude-3.5-sonnet";
|
||||
pub const OPENROUTER_MODEL_PREFIX_ANTHROPIC: &str = "anthropic";
|
||||
@@ -56,14 +57,15 @@ impl OpenRouterProvider {
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
let url = format!(
|
||||
"{}/api/v1/chat/completions",
|
||||
self.host.trim_end_matches('/')
|
||||
);
|
||||
let base_url = Url::parse(&self.host)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
let url = base_url.join("api/v1/chat/completions").map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
||||
})?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.post(url)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("HTTP-Referer", "https://github.com/block/goose")
|
||||
|
||||
@@ -7,6 +7,7 @@ use std::env;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
const DEFAULT_LANGFUSE_URL: &str = "http://localhost:3000";
|
||||
@@ -77,11 +78,14 @@ impl LangfuseBatchManager {
|
||||
}
|
||||
|
||||
let payload = json!({ "batch": self.batch });
|
||||
let url = format!("{}/api/public/ingestion", self.base_url);
|
||||
let base_url = Url::parse(&self.base_url).map_err(|e| format!("Invalid base URL: {e}"))?;
|
||||
let url = base_url
|
||||
.join("api/public/ingestion")
|
||||
.map_err(|e| format!("Failed to construct endpoint URL: {e}"))?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.post(url)
|
||||
.basic_auth(&self.public_key, Some(&self.secret_key))
|
||||
.json(&payload)
|
||||
.send()
|
||||
|
||||
@@ -230,6 +230,15 @@ impl ProviderTester {
|
||||
dbg!(&result);
|
||||
println!("===================");
|
||||
|
||||
// Ollama and OpenRouter truncate by default even when the context window is exceeded
|
||||
if self.name.to_lowercase() == "ollama" || self.name.to_lowercase() == "openrouter" {
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Expected to succeed because of default truncation"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"Expected error when context window is exceeded"
|
||||
|
||||
Reference in New Issue
Block a user