Add xAI Provider Support for Grok Models (#2976)

Co-authored-by: jack <jack@deck.local>
This commit is contained in:
jack
2025-06-20 19:27:28 +02:00
committed by GitHub
parent 97cfddc593
commit bd25b15aab
10 changed files with 227 additions and 2 deletions

1
.gitignore vendored
View File

@@ -12,6 +12,7 @@ tmp/
# will have compiled files and executables # will have compiled files and executables
debug/ debug/
target/ target/
.goose/
# These are backup files generated by rustfmt # These are backup files generated by rustfmt
**/*.rs.bk **/*.rs.bk

View File

@@ -30,6 +30,9 @@ static MODEL_SPECIFIC_LIMITS: Lazy<HashMap<&'static str, usize>> = Lazy::new(||
// Meta Llama models, https://github.com/meta-llama/llama-models/tree/main?tab=readme-ov-file#llama-models-1 // Meta Llama models, https://github.com/meta-llama/llama-models/tree/main?tab=readme-ov-file#llama-models-1
map.insert("llama3.2", 128_000); map.insert("llama3.2", 128_000);
map.insert("llama3.3", 128_000); map.insert("llama3.3", 128_000);
// x.ai Grok models, https://docs.x.ai/docs/overview
map.insert("grok", 131_072);
map map
}); });

View File

@@ -17,6 +17,7 @@ use super::{
sagemaker_tgi::SageMakerTgiProvider, sagemaker_tgi::SageMakerTgiProvider,
snowflake::SnowflakeProvider, snowflake::SnowflakeProvider,
venice::VeniceProvider, venice::VeniceProvider,
xai::XaiProvider,
}; };
use crate::model::ModelConfig; use crate::model::ModelConfig;
use anyhow::Result; use anyhow::Result;
@@ -52,6 +53,7 @@ pub fn providers() -> Vec<ProviderMetadata> {
SageMakerTgiProvider::metadata(), SageMakerTgiProvider::metadata(),
VeniceProvider::metadata(), VeniceProvider::metadata(),
SnowflakeProvider::metadata(), SnowflakeProvider::metadata(),
XaiProvider::metadata(),
] ]
} }
@@ -128,6 +130,7 @@ fn create_provider(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>>
"venice" => Ok(Arc::new(VeniceProvider::from_env(model)?)), "venice" => Ok(Arc::new(VeniceProvider::from_env(model)?)),
"snowflake" => Ok(Arc::new(SnowflakeProvider::from_env(model)?)), "snowflake" => Ok(Arc::new(SnowflakeProvider::from_env(model)?)),
"github_copilot" => Ok(Arc::new(GithubCopilotProvider::from_env(model)?)), "github_copilot" => Ok(Arc::new(GithubCopilotProvider::from_env(model)?)),
"xai" => Ok(Arc::new(XaiProvider::from_env(model)?)),
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)), _ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
} }
} }
@@ -259,7 +262,7 @@ mod tests {
} }
// Set only the required lead model // Set only the required lead model
env::set_var("GOOSE_LEAD_MODEL", "gpt-4o"); env::set_var("GOOSE_LEAD_MODEL", "grok-3");
// This should use defaults for all other values // This should use defaults for all other values
let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string())); let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));

View File

@@ -24,5 +24,6 @@ pub mod toolshim;
pub mod utils; pub mod utils;
pub mod utils_universal_openai_stream; pub mod utils_universal_openai_stream;
pub mod venice; pub mod venice;
pub mod xai;
pub use factory::{create, providers}; pub use factory::{create, providers};

View File

