mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-17 06:04:23 +01:00
chore: refactor read-write lock on agent (#2225)
Co-authored-by: Alice Hau <ahau@squareup.com>
This commit is contained in:
@@ -983,11 +983,11 @@ pub async fn configure_tool_permissions_dialog() -> Result<(), Box<dyn Error>> {
|
||||
.get_param("GOOSE_MODEL")
|
||||
.expect("No model configured. Please set model first");
|
||||
let model_config = goose::model::ModelConfig::new(model.clone());
|
||||
let provider =
|
||||
goose::providers::create(&provider_name, model_config).expect("Failed to create provider");
|
||||
|
||||
// Create the agent
|
||||
let mut agent = Agent::new(provider);
|
||||
let agent = Agent::new();
|
||||
let new_provider = create(&provider_name, model_config)?;
|
||||
agent.update_provider(new_provider).await?;
|
||||
if let Ok(Some(config)) = ExtensionConfigManager::get_config_by_name(&selected_extension_name) {
|
||||
agent
|
||||
.add_extension(config.clone())
|
||||
|
||||
@@ -2,6 +2,7 @@ use console::style;
|
||||
use goose::agents::extension::ExtensionError;
|
||||
use goose::agents::Agent;
|
||||
use goose::config::{Config, ExtensionConfig, ExtensionConfigManager};
|
||||
use goose::providers::create;
|
||||
use goose::session;
|
||||
use goose::session::Identifier;
|
||||
use mcp_client::transport::Error as McpClientError;
|
||||
@@ -46,11 +47,11 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
|
||||
.get_param("GOOSE_MODEL")
|
||||
.expect("No model configured. Run 'goose configure' first");
|
||||
let model_config = goose::model::ModelConfig::new(model.clone());
|
||||
let provider =
|
||||
goose::providers::create(&provider_name, model_config).expect("Failed to create provider");
|
||||
|
||||
// Create the agent
|
||||
let mut agent = Agent::new(provider);
|
||||
let agent: Agent = Agent::new();
|
||||
let new_provider = create(&provider_name, model_config).unwrap();
|
||||
let _ = agent.update_provider(new_provider).await;
|
||||
|
||||
// Handle session file resolution and resuming
|
||||
let session_file = if session_config.resume {
|
||||
|
||||
@@ -285,7 +285,7 @@ impl Session {
|
||||
async fn process_message(&mut self, message: String) -> Result<()> {
|
||||
self.messages.push(Message::user().with_text(&message));
|
||||
// Get the provider from the agent for description generation
|
||||
let provider = self.agent.provider();
|
||||
let provider = self.agent.provider().await?;
|
||||
|
||||
// Persist messages with provider for automatic description generation
|
||||
session::persist_messages(&self.session_file, &self.messages, Some(provider)).await?;
|
||||
@@ -357,7 +357,7 @@ impl Session {
|
||||
self.messages.push(Message::user().with_text(&content));
|
||||
|
||||
// Get the provider from the agent for description generation
|
||||
let provider = self.agent.provider();
|
||||
let provider = self.agent.provider().await?;
|
||||
|
||||
// Persist messages with provider for automatic description generation
|
||||
session::persist_messages(
|
||||
@@ -526,7 +526,7 @@ impl Session {
|
||||
output::render_message(&plan_response, self.debug);
|
||||
output::hide_thinking();
|
||||
let planner_response_type =
|
||||
classify_planner_response(plan_response.as_concat_text(), self.agent.provider())
|
||||
classify_planner_response(plan_response.as_concat_text(), self.agent.provider().await?)
|
||||
.await?;
|
||||
|
||||
match planner_response_type {
|
||||
|
||||
@@ -178,7 +178,10 @@ pub unsafe extern "C" fn goose_agent_new(config: *const ProviderConfigFFI) -> Ag
|
||||
// Create Databricks provider with required parameters
|
||||
match DatabricksProvider::from_params(host, api_key, model_config) {
|
||||
Ok(provider) => {
|
||||
let agent = Agent::new(Arc::new(provider));
|
||||
let agent = Agent::new();
|
||||
get_runtime().block_on(async {
|
||||
let _ = agent.update_provider(Arc::new(provider)).await;
|
||||
});
|
||||
Box::into_raw(Box::new(agent))
|
||||
}
|
||||
Err(e) => {
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::configuration;
|
||||
use crate::state;
|
||||
use anyhow::Result;
|
||||
use goose::agents::Agent;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use tracing::info;
|
||||
|
||||
@@ -15,8 +18,10 @@ pub async fn run() -> Result<()> {
|
||||
let secret_key =
|
||||
std::env::var("GOOSE_SERVER__SECRET_KEY").unwrap_or_else(|_| "test".to_string());
|
||||
|
||||
// Create app state - agent will start as None
|
||||
let state = state::AppState::new(secret_key.clone()).await?;
|
||||
let new_agent = Agent::new();
|
||||
|
||||
// Create app state with agent
|
||||
let state = state::AppState::new(Arc::new(new_agent), secret_key.clone()).await;
|
||||
|
||||
// Create router with CORS support
|
||||
let cors = CorsLayer::new()
|
||||
|
||||
@@ -8,14 +8,15 @@ use axum::{
|
||||
};
|
||||
use goose::config::Config;
|
||||
use goose::config::PermissionManager;
|
||||
use goose::{agents::Agent, model::ModelConfig, providers};
|
||||
use goose::model::ModelConfig;
|
||||
use goose::providers::create;
|
||||
use goose::{
|
||||
agents::{extension::ToolInfo, extension_manager::get_parameter_names},
|
||||
config::permission::PermissionLevel,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct VersionsResponse {
|
||||
@@ -33,17 +34,6 @@ struct ExtendPromptResponse {
|
||||
success: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct CreateAgentRequest {
|
||||
provider: String,
|
||||
model: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct CreateAgentResponse {
|
||||
version: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ProviderFile {
|
||||
name: String,
|
||||
@@ -66,6 +56,12 @@ struct ProviderList {
|
||||
details: ProviderDetails,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct UpdateProviderRequest {
|
||||
provider: String,
|
||||
model: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetToolsQuery {
|
||||
extension_name: Option<String>,
|
||||
@@ -82,53 +78,18 @@ async fn get_versions() -> Json<VersionsResponse> {
|
||||
}
|
||||
|
||||
async fn extend_prompt(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Json(payload): Json<ExtendPromptRequest>,
|
||||
) -> Result<Json<ExtendPromptResponse>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
let mut agent = state.agent.write().await;
|
||||
if let Some(ref mut agent) = *agent {
|
||||
agent.extend_system_prompt(payload.extension).await;
|
||||
Ok(Json(ExtendPromptResponse { success: true }))
|
||||
} else {
|
||||
Err(StatusCode::NOT_FOUND)
|
||||
}
|
||||
}
|
||||
|
||||
#[axum::debug_handler]
|
||||
async fn create_agent(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Json(payload): Json<CreateAgentRequest>,
|
||||
) -> Result<Json<CreateAgentResponse>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
// Set the environment variable for the model if provided
|
||||
if let Some(model) = &payload.model {
|
||||
let env_var_key = format!("{}_MODEL", payload.provider.to_uppercase());
|
||||
env::set_var(env_var_key.clone(), model);
|
||||
println!("Set environment variable: {}={}", env_var_key, model);
|
||||
}
|
||||
|
||||
let config = Config::global();
|
||||
let model = payload.model.unwrap_or_else(|| {
|
||||
config
|
||||
.get_param("GOOSE_MODEL")
|
||||
.expect("Did not find a model on payload or in env")
|
||||
});
|
||||
let model_config = ModelConfig::new(model);
|
||||
let provider =
|
||||
providers::create(&payload.provider, model_config).expect("Failed to create provider");
|
||||
|
||||
let version = String::from("goose");
|
||||
let new_agent = Agent::new(provider);
|
||||
|
||||
let mut agent = state.agent.write().await;
|
||||
*agent = Some(new_agent);
|
||||
|
||||
Ok(Json(CreateAgentResponse { version }))
|
||||
let agent = state
|
||||
.get_agent()
|
||||
.await
|
||||
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
|
||||
agent.extend_system_prompt(payload.extension.clone()).await;
|
||||
Ok(Json(ExtendPromptResponse { success: true }))
|
||||
}
|
||||
|
||||
async fn list_providers() -> Json<Vec<ProviderList>> {
|
||||
@@ -168,7 +129,7 @@ async fn list_providers() -> Json<Vec<ProviderList>> {
|
||||
)
|
||||
)]
|
||||
async fn get_tools(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Query(query): Query<GetToolsQuery>,
|
||||
) -> Result<Json<Vec<ToolInfo>>, StatusCode> {
|
||||
@@ -176,8 +137,10 @@ async fn get_tools(
|
||||
|
||||
let config = Config::global();
|
||||
let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
|
||||
let agent = state.agent.read().await;
|
||||
let agent = agent.as_ref().ok_or(StatusCode::PRECONDITION_REQUIRED)?;
|
||||
let agent = state
|
||||
.get_agent()
|
||||
.await
|
||||
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
|
||||
let permission_manager = PermissionManager::default();
|
||||
|
||||
let mut tools: Vec<ToolInfo> = agent
|
||||
@@ -210,12 +173,56 @@ async fn get_tools(
|
||||
Ok(Json(tools))
|
||||
}
|
||||
|
||||
pub fn routes(state: AppState) -> Router {
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/agent/update_provider",
|
||||
responses(
|
||||
(status = 200, description = "Update provider completed", body = String),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
async fn update_agent_provider(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Json(payload): Json<UpdateProviderRequest>,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
// Verify secret key
|
||||
let secret_key = headers
|
||||
.get("X-Secret-Key")
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
if secret_key != state.secret_key {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
let agent = state
|
||||
.get_agent()
|
||||
.await
|
||||
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
|
||||
|
||||
let config = Config::global();
|
||||
let model = payload.model.unwrap_or_else(|| {
|
||||
config
|
||||
.get_param("GOOSE_MODEL")
|
||||
.expect("Did not find a model on payload or in env to update provider with")
|
||||
});
|
||||
let model_config = ModelConfig::new(model);
|
||||
let new_provider = create(&payload.provider, model_config).unwrap();
|
||||
agent
|
||||
.update_provider(new_provider)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(StatusCode::OK)
|
||||
}
|
||||
|
||||
pub fn routes(state: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/agent/versions", get(get_versions))
|
||||
.route("/agent/providers", get(list_providers))
|
||||
.route("/agent/prompt", post(extend_prompt))
|
||||
.route("/agent/tools", get(get_tools))
|
||||
.route("/agent", post(create_agent))
|
||||
.route("/agent/update_provider", post(update_agent_provider))
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use serde_yaml;
|
||||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
@@ -89,7 +89,7 @@ pub struct UpsertPermissionsQuery {
|
||||
)
|
||||
)]
|
||||
pub async fn upsert_config(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Json(query): Json<UpsertConfigQuery>,
|
||||
) -> Result<Json<Value>, StatusCode> {
|
||||
@@ -116,7 +116,7 @@ pub async fn upsert_config(
|
||||
)
|
||||
)]
|
||||
pub async fn remove_config(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Json(query): Json<ConfigKeyQuery>,
|
||||
) -> Result<Json<String>, StatusCode> {
|
||||
@@ -148,7 +148,7 @@ pub async fn remove_config(
|
||||
)
|
||||
)]
|
||||
pub async fn read_config(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Json(query): Json<ConfigKeyQuery>,
|
||||
) -> Result<Json<Value>, StatusCode> {
|
||||
@@ -180,7 +180,7 @@ pub async fn read_config(
|
||||
)
|
||||
)]
|
||||
pub async fn get_extensions(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<ExtensionResponse>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
@@ -213,7 +213,7 @@ pub async fn get_extensions(
|
||||
)
|
||||
)]
|
||||
pub async fn add_extension(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Json(extension_query): Json<ExtensionQuery>,
|
||||
) -> Result<Json<String>, StatusCode> {
|
||||
@@ -251,7 +251,7 @@ pub async fn add_extension(
|
||||
)
|
||||
)]
|
||||
pub async fn remove_extension(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
axum::extract::Path(name): axum::extract::Path<String>,
|
||||
) -> Result<Json<String>, StatusCode> {
|
||||
@@ -272,7 +272,7 @@ pub async fn remove_extension(
|
||||
)
|
||||
)]
|
||||
pub async fn read_all_config(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<ConfigResponse>, StatusCode> {
|
||||
// Use the helper function to verify the secret key
|
||||
@@ -297,7 +297,7 @@ pub async fn read_all_config(
|
||||
)
|
||||
)]
|
||||
pub async fn providers(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<Vec<ProviderDetails>>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
@@ -332,7 +332,7 @@ pub async fn providers(
|
||||
)
|
||||
)]
|
||||
pub async fn init_config(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<String>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
@@ -402,7 +402,7 @@ pub async fn init_config(
|
||||
)
|
||||
)]
|
||||
pub async fn upsert_permissions(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Json(query): Json<UpsertPermissionsQuery>,
|
||||
) -> Result<Json<String>, StatusCode> {
|
||||
@@ -435,7 +435,7 @@ pub static APP_STRATEGY: Lazy<AppStrategyArgs> = Lazy::new(|| AppStrategyArgs {
|
||||
)
|
||||
)]
|
||||
pub async fn backup_config(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<String>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
@@ -466,7 +466,7 @@ pub async fn backup_config(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn routes(state: AppState) -> Router {
|
||||
pub fn routes(state: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/config", get(read_all_config))
|
||||
.route("/config/upsert", post(upsert_config))
|
||||
|
||||
@@ -10,7 +10,7 @@ use http::{HeaderMap, StatusCode};
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ConfigResponse {
|
||||
@@ -26,7 +26,7 @@ struct ConfigRequest {
|
||||
}
|
||||
|
||||
async fn store_config(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Json(request): Json<ConfigRequest>,
|
||||
) -> Result<Json<ConfigResponse>, StatusCode> {
|
||||
@@ -148,7 +148,7 @@ pub struct GetConfigResponse {
|
||||
}
|
||||
|
||||
pub async fn get_config(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Query(query): Query<GetConfigQuery>,
|
||||
) -> Result<Json<GetConfigResponse>, StatusCode> {
|
||||
@@ -174,7 +174,7 @@ struct DeleteConfigRequest {
|
||||
}
|
||||
|
||||
async fn delete_config(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Json(request): Json<DeleteConfigRequest>,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
@@ -193,7 +193,7 @@ async fn delete_config(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn routes(state: AppState) -> Router {
|
||||
pub fn routes(state: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/configs/providers", post(check_provider_configs))
|
||||
.route("/configs/get", get(get_config))
|
||||
|
||||
@@ -8,6 +8,7 @@ use axum::{
|
||||
};
|
||||
use goose::message::Message;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
// Direct message serialization for context mgmt request
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -26,15 +27,16 @@ pub struct ContextManageResponse {
|
||||
}
|
||||
|
||||
async fn manage_context(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Json(request): Json<ContextManageRequest>,
|
||||
) -> Result<Json<ContextManageResponse>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
// Get a lock on the shared agent
|
||||
let agent = state.agent.read().await;
|
||||
let agent = agent.as_ref().ok_or(StatusCode::PRECONDITION_REQUIRED)?;
|
||||
let agent = state
|
||||
.get_agent()
|
||||
.await
|
||||
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
|
||||
|
||||
let mut processed_messages: Vec<Message> = vec![];
|
||||
let mut token_counts: Vec<usize> = vec![];
|
||||
@@ -57,7 +59,7 @@ async fn manage_context(
|
||||
}
|
||||
|
||||
// Configure routes for this module
|
||||
pub fn routes(state: AppState) -> Router {
|
||||
pub fn routes(state: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/context/manage", post(manage_context))
|
||||
.with_state(state)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::env;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use super::utils::verify_secret_key;
|
||||
@@ -79,7 +80,7 @@ struct ExtensionResponse {
|
||||
|
||||
/// Handler for adding a new extension configuration.
|
||||
async fn add_extension(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
raw: axum::extract::Json<serde_json::Value>,
|
||||
) -> Result<Json<ExtensionResponse>, StatusCode> {
|
||||
@@ -228,9 +229,11 @@ async fn add_extension(
|
||||
},
|
||||
};
|
||||
|
||||
// Acquire a lock on the agent and attempt to add the extension.
|
||||
let mut agent = state.agent.write().await;
|
||||
let agent = agent.as_mut().ok_or(StatusCode::PRECONDITION_REQUIRED)?;
|
||||
// Get a reference to the agent
|
||||
let agent = state
|
||||
.get_agent()
|
||||
.await
|
||||
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
|
||||
let response = agent.add_extension(extension_config).await;
|
||||
|
||||
// Respond with the result.
|
||||
@@ -254,15 +257,17 @@ async fn add_extension(
|
||||
|
||||
/// Handler for removing an extension by name
|
||||
async fn remove_extension(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Json(name): Json<String>,
|
||||
) -> Result<Json<ExtensionResponse>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
// Acquire a lock on the agent and attempt to remove the extension
|
||||
let mut agent = state.agent.write().await;
|
||||
let agent = agent.as_mut().ok_or(StatusCode::PRECONDITION_REQUIRED)?;
|
||||
// Get a reference to the agent
|
||||
let agent = state
|
||||
.get_agent()
|
||||
.await
|
||||
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
|
||||
agent.remove_extension(&name).await;
|
||||
|
||||
Ok(Json(ExtensionResponse {
|
||||
@@ -272,7 +277,7 @@ async fn remove_extension(
|
||||
}
|
||||
|
||||
/// Registers the extension management routes with the Axum router.
|
||||
pub fn routes(state: AppState) -> Router {
|
||||
pub fn routes(state: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/extensions/add", post(add_extension))
|
||||
.route("/extensions/remove", post(remove_extension))
|
||||
|
||||
@@ -9,10 +9,12 @@ pub mod recipe;
|
||||
pub mod reply;
|
||||
pub mod session;
|
||||
pub mod utils;
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::Router;
|
||||
|
||||
// Function to configure all routes
|
||||
pub fn configure(state: crate::state::AppState) -> Router {
|
||||
pub fn configure(state: Arc<crate::state::AppState>) -> Router {
|
||||
Router::new()
|
||||
.merge(health::routes())
|
||||
.merge(reply::routes(state.clone()))
|
||||
@@ -22,5 +24,5 @@ pub fn configure(state: crate::state::AppState) -> Router {
|
||||
.merge(configs::routes(state.clone()))
|
||||
.merge(config_management::routes(state.clone()))
|
||||
.merge(recipe::routes(state.clone()))
|
||||
.merge(session::routes(state))
|
||||
.merge(session::routes(state.clone()))
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{extract::State, http::StatusCode, routing::post, Json, Router};
|
||||
use goose::message::Message;
|
||||
use goose::recipe::Recipe;
|
||||
@@ -34,17 +36,17 @@ pub struct CreateRecipeResponse {
|
||||
|
||||
/// Create a Recipe configuration from the current state of an agent
|
||||
async fn create_recipe(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(request): Json<CreateRecipeRequest>,
|
||||
) -> Result<Json<CreateRecipeResponse>, (StatusCode, Json<CreateRecipeResponse>)> {
|
||||
let agent = state.agent.read().await;
|
||||
let agent = agent.as_ref().ok_or_else(|| {
|
||||
let error_response = CreateRecipeResponse {
|
||||
recipe: None,
|
||||
error: Some("Agent not initialized".to_string()),
|
||||
};
|
||||
(StatusCode::PRECONDITION_REQUIRED, Json(error_response))
|
||||
})?;
|
||||
let error_response = CreateRecipeResponse {
|
||||
recipe: None,
|
||||
error: Some("Missing agent".to_string()),
|
||||
};
|
||||
let agent = state
|
||||
.get_agent()
|
||||
.await
|
||||
.map_err(|_| (StatusCode::PRECONDITION_FAILED, Json(error_response)))?;
|
||||
|
||||
// Create base recipe from agent state and messages
|
||||
let recipe_result = agent.create_recipe(request.messages).await;
|
||||
@@ -82,7 +84,7 @@ async fn create_recipe(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn routes(state: AppState) -> Router {
|
||||
pub fn routes(state: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/recipe/create", post(create_recipe))
|
||||
.with_state(state)
|
||||
|
||||
@@ -26,6 +26,7 @@ use std::{
|
||||
convert::Infallible,
|
||||
path::PathBuf,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
@@ -101,7 +102,7 @@ async fn stream_event(
|
||||
}
|
||||
|
||||
async fn handler(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Json(request): Json<ChatRequest>,
|
||||
) -> Result<SseResponse, StatusCode> {
|
||||
@@ -119,15 +120,34 @@ async fn handler(
|
||||
.session_id
|
||||
.unwrap_or_else(session::generate_session_id);
|
||||
|
||||
// Get a lock on the shared agent
|
||||
let agent = state.agent.clone();
|
||||
|
||||
// Spawn task to handle streaming
|
||||
tokio::spawn(async move {
|
||||
let agent = agent.read().await;
|
||||
let agent = match agent.as_ref() {
|
||||
Some(agent) => agent,
|
||||
None => {
|
||||
let agent = state.get_agent().await;
|
||||
let agent = match agent {
|
||||
Ok(agent) => {
|
||||
let provider = agent.provider().await;
|
||||
match provider {
|
||||
Ok(_) => agent,
|
||||
Err(_) => {
|
||||
let _ = stream_event(
|
||||
MessageEvent::Error {
|
||||
error: "No provider configured".to_string(),
|
||||
},
|
||||
&tx,
|
||||
)
|
||||
.await;
|
||||
let _ = stream_event(
|
||||
MessageEvent::Finish {
|
||||
reason: "error".to_string(),
|
||||
},
|
||||
&tx,
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = stream_event(
|
||||
MessageEvent::Error {
|
||||
error: "No agent configured".to_string(),
|
||||
@@ -147,7 +167,7 @@ async fn handler(
|
||||
};
|
||||
|
||||
// Get the provider first, before starting the reply stream
|
||||
let provider = agent.provider();
|
||||
let provider = agent.provider().await;
|
||||
|
||||
let mut stream = match agent
|
||||
.reply(
|
||||
@@ -204,7 +224,7 @@ async fn handler(
|
||||
// Store messages and generate description in background
|
||||
let session_path = session_path.clone();
|
||||
let messages = all_messages.clone();
|
||||
let provider = provider.clone();
|
||||
let provider = Arc::clone(provider.as_ref().unwrap());
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = session::persist_messages(&session_path, &messages, Some(provider)).await {
|
||||
tracing::error!("Failed to store session history: {:?}", e);
|
||||
@@ -262,7 +282,7 @@ struct AskResponse {
|
||||
|
||||
// Simple ask an AI for a response, non streaming
|
||||
async fn ask_handler(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Json(request): Json<AskRequest>,
|
||||
) -> Result<Json<AskResponse>, StatusCode> {
|
||||
@@ -275,12 +295,13 @@ async fn ask_handler(
|
||||
.session_id
|
||||
.unwrap_or_else(session::generate_session_id);
|
||||
|
||||
let agent = state.agent.clone();
|
||||
let agent = agent.write().await;
|
||||
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;
|
||||
let agent = state
|
||||
.get_agent()
|
||||
.await
|
||||
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
|
||||
|
||||
// Get the provider first, before starting the reply stream
|
||||
let provider = agent.provider();
|
||||
let provider = agent.provider().await;
|
||||
|
||||
// Create a single message for the prompt
|
||||
let messages = vec![Message::user().with_text(request.prompt)];
|
||||
@@ -339,7 +360,7 @@ async fn ask_handler(
|
||||
// Store messages and generate description in background
|
||||
let session_path = session_path.clone();
|
||||
let messages = all_messages.clone();
|
||||
let provider = provider.clone();
|
||||
let provider = Arc::clone(provider.as_ref().unwrap());
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = session::persist_messages(&session_path, &messages, Some(provider)).await {
|
||||
tracing::error!("Failed to store session history: {:?}", e);
|
||||
@@ -374,15 +395,16 @@ fn default_principal_type() -> PrincipalType {
|
||||
)
|
||||
)]
|
||||
pub async fn confirm_permission(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Json(request): Json<PermissionConfirmationRequest>,
|
||||
) -> Result<Json<Value>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
let agent = state.agent.clone();
|
||||
let agent = agent.read().await;
|
||||
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;
|
||||
let agent = state
|
||||
.get_agent()
|
||||
.await
|
||||
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
|
||||
|
||||
let permission = match request.action.as_str() {
|
||||
"always_allow" => Permission::AlwaysAllow,
|
||||
@@ -410,7 +432,7 @@ struct ToolResultRequest {
|
||||
}
|
||||
|
||||
async fn submit_tool_result(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
raw: axum::extract::Json<serde_json::Value>,
|
||||
) -> Result<Json<Value>, StatusCode> {
|
||||
@@ -435,14 +457,16 @@ async fn submit_tool_result(
|
||||
}
|
||||
};
|
||||
|
||||
let agent = state.agent.read().await;
|
||||
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;
|
||||
let agent = state
|
||||
.get_agent()
|
||||
.await
|
||||
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
|
||||
agent.handle_tool_result(payload.id, payload.result).await;
|
||||
Ok(Json(json!({"status": "ok"})))
|
||||
}
|
||||
|
||||
// Configure routes for this module
|
||||
pub fn routes(state: AppState) -> Router {
|
||||
pub fn routes(state: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/reply", post(handler))
|
||||
.route("/ask", post(ask_handler))
|
||||
@@ -496,9 +520,7 @@ mod tests {
|
||||
mod integration_tests {
|
||||
use super::*;
|
||||
use axum::{body::Body, http::Request};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
use tower::ServiceExt;
|
||||
|
||||
// This test requires tokio runtime
|
||||
@@ -509,12 +531,9 @@ mod tests {
|
||||
let mock_provider = Arc::new(MockProvider {
|
||||
model_config: mock_model_config,
|
||||
});
|
||||
let agent = Agent::new(mock_provider);
|
||||
let state = AppState {
|
||||
config: Arc::new(Mutex::new(HashMap::new())),
|
||||
agent: Arc::new(RwLock::new(Some(agent))),
|
||||
secret_key: "test-secret".to_string(),
|
||||
};
|
||||
let agent = Agent::new();
|
||||
let _ = agent.update_provider(mock_provider).await;
|
||||
let state = AppState::new(Arc::new(agent), "test-secret".to_string()).await;
|
||||
|
||||
// Build router
|
||||
let app = routes(state);
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
use super::utils::verify_secret_key;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
@@ -25,7 +27,7 @@ struct SessionHistoryResponse {
|
||||
|
||||
// List all available sessions
|
||||
async fn list_sessions(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<SessionListResponse>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
@@ -38,7 +40,7 @@ async fn list_sessions(
|
||||
|
||||
// Get a specific session's history
|
||||
async fn get_session_history(
|
||||
State(state): State<AppState>,
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Path(session_id): Path<String>,
|
||||
) -> Result<Json<SessionHistoryResponse>, StatusCode> {
|
||||
@@ -65,7 +67,7 @@ async fn get_session_history(
|
||||
}
|
||||
|
||||
// Configure routes for this module
|
||||
pub fn routes(state: AppState) -> Router {
|
||||
pub fn routes(state: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/sessions", get(list_sessions))
|
||||
.route("/sessions/:session_id", get(get_session_history))
|
||||
|
||||
@@ -1,25 +1,34 @@
|
||||
use anyhow::Result;
|
||||
use goose::agents::Agent;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
|
||||
/// Shared reference to an Agent that can be cloned cheaply
|
||||
/// without cloning the underlying Agent object
|
||||
pub type AgentRef = Arc<Agent>;
|
||||
|
||||
/// Thread-safe container for an optional Agent reference
|
||||
/// Outer Arc: Allows multiple route handlers to access the same Mutex
|
||||
/// - Mutex provides exclusive access for updates
|
||||
/// - Option allows for the case where no agent exists yet
|
||||
///
|
||||
/// Shared application state
|
||||
#[allow(dead_code)]
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub agent: Arc<RwLock<Option<Agent>>>,
|
||||
// agent: SharedAgentStore,
|
||||
agent: Option<AgentRef>,
|
||||
pub secret_key: String,
|
||||
pub config: Arc<Mutex<HashMap<String, Value>>>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub async fn new(secret_key: String) -> Result<Self> {
|
||||
Ok(Self {
|
||||
agent: Arc::new(RwLock::new(None)),
|
||||
pub async fn new(agent: AgentRef, secret_key: String) -> Arc<AppState> {
|
||||
Arc::new(Self {
|
||||
agent: Some(agent.clone()),
|
||||
secret_key,
|
||||
config: Arc::new(Mutex::new(HashMap::new())),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_agent(&self) -> Result<Arc<Agent>, anyhow::Error> {
|
||||
self.agent
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow::anyhow!("Agent needs to be created first."))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,8 @@ async fn main() {
|
||||
let provider = Arc::new(DatabricksProvider::default());
|
||||
|
||||
// Setup an agent with the developer extension
|
||||
let mut agent = Agent::new(provider);
|
||||
let agent = Agent::new();
|
||||
let _ = agent.update_provider(provider).await;
|
||||
|
||||
let config = ExtensionConfig::stdio(
|
||||
"developer",
|
||||
|
||||
@@ -35,12 +35,11 @@ use super::tool_execution::{ToolFuture, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINE
|
||||
|
||||
/// The main goose Agent
|
||||
pub struct Agent {
|
||||
pub(super) provider: Arc<dyn Provider>,
|
||||
pub(super) provider: Mutex<Option<Arc<dyn Provider>>>,
|
||||
pub(super) extension_manager: Mutex<ExtensionManager>,
|
||||
pub(super) frontend_tools: HashMap<String, FrontendTool>,
|
||||
pub(super) frontend_instructions: Option<String>,
|
||||
pub(super) prompt_manager: PromptManager,
|
||||
// Channels for tool results and confirmations
|
||||
pub(super) frontend_tools: Mutex<HashMap<String, FrontendTool>>,
|
||||
pub(super) frontend_instructions: Mutex<Option<String>>,
|
||||
pub(super) prompt_manager: Mutex<PromptManager>,
|
||||
pub(super) confirmation_tx: mpsc::Sender<(String, PermissionConfirmation)>,
|
||||
pub(super) confirmation_rx: Mutex<mpsc::Receiver<(String, PermissionConfirmation)>>,
|
||||
pub(super) tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
|
||||
@@ -48,41 +47,52 @@ pub struct Agent {
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
pub fn new(provider: Arc<dyn Provider>) -> Self {
|
||||
pub fn new() -> Self {
|
||||
// Create channels with buffer size 32 (adjust if needed)
|
||||
let (confirm_tx, confirm_rx) = mpsc::channel(32);
|
||||
let (tool_tx, tool_rx) = mpsc::channel(32);
|
||||
|
||||
Self {
|
||||
provider,
|
||||
provider: Mutex::new(None),
|
||||
extension_manager: Mutex::new(ExtensionManager::new()),
|
||||
frontend_tools: HashMap::new(),
|
||||
frontend_instructions: None,
|
||||
prompt_manager: PromptManager::new(),
|
||||
frontend_tools: Mutex::new(HashMap::new()),
|
||||
frontend_instructions: Mutex::new(None),
|
||||
prompt_manager: Mutex::new(PromptManager::new()),
|
||||
confirmation_tx: confirm_tx,
|
||||
confirmation_rx: Mutex::new(confirm_rx),
|
||||
tool_result_tx: tool_tx,
|
||||
tool_result_rx: Arc::new(Mutex::new(tool_rx)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Agent {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
/// Get a reference count clone to the provider
|
||||
pub fn provider(&self) -> Arc<dyn Provider> {
|
||||
Arc::clone(&self.provider)
|
||||
pub async fn provider(&self) -> Result<Arc<dyn Provider>, anyhow::Error> {
|
||||
match &*self.provider.lock().await {
|
||||
Some(provider) => Ok(Arc::clone(provider)),
|
||||
None => Err(anyhow!("Provider not set")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a tool is a frontend tool
|
||||
pub fn is_frontend_tool(&self, name: &str) -> bool {
|
||||
self.frontend_tools.contains_key(name)
|
||||
pub async fn is_frontend_tool(&self, name: &str) -> bool {
|
||||
self.frontend_tools.lock().await.contains_key(name)
|
||||
}
|
||||
|
||||
/// Get a reference to a frontend tool
|
||||
pub fn get_frontend_tool(&self, name: &str) -> Option<&FrontendTool> {
|
||||
self.frontend_tools.get(name)
|
||||
pub async fn get_frontend_tool(&self, name: &str) -> Option<FrontendTool> {
|
||||
self.frontend_tools.lock().await.get(name).cloned()
|
||||
}
|
||||
|
||||
/// Get all tools from all clients with proper prefixing
|
||||
pub async fn get_prefixed_tools(&mut self) -> ExtensionResult<Vec<Tool>> {
|
||||
pub async fn get_prefixed_tools(&self) -> ExtensionResult<Vec<Tool>> {
|
||||
let mut tools = self
|
||||
.extension_manager
|
||||
.lock()
|
||||
@@ -91,7 +101,8 @@ impl Agent {
|
||||
.await?;
|
||||
|
||||
// Add frontend tools directly - they don't need prefixing since they're already uniquely named
|
||||
for frontend_tool in self.frontend_tools.values() {
|
||||
let frontend_tools = self.frontend_tools.lock().await;
|
||||
for frontend_tool in frontend_tools.values() {
|
||||
tools.push(frontend_tool.tool.clone());
|
||||
}
|
||||
|
||||
@@ -135,7 +146,7 @@ impl Agent {
|
||||
.await
|
||||
} else if tool_call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME {
|
||||
extension_manager.search_available_extensions().await
|
||||
} else if self.is_frontend_tool(&tool_call.name) {
|
||||
} else if self.is_frontend_tool(&tool_call.name).await {
|
||||
// For frontend tools, return an error indicating we need frontend execution
|
||||
Err(ToolError::ExecutionError(
|
||||
"Frontend tool execution required".to_string(),
|
||||
@@ -212,7 +223,7 @@ impl Agent {
|
||||
(request_id, result)
|
||||
}
|
||||
|
||||
pub async fn add_extension(&mut self, extension: ExtensionConfig) -> ExtensionResult<()> {
|
||||
pub async fn add_extension(&self, extension: ExtensionConfig) -> ExtensionResult<()> {
|
||||
match &extension {
|
||||
ExtensionConfig::Frontend {
|
||||
name: _,
|
||||
@@ -221,19 +232,21 @@ impl Agent {
|
||||
bundled: _,
|
||||
} => {
|
||||
// For frontend tools, just store them in the frontend_tools map
|
||||
let mut frontend_tools = self.frontend_tools.lock().await;
|
||||
for tool in tools {
|
||||
let frontend_tool = FrontendTool {
|
||||
name: tool.name.clone(),
|
||||
tool: tool.clone(),
|
||||
};
|
||||
self.frontend_tools.insert(tool.name.clone(), frontend_tool);
|
||||
frontend_tools.insert(tool.name.clone(), frontend_tool);
|
||||
}
|
||||
// Store instructions if provided, using "frontend" as the key
|
||||
let mut frontend_instructions = self.frontend_instructions.lock().await;
|
||||
if let Some(instructions) = instructions {
|
||||
self.frontend_instructions = Some(instructions.clone());
|
||||
*frontend_instructions = Some(instructions.clone());
|
||||
} else {
|
||||
// Default frontend instructions if none provided
|
||||
self.frontend_instructions = Some(
|
||||
*frontend_instructions = Some(
|
||||
"The following tools are provided directly by the frontend and will be executed by the frontend when called.".to_string(),
|
||||
);
|
||||
}
|
||||
@@ -269,7 +282,7 @@ impl Agent {
|
||||
prefixed_tools
|
||||
}
|
||||
|
||||
pub async fn remove_extension(&mut self, name: &str) {
|
||||
pub async fn remove_extension(&self, name: &str) {
|
||||
let mut extension_manager = self.extension_manager.lock().await;
|
||||
extension_manager
|
||||
.remove_extension(name)
|
||||
@@ -329,7 +342,7 @@ impl Agent {
|
||||
let _ = reply_span.enter();
|
||||
loop {
|
||||
match Self::generate_response_from_provider(
|
||||
self.provider(),
|
||||
self.provider().await?,
|
||||
&system_prompt,
|
||||
&messages,
|
||||
&tools,
|
||||
@@ -345,7 +358,7 @@ impl Agent {
|
||||
let (frontend_requests,
|
||||
remaining_requests,
|
||||
filtered_response) =
|
||||
self.categorize_tool_requests(&response);
|
||||
self.categorize_tool_requests(&response).await;
|
||||
|
||||
|
||||
// Yield the assistant's response with frontend tool requests filtered out
|
||||
@@ -396,8 +409,7 @@ impl Agent {
|
||||
tools_with_readonly_annotation.clone(),
|
||||
tools_without_annotation.clone(),
|
||||
&mut permission_manager,
|
||||
self.provider(),
|
||||
).await;
|
||||
self.provider().await?).await;
|
||||
|
||||
// Handle pre-approved and read-only tools in parallel
|
||||
let mut tool_futures: Vec<ToolFuture> = Vec::new();
|
||||
@@ -492,13 +504,21 @@ impl Agent {
|
||||
}
|
||||
|
||||
/// Extend the system prompt with one line of additional instruction
|
||||
pub async fn extend_system_prompt(&mut self, instruction: String) {
|
||||
self.prompt_manager.add_system_prompt_extra(instruction);
|
||||
pub async fn extend_system_prompt(&self, instruction: String) {
|
||||
let mut prompt_manager = self.prompt_manager.lock().await;
|
||||
prompt_manager.add_system_prompt_extra(instruction);
|
||||
}
|
||||
|
||||
/// Update the provider used by this agent
|
||||
pub async fn update_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
|
||||
*self.provider.lock().await = Some(provider);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Override the system prompt with a custom template
|
||||
pub async fn override_system_prompt(&mut self, template: String) {
|
||||
self.prompt_manager.set_system_prompt_override(template);
|
||||
pub async fn override_system_prompt(&self, template: String) {
|
||||
let mut prompt_manager = self.prompt_manager.lock().await;
|
||||
prompt_manager.set_system_prompt_override(template);
|
||||
}
|
||||
|
||||
pub async fn list_extension_prompts(&self) -> HashMap<String, Vec<Prompt>> {
|
||||
@@ -563,23 +583,29 @@ impl Agent {
|
||||
let extensions_info = extension_manager.get_extensions_info().await;
|
||||
|
||||
// Get model name from provider
|
||||
let model_config = self.provider.get_model_config();
|
||||
let provider = self.provider().await?;
|
||||
let model_config = provider.get_model_config();
|
||||
let model_name = &model_config.model_name;
|
||||
|
||||
let system_prompt = self.prompt_manager.build_system_prompt(
|
||||
let prompt_manager = self.prompt_manager.lock().await;
|
||||
let system_prompt = prompt_manager.build_system_prompt(
|
||||
extensions_info,
|
||||
self.frontend_instructions.clone(),
|
||||
self.frontend_instructions.lock().await.clone(),
|
||||
extension_manager.suggest_disable_extensions_prompt().await,
|
||||
Some(model_name),
|
||||
);
|
||||
|
||||
let recipe_prompt = self.prompt_manager.get_recipe_prompt().await;
|
||||
let recipe_prompt = prompt_manager.get_recipe_prompt().await;
|
||||
let tools = extension_manager.get_prefixed_tools(None).await?;
|
||||
|
||||
messages.push(Message::user().with_text(recipe_prompt));
|
||||
|
||||
let (result, _usage) = self
|
||||
.provider
|
||||
.lock()
|
||||
.await
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.complete(&system_prompt, &messages, &tools)
|
||||
.await?;
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ impl Agent {
|
||||
&self,
|
||||
messages: &[Message], // last message is a user msg that led to assistant message with_context_length_exceeded
|
||||
) -> Result<(Vec<Message>, Vec<usize>), anyhow::Error> {
|
||||
let provider = self.provider.clone();
|
||||
let provider = self.provider().await?;
|
||||
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
||||
let target_context_limit = estimate_target_context_limit(provider);
|
||||
let token_counts = get_messages_token_counts(&token_counter, messages);
|
||||
@@ -41,7 +41,7 @@ impl Agent {
|
||||
&self,
|
||||
messages: &[Message], // last message is a user msg that led to assistant message with_context_length_exceeded
|
||||
) -> Result<(Vec<Message>, Vec<usize>), anyhow::Error> {
|
||||
let provider = self.provider.clone();
|
||||
let provider = self.provider().await?;
|
||||
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
||||
let target_context_limit = estimate_target_context_limit(provider.clone());
|
||||
|
||||
|
||||
@@ -22,7 +22,8 @@ impl Agent {
|
||||
let mut tools = self.list_tools(None).await;
|
||||
|
||||
// Add frontend tools
|
||||
for frontend_tool in self.frontend_tools.values() {
|
||||
let frontend_tools = self.frontend_tools.lock().await;
|
||||
for frontend_tool in frontend_tools.values() {
|
||||
tools.push(frontend_tool.tool.clone());
|
||||
}
|
||||
|
||||
@@ -31,19 +32,21 @@ impl Agent {
|
||||
let extensions_info = extension_manager.get_extensions_info().await;
|
||||
|
||||
// Get model name from provider
|
||||
let model_config = self.provider.get_model_config();
|
||||
let provider = self.provider().await?;
|
||||
let model_config = provider.get_model_config();
|
||||
let model_name = &model_config.model_name;
|
||||
|
||||
let mut system_prompt = self.prompt_manager.build_system_prompt(
|
||||
let prompt_manager = self.prompt_manager.lock().await;
|
||||
let mut system_prompt = prompt_manager.build_system_prompt(
|
||||
extensions_info,
|
||||
self.frontend_instructions.clone(),
|
||||
self.frontend_instructions.lock().await.clone(),
|
||||
extension_manager.suggest_disable_extensions_prompt().await,
|
||||
Some(model_name),
|
||||
);
|
||||
|
||||
// Handle toolshim if enabled
|
||||
let mut toolshim_tools = vec![];
|
||||
if self.provider.get_model_config().toolshim {
|
||||
if model_config.toolshim {
|
||||
// If tool interpretation is enabled, modify the system prompt
|
||||
system_prompt = modify_system_prompt_for_tool_json(&system_prompt, &tools);
|
||||
// Make a copy of tools before emptying
|
||||
@@ -115,7 +118,7 @@ impl Agent {
|
||||
/// - frontend_requests: Tool requests that should be handled by the frontend
|
||||
/// - other_requests: All other tool requests (including requests to enable extensions)
|
||||
/// - filtered_message: The original message with frontend tool requests removed
|
||||
pub(crate) fn categorize_tool_requests(
|
||||
pub(crate) async fn categorize_tool_requests(
|
||||
&self,
|
||||
response: &Message,
|
||||
) -> (Vec<ToolRequest>, Vec<ToolRequest>, Message) {
|
||||
@@ -133,20 +136,25 @@ impl Agent {
|
||||
.collect();
|
||||
|
||||
// Create a filtered message with frontend tool requests removed
|
||||
let filtered_content = response
|
||||
.content
|
||||
.iter()
|
||||
.filter(|c| {
|
||||
if let MessageContent::ToolRequest(req) = c {
|
||||
// Only filter out frontend tool requests
|
||||
let mut filtered_content = Vec::new();
|
||||
|
||||
// Process each content item one by one
|
||||
for content in &response.content {
|
||||
let should_include = match content {
|
||||
MessageContent::ToolRequest(req) => {
|
||||
if let Ok(tool_call) = &req.tool_call {
|
||||
return !self.is_frontend_tool(&tool_call.name);
|
||||
!self.is_frontend_tool(&tool_call.name).await
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
true
|
||||
})
|
||||
.cloned()
|
||||
.collect();
|
||||
_ => true,
|
||||
};
|
||||
|
||||
if should_include {
|
||||
filtered_content.push(content.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let filtered_message = Message {
|
||||
role: response.role.clone(),
|
||||
@@ -160,7 +168,7 @@ impl Agent {
|
||||
|
||||
for request in tool_requests {
|
||||
if let Ok(tool_call) = &request.tool_call {
|
||||
if self.is_frontend_tool(&tool_call.name) {
|
||||
if self.is_frontend_tool(&tool_call.name).await {
|
||||
frontend_requests.push(request);
|
||||
} else {
|
||||
other_requests.push(request);
|
||||
|
||||
@@ -87,7 +87,7 @@ impl Agent {
|
||||
try_stream! {
|
||||
for request in tool_requests {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
if self.is_frontend_tool(&tool_call.name) {
|
||||
if self.is_frontend_tool(&tool_call.name).await {
|
||||
// Send frontend tool request and wait for response
|
||||
yield Message::assistant().with_frontend_tool_request(
|
||||
request.id.clone(),
|
||||
|
||||
@@ -739,10 +739,10 @@ mod tests {
|
||||
thread::sleep(Duration::from_millis(i * 10));
|
||||
|
||||
let extension_key = format!("extension_{}", i);
|
||||
let mut values = config.load_values()?;
|
||||
|
||||
values.insert(
|
||||
extension_key.clone(),
|
||||
// Use set_param which handles concurrent access properly
|
||||
config.set_param(
|
||||
&extension_key,
|
||||
serde_json::json!({
|
||||
"name": format!("test_extension_{}", i),
|
||||
"version": format!("1.0.{}", i),
|
||||
@@ -752,10 +752,7 @@ mod tests {
|
||||
"option2": i
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
// Write all values atomically
|
||||
config.save_values(values)?;
|
||||
)?;
|
||||
Ok(())
|
||||
});
|
||||
handles.push(handle);
|
||||
|
||||
@@ -110,7 +110,8 @@ async fn run_truncate_test(
|
||||
.with_temperature(Some(0.0));
|
||||
let provider = provider_type.create_provider(model_config)?;
|
||||
|
||||
let agent = Agent::new(provider);
|
||||
let agent = Agent::new();
|
||||
agent.update_provider(provider).await?;
|
||||
let repeat_count = context_window + 10_000;
|
||||
let large_message_content = "hello ".repeat(repeat_count);
|
||||
let messages = vec![
|
||||
|
||||
@@ -33,12 +33,28 @@ CALCULATOR_TOOL = {
|
||||
},
|
||||
}
|
||||
|
||||
# Enable Extension tool definition
|
||||
ENABLE_EXTENSION_TOOL = {
|
||||
"name": "enable_extension",
|
||||
"description": "Enable extensions to help complete tasks. Enable an extension by providing the extension name.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"required": ["extension_name"],
|
||||
"properties": {
|
||||
"extension_name": {
|
||||
"type": "string",
|
||||
"description": "The name of the extension to enable",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Frontend extension configuration
|
||||
FRONTEND_CONFIG = {
|
||||
"name": "pythonclient",
|
||||
"type": "frontend",
|
||||
"tools": [CALCULATOR_TOOL],
|
||||
"instructions": "A calculator extension that can perform basic arithmetic operations.",
|
||||
"tools": [CALCULATOR_TOOL, ENABLE_EXTENSION_TOOL],
|
||||
"instructions": "A calculator extension that can perform basic arithmetic operations. Use enable extension tool to add extesions such as fetch, pdf reader, etc.",
|
||||
}
|
||||
|
||||
|
||||
@@ -47,7 +63,7 @@ async def setup_agent() -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
# First create the agent
|
||||
response = await client.post(
|
||||
f"{GOOSE_URL}/agent",
|
||||
f"{GOOSE_URL}/agent/update_provider",
|
||||
json={"provider": "databricks", "model": "goose"},
|
||||
headers={"X-Secret-Key": SECRET_KEY},
|
||||
)
|
||||
@@ -101,6 +117,55 @@ def execute_calculator(args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
}
|
||||
]
|
||||
|
||||
def get_tools() -> Dict[str, Any]:
|
||||
with httpx.Client() as client:
|
||||
response = client.get(
|
||||
f"{GOOSE_URL}/agent/tools",
|
||||
headers={"X-Secret-Key": SECRET_KEY},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def execute_enable_extension(args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Execute the enable_extension tool.
|
||||
This function fetches available extensions, finds the one with the provided extension_name,
|
||||
and posts its configuration to the /extensions/add endpoint.
|
||||
"""
|
||||
extension = args
|
||||
extension_name = extension.get("name")
|
||||
|
||||
# Post the extension configuration to enable it
|
||||
with httpx.Client() as client:
|
||||
payload = {
|
||||
"type": extension.get("type"),
|
||||
"name": extension.get("name"),
|
||||
"cmd": extension.get("cmd"),
|
||||
"args": extension.get("args"),
|
||||
"envs": extension.get("envs", {}),
|
||||
"timeout": extension.get("timeout"),
|
||||
"bundled": extension.get("bundled"),
|
||||
}
|
||||
add_response = client.post(
|
||||
f"{GOOSE_URL}/extensions/add",
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json", "X-Secret-Key": SECRET_KEY},
|
||||
)
|
||||
if add_response.status_code != 200:
|
||||
error_text = add_response.text
|
||||
return [{
|
||||
"type": "text",
|
||||
"text": f"Error: Failed to enable extension: {error_text}",
|
||||
"annotations": None,
|
||||
}]
|
||||
|
||||
return [{
|
||||
"type": "text",
|
||||
"text": f"Successfully enabled extension: {extension_name}",
|
||||
"annotations": None,
|
||||
}]
|
||||
|
||||
|
||||
def submit_tool_result(tool_id: str, result: List[Dict[str, Any]]) -> None:
|
||||
"""Submit the tool execution result back to Goose.
|
||||
@@ -129,7 +194,7 @@ async def chat_loop() -> None:
|
||||
session_id = "test-session"
|
||||
|
||||
# Use a client with a longer timeout for streaming
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
# Get user input
|
||||
user_message = input("\nYou: ")
|
||||
if user_message.lower() in ["exit", "quit"]:
|
||||
@@ -152,7 +217,7 @@ async def chat_loop() -> None:
|
||||
# Process the stream of responses
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{GOOSE_URL}/reply",
|
||||
f"{GOOSE_URL}/reply", # lock
|
||||
json=payload,
|
||||
headers={
|
||||
"X-Secret-Key": SECRET_KEY,
|
||||
@@ -185,9 +250,26 @@ async def chat_loop() -> None:
|
||||
elif content["type"] == "frontendToolRequest":
|
||||
# Execute the tool and submit results
|
||||
tool_call = content["toolCall"]["value"]
|
||||
print(f"Calculator: {tool_call}")
|
||||
# Execute the tool
|
||||
result = execute_calculator(tool_call["arguments"])
|
||||
print(f"\nTool Request: {tool_call}")
|
||||
|
||||
if tool_call['name'] == "calculator":
|
||||
print(f"Calculator: {tool_call}")
|
||||
# Execute the tool
|
||||
result = execute_calculator(tool_call["arguments"])
|
||||
elif tool_call['name'] == "enable_extension":
|
||||
# to trigger this tool, use the instruction "use enable_extension tool with "fetch" extension name"
|
||||
print(f"Enabling fetch extension")
|
||||
result = execute_enable_extension(args={
|
||||
"type": "stdio",
|
||||
"name": "fetch",
|
||||
"cmd": "uvx",
|
||||
"args": ["mcp-server-fetch"],
|
||||
"timeout": 300,
|
||||
"bundled": False
|
||||
})
|
||||
listed_tools = get_tools()
|
||||
print(f"\nTools after enabling extension: {listed_tools}")
|
||||
|
||||
|
||||
# Submit the result
|
||||
submit_tool_result(content["id"], result)
|
||||
|
||||
@@ -6,7 +6,7 @@ interface initializeAgentProps {
|
||||
}
|
||||
|
||||
export async function initializeAgent({ model, provider }: initializeAgentProps) {
|
||||
const response = await fetch(getApiUrl('/agent'), {
|
||||
const response = await fetch(getApiUrl('/agent/update_provider'), {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
|
||||
Reference in New Issue
Block a user