chore: refactor read-write lock on agent (#2225)

Co-authored-by: Alice Hau <ahau@squareup.com>
This commit is contained in:
Salman Mohammed
2025-04-23 23:46:22 -03:00
committed by GitHub
parent 85e2ee3984
commit 199fa6adbc
24 changed files with 409 additions and 237 deletions

View File

@@ -983,11 +983,11 @@ pub async fn configure_tool_permissions_dialog() -> Result<(), Box<dyn Error>> {
.get_param("GOOSE_MODEL")
.expect("No model configured. Please set model first");
let model_config = goose::model::ModelConfig::new(model.clone());
let provider =
goose::providers::create(&provider_name, model_config).expect("Failed to create provider");
// Create the agent
let mut agent = Agent::new(provider);
let agent = Agent::new();
let new_provider = create(&provider_name, model_config)?;
agent.update_provider(new_provider).await?;
if let Ok(Some(config)) = ExtensionConfigManager::get_config_by_name(&selected_extension_name) {
agent
.add_extension(config.clone())

View File

@@ -2,6 +2,7 @@ use console::style;
use goose::agents::extension::ExtensionError;
use goose::agents::Agent;
use goose::config::{Config, ExtensionConfig, ExtensionConfigManager};
use goose::providers::create;
use goose::session;
use goose::session::Identifier;
use mcp_client::transport::Error as McpClientError;
@@ -46,11 +47,11 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
.get_param("GOOSE_MODEL")
.expect("No model configured. Run 'goose configure' first");
let model_config = goose::model::ModelConfig::new(model.clone());
let provider =
goose::providers::create(&provider_name, model_config).expect("Failed to create provider");
// Create the agent
let mut agent = Agent::new(provider);
let agent: Agent = Agent::new();
let new_provider = create(&provider_name, model_config).unwrap();
let _ = agent.update_provider(new_provider).await;
// Handle session file resolution and resuming
let session_file = if session_config.resume {

View File

@@ -285,7 +285,7 @@ impl Session {
async fn process_message(&mut self, message: String) -> Result<()> {
self.messages.push(Message::user().with_text(&message));
// Get the provider from the agent for description generation
let provider = self.agent.provider();
let provider = self.agent.provider().await?;
// Persist messages with provider for automatic description generation
session::persist_messages(&self.session_file, &self.messages, Some(provider)).await?;
@@ -357,7 +357,7 @@ impl Session {
self.messages.push(Message::user().with_text(&content));
// Get the provider from the agent for description generation
let provider = self.agent.provider();
let provider = self.agent.provider().await?;
// Persist messages with provider for automatic description generation
session::persist_messages(
@@ -526,7 +526,7 @@ impl Session {
output::render_message(&plan_response, self.debug);
output::hide_thinking();
let planner_response_type =
classify_planner_response(plan_response.as_concat_text(), self.agent.provider())
classify_planner_response(plan_response.as_concat_text(), self.agent.provider().await?)
.await?;
match planner_response_type {

View File

@@ -178,7 +178,10 @@ pub unsafe extern "C" fn goose_agent_new(config: *const ProviderConfigFFI) -> Ag
// Create Databricks provider with required parameters
match DatabricksProvider::from_params(host, api_key, model_config) {
Ok(provider) => {
let agent = Agent::new(Arc::new(provider));
let agent = Agent::new();
get_runtime().block_on(async {
let _ = agent.update_provider(Arc::new(provider)).await;
});
Box::into_raw(Box::new(agent))
}
Err(e) => {

View File

@@ -1,6 +1,9 @@
use std::sync::Arc;
use crate::configuration;
use crate::state;
use anyhow::Result;
use goose::agents::Agent;
use tower_http::cors::{Any, CorsLayer};
use tracing::info;
@@ -15,8 +18,10 @@ pub async fn run() -> Result<()> {
let secret_key =
std::env::var("GOOSE_SERVER__SECRET_KEY").unwrap_or_else(|_| "test".to_string());
// Create app state - agent will start as None
let state = state::AppState::new(secret_key.clone()).await?;
let new_agent = Agent::new();
// Create app state with agent
let state = state::AppState::new(Arc::new(new_agent), secret_key.clone()).await;
// Create router with CORS support
let cors = CorsLayer::new()

View File

@@ -8,14 +8,15 @@ use axum::{
};
use goose::config::Config;
use goose::config::PermissionManager;
use goose::{agents::Agent, model::ModelConfig, providers};
use goose::model::ModelConfig;
use goose::providers::create;
use goose::{
agents::{extension::ToolInfo, extension_manager::get_parameter_names},
config::permission::PermissionLevel,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::env;
use std::sync::Arc;
#[derive(Serialize)]
struct VersionsResponse {
@@ -33,17 +34,6 @@ struct ExtendPromptResponse {
success: bool,
}
#[derive(Deserialize)]
struct CreateAgentRequest {
provider: String,
model: Option<String>,
}
#[derive(Serialize)]
struct CreateAgentResponse {
version: String,
}
#[derive(Deserialize)]
struct ProviderFile {
name: String,
@@ -66,6 +56,12 @@ struct ProviderList {
details: ProviderDetails,
}
#[derive(Deserialize)]
struct UpdateProviderRequest {
provider: String,
model: Option<String>,
}
#[derive(Deserialize)]
pub struct GetToolsQuery {
extension_name: Option<String>,
@@ -82,53 +78,18 @@ async fn get_versions() -> Json<VersionsResponse> {
}
async fn extend_prompt(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(payload): Json<ExtendPromptRequest>,
) -> Result<Json<ExtendPromptResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;
let mut agent = state.agent.write().await;
if let Some(ref mut agent) = *agent {
agent.extend_system_prompt(payload.extension).await;
Ok(Json(ExtendPromptResponse { success: true }))
} else {
Err(StatusCode::NOT_FOUND)
}
}
#[axum::debug_handler]
async fn create_agent(
State(state): State<AppState>,
headers: HeaderMap,
Json(payload): Json<CreateAgentRequest>,
) -> Result<Json<CreateAgentResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;
// Set the environment variable for the model if provided
if let Some(model) = &payload.model {
let env_var_key = format!("{}_MODEL", payload.provider.to_uppercase());
env::set_var(env_var_key.clone(), model);
println!("Set environment variable: {}={}", env_var_key, model);
}
let config = Config::global();
let model = payload.model.unwrap_or_else(|| {
config
.get_param("GOOSE_MODEL")
.expect("Did not find a model on payload or in env")
});
let model_config = ModelConfig::new(model);
let provider =
providers::create(&payload.provider, model_config).expect("Failed to create provider");
let version = String::from("goose");
let new_agent = Agent::new(provider);
let mut agent = state.agent.write().await;
*agent = Some(new_agent);
Ok(Json(CreateAgentResponse { version }))
let agent = state
.get_agent()
.await
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
agent.extend_system_prompt(payload.extension.clone()).await;
Ok(Json(ExtendPromptResponse { success: true }))
}
async fn list_providers() -> Json<Vec<ProviderList>> {
@@ -168,7 +129,7 @@ async fn list_providers() -> Json<Vec<ProviderList>> {
)
)]
async fn get_tools(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Query(query): Query<GetToolsQuery>,
) -> Result<Json<Vec<ToolInfo>>, StatusCode> {
@@ -176,8 +137,10 @@ async fn get_tools(
let config = Config::global();
let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
let agent = state.agent.read().await;
let agent = agent.as_ref().ok_or(StatusCode::PRECONDITION_REQUIRED)?;
let agent = state
.get_agent()
.await
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
let permission_manager = PermissionManager::default();
let mut tools: Vec<ToolInfo> = agent
@@ -210,12 +173,56 @@ async fn get_tools(
Ok(Json(tools))
}
pub fn routes(state: AppState) -> Router {
#[utoipa::path(
post,
path = "/agent/update_provider",
responses(
(status = 200, description = "Update provider completed", body = String),
(status = 500, description = "Internal server error")
)
)]
async fn update_agent_provider(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(payload): Json<UpdateProviderRequest>,
) -> Result<StatusCode, StatusCode> {
// Verify secret key
let secret_key = headers
.get("X-Secret-Key")
.and_then(|value| value.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;
if secret_key != state.secret_key {
return Err(StatusCode::UNAUTHORIZED);
}
let agent = state
.get_agent()
.await
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
let config = Config::global();
let model = payload.model.unwrap_or_else(|| {
config
.get_param("GOOSE_MODEL")
.expect("Did not find a model on payload or in env to update provider with")
});
let model_config = ModelConfig::new(model);
let new_provider = create(&payload.provider, model_config).unwrap();
agent
.update_provider(new_provider)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(StatusCode::OK)
}
pub fn routes(state: Arc<AppState>) -> Router {
Router::new()
.route("/agent/versions", get(get_versions))
.route("/agent/providers", get(list_providers))
.route("/agent/prompt", post(extend_prompt))
.route("/agent/tools", get(get_tools))
.route("/agent", post(create_agent))
.route("/agent/update_provider", post(update_agent_provider))
.with_state(state)
}

View File

@@ -18,7 +18,7 @@ use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_yaml;
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};
use utoipa::ToSchema;
#[derive(Serialize, ToSchema)]
@@ -89,7 +89,7 @@ pub struct UpsertPermissionsQuery {
)
)]
pub async fn upsert_config(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(query): Json<UpsertConfigQuery>,
) -> Result<Json<Value>, StatusCode> {
@@ -116,7 +116,7 @@ pub async fn upsert_config(
)
)]
pub async fn remove_config(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(query): Json<ConfigKeyQuery>,
) -> Result<Json<String>, StatusCode> {
@@ -148,7 +148,7 @@ pub async fn remove_config(
)
)]
pub async fn read_config(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(query): Json<ConfigKeyQuery>,
) -> Result<Json<Value>, StatusCode> {
@@ -180,7 +180,7 @@ pub async fn read_config(
)
)]
pub async fn get_extensions(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
) -> Result<Json<ExtensionResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;
@@ -213,7 +213,7 @@ pub async fn get_extensions(
)
)]
pub async fn add_extension(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(extension_query): Json<ExtensionQuery>,
) -> Result<Json<String>, StatusCode> {
@@ -251,7 +251,7 @@ pub async fn add_extension(
)
)]
pub async fn remove_extension(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
axum::extract::Path(name): axum::extract::Path<String>,
) -> Result<Json<String>, StatusCode> {
@@ -272,7 +272,7 @@ pub async fn remove_extension(
)
)]
pub async fn read_all_config(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
) -> Result<Json<ConfigResponse>, StatusCode> {
// Use the helper function to verify the secret key
@@ -297,7 +297,7 @@ pub async fn read_all_config(
)
)]
pub async fn providers(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
) -> Result<Json<Vec<ProviderDetails>>, StatusCode> {
verify_secret_key(&headers, &state)?;
@@ -332,7 +332,7 @@ pub async fn providers(
)
)]
pub async fn init_config(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
) -> Result<Json<String>, StatusCode> {
verify_secret_key(&headers, &state)?;
@@ -402,7 +402,7 @@ pub async fn init_config(
)
)]
pub async fn upsert_permissions(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(query): Json<UpsertPermissionsQuery>,
) -> Result<Json<String>, StatusCode> {
@@ -435,7 +435,7 @@ pub static APP_STRATEGY: Lazy<AppStrategyArgs> = Lazy::new(|| AppStrategyArgs {
)
)]
pub async fn backup_config(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
) -> Result<Json<String>, StatusCode> {
verify_secret_key(&headers, &state)?;
@@ -466,7 +466,7 @@ pub async fn backup_config(
}
}
pub fn routes(state: AppState) -> Router {
pub fn routes(state: Arc<AppState>) -> Router {
Router::new()
.route("/config", get(read_all_config))
.route("/config/upsert", post(upsert_config))

View File

@@ -10,7 +10,7 @@ use http::{HeaderMap, StatusCode};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};
#[derive(Serialize)]
struct ConfigResponse {
@@ -26,7 +26,7 @@ struct ConfigRequest {
}
async fn store_config(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(request): Json<ConfigRequest>,
) -> Result<Json<ConfigResponse>, StatusCode> {
@@ -148,7 +148,7 @@ pub struct GetConfigResponse {
}
pub async fn get_config(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Query(query): Query<GetConfigQuery>,
) -> Result<Json<GetConfigResponse>, StatusCode> {
@@ -174,7 +174,7 @@ struct DeleteConfigRequest {
}
async fn delete_config(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(request): Json<DeleteConfigRequest>,
) -> Result<StatusCode, StatusCode> {
@@ -193,7 +193,7 @@ async fn delete_config(
}
}
pub fn routes(state: AppState) -> Router {
pub fn routes(state: Arc<AppState>) -> Router {
Router::new()
.route("/configs/providers", post(check_provider_configs))
.route("/configs/get", get(get_config))

View File

@@ -8,6 +8,7 @@ use axum::{
};
use goose::message::Message;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
// Direct message serialization for context mgmt request
#[derive(Debug, Deserialize)]
@@ -26,15 +27,16 @@ pub struct ContextManageResponse {
}
async fn manage_context(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(request): Json<ContextManageRequest>,
) -> Result<Json<ContextManageResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;
// Get a lock on the shared agent
let agent = state.agent.read().await;
let agent = agent.as_ref().ok_or(StatusCode::PRECONDITION_REQUIRED)?;
let agent = state
.get_agent()
.await
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
let mut processed_messages: Vec<Message> = vec![];
let mut token_counts: Vec<usize> = vec![];
@@ -57,7 +59,7 @@ async fn manage_context(
}
// Configure routes for this module
pub fn routes(state: AppState) -> Router {
pub fn routes(state: Arc<AppState>) -> Router {
Router::new()
.route("/context/manage", post(manage_context))
.with_state(state)

View File

@@ -1,5 +1,6 @@
use std::env;
use std::path::Path;
use std::sync::Arc;
use std::sync::OnceLock;
use super::utils::verify_secret_key;
@@ -79,7 +80,7 @@ struct ExtensionResponse {
/// Handler for adding a new extension configuration.
async fn add_extension(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
raw: axum::extract::Json<serde_json::Value>,
) -> Result<Json<ExtensionResponse>, StatusCode> {
@@ -228,9 +229,11 @@ async fn add_extension(
},
};
// Acquire a lock on the agent and attempt to add the extension.
let mut agent = state.agent.write().await;
let agent = agent.as_mut().ok_or(StatusCode::PRECONDITION_REQUIRED)?;
// Get a reference to the agent
let agent = state
.get_agent()
.await
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
let response = agent.add_extension(extension_config).await;
// Respond with the result.
@@ -254,15 +257,17 @@ async fn add_extension(
/// Handler for removing an extension by name
async fn remove_extension(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(name): Json<String>,
) -> Result<Json<ExtensionResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;
// Acquire a lock on the agent and attempt to remove the extension
let mut agent = state.agent.write().await;
let agent = agent.as_mut().ok_or(StatusCode::PRECONDITION_REQUIRED)?;
// Get a reference to the agent
let agent = state
.get_agent()
.await
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
agent.remove_extension(&name).await;
Ok(Json(ExtensionResponse {
@@ -272,7 +277,7 @@ async fn remove_extension(
}
/// Registers the extension management routes with the Axum router.
pub fn routes(state: AppState) -> Router {
pub fn routes(state: Arc<AppState>) -> Router {
Router::new()
.route("/extensions/add", post(add_extension))
.route("/extensions/remove", post(remove_extension))

View File

@@ -9,10 +9,12 @@ pub mod recipe;
pub mod reply;
pub mod session;
pub mod utils;
use std::sync::Arc;
use axum::Router;
// Function to configure all routes
pub fn configure(state: crate::state::AppState) -> Router {
pub fn configure(state: Arc<crate::state::AppState>) -> Router {
Router::new()
.merge(health::routes())
.merge(reply::routes(state.clone()))
@@ -22,5 +24,5 @@ pub fn configure(state: crate::state::AppState) -> Router {
.merge(configs::routes(state.clone()))
.merge(config_management::routes(state.clone()))
.merge(recipe::routes(state.clone()))
.merge(session::routes(state))
.merge(session::routes(state.clone()))
}

View File

@@ -1,3 +1,5 @@
use std::sync::Arc;
use axum::{extract::State, http::StatusCode, routing::post, Json, Router};
use goose::message::Message;
use goose::recipe::Recipe;
@@ -34,17 +36,17 @@ pub struct CreateRecipeResponse {
/// Create a Recipe configuration from the current state of an agent
async fn create_recipe(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
Json(request): Json<CreateRecipeRequest>,
) -> Result<Json<CreateRecipeResponse>, (StatusCode, Json<CreateRecipeResponse>)> {
let agent = state.agent.read().await;
let agent = agent.as_ref().ok_or_else(|| {
let error_response = CreateRecipeResponse {
recipe: None,
error: Some("Agent not initialized".to_string()),
};
(StatusCode::PRECONDITION_REQUIRED, Json(error_response))
})?;
let error_response = CreateRecipeResponse {
recipe: None,
error: Some("Missing agent".to_string()),
};
let agent = state
.get_agent()
.await
.map_err(|_| (StatusCode::PRECONDITION_FAILED, Json(error_response)))?;
// Create base recipe from agent state and messages
let recipe_result = agent.create_recipe(request.messages).await;
@@ -82,7 +84,7 @@ async fn create_recipe(
}
}
pub fn routes(state: AppState) -> Router {
pub fn routes(state: Arc<AppState>) -> Router {
Router::new()
.route("/recipe/create", post(create_recipe))
.with_state(state)

View File

@@ -26,6 +26,7 @@ use std::{
convert::Infallible,
path::PathBuf,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
@@ -101,7 +102,7 @@ async fn stream_event(
}
async fn handler(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(request): Json<ChatRequest>,
) -> Result<SseResponse, StatusCode> {
@@ -119,15 +120,34 @@ async fn handler(
.session_id
.unwrap_or_else(session::generate_session_id);
// Get a lock on the shared agent
let agent = state.agent.clone();
// Spawn task to handle streaming
tokio::spawn(async move {
let agent = agent.read().await;
let agent = match agent.as_ref() {
Some(agent) => agent,
None => {
let agent = state.get_agent().await;
let agent = match agent {
Ok(agent) => {
let provider = agent.provider().await;
match provider {
Ok(_) => agent,
Err(_) => {
let _ = stream_event(
MessageEvent::Error {
error: "No provider configured".to_string(),
},
&tx,
)
.await;
let _ = stream_event(
MessageEvent::Finish {
reason: "error".to_string(),
},
&tx,
)
.await;
return;
}
}
}
Err(_) => {
let _ = stream_event(
MessageEvent::Error {
error: "No agent configured".to_string(),
@@ -147,7 +167,7 @@ async fn handler(
};
// Get the provider first, before starting the reply stream
let provider = agent.provider();
let provider = agent.provider().await;
let mut stream = match agent
.reply(
@@ -204,7 +224,7 @@ async fn handler(
// Store messages and generate description in background
let session_path = session_path.clone();
let messages = all_messages.clone();
let provider = provider.clone();
let provider = Arc::clone(provider.as_ref().unwrap());
tokio::spawn(async move {
if let Err(e) = session::persist_messages(&session_path, &messages, Some(provider)).await {
tracing::error!("Failed to store session history: {:?}", e);
@@ -262,7 +282,7 @@ struct AskResponse {
// Simple ask an AI for a response, non streaming
async fn ask_handler(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(request): Json<AskRequest>,
) -> Result<Json<AskResponse>, StatusCode> {
@@ -275,12 +295,13 @@ async fn ask_handler(
.session_id
.unwrap_or_else(session::generate_session_id);
let agent = state.agent.clone();
let agent = agent.write().await;
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;
let agent = state
.get_agent()
.await
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
// Get the provider first, before starting the reply stream
let provider = agent.provider();
let provider = agent.provider().await;
// Create a single message for the prompt
let messages = vec![Message::user().with_text(request.prompt)];
@@ -339,7 +360,7 @@ async fn ask_handler(
// Store messages and generate description in background
let session_path = session_path.clone();
let messages = all_messages.clone();
let provider = provider.clone();
let provider = Arc::clone(provider.as_ref().unwrap());
tokio::spawn(async move {
if let Err(e) = session::persist_messages(&session_path, &messages, Some(provider)).await {
tracing::error!("Failed to store session history: {:?}", e);
@@ -374,15 +395,16 @@ fn default_principal_type() -> PrincipalType {
)
)]
pub async fn confirm_permission(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(request): Json<PermissionConfirmationRequest>,
) -> Result<Json<Value>, StatusCode> {
verify_secret_key(&headers, &state)?;
let agent = state.agent.clone();
let agent = agent.read().await;
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;
let agent = state
.get_agent()
.await
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
let permission = match request.action.as_str() {
"always_allow" => Permission::AlwaysAllow,
@@ -410,7 +432,7 @@ struct ToolResultRequest {
}
async fn submit_tool_result(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
raw: axum::extract::Json<serde_json::Value>,
) -> Result<Json<Value>, StatusCode> {
@@ -435,14 +457,16 @@ async fn submit_tool_result(
}
};
let agent = state.agent.read().await;
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;
let agent = state
.get_agent()
.await
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
agent.handle_tool_result(payload.id, payload.result).await;
Ok(Json(json!({"status": "ok"})))
}
// Configure routes for this module
pub fn routes(state: AppState) -> Router {
pub fn routes(state: Arc<AppState>) -> Router {
Router::new()
.route("/reply", post(handler))
.route("/ask", post(ask_handler))
@@ -496,9 +520,7 @@ mod tests {
mod integration_tests {
use super::*;
use axum::{body::Body, http::Request};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
use tower::ServiceExt;
// This test requires tokio runtime
@@ -509,12 +531,9 @@ mod tests {
let mock_provider = Arc::new(MockProvider {
model_config: mock_model_config,
});
let agent = Agent::new(mock_provider);
let state = AppState {
config: Arc::new(Mutex::new(HashMap::new())),
agent: Arc::new(RwLock::new(Some(agent))),
secret_key: "test-secret".to_string(),
};
let agent = Agent::new();
let _ = agent.update_provider(mock_provider).await;
let state = AppState::new(Arc::new(agent), "test-secret".to_string()).await;
// Build router
let app = routes(state);

View File

@@ -1,4 +1,6 @@
use super::utils::verify_secret_key;
use std::sync::Arc;
use crate::state::AppState;
use axum::{
extract::{Path, State},
@@ -25,7 +27,7 @@ struct SessionHistoryResponse {
// List all available sessions
async fn list_sessions(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
) -> Result<Json<SessionListResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;
@@ -38,7 +40,7 @@ async fn list_sessions(
// Get a specific session's history
async fn get_session_history(
State(state): State<AppState>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Path(session_id): Path<String>,
) -> Result<Json<SessionHistoryResponse>, StatusCode> {
@@ -65,7 +67,7 @@ async fn get_session_history(
}
// Configure routes for this module
pub fn routes(state: AppState) -> Router {
pub fn routes(state: Arc<AppState>) -> Router {
Router::new()
.route("/sessions", get(list_sessions))
.route("/sessions/:session_id", get(get_session_history))

View File

@@ -1,25 +1,34 @@
use anyhow::Result;
use goose::agents::Agent;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
/// Shared reference to an Agent that can be cloned cheaply
/// without cloning the underlying Agent object
pub type AgentRef = Arc<Agent>;
/// Thread-safe container for an optional Agent reference
/// Outer Arc: Allows multiple route handlers to access the same Mutex
/// - Mutex provides exclusive access for updates
/// - Option allows for the case where no agent exists yet
///
/// Shared application state
#[allow(dead_code)]
#[derive(Clone)]
pub struct AppState {
pub agent: Arc<RwLock<Option<Agent>>>,
// agent: SharedAgentStore,
agent: Option<AgentRef>,
pub secret_key: String,
pub config: Arc<Mutex<HashMap<String, Value>>>,
}
impl AppState {
pub async fn new(secret_key: String) -> Result<Self> {
Ok(Self {
agent: Arc::new(RwLock::new(None)),
pub async fn new(agent: AgentRef, secret_key: String) -> Arc<AppState> {
Arc::new(Self {
agent: Some(agent.clone()),
secret_key,
config: Arc::new(Mutex::new(HashMap::new())),
})
}
pub async fn get_agent(&self) -> Result<Arc<Agent>, anyhow::Error> {
self.agent
.clone()
.ok_or_else(|| anyhow::anyhow!("Agent needs to be created first."))
}
}

View File

@@ -15,7 +15,8 @@ async fn main() {
let provider = Arc::new(DatabricksProvider::default());
// Setup an agent with the developer extension
let mut agent = Agent::new(provider);
let agent = Agent::new();
let _ = agent.update_provider(provider).await;
let config = ExtensionConfig::stdio(
"developer",

View File

@@ -35,12 +35,11 @@ use super::tool_execution::{ToolFuture, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINE
/// The main goose Agent
pub struct Agent {
pub(super) provider: Arc<dyn Provider>,
pub(super) provider: Mutex<Option<Arc<dyn Provider>>>,
pub(super) extension_manager: Mutex<ExtensionManager>,
pub(super) frontend_tools: HashMap<String, FrontendTool>,
pub(super) frontend_instructions: Option<String>,
pub(super) prompt_manager: PromptManager,
// Channels for tool results and confirmations
pub(super) frontend_tools: Mutex<HashMap<String, FrontendTool>>,
pub(super) frontend_instructions: Mutex<Option<String>>,
pub(super) prompt_manager: Mutex<PromptManager>,
pub(super) confirmation_tx: mpsc::Sender<(String, PermissionConfirmation)>,
pub(super) confirmation_rx: Mutex<mpsc::Receiver<(String, PermissionConfirmation)>>,
pub(super) tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
@@ -48,41 +47,52 @@ pub struct Agent {
}
impl Agent {
pub fn new(provider: Arc<dyn Provider>) -> Self {
pub fn new() -> Self {
// Create channels with buffer size 32 (adjust if needed)
let (confirm_tx, confirm_rx) = mpsc::channel(32);
let (tool_tx, tool_rx) = mpsc::channel(32);
Self {
provider,
provider: Mutex::new(None),
extension_manager: Mutex::new(ExtensionManager::new()),
frontend_tools: HashMap::new(),
frontend_instructions: None,
prompt_manager: PromptManager::new(),
frontend_tools: Mutex::new(HashMap::new()),
frontend_instructions: Mutex::new(None),
prompt_manager: Mutex::new(PromptManager::new()),
confirmation_tx: confirm_tx,
confirmation_rx: Mutex::new(confirm_rx),
tool_result_tx: tool_tx,
tool_result_rx: Arc::new(Mutex::new(tool_rx)),
}
}
}
impl Default for Agent {
fn default() -> Self {
Self::new()
}
}
impl Agent {
/// Get a reference count clone to the provider
pub fn provider(&self) -> Arc<dyn Provider> {
Arc::clone(&self.provider)
pub async fn provider(&self) -> Result<Arc<dyn Provider>, anyhow::Error> {
match &*self.provider.lock().await {
Some(provider) => Ok(Arc::clone(provider)),
None => Err(anyhow!("Provider not set")),
}
}
/// Check if a tool is a frontend tool
pub fn is_frontend_tool(&self, name: &str) -> bool {
self.frontend_tools.contains_key(name)
pub async fn is_frontend_tool(&self, name: &str) -> bool {
self.frontend_tools.lock().await.contains_key(name)
}
/// Get a reference to a frontend tool
pub fn get_frontend_tool(&self, name: &str) -> Option<&FrontendTool> {
self.frontend_tools.get(name)
pub async fn get_frontend_tool(&self, name: &str) -> Option<FrontendTool> {
self.frontend_tools.lock().await.get(name).cloned()
}
/// Get all tools from all clients with proper prefixing
pub async fn get_prefixed_tools(&mut self) -> ExtensionResult<Vec<Tool>> {
pub async fn get_prefixed_tools(&self) -> ExtensionResult<Vec<Tool>> {
let mut tools = self
.extension_manager
.lock()
@@ -91,7 +101,8 @@ impl Agent {
.await?;
// Add frontend tools directly - they don't need prefixing since they're already uniquely named
for frontend_tool in self.frontend_tools.values() {
let frontend_tools = self.frontend_tools.lock().await;
for frontend_tool in frontend_tools.values() {
tools.push(frontend_tool.tool.clone());
}
@@ -135,7 +146,7 @@ impl Agent {
.await
} else if tool_call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME {
extension_manager.search_available_extensions().await
} else if self.is_frontend_tool(&tool_call.name) {
} else if self.is_frontend_tool(&tool_call.name).await {
// For frontend tools, return an error indicating we need frontend execution
Err(ToolError::ExecutionError(
"Frontend tool execution required".to_string(),
@@ -212,7 +223,7 @@ impl Agent {
(request_id, result)
}
pub async fn add_extension(&mut self, extension: ExtensionConfig) -> ExtensionResult<()> {
pub async fn add_extension(&self, extension: ExtensionConfig) -> ExtensionResult<()> {
match &extension {
ExtensionConfig::Frontend {
name: _,
@@ -221,19 +232,21 @@ impl Agent {
bundled: _,
} => {
// For frontend tools, just store them in the frontend_tools map
let mut frontend_tools = self.frontend_tools.lock().await;
for tool in tools {
let frontend_tool = FrontendTool {
name: tool.name.clone(),
tool: tool.clone(),
};
self.frontend_tools.insert(tool.name.clone(), frontend_tool);
frontend_tools.insert(tool.name.clone(), frontend_tool);
}
// Store instructions if provided, using "frontend" as the key
let mut frontend_instructions = self.frontend_instructions.lock().await;
if let Some(instructions) = instructions {
self.frontend_instructions = Some(instructions.clone());
*frontend_instructions = Some(instructions.clone());
} else {
// Default frontend instructions if none provided
self.frontend_instructions = Some(
*frontend_instructions = Some(
"The following tools are provided directly by the frontend and will be executed by the frontend when called.".to_string(),
);
}
@@ -269,7 +282,7 @@ impl Agent {
prefixed_tools
}
pub async fn remove_extension(&mut self, name: &str) {
pub async fn remove_extension(&self, name: &str) {
let mut extension_manager = self.extension_manager.lock().await;
extension_manager
.remove_extension(name)
@@ -329,7 +342,7 @@ impl Agent {
let _ = reply_span.enter();
loop {
match Self::generate_response_from_provider(
self.provider(),
self.provider().await?,
&system_prompt,
&messages,
&tools,
@@ -345,7 +358,7 @@ impl Agent {
let (frontend_requests,
remaining_requests,
filtered_response) =
self.categorize_tool_requests(&response);
self.categorize_tool_requests(&response).await;
// Yield the assistant's response with frontend tool requests filtered out
@@ -396,8 +409,7 @@ impl Agent {
tools_with_readonly_annotation.clone(),
tools_without_annotation.clone(),
&mut permission_manager,
self.provider(),
).await;
self.provider().await?).await;
// Handle pre-approved and read-only tools in parallel
let mut tool_futures: Vec<ToolFuture> = Vec::new();
@@ -492,13 +504,21 @@ impl Agent {
}
/// Extend the system prompt with one line of additional instruction
pub async fn extend_system_prompt(&mut self, instruction: String) {
self.prompt_manager.add_system_prompt_extra(instruction);
pub async fn extend_system_prompt(&self, instruction: String) {
let mut prompt_manager = self.prompt_manager.lock().await;
prompt_manager.add_system_prompt_extra(instruction);
}
/// Update the provider used by this agent
pub async fn update_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
*self.provider.lock().await = Some(provider);
Ok(())
}
/// Override the system prompt with a custom template
pub async fn override_system_prompt(&mut self, template: String) {
self.prompt_manager.set_system_prompt_override(template);
pub async fn override_system_prompt(&self, template: String) {
let mut prompt_manager = self.prompt_manager.lock().await;
prompt_manager.set_system_prompt_override(template);
}
pub async fn list_extension_prompts(&self) -> HashMap<String, Vec<Prompt>> {
@@ -563,23 +583,29 @@ impl Agent {
let extensions_info = extension_manager.get_extensions_info().await;
// Get model name from provider
let model_config = self.provider.get_model_config();
let provider = self.provider().await?;
let model_config = provider.get_model_config();
let model_name = &model_config.model_name;
let system_prompt = self.prompt_manager.build_system_prompt(
let prompt_manager = self.prompt_manager.lock().await;
let system_prompt = prompt_manager.build_system_prompt(
extensions_info,
self.frontend_instructions.clone(),
self.frontend_instructions.lock().await.clone(),
extension_manager.suggest_disable_extensions_prompt().await,
Some(model_name),
);
let recipe_prompt = self.prompt_manager.get_recipe_prompt().await;
let recipe_prompt = prompt_manager.get_recipe_prompt().await;
let tools = extension_manager.get_prefixed_tools(None).await?;
messages.push(Message::user().with_text(recipe_prompt));
let (result, _usage) = self
.provider
.lock()
.await
.as_ref()
.unwrap()
.complete(&system_prompt, &messages, &tools)
.await?;

View File

@@ -15,7 +15,7 @@ impl Agent {
&self,
messages: &[Message], // last message is a user msg that led to assistant message with_context_length_exceeded
) -> Result<(Vec<Message>, Vec<usize>), anyhow::Error> {
let provider = self.provider.clone();
let provider = self.provider().await?;
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
let target_context_limit = estimate_target_context_limit(provider);
let token_counts = get_messages_token_counts(&token_counter, messages);
@@ -41,7 +41,7 @@ impl Agent {
&self,
messages: &[Message], // last message is a user msg that led to assistant message with_context_length_exceeded
) -> Result<(Vec<Message>, Vec<usize>), anyhow::Error> {
let provider = self.provider.clone();
let provider = self.provider().await?;
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
let target_context_limit = estimate_target_context_limit(provider.clone());

View File

@@ -22,7 +22,8 @@ impl Agent {
let mut tools = self.list_tools(None).await;
// Add frontend tools
for frontend_tool in self.frontend_tools.values() {
let frontend_tools = self.frontend_tools.lock().await;
for frontend_tool in frontend_tools.values() {
tools.push(frontend_tool.tool.clone());
}
@@ -31,19 +32,21 @@ impl Agent {
let extensions_info = extension_manager.get_extensions_info().await;
// Get model name from provider
let model_config = self.provider.get_model_config();
let provider = self.provider().await?;
let model_config = provider.get_model_config();
let model_name = &model_config.model_name;
let mut system_prompt = self.prompt_manager.build_system_prompt(
let prompt_manager = self.prompt_manager.lock().await;
let mut system_prompt = prompt_manager.build_system_prompt(
extensions_info,
self.frontend_instructions.clone(),
self.frontend_instructions.lock().await.clone(),
extension_manager.suggest_disable_extensions_prompt().await,
Some(model_name),
);
// Handle toolshim if enabled
let mut toolshim_tools = vec![];
if self.provider.get_model_config().toolshim {
if model_config.toolshim {
// If tool interpretation is enabled, modify the system prompt
system_prompt = modify_system_prompt_for_tool_json(&system_prompt, &tools);
// Make a copy of tools before emptying
@@ -115,7 +118,7 @@ impl Agent {
/// - frontend_requests: Tool requests that should be handled by the frontend
/// - other_requests: All other tool requests (including requests to enable extensions)
/// - filtered_message: The original message with frontend tool requests removed
pub(crate) fn categorize_tool_requests(
pub(crate) async fn categorize_tool_requests(
&self,
response: &Message,
) -> (Vec<ToolRequest>, Vec<ToolRequest>, Message) {
@@ -133,20 +136,25 @@ impl Agent {
.collect();
// Create a filtered message with frontend tool requests removed
let filtered_content = response
.content
.iter()
.filter(|c| {
if let MessageContent::ToolRequest(req) = c {
// Only filter out frontend tool requests
let mut filtered_content = Vec::new();
// Process each content item one by one
for content in &response.content {
let should_include = match content {
MessageContent::ToolRequest(req) => {
if let Ok(tool_call) = &req.tool_call {
return !self.is_frontend_tool(&tool_call.name);
!self.is_frontend_tool(&tool_call.name).await
} else {
true
}
}
true
})
.cloned()
.collect();
_ => true,
};
if should_include {
filtered_content.push(content.clone());
}
}
let filtered_message = Message {
role: response.role.clone(),
@@ -160,7 +168,7 @@ impl Agent {
for request in tool_requests {
if let Ok(tool_call) = &request.tool_call {
if self.is_frontend_tool(&tool_call.name) {
if self.is_frontend_tool(&tool_call.name).await {
frontend_requests.push(request);
} else {
other_requests.push(request);

View File

@@ -87,7 +87,7 @@ impl Agent {
try_stream! {
for request in tool_requests {
if let Ok(tool_call) = request.tool_call.clone() {
if self.is_frontend_tool(&tool_call.name) {
if self.is_frontend_tool(&tool_call.name).await {
// Send frontend tool request and wait for response
yield Message::assistant().with_frontend_tool_request(
request.id.clone(),

View File

@@ -739,10 +739,10 @@ mod tests {
thread::sleep(Duration::from_millis(i * 10));
let extension_key = format!("extension_{}", i);
let mut values = config.load_values()?;
values.insert(
extension_key.clone(),
// Use set_param which handles concurrent access properly
config.set_param(
&extension_key,
serde_json::json!({
"name": format!("test_extension_{}", i),
"version": format!("1.0.{}", i),
@@ -752,10 +752,7 @@ mod tests {
"option2": i
}
}),
);
// Write all values atomically
config.save_values(values)?;
)?;
Ok(())
});
handles.push(handle);

View File

@@ -110,7 +110,8 @@ async fn run_truncate_test(
.with_temperature(Some(0.0));
let provider = provider_type.create_provider(model_config)?;
let agent = Agent::new(provider);
let agent = Agent::new();
agent.update_provider(provider).await?;
let repeat_count = context_window + 10_000;
let large_message_content = "hello ".repeat(repeat_count);
let messages = vec![

View File

@@ -33,12 +33,28 @@ CALCULATOR_TOOL = {
},
}
# Enable Extension tool definition
ENABLE_EXTENSION_TOOL = {
"name": "enable_extension",
"description": "Enable extensions to help complete tasks. Enable an extension by providing the extension name.",
"inputSchema": {
"type": "object",
"required": ["extension_name"],
"properties": {
"extension_name": {
"type": "string",
"description": "The name of the extension to enable",
},
},
},
}
# Frontend extension configuration
FRONTEND_CONFIG = {
"name": "pythonclient",
"type": "frontend",
"tools": [CALCULATOR_TOOL],
"instructions": "A calculator extension that can perform basic arithmetic operations.",
"tools": [CALCULATOR_TOOL, ENABLE_EXTENSION_TOOL],
"instructions": "A calculator extension that can perform basic arithmetic operations. Use enable extension tool to add extesions such as fetch, pdf reader, etc.",
}
@@ -47,7 +63,7 @@ async def setup_agent() -> None:
async with httpx.AsyncClient() as client:
# First create the agent
response = await client.post(
f"{GOOSE_URL}/agent",
f"{GOOSE_URL}/agent/update_provider",
json={"provider": "databricks", "model": "goose"},
headers={"X-Secret-Key": SECRET_KEY},
)
@@ -101,6 +117,55 @@ def execute_calculator(args: Dict[str, Any]) -> List[Dict[str, Any]]:
}
]
def get_tools() -> Dict[str, Any]:
with httpx.Client() as client:
response = client.get(
f"{GOOSE_URL}/agent/tools",
headers={"X-Secret-Key": SECRET_KEY},
)
response.raise_for_status()
return response.json()
def execute_enable_extension(args: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Execute the enable_extension tool.
This function fetches available extensions, finds the one with the provided extension_name,
and posts its configuration to the /extensions/add endpoint.
"""
extension = args
extension_name = extension.get("name")
# Post the extension configuration to enable it
with httpx.Client() as client:
payload = {
"type": extension.get("type"),
"name": extension.get("name"),
"cmd": extension.get("cmd"),
"args": extension.get("args"),
"envs": extension.get("envs", {}),
"timeout": extension.get("timeout"),
"bundled": extension.get("bundled"),
}
add_response = client.post(
f"{GOOSE_URL}/extensions/add",
json=payload,
headers={"Content-Type": "application/json", "X-Secret-Key": SECRET_KEY},
)
if add_response.status_code != 200:
error_text = add_response.text
return [{
"type": "text",
"text": f"Error: Failed to enable extension: {error_text}",
"annotations": None,
}]
return [{
"type": "text",
"text": f"Successfully enabled extension: {extension_name}",
"annotations": None,
}]
def submit_tool_result(tool_id: str, result: List[Dict[str, Any]]) -> None:
"""Submit the tool execution result back to Goose.
@@ -129,7 +194,7 @@ async def chat_loop() -> None:
session_id = "test-session"
# Use a client with a longer timeout for streaming
async with httpx.AsyncClient(timeout=30.0) as client:
async with httpx.AsyncClient(timeout=60.0) as client:
# Get user input
user_message = input("\nYou: ")
if user_message.lower() in ["exit", "quit"]:
@@ -152,7 +217,7 @@ async def chat_loop() -> None:
# Process the stream of responses
async with client.stream(
"POST",
f"{GOOSE_URL}/reply",
f"{GOOSE_URL}/reply", # lock
json=payload,
headers={
"X-Secret-Key": SECRET_KEY,
@@ -185,9 +250,26 @@ async def chat_loop() -> None:
elif content["type"] == "frontendToolRequest":
# Execute the tool and submit results
tool_call = content["toolCall"]["value"]
print(f"Calculator: {tool_call}")
# Execute the tool
result = execute_calculator(tool_call["arguments"])
print(f"\nTool Request: {tool_call}")
if tool_call['name'] == "calculator":
print(f"Calculator: {tool_call}")
# Execute the tool
result = execute_calculator(tool_call["arguments"])
elif tool_call['name'] == "enable_extension":
# to trigger this tool, use the instruction "use enable_extension tool with "fetch" extension name"
print(f"Enabling fetch extension")
result = execute_enable_extension(args={
"type": "stdio",
"name": "fetch",
"cmd": "uvx",
"args": ["mcp-server-fetch"],
"timeout": 300,
"bundled": False
})
listed_tools = get_tools()
print(f"\nTools after enabling extension: {listed_tools}")
# Submit the result
submit_tool_result(content["id"], result)

View File

@@ -6,7 +6,7 @@ interface initializeAgentProps {
}
export async function initializeAgent({ model, provider }: initializeAgentProps) {
const response = await fetch(getApiUrl('/agent'), {
const response = await fetch(getApiUrl('/agent/update_provider'), {
method: 'POST',
headers: {
'Content-Type': 'application/json',