@@ -0,0 +1,181 @@
use super::errors::ProviderError;
use crate::message::Message;
use crate::model::ModelConfig;
use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
use crate::providers::utils::get_model;
use anyhow::Result;
use async_trait::async_trait;
use mcp_core::Tool;
use reqwest::{Client, StatusCode};
use serde_json::Value;
use std::time::Duration;
use url::Url;
pub const XAI_API_HOST: &str = "https://api.x.ai/v1";
pub const XAI_DEFAULT_MODEL: &str = "grok-3";
pub const XAI_KNOWN_MODELS: &[&str] = &[
"grok-3",
"grok-3-fast",
"grok-3-mini",
"grok-3-mini-fast",
"grok-2-vision-1212",
"grok-2-image-1212",
"grok-2-1212",
"grok-3-latest",
"grok-3-fast-latest",
"grok-3-mini-latest",
"grok-3-mini-fast-latest",
"grok-2-vision",
"grok-2-vision-latest",
"grok-2-image",
"grok-2-image-latest",
"grok-2",
"grok-2-latest",
];
pub const XAI_DOC_URL: &str = "https://docs.x.ai/docs/overview";
#[derive(serde::Serialize)]
pub struct XaiProvider {
#[serde(skip)]
client: Client,
host: String,
api_key: String,
model: ModelConfig,
}
impl Default for XaiProvider {
fn default() -> Self {
let model = ModelConfig::new(XaiProvider::metadata().default_model);
XaiProvider::from_env(model).expect("Failed to initialize xAI provider")
}
}
impl XaiProvider {
pub fn from_env(model: ModelConfig) -> Result<Self> {
let config = crate::config::Config::global();
let api_key: String = config.get_secret("XAI_API_KEY")?;
let host: String = config
.get_param("XAI_HOST")
.unwrap_or_else(|_| XAI_API_HOST.to_string());
let client = Client::builder()
.timeout(Duration::from_secs(600))
.build()?;
Ok(Self {
client,
host,
api_key,
model,
})
}
async fn post(&self, payload: Value) -> anyhow::Result<Value, ProviderError> {
// Ensure the host ends with a slash for proper URL joining
let host = if self.host.ends_with('/') {
self.host.clone()
} else {
format!("{}/", self.host)
};
let base_url = Url::parse(&host)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
let url = base_url.join("chat/completions").map_err(|e| {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;
tracing::debug!("xAI API URL: {}", url);
tracing::debug!("xAI request model: {:?}", self.model.model_name);
let response = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&payload)
.send()
.await?;
let status = response.status();
let payload: Option<Value> = response.json().await.ok();
match status {
StatusCode::OK => payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ),
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \
Status: {}. Response: {:?}", status, payload)))
}
StatusCode::PAYLOAD_TOO_LARGE => {
Err(ProviderError::ContextLengthExceeded(format!("{:?}", payload)))
}
StatusCode::TOO_MANY_REQUESTS => {
Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))
}
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
Err(ProviderError::ServerError(format!("{:?}", payload)))
}
_ => {
tracing::debug!(
"{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload)
);
Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status)))
}
}
}
}
#[async_trait]
impl Provider for XaiProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::new(
"xai",
"xAI",
"Grok models from xAI, including reasoning and multimodal capabilities",
XAI_DEFAULT_MODEL,
XAI_KNOWN_MODELS.to_vec(),
XAI_DOC_URL,
vec![
ConfigKey::new("XAI_API_KEY", true, true, None),
ConfigKey::new("XAI_HOST", false, false, Some(XAI_API_HOST)),
],
)
}
fn get_model_config(&self) -> ModelConfig {
self.model.clone()
}
#[tracing::instrument(
skip(self, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
)]
async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> anyhow::Result<(Message, ProviderUsage), ProviderError> {
let payload = create_request(
&self.model,
system,
messages,
tools,
&super::utils::ImageFormat::OpenAi,
)?;
let response = self.post(payload.clone()).await?;
let message = response_to_message(response.clone())?;
let usage = match get_usage(&response) {
Ok(usage) => usage,
Err(ProviderError::UsageError(e)) => {
tracing::debug!("Failed to get usage data: {}", e);
Usage::default()
}
Err(e) => return Err(e),
};
let model = get_model(&response);
super::utils::emit_debug_trace(&self.model, &payload, &response, &usage);
Ok((message, ProviderUsage::new(model, usage)))
}
}

View File

