diff --git a/Cargo.lock b/Cargo.lock index 0e6c66c8..86a2b497 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3568,6 +3568,7 @@ dependencies = [ "test-case", "tokio", "tokio-stream", + "tokio-util", "tower-http", "tracing", "tracing-appender", @@ -3703,6 +3704,7 @@ dependencies = [ "tokio", "tokio-cron-scheduler", "tokio-stream", + "tokio-util", "tower 0.5.2", "tower-http", "tracing", diff --git a/crates/goose-cli/Cargo.toml b/crates/goose-cli/Cargo.toml index 88aa0266..502e9a09 100644 --- a/crates/goose-cli/Cargo.toml +++ b/crates/goose-cli/Cargo.toml @@ -69,8 +69,8 @@ tokio-stream = "0.1" bytes = "1.5" http = "1.0" webbrowser = "1.0" - indicatif = "0.17.11" +tokio-util = "0.7.15" [target.'cfg(target_os = "windows")'.dependencies] winapi = { version = "0.3", features = ["wincred"] } diff --git a/crates/goose-cli/src/commands/web.rs b/crates/goose-cli/src/commands/web.rs index 2fab9816..ba5206b4 100644 --- a/crates/goose-cli/src/commands/web.rs +++ b/crates/goose-cli/src/commands/web.rs @@ -209,7 +209,7 @@ async fn serve_static(axum::extract::Path(path): axum::extract::Path) -> include_bytes!("../../../../documentation/static/img/logo_light.png").to_vec(), ) .into_response(), - _ => (axum::http::StatusCode::NOT_FOUND, "Not found").into_response(), + _ => (http::StatusCode::NOT_FOUND, "Not found").into_response(), } } @@ -484,7 +484,6 @@ async fn process_message_streaming( ) .await?; - // Create a session config let session_config = SessionConfig { id: session::Identifier::Path(session_file.clone()), working_dir: std::env::current_dir()?, @@ -494,8 +493,7 @@ async fn process_message_streaming( retry_config: None, }; - // Get response from agent - match agent.reply(&messages, Some(session_config)).await { + match agent.reply(&messages, Some(session_config), None).await { Ok(mut stream) => { while let Some(result) = stream.next().await { match result { diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 1635a65d..52e01d3d 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -48,6 +48,7 @@ use std::path::PathBuf; use std::sync::Arc; use std::time::Instant; use tokio; +use tokio_util::sync::CancellationToken; pub enum RunMode { Normal, @@ -132,13 +133,10 @@ impl Session { retry_config: Option, ) -> Self { let messages = if let Some(session_file) = &session_file { - match session::read_messages(session_file) { - Ok(msgs) => msgs, - Err(e) => { - eprintln!("Warning: Failed to load message history: {}", e); - Vec::new() - } - } + session::read_messages(session_file).unwrap_or_else(|e| { + eprintln!("Warning: Failed to load message history: {}", e); + Vec::new() + }) } else { // Don't try to read messages if we're not saving sessions Vec::new() @@ -180,7 +178,7 @@ impl Session { /// Format: "ENV1=val1 ENV2=val2 command args..." pub async fn add_extension(&mut self, extension_command: String) -> Result<()> { let mut parts: Vec<&str> = extension_command.split_whitespace().collect(); - let mut envs = std::collections::HashMap::new(); + let mut envs = HashMap::new(); // Parse environment variables (format: KEY=value) while let Some(part) = parts.first() { @@ -473,7 +471,7 @@ impl Session { self.display_context_usage().await?; match input::get_input(&mut editor)? { - input::InputResult::Message(content) => { + InputResult::Message(content) => { match self.run_mode { RunMode::Normal => { save_history(&mut editor); @@ -495,15 +493,11 @@ impl Session { eprintln!("Warning: Failed to update project tracker with instruction: {}", e); } - // Get the provider from the agent for description generation let provider = self.agent.provider().await?; // Persist messages with provider for automatic description generation if let Some(session_file) = &self.session_file { - let working_dir = Some( - std::env::current_dir() - .expect("failed to get current session working directory"), - ); + let working_dir = Some(std::env::current_dir().unwrap_or_default()); session::persist_messages_with_schedule_id( session_file, @@ -847,12 +841,14 @@ impl Session { } async fn process_agent_response(&mut self, interactive: bool) -> Result<()> { + let cancel_token = CancellationToken::new(); + let cancel_token_clone = cancel_token.clone(); + let session_config = self.session_file.as_ref().map(|s| { let session_id = session::Identifier::Path(s.clone()); SessionConfig { id: session_id.clone(), - working_dir: std::env::current_dir() - .expect("failed to get current session working directory"), + working_dir: std::env::current_dir().unwrap_or_default(), schedule_id: self.scheduled_job_id.clone(), execution_mode: None, max_turns: self.max_turns, @@ -861,7 +857,7 @@ impl Session { }); let mut stream = self .agent - .reply(&self.messages, session_config.clone()) + .reply(&self.messages, session_config.clone(), Some(cancel_token)) .await?; let mut progress_bars = output::McpSpinners::new(); @@ -919,7 +915,7 @@ impl Session { ) .await?; } - + cancel_token_clone.cancel(); drop(stream); break; } else { @@ -1001,6 +997,7 @@ impl Session { .reply( &self.messages, session_config.clone(), + None ) .await?; } @@ -1157,6 +1154,7 @@ impl Session { Some(Err(e)) => { eprintln!("Error: {}", e); + cancel_token_clone.cancel(); drop(stream); if let Err(e) = self.handle_interrupted_messages(false).await { eprintln!("Error handling interruption: {}", e); @@ -1173,6 +1171,7 @@ impl Session { } } _ = tokio::signal::ctrl_c() => { + cancel_token_clone.cancel(); drop(stream); if let Err(e) = self.handle_interrupted_messages(true).await { eprintln!("Error handling interruption: {}", e); diff --git a/crates/goose-ffi/src/lib.rs b/crates/goose-ffi/src/lib.rs index bf4197c2..6bcf3abe 100644 --- a/crates/goose-ffi/src/lib.rs +++ b/crates/goose-ffi/src/lib.rs @@ -247,7 +247,7 @@ pub unsafe extern "C" fn goose_agent_send_message( // Block on the async call using our global runtime let response = get_runtime().block_on(async { - let mut stream = match agent.reply(&messages, None).await { + let mut stream = match agent.reply(&messages, None, None).await { Ok(stream) => stream, Err(e) => return format!("Error getting reply from agent: {}", e), }; diff --git a/crates/goose-server/Cargo.toml b/crates/goose-server/Cargo.toml index 22f2c711..8260a929 100644 --- a/crates/goose-server/Cargo.toml +++ b/crates/goose-server/Cargo.toml @@ -42,6 +42,7 @@ axum-extra = "0.10.0" utoipa = { version = "4.1", features = ["axum_extras", "chrono"] } dirs = "6.0.0" reqwest = { version = "0.12.9", features = ["json", "rustls-tls", "blocking", "multipart"], default-features = false } +tokio-util = "0.7.15" [[bin]] name = "goosed" diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index e5804171..36ceed48 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -11,7 +11,7 @@ use bytes::Bytes; use futures::{stream::StreamExt, Stream}; use goose::{ agents::{AgentEvent, SessionConfig}, - message::{push_message, Message, MessageContent}, + message::{push_message, Message}, permission::permission_confirmation::PrincipalType, }; use goose::{ @@ -19,7 +19,7 @@ use goose::{ session, }; use mcp_core::{protocol::JsonRpcMessage, ToolResult}; -use rmcp::model::{Content, Role}; +use rmcp::model::Content; use serde::{Deserialize, Serialize}; use serde_json::json; use serde_json::Value; @@ -34,9 +34,10 @@ use std::{ use tokio::sync::mpsc; use tokio::time::timeout; use tokio_stream::wrappers::ReceiverStream; +use tokio_util::sync::CancellationToken; use utoipa::ToSchema; -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize)] struct ChatRequest { messages: Vec, session_id: Option, @@ -113,7 +114,7 @@ async fn stream_event( tx.send(format!("data: {}\n\n", json)).await } -async fn handler( +async fn reply_handler( State(state): State>, headers: HeaderMap, Json(request): Json, @@ -122,6 +123,7 @@ async fn handler( let (tx, rx) = mpsc::channel(100); let stream = ReceiverStream::new(rx); + let cancel_token = CancellationToken::new(); let messages = request.messages; let session_working_dir = request.session_working_dir.clone(); @@ -130,65 +132,35 @@ async fn handler( .session_id .unwrap_or_else(session::generate_session_id); - tokio::spawn(async move { - let agent = state.get_agent().await; - let agent = match agent { - Ok(agent) => { - let provider = agent.provider().await; - match provider { - Ok(_) => agent, - Err(_) => { - let _ = stream_event( - MessageEvent::Error { - error: "No provider configured".to_string(), - }, - &tx, - ) - .await; - let _ = stream_event( - MessageEvent::Finish { - reason: "error".to_string(), - }, - &tx, - ) - .await; - return; - } - } - } + let task_cancel = cancel_token.clone(); + let task_tx = tx.clone(); + + std::mem::drop(tokio::spawn(async move { + let agent = match state.get_agent().await { + Ok(agent) => agent, Err(_) => { let _ = stream_event( MessageEvent::Error { error: "No agent configured".to_string(), }, - &tx, - ) - .await; - let _ = stream_event( - MessageEvent::Finish { - reason: "error".to_string(), - }, - &tx, + &task_tx, ) .await; return; } }; - let provider = agent.provider().await; + let session_config = SessionConfig { + id: session::Identifier::Name(session_id.clone()), + working_dir: PathBuf::from(&session_working_dir), + schedule_id: request.scheduled_job_id.clone(), + execution_mode: None, + max_turns: None, + retry_config: None, + }; let mut stream = match agent - .reply( - &messages, - Some(SessionConfig { - id: session::Identifier::Name(session_id.clone()), - working_dir: PathBuf::from(&session_working_dir), - schedule_id: request.scheduled_job_id.clone(), - execution_mode: None, - max_turns: None, - retry_config: None, - }), - ) + .reply(&messages, Some(session_config), Some(task_cancel.clone())) .await { Ok(stream) => stream, @@ -198,14 +170,7 @@ async fn handler( MessageEvent::Error { error: e.to_string(), }, - &tx, - ) - .await; - let _ = stream_event( - MessageEvent::Finish { - reason: "error".to_string(), - }, - &tx, + &task_tx, ) .await; return; @@ -221,7 +186,7 @@ async fn handler( MessageEvent::Error { error: format!("Failed to get session path: {}", e), }, - &tx, + &task_tx, ) .await; return; @@ -231,222 +196,104 @@ async fn handler( loop { tokio::select! { - response = timeout(Duration::from_millis(500), stream.next()) => { - match response { - Ok(Some(Ok(AgentEvent::Message(message)))) => { - push_message(&mut all_messages, message.clone()); - if let Err(e) = stream_event(MessageEvent::Message { message }, &tx).await { - tracing::error!("Error sending message through channel: {}", e); - let _ = stream_event( - MessageEvent::Error { - error: e.to_string(), - }, - &tx, - ).await; + _ = task_cancel.cancelled() => { + tracing::info!("Agent task cancelled"); break; } - } - Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => { - if let Err(e) = stream_event(MessageEvent::ModelChange { model, mode }, &tx).await { - tracing::error!("Error sending model change through channel: {}", e); - let _ = stream_event( - MessageEvent::Error { - error: e.to_string(), - }, - &tx, - ).await; - } - } - Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => { - if let Err(e) = stream_event(MessageEvent::Notification{ - request_id: request_id.clone(), - message: n, - }, &tx).await { - tracing::error!("Error sending message through channel: {}", e); - let _ = stream_event( - MessageEvent::Error { - error: e.to_string(), - }, - &tx, - ).await; - } - } + response = timeout(Duration::from_millis(500), stream.next()) => { + match response { + Ok(Some(Ok(AgentEvent::Message(message)))) => { + push_message(&mut all_messages, message.clone()); + if let Err(e) = stream_event(MessageEvent::Message { message }, &tx).await { + tracing::error!("Error sending message through channel: {}", e); + let _ = stream_event( + MessageEvent::Error { + error: e.to_string(), + }, + &tx, + ).await; + break; + } + } + Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => { + if let Err(e) = stream_event(MessageEvent::ModelChange { model, mode }, &tx).await { + tracing::error!("Error sending model change through channel: {}", e); + let _ = stream_event( + MessageEvent::Error { + error: e.to_string(), + }, + &tx, + ).await; + } + } + Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => { + if let Err(e) = stream_event(MessageEvent::Notification{ + request_id: request_id.clone(), + message: n, + }, &tx).await { + tracing::error!("Error sending message through channel: {}", e); + let _ = stream_event( + MessageEvent::Error { + error: e.to_string(), + }, + &tx, + ).await; + } + } - Ok(Some(Err(e))) => { - tracing::error!("Error processing message: {}", e); - let _ = stream_event( - MessageEvent::Error { - error: e.to_string(), - }, - &tx, - ).await; - break; - } - Ok(None) => { - break; - } - Err(_) => { // Heartbeat, used to detect disconnected clients - if tx.is_closed() { - break; + Ok(Some(Err(e))) => { + tracing::error!("Error processing message: {}", e); + let _ = stream_event( + MessageEvent::Error { + error: e.to_string(), + }, + &tx, + ).await; + break; + } + Ok(None) => { + break; + } + Err(_) => { + if tx.is_closed() { + break; + } + continue; + } + } } - continue; } - } - } - } } if all_messages.len() > saved_message_count { - let provider = Arc::clone(provider.as_ref().unwrap()); - tokio::spawn(async move { - if let Err(e) = session::persist_messages( - &session_path, - &all_messages, - Some(provider), - Some(PathBuf::from(&session_working_dir)), - ) - .await - { - tracing::error!("Failed to store session history: {:?}", e); - } - }); + if let Ok(provider) = agent.provider().await { + let provider = Arc::clone(&provider); + tokio::spawn(async move { + if let Err(e) = session::persist_messages( + &session_path, + &all_messages, + Some(provider), + Some(PathBuf::from(&session_working_dir)), + ) + .await + { + tracing::error!("Failed to store session history: {:?}", e); + } + }); + } } let _ = stream_event( MessageEvent::Finish { reason: "stop".to_string(), }, - &tx, + &task_tx, ) .await; - }); - + })); Ok(SseResponse::new(stream)) } -#[derive(Debug, Deserialize, Serialize)] -struct AskRequest { - prompt: String, - session_id: Option, - session_working_dir: String, - scheduled_job_id: Option, -} - -#[derive(Debug, Serialize)] -struct AskResponse { - response: String, -} - -async fn ask_handler( - State(state): State>, - headers: HeaderMap, - Json(request): Json, -) -> Result, StatusCode> { - verify_secret_key(&headers, &state)?; - - let session_working_dir = request.session_working_dir.clone(); - - let session_id = request - .session_id - .unwrap_or_else(session::generate_session_id); - - let agent = state - .get_agent() - .await - .map_err(|_| StatusCode::PRECONDITION_FAILED)?; - - let provider = agent.provider().await; - - let messages = vec![Message::user().with_text(request.prompt)]; - - let mut response_text = String::new(); - let mut stream = match agent - .reply( - &messages, - Some(SessionConfig { - id: session::Identifier::Name(session_id.clone()), - working_dir: PathBuf::from(&session_working_dir), - schedule_id: request.scheduled_job_id.clone(), - execution_mode: None, - max_turns: None, - retry_config: None, - }), - ) - .await - { - Ok(stream) => stream, - Err(e) => { - tracing::error!("Failed to start reply stream: {:?}", e); - return Err(StatusCode::INTERNAL_SERVER_ERROR); - } - }; - - let mut all_messages = messages.clone(); - let mut response_message = Message::assistant(); - - while let Some(response) = stream.next().await { - match response { - Ok(AgentEvent::Message(message)) => { - if message.role == Role::Assistant { - for content in &message.content { - if let MessageContent::Text(text) = content { - response_text.push_str(&text.text); - response_text.push('\n'); - } - response_message.content.push(content.clone()); - } - } - } - Ok(AgentEvent::ModelChange { model, mode }) => { - // Log model change for non-streaming - tracing::info!("Model changed to {} in {} mode", model, mode); - } - Ok(AgentEvent::McpNotification(n)) => { - // Handle notifications if needed - tracing::info!("Received notification: {:?}", n); - } - - Err(e) => { - tracing::error!("Error processing as_ai message: {}", e); - return Err(StatusCode::INTERNAL_SERVER_ERROR); - } - } - } - - if !response_message.content.is_empty() { - push_message(&mut all_messages, response_message); - } - - let session_path = match session::get_path(session::Identifier::Name(session_id.clone())) { - Ok(path) => path, - Err(e) => { - tracing::error!("Failed to get session path: {}", e); - return Err(StatusCode::INTERNAL_SERVER_ERROR); - } - }; - - let session_path_clone = session_path.clone(); - let messages = all_messages.clone(); - let provider = Arc::clone(provider.as_ref().unwrap()); - let session_working_dir_clone = session_working_dir.clone(); - tokio::spawn(async move { - if let Err(e) = session::persist_messages( - &session_path_clone, - &messages, - Some(provider), - Some(PathBuf::from(session_working_dir_clone)), - ) - .await - { - tracing::error!("Failed to store session history: {:?}", e); - } - }); - - Ok(Json(AskResponse { - response: response_text.trim().to_string(), - })) -} - #[derive(Debug, Deserialize, Serialize, ToSchema)] pub struct PermissionConfirmationRequest { id: String, @@ -509,7 +356,7 @@ struct ToolResultRequest { async fn submit_tool_result( State(state): State>, headers: HeaderMap, - raw: axum::extract::Json, + raw: Json, ) -> Result, StatusCode> { verify_secret_key(&headers, &state)?; @@ -540,8 +387,7 @@ async fn submit_tool_result( pub fn routes(state: Arc) -> Router { Router::new() - .route("/reply", post(handler)) - .route("/ask", post(ask_handler)) + .route("/reply", post(reply_handler)) .route("/confirm", post(confirm_permission)) .route("/tool_result", post(submit_tool_result)) .with_state(state) @@ -571,10 +417,6 @@ mod tests { goose::providers::base::ProviderMetadata::empty() } - fn get_model_config(&self) -> ModelConfig { - self.model_config.clone() - } - async fn complete( &self, _system: &str, @@ -586,6 +428,10 @@ mod tests { ProviderUsage::new("mock".to_string(), Usage::default()), )) } + + fn get_model_config(&self) -> ModelConfig { + self.model_config.clone() + } } mod integration_tests { @@ -595,7 +441,7 @@ mod tests { use tower::ServiceExt; #[tokio::test] - async fn test_ask_endpoint() { + async fn test_reply_endpoint() { let mock_model_config = ModelConfig::new("test-model".to_string()); let mock_provider = Arc::new(MockProvider { model_config: mock_model_config, @@ -603,24 +449,17 @@ mod tests { let agent = Agent::new(); let _ = agent.update_provider(mock_provider).await; let state = AppState::new(Arc::new(agent), "test-secret".to_string()).await; - let scheduler_path = goose::scheduler::get_default_scheduler_storage_path() - .expect("Failed to get default scheduler storage path"); - let scheduler = - goose::scheduler_factory::SchedulerFactory::create_legacy(scheduler_path) - .await - .unwrap(); - state.set_scheduler(scheduler).await; let app = routes(state); let request = Request::builder() - .uri("/ask") + .uri("/reply") .method("POST") .header("content-type", "application/json") .header("x-secret-key", "test-secret") .body(Body::from( - serde_json::to_string(&AskRequest { - prompt: "test prompt".to_string(), + serde_json::to_string(&ChatRequest { + messages: vec![Message::user().with_text("test message")], session_id: Some("test-session".to_string()), session_working_dir: "test-working-dir".to_string(), scheduled_job_id: None, diff --git a/crates/goose/examples/agent.rs b/crates/goose/examples/agent.rs index c5ac11cb..5ed6e90c 100644 --- a/crates/goose/examples/agent.rs +++ b/crates/goose/examples/agent.rs @@ -35,7 +35,7 @@ async fn main() { let messages = vec![Message::user() .with_text("can you summarize the readme.md in this dir using just a haiku?")]; - let mut stream = agent.reply(&messages, None).await.unwrap(); + let mut stream = agent.reply(&messages, None, None).await.unwrap(); while let Some(Ok(AgentEvent::Message(message))) = stream.next().await { println!("{}", serde_json::to_string_pretty(&message).unwrap()); println!("\n"); diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 1fbc70e8..ed755a7a 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -8,15 +8,32 @@ use futures::stream::BoxStream; use futures::{stream, FutureExt, Stream, StreamExt, TryStreamExt}; use mcp_core::protocol::JsonRpcMessage; +use crate::agents::extension::{ExtensionConfig, ExtensionError, ExtensionResult, ToolInfo}; +use crate::agents::extension_manager::{get_parameter_names, ExtensionManager}; use crate::agents::final_output_tool::{FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_OUTPUT_TOOL_NAME}; +use crate::agents::platform_tools::{ + PLATFORM_LIST_RESOURCES_TOOL_NAME, PLATFORM_MANAGE_EXTENSIONS_TOOL_NAME, + PLATFORM_MANAGE_SCHEDULE_TOOL_NAME, PLATFORM_READ_RESOURCE_TOOL_NAME, + PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME, +}; +use crate::agents::prompt_manager::PromptManager; use crate::agents::recipe_tools::dynamic_task_tools::{ create_dynamic_task, create_dynamic_task_tool, DYNAMIC_TASK_TOOL_NAME_PREFIX, }; +use crate::agents::retry::{RetryManager, RetryResult}; +use crate::agents::router_tool_selector::{ + create_tool_selector, RouterToolSelectionStrategy, RouterToolSelector, +}; +use crate::agents::router_tools::{ROUTER_LLM_SEARCH_TOOL_NAME, ROUTER_VECTOR_SEARCH_TOOL_NAME}; use crate::agents::sub_recipe_manager::SubRecipeManager; use crate::agents::subagent_execution_tool::subagent_execute_task_tool::{ self, SUBAGENT_EXECUTE_TASK_TOOL_NAME, }; use crate::agents::subagent_execution_tool::tasks_manager::TasksManager; +use crate::agents::tool_router_index_manager::ToolRouterIndexManager; +use crate::agents::tool_vectordb::generate_table_id; +use crate::agents::types::SessionConfig; +use crate::agents::types::{FrontendTool, ToolResultReceiver}; use crate::config::{Config, ExtensionConfigManager, PermissionManager}; use crate::message::{push_message, Message}; use crate::permission::permission_judge::check_tool_permissions; @@ -26,31 +43,14 @@ use crate::providers::errors::ProviderError; use crate::recipe::{Author, Recipe, Response, Settings, SubRecipe}; use crate::scheduler_trait::SchedulerTrait; use crate::tool_monitor::{ToolCall, ToolMonitor}; +use mcp_core::{protocol::GetPromptResult, tool::Tool, ToolError, ToolResult}; use regex::Regex; +use rmcp::model::{Content, Prompt}; use serde_json::Value; use tokio::sync::{mpsc, Mutex, RwLock}; +use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, instrument}; -use crate::agents::extension::{ExtensionConfig, ExtensionError, ExtensionResult, ToolInfo}; -use crate::agents::extension_manager::{get_parameter_names, ExtensionManager}; -use crate::agents::platform_tools::{ - PLATFORM_LIST_RESOURCES_TOOL_NAME, PLATFORM_MANAGE_EXTENSIONS_TOOL_NAME, - PLATFORM_MANAGE_SCHEDULE_TOOL_NAME, PLATFORM_READ_RESOURCE_TOOL_NAME, - PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME, -}; -use crate::agents::prompt_manager::PromptManager; -use crate::agents::retry::{RetryManager, RetryResult}; -use crate::agents::router_tool_selector::{ - create_tool_selector, RouterToolSelectionStrategy, RouterToolSelector, -}; -use crate::agents::router_tools::{ROUTER_LLM_SEARCH_TOOL_NAME, ROUTER_VECTOR_SEARCH_TOOL_NAME}; -use crate::agents::tool_router_index_manager::ToolRouterIndexManager; -use crate::agents::tool_vectordb::generate_table_id; -use crate::agents::types::SessionConfig; -use crate::agents::types::{FrontendTool, ToolResultReceiver}; -use mcp_core::{protocol::GetPromptResult, tool::Tool, ToolError, ToolResult}; -use rmcp::model::{Content, Prompt}; - use super::final_output_tool::FinalOutputTool; use super::platform_tools; use super::router_tools; @@ -317,17 +317,17 @@ impl Agent { } if tool_call.name == FINAL_OUTPUT_TOOL_NAME { - if let Some(final_output_tool) = self.final_output_tool.lock().await.as_mut() { + return if let Some(final_output_tool) = self.final_output_tool.lock().await.as_mut() { let result = final_output_tool.execute_tool_call(tool_call.clone()).await; - return (request_id, Ok(result)); + (request_id, Ok(result)) } else { - return ( + ( request_id, Err(ToolError::ExecutionError( "Final output tool not defined".to_string(), )), - ); - } + ) + }; } let extension_manager = self.extension_manager.read().await; @@ -406,10 +406,9 @@ impl Agent { let result = extension_manager .dispatch_tool_call(tool_call.clone()) .await; - match result { - Ok(call_result) => call_result, - Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))), - } + result.unwrap_or_else(|e| { + ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))) + }) }; ( @@ -719,42 +718,17 @@ impl Agent { &self, messages: &[Message], session: Option, - ) -> anyhow::Result>> { + cancel_token: Option, + ) -> Result>> { let mut messages = messages.to_vec(); let initial_messages = messages.clone(); let reply_span = tracing::Span::current(); - self.reset_retry_attempts().await; - - // Load settings from config let config = Config::global(); - // Setup tools and prompt let (mut tools, mut toolshim_tools, mut system_prompt) = self.prepare_tools_and_prompt().await?; - - // Get goose_mode from config, but override with execution_mode if provided in session config - let mut goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string()); - - // If this is a scheduled job with an execution_mode, override the goose_mode - if let Some(session_config) = &session { - if let Some(execution_mode) = &session_config.execution_mode { - // Map "foreground" to "auto" and "background" to "chat" - goose_mode = match execution_mode.as_str() { - "foreground" => "auto".to_string(), - "background" => "chat".to_string(), - _ => goose_mode, - }; - tracing::info!( - "Using execution_mode '{}' which maps to goose_mode '{}'", - execution_mode, - goose_mode - ); - } - } - - let (tools_with_readonly_annotation, tools_without_annotation) = - Self::categorize_tools_by_annotation(&tools); + let goose_mode = Self::determine_goose_mode(session.as_ref(), config); if let Some(content) = messages .last() @@ -775,12 +749,16 @@ impl Agent { }); loop { - // Check for final output before incrementing turns or checking max_turns - // This ensures that if we have a final output ready, we return it immediately - // without being blocked by the max_turns limit - this is needed for streaming cases + if cancel_token.as_ref().is_some_and(|t| t.is_cancelled()) { + break; + } + if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() { if final_output_tool.final_output.is_some() { - yield AgentEvent::Message(Message::assistant().with_text(final_output_tool.final_output.clone().unwrap())); + let final_event = AgentEvent::Message( + Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()), + ); + yield final_event; break; } } @@ -793,23 +771,19 @@ impl Agent { break; } - // Check for MCP notifications from subagents + // Handle MCP notifications from subagents let mcp_notifications = self.get_mcp_notifications().await; for notification in mcp_notifications { - // Extract subagent info from the notification data - if let JsonRpcMessage::Notification(ref notif) = notification { - if let Some(params) = ¬if.params { - if let Some(data) = params.get("data") { - if let (Some(subagent_id), Some(_message)) = ( - data.get("subagent_id").and_then(|v| v.as_str()), - data.get("message").and_then(|v| v.as_str()) - ) { - // Emit as McpNotification event - yield AgentEvent::McpNotification(( - subagent_id.to_string(), - notification.clone() - )); - } + if let JsonRpcMessage::Notification(notif) = ¬ification { + if let Some(data) = notif.params.as_ref().and_then(|p| p.get("data")) { + if let (Some(subagent_id), Some(_message)) = ( + data.get("subagent_id").and_then(|v| v.as_str()), + data.get("message").and_then(|v| v.as_str()), + ) { + yield AgentEvent::McpNotification(( + subagent_id.to_string(), + notification.clone(), + )); } } } @@ -821,17 +795,24 @@ impl Agent { &messages, &tools, &toolshim_tools, - ).await?; + ) + .await?; let mut added_message = false; + let mut messages_to_add = Vec::new(); + let mut tools_updated = false; + while let Some(next) = stream.next().await { + if cancel_token.as_ref().is_some_and(|t| t.is_cancelled()) { + break; + } + match next { Ok((response, usage)) => { // Emit model change event if provider is lead-worker let provider = self.provider().await?; if let Some(lead_worker) = provider.as_lead_worker() { if let Some(ref usage) = usage { - // The actual model used is in the usage let active_model = usage.model.clone(); let (lead_model, worker_model) = lead_worker.get_model_info(); let mode = if active_model == lead_model { @@ -849,43 +830,44 @@ impl Agent { } } - // record usage for the session in the session file - if let Some(session_config) = session.clone() { + // Record usage for the session + if let Some(ref session_config) = &session { if let Some(ref usage) = usage { - Self::update_session_metrics(session_config, usage, messages.len()).await?; + Self::update_session_metrics(session_config, usage, messages.len()) + .await?; } } if let Some(response) = response { - // categorize the type of requests we need to handle - let (frontend_requests, - remaining_requests, - filtered_response) = + let (tools_with_readonly_annotation, tools_without_annotation) = + Self::categorize_tools_by_annotation(&tools); + + // Categorize tool requests + let (frontend_requests, remaining_requests, filtered_response) = self.categorize_tool_requests(&response).await; // Record tool calls in the router selector let selector = self.router_tool_selector.lock().await.clone(); if let Some(selector) = selector { - // Record frontend tool calls for request in &frontend_requests { if let Ok(tool_call) = &request.tool_call { - if let Err(e) = selector.record_tool_call(&tool_call.name).await { - tracing::error!("Failed to record frontend tool call: {}", e); + if let Err(e) = selector.record_tool_call(&tool_call.name).await + { + error!("Failed to record frontend tool call: {}", e); } } } - // Record remaining tool calls for request in &remaining_requests { if let Ok(tool_call) = &request.tool_call { - if let Err(e) = selector.record_tool_call(&tool_call.name).await { - tracing::error!("Failed to record tool call: {}", e); + if let Err(e) = selector.record_tool_call(&tool_call.name).await + { + error!("Failed to record tool call: {}", e); } } } } - // Yield the assistant's response with frontend tool requests filtered out - yield AgentEvent::Message(filtered_response.clone()); + yield AgentEvent::Message(filtered_response.clone()); tokio::task::yield_now().await; let num_tool_requests = frontend_requests.len() + remaining_requests.len(); @@ -893,23 +875,17 @@ impl Agent { continue; } - // Process tool requests depending on frontend tools and then goose_mode let message_tool_response = Arc::new(Mutex::new(Message::user())); - // First handle any frontend tool requests let mut frontend_tool_stream = self.handle_frontend_tool_requests( &frontend_requests, - message_tool_response.clone() + message_tool_response.clone(), ); - // we have a stream of frontend tools to handle, inside the stream - // execution is yeield back to this reply loop, and is of the same Message - // type, so we can yield that back up to be handled while let Some(msg) = frontend_tool_stream.try_next().await? { yield AgentEvent::Message(msg); } - // Clone goose_mode once before the match to avoid move issues let mode = goose_mode.clone(); if mode.as_str() == "chat" { // Skip all tool calls in chat mode @@ -921,36 +897,42 @@ impl Agent { ); } } else { - // At this point, we have handled the frontend tool requests and know goose_mode != "chat" - // What remains is handling the remaining tool requests (enable extension, - // regular tool calls) in goose_mode == ["auto", "approve" or "smart_approve"] let mut permission_manager = PermissionManager::default(); - let (permission_check_result, enable_extension_request_ids) = check_tool_permissions( - &remaining_requests, - &mode, - tools_with_readonly_annotation.clone(), - tools_without_annotation.clone(), - &mut permission_manager, - self.provider().await?).await; + let (permission_check_result, enable_extension_request_ids) = + check_tool_permissions( + &remaining_requests, + &mode, + tools_with_readonly_annotation.clone(), + tools_without_annotation.clone(), + &mut permission_manager, + self.provider().await?, + ) + .await; - // Handle pre-approved and read-only tools in parallel let mut tool_futures: Vec<(String, ToolStream)> = Vec::new(); - // Skip the confirmation for approved tools + // Handle pre-approved and read-only tools for request in &permission_check_result.approved { if let Ok(tool_call) = request.tool_call.clone() { - let (req_id, tool_result) = self.dispatch_tool_call(tool_call, request.id.clone()).await; + let (req_id, tool_result) = self + .dispatch_tool_call(tool_call, request.id.clone()) + .await; - tool_futures.push((req_id, match tool_result { - Ok(result) => tool_stream( - result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())), - result.result, - ), - Err(e) => tool_stream( - Box::new(stream::empty()), - futures::future::ready(Err(e)), - ), - })); + tool_futures.push(( + req_id, + match tool_result { + Ok(result) => tool_stream( + result + .notification_stream + .unwrap_or_else(|| Box::new(stream::empty())), + result.result, + ), + Err(e) => tool_stream( + Box::new(stream::empty()), + futures::future::ready(Err(e)), + ), + }, + )); } } @@ -962,29 +944,22 @@ impl Agent { ); } - // We need interior mutability in handle_approval_tool_requests let tool_futures_arc = Arc::new(Mutex::new(tool_futures)); - // Process tools requiring approval (enable extension, regular tool calls) + // Process tools requiring approval let mut tool_approval_stream = self.handle_approval_tool_requests( &permission_check_result.needs_approval, tool_futures_arc.clone(), &mut permission_manager, - message_tool_response.clone() + message_tool_response.clone(), ); - // We have a stream of tool_approval_requests to handle - // Execution is yielded back to this reply loop, and is of the same Message - // type, so we can yield the Message back up to be handled and grab any - // confirmations or denials while let Some(msg) = tool_approval_stream.try_next().await? { yield AgentEvent::Message(msg); } tool_futures = { - // Lock the mutex asynchronously let mut futures_lock = tool_futures_arc.lock().await; - // Drain the vector and collect into a new Vec futures_lock.drain(..).collect::>() }; @@ -996,27 +971,33 @@ impl Agent { .collect::>(); let mut combined = stream::select_all(with_id); - let mut all_install_successful = true; while let Some((request_id, item)) = combined.next().await { + if cancel_token.as_ref().is_some_and(|t| t.is_cancelled()) { + break; + } match item { ToolStreamItem::Result(output) => { - if enable_extension_request_ids.contains(&request_id) && output.is_err(){ + if enable_extension_request_ids.contains(&request_id) + && output.is_err() + { all_install_successful = false; } let mut response = message_tool_response.lock().await; - *response = response.clone().with_tool_response(request_id, output); - }, + *response = + response.clone().with_tool_response(request_id, output); + } ToolStreamItem::Message(msg) => { - yield AgentEvent::McpNotification((request_id, msg)) + yield AgentEvent::McpNotification(( + request_id, msg, + )); } } } - // Update system prompt and tools if installations were successful if all_install_successful { - (tools, toolshim_tools, system_prompt) = self.prepare_tools_and_prompt().await?; + tools_updated = true; } } @@ -1024,63 +1005,39 @@ impl Agent { yield AgentEvent::Message(final_message_tool_resp.clone()); added_message = true; - push_message(&mut messages, response); - push_message(&mut messages, final_message_tool_resp); - - // Check for MCP notifications from subagents again before next iteration - // Note: These are already handled as McpNotification events above, - // so we don't need to convert them to assistant messages here. - // This was causing duplicate plain-text notifications. - // let mcp_notifications = self.get_mcp_notifications().await; - // for notification in mcp_notifications { - // // Extract subagent info from the notification data for assistant messages - // if let JsonRpcMessage::Notification(ref notif) = notification { - // if let Some(params) = ¬if.params { - // if let Some(data) = params.get("data") { - // if let (Some(subagent_id), Some(message)) = ( - // data.get("subagent_id").and_then(|v| v.as_str()), - // data.get("message").and_then(|v| v.as_str()) - // ) { - // yield AgentEvent::Message( - // Message::assistant().with_text( - // format!("Subagent {}: {}", subagent_id, message) - // ) - // ); - // } - // } - // } - // } - // } + push_message(&mut messages_to_add, response); + push_message(&mut messages_to_add, final_message_tool_resp); } - }, + } Err(ProviderError::ContextLengthExceeded(_)) => { - // At this point, the last message should be a user message - // because call to provider led to context length exceeded error - // Immediately yield a special message and break yield AgentEvent::Message(Message::assistant().with_context_length_exceeded( - "The context length of the model has been exceeded. Please start a new session and try again.", - )); + "The context length of the model has been exceeded. Please start a new session and try again.", + )); break; - }, + } Err(e) => { - // Create an error message & terminate the stream error!("Error: {}", e); - yield AgentEvent::Message(Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error."))); + yield AgentEvent::Message(Message::assistant().with_text( + format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error.") + )); break; } } } + if tools_updated { + (tools, toolshim_tools, system_prompt) = self.prepare_tools_and_prompt().await?; + } if !added_message { if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() { if final_output_tool.final_output.is_none() { tracing::warn!("Final output tool has not been called yet. Continuing agent loop."); let message = Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE); - messages.push(message.clone()); + messages_to_add.push(message.clone()); yield AgentEvent::Message(message); - continue; + continue } else { let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()); - messages.push(message.clone()); + messages_to_add.push(message.clone()); yield AgentEvent::Message(message); } } @@ -1099,16 +1056,28 @@ impl Agent { )); } } - break; } - // Yield control back to the scheduler to prevent blocking + messages.extend(messages_to_add); + tokio::task::yield_now().await; } })) } + fn determine_goose_mode(session: Option<&SessionConfig>, config: &Config) -> String { + let mode = session.and_then(|s| s.execution_mode.as_deref()); + + match mode { + Some("foreground") => "auto".to_string(), + Some("background") => "chat".to_string(), + _ => config + .get_param("GOOSE_MODE") + .unwrap_or_else(|_| "auto".to_string()), + } + } + /// Extend the system prompt with one line of additional instruction pub async fn extend_system_prompt(&self, instruction: String) { let mut prompt_manager = self.prompt_manager.lock().await; @@ -1127,7 +1096,6 @@ impl Agent { notifications } - /// Update the provider pub async fn update_provider(&self, provider: Arc) -> Result<()> { let mut current_provider = self.provider.lock().await; *current_provider = Some(provider.clone()); @@ -1191,10 +1159,9 @@ impl Agent { ) .await { - tracing::error!( + error!( "Failed to index tools for extension {}: {}", - extension_name, - e + extension_name, e ); } } @@ -1243,7 +1210,7 @@ impl Agent { Err(anyhow!("Prompt '{}' not found", name)) } - pub async fn get_plan_prompt(&self) -> anyhow::Result { + pub async fn get_plan_prompt(&self) -> Result { let extension_manager = self.extension_manager.read().await; let tools = extension_manager.get_prefixed_tools(None).await?; let tools_info = tools @@ -1265,7 +1232,7 @@ impl Agent { pub async fn handle_tool_result(&self, id: String, result: ToolResult>) { if let Err(e) = self.tool_result_tx.send((id, result)).await { - tracing::error!("Failed to send tool result: {}", e); + error!("Failed to send tool result: {}", e); } } @@ -1356,7 +1323,7 @@ impl Agent { let activities_text = activities_text.trim(); // Regex to remove bullet markers or numbers with an optional dot. - let bullet_re = Regex::new(r"^[•\-\*\d]+\.?\s*").expect("Invalid regex"); + let bullet_re = Regex::new(r"^[•\-*\d]+\.?\s*").expect("Invalid regex"); // Process each line in the activities section. let activities: Vec = activities_text @@ -1383,7 +1350,7 @@ impl Agent { metadata: None, }; - // Ideally we'd get the name of the provider we are using from the provider itself + // Ideally we'd get the name of the provider we are using from the provider itself, // but it doesn't know and the plumbing looks complicated. let config = Config::global(); let provider_name: String = config @@ -1443,7 +1410,7 @@ mod tests { let prompt_manager = agent.prompt_manager.lock().await; let system_prompt = - prompt_manager.build_system_prompt(vec![], None, serde_json::Value::Null, None, None); + prompt_manager.build_system_prompt(vec![], None, Value::Null, None, None); let final_output_tool_ref = agent.final_output_tool.lock().await; let final_output_tool_system_prompt = diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index 486b09fb..64bd196d 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -275,10 +275,9 @@ impl Agent { (frontend_requests, other_requests, filtered_message) } - /// Update session metrics after a response pub(crate) async fn update_session_metrics( - session_config: crate::agents::types::SessionConfig, - usage: &crate::providers::base::ProviderUsage, + session_config: &crate::agents::types::SessionConfig, + usage: &ProviderUsage, messages_length: usize, ) -> Result<()> { let session_file_path = match session::storage::get_path(session_config.id.clone()) { diff --git a/crates/goose/src/scheduler.rs b/crates/goose/src/scheduler.rs index 15205857..50e06c24 100644 --- a/crates/goose/src/scheduler.rs +++ b/crates/goose/src/scheduler.rs @@ -1208,7 +1208,7 @@ async fn run_scheduled_job_internal( }; match agent - .reply(&all_session_messages, Some(session_config.clone())) + .reply(&all_session_messages, Some(session_config.clone()), None) .await { Ok(mut stream) => { diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index 6f21bd48..ab8b8cb1 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -129,7 +129,7 @@ async fn run_truncate_test( ), ]; - let reply_stream = agent.reply(&messages, None).await?; + let reply_stream = agent.reply(&messages, None, None).await?; tokio::pin!(reply_stream); let mut responses = Vec::new(); @@ -619,7 +619,7 @@ mod final_output_tool_tests { ); // Simulate the reply stream continuing after the final output tool call. - let reply_stream = agent.reply(&vec![], None).await?; + let reply_stream = agent.reply(&vec![], None, None).await?; tokio::pin!(reply_stream); let mut responses = Vec::new(); @@ -716,7 +716,7 @@ mod final_output_tool_tests { agent.add_final_output_tool(response).await; // Simulate the reply stream being called. - let reply_stream = agent.reply(&vec![], None).await?; + let reply_stream = agent.reply(&vec![], None, None).await?; tokio::pin!(reply_stream); let mut responses = Vec::new(); @@ -850,7 +850,9 @@ mod retry_tests { let initial_messages = vec![Message::user().with_text("Complete this task")]; - let reply_stream = agent.reply(&initial_messages, Some(session_config)).await?; + let reply_stream = agent + .reply(&initial_messages, Some(session_config), None) + .await?; tokio::pin!(reply_stream); let mut responses = Vec::new(); @@ -1013,7 +1015,7 @@ mod max_turns_tests { }; let messages = vec![Message::user().with_text("Hello")]; - let reply_stream = agent.reply(&messages, Some(session_config)).await?; + let reply_stream = agent.reply(&messages, Some(session_config), None).await?; tokio::pin!(reply_stream); let mut responses = Vec::new(); diff --git a/ui/desktop/openapi.json b/ui/desktop/openapi.json index 971f05db..6814ace4 100644 --- a/ui/desktop/openapi.json +++ b/ui/desktop/openapi.json @@ -2332,17 +2332,16 @@ } }, "ResourceContents": { - "oneOf": [ + "anyOf": [ { "type": "object", "required": [ - "uri", - "text" + "text", + "uri" ], "properties": { "mime_type": { - "type": "string", - "nullable": true + "type": "string" }, "text": { "type": "string" @@ -2355,16 +2354,15 @@ { "type": "object", "required": [ - "uri", - "blob" + "blob", + "uri" ], "properties": { "blob": { "type": "string" }, "mime_type": { - "type": "string", - "nullable": true + "type": "string" }, "uri": { "type": "string" diff --git a/ui/desktop/src/App.tsx b/ui/desktop/src/App.tsx index e4b2881a..968f2547 100644 --- a/ui/desktop/src/App.tsx +++ b/ui/desktop/src/App.tsx @@ -1174,7 +1174,7 @@ export default function App() { }, []); const config = window.electron.getConfig(); - const STRICT_ALLOWLIST = config.GOOSE_ALLOWLIST_WARNING === true ? false : true; + const STRICT_ALLOWLIST = config.GOOSE_ALLOWLIST_WARNING !== true; useEffect(() => { console.log('Setting up extension handler'); diff --git a/ui/desktop/src/api/types.gen.ts b/ui/desktop/src/api/types.gen.ts index 53230d06..6df01c78 100644 --- a/ui/desktop/src/api/types.gen.ts +++ b/ui/desktop/src/api/types.gen.ts @@ -463,12 +463,12 @@ export type RedactedThinkingContent = { }; export type ResourceContents = { - mime_type?: string | null; + mime_type?: string; text: string; uri: string; } | { blob: string; - mime_type?: string | null; + mime_type?: string; uri: string; }; diff --git a/ui/desktop/src/components/BaseChat.tsx b/ui/desktop/src/components/BaseChat.tsx index c20c29de..2c90613f 100644 --- a/ui/desktop/src/components/BaseChat.tsx +++ b/ui/desktop/src/components/BaseChat.tsx @@ -537,7 +537,7 @@ function BaseChatContent({ recipeDetails={{ title: recipeConfig?.title, description: recipeConfig?.description, - instructions: recipeConfig?.instructions, + instructions: recipeConfig?.instructions || undefined, }} /> diff --git a/ui/desktop/src/components/ToolCallWithResponse.tsx b/ui/desktop/src/components/ToolCallWithResponse.tsx index a75a7065..edf896a8 100644 --- a/ui/desktop/src/components/ToolCallWithResponse.tsx +++ b/ui/desktop/src/components/ToolCallWithResponse.tsx @@ -126,10 +126,10 @@ const notificationToProgress = (notification: NotificationEvent): Progress => const getExtensionTooltip = (toolCallName: string): string | null => { const lastIndex = toolCallName.lastIndexOf('__'); if (lastIndex === -1) return null; - + const extensionName = toolCallName.substring(0, lastIndex); if (!extensionName) return null; - + return `${extensionName} extension`; }; @@ -377,7 +377,7 @@ function ToolCallView({ // This ensures any MCP tool works without explicit handling const toolDisplayName = snakeToTitleCase(toolName); const entries = Object.entries(args); - + if (entries.length === 0) { return `${toolDisplayName}`; } @@ -413,7 +413,7 @@ function ToolCallView({ }; const toolLabel = ( - + {getToolLabelContent()} ); diff --git a/ui/desktop/src/main.ts b/ui/desktop/src/main.ts index 28a8e25b..84ac0b52 100644 --- a/ui/desktop/src/main.ts +++ b/ui/desktop/src/main.ts @@ -1,20 +1,20 @@ +import type { OpenDialogReturnValue } from 'electron'; import { app, - session, + App, BrowserWindow, dialog, + Event, + globalShortcut, ipcMain, Menu, MenuItem, Notification, powerSaveBlocker, - Tray, - App, - globalShortcut, + session, shell, - Event, + Tray, } from 'electron'; -import type { OpenDialogReturnValue } from 'electron'; import { Buffer } from 'node:buffer'; import fs from 'node:fs/promises'; import fsSync from 'node:fs'; @@ -33,20 +33,20 @@ import { EnvToggles, loadSettings, saveSettings, + SchedulingEngine, updateEnvironmentVariables, updateSchedulingEngineEnvironment, - SchedulingEngine, } from './utils/settings'; import * as crypto from 'crypto'; // import electron from "electron"; import * as yaml from 'yaml'; import windowStateKeeper from 'electron-window-state'; import { - setupAutoUpdater, + getUpdateAvailable, registerUpdateIpcHandlers, setTrayRef, + setupAutoUpdater, updateTrayMenu, - getUpdateAvailable, } from './utils/autoUpdater'; import { UPDATES_ENABLED } from './updates'; import { Recipe } from './recipe'; @@ -589,7 +589,7 @@ const createChat = async ( titleBarStyle: process.platform === 'darwin' ? 'hidden' : 'default', trafficLightPosition: process.platform === 'darwin' ? { x: 20, y: 16 } : undefined, vibrancy: process.platform === 'darwin' ? 'window' : undefined, - frame: process.platform === 'darwin' ? false : true, + frame: process.platform !== 'darwin', x: mainWindowState.x, y: mainWindowState.y, width: mainWindowState.width, @@ -1542,15 +1542,8 @@ ipcMain.handle('show-message-box', async (_event, options) => { return result; }); -// Handle allowed extensions list fetching ipcMain.handle('get-allowed-extensions', async () => { - try { - const allowList = await getAllowList(); - return allowList; - } catch (error) { - console.error('Error fetching allowed extensions:', error); - throw error; - } + return await getAllowList(); }); const createNewWindow = async (app: App, dir?: string | null) => { @@ -2068,57 +2061,33 @@ app.whenReady().then(async () => { }); }); -/** - * Fetches the allowed extensions list from the remote YAML file if GOOSE_ALLOWLIST is set. - * If the ALLOWLIST is not set, any are allowed. If one is set, it will warn if the deeplink - * doesn't match a command from the list. - * If it fails to load, then it will return an empty list. - * If the format is incorrect, it will return an empty list. - * Format of yaml is: - * - ```yaml: - extensions: - - id: slack - command: uvx mcp_slack - - id: knowledge_graph_memory - command: npx -y @modelcontextprotocol/server-memory - ``` - * - * @returns A promise that resolves to an array of extension commands that are allowed. - */ async function getAllowList(): Promise { if (!process.env.GOOSE_ALLOWLIST) { return []; } - try { - // Fetch the YAML file - const response = await fetch(process.env.GOOSE_ALLOWLIST); + const response = await fetch(process.env.GOOSE_ALLOWLIST); - if (!response.ok) { - throw new Error( - `Failed to fetch allowed extensions: ${response.status} ${response.statusText}` - ); - } + if (!response.ok) { + throw new Error( + `Failed to fetch allowed extensions: ${response.status} ${response.statusText}` + ); + } - // Parse the YAML content - const yamlContent = await response.text(); - const parsedYaml = yaml.parse(yamlContent); + // Parse the YAML content + const yamlContent = await response.text(); + const parsedYaml = yaml.parse(yamlContent); - // Extract the commands from the extensions array - if (parsedYaml && parsedYaml.extensions && Array.isArray(parsedYaml.extensions)) { - const commands = parsedYaml.extensions.map( - (ext: { id: string; command: string }) => ext.command - ); - console.log(`Fetched ${commands.length} allowed extension commands`); - return commands; - } else { - console.error('Invalid YAML structure:', parsedYaml); - return []; - } - } catch (error) { - console.error('Error in getAllowList:', error); - throw error; + // Extract the commands from the extensions array + if (parsedYaml && parsedYaml.extensions && Array.isArray(parsedYaml.extensions)) { + const commands = parsedYaml.extensions.map( + (ext: { id: string; command: string }) => ext.command + ); + console.log(`Fetched ${commands.length} allowed extension commands`); + return commands; + } else { + console.error('Invalid YAML structure:', parsedYaml); + return []; } } diff --git a/ui/desktop/src/utils/askAI.ts b/ui/desktop/src/utils/askAI.ts deleted file mode 100644 index bdf7a0d2..00000000 --- a/ui/desktop/src/utils/askAI.ts +++ /dev/null @@ -1,198 +0,0 @@ -import { getApiUrl, getSecretKey } from '../config'; -import { safeJsonParse } from './jsonUtils'; - -const getQuestionClassifierPrompt = (messageContent: string): string => ` -You are a simple classifier that takes content and decides if it is asking for input -from a person before continuing if there is more to do, or not. These are questions -on if a course of action should proceeed or not, or approval is needed. If it is CLEARLY a -question asking if it ok to proceed or make a choice or some input is required to proceed, then, and ONLY THEN, return QUESTION, otherwise READY if not 97% sure. - -### Examples message content that is classified as READY: -anything else I can do? -Could you please run the application and verify that the headlines are now visible in dark mode? You can use npm start. -Would you like me to make any adjustments to the formatting of these multiline strings? -Would you like me to show you how to ... (do something)? -Listing window titles... Is there anything specific you'd like help with using these tools? -Would you like me to demonstrate any specific capability or help you with a particular task? -Would you like me to run any tests? -Would you like me to make any adjustments or would you like to test? -Would you like me to dive deeper into any aspect? -Would you like me to make any other adjustments to this implementation? -Would you like any further information or assistance? -Would you like to me to make any changes? -Would you like me to make any adjustments to this implementation? -Would you like me to show you how to… -What would you like to do next? - -### Examples that are QUESTIONS: -Should I go ahead and make the changes? -Should I Go ahead with this plan? -Should I focus on X or Y? -Provide me with the name of the package and version you would like to install. - - -### Message Content: -${messageContent} - -You must provide a response strictly limited to one of the following two words: -QUESTION, READY. No other words, phrases, or explanations are allowed. - -Response:`; - -const getOptionsClassifierPrompt = (messageContent: string): string => ` -You are a simple classifier that takes content and decides if it a list of options -or plans to choose from, or not a list of options to choose from. It is IMPORTANT -that you really know this is a choice, just not numbered steps. If it is a list -of options and you are 95% sure, return OPTIONS, otherwise return NO. - -### Example (text -> response): -Would you like me to proceed with creating this file? Please let me know if you want any changes before I write it. -> NO -Here are some options for you to choose from: -> OPTIONS -which one do you want to choose? -> OPTIONS -Would you like me to dive deeper into any aspects of these components? -> NO -Should I focus on X or Y? -> OPTIONS - -### Message Content: -${messageContent} - -You must provide a response strictly limited to one of the following two words: -OPTIONS, NO. No other words, phrases, or explanations are allowed. - -Response:`; - -const getOptionsFormatterPrompt = (messageContent: string): string => ` -If the content is list of distinct options or plans of action to choose from, and -not just a list of things, but clearly a list of things to choose one from, taking -into account the Message Content alone, try to format it in a json array, like this -JSON array of objects of the form optionTitle:string, optionDescription:string (markdown). - -If is not a list of options or plans to choose from, then return empty list. - -### Message Content: -${messageContent} - -You must provide a response strictly as json in the format descriribed. No other -words, phrases, or explanations are allowed. - -Response:`; - -const getFormPrompt = (messageContent: string): string => ` -When you see a request for several pieces of information, then provide a well formed JSON object like will be shown below. -The response will have: -* a title, description, -* a list of fields, each field will have a label, type, name, placeholder, and required (boolean). -(type is either text or textarea only). -If it is not requesting clearly several pieces of information, just return an empty object. -If the task could be confirmed without more information, return an empty object. - -### Example Message: -I'll help you scaffold out a Python package. To create a well-structured Python package, I'll need to know a few key pieces of information: - -Package name - What would you like to call your package? (This should be a valid Python package name - lowercase, no spaces, typically using underscores for separators if needed) - -Brief description - What is the main purpose of the package? This helps in setting up appropriate documentation and structure. - -Initial modules - Do you have specific functionality in mind that should be split into different modules? - -Python version - Which Python version(s) do you want to support? - -Dependencies - Are there any known external packages you'll need? - -### Example JSON Response: -{ - "title": "Python Package Scaffolding Form", - "description": "Provide the details below to scaffold a well-structured Python package.", - "fields": [ - { - "label": "Package Name", - "type": "text", - "name": "package_name", - "placeholder": "Enter the package name (lowercase, no spaces, use underscores if needed)", - "required": true - }, - { - "label": "Brief Description", - "type": "textarea", - "name": "brief_description", - "placeholder": "Enter a brief description of the package's purpose", - "required": true - }, - { - "label": "Initial Modules", - "type": "textarea", - "name": "initial_modules", - "placeholder": "List the specific functionalities or modules (optional)", - "required": false - }, - { - "label": "Python Version(s)", - "type": "text", - "name": "python_versions", - "placeholder": "Enter the Python version(s) to support (e.g., 3.8, 3.9, 3.10)", - "required": true - }, - { - "label": "Dependencies", - "type": "textarea", - "name": "dependencies", - "placeholder": "List any known external packages you'll need (optional)", - "required": false - } - ] -} - -### Message Content: -${messageContent} - -You must provide a response strictly as json in the format described. No other -words, phrases, or explanations are allowed. - -Response:`; - -/** - * Core function to ask the AI a single question and get a response - * @param prompt The prompt to send to the AI - * @returns Promise The AI's response - */ -export async function ask(prompt: string): Promise { - const response = await fetch(getApiUrl('/ask'), { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'X-Secret-Key': getSecretKey(), - }, - body: JSON.stringify({ prompt }), - }); - - if (!response.ok) { - throw new Error('Failed to get response'); - } - - const data = await safeJsonParse<{ response: string }>(response, 'Failed to get AI response'); - return data.response; -} - -/** - * Utility to ask the LLM multiple questions to clarify without wider context. - * @param messageContent The content to analyze - * @returns Promise Array of responses from the AI for each classifier - */ -export async function askAi(messageContent: string): Promise { - // First, check the question classifier - const questionClassification = await ask(getQuestionClassifierPrompt(messageContent)); - - // If READY, return early with empty responses for options - if (questionClassification === 'READY') { - return [questionClassification, 'NO', '[]', '{}']; - } - - // Otherwise, proceed with all classifiers in parallel - const prompts = [ - Promise.resolve(questionClassification), // Reuse the result we already have - ask(getOptionsClassifierPrompt(messageContent)), - ask(getOptionsFormatterPrompt(messageContent)), - ask(getFormPrompt(messageContent)), - ]; - - return Promise.all(prompts); -}