feat: Add LiteLLM provider with automatic prompt caching support (#3380)

Signed-off-by: HikaruEgashira <account@egahika.dev>
This commit is contained in:
hikae
2025-07-19 19:07:01 +09:00
committed by GitHub
parent c38c7f16dc
commit 78b30cc0ca
8 changed files with 411 additions and 10 deletions

View File

@@ -157,7 +157,7 @@ impl ProviderTester {
.content
.iter()
.filter_map(|message| message.as_tool_request())
.last()
.next_back()
.expect("got tool request")
.id;

View File

@@ -275,6 +275,7 @@ mod tests {
}
#[test]
#[serial_test::serial]
fn test_model_config_context_limit_env_vars() {
use temp_env::with_vars;

View File

@@ -41,6 +41,8 @@ pub struct ModelInfo {
pub output_token_cost: Option<f64>,
/// Currency for the costs (default: "$")
pub currency: Option<String>,
/// Whether this model supports cache control
pub supports_cache_control: Option<bool>,
}
impl ModelInfo {
@@ -52,6 +54,7 @@ impl ModelInfo {
input_token_cost: None,
output_token_cost: None,
currency: None,
supports_cache_control: None,
}
}
@@ -68,6 +71,7 @@ impl ModelInfo {
input_token_cost: Some(input_cost),
output_token_cost: Some(output_cost),
currency: Some("$".to_string()),
supports_cache_control: None,
}
}
}
@@ -115,6 +119,7 @@ impl ProviderMetadata {
input_token_cost: None,
output_token_cost: None,
currency: None,
supports_cache_control: None,
})
.collect(),
model_doc_link: model_doc_link.to_string(),
@@ -290,6 +295,11 @@ pub trait Provider: Send + Sync {
false
}
/// Check if this provider supports cache control
fn supports_cache_control(&self) -> bool {
false
}
/// Create embeddings if supported. Default implementation returns an error.
async fn create_embeddings(&self, _texts: Vec<String>) -> Result<Vec<Vec<f32>>, ProviderError> {
Err(ProviderError::ExecutionError(
@@ -435,6 +445,7 @@ mod tests {
input_token_cost: None,
output_token_cost: None,
currency: None,
supports_cache_control: None,
};
assert_eq!(info.context_limit, 1000);
@@ -445,6 +456,7 @@ mod tests {
input_token_cost: None,
output_token_cost: None,
currency: None,
supports_cache_control: None,
};
assert_eq!(info, info2);
@@ -455,6 +467,7 @@ mod tests {
input_token_cost: None,
output_token_cost: None,
currency: None,
supports_cache_control: None,
};
assert_ne!(info, info3);
}

View File

@@ -12,6 +12,7 @@ use super::{
google::GoogleProvider,
groq::GroqProvider,
lead_worker::LeadWorkerProvider,
litellm::LiteLLMProvider,
ollama::OllamaProvider,
openai::OpenAiProvider,
openrouter::OpenRouterProvider,
@@ -50,6 +51,7 @@ pub fn providers() -> Vec<ProviderMetadata> {
// GithubCopilotProvider::metadata(),
GoogleProvider::metadata(),
GroqProvider::metadata(),
LiteLLMProvider::metadata(),
OllamaProvider::metadata(),
OpenAiProvider::metadata(),
OpenRouterProvider::metadata(),
@@ -158,6 +160,7 @@ fn create_provider(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>>
"databricks" => Ok(Arc::new(DatabricksProvider::from_env(model)?)),
"gemini-cli" => Ok(Arc::new(GeminiCliProvider::from_env(model)?)),
"groq" => Ok(Arc::new(GroqProvider::from_env(model)?)),
"litellm" => Ok(Arc::new(LiteLLMProvider::from_env(model)?)),
"ollama" => Ok(Arc::new(OllamaProvider::from_env(model)?)),
"openrouter" => Ok(Arc::new(OpenRouterProvider::from_env(model)?)),
"gcp_vertex_ai" => Ok(Arc::new(GcpVertexAIProvider::from_env(model)?)),

View File

@@ -0,0 +1,357 @@
use anyhow::Result;
use async_trait::async_trait;
use reqwest::Client;
use serde_json::{json, Value};
use std::collections::HashMap;
use std::time::Duration;
use url::Url;
use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage};
use super::embedding::EmbeddingCapable;
use super::errors::ProviderError;
use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat};
use crate::message::Message;
use crate::model::ModelConfig;
use mcp_core::tool::Tool;
pub const LITELLM_DEFAULT_MODEL: &str = "gpt-4o-mini";
pub const LITELLM_DOC_URL: &str = "https://docs.litellm.ai/docs/";
#[derive(Debug, serde::Serialize)]
pub struct LiteLLMProvider {
#[serde(skip)]
client: Client,
host: String,
base_path: String,
api_key: String,
model: ModelConfig,
custom_headers: Option<HashMap<String, String>>,
}
impl Default for LiteLLMProvider {
fn default() -> Self {
let model = ModelConfig::new(LiteLLMProvider::metadata().default_model);
LiteLLMProvider::from_env(model).expect("Failed to initialize LiteLLM provider")
}
}
impl LiteLLMProvider {
pub fn from_env(model: ModelConfig) -> Result<Self> {
let config = crate::config::Config::global();
let api_key: String = config
.get_secret("LITELLM_API_KEY")
.unwrap_or_else(|_| String::new());
let host: String = config
.get_param("LITELLM_HOST")
.unwrap_or_else(|_| "https://api.litellm.ai".to_string());
let base_path: String = config
.get_param("LITELLM_BASE_PATH")
.unwrap_or_else(|_| "v1/chat/completions".to_string());
let custom_headers: Option<HashMap<String, String>> = config
.get_secret("LITELLM_CUSTOM_HEADERS")
.or_else(|_| config.get_param("LITELLM_CUSTOM_HEADERS"))
.ok()
.map(parse_custom_headers);
let timeout_secs: u64 = config.get_param("LITELLM_TIMEOUT").unwrap_or(600);
let client = Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.build()?;
Ok(Self {
client,
host,
base_path,
api_key,
model,
custom_headers,
})
}
fn add_headers(&self, mut request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
if let Some(custom_headers) = &self.custom_headers {
for (key, value) in custom_headers {
request = request.header(key, value);
}
}
request
}
async fn fetch_models(&self) -> Result<Vec<ModelInfo>, ProviderError> {
let models_url = format!("{}/model/info", self.host);
let mut req = self
.client
.get(&models_url)
.header("Authorization", format!("Bearer {}", self.api_key));
req = self.add_headers(req);
let response = req
.send()
.await
.map_err(|e| ProviderError::RequestFailed(format!("Failed to fetch models: {}", e)))?;
if !response.status().is_success() {
return Err(ProviderError::RequestFailed(format!(
"Models endpoint returned status: {}",
response.status()
)));
}
let response_json: Value = response.json().await.map_err(|e| {
ProviderError::RequestFailed(format!("Failed to parse models response: {}", e))
})?;
let models_data = response_json["data"].as_array().ok_or_else(|| {
ProviderError::RequestFailed("Missing data field in models response".to_string())
})?;
let mut models = Vec::new();
for model_data in models_data {
if let Some(model_name) = model_data["model_name"].as_str() {
if model_name.contains("/*") {
continue;
}
let model_info = &model_data["model_info"];
let context_length =
model_info["max_input_tokens"].as_u64().unwrap_or(128000) as usize;
let supports_cache_control = model_info["supports_prompt_caching"].as_bool();
let mut model_info_obj = ModelInfo::new(model_name, context_length);
model_info_obj.supports_cache_control = supports_cache_control;
models.push(model_info_obj);
}
}
Ok(models)
}
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
let base_url = Url::parse(&self.host)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
let url = base_url.join(&self.base_path).map_err(|e| {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;
let request = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key));
let request = self.add_headers(request);
let response = request.json(&payload).send().await?;
handle_response_openai_compat(response).await
}
}
#[async_trait]
impl Provider for LiteLLMProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::new(
"litellm",
"LiteLLM",
"LiteLLM proxy supporting multiple models with automatic prompt caching",
LITELLM_DEFAULT_MODEL,
vec![],
LITELLM_DOC_URL,
vec![
ConfigKey::new("LITELLM_API_KEY", false, true, None),
ConfigKey::new("LITELLM_HOST", true, false, Some("http://localhost:4000")),
ConfigKey::new(
"LITELLM_BASE_PATH",
true,
false,
Some("v1/chat/completions"),
),
ConfigKey::new("LITELLM_CUSTOM_HEADERS", false, true, None),
ConfigKey::new("LITELLM_TIMEOUT", false, false, Some("600")),
],
)
}
fn get_model_config(&self) -> ModelConfig {
self.model.clone()
}
#[tracing::instrument(skip_all, name = "provider_complete")]
async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
let mut payload = super::formats::openai::create_request(
&self.model,
system,
messages,
tools,
&ImageFormat::OpenAi,
)?;
if self.supports_cache_control() {
payload = update_request_for_cache_control(&payload);
}
let response = self.post(payload.clone()).await?;
let message = super::formats::openai::response_to_message(response.clone())?;
let usage = super::formats::openai::get_usage(&response);
let model = get_model(&response);
emit_debug_trace(&self.model, &payload, &response, &usage);
Ok((message, ProviderUsage::new(model, usage)))
}
fn supports_embeddings(&self) -> bool {
true
}
fn supports_cache_control(&self) -> bool {
if let Ok(models) = tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(self.fetch_models())
}) {
if let Some(model_info) = models.iter().find(|m| m.name == self.model.model_name) {
return model_info.supports_cache_control.unwrap_or(false);
}
}
self.model.model_name.to_lowercase().contains("claude")
}
async fn fetch_supported_models_async(&self) -> Result<Option<Vec<String>>, ProviderError> {
match self.fetch_models().await {
Ok(models) => {
let model_names: Vec<String> = models.into_iter().map(|m| m.name).collect();
Ok(Some(model_names))
}
Err(e) => {
tracing::warn!("Failed to fetch models from LiteLLM: {}", e);
Ok(None)
}
}
}
}
#[async_trait]
impl EmbeddingCapable for LiteLLMProvider {
async fn create_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, anyhow::Error> {
let endpoint = format!("{}/v1/embeddings", self.host);
let embedding_model = std::env::var("GOOSE_EMBEDDING_MODEL")
.unwrap_or_else(|_| "text-embedding-3-small".to_string());
let payload = json!({
"input": texts,
"model": embedding_model,
"encoding_format": "float"
});
let mut req = self
.client
.post(&endpoint)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&payload);
req = self.add_headers(req);
let response = req.send().await?;
let response_text = response.text().await?;
let response_json: Value = serde_json::from_str(&response_text)?;
let data = response_json["data"]
.as_array()
.ok_or_else(|| anyhow::anyhow!("Missing data field"))?;
let mut embeddings = Vec::new();
for item in data {
let embedding: Vec<f32> = item["embedding"]
.as_array()
.ok_or_else(|| anyhow::anyhow!("Missing embedding field"))?
.iter()
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
.collect();
embeddings.push(embedding);
}
Ok(embeddings)
}
}
/// Updates the request payload to include cache control headers for automatic prompt caching
/// Adds ephemeral cache control to the last 2 user messages, system message, and last tool
pub fn update_request_for_cache_control(original_payload: &Value) -> Value {
let mut payload = original_payload.clone();
if let Some(messages_spec) = payload
.as_object_mut()
.and_then(|obj| obj.get_mut("messages"))
.and_then(|messages| messages.as_array_mut())
{
let mut user_count = 0;
for message in messages_spec.iter_mut().rev() {
if message.get("role") == Some(&json!("user")) {
if let Some(content) = message.get_mut("content") {
if let Some(content_str) = content.as_str() {
*content = json!([{
"type": "text",
"text": content_str,
"cache_control": { "type": "ephemeral" }
}]);
}
}
user_count += 1;
if user_count >= 2 {
break;
}
}
}
if let Some(system_message) = messages_spec
.iter_mut()
.find(|msg| msg.get("role") == Some(&json!("system")))
{
if let Some(content) = system_message.get_mut("content") {
if let Some(content_str) = content.as_str() {
*system_message = json!({
"role": "system",
"content": [{
"type": "text",
"text": content_str,
"cache_control": { "type": "ephemeral" }
}]
});
}
}
}
}
if let Some(tools_spec) = payload
.as_object_mut()
.and_then(|obj| obj.get_mut("tools"))
.and_then(|tools| tools.as_array_mut())
{
if let Some(last_tool) = tools_spec.last_mut() {
if let Some(function) = last_tool.get_mut("function") {
function
.as_object_mut()
.unwrap()
.insert("cache_control".to_string(), json!({ "type": "ephemeral" }));
}
}
}
payload
}
fn parse_custom_headers(headers_str: String) -> HashMap<String, String> {
let mut headers = HashMap::new();
for line in headers_str.lines() {
if let Some((key, value)) = line.split_once(':') {
headers.insert(key.trim().to_string(), value.trim().to_string());
}
}
headers
}

