mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 06:34:26 +01:00
Damien/sagemaker tgi (#2888)
This commit is contained in:
26
Cargo.lock
generated
26
Cargo.lock
generated
@@ -704,6 +704,7 @@ dependencies = [
|
||||
"aws-credential-types",
|
||||
"aws-sigv4",
|
||||
"aws-smithy-async",
|
||||
"aws-smithy-eventstream",
|
||||
"aws-smithy-http 0.60.12",
|
||||
"aws-smithy-runtime",
|
||||
"aws-smithy-runtime-api",
|
||||
@@ -767,6 +768,29 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aws-sdk-sagemakerruntime"
|
||||
version = "1.63.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8c3188bb9f962a9e1781c917dbe7f016ab9430e4bd81ba7daf422e58d86a3595"
|
||||
dependencies = [
|
||||
"aws-credential-types",
|
||||
"aws-runtime",
|
||||
"aws-smithy-async",
|
||||
"aws-smithy-eventstream",
|
||||
"aws-smithy-http 0.61.1",
|
||||
"aws-smithy-json",
|
||||
"aws-smithy-runtime",
|
||||
"aws-smithy-runtime-api",
|
||||
"aws-smithy-types",
|
||||
"aws-types",
|
||||
"bytes",
|
||||
"http 0.2.12",
|
||||
"once_cell",
|
||||
"regex-lite",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aws-sdk-sso"
|
||||
version = "1.61.0"
|
||||
@@ -841,6 +865,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9bfe75fad52793ce6dec0dc3d4b1f388f038b5eb866c8d4d7f3a8e21b5ea5051"
|
||||
dependencies = [
|
||||
"aws-credential-types",
|
||||
"aws-smithy-eventstream",
|
||||
"aws-smithy-http 0.60.12",
|
||||
"aws-smithy-runtime-api",
|
||||
"aws-smithy-types",
|
||||
@@ -3394,6 +3419,7 @@ dependencies = [
|
||||
"async-trait",
|
||||
"aws-config",
|
||||
"aws-sdk-bedrockruntime",
|
||||
"aws-sdk-sagemakerruntime",
|
||||
"aws-smithy-types",
|
||||
"axum",
|
||||
"base64 0.21.7",
|
||||
|
||||
@@ -68,6 +68,9 @@ aws-config = { version = "1.5.16", features = ["behavior-version-latest"] }
|
||||
aws-smithy-types = "1.2.13"
|
||||
aws-sdk-bedrockruntime = "1.74.0"
|
||||
|
||||
# For SageMaker TGI provider
|
||||
aws-sdk-sagemakerruntime = "1.62.0"
|
||||
|
||||
# For GCP Vertex AI provider auth
|
||||
jsonwebtoken = "9.3.1"
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ use super::{
|
||||
ollama::OllamaProvider,
|
||||
openai::OpenAiProvider,
|
||||
openrouter::OpenRouterProvider,
|
||||
sagemaker_tgi::SageMakerTgiProvider,
|
||||
snowflake::SnowflakeProvider,
|
||||
venice::VeniceProvider,
|
||||
};
|
||||
@@ -48,6 +49,7 @@ pub fn providers() -> Vec<ProviderMetadata> {
|
||||
OllamaProvider::metadata(),
|
||||
OpenAiProvider::metadata(),
|
||||
OpenRouterProvider::metadata(),
|
||||
SageMakerTgiProvider::metadata(),
|
||||
VeniceProvider::metadata(),
|
||||
SnowflakeProvider::metadata(),
|
||||
]
|
||||
@@ -122,6 +124,7 @@ fn create_provider(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>>
|
||||
"openrouter" => Ok(Arc::new(OpenRouterProvider::from_env(model)?)),
|
||||
"gcp_vertex_ai" => Ok(Arc::new(GcpVertexAIProvider::from_env(model)?)),
|
||||
"google" => Ok(Arc::new(GoogleProvider::from_env(model)?)),
|
||||
"sagemaker_tgi" => Ok(Arc::new(SageMakerTgiProvider::from_env(model)?)),
|
||||
"venice" => Ok(Arc::new(VeniceProvider::from_env(model)?)),
|
||||
"snowflake" => Ok(Arc::new(SnowflakeProvider::from_env(model)?)),
|
||||
"github_copilot" => Ok(Arc::new(GithubCopilotProvider::from_env(model)?)),
|
||||
|
||||
@@ -18,6 +18,7 @@ pub mod oauth;
|
||||
pub mod ollama;
|
||||
pub mod openai;
|
||||
pub mod openrouter;
|
||||
pub mod sagemaker_tgi;
|
||||
pub mod snowflake;
|
||||
pub mod toolshim;
|
||||
pub mod utils;
|
||||
|
||||
365
crates/goose/src/providers/sagemaker_tgi.rs
Normal file
365
crates/goose/src/providers/sagemaker_tgi.rs
Normal file
@@ -0,0 +1,365 @@
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use aws_config;
|
||||
use aws_sdk_bedrockruntime::config::ProvideCredentials;
|
||||
use aws_sdk_sagemakerruntime::Client as SageMakerClient;
|
||||
use mcp_core::Tool;
|
||||
use serde_json::{json, Value};
|
||||
use tokio::time::sleep;
|
||||
|
||||
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
|
||||
use super::errors::ProviderError;
|
||||
use super::utils::emit_debug_trace;
|
||||
use crate::message::{Message, MessageContent};
|
||||
use crate::model::ModelConfig;
|
||||
use chrono::Utc;
|
||||
use mcp_core::content::TextContent;
|
||||
use mcp_core::role::Role;
|
||||
|
||||
pub const SAGEMAKER_TGI_DOC_LINK: &str =
|
||||
"https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html";
|
||||
|
||||
pub const SAGEMAKER_TGI_DEFAULT_MODEL: &str = "sagemaker-tgi-endpoint";
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct SageMakerTgiProvider {
|
||||
#[serde(skip)]
|
||||
sagemaker_client: SageMakerClient,
|
||||
endpoint_name: String,
|
||||
model: ModelConfig,
|
||||
}
|
||||
|
||||
impl SageMakerTgiProvider {
|
||||
pub fn from_env(model: ModelConfig) -> Result<Self> {
|
||||
let config = crate::config::Config::global();
|
||||
|
||||
// Get SageMaker endpoint name (just the name, not full URL)
|
||||
let endpoint_name: String = config.get_param("SAGEMAKER_ENDPOINT_NAME").map_err(|_| {
|
||||
anyhow::anyhow!("SAGEMAKER_ENDPOINT_NAME is required for SageMaker TGI provider")
|
||||
})?;
|
||||
|
||||
// Attempt to load config and secrets to get AWS_ prefixed keys
|
||||
let set_aws_env_vars = |res: Result<HashMap<String, Value>, _>| {
|
||||
if let Ok(map) = res {
|
||||
map.into_iter()
|
||||
.filter(|(key, _)| key.starts_with("AWS_"))
|
||||
.filter_map(|(key, value)| value.as_str().map(|s| (key, s.to_string())))
|
||||
.for_each(|(key, s)| std::env::set_var(key, s));
|
||||
}
|
||||
};
|
||||
|
||||
set_aws_env_vars(config.load_values());
|
||||
set_aws_env_vars(config.load_secrets());
|
||||
|
||||
let aws_config = futures::executor::block_on(aws_config::load_from_env());
|
||||
|
||||
// Validate credentials
|
||||
futures::executor::block_on(
|
||||
aws_config
|
||||
.credentials_provider()
|
||||
.unwrap()
|
||||
.provide_credentials(),
|
||||
)?;
|
||||
|
||||
// Create client with longer timeout for model initialization
|
||||
let timeout_config = aws_config::timeout::TimeoutConfig::builder()
|
||||
.operation_timeout(Duration::from_secs(300)) // 5 minutes for cold starts
|
||||
.build();
|
||||
|
||||
let config_with_timeout = aws_config
|
||||
.into_builder()
|
||||
.timeout_config(timeout_config)
|
||||
.build();
|
||||
|
||||
let sagemaker_client = SageMakerClient::new(&config_with_timeout);
|
||||
|
||||
Ok(Self {
|
||||
sagemaker_client,
|
||||
endpoint_name,
|
||||
model,
|
||||
})
|
||||
}
|
||||
|
||||
fn create_tgi_request(&self, system: &str, messages: &[Message]) -> Result<Value> {
|
||||
// Create a simplified prompt for TGI models using recent user and assistant messages.
|
||||
// Uses a minimal system prompt and avoids HTML or tool-related formatting.
|
||||
let mut prompt = String::new();
|
||||
|
||||
// Use a very simple system prompt if provided, but ensure it doesn't contain HTML instructions
|
||||
if !system.is_empty()
|
||||
&& !system.contains("Available tools")
|
||||
&& system.len() < 200
|
||||
&& !system.contains("HTML")
|
||||
&& !system.contains("markdown")
|
||||
{
|
||||
prompt.push_str(&format!("System: {}\n\n", system));
|
||||
} else {
|
||||
// Use a minimal system prompt for TGI that explicitly avoids HTML
|
||||
prompt.push_str("System: You are a helpful AI assistant. Provide responses in plain text only. Do not use HTML tags, markup, or formatting.\n\n");
|
||||
}
|
||||
|
||||
// Only include the most recent user messages to avoid overwhelming the model
|
||||
let recent_messages: Vec<_> = messages.iter().rev().take(3).collect();
|
||||
for message in recent_messages.iter().rev() {
|
||||
match &message.role {
|
||||
Role::User => {
|
||||
prompt.push_str("User: ");
|
||||
for content in &message.content {
|
||||
if let MessageContent::Text(text) = content {
|
||||
prompt.push_str(&text.text);
|
||||
}
|
||||
}
|
||||
prompt.push_str("\n\n");
|
||||
}
|
||||
Role::Assistant => {
|
||||
prompt.push_str("Assistant: ");
|
||||
for content in &message.content {
|
||||
if let MessageContent::Text(text) = content {
|
||||
// Skip responses that look like tool descriptions or contain HTML
|
||||
if !text.text.contains("__")
|
||||
&& !text.text.contains("Available tools")
|
||||
&& !text.text.contains("<")
|
||||
{
|
||||
prompt.push_str(&text.text);
|
||||
}
|
||||
}
|
||||
}
|
||||
prompt.push_str("\n\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prompt.push_str("Assistant: ");
|
||||
|
||||
// Skip tool descriptions entirely for TGI models to avoid confusion
|
||||
// TGI models don't support tools natively and including tool descriptions
|
||||
// causes them to mimic that format in their responses
|
||||
|
||||
// Build TGI request with reasonable parameters
|
||||
let request = json!({
|
||||
"inputs": prompt,
|
||||
"parameters": {
|
||||
"max_new_tokens": self.model.max_tokens.unwrap_or(150),
|
||||
"temperature": self.model.temperature.unwrap_or(0.7),
|
||||
"do_sample": true,
|
||||
"return_full_text": false
|
||||
}
|
||||
});
|
||||
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
async fn invoke_endpoint(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
let body = serde_json::to_string(&payload).map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to serialize request: {}", e))
|
||||
})?;
|
||||
|
||||
let response = self
|
||||
.sagemaker_client
|
||||
.invoke_endpoint()
|
||||
.endpoint_name(&self.endpoint_name)
|
||||
.content_type("application/json")
|
||||
.body(body.into_bytes().into())
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("SageMaker invoke failed: {}", e)))?;
|
||||
|
||||
let response_body = response
|
||||
.body
|
||||
.as_ref()
|
||||
.ok_or_else(|| ProviderError::RequestFailed("Empty response body".to_string()))?;
|
||||
let response_text = std::str::from_utf8(response_body.as_ref()).map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to decode response: {}", e))
|
||||
})?;
|
||||
|
||||
serde_json::from_str(response_text).map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to parse response JSON: {}", e))
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_tgi_response(&self, response: Value) -> Result<Message, ProviderError> {
|
||||
// Handle standard TGI response: [{"generated_text": "..."}]
|
||||
let response_array = response
|
||||
.as_array()
|
||||
.ok_or_else(|| ProviderError::RequestFailed("Expected array response".to_string()))?;
|
||||
|
||||
if response_array.is_empty() {
|
||||
return Err(ProviderError::RequestFailed(
|
||||
"Empty response array".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let first_result = &response_array[0];
|
||||
let generated_text = first_result
|
||||
.get("generated_text")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| {
|
||||
ProviderError::RequestFailed("No generated_text in response".to_string())
|
||||
})?;
|
||||
|
||||
// Strip any HTML tags that might have been generated
|
||||
let clean_text = self.strip_html_tags(generated_text);
|
||||
|
||||
Ok(Message {
|
||||
role: Role::Assistant,
|
||||
created: Utc::now().timestamp(),
|
||||
content: vec![MessageContent::Text(TextContent {
|
||||
text: clean_text,
|
||||
annotations: None,
|
||||
})],
|
||||
})
|
||||
}
|
||||
|
||||
/// Strip HTML tags from text to ensure clean output
|
||||
fn strip_html_tags(&self, text: &str) -> String {
|
||||
// Simple regex-free approach to strip common HTML tags
|
||||
let mut result = text.to_string();
|
||||
|
||||
// Remove common HTML tags like <b>, <i>, <strong>, <em>, etc.
|
||||
let tags_to_remove = [
|
||||
"<b>",
|
||||
"</b>",
|
||||
"<i>",
|
||||
"</i>",
|
||||
"<strong>",
|
||||
"</strong>",
|
||||
"<em>",
|
||||
"</em>",
|
||||
"<u>",
|
||||
"</u>",
|
||||
"<br>",
|
||||
"<br/>",
|
||||
"<p>",
|
||||
"</p>",
|
||||
"<div>",
|
||||
"</div>",
|
||||
"<span>",
|
||||
"</span>",
|
||||
];
|
||||
|
||||
for tag in &tags_to_remove {
|
||||
result = result.replace(tag, "");
|
||||
}
|
||||
|
||||
// Remove any remaining HTML-like tags using a simple pattern
|
||||
// This is a basic implementation - for production use, consider using a proper HTML parser
|
||||
while let Some(start) = result.find('<') {
|
||||
if let Some(end) = result[start..].find('>') {
|
||||
result.replace_range(start..start + end + 1, "");
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
result.trim().to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SageMakerTgiProvider {
|
||||
fn default() -> Self {
|
||||
let model = ModelConfig::new(SageMakerTgiProvider::metadata().default_model);
|
||||
SageMakerTgiProvider::from_env(model).expect("Failed to initialize SageMaker TGI provider")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for SageMakerTgiProvider {
|
||||
fn metadata() -> ProviderMetadata {
|
||||
ProviderMetadata::new(
|
||||
"sagemaker_tgi",
|
||||
"Amazon SageMaker TGI",
|
||||
"Run Text Generation Inference models through Amazon SageMaker endpoints. Requires AWS credentials and a SageMaker endpoint URL.",
|
||||
SAGEMAKER_TGI_DEFAULT_MODEL,
|
||||
vec![SAGEMAKER_TGI_DEFAULT_MODEL],
|
||||
SAGEMAKER_TGI_DOC_LINK,
|
||||
vec![
|
||||
ConfigKey::new("SAGEMAKER_ENDPOINT_NAME", false, false, None),
|
||||
ConfigKey::new("AWS_REGION", true, false, Some("us-east-1")),
|
||||
ConfigKey::new("AWS_PROFILE", true, false, Some("default")),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
self.model.clone()
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip(self, system, messages, tools),
|
||||
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
|
||||
)]
|
||||
async fn complete(
|
||||
&self,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
let model_name = &self.model.model_name;
|
||||
|
||||
let request_payload = self.create_tgi_request(system, messages).map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to create request: {}", e))
|
||||
})?;
|
||||
|
||||
// Retry configuration
|
||||
const MAX_RETRIES: u32 = 3;
|
||||
const INITIAL_BACKOFF_MS: u64 = 1000; // 1 second
|
||||
const MAX_BACKOFF_MS: u64 = 30000; // 30 seconds
|
||||
|
||||
let mut attempts = 0;
|
||||
let mut backoff_ms = INITIAL_BACKOFF_MS;
|
||||
|
||||
loop {
|
||||
attempts += 1;
|
||||
|
||||
match self.invoke_endpoint(request_payload.clone()).await {
|
||||
Ok(response) => {
|
||||
let message = self.parse_tgi_response(response)?;
|
||||
|
||||
// TGI doesn't provide usage statistics, so we estimate
|
||||
let usage = Usage {
|
||||
input_tokens: Some(0), // Would need to tokenize input to get accurate count
|
||||
output_tokens: Some(0), // Would need to tokenize output to get accurate count
|
||||
total_tokens: Some(0),
|
||||
};
|
||||
|
||||
// Add debug trace
|
||||
let debug_payload = serde_json::json!({
|
||||
"system": system,
|
||||
"messages": messages,
|
||||
"tools": tools
|
||||
});
|
||||
emit_debug_trace(
|
||||
&self.model,
|
||||
&debug_payload,
|
||||
&serde_json::to_value(&message).unwrap_or_default(),
|
||||
&usage,
|
||||
);
|
||||
|
||||
let provider_usage = ProviderUsage::new(model_name.to_string(), usage);
|
||||
return Ok((message, provider_usage));
|
||||
}
|
||||
Err(err) => {
|
||||
if attempts > MAX_RETRIES {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
// Log retry attempt
|
||||
tracing::warn!(
|
||||
"SageMaker TGI request failed (attempt {}/{}), retrying in {} ms: {:?}",
|
||||
attempts,
|
||||
MAX_RETRIES,
|
||||
backoff_ms,
|
||||
err
|
||||
);
|
||||
|
||||
// Wait before retry
|
||||
sleep(Duration::from_millis(backoff_ms)).await;
|
||||
backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -490,6 +490,17 @@ async fn test_snowflake_provider() -> Result<()> {
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sagemaker_tgi_provider() -> Result<()> {
|
||||
test_provider(
|
||||
"SageMakerTgi",
|
||||
&["SAGEMAKER_ENDPOINT_NAME"],
|
||||
None,
|
||||
goose::providers::sagemaker_tgi::SageMakerTgiProvider::default,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
// Print the final test report
|
||||
#[ctor::dtor]
|
||||
fn print_test_report() {
|
||||
|
||||
Reference in New Issue
Block a user