feat: enable extension from ui (#2117)

This commit is contained in:
Wendy Tang
2025-04-11 13:06:59 -07:00
committed by GitHub
parent 387432ec86
commit b030f845ce
9 changed files with 254 additions and 13 deletions

View File

@@ -365,16 +365,23 @@ async fn ask_handler(
}))
}
#[derive(Debug, Deserialize)]
struct ToolConfirmationRequest {
#[derive(Debug, Deserialize, Serialize)]
struct PermissionConfirmationRequest {
id: String,
confirmed: bool,
#[serde(default = "default_principal_type")]
principal_type: PrincipalType,
action: String,
}
fn default_principal_type() -> PrincipalType {
PrincipalType::Tool
}
async fn confirm_handler(
State(state): State<AppState>,
headers: HeaderMap,
Json(request): Json<ToolConfirmationRequest>,
Json(request): Json<PermissionConfirmationRequest>,
) -> Result<Json<Value>, StatusCode> {
// Verify secret key
let secret_key = headers
@@ -385,6 +392,10 @@ async fn confirm_handler(
if secret_key != state.secret_key {
return Err(StatusCode::UNAUTHORIZED);
}
tracing::info!(
"Received confirmation request: {}",
serde_json::to_string_pretty(&request).unwrap()
);
let agent = state.agent.clone();
let agent = agent.read().await;
@@ -401,7 +412,7 @@ async fn confirm_handler(
.handle_confirmation(
request.id.clone(),
PermissionConfirmation {
principal_type: PrincipalType::Tool,
principal_type: request.principal_type,
permission,
},
)

View File

@@ -461,12 +461,12 @@ impl Agent {
let mut rx = self.confirmation_rx.lock().await;
while let Some((req_id, extension_confirmation)) = rx.recv().await {
let extension_name = tool_call.arguments.get("extension_name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if req_id == request.id {
if extension_confirmation.permission == Permission::AllowOnce || extension_confirmation.permission == Permission::AlwaysAllow {
let extension_name = tool_call.arguments.get("extension_name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let install_result = self.enable_extension(extension_name, request.id.clone()).await;
install_results.push(install_result);
} else {

View File

@@ -9,7 +9,7 @@ These models have varying knowledge cut-off dates depending on when they were tr
Extensions allow other applications to provide context to Goose. Extensions connect Goose to different data sources and tools.
You are capable of dynamically plugging into new extensions and learning how to use them. You solve higher level problems using the tools in these extensions, and can interact with multiple at once.
Use the search_available_extensions tool to find additional extensions to enable to help with your task. To enable extensions, use the enable_extension tool. You should only enable extensions found from the search_available_extensions tool.
Use the search_available_extensions tool to find additional extensions to enable to help with your task. To enable extensions, use the enable_extension tool and provide the extension_name. You should only enable extensions found from the search_available_extensions tool.
{% if (extensions is defined) and extensions %}
Because you dynamically load extensions, your conversation history may refer

View File

@@ -25,6 +25,7 @@ import {
ToolResponseMessageContent,
ToolConfirmationRequestMessageContent,
getTextContent,
EnableExtensionRequestMessageContent,
} from '../types/message';
export interface ChatType {
@@ -49,10 +50,12 @@ const isUserMessage = (message: Message): boolean => {
if (message.role === 'assistant') {
return false;
}
if (message.content.every((c) => c.type === 'toolConfirmationRequest')) {
return false;
}
if (message.content.every((c) => c.type === 'enableExtensionRequest')) {
return false;
}
return true;
};
@@ -312,6 +315,14 @@ export default function ChatView({
return [content.id, toolCall];
}
});
const enableExtensionRequests = lastMessage.content
.filter(
(content): content is EnableExtensionRequestMessageContent =>
content.type === 'enableExtensionRequest'
)
.map((content) => {
return [content.id, content.extensionName];
});
if (toolRequests.length !== 0) {
// This means we were interrupted during a tool request
@@ -338,10 +349,33 @@ export default function ChatView({
responseMessage.content.push(toolResponse);
}
// Use an immutable update to add the response message to the messages array
setMessages([...messages, responseMessage]);
}
// do the same for enable extension requests
// leverages toolResponse to send the error notification
if (enableExtensionRequests.length !== 0) {
let responseMessage: Message = {
role: 'user',
created: Date.now(),
content: [],
};
const notification = 'Interrupted by the user to make a correction';
// generate a response saying it was interrupted for each extension request
for (const [reqId, _] of enableExtensionRequests) {
const toolResponse: ToolResponseMessageContent = {
type: 'toolResponse',
id: reqId,
toolResult: {
status: 'error',
error: notification,
},
};
responseMessage.content.push(toolResponse);
}
setMessages([...messages, responseMessage]);
}
}
};
@@ -359,8 +393,9 @@ export default function ChatView({
(c) => c.type === 'toolConfirmationRequest'
);
const hasEnableExtension = message.content.every((c) => c.type === 'enableExtensionRequest');
// Keep the message if it has text content or tool confirmation or is not just tool responses
return hasTextContent || !hasOnlyToolResponses || hasToolConfirmation;
return hasTextContent || !hasOnlyToolResponses || hasToolConfirmation || hasEnableExtension;
}
return true;

View File

@@ -0,0 +1,85 @@
import React, { useState } from 'react';
import { ConfirmExtensionRequest } from '../utils/extensionConfirm';
import { snakeToTitleCase } from '../utils';
export default function ExtensionConfirmation({
isCancelledMessage,
isClicked,
extensionConfirmationId,
extensionName,
}) {
const [clicked, setClicked] = useState(isClicked);
const [status, setStatus] = useState('unknown');
const handleButtonClick = (confirmed) => {
setClicked(true);
setStatus(confirmed ? 'approved' : 'denied');
ConfirmExtensionRequest(extensionConfirmationId, confirmed);
};
return isCancelledMessage ? (
<div className="goose-message-content bg-bgSubtle rounded-2xl px-4 py-2 text-textStandard">
Extension enablement is cancelled.
</div>
) : (
<>
<div className="goose-message-content bg-bgSubtle rounded-2xl px-4 py-2 rounded-b-none text-textStandard">
Goose would like to enable the above extension. Allow?
</div>
{clicked ? (
<div className="goose-message-tool bg-bgApp border border-borderSubtle dark:border-gray-700 rounded-b-2xl px-4 pt-4 pb-2 flex gap-4 mt-1">
<div className="flex items-center">
{status === 'approved' && (
<svg
className="w-5 h-5 text-gray-500"
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
strokeWidth={2}
>
<path strokeLinecap="round" strokeLinejoin="round" d="M5 13l4 4L19 7" />
</svg>
)}
{status === 'denied' && (
<svg
className="w-5 h-5 text-gray-500"
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
strokeWidth={2}
>
<path strokeLinecap="round" strokeLinejoin="round" d="M6 18L18 6M6 6l12 12" />
</svg>
)}
<span className="ml-2 text-textStandard">
{isClicked
? 'Extension enablement is not available'
: `${snakeToTitleCase(extensionName.includes('__') ? extensionName.split('__').pop() : extensionName)} is ${status}`}{' '}
</span>
</div>
</div>
) : (
<div className="goose-message-tool bg-bgApp border border-borderSubtle dark:border-gray-700 rounded-b-2xl px-4 pt-4 pb-2 flex gap-4 mt-1">
<button
className={
'bg-black text-white dark:bg-white dark:text-black rounded-full px-6 py-2 transition'
}
onClick={() => handleButtonClick(true)}
>
Enable extension
</button>
<button
className={
'bg-white text-black dark:bg-black dark:text-white border border-gray-300 dark:border-gray-700 rounded-full px-6 py-2 transition'
}
onClick={() => handleButtonClick(false)}
>
Deny
</button>
</div>
)}
</>
);
}

View File

@@ -11,9 +11,11 @@ import {
getToolResponses,
getToolConfirmationContent,
createToolErrorResponseMessage,
getEnableExtensionContent,
} from '../types/message';
import ToolCallConfirmation from './ToolCallConfirmation';
import MessageCopyLink from './MessageCopyLink';
import ExtensionConfirmation from './ExtensionConfirmation';
interface GooseMessageProps {
messageHistoryIndex: number;
@@ -52,6 +54,9 @@ export default function GooseMessage({
const toolConfirmationContent = getToolConfirmationContent(message);
const hasToolConfirmation = toolConfirmationContent !== undefined;
const enableExtensionContent = getEnableExtensionContent(message);
const hasEnableExtension = enableExtensionContent !== undefined;
// Find tool responses that correspond to the tool requests in this message
const toolResponsesMap = useMemo(() => {
const responseMap = new Map();
@@ -86,12 +91,23 @@ export default function GooseMessage({
createToolErrorResponseMessage(toolConfirmationContent.id, 'The tool call is cancelled.')
);
}
if (messageIndex == messageHistoryIndex - 1 && hasEnableExtension) {
appendMessage(
createToolErrorResponseMessage(
enableExtensionContent.id,
'The extension enablement is cancelled.'
)
);
}
}, [
messageIndex,
messageHistoryIndex,
hasToolConfirmation,
toolConfirmationContent,
appendMessage,
hasEnableExtension,
// Only include enableExtensionContent if it exists
enableExtensionContent?.id,
]);
return (
@@ -140,6 +156,15 @@ export default function GooseMessage({
toolName={toolConfirmationContent.toolName}
/>
)}
{hasEnableExtension && (
<ExtensionConfirmation
isCancelledMessage={messageIndex == messageHistoryIndex - 1}
isClicked={messageIndex < messageHistoryIndex - 1}
extensionConfirmationId={enableExtensionContent.id}
extensionName={enableExtensionContent.extensionName}
/>
)}
</div>
{/* TODO(alexhancock): Re-enable link previews once styled well again */}

View File

@@ -68,12 +68,44 @@ export interface ToolConfirmationRequestMessageContent {
prompt?: string;
}
export interface EnableExtensionCall {
name: string;
arguments: Record<string, unknown>;
extensionName: string;
}
export interface EnableExtensionCallResult<T> {
status: 'success' | 'error';
value?: T;
error?: string;
}
export interface EnableExtensionRequest {
id: string;
extensionCall: EnableExtensionCallResult<EnableExtensionCall>;
}
export interface EnableExtensionConfirmationRequest {
id: string;
extensionName: string;
arguments: Record<string, unknown>;
prompt?: string;
}
export interface EnableExtensionRequestMessageContent {
type: 'enableExtensionRequest';
id: string;
extensionCall: EnableExtensionCallResult<EnableExtensionCall>;
extensionName: string;
}
export type MessageContent =
| TextContent
| ImageContent
| ToolRequestMessageContent
| ToolResponseMessageContent
| ToolConfirmationRequestMessageContent;
| ToolConfirmationRequestMessageContent
| EnableExtensionRequestMessageContent;
export interface Message {
id?: string;
@@ -187,6 +219,15 @@ export function getToolResponses(message: Message): ToolResponseMessageContent[]
);
}
export function getEnableExtensionRequests(
message: Message
): EnableExtensionRequestMessageContent[] {
return message.content.filter(
(content): content is EnableExtensionRequestMessageContent =>
content.type === 'enableExtensionRequest'
);
}
export function getToolConfirmationContent(
message: Message
): ToolConfirmationRequestMessageContent {
@@ -196,6 +237,13 @@ export function getToolConfirmationContent(
);
}
export function getEnableExtensionContent(message: Message): EnableExtensionRequestMessageContent {
return message.content.find(
(content): content is EnableExtensionRequestMessageContent =>
content.type === 'enableExtensionRequest'
);
}
export function hasCompletedToolCalls(message: Message): boolean {
const toolRequests = getToolRequests(message);
if (toolRequests.length === 0) return false;
@@ -205,3 +253,13 @@ export function hasCompletedToolCalls(message: Message): boolean {
// by looking through subsequent messages
return true;
}
export function hasCompletedEnableExtensionCalls(message: Message): boolean {
const extensionRequests = getEnableExtensionRequests(message);
if (extensionRequests.length === 0) return false;
// For now, we'll assume all extension calls are completed when this is checked
// In a real implementation, you'd need to check if all extension requests have responses
// by looking through subsequent messages
return true;
}

View File

@@ -0,0 +1,26 @@
import { getApiUrl, getSecretKey } from '../config';
export async function ConfirmExtensionRequest(requestId: string, confirmed: boolean) {
try {
const response = await fetch(getApiUrl('/confirm'), {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-Secret-Key': getSecretKey(),
},
body: JSON.stringify({
id: requestId,
confirmed,
principal_type: 'Extension',
}),
});
if (!response.ok) {
const errorText = await response.text();
console.error('Delete response error: ', errorText);
throw new Error('Failed to confirm extension enablement');
}
} catch (error) {
console.error('Error confirming extension enablement: ', error);
}
}

View File

@@ -11,6 +11,7 @@ export async function ConfirmToolRequest(requesyId: string, action: string) {
body: JSON.stringify({
id: requesyId,
action,
principal_type: 'Tool',
}),
});