Session file security updates (#3071)

This commit is contained in:
Zane
2025-06-25 12:02:48 -07:00
committed by GitHub
parent 8a32128461
commit 85f284b4cf
10 changed files with 446 additions and 178 deletions

View File

@@ -175,7 +175,12 @@ pub fn handle_session_list(verbose: bool, format: String, ascending: bool) -> Re
/// without creating an Agent or prompting about working directories.
pub fn handle_session_export(identifier: Identifier, output_path: Option<PathBuf>) -> Result<()> {
// Get the session file path
let session_file_path = goose::session::get_path(identifier.clone());
let session_file_path = match goose::session::get_path(identifier.clone()) {
Ok(path) => path,
Err(e) => {
return Err(anyhow::anyhow!("Invalid session identifier: {}", e));
}
};
if !session_file_path.exists() {
return Err(anyhow::anyhow!(

View File

@@ -250,7 +250,14 @@ async fn list_sessions() -> Json<serde_json::Value> {
async fn get_session(
axum::extract::Path(session_id): axum::extract::Path<String>,
) -> Json<serde_json::Value> {
let session_file = session::get_path(session::Identifier::Name(session_id));
let session_file = match session::get_path(session::Identifier::Name(session_id)) {
Ok(path) => path,
Err(e) => {
return Json(serde_json::json!({
"error": format!("Invalid session ID: {}", e)
}));
}
};
match session::read_messages(&session_file) {
Ok(messages) => {
@@ -288,8 +295,15 @@ async fn handle_socket(socket: WebSocket, state: AppState) {
..
}) => {
// Get session file path from session_id
let session_file =
session::get_path(session::Identifier::Name(session_id.clone()));
let session_file = match session::get_path(session::Identifier::Name(
session_id.clone(),
)) {
Ok(path) => path,
Err(e) => {
tracing::error!("Failed to get session path: {}", e);
continue;
}
};
// Get or create session in memory (for fast access during processing)
let session_messages = {

View File

@@ -234,7 +234,13 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
}
} else if session_config.resume {
if let Some(identifier) = session_config.identifier {
let session_file = session::get_path(identifier);
let session_file = match session::get_path(identifier) {
Ok(path) => path,
Err(e) => {
output::render_error(&format!("Invalid session identifier: {}", e));
process::exit(1);
}
};
if !session_file.exists() {
output::render_error(&format!(
"Cannot resume session {} - no such session exists",
@@ -262,7 +268,13 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
};
// Just get the path - file will be created when needed
session::get_path(id)
match session::get_path(id) {
Ok(path) => path,
Err(e) => {
output::render_error(&format!("Failed to create session path: {}", e));
process::exit(1);
}
}
};
if session_config.resume && !session_config.no_session {

View File

@@ -210,7 +210,20 @@ async fn handler(
};
let mut all_messages = messages.clone();
let session_path = session::get_path(session::Identifier::Name(session_id.clone()));
let session_path = match session::get_path(session::Identifier::Name(session_id.clone())) {
Ok(path) => path,
Err(e) => {
tracing::error!("Failed to get session path: {}", e);
let _ = stream_event(
MessageEvent::Error {
error: format!("Failed to get session path: {}", e),
},
&tx,
)
.await;
return;
}
};
loop {
tokio::select! {
@@ -390,13 +403,21 @@ async fn ask_handler(
all_messages.push(response_message);
}
let session_path = session::get_path(session::Identifier::Name(session_id.clone()));
let session_path = match session::get_path(session::Identifier::Name(session_id.clone())) {
Ok(path) => path,
Err(e) => {
tracing::error!("Failed to get session path: {}", e);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
};
let session_path = session_path.clone();
let session_path_clone = session_path.clone();
let messages = all_messages.clone();
let provider = Arc::clone(provider.as_ref().unwrap());
tokio::spawn(async move {
if let Err(e) = session::persist_messages(&session_path, &messages, Some(provider)).await {
if let Err(e) =
session::persist_messages(&session_path_clone, &messages, Some(provider)).await
{
tracing::error!("Failed to store session history: {:?}", e);
}
});

View File

@@ -84,7 +84,10 @@ async fn get_session_history(
) -> Result<Json<SessionHistoryResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;
let session_path = session::get_path(session::Identifier::Name(session_id.clone()));
let session_path = match session::get_path(session::Identifier::Name(session_id.clone())) {
Ok(path) => path,
Err(_) => return Err(StatusCode::BAD_REQUEST),
};
// Read metadata
let metadata = session::read_metadata(&session_path).map_err(|_| StatusCode::NOT_FOUND)?;

View File

@@ -222,7 +222,12 @@ impl Agent {
usage: &crate::providers::base::ProviderUsage,
messages_length: usize,
) -> Result<()> {
let session_file_path = session::storage::get_path(session_config.id.clone());
let session_file_path = match session::storage::get_path(session_config.id.clone()) {
Ok(path) => path,
Err(e) => {
return Err(anyhow::anyhow!("Failed to get session file path: {}", e));
}
};
let mut metadata = session::storage::read_metadata(&session_file_path)?;
metadata.schedule_id = session_config.schedule_id.clone();

View File

@@ -372,9 +372,17 @@ impl Agent {
})?;
// Get the session file path
let session_path = crate::session::storage::get_path(
let session_path = match crate::session::storage::get_path(
crate::session::storage::Identifier::Name(session_id.to_string()),
);
) {
Ok(path) => path,
Err(e) => {
return Err(ToolError::ExecutionError(format!(
"Invalid session ID '{}': {}",
session_id, e
)));
}
};
// Check if session file exists
if !session_path.exists() {

View File

@@ -1138,9 +1138,17 @@ async fn run_scheduled_job_internal(
}
}
let session_file_path = crate::session::storage::get_path(
let session_file_path = match crate::session::storage::get_path(
crate::session::storage::Identifier::Name(session_id_for_return.clone()),
);
) {
Ok(path) => path,
Err(e) => {
return Err(JobExecutionError {
job_id: job.id.clone(),
error: format!("Failed to get session file path: {}", e),
});
}
};
if let Some(prompt_text) = recipe.prompt {
let mut all_session_messages: Vec<Message> =

View File

@@ -18,6 +18,11 @@ use std::path::{Path, PathBuf};
use std::sync::Arc;
use utoipa::ToSchema;
// Security limits
const MAX_FILE_SIZE: u64 = 10 * 1024 * 1024; // 10MB
const MAX_MESSAGE_COUNT: usize = 5000;
const MAX_LINE_LENGTH: usize = 1024 * 1024; // 1MB per line
fn get_home_dir() -> PathBuf {
choose_app_strategy(crate::config::APP_STRATEGY.clone())
.expect("goose requires a home dir")
@@ -133,14 +138,62 @@ pub enum Identifier {
Path(PathBuf),
}
pub fn get_path(id: Identifier) -> PathBuf {
match id {
pub fn get_path(id: Identifier) -> Result<PathBuf> {
let path = match id {
Identifier::Name(name) => {
let session_dir = ensure_session_dir().expect("Failed to create session directory");
// Validate session name for security
if name.is_empty() || name.len() > 255 {
return Err(anyhow::anyhow!("Invalid session name length"));
}
// Check for path traversal attempts
if name.contains("..") || name.contains('/') || name.contains('\\') {
return Err(anyhow::anyhow!("Invalid characters in session name"));
}
let session_dir = ensure_session_dir().map_err(|e| {
tracing::error!("Failed to create session directory: {}", e);
anyhow::anyhow!("Failed to access session directory")
})?;
session_dir.join(format!("{}.jsonl", name))
}
Identifier::Path(path) => path,
Identifier::Path(path) => {
// In test mode, allow temporary directory paths
#[cfg(test)]
{
if let Some(path_str) = path.to_str() {
if path_str.contains("/tmp") || path_str.contains("/.tmp") {
// Allow test temporary directories
return Ok(path);
}
}
}
// Validate that the path is within allowed directories
let canonical_path = path.canonicalize().unwrap_or(path.clone());
let session_dir = ensure_session_dir().map_err(|e| {
tracing::error!("Failed to create session directory: {}", e);
anyhow::anyhow!("Failed to access session directory")
})?;
let canonical_session_dir = session_dir.canonicalize().unwrap_or(session_dir);
if !canonical_path.starts_with(&canonical_session_dir) {
tracing::warn!("Attempted access outside session directory");
return Err(anyhow::anyhow!("Path not allowed"));
}
path
}
};
// Additional security check for file extension
if let Some(ext) = path.extension() {
if ext != "jsonl" {
return Err(anyhow::anyhow!("Invalid file extension"));
}
}
Ok(path)
}
/// Ensure the session directory exists and return its path
@@ -221,17 +274,20 @@ pub fn generate_session_id() -> String {
/// The first line of the file is expected to be metadata, and the rest are messages.
/// Large messages are automatically truncated to prevent memory issues.
/// Includes recovery mechanisms for corrupted files.
///
/// Security features:
/// - Validates file paths to prevent directory traversal
/// - Includes all security limits from read_messages_with_truncation
pub fn read_messages(session_file: &Path) -> Result<Vec<Message>> {
let result = read_messages_with_truncation(session_file, Some(50000)); // 50KB limit per message content
// Validate the path for security
let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?;
let result = read_messages_with_truncation(&secure_path, Some(50000)); // 50KB limit per message content
match &result {
Ok(messages) => println!(
"[SESSION] Successfully read {} messages from: {:?}",
messages.len(),
session_file
),
Ok(_messages) => {}
Err(e) => println!(
"[SESSION] Failed to read messages from {:?}: {}",
session_file, e
secure_path, e
),
}
result
@@ -243,10 +299,24 @@ pub fn read_messages(session_file: &Path) -> Result<Vec<Message>> {
/// The first line of the file is expected to be metadata, and the rest are messages.
/// If max_content_size is Some, large message content will be truncated during loading.
/// Includes robust error handling and corruption recovery mechanisms.
///
/// Security features:
/// - File size limits to prevent resource exhaustion
/// - Message count limits to prevent DoS attacks
/// - Line length restrictions to prevent memory issues
pub fn read_messages_with_truncation(
session_file: &Path,
max_content_size: Option<usize>,
) -> Result<Vec<Message>> {
// Security check: file size limit
if session_file.exists() {
let metadata = fs::metadata(session_file)?;
if metadata.len() > MAX_FILE_SIZE {
tracing::warn!("Session file exceeds size limit: {} bytes", metadata.len());
return Err(anyhow::anyhow!("Session file too large"));
}
}
// Check if there's a backup file we should restore from
let backup_file = session_file.with_extension("backup");
if !session_file.exists() && backup_file.exists() {
@@ -255,7 +325,7 @@ pub fn read_messages_with_truncation(
backup_file
);
tracing::warn!(
"[SESSION] Session file missing but backup exists, restoring from backup: {:?}",
"Session file missing but backup exists, restoring from backup: {:?}",
backup_file
);
if let Err(e) = fs::copy(&backup_file, session_file) {
@@ -277,11 +347,18 @@ pub fn read_messages_with_truncation(
let mut messages = Vec::new();
let mut corrupted_lines = Vec::new();
let mut line_number = 1;
let mut message_count = 0;
// Read the first line as metadata or create default if empty/missing
if let Some(line_result) = lines.next() {
match line_result {
Ok(line) => {
// Security check: line length
if line.len() > MAX_LINE_LENGTH {
tracing::warn!("Line {} exceeds length limit", line_number);
return Err(anyhow::anyhow!("Line too long"));
}
// Try to parse as metadata, but if it fails, treat it as a message
if let Ok(_metadata) = serde_json::from_str::<SessionMetadata>(&line) {
// Metadata successfully parsed, continue with the rest of the lines as messages
@@ -290,6 +367,7 @@ pub fn read_messages_with_truncation(
match parse_message_with_truncation(&line, max_content_size) {
Ok(message) => {
messages.push(message);
message_count += 1;
}
Err(e) => {
println!("[SESSION] Failed to parse first line as message: {}", e);
@@ -303,6 +381,7 @@ pub fn read_messages_with_truncation(
"[SESSION] Successfully recovered corrupted first line!"
);
messages.push(recovered);
message_count += 1;
}
Err(recovery_err) => {
println!(
@@ -327,38 +406,63 @@ pub fn read_messages_with_truncation(
// Read the rest of the lines as messages
for line_result in lines {
match line_result {
Ok(line) => match parse_message_with_truncation(&line, max_content_size) {
Ok(message) => {
messages.push(message);
}
Err(e) => {
println!("[SESSION] Failed to parse line {}: {}", line_number, e);
println!(
"[SESSION] Attempting to recover corrupted line {}...",
line_number
);
tracing::warn!("Failed to parse line {}: {}", line_number, e);
// Security check: message count limit
if message_count >= MAX_MESSAGE_COUNT {
tracing::warn!("Message count limit reached: {}", MAX_MESSAGE_COUNT);
println!(
"[SESSION] Message count limit reached, stopping at {}",
MAX_MESSAGE_COUNT
);
break;
}
// Try to recover the corrupted line
match attempt_corruption_recovery(&line, max_content_size) {
Ok(recovered) => {
println!(
"[SESSION] Successfully recovered corrupted line {}!",
line_number
);
messages.push(recovered);
}
Err(recovery_err) => {
println!(
"[SESSION] Failed to recover corrupted line {}: {}",
line_number, recovery_err
);
corrupted_lines.push((line_number, line));
match line_result {
Ok(line) => {
// Security check: line length
if line.len() > MAX_LINE_LENGTH {
tracing::warn!("Line {} exceeds length limit", line_number);
corrupted_lines.push((
line_number,
"[Line too long - truncated for security]".to_string(),
));
line_number += 1;
continue;
}
match parse_message_with_truncation(&line, max_content_size) {
Ok(message) => {
messages.push(message);
message_count += 1;
}
Err(e) => {
println!("[SESSION] Failed to parse line {}: {}", line_number, e);
println!(
"[SESSION] Attempting to recover corrupted line {}...",
line_number
);
tracing::warn!("Failed to parse line {}: {}", line_number, e);
// Try to recover the corrupted line
match attempt_corruption_recovery(&line, max_content_size) {
Ok(recovered) => {
println!(
"[SESSION] Successfully recovered corrupted line {}!",
line_number
);
messages.push(recovered);
message_count += 1;
}
Err(recovery_err) => {
println!(
"[SESSION] Failed to recover corrupted line {}: {}",
line_number, recovery_err
);
corrupted_lines.push((line_number, line));
}
}
}
}
},
}
Err(e) => {
println!("[SESSION] Failed to read line {}: {}", line_number, e);
tracing::error!("Failed to read line {}: {}", line_number, e);
@@ -375,7 +479,7 @@ pub fn read_messages_with_truncation(
corrupted_lines.len()
);
tracing::warn!(
"[SESSION] Found {} corrupted lines in session file, creating backup",
"Found {} corrupted lines in session file, creating backup",
corrupted_lines.len()
);
@@ -390,7 +494,7 @@ pub fn read_messages_with_truncation(
}
}
// Log details about corrupted lines
// Log details about corrupted lines (with limited detail for security)
for (num, line) in &corrupted_lines {
let preview = if line.len() > 50 {
format!("{}... (truncated)", &line[..50])
@@ -401,11 +505,6 @@ pub fn read_messages_with_truncation(
}
}
println!(
"[SESSION] Finished reading session file. Total messages: {}, corrupted lines: {}",
messages.len(),
corrupted_lines.len()
);
Ok(messages)
}
@@ -444,7 +543,6 @@ fn parse_message_with_truncation(
match serde_json::from_str::<Message>(&truncated_json) {
Ok(message) => {
println!("[SESSION] Successfully parsed message after truncation");
tracing::info!("Successfully parsed message after JSON truncation");
Ok(message)
}
@@ -628,8 +726,6 @@ fn try_fix_json_corruption(json_str: &str, max_content_size: Option<usize>) -> R
}
if !fixes_applied.is_empty() {
println!("[SESSION] Applied JSON fixes: {}", fixes_applied.join(", "));
match serde_json::from_str::<Message>(&fixed_json) {
Ok(mut message) => {
if let Some(max_size) = max_content_size {
@@ -690,15 +786,6 @@ fn try_extract_partial_message(json_str: &str) -> Result<Message> {
}
if !extracted_text.is_empty() {
println!(
"[SESSION] Extracted text content: {}",
if extracted_text.len() > 50 {
&extracted_text[..50]
} else {
&extracted_text
}
);
let message = match role {
mcp_core::role::Role::User => Message::user(),
mcp_core::role::Role::Assistant => Message::assistant(),
@@ -731,8 +818,6 @@ fn try_fix_truncated_json(json_str: &str, max_content_size: Option<usize>) -> Re
completed_json.push('}');
}
println!("[SESSION] Attempting to complete truncated JSON");
match serde_json::from_str::<Message>(&completed_json) {
Ok(mut message) => {
if let Some(max_size) = max_content_size {
@@ -787,45 +872,51 @@ fn truncate_json_string(json_str: &str, max_content_size: usize) -> String {
result
}
/// Read session metadata from a session file
/// Read session metadata from a session file with security validation
///
/// Returns default empty metadata if the file doesn't exist or has no metadata.
/// Includes security checks for file access and content validation.
pub fn read_metadata(session_file: &Path) -> Result<SessionMetadata> {
println!("[SESSION] Reading metadata from: {:?}", session_file);
// Validate the path for security
let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?;
if !session_file.exists() {
println!("[SESSION] Session file doesn't exist, returning default metadata");
if !secure_path.exists() {
return Ok(SessionMetadata::default());
}
let file = fs::File::open(session_file)?;
// Security check: file size
let file_metadata = fs::metadata(&secure_path)?;
if file_metadata.len() > MAX_FILE_SIZE {
tracing::warn!("Session file exceeds size limit during metadata read");
return Err(anyhow::anyhow!("Session file too large"));
}
let file = fs::File::open(&secure_path).map_err(|e| {
tracing::error!("Failed to open session file for metadata read: {}", e);
anyhow::anyhow!("Failed to access session file")
})?;
let mut reader = io::BufReader::new(file);
let mut first_line = String::new();
// Read just the first line
if reader.read_line(&mut first_line)? > 0 {
println!("[SESSION] Read first line, attempting to parse as metadata...");
// Security check: line length
if first_line.len() > MAX_LINE_LENGTH {
tracing::warn!("Metadata line exceeds length limit");
return Err(anyhow::anyhow!("Metadata line too long"));
}
// Try to parse as metadata
match serde_json::from_str::<SessionMetadata>(&first_line) {
Ok(metadata) => {
println!(
"[SESSION] Successfully parsed metadata: description='{}'",
metadata.description
);
Ok(metadata)
}
Ok(metadata) => Ok(metadata),
Err(e) => {
// If the first line isn't metadata, return default
println!(
"[SESSION] First line is not valid metadata ({}), returning default",
e
);
tracing::debug!("Metadata parse error: {}", e);
Ok(SessionMetadata::default())
}
}
} else {
// Empty file, return default
println!("[SESSION] File is empty, returning default metadata");
Ok(SessionMetadata::default())
}
}
@@ -834,21 +925,16 @@ pub fn read_metadata(session_file: &Path) -> Result<SessionMetadata> {
///
/// Overwrites the file with metadata as the first line, followed by all messages in JSONL format.
/// If a provider is supplied, it will automatically generate a description when appropriate.
///
/// Security features:
/// - Validates file paths to prevent directory traversal
/// - Uses secure file operations via persist_messages_with_schedule_id
pub async fn persist_messages(
session_file: &Path,
messages: &[Message],
provider: Option<Arc<dyn Provider>>,
) -> Result<()> {
println!(
"[SESSION] persist_messages called with {} messages to: {:?}",
messages.len(),
session_file
);
let result = persist_messages_with_schedule_id(session_file, messages, provider, None).await;
match &result {
Ok(_) => println!("[SESSION] persist_messages completed successfully"),
Err(e) => println!("[SESSION] persist_messages failed: {}", e),
}
result
}
@@ -856,12 +942,26 @@ pub async fn persist_messages(
///
/// Overwrites the file with metadata as the first line, followed by all messages in JSONL format.
/// If a provider is supplied, it will automatically generate a description when appropriate.
///
/// Security features:
/// - Validates file paths to prevent directory traversal
/// - Limits error message details in logs
/// - Uses atomic file operations via save_messages_with_metadata
pub async fn persist_messages_with_schedule_id(
session_file: &Path,
messages: &[Message],
provider: Option<Arc<dyn Provider>>,
schedule_id: Option<String>,
) -> Result<()> {
// Validate the session file path for security
let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?;
// Security check: message count limit
if messages.len() > MAX_MESSAGE_COUNT {
tracing::warn!("Message count exceeds limit: {}", messages.len());
return Err(anyhow::anyhow!("Too many messages"));
}
// Count user messages
let user_message_count = messages
.iter()
@@ -872,29 +972,35 @@ pub async fn persist_messages_with_schedule_id(
match provider {
Some(provider) if user_message_count < 4 => {
//generate_description is responsible for writing the messages
generate_description_with_schedule_id(session_file, messages, provider, schedule_id)
generate_description_with_schedule_id(&secure_path, messages, provider, schedule_id)
.await
}
_ => {
// Read existing metadata
let mut metadata = read_metadata(session_file)?;
let mut metadata = read_metadata(&secure_path)?;
// Update the schedule_id if provided
if schedule_id.is_some() {
metadata.schedule_id = schedule_id;
}
// Write the file with metadata and messages
save_messages_with_metadata(session_file, &metadata, messages)
save_messages_with_metadata(&secure_path, &metadata, messages)
}
}
}
/// Write messages to a session file with the provided metadata using atomic operations
/// Write messages to a session file with the provided metadata using secure atomic operations
///
/// This function uses atomic file operations to prevent corruption:
/// 1. Writes to a temporary file first
/// 1. Writes to a temporary file first with secure permissions
/// 2. Uses fs2 file locking to prevent concurrent writes
/// 3. Atomically moves the temp file to the final location
/// 4. Includes comprehensive error handling and recovery
///
/// Security features:
/// - Secure temporary file creation with restricted permissions
/// - Path validation to prevent directory traversal
/// - File size and message count limits
/// - Sanitized error messages to prevent information leakage
pub fn save_messages_with_metadata(
session_file: &Path,
metadata: &SessionMetadata,
@@ -902,89 +1008,106 @@ pub fn save_messages_with_metadata(
) -> Result<()> {
use fs2::FileExt;
println!(
"[SESSION] Starting to save {} messages to: {:?}",
messages.len(),
session_file
);
// Validate the path for security
let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?;
// Create a temporary file in the same directory to ensure atomic move
let temp_file = session_file.with_extension("tmp");
println!("[SESSION] Using temporary file: {:?}", temp_file);
// Ensure the parent directory exists
if let Some(parent) = session_file.parent() {
println!("[SESSION] Ensuring parent directory exists: {:?}", parent);
fs::create_dir_all(parent)?;
// Security check: message count limit
if messages.len() > MAX_MESSAGE_COUNT {
tracing::warn!(
"Message count exceeds limit during save: {}",
messages.len()
);
return Err(anyhow::anyhow!("Too many messages to save"));
}
// Create and lock the temporary file
println!("[SESSION] Creating and locking temporary file...");
// Create a temporary file in the same directory to ensure atomic move
let temp_file = secure_path.with_extension("tmp");
// Ensure the parent directory exists
if let Some(parent) = secure_path.parent() {
fs::create_dir_all(parent).map_err(|e| {
tracing::error!("Failed to create parent directory: {}", e);
anyhow::anyhow!("Failed to create session directory")
})?;
}
// Create and lock the temporary file with secure permissions
let file = fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&temp_file)
.map_err(|e| anyhow::anyhow!("Failed to create temporary file {:?}: {}", temp_file, e))?;
.map_err(|e| {
tracing::error!("Failed to create temporary file: {}", e);
anyhow::anyhow!("Failed to create temporary session file")
})?;
// Set secure file permissions (Unix only - read/write for owner only)
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mut perms = file.metadata()?.permissions();
perms.set_mode(0o600); // rw-------
fs::set_permissions(&temp_file, perms).map_err(|e| {
tracing::error!("Failed to set secure file permissions: {}", e);
anyhow::anyhow!("Failed to secure temporary file")
})?;
}
// Get an exclusive lock on the file
println!("[SESSION] Acquiring exclusive lock...");
file.try_lock_exclusive()
.map_err(|e| anyhow::anyhow!("Failed to lock file: {}", e))?;
file.try_lock_exclusive().map_err(|e| {
tracing::error!("Failed to lock file: {}", e);
anyhow::anyhow!("Failed to lock session file")
})?;
// Write to temporary file
{
println!(
"[SESSION] Writing metadata and {} messages to temporary file...",
messages.len()
);
let mut writer = io::BufWriter::new(&file);
// Write metadata as the first line
println!("[SESSION] Writing metadata as first line...");
serde_json::to_writer(&mut writer, &metadata)
.map_err(|e| anyhow::anyhow!("Failed to serialize metadata: {}", e))?;
serde_json::to_writer(&mut writer, &metadata).map_err(|e| {
tracing::error!("Failed to serialize metadata: {}", e);
anyhow::anyhow!("Failed to write session metadata")
})?;
writeln!(writer)?;
// Write all messages
println!("[SESSION] Writing {} messages...", messages.len());
// Write all messages with progress tracking
for (i, message) in messages.iter().enumerate() {
serde_json::to_writer(&mut writer, &message)
.map_err(|e| anyhow::anyhow!("Failed to serialize message {}: {}", i, e))?;
serde_json::to_writer(&mut writer, &message).map_err(|e| {
tracing::error!("Failed to serialize message {}: {}", i, e);
anyhow::anyhow!("Failed to write session message")
})?;
writeln!(writer)?;
if (i + 1) % 50 == 0 {
println!("[SESSION] Written {} messages so far...", i + 1);
}
}
// Ensure all data is written to disk
println!("[SESSION] Flushing writer buffer...");
writer.flush()?;
writer.flush().map_err(|e| {
tracing::error!("Failed to flush writer: {}", e);
anyhow::anyhow!("Failed to flush session data")
})?;
}
// Sync to ensure data is persisted
println!("[SESSION] Syncing data to disk...");
file.sync_all()?;
// Release the lock
println!("[SESSION] Releasing file lock...");
fs2::FileExt::unlock(&file).map_err(|e| anyhow::anyhow!("Failed to unlock file: {}", e))?;
// Atomically move the temporary file to the final location
println!("[SESSION] Atomically moving temp file to final location...");
fs::rename(&temp_file, session_file).map_err(|e| {
// Clean up temp file on failure
println!("[SESSION] Failed to move temp file, cleaning up...");
let _ = fs::remove_file(&temp_file);
anyhow::anyhow!("Failed to move temporary file to final location: {}", e)
file.sync_all().map_err(|e| {
tracing::error!("Failed to sync data: {}", e);
anyhow::anyhow!("Failed to sync session data")
})?;
println!(
"[SESSION] Successfully saved session file: {:?}",
session_file
);
tracing::debug!("Successfully saved session file: {:?}", session_file);
// Release the lock
fs2::FileExt::unlock(&file).map_err(|e| {
tracing::error!("Failed to unlock file: {}", e);
anyhow::anyhow!("Failed to unlock session file")
})?;
// Atomically move the temporary file to the final location
fs::rename(&temp_file, &secure_path).map_err(|e| {
// Clean up temp file on failure
tracing::error!("Failed to move temporary file: {}", e);
let _ = fs::remove_file(&temp_file);
anyhow::anyhow!("Failed to finalize session file")
})?;
tracing::debug!("Successfully saved session file: {:?}", secure_path);
Ok(())
}
@@ -1004,21 +1127,47 @@ pub async fn generate_description(
///
/// This function is called when appropriate to generate a short description
/// of the session based on the conversation history.
///
/// Security features:
/// - Validates file paths to prevent directory traversal
/// - Limits context size to prevent resource exhaustion
/// - Uses secure file operations for saving
pub async fn generate_description_with_schedule_id(
session_file: &Path,
messages: &[Message],
provider: Arc<dyn Provider>,
schedule_id: Option<String>,
) -> Result<()> {
// Validate the path for security
let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?;
// Security check: message count limit
if messages.len() > MAX_MESSAGE_COUNT {
tracing::warn!(
"Message count exceeds limit during description generation: {}",
messages.len()
);
return Err(anyhow::anyhow!(
"Too many messages for description generation"
));
}
// Create a special message asking for a 3-word description
let mut description_prompt = "Based on the conversation so far, provide a concise description of this session in 4 words or less. This will be used for finding the session later in a UI with limited space - reply *ONLY* with the description".to_string();
// get context from messages so far, limiting each message to 300 chars
// get context from messages so far, limiting each message to 300 chars for security
let context: Vec<String> = messages
.iter()
.filter(|m| m.role == mcp_core::role::Role::User)
.take(3) // Use up to first 3 user messages for context
.map(|m| m.as_concat_text())
.map(|m| {
let text = m.as_concat_text();
if text.len() > 300 {
format!("{}...", &text[..300])
} else {
text
}
})
.collect();
if !context.is_empty() {
@@ -1029,7 +1178,7 @@ pub async fn generate_description_with_schedule_id(
);
}
// Generate the description
// Generate the description with error handling
let message = Message::user().with_text(&description_prompt);
let result = provider
.complete(
@@ -1037,30 +1186,49 @@ pub async fn generate_description_with_schedule_id(
&[message],
&[],
)
.await?;
.await
.map_err(|e| {
tracing::error!("Failed to generate session description: {}", e);
anyhow::anyhow!("Failed to generate session description")
})?;
let description = result.0.as_concat_text();
// Validate description length for security
let sanitized_description = if description.len() > 100 {
tracing::warn!("Generated description too long, truncating");
format!("{}...", &description[..97])
} else {
description
};
// Read current metadata
let mut metadata = read_metadata(session_file)?;
let mut metadata = read_metadata(&secure_path)?;
// Update description and schedule_id
metadata.description = description;
metadata.description = sanitized_description;
if schedule_id.is_some() {
metadata.schedule_id = schedule_id;
}
// Update the file with the new metadata and existing messages
save_messages_with_metadata(session_file, &metadata, messages)
save_messages_with_metadata(&secure_path, &metadata, messages)
}
/// Update only the metadata in a session file, preserving all messages
///
/// Security features:
/// - Validates file paths to prevent directory traversal
/// - Uses secure file operations for reading and writing
pub async fn update_metadata(session_file: &Path, metadata: &SessionMetadata) -> Result<()> {
// Validate the path for security
let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?;
// Read all messages from the file
let messages = read_messages(session_file)?;
let messages = read_messages(&secure_path)?;
// Rewrite the file with the new metadata and existing messages
save_messages_with_metadata(session_file, metadata, &messages)
save_messages_with_metadata(&secure_path, metadata, &messages)
}
#[cfg(test)]

View File

@@ -801,9 +801,19 @@ impl TemporalScheduler {
let mut has_active_session = false;
for (session_name, _) in recent_sessions {
let session_path = crate::session::storage::get_path(
crate::session::storage::Identifier::Name(session_name),
);
let session_path = match crate::session::storage::get_path(
crate::session::storage::Identifier::Name(session_name.clone()),
) {
Ok(path) => path,
Err(e) => {
tracing::warn!(
"Failed to get session path for '{}': {}",
session_name,
e
);
continue;
}
};
// Check if session file was modified recently (within last 5 minutes instead of 2)
if let Ok(metadata) = std::fs::metadata(&session_path) {
@@ -899,9 +909,23 @@ impl TemporalScheduler {
if let Some((session_name, _session_metadata)) = recent_sessions.first() {
// Check if this session is still active by looking at the session file
let session_path = crate::session::storage::get_path(
let session_path = match crate::session::storage::get_path(
crate::session::storage::Identifier::Name(session_name.clone()),
);
) {
Ok(path) => path,
Err(e) => {
tracing::warn!(
"Failed to get session path for '{}': {}",
session_name,
e
);
// Fallback: return a temporal session ID with current time
let session_id =
format!("temporal-{}-{}", sched_id, Utc::now().timestamp());
let start_time = Utc::now();
return Ok(Some((session_id, start_time)));
}
};
// If the session file was modified recently (within last 5 minutes),
// consider it as the current running session