mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-19 15:14:21 +01:00
feat: stream LLM responses (#2677)
Co-authored-by: Michael Neale <michael.neale@gmail.com>
This commit is contained in:
5
Cargo.lock
generated
5
Cargo.lock
generated
@@ -3486,6 +3486,7 @@ dependencies = [
|
|||||||
"tokio",
|
"tokio",
|
||||||
"tokio-cron-scheduler",
|
"tokio-cron-scheduler",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
|
"tokio-util",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
"url",
|
"url",
|
||||||
@@ -8604,9 +8605,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio-util"
|
name = "tokio-util"
|
||||||
version = "0.7.13"
|
version = "0.7.15"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078"
|
checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytes",
|
"bytes",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ pub use self::export::message_to_markdown;
|
|||||||
pub use builder::{build_session, SessionBuilderConfig, SessionSettings};
|
pub use builder::{build_session, SessionBuilderConfig, SessionSettings};
|
||||||
use console::Color;
|
use console::Color;
|
||||||
use goose::agents::AgentEvent;
|
use goose::agents::AgentEvent;
|
||||||
|
use goose::message::push_message;
|
||||||
use goose::permission::permission_confirmation::PrincipalType;
|
use goose::permission::permission_confirmation::PrincipalType;
|
||||||
use goose::permission::Permission;
|
use goose::permission::Permission;
|
||||||
use goose::permission::PermissionConfirmation;
|
use goose::permission::PermissionConfirmation;
|
||||||
@@ -356,7 +357,7 @@ impl Session {
|
|||||||
|
|
||||||
/// Process a single message and get the response
|
/// Process a single message and get the response
|
||||||
async fn process_message(&mut self, message: String) -> Result<()> {
|
async fn process_message(&mut self, message: String) -> Result<()> {
|
||||||
self.messages.push(Message::user().with_text(&message));
|
self.push_message(Message::user().with_text(&message));
|
||||||
// Get the provider from the agent for description generation
|
// Get the provider from the agent for description generation
|
||||||
let provider = self.agent.provider().await?;
|
let provider = self.agent.provider().await?;
|
||||||
|
|
||||||
@@ -462,7 +463,7 @@ impl Session {
|
|||||||
RunMode::Normal => {
|
RunMode::Normal => {
|
||||||
save_history(&mut editor);
|
save_history(&mut editor);
|
||||||
|
|
||||||
self.messages.push(Message::user().with_text(&content));
|
self.push_message(Message::user().with_text(&content));
|
||||||
|
|
||||||
// Track the current directory and last instruction in projects.json
|
// Track the current directory and last instruction in projects.json
|
||||||
let session_id = self
|
let session_id = self
|
||||||
@@ -785,7 +786,7 @@ impl Session {
|
|||||||
self.messages.clear();
|
self.messages.clear();
|
||||||
// add the plan response as a user message
|
// add the plan response as a user message
|
||||||
let plan_message = Message::user().with_text(plan_response.as_concat_text());
|
let plan_message = Message::user().with_text(plan_response.as_concat_text());
|
||||||
self.messages.push(plan_message);
|
self.push_message(plan_message);
|
||||||
// act on the plan
|
// act on the plan
|
||||||
output::show_thinking();
|
output::show_thinking();
|
||||||
self.process_agent_response(true).await?;
|
self.process_agent_response(true).await?;
|
||||||
@@ -800,13 +801,13 @@ impl Session {
|
|||||||
} else {
|
} else {
|
||||||
// add the plan response (assistant message) & carry the conversation forward
|
// add the plan response (assistant message) & carry the conversation forward
|
||||||
// in the next round, the user might wanna slightly modify the plan
|
// in the next round, the user might wanna slightly modify the plan
|
||||||
self.messages.push(plan_response);
|
self.push_message(plan_response);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
PlannerResponseType::ClarifyingQuestions => {
|
PlannerResponseType::ClarifyingQuestions => {
|
||||||
// add the plan response (assistant message) & carry the conversation forward
|
// add the plan response (assistant message) & carry the conversation forward
|
||||||
// in the next round, the user will answer the clarifying questions
|
// in the next round, the user will answer the clarifying questions
|
||||||
self.messages.push(plan_response);
|
self.push_message(plan_response);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -878,7 +879,7 @@ impl Session {
|
|||||||
confirmation.id.clone(),
|
confirmation.id.clone(),
|
||||||
Err(ToolError::ExecutionError("Tool call cancelled by user".to_string()))
|
Err(ToolError::ExecutionError("Tool call cancelled by user".to_string()))
|
||||||
));
|
));
|
||||||
self.messages.push(response_message);
|
push_message(&mut self.messages, response_message);
|
||||||
if let Some(session_file) = &self.session_file {
|
if let Some(session_file) = &self.session_file {
|
||||||
session::persist_messages_with_schedule_id(
|
session::persist_messages_with_schedule_id(
|
||||||
session_file,
|
session_file,
|
||||||
@@ -975,7 +976,7 @@ impl Session {
|
|||||||
}
|
}
|
||||||
// otherwise we have a model/tool to render
|
// otherwise we have a model/tool to render
|
||||||
else {
|
else {
|
||||||
self.messages.push(message.clone());
|
push_message(&mut self.messages, message.clone());
|
||||||
|
|
||||||
// No need to update description on assistant messages
|
// No need to update description on assistant messages
|
||||||
if let Some(session_file) = &self.session_file {
|
if let Some(session_file) = &self.session_file {
|
||||||
@@ -991,7 +992,6 @@ impl Session {
|
|||||||
if interactive {output::hide_thinking()};
|
if interactive {output::hide_thinking()};
|
||||||
let _ = progress_bars.hide();
|
let _ = progress_bars.hide();
|
||||||
output::render_message(&message, self.debug);
|
output::render_message(&message, self.debug);
|
||||||
if interactive {output::show_thinking()};
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Some(Ok(AgentEvent::McpNotification((_id, message)))) => {
|
Some(Ok(AgentEvent::McpNotification((_id, message)))) => {
|
||||||
@@ -1139,6 +1139,7 @@ impl Session {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
println!();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -1182,7 +1183,7 @@ impl Session {
|
|||||||
Err(ToolError::ExecutionError(notification.clone())),
|
Err(ToolError::ExecutionError(notification.clone())),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
self.messages.push(response_message);
|
self.push_message(response_message);
|
||||||
|
|
||||||
// No need for description update here
|
// No need for description update here
|
||||||
if let Some(session_file) = &self.session_file {
|
if let Some(session_file) = &self.session_file {
|
||||||
@@ -1199,7 +1200,7 @@ impl Session {
|
|||||||
"The existing call to {} was interrupted. How would you like to proceed?",
|
"The existing call to {} was interrupted. How would you like to proceed?",
|
||||||
last_tool_name
|
last_tool_name
|
||||||
);
|
);
|
||||||
self.messages.push(Message::assistant().with_text(&prompt));
|
self.push_message(Message::assistant().with_text(&prompt));
|
||||||
|
|
||||||
// No need for description update here
|
// No need for description update here
|
||||||
if let Some(session_file) = &self.session_file {
|
if let Some(session_file) = &self.session_file {
|
||||||
@@ -1221,7 +1222,7 @@ impl Session {
|
|||||||
Some(MessageContent::ToolResponse(_)) => {
|
Some(MessageContent::ToolResponse(_)) => {
|
||||||
// Interruption occurred after a tool had completed but not assistant reply
|
// Interruption occurred after a tool had completed but not assistant reply
|
||||||
let prompt = "The tool calling loop was interrupted. How would you like to proceed?";
|
let prompt = "The tool calling loop was interrupted. How would you like to proceed?";
|
||||||
self.messages.push(Message::assistant().with_text(prompt));
|
self.push_message(Message::assistant().with_text(prompt));
|
||||||
|
|
||||||
// No need for description update here
|
// No need for description update here
|
||||||
if let Some(session_file) = &self.session_file {
|
if let Some(session_file) = &self.session_file {
|
||||||
@@ -1438,7 +1439,7 @@ impl Session {
|
|||||||
if msg.role == mcp_core::Role::User {
|
if msg.role == mcp_core::Role::User {
|
||||||
output::render_message(&msg, self.debug);
|
output::render_message(&msg, self.debug);
|
||||||
}
|
}
|
||||||
self.messages.push(msg);
|
self.push_message(msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
if valid {
|
if valid {
|
||||||
@@ -1496,6 +1497,10 @@ impl Session {
|
|||||||
|
|
||||||
Ok(path)
|
Ok(path)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn push_message(&mut self, message: Message) {
|
||||||
|
push_message(&mut self.messages, message);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_reasoner() -> Result<Arc<dyn Provider>, anyhow::Error> {
|
fn get_reasoner() -> Result<Arc<dyn Provider>, anyhow::Error> {
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ use regex::Regex;
|
|||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::cell::RefCell;
|
use std::cell::RefCell;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::io::Error;
|
use std::io::{Error, Write};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
@@ -166,7 +166,8 @@ pub fn render_message(message: &Message, debug: bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
println!();
|
|
||||||
|
let _ = std::io::stdout().flush();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn render_text(text: &str, color: Option<Color>, dim: bool) {
|
pub fn render_text(text: &str, color: Option<Color>, dim: bool) {
|
||||||
|
|||||||
@@ -225,6 +225,7 @@ async fn handler(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
let saved_message_count = all_messages.len();
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
@@ -242,16 +243,6 @@ async fn handler(
|
|||||||
).await;
|
).await;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
let session_path = session_path.clone();
|
|
||||||
let messages = all_messages.clone();
|
|
||||||
let provider = Arc::clone(provider.as_ref().unwrap());
|
|
||||||
tokio::spawn(async move {
|
|
||||||
if let Err(e) = session::persist_messages(&session_path, &messages, Some(provider)).await {
|
|
||||||
tracing::error!("Failed to store session history: {:?}", e);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => {
|
Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => {
|
||||||
if let Err(e) = stream_event(MessageEvent::ModelChange { model, mode }, &tx).await {
|
if let Err(e) = stream_event(MessageEvent::ModelChange { model, mode }, &tx).await {
|
||||||
@@ -303,6 +294,17 @@ async fn handler(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if all_messages.len() > saved_message_count {
|
||||||
|
let provider = Arc::clone(provider.as_ref().unwrap());
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) =
|
||||||
|
session::persist_messages(&session_path, &all_messages, Some(provider)).await
|
||||||
|
{
|
||||||
|
tracing::error!("Failed to store session history: {:?}", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
let _ = stream_event(
|
let _ = stream_event(
|
||||||
MessageEvent::Finish {
|
MessageEvent::Finish {
|
||||||
reason: "stop".to_string(),
|
reason: "stop".to_string(),
|
||||||
|
|||||||
@@ -81,6 +81,7 @@ fs2 = "0.4.3"
|
|||||||
tokio-stream = "0.1.17"
|
tokio-stream = "0.1.17"
|
||||||
dashmap = "6.1"
|
dashmap = "6.1"
|
||||||
ahash = "0.8"
|
ahash = "0.8"
|
||||||
|
tokio-util = "0.7.15"
|
||||||
|
|
||||||
# Vector database for tool selection
|
# Vector database for tool selection
|
||||||
lancedb = "0.13"
|
lancedb = "0.13"
|
||||||
|
|||||||
@@ -2,8 +2,12 @@ use anyhow::Result;
|
|||||||
use dotenv::dotenv;
|
use dotenv::dotenv;
|
||||||
use goose::{
|
use goose::{
|
||||||
message::Message,
|
message::Message,
|
||||||
providers::{base::Provider, databricks::DatabricksProvider},
|
providers::{
|
||||||
|
base::{Provider, Usage},
|
||||||
|
databricks::DatabricksProvider,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
use tokio_stream::StreamExt;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<()> {
|
async fn main() -> Result<()> {
|
||||||
@@ -20,21 +24,24 @@ async fn main() -> Result<()> {
|
|||||||
let message = Message::user().with_text("Tell me a short joke about programming.");
|
let message = Message::user().with_text("Tell me a short joke about programming.");
|
||||||
|
|
||||||
// Get a response
|
// Get a response
|
||||||
let (response, usage) = provider
|
let mut stream = provider
|
||||||
.complete("You are a helpful assistant.", &[message], &[])
|
.stream("You are a helpful assistant.", &[message], &[])
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Print the response and usage statistics
|
|
||||||
println!("\nResponse from AI:");
|
println!("\nResponse from AI:");
|
||||||
println!("---------------");
|
println!("---------------");
|
||||||
for content in response.content {
|
let mut usage = Usage::default();
|
||||||
dbg!(content);
|
while let Some(Ok((msg, usage_part))) = stream.next().await {
|
||||||
|
dbg!(msg);
|
||||||
|
usage_part.map(|u| {
|
||||||
|
usage += u.usage;
|
||||||
|
});
|
||||||
}
|
}
|
||||||
println!("\nToken Usage:");
|
println!("\nToken Usage:");
|
||||||
println!("------------");
|
println!("------------");
|
||||||
println!("Input tokens: {:?}", usage.usage.input_tokens);
|
println!("Input tokens: {:?}", usage.input_tokens);
|
||||||
println!("Output tokens: {:?}", usage.usage.output_tokens);
|
println!("Output tokens: {:?}", usage.output_tokens);
|
||||||
println!("Total tokens: {:?}", usage.usage.total_tokens);
|
println!("Total tokens: {:?}", usage.total_tokens);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ use crate::agents::sub_recipe_execution_tool::sub_recipe_execute_task_tool::{
|
|||||||
};
|
};
|
||||||
use crate::agents::sub_recipe_manager::SubRecipeManager;
|
use crate::agents::sub_recipe_manager::SubRecipeManager;
|
||||||
use crate::config::{Config, ExtensionConfigManager, PermissionManager};
|
use crate::config::{Config, ExtensionConfigManager, PermissionManager};
|
||||||
use crate::message::Message;
|
use crate::message::{push_message, Message};
|
||||||
use crate::permission::permission_judge::check_tool_permissions;
|
use crate::permission::permission_judge::check_tool_permissions;
|
||||||
use crate::permission::PermissionConfirmation;
|
use crate::permission::PermissionConfirmation;
|
||||||
use crate::providers::base::Provider;
|
use crate::providers::base::Provider;
|
||||||
@@ -722,6 +722,16 @@ impl Agent {
|
|||||||
});
|
});
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
|
// Check for final output before incrementing turns or checking max_turns
|
||||||
|
// This ensures that if we have a final output ready, we return it immediately
|
||||||
|
// without being blocked by the max_turns limit - this is needed for streaming cases
|
||||||
|
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
|
||||||
|
if final_output_tool.final_output.is_some() {
|
||||||
|
yield AgentEvent::Message(Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
turns_taken += 1;
|
turns_taken += 1;
|
||||||
if turns_taken > max_turns {
|
if turns_taken > max_turns {
|
||||||
yield AgentEvent::Message(Message::assistant().with_text(
|
yield AgentEvent::Message(Message::assistant().with_text(
|
||||||
@@ -752,262 +762,291 @@ impl Agent {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
match Self::generate_response_from_provider(
|
let mut stream = Self::stream_response_from_provider(
|
||||||
self.provider().await?,
|
self.provider().await?,
|
||||||
&system_prompt,
|
&system_prompt,
|
||||||
&messages,
|
&messages,
|
||||||
&tools,
|
&tools,
|
||||||
&toolshim_tools,
|
&toolshim_tools,
|
||||||
).await {
|
).await?;
|
||||||
Ok((response, usage)) => {
|
|
||||||
// Emit model change event if provider is lead-worker
|
|
||||||
let provider = self.provider().await?;
|
|
||||||
if let Some(lead_worker) = provider.as_lead_worker() {
|
|
||||||
// The actual model used is in the usage
|
|
||||||
let active_model = usage.model.clone();
|
|
||||||
let (lead_model, worker_model) = lead_worker.get_model_info();
|
|
||||||
let mode = if active_model == lead_model {
|
|
||||||
"lead"
|
|
||||||
} else if active_model == worker_model {
|
|
||||||
"worker"
|
|
||||||
} else {
|
|
||||||
"unknown"
|
|
||||||
};
|
|
||||||
|
|
||||||
yield AgentEvent::ModelChange {
|
let mut added_message = false;
|
||||||
model: active_model,
|
while let Some(next) = stream.next().await {
|
||||||
mode: mode.to_string(),
|
match next {
|
||||||
};
|
Ok((response, usage)) => {
|
||||||
}
|
// Emit model change event if provider is lead-worker
|
||||||
|
let provider = self.provider().await?;
|
||||||
|
if let Some(lead_worker) = provider.as_lead_worker() {
|
||||||
|
if let Some(ref usage) = usage {
|
||||||
|
// The actual model used is in the usage
|
||||||
|
let active_model = usage.model.clone();
|
||||||
|
let (lead_model, worker_model) = lead_worker.get_model_info();
|
||||||
|
let mode = if active_model == lead_model {
|
||||||
|
"lead"
|
||||||
|
} else if active_model == worker_model {
|
||||||
|
"worker"
|
||||||
|
} else {
|
||||||
|
"unknown"
|
||||||
|
};
|
||||||
|
|
||||||
// record usage for the session in the session file
|
yield AgentEvent::ModelChange {
|
||||||
if let Some(session_config) = session.clone() {
|
model: active_model,
|
||||||
Self::update_session_metrics(session_config, &usage, messages.len()).await?;
|
mode: mode.to_string(),
|
||||||
}
|
};
|
||||||
|
|
||||||
// categorize the type of requests we need to handle
|
|
||||||
let (frontend_requests,
|
|
||||||
remaining_requests,
|
|
||||||
filtered_response) =
|
|
||||||
self.categorize_tool_requests(&response).await;
|
|
||||||
|
|
||||||
// Record tool calls in the router selector
|
|
||||||
let selector = self.router_tool_selector.lock().await.clone();
|
|
||||||
if let Some(selector) = selector {
|
|
||||||
// Record frontend tool calls
|
|
||||||
for request in &frontend_requests {
|
|
||||||
if let Ok(tool_call) = &request.tool_call {
|
|
||||||
if let Err(e) = selector.record_tool_call(&tool_call.name).await {
|
|
||||||
tracing::error!("Failed to record frontend tool call: {}", e);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Record remaining tool calls
|
|
||||||
for request in &remaining_requests {
|
// record usage for the session in the session file
|
||||||
if let Ok(tool_call) = &request.tool_call {
|
if let Some(session_config) = session.clone() {
|
||||||
if let Err(e) = selector.record_tool_call(&tool_call.name).await {
|
if let Some(ref usage) = usage {
|
||||||
tracing::error!("Failed to record tool call: {}", e);
|
Self::update_session_metrics(session_config, usage, messages.len()).await?;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
// Yield the assistant's response with frontend tool requests filtered out
|
|
||||||
yield AgentEvent::Message(filtered_response.clone());
|
|
||||||
|
|
||||||
tokio::task::yield_now().await;
|
if let Some(response) = response {
|
||||||
|
// categorize the type of requests we need to handle
|
||||||
|
let (frontend_requests,
|
||||||
|
remaining_requests,
|
||||||
|
filtered_response) =
|
||||||
|
self.categorize_tool_requests(&response).await;
|
||||||
|
|
||||||
let num_tool_requests = frontend_requests.len() + remaining_requests.len();
|
// Record tool calls in the router selector
|
||||||
if num_tool_requests == 0 {
|
let selector = self.router_tool_selector.lock().await.clone();
|
||||||
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
|
if let Some(selector) = selector {
|
||||||
if final_output_tool.final_output.is_none() {
|
// Record frontend tool calls
|
||||||
tracing::warn!("Final output tool has not been called yet. Continuing agent loop.");
|
for request in &frontend_requests {
|
||||||
let message = Message::assistant().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE);
|
if let Ok(tool_call) = &request.tool_call {
|
||||||
messages.push(message.clone());
|
if let Err(e) = selector.record_tool_call(&tool_call.name).await {
|
||||||
yield AgentEvent::Message(message);
|
tracing::error!("Failed to record frontend tool call: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Record remaining tool calls
|
||||||
|
for request in &remaining_requests {
|
||||||
|
if let Ok(tool_call) = &request.tool_call {
|
||||||
|
if let Err(e) = selector.record_tool_call(&tool_call.name).await {
|
||||||
|
tracing::error!("Failed to record tool call: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Yield the assistant's response with frontend tool requests filtered out
|
||||||
|
yield AgentEvent::Message(filtered_response.clone());
|
||||||
|
|
||||||
|
tokio::task::yield_now().await;
|
||||||
|
|
||||||
|
let num_tool_requests = frontend_requests.len() + remaining_requests.len();
|
||||||
|
if num_tool_requests == 0 {
|
||||||
|
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
|
||||||
|
if final_output_tool.final_output.is_none() {
|
||||||
|
tracing::warn!("Final output tool has not been called yet. Continuing agent loop.");
|
||||||
|
let message = Message::assistant().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE);
|
||||||
|
messages.push(message.clone());
|
||||||
|
yield AgentEvent::Message(message);
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap());
|
||||||
|
messages.push(message.clone());
|
||||||
|
yield AgentEvent::Message(message);
|
||||||
|
// Set added_message to true and continue to end the current iteration
|
||||||
|
added_message = true;
|
||||||
|
push_message(&mut messages, response);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If there's no final output tool and no tool requests, continue the loop
|
||||||
continue;
|
continue;
|
||||||
} else {
|
|
||||||
let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap());
|
|
||||||
messages.push(message.clone());
|
|
||||||
yield AgentEvent::Message(message);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Process tool requests depending on frontend tools and then goose_mode
|
||||||
|
let message_tool_response = Arc::new(Mutex::new(Message::user()));
|
||||||
|
|
||||||
|
// First handle any frontend tool requests
|
||||||
|
let mut frontend_tool_stream = self.handle_frontend_tool_requests(
|
||||||
|
&frontend_requests,
|
||||||
|
message_tool_response.clone()
|
||||||
|
);
|
||||||
|
|
||||||
|
// we have a stream of frontend tools to handle, inside the stream
|
||||||
|
// execution is yeield back to this reply loop, and is of the same Message
|
||||||
|
// type, so we can yield that back up to be handled
|
||||||
|
while let Some(msg) = frontend_tool_stream.try_next().await? {
|
||||||
|
yield AgentEvent::Message(msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone goose_mode once before the match to avoid move issues
|
||||||
|
let mode = goose_mode.clone();
|
||||||
|
if mode.as_str() == "chat" {
|
||||||
|
// Skip all tool calls in chat mode
|
||||||
|
for request in remaining_requests {
|
||||||
|
let mut response = message_tool_response.lock().await;
|
||||||
|
*response = response.clone().with_tool_response(
|
||||||
|
request.id.clone(),
|
||||||
|
Ok(vec![Content::text(CHAT_MODE_TOOL_SKIPPED_RESPONSE)]),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// At this point, we have handled the frontend tool requests and know goose_mode != "chat"
|
||||||
|
// What remains is handling the remaining tool requests (enable extension,
|
||||||
|
// regular tool calls) in goose_mode == ["auto", "approve" or "smart_approve"]
|
||||||
|
let mut permission_manager = PermissionManager::default();
|
||||||
|
let (permission_check_result, enable_extension_request_ids) = check_tool_permissions(
|
||||||
|
&remaining_requests,
|
||||||
|
&mode,
|
||||||
|
tools_with_readonly_annotation.clone(),
|
||||||
|
tools_without_annotation.clone(),
|
||||||
|
&mut permission_manager,
|
||||||
|
self.provider().await?).await;
|
||||||
|
|
||||||
|
// Handle pre-approved and read-only tools in parallel
|
||||||
|
let mut tool_futures: Vec<(String, ToolStream)> = Vec::new();
|
||||||
|
|
||||||
|
// Skip the confirmation for approved tools
|
||||||
|
for request in &permission_check_result.approved {
|
||||||
|
if let Ok(tool_call) = request.tool_call.clone() {
|
||||||
|
let (req_id, tool_result) = self.dispatch_tool_call(tool_call, request.id.clone()).await;
|
||||||
|
|
||||||
|
tool_futures.push((req_id, match tool_result {
|
||||||
|
Ok(result) => tool_stream(
|
||||||
|
result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())),
|
||||||
|
result.result,
|
||||||
|
),
|
||||||
|
Err(e) => tool_stream(
|
||||||
|
Box::new(stream::empty()),
|
||||||
|
futures::future::ready(Err(e)),
|
||||||
|
),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for request in &permission_check_result.denied {
|
||||||
|
let mut response = message_tool_response.lock().await;
|
||||||
|
*response = response.clone().with_tool_response(
|
||||||
|
request.id.clone(),
|
||||||
|
Ok(vec![Content::text(DECLINED_RESPONSE)]),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// We need interior mutability in handle_approval_tool_requests
|
||||||
|
let tool_futures_arc = Arc::new(Mutex::new(tool_futures));
|
||||||
|
|
||||||
|
// Process tools requiring approval (enable extension, regular tool calls)
|
||||||
|
let mut tool_approval_stream = self.handle_approval_tool_requests(
|
||||||
|
&permission_check_result.needs_approval,
|
||||||
|
tool_futures_arc.clone(),
|
||||||
|
&mut permission_manager,
|
||||||
|
message_tool_response.clone()
|
||||||
|
);
|
||||||
|
|
||||||
|
// We have a stream of tool_approval_requests to handle
|
||||||
|
// Execution is yielded back to this reply loop, and is of the same Message
|
||||||
|
// type, so we can yield the Message back up to be handled and grab any
|
||||||
|
// confirmations or denials
|
||||||
|
while let Some(msg) = tool_approval_stream.try_next().await? {
|
||||||
|
yield AgentEvent::Message(msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_futures = {
|
||||||
|
// Lock the mutex asynchronously
|
||||||
|
let mut futures_lock = tool_futures_arc.lock().await;
|
||||||
|
// Drain the vector and collect into a new Vec
|
||||||
|
futures_lock.drain(..).collect::<Vec<_>>()
|
||||||
|
};
|
||||||
|
|
||||||
|
let with_id = tool_futures
|
||||||
|
.into_iter()
|
||||||
|
.map(|(request_id, stream)| {
|
||||||
|
stream.map(move |item| (request_id.clone(), item))
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let mut combined = stream::select_all(with_id);
|
||||||
|
|
||||||
|
let mut all_install_successful = true;
|
||||||
|
|
||||||
|
while let Some((request_id, item)) = combined.next().await {
|
||||||
|
match item {
|
||||||
|
ToolStreamItem::Result(output) => {
|
||||||
|
if enable_extension_request_ids.contains(&request_id) && output.is_err(){
|
||||||
|
all_install_successful = false;
|
||||||
|
}
|
||||||
|
let mut response = message_tool_response.lock().await;
|
||||||
|
*response = response.clone().with_tool_response(request_id, output);
|
||||||
|
},
|
||||||
|
ToolStreamItem::Message(msg) => {
|
||||||
|
yield AgentEvent::McpNotification((request_id, msg))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update system prompt and tools if installations were successful
|
||||||
|
if all_install_successful {
|
||||||
|
(tools, toolshim_tools, system_prompt) = self.prepare_tools_and_prompt().await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let final_message_tool_resp = message_tool_response.lock().await.clone();
|
||||||
|
yield AgentEvent::Message(final_message_tool_resp.clone());
|
||||||
|
|
||||||
|
added_message = true;
|
||||||
|
push_message(&mut messages, response);
|
||||||
|
push_message(&mut messages, final_message_tool_resp);
|
||||||
|
|
||||||
|
// Check for MCP notifications from subagents again before next iteration
|
||||||
|
// Note: These are already handled as McpNotification events above,
|
||||||
|
// so we don't need to convert them to assistant messages here.
|
||||||
|
// This was causing duplicate plain-text notifications.
|
||||||
|
// let mcp_notifications = self.get_mcp_notifications().await;
|
||||||
|
// for notification in mcp_notifications {
|
||||||
|
// // Extract subagent info from the notification data for assistant messages
|
||||||
|
// if let JsonRpcMessage::Notification(ref notif) = notification {
|
||||||
|
// if let Some(params) = ¬if.params {
|
||||||
|
// if let Some(data) = params.get("data") {
|
||||||
|
// if let (Some(subagent_id), Some(message)) = (
|
||||||
|
// data.get("subagent_id").and_then(|v| v.as_str()),
|
||||||
|
// data.get("message").and_then(|v| v.as_str())
|
||||||
|
// ) {
|
||||||
|
// yield AgentEvent::Message(
|
||||||
|
// Message::assistant().with_text(
|
||||||
|
// format!("Subagent {}: {}", subagent_id, message)
|
||||||
|
// )
|
||||||
|
// );
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
Err(ProviderError::ContextLengthExceeded(_)) => {
|
||||||
|
// At this point, the last message should be a user message
|
||||||
|
// because call to provider led to context length exceeded error
|
||||||
|
// Immediately yield a special message and break
|
||||||
|
yield AgentEvent::Message(Message::assistant().with_context_length_exceeded(
|
||||||
|
"The context length of the model has been exceeded. Please start a new session and try again.",
|
||||||
|
));
|
||||||
|
break;
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
// Create an error message & terminate the stream
|
||||||
|
error!("Error: {}", e);
|
||||||
|
yield AgentEvent::Message(Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error.")));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process tool requests depending on frontend tools and then goose_mode
|
|
||||||
let message_tool_response = Arc::new(Mutex::new(Message::user()));
|
|
||||||
|
|
||||||
// First handle any frontend tool requests
|
|
||||||
let mut frontend_tool_stream = self.handle_frontend_tool_requests(
|
|
||||||
&frontend_requests,
|
|
||||||
message_tool_response.clone()
|
|
||||||
);
|
|
||||||
|
|
||||||
// we have a stream of frontend tools to handle, inside the stream
|
|
||||||
// execution is yeield back to this reply loop, and is of the same Message
|
|
||||||
// type, so we can yield that back up to be handled
|
|
||||||
while let Some(msg) = frontend_tool_stream.try_next().await? {
|
|
||||||
yield AgentEvent::Message(msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clone goose_mode once before the match to avoid move issues
|
|
||||||
let mode = goose_mode.clone();
|
|
||||||
if mode.as_str() == "chat" {
|
|
||||||
// Skip all tool calls in chat mode
|
|
||||||
for request in remaining_requests {
|
|
||||||
let mut response = message_tool_response.lock().await;
|
|
||||||
*response = response.clone().with_tool_response(
|
|
||||||
request.id.clone(),
|
|
||||||
Ok(vec![Content::text(CHAT_MODE_TOOL_SKIPPED_RESPONSE)]),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// At this point, we have handled the frontend tool requests and know goose_mode != "chat"
|
|
||||||
// What remains is handling the remaining tool requests (enable extension,
|
|
||||||
// regular tool calls) in goose_mode == ["auto", "approve" or "smart_approve"]
|
|
||||||
let mut permission_manager = PermissionManager::default();
|
|
||||||
let (permission_check_result, enable_extension_request_ids) = check_tool_permissions(
|
|
||||||
&remaining_requests,
|
|
||||||
&mode,
|
|
||||||
tools_with_readonly_annotation.clone(),
|
|
||||||
tools_without_annotation.clone(),
|
|
||||||
&mut permission_manager,
|
|
||||||
self.provider().await?).await;
|
|
||||||
|
|
||||||
// Handle pre-approved and read-only tools in parallel
|
|
||||||
let mut tool_futures: Vec<(String, ToolStream)> = Vec::new();
|
|
||||||
|
|
||||||
// Skip the confirmation for approved tools
|
|
||||||
for request in &permission_check_result.approved {
|
|
||||||
if let Ok(tool_call) = request.tool_call.clone() {
|
|
||||||
let (req_id, tool_result) = self.dispatch_tool_call(tool_call, request.id.clone()).await;
|
|
||||||
|
|
||||||
tool_futures.push((req_id, match tool_result {
|
|
||||||
Ok(result) => tool_stream(
|
|
||||||
result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())),
|
|
||||||
result.result,
|
|
||||||
),
|
|
||||||
Err(e) => tool_stream(
|
|
||||||
Box::new(stream::empty()),
|
|
||||||
futures::future::ready(Err(e)),
|
|
||||||
),
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for request in &permission_check_result.denied {
|
|
||||||
let mut response = message_tool_response.lock().await;
|
|
||||||
*response = response.clone().with_tool_response(
|
|
||||||
request.id.clone(),
|
|
||||||
Ok(vec![Content::text(DECLINED_RESPONSE)]),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// We need interior mutability in handle_approval_tool_requests
|
|
||||||
let tool_futures_arc = Arc::new(Mutex::new(tool_futures));
|
|
||||||
|
|
||||||
// Process tools requiring approval (enable extension, regular tool calls)
|
|
||||||
let mut tool_approval_stream = self.handle_approval_tool_requests(
|
|
||||||
&permission_check_result.needs_approval,
|
|
||||||
tool_futures_arc.clone(),
|
|
||||||
&mut permission_manager,
|
|
||||||
message_tool_response.clone()
|
|
||||||
);
|
|
||||||
|
|
||||||
// We have a stream of tool_approval_requests to handle
|
|
||||||
// Execution is yielded back to this reply loop, and is of the same Message
|
|
||||||
// type, so we can yield the Message back up to be handled and grab any
|
|
||||||
// confirmations or denials
|
|
||||||
while let Some(msg) = tool_approval_stream.try_next().await? {
|
|
||||||
yield AgentEvent::Message(msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
tool_futures = {
|
|
||||||
// Lock the mutex asynchronously
|
|
||||||
let mut futures_lock = tool_futures_arc.lock().await;
|
|
||||||
// Drain the vector and collect into a new Vec
|
|
||||||
futures_lock.drain(..).collect::<Vec<_>>()
|
|
||||||
};
|
|
||||||
|
|
||||||
let with_id = tool_futures
|
|
||||||
.into_iter()
|
|
||||||
.map(|(request_id, stream)| {
|
|
||||||
stream.map(move |item| (request_id.clone(), item))
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let mut combined = stream::select_all(with_id);
|
|
||||||
|
|
||||||
let mut all_install_successful = true;
|
|
||||||
|
|
||||||
while let Some((request_id, item)) = combined.next().await {
|
|
||||||
match item {
|
|
||||||
ToolStreamItem::Result(output) => {
|
|
||||||
if enable_extension_request_ids.contains(&request_id) && output.is_err(){
|
|
||||||
all_install_successful = false;
|
|
||||||
}
|
|
||||||
let mut response = message_tool_response.lock().await;
|
|
||||||
*response = response.clone().with_tool_response(request_id, output);
|
|
||||||
},
|
|
||||||
ToolStreamItem::Message(msg) => {
|
|
||||||
yield AgentEvent::McpNotification((request_id, msg))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update system prompt and tools if installations were successful
|
|
||||||
if all_install_successful {
|
|
||||||
(tools, toolshim_tools, system_prompt) = self.prepare_tools_and_prompt().await?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let final_message_tool_resp = message_tool_response.lock().await.clone();
|
|
||||||
yield AgentEvent::Message(final_message_tool_resp.clone());
|
|
||||||
|
|
||||||
messages.push(response);
|
|
||||||
messages.push(final_message_tool_resp);
|
|
||||||
|
|
||||||
// Check for MCP notifications from subagents again before next iteration
|
|
||||||
// Note: These are already handled as McpNotification events above,
|
|
||||||
// so we don't need to convert them to assistant messages here.
|
|
||||||
// This was causing duplicate plain-text notifications.
|
|
||||||
// let mcp_notifications = self.get_mcp_notifications().await;
|
|
||||||
// for notification in mcp_notifications {
|
|
||||||
// // Extract subagent info from the notification data for assistant messages
|
|
||||||
// if let JsonRpcMessage::Notification(ref notif) = notification {
|
|
||||||
// if let Some(params) = ¬if.params {
|
|
||||||
// if let Some(data) = params.get("data") {
|
|
||||||
// if let (Some(subagent_id), Some(message)) = (
|
|
||||||
// data.get("subagent_id").and_then(|v| v.as_str()),
|
|
||||||
// data.get("message").and_then(|v| v.as_str())
|
|
||||||
// ) {
|
|
||||||
// yield AgentEvent::Message(
|
|
||||||
// Message::assistant().with_text(
|
|
||||||
// format!("Subagent {}: {}", subagent_id, message)
|
|
||||||
// )
|
|
||||||
// );
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
},
|
|
||||||
Err(ProviderError::ContextLengthExceeded(_)) => {
|
|
||||||
// At this point, the last message should be a user message
|
|
||||||
// because call to provider led to context length exceeded error
|
|
||||||
// Immediately yield a special message and break
|
|
||||||
yield AgentEvent::Message(Message::assistant().with_context_length_exceeded(
|
|
||||||
"The context length of the model has been exceeded. Please start a new session and try again.",
|
|
||||||
));
|
|
||||||
break;
|
|
||||||
},
|
|
||||||
Err(e) => {
|
|
||||||
// Create an error message & terminate the stream
|
|
||||||
error!("Error: {}", e);
|
|
||||||
yield AgentEvent::Message(Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error.")));
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if !added_message {
|
||||||
|
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
|
||||||
|
if final_output_tool.final_output.is_none() {
|
||||||
|
tracing::warn!("Final output tool has not been called yet. Continuing agent loop.");
|
||||||
|
yield AgentEvent::Message(Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE));
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
yield AgentEvent::Message(Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
// Yield control back to the scheduler to prevent blocking
|
// Yield control back to the scheduler to prevent blocking
|
||||||
tokio::task::yield_now().await;
|
tokio::task::yield_now().await;
|
||||||
|
|||||||
@@ -2,10 +2,13 @@ use anyhow::Result;
|
|||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_stream::try_stream;
|
||||||
|
use futures::stream::StreamExt;
|
||||||
|
|
||||||
use crate::agents::router_tool_selector::RouterToolSelectionStrategy;
|
use crate::agents::router_tool_selector::RouterToolSelectionStrategy;
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::message::{Message, MessageContent, ToolRequest};
|
use crate::message::{Message, MessageContent, ToolRequest};
|
||||||
use crate::providers::base::{Provider, ProviderUsage};
|
use crate::providers::base::{stream_from_single_message, MessageStream, Provider, ProviderUsage};
|
||||||
use crate::providers::errors::ProviderError;
|
use crate::providers::errors::ProviderError;
|
||||||
use crate::providers::toolshim::{
|
use crate::providers::toolshim::{
|
||||||
augment_message_with_tool_calls, convert_tool_messages_to_text,
|
augment_message_with_tool_calls, convert_tool_messages_to_text,
|
||||||
@@ -16,6 +19,19 @@ use mcp_core::tool::Tool;
|
|||||||
|
|
||||||
use super::super::agents::Agent;
|
use super::super::agents::Agent;
|
||||||
|
|
||||||
|
async fn toolshim_postprocess(
|
||||||
|
response: Message,
|
||||||
|
toolshim_tools: &[Tool],
|
||||||
|
) -> Result<Message, ProviderError> {
|
||||||
|
let interpreter = OllamaInterpreter::new().map_err(|e| {
|
||||||
|
ProviderError::ExecutionError(format!("Failed to create OllamaInterpreter: {}", e))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
augment_message_with_tool_calls(&interpreter, response, toolshim_tools)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ProviderError::ExecutionError(format!("Failed to augment message: {}", e)))
|
||||||
|
}
|
||||||
|
|
||||||
impl Agent {
|
impl Agent {
|
||||||
/// Prepares tools and system prompt for a provider request
|
/// Prepares tools and system prompt for a provider request
|
||||||
pub(crate) async fn prepare_tools_and_prompt(
|
pub(crate) async fn prepare_tools_and_prompt(
|
||||||
@@ -128,25 +144,67 @@ impl Agent {
|
|||||||
.complete(system_prompt, &messages_for_provider, tools)
|
.complete(system_prompt, &messages_for_provider, tools)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Store the model information in the global store
|
|
||||||
crate::providers::base::set_current_model(&usage.model);
|
crate::providers::base::set_current_model(&usage.model);
|
||||||
|
|
||||||
// Post-process / structure the response only if tool interpretation is enabled
|
|
||||||
if config.toolshim {
|
if config.toolshim {
|
||||||
let interpreter = OllamaInterpreter::new().map_err(|e| {
|
response = toolshim_postprocess(response, toolshim_tools).await?;
|
||||||
ProviderError::ExecutionError(format!("Failed to create OllamaInterpreter: {}", e))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
response = augment_message_with_tool_calls(&interpreter, response, toolshim_tools)
|
|
||||||
.await
|
|
||||||
.map_err(|e| {
|
|
||||||
ProviderError::ExecutionError(format!("Failed to augment message: {}", e))
|
|
||||||
})?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok((response, usage))
|
Ok((response, usage))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Stream a response from the LLM provider.
|
||||||
|
/// Handles toolshim transformations if needed
|
||||||
|
pub(crate) async fn stream_response_from_provider(
|
||||||
|
provider: Arc<dyn Provider>,
|
||||||
|
system_prompt: &str,
|
||||||
|
messages: &[Message],
|
||||||
|
tools: &[Tool],
|
||||||
|
toolshim_tools: &[Tool],
|
||||||
|
) -> Result<MessageStream, ProviderError> {
|
||||||
|
let config = provider.get_model_config();
|
||||||
|
|
||||||
|
// Convert tool messages to text if toolshim is enabled
|
||||||
|
let messages_for_provider = if config.toolshim {
|
||||||
|
convert_tool_messages_to_text(messages)
|
||||||
|
} else {
|
||||||
|
messages.to_vec()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Clone owned data to move into the async stream
|
||||||
|
let system_prompt = system_prompt.to_owned();
|
||||||
|
let tools = tools.to_owned();
|
||||||
|
let toolshim_tools = toolshim_tools.to_owned();
|
||||||
|
let provider = provider.clone();
|
||||||
|
|
||||||
|
let mut stream = if provider.supports_streaming() {
|
||||||
|
provider
|
||||||
|
.stream(system_prompt.as_str(), &messages_for_provider, &tools)
|
||||||
|
.await?
|
||||||
|
} else {
|
||||||
|
let (message, usage) = provider
|
||||||
|
.complete(system_prompt.as_str(), &messages_for_provider, &tools)
|
||||||
|
.await?;
|
||||||
|
stream_from_single_message(message, usage)
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Box::pin(try_stream! {
|
||||||
|
while let Some(Ok((mut message, usage))) = stream.next().await {
|
||||||
|
// Store the model information in the global store
|
||||||
|
if let Some(usage) = usage.as_ref() {
|
||||||
|
crate::providers::base::set_current_model(&usage.model);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Post-process / structure the response only if tool interpretation is enabled
|
||||||
|
if message.is_some() && config.toolshim {
|
||||||
|
message = Some(toolshim_postprocess(message.unwrap(), &toolshim_tools).await?);
|
||||||
|
}
|
||||||
|
|
||||||
|
yield (message, usage);
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
/// Categorize tool requests from the response into different types
|
/// Categorize tool requests from the response into different types
|
||||||
/// Returns:
|
/// Returns:
|
||||||
/// - frontend_requests: Tool requests that should be handled by the frontend
|
/// - frontend_requests: Tool requests that should be handled by the frontend
|
||||||
@@ -191,6 +249,7 @@ impl Agent {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let filtered_message = Message {
|
let filtered_message = Message {
|
||||||
|
id: response.id.clone(),
|
||||||
role: response.role.clone(),
|
role: response.role.clone(),
|
||||||
created: response.created,
|
created: response.created,
|
||||||
content: filtered_content,
|
content: filtered_content,
|
||||||
|
|||||||
@@ -247,14 +247,14 @@ mod tests {
|
|||||||
_tools: &[Tool],
|
_tools: &[Tool],
|
||||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||||
Ok((
|
Ok((
|
||||||
Message {
|
Message::new(
|
||||||
role: Role::Assistant,
|
Role::Assistant,
|
||||||
created: Utc::now().timestamp(),
|
Utc::now().timestamp(),
|
||||||
content: vec![MessageContent::Text(TextContent {
|
vec![MessageContent::Text(TextContent {
|
||||||
text: "Summarized content".to_string(),
|
text: "Summarized content".to_string(),
|
||||||
annotations: None,
|
annotations: None,
|
||||||
})],
|
})],
|
||||||
},
|
),
|
||||||
ProviderUsage::new("mock".to_string(), Usage::default()),
|
ProviderUsage::new("mock".to_string(), Usage::default()),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
@@ -277,30 +277,26 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn set_up_text_message(text: &str, role: Role) -> Message {
|
fn set_up_text_message(text: &str, role: Role) -> Message {
|
||||||
Message {
|
Message::new(role, 0, vec![MessageContent::text(text.to_string())])
|
||||||
role,
|
|
||||||
created: 0,
|
|
||||||
content: vec![MessageContent::text(text.to_string())],
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_up_tool_request_message(id: &str, tool_call: ToolCall) -> Message {
|
fn set_up_tool_request_message(id: &str, tool_call: ToolCall) -> Message {
|
||||||
Message {
|
Message::new(
|
||||||
role: Role::Assistant,
|
Role::Assistant,
|
||||||
created: 0,
|
0,
|
||||||
content: vec![MessageContent::tool_request(id.to_string(), Ok(tool_call))],
|
vec![MessageContent::tool_request(id.to_string(), Ok(tool_call))],
|
||||||
}
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_up_tool_response_message(id: &str, tool_response: Vec<Content>) -> Message {
|
fn set_up_tool_response_message(id: &str, tool_response: Vec<Content>) -> Message {
|
||||||
Message {
|
Message::new(
|
||||||
role: Role::User,
|
Role::User,
|
||||||
created: 0,
|
0,
|
||||||
content: vec![MessageContent::tool_response(
|
vec![MessageContent::tool_response(
|
||||||
id.to_string(),
|
id.to_string(),
|
||||||
Ok(tool_response),
|
Ok(tool_response),
|
||||||
)],
|
)],
|
||||||
}
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -448,14 +444,14 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_reintegrate_removed_messages() {
|
async fn test_reintegrate_removed_messages() {
|
||||||
let summarized_messages = vec![Message {
|
let summarized_messages = vec![Message::new(
|
||||||
role: Role::Assistant,
|
Role::Assistant,
|
||||||
created: Utc::now().timestamp(),
|
Utc::now().timestamp(),
|
||||||
content: vec![MessageContent::Text(TextContent {
|
vec![MessageContent::Text(TextContent {
|
||||||
text: "Summary".to_string(),
|
text: "Summary".to_string(),
|
||||||
annotations: None,
|
annotations: None,
|
||||||
})],
|
})],
|
||||||
}];
|
)];
|
||||||
let arguments = json!({
|
let arguments = json!({
|
||||||
"param1": "value1"
|
"param1": "value1"
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -303,15 +303,46 @@ impl From<PromptMessage> for Message {
|
|||||||
/// A message to or from an LLM
|
/// A message to or from an LLM
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct Message {
|
pub struct Message {
|
||||||
|
pub id: Option<String>,
|
||||||
pub role: Role,
|
pub role: Role,
|
||||||
pub created: i64,
|
pub created: i64,
|
||||||
pub content: Vec<MessageContent>,
|
pub content: Vec<MessageContent>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn push_message(messages: &mut Vec<Message>, message: Message) {
|
||||||
|
if let Some(last) = messages
|
||||||
|
.last_mut()
|
||||||
|
.filter(|m| m.id.is_some() && m.id == message.id)
|
||||||
|
{
|
||||||
|
match (last.content.last_mut(), message.content.last()) {
|
||||||
|
(Some(MessageContent::Text(ref mut last)), Some(MessageContent::Text(new)))
|
||||||
|
if message.content.len() == 1 =>
|
||||||
|
{
|
||||||
|
last.text.push_str(&new.text);
|
||||||
|
}
|
||||||
|
(_, _) => {
|
||||||
|
last.content.extend(message.content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
messages.push(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Message {
|
impl Message {
|
||||||
|
pub fn new(role: Role, created: i64, content: Vec<MessageContent>) -> Self {
|
||||||
|
Message {
|
||||||
|
id: None,
|
||||||
|
role,
|
||||||
|
created,
|
||||||
|
content,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Create a new user message with the current timestamp
|
/// Create a new user message with the current timestamp
|
||||||
pub fn user() -> Self {
|
pub fn user() -> Self {
|
||||||
Message {
|
Message {
|
||||||
|
id: None,
|
||||||
role: Role::User,
|
role: Role::User,
|
||||||
created: Utc::now().timestamp(),
|
created: Utc::now().timestamp(),
|
||||||
content: Vec::new(),
|
content: Vec::new(),
|
||||||
@@ -321,6 +352,7 @@ impl Message {
|
|||||||
/// Create a new assistant message with the current timestamp
|
/// Create a new assistant message with the current timestamp
|
||||||
pub fn assistant() -> Self {
|
pub fn assistant() -> Self {
|
||||||
Message {
|
Message {
|
||||||
|
id: None,
|
||||||
role: Role::Assistant,
|
role: Role::Assistant,
|
||||||
created: Utc::now().timestamp(),
|
created: Utc::now().timestamp(),
|
||||||
content: Vec::new(),
|
content: Vec::new(),
|
||||||
|
|||||||
@@ -81,10 +81,10 @@ fn create_check_messages(tool_requests: Vec<&ToolRequest>) -> Vec<Message> {
|
|||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
let mut check_messages = vec![];
|
let mut check_messages = vec![];
|
||||||
check_messages.push(Message {
|
check_messages.push(Message::new(
|
||||||
role: mcp_core::Role::User,
|
mcp_core::Role::User,
|
||||||
created: Utc::now().timestamp(),
|
Utc::now().timestamp(),
|
||||||
content: vec![MessageContent::Text(TextContent {
|
vec![MessageContent::Text(TextContent {
|
||||||
text: format!(
|
text: format!(
|
||||||
"Here are the tool requests: {:?}\n\nAnalyze the tool requests and list the tools that perform read-only operations. \
|
"Here are the tool requests: {:?}\n\nAnalyze the tool requests and list the tools that perform read-only operations. \
|
||||||
\n\nGuidelines for Read-Only Operations: \
|
\n\nGuidelines for Read-Only Operations: \
|
||||||
@@ -96,7 +96,7 @@ fn create_check_messages(tool_requests: Vec<&ToolRequest>) -> Vec<Message> {
|
|||||||
),
|
),
|
||||||
annotations: None,
|
annotations: None,
|
||||||
})],
|
})],
|
||||||
});
|
));
|
||||||
check_messages
|
check_messages
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -296,10 +296,10 @@ mod tests {
|
|||||||
_tools: &[Tool],
|
_tools: &[Tool],
|
||||||
) -> anyhow::Result<(Message, ProviderUsage), ProviderError> {
|
) -> anyhow::Result<(Message, ProviderUsage), ProviderError> {
|
||||||
Ok((
|
Ok((
|
||||||
Message {
|
Message::new(
|
||||||
role: Role::Assistant,
|
Role::Assistant,
|
||||||
created: Utc::now().timestamp(),
|
Utc::now().timestamp(),
|
||||||
content: vec![MessageContent::ToolRequest(ToolRequest {
|
vec![MessageContent::ToolRequest(ToolRequest {
|
||||||
id: "mock_tool_request".to_string(),
|
id: "mock_tool_request".to_string(),
|
||||||
tool_call: ToolResult::Ok(ToolCall {
|
tool_call: ToolResult::Ok(ToolCall {
|
||||||
name: "platform__tool_by_tool_permission".to_string(),
|
name: "platform__tool_by_tool_permission".to_string(),
|
||||||
@@ -308,7 +308,7 @@ mod tests {
|
|||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
})],
|
})],
|
||||||
},
|
),
|
||||||
ProviderUsage::new("mock".to_string(), Usage::default()),
|
ProviderUsage::new("mock".to_string(), Usage::default()),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
@@ -354,10 +354,10 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_extract_read_only_tools() {
|
fn test_extract_read_only_tools() {
|
||||||
let message = Message {
|
let message = Message::new(
|
||||||
role: Role::Assistant,
|
Role::Assistant,
|
||||||
created: Utc::now().timestamp(),
|
Utc::now().timestamp(),
|
||||||
content: vec![MessageContent::ToolRequest(ToolRequest {
|
vec![MessageContent::ToolRequest(ToolRequest {
|
||||||
id: "tool_2".to_string(),
|
id: "tool_2".to_string(),
|
||||||
tool_call: ToolResult::Ok(ToolCall {
|
tool_call: ToolResult::Ok(ToolCall {
|
||||||
name: "platform__tool_by_tool_permission".to_string(),
|
name: "platform__tool_by_tool_permission".to_string(),
|
||||||
@@ -366,7 +366,7 @@ mod tests {
|
|||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
})],
|
})],
|
||||||
};
|
);
|
||||||
|
|
||||||
let result = extract_read_only_tools(&message);
|
let result = extract_read_only_tools(&message);
|
||||||
assert!(result.is_some());
|
assert!(result.is_some());
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use futures::Stream;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use super::errors::ProviderError;
|
use super::errors::ProviderError;
|
||||||
@@ -8,6 +9,8 @@ use mcp_core::tool::Tool;
|
|||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
|
use std::ops::{Add, AddAssign};
|
||||||
|
use std::pin::Pin;
|
||||||
use std::sync::Mutex;
|
use std::sync::Mutex;
|
||||||
|
|
||||||
/// A global store for the current model being used, we use this as when a provider returns, it tells us the real model, not an alias
|
/// A global store for the current model being used, we use this as when a provider returns, it tells us the real model, not an alias
|
||||||
@@ -184,13 +187,43 @@ impl ProviderUsage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
#[derive(Debug, Clone, Serialize, Deserialize, Default, Copy)]
|
||||||
pub struct Usage {
|
pub struct Usage {
|
||||||
pub input_tokens: Option<i32>,
|
pub input_tokens: Option<i32>,
|
||||||
pub output_tokens: Option<i32>,
|
pub output_tokens: Option<i32>,
|
||||||
pub total_tokens: Option<i32>,
|
pub total_tokens: Option<i32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn sum_optionals<T>(a: Option<T>, b: Option<T>) -> Option<T>
|
||||||
|
where
|
||||||
|
T: Add<Output = T> + Default,
|
||||||
|
{
|
||||||
|
match (a, b) {
|
||||||
|
(Some(x), Some(y)) => Some(x + y),
|
||||||
|
(Some(x), None) => Some(x + T::default()),
|
||||||
|
(None, Some(y)) => Some(T::default() + y),
|
||||||
|
(None, None) => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Add for Usage {
|
||||||
|
type Output = Self;
|
||||||
|
|
||||||
|
fn add(self, other: Self) -> Self {
|
||||||
|
Self {
|
||||||
|
input_tokens: sum_optionals(self.input_tokens, other.input_tokens),
|
||||||
|
output_tokens: sum_optionals(self.output_tokens, other.output_tokens),
|
||||||
|
total_tokens: sum_optionals(self.total_tokens, other.total_tokens),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AddAssign for Usage {
|
||||||
|
fn add_assign(&mut self, rhs: Self) {
|
||||||
|
*self = *self + rhs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Usage {
|
impl Usage {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
input_tokens: Option<i32>,
|
input_tokens: Option<i32>,
|
||||||
@@ -270,6 +303,21 @@ pub trait Provider: Send + Sync {
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn stream(
|
||||||
|
&self,
|
||||||
|
_system: &str,
|
||||||
|
_messages: &[Message],
|
||||||
|
_tools: &[Tool],
|
||||||
|
) -> Result<MessageStream, ProviderError> {
|
||||||
|
Err(ProviderError::NotImplemented(
|
||||||
|
"streaming not implemented".to_string(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supports_streaming(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the currently active model name
|
/// Get the currently active model name
|
||||||
/// For regular providers, this returns the configured model
|
/// For regular providers, this returns the configured model
|
||||||
/// For LeadWorkerProvider, this returns the currently active model (lead or worker)
|
/// For LeadWorkerProvider, this returns the currently active model (lead or worker)
|
||||||
@@ -282,6 +330,18 @@ pub trait Provider: Send + Sync {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A message stream yields partial text content but complete tool calls, all within the Message object
|
||||||
|
/// So a message with text will contain potentially just a word of a longer response, but tool calls
|
||||||
|
/// messages will only be yielded once concatenated.
|
||||||
|
pub type MessageStream = Pin<
|
||||||
|
Box<dyn Stream<Item = Result<(Option<Message>, Option<ProviderUsage>), ProviderError>> + Send>,
|
||||||
|
>;
|
||||||
|
|
||||||
|
pub fn stream_from_single_message(message: Message, usage: ProviderUsage) -> MessageStream {
|
||||||
|
let stream = futures::stream::once(async move { Ok((Some(message), Some(usage))) });
|
||||||
|
Box::pin(stream)
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|||||||
@@ -219,11 +219,11 @@ impl ClaudeCodeProvider {
|
|||||||
annotations: None,
|
annotations: None,
|
||||||
})];
|
})];
|
||||||
|
|
||||||
let response_message = Message {
|
let response_message = Message::new(
|
||||||
role: Role::Assistant,
|
Role::Assistant,
|
||||||
created: chrono::Utc::now().timestamp(),
|
chrono::Utc::now().timestamp(),
|
||||||
content: message_content,
|
message_content,
|
||||||
};
|
);
|
||||||
|
|
||||||
Ok((response_message, usage))
|
Ok((response_message, usage))
|
||||||
}
|
}
|
||||||
@@ -353,14 +353,14 @@ impl ClaudeCodeProvider {
|
|||||||
println!("================================");
|
println!("================================");
|
||||||
}
|
}
|
||||||
|
|
||||||
let message = Message {
|
let message = Message::new(
|
||||||
role: mcp_core::Role::Assistant,
|
mcp_core::Role::Assistant,
|
||||||
created: chrono::Utc::now().timestamp(),
|
chrono::Utc::now().timestamp(),
|
||||||
content: vec![MessageContent::Text(mcp_core::content::TextContent {
|
vec![MessageContent::Text(mcp_core::content::TextContent {
|
||||||
text: description.clone(),
|
text: description.clone(),
|
||||||
annotations: None,
|
annotations: None,
|
||||||
})],
|
})],
|
||||||
};
|
);
|
||||||
|
|
||||||
let usage = Usage::default();
|
let usage = Usage::default();
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,16 @@
|
|||||||
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
|
use anyhow::Result;
|
||||||
|
use async_stream::try_stream;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::TryStreamExt;
|
||||||
|
use reqwest::{Client, StatusCode};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::io;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::pin;
|
||||||
|
use tokio_util::io::StreamReader;
|
||||||
|
|
||||||
|
use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage};
|
||||||
use super::embedding::EmbeddingCapable;
|
use super::embedding::EmbeddingCapable;
|
||||||
use super::errors::ProviderError;
|
use super::errors::ProviderError;
|
||||||
use super::formats::databricks::{create_request, get_usage, response_to_message};
|
use super::formats::databricks::{create_request, get_usage, response_to_message};
|
||||||
@@ -7,17 +19,13 @@ use super::utils::{get_model, ImageFormat};
|
|||||||
use crate::config::ConfigError;
|
use crate::config::ConfigError;
|
||||||
use crate::message::Message;
|
use crate::message::Message;
|
||||||
use crate::model::ModelConfig;
|
use crate::model::ModelConfig;
|
||||||
|
use crate::providers::formats::databricks::response_to_streaming_message;
|
||||||
use mcp_core::tool::Tool;
|
use mcp_core::tool::Tool;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use url::Url;
|
|
||||||
|
|
||||||
use anyhow::Result;
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use reqwest::{Client, StatusCode};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use serde_json::Value;
|
|
||||||
use std::time::Duration;
|
|
||||||
use tokio::time::sleep;
|
use tokio::time::sleep;
|
||||||
|
use tokio_stream::StreamExt;
|
||||||
|
use tokio_util::codec::{FramedRead, LinesCodec};
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
const DEFAULT_CLIENT_ID: &str = "databricks-cli";
|
const DEFAULT_CLIENT_ID: &str = "databricks-cli";
|
||||||
const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020";
|
const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020";
|
||||||
@@ -266,9 +274,6 @@ impl DatabricksProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||||
let base_url = Url::parse(&self.host)
|
|
||||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
|
||||||
|
|
||||||
// Check if this is an embedding request by looking at the payload structure
|
// Check if this is an embedding request by looking at the payload structure
|
||||||
let is_embedding = payload.get("input").is_some() && payload.get("messages").is_none();
|
let is_embedding = payload.get("input").is_some() && payload.get("messages").is_none();
|
||||||
let path = if is_embedding {
|
let path = if is_embedding {
|
||||||
@@ -279,56 +284,71 @@ impl DatabricksProvider {
|
|||||||
format!("serving-endpoints/{}/invocations", self.model.model_name)
|
format!("serving-endpoints/{}/invocations", self.model.model_name)
|
||||||
};
|
};
|
||||||
|
|
||||||
let url = base_url.join(&path).map_err(|e| {
|
match self.post_with_retry(path.as_str(), &payload).await {
|
||||||
|
Ok(res) => res.json().await.map_err(|_| {
|
||||||
|
ProviderError::RequestFailed("Response body is not valid JSON".to_string())
|
||||||
|
}),
|
||||||
|
Err(e) => Err(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn post_with_retry(
|
||||||
|
&self,
|
||||||
|
path: &str,
|
||||||
|
payload: &Value,
|
||||||
|
) -> Result<reqwest::Response, ProviderError> {
|
||||||
|
let base_url = Url::parse(&self.host)
|
||||||
|
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||||
|
let url = base_url.join(path).map_err(|e| {
|
||||||
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Initialize retry counter
|
|
||||||
let mut attempts = 0;
|
let mut attempts = 0;
|
||||||
let mut last_error = None;
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
// Check if we've exceeded max retries
|
|
||||||
if attempts > 0 && attempts > self.retry_config.max_retries {
|
|
||||||
let error_msg = format!(
|
|
||||||
"Exceeded maximum retry attempts ({}) for rate limiting (429)",
|
|
||||||
self.retry_config.max_retries
|
|
||||||
);
|
|
||||||
tracing::error!("{}", error_msg);
|
|
||||||
return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg)));
|
|
||||||
}
|
|
||||||
|
|
||||||
let auth_header = self.ensure_auth_header().await?;
|
let auth_header = self.ensure_auth_header().await?;
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.post(url.clone())
|
.post(url.clone())
|
||||||
.header("Authorization", auth_header)
|
.header("Authorization", auth_header)
|
||||||
.json(&payload)
|
.json(payload)
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
let payload: Option<Value> = response.json().await.ok();
|
|
||||||
|
|
||||||
match status {
|
break match status {
|
||||||
StatusCode::OK => {
|
StatusCode::OK => Ok(response),
|
||||||
return payload.ok_or_else(|| {
|
StatusCode::TOO_MANY_REQUESTS
|
||||||
ProviderError::RequestFailed("Response body is not valid JSON".to_string())
|
| StatusCode::INTERNAL_SERVER_ERROR
|
||||||
});
|
| StatusCode::SERVICE_UNAVAILABLE => {
|
||||||
}
|
if attempts < self.retry_config.max_retries {
|
||||||
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
|
attempts += 1;
|
||||||
return Err(ProviderError::Authentication(format!(
|
tracing::warn!(
|
||||||
"Authentication failed. Please ensure your API keys are valid and have the required permissions. \
|
"{}: retrying ({}/{})",
|
||||||
Status: {}. Response: {:?}",
|
status,
|
||||||
status, payload
|
attempts,
|
||||||
)));
|
self.retry_config.max_retries
|
||||||
|
);
|
||||||
|
|
||||||
|
let delay = self.retry_config.delay_for_attempt(attempts);
|
||||||
|
tracing::info!("Backing off for {:?} before retry", delay);
|
||||||
|
sleep(delay).await;
|
||||||
|
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(match status {
|
||||||
|
StatusCode::TOO_MANY_REQUESTS => {
|
||||||
|
ProviderError::RateLimitExceeded("Rate limit exceeded".to_string())
|
||||||
|
}
|
||||||
|
_ => ProviderError::ServerError("Server error".to_string()),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
StatusCode::BAD_REQUEST => {
|
StatusCode::BAD_REQUEST => {
|
||||||
// Databricks provides a generic 'error' but also includes 'external_model_message' which is provider specific
|
// Databricks provides a generic 'error' but also includes 'external_model_message' which is provider specific
|
||||||
// We try to extract the error message from the payload and check for phrases that indicate context length exceeded
|
// We try to extract the error message from the payload and check for phrases that indicate context length exceeded
|
||||||
let payload_str = serde_json::to_string(&payload)
|
let bytes = response.bytes().await?;
|
||||||
.unwrap_or_default()
|
let payload_str = String::from_utf8_lossy(&bytes).to_lowercase();
|
||||||
.to_lowercase();
|
|
||||||
let check_phrases = [
|
let check_phrases = [
|
||||||
"too long",
|
"too long",
|
||||||
"context length",
|
"context length",
|
||||||
@@ -347,13 +367,13 @@ impl DatabricksProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut error_msg = "Unknown error".to_string();
|
let mut error_msg = "Unknown error".to_string();
|
||||||
if let Some(payload) = &payload {
|
if let Ok(response_json) = serde_json::from_slice::<Value>(&bytes) {
|
||||||
// try to convert message to string, if that fails use external_model_message
|
// try to convert message to string, if that fails use external_model_message
|
||||||
error_msg = payload
|
error_msg = response_json
|
||||||
.get("message")
|
.get("message")
|
||||||
.and_then(|m| m.as_str())
|
.and_then(|m| m.as_str())
|
||||||
.or_else(|| {
|
.or_else(|| {
|
||||||
payload
|
response_json
|
||||||
.get("external_model_message")
|
.get("external_model_message")
|
||||||
.and_then(|ext| ext.get("message"))
|
.and_then(|ext| ext.get("message"))
|
||||||
.and_then(|m| m.as_str())
|
.and_then(|m| m.as_str())
|
||||||
@@ -366,7 +386,7 @@ impl DatabricksProvider {
|
|||||||
"{}",
|
"{}",
|
||||||
format!(
|
format!(
|
||||||
"Provider request failed with status: {}. Payload: {:?}",
|
"Provider request failed with status: {}. Payload: {:?}",
|
||||||
status, payload
|
status, payload_str
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
return Err(ProviderError::RequestFailed(format!(
|
return Err(ProviderError::RequestFailed(format!(
|
||||||
@@ -374,50 +394,13 @@ impl DatabricksProvider {
|
|||||||
status, error_msg
|
status, error_msg
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
StatusCode::TOO_MANY_REQUESTS => {
|
|
||||||
attempts += 1;
|
|
||||||
let error_msg = format!(
|
|
||||||
"Rate limit exceeded (attempt {}/{}): {:?}",
|
|
||||||
attempts, self.retry_config.max_retries, payload
|
|
||||||
);
|
|
||||||
tracing::warn!("{}. Retrying after backoff...", error_msg);
|
|
||||||
|
|
||||||
// Store the error in case we need to return it after max retries
|
|
||||||
last_error = Some(ProviderError::RateLimitExceeded(error_msg));
|
|
||||||
|
|
||||||
// Calculate and apply the backoff delay
|
|
||||||
let delay = self.retry_config.delay_for_attempt(attempts);
|
|
||||||
tracing::info!("Backing off for {:?} before retry", delay);
|
|
||||||
sleep(delay).await;
|
|
||||||
|
|
||||||
// Continue to the next retry attempt
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
|
|
||||||
attempts += 1;
|
|
||||||
let error_msg = format!(
|
|
||||||
"Server error (attempt {}/{}): {:?}",
|
|
||||||
attempts, self.retry_config.max_retries, payload
|
|
||||||
);
|
|
||||||
tracing::warn!("{}. Retrying after backoff...", error_msg);
|
|
||||||
|
|
||||||
// Store the error in case we need to return it after max retries
|
|
||||||
last_error = Some(ProviderError::ServerError(error_msg));
|
|
||||||
|
|
||||||
// Calculate and apply the backoff delay
|
|
||||||
let delay = self.retry_config.delay_for_attempt(attempts);
|
|
||||||
tracing::info!("Backing off for {:?} before retry", delay);
|
|
||||||
sleep(delay).await;
|
|
||||||
|
|
||||||
// Continue to the next retry attempt
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
_ => {
|
_ => {
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
"{}",
|
"{}",
|
||||||
format!(
|
format!(
|
||||||
"Provider request failed with status: {}. Payload: {:?}",
|
"Provider request failed with status: {}. Payload: {:?}",
|
||||||
status, payload
|
status,
|
||||||
|
response.text().await.ok().unwrap_or_default()
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
return Err(ProviderError::RequestFailed(format!(
|
return Err(ProviderError::RequestFailed(format!(
|
||||||
@@ -425,7 +408,7 @@ impl DatabricksProvider {
|
|||||||
status
|
status
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -472,13 +455,12 @@ impl Provider for DatabricksProvider {
|
|||||||
|
|
||||||
// Parse response
|
// Parse response
|
||||||
let message = response_to_message(response.clone())?;
|
let message = response_to_message(response.clone())?;
|
||||||
let usage = match get_usage(&response) {
|
let usage = match response.get("usage").map(get_usage) {
|
||||||
Ok(usage) => usage,
|
Some(usage) => usage,
|
||||||
Err(ProviderError::UsageError(e)) => {
|
None => {
|
||||||
tracing::debug!("Failed to get usage data: {}", e);
|
tracing::debug!("Failed to get usage data");
|
||||||
Usage::default()
|
Usage::default()
|
||||||
}
|
}
|
||||||
Err(e) => return Err(e),
|
|
||||||
};
|
};
|
||||||
let model = get_model(&response);
|
let model = get_model(&response);
|
||||||
super::utils::emit_debug_trace(&self.model, &payload, &response, &usage);
|
super::utils::emit_debug_trace(&self.model, &payload, &response, &usage);
|
||||||
@@ -486,6 +468,54 @@ impl Provider for DatabricksProvider {
|
|||||||
Ok((message, ProviderUsage::new(model, usage)))
|
Ok((message, ProviderUsage::new(model, usage)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn stream(
|
||||||
|
&self,
|
||||||
|
system: &str,
|
||||||
|
messages: &[Message],
|
||||||
|
tools: &[Tool],
|
||||||
|
) -> Result<MessageStream, ProviderError> {
|
||||||
|
let mut payload = create_request(&self.model, system, messages, tools, &self.image_format)?;
|
||||||
|
// Remove the model key which is part of the url with databricks
|
||||||
|
payload
|
||||||
|
.as_object_mut()
|
||||||
|
.expect("payload should have model key")
|
||||||
|
.remove("model");
|
||||||
|
|
||||||
|
payload
|
||||||
|
.as_object_mut()
|
||||||
|
.unwrap()
|
||||||
|
.insert("stream".to_string(), Value::Bool(true));
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.post_with_retry(
|
||||||
|
format!("serving-endpoints/{}/invocations", self.model.model_name).as_str(),
|
||||||
|
&payload,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Map reqwest error to io::Error
|
||||||
|
let stream = response.bytes_stream().map_err(io::Error::other);
|
||||||
|
|
||||||
|
let model_config = self.model.clone();
|
||||||
|
// Wrap in a line decoder and yield lines inside the stream
|
||||||
|
Ok(Box::pin(try_stream! {
|
||||||
|
let stream_reader = StreamReader::new(stream);
|
||||||
|
let framed = FramedRead::new(stream_reader, LinesCodec::new()).map_err(anyhow::Error::from);
|
||||||
|
|
||||||
|
let message_stream = response_to_streaming_message(framed);
|
||||||
|
pin!(message_stream);
|
||||||
|
while let Some(message) = message_stream.next().await {
|
||||||
|
let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?;
|
||||||
|
super::utils::emit_debug_trace(&model_config, &payload, &message, &usage.as_ref().map(|f| f.usage).unwrap_or_default());
|
||||||
|
yield (message, usage);
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supports_streaming(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
fn supports_embeddings(&self) -> bool {
|
fn supports_embeddings(&self) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,6 +23,9 @@ pub enum ProviderError {
|
|||||||
|
|
||||||
#[error("Usage data error: {0}")]
|
#[error("Usage data error: {0}")]
|
||||||
UsageError(String),
|
UsageError(String),
|
||||||
|
|
||||||
|
#[error("Unsupported operation: {0}")]
|
||||||
|
NotImplemented(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<anyhow::Error> for ProviderError {
|
impl From<anyhow::Error> for ProviderError {
|
||||||
|
|||||||
@@ -212,17 +212,17 @@ mod tests {
|
|||||||
_tools: &[Tool],
|
_tools: &[Tool],
|
||||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||||
Ok((
|
Ok((
|
||||||
Message {
|
Message::new(
|
||||||
role: Role::Assistant,
|
Role::Assistant,
|
||||||
created: Utc::now().timestamp(),
|
Utc::now().timestamp(),
|
||||||
content: vec![MessageContent::Text(TextContent {
|
vec![MessageContent::Text(TextContent {
|
||||||
text: format!(
|
text: format!(
|
||||||
"Response from {} with model {}",
|
"Response from {} with model {}",
|
||||||
self.name, self.model_config.model_name
|
self.name, self.model_config.model_name
|
||||||
),
|
),
|
||||||
annotations: None,
|
annotations: None,
|
||||||
})],
|
})],
|
||||||
},
|
),
|
||||||
ProviderUsage::new(self.model_config.model_name.clone(), Usage::default()),
|
ProviderUsage::new(self.model_config.model_name.clone(), Usage::default()),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -260,11 +260,7 @@ pub fn from_bedrock_message(message: &bedrock::Message) -> Result<Message> {
|
|||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let created = Utc::now().timestamp();
|
let created = Utc::now().timestamp();
|
||||||
|
|
||||||
Ok(Message {
|
Ok(Message::new(role, created, content))
|
||||||
role,
|
|
||||||
content,
|
|
||||||
created,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_bedrock_content_block(block: &bedrock::ContentBlock) -> Result<MessageContent> {
|
pub fn from_bedrock_content_block(block: &bedrock::ContentBlock) -> Result<MessageContent> {
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
use crate::message::{Message, MessageContent};
|
use crate::message::{Message, MessageContent};
|
||||||
use crate::model::ModelConfig;
|
use crate::model::ModelConfig;
|
||||||
use crate::providers::base::Usage;
|
use crate::providers::base::{ProviderUsage, Usage};
|
||||||
use crate::providers::errors::ProviderError;
|
|
||||||
use crate::providers::utils::{
|
use crate::providers::utils::{
|
||||||
convert_image, detect_image_path, is_valid_function_name, load_image_file,
|
convert_image, detect_image_path, is_valid_function_name, load_image_file,
|
||||||
sanitize_function_name, ImageFormat,
|
sanitize_function_name, ImageFormat,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Error};
|
use anyhow::{anyhow, Error};
|
||||||
|
use async_stream::try_stream;
|
||||||
|
use futures::Stream;
|
||||||
use mcp_core::ToolError;
|
use mcp_core::ToolError;
|
||||||
use mcp_core::{Content, Role, Tool, ToolCall};
|
use mcp_core::{Content, Role, Tool, ToolCall};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
|
|
||||||
/// Convert internal Message format to Databricks' API message specification
|
/// Convert internal Message format to Databricks' API message specification
|
||||||
@@ -358,18 +360,162 @@ pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Message {
|
Ok(Message::new(
|
||||||
role: Role::Assistant,
|
Role::Assistant,
|
||||||
created: chrono::Utc::now().timestamp(),
|
chrono::Utc::now().timestamp(),
|
||||||
content,
|
content,
|
||||||
})
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_usage(data: &Value) -> Result<Usage, ProviderError> {
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
let usage = data
|
struct DeltaToolCallFunction {
|
||||||
.get("usage")
|
name: Option<String>,
|
||||||
.ok_or_else(|| ProviderError::UsageError("No usage data in response".to_string()))?;
|
arguments: String, // chunk of encoded JSON,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
struct DeltaToolCall {
|
||||||
|
id: Option<String>,
|
||||||
|
function: DeltaToolCallFunction,
|
||||||
|
index: Option<i32>,
|
||||||
|
r#type: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
struct Delta {
|
||||||
|
content: Option<String>,
|
||||||
|
role: Option<String>,
|
||||||
|
tool_calls: Option<Vec<DeltaToolCall>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
struct StreamingChoice {
|
||||||
|
delta: Delta,
|
||||||
|
index: Option<i32>,
|
||||||
|
finish_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
struct StreamingChunk {
|
||||||
|
choices: Vec<StreamingChoice>,
|
||||||
|
created: Option<i64>,
|
||||||
|
id: Option<String>,
|
||||||
|
usage: Option<Value>,
|
||||||
|
model: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn strip_data_prefix(line: &str) -> Option<&str> {
|
||||||
|
line.strip_prefix("data: ").map(|s| s.trim())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn response_to_streaming_message<S>(
|
||||||
|
mut stream: S,
|
||||||
|
) -> impl Stream<Item = anyhow::Result<(Option<Message>, Option<ProviderUsage>)>> + 'static
|
||||||
|
where
|
||||||
|
S: Stream<Item = anyhow::Result<String>> + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
try_stream! {
|
||||||
|
use futures::StreamExt;
|
||||||
|
|
||||||
|
'outer: while let Some(response) = stream.next().await {
|
||||||
|
if response.as_ref().is_ok_and(|s| s == "data: [DONE]") {
|
||||||
|
break 'outer;
|
||||||
|
}
|
||||||
|
let response_str = response?;
|
||||||
|
let line = strip_data_prefix(&response_str);
|
||||||
|
|
||||||
|
if line.is_none() || line.is_some_and(|l| l.is_empty()) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunk: StreamingChunk = serde_json::from_str(line
|
||||||
|
.ok_or_else(|| anyhow!("unexpected stream format"))?)
|
||||||
|
.map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?;
|
||||||
|
let model = chunk.model.clone();
|
||||||
|
|
||||||
|
let usage = chunk.usage.as_ref().map(|u| {
|
||||||
|
ProviderUsage {
|
||||||
|
usage: get_usage(u),
|
||||||
|
model,
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if chunk.choices.is_empty() {
|
||||||
|
yield (None, usage)
|
||||||
|
} else if let Some(tool_calls) = &chunk.choices[0].delta.tool_calls {
|
||||||
|
let tool_call = &tool_calls[0];
|
||||||
|
let id = tool_call.id.clone().ok_or(anyhow!("No tool call ID"))?;
|
||||||
|
let function_name = tool_call.function.name.clone().ok_or(anyhow!("No function name"))?;
|
||||||
|
let mut arguments = tool_call.function.arguments.clone();
|
||||||
|
|
||||||
|
while let Some(response_chunk) = stream.next().await {
|
||||||
|
if response_chunk.as_ref().is_ok_and(|s| s == "data: [DONE]") {
|
||||||
|
break 'outer;
|
||||||
|
}
|
||||||
|
let response_str = response_chunk?;
|
||||||
|
if let Some(line) = strip_data_prefix(&response_str) {
|
||||||
|
let tool_chunk: StreamingChunk = serde_json::from_str(line)
|
||||||
|
.map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?;
|
||||||
|
let more_args = tool_chunk.choices[0].delta.tool_calls.as_ref()
|
||||||
|
.and_then(|calls| calls.first())
|
||||||
|
.map(|call| call.function.arguments.as_str());
|
||||||
|
if let Some(more_args) = more_args {
|
||||||
|
arguments.push_str(more_args);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let parsed = if arguments.is_empty() {
|
||||||
|
Ok(json!({}))
|
||||||
|
} else {
|
||||||
|
serde_json::from_str::<Value>(&arguments)
|
||||||
|
};
|
||||||
|
|
||||||
|
let content = match parsed {
|
||||||
|
Ok(params) => MessageContent::tool_request(
|
||||||
|
id,
|
||||||
|
Ok(ToolCall::new(function_name, params)),
|
||||||
|
),
|
||||||
|
Err(e) => {
|
||||||
|
let error = ToolError::InvalidParameters(format!(
|
||||||
|
"Could not interpret tool use parameters for id {}: {}",
|
||||||
|
id, e
|
||||||
|
));
|
||||||
|
MessageContent::tool_request(id, Err(error))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
yield (
|
||||||
|
Some(Message {
|
||||||
|
id: chunk.id,
|
||||||
|
role: Role::Assistant,
|
||||||
|
created: chrono::Utc::now().timestamp(),
|
||||||
|
content: vec![content],
|
||||||
|
}),
|
||||||
|
usage,
|
||||||
|
)
|
||||||
|
} else if let Some(text) = &chunk.choices[0].delta.content {
|
||||||
|
yield (
|
||||||
|
Some(Message {
|
||||||
|
id: chunk.id,
|
||||||
|
role: Role::Assistant,
|
||||||
|
created: chrono::Utc::now().timestamp(),
|
||||||
|
content: vec![MessageContent::text(text)],
|
||||||
|
}),
|
||||||
|
if chunk.choices[0].finish_reason.is_some() {
|
||||||
|
usage
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_usage(usage: &Value) -> Usage {
|
||||||
let input_tokens = usage
|
let input_tokens = usage
|
||||||
.get("prompt_tokens")
|
.get("prompt_tokens")
|
||||||
.and_then(|v| v.as_i64())
|
.and_then(|v| v.as_i64())
|
||||||
@@ -389,7 +535,7 @@ pub fn get_usage(data: &Value) -> Result<Usage, ProviderError> {
|
|||||||
_ => None,
|
_ => None,
|
||||||
});
|
});
|
||||||
|
|
||||||
Ok(Usage::new(input_tokens, output_tokens, total_tokens))
|
Usage::new(input_tokens, output_tokens, total_tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validates and fixes tool schemas to ensure they have proper parameter structure.
|
/// Validates and fixes tool schemas to ensure they have proper parameter structure.
|
||||||
|
|||||||
@@ -209,11 +209,7 @@ pub fn response_to_message(response: Value) -> Result<Message> {
|
|||||||
let role = Role::Assistant;
|
let role = Role::Assistant;
|
||||||
let created = chrono::Utc::now().timestamp();
|
let created = chrono::Utc::now().timestamp();
|
||||||
if candidate.is_none() {
|
if candidate.is_none() {
|
||||||
return Ok(Message {
|
return Ok(Message::new(role, created, content));
|
||||||
role,
|
|
||||||
created,
|
|
||||||
content,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
let candidate = candidate.unwrap();
|
let candidate = candidate.unwrap();
|
||||||
let parts = candidate
|
let parts = candidate
|
||||||
@@ -252,11 +248,7 @@ pub fn response_to_message(response: Value) -> Result<Message> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(Message {
|
Ok(Message::new(role, created, content))
|
||||||
role,
|
|
||||||
created,
|
|
||||||
content,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract usage information from Google's API response
|
/// Extract usage information from Google's API response
|
||||||
@@ -324,43 +316,39 @@ mod tests {
|
|||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
||||||
fn set_up_text_message(text: &str, role: Role) -> Message {
|
fn set_up_text_message(text: &str, role: Role) -> Message {
|
||||||
Message {
|
Message::new(role, 0, vec![MessageContent::text(text.to_string())])
|
||||||
role,
|
|
||||||
created: 0,
|
|
||||||
content: vec![MessageContent::text(text.to_string())],
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_up_tool_request_message(id: &str, tool_call: ToolCall) -> Message {
|
fn set_up_tool_request_message(id: &str, tool_call: ToolCall) -> Message {
|
||||||
Message {
|
Message::new(
|
||||||
role: Role::User,
|
Role::User,
|
||||||
created: 0,
|
0,
|
||||||
content: vec![MessageContent::tool_request(id.to_string(), Ok(tool_call))],
|
vec![MessageContent::tool_request(id.to_string(), Ok(tool_call))],
|
||||||
}
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_up_tool_confirmation_message(id: &str, tool_call: ToolCall) -> Message {
|
fn set_up_tool_confirmation_message(id: &str, tool_call: ToolCall) -> Message {
|
||||||
Message {
|
Message::new(
|
||||||
role: Role::User,
|
Role::User,
|
||||||
created: 0,
|
0,
|
||||||
content: vec![MessageContent::tool_confirmation_request(
|
vec![MessageContent::tool_confirmation_request(
|
||||||
id.to_string(),
|
id.to_string(),
|
||||||
tool_call.name.clone(),
|
tool_call.name.clone(),
|
||||||
tool_call.arguments.clone(),
|
tool_call.arguments.clone(),
|
||||||
Some("Goose would like to call the above tool. Allow? (y/n):".to_string()),
|
Some("Goose would like to call the above tool. Allow? (y/n):".to_string()),
|
||||||
)],
|
)],
|
||||||
}
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_up_tool_response_message(id: &str, tool_response: Vec<Content>) -> Message {
|
fn set_up_tool_response_message(id: &str, tool_response: Vec<Content>) -> Message {
|
||||||
Message {
|
Message::new(
|
||||||
role: Role::Assistant,
|
Role::Assistant,
|
||||||
created: 0,
|
0,
|
||||||
content: vec![MessageContent::tool_response(
|
vec![MessageContent::tool_response(
|
||||||
id.to_string(),
|
id.to_string(),
|
||||||
Ok(tool_response),
|
Ok(tool_response),
|
||||||
)],
|
)],
|
||||||
}
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_up_tool(name: &str, description: &str, params: Value) -> Tool {
|
fn set_up_tool(name: &str, description: &str, params: Value) -> Tool {
|
||||||
|
|||||||
@@ -274,11 +274,11 @@ pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Message {
|
Ok(Message::new(
|
||||||
role: Role::Assistant,
|
Role::Assistant,
|
||||||
created: chrono::Utc::now().timestamp(),
|
chrono::Utc::now().timestamp(),
|
||||||
content,
|
content,
|
||||||
})
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_usage(data: &Value) -> Result<Usage, ProviderError> {
|
pub fn get_usage(data: &Value) -> Result<Usage, ProviderError> {
|
||||||
|
|||||||
@@ -169,14 +169,14 @@ impl GeminiCliProvider {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let message = Message {
|
let message = Message::new(
|
||||||
role: Role::Assistant,
|
Role::Assistant,
|
||||||
created: chrono::Utc::now().timestamp(),
|
chrono::Utc::now().timestamp(),
|
||||||
content: vec![MessageContent::Text(TextContent {
|
vec![MessageContent::Text(TextContent {
|
||||||
text: response_text,
|
text: response_text,
|
||||||
annotations: None,
|
annotations: None,
|
||||||
})],
|
})],
|
||||||
};
|
);
|
||||||
|
|
||||||
let usage = Usage::default(); // No usage info available for gemini CLI
|
let usage = Usage::default(); // No usage info available for gemini CLI
|
||||||
|
|
||||||
@@ -214,14 +214,14 @@ impl GeminiCliProvider {
|
|||||||
println!("================================");
|
println!("================================");
|
||||||
}
|
}
|
||||||
|
|
||||||
let message = Message {
|
let message = Message::new(
|
||||||
role: Role::Assistant,
|
Role::Assistant,
|
||||||
created: chrono::Utc::now().timestamp(),
|
chrono::Utc::now().timestamp(),
|
||||||
content: vec![MessageContent::Text(TextContent {
|
vec![MessageContent::Text(TextContent {
|
||||||
text: description.clone(),
|
text: description.clone(),
|
||||||
annotations: None,
|
annotations: None,
|
||||||
})],
|
})],
|
||||||
};
|
);
|
||||||
|
|
||||||
let usage = Usage::default();
|
let usage = Usage::default();
|
||||||
|
|
||||||
|
|||||||
@@ -480,14 +480,14 @@ mod tests {
|
|||||||
_tools: &[Tool],
|
_tools: &[Tool],
|
||||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||||
Ok((
|
Ok((
|
||||||
Message {
|
Message::new(
|
||||||
role: Role::Assistant,
|
Role::Assistant,
|
||||||
created: Utc::now().timestamp(),
|
Utc::now().timestamp(),
|
||||||
content: vec![MessageContent::Text(TextContent {
|
vec![MessageContent::Text(TextContent {
|
||||||
text: format!("Response from {}", self.name),
|
text: format!("Response from {}", self.name),
|
||||||
annotations: None,
|
annotations: None,
|
||||||
})],
|
})],
|
||||||
},
|
),
|
||||||
ProviderUsage::new(self.name.clone(), Usage::default()),
|
ProviderUsage::new(self.name.clone(), Usage::default()),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
@@ -643,14 +643,14 @@ mod tests {
|
|||||||
))
|
))
|
||||||
} else {
|
} else {
|
||||||
Ok((
|
Ok((
|
||||||
Message {
|
Message::new(
|
||||||
role: Role::Assistant,
|
Role::Assistant,
|
||||||
created: Utc::now().timestamp(),
|
Utc::now().timestamp(),
|
||||||
content: vec![MessageContent::Text(TextContent {
|
vec![MessageContent::Text(TextContent {
|
||||||
text: format!("Response from {}", self.name),
|
text: format!("Response from {}", self.name),
|
||||||
annotations: None,
|
annotations: None,
|
||||||
})],
|
})],
|
||||||
},
|
),
|
||||||
ProviderUsage::new(self.name.clone(), Usage::default()),
|
ProviderUsage::new(self.name.clone(), Usage::default()),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -203,14 +203,14 @@ impl SageMakerTgiProvider {
|
|||||||
// Strip any HTML tags that might have been generated
|
// Strip any HTML tags that might have been generated
|
||||||
let clean_text = self.strip_html_tags(generated_text);
|
let clean_text = self.strip_html_tags(generated_text);
|
||||||
|
|
||||||
Ok(Message {
|
Ok(Message::new(
|
||||||
role: Role::Assistant,
|
Role::Assistant,
|
||||||
created: Utc::now().timestamp(),
|
Utc::now().timestamp(),
|
||||||
content: vec![MessageContent::Text(TextContent {
|
vec![MessageContent::Text(TextContent {
|
||||||
text: clean_text,
|
text: clean_text,
|
||||||
annotations: None,
|
annotations: None,
|
||||||
})],
|
})],
|
||||||
})
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Strip HTML tags from text to ensure clean output
|
/// Strip HTML tags from text to ensure clean output
|
||||||
|
|||||||
@@ -359,11 +359,7 @@ pub fn convert_tool_messages_to_text(messages: &[Message]) -> Vec<Message> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if has_tool_content {
|
if has_tool_content {
|
||||||
Message {
|
Message::new(message.role.clone(), message.created, new_content)
|
||||||
role: message.role.clone(),
|
|
||||||
content: new_content,
|
|
||||||
created: message.created,
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
message.clone()
|
message.clone()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -319,12 +319,15 @@ pub fn unescape_json_values(value: &Value) -> Value {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn emit_debug_trace(
|
pub fn emit_debug_trace<T1, T2>(
|
||||||
model_config: &ModelConfig,
|
model_config: &ModelConfig,
|
||||||
payload: &Value,
|
payload: &T1,
|
||||||
response: &Value,
|
response: &T2,
|
||||||
usage: &Usage,
|
usage: &Usage,
|
||||||
) {
|
) where
|
||||||
|
T1: ?Sized + Serialize,
|
||||||
|
T2: ?Sized + Serialize,
|
||||||
|
{
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
model_config = %serde_json::to_string_pretty(model_config).unwrap_or_default(),
|
model_config = %serde_json::to_string_pretty(model_config).unwrap_or_default(),
|
||||||
input = %serde_json::to_string_pretty(payload).unwrap_or_default(),
|
input = %serde_json::to_string_pretty(payload).unwrap_or_default(),
|
||||||
|
|||||||
@@ -557,11 +557,7 @@ impl Provider for VeniceProvider {
|
|||||||
};
|
};
|
||||||
|
|
||||||
Ok((
|
Ok((
|
||||||
Message {
|
Message::new(Role::Assistant, Utc::now().timestamp(), content),
|
||||||
role: Role::Assistant,
|
|
||||||
created: Utc::now().timestamp(),
|
|
||||||
content,
|
|
||||||
},
|
|
||||||
ProviderUsage::new(strip_flags(&self.model.model_name).to_string(), usage),
|
ProviderUsage::new(strip_flags(&self.model.model_name).to_string(), usage),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1370,14 +1370,14 @@ mod tests {
|
|||||||
_tools: &[Tool],
|
_tools: &[Tool],
|
||||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||||
Ok((
|
Ok((
|
||||||
Message {
|
Message::new(
|
||||||
role: Role::Assistant,
|
Role::Assistant,
|
||||||
created: Utc::now().timestamp(),
|
Utc::now().timestamp(),
|
||||||
content: vec![MessageContent::Text(TextContent {
|
vec![MessageContent::Text(TextContent {
|
||||||
text: "Mocked scheduled response".to_string(),
|
text: "Mocked scheduled response".to_string(),
|
||||||
annotations: None,
|
annotations: None,
|
||||||
})],
|
})],
|
||||||
},
|
),
|
||||||
ProviderUsage::new("mock-scheduler-test".to_string(), Usage::default()),
|
ProviderUsage::new("mock-scheduler-test".to_string(), Usage::default()),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -407,25 +407,25 @@ mod tests {
|
|||||||
"You are a helpful assistant that can answer questions about the weather.";
|
"You are a helpful assistant that can answer questions about the weather.";
|
||||||
|
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
Message {
|
Message::new(
|
||||||
role: Role::User,
|
Role::User,
|
||||||
created: 0,
|
0,
|
||||||
content: vec![MessageContent::text(
|
vec![MessageContent::text(
|
||||||
"What's the weather like in San Francisco?",
|
"What's the weather like in San Francisco?",
|
||||||
)],
|
)],
|
||||||
},
|
),
|
||||||
Message {
|
Message::new(
|
||||||
role: Role::Assistant,
|
Role::Assistant,
|
||||||
created: 1,
|
1,
|
||||||
content: vec![MessageContent::text(
|
vec![MessageContent::text(
|
||||||
"Looks like it's 60 degrees Fahrenheit in San Francisco.",
|
"Looks like it's 60 degrees Fahrenheit in San Francisco.",
|
||||||
)],
|
)],
|
||||||
},
|
),
|
||||||
Message {
|
Message::new(
|
||||||
role: Role::User,
|
Role::User,
|
||||||
created: 2,
|
2,
|
||||||
content: vec![MessageContent::text("How about New York?")],
|
vec![MessageContent::text("How about New York?")],
|
||||||
},
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
let tools = vec![Tool {
|
let tools = vec![Tool {
|
||||||
@@ -505,25 +505,25 @@ mod tests {
|
|||||||
"You are a helpful assistant that can answer questions about the weather.";
|
"You are a helpful assistant that can answer questions about the weather.";
|
||||||
|
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
Message {
|
Message::new(
|
||||||
role: Role::User,
|
Role::User,
|
||||||
created: 0,
|
0,
|
||||||
content: vec![MessageContent::text(
|
vec![MessageContent::text(
|
||||||
"What's the weather like in San Francisco?",
|
"What's the weather like in San Francisco?",
|
||||||
)],
|
)],
|
||||||
},
|
),
|
||||||
Message {
|
Message::new(
|
||||||
role: Role::Assistant,
|
Role::Assistant,
|
||||||
created: 1,
|
1,
|
||||||
content: vec![MessageContent::text(
|
vec![MessageContent::text(
|
||||||
"Looks like it's 60 degrees Fahrenheit in San Francisco.",
|
"Looks like it's 60 degrees Fahrenheit in San Francisco.",
|
||||||
)],
|
)],
|
||||||
},
|
),
|
||||||
Message {
|
Message::new(
|
||||||
role: Role::User,
|
Role::User,
|
||||||
created: 2,
|
2,
|
||||||
content: vec![MessageContent::text("How about New York?")],
|
vec![MessageContent::text("How about New York?")],
|
||||||
},
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
let tools = vec![Tool {
|
let tools = vec![Tool {
|
||||||
|
|||||||
@@ -786,7 +786,7 @@ function ChatContent({
|
|||||||
<SearchView>
|
<SearchView>
|
||||||
{filteredMessages.map((message, index) => (
|
{filteredMessages.map((message, index) => (
|
||||||
<div
|
<div
|
||||||
key={message.id || index}
|
key={(message.id && `${message.id}-${message.content.length}`) || index}
|
||||||
className="mt-4 px-4"
|
className="mt-4 px-4"
|
||||||
data-testid="message-container"
|
data-testid="message-container"
|
||||||
>
|
>
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ export default function GooseMessage({
|
|||||||
]);
|
]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="goose-message flex w-[90%] justify-start opacity-0 animate-[appear_150ms_ease-in_forwards]">
|
<div className="goose-message flex w-[90%] justify-start">
|
||||||
<div className="flex flex-col w-full">
|
<div className="flex flex-col w-full">
|
||||||
{/* Chain-of-Thought (hidden by default) */}
|
{/* Chain-of-Thought (hidden by default) */}
|
||||||
{cotText && (
|
{cotText && (
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { useState, useCallback, useEffect, useRef, useId } from 'react';
|
import { useState, useCallback, useEffect, useRef, useId, useReducer } from 'react';
|
||||||
import useSWR from 'swr';
|
import useSWR from 'swr';
|
||||||
import { getSecretKey } from '../config';
|
import { getSecretKey } from '../config';
|
||||||
import { Message, createUserMessage, hasCompletedToolCalls } from '../types/message';
|
import { Message, createUserMessage, hasCompletedToolCalls } from '../types/message';
|
||||||
@@ -235,6 +235,9 @@ export function useMessageStream({
|
|||||||
};
|
};
|
||||||
}, [headers, body]);
|
}, [headers, body]);
|
||||||
|
|
||||||
|
// TODO: not this?
|
||||||
|
const [, forceUpdate] = useReducer((x) => x + 1, 0);
|
||||||
|
|
||||||
// Process the SSE stream from the server
|
// Process the SSE stream from the server
|
||||||
const processMessageStream = useCallback(
|
const processMessageStream = useCallback(
|
||||||
async (response: Response, currentMessages: Message[]) => {
|
async (response: Response, currentMessages: Message[]) => {
|
||||||
@@ -284,8 +287,23 @@ export function useMessageStream({
|
|||||||
: parsedEvent.message.sendToLLM,
|
: parsedEvent.message.sendToLLM,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
console.log('New message:', JSON.stringify(newMessage, null, 2));
|
||||||
|
|
||||||
// Update messages with the new message
|
// Update messages with the new message
|
||||||
currentMessages = [...currentMessages, newMessage];
|
|
||||||
|
if (
|
||||||
|
newMessage.id &&
|
||||||
|
currentMessages.length > 0 &&
|
||||||
|
currentMessages[currentMessages.length - 1].id === newMessage.id
|
||||||
|
) {
|
||||||
|
// If the last message has the same ID, update it instead of adding a new one
|
||||||
|
const lastMessage = currentMessages[currentMessages.length - 1];
|
||||||
|
lastMessage.content = [...lastMessage.content, ...newMessage.content];
|
||||||
|
forceUpdate();
|
||||||
|
} else {
|
||||||
|
currentMessages = [...currentMessages, newMessage];
|
||||||
|
}
|
||||||
|
|
||||||
mutate(currentMessages, false);
|
mutate(currentMessages, false);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -373,7 +391,7 @@ export function useMessageStream({
|
|||||||
|
|
||||||
return currentMessages;
|
return currentMessages;
|
||||||
},
|
},
|
||||||
[mutate, onFinish, onError]
|
[mutate, onFinish, onError, forceUpdate]
|
||||||
);
|
);
|
||||||
|
|
||||||
// Send a request to the server
|
// Send a request to the server
|
||||||
|
|||||||
@@ -201,7 +201,7 @@ export function getTextContent(message: Message): string {
|
|||||||
}
|
}
|
||||||
return '';
|
return '';
|
||||||
})
|
})
|
||||||
.join('\n');
|
.join('');
|
||||||
}
|
}
|
||||||
|
|
||||||
export function getToolRequests(message: Message): ToolRequestMessageContent[] {
|
export function getToolRequests(message: Message): ToolRequestMessageContent[] {
|
||||||
|
|||||||
Reference in New Issue
Block a user