mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 06:34:26 +01:00
Add xAI Provider Support for Grok Models (#2976)
Co-authored-by: jack <jack@deck.local>
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -12,6 +12,7 @@ tmp/
|
||||
# will have compiled files and executables
|
||||
debug/
|
||||
target/
|
||||
.goose/
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
||||
@@ -30,6 +30,9 @@ static MODEL_SPECIFIC_LIMITS: Lazy<HashMap<&'static str, usize>> = Lazy::new(||
|
||||
// Meta Llama models, https://github.com/meta-llama/llama-models/tree/main?tab=readme-ov-file#llama-models-1
|
||||
map.insert("llama3.2", 128_000);
|
||||
map.insert("llama3.3", 128_000);
|
||||
|
||||
// x.ai Grok models, https://docs.x.ai/docs/overview
|
||||
map.insert("grok", 131_072);
|
||||
map
|
||||
});
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ use super::{
|
||||
sagemaker_tgi::SageMakerTgiProvider,
|
||||
snowflake::SnowflakeProvider,
|
||||
venice::VeniceProvider,
|
||||
xai::XaiProvider,
|
||||
};
|
||||
use crate::model::ModelConfig;
|
||||
use anyhow::Result;
|
||||
@@ -52,6 +53,7 @@ pub fn providers() -> Vec<ProviderMetadata> {
|
||||
SageMakerTgiProvider::metadata(),
|
||||
VeniceProvider::metadata(),
|
||||
SnowflakeProvider::metadata(),
|
||||
XaiProvider::metadata(),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -128,6 +130,7 @@ fn create_provider(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>>
|
||||
"venice" => Ok(Arc::new(VeniceProvider::from_env(model)?)),
|
||||
"snowflake" => Ok(Arc::new(SnowflakeProvider::from_env(model)?)),
|
||||
"github_copilot" => Ok(Arc::new(GithubCopilotProvider::from_env(model)?)),
|
||||
"xai" => Ok(Arc::new(XaiProvider::from_env(model)?)),
|
||||
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
|
||||
}
|
||||
}
|
||||
@@ -259,7 +262,7 @@ mod tests {
|
||||
}
|
||||
|
||||
// Set only the required lead model
|
||||
env::set_var("GOOSE_LEAD_MODEL", "gpt-4o");
|
||||
env::set_var("GOOSE_LEAD_MODEL", "grok-3");
|
||||
|
||||
// This should use defaults for all other values
|
||||
let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
|
||||
|
||||
@@ -24,5 +24,6 @@ pub mod toolshim;
|
||||
pub mod utils;
|
||||
pub mod utils_universal_openai_stream;
|
||||
pub mod venice;
|
||||
pub mod xai;
|
||||
|
||||
pub use factory::{create, providers};
|
||||
|
||||
181
crates/goose/src/providers/xai.rs
Normal file
181
crates/goose/src/providers/xai.rs
Normal file
@@ -0,0 +1,181 @@
|
||||
use super::errors::ProviderError;
|
||||
use crate::message::Message;
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
|
||||
use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
|
||||
use crate::providers::utils::get_model;
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use mcp_core::Tool;
|
||||
use reqwest::{Client, StatusCode};
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
use url::Url;
|
||||
|
||||
pub const XAI_API_HOST: &str = "https://api.x.ai/v1";
|
||||
pub const XAI_DEFAULT_MODEL: &str = "grok-3";
|
||||
pub const XAI_KNOWN_MODELS: &[&str] = &[
|
||||
"grok-3",
|
||||
"grok-3-fast",
|
||||
"grok-3-mini",
|
||||
"grok-3-mini-fast",
|
||||
"grok-2-vision-1212",
|
||||
"grok-2-image-1212",
|
||||
"grok-2-1212",
|
||||
"grok-3-latest",
|
||||
"grok-3-fast-latest",
|
||||
"grok-3-mini-latest",
|
||||
"grok-3-mini-fast-latest",
|
||||
"grok-2-vision",
|
||||
"grok-2-vision-latest",
|
||||
"grok-2-image",
|
||||
"grok-2-image-latest",
|
||||
"grok-2",
|
||||
"grok-2-latest",
|
||||
];
|
||||
|
||||
pub const XAI_DOC_URL: &str = "https://docs.x.ai/docs/overview";
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
pub struct XaiProvider {
|
||||
#[serde(skip)]
|
||||
client: Client,
|
||||
host: String,
|
||||
api_key: String,
|
||||
model: ModelConfig,
|
||||
}
|
||||
|
||||
impl Default for XaiProvider {
|
||||
fn default() -> Self {
|
||||
let model = ModelConfig::new(XaiProvider::metadata().default_model);
|
||||
XaiProvider::from_env(model).expect("Failed to initialize xAI provider")
|
||||
}
|
||||
}
|
||||
|
||||
impl XaiProvider {
|
||||
pub fn from_env(model: ModelConfig) -> Result<Self> {
|
||||
let config = crate::config::Config::global();
|
||||
let api_key: String = config.get_secret("XAI_API_KEY")?;
|
||||
let host: String = config
|
||||
.get_param("XAI_HOST")
|
||||
.unwrap_or_else(|_| XAI_API_HOST.to_string());
|
||||
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(600))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
host,
|
||||
api_key,
|
||||
model,
|
||||
})
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> anyhow::Result<Value, ProviderError> {
|
||||
// Ensure the host ends with a slash for proper URL joining
|
||||
let host = if self.host.ends_with('/') {
|
||||
self.host.clone()
|
||||
} else {
|
||||
format!("{}/", self.host)
|
||||
};
|
||||
let base_url = Url::parse(&host)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
let url = base_url.join("chat/completions").map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
||||
})?;
|
||||
|
||||
tracing::debug!("xAI API URL: {}", url);
|
||||
tracing::debug!("xAI request model: {:?}", self.model.model_name);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
let payload: Option<Value> = response.json().await.ok();
|
||||
|
||||
match status {
|
||||
StatusCode::OK => payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ),
|
||||
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
|
||||
Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \
|
||||
Status: {}. Response: {:?}", status, payload)))
|
||||
}
|
||||
StatusCode::PAYLOAD_TOO_LARGE => {
|
||||
Err(ProviderError::ContextLengthExceeded(format!("{:?}", payload)))
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))
|
||||
}
|
||||
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
|
||||
Err(ProviderError::ServerError(format!("{:?}", payload)))
|
||||
}
|
||||
_ => {
|
||||
tracing::debug!(
|
||||
"{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload)
|
||||
);
|
||||
Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for XaiProvider {
|
||||
fn metadata() -> ProviderMetadata {
|
||||
ProviderMetadata::new(
|
||||
"xai",
|
||||
"xAI",
|
||||
"Grok models from xAI, including reasoning and multimodal capabilities",
|
||||
XAI_DEFAULT_MODEL,
|
||||
XAI_KNOWN_MODELS.to_vec(),
|
||||
XAI_DOC_URL,
|
||||
vec![
|
||||
ConfigKey::new("XAI_API_KEY", true, true, None),
|
||||
ConfigKey::new("XAI_HOST", false, false, Some(XAI_API_HOST)),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
self.model.clone()
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip(self, system, messages, tools),
|
||||
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
|
||||
)]
|
||||
async fn complete(
|
||||
&self,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> anyhow::Result<(Message, ProviderUsage), ProviderError> {
|
||||
let payload = create_request(
|
||||
&self.model,
|
||||
system,
|
||||
messages,
|
||||
tools,
|
||||
&super::utils::ImageFormat::OpenAi,
|
||||
)?;
|
||||
|
||||
let response = self.post(payload.clone()).await?;
|
||||
|
||||
let message = response_to_message(response.clone())?;
|
||||
let usage = match get_usage(&response) {
|
||||
Ok(usage) => usage,
|
||||
Err(ProviderError::UsageError(e)) => {
|
||||
tracing::debug!("Failed to get usage data: {}", e);
|
||||
Usage::default()
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
let model = get_model(&response);
|
||||
super::utils::emit_debug_trace(&self.model, &payload, &response, &usage);
|
||||
Ok((message, ProviderUsage::new(model, usage)))
|
||||
}
|
||||
}
|
||||
@@ -33,6 +33,7 @@ Goose relies heavily on tool calling capabilities and currently works best with
|
||||
| [OpenRouter](https://openrouter.ai/) | API gateway for unified access to various models with features like rate-limiting management. | `OPENROUTER_API_KEY` |
|
||||
| [Snowflake](https://docs.snowflake.com/user-guide/snowflake-cortex/aisql#choosing-a-model) | Access the latest models using Snowflake Cortex services, including Claude models. **Requires a Snowflake account and programmatic access token (PAT)**. | `SNOWFLAKE_HOST`, `SNOWFLAKE_TOKEN` |
|
||||
| [Venice AI](https://venice.ai/home) | Provides access to open source models like Llama, Mistral, and Qwen while prioritizing user privacy. **Requires an account and an [API key](https://docs.venice.ai/overview/guides/generating-api-key)**. | `VENICE_API_KEY`, `VENICE_HOST` (optional), `VENICE_BASE_PATH` (optional), `VENICE_MODELS_PATH` (optional) |
|
||||
| [xAI](https://x.ai/) | Access to xAI's Grok models including grok-3, grok-3-mini, and grok-3-fast with 131,072 token context window. | `XAI_API_KEY`, `XAI_HOST` (optional) |
|
||||
|
||||
|
||||
## Configure Provider
|
||||
|
||||
@@ -97,6 +97,25 @@ export const PROVIDER_REGISTRY: ProviderRegistry[] = [
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
name: 'xAI',
|
||||
details: {
|
||||
id: 'xai',
|
||||
name: 'xAI',
|
||||
description: 'Access Grok models from xAI, including reasoning and multimodal capabilities',
|
||||
parameters: [
|
||||
{
|
||||
name: 'XAI_API_KEY',
|
||||
is_secret: true,
|
||||
},
|
||||
{
|
||||
name: 'XAI_HOST',
|
||||
is_secret: false,
|
||||
default: 'https://api.x.ai/v1',
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
name: 'Google',
|
||||
details: {
|
||||
|
||||
@@ -6,6 +6,7 @@ import OllamaLogo from './icons/ollama@3x.png';
|
||||
import DatabricksLogo from './icons/databricks@3x.png';
|
||||
import OpenRouterLogo from './icons/openrouter@3x.png';
|
||||
import SnowflakeLogo from './icons/snowflake@3x.png';
|
||||
import XaiLogo from './icons/xai@3x.png';
|
||||
import DefaultLogo from './icons/default@3x.png';
|
||||
|
||||
// Map provider names to their logos
|
||||
@@ -18,6 +19,7 @@ const providerLogos: Record<string, string> = {
|
||||
databricks: DatabricksLogo,
|
||||
openrouter: OpenRouterLogo,
|
||||
snowflake: SnowflakeLogo,
|
||||
xai: XaiLogo,
|
||||
default: DefaultLogo,
|
||||
};
|
||||
|
||||
@@ -30,10 +32,24 @@ export default function ProviderLogo({ providerName }: ProviderLogoProps) {
|
||||
const logoKey = providerName.toLowerCase();
|
||||
const logo = providerLogos[logoKey] || DefaultLogo;
|
||||
|
||||
// Special handling for xAI logo
|
||||
const isXai = logoKey === 'xai';
|
||||
const imageStyle = isXai ? { filter: 'invert(1)', opacity: 0.9 } : {};
|
||||
|
||||
// Use smaller size for xAI logo to fit better in circle
|
||||
const imageClassName = isXai
|
||||
? 'w-8 h-8 object-contain' // Smaller size for xAI
|
||||
: 'w-16 h-16 object-contain'; // Default size for others
|
||||
|
||||
return (
|
||||
<div className="flex justify-center mb-2">
|
||||
<div className="w-12 h-12 bg-black rounded-full overflow-hidden flex items-center justify-center">
|
||||
<img src={logo} alt={`${providerName} logo`} className="w-16 h-16 object-contain" />
|
||||
<img
|
||||
src={logo}
|
||||
alt={`${providerName} logo`}
|
||||
className={imageClassName}
|
||||
style={imageStyle}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 3.7 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 3.7 KiB |
Reference in New Issue
Block a user