feat: stream LLM responses (#2677)

Co-authored-by: Michael Neale <michael.neale@gmail.com>
This commit is contained in:
Jack Amadeo
2025-07-14 14:57:03 -04:00
committed by GitHub
parent 99e4deed3e
commit fde3a578a5
34 changed files with 943 additions and 564 deletions

5
Cargo.lock generated
View File

@@ -3486,6 +3486,7 @@ dependencies = [
"tokio",
"tokio-cron-scheduler",
"tokio-stream",
"tokio-util",
"tracing",
"tracing-subscriber",
"url",
@@ -8604,9 +8605,9 @@ dependencies = [
[[package]]
name = "tokio-util"
version = "0.7.13"
version = "0.7.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078"
checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df"
dependencies = [
"bytes",
"futures-core",

View File

@@ -10,6 +10,7 @@ pub use self::export::message_to_markdown;
pub use builder::{build_session, SessionBuilderConfig, SessionSettings};
use console::Color;
use goose::agents::AgentEvent;
use goose::message::push_message;
use goose::permission::permission_confirmation::PrincipalType;
use goose::permission::Permission;
use goose::permission::PermissionConfirmation;
@@ -356,7 +357,7 @@ impl Session {
/// Process a single message and get the response
async fn process_message(&mut self, message: String) -> Result<()> {
self.messages.push(Message::user().with_text(&message));
self.push_message(Message::user().with_text(&message));
// Get the provider from the agent for description generation
let provider = self.agent.provider().await?;
@@ -462,7 +463,7 @@ impl Session {
RunMode::Normal => {
save_history(&mut editor);
self.messages.push(Message::user().with_text(&content));
self.push_message(Message::user().with_text(&content));
// Track the current directory and last instruction in projects.json
let session_id = self
@@ -785,7 +786,7 @@ impl Session {
self.messages.clear();
// add the plan response as a user message
let plan_message = Message::user().with_text(plan_response.as_concat_text());
self.messages.push(plan_message);
self.push_message(plan_message);
// act on the plan
output::show_thinking();
self.process_agent_response(true).await?;
@@ -800,13 +801,13 @@ impl Session {
} else {
// add the plan response (assistant message) & carry the conversation forward
// in the next round, the user might wanna slightly modify the plan
self.messages.push(plan_response);
self.push_message(plan_response);
}
}
PlannerResponseType::ClarifyingQuestions => {
// add the plan response (assistant message) & carry the conversation forward
// in the next round, the user will answer the clarifying questions
self.messages.push(plan_response);
self.push_message(plan_response);
}
}
@@ -878,7 +879,7 @@ impl Session {
confirmation.id.clone(),
Err(ToolError::ExecutionError("Tool call cancelled by user".to_string()))
));
self.messages.push(response_message);
push_message(&mut self.messages, response_message);
if let Some(session_file) = &self.session_file {
session::persist_messages_with_schedule_id(
session_file,
@@ -975,7 +976,7 @@ impl Session {
}
// otherwise we have a model/tool to render
else {
self.messages.push(message.clone());
push_message(&mut self.messages, message.clone());
// No need to update description on assistant messages
if let Some(session_file) = &self.session_file {
@@ -991,7 +992,6 @@ impl Session {
if interactive {output::hide_thinking()};
let _ = progress_bars.hide();
output::render_message(&message, self.debug);
if interactive {output::show_thinking()};
}
}
Some(Ok(AgentEvent::McpNotification((_id, message)))) => {
@@ -1139,6 +1139,7 @@ impl Session {
}
}
}
println!();
Ok(())
}
@@ -1182,7 +1183,7 @@ impl Session {
Err(ToolError::ExecutionError(notification.clone())),
));
}
self.messages.push(response_message);
self.push_message(response_message);
// No need for description update here
if let Some(session_file) = &self.session_file {
@@ -1199,7 +1200,7 @@ impl Session {
"The existing call to {} was interrupted. How would you like to proceed?",
last_tool_name
);
self.messages.push(Message::assistant().with_text(&prompt));
self.push_message(Message::assistant().with_text(&prompt));
// No need for description update here
if let Some(session_file) = &self.session_file {
@@ -1221,7 +1222,7 @@ impl Session {
Some(MessageContent::ToolResponse(_)) => {
// Interruption occurred after a tool had completed but not assistant reply
let prompt = "The tool calling loop was interrupted. How would you like to proceed?";
self.messages.push(Message::assistant().with_text(prompt));
self.push_message(Message::assistant().with_text(prompt));
// No need for description update here
if let Some(session_file) = &self.session_file {
@@ -1438,7 +1439,7 @@ impl Session {
if msg.role == mcp_core::Role::User {
output::render_message(&msg, self.debug);
}
self.messages.push(msg);
self.push_message(msg);
}
if valid {
@@ -1496,6 +1497,10 @@ impl Session {
Ok(path)
}
fn push_message(&mut self, message: Message) {
push_message(&mut self.messages, message);
}
}
fn get_reasoner() -> Result<Arc<dyn Provider>, anyhow::Error> {

View File

@@ -10,7 +10,7 @@ use regex::Regex;
use serde_json::Value;
use std::cell::RefCell;
use std::collections::HashMap;
use std::io::Error;
use std::io::{Error, Write};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
@@ -166,7 +166,8 @@ pub fn render_message(message: &Message, debug: bool) {
}
}
}
println!();
let _ = std::io::stdout().flush();
}
pub fn render_text(text: &str, color: Option<Color>, dim: bool) {

View File

@@ -225,6 +225,7 @@ async fn handler(
return;
}
};
let saved_message_count = all_messages.len();
loop {
tokio::select! {
@@ -242,16 +243,6 @@ async fn handler(
).await;
break;
}
let session_path = session_path.clone();
let messages = all_messages.clone();
let provider = Arc::clone(provider.as_ref().unwrap());
tokio::spawn(async move {
if let Err(e) = session::persist_messages(&session_path, &messages, Some(provider)).await {
tracing::error!("Failed to store session history: {:?}", e);
}
});
}
Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => {
if let Err(e) = stream_event(MessageEvent::ModelChange { model, mode }, &tx).await {
@@ -303,6 +294,17 @@ async fn handler(
}
}
if all_messages.len() > saved_message_count {
let provider = Arc::clone(provider.as_ref().unwrap());
tokio::spawn(async move {
if let Err(e) =
session::persist_messages(&session_path, &all_messages, Some(provider)).await
{
tracing::error!("Failed to store session history: {:?}", e);
}
});
}
let _ = stream_event(
MessageEvent::Finish {
reason: "stop".to_string(),

View File

@@ -81,6 +81,7 @@ fs2 = "0.4.3"
tokio-stream = "0.1.17"
dashmap = "6.1"
ahash = "0.8"
tokio-util = "0.7.15"
# Vector database for tool selection
lancedb = "0.13"

View File

@@ -2,8 +2,12 @@ use anyhow::Result;
use dotenv::dotenv;
use goose::{
message::Message,
providers::{base::Provider, databricks::DatabricksProvider},
providers::{
base::{Provider, Usage},
databricks::DatabricksProvider,
},
};
use tokio_stream::StreamExt;
#[tokio::main]
async fn main() -> Result<()> {
@@ -20,21 +24,24 @@ async fn main() -> Result<()> {
let message = Message::user().with_text("Tell me a short joke about programming.");
// Get a response
let (response, usage) = provider
.complete("You are a helpful assistant.", &[message], &[])
let mut stream = provider
.stream("You are a helpful assistant.", &[message], &[])
.await?;
// Print the response and usage statistics
println!("\nResponse from AI:");
println!("---------------");
for content in response.content {
dbg!(content);
let mut usage = Usage::default();
while let Some(Ok((msg, usage_part))) = stream.next().await {
dbg!(msg);
usage_part.map(|u| {
usage += u.usage;
});
}
println!("\nToken Usage:");
println!("------------");
println!("Input tokens: {:?}", usage.usage.input_tokens);
println!("Output tokens: {:?}", usage.usage.output_tokens);
println!("Total tokens: {:?}", usage.usage.total_tokens);
println!("Input tokens: {:?}", usage.input_tokens);
println!("Output tokens: {:?}", usage.output_tokens);
println!("Total tokens: {:?}", usage.total_tokens);
Ok(())
}

View File

@@ -14,7 +14,7 @@ use crate::agents::sub_recipe_execution_tool::sub_recipe_execute_task_tool::{
};
use crate::agents::sub_recipe_manager::SubRecipeManager;
use crate::config::{Config, ExtensionConfigManager, PermissionManager};
use crate::message::Message;
use crate::message::{push_message, Message};
use crate::permission::permission_judge::check_tool_permissions;
use crate::permission::PermissionConfirmation;
use crate::providers::base::Provider;
@@ -722,6 +722,16 @@ impl Agent {
});
loop {
// Check for final output before incrementing turns or checking max_turns
// This ensures that if we have a final output ready, we return it immediately
// without being blocked by the max_turns limit - this is needed for streaming cases
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
if final_output_tool.final_output.is_some() {
yield AgentEvent::Message(Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()));
break;
}
}
turns_taken += 1;
if turns_taken > max_turns {
yield AgentEvent::Message(Message::assistant().with_text(
@@ -752,17 +762,22 @@ impl Agent {
}
}
match Self::generate_response_from_provider(
let mut stream = Self::stream_response_from_provider(
self.provider().await?,
&system_prompt,
&messages,
&tools,
&toolshim_tools,
).await {
).await?;
let mut added_message = false;
while let Some(next) = stream.next().await {
match next {
Ok((response, usage)) => {
// Emit model change event if provider is lead-worker
let provider = self.provider().await?;
if let Some(lead_worker) = provider.as_lead_worker() {
if let Some(ref usage) = usage {
// The actual model used is in the usage
let active_model = usage.model.clone();
let (lead_model, worker_model) = lead_worker.get_model_info();
@@ -779,12 +794,16 @@ impl Agent {
mode: mode.to_string(),
};
}
}
// record usage for the session in the session file
if let Some(session_config) = session.clone() {
Self::update_session_metrics(session_config, &usage, messages.len()).await?;
if let Some(ref usage) = usage {
Self::update_session_metrics(session_config, usage, messages.len()).await?;
}
}
if let Some(response) = response {
// categorize the type of requests we need to handle
let (frontend_requests,
remaining_requests,
@@ -829,9 +848,14 @@ impl Agent {
let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap());
messages.push(message.clone());
yield AgentEvent::Message(message);
// Set added_message to true and continue to end the current iteration
added_message = true;
push_message(&mut messages, response);
continue;
}
}
break;
// If there's no final output tool and no tool requests, continue the loop
continue;
}
// Process tool requests depending on frontend tools and then goose_mode
@@ -964,8 +988,9 @@ impl Agent {
let final_message_tool_resp = message_tool_response.lock().await.clone();
yield AgentEvent::Message(final_message_tool_resp.clone());
messages.push(response);
messages.push(final_message_tool_resp);
added_message = true;
push_message(&mut messages, response);
push_message(&mut messages, final_message_tool_resp);
// Check for MCP notifications from subagents again before next iteration
// Note: These are already handled as McpNotification events above,
@@ -991,6 +1016,7 @@ impl Agent {
// }
// }
// }
}
},
Err(ProviderError::ContextLengthExceeded(_)) => {
// At this point, the last message should be a user message
@@ -1008,6 +1034,19 @@ impl Agent {
break;
}
}
}
if !added_message {
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
if final_output_tool.final_output.is_none() {
tracing::warn!("Final output tool has not been called yet. Continuing agent loop.");
yield AgentEvent::Message(Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE));
continue;
} else {
yield AgentEvent::Message(Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()));
}
}
break;
}
// Yield control back to the scheduler to prevent blocking
tokio::task::yield_now().await;

View File

@@ -2,10 +2,13 @@ use anyhow::Result;
use std::collections::HashSet;
use std::sync::Arc;
use async_stream::try_stream;
use futures::stream::StreamExt;
use crate::agents::router_tool_selector::RouterToolSelectionStrategy;
use crate::config::Config;
use crate::message::{Message, MessageContent, ToolRequest};
use crate::providers::base::{Provider, ProviderUsage};
use crate::providers::base::{stream_from_single_message, MessageStream, Provider, ProviderUsage};
use crate::providers::errors::ProviderError;
use crate::providers::toolshim::{
augment_message_with_tool_calls, convert_tool_messages_to_text,
@@ -16,6 +19,19 @@ use mcp_core::tool::Tool;
use super::super::agents::Agent;
async fn toolshim_postprocess(
response: Message,
toolshim_tools: &[Tool],
) -> Result<Message, ProviderError> {
let interpreter = OllamaInterpreter::new().map_err(|e| {
ProviderError::ExecutionError(format!("Failed to create OllamaInterpreter: {}", e))
})?;
augment_message_with_tool_calls(&interpreter, response, toolshim_tools)
.await
.map_err(|e| ProviderError::ExecutionError(format!("Failed to augment message: {}", e)))
}
impl Agent {
/// Prepares tools and system prompt for a provider request
pub(crate) async fn prepare_tools_and_prompt(
@@ -128,25 +144,67 @@ impl Agent {
.complete(system_prompt, &messages_for_provider, tools)
.await?;
// Store the model information in the global store
crate::providers::base::set_current_model(&usage.model);
// Post-process / structure the response only if tool interpretation is enabled
if config.toolshim {
let interpreter = OllamaInterpreter::new().map_err(|e| {
ProviderError::ExecutionError(format!("Failed to create OllamaInterpreter: {}", e))
})?;
response = augment_message_with_tool_calls(&interpreter, response, toolshim_tools)
.await
.map_err(|e| {
ProviderError::ExecutionError(format!("Failed to augment message: {}", e))
})?;
response = toolshim_postprocess(response, toolshim_tools).await?;
}
Ok((response, usage))
}
/// Stream a response from the LLM provider.
/// Handles toolshim transformations if needed
pub(crate) async fn stream_response_from_provider(
provider: Arc<dyn Provider>,
system_prompt: &str,
messages: &[Message],
tools: &[Tool],
toolshim_tools: &[Tool],
) -> Result<MessageStream, ProviderError> {
let config = provider.get_model_config();
// Convert tool messages to text if toolshim is enabled
let messages_for_provider = if config.toolshim {
convert_tool_messages_to_text(messages)
} else {
messages.to_vec()
};
// Clone owned data to move into the async stream
let system_prompt = system_prompt.to_owned();
let tools = tools.to_owned();
let toolshim_tools = toolshim_tools.to_owned();
let provider = provider.clone();
let mut stream = if provider.supports_streaming() {
provider
.stream(system_prompt.as_str(), &messages_for_provider, &tools)
.await?
} else {
let (message, usage) = provider
.complete(system_prompt.as_str(), &messages_for_provider, &tools)
.await?;
stream_from_single_message(message, usage)
};
Ok(Box::pin(try_stream! {
while let Some(Ok((mut message, usage))) = stream.next().await {
// Store the model information in the global store
if let Some(usage) = usage.as_ref() {
crate::providers::base::set_current_model(&usage.model);
}
// Post-process / structure the response only if tool interpretation is enabled
if message.is_some() && config.toolshim {
message = Some(toolshim_postprocess(message.unwrap(), &toolshim_tools).await?);
}
yield (message, usage);
}
}))
}
/// Categorize tool requests from the response into different types
/// Returns:
/// - frontend_requests: Tool requests that should be handled by the frontend
@@ -191,6 +249,7 @@ impl Agent {
}
let filtered_message = Message {
id: response.id.clone(),
role: response.role.clone(),
created: response.created,
content: filtered_content,

View File

@@ -247,14 +247,14 @@ mod tests {
_tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
Ok((
Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
Message::new(
Role::Assistant,
Utc::now().timestamp(),
vec![MessageContent::Text(TextContent {
text: "Summarized content".to_string(),
annotations: None,
})],
},
),
ProviderUsage::new("mock".to_string(), Usage::default()),
))
}
@@ -277,30 +277,26 @@ mod tests {
}
fn set_up_text_message(text: &str, role: Role) -> Message {
Message {
role,
created: 0,
content: vec![MessageContent::text(text.to_string())],
}
Message::new(role, 0, vec![MessageContent::text(text.to_string())])
}
fn set_up_tool_request_message(id: &str, tool_call: ToolCall) -> Message {
Message {
role: Role::Assistant,
created: 0,
content: vec![MessageContent::tool_request(id.to_string(), Ok(tool_call))],
}
Message::new(
Role::Assistant,
0,
vec![MessageContent::tool_request(id.to_string(), Ok(tool_call))],
)
}
fn set_up_tool_response_message(id: &str, tool_response: Vec<Content>) -> Message {
Message {
role: Role::User,
created: 0,
content: vec![MessageContent::tool_response(
Message::new(
Role::User,
0,
vec![MessageContent::tool_response(
id.to_string(),
Ok(tool_response),
)],
}
)
}
#[tokio::test]
@@ -448,14 +444,14 @@ mod tests {
#[tokio::test]
async fn test_reintegrate_removed_messages() {
let summarized_messages = vec![Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
let summarized_messages = vec![Message::new(
Role::Assistant,
Utc::now().timestamp(),
vec![MessageContent::Text(TextContent {
text: "Summary".to_string(),
annotations: None,
})],
}];
)];
let arguments = json!({
"param1": "value1"
});

View File

@@ -303,15 +303,46 @@ impl From<PromptMessage> for Message {
/// A message to or from an LLM
#[serde(rename_all = "camelCase")]
pub struct Message {
pub id: Option<String>,
pub role: Role,
pub created: i64,
pub content: Vec<MessageContent>,
}
pub fn push_message(messages: &mut Vec<Message>, message: Message) {
if let Some(last) = messages
.last_mut()
.filter(|m| m.id.is_some() && m.id == message.id)
{
match (last.content.last_mut(), message.content.last()) {
(Some(MessageContent::Text(ref mut last)), Some(MessageContent::Text(new)))
if message.content.len() == 1 =>
{
last.text.push_str(&new.text);
}
(_, _) => {
last.content.extend(message.content);
}
}
} else {
messages.push(message);
}
}
impl Message {
pub fn new(role: Role, created: i64, content: Vec<MessageContent>) -> Self {
Message {
id: None,
role,
created,
content,
}
}
/// Create a new user message with the current timestamp
pub fn user() -> Self {
Message {
id: None,
role: Role::User,
created: Utc::now().timestamp(),
content: Vec::new(),
@@ -321,6 +352,7 @@ impl Message {
/// Create a new assistant message with the current timestamp
pub fn assistant() -> Self {
Message {
id: None,
role: Role::Assistant,
created: Utc::now().timestamp(),
content: Vec::new(),

View File

@@ -81,10 +81,10 @@ fn create_check_messages(tool_requests: Vec<&ToolRequest>) -> Vec<Message> {
})
.collect();
let mut check_messages = vec![];
check_messages.push(Message {
role: mcp_core::Role::User,
created: Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
check_messages.push(Message::new(
mcp_core::Role::User,
Utc::now().timestamp(),
vec![MessageContent::Text(TextContent {
text: format!(
"Here are the tool requests: {:?}\n\nAnalyze the tool requests and list the tools that perform read-only operations. \
\n\nGuidelines for Read-Only Operations: \
@@ -96,7 +96,7 @@ fn create_check_messages(tool_requests: Vec<&ToolRequest>) -> Vec<Message> {
),
annotations: None,
})],
});
));
check_messages
}
@@ -296,10 +296,10 @@ mod tests {
_tools: &[Tool],
) -> anyhow::Result<(Message, ProviderUsage), ProviderError> {
Ok((
Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![MessageContent::ToolRequest(ToolRequest {
Message::new(
Role::Assistant,
Utc::now().timestamp(),
vec![MessageContent::ToolRequest(ToolRequest {
id: "mock_tool_request".to_string(),
tool_call: ToolResult::Ok(ToolCall {
name: "platform__tool_by_tool_permission".to_string(),
@@ -308,7 +308,7 @@ mod tests {
}),
}),
})],
},
),
ProviderUsage::new("mock".to_string(), Usage::default()),
))
}
@@ -354,10 +354,10 @@ mod tests {
#[test]
fn test_extract_read_only_tools() {
let message = Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![MessageContent::ToolRequest(ToolRequest {
let message = Message::new(
Role::Assistant,
Utc::now().timestamp(),
vec![MessageContent::ToolRequest(ToolRequest {
id: "tool_2".to_string(),
tool_call: ToolResult::Ok(ToolCall {
name: "platform__tool_by_tool_permission".to_string(),
@@ -366,7 +366,7 @@ mod tests {
}),
}),
})],
};
);
let result = extract_read_only_tools(&message);
assert!(result.is_some());

View File

@@ -1,4 +1,5 @@
use anyhow::Result;
use futures::Stream;
use serde::{Deserialize, Serialize};
use super::errors::ProviderError;
@@ -8,6 +9,8 @@ use mcp_core::tool::Tool;
use utoipa::ToSchema;
use once_cell::sync::Lazy;
use std::ops::{Add, AddAssign};
use std::pin::Pin;
use std::sync::Mutex;
/// A global store for the current model being used, we use this as when a provider returns, it tells us the real model, not an alias
@@ -184,13 +187,43 @@ impl ProviderUsage {
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[derive(Debug, Clone, Serialize, Deserialize, Default, Copy)]
pub struct Usage {
pub input_tokens: Option<i32>,
pub output_tokens: Option<i32>,
pub total_tokens: Option<i32>,
}
fn sum_optionals<T>(a: Option<T>, b: Option<T>) -> Option<T>
where
T: Add<Output = T> + Default,
{
match (a, b) {
(Some(x), Some(y)) => Some(x + y),
(Some(x), None) => Some(x + T::default()),
(None, Some(y)) => Some(T::default() + y),
(None, None) => None,
}
}
impl Add for Usage {
type Output = Self;
fn add(self, other: Self) -> Self {
Self {
input_tokens: sum_optionals(self.input_tokens, other.input_tokens),
output_tokens: sum_optionals(self.output_tokens, other.output_tokens),
total_tokens: sum_optionals(self.total_tokens, other.total_tokens),
}
}
}
impl AddAssign for Usage {
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
impl Usage {
pub fn new(
input_tokens: Option<i32>,
@@ -270,6 +303,21 @@ pub trait Provider: Send + Sync {
None
}
async fn stream(
&self,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<MessageStream, ProviderError> {
Err(ProviderError::NotImplemented(
"streaming not implemented".to_string(),
))
}
fn supports_streaming(&self) -> bool {
false
}
/// Get the currently active model name
/// For regular providers, this returns the configured model
/// For LeadWorkerProvider, this returns the currently active model (lead or worker)
@@ -282,6 +330,18 @@ pub trait Provider: Send + Sync {
}
}
/// A message stream yields partial text content but complete tool calls, all within the Message object
/// So a message with text will contain potentially just a word of a longer response, but tool calls
/// messages will only be yielded once concatenated.
pub type MessageStream = Pin<
Box<dyn Stream<Item = Result<(Option<Message>, Option<ProviderUsage>), ProviderError>> + Send>,
>;
pub fn stream_from_single_message(message: Message, usage: ProviderUsage) -> MessageStream {
let stream = futures::stream::once(async move { Ok((Some(message), Some(usage))) });
Box::pin(stream)
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -219,11 +219,11 @@ impl ClaudeCodeProvider {
annotations: None,
})];
let response_message = Message {
role: Role::Assistant,
created: chrono::Utc::now().timestamp(),
content: message_content,
};
let response_message = Message::new(
Role::Assistant,
chrono::Utc::now().timestamp(),
message_content,
);
Ok((response_message, usage))
}
@@ -353,14 +353,14 @@ impl ClaudeCodeProvider {
println!("================================");
}
let message = Message {
role: mcp_core::Role::Assistant,
created: chrono::Utc::now().timestamp(),
content: vec![MessageContent::Text(mcp_core::content::TextContent {
let message = Message::new(
mcp_core::Role::Assistant,
chrono::Utc::now().timestamp(),
vec![MessageContent::Text(mcp_core::content::TextContent {
text: description.clone(),
annotations: None,
})],
};
);
let usage = Usage::default();

View File

@@ -1,4 +1,16 @@
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
use anyhow::Result;
use async_stream::try_stream;
use async_trait::async_trait;
use futures::TryStreamExt;
use reqwest::{Client, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::io;
use std::time::Duration;
use tokio::pin;
use tokio_util::io::StreamReader;
use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage};
use super::embedding::EmbeddingCapable;
use super::errors::ProviderError;
use super::formats::databricks::{create_request, get_usage, response_to_message};
@@ -7,17 +19,13 @@ use super::utils::{get_model, ImageFormat};
use crate::config::ConfigError;
use crate::message::Message;
use crate::model::ModelConfig;
use crate::providers::formats::databricks::response_to_streaming_message;
use mcp_core::tool::Tool;
use serde_json::json;
use url::Url;
use anyhow::Result;
use async_trait::async_trait;
use reqwest::{Client, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::time::Duration;
use tokio::time::sleep;
use tokio_stream::StreamExt;
use tokio_util::codec::{FramedRead, LinesCodec};
use url::Url;
const DEFAULT_CLIENT_ID: &str = "databricks-cli";
const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020";
@@ -266,9 +274,6 @@ impl DatabricksProvider {
}
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
let base_url = Url::parse(&self.host)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
// Check if this is an embedding request by looking at the payload structure
let is_embedding = payload.get("input").is_some() && payload.get("messages").is_none();
let path = if is_embedding {
@@ -279,56 +284,71 @@ impl DatabricksProvider {
format!("serving-endpoints/{}/invocations", self.model.model_name)
};
let url = base_url.join(&path).map_err(|e| {
match self.post_with_retry(path.as_str(), &payload).await {
Ok(res) => res.json().await.map_err(|_| {
ProviderError::RequestFailed("Response body is not valid JSON".to_string())
}),
Err(e) => Err(e),
}
}
async fn post_with_retry(
&self,
path: &str,
payload: &Value,
) -> Result<reqwest::Response, ProviderError> {
let base_url = Url::parse(&self.host)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
let url = base_url.join(path).map_err(|e| {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;
// Initialize retry counter
let mut attempts = 0;
let mut last_error = None;
loop {
// Check if we've exceeded max retries
if attempts > 0 && attempts > self.retry_config.max_retries {
let error_msg = format!(
"Exceeded maximum retry attempts ({}) for rate limiting (429)",
self.retry_config.max_retries
);
tracing::error!("{}", error_msg);
return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg)));
}
let auth_header = self.ensure_auth_header().await?;
let response = self
.client
.post(url.clone())
.header("Authorization", auth_header)
.json(&payload)
.json(payload)
.send()
.await?;
let status = response.status();
let payload: Option<Value> = response.json().await.ok();
match status {
StatusCode::OK => {
return payload.ok_or_else(|| {
ProviderError::RequestFailed("Response body is not valid JSON".to_string())
});
break match status {
StatusCode::OK => Ok(response),
StatusCode::TOO_MANY_REQUESTS
| StatusCode::INTERNAL_SERVER_ERROR
| StatusCode::SERVICE_UNAVAILABLE => {
if attempts < self.retry_config.max_retries {
attempts += 1;
tracing::warn!(
"{}: retrying ({}/{})",
status,
attempts,
self.retry_config.max_retries
);
let delay = self.retry_config.delay_for_attempt(attempts);
tracing::info!("Backing off for {:?} before retry", delay);
sleep(delay).await;
continue;
}
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
return Err(ProviderError::Authentication(format!(
"Authentication failed. Please ensure your API keys are valid and have the required permissions. \
Status: {}. Response: {:?}",
status, payload
)));
Err(match status {
StatusCode::TOO_MANY_REQUESTS => {
ProviderError::RateLimitExceeded("Rate limit exceeded".to_string())
}
_ => ProviderError::ServerError("Server error".to_string()),
})
}
StatusCode::BAD_REQUEST => {
// Databricks provides a generic 'error' but also includes 'external_model_message' which is provider specific
// We try to extract the error message from the payload and check for phrases that indicate context length exceeded
let payload_str = serde_json::to_string(&payload)
.unwrap_or_default()
.to_lowercase();
let bytes = response.bytes().await?;
let payload_str = String::from_utf8_lossy(&bytes).to_lowercase();
let check_phrases = [
"too long",
"context length",
@@ -347,13 +367,13 @@ impl DatabricksProvider {
}
let mut error_msg = "Unknown error".to_string();
if let Some(payload) = &payload {
if let Ok(response_json) = serde_json::from_slice::<Value>(&bytes) {
// try to convert message to string, if that fails use external_model_message
error_msg = payload
error_msg = response_json
.get("message")
.and_then(|m| m.as_str())
.or_else(|| {
payload
response_json
.get("external_model_message")
.and_then(|ext| ext.get("message"))
.and_then(|m| m.as_str())
@@ -366,7 +386,7 @@ impl DatabricksProvider {
"{}",
format!(
"Provider request failed with status: {}. Payload: {:?}",
status, payload
status, payload_str
)
);
return Err(ProviderError::RequestFailed(format!(
@@ -374,50 +394,13 @@ impl DatabricksProvider {
status, error_msg
)));
}
StatusCode::TOO_MANY_REQUESTS => {
attempts += 1;
let error_msg = format!(
"Rate limit exceeded (attempt {}/{}): {:?}",
attempts, self.retry_config.max_retries, payload
);
tracing::warn!("{}. Retrying after backoff...", error_msg);
// Store the error in case we need to return it after max retries
last_error = Some(ProviderError::RateLimitExceeded(error_msg));
// Calculate and apply the backoff delay
let delay = self.retry_config.delay_for_attempt(attempts);
tracing::info!("Backing off for {:?} before retry", delay);
sleep(delay).await;
// Continue to the next retry attempt
continue;
}
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
attempts += 1;
let error_msg = format!(
"Server error (attempt {}/{}): {:?}",
attempts, self.retry_config.max_retries, payload
);
tracing::warn!("{}. Retrying after backoff...", error_msg);
// Store the error in case we need to return it after max retries
last_error = Some(ProviderError::ServerError(error_msg));
// Calculate and apply the backoff delay
let delay = self.retry_config.delay_for_attempt(attempts);
tracing::info!("Backing off for {:?} before retry", delay);
sleep(delay).await;
// Continue to the next retry attempt
continue;
}
_ => {
tracing::debug!(
"{}",
format!(
"Provider request failed with status: {}. Payload: {:?}",
status, payload
status,
response.text().await.ok().unwrap_or_default()
)
);
return Err(ProviderError::RequestFailed(format!(
@@ -425,7 +408,7 @@ impl DatabricksProvider {
status
)));
}
}
};
}
}
}
@@ -472,13 +455,12 @@ impl Provider for DatabricksProvider {
// Parse response
let message = response_to_message(response.clone())?;
let usage = match get_usage(&response) {
Ok(usage) => usage,
Err(ProviderError::UsageError(e)) => {
tracing::debug!("Failed to get usage data: {}", e);
let usage = match response.get("usage").map(get_usage) {
Some(usage) => usage,
None => {
tracing::debug!("Failed to get usage data");
Usage::default()
}
Err(e) => return Err(e),
};
let model = get_model(&response);
super::utils::emit_debug_trace(&self.model, &payload, &response, &usage);
@@ -486,6 +468,54 @@ impl Provider for DatabricksProvider {
Ok((message, ProviderUsage::new(model, usage)))
}
async fn stream(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<MessageStream, ProviderError> {
let mut payload = create_request(&self.model, system, messages, tools, &self.image_format)?;
// Remove the model key which is part of the url with databricks
payload
.as_object_mut()
.expect("payload should have model key")
.remove("model");
payload
.as_object_mut()
.unwrap()
.insert("stream".to_string(), Value::Bool(true));
let response = self
.post_with_retry(
format!("serving-endpoints/{}/invocations", self.model.model_name).as_str(),
&payload,
)
.await?;
// Map reqwest error to io::Error
let stream = response.bytes_stream().map_err(io::Error::other);
let model_config = self.model.clone();
// Wrap in a line decoder and yield lines inside the stream
Ok(Box::pin(try_stream! {
let stream_reader = StreamReader::new(stream);
let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from);
let message_stream = response_to_streaming_message(framed);
pin!(message_stream);
while let Some(message) = message_stream.next().await {
let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?;
super::utils::emit_debug_trace(&model_config, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default());
yield (message, usage);
}
}))
}
fn supports_streaming(&self) -> bool {
true
}
fn supports_embeddings(&self) -> bool {
true
}

View File

@@ -23,6 +23,9 @@ pub enum ProviderError {
#[error("Usage data error: {0}")]
UsageError(String),
#[error("Unsupported operation: {0}")]
NotImplemented(String),
}
impl From<anyhow::Error> for ProviderError {

View File

@@ -212,17 +212,17 @@ mod tests {
_tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
Ok((
Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
Message::new(
Role::Assistant,
Utc::now().timestamp(),
vec![MessageContent::Text(TextContent {
text: format!(
"Response from {} with model {}",
self.name, self.model_config.model_name
),
annotations: None,
})],
},
),
ProviderUsage::new(self.model_config.model_name.clone(), Usage::default()),
))
}

View File

@@ -260,11 +260,7 @@ pub fn from_bedrock_message(message: &bedrock::Message) -> Result<Message> {
.collect::<Result<Vec<_>>>()?;
let created = Utc::now().timestamp();
Ok(Message {
role,
content,
created,
})
Ok(Message::new(role, created, content))
}
pub fn from_bedrock_content_block(block: &bedrock::ContentBlock) -> Result<MessageContent> {

View File

@@ -1,14 +1,16 @@
use crate::message::{Message, MessageContent};
use crate::model::ModelConfig;
use crate::providers::base::Usage;
use crate::providers::errors::ProviderError;
use crate::providers::base::{ProviderUsage, Usage};
use crate::providers::utils::{
convert_image, detect_image_path, is_valid_function_name, load_image_file,
sanitize_function_name, ImageFormat,
};
use anyhow::{anyhow, Error};
use async_stream::try_stream;
use futures::Stream;
use mcp_core::ToolError;
use mcp_core::{Content, Role, Tool, ToolCall};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
/// Convert internal Message format to Databricks' API message specification
@@ -358,18 +360,162 @@ pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
}
}
Ok(Message {
role: Role::Assistant,
created: chrono::Utc::now().timestamp(),
Ok(Message::new(
Role::Assistant,
chrono::Utc::now().timestamp(),
content,
})
))
}
pub fn get_usage(data: &Value) -> Result<Usage, ProviderError> {
let usage = data
.get("usage")
.ok_or_else(|| ProviderError::UsageError("No usage data in response".to_string()))?;
#[derive(Serialize, Deserialize, Debug)]
struct DeltaToolCallFunction {
name: Option<String>,
arguments: String, // chunk of encoded JSON,
}
#[derive(Serialize, Deserialize, Debug)]
struct DeltaToolCall {
id: Option<String>,
function: DeltaToolCallFunction,
index: Option<i32>,
r#type: Option<String>,
}
#[derive(Serialize, Deserialize, Debug)]
struct Delta {
content: Option<String>,
role: Option<String>,
tool_calls: Option<Vec<DeltaToolCall>>,
}
#[derive(Serialize, Deserialize, Debug)]
struct StreamingChoice {
delta: Delta,
index: Option<i32>,
finish_reason: Option<String>,
}
#[derive(Serialize, Deserialize, Debug)]
struct StreamingChunk {
choices: Vec<StreamingChoice>,
created: Option<i64>,
id: Option<String>,
usage: Option<Value>,
model: String,
}
fn strip_data_prefix(line: &str) -> Option<&str> {
line.strip_prefix("data: ").map(|s| s.trim())
}
pub fn response_to_streaming_message<S>(
mut stream: S,
) -> impl Stream<Item = anyhow::Result<(Option<Message>, Option<ProviderUsage>)>> + 'static
where
S: Stream<Item = anyhow::Result<String>> + Unpin + Send + 'static,
{
try_stream! {
use futures::StreamExt;
'outer: while let Some(response) = stream.next().await {
if response.as_ref().is_ok_and(|s| s == "data: [DONE]") {
break 'outer;
}
let response_str = response?;
let line = strip_data_prefix(&response_str);
if line.is_none() || line.is_some_and(|l| l.is_empty()) {
continue
}
let chunk: StreamingChunk = serde_json::from_str(line
.ok_or_else(|| anyhow!("unexpected stream format"))?)
.map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?;
let model = chunk.model.clone();
let usage = chunk.usage.as_ref().map(|u| {
ProviderUsage {
usage: get_usage(u),
model,
}
});
if chunk.choices.is_empty() {
yield (None, usage)
} else if let Some(tool_calls) = &chunk.choices[0].delta.tool_calls {
let tool_call = &tool_calls[0];
let id = tool_call.id.clone().ok_or(anyhow!("No tool call ID"))?;
let function_name = tool_call.function.name.clone().ok_or(anyhow!("No function name"))?;
let mut arguments = tool_call.function.arguments.clone();
while let Some(response_chunk) = stream.next().await {
if response_chunk.as_ref().is_ok_and(|s| s == "data: [DONE]") {
break 'outer;
}
let response_str = response_chunk?;
if let Some(line) = strip_data_prefix(&response_str) {
let tool_chunk: StreamingChunk = serde_json::from_str(line)
.map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?;
let more_args = tool_chunk.choices[0].delta.tool_calls.as_ref()
.and_then(|calls| calls.first())
.map(|call| call.function.arguments.as_str());
if let Some(more_args) = more_args {
arguments.push_str(more_args);
} else {
break;
}
}
}
let parsed = if arguments.is_empty() {
Ok(json!({}))
} else {
serde_json::from_str::<Value>(&arguments)
};
let content = match parsed {
Ok(params) => MessageContent::tool_request(
id,
Ok(ToolCall::new(function_name, params)),
),
Err(e) => {
let error = ToolError::InvalidParameters(format!(
"Could not interpret tool use parameters for id {}: {}",
id, e
));
MessageContent::tool_request(id, Err(error))
}
};
yield (
Some(Message {
id: chunk.id,
role: Role::Assistant,
created: chrono::Utc::now().timestamp(),
content: vec![content],
}),
usage,
)
} else if let Some(text) = &chunk.choices[0].delta.content {
yield (
Some(Message {
id: chunk.id,
role: Role::Assistant,
created: chrono::Utc::now().timestamp(),
content: vec![MessageContent::text(text)],
}),
if chunk.choices[0].finish_reason.is_some() {
usage
} else {
None
},
)
}
}
}
}
pub fn get_usage(usage: &Value) -> Usage {
let input_tokens = usage
.get("prompt_tokens")
.and_then(|v| v.as_i64())
@@ -389,7 +535,7 @@ pub fn get_usage(data: &Value) -> Result<Usage, ProviderError> {
_ => None,
});
Ok(Usage::new(input_tokens, output_tokens, total_tokens))
Usage::new(input_tokens, output_tokens, total_tokens)
}
/// Validates and fixes tool schemas to ensure they have proper parameter structure.

View File

@@ -209,11 +209,7 @@ pub fn response_to_message(response: Value) -> Result<Message> {
let role = Role::Assistant;
let created = chrono::Utc::now().timestamp();
if candidate.is_none() {
return Ok(Message {
role,
created,
content,
});
return Ok(Message::new(role, created, content));
}
let candidate = candidate.unwrap();
let parts = candidate
@@ -252,11 +248,7 @@ pub fn response_to_message(response: Value) -> Result<Message> {
}
}
}
Ok(Message {
role,
created,
content,
})
Ok(Message::new(role, created, content))
}
/// Extract usage information from Google's API response
@@ -324,43 +316,39 @@ mod tests {
use serde_json::json;
fn set_up_text_message(text: &str, role: Role) -> Message {
Message {
role,
created: 0,
content: vec![MessageContent::text(text.to_string())],
}
Message::new(role, 0, vec![MessageContent::text(text.to_string())])
}
fn set_up_tool_request_message(id: &str, tool_call: ToolCall) -> Message {
Message {
role: Role::User,
created: 0,
content: vec![MessageContent::tool_request(id.to_string(), Ok(tool_call))],
}
Message::new(
Role::User,
0,
vec![MessageContent::tool_request(id.to_string(), Ok(tool_call))],
)
}
fn set_up_tool_confirmation_message(id: &str, tool_call: ToolCall) -> Message {
Message {
role: Role::User,
created: 0,
content: vec![MessageContent::tool_confirmation_request(
Message::new(
Role::User,
0,
vec![MessageContent::tool_confirmation_request(
id.to_string(),
tool_call.name.clone(),
tool_call.arguments.clone(),
Some("Goose would like to call the above tool. Allow? (y/n):".to_string()),
)],
}
)
}
fn set_up_tool_response_message(id: &str, tool_response: Vec<Content>) -> Message {
Message {
role: Role::Assistant,
created: 0,
content: vec![MessageContent::tool_response(
Message::new(
Role::Assistant,
0,
vec![MessageContent::tool_response(
id.to_string(),
Ok(tool_response),
)],
}
)
}
fn set_up_tool(name: &str, description: &str, params: Value) -> Tool {

View File

@@ -274,11 +274,11 @@ pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
}
}
Ok(Message {
role: Role::Assistant,
created: chrono::Utc::now().timestamp(),
Ok(Message::new(
Role::Assistant,
chrono::Utc::now().timestamp(),
content,
})
))
}
pub fn get_usage(data: &Value) -> Result<Usage, ProviderError> {

View File

@@ -169,14 +169,14 @@ impl GeminiCliProvider {
));
}
let message = Message {
role: Role::Assistant,
created: chrono::Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
let message = Message::new(
Role::Assistant,
chrono::Utc::now().timestamp(),
vec![MessageContent::Text(TextContent {
text: response_text,
annotations: None,
})],
};
);
let usage = Usage::default(); // No usage info available for gemini CLI
@@ -214,14 +214,14 @@ impl GeminiCliProvider {
println!("================================");
}
let message = Message {
role: Role::Assistant,
created: chrono::Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
let message = Message::new(
Role::Assistant,
chrono::Utc::now().timestamp(),
vec![MessageContent::Text(TextContent {
text: description.clone(),
annotations: None,
})],
};
);
let usage = Usage::default();

View File

@@ -480,14 +480,14 @@ mod tests {
_tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
Ok((
Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
Message::new(
Role::Assistant,
Utc::now().timestamp(),
vec![MessageContent::Text(TextContent {
text: format!("Response from {}", self.name),
annotations: None,
})],
},
),
ProviderUsage::new(self.name.clone(), Usage::default()),
))
}
@@ -643,14 +643,14 @@ mod tests {
))
} else {
Ok((
Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
Message::new(
Role::Assistant,
Utc::now().timestamp(),
vec![MessageContent::Text(TextContent {
text: format!("Response from {}", self.name),
annotations: None,
})],
},
),
ProviderUsage::new(self.name.clone(), Usage::default()),
))
}

View File

@@ -203,14 +203,14 @@ impl SageMakerTgiProvider {
// Strip any HTML tags that might have been generated
let clean_text = self.strip_html_tags(generated_text);
Ok(Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
Ok(Message::new(
Role::Assistant,
Utc::now().timestamp(),
vec![MessageContent::Text(TextContent {
text: clean_text,
annotations: None,
})],
})
))
}
/// Strip HTML tags from text to ensure clean output

View File

@@ -359,11 +359,7 @@ pub fn convert_tool_messages_to_text(messages: &[Message]) -> Vec<Message> {
}
if has_tool_content {
Message {
role: message.role.clone(),
content: new_content,
created: message.created,
}
Message::new(message.role.clone(), message.created, new_content)
} else {
message.clone()
}

View File

@@ -319,12 +319,15 @@ pub fn unescape_json_values(value: &Value) -> Value {
}
}
pub fn emit_debug_trace(
pub fn emit_debug_trace<T1, T2>(
model_config: &ModelConfig,
payload: &Value,
response: &Value,
payload: &T1,
response: &T2,
usage: &Usage,
) {
) where
T1: ?Sized + Serialize,
T2: ?Sized + Serialize,
{
tracing::debug!(
model_config = %serde_json::to_string_pretty(model_config).unwrap_or_default(),
input = %serde_json::to_string_pretty(payload).unwrap_or_default(),

View File

@@ -557,11 +557,7 @@ impl Provider for VeniceProvider {
};
Ok((
Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content,
},
Message::new(Role::Assistant, Utc::now().timestamp(), content),
ProviderUsage::new(strip_flags(&self.model.model_name).to_string(), usage),
))
}

View File

@@ -1370,14 +1370,14 @@ mod tests {
_tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
Ok((
Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
Message::new(
Role::Assistant,
Utc::now().timestamp(),
vec![MessageContent::Text(TextContent {
text: "Mocked scheduled response".to_string(),
annotations: None,
})],
},
),
ProviderUsage::new("mock-scheduler-test".to_string(), Usage::default()),
))
}

View File

@@ -407,25 +407,25 @@ mod tests {
"You are a helpful assistant that can answer questions about the weather.";
let messages = vec![
Message {
role: Role::User,
created: 0,
content: vec![MessageContent::text(
Message::new(
Role::User,
0,
vec![MessageContent::text(
"What's the weather like in San Francisco?",
)],
},
Message {
role: Role::Assistant,
created: 1,
content: vec![MessageContent::text(
),
Message::new(
Role::Assistant,
1,
vec![MessageContent::text(
"Looks like it's 60 degrees Fahrenheit in San Francisco.",
)],
},
Message {
role: Role::User,
created: 2,
content: vec![MessageContent::text("How about New York?")],
},
),
Message::new(
Role::User,
2,
vec![MessageContent::text("How about New York?")],
),
];
let tools = vec![Tool {
@@ -505,25 +505,25 @@ mod tests {
"You are a helpful assistant that can answer questions about the weather.";
let messages = vec![
Message {
role: Role::User,
created: 0,
content: vec![MessageContent::text(
Message::new(
Role::User,
0,
vec![MessageContent::text(
"What's the weather like in San Francisco?",
)],
},
Message {
role: Role::Assistant,
created: 1,
content: vec![MessageContent::text(
),
Message::new(
Role::Assistant,
1,
vec![MessageContent::text(
"Looks like it's 60 degrees Fahrenheit in San Francisco.",
)],
},
Message {
role: Role::User,
created: 2,
content: vec![MessageContent::text("How about New York?")],
},
),
Message::new(
Role::User,
2,
vec![MessageContent::text("How about New York?")],
),
];
let tools = vec![Tool {

View File

@@ -786,7 +786,7 @@ function ChatContent({
<SearchView>
{filteredMessages.map((message, index) => (
<div
key={message.id || index}
key={(message.id && `${message.id}-${message.content.length}`) || index}
className="mt-4 px-4"
data-testid="message-container"
>

View File

@@ -130,7 +130,7 @@ export default function GooseMessage({
]);
return (
<div className="goose-message flex w-[90%] justify-start opacity-0 animate-[appear_150ms_ease-in_forwards]">
<div className="goose-message flex w-[90%] justify-start">
<div className="flex flex-col w-full">
{/* Chain-of-Thought (hidden by default) */}
{cotText && (

View File

@@ -1,4 +1,4 @@
import { useState, useCallback, useEffect, useRef, useId } from 'react';
import { useState, useCallback, useEffect, useRef, useId, useReducer } from 'react';
import useSWR from 'swr';
import { getSecretKey } from '../config';
import { Message, createUserMessage, hasCompletedToolCalls } from '../types/message';
@@ -235,6 +235,9 @@ export function useMessageStream({
};
}, [headers, body]);
// TODO: not this?
const [, forceUpdate] = useReducer((x) => x + 1, 0);
// Process the SSE stream from the server
const processMessageStream = useCallback(
async (response: Response, currentMessages: Message[]) => {
@@ -284,8 +287,23 @@ export function useMessageStream({
: parsedEvent.message.sendToLLM,
};
console.log('New message:', JSON.stringify(newMessage, null, 2));
// Update messages with the new message
if (
newMessage.id &&
currentMessages.length > 0 &&
currentMessages[currentMessages.length - 1].id === newMessage.id
) {
// If the last message has the same ID, update it instead of adding a new one
const lastMessage = currentMessages[currentMessages.length - 1];
lastMessage.content = [...lastMessage.content, ...newMessage.content];
forceUpdate();
} else {
currentMessages = [...currentMessages, newMessage];
}
mutate(currentMessages, false);
break;
}
@@ -373,7 +391,7 @@ export function useMessageStream({
return currentMessages;
},
[mutate, onFinish, onError]
[mutate, onFinish, onError, forceUpdate]
);
// Send a request to the server

View File

@@ -201,7 +201,7 @@ export function getTextContent(message: Message): string {
}
return '';
})
.join('\n');
.join('');
}
export function getToolRequests(message: Message): ToolRequestMessageContent[] {