View File

@@ -16,6 +16,7 @@ pub mod githubcopilot;
pub mod google;
pub mod groq;
pub mod lead_worker;
pub mod litellm;
pub mod oauth;
pub mod ollama;
pub mod openai;

View File

@@ -199,23 +199,20 @@ fn update_request_for_anthropic(original_payload: &Value) -> Value {
}
fn create_request_based_on_model(
model_config: &ModelConfig,
provider: &OpenRouterProvider,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> anyhow::Result<Value, Error> {
let mut payload = create_request(
model_config,
&provider.model,
system,
messages,
tools,
&super::utils::ImageFormat::OpenAi,
)?;
if model_config
.model_name
.starts_with(OPENROUTER_MODEL_PREFIX_ANTHROPIC)
{
if provider.supports_cache_control() {
payload = update_request_for_anthropic(&payload);
}
@@ -259,7 +256,7 @@ impl Provider for OpenRouterProvider {
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
// Create the base payload
let payload = create_request_based_on_model(&self.model, system, messages, tools)?;
let payload = create_request_based_on_model(self, system, messages, tools)?;
// Make request
let response = self.post(payload.clone()).await?;
@@ -365,4 +362,10 @@ impl Provider for OpenRouterProvider {
models.sort();
Ok(Some(models))
}
fn supports_cache_control(&self) -> bool {
self.model
.model_name
.starts_with(OPENROUTER_MODEL_PREFIX_ANTHROPIC)
}
}

View File

@@ -4,7 +4,8 @@ use goose::message::{Message, MessageContent};
use goose::providers::base::Provider;
use goose::providers::errors::ProviderError;
use goose::providers::{
anthropic, azure, bedrock, databricks, google, groq, ollama, openai, openrouter, snowflake, xai,
anthropic, azure, bedrock, databricks, google, groq, litellm, ollama, openai, openrouter,
snowflake, xai,
};
use mcp_core::tool::Tool;
use rmcp::model::{AnnotateAble, Content, RawImageContent};
@@ -158,7 +159,7 @@ impl ProviderTester {
.content
.iter()
.filter_map(|message| message.as_tool_request())
.last()
.next_back()
.expect("got tool request")
.id;
@@ -596,6 +597,28 @@ async fn test_sagemaker_tgi_provider() -> Result<()> {
.await
}
#[tokio::test]
async fn test_litellm_provider() -> Result<()> {
if std::env::var("LITELLM_HOST").is_err() {
println!("LITELLM_HOST not set, skipping test");
TEST_REPORT.record_skip("LiteLLM");
return Ok(());
}
let env_mods = HashMap::from_iter([
("LITELLM_HOST", Some("http://localhost:4000".to_string())),
("LITELLM_API_KEY", Some("".to_string())),
]);
test_provider(
"LiteLLM",
&[], // No required environment variables
Some(env_mods),
litellm::LiteLLMProvider::default,
)
.await
}
#[tokio::test]
async fn test_xai_provider() -> Result<()> {
test_provider("Xai", &["XAI_API_KEY"], None, xai::XaiProvider::default).await