chore: generalize extension request (#2213)

This commit is contained in:
Wendy Tang
2025-04-16 08:39:49 -07:00
committed by GitHub
parent 970147f8ad
commit 7e4cfcdaae
11 changed files with 84 additions and 69 deletions

View File

@@ -17,6 +17,7 @@ use completion::GooseCompleter;
use etcetera::choose_app_strategy;
use etcetera::AppStrategy;
use goose::agents::extension::{Envs, ExtensionConfig};
use goose::agents::platform_tools::PLATFORM_ENABLE_EXTENSION_TOOL_NAME;
use goose::agents::{Agent, SessionConfig};
use goose::config::Config;
use goose::message::{Message, MessageContent};
@@ -620,13 +621,19 @@ impl Session {
principal_type: PrincipalType::Tool,
permission,
},).await;
} else if let Some(MessageContent::EnableExtensionRequest(enable_extension_request)) = message.content.first() {
} else if let Some(MessageContent::ExtensionRequest(enable_extension_request)) = message.content.first() {
output::hide_thinking();
let prompt = "Goose would like to install the following extension, do you approve?".to_string();
let extension_action = if enable_extension_request.tool_name == PLATFORM_ENABLE_EXTENSION_TOOL_NAME {
"enable"
} else {
"disable"
};
let prompt = format!("Goose would like to {} the following extension, do you approve?", extension_action);
let confirmed = cliclack::select(prompt)
.item(true, "Yes, for this session", "Enable the extension for this session")
.item(false, "No", "Do not enable the extension")
.item(true, "Yes, for this session", format!("{} the extension for this session", extension_action))
.item(false, "No", format!("Do not {} the extension", extension_action))
.interact()?;
let permission = if confirmed {
Permission::AllowOnce

View File

@@ -61,9 +61,10 @@ pub struct ToolConfirmationRequest {
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct EnableExtensionRequest {
pub struct ExtensionRequest {
pub id: String,
pub extension_name: String,
pub tool_name: String,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
@@ -94,7 +95,7 @@ pub enum MessageContent {
ToolRequest(ToolRequest),
ToolResponse(ToolResponse),
ToolConfirmationRequest(ToolConfirmationRequest),
EnableExtensionRequest(EnableExtensionRequest),
ExtensionRequest(ExtensionRequest),
FrontendToolRequest(FrontendToolRequest),
Thinking(ThinkingContent),
RedactedThinking(RedactedThinkingContent),
@@ -144,10 +145,15 @@ impl MessageContent {
})
}
pub fn enable_extension_request<S: Into<String>>(id: S, extension_name: String) -> Self {
MessageContent::EnableExtensionRequest(EnableExtensionRequest {
pub fn extension_request<S: Into<String>>(
id: S,
extension_name: String,
tool_name: String,
) -> Self {
MessageContent::ExtensionRequest(ExtensionRequest {
id: id.into(),
extension_name,
tool_name,
})
}
@@ -192,9 +198,9 @@ impl MessageContent {
}
}
pub fn as_enable_extension_request(&self) -> Option<&EnableExtensionRequest> {
if let MessageContent::EnableExtensionRequest(ref enable_extension_request) = self {
Some(enable_extension_request)
pub fn as_extension_request(&self) -> Option<&ExtensionRequest> {
if let MessageContent::ExtensionRequest(ref extension_request) = self {
Some(extension_request)
} else {
None
}
@@ -359,12 +365,17 @@ impl Message {
))
}
pub fn with_enable_extension_request<S: Into<String>>(
pub fn with_extension_request<S: Into<String>>(
self,
id: S,
extension_name: String,
tool_name: String,
) -> Self {
self.with_content(MessageContent::enable_extension_request(id, extension_name))
self.with_content(MessageContent::extension_request(
id,
extension_name,
tool_name,
))
}
pub fn with_frontend_tool_request<S: Into<String>>(

View File

@@ -160,7 +160,7 @@ pub fn get_confirmation_message(request_id: &str, tool_call: ToolCall) -> (Princ
if tool_call.name == PLATFORM_ENABLE_EXTENSION_TOOL_NAME {
(
PrincipalType::Extension,
Message::user().with_enable_extension_request(
Message::user().with_extension_request(
request_id,
tool_call
.arguments
@@ -168,6 +168,7 @@ pub fn get_confirmation_message(request_id: &str, tool_call: ToolCall) -> (Princ
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
tool_call.name.clone(),
),
)
} else {

View File

@@ -60,8 +60,8 @@ pub fn format_messages(messages: &[Message]) -> Vec<Value> {
MessageContent::ToolConfirmationRequest(_tool_confirmation_request) => {
// Skip tool confirmation requests
}
MessageContent::EnableExtensionRequest(_enable_extension_request) => {
// Skip enable extension requests
MessageContent::ExtensionRequest(_extension_request) => {
// Skip extension requests
}
MessageContent::Thinking(thinking) => {
content.push(json!({

View File

@@ -31,7 +31,7 @@ pub fn to_bedrock_message_content(content: &MessageContent) -> Result<bedrock::C
MessageContent::ToolConfirmationRequest(_tool_confirmation_request) => {
bedrock::ContentBlock::Text("".to_string())
}
MessageContent::EnableExtensionRequest(_enable_extension_request) => {
MessageContent::ExtensionRequest(_extension_request) => {
bedrock::ContentBlock::Text("".to_string())
}
MessageContent::Image(_) => {

View File

@@ -179,7 +179,7 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
MessageContent::ToolConfirmationRequest(_) => {
// Skip tool confirmation requests
}
MessageContent::EnableExtensionRequest(_) => {
MessageContent::ExtensionRequest(_) => {
// Skip enable extension requests
}
MessageContent::Image(image) => {

View File

@@ -147,7 +147,7 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
MessageContent::ToolConfirmationRequest(_) => {
// Skip tool confirmation requests
}
MessageContent::EnableExtensionRequest(_) => {
MessageContent::ExtensionRequest(_) => {
// Skip enable extension requests
}
MessageContent::Image(image) => {

View File

@@ -27,7 +27,7 @@ import {
ToolRequestMessageContent,
ToolResponseMessageContent,
ToolConfirmationRequestMessageContent,
EnableExtensionRequestMessageContent,
ExtensionRequestMessageContent,
} from '../types/message';
export interface ChatType {
@@ -47,7 +47,7 @@ const isUserMessage = (message: Message): boolean => {
if (message.content.every((c) => c.type === 'toolConfirmationRequest')) {
return false;
}
if (message.content.every((c) => c.type === 'enableExtensionRequest')) {
if (message.content.every((c) => c.type === 'extensionRequest')) {
return false;
}
return true;
@@ -258,13 +258,13 @@ export default function ChatView({
return [content.id, toolCall];
}
});
const enableExtensionRequests = lastMessage.content
const extensionRequests = lastMessage.content
.filter(
(content): content is EnableExtensionRequestMessageContent =>
content.type === 'enableExtensionRequest'
(content): content is ExtensionRequestMessageContent =>
content.type === 'extensionRequest'
)
.map((content) => {
return [content.id, content.extensionName];
return [content.id, content.extensionCall];
});
if (toolRequests.length !== 0) {
@@ -298,7 +298,7 @@ export default function ChatView({
// do the same for enable extension requests
// leverages toolResponse to send the error notification
if (enableExtensionRequests.length !== 0) {
if (extensionRequests.length !== 0) {
let responseMessage: Message = {
role: 'user',
created: Date.now(),
@@ -306,7 +306,7 @@ export default function ChatView({
};
const notification = 'Interrupted by the user to make a correction';
// generate a response saying it was interrupted for each extension request
for (const [reqId, _] of enableExtensionRequests) {
for (const [reqId, _] of extensionRequests) {
const toolResponse: ToolResponseMessageContent = {
type: 'toolResponse',
id: reqId,
@@ -336,9 +336,9 @@ export default function ChatView({
(c) => c.type === 'toolConfirmationRequest'
);
const hasEnableExtension = message.content.every((c) => c.type === 'enableExtensionRequest');
const hasExtensionRequest = message.content.every((c) => c.type === 'extensionRequest');
// Keep the message if it has text content or tool confirmation or is not just tool responses
return hasTextContent || !hasOnlyToolResponses || hasToolConfirmation || hasEnableExtension;
return hasTextContent || !hasOnlyToolResponses || hasToolConfirmation || hasExtensionRequest;
}
return true;

View File

@@ -7,16 +7,20 @@ interface ExtensionConfirmationProps {
isClicked: boolean;
extensionConfirmationId: string;
extensionName: string;
toolName: string;
}
export default function ExtensionConfirmation({
isCancelledMessage,
isClicked,
extensionConfirmationId,
extensionName,
toolName,
}: ExtensionConfirmationProps) {
const [clicked, setClicked] = useState(isClicked);
const [status, setStatus] = useState('unknown');
const extensionAction = toolName.toLowerCase().includes('enable') ? 'enable' : 'disable';
const handleButtonClick = async (confirmed: boolean) => {
setClicked(true);
setStatus(confirmed ? 'approved' : 'denied');
@@ -38,12 +42,12 @@ export default function ExtensionConfirmation({
return isCancelledMessage ? (
<div className="goose-message-content bg-bgSubtle rounded-2xl px-4 py-2 text-textStandard">
Extension enablement is cancelled.
Extension {extensionAction} is cancelled.
</div>
) : (
<>
<div className="goose-message-content bg-bgSubtle rounded-2xl px-4 py-2 rounded-b-none text-textStandard">
Goose would like to enable the above extension. Allow?
Goose would like to {extensionAction} the following extension. Allow?
</div>
{clicked ? (
<div className="goose-message-tool bg-bgApp border border-borderSubtle dark:border-gray-700 rounded-b-2xl px-4 pt-4 pb-2 flex gap-4 mt-1">
@@ -75,7 +79,7 @@ export default function ExtensionConfirmation({
<span className="ml-2 text-textStandard">
{isClicked
? 'Extension enablement is not available'
: `${snakeToTitleCase(extensionName.includes('__') ? extensionName.split('__').pop() : extensionName)} is ${status}`}{' '}
: `${snakeToTitleCase(extensionName.includes('__') ? extensionName.split('__').pop() || extensionName : extensionName)} is ${status}`}{' '}
</span>
</div>
</div>
@@ -87,7 +91,8 @@ export default function ExtensionConfirmation({
}
onClick={() => handleButtonClick(true)}
>
Enable extension
{extensionAction.charAt(0).toUpperCase() + extensionAction.slice(1).toLowerCase()}{' '}
extension
</button>
<button
className={

View File

@@ -11,7 +11,7 @@ import {
getToolResponses,
getToolConfirmationContent,
createToolErrorResponseMessage,
getEnableExtensionContent,
getExtensionContent,
} from '../types/message';
import ToolCallConfirmation from './ToolCallConfirmation';
import MessageCopyLink from './MessageCopyLink';
@@ -54,8 +54,8 @@ export default function GooseMessage({
const toolConfirmationContent = getToolConfirmationContent(message);
const hasToolConfirmation = toolConfirmationContent !== undefined;
const enableExtensionContent = getEnableExtensionContent(message);
const hasEnableExtension = enableExtensionContent !== undefined;
const extensionContent = getExtensionContent(message);
const hasExtensionRequest = extensionContent !== undefined;
// Find tool responses that correspond to the tool requests in this message
const toolResponsesMap = useMemo(() => {
@@ -91,10 +91,10 @@ export default function GooseMessage({
createToolErrorResponseMessage(toolConfirmationContent.id, 'The tool call is cancelled.')
);
}
if (messageIndex == messageHistoryIndex - 1 && hasEnableExtension) {
if (messageIndex == messageHistoryIndex - 1 && hasExtensionRequest) {
appendMessage(
createToolErrorResponseMessage(
enableExtensionContent.id,
extensionContent.id,
'The extension enablement is cancelled.'
)
);
@@ -105,9 +105,9 @@ export default function GooseMessage({
hasToolConfirmation,
toolConfirmationContent,
appendMessage,
hasEnableExtension,
hasExtensionRequest,
// Only include enableExtensionContent if it exists
enableExtensionContent?.id,
extensionContent?.id,
]);
return (
@@ -157,12 +157,13 @@ export default function GooseMessage({
/>
)}
{hasEnableExtension && (
{hasExtensionRequest && (
<ExtensionConfirmation
isCancelledMessage={messageIndex == messageHistoryIndex - 1}
isClicked={messageIndex < messageHistoryIndex - 1}
extensionConfirmationId={enableExtensionContent.id}
extensionName={enableExtensionContent.extensionName}
extensionConfirmationId={extensionContent.id}
extensionName={extensionContent.extensionName}
toolName={extensionContent.toolName}
/>
)}
</div>

View File

@@ -68,35 +68,36 @@ export interface ToolConfirmationRequestMessageContent {
prompt?: string;
}
export interface EnableExtensionCall {
export interface ExtensionCall {
name: string;
arguments: Record<string, unknown>;
extensionName: string;
}
export interface EnableExtensionCallResult<T> {
export interface ExtensionCallResult<T> {
status: 'success' | 'error';
value?: T;
error?: string;
}
export interface EnableExtensionRequest {
export interface ExtensionRequest {
id: string;
extensionCall: EnableExtensionCallResult<EnableExtensionCall>;
extensionCall: ExtensionCallResult<ExtensionCall>;
}
export interface EnableExtensionConfirmationRequest {
export interface ExtensionConfirmationRequest {
id: string;
extensionName: string;
arguments: Record<string, unknown>;
prompt?: string;
}
export interface EnableExtensionRequestMessageContent {
type: 'enableExtensionRequest';
export interface ExtensionRequestMessageContent {
type: 'extensionRequest';
id: string;
extensionCall: EnableExtensionCallResult<EnableExtensionCall>;
extensionCall: ExtensionCallResult<ExtensionCall>;
extensionName: string;
toolName: string;
}
export type MessageContent =
@@ -105,7 +106,7 @@ export type MessageContent =
| ToolRequestMessageContent
| ToolResponseMessageContent
| ToolConfirmationRequestMessageContent
| EnableExtensionRequestMessageContent;
| ExtensionRequestMessageContent;
export interface Message {
id?: string;
@@ -219,12 +220,11 @@ export function getToolResponses(message: Message): ToolResponseMessageContent[]
);
}
export function getEnableExtensionRequests(
export function getExtensionRequests(
message: Message
): EnableExtensionRequestMessageContent[] {
): ExtensionRequestMessageContent[] {
return message.content.filter(
(content): content is EnableExtensionRequestMessageContent =>
content.type === 'enableExtensionRequest'
(content): content is ExtensionRequestMessageContent => content.type === 'extensionRequest'
);
}
@@ -237,10 +237,10 @@ export function getToolConfirmationContent(
);
}
export function getEnableExtensionContent(message: Message): EnableExtensionRequestMessageContent {
export function getExtensionContent(message: Message): ExtensionRequestMessageContent {
return message.content.find(
(content): content is EnableExtensionRequestMessageContent =>
content.type === 'enableExtensionRequest'
(content): content is ExtensionRequestMessageContent =>
content.type === 'extensionRequest'
);
}
@@ -253,13 +253,3 @@ export function hasCompletedToolCalls(message: Message): boolean {
// by looking through subsequent messages
return true;
}
export function hasCompletedEnableExtensionCalls(message: Message): boolean {
const extensionRequests = getEnableExtensionRequests(message);
if (extensionRequests.length === 0) return false;
// For now, we'll assume all extension calls are completed when this is checked
// In a real implementation, you'd need to check if all extension requests have responses
// by looking through subsequent messages
return true;
}