mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-11 17:44:24 +01:00
chore: generalize extension request (#2213)
This commit is contained in:
@@ -17,6 +17,7 @@ use completion::GooseCompleter;
|
||||
use etcetera::choose_app_strategy;
|
||||
use etcetera::AppStrategy;
|
||||
use goose::agents::extension::{Envs, ExtensionConfig};
|
||||
use goose::agents::platform_tools::PLATFORM_ENABLE_EXTENSION_TOOL_NAME;
|
||||
use goose::agents::{Agent, SessionConfig};
|
||||
use goose::config::Config;
|
||||
use goose::message::{Message, MessageContent};
|
||||
@@ -620,13 +621,19 @@ impl Session {
|
||||
principal_type: PrincipalType::Tool,
|
||||
permission,
|
||||
},).await;
|
||||
} else if let Some(MessageContent::EnableExtensionRequest(enable_extension_request)) = message.content.first() {
|
||||
} else if let Some(MessageContent::ExtensionRequest(enable_extension_request)) = message.content.first() {
|
||||
output::hide_thinking();
|
||||
|
||||
let prompt = "Goose would like to install the following extension, do you approve?".to_string();
|
||||
let extension_action = if enable_extension_request.tool_name == PLATFORM_ENABLE_EXTENSION_TOOL_NAME {
|
||||
"enable"
|
||||
} else {
|
||||
"disable"
|
||||
};
|
||||
|
||||
let prompt = format!("Goose would like to {} the following extension, do you approve?", extension_action);
|
||||
let confirmed = cliclack::select(prompt)
|
||||
.item(true, "Yes, for this session", "Enable the extension for this session")
|
||||
.item(false, "No", "Do not enable the extension")
|
||||
.item(true, "Yes, for this session", format!("{} the extension for this session", extension_action))
|
||||
.item(false, "No", format!("Do not {} the extension", extension_action))
|
||||
.interact()?;
|
||||
let permission = if confirmed {
|
||||
Permission::AllowOnce
|
||||
|
||||
@@ -61,9 +61,10 @@ pub struct ToolConfirmationRequest {
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct EnableExtensionRequest {
|
||||
pub struct ExtensionRequest {
|
||||
pub id: String,
|
||||
pub extension_name: String,
|
||||
pub tool_name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
@@ -94,7 +95,7 @@ pub enum MessageContent {
|
||||
ToolRequest(ToolRequest),
|
||||
ToolResponse(ToolResponse),
|
||||
ToolConfirmationRequest(ToolConfirmationRequest),
|
||||
EnableExtensionRequest(EnableExtensionRequest),
|
||||
ExtensionRequest(ExtensionRequest),
|
||||
FrontendToolRequest(FrontendToolRequest),
|
||||
Thinking(ThinkingContent),
|
||||
RedactedThinking(RedactedThinkingContent),
|
||||
@@ -144,10 +145,15 @@ impl MessageContent {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn enable_extension_request<S: Into<String>>(id: S, extension_name: String) -> Self {
|
||||
MessageContent::EnableExtensionRequest(EnableExtensionRequest {
|
||||
pub fn extension_request<S: Into<String>>(
|
||||
id: S,
|
||||
extension_name: String,
|
||||
tool_name: String,
|
||||
) -> Self {
|
||||
MessageContent::ExtensionRequest(ExtensionRequest {
|
||||
id: id.into(),
|
||||
extension_name,
|
||||
tool_name,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -192,9 +198,9 @@ impl MessageContent {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_enable_extension_request(&self) -> Option<&EnableExtensionRequest> {
|
||||
if let MessageContent::EnableExtensionRequest(ref enable_extension_request) = self {
|
||||
Some(enable_extension_request)
|
||||
pub fn as_extension_request(&self) -> Option<&ExtensionRequest> {
|
||||
if let MessageContent::ExtensionRequest(ref extension_request) = self {
|
||||
Some(extension_request)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -359,12 +365,17 @@ impl Message {
|
||||
))
|
||||
}
|
||||
|
||||
pub fn with_enable_extension_request<S: Into<String>>(
|
||||
pub fn with_extension_request<S: Into<String>>(
|
||||
self,
|
||||
id: S,
|
||||
extension_name: String,
|
||||
tool_name: String,
|
||||
) -> Self {
|
||||
self.with_content(MessageContent::enable_extension_request(id, extension_name))
|
||||
self.with_content(MessageContent::extension_request(
|
||||
id,
|
||||
extension_name,
|
||||
tool_name,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn with_frontend_tool_request<S: Into<String>>(
|
||||
|
||||
@@ -160,7 +160,7 @@ pub fn get_confirmation_message(request_id: &str, tool_call: ToolCall) -> (Princ
|
||||
if tool_call.name == PLATFORM_ENABLE_EXTENSION_TOOL_NAME {
|
||||
(
|
||||
PrincipalType::Extension,
|
||||
Message::user().with_enable_extension_request(
|
||||
Message::user().with_extension_request(
|
||||
request_id,
|
||||
tool_call
|
||||
.arguments
|
||||
@@ -168,6 +168,7 @@ pub fn get_confirmation_message(request_id: &str, tool_call: ToolCall) -> (Princ
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
tool_call.name.clone(),
|
||||
),
|
||||
)
|
||||
} else {
|
||||
|
||||
@@ -60,8 +60,8 @@ pub fn format_messages(messages: &[Message]) -> Vec<Value> {
|
||||
MessageContent::ToolConfirmationRequest(_tool_confirmation_request) => {
|
||||
// Skip tool confirmation requests
|
||||
}
|
||||
MessageContent::EnableExtensionRequest(_enable_extension_request) => {
|
||||
// Skip enable extension requests
|
||||
MessageContent::ExtensionRequest(_extension_request) => {
|
||||
// Skip extension requests
|
||||
}
|
||||
MessageContent::Thinking(thinking) => {
|
||||
content.push(json!({
|
||||
|
||||
@@ -31,7 +31,7 @@ pub fn to_bedrock_message_content(content: &MessageContent) -> Result<bedrock::C
|
||||
MessageContent::ToolConfirmationRequest(_tool_confirmation_request) => {
|
||||
bedrock::ContentBlock::Text("".to_string())
|
||||
}
|
||||
MessageContent::EnableExtensionRequest(_enable_extension_request) => {
|
||||
MessageContent::ExtensionRequest(_extension_request) => {
|
||||
bedrock::ContentBlock::Text("".to_string())
|
||||
}
|
||||
MessageContent::Image(_) => {
|
||||
|
||||
@@ -179,7 +179,7 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
|
||||
MessageContent::ToolConfirmationRequest(_) => {
|
||||
// Skip tool confirmation requests
|
||||
}
|
||||
MessageContent::EnableExtensionRequest(_) => {
|
||||
MessageContent::ExtensionRequest(_) => {
|
||||
// Skip enable extension requests
|
||||
}
|
||||
MessageContent::Image(image) => {
|
||||
|
||||
@@ -147,7 +147,7 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
|
||||
MessageContent::ToolConfirmationRequest(_) => {
|
||||
// Skip tool confirmation requests
|
||||
}
|
||||
MessageContent::EnableExtensionRequest(_) => {
|
||||
MessageContent::ExtensionRequest(_) => {
|
||||
// Skip enable extension requests
|
||||
}
|
||||
MessageContent::Image(image) => {
|
||||
|
||||
@@ -27,7 +27,7 @@ import {
|
||||
ToolRequestMessageContent,
|
||||
ToolResponseMessageContent,
|
||||
ToolConfirmationRequestMessageContent,
|
||||
EnableExtensionRequestMessageContent,
|
||||
ExtensionRequestMessageContent,
|
||||
} from '../types/message';
|
||||
|
||||
export interface ChatType {
|
||||
@@ -47,7 +47,7 @@ const isUserMessage = (message: Message): boolean => {
|
||||
if (message.content.every((c) => c.type === 'toolConfirmationRequest')) {
|
||||
return false;
|
||||
}
|
||||
if (message.content.every((c) => c.type === 'enableExtensionRequest')) {
|
||||
if (message.content.every((c) => c.type === 'extensionRequest')) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
@@ -258,13 +258,13 @@ export default function ChatView({
|
||||
return [content.id, toolCall];
|
||||
}
|
||||
});
|
||||
const enableExtensionRequests = lastMessage.content
|
||||
const extensionRequests = lastMessage.content
|
||||
.filter(
|
||||
(content): content is EnableExtensionRequestMessageContent =>
|
||||
content.type === 'enableExtensionRequest'
|
||||
(content): content is ExtensionRequestMessageContent =>
|
||||
content.type === 'extensionRequest'
|
||||
)
|
||||
.map((content) => {
|
||||
return [content.id, content.extensionName];
|
||||
return [content.id, content.extensionCall];
|
||||
});
|
||||
|
||||
if (toolRequests.length !== 0) {
|
||||
@@ -298,7 +298,7 @@ export default function ChatView({
|
||||
|
||||
// do the same for enable extension requests
|
||||
// leverages toolResponse to send the error notification
|
||||
if (enableExtensionRequests.length !== 0) {
|
||||
if (extensionRequests.length !== 0) {
|
||||
let responseMessage: Message = {
|
||||
role: 'user',
|
||||
created: Date.now(),
|
||||
@@ -306,7 +306,7 @@ export default function ChatView({
|
||||
};
|
||||
const notification = 'Interrupted by the user to make a correction';
|
||||
// generate a response saying it was interrupted for each extension request
|
||||
for (const [reqId, _] of enableExtensionRequests) {
|
||||
for (const [reqId, _] of extensionRequests) {
|
||||
const toolResponse: ToolResponseMessageContent = {
|
||||
type: 'toolResponse',
|
||||
id: reqId,
|
||||
@@ -336,9 +336,9 @@ export default function ChatView({
|
||||
(c) => c.type === 'toolConfirmationRequest'
|
||||
);
|
||||
|
||||
const hasEnableExtension = message.content.every((c) => c.type === 'enableExtensionRequest');
|
||||
const hasExtensionRequest = message.content.every((c) => c.type === 'extensionRequest');
|
||||
// Keep the message if it has text content or tool confirmation or is not just tool responses
|
||||
return hasTextContent || !hasOnlyToolResponses || hasToolConfirmation || hasEnableExtension;
|
||||
return hasTextContent || !hasOnlyToolResponses || hasToolConfirmation || hasExtensionRequest;
|
||||
}
|
||||
|
||||
return true;
|
||||
|
||||
@@ -7,16 +7,20 @@ interface ExtensionConfirmationProps {
|
||||
isClicked: boolean;
|
||||
extensionConfirmationId: string;
|
||||
extensionName: string;
|
||||
toolName: string;
|
||||
}
|
||||
export default function ExtensionConfirmation({
|
||||
isCancelledMessage,
|
||||
isClicked,
|
||||
extensionConfirmationId,
|
||||
extensionName,
|
||||
toolName,
|
||||
}: ExtensionConfirmationProps) {
|
||||
const [clicked, setClicked] = useState(isClicked);
|
||||
const [status, setStatus] = useState('unknown');
|
||||
|
||||
const extensionAction = toolName.toLowerCase().includes('enable') ? 'enable' : 'disable';
|
||||
|
||||
const handleButtonClick = async (confirmed: boolean) => {
|
||||
setClicked(true);
|
||||
setStatus(confirmed ? 'approved' : 'denied');
|
||||
@@ -38,12 +42,12 @@ export default function ExtensionConfirmation({
|
||||
|
||||
return isCancelledMessage ? (
|
||||
<div className="goose-message-content bg-bgSubtle rounded-2xl px-4 py-2 text-textStandard">
|
||||
Extension enablement is cancelled.
|
||||
Extension {extensionAction} is cancelled.
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div className="goose-message-content bg-bgSubtle rounded-2xl px-4 py-2 rounded-b-none text-textStandard">
|
||||
Goose would like to enable the above extension. Allow?
|
||||
Goose would like to {extensionAction} the following extension. Allow?
|
||||
</div>
|
||||
{clicked ? (
|
||||
<div className="goose-message-tool bg-bgApp border border-borderSubtle dark:border-gray-700 rounded-b-2xl px-4 pt-4 pb-2 flex gap-4 mt-1">
|
||||
@@ -75,7 +79,7 @@ export default function ExtensionConfirmation({
|
||||
<span className="ml-2 text-textStandard">
|
||||
{isClicked
|
||||
? 'Extension enablement is not available'
|
||||
: `${snakeToTitleCase(extensionName.includes('__') ? extensionName.split('__').pop() : extensionName)} is ${status}`}{' '}
|
||||
: `${snakeToTitleCase(extensionName.includes('__') ? extensionName.split('__').pop() || extensionName : extensionName)} is ${status}`}{' '}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
@@ -87,7 +91,8 @@ export default function ExtensionConfirmation({
|
||||
}
|
||||
onClick={() => handleButtonClick(true)}
|
||||
>
|
||||
Enable extension
|
||||
{extensionAction.charAt(0).toUpperCase() + extensionAction.slice(1).toLowerCase()}{' '}
|
||||
extension
|
||||
</button>
|
||||
<button
|
||||
className={
|
||||
|
||||
@@ -11,7 +11,7 @@ import {
|
||||
getToolResponses,
|
||||
getToolConfirmationContent,
|
||||
createToolErrorResponseMessage,
|
||||
getEnableExtensionContent,
|
||||
getExtensionContent,
|
||||
} from '../types/message';
|
||||
import ToolCallConfirmation from './ToolCallConfirmation';
|
||||
import MessageCopyLink from './MessageCopyLink';
|
||||
@@ -54,8 +54,8 @@ export default function GooseMessage({
|
||||
const toolConfirmationContent = getToolConfirmationContent(message);
|
||||
const hasToolConfirmation = toolConfirmationContent !== undefined;
|
||||
|
||||
const enableExtensionContent = getEnableExtensionContent(message);
|
||||
const hasEnableExtension = enableExtensionContent !== undefined;
|
||||
const extensionContent = getExtensionContent(message);
|
||||
const hasExtensionRequest = extensionContent !== undefined;
|
||||
|
||||
// Find tool responses that correspond to the tool requests in this message
|
||||
const toolResponsesMap = useMemo(() => {
|
||||
@@ -91,10 +91,10 @@ export default function GooseMessage({
|
||||
createToolErrorResponseMessage(toolConfirmationContent.id, 'The tool call is cancelled.')
|
||||
);
|
||||
}
|
||||
if (messageIndex == messageHistoryIndex - 1 && hasEnableExtension) {
|
||||
if (messageIndex == messageHistoryIndex - 1 && hasExtensionRequest) {
|
||||
appendMessage(
|
||||
createToolErrorResponseMessage(
|
||||
enableExtensionContent.id,
|
||||
extensionContent.id,
|
||||
'The extension enablement is cancelled.'
|
||||
)
|
||||
);
|
||||
@@ -105,9 +105,9 @@ export default function GooseMessage({
|
||||
hasToolConfirmation,
|
||||
toolConfirmationContent,
|
||||
appendMessage,
|
||||
hasEnableExtension,
|
||||
hasExtensionRequest,
|
||||
// Only include enableExtensionContent if it exists
|
||||
enableExtensionContent?.id,
|
||||
extensionContent?.id,
|
||||
]);
|
||||
|
||||
return (
|
||||
@@ -157,12 +157,13 @@ export default function GooseMessage({
|
||||
/>
|
||||
)}
|
||||
|
||||
{hasEnableExtension && (
|
||||
{hasExtensionRequest && (
|
||||
<ExtensionConfirmation
|
||||
isCancelledMessage={messageIndex == messageHistoryIndex - 1}
|
||||
isClicked={messageIndex < messageHistoryIndex - 1}
|
||||
extensionConfirmationId={enableExtensionContent.id}
|
||||
extensionName={enableExtensionContent.extensionName}
|
||||
extensionConfirmationId={extensionContent.id}
|
||||
extensionName={extensionContent.extensionName}
|
||||
toolName={extensionContent.toolName}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -68,35 +68,36 @@ export interface ToolConfirmationRequestMessageContent {
|
||||
prompt?: string;
|
||||
}
|
||||
|
||||
export interface EnableExtensionCall {
|
||||
export interface ExtensionCall {
|
||||
name: string;
|
||||
arguments: Record<string, unknown>;
|
||||
extensionName: string;
|
||||
}
|
||||
|
||||
export interface EnableExtensionCallResult<T> {
|
||||
export interface ExtensionCallResult<T> {
|
||||
status: 'success' | 'error';
|
||||
value?: T;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export interface EnableExtensionRequest {
|
||||
export interface ExtensionRequest {
|
||||
id: string;
|
||||
extensionCall: EnableExtensionCallResult<EnableExtensionCall>;
|
||||
extensionCall: ExtensionCallResult<ExtensionCall>;
|
||||
}
|
||||
|
||||
export interface EnableExtensionConfirmationRequest {
|
||||
export interface ExtensionConfirmationRequest {
|
||||
id: string;
|
||||
extensionName: string;
|
||||
arguments: Record<string, unknown>;
|
||||
prompt?: string;
|
||||
}
|
||||
|
||||
export interface EnableExtensionRequestMessageContent {
|
||||
type: 'enableExtensionRequest';
|
||||
export interface ExtensionRequestMessageContent {
|
||||
type: 'extensionRequest';
|
||||
id: string;
|
||||
extensionCall: EnableExtensionCallResult<EnableExtensionCall>;
|
||||
extensionCall: ExtensionCallResult<ExtensionCall>;
|
||||
extensionName: string;
|
||||
toolName: string;
|
||||
}
|
||||
|
||||
export type MessageContent =
|
||||
@@ -105,7 +106,7 @@ export type MessageContent =
|
||||
| ToolRequestMessageContent
|
||||
| ToolResponseMessageContent
|
||||
| ToolConfirmationRequestMessageContent
|
||||
| EnableExtensionRequestMessageContent;
|
||||
| ExtensionRequestMessageContent;
|
||||
|
||||
export interface Message {
|
||||
id?: string;
|
||||
@@ -219,12 +220,11 @@ export function getToolResponses(message: Message): ToolResponseMessageContent[]
|
||||
);
|
||||
}
|
||||
|
||||
export function getEnableExtensionRequests(
|
||||
export function getExtensionRequests(
|
||||
message: Message
|
||||
): EnableExtensionRequestMessageContent[] {
|
||||
): ExtensionRequestMessageContent[] {
|
||||
return message.content.filter(
|
||||
(content): content is EnableExtensionRequestMessageContent =>
|
||||
content.type === 'enableExtensionRequest'
|
||||
(content): content is ExtensionRequestMessageContent => content.type === 'extensionRequest'
|
||||
);
|
||||
}
|
||||
|
||||
@@ -237,10 +237,10 @@ export function getToolConfirmationContent(
|
||||
);
|
||||
}
|
||||
|
||||
export function getEnableExtensionContent(message: Message): EnableExtensionRequestMessageContent {
|
||||
export function getExtensionContent(message: Message): ExtensionRequestMessageContent {
|
||||
return message.content.find(
|
||||
(content): content is EnableExtensionRequestMessageContent =>
|
||||
content.type === 'enableExtensionRequest'
|
||||
(content): content is ExtensionRequestMessageContent =>
|
||||
content.type === 'extensionRequest'
|
||||
);
|
||||
}
|
||||
|
||||
@@ -253,13 +253,3 @@ export function hasCompletedToolCalls(message: Message): boolean {
|
||||
// by looking through subsequent messages
|
||||
return true;
|
||||
}
|
||||
|
||||
export function hasCompletedEnableExtensionCalls(message: Message): boolean {
|
||||
const extensionRequests = getEnableExtensionRequests(message);
|
||||
if (extensionRequests.length === 0) return false;
|
||||
|
||||
// For now, we'll assume all extension calls are completed when this is checked
|
||||
// In a real implementation, you'd need to check if all extension requests have responses
|
||||
// by looking through subsequent messages
|
||||
return true;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user