mirror of
https://github.com/aljazceru/goose.git
synced 2026-01-06 16:04:28 +01:00
feat: Add LiteLLM provider with automatic prompt caching support (#3380)
Signed-off-by: HikaruEgashira <account@egahika.dev>
This commit is contained in:
@@ -157,7 +157,7 @@ impl ProviderTester {
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|message| message.as_tool_request())
|
||||
.last()
|
||||
.next_back()
|
||||
.expect("got tool request")
|
||||
.id;
|
||||
|
||||
|
||||
@@ -275,6 +275,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial_test::serial]
|
||||
fn test_model_config_context_limit_env_vars() {
|
||||
use temp_env::with_vars;
|
||||
|
||||
|
||||
@@ -41,6 +41,8 @@ pub struct ModelInfo {
|
||||
pub output_token_cost: Option<f64>,
|
||||
/// Currency for the costs (default: "$")
|
||||
pub currency: Option<String>,
|
||||
/// Whether this model supports cache control
|
||||
pub supports_cache_control: Option<bool>,
|
||||
}
|
||||
|
||||
impl ModelInfo {
|
||||
@@ -52,6 +54,7 @@ impl ModelInfo {
|
||||
input_token_cost: None,
|
||||
output_token_cost: None,
|
||||
currency: None,
|
||||
supports_cache_control: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,6 +71,7 @@ impl ModelInfo {
|
||||
input_token_cost: Some(input_cost),
|
||||
output_token_cost: Some(output_cost),
|
||||
currency: Some("$".to_string()),
|
||||
supports_cache_control: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -115,6 +119,7 @@ impl ProviderMetadata {
|
||||
input_token_cost: None,
|
||||
output_token_cost: None,
|
||||
currency: None,
|
||||
supports_cache_control: None,
|
||||
})
|
||||
.collect(),
|
||||
model_doc_link: model_doc_link.to_string(),
|
||||
@@ -290,6 +295,11 @@ pub trait Provider: Send + Sync {
|
||||
false
|
||||
}
|
||||
|
||||
/// Check if this provider supports cache control
|
||||
fn supports_cache_control(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Create embeddings if supported. Default implementation returns an error.
|
||||
async fn create_embeddings(&self, _texts: Vec<String>) -> Result<Vec<Vec<f32>>, ProviderError> {
|
||||
Err(ProviderError::ExecutionError(
|
||||
@@ -435,6 +445,7 @@ mod tests {
|
||||
input_token_cost: None,
|
||||
output_token_cost: None,
|
||||
currency: None,
|
||||
supports_cache_control: None,
|
||||
};
|
||||
assert_eq!(info.context_limit, 1000);
|
||||
|
||||
@@ -445,6 +456,7 @@ mod tests {
|
||||
input_token_cost: None,
|
||||
output_token_cost: None,
|
||||
currency: None,
|
||||
supports_cache_control: None,
|
||||
};
|
||||
assert_eq!(info, info2);
|
||||
|
||||
@@ -455,6 +467,7 @@ mod tests {
|
||||
input_token_cost: None,
|
||||
output_token_cost: None,
|
||||
currency: None,
|
||||
supports_cache_control: None,
|
||||
};
|
||||
assert_ne!(info, info3);
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ use super::{
|
||||
google::GoogleProvider,
|
||||
groq::GroqProvider,
|
||||
lead_worker::LeadWorkerProvider,
|
||||
litellm::LiteLLMProvider,
|
||||
ollama::OllamaProvider,
|
||||
openai::OpenAiProvider,
|
||||
openrouter::OpenRouterProvider,
|
||||
@@ -50,6 +51,7 @@ pub fn providers() -> Vec<ProviderMetadata> {
|
||||
// GithubCopilotProvider::metadata(),
|
||||
GoogleProvider::metadata(),
|
||||
GroqProvider::metadata(),
|
||||
LiteLLMProvider::metadata(),
|
||||
OllamaProvider::metadata(),
|
||||
OpenAiProvider::metadata(),
|
||||
OpenRouterProvider::metadata(),
|
||||
@@ -158,6 +160,7 @@ fn create_provider(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>>
|
||||
"databricks" => Ok(Arc::new(DatabricksProvider::from_env(model)?)),
|
||||
"gemini-cli" => Ok(Arc::new(GeminiCliProvider::from_env(model)?)),
|
||||
"groq" => Ok(Arc::new(GroqProvider::from_env(model)?)),
|
||||
"litellm" => Ok(Arc::new(LiteLLMProvider::from_env(model)?)),
|
||||
"ollama" => Ok(Arc::new(OllamaProvider::from_env(model)?)),
|
||||
"openrouter" => Ok(Arc::new(OpenRouterProvider::from_env(model)?)),
|
||||
"gcp_vertex_ai" => Ok(Arc::new(GcpVertexAIProvider::from_env(model)?)),
|
||||
|
||||
357
crates/goose/src/providers/litellm.rs
Normal file
357
crates/goose/src/providers/litellm.rs
Normal file
@@ -0,0 +1,357 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use url::Url;
|
||||
|
||||
use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage};
|
||||
use super::embedding::EmbeddingCapable;
|
||||
use super::errors::ProviderError;
|
||||
use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat};
|
||||
use crate::message::Message;
|
||||
use crate::model::ModelConfig;
|
||||
use mcp_core::tool::Tool;
|
||||
|
||||
pub const LITELLM_DEFAULT_MODEL: &str = "gpt-4o-mini";
|
||||
pub const LITELLM_DOC_URL: &str = "https://docs.litellm.ai/docs/";
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct LiteLLMProvider {
|
||||
#[serde(skip)]
|
||||
client: Client,
|
||||
host: String,
|
||||
base_path: String,
|
||||
api_key: String,
|
||||
model: ModelConfig,
|
||||
custom_headers: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
impl Default for LiteLLMProvider {
|
||||
fn default() -> Self {
|
||||
let model = ModelConfig::new(LiteLLMProvider::metadata().default_model);
|
||||
LiteLLMProvider::from_env(model).expect("Failed to initialize LiteLLM provider")
|
||||
}
|
||||
}
|
||||
|
||||
impl LiteLLMProvider {
|
||||
pub fn from_env(model: ModelConfig) -> Result<Self> {
|
||||
let config = crate::config::Config::global();
|
||||
let api_key: String = config
|
||||
.get_secret("LITELLM_API_KEY")
|
||||
.unwrap_or_else(|_| String::new());
|
||||
let host: String = config
|
||||
.get_param("LITELLM_HOST")
|
||||
.unwrap_or_else(|_| "https://api.litellm.ai".to_string());
|
||||
let base_path: String = config
|
||||
.get_param("LITELLM_BASE_PATH")
|
||||
.unwrap_or_else(|_| "v1/chat/completions".to_string());
|
||||
let custom_headers: Option<HashMap<String, String>> = config
|
||||
.get_secret("LITELLM_CUSTOM_HEADERS")
|
||||
.or_else(|_| config.get_param("LITELLM_CUSTOM_HEADERS"))
|
||||
.ok()
|
||||
.map(parse_custom_headers);
|
||||
let timeout_secs: u64 = config.get_param("LITELLM_TIMEOUT").unwrap_or(600);
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(timeout_secs))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
host,
|
||||
base_path,
|
||||
api_key,
|
||||
model,
|
||||
custom_headers,
|
||||
})
|
||||
}
|
||||
|
||||
fn add_headers(&self, mut request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
||||
if let Some(custom_headers) = &self.custom_headers {
|
||||
for (key, value) in custom_headers {
|
||||
request = request.header(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
request
|
||||
}
|
||||
|
||||
async fn fetch_models(&self) -> Result<Vec<ModelInfo>, ProviderError> {
|
||||
let models_url = format!("{}/model/info", self.host);
|
||||
|
||||
let mut req = self
|
||||
.client
|
||||
.get(&models_url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key));
|
||||
|
||||
req = self.add_headers(req);
|
||||
|
||||
let response = req
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Failed to fetch models: {}", e)))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(ProviderError::RequestFailed(format!(
|
||||
"Models endpoint returned status: {}",
|
||||
response.status()
|
||||
)));
|
||||
}
|
||||
|
||||
let response_json: Value = response.json().await.map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to parse models response: {}", e))
|
||||
})?;
|
||||
|
||||
let models_data = response_json["data"].as_array().ok_or_else(|| {
|
||||
ProviderError::RequestFailed("Missing data field in models response".to_string())
|
||||
})?;
|
||||
|
||||
let mut models = Vec::new();
|
||||
for model_data in models_data {
|
||||
if let Some(model_name) = model_data["model_name"].as_str() {
|
||||
if model_name.contains("/*") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let model_info = &model_data["model_info"];
|
||||
let context_length =
|
||||
model_info["max_input_tokens"].as_u64().unwrap_or(128000) as usize;
|
||||
let supports_cache_control = model_info["supports_prompt_caching"].as_bool();
|
||||
|
||||
let mut model_info_obj = ModelInfo::new(model_name, context_length);
|
||||
model_info_obj.supports_cache_control = supports_cache_control;
|
||||
models.push(model_info_obj);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(models)
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
let base_url = Url::parse(&self.host)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
let url = base_url.join(&self.base_path).map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
||||
})?;
|
||||
|
||||
let request = self
|
||||
.client
|
||||
.post(url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key));
|
||||
|
||||
let request = self.add_headers(request);
|
||||
|
||||
let response = request.json(&payload).send().await?;
|
||||
|
||||
handle_response_openai_compat(response).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for LiteLLMProvider {
|
||||
fn metadata() -> ProviderMetadata {
|
||||
ProviderMetadata::new(
|
||||
"litellm",
|
||||
"LiteLLM",
|
||||
"LiteLLM proxy supporting multiple models with automatic prompt caching",
|
||||
LITELLM_DEFAULT_MODEL,
|
||||
vec![],
|
||||
LITELLM_DOC_URL,
|
||||
vec![
|
||||
ConfigKey::new("LITELLM_API_KEY", false, true, None),
|
||||
ConfigKey::new("LITELLM_HOST", true, false, Some("http://localhost:4000")),
|
||||
ConfigKey::new(
|
||||
"LITELLM_BASE_PATH",
|
||||
true,
|
||||
false,
|
||||
Some("v1/chat/completions"),
|
||||
),
|
||||
ConfigKey::new("LITELLM_CUSTOM_HEADERS", false, true, None),
|
||||
ConfigKey::new("LITELLM_TIMEOUT", false, false, Some("600")),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
self.model.clone()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, name = "provider_complete")]
|
||||
async fn complete(
|
||||
&self,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
let mut payload = super::formats::openai::create_request(
|
||||
&self.model,
|
||||
system,
|
||||
messages,
|
||||
tools,
|
||||
&ImageFormat::OpenAi,
|
||||
)?;
|
||||
|
||||
if self.supports_cache_control() {
|
||||
payload = update_request_for_cache_control(&payload);
|
||||
}
|
||||
|
||||
let response = self.post(payload.clone()).await?;
|
||||
|
||||
let message = super::formats::openai::response_to_message(response.clone())?;
|
||||
let usage = super::formats::openai::get_usage(&response);
|
||||
let model = get_model(&response);
|
||||
emit_debug_trace(&self.model, &payload, &response, &usage);
|
||||
Ok((message, ProviderUsage::new(model, usage)))
|
||||
}
|
||||
|
||||
fn supports_embeddings(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn supports_cache_control(&self) -> bool {
|
||||
if let Ok(models) = tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(self.fetch_models())
|
||||
}) {
|
||||
if let Some(model_info) = models.iter().find(|m| m.name == self.model.model_name) {
|
||||
return model_info.supports_cache_control.unwrap_or(false);
|
||||
}
|
||||
}
|
||||
|
||||
self.model.model_name.to_lowercase().contains("claude")
|
||||
}
|
||||
|
||||
async fn fetch_supported_models_async(&self) -> Result<Option<Vec<String>>, ProviderError> {
|
||||
match self.fetch_models().await {
|
||||
Ok(models) => {
|
||||
let model_names: Vec<String> = models.into_iter().map(|m| m.name).collect();
|
||||
Ok(Some(model_names))
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch models from LiteLLM: {}", e);
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingCapable for LiteLLMProvider {
|
||||
async fn create_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, anyhow::Error> {
|
||||
let endpoint = format!("{}/v1/embeddings", self.host);
|
||||
|
||||
let embedding_model = std::env::var("GOOSE_EMBEDDING_MODEL")
|
||||
.unwrap_or_else(|_| "text-embedding-3-small".to_string());
|
||||
|
||||
let payload = json!({
|
||||
"input": texts,
|
||||
"model": embedding_model,
|
||||
"encoding_format": "float"
|
||||
});
|
||||
|
||||
let mut req = self
|
||||
.client
|
||||
.post(&endpoint)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&payload);
|
||||
|
||||
req = self.add_headers(req);
|
||||
|
||||
let response = req.send().await?;
|
||||
let response_text = response.text().await?;
|
||||
let response_json: Value = serde_json::from_str(&response_text)?;
|
||||
|
||||
let data = response_json["data"]
|
||||
.as_array()
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing data field"))?;
|
||||
|
||||
let mut embeddings = Vec::new();
|
||||
for item in data {
|
||||
let embedding: Vec<f32> = item["embedding"]
|
||||
.as_array()
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing embedding field"))?
|
||||
.iter()
|
||||
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
|
||||
.collect();
|
||||
embeddings.push(embedding);
|
||||
}
|
||||
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
/// Updates the request payload to include cache control headers for automatic prompt caching
|
||||
/// Adds ephemeral cache control to the last 2 user messages, system message, and last tool
|
||||
pub fn update_request_for_cache_control(original_payload: &Value) -> Value {
|
||||
let mut payload = original_payload.clone();
|
||||
|
||||
if let Some(messages_spec) = payload
|
||||
.as_object_mut()
|
||||
.and_then(|obj| obj.get_mut("messages"))
|
||||
.and_then(|messages| messages.as_array_mut())
|
||||
{
|
||||
let mut user_count = 0;
|
||||
for message in messages_spec.iter_mut().rev() {
|
||||
if message.get("role") == Some(&json!("user")) {
|
||||
if let Some(content) = message.get_mut("content") {
|
||||
if let Some(content_str) = content.as_str() {
|
||||
*content = json!([{
|
||||
"type": "text",
|
||||
"text": content_str,
|
||||
"cache_control": { "type": "ephemeral" }
|
||||
}]);
|
||||
}
|
||||
}
|
||||
user_count += 1;
|
||||
if user_count >= 2 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(system_message) = messages_spec
|
||||
.iter_mut()
|
||||
.find(|msg| msg.get("role") == Some(&json!("system")))
|
||||
{
|
||||
if let Some(content) = system_message.get_mut("content") {
|
||||
if let Some(content_str) = content.as_str() {
|
||||
*system_message = json!({
|
||||
"role": "system",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": content_str,
|
||||
"cache_control": { "type": "ephemeral" }
|
||||
}]
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(tools_spec) = payload
|
||||
.as_object_mut()
|
||||
.and_then(|obj| obj.get_mut("tools"))
|
||||
.and_then(|tools| tools.as_array_mut())
|
||||
{
|
||||
if let Some(last_tool) = tools_spec.last_mut() {
|
||||
if let Some(function) = last_tool.get_mut("function") {
|
||||
function
|
||||
.as_object_mut()
|
||||
.unwrap()
|
||||
.insert("cache_control".to_string(), json!({ "type": "ephemeral" }));
|
||||
}
|
||||
}
|
||||
}
|
||||
payload
|
||||
}
|
||||
|
||||
fn parse_custom_headers(headers_str: String) -> HashMap<String, String> {
|
||||
let mut headers = HashMap::new();
|
||||
for line in headers_str.lines() {
|
||||
if let Some((key, value)) = line.split_once(':') {
|
||||
headers.insert(key.trim().to_string(), value.trim().to_string());
|
||||
}
|
||||
}
|
||||
headers
|
||||
}
|
||||
@@ -16,6 +16,7 @@ pub mod githubcopilot;
|
||||
pub mod google;
|
||||
pub mod groq;
|
||||
pub mod lead_worker;
|
||||
pub mod litellm;
|
||||
pub mod oauth;
|
||||
pub mod ollama;
|
||||
pub mod openai;
|
||||
|
||||
@@ -199,23 +199,20 @@ fn update_request_for_anthropic(original_payload: &Value) -> Value {
|
||||
}
|
||||
|
||||
fn create_request_based_on_model(
|
||||
model_config: &ModelConfig,
|
||||
provider: &OpenRouterProvider,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> anyhow::Result<Value, Error> {
|
||||
let mut payload = create_request(
|
||||
model_config,
|
||||
&provider.model,
|
||||
system,
|
||||
messages,
|
||||
tools,
|
||||
&super::utils::ImageFormat::OpenAi,
|
||||
)?;
|
||||
|
||||
if model_config
|
||||
.model_name
|
||||
.starts_with(OPENROUTER_MODEL_PREFIX_ANTHROPIC)
|
||||
{
|
||||
if provider.supports_cache_control() {
|
||||
payload = update_request_for_anthropic(&payload);
|
||||
}
|
||||
|
||||
@@ -259,7 +256,7 @@ impl Provider for OpenRouterProvider {
|
||||
tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
// Create the base payload
|
||||
let payload = create_request_based_on_model(&self.model, system, messages, tools)?;
|
||||
let payload = create_request_based_on_model(self, system, messages, tools)?;
|
||||
|
||||
// Make request
|
||||
let response = self.post(payload.clone()).await?;
|
||||
@@ -365,4 +362,10 @@ impl Provider for OpenRouterProvider {
|
||||
models.sort();
|
||||
Ok(Some(models))
|
||||
}
|
||||
|
||||
fn supports_cache_control(&self) -> bool {
|
||||
self.model
|
||||
.model_name
|
||||
.starts_with(OPENROUTER_MODEL_PREFIX_ANTHROPIC)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,8 @@ use goose::message::{Message, MessageContent};
|
||||
use goose::providers::base::Provider;
|
||||
use goose::providers::errors::ProviderError;
|
||||
use goose::providers::{
|
||||
anthropic, azure, bedrock, databricks, google, groq, ollama, openai, openrouter, snowflake, xai,
|
||||
anthropic, azure, bedrock, databricks, google, groq, litellm, ollama, openai, openrouter,
|
||||
snowflake, xai,
|
||||
};
|
||||
use mcp_core::tool::Tool;
|
||||
use rmcp::model::{AnnotateAble, Content, RawImageContent};
|
||||
@@ -158,7 +159,7 @@ impl ProviderTester {
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|message| message.as_tool_request())
|
||||
.last()
|
||||
.next_back()
|
||||
.expect("got tool request")
|
||||
.id;
|
||||
|
||||
@@ -596,6 +597,28 @@ async fn test_sagemaker_tgi_provider() -> Result<()> {
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_litellm_provider() -> Result<()> {
|
||||
if std::env::var("LITELLM_HOST").is_err() {
|
||||
println!("LITELLM_HOST not set, skipping test");
|
||||
TEST_REPORT.record_skip("LiteLLM");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let env_mods = HashMap::from_iter([
|
||||
("LITELLM_HOST", Some("http://localhost:4000".to_string())),
|
||||
("LITELLM_API_KEY", Some("".to_string())),
|
||||
]);
|
||||
|
||||
test_provider(
|
||||
"LiteLLM",
|
||||
&[], // No required environment variables
|
||||
Some(env_mods),
|
||||
litellm::LiteLLMProvider::default,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_xai_provider() -> Result<()> {
|
||||
test_provider("Xai", &["XAI_API_KEY"], None, xai::XaiProvider::default).await
|
||||
|
||||
Reference in New Issue
Block a user