Damien/sagemaker tgi (#2888)

This commit is contained in:
damienrj
2025-06-13 11:35:42 -07:00
committed by GitHub
parent b1d2a02f66
commit a7ad73197d
6 changed files with 409 additions and 0 deletions

26
Cargo.lock generated
View File

@@ -704,6 +704,7 @@ dependencies = [
"aws-credential-types",
"aws-sigv4",
"aws-smithy-async",
"aws-smithy-eventstream",
"aws-smithy-http 0.60.12",
"aws-smithy-runtime",
"aws-smithy-runtime-api",
@@ -767,6 +768,29 @@ dependencies = [
"tracing",
]
[[package]]
name = "aws-sdk-sagemakerruntime"
version = "1.63.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c3188bb9f962a9e1781c917dbe7f016ab9430e4bd81ba7daf422e58d86a3595"
dependencies = [
"aws-credential-types",
"aws-runtime",
"aws-smithy-async",
"aws-smithy-eventstream",
"aws-smithy-http 0.61.1",
"aws-smithy-json",
"aws-smithy-runtime",
"aws-smithy-runtime-api",
"aws-smithy-types",
"aws-types",
"bytes",
"http 0.2.12",
"once_cell",
"regex-lite",
"tracing",
]
[[package]]
name = "aws-sdk-sso"
version = "1.61.0"
@@ -841,6 +865,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9bfe75fad52793ce6dec0dc3d4b1f388f038b5eb866c8d4d7f3a8e21b5ea5051"
dependencies = [
"aws-credential-types",
"aws-smithy-eventstream",
"aws-smithy-http 0.60.12",
"aws-smithy-runtime-api",
"aws-smithy-types",
@@ -3394,6 +3419,7 @@ dependencies = [
"async-trait",
"aws-config",
"aws-sdk-bedrockruntime",
"aws-sdk-sagemakerruntime",
"aws-smithy-types",
"axum",
"base64 0.21.7",

View File

@@ -68,6 +68,9 @@ aws-config = { version = "1.5.16", features = ["behavior-version-latest"] }
aws-smithy-types = "1.2.13"
aws-sdk-bedrockruntime = "1.74.0"
# For SageMaker TGI provider
aws-sdk-sagemakerruntime = "1.62.0"
# For GCP Vertex AI provider auth
jsonwebtoken = "9.3.1"

View File

@@ -14,6 +14,7 @@ use super::{
ollama::OllamaProvider,
openai::OpenAiProvider,
openrouter::OpenRouterProvider,
sagemaker_tgi::SageMakerTgiProvider,
snowflake::SnowflakeProvider,
venice::VeniceProvider,
};
@@ -48,6 +49,7 @@ pub fn providers() -> Vec<ProviderMetadata> {
OllamaProvider::metadata(),
OpenAiProvider::metadata(),
OpenRouterProvider::metadata(),
SageMakerTgiProvider::metadata(),
VeniceProvider::metadata(),
SnowflakeProvider::metadata(),
]
@@ -122,6 +124,7 @@ fn create_provider(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>>
"openrouter" => Ok(Arc::new(OpenRouterProvider::from_env(model)?)),
"gcp_vertex_ai" => Ok(Arc::new(GcpVertexAIProvider::from_env(model)?)),
"google" => Ok(Arc::new(GoogleProvider::from_env(model)?)),
"sagemaker_tgi" => Ok(Arc::new(SageMakerTgiProvider::from_env(model)?)),
"venice" => Ok(Arc::new(VeniceProvider::from_env(model)?)),
"snowflake" => Ok(Arc::new(SnowflakeProvider::from_env(model)?)),
"github_copilot" => Ok(Arc::new(GithubCopilotProvider::from_env(model)?)),

View File

@@ -18,6 +18,7 @@ pub mod oauth;
pub mod ollama;
pub mod openai;
pub mod openrouter;
pub mod sagemaker_tgi;
pub mod snowflake;
pub mod toolshim;
pub mod utils;

View File

@@ -0,0 +1,365 @@
use std::collections::HashMap;
use std::time::Duration;
use anyhow::Result;
use async_trait::async_trait;
use aws_config;
use aws_sdk_bedrockruntime::config::ProvideCredentials;
use aws_sdk_sagemakerruntime::Client as SageMakerClient;
use mcp_core::Tool;
use serde_json::{json, Value};
use tokio::time::sleep;
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
use super::errors::ProviderError;
use super::utils::emit_debug_trace;
use crate::message::{Message, MessageContent};
use crate::model::ModelConfig;
use chrono::Utc;
use mcp_core::content::TextContent;
use mcp_core::role::Role;
pub const SAGEMAKER_TGI_DOC_LINK: &str =
"https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html";
pub const SAGEMAKER_TGI_DEFAULT_MODEL: &str = "sagemaker-tgi-endpoint";
#[derive(Debug, serde::Serialize)]
pub struct SageMakerTgiProvider {
#[serde(skip)]
sagemaker_client: SageMakerClient,
endpoint_name: String,
model: ModelConfig,
}
impl SageMakerTgiProvider {
pub fn from_env(model: ModelConfig) -> Result<Self> {
let config = crate::config::Config::global();
// Get SageMaker endpoint name (just the name, not full URL)
let endpoint_name: String = config.get_param("SAGEMAKER_ENDPOINT_NAME").map_err(|_| {
anyhow::anyhow!("SAGEMAKER_ENDPOINT_NAME is required for SageMaker TGI provider")
})?;
// Attempt to load config and secrets to get AWS_ prefixed keys
let set_aws_env_vars = |res: Result<HashMap<String, Value>, _>| {
if let Ok(map) = res {
map.into_iter()
.filter(|(key, _)| key.starts_with("AWS_"))
.filter_map(|(key, value)| value.as_str().map(|s| (key, s.to_string())))
.for_each(|(key, s)| std::env::set_var(key, s));
}
};
set_aws_env_vars(config.load_values());
set_aws_env_vars(config.load_secrets());
let aws_config = futures::executor::block_on(aws_config::load_from_env());
// Validate credentials
futures::executor::block_on(
aws_config
.credentials_provider()
.unwrap()
.provide_credentials(),
)?;
// Create client with longer timeout for model initialization
let timeout_config = aws_config::timeout::TimeoutConfig::builder()
.operation_timeout(Duration::from_secs(300)) // 5 minutes for cold starts
.build();
let config_with_timeout = aws_config
.into_builder()
.timeout_config(timeout_config)
.build();
let sagemaker_client = SageMakerClient::new(&config_with_timeout);
Ok(Self {
sagemaker_client,
endpoint_name,
model,
})
}
fn create_tgi_request(&self, system: &str, messages: &[Message]) -> Result<Value> {
// Create a simplified prompt for TGI models using recent user and assistant messages.
// Uses a minimal system prompt and avoids HTML or tool-related formatting.
let mut prompt = String::new();
// Use a very simple system prompt if provided, but ensure it doesn't contain HTML instructions
if !system.is_empty()
&& !system.contains("Available tools")
&& system.len() < 200
&& !system.contains("HTML")
&& !system.contains("markdown")
{
prompt.push_str(&format!("System: {}\n\n", system));
} else {
// Use a minimal system prompt for TGI that explicitly avoids HTML
prompt.push_str("System: You are a helpful AI assistant. Provide responses in plain text only. Do not use HTML tags, markup, or formatting.\n\n");
}
// Only include the most recent user messages to avoid overwhelming the model
let recent_messages: Vec<_> = messages.iter().rev().take(3).collect();
for message in recent_messages.iter().rev() {
match &message.role {
Role::User => {
prompt.push_str("User: ");
for content in &message.content {
if let MessageContent::Text(text) = content {
prompt.push_str(&text.text);
}
}
prompt.push_str("\n\n");
}
Role::Assistant => {
prompt.push_str("Assistant: ");
for content in &message.content {
if let MessageContent::Text(text) = content {
// Skip responses that look like tool descriptions or contain HTML
if !text.text.contains("__")
&& !text.text.contains("Available tools")
&& !text.text.contains("<")
{
prompt.push_str(&text.text);
}
}
}
prompt.push_str("\n\n");
}
}
}
prompt.push_str("Assistant: ");
// Skip tool descriptions entirely for TGI models to avoid confusion
// TGI models don't support tools natively and including tool descriptions
// causes them to mimic that format in their responses
// Build TGI request with reasonable parameters
let request = json!({
"inputs": prompt,
"parameters": {
"max_new_tokens": self.model.max_tokens.unwrap_or(150),
"temperature": self.model.temperature.unwrap_or(0.7),
"do_sample": true,
"return_full_text": false
}
});
Ok(request)
}
async fn invoke_endpoint(&self, payload: Value) -> Result<Value, ProviderError> {
let body = serde_json::to_string(&payload).map_err(|e| {
ProviderError::RequestFailed(format!("Failed to serialize request: {}", e))
})?;
let response = self
.sagemaker_client
.invoke_endpoint()
.endpoint_name(&self.endpoint_name)
.content_type("application/json")
.body(body.into_bytes().into())
.send()
.await
.map_err(|e| ProviderError::RequestFailed(format!("SageMaker invoke failed: {}", e)))?;
let response_body = response
.body
.as_ref()
.ok_or_else(|| ProviderError::RequestFailed("Empty response body".to_string()))?;
let response_text = std::str::from_utf8(response_body.as_ref()).map_err(|e| {
ProviderError::RequestFailed(format!("Failed to decode response: {}", e))
})?;
serde_json::from_str(response_text).map_err(|e| {
ProviderError::RequestFailed(format!("Failed to parse response JSON: {}", e))
})
}
fn parse_tgi_response(&self, response: Value) -> Result<Message, ProviderError> {
// Handle standard TGI response: [{"generated_text": "..."}]
let response_array = response
.as_array()
.ok_or_else(|| ProviderError::RequestFailed("Expected array response".to_string()))?;
if response_array.is_empty() {
return Err(ProviderError::RequestFailed(
"Empty response array".to_string(),
));
}
let first_result = &response_array[0];
let generated_text = first_result
.get("generated_text")
.and_then(|v| v.as_str())
.ok_or_else(|| {
ProviderError::RequestFailed("No generated_text in response".to_string())
})?;
// Strip any HTML tags that might have been generated
let clean_text = self.strip_html_tags(generated_text);
Ok(Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
text: clean_text,
annotations: None,
})],
})
}
/// Strip HTML tags from text to ensure clean output
fn strip_html_tags(&self, text: &str) -> String {
// Simple regex-free approach to strip common HTML tags
let mut result = text.to_string();
// Remove common HTML tags like <b>, <i>, <strong>, <em>, etc.
let tags_to_remove = [
"<b>",
"</b>",
"<i>",
"</i>",
"<strong>",
"</strong>",
"<em>",
"</em>",
"<u>",
"</u>",
"<br>",
"<br/>",
"<p>",
"</p>",
"<div>",
"</div>",
"<span>",
"</span>",
];
for tag in &tags_to_remove {
result = result.replace(tag, "");
}
// Remove any remaining HTML-like tags using a simple pattern
// This is a basic implementation - for production use, consider using a proper HTML parser
while let Some(start) = result.find('<') {
if let Some(end) = result[start..].find('>') {
result.replace_range(start..start + end + 1, "");
} else {
break;
}
}
result.trim().to_string()
}
}
impl Default for SageMakerTgiProvider {
fn default() -> Self {
let model = ModelConfig::new(SageMakerTgiProvider::metadata().default_model);
SageMakerTgiProvider::from_env(model).expect("Failed to initialize SageMaker TGI provider")
}
}
#[async_trait]
impl Provider for SageMakerTgiProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::new(
"sagemaker_tgi",
"Amazon SageMaker TGI",
"Run Text Generation Inference models through Amazon SageMaker endpoints. Requires AWS credentials and a SageMaker endpoint URL.",
SAGEMAKER_TGI_DEFAULT_MODEL,
vec![SAGEMAKER_TGI_DEFAULT_MODEL],
SAGEMAKER_TGI_DOC_LINK,
vec![
ConfigKey::new("SAGEMAKER_ENDPOINT_NAME", false, false, None),
ConfigKey::new("AWS_REGION", true, false, Some("us-east-1")),
ConfigKey::new("AWS_PROFILE", true, false, Some("default")),
],
)
}
fn get_model_config(&self) -> ModelConfig {
self.model.clone()
}
#[tracing::instrument(
skip(self, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
)]
async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
let model_name = &self.model.model_name;
let request_payload = self.create_tgi_request(system, messages).map_err(|e| {
ProviderError::RequestFailed(format!("Failed to create request: {}", e))
})?;
// Retry configuration
const MAX_RETRIES: u32 = 3;
const INITIAL_BACKOFF_MS: u64 = 1000; // 1 second
const MAX_BACKOFF_MS: u64 = 30000; // 30 seconds
let mut attempts = 0;
let mut backoff_ms = INITIAL_BACKOFF_MS;
loop {
attempts += 1;
match self.invoke_endpoint(request_payload.clone()).await {
Ok(response) => {
let message = self.parse_tgi_response(response)?;
// TGI doesn't provide usage statistics, so we estimate
let usage = Usage {
input_tokens: Some(0), // Would need to tokenize input to get accurate count
output_tokens: Some(0), // Would need to tokenize output to get accurate count
total_tokens: Some(0),
};
// Add debug trace
let debug_payload = serde_json::json!({
"system": system,
"messages": messages,
"tools": tools
});
emit_debug_trace(
&self.model,
&debug_payload,
&serde_json::to_value(&message).unwrap_or_default(),
&usage,
);
let provider_usage = ProviderUsage::new(model_name.to_string(), usage);
return Ok((message, provider_usage));
}
Err(err) => {
if attempts > MAX_RETRIES {
return Err(err);
}
// Log retry attempt
tracing::warn!(
"SageMaker TGI request failed (attempt {}/{}), retrying in {} ms: {:?}",
attempts,
MAX_RETRIES,
backoff_ms,
err
);
// Wait before retry
sleep(Duration::from_millis(backoff_ms)).await;
backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS);
}
}
}
}
}

View File

@@ -490,6 +490,17 @@ async fn test_snowflake_provider() -> Result<()> {
.await
}
#[tokio::test]
async fn test_sagemaker_tgi_provider() -> Result<()> {
test_provider(
"SageMakerTgi",
&["SAGEMAKER_ENDPOINT_NAME"],
None,
goose::providers::sagemaker_tgi::SageMakerTgiProvider::default,
)
.await
}
// Print the final test report
#[ctor::dtor]
fn print_test_report() {