mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 14:44:21 +01:00
Add xAI Provider Support for Grok Models (#2976)
Co-authored-by: jack <jack@deck.local>
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -12,6 +12,7 @@ tmp/
|
|||||||
# will have compiled files and executables
|
# will have compiled files and executables
|
||||||
debug/
|
debug/
|
||||||
target/
|
target/
|
||||||
|
.goose/
|
||||||
|
|
||||||
# These are backup files generated by rustfmt
|
# These are backup files generated by rustfmt
|
||||||
**/*.rs.bk
|
**/*.rs.bk
|
||||||
|
|||||||
@@ -30,6 +30,9 @@ static MODEL_SPECIFIC_LIMITS: Lazy<HashMap<&'static str, usize>> = Lazy::new(||
|
|||||||
// Meta Llama models, https://github.com/meta-llama/llama-models/tree/main?tab=readme-ov-file#llama-models-1
|
// Meta Llama models, https://github.com/meta-llama/llama-models/tree/main?tab=readme-ov-file#llama-models-1
|
||||||
map.insert("llama3.2", 128_000);
|
map.insert("llama3.2", 128_000);
|
||||||
map.insert("llama3.3", 128_000);
|
map.insert("llama3.3", 128_000);
|
||||||
|
|
||||||
|
// x.ai Grok models, https://docs.x.ai/docs/overview
|
||||||
|
map.insert("grok", 131_072);
|
||||||
map
|
map
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ use super::{
|
|||||||
sagemaker_tgi::SageMakerTgiProvider,
|
sagemaker_tgi::SageMakerTgiProvider,
|
||||||
snowflake::SnowflakeProvider,
|
snowflake::SnowflakeProvider,
|
||||||
venice::VeniceProvider,
|
venice::VeniceProvider,
|
||||||
|
xai::XaiProvider,
|
||||||
};
|
};
|
||||||
use crate::model::ModelConfig;
|
use crate::model::ModelConfig;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
@@ -52,6 +53,7 @@ pub fn providers() -> Vec<ProviderMetadata> {
|
|||||||
SageMakerTgiProvider::metadata(),
|
SageMakerTgiProvider::metadata(),
|
||||||
VeniceProvider::metadata(),
|
VeniceProvider::metadata(),
|
||||||
SnowflakeProvider::metadata(),
|
SnowflakeProvider::metadata(),
|
||||||
|
XaiProvider::metadata(),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -128,6 +130,7 @@ fn create_provider(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>>
|
|||||||
"venice" => Ok(Arc::new(VeniceProvider::from_env(model)?)),
|
"venice" => Ok(Arc::new(VeniceProvider::from_env(model)?)),
|
||||||
"snowflake" => Ok(Arc::new(SnowflakeProvider::from_env(model)?)),
|
"snowflake" => Ok(Arc::new(SnowflakeProvider::from_env(model)?)),
|
||||||
"github_copilot" => Ok(Arc::new(GithubCopilotProvider::from_env(model)?)),
|
"github_copilot" => Ok(Arc::new(GithubCopilotProvider::from_env(model)?)),
|
||||||
|
"xai" => Ok(Arc::new(XaiProvider::from_env(model)?)),
|
||||||
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
|
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -259,7 +262,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set only the required lead model
|
// Set only the required lead model
|
||||||
env::set_var("GOOSE_LEAD_MODEL", "gpt-4o");
|
env::set_var("GOOSE_LEAD_MODEL", "grok-3");
|
||||||
|
|
||||||
// This should use defaults for all other values
|
// This should use defaults for all other values
|
||||||
let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
|
let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
|
||||||
|
|||||||
@@ -24,5 +24,6 @@ pub mod toolshim;
|
|||||||
pub mod utils;
|
pub mod utils;
|
||||||
pub mod utils_universal_openai_stream;
|
pub mod utils_universal_openai_stream;
|
||||||
pub mod venice;
|
pub mod venice;
|
||||||
|
pub mod xai;
|
||||||
|
|
||||||
pub use factory::{create, providers};
|
pub use factory::{create, providers};
|
||||||
|
|||||||
181
crates/goose/src/providers/xai.rs
Normal file
181
crates/goose/src/providers/xai.rs
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
use super::errors::ProviderError;
|
||||||
|
use crate::message::Message;
|
||||||
|
use crate::model::ModelConfig;
|
||||||
|
use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
|
||||||
|
use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
|
||||||
|
use crate::providers::utils::get_model;
|
||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use mcp_core::Tool;
|
||||||
|
use reqwest::{Client, StatusCode};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::time::Duration;
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
|
pub const XAI_API_HOST: &str = "https://api.x.ai/v1";
|
||||||
|
pub const XAI_DEFAULT_MODEL: &str = "grok-3";
|
||||||
|
pub const XAI_KNOWN_MODELS: &[&str] = &[
|
||||||
|
"grok-3",
|
||||||
|
"grok-3-fast",
|
||||||
|
"grok-3-mini",
|
||||||
|
"grok-3-mini-fast",
|
||||||
|
"grok-2-vision-1212",
|
||||||
|
"grok-2-image-1212",
|
||||||
|
"grok-2-1212",
|
||||||
|
"grok-3-latest",
|
||||||
|
"grok-3-fast-latest",
|
||||||
|
"grok-3-mini-latest",
|
||||||
|
"grok-3-mini-fast-latest",
|
||||||
|
"grok-2-vision",
|
||||||
|
"grok-2-vision-latest",
|
||||||
|
"grok-2-image",
|
||||||
|
"grok-2-image-latest",
|
||||||
|
"grok-2",
|
||||||
|
"grok-2-latest",
|
||||||
|
];
|
||||||
|
|
||||||
|
pub const XAI_DOC_URL: &str = "https://docs.x.ai/docs/overview";
|
||||||
|
|
||||||
|
#[derive(serde::Serialize)]
|
||||||
|
pub struct XaiProvider {
|
||||||
|
#[serde(skip)]
|
||||||
|
client: Client,
|
||||||
|
host: String,
|
||||||
|
api_key: String,
|
||||||
|
model: ModelConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for XaiProvider {
|
||||||
|
fn default() -> Self {
|
||||||
|
let model = ModelConfig::new(XaiProvider::metadata().default_model);
|
||||||
|
XaiProvider::from_env(model).expect("Failed to initialize xAI provider")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl XaiProvider {
|
||||||
|
pub fn from_env(model: ModelConfig) -> Result<Self> {
|
||||||
|
let config = crate::config::Config::global();
|
||||||
|
let api_key: String = config.get_secret("XAI_API_KEY")?;
|
||||||
|
let host: String = config
|
||||||
|
.get_param("XAI_HOST")
|
||||||
|
.unwrap_or_else(|_| XAI_API_HOST.to_string());
|
||||||
|
|
||||||
|
let client = Client::builder()
|
||||||
|
.timeout(Duration::from_secs(600))
|
||||||
|
.build()?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
client,
|
||||||
|
host,
|
||||||
|
api_key,
|
||||||
|
model,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn post(&self, payload: Value) -> anyhow::Result<Value, ProviderError> {
|
||||||
|
// Ensure the host ends with a slash for proper URL joining
|
||||||
|
let host = if self.host.ends_with('/') {
|
||||||
|
self.host.clone()
|
||||||
|
} else {
|
||||||
|
format!("{}/", self.host)
|
||||||
|
};
|
||||||
|
let base_url = Url::parse(&host)
|
||||||
|
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||||
|
let url = base_url.join("chat/completions").map_err(|e| {
|
||||||
|
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
tracing::debug!("xAI API URL: {}", url);
|
||||||
|
tracing::debug!("xAI request model: {:?}", self.model.model_name);
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(url)
|
||||||
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
|
.json(&payload)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
let payload: Option<Value> = response.json().await.ok();
|
||||||
|
|
||||||
|
match status {
|
||||||
|
StatusCode::OK => payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ),
|
||||||
|
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
|
||||||
|
Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \
|
||||||
|
Status: {}. Response: {:?}", status, payload)))
|
||||||
|
}
|
||||||
|
StatusCode::PAYLOAD_TOO_LARGE => {
|
||||||
|
Err(ProviderError::ContextLengthExceeded(format!("{:?}", payload)))
|
||||||
|
}
|
||||||
|
StatusCode::TOO_MANY_REQUESTS => {
|
||||||
|
Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))
|
||||||
|
}
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
|
||||||
|
Err(ProviderError::ServerError(format!("{:?}", payload)))
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
tracing::debug!(
|
||||||
|
"{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload)
|
||||||
|
);
|
||||||
|
Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for XaiProvider {
|
||||||
|
fn metadata() -> ProviderMetadata {
|
||||||
|
ProviderMetadata::new(
|
||||||
|
"xai",
|
||||||
|
"xAI",
|
||||||
|
"Grok models from xAI, including reasoning and multimodal capabilities",
|
||||||
|
XAI_DEFAULT_MODEL,
|
||||||
|
XAI_KNOWN_MODELS.to_vec(),
|
||||||
|
XAI_DOC_URL,
|
||||||
|
vec![
|
||||||
|
ConfigKey::new("XAI_API_KEY", true, true, None),
|
||||||
|
ConfigKey::new("XAI_HOST", false, false, Some(XAI_API_HOST)),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_model_config(&self) -> ModelConfig {
|
||||||
|
self.model.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
skip(self, system, messages, tools),
|
||||||
|
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
|
||||||
|
)]
|
||||||
|
async fn complete(
|
||||||
|
&self,
|
||||||
|
system: &str,
|
||||||
|
messages: &[Message],
|
||||||
|
tools: &[Tool],
|
||||||
|
) -> anyhow::Result<(Message, ProviderUsage), ProviderError> {
|
||||||
|
let payload = create_request(
|
||||||
|
&self.model,
|
||||||
|
system,
|
||||||
|
messages,
|
||||||
|
tools,
|
||||||
|
&super::utils::ImageFormat::OpenAi,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let response = self.post(payload.clone()).await?;
|
||||||
|
|
||||||
|
let message = response_to_message(response.clone())?;
|
||||||
|
let usage = match get_usage(&response) {
|
||||||
|
Ok(usage) => usage,
|
||||||
|
Err(ProviderError::UsageError(e)) => {
|
||||||
|
tracing::debug!("Failed to get usage data: {}", e);
|
||||||
|
Usage::default()
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e),
|
||||||
|
};
|
||||||
|
let model = get_model(&response);
|
||||||
|
super::utils::emit_debug_trace(&self.model, &payload, &response, &usage);
|
||||||
|
Ok((message, ProviderUsage::new(model, usage)))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -33,6 +33,7 @@ Goose relies heavily on tool calling capabilities and currently works best with
|
|||||||
| [OpenRouter](https://openrouter.ai/) | API gateway for unified access to various models with features like rate-limiting management. | `OPENROUTER_API_KEY` |
|
| [OpenRouter](https://openrouter.ai/) | API gateway for unified access to various models with features like rate-limiting management. | `OPENROUTER_API_KEY` |
|
||||||
| [Snowflake](https://docs.snowflake.com/user-guide/snowflake-cortex/aisql#choosing-a-model) | Access the latest models using Snowflake Cortex services, including Claude models. **Requires a Snowflake account and programmatic access token (PAT)**. | `SNOWFLAKE_HOST`, `SNOWFLAKE_TOKEN` |
|
| [Snowflake](https://docs.snowflake.com/user-guide/snowflake-cortex/aisql#choosing-a-model) | Access the latest models using Snowflake Cortex services, including Claude models. **Requires a Snowflake account and programmatic access token (PAT)**. | `SNOWFLAKE_HOST`, `SNOWFLAKE_TOKEN` |
|
||||||
| [Venice AI](https://venice.ai/home) | Provides access to open source models like Llama, Mistral, and Qwen while prioritizing user privacy. **Requires an account and an [API key](https://docs.venice.ai/overview/guides/generating-api-key)**. | `VENICE_API_KEY`, `VENICE_HOST` (optional), `VENICE_BASE_PATH` (optional), `VENICE_MODELS_PATH` (optional) |
|
| [Venice AI](https://venice.ai/home) | Provides access to open source models like Llama, Mistral, and Qwen while prioritizing user privacy. **Requires an account and an [API key](https://docs.venice.ai/overview/guides/generating-api-key)**. | `VENICE_API_KEY`, `VENICE_HOST` (optional), `VENICE_BASE_PATH` (optional), `VENICE_MODELS_PATH` (optional) |
|
||||||
|
| [xAI](https://x.ai/) | Access to xAI's Grok models including grok-3, grok-3-mini, and grok-3-fast with 131,072 token context window. | `XAI_API_KEY`, `XAI_HOST` (optional) |
|
||||||
|
|
||||||
|
|
||||||
## Configure Provider
|
## Configure Provider
|
||||||
|
|||||||
@@ -97,6 +97,25 @@ export const PROVIDER_REGISTRY: ProviderRegistry[] = [
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: 'xAI',
|
||||||
|
details: {
|
||||||
|
id: 'xai',
|
||||||
|
name: 'xAI',
|
||||||
|
description: 'Access Grok models from xAI, including reasoning and multimodal capabilities',
|
||||||
|
parameters: [
|
||||||
|
{
|
||||||
|
name: 'XAI_API_KEY',
|
||||||
|
is_secret: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'XAI_HOST',
|
||||||
|
is_secret: false,
|
||||||
|
default: 'https://api.x.ai/v1',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: 'Google',
|
name: 'Google',
|
||||||
details: {
|
details: {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import OllamaLogo from './icons/ollama@3x.png';
|
|||||||
import DatabricksLogo from './icons/databricks@3x.png';
|
import DatabricksLogo from './icons/databricks@3x.png';
|
||||||
import OpenRouterLogo from './icons/openrouter@3x.png';
|
import OpenRouterLogo from './icons/openrouter@3x.png';
|
||||||
import SnowflakeLogo from './icons/snowflake@3x.png';
|
import SnowflakeLogo from './icons/snowflake@3x.png';
|
||||||
|
import XaiLogo from './icons/xai@3x.png';
|
||||||
import DefaultLogo from './icons/default@3x.png';
|
import DefaultLogo from './icons/default@3x.png';
|
||||||
|
|
||||||
// Map provider names to their logos
|
// Map provider names to their logos
|
||||||
@@ -18,6 +19,7 @@ const providerLogos: Record<string, string> = {
|
|||||||
databricks: DatabricksLogo,
|
databricks: DatabricksLogo,
|
||||||
openrouter: OpenRouterLogo,
|
openrouter: OpenRouterLogo,
|
||||||
snowflake: SnowflakeLogo,
|
snowflake: SnowflakeLogo,
|
||||||
|
xai: XaiLogo,
|
||||||
default: DefaultLogo,
|
default: DefaultLogo,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -30,10 +32,24 @@ export default function ProviderLogo({ providerName }: ProviderLogoProps) {
|
|||||||
const logoKey = providerName.toLowerCase();
|
const logoKey = providerName.toLowerCase();
|
||||||
const logo = providerLogos[logoKey] || DefaultLogo;
|
const logo = providerLogos[logoKey] || DefaultLogo;
|
||||||
|
|
||||||
|
// Special handling for xAI logo
|
||||||
|
const isXai = logoKey === 'xai';
|
||||||
|
const imageStyle = isXai ? { filter: 'invert(1)', opacity: 0.9 } : {};
|
||||||
|
|
||||||
|
// Use smaller size for xAI logo to fit better in circle
|
||||||
|
const imageClassName = isXai
|
||||||
|
? 'w-8 h-8 object-contain' // Smaller size for xAI
|
||||||
|
: 'w-16 h-16 object-contain'; // Default size for others
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex justify-center mb-2">
|
<div className="flex justify-center mb-2">
|
||||||
<div className="w-12 h-12 bg-black rounded-full overflow-hidden flex items-center justify-center">
|
<div className="w-12 h-12 bg-black rounded-full overflow-hidden flex items-center justify-center">
|
||||||
<img src={logo} alt={`${providerName} logo`} className="w-16 h-16 object-contain" />
|
<img
|
||||||
|
src={logo}
|
||||||
|
alt={`${providerName} logo`}
|
||||||
|
className={imageClassName}
|
||||||
|
style={imageStyle}
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
Binary file not shown.
|
After Width: | Height: | Size: 3.7 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 3.7 KiB |
Reference in New Issue
Block a user