From eaba38adbc79dae50f0df39ccf47adcc3924c553 Mon Sep 17 00:00:00 2001 From: Michael Neale Date: Thu, 17 Jul 2025 07:32:55 +1000 Subject: [PATCH] chore: implement streaming for anthropic.rs firstparty provider (#3419) --- crates/goose/src/providers/anthropic.rs | 99 +++++- .../goose/src/providers/formats/anthropic.rs | 322 ++++++++++++++++-- 2 files changed, 384 insertions(+), 37 deletions(-) diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 88a71b0f..6a3c9304 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -1,13 +1,21 @@ use anyhow::Result; +use async_stream::try_stream; use async_trait::async_trait; use axum::http::HeaderMap; +use futures::TryStreamExt; use reqwest::{Client, StatusCode}; use serde_json::Value; +use std::io; use std::time::Duration; +use tokio::pin; -use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage}; +use tokio_util::io::StreamReader; + +use super::base::{ConfigKey, MessageStream, ModelInfo, Provider, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; -use super::formats::anthropic::{create_request, get_usage, response_to_message}; +use super::formats::anthropic::{ + create_request, get_usage, response_to_message, response_to_streaming_message, +}; use super::utils::{emit_debug_trace, get_model}; use crate::message::Message; use crate::model::ModelConfig; @@ -195,10 +203,17 @@ impl Provider for AnthropicProvider { // Parse response let message = response_to_message(response.clone())?; let usage = get_usage(&response)?; + tracing::debug!("🔍 Anthropic non-streaming parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}", + usage.input_tokens, usage.output_tokens, usage.total_tokens); let model = get_model(&response); emit_debug_trace(&self.model, &payload, &response, &usage); - Ok((message, ProviderUsage::new(model, usage))) + let provider_usage = ProviderUsage::new(model, usage); + tracing::debug!( + "🔍 Anthropic non-streaming returning ProviderUsage: {:?}", + provider_usage + ); + Ok((message, provider_usage)) } /// Fetch supported models from Anthropic; returns Err on failure, Ok(None) if not present @@ -232,4 +247,82 @@ impl Provider for AnthropicProvider { models.sort(); Ok(Some(models)) } + + async fn stream( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result { + let mut payload = create_request(&self.model, system, messages, tools)?; + + // Add stream parameter + payload + .as_object_mut() + .unwrap() + .insert("stream".to_string(), Value::Bool(true)); + + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("x-api-key", self.api_key.parse().unwrap()); + headers.insert("anthropic-version", ANTHROPIC_API_VERSION.parse().unwrap()); + + let is_thinking_enabled = std::env::var("CLAUDE_THINKING_ENABLED").is_ok(); + if self.model.model_name.starts_with("claude-3-7-sonnet-") && is_thinking_enabled { + // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#extended-output-capabilities-beta + headers.insert("anthropic-beta", "output-128k-2025-02-19".parse().unwrap()); + } + + if self.model.model_name.starts_with("claude-3-7-sonnet-") { + // https://docs.anthropic.com/en/docs/build-with-claude/tool-use/token-efficient-tool-use + headers.insert( + "anthropic-beta", + "token-efficient-tools-2025-02-19".parse().unwrap(), + ); + } + + let base_url = url::Url::parse(&self.host) + .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; + let url = base_url.join("v1/messages").map_err(|e| { + ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) + })?; + + let response = self + .client + .post(url) + .headers(headers) + .json(&payload) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_default(); + return Err(ProviderError::RequestFailed(format!( + "Streaming request failed with status: {}. Error: {}", + status, error_text + ))); + } + + // Map reqwest error to io::Error + let stream = response.bytes_stream().map_err(io::Error::other); + + let model_config = self.model.clone(); + // Wrap in a line decoder and yield lines inside the stream + Ok(Box::pin(try_stream! { + let stream_reader = StreamReader::new(stream); + let framed = tokio_util::codec::FramedRead::new(stream_reader, tokio_util::codec::LinesCodec::new()).map_err(anyhow::Error::from); + + let message_stream = response_to_streaming_message(framed); + pin!(message_stream); + while let Some(message) = futures::StreamExt::next(&mut message_stream).await { + let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?; + super::utils::emit_debug_trace(&model_config, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default()); + yield (message, usage); + } + })) + } + + fn supports_streaming(&self) -> bool { + true + } } diff --git a/crates/goose/src/providers/formats/anthropic.rs b/crates/goose/src/providers/formats/anthropic.rs index 661d2bce..c4656058 100644 --- a/crates/goose/src/providers/formats/anthropic.rs +++ b/crates/goose/src/providers/formats/anthropic.rs @@ -292,40 +292,68 @@ pub fn get_usage(data: &Value) -> Result { .and_then(|v| v.as_u64()) .unwrap_or(0); - // IMPORTANT: Based on the API responses, when caching is used: - // - input_tokens is ONLY the new/fresh tokens (can be very small, like 7) - // - cache_creation_input_tokens and cache_read_input_tokens are the cached content - // - These cached tokens are charged at different rates: - // * Fresh input tokens: 100% of regular price - // * Cache creation tokens: 125% of regular price - // * Cache read tokens: 10% of regular price - // - // Calculate effective input tokens for cost calculation based on Anthropic's pricing: - // - Fresh input tokens: 100% of regular price (1.0x) - // - Cache creation tokens: 125% of regular price (1.25x) - // - Cache read tokens: 10% of regular price (0.10x) - // - // The effective input tokens represent the cost-equivalent tokens when multiplied - // by the regular input price, ensuring accurate cost calculations in the frontend. - let effective_input_tokens = input_tokens as f64 * 1.0 - + cache_creation_tokens as f64 * 1.25 - + cache_read_tokens as f64 * 0.10; + // IMPORTANT: For display purposes, we want to show the ACTUAL total tokens consumed + // The cache pricing should only affect cost calculation, not token count display + let total_input_tokens = input_tokens + cache_creation_tokens + cache_read_tokens; - // For token counting purposes, we still want to show the actual total count - let _total_actual_tokens = input_tokens + cache_creation_tokens + cache_read_tokens; - - // Return the effective input tokens for cost calculation - // This ensures the frontend cost calculation is accurate when multiplying by regular prices - let effective_input_i32 = effective_input_tokens.round().clamp(0.0, i32::MAX as f64) as i32; + // Convert to i32 with bounds checking + let total_input_i32 = total_input_tokens.min(i32::MAX as u64) as i32; let output_tokens_i32 = output_tokens.min(i32::MAX as u64) as i32; let total_tokens_i32 = - (effective_input_i32 as i64 + output_tokens_i32 as i64).min(i32::MAX as i64) as i32; + (total_input_i32 as i64 + output_tokens_i32 as i64).min(i32::MAX as i64) as i32; Ok(Usage::new( - Some(effective_input_i32), + Some(total_input_i32), Some(output_tokens_i32), Some(total_tokens_i32), )) + } else if data.as_object().is_some() { + // Check if the data itself is the usage object (for message_delta events that might have usage at top level) + let input_tokens = data + .get("input_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0); + + let cache_creation_tokens = data + .get("cache_creation_input_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0); + + let cache_read_tokens = data + .get("cache_read_input_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0); + + let output_tokens = data + .get("output_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0); + + // If we found any token data, process it + if input_tokens > 0 + || cache_creation_tokens > 0 + || cache_read_tokens > 0 + || output_tokens > 0 + { + let total_input_tokens = input_tokens + cache_creation_tokens + cache_read_tokens; + + let total_input_i32 = total_input_tokens.min(i32::MAX as u64) as i32; + let output_tokens_i32 = output_tokens.min(i32::MAX as u64) as i32; + let total_tokens_i32 = + (total_input_i32 as i64 + output_tokens_i32 as i64).min(i32::MAX as i64) as i32; + + tracing::debug!("🔍 Anthropic ACTUAL token counts from direct object: input={}, output={}, total={}", + total_input_i32, output_tokens_i32, total_tokens_i32); + + Ok(Usage::new( + Some(total_input_i32), + Some(output_tokens_i32), + Some(total_tokens_i32), + )) + } else { + tracing::debug!("🔍 Anthropic no token data found in object"); + Ok(Usage::new(None, None, None)) + } } else { tracing::debug!( "Failed to get usage data: {}", @@ -414,6 +442,232 @@ pub fn create_request( Ok(payload) } +/// Process streaming response from Anthropic's API +pub fn response_to_streaming_message( + mut stream: S, +) -> impl futures::Stream< + Item = anyhow::Result<( + Option, + Option, + )>, +> + 'static +where + S: futures::Stream> + Unpin + Send + 'static, +{ + use async_stream::try_stream; + use futures::StreamExt; + use serde::{Deserialize, Serialize}; + + #[derive(Serialize, Deserialize, Debug)] + struct StreamingEvent { + #[serde(rename = "type")] + event_type: String, + #[serde(flatten)] + data: Value, + } + + try_stream! { + let mut accumulated_text = String::new(); + let mut accumulated_tool_calls: std::collections::HashMap = std::collections::HashMap::new(); + let mut current_tool_id: Option = None; + let mut final_usage: Option = None; + + while let Some(line_result) = stream.next().await { + let line = line_result?; + + // Skip empty lines and non-data lines + if line.trim().is_empty() || !line.starts_with("data: ") { + continue; + } + + let data_part = line.strip_prefix("data: ").unwrap_or(&line); + + // Handle end of stream + if data_part.trim() == "[DONE]" { + break; + } + + // Parse the JSON event + let event: StreamingEvent = match serde_json::from_str(data_part) { + Ok(event) => event, + Err(e) => { + tracing::debug!("Failed to parse streaming event: {} - Line: {}", e, data_part); + continue; + } + }; + + match event.event_type.as_str() { + "message_start" => { + // Message started, we can extract initial metadata and usage if needed + if let Some(message_data) = event.data.get("message") { + if let Some(usage_data) = message_data.get("usage") { + let usage = get_usage(usage_data).unwrap_or_default(); + tracing::debug!("🔍 Anthropic message_start parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}", + usage.input_tokens, usage.output_tokens, usage.total_tokens); + let model = message_data.get("model") + .and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(); + final_usage = Some(crate::providers::base::ProviderUsage::new(model, usage)); + } else { + tracing::debug!("🔍 Anthropic message_start has no usage data"); + } + } + continue; + } + "content_block_start" => { + // A new content block started + if let Some(content_block) = event.data.get("content_block") { + if content_block.get("type") == Some(&json!("tool_use")) { + if let Some(id) = content_block.get("id").and_then(|v| v.as_str()) { + current_tool_id = Some(id.to_string()); + if let Some(name) = content_block.get("name").and_then(|v| v.as_str()) { + accumulated_tool_calls.insert(id.to_string(), (name.to_string(), String::new())); + } + } + } + } + continue; + } + "content_block_delta" => { + if let Some(delta) = event.data.get("delta") { + if delta.get("type") == Some(&json!("text_delta")) { + // Text content delta + if let Some(text) = delta.get("text").and_then(|v| v.as_str()) { + accumulated_text.push_str(text); + + // Yield partial text message + let message = Message::new( + mcp_core::role::Role::Assistant, + chrono::Utc::now().timestamp(), + vec![MessageContent::text(text)], + ); + yield (Some(message), None); + } + } else if delta.get("type") == Some(&json!("input_json_delta")) { + // Tool input delta + if let Some(tool_id) = ¤t_tool_id { + if let Some(partial_json) = delta.get("partial_json").and_then(|v| v.as_str()) { + if let Some((_name, args)) = accumulated_tool_calls.get_mut(tool_id) { + args.push_str(partial_json); + } + } + } + } + } + continue; + } + "content_block_stop" => { + // Content block finished + if let Some(tool_id) = current_tool_id.take() { + // Tool call finished, yield complete tool call + if let Some((name, args)) = accumulated_tool_calls.remove(&tool_id) { + let parsed_args = if args.is_empty() { + json!({}) + } else { + match serde_json::from_str::(&args) { + Ok(parsed) => parsed, + Err(_) => { + // If parsing fails, create an error tool request + let error = mcp_core::handler::ToolError::InvalidParameters( + format!("Could not parse tool arguments: {}", args) + ); + let message = Message::new( + mcp_core::role::Role::Assistant, + chrono::Utc::now().timestamp(), + vec![MessageContent::tool_request(tool_id, Err(error))], + ); + yield (Some(message), None); + continue; + } + } + }; + + let tool_call = ToolCall::new(&name, parsed_args); + let message = Message::new( + mcp_core::role::Role::Assistant, + chrono::Utc::now().timestamp(), + vec![MessageContent::tool_request(tool_id, Ok(tool_call))], + ); + yield (Some(message), None); + } + } + continue; + } + "message_delta" => { + // Message metadata delta (like stop_reason) and cumulative usage + tracing::debug!("🔍 Anthropic message_delta event data: {}", serde_json::to_string_pretty(&event.data).unwrap_or_else(|_| format!("{:?}", event.data))); + if let Some(usage_data) = event.data.get("usage") { + tracing::debug!("🔍 Anthropic message_delta usage data (cumulative): {}", serde_json::to_string_pretty(usage_data).unwrap_or_else(|_| format!("{:?}", usage_data))); + let delta_usage = get_usage(usage_data).unwrap_or_default(); + tracing::debug!("🔍 Anthropic message_delta parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}", + delta_usage.input_tokens, delta_usage.output_tokens, delta_usage.total_tokens); + + // IMPORTANT: message_delta usage should be MERGED with existing usage, not replace it + // message_start has input tokens, message_delta has output tokens + if let Some(existing_usage) = &final_usage { + let merged_input = existing_usage.usage.input_tokens.or(delta_usage.input_tokens); + let merged_output = delta_usage.output_tokens.or(existing_usage.usage.output_tokens); + let merged_total = match (merged_input, merged_output) { + (Some(input), Some(output)) => Some(input + output), + (Some(input), None) => Some(input), + (None, Some(output)) => Some(output), + (None, None) => None, + }; + + let merged_usage = crate::providers::base::Usage::new(merged_input, merged_output, merged_total); + final_usage = Some(crate::providers::base::ProviderUsage::new(existing_usage.model.clone(), merged_usage)); + tracing::debug!("🔍 Anthropic MERGED usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}", + merged_input, merged_output, merged_total); + } else { + // No existing usage, just use delta usage + let model = event.data.get("model") + .and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(); + final_usage = Some(crate::providers::base::ProviderUsage::new(model, delta_usage)); + tracing::debug!("🔍 Anthropic no existing usage, using delta usage"); + } + } else { + tracing::debug!("🔍 Anthropic message_delta event has no usage field"); + } + continue; + } + "message_stop" => { + // Message finished, extract final usage if available + if let Some(usage_data) = event.data.get("usage") { + tracing::debug!("🔍 Anthropic streaming usage data: {}", serde_json::to_string_pretty(usage_data).unwrap_or_else(|_| format!("{:?}", usage_data))); + let usage = get_usage(usage_data).unwrap_or_default(); + tracing::debug!("🔍 Anthropic parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}", + usage.input_tokens, usage.output_tokens, usage.total_tokens); + let model = event.data.get("model") + .and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(); + tracing::debug!("🔍 Anthropic final_usage created with model: {}", model); + final_usage = Some(crate::providers::base::ProviderUsage::new(model, usage)); + } else { + tracing::debug!("🔍 Anthropic message_stop event has no usage data"); + } + break; + } + _ => { + // Unknown event type, log and continue + tracing::debug!("Unknown streaming event type: {}", event.event_type); + continue; + } + } + } + + // Yield final usage information if available + if let Some(usage) = final_usage { + yield (None, Some(usage)); + } else { + tracing::debug!("🔍 Anthropic no final usage to yield"); + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -449,9 +703,9 @@ mod tests { panic!("Expected Text content"); } - assert_eq!(usage.input_tokens, Some(27)); // 12 * 1.0 + 12 * 1.25 = 27 effective tokens + assert_eq!(usage.input_tokens, Some(24)); // 12 + 12 = 24 actual tokens assert_eq!(usage.output_tokens, Some(15)); - assert_eq!(usage.total_tokens, Some(42)); // 27 + 15 + assert_eq!(usage.total_tokens, Some(39)); // 24 + 15 Ok(()) } @@ -492,9 +746,9 @@ mod tests { panic!("Expected ToolRequest content"); } - assert_eq!(usage.input_tokens, Some(34)); // 15 * 1.0 + 15 * 1.25 = 33.75 → 34 effective tokens + assert_eq!(usage.input_tokens, Some(30)); // 15 + 15 = 30 actual tokens assert_eq!(usage.output_tokens, Some(20)); - assert_eq!(usage.total_tokens, Some(54)); // 34 + 20 + assert_eq!(usage.total_tokens, Some(50)); // 30 + 20 Ok(()) } @@ -718,11 +972,11 @@ mod tests { let usage = get_usage(&response)?; - // Effective input tokens should be: - // 7 * 1.0 + 10000 * 1.25 + 5000 * 0.10 = 7 + 12500 + 500 = 13007 - assert_eq!(usage.input_tokens, Some(13007)); + // ACTUAL input tokens should be: + // 7 + 10000 + 5000 = 15007 total actual tokens + assert_eq!(usage.input_tokens, Some(15007)); assert_eq!(usage.output_tokens, Some(50)); - assert_eq!(usage.total_tokens, Some(13057)); // 13007 + 50 + assert_eq!(usage.total_tokens, Some(15057)); // 15007 + 50 Ok(()) }