mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-23 07:24:24 +01:00
Session file security updates (#3071)
This commit is contained in:
@@ -175,7 +175,12 @@ pub fn handle_session_list(verbose: bool, format: String, ascending: bool) -> Re
|
||||
/// without creating an Agent or prompting about working directories.
|
||||
pub fn handle_session_export(identifier: Identifier, output_path: Option<PathBuf>) -> Result<()> {
|
||||
// Get the session file path
|
||||
let session_file_path = goose::session::get_path(identifier.clone());
|
||||
let session_file_path = match goose::session::get_path(identifier.clone()) {
|
||||
Ok(path) => path,
|
||||
Err(e) => {
|
||||
return Err(anyhow::anyhow!("Invalid session identifier: {}", e));
|
||||
}
|
||||
};
|
||||
|
||||
if !session_file_path.exists() {
|
||||
return Err(anyhow::anyhow!(
|
||||
|
||||
@@ -250,7 +250,14 @@ async fn list_sessions() -> Json<serde_json::Value> {
|
||||
async fn get_session(
|
||||
axum::extract::Path(session_id): axum::extract::Path<String>,
|
||||
) -> Json<serde_json::Value> {
|
||||
let session_file = session::get_path(session::Identifier::Name(session_id));
|
||||
let session_file = match session::get_path(session::Identifier::Name(session_id)) {
|
||||
Ok(path) => path,
|
||||
Err(e) => {
|
||||
return Json(serde_json::json!({
|
||||
"error": format!("Invalid session ID: {}", e)
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
match session::read_messages(&session_file) {
|
||||
Ok(messages) => {
|
||||
@@ -288,8 +295,15 @@ async fn handle_socket(socket: WebSocket, state: AppState) {
|
||||
..
|
||||
}) => {
|
||||
// Get session file path from session_id
|
||||
let session_file =
|
||||
session::get_path(session::Identifier::Name(session_id.clone()));
|
||||
let session_file = match session::get_path(session::Identifier::Name(
|
||||
session_id.clone(),
|
||||
)) {
|
||||
Ok(path) => path,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to get session path: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Get or create session in memory (for fast access during processing)
|
||||
let session_messages = {
|
||||
|
||||
@@ -234,7 +234,13 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
|
||||
}
|
||||
} else if session_config.resume {
|
||||
if let Some(identifier) = session_config.identifier {
|
||||
let session_file = session::get_path(identifier);
|
||||
let session_file = match session::get_path(identifier) {
|
||||
Ok(path) => path,
|
||||
Err(e) => {
|
||||
output::render_error(&format!("Invalid session identifier: {}", e));
|
||||
process::exit(1);
|
||||
}
|
||||
};
|
||||
if !session_file.exists() {
|
||||
output::render_error(&format!(
|
||||
"Cannot resume session {} - no such session exists",
|
||||
@@ -262,7 +268,13 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
|
||||
};
|
||||
|
||||
// Just get the path - file will be created when needed
|
||||
session::get_path(id)
|
||||
match session::get_path(id) {
|
||||
Ok(path) => path,
|
||||
Err(e) => {
|
||||
output::render_error(&format!("Failed to create session path: {}", e));
|
||||
process::exit(1);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if session_config.resume && !session_config.no_session {
|
||||
|
||||
@@ -210,7 +210,20 @@ async fn handler(
|
||||
};
|
||||
|
||||
let mut all_messages = messages.clone();
|
||||
let session_path = session::get_path(session::Identifier::Name(session_id.clone()));
|
||||
let session_path = match session::get_path(session::Identifier::Name(session_id.clone())) {
|
||||
Ok(path) => path,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to get session path: {}", e);
|
||||
let _ = stream_event(
|
||||
MessageEvent::Error {
|
||||
error: format!("Failed to get session path: {}", e),
|
||||
},
|
||||
&tx,
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
@@ -390,13 +403,21 @@ async fn ask_handler(
|
||||
all_messages.push(response_message);
|
||||
}
|
||||
|
||||
let session_path = session::get_path(session::Identifier::Name(session_id.clone()));
|
||||
let session_path = match session::get_path(session::Identifier::Name(session_id.clone())) {
|
||||
Ok(path) => path,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to get session path: {}", e);
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
};
|
||||
|
||||
let session_path = session_path.clone();
|
||||
let session_path_clone = session_path.clone();
|
||||
let messages = all_messages.clone();
|
||||
let provider = Arc::clone(provider.as_ref().unwrap());
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = session::persist_messages(&session_path, &messages, Some(provider)).await {
|
||||
if let Err(e) =
|
||||
session::persist_messages(&session_path_clone, &messages, Some(provider)).await
|
||||
{
|
||||
tracing::error!("Failed to store session history: {:?}", e);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -84,7 +84,10 @@ async fn get_session_history(
|
||||
) -> Result<Json<SessionHistoryResponse>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
let session_path = session::get_path(session::Identifier::Name(session_id.clone()));
|
||||
let session_path = match session::get_path(session::Identifier::Name(session_id.clone())) {
|
||||
Ok(path) => path,
|
||||
Err(_) => return Err(StatusCode::BAD_REQUEST),
|
||||
};
|
||||
|
||||
// Read metadata
|
||||
let metadata = session::read_metadata(&session_path).map_err(|_| StatusCode::NOT_FOUND)?;
|
||||
|
||||
@@ -222,7 +222,12 @@ impl Agent {
|
||||
usage: &crate::providers::base::ProviderUsage,
|
||||
messages_length: usize,
|
||||
) -> Result<()> {
|
||||
let session_file_path = session::storage::get_path(session_config.id.clone());
|
||||
let session_file_path = match session::storage::get_path(session_config.id.clone()) {
|
||||
Ok(path) => path,
|
||||
Err(e) => {
|
||||
return Err(anyhow::anyhow!("Failed to get session file path: {}", e));
|
||||
}
|
||||
};
|
||||
let mut metadata = session::storage::read_metadata(&session_file_path)?;
|
||||
|
||||
metadata.schedule_id = session_config.schedule_id.clone();
|
||||
|
||||
@@ -372,9 +372,17 @@ impl Agent {
|
||||
})?;
|
||||
|
||||
// Get the session file path
|
||||
let session_path = crate::session::storage::get_path(
|
||||
let session_path = match crate::session::storage::get_path(
|
||||
crate::session::storage::Identifier::Name(session_id.to_string()),
|
||||
);
|
||||
) {
|
||||
Ok(path) => path,
|
||||
Err(e) => {
|
||||
return Err(ToolError::ExecutionError(format!(
|
||||
"Invalid session ID '{}': {}",
|
||||
session_id, e
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
// Check if session file exists
|
||||
if !session_path.exists() {
|
||||
|
||||
@@ -1138,9 +1138,17 @@ async fn run_scheduled_job_internal(
|
||||
}
|
||||
}
|
||||
|
||||
let session_file_path = crate::session::storage::get_path(
|
||||
let session_file_path = match crate::session::storage::get_path(
|
||||
crate::session::storage::Identifier::Name(session_id_for_return.clone()),
|
||||
);
|
||||
) {
|
||||
Ok(path) => path,
|
||||
Err(e) => {
|
||||
return Err(JobExecutionError {
|
||||
job_id: job.id.clone(),
|
||||
error: format!("Failed to get session file path: {}", e),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(prompt_text) = recipe.prompt {
|
||||
let mut all_session_messages: Vec<Message> =
|
||||
|
||||
@@ -18,6 +18,11 @@ use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
// Security limits
|
||||
const MAX_FILE_SIZE: u64 = 10 * 1024 * 1024; // 10MB
|
||||
const MAX_MESSAGE_COUNT: usize = 5000;
|
||||
const MAX_LINE_LENGTH: usize = 1024 * 1024; // 1MB per line
|
||||
|
||||
fn get_home_dir() -> PathBuf {
|
||||
choose_app_strategy(crate::config::APP_STRATEGY.clone())
|
||||
.expect("goose requires a home dir")
|
||||
@@ -133,14 +138,62 @@ pub enum Identifier {
|
||||
Path(PathBuf),
|
||||
}
|
||||
|
||||
pub fn get_path(id: Identifier) -> PathBuf {
|
||||
match id {
|
||||
pub fn get_path(id: Identifier) -> Result<PathBuf> {
|
||||
let path = match id {
|
||||
Identifier::Name(name) => {
|
||||
let session_dir = ensure_session_dir().expect("Failed to create session directory");
|
||||
// Validate session name for security
|
||||
if name.is_empty() || name.len() > 255 {
|
||||
return Err(anyhow::anyhow!("Invalid session name length"));
|
||||
}
|
||||
|
||||
// Check for path traversal attempts
|
||||
if name.contains("..") || name.contains('/') || name.contains('\\') {
|
||||
return Err(anyhow::anyhow!("Invalid characters in session name"));
|
||||
}
|
||||
|
||||
let session_dir = ensure_session_dir().map_err(|e| {
|
||||
tracing::error!("Failed to create session directory: {}", e);
|
||||
anyhow::anyhow!("Failed to access session directory")
|
||||
})?;
|
||||
session_dir.join(format!("{}.jsonl", name))
|
||||
}
|
||||
Identifier::Path(path) => path,
|
||||
Identifier::Path(path) => {
|
||||
// In test mode, allow temporary directory paths
|
||||
#[cfg(test)]
|
||||
{
|
||||
if let Some(path_str) = path.to_str() {
|
||||
if path_str.contains("/tmp") || path_str.contains("/.tmp") {
|
||||
// Allow test temporary directories
|
||||
return Ok(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate that the path is within allowed directories
|
||||
let canonical_path = path.canonicalize().unwrap_or(path.clone());
|
||||
let session_dir = ensure_session_dir().map_err(|e| {
|
||||
tracing::error!("Failed to create session directory: {}", e);
|
||||
anyhow::anyhow!("Failed to access session directory")
|
||||
})?;
|
||||
let canonical_session_dir = session_dir.canonicalize().unwrap_or(session_dir);
|
||||
|
||||
if !canonical_path.starts_with(&canonical_session_dir) {
|
||||
tracing::warn!("Attempted access outside session directory");
|
||||
return Err(anyhow::anyhow!("Path not allowed"));
|
||||
}
|
||||
|
||||
path
|
||||
}
|
||||
};
|
||||
|
||||
// Additional security check for file extension
|
||||
if let Some(ext) = path.extension() {
|
||||
if ext != "jsonl" {
|
||||
return Err(anyhow::anyhow!("Invalid file extension"));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
/// Ensure the session directory exists and return its path
|
||||
@@ -221,17 +274,20 @@ pub fn generate_session_id() -> String {
|
||||
/// The first line of the file is expected to be metadata, and the rest are messages.
|
||||
/// Large messages are automatically truncated to prevent memory issues.
|
||||
/// Includes recovery mechanisms for corrupted files.
|
||||
///
|
||||
/// Security features:
|
||||
/// - Validates file paths to prevent directory traversal
|
||||
/// - Includes all security limits from read_messages_with_truncation
|
||||
pub fn read_messages(session_file: &Path) -> Result<Vec<Message>> {
|
||||
let result = read_messages_with_truncation(session_file, Some(50000)); // 50KB limit per message content
|
||||
// Validate the path for security
|
||||
let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?;
|
||||
|
||||
let result = read_messages_with_truncation(&secure_path, Some(50000)); // 50KB limit per message content
|
||||
match &result {
|
||||
Ok(messages) => println!(
|
||||
"[SESSION] Successfully read {} messages from: {:?}",
|
||||
messages.len(),
|
||||
session_file
|
||||
),
|
||||
Ok(_messages) => {}
|
||||
Err(e) => println!(
|
||||
"[SESSION] Failed to read messages from {:?}: {}",
|
||||
session_file, e
|
||||
secure_path, e
|
||||
),
|
||||
}
|
||||
result
|
||||
@@ -243,10 +299,24 @@ pub fn read_messages(session_file: &Path) -> Result<Vec<Message>> {
|
||||
/// The first line of the file is expected to be metadata, and the rest are messages.
|
||||
/// If max_content_size is Some, large message content will be truncated during loading.
|
||||
/// Includes robust error handling and corruption recovery mechanisms.
|
||||
///
|
||||
/// Security features:
|
||||
/// - File size limits to prevent resource exhaustion
|
||||
/// - Message count limits to prevent DoS attacks
|
||||
/// - Line length restrictions to prevent memory issues
|
||||
pub fn read_messages_with_truncation(
|
||||
session_file: &Path,
|
||||
max_content_size: Option<usize>,
|
||||
) -> Result<Vec<Message>> {
|
||||
// Security check: file size limit
|
||||
if session_file.exists() {
|
||||
let metadata = fs::metadata(session_file)?;
|
||||
if metadata.len() > MAX_FILE_SIZE {
|
||||
tracing::warn!("Session file exceeds size limit: {} bytes", metadata.len());
|
||||
return Err(anyhow::anyhow!("Session file too large"));
|
||||
}
|
||||
}
|
||||
|
||||
// Check if there's a backup file we should restore from
|
||||
let backup_file = session_file.with_extension("backup");
|
||||
if !session_file.exists() && backup_file.exists() {
|
||||
@@ -255,7 +325,7 @@ pub fn read_messages_with_truncation(
|
||||
backup_file
|
||||
);
|
||||
tracing::warn!(
|
||||
"[SESSION] Session file missing but backup exists, restoring from backup: {:?}",
|
||||
"Session file missing but backup exists, restoring from backup: {:?}",
|
||||
backup_file
|
||||
);
|
||||
if let Err(e) = fs::copy(&backup_file, session_file) {
|
||||
@@ -277,11 +347,18 @@ pub fn read_messages_with_truncation(
|
||||
let mut messages = Vec::new();
|
||||
let mut corrupted_lines = Vec::new();
|
||||
let mut line_number = 1;
|
||||
let mut message_count = 0;
|
||||
|
||||
// Read the first line as metadata or create default if empty/missing
|
||||
if let Some(line_result) = lines.next() {
|
||||
match line_result {
|
||||
Ok(line) => {
|
||||
// Security check: line length
|
||||
if line.len() > MAX_LINE_LENGTH {
|
||||
tracing::warn!("Line {} exceeds length limit", line_number);
|
||||
return Err(anyhow::anyhow!("Line too long"));
|
||||
}
|
||||
|
||||
// Try to parse as metadata, but if it fails, treat it as a message
|
||||
if let Ok(_metadata) = serde_json::from_str::<SessionMetadata>(&line) {
|
||||
// Metadata successfully parsed, continue with the rest of the lines as messages
|
||||
@@ -290,6 +367,7 @@ pub fn read_messages_with_truncation(
|
||||
match parse_message_with_truncation(&line, max_content_size) {
|
||||
Ok(message) => {
|
||||
messages.push(message);
|
||||
message_count += 1;
|
||||
}
|
||||
Err(e) => {
|
||||
println!("[SESSION] Failed to parse first line as message: {}", e);
|
||||
@@ -303,6 +381,7 @@ pub fn read_messages_with_truncation(
|
||||
"[SESSION] Successfully recovered corrupted first line!"
|
||||
);
|
||||
messages.push(recovered);
|
||||
message_count += 1;
|
||||
}
|
||||
Err(recovery_err) => {
|
||||
println!(
|
||||
@@ -327,38 +406,63 @@ pub fn read_messages_with_truncation(
|
||||
|
||||
// Read the rest of the lines as messages
|
||||
for line_result in lines {
|
||||
match line_result {
|
||||
Ok(line) => match parse_message_with_truncation(&line, max_content_size) {
|
||||
Ok(message) => {
|
||||
messages.push(message);
|
||||
}
|
||||
Err(e) => {
|
||||
println!("[SESSION] Failed to parse line {}: {}", line_number, e);
|
||||
println!(
|
||||
"[SESSION] Attempting to recover corrupted line {}...",
|
||||
line_number
|
||||
);
|
||||
tracing::warn!("Failed to parse line {}: {}", line_number, e);
|
||||
// Security check: message count limit
|
||||
if message_count >= MAX_MESSAGE_COUNT {
|
||||
tracing::warn!("Message count limit reached: {}", MAX_MESSAGE_COUNT);
|
||||
println!(
|
||||
"[SESSION] Message count limit reached, stopping at {}",
|
||||
MAX_MESSAGE_COUNT
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
// Try to recover the corrupted line
|
||||
match attempt_corruption_recovery(&line, max_content_size) {
|
||||
Ok(recovered) => {
|
||||
println!(
|
||||
"[SESSION] Successfully recovered corrupted line {}!",
|
||||
line_number
|
||||
);
|
||||
messages.push(recovered);
|
||||
}
|
||||
Err(recovery_err) => {
|
||||
println!(
|
||||
"[SESSION] Failed to recover corrupted line {}: {}",
|
||||
line_number, recovery_err
|
||||
);
|
||||
corrupted_lines.push((line_number, line));
|
||||
match line_result {
|
||||
Ok(line) => {
|
||||
// Security check: line length
|
||||
if line.len() > MAX_LINE_LENGTH {
|
||||
tracing::warn!("Line {} exceeds length limit", line_number);
|
||||
corrupted_lines.push((
|
||||
line_number,
|
||||
"[Line too long - truncated for security]".to_string(),
|
||||
));
|
||||
line_number += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
match parse_message_with_truncation(&line, max_content_size) {
|
||||
Ok(message) => {
|
||||
messages.push(message);
|
||||
message_count += 1;
|
||||
}
|
||||
Err(e) => {
|
||||
println!("[SESSION] Failed to parse line {}: {}", line_number, e);
|
||||
println!(
|
||||
"[SESSION] Attempting to recover corrupted line {}...",
|
||||
line_number
|
||||
);
|
||||
tracing::warn!("Failed to parse line {}: {}", line_number, e);
|
||||
|
||||
// Try to recover the corrupted line
|
||||
match attempt_corruption_recovery(&line, max_content_size) {
|
||||
Ok(recovered) => {
|
||||
println!(
|
||||
"[SESSION] Successfully recovered corrupted line {}!",
|
||||
line_number
|
||||
);
|
||||
messages.push(recovered);
|
||||
message_count += 1;
|
||||
}
|
||||
Err(recovery_err) => {
|
||||
println!(
|
||||
"[SESSION] Failed to recover corrupted line {}: {}",
|
||||
line_number, recovery_err
|
||||
);
|
||||
corrupted_lines.push((line_number, line));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
Err(e) => {
|
||||
println!("[SESSION] Failed to read line {}: {}", line_number, e);
|
||||
tracing::error!("Failed to read line {}: {}", line_number, e);
|
||||
@@ -375,7 +479,7 @@ pub fn read_messages_with_truncation(
|
||||
corrupted_lines.len()
|
||||
);
|
||||
tracing::warn!(
|
||||
"[SESSION] Found {} corrupted lines in session file, creating backup",
|
||||
"Found {} corrupted lines in session file, creating backup",
|
||||
corrupted_lines.len()
|
||||
);
|
||||
|
||||
@@ -390,7 +494,7 @@ pub fn read_messages_with_truncation(
|
||||
}
|
||||
}
|
||||
|
||||
// Log details about corrupted lines
|
||||
// Log details about corrupted lines (with limited detail for security)
|
||||
for (num, line) in &corrupted_lines {
|
||||
let preview = if line.len() > 50 {
|
||||
format!("{}... (truncated)", &line[..50])
|
||||
@@ -401,11 +505,6 @@ pub fn read_messages_with_truncation(
|
||||
}
|
||||
}
|
||||
|
||||
println!(
|
||||
"[SESSION] Finished reading session file. Total messages: {}, corrupted lines: {}",
|
||||
messages.len(),
|
||||
corrupted_lines.len()
|
||||
);
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
@@ -444,7 +543,6 @@ fn parse_message_with_truncation(
|
||||
|
||||
match serde_json::from_str::<Message>(&truncated_json) {
|
||||
Ok(message) => {
|
||||
println!("[SESSION] Successfully parsed message after truncation");
|
||||
tracing::info!("Successfully parsed message after JSON truncation");
|
||||
Ok(message)
|
||||
}
|
||||
@@ -628,8 +726,6 @@ fn try_fix_json_corruption(json_str: &str, max_content_size: Option<usize>) -> R
|
||||
}
|
||||
|
||||
if !fixes_applied.is_empty() {
|
||||
println!("[SESSION] Applied JSON fixes: {}", fixes_applied.join(", "));
|
||||
|
||||
match serde_json::from_str::<Message>(&fixed_json) {
|
||||
Ok(mut message) => {
|
||||
if let Some(max_size) = max_content_size {
|
||||
@@ -690,15 +786,6 @@ fn try_extract_partial_message(json_str: &str) -> Result<Message> {
|
||||
}
|
||||
|
||||
if !extracted_text.is_empty() {
|
||||
println!(
|
||||
"[SESSION] Extracted text content: {}",
|
||||
if extracted_text.len() > 50 {
|
||||
&extracted_text[..50]
|
||||
} else {
|
||||
&extracted_text
|
||||
}
|
||||
);
|
||||
|
||||
let message = match role {
|
||||
mcp_core::role::Role::User => Message::user(),
|
||||
mcp_core::role::Role::Assistant => Message::assistant(),
|
||||
@@ -731,8 +818,6 @@ fn try_fix_truncated_json(json_str: &str, max_content_size: Option<usize>) -> Re
|
||||
completed_json.push('}');
|
||||
}
|
||||
|
||||
println!("[SESSION] Attempting to complete truncated JSON");
|
||||
|
||||
match serde_json::from_str::<Message>(&completed_json) {
|
||||
Ok(mut message) => {
|
||||
if let Some(max_size) = max_content_size {
|
||||
@@ -787,45 +872,51 @@ fn truncate_json_string(json_str: &str, max_content_size: usize) -> String {
|
||||
result
|
||||
}
|
||||
|
||||
/// Read session metadata from a session file
|
||||
/// Read session metadata from a session file with security validation
|
||||
///
|
||||
/// Returns default empty metadata if the file doesn't exist or has no metadata.
|
||||
/// Includes security checks for file access and content validation.
|
||||
pub fn read_metadata(session_file: &Path) -> Result<SessionMetadata> {
|
||||
println!("[SESSION] Reading metadata from: {:?}", session_file);
|
||||
// Validate the path for security
|
||||
let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?;
|
||||
|
||||
if !session_file.exists() {
|
||||
println!("[SESSION] Session file doesn't exist, returning default metadata");
|
||||
if !secure_path.exists() {
|
||||
return Ok(SessionMetadata::default());
|
||||
}
|
||||
|
||||
let file = fs::File::open(session_file)?;
|
||||
// Security check: file size
|
||||
let file_metadata = fs::metadata(&secure_path)?;
|
||||
if file_metadata.len() > MAX_FILE_SIZE {
|
||||
tracing::warn!("Session file exceeds size limit during metadata read");
|
||||
return Err(anyhow::anyhow!("Session file too large"));
|
||||
}
|
||||
|
||||
let file = fs::File::open(&secure_path).map_err(|e| {
|
||||
tracing::error!("Failed to open session file for metadata read: {}", e);
|
||||
anyhow::anyhow!("Failed to access session file")
|
||||
})?;
|
||||
let mut reader = io::BufReader::new(file);
|
||||
let mut first_line = String::new();
|
||||
|
||||
// Read just the first line
|
||||
if reader.read_line(&mut first_line)? > 0 {
|
||||
println!("[SESSION] Read first line, attempting to parse as metadata...");
|
||||
// Security check: line length
|
||||
if first_line.len() > MAX_LINE_LENGTH {
|
||||
tracing::warn!("Metadata line exceeds length limit");
|
||||
return Err(anyhow::anyhow!("Metadata line too long"));
|
||||
}
|
||||
|
||||
// Try to parse as metadata
|
||||
match serde_json::from_str::<SessionMetadata>(&first_line) {
|
||||
Ok(metadata) => {
|
||||
println!(
|
||||
"[SESSION] Successfully parsed metadata: description='{}'",
|
||||
metadata.description
|
||||
);
|
||||
Ok(metadata)
|
||||
}
|
||||
Ok(metadata) => Ok(metadata),
|
||||
Err(e) => {
|
||||
// If the first line isn't metadata, return default
|
||||
println!(
|
||||
"[SESSION] First line is not valid metadata ({}), returning default",
|
||||
e
|
||||
);
|
||||
tracing::debug!("Metadata parse error: {}", e);
|
||||
Ok(SessionMetadata::default())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Empty file, return default
|
||||
println!("[SESSION] File is empty, returning default metadata");
|
||||
Ok(SessionMetadata::default())
|
||||
}
|
||||
}
|
||||
@@ -834,21 +925,16 @@ pub fn read_metadata(session_file: &Path) -> Result<SessionMetadata> {
|
||||
///
|
||||
/// Overwrites the file with metadata as the first line, followed by all messages in JSONL format.
|
||||
/// If a provider is supplied, it will automatically generate a description when appropriate.
|
||||
///
|
||||
/// Security features:
|
||||
/// - Validates file paths to prevent directory traversal
|
||||
/// - Uses secure file operations via persist_messages_with_schedule_id
|
||||
pub async fn persist_messages(
|
||||
session_file: &Path,
|
||||
messages: &[Message],
|
||||
provider: Option<Arc<dyn Provider>>,
|
||||
) -> Result<()> {
|
||||
println!(
|
||||
"[SESSION] persist_messages called with {} messages to: {:?}",
|
||||
messages.len(),
|
||||
session_file
|
||||
);
|
||||
let result = persist_messages_with_schedule_id(session_file, messages, provider, None).await;
|
||||
match &result {
|
||||
Ok(_) => println!("[SESSION] persist_messages completed successfully"),
|
||||
Err(e) => println!("[SESSION] persist_messages failed: {}", e),
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
@@ -856,12 +942,26 @@ pub async fn persist_messages(
|
||||
///
|
||||
/// Overwrites the file with metadata as the first line, followed by all messages in JSONL format.
|
||||
/// If a provider is supplied, it will automatically generate a description when appropriate.
|
||||
///
|
||||
/// Security features:
|
||||
/// - Validates file paths to prevent directory traversal
|
||||
/// - Limits error message details in logs
|
||||
/// - Uses atomic file operations via save_messages_with_metadata
|
||||
pub async fn persist_messages_with_schedule_id(
|
||||
session_file: &Path,
|
||||
messages: &[Message],
|
||||
provider: Option<Arc<dyn Provider>>,
|
||||
schedule_id: Option<String>,
|
||||
) -> Result<()> {
|
||||
// Validate the session file path for security
|
||||
let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?;
|
||||
|
||||
// Security check: message count limit
|
||||
if messages.len() > MAX_MESSAGE_COUNT {
|
||||
tracing::warn!("Message count exceeds limit: {}", messages.len());
|
||||
return Err(anyhow::anyhow!("Too many messages"));
|
||||
}
|
||||
|
||||
// Count user messages
|
||||
let user_message_count = messages
|
||||
.iter()
|
||||
@@ -872,29 +972,35 @@ pub async fn persist_messages_with_schedule_id(
|
||||
match provider {
|
||||
Some(provider) if user_message_count < 4 => {
|
||||
//generate_description is responsible for writing the messages
|
||||
generate_description_with_schedule_id(session_file, messages, provider, schedule_id)
|
||||
generate_description_with_schedule_id(&secure_path, messages, provider, schedule_id)
|
||||
.await
|
||||
}
|
||||
_ => {
|
||||
// Read existing metadata
|
||||
let mut metadata = read_metadata(session_file)?;
|
||||
let mut metadata = read_metadata(&secure_path)?;
|
||||
// Update the schedule_id if provided
|
||||
if schedule_id.is_some() {
|
||||
metadata.schedule_id = schedule_id;
|
||||
}
|
||||
// Write the file with metadata and messages
|
||||
save_messages_with_metadata(session_file, &metadata, messages)
|
||||
save_messages_with_metadata(&secure_path, &metadata, messages)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Write messages to a session file with the provided metadata using atomic operations
|
||||
/// Write messages to a session file with the provided metadata using secure atomic operations
|
||||
///
|
||||
/// This function uses atomic file operations to prevent corruption:
|
||||
/// 1. Writes to a temporary file first
|
||||
/// 1. Writes to a temporary file first with secure permissions
|
||||
/// 2. Uses fs2 file locking to prevent concurrent writes
|
||||
/// 3. Atomically moves the temp file to the final location
|
||||
/// 4. Includes comprehensive error handling and recovery
|
||||
///
|
||||
/// Security features:
|
||||
/// - Secure temporary file creation with restricted permissions
|
||||
/// - Path validation to prevent directory traversal
|
||||
/// - File size and message count limits
|
||||
/// - Sanitized error messages to prevent information leakage
|
||||
pub fn save_messages_with_metadata(
|
||||
session_file: &Path,
|
||||
metadata: &SessionMetadata,
|
||||
@@ -902,89 +1008,106 @@ pub fn save_messages_with_metadata(
|
||||
) -> Result<()> {
|
||||
use fs2::FileExt;
|
||||
|
||||
println!(
|
||||
"[SESSION] Starting to save {} messages to: {:?}",
|
||||
messages.len(),
|
||||
session_file
|
||||
);
|
||||
// Validate the path for security
|
||||
let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?;
|
||||
|
||||
// Create a temporary file in the same directory to ensure atomic move
|
||||
let temp_file = session_file.with_extension("tmp");
|
||||
println!("[SESSION] Using temporary file: {:?}", temp_file);
|
||||
|
||||
// Ensure the parent directory exists
|
||||
if let Some(parent) = session_file.parent() {
|
||||
println!("[SESSION] Ensuring parent directory exists: {:?}", parent);
|
||||
fs::create_dir_all(parent)?;
|
||||
// Security check: message count limit
|
||||
if messages.len() > MAX_MESSAGE_COUNT {
|
||||
tracing::warn!(
|
||||
"Message count exceeds limit during save: {}",
|
||||
messages.len()
|
||||
);
|
||||
return Err(anyhow::anyhow!("Too many messages to save"));
|
||||
}
|
||||
|
||||
// Create and lock the temporary file
|
||||
println!("[SESSION] Creating and locking temporary file...");
|
||||
// Create a temporary file in the same directory to ensure atomic move
|
||||
let temp_file = secure_path.with_extension("tmp");
|
||||
|
||||
// Ensure the parent directory exists
|
||||
if let Some(parent) = secure_path.parent() {
|
||||
fs::create_dir_all(parent).map_err(|e| {
|
||||
tracing::error!("Failed to create parent directory: {}", e);
|
||||
anyhow::anyhow!("Failed to create session directory")
|
||||
})?;
|
||||
}
|
||||
|
||||
// Create and lock the temporary file with secure permissions
|
||||
let file = fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.open(&temp_file)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to create temporary file {:?}: {}", temp_file, e))?;
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to create temporary file: {}", e);
|
||||
anyhow::anyhow!("Failed to create temporary session file")
|
||||
})?;
|
||||
|
||||
// Set secure file permissions (Unix only - read/write for owner only)
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let mut perms = file.metadata()?.permissions();
|
||||
perms.set_mode(0o600); // rw-------
|
||||
fs::set_permissions(&temp_file, perms).map_err(|e| {
|
||||
tracing::error!("Failed to set secure file permissions: {}", e);
|
||||
anyhow::anyhow!("Failed to secure temporary file")
|
||||
})?;
|
||||
}
|
||||
|
||||
// Get an exclusive lock on the file
|
||||
println!("[SESSION] Acquiring exclusive lock...");
|
||||
file.try_lock_exclusive()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to lock file: {}", e))?;
|
||||
file.try_lock_exclusive().map_err(|e| {
|
||||
tracing::error!("Failed to lock file: {}", e);
|
||||
anyhow::anyhow!("Failed to lock session file")
|
||||
})?;
|
||||
|
||||
// Write to temporary file
|
||||
{
|
||||
println!(
|
||||
"[SESSION] Writing metadata and {} messages to temporary file...",
|
||||
messages.len()
|
||||
);
|
||||
let mut writer = io::BufWriter::new(&file);
|
||||
|
||||
// Write metadata as the first line
|
||||
println!("[SESSION] Writing metadata as first line...");
|
||||
serde_json::to_writer(&mut writer, &metadata)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to serialize metadata: {}", e))?;
|
||||
serde_json::to_writer(&mut writer, &metadata).map_err(|e| {
|
||||
tracing::error!("Failed to serialize metadata: {}", e);
|
||||
anyhow::anyhow!("Failed to write session metadata")
|
||||
})?;
|
||||
writeln!(writer)?;
|
||||
|
||||
// Write all messages
|
||||
println!("[SESSION] Writing {} messages...", messages.len());
|
||||
// Write all messages with progress tracking
|
||||
for (i, message) in messages.iter().enumerate() {
|
||||
serde_json::to_writer(&mut writer, &message)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to serialize message {}: {}", i, e))?;
|
||||
serde_json::to_writer(&mut writer, &message).map_err(|e| {
|
||||
tracing::error!("Failed to serialize message {}: {}", i, e);
|
||||
anyhow::anyhow!("Failed to write session message")
|
||||
})?;
|
||||
writeln!(writer)?;
|
||||
|
||||
if (i + 1) % 50 == 0 {
|
||||
println!("[SESSION] Written {} messages so far...", i + 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure all data is written to disk
|
||||
println!("[SESSION] Flushing writer buffer...");
|
||||
writer.flush()?;
|
||||
writer.flush().map_err(|e| {
|
||||
tracing::error!("Failed to flush writer: {}", e);
|
||||
anyhow::anyhow!("Failed to flush session data")
|
||||
})?;
|
||||
}
|
||||
|
||||
// Sync to ensure data is persisted
|
||||
println!("[SESSION] Syncing data to disk...");
|
||||
file.sync_all()?;
|
||||
|
||||
// Release the lock
|
||||
println!("[SESSION] Releasing file lock...");
|
||||
fs2::FileExt::unlock(&file).map_err(|e| anyhow::anyhow!("Failed to unlock file: {}", e))?;
|
||||
|
||||
// Atomically move the temporary file to the final location
|
||||
println!("[SESSION] Atomically moving temp file to final location...");
|
||||
fs::rename(&temp_file, session_file).map_err(|e| {
|
||||
// Clean up temp file on failure
|
||||
println!("[SESSION] Failed to move temp file, cleaning up...");
|
||||
let _ = fs::remove_file(&temp_file);
|
||||
anyhow::anyhow!("Failed to move temporary file to final location: {}", e)
|
||||
file.sync_all().map_err(|e| {
|
||||
tracing::error!("Failed to sync data: {}", e);
|
||||
anyhow::anyhow!("Failed to sync session data")
|
||||
})?;
|
||||
|
||||
println!(
|
||||
"[SESSION] Successfully saved session file: {:?}",
|
||||
session_file
|
||||
);
|
||||
tracing::debug!("Successfully saved session file: {:?}", session_file);
|
||||
// Release the lock
|
||||
fs2::FileExt::unlock(&file).map_err(|e| {
|
||||
tracing::error!("Failed to unlock file: {}", e);
|
||||
anyhow::anyhow!("Failed to unlock session file")
|
||||
})?;
|
||||
|
||||
// Atomically move the temporary file to the final location
|
||||
fs::rename(&temp_file, &secure_path).map_err(|e| {
|
||||
// Clean up temp file on failure
|
||||
tracing::error!("Failed to move temporary file: {}", e);
|
||||
let _ = fs::remove_file(&temp_file);
|
||||
anyhow::anyhow!("Failed to finalize session file")
|
||||
})?;
|
||||
|
||||
tracing::debug!("Successfully saved session file: {:?}", secure_path);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1004,21 +1127,47 @@ pub async fn generate_description(
|
||||
///
|
||||
/// This function is called when appropriate to generate a short description
|
||||
/// of the session based on the conversation history.
|
||||
///
|
||||
/// Security features:
|
||||
/// - Validates file paths to prevent directory traversal
|
||||
/// - Limits context size to prevent resource exhaustion
|
||||
/// - Uses secure file operations for saving
|
||||
pub async fn generate_description_with_schedule_id(
|
||||
session_file: &Path,
|
||||
messages: &[Message],
|
||||
provider: Arc<dyn Provider>,
|
||||
schedule_id: Option<String>,
|
||||
) -> Result<()> {
|
||||
// Validate the path for security
|
||||
let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?;
|
||||
|
||||
// Security check: message count limit
|
||||
if messages.len() > MAX_MESSAGE_COUNT {
|
||||
tracing::warn!(
|
||||
"Message count exceeds limit during description generation: {}",
|
||||
messages.len()
|
||||
);
|
||||
return Err(anyhow::anyhow!(
|
||||
"Too many messages for description generation"
|
||||
));
|
||||
}
|
||||
|
||||
// Create a special message asking for a 3-word description
|
||||
let mut description_prompt = "Based on the conversation so far, provide a concise description of this session in 4 words or less. This will be used for finding the session later in a UI with limited space - reply *ONLY* with the description".to_string();
|
||||
|
||||
// get context from messages so far, limiting each message to 300 chars
|
||||
// get context from messages so far, limiting each message to 300 chars for security
|
||||
let context: Vec<String> = messages
|
||||
.iter()
|
||||
.filter(|m| m.role == mcp_core::role::Role::User)
|
||||
.take(3) // Use up to first 3 user messages for context
|
||||
.map(|m| m.as_concat_text())
|
||||
.map(|m| {
|
||||
let text = m.as_concat_text();
|
||||
if text.len() > 300 {
|
||||
format!("{}...", &text[..300])
|
||||
} else {
|
||||
text
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !context.is_empty() {
|
||||
@@ -1029,7 +1178,7 @@ pub async fn generate_description_with_schedule_id(
|
||||
);
|
||||
}
|
||||
|
||||
// Generate the description
|
||||
// Generate the description with error handling
|
||||
let message = Message::user().with_text(&description_prompt);
|
||||
let result = provider
|
||||
.complete(
|
||||
@@ -1037,30 +1186,49 @@ pub async fn generate_description_with_schedule_id(
|
||||
&[message],
|
||||
&[],
|
||||
)
|
||||
.await?;
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to generate session description: {}", e);
|
||||
anyhow::anyhow!("Failed to generate session description")
|
||||
})?;
|
||||
|
||||
let description = result.0.as_concat_text();
|
||||
|
||||
// Validate description length for security
|
||||
let sanitized_description = if description.len() > 100 {
|
||||
tracing::warn!("Generated description too long, truncating");
|
||||
format!("{}...", &description[..97])
|
||||
} else {
|
||||
description
|
||||
};
|
||||
|
||||
// Read current metadata
|
||||
let mut metadata = read_metadata(session_file)?;
|
||||
let mut metadata = read_metadata(&secure_path)?;
|
||||
|
||||
// Update description and schedule_id
|
||||
metadata.description = description;
|
||||
metadata.description = sanitized_description;
|
||||
if schedule_id.is_some() {
|
||||
metadata.schedule_id = schedule_id;
|
||||
}
|
||||
|
||||
// Update the file with the new metadata and existing messages
|
||||
save_messages_with_metadata(session_file, &metadata, messages)
|
||||
save_messages_with_metadata(&secure_path, &metadata, messages)
|
||||
}
|
||||
|
||||
/// Update only the metadata in a session file, preserving all messages
|
||||
///
|
||||
/// Security features:
|
||||
/// - Validates file paths to prevent directory traversal
|
||||
/// - Uses secure file operations for reading and writing
|
||||
pub async fn update_metadata(session_file: &Path, metadata: &SessionMetadata) -> Result<()> {
|
||||
// Validate the path for security
|
||||
let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?;
|
||||
|
||||
// Read all messages from the file
|
||||
let messages = read_messages(session_file)?;
|
||||
let messages = read_messages(&secure_path)?;
|
||||
|
||||
// Rewrite the file with the new metadata and existing messages
|
||||
save_messages_with_metadata(session_file, metadata, &messages)
|
||||
save_messages_with_metadata(&secure_path, metadata, &messages)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -801,9 +801,19 @@ impl TemporalScheduler {
|
||||
let mut has_active_session = false;
|
||||
|
||||
for (session_name, _) in recent_sessions {
|
||||
let session_path = crate::session::storage::get_path(
|
||||
crate::session::storage::Identifier::Name(session_name),
|
||||
);
|
||||
let session_path = match crate::session::storage::get_path(
|
||||
crate::session::storage::Identifier::Name(session_name.clone()),
|
||||
) {
|
||||
Ok(path) => path,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Failed to get session path for '{}': {}",
|
||||
session_name,
|
||||
e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Check if session file was modified recently (within last 5 minutes instead of 2)
|
||||
if let Ok(metadata) = std::fs::metadata(&session_path) {
|
||||
@@ -899,9 +909,23 @@ impl TemporalScheduler {
|
||||
|
||||
if let Some((session_name, _session_metadata)) = recent_sessions.first() {
|
||||
// Check if this session is still active by looking at the session file
|
||||
let session_path = crate::session::storage::get_path(
|
||||
let session_path = match crate::session::storage::get_path(
|
||||
crate::session::storage::Identifier::Name(session_name.clone()),
|
||||
);
|
||||
) {
|
||||
Ok(path) => path,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Failed to get session path for '{}': {}",
|
||||
session_name,
|
||||
e
|
||||
);
|
||||
// Fallback: return a temporal session ID with current time
|
||||
let session_id =
|
||||
format!("temporal-{}-{}", sched_id, Utc::now().timestamp());
|
||||
let start_time = Utc::now();
|
||||
return Ok(Some((session_id, start_time)));
|
||||
}
|
||||
};
|
||||
|
||||
// If the session file was modified recently (within last 5 minutes),
|
||||
// consider it as the current running session
|
||||
|
||||
Reference in New Issue
Block a user