From 7e4cfcdaaef4272db415c572cd39ff6e3f825990 Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Wed, 16 Apr 2025 08:39:49 -0700 Subject: [PATCH] chore: generalize extension request (#2213) --- crates/goose-cli/src/session/mod.rs | 15 +++++-- crates/goose/src/message.rs | 29 +++++++++---- .../goose/src/permission/permission_judge.rs | 3 +- .../goose/src/providers/formats/anthropic.rs | 4 +- crates/goose/src/providers/formats/bedrock.rs | 2 +- .../goose/src/providers/formats/databricks.rs | 2 +- crates/goose/src/providers/formats/openai.rs | 2 +- ui/desktop/src/components/ChatView.tsx | 20 ++++----- .../src/components/ExtensionConfirmation.tsx | 13 ++++-- ui/desktop/src/components/GooseMessage.tsx | 21 +++++----- ui/desktop/src/types/message.ts | 42 +++++++------------ 11 files changed, 84 insertions(+), 69 deletions(-) diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index a0238af4..5431fc81 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -17,6 +17,7 @@ use completion::GooseCompleter; use etcetera::choose_app_strategy; use etcetera::AppStrategy; use goose::agents::extension::{Envs, ExtensionConfig}; +use goose::agents::platform_tools::PLATFORM_ENABLE_EXTENSION_TOOL_NAME; use goose::agents::{Agent, SessionConfig}; use goose::config::Config; use goose::message::{Message, MessageContent}; @@ -620,13 +621,19 @@ impl Session { principal_type: PrincipalType::Tool, permission, },).await; - } else if let Some(MessageContent::EnableExtensionRequest(enable_extension_request)) = message.content.first() { + } else if let Some(MessageContent::ExtensionRequest(enable_extension_request)) = message.content.first() { output::hide_thinking(); - let prompt = "Goose would like to install the following extension, do you approve?".to_string(); + let extension_action = if enable_extension_request.tool_name == PLATFORM_ENABLE_EXTENSION_TOOL_NAME { + "enable" + } else { + "disable" + }; + + let prompt = format!("Goose would like to {} the following extension, do you approve?", extension_action); let confirmed = cliclack::select(prompt) - .item(true, "Yes, for this session", "Enable the extension for this session") - .item(false, "No", "Do not enable the extension") + .item(true, "Yes, for this session", format!("{} the extension for this session", extension_action)) + .item(false, "No", format!("Do not {} the extension", extension_action)) .interact()?; let permission = if confirmed { Permission::AllowOnce diff --git a/crates/goose/src/message.rs b/crates/goose/src/message.rs index 5dbd305d..9d5b2356 100644 --- a/crates/goose/src/message.rs +++ b/crates/goose/src/message.rs @@ -61,9 +61,10 @@ pub struct ToolConfirmationRequest { #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[serde(rename_all = "camelCase")] -pub struct EnableExtensionRequest { +pub struct ExtensionRequest { pub id: String, pub extension_name: String, + pub tool_name: String, } #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] @@ -94,7 +95,7 @@ pub enum MessageContent { ToolRequest(ToolRequest), ToolResponse(ToolResponse), ToolConfirmationRequest(ToolConfirmationRequest), - EnableExtensionRequest(EnableExtensionRequest), + ExtensionRequest(ExtensionRequest), FrontendToolRequest(FrontendToolRequest), Thinking(ThinkingContent), RedactedThinking(RedactedThinkingContent), @@ -144,10 +145,15 @@ impl MessageContent { }) } - pub fn enable_extension_request>(id: S, extension_name: String) -> Self { - MessageContent::EnableExtensionRequest(EnableExtensionRequest { + pub fn extension_request>( + id: S, + extension_name: String, + tool_name: String, + ) -> Self { + MessageContent::ExtensionRequest(ExtensionRequest { id: id.into(), extension_name, + tool_name, }) } @@ -192,9 +198,9 @@ impl MessageContent { } } - pub fn as_enable_extension_request(&self) -> Option<&EnableExtensionRequest> { - if let MessageContent::EnableExtensionRequest(ref enable_extension_request) = self { - Some(enable_extension_request) + pub fn as_extension_request(&self) -> Option<&ExtensionRequest> { + if let MessageContent::ExtensionRequest(ref extension_request) = self { + Some(extension_request) } else { None } @@ -359,12 +365,17 @@ impl Message { )) } - pub fn with_enable_extension_request>( + pub fn with_extension_request>( self, id: S, extension_name: String, + tool_name: String, ) -> Self { - self.with_content(MessageContent::enable_extension_request(id, extension_name)) + self.with_content(MessageContent::extension_request( + id, + extension_name, + tool_name, + )) } pub fn with_frontend_tool_request>( diff --git a/crates/goose/src/permission/permission_judge.rs b/crates/goose/src/permission/permission_judge.rs index 1664a5fd..cb831f1c 100644 --- a/crates/goose/src/permission/permission_judge.rs +++ b/crates/goose/src/permission/permission_judge.rs @@ -160,7 +160,7 @@ pub fn get_confirmation_message(request_id: &str, tool_call: ToolCall) -> (Princ if tool_call.name == PLATFORM_ENABLE_EXTENSION_TOOL_NAME { ( PrincipalType::Extension, - Message::user().with_enable_extension_request( + Message::user().with_extension_request( request_id, tool_call .arguments @@ -168,6 +168,7 @@ pub fn get_confirmation_message(request_id: &str, tool_call: ToolCall) -> (Princ .and_then(|v| v.as_str()) .unwrap_or("") .to_string(), + tool_call.name.clone(), ), ) } else { diff --git a/crates/goose/src/providers/formats/anthropic.rs b/crates/goose/src/providers/formats/anthropic.rs index d8dd07d7..171102c6 100644 --- a/crates/goose/src/providers/formats/anthropic.rs +++ b/crates/goose/src/providers/formats/anthropic.rs @@ -60,8 +60,8 @@ pub fn format_messages(messages: &[Message]) -> Vec { MessageContent::ToolConfirmationRequest(_tool_confirmation_request) => { // Skip tool confirmation requests } - MessageContent::EnableExtensionRequest(_enable_extension_request) => { - // Skip enable extension requests + MessageContent::ExtensionRequest(_extension_request) => { + // Skip extension requests } MessageContent::Thinking(thinking) => { content.push(json!({ diff --git a/crates/goose/src/providers/formats/bedrock.rs b/crates/goose/src/providers/formats/bedrock.rs index d114ef59..1a7b59c7 100644 --- a/crates/goose/src/providers/formats/bedrock.rs +++ b/crates/goose/src/providers/formats/bedrock.rs @@ -31,7 +31,7 @@ pub fn to_bedrock_message_content(content: &MessageContent) -> Result { bedrock::ContentBlock::Text("".to_string()) } - MessageContent::EnableExtensionRequest(_enable_extension_request) => { + MessageContent::ExtensionRequest(_extension_request) => { bedrock::ContentBlock::Text("".to_string()) } MessageContent::Image(_) => { diff --git a/crates/goose/src/providers/formats/databricks.rs b/crates/goose/src/providers/formats/databricks.rs index 69c4ac77..072da7c4 100644 --- a/crates/goose/src/providers/formats/databricks.rs +++ b/crates/goose/src/providers/formats/databricks.rs @@ -179,7 +179,7 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< MessageContent::ToolConfirmationRequest(_) => { // Skip tool confirmation requests } - MessageContent::EnableExtensionRequest(_) => { + MessageContent::ExtensionRequest(_) => { // Skip enable extension requests } MessageContent::Image(image) => { diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index c6c34e9b..6629b17d 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -147,7 +147,7 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< MessageContent::ToolConfirmationRequest(_) => { // Skip tool confirmation requests } - MessageContent::EnableExtensionRequest(_) => { + MessageContent::ExtensionRequest(_) => { // Skip enable extension requests } MessageContent::Image(image) => { diff --git a/ui/desktop/src/components/ChatView.tsx b/ui/desktop/src/components/ChatView.tsx index 6bd272e3..933286ae 100644 --- a/ui/desktop/src/components/ChatView.tsx +++ b/ui/desktop/src/components/ChatView.tsx @@ -27,7 +27,7 @@ import { ToolRequestMessageContent, ToolResponseMessageContent, ToolConfirmationRequestMessageContent, - EnableExtensionRequestMessageContent, + ExtensionRequestMessageContent, } from '../types/message'; export interface ChatType { @@ -47,7 +47,7 @@ const isUserMessage = (message: Message): boolean => { if (message.content.every((c) => c.type === 'toolConfirmationRequest')) { return false; } - if (message.content.every((c) => c.type === 'enableExtensionRequest')) { + if (message.content.every((c) => c.type === 'extensionRequest')) { return false; } return true; @@ -258,13 +258,13 @@ export default function ChatView({ return [content.id, toolCall]; } }); - const enableExtensionRequests = lastMessage.content + const extensionRequests = lastMessage.content .filter( - (content): content is EnableExtensionRequestMessageContent => - content.type === 'enableExtensionRequest' + (content): content is ExtensionRequestMessageContent => + content.type === 'extensionRequest' ) .map((content) => { - return [content.id, content.extensionName]; + return [content.id, content.extensionCall]; }); if (toolRequests.length !== 0) { @@ -298,7 +298,7 @@ export default function ChatView({ // do the same for enable extension requests // leverages toolResponse to send the error notification - if (enableExtensionRequests.length !== 0) { + if (extensionRequests.length !== 0) { let responseMessage: Message = { role: 'user', created: Date.now(), @@ -306,7 +306,7 @@ export default function ChatView({ }; const notification = 'Interrupted by the user to make a correction'; // generate a response saying it was interrupted for each extension request - for (const [reqId, _] of enableExtensionRequests) { + for (const [reqId, _] of extensionRequests) { const toolResponse: ToolResponseMessageContent = { type: 'toolResponse', id: reqId, @@ -336,9 +336,9 @@ export default function ChatView({ (c) => c.type === 'toolConfirmationRequest' ); - const hasEnableExtension = message.content.every((c) => c.type === 'enableExtensionRequest'); + const hasExtensionRequest = message.content.every((c) => c.type === 'extensionRequest'); // Keep the message if it has text content or tool confirmation or is not just tool responses - return hasTextContent || !hasOnlyToolResponses || hasToolConfirmation || hasEnableExtension; + return hasTextContent || !hasOnlyToolResponses || hasToolConfirmation || hasExtensionRequest; } return true; diff --git a/ui/desktop/src/components/ExtensionConfirmation.tsx b/ui/desktop/src/components/ExtensionConfirmation.tsx index 3f4376ad..114c5f13 100644 --- a/ui/desktop/src/components/ExtensionConfirmation.tsx +++ b/ui/desktop/src/components/ExtensionConfirmation.tsx @@ -7,16 +7,20 @@ interface ExtensionConfirmationProps { isClicked: boolean; extensionConfirmationId: string; extensionName: string; + toolName: string; } export default function ExtensionConfirmation({ isCancelledMessage, isClicked, extensionConfirmationId, extensionName, + toolName, }: ExtensionConfirmationProps) { const [clicked, setClicked] = useState(isClicked); const [status, setStatus] = useState('unknown'); + const extensionAction = toolName.toLowerCase().includes('enable') ? 'enable' : 'disable'; + const handleButtonClick = async (confirmed: boolean) => { setClicked(true); setStatus(confirmed ? 'approved' : 'denied'); @@ -38,12 +42,12 @@ export default function ExtensionConfirmation({ return isCancelledMessage ? (
- Extension enablement is cancelled. + Extension {extensionAction} is cancelled.
) : ( <>
- Goose would like to enable the above extension. Allow? + Goose would like to {extensionAction} the following extension. Allow?
{clicked ? (
@@ -75,7 +79,7 @@ export default function ExtensionConfirmation({ {isClicked ? 'Extension enablement is not available' - : `${snakeToTitleCase(extensionName.includes('__') ? extensionName.split('__').pop() : extensionName)} is ${status}`}{' '} + : `${snakeToTitleCase(extensionName.includes('__') ? extensionName.split('__').pop() || extensionName : extensionName)} is ${status}`}{' '}
@@ -87,7 +91,8 @@ export default function ExtensionConfirmation({ } onClick={() => handleButtonClick(true)} > - Enable extension + {extensionAction.charAt(0).toUpperCase() + extensionAction.slice(1).toLowerCase()}{' '} + extension