mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-19 07:04:21 +01:00
feat: stream LLM responses (#2677)
Co-authored-by: Michael Neale <michael.neale@gmail.com>
This commit is contained in:
5
Cargo.lock
generated
5
Cargo.lock
generated
@@ -3486,6 +3486,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-cron-scheduler",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"url",
|
||||
@@ -8604,9 +8605,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.7.13"
|
||||
version = "0.7.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078"
|
||||
checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
|
||||
@@ -15,4 +15,4 @@ uninlined_format_args = "allow"
|
||||
|
||||
# Patch for Windows cross-compilation issue with crunchy
|
||||
[patch.crates-io]
|
||||
crunchy = { git = "https://github.com/nmathewson/crunchy", branch = "cross-compilation-fix" }
|
||||
crunchy = { git = "https://github.com/nmathewson/crunchy", branch = "cross-compilation-fix" }
|
||||
|
||||
@@ -10,6 +10,7 @@ pub use self::export::message_to_markdown;
|
||||
pub use builder::{build_session, SessionBuilderConfig, SessionSettings};
|
||||
use console::Color;
|
||||
use goose::agents::AgentEvent;
|
||||
use goose::message::push_message;
|
||||
use goose::permission::permission_confirmation::PrincipalType;
|
||||
use goose::permission::Permission;
|
||||
use goose::permission::PermissionConfirmation;
|
||||
@@ -356,7 +357,7 @@ impl Session {
|
||||
|
||||
/// Process a single message and get the response
|
||||
async fn process_message(&mut self, message: String) -> Result<()> {
|
||||
self.messages.push(Message::user().with_text(&message));
|
||||
self.push_message(Message::user().with_text(&message));
|
||||
// Get the provider from the agent for description generation
|
||||
let provider = self.agent.provider().await?;
|
||||
|
||||
@@ -462,7 +463,7 @@ impl Session {
|
||||
RunMode::Normal => {
|
||||
save_history(&mut editor);
|
||||
|
||||
self.messages.push(Message::user().with_text(&content));
|
||||
self.push_message(Message::user().with_text(&content));
|
||||
|
||||
// Track the current directory and last instruction in projects.json
|
||||
let session_id = self
|
||||
@@ -785,7 +786,7 @@ impl Session {
|
||||
self.messages.clear();
|
||||
// add the plan response as a user message
|
||||
let plan_message = Message::user().with_text(plan_response.as_concat_text());
|
||||
self.messages.push(plan_message);
|
||||
self.push_message(plan_message);
|
||||
// act on the plan
|
||||
output::show_thinking();
|
||||
self.process_agent_response(true).await?;
|
||||
@@ -800,13 +801,13 @@ impl Session {
|
||||
} else {
|
||||
// add the plan response (assistant message) & carry the conversation forward
|
||||
// in the next round, the user might wanna slightly modify the plan
|
||||
self.messages.push(plan_response);
|
||||
self.push_message(plan_response);
|
||||
}
|
||||
}
|
||||
PlannerResponseType::ClarifyingQuestions => {
|
||||
// add the plan response (assistant message) & carry the conversation forward
|
||||
// in the next round, the user will answer the clarifying questions
|
||||
self.messages.push(plan_response);
|
||||
self.push_message(plan_response);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -878,7 +879,7 @@ impl Session {
|
||||
confirmation.id.clone(),
|
||||
Err(ToolError::ExecutionError("Tool call cancelled by user".to_string()))
|
||||
));
|
||||
self.messages.push(response_message);
|
||||
push_message(&mut self.messages, response_message);
|
||||
if let Some(session_file) = &self.session_file {
|
||||
session::persist_messages_with_schedule_id(
|
||||
session_file,
|
||||
@@ -975,7 +976,7 @@ impl Session {
|
||||
}
|
||||
// otherwise we have a model/tool to render
|
||||
else {
|
||||
self.messages.push(message.clone());
|
||||
push_message(&mut self.messages, message.clone());
|
||||
|
||||
// No need to update description on assistant messages
|
||||
if let Some(session_file) = &self.session_file {
|
||||
@@ -991,7 +992,6 @@ impl Session {
|
||||
if interactive {output::hide_thinking()};
|
||||
let _ = progress_bars.hide();
|
||||
output::render_message(&message, self.debug);
|
||||
if interactive {output::show_thinking()};
|
||||
}
|
||||
}
|
||||
Some(Ok(AgentEvent::McpNotification((_id, message)))) => {
|
||||
@@ -1139,6 +1139,7 @@ impl Session {
|
||||
}
|
||||
}
|
||||
}
|
||||
println!();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1182,7 +1183,7 @@ impl Session {
|
||||
Err(ToolError::ExecutionError(notification.clone())),
|
||||
));
|
||||
}
|
||||
self.messages.push(response_message);
|
||||
self.push_message(response_message);
|
||||
|
||||
// No need for description update here
|
||||
if let Some(session_file) = &self.session_file {
|
||||
@@ -1199,7 +1200,7 @@ impl Session {
|
||||
"The existing call to {} was interrupted. How would you like to proceed?",
|
||||
last_tool_name
|
||||
);
|
||||
self.messages.push(Message::assistant().with_text(&prompt));
|
||||
self.push_message(Message::assistant().with_text(&prompt));
|
||||
|
||||
// No need for description update here
|
||||
if let Some(session_file) = &self.session_file {
|
||||
@@ -1221,7 +1222,7 @@ impl Session {
|
||||
Some(MessageContent::ToolResponse(_)) => {
|
||||
// Interruption occurred after a tool had completed but not assistant reply
|
||||
let prompt = "The tool calling loop was interrupted. How would you like to proceed?";
|
||||
self.messages.push(Message::assistant().with_text(prompt));
|
||||
self.push_message(Message::assistant().with_text(prompt));
|
||||
|
||||
// No need for description update here
|
||||
if let Some(session_file) = &self.session_file {
|
||||
@@ -1438,7 +1439,7 @@ impl Session {
|
||||
if msg.role == mcp_core::Role::User {
|
||||
output::render_message(&msg, self.debug);
|
||||
}
|
||||
self.messages.push(msg);
|
||||
self.push_message(msg);
|
||||
}
|
||||
|
||||
if valid {
|
||||
@@ -1496,6 +1497,10 @@ impl Session {
|
||||
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
fn push_message(&mut self, message: Message) {
|
||||
push_message(&mut self.messages, message);
|
||||
}
|
||||
}
|
||||
|
||||
fn get_reasoner() -> Result<Arc<dyn Provider>, anyhow::Error> {
|
||||
|
||||
@@ -10,7 +10,7 @@ use regex::Regex;
|
||||
use serde_json::Value;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::io::Error;
|
||||
use std::io::{Error, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
@@ -166,7 +166,8 @@ pub fn render_message(message: &Message, debug: bool) {
|
||||
}
|
||||
}
|
||||
}
|
||||
println!();
|
||||
|
||||
let _ = std::io::stdout().flush();
|
||||
}
|
||||
|
||||
pub fn render_text(text: &str, color: Option<Color>, dim: bool) {
|
||||
|
||||
@@ -225,6 +225,7 @@ async fn handler(
|
||||
return;
|
||||
}
|
||||
};
|
||||
let saved_message_count = all_messages.len();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
@@ -242,16 +243,6 @@ async fn handler(
|
||||
).await;
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
let session_path = session_path.clone();
|
||||
let messages = all_messages.clone();
|
||||
let provider = Arc::clone(provider.as_ref().unwrap());
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = session::persist_messages(&session_path, &messages, Some(provider)).await {
|
||||
tracing::error!("Failed to store session history: {:?}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => {
|
||||
if let Err(e) = stream_event(MessageEvent::ModelChange { model, mode }, &tx).await {
|
||||
@@ -303,6 +294,17 @@ async fn handler(
|
||||
}
|
||||
}
|
||||
|
||||
if all_messages.len() > saved_message_count {
|
||||
let provider = Arc::clone(provider.as_ref().unwrap());
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) =
|
||||
session::persist_messages(&session_path, &all_messages, Some(provider)).await
|
||||
{
|
||||
tracing::error!("Failed to store session history: {:?}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let _ = stream_event(
|
||||
MessageEvent::Finish {
|
||||
reason: "stop".to_string(),
|
||||
|
||||
@@ -81,6 +81,7 @@ fs2 = "0.4.3"
|
||||
tokio-stream = "0.1.17"
|
||||
dashmap = "6.1"
|
||||
ahash = "0.8"
|
||||
tokio-util = "0.7.15"
|
||||
|
||||
# Vector database for tool selection
|
||||
lancedb = "0.13"
|
||||
|
||||
@@ -2,8 +2,12 @@ use anyhow::Result;
|
||||
use dotenv::dotenv;
|
||||
use goose::{
|
||||
message::Message,
|
||||
providers::{base::Provider, databricks::DatabricksProvider},
|
||||
providers::{
|
||||
base::{Provider, Usage},
|
||||
databricks::DatabricksProvider,
|
||||
},
|
||||
};
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
@@ -20,21 +24,24 @@ async fn main() -> Result<()> {
|
||||
let message = Message::user().with_text("Tell me a short joke about programming.");
|
||||
|
||||
// Get a response
|
||||
let (response, usage) = provider
|
||||
.complete("You are a helpful assistant.", &[message], &[])
|
||||
let mut stream = provider
|
||||
.stream("You are a helpful assistant.", &[message], &[])
|
||||
.await?;
|
||||
|
||||
// Print the response and usage statistics
|
||||
println!("\nResponse from AI:");
|
||||
println!("---------------");
|
||||
for content in response.content {
|
||||
dbg!(content);
|
||||
let mut usage = Usage::default();
|
||||
while let Some(Ok((msg, usage_part))) = stream.next().await {
|
||||
dbg!(msg);
|
||||
usage_part.map(|u| {
|
||||
usage += u.usage;
|
||||
});
|
||||
}
|
||||
println!("\nToken Usage:");
|
||||
println!("------------");
|
||||
println!("Input tokens: {:?}", usage.usage.input_tokens);
|
||||
println!("Output tokens: {:?}", usage.usage.output_tokens);
|
||||
println!("Total tokens: {:?}", usage.usage.total_tokens);
|
||||
println!("Input tokens: {:?}", usage.input_tokens);
|
||||
println!("Output tokens: {:?}", usage.output_tokens);
|
||||
println!("Total tokens: {:?}", usage.total_tokens);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ use crate::agents::sub_recipe_execution_tool::sub_recipe_execute_task_tool::{
|
||||
};
|
||||
use crate::agents::sub_recipe_manager::SubRecipeManager;
|
||||
use crate::config::{Config, ExtensionConfigManager, PermissionManager};
|
||||
use crate::message::Message;
|
||||
use crate::message::{push_message, Message};
|
||||
use crate::permission::permission_judge::check_tool_permissions;
|
||||
use crate::permission::PermissionConfirmation;
|
||||
use crate::providers::base::Provider;
|
||||
@@ -722,6 +722,16 @@ impl Agent {
|
||||
});
|
||||
|
||||
loop {
|
||||
// Check for final output before incrementing turns or checking max_turns
|
||||
// This ensures that if we have a final output ready, we return it immediately
|
||||
// without being blocked by the max_turns limit - this is needed for streaming cases
|
||||
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
|
||||
if final_output_tool.final_output.is_some() {
|
||||
yield AgentEvent::Message(Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
turns_taken += 1;
|
||||
if turns_taken > max_turns {
|
||||
yield AgentEvent::Message(Message::assistant().with_text(
|
||||
@@ -752,262 +762,291 @@ impl Agent {
|
||||
}
|
||||
}
|
||||
|
||||
match Self::generate_response_from_provider(
|
||||
let mut stream = Self::stream_response_from_provider(
|
||||
self.provider().await?,
|
||||
&system_prompt,
|
||||
&messages,
|
||||
&tools,
|
||||
&toolshim_tools,
|
||||
).await {
|
||||
Ok((response, usage)) => {
|
||||
// Emit model change event if provider is lead-worker
|
||||
let provider = self.provider().await?;
|
||||
if let Some(lead_worker) = provider.as_lead_worker() {
|
||||
// The actual model used is in the usage
|
||||
let active_model = usage.model.clone();
|
||||
let (lead_model, worker_model) = lead_worker.get_model_info();
|
||||
let mode = if active_model == lead_model {
|
||||
"lead"
|
||||
} else if active_model == worker_model {
|
||||
"worker"
|
||||
} else {
|
||||
"unknown"
|
||||
};
|
||||
).await?;
|
||||
|
||||
yield AgentEvent::ModelChange {
|
||||
model: active_model,
|
||||
mode: mode.to_string(),
|
||||
};
|
||||
}
|
||||
let mut added_message = false;
|
||||
while let Some(next) = stream.next().await {
|
||||
match next {
|
||||
Ok((response, usage)) => {
|
||||
// Emit model change event if provider is lead-worker
|
||||
let provider = self.provider().await?;
|
||||
if let Some(lead_worker) = provider.as_lead_worker() {
|
||||
if let Some(ref usage) = usage {
|
||||
// The actual model used is in the usage
|
||||
let active_model = usage.model.clone();
|
||||
let (lead_model, worker_model) = lead_worker.get_model_info();
|
||||
let mode = if active_model == lead_model {
|
||||
"lead"
|
||||
} else if active_model == worker_model {
|
||||
"worker"
|
||||
} else {
|
||||
"unknown"
|
||||
};
|
||||
|
||||
// record usage for the session in the session file
|
||||
if let Some(session_config) = session.clone() {
|
||||
Self::update_session_metrics(session_config, &usage, messages.len()).await?;
|
||||
}
|
||||
|
||||
// categorize the type of requests we need to handle
|
||||
let (frontend_requests,
|
||||
remaining_requests,
|
||||
filtered_response) =
|
||||
self.categorize_tool_requests(&response).await;
|
||||
|
||||
// Record tool calls in the router selector
|
||||
let selector = self.router_tool_selector.lock().await.clone();
|
||||
if let Some(selector) = selector {
|
||||
// Record frontend tool calls
|
||||
for request in &frontend_requests {
|
||||
if let Ok(tool_call) = &request.tool_call {
|
||||
if let Err(e) = selector.record_tool_call(&tool_call.name).await {
|
||||
tracing::error!("Failed to record frontend tool call: {}", e);
|
||||
}
|
||||
yield AgentEvent::ModelChange {
|
||||
model: active_model,
|
||||
mode: mode.to_string(),
|
||||
};
|
||||
}
|
||||
}
|
||||
// Record remaining tool calls
|
||||
for request in &remaining_requests {
|
||||
if let Ok(tool_call) = &request.tool_call {
|
||||
if let Err(e) = selector.record_tool_call(&tool_call.name).await {
|
||||
tracing::error!("Failed to record tool call: {}", e);
|
||||
}
|
||||
|
||||
// record usage for the session in the session file
|
||||
if let Some(session_config) = session.clone() {
|
||||
if let Some(ref usage) = usage {
|
||||
Self::update_session_metrics(session_config, usage, messages.len()).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Yield the assistant's response with frontend tool requests filtered out
|
||||
yield AgentEvent::Message(filtered_response.clone());
|
||||
|
||||
tokio::task::yield_now().await;
|
||||
if let Some(response) = response {
|
||||
// categorize the type of requests we need to handle
|
||||
let (frontend_requests,
|
||||
remaining_requests,
|
||||
filtered_response) =
|
||||
self.categorize_tool_requests(&response).await;
|
||||
|
||||
let num_tool_requests = frontend_requests.len() + remaining_requests.len();
|
||||
if num_tool_requests == 0 {
|
||||
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
|
||||
if final_output_tool.final_output.is_none() {
|
||||
tracing::warn!("Final output tool has not been called yet. Continuing agent loop.");
|
||||
let message = Message::assistant().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE);
|
||||
messages.push(message.clone());
|
||||
yield AgentEvent::Message(message);
|
||||
// Record tool calls in the router selector
|
||||
let selector = self.router_tool_selector.lock().await.clone();
|
||||
if let Some(selector) = selector {
|
||||
// Record frontend tool calls
|
||||
for request in &frontend_requests {
|
||||
if let Ok(tool_call) = &request.tool_call {
|
||||
if let Err(e) = selector.record_tool_call(&tool_call.name).await {
|
||||
tracing::error!("Failed to record frontend tool call: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Record remaining tool calls
|
||||
for request in &remaining_requests {
|
||||
if let Ok(tool_call) = &request.tool_call {
|
||||
if let Err(e) = selector.record_tool_call(&tool_call.name).await {
|
||||
tracing::error!("Failed to record tool call: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Yield the assistant's response with frontend tool requests filtered out
|
||||
yield AgentEvent::Message(filtered_response.clone());
|
||||
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
let num_tool_requests = frontend_requests.len() + remaining_requests.len();
|
||||
if num_tool_requests == 0 {
|
||||
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
|
||||
if final_output_tool.final_output.is_none() {
|
||||
tracing::warn!("Final output tool has not been called yet. Continuing agent loop.");
|
||||
let message = Message::assistant().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE);
|
||||
messages.push(message.clone());
|
||||
yield AgentEvent::Message(message);
|
||||
continue;
|
||||
} else {
|
||||
let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap());
|
||||
messages.push(message.clone());
|
||||
yield AgentEvent::Message(message);
|
||||
// Set added_message to true and continue to end the current iteration
|
||||
added_message = true;
|
||||
push_message(&mut messages, response);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// If there's no final output tool and no tool requests, continue the loop
|
||||
continue;
|
||||
} else {
|
||||
let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap());
|
||||
messages.push(message.clone());
|
||||
yield AgentEvent::Message(message);
|
||||
}
|
||||
|
||||
// Process tool requests depending on frontend tools and then goose_mode
|
||||
let message_tool_response = Arc::new(Mutex::new(Message::user()));
|
||||
|
||||
// First handle any frontend tool requests
|
||||
let mut frontend_tool_stream = self.handle_frontend_tool_requests(
|
||||
&frontend_requests,
|
||||
message_tool_response.clone()
|
||||
);
|
||||
|
||||
// we have a stream of frontend tools to handle, inside the stream
|
||||
// execution is yeield back to this reply loop, and is of the same Message
|
||||
// type, so we can yield that back up to be handled
|
||||
while let Some(msg) = frontend_tool_stream.try_next().await? {
|
||||
yield AgentEvent::Message(msg);
|
||||
}
|
||||
|
||||
// Clone goose_mode once before the match to avoid move issues
|
||||
let mode = goose_mode.clone();
|
||||
if mode.as_str() == "chat" {
|
||||
// Skip all tool calls in chat mode
|
||||
for request in remaining_requests {
|
||||
let mut response = message_tool_response.lock().await;
|
||||
*response = response.clone().with_tool_response(
|
||||
request.id.clone(),
|
||||
Ok(vec![Content::text(CHAT_MODE_TOOL_SKIPPED_RESPONSE)]),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// At this point, we have handled the frontend tool requests and know goose_mode != "chat"
|
||||
// What remains is handling the remaining tool requests (enable extension,
|
||||
// regular tool calls) in goose_mode == ["auto", "approve" or "smart_approve"]
|
||||
let mut permission_manager = PermissionManager::default();
|
||||
let (permission_check_result, enable_extension_request_ids) = check_tool_permissions(
|
||||
&remaining_requests,
|
||||
&mode,
|
||||
tools_with_readonly_annotation.clone(),
|
||||
tools_without_annotation.clone(),
|
||||
&mut permission_manager,
|
||||
self.provider().await?).await;
|
||||
|
||||
// Handle pre-approved and read-only tools in parallel
|
||||
let mut tool_futures: Vec<(String, ToolStream)> = Vec::new();
|
||||
|
||||
// Skip the confirmation for approved tools
|
||||
for request in &permission_check_result.approved {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
let (req_id, tool_result) = self.dispatch_tool_call(tool_call, request.id.clone()).await;
|
||||
|
||||
tool_futures.push((req_id, match tool_result {
|
||||
Ok(result) => tool_stream(
|
||||
result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())),
|
||||
result.result,
|
||||
),
|
||||
Err(e) => tool_stream(
|
||||
Box::new(stream::empty()),
|
||||
futures::future::ready(Err(e)),
|
||||
),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
for request in &permission_check_result.denied {
|
||||
let mut response = message_tool_response.lock().await;
|
||||
*response = response.clone().with_tool_response(
|
||||
request.id.clone(),
|
||||
Ok(vec![Content::text(DECLINED_RESPONSE)]),
|
||||
);
|
||||
}
|
||||
|
||||
// We need interior mutability in handle_approval_tool_requests
|
||||
let tool_futures_arc = Arc::new(Mutex::new(tool_futures));
|
||||
|
||||
// Process tools requiring approval (enable extension, regular tool calls)
|
||||
let mut tool_approval_stream = self.handle_approval_tool_requests(
|
||||
&permission_check_result.needs_approval,
|
||||
tool_futures_arc.clone(),
|
||||
&mut permission_manager,
|
||||
message_tool_response.clone()
|
||||
);
|
||||
|
||||
// We have a stream of tool_approval_requests to handle
|
||||
// Execution is yielded back to this reply loop, and is of the same Message
|
||||
// type, so we can yield the Message back up to be handled and grab any
|
||||
// confirmations or denials
|
||||
while let Some(msg) = tool_approval_stream.try_next().await? {
|
||||
yield AgentEvent::Message(msg);
|
||||
}
|
||||
|
||||
tool_futures = {
|
||||
// Lock the mutex asynchronously
|
||||
let mut futures_lock = tool_futures_arc.lock().await;
|
||||
// Drain the vector and collect into a new Vec
|
||||
futures_lock.drain(..).collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
let with_id = tool_futures
|
||||
.into_iter()
|
||||
.map(|(request_id, stream)| {
|
||||
stream.map(move |item| (request_id.clone(), item))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut combined = stream::select_all(with_id);
|
||||
|
||||
let mut all_install_successful = true;
|
||||
|
||||
while let Some((request_id, item)) = combined.next().await {
|
||||
match item {
|
||||
ToolStreamItem::Result(output) => {
|
||||
if enable_extension_request_ids.contains(&request_id) && output.is_err(){
|
||||
all_install_successful = false;
|
||||
}
|
||||
let mut response = message_tool_response.lock().await;
|
||||
*response = response.clone().with_tool_response(request_id, output);
|
||||
},
|
||||
ToolStreamItem::Message(msg) => {
|
||||
yield AgentEvent::McpNotification((request_id, msg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update system prompt and tools if installations were successful
|
||||
if all_install_successful {
|
||||
(tools, toolshim_tools, system_prompt) = self.prepare_tools_and_prompt().await?;
|
||||
}
|
||||
}
|
||||
|
||||
let final_message_tool_resp = message_tool_response.lock().await.clone();
|
||||
yield AgentEvent::Message(final_message_tool_resp.clone());
|
||||
|
||||
added_message = true;
|
||||
push_message(&mut messages, response);
|
||||
push_message(&mut messages, final_message_tool_resp);
|
||||
|
||||
// Check for MCP notifications from subagents again before next iteration
|
||||
// Note: These are already handled as McpNotification events above,
|
||||
// so we don't need to convert them to assistant messages here.
|
||||
// This was causing duplicate plain-text notifications.
|
||||
// let mcp_notifications = self.get_mcp_notifications().await;
|
||||
// for notification in mcp_notifications {
|
||||
// // Extract subagent info from the notification data for assistant messages
|
||||
// if let JsonRpcMessage::Notification(ref notif) = notification {
|
||||
// if let Some(params) = ¬if.params {
|
||||
// if let Some(data) = params.get("data") {
|
||||
// if let (Some(subagent_id), Some(message)) = (
|
||||
// data.get("subagent_id").and_then(|v| v.as_str()),
|
||||
// data.get("message").and_then(|v| v.as_str())
|
||||
// ) {
|
||||
// yield AgentEvent::Message(
|
||||
// Message::assistant().with_text(
|
||||
// format!("Subagent {}: {}", subagent_id, message)
|
||||
// )
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
},
|
||||
Err(ProviderError::ContextLengthExceeded(_)) => {
|
||||
// At this point, the last message should be a user message
|
||||
// because call to provider led to context length exceeded error
|
||||
// Immediately yield a special message and break
|
||||
yield AgentEvent::Message(Message::assistant().with_context_length_exceeded(
|
||||
"The context length of the model has been exceeded. Please start a new session and try again.",
|
||||
));
|
||||
break;
|
||||
},
|
||||
Err(e) => {
|
||||
// Create an error message & terminate the stream
|
||||
error!("Error: {}", e);
|
||||
yield AgentEvent::Message(Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error.")));
|
||||
break;
|
||||
}
|
||||
|
||||
// Process tool requests depending on frontend tools and then goose_mode
|
||||
let message_tool_response = Arc::new(Mutex::new(Message::user()));
|
||||
|
||||
// First handle any frontend tool requests
|
||||
let mut frontend_tool_stream = self.handle_frontend_tool_requests(
|
||||
&frontend_requests,
|
||||
message_tool_response.clone()
|
||||
);
|
||||
|
||||
// we have a stream of frontend tools to handle, inside the stream
|
||||
// execution is yeield back to this reply loop, and is of the same Message
|
||||
// type, so we can yield that back up to be handled
|
||||
while let Some(msg) = frontend_tool_stream.try_next().await? {
|
||||
yield AgentEvent::Message(msg);
|
||||
}
|
||||
|
||||
// Clone goose_mode once before the match to avoid move issues
|
||||
let mode = goose_mode.clone();
|
||||
if mode.as_str() == "chat" {
|
||||
// Skip all tool calls in chat mode
|
||||
for request in remaining_requests {
|
||||
let mut response = message_tool_response.lock().await;
|
||||
*response = response.clone().with_tool_response(
|
||||
request.id.clone(),
|
||||
Ok(vec![Content::text(CHAT_MODE_TOOL_SKIPPED_RESPONSE)]),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// At this point, we have handled the frontend tool requests and know goose_mode != "chat"
|
||||
// What remains is handling the remaining tool requests (enable extension,
|
||||
// regular tool calls) in goose_mode == ["auto", "approve" or "smart_approve"]
|
||||
let mut permission_manager = PermissionManager::default();
|
||||
let (permission_check_result, enable_extension_request_ids) = check_tool_permissions(
|
||||
&remaining_requests,
|
||||
&mode,
|
||||
tools_with_readonly_annotation.clone(),
|
||||
tools_without_annotation.clone(),
|
||||
&mut permission_manager,
|
||||
self.provider().await?).await;
|
||||
|
||||
// Handle pre-approved and read-only tools in parallel
|
||||
let mut tool_futures: Vec<(String, ToolStream)> = Vec::new();
|
||||
|
||||
// Skip the confirmation for approved tools
|
||||
for request in &permission_check_result.approved {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
let (req_id, tool_result) = self.dispatch_tool_call(tool_call, request.id.clone()).await;
|
||||
|
||||
tool_futures.push((req_id, match tool_result {
|
||||
Ok(result) => tool_stream(
|
||||
result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())),
|
||||
result.result,
|
||||
),
|
||||
Err(e) => tool_stream(
|
||||
Box::new(stream::empty()),
|
||||
futures::future::ready(Err(e)),
|
||||
),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
for request in &permission_check_result.denied {
|
||||
let mut response = message_tool_response.lock().await;
|
||||
*response = response.clone().with_tool_response(
|
||||
request.id.clone(),
|
||||
Ok(vec![Content::text(DECLINED_RESPONSE)]),
|
||||
);
|
||||
}
|
||||
|
||||
// We need interior mutability in handle_approval_tool_requests
|
||||
let tool_futures_arc = Arc::new(Mutex::new(tool_futures));
|
||||
|
||||
// Process tools requiring approval (enable extension, regular tool calls)
|
||||
let mut tool_approval_stream = self.handle_approval_tool_requests(
|
||||
&permission_check_result.needs_approval,
|
||||
tool_futures_arc.clone(),
|
||||
&mut permission_manager,
|
||||
message_tool_response.clone()
|
||||
);
|
||||
|
||||
// We have a stream of tool_approval_requests to handle
|
||||
// Execution is yielded back to this reply loop, and is of the same Message
|
||||
// type, so we can yield the Message back up to be handled and grab any
|
||||
// confirmations or denials
|
||||
while let Some(msg) = tool_approval_stream.try_next().await? {
|
||||
yield AgentEvent::Message(msg);
|
||||
}
|
||||
|
||||
tool_futures = {
|
||||
// Lock the mutex asynchronously
|
||||
let mut futures_lock = tool_futures_arc.lock().await;
|
||||
// Drain the vector and collect into a new Vec
|
||||
futures_lock.drain(..).collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
let with_id = tool_futures
|
||||
.into_iter()
|
||||
.map(|(request_id, stream)| {
|
||||
stream.map(move |item| (request_id.clone(), item))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut combined = stream::select_all(with_id);
|
||||
|
||||
let mut all_install_successful = true;
|
||||
|
||||
while let Some((request_id, item)) = combined.next().await {
|
||||
match item {
|
||||
ToolStreamItem::Result(output) => {
|
||||
if enable_extension_request_ids.contains(&request_id) && output.is_err(){
|
||||
all_install_successful = false;
|
||||
}
|
||||
let mut response = message_tool_response.lock().await;
|
||||
*response = response.clone().with_tool_response(request_id, output);
|
||||
},
|
||||
ToolStreamItem::Message(msg) => {
|
||||
yield AgentEvent::McpNotification((request_id, msg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update system prompt and tools if installations were successful
|
||||
if all_install_successful {
|
||||
(tools, toolshim_tools, system_prompt) = self.prepare_tools_and_prompt().await?;
|
||||
}
|
||||
}
|
||||
|
||||
let final_message_tool_resp = message_tool_response.lock().await.clone();
|
||||
yield AgentEvent::Message(final_message_tool_resp.clone());
|
||||
|
||||
messages.push(response);
|
||||
messages.push(final_message_tool_resp);
|
||||
|
||||
// Check for MCP notifications from subagents again before next iteration
|
||||
// Note: These are already handled as McpNotification events above,
|
||||
// so we don't need to convert them to assistant messages here.
|
||||
// This was causing duplicate plain-text notifications.
|
||||
// let mcp_notifications = self.get_mcp_notifications().await;
|
||||
// for notification in mcp_notifications {
|
||||
// // Extract subagent info from the notification data for assistant messages
|
||||
// if let JsonRpcMessage::Notification(ref notif) = notification {
|
||||
// if let Some(params) = ¬if.params {
|
||||
// if let Some(data) = params.get("data") {
|
||||
// if let (Some(subagent_id), Some(message)) = (
|
||||
// data.get("subagent_id").and_then(|v| v.as_str()),
|
||||
// data.get("message").and_then(|v| v.as_str())
|
||||
// ) {
|
||||
// yield AgentEvent::Message(
|
||||
// Message::assistant().with_text(
|
||||
// format!("Subagent {}: {}", subagent_id, message)
|
||||
// )
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
},
|
||||
Err(ProviderError::ContextLengthExceeded(_)) => {
|
||||
// At this point, the last message should be a user message
|
||||
// because call to provider led to context length exceeded error
|
||||
// Immediately yield a special message and break
|
||||
yield AgentEvent::Message(Message::assistant().with_context_length_exceeded(
|
||||
"The context length of the model has been exceeded. Please start a new session and try again.",
|
||||
));
|
||||
break;
|
||||
},
|
||||
Err(e) => {
|
||||
// Create an error message & terminate the stream
|
||||
error!("Error: {}", e);
|
||||
yield AgentEvent::Message(Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error.")));
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !added_message {
|
||||
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
|
||||
if final_output_tool.final_output.is_none() {
|
||||
tracing::warn!("Final output tool has not been called yet. Continuing agent loop.");
|
||||
yield AgentEvent::Message(Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE));
|
||||
continue;
|
||||
} else {
|
||||
yield AgentEvent::Message(Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()));
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// Yield control back to the scheduler to prevent blocking
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
@@ -2,10 +2,13 @@ use anyhow::Result;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_stream::try_stream;
|
||||
use futures::stream::StreamExt;
|
||||
|
||||
use crate::agents::router_tool_selector::RouterToolSelectionStrategy;
|
||||
use crate::config::Config;
|
||||
use crate::message::{Message, MessageContent, ToolRequest};
|
||||
use crate::providers::base::{Provider, ProviderUsage};
|
||||
use crate::providers::base::{stream_from_single_message, MessageStream, Provider, ProviderUsage};
|
||||
use crate::providers::errors::ProviderError;
|
||||
use crate::providers::toolshim::{
|
||||
augment_message_with_tool_calls, convert_tool_messages_to_text,
|
||||
@@ -16,6 +19,19 @@ use mcp_core::tool::Tool;
|
||||
|
||||
use super::super::agents::Agent;
|
||||
|
||||
async fn toolshim_postprocess(
|
||||
response: Message,
|
||||
toolshim_tools: &[Tool],
|
||||
) -> Result<Message, ProviderError> {
|
||||
let interpreter = OllamaInterpreter::new().map_err(|e| {
|
||||
ProviderError::ExecutionError(format!("Failed to create OllamaInterpreter: {}", e))
|
||||
})?;
|
||||
|
||||
augment_message_with_tool_calls(&interpreter, response, toolshim_tools)
|
||||
.await
|
||||
.map_err(|e| ProviderError::ExecutionError(format!("Failed to augment message: {}", e)))
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
/// Prepares tools and system prompt for a provider request
|
||||
pub(crate) async fn prepare_tools_and_prompt(
|
||||
@@ -128,25 +144,67 @@ impl Agent {
|
||||
.complete(system_prompt, &messages_for_provider, tools)
|
||||
.await?;
|
||||
|
||||
// Store the model information in the global store
|
||||
crate::providers::base::set_current_model(&usage.model);
|
||||
|
||||
// Post-process / structure the response only if tool interpretation is enabled
|
||||
if config.toolshim {
|
||||
let interpreter = OllamaInterpreter::new().map_err(|e| {
|
||||
ProviderError::ExecutionError(format!("Failed to create OllamaInterpreter: {}", e))
|
||||
})?;
|
||||
|
||||
response = augment_message_with_tool_calls(&interpreter, response, toolshim_tools)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
ProviderError::ExecutionError(format!("Failed to augment message: {}", e))
|
||||
})?;
|
||||
response = toolshim_postprocess(response, toolshim_tools).await?;
|
||||
}
|
||||
|
||||
Ok((response, usage))
|
||||
}
|
||||
|
||||
/// Stream a response from the LLM provider.
|
||||
/// Handles toolshim transformations if needed
|
||||
pub(crate) async fn stream_response_from_provider(
|
||||
provider: Arc<dyn Provider>,
|
||||
system_prompt: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
toolshim_tools: &[Tool],
|
||||
) -> Result<MessageStream, ProviderError> {
|
||||
let config = provider.get_model_config();
|
||||
|
||||
// Convert tool messages to text if toolshim is enabled
|
||||
let messages_for_provider = if config.toolshim {
|
||||
convert_tool_messages_to_text(messages)
|
||||
} else {
|
||||
messages.to_vec()
|
||||
};
|
||||
|
||||
// Clone owned data to move into the async stream
|
||||
let system_prompt = system_prompt.to_owned();
|
||||
let tools = tools.to_owned();
|
||||
let toolshim_tools = toolshim_tools.to_owned();
|
||||
let provider = provider.clone();
|
||||
|
||||
let mut stream = if provider.supports_streaming() {
|
||||
provider
|
||||
.stream(system_prompt.as_str(), &messages_for_provider, &tools)
|
||||
.await?
|
||||
} else {
|
||||
let (message, usage) = provider
|
||||
.complete(system_prompt.as_str(), &messages_for_provider, &tools)
|
||||
.await?;
|
||||
stream_from_single_message(message, usage)
|
||||
};
|
||||
|
||||
Ok(Box::pin(try_stream! {
|
||||
while let Some(Ok((mut message, usage))) = stream.next().await {
|
||||
// Store the model information in the global store
|
||||
if let Some(usage) = usage.as_ref() {
|
||||
crate::providers::base::set_current_model(&usage.model);
|
||||
}
|
||||
|
||||
// Post-process / structure the response only if tool interpretation is enabled
|
||||
if message.is_some() && config.toolshim {
|
||||
message = Some(toolshim_postprocess(message.unwrap(), &toolshim_tools).await?);
|
||||
}
|
||||
|
||||
yield (message, usage);
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// Categorize tool requests from the response into different types
|
||||
/// Returns:
|
||||
/// - frontend_requests: Tool requests that should be handled by the frontend
|
||||
@@ -191,6 +249,7 @@ impl Agent {
|
||||
}
|
||||
|
||||
let filtered_message = Message {
|
||||
id: response.id.clone(),
|
||||
role: response.role.clone(),
|
||||
created: response.created,
|
||||
content: filtered_content,
|
||||
|
||||
@@ -247,14 +247,14 @@ mod tests {
|
||||
_tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
Ok((
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
created: Utc::now().timestamp(),
|
||||
content: vec![MessageContent::Text(TextContent {
|
||||
Message::new(
|
||||
Role::Assistant,
|
||||
Utc::now().timestamp(),
|
||||
vec![MessageContent::Text(TextContent {
|
||||
text: "Summarized content".to_string(),
|
||||
annotations: None,
|
||||
})],
|
||||
},
|
||||
),
|
||||
ProviderUsage::new("mock".to_string(), Usage::default()),
|
||||
))
|
||||
}
|
||||
@@ -277,30 +277,26 @@ mod tests {
|
||||
}
|
||||
|
||||
fn set_up_text_message(text: &str, role: Role) -> Message {
|
||||
Message {
|
||||
role,
|
||||
created: 0,
|
||||
content: vec![MessageContent::text(text.to_string())],
|
||||
}
|
||||
Message::new(role, 0, vec![MessageContent::text(text.to_string())])
|
||||
}
|
||||
|
||||
fn set_up_tool_request_message(id: &str, tool_call: ToolCall) -> Message {
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
created: 0,
|
||||
content: vec![MessageContent::tool_request(id.to_string(), Ok(tool_call))],
|
||||
}
|
||||
Message::new(
|
||||
Role::Assistant,
|
||||
0,
|
||||
vec![MessageContent::tool_request(id.to_string(), Ok(tool_call))],
|
||||
)
|
||||
}
|
||||
|
||||
fn set_up_tool_response_message(id: &str, tool_response: Vec<Content>) -> Message {
|
||||
Message {
|
||||
role: Role::User,
|
||||
created: 0,
|
||||
content: vec![MessageContent::tool_response(
|
||||
Message::new(
|
||||
Role::User,
|
||||
0,
|
||||
vec![MessageContent::tool_response(
|
||||
id.to_string(),
|
||||
Ok(tool_response),
|
||||
)],
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -448,14 +444,14 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reintegrate_removed_messages() {
|
||||
let summarized_messages = vec![Message {
|
||||
role: Role::Assistant,
|
||||
created: Utc::now().timestamp(),
|
||||
content: vec![MessageContent::Text(TextContent {
|
||||
let summarized_messages = vec![Message::new(
|
||||
Role::Assistant,
|
||||
Utc::now().timestamp(),
|
||||
vec![MessageContent::Text(TextContent {
|
||||
text: "Summary".to_string(),
|
||||
annotations: None,
|
||||
})],
|
||||
}];
|
||||
)];
|
||||
let arguments = json!({
|
||||
"param1": "value1"
|
||||
});
|
||||
|
||||
@@ -303,15 +303,46 @@ impl From<PromptMessage> for Message {
|
||||
/// A message to or from an LLM
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Message {
|
||||
pub id: Option<String>,
|
||||
pub role: Role,
|
||||
pub created: i64,
|
||||
pub content: Vec<MessageContent>,
|
||||
}
|
||||
|
||||
pub fn push_message(messages: &mut Vec<Message>, message: Message) {
|
||||
if let Some(last) = messages
|
||||
.last_mut()
|
||||
.filter(|m| m.id.is_some() && m.id == message.id)
|
||||
{
|
||||
match (last.content.last_mut(), message.content.last()) {
|
||||
(Some(MessageContent::Text(ref mut last)), Some(MessageContent::Text(new)))
|
||||
if message.content.len() == 1 =>
|
||||
{
|
||||
last.text.push_str(&new.text);
|
||||
}
|
||||
(_, _) => {
|
||||
last.content.extend(message.content);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
messages.push(message);
|
||||
}
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new(role: Role, created: i64, content: Vec<MessageContent>) -> Self {
|
||||
Message {
|
||||
id: None,
|
||||
role,
|
||||
created,
|
||||
content,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new user message with the current timestamp
|
||||
pub fn user() -> Self {
|
||||
Message {
|
||||
id: None,
|
||||
role: Role::User,
|
||||
created: Utc::now().timestamp(),
|
||||
content: Vec::new(),
|
||||
@@ -321,6 +352,7 @@ impl Message {
|
||||
/// Create a new assistant message with the current timestamp
|
||||
pub fn assistant() -> Self {
|
||||
Message {
|
||||
id: None,
|
||||
role: Role::Assistant,
|
||||
created: Utc::now().timestamp(),
|
||||
content: Vec::new(),
|
||||
|
||||
@@ -81,10 +81,10 @@ fn create_check_messages(tool_requests: Vec<&ToolRequest>) -> Vec<Message> {
|
||||
})
|
||||
.collect();
|
||||
let mut check_messages = vec![];
|
||||
check_messages.push(Message {
|
||||
role: mcp_core::Role::User,
|
||||
created: Utc::now().timestamp(),
|
||||
content: vec![MessageContent::Text(TextContent {
|
||||
check_messages.push(Message::new(
|
||||
mcp_core::Role::User,
|
||||
Utc::now().timestamp(),
|
||||
vec![MessageContent::Text(TextContent {
|
||||
text: format!(
|
||||
"Here are the tool requests: {:?}\n\nAnalyze the tool requests and list the tools that perform read-only operations. \
|
||||
\n\nGuidelines for Read-Only Operations: \
|
||||
@@ -96,7 +96,7 @@ fn create_check_messages(tool_requests: Vec<&ToolRequest>) -> Vec<Message> {
|
||||
),
|
||||
annotations: None,
|
||||
})],
|
||||
});
|
||||
));
|
||||
check_messages
|
||||
}
|
||||
|
||||
@@ -296,10 +296,10 @@ mod tests {
|
||||
_tools: &[Tool],
|
||||
) -> anyhow::Result<(Message, ProviderUsage), ProviderError> {
|
||||
Ok((
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
created: Utc::now().timestamp(),
|
||||
content: vec![MessageContent::ToolRequest(ToolRequest {
|
||||
Message::new(
|
||||
Role::Assistant,
|
||||
Utc::now().timestamp(),
|
||||
vec![MessageContent::ToolRequest(ToolRequest {
|
||||
id: "mock_tool_request".to_string(),
|
||||
tool_call: ToolResult::Ok(ToolCall {
|
||||
name: "platform__tool_by_tool_permission".to_string(),
|
||||
@@ -308,7 +308,7 @@ mod tests {
|
||||
}),
|
||||
}),
|
||||
})],
|
||||
},
|
||||
),
|
||||
ProviderUsage::new("mock".to_string(), Usage::default()),
|
||||
))
|
||||
}
|
||||
@@ -354,10 +354,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_extract_read_only_tools() {
|
||||
let message = Message {
|
||||
role: Role::Assistant,
|
||||
created: Utc::now().timestamp(),
|
||||
content: vec![MessageContent::ToolRequest(ToolRequest {
|
||||
let message = Message::new(
|
||||
Role::Assistant,
|
||||
Utc::now().timestamp(),
|
||||
vec![MessageContent::ToolRequest(ToolRequest {
|
||||
id: "tool_2".to_string(),
|
||||
tool_call: ToolResult::Ok(ToolCall {
|
||||
name: "platform__tool_by_tool_permission".to_string(),
|
||||
@@ -366,7 +366,7 @@ mod tests {
|
||||
}),
|
||||
}),
|
||||
})],
|
||||
};
|
||||
);
|
||||
|
||||
let result = extract_read_only_tools(&message);
|
||||
assert!(result.is_some());
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use anyhow::Result;
|
||||
use futures::Stream;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::errors::ProviderError;
|
||||
@@ -8,6 +9,8 @@ use mcp_core::tool::Tool;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
use std::ops::{Add, AddAssign};
|
||||
use std::pin::Pin;
|
||||
use std::sync::Mutex;
|
||||
|
||||
/// A global store for the current model being used, we use this as when a provider returns, it tells us the real model, not an alias
|
||||
@@ -184,13 +187,43 @@ impl ProviderUsage {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default, Copy)]
|
||||
pub struct Usage {
|
||||
pub input_tokens: Option<i32>,
|
||||
pub output_tokens: Option<i32>,
|
||||
pub total_tokens: Option<i32>,
|
||||
}
|
||||
|
||||
fn sum_optionals<T>(a: Option<T>, b: Option<T>) -> Option<T>
|
||||
where
|
||||
T: Add<Output = T> + Default,
|
||||
{
|
||||
match (a, b) {
|
||||
(Some(x), Some(y)) => Some(x + y),
|
||||
(Some(x), None) => Some(x + T::default()),
|
||||
(None, Some(y)) => Some(T::default() + y),
|
||||
(None, None) => None,
|
||||
}
|
||||
}
|
||||
|
||||
impl Add for Usage {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, other: Self) -> Self {
|
||||
Self {
|
||||
input_tokens: sum_optionals(self.input_tokens, other.input_tokens),
|
||||
output_tokens: sum_optionals(self.output_tokens, other.output_tokens),
|
||||
total_tokens: sum_optionals(self.total_tokens, other.total_tokens),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AddAssign for Usage {
|
||||
fn add_assign(&mut self, rhs: Self) {
|
||||
*self = *self + rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Usage {
|
||||
pub fn new(
|
||||
input_tokens: Option<i32>,
|
||||
@@ -270,6 +303,21 @@ pub trait Provider: Send + Sync {
|
||||
None
|
||||
}
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
_system: &str,
|
||||
_messages: &[Message],
|
||||
_tools: &[Tool],
|
||||
) -> Result<MessageStream, ProviderError> {
|
||||
Err(ProviderError::NotImplemented(
|
||||
"streaming not implemented".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn supports_streaming(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Get the currently active model name
|
||||
/// For regular providers, this returns the configured model
|
||||
/// For LeadWorkerProvider, this returns the currently active model (lead or worker)
|
||||
@@ -282,6 +330,18 @@ pub trait Provider: Send + Sync {
|
||||
}
|
||||
}
|
||||
|
||||
/// A message stream yields partial text content but complete tool calls, all within the Message object
|
||||
/// So a message with text will contain potentially just a word of a longer response, but tool calls
|
||||
/// messages will only be yielded once concatenated.
|
||||
pub type MessageStream = Pin<
|
||||
Box<dyn Stream<Item = Result<(Option<Message>, Option<ProviderUsage>), ProviderError>> + Send>,
|
||||
>;
|
||||
|
||||
pub fn stream_from_single_message(message: Message, usage: ProviderUsage) -> MessageStream {
|
||||
let stream = futures::stream::once(async move { Ok((Some(message), Some(usage))) });
|
||||
Box::pin(stream)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -219,11 +219,11 @@ impl ClaudeCodeProvider {
|
||||
annotations: None,
|
||||
})];
|
||||
|
||||
let response_message = Message {
|
||||
role: Role::Assistant,
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
content: message_content,
|
||||
};
|
||||
let response_message = Message::new(
|
||||
Role::Assistant,
|
||||
chrono::Utc::now().timestamp(),
|
||||
message_content,
|
||||
);
|
||||
|
||||
Ok((response_message, usage))
|
||||
}
|
||||
@@ -353,14 +353,14 @@ impl ClaudeCodeProvider {
|
||||
println!("================================");
|
||||
}
|
||||
|
||||
let message = Message {
|
||||
role: mcp_core::Role::Assistant,
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
content: vec![MessageContent::Text(mcp_core::content::TextContent {
|
||||
let message = Message::new(
|
||||
mcp_core::Role::Assistant,
|
||||
chrono::Utc::now().timestamp(),
|
||||
vec![MessageContent::Text(mcp_core::content::TextContent {
|
||||
text: description.clone(),
|
||||
annotations: None,
|
||||
})],
|
||||
};
|
||||
);
|
||||
|
||||
let usage = Usage::default();
|
||||
|
||||
|
||||
@@ -1,4 +1,16 @@
|
||||
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
|
||||
use anyhow::Result;
|
||||
use async_stream::try_stream;
|
||||
use async_trait::async_trait;
|
||||
use futures::TryStreamExt;
|
||||
use reqwest::{Client, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::io;
|
||||
use std::time::Duration;
|
||||
use tokio::pin;
|
||||
use tokio_util::io::StreamReader;
|
||||
|
||||
use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage};
|
||||
use super::embedding::EmbeddingCapable;
|
||||
use super::errors::ProviderError;
|
||||
use super::formats::databricks::{create_request, get_usage, response_to_message};
|
||||
@@ -7,17 +19,13 @@ use super::utils::{get_model, ImageFormat};
|
||||
use crate::config::ConfigError;
|
||||
use crate::message::Message;
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::formats::databricks::response_to_streaming_message;
|
||||
use mcp_core::tool::Tool;
|
||||
use serde_json::json;
|
||||
use url::Url;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::{Client, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
use tokio_stream::StreamExt;
|
||||
use tokio_util::codec::{FramedRead, LinesCodec};
|
||||
use url::Url;
|
||||
|
||||
const DEFAULT_CLIENT_ID: &str = "databricks-cli";
|
||||
const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020";
|
||||
@@ -266,9 +274,6 @@ impl DatabricksProvider {
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
let base_url = Url::parse(&self.host)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
|
||||
// Check if this is an embedding request by looking at the payload structure
|
||||
let is_embedding = payload.get("input").is_some() && payload.get("messages").is_none();
|
||||
let path = if is_embedding {
|
||||
@@ -279,56 +284,71 @@ impl DatabricksProvider {
|
||||
format!("serving-endpoints/{}/invocations", self.model.model_name)
|
||||
};
|
||||
|
||||
let url = base_url.join(&path).map_err(|e| {
|
||||
match self.post_with_retry(path.as_str(), &payload).await {
|
||||
Ok(res) => res.json().await.map_err(|_| {
|
||||
ProviderError::RequestFailed("Response body is not valid JSON".to_string())
|
||||
}),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
async fn post_with_retry(
|
||||
&self,
|
||||
path: &str,
|
||||
payload: &Value,
|
||||
) -> Result<reqwest::Response, ProviderError> {
|
||||
let base_url = Url::parse(&self.host)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
let url = base_url.join(path).map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
||||
})?;
|
||||
|
||||
// Initialize retry counter
|
||||
let mut attempts = 0;
|
||||
let mut last_error = None;
|
||||
|
||||
loop {
|
||||
// Check if we've exceeded max retries
|
||||
if attempts > 0 && attempts > self.retry_config.max_retries {
|
||||
let error_msg = format!(
|
||||
"Exceeded maximum retry attempts ({}) for rate limiting (429)",
|
||||
self.retry_config.max_retries
|
||||
);
|
||||
tracing::error!("{}", error_msg);
|
||||
return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg)));
|
||||
}
|
||||
|
||||
let auth_header = self.ensure_auth_header().await?;
|
||||
let response = self
|
||||
.client
|
||||
.post(url.clone())
|
||||
.header("Authorization", auth_header)
|
||||
.json(&payload)
|
||||
.json(payload)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
let payload: Option<Value> = response.json().await.ok();
|
||||
|
||||
match status {
|
||||
StatusCode::OK => {
|
||||
return payload.ok_or_else(|| {
|
||||
ProviderError::RequestFailed("Response body is not valid JSON".to_string())
|
||||
});
|
||||
}
|
||||
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
|
||||
return Err(ProviderError::Authentication(format!(
|
||||
"Authentication failed. Please ensure your API keys are valid and have the required permissions. \
|
||||
Status: {}. Response: {:?}",
|
||||
status, payload
|
||||
)));
|
||||
break match status {
|
||||
StatusCode::OK => Ok(response),
|
||||
StatusCode::TOO_MANY_REQUESTS
|
||||
| StatusCode::INTERNAL_SERVER_ERROR
|
||||
| StatusCode::SERVICE_UNAVAILABLE => {
|
||||
if attempts < self.retry_config.max_retries {
|
||||
attempts += 1;
|
||||
tracing::warn!(
|
||||
"{}: retrying ({}/{})",
|
||||
status,
|
||||
attempts,
|
||||
self.retry_config.max_retries
|
||||
);
|
||||
|
||||
let delay = self.retry_config.delay_for_attempt(attempts);
|
||||
tracing::info!("Backing off for {:?} before retry", delay);
|
||||
sleep(delay).await;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
Err(match status {
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
ProviderError::RateLimitExceeded("Rate limit exceeded".to_string())
|
||||
}
|
||||
_ => ProviderError::ServerError("Server error".to_string()),
|
||||
})
|
||||
}
|
||||
StatusCode::BAD_REQUEST => {
|
||||
// Databricks provides a generic 'error' but also includes 'external_model_message' which is provider specific
|
||||
// We try to extract the error message from the payload and check for phrases that indicate context length exceeded
|
||||
let payload_str = serde_json::to_string(&payload)
|
||||
.unwrap_or_default()
|
||||
.to_lowercase();
|
||||
let bytes = response.bytes().await?;
|
||||
let payload_str = String::from_utf8_lossy(&bytes).to_lowercase();
|
||||
let check_phrases = [
|
||||
"too long",
|
||||
"context length",
|
||||
@@ -347,13 +367,13 @@ impl DatabricksProvider {
|
||||
}
|
||||
|
||||
let mut error_msg = "Unknown error".to_string();
|
||||
if let Some(payload) = &payload {
|
||||
if let Ok(response_json) = serde_json::from_slice::<Value>(&bytes) {
|
||||
// try to convert message to string, if that fails use external_model_message
|
||||
error_msg = payload
|
||||
error_msg = response_json
|
||||
.get("message")
|
||||
.and_then(|m| m.as_str())
|
||||
.or_else(|| {
|
||||
payload
|
||||
response_json
|
||||
.get("external_model_message")
|
||||
.and_then(|ext| ext.get("message"))
|
||||
.and_then(|m| m.as_str())
|
||||
@@ -366,7 +386,7 @@ impl DatabricksProvider {
|
||||
"{}",
|
||||
format!(
|
||||
"Provider request failed with status: {}. Payload: {:?}",
|
||||
status, payload
|
||||
status, payload_str
|
||||
)
|
||||
);
|
||||
return Err(ProviderError::RequestFailed(format!(
|
||||
@@ -374,50 +394,13 @@ impl DatabricksProvider {
|
||||
status, error_msg
|
||||
)));
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
attempts += 1;
|
||||
let error_msg = format!(
|
||||
"Rate limit exceeded (attempt {}/{}): {:?}",
|
||||
attempts, self.retry_config.max_retries, payload
|
||||
);
|
||||
tracing::warn!("{}. Retrying after backoff...", error_msg);
|
||||
|
||||
// Store the error in case we need to return it after max retries
|
||||
last_error = Some(ProviderError::RateLimitExceeded(error_msg));
|
||||
|
||||
// Calculate and apply the backoff delay
|
||||
let delay = self.retry_config.delay_for_attempt(attempts);
|
||||
tracing::info!("Backing off for {:?} before retry", delay);
|
||||
sleep(delay).await;
|
||||
|
||||
// Continue to the next retry attempt
|
||||
continue;
|
||||
}
|
||||
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
|
||||
attempts += 1;
|
||||
let error_msg = format!(
|
||||
"Server error (attempt {}/{}): {:?}",
|
||||
attempts, self.retry_config.max_retries, payload
|
||||
);
|
||||
tracing::warn!("{}. Retrying after backoff...", error_msg);
|
||||
|
||||
// Store the error in case we need to return it after max retries
|
||||
last_error = Some(ProviderError::ServerError(error_msg));
|
||||
|
||||
// Calculate and apply the backoff delay
|
||||
let delay = self.retry_config.delay_for_attempt(attempts);
|
||||
tracing::info!("Backing off for {:?} before retry", delay);
|
||||
sleep(delay).await;
|
||||
|
||||
// Continue to the next retry attempt
|
||||
continue;
|
||||
}
|
||||
_ => {
|
||||
tracing::debug!(
|
||||
"{}",
|
||||
format!(
|
||||
"Provider request failed with status: {}. Payload: {:?}",
|
||||
status, payload
|
||||
status,
|
||||
response.text().await.ok().unwrap_or_default()
|
||||
)
|
||||
);
|
||||
return Err(ProviderError::RequestFailed(format!(
|
||||
@@ -425,7 +408,7 @@ impl DatabricksProvider {
|
||||
status
|
||||
)));
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -472,13 +455,12 @@ impl Provider for DatabricksProvider {
|
||||
|
||||
// Parse response
|
||||
let message = response_to_message(response.clone())?;
|
||||
let usage = match get_usage(&response) {
|
||||
Ok(usage) => usage,
|
||||
Err(ProviderError::UsageError(e)) => {
|
||||
tracing::debug!("Failed to get usage data: {}", e);
|
||||
let usage = match response.get("usage").map(get_usage) {
|
||||
Some(usage) => usage,
|
||||
None => {
|
||||
tracing::debug!("Failed to get usage data");
|
||||
Usage::default()
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
let model = get_model(&response);
|
||||
super::utils::emit_debug_trace(&self.model, &payload, &response, &usage);
|
||||
@@ -486,6 +468,54 @@ impl Provider for DatabricksProvider {
|
||||
Ok((message, ProviderUsage::new(model, usage)))
|
||||
}
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<MessageStream, ProviderError> {
|
||||
let mut payload = create_request(&self.model, system, messages, tools, &self.image_format)?;
|
||||
// Remove the model key which is part of the url with databricks
|
||||
payload
|
||||
.as_object_mut()
|
||||
.expect("payload should have model key")
|
||||
.remove("model");
|
||||
|
||||
payload
|
||||
.as_object_mut()
|
||||
.unwrap()
|
||||
.insert("stream".to_string(), Value::Bool(true));
|
||||
|
||||
let response = self
|
||||
.post_with_retry(
|
||||
format!("serving-endpoints/{}/invocations", self.model.model_name).as_str(),
|
||||
&payload,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Map reqwest error to io::Error
|
||||
let stream = response.bytes_stream().map_err(io::Error::other);
|
||||
|
||||
let model_config = self.model.clone();
|
||||
// Wrap in a line decoder and yield lines inside the stream
|
||||
Ok(Box::pin(try_stream! {
|
||||
let stream_reader = StreamReader::new(stream);
|
||||
let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from);
|
||||
|
||||
let message_stream = response_to_streaming_message(framed);
|
||||
pin!(message_stream);
|
||||
while let Some(message) = message_stream.next().await {
|
||||
let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?;
|
||||
super::utils::emit_debug_trace(&model_config, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default());
|
||||
yield (message, usage);
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn supports_streaming(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn supports_embeddings(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
@@ -23,6 +23,9 @@ pub enum ProviderError {
|
||||
|
||||
#[error("Usage data error: {0}")]
|
||||
UsageError(String),
|
||||
|
||||
#[error("Unsupported operation: {0}")]
|
||||
NotImplemented(String),
|
||||
}
|
||||
|
||||
impl From<anyhow::Error> for ProviderError {
|
||||
|
||||
@@ -212,17 +212,17 @@ mod tests {
|
||||
_tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
Ok((
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
created: Utc::now().timestamp(),
|
||||
content: vec![MessageContent::Text(TextContent {
|
||||
Message::new(
|
||||
Role::Assistant,
|
||||
Utc::now().timestamp(),
|
||||
vec![MessageContent::Text(TextContent {
|
||||
text: format!(
|
||||
"Response from {} with model {}",
|
||||
self.name, self.model_config.model_name
|
||||
),
|
||||
annotations: None,
|
||||
})],
|
||||
},
|
||||
),
|
||||
ProviderUsage::new(self.model_config.model_name.clone(), Usage::default()),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -260,11 +260,7 @@ pub fn from_bedrock_message(message: &bedrock::Message) -> Result<Message> {
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let created = Utc::now().timestamp();
|
||||
|
||||
Ok(Message {
|
||||
role,
|
||||
content,
|
||||
created,
|
||||
})
|
||||
Ok(Message::new(role, created, content))
|
||||
}
|
||||
|
||||
pub fn from_bedrock_content_block(block: &bedrock::ContentBlock) -> Result<MessageContent> {
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
use crate::message::{Message, MessageContent};
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::base::Usage;
|
||||
use crate::providers::errors::ProviderError;
|
||||
use crate::providers::base::{ProviderUsage, Usage};
|
||||
use crate::providers::utils::{
|
||||
convert_image, detect_image_path, is_valid_function_name, load_image_file,
|
||||
sanitize_function_name, ImageFormat,
|
||||
};
|
||||
use anyhow::{anyhow, Error};
|
||||
use async_stream::try_stream;
|
||||
use futures::Stream;
|
||||
use mcp_core::ToolError;
|
||||
use mcp_core::{Content, Role, Tool, ToolCall};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
/// Convert internal Message format to Databricks' API message specification
|
||||
@@ -358,18 +360,162 @@ pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Message {
|
||||
role: Role::Assistant,
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
Ok(Message::new(
|
||||
Role::Assistant,
|
||||
chrono::Utc::now().timestamp(),
|
||||
content,
|
||||
})
|
||||
))
|
||||
}
|
||||
|
||||
pub fn get_usage(data: &Value) -> Result<Usage, ProviderError> {
|
||||
let usage = data
|
||||
.get("usage")
|
||||
.ok_or_else(|| ProviderError::UsageError("No usage data in response".to_string()))?;
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct DeltaToolCallFunction {
|
||||
name: Option<String>,
|
||||
arguments: String, // chunk of encoded JSON,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct DeltaToolCall {
|
||||
id: Option<String>,
|
||||
function: DeltaToolCallFunction,
|
||||
index: Option<i32>,
|
||||
r#type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct Delta {
|
||||
content: Option<String>,
|
||||
role: Option<String>,
|
||||
tool_calls: Option<Vec<DeltaToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct StreamingChoice {
|
||||
delta: Delta,
|
||||
index: Option<i32>,
|
||||
finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct StreamingChunk {
|
||||
choices: Vec<StreamingChoice>,
|
||||
created: Option<i64>,
|
||||
id: Option<String>,
|
||||
usage: Option<Value>,
|
||||
model: String,
|
||||
}
|
||||
|
||||
fn strip_data_prefix(line: &str) -> Option<&str> {
|
||||
line.strip_prefix("data: ").map(|s| s.trim())
|
||||
}
|
||||
|
||||
pub fn response_to_streaming_message<S>(
|
||||
mut stream: S,
|
||||
) -> impl Stream<Item = anyhow::Result<(Option<Message>, Option<ProviderUsage>)>> + 'static
|
||||
where
|
||||
S: Stream<Item = anyhow::Result<String>> + Unpin + Send + 'static,
|
||||
{
|
||||
try_stream! {
|
||||
use futures::StreamExt;
|
||||
|
||||
'outer: while let Some(response) = stream.next().await {
|
||||
if response.as_ref().is_ok_and(|s| s == "data: [DONE]") {
|
||||
break 'outer;
|
||||
}
|
||||
let response_str = response?;
|
||||
let line = strip_data_prefix(&response_str);
|
||||
|
||||
if line.is_none() || line.is_some_and(|l| l.is_empty()) {
|
||||
continue
|
||||
}
|
||||
|
||||
let chunk: StreamingChunk = serde_json::from_str(line
|
||||
.ok_or_else(|| anyhow!("unexpected stream format"))?)
|
||||
.map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?;
|
||||
let model = chunk.model.clone();
|
||||
|
||||
let usage = chunk.usage.as_ref().map(|u| {
|
||||
ProviderUsage {
|
||||
usage: get_usage(u),
|
||||
model,
|
||||
}
|
||||
});
|
||||
|
||||
if chunk.choices.is_empty() {
|
||||
yield (None, usage)
|
||||
} else if let Some(tool_calls) = &chunk.choices[0].delta.tool_calls {
|
||||
let tool_call = &tool_calls[0];
|
||||
let id = tool_call.id.clone().ok_or(anyhow!("No tool call ID"))?;
|
||||
let function_name = tool_call.function.name.clone().ok_or(anyhow!("No function name"))?;
|
||||
let mut arguments = tool_call.function.arguments.clone();
|
||||
|
||||
while let Some(response_chunk) = stream.next().await {
|
||||
if response_chunk.as_ref().is_ok_and(|s| s == "data: [DONE]") {
|
||||
break 'outer;
|
||||
}
|
||||
let response_str = response_chunk?;
|
||||
if let Some(line) = strip_data_prefix(&response_str) {
|
||||
let tool_chunk: StreamingChunk = serde_json::from_str(line)
|
||||
.map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?;
|
||||
let more_args = tool_chunk.choices[0].delta.tool_calls.as_ref()
|
||||
.and_then(|calls| calls.first())
|
||||
.map(|call| call.function.arguments.as_str());
|
||||
if let Some(more_args) = more_args {
|
||||
arguments.push_str(more_args);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let parsed = if arguments.is_empty() {
|
||||
Ok(json!({}))
|
||||
} else {
|
||||
serde_json::from_str::<Value>(&arguments)
|
||||
};
|
||||
|
||||
let content = match parsed {
|
||||
Ok(params) => MessageContent::tool_request(
|
||||
id,
|
||||
Ok(ToolCall::new(function_name, params)),
|
||||
),
|
||||
Err(e) => {
|
||||
let error = ToolError::InvalidParameters(format!(
|
||||
"Could not interpret tool use parameters for id {}: {}",
|
||||
id, e
|
||||
));
|
||||
MessageContent::tool_request(id, Err(error))
|
||||
}
|
||||
};
|
||||
|
||||
yield (
|
||||
Some(Message {
|
||||
id: chunk.id,
|
||||
role: Role::Assistant,
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
content: vec![content],
|
||||
}),
|
||||
usage,
|
||||
)
|
||||
} else if let Some(text) = &chunk.choices[0].delta.content {
|
||||
yield (
|
||||
Some(Message {
|
||||
id: chunk.id,
|
||||
role: Role::Assistant,
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
content: vec![MessageContent::text(text)],
|
||||
}),
|
||||
if chunk.choices[0].finish_reason.is_some() {
|
||||
usage
|
||||
} else {
|
||||
None
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_usage(usage: &Value) -> Usage {
|
||||
let input_tokens = usage
|
||||
.get("prompt_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
@@ -389,7 +535,7 @@ pub fn get_usage(data: &Value) -> Result<Usage, ProviderError> {
|
||||
_ => None,
|
||||
});
|
||||
|
||||
Ok(Usage::new(input_tokens, output_tokens, total_tokens))
|
||||
Usage::new(input_tokens, output_tokens, total_tokens)
|
||||
}
|
||||
|
||||
/// Validates and fixes tool schemas to ensure they have proper parameter structure.
|
||||
|
||||
@@ -209,11 +209,7 @@ pub fn response_to_message(response: Value) -> Result<Message> {
|
||||
let role = Role::Assistant;
|
||||
let created = chrono::Utc::now().timestamp();
|
||||
if candidate.is_none() {
|
||||
return Ok(Message {
|
||||
role,
|
||||
created,
|
||||
content,
|
||||
});
|
||||
return Ok(Message::new(role, created, content));
|
||||
}
|
||||
let candidate = candidate.unwrap();
|
||||
let parts = candidate
|
||||
@@ -252,11 +248,7 @@ pub fn response_to_message(response: Value) -> Result<Message> {
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Message {
|
||||
role,
|
||||
created,
|
||||
content,
|
||||
})
|
||||
Ok(Message::new(role, created, content))
|
||||
}
|
||||
|
||||
/// Extract usage information from Google's API response
|
||||
@@ -324,43 +316,39 @@ mod tests {
|
||||
use serde_json::json;
|
||||
|
||||
fn set_up_text_message(text: &str, role: Role) -> Message {
|
||||
Message {
|
||||
role,
|
||||
created: 0,
|
||||
content: vec![MessageContent::text(text.to_string())],
|
||||
}
|
||||
Message::new(role, 0, vec![MessageContent::text(text.to_string())])
|
||||
}
|
||||
|
||||
fn set_up_tool_request_message(id: &str, tool_call: ToolCall) -> Message {
|
||||
Message {
|
||||
role: Role::User,
|
||||
created: 0,
|
||||
content: vec![MessageContent::tool_request(id.to_string(), Ok(tool_call))],
|
||||
}
|
||||
Message::new(
|
||||
Role::User,
|
||||
0,
|
||||
vec![MessageContent::tool_request(id.to_string(), Ok(tool_call))],
|
||||
)
|
||||
}
|
||||
|
||||
fn set_up_tool_confirmation_message(id: &str, tool_call: ToolCall) -> Message {
|
||||
Message {
|
||||
role: Role::User,
|
||||
created: 0,
|
||||
content: vec![MessageContent::tool_confirmation_request(
|
||||
Message::new(
|
||||
Role::User,
|
||||
0,
|
||||
vec![MessageContent::tool_confirmation_request(
|
||||
id.to_string(),
|
||||
tool_call.name.clone(),
|
||||
tool_call.arguments.clone(),
|
||||
Some("Goose would like to call the above tool. Allow? (y/n):".to_string()),
|
||||
)],
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn set_up_tool_response_message(id: &str, tool_response: Vec<Content>) -> Message {
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
created: 0,
|
||||
content: vec![MessageContent::tool_response(
|
||||
Message::new(
|
||||
Role::Assistant,
|
||||
0,
|
||||
vec![MessageContent::tool_response(
|
||||
id.to_string(),
|
||||
Ok(tool_response),
|
||||
)],
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn set_up_tool(name: &str, description: &str, params: Value) -> Tool {
|
||||
|
||||
@@ -274,11 +274,11 @@ pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Message {
|
||||
role: Role::Assistant,
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
Ok(Message::new(
|
||||
Role::Assistant,
|
||||
chrono::Utc::now().timestamp(),
|
||||
content,
|
||||
})
|
||||
))
|
||||
}
|
||||
|
||||
pub fn get_usage(data: &Value) -> Result<Usage, ProviderError> {
|
||||
|
||||
@@ -169,14 +169,14 @@ impl GeminiCliProvider {
|
||||
));
|
||||
}
|
||||
|
||||
let message = Message {
|
||||
role: Role::Assistant,
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
content: vec![MessageContent::Text(TextContent {
|
||||
let message = Message::new(
|
||||
Role::Assistant,
|
||||
chrono::Utc::now().timestamp(),
|
||||
vec![MessageContent::Text(TextContent {
|
||||
text: response_text,
|
||||
annotations: None,
|
||||
})],
|
||||
};
|
||||
);
|
||||
|
||||
let usage = Usage::default(); // No usage info available for gemini CLI
|
||||
|
||||
@@ -214,14 +214,14 @@ impl GeminiCliProvider {
|
||||
println!("================================");
|
||||
}
|
||||
|
||||
let message = Message {
|
||||
role: Role::Assistant,
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
content: vec![MessageContent::Text(TextContent {
|
||||
let message = Message::new(
|
||||
Role::Assistant,
|
||||
chrono::Utc::now().timestamp(),
|
||||
vec![MessageContent::Text(TextContent {
|
||||
text: description.clone(),
|
||||
annotations: None,
|
||||
})],
|
||||
};
|
||||
);
|
||||
|
||||
let usage = Usage::default();
|
||||
|
||||
|
||||
@@ -480,14 +480,14 @@ mod tests {
|
||||
_tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
Ok((
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
created: Utc::now().timestamp(),
|
||||
content: vec![MessageContent::Text(TextContent {
|
||||
Message::new(
|
||||
Role::Assistant,
|
||||
Utc::now().timestamp(),
|
||||
vec![MessageContent::Text(TextContent {
|
||||
text: format!("Response from {}", self.name),
|
||||
annotations: None,
|
||||
})],
|
||||
},
|
||||
),
|
||||
ProviderUsage::new(self.name.clone(), Usage::default()),
|
||||
))
|
||||
}
|
||||
@@ -643,14 +643,14 @@ mod tests {
|
||||
))
|
||||
} else {
|
||||
Ok((
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
created: Utc::now().timestamp(),
|
||||
content: vec![MessageContent::Text(TextContent {
|
||||
Message::new(
|
||||
Role::Assistant,
|
||||
Utc::now().timestamp(),
|
||||
vec![MessageContent::Text(TextContent {
|
||||
text: format!("Response from {}", self.name),
|
||||
annotations: None,
|
||||
})],
|
||||
},
|
||||
),
|
||||
ProviderUsage::new(self.name.clone(), Usage::default()),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -203,14 +203,14 @@ impl SageMakerTgiProvider {
|
||||
// Strip any HTML tags that might have been generated
|
||||
let clean_text = self.strip_html_tags(generated_text);
|
||||
|
||||
Ok(Message {
|
||||
role: Role::Assistant,
|
||||
created: Utc::now().timestamp(),
|
||||
content: vec![MessageContent::Text(TextContent {
|
||||
Ok(Message::new(
|
||||
Role::Assistant,
|
||||
Utc::now().timestamp(),
|
||||
vec![MessageContent::Text(TextContent {
|
||||
text: clean_text,
|
||||
annotations: None,
|
||||
})],
|
||||
})
|
||||
))
|
||||
}
|
||||
|
||||
/// Strip HTML tags from text to ensure clean output
|
||||
|
||||
@@ -359,11 +359,7 @@ pub fn convert_tool_messages_to_text(messages: &[Message]) -> Vec<Message> {
|
||||
}
|
||||
|
||||
if has_tool_content {
|
||||
Message {
|
||||
role: message.role.clone(),
|
||||
content: new_content,
|
||||
created: message.created,
|
||||
}
|
||||
Message::new(message.role.clone(), message.created, new_content)
|
||||
} else {
|
||||
message.clone()
|
||||
}
|
||||
|
||||
@@ -319,12 +319,15 @@ pub fn unescape_json_values(value: &Value) -> Value {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn emit_debug_trace(
|
||||
pub fn emit_debug_trace<T1, T2>(
|
||||
model_config: &ModelConfig,
|
||||
payload: &Value,
|
||||
response: &Value,
|
||||
payload: &T1,
|
||||
response: &T2,
|
||||
usage: &Usage,
|
||||
) {
|
||||
) where
|
||||
T1: ?Sized + Serialize,
|
||||
T2: ?Sized + Serialize,
|
||||
{
|
||||
tracing::debug!(
|
||||
model_config = %serde_json::to_string_pretty(model_config).unwrap_or_default(),
|
||||
input = %serde_json::to_string_pretty(payload).unwrap_or_default(),
|
||||
|
||||
@@ -557,11 +557,7 @@ impl Provider for VeniceProvider {
|
||||
};
|
||||
|
||||
Ok((
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
created: Utc::now().timestamp(),
|
||||
content,
|
||||
},
|
||||
Message::new(Role::Assistant, Utc::now().timestamp(), content),
|
||||
ProviderUsage::new(strip_flags(&self.model.model_name).to_string(), usage),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -1370,14 +1370,14 @@ mod tests {
|
||||
_tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
Ok((
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
created: Utc::now().timestamp(),
|
||||
content: vec![MessageContent::Text(TextContent {
|
||||
Message::new(
|
||||
Role::Assistant,
|
||||
Utc::now().timestamp(),
|
||||
vec![MessageContent::Text(TextContent {
|
||||
text: "Mocked scheduled response".to_string(),
|
||||
annotations: None,
|
||||
})],
|
||||
},
|
||||
),
|
||||
ProviderUsage::new("mock-scheduler-test".to_string(), Usage::default()),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -1526,7 +1526,7 @@ mod tests {
|
||||
"]}}\"\\n\\\"{[",
|
||||
"Edge case: } ] some text",
|
||||
"{\"foo\": \"} ]\"}",
|
||||
"}]",
|
||||
"}]",
|
||||
];
|
||||
|
||||
let mut messages = Vec::new();
|
||||
|
||||
@@ -407,25 +407,25 @@ mod tests {
|
||||
"You are a helpful assistant that can answer questions about the weather.";
|
||||
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: Role::User,
|
||||
created: 0,
|
||||
content: vec![MessageContent::text(
|
||||
Message::new(
|
||||
Role::User,
|
||||
0,
|
||||
vec![MessageContent::text(
|
||||
"What's the weather like in San Francisco?",
|
||||
)],
|
||||
},
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
created: 1,
|
||||
content: vec![MessageContent::text(
|
||||
),
|
||||
Message::new(
|
||||
Role::Assistant,
|
||||
1,
|
||||
vec![MessageContent::text(
|
||||
"Looks like it's 60 degrees Fahrenheit in San Francisco.",
|
||||
)],
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
created: 2,
|
||||
content: vec![MessageContent::text("How about New York?")],
|
||||
},
|
||||
),
|
||||
Message::new(
|
||||
Role::User,
|
||||
2,
|
||||
vec![MessageContent::text("How about New York?")],
|
||||
),
|
||||
];
|
||||
|
||||
let tools = vec![Tool {
|
||||
@@ -505,25 +505,25 @@ mod tests {
|
||||
"You are a helpful assistant that can answer questions about the weather.";
|
||||
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: Role::User,
|
||||
created: 0,
|
||||
content: vec![MessageContent::text(
|
||||
Message::new(
|
||||
Role::User,
|
||||
0,
|
||||
vec![MessageContent::text(
|
||||
"What's the weather like in San Francisco?",
|
||||
)],
|
||||
},
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
created: 1,
|
||||
content: vec![MessageContent::text(
|
||||
),
|
||||
Message::new(
|
||||
Role::Assistant,
|
||||
1,
|
||||
vec![MessageContent::text(
|
||||
"Looks like it's 60 degrees Fahrenheit in San Francisco.",
|
||||
)],
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
created: 2,
|
||||
content: vec![MessageContent::text("How about New York?")],
|
||||
},
|
||||
),
|
||||
Message::new(
|
||||
Role::User,
|
||||
2,
|
||||
vec![MessageContent::text("How about New York?")],
|
||||
),
|
||||
];
|
||||
|
||||
let tools = vec![Tool {
|
||||
|
||||
@@ -786,7 +786,7 @@ function ChatContent({
|
||||
<SearchView>
|
||||
{filteredMessages.map((message, index) => (
|
||||
<div
|
||||
key={message.id || index}
|
||||
key={(message.id && `${message.id}-${message.content.length}`) || index}
|
||||
className="mt-4 px-4"
|
||||
data-testid="message-container"
|
||||
>
|
||||
|
||||
@@ -130,7 +130,7 @@ export default function GooseMessage({
|
||||
]);
|
||||
|
||||
return (
|
||||
<div className="goose-message flex w-[90%] justify-start opacity-0 animate-[appear_150ms_ease-in_forwards]">
|
||||
<div className="goose-message flex w-[90%] justify-start">
|
||||
<div className="flex flex-col w-full">
|
||||
{/* Chain-of-Thought (hidden by default) */}
|
||||
{cotText && (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useState, useCallback, useEffect, useRef, useId } from 'react';
|
||||
import { useState, useCallback, useEffect, useRef, useId, useReducer } from 'react';
|
||||
import useSWR from 'swr';
|
||||
import { getSecretKey } from '../config';
|
||||
import { Message, createUserMessage, hasCompletedToolCalls } from '../types/message';
|
||||
@@ -235,6 +235,9 @@ export function useMessageStream({
|
||||
};
|
||||
}, [headers, body]);
|
||||
|
||||
// TODO: not this?
|
||||
const [, forceUpdate] = useReducer((x) => x + 1, 0);
|
||||
|
||||
// Process the SSE stream from the server
|
||||
const processMessageStream = useCallback(
|
||||
async (response: Response, currentMessages: Message[]) => {
|
||||
@@ -284,8 +287,23 @@ export function useMessageStream({
|
||||
: parsedEvent.message.sendToLLM,
|
||||
};
|
||||
|
||||
console.log('New message:', JSON.stringify(newMessage, null, 2));
|
||||
|
||||
// Update messages with the new message
|
||||
currentMessages = [...currentMessages, newMessage];
|
||||
|
||||
if (
|
||||
newMessage.id &&
|
||||
currentMessages.length > 0 &&
|
||||
currentMessages[currentMessages.length - 1].id === newMessage.id
|
||||
) {
|
||||
// If the last message has the same ID, update it instead of adding a new one
|
||||
const lastMessage = currentMessages[currentMessages.length - 1];
|
||||
lastMessage.content = [...lastMessage.content, ...newMessage.content];
|
||||
forceUpdate();
|
||||
} else {
|
||||
currentMessages = [...currentMessages, newMessage];
|
||||
}
|
||||
|
||||
mutate(currentMessages, false);
|
||||
break;
|
||||
}
|
||||
@@ -373,7 +391,7 @@ export function useMessageStream({
|
||||
|
||||
return currentMessages;
|
||||
},
|
||||
[mutate, onFinish, onError]
|
||||
[mutate, onFinish, onError, forceUpdate]
|
||||
);
|
||||
|
||||
// Send a request to the server
|
||||
|
||||
@@ -201,7 +201,7 @@ export function getTextContent(message: Message): string {
|
||||
}
|
||||
return '';
|
||||
})
|
||||
.join('\n');
|
||||
.join('');
|
||||
}
|
||||
|
||||
export function getToolRequests(message: Message): ToolRequestMessageContent[] {
|
||||
|
||||
Reference in New Issue
Block a user