feat(goose): support customizing extension timeout (#1428)

This commit is contained in:
Ariel
2025-03-01 04:17:53 +08:00
committed by GitHub
parent e8212c4005
commit fbc6bb7b90
10 changed files with 138 additions and 17 deletions

View File

@@ -38,6 +38,7 @@ pub async fn handle_configure() -> Result<(), Box<dyn Error>> {
enabled: true, enabled: true,
config: ExtensionConfig::Builtin { config: ExtensionConfig::Builtin {
name: "developer".to_string(), name: "developer".to_string(),
timeout: Some(goose::config::DEFAULT_EXTENSION_TIMEOUT),
}, },
})?; })?;
} }
@@ -437,10 +438,19 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
.interact()? .interact()?
.to_string(); .to_string();
let timeout: u64 = cliclack::input("Please set the timeout for this tool (in secs):")
.placeholder(&goose::config::DEFAULT_EXTENSION_TIMEOUT.to_string())
.validate(|input: &String| match input.parse::<u64>() {
Ok(_) => Ok(()),
Err(_) => Err("Please enter a valide timeout"),
})
.interact()?;
ExtensionManager::set(ExtensionEntry { ExtensionManager::set(ExtensionEntry {
enabled: true, enabled: true,
config: ExtensionConfig::Builtin { config: ExtensionConfig::Builtin {
name: extension.clone(), name: extension.clone(),
timeout: Some(timeout),
}, },
})?; })?;
@@ -472,6 +482,14 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
}) })
.interact()?; .interact()?;
let timeout: u64 = cliclack::input("Please set the timeout for this tool (in secs):")
.placeholder(&goose::config::DEFAULT_EXTENSION_TIMEOUT.to_string())
.validate(|input: &String| match input.parse::<u64>() {
Ok(_) => Ok(()),
Err(_) => Err("Please enter a valide timeout"),
})
.interact()?;
// Split the command string into command and args // Split the command string into command and args
let mut parts = command_str.split_whitespace(); let mut parts = command_str.split_whitespace();
let cmd = parts.next().unwrap_or("").to_string(); let cmd = parts.next().unwrap_or("").to_string();
@@ -506,6 +524,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
cmd, cmd,
args, args,
envs: Envs::new(envs), envs: Envs::new(envs),
timeout: Some(timeout),
}, },
})?; })?;
@@ -539,6 +558,14 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
}) })
.interact()?; .interact()?;
let timeout: u64 = cliclack::input("Please set the timeout for this tool (in secs):")
.placeholder(&goose::config::DEFAULT_EXTENSION_TIMEOUT.to_string())
.validate(|input: &String| match input.parse::<u64>() {
Ok(_) => Ok(()),
Err(_) => Err("Please enter a valide timeout"),
})
.interact()?;
let add_env = let add_env =
cliclack::confirm("Would you like to add environment variables?").interact()?; cliclack::confirm("Would you like to add environment variables?").interact()?;
@@ -567,6 +594,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
name: name.clone(), name: name.clone(),
uri, uri,
envs: Envs::new(envs), envs: Envs::new(envs),
timeout: Some(timeout),
}, },
})?; })?;

View File

