mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-17 22:24:21 +01:00
chore: refactor read-write lock on agent (#2225)
Co-authored-by: Alice Hau <ahau@squareup.com>
This commit is contained in:
@@ -983,11 +983,11 @@ pub async fn configure_tool_permissions_dialog() -> Result<(), Box<dyn Error>> {
|
|||||||
.get_param("GOOSE_MODEL")
|
.get_param("GOOSE_MODEL")
|
||||||
.expect("No model configured. Please set model first");
|
.expect("No model configured. Please set model first");
|
||||||
let model_config = goose::model::ModelConfig::new(model.clone());
|
let model_config = goose::model::ModelConfig::new(model.clone());
|
||||||
let provider =
|
|
||||||
goose::providers::create(&provider_name, model_config).expect("Failed to create provider");
|
|
||||||
|
|
||||||
// Create the agent
|
// Create the agent
|
||||||
let mut agent = Agent::new(provider);
|
let agent = Agent::new();
|
||||||
|
let new_provider = create(&provider_name, model_config)?;
|
||||||
|
agent.update_provider(new_provider).await?;
|
||||||
if let Ok(Some(config)) = ExtensionConfigManager::get_config_by_name(&selected_extension_name) {
|
if let Ok(Some(config)) = ExtensionConfigManager::get_config_by_name(&selected_extension_name) {
|
||||||
agent
|
agent
|
||||||
.add_extension(config.clone())
|
.add_extension(config.clone())
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ use console::style;
|
|||||||
use goose::agents::extension::ExtensionError;
|
use goose::agents::extension::ExtensionError;
|
||||||
use goose::agents::Agent;
|
use goose::agents::Agent;
|
||||||
use goose::config::{Config, ExtensionConfig, ExtensionConfigManager};
|
use goose::config::{Config, ExtensionConfig, ExtensionConfigManager};
|
||||||
|
use goose::providers::create;
|
||||||
use goose::session;
|
use goose::session;
|
||||||
use goose::session::Identifier;
|
use goose::session::Identifier;
|
||||||
use mcp_client::transport::Error as McpClientError;
|
use mcp_client::transport::Error as McpClientError;
|
||||||
@@ -46,11 +47,11 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
|
|||||||
.get_param("GOOSE_MODEL")
|
.get_param("GOOSE_MODEL")
|
||||||
.expect("No model configured. Run 'goose configure' first");
|
.expect("No model configured. Run 'goose configure' first");
|
||||||
let model_config = goose::model::ModelConfig::new(model.clone());
|
let model_config = goose::model::ModelConfig::new(model.clone());
|
||||||
let provider =
|
|
||||||
goose::providers::create(&provider_name, model_config).expect("Failed to create provider");
|
|
||||||
|
|
||||||
// Create the agent
|
// Create the agent
|
||||||
let mut agent = Agent::new(provider);
|
let agent: Agent = Agent::new();
|
||||||
|
let new_provider = create(&provider_name, model_config).unwrap();
|
||||||
|
let _ = agent.update_provider(new_provider).await;
|
||||||
|
|
||||||
// Handle session file resolution and resuming
|
// Handle session file resolution and resuming
|
||||||
let session_file = if session_config.resume {
|
let session_file = if session_config.resume {
|
||||||
|
|||||||
@@ -285,7 +285,7 @@ impl Session {
|
|||||||
async fn process_message(&mut self, message: String) -> Result<()> {
|
async fn process_message(&mut self, message: String) -> Result<()> {
|
||||||
self.messages.push(Message::user().with_text(&message));
|
self.messages.push(Message::user().with_text(&message));
|
||||||
// Get the provider from the agent for description generation
|
// Get the provider from the agent for description generation
|
||||||
let provider = self.agent.provider();
|
let provider = self.agent.provider().await?;
|
||||||
|
|
||||||
// Persist messages with provider for automatic description generation
|
// Persist messages with provider for automatic description generation
|
||||||
session::persist_messages(&self.session_file, &self.messages, Some(provider)).await?;
|
session::persist_messages(&self.session_file, &self.messages, Some(provider)).await?;
|
||||||
@@ -357,7 +357,7 @@ impl Session {
|
|||||||
self.messages.push(Message::user().with_text(&content));
|
self.messages.push(Message::user().with_text(&content));
|
||||||
|
|
||||||
// Get the provider from the agent for description generation
|
// Get the provider from the agent for description generation
|
||||||
let provider = self.agent.provider();
|
let provider = self.agent.provider().await?;
|
||||||
|
|
||||||
// Persist messages with provider for automatic description generation
|
// Persist messages with provider for automatic description generation
|
||||||
session::persist_messages(
|
session::persist_messages(
|
||||||
@@ -526,7 +526,7 @@ impl Session {
|
|||||||
output::render_message(&plan_response, self.debug);
|
output::render_message(&plan_response, self.debug);
|
||||||
output::hide_thinking();
|
output::hide_thinking();
|
||||||
let planner_response_type =
|
let planner_response_type =
|
||||||
classify_planner_response(plan_response.as_concat_text(), self.agent.provider())
|
classify_planner_response(plan_response.as_concat_text(), self.agent.provider().await?)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
match planner_response_type {
|
match planner_response_type {
|
||||||
|
|||||||
@@ -178,7 +178,10 @@ pub unsafe extern "C" fn goose_agent_new(config: *const ProviderConfigFFI) -> Ag
|
|||||||
// Create Databricks provider with required parameters
|
// Create Databricks provider with required parameters
|
||||||
match DatabricksProvider::from_params(host, api_key, model_config) {
|
match DatabricksProvider::from_params(host, api_key, model_config) {
|
||||||
Ok(provider) => {
|
Ok(provider) => {
|
||||||
let agent = Agent::new(Arc::new(provider));
|
let agent = Agent::new();
|
||||||
|
get_runtime().block_on(async {
|
||||||
|
let _ = agent.update_provider(Arc::new(provider)).await;
|
||||||
|
});
|
||||||
Box::into_raw(Box::new(agent))
|
Box::into_raw(Box::new(agent))
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::configuration;
|
use crate::configuration;
|
||||||
use crate::state;
|
use crate::state;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use goose::agents::Agent;
|
||||||
use tower_http::cors::{Any, CorsLayer};
|
use tower_http::cors::{Any, CorsLayer};
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
@@ -15,8 +18,10 @@ pub async fn run() -> Result<()> {
|
|||||||
let secret_key =
|
let secret_key =
|
||||||
std::env::var("GOOSE_SERVER__SECRET_KEY").unwrap_or_else(|_| "test".to_string());
|
std::env::var("GOOSE_SERVER__SECRET_KEY").unwrap_or_else(|_| "test".to_string());
|
||||||
|
|
||||||
// Create app state - agent will start as None
|
let new_agent = Agent::new();
|
||||||
let state = state::AppState::new(secret_key.clone()).await?;
|
|
||||||
|
// Create app state with agent
|
||||||
|
let state = state::AppState::new(Arc::new(new_agent), secret_key.clone()).await;
|
||||||
|
|
||||||
// Create router with CORS support
|
// Create router with CORS support
|
||||||
let cors = CorsLayer::new()
|
let cors = CorsLayer::new()
|
||||||
|
|||||||
@@ -8,14 +8,15 @@ use axum::{
|
|||||||
};
|
};
|
||||||
use goose::config::Config;
|
use goose::config::Config;
|
||||||
use goose::config::PermissionManager;
|
use goose::config::PermissionManager;
|
||||||
use goose::{agents::Agent, model::ModelConfig, providers};
|
use goose::model::ModelConfig;
|
||||||
|
use goose::providers::create;
|
||||||
use goose::{
|
use goose::{
|
||||||
agents::{extension::ToolInfo, extension_manager::get_parameter_names},
|
agents::{extension::ToolInfo, extension_manager::get_parameter_names},
|
||||||
config::permission::PermissionLevel,
|
config::permission::PermissionLevel,
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::env;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
struct VersionsResponse {
|
struct VersionsResponse {
|
||||||
@@ -33,17 +34,6 @@ struct ExtendPromptResponse {
|
|||||||
success: bool,
|
success: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct CreateAgentRequest {
|
|
||||||
provider: String,
|
|
||||||
model: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct CreateAgentResponse {
|
|
||||||
version: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct ProviderFile {
|
struct ProviderFile {
|
||||||
name: String,
|
name: String,
|
||||||
@@ -66,6 +56,12 @@ struct ProviderList {
|
|||||||
details: ProviderDetails,
|
details: ProviderDetails,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct UpdateProviderRequest {
|
||||||
|
provider: String,
|
||||||
|
model: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
pub struct GetToolsQuery {
|
pub struct GetToolsQuery {
|
||||||
extension_name: Option<String>,
|
extension_name: Option<String>,
|
||||||
@@ -82,53 +78,18 @@ async fn get_versions() -> Json<VersionsResponse> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn extend_prompt(
|
async fn extend_prompt(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
Json(payload): Json<ExtendPromptRequest>,
|
Json(payload): Json<ExtendPromptRequest>,
|
||||||
) -> Result<Json<ExtendPromptResponse>, StatusCode> {
|
) -> Result<Json<ExtendPromptResponse>, StatusCode> {
|
||||||
verify_secret_key(&headers, &state)?;
|
verify_secret_key(&headers, &state)?;
|
||||||
|
|
||||||
let mut agent = state.agent.write().await;
|
let agent = state
|
||||||
if let Some(ref mut agent) = *agent {
|
.get_agent()
|
||||||
agent.extend_system_prompt(payload.extension).await;
|
.await
|
||||||
|
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
|
||||||
|
agent.extend_system_prompt(payload.extension.clone()).await;
|
||||||
Ok(Json(ExtendPromptResponse { success: true }))
|
Ok(Json(ExtendPromptResponse { success: true }))
|
||||||
} else {
|
|
||||||
Err(StatusCode::NOT_FOUND)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[axum::debug_handler]
|
|
||||||
async fn create_agent(
|
|
||||||
State(state): State<AppState>,
|
|
||||||
headers: HeaderMap,
|
|
||||||
Json(payload): Json<CreateAgentRequest>,
|
|
||||||
) -> Result<Json<CreateAgentResponse>, StatusCode> {
|
|
||||||
verify_secret_key(&headers, &state)?;
|
|
||||||
|
|
||||||
// Set the environment variable for the model if provided
|
|
||||||
if let Some(model) = &payload.model {
|
|
||||||
let env_var_key = format!("{}_MODEL", payload.provider.to_uppercase());
|
|
||||||
env::set_var(env_var_key.clone(), model);
|
|
||||||
println!("Set environment variable: {}={}", env_var_key, model);
|
|
||||||
}
|
|
||||||
|
|
||||||
let config = Config::global();
|
|
||||||
let model = payload.model.unwrap_or_else(|| {
|
|
||||||
config
|
|
||||||
.get_param("GOOSE_MODEL")
|
|
||||||
.expect("Did not find a model on payload or in env")
|
|
||||||
});
|
|
||||||
let model_config = ModelConfig::new(model);
|
|
||||||
let provider =
|
|
||||||
providers::create(&payload.provider, model_config).expect("Failed to create provider");
|
|
||||||
|
|
||||||
let version = String::from("goose");
|
|
||||||
let new_agent = Agent::new(provider);
|
|
||||||
|
|
||||||
let mut agent = state.agent.write().await;
|
|
||||||
*agent = Some(new_agent);
|
|
||||||
|
|
||||||
Ok(Json(CreateAgentResponse { version }))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn list_providers() -> Json<Vec<ProviderList>> {
|
async fn list_providers() -> Json<Vec<ProviderList>> {
|
||||||
@@ -168,7 +129,7 @@ async fn list_providers() -> Json<Vec<ProviderList>> {
|
|||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
async fn get_tools(
|
async fn get_tools(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
Query(query): Query<GetToolsQuery>,
|
Query(query): Query<GetToolsQuery>,
|
||||||
) -> Result<Json<Vec<ToolInfo>>, StatusCode> {
|
) -> Result<Json<Vec<ToolInfo>>, StatusCode> {
|
||||||
@@ -176,8 +137,10 @@ async fn get_tools(
|
|||||||
|
|
||||||
let config = Config::global();
|
let config = Config::global();
|
||||||
let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
|
let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
|
||||||
let agent = state.agent.read().await;
|
let agent = state
|
||||||
let agent = agent.as_ref().ok_or(StatusCode::PRECONDITION_REQUIRED)?;
|
.get_agent()
|
||||||
|
.await
|
||||||
|
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
|
||||||
let permission_manager = PermissionManager::default();
|
let permission_manager = PermissionManager::default();
|
||||||
|
|
||||||
let mut tools: Vec<ToolInfo> = agent
|
let mut tools: Vec<ToolInfo> = agent
|
||||||
@@ -210,12 +173,56 @@ async fn get_tools(
|
|||||||
Ok(Json(tools))
|
Ok(Json(tools))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn routes(state: AppState) -> Router {
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
path = "/agent/update_provider",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Update provider completed", body = String),
|
||||||
|
(status = 500, description = "Internal server error")
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
async fn update_agent_provider(
|
||||||
|
State(state): State<Arc<AppState>>,
|
||||||
|
headers: HeaderMap,
|
||||||
|
Json(payload): Json<UpdateProviderRequest>,
|
||||||
|
) -> Result<StatusCode, StatusCode> {
|
||||||
|
// Verify secret key
|
||||||
|
let secret_key = headers
|
||||||
|
.get("X-Secret-Key")
|
||||||
|
.and_then(|value| value.to_str().ok())
|
||||||
|
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||||
|
|
||||||
|
if secret_key != state.secret_key {
|
||||||
|
return Err(StatusCode::UNAUTHORIZED);
|
||||||
|
}
|
||||||
|
|
||||||
|
let agent = state
|
||||||
|
.get_agent()
|
||||||
|
.await
|
||||||
|
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
|
||||||
|
|
||||||
|
let config = Config::global();
|
||||||
|
let model = payload.model.unwrap_or_else(|| {
|
||||||
|
config
|
||||||
|
.get_param("GOOSE_MODEL")
|
||||||
|
.expect("Did not find a model on payload or in env to update provider with")
|
||||||
|
});
|
||||||
|
let model_config = ModelConfig::new(model);
|
||||||
|
let new_provider = create(&payload.provider, model_config).unwrap();
|
||||||
|
agent
|
||||||
|
.update_provider(new_provider)
|
||||||
|
.await
|
||||||
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||||
|
|
||||||
|
Ok(StatusCode::OK)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn routes(state: Arc<AppState>) -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/agent/versions", get(get_versions))
|
.route("/agent/versions", get(get_versions))
|
||||||
.route("/agent/providers", get(list_providers))
|
.route("/agent/providers", get(list_providers))
|
||||||
.route("/agent/prompt", post(extend_prompt))
|
.route("/agent/prompt", post(extend_prompt))
|
||||||
.route("/agent/tools", get(get_tools))
|
.route("/agent/tools", get(get_tools))
|
||||||
.route("/agent", post(create_agent))
|
.route("/agent/update_provider", post(update_agent_provider))
|
||||||
.with_state(state)
|
.with_state(state)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ use once_cell::sync::Lazy;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use serde_yaml;
|
use serde_yaml;
|
||||||
use std::collections::HashMap;
|
use std::{collections::HashMap, sync::Arc};
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
@@ -89,7 +89,7 @@ pub struct UpsertPermissionsQuery {
|
|||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn upsert_config(
|
pub async fn upsert_config(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
Json(query): Json<UpsertConfigQuery>,
|
Json(query): Json<UpsertConfigQuery>,
|
||||||
) -> Result<Json<Value>, StatusCode> {
|
) -> Result<Json<Value>, StatusCode> {
|
||||||
@@ -116,7 +116,7 @@ pub async fn upsert_config(
|
|||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn remove_config(
|
pub async fn remove_config(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
Json(query): Json<ConfigKeyQuery>,
|
Json(query): Json<ConfigKeyQuery>,
|
||||||
) -> Result<Json<String>, StatusCode> {
|
) -> Result<Json<String>, StatusCode> {
|
||||||
@@ -148,7 +148,7 @@ pub async fn remove_config(
|
|||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn read_config(
|
pub async fn read_config(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
Json(query): Json<ConfigKeyQuery>,
|
Json(query): Json<ConfigKeyQuery>,
|
||||||
) -> Result<Json<Value>, StatusCode> {
|
) -> Result<Json<Value>, StatusCode> {
|
||||||
@@ -180,7 +180,7 @@ pub async fn read_config(
|
|||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn get_extensions(
|
pub async fn get_extensions(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
) -> Result<Json<ExtensionResponse>, StatusCode> {
|
) -> Result<Json<ExtensionResponse>, StatusCode> {
|
||||||
verify_secret_key(&headers, &state)?;
|
verify_secret_key(&headers, &state)?;
|
||||||
@@ -213,7 +213,7 @@ pub async fn get_extensions(
|
|||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn add_extension(
|
pub async fn add_extension(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
Json(extension_query): Json<ExtensionQuery>,
|
Json(extension_query): Json<ExtensionQuery>,
|
||||||
) -> Result<Json<String>, StatusCode> {
|
) -> Result<Json<String>, StatusCode> {
|
||||||
@@ -251,7 +251,7 @@ pub async fn add_extension(
|
|||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn remove_extension(
|
pub async fn remove_extension(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
axum::extract::Path(name): axum::extract::Path<String>,
|
axum::extract::Path(name): axum::extract::Path<String>,
|
||||||
) -> Result<Json<String>, StatusCode> {
|
) -> Result<Json<String>, StatusCode> {
|
||||||
@@ -272,7 +272,7 @@ pub async fn remove_extension(
|
|||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn read_all_config(
|
pub async fn read_all_config(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
) -> Result<Json<ConfigResponse>, StatusCode> {
|
) -> Result<Json<ConfigResponse>, StatusCode> {
|
||||||
// Use the helper function to verify the secret key
|
// Use the helper function to verify the secret key
|
||||||
@@ -297,7 +297,7 @@ pub async fn read_all_config(
|
|||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn providers(
|
pub async fn providers(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
) -> Result<Json<Vec<ProviderDetails>>, StatusCode> {
|
) -> Result<Json<Vec<ProviderDetails>>, StatusCode> {
|
||||||
verify_secret_key(&headers, &state)?;
|
verify_secret_key(&headers, &state)?;
|
||||||
@@ -332,7 +332,7 @@ pub async fn providers(
|
|||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn init_config(
|
pub async fn init_config(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
) -> Result<Json<String>, StatusCode> {
|
) -> Result<Json<String>, StatusCode> {
|
||||||
verify_secret_key(&headers, &state)?;
|
verify_secret_key(&headers, &state)?;
|
||||||
@@ -402,7 +402,7 @@ pub async fn init_config(
|
|||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn upsert_permissions(
|
pub async fn upsert_permissions(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
Json(query): Json<UpsertPermissionsQuery>,
|
Json(query): Json<UpsertPermissionsQuery>,
|
||||||
) -> Result<Json<String>, StatusCode> {
|
) -> Result<Json<String>, StatusCode> {
|
||||||
@@ -435,7 +435,7 @@ pub static APP_STRATEGY: Lazy<AppStrategyArgs> = Lazy::new(|| AppStrategyArgs {
|
|||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn backup_config(
|
pub async fn backup_config(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
) -> Result<Json<String>, StatusCode> {
|
) -> Result<Json<String>, StatusCode> {
|
||||||
verify_secret_key(&headers, &state)?;
|
verify_secret_key(&headers, &state)?;
|
||||||
@@ -466,7 +466,7 @@ pub async fn backup_config(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn routes(state: AppState) -> Router {
|
pub fn routes(state: Arc<AppState>) -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/config", get(read_all_config))
|
.route("/config", get(read_all_config))
|
||||||
.route("/config/upsert", post(upsert_config))
|
.route("/config/upsert", post(upsert_config))
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ use http::{HeaderMap, StatusCode};
|
|||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::collections::HashMap;
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
struct ConfigResponse {
|
struct ConfigResponse {
|
||||||
@@ -26,7 +26,7 @@ struct ConfigRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn store_config(
|
async fn store_config(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
Json(request): Json<ConfigRequest>,
|
Json(request): Json<ConfigRequest>,
|
||||||
) -> Result<Json<ConfigResponse>, StatusCode> {
|
) -> Result<Json<ConfigResponse>, StatusCode> {
|
||||||
@@ -148,7 +148,7 @@ pub struct GetConfigResponse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_config(
|
pub async fn get_config(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
Query(query): Query<GetConfigQuery>,
|
Query(query): Query<GetConfigQuery>,
|
||||||
) -> Result<Json<GetConfigResponse>, StatusCode> {
|
) -> Result<Json<GetConfigResponse>, StatusCode> {
|
||||||
@@ -174,7 +174,7 @@ struct DeleteConfigRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn delete_config(
|
async fn delete_config(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
Json(request): Json<DeleteConfigRequest>,
|
Json(request): Json<DeleteConfigRequest>,
|
||||||
) -> Result<StatusCode, StatusCode> {
|
) -> Result<StatusCode, StatusCode> {
|
||||||
@@ -193,7 +193,7 @@ async fn delete_config(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn routes(state: AppState) -> Router {
|
pub fn routes(state: Arc<AppState>) -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/configs/providers", post(check_provider_configs))
|
.route("/configs/providers", post(check_provider_configs))
|
||||||
.route("/configs/get", get(get_config))
|
.route("/configs/get", get(get_config))
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ use axum::{
|
|||||||
};
|
};
|
||||||
use goose::message::Message;
|
use goose::message::Message;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
// Direct message serialization for context mgmt request
|
// Direct message serialization for context mgmt request
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
@@ -26,15 +27,16 @@ pub struct ContextManageResponse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn manage_context(
|
async fn manage_context(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
Json(request): Json<ContextManageRequest>,
|
Json(request): Json<ContextManageRequest>,
|
||||||
) -> Result<Json<ContextManageResponse>, StatusCode> {
|
) -> Result<Json<ContextManageResponse>, StatusCode> {
|
||||||
verify_secret_key(&headers, &state)?;
|
verify_secret_key(&headers, &state)?;
|
||||||
|
|
||||||
// Get a lock on the shared agent
|
let agent = state
|
||||||
let agent = state.agent.read().await;
|
.get_agent()
|
||||||
let agent = agent.as_ref().ok_or(StatusCode::PRECONDITION_REQUIRED)?;
|
.await
|
||||||
|
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
|
||||||
|
|
||||||
let mut processed_messages: Vec<Message> = vec![];
|
let mut processed_messages: Vec<Message> = vec![];
|
||||||
let mut token_counts: Vec<usize> = vec![];
|
let mut token_counts: Vec<usize> = vec![];
|
||||||
@@ -57,7 +59,7 @@ async fn manage_context(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Configure routes for this module
|
// Configure routes for this module
|
||||||
pub fn routes(state: AppState) -> Router {
|
pub fn routes(state: Arc<AppState>) -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/context/manage", post(manage_context))
|
.route("/context/manage", post(manage_context))
|
||||||
.with_state(state)
|
.with_state(state)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
use std::env;
|
use std::env;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
use std::sync::Arc;
|
||||||
use std::sync::OnceLock;
|
use std::sync::OnceLock;
|
||||||
|
|
||||||
use super::utils::verify_secret_key;
|
use super::utils::verify_secret_key;
|
||||||
@@ -79,7 +80,7 @@ struct ExtensionResponse {
|
|||||||
|
|
||||||
/// Handler for adding a new extension configuration.
|
/// Handler for adding a new extension configuration.
|
||||||
async fn add_extension(
|
async fn add_extension(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
raw: axum::extract::Json<serde_json::Value>,
|
raw: axum::extract::Json<serde_json::Value>,
|
||||||
) -> Result<Json<ExtensionResponse>, StatusCode> {
|
) -> Result<Json<ExtensionResponse>, StatusCode> {
|
||||||
@@ -228,9 +229,11 @@ async fn add_extension(
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
// Acquire a lock on the agent and attempt to add the extension.
|
// Get a reference to the agent
|
||||||
let mut agent = state.agent.write().await;
|
let agent = state
|
||||||
let agent = agent.as_mut().ok_or(StatusCode::PRECONDITION_REQUIRED)?;
|
.get_agent()
|
||||||
|
.await
|
||||||
|
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
|
||||||
let response = agent.add_extension(extension_config).await;
|
let response = agent.add_extension(extension_config).await;
|
||||||
|
|
||||||
// Respond with the result.
|
// Respond with the result.
|
||||||
@@ -254,15 +257,17 @@ async fn add_extension(
|
|||||||
|
|
||||||
/// Handler for removing an extension by name
|
/// Handler for removing an extension by name
|
||||||
async fn remove_extension(
|
async fn remove_extension(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
Json(name): Json<String>,
|
Json(name): Json<String>,
|
||||||
) -> Result<Json<ExtensionResponse>, StatusCode> {
|
) -> Result<Json<ExtensionResponse>, StatusCode> {
|
||||||
verify_secret_key(&headers, &state)?;
|
verify_secret_key(&headers, &state)?;
|
||||||
|
|
||||||
// Acquire a lock on the agent and attempt to remove the extension
|
// Get a reference to the agent
|
||||||
let mut agent = state.agent.write().await;
|
let agent = state
|
||||||
let agent = agent.as_mut().ok_or(StatusCode::PRECONDITION_REQUIRED)?;
|
.get_agent()
|
||||||
|
.await
|
||||||
|
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
|
||||||
agent.remove_extension(&name).await;
|
agent.remove_extension(&name).await;
|
||||||
|
|
||||||
Ok(Json(ExtensionResponse {
|
Ok(Json(ExtensionResponse {
|
||||||
@@ -272,7 +277,7 @@ async fn remove_extension(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Registers the extension management routes with the Axum router.
|
/// Registers the extension management routes with the Axum router.
|
||||||
pub fn routes(state: AppState) -> Router {
|
pub fn routes(state: Arc<AppState>) -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/extensions/add", post(add_extension))
|
.route("/extensions/add", post(add_extension))
|
||||||
.route("/extensions/remove", post(remove_extension))
|
.route("/extensions/remove", post(remove_extension))
|
||||||
|
|||||||
@@ -9,10 +9,12 @@ pub mod recipe;
|
|||||||
pub mod reply;
|
pub mod reply;
|
||||||
pub mod session;
|
pub mod session;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use axum::Router;
|
use axum::Router;
|
||||||
|
|
||||||
// Function to configure all routes
|
// Function to configure all routes
|
||||||
pub fn configure(state: crate::state::AppState) -> Router {
|
pub fn configure(state: Arc<crate::state::AppState>) -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
.merge(health::routes())
|
.merge(health::routes())
|
||||||
.merge(reply::routes(state.clone()))
|
.merge(reply::routes(state.clone()))
|
||||||
@@ -22,5 +24,5 @@ pub fn configure(state: crate::state::AppState) -> Router {
|
|||||||
.merge(configs::routes(state.clone()))
|
.merge(configs::routes(state.clone()))
|
||||||
.merge(config_management::routes(state.clone()))
|
.merge(config_management::routes(state.clone()))
|
||||||
.merge(recipe::routes(state.clone()))
|
.merge(recipe::routes(state.clone()))
|
||||||
.merge(session::routes(state))
|
.merge(session::routes(state.clone()))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use axum::{extract::State, http::StatusCode, routing::post, Json, Router};
|
use axum::{extract::State, http::StatusCode, routing::post, Json, Router};
|
||||||
use goose::message::Message;
|
use goose::message::Message;
|
||||||
use goose::recipe::Recipe;
|
use goose::recipe::Recipe;
|
||||||
@@ -34,17 +36,17 @@ pub struct CreateRecipeResponse {
|
|||||||
|
|
||||||
/// Create a Recipe configuration from the current state of an agent
|
/// Create a Recipe configuration from the current state of an agent
|
||||||
async fn create_recipe(
|
async fn create_recipe(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
Json(request): Json<CreateRecipeRequest>,
|
Json(request): Json<CreateRecipeRequest>,
|
||||||
) -> Result<Json<CreateRecipeResponse>, (StatusCode, Json<CreateRecipeResponse>)> {
|
) -> Result<Json<CreateRecipeResponse>, (StatusCode, Json<CreateRecipeResponse>)> {
|
||||||
let agent = state.agent.read().await;
|
|
||||||
let agent = agent.as_ref().ok_or_else(|| {
|
|
||||||
let error_response = CreateRecipeResponse {
|
let error_response = CreateRecipeResponse {
|
||||||
recipe: None,
|
recipe: None,
|
||||||
error: Some("Agent not initialized".to_string()),
|
error: Some("Missing agent".to_string()),
|
||||||
};
|
};
|
||||||
(StatusCode::PRECONDITION_REQUIRED, Json(error_response))
|
let agent = state
|
||||||
})?;
|
.get_agent()
|
||||||
|
.await
|
||||||
|
.map_err(|_| (StatusCode::PRECONDITION_FAILED, Json(error_response)))?;
|
||||||
|
|
||||||
// Create base recipe from agent state and messages
|
// Create base recipe from agent state and messages
|
||||||
let recipe_result = agent.create_recipe(request.messages).await;
|
let recipe_result = agent.create_recipe(request.messages).await;
|
||||||
@@ -82,7 +84,7 @@ async fn create_recipe(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn routes(state: AppState) -> Router {
|
pub fn routes(state: Arc<AppState>) -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/recipe/create", post(create_recipe))
|
.route("/recipe/create", post(create_recipe))
|
||||||
.with_state(state)
|
.with_state(state)
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ use std::{
|
|||||||
convert::Infallible,
|
convert::Infallible,
|
||||||
path::PathBuf,
|
path::PathBuf,
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
|
sync::Arc,
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
time::Duration,
|
time::Duration,
|
||||||
};
|
};
|
||||||
@@ -101,7 +102,7 @@ async fn stream_event(
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn handler(
|
async fn handler(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
Json(request): Json<ChatRequest>,
|
Json(request): Json<ChatRequest>,
|
||||||
) -> Result<SseResponse, StatusCode> {
|
) -> Result<SseResponse, StatusCode> {
|
||||||
@@ -119,15 +120,34 @@ async fn handler(
|
|||||||
.session_id
|
.session_id
|
||||||
.unwrap_or_else(session::generate_session_id);
|
.unwrap_or_else(session::generate_session_id);
|
||||||
|
|
||||||
// Get a lock on the shared agent
|
|
||||||
let agent = state.agent.clone();
|
|
||||||
|
|
||||||
// Spawn task to handle streaming
|
// Spawn task to handle streaming
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let agent = agent.read().await;
|
let agent = state.get_agent().await;
|
||||||
let agent = match agent.as_ref() {
|
let agent = match agent {
|
||||||
Some(agent) => agent,
|
Ok(agent) => {
|
||||||
None => {
|
let provider = agent.provider().await;
|
||||||
|
match provider {
|
||||||
|
Ok(_) => agent,
|
||||||
|
Err(_) => {
|
||||||
|
let _ = stream_event(
|
||||||
|
MessageEvent::Error {
|
||||||
|
error: "No provider configured".to_string(),
|
||||||
|
},
|
||||||
|
&tx,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let _ = stream_event(
|
||||||
|
MessageEvent::Finish {
|
||||||
|
reason: "error".to_string(),
|
||||||
|
},
|
||||||
|
&tx,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
let _ = stream_event(
|
let _ = stream_event(
|
||||||
MessageEvent::Error {
|
MessageEvent::Error {
|
||||||
error: "No agent configured".to_string(),
|
error: "No agent configured".to_string(),
|
||||||
@@ -147,7 +167,7 @@ async fn handler(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Get the provider first, before starting the reply stream
|
// Get the provider first, before starting the reply stream
|
||||||
let provider = agent.provider();
|
let provider = agent.provider().await;
|
||||||
|
|
||||||
let mut stream = match agent
|
let mut stream = match agent
|
||||||
.reply(
|
.reply(
|
||||||
@@ -204,7 +224,7 @@ async fn handler(
|
|||||||
// Store messages and generate description in background
|
// Store messages and generate description in background
|
||||||
let session_path = session_path.clone();
|
let session_path = session_path.clone();
|
||||||
let messages = all_messages.clone();
|
let messages = all_messages.clone();
|
||||||
let provider = provider.clone();
|
let provider = Arc::clone(provider.as_ref().unwrap());
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(e) = session::persist_messages(&session_path, &messages, Some(provider)).await {
|
if let Err(e) = session::persist_messages(&session_path, &messages, Some(provider)).await {
|
||||||
tracing::error!("Failed to store session history: {:?}", e);
|
tracing::error!("Failed to store session history: {:?}", e);
|
||||||
@@ -262,7 +282,7 @@ struct AskResponse {
|
|||||||
|
|
||||||
// Simple ask an AI for a response, non streaming
|
// Simple ask an AI for a response, non streaming
|
||||||
async fn ask_handler(
|
async fn ask_handler(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
Json(request): Json<AskRequest>,
|
Json(request): Json<AskRequest>,
|
||||||
) -> Result<Json<AskResponse>, StatusCode> {
|
) -> Result<Json<AskResponse>, StatusCode> {
|
||||||
@@ -275,12 +295,13 @@ async fn ask_handler(
|
|||||||
.session_id
|
.session_id
|
||||||
.unwrap_or_else(session::generate_session_id);
|
.unwrap_or_else(session::generate_session_id);
|
||||||
|
|
||||||
let agent = state.agent.clone();
|
let agent = state
|
||||||
let agent = agent.write().await;
|
.get_agent()
|
||||||
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;
|
.await
|
||||||
|
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
|
||||||
|
|
||||||
// Get the provider first, before starting the reply stream
|
// Get the provider first, before starting the reply stream
|
||||||
let provider = agent.provider();
|
let provider = agent.provider().await;
|
||||||
|
|
||||||
// Create a single message for the prompt
|
// Create a single message for the prompt
|
||||||
let messages = vec![Message::user().with_text(request.prompt)];
|
let messages = vec![Message::user().with_text(request.prompt)];
|
||||||
@@ -339,7 +360,7 @@ async fn ask_handler(
|
|||||||
// Store messages and generate description in background
|
// Store messages and generate description in background
|
||||||
let session_path = session_path.clone();
|
let session_path = session_path.clone();
|
||||||
let messages = all_messages.clone();
|
let messages = all_messages.clone();
|
||||||
let provider = provider.clone();
|
let provider = Arc::clone(provider.as_ref().unwrap());
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(e) = session::persist_messages(&session_path, &messages, Some(provider)).await {
|
if let Err(e) = session::persist_messages(&session_path, &messages, Some(provider)).await {
|
||||||
tracing::error!("Failed to store session history: {:?}", e);
|
tracing::error!("Failed to store session history: {:?}", e);
|
||||||
@@ -374,15 +395,16 @@ fn default_principal_type() -> PrincipalType {
|
|||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn confirm_permission(
|
pub async fn confirm_permission(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
Json(request): Json<PermissionConfirmationRequest>,
|
Json(request): Json<PermissionConfirmationRequest>,
|
||||||
) -> Result<Json<Value>, StatusCode> {
|
) -> Result<Json<Value>, StatusCode> {
|
||||||
verify_secret_key(&headers, &state)?;
|
verify_secret_key(&headers, &state)?;
|
||||||
|
|
||||||
let agent = state.agent.clone();
|
let agent = state
|
||||||
let agent = agent.read().await;
|
.get_agent()
|
||||||
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;
|
.await
|
||||||
|
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
|
||||||
|
|
||||||
let permission = match request.action.as_str() {
|
let permission = match request.action.as_str() {
|
||||||
"always_allow" => Permission::AlwaysAllow,
|
"always_allow" => Permission::AlwaysAllow,
|
||||||
@@ -410,7 +432,7 @@ struct ToolResultRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn submit_tool_result(
|
async fn submit_tool_result(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
raw: axum::extract::Json<serde_json::Value>,
|
raw: axum::extract::Json<serde_json::Value>,
|
||||||
) -> Result<Json<Value>, StatusCode> {
|
) -> Result<Json<Value>, StatusCode> {
|
||||||
@@ -435,14 +457,16 @@ async fn submit_tool_result(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let agent = state.agent.read().await;
|
let agent = state
|
||||||
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;
|
.get_agent()
|
||||||
|
.await
|
||||||
|
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
|
||||||
agent.handle_tool_result(payload.id, payload.result).await;
|
agent.handle_tool_result(payload.id, payload.result).await;
|
||||||
Ok(Json(json!({"status": "ok"})))
|
Ok(Json(json!({"status": "ok"})))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure routes for this module
|
// Configure routes for this module
|
||||||
pub fn routes(state: AppState) -> Router {
|
pub fn routes(state: Arc<AppState>) -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/reply", post(handler))
|
.route("/reply", post(handler))
|
||||||
.route("/ask", post(ask_handler))
|
.route("/ask", post(ask_handler))
|
||||||
@@ -496,9 +520,7 @@ mod tests {
|
|||||||
mod integration_tests {
|
mod integration_tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use axum::{body::Body, http::Request};
|
use axum::{body::Body, http::Request};
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::{Mutex, RwLock};
|
|
||||||
use tower::ServiceExt;
|
use tower::ServiceExt;
|
||||||
|
|
||||||
// This test requires tokio runtime
|
// This test requires tokio runtime
|
||||||
@@ -509,12 +531,9 @@ mod tests {
|
|||||||
let mock_provider = Arc::new(MockProvider {
|
let mock_provider = Arc::new(MockProvider {
|
||||||
model_config: mock_model_config,
|
model_config: mock_model_config,
|
||||||
});
|
});
|
||||||
let agent = Agent::new(mock_provider);
|
let agent = Agent::new();
|
||||||
let state = AppState {
|
let _ = agent.update_provider(mock_provider).await;
|
||||||
config: Arc::new(Mutex::new(HashMap::new())),
|
let state = AppState::new(Arc::new(agent), "test-secret".to_string()).await;
|
||||||
agent: Arc::new(RwLock::new(Some(agent))),
|
|
||||||
secret_key: "test-secret".to_string(),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Build router
|
// Build router
|
||||||
let app = routes(state);
|
let app = routes(state);
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
use super::utils::verify_secret_key;
|
use super::utils::verify_secret_key;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::state::AppState;
|
use crate::state::AppState;
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Path, State},
|
extract::{Path, State},
|
||||||
@@ -25,7 +27,7 @@ struct SessionHistoryResponse {
|
|||||||
|
|
||||||
// List all available sessions
|
// List all available sessions
|
||||||
async fn list_sessions(
|
async fn list_sessions(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
) -> Result<Json<SessionListResponse>, StatusCode> {
|
) -> Result<Json<SessionListResponse>, StatusCode> {
|
||||||
verify_secret_key(&headers, &state)?;
|
verify_secret_key(&headers, &state)?;
|
||||||
@@ -38,7 +40,7 @@ async fn list_sessions(
|
|||||||
|
|
||||||
// Get a specific session's history
|
// Get a specific session's history
|
||||||
async fn get_session_history(
|
async fn get_session_history(
|
||||||
State(state): State<AppState>,
|
State(state): State<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
Path(session_id): Path<String>,
|
Path(session_id): Path<String>,
|
||||||
) -> Result<Json<SessionHistoryResponse>, StatusCode> {
|
) -> Result<Json<SessionHistoryResponse>, StatusCode> {
|
||||||
@@ -65,7 +67,7 @@ async fn get_session_history(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Configure routes for this module
|
// Configure routes for this module
|
||||||
pub fn routes(state: AppState) -> Router {
|
pub fn routes(state: Arc<AppState>) -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/sessions", get(list_sessions))
|
.route("/sessions", get(list_sessions))
|
||||||
.route("/sessions/:session_id", get(get_session_history))
|
.route("/sessions/:session_id", get(get_session_history))
|
||||||
|
|||||||
@@ -1,25 +1,34 @@
|
|||||||
use anyhow::Result;
|
|
||||||
use goose::agents::Agent;
|
use goose::agents::Agent;
|
||||||
use serde_json::Value;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::{Mutex, RwLock};
|
|
||||||
|
|
||||||
|
/// Shared reference to an Agent that can be cloned cheaply
|
||||||
|
/// without cloning the underlying Agent object
|
||||||
|
pub type AgentRef = Arc<Agent>;
|
||||||
|
|
||||||
|
/// Thread-safe container for an optional Agent reference
|
||||||
|
/// Outer Arc: Allows multiple route handlers to access the same Mutex
|
||||||
|
/// - Mutex provides exclusive access for updates
|
||||||
|
/// - Option allows for the case where no agent exists yet
|
||||||
|
///
|
||||||
/// Shared application state
|
/// Shared application state
|
||||||
#[allow(dead_code)]
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
pub agent: Arc<RwLock<Option<Agent>>>,
|
// agent: SharedAgentStore,
|
||||||
|
agent: Option<AgentRef>,
|
||||||
pub secret_key: String,
|
pub secret_key: String,
|
||||||
pub config: Arc<Mutex<HashMap<String, Value>>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AppState {
|
impl AppState {
|
||||||
pub async fn new(secret_key: String) -> Result<Self> {
|
pub async fn new(agent: AgentRef, secret_key: String) -> Arc<AppState> {
|
||||||
Ok(Self {
|
Arc::new(Self {
|
||||||
agent: Arc::new(RwLock::new(None)),
|
agent: Some(agent.clone()),
|
||||||
secret_key,
|
secret_key,
|
||||||
config: Arc::new(Mutex::new(HashMap::new())),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn get_agent(&self) -> Result<Arc<Agent>, anyhow::Error> {
|
||||||
|
self.agent
|
||||||
|
.clone()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Agent needs to be created first."))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,8 @@ async fn main() {
|
|||||||
let provider = Arc::new(DatabricksProvider::default());
|
let provider = Arc::new(DatabricksProvider::default());
|
||||||
|
|
||||||
// Setup an agent with the developer extension
|
// Setup an agent with the developer extension
|
||||||
let mut agent = Agent::new(provider);
|
let agent = Agent::new();
|
||||||
|
let _ = agent.update_provider(provider).await;
|
||||||
|
|
||||||
let config = ExtensionConfig::stdio(
|
let config = ExtensionConfig::stdio(
|
||||||
"developer",
|
"developer",
|
||||||
|
|||||||
@@ -35,12 +35,11 @@ use super::tool_execution::{ToolFuture, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINE
|
|||||||
|
|
||||||
/// The main goose Agent
|
/// The main goose Agent
|
||||||
pub struct Agent {
|
pub struct Agent {
|
||||||
pub(super) provider: Arc<dyn Provider>,
|
pub(super) provider: Mutex<Option<Arc<dyn Provider>>>,
|
||||||
pub(super) extension_manager: Mutex<ExtensionManager>,
|
pub(super) extension_manager: Mutex<ExtensionManager>,
|
||||||
pub(super) frontend_tools: HashMap<String, FrontendTool>,
|
pub(super) frontend_tools: Mutex<HashMap<String, FrontendTool>>,
|
||||||
pub(super) frontend_instructions: Option<String>,
|
pub(super) frontend_instructions: Mutex<Option<String>>,
|
||||||
pub(super) prompt_manager: PromptManager,
|
pub(super) prompt_manager: Mutex<PromptManager>,
|
||||||
// Channels for tool results and confirmations
|
|
||||||
pub(super) confirmation_tx: mpsc::Sender<(String, PermissionConfirmation)>,
|
pub(super) confirmation_tx: mpsc::Sender<(String, PermissionConfirmation)>,
|
||||||
pub(super) confirmation_rx: Mutex<mpsc::Receiver<(String, PermissionConfirmation)>>,
|
pub(super) confirmation_rx: Mutex<mpsc::Receiver<(String, PermissionConfirmation)>>,
|
||||||
pub(super) tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
|
pub(super) tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
|
||||||
@@ -48,41 +47,52 @@ pub struct Agent {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Agent {
|
impl Agent {
|
||||||
pub fn new(provider: Arc<dyn Provider>) -> Self {
|
pub fn new() -> Self {
|
||||||
// Create channels with buffer size 32 (adjust if needed)
|
// Create channels with buffer size 32 (adjust if needed)
|
||||||
let (confirm_tx, confirm_rx) = mpsc::channel(32);
|
let (confirm_tx, confirm_rx) = mpsc::channel(32);
|
||||||
let (tool_tx, tool_rx) = mpsc::channel(32);
|
let (tool_tx, tool_rx) = mpsc::channel(32);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
provider,
|
provider: Mutex::new(None),
|
||||||
extension_manager: Mutex::new(ExtensionManager::new()),
|
extension_manager: Mutex::new(ExtensionManager::new()),
|
||||||
frontend_tools: HashMap::new(),
|
frontend_tools: Mutex::new(HashMap::new()),
|
||||||
frontend_instructions: None,
|
frontend_instructions: Mutex::new(None),
|
||||||
prompt_manager: PromptManager::new(),
|
prompt_manager: Mutex::new(PromptManager::new()),
|
||||||
confirmation_tx: confirm_tx,
|
confirmation_tx: confirm_tx,
|
||||||
confirmation_rx: Mutex::new(confirm_rx),
|
confirmation_rx: Mutex::new(confirm_rx),
|
||||||
tool_result_tx: tool_tx,
|
tool_result_tx: tool_tx,
|
||||||
tool_result_rx: Arc::new(Mutex::new(tool_rx)),
|
tool_result_rx: Arc::new(Mutex::new(tool_rx)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Agent {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Agent {
|
||||||
/// Get a reference count clone to the provider
|
/// Get a reference count clone to the provider
|
||||||
pub fn provider(&self) -> Arc<dyn Provider> {
|
pub async fn provider(&self) -> Result<Arc<dyn Provider>, anyhow::Error> {
|
||||||
Arc::clone(&self.provider)
|
match &*self.provider.lock().await {
|
||||||
|
Some(provider) => Ok(Arc::clone(provider)),
|
||||||
|
None => Err(anyhow!("Provider not set")),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if a tool is a frontend tool
|
/// Check if a tool is a frontend tool
|
||||||
pub fn is_frontend_tool(&self, name: &str) -> bool {
|
pub async fn is_frontend_tool(&self, name: &str) -> bool {
|
||||||
self.frontend_tools.contains_key(name)
|
self.frontend_tools.lock().await.contains_key(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a reference to a frontend tool
|
/// Get a reference to a frontend tool
|
||||||
pub fn get_frontend_tool(&self, name: &str) -> Option<&FrontendTool> {
|
pub async fn get_frontend_tool(&self, name: &str) -> Option<FrontendTool> {
|
||||||
self.frontend_tools.get(name)
|
self.frontend_tools.lock().await.get(name).cloned()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get all tools from all clients with proper prefixing
|
/// Get all tools from all clients with proper prefixing
|
||||||
pub async fn get_prefixed_tools(&mut self) -> ExtensionResult<Vec<Tool>> {
|
pub async fn get_prefixed_tools(&self) -> ExtensionResult<Vec<Tool>> {
|
||||||
let mut tools = self
|
let mut tools = self
|
||||||
.extension_manager
|
.extension_manager
|
||||||
.lock()
|
.lock()
|
||||||
@@ -91,7 +101,8 @@ impl Agent {
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Add frontend tools directly - they don't need prefixing since they're already uniquely named
|
// Add frontend tools directly - they don't need prefixing since they're already uniquely named
|
||||||
for frontend_tool in self.frontend_tools.values() {
|
let frontend_tools = self.frontend_tools.lock().await;
|
||||||
|
for frontend_tool in frontend_tools.values() {
|
||||||
tools.push(frontend_tool.tool.clone());
|
tools.push(frontend_tool.tool.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,7 +146,7 @@ impl Agent {
|
|||||||
.await
|
.await
|
||||||
} else if tool_call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME {
|
} else if tool_call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME {
|
||||||
extension_manager.search_available_extensions().await
|
extension_manager.search_available_extensions().await
|
||||||
} else if self.is_frontend_tool(&tool_call.name) {
|
} else if self.is_frontend_tool(&tool_call.name).await {
|
||||||
// For frontend tools, return an error indicating we need frontend execution
|
// For frontend tools, return an error indicating we need frontend execution
|
||||||
Err(ToolError::ExecutionError(
|
Err(ToolError::ExecutionError(
|
||||||
"Frontend tool execution required".to_string(),
|
"Frontend tool execution required".to_string(),
|
||||||
@@ -212,7 +223,7 @@ impl Agent {
|
|||||||
(request_id, result)
|
(request_id, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn add_extension(&mut self, extension: ExtensionConfig) -> ExtensionResult<()> {
|
pub async fn add_extension(&self, extension: ExtensionConfig) -> ExtensionResult<()> {
|
||||||
match &extension {
|
match &extension {
|
||||||
ExtensionConfig::Frontend {
|
ExtensionConfig::Frontend {
|
||||||
name: _,
|
name: _,
|
||||||
@@ -221,19 +232,21 @@ impl Agent {
|
|||||||
bundled: _,
|
bundled: _,
|
||||||
} => {
|
} => {
|
||||||
// For frontend tools, just store them in the frontend_tools map
|
// For frontend tools, just store them in the frontend_tools map
|
||||||
|
let mut frontend_tools = self.frontend_tools.lock().await;
|
||||||
for tool in tools {
|
for tool in tools {
|
||||||
let frontend_tool = FrontendTool {
|
let frontend_tool = FrontendTool {
|
||||||
name: tool.name.clone(),
|
name: tool.name.clone(),
|
||||||
tool: tool.clone(),
|
tool: tool.clone(),
|
||||||
};
|
};
|
||||||
self.frontend_tools.insert(tool.name.clone(), frontend_tool);
|
frontend_tools.insert(tool.name.clone(), frontend_tool);
|
||||||
}
|
}
|
||||||
// Store instructions if provided, using "frontend" as the key
|
// Store instructions if provided, using "frontend" as the key
|
||||||
|
let mut frontend_instructions = self.frontend_instructions.lock().await;
|
||||||
if let Some(instructions) = instructions {
|
if let Some(instructions) = instructions {
|
||||||
self.frontend_instructions = Some(instructions.clone());
|
*frontend_instructions = Some(instructions.clone());
|
||||||
} else {
|
} else {
|
||||||
// Default frontend instructions if none provided
|
// Default frontend instructions if none provided
|
||||||
self.frontend_instructions = Some(
|
*frontend_instructions = Some(
|
||||||
"The following tools are provided directly by the frontend and will be executed by the frontend when called.".to_string(),
|
"The following tools are provided directly by the frontend and will be executed by the frontend when called.".to_string(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -269,7 +282,7 @@ impl Agent {
|
|||||||
prefixed_tools
|
prefixed_tools
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn remove_extension(&mut self, name: &str) {
|
pub async fn remove_extension(&self, name: &str) {
|
||||||
let mut extension_manager = self.extension_manager.lock().await;
|
let mut extension_manager = self.extension_manager.lock().await;
|
||||||
extension_manager
|
extension_manager
|
||||||
.remove_extension(name)
|
.remove_extension(name)
|
||||||
@@ -329,7 +342,7 @@ impl Agent {
|
|||||||
let _ = reply_span.enter();
|
let _ = reply_span.enter();
|
||||||
loop {
|
loop {
|
||||||
match Self::generate_response_from_provider(
|
match Self::generate_response_from_provider(
|
||||||
self.provider(),
|
self.provider().await?,
|
||||||
&system_prompt,
|
&system_prompt,
|
||||||
&messages,
|
&messages,
|
||||||
&tools,
|
&tools,
|
||||||
@@ -345,7 +358,7 @@ impl Agent {
|
|||||||
let (frontend_requests,
|
let (frontend_requests,
|
||||||
remaining_requests,
|
remaining_requests,
|
||||||
filtered_response) =
|
filtered_response) =
|
||||||
self.categorize_tool_requests(&response);
|
self.categorize_tool_requests(&response).await;
|
||||||
|
|
||||||
|
|
||||||
// Yield the assistant's response with frontend tool requests filtered out
|
// Yield the assistant's response with frontend tool requests filtered out
|
||||||
@@ -396,8 +409,7 @@ impl Agent {
|
|||||||
tools_with_readonly_annotation.clone(),
|
tools_with_readonly_annotation.clone(),
|
||||||
tools_without_annotation.clone(),
|
tools_without_annotation.clone(),
|
||||||
&mut permission_manager,
|
&mut permission_manager,
|
||||||
self.provider(),
|
self.provider().await?).await;
|
||||||
).await;
|
|
||||||
|
|
||||||
// Handle pre-approved and read-only tools in parallel
|
// Handle pre-approved and read-only tools in parallel
|
||||||
let mut tool_futures: Vec<ToolFuture> = Vec::new();
|
let mut tool_futures: Vec<ToolFuture> = Vec::new();
|
||||||
@@ -492,13 +504,21 @@ impl Agent {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Extend the system prompt with one line of additional instruction
|
/// Extend the system prompt with one line of additional instruction
|
||||||
pub async fn extend_system_prompt(&mut self, instruction: String) {
|
pub async fn extend_system_prompt(&self, instruction: String) {
|
||||||
self.prompt_manager.add_system_prompt_extra(instruction);
|
let mut prompt_manager = self.prompt_manager.lock().await;
|
||||||
|
prompt_manager.add_system_prompt_extra(instruction);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update the provider used by this agent
|
||||||
|
pub async fn update_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
|
||||||
|
*self.provider.lock().await = Some(provider);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Override the system prompt with a custom template
|
/// Override the system prompt with a custom template
|
||||||
pub async fn override_system_prompt(&mut self, template: String) {
|
pub async fn override_system_prompt(&self, template: String) {
|
||||||
self.prompt_manager.set_system_prompt_override(template);
|
let mut prompt_manager = self.prompt_manager.lock().await;
|
||||||
|
prompt_manager.set_system_prompt_override(template);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn list_extension_prompts(&self) -> HashMap<String, Vec<Prompt>> {
|
pub async fn list_extension_prompts(&self) -> HashMap<String, Vec<Prompt>> {
|
||||||
@@ -563,23 +583,29 @@ impl Agent {
|
|||||||
let extensions_info = extension_manager.get_extensions_info().await;
|
let extensions_info = extension_manager.get_extensions_info().await;
|
||||||
|
|
||||||
// Get model name from provider
|
// Get model name from provider
|
||||||
let model_config = self.provider.get_model_config();
|
let provider = self.provider().await?;
|
||||||
|
let model_config = provider.get_model_config();
|
||||||
let model_name = &model_config.model_name;
|
let model_name = &model_config.model_name;
|
||||||
|
|
||||||
let system_prompt = self.prompt_manager.build_system_prompt(
|
let prompt_manager = self.prompt_manager.lock().await;
|
||||||
|
let system_prompt = prompt_manager.build_system_prompt(
|
||||||
extensions_info,
|
extensions_info,
|
||||||
self.frontend_instructions.clone(),
|
self.frontend_instructions.lock().await.clone(),
|
||||||
extension_manager.suggest_disable_extensions_prompt().await,
|
extension_manager.suggest_disable_extensions_prompt().await,
|
||||||
Some(model_name),
|
Some(model_name),
|
||||||
);
|
);
|
||||||
|
|
||||||
let recipe_prompt = self.prompt_manager.get_recipe_prompt().await;
|
let recipe_prompt = prompt_manager.get_recipe_prompt().await;
|
||||||
let tools = extension_manager.get_prefixed_tools(None).await?;
|
let tools = extension_manager.get_prefixed_tools(None).await?;
|
||||||
|
|
||||||
messages.push(Message::user().with_text(recipe_prompt));
|
messages.push(Message::user().with_text(recipe_prompt));
|
||||||
|
|
||||||
let (result, _usage) = self
|
let (result, _usage) = self
|
||||||
.provider
|
.provider
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
.complete(&system_prompt, &messages, &tools)
|
.complete(&system_prompt, &messages, &tools)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ impl Agent {
|
|||||||
&self,
|
&self,
|
||||||
messages: &[Message], // last message is a user msg that led to assistant message with_context_length_exceeded
|
messages: &[Message], // last message is a user msg that led to assistant message with_context_length_exceeded
|
||||||
) -> Result<(Vec<Message>, Vec<usize>), anyhow::Error> {
|
) -> Result<(Vec<Message>, Vec<usize>), anyhow::Error> {
|
||||||
let provider = self.provider.clone();
|
let provider = self.provider().await?;
|
||||||
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
||||||
let target_context_limit = estimate_target_context_limit(provider);
|
let target_context_limit = estimate_target_context_limit(provider);
|
||||||
let token_counts = get_messages_token_counts(&token_counter, messages);
|
let token_counts = get_messages_token_counts(&token_counter, messages);
|
||||||
@@ -41,7 +41,7 @@ impl Agent {
|
|||||||
&self,
|
&self,
|
||||||
messages: &[Message], // last message is a user msg that led to assistant message with_context_length_exceeded
|
messages: &[Message], // last message is a user msg that led to assistant message with_context_length_exceeded
|
||||||
) -> Result<(Vec<Message>, Vec<usize>), anyhow::Error> {
|
) -> Result<(Vec<Message>, Vec<usize>), anyhow::Error> {
|
||||||
let provider = self.provider.clone();
|
let provider = self.provider().await?;
|
||||||
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
||||||
let target_context_limit = estimate_target_context_limit(provider.clone());
|
let target_context_limit = estimate_target_context_limit(provider.clone());
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,8 @@ impl Agent {
|
|||||||
let mut tools = self.list_tools(None).await;
|
let mut tools = self.list_tools(None).await;
|
||||||
|
|
||||||
// Add frontend tools
|
// Add frontend tools
|
||||||
for frontend_tool in self.frontend_tools.values() {
|
let frontend_tools = self.frontend_tools.lock().await;
|
||||||
|
for frontend_tool in frontend_tools.values() {
|
||||||
tools.push(frontend_tool.tool.clone());
|
tools.push(frontend_tool.tool.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -31,19 +32,21 @@ impl Agent {
|
|||||||
let extensions_info = extension_manager.get_extensions_info().await;
|
let extensions_info = extension_manager.get_extensions_info().await;
|
||||||
|
|
||||||
// Get model name from provider
|
// Get model name from provider
|
||||||
let model_config = self.provider.get_model_config();
|
let provider = self.provider().await?;
|
||||||
|
let model_config = provider.get_model_config();
|
||||||
let model_name = &model_config.model_name;
|
let model_name = &model_config.model_name;
|
||||||
|
|
||||||
let mut system_prompt = self.prompt_manager.build_system_prompt(
|
let prompt_manager = self.prompt_manager.lock().await;
|
||||||
|
let mut system_prompt = prompt_manager.build_system_prompt(
|
||||||
extensions_info,
|
extensions_info,
|
||||||
self.frontend_instructions.clone(),
|
self.frontend_instructions.lock().await.clone(),
|
||||||
extension_manager.suggest_disable_extensions_prompt().await,
|
extension_manager.suggest_disable_extensions_prompt().await,
|
||||||
Some(model_name),
|
Some(model_name),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Handle toolshim if enabled
|
// Handle toolshim if enabled
|
||||||
let mut toolshim_tools = vec![];
|
let mut toolshim_tools = vec![];
|
||||||
if self.provider.get_model_config().toolshim {
|
if model_config.toolshim {
|
||||||
// If tool interpretation is enabled, modify the system prompt
|
// If tool interpretation is enabled, modify the system prompt
|
||||||
system_prompt = modify_system_prompt_for_tool_json(&system_prompt, &tools);
|
system_prompt = modify_system_prompt_for_tool_json(&system_prompt, &tools);
|
||||||
// Make a copy of tools before emptying
|
// Make a copy of tools before emptying
|
||||||
@@ -115,7 +118,7 @@ impl Agent {
|
|||||||
/// - frontend_requests: Tool requests that should be handled by the frontend
|
/// - frontend_requests: Tool requests that should be handled by the frontend
|
||||||
/// - other_requests: All other tool requests (including requests to enable extensions)
|
/// - other_requests: All other tool requests (including requests to enable extensions)
|
||||||
/// - filtered_message: The original message with frontend tool requests removed
|
/// - filtered_message: The original message with frontend tool requests removed
|
||||||
pub(crate) fn categorize_tool_requests(
|
pub(crate) async fn categorize_tool_requests(
|
||||||
&self,
|
&self,
|
||||||
response: &Message,
|
response: &Message,
|
||||||
) -> (Vec<ToolRequest>, Vec<ToolRequest>, Message) {
|
) -> (Vec<ToolRequest>, Vec<ToolRequest>, Message) {
|
||||||
@@ -133,20 +136,25 @@ impl Agent {
|
|||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Create a filtered message with frontend tool requests removed
|
// Create a filtered message with frontend tool requests removed
|
||||||
let filtered_content = response
|
let mut filtered_content = Vec::new();
|
||||||
.content
|
|
||||||
.iter()
|
// Process each content item one by one
|
||||||
.filter(|c| {
|
for content in &response.content {
|
||||||
if let MessageContent::ToolRequest(req) = c {
|
let should_include = match content {
|
||||||
// Only filter out frontend tool requests
|
MessageContent::ToolRequest(req) => {
|
||||||
if let Ok(tool_call) = &req.tool_call {
|
if let Ok(tool_call) = &req.tool_call {
|
||||||
return !self.is_frontend_tool(&tool_call.name);
|
!self.is_frontend_tool(&tool_call.name).await
|
||||||
}
|
} else {
|
||||||
}
|
|
||||||
true
|
true
|
||||||
})
|
}
|
||||||
.cloned()
|
}
|
||||||
.collect();
|
_ => true,
|
||||||
|
};
|
||||||
|
|
||||||
|
if should_include {
|
||||||
|
filtered_content.push(content.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let filtered_message = Message {
|
let filtered_message = Message {
|
||||||
role: response.role.clone(),
|
role: response.role.clone(),
|
||||||
@@ -160,7 +168,7 @@ impl Agent {
|
|||||||
|
|
||||||
for request in tool_requests {
|
for request in tool_requests {
|
||||||
if let Ok(tool_call) = &request.tool_call {
|
if let Ok(tool_call) = &request.tool_call {
|
||||||
if self.is_frontend_tool(&tool_call.name) {
|
if self.is_frontend_tool(&tool_call.name).await {
|
||||||
frontend_requests.push(request);
|
frontend_requests.push(request);
|
||||||
} else {
|
} else {
|
||||||
other_requests.push(request);
|
other_requests.push(request);
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ impl Agent {
|
|||||||
try_stream! {
|
try_stream! {
|
||||||
for request in tool_requests {
|
for request in tool_requests {
|
||||||
if let Ok(tool_call) = request.tool_call.clone() {
|
if let Ok(tool_call) = request.tool_call.clone() {
|
||||||
if self.is_frontend_tool(&tool_call.name) {
|
if self.is_frontend_tool(&tool_call.name).await {
|
||||||
// Send frontend tool request and wait for response
|
// Send frontend tool request and wait for response
|
||||||
yield Message::assistant().with_frontend_tool_request(
|
yield Message::assistant().with_frontend_tool_request(
|
||||||
request.id.clone(),
|
request.id.clone(),
|
||||||
|
|||||||
@@ -739,10 +739,10 @@ mod tests {
|
|||||||
thread::sleep(Duration::from_millis(i * 10));
|
thread::sleep(Duration::from_millis(i * 10));
|
||||||
|
|
||||||
let extension_key = format!("extension_{}", i);
|
let extension_key = format!("extension_{}", i);
|
||||||
let mut values = config.load_values()?;
|
|
||||||
|
|
||||||
values.insert(
|
// Use set_param which handles concurrent access properly
|
||||||
extension_key.clone(),
|
config.set_param(
|
||||||
|
&extension_key,
|
||||||
serde_json::json!({
|
serde_json::json!({
|
||||||
"name": format!("test_extension_{}", i),
|
"name": format!("test_extension_{}", i),
|
||||||
"version": format!("1.0.{}", i),
|
"version": format!("1.0.{}", i),
|
||||||
@@ -752,10 +752,7 @@ mod tests {
|
|||||||
"option2": i
|
"option2": i
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
);
|
)?;
|
||||||
|
|
||||||
// Write all values atomically
|
|
||||||
config.save_values(values)?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
});
|
});
|
||||||
handles.push(handle);
|
handles.push(handle);
|
||||||
|
|||||||
@@ -110,7 +110,8 @@ async fn run_truncate_test(
|
|||||||
.with_temperature(Some(0.0));
|
.with_temperature(Some(0.0));
|
||||||
let provider = provider_type.create_provider(model_config)?;
|
let provider = provider_type.create_provider(model_config)?;
|
||||||
|
|
||||||
let agent = Agent::new(provider);
|
let agent = Agent::new();
|
||||||
|
agent.update_provider(provider).await?;
|
||||||
let repeat_count = context_window + 10_000;
|
let repeat_count = context_window + 10_000;
|
||||||
let large_message_content = "hello ".repeat(repeat_count);
|
let large_message_content = "hello ".repeat(repeat_count);
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
|
|||||||
@@ -33,12 +33,28 @@ CALCULATOR_TOOL = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Enable Extension tool definition
|
||||||
|
ENABLE_EXTENSION_TOOL = {
|
||||||
|
"name": "enable_extension",
|
||||||
|
"description": "Enable extensions to help complete tasks. Enable an extension by providing the extension name.",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"required": ["extension_name"],
|
||||||
|
"properties": {
|
||||||
|
"extension_name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the extension to enable",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
# Frontend extension configuration
|
# Frontend extension configuration
|
||||||
FRONTEND_CONFIG = {
|
FRONTEND_CONFIG = {
|
||||||
"name": "pythonclient",
|
"name": "pythonclient",
|
||||||
"type": "frontend",
|
"type": "frontend",
|
||||||
"tools": [CALCULATOR_TOOL],
|
"tools": [CALCULATOR_TOOL, ENABLE_EXTENSION_TOOL],
|
||||||
"instructions": "A calculator extension that can perform basic arithmetic operations.",
|
"instructions": "A calculator extension that can perform basic arithmetic operations. Use enable extension tool to add extesions such as fetch, pdf reader, etc.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -47,7 +63,7 @@ async def setup_agent() -> None:
|
|||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
# First create the agent
|
# First create the agent
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{GOOSE_URL}/agent",
|
f"{GOOSE_URL}/agent/update_provider",
|
||||||
json={"provider": "databricks", "model": "goose"},
|
json={"provider": "databricks", "model": "goose"},
|
||||||
headers={"X-Secret-Key": SECRET_KEY},
|
headers={"X-Secret-Key": SECRET_KEY},
|
||||||
)
|
)
|
||||||
@@ -101,6 +117,55 @@ def execute_calculator(args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def get_tools() -> Dict[str, Any]:
|
||||||
|
with httpx.Client() as client:
|
||||||
|
response = client.get(
|
||||||
|
f"{GOOSE_URL}/agent/tools",
|
||||||
|
headers={"X-Secret-Key": SECRET_KEY},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
def execute_enable_extension(args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Execute the enable_extension tool.
|
||||||
|
This function fetches available extensions, finds the one with the provided extension_name,
|
||||||
|
and posts its configuration to the /extensions/add endpoint.
|
||||||
|
"""
|
||||||
|
extension = args
|
||||||
|
extension_name = extension.get("name")
|
||||||
|
|
||||||
|
# Post the extension configuration to enable it
|
||||||
|
with httpx.Client() as client:
|
||||||
|
payload = {
|
||||||
|
"type": extension.get("type"),
|
||||||
|
"name": extension.get("name"),
|
||||||
|
"cmd": extension.get("cmd"),
|
||||||
|
"args": extension.get("args"),
|
||||||
|
"envs": extension.get("envs", {}),
|
||||||
|
"timeout": extension.get("timeout"),
|
||||||
|
"bundled": extension.get("bundled"),
|
||||||
|
}
|
||||||
|
add_response = client.post(
|
||||||
|
f"{GOOSE_URL}/extensions/add",
|
||||||
|
json=payload,
|
||||||
|
headers={"Content-Type": "application/json", "X-Secret-Key": SECRET_KEY},
|
||||||
|
)
|
||||||
|
if add_response.status_code != 200:
|
||||||
|
error_text = add_response.text
|
||||||
|
return [{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"Error: Failed to enable extension: {error_text}",
|
||||||
|
"annotations": None,
|
||||||
|
}]
|
||||||
|
|
||||||
|
return [{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"Successfully enabled extension: {extension_name}",
|
||||||
|
"annotations": None,
|
||||||
|
}]
|
||||||
|
|
||||||
|
|
||||||
def submit_tool_result(tool_id: str, result: List[Dict[str, Any]]) -> None:
|
def submit_tool_result(tool_id: str, result: List[Dict[str, Any]]) -> None:
|
||||||
"""Submit the tool execution result back to Goose.
|
"""Submit the tool execution result back to Goose.
|
||||||
@@ -129,7 +194,7 @@ async def chat_loop() -> None:
|
|||||||
session_id = "test-session"
|
session_id = "test-session"
|
||||||
|
|
||||||
# Use a client with a longer timeout for streaming
|
# Use a client with a longer timeout for streaming
|
||||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||||
# Get user input
|
# Get user input
|
||||||
user_message = input("\nYou: ")
|
user_message = input("\nYou: ")
|
||||||
if user_message.lower() in ["exit", "quit"]:
|
if user_message.lower() in ["exit", "quit"]:
|
||||||
@@ -152,7 +217,7 @@ async def chat_loop() -> None:
|
|||||||
# Process the stream of responses
|
# Process the stream of responses
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
"POST",
|
"POST",
|
||||||
f"{GOOSE_URL}/reply",
|
f"{GOOSE_URL}/reply", # lock
|
||||||
json=payload,
|
json=payload,
|
||||||
headers={
|
headers={
|
||||||
"X-Secret-Key": SECRET_KEY,
|
"X-Secret-Key": SECRET_KEY,
|
||||||
@@ -185,9 +250,26 @@ async def chat_loop() -> None:
|
|||||||
elif content["type"] == "frontendToolRequest":
|
elif content["type"] == "frontendToolRequest":
|
||||||
# Execute the tool and submit results
|
# Execute the tool and submit results
|
||||||
tool_call = content["toolCall"]["value"]
|
tool_call = content["toolCall"]["value"]
|
||||||
|
print(f"\nTool Request: {tool_call}")
|
||||||
|
|
||||||
|
if tool_call['name'] == "calculator":
|
||||||
print(f"Calculator: {tool_call}")
|
print(f"Calculator: {tool_call}")
|
||||||
# Execute the tool
|
# Execute the tool
|
||||||
result = execute_calculator(tool_call["arguments"])
|
result = execute_calculator(tool_call["arguments"])
|
||||||
|
elif tool_call['name'] == "enable_extension":
|
||||||
|
# to trigger this tool, use the instruction "use enable_extension tool with "fetch" extension name"
|
||||||
|
print(f"Enabling fetch extension")
|
||||||
|
result = execute_enable_extension(args={
|
||||||
|
"type": "stdio",
|
||||||
|
"name": "fetch",
|
||||||
|
"cmd": "uvx",
|
||||||
|
"args": ["mcp-server-fetch"],
|
||||||
|
"timeout": 300,
|
||||||
|
"bundled": False
|
||||||
|
})
|
||||||
|
listed_tools = get_tools()
|
||||||
|
print(f"\nTools after enabling extension: {listed_tools}")
|
||||||
|
|
||||||
|
|
||||||
# Submit the result
|
# Submit the result
|
||||||
submit_tool_result(content["id"], result)
|
submit_tool_result(content["id"], result)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ interface initializeAgentProps {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export async function initializeAgent({ model, provider }: initializeAgentProps) {
|
export async function initializeAgent({ model, provider }: initializeAgentProps) {
|
||||||
const response = await fetch(getApiUrl('/agent'), {
|
const response = await fetch(getApiUrl('/agent/update_provider'), {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
|
|||||||
Reference in New Issue
Block a user