mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-07 07:34:29 +01:00
chore: implement streaming for anthropic.rs firstparty provider (#3419)
This commit is contained in:
@@ -1,13 +1,21 @@
|
||||
use anyhow::Result;
|
||||
use async_stream::try_stream;
|
||||
use async_trait::async_trait;
|
||||
use axum::http::HeaderMap;
|
||||
use futures::TryStreamExt;
|
||||
use reqwest::{Client, StatusCode};
|
||||
use serde_json::Value;
|
||||
use std::io;
|
||||
use std::time::Duration;
|
||||
use tokio::pin;
|
||||
|
||||
use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage};
|
||||
use tokio_util::io::StreamReader;
|
||||
|
||||
use super::base::{ConfigKey, MessageStream, ModelInfo, Provider, ProviderMetadata, ProviderUsage};
|
||||
use super::errors::ProviderError;
|
||||
use super::formats::anthropic::{create_request, get_usage, response_to_message};
|
||||
use super::formats::anthropic::{
|
||||
create_request, get_usage, response_to_message, response_to_streaming_message,
|
||||
};
|
||||
use super::utils::{emit_debug_trace, get_model};
|
||||
use crate::message::Message;
|
||||
use crate::model::ModelConfig;
|
||||
@@ -195,10 +203,17 @@ impl Provider for AnthropicProvider {
|
||||
// Parse response
|
||||
let message = response_to_message(response.clone())?;
|
||||
let usage = get_usage(&response)?;
|
||||
tracing::debug!("🔍 Anthropic non-streaming parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}",
|
||||
usage.input_tokens, usage.output_tokens, usage.total_tokens);
|
||||
|
||||
let model = get_model(&response);
|
||||
emit_debug_trace(&self.model, &payload, &response, &usage);
|
||||
Ok((message, ProviderUsage::new(model, usage)))
|
||||
let provider_usage = ProviderUsage::new(model, usage);
|
||||
tracing::debug!(
|
||||
"🔍 Anthropic non-streaming returning ProviderUsage: {:?}",
|
||||
provider_usage
|
||||
);
|
||||
Ok((message, provider_usage))
|
||||
}
|
||||
|
||||
/// Fetch supported models from Anthropic; returns Err on failure, Ok(None) if not present
|
||||
@@ -232,4 +247,82 @@ impl Provider for AnthropicProvider {
|
||||
models.sort();
|
||||
Ok(Some(models))
|
||||
}
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<MessageStream, ProviderError> {
|
||||
let mut payload = create_request(&self.model, system, messages, tools)?;
|
||||
|
||||
// Add stream parameter
|
||||
payload
|
||||
.as_object_mut()
|
||||
.unwrap()
|
||||
.insert("stream".to_string(), Value::Bool(true));
|
||||
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert("x-api-key", self.api_key.parse().unwrap());
|
||||
headers.insert("anthropic-version", ANTHROPIC_API_VERSION.parse().unwrap());
|
||||
|
||||
let is_thinking_enabled = std::env::var("CLAUDE_THINKING_ENABLED").is_ok();
|
||||
if self.model.model_name.starts_with("claude-3-7-sonnet-") && is_thinking_enabled {
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#extended-output-capabilities-beta
|
||||
headers.insert("anthropic-beta", "output-128k-2025-02-19".parse().unwrap());
|
||||
}
|
||||
|
||||
if self.model.model_name.starts_with("claude-3-7-sonnet-") {
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/tool-use/token-efficient-tool-use
|
||||
headers.insert(
|
||||
"anthropic-beta",
|
||||
"token-efficient-tools-2025-02-19".parse().unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
let base_url = url::Url::parse(&self.host)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
let url = base_url.join("v1/messages").map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
||||
})?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(url)
|
||||
.headers(headers)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
return Err(ProviderError::RequestFailed(format!(
|
||||
"Streaming request failed with status: {}. Error: {}",
|
||||
status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
// Map reqwest error to io::Error
|
||||
let stream = response.bytes_stream().map_err(io::Error::other);
|
||||
|
||||
let model_config = self.model.clone();
|
||||
// Wrap in a line decoder and yield lines inside the stream
|
||||
Ok(Box::pin(try_stream! {
|
||||
let stream_reader = StreamReader::new(stream);
|
||||
let framed = tokio_util::codec::FramedRead::new(stream_reader, tokio_util::codec::LinesCodec::new()).map_err(anyhow::Error::from);
|
||||
|
||||
let message_stream = response_to_streaming_message(framed);
|
||||
pin!(message_stream);
|
||||
while let Some(message) = futures::StreamExt::next(&mut message_stream).await {
|
||||
let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?;
|
||||
super::utils::emit_debug_trace(&model_config, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default());
|
||||
yield (message, usage);
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn supports_streaming(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -292,40 +292,68 @@ pub fn get_usage(data: &Value) -> Result<Usage> {
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
// IMPORTANT: Based on the API responses, when caching is used:
|
||||
// - input_tokens is ONLY the new/fresh tokens (can be very small, like 7)
|
||||
// - cache_creation_input_tokens and cache_read_input_tokens are the cached content
|
||||
// - These cached tokens are charged at different rates:
|
||||
// * Fresh input tokens: 100% of regular price
|
||||
// * Cache creation tokens: 125% of regular price
|
||||
// * Cache read tokens: 10% of regular price
|
||||
//
|
||||
// Calculate effective input tokens for cost calculation based on Anthropic's pricing:
|
||||
// - Fresh input tokens: 100% of regular price (1.0x)
|
||||
// - Cache creation tokens: 125% of regular price (1.25x)
|
||||
// - Cache read tokens: 10% of regular price (0.10x)
|
||||
//
|
||||
// The effective input tokens represent the cost-equivalent tokens when multiplied
|
||||
// by the regular input price, ensuring accurate cost calculations in the frontend.
|
||||
let effective_input_tokens = input_tokens as f64 * 1.0
|
||||
+ cache_creation_tokens as f64 * 1.25
|
||||
+ cache_read_tokens as f64 * 0.10;
|
||||
// IMPORTANT: For display purposes, we want to show the ACTUAL total tokens consumed
|
||||
// The cache pricing should only affect cost calculation, not token count display
|
||||
let total_input_tokens = input_tokens + cache_creation_tokens + cache_read_tokens;
|
||||
|
||||
// For token counting purposes, we still want to show the actual total count
|
||||
let _total_actual_tokens = input_tokens + cache_creation_tokens + cache_read_tokens;
|
||||
|
||||
// Return the effective input tokens for cost calculation
|
||||
// This ensures the frontend cost calculation is accurate when multiplying by regular prices
|
||||
let effective_input_i32 = effective_input_tokens.round().clamp(0.0, i32::MAX as f64) as i32;
|
||||
// Convert to i32 with bounds checking
|
||||
let total_input_i32 = total_input_tokens.min(i32::MAX as u64) as i32;
|
||||
let output_tokens_i32 = output_tokens.min(i32::MAX as u64) as i32;
|
||||
let total_tokens_i32 =
|
||||
(effective_input_i32 as i64 + output_tokens_i32 as i64).min(i32::MAX as i64) as i32;
|
||||
(total_input_i32 as i64 + output_tokens_i32 as i64).min(i32::MAX as i64) as i32;
|
||||
|
||||
Ok(Usage::new(
|
||||
Some(effective_input_i32),
|
||||
Some(total_input_i32),
|
||||
Some(output_tokens_i32),
|
||||
Some(total_tokens_i32),
|
||||
))
|
||||
} else if data.as_object().is_some() {
|
||||
// Check if the data itself is the usage object (for message_delta events that might have usage at top level)
|
||||
let input_tokens = data
|
||||
.get("input_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
let cache_creation_tokens = data
|
||||
.get("cache_creation_input_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
let cache_read_tokens = data
|
||||
.get("cache_read_input_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
let output_tokens = data
|
||||
.get("output_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
// If we found any token data, process it
|
||||
if input_tokens > 0
|
||||
|| cache_creation_tokens > 0
|
||||
|| cache_read_tokens > 0
|
||||
|| output_tokens > 0
|
||||
{
|
||||
let total_input_tokens = input_tokens + cache_creation_tokens + cache_read_tokens;
|
||||
|
||||
let total_input_i32 = total_input_tokens.min(i32::MAX as u64) as i32;
|
||||
let output_tokens_i32 = output_tokens.min(i32::MAX as u64) as i32;
|
||||
let total_tokens_i32 =
|
||||
(total_input_i32 as i64 + output_tokens_i32 as i64).min(i32::MAX as i64) as i32;
|
||||
|
||||
tracing::debug!("🔍 Anthropic ACTUAL token counts from direct object: input={}, output={}, total={}",
|
||||
total_input_i32, output_tokens_i32, total_tokens_i32);
|
||||
|
||||
Ok(Usage::new(
|
||||
Some(total_input_i32),
|
||||
Some(output_tokens_i32),
|
||||
Some(total_tokens_i32),
|
||||
))
|
||||
} else {
|
||||
tracing::debug!("🔍 Anthropic no token data found in object");
|
||||
Ok(Usage::new(None, None, None))
|
||||
}
|
||||
} else {
|
||||
tracing::debug!(
|
||||
"Failed to get usage data: {}",
|
||||
@@ -414,6 +442,232 @@ pub fn create_request(
|
||||
Ok(payload)
|
||||
}
|
||||
|
||||
/// Process streaming response from Anthropic's API
|
||||
pub fn response_to_streaming_message<S>(
|
||||
mut stream: S,
|
||||
) -> impl futures::Stream<
|
||||
Item = anyhow::Result<(
|
||||
Option<Message>,
|
||||
Option<crate::providers::base::ProviderUsage>,
|
||||
)>,
|
||||
> + 'static
|
||||
where
|
||||
S: futures::Stream<Item = anyhow::Result<String>> + Unpin + Send + 'static,
|
||||
{
|
||||
use async_stream::try_stream;
|
||||
use futures::StreamExt;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct StreamingEvent {
|
||||
#[serde(rename = "type")]
|
||||
event_type: String,
|
||||
#[serde(flatten)]
|
||||
data: Value,
|
||||
}
|
||||
|
||||
try_stream! {
|
||||
let mut accumulated_text = String::new();
|
||||
let mut accumulated_tool_calls: std::collections::HashMap<String, (String, String)> = std::collections::HashMap::new();
|
||||
let mut current_tool_id: Option<String> = None;
|
||||
let mut final_usage: Option<crate::providers::base::ProviderUsage> = None;
|
||||
|
||||
while let Some(line_result) = stream.next().await {
|
||||
let line = line_result?;
|
||||
|
||||
// Skip empty lines and non-data lines
|
||||
if line.trim().is_empty() || !line.starts_with("data: ") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let data_part = line.strip_prefix("data: ").unwrap_or(&line);
|
||||
|
||||
// Handle end of stream
|
||||
if data_part.trim() == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
|
||||
// Parse the JSON event
|
||||
let event: StreamingEvent = match serde_json::from_str(data_part) {
|
||||
Ok(event) => event,
|
||||
Err(e) => {
|
||||
tracing::debug!("Failed to parse streaming event: {} - Line: {}", e, data_part);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match event.event_type.as_str() {
|
||||
"message_start" => {
|
||||
// Message started, we can extract initial metadata and usage if needed
|
||||
if let Some(message_data) = event.data.get("message") {
|
||||
if let Some(usage_data) = message_data.get("usage") {
|
||||
let usage = get_usage(usage_data).unwrap_or_default();
|
||||
tracing::debug!("🔍 Anthropic message_start parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}",
|
||||
usage.input_tokens, usage.output_tokens, usage.total_tokens);
|
||||
let model = message_data.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
final_usage = Some(crate::providers::base::ProviderUsage::new(model, usage));
|
||||
} else {
|
||||
tracing::debug!("🔍 Anthropic message_start has no usage data");
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
"content_block_start" => {
|
||||
// A new content block started
|
||||
if let Some(content_block) = event.data.get("content_block") {
|
||||
if content_block.get("type") == Some(&json!("tool_use")) {
|
||||
if let Some(id) = content_block.get("id").and_then(|v| v.as_str()) {
|
||||
current_tool_id = Some(id.to_string());
|
||||
if let Some(name) = content_block.get("name").and_then(|v| v.as_str()) {
|
||||
accumulated_tool_calls.insert(id.to_string(), (name.to_string(), String::new()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
"content_block_delta" => {
|
||||
if let Some(delta) = event.data.get("delta") {
|
||||
if delta.get("type") == Some(&json!("text_delta")) {
|
||||
// Text content delta
|
||||
if let Some(text) = delta.get("text").and_then(|v| v.as_str()) {
|
||||
accumulated_text.push_str(text);
|
||||
|
||||
// Yield partial text message
|
||||
let message = Message::new(
|
||||
mcp_core::role::Role::Assistant,
|
||||
chrono::Utc::now().timestamp(),
|
||||
vec![MessageContent::text(text)],
|
||||
);
|
||||
yield (Some(message), None);
|
||||
}
|
||||
} else if delta.get("type") == Some(&json!("input_json_delta")) {
|
||||
// Tool input delta
|
||||
if let Some(tool_id) = ¤t_tool_id {
|
||||
if let Some(partial_json) = delta.get("partial_json").and_then(|v| v.as_str()) {
|
||||
if let Some((_name, args)) = accumulated_tool_calls.get_mut(tool_id) {
|
||||
args.push_str(partial_json);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
"content_block_stop" => {
|
||||
// Content block finished
|
||||
if let Some(tool_id) = current_tool_id.take() {
|
||||
// Tool call finished, yield complete tool call
|
||||
if let Some((name, args)) = accumulated_tool_calls.remove(&tool_id) {
|
||||
let parsed_args = if args.is_empty() {
|
||||
json!({})
|
||||
} else {
|
||||
match serde_json::from_str::<Value>(&args) {
|
||||
Ok(parsed) => parsed,
|
||||
Err(_) => {
|
||||
// If parsing fails, create an error tool request
|
||||
let error = mcp_core::handler::ToolError::InvalidParameters(
|
||||
format!("Could not parse tool arguments: {}", args)
|
||||
);
|
||||
let message = Message::new(
|
||||
mcp_core::role::Role::Assistant,
|
||||
chrono::Utc::now().timestamp(),
|
||||
vec![MessageContent::tool_request(tool_id, Err(error))],
|
||||
);
|
||||
yield (Some(message), None);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let tool_call = ToolCall::new(&name, parsed_args);
|
||||
let message = Message::new(
|
||||
mcp_core::role::Role::Assistant,
|
||||
chrono::Utc::now().timestamp(),
|
||||
vec![MessageContent::tool_request(tool_id, Ok(tool_call))],
|
||||
);
|
||||
yield (Some(message), None);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
"message_delta" => {
|
||||
// Message metadata delta (like stop_reason) and cumulative usage
|
||||
tracing::debug!("🔍 Anthropic message_delta event data: {}", serde_json::to_string_pretty(&event.data).unwrap_or_else(|_| format!("{:?}", event.data)));
|
||||
if let Some(usage_data) = event.data.get("usage") {
|
||||
tracing::debug!("🔍 Anthropic message_delta usage data (cumulative): {}", serde_json::to_string_pretty(usage_data).unwrap_or_else(|_| format!("{:?}", usage_data)));
|
||||
let delta_usage = get_usage(usage_data).unwrap_or_default();
|
||||
tracing::debug!("🔍 Anthropic message_delta parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}",
|
||||
delta_usage.input_tokens, delta_usage.output_tokens, delta_usage.total_tokens);
|
||||
|
||||
// IMPORTANT: message_delta usage should be MERGED with existing usage, not replace it
|
||||
// message_start has input tokens, message_delta has output tokens
|
||||
if let Some(existing_usage) = &final_usage {
|
||||
let merged_input = existing_usage.usage.input_tokens.or(delta_usage.input_tokens);
|
||||
let merged_output = delta_usage.output_tokens.or(existing_usage.usage.output_tokens);
|
||||
let merged_total = match (merged_input, merged_output) {
|
||||
(Some(input), Some(output)) => Some(input + output),
|
||||
(Some(input), None) => Some(input),
|
||||
(None, Some(output)) => Some(output),
|
||||
(None, None) => None,
|
||||
};
|
||||
|
||||
let merged_usage = crate::providers::base::Usage::new(merged_input, merged_output, merged_total);
|
||||
final_usage = Some(crate::providers::base::ProviderUsage::new(existing_usage.model.clone(), merged_usage));
|
||||
tracing::debug!("🔍 Anthropic MERGED usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}",
|
||||
merged_input, merged_output, merged_total);
|
||||
} else {
|
||||
// No existing usage, just use delta usage
|
||||
let model = event.data.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
final_usage = Some(crate::providers::base::ProviderUsage::new(model, delta_usage));
|
||||
tracing::debug!("🔍 Anthropic no existing usage, using delta usage");
|
||||
}
|
||||
} else {
|
||||
tracing::debug!("🔍 Anthropic message_delta event has no usage field");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
"message_stop" => {
|
||||
// Message finished, extract final usage if available
|
||||
if let Some(usage_data) = event.data.get("usage") {
|
||||
tracing::debug!("🔍 Anthropic streaming usage data: {}", serde_json::to_string_pretty(usage_data).unwrap_or_else(|_| format!("{:?}", usage_data)));
|
||||
let usage = get_usage(usage_data).unwrap_or_default();
|
||||
tracing::debug!("🔍 Anthropic parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}",
|
||||
usage.input_tokens, usage.output_tokens, usage.total_tokens);
|
||||
let model = event.data.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
tracing::debug!("🔍 Anthropic final_usage created with model: {}", model);
|
||||
final_usage = Some(crate::providers::base::ProviderUsage::new(model, usage));
|
||||
} else {
|
||||
tracing::debug!("🔍 Anthropic message_stop event has no usage data");
|
||||
}
|
||||
break;
|
||||
}
|
||||
_ => {
|
||||
// Unknown event type, log and continue
|
||||
tracing::debug!("Unknown streaming event type: {}", event.event_type);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Yield final usage information if available
|
||||
if let Some(usage) = final_usage {
|
||||
yield (None, Some(usage));
|
||||
} else {
|
||||
tracing::debug!("🔍 Anthropic no final usage to yield");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -449,9 +703,9 @@ mod tests {
|
||||
panic!("Expected Text content");
|
||||
}
|
||||
|
||||
assert_eq!(usage.input_tokens, Some(27)); // 12 * 1.0 + 12 * 1.25 = 27 effective tokens
|
||||
assert_eq!(usage.input_tokens, Some(24)); // 12 + 12 = 24 actual tokens
|
||||
assert_eq!(usage.output_tokens, Some(15));
|
||||
assert_eq!(usage.total_tokens, Some(42)); // 27 + 15
|
||||
assert_eq!(usage.total_tokens, Some(39)); // 24 + 15
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -492,9 +746,9 @@ mod tests {
|
||||
panic!("Expected ToolRequest content");
|
||||
}
|
||||
|
||||
assert_eq!(usage.input_tokens, Some(34)); // 15 * 1.0 + 15 * 1.25 = 33.75 → 34 effective tokens
|
||||
assert_eq!(usage.input_tokens, Some(30)); // 15 + 15 = 30 actual tokens
|
||||
assert_eq!(usage.output_tokens, Some(20));
|
||||
assert_eq!(usage.total_tokens, Some(54)); // 34 + 20
|
||||
assert_eq!(usage.total_tokens, Some(50)); // 30 + 20
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -718,11 +972,11 @@ mod tests {
|
||||
|
||||
let usage = get_usage(&response)?;
|
||||
|
||||
// Effective input tokens should be:
|
||||
// 7 * 1.0 + 10000 * 1.25 + 5000 * 0.10 = 7 + 12500 + 500 = 13007
|
||||
assert_eq!(usage.input_tokens, Some(13007));
|
||||
// ACTUAL input tokens should be:
|
||||
// 7 + 10000 + 5000 = 15007 total actual tokens
|
||||
assert_eq!(usage.input_tokens, Some(15007));
|
||||
assert_eq!(usage.output_tokens, Some(50));
|
||||
assert_eq!(usage.total_tokens, Some(13057)); // 13007 + 50
|
||||
assert_eq!(usage.total_tokens, Some(15057)); // 15007 + 50
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user