Files
goose/crates/goose-server/src/routes/reply.rs
Salman Mohammed 199fa6adbc chore: refactor read-write lock on agent (#2225)
Co-authored-by: Alice Hau <ahau@squareup.com>
2025-04-23 22:46:22 -04:00

565 lines
17 KiB
Rust

use super::utils::verify_secret_key;
use crate::state::AppState;
use axum::{
extract::State,
http::{self, HeaderMap, StatusCode},
response::IntoResponse,
routing::post,
Json, Router,
};
use bytes::Bytes;
use futures::{stream::StreamExt, Stream};
use goose::{
agents::SessionConfig,
message::{Message, MessageContent},
permission::permission_confirmation::PrincipalType,
};
use goose::{
permission::{Permission, PermissionConfirmation},
session,
};
use mcp_core::{role::Role, Content, ToolResult};
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::Value;
use std::{
convert::Infallible,
path::PathBuf,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use tokio::sync::mpsc;
use tokio::time::timeout;
use tokio_stream::wrappers::ReceiverStream;
use utoipa::ToSchema;
// Direct message serialization for the chat request
#[derive(Debug, Deserialize)]
struct ChatRequest {
messages: Vec<Message>,
session_id: Option<String>,
session_working_dir: String,
}
// Custom SSE response type for streaming messages
pub struct SseResponse {
rx: ReceiverStream<String>,
}
impl SseResponse {
fn new(rx: ReceiverStream<String>) -> Self {
Self { rx }
}
}
impl Stream for SseResponse {
type Item = Result<Bytes, Infallible>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.rx)
.poll_next(cx)
.map(|opt| opt.map(|s| Ok(Bytes::from(s))))
}
}
impl IntoResponse for SseResponse {
fn into_response(self) -> axum::response::Response {
let stream = self;
let body = axum::body::Body::from_stream(stream);
http::Response::builder()
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache")
.header("Connection", "keep-alive")
.body(body)
.unwrap()
}
}
// Message event types for SSE streaming
#[derive(Debug, Serialize)]
#[serde(tag = "type")]
enum MessageEvent {
Message { message: Message },
Error { error: String },
Finish { reason: String },
}
// Stream a message as an SSE event
async fn stream_event(
event: MessageEvent,
tx: &mpsc::Sender<String>,
) -> Result<(), mpsc::error::SendError<String>> {
let json = serde_json::to_string(&event).unwrap_or_else(|e| {
format!(
r#"{{"type":"Error","error":"Failed to serialize event: {}"}}"#,
e
)
});
tx.send(format!("data: {}\n\n", json)).await
}
async fn handler(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(request): Json<ChatRequest>,
) -> Result<SseResponse, StatusCode> {
verify_secret_key(&headers, &state)?;
// Create channel for streaming
let (tx, rx) = mpsc::channel(100);
let stream = ReceiverStream::new(rx);
let messages = request.messages;
let session_working_dir = request.session_working_dir;
// Generate a new session ID if not provided in the request
let session_id = request
.session_id
.unwrap_or_else(session::generate_session_id);
// Spawn task to handle streaming
tokio::spawn(async move {
let agent = state.get_agent().await;
let agent = match agent {
Ok(agent) => {
let provider = agent.provider().await;
match provider {
Ok(_) => agent,
Err(_) => {
let _ = stream_event(
MessageEvent::Error {
error: "No provider configured".to_string(),
},
&tx,
)
.await;
let _ = stream_event(
MessageEvent::Finish {
reason: "error".to_string(),
},
&tx,
)
.await;
return;
}
}
}
Err(_) => {
let _ = stream_event(
MessageEvent::Error {
error: "No agent configured".to_string(),
},
&tx,
)
.await;
let _ = stream_event(
MessageEvent::Finish {
reason: "error".to_string(),
},
&tx,
)
.await;
return;
}
};
// Get the provider first, before starting the reply stream
let provider = agent.provider().await;
let mut stream = match agent
.reply(
&messages,
Some(SessionConfig {
id: session::Identifier::Name(session_id.clone()),
working_dir: PathBuf::from(session_working_dir),
}),
)
.await
{
Ok(stream) => stream,
Err(e) => {
tracing::error!("Failed to start reply stream: {:?}", e);
let _ = stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
)
.await;
let _ = stream_event(
MessageEvent::Finish {
reason: "error".to_string(),
},
&tx,
)
.await;
return;
}
};
// Collect all messages for storage
let mut all_messages = messages.clone();
let session_path = session::get_path(session::Identifier::Name(session_id.clone()));
loop {
tokio::select! {
response = timeout(Duration::from_millis(500), stream.next()) => {
match response {
Ok(Some(Ok(message))) => {
all_messages.push(message.clone());
if let Err(e) = stream_event(MessageEvent::Message { message }, &tx).await {
tracing::error!("Error sending message through channel: {}", e);
let _ = stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
).await;
break;
}
// Store messages and generate description in background
let session_path = session_path.clone();
let messages = all_messages.clone();
let provider = Arc::clone(provider.as_ref().unwrap());
tokio::spawn(async move {
if let Err(e) = session::persist_messages(&session_path, &messages, Some(provider)).await {
tracing::error!("Failed to store session history: {:?}", e);
}
});
}
Ok(Some(Err(e))) => {
tracing::error!("Error processing message: {}", e);
let _ = stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
).await;
break;
}
Ok(None) => {
break;
}
Err(_) => { // Heartbeat, used to detect disconnected clients
if tx.is_closed() {
break;
}
continue;
}
}
}
}
}
// Send finish event
let _ = stream_event(
MessageEvent::Finish {
reason: "stop".to_string(),
},
&tx,
)
.await;
});
Ok(SseResponse::new(stream))
}
#[derive(Debug, Deserialize, Serialize)]
struct AskRequest {
prompt: String,
session_id: Option<String>,
session_working_dir: String,
}
#[derive(Debug, Serialize)]
struct AskResponse {
response: String,
}
// Simple ask an AI for a response, non streaming
async fn ask_handler(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(request): Json<AskRequest>,
) -> Result<Json<AskResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;
let session_working_dir = request.session_working_dir;
// Generate a new session ID if not provided in the request
let session_id = request
.session_id
.unwrap_or_else(session::generate_session_id);
let agent = state
.get_agent()
.await
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
// Get the provider first, before starting the reply stream
let provider = agent.provider().await;
// Create a single message for the prompt
let messages = vec![Message::user().with_text(request.prompt)];
// Get response from agent
let mut response_text = String::new();
let mut stream = match agent
.reply(
&messages,
Some(SessionConfig {
id: session::Identifier::Name(session_id.clone()),
working_dir: PathBuf::from(session_working_dir),
}),
)
.await
{
Ok(stream) => stream,
Err(e) => {
tracing::error!("Failed to start reply stream: {:?}", e);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
};
// Collect all messages for storage
let mut all_messages = messages.clone();
let mut response_message = Message::assistant();
while let Some(response) = stream.next().await {
match response {
Ok(message) => {
if message.role == Role::Assistant {
for content in &message.content {
if let MessageContent::Text(text) = content {
response_text.push_str(&text.text);
response_text.push('\n');
}
response_message.content.push(content.clone());
}
}
}
Err(e) => {
tracing::error!("Error processing as_ai message: {}", e);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
}
}
// Add the complete response message to the conversation history
if !response_message.content.is_empty() {
all_messages.push(response_message);
}
// Get the session path - file will be created when needed
let session_path = session::get_path(session::Identifier::Name(session_id.clone()));
// Store messages and generate description in background
let session_path = session_path.clone();
let messages = all_messages.clone();
let provider = Arc::clone(provider.as_ref().unwrap());
tokio::spawn(async move {
if let Err(e) = session::persist_messages(&session_path, &messages, Some(provider)).await {
tracing::error!("Failed to store session history: {:?}", e);
}
});
Ok(Json(AskResponse {
response: response_text.trim().to_string(),
}))
}
#[derive(Debug, Deserialize, Serialize, ToSchema)]
pub struct PermissionConfirmationRequest {
id: String,
#[serde(default = "default_principal_type")]
principal_type: PrincipalType,
action: String,
}
fn default_principal_type() -> PrincipalType {
PrincipalType::Tool
}
#[utoipa::path(
post,
path = "/confirm",
request_body = PermissionConfirmationRequest,
responses(
(status = 200, description = "Permission action is confirmed", body = Value),
(status = 401, description = "Unauthorized - invalid secret key"),
(status = 500, description = "Internal server error")
)
)]
pub async fn confirm_permission(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(request): Json<PermissionConfirmationRequest>,
) -> Result<Json<Value>, StatusCode> {
verify_secret_key(&headers, &state)?;
let agent = state
.get_agent()
.await
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
let permission = match request.action.as_str() {
"always_allow" => Permission::AlwaysAllow,
"allow_once" => Permission::AllowOnce,
"deny" => Permission::DenyOnce,
_ => Permission::DenyOnce,
};
agent
.handle_confirmation(
request.id.clone(),
PermissionConfirmation {
principal_type: request.principal_type,
permission,
},
)
.await;
Ok(Json(Value::Object(serde_json::Map::new())))
}
#[derive(Debug, Deserialize)]
struct ToolResultRequest {
id: String,
result: ToolResult<Vec<Content>>,
}
async fn submit_tool_result(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
raw: axum::extract::Json<serde_json::Value>,
) -> Result<Json<Value>, StatusCode> {
verify_secret_key(&headers, &state)?;
// Log the raw request for debugging
tracing::info!(
"Received tool result request: {}",
serde_json::to_string_pretty(&raw.0).unwrap()
);
// Try to parse into our struct
let payload: ToolResultRequest = match serde_json::from_value(raw.0.clone()) {
Ok(req) => req,
Err(e) => {
tracing::error!("Failed to parse tool result request: {}", e);
tracing::error!(
"Raw request was: {}",
serde_json::to_string_pretty(&raw.0).unwrap()
);
return Err(StatusCode::UNPROCESSABLE_ENTITY);
}
};
let agent = state
.get_agent()
.await
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
agent.handle_tool_result(payload.id, payload.result).await;
Ok(Json(json!({"status": "ok"})))
}
// Configure routes for this module
pub fn routes(state: Arc<AppState>) -> Router {
Router::new()
.route("/reply", post(handler))
.route("/ask", post(ask_handler))
.route("/confirm", post(confirm_permission))
.route("/tool_result", post(submit_tool_result))
.with_state(state)
}
#[cfg(test)]
mod tests {
use super::*;
use goose::{
agents::Agent,
model::ModelConfig,
providers::{
base::{Provider, ProviderUsage, Usage},
errors::ProviderError,
},
};
use mcp_core::tool::Tool;
// Mock Provider implementation for testing
#[derive(Clone)]
struct MockProvider {
model_config: ModelConfig,
}
#[async_trait::async_trait]
impl Provider for MockProvider {
fn metadata() -> goose::providers::base::ProviderMetadata {
goose::providers::base::ProviderMetadata::empty()
}
fn get_model_config(&self) -> ModelConfig {
self.model_config.clone()
}
async fn complete(
&self,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> anyhow::Result<(Message, ProviderUsage), ProviderError> {
Ok((
Message::assistant().with_text("Mock response"),
ProviderUsage::new("mock".to_string(), Usage::default()),
))
}
}
mod integration_tests {
use super::*;
use axum::{body::Body, http::Request};
use std::sync::Arc;
use tower::ServiceExt;
// This test requires tokio runtime
#[tokio::test]
async fn test_ask_endpoint() {
// Create a mock app state with mock provider
let mock_model_config = ModelConfig::new("test-model".to_string());
let mock_provider = Arc::new(MockProvider {
model_config: mock_model_config,
});
let agent = Agent::new();
let _ = agent.update_provider(mock_provider).await;
let state = AppState::new(Arc::new(agent), "test-secret".to_string()).await;
// Build router
let app = routes(state);
// Create request
let request = Request::builder()
.uri("/ask")
.method("POST")
.header("content-type", "application/json")
.header("x-secret-key", "test-secret")
.body(Body::from(
serde_json::to_string(&AskRequest {
prompt: "test prompt".to_string(),
session_id: Some("test-session".to_string()),
session_working_dir: "test-working-dir".to_string(),
})
.unwrap(),
))
.unwrap();
// Send request
let response = app.oneshot(request).await.unwrap();
// Assert response status
assert_eq!(response.status(), StatusCode::OK);
}
}
}