feat: support goose mode in UI (#1434)

Co-authored-by: Lily Delalande <ldelalande@squareup.com>
This commit is contained in:
Yingjie He
2025-02-28 17:00:41 -08:00
committed by GitHub
parent 8f5fba97b8
commit f7f2540287
16 changed files with 320 additions and 19 deletions

View File

@@ -86,7 +86,7 @@ async fn extend_prompt(
return Err(StatusCode::UNAUTHORIZED);
}
let mut agent = state.agent.lock().await;
let mut agent = state.agent.write().await;
if let Some(ref mut agent) = *agent {
agent.extend_system_prompt(payload.extension).await;
Ok(Json(ExtendPromptResponse { success: true }))
@@ -134,7 +134,7 @@ async fn create_agent(
let new_agent = AgentFactory::create(&version, provider).expect("Failed to create agent");
let mut agent = state.agent.lock().await;
let mut agent = state.agent.write().await;
*agent = Some(new_agent);
Ok(Json(CreateAgentResponse { version }))

View File

@@ -1,5 +1,9 @@
use crate::state::AppState;
use axum::{extract::State, routing::delete, routing::post, Json, Router};
use axum::{
extract::{Query, State},
routing::{delete, get, post},
Json, Router,
};
use goose::config::Config;
use http::{HeaderMap, StatusCode};
use once_cell::sync::Lazy;
@@ -140,6 +144,45 @@ async fn check_provider_configs(
Ok(Json(response))
}
#[derive(Deserialize)]
pub struct GetConfigQuery {
key: String,
}
#[derive(Serialize)]
pub struct GetConfigResponse {
value: Option<String>,
}
pub async fn get_config(
State(state): State<AppState>,
headers: HeaderMap,
Query(query): Query<GetConfigQuery>,
) -> Result<Json<GetConfigResponse>, StatusCode> {
// Verify secret key
let secret_key = headers
.get("X-Secret-Key")
.and_then(|value| value.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;
if secret_key != state.secret_key {
return Err(StatusCode::UNAUTHORIZED);
}
// Fetch the configuration value. Right now we don't allow get a secret.
let config = Config::global();
let value = if let Ok(config_value) = config.get::<String>(&query.key) {
Some(config_value)
} else if let Ok(env_value) = std::env::var(&query.key) {
Some(env_value)
} else {
None
};
// Return the value
Ok(Json(GetConfigResponse { value }))
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct DeleteConfigRequest {
@@ -178,6 +221,7 @@ async fn delete_config(
pub fn routes(state: AppState) -> Router {
Router::new()
.route("/configs/providers", post(check_provider_configs))
.route("/configs/get", get(get_config))
.route("/configs/store", post(store_config))
.route("/configs/delete", delete(delete_config))
.with_state(state)

View File

@@ -161,7 +161,7 @@ async fn add_extension(
};
// Acquire a lock on the agent and attempt to add the extension.
let mut agent = state.agent.lock().await;
let mut agent = state.agent.write().await;
let agent = agent.as_mut().ok_or(StatusCode::PRECONDITION_REQUIRED)?;
let response = agent.add_extension(extension_config).await;
@@ -201,7 +201,7 @@ async fn remove_extension(
}
// Acquire a lock on the agent and attempt to remove the extension
let mut agent = state.agent.lock().await;
let mut agent = state.agent.write().await;
let agent = agent.as_mut().ok_or(StatusCode::PRECONDITION_REQUIRED)?;
agent.remove_extension(&name).await;

View File

@@ -12,6 +12,7 @@ use goose::message::{Message, MessageContent};
use mcp_core::role::Role;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{
convert::Infallible,
pin::Pin,
@@ -113,7 +114,7 @@ async fn handler(
// Spawn task to handle streaming
tokio::spawn(async move {
let agent = agent.lock().await;
let agent = agent.read().await;
let agent = match agent.as_ref() {
Some(agent) => agent,
None => {
@@ -237,7 +238,7 @@ async fn ask_handler(
}
let agent = state.agent.clone();
let agent = agent.lock().await;
let agent = agent.write().await;
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;
// Create a single message for the prompt
@@ -277,11 +278,42 @@ async fn ask_handler(
}))
}
#[derive(Debug, Deserialize)]
struct ToolConfirmationRequest {
id: String,
confirmed: bool,
}
async fn confirm_handler(
State(state): State<AppState>,
headers: HeaderMap,
Json(request): Json<ToolConfirmationRequest>,
) -> Result<Json<Value>, StatusCode> {
// Verify secret key
let secret_key = headers
.get("X-Secret-Key")
.and_then(|value| value.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;
if secret_key != state.secret_key {
return Err(StatusCode::UNAUTHORIZED);
}
let agent = state.agent.clone();
let agent = agent.read().await;
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;
agent
.handle_confirmation(request.id.clone(), request.confirmed)
.await;
Ok(Json(Value::Object(serde_json::Map::new())))
}
// Configure routes for this module
pub fn routes(state: AppState) -> Router {
Router::new()
.route("/reply", post(handler))
.route("/ask", post(ask_handler))
.route("/confirm", post(confirm_handler))
.with_state(state)
}
@@ -332,7 +364,7 @@ mod tests {
use axum::{body::Body, http::Request};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::{Mutex, RwLock};
use tower::ServiceExt;
// This test requires tokio runtime
@@ -346,7 +378,7 @@ mod tests {
let agent = AgentFactory::create("reference", mock_provider).unwrap();
let state = AppState {
config: Arc::new(Mutex::new(HashMap::new())),
agent: Arc::new(Mutex::new(Some(agent))),
agent: Arc::new(RwLock::new(Some(agent))),
secret_key: "test-secret".to_string(),
};

View File

@@ -3,13 +3,13 @@ use goose::agents::Agent;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::{Mutex, RwLock};
/// Shared application state
#[allow(dead_code)]
#[derive(Clone)]
pub struct AppState {
pub agent: Arc<Mutex<Option<Box<dyn Agent>>>>,
pub agent: Arc<RwLock<Option<Box<dyn Agent>>>>,
pub secret_key: String,
pub config: Arc<Mutex<HashMap<String, Value>>>,
}
@@ -17,7 +17,7 @@ pub struct AppState {
impl AppState {
pub async fn new(secret_key: String) -> Result<Self> {
Ok(Self {
agent: Arc::new(Mutex::new(None)),
agent: Arc::new(RwLock::new(None)),
secret_key,
config: Arc::new(Mutex::new(HashMap::new())),
})

View File

@@ -274,7 +274,8 @@ impl Agent for TruncateAgent {
// Wait for confirmation response through the channel
let mut rx = self.confirmation_rx.lock().await;
if let Some((req_id, confirmed)) = rx.recv().await {
// Loop the recv until we have a matched req_id due to potential duplicate messages.
while let Some((req_id, confirmed)) = rx.recv().await {
if req_id == request.id {
if confirmed {
// User approved - dispatch the tool call
@@ -290,6 +291,7 @@ impl Agent for TruncateAgent {
Ok(vec![Content::text("User declined to run this tool.")]),
);
}
break; // Exit the loop once the matching `req_id` is found
}
}
}

View File

@@ -16,6 +16,7 @@
"@radix-ui/react-avatar": "^1.1.1",
"@radix-ui/react-dialog": "^1.1.4",
"@radix-ui/react-icons": "^1.3.1",
"@radix-ui/react-radio-group": "^1.2.3",
"@radix-ui/react-scroll-area": "^1.2.0",
"@radix-ui/react-select": "^2.1.5",
"@radix-ui/react-slot": "^1.1.1",

View File

@@ -72,6 +72,7 @@
"@radix-ui/react-avatar": "^1.1.1",
"@radix-ui/react-dialog": "^1.1.4",
"@radix-ui/react-icons": "^1.3.1",
"@radix-ui/react-radio-group": "^1.2.3",
"@radix-ui/react-scroll-area": "^1.2.0",
"@radix-ui/react-select": "^2.1.5",
"@radix-ui/react-slot": "^1.1.1",

View File

@@ -167,14 +167,28 @@ export default function ChatView({ setView }: { setView: (view: View) => void })
if (message.role === 'user') {
const hasOnlyToolResponses = message.content.every((c) => c.type === 'toolResponse');
const hasTextContent = message.content.some((c) => c.type === 'text');
const hasToolConfirmation = message.content.every(
(c) => c.type === 'toolConfirmationRequest'
);
// Keep the message if it has text content or is not just tool responses
return hasTextContent || !hasOnlyToolResponses;
// Keep the message if it has text content or tool confirmation or is not just tool responses
return hasTextContent || !hasOnlyToolResponses || hasToolConfirmation;
}
return true;
});
const isUserMessage = (message: Message) => {
if (message.role === 'assistant') {
return false;
}
if (message.content.every((c) => c.type === 'toolConfirmationRequest')) {
return false;
}
return true;
};
return (
<div className="flex flex-col w-full h-screen items-center justify-center">
<div className="relative flex items-center h-[36px] w-full bg-bgSubtle border-b border-borderSubtle">
@@ -187,7 +201,7 @@ export default function ChatView({ setView }: { setView: (view: View) => void })
<ScrollArea ref={scrollRef} className="flex-1 px-4" autoScroll>
{filteredMessages.map((message, index) => (
<div key={message.id || index} className="mt-[16px]">
{message.role === 'user' ? (
{isUserMessage(message) ? (
<UserMessage message={message} />
) : (
<GooseMessage

View File

@@ -4,7 +4,14 @@ import GooseResponseForm from './GooseResponseForm';
import { extractUrls } from '../utils/urlUtils';
import MarkdownContent from './MarkdownContent';
import ToolCallWithResponse from './ToolCallWithResponse';
import { Message, getTextContent, getToolRequests, getToolResponses } from '../types/message';
import {
Message,
getTextContent,
getToolRequests,
getToolResponses,
getToolConfirmationRequestId,
} from '../types/message';
import ToolCallConfirmation from './ToolCallConfirmation';
interface GooseMessageProps {
message: Message;
@@ -15,7 +22,7 @@ interface GooseMessageProps {
export default function GooseMessage({ message, metadata, messages, append }: GooseMessageProps) {
// Extract text content from the message
const textContent = getTextContent(message);
let textContent = getTextContent(message);
// Get tool requests from the message
const toolRequests = getToolRequests(message);
@@ -29,6 +36,8 @@ export default function GooseMessage({ message, metadata, messages, append }: Go
const previousUrls = previousMessage ? extractUrls(getTextContent(previousMessage)) : [];
const urls = toolRequests.length === 0 ? extractUrls(textContent, previousUrls) : [];
const [toolConfirmationId, hasToolConfirmation] = getToolConfirmationRequestId(message);
// Find tool responses that correspond to the tool requests in this message
const toolResponsesMap = useMemo(() => {
const responseMap = new Map();
@@ -63,6 +72,8 @@ export default function GooseMessage({ message, metadata, messages, append }: Go
</div>
)}
{hasToolConfirmation && <ToolCallConfirmation toolConfirmationId={toolConfirmationId} />}
{toolRequests.length > 0 && (
<div className="goose-message-tool bg-bgApp border border-borderSubtle dark:border-gray-700 rounded-b-2xl px-4 pt-4 pb-2 mt-1">
{toolRequests.map((toolRequest) => (

View File

@@ -0,0 +1,39 @@
import React, { useState } from 'react';
import { ConfirmToolRequest } from '../utils/toolConfirm';
export default function ToolConfirmation({ toolConfirmationId }) {
const [disabled, setDisabled] = useState(false);
const handleButtonClick = (confirmed) => {
setDisabled(true);
ConfirmToolRequest(toolConfirmationId, confirmed);
};
return (
<>
<div className="goose-message-content bg-bgSubtle rounded-2xl px-4 py-2 rounded-b-none">
Goose would like to call the above tool. Allow?
</div>
<div className="goose-message-tool bg-bgApp border border-borderSubtle dark:border-gray-700 rounded-b-2xl px-4 pt-4 pb-2 flex gap-4 mt-1">
<button
className={
'bg-black text-white dark:bg-white dark:text-black rounded-full px-6 py-2 transition'
}
onClick={() => handleButtonClick(true)}
disabled={disabled}
>
Allow tool
</button>
<button
className={
'bg-white text-black dark:bg-black dark:text-white border border-gray-300 dark:border-gray-700 rounded-full px-6 py-2 transition'
}
onClick={() => handleButtonClick(false)}
disabled={disabled}
>
Deny
</button>
</div>
</>
);
}

View File

@@ -80,6 +80,9 @@ function ToolResultView({ result }: ToolResultViewProps) {
// Find results where either audience is not set, or it's set to a list that includes user
const filteredResults = result.filter((item) => {
if (!item.annotations) {
return false;
}
// Check audience (which may not be in the type)
const audience = item.annotations?.audience;

View File

@@ -15,6 +15,8 @@ import BackButton from '../ui/BackButton';
import { RecentModelsRadio } from './models/RecentModels';
import { ExtensionItem } from './extensions/ExtensionItem';
import type { View } from '../../App';
import ModeSelection from './basic/ModeSelection';
import { getApiUrl, getSecretKey } from '../../config';
const EXTENSIONS_DESCRIPTION =
'The Model Context Protocol (MCP) is a system that allows AI models to securely connect with local or remote resources using standard server setups. It works like a client-server setup and expands AI capabilities using three main components: Prompts, Resources, and Tools.';
@@ -60,6 +62,53 @@ export default function SettingsView({
setView: (view: View) => void;
viewOptions: SettingsViewOptions;
}) {
const [mode, setMode] = useState('approve');
const handleModeChange = async (newMode: string) => {
const storeResponse = await fetch(getApiUrl('/configs/store'), {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-Secret-Key': getSecretKey(),
},
body: JSON.stringify({
key: 'GOOSE_MODE',
value: newMode,
isSecret: false,
}),
});
if (!storeResponse.ok) {
const errorText = await storeResponse.text();
console.error('Store response error:', errorText);
throw new Error(`Failed to store new goose mode: ${newMode}`);
}
setMode(newMode);
};
useEffect(() => {
const fetchCurrentMode = async () => {
try {
const response = await fetch(getApiUrl('/configs/get?key=GOOSE_MODE'), {
method: 'GET',
headers: {
'Content-Type': 'application/json',
'X-Secret-Key': getSecretKey(),
},
});
if (response.ok) {
const { value } = await response.json();
setMode(value);
}
} catch (error) {
console.error('Error fetching current mode:', error);
}
};
fetchCurrentMode();
}, []);
const [settings, setSettings] = React.useState<SettingsType>(() => {
const saved = localStorage.getItem('user_settings');
window.electron.logInfo('Settings: ' + saved);
@@ -84,7 +133,7 @@ export default function SettingsView({
const [isManualModalOpen, setIsManualModalOpen] = useState(false);
// Persist settings changes
React.useEffect(() => {
useEffect(() => {
localStorage.setItem('user_settings', JSON.stringify(settings));
}, [settings]);
@@ -255,6 +304,20 @@ export default function SettingsView({
)}
</div>
</section>
<section id="others">
<div className="flex justify-between items-center mb-6 border-b border-borderSubtle px-8">
<h2 className="text-xl font-semibold text-textStandard">Others</h2>
</div>
<div className="px-8">
<p className="text-sm text-textStandard mb-4">
Others setting like Goose Mode, Tool Output, Experiment and more
</p>
<ModeSelection value={mode} onChange={handleModeChange} />
</div>
</section>
</div>
</div>
</div>

View File

@@ -0,0 +1,50 @@
import * as RadioGroup from '@radix-ui/react-radio-group';
import React from 'react';
const ModeSelection = ({ value, onChange }) => {
const modes = [
{
value: 'auto',
label: 'Completely autonomous',
description: 'Full file modification capabilities, edit, create, and delete files freely.',
},
{
value: 'approve',
label: 'Approval needed',
description: 'Editing, creating, and deleting files will require human approval.',
},
{
value: 'chat',
label: 'Chat only',
description: 'Engage with the selected provider without using tools or extensions.',
},
];
return (
<div>
<h4 className="font-medium mb-4">Mode Selection</h4>
<RadioGroup.Root className="flex flex-col space-y-2" value={value} onValueChange={onChange}>
{modes.map((mode) => (
<RadioGroup.Item
key={mode.value}
value={mode.value}
className="flex items-center justify-between p-2 hover:bg-gray-100 rounded transition-all cursor-pointer"
>
<div className="flex flex-col text-left">
<h3 className="text-sm font-semibold text-textStandard">{mode.label}</h3>
<p className="text-xs text-textSubtle mt-[2px]">{mode.description}</p>
</div>
<div className="flex-shrink-0">
<div className="w-4 h-4 flex items-center justify-center rounded-full border border-gray-500">
{value === mode.value && <div className="w-2 h-2 bg-black rounded-full" />}
</div>
</div>
</RadioGroup.Item>
))}
</RadioGroup.Root>
</div>
);
};
export default ModeSelection;

View File

@@ -187,6 +187,22 @@ export function getToolResponses(message: Message): ToolResponseMessageContent[]
);
}
export function getToolConfirmationRequestId(message: Message): [string, boolean] {
const hasToolConfirmationRequest = message.content.some(
(content): content is ToolConfirmationRequestMessageContent =>
content.type === 'toolConfirmationRequest'
);
const contentId = hasToolConfirmationRequest
? message.content.find(
(content): content is ToolConfirmationRequestMessageContent =>
content.type === 'toolConfirmationRequest'
)?.id || ''
: '';
return [contentId, hasToolConfirmationRequest];
}
export function hasCompletedToolCalls(message: Message): boolean {
const toolRequests = getToolRequests(message);
if (toolRequests.length === 0) return false;

View File

@@ -0,0 +1,25 @@
import { getApiUrl, getSecretKey } from '../config';
export async function ConfirmToolRequest(requesyId: string, confirmed: boolean) {
try {
const response = await fetch(getApiUrl('/confirm'), {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-Secret-Key': getSecretKey(),
},
body: JSON.stringify({
id: requesyId,
confirmed,
}),
});
if (!response.ok) {
const errorText = await response.text();
console.error('Delete response error: ', errorText);
throw new Error('Failed to confirm tool');
}
} catch (error) {
console.error('Error confirm tool: ', error);
}
}