@@ -33,6 +33,7 @@ Goose relies heavily on tool calling capabilities and currently works best with
| [OpenRouter](https://openrouter.ai/) | API gateway for unified access to various models with features like rate-limiting management. | `OPENROUTER_API_KEY` | | [OpenRouter](https://openrouter.ai/) | API gateway for unified access to various models with features like rate-limiting management. | `OPENROUTER_API_KEY` |
| [Snowflake](https://docs.snowflake.com/user-guide/snowflake-cortex/aisql#choosing-a-model) | Access the latest models using Snowflake Cortex services, including Claude models. **Requires a Snowflake account and programmatic access token (PAT)**. | `SNOWFLAKE_HOST`, `SNOWFLAKE_TOKEN` | | [Snowflake](https://docs.snowflake.com/user-guide/snowflake-cortex/aisql#choosing-a-model) | Access the latest models using Snowflake Cortex services, including Claude models. **Requires a Snowflake account and programmatic access token (PAT)**. | `SNOWFLAKE_HOST`, `SNOWFLAKE_TOKEN` |
| [Venice AI](https://venice.ai/home) | Provides access to open source models like Llama, Mistral, and Qwen while prioritizing user privacy. **Requires an account and an [API key](https://docs.venice.ai/overview/guides/generating-api-key)**. | `VENICE_API_KEY`, `VENICE_HOST` (optional), `VENICE_BASE_PATH` (optional), `VENICE_MODELS_PATH` (optional) | | [Venice AI](https://venice.ai/home) | Provides access to open source models like Llama, Mistral, and Qwen while prioritizing user privacy. **Requires an account and an [API key](https://docs.venice.ai/overview/guides/generating-api-key)**. | `VENICE_API_KEY`, `VENICE_HOST` (optional), `VENICE_BASE_PATH` (optional), `VENICE_MODELS_PATH` (optional) |
| [xAI](https://x.ai/) | Access to xAI's Grok models including grok-3, grok-3-mini, and grok-3-fast with 131,072 token context window. | `XAI_API_KEY`, `XAI_HOST` (optional) |
## Configure Provider ## Configure Provider

View File

@@ -97,6 +97,25 @@ export const PROVIDER_REGISTRY: ProviderRegistry[] = [
], ],
}, },
}, },
{
name: 'xAI',
details: {
id: 'xai',
name: 'xAI',
description: 'Access Grok models from xAI, including reasoning and multimodal capabilities',
parameters: [
{
name: 'XAI_API_KEY',
is_secret: true,
},
{
name: 'XAI_HOST',
is_secret: false,
default: 'https://api.x.ai/v1',
},
],
},
},
{ {
name: 'Google', name: 'Google',
details: { details: {

View File

@@ -6,6 +6,7 @@ import OllamaLogo from './icons/ollama@3x.png';
import DatabricksLogo from './icons/databricks@3x.png'; import DatabricksLogo from './icons/databricks@3x.png';
import OpenRouterLogo from './icons/openrouter@3x.png'; import OpenRouterLogo from './icons/openrouter@3x.png';
import SnowflakeLogo from './icons/snowflake@3x.png'; import SnowflakeLogo from './icons/snowflake@3x.png';
import XaiLogo from './icons/xai@3x.png';
import DefaultLogo from './icons/default@3x.png'; import DefaultLogo from './icons/default@3x.png';
// Map provider names to their logos // Map provider names to their logos
@@ -18,6 +19,7 @@ const providerLogos: Record<string, string> = {
databricks: DatabricksLogo, databricks: DatabricksLogo,
openrouter: OpenRouterLogo, openrouter: OpenRouterLogo,
snowflake: SnowflakeLogo, snowflake: SnowflakeLogo,
xai: XaiLogo,
default: DefaultLogo, default: DefaultLogo,
}; };
@@ -30,10 +32,24 @@ export default function ProviderLogo({ providerName }: ProviderLogoProps) {
const logoKey = providerName.toLowerCase(); const logoKey = providerName.toLowerCase();
const logo = providerLogos[logoKey] || DefaultLogo; const logo = providerLogos[logoKey] || DefaultLogo;
// Special handling for xAI logo
const isXai = logoKey === 'xai';
const imageStyle = isXai ? { filter: 'invert(1)', opacity: 0.9 } : {};
// Use smaller size for xAI logo to fit better in circle
const imageClassName = isXai
? 'w-8 h-8 object-contain' // Smaller size for xAI
: 'w-16 h-16 object-contain'; // Default size for others
return ( return (
<div className="flex justify-center mb-2"> <div className="flex justify-center mb-2">
<div className="w-12 h-12 bg-black rounded-full overflow-hidden flex items-center justify-center"> <div className="w-12 h-12 bg-black rounded-full overflow-hidden flex items-center justify-center">
<img src={logo} alt={`${providerName} logo`} className="w-16 h-16 object-contain" /> <img
src={logo}
alt={`${providerName} logo`}
className={imageClassName}
style={imageStyle}
/>
</div> </div>
</div> </div>
); );

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.7 KiB