mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 14:44:21 +01:00
Refactor goose-api into modules
This commit is contained in:
97
crates/goose-api/src/config.rs
Normal file
97
crates/goose-api/src/config.rs
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
use crate::handlers::{AGENT, EXTENSION_MANAGER};
|
||||||
|
use goose::config::{Config, ExtensionEntry};
|
||||||
|
use goose::agents::ExtensionConfig;
|
||||||
|
use goose::providers::{create, providers};
|
||||||
|
use goose::model::ModelConfig;
|
||||||
|
use tracing::{info, warn, error};
|
||||||
|
use config::{builder::DefaultState, ConfigBuilder, Environment, File};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
pub fn load_configuration() -> std::result::Result<config::Config, config::ConfigError> {
|
||||||
|
let config_path = std::env::var("GOOSE_CONFIG").unwrap_or_else(|_| "config".to_string());
|
||||||
|
let builder = ConfigBuilder::<DefaultState>::default()
|
||||||
|
.add_source(File::with_name(&config_path).required(false))
|
||||||
|
.add_source(Environment::with_prefix("GOOSE_API"));
|
||||||
|
builder.build()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn initialize_provider_config() -> Result<(), anyhow::Error> {
|
||||||
|
let api_config = load_configuration()?;
|
||||||
|
|
||||||
|
let provider_name = std::env::var("GOOSE_API_PROVIDER")
|
||||||
|
.or_else(|_| api_config.get_string("provider"))
|
||||||
|
.unwrap_or_else(|_| "openai".to_string());
|
||||||
|
|
||||||
|
let model_name = std::env::var("GOOSE_API_MODEL")
|
||||||
|
.or_else(|_| api_config.get_string("model"))
|
||||||
|
.unwrap_or_else(|_| "gpt-4o".to_string());
|
||||||
|
|
||||||
|
info!("Initializing with provider: {}, model: {}", provider_name, model_name);
|
||||||
|
|
||||||
|
let config = Config::global();
|
||||||
|
config.set_param("GOOSE_PROVIDER", Value::String(provider_name.clone()))?;
|
||||||
|
config.set_param("GOOSE_MODEL", Value::String(model_name.clone()))?;
|
||||||
|
|
||||||
|
let available_providers = providers();
|
||||||
|
if let Some(provider_meta) = available_providers.iter().find(|p| p.name == provider_name) {
|
||||||
|
for key in &provider_meta.config_keys {
|
||||||
|
let env_name = key.name.clone();
|
||||||
|
if let Ok(value) = std::env::var(&env_name) {
|
||||||
|
if key.secret {
|
||||||
|
config.set_secret(&key.name, Value::String(value))?;
|
||||||
|
info!("Set secret key: {}", key.name);
|
||||||
|
} else {
|
||||||
|
config.set_param(&key.name, Value::String(value))?;
|
||||||
|
info!("Set parameter: {}", key.name);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
warn!("Environment variable not set for key: {}", key.name);
|
||||||
|
if key.required {
|
||||||
|
error!("Required key {} not provided", key.name);
|
||||||
|
return Err(anyhow::anyhow!("Required key {} not provided", key.name));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let model_config = ModelConfig::new(model_name);
|
||||||
|
let provider = create(&provider_name, model_config)?;
|
||||||
|
|
||||||
|
let agent = AGENT.lock().await;
|
||||||
|
agent.update_provider(provider).await?;
|
||||||
|
|
||||||
|
info!("Provider configuration successful");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Error> {
|
||||||
|
if let Ok(ext_table) = config.get_table("extensions") {
|
||||||
|
for (name, ext_config) in ext_table {
|
||||||
|
let entry: ExtensionEntry = ext_config.clone().try_deserialize()
|
||||||
|
.map_err(|e| anyhow::anyhow!("Failed to deserialize extension config for {}: {}", name, e))?;
|
||||||
|
|
||||||
|
if entry.enabled {
|
||||||
|
let extension_config: ExtensionConfig = entry.config;
|
||||||
|
let mut agent = AGENT.lock().await;
|
||||||
|
if let Err(e) = agent.add_extension(extension_config).await {
|
||||||
|
error!("Failed to add extension {}: {}", name, e);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
info!("Skipping disabled extension: {}", name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
warn!("No extensions configured in config file.");
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_init_tests() -> Result<(), anyhow::Error> {
|
||||||
|
info!("Running initialization tests");
|
||||||
|
{
|
||||||
|
let _agent = AGENT.lock().await;
|
||||||
|
info!("Agent initialization test passed");
|
||||||
|
}
|
||||||
|
info!("Initialization tests completed");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
449
crates/goose-api/src/handlers.rs
Normal file
449
crates/goose-api/src/handlers.rs
Normal file
@@ -0,0 +1,449 @@
|
|||||||
|
use warp::{http::HeaderValue, Filter, Rejection};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use uuid::Uuid;
|
||||||
|
use futures_util::TryStreamExt;
|
||||||
|
use tracing::{info, warn, error};
|
||||||
|
use mcp_core::tool::Tool;
|
||||||
|
use goose::agents::{extension::Envs, extension_manager::ExtensionManager, ExtensionConfig, Agent, SessionConfig};
|
||||||
|
use goose::message::Message;
|
||||||
|
use goose::session::{self, Identifier};
|
||||||
|
use goose::config::{Config, ExtensionEntry};
|
||||||
|
use std::sync::LazyLock;
|
||||||
|
|
||||||
|
pub static EXTENSION_MANAGER: LazyLock<ExtensionManager> = LazyLock::new(|| ExtensionManager::default());
|
||||||
|
pub static AGENT: LazyLock<tokio::sync::Mutex<Agent>> = LazyLock::new(|| tokio::sync::Mutex::new(Agent::new()));
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct SessionRequest {
|
||||||
|
pub prompt: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ApiResponse {
|
||||||
|
pub message: String,
|
||||||
|
pub status: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct StartSessionResponse {
|
||||||
|
pub message: String,
|
||||||
|
pub status: String,
|
||||||
|
pub session_id: Uuid,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct SessionReplyRequest {
|
||||||
|
pub session_id: Uuid,
|
||||||
|
pub prompt: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct EndSessionRequest {
|
||||||
|
pub session_id: Uuid,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ExtensionsResponse {
|
||||||
|
pub extensions: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ProviderConfig {
|
||||||
|
pub provider: String,
|
||||||
|
pub model: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ExtensionResponse {
|
||||||
|
pub error: bool,
|
||||||
|
pub message: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub enum ExtensionConfigRequest {
|
||||||
|
#[serde(rename = "sse")]
|
||||||
|
Sse {
|
||||||
|
name: String,
|
||||||
|
uri: String,
|
||||||
|
#[serde(default)]
|
||||||
|
envs: Envs,
|
||||||
|
#[serde(default)]
|
||||||
|
env_keys: Vec<String>,
|
||||||
|
timeout: Option<u64>,
|
||||||
|
},
|
||||||
|
#[serde(rename = "stdio")]
|
||||||
|
Stdio {
|
||||||
|
name: String,
|
||||||
|
cmd: String,
|
||||||
|
#[serde(default)]
|
||||||
|
args: Vec<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
envs: Envs,
|
||||||
|
#[serde(default)]
|
||||||
|
env_keys: Vec<String>,
|
||||||
|
timeout: Option<u64>,
|
||||||
|
},
|
||||||
|
#[serde(rename = "builtin")]
|
||||||
|
Builtin {
|
||||||
|
name: String,
|
||||||
|
display_name: Option<String>,
|
||||||
|
timeout: Option<u64>,
|
||||||
|
},
|
||||||
|
#[serde(rename = "frontend")]
|
||||||
|
Frontend {
|
||||||
|
name: String,
|
||||||
|
tools: Vec<Tool>,
|
||||||
|
instructions: Option<String>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn start_session_handler(
|
||||||
|
req: SessionRequest,
|
||||||
|
_api_key: String,
|
||||||
|
) -> Result<impl warp::Reply, Rejection> {
|
||||||
|
info!("Starting session with prompt: {}", req.prompt);
|
||||||
|
|
||||||
|
let mut agent = AGENT.lock().await;
|
||||||
|
let mut messages = vec![Message::user().with_text(&req.prompt)];
|
||||||
|
let session_id = Uuid::new_v4();
|
||||||
|
let session_name = session_id.to_string();
|
||||||
|
let session_path = session::get_path(Identifier::Name(session_name.clone()));
|
||||||
|
|
||||||
|
let provider = agent.provider().await.ok();
|
||||||
|
|
||||||
|
let result = agent
|
||||||
|
.reply(
|
||||||
|
&messages,
|
||||||
|
Some(SessionConfig {
|
||||||
|
id: Identifier::Name(session_name.clone()),
|
||||||
|
working_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(mut stream) => {
|
||||||
|
if let Ok(Some(response)) = stream.try_next().await {
|
||||||
|
let response_text = response.as_concat_text();
|
||||||
|
messages.push(response);
|
||||||
|
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
||||||
|
warn!("Failed to persist session {}: {}", session_name, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
let api_response = StartSessionResponse {
|
||||||
|
message: response_text,
|
||||||
|
status: "success".to_string(),
|
||||||
|
session_id,
|
||||||
|
};
|
||||||
|
Ok(warp::reply::with_status(
|
||||||
|
warp::reply::json(&api_response),
|
||||||
|
warp::http::StatusCode::OK,
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
||||||
|
warn!("Failed to persist session {}: {}", session_name, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
let api_response = StartSessionResponse {
|
||||||
|
message: "Session started but no response generated".to_string(),
|
||||||
|
status: "warning".to_string(),
|
||||||
|
session_id,
|
||||||
|
};
|
||||||
|
Ok(warp::reply::with_status(
|
||||||
|
warp::reply::json(&api_response),
|
||||||
|
warp::http::StatusCode::OK,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to start session: {}", e);
|
||||||
|
let response = ApiResponse {
|
||||||
|
message: format!("Failed to start session: {}", e),
|
||||||
|
status: "error".to_string(),
|
||||||
|
};
|
||||||
|
Ok(warp::reply::with_status(
|
||||||
|
warp::reply::json(&response),
|
||||||
|
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn reply_session_handler(
|
||||||
|
req: SessionReplyRequest,
|
||||||
|
_api_key: String,
|
||||||
|
) -> Result<impl warp::Reply, Rejection> {
|
||||||
|
info!("Replying to session with prompt: {}", req.prompt);
|
||||||
|
|
||||||
|
let mut agent = AGENT.lock().await;
|
||||||
|
|
||||||
|
let session_name = req.session_id.to_string();
|
||||||
|
let session_path = session::get_path(Identifier::Name(session_name.clone()));
|
||||||
|
|
||||||
|
let mut messages = match session::read_messages(&session_path) {
|
||||||
|
Ok(m) => m,
|
||||||
|
Err(_) => {
|
||||||
|
let response = ApiResponse {
|
||||||
|
message: "Session not found".to_string(),
|
||||||
|
status: "error".to_string(),
|
||||||
|
};
|
||||||
|
return Ok(warp::reply::with_status(
|
||||||
|
warp::reply::json(&response),
|
||||||
|
warp::http::StatusCode::NOT_FOUND,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
messages.push(Message::user().with_text(&req.prompt));
|
||||||
|
|
||||||
|
let provider = agent.provider().await.ok();
|
||||||
|
|
||||||
|
let result = agent
|
||||||
|
.reply(
|
||||||
|
&messages,
|
||||||
|
Some(SessionConfig {
|
||||||
|
id: Identifier::Name(session_name.clone()),
|
||||||
|
working_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(mut stream) => {
|
||||||
|
if let Ok(Some(response)) = stream.try_next().await {
|
||||||
|
let response_text = response.as_concat_text();
|
||||||
|
messages.push(response);
|
||||||
|
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
||||||
|
warn!("Failed to persist session {}: {}", session_name, e);
|
||||||
|
}
|
||||||
|
let api_response = ApiResponse {
|
||||||
|
message: format!("Reply: {}", response_text),
|
||||||
|
status: "success".to_string(),
|
||||||
|
};
|
||||||
|
Ok(warp::reply::with_status(
|
||||||
|
warp::reply::json(&api_response),
|
||||||
|
warp::http::StatusCode::OK,
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
||||||
|
warn!("Failed to persist session {}: {}", session_name, e);
|
||||||
|
}
|
||||||
|
let api_response = ApiResponse {
|
||||||
|
message: "Reply processed but no response generated".to_string(),
|
||||||
|
status: "warning".to_string(),
|
||||||
|
};
|
||||||
|
Ok(warp::reply::with_status(
|
||||||
|
warp::reply::json(&api_response),
|
||||||
|
warp::http::StatusCode::OK,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to reply to session: {}", e);
|
||||||
|
let response = ApiResponse {
|
||||||
|
message: format!("Failed to reply to session: {}", e),
|
||||||
|
status: "error".to_string(),
|
||||||
|
};
|
||||||
|
Ok(warp::reply::with_status(
|
||||||
|
warp::reply::json(&response),
|
||||||
|
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn end_session_handler(
|
||||||
|
req: EndSessionRequest,
|
||||||
|
_api_key: String,
|
||||||
|
) -> Result<impl warp::Reply, Rejection> {
|
||||||
|
let session_name = req.session_id.to_string();
|
||||||
|
let session_path = session::get_path(Identifier::Name(session_name.clone()));
|
||||||
|
|
||||||
|
if std::fs::remove_file(&session_path).is_ok() {
|
||||||
|
let response = ApiResponse {
|
||||||
|
message: "Session ended".to_string(),
|
||||||
|
status: "success".to_string(),
|
||||||
|
};
|
||||||
|
Ok(warp::reply::with_status(
|
||||||
|
warp::reply::json(&response),
|
||||||
|
warp::http::StatusCode::OK,
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
let response = ApiResponse {
|
||||||
|
message: "Session not found".to_string(),
|
||||||
|
status: "error".to_string(),
|
||||||
|
};
|
||||||
|
Ok(warp::reply::with_status(
|
||||||
|
warp::reply::json(&response),
|
||||||
|
warp::http::StatusCode::NOT_FOUND,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn list_extensions_handler() -> Result<impl warp::Reply, Rejection> {
|
||||||
|
info!("Listing extensions");
|
||||||
|
|
||||||
|
match EXTENSION_MANAGER.list_extensions().await {
|
||||||
|
Ok(exts) => {
|
||||||
|
let response = ExtensionsResponse { extensions: exts };
|
||||||
|
Ok::<warp::reply::Json, warp::Rejection>(warp::reply::json(&response))
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to list extensions: {}", e);
|
||||||
|
let response = ExtensionsResponse {
|
||||||
|
extensions: vec!["Failed to list extensions".to_string()],
|
||||||
|
};
|
||||||
|
Ok::<warp::reply::Json, warp::Rejection>(warp::reply::json(&response))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_provider_config_handler() -> Result<impl warp::Reply, Rejection> {
|
||||||
|
info!("Getting provider configuration");
|
||||||
|
|
||||||
|
let config = Config::global();
|
||||||
|
let provider = config
|
||||||
|
.get_param::<String>("GOOSE_PROVIDER")
|
||||||
|
.unwrap_or_else(|_| "Not configured".to_string());
|
||||||
|
let model = config
|
||||||
|
.get_param::<String>("GOOSE_MODEL")
|
||||||
|
.unwrap_or_else(|_| "Not configured".to_string());
|
||||||
|
|
||||||
|
let response = ProviderConfig { provider, model };
|
||||||
|
Ok::<warp::reply::Json, warp::Rejection>(warp::reply::json(&response))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn add_extension_handler(
|
||||||
|
req: ExtensionConfigRequest,
|
||||||
|
_api_key: String,
|
||||||
|
) -> Result<impl warp::Reply, Rejection> {
|
||||||
|
info!("Adding extension: {:?}", req);
|
||||||
|
|
||||||
|
#[cfg(target_os = "windows")]
|
||||||
|
if let ExtensionConfigRequest::Stdio { cmd, .. } = &req {
|
||||||
|
if cmd.ends_with("npx.cmd") || cmd.ends_with("npx") {
|
||||||
|
let node_exists = std::path::Path::new(r"C:\Program Files\nodejs\node.exe").exists()
|
||||||
|
|| std::path::Path::new(r"C:\Program Files (x86)\nodejs\node.exe").exists();
|
||||||
|
|
||||||
|
if !node_exists {
|
||||||
|
let cmd_path = std::path::Path::new(cmd);
|
||||||
|
let script_dir = cmd_path.parent().ok_or_else(|| warp::reject())?;
|
||||||
|
|
||||||
|
let install_script = script_dir.join("install-node.cmd");
|
||||||
|
|
||||||
|
if install_script.exists() {
|
||||||
|
eprintln!("Installing Node.js...");
|
||||||
|
let output = std::process::Command::new(&install_script)
|
||||||
|
.arg("https://nodejs.org/dist/v23.10.0/node-v23.10.0-x64.msi")
|
||||||
|
.output()
|
||||||
|
.map_err(|_| warp::reject())?;
|
||||||
|
|
||||||
|
if !output.status.success() {
|
||||||
|
eprintln!(
|
||||||
|
"Failed to install Node.js: {}",
|
||||||
|
String::from_utf8_lossy(&output.stderr)
|
||||||
|
);
|
||||||
|
let resp = ExtensionResponse {
|
||||||
|
error: true,
|
||||||
|
message: Some(format!(
|
||||||
|
"Failed to install Node.js: {}",
|
||||||
|
String::from_utf8_lossy(&output.stderr)
|
||||||
|
)),
|
||||||
|
};
|
||||||
|
return Ok(warp::reply::json(&resp));
|
||||||
|
}
|
||||||
|
eprintln!("Node.js installation completed");
|
||||||
|
} else {
|
||||||
|
eprintln!("Node.js installer script not found at: {}", install_script.display());
|
||||||
|
let resp = ExtensionResponse {
|
||||||
|
error: true,
|
||||||
|
message: Some("Node.js installer script not found".to_string()),
|
||||||
|
};
|
||||||
|
return Ok(warp::reply::json(&resp));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let extension = match req {
|
||||||
|
ExtensionConfigRequest::Sse { name, uri, envs, env_keys, timeout } => {
|
||||||
|
ExtensionConfig::Sse {
|
||||||
|
name,
|
||||||
|
uri,
|
||||||
|
envs,
|
||||||
|
env_keys,
|
||||||
|
description: None,
|
||||||
|
timeout,
|
||||||
|
bundled: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExtensionConfigRequest::Stdio { name, cmd, args, envs, env_keys, timeout } => {
|
||||||
|
ExtensionConfig::Stdio {
|
||||||
|
name,
|
||||||
|
cmd,
|
||||||
|
args,
|
||||||
|
envs,
|
||||||
|
env_keys,
|
||||||
|
timeout,
|
||||||
|
description: None,
|
||||||
|
bundled: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExtensionConfigRequest::Builtin { name, display_name, timeout } => {
|
||||||
|
ExtensionConfig::Builtin {
|
||||||
|
name,
|
||||||
|
display_name,
|
||||||
|
timeout,
|
||||||
|
bundled: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExtensionConfigRequest::Frontend { name, tools, instructions } => {
|
||||||
|
ExtensionConfig::Frontend {
|
||||||
|
name,
|
||||||
|
tools,
|
||||||
|
instructions,
|
||||||
|
bundled: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let agent = AGENT.lock().await;
|
||||||
|
let result = agent.add_extension(extension).await;
|
||||||
|
|
||||||
|
let resp = match result {
|
||||||
|
Ok(_) => ExtensionResponse { error: false, message: None },
|
||||||
|
Err(e) => ExtensionResponse {
|
||||||
|
error: true,
|
||||||
|
message: Some(format!("Failed to add extension configuration, error: {:?}", e)),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
Ok(warp::reply::json(&resp))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn remove_extension_handler(
|
||||||
|
name: String,
|
||||||
|
_api_key: String,
|
||||||
|
) -> Result<impl warp::Reply, Rejection> {
|
||||||
|
info!("Removing extension: {}", name);
|
||||||
|
let agent = AGENT.lock().await;
|
||||||
|
agent.remove_extension(&name).await;
|
||||||
|
|
||||||
|
let resp = ExtensionResponse { error: false, message: None };
|
||||||
|
Ok(warp::reply::json(&resp))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_api_key(api_key: String) -> impl Filter<Extract = (String,), Error = Rejection> + Clone {
|
||||||
|
warp::header::value("x-api-key")
|
||||||
|
.and_then(move |header_api_key: HeaderValue| {
|
||||||
|
let api_key = api_key.clone();
|
||||||
|
async move {
|
||||||
|
if header_api_key == api_key {
|
||||||
|
Ok(api_key)
|
||||||
|
} else {
|
||||||
|
Err(warp::reject::not_found())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
5
crates/goose-api/src/lib.rs
Normal file
5
crates/goose-api/src/lib.rs
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
mod handlers;
|
||||||
|
mod config;
|
||||||
|
mod routes;
|
||||||
|
|
||||||
|
pub use routes::{build_routes, run_server};
|
||||||
@@ -1,706 +1,6 @@
|
|||||||
use warp::{Filter, Rejection};
|
use goose_api::run_server;
|
||||||
use warp::http::HeaderValue;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::sync::LazyLock;
|
|
||||||
use goose::config::{Config, ExtensionEntry};
|
|
||||||
use goose::agents::{
|
|
||||||
extension::Envs,
|
|
||||||
Agent,
|
|
||||||
extension_manager::ExtensionManager,
|
|
||||||
ExtensionConfig,
|
|
||||||
};
|
|
||||||
use mcp_core::tool::Tool;
|
|
||||||
use uuid::Uuid;
|
|
||||||
use goose::session::{self, Identifier};
|
|
||||||
use goose::agents::SessionConfig;
|
|
||||||
use std::path::PathBuf;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use goose::providers::{create, providers};
|
|
||||||
use goose::model::ModelConfig;
|
|
||||||
use goose::message::Message;
|
|
||||||
use tracing::{info, warn, error};
|
|
||||||
use config::{builder::DefaultState, ConfigBuilder, Environment, File};
|
|
||||||
use serde_json::Value; // Import the correct Value type
|
|
||||||
use futures_util::TryStreamExt;
|
|
||||||
|
|
||||||
// Global extension manager for extension listing
|
|
||||||
static EXTENSION_MANAGER: LazyLock<ExtensionManager> = LazyLock::new(|| ExtensionManager::default());
|
|
||||||
|
|
||||||
// Global agent for handling sessions
|
|
||||||
static AGENT: LazyLock<tokio::sync::Mutex<Agent>> = LazyLock::new(|| {
|
|
||||||
tokio::sync::Mutex::new(Agent::new())
|
|
||||||
});
|
|
||||||
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
struct SessionRequest {
|
|
||||||
prompt: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
struct ApiResponse {
|
|
||||||
message: String,
|
|
||||||
status: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
struct StartSessionResponse {
|
|
||||||
message: String,
|
|
||||||
status: String,
|
|
||||||
session_id: Uuid,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
struct SessionReplyRequest {
|
|
||||||
session_id: Uuid,
|
|
||||||
prompt: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
struct EndSessionRequest {
|
|
||||||
session_id: Uuid,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
struct ExtensionsResponse {
|
|
||||||
extensions: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
struct ProviderConfig {
|
|
||||||
provider: String,
|
|
||||||
model: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
struct ExtensionResponse {
|
|
||||||
error: bool,
|
|
||||||
message: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
#[serde(tag = "type")]
|
|
||||||
enum ExtensionConfigRequest {
|
|
||||||
#[serde(rename = "sse")]
|
|
||||||
Sse {
|
|
||||||
name: String,
|
|
||||||
uri: String,
|
|
||||||
#[serde(default)]
|
|
||||||
envs: Envs,
|
|
||||||
#[serde(default)]
|
|
||||||
env_keys: Vec<String>,
|
|
||||||
timeout: Option<u64>,
|
|
||||||
},
|
|
||||||
#[serde(rename = "stdio")]
|
|
||||||
Stdio {
|
|
||||||
name: String,
|
|
||||||
cmd: String,
|
|
||||||
#[serde(default)]
|
|
||||||
args: Vec<String>,
|
|
||||||
#[serde(default)]
|
|
||||||
envs: Envs,
|
|
||||||
#[serde(default)]
|
|
||||||
env_keys: Vec<String>,
|
|
||||||
timeout: Option<u64>,
|
|
||||||
},
|
|
||||||
#[serde(rename = "builtin")]
|
|
||||||
Builtin {
|
|
||||||
name: String,
|
|
||||||
display_name: Option<String>,
|
|
||||||
timeout: Option<u64>,
|
|
||||||
},
|
|
||||||
#[serde(rename = "frontend")]
|
|
||||||
Frontend {
|
|
||||||
name: String,
|
|
||||||
tools: Vec<Tool>,
|
|
||||||
instructions: Option<String>,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn start_session_handler(
|
|
||||||
req: SessionRequest,
|
|
||||||
_api_key: String,
|
|
||||||
) -> Result<impl warp::Reply, Rejection> {
|
|
||||||
info!("Starting session with prompt: {}", req.prompt);
|
|
||||||
|
|
||||||
let mut agent = AGENT.lock().await;
|
|
||||||
|
|
||||||
// Create a user message with the prompt
|
|
||||||
let mut messages = vec![Message::user().with_text(&req.prompt)];
|
|
||||||
|
|
||||||
// Generate a new session ID and process the messages
|
|
||||||
let session_id = Uuid::new_v4();
|
|
||||||
let session_name = session_id.to_string();
|
|
||||||
let session_path = session::get_path(Identifier::Name(session_name.clone()));
|
|
||||||
|
|
||||||
let provider = agent.provider().await.ok();
|
|
||||||
|
|
||||||
let result = agent
|
|
||||||
.reply(
|
|
||||||
&messages,
|
|
||||||
Some(SessionConfig {
|
|
||||||
id: Identifier::Name(session_name.clone()),
|
|
||||||
working_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(mut stream) => {
|
|
||||||
// Process the stream to get the first response
|
|
||||||
if let Ok(Some(response)) = stream.try_next().await {
|
|
||||||
let response_text = response.as_concat_text();
|
|
||||||
messages.push(response);
|
|
||||||
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
|
||||||
warn!("Failed to persist session {}: {}", session_name, e);
|
|
||||||
}
|
|
||||||
|
|
||||||
let api_response = StartSessionResponse {
|
|
||||||
message: response_text,
|
|
||||||
status: "success".to_string(),
|
|
||||||
session_id,
|
|
||||||
};
|
|
||||||
Ok(warp::reply::with_status(
|
|
||||||
warp::reply::json(&api_response),
|
|
||||||
warp::http::StatusCode::OK,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
|
||||||
warn!("Failed to persist session {}: {}", session_name, e);
|
|
||||||
}
|
|
||||||
|
|
||||||
let api_response = StartSessionResponse {
|
|
||||||
message: "Session started but no response generated".to_string(),
|
|
||||||
status: "warning".to_string(),
|
|
||||||
session_id,
|
|
||||||
};
|
|
||||||
Ok(warp::reply::with_status(
|
|
||||||
warp::reply::json(&api_response),
|
|
||||||
warp::http::StatusCode::OK,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Err(e) => {
|
|
||||||
error!("Failed to start session: {}", e);
|
|
||||||
let response = ApiResponse {
|
|
||||||
message: format!("Failed to start session: {}", e),
|
|
||||||
status: "error".to_string(),
|
|
||||||
};
|
|
||||||
Ok(warp::reply::with_status(
|
|
||||||
warp::reply::json(&response),
|
|
||||||
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn reply_session_handler(
|
|
||||||
req: SessionReplyRequest,
|
|
||||||
_api_key: String,
|
|
||||||
) -> Result<impl warp::Reply, Rejection> {
|
|
||||||
info!("Replying to session with prompt: {}", req.prompt);
|
|
||||||
|
|
||||||
let mut agent = AGENT.lock().await;
|
|
||||||
|
|
||||||
let session_name = req.session_id.to_string();
|
|
||||||
let session_path = session::get_path(Identifier::Name(session_name.clone()));
|
|
||||||
|
|
||||||
// Retrieve existing session history from disk
|
|
||||||
let mut messages = match session::read_messages(&session_path) {
|
|
||||||
Ok(m) => m,
|
|
||||||
Err(_) => {
|
|
||||||
let response = ApiResponse {
|
|
||||||
message: "Session not found".to_string(),
|
|
||||||
status: "error".to_string(),
|
|
||||||
};
|
|
||||||
return Ok(warp::reply::with_status(
|
|
||||||
warp::reply::json(&response),
|
|
||||||
warp::http::StatusCode::NOT_FOUND,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Append the new user message
|
|
||||||
messages.push(Message::user().with_text(&req.prompt));
|
|
||||||
|
|
||||||
let provider = agent.provider().await.ok();
|
|
||||||
|
|
||||||
// Process the messages through the agent
|
|
||||||
let result = agent
|
|
||||||
.reply(
|
|
||||||
&messages,
|
|
||||||
Some(SessionConfig {
|
|
||||||
id: Identifier::Name(session_name.clone()),
|
|
||||||
working_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(mut stream) => {
|
|
||||||
// Process the stream to get the first response
|
|
||||||
if let Ok(Some(response)) = stream.try_next().await {
|
|
||||||
let response_text = response.as_concat_text();
|
|
||||||
messages.push(response);
|
|
||||||
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
|
||||||
warn!("Failed to persist session {}: {}", session_name, e);
|
|
||||||
}
|
|
||||||
let api_response = ApiResponse {
|
|
||||||
message: format!("Reply: {}", response_text),
|
|
||||||
status: "success".to_string(),
|
|
||||||
};
|
|
||||||
Ok(warp::reply::with_status(
|
|
||||||
warp::reply::json(&api_response),
|
|
||||||
warp::http::StatusCode::OK,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
|
||||||
warn!("Failed to persist session {}: {}", session_name, e);
|
|
||||||
}
|
|
||||||
let api_response = ApiResponse {
|
|
||||||
message: "Reply processed but no response generated".to_string(),
|
|
||||||
status: "warning".to_string(),
|
|
||||||
};
|
|
||||||
Ok(warp::reply::with_status(
|
|
||||||
warp::reply::json(&api_response),
|
|
||||||
warp::http::StatusCode::OK,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Err(e) => {
|
|
||||||
error!("Failed to reply to session: {}", e);
|
|
||||||
let response = ApiResponse {
|
|
||||||
message: format!("Failed to reply to session: {}", e),
|
|
||||||
status: "error".to_string(),
|
|
||||||
};
|
|
||||||
Ok(warp::reply::with_status(
|
|
||||||
warp::reply::json(&response),
|
|
||||||
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn end_session_handler(
|
|
||||||
req: EndSessionRequest,
|
|
||||||
_api_key: String,
|
|
||||||
) -> Result<impl warp::Reply, Rejection> {
|
|
||||||
let session_name = req.session_id.to_string();
|
|
||||||
let session_path = session::get_path(Identifier::Name(session_name.clone()));
|
|
||||||
|
|
||||||
if std::fs::remove_file(&session_path).is_ok() {
|
|
||||||
let response = ApiResponse {
|
|
||||||
message: "Session ended".to_string(),
|
|
||||||
status: "success".to_string(),
|
|
||||||
};
|
|
||||||
Ok(warp::reply::with_status(
|
|
||||||
warp::reply::json(&response),
|
|
||||||
warp::http::StatusCode::OK,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
let response = ApiResponse {
|
|
||||||
message: "Session not found".to_string(),
|
|
||||||
status: "error".to_string(),
|
|
||||||
};
|
|
||||||
Ok(warp::reply::with_status(
|
|
||||||
warp::reply::json(&response),
|
|
||||||
warp::http::StatusCode::NOT_FOUND,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn list_extensions_handler() -> Result<impl warp::Reply, Rejection> {
|
|
||||||
info!("Listing extensions");
|
|
||||||
|
|
||||||
match EXTENSION_MANAGER.list_extensions().await {
|
|
||||||
Ok(exts) => {
|
|
||||||
let response = ExtensionsResponse { extensions: exts };
|
|
||||||
Ok::<warp::reply::Json, warp::Rejection>(warp::reply::json(&response))
|
|
||||||
},
|
|
||||||
Err(e) => {
|
|
||||||
error!("Failed to list extensions: {}", e);
|
|
||||||
let response = ExtensionsResponse {
|
|
||||||
extensions: vec!["Failed to list extensions".to_string()]
|
|
||||||
};
|
|
||||||
Ok::<warp::reply::Json, warp::Rejection>(warp::reply::json(&response))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_provider_config_handler() -> Result<impl warp::Reply, Rejection> {
|
|
||||||
info!("Getting provider configuration");
|
|
||||||
|
|
||||||
let config = Config::global();
|
|
||||||
let provider = config.get_param::<String>("GOOSE_PROVIDER")
|
|
||||||
.unwrap_or_else(|_| "Not configured".to_string());
|
|
||||||
let model = config.get_param::<String>("GOOSE_MODEL")
|
|
||||||
.unwrap_or_else(|_| "Not configured".to_string());
|
|
||||||
|
|
||||||
let response = ProviderConfig { provider, model };
|
|
||||||
Ok::<warp::reply::Json, warp::Rejection>(warp::reply::json(&response))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn add_extension_handler(
|
|
||||||
req: ExtensionConfigRequest,
|
|
||||||
_api_key: String,
|
|
||||||
) -> Result<impl warp::Reply, Rejection> {
|
|
||||||
info!("Adding extension: {:?}", req);
|
|
||||||
|
|
||||||
#[cfg(target_os = "windows")]
|
|
||||||
if let ExtensionConfigRequest::Stdio { cmd, .. } = &req {
|
|
||||||
if cmd.ends_with("npx.cmd") || cmd.ends_with("npx") {
|
|
||||||
let node_exists = std::path::Path::new(r"C:\Program Files\nodejs\node.exe").exists()
|
|
||||||
|| std::path::Path::new(r"C:\Program Files (x86)\nodejs\node.exe").exists();
|
|
||||||
|
|
||||||
if !node_exists {
|
|
||||||
let cmd_path = std::path::Path::new(cmd);
|
|
||||||
let script_dir = cmd_path.parent().ok_or_else(|| warp::reject())?;
|
|
||||||
|
|
||||||
let install_script = script_dir.join("install-node.cmd");
|
|
||||||
|
|
||||||
if install_script.exists() {
|
|
||||||
eprintln!("Installing Node.js...");
|
|
||||||
let output = std::process::Command::new(&install_script)
|
|
||||||
.arg("https://nodejs.org/dist/v23.10.0/node-v23.10.0-x64.msi")
|
|
||||||
.output()
|
|
||||||
.map_err(|_| warp::reject())?;
|
|
||||||
|
|
||||||
if !output.status.success() {
|
|
||||||
eprintln!(
|
|
||||||
"Failed to install Node.js: {}",
|
|
||||||
String::from_utf8_lossy(&output.stderr)
|
|
||||||
);
|
|
||||||
let resp = ExtensionResponse {
|
|
||||||
error: true,
|
|
||||||
message: Some(format!(
|
|
||||||
"Failed to install Node.js: {}",
|
|
||||||
String::from_utf8_lossy(&output.stderr)
|
|
||||||
)),
|
|
||||||
};
|
|
||||||
return Ok(warp::reply::json(&resp));
|
|
||||||
}
|
|
||||||
eprintln!("Node.js installation completed");
|
|
||||||
} else {
|
|
||||||
eprintln!(
|
|
||||||
"Node.js installer script not found at: {}",
|
|
||||||
install_script.display()
|
|
||||||
);
|
|
||||||
let resp = ExtensionResponse {
|
|
||||||
error: true,
|
|
||||||
message: Some("Node.js installer script not found".to_string()),
|
|
||||||
};
|
|
||||||
return Ok(warp::reply::json(&resp));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let extension = match req {
|
|
||||||
ExtensionConfigRequest::Sse { name, uri, envs, env_keys, timeout } => {
|
|
||||||
ExtensionConfig::Sse {
|
|
||||||
name,
|
|
||||||
uri,
|
|
||||||
envs,
|
|
||||||
env_keys,
|
|
||||||
description: None,
|
|
||||||
timeout,
|
|
||||||
bundled: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ExtensionConfigRequest::Stdio { name, cmd, args, envs, env_keys, timeout } => {
|
|
||||||
ExtensionConfig::Stdio {
|
|
||||||
name,
|
|
||||||
cmd,
|
|
||||||
args,
|
|
||||||
envs,
|
|
||||||
env_keys,
|
|
||||||
timeout,
|
|
||||||
description: None,
|
|
||||||
bundled: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ExtensionConfigRequest::Builtin { name, display_name, timeout } => {
|
|
||||||
ExtensionConfig::Builtin {
|
|
||||||
name,
|
|
||||||
display_name,
|
|
||||||
timeout,
|
|
||||||
bundled: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ExtensionConfigRequest::Frontend { name, tools, instructions } => {
|
|
||||||
ExtensionConfig::Frontend {
|
|
||||||
name,
|
|
||||||
tools,
|
|
||||||
instructions,
|
|
||||||
bundled: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let agent = AGENT.lock().await;
|
|
||||||
let result = agent.add_extension(extension).await;
|
|
||||||
|
|
||||||
let resp = match result {
|
|
||||||
Ok(_) => ExtensionResponse { error: false, message: None },
|
|
||||||
Err(e) => ExtensionResponse {
|
|
||||||
error: true,
|
|
||||||
message: Some(format!("Failed to add extension configuration, error: {:?}", e)),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
Ok(warp::reply::json(&resp))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn remove_extension_handler(
|
|
||||||
name: String,
|
|
||||||
_api_key: String,
|
|
||||||
) -> Result<impl warp::Reply, Rejection> {
|
|
||||||
info!("Removing extension: {}", name);
|
|
||||||
let agent = AGENT.lock().await;
|
|
||||||
agent.remove_extension(&name).await;
|
|
||||||
|
|
||||||
let resp = ExtensionResponse { error: false, message: None };
|
|
||||||
Ok(warp::reply::json(&resp))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn with_api_key(api_key: String) -> impl Filter<Extract = (String,), Error = Rejection> + Clone {
|
|
||||||
warp::header::value("x-api-key")
|
|
||||||
.and_then(move |header_api_key: HeaderValue| {
|
|
||||||
let api_key = api_key.clone();
|
|
||||||
async move {
|
|
||||||
if header_api_key == api_key {
|
|
||||||
Ok(api_key)
|
|
||||||
} else {
|
|
||||||
Err(warp::reject::not_found())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load configuration from file and environment variables
|
|
||||||
fn load_configuration() -> std::result::Result<config::Config, config::ConfigError> {
|
|
||||||
let config_path = std::env::var("GOOSE_CONFIG").unwrap_or_else(|_| "config".to_string());
|
|
||||||
let builder = ConfigBuilder::<DefaultState>::default()
|
|
||||||
.add_source(File::with_name(&config_path).required(false))
|
|
||||||
.add_source(Environment::with_prefix("GOOSE_API"));
|
|
||||||
|
|
||||||
builder.build()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize global provider configuration
|
|
||||||
async fn initialize_provider_config() -> Result<(), anyhow::Error> {
|
|
||||||
// Get configuration
|
|
||||||
let api_config = load_configuration()?;
|
|
||||||
|
|
||||||
// Get provider settings from configuration or environment variables
|
|
||||||
let provider_name = std::env::var("GOOSE_API_PROVIDER")
|
|
||||||
.or_else(|_| api_config.get_string("provider"))
|
|
||||||
.unwrap_or_else(|_| "openai".to_string());
|
|
||||||
|
|
||||||
let model_name = std::env::var("GOOSE_API_MODEL")
|
|
||||||
.or_else(|_| api_config.get_string("model"))
|
|
||||||
.unwrap_or_else(|_| "gpt-4o".to_string());
|
|
||||||
|
|
||||||
info!("Initializing with provider: {}, model: {}", provider_name, model_name);
|
|
||||||
|
|
||||||
// Initialize the global Config object
|
|
||||||
let config = Config::global();
|
|
||||||
config.set_param("GOOSE_PROVIDER", Value::String(provider_name.clone()))?;
|
|
||||||
config.set_param("GOOSE_MODEL", Value::String(model_name.clone()))?;
|
|
||||||
|
|
||||||
// Set up API keys from environment variables
|
|
||||||
let available_providers = providers();
|
|
||||||
if let Some(provider_meta) = available_providers.iter().find(|p| p.name == provider_name) {
|
|
||||||
for key in &provider_meta.config_keys {
|
|
||||||
let env_name = key.name.clone();
|
|
||||||
if let Ok(value) = std::env::var(&env_name) {
|
|
||||||
if key.secret {
|
|
||||||
config.set_secret(&key.name, Value::String(value))?;
|
|
||||||
info!("Set secret key: {}", key.name);
|
|
||||||
} else {
|
|
||||||
config.set_param(&key.name, Value::String(value))?;
|
|
||||||
info!("Set parameter: {}", key.name);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
warn!("Environment variable not set for key: {}", key.name);
|
|
||||||
if key.required {
|
|
||||||
error!("Required key {} not provided", key.name);
|
|
||||||
return Err(anyhow::anyhow!("Required key {} not provided", key.name));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize agent with provider
|
|
||||||
let model_config = ModelConfig::new(model_name);
|
|
||||||
let provider = create(&provider_name, model_config)?;
|
|
||||||
|
|
||||||
let agent = AGENT.lock().await;
|
|
||||||
agent.update_provider(provider).await?;
|
|
||||||
|
|
||||||
info!("Provider configuration successful");
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
/// Initialize extensions from the configuration.
|
|
||||||
async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Error> {
|
|
||||||
if let Ok(ext_table) = config.get_table("extensions") {
|
|
||||||
for (name, ext_config) in ext_table {
|
|
||||||
// Deserialize into ExtensionEntry to get enabled flag and config
|
|
||||||
let entry: ExtensionEntry = ext_config.clone().try_deserialize()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to deserialize extension config for {}: {}", name, e))?;
|
|
||||||
|
|
||||||
if entry.enabled {
|
|
||||||
let extension_config: ExtensionConfig = entry.config;
|
|
||||||
// Acquire the global agent lock and try to add the extension
|
|
||||||
let mut agent = AGENT.lock().await;
|
|
||||||
if let Err(e) = agent.add_extension(extension_config).await {
|
|
||||||
error!("Failed to add extension {}: {}", name, e);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
info!("Skipping disabled extension: {}", name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
warn!("No extensions configured in config file.");
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async fn run_init_tests() -> Result<(), anyhow::Error> {
|
|
||||||
info!("Running initialization tests");
|
|
||||||
{
|
|
||||||
let _agent = AGENT.lock().await;
|
|
||||||
info!("Agent initialization test passed");
|
|
||||||
}
|
|
||||||
info!("Initialization tests completed");
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), anyhow::Error> {
|
async fn main() -> Result<(), anyhow::Error> {
|
||||||
// Initialize tracing
|
run_server().await
|
||||||
tracing_subscriber::fmt()
|
|
||||||
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
|
||||||
.init();
|
|
||||||
|
|
||||||
info!("Starting goose-api server");
|
|
||||||
|
|
||||||
// Load configuration
|
|
||||||
let api_config = load_configuration()?;
|
|
||||||
|
|
||||||
// Get API key from configuration or environment
|
|
||||||
let api_key: String = std::env::var("GOOSE_API_KEY")
|
|
||||||
.or_else(|_| api_config.get_string("api_key"))
|
|
||||||
.unwrap_or_else(|_| {
|
|
||||||
warn!("No API key configured, using default");
|
|
||||||
"default_api_key".to_string()
|
|
||||||
});
|
|
||||||
|
|
||||||
// Initialize provider configuration
|
|
||||||
if let Err(e) = initialize_provider_config().await {
|
|
||||||
error!("Failed to initialize provider: {}", e);
|
|
||||||
return Err(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize extensions from configuration
|
|
||||||
if let Err(e) = initialize_extensions(&api_config).await {
|
|
||||||
error!("Failed to initialize extensions: {}", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Err(e) = run_init_tests().await {
|
|
||||||
error!("Initialization tests failed: {}", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Session start endpoint
|
|
||||||
let start_session = warp::path("session")
|
|
||||||
.and(warp::path("start"))
|
|
||||||
.and(warp::post())
|
|
||||||
.and(warp::body::json())
|
|
||||||
.and(with_api_key(api_key.clone()))
|
|
||||||
.and_then(start_session_handler);
|
|
||||||
|
|
||||||
// Session reply endpoint
|
|
||||||
let reply_session = warp::path("session")
|
|
||||||
.and(warp::path("reply"))
|
|
||||||
.and(warp::post())
|
|
||||||
.and(warp::body::json())
|
|
||||||
.and(with_api_key(api_key.clone()))
|
|
||||||
.and_then(reply_session_handler);
|
|
||||||
|
|
||||||
// Session end endpoint
|
|
||||||
let end_session = warp::path("session")
|
|
||||||
.and(warp::path("end"))
|
|
||||||
.and(warp::post())
|
|
||||||
.and(warp::body::json())
|
|
||||||
.and(with_api_key(api_key.clone()))
|
|
||||||
.and_then(end_session_handler);
|
|
||||||
|
|
||||||
// List extensions endpoint
|
|
||||||
let list_extensions = warp::path("extensions")
|
|
||||||
.and(warp::path("list"))
|
|
||||||
.and(warp::get())
|
|
||||||
.and_then(list_extensions_handler);
|
|
||||||
|
|
||||||
// Add extension endpoint
|
|
||||||
let add_extension = warp::path("extensions")
|
|
||||||
.and(warp::path("add"))
|
|
||||||
.and(warp::post())
|
|
||||||
.and(warp::body::json())
|
|
||||||
.and(with_api_key(api_key.clone()))
|
|
||||||
.and_then(add_extension_handler);
|
|
||||||
|
|
||||||
// Remove extension endpoint
|
|
||||||
let remove_extension = warp::path("extensions")
|
|
||||||
.and(warp::path("remove"))
|
|
||||||
.and(warp::post())
|
|
||||||
.and(warp::body::json())
|
|
||||||
.and(with_api_key(api_key.clone()))
|
|
||||||
.and_then(remove_extension_handler);
|
|
||||||
|
|
||||||
// Get provider configuration endpoint
|
|
||||||
let get_provider_config = warp::path("provider")
|
|
||||||
.and(warp::path("config"))
|
|
||||||
.and(warp::get())
|
|
||||||
.and_then(get_provider_config_handler);
|
|
||||||
|
|
||||||
// Combine all routes
|
|
||||||
let routes = start_session
|
|
||||||
.or(reply_session)
|
|
||||||
.or(end_session)
|
|
||||||
.or(list_extensions)
|
|
||||||
.or(add_extension)
|
|
||||||
.or(remove_extension)
|
|
||||||
.or(get_provider_config);
|
|
||||||
|
|
||||||
// Get bind address from configuration or use default
|
|
||||||
let host = std::env::var("GOOSE_API_HOST")
|
|
||||||
.or_else(|_| api_config.get_string("host"))
|
|
||||||
.unwrap_or_else(|_| "127.0.0.1".to_string());
|
|
||||||
|
|
||||||
let port = std::env::var("GOOSE_API_PORT")
|
|
||||||
.or_else(|_| api_config.get_string("port"))
|
|
||||||
.unwrap_or_else(|_| "8080".to_string())
|
|
||||||
.parse::<u16>()
|
|
||||||
.unwrap_or(8080);
|
|
||||||
|
|
||||||
info!("Starting server on {}:{}", host, port);
|
|
||||||
|
|
||||||
// Parse host string
|
|
||||||
let host_parts: Vec<u8> = host.split('.')
|
|
||||||
.map(|part| part.parse::<u8>().unwrap_or(127))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let addr = if host_parts.len() == 4 {
|
|
||||||
[host_parts[0], host_parts[1], host_parts[2], host_parts[3]]
|
|
||||||
} else {
|
|
||||||
[127, 0, 0, 1]
|
|
||||||
};
|
|
||||||
|
|
||||||
// Start the server
|
|
||||||
warp::serve(routes)
|
|
||||||
.run((addr, port))
|
|
||||||
.await;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|||||||
123
crates/goose-api/src/routes.rs
Normal file
123
crates/goose-api/src/routes.rs
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
use warp::Filter;
|
||||||
|
use tracing::{info, warn, error};
|
||||||
|
|
||||||
|
use crate::handlers::{
|
||||||
|
add_extension_handler, end_session_handler, get_provider_config_handler,
|
||||||
|
list_extensions_handler, remove_extension_handler, reply_session_handler,
|
||||||
|
start_session_handler, with_api_key,
|
||||||
|
};
|
||||||
|
use crate::config::{
|
||||||
|
initialize_extensions, initialize_provider_config, load_configuration,
|
||||||
|
run_init_tests,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn build_routes(api_key: String) -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
|
||||||
|
let start_session = warp::path("session")
|
||||||
|
.and(warp::path("start"))
|
||||||
|
.and(warp::post())
|
||||||
|
.and(warp::body::json())
|
||||||
|
.and(with_api_key(api_key.clone()))
|
||||||
|
.and_then(start_session_handler);
|
||||||
|
|
||||||
|
let reply_session = warp::path("session")
|
||||||
|
.and(warp::path("reply"))
|
||||||
|
.and(warp::post())
|
||||||
|
.and(warp::body::json())
|
||||||
|
.and(with_api_key(api_key.clone()))
|
||||||
|
.and_then(reply_session_handler);
|
||||||
|
|
||||||
|
let end_session = warp::path("session")
|
||||||
|
.and(warp::path("end"))
|
||||||
|
.and(warp::post())
|
||||||
|
.and(warp::body::json())
|
||||||
|
.and(with_api_key(api_key.clone()))
|
||||||
|
.and_then(end_session_handler);
|
||||||
|
|
||||||
|
let list_extensions = warp::path("extensions")
|
||||||
|
.and(warp::path("list"))
|
||||||
|
.and(warp::get())
|
||||||
|
.and_then(list_extensions_handler);
|
||||||
|
|
||||||
|
let add_extension = warp::path("extensions")
|
||||||
|
.and(warp::path("add"))
|
||||||
|
.and(warp::post())
|
||||||
|
.and(warp::body::json())
|
||||||
|
.and(with_api_key(api_key.clone()))
|
||||||
|
.and_then(add_extension_handler);
|
||||||
|
|
||||||
|
let remove_extension = warp::path("extensions")
|
||||||
|
.and(warp::path("remove"))
|
||||||
|
.and(warp::post())
|
||||||
|
.and(warp::body::json())
|
||||||
|
.and(with_api_key(api_key.clone()))
|
||||||
|
.and_then(remove_extension_handler);
|
||||||
|
|
||||||
|
let get_provider_config = warp::path("provider")
|
||||||
|
.and(warp::path("config"))
|
||||||
|
.and(warp::get())
|
||||||
|
.and_then(get_provider_config_handler);
|
||||||
|
|
||||||
|
start_session
|
||||||
|
.or(reply_session)
|
||||||
|
.or(end_session)
|
||||||
|
.or(list_extensions)
|
||||||
|
.or(add_extension)
|
||||||
|
.or(remove_extension)
|
||||||
|
.or(get_provider_config)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_server() -> Result<(), anyhow::Error> {
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
||||||
|
.init();
|
||||||
|
|
||||||
|
info!("Starting goose-api server");
|
||||||
|
|
||||||
|
let api_config = load_configuration()?;
|
||||||
|
|
||||||
|
let api_key: String = std::env::var("GOOSE_API_KEY")
|
||||||
|
.or_else(|_| api_config.get_string("api_key"))
|
||||||
|
.unwrap_or_else(|_| {
|
||||||
|
warn!("No API key configured, using default");
|
||||||
|
"default_api_key".to_string()
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Err(e) = initialize_provider_config().await {
|
||||||
|
error!("Failed to initialize provider: {}", e);
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(e) = initialize_extensions(&api_config).await {
|
||||||
|
error!("Failed to initialize extensions: {}", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(e) = run_init_tests().await {
|
||||||
|
error!("Initialization tests failed: {}", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
let routes = build_routes(api_key.clone());
|
||||||
|
|
||||||
|
let host = std::env::var("GOOSE_API_HOST")
|
||||||
|
.or_else(|_| api_config.get_string("host"))
|
||||||
|
.unwrap_or_else(|_| "127.0.0.1".to_string());
|
||||||
|
let port = std::env::var("GOOSE_API_PORT")
|
||||||
|
.or_else(|_| api_config.get_string("port"))
|
||||||
|
.unwrap_or_else(|_| "8080".to_string())
|
||||||
|
.parse::<u16>()
|
||||||
|
.unwrap_or(8080);
|
||||||
|
|
||||||
|
info!("Starting server on {}:{}", host, port);
|
||||||
|
|
||||||
|
let host_parts: Vec<u8> = host
|
||||||
|
.split('.')
|
||||||
|
.map(|part| part.parse::<u8>().unwrap_or(127))
|
||||||
|
.collect();
|
||||||
|
let addr = if host_parts.len() == 4 {
|
||||||
|
[host_parts[0], host_parts[1], host_parts[2], host_parts[3]]
|
||||||
|
} else {
|
||||||
|
[127, 0, 0, 1]
|
||||||
|
};
|
||||||
|
|
||||||
|
warp::serve(routes).run((addr, port)).await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
10
crates/goose-api/src/tests.rs
Normal file
10
crates/goose-api/src/tests.rs
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn build_routes_compiles() {
|
||||||
|
let _routes = build_routes("test-key".to_string());
|
||||||
|
// Just ensure building routes doesn't panic
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user