@@ -107,6 +107,8 @@ impl Session {
cmd, cmd,
args: parts.iter().map(|s| s.to_string()).collect(), args: parts.iter().map(|s| s.to_string()).collect(),
envs: Envs::new(envs), envs: Envs::new(envs),
// TODO: should set timeout
timeout: Some(goose::config::DEFAULT_EXTENSION_TIMEOUT),
}; };
self.agent self.agent
@@ -128,6 +130,8 @@ impl Session {
for name in builtin_name.split(',') { for name in builtin_name.split(',') {
let config = ExtensionConfig::Builtin { let config = ExtensionConfig::Builtin {
name: name.trim().to_string(), name: name.trim().to_string(),
// TODO: should set a timeout
timeout: Some(goose::config::DEFAULT_EXTENSION_TIMEOUT),
}; };
self.agent self.agent
.add_extension(config) .add_extension(config)

View File

@@ -23,6 +23,7 @@ enum ExtensionConfigRequest {
/// List of environment variable keys. The server will fetch their values from the keyring. /// List of environment variable keys. The server will fetch their values from the keyring.
#[serde(default)] #[serde(default)]
env_keys: Vec<String>, env_keys: Vec<String>,
timeout: Option<u64>,
}, },
/// Standard I/O (stdio) extension. /// Standard I/O (stdio) extension.
#[serde(rename = "stdio")] #[serde(rename = "stdio")]
@@ -37,12 +38,14 @@ enum ExtensionConfigRequest {
/// List of environment variable keys. The server will fetch their values from the keyring. /// List of environment variable keys. The server will fetch their values from the keyring.
#[serde(default)] #[serde(default)]
env_keys: Vec<String>, env_keys: Vec<String>,
timeout: Option<u64>,
}, },
/// Built-in extension that is part of the goose binary. /// Built-in extension that is part of the goose binary.
#[serde(rename = "builtin")] #[serde(rename = "builtin")]
Builtin { Builtin {
/// The name of the built-in extension. /// The name of the built-in extension.
name: String, name: String,
timeout: Option<u64>,
}, },
} }
@@ -84,6 +87,7 @@ async fn add_extension(
name, name,
uri, uri,
env_keys, env_keys,
timeout,
} => { } => {
let mut env_map = HashMap::new(); let mut env_map = HashMap::new();
for key in env_keys { for key in env_keys {
@@ -111,6 +115,7 @@ async fn add_extension(
name, name,
uri, uri,
envs: Envs::new(env_map), envs: Envs::new(env_map),
timeout,
} }
} }
ExtensionConfigRequest::Stdio { ExtensionConfigRequest::Stdio {
@@ -118,6 +123,7 @@ async fn add_extension(
cmd, cmd,
args, args,
env_keys, env_keys,
timeout,
} => { } => {
let mut env_map = HashMap::new(); let mut env_map = HashMap::new();
for key in env_keys { for key in env_keys {
@@ -146,9 +152,12 @@ async fn add_extension(
cmd, cmd,
args, args,
envs: Envs::new(env_map), envs: Envs::new(env_map),
timeout,
} }
} }
ExtensionConfigRequest::Builtin { name } => ExtensionConfig::Builtin { name }, ExtensionConfigRequest::Builtin { name, timeout } => {
ExtensionConfig::Builtin { name, timeout }
}
}; };
// Acquire a lock on the agent and attempt to add the extension. // Acquire a lock on the agent and attempt to add the extension.

View File

