fix: working dir was not being set correctly (#3477)

This commit is contained in:
Michael Neale
2025-07-18 12:59:19 +10:00
committed by GitHub
parent 8a6329f8ca
commit a1463e6674
4 changed files with 261 additions and 23 deletions

View File

@@ -475,7 +475,14 @@ async fn process_message_streaming(
}
let provider = provider.unwrap();
session::persist_messages(&session_file, &messages, Some(provider.clone())).await?;
let working_dir = Some(std::env::current_dir()?);
session::persist_messages(
&session_file,
&messages,
Some(provider.clone()),
working_dir.clone(),
)
.await?;
// Create a session config
let session_config = SessionConfig {
@@ -503,7 +510,13 @@ async fn process_message_streaming(
let session_msgs = session_messages.lock().await;
session_msgs.clone()
};
session::persist_messages(&session_file, &current_messages, None).await?;
session::persist_messages(
&session_file,
&current_messages,
None,
working_dir.clone(),
)
.await?;
// Handle different message content types
for content in &message.content {
match content {

View File

@@ -370,11 +370,16 @@ impl Session {
// Persist messages with provider for automatic description generation
if let Some(session_file) = &self.session_file {
let working_dir = Some(
std::env::current_dir().expect("failed to get current session working directory"),
);
session::persist_messages_with_schedule_id(
session_file,
&self.messages,
Some(provider),
self.scheduled_job_id.clone(),
working_dir,
)
.await?;
}
@@ -492,11 +497,17 @@ impl Session {
// Persist messages with provider for automatic description generation
if let Some(session_file) = &self.session_file {
let working_dir = Some(
std::env::current_dir()
.expect("failed to get current session working directory"),
);
session::persist_messages_with_schedule_id(
session_file,
&self.messages,
Some(provider),
self.scheduled_job_id.clone(),
working_dir,
)
.await?;
}
@@ -708,11 +719,13 @@ impl Session {
// Persist the summarized messages
if let Some(session_file) = &self.session_file {
let working_dir = std::env::current_dir().ok();
session::persist_messages_with_schedule_id(
session_file,
&self.messages,
Some(provider),
self.scheduled_job_id.clone(),
working_dir,
)
.await?;
}
@@ -892,11 +905,13 @@ impl Session {
));
push_message(&mut self.messages, response_message);
if let Some(session_file) = &self.session_file {
let working_dir = std::env::current_dir().ok();
session::persist_messages_with_schedule_id(
session_file,
&self.messages,
None,
self.scheduled_job_id.clone(),
working_dir,
)
.await?;
}
@@ -991,11 +1006,13 @@ impl Session {
// No need to update description on assistant messages
if let Some(session_file) = &self.session_file {
let working_dir = std::env::current_dir().ok();
session::persist_messages_with_schedule_id(
session_file,
&self.messages,
None,
self.scheduled_job_id.clone(),
working_dir,
)
.await?;
}
@@ -1208,11 +1225,13 @@ impl Session {
// No need for description update here
if let Some(session_file) = &self.session_file {
let working_dir = std::env::current_dir().ok();
session::persist_messages_with_schedule_id(
session_file,
&self.messages,
None,
self.scheduled_job_id.clone(),
working_dir,
)
.await?;
}
@@ -1225,11 +1244,13 @@ impl Session {
// No need for description update here
if let Some(session_file) = &self.session_file {
let working_dir = std::env::current_dir().ok();
session::persist_messages_with_schedule_id(
session_file,
&self.messages,
None,
self.scheduled_job_id.clone(),
working_dir,
)
.await?;
}
@@ -1247,11 +1268,13 @@ impl Session {
// No need for description update here
if let Some(session_file) = &self.session_file {
let working_dir = std::env::current_dir().ok();
session::persist_messages_with_schedule_id(
session_file,
&self.messages,
None,
self.scheduled_job_id.clone(),
working_dir,
)
.await?;
}

View File

@@ -123,7 +123,7 @@ async fn handler(
let stream = ReceiverStream::new(rx);
let messages = request.messages;
let session_working_dir = request.session_working_dir;
let session_working_dir = request.session_working_dir.clone();
let session_id = request
.session_id
@@ -181,7 +181,7 @@ async fn handler(
&messages,
Some(SessionConfig {
id: session::Identifier::Name(session_id.clone()),
working_dir: PathBuf::from(session_working_dir),
working_dir: PathBuf::from(&session_working_dir),
schedule_id: request.scheduled_job_id.clone(),
execution_mode: None,
max_turns: None,
@@ -297,8 +297,13 @@ async fn handler(
if all_messages.len() > saved_message_count {
let provider = Arc::clone(provider.as_ref().unwrap());
tokio::spawn(async move {
if let Err(e) =
session::persist_messages(&session_path, &all_messages, Some(provider)).await
if let Err(e) = session::persist_messages(
&session_path,
&all_messages,
Some(provider),
Some(PathBuf::from(&session_working_dir)),
)
.await
{
tracing::error!("Failed to store session history: {:?}", e);
}
@@ -337,7 +342,7 @@ async fn ask_handler(
) -> Result<Json<AskResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;
let session_working_dir = request.session_working_dir;
let session_working_dir = request.session_working_dir.clone();
let session_id = request
.session_id
@@ -358,7 +363,7 @@ async fn ask_handler(
&messages,
Some(SessionConfig {
id: session::Identifier::Name(session_id.clone()),
working_dir: PathBuf::from(session_working_dir),
working_dir: PathBuf::from(&session_working_dir),
schedule_id: request.scheduled_job_id.clone(),
execution_mode: None,
max_turns: None,
@@ -420,9 +425,15 @@ async fn ask_handler(
let session_path_clone = session_path.clone();
let messages = all_messages.clone();
let provider = Arc::clone(provider.as_ref().unwrap());
let session_working_dir_clone = session_working_dir.clone();
tokio::spawn(async move {
if let Err(e) =
session::persist_messages(&session_path_clone, &messages, Some(provider)).await
if let Err(e) = session::persist_messages(
&session_path_clone,
&messages,
Some(provider),
Some(PathBuf::from(session_working_dir_clone)),
)
.await
{
tracing::error!("Failed to store session history: {:?}", e);
}

View File

@@ -1042,13 +1042,13 @@ pub fn read_metadata(session_file: &Path) -> Result<SessionMetadata> {
///
/// Security features:
/// - Validates file paths to prevent directory traversal
/// - Uses secure file operations via persist_messages_with_schedule_id
pub async fn persist_messages(
session_file: &Path,
messages: &[Message],
provider: Option<Arc<dyn Provider>>,
working_dir: Option<PathBuf>,
) -> Result<()> {
persist_messages_with_schedule_id(session_file, messages, provider, None).await
persist_messages_with_schedule_id(session_file, messages, provider, None, working_dir).await
}
/// Write messages to a session file with metadata, including an optional scheduled job ID
@@ -1065,6 +1065,7 @@ pub async fn persist_messages_with_schedule_id(
messages: &[Message],
provider: Option<Arc<dyn Provider>>,
schedule_id: Option<String>,
working_dir: Option<PathBuf>,
) -> Result<()> {
// Validate the session file path for security
let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?;
@@ -1085,16 +1086,35 @@ pub async fn persist_messages_with_schedule_id(
match provider {
Some(provider) if user_message_count < 4 => {
//generate_description is responsible for writing the messages
generate_description_with_schedule_id(&secure_path, messages, provider, schedule_id)
.await
generate_description_with_schedule_id(
&secure_path,
messages,
provider,
schedule_id,
working_dir,
)
.await
}
_ => {
// Read existing metadata
let mut metadata = read_metadata(&secure_path)?;
// Read existing metadata or create new with proper working_dir
let mut metadata = if secure_path.exists() {
read_metadata(&secure_path)?
} else {
// Create new metadata with the provided working_dir or fall back to home
let work_dir = working_dir.clone().unwrap_or_else(get_home_dir);
SessionMetadata::new(work_dir)
};
// Update the working_dir if provided (even for existing files)
if let Some(work_dir) = working_dir {
metadata.working_dir = work_dir;
}
// Update the schedule_id if provided
if schedule_id.is_some() {
metadata.schedule_id = schedule_id;
}
// Write the file with metadata and messages
save_messages_with_metadata(&secure_path, &metadata, messages)
}
@@ -1232,11 +1252,12 @@ pub async fn generate_description(
session_file: &Path,
messages: &[Message],
provider: Arc<dyn Provider>,
working_dir: Option<PathBuf>,
) -> Result<()> {
generate_description_with_schedule_id(session_file, messages, provider, None).await
generate_description_with_schedule_id(session_file, messages, provider, None, working_dir).await
}
/// Generate a description for the session using the provider, including an optional scheduled job ID
/// Generate a description for the session using the provider, including an optional scheduled job ID and working directory
///
/// This function is called when appropriate to generate a short description
/// of the session based on the conversation history.
@@ -1250,6 +1271,7 @@ pub async fn generate_description_with_schedule_id(
messages: &[Message],
provider: Arc<dyn Provider>,
schedule_id: Option<String>,
working_dir: Option<PathBuf>,
) -> Result<()> {
// Validate the path for security
let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?;
@@ -1311,7 +1333,14 @@ pub async fn generate_description_with_schedule_id(
description
};
let mut metadata = read_metadata(&secure_path)?;
// Create metadata with proper working_dir or read existing and update
let mut metadata = if secure_path.exists() {
read_metadata(&secure_path)?
} else {
// Create new metadata with the provided working_dir or fall back to home
let work_dir = working_dir.clone().unwrap_or_else(get_home_dir);
SessionMetadata::new(work_dir)
};
// Update description and schedule_id
metadata.description = sanitized_description;
@@ -1319,6 +1348,11 @@ pub async fn generate_description_with_schedule_id(
metadata.schedule_id = schedule_id;
}
// Update the working_dir if provided (even for existing files)
if let Some(work_dir) = working_dir {
metadata.working_dir = work_dir;
}
// Update the file with the new metadata and existing messages
save_messages_with_metadata(&secure_path, &metadata, messages)
}
@@ -1430,7 +1464,7 @@ mod tests {
];
// Write messages
persist_messages(&file_path, &messages, None).await?;
persist_messages(&file_path, &messages, None, None).await?;
// Read them back
let read_messages = read_messages(&file_path)?;
@@ -1538,7 +1572,7 @@ mod tests {
}
// Write messages with special characters
persist_messages(&file_path, &messages, None).await?;
persist_messages(&file_path, &messages, None, None).await?;
// Read them back
let read_messages = read_messages(&file_path)?;
@@ -1603,7 +1637,7 @@ mod tests {
];
// Write messages
persist_messages(&file_path, &messages, None).await?;
persist_messages(&file_path, &messages, None, None).await?;
// Read them back - should be truncated
let read_messages = read_messages(&file_path)?;
@@ -1694,6 +1728,162 @@ mod tests {
Ok(())
}
#[tokio::test]
async fn test_working_dir_preservation() -> Result<()> {
let dir = tempdir()?;
let file_path = dir.path().join("test.jsonl");
// Create a temporary working directory
let working_dir = tempdir()?;
let working_dir_path = working_dir.path().to_path_buf();
// Create messages
let messages = vec![Message::user().with_text("test message")];
// Use persist_messages_with_schedule_id to set working dir
persist_messages_with_schedule_id(
&file_path,
&messages,
None,
None,
Some(working_dir_path.clone()),
)
.await?;
// Read back the metadata and verify working_dir is preserved
let metadata = read_metadata(&file_path)?;
assert_eq!(metadata.working_dir, working_dir_path);
// Verify the messages are also preserved
let read_messages = read_messages(&file_path)?;
assert_eq!(read_messages.len(), 1);
assert_eq!(read_messages[0].role, messages[0].role);
Ok(())
}
#[tokio::test]
async fn test_working_dir_issue_fixed() -> Result<()> {
// This test demonstrates that the working_dir issue in jsonl files is fixed
let dir = tempdir()?;
let file_path = dir.path().join("test.jsonl");
// Create a temporary working directory (this simulates the actual working directory)
let working_dir = tempdir()?;
let working_dir_path = working_dir.path().to_path_buf();
// Create messages
let messages = vec![Message::user().with_text("test message")];
// Get the home directory for comparison
let home_dir = get_home_dir();
// Test 1: Using the old persist_messages function (without working_dir)
// This will fall back to home directory since no working_dir is provided
persist_messages(&file_path, &messages, None, None).await?;
// Read back the metadata - this should now have the home directory as working_dir
let metadata_old = read_metadata(&file_path)?;
assert_eq!(
metadata_old.working_dir, home_dir,
"persist_messages should use home directory when no working_dir is provided"
);
// Test 2: Using persist_messages_with_schedule_id function
// This should properly set the working_dir (this is the main fix)
persist_messages_with_schedule_id(
&file_path,
&messages,
None,
None,
Some(working_dir_path.clone()),
)
.await?;
// Read back the metadata - this should now have the correct working_dir
let metadata_new = read_metadata(&file_path)?;
assert_eq!(
metadata_new.working_dir, working_dir_path,
"persist_messages_with_schedule_id should use provided working_dir"
);
assert_ne!(
metadata_new.working_dir, home_dir,
"working_dir should be different from home directory"
);
// Test 3: Create a new session file without working_dir (should fall back to home)
let file_path_2 = dir.path().join("test2.jsonl");
persist_messages_with_schedule_id(
&file_path_2,
&messages,
None,
None,
None, // No working_dir provided
)
.await?;
let metadata_fallback = read_metadata(&file_path_2)?;
assert_eq!(metadata_fallback.working_dir, home_dir, "persist_messages_with_schedule_id should fall back to home directory when no working_dir is provided");
// Test 4: Test that the fix works for existing files
// Create a session file and then add to it with different working_dir
let file_path_3 = dir.path().join("test3.jsonl");
// First, create with home directory
persist_messages(&file_path_3, &messages, None, None).await?;
let metadata_initial = read_metadata(&file_path_3)?;
assert_eq!(
metadata_initial.working_dir, home_dir,
"Initial session should use home directory"
);
// Then update with a specific working_dir
persist_messages_with_schedule_id(
&file_path_3,
&messages,
None,
None,
Some(working_dir_path.clone()),
)
.await?;
let metadata_updated = read_metadata(&file_path_3)?;
assert_eq!(
metadata_updated.working_dir, working_dir_path,
"Updated session should use new working_dir"
);
// Test 5: Most important test - simulate the real-world scenario where
// CLI and web interfaces pass the current directory instead of None
let file_path_4 = dir.path().join("test4.jsonl");
let current_dir = std::env::current_dir()?;
// This is what web.rs and session/mod.rs do now after the fix
persist_messages_with_schedule_id(
&file_path_4,
&messages,
None,
None,
Some(current_dir.clone()),
)
.await?;
let metadata_current = read_metadata(&file_path_4)?;
assert_eq!(
metadata_current.working_dir, current_dir,
"Session should use current directory when explicitly provided"
);
// This should NOT be the home directory anymore (unless current_dir == home_dir)
if current_dir != home_dir {
assert_ne!(
metadata_current.working_dir, home_dir,
"working_dir should be different from home directory when current_dir is different"
);
}
Ok(())
}
#[test]
fn test_windows_path_validation() -> Result<()> {
// Test the Windows path validation logic
@@ -1781,12 +1971,13 @@ mod tests {
Message::assistant().with_text("Test response"),
];
// Test persist_messages_with_schedule_id with save_session = true
// Test persist_messages_with_schedule_id with working_dir parameter
persist_messages_with_schedule_id(
&file_path,
&messages,
None,
Some("test_schedule".to_string()),
None,
)
.await?;