fix: improve url construction in provider (#985)

This commit is contained in:
Salman Mohammed
2025-01-31 12:20:16 -05:00
committed by GitHub
parent 31e0b72b4f
commit 23ae315907
10 changed files with 93 additions and 38 deletions

View File

@@ -58,11 +58,15 @@ impl AnthropicProvider {
}
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
let url = format!("{}/v1/messages", self.host.trim_end_matches('/'));
let base_url = url::Url::parse(&self.host)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
let url = base_url.join("v1/messages").map_err(|e| {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;
let response = self
.client
.post(&url)
.post(url)
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.json(&payload)

View File

@@ -14,6 +14,7 @@ use crate::config::ConfigError;
use crate::message::Message;
use crate::model::ModelConfig;
use mcp_core::tool::Tool;
use url::Url;
const DEFAULT_CLIENT_ID: &str = "databricks-cli";
const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020";
@@ -137,16 +138,17 @@ impl DatabricksProvider {
}
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
let url = format!(
"{}/serving-endpoints/{}/invocations",
self.host.trim_end_matches('/'),
self.model.model_name
);
let base_url = Url::parse(&self.host)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
let path = format!("serving-endpoints/{}/invocations", self.model.model_name);
let url = base_url.join(&path).map_err(|e| {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;
let auth_header = self.ensure_auth_header().await?;
let response = self
.client
.post(&url)
.post(url)
.header("Authorization", auth_header)
.json(&payload)
.send()

View File

@@ -10,6 +10,7 @@ use mcp_core::tool::Tool;
use reqwest::{Client, StatusCode};
use serde_json::Value;
use std::time::Duration;
use url::Url;
pub const GOOGLE_API_HOST: &str = "https://generativelanguage.googleapis.com";
pub const GOOGLE_DEFAULT_MODEL: &str = "gemini-2.0-flash-exp";
@@ -61,16 +62,21 @@ impl GoogleProvider {
}
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
let url = format!(
"{}/v1beta/models/{}:generateContent?key={}",
self.host.trim_end_matches('/'),
self.model.model_name,
self.api_key
);
let base_url = Url::parse(&self.host)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
let url = base_url
.join(&format!(
"v1beta/models/{}:generateContent?key={}",
self.model.model_name, self.api_key
))
.map_err(|e| {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;
let response = self
.client
.post(&url)
.post(url)
.header("CONTENT_TYPE", "application/json")
.json(&payload)
.send()

View File

@@ -10,6 +10,7 @@ use mcp_core::Tool;
use reqwest::{Client, StatusCode};
use serde_json::Value;
use std::time::Duration;
use url::Url;
pub const GROQ_API_HOST: &str = "https://api.groq.com";
pub const GROQ_DEFAULT_MODEL: &str = "llama-3.3-70b-versatile";
@@ -54,14 +55,15 @@ impl GroqProvider {
}
async fn post(&self, payload: Value) -> anyhow::Result<Value, ProviderError> {
let url = format!(
"{}/openai/v1/chat/completions",
self.host.trim_end_matches('/')
);
let base_url = Url::parse(&self.host)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
let url = base_url.join("openai/v1/chat/completions").map_err(|e| {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;
let response = self
.client
.post(&url)
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&payload)
.send()

View File

@@ -8,11 +8,11 @@ use serde_json::Value;
use sha2::Digest;
use std::{collections::HashMap, fs, net::SocketAddr, path::PathBuf, sync::Arc};
use tokio::sync::{oneshot, Mutex as TokioMutex};
use url::Url;
lazy_static! {
static ref OAUTH_MUTEX: TokioMutex<()> = TokioMutex::new(());
}
use url::Url;
#[derive(Debug, Clone)]
struct OidcEndpoints {
@@ -76,16 +76,18 @@ impl TokenCache {
}
async fn get_workspace_endpoints(host: &str) -> Result<OidcEndpoints> {
let host = host.trim_end_matches('/');
let oidc_url = format!("{}/oidc/.well-known/oauth-authorization-server", host);
let base_url = Url::parse(host).expect("Invalid host URL");
let oidc_url = base_url
.join("oidc/.well-known/oauth-authorization-server")
.expect("Invalid OIDC URL");
let client = reqwest::Client::new();
let resp = client.get(&oidc_url).send().await?;
let resp = client.get(oidc_url.clone()).send().await?;
if !resp.status().is_success() {
return Err(anyhow::anyhow!(
"Failed to get OIDC configuration from {}",
oidc_url
oidc_url.to_string()
));
}

View File

@@ -10,8 +10,10 @@ use mcp_core::tool::Tool;
use reqwest::Client;
use serde_json::Value;
use std::time::Duration;
use url::Url;
pub const OLLAMA_HOST: &str = "http://localhost:11434";
pub const OLLAMA_HOST: &str = "localhost";
pub const OLLAMA_DEFAULT_PORT: u16 = 11434;
pub const OLLAMA_DEFAULT_MODEL: &str = "qwen2.5";
// Ollama can run many models, we only provide the default
pub const OLLAMA_KNOWN_MODELS: &[&str] = &[OLLAMA_DEFAULT_MODEL];
@@ -51,9 +53,28 @@ impl OllamaProvider {
}
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
let url = format!("{}/v1/chat/completions", self.host.trim_end_matches('/'));
// OLLAMA_HOST is sometimes just the 'host' or 'host:port' without a scheme
let base = if self.host.starts_with("http://") || self.host.starts_with("https://") {
self.host.clone()
} else {
format!("http://{}", self.host)
};
let response = self.client.post(&url).json(&payload).send().await?;
let mut base_url = Url::parse(&base)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
// Set the default port if missing
if base_url.port().is_none() {
base_url.set_port(Some(OLLAMA_DEFAULT_PORT)).map_err(|_| {
ProviderError::RequestFailed("Failed to set default port".to_string())
})?;
}
let url = base_url.join("v1/chat/completions").map_err(|e| {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;
let response = self.client.post(url).json(&payload).send().await?;
handle_response_openai_compat(response).await
}
@@ -99,7 +120,6 @@ impl Provider for OllamaProvider {
tools,
&super::utils::ImageFormat::OpenAi,
)?;
let response = self.post(payload.clone()).await?;
// Parse response

View File

@@ -59,11 +59,15 @@ impl OpenAiProvider {
}
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
let url = format!("{}/v1/chat/completions", self.host.trim_end_matches('/'));
let base_url = url::Url::parse(&self.host)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
let url = base_url.join("v1/chat/completions").map_err(|e| {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;
let response = self
.client
.post(&url)
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&payload)
.send()

View File

@@ -11,6 +11,7 @@ use crate::message::Message;
use crate::model::ModelConfig;
use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
use mcp_core::tool::Tool;
use url::Url;
pub const OPENROUTER_DEFAULT_MODEL: &str = "anthropic/claude-3.5-sonnet";
pub const OPENROUTER_MODEL_PREFIX_ANTHROPIC: &str = "anthropic";
@@ -56,14 +57,15 @@ impl OpenRouterProvider {
}
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
let url = format!(
"{}/api/v1/chat/completions",
self.host.trim_end_matches('/')
);
let base_url = Url::parse(&self.host)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
let url = base_url.join("api/v1/chat/completions").map_err(|e| {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;
let response = self
.client
.post(&url)
.post(url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.header("HTTP-Referer", "https://github.com/block/goose")

View File

@@ -7,6 +7,7 @@ use std::env;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use url::Url;
use uuid::Uuid;
const DEFAULT_LANGFUSE_URL: &str = "http://localhost:3000";
@@ -77,11 +78,14 @@ impl LangfuseBatchManager {
}
let payload = json!({ "batch": self.batch });
let url = format!("{}/api/public/ingestion", self.base_url);
let base_url = Url::parse(&self.base_url).map_err(|e| format!("Invalid base URL: {e}"))?;
let url = base_url
.join("api/public/ingestion")
.map_err(|e| format!("Failed to construct endpoint URL: {e}"))?;
let response = self
.client
.post(&url)
.post(url)
.basic_auth(&self.public_key, Some(&self.secret_key))
.json(&payload)
.send()

View File

@@ -230,6 +230,15 @@ impl ProviderTester {
dbg!(&result);
println!("===================");
// Ollama and OpenRouter truncate by default even when the context window is exceeded
if self.name.to_lowercase() == "ollama" || self.name.to_lowercase() == "openrouter" {
assert!(
result.is_ok(),
"Expected to succeed because of default truncation"
);
return Ok(());
}
assert!(
result.is_err(),
"Expected error when context window is exceeded"