From 78b30cc0ca5220c44db3122f4e79d778992b2f48 Mon Sep 17 00:00:00 2001 From: hikae Date: Sat, 19 Jul 2025 19:07:01 +0900 Subject: [PATCH] feat: Add LiteLLM provider with automatic prompt caching support (#3380) Signed-off-by: HikaruEgashira --- crates/goose-llm/tests/providers_complete.rs | 2 +- crates/goose/src/model.rs | 1 + crates/goose/src/providers/base.rs | 13 + crates/goose/src/providers/factory.rs | 3 + crates/goose/src/providers/litellm.rs | 357 +++++++++++++++++++ crates/goose/src/providers/mod.rs | 1 + crates/goose/src/providers/openrouter.rs | 17 +- crates/goose/tests/providers.rs | 27 +- 8 files changed, 411 insertions(+), 10 deletions(-) create mode 100644 crates/goose/src/providers/litellm.rs diff --git a/crates/goose-llm/tests/providers_complete.rs b/crates/goose-llm/tests/providers_complete.rs index 25706558..10ae75fc 100644 --- a/crates/goose-llm/tests/providers_complete.rs +++ b/crates/goose-llm/tests/providers_complete.rs @@ -157,7 +157,7 @@ impl ProviderTester { .content .iter() .filter_map(|message| message.as_tool_request()) - .last() + .next_back() .expect("got tool request") .id; diff --git a/crates/goose/src/model.rs b/crates/goose/src/model.rs index 9fe73aec..59f99763 100644 --- a/crates/goose/src/model.rs +++ b/crates/goose/src/model.rs @@ -275,6 +275,7 @@ mod tests { } #[test] + #[serial_test::serial] fn test_model_config_context_limit_env_vars() { use temp_env::with_vars; diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index c3510fe8..260e1f22 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -41,6 +41,8 @@ pub struct ModelInfo { pub output_token_cost: Option, /// Currency for the costs (default: "$") pub currency: Option, + /// Whether this model supports cache control + pub supports_cache_control: Option, } impl ModelInfo { @@ -52,6 +54,7 @@ impl ModelInfo { input_token_cost: None, output_token_cost: None, currency: None, + supports_cache_control: None, } } @@ -68,6 +71,7 @@ impl ModelInfo { input_token_cost: Some(input_cost), output_token_cost: Some(output_cost), currency: Some("$".to_string()), + supports_cache_control: None, } } } @@ -115,6 +119,7 @@ impl ProviderMetadata { input_token_cost: None, output_token_cost: None, currency: None, + supports_cache_control: None, }) .collect(), model_doc_link: model_doc_link.to_string(), @@ -290,6 +295,11 @@ pub trait Provider: Send + Sync { false } + /// Check if this provider supports cache control + fn supports_cache_control(&self) -> bool { + false + } + /// Create embeddings if supported. Default implementation returns an error. async fn create_embeddings(&self, _texts: Vec) -> Result>, ProviderError> { Err(ProviderError::ExecutionError( @@ -435,6 +445,7 @@ mod tests { input_token_cost: None, output_token_cost: None, currency: None, + supports_cache_control: None, }; assert_eq!(info.context_limit, 1000); @@ -445,6 +456,7 @@ mod tests { input_token_cost: None, output_token_cost: None, currency: None, + supports_cache_control: None, }; assert_eq!(info, info2); @@ -455,6 +467,7 @@ mod tests { input_token_cost: None, output_token_cost: None, currency: None, + supports_cache_control: None, }; assert_ne!(info, info3); } diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index c6123d0a..09142553 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -12,6 +12,7 @@ use super::{ google::GoogleProvider, groq::GroqProvider, lead_worker::LeadWorkerProvider, + litellm::LiteLLMProvider, ollama::OllamaProvider, openai::OpenAiProvider, openrouter::OpenRouterProvider, @@ -50,6 +51,7 @@ pub fn providers() -> Vec { // GithubCopilotProvider::metadata(), GoogleProvider::metadata(), GroqProvider::metadata(), + LiteLLMProvider::metadata(), OllamaProvider::metadata(), OpenAiProvider::metadata(), OpenRouterProvider::metadata(), @@ -158,6 +160,7 @@ fn create_provider(name: &str, model: ModelConfig) -> Result> "databricks" => Ok(Arc::new(DatabricksProvider::from_env(model)?)), "gemini-cli" => Ok(Arc::new(GeminiCliProvider::from_env(model)?)), "groq" => Ok(Arc::new(GroqProvider::from_env(model)?)), + "litellm" => Ok(Arc::new(LiteLLMProvider::from_env(model)?)), "ollama" => Ok(Arc::new(OllamaProvider::from_env(model)?)), "openrouter" => Ok(Arc::new(OpenRouterProvider::from_env(model)?)), "gcp_vertex_ai" => Ok(Arc::new(GcpVertexAIProvider::from_env(model)?)), diff --git a/crates/goose/src/providers/litellm.rs b/crates/goose/src/providers/litellm.rs new file mode 100644 index 00000000..591337ba --- /dev/null +++ b/crates/goose/src/providers/litellm.rs @@ -0,0 +1,357 @@ +use anyhow::Result; +use async_trait::async_trait; +use reqwest::Client; +use serde_json::{json, Value}; +use std::collections::HashMap; +use std::time::Duration; +use url::Url; + +use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage}; +use super::embedding::EmbeddingCapable; +use super::errors::ProviderError; +use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat}; +use crate::message::Message; +use crate::model::ModelConfig; +use mcp_core::tool::Tool; + +pub const LITELLM_DEFAULT_MODEL: &str = "gpt-4o-mini"; +pub const LITELLM_DOC_URL: &str = "https://docs.litellm.ai/docs/"; + +#[derive(Debug, serde::Serialize)] +pub struct LiteLLMProvider { + #[serde(skip)] + client: Client, + host: String, + base_path: String, + api_key: String, + model: ModelConfig, + custom_headers: Option>, +} + +impl Default for LiteLLMProvider { + fn default() -> Self { + let model = ModelConfig::new(LiteLLMProvider::metadata().default_model); + LiteLLMProvider::from_env(model).expect("Failed to initialize LiteLLM provider") + } +} + +impl LiteLLMProvider { + pub fn from_env(model: ModelConfig) -> Result { + let config = crate::config::Config::global(); + let api_key: String = config + .get_secret("LITELLM_API_KEY") + .unwrap_or_else(|_| String::new()); + let host: String = config + .get_param("LITELLM_HOST") + .unwrap_or_else(|_| "https://api.litellm.ai".to_string()); + let base_path: String = config + .get_param("LITELLM_BASE_PATH") + .unwrap_or_else(|_| "v1/chat/completions".to_string()); + let custom_headers: Option> = config + .get_secret("LITELLM_CUSTOM_HEADERS") + .or_else(|_| config.get_param("LITELLM_CUSTOM_HEADERS")) + .ok() + .map(parse_custom_headers); + let timeout_secs: u64 = config.get_param("LITELLM_TIMEOUT").unwrap_or(600); + let client = Client::builder() + .timeout(Duration::from_secs(timeout_secs)) + .build()?; + + Ok(Self { + client, + host, + base_path, + api_key, + model, + custom_headers, + }) + } + + fn add_headers(&self, mut request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + if let Some(custom_headers) = &self.custom_headers { + for (key, value) in custom_headers { + request = request.header(key, value); + } + } + + request + } + + async fn fetch_models(&self) -> Result, ProviderError> { + let models_url = format!("{}/model/info", self.host); + + let mut req = self + .client + .get(&models_url) + .header("Authorization", format!("Bearer {}", self.api_key)); + + req = self.add_headers(req); + + let response = req + .send() + .await + .map_err(|e| ProviderError::RequestFailed(format!("Failed to fetch models: {}", e)))?; + + if !response.status().is_success() { + return Err(ProviderError::RequestFailed(format!( + "Models endpoint returned status: {}", + response.status() + ))); + } + + let response_json: Value = response.json().await.map_err(|e| { + ProviderError::RequestFailed(format!("Failed to parse models response: {}", e)) + })?; + + let models_data = response_json["data"].as_array().ok_or_else(|| { + ProviderError::RequestFailed("Missing data field in models response".to_string()) + })?; + + let mut models = Vec::new(); + for model_data in models_data { + if let Some(model_name) = model_data["model_name"].as_str() { + if model_name.contains("/*") { + continue; + } + + let model_info = &model_data["model_info"]; + let context_length = + model_info["max_input_tokens"].as_u64().unwrap_or(128000) as usize; + let supports_cache_control = model_info["supports_prompt_caching"].as_bool(); + + let mut model_info_obj = ModelInfo::new(model_name, context_length); + model_info_obj.supports_cache_control = supports_cache_control; + models.push(model_info_obj); + } + } + + Ok(models) + } + + async fn post(&self, payload: Value) -> Result { + let base_url = Url::parse(&self.host) + .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; + let url = base_url.join(&self.base_path).map_err(|e| { + ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) + })?; + + let request = self + .client + .post(url) + .header("Authorization", format!("Bearer {}", self.api_key)); + + let request = self.add_headers(request); + + let response = request.json(&payload).send().await?; + + handle_response_openai_compat(response).await + } +} + +#[async_trait] +impl Provider for LiteLLMProvider { + fn metadata() -> ProviderMetadata { + ProviderMetadata::new( + "litellm", + "LiteLLM", + "LiteLLM proxy supporting multiple models with automatic prompt caching", + LITELLM_DEFAULT_MODEL, + vec![], + LITELLM_DOC_URL, + vec![ + ConfigKey::new("LITELLM_API_KEY", false, true, None), + ConfigKey::new("LITELLM_HOST", true, false, Some("http://localhost:4000")), + ConfigKey::new( + "LITELLM_BASE_PATH", + true, + false, + Some("v1/chat/completions"), + ), + ConfigKey::new("LITELLM_CUSTOM_HEADERS", false, true, None), + ConfigKey::new("LITELLM_TIMEOUT", false, false, Some("600")), + ], + ) + } + + fn get_model_config(&self) -> ModelConfig { + self.model.clone() + } + + #[tracing::instrument(skip_all, name = "provider_complete")] + async fn complete( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + let mut payload = super::formats::openai::create_request( + &self.model, + system, + messages, + tools, + &ImageFormat::OpenAi, + )?; + + if self.supports_cache_control() { + payload = update_request_for_cache_control(&payload); + } + + let response = self.post(payload.clone()).await?; + + let message = super::formats::openai::response_to_message(response.clone())?; + let usage = super::formats::openai::get_usage(&response); + let model = get_model(&response); + emit_debug_trace(&self.model, &payload, &response, &usage); + Ok((message, ProviderUsage::new(model, usage))) + } + + fn supports_embeddings(&self) -> bool { + true + } + + fn supports_cache_control(&self) -> bool { + if let Ok(models) = tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(self.fetch_models()) + }) { + if let Some(model_info) = models.iter().find(|m| m.name == self.model.model_name) { + return model_info.supports_cache_control.unwrap_or(false); + } + } + + self.model.model_name.to_lowercase().contains("claude") + } + + async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { + match self.fetch_models().await { + Ok(models) => { + let model_names: Vec = models.into_iter().map(|m| m.name).collect(); + Ok(Some(model_names)) + } + Err(e) => { + tracing::warn!("Failed to fetch models from LiteLLM: {}", e); + Ok(None) + } + } + } +} + +#[async_trait] +impl EmbeddingCapable for LiteLLMProvider { + async fn create_embeddings(&self, texts: Vec) -> Result>, anyhow::Error> { + let endpoint = format!("{}/v1/embeddings", self.host); + + let embedding_model = std::env::var("GOOSE_EMBEDDING_MODEL") + .unwrap_or_else(|_| "text-embedding-3-small".to_string()); + + let payload = json!({ + "input": texts, + "model": embedding_model, + "encoding_format": "float" + }); + + let mut req = self + .client + .post(&endpoint) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&payload); + + req = self.add_headers(req); + + let response = req.send().await?; + let response_text = response.text().await?; + let response_json: Value = serde_json::from_str(&response_text)?; + + let data = response_json["data"] + .as_array() + .ok_or_else(|| anyhow::anyhow!("Missing data field"))?; + + let mut embeddings = Vec::new(); + for item in data { + let embedding: Vec = item["embedding"] + .as_array() + .ok_or_else(|| anyhow::anyhow!("Missing embedding field"))? + .iter() + .map(|v| v.as_f64().unwrap_or(0.0) as f32) + .collect(); + embeddings.push(embedding); + } + + Ok(embeddings) + } +} + +/// Updates the request payload to include cache control headers for automatic prompt caching +/// Adds ephemeral cache control to the last 2 user messages, system message, and last tool +pub fn update_request_for_cache_control(original_payload: &Value) -> Value { + let mut payload = original_payload.clone(); + + if let Some(messages_spec) = payload + .as_object_mut() + .and_then(|obj| obj.get_mut("messages")) + .and_then(|messages| messages.as_array_mut()) + { + let mut user_count = 0; + for message in messages_spec.iter_mut().rev() { + if message.get("role") == Some(&json!("user")) { + if let Some(content) = message.get_mut("content") { + if let Some(content_str) = content.as_str() { + *content = json!([{ + "type": "text", + "text": content_str, + "cache_control": { "type": "ephemeral" } + }]); + } + } + user_count += 1; + if user_count >= 2 { + break; + } + } + } + + if let Some(system_message) = messages_spec + .iter_mut() + .find(|msg| msg.get("role") == Some(&json!("system"))) + { + if let Some(content) = system_message.get_mut("content") { + if let Some(content_str) = content.as_str() { + *system_message = json!({ + "role": "system", + "content": [{ + "type": "text", + "text": content_str, + "cache_control": { "type": "ephemeral" } + }] + }); + } + } + } + } + + if let Some(tools_spec) = payload + .as_object_mut() + .and_then(|obj| obj.get_mut("tools")) + .and_then(|tools| tools.as_array_mut()) + { + if let Some(last_tool) = tools_spec.last_mut() { + if let Some(function) = last_tool.get_mut("function") { + function + .as_object_mut() + .unwrap() + .insert("cache_control".to_string(), json!({ "type": "ephemeral" })); + } + } + } + payload +} + +fn parse_custom_headers(headers_str: String) -> HashMap { + let mut headers = HashMap::new(); + for line in headers_str.lines() { + if let Some((key, value)) = line.split_once(':') { + headers.insert(key.trim().to_string(), value.trim().to_string()); + } + } + headers +} diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index a7748044..4b155c69 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -16,6 +16,7 @@ pub mod githubcopilot; pub mod google; pub mod groq; pub mod lead_worker; +pub mod litellm; pub mod oauth; pub mod ollama; pub mod openai; diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index d629c37b..782e5adf 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -199,23 +199,20 @@ fn update_request_for_anthropic(original_payload: &Value) -> Value { } fn create_request_based_on_model( - model_config: &ModelConfig, + provider: &OpenRouterProvider, system: &str, messages: &[Message], tools: &[Tool], ) -> anyhow::Result { let mut payload = create_request( - model_config, + &provider.model, system, messages, tools, &super::utils::ImageFormat::OpenAi, )?; - if model_config - .model_name - .starts_with(OPENROUTER_MODEL_PREFIX_ANTHROPIC) - { + if provider.supports_cache_control() { payload = update_request_for_anthropic(&payload); } @@ -259,7 +256,7 @@ impl Provider for OpenRouterProvider { tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { // Create the base payload - let payload = create_request_based_on_model(&self.model, system, messages, tools)?; + let payload = create_request_based_on_model(self, system, messages, tools)?; // Make request let response = self.post(payload.clone()).await?; @@ -365,4 +362,10 @@ impl Provider for OpenRouterProvider { models.sort(); Ok(Some(models)) } + + fn supports_cache_control(&self) -> bool { + self.model + .model_name + .starts_with(OPENROUTER_MODEL_PREFIX_ANTHROPIC) + } } diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index 5cebdf3d..c7b5c426 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -4,7 +4,8 @@ use goose::message::{Message, MessageContent}; use goose::providers::base::Provider; use goose::providers::errors::ProviderError; use goose::providers::{ - anthropic, azure, bedrock, databricks, google, groq, ollama, openai, openrouter, snowflake, xai, + anthropic, azure, bedrock, databricks, google, groq, litellm, ollama, openai, openrouter, + snowflake, xai, }; use mcp_core::tool::Tool; use rmcp::model::{AnnotateAble, Content, RawImageContent}; @@ -158,7 +159,7 @@ impl ProviderTester { .content .iter() .filter_map(|message| message.as_tool_request()) - .last() + .next_back() .expect("got tool request") .id; @@ -596,6 +597,28 @@ async fn test_sagemaker_tgi_provider() -> Result<()> { .await } +#[tokio::test] +async fn test_litellm_provider() -> Result<()> { + if std::env::var("LITELLM_HOST").is_err() { + println!("LITELLM_HOST not set, skipping test"); + TEST_REPORT.record_skip("LiteLLM"); + return Ok(()); + } + + let env_mods = HashMap::from_iter([ + ("LITELLM_HOST", Some("http://localhost:4000".to_string())), + ("LITELLM_API_KEY", Some("".to_string())), + ]); + + test_provider( + "LiteLLM", + &[], // No required environment variables + Some(env_mods), + litellm::LiteLLMProvider::default, + ) + .await +} + #[tokio::test] async fn test_xai_provider() -> Result<()> { test_provider("Xai", &["XAI_API_KEY"], None, xai::XaiProvider::default).await