chore: implement streaming for anthropic.rs firstparty provider (#3419)

This commit is contained in:
Michael Neale
2025-07-17 07:32:55 +10:00
committed by GitHub
parent dc0008bd4f
commit eaba38adbc
2 changed files with 384 additions and 37 deletions

View File

@@ -1,13 +1,21 @@
use anyhow::Result;
use async_stream::try_stream;
use async_trait::async_trait;
use axum::http::HeaderMap;
use futures::TryStreamExt;
use reqwest::{Client, StatusCode};
use serde_json::Value;
use std::io;
use std::time::Duration;
use tokio::pin;
use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage};
use tokio_util::io::StreamReader;
use super::base::{ConfigKey, MessageStream, ModelInfo, Provider, ProviderMetadata, ProviderUsage};
use super::errors::ProviderError;
use super::formats::anthropic::{create_request, get_usage, response_to_message};
use super::formats::anthropic::{
create_request, get_usage, response_to_message, response_to_streaming_message,
};
use super::utils::{emit_debug_trace, get_model};
use crate::message::Message;
use crate::model::ModelConfig;
@@ -195,10 +203,17 @@ impl Provider for AnthropicProvider {
// Parse response
let message = response_to_message(response.clone())?;
let usage = get_usage(&response)?;
tracing::debug!("🔍 Anthropic non-streaming parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}",
usage.input_tokens, usage.output_tokens, usage.total_tokens);
let model = get_model(&response);
emit_debug_trace(&self.model, &payload, &response, &usage);
Ok((message, ProviderUsage::new(model, usage)))
let provider_usage = ProviderUsage::new(model, usage);
tracing::debug!(
"🔍 Anthropic non-streaming returning ProviderUsage: {:?}",
provider_usage
);
Ok((message, provider_usage))
}
/// Fetch supported models from Anthropic; returns Err on failure, Ok(None) if not present
@@ -232,4 +247,82 @@ impl Provider for AnthropicProvider {
models.sort();
Ok(Some(models))
}
async fn stream(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<MessageStream, ProviderError> {
let mut payload = create_request(&self.model, system, messages, tools)?;
// Add stream parameter
payload
.as_object_mut()
.unwrap()
.insert("stream".to_string(), Value::Bool(true));
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("x-api-key", self.api_key.parse().unwrap());
headers.insert("anthropic-version", ANTHROPIC_API_VERSION.parse().unwrap());
let is_thinking_enabled = std::env::var("CLAUDE_THINKING_ENABLED").is_ok();
if self.model.model_name.starts_with("claude-3-7-sonnet-") && is_thinking_enabled {
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#extended-output-capabilities-beta
headers.insert("anthropic-beta", "output-128k-2025-02-19".parse().unwrap());
}
if self.model.model_name.starts_with("claude-3-7-sonnet-") {
// https://docs.anthropic.com/en/docs/build-with-claude/tool-use/token-efficient-tool-use
headers.insert(
"anthropic-beta",
"token-efficient-tools-2025-02-19".parse().unwrap(),
);
}
let base_url = url::Url::parse(&self.host)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
let url = base_url.join("v1/messages").map_err(|e| {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;
let response = self
.client
.post(url)
.headers(headers)
.json(&payload)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(ProviderError::RequestFailed(format!(
"Streaming request failed with status: {}. Error: {}",
status, error_text
)));
}
// Map reqwest error to io::Error
let stream = response.bytes_stream().map_err(io::Error::other);
let model_config = self.model.clone();
// Wrap in a line decoder and yield lines inside the stream
Ok(Box::pin(try_stream! {
let stream_reader = StreamReader::new(stream);
let framed = tokio_util::codec::FramedRead::new(stream_reader, tokio_util::codec::LinesCodec::new()).map_err(anyhow::Error::from);
let message_stream = response_to_streaming_message(framed);
pin!(message_stream);
while let Some(message) = futures::StreamExt::next(&mut message_stream).await {
let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?;
super::utils::emit_debug_trace(&model_config, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default());
yield (message, usage);
}
}))
}
fn supports_streaming(&self) -> bool {
true
}
}

View File

@@ -292,40 +292,68 @@ pub fn get_usage(data: &Value) -> Result<Usage> {
.and_then(|v| v.as_u64())
.unwrap_or(0);
// IMPORTANT: Based on the API responses, when caching is used:
// - input_tokens is ONLY the new/fresh tokens (can be very small, like 7)
// - cache_creation_input_tokens and cache_read_input_tokens are the cached content
// - These cached tokens are charged at different rates:
// * Fresh input tokens: 100% of regular price
// * Cache creation tokens: 125% of regular price
// * Cache read tokens: 10% of regular price
//
// Calculate effective input tokens for cost calculation based on Anthropic's pricing:
// - Fresh input tokens: 100% of regular price (1.0x)
// - Cache creation tokens: 125% of regular price (1.25x)
// - Cache read tokens: 10% of regular price (0.10x)
//
// The effective input tokens represent the cost-equivalent tokens when multiplied
// by the regular input price, ensuring accurate cost calculations in the frontend.
let effective_input_tokens = input_tokens as f64 * 1.0
+ cache_creation_tokens as f64 * 1.25
+ cache_read_tokens as f64 * 0.10;
// IMPORTANT: For display purposes, we want to show the ACTUAL total tokens consumed
// The cache pricing should only affect cost calculation, not token count display
let total_input_tokens = input_tokens + cache_creation_tokens + cache_read_tokens;
// For token counting purposes, we still want to show the actual total count
let _total_actual_tokens = input_tokens + cache_creation_tokens + cache_read_tokens;
// Return the effective input tokens for cost calculation
// This ensures the frontend cost calculation is accurate when multiplying by regular prices
let effective_input_i32 = effective_input_tokens.round().clamp(0.0, i32::MAX as f64) as i32;
// Convert to i32 with bounds checking
let total_input_i32 = total_input_tokens.min(i32::MAX as u64) as i32;
let output_tokens_i32 = output_tokens.min(i32::MAX as u64) as i32;
let total_tokens_i32 =
(effective_input_i32 as i64 + output_tokens_i32 as i64).min(i32::MAX as i64) as i32;
(total_input_i32 as i64 + output_tokens_i32 as i64).min(i32::MAX as i64) as i32;
Ok(Usage::new(
Some(effective_input_i32),
Some(total_input_i32),
Some(output_tokens_i32),
Some(total_tokens_i32),
))
} else if data.as_object().is_some() {
// Check if the data itself is the usage object (for message_delta events that might have usage at top level)
let input_tokens = data
.get("input_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0);
let cache_creation_tokens = data
.get("cache_creation_input_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0);
let cache_read_tokens = data
.get("cache_read_input_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0);
let output_tokens = data
.get("output_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0);
// If we found any token data, process it
if input_tokens > 0
|| cache_creation_tokens > 0
|| cache_read_tokens > 0
|| output_tokens > 0
{
let total_input_tokens = input_tokens + cache_creation_tokens + cache_read_tokens;
let total_input_i32 = total_input_tokens.min(i32::MAX as u64) as i32;
let output_tokens_i32 = output_tokens.min(i32::MAX as u64) as i32;
let total_tokens_i32 =
(total_input_i32 as i64 + output_tokens_i32 as i64).min(i32::MAX as i64) as i32;
tracing::debug!("🔍 Anthropic ACTUAL token counts from direct object: input={}, output={}, total={}",
total_input_i32, output_tokens_i32, total_tokens_i32);
Ok(Usage::new(
Some(total_input_i32),
Some(output_tokens_i32),
Some(total_tokens_i32),
))
} else {
tracing::debug!("🔍 Anthropic no token data found in object");
Ok(Usage::new(None, None, None))
}
} else {
tracing::debug!(
"Failed to get usage data: {}",
@@ -414,6 +442,232 @@ pub fn create_request(
Ok(payload)
}
/// Process streaming response from Anthropic's API
pub fn response_to_streaming_message<S>(
mut stream: S,
) -> impl futures::Stream<
Item = anyhow::Result<(
Option<Message>,
Option<crate::providers::base::ProviderUsage>,
)>,
> + 'static
where
S: futures::Stream<Item = anyhow::Result<String>> + Unpin + Send + 'static,
{
use async_stream::try_stream;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug)]
struct StreamingEvent {
#[serde(rename = "type")]
event_type: String,
#[serde(flatten)]
data: Value,
}
try_stream! {
let mut accumulated_text = String::new();
let mut accumulated_tool_calls: std::collections::HashMap<String, (String, String)> = std::collections::HashMap::new();
let mut current_tool_id: Option<String> = None;
let mut final_usage: Option<crate::providers::base::ProviderUsage> = None;
while let Some(line_result) = stream.next().await {
let line = line_result?;
// Skip empty lines and non-data lines
if line.trim().is_empty() || !line.starts_with("data: ") {
continue;
}
let data_part = line.strip_prefix("data: ").unwrap_or(&line);
// Handle end of stream
if data_part.trim() == "[DONE]" {
break;
}
// Parse the JSON event
let event: StreamingEvent = match serde_json::from_str(data_part) {
Ok(event) => event,
Err(e) => {
tracing::debug!("Failed to parse streaming event: {} - Line: {}", e, data_part);
continue;
}
};
match event.event_type.as_str() {
"message_start" => {
// Message started, we can extract initial metadata and usage if needed
if let Some(message_data) = event.data.get("message") {
if let Some(usage_data) = message_data.get("usage") {
let usage = get_usage(usage_data).unwrap_or_default();
tracing::debug!("🔍 Anthropic message_start parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}",
usage.input_tokens, usage.output_tokens, usage.total_tokens);
let model = message_data.get("model")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
final_usage = Some(crate::providers::base::ProviderUsage::new(model, usage));
} else {
tracing::debug!("🔍 Anthropic message_start has no usage data");
}
}
continue;
}
"content_block_start" => {
// A new content block started
if let Some(content_block) = event.data.get("content_block") {
if content_block.get("type") == Some(&json!("tool_use")) {
if let Some(id) = content_block.get("id").and_then(|v| v.as_str()) {
current_tool_id = Some(id.to_string());
if let Some(name) = content_block.get("name").and_then(|v| v.as_str()) {
accumulated_tool_calls.insert(id.to_string(), (name.to_string(), String::new()));
}
}
}
}
continue;
}
"content_block_delta" => {
if let Some(delta) = event.data.get("delta") {
if delta.get("type") == Some(&json!("text_delta")) {
// Text content delta
if let Some(text) = delta.get("text").and_then(|v| v.as_str()) {
accumulated_text.push_str(text);
// Yield partial text message
let message = Message::new(
mcp_core::role::Role::Assistant,
chrono::Utc::now().timestamp(),
vec![MessageContent::text(text)],
);
yield (Some(message), None);
}
} else if delta.get("type") == Some(&json!("input_json_delta")) {
// Tool input delta
if let Some(tool_id) = &current_tool_id {
if let Some(partial_json) = delta.get("partial_json").and_then(|v| v.as_str()) {
if let Some((_name, args)) = accumulated_tool_calls.get_mut(tool_id) {
args.push_str(partial_json);
}
}
}
}
}
continue;
}
"content_block_stop" => {
// Content block finished
if let Some(tool_id) = current_tool_id.take() {
// Tool call finished, yield complete tool call
if let Some((name, args)) = accumulated_tool_calls.remove(&tool_id) {
let parsed_args = if args.is_empty() {
json!({})
} else {
match serde_json::from_str::<Value>(&args) {
Ok(parsed) => parsed,
Err(_) => {
// If parsing fails, create an error tool request
let error = mcp_core::handler::ToolError::InvalidParameters(
format!("Could not parse tool arguments: {}", args)
);
let message = Message::new(
mcp_core::role::Role::Assistant,
chrono::Utc::now().timestamp(),
vec![MessageContent::tool_request(tool_id, Err(error))],
);
yield (Some(message), None);
continue;
}
}
};
let tool_call = ToolCall::new(&name, parsed_args);
let message = Message::new(
mcp_core::role::Role::Assistant,
chrono::Utc::now().timestamp(),
vec![MessageContent::tool_request(tool_id, Ok(tool_call))],
);
yield (Some(message), None);
}
}
continue;
}
"message_delta" => {
// Message metadata delta (like stop_reason) and cumulative usage
tracing::debug!("🔍 Anthropic message_delta event data: {}", serde_json::to_string_pretty(&event.data).unwrap_or_else(|_| format!("{:?}", event.data)));
if let Some(usage_data) = event.data.get("usage") {
tracing::debug!("🔍 Anthropic message_delta usage data (cumulative): {}", serde_json::to_string_pretty(usage_data).unwrap_or_else(|_| format!("{:?}", usage_data)));
let delta_usage = get_usage(usage_data).unwrap_or_default();
tracing::debug!("🔍 Anthropic message_delta parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}",
delta_usage.input_tokens, delta_usage.output_tokens, delta_usage.total_tokens);
// IMPORTANT: message_delta usage should be MERGED with existing usage, not replace it
// message_start has input tokens, message_delta has output tokens
if let Some(existing_usage) = &final_usage {
let merged_input = existing_usage.usage.input_tokens.or(delta_usage.input_tokens);
let merged_output = delta_usage.output_tokens.or(existing_usage.usage.output_tokens);
let merged_total = match (merged_input, merged_output) {
(Some(input), Some(output)) => Some(input + output),
(Some(input), None) => Some(input),
(None, Some(output)) => Some(output),
(None, None) => None,
};
let merged_usage = crate::providers::base::Usage::new(merged_input, merged_output, merged_total);
final_usage = Some(crate::providers::base::ProviderUsage::new(existing_usage.model.clone(), merged_usage));
tracing::debug!("🔍 Anthropic MERGED usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}",
merged_input, merged_output, merged_total);
} else {
// No existing usage, just use delta usage
let model = event.data.get("model")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
final_usage = Some(crate::providers::base::ProviderUsage::new(model, delta_usage));
tracing::debug!("🔍 Anthropic no existing usage, using delta usage");
}
} else {
tracing::debug!("🔍 Anthropic message_delta event has no usage field");
}
continue;
}
"message_stop" => {
// Message finished, extract final usage if available
if let Some(usage_data) = event.data.get("usage") {
tracing::debug!("🔍 Anthropic streaming usage data: {}", serde_json::to_string_pretty(usage_data).unwrap_or_else(|_| format!("{:?}", usage_data)));
let usage = get_usage(usage_data).unwrap_or_default();
tracing::debug!("🔍 Anthropic parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}",
usage.input_tokens, usage.output_tokens, usage.total_tokens);
let model = event.data.get("model")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
tracing::debug!("🔍 Anthropic final_usage created with model: {}", model);
final_usage = Some(crate::providers::base::ProviderUsage::new(model, usage));
} else {
tracing::debug!("🔍 Anthropic message_stop event has no usage data");
}
break;
}
_ => {
// Unknown event type, log and continue
tracing::debug!("Unknown streaming event type: {}", event.event_type);
continue;
}
}
}
// Yield final usage information if available
if let Some(usage) = final_usage {
yield (None, Some(usage));
} else {
tracing::debug!("🔍 Anthropic no final usage to yield");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -449,9 +703,9 @@ mod tests {
panic!("Expected Text content");
}
assert_eq!(usage.input_tokens, Some(27)); // 12 * 1.0 + 12 * 1.25 = 27 effective tokens
assert_eq!(usage.input_tokens, Some(24)); // 12 + 12 = 24 actual tokens
assert_eq!(usage.output_tokens, Some(15));
assert_eq!(usage.total_tokens, Some(42)); // 27 + 15
assert_eq!(usage.total_tokens, Some(39)); // 24 + 15
Ok(())
}
@@ -492,9 +746,9 @@ mod tests {
panic!("Expected ToolRequest content");
}
assert_eq!(usage.input_tokens, Some(34)); // 15 * 1.0 + 15 * 1.25 = 33.75 → 34 effective tokens
assert_eq!(usage.input_tokens, Some(30)); // 15 + 15 = 30 actual tokens
assert_eq!(usage.output_tokens, Some(20));
assert_eq!(usage.total_tokens, Some(54)); // 34 + 20
assert_eq!(usage.total_tokens, Some(50)); // 30 + 20
Ok(())
}
@@ -718,11 +972,11 @@ mod tests {
let usage = get_usage(&response)?;
// Effective input tokens should be:
// 7 * 1.0 + 10000 * 1.25 + 5000 * 0.10 = 7 + 12500 + 500 = 13007
assert_eq!(usage.input_tokens, Some(13007));
// ACTUAL input tokens should be:
// 7 + 10000 + 5000 = 15007 total actual tokens
assert_eq!(usage.input_tokens, Some(15007));
assert_eq!(usage.output_tokens, Some(50));
assert_eq!(usage.total_tokens, Some(13057)); // 13007 + 50
assert_eq!(usage.total_tokens, Some(15057)); // 15007 + 50
Ok(())
}