@@ -1,6 +1,7 @@
use dotenv::dotenv; use dotenv::dotenv;
use futures::StreamExt; use futures::StreamExt;
use goose::agents::{AgentFactory, ExtensionConfig}; use goose::agents::{AgentFactory, ExtensionConfig};
use goose::config::DEFAULT_EXTENSION_TIMEOUT;
use goose::message::Message; use goose::message::Message;
use goose::providers::databricks::DatabricksProvider; use goose::providers::databricks::DatabricksProvider;
@@ -14,7 +15,11 @@ async fn main() {
// Setup an agent with the developer extension // Setup an agent with the developer extension
let mut agent = AgentFactory::create("reference", provider).expect("default should exist"); let mut agent = AgentFactory::create("reference", provider).expect("default should exist");
let config = ExtensionConfig::stdio("developer", "./target/debug/developer"); let config = ExtensionConfig::stdio(
"developer",
"./target/debug/developer",
DEFAULT_EXTENSION_TIMEOUT,
);
agent.add_extension(config).await.unwrap(); agent.add_extension(config).await.unwrap();
println!("Extensions:"); println!("Extensions:");

View File

@@ -105,21 +105,37 @@ impl Capabilities {
// TODO IMPORTANT need to ensure this times out if the extension command is broken! // TODO IMPORTANT need to ensure this times out if the extension command is broken!
pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> { pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> {
let mut client: Box<dyn McpClientTrait> = match &config { let mut client: Box<dyn McpClientTrait> = match &config {
ExtensionConfig::Sse { uri, envs, .. } => { ExtensionConfig::Sse {
uri, envs, timeout, ..
} => {
let transport = SseTransport::new(uri, envs.get_env()); let transport = SseTransport::new(uri, envs.get_env());
let handle = transport.start().await?; let handle = transport.start().await?;
let service = McpService::with_timeout(handle, Duration::from_secs(300)); let service = McpService::with_timeout(
handle,
Duration::from_secs(
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
),
);
Box::new(McpClient::new(service)) Box::new(McpClient::new(service))
} }
ExtensionConfig::Stdio { ExtensionConfig::Stdio {
cmd, args, envs, .. cmd,
args,
envs,
timeout,
..
} => { } => {
let transport = StdioTransport::new(cmd, args.to_vec(), envs.get_env()); let transport = StdioTransport::new(cmd, args.to_vec(), envs.get_env());
let handle = transport.start().await?; let handle = transport.start().await?;
let service = McpService::with_timeout(handle, Duration::from_secs(300)); let service = McpService::with_timeout(
handle,
Duration::from_secs(
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
),
);
Box::new(McpClient::new(service)) Box::new(McpClient::new(service))
} }
ExtensionConfig::Builtin { name } => { ExtensionConfig::Builtin { name, timeout } => {
// For builtin extensions, we run the current executable with mcp and extension name // For builtin extensions, we run the current executable with mcp and extension name
let cmd = std::env::current_exe() let cmd = std::env::current_exe()
.expect("should find the current executable") .expect("should find the current executable")
@@ -132,7 +148,12 @@ impl Capabilities {
HashMap::new(), HashMap::new(),
); );
let handle = transport.start().await?; let handle = transport.start().await?;
let service = McpService::with_timeout(handle, Duration::from_secs(300)); let service = McpService::with_timeout(
handle,
Duration::from_secs(
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
),
);
Box::new(McpClient::new(service)) Box::new(McpClient::new(service))
} }
}; };

View File

@@ -4,6 +4,8 @@ use mcp_client::client::Error as ClientError;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use thiserror::Error; use thiserror::Error;
use crate::config;
/// Errors from Extension operation /// Errors from Extension operation
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum ExtensionError { pub enum ExtensionError {
@@ -52,6 +54,9 @@ pub enum ExtensionConfig {
uri: String, uri: String,
#[serde(default)] #[serde(default)]
envs: Envs, envs: Envs,
// NOTE: set timeout to be optional for compatibility.
// However, new configurations should include this field.
timeout: Option<u64>,
}, },
/// Standard I/O client with command and arguments /// Standard I/O client with command and arguments
#[serde(rename = "stdio")] #[serde(rename = "stdio")]
@@ -62,38 +67,43 @@ pub enum ExtensionConfig {
args: Vec<String>, args: Vec<String>,
#[serde(default)] #[serde(default)]
envs: Envs, envs: Envs,
timeout: Option<u64>,
}, },
/// Built-in extension that is part of the goose binary /// Built-in extension that is part of the goose binary
#[serde(rename = "builtin")] #[serde(rename = "builtin")]
Builtin { Builtin {
/// The name used to identify this extension /// The name used to identify this extension
name: String, name: String,
timeout: Option<u64>,
}, },
} }
impl Default for ExtensionConfig { impl Default for ExtensionConfig {
fn default() -> Self { fn default() -> Self {
Self::Builtin { Self::Builtin {
name: String::from("default"), name: config::DEFAULT_EXTENSION.to_string(),
timeout: Some(config::DEFAULT_EXTENSION_TIMEOUT),
} }
} }
} }
impl ExtensionConfig { impl ExtensionConfig {
pub fn sse<S: Into<String>>(name: S, uri: S) -> Self { pub fn sse<S: Into<String>, T: Into<u64>>(name: S, uri: S, timeout: T) -> Self {
Self::Sse { Self::Sse {
name: name.into(), name: name.into(),
uri: uri.into(), uri: uri.into(),
envs: Envs::default(), envs: Envs::default(),
timeout: Some(timeout.into()),
} }
} }
pub fn stdio<S: Into<String>>(name: S, cmd: S) -> Self { pub fn stdio<S: Into<String>, T: Into<u64>>(name: S, cmd: S, timeout: T) -> Self {
Self::Stdio { Self::Stdio {
name: name.into(), name: name.into(),
cmd: cmd.into(), cmd: cmd.into(),
args: vec![], args: vec![],
envs: Envs::default(), envs: Envs::default(),
timeout: Some(timeout.into()),
} }
} }
@@ -104,12 +114,17 @@ impl ExtensionConfig {
{ {
match self { match self {
Self::Stdio { Self::Stdio {
name, cmd, envs, .. name,
cmd,
envs,
timeout,
..
} => Self::Stdio { } => Self::Stdio {
name, name,
cmd, cmd,
envs, envs,
args: args.into_iter().map(Into::into).collect(), args: args.into_iter().map(Into::into).collect(),
timeout,
}, },
other => other, other => other,
} }
@@ -120,7 +135,7 @@ impl ExtensionConfig {
match self { match self {
Self::Sse { name, .. } => name, Self::Sse { name, .. } => name,
Self::Stdio { name, .. } => name, Self::Stdio { name, .. } => name,
Self::Builtin { name } => name, Self::Builtin { name, .. } => name,
} }
} }
} }
@@ -134,7 +149,7 @@ impl std::fmt::Display for ExtensionConfig {
} => { } => {
write!(f, "Stdio({}: {} {})", name, cmd, args.join(" ")) write!(f, "Stdio({}: {} {})", name, cmd, args.join(" "))
} }
ExtensionConfig::Builtin { name } => write!(f, "Builtin({})", name), ExtensionConfig::Builtin { name, .. } => write!(f, "Builtin({})", name),
} }
} }
} }

