From f7f25402873219ce8d2aaf92e9b56ff959f4e964 Mon Sep 17 00:00:00 2001 From: Yingjie He Date: Fri, 28 Feb 2025 17:00:41 -0800 Subject: [PATCH] feat: support goose mode in UI (#1434) Co-authored-by: Lily Delalande --- crates/goose-server/src/routes/agent.rs | 4 +- crates/goose-server/src/routes/configs.rs | 46 ++++++++++++- crates/goose-server/src/routes/extension.rs | 4 +- crates/goose-server/src/routes/reply.rs | 40 ++++++++++-- crates/goose-server/src/state.rs | 6 +- crates/goose/src/agents/truncate.rs | 4 +- ui/desktop/package-lock.json | 1 + ui/desktop/package.json | 1 + ui/desktop/src/components/ChatView.tsx | 20 +++++- ui/desktop/src/components/GooseMessage.tsx | 15 ++++- .../src/components/ToolCallConfirmation.tsx | 39 +++++++++++ .../src/components/ToolCallWithResponse.tsx | 3 + .../src/components/settings/SettingsView.tsx | 65 ++++++++++++++++++- .../settings/basic/ModeSelection.tsx | 50 ++++++++++++++ ui/desktop/src/types/message.ts | 16 +++++ ui/desktop/src/utils/toolConfirm.ts | 25 +++++++ 16 files changed, 320 insertions(+), 19 deletions(-) create mode 100644 ui/desktop/src/components/ToolCallConfirmation.tsx create mode 100644 ui/desktop/src/components/settings/basic/ModeSelection.tsx create mode 100644 ui/desktop/src/utils/toolConfirm.ts diff --git a/crates/goose-server/src/routes/agent.rs b/crates/goose-server/src/routes/agent.rs index 204fd82a..f46b548f 100644 --- a/crates/goose-server/src/routes/agent.rs +++ b/crates/goose-server/src/routes/agent.rs @@ -86,7 +86,7 @@ async fn extend_prompt( return Err(StatusCode::UNAUTHORIZED); } - let mut agent = state.agent.lock().await; + let mut agent = state.agent.write().await; if let Some(ref mut agent) = *agent { agent.extend_system_prompt(payload.extension).await; Ok(Json(ExtendPromptResponse { success: true })) @@ -134,7 +134,7 @@ async fn create_agent( let new_agent = AgentFactory::create(&version, provider).expect("Failed to create agent"); - let mut agent = state.agent.lock().await; + let mut agent = state.agent.write().await; *agent = Some(new_agent); Ok(Json(CreateAgentResponse { version })) diff --git a/crates/goose-server/src/routes/configs.rs b/crates/goose-server/src/routes/configs.rs index 05e599c4..ca557787 100644 --- a/crates/goose-server/src/routes/configs.rs +++ b/crates/goose-server/src/routes/configs.rs @@ -1,5 +1,9 @@ use crate::state::AppState; -use axum::{extract::State, routing::delete, routing::post, Json, Router}; +use axum::{ + extract::{Query, State}, + routing::{delete, get, post}, + Json, Router, +}; use goose::config::Config; use http::{HeaderMap, StatusCode}; use once_cell::sync::Lazy; @@ -140,6 +144,45 @@ async fn check_provider_configs( Ok(Json(response)) } +#[derive(Deserialize)] +pub struct GetConfigQuery { + key: String, +} + +#[derive(Serialize)] +pub struct GetConfigResponse { + value: Option, +} + +pub async fn get_config( + State(state): State, + headers: HeaderMap, + Query(query): Query, +) -> Result, StatusCode> { + // Verify secret key + let secret_key = headers + .get("X-Secret-Key") + .and_then(|value| value.to_str().ok()) + .ok_or(StatusCode::UNAUTHORIZED)?; + + if secret_key != state.secret_key { + return Err(StatusCode::UNAUTHORIZED); + } + + // Fetch the configuration value. Right now we don't allow get a secret. + let config = Config::global(); + let value = if let Ok(config_value) = config.get::(&query.key) { + Some(config_value) + } else if let Ok(env_value) = std::env::var(&query.key) { + Some(env_value) + } else { + None + }; + + // Return the value + Ok(Json(GetConfigResponse { value })) +} + #[derive(Deserialize)] #[serde(rename_all = "camelCase")] struct DeleteConfigRequest { @@ -178,6 +221,7 @@ async fn delete_config( pub fn routes(state: AppState) -> Router { Router::new() .route("/configs/providers", post(check_provider_configs)) + .route("/configs/get", get(get_config)) .route("/configs/store", post(store_config)) .route("/configs/delete", delete(delete_config)) .with_state(state) diff --git a/crates/goose-server/src/routes/extension.rs b/crates/goose-server/src/routes/extension.rs index 21ab9f5b..60a5b996 100644 --- a/crates/goose-server/src/routes/extension.rs +++ b/crates/goose-server/src/routes/extension.rs @@ -161,7 +161,7 @@ async fn add_extension( }; // Acquire a lock on the agent and attempt to add the extension. - let mut agent = state.agent.lock().await; + let mut agent = state.agent.write().await; let agent = agent.as_mut().ok_or(StatusCode::PRECONDITION_REQUIRED)?; let response = agent.add_extension(extension_config).await; @@ -201,7 +201,7 @@ async fn remove_extension( } // Acquire a lock on the agent and attempt to remove the extension - let mut agent = state.agent.lock().await; + let mut agent = state.agent.write().await; let agent = agent.as_mut().ok_or(StatusCode::PRECONDITION_REQUIRED)?; agent.remove_extension(&name).await; diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index f1c4b38a..f0220dac 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -12,6 +12,7 @@ use goose::message::{Message, MessageContent}; use mcp_core::role::Role; use serde::{Deserialize, Serialize}; +use serde_json::Value; use std::{ convert::Infallible, pin::Pin, @@ -113,7 +114,7 @@ async fn handler( // Spawn task to handle streaming tokio::spawn(async move { - let agent = agent.lock().await; + let agent = agent.read().await; let agent = match agent.as_ref() { Some(agent) => agent, None => { @@ -237,7 +238,7 @@ async fn ask_handler( } let agent = state.agent.clone(); - let agent = agent.lock().await; + let agent = agent.write().await; let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?; // Create a single message for the prompt @@ -277,11 +278,42 @@ async fn ask_handler( })) } +#[derive(Debug, Deserialize)] +struct ToolConfirmationRequest { + id: String, + confirmed: bool, +} + +async fn confirm_handler( + State(state): State, + headers: HeaderMap, + Json(request): Json, +) -> Result, StatusCode> { + // Verify secret key + let secret_key = headers + .get("X-Secret-Key") + .and_then(|value| value.to_str().ok()) + .ok_or(StatusCode::UNAUTHORIZED)?; + + if secret_key != state.secret_key { + return Err(StatusCode::UNAUTHORIZED); + } + + let agent = state.agent.clone(); + let agent = agent.read().await; + let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?; + agent + .handle_confirmation(request.id.clone(), request.confirmed) + .await; + Ok(Json(Value::Object(serde_json::Map::new()))) +} + // Configure routes for this module pub fn routes(state: AppState) -> Router { Router::new() .route("/reply", post(handler)) .route("/ask", post(ask_handler)) + .route("/confirm", post(confirm_handler)) .with_state(state) } @@ -332,7 +364,7 @@ mod tests { use axum::{body::Body, http::Request}; use std::collections::HashMap; use std::sync::Arc; - use tokio::sync::Mutex; + use tokio::sync::{Mutex, RwLock}; use tower::ServiceExt; // This test requires tokio runtime @@ -346,7 +378,7 @@ mod tests { let agent = AgentFactory::create("reference", mock_provider).unwrap(); let state = AppState { config: Arc::new(Mutex::new(HashMap::new())), - agent: Arc::new(Mutex::new(Some(agent))), + agent: Arc::new(RwLock::new(Some(agent))), secret_key: "test-secret".to_string(), }; diff --git a/crates/goose-server/src/state.rs b/crates/goose-server/src/state.rs index 7752889c..0e31ad53 100644 --- a/crates/goose-server/src/state.rs +++ b/crates/goose-server/src/state.rs @@ -3,13 +3,13 @@ use goose::agents::Agent; use serde_json::Value; use std::collections::HashMap; use std::sync::Arc; -use tokio::sync::Mutex; +use tokio::sync::{Mutex, RwLock}; /// Shared application state #[allow(dead_code)] #[derive(Clone)] pub struct AppState { - pub agent: Arc>>>, + pub agent: Arc>>>, pub secret_key: String, pub config: Arc>>, } @@ -17,7 +17,7 @@ pub struct AppState { impl AppState { pub async fn new(secret_key: String) -> Result { Ok(Self { - agent: Arc::new(Mutex::new(None)), + agent: Arc::new(RwLock::new(None)), secret_key, config: Arc::new(Mutex::new(HashMap::new())), }) diff --git a/crates/goose/src/agents/truncate.rs b/crates/goose/src/agents/truncate.rs index c795f93b..f2717789 100644 --- a/crates/goose/src/agents/truncate.rs +++ b/crates/goose/src/agents/truncate.rs @@ -274,7 +274,8 @@ impl Agent for TruncateAgent { // Wait for confirmation response through the channel let mut rx = self.confirmation_rx.lock().await; - if let Some((req_id, confirmed)) = rx.recv().await { + // Loop the recv until we have a matched req_id due to potential duplicate messages. + while let Some((req_id, confirmed)) = rx.recv().await { if req_id == request.id { if confirmed { // User approved - dispatch the tool call @@ -290,6 +291,7 @@ impl Agent for TruncateAgent { Ok(vec![Content::text("User declined to run this tool.")]), ); } + break; // Exit the loop once the matching `req_id` is found } } } diff --git a/ui/desktop/package-lock.json b/ui/desktop/package-lock.json index 6bc44999..ad1cb923 100644 --- a/ui/desktop/package-lock.json +++ b/ui/desktop/package-lock.json @@ -16,6 +16,7 @@ "@radix-ui/react-avatar": "^1.1.1", "@radix-ui/react-dialog": "^1.1.4", "@radix-ui/react-icons": "^1.3.1", + "@radix-ui/react-radio-group": "^1.2.3", "@radix-ui/react-scroll-area": "^1.2.0", "@radix-ui/react-select": "^2.1.5", "@radix-ui/react-slot": "^1.1.1", diff --git a/ui/desktop/package.json b/ui/desktop/package.json index ca8de183..42483098 100644 --- a/ui/desktop/package.json +++ b/ui/desktop/package.json @@ -72,6 +72,7 @@ "@radix-ui/react-avatar": "^1.1.1", "@radix-ui/react-dialog": "^1.1.4", "@radix-ui/react-icons": "^1.3.1", + "@radix-ui/react-radio-group": "^1.2.3", "@radix-ui/react-scroll-area": "^1.2.0", "@radix-ui/react-select": "^2.1.5", "@radix-ui/react-slot": "^1.1.1", diff --git a/ui/desktop/src/components/ChatView.tsx b/ui/desktop/src/components/ChatView.tsx index b997a7da..577482bb 100644 --- a/ui/desktop/src/components/ChatView.tsx +++ b/ui/desktop/src/components/ChatView.tsx @@ -167,14 +167,28 @@ export default function ChatView({ setView }: { setView: (view: View) => void }) if (message.role === 'user') { const hasOnlyToolResponses = message.content.every((c) => c.type === 'toolResponse'); const hasTextContent = message.content.some((c) => c.type === 'text'); + const hasToolConfirmation = message.content.every( + (c) => c.type === 'toolConfirmationRequest' + ); - // Keep the message if it has text content or is not just tool responses - return hasTextContent || !hasOnlyToolResponses; + // Keep the message if it has text content or tool confirmation or is not just tool responses + return hasTextContent || !hasOnlyToolResponses || hasToolConfirmation; } return true; }); + const isUserMessage = (message: Message) => { + if (message.role === 'assistant') { + return false; + } + + if (message.content.every((c) => c.type === 'toolConfirmationRequest')) { + return false; + } + return true; + }; + return (
@@ -187,7 +201,7 @@ export default function ChatView({ setView }: { setView: (view: View) => void }) {filteredMessages.map((message, index) => (
- {message.role === 'user' ? ( + {isUserMessage(message) ? ( ) : ( { const responseMap = new Map(); @@ -63,6 +72,8 @@ export default function GooseMessage({ message, metadata, messages, append }: Go
)} + {hasToolConfirmation && } + {toolRequests.length > 0 && (
{toolRequests.map((toolRequest) => ( diff --git a/ui/desktop/src/components/ToolCallConfirmation.tsx b/ui/desktop/src/components/ToolCallConfirmation.tsx new file mode 100644 index 00000000..32418b83 --- /dev/null +++ b/ui/desktop/src/components/ToolCallConfirmation.tsx @@ -0,0 +1,39 @@ +import React, { useState } from 'react'; +import { ConfirmToolRequest } from '../utils/toolConfirm'; + +export default function ToolConfirmation({ toolConfirmationId }) { + const [disabled, setDisabled] = useState(false); + + const handleButtonClick = (confirmed) => { + setDisabled(true); + ConfirmToolRequest(toolConfirmationId, confirmed); + }; + + return ( + <> +
+ Goose would like to call the above tool. Allow? +
+
+ + +
+ + ); +} diff --git a/ui/desktop/src/components/ToolCallWithResponse.tsx b/ui/desktop/src/components/ToolCallWithResponse.tsx index 2b1b2d72..b6e3e7f5 100644 --- a/ui/desktop/src/components/ToolCallWithResponse.tsx +++ b/ui/desktop/src/components/ToolCallWithResponse.tsx @@ -80,6 +80,9 @@ function ToolResultView({ result }: ToolResultViewProps) { // Find results where either audience is not set, or it's set to a list that includes user const filteredResults = result.filter((item) => { + if (!item.annotations) { + return false; + } // Check audience (which may not be in the type) const audience = item.annotations?.audience; diff --git a/ui/desktop/src/components/settings/SettingsView.tsx b/ui/desktop/src/components/settings/SettingsView.tsx index f3bf4b41..ae4eb5f2 100644 --- a/ui/desktop/src/components/settings/SettingsView.tsx +++ b/ui/desktop/src/components/settings/SettingsView.tsx @@ -15,6 +15,8 @@ import BackButton from '../ui/BackButton'; import { RecentModelsRadio } from './models/RecentModels'; import { ExtensionItem } from './extensions/ExtensionItem'; import type { View } from '../../App'; +import ModeSelection from './basic/ModeSelection'; +import { getApiUrl, getSecretKey } from '../../config'; const EXTENSIONS_DESCRIPTION = 'The Model Context Protocol (MCP) is a system that allows AI models to securely connect with local or remote resources using standard server setups. It works like a client-server setup and expands AI capabilities using three main components: Prompts, Resources, and Tools.'; @@ -60,6 +62,53 @@ export default function SettingsView({ setView: (view: View) => void; viewOptions: SettingsViewOptions; }) { + const [mode, setMode] = useState('approve'); + + const handleModeChange = async (newMode: string) => { + const storeResponse = await fetch(getApiUrl('/configs/store'), { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'X-Secret-Key': getSecretKey(), + }, + body: JSON.stringify({ + key: 'GOOSE_MODE', + value: newMode, + isSecret: false, + }), + }); + + if (!storeResponse.ok) { + const errorText = await storeResponse.text(); + console.error('Store response error:', errorText); + throw new Error(`Failed to store new goose mode: ${newMode}`); + } + setMode(newMode); + }; + + useEffect(() => { + const fetchCurrentMode = async () => { + try { + const response = await fetch(getApiUrl('/configs/get?key=GOOSE_MODE'), { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + 'X-Secret-Key': getSecretKey(), + }, + }); + + if (response.ok) { + const { value } = await response.json(); + setMode(value); + } + } catch (error) { + console.error('Error fetching current mode:', error); + } + }; + + fetchCurrentMode(); + }, []); + const [settings, setSettings] = React.useState(() => { const saved = localStorage.getItem('user_settings'); window.electron.logInfo('Settings: ' + saved); @@ -84,7 +133,7 @@ export default function SettingsView({ const [isManualModalOpen, setIsManualModalOpen] = useState(false); // Persist settings changes - React.useEffect(() => { + useEffect(() => { localStorage.setItem('user_settings', JSON.stringify(settings)); }, [settings]); @@ -255,6 +304,20 @@ export default function SettingsView({ )}
+ +
+
+

Others

+
+ +
+

+ Others setting like Goose Mode, Tool Output, Experiment and more +

+ + +
+
diff --git a/ui/desktop/src/components/settings/basic/ModeSelection.tsx b/ui/desktop/src/components/settings/basic/ModeSelection.tsx new file mode 100644 index 00000000..ceeb6417 --- /dev/null +++ b/ui/desktop/src/components/settings/basic/ModeSelection.tsx @@ -0,0 +1,50 @@ +import * as RadioGroup from '@radix-ui/react-radio-group'; +import React from 'react'; + +const ModeSelection = ({ value, onChange }) => { + const modes = [ + { + value: 'auto', + label: 'Completely autonomous', + description: 'Full file modification capabilities, edit, create, and delete files freely.', + }, + { + value: 'approve', + label: 'Approval needed', + description: 'Editing, creating, and deleting files will require human approval.', + }, + { + value: 'chat', + label: 'Chat only', + description: 'Engage with the selected provider without using tools or extensions.', + }, + ]; + + return ( +
+

Mode Selection

+ + + {modes.map((mode) => ( + +
+

{mode.label}

+

{mode.description}

+
+
+
+ {value === mode.value &&
} +
+
+ + ))} + +
+ ); +}; + +export default ModeSelection; diff --git a/ui/desktop/src/types/message.ts b/ui/desktop/src/types/message.ts index 29b57bb4..d03212bd 100644 --- a/ui/desktop/src/types/message.ts +++ b/ui/desktop/src/types/message.ts @@ -187,6 +187,22 @@ export function getToolResponses(message: Message): ToolResponseMessageContent[] ); } +export function getToolConfirmationRequestId(message: Message): [string, boolean] { + const hasToolConfirmationRequest = message.content.some( + (content): content is ToolConfirmationRequestMessageContent => + content.type === 'toolConfirmationRequest' + ); + + const contentId = hasToolConfirmationRequest + ? message.content.find( + (content): content is ToolConfirmationRequestMessageContent => + content.type === 'toolConfirmationRequest' + )?.id || '' + : ''; + + return [contentId, hasToolConfirmationRequest]; +} + export function hasCompletedToolCalls(message: Message): boolean { const toolRequests = getToolRequests(message); if (toolRequests.length === 0) return false; diff --git a/ui/desktop/src/utils/toolConfirm.ts b/ui/desktop/src/utils/toolConfirm.ts new file mode 100644 index 00000000..19c8bfd9 --- /dev/null +++ b/ui/desktop/src/utils/toolConfirm.ts @@ -0,0 +1,25 @@ +import { getApiUrl, getSecretKey } from '../config'; + +export async function ConfirmToolRequest(requesyId: string, confirmed: boolean) { + try { + const response = await fetch(getApiUrl('/confirm'), { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'X-Secret-Key': getSecretKey(), + }, + body: JSON.stringify({ + id: requesyId, + confirmed, + }), + }); + + if (!response.ok) { + const errorText = await response.text(); + console.error('Delete response error: ', errorText); + throw new Error('Failed to confirm tool'); + } + } catch (error) { + console.error('Error confirm tool: ', error); + } +}