From 23ae31590766b8ffbcddc5a7f29fc635d902488b Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Fri, 31 Jan 2025 12:20:16 -0500 Subject: [PATCH] fix: improve url construction in provider (#985) --- crates/goose/src/providers/anthropic.rs | 8 +++++-- crates/goose/src/providers/databricks.rs | 14 ++++++----- crates/goose/src/providers/google.rs | 20 ++++++++++------ crates/goose/src/providers/groq.rs | 12 ++++++---- crates/goose/src/providers/oauth.rs | 12 ++++++---- crates/goose/src/providers/ollama.rs | 28 ++++++++++++++++++---- crates/goose/src/providers/openai.rs | 8 +++++-- crates/goose/src/providers/openrouter.rs | 12 ++++++---- crates/goose/src/tracing/langfuse_layer.rs | 8 +++++-- crates/goose/tests/providers.rs | 9 +++++++ 10 files changed, 93 insertions(+), 38 deletions(-) diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 31d320e6..3cfd51a2 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -58,11 +58,15 @@ impl AnthropicProvider { } async fn post(&self, payload: Value) -> Result { - let url = format!("{}/v1/messages", self.host.trim_end_matches('/')); + let base_url = url::Url::parse(&self.host) + .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; + let url = base_url.join("v1/messages").map_err(|e| { + ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) + })?; let response = self .client - .post(&url) + .post(url) .header("x-api-key", &self.api_key) .header("anthropic-version", "2023-06-01") .json(&payload) diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 3d140262..2ed09a18 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -14,6 +14,7 @@ use crate::config::ConfigError; use crate::message::Message; use crate::model::ModelConfig; use mcp_core::tool::Tool; +use url::Url; const DEFAULT_CLIENT_ID: &str = "databricks-cli"; const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020"; @@ -137,16 +138,17 @@ impl DatabricksProvider { } async fn post(&self, payload: Value) -> Result { - let url = format!( - "{}/serving-endpoints/{}/invocations", - self.host.trim_end_matches('/'), - self.model.model_name - ); + let base_url = Url::parse(&self.host) + .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; + let path = format!("serving-endpoints/{}/invocations", self.model.model_name); + let url = base_url.join(&path).map_err(|e| { + ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) + })?; let auth_header = self.ensure_auth_header().await?; let response = self .client - .post(&url) + .post(url) .header("Authorization", auth_header) .json(&payload) .send() diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index bf7d73bd..42ccb68e 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -10,6 +10,7 @@ use mcp_core::tool::Tool; use reqwest::{Client, StatusCode}; use serde_json::Value; use std::time::Duration; +use url::Url; pub const GOOGLE_API_HOST: &str = "https://generativelanguage.googleapis.com"; pub const GOOGLE_DEFAULT_MODEL: &str = "gemini-2.0-flash-exp"; @@ -61,16 +62,21 @@ impl GoogleProvider { } async fn post(&self, payload: Value) -> Result { - let url = format!( - "{}/v1beta/models/{}:generateContent?key={}", - self.host.trim_end_matches('/'), - self.model.model_name, - self.api_key - ); + let base_url = Url::parse(&self.host) + .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; + + let url = base_url + .join(&format!( + "v1beta/models/{}:generateContent?key={}", + self.model.model_name, self.api_key + )) + .map_err(|e| { + ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) + })?; let response = self .client - .post(&url) + .post(url) .header("CONTENT_TYPE", "application/json") .json(&payload) .send() diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index c230952e..f0c3ace4 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -10,6 +10,7 @@ use mcp_core::Tool; use reqwest::{Client, StatusCode}; use serde_json::Value; use std::time::Duration; +use url::Url; pub const GROQ_API_HOST: &str = "https://api.groq.com"; pub const GROQ_DEFAULT_MODEL: &str = "llama-3.3-70b-versatile"; @@ -54,14 +55,15 @@ impl GroqProvider { } async fn post(&self, payload: Value) -> anyhow::Result { - let url = format!( - "{}/openai/v1/chat/completions", - self.host.trim_end_matches('/') - ); + let base_url = Url::parse(&self.host) + .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; + let url = base_url.join("openai/v1/chat/completions").map_err(|e| { + ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) + })?; let response = self .client - .post(&url) + .post(url) .header("Authorization", format!("Bearer {}", self.api_key)) .json(&payload) .send() diff --git a/crates/goose/src/providers/oauth.rs b/crates/goose/src/providers/oauth.rs index 00bc6673..f4bf1f85 100644 --- a/crates/goose/src/providers/oauth.rs +++ b/crates/goose/src/providers/oauth.rs @@ -8,11 +8,11 @@ use serde_json::Value; use sha2::Digest; use std::{collections::HashMap, fs, net::SocketAddr, path::PathBuf, sync::Arc}; use tokio::sync::{oneshot, Mutex as TokioMutex}; +use url::Url; lazy_static! { static ref OAUTH_MUTEX: TokioMutex<()> = TokioMutex::new(()); } -use url::Url; #[derive(Debug, Clone)] struct OidcEndpoints { @@ -76,16 +76,18 @@ impl TokenCache { } async fn get_workspace_endpoints(host: &str) -> Result { - let host = host.trim_end_matches('/'); - let oidc_url = format!("{}/oidc/.well-known/oauth-authorization-server", host); + let base_url = Url::parse(host).expect("Invalid host URL"); + let oidc_url = base_url + .join("oidc/.well-known/oauth-authorization-server") + .expect("Invalid OIDC URL"); let client = reqwest::Client::new(); - let resp = client.get(&oidc_url).send().await?; + let resp = client.get(oidc_url.clone()).send().await?; if !resp.status().is_success() { return Err(anyhow::anyhow!( "Failed to get OIDC configuration from {}", - oidc_url + oidc_url.to_string() )); } diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index db3f1f59..f8d0f749 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -10,8 +10,10 @@ use mcp_core::tool::Tool; use reqwest::Client; use serde_json::Value; use std::time::Duration; +use url::Url; -pub const OLLAMA_HOST: &str = "http://localhost:11434"; +pub const OLLAMA_HOST: &str = "localhost"; +pub const OLLAMA_DEFAULT_PORT: u16 = 11434; pub const OLLAMA_DEFAULT_MODEL: &str = "qwen2.5"; // Ollama can run many models, we only provide the default pub const OLLAMA_KNOWN_MODELS: &[&str] = &[OLLAMA_DEFAULT_MODEL]; @@ -51,9 +53,28 @@ impl OllamaProvider { } async fn post(&self, payload: Value) -> Result { - let url = format!("{}/v1/chat/completions", self.host.trim_end_matches('/')); + // OLLAMA_HOST is sometimes just the 'host' or 'host:port' without a scheme + let base = if self.host.starts_with("http://") || self.host.starts_with("https://") { + self.host.clone() + } else { + format!("http://{}", self.host) + }; - let response = self.client.post(&url).json(&payload).send().await?; + let mut base_url = Url::parse(&base) + .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; + + // Set the default port if missing + if base_url.port().is_none() { + base_url.set_port(Some(OLLAMA_DEFAULT_PORT)).map_err(|_| { + ProviderError::RequestFailed("Failed to set default port".to_string()) + })?; + } + + let url = base_url.join("v1/chat/completions").map_err(|e| { + ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) + })?; + + let response = self.client.post(url).json(&payload).send().await?; handle_response_openai_compat(response).await } @@ -99,7 +120,6 @@ impl Provider for OllamaProvider { tools, &super::utils::ImageFormat::OpenAi, )?; - let response = self.post(payload.clone()).await?; // Parse response diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 31c98a18..35149bff 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -59,11 +59,15 @@ impl OpenAiProvider { } async fn post(&self, payload: Value) -> Result { - let url = format!("{}/v1/chat/completions", self.host.trim_end_matches('/')); + let base_url = url::Url::parse(&self.host) + .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; + let url = base_url.join("v1/chat/completions").map_err(|e| { + ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) + })?; let response = self .client - .post(&url) + .post(url) .header("Authorization", format!("Bearer {}", self.api_key)) .json(&payload) .send() diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index 1750776b..ae94a5e0 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -11,6 +11,7 @@ use crate::message::Message; use crate::model::ModelConfig; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; use mcp_core::tool::Tool; +use url::Url; pub const OPENROUTER_DEFAULT_MODEL: &str = "anthropic/claude-3.5-sonnet"; pub const OPENROUTER_MODEL_PREFIX_ANTHROPIC: &str = "anthropic"; @@ -56,14 +57,15 @@ impl OpenRouterProvider { } async fn post(&self, payload: Value) -> Result { - let url = format!( - "{}/api/v1/chat/completions", - self.host.trim_end_matches('/') - ); + let base_url = Url::parse(&self.host) + .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; + let url = base_url.join("api/v1/chat/completions").map_err(|e| { + ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) + })?; let response = self .client - .post(&url) + .post(url) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", self.api_key)) .header("HTTP-Referer", "https://github.com/block/goose") diff --git a/crates/goose/src/tracing/langfuse_layer.rs b/crates/goose/src/tracing/langfuse_layer.rs index ba41aa9d..2ac418cf 100644 --- a/crates/goose/src/tracing/langfuse_layer.rs +++ b/crates/goose/src/tracing/langfuse_layer.rs @@ -7,6 +7,7 @@ use std::env; use std::sync::Arc; use std::time::Duration; use tokio::sync::Mutex; +use url::Url; use uuid::Uuid; const DEFAULT_LANGFUSE_URL: &str = "http://localhost:3000"; @@ -77,11 +78,14 @@ impl LangfuseBatchManager { } let payload = json!({ "batch": self.batch }); - let url = format!("{}/api/public/ingestion", self.base_url); + let base_url = Url::parse(&self.base_url).map_err(|e| format!("Invalid base URL: {e}"))?; + let url = base_url + .join("api/public/ingestion") + .map_err(|e| format!("Failed to construct endpoint URL: {e}"))?; let response = self .client - .post(&url) + .post(url) .basic_auth(&self.public_key, Some(&self.secret_key)) .json(&payload) .send() diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index 03b3ccef..4ef98125 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -230,6 +230,15 @@ impl ProviderTester { dbg!(&result); println!("==================="); + // Ollama and OpenRouter truncate by default even when the context window is exceeded + if self.name.to_lowercase() == "ollama" || self.name.to_lowercase() == "openrouter" { + assert!( + result.is_ok(), + "Expected to succeed because of default truncation" + ); + return Ok(()); + } + assert!( result.is_err(), "Expected error when context window is exceeded"