View File

@@ -4,7 +4,8 @@ use anyhow::Result;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
const DEFAULT_EXTENSION: &str = "developer"; pub const DEFAULT_EXTENSION: &str = "developer";
pub const DEFAULT_EXTENSION_TIMEOUT: u64 = 300;
#[derive(Debug, Deserialize, Serialize, Clone)] #[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ExtensionEntry { pub struct ExtensionEntry {
@@ -32,6 +33,7 @@ impl ExtensionManager {
enabled: true, enabled: true,
config: ExtensionConfig::Builtin { config: ExtensionConfig::Builtin {
name: DEFAULT_EXTENSION.to_string(), name: DEFAULT_EXTENSION.to_string(),
timeout: Some(DEFAULT_EXTENSION_TIMEOUT),
}, },
}, },
)]); )]);

View File

@@ -6,3 +6,6 @@ pub use crate::agents::ExtensionConfig;
pub use base::{Config, ConfigError, APP_STRATEGY}; pub use base::{Config, ConfigError, APP_STRATEGY};
pub use experiments::ExperimentManager; pub use experiments::ExperimentManager;
pub use extensions::{ExtensionEntry, ExtensionManager}; pub use extensions::{ExtensionEntry, ExtensionManager};
pub use extensions::DEFAULT_EXTENSION;
pub use extensions::DEFAULT_EXTENSION_TIMEOUT;

View File

@@ -2,7 +2,7 @@ import React, { useState } from 'react';
import { Card } from '../../ui/card'; import { Card } from '../../ui/card';
import { Button } from '../../ui/button'; import { Button } from '../../ui/button';
import { Input } from '../../ui/input'; import { Input } from '../../ui/input';
import { FullExtensionConfig } from '../../../extensions'; import { FullExtensionConfig, DEFAULT_EXTENSION_TIMEOUT } from '../../../extensions';
import { toast } from 'react-toastify'; import { toast } from 'react-toastify';
import Select from 'react-select'; import Select from 'react-select';
import { createDarkSelectStyles, darkSelectTheme } from '../../ui/select-styles'; import { createDarkSelectStyles, darkSelectTheme } from '../../ui/select-styles';
@@ -22,6 +22,7 @@ export function ManualExtensionModal({ isOpen, onClose, onSubmit }: ManualExtens
enabled: true, enabled: true,
args: [], args: [],
commandInput: '', commandInput: '',
timeout: DEFAULT_EXTENSION_TIMEOUT,
}); });
const [envKey, setEnvKey] = useState(''); const [envKey, setEnvKey] = useState('');
const [envValue, setEnvValue] = useState(''); const [envValue, setEnvValue] = useState('');
@@ -267,8 +268,20 @@ export function ManualExtensionModal({ isOpen, onClose, onSubmit }: ManualExtens
</div> </div>
)} )}
</div> </div>
</div>
<div>
<label className="block text-sm font-medium text-textStandard mb-2">
Timeout (secs)*
</label>
<Input
type="number"
value={formData.timeout || DEFAULT_EXTENSION_TIMEOUT}
onChange={(e) => setFormData({ ...formData, timeout: parseInt(e.target.value) })}
className="w-full"
required
/>
</div>
</div>
<div className="mt-[8px] -ml-8 -mr-8 pt-8"> <div className="mt-[8px] -ml-8 -mr-8 pt-8">
<Button <Button
type="submit" type="submit"

View File

