diff --git a/.gitignore b/.gitignore index d9b45401..eb423636 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ tmp/ # will have compiled files and executables debug/ target/ +.goose/ # These are backup files generated by rustfmt **/*.rs.bk diff --git a/crates/goose/src/model.rs b/crates/goose/src/model.rs index c9bcfe8e..c8e28e47 100644 --- a/crates/goose/src/model.rs +++ b/crates/goose/src/model.rs @@ -30,6 +30,9 @@ static MODEL_SPECIFIC_LIMITS: Lazy> = Lazy::new(|| // Meta Llama models, https://github.com/meta-llama/llama-models/tree/main?tab=readme-ov-file#llama-models-1 map.insert("llama3.2", 128_000); map.insert("llama3.3", 128_000); + + // x.ai Grok models, https://docs.x.ai/docs/overview + map.insert("grok", 131_072); map }); diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index af4fc2e3..11757aa9 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -17,6 +17,7 @@ use super::{ sagemaker_tgi::SageMakerTgiProvider, snowflake::SnowflakeProvider, venice::VeniceProvider, + xai::XaiProvider, }; use crate::model::ModelConfig; use anyhow::Result; @@ -52,6 +53,7 @@ pub fn providers() -> Vec { SageMakerTgiProvider::metadata(), VeniceProvider::metadata(), SnowflakeProvider::metadata(), + XaiProvider::metadata(), ] } @@ -128,6 +130,7 @@ fn create_provider(name: &str, model: ModelConfig) -> Result> "venice" => Ok(Arc::new(VeniceProvider::from_env(model)?)), "snowflake" => Ok(Arc::new(SnowflakeProvider::from_env(model)?)), "github_copilot" => Ok(Arc::new(GithubCopilotProvider::from_env(model)?)), + "xai" => Ok(Arc::new(XaiProvider::from_env(model)?)), _ => Err(anyhow::anyhow!("Unknown provider: {}", name)), } } @@ -259,7 +262,7 @@ mod tests { } // Set only the required lead model - env::set_var("GOOSE_LEAD_MODEL", "gpt-4o"); + env::set_var("GOOSE_LEAD_MODEL", "grok-3"); // This should use defaults for all other values let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string())); diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index ccda29a4..decd346a 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -24,5 +24,6 @@ pub mod toolshim; pub mod utils; pub mod utils_universal_openai_stream; pub mod venice; +pub mod xai; pub use factory::{create, providers}; diff --git a/crates/goose/src/providers/xai.rs b/crates/goose/src/providers/xai.rs new file mode 100644 index 00000000..7e91a23f --- /dev/null +++ b/crates/goose/src/providers/xai.rs @@ -0,0 +1,181 @@ +use super::errors::ProviderError; +use crate::message::Message; +use crate::model::ModelConfig; +use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; +use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; +use crate::providers::utils::get_model; +use anyhow::Result; +use async_trait::async_trait; +use mcp_core::Tool; +use reqwest::{Client, StatusCode}; +use serde_json::Value; +use std::time::Duration; +use url::Url; + +pub const XAI_API_HOST: &str = "https://api.x.ai/v1"; +pub const XAI_DEFAULT_MODEL: &str = "grok-3"; +pub const XAI_KNOWN_MODELS: &[&str] = &[ + "grok-3", + "grok-3-fast", + "grok-3-mini", + "grok-3-mini-fast", + "grok-2-vision-1212", + "grok-2-image-1212", + "grok-2-1212", + "grok-3-latest", + "grok-3-fast-latest", + "grok-3-mini-latest", + "grok-3-mini-fast-latest", + "grok-2-vision", + "grok-2-vision-latest", + "grok-2-image", + "grok-2-image-latest", + "grok-2", + "grok-2-latest", +]; + +pub const XAI_DOC_URL: &str = "https://docs.x.ai/docs/overview"; + +#[derive(serde::Serialize)] +pub struct XaiProvider { + #[serde(skip)] + client: Client, + host: String, + api_key: String, + model: ModelConfig, +} + +impl Default for XaiProvider { + fn default() -> Self { + let model = ModelConfig::new(XaiProvider::metadata().default_model); + XaiProvider::from_env(model).expect("Failed to initialize xAI provider") + } +} + +impl XaiProvider { + pub fn from_env(model: ModelConfig) -> Result { + let config = crate::config::Config::global(); + let api_key: String = config.get_secret("XAI_API_KEY")?; + let host: String = config + .get_param("XAI_HOST") + .unwrap_or_else(|_| XAI_API_HOST.to_string()); + + let client = Client::builder() + .timeout(Duration::from_secs(600)) + .build()?; + + Ok(Self { + client, + host, + api_key, + model, + }) + } + + async fn post(&self, payload: Value) -> anyhow::Result { + // Ensure the host ends with a slash for proper URL joining + let host = if self.host.ends_with('/') { + self.host.clone() + } else { + format!("{}/", self.host) + }; + let base_url = Url::parse(&host) + .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; + let url = base_url.join("chat/completions").map_err(|e| { + ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) + })?; + + tracing::debug!("xAI API URL: {}", url); + tracing::debug!("xAI request model: {:?}", self.model.model_name); + + let response = self + .client + .post(url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&payload) + .send() + .await?; + + let status = response.status(); + let payload: Option = response.json().await.ok(); + + match status { + StatusCode::OK => payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ), + StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { + Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \ + Status: {}. Response: {:?}", status, payload))) + } + StatusCode::PAYLOAD_TOO_LARGE => { + Err(ProviderError::ContextLengthExceeded(format!("{:?}", payload))) + } + StatusCode::TOO_MANY_REQUESTS => { + Err(ProviderError::RateLimitExceeded(format!("{:?}", payload))) + } + StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => { + Err(ProviderError::ServerError(format!("{:?}", payload))) + } + _ => { + tracing::debug!( + "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) + ); + Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status))) + } + } + } +} + +#[async_trait] +impl Provider for XaiProvider { + fn metadata() -> ProviderMetadata { + ProviderMetadata::new( + "xai", + "xAI", + "Grok models from xAI, including reasoning and multimodal capabilities", + XAI_DEFAULT_MODEL, + XAI_KNOWN_MODELS.to_vec(), + XAI_DOC_URL, + vec![ + ConfigKey::new("XAI_API_KEY", true, true, None), + ConfigKey::new("XAI_HOST", false, false, Some(XAI_API_HOST)), + ], + ) + } + + fn get_model_config(&self) -> ModelConfig { + self.model.clone() + } + + #[tracing::instrument( + skip(self, system, messages, tools), + fields(model_config, input, output, input_tokens, output_tokens, total_tokens) + )] + async fn complete( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> anyhow::Result<(Message, ProviderUsage), ProviderError> { + let payload = create_request( + &self.model, + system, + messages, + tools, + &super::utils::ImageFormat::OpenAi, + )?; + + let response = self.post(payload.clone()).await?; + + let message = response_to_message(response.clone())?; + let usage = match get_usage(&response) { + Ok(usage) => usage, + Err(ProviderError::UsageError(e)) => { + tracing::debug!("Failed to get usage data: {}", e); + Usage::default() + } + Err(e) => return Err(e), + }; + let model = get_model(&response); + super::utils::emit_debug_trace(&self.model, &payload, &response, &usage); + Ok((message, ProviderUsage::new(model, usage))) + } +} diff --git a/documentation/docs/getting-started/providers.md b/documentation/docs/getting-started/providers.md index 1aff07bf..1b8d1a53 100644 --- a/documentation/docs/getting-started/providers.md +++ b/documentation/docs/getting-started/providers.md @@ -33,6 +33,7 @@ Goose relies heavily on tool calling capabilities and currently works best with | [OpenRouter](https://openrouter.ai/) | API gateway for unified access to various models with features like rate-limiting management. | `OPENROUTER_API_KEY` | | [Snowflake](https://docs.snowflake.com/user-guide/snowflake-cortex/aisql#choosing-a-model) | Access the latest models using Snowflake Cortex services, including Claude models. **Requires a Snowflake account and programmatic access token (PAT)**. | `SNOWFLAKE_HOST`, `SNOWFLAKE_TOKEN` | | [Venice AI](https://venice.ai/home) | Provides access to open source models like Llama, Mistral, and Qwen while prioritizing user privacy. **Requires an account and an [API key](https://docs.venice.ai/overview/guides/generating-api-key)**. | `VENICE_API_KEY`, `VENICE_HOST` (optional), `VENICE_BASE_PATH` (optional), `VENICE_MODELS_PATH` (optional) | +| [xAI](https://x.ai/) | Access to xAI's Grok models including grok-3, grok-3-mini, and grok-3-fast with 131,072 token context window. | `XAI_API_KEY`, `XAI_HOST` (optional) | ## Configure Provider diff --git a/ui/desktop/src/components/settings/providers/ProviderRegistry.tsx b/ui/desktop/src/components/settings/providers/ProviderRegistry.tsx index 7285f778..56aa7ed7 100644 --- a/ui/desktop/src/components/settings/providers/ProviderRegistry.tsx +++ b/ui/desktop/src/components/settings/providers/ProviderRegistry.tsx @@ -97,6 +97,25 @@ export const PROVIDER_REGISTRY: ProviderRegistry[] = [ ], }, }, + { + name: 'xAI', + details: { + id: 'xai', + name: 'xAI', + description: 'Access Grok models from xAI, including reasoning and multimodal capabilities', + parameters: [ + { + name: 'XAI_API_KEY', + is_secret: true, + }, + { + name: 'XAI_HOST', + is_secret: false, + default: 'https://api.x.ai/v1', + }, + ], + }, + }, { name: 'Google', details: { diff --git a/ui/desktop/src/components/settings/providers/modal/subcomponents/ProviderLogo.tsx b/ui/desktop/src/components/settings/providers/modal/subcomponents/ProviderLogo.tsx index 585b436c..8f18c721 100644 --- a/ui/desktop/src/components/settings/providers/modal/subcomponents/ProviderLogo.tsx +++ b/ui/desktop/src/components/settings/providers/modal/subcomponents/ProviderLogo.tsx @@ -6,6 +6,7 @@ import OllamaLogo from './icons/ollama@3x.png'; import DatabricksLogo from './icons/databricks@3x.png'; import OpenRouterLogo from './icons/openrouter@3x.png'; import SnowflakeLogo from './icons/snowflake@3x.png'; +import XaiLogo from './icons/xai@3x.png'; import DefaultLogo from './icons/default@3x.png'; // Map provider names to their logos @@ -18,6 +19,7 @@ const providerLogos: Record = { databricks: DatabricksLogo, openrouter: OpenRouterLogo, snowflake: SnowflakeLogo, + xai: XaiLogo, default: DefaultLogo, }; @@ -30,10 +32,24 @@ export default function ProviderLogo({ providerName }: ProviderLogoProps) { const logoKey = providerName.toLowerCase(); const logo = providerLogos[logoKey] || DefaultLogo; + // Special handling for xAI logo + const isXai = logoKey === 'xai'; + const imageStyle = isXai ? { filter: 'invert(1)', opacity: 0.9 } : {}; + + // Use smaller size for xAI logo to fit better in circle + const imageClassName = isXai + ? 'w-8 h-8 object-contain' // Smaller size for xAI + : 'w-16 h-16 object-contain'; // Default size for others + return (
- {`${providerName} + {`${providerName}
); diff --git a/ui/desktop/src/components/settings/providers/modal/subcomponents/icons/xai.png b/ui/desktop/src/components/settings/providers/modal/subcomponents/icons/xai.png new file mode 100644 index 00000000..14bfe48b Binary files /dev/null and b/ui/desktop/src/components/settings/providers/modal/subcomponents/icons/xai.png differ diff --git a/ui/desktop/src/components/settings/providers/modal/subcomponents/icons/xai@3x.png b/ui/desktop/src/components/settings/providers/modal/subcomponents/icons/xai@3x.png new file mode 100644 index 00000000..14bfe48b Binary files /dev/null and b/ui/desktop/src/components/settings/providers/modal/subcomponents/icons/xai@3x.png differ