From b030f845ce4e6c793686eeb6614add6e5db589e8 Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Fri, 11 Apr 2025 13:06:59 -0700 Subject: [PATCH] feat: enable extension from ui (#2117) --- crates/goose-server/src/routes/reply.rs | 19 ++++- crates/goose/src/agents/agent.rs | 8 +- crates/goose/src/prompts/system.md | 2 +- ui/desktop/src/components/ChatView.tsx | 41 ++++++++- .../src/components/ExtensionConfirmation.tsx | 85 +++++++++++++++++++ ui/desktop/src/components/GooseMessage.tsx | 25 ++++++ ui/desktop/src/types/message.ts | 60 ++++++++++++- ui/desktop/src/utils/extensionConfirm.ts | 26 ++++++ ui/desktop/src/utils/toolConfirm.ts | 1 + 9 files changed, 254 insertions(+), 13 deletions(-) create mode 100644 ui/desktop/src/components/ExtensionConfirmation.tsx create mode 100644 ui/desktop/src/utils/extensionConfirm.ts diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 9f115dca..787b8a6a 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -365,16 +365,23 @@ async fn ask_handler( })) } -#[derive(Debug, Deserialize)] -struct ToolConfirmationRequest { +#[derive(Debug, Deserialize, Serialize)] +struct PermissionConfirmationRequest { id: String, + confirmed: bool, + #[serde(default = "default_principal_type")] + principal_type: PrincipalType, action: String, } +fn default_principal_type() -> PrincipalType { + PrincipalType::Tool +} + async fn confirm_handler( State(state): State, headers: HeaderMap, - Json(request): Json, + Json(request): Json, ) -> Result, StatusCode> { // Verify secret key let secret_key = headers @@ -385,6 +392,10 @@ async fn confirm_handler( if secret_key != state.secret_key { return Err(StatusCode::UNAUTHORIZED); } + tracing::info!( + "Received confirmation request: {}", + serde_json::to_string_pretty(&request).unwrap() + ); let agent = state.agent.clone(); let agent = agent.read().await; @@ -401,7 +412,7 @@ async fn confirm_handler( .handle_confirmation( request.id.clone(), PermissionConfirmation { - principal_type: PrincipalType::Tool, + principal_type: request.principal_type, permission, }, ) diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 6a35a4d0..0becddcd 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -461,12 +461,12 @@ impl Agent { let mut rx = self.confirmation_rx.lock().await; while let Some((req_id, extension_confirmation)) = rx.recv().await { + let extension_name = tool_call.arguments.get("extension_name") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); if req_id == request.id { if extension_confirmation.permission == Permission::AllowOnce || extension_confirmation.permission == Permission::AlwaysAllow { - let extension_name = tool_call.arguments.get("extension_name") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); let install_result = self.enable_extension(extension_name, request.id.clone()).await; install_results.push(install_result); } else { diff --git a/crates/goose/src/prompts/system.md b/crates/goose/src/prompts/system.md index 65781dee..7127671c 100644 --- a/crates/goose/src/prompts/system.md +++ b/crates/goose/src/prompts/system.md @@ -9,7 +9,7 @@ These models have varying knowledge cut-off dates depending on when they were tr Extensions allow other applications to provide context to Goose. Extensions connect Goose to different data sources and tools. You are capable of dynamically plugging into new extensions and learning how to use them. You solve higher level problems using the tools in these extensions, and can interact with multiple at once. -Use the search_available_extensions tool to find additional extensions to enable to help with your task. To enable extensions, use the enable_extension tool. You should only enable extensions found from the search_available_extensions tool. +Use the search_available_extensions tool to find additional extensions to enable to help with your task. To enable extensions, use the enable_extension tool and provide the extension_name. You should only enable extensions found from the search_available_extensions tool. {% if (extensions is defined) and extensions %} Because you dynamically load extensions, your conversation history may refer diff --git a/ui/desktop/src/components/ChatView.tsx b/ui/desktop/src/components/ChatView.tsx index 4826257a..3c9deee0 100644 --- a/ui/desktop/src/components/ChatView.tsx +++ b/ui/desktop/src/components/ChatView.tsx @@ -25,6 +25,7 @@ import { ToolResponseMessageContent, ToolConfirmationRequestMessageContent, getTextContent, + EnableExtensionRequestMessageContent, } from '../types/message'; export interface ChatType { @@ -49,10 +50,12 @@ const isUserMessage = (message: Message): boolean => { if (message.role === 'assistant') { return false; } - if (message.content.every((c) => c.type === 'toolConfirmationRequest')) { return false; } + if (message.content.every((c) => c.type === 'enableExtensionRequest')) { + return false; + } return true; }; @@ -312,6 +315,14 @@ export default function ChatView({ return [content.id, toolCall]; } }); + const enableExtensionRequests = lastMessage.content + .filter( + (content): content is EnableExtensionRequestMessageContent => + content.type === 'enableExtensionRequest' + ) + .map((content) => { + return [content.id, content.extensionName]; + }); if (toolRequests.length !== 0) { // This means we were interrupted during a tool request @@ -338,10 +349,33 @@ export default function ChatView({ responseMessage.content.push(toolResponse); } - // Use an immutable update to add the response message to the messages array setMessages([...messages, responseMessage]); } + + // do the same for enable extension requests + // leverages toolResponse to send the error notification + if (enableExtensionRequests.length !== 0) { + let responseMessage: Message = { + role: 'user', + created: Date.now(), + content: [], + }; + const notification = 'Interrupted by the user to make a correction'; + // generate a response saying it was interrupted for each extension request + for (const [reqId, _] of enableExtensionRequests) { + const toolResponse: ToolResponseMessageContent = { + type: 'toolResponse', + id: reqId, + toolResult: { + status: 'error', + error: notification, + }, + }; + responseMessage.content.push(toolResponse); + } + setMessages([...messages, responseMessage]); + } } }; @@ -359,8 +393,9 @@ export default function ChatView({ (c) => c.type === 'toolConfirmationRequest' ); + const hasEnableExtension = message.content.every((c) => c.type === 'enableExtensionRequest'); // Keep the message if it has text content or tool confirmation or is not just tool responses - return hasTextContent || !hasOnlyToolResponses || hasToolConfirmation; + return hasTextContent || !hasOnlyToolResponses || hasToolConfirmation || hasEnableExtension; } return true; diff --git a/ui/desktop/src/components/ExtensionConfirmation.tsx b/ui/desktop/src/components/ExtensionConfirmation.tsx new file mode 100644 index 00000000..0e9aa1cf --- /dev/null +++ b/ui/desktop/src/components/ExtensionConfirmation.tsx @@ -0,0 +1,85 @@ +import React, { useState } from 'react'; +import { ConfirmExtensionRequest } from '../utils/extensionConfirm'; +import { snakeToTitleCase } from '../utils'; + +export default function ExtensionConfirmation({ + isCancelledMessage, + isClicked, + extensionConfirmationId, + extensionName, +}) { + const [clicked, setClicked] = useState(isClicked); + const [status, setStatus] = useState('unknown'); + + const handleButtonClick = (confirmed) => { + setClicked(true); + setStatus(confirmed ? 'approved' : 'denied'); + ConfirmExtensionRequest(extensionConfirmationId, confirmed); + }; + + return isCancelledMessage ? ( +
+ Extension enablement is cancelled. +
+ ) : ( + <> +
+ Goose would like to enable the above extension. Allow? +
+ {clicked ? ( +
+
+ {status === 'approved' && ( + + + + )} + {status === 'denied' && ( + + + + )} + + {isClicked + ? 'Extension enablement is not available' + : `${snakeToTitleCase(extensionName.includes('__') ? extensionName.split('__').pop() : extensionName)} is ${status}`}{' '} + +
+
+ ) : ( +
+ + +
+ )} + + ); +} diff --git a/ui/desktop/src/components/GooseMessage.tsx b/ui/desktop/src/components/GooseMessage.tsx index b628b714..e2086a25 100644 --- a/ui/desktop/src/components/GooseMessage.tsx +++ b/ui/desktop/src/components/GooseMessage.tsx @@ -11,9 +11,11 @@ import { getToolResponses, getToolConfirmationContent, createToolErrorResponseMessage, + getEnableExtensionContent, } from '../types/message'; import ToolCallConfirmation from './ToolCallConfirmation'; import MessageCopyLink from './MessageCopyLink'; +import ExtensionConfirmation from './ExtensionConfirmation'; interface GooseMessageProps { messageHistoryIndex: number; @@ -52,6 +54,9 @@ export default function GooseMessage({ const toolConfirmationContent = getToolConfirmationContent(message); const hasToolConfirmation = toolConfirmationContent !== undefined; + const enableExtensionContent = getEnableExtensionContent(message); + const hasEnableExtension = enableExtensionContent !== undefined; + // Find tool responses that correspond to the tool requests in this message const toolResponsesMap = useMemo(() => { const responseMap = new Map(); @@ -86,12 +91,23 @@ export default function GooseMessage({ createToolErrorResponseMessage(toolConfirmationContent.id, 'The tool call is cancelled.') ); } + if (messageIndex == messageHistoryIndex - 1 && hasEnableExtension) { + appendMessage( + createToolErrorResponseMessage( + enableExtensionContent.id, + 'The extension enablement is cancelled.' + ) + ); + } }, [ messageIndex, messageHistoryIndex, hasToolConfirmation, toolConfirmationContent, appendMessage, + hasEnableExtension, + // Only include enableExtensionContent if it exists + enableExtensionContent?.id, ]); return ( @@ -140,6 +156,15 @@ export default function GooseMessage({ toolName={toolConfirmationContent.toolName} /> )} + + {hasEnableExtension && ( + + )} {/* TODO(alexhancock): Re-enable link previews once styled well again */} diff --git a/ui/desktop/src/types/message.ts b/ui/desktop/src/types/message.ts index 31e898aa..752cd5b1 100644 --- a/ui/desktop/src/types/message.ts +++ b/ui/desktop/src/types/message.ts @@ -68,12 +68,44 @@ export interface ToolConfirmationRequestMessageContent { prompt?: string; } +export interface EnableExtensionCall { + name: string; + arguments: Record; + extensionName: string; +} + +export interface EnableExtensionCallResult { + status: 'success' | 'error'; + value?: T; + error?: string; +} + +export interface EnableExtensionRequest { + id: string; + extensionCall: EnableExtensionCallResult; +} + +export interface EnableExtensionConfirmationRequest { + id: string; + extensionName: string; + arguments: Record; + prompt?: string; +} + +export interface EnableExtensionRequestMessageContent { + type: 'enableExtensionRequest'; + id: string; + extensionCall: EnableExtensionCallResult; + extensionName: string; +} + export type MessageContent = | TextContent | ImageContent | ToolRequestMessageContent | ToolResponseMessageContent - | ToolConfirmationRequestMessageContent; + | ToolConfirmationRequestMessageContent + | EnableExtensionRequestMessageContent; export interface Message { id?: string; @@ -187,6 +219,15 @@ export function getToolResponses(message: Message): ToolResponseMessageContent[] ); } +export function getEnableExtensionRequests( + message: Message +): EnableExtensionRequestMessageContent[] { + return message.content.filter( + (content): content is EnableExtensionRequestMessageContent => + content.type === 'enableExtensionRequest' + ); +} + export function getToolConfirmationContent( message: Message ): ToolConfirmationRequestMessageContent { @@ -196,6 +237,13 @@ export function getToolConfirmationContent( ); } +export function getEnableExtensionContent(message: Message): EnableExtensionRequestMessageContent { + return message.content.find( + (content): content is EnableExtensionRequestMessageContent => + content.type === 'enableExtensionRequest' + ); +} + export function hasCompletedToolCalls(message: Message): boolean { const toolRequests = getToolRequests(message); if (toolRequests.length === 0) return false; @@ -205,3 +253,13 @@ export function hasCompletedToolCalls(message: Message): boolean { // by looking through subsequent messages return true; } + +export function hasCompletedEnableExtensionCalls(message: Message): boolean { + const extensionRequests = getEnableExtensionRequests(message); + if (extensionRequests.length === 0) return false; + + // For now, we'll assume all extension calls are completed when this is checked + // In a real implementation, you'd need to check if all extension requests have responses + // by looking through subsequent messages + return true; +} diff --git a/ui/desktop/src/utils/extensionConfirm.ts b/ui/desktop/src/utils/extensionConfirm.ts new file mode 100644 index 00000000..d85cba70 --- /dev/null +++ b/ui/desktop/src/utils/extensionConfirm.ts @@ -0,0 +1,26 @@ +import { getApiUrl, getSecretKey } from '../config'; + +export async function ConfirmExtensionRequest(requestId: string, confirmed: boolean) { + try { + const response = await fetch(getApiUrl('/confirm'), { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'X-Secret-Key': getSecretKey(), + }, + body: JSON.stringify({ + id: requestId, + confirmed, + principal_type: 'Extension', + }), + }); + + if (!response.ok) { + const errorText = await response.text(); + console.error('Delete response error: ', errorText); + throw new Error('Failed to confirm extension enablement'); + } + } catch (error) { + console.error('Error confirming extension enablement: ', error); + } +} diff --git a/ui/desktop/src/utils/toolConfirm.ts b/ui/desktop/src/utils/toolConfirm.ts index 3b0f3f55..d858654d 100644 --- a/ui/desktop/src/utils/toolConfirm.ts +++ b/ui/desktop/src/utils/toolConfirm.ts @@ -11,6 +11,7 @@ export async function ConfirmToolRequest(requesyId: string, action: string) { body: JSON.stringify({ id: requesyId, action, + principal_type: 'Tool', }), });