mirror of
https://github.com/aljazceru/goose.git
synced 2026-01-31 12:14:32 +01:00
Agent loop defensive (#3554)
Co-authored-by: Douwe Osinga <douwe@squareup.com>
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -3568,6 +3568,7 @@ dependencies = [
|
||||
"test-case",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
"tracing-appender",
|
||||
@@ -3703,6 +3704,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-cron-scheduler",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
"tower 0.5.2",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
|
||||
@@ -69,8 +69,8 @@ tokio-stream = "0.1"
|
||||
bytes = "1.5"
|
||||
http = "1.0"
|
||||
webbrowser = "1.0"
|
||||
|
||||
indicatif = "0.17.11"
|
||||
tokio-util = "0.7.15"
|
||||
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
winapi = { version = "0.3", features = ["wincred"] }
|
||||
|
||||
@@ -209,7 +209,7 @@ async fn serve_static(axum::extract::Path(path): axum::extract::Path<String>) ->
|
||||
include_bytes!("../../../../documentation/static/img/logo_light.png").to_vec(),
|
||||
)
|
||||
.into_response(),
|
||||
_ => (axum::http::StatusCode::NOT_FOUND, "Not found").into_response(),
|
||||
_ => (http::StatusCode::NOT_FOUND, "Not found").into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -484,7 +484,6 @@ async fn process_message_streaming(
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Create a session config
|
||||
let session_config = SessionConfig {
|
||||
id: session::Identifier::Path(session_file.clone()),
|
||||
working_dir: std::env::current_dir()?,
|
||||
@@ -494,8 +493,7 @@ async fn process_message_streaming(
|
||||
retry_config: None,
|
||||
};
|
||||
|
||||
// Get response from agent
|
||||
match agent.reply(&messages, Some(session_config)).await {
|
||||
match agent.reply(&messages, Some(session_config), None).await {
|
||||
Ok(mut stream) => {
|
||||
while let Some(result) = stream.next().await {
|
||||
match result {
|
||||
|
||||
@@ -48,6 +48,7 @@ use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
pub enum RunMode {
|
||||
Normal,
|
||||
@@ -132,13 +133,10 @@ impl Session {
|
||||
retry_config: Option<RetryConfig>,
|
||||
) -> Self {
|
||||
let messages = if let Some(session_file) = &session_file {
|
||||
match session::read_messages(session_file) {
|
||||
Ok(msgs) => msgs,
|
||||
Err(e) => {
|
||||
eprintln!("Warning: Failed to load message history: {}", e);
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
session::read_messages(session_file).unwrap_or_else(|e| {
|
||||
eprintln!("Warning: Failed to load message history: {}", e);
|
||||
Vec::new()
|
||||
})
|
||||
} else {
|
||||
// Don't try to read messages if we're not saving sessions
|
||||
Vec::new()
|
||||
@@ -180,7 +178,7 @@ impl Session {
|
||||
/// Format: "ENV1=val1 ENV2=val2 command args..."
|
||||
pub async fn add_extension(&mut self, extension_command: String) -> Result<()> {
|
||||
let mut parts: Vec<&str> = extension_command.split_whitespace().collect();
|
||||
let mut envs = std::collections::HashMap::new();
|
||||
let mut envs = HashMap::new();
|
||||
|
||||
// Parse environment variables (format: KEY=value)
|
||||
while let Some(part) = parts.first() {
|
||||
@@ -473,7 +471,7 @@ impl Session {
|
||||
self.display_context_usage().await?;
|
||||
|
||||
match input::get_input(&mut editor)? {
|
||||
input::InputResult::Message(content) => {
|
||||
InputResult::Message(content) => {
|
||||
match self.run_mode {
|
||||
RunMode::Normal => {
|
||||
save_history(&mut editor);
|
||||
@@ -495,15 +493,11 @@ impl Session {
|
||||
eprintln!("Warning: Failed to update project tracker with instruction: {}", e);
|
||||
}
|
||||
|
||||
// Get the provider from the agent for description generation
|
||||
let provider = self.agent.provider().await?;
|
||||
|
||||
// Persist messages with provider for automatic description generation
|
||||
if let Some(session_file) = &self.session_file {
|
||||
let working_dir = Some(
|
||||
std::env::current_dir()
|
||||
.expect("failed to get current session working directory"),
|
||||
);
|
||||
let working_dir = Some(std::env::current_dir().unwrap_or_default());
|
||||
|
||||
session::persist_messages_with_schedule_id(
|
||||
session_file,
|
||||
@@ -847,12 +841,14 @@ impl Session {
|
||||
}
|
||||
|
||||
async fn process_agent_response(&mut self, interactive: bool) -> Result<()> {
|
||||
let cancel_token = CancellationToken::new();
|
||||
let cancel_token_clone = cancel_token.clone();
|
||||
|
||||
let session_config = self.session_file.as_ref().map(|s| {
|
||||
let session_id = session::Identifier::Path(s.clone());
|
||||
SessionConfig {
|
||||
id: session_id.clone(),
|
||||
working_dir: std::env::current_dir()
|
||||
.expect("failed to get current session working directory"),
|
||||
working_dir: std::env::current_dir().unwrap_or_default(),
|
||||
schedule_id: self.scheduled_job_id.clone(),
|
||||
execution_mode: None,
|
||||
max_turns: self.max_turns,
|
||||
@@ -861,7 +857,7 @@ impl Session {
|
||||
});
|
||||
let mut stream = self
|
||||
.agent
|
||||
.reply(&self.messages, session_config.clone())
|
||||
.reply(&self.messages, session_config.clone(), Some(cancel_token))
|
||||
.await?;
|
||||
|
||||
let mut progress_bars = output::McpSpinners::new();
|
||||
@@ -919,7 +915,7 @@ impl Session {
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
cancel_token_clone.cancel();
|
||||
drop(stream);
|
||||
break;
|
||||
} else {
|
||||
@@ -1001,6 +997,7 @@ impl Session {
|
||||
.reply(
|
||||
&self.messages,
|
||||
session_config.clone(),
|
||||
None
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
@@ -1157,6 +1154,7 @@ impl Session {
|
||||
|
||||
Some(Err(e)) => {
|
||||
eprintln!("Error: {}", e);
|
||||
cancel_token_clone.cancel();
|
||||
drop(stream);
|
||||
if let Err(e) = self.handle_interrupted_messages(false).await {
|
||||
eprintln!("Error handling interruption: {}", e);
|
||||
@@ -1173,6 +1171,7 @@ impl Session {
|
||||
}
|
||||
}
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
cancel_token_clone.cancel();
|
||||
drop(stream);
|
||||
if let Err(e) = self.handle_interrupted_messages(true).await {
|
||||
eprintln!("Error handling interruption: {}", e);
|
||||
|
||||
@@ -247,7 +247,7 @@ pub unsafe extern "C" fn goose_agent_send_message(
|
||||
|
||||
// Block on the async call using our global runtime
|
||||
let response = get_runtime().block_on(async {
|
||||
let mut stream = match agent.reply(&messages, None).await {
|
||||
let mut stream = match agent.reply(&messages, None, None).await {
|
||||
Ok(stream) => stream,
|
||||
Err(e) => return format!("Error getting reply from agent: {}", e),
|
||||
};
|
||||
|
||||
@@ -42,6 +42,7 @@ axum-extra = "0.10.0"
|
||||
utoipa = { version = "4.1", features = ["axum_extras", "chrono"] }
|
||||
dirs = "6.0.0"
|
||||
reqwest = { version = "0.12.9", features = ["json", "rustls-tls", "blocking", "multipart"], default-features = false }
|
||||
tokio-util = "0.7.15"
|
||||
|
||||
[[bin]]
|
||||
name = "goosed"
|
||||
|
||||
@@ -11,7 +11,7 @@ use bytes::Bytes;
|
||||
use futures::{stream::StreamExt, Stream};
|
||||
use goose::{
|
||||
agents::{AgentEvent, SessionConfig},
|
||||
message::{push_message, Message, MessageContent},
|
||||
message::{push_message, Message},
|
||||
permission::permission_confirmation::PrincipalType,
|
||||
};
|
||||
use goose::{
|
||||
@@ -19,7 +19,7 @@ use goose::{
|
||||
session,
|
||||
};
|
||||
use mcp_core::{protocol::JsonRpcMessage, ToolResult};
|
||||
use rmcp::model::{Content, Role};
|
||||
use rmcp::model::Content;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use serde_json::Value;
|
||||
@@ -34,9 +34,10 @@ use std::{
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::timeout;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct ChatRequest {
|
||||
messages: Vec<Message>,
|
||||
session_id: Option<String>,
|
||||
@@ -113,7 +114,7 @@ async fn stream_event(
|
||||
tx.send(format!("data: {}\n\n", json)).await
|
||||
}
|
||||
|
||||
async fn handler(
|
||||
async fn reply_handler(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Json(request): Json<ChatRequest>,
|
||||
@@ -122,6 +123,7 @@ async fn handler(
|
||||
|
||||
let (tx, rx) = mpsc::channel(100);
|
||||
let stream = ReceiverStream::new(rx);
|
||||
let cancel_token = CancellationToken::new();
|
||||
|
||||
let messages = request.messages;
|
||||
let session_working_dir = request.session_working_dir.clone();
|
||||
@@ -130,65 +132,35 @@ async fn handler(
|
||||
.session_id
|
||||
.unwrap_or_else(session::generate_session_id);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let agent = state.get_agent().await;
|
||||
let agent = match agent {
|
||||
Ok(agent) => {
|
||||
let provider = agent.provider().await;
|
||||
match provider {
|
||||
Ok(_) => agent,
|
||||
Err(_) => {
|
||||
let _ = stream_event(
|
||||
MessageEvent::Error {
|
||||
error: "No provider configured".to_string(),
|
||||
},
|
||||
&tx,
|
||||
)
|
||||
.await;
|
||||
let _ = stream_event(
|
||||
MessageEvent::Finish {
|
||||
reason: "error".to_string(),
|
||||
},
|
||||
&tx,
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
let task_cancel = cancel_token.clone();
|
||||
let task_tx = tx.clone();
|
||||
|
||||
std::mem::drop(tokio::spawn(async move {
|
||||
let agent = match state.get_agent().await {
|
||||
Ok(agent) => agent,
|
||||
Err(_) => {
|
||||
let _ = stream_event(
|
||||
MessageEvent::Error {
|
||||
error: "No agent configured".to_string(),
|
||||
},
|
||||
&tx,
|
||||
)
|
||||
.await;
|
||||
let _ = stream_event(
|
||||
MessageEvent::Finish {
|
||||
reason: "error".to_string(),
|
||||
},
|
||||
&tx,
|
||||
&task_tx,
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let provider = agent.provider().await;
|
||||
let session_config = SessionConfig {
|
||||
id: session::Identifier::Name(session_id.clone()),
|
||||
working_dir: PathBuf::from(&session_working_dir),
|
||||
schedule_id: request.scheduled_job_id.clone(),
|
||||
execution_mode: None,
|
||||
max_turns: None,
|
||||
retry_config: None,
|
||||
};
|
||||
|
||||
let mut stream = match agent
|
||||
.reply(
|
||||
&messages,
|
||||
Some(SessionConfig {
|
||||
id: session::Identifier::Name(session_id.clone()),
|
||||
working_dir: PathBuf::from(&session_working_dir),
|
||||
schedule_id: request.scheduled_job_id.clone(),
|
||||
execution_mode: None,
|
||||
max_turns: None,
|
||||
retry_config: None,
|
||||
}),
|
||||
)
|
||||
.reply(&messages, Some(session_config), Some(task_cancel.clone()))
|
||||
.await
|
||||
{
|
||||
Ok(stream) => stream,
|
||||
@@ -198,14 +170,7 @@ async fn handler(
|
||||
MessageEvent::Error {
|
||||
error: e.to_string(),
|
||||
},
|
||||
&tx,
|
||||
)
|
||||
.await;
|
||||
let _ = stream_event(
|
||||
MessageEvent::Finish {
|
||||
reason: "error".to_string(),
|
||||
},
|
||||
&tx,
|
||||
&task_tx,
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
@@ -221,7 +186,7 @@ async fn handler(
|
||||
MessageEvent::Error {
|
||||
error: format!("Failed to get session path: {}", e),
|
||||
},
|
||||
&tx,
|
||||
&task_tx,
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
@@ -231,222 +196,104 @@ async fn handler(
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
response = timeout(Duration::from_millis(500), stream.next()) => {
|
||||
match response {
|
||||
Ok(Some(Ok(AgentEvent::Message(message)))) => {
|
||||
push_message(&mut all_messages, message.clone());
|
||||
if let Err(e) = stream_event(MessageEvent::Message { message }, &tx).await {
|
||||
tracing::error!("Error sending message through channel: {}", e);
|
||||
let _ = stream_event(
|
||||
MessageEvent::Error {
|
||||
error: e.to_string(),
|
||||
},
|
||||
&tx,
|
||||
).await;
|
||||
_ = task_cancel.cancelled() => {
|
||||
tracing::info!("Agent task cancelled");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => {
|
||||
if let Err(e) = stream_event(MessageEvent::ModelChange { model, mode }, &tx).await {
|
||||
tracing::error!("Error sending model change through channel: {}", e);
|
||||
let _ = stream_event(
|
||||
MessageEvent::Error {
|
||||
error: e.to_string(),
|
||||
},
|
||||
&tx,
|
||||
).await;
|
||||
}
|
||||
}
|
||||
Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => {
|
||||
if let Err(e) = stream_event(MessageEvent::Notification{
|
||||
request_id: request_id.clone(),
|
||||
message: n,
|
||||
}, &tx).await {
|
||||
tracing::error!("Error sending message through channel: {}", e);
|
||||
let _ = stream_event(
|
||||
MessageEvent::Error {
|
||||
error: e.to_string(),
|
||||
},
|
||||
&tx,
|
||||
).await;
|
||||
}
|
||||
}
|
||||
response = timeout(Duration::from_millis(500), stream.next()) => {
|
||||
match response {
|
||||
Ok(Some(Ok(AgentEvent::Message(message)))) => {
|
||||
push_message(&mut all_messages, message.clone());
|
||||
if let Err(e) = stream_event(MessageEvent::Message { message }, &tx).await {
|
||||
tracing::error!("Error sending message through channel: {}", e);
|
||||
let _ = stream_event(
|
||||
MessageEvent::Error {
|
||||
error: e.to_string(),
|
||||
},
|
||||
&tx,
|
||||
).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => {
|
||||
if let Err(e) = stream_event(MessageEvent::ModelChange { model, mode }, &tx).await {
|
||||
tracing::error!("Error sending model change through channel: {}", e);
|
||||
let _ = stream_event(
|
||||
MessageEvent::Error {
|
||||
error: e.to_string(),
|
||||
},
|
||||
&tx,
|
||||
).await;
|
||||
}
|
||||
}
|
||||
Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => {
|
||||
if let Err(e) = stream_event(MessageEvent::Notification{
|
||||
request_id: request_id.clone(),
|
||||
message: n,
|
||||
}, &tx).await {
|
||||
tracing::error!("Error sending message through channel: {}", e);
|
||||
let _ = stream_event(
|
||||
MessageEvent::Error {
|
||||
error: e.to_string(),
|
||||
},
|
||||
&tx,
|
||||
).await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Some(Err(e))) => {
|
||||
tracing::error!("Error processing message: {}", e);
|
||||
let _ = stream_event(
|
||||
MessageEvent::Error {
|
||||
error: e.to_string(),
|
||||
},
|
||||
&tx,
|
||||
).await;
|
||||
break;
|
||||
}
|
||||
Ok(None) => {
|
||||
break;
|
||||
}
|
||||
Err(_) => { // Heartbeat, used to detect disconnected clients
|
||||
if tx.is_closed() {
|
||||
break;
|
||||
Ok(Some(Err(e))) => {
|
||||
tracing::error!("Error processing message: {}", e);
|
||||
let _ = stream_event(
|
||||
MessageEvent::Error {
|
||||
error: e.to_string(),
|
||||
},
|
||||
&tx,
|
||||
).await;
|
||||
break;
|
||||
}
|
||||
Ok(None) => {
|
||||
break;
|
||||
}
|
||||
Err(_) => {
|
||||
if tx.is_closed() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if all_messages.len() > saved_message_count {
|
||||
let provider = Arc::clone(provider.as_ref().unwrap());
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = session::persist_messages(
|
||||
&session_path,
|
||||
&all_messages,
|
||||
Some(provider),
|
||||
Some(PathBuf::from(&session_working_dir)),
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::error!("Failed to store session history: {:?}", e);
|
||||
}
|
||||
});
|
||||
if let Ok(provider) = agent.provider().await {
|
||||
let provider = Arc::clone(&provider);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = session::persist_messages(
|
||||
&session_path,
|
||||
&all_messages,
|
||||
Some(provider),
|
||||
Some(PathBuf::from(&session_working_dir)),
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::error!("Failed to store session history: {:?}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let _ = stream_event(
|
||||
MessageEvent::Finish {
|
||||
reason: "stop".to_string(),
|
||||
},
|
||||
&tx,
|
||||
&task_tx,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
}));
|
||||
Ok(SseResponse::new(stream))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct AskRequest {
|
||||
prompt: String,
|
||||
session_id: Option<String>,
|
||||
session_working_dir: String,
|
||||
scheduled_job_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct AskResponse {
|
||||
response: String,
|
||||
}
|
||||
|
||||
async fn ask_handler(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Json(request): Json<AskRequest>,
|
||||
) -> Result<Json<AskResponse>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
let session_working_dir = request.session_working_dir.clone();
|
||||
|
||||
let session_id = request
|
||||
.session_id
|
||||
.unwrap_or_else(session::generate_session_id);
|
||||
|
||||
let agent = state
|
||||
.get_agent()
|
||||
.await
|
||||
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
|
||||
|
||||
let provider = agent.provider().await;
|
||||
|
||||
let messages = vec![Message::user().with_text(request.prompt)];
|
||||
|
||||
let mut response_text = String::new();
|
||||
let mut stream = match agent
|
||||
.reply(
|
||||
&messages,
|
||||
Some(SessionConfig {
|
||||
id: session::Identifier::Name(session_id.clone()),
|
||||
working_dir: PathBuf::from(&session_working_dir),
|
||||
schedule_id: request.scheduled_job_id.clone(),
|
||||
execution_mode: None,
|
||||
max_turns: None,
|
||||
retry_config: None,
|
||||
}),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(stream) => stream,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to start reply stream: {:?}", e);
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
};
|
||||
|
||||
let mut all_messages = messages.clone();
|
||||
let mut response_message = Message::assistant();
|
||||
|
||||
while let Some(response) = stream.next().await {
|
||||
match response {
|
||||
Ok(AgentEvent::Message(message)) => {
|
||||
if message.role == Role::Assistant {
|
||||
for content in &message.content {
|
||||
if let MessageContent::Text(text) = content {
|
||||
response_text.push_str(&text.text);
|
||||
response_text.push('\n');
|
||||
}
|
||||
response_message.content.push(content.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(AgentEvent::ModelChange { model, mode }) => {
|
||||
// Log model change for non-streaming
|
||||
tracing::info!("Model changed to {} in {} mode", model, mode);
|
||||
}
|
||||
Ok(AgentEvent::McpNotification(n)) => {
|
||||
// Handle notifications if needed
|
||||
tracing::info!("Received notification: {:?}", n);
|
||||
}
|
||||
|
||||
Err(e) => {
|
||||
tracing::error!("Error processing as_ai message: {}", e);
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !response_message.content.is_empty() {
|
||||
push_message(&mut all_messages, response_message);
|
||||
}
|
||||
|
||||
let session_path = match session::get_path(session::Identifier::Name(session_id.clone())) {
|
||||
Ok(path) => path,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to get session path: {}", e);
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
};
|
||||
|
||||
let session_path_clone = session_path.clone();
|
||||
let messages = all_messages.clone();
|
||||
let provider = Arc::clone(provider.as_ref().unwrap());
|
||||
let session_working_dir_clone = session_working_dir.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = session::persist_messages(
|
||||
&session_path_clone,
|
||||
&messages,
|
||||
Some(provider),
|
||||
Some(PathBuf::from(session_working_dir_clone)),
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::error!("Failed to store session history: {:?}", e);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Json(AskResponse {
|
||||
response: response_text.trim().to_string(),
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, ToSchema)]
|
||||
pub struct PermissionConfirmationRequest {
|
||||
id: String,
|
||||
@@ -509,7 +356,7 @@ struct ToolResultRequest {
|
||||
async fn submit_tool_result(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
raw: axum::extract::Json<serde_json::Value>,
|
||||
raw: Json<Value>,
|
||||
) -> Result<Json<Value>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
@@ -540,8 +387,7 @@ async fn submit_tool_result(
|
||||
|
||||
pub fn routes(state: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/reply", post(handler))
|
||||
.route("/ask", post(ask_handler))
|
||||
.route("/reply", post(reply_handler))
|
||||
.route("/confirm", post(confirm_permission))
|
||||
.route("/tool_result", post(submit_tool_result))
|
||||
.with_state(state)
|
||||
@@ -571,10 +417,6 @@ mod tests {
|
||||
goose::providers::base::ProviderMetadata::empty()
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
self.model_config.clone()
|
||||
}
|
||||
|
||||
async fn complete(
|
||||
&self,
|
||||
_system: &str,
|
||||
@@ -586,6 +428,10 @@ mod tests {
|
||||
ProviderUsage::new("mock".to_string(), Usage::default()),
|
||||
))
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
self.model_config.clone()
|
||||
}
|
||||
}
|
||||
|
||||
mod integration_tests {
|
||||
@@ -595,7 +441,7 @@ mod tests {
|
||||
use tower::ServiceExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ask_endpoint() {
|
||||
async fn test_reply_endpoint() {
|
||||
let mock_model_config = ModelConfig::new("test-model".to_string());
|
||||
let mock_provider = Arc::new(MockProvider {
|
||||
model_config: mock_model_config,
|
||||
@@ -603,24 +449,17 @@ mod tests {
|
||||
let agent = Agent::new();
|
||||
let _ = agent.update_provider(mock_provider).await;
|
||||
let state = AppState::new(Arc::new(agent), "test-secret".to_string()).await;
|
||||
let scheduler_path = goose::scheduler::get_default_scheduler_storage_path()
|
||||
.expect("Failed to get default scheduler storage path");
|
||||
let scheduler =
|
||||
goose::scheduler_factory::SchedulerFactory::create_legacy(scheduler_path)
|
||||
.await
|
||||
.unwrap();
|
||||
state.set_scheduler(scheduler).await;
|
||||
|
||||
let app = routes(state);
|
||||
|
||||
let request = Request::builder()
|
||||
.uri("/ask")
|
||||
.uri("/reply")
|
||||
.method("POST")
|
||||
.header("content-type", "application/json")
|
||||
.header("x-secret-key", "test-secret")
|
||||
.body(Body::from(
|
||||
serde_json::to_string(&AskRequest {
|
||||
prompt: "test prompt".to_string(),
|
||||
serde_json::to_string(&ChatRequest {
|
||||
messages: vec![Message::user().with_text("test message")],
|
||||
session_id: Some("test-session".to_string()),
|
||||
session_working_dir: "test-working-dir".to_string(),
|
||||
scheduled_job_id: None,
|
||||
|
||||
@@ -35,7 +35,7 @@ async fn main() {
|
||||
let messages = vec![Message::user()
|
||||
.with_text("can you summarize the readme.md in this dir using just a haiku?")];
|
||||
|
||||
let mut stream = agent.reply(&messages, None).await.unwrap();
|
||||
let mut stream = agent.reply(&messages, None, None).await.unwrap();
|
||||
while let Some(Ok(AgentEvent::Message(message))) = stream.next().await {
|
||||
println!("{}", serde_json::to_string_pretty(&message).unwrap());
|
||||
println!("\n");
|
||||
|
||||
@@ -8,15 +8,32 @@ use futures::stream::BoxStream;
|
||||
use futures::{stream, FutureExt, Stream, StreamExt, TryStreamExt};
|
||||
use mcp_core::protocol::JsonRpcMessage;
|
||||
|
||||
use crate::agents::extension::{ExtensionConfig, ExtensionError, ExtensionResult, ToolInfo};
|
||||
use crate::agents::extension_manager::{get_parameter_names, ExtensionManager};
|
||||
use crate::agents::final_output_tool::{FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_OUTPUT_TOOL_NAME};
|
||||
use crate::agents::platform_tools::{
|
||||
PLATFORM_LIST_RESOURCES_TOOL_NAME, PLATFORM_MANAGE_EXTENSIONS_TOOL_NAME,
|
||||
PLATFORM_MANAGE_SCHEDULE_TOOL_NAME, PLATFORM_READ_RESOURCE_TOOL_NAME,
|
||||
PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME,
|
||||
};
|
||||
use crate::agents::prompt_manager::PromptManager;
|
||||
use crate::agents::recipe_tools::dynamic_task_tools::{
|
||||
create_dynamic_task, create_dynamic_task_tool, DYNAMIC_TASK_TOOL_NAME_PREFIX,
|
||||
};
|
||||
use crate::agents::retry::{RetryManager, RetryResult};
|
||||
use crate::agents::router_tool_selector::{
|
||||
create_tool_selector, RouterToolSelectionStrategy, RouterToolSelector,
|
||||
};
|
||||
use crate::agents::router_tools::{ROUTER_LLM_SEARCH_TOOL_NAME, ROUTER_VECTOR_SEARCH_TOOL_NAME};
|
||||
use crate::agents::sub_recipe_manager::SubRecipeManager;
|
||||
use crate::agents::subagent_execution_tool::subagent_execute_task_tool::{
|
||||
self, SUBAGENT_EXECUTE_TASK_TOOL_NAME,
|
||||
};
|
||||
use crate::agents::subagent_execution_tool::tasks_manager::TasksManager;
|
||||
use crate::agents::tool_router_index_manager::ToolRouterIndexManager;
|
||||
use crate::agents::tool_vectordb::generate_table_id;
|
||||
use crate::agents::types::SessionConfig;
|
||||
use crate::agents::types::{FrontendTool, ToolResultReceiver};
|
||||
use crate::config::{Config, ExtensionConfigManager, PermissionManager};
|
||||
use crate::message::{push_message, Message};
|
||||
use crate::permission::permission_judge::check_tool_permissions;
|
||||
@@ -26,31 +43,14 @@ use crate::providers::errors::ProviderError;
|
||||
use crate::recipe::{Author, Recipe, Response, Settings, SubRecipe};
|
||||
use crate::scheduler_trait::SchedulerTrait;
|
||||
use crate::tool_monitor::{ToolCall, ToolMonitor};
|
||||
use mcp_core::{protocol::GetPromptResult, tool::Tool, ToolError, ToolResult};
|
||||
use regex::Regex;
|
||||
use rmcp::model::{Content, Prompt};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::{mpsc, Mutex, RwLock};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, error, info, instrument};
|
||||
|
||||
use crate::agents::extension::{ExtensionConfig, ExtensionError, ExtensionResult, ToolInfo};
|
||||
use crate::agents::extension_manager::{get_parameter_names, ExtensionManager};
|
||||
use crate::agents::platform_tools::{
|
||||
PLATFORM_LIST_RESOURCES_TOOL_NAME, PLATFORM_MANAGE_EXTENSIONS_TOOL_NAME,
|
||||
PLATFORM_MANAGE_SCHEDULE_TOOL_NAME, PLATFORM_READ_RESOURCE_TOOL_NAME,
|
||||
PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME,
|
||||
};
|
||||
use crate::agents::prompt_manager::PromptManager;
|
||||
use crate::agents::retry::{RetryManager, RetryResult};
|
||||
use crate::agents::router_tool_selector::{
|
||||
create_tool_selector, RouterToolSelectionStrategy, RouterToolSelector,
|
||||
};
|
||||
use crate::agents::router_tools::{ROUTER_LLM_SEARCH_TOOL_NAME, ROUTER_VECTOR_SEARCH_TOOL_NAME};
|
||||
use crate::agents::tool_router_index_manager::ToolRouterIndexManager;
|
||||
use crate::agents::tool_vectordb::generate_table_id;
|
||||
use crate::agents::types::SessionConfig;
|
||||
use crate::agents::types::{FrontendTool, ToolResultReceiver};
|
||||
use mcp_core::{protocol::GetPromptResult, tool::Tool, ToolError, ToolResult};
|
||||
use rmcp::model::{Content, Prompt};
|
||||
|
||||
use super::final_output_tool::FinalOutputTool;
|
||||
use super::platform_tools;
|
||||
use super::router_tools;
|
||||
@@ -317,17 +317,17 @@ impl Agent {
|
||||
}
|
||||
|
||||
if tool_call.name == FINAL_OUTPUT_TOOL_NAME {
|
||||
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_mut() {
|
||||
return if let Some(final_output_tool) = self.final_output_tool.lock().await.as_mut() {
|
||||
let result = final_output_tool.execute_tool_call(tool_call.clone()).await;
|
||||
return (request_id, Ok(result));
|
||||
(request_id, Ok(result))
|
||||
} else {
|
||||
return (
|
||||
(
|
||||
request_id,
|
||||
Err(ToolError::ExecutionError(
|
||||
"Final output tool not defined".to_string(),
|
||||
)),
|
||||
);
|
||||
}
|
||||
)
|
||||
};
|
||||
}
|
||||
|
||||
let extension_manager = self.extension_manager.read().await;
|
||||
@@ -406,10 +406,9 @@ impl Agent {
|
||||
let result = extension_manager
|
||||
.dispatch_tool_call(tool_call.clone())
|
||||
.await;
|
||||
match result {
|
||||
Ok(call_result) => call_result,
|
||||
Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))),
|
||||
}
|
||||
result.unwrap_or_else(|e| {
|
||||
ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string())))
|
||||
})
|
||||
};
|
||||
|
||||
(
|
||||
@@ -719,42 +718,17 @@ impl Agent {
|
||||
&self,
|
||||
messages: &[Message],
|
||||
session: Option<SessionConfig>,
|
||||
) -> anyhow::Result<BoxStream<'_, anyhow::Result<AgentEvent>>> {
|
||||
cancel_token: Option<CancellationToken>,
|
||||
) -> Result<BoxStream<'_, Result<AgentEvent>>> {
|
||||
let mut messages = messages.to_vec();
|
||||
let initial_messages = messages.clone();
|
||||
let reply_span = tracing::Span::current();
|
||||
|
||||
self.reset_retry_attempts().await;
|
||||
|
||||
// Load settings from config
|
||||
let config = Config::global();
|
||||
|
||||
// Setup tools and prompt
|
||||
let (mut tools, mut toolshim_tools, mut system_prompt) =
|
||||
self.prepare_tools_and_prompt().await?;
|
||||
|
||||
// Get goose_mode from config, but override with execution_mode if provided in session config
|
||||
let mut goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
|
||||
|
||||
// If this is a scheduled job with an execution_mode, override the goose_mode
|
||||
if let Some(session_config) = &session {
|
||||
if let Some(execution_mode) = &session_config.execution_mode {
|
||||
// Map "foreground" to "auto" and "background" to "chat"
|
||||
goose_mode = match execution_mode.as_str() {
|
||||
"foreground" => "auto".to_string(),
|
||||
"background" => "chat".to_string(),
|
||||
_ => goose_mode,
|
||||
};
|
||||
tracing::info!(
|
||||
"Using execution_mode '{}' which maps to goose_mode '{}'",
|
||||
execution_mode,
|
||||
goose_mode
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let (tools_with_readonly_annotation, tools_without_annotation) =
|
||||
Self::categorize_tools_by_annotation(&tools);
|
||||
let goose_mode = Self::determine_goose_mode(session.as_ref(), config);
|
||||
|
||||
if let Some(content) = messages
|
||||
.last()
|
||||
@@ -775,12 +749,16 @@ impl Agent {
|
||||
});
|
||||
|
||||
loop {
|
||||
// Check for final output before incrementing turns or checking max_turns
|
||||
// This ensures that if we have a final output ready, we return it immediately
|
||||
// without being blocked by the max_turns limit - this is needed for streaming cases
|
||||
if cancel_token.as_ref().is_some_and(|t| t.is_cancelled()) {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
|
||||
if final_output_tool.final_output.is_some() {
|
||||
yield AgentEvent::Message(Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()));
|
||||
let final_event = AgentEvent::Message(
|
||||
Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()),
|
||||
);
|
||||
yield final_event;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -793,23 +771,19 @@ impl Agent {
|
||||
break;
|
||||
}
|
||||
|
||||
// Check for MCP notifications from subagents
|
||||
// Handle MCP notifications from subagents
|
||||
let mcp_notifications = self.get_mcp_notifications().await;
|
||||
for notification in mcp_notifications {
|
||||
// Extract subagent info from the notification data
|
||||
if let JsonRpcMessage::Notification(ref notif) = notification {
|
||||
if let Some(params) = ¬if.params {
|
||||
if let Some(data) = params.get("data") {
|
||||
if let (Some(subagent_id), Some(_message)) = (
|
||||
data.get("subagent_id").and_then(|v| v.as_str()),
|
||||
data.get("message").and_then(|v| v.as_str())
|
||||
) {
|
||||
// Emit as McpNotification event
|
||||
yield AgentEvent::McpNotification((
|
||||
subagent_id.to_string(),
|
||||
notification.clone()
|
||||
));
|
||||
}
|
||||
if let JsonRpcMessage::Notification(notif) = ¬ification {
|
||||
if let Some(data) = notif.params.as_ref().and_then(|p| p.get("data")) {
|
||||
if let (Some(subagent_id), Some(_message)) = (
|
||||
data.get("subagent_id").and_then(|v| v.as_str()),
|
||||
data.get("message").and_then(|v| v.as_str()),
|
||||
) {
|
||||
yield AgentEvent::McpNotification((
|
||||
subagent_id.to_string(),
|
||||
notification.clone(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -821,17 +795,24 @@ impl Agent {
|
||||
&messages,
|
||||
&tools,
|
||||
&toolshim_tools,
|
||||
).await?;
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut added_message = false;
|
||||
let mut messages_to_add = Vec::new();
|
||||
let mut tools_updated = false;
|
||||
|
||||
while let Some(next) = stream.next().await {
|
||||
if cancel_token.as_ref().is_some_and(|t| t.is_cancelled()) {
|
||||
break;
|
||||
}
|
||||
|
||||
match next {
|
||||
Ok((response, usage)) => {
|
||||
// Emit model change event if provider is lead-worker
|
||||
let provider = self.provider().await?;
|
||||
if let Some(lead_worker) = provider.as_lead_worker() {
|
||||
if let Some(ref usage) = usage {
|
||||
// The actual model used is in the usage
|
||||
let active_model = usage.model.clone();
|
||||
let (lead_model, worker_model) = lead_worker.get_model_info();
|
||||
let mode = if active_model == lead_model {
|
||||
@@ -849,43 +830,44 @@ impl Agent {
|
||||
}
|
||||
}
|
||||
|
||||
// record usage for the session in the session file
|
||||
if let Some(session_config) = session.clone() {
|
||||
// Record usage for the session
|
||||
if let Some(ref session_config) = &session {
|
||||
if let Some(ref usage) = usage {
|
||||
Self::update_session_metrics(session_config, usage, messages.len()).await?;
|
||||
Self::update_session_metrics(session_config, usage, messages.len())
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(response) = response {
|
||||
// categorize the type of requests we need to handle
|
||||
let (frontend_requests,
|
||||
remaining_requests,
|
||||
filtered_response) =
|
||||
let (tools_with_readonly_annotation, tools_without_annotation) =
|
||||
Self::categorize_tools_by_annotation(&tools);
|
||||
|
||||
// Categorize tool requests
|
||||
let (frontend_requests, remaining_requests, filtered_response) =
|
||||
self.categorize_tool_requests(&response).await;
|
||||
|
||||
// Record tool calls in the router selector
|
||||
let selector = self.router_tool_selector.lock().await.clone();
|
||||
if let Some(selector) = selector {
|
||||
// Record frontend tool calls
|
||||
for request in &frontend_requests {
|
||||
if let Ok(tool_call) = &request.tool_call {
|
||||
if let Err(e) = selector.record_tool_call(&tool_call.name).await {
|
||||
tracing::error!("Failed to record frontend tool call: {}", e);
|
||||
if let Err(e) = selector.record_tool_call(&tool_call.name).await
|
||||
{
|
||||
error!("Failed to record frontend tool call: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Record remaining tool calls
|
||||
for request in &remaining_requests {
|
||||
if let Ok(tool_call) = &request.tool_call {
|
||||
if let Err(e) = selector.record_tool_call(&tool_call.name).await {
|
||||
tracing::error!("Failed to record tool call: {}", e);
|
||||
if let Err(e) = selector.record_tool_call(&tool_call.name).await
|
||||
{
|
||||
error!("Failed to record tool call: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Yield the assistant's response with frontend tool requests filtered out
|
||||
yield AgentEvent::Message(filtered_response.clone());
|
||||
|
||||
yield AgentEvent::Message(filtered_response.clone());
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
let num_tool_requests = frontend_requests.len() + remaining_requests.len();
|
||||
@@ -893,23 +875,17 @@ impl Agent {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Process tool requests depending on frontend tools and then goose_mode
|
||||
let message_tool_response = Arc::new(Mutex::new(Message::user()));
|
||||
|
||||
// First handle any frontend tool requests
|
||||
let mut frontend_tool_stream = self.handle_frontend_tool_requests(
|
||||
&frontend_requests,
|
||||
message_tool_response.clone()
|
||||
message_tool_response.clone(),
|
||||
);
|
||||
|
||||
// we have a stream of frontend tools to handle, inside the stream
|
||||
// execution is yeield back to this reply loop, and is of the same Message
|
||||
// type, so we can yield that back up to be handled
|
||||
while let Some(msg) = frontend_tool_stream.try_next().await? {
|
||||
yield AgentEvent::Message(msg);
|
||||
}
|
||||
|
||||
// Clone goose_mode once before the match to avoid move issues
|
||||
let mode = goose_mode.clone();
|
||||
if mode.as_str() == "chat" {
|
||||
// Skip all tool calls in chat mode
|
||||
@@ -921,36 +897,42 @@ impl Agent {
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// At this point, we have handled the frontend tool requests and know goose_mode != "chat"
|
||||
// What remains is handling the remaining tool requests (enable extension,
|
||||
// regular tool calls) in goose_mode == ["auto", "approve" or "smart_approve"]
|
||||
let mut permission_manager = PermissionManager::default();
|
||||
let (permission_check_result, enable_extension_request_ids) = check_tool_permissions(
|
||||
&remaining_requests,
|
||||
&mode,
|
||||
tools_with_readonly_annotation.clone(),
|
||||
tools_without_annotation.clone(),
|
||||
&mut permission_manager,
|
||||
self.provider().await?).await;
|
||||
let (permission_check_result, enable_extension_request_ids) =
|
||||
check_tool_permissions(
|
||||
&remaining_requests,
|
||||
&mode,
|
||||
tools_with_readonly_annotation.clone(),
|
||||
tools_without_annotation.clone(),
|
||||
&mut permission_manager,
|
||||
self.provider().await?,
|
||||
)
|
||||
.await;
|
||||
|
||||
// Handle pre-approved and read-only tools in parallel
|
||||
let mut tool_futures: Vec<(String, ToolStream)> = Vec::new();
|
||||
|
||||
// Skip the confirmation for approved tools
|
||||
// Handle pre-approved and read-only tools
|
||||
for request in &permission_check_result.approved {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
let (req_id, tool_result) = self.dispatch_tool_call(tool_call, request.id.clone()).await;
|
||||
let (req_id, tool_result) = self
|
||||
.dispatch_tool_call(tool_call, request.id.clone())
|
||||
.await;
|
||||
|
||||
tool_futures.push((req_id, match tool_result {
|
||||
Ok(result) => tool_stream(
|
||||
result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())),
|
||||
result.result,
|
||||
),
|
||||
Err(e) => tool_stream(
|
||||
Box::new(stream::empty()),
|
||||
futures::future::ready(Err(e)),
|
||||
),
|
||||
}));
|
||||
tool_futures.push((
|
||||
req_id,
|
||||
match tool_result {
|
||||
Ok(result) => tool_stream(
|
||||
result
|
||||
.notification_stream
|
||||
.unwrap_or_else(|| Box::new(stream::empty())),
|
||||
result.result,
|
||||
),
|
||||
Err(e) => tool_stream(
|
||||
Box::new(stream::empty()),
|
||||
futures::future::ready(Err(e)),
|
||||
),
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -962,29 +944,22 @@ impl Agent {
|
||||
);
|
||||
}
|
||||
|
||||
// We need interior mutability in handle_approval_tool_requests
|
||||
let tool_futures_arc = Arc::new(Mutex::new(tool_futures));
|
||||
|
||||
// Process tools requiring approval (enable extension, regular tool calls)
|
||||
// Process tools requiring approval
|
||||
let mut tool_approval_stream = self.handle_approval_tool_requests(
|
||||
&permission_check_result.needs_approval,
|
||||
tool_futures_arc.clone(),
|
||||
&mut permission_manager,
|
||||
message_tool_response.clone()
|
||||
message_tool_response.clone(),
|
||||
);
|
||||
|
||||
// We have a stream of tool_approval_requests to handle
|
||||
// Execution is yielded back to this reply loop, and is of the same Message
|
||||
// type, so we can yield the Message back up to be handled and grab any
|
||||
// confirmations or denials
|
||||
while let Some(msg) = tool_approval_stream.try_next().await? {
|
||||
yield AgentEvent::Message(msg);
|
||||
}
|
||||
|
||||
tool_futures = {
|
||||
// Lock the mutex asynchronously
|
||||
let mut futures_lock = tool_futures_arc.lock().await;
|
||||
// Drain the vector and collect into a new Vec
|
||||
futures_lock.drain(..).collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
@@ -996,27 +971,33 @@ impl Agent {
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut combined = stream::select_all(with_id);
|
||||
|
||||
let mut all_install_successful = true;
|
||||
|
||||
while let Some((request_id, item)) = combined.next().await {
|
||||
if cancel_token.as_ref().is_some_and(|t| t.is_cancelled()) {
|
||||
break;
|
||||
}
|
||||
match item {
|
||||
ToolStreamItem::Result(output) => {
|
||||
if enable_extension_request_ids.contains(&request_id) && output.is_err(){
|
||||
if enable_extension_request_ids.contains(&request_id)
|
||||
&& output.is_err()
|
||||
{
|
||||
all_install_successful = false;
|
||||
}
|
||||
let mut response = message_tool_response.lock().await;
|
||||
*response = response.clone().with_tool_response(request_id, output);
|
||||
},
|
||||
*response =
|
||||
response.clone().with_tool_response(request_id, output);
|
||||
}
|
||||
ToolStreamItem::Message(msg) => {
|
||||
yield AgentEvent::McpNotification((request_id, msg))
|
||||
yield AgentEvent::McpNotification((
|
||||
request_id, msg,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update system prompt and tools if installations were successful
|
||||
if all_install_successful {
|
||||
(tools, toolshim_tools, system_prompt) = self.prepare_tools_and_prompt().await?;
|
||||
tools_updated = true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1024,63 +1005,39 @@ impl Agent {
|
||||
yield AgentEvent::Message(final_message_tool_resp.clone());
|
||||
|
||||
added_message = true;
|
||||
push_message(&mut messages, response);
|
||||
push_message(&mut messages, final_message_tool_resp);
|
||||
|
||||
// Check for MCP notifications from subagents again before next iteration
|
||||
// Note: These are already handled as McpNotification events above,
|
||||
// so we don't need to convert them to assistant messages here.
|
||||
// This was causing duplicate plain-text notifications.
|
||||
// let mcp_notifications = self.get_mcp_notifications().await;
|
||||
// for notification in mcp_notifications {
|
||||
// // Extract subagent info from the notification data for assistant messages
|
||||
// if let JsonRpcMessage::Notification(ref notif) = notification {
|
||||
// if let Some(params) = ¬if.params {
|
||||
// if let Some(data) = params.get("data") {
|
||||
// if let (Some(subagent_id), Some(message)) = (
|
||||
// data.get("subagent_id").and_then(|v| v.as_str()),
|
||||
// data.get("message").and_then(|v| v.as_str())
|
||||
// ) {
|
||||
// yield AgentEvent::Message(
|
||||
// Message::assistant().with_text(
|
||||
// format!("Subagent {}: {}", subagent_id, message)
|
||||
// )
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
push_message(&mut messages_to_add, response);
|
||||
push_message(&mut messages_to_add, final_message_tool_resp);
|
||||
}
|
||||
},
|
||||
}
|
||||
Err(ProviderError::ContextLengthExceeded(_)) => {
|
||||
// At this point, the last message should be a user message
|
||||
// because call to provider led to context length exceeded error
|
||||
// Immediately yield a special message and break
|
||||
yield AgentEvent::Message(Message::assistant().with_context_length_exceeded(
|
||||
"The context length of the model has been exceeded. Please start a new session and try again.",
|
||||
));
|
||||
"The context length of the model has been exceeded. Please start a new session and try again.",
|
||||
));
|
||||
break;
|
||||
},
|
||||
}
|
||||
Err(e) => {
|
||||
// Create an error message & terminate the stream
|
||||
error!("Error: {}", e);
|
||||
yield AgentEvent::Message(Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error.")));
|
||||
yield AgentEvent::Message(Message::assistant().with_text(
|
||||
format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error.")
|
||||
));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if tools_updated {
|
||||
(tools, toolshim_tools, system_prompt) = self.prepare_tools_and_prompt().await?;
|
||||
}
|
||||
if !added_message {
|
||||
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
|
||||
if final_output_tool.final_output.is_none() {
|
||||
tracing::warn!("Final output tool has not been called yet. Continuing agent loop.");
|
||||
let message = Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE);
|
||||
messages.push(message.clone());
|
||||
messages_to_add.push(message.clone());
|
||||
yield AgentEvent::Message(message);
|
||||
continue;
|
||||
continue
|
||||
} else {
|
||||
let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap());
|
||||
messages.push(message.clone());
|
||||
messages_to_add.push(message.clone());
|
||||
yield AgentEvent::Message(message);
|
||||
}
|
||||
}
|
||||
@@ -1099,16 +1056,28 @@ impl Agent {
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
// Yield control back to the scheduler to prevent blocking
|
||||
messages.extend(messages_to_add);
|
||||
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn determine_goose_mode(session: Option<&SessionConfig>, config: &Config) -> String {
|
||||
let mode = session.and_then(|s| s.execution_mode.as_deref());
|
||||
|
||||
match mode {
|
||||
Some("foreground") => "auto".to_string(),
|
||||
Some("background") => "chat".to_string(),
|
||||
_ => config
|
||||
.get_param("GOOSE_MODE")
|
||||
.unwrap_or_else(|_| "auto".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extend the system prompt with one line of additional instruction
|
||||
pub async fn extend_system_prompt(&self, instruction: String) {
|
||||
let mut prompt_manager = self.prompt_manager.lock().await;
|
||||
@@ -1127,7 +1096,6 @@ impl Agent {
|
||||
notifications
|
||||
}
|
||||
|
||||
/// Update the provider
|
||||
pub async fn update_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
|
||||
let mut current_provider = self.provider.lock().await;
|
||||
*current_provider = Some(provider.clone());
|
||||
@@ -1191,10 +1159,9 @@ impl Agent {
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
error!(
|
||||
"Failed to index tools for extension {}: {}",
|
||||
extension_name,
|
||||
e
|
||||
extension_name, e
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1243,7 +1210,7 @@ impl Agent {
|
||||
Err(anyhow!("Prompt '{}' not found", name))
|
||||
}
|
||||
|
||||
pub async fn get_plan_prompt(&self) -> anyhow::Result<String> {
|
||||
pub async fn get_plan_prompt(&self) -> Result<String> {
|
||||
let extension_manager = self.extension_manager.read().await;
|
||||
let tools = extension_manager.get_prefixed_tools(None).await?;
|
||||
let tools_info = tools
|
||||
@@ -1265,7 +1232,7 @@ impl Agent {
|
||||
|
||||
pub async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>) {
|
||||
if let Err(e) = self.tool_result_tx.send((id, result)).await {
|
||||
tracing::error!("Failed to send tool result: {}", e);
|
||||
error!("Failed to send tool result: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1356,7 +1323,7 @@ impl Agent {
|
||||
let activities_text = activities_text.trim();
|
||||
|
||||
// Regex to remove bullet markers or numbers with an optional dot.
|
||||
let bullet_re = Regex::new(r"^[•\-\*\d]+\.?\s*").expect("Invalid regex");
|
||||
let bullet_re = Regex::new(r"^[•\-*\d]+\.?\s*").expect("Invalid regex");
|
||||
|
||||
// Process each line in the activities section.
|
||||
let activities: Vec<String> = activities_text
|
||||
@@ -1383,7 +1350,7 @@ impl Agent {
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
// Ideally we'd get the name of the provider we are using from the provider itself
|
||||
// Ideally we'd get the name of the provider we are using from the provider itself,
|
||||
// but it doesn't know and the plumbing looks complicated.
|
||||
let config = Config::global();
|
||||
let provider_name: String = config
|
||||
@@ -1443,7 +1410,7 @@ mod tests {
|
||||
|
||||
let prompt_manager = agent.prompt_manager.lock().await;
|
||||
let system_prompt =
|
||||
prompt_manager.build_system_prompt(vec![], None, serde_json::Value::Null, None, None);
|
||||
prompt_manager.build_system_prompt(vec![], None, Value::Null, None, None);
|
||||
|
||||
let final_output_tool_ref = agent.final_output_tool.lock().await;
|
||||
let final_output_tool_system_prompt =
|
||||
|
||||
@@ -275,10 +275,9 @@ impl Agent {
|
||||
(frontend_requests, other_requests, filtered_message)
|
||||
}
|
||||
|
||||
/// Update session metrics after a response
|
||||
pub(crate) async fn update_session_metrics(
|
||||
session_config: crate::agents::types::SessionConfig,
|
||||
usage: &crate::providers::base::ProviderUsage,
|
||||
session_config: &crate::agents::types::SessionConfig,
|
||||
usage: &ProviderUsage,
|
||||
messages_length: usize,
|
||||
) -> Result<()> {
|
||||
let session_file_path = match session::storage::get_path(session_config.id.clone()) {
|
||||
|
||||
@@ -1208,7 +1208,7 @@ async fn run_scheduled_job_internal(
|
||||
};
|
||||
|
||||
match agent
|
||||
.reply(&all_session_messages, Some(session_config.clone()))
|
||||
.reply(&all_session_messages, Some(session_config.clone()), None)
|
||||
.await
|
||||
{
|
||||
Ok(mut stream) => {
|
||||
|
||||
@@ -129,7 +129,7 @@ async fn run_truncate_test(
|
||||
),
|
||||
];
|
||||
|
||||
let reply_stream = agent.reply(&messages, None).await?;
|
||||
let reply_stream = agent.reply(&messages, None, None).await?;
|
||||
tokio::pin!(reply_stream);
|
||||
|
||||
let mut responses = Vec::new();
|
||||
@@ -619,7 +619,7 @@ mod final_output_tool_tests {
|
||||
);
|
||||
|
||||
// Simulate the reply stream continuing after the final output tool call.
|
||||
let reply_stream = agent.reply(&vec![], None).await?;
|
||||
let reply_stream = agent.reply(&vec![], None, None).await?;
|
||||
tokio::pin!(reply_stream);
|
||||
|
||||
let mut responses = Vec::new();
|
||||
@@ -716,7 +716,7 @@ mod final_output_tool_tests {
|
||||
agent.add_final_output_tool(response).await;
|
||||
|
||||
// Simulate the reply stream being called.
|
||||
let reply_stream = agent.reply(&vec![], None).await?;
|
||||
let reply_stream = agent.reply(&vec![], None, None).await?;
|
||||
tokio::pin!(reply_stream);
|
||||
|
||||
let mut responses = Vec::new();
|
||||
@@ -850,7 +850,9 @@ mod retry_tests {
|
||||
|
||||
let initial_messages = vec![Message::user().with_text("Complete this task")];
|
||||
|
||||
let reply_stream = agent.reply(&initial_messages, Some(session_config)).await?;
|
||||
let reply_stream = agent
|
||||
.reply(&initial_messages, Some(session_config), None)
|
||||
.await?;
|
||||
tokio::pin!(reply_stream);
|
||||
|
||||
let mut responses = Vec::new();
|
||||
@@ -1013,7 +1015,7 @@ mod max_turns_tests {
|
||||
};
|
||||
let messages = vec![Message::user().with_text("Hello")];
|
||||
|
||||
let reply_stream = agent.reply(&messages, Some(session_config)).await?;
|
||||
let reply_stream = agent.reply(&messages, Some(session_config), None).await?;
|
||||
tokio::pin!(reply_stream);
|
||||
|
||||
let mut responses = Vec::new();
|
||||
|
||||
@@ -2332,17 +2332,16 @@
|
||||
}
|
||||
},
|
||||
"ResourceContents": {
|
||||
"oneOf": [
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"uri",
|
||||
"text"
|
||||
"text",
|
||||
"uri"
|
||||
],
|
||||
"properties": {
|
||||
"mime_type": {
|
||||
"type": "string",
|
||||
"nullable": true
|
||||
"type": "string"
|
||||
},
|
||||
"text": {
|
||||
"type": "string"
|
||||
@@ -2355,16 +2354,15 @@
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"uri",
|
||||
"blob"
|
||||
"blob",
|
||||
"uri"
|
||||
],
|
||||
"properties": {
|
||||
"blob": {
|
||||
"type": "string"
|
||||
},
|
||||
"mime_type": {
|
||||
"type": "string",
|
||||
"nullable": true
|
||||
"type": "string"
|
||||
},
|
||||
"uri": {
|
||||
"type": "string"
|
||||
|
||||
@@ -1174,7 +1174,7 @@ export default function App() {
|
||||
}, []);
|
||||
|
||||
const config = window.electron.getConfig();
|
||||
const STRICT_ALLOWLIST = config.GOOSE_ALLOWLIST_WARNING === true ? false : true;
|
||||
const STRICT_ALLOWLIST = config.GOOSE_ALLOWLIST_WARNING !== true;
|
||||
|
||||
useEffect(() => {
|
||||
console.log('Setting up extension handler');
|
||||
|
||||
@@ -463,12 +463,12 @@ export type RedactedThinkingContent = {
|
||||
};
|
||||
|
||||
export type ResourceContents = {
|
||||
mime_type?: string | null;
|
||||
mime_type?: string;
|
||||
text: string;
|
||||
uri: string;
|
||||
} | {
|
||||
blob: string;
|
||||
mime_type?: string | null;
|
||||
mime_type?: string;
|
||||
uri: string;
|
||||
};
|
||||
|
||||
|
||||
@@ -537,7 +537,7 @@ function BaseChatContent({
|
||||
recipeDetails={{
|
||||
title: recipeConfig?.title,
|
||||
description: recipeConfig?.description,
|
||||
instructions: recipeConfig?.instructions,
|
||||
instructions: recipeConfig?.instructions || undefined,
|
||||
}}
|
||||
/>
|
||||
|
||||
|
||||
@@ -126,10 +126,10 @@ const notificationToProgress = (notification: NotificationEvent): Progress =>
|
||||
const getExtensionTooltip = (toolCallName: string): string | null => {
|
||||
const lastIndex = toolCallName.lastIndexOf('__');
|
||||
if (lastIndex === -1) return null;
|
||||
|
||||
|
||||
const extensionName = toolCallName.substring(0, lastIndex);
|
||||
if (!extensionName) return null;
|
||||
|
||||
|
||||
return `${extensionName} extension`;
|
||||
};
|
||||
|
||||
@@ -377,7 +377,7 @@ function ToolCallView({
|
||||
// This ensures any MCP tool works without explicit handling
|
||||
const toolDisplayName = snakeToTitleCase(toolName);
|
||||
const entries = Object.entries(args);
|
||||
|
||||
|
||||
if (entries.length === 0) {
|
||||
return `${toolDisplayName}`;
|
||||
}
|
||||
@@ -413,7 +413,7 @@ function ToolCallView({
|
||||
};
|
||||
|
||||
const toolLabel = (
|
||||
<span className={cn("ml-2", extensionTooltip && "cursor-pointer hover:opacity-80")}>
|
||||
<span className={cn('ml-2', extensionTooltip && 'cursor-pointer hover:opacity-80')}>
|
||||
{getToolLabelContent()}
|
||||
</span>
|
||||
);
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
import type { OpenDialogReturnValue } from 'electron';
|
||||
import {
|
||||
app,
|
||||
session,
|
||||
App,
|
||||
BrowserWindow,
|
||||
dialog,
|
||||
Event,
|
||||
globalShortcut,
|
||||
ipcMain,
|
||||
Menu,
|
||||
MenuItem,
|
||||
Notification,
|
||||
powerSaveBlocker,
|
||||
Tray,
|
||||
App,
|
||||
globalShortcut,
|
||||
session,
|
||||
shell,
|
||||
Event,
|
||||
Tray,
|
||||
} from 'electron';
|
||||
import type { OpenDialogReturnValue } from 'electron';
|
||||
import { Buffer } from 'node:buffer';
|
||||
import fs from 'node:fs/promises';
|
||||
import fsSync from 'node:fs';
|
||||
@@ -33,20 +33,20 @@ import {
|
||||
EnvToggles,
|
||||
loadSettings,
|
||||
saveSettings,
|
||||
SchedulingEngine,
|
||||
updateEnvironmentVariables,
|
||||
updateSchedulingEngineEnvironment,
|
||||
SchedulingEngine,
|
||||
} from './utils/settings';
|
||||
import * as crypto from 'crypto';
|
||||
// import electron from "electron";
|
||||
import * as yaml from 'yaml';
|
||||
import windowStateKeeper from 'electron-window-state';
|
||||
import {
|
||||
setupAutoUpdater,
|
||||
getUpdateAvailable,
|
||||
registerUpdateIpcHandlers,
|
||||
setTrayRef,
|
||||
setupAutoUpdater,
|
||||
updateTrayMenu,
|
||||
getUpdateAvailable,
|
||||
} from './utils/autoUpdater';
|
||||
import { UPDATES_ENABLED } from './updates';
|
||||
import { Recipe } from './recipe';
|
||||
@@ -589,7 +589,7 @@ const createChat = async (
|
||||
titleBarStyle: process.platform === 'darwin' ? 'hidden' : 'default',
|
||||
trafficLightPosition: process.platform === 'darwin' ? { x: 20, y: 16 } : undefined,
|
||||
vibrancy: process.platform === 'darwin' ? 'window' : undefined,
|
||||
frame: process.platform === 'darwin' ? false : true,
|
||||
frame: process.platform !== 'darwin',
|
||||
x: mainWindowState.x,
|
||||
y: mainWindowState.y,
|
||||
width: mainWindowState.width,
|
||||
@@ -1542,15 +1542,8 @@ ipcMain.handle('show-message-box', async (_event, options) => {
|
||||
return result;
|
||||
});
|
||||
|
||||
// Handle allowed extensions list fetching
|
||||
ipcMain.handle('get-allowed-extensions', async () => {
|
||||
try {
|
||||
const allowList = await getAllowList();
|
||||
return allowList;
|
||||
} catch (error) {
|
||||
console.error('Error fetching allowed extensions:', error);
|
||||
throw error;
|
||||
}
|
||||
return await getAllowList();
|
||||
});
|
||||
|
||||
const createNewWindow = async (app: App, dir?: string | null) => {
|
||||
@@ -2068,57 +2061,33 @@ app.whenReady().then(async () => {
|
||||
});
|
||||
});
|
||||
|
||||
/**
|
||||
* Fetches the allowed extensions list from the remote YAML file if GOOSE_ALLOWLIST is set.
|
||||
* If the ALLOWLIST is not set, any are allowed. If one is set, it will warn if the deeplink
|
||||
* doesn't match a command from the list.
|
||||
* If it fails to load, then it will return an empty list.
|
||||
* If the format is incorrect, it will return an empty list.
|
||||
* Format of yaml is:
|
||||
*
|
||||
```yaml:
|
||||
extensions:
|
||||
- id: slack
|
||||
command: uvx mcp_slack
|
||||
- id: knowledge_graph_memory
|
||||
command: npx -y @modelcontextprotocol/server-memory
|
||||
```
|
||||
*
|
||||
* @returns A promise that resolves to an array of extension commands that are allowed.
|
||||
*/
|
||||
async function getAllowList(): Promise<string[]> {
|
||||
if (!process.env.GOOSE_ALLOWLIST) {
|
||||
return [];
|
||||
}
|
||||
|
||||
try {
|
||||
// Fetch the YAML file
|
||||
const response = await fetch(process.env.GOOSE_ALLOWLIST);
|
||||
const response = await fetch(process.env.GOOSE_ALLOWLIST);
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Failed to fetch allowed extensions: ${response.status} ${response.statusText}`
|
||||
);
|
||||
}
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Failed to fetch allowed extensions: ${response.status} ${response.statusText}`
|
||||
);
|
||||
}
|
||||
|
||||
// Parse the YAML content
|
||||
const yamlContent = await response.text();
|
||||
const parsedYaml = yaml.parse(yamlContent);
|
||||
// Parse the YAML content
|
||||
const yamlContent = await response.text();
|
||||
const parsedYaml = yaml.parse(yamlContent);
|
||||
|
||||
// Extract the commands from the extensions array
|
||||
if (parsedYaml && parsedYaml.extensions && Array.isArray(parsedYaml.extensions)) {
|
||||
const commands = parsedYaml.extensions.map(
|
||||
(ext: { id: string; command: string }) => ext.command
|
||||
);
|
||||
console.log(`Fetched ${commands.length} allowed extension commands`);
|
||||
return commands;
|
||||
} else {
|
||||
console.error('Invalid YAML structure:', parsedYaml);
|
||||
return [];
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error in getAllowList:', error);
|
||||
throw error;
|
||||
// Extract the commands from the extensions array
|
||||
if (parsedYaml && parsedYaml.extensions && Array.isArray(parsedYaml.extensions)) {
|
||||
const commands = parsedYaml.extensions.map(
|
||||
(ext: { id: string; command: string }) => ext.command
|
||||
);
|
||||
console.log(`Fetched ${commands.length} allowed extension commands`);
|
||||
return commands;
|
||||
} else {
|
||||
console.error('Invalid YAML structure:', parsedYaml);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,198 +0,0 @@
|
||||
import { getApiUrl, getSecretKey } from '../config';
|
||||
import { safeJsonParse } from './jsonUtils';
|
||||
|
||||
const getQuestionClassifierPrompt = (messageContent: string): string => `
|
||||
You are a simple classifier that takes content and decides if it is asking for input
|
||||
from a person before continuing if there is more to do, or not. These are questions
|
||||
on if a course of action should proceeed or not, or approval is needed. If it is CLEARLY a
|
||||
question asking if it ok to proceed or make a choice or some input is required to proceed, then, and ONLY THEN, return QUESTION, otherwise READY if not 97% sure.
|
||||
|
||||
### Examples message content that is classified as READY:
|
||||
anything else I can do?
|
||||
Could you please run the application and verify that the headlines are now visible in dark mode? You can use npm start.
|
||||
Would you like me to make any adjustments to the formatting of these multiline strings?
|
||||
Would you like me to show you how to ... (do something)?
|
||||
Listing window titles... Is there anything specific you'd like help with using these tools?
|
||||
Would you like me to demonstrate any specific capability or help you with a particular task?
|
||||
Would you like me to run any tests?
|
||||
Would you like me to make any adjustments or would you like to test?
|
||||
Would you like me to dive deeper into any aspect?
|
||||
Would you like me to make any other adjustments to this implementation?
|
||||
Would you like any further information or assistance?
|
||||
Would you like to me to make any changes?
|
||||
Would you like me to make any adjustments to this implementation?
|
||||
Would you like me to show you how to…
|
||||
What would you like to do next?
|
||||
|
||||
### Examples that are QUESTIONS:
|
||||
Should I go ahead and make the changes?
|
||||
Should I Go ahead with this plan?
|
||||
Should I focus on X or Y?
|
||||
Provide me with the name of the package and version you would like to install.
|
||||
|
||||
|
||||
### Message Content:
|
||||
${messageContent}
|
||||
|
||||
You must provide a response strictly limited to one of the following two words:
|
||||
QUESTION, READY. No other words, phrases, or explanations are allowed.
|
||||
|
||||
Response:`;
|
||||
|
||||
const getOptionsClassifierPrompt = (messageContent: string): string => `
|
||||
You are a simple classifier that takes content and decides if it a list of options
|
||||
or plans to choose from, or not a list of options to choose from. It is IMPORTANT
|
||||
that you really know this is a choice, just not numbered steps. If it is a list
|
||||
of options and you are 95% sure, return OPTIONS, otherwise return NO.
|
||||
|
||||
### Example (text -> response):
|
||||
Would you like me to proceed with creating this file? Please let me know if you want any changes before I write it. -> NO
|
||||
Here are some options for you to choose from: -> OPTIONS
|
||||
which one do you want to choose? -> OPTIONS
|
||||
Would you like me to dive deeper into any aspects of these components? -> NO
|
||||
Should I focus on X or Y? -> OPTIONS
|
||||
|
||||
### Message Content:
|
||||
${messageContent}
|
||||
|
||||
You must provide a response strictly limited to one of the following two words:
|
||||
OPTIONS, NO. No other words, phrases, or explanations are allowed.
|
||||
|
||||
Response:`;
|
||||
|
||||
const getOptionsFormatterPrompt = (messageContent: string): string => `
|
||||
If the content is list of distinct options or plans of action to choose from, and
|
||||
not just a list of things, but clearly a list of things to choose one from, taking
|
||||
into account the Message Content alone, try to format it in a json array, like this
|
||||
JSON array of objects of the form optionTitle:string, optionDescription:string (markdown).
|
||||
|
||||
If is not a list of options or plans to choose from, then return empty list.
|
||||
|
||||
### Message Content:
|
||||
${messageContent}
|
||||
|
||||
You must provide a response strictly as json in the format descriribed. No other
|
||||
words, phrases, or explanations are allowed.
|
||||
|
||||
Response:`;
|
||||
|
||||
const getFormPrompt = (messageContent: string): string => `
|
||||
When you see a request for several pieces of information, then provide a well formed JSON object like will be shown below.
|
||||
The response will have:
|
||||
* a title, description,
|
||||
* a list of fields, each field will have a label, type, name, placeholder, and required (boolean).
|
||||
(type is either text or textarea only).
|
||||
If it is not requesting clearly several pieces of information, just return an empty object.
|
||||
If the task could be confirmed without more information, return an empty object.
|
||||
|
||||
### Example Message:
|
||||
I'll help you scaffold out a Python package. To create a well-structured Python package, I'll need to know a few key pieces of information:
|
||||
|
||||
Package name - What would you like to call your package? (This should be a valid Python package name - lowercase, no spaces, typically using underscores for separators if needed)
|
||||
|
||||
Brief description - What is the main purpose of the package? This helps in setting up appropriate documentation and structure.
|
||||
|
||||
Initial modules - Do you have specific functionality in mind that should be split into different modules?
|
||||
|
||||
Python version - Which Python version(s) do you want to support?
|
||||
|
||||
Dependencies - Are there any known external packages you'll need?
|
||||
|
||||
### Example JSON Response:
|
||||
{
|
||||
"title": "Python Package Scaffolding Form",
|
||||
"description": "Provide the details below to scaffold a well-structured Python package.",
|
||||
"fields": [
|
||||
{
|
||||
"label": "Package Name",
|
||||
"type": "text",
|
||||
"name": "package_name",
|
||||
"placeholder": "Enter the package name (lowercase, no spaces, use underscores if needed)",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"label": "Brief Description",
|
||||
"type": "textarea",
|
||||
"name": "brief_description",
|
||||
"placeholder": "Enter a brief description of the package's purpose",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"label": "Initial Modules",
|
||||
"type": "textarea",
|
||||
"name": "initial_modules",
|
||||
"placeholder": "List the specific functionalities or modules (optional)",
|
||||
"required": false
|
||||
},
|
||||
{
|
||||
"label": "Python Version(s)",
|
||||
"type": "text",
|
||||
"name": "python_versions",
|
||||
"placeholder": "Enter the Python version(s) to support (e.g., 3.8, 3.9, 3.10)",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"label": "Dependencies",
|
||||
"type": "textarea",
|
||||
"name": "dependencies",
|
||||
"placeholder": "List any known external packages you'll need (optional)",
|
||||
"required": false
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
### Message Content:
|
||||
${messageContent}
|
||||
|
||||
You must provide a response strictly as json in the format described. No other
|
||||
words, phrases, or explanations are allowed.
|
||||
|
||||
Response:`;
|
||||
|
||||
/**
|
||||
* Core function to ask the AI a single question and get a response
|
||||
* @param prompt The prompt to send to the AI
|
||||
* @returns Promise<string> The AI's response
|
||||
*/
|
||||
export async function ask(prompt: string): Promise<string> {
|
||||
const response = await fetch(getApiUrl('/ask'), {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'X-Secret-Key': getSecretKey(),
|
||||
},
|
||||
body: JSON.stringify({ prompt }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to get response');
|
||||
}
|
||||
|
||||
const data = await safeJsonParse<{ response: string }>(response, 'Failed to get AI response');
|
||||
return data.response;
|
||||
}
|
||||
|
||||
/**
|
||||
* Utility to ask the LLM multiple questions to clarify without wider context.
|
||||
* @param messageContent The content to analyze
|
||||
* @returns Promise<string[]> Array of responses from the AI for each classifier
|
||||
*/
|
||||
export async function askAi(messageContent: string): Promise<string[]> {
|
||||
// First, check the question classifier
|
||||
const questionClassification = await ask(getQuestionClassifierPrompt(messageContent));
|
||||
|
||||
// If READY, return early with empty responses for options
|
||||
if (questionClassification === 'READY') {
|
||||
return [questionClassification, 'NO', '[]', '{}'];
|
||||
}
|
||||
|
||||
// Otherwise, proceed with all classifiers in parallel
|
||||
const prompts = [
|
||||
Promise.resolve(questionClassification), // Reuse the result we already have
|
||||
ask(getOptionsClassifierPrompt(messageContent)),
|
||||
ask(getOptionsFormatterPrompt(messageContent)),
|
||||
ask(getFormPrompt(messageContent)),
|
||||
];
|
||||
|
||||
return Promise.all(prompts);
|
||||
}
|
||||
Reference in New Issue
Block a user