@@ -3,13 +3,17 @@ import { type View } from './App';
import { type SettingsViewOptions } from './components/settings/SettingsView'; import { type SettingsViewOptions } from './components/settings/SettingsView';
import { toast } from 'react-toastify'; import { toast } from 'react-toastify';
export const DEFAULT_EXTENSION_TIMEOUT: number = 300;
// ExtensionConfig type matching the Rust version // ExtensionConfig type matching the Rust version
// TODO: refactor this
export type ExtensionConfig = export type ExtensionConfig =
| { | {
type: 'sse'; type: 'sse';
name: string; name: string;
uri: string; uri: string;
env_keys?: string[]; env_keys?: string[];
timeout?: number;
} }
| { | {
type: 'stdio'; type: 'stdio';
@@ -17,11 +21,13 @@ export type ExtensionConfig =
cmd: string; cmd: string;
args: string[]; args: string[];
env_keys?: string[]; env_keys?: string[];
timeout?: number;
} }
| { | {
type: 'builtin'; type: 'builtin';
name: string; name: string;
env_keys?: string[]; env_keys?: string[];
timeout?: number;
}; };
// FullExtensionConfig type matching all the fields that come in deep links and are stored in local storage // FullExtensionConfig type matching all the fields that come in deep links and are stored in local storage
@@ -38,6 +44,7 @@ export interface ExtensionPayload {
args?: string[]; args?: string[];
uri?: string; uri?: string;
env_keys?: string[]; env_keys?: string[];
timeout?: number;
} }
export const BUILT_IN_EXTENSIONS = [ export const BUILT_IN_EXTENSIONS = [
@@ -48,6 +55,7 @@ export const BUILT_IN_EXTENSIONS = [
enabled: true, enabled: true,
type: 'builtin', type: 'builtin',
env_keys: [], env_keys: [],
timeout: DEFAULT_EXTENSION_TIMEOUT,
}, },
{ {
id: 'computercontroller', id: 'computercontroller',
@@ -57,6 +65,7 @@ export const BUILT_IN_EXTENSIONS = [
enabled: false, enabled: false,
type: 'builtin', type: 'builtin',
env_keys: [], env_keys: [],
timeout: DEFAULT_EXTENSION_TIMEOUT,
}, },
{ {
id: 'memory', id: 'memory',
@@ -65,6 +74,7 @@ export const BUILT_IN_EXTENSIONS = [
enabled: false, enabled: false,
type: 'builtin', type: 'builtin',
env_keys: [], env_keys: [],
timeout: DEFAULT_EXTENSION_TIMEOUT,
}, },
{ {
id: 'jetbrains', id: 'jetbrains',
@@ -73,6 +83,7 @@ export const BUILT_IN_EXTENSIONS = [
enabled: false, enabled: false,
type: 'builtin', type: 'builtin',
env_keys: [], env_keys: [],
timeout: DEFAULT_EXTENSION_TIMEOUT,
}, },
{ {
id: 'tutorial', id: 'tutorial',
@@ -93,6 +104,7 @@ export const BUILT_IN_EXTENSIONS = [
'GOOGLE_DRIVE_CREDENTIALS_PATH', 'GOOGLE_DRIVE_CREDENTIALS_PATH',
'GOOGLE_DRIVE_OAUTH_CONFIG', 'GOOGLE_DRIVE_OAUTH_CONFIG',
], ],
timeout: DEFAULT_EXTENSION_TIMEOUT,
},*/ },*/
]; ];
@@ -121,6 +133,7 @@ export async function addExtension(
name: sanitizeName(extension.name), name: sanitizeName(extension.name),
}), }),
env_keys: extension.env_keys, env_keys: extension.env_keys,
timeout: extension.timeout,
}; };
const response = await fetch(getApiUrl('/extensions/add'), { const response = await fetch(getApiUrl('/extensions/add'), {
@@ -327,6 +340,7 @@ export async function addExtensionFromDeepLink(
const id = parsedUrl.searchParams.get('id'); const id = parsedUrl.searchParams.get('id');
const name = parsedUrl.searchParams.get('name'); const name = parsedUrl.searchParams.get('name');
const description = parsedUrl.searchParams.get('description'); const description = parsedUrl.searchParams.get('description');
const timeout = parsedUrl.searchParams.get('timeout');
// split env based on delimiter to a map // split env based on delimiter to a map
const envs = envList.reduce( const envs = envList.reduce(
@@ -339,6 +353,9 @@ export async function addExtensionFromDeepLink(
); );
// Create a ExtensionConfig from the URL parameters // Create a ExtensionConfig from the URL parameters
// Parse timeout if provided, otherwise use default
const parsedTimeout = timeout ? parseInt(timeout, 10) : null;
const config: FullExtensionConfig = { const config: FullExtensionConfig = {
id, id,
name, name,
@@ -348,6 +365,10 @@ export async function addExtensionFromDeepLink(
description, description,
enabled: true, enabled: true,
env_keys: Object.keys(envs).length > 0 ? Object.keys(envs) : [], env_keys: Object.keys(envs).length > 0 ? Object.keys(envs) : [],
timeout:
parsedTimeout !== null && !isNaN(parsedTimeout) && Number.isInteger(parsedTimeout)
? parsedTimeout
: DEFAULT_EXTENSION_TIMEOUT,
}; };
// Store the extension config regardless of env vars status // Store the extension config regardless of env vars status