fix(homeserver): DDoS due to slow uploads or downloads (#132)

* switch branch

* moved files around

* validate webdav path in requests

* async and sync writer

* implemented file read and write via disk

* cleaning up

* clippy and fmt

* cleanup 1

* last cleanup

* fmt

* test concurrent read/write

* cleanup3

* cleanup3

* added log in case of a join error

* removed pub from max_chunk_size()

* fmt and clippy

* fixed test

* fmt

* fixed main merge errors
This commit is contained in:
Severin Alexander Bühler
2025-05-25 10:51:25 +03:00
committed by GitHub
parent a485d8c2f4
commit 178bae142e
24 changed files with 1714 additions and 769 deletions

6
Cargo.lock generated
View File

@@ -2468,6 +2468,7 @@ dependencies = [
"hostname-validator",
"httpdate",
"page_size",
"percent-encoding",
"pkarr",
"pkarr-republisher",
"postcard",
@@ -2479,6 +2480,7 @@ dependencies = [
"tempfile",
"thiserror 2.0.12",
"tokio",
"tokio-util",
"toml",
"tower 0.5.2",
"tower-cookies",
@@ -3557,9 +3559,9 @@ dependencies = [
[[package]]
name = "tokio-util"
version = "0.7.13"
version = "0.7.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078"
checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df"
dependencies = [
"bytes",
"futures-core",

View File

@@ -80,7 +80,11 @@ async fn disabled_user() {
// Make sure the user can read their own file
let response = client.get(file_url.clone()).send().await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.status(),
StatusCode::OK,
"User should be able to read their own file"
);
let admin_socket = server.admin().listen_socket();
let admin_client = reqwest::Client::new();

View File

@@ -1,7 +1,8 @@
use std::time::Duration;
use pkarr::Keypair;
use pkarr::{Keypair, PublicKey};
use pubky_testnet::{
pubky::Client,
pubky_homeserver::{
quota_config::{GlobPattern, LimitKey, LimitKeyType, PathLimit},
ConfigToml, MockDataDir,
@@ -186,3 +187,127 @@ async fn test_limit_upload() {
assert_eq!(res.status(), StatusCode::CREATED);
assert!(start.elapsed() > Duration::from_secs(2));
}
/// Test that 10 clients can write/read to the server concurrently
/// Upload/download rate is limited to 1kb/s per user.
/// 3kb files are used to make the writes/reads take ~2.5s each.
/// Concurrently writing/reading 10 files, the total time taken should be ~3s.
/// If the concurrent writes/reads are not properly handled, the total time taken will be closer to ~25s.
#[tokio::test]
async fn test_concurrent_write_read() {
// Setup the testnet
let mut testnet = Testnet::new().await.unwrap();
let mut config = ConfigToml::test();
config.drive.rate_limits = vec![
PathLimit::new(
// Limit uploads to 1kb/s per user
GlobPattern::new("/pub/**"),
Method::PUT,
"1kb/s".parse().unwrap(),
LimitKeyType::User,
None,
),
PathLimit::new(
// Limit downloads to 1kb/s per user
GlobPattern::new("/pub/**"),
Method::GET,
"1kb/s".parse().unwrap(),
LimitKeyType::User,
None,
),
];
let mock_dir = MockDataDir::new(config, None).unwrap();
let hs_pubkey = {
let server = testnet
.create_homeserver_suite_with_mock(mock_dir)
.await
.unwrap();
server.public_key()
};
// Create helper struct to handle clients
#[derive(Clone)]
struct TestClient {
pub keypair: Keypair,
pub client: Client,
}
impl TestClient {
fn new(testnet: &mut Testnet) -> Self {
let keypair = Keypair::random();
let client = testnet.pubky_client_builder().build().unwrap();
Self { keypair, client }
}
pub async fn signup(&self, hs_pubkey: &PublicKey) {
self.client
.signup(&self.keypair, &hs_pubkey, None)
.await
.expect("Failed to signup");
}
pub async fn put(&self, url: Url, body: Vec<u8>) {
self.client
.put(url)
.body(body)
.send()
.await
.expect("Failed to put");
}
pub async fn get(&self, url: Url) {
let response = self.client.get(url).send().await.expect("Failed to get");
assert_eq!(response.status(), StatusCode::OK, "Failed to get");
response.bytes().await.expect("Failed to get bytes"); // Download the body
}
}
// Signup with the clients
let user_count: usize = 10;
let mut clients = vec![0; user_count]
.into_iter()
.map(|_| TestClient::new(&mut testnet))
.collect::<Vec<_>>();
for client in clients.iter_mut() {
client.signup(&hs_pubkey).await;
}
// --------------------------------------------------------------------------------------------
// Write to server concurrently
let start = Instant::now();
let mut handles = vec![];
for client in clients.iter() {
let client = client.clone();
let handle = tokio::spawn(async move {
let url: Url = format!("pubky://{}/pub/test.txt", client.keypair.public_key())
.parse()
.unwrap();
let body = vec![0u8; 3 * 1024]; // 2kb
client.put(url, body).await;
});
handles.push(handle);
}
// Wait for all the writes to finish
for handle in handles {
handle.await.unwrap();
}
let elapsed = start.elapsed();
assert!(elapsed < Duration::from_secs(5));
// --------------------------------------------------------------------------------------------
// Read from server concurrently
let start = Instant::now();
let mut handles = vec![];
for client in clients.iter() {
let client = client.clone();
let handle = tokio::spawn(async move {
let url: Url = format!("pubky://{}/pub/test.txt", client.keypair.public_key())
.parse()
.unwrap();
client.get(url).await;
});
handles.push(handle);
}
// Wait for all the reads to finish
for handle in handles {
handle.await.unwrap();
}
let elapsed = start.elapsed();
assert!(elapsed < Duration::from_secs(5));
}

View File

@@ -55,6 +55,8 @@ dyn-clone = "1.0.19"
reqwest = "0.12.15"
governor = "0.10.0"
fast-glob = "0.4.5"
tokio-util = "0.7.15"
percent-encoding = "2.3.1"
serde_valid = "1.0.5"

View File

@@ -25,10 +25,7 @@ fn create_protected_router(password: &str) -> Router<AppState> {
get(generate_signup_token::generate_signup_token),
)
.route("/info", get(info::info))
.route(
"/webdav/{pubkey}/{*path}",
delete(delete_entry::delete_entry),
)
.route("/webdav/{*entry_path}", delete(delete_entry::delete_entry))
.route("/users/{pubkey}/disable", post(disable_user))
.route("/users/{pubkey}/enable", post(enable_user))
.layer(AdminAuthLayer::new(password.to_string()))

View File

@@ -1,5 +1,5 @@
use super::super::app_state::AppState;
use crate::shared::{HttpError, HttpResult, Z32Pubkey};
use crate::shared::{webdav::EntryPath, HttpError, HttpResult};
use axum::{
extract::{Path, State},
http::StatusCode,
@@ -9,16 +9,9 @@ use axum::{
/// Delete a single entry from the database.
pub async fn delete_entry(
State(mut state): State<AppState>,
Path((pubkey, path)): Path<(Z32Pubkey, String)>,
Path(entry_path): Path<EntryPath>,
) -> HttpResult<impl IntoResponse> {
let path = format!("/{}", path); // Add missing leading slash
if !path.starts_with("/pub/") {
return Err(HttpError::new(
StatusCode::BAD_REQUEST,
Some("Invalid path"),
));
}
let deleted = state.db.delete_entry(&pubkey.0, &path)?;
let deleted = state.db.delete_entry(&entry_path).await?;
if deleted {
Ok((StatusCode::NO_CONTENT, ()))
} else {
@@ -33,16 +26,14 @@ pub async fn delete_entry(
mod tests {
use super::super::super::app_state::AppState;
use super::*;
use crate::persistence::lmdb::LmDB;
use crate::persistence::lmdb::{tables::files::InDbTempFile, LmDB};
use crate::shared::webdav::{EntryPath, WebDavPath};
use axum::{routing::delete, Router};
use pkarr::{Keypair, PublicKey};
use std::io::Write;
use pkarr::Keypair;
async fn write_test_file(db: &mut LmDB, pubkey: &PublicKey, path: &str) {
let mut entry_writer = db.write_entry(pubkey, path).unwrap();
let content = b"Hello, world!";
entry_writer.write_all(content).unwrap();
let _entry = entry_writer.commit().unwrap();
async fn write_test_file(db: &mut LmDB, entry_path: &EntryPath) {
let file = InDbTempFile::zeros(10).await.unwrap();
let _entry = db.write_entry(&entry_path, &file).await.unwrap();
}
#[tokio::test]
@@ -54,25 +45,25 @@ mod tests {
let mut db = LmDB::test();
let app_state = AppState::new(db.clone());
let router = Router::new()
.route("/webdav/{pubkey}/{*path}", delete(delete_entry))
.route("/webdav/{*entry_path}", delete(delete_entry))
.with_state(app_state);
// Write a test file
let entry_path = format!("/pub/{}", file_path);
write_test_file(&mut db, &pubkey, &entry_path).await;
let webdav_path = WebDavPath::new(format!("/pub/{}", file_path).as_str()).unwrap();
let entry_path = EntryPath::new(pubkey.clone(), webdav_path);
write_test_file(&mut db, &entry_path).await;
// Delete the file
let server = axum_test::TestServer::new(router).unwrap();
let response = server
.delete(format!("/webdav/{}{}", pubkey, entry_path).as_str())
.delete(format!("/webdav/{}{}", pubkey, entry_path.path().as_str()).as_str())
.await;
assert_eq!(response.status_code(), StatusCode::NO_CONTENT);
// Check that the file is deleted
let rtx = db.env.read_txn().unwrap();
let entry = db.get_entry(&rtx, &pubkey, &file_path).unwrap();
let entry = db.get_entry(&entry_path).unwrap();
assert!(entry.is_none(), "Entry should be deleted");
rtx.commit().unwrap();
let events = db.list_events(None, None).unwrap();
assert_eq!(
@@ -92,14 +83,13 @@ mod tests {
let file_path = "my_file.txt";
let app_state = AppState::new(LmDB::test());
let router = Router::new()
.route("/webdav/{pubkey}/{*path}", delete(delete_entry))
.route("/webdav/{*entry_path}", delete(delete_entry))
.with_state(app_state);
// Delete the file
let url = format!("/webdav/{}/pub/{}", pubkey, file_path);
let server = axum_test::TestServer::new(router).unwrap();
let response = server
.delete(format!("/webdav/{}/pub/{}", pubkey, file_path).as_str())
.await;
let response = server.delete(url.as_str()).await;
assert_eq!(response.status_code(), StatusCode::NOT_FOUND);
}
@@ -109,7 +99,7 @@ mod tests {
let db = LmDB::test();
let app_state = AppState::new(db.clone());
let router = Router::new()
.route("/webdav/{pubkey}/{*path}", delete(delete_entry))
.route("/webdav/{*entry_path}", delete(delete_entry))
.with_state(app_state);
// Delete with invalid pubkey

View File

@@ -3,11 +3,7 @@
//! Every route here is relative to a tenant's Pubky host,
//! as opposed to routes relative to the Homeserver's owner.
use axum::{
extract::DefaultBodyLimit,
routing::{delete, get, head, put},
Router,
};
use axum::{extract::DefaultBodyLimit, routing::get, Router};
use crate::core::{layers::authz::AuthorizationLayer, AppState};
@@ -17,16 +13,14 @@ pub mod write;
pub fn router(state: AppState) -> Router<AppState> {
Router::new()
// - Datastore routes
.route("/pub/", get(read::get))
.route("/pub/{*path}", get(read::get))
.route("/pub/{*path}", head(read::head))
.route("/pub/{*path}", put(write::put))
.route("/pub/{*path}", delete(write::delete))
// - Session routes
.route("/session", get(session::session))
.route("/session", delete(session::signout))
// Layers
.route("/session", get(session::session).delete(session::signout))
.route(
"/{*path}",
get(read::get)
.head(read::head)
.put(write::put)
.delete(write::delete),
)
// TODO: different max size for sessions and other routes?
.layer(DefaultBodyLimit::max(100 * 1024 * 1024))
.layer(AuthorizationLayer::new(state.clone()))

View File

@@ -1,90 +1,75 @@
use crate::persistence::lmdb::tables::files::Entry;
use crate::{
core::{
err_if_user_is_invalid::err_if_user_is_invalid,
error::{Error, Result},
extractors::{ListQueryParams, PubkyHost},
AppState,
},
shared::webdav::{EntryPath, WebDavPathAxum},
};
use axum::{
body::Body,
extract::{OriginalUri, State},
extract::{Path, State},
http::{header, HeaderMap, HeaderValue, Response, StatusCode},
response::IntoResponse,
};
use httpdate::HttpDate;
use pkarr::PublicKey;
use std::str::FromStr;
use crate::core::{
error::{Error, Result},
extractors::{ListQueryParams, PubkyHost},
AppState,
};
use crate::persistence::lmdb::tables::entries::Entry;
use tokio_util::io::ReaderStream;
pub async fn head(
State(state): State<AppState>,
pubky: PubkyHost,
headers: HeaderMap,
path: OriginalUri,
Path(path): Path<WebDavPathAxum>,
) -> Result<impl IntoResponse> {
let rtxn = state.db.env.read_txn()?;
get_entry(
headers,
state
.db
.get_entry(&rtxn, pubky.public_key(), path.0.path())?,
None,
)
err_if_user_is_invalid(pubky.public_key(), &state.db, false)?;
let entry_path = EntryPath::new(pubky.public_key().clone(), path.0);
let entry = state
.db
.get_entry(&entry_path)?
.ok_or_else(|| Error::with_status(StatusCode::NOT_FOUND))?;
get_entry(headers, entry, None)
}
#[axum::debug_handler]
pub async fn get(
State(state): State<AppState>,
headers: HeaderMap,
pubky: PubkyHost,
path: OriginalUri,
Path(path): Path<WebDavPathAxum>,
params: ListQueryParams,
) -> Result<impl IntoResponse> {
let public_key = pubky.public_key().clone();
let path = path.0.path().to_string();
if path.ends_with('/') {
return list(state, &public_key, &path, params);
let dav_path = path.0;
let entry_path = EntryPath::new(public_key.clone(), dav_path);
if entry_path.path().is_directory() {
return list(state, &entry_path, params);
}
let (entry_tx, entry_rx) = flume::bounded::<Option<Entry>>(1);
let (chunks_tx, chunks_rx) = flume::unbounded::<std::result::Result<Vec<u8>, heed::Error>>();
let entry = state
.db
.get_entry(&entry_path)?
.ok_or_else(|| Error::with_status(StatusCode::NOT_FOUND))?;
let buffer_file = state.db.read_file(&entry.file_id()).await?;
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
let rtxn = state.db.env.read_txn()?;
let option = state.db.get_entry(&rtxn, &public_key, &path)?;
if let Some(entry) = option {
let iter = entry.read_content(&state.db, &rtxn)?;
entry_tx.send(Some(entry))?;
for next in iter {
chunks_tx.send(next.map(|b| b.to_vec()))?;
}
};
entry_tx.send(None)?;
Ok(())
});
get_entry(
headers,
entry_rx.recv_async().await?,
Some(Body::from_stream(chunks_rx.into_stream())),
)
let file_handle = buffer_file.open_file_handle()?;
// Async stream the file
let tokio_file_handle = tokio::fs::File::from_std(file_handle);
let body_stream = Body::from_stream(ReaderStream::new(tokio_file_handle));
get_entry(headers, entry, Some(body_stream))
}
pub fn list(
state: AppState,
public_key: &PublicKey,
path: &str,
entry_path: &EntryPath,
params: ListQueryParams,
) -> Result<Response<Body>> {
let txn = state.db.env.read_txn()?;
let path = format!("{public_key}{path}");
if !state.db.contains_directory(&txn, &path)? {
if !state.db.contains_directory(&txn, entry_path)? {
return Err(Error::new(
StatusCode::NOT_FOUND,
"Directory Not Found".into(),
@@ -92,9 +77,9 @@ pub fn list(
}
// Handle listing
let vec = state.db.list(
let vec = state.db.list_entries(
&txn,
&path,
entry_path,
params.reverse,
params.limit,
params.cursor,
@@ -107,54 +92,44 @@ pub fn list(
.body(Body::from(vec.join("\n")))?)
}
pub fn get_entry(
headers: HeaderMap,
entry: Option<Entry>,
body: Option<Body>,
) -> Result<Response<Body>> {
if let Some(entry) = entry {
// TODO: Enable seek API (range requests)
// TODO: Gzip? or brotli?
pub fn get_entry(headers: HeaderMap, entry: Entry, body: Option<Body>) -> Result<Response<Body>> {
// TODO: Enable seek API (range requests)
// TODO: Gzip? or brotli?
let mut response = HeaderMap::from(&entry).into_response();
let mut response = HeaderMap::from(&entry).into_response();
// Handle IF_MODIFIED_SINCE
if let Some(condition_http_date) = headers
.get(header::IF_MODIFIED_SINCE)
.and_then(|h| h.to_str().ok())
.and_then(|s| HttpDate::from_str(s).ok())
{
let entry_http_date: HttpDate = entry.timestamp().to_owned().into();
// Handle IF_MODIFIED_SINCE
if let Some(condition_http_date) = headers
.get(header::IF_MODIFIED_SINCE)
.and_then(|h| h.to_str().ok())
.and_then(|s| HttpDate::from_str(s).ok())
{
let entry_http_date: HttpDate = entry.timestamp().to_owned().into();
if condition_http_date >= entry_http_date {
*response.status_mut() = StatusCode::NOT_MODIFIED;
}
};
// Handle IF_NONE_MATCH
if let Some(str) = headers
.get(header::IF_NONE_MATCH)
.and_then(|h| h.to_str().ok())
{
let etag = format!("\"{}\"", entry.content_hash());
if str
.trim()
.split(',')
.collect::<Vec<_>>()
.contains(&etag.as_str())
{
*response.status_mut() = StatusCode::NOT_MODIFIED;
};
if condition_http_date >= entry_http_date {
*response.status_mut() = StatusCode::NOT_MODIFIED;
}
};
if let Some(body) = body {
*response.body_mut() = body;
// Handle IF_NONE_MATCH
if let Some(str) = headers
.get(header::IF_NONE_MATCH)
.and_then(|h| h.to_str().ok())
{
let etag = format!("\"{}\"", entry.content_hash());
if str
.trim()
.split(',')
.collect::<Vec<_>>()
.contains(&etag.as_str())
{
*response.status_mut() = StatusCode::NOT_MODIFIED;
};
Ok(response)
} else {
Err(Error::with_status(StatusCode::NOT_FOUND))?
}
if let Some(body) = body {
*response.body_mut() = body;
}
Ok(response)
}
impl From<&Entry> for HeaderMap {

View File

@@ -1,21 +1,23 @@
use std::io::Write;
use axum::{
body::{Body, HttpBody},
extract::{OriginalUri, State},
extract::{Path, State},
http::StatusCode,
response::IntoResponse,
};
use futures_util::stream::StreamExt;
use crate::core::{
err_if_user_is_invalid::err_if_user_is_invalid,
error::{Error, Result},
extractors::PubkyHost,
AppState,
use crate::{
core::{
err_if_user_is_invalid::err_if_user_is_invalid,
error::{Error, Result},
extractors::PubkyHost,
AppState,
},
persistence::lmdb::tables::files::AsyncInDbTempFileWriter,
shared::webdav::{EntryPath, WebDavPathAxum},
};
/// Fail with 507 if `(current + incoming existing) > quota`.
/// Fail with 507 if `(current + incoming existing) > quota`.
fn enforce_user_disk_quota(
existing_bytes: u64,
incoming_bytes: u64,
@@ -31,7 +33,7 @@ fn enforce_user_disk_quota(
return Err(Error::new(
StatusCode::INSUFFICIENT_STORAGE,
Some(format!(
"Quota of {:.1} MB exceeded: youve used {:.1} MB, trying to add {:.1} MB",
"Quota of {:.1} MB exceeded: you've used {:.1} MB, trying to add {:.1} MB",
max_mb, current_mb, adding_mb
)),
));
@@ -43,15 +45,15 @@ fn enforce_user_disk_quota(
pub async fn delete(
State(mut state): State<AppState>,
pubky: PubkyHost,
path: OriginalUri,
Path(path): Path<WebDavPathAxum>,
) -> Result<impl IntoResponse> {
err_if_user_is_invalid(pubky.public_key(), &state.db, false)?;
let public_key = pubky.public_key();
let full_path = path.0.path();
let existing_bytes = state.db.get_entry_content_length(public_key, full_path)?;
err_if_user_is_invalid(pubky.public_key(), &state.db, false)?;
let entry_path = EntryPath::new(public_key.clone(), path.0);
let existing_bytes = state.db.get_entry_content_length(&entry_path)?;
// Remove entry
if !state.db.delete_entry(public_key, full_path)? {
if !state.db.delete_entry(&entry_path).await? {
return Err(Error::with_status(StatusCode::NOT_FOUND));
}
@@ -66,36 +68,38 @@ pub async fn delete(
pub async fn put(
State(mut state): State<AppState>,
pubky: PubkyHost,
path: OriginalUri,
Path(path): Path<WebDavPathAxum>,
body: Body,
) -> Result<impl IntoResponse> {
err_if_user_is_invalid(pubky.public_key(), &state.db, true)?;
let public_key = pubky.public_key();
let full_path = path.0.path();
let existing_entry_bytes = state.db.get_entry_content_length(public_key, full_path)?;
err_if_user_is_invalid(public_key, &state.db, true)?;
let entry_path = EntryPath::new(public_key.clone(), path.0);
let existing_entry_bytes = state.db.get_entry_content_length(&entry_path)?;
let quota_bytes = state.user_quota_bytes;
let used_bytes = state.db.get_user_data_usage(public_key)?;
// Upfront check when we have an exact ContentLength
// Upfront check when we have an exact Content-Length
let hint = body.size_hint().exact();
if let Some(exact_bytes) = hint {
enforce_user_disk_quota(existing_entry_bytes, exact_bytes, used_bytes, quota_bytes)?;
}
// Stream body
let mut writer = state.db.write_entry(public_key, full_path)?;
// Stream body to disk first.
let mut seen_bytes: u64 = 0;
let mut stream = body.into_data_stream();
let mut buffer_file_writer = AsyncInDbTempFileWriter::new().await?;
while let Some(chunk) = stream.next().await.transpose()? {
seen_bytes += chunk.len() as u64;
enforce_user_disk_quota(existing_entry_bytes, seen_bytes, used_bytes, quota_bytes)?;
writer.write_all(&chunk)?;
buffer_file_writer.write_chunk(&chunk).await?;
}
let buffer_file = buffer_file_writer.complete().await?;
// Commit & bump usage
let entry = writer.commit()?;
let delta = entry.content_length() as i64 - existing_entry_bytes as i64;
// Write file on disk to db
state.db.write_entry(&entry_path, &buffer_file).await?;
let delta = buffer_file.len() as i64 - existing_entry_bytes as i64;
state.db.update_data_usage(public_key, delta)?;
Ok((StatusCode::CREATED, ()))

View File

@@ -11,7 +11,6 @@ pub const DEFAULT_MAP_SIZE: usize = 10995116277760; // 10TB (not = disk-space us
pub struct LmDB {
pub(crate) env: Env,
pub(crate) tables: Tables,
pub(crate) buffers_dir: PathBuf,
pub(crate) max_chunk_size: usize,
// Only used for testing purposes to keep the testdir alive.
#[allow(dead_code)]
@@ -44,7 +43,6 @@ impl LmDB {
let db = LmDB {
env,
tables,
buffers_dir,
max_chunk_size: Self::max_chunk_size(),
test_dir: None,
};

View File

@@ -1,15 +1,15 @@
use heed::{Env, RwTxn};
use crate::persistence::lmdb::tables::{blobs, entries, events, sessions, signup_tokens, users};
use crate::persistence::lmdb::tables::{events, files, sessions, signup_tokens, users};
pub fn run(env: &Env, wtxn: &mut RwTxn) -> anyhow::Result<()> {
let _: users::UsersTable = env.create_database(wtxn, Some(users::USERS_TABLE))?;
let _: sessions::SessionsTable = env.create_database(wtxn, Some(sessions::SESSIONS_TABLE))?;
let _: blobs::BlobsTable = env.create_database(wtxn, Some(blobs::BLOBS_TABLE))?;
let _: files::BlobsTable = env.create_database(wtxn, Some(files::BLOBS_TABLE))?;
let _: entries::EntriesTable = env.create_database(wtxn, Some(entries::ENTRIES_TABLE))?;
let _: files::EntriesTable = env.create_database(wtxn, Some(files::ENTRIES_TABLE))?;
let _: events::EventsTable = env.create_database(wtxn, Some(events::EVENTS_TABLE))?;

View File

@@ -1,14 +1,13 @@
pub mod blobs;
pub mod entries;
pub mod events;
pub mod files;
pub mod sessions;
pub mod signup_tokens;
pub mod users;
use heed::{Env, RwTxn};
use blobs::{BlobsTable, BLOBS_TABLE};
use entries::{EntriesTable, ENTRIES_TABLE};
use files::{BlobsTable, BLOBS_TABLE};
use files::{EntriesTable, ENTRIES_TABLE};
use self::{
events::{EventsTable, EVENTS_TABLE},

View File

@@ -1,24 +0,0 @@
use heed::{types::Bytes, Database, RoTxn};
use super::super::LmDB;
use super::entries::Entry;
/// (entry timestamp | chunk_index BE) => bytes
pub type BlobsTable = Database<Bytes, Bytes>;
pub const BLOBS_TABLE: &str = "blobs";
impl LmDB {
pub fn read_entry_content<'txn>(
&self,
rtxn: &'txn RoTxn,
entry: &Entry,
) -> anyhow::Result<impl Iterator<Item = Result<&'txn [u8], heed::Error>> + 'txn> {
Ok(self
.tables
.blobs
.prefix_iter(rtxn, &entry.timestamp().to_bytes())?
.map(|i| i.map(|(_, bytes)| bytes)))
}
}

View File

@@ -1,560 +0,0 @@
use pkarr::PublicKey;
use postcard::{from_bytes, to_allocvec};
use serde::{Deserialize, Serialize};
use std::{
fs::File,
io::{Read, Write},
path::PathBuf,
};
use tracing::instrument;
use heed::{
types::{Bytes, Str},
Database, RoTxn,
};
use pubky_common::{
crypto::{Hash, Hasher},
timestamp::Timestamp,
};
use crate::constants::{DEFAULT_LIST_LIMIT, DEFAULT_MAX_LIST_LIMIT};
use super::super::LmDB;
use super::events::Event;
/// full_path(pubky/*path) => Entry.
pub type EntriesTable = Database<Str, Bytes>;
pub const ENTRIES_TABLE: &str = "entries";
impl LmDB {
/// Write an entry by an author at a given path.
///
/// The path has to start with a forward slash `/`
pub fn write_entry(
&mut self,
public_key: &PublicKey,
path: &str,
) -> anyhow::Result<EntryWriter> {
EntryWriter::new(self, public_key, path)
}
/// Delete an entry by an author at a given path.
///
/// The path has to start with a forward slash `/`
pub fn delete_entry(&mut self, public_key: &PublicKey, path: &str) -> anyhow::Result<bool> {
let mut wtxn = self.env.write_txn()?;
let key = format!("{public_key}{path}");
let deleted = if let Some(bytes) = self.tables.entries.get(&wtxn, &key)? {
let entry = Entry::deserialize(bytes)?;
let mut deleted_chunks = false;
{
let mut iter = self
.tables
.blobs
.prefix_iter_mut(&mut wtxn, &entry.timestamp.to_bytes())?;
while iter.next().is_some() {
unsafe {
deleted_chunks = iter.del_current()?;
}
}
}
let deleted_entry = self.tables.entries.delete(&mut wtxn, &key)?;
// create DELETE event
if path.starts_with("/pub/") {
let url = format!("pubky://{key}");
let event = Event::delete(&url);
let value = event.serialize();
let key = Timestamp::now().to_string();
self.tables.events.put(&mut wtxn, &key, &value)?;
// TODO: delete events older than a threshold.
// TODO: move to events.rs
}
deleted_entry && deleted_chunks
} else {
false
};
wtxn.commit()?;
Ok(deleted)
}
pub fn get_entry(
&self,
txn: &RoTxn,
public_key: &PublicKey,
path: &str,
) -> anyhow::Result<Option<Entry>> {
let key = format!("{public_key}{path}");
if let Some(bytes) = self.tables.entries.get(txn, &key)? {
return Ok(Some(Entry::deserialize(bytes)?));
}
Ok(None)
}
/// Bytes stored at `path` for this user (0 if none).
pub fn get_entry_content_length(
&self,
public_key: &PublicKey,
path: &str,
) -> anyhow::Result<u64> {
let txn = self.env.read_txn()?;
let content_length = self
.get_entry(&txn, public_key, path)?
.map(|e| e.content_length() as u64)
.unwrap_or(0);
Ok(content_length)
}
pub fn contains_directory(&self, txn: &RoTxn, path: &str) -> anyhow::Result<bool> {
Ok(self.tables.entries.get_greater_than(txn, path)?.is_some())
}
/// Return a list of pubky urls.
///
/// - limit defaults to [crate::config::DEFAULT_LIST_LIMIT] and capped by [crate::config::DEFAULT_MAX_LIST_LIMIT]
pub fn list(
&self,
txn: &RoTxn,
path: &str,
reverse: bool,
limit: Option<u16>,
cursor: Option<String>,
shallow: bool,
) -> anyhow::Result<Vec<String>> {
// Vector to store results
let mut results = Vec::new();
let limit = limit
.unwrap_or(DEFAULT_LIST_LIMIT)
.min(DEFAULT_MAX_LIST_LIMIT);
// TODO: make this more performant than split and allocations?
let mut threshold = cursor
.map(|cursor| {
// Removing leading forward slashes
let mut file_or_directory = cursor.trim_start_matches('/');
if cursor.starts_with("pubky://") {
file_or_directory = cursor.split(path).last().expect("should not be reachable")
};
next_threshold(
path,
file_or_directory,
file_or_directory.ends_with('/'),
reverse,
shallow,
)
})
.unwrap_or(next_threshold(path, "", false, reverse, shallow));
for _ in 0..limit {
if let Some((key, _)) = if reverse {
self.tables.entries.get_lower_than(txn, &threshold)?
} else {
self.tables.entries.get_greater_than(txn, &threshold)?
} {
if !key.starts_with(path) {
break;
}
if shallow {
let mut split = key[path.len()..].split('/');
let file_or_directory = split.next().expect("should not be reachable");
let is_directory = split.next().is_some();
threshold =
next_threshold(path, file_or_directory, is_directory, reverse, shallow);
results.push(format!(
"pubky://{path}{file_or_directory}{}",
if is_directory { "/" } else { "" }
));
} else {
threshold = key.to_string();
results.push(format!("pubky://{}", key))
}
};
}
Ok(results)
}
}
/// Calculate the next threshold
#[instrument]
fn next_threshold(
path: &str,
file_or_directory: &str,
is_directory: bool,
reverse: bool,
shallow: bool,
) -> String {
format!(
"{path}{file_or_directory}{}",
if file_or_directory.is_empty() {
// No file_or_directory, early return
if reverse {
// `path/to/dir/\x7f` to catch all paths than `path/to/dir/`
"\x7f"
} else {
""
}
} else if shallow & is_directory {
if reverse {
// threshold = `path/to/dir\x2e`, since `\x2e` is lower than `/`
"\x2e"
} else {
//threshold = `path/to/dir\x7f`, since `\x7f` is greater than `/`
"\x7f"
}
} else {
""
}
)
}
#[derive(Clone, Default, Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct Entry {
/// Encoding version
version: usize,
/// Modified at
timestamp: Timestamp,
content_hash: EntryHash,
content_length: usize,
content_type: String,
// user_metadata: ?
}
#[derive(Clone, Debug, Eq, PartialEq)]
struct EntryHash(Hash);
impl Default for EntryHash {
fn default() -> Self {
Self(Hash::from_bytes([0; 32]))
}
}
impl Serialize for EntryHash {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let bytes = self.0.as_bytes();
bytes.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for EntryHash {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let bytes: [u8; 32] = Deserialize::deserialize(deserializer)?;
Ok(Self(Hash::from_bytes(bytes)))
}
}
impl Entry {
pub fn new() -> Self {
Default::default()
}
// === Setters ===
pub fn set_timestamp(&mut self, timestamp: &Timestamp) -> &mut Self {
self.timestamp = *timestamp;
self
}
pub fn set_content_hash(&mut self, content_hash: Hash) -> &mut Self {
EntryHash(content_hash).clone_into(&mut self.content_hash);
self
}
pub fn set_content_length(&mut self, content_length: usize) -> &mut Self {
self.content_length = content_length;
self
}
// === Getters ===
pub fn timestamp(&self) -> &Timestamp {
&self.timestamp
}
pub fn content_hash(&self) -> &Hash {
&self.content_hash.0
}
pub fn content_length(&self) -> usize {
self.content_length
}
pub fn content_type(&self) -> &str {
&self.content_type
}
// === Public Method ===
pub fn read_content<'txn>(
&self,
db: &'txn LmDB,
rtxn: &'txn RoTxn,
) -> anyhow::Result<impl Iterator<Item = Result<&'txn [u8], heed::Error>> + 'txn> {
db.read_entry_content(rtxn, self)
}
pub fn serialize(&self) -> Vec<u8> {
to_allocvec(self).expect("Session::serialize")
}
pub fn deserialize(bytes: &[u8]) -> core::result::Result<Self, postcard::Error> {
if bytes[0] > 0 {
panic!("Unknown Entry version");
}
from_bytes(bytes)
}
}
pub struct EntryWriter<'db> {
db: &'db LmDB,
buffer: File,
hasher: Hasher,
buffer_path: PathBuf,
entry_key: String,
timestamp: Timestamp,
is_public: bool,
}
impl<'db> EntryWriter<'db> {
pub fn new(db: &'db LmDB, public_key: &PublicKey, path: &str) -> anyhow::Result<Self> {
let hasher = Hasher::new();
let timestamp = Timestamp::now();
let buffer_path = db.buffers_dir.join(timestamp.to_string());
let buffer = File::create(&buffer_path)?;
let entry_key = format!("{public_key}{path}");
Ok(Self {
db,
buffer,
hasher,
buffer_path,
entry_key,
timestamp,
is_public: path.starts_with("/pub/"),
})
}
/// Same ase [EntryWriter::write_all] but returns a Result of a mutable reference of itself
/// to enable chaining with [Self::commit].
pub fn update(&mut self, chunk: &[u8]) -> Result<&mut Self, std::io::Error> {
self.write_all(chunk)?;
Ok(self)
}
/// Commit blob from the filesystem buffer to LMDB,
/// write the [Entry], and commit the write transaction.
pub fn commit(&self) -> anyhow::Result<Entry> {
let hash = self.hasher.finalize();
let mut buffer = File::open(&self.buffer_path)?;
let mut wtxn = self.db.env.write_txn()?;
let mut chunk_key = [0; 12];
chunk_key[0..8].copy_from_slice(&self.timestamp.to_bytes());
let mut chunk_index: u32 = 0;
loop {
let mut chunk = vec![0_u8; self.db.max_chunk_size];
let bytes_read = buffer.read(&mut chunk)?;
if bytes_read == 0 {
break; // EOF reached
}
chunk_key[8..].copy_from_slice(&chunk_index.to_be_bytes());
self.db
.tables
.blobs
.put(&mut wtxn, &chunk_key, &chunk[..bytes_read])?;
chunk_index += 1;
}
let mut entry = Entry::new();
entry.set_timestamp(&self.timestamp);
entry.set_content_hash(hash);
let length = buffer.metadata()?.len();
entry.set_content_length(length as usize);
self.db
.tables
.entries
.put(&mut wtxn, &self.entry_key, &entry.serialize())?;
// Write a public [Event].
if self.is_public {
let url = format!("pubky://{}", self.entry_key);
let event = Event::put(&url);
let value = event.serialize();
let key = entry.timestamp.to_string();
self.db.tables.events.put(&mut wtxn, &key, &value)?;
// TODO: delete events older than a threshold.
// TODO: move to events.rs
}
wtxn.commit()?;
std::fs::remove_file(&self.buffer_path)?;
Ok(entry)
}
}
impl std::io::Write for EntryWriter<'_> {
/// Write a chunk to a Filesystem based buffer.
#[inline]
fn write(&mut self, chunk: &[u8]) -> std::io::Result<usize> {
self.hasher.update(chunk);
self.buffer.write_all(chunk)?;
Ok(chunk.len())
}
/// Does not do anything, you need to call [Self::commit]
#[inline]
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use pkarr::Keypair;
use super::LmDB;
#[tokio::test]
async fn entries() -> anyhow::Result<()> {
let mut db = LmDB::test();
let keypair = Keypair::random();
let public_key = keypair.public_key();
let path = "/pub/foo.txt";
let chunk = Bytes::from(vec![1, 2, 3, 4, 5]);
db.write_entry(&public_key, path)?
.update(&chunk)?
.commit()?;
let rtxn = db.env.read_txn().unwrap();
let entry = db.get_entry(&rtxn, &public_key, path).unwrap().unwrap();
assert_eq!(
entry.content_hash(),
&[
2, 79, 103, 192, 66, 90, 61, 192, 47, 186, 245, 140, 185, 61, 229, 19, 46, 61, 117,
197, 25, 250, 160, 186, 218, 33, 73, 29, 136, 201, 112, 87
]
);
let mut blob = vec![];
{
let mut iter = entry.read_content(&db, &rtxn).unwrap();
while let Some(Ok(chunk)) = iter.next() {
blob.extend_from_slice(chunk);
}
}
assert_eq!(blob, vec![1, 2, 3, 4, 5]);
rtxn.commit().unwrap();
Ok(())
}
#[tokio::test]
async fn chunked_entry() -> anyhow::Result<()> {
let mut db = LmDB::test();
let keypair = Keypair::random();
let public_key = keypair.public_key();
let path = "/pub/foo.txt";
let chunk = Bytes::from(vec![0; 1024 * 1024]);
db.write_entry(&public_key, path)?
.update(&chunk)?
.commit()?;
let rtxn = db.env.read_txn().unwrap();
let entry = db.get_entry(&rtxn, &public_key, path).unwrap().unwrap();
assert_eq!(
entry.content_hash(),
&[
72, 141, 226, 2, 247, 59, 217, 118, 222, 78, 112, 72, 244, 225, 243, 154, 119, 109,
134, 213, 130, 183, 52, 143, 245, 59, 244, 50, 185, 135, 252, 168
]
);
let mut blob = vec![];
{
let mut iter = entry.read_content(&db, &rtxn).unwrap();
while let Some(Ok(chunk)) = iter.next() {
blob.extend_from_slice(chunk);
}
}
assert_eq!(blob, vec![0; 1024 * 1024]);
let stats = db.tables.blobs.stat(&rtxn).unwrap();
assert_eq!(stats.overflow_pages, 0);
rtxn.commit().unwrap();
Ok(())
}
}

View File

@@ -0,0 +1,160 @@
use super::super::super::LmDB;
use super::{InDbFileId, InDbTempFile, SyncInDbTempFileWriter};
use heed::{types::Bytes, Database};
use std::io::Read;
/// (entry timestamp | chunk_index BE) => bytes
pub type BlobsTable = Database<Bytes, Bytes>;
pub const BLOBS_TABLE: &str = "blobs";
impl LmDB {
/// Read the blobs into a temporary file.
///
/// The file is written to disk to minimize the size/duration of the LMDB transaction.
pub(crate) fn read_file_sync(&self, id: &InDbFileId) -> anyhow::Result<InDbTempFile> {
let mut file_writer = SyncInDbTempFileWriter::new()?;
let rtxn = self.env.read_txn()?;
let blobs_iter = self
.tables
.blobs
.prefix_iter(&rtxn, &id.bytes())?
.map(|i| i.map(|(_, bytes)| bytes));
let mut file_exists = false;
for read_result in blobs_iter {
file_exists = true;
let chunk = read_result?;
file_writer.write_chunk(chunk)?;
}
if !file_exists {
return Ok(InDbTempFile::empty()?);
}
let file = file_writer.complete()?;
rtxn.commit()?;
Ok(file)
}
/// Read the blobs into a temporary file asynchronously.
pub(crate) async fn read_file(&self, id: &InDbFileId) -> anyhow::Result<InDbTempFile> {
let db = self.clone();
let id = *id;
let join_handle = tokio::task::spawn_blocking(move || -> anyhow::Result<InDbTempFile> {
db.read_file_sync(&id)
})
.await;
match join_handle {
Ok(result) => result,
Err(e) => {
tracing::error!("Error reading file. JoinError: {:?}", e);
Err(e.into())
}
}
}
/// Write the blobs from a temporary file to LMDB.
pub(crate) fn write_file_sync<'txn>(
&'txn self,
file: &InDbTempFile,
wtxn: &mut heed::RwTxn<'txn>,
) -> anyhow::Result<InDbFileId> {
let id = InDbFileId::new();
let mut file_handle = file.open_file_handle()?;
let mut blob_index: u32 = 0;
loop {
let mut blob = vec![0_u8; self.max_chunk_size];
let bytes_read = file_handle.read(&mut blob)?;
let is_end_of_file = bytes_read == 0;
if is_end_of_file {
break; // EOF reached
}
let blob_key = id.get_blob_key(blob_index);
self.tables
.blobs
.put(wtxn, &blob_key, &blob[..bytes_read])?;
blob_index += 1;
}
Ok(id)
}
/// Delete the blobs from LMDB.
pub(crate) fn delete_file<'txn>(
&'txn self,
file: &InDbFileId,
wtxn: &mut heed::RwTxn<'txn>,
) -> anyhow::Result<bool> {
let mut deleted_chunks = false;
{
let mut iter = self.tables.blobs.prefix_iter_mut(wtxn, &file.bytes())?;
while iter.next().is_some() {
unsafe {
deleted_chunks = iter.del_current()?;
}
}
}
Ok(deleted_chunks)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_write_read_delete_file() {
let lmdb = LmDB::test();
// Write file to LMDB
let write_file = InDbTempFile::zeros(50).await.unwrap();
let mut wtxn = lmdb.env.write_txn().unwrap();
let id = lmdb.write_file_sync(&write_file, &mut wtxn).unwrap();
wtxn.commit().unwrap();
// Read file from LMDB
let read_file = lmdb.read_file(&id).await.unwrap();
assert_eq!(read_file.len(), write_file.len());
assert_eq!(read_file.hash(), write_file.hash());
let written_file_content = std::fs::read(write_file.path()).unwrap();
let read_file_content = std::fs::read(read_file.path()).unwrap();
assert_eq!(written_file_content, read_file_content);
// Delete file from LMDB
let mut wtxn = lmdb.env.write_txn().unwrap();
let deleted = lmdb.delete_file(&id, &mut wtxn).unwrap();
wtxn.commit().unwrap();
assert!(deleted);
// Try to read file from LMDB
let read_file = lmdb.read_file(&id).await.unwrap();
assert_eq!(read_file.len(), 0);
}
#[tokio::test]
async fn test_write_empty_file() {
let lmdb = LmDB::test();
// Write file to LMDB
let write_file = InDbTempFile::empty().unwrap();
let mut wtxn = lmdb.env.write_txn().unwrap();
let id = lmdb.write_file_sync(&write_file, &mut wtxn).unwrap();
wtxn.commit().unwrap();
// Read file from LMDB
let read_file = lmdb.read_file(&id).await.unwrap();
assert_eq!(read_file.len(), write_file.len());
assert_eq!(read_file.hash(), write_file.hash());
let written_file_content = std::fs::read(write_file.path()).unwrap();
let read_file_content = std::fs::read(read_file.path()).unwrap();
assert_eq!(written_file_content, read_file_content);
}
}

View File

@@ -0,0 +1,499 @@
use super::super::events::Event;
use super::{super::super::LmDB, InDbFileId, InDbTempFile};
use crate::constants::{DEFAULT_LIST_LIMIT, DEFAULT_MAX_LIST_LIMIT};
use crate::shared::webdav::EntryPath;
use heed::{
types::{Bytes, Str},
Database, RoTxn,
};
use postcard::{from_bytes, to_allocvec};
use pubky_common::{crypto::Hash, timestamp::Timestamp};
use serde::{Deserialize, Serialize};
use tracing::instrument;
/// full_path(pubky/*path) => Entry.
pub type EntriesTable = Database<Str, Bytes>;
pub const ENTRIES_TABLE: &str = "entries";
impl LmDB {
/// Writes an entry to the database.
///
/// The entry is written to the database and the file is written to the blob store.
/// An event is written to the events table.
/// The entry is returned.
pub async fn write_entry(
&mut self,
path: &EntryPath,
file: &InDbTempFile,
) -> anyhow::Result<Entry> {
let mut db = self.clone();
let path = path.clone();
let file = file.clone();
let join_handle = tokio::task::spawn_blocking(move || -> anyhow::Result<Entry> {
db.write_entry_sync(&path, &file)
})
.await;
match join_handle {
Ok(result) => result,
Err(e) => {
tracing::error!("Error writing entry. JoinError: {:?}", e);
Err(e.into())
}
}
}
/// Writes an entry to the database.
///
/// The entry is written to the database and the file is written to the blob store.
/// An event is written to the events table.
/// The entry is returned.
pub fn write_entry_sync(
&mut self,
path: &EntryPath,
file: &InDbTempFile,
) -> anyhow::Result<Entry> {
let mut wtxn = self.env.write_txn()?;
let mut entry = Entry::new();
entry.set_content_hash(*file.hash());
entry.set_content_length(file.len());
let file_id = self.write_file_sync(file, &mut wtxn)?;
entry.set_timestamp(file_id.timestamp());
let entry_key = path.to_string();
self.tables
.entries
.put(&mut wtxn, entry_key.as_str(), &entry.serialize())?;
// Write a public [Event].
let url = format!("pubky://{}", entry_key);
let event = Event::put(&url);
let value = event.serialize();
self.tables
.events
.put(&mut wtxn, file_id.timestamp().to_string().as_str(), &value)?;
wtxn.commit()?;
Ok(entry)
}
/// Get an entry from the database.
/// This doesn't include the file but only metadata.
pub fn get_entry(&self, path: &EntryPath) -> anyhow::Result<Option<Entry>> {
let txn = self.env.read_txn()?;
let entry = match self.tables.entries.get(&txn, path.as_str())? {
Some(bytes) => Entry::deserialize(bytes)?,
None => return Ok(None),
};
Ok(Some(entry))
}
/// Delete an entry including the associated file from the database.
pub async fn delete_entry(&mut self, path: &EntryPath) -> anyhow::Result<bool> {
let mut db = self.clone();
let path = path.clone();
let join_handle = tokio::task::spawn_blocking(move || -> anyhow::Result<bool> {
db.delete_entry_sync(&path)
})
.await;
match join_handle {
Ok(result) => result,
Err(e) => {
tracing::error!("Error deleting entry. JoinError: {:?}", e);
Err(e.into())
}
}
}
/// Delete an entry including the associated file from the database.
pub fn delete_entry_sync(&mut self, path: &EntryPath) -> anyhow::Result<bool> {
let entry = match self.get_entry(path)? {
Some(entry) => entry,
None => return Ok(false),
};
let mut wtxn = self.env.write_txn()?;
let deleted = self.delete_file(&entry.file_id(), &mut wtxn)?;
if !deleted {
wtxn.abort();
return Ok(false);
}
let deleted = self.tables.entries.delete(&mut wtxn, path.as_str())?;
if !deleted {
wtxn.abort();
return Ok(false);
}
// create DELETE event
let url = format!("pubky://{}", path.as_str());
let event = Event::delete(&url);
let value = event.serialize();
let key = Timestamp::now().to_string();
self.tables.events.put(&mut wtxn, &key, &value)?;
wtxn.commit()?;
Ok(true)
}
/// Bytes stored at `path` for this user (0 if none).
pub fn get_entry_content_length(&self, path: &EntryPath) -> anyhow::Result<u64> {
let content_length = self
.get_entry(path)?
.map(|e| e.content_length() as u64)
.unwrap_or(0);
Ok(content_length)
}
pub fn contains_directory(&self, txn: &RoTxn, entry_path: &EntryPath) -> anyhow::Result<bool> {
Ok(self
.tables
.entries
.get_greater_than(txn, entry_path.as_str())?
.is_some())
}
/// Return a list of pubky urls.
///
/// - limit defaults to [crate::config::DEFAULT_LIST_LIMIT] and capped by [crate::config::DEFAULT_MAX_LIST_LIMIT]
pub fn list_entries(
&self,
txn: &RoTxn,
entry_path: &EntryPath,
reverse: bool,
limit: Option<u16>,
cursor: Option<String>,
shallow: bool,
) -> anyhow::Result<Vec<String>> {
// Vector to store results
let mut results = Vec::new();
let limit = limit
.unwrap_or(DEFAULT_LIST_LIMIT)
.min(DEFAULT_MAX_LIST_LIMIT);
// TODO: make this more performant than split and allocations?
let mut threshold = cursor
.map(|cursor| {
// Removing leading forward slashes
let mut file_or_directory = cursor.trim_start_matches('/');
if cursor.starts_with("pubky://") {
file_or_directory = cursor
.split(entry_path.as_str())
.last()
.expect("should not be reachable")
};
next_threshold(
entry_path.as_str(),
file_or_directory,
file_or_directory.ends_with('/'),
reverse,
shallow,
)
})
.unwrap_or(next_threshold(
entry_path.as_str(),
"",
false,
reverse,
shallow,
));
for _ in 0..limit {
if let Some((key, _)) = if reverse {
self.tables.entries.get_lower_than(txn, &threshold)?
} else {
self.tables.entries.get_greater_than(txn, &threshold)?
} {
if !key.starts_with(entry_path.as_str()) {
break;
}
if shallow {
let mut split = key[entry_path.as_str().len()..].split('/');
let file_or_directory = split.next().expect("should not be reachable");
let is_directory = split.next().is_some();
threshold = next_threshold(
entry_path.as_str(),
file_or_directory,
is_directory,
reverse,
shallow,
);
results.push(format!(
"pubky://{}{file_or_directory}{}",
entry_path.as_str(),
if is_directory { "/" } else { "" }
));
} else {
threshold = key.to_string();
results.push(format!("pubky://{}", key))
}
};
}
Ok(results)
}
}
/// Calculate the next threshold
#[instrument]
fn next_threshold(
path: &str,
file_or_directory: &str,
is_directory: bool,
reverse: bool,
shallow: bool,
) -> String {
format!(
"{path}{file_or_directory}{}",
if file_or_directory.is_empty() {
// No file_or_directory, early return
if reverse {
// `path/to/dir/\x7f` to catch all paths than `path/to/dir/`
"\x7f"
} else {
""
}
} else if shallow & is_directory {
if reverse {
// threshold = `path/to/dir\x2e`, since `\x2e` is lower than `/`
"\x2e"
} else {
//threshold = `path/to/dir\x7f`, since `\x7f` is greater than `/`
"\x7f"
}
} else {
""
}
)
}
#[derive(Clone, Default, Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct Entry {
/// Encoding version
version: usize,
/// Modified at
timestamp: Timestamp,
content_hash: EntryHash,
content_length: usize,
content_type: String,
}
#[derive(Clone, Debug, Eq, PartialEq)]
struct EntryHash(Hash);
impl Default for EntryHash {
fn default() -> Self {
Self(Hash::from_bytes([0; 32]))
}
}
impl Serialize for EntryHash {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let bytes = self.0.as_bytes();
bytes.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for EntryHash {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let bytes: [u8; 32] = Deserialize::deserialize(deserializer)?;
Ok(Self(Hash::from_bytes(bytes)))
}
}
impl Entry {
pub fn new() -> Self {
Default::default()
}
pub fn chunk_key(&self, chunk_index: u32) -> [u8; 12] {
let mut chunk_key = [0; 12];
chunk_key[0..8].copy_from_slice(&self.timestamp.to_bytes());
chunk_key[8..].copy_from_slice(&chunk_index.to_be_bytes());
chunk_key
}
// === Setters ===
pub fn set_timestamp(&mut self, timestamp: &Timestamp) -> &mut Self {
self.timestamp = *timestamp;
self
}
pub fn set_content_hash(&mut self, content_hash: Hash) -> &mut Self {
EntryHash(content_hash).clone_into(&mut self.content_hash);
self
}
pub fn set_content_length(&mut self, content_length: usize) -> &mut Self {
self.content_length = content_length;
self
}
// === Getters ===
pub fn timestamp(&self) -> &Timestamp {
&self.timestamp
}
pub fn content_hash(&self) -> &Hash {
&self.content_hash.0
}
pub fn content_length(&self) -> usize {
self.content_length
}
pub fn content_type(&self) -> &str {
&self.content_type
}
pub fn file_id(&self) -> InDbFileId {
InDbFileId::from(self.timestamp)
}
// === Public Method ===
pub fn serialize(&self) -> Vec<u8> {
to_allocvec(self).expect("Session::serialize")
}
pub fn deserialize(bytes: &[u8]) -> core::result::Result<Self, postcard::Error> {
if bytes[0] > 0 {
panic!("Unknown Entry version");
}
from_bytes(bytes)
}
}
#[cfg(test)]
mod tests {
use super::LmDB;
use crate::{
persistence::lmdb::tables::files::{InDbTempFile, SyncInDbTempFileWriter},
shared::webdav::{EntryPath, WebDavPath},
};
use bytes::Bytes;
use pkarr::Keypair;
use std::io::Read;
#[tokio::test]
async fn test_write_read_delete_method() {
let mut db = LmDB::test();
let path = EntryPath::new(
Keypair::random().public_key(),
WebDavPath::new("/pub/foo.txt").unwrap(),
);
let file = InDbTempFile::zeros(5).await.unwrap();
let entry = db.write_entry_sync(&path, &file).unwrap();
let read_entry = db.get_entry(&path).unwrap().expect("Entry doesn't exist");
assert_eq!(entry.content_hash(), read_entry.content_hash());
assert_eq!(entry.content_length(), read_entry.content_length());
assert_eq!(entry.timestamp(), read_entry.timestamp());
let read_file = db.read_file(&entry.file_id()).await.unwrap();
let mut file_handle = read_file.open_file_handle().unwrap();
let mut content = vec![];
file_handle.read_to_end(&mut content).unwrap();
assert_eq!(content, vec![0, 0, 0, 0, 0]);
let deleted = db.delete_entry_sync(&path).unwrap();
assert!(deleted);
// Verify the entry and file are deleted
let read_entry = db.get_entry(&path).unwrap();
assert!(read_entry.is_none());
let read_file = db.read_file(&entry.file_id()).await.unwrap();
assert_eq!(read_file.len(), 0);
}
#[tokio::test]
async fn entries() -> anyhow::Result<()> {
let mut db = LmDB::test();
let keypair = Keypair::random();
let public_key = keypair.public_key();
let path = "/pub/foo.txt";
let entry_path = EntryPath::new(public_key, WebDavPath::new(path).unwrap());
let chunk = Bytes::from(vec![1, 2, 3, 4, 5]);
let mut writer = SyncInDbTempFileWriter::new()?;
writer.write_chunk(&chunk)?;
let file = writer.complete()?;
db.write_entry_sync(&entry_path, &file)?;
let entry = db.get_entry(&entry_path).unwrap().unwrap();
assert_eq!(
entry.content_hash(),
&[
2, 79, 103, 192, 66, 90, 61, 192, 47, 186, 245, 140, 185, 61, 229, 19, 46, 61, 117,
197, 25, 250, 160, 186, 218, 33, 73, 29, 136, 201, 112, 87
]
);
let read_file = db.read_file(&entry.file_id()).await.unwrap();
let mut file_handle = read_file.open_file_handle().unwrap();
let mut content = vec![];
file_handle.read_to_end(&mut content).unwrap();
assert_eq!(content, vec![1, 2, 3, 4, 5]);
Ok(())
}
#[tokio::test]
async fn chunked_entry() -> anyhow::Result<()> {
let mut db = LmDB::test();
let keypair = Keypair::random();
let public_key = keypair.public_key();
let path = "/pub/foo.txt";
let entry_path = EntryPath::new(public_key, WebDavPath::new(path).unwrap());
let chunk = Bytes::from(vec![0; 1024 * 1024]);
let mut writer = SyncInDbTempFileWriter::new()?;
writer.write_chunk(&chunk)?;
let file = writer.complete()?;
db.write_entry_sync(&entry_path, &file)?;
let entry = db.get_entry(&entry_path).unwrap().unwrap();
assert_eq!(
entry.content_hash(),
&[
72, 141, 226, 2, 247, 59, 217, 118, 222, 78, 112, 72, 244, 225, 243, 154, 119, 109,
134, 213, 130, 183, 52, 143, 245, 59, 244, 50, 185, 135, 252, 168
]
);
let read_file = db.read_file(&entry.file_id()).await.unwrap();
let mut file_handle = read_file.open_file_handle().unwrap();
let mut content = vec![];
file_handle.read_to_end(&mut content).unwrap();
assert_eq!(content, vec![0; 1024 * 1024]);
Ok(())
}
}

View File

@@ -0,0 +1,236 @@
//!
//! InDbFile is an abstraction over the way we store blobs/chunks in LMDB.
//! Because the value size of LMDB is limited, we need to store multiple blobs for one file.
//!
//! - `InDbFileId` is the identifier of a file that consists of multiple blobs.
//! - `InDbTempFile` is a helper to read/write a file to/from disk.
//!
use pubky_common::crypto::{Hash, Hasher};
use pubky_common::timestamp::Timestamp;
/// A file identifier for a file stored in LMDB.
/// The identifier is basically the timestamp of the file.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct InDbFileId(Timestamp);
impl InDbFileId {
pub fn new() -> Self {
Self(Timestamp::now())
}
pub fn timestamp(&self) -> &Timestamp {
&self.0
}
pub fn bytes(&self) -> [u8; 8] {
self.0.to_bytes()
}
/// Create a blob key from a timestamp and a blob index.
/// blob key = (timestamp | blob_index) => bytes.
/// Max file size is 2^32 blobs.
pub fn get_blob_key(&self, blob_index: u32) -> [u8; 12] {
let mut blob_key = [0; 12];
blob_key[0..8].copy_from_slice(&self.bytes());
blob_key[8..].copy_from_slice(&blob_index.to_be_bytes());
blob_key
}
}
impl From<Timestamp> for InDbFileId {
fn from(timestamp: Timestamp) -> Self {
Self(timestamp)
}
}
impl Default for InDbFileId {
fn default() -> Self {
Self::new()
}
}
use std::sync::Arc;
use std::{fs::File, io::Write, path::PathBuf};
use tokio::fs::File as AsyncFile;
use tokio::io::AsyncWriteExt;
use tokio::task;
/// Writes a temp file to disk asynchronously.
#[derive(Debug)]
pub(crate) struct AsyncInDbTempFileWriter {
// Temp dir is automatically deleted when the EntryTempFile is dropped.
#[allow(dead_code)]
dir: tempfile::TempDir,
writer_file: AsyncFile,
file_path: PathBuf,
hasher: Hasher,
}
impl AsyncInDbTempFileWriter {
pub async fn new() -> Result<Self, std::io::Error> {
let dir = task::spawn_blocking(tempfile::tempdir)
.await
.map_err(|join_error| {
std::io::Error::other(format!(
"Task join error for tempdir creation: {}",
join_error
))
})??; // Handles the Result from tempfile::tempdir()
let file_path = dir.path().join("entry.bin");
let writer_file = AsyncFile::create(file_path.clone()).await?;
let hasher = Hasher::new();
Ok(Self {
dir,
writer_file,
file_path,
hasher,
})
}
/// Create a new BlobsTempFile with zero content.
/// Convenient method used for testing.
#[cfg(test)]
pub async fn zeros(size_bytes: usize) -> Result<InDbTempFile, std::io::Error> {
let mut file = Self::new().await?;
let buffer = vec![0u8; size_bytes];
file.write_chunk(&buffer).await?;
file.complete().await
}
/// Write a chunk to the file.
/// Chunk writing is done by the axum body stream and by LMDB itself.
pub async fn write_chunk(&mut self, chunk: &[u8]) -> Result<(), std::io::Error> {
self.writer_file.write_all(chunk).await?;
self.hasher.update(chunk);
Ok(())
}
/// Flush the file to disk.
/// This completes the writing of the file.
/// Returns a BlobsTempFile that can be used to read the file.
pub async fn complete(mut self) -> Result<InDbTempFile, std::io::Error> {
self.writer_file.flush().await?;
let hash = self.hasher.finalize();
let file_size = self.writer_file.metadata().await?.len();
Ok(InDbTempFile {
dir: Arc::new(self.dir),
file_path: self.file_path,
file_size: file_size as usize,
file_hash: hash,
})
}
}
/// Writes a temp file to disk synchronously.
#[derive(Debug)]
pub(crate) struct SyncInDbTempFileWriter {
// Temp dir is automatically deleted when the EntryTempFile is dropped.
#[allow(dead_code)]
dir: tempfile::TempDir,
writer_file: File,
file_path: PathBuf,
hasher: Hasher,
}
impl SyncInDbTempFileWriter {
pub fn new() -> Result<Self, std::io::Error> {
let dir = tempfile::tempdir()?;
let file_path = dir.path().join("entry.bin");
let writer_file = File::create(file_path.clone())?;
let hasher = Hasher::new();
Ok(Self {
dir,
writer_file,
file_path,
hasher,
})
}
/// Write a chunk to the file.
pub fn write_chunk(&mut self, chunk: &[u8]) -> Result<(), std::io::Error> {
self.writer_file.write_all(chunk)?;
self.hasher.update(chunk);
Ok(())
}
/// Flush the file to disk.
/// This completes the writing of the file.
/// Returns a BlobsTempFile that can be used to read the file.
pub fn complete(mut self) -> Result<InDbTempFile, std::io::Error> {
self.writer_file.flush()?;
let hash = self.hasher.finalize();
let file_size = self.writer_file.metadata()?.len();
Ok(InDbTempFile {
dir: Arc::new(self.dir),
file_path: self.file_path,
file_size: file_size as usize,
file_hash: hash,
})
}
}
/// A temporary file helper for Entry.
///
/// Every file in LMDB is first written to disk before being written to LMDB.
/// The same is true if you read a file from LMDB.
///
/// This is to keep the LMDB transaction small and fast.
///
/// As soon as EntryTempFile is dropped, the file on disk is deleted.
///
#[derive(Debug, Clone)]
pub struct InDbTempFile {
// Temp dir is automatically deleted when the EntryTempFile is dropped.
#[allow(dead_code)]
dir: Arc<tempfile::TempDir>,
file_path: PathBuf,
file_size: usize,
file_hash: Hash,
}
impl InDbTempFile {
/// Create a new BlobsTempFile with random content.
/// Convenient method used for testing.
#[cfg(test)]
pub async fn zeros(size_bytes: usize) -> Result<Self, std::io::Error> {
AsyncInDbTempFileWriter::zeros(size_bytes).await
}
/// Create a new InDbTempFile with zero content.
pub fn empty() -> Result<Self, std::io::Error> {
let dir = tempfile::tempdir()?;
let file_path = dir.path().join("entry.bin");
std::fs::File::create(file_path.clone())?;
let file_size = 0;
let hasher = Hasher::new();
let file_hash = hasher.finalize();
Ok(Self {
dir: Arc::new(dir),
file_path,
file_size,
file_hash,
})
}
pub fn len(&self) -> usize {
self.file_size
}
pub fn hash(&self) -> &Hash {
&self.file_hash
}
/// Get the path of the file on disk.
#[cfg(test)]
pub fn path(&self) -> &PathBuf {
&self.file_path
}
/// Open the file on disk.
pub fn open_file_handle(&self) -> Result<File, std::io::Error> {
File::open(self.file_path.as_path())
}
}

View File

@@ -0,0 +1,7 @@
mod blobs;
mod entries;
mod in_db_file;
pub use blobs::{BlobsTable, BLOBS_TABLE};
pub use entries::{EntriesTable, Entry, ENTRIES_TABLE};
pub use in_db_file::*;

View File

@@ -157,6 +157,7 @@ impl LmDB {
/// # Errors
///
/// - `UserQueryError::DatabaseError` if the database operation fails.
#[cfg(test)]
pub fn create_user(&self, pubkey: &PublicKey, wtxn: &mut RwTxn) -> anyhow::Result<()> {
let user = User::default();
self.tables.users.put(wtxn, pubkey, &user)?;

View File

@@ -1,5 +1,6 @@
mod http_error;
mod pubkey_path_validator;
pub(crate) mod webdav;
pub(crate) use http_error::{HttpError, HttpResult};
pub(crate) use pubkey_path_validator::Z32Pubkey;

View File

@@ -0,0 +1,124 @@
use pkarr::PublicKey;
use std::str::FromStr;
use super::WebDavPath;
#[derive(thiserror::Error, Debug)]
pub enum EntryPathError {
#[error("{0}")]
Invalid(String),
#[error("Failed to parse webdav path: {0}")]
InvalidWebdavPath(anyhow::Error),
#[error("Failed to parse pubkey: {0}")]
InvalidPubkey(pkarr::errors::PublicKeyError),
}
/// A path to an entry.
///
/// The path as a string is used to identify the entry.
#[derive(Debug, Clone)]
pub struct EntryPath {
#[allow(dead_code)]
pubkey: PublicKey,
path: WebDavPath,
/// The key of the entry represented as a string.
/// The key is the pubkey and the path concatenated.
/// Example: `8pinxxgqs41n4aididenw5apqp1urfmzdztr8jt4abrkdn435ewo/folder/file.txt`
/// This is cached/redundant to avoid reallocating the string on every access.
key: String,
}
impl EntryPath {
pub fn new(pubkey: PublicKey, path: WebDavPath) -> Self {
let key = format!("{}{}", pubkey, path);
Self { pubkey, path, key }
}
#[allow(dead_code)]
pub fn pubkey(&self) -> &PublicKey {
&self.pubkey
}
pub fn path(&self) -> &WebDavPath {
&self.path
}
/// The key of the entry.
///
/// The key is the pubkey and the path concatenated.
///
/// Example: `8pinxxgqs41n4aididenw5apqp1urfmzdztr8jt4abrkdn435ewo/folder/file.txt`
pub fn as_str(&self) -> &str {
&self.key
}
}
impl AsRef<str> for EntryPath {
fn as_ref(&self) -> &str {
&self.key
}
}
impl FromStr for EntryPath {
type Err = EntryPathError;
fn from_str(s: &str) -> Result<Self, EntryPathError> {
let first_slash_index = s
.find('/')
.ok_or(EntryPathError::Invalid("Missing '/'".to_string()))?;
let (pubkey, path) = match s.split_at_checked(first_slash_index) {
Some((pubkey, path)) => (pubkey, path),
None => return Err(EntryPathError::Invalid("Missing '/'".to_string())),
};
let pubkey = PublicKey::from_str(pubkey).map_err(EntryPathError::InvalidPubkey)?;
let webdav_path = WebDavPath::new(path).map_err(EntryPathError::InvalidWebdavPath)?;
Ok(Self::new(pubkey, webdav_path))
}
}
impl std::fmt::Display for EntryPath {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_ref())
}
}
impl serde::Serialize for EntryPath {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(self.as_ref())
}
}
impl<'de> serde::Deserialize<'de> for EntryPath {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Self::from_str(&s).map_err(serde::de::Error::custom)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_entry_path_from_str() {
let pubkey =
PublicKey::from_str("8pinxxgqs41n4aididenw5apqp1urfmzdztr8jt4abrkdn435ewo").unwrap();
let path = WebDavPath::new("/pub/folder/file.txt").unwrap();
let key = format!("{pubkey}{path}");
let entry_path = EntryPath::new(pubkey, path);
assert_eq!(entry_path.as_ref(), key);
}
#[test]
fn test_entry_path_serde() {
let string = "8pinxxgqs41n4aididenw5apqp1urfmzdztr8jt4abrkdn435ewo/pub/folder/file.txt";
let entry_path = EntryPath::from_str(string).unwrap();
assert_eq!(entry_path.to_string(), string);
}
}

View File

@@ -0,0 +1,7 @@
mod entry_path;
mod webdav_path;
mod webdav_path_axum;
pub use entry_path::EntryPath;
pub use webdav_path::WebDavPath;
pub use webdav_path_axum::WebDavPathAxum;

View File

@@ -0,0 +1,343 @@
use std::str::FromStr;
/// A normalized and validated webdav path.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct WebDavPath {
normalized_path: String,
}
impl WebDavPath {
/// Create a new WebDavPath from a already normalized path.
/// Make sure the path is 100% normalized, and valid before using this constructor.
///
/// Use `WebDavPath::new` to create a new WebDavPath from an unnormalized path.
pub fn new_unchecked(normalized_path: String) -> Self {
Self {
normalized_path: normalized_path.to_string(),
}
}
/// Create a new WebDavPath from an unnormalized path.
///
/// The path will be normalized and validated.
pub fn new(unnormalized_path: &str) -> anyhow::Result<Self> {
let normalized_path = normalize_and_validate_webdav_path(unnormalized_path)?;
if !normalized_path.starts_with("/pub/") {
return Err(anyhow::anyhow!("Path must start with /pub/"));
}
Ok(Self::new_unchecked(normalized_path))
}
#[allow(dead_code)]
pub fn url_encode(&self) -> String {
percent_encoding::utf8_percent_encode(self.normalized_path.as_str(), PATH_ENCODE_SET)
.to_string()
}
pub fn as_str(&self) -> &str {
self.normalized_path.as_str()
}
pub fn is_directory(&self) -> bool {
self.normalized_path.ends_with('/')
}
/// Check if the path is a file.
#[allow(dead_code)]
pub fn is_file(&self) -> bool {
!self.is_directory()
}
}
impl std::fmt::Display for WebDavPath {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.normalized_path)
}
}
impl FromStr for WebDavPath {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::new(s)
}
}
// Encode all non-unreserved characters, except '/'.
// See RFC3986, and https://en.wikipedia.org/wiki/Percent-encoding .
const PATH_ENCODE_SET: &percent_encoding::AsciiSet = &percent_encoding::NON_ALPHANUMERIC
.remove(b'-')
.remove(b'_')
.remove(b'.')
.remove(b'~')
.remove(b'/');
/// Maximum length of a single path segment.
const MAX_WEBDAV_PATH_SEGMENT_LENGTH: usize = 255;
/// Maximum total length of a normalized WebDAV path.
const MAX_WEBDAV_PATH_TOTAL_LENGTH: usize = 4096;
/// Takes a path, normalizes and validates it.
/// Make sure to url decode the path before calling this function.
/// Inspired by https://github.com/messense/dav-server-rs/blob/740dae05ac2eeda8e2ea11fface3ab6d53b6705e/src/davpath.rs#L101
fn normalize_and_validate_webdav_path(path: &str) -> anyhow::Result<String> {
// Ensure the path starts with '/'
if !path.starts_with('/') {
return Err(anyhow::anyhow!("Path must start with '/'"));
}
let is_dir = path.ends_with('/') || path.ends_with("..");
// Split the path into segments, filtering out empty ones
let segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
// Build the normalized path
let mut normalized_segments = vec![];
for segment in segments {
// Check for segment length
if segment.len() > MAX_WEBDAV_PATH_SEGMENT_LENGTH {
return Err(anyhow::anyhow!(
"Invalid path: Segment exceeds maximum length of {} characters. Segment: '{}'",
MAX_WEBDAV_PATH_SEGMENT_LENGTH,
segment
));
}
// Check for any ASCII control characters in the decoded segment
if segment.chars().any(|c| c.is_control()) {
return Err(anyhow::anyhow!(
"Invalid path: ASCII control characters are not allowed in segments"
));
}
if segment == "." {
continue;
} else if segment == ".." {
if normalized_segments.len() < 2 {
return Err(anyhow::anyhow!("Failed to normalize path: '..'."));
}
normalized_segments.pop();
normalized_segments.pop();
} else {
normalized_segments.push("/".to_string());
normalized_segments.push(segment.to_string());
}
}
if is_dir {
normalized_segments.push("/".to_string());
}
let full_path = normalized_segments.join("");
// Check for total path length
if full_path.len() > MAX_WEBDAV_PATH_TOTAL_LENGTH {
return Err(anyhow::anyhow!(
"Invalid path: Total path length exceeds maximum of {} characters. Length: {}, Path: '{}'",
MAX_WEBDAV_PATH_TOTAL_LENGTH,
full_path.len(),
full_path
));
}
Ok(full_path)
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_valid_path(path: &str, expected: &str) {
match normalize_and_validate_webdav_path(path) {
Ok(path) => {
assert_eq!(path, expected);
}
Err(e) => {
assert!(
false,
"Path '{path}' is invalid. Should be '{expected}'. Error: {e}"
);
}
};
}
fn assert_invalid_path(path: &str) {
if let Ok(normalized_path) = normalize_and_validate_webdav_path(path) {
assert!(
false,
"Invalid path '{path}' is valid. Normalized result: '{normalized_path}'"
);
}
}
#[test]
fn test_slash_is_valid() {
assert_valid_path("/", "/");
}
#[test]
fn test_two_dots_is_valid() {
assert_valid_path("/test/..", "/");
}
#[test]
fn test_two_dots_in_the_middle_is_valid() {
assert_valid_path("/test/../test", "/test");
}
#[test]
fn test_two_dots_in_the_middle_with_slash_is_valid() {
assert_valid_path("/test/../test/", "/test/");
}
#[test]
fn test_two_dots_invalid() {
assert_invalid_path("/..");
}
#[test]
fn test_two_dots_twice_invalid() {
assert_invalid_path("/test/../..");
}
#[test]
fn test_two_slashes_is_valid() {
assert_valid_path("//", "/");
}
#[test]
fn test_two_slashes_in_the_middle_is_valid() {
assert_valid_path("/test//test", "/test/test");
}
#[test]
fn test_one_segment_is_valid() {
assert_valid_path("/test", "/test");
}
#[test]
fn test_one_segment_with_trailing_slash_is_valid() {
assert_valid_path("/test/", "/test/");
}
#[test]
fn test_two_segments_is_valid() {
assert_valid_path("/test/test", "/test/test");
}
#[test]
fn test_wildcard_is_valid() {
assert_valid_path("/dav/file*.txt", "/dav/file*.txt");
}
#[test]
fn test_two_slashes_in_the_middle_with_slash_is_valid() {
assert_valid_path("/dav//folder/", "/dav/folder/");
}
#[test]
fn test_script_tag_is_valid() {
assert_valid_path("/dav/<script>", "/dav/<script>");
}
#[test]
fn test_null_is_invalid() {
assert_invalid_path("/dav/file\0");
}
#[test]
fn test_empty_path_is_invalid() {
assert_invalid_path("");
}
#[test]
fn test_missing_root_slash1_is_invalid() {
assert_invalid_path("test");
}
#[test]
fn test_missing_root_slash2_is_invalid() {
assert_invalid_path("test/");
}
#[test]
fn test_invalid_path_test_over_test() {
assert_invalid_path("test/test");
}
#[test]
fn test_invalid_path_http_example_com_test() {
assert_invalid_path("http://example.com/test");
}
#[test]
fn test_invalid_path_backslash_test_backslash() {
assert_invalid_path("\\test\\");
}
#[test]
fn test_invalid_path_dot() {
assert_invalid_path(".");
}
#[test]
fn test_invalid_path_dot_dot() {
assert_invalid_path("..");
}
#[test]
fn test_invalid_windows_path() {
assert_invalid_path("C:\\dav\\file");
}
#[test]
fn test_valid_path_dav_uber() {
assert_valid_path("/dav/über", "/dav/über");
}
#[test]
fn test_webdav_pub_required() {
WebDavPath::from_str("/pub/file.txt").expect("Should be valid");
WebDavPath::from_str("/file.txt").expect_err("Should not be valid. /pub/ required.");
}
#[test]
fn test_url_encode() {
let url_encoded = "/pub/file%25.txt";
let url_decoded = percent_encoding::percent_decode_str(url_encoded)
.decode_utf8()
.unwrap()
.to_string();
let path = WebDavPath::new(url_decoded.as_str()).unwrap();
let normalized = path.to_string();
assert_eq!(normalized, "/pub/file%.txt");
assert_eq!(path.url_encode(), url_encoded);
}
#[test]
fn test_segment_too_long() {
let long_segment = "a".repeat(MAX_WEBDAV_PATH_SEGMENT_LENGTH + 1);
let path = format!("/prefix/{}/suffix", long_segment);
assert_invalid_path(&path);
}
#[test]
fn test_segment_max_length_is_valid() {
let max_segment = "a".repeat(MAX_WEBDAV_PATH_SEGMENT_LENGTH);
let path = format!("/prefix/{}/suffix", max_segment);
let expected_path = path.clone(); // Expected path is the same as input if valid
assert_valid_path(&path, &expected_path);
}
#[test]
fn test_total_path_too_long() {
let num_segments = MAX_WEBDAV_PATH_TOTAL_LENGTH; // This will create path like "/a/a/.../a"
let segments: Vec<String> = std::iter::repeat("a".to_string())
.take(num_segments)
.collect();
let path = format!("/{}", segments.join("/"));
assert_invalid_path(&path);
let almost_too_long_segment = "a".repeat(MAX_WEBDAV_PATH_TOTAL_LENGTH - 1); // e.g., if max is 10, this is 9 'a's
let path_too_long = format!("/{}/b", almost_too_long_segment); // "/aaaaaaaaa/b" -> 1 + 9 + 1 + 1 = 12 > 10
assert_invalid_path(&path_too_long);
}
}

View File

@@ -0,0 +1,61 @@
use std::str::FromStr;
use serde::{Deserialize, Serialize};
use super::WebDavPath;
/// A webdav path that can be used with axum.
///
/// When using `.route("/{*path}", your_handler)` in axum, the path is passed without the leading slash.
/// This struct adds the leading slash back and therefore allows direct validation of the path.
///
/// Usage in handler:
///
/// `Path(path): Path<WebDavPathAxum>`
pub struct WebDavPathAxum(pub WebDavPath);
impl std::fmt::Display for WebDavPathAxum {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0.as_str())
}
}
impl FromStr for WebDavPathAxum {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let with_slash = format!("/{}", s);
let inner = WebDavPath::new(&with_slash)?;
Ok(Self(inner))
}
}
impl Serialize for WebDavPathAxum {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(self.0.as_str())
}
}
impl<'de> Deserialize<'de> for WebDavPathAxum {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Self::from_str(&s).map_err(serde::de::Error::custom)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_webdav_path_axum() {
let path = WebDavPathAxum::from_str("pub/foo/bar").unwrap();
assert_eq!(path.0.as_str(), "/pub/foo/bar");
}
}