Agent loop defensive (#3554)

Co-authored-by: Douwe Osinga <douwe@squareup.com>
This commit is contained in:
Douwe Osinga
2025-07-22 18:21:47 +02:00
committed by GitHub
parent b3cd03ef61
commit 9f356e7009
19 changed files with 349 additions and 773 deletions

2
Cargo.lock generated
View File

@@ -3568,6 +3568,7 @@ dependencies = [
"test-case",
"tokio",
"tokio-stream",
"tokio-util",
"tower-http",
"tracing",
"tracing-appender",
@@ -3703,6 +3704,7 @@ dependencies = [
"tokio",
"tokio-cron-scheduler",
"tokio-stream",
"tokio-util",
"tower 0.5.2",
"tower-http",
"tracing",

View File

@@ -69,8 +69,8 @@ tokio-stream = "0.1"
bytes = "1.5"
http = "1.0"
webbrowser = "1.0"
indicatif = "0.17.11"
tokio-util = "0.7.15"
[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = ["wincred"] }

View File

@@ -209,7 +209,7 @@ async fn serve_static(axum::extract::Path(path): axum::extract::Path<String>) ->
include_bytes!("../../../../documentation/static/img/logo_light.png").to_vec(),
)
.into_response(),
_ => (axum::http::StatusCode::NOT_FOUND, "Not found").into_response(),
_ => (http::StatusCode::NOT_FOUND, "Not found").into_response(),
}
}
@@ -484,7 +484,6 @@ async fn process_message_streaming(
)
.await?;
// Create a session config
let session_config = SessionConfig {
id: session::Identifier::Path(session_file.clone()),
working_dir: std::env::current_dir()?,
@@ -494,8 +493,7 @@ async fn process_message_streaming(
retry_config: None,
};
// Get response from agent
match agent.reply(&messages, Some(session_config)).await {
match agent.reply(&messages, Some(session_config), None).await {
Ok(mut stream) => {
while let Some(result) = stream.next().await {
match result {

View File

@@ -48,6 +48,7 @@ use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
use tokio;
use tokio_util::sync::CancellationToken;
pub enum RunMode {
Normal,
@@ -132,13 +133,10 @@ impl Session {
retry_config: Option<RetryConfig>,
) -> Self {
let messages = if let Some(session_file) = &session_file {
match session::read_messages(session_file) {
Ok(msgs) => msgs,
Err(e) => {
eprintln!("Warning: Failed to load message history: {}", e);
Vec::new()
}
}
session::read_messages(session_file).unwrap_or_else(|e| {
eprintln!("Warning: Failed to load message history: {}", e);
Vec::new()
})
} else {
// Don't try to read messages if we're not saving sessions
Vec::new()
@@ -180,7 +178,7 @@ impl Session {
/// Format: "ENV1=val1 ENV2=val2 command args..."
pub async fn add_extension(&mut self, extension_command: String) -> Result<()> {
let mut parts: Vec<&str> = extension_command.split_whitespace().collect();
let mut envs = std::collections::HashMap::new();
let mut envs = HashMap::new();
// Parse environment variables (format: KEY=value)
while let Some(part) = parts.first() {
@@ -473,7 +471,7 @@ impl Session {
self.display_context_usage().await?;
match input::get_input(&mut editor)? {
input::InputResult::Message(content) => {
InputResult::Message(content) => {
match self.run_mode {
RunMode::Normal => {
save_history(&mut editor);
@@ -495,15 +493,11 @@ impl Session {
eprintln!("Warning: Failed to update project tracker with instruction: {}", e);
}
// Get the provider from the agent for description generation
let provider = self.agent.provider().await?;
// Persist messages with provider for automatic description generation
if let Some(session_file) = &self.session_file {
let working_dir = Some(
std::env::current_dir()
.expect("failed to get current session working directory"),
);
let working_dir = Some(std::env::current_dir().unwrap_or_default());
session::persist_messages_with_schedule_id(
session_file,
@@ -847,12 +841,14 @@ impl Session {
}
async fn process_agent_response(&mut self, interactive: bool) -> Result<()> {
let cancel_token = CancellationToken::new();
let cancel_token_clone = cancel_token.clone();
let session_config = self.session_file.as_ref().map(|s| {
let session_id = session::Identifier::Path(s.clone());
SessionConfig {
id: session_id.clone(),
working_dir: std::env::current_dir()
.expect("failed to get current session working directory"),
working_dir: std::env::current_dir().unwrap_or_default(),
schedule_id: self.scheduled_job_id.clone(),
execution_mode: None,
max_turns: self.max_turns,
@@ -861,7 +857,7 @@ impl Session {
});
let mut stream = self
.agent
.reply(&self.messages, session_config.clone())
.reply(&self.messages, session_config.clone(), Some(cancel_token))
.await?;
let mut progress_bars = output::McpSpinners::new();
@@ -919,7 +915,7 @@ impl Session {
)
.await?;
}
cancel_token_clone.cancel();
drop(stream);
break;
} else {
@@ -1001,6 +997,7 @@ impl Session {
.reply(
&self.messages,
session_config.clone(),
None
)
.await?;
}
@@ -1157,6 +1154,7 @@ impl Session {
Some(Err(e)) => {
eprintln!("Error: {}", e);
cancel_token_clone.cancel();
drop(stream);
if let Err(e) = self.handle_interrupted_messages(false).await {
eprintln!("Error handling interruption: {}", e);
@@ -1173,6 +1171,7 @@ impl Session {
}
}
_ = tokio::signal::ctrl_c() => {
cancel_token_clone.cancel();
drop(stream);
if let Err(e) = self.handle_interrupted_messages(true).await {
eprintln!("Error handling interruption: {}", e);

View File

@@ -247,7 +247,7 @@ pub unsafe extern "C" fn goose_agent_send_message(
// Block on the async call using our global runtime
let response = get_runtime().block_on(async {
let mut stream = match agent.reply(&messages, None).await {
let mut stream = match agent.reply(&messages, None, None).await {
Ok(stream) => stream,
Err(e) => return format!("Error getting reply from agent: {}", e),
};

View File

@@ -42,6 +42,7 @@ axum-extra = "0.10.0"
utoipa = { version = "4.1", features = ["axum_extras", "chrono"] }
dirs = "6.0.0"
reqwest = { version = "0.12.9", features = ["json", "rustls-tls", "blocking", "multipart"], default-features = false }
tokio-util = "0.7.15"
[[bin]]
name = "goosed"

View File

@@ -11,7 +11,7 @@ use bytes::Bytes;
use futures::{stream::StreamExt, Stream};
use goose::{
agents::{AgentEvent, SessionConfig},
message::{push_message, Message, MessageContent},
message::{push_message, Message},
permission::permission_confirmation::PrincipalType,
};
use goose::{
@@ -19,7 +19,7 @@ use goose::{
session,
};
use mcp_core::{protocol::JsonRpcMessage, ToolResult};
use rmcp::model::{Content, Role};
use rmcp::model::Content;
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::Value;
@@ -34,9 +34,10 @@ use std::{
use tokio::sync::mpsc;
use tokio::time::timeout;
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::sync::CancellationToken;
use utoipa::ToSchema;
#[derive(Debug, Deserialize)]
#[derive(Debug, Deserialize, Serialize)]
struct ChatRequest {
messages: Vec<Message>,
session_id: Option<String>,
@@ -113,7 +114,7 @@ async fn stream_event(
tx.send(format!("data: {}\n\n", json)).await
}
async fn handler(
async fn reply_handler(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(request): Json<ChatRequest>,
@@ -122,6 +123,7 @@ async fn handler(
let (tx, rx) = mpsc::channel(100);
let stream = ReceiverStream::new(rx);
let cancel_token = CancellationToken::new();
let messages = request.messages;
let session_working_dir = request.session_working_dir.clone();
@@ -130,65 +132,35 @@ async fn handler(
.session_id
.unwrap_or_else(session::generate_session_id);
tokio::spawn(async move {
let agent = state.get_agent().await;
let agent = match agent {
Ok(agent) => {
let provider = agent.provider().await;
match provider {
Ok(_) => agent,
Err(_) => {
let _ = stream_event(
MessageEvent::Error {
error: "No provider configured".to_string(),
},
&tx,
)
.await;
let _ = stream_event(
MessageEvent::Finish {
reason: "error".to_string(),
},
&tx,
)
.await;
return;
}
}
}
let task_cancel = cancel_token.clone();
let task_tx = tx.clone();
std::mem::drop(tokio::spawn(async move {
let agent = match state.get_agent().await {
Ok(agent) => agent,
Err(_) => {
let _ = stream_event(
MessageEvent::Error {
error: "No agent configured".to_string(),
},
&tx,
)
.await;
let _ = stream_event(
MessageEvent::Finish {
reason: "error".to_string(),
},
&tx,
&task_tx,
)
.await;
return;
}
};
let provider = agent.provider().await;
let session_config = SessionConfig {
id: session::Identifier::Name(session_id.clone()),
working_dir: PathBuf::from(&session_working_dir),
schedule_id: request.scheduled_job_id.clone(),
execution_mode: None,
max_turns: None,
retry_config: None,
};
let mut stream = match agent
.reply(
&messages,
Some(SessionConfig {
id: session::Identifier::Name(session_id.clone()),
working_dir: PathBuf::from(&session_working_dir),
schedule_id: request.scheduled_job_id.clone(),
execution_mode: None,
max_turns: None,
retry_config: None,
}),
)
.reply(&messages, Some(session_config), Some(task_cancel.clone()))
.await
{
Ok(stream) => stream,
@@ -198,14 +170,7 @@ async fn handler(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
)
.await;
let _ = stream_event(
MessageEvent::Finish {
reason: "error".to_string(),
},
&tx,
&task_tx,
)
.await;
return;
@@ -221,7 +186,7 @@ async fn handler(
MessageEvent::Error {
error: format!("Failed to get session path: {}", e),
},
&tx,
&task_tx,
)
.await;
return;
@@ -231,222 +196,104 @@ async fn handler(
loop {
tokio::select! {
response = timeout(Duration::from_millis(500), stream.next()) => {
match response {
Ok(Some(Ok(AgentEvent::Message(message)))) => {
push_message(&mut all_messages, message.clone());
if let Err(e) = stream_event(MessageEvent::Message { message }, &tx).await {
tracing::error!("Error sending message through channel: {}", e);
let _ = stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
).await;
_ = task_cancel.cancelled() => {
tracing::info!("Agent task cancelled");
break;
}
}
Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => {
if let Err(e) = stream_event(MessageEvent::ModelChange { model, mode }, &tx).await {
tracing::error!("Error sending model change through channel: {}", e);
let _ = stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
).await;
}
}
Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => {
if let Err(e) = stream_event(MessageEvent::Notification{
request_id: request_id.clone(),
message: n,
}, &tx).await {
tracing::error!("Error sending message through channel: {}", e);
let _ = stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
).await;
}
}
response = timeout(Duration::from_millis(500), stream.next()) => {
match response {
Ok(Some(Ok(AgentEvent::Message(message)))) => {
push_message(&mut all_messages, message.clone());
if let Err(e) = stream_event(MessageEvent::Message { message }, &tx).await {
tracing::error!("Error sending message through channel: {}", e);
let _ = stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
).await;
break;
}
}
Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => {
if let Err(e) = stream_event(MessageEvent::ModelChange { model, mode }, &tx).await {
tracing::error!("Error sending model change through channel: {}", e);
let _ = stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
).await;
}
}
Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => {
if let Err(e) = stream_event(MessageEvent::Notification{
request_id: request_id.clone(),
message: n,
}, &tx).await {
tracing::error!("Error sending message through channel: {}", e);
let _ = stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
).await;
}
}
Ok(Some(Err(e))) => {
tracing::error!("Error processing message: {}", e);
let _ = stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
).await;
break;
}
Ok(None) => {
break;
}
Err(_) => { // Heartbeat, used to detect disconnected clients
if tx.is_closed() {
break;
Ok(Some(Err(e))) => {
tracing::error!("Error processing message: {}", e);
let _ = stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
).await;
break;
}
Ok(None) => {
break;
}
Err(_) => {
if tx.is_closed() {
break;
}
continue;
}
}
}
continue;
}
}
}
}
}
if all_messages.len() > saved_message_count {
let provider = Arc::clone(provider.as_ref().unwrap());
tokio::spawn(async move {
if let Err(e) = session::persist_messages(
&session_path,
&all_messages,
Some(provider),
Some(PathBuf::from(&session_working_dir)),
)
.await
{
tracing::error!("Failed to store session history: {:?}", e);
}
});
if let Ok(provider) = agent.provider().await {
let provider = Arc::clone(&provider);
tokio::spawn(async move {
if let Err(e) = session::persist_messages(
&session_path,
&all_messages,
Some(provider),
Some(PathBuf::from(&session_working_dir)),
)
.await
{
tracing::error!("Failed to store session history: {:?}", e);
}
});
}
}
let _ = stream_event(
MessageEvent::Finish {
reason: "stop".to_string(),
},
&tx,
&task_tx,
)
.await;
});
}));
Ok(SseResponse::new(stream))
}
#[derive(Debug, Deserialize, Serialize)]
struct AskRequest {
prompt: String,
session_id: Option<String>,
session_working_dir: String,
scheduled_job_id: Option<String>,
}
#[derive(Debug, Serialize)]
struct AskResponse {
response: String,
}
async fn ask_handler(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(request): Json<AskRequest>,
) -> Result<Json<AskResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;
let session_working_dir = request.session_working_dir.clone();
let session_id = request
.session_id
.unwrap_or_else(session::generate_session_id);
let agent = state
.get_agent()
.await
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
let provider = agent.provider().await;
let messages = vec![Message::user().with_text(request.prompt)];
let mut response_text = String::new();
let mut stream = match agent
.reply(
&messages,
Some(SessionConfig {
id: session::Identifier::Name(session_id.clone()),
working_dir: PathBuf::from(&session_working_dir),
schedule_id: request.scheduled_job_id.clone(),
execution_mode: None,
max_turns: None,
retry_config: None,
}),
)
.await
{
Ok(stream) => stream,
Err(e) => {
tracing::error!("Failed to start reply stream: {:?}", e);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
};
let mut all_messages = messages.clone();
let mut response_message = Message::assistant();
while let Some(response) = stream.next().await {
match response {
Ok(AgentEvent::Message(message)) => {
if message.role == Role::Assistant {
for content in &message.content {
if let MessageContent::Text(text) = content {
response_text.push_str(&text.text);
response_text.push('\n');
}
response_message.content.push(content.clone());
}
}
}
Ok(AgentEvent::ModelChange { model, mode }) => {
// Log model change for non-streaming
tracing::info!("Model changed to {} in {} mode", model, mode);
}
Ok(AgentEvent::McpNotification(n)) => {
// Handle notifications if needed
tracing::info!("Received notification: {:?}", n);
}
Err(e) => {
tracing::error!("Error processing as_ai message: {}", e);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
}
}
if !response_message.content.is_empty() {
push_message(&mut all_messages, response_message);
}
let session_path = match session::get_path(session::Identifier::Name(session_id.clone())) {
Ok(path) => path,
Err(e) => {
tracing::error!("Failed to get session path: {}", e);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
};
let session_path_clone = session_path.clone();
let messages = all_messages.clone();
let provider = Arc::clone(provider.as_ref().unwrap());
let session_working_dir_clone = session_working_dir.clone();
tokio::spawn(async move {
if let Err(e) = session::persist_messages(
&session_path_clone,
&messages,
Some(provider),
Some(PathBuf::from(session_working_dir_clone)),
)
.await
{
tracing::error!("Failed to store session history: {:?}", e);
}
});
Ok(Json(AskResponse {
response: response_text.trim().to_string(),
}))
}
#[derive(Debug, Deserialize, Serialize, ToSchema)]
pub struct PermissionConfirmationRequest {
id: String,
@@ -509,7 +356,7 @@ struct ToolResultRequest {
async fn submit_tool_result(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
raw: axum::extract::Json<serde_json::Value>,
raw: Json<Value>,
) -> Result<Json<Value>, StatusCode> {
verify_secret_key(&headers, &state)?;
@@ -540,8 +387,7 @@ async fn submit_tool_result(
pub fn routes(state: Arc<AppState>) -> Router {
Router::new()
.route("/reply", post(handler))
.route("/ask", post(ask_handler))
.route("/reply", post(reply_handler))
.route("/confirm", post(confirm_permission))
.route("/tool_result", post(submit_tool_result))
.with_state(state)
@@ -571,10 +417,6 @@ mod tests {
goose::providers::base::ProviderMetadata::empty()
}
fn get_model_config(&self) -> ModelConfig {
self.model_config.clone()
}
async fn complete(
&self,
_system: &str,
@@ -586,6 +428,10 @@ mod tests {
ProviderUsage::new("mock".to_string(), Usage::default()),
))
}
fn get_model_config(&self) -> ModelConfig {
self.model_config.clone()
}
}
mod integration_tests {
@@ -595,7 +441,7 @@ mod tests {
use tower::ServiceExt;
#[tokio::test]
async fn test_ask_endpoint() {
async fn test_reply_endpoint() {
let mock_model_config = ModelConfig::new("test-model".to_string());
let mock_provider = Arc::new(MockProvider {
model_config: mock_model_config,
@@ -603,24 +449,17 @@ mod tests {
let agent = Agent::new();
let _ = agent.update_provider(mock_provider).await;
let state = AppState::new(Arc::new(agent), "test-secret".to_string()).await;
let scheduler_path = goose::scheduler::get_default_scheduler_storage_path()
.expect("Failed to get default scheduler storage path");
let scheduler =
goose::scheduler_factory::SchedulerFactory::create_legacy(scheduler_path)
.await
.unwrap();
state.set_scheduler(scheduler).await;
let app = routes(state);
let request = Request::builder()
.uri("/ask")
.uri("/reply")
.method("POST")
.header("content-type", "application/json")
.header("x-secret-key", "test-secret")
.body(Body::from(
serde_json::to_string(&AskRequest {
prompt: "test prompt".to_string(),
serde_json::to_string(&ChatRequest {
messages: vec![Message::user().with_text("test message")],
session_id: Some("test-session".to_string()),
session_working_dir: "test-working-dir".to_string(),
scheduled_job_id: None,

View File

@@ -35,7 +35,7 @@ async fn main() {
let messages = vec![Message::user()
.with_text("can you summarize the readme.md in this dir using just a haiku?")];
let mut stream = agent.reply(&messages, None).await.unwrap();
let mut stream = agent.reply(&messages, None, None).await.unwrap();
while let Some(Ok(AgentEvent::Message(message))) = stream.next().await {
println!("{}", serde_json::to_string_pretty(&message).unwrap());
println!("\n");

View File

@@ -8,15 +8,32 @@ use futures::stream::BoxStream;
use futures::{stream, FutureExt, Stream, StreamExt, TryStreamExt};
use mcp_core::protocol::JsonRpcMessage;
use crate::agents::extension::{ExtensionConfig, ExtensionError, ExtensionResult, ToolInfo};
use crate::agents::extension_manager::{get_parameter_names, ExtensionManager};
use crate::agents::final_output_tool::{FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_OUTPUT_TOOL_NAME};
use crate::agents::platform_tools::{
PLATFORM_LIST_RESOURCES_TOOL_NAME, PLATFORM_MANAGE_EXTENSIONS_TOOL_NAME,
PLATFORM_MANAGE_SCHEDULE_TOOL_NAME, PLATFORM_READ_RESOURCE_TOOL_NAME,
PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME,
};
use crate::agents::prompt_manager::PromptManager;
use crate::agents::recipe_tools::dynamic_task_tools::{
create_dynamic_task, create_dynamic_task_tool, DYNAMIC_TASK_TOOL_NAME_PREFIX,
};
use crate::agents::retry::{RetryManager, RetryResult};
use crate::agents::router_tool_selector::{
create_tool_selector, RouterToolSelectionStrategy, RouterToolSelector,
};
use crate::agents::router_tools::{ROUTER_LLM_SEARCH_TOOL_NAME, ROUTER_VECTOR_SEARCH_TOOL_NAME};
use crate::agents::sub_recipe_manager::SubRecipeManager;
use crate::agents::subagent_execution_tool::subagent_execute_task_tool::{
self, SUBAGENT_EXECUTE_TASK_TOOL_NAME,
};
use crate::agents::subagent_execution_tool::tasks_manager::TasksManager;
use crate::agents::tool_router_index_manager::ToolRouterIndexManager;
use crate::agents::tool_vectordb::generate_table_id;
use crate::agents::types::SessionConfig;
use crate::agents::types::{FrontendTool, ToolResultReceiver};
use crate::config::{Config, ExtensionConfigManager, PermissionManager};
use crate::message::{push_message, Message};
use crate::permission::permission_judge::check_tool_permissions;
@@ -26,31 +43,14 @@ use crate::providers::errors::ProviderError;
use crate::recipe::{Author, Recipe, Response, Settings, SubRecipe};
use crate::scheduler_trait::SchedulerTrait;
use crate::tool_monitor::{ToolCall, ToolMonitor};
use mcp_core::{protocol::GetPromptResult, tool::Tool, ToolError, ToolResult};
use regex::Regex;
use rmcp::model::{Content, Prompt};
use serde_json::Value;
use tokio::sync::{mpsc, Mutex, RwLock};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, instrument};
use crate::agents::extension::{ExtensionConfig, ExtensionError, ExtensionResult, ToolInfo};
use crate::agents::extension_manager::{get_parameter_names, ExtensionManager};
use crate::agents::platform_tools::{
PLATFORM_LIST_RESOURCES_TOOL_NAME, PLATFORM_MANAGE_EXTENSIONS_TOOL_NAME,
PLATFORM_MANAGE_SCHEDULE_TOOL_NAME, PLATFORM_READ_RESOURCE_TOOL_NAME,
PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME,
};
use crate::agents::prompt_manager::PromptManager;
use crate::agents::retry::{RetryManager, RetryResult};
use crate::agents::router_tool_selector::{
create_tool_selector, RouterToolSelectionStrategy, RouterToolSelector,
};
use crate::agents::router_tools::{ROUTER_LLM_SEARCH_TOOL_NAME, ROUTER_VECTOR_SEARCH_TOOL_NAME};
use crate::agents::tool_router_index_manager::ToolRouterIndexManager;
use crate::agents::tool_vectordb::generate_table_id;
use crate::agents::types::SessionConfig;
use crate::agents::types::{FrontendTool, ToolResultReceiver};
use mcp_core::{protocol::GetPromptResult, tool::Tool, ToolError, ToolResult};
use rmcp::model::{Content, Prompt};
use super::final_output_tool::FinalOutputTool;
use super::platform_tools;
use super::router_tools;
@@ -317,17 +317,17 @@ impl Agent {
}
if tool_call.name == FINAL_OUTPUT_TOOL_NAME {
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_mut() {
return if let Some(final_output_tool) = self.final_output_tool.lock().await.as_mut() {
let result = final_output_tool.execute_tool_call(tool_call.clone()).await;
return (request_id, Ok(result));
(request_id, Ok(result))
} else {
return (
(
request_id,
Err(ToolError::ExecutionError(
"Final output tool not defined".to_string(),
)),
);
}
)
};
}
let extension_manager = self.extension_manager.read().await;
@@ -406,10 +406,9 @@ impl Agent {
let result = extension_manager
.dispatch_tool_call(tool_call.clone())
.await;
match result {
Ok(call_result) => call_result,
Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))),
}
result.unwrap_or_else(|e| {
ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string())))
})
};
(
@@ -719,42 +718,17 @@ impl Agent {
&self,
messages: &[Message],
session: Option<SessionConfig>,
) -> anyhow::Result<BoxStream<'_, anyhow::Result<AgentEvent>>> {
cancel_token: Option<CancellationToken>,
) -> Result<BoxStream<'_, Result<AgentEvent>>> {
let mut messages = messages.to_vec();
let initial_messages = messages.clone();
let reply_span = tracing::Span::current();
self.reset_retry_attempts().await;
// Load settings from config
let config = Config::global();
// Setup tools and prompt
let (mut tools, mut toolshim_tools, mut system_prompt) =
self.prepare_tools_and_prompt().await?;
// Get goose_mode from config, but override with execution_mode if provided in session config
let mut goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
// If this is a scheduled job with an execution_mode, override the goose_mode
if let Some(session_config) = &session {
if let Some(execution_mode) = &session_config.execution_mode {
// Map "foreground" to "auto" and "background" to "chat"
goose_mode = match execution_mode.as_str() {
"foreground" => "auto".to_string(),
"background" => "chat".to_string(),
_ => goose_mode,
};
tracing::info!(
"Using execution_mode '{}' which maps to goose_mode '{}'",
execution_mode,
goose_mode
);
}
}
let (tools_with_readonly_annotation, tools_without_annotation) =
Self::categorize_tools_by_annotation(&tools);
let goose_mode = Self::determine_goose_mode(session.as_ref(), config);
if let Some(content) = messages
.last()
@@ -775,12 +749,16 @@ impl Agent {
});
loop {
// Check for final output before incrementing turns or checking max_turns
// This ensures that if we have a final output ready, we return it immediately
// without being blocked by the max_turns limit - this is needed for streaming cases
if cancel_token.as_ref().is_some_and(|t| t.is_cancelled()) {
break;
}
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
if final_output_tool.final_output.is_some() {
yield AgentEvent::Message(Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()));
let final_event = AgentEvent::Message(
Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()),
);
yield final_event;
break;
}
}
@@ -793,23 +771,19 @@ impl Agent {
break;
}
// Check for MCP notifications from subagents
// Handle MCP notifications from subagents
let mcp_notifications = self.get_mcp_notifications().await;
for notification in mcp_notifications {
// Extract subagent info from the notification data
if let JsonRpcMessage::Notification(ref notif) = notification {
if let Some(params) = &notif.params {
if let Some(data) = params.get("data") {
if let (Some(subagent_id), Some(_message)) = (
data.get("subagent_id").and_then(|v| v.as_str()),
data.get("message").and_then(|v| v.as_str())
) {
// Emit as McpNotification event
yield AgentEvent::McpNotification((
subagent_id.to_string(),
notification.clone()
));
}
if let JsonRpcMessage::Notification(notif) = &notification {
if let Some(data) = notif.params.as_ref().and_then(|p| p.get("data")) {
if let (Some(subagent_id), Some(_message)) = (
data.get("subagent_id").and_then(|v| v.as_str()),
data.get("message").and_then(|v| v.as_str()),
) {
yield AgentEvent::McpNotification((
subagent_id.to_string(),
notification.clone(),
));
}
}
}
@@ -821,17 +795,24 @@ impl Agent {
&messages,
&tools,
&toolshim_tools,
).await?;
)
.await?;
let mut added_message = false;
let mut messages_to_add = Vec::new();
let mut tools_updated = false;
while let Some(next) = stream.next().await {
if cancel_token.as_ref().is_some_and(|t| t.is_cancelled()) {
break;
}
match next {
Ok((response, usage)) => {
// Emit model change event if provider is lead-worker
let provider = self.provider().await?;
if let Some(lead_worker) = provider.as_lead_worker() {
if let Some(ref usage) = usage {
// The actual model used is in the usage
let active_model = usage.model.clone();
let (lead_model, worker_model) = lead_worker.get_model_info();
let mode = if active_model == lead_model {
@@ -849,43 +830,44 @@ impl Agent {
}
}
// record usage for the session in the session file
if let Some(session_config) = session.clone() {
// Record usage for the session
if let Some(ref session_config) = &session {
if let Some(ref usage) = usage {
Self::update_session_metrics(session_config, usage, messages.len()).await?;
Self::update_session_metrics(session_config, usage, messages.len())
.await?;
}
}
if let Some(response) = response {
// categorize the type of requests we need to handle
let (frontend_requests,
remaining_requests,
filtered_response) =
let (tools_with_readonly_annotation, tools_without_annotation) =
Self::categorize_tools_by_annotation(&tools);
// Categorize tool requests
let (frontend_requests, remaining_requests, filtered_response) =
self.categorize_tool_requests(&response).await;
// Record tool calls in the router selector
let selector = self.router_tool_selector.lock().await.clone();
if let Some(selector) = selector {
// Record frontend tool calls
for request in &frontend_requests {
if let Ok(tool_call) = &request.tool_call {
if let Err(e) = selector.record_tool_call(&tool_call.name).await {
tracing::error!("Failed to record frontend tool call: {}", e);
if let Err(e) = selector.record_tool_call(&tool_call.name).await
{
error!("Failed to record frontend tool call: {}", e);
}
}
}
// Record remaining tool calls
for request in &remaining_requests {
if let Ok(tool_call) = &request.tool_call {
if let Err(e) = selector.record_tool_call(&tool_call.name).await {
tracing::error!("Failed to record tool call: {}", e);
if let Err(e) = selector.record_tool_call(&tool_call.name).await
{
error!("Failed to record tool call: {}", e);
}
}
}
}
// Yield the assistant's response with frontend tool requests filtered out
yield AgentEvent::Message(filtered_response.clone());
yield AgentEvent::Message(filtered_response.clone());
tokio::task::yield_now().await;
let num_tool_requests = frontend_requests.len() + remaining_requests.len();
@@ -893,23 +875,17 @@ impl Agent {
continue;
}
// Process tool requests depending on frontend tools and then goose_mode
let message_tool_response = Arc::new(Mutex::new(Message::user()));
// First handle any frontend tool requests
let mut frontend_tool_stream = self.handle_frontend_tool_requests(
&frontend_requests,
message_tool_response.clone()
message_tool_response.clone(),
);
// we have a stream of frontend tools to handle, inside the stream
// execution is yeield back to this reply loop, and is of the same Message
// type, so we can yield that back up to be handled
while let Some(msg) = frontend_tool_stream.try_next().await? {
yield AgentEvent::Message(msg);
}
// Clone goose_mode once before the match to avoid move issues
let mode = goose_mode.clone();
if mode.as_str() == "chat" {
// Skip all tool calls in chat mode
@@ -921,36 +897,42 @@ impl Agent {
);
}
} else {
// At this point, we have handled the frontend tool requests and know goose_mode != "chat"
// What remains is handling the remaining tool requests (enable extension,
// regular tool calls) in goose_mode == ["auto", "approve" or "smart_approve"]
let mut permission_manager = PermissionManager::default();
let (permission_check_result, enable_extension_request_ids) = check_tool_permissions(
&remaining_requests,
&mode,
tools_with_readonly_annotation.clone(),
tools_without_annotation.clone(),
&mut permission_manager,
self.provider().await?).await;
let (permission_check_result, enable_extension_request_ids) =
check_tool_permissions(
&remaining_requests,
&mode,
tools_with_readonly_annotation.clone(),
tools_without_annotation.clone(),
&mut permission_manager,
self.provider().await?,
)
.await;
// Handle pre-approved and read-only tools in parallel
let mut tool_futures: Vec<(String, ToolStream)> = Vec::new();
// Skip the confirmation for approved tools
// Handle pre-approved and read-only tools
for request in &permission_check_result.approved {
if let Ok(tool_call) = request.tool_call.clone() {
let (req_id, tool_result) = self.dispatch_tool_call(tool_call, request.id.clone()).await;
let (req_id, tool_result) = self
.dispatch_tool_call(tool_call, request.id.clone())
.await;
tool_futures.push((req_id, match tool_result {
Ok(result) => tool_stream(
result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())),
result.result,
),
Err(e) => tool_stream(
Box::new(stream::empty()),
futures::future::ready(Err(e)),
),
}));
tool_futures.push((
req_id,
match tool_result {
Ok(result) => tool_stream(
result
.notification_stream
.unwrap_or_else(|| Box::new(stream::empty())),
result.result,
),
Err(e) => tool_stream(
Box::new(stream::empty()),
futures::future::ready(Err(e)),
),
},
));
}
}
@@ -962,29 +944,22 @@ impl Agent {
);
}
// We need interior mutability in handle_approval_tool_requests
let tool_futures_arc = Arc::new(Mutex::new(tool_futures));
// Process tools requiring approval (enable extension, regular tool calls)
// Process tools requiring approval
let mut tool_approval_stream = self.handle_approval_tool_requests(
&permission_check_result.needs_approval,
tool_futures_arc.clone(),
&mut permission_manager,
message_tool_response.clone()
message_tool_response.clone(),
);
// We have a stream of tool_approval_requests to handle
// Execution is yielded back to this reply loop, and is of the same Message
// type, so we can yield the Message back up to be handled and grab any
// confirmations or denials
while let Some(msg) = tool_approval_stream.try_next().await? {
yield AgentEvent::Message(msg);
}
tool_futures = {
// Lock the mutex asynchronously
let mut futures_lock = tool_futures_arc.lock().await;
// Drain the vector and collect into a new Vec
futures_lock.drain(..).collect::<Vec<_>>()
};
@@ -996,27 +971,33 @@ impl Agent {
.collect::<Vec<_>>();
let mut combined = stream::select_all(with_id);
let mut all_install_successful = true;
while let Some((request_id, item)) = combined.next().await {
if cancel_token.as_ref().is_some_and(|t| t.is_cancelled()) {
break;
}
match item {
ToolStreamItem::Result(output) => {
if enable_extension_request_ids.contains(&request_id) && output.is_err(){
if enable_extension_request_ids.contains(&request_id)
&& output.is_err()
{
all_install_successful = false;
}
let mut response = message_tool_response.lock().await;
*response = response.clone().with_tool_response(request_id, output);
},
*response =
response.clone().with_tool_response(request_id, output);
}
ToolStreamItem::Message(msg) => {
yield AgentEvent::McpNotification((request_id, msg))
yield AgentEvent::McpNotification((
request_id, msg,
));
}
}
}
// Update system prompt and tools if installations were successful
if all_install_successful {
(tools, toolshim_tools, system_prompt) = self.prepare_tools_and_prompt().await?;
tools_updated = true;
}
}
@@ -1024,63 +1005,39 @@ impl Agent {
yield AgentEvent::Message(final_message_tool_resp.clone());
added_message = true;
push_message(&mut messages, response);
push_message(&mut messages, final_message_tool_resp);
// Check for MCP notifications from subagents again before next iteration
// Note: These are already handled as McpNotification events above,
// so we don't need to convert them to assistant messages here.
// This was causing duplicate plain-text notifications.
// let mcp_notifications = self.get_mcp_notifications().await;
// for notification in mcp_notifications {
// // Extract subagent info from the notification data for assistant messages
// if let JsonRpcMessage::Notification(ref notif) = notification {
// if let Some(params) = &notif.params {
// if let Some(data) = params.get("data") {
// if let (Some(subagent_id), Some(message)) = (
// data.get("subagent_id").and_then(|v| v.as_str()),
// data.get("message").and_then(|v| v.as_str())
// ) {
// yield AgentEvent::Message(
// Message::assistant().with_text(
// format!("Subagent {}: {}", subagent_id, message)
// )
// );
// }
// }
// }
// }
// }
push_message(&mut messages_to_add, response);
push_message(&mut messages_to_add, final_message_tool_resp);
}
},
}
Err(ProviderError::ContextLengthExceeded(_)) => {
// At this point, the last message should be a user message
// because call to provider led to context length exceeded error
// Immediately yield a special message and break
yield AgentEvent::Message(Message::assistant().with_context_length_exceeded(
"The context length of the model has been exceeded. Please start a new session and try again.",
));
"The context length of the model has been exceeded. Please start a new session and try again.",
));
break;
},
}
Err(e) => {
// Create an error message & terminate the stream
error!("Error: {}", e);
yield AgentEvent::Message(Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error.")));
yield AgentEvent::Message(Message::assistant().with_text(
format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error.")
));
break;
}
}
}
if tools_updated {
(tools, toolshim_tools, system_prompt) = self.prepare_tools_and_prompt().await?;
}
if !added_message {
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
if final_output_tool.final_output.is_none() {
tracing::warn!("Final output tool has not been called yet. Continuing agent loop.");
let message = Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE);
messages.push(message.clone());
messages_to_add.push(message.clone());
yield AgentEvent::Message(message);
continue;
continue
} else {
let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap());
messages.push(message.clone());
messages_to_add.push(message.clone());
yield AgentEvent::Message(message);
}
}
@@ -1099,16 +1056,28 @@ impl Agent {
));
}
}
break;
}
// Yield control back to the scheduler to prevent blocking
messages.extend(messages_to_add);
tokio::task::yield_now().await;
}
}))
}
fn determine_goose_mode(session: Option<&SessionConfig>, config: &Config) -> String {
let mode = session.and_then(|s| s.execution_mode.as_deref());
match mode {
Some("foreground") => "auto".to_string(),
Some("background") => "chat".to_string(),
_ => config
.get_param("GOOSE_MODE")
.unwrap_or_else(|_| "auto".to_string()),
}
}
/// Extend the system prompt with one line of additional instruction
pub async fn extend_system_prompt(&self, instruction: String) {
let mut prompt_manager = self.prompt_manager.lock().await;
@@ -1127,7 +1096,6 @@ impl Agent {
notifications
}
/// Update the provider
pub async fn update_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
let mut current_provider = self.provider.lock().await;
*current_provider = Some(provider.clone());
@@ -1191,10 +1159,9 @@ impl Agent {
)
.await
{
tracing::error!(
error!(
"Failed to index tools for extension {}: {}",
extension_name,
e
extension_name, e
);
}
}
@@ -1243,7 +1210,7 @@ impl Agent {
Err(anyhow!("Prompt '{}' not found", name))
}
pub async fn get_plan_prompt(&self) -> anyhow::Result<String> {
pub async fn get_plan_prompt(&self) -> Result<String> {
let extension_manager = self.extension_manager.read().await;
let tools = extension_manager.get_prefixed_tools(None).await?;
let tools_info = tools
@@ -1265,7 +1232,7 @@ impl Agent {
pub async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>) {
if let Err(e) = self.tool_result_tx.send((id, result)).await {
tracing::error!("Failed to send tool result: {}", e);
error!("Failed to send tool result: {}", e);
}
}
@@ -1356,7 +1323,7 @@ impl Agent {
let activities_text = activities_text.trim();
// Regex to remove bullet markers or numbers with an optional dot.
let bullet_re = Regex::new(r"^[•\-\*\d]+\.?\s*").expect("Invalid regex");
let bullet_re = Regex::new(r"^[•\-*\d]+\.?\s*").expect("Invalid regex");
// Process each line in the activities section.
let activities: Vec<String> = activities_text
@@ -1383,7 +1350,7 @@ impl Agent {
metadata: None,
};
// Ideally we'd get the name of the provider we are using from the provider itself
// Ideally we'd get the name of the provider we are using from the provider itself,
// but it doesn't know and the plumbing looks complicated.
let config = Config::global();
let provider_name: String = config
@@ -1443,7 +1410,7 @@ mod tests {
let prompt_manager = agent.prompt_manager.lock().await;
let system_prompt =
prompt_manager.build_system_prompt(vec![], None, serde_json::Value::Null, None, None);
prompt_manager.build_system_prompt(vec![], None, Value::Null, None, None);
let final_output_tool_ref = agent.final_output_tool.lock().await;
let final_output_tool_system_prompt =

View File

@@ -275,10 +275,9 @@ impl Agent {
(frontend_requests, other_requests, filtered_message)
}
/// Update session metrics after a response
pub(crate) async fn update_session_metrics(
session_config: crate::agents::types::SessionConfig,
usage: &crate::providers::base::ProviderUsage,
session_config: &crate::agents::types::SessionConfig,
usage: &ProviderUsage,
messages_length: usize,
) -> Result<()> {
let session_file_path = match session::storage::get_path(session_config.id.clone()) {

View File

@@ -1208,7 +1208,7 @@ async fn run_scheduled_job_internal(
};
match agent
.reply(&all_session_messages, Some(session_config.clone()))
.reply(&all_session_messages, Some(session_config.clone()), None)
.await
{
Ok(mut stream) => {

View File

@@ -129,7 +129,7 @@ async fn run_truncate_test(
),
];
let reply_stream = agent.reply(&messages, None).await?;
let reply_stream = agent.reply(&messages, None, None).await?;
tokio::pin!(reply_stream);
let mut responses = Vec::new();
@@ -619,7 +619,7 @@ mod final_output_tool_tests {
);
// Simulate the reply stream continuing after the final output tool call.
let reply_stream = agent.reply(&vec![], None).await?;
let reply_stream = agent.reply(&vec![], None, None).await?;
tokio::pin!(reply_stream);
let mut responses = Vec::new();
@@ -716,7 +716,7 @@ mod final_output_tool_tests {
agent.add_final_output_tool(response).await;
// Simulate the reply stream being called.
let reply_stream = agent.reply(&vec![], None).await?;
let reply_stream = agent.reply(&vec![], None, None).await?;
tokio::pin!(reply_stream);
let mut responses = Vec::new();
@@ -850,7 +850,9 @@ mod retry_tests {
let initial_messages = vec![Message::user().with_text("Complete this task")];
let reply_stream = agent.reply(&initial_messages, Some(session_config)).await?;
let reply_stream = agent
.reply(&initial_messages, Some(session_config), None)
.await?;
tokio::pin!(reply_stream);
let mut responses = Vec::new();
@@ -1013,7 +1015,7 @@ mod max_turns_tests {
};
let messages = vec![Message::user().with_text("Hello")];
let reply_stream = agent.reply(&messages, Some(session_config)).await?;
let reply_stream = agent.reply(&messages, Some(session_config), None).await?;
tokio::pin!(reply_stream);
let mut responses = Vec::new();

View File

@@ -2332,17 +2332,16 @@
}
},
"ResourceContents": {
"oneOf": [
"anyOf": [
{
"type": "object",
"required": [
"uri",
"text"
"text",
"uri"
],
"properties": {
"mime_type": {
"type": "string",
"nullable": true
"type": "string"
},
"text": {
"type": "string"
@@ -2355,16 +2354,15 @@
{
"type": "object",
"required": [
"uri",
"blob"
"blob",
"uri"
],
"properties": {
"blob": {
"type": "string"
},
"mime_type": {
"type": "string",
"nullable": true
"type": "string"
},
"uri": {
"type": "string"

View File

@@ -1174,7 +1174,7 @@ export default function App() {
}, []);
const config = window.electron.getConfig();
const STRICT_ALLOWLIST = config.GOOSE_ALLOWLIST_WARNING === true ? false : true;
const STRICT_ALLOWLIST = config.GOOSE_ALLOWLIST_WARNING !== true;
useEffect(() => {
console.log('Setting up extension handler');

View File

@@ -463,12 +463,12 @@ export type RedactedThinkingContent = {
};
export type ResourceContents = {
mime_type?: string | null;
mime_type?: string;
text: string;
uri: string;
} | {
blob: string;
mime_type?: string | null;
mime_type?: string;
uri: string;
};

View File

@@ -537,7 +537,7 @@ function BaseChatContent({
recipeDetails={{
title: recipeConfig?.title,
description: recipeConfig?.description,
instructions: recipeConfig?.instructions,
instructions: recipeConfig?.instructions || undefined,
}}
/>

View File

@@ -126,10 +126,10 @@ const notificationToProgress = (notification: NotificationEvent): Progress =>
const getExtensionTooltip = (toolCallName: string): string | null => {
const lastIndex = toolCallName.lastIndexOf('__');
if (lastIndex === -1) return null;
const extensionName = toolCallName.substring(0, lastIndex);
if (!extensionName) return null;
return `${extensionName} extension`;
};
@@ -377,7 +377,7 @@ function ToolCallView({
// This ensures any MCP tool works without explicit handling
const toolDisplayName = snakeToTitleCase(toolName);
const entries = Object.entries(args);
if (entries.length === 0) {
return `${toolDisplayName}`;
}
@@ -413,7 +413,7 @@ function ToolCallView({
};
const toolLabel = (
<span className={cn("ml-2", extensionTooltip && "cursor-pointer hover:opacity-80")}>
<span className={cn('ml-2', extensionTooltip && 'cursor-pointer hover:opacity-80')}>
{getToolLabelContent()}
</span>
);

View File

@@ -1,20 +1,20 @@
import type { OpenDialogReturnValue } from 'electron';
import {
app,
session,
App,
BrowserWindow,
dialog,
Event,
globalShortcut,
ipcMain,
Menu,
MenuItem,
Notification,
powerSaveBlocker,
Tray,
App,
globalShortcut,
session,
shell,
Event,
Tray,
} from 'electron';
import type { OpenDialogReturnValue } from 'electron';
import { Buffer } from 'node:buffer';
import fs from 'node:fs/promises';
import fsSync from 'node:fs';
@@ -33,20 +33,20 @@ import {
EnvToggles,
loadSettings,
saveSettings,
SchedulingEngine,
updateEnvironmentVariables,
updateSchedulingEngineEnvironment,
SchedulingEngine,
} from './utils/settings';
import * as crypto from 'crypto';
// import electron from "electron";
import * as yaml from 'yaml';
import windowStateKeeper from 'electron-window-state';
import {
setupAutoUpdater,
getUpdateAvailable,
registerUpdateIpcHandlers,
setTrayRef,
setupAutoUpdater,
updateTrayMenu,
getUpdateAvailable,
} from './utils/autoUpdater';
import { UPDATES_ENABLED } from './updates';
import { Recipe } from './recipe';
@@ -589,7 +589,7 @@ const createChat = async (
titleBarStyle: process.platform === 'darwin' ? 'hidden' : 'default',
trafficLightPosition: process.platform === 'darwin' ? { x: 20, y: 16 } : undefined,
vibrancy: process.platform === 'darwin' ? 'window' : undefined,
frame: process.platform === 'darwin' ? false : true,
frame: process.platform !== 'darwin',
x: mainWindowState.x,
y: mainWindowState.y,
width: mainWindowState.width,
@@ -1542,15 +1542,8 @@ ipcMain.handle('show-message-box', async (_event, options) => {
return result;
});
// Handle allowed extensions list fetching
ipcMain.handle('get-allowed-extensions', async () => {
try {
const allowList = await getAllowList();
return allowList;
} catch (error) {
console.error('Error fetching allowed extensions:', error);
throw error;
}
return await getAllowList();
});
const createNewWindow = async (app: App, dir?: string | null) => {
@@ -2068,57 +2061,33 @@ app.whenReady().then(async () => {
});
});
/**
* Fetches the allowed extensions list from the remote YAML file if GOOSE_ALLOWLIST is set.
* If the ALLOWLIST is not set, any are allowed. If one is set, it will warn if the deeplink
* doesn't match a command from the list.
* If it fails to load, then it will return an empty list.
* If the format is incorrect, it will return an empty list.
* Format of yaml is:
*
```yaml:
extensions:
- id: slack
command: uvx mcp_slack
- id: knowledge_graph_memory
command: npx -y @modelcontextprotocol/server-memory
```
*
* @returns A promise that resolves to an array of extension commands that are allowed.
*/
async function getAllowList(): Promise<string[]> {
if (!process.env.GOOSE_ALLOWLIST) {
return [];
}
try {
// Fetch the YAML file
const response = await fetch(process.env.GOOSE_ALLOWLIST);
const response = await fetch(process.env.GOOSE_ALLOWLIST);
if (!response.ok) {
throw new Error(
`Failed to fetch allowed extensions: ${response.status} ${response.statusText}`
);
}
if (!response.ok) {
throw new Error(
`Failed to fetch allowed extensions: ${response.status} ${response.statusText}`
);
}
// Parse the YAML content
const yamlContent = await response.text();
const parsedYaml = yaml.parse(yamlContent);
// Parse the YAML content
const yamlContent = await response.text();
const parsedYaml = yaml.parse(yamlContent);
// Extract the commands from the extensions array
if (parsedYaml && parsedYaml.extensions && Array.isArray(parsedYaml.extensions)) {
const commands = parsedYaml.extensions.map(
(ext: { id: string; command: string }) => ext.command
);
console.log(`Fetched ${commands.length} allowed extension commands`);
return commands;
} else {
console.error('Invalid YAML structure:', parsedYaml);
return [];
}
} catch (error) {
console.error('Error in getAllowList:', error);
throw error;
// Extract the commands from the extensions array
if (parsedYaml && parsedYaml.extensions && Array.isArray(parsedYaml.extensions)) {
const commands = parsedYaml.extensions.map(
(ext: { id: string; command: string }) => ext.command
);
console.log(`Fetched ${commands.length} allowed extension commands`);
return commands;
} else {
console.error('Invalid YAML structure:', parsedYaml);
return [];
}
}

View File

@@ -1,198 +0,0 @@
import { getApiUrl, getSecretKey } from '../config';
import { safeJsonParse } from './jsonUtils';
const getQuestionClassifierPrompt = (messageContent: string): string => `
You are a simple classifier that takes content and decides if it is asking for input
from a person before continuing if there is more to do, or not. These are questions
on if a course of action should proceeed or not, or approval is needed. If it is CLEARLY a
question asking if it ok to proceed or make a choice or some input is required to proceed, then, and ONLY THEN, return QUESTION, otherwise READY if not 97% sure.
### Examples message content that is classified as READY:
anything else I can do?
Could you please run the application and verify that the headlines are now visible in dark mode? You can use npm start.
Would you like me to make any adjustments to the formatting of these multiline strings?
Would you like me to show you how to ... (do something)?
Listing window titles... Is there anything specific you'd like help with using these tools?
Would you like me to demonstrate any specific capability or help you with a particular task?
Would you like me to run any tests?
Would you like me to make any adjustments or would you like to test?
Would you like me to dive deeper into any aspect?
Would you like me to make any other adjustments to this implementation?
Would you like any further information or assistance?
Would you like to me to make any changes?
Would you like me to make any adjustments to this implementation?
Would you like me to show you how to…
What would you like to do next?
### Examples that are QUESTIONS:
Should I go ahead and make the changes?
Should I Go ahead with this plan?
Should I focus on X or Y?
Provide me with the name of the package and version you would like to install.
### Message Content:
${messageContent}
You must provide a response strictly limited to one of the following two words:
QUESTION, READY. No other words, phrases, or explanations are allowed.
Response:`;
const getOptionsClassifierPrompt = (messageContent: string): string => `
You are a simple classifier that takes content and decides if it a list of options
or plans to choose from, or not a list of options to choose from. It is IMPORTANT
that you really know this is a choice, just not numbered steps. If it is a list
of options and you are 95% sure, return OPTIONS, otherwise return NO.
### Example (text -> response):
Would you like me to proceed with creating this file? Please let me know if you want any changes before I write it. -> NO
Here are some options for you to choose from: -> OPTIONS
which one do you want to choose? -> OPTIONS
Would you like me to dive deeper into any aspects of these components? -> NO
Should I focus on X or Y? -> OPTIONS
### Message Content:
${messageContent}
You must provide a response strictly limited to one of the following two words:
OPTIONS, NO. No other words, phrases, or explanations are allowed.
Response:`;
const getOptionsFormatterPrompt = (messageContent: string): string => `
If the content is list of distinct options or plans of action to choose from, and
not just a list of things, but clearly a list of things to choose one from, taking
into account the Message Content alone, try to format it in a json array, like this
JSON array of objects of the form optionTitle:string, optionDescription:string (markdown).
If is not a list of options or plans to choose from, then return empty list.
### Message Content:
${messageContent}
You must provide a response strictly as json in the format descriribed. No other
words, phrases, or explanations are allowed.
Response:`;
const getFormPrompt = (messageContent: string): string => `
When you see a request for several pieces of information, then provide a well formed JSON object like will be shown below.
The response will have:
* a title, description,
* a list of fields, each field will have a label, type, name, placeholder, and required (boolean).
(type is either text or textarea only).
If it is not requesting clearly several pieces of information, just return an empty object.
If the task could be confirmed without more information, return an empty object.
### Example Message:
I'll help you scaffold out a Python package. To create a well-structured Python package, I'll need to know a few key pieces of information:
Package name - What would you like to call your package? (This should be a valid Python package name - lowercase, no spaces, typically using underscores for separators if needed)
Brief description - What is the main purpose of the package? This helps in setting up appropriate documentation and structure.
Initial modules - Do you have specific functionality in mind that should be split into different modules?
Python version - Which Python version(s) do you want to support?
Dependencies - Are there any known external packages you'll need?
### Example JSON Response:
{
"title": "Python Package Scaffolding Form",
"description": "Provide the details below to scaffold a well-structured Python package.",
"fields": [
{
"label": "Package Name",
"type": "text",
"name": "package_name",
"placeholder": "Enter the package name (lowercase, no spaces, use underscores if needed)",
"required": true
},
{
"label": "Brief Description",
"type": "textarea",
"name": "brief_description",
"placeholder": "Enter a brief description of the package's purpose",
"required": true
},
{
"label": "Initial Modules",
"type": "textarea",
"name": "initial_modules",
"placeholder": "List the specific functionalities or modules (optional)",
"required": false
},
{
"label": "Python Version(s)",
"type": "text",
"name": "python_versions",
"placeholder": "Enter the Python version(s) to support (e.g., 3.8, 3.9, 3.10)",
"required": true
},
{
"label": "Dependencies",
"type": "textarea",
"name": "dependencies",
"placeholder": "List any known external packages you'll need (optional)",
"required": false
}
]
}
### Message Content:
${messageContent}
You must provide a response strictly as json in the format described. No other
words, phrases, or explanations are allowed.
Response:`;
/**
* Core function to ask the AI a single question and get a response
* @param prompt The prompt to send to the AI
* @returns Promise<string> The AI's response
*/
export async function ask(prompt: string): Promise<string> {
const response = await fetch(getApiUrl('/ask'), {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-Secret-Key': getSecretKey(),
},
body: JSON.stringify({ prompt }),
});
if (!response.ok) {
throw new Error('Failed to get response');
}
const data = await safeJsonParse<{ response: string }>(response, 'Failed to get AI response');
return data.response;
}
/**
* Utility to ask the LLM multiple questions to clarify without wider context.
* @param messageContent The content to analyze
* @returns Promise<string[]> Array of responses from the AI for each classifier
*/
export async function askAi(messageContent: string): Promise<string[]> {
// First, check the question classifier
const questionClassification = await ask(getQuestionClassifierPrompt(messageContent));
// If READY, return early with empty responses for options
if (questionClassification === 'READY') {
return [questionClassification, 'NO', '[]', '{}'];
}
// Otherwise, proceed with all classifiers in parallel
const prompts = [
Promise.resolve(questionClassification), // Reuse the result we already have
ask(getOptionsClassifierPrompt(messageContent)),
ask(getOptionsFormatterPrompt(messageContent)),
ask(getFormPrompt(messageContent)),
];
return Promise.all(prompts);
}