diff --git a/Cargo.lock b/Cargo.lock index d317aa2..61968b5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2468,6 +2468,7 @@ dependencies = [ "hostname-validator", "httpdate", "page_size", + "percent-encoding", "pkarr", "pkarr-republisher", "postcard", @@ -2479,6 +2480,7 @@ dependencies = [ "tempfile", "thiserror 2.0.12", "tokio", + "tokio-util", "toml", "tower 0.5.2", "tower-cookies", @@ -3557,9 +3559,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.13" +version = "0.7.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" +checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df" dependencies = [ "bytes", "futures-core", diff --git a/e2e/src/tests/auth.rs b/e2e/src/tests/auth.rs index 1e9ee00..88e05ce 100644 --- a/e2e/src/tests/auth.rs +++ b/e2e/src/tests/auth.rs @@ -80,7 +80,11 @@ async fn disabled_user() { // Make sure the user can read their own file let response = client.get(file_url.clone()).send().await.unwrap(); - assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.status(), + StatusCode::OK, + "User should be able to read their own file" + ); let admin_socket = server.admin().listen_socket(); let admin_client = reqwest::Client::new(); diff --git a/e2e/src/tests/rate_limiting.rs b/e2e/src/tests/rate_limiting.rs index 0df8689..116281b 100644 --- a/e2e/src/tests/rate_limiting.rs +++ b/e2e/src/tests/rate_limiting.rs @@ -1,7 +1,8 @@ use std::time::Duration; -use pkarr::Keypair; +use pkarr::{Keypair, PublicKey}; use pubky_testnet::{ + pubky::Client, pubky_homeserver::{ quota_config::{GlobPattern, LimitKey, LimitKeyType, PathLimit}, ConfigToml, MockDataDir, @@ -186,3 +187,127 @@ async fn test_limit_upload() { assert_eq!(res.status(), StatusCode::CREATED); assert!(start.elapsed() > Duration::from_secs(2)); } + +/// Test that 10 clients can write/read to the server concurrently +/// Upload/download rate is limited to 1kb/s per user. +/// 3kb files are used to make the writes/reads take ~2.5s each. +/// Concurrently writing/reading 10 files, the total time taken should be ~3s. +/// If the concurrent writes/reads are not properly handled, the total time taken will be closer to ~25s. +#[tokio::test] +async fn test_concurrent_write_read() { + // Setup the testnet + let mut testnet = Testnet::new().await.unwrap(); + let mut config = ConfigToml::test(); + config.drive.rate_limits = vec![ + PathLimit::new( + // Limit uploads to 1kb/s per user + GlobPattern::new("/pub/**"), + Method::PUT, + "1kb/s".parse().unwrap(), + LimitKeyType::User, + None, + ), + PathLimit::new( + // Limit downloads to 1kb/s per user + GlobPattern::new("/pub/**"), + Method::GET, + "1kb/s".parse().unwrap(), + LimitKeyType::User, + None, + ), + ]; + let mock_dir = MockDataDir::new(config, None).unwrap(); + let hs_pubkey = { + let server = testnet + .create_homeserver_suite_with_mock(mock_dir) + .await + .unwrap(); + server.public_key() + }; + + // Create helper struct to handle clients + #[derive(Clone)] + struct TestClient { + pub keypair: Keypair, + pub client: Client, + } + impl TestClient { + fn new(testnet: &mut Testnet) -> Self { + let keypair = Keypair::random(); + let client = testnet.pubky_client_builder().build().unwrap(); + Self { keypair, client } + } + pub async fn signup(&self, hs_pubkey: &PublicKey) { + self.client + .signup(&self.keypair, &hs_pubkey, None) + .await + .expect("Failed to signup"); + } + pub async fn put(&self, url: Url, body: Vec) { + self.client + .put(url) + .body(body) + .send() + .await + .expect("Failed to put"); + } + pub async fn get(&self, url: Url) { + let response = self.client.get(url).send().await.expect("Failed to get"); + assert_eq!(response.status(), StatusCode::OK, "Failed to get"); + response.bytes().await.expect("Failed to get bytes"); // Download the body + } + } + + // Signup with the clients + let user_count: usize = 10; + let mut clients = vec![0; user_count] + .into_iter() + .map(|_| TestClient::new(&mut testnet)) + .collect::>(); + for client in clients.iter_mut() { + client.signup(&hs_pubkey).await; + } + + // -------------------------------------------------------------------------------------------- + // Write to server concurrently + let start = Instant::now(); + let mut handles = vec![]; + for client in clients.iter() { + let client = client.clone(); + let handle = tokio::spawn(async move { + let url: Url = format!("pubky://{}/pub/test.txt", client.keypair.public_key()) + .parse() + .unwrap(); + let body = vec![0u8; 3 * 1024]; // 2kb + client.put(url, body).await; + }); + handles.push(handle); + } + // Wait for all the writes to finish + for handle in handles { + handle.await.unwrap(); + } + let elapsed = start.elapsed(); + assert!(elapsed < Duration::from_secs(5)); + + // -------------------------------------------------------------------------------------------- + // Read from server concurrently + let start = Instant::now(); + let mut handles = vec![]; + for client in clients.iter() { + let client = client.clone(); + let handle = tokio::spawn(async move { + let url: Url = format!("pubky://{}/pub/test.txt", client.keypair.public_key()) + .parse() + .unwrap(); + client.get(url).await; + }); + handles.push(handle); + } + // Wait for all the reads to finish + for handle in handles { + handle.await.unwrap(); + } + let elapsed = start.elapsed(); + assert!(elapsed < Duration::from_secs(5)); +} diff --git a/pubky-homeserver/Cargo.toml b/pubky-homeserver/Cargo.toml index 53f783a..66a78b1 100644 --- a/pubky-homeserver/Cargo.toml +++ b/pubky-homeserver/Cargo.toml @@ -55,6 +55,8 @@ dyn-clone = "1.0.19" reqwest = "0.12.15" governor = "0.10.0" fast-glob = "0.4.5" +tokio-util = "0.7.15" +percent-encoding = "2.3.1" serde_valid = "1.0.5" diff --git a/pubky-homeserver/src/admin/app.rs b/pubky-homeserver/src/admin/app.rs index 6b7815b..d030ec2 100644 --- a/pubky-homeserver/src/admin/app.rs +++ b/pubky-homeserver/src/admin/app.rs @@ -25,10 +25,7 @@ fn create_protected_router(password: &str) -> Router { get(generate_signup_token::generate_signup_token), ) .route("/info", get(info::info)) - .route( - "/webdav/{pubkey}/{*path}", - delete(delete_entry::delete_entry), - ) + .route("/webdav/{*entry_path}", delete(delete_entry::delete_entry)) .route("/users/{pubkey}/disable", post(disable_user)) .route("/users/{pubkey}/enable", post(enable_user)) .layer(AdminAuthLayer::new(password.to_string())) diff --git a/pubky-homeserver/src/admin/routes/delete_entry.rs b/pubky-homeserver/src/admin/routes/delete_entry.rs index 4124ee2..0954014 100644 --- a/pubky-homeserver/src/admin/routes/delete_entry.rs +++ b/pubky-homeserver/src/admin/routes/delete_entry.rs @@ -1,5 +1,5 @@ use super::super::app_state::AppState; -use crate::shared::{HttpError, HttpResult, Z32Pubkey}; +use crate::shared::{webdav::EntryPath, HttpError, HttpResult}; use axum::{ extract::{Path, State}, http::StatusCode, @@ -9,16 +9,9 @@ use axum::{ /// Delete a single entry from the database. pub async fn delete_entry( State(mut state): State, - Path((pubkey, path)): Path<(Z32Pubkey, String)>, + Path(entry_path): Path, ) -> HttpResult { - let path = format!("/{}", path); // Add missing leading slash - if !path.starts_with("/pub/") { - return Err(HttpError::new( - StatusCode::BAD_REQUEST, - Some("Invalid path"), - )); - } - let deleted = state.db.delete_entry(&pubkey.0, &path)?; + let deleted = state.db.delete_entry(&entry_path).await?; if deleted { Ok((StatusCode::NO_CONTENT, ())) } else { @@ -33,16 +26,14 @@ pub async fn delete_entry( mod tests { use super::super::super::app_state::AppState; use super::*; - use crate::persistence::lmdb::LmDB; + use crate::persistence::lmdb::{tables::files::InDbTempFile, LmDB}; + use crate::shared::webdav::{EntryPath, WebDavPath}; use axum::{routing::delete, Router}; - use pkarr::{Keypair, PublicKey}; - use std::io::Write; + use pkarr::Keypair; - async fn write_test_file(db: &mut LmDB, pubkey: &PublicKey, path: &str) { - let mut entry_writer = db.write_entry(pubkey, path).unwrap(); - let content = b"Hello, world!"; - entry_writer.write_all(content).unwrap(); - let _entry = entry_writer.commit().unwrap(); + async fn write_test_file(db: &mut LmDB, entry_path: &EntryPath) { + let file = InDbTempFile::zeros(10).await.unwrap(); + let _entry = db.write_entry(&entry_path, &file).await.unwrap(); } #[tokio::test] @@ -54,25 +45,25 @@ mod tests { let mut db = LmDB::test(); let app_state = AppState::new(db.clone()); let router = Router::new() - .route("/webdav/{pubkey}/{*path}", delete(delete_entry)) + .route("/webdav/{*entry_path}", delete(delete_entry)) .with_state(app_state); // Write a test file - let entry_path = format!("/pub/{}", file_path); - write_test_file(&mut db, &pubkey, &entry_path).await; + let webdav_path = WebDavPath::new(format!("/pub/{}", file_path).as_str()).unwrap(); + let entry_path = EntryPath::new(pubkey.clone(), webdav_path); + + write_test_file(&mut db, &entry_path).await; // Delete the file let server = axum_test::TestServer::new(router).unwrap(); let response = server - .delete(format!("/webdav/{}{}", pubkey, entry_path).as_str()) + .delete(format!("/webdav/{}{}", pubkey, entry_path.path().as_str()).as_str()) .await; assert_eq!(response.status_code(), StatusCode::NO_CONTENT); // Check that the file is deleted - let rtx = db.env.read_txn().unwrap(); - let entry = db.get_entry(&rtx, &pubkey, &file_path).unwrap(); + let entry = db.get_entry(&entry_path).unwrap(); assert!(entry.is_none(), "Entry should be deleted"); - rtx.commit().unwrap(); let events = db.list_events(None, None).unwrap(); assert_eq!( @@ -92,14 +83,13 @@ mod tests { let file_path = "my_file.txt"; let app_state = AppState::new(LmDB::test()); let router = Router::new() - .route("/webdav/{pubkey}/{*path}", delete(delete_entry)) + .route("/webdav/{*entry_path}", delete(delete_entry)) .with_state(app_state); // Delete the file + let url = format!("/webdav/{}/pub/{}", pubkey, file_path); let server = axum_test::TestServer::new(router).unwrap(); - let response = server - .delete(format!("/webdav/{}/pub/{}", pubkey, file_path).as_str()) - .await; + let response = server.delete(url.as_str()).await; assert_eq!(response.status_code(), StatusCode::NOT_FOUND); } @@ -109,7 +99,7 @@ mod tests { let db = LmDB::test(); let app_state = AppState::new(db.clone()); let router = Router::new() - .route("/webdav/{pubkey}/{*path}", delete(delete_entry)) + .route("/webdav/{*entry_path}", delete(delete_entry)) .with_state(app_state); // Delete with invalid pubkey diff --git a/pubky-homeserver/src/core/routes/tenants/mod.rs b/pubky-homeserver/src/core/routes/tenants/mod.rs index 23c4126..64d214a 100644 --- a/pubky-homeserver/src/core/routes/tenants/mod.rs +++ b/pubky-homeserver/src/core/routes/tenants/mod.rs @@ -3,11 +3,7 @@ //! Every route here is relative to a tenant's Pubky host, //! as opposed to routes relative to the Homeserver's owner. -use axum::{ - extract::DefaultBodyLimit, - routing::{delete, get, head, put}, - Router, -}; +use axum::{extract::DefaultBodyLimit, routing::get, Router}; use crate::core::{layers::authz::AuthorizationLayer, AppState}; @@ -17,16 +13,14 @@ pub mod write; pub fn router(state: AppState) -> Router { Router::new() - // - Datastore routes - .route("/pub/", get(read::get)) - .route("/pub/{*path}", get(read::get)) - .route("/pub/{*path}", head(read::head)) - .route("/pub/{*path}", put(write::put)) - .route("/pub/{*path}", delete(write::delete)) - // - Session routes - .route("/session", get(session::session)) - .route("/session", delete(session::signout)) - // Layers + .route("/session", get(session::session).delete(session::signout)) + .route( + "/{*path}", + get(read::get) + .head(read::head) + .put(write::put) + .delete(write::delete), + ) // TODO: different max size for sessions and other routes? .layer(DefaultBodyLimit::max(100 * 1024 * 1024)) .layer(AuthorizationLayer::new(state.clone())) diff --git a/pubky-homeserver/src/core/routes/tenants/read.rs b/pubky-homeserver/src/core/routes/tenants/read.rs index 8cfea86..ccbcd7b 100644 --- a/pubky-homeserver/src/core/routes/tenants/read.rs +++ b/pubky-homeserver/src/core/routes/tenants/read.rs @@ -1,90 +1,75 @@ +use crate::persistence::lmdb::tables::files::Entry; +use crate::{ + core::{ + err_if_user_is_invalid::err_if_user_is_invalid, + error::{Error, Result}, + extractors::{ListQueryParams, PubkyHost}, + AppState, + }, + shared::webdav::{EntryPath, WebDavPathAxum}, +}; use axum::{ body::Body, - extract::{OriginalUri, State}, + extract::{Path, State}, http::{header, HeaderMap, HeaderValue, Response, StatusCode}, response::IntoResponse, }; use httpdate::HttpDate; -use pkarr::PublicKey; use std::str::FromStr; - -use crate::core::{ - error::{Error, Result}, - extractors::{ListQueryParams, PubkyHost}, - AppState, -}; -use crate::persistence::lmdb::tables::entries::Entry; +use tokio_util::io::ReaderStream; pub async fn head( State(state): State, pubky: PubkyHost, headers: HeaderMap, - path: OriginalUri, + Path(path): Path, ) -> Result { - let rtxn = state.db.env.read_txn()?; - get_entry( - headers, - state - .db - .get_entry(&rtxn, pubky.public_key(), path.0.path())?, - None, - ) + err_if_user_is_invalid(pubky.public_key(), &state.db, false)?; + let entry_path = EntryPath::new(pubky.public_key().clone(), path.0); + let entry = state + .db + .get_entry(&entry_path)? + .ok_or_else(|| Error::with_status(StatusCode::NOT_FOUND))?; + + get_entry(headers, entry, None) } +#[axum::debug_handler] pub async fn get( State(state): State, headers: HeaderMap, pubky: PubkyHost, - path: OriginalUri, + Path(path): Path, params: ListQueryParams, ) -> Result { let public_key = pubky.public_key().clone(); - let path = path.0.path().to_string(); - - if path.ends_with('/') { - return list(state, &public_key, &path, params); + let dav_path = path.0; + let entry_path = EntryPath::new(public_key.clone(), dav_path); + if entry_path.path().is_directory() { + return list(state, &entry_path, params); } - let (entry_tx, entry_rx) = flume::bounded::>(1); - let (chunks_tx, chunks_rx) = flume::unbounded::, heed::Error>>(); + let entry = state + .db + .get_entry(&entry_path)? + .ok_or_else(|| Error::with_status(StatusCode::NOT_FOUND))?; + let buffer_file = state.db.read_file(&entry.file_id()).await?; - tokio::task::spawn_blocking(move || -> anyhow::Result<()> { - let rtxn = state.db.env.read_txn()?; - - let option = state.db.get_entry(&rtxn, &public_key, &path)?; - - if let Some(entry) = option { - let iter = entry.read_content(&state.db, &rtxn)?; - - entry_tx.send(Some(entry))?; - - for next in iter { - chunks_tx.send(next.map(|b| b.to_vec()))?; - } - }; - - entry_tx.send(None)?; - - Ok(()) - }); - - get_entry( - headers, - entry_rx.recv_async().await?, - Some(Body::from_stream(chunks_rx.into_stream())), - ) + let file_handle = buffer_file.open_file_handle()?; + // Async stream the file + let tokio_file_handle = tokio::fs::File::from_std(file_handle); + let body_stream = Body::from_stream(ReaderStream::new(tokio_file_handle)); + get_entry(headers, entry, Some(body_stream)) } pub fn list( state: AppState, - public_key: &PublicKey, - path: &str, + entry_path: &EntryPath, params: ListQueryParams, ) -> Result> { let txn = state.db.env.read_txn()?; - let path = format!("{public_key}{path}"); - if !state.db.contains_directory(&txn, &path)? { + if !state.db.contains_directory(&txn, entry_path)? { return Err(Error::new( StatusCode::NOT_FOUND, "Directory Not Found".into(), @@ -92,9 +77,9 @@ pub fn list( } // Handle listing - let vec = state.db.list( + let vec = state.db.list_entries( &txn, - &path, + entry_path, params.reverse, params.limit, params.cursor, @@ -107,54 +92,44 @@ pub fn list( .body(Body::from(vec.join("\n")))?) } -pub fn get_entry( - headers: HeaderMap, - entry: Option, - body: Option, -) -> Result> { - if let Some(entry) = entry { - // TODO: Enable seek API (range requests) - // TODO: Gzip? or brotli? +pub fn get_entry(headers: HeaderMap, entry: Entry, body: Option) -> Result> { + // TODO: Enable seek API (range requests) + // TODO: Gzip? or brotli? - let mut response = HeaderMap::from(&entry).into_response(); + let mut response = HeaderMap::from(&entry).into_response(); - // Handle IF_MODIFIED_SINCE - if let Some(condition_http_date) = headers - .get(header::IF_MODIFIED_SINCE) - .and_then(|h| h.to_str().ok()) - .and_then(|s| HttpDate::from_str(s).ok()) - { - let entry_http_date: HttpDate = entry.timestamp().to_owned().into(); + // Handle IF_MODIFIED_SINCE + if let Some(condition_http_date) = headers + .get(header::IF_MODIFIED_SINCE) + .and_then(|h| h.to_str().ok()) + .and_then(|s| HttpDate::from_str(s).ok()) + { + let entry_http_date: HttpDate = entry.timestamp().to_owned().into(); - if condition_http_date >= entry_http_date { - *response.status_mut() = StatusCode::NOT_MODIFIED; - } - }; - - // Handle IF_NONE_MATCH - if let Some(str) = headers - .get(header::IF_NONE_MATCH) - .and_then(|h| h.to_str().ok()) - { - let etag = format!("\"{}\"", entry.content_hash()); - if str - .trim() - .split(',') - .collect::>() - .contains(&etag.as_str()) - { - *response.status_mut() = StatusCode::NOT_MODIFIED; - }; + if condition_http_date >= entry_http_date { + *response.status_mut() = StatusCode::NOT_MODIFIED; } + }; - if let Some(body) = body { - *response.body_mut() = body; + // Handle IF_NONE_MATCH + if let Some(str) = headers + .get(header::IF_NONE_MATCH) + .and_then(|h| h.to_str().ok()) + { + let etag = format!("\"{}\"", entry.content_hash()); + if str + .trim() + .split(',') + .collect::>() + .contains(&etag.as_str()) + { + *response.status_mut() = StatusCode::NOT_MODIFIED; }; - - Ok(response) - } else { - Err(Error::with_status(StatusCode::NOT_FOUND))? } + if let Some(body) = body { + *response.body_mut() = body; + } + Ok(response) } impl From<&Entry> for HeaderMap { diff --git a/pubky-homeserver/src/core/routes/tenants/write.rs b/pubky-homeserver/src/core/routes/tenants/write.rs index e364547..01f63f8 100644 --- a/pubky-homeserver/src/core/routes/tenants/write.rs +++ b/pubky-homeserver/src/core/routes/tenants/write.rs @@ -1,21 +1,23 @@ -use std::io::Write; - use axum::{ body::{Body, HttpBody}, - extract::{OriginalUri, State}, + extract::{Path, State}, http::StatusCode, response::IntoResponse, }; use futures_util::stream::StreamExt; -use crate::core::{ - err_if_user_is_invalid::err_if_user_is_invalid, - error::{Error, Result}, - extractors::PubkyHost, - AppState, +use crate::{ + core::{ + err_if_user_is_invalid::err_if_user_is_invalid, + error::{Error, Result}, + extractors::PubkyHost, + AppState, + }, + persistence::lmdb::tables::files::AsyncInDbTempFileWriter, + shared::webdav::{EntryPath, WebDavPathAxum}, }; -/// Fail with 507 if `(current + incoming − existing) > quota`. +/// Fail with 507 if `(current + incoming − existing) > quota`. fn enforce_user_disk_quota( existing_bytes: u64, incoming_bytes: u64, @@ -31,7 +33,7 @@ fn enforce_user_disk_quota( return Err(Error::new( StatusCode::INSUFFICIENT_STORAGE, Some(format!( - "Quota of {:.1} MB exceeded: you’ve used {:.1} MB, trying to add {:.1} MB", + "Quota of {:.1} MB exceeded: you've used {:.1} MB, trying to add {:.1} MB", max_mb, current_mb, adding_mb )), )); @@ -43,15 +45,15 @@ fn enforce_user_disk_quota( pub async fn delete( State(mut state): State, pubky: PubkyHost, - path: OriginalUri, + Path(path): Path, ) -> Result { - err_if_user_is_invalid(pubky.public_key(), &state.db, false)?; let public_key = pubky.public_key(); - let full_path = path.0.path(); - let existing_bytes = state.db.get_entry_content_length(public_key, full_path)?; + err_if_user_is_invalid(pubky.public_key(), &state.db, false)?; + let entry_path = EntryPath::new(public_key.clone(), path.0); + let existing_bytes = state.db.get_entry_content_length(&entry_path)?; // Remove entry - if !state.db.delete_entry(public_key, full_path)? { + if !state.db.delete_entry(&entry_path).await? { return Err(Error::with_status(StatusCode::NOT_FOUND)); } @@ -66,36 +68,38 @@ pub async fn delete( pub async fn put( State(mut state): State, pubky: PubkyHost, - path: OriginalUri, + Path(path): Path, body: Body, ) -> Result { - err_if_user_is_invalid(pubky.public_key(), &state.db, true)?; let public_key = pubky.public_key(); - let full_path = path.0.path(); - let existing_entry_bytes = state.db.get_entry_content_length(public_key, full_path)?; + err_if_user_is_invalid(public_key, &state.db, true)?; + let entry_path = EntryPath::new(public_key.clone(), path.0); + + let existing_entry_bytes = state.db.get_entry_content_length(&entry_path)?; let quota_bytes = state.user_quota_bytes; let used_bytes = state.db.get_user_data_usage(public_key)?; - // Upfront check when we have an exact Content‑Length + // Upfront check when we have an exact Content-Length let hint = body.size_hint().exact(); if let Some(exact_bytes) = hint { enforce_user_disk_quota(existing_entry_bytes, exact_bytes, used_bytes, quota_bytes)?; } - // Stream body - let mut writer = state.db.write_entry(public_key, full_path)?; + // Stream body to disk first. let mut seen_bytes: u64 = 0; let mut stream = body.into_data_stream(); + let mut buffer_file_writer = AsyncInDbTempFileWriter::new().await?; while let Some(chunk) = stream.next().await.transpose()? { seen_bytes += chunk.len() as u64; enforce_user_disk_quota(existing_entry_bytes, seen_bytes, used_bytes, quota_bytes)?; - writer.write_all(&chunk)?; + buffer_file_writer.write_chunk(&chunk).await?; } + let buffer_file = buffer_file_writer.complete().await?; - // Commit & bump usage - let entry = writer.commit()?; - let delta = entry.content_length() as i64 - existing_entry_bytes as i64; + // Write file on disk to db + state.db.write_entry(&entry_path, &buffer_file).await?; + let delta = buffer_file.len() as i64 - existing_entry_bytes as i64; state.db.update_data_usage(public_key, delta)?; Ok((StatusCode::CREATED, ())) diff --git a/pubky-homeserver/src/persistence/lmdb/db.rs b/pubky-homeserver/src/persistence/lmdb/db.rs index e3cdff5..483112d 100644 --- a/pubky-homeserver/src/persistence/lmdb/db.rs +++ b/pubky-homeserver/src/persistence/lmdb/db.rs @@ -11,7 +11,6 @@ pub const DEFAULT_MAP_SIZE: usize = 10995116277760; // 10TB (not = disk-space us pub struct LmDB { pub(crate) env: Env, pub(crate) tables: Tables, - pub(crate) buffers_dir: PathBuf, pub(crate) max_chunk_size: usize, // Only used for testing purposes to keep the testdir alive. #[allow(dead_code)] @@ -44,7 +43,6 @@ impl LmDB { let db = LmDB { env, tables, - buffers_dir, max_chunk_size: Self::max_chunk_size(), test_dir: None, }; diff --git a/pubky-homeserver/src/persistence/lmdb/migrations/m0.rs b/pubky-homeserver/src/persistence/lmdb/migrations/m0.rs index 088e069..4d36f18 100644 --- a/pubky-homeserver/src/persistence/lmdb/migrations/m0.rs +++ b/pubky-homeserver/src/persistence/lmdb/migrations/m0.rs @@ -1,15 +1,15 @@ use heed::{Env, RwTxn}; -use crate::persistence::lmdb::tables::{blobs, entries, events, sessions, signup_tokens, users}; +use crate::persistence::lmdb::tables::{events, files, sessions, signup_tokens, users}; pub fn run(env: &Env, wtxn: &mut RwTxn) -> anyhow::Result<()> { let _: users::UsersTable = env.create_database(wtxn, Some(users::USERS_TABLE))?; let _: sessions::SessionsTable = env.create_database(wtxn, Some(sessions::SESSIONS_TABLE))?; - let _: blobs::BlobsTable = env.create_database(wtxn, Some(blobs::BLOBS_TABLE))?; + let _: files::BlobsTable = env.create_database(wtxn, Some(files::BLOBS_TABLE))?; - let _: entries::EntriesTable = env.create_database(wtxn, Some(entries::ENTRIES_TABLE))?; + let _: files::EntriesTable = env.create_database(wtxn, Some(files::ENTRIES_TABLE))?; let _: events::EventsTable = env.create_database(wtxn, Some(events::EVENTS_TABLE))?; diff --git a/pubky-homeserver/src/persistence/lmdb/tables.rs b/pubky-homeserver/src/persistence/lmdb/tables.rs index 0189c12..56c014c 100644 --- a/pubky-homeserver/src/persistence/lmdb/tables.rs +++ b/pubky-homeserver/src/persistence/lmdb/tables.rs @@ -1,14 +1,13 @@ -pub mod blobs; -pub mod entries; pub mod events; +pub mod files; pub mod sessions; pub mod signup_tokens; pub mod users; use heed::{Env, RwTxn}; -use blobs::{BlobsTable, BLOBS_TABLE}; -use entries::{EntriesTable, ENTRIES_TABLE}; +use files::{BlobsTable, BLOBS_TABLE}; +use files::{EntriesTable, ENTRIES_TABLE}; use self::{ events::{EventsTable, EVENTS_TABLE}, diff --git a/pubky-homeserver/src/persistence/lmdb/tables/blobs.rs b/pubky-homeserver/src/persistence/lmdb/tables/blobs.rs deleted file mode 100644 index 5ec608d..0000000 --- a/pubky-homeserver/src/persistence/lmdb/tables/blobs.rs +++ /dev/null @@ -1,24 +0,0 @@ -use heed::{types::Bytes, Database, RoTxn}; - -use super::super::LmDB; - -use super::entries::Entry; - -/// (entry timestamp | chunk_index BE) => bytes -pub type BlobsTable = Database; - -pub const BLOBS_TABLE: &str = "blobs"; - -impl LmDB { - pub fn read_entry_content<'txn>( - &self, - rtxn: &'txn RoTxn, - entry: &Entry, - ) -> anyhow::Result> + 'txn> { - Ok(self - .tables - .blobs - .prefix_iter(rtxn, &entry.timestamp().to_bytes())? - .map(|i| i.map(|(_, bytes)| bytes))) - } -} diff --git a/pubky-homeserver/src/persistence/lmdb/tables/entries.rs b/pubky-homeserver/src/persistence/lmdb/tables/entries.rs deleted file mode 100644 index 5e4c9e0..0000000 --- a/pubky-homeserver/src/persistence/lmdb/tables/entries.rs +++ /dev/null @@ -1,560 +0,0 @@ -use pkarr::PublicKey; -use postcard::{from_bytes, to_allocvec}; -use serde::{Deserialize, Serialize}; -use std::{ - fs::File, - io::{Read, Write}, - path::PathBuf, -}; -use tracing::instrument; - -use heed::{ - types::{Bytes, Str}, - Database, RoTxn, -}; - -use pubky_common::{ - crypto::{Hash, Hasher}, - timestamp::Timestamp, -}; - -use crate::constants::{DEFAULT_LIST_LIMIT, DEFAULT_MAX_LIST_LIMIT}; - -use super::super::LmDB; - -use super::events::Event; - -/// full_path(pubky/*path) => Entry. -pub type EntriesTable = Database; - -pub const ENTRIES_TABLE: &str = "entries"; - -impl LmDB { - /// Write an entry by an author at a given path. - /// - /// The path has to start with a forward slash `/` - pub fn write_entry( - &mut self, - public_key: &PublicKey, - path: &str, - ) -> anyhow::Result { - EntryWriter::new(self, public_key, path) - } - - /// Delete an entry by an author at a given path. - /// - /// The path has to start with a forward slash `/` - pub fn delete_entry(&mut self, public_key: &PublicKey, path: &str) -> anyhow::Result { - let mut wtxn = self.env.write_txn()?; - - let key = format!("{public_key}{path}"); - - let deleted = if let Some(bytes) = self.tables.entries.get(&wtxn, &key)? { - let entry = Entry::deserialize(bytes)?; - - let mut deleted_chunks = false; - - { - let mut iter = self - .tables - .blobs - .prefix_iter_mut(&mut wtxn, &entry.timestamp.to_bytes())?; - - while iter.next().is_some() { - unsafe { - deleted_chunks = iter.del_current()?; - } - } - } - - let deleted_entry = self.tables.entries.delete(&mut wtxn, &key)?; - - // create DELETE event - if path.starts_with("/pub/") { - let url = format!("pubky://{key}"); - - let event = Event::delete(&url); - let value = event.serialize(); - - let key = Timestamp::now().to_string(); - - self.tables.events.put(&mut wtxn, &key, &value)?; - - // TODO: delete events older than a threshold. - // TODO: move to events.rs - } - - deleted_entry && deleted_chunks - } else { - false - }; - - wtxn.commit()?; - - Ok(deleted) - } - - pub fn get_entry( - &self, - txn: &RoTxn, - public_key: &PublicKey, - path: &str, - ) -> anyhow::Result> { - let key = format!("{public_key}{path}"); - - if let Some(bytes) = self.tables.entries.get(txn, &key)? { - return Ok(Some(Entry::deserialize(bytes)?)); - } - - Ok(None) - } - - /// Bytes stored at `path` for this user (0 if none). - pub fn get_entry_content_length( - &self, - public_key: &PublicKey, - path: &str, - ) -> anyhow::Result { - let txn = self.env.read_txn()?; - let content_length = self - .get_entry(&txn, public_key, path)? - .map(|e| e.content_length() as u64) - .unwrap_or(0); - Ok(content_length) - } - - pub fn contains_directory(&self, txn: &RoTxn, path: &str) -> anyhow::Result { - Ok(self.tables.entries.get_greater_than(txn, path)?.is_some()) - } - - /// Return a list of pubky urls. - /// - /// - limit defaults to [crate::config::DEFAULT_LIST_LIMIT] and capped by [crate::config::DEFAULT_MAX_LIST_LIMIT] - pub fn list( - &self, - txn: &RoTxn, - path: &str, - reverse: bool, - limit: Option, - cursor: Option, - shallow: bool, - ) -> anyhow::Result> { - // Vector to store results - let mut results = Vec::new(); - - let limit = limit - .unwrap_or(DEFAULT_LIST_LIMIT) - .min(DEFAULT_MAX_LIST_LIMIT); - - // TODO: make this more performant than split and allocations? - - let mut threshold = cursor - .map(|cursor| { - // Removing leading forward slashes - let mut file_or_directory = cursor.trim_start_matches('/'); - - if cursor.starts_with("pubky://") { - file_or_directory = cursor.split(path).last().expect("should not be reachable") - }; - - next_threshold( - path, - file_or_directory, - file_or_directory.ends_with('/'), - reverse, - shallow, - ) - }) - .unwrap_or(next_threshold(path, "", false, reverse, shallow)); - - for _ in 0..limit { - if let Some((key, _)) = if reverse { - self.tables.entries.get_lower_than(txn, &threshold)? - } else { - self.tables.entries.get_greater_than(txn, &threshold)? - } { - if !key.starts_with(path) { - break; - } - - if shallow { - let mut split = key[path.len()..].split('/'); - let file_or_directory = split.next().expect("should not be reachable"); - - let is_directory = split.next().is_some(); - - threshold = - next_threshold(path, file_or_directory, is_directory, reverse, shallow); - - results.push(format!( - "pubky://{path}{file_or_directory}{}", - if is_directory { "/" } else { "" } - )); - } else { - threshold = key.to_string(); - results.push(format!("pubky://{}", key)) - } - }; - } - - Ok(results) - } -} - -/// Calculate the next threshold -#[instrument] -fn next_threshold( - path: &str, - file_or_directory: &str, - is_directory: bool, - reverse: bool, - shallow: bool, -) -> String { - format!( - "{path}{file_or_directory}{}", - if file_or_directory.is_empty() { - // No file_or_directory, early return - if reverse { - // `path/to/dir/\x7f` to catch all paths than `path/to/dir/` - "\x7f" - } else { - "" - } - } else if shallow & is_directory { - if reverse { - // threshold = `path/to/dir\x2e`, since `\x2e` is lower than `/` - "\x2e" - } else { - //threshold = `path/to/dir\x7f`, since `\x7f` is greater than `/` - "\x7f" - } - } else { - "" - } - ) -} - -#[derive(Clone, Default, Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct Entry { - /// Encoding version - version: usize, - /// Modified at - timestamp: Timestamp, - content_hash: EntryHash, - content_length: usize, - content_type: String, - // user_metadata: ? -} - -#[derive(Clone, Debug, Eq, PartialEq)] -struct EntryHash(Hash); - -impl Default for EntryHash { - fn default() -> Self { - Self(Hash::from_bytes([0; 32])) - } -} - -impl Serialize for EntryHash { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - let bytes = self.0.as_bytes(); - bytes.serialize(serializer) - } -} - -impl<'de> Deserialize<'de> for EntryHash { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let bytes: [u8; 32] = Deserialize::deserialize(deserializer)?; - Ok(Self(Hash::from_bytes(bytes))) - } -} - -impl Entry { - pub fn new() -> Self { - Default::default() - } - - // === Setters === - - pub fn set_timestamp(&mut self, timestamp: &Timestamp) -> &mut Self { - self.timestamp = *timestamp; - self - } - - pub fn set_content_hash(&mut self, content_hash: Hash) -> &mut Self { - EntryHash(content_hash).clone_into(&mut self.content_hash); - self - } - - pub fn set_content_length(&mut self, content_length: usize) -> &mut Self { - self.content_length = content_length; - self - } - - // === Getters === - - pub fn timestamp(&self) -> &Timestamp { - &self.timestamp - } - - pub fn content_hash(&self) -> &Hash { - &self.content_hash.0 - } - - pub fn content_length(&self) -> usize { - self.content_length - } - - pub fn content_type(&self) -> &str { - &self.content_type - } - - // === Public Method === - - pub fn read_content<'txn>( - &self, - db: &'txn LmDB, - rtxn: &'txn RoTxn, - ) -> anyhow::Result> + 'txn> { - db.read_entry_content(rtxn, self) - } - - pub fn serialize(&self) -> Vec { - to_allocvec(self).expect("Session::serialize") - } - - pub fn deserialize(bytes: &[u8]) -> core::result::Result { - if bytes[0] > 0 { - panic!("Unknown Entry version"); - } - - from_bytes(bytes) - } -} - -pub struct EntryWriter<'db> { - db: &'db LmDB, - buffer: File, - hasher: Hasher, - buffer_path: PathBuf, - entry_key: String, - timestamp: Timestamp, - is_public: bool, -} - -impl<'db> EntryWriter<'db> { - pub fn new(db: &'db LmDB, public_key: &PublicKey, path: &str) -> anyhow::Result { - let hasher = Hasher::new(); - - let timestamp = Timestamp::now(); - - let buffer_path = db.buffers_dir.join(timestamp.to_string()); - - let buffer = File::create(&buffer_path)?; - - let entry_key = format!("{public_key}{path}"); - - Ok(Self { - db, - buffer, - hasher, - buffer_path, - entry_key, - timestamp, - is_public: path.starts_with("/pub/"), - }) - } - - /// Same ase [EntryWriter::write_all] but returns a Result of a mutable reference of itself - /// to enable chaining with [Self::commit]. - pub fn update(&mut self, chunk: &[u8]) -> Result<&mut Self, std::io::Error> { - self.write_all(chunk)?; - - Ok(self) - } - - /// Commit blob from the filesystem buffer to LMDB, - /// write the [Entry], and commit the write transaction. - pub fn commit(&self) -> anyhow::Result { - let hash = self.hasher.finalize(); - - let mut buffer = File::open(&self.buffer_path)?; - - let mut wtxn = self.db.env.write_txn()?; - - let mut chunk_key = [0; 12]; - chunk_key[0..8].copy_from_slice(&self.timestamp.to_bytes()); - - let mut chunk_index: u32 = 0; - - loop { - let mut chunk = vec![0_u8; self.db.max_chunk_size]; - - let bytes_read = buffer.read(&mut chunk)?; - - if bytes_read == 0 { - break; // EOF reached - } - - chunk_key[8..].copy_from_slice(&chunk_index.to_be_bytes()); - - self.db - .tables - .blobs - .put(&mut wtxn, &chunk_key, &chunk[..bytes_read])?; - - chunk_index += 1; - } - - let mut entry = Entry::new(); - entry.set_timestamp(&self.timestamp); - - entry.set_content_hash(hash); - - let length = buffer.metadata()?.len(); - entry.set_content_length(length as usize); - - self.db - .tables - .entries - .put(&mut wtxn, &self.entry_key, &entry.serialize())?; - - // Write a public [Event]. - if self.is_public { - let url = format!("pubky://{}", self.entry_key); - let event = Event::put(&url); - let value = event.serialize(); - - let key = entry.timestamp.to_string(); - - self.db.tables.events.put(&mut wtxn, &key, &value)?; - - // TODO: delete events older than a threshold. - // TODO: move to events.rs - } - - wtxn.commit()?; - - std::fs::remove_file(&self.buffer_path)?; - - Ok(entry) - } -} - -impl std::io::Write for EntryWriter<'_> { - /// Write a chunk to a Filesystem based buffer. - #[inline] - fn write(&mut self, chunk: &[u8]) -> std::io::Result { - self.hasher.update(chunk); - self.buffer.write_all(chunk)?; - - Ok(chunk.len()) - } - - /// Does not do anything, you need to call [Self::commit] - #[inline] - fn flush(&mut self) -> std::io::Result<()> { - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use bytes::Bytes; - use pkarr::Keypair; - - use super::LmDB; - - #[tokio::test] - async fn entries() -> anyhow::Result<()> { - let mut db = LmDB::test(); - - let keypair = Keypair::random(); - let public_key = keypair.public_key(); - let path = "/pub/foo.txt"; - - let chunk = Bytes::from(vec![1, 2, 3, 4, 5]); - - db.write_entry(&public_key, path)? - .update(&chunk)? - .commit()?; - - let rtxn = db.env.read_txn().unwrap(); - let entry = db.get_entry(&rtxn, &public_key, path).unwrap().unwrap(); - - assert_eq!( - entry.content_hash(), - &[ - 2, 79, 103, 192, 66, 90, 61, 192, 47, 186, 245, 140, 185, 61, 229, 19, 46, 61, 117, - 197, 25, 250, 160, 186, 218, 33, 73, 29, 136, 201, 112, 87 - ] - ); - - let mut blob = vec![]; - - { - let mut iter = entry.read_content(&db, &rtxn).unwrap(); - - while let Some(Ok(chunk)) = iter.next() { - blob.extend_from_slice(chunk); - } - } - - assert_eq!(blob, vec![1, 2, 3, 4, 5]); - - rtxn.commit().unwrap(); - - Ok(()) - } - - #[tokio::test] - async fn chunked_entry() -> anyhow::Result<()> { - let mut db = LmDB::test(); - - let keypair = Keypair::random(); - let public_key = keypair.public_key(); - let path = "/pub/foo.txt"; - - let chunk = Bytes::from(vec![0; 1024 * 1024]); - - db.write_entry(&public_key, path)? - .update(&chunk)? - .commit()?; - - let rtxn = db.env.read_txn().unwrap(); - let entry = db.get_entry(&rtxn, &public_key, path).unwrap().unwrap(); - - assert_eq!( - entry.content_hash(), - &[ - 72, 141, 226, 2, 247, 59, 217, 118, 222, 78, 112, 72, 244, 225, 243, 154, 119, 109, - 134, 213, 130, 183, 52, 143, 245, 59, 244, 50, 185, 135, 252, 168 - ] - ); - - let mut blob = vec![]; - - { - let mut iter = entry.read_content(&db, &rtxn).unwrap(); - - while let Some(Ok(chunk)) = iter.next() { - blob.extend_from_slice(chunk); - } - } - - assert_eq!(blob, vec![0; 1024 * 1024]); - - let stats = db.tables.blobs.stat(&rtxn).unwrap(); - assert_eq!(stats.overflow_pages, 0); - - rtxn.commit().unwrap(); - - Ok(()) - } -} diff --git a/pubky-homeserver/src/persistence/lmdb/tables/files/blobs.rs b/pubky-homeserver/src/persistence/lmdb/tables/files/blobs.rs new file mode 100644 index 0000000..c8f008d --- /dev/null +++ b/pubky-homeserver/src/persistence/lmdb/tables/files/blobs.rs @@ -0,0 +1,160 @@ +use super::super::super::LmDB; +use super::{InDbFileId, InDbTempFile, SyncInDbTempFileWriter}; +use heed::{types::Bytes, Database}; +use std::io::Read; + +/// (entry timestamp | chunk_index BE) => bytes +pub type BlobsTable = Database; +pub const BLOBS_TABLE: &str = "blobs"; + +impl LmDB { + /// Read the blobs into a temporary file. + /// + /// The file is written to disk to minimize the size/duration of the LMDB transaction. + pub(crate) fn read_file_sync(&self, id: &InDbFileId) -> anyhow::Result { + let mut file_writer = SyncInDbTempFileWriter::new()?; + let rtxn = self.env.read_txn()?; + let blobs_iter = self + .tables + .blobs + .prefix_iter(&rtxn, &id.bytes())? + .map(|i| i.map(|(_, bytes)| bytes)); + let mut file_exists = false; + for read_result in blobs_iter { + file_exists = true; + let chunk = read_result?; + file_writer.write_chunk(chunk)?; + } + + if !file_exists { + return Ok(InDbTempFile::empty()?); + } + + let file = file_writer.complete()?; + rtxn.commit()?; + Ok(file) + } + + /// Read the blobs into a temporary file asynchronously. + pub(crate) async fn read_file(&self, id: &InDbFileId) -> anyhow::Result { + let db = self.clone(); + let id = *id; + let join_handle = tokio::task::spawn_blocking(move || -> anyhow::Result { + db.read_file_sync(&id) + }) + .await; + match join_handle { + Ok(result) => result, + Err(e) => { + tracing::error!("Error reading file. JoinError: {:?}", e); + Err(e.into()) + } + } + } + + /// Write the blobs from a temporary file to LMDB. + pub(crate) fn write_file_sync<'txn>( + &'txn self, + file: &InDbTempFile, + wtxn: &mut heed::RwTxn<'txn>, + ) -> anyhow::Result { + let id = InDbFileId::new(); + let mut file_handle = file.open_file_handle()?; + + let mut blob_index: u32 = 0; + loop { + let mut blob = vec![0_u8; self.max_chunk_size]; + let bytes_read = file_handle.read(&mut blob)?; + let is_end_of_file = bytes_read == 0; + if is_end_of_file { + break; // EOF reached + } + + let blob_key = id.get_blob_key(blob_index); + self.tables + .blobs + .put(wtxn, &blob_key, &blob[..bytes_read])?; + + blob_index += 1; + } + + Ok(id) + } + + /// Delete the blobs from LMDB. + pub(crate) fn delete_file<'txn>( + &'txn self, + file: &InDbFileId, + wtxn: &mut heed::RwTxn<'txn>, + ) -> anyhow::Result { + let mut deleted_chunks = false; + + { + let mut iter = self.tables.blobs.prefix_iter_mut(wtxn, &file.bytes())?; + + while iter.next().is_some() { + unsafe { + deleted_chunks = iter.del_current()?; + } + } + } + Ok(deleted_chunks) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_write_read_delete_file() { + let lmdb = LmDB::test(); + + // Write file to LMDB + let write_file = InDbTempFile::zeros(50).await.unwrap(); + let mut wtxn = lmdb.env.write_txn().unwrap(); + let id = lmdb.write_file_sync(&write_file, &mut wtxn).unwrap(); + wtxn.commit().unwrap(); + + // Read file from LMDB + let read_file = lmdb.read_file(&id).await.unwrap(); + + assert_eq!(read_file.len(), write_file.len()); + assert_eq!(read_file.hash(), write_file.hash()); + + let written_file_content = std::fs::read(write_file.path()).unwrap(); + let read_file_content = std::fs::read(read_file.path()).unwrap(); + assert_eq!(written_file_content, read_file_content); + + // Delete file from LMDB + let mut wtxn = lmdb.env.write_txn().unwrap(); + let deleted = lmdb.delete_file(&id, &mut wtxn).unwrap(); + wtxn.commit().unwrap(); + assert!(deleted); + + // Try to read file from LMDB + let read_file = lmdb.read_file(&id).await.unwrap(); + assert_eq!(read_file.len(), 0); + } + + #[tokio::test] + async fn test_write_empty_file() { + let lmdb = LmDB::test(); + + // Write file to LMDB + let write_file = InDbTempFile::empty().unwrap(); + let mut wtxn = lmdb.env.write_txn().unwrap(); + let id = lmdb.write_file_sync(&write_file, &mut wtxn).unwrap(); + wtxn.commit().unwrap(); + + // Read file from LMDB + let read_file = lmdb.read_file(&id).await.unwrap(); + + assert_eq!(read_file.len(), write_file.len()); + assert_eq!(read_file.hash(), write_file.hash()); + + let written_file_content = std::fs::read(write_file.path()).unwrap(); + let read_file_content = std::fs::read(read_file.path()).unwrap(); + assert_eq!(written_file_content, read_file_content); + } +} diff --git a/pubky-homeserver/src/persistence/lmdb/tables/files/entries.rs b/pubky-homeserver/src/persistence/lmdb/tables/files/entries.rs new file mode 100644 index 0000000..afd8c8d --- /dev/null +++ b/pubky-homeserver/src/persistence/lmdb/tables/files/entries.rs @@ -0,0 +1,499 @@ +use super::super::events::Event; +use super::{super::super::LmDB, InDbFileId, InDbTempFile}; +use crate::constants::{DEFAULT_LIST_LIMIT, DEFAULT_MAX_LIST_LIMIT}; +use crate::shared::webdav::EntryPath; +use heed::{ + types::{Bytes, Str}, + Database, RoTxn, +}; +use postcard::{from_bytes, to_allocvec}; +use pubky_common::{crypto::Hash, timestamp::Timestamp}; +use serde::{Deserialize, Serialize}; +use tracing::instrument; + +/// full_path(pubky/*path) => Entry. +pub type EntriesTable = Database; + +pub const ENTRIES_TABLE: &str = "entries"; + +impl LmDB { + /// Writes an entry to the database. + /// + /// The entry is written to the database and the file is written to the blob store. + /// An event is written to the events table. + /// The entry is returned. + pub async fn write_entry( + &mut self, + path: &EntryPath, + file: &InDbTempFile, + ) -> anyhow::Result { + let mut db = self.clone(); + let path = path.clone(); + let file = file.clone(); + let join_handle = tokio::task::spawn_blocking(move || -> anyhow::Result { + db.write_entry_sync(&path, &file) + }) + .await; + match join_handle { + Ok(result) => result, + Err(e) => { + tracing::error!("Error writing entry. JoinError: {:?}", e); + Err(e.into()) + } + } + } + + /// Writes an entry to the database. + /// + /// The entry is written to the database and the file is written to the blob store. + /// An event is written to the events table. + /// The entry is returned. + pub fn write_entry_sync( + &mut self, + path: &EntryPath, + file: &InDbTempFile, + ) -> anyhow::Result { + let mut wtxn = self.env.write_txn()?; + let mut entry = Entry::new(); + entry.set_content_hash(*file.hash()); + entry.set_content_length(file.len()); + let file_id = self.write_file_sync(file, &mut wtxn)?; + entry.set_timestamp(file_id.timestamp()); + let entry_key = path.to_string(); + self.tables + .entries + .put(&mut wtxn, entry_key.as_str(), &entry.serialize())?; + + // Write a public [Event]. + let url = format!("pubky://{}", entry_key); + let event = Event::put(&url); + let value = event.serialize(); + + self.tables + .events + .put(&mut wtxn, file_id.timestamp().to_string().as_str(), &value)?; + wtxn.commit()?; + + Ok(entry) + } + + /// Get an entry from the database. + /// This doesn't include the file but only metadata. + pub fn get_entry(&self, path: &EntryPath) -> anyhow::Result> { + let txn = self.env.read_txn()?; + let entry = match self.tables.entries.get(&txn, path.as_str())? { + Some(bytes) => Entry::deserialize(bytes)?, + None => return Ok(None), + }; + Ok(Some(entry)) + } + + /// Delete an entry including the associated file from the database. + pub async fn delete_entry(&mut self, path: &EntryPath) -> anyhow::Result { + let mut db = self.clone(); + let path = path.clone(); + let join_handle = tokio::task::spawn_blocking(move || -> anyhow::Result { + db.delete_entry_sync(&path) + }) + .await; + match join_handle { + Ok(result) => result, + Err(e) => { + tracing::error!("Error deleting entry. JoinError: {:?}", e); + Err(e.into()) + } + } + } + + /// Delete an entry including the associated file from the database. + pub fn delete_entry_sync(&mut self, path: &EntryPath) -> anyhow::Result { + let entry = match self.get_entry(path)? { + Some(entry) => entry, + None => return Ok(false), + }; + + let mut wtxn = self.env.write_txn()?; + let deleted = self.delete_file(&entry.file_id(), &mut wtxn)?; + if !deleted { + wtxn.abort(); + return Ok(false); + } + + let deleted = self.tables.entries.delete(&mut wtxn, path.as_str())?; + if !deleted { + wtxn.abort(); + return Ok(false); + } + + // create DELETE event + let url = format!("pubky://{}", path.as_str()); + + let event = Event::delete(&url); + let value = event.serialize(); + + let key = Timestamp::now().to_string(); + + self.tables.events.put(&mut wtxn, &key, &value)?; + + wtxn.commit()?; + Ok(true) + } + + /// Bytes stored at `path` for this user (0 if none). + pub fn get_entry_content_length(&self, path: &EntryPath) -> anyhow::Result { + let content_length = self + .get_entry(path)? + .map(|e| e.content_length() as u64) + .unwrap_or(0); + Ok(content_length) + } + + pub fn contains_directory(&self, txn: &RoTxn, entry_path: &EntryPath) -> anyhow::Result { + Ok(self + .tables + .entries + .get_greater_than(txn, entry_path.as_str())? + .is_some()) + } + + /// Return a list of pubky urls. + /// + /// - limit defaults to [crate::config::DEFAULT_LIST_LIMIT] and capped by [crate::config::DEFAULT_MAX_LIST_LIMIT] + pub fn list_entries( + &self, + txn: &RoTxn, + entry_path: &EntryPath, + reverse: bool, + limit: Option, + cursor: Option, + shallow: bool, + ) -> anyhow::Result> { + // Vector to store results + let mut results = Vec::new(); + + let limit = limit + .unwrap_or(DEFAULT_LIST_LIMIT) + .min(DEFAULT_MAX_LIST_LIMIT); + + // TODO: make this more performant than split and allocations? + + let mut threshold = cursor + .map(|cursor| { + // Removing leading forward slashes + let mut file_or_directory = cursor.trim_start_matches('/'); + + if cursor.starts_with("pubky://") { + file_or_directory = cursor + .split(entry_path.as_str()) + .last() + .expect("should not be reachable") + }; + + next_threshold( + entry_path.as_str(), + file_or_directory, + file_or_directory.ends_with('/'), + reverse, + shallow, + ) + }) + .unwrap_or(next_threshold( + entry_path.as_str(), + "", + false, + reverse, + shallow, + )); + + for _ in 0..limit { + if let Some((key, _)) = if reverse { + self.tables.entries.get_lower_than(txn, &threshold)? + } else { + self.tables.entries.get_greater_than(txn, &threshold)? + } { + if !key.starts_with(entry_path.as_str()) { + break; + } + + if shallow { + let mut split = key[entry_path.as_str().len()..].split('/'); + let file_or_directory = split.next().expect("should not be reachable"); + + let is_directory = split.next().is_some(); + + threshold = next_threshold( + entry_path.as_str(), + file_or_directory, + is_directory, + reverse, + shallow, + ); + + results.push(format!( + "pubky://{}{file_or_directory}{}", + entry_path.as_str(), + if is_directory { "/" } else { "" } + )); + } else { + threshold = key.to_string(); + results.push(format!("pubky://{}", key)) + } + }; + } + + Ok(results) + } +} + +/// Calculate the next threshold +#[instrument] +fn next_threshold( + path: &str, + file_or_directory: &str, + is_directory: bool, + reverse: bool, + shallow: bool, +) -> String { + format!( + "{path}{file_or_directory}{}", + if file_or_directory.is_empty() { + // No file_or_directory, early return + if reverse { + // `path/to/dir/\x7f` to catch all paths than `path/to/dir/` + "\x7f" + } else { + "" + } + } else if shallow & is_directory { + if reverse { + // threshold = `path/to/dir\x2e`, since `\x2e` is lower than `/` + "\x2e" + } else { + //threshold = `path/to/dir\x7f`, since `\x7f` is greater than `/` + "\x7f" + } + } else { + "" + } + ) +} + +#[derive(Clone, Default, Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct Entry { + /// Encoding version + version: usize, + /// Modified at + timestamp: Timestamp, + content_hash: EntryHash, + content_length: usize, + content_type: String, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +struct EntryHash(Hash); + +impl Default for EntryHash { + fn default() -> Self { + Self(Hash::from_bytes([0; 32])) + } +} + +impl Serialize for EntryHash { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let bytes = self.0.as_bytes(); + bytes.serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for EntryHash { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let bytes: [u8; 32] = Deserialize::deserialize(deserializer)?; + Ok(Self(Hash::from_bytes(bytes))) + } +} + +impl Entry { + pub fn new() -> Self { + Default::default() + } + + pub fn chunk_key(&self, chunk_index: u32) -> [u8; 12] { + let mut chunk_key = [0; 12]; + chunk_key[0..8].copy_from_slice(&self.timestamp.to_bytes()); + chunk_key[8..].copy_from_slice(&chunk_index.to_be_bytes()); + chunk_key + } + + // === Setters === + + pub fn set_timestamp(&mut self, timestamp: &Timestamp) -> &mut Self { + self.timestamp = *timestamp; + self + } + + pub fn set_content_hash(&mut self, content_hash: Hash) -> &mut Self { + EntryHash(content_hash).clone_into(&mut self.content_hash); + self + } + + pub fn set_content_length(&mut self, content_length: usize) -> &mut Self { + self.content_length = content_length; + self + } + + // === Getters === + + pub fn timestamp(&self) -> &Timestamp { + &self.timestamp + } + + pub fn content_hash(&self) -> &Hash { + &self.content_hash.0 + } + + pub fn content_length(&self) -> usize { + self.content_length + } + + pub fn content_type(&self) -> &str { + &self.content_type + } + + pub fn file_id(&self) -> InDbFileId { + InDbFileId::from(self.timestamp) + } + + // === Public Method === + + pub fn serialize(&self) -> Vec { + to_allocvec(self).expect("Session::serialize") + } + + pub fn deserialize(bytes: &[u8]) -> core::result::Result { + if bytes[0] > 0 { + panic!("Unknown Entry version"); + } + + from_bytes(bytes) + } +} + +#[cfg(test)] +mod tests { + use super::LmDB; + use crate::{ + persistence::lmdb::tables::files::{InDbTempFile, SyncInDbTempFileWriter}, + shared::webdav::{EntryPath, WebDavPath}, + }; + use bytes::Bytes; + use pkarr::Keypair; + use std::io::Read; + + #[tokio::test] + async fn test_write_read_delete_method() { + let mut db = LmDB::test(); + + let path = EntryPath::new( + Keypair::random().public_key(), + WebDavPath::new("/pub/foo.txt").unwrap(), + ); + let file = InDbTempFile::zeros(5).await.unwrap(); + let entry = db.write_entry_sync(&path, &file).unwrap(); + + let read_entry = db.get_entry(&path).unwrap().expect("Entry doesn't exist"); + assert_eq!(entry.content_hash(), read_entry.content_hash()); + assert_eq!(entry.content_length(), read_entry.content_length()); + assert_eq!(entry.timestamp(), read_entry.timestamp()); + + let read_file = db.read_file(&entry.file_id()).await.unwrap(); + let mut file_handle = read_file.open_file_handle().unwrap(); + let mut content = vec![]; + file_handle.read_to_end(&mut content).unwrap(); + assert_eq!(content, vec![0, 0, 0, 0, 0]); + + let deleted = db.delete_entry_sync(&path).unwrap(); + assert!(deleted); + + // Verify the entry and file are deleted + let read_entry = db.get_entry(&path).unwrap(); + assert!(read_entry.is_none()); + let read_file = db.read_file(&entry.file_id()).await.unwrap(); + assert_eq!(read_file.len(), 0); + } + + #[tokio::test] + async fn entries() -> anyhow::Result<()> { + let mut db = LmDB::test(); + + let keypair = Keypair::random(); + let public_key = keypair.public_key(); + let path = "/pub/foo.txt"; + + let entry_path = EntryPath::new(public_key, WebDavPath::new(path).unwrap()); + let chunk = Bytes::from(vec![1, 2, 3, 4, 5]); + let mut writer = SyncInDbTempFileWriter::new()?; + writer.write_chunk(&chunk)?; + let file = writer.complete()?; + + db.write_entry_sync(&entry_path, &file)?; + + let entry = db.get_entry(&entry_path).unwrap().unwrap(); + + assert_eq!( + entry.content_hash(), + &[ + 2, 79, 103, 192, 66, 90, 61, 192, 47, 186, 245, 140, 185, 61, 229, 19, 46, 61, 117, + 197, 25, 250, 160, 186, 218, 33, 73, 29, 136, 201, 112, 87 + ] + ); + + let read_file = db.read_file(&entry.file_id()).await.unwrap(); + let mut file_handle = read_file.open_file_handle().unwrap(); + let mut content = vec![]; + file_handle.read_to_end(&mut content).unwrap(); + assert_eq!(content, vec![1, 2, 3, 4, 5]); + Ok(()) + } + + #[tokio::test] + async fn chunked_entry() -> anyhow::Result<()> { + let mut db = LmDB::test(); + + let keypair = Keypair::random(); + let public_key = keypair.public_key(); + let path = "/pub/foo.txt"; + let entry_path = EntryPath::new(public_key, WebDavPath::new(path).unwrap()); + + let chunk = Bytes::from(vec![0; 1024 * 1024]); + + let mut writer = SyncInDbTempFileWriter::new()?; + writer.write_chunk(&chunk)?; + let file = writer.complete()?; + + db.write_entry_sync(&entry_path, &file)?; + + let entry = db.get_entry(&entry_path).unwrap().unwrap(); + + assert_eq!( + entry.content_hash(), + &[ + 72, 141, 226, 2, 247, 59, 217, 118, 222, 78, 112, 72, 244, 225, 243, 154, 119, 109, + 134, 213, 130, 183, 52, 143, 245, 59, 244, 50, 185, 135, 252, 168 + ] + ); + + let read_file = db.read_file(&entry.file_id()).await.unwrap(); + let mut file_handle = read_file.open_file_handle().unwrap(); + let mut content = vec![]; + file_handle.read_to_end(&mut content).unwrap(); + assert_eq!(content, vec![0; 1024 * 1024]); + + Ok(()) + } +} diff --git a/pubky-homeserver/src/persistence/lmdb/tables/files/in_db_file.rs b/pubky-homeserver/src/persistence/lmdb/tables/files/in_db_file.rs new file mode 100644 index 0000000..90ea267 --- /dev/null +++ b/pubky-homeserver/src/persistence/lmdb/tables/files/in_db_file.rs @@ -0,0 +1,236 @@ +//! +//! InDbFile is an abstraction over the way we store blobs/chunks in LMDB. +//! Because the value size of LMDB is limited, we need to store multiple blobs for one file. +//! +//! - `InDbFileId` is the identifier of a file that consists of multiple blobs. +//! - `InDbTempFile` is a helper to read/write a file to/from disk. +//! +use pubky_common::crypto::{Hash, Hasher}; +use pubky_common::timestamp::Timestamp; + +/// A file identifier for a file stored in LMDB. +/// The identifier is basically the timestamp of the file. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct InDbFileId(Timestamp); + +impl InDbFileId { + pub fn new() -> Self { + Self(Timestamp::now()) + } + + pub fn timestamp(&self) -> &Timestamp { + &self.0 + } + + pub fn bytes(&self) -> [u8; 8] { + self.0.to_bytes() + } + + /// Create a blob key from a timestamp and a blob index. + /// blob key = (timestamp | blob_index) => bytes. + /// Max file size is 2^32 blobs. + pub fn get_blob_key(&self, blob_index: u32) -> [u8; 12] { + let mut blob_key = [0; 12]; + blob_key[0..8].copy_from_slice(&self.bytes()); + blob_key[8..].copy_from_slice(&blob_index.to_be_bytes()); + blob_key + } +} + +impl From for InDbFileId { + fn from(timestamp: Timestamp) -> Self { + Self(timestamp) + } +} + +impl Default for InDbFileId { + fn default() -> Self { + Self::new() + } +} + +use std::sync::Arc; +use std::{fs::File, io::Write, path::PathBuf}; +use tokio::fs::File as AsyncFile; +use tokio::io::AsyncWriteExt; +use tokio::task; + +/// Writes a temp file to disk asynchronously. +#[derive(Debug)] +pub(crate) struct AsyncInDbTempFileWriter { + // Temp dir is automatically deleted when the EntryTempFile is dropped. + #[allow(dead_code)] + dir: tempfile::TempDir, + writer_file: AsyncFile, + file_path: PathBuf, + hasher: Hasher, +} + +impl AsyncInDbTempFileWriter { + pub async fn new() -> Result { + let dir = task::spawn_blocking(tempfile::tempdir) + .await + .map_err(|join_error| { + std::io::Error::other(format!( + "Task join error for tempdir creation: {}", + join_error + )) + })??; // Handles the Result from tempfile::tempdir() + + let file_path = dir.path().join("entry.bin"); + let writer_file = AsyncFile::create(file_path.clone()).await?; + let hasher = Hasher::new(); + + Ok(Self { + dir, + writer_file, + file_path, + hasher, + }) + } + + /// Create a new BlobsTempFile with zero content. + /// Convenient method used for testing. + #[cfg(test)] + pub async fn zeros(size_bytes: usize) -> Result { + let mut file = Self::new().await?; + let buffer = vec![0u8; size_bytes]; + file.write_chunk(&buffer).await?; + file.complete().await + } + + /// Write a chunk to the file. + /// Chunk writing is done by the axum body stream and by LMDB itself. + pub async fn write_chunk(&mut self, chunk: &[u8]) -> Result<(), std::io::Error> { + self.writer_file.write_all(chunk).await?; + self.hasher.update(chunk); + Ok(()) + } + + /// Flush the file to disk. + /// This completes the writing of the file. + /// Returns a BlobsTempFile that can be used to read the file. + pub async fn complete(mut self) -> Result { + self.writer_file.flush().await?; + let hash = self.hasher.finalize(); + let file_size = self.writer_file.metadata().await?.len(); + Ok(InDbTempFile { + dir: Arc::new(self.dir), + file_path: self.file_path, + file_size: file_size as usize, + file_hash: hash, + }) + } +} + +/// Writes a temp file to disk synchronously. +#[derive(Debug)] +pub(crate) struct SyncInDbTempFileWriter { + // Temp dir is automatically deleted when the EntryTempFile is dropped. + #[allow(dead_code)] + dir: tempfile::TempDir, + writer_file: File, + file_path: PathBuf, + hasher: Hasher, +} + +impl SyncInDbTempFileWriter { + pub fn new() -> Result { + let dir = tempfile::tempdir()?; + let file_path = dir.path().join("entry.bin"); + let writer_file = File::create(file_path.clone())?; + let hasher = Hasher::new(); + + Ok(Self { + dir, + writer_file, + file_path, + hasher, + }) + } + + /// Write a chunk to the file. + pub fn write_chunk(&mut self, chunk: &[u8]) -> Result<(), std::io::Error> { + self.writer_file.write_all(chunk)?; + self.hasher.update(chunk); + Ok(()) + } + + /// Flush the file to disk. + /// This completes the writing of the file. + /// Returns a BlobsTempFile that can be used to read the file. + pub fn complete(mut self) -> Result { + self.writer_file.flush()?; + let hash = self.hasher.finalize(); + let file_size = self.writer_file.metadata()?.len(); + Ok(InDbTempFile { + dir: Arc::new(self.dir), + file_path: self.file_path, + file_size: file_size as usize, + file_hash: hash, + }) + } +} + +/// A temporary file helper for Entry. +/// +/// Every file in LMDB is first written to disk before being written to LMDB. +/// The same is true if you read a file from LMDB. +/// +/// This is to keep the LMDB transaction small and fast. +/// +/// As soon as EntryTempFile is dropped, the file on disk is deleted. +/// +#[derive(Debug, Clone)] +pub struct InDbTempFile { + // Temp dir is automatically deleted when the EntryTempFile is dropped. + #[allow(dead_code)] + dir: Arc, + file_path: PathBuf, + file_size: usize, + file_hash: Hash, +} + +impl InDbTempFile { + /// Create a new BlobsTempFile with random content. + /// Convenient method used for testing. + #[cfg(test)] + pub async fn zeros(size_bytes: usize) -> Result { + AsyncInDbTempFileWriter::zeros(size_bytes).await + } + + /// Create a new InDbTempFile with zero content. + pub fn empty() -> Result { + let dir = tempfile::tempdir()?; + let file_path = dir.path().join("entry.bin"); + std::fs::File::create(file_path.clone())?; + let file_size = 0; + let hasher = Hasher::new(); + let file_hash = hasher.finalize(); + Ok(Self { + dir: Arc::new(dir), + file_path, + file_size, + file_hash, + }) + } + + pub fn len(&self) -> usize { + self.file_size + } + + pub fn hash(&self) -> &Hash { + &self.file_hash + } + + /// Get the path of the file on disk. + #[cfg(test)] + pub fn path(&self) -> &PathBuf { + &self.file_path + } + + /// Open the file on disk. + pub fn open_file_handle(&self) -> Result { + File::open(self.file_path.as_path()) + } +} diff --git a/pubky-homeserver/src/persistence/lmdb/tables/files/mod.rs b/pubky-homeserver/src/persistence/lmdb/tables/files/mod.rs new file mode 100644 index 0000000..63e6fde --- /dev/null +++ b/pubky-homeserver/src/persistence/lmdb/tables/files/mod.rs @@ -0,0 +1,7 @@ +mod blobs; +mod entries; +mod in_db_file; + +pub use blobs::{BlobsTable, BLOBS_TABLE}; +pub use entries::{EntriesTable, Entry, ENTRIES_TABLE}; +pub use in_db_file::*; diff --git a/pubky-homeserver/src/persistence/lmdb/tables/users.rs b/pubky-homeserver/src/persistence/lmdb/tables/users.rs index 7ded66d..cbc3f25 100644 --- a/pubky-homeserver/src/persistence/lmdb/tables/users.rs +++ b/pubky-homeserver/src/persistence/lmdb/tables/users.rs @@ -157,6 +157,7 @@ impl LmDB { /// # Errors /// /// - `UserQueryError::DatabaseError` if the database operation fails. + #[cfg(test)] pub fn create_user(&self, pubkey: &PublicKey, wtxn: &mut RwTxn) -> anyhow::Result<()> { let user = User::default(); self.tables.users.put(wtxn, pubkey, &user)?; diff --git a/pubky-homeserver/src/shared/mod.rs b/pubky-homeserver/src/shared/mod.rs index 550de7e..cf02ed6 100644 --- a/pubky-homeserver/src/shared/mod.rs +++ b/pubky-homeserver/src/shared/mod.rs @@ -1,5 +1,6 @@ mod http_error; mod pubkey_path_validator; +pub(crate) mod webdav; pub(crate) use http_error::{HttpError, HttpResult}; pub(crate) use pubkey_path_validator::Z32Pubkey; diff --git a/pubky-homeserver/src/shared/webdav/entry_path.rs b/pubky-homeserver/src/shared/webdav/entry_path.rs new file mode 100644 index 0000000..a7e5a09 --- /dev/null +++ b/pubky-homeserver/src/shared/webdav/entry_path.rs @@ -0,0 +1,124 @@ +use pkarr::PublicKey; +use std::str::FromStr; + +use super::WebDavPath; + +#[derive(thiserror::Error, Debug)] +pub enum EntryPathError { + #[error("{0}")] + Invalid(String), + #[error("Failed to parse webdav path: {0}")] + InvalidWebdavPath(anyhow::Error), + #[error("Failed to parse pubkey: {0}")] + InvalidPubkey(pkarr::errors::PublicKeyError), +} + +/// A path to an entry. +/// +/// The path as a string is used to identify the entry. +#[derive(Debug, Clone)] +pub struct EntryPath { + #[allow(dead_code)] + pubkey: PublicKey, + path: WebDavPath, + /// The key of the entry represented as a string. + /// The key is the pubkey and the path concatenated. + /// Example: `8pinxxgqs41n4aididenw5apqp1urfmzdztr8jt4abrkdn435ewo/folder/file.txt` + /// This is cached/redundant to avoid reallocating the string on every access. + key: String, +} + +impl EntryPath { + pub fn new(pubkey: PublicKey, path: WebDavPath) -> Self { + let key = format!("{}{}", pubkey, path); + Self { pubkey, path, key } + } + + #[allow(dead_code)] + pub fn pubkey(&self) -> &PublicKey { + &self.pubkey + } + + pub fn path(&self) -> &WebDavPath { + &self.path + } + + /// The key of the entry. + /// + /// The key is the pubkey and the path concatenated. + /// + /// Example: `8pinxxgqs41n4aididenw5apqp1urfmzdztr8jt4abrkdn435ewo/folder/file.txt` + pub fn as_str(&self) -> &str { + &self.key + } +} + +impl AsRef for EntryPath { + fn as_ref(&self) -> &str { + &self.key + } +} + +impl FromStr for EntryPath { + type Err = EntryPathError; + + fn from_str(s: &str) -> Result { + let first_slash_index = s + .find('/') + .ok_or(EntryPathError::Invalid("Missing '/'".to_string()))?; + let (pubkey, path) = match s.split_at_checked(first_slash_index) { + Some((pubkey, path)) => (pubkey, path), + None => return Err(EntryPathError::Invalid("Missing '/'".to_string())), + }; + let pubkey = PublicKey::from_str(pubkey).map_err(EntryPathError::InvalidPubkey)?; + let webdav_path = WebDavPath::new(path).map_err(EntryPathError::InvalidWebdavPath)?; + Ok(Self::new(pubkey, webdav_path)) + } +} + +impl std::fmt::Display for EntryPath { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.as_ref()) + } +} + +impl serde::Serialize for EntryPath { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(self.as_ref()) + } +} + +impl<'de> serde::Deserialize<'de> for EntryPath { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + Self::from_str(&s).map_err(serde::de::Error::custom) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_entry_path_from_str() { + let pubkey = + PublicKey::from_str("8pinxxgqs41n4aididenw5apqp1urfmzdztr8jt4abrkdn435ewo").unwrap(); + let path = WebDavPath::new("/pub/folder/file.txt").unwrap(); + let key = format!("{pubkey}{path}"); + let entry_path = EntryPath::new(pubkey, path); + assert_eq!(entry_path.as_ref(), key); + } + + #[test] + fn test_entry_path_serde() { + let string = "8pinxxgqs41n4aididenw5apqp1urfmzdztr8jt4abrkdn435ewo/pub/folder/file.txt"; + let entry_path = EntryPath::from_str(string).unwrap(); + assert_eq!(entry_path.to_string(), string); + } +} diff --git a/pubky-homeserver/src/shared/webdav/mod.rs b/pubky-homeserver/src/shared/webdav/mod.rs new file mode 100644 index 0000000..48316b7 --- /dev/null +++ b/pubky-homeserver/src/shared/webdav/mod.rs @@ -0,0 +1,7 @@ +mod entry_path; +mod webdav_path; +mod webdav_path_axum; + +pub use entry_path::EntryPath; +pub use webdav_path::WebDavPath; +pub use webdav_path_axum::WebDavPathAxum; diff --git a/pubky-homeserver/src/shared/webdav/webdav_path.rs b/pubky-homeserver/src/shared/webdav/webdav_path.rs new file mode 100644 index 0000000..b91ef15 --- /dev/null +++ b/pubky-homeserver/src/shared/webdav/webdav_path.rs @@ -0,0 +1,343 @@ +use std::str::FromStr; + +/// A normalized and validated webdav path. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct WebDavPath { + normalized_path: String, +} + +impl WebDavPath { + /// Create a new WebDavPath from a already normalized path. + /// Make sure the path is 100% normalized, and valid before using this constructor. + /// + /// Use `WebDavPath::new` to create a new WebDavPath from an unnormalized path. + pub fn new_unchecked(normalized_path: String) -> Self { + Self { + normalized_path: normalized_path.to_string(), + } + } + + /// Create a new WebDavPath from an unnormalized path. + /// + /// The path will be normalized and validated. + pub fn new(unnormalized_path: &str) -> anyhow::Result { + let normalized_path = normalize_and_validate_webdav_path(unnormalized_path)?; + if !normalized_path.starts_with("/pub/") { + return Err(anyhow::anyhow!("Path must start with /pub/")); + } + Ok(Self::new_unchecked(normalized_path)) + } + + #[allow(dead_code)] + pub fn url_encode(&self) -> String { + percent_encoding::utf8_percent_encode(self.normalized_path.as_str(), PATH_ENCODE_SET) + .to_string() + } + + pub fn as_str(&self) -> &str { + self.normalized_path.as_str() + } + + pub fn is_directory(&self) -> bool { + self.normalized_path.ends_with('/') + } + + /// Check if the path is a file. + #[allow(dead_code)] + pub fn is_file(&self) -> bool { + !self.is_directory() + } +} + +impl std::fmt::Display for WebDavPath { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.normalized_path) + } +} + +impl FromStr for WebDavPath { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + Self::new(s) + } +} + +// Encode all non-unreserved characters, except '/'. +// See RFC3986, and https://en.wikipedia.org/wiki/Percent-encoding . +const PATH_ENCODE_SET: &percent_encoding::AsciiSet = &percent_encoding::NON_ALPHANUMERIC + .remove(b'-') + .remove(b'_') + .remove(b'.') + .remove(b'~') + .remove(b'/'); + +/// Maximum length of a single path segment. +const MAX_WEBDAV_PATH_SEGMENT_LENGTH: usize = 255; +/// Maximum total length of a normalized WebDAV path. +const MAX_WEBDAV_PATH_TOTAL_LENGTH: usize = 4096; + +/// Takes a path, normalizes and validates it. +/// Make sure to url decode the path before calling this function. +/// Inspired by https://github.com/messense/dav-server-rs/blob/740dae05ac2eeda8e2ea11fface3ab6d53b6705e/src/davpath.rs#L101 +fn normalize_and_validate_webdav_path(path: &str) -> anyhow::Result { + // Ensure the path starts with '/' + if !path.starts_with('/') { + return Err(anyhow::anyhow!("Path must start with '/'")); + } + + let is_dir = path.ends_with('/') || path.ends_with(".."); + // Split the path into segments, filtering out empty ones + let segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect(); + // Build the normalized path + let mut normalized_segments = vec![]; + + for segment in segments { + // Check for segment length + if segment.len() > MAX_WEBDAV_PATH_SEGMENT_LENGTH { + return Err(anyhow::anyhow!( + "Invalid path: Segment exceeds maximum length of {} characters. Segment: '{}'", + MAX_WEBDAV_PATH_SEGMENT_LENGTH, + segment + )); + } + + // Check for any ASCII control characters in the decoded segment + if segment.chars().any(|c| c.is_control()) { + return Err(anyhow::anyhow!( + "Invalid path: ASCII control characters are not allowed in segments" + )); + } + + if segment == "." { + continue; + } else if segment == ".." { + if normalized_segments.len() < 2 { + return Err(anyhow::anyhow!("Failed to normalize path: '..'.")); + } + normalized_segments.pop(); + normalized_segments.pop(); + } else { + normalized_segments.push("/".to_string()); + normalized_segments.push(segment.to_string()); + } + } + + if is_dir { + normalized_segments.push("/".to_string()); + } + let full_path = normalized_segments.join(""); + + // Check for total path length + if full_path.len() > MAX_WEBDAV_PATH_TOTAL_LENGTH { + return Err(anyhow::anyhow!( + "Invalid path: Total path length exceeds maximum of {} characters. Length: {}, Path: '{}'", + MAX_WEBDAV_PATH_TOTAL_LENGTH, + full_path.len(), + full_path + )); + } + + Ok(full_path) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn assert_valid_path(path: &str, expected: &str) { + match normalize_and_validate_webdav_path(path) { + Ok(path) => { + assert_eq!(path, expected); + } + Err(e) => { + assert!( + false, + "Path '{path}' is invalid. Should be '{expected}'. Error: {e}" + ); + } + }; + } + + fn assert_invalid_path(path: &str) { + if let Ok(normalized_path) = normalize_and_validate_webdav_path(path) { + assert!( + false, + "Invalid path '{path}' is valid. Normalized result: '{normalized_path}'" + ); + } + } + + #[test] + fn test_slash_is_valid() { + assert_valid_path("/", "/"); + } + + #[test] + fn test_two_dots_is_valid() { + assert_valid_path("/test/..", "/"); + } + + #[test] + fn test_two_dots_in_the_middle_is_valid() { + assert_valid_path("/test/../test", "/test"); + } + + #[test] + fn test_two_dots_in_the_middle_with_slash_is_valid() { + assert_valid_path("/test/../test/", "/test/"); + } + + #[test] + fn test_two_dots_invalid() { + assert_invalid_path("/.."); + } + + #[test] + fn test_two_dots_twice_invalid() { + assert_invalid_path("/test/../.."); + } + + #[test] + fn test_two_slashes_is_valid() { + assert_valid_path("//", "/"); + } + + #[test] + fn test_two_slashes_in_the_middle_is_valid() { + assert_valid_path("/test//test", "/test/test"); + } + + #[test] + fn test_one_segment_is_valid() { + assert_valid_path("/test", "/test"); + } + + #[test] + fn test_one_segment_with_trailing_slash_is_valid() { + assert_valid_path("/test/", "/test/"); + } + + #[test] + fn test_two_segments_is_valid() { + assert_valid_path("/test/test", "/test/test"); + } + + #[test] + fn test_wildcard_is_valid() { + assert_valid_path("/dav/file*.txt", "/dav/file*.txt"); + } + + #[test] + fn test_two_slashes_in_the_middle_with_slash_is_valid() { + assert_valid_path("/dav//folder/", "/dav/folder/"); + } + + #[test] + fn test_script_tag_is_valid() { + assert_valid_path("/dav/