mirror of
https://github.com/aljazceru/pubky-core.git
synced 2026-01-31 20:04:19 +01:00
fix(homeserver): DDoS due to slow uploads or downloads (#132)
* switch branch * moved files around * validate webdav path in requests * async and sync writer * implemented file read and write via disk * cleaning up * clippy and fmt * cleanup 1 * last cleanup * fmt * test concurrent read/write * cleanup3 * cleanup3 * added log in case of a join error * removed pub from max_chunk_size() * fmt and clippy * fixed test * fmt * fixed main merge errors
This commit is contained in:
committed by
GitHub
parent
a485d8c2f4
commit
178bae142e
6
Cargo.lock
generated
6
Cargo.lock
generated
@@ -2468,6 +2468,7 @@ dependencies = [
|
||||
"hostname-validator",
|
||||
"httpdate",
|
||||
"page_size",
|
||||
"percent-encoding",
|
||||
"pkarr",
|
||||
"pkarr-republisher",
|
||||
"postcard",
|
||||
@@ -2479,6 +2480,7 @@ dependencies = [
|
||||
"tempfile",
|
||||
"thiserror 2.0.12",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"toml",
|
||||
"tower 0.5.2",
|
||||
"tower-cookies",
|
||||
@@ -3557,9 +3559,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.7.13"
|
||||
version = "0.7.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078"
|
||||
checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
|
||||
@@ -80,7 +80,11 @@ async fn disabled_user() {
|
||||
|
||||
// Make sure the user can read their own file
|
||||
let response = client.get(file_url.clone()).send().await.unwrap();
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
assert_eq!(
|
||||
response.status(),
|
||||
StatusCode::OK,
|
||||
"User should be able to read their own file"
|
||||
);
|
||||
|
||||
let admin_socket = server.admin().listen_socket();
|
||||
let admin_client = reqwest::Client::new();
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use pkarr::Keypair;
|
||||
use pkarr::{Keypair, PublicKey};
|
||||
use pubky_testnet::{
|
||||
pubky::Client,
|
||||
pubky_homeserver::{
|
||||
quota_config::{GlobPattern, LimitKey, LimitKeyType, PathLimit},
|
||||
ConfigToml, MockDataDir,
|
||||
@@ -186,3 +187,127 @@ async fn test_limit_upload() {
|
||||
assert_eq!(res.status(), StatusCode::CREATED);
|
||||
assert!(start.elapsed() > Duration::from_secs(2));
|
||||
}
|
||||
|
||||
/// Test that 10 clients can write/read to the server concurrently
|
||||
/// Upload/download rate is limited to 1kb/s per user.
|
||||
/// 3kb files are used to make the writes/reads take ~2.5s each.
|
||||
/// Concurrently writing/reading 10 files, the total time taken should be ~3s.
|
||||
/// If the concurrent writes/reads are not properly handled, the total time taken will be closer to ~25s.
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_write_read() {
|
||||
// Setup the testnet
|
||||
let mut testnet = Testnet::new().await.unwrap();
|
||||
let mut config = ConfigToml::test();
|
||||
config.drive.rate_limits = vec![
|
||||
PathLimit::new(
|
||||
// Limit uploads to 1kb/s per user
|
||||
GlobPattern::new("/pub/**"),
|
||||
Method::PUT,
|
||||
"1kb/s".parse().unwrap(),
|
||||
LimitKeyType::User,
|
||||
None,
|
||||
),
|
||||
PathLimit::new(
|
||||
// Limit downloads to 1kb/s per user
|
||||
GlobPattern::new("/pub/**"),
|
||||
Method::GET,
|
||||
"1kb/s".parse().unwrap(),
|
||||
LimitKeyType::User,
|
||||
None,
|
||||
),
|
||||
];
|
||||
let mock_dir = MockDataDir::new(config, None).unwrap();
|
||||
let hs_pubkey = {
|
||||
let server = testnet
|
||||
.create_homeserver_suite_with_mock(mock_dir)
|
||||
.await
|
||||
.unwrap();
|
||||
server.public_key()
|
||||
};
|
||||
|
||||
// Create helper struct to handle clients
|
||||
#[derive(Clone)]
|
||||
struct TestClient {
|
||||
pub keypair: Keypair,
|
||||
pub client: Client,
|
||||
}
|
||||
impl TestClient {
|
||||
fn new(testnet: &mut Testnet) -> Self {
|
||||
let keypair = Keypair::random();
|
||||
let client = testnet.pubky_client_builder().build().unwrap();
|
||||
Self { keypair, client }
|
||||
}
|
||||
pub async fn signup(&self, hs_pubkey: &PublicKey) {
|
||||
self.client
|
||||
.signup(&self.keypair, &hs_pubkey, None)
|
||||
.await
|
||||
.expect("Failed to signup");
|
||||
}
|
||||
pub async fn put(&self, url: Url, body: Vec<u8>) {
|
||||
self.client
|
||||
.put(url)
|
||||
.body(body)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to put");
|
||||
}
|
||||
pub async fn get(&self, url: Url) {
|
||||
let response = self.client.get(url).send().await.expect("Failed to get");
|
||||
assert_eq!(response.status(), StatusCode::OK, "Failed to get");
|
||||
response.bytes().await.expect("Failed to get bytes"); // Download the body
|
||||
}
|
||||
}
|
||||
|
||||
// Signup with the clients
|
||||
let user_count: usize = 10;
|
||||
let mut clients = vec![0; user_count]
|
||||
.into_iter()
|
||||
.map(|_| TestClient::new(&mut testnet))
|
||||
.collect::<Vec<_>>();
|
||||
for client in clients.iter_mut() {
|
||||
client.signup(&hs_pubkey).await;
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------------------------
|
||||
// Write to server concurrently
|
||||
let start = Instant::now();
|
||||
let mut handles = vec![];
|
||||
for client in clients.iter() {
|
||||
let client = client.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
let url: Url = format!("pubky://{}/pub/test.txt", client.keypair.public_key())
|
||||
.parse()
|
||||
.unwrap();
|
||||
let body = vec![0u8; 3 * 1024]; // 2kb
|
||||
client.put(url, body).await;
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
// Wait for all the writes to finish
|
||||
for handle in handles {
|
||||
handle.await.unwrap();
|
||||
}
|
||||
let elapsed = start.elapsed();
|
||||
assert!(elapsed < Duration::from_secs(5));
|
||||
|
||||
// --------------------------------------------------------------------------------------------
|
||||
// Read from server concurrently
|
||||
let start = Instant::now();
|
||||
let mut handles = vec![];
|
||||
for client in clients.iter() {
|
||||
let client = client.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
let url: Url = format!("pubky://{}/pub/test.txt", client.keypair.public_key())
|
||||
.parse()
|
||||
.unwrap();
|
||||
client.get(url).await;
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
// Wait for all the reads to finish
|
||||
for handle in handles {
|
||||
handle.await.unwrap();
|
||||
}
|
||||
let elapsed = start.elapsed();
|
||||
assert!(elapsed < Duration::from_secs(5));
|
||||
}
|
||||
|
||||
@@ -55,6 +55,8 @@ dyn-clone = "1.0.19"
|
||||
reqwest = "0.12.15"
|
||||
governor = "0.10.0"
|
||||
fast-glob = "0.4.5"
|
||||
tokio-util = "0.7.15"
|
||||
percent-encoding = "2.3.1"
|
||||
serde_valid = "1.0.5"
|
||||
|
||||
|
||||
|
||||
@@ -25,10 +25,7 @@ fn create_protected_router(password: &str) -> Router<AppState> {
|
||||
get(generate_signup_token::generate_signup_token),
|
||||
)
|
||||
.route("/info", get(info::info))
|
||||
.route(
|
||||
"/webdav/{pubkey}/{*path}",
|
||||
delete(delete_entry::delete_entry),
|
||||
)
|
||||
.route("/webdav/{*entry_path}", delete(delete_entry::delete_entry))
|
||||
.route("/users/{pubkey}/disable", post(disable_user))
|
||||
.route("/users/{pubkey}/enable", post(enable_user))
|
||||
.layer(AdminAuthLayer::new(password.to_string()))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use super::super::app_state::AppState;
|
||||
use crate::shared::{HttpError, HttpResult, Z32Pubkey};
|
||||
use crate::shared::{webdav::EntryPath, HttpError, HttpResult};
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
@@ -9,16 +9,9 @@ use axum::{
|
||||
/// Delete a single entry from the database.
|
||||
pub async fn delete_entry(
|
||||
State(mut state): State<AppState>,
|
||||
Path((pubkey, path)): Path<(Z32Pubkey, String)>,
|
||||
Path(entry_path): Path<EntryPath>,
|
||||
) -> HttpResult<impl IntoResponse> {
|
||||
let path = format!("/{}", path); // Add missing leading slash
|
||||
if !path.starts_with("/pub/") {
|
||||
return Err(HttpError::new(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Some("Invalid path"),
|
||||
));
|
||||
}
|
||||
let deleted = state.db.delete_entry(&pubkey.0, &path)?;
|
||||
let deleted = state.db.delete_entry(&entry_path).await?;
|
||||
if deleted {
|
||||
Ok((StatusCode::NO_CONTENT, ()))
|
||||
} else {
|
||||
@@ -33,16 +26,14 @@ pub async fn delete_entry(
|
||||
mod tests {
|
||||
use super::super::super::app_state::AppState;
|
||||
use super::*;
|
||||
use crate::persistence::lmdb::LmDB;
|
||||
use crate::persistence::lmdb::{tables::files::InDbTempFile, LmDB};
|
||||
use crate::shared::webdav::{EntryPath, WebDavPath};
|
||||
use axum::{routing::delete, Router};
|
||||
use pkarr::{Keypair, PublicKey};
|
||||
use std::io::Write;
|
||||
use pkarr::Keypair;
|
||||
|
||||
async fn write_test_file(db: &mut LmDB, pubkey: &PublicKey, path: &str) {
|
||||
let mut entry_writer = db.write_entry(pubkey, path).unwrap();
|
||||
let content = b"Hello, world!";
|
||||
entry_writer.write_all(content).unwrap();
|
||||
let _entry = entry_writer.commit().unwrap();
|
||||
async fn write_test_file(db: &mut LmDB, entry_path: &EntryPath) {
|
||||
let file = InDbTempFile::zeros(10).await.unwrap();
|
||||
let _entry = db.write_entry(&entry_path, &file).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -54,25 +45,25 @@ mod tests {
|
||||
let mut db = LmDB::test();
|
||||
let app_state = AppState::new(db.clone());
|
||||
let router = Router::new()
|
||||
.route("/webdav/{pubkey}/{*path}", delete(delete_entry))
|
||||
.route("/webdav/{*entry_path}", delete(delete_entry))
|
||||
.with_state(app_state);
|
||||
|
||||
// Write a test file
|
||||
let entry_path = format!("/pub/{}", file_path);
|
||||
write_test_file(&mut db, &pubkey, &entry_path).await;
|
||||
let webdav_path = WebDavPath::new(format!("/pub/{}", file_path).as_str()).unwrap();
|
||||
let entry_path = EntryPath::new(pubkey.clone(), webdav_path);
|
||||
|
||||
write_test_file(&mut db, &entry_path).await;
|
||||
|
||||
// Delete the file
|
||||
let server = axum_test::TestServer::new(router).unwrap();
|
||||
let response = server
|
||||
.delete(format!("/webdav/{}{}", pubkey, entry_path).as_str())
|
||||
.delete(format!("/webdav/{}{}", pubkey, entry_path.path().as_str()).as_str())
|
||||
.await;
|
||||
assert_eq!(response.status_code(), StatusCode::NO_CONTENT);
|
||||
|
||||
// Check that the file is deleted
|
||||
let rtx = db.env.read_txn().unwrap();
|
||||
let entry = db.get_entry(&rtx, &pubkey, &file_path).unwrap();
|
||||
let entry = db.get_entry(&entry_path).unwrap();
|
||||
assert!(entry.is_none(), "Entry should be deleted");
|
||||
rtx.commit().unwrap();
|
||||
|
||||
let events = db.list_events(None, None).unwrap();
|
||||
assert_eq!(
|
||||
@@ -92,14 +83,13 @@ mod tests {
|
||||
let file_path = "my_file.txt";
|
||||
let app_state = AppState::new(LmDB::test());
|
||||
let router = Router::new()
|
||||
.route("/webdav/{pubkey}/{*path}", delete(delete_entry))
|
||||
.route("/webdav/{*entry_path}", delete(delete_entry))
|
||||
.with_state(app_state);
|
||||
|
||||
// Delete the file
|
||||
let url = format!("/webdav/{}/pub/{}", pubkey, file_path);
|
||||
let server = axum_test::TestServer::new(router).unwrap();
|
||||
let response = server
|
||||
.delete(format!("/webdav/{}/pub/{}", pubkey, file_path).as_str())
|
||||
.await;
|
||||
let response = server.delete(url.as_str()).await;
|
||||
assert_eq!(response.status_code(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
@@ -109,7 +99,7 @@ mod tests {
|
||||
let db = LmDB::test();
|
||||
let app_state = AppState::new(db.clone());
|
||||
let router = Router::new()
|
||||
.route("/webdav/{pubkey}/{*path}", delete(delete_entry))
|
||||
.route("/webdav/{*entry_path}", delete(delete_entry))
|
||||
.with_state(app_state);
|
||||
|
||||
// Delete with invalid pubkey
|
||||
|
||||
@@ -3,11 +3,7 @@
|
||||
//! Every route here is relative to a tenant's Pubky host,
|
||||
//! as opposed to routes relative to the Homeserver's owner.
|
||||
|
||||
use axum::{
|
||||
extract::DefaultBodyLimit,
|
||||
routing::{delete, get, head, put},
|
||||
Router,
|
||||
};
|
||||
use axum::{extract::DefaultBodyLimit, routing::get, Router};
|
||||
|
||||
use crate::core::{layers::authz::AuthorizationLayer, AppState};
|
||||
|
||||
@@ -17,16 +13,14 @@ pub mod write;
|
||||
|
||||
pub fn router(state: AppState) -> Router<AppState> {
|
||||
Router::new()
|
||||
// - Datastore routes
|
||||
.route("/pub/", get(read::get))
|
||||
.route("/pub/{*path}", get(read::get))
|
||||
.route("/pub/{*path}", head(read::head))
|
||||
.route("/pub/{*path}", put(write::put))
|
||||
.route("/pub/{*path}", delete(write::delete))
|
||||
// - Session routes
|
||||
.route("/session", get(session::session))
|
||||
.route("/session", delete(session::signout))
|
||||
// Layers
|
||||
.route("/session", get(session::session).delete(session::signout))
|
||||
.route(
|
||||
"/{*path}",
|
||||
get(read::get)
|
||||
.head(read::head)
|
||||
.put(write::put)
|
||||
.delete(write::delete),
|
||||
)
|
||||
// TODO: different max size for sessions and other routes?
|
||||
.layer(DefaultBodyLimit::max(100 * 1024 * 1024))
|
||||
.layer(AuthorizationLayer::new(state.clone()))
|
||||
|
||||
@@ -1,90 +1,75 @@
|
||||
use crate::persistence::lmdb::tables::files::Entry;
|
||||
use crate::{
|
||||
core::{
|
||||
err_if_user_is_invalid::err_if_user_is_invalid,
|
||||
error::{Error, Result},
|
||||
extractors::{ListQueryParams, PubkyHost},
|
||||
AppState,
|
||||
},
|
||||
shared::webdav::{EntryPath, WebDavPathAxum},
|
||||
};
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{OriginalUri, State},
|
||||
extract::{Path, State},
|
||||
http::{header, HeaderMap, HeaderValue, Response, StatusCode},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use httpdate::HttpDate;
|
||||
use pkarr::PublicKey;
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::core::{
|
||||
error::{Error, Result},
|
||||
extractors::{ListQueryParams, PubkyHost},
|
||||
AppState,
|
||||
};
|
||||
use crate::persistence::lmdb::tables::entries::Entry;
|
||||
use tokio_util::io::ReaderStream;
|
||||
|
||||
pub async fn head(
|
||||
State(state): State<AppState>,
|
||||
pubky: PubkyHost,
|
||||
headers: HeaderMap,
|
||||
path: OriginalUri,
|
||||
Path(path): Path<WebDavPathAxum>,
|
||||
) -> Result<impl IntoResponse> {
|
||||
let rtxn = state.db.env.read_txn()?;
|
||||
get_entry(
|
||||
headers,
|
||||
state
|
||||
.db
|
||||
.get_entry(&rtxn, pubky.public_key(), path.0.path())?,
|
||||
None,
|
||||
)
|
||||
err_if_user_is_invalid(pubky.public_key(), &state.db, false)?;
|
||||
let entry_path = EntryPath::new(pubky.public_key().clone(), path.0);
|
||||
let entry = state
|
||||
.db
|
||||
.get_entry(&entry_path)?
|
||||
.ok_or_else(|| Error::with_status(StatusCode::NOT_FOUND))?;
|
||||
|
||||
get_entry(headers, entry, None)
|
||||
}
|
||||
|
||||
#[axum::debug_handler]
|
||||
pub async fn get(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
pubky: PubkyHost,
|
||||
path: OriginalUri,
|
||||
Path(path): Path<WebDavPathAxum>,
|
||||
params: ListQueryParams,
|
||||
) -> Result<impl IntoResponse> {
|
||||
let public_key = pubky.public_key().clone();
|
||||
let path = path.0.path().to_string();
|
||||
|
||||
if path.ends_with('/') {
|
||||
return list(state, &public_key, &path, params);
|
||||
let dav_path = path.0;
|
||||
let entry_path = EntryPath::new(public_key.clone(), dav_path);
|
||||
if entry_path.path().is_directory() {
|
||||
return list(state, &entry_path, params);
|
||||
}
|
||||
|
||||
let (entry_tx, entry_rx) = flume::bounded::<Option<Entry>>(1);
|
||||
let (chunks_tx, chunks_rx) = flume::unbounded::<std::result::Result<Vec<u8>, heed::Error>>();
|
||||
let entry = state
|
||||
.db
|
||||
.get_entry(&entry_path)?
|
||||
.ok_or_else(|| Error::with_status(StatusCode::NOT_FOUND))?;
|
||||
let buffer_file = state.db.read_file(&entry.file_id()).await?;
|
||||
|
||||
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
|
||||
let rtxn = state.db.env.read_txn()?;
|
||||
|
||||
let option = state.db.get_entry(&rtxn, &public_key, &path)?;
|
||||
|
||||
if let Some(entry) = option {
|
||||
let iter = entry.read_content(&state.db, &rtxn)?;
|
||||
|
||||
entry_tx.send(Some(entry))?;
|
||||
|
||||
for next in iter {
|
||||
chunks_tx.send(next.map(|b| b.to_vec()))?;
|
||||
}
|
||||
};
|
||||
|
||||
entry_tx.send(None)?;
|
||||
|
||||
Ok(())
|
||||
});
|
||||
|
||||
get_entry(
|
||||
headers,
|
||||
entry_rx.recv_async().await?,
|
||||
Some(Body::from_stream(chunks_rx.into_stream())),
|
||||
)
|
||||
let file_handle = buffer_file.open_file_handle()?;
|
||||
// Async stream the file
|
||||
let tokio_file_handle = tokio::fs::File::from_std(file_handle);
|
||||
let body_stream = Body::from_stream(ReaderStream::new(tokio_file_handle));
|
||||
get_entry(headers, entry, Some(body_stream))
|
||||
}
|
||||
|
||||
pub fn list(
|
||||
state: AppState,
|
||||
public_key: &PublicKey,
|
||||
path: &str,
|
||||
entry_path: &EntryPath,
|
||||
params: ListQueryParams,
|
||||
) -> Result<Response<Body>> {
|
||||
let txn = state.db.env.read_txn()?;
|
||||
let path = format!("{public_key}{path}");
|
||||
|
||||
if !state.db.contains_directory(&txn, &path)? {
|
||||
if !state.db.contains_directory(&txn, entry_path)? {
|
||||
return Err(Error::new(
|
||||
StatusCode::NOT_FOUND,
|
||||
"Directory Not Found".into(),
|
||||
@@ -92,9 +77,9 @@ pub fn list(
|
||||
}
|
||||
|
||||
// Handle listing
|
||||
let vec = state.db.list(
|
||||
let vec = state.db.list_entries(
|
||||
&txn,
|
||||
&path,
|
||||
entry_path,
|
||||
params.reverse,
|
||||
params.limit,
|
||||
params.cursor,
|
||||
@@ -107,54 +92,44 @@ pub fn list(
|
||||
.body(Body::from(vec.join("\n")))?)
|
||||
}
|
||||
|
||||
pub fn get_entry(
|
||||
headers: HeaderMap,
|
||||
entry: Option<Entry>,
|
||||
body: Option<Body>,
|
||||
) -> Result<Response<Body>> {
|
||||
if let Some(entry) = entry {
|
||||
// TODO: Enable seek API (range requests)
|
||||
// TODO: Gzip? or brotli?
|
||||
pub fn get_entry(headers: HeaderMap, entry: Entry, body: Option<Body>) -> Result<Response<Body>> {
|
||||
// TODO: Enable seek API (range requests)
|
||||
// TODO: Gzip? or brotli?
|
||||
|
||||
let mut response = HeaderMap::from(&entry).into_response();
|
||||
let mut response = HeaderMap::from(&entry).into_response();
|
||||
|
||||
// Handle IF_MODIFIED_SINCE
|
||||
if let Some(condition_http_date) = headers
|
||||
.get(header::IF_MODIFIED_SINCE)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|s| HttpDate::from_str(s).ok())
|
||||
{
|
||||
let entry_http_date: HttpDate = entry.timestamp().to_owned().into();
|
||||
// Handle IF_MODIFIED_SINCE
|
||||
if let Some(condition_http_date) = headers
|
||||
.get(header::IF_MODIFIED_SINCE)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|s| HttpDate::from_str(s).ok())
|
||||
{
|
||||
let entry_http_date: HttpDate = entry.timestamp().to_owned().into();
|
||||
|
||||
if condition_http_date >= entry_http_date {
|
||||
*response.status_mut() = StatusCode::NOT_MODIFIED;
|
||||
}
|
||||
};
|
||||
|
||||
// Handle IF_NONE_MATCH
|
||||
if let Some(str) = headers
|
||||
.get(header::IF_NONE_MATCH)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
{
|
||||
let etag = format!("\"{}\"", entry.content_hash());
|
||||
if str
|
||||
.trim()
|
||||
.split(',')
|
||||
.collect::<Vec<_>>()
|
||||
.contains(&etag.as_str())
|
||||
{
|
||||
*response.status_mut() = StatusCode::NOT_MODIFIED;
|
||||
};
|
||||
if condition_http_date >= entry_http_date {
|
||||
*response.status_mut() = StatusCode::NOT_MODIFIED;
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(body) = body {
|
||||
*response.body_mut() = body;
|
||||
// Handle IF_NONE_MATCH
|
||||
if let Some(str) = headers
|
||||
.get(header::IF_NONE_MATCH)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
{
|
||||
let etag = format!("\"{}\"", entry.content_hash());
|
||||
if str
|
||||
.trim()
|
||||
.split(',')
|
||||
.collect::<Vec<_>>()
|
||||
.contains(&etag.as_str())
|
||||
{
|
||||
*response.status_mut() = StatusCode::NOT_MODIFIED;
|
||||
};
|
||||
|
||||
Ok(response)
|
||||
} else {
|
||||
Err(Error::with_status(StatusCode::NOT_FOUND))?
|
||||
}
|
||||
if let Some(body) = body {
|
||||
*response.body_mut() = body;
|
||||
}
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
impl From<&Entry> for HeaderMap {
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
use std::io::Write;
|
||||
|
||||
use axum::{
|
||||
body::{Body, HttpBody},
|
||||
extract::{OriginalUri, State},
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
};
|
||||
use futures_util::stream::StreamExt;
|
||||
|
||||
use crate::core::{
|
||||
err_if_user_is_invalid::err_if_user_is_invalid,
|
||||
error::{Error, Result},
|
||||
extractors::PubkyHost,
|
||||
AppState,
|
||||
use crate::{
|
||||
core::{
|
||||
err_if_user_is_invalid::err_if_user_is_invalid,
|
||||
error::{Error, Result},
|
||||
extractors::PubkyHost,
|
||||
AppState,
|
||||
},
|
||||
persistence::lmdb::tables::files::AsyncInDbTempFileWriter,
|
||||
shared::webdav::{EntryPath, WebDavPathAxum},
|
||||
};
|
||||
|
||||
/// Fail with 507 if `(current + incoming − existing) > quota`.
|
||||
/// Fail with 507 if `(current + incoming − existing) > quota`.
|
||||
fn enforce_user_disk_quota(
|
||||
existing_bytes: u64,
|
||||
incoming_bytes: u64,
|
||||
@@ -31,7 +33,7 @@ fn enforce_user_disk_quota(
|
||||
return Err(Error::new(
|
||||
StatusCode::INSUFFICIENT_STORAGE,
|
||||
Some(format!(
|
||||
"Quota of {:.1} MB exceeded: you’ve used {:.1} MB, trying to add {:.1} MB",
|
||||
"Quota of {:.1} MB exceeded: you've used {:.1} MB, trying to add {:.1} MB",
|
||||
max_mb, current_mb, adding_mb
|
||||
)),
|
||||
));
|
||||
@@ -43,15 +45,15 @@ fn enforce_user_disk_quota(
|
||||
pub async fn delete(
|
||||
State(mut state): State<AppState>,
|
||||
pubky: PubkyHost,
|
||||
path: OriginalUri,
|
||||
Path(path): Path<WebDavPathAxum>,
|
||||
) -> Result<impl IntoResponse> {
|
||||
err_if_user_is_invalid(pubky.public_key(), &state.db, false)?;
|
||||
let public_key = pubky.public_key();
|
||||
let full_path = path.0.path();
|
||||
let existing_bytes = state.db.get_entry_content_length(public_key, full_path)?;
|
||||
err_if_user_is_invalid(pubky.public_key(), &state.db, false)?;
|
||||
let entry_path = EntryPath::new(public_key.clone(), path.0);
|
||||
let existing_bytes = state.db.get_entry_content_length(&entry_path)?;
|
||||
|
||||
// Remove entry
|
||||
if !state.db.delete_entry(public_key, full_path)? {
|
||||
if !state.db.delete_entry(&entry_path).await? {
|
||||
return Err(Error::with_status(StatusCode::NOT_FOUND));
|
||||
}
|
||||
|
||||
@@ -66,36 +68,38 @@ pub async fn delete(
|
||||
pub async fn put(
|
||||
State(mut state): State<AppState>,
|
||||
pubky: PubkyHost,
|
||||
path: OriginalUri,
|
||||
Path(path): Path<WebDavPathAxum>,
|
||||
body: Body,
|
||||
) -> Result<impl IntoResponse> {
|
||||
err_if_user_is_invalid(pubky.public_key(), &state.db, true)?;
|
||||
let public_key = pubky.public_key();
|
||||
let full_path = path.0.path();
|
||||
let existing_entry_bytes = state.db.get_entry_content_length(public_key, full_path)?;
|
||||
err_if_user_is_invalid(public_key, &state.db, true)?;
|
||||
let entry_path = EntryPath::new(public_key.clone(), path.0);
|
||||
|
||||
let existing_entry_bytes = state.db.get_entry_content_length(&entry_path)?;
|
||||
let quota_bytes = state.user_quota_bytes;
|
||||
let used_bytes = state.db.get_user_data_usage(public_key)?;
|
||||
|
||||
// Upfront check when we have an exact Content‑Length
|
||||
// Upfront check when we have an exact Content-Length
|
||||
let hint = body.size_hint().exact();
|
||||
if let Some(exact_bytes) = hint {
|
||||
enforce_user_disk_quota(existing_entry_bytes, exact_bytes, used_bytes, quota_bytes)?;
|
||||
}
|
||||
|
||||
// Stream body
|
||||
let mut writer = state.db.write_entry(public_key, full_path)?;
|
||||
// Stream body to disk first.
|
||||
let mut seen_bytes: u64 = 0;
|
||||
let mut stream = body.into_data_stream();
|
||||
let mut buffer_file_writer = AsyncInDbTempFileWriter::new().await?;
|
||||
|
||||
while let Some(chunk) = stream.next().await.transpose()? {
|
||||
seen_bytes += chunk.len() as u64;
|
||||
enforce_user_disk_quota(existing_entry_bytes, seen_bytes, used_bytes, quota_bytes)?;
|
||||
writer.write_all(&chunk)?;
|
||||
buffer_file_writer.write_chunk(&chunk).await?;
|
||||
}
|
||||
let buffer_file = buffer_file_writer.complete().await?;
|
||||
|
||||
// Commit & bump usage
|
||||
let entry = writer.commit()?;
|
||||
let delta = entry.content_length() as i64 - existing_entry_bytes as i64;
|
||||
// Write file on disk to db
|
||||
state.db.write_entry(&entry_path, &buffer_file).await?;
|
||||
let delta = buffer_file.len() as i64 - existing_entry_bytes as i64;
|
||||
state.db.update_data_usage(public_key, delta)?;
|
||||
|
||||
Ok((StatusCode::CREATED, ()))
|
||||
|
||||
@@ -11,7 +11,6 @@ pub const DEFAULT_MAP_SIZE: usize = 10995116277760; // 10TB (not = disk-space us
|
||||
pub struct LmDB {
|
||||
pub(crate) env: Env,
|
||||
pub(crate) tables: Tables,
|
||||
pub(crate) buffers_dir: PathBuf,
|
||||
pub(crate) max_chunk_size: usize,
|
||||
// Only used for testing purposes to keep the testdir alive.
|
||||
#[allow(dead_code)]
|
||||
@@ -44,7 +43,6 @@ impl LmDB {
|
||||
let db = LmDB {
|
||||
env,
|
||||
tables,
|
||||
buffers_dir,
|
||||
max_chunk_size: Self::max_chunk_size(),
|
||||
test_dir: None,
|
||||
};
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
use heed::{Env, RwTxn};
|
||||
|
||||
use crate::persistence::lmdb::tables::{blobs, entries, events, sessions, signup_tokens, users};
|
||||
use crate::persistence::lmdb::tables::{events, files, sessions, signup_tokens, users};
|
||||
|
||||
pub fn run(env: &Env, wtxn: &mut RwTxn) -> anyhow::Result<()> {
|
||||
let _: users::UsersTable = env.create_database(wtxn, Some(users::USERS_TABLE))?;
|
||||
|
||||
let _: sessions::SessionsTable = env.create_database(wtxn, Some(sessions::SESSIONS_TABLE))?;
|
||||
|
||||
let _: blobs::BlobsTable = env.create_database(wtxn, Some(blobs::BLOBS_TABLE))?;
|
||||
let _: files::BlobsTable = env.create_database(wtxn, Some(files::BLOBS_TABLE))?;
|
||||
|
||||
let _: entries::EntriesTable = env.create_database(wtxn, Some(entries::ENTRIES_TABLE))?;
|
||||
let _: files::EntriesTable = env.create_database(wtxn, Some(files::ENTRIES_TABLE))?;
|
||||
|
||||
let _: events::EventsTable = env.create_database(wtxn, Some(events::EVENTS_TABLE))?;
|
||||
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
pub mod blobs;
|
||||
pub mod entries;
|
||||
pub mod events;
|
||||
pub mod files;
|
||||
pub mod sessions;
|
||||
pub mod signup_tokens;
|
||||
pub mod users;
|
||||
|
||||
use heed::{Env, RwTxn};
|
||||
|
||||
use blobs::{BlobsTable, BLOBS_TABLE};
|
||||
use entries::{EntriesTable, ENTRIES_TABLE};
|
||||
use files::{BlobsTable, BLOBS_TABLE};
|
||||
use files::{EntriesTable, ENTRIES_TABLE};
|
||||
|
||||
use self::{
|
||||
events::{EventsTable, EVENTS_TABLE},
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
use heed::{types::Bytes, Database, RoTxn};
|
||||
|
||||
use super::super::LmDB;
|
||||
|
||||
use super::entries::Entry;
|
||||
|
||||
/// (entry timestamp | chunk_index BE) => bytes
|
||||
pub type BlobsTable = Database<Bytes, Bytes>;
|
||||
|
||||
pub const BLOBS_TABLE: &str = "blobs";
|
||||
|
||||
impl LmDB {
|
||||
pub fn read_entry_content<'txn>(
|
||||
&self,
|
||||
rtxn: &'txn RoTxn,
|
||||
entry: &Entry,
|
||||
) -> anyhow::Result<impl Iterator<Item = Result<&'txn [u8], heed::Error>> + 'txn> {
|
||||
Ok(self
|
||||
.tables
|
||||
.blobs
|
||||
.prefix_iter(rtxn, &entry.timestamp().to_bytes())?
|
||||
.map(|i| i.map(|(_, bytes)| bytes)))
|
||||
}
|
||||
}
|
||||
@@ -1,560 +0,0 @@
|
||||
use pkarr::PublicKey;
|
||||
use postcard::{from_bytes, to_allocvec};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
fs::File,
|
||||
io::{Read, Write},
|
||||
path::PathBuf,
|
||||
};
|
||||
use tracing::instrument;
|
||||
|
||||
use heed::{
|
||||
types::{Bytes, Str},
|
||||
Database, RoTxn,
|
||||
};
|
||||
|
||||
use pubky_common::{
|
||||
crypto::{Hash, Hasher},
|
||||
timestamp::Timestamp,
|
||||
};
|
||||
|
||||
use crate::constants::{DEFAULT_LIST_LIMIT, DEFAULT_MAX_LIST_LIMIT};
|
||||
|
||||
use super::super::LmDB;
|
||||
|
||||
use super::events::Event;
|
||||
|
||||
/// full_path(pubky/*path) => Entry.
|
||||
pub type EntriesTable = Database<Str, Bytes>;
|
||||
|
||||
pub const ENTRIES_TABLE: &str = "entries";
|
||||
|
||||
impl LmDB {
|
||||
/// Write an entry by an author at a given path.
|
||||
///
|
||||
/// The path has to start with a forward slash `/`
|
||||
pub fn write_entry(
|
||||
&mut self,
|
||||
public_key: &PublicKey,
|
||||
path: &str,
|
||||
) -> anyhow::Result<EntryWriter> {
|
||||
EntryWriter::new(self, public_key, path)
|
||||
}
|
||||
|
||||
/// Delete an entry by an author at a given path.
|
||||
///
|
||||
/// The path has to start with a forward slash `/`
|
||||
pub fn delete_entry(&mut self, public_key: &PublicKey, path: &str) -> anyhow::Result<bool> {
|
||||
let mut wtxn = self.env.write_txn()?;
|
||||
|
||||
let key = format!("{public_key}{path}");
|
||||
|
||||
let deleted = if let Some(bytes) = self.tables.entries.get(&wtxn, &key)? {
|
||||
let entry = Entry::deserialize(bytes)?;
|
||||
|
||||
let mut deleted_chunks = false;
|
||||
|
||||
{
|
||||
let mut iter = self
|
||||
.tables
|
||||
.blobs
|
||||
.prefix_iter_mut(&mut wtxn, &entry.timestamp.to_bytes())?;
|
||||
|
||||
while iter.next().is_some() {
|
||||
unsafe {
|
||||
deleted_chunks = iter.del_current()?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let deleted_entry = self.tables.entries.delete(&mut wtxn, &key)?;
|
||||
|
||||
// create DELETE event
|
||||
if path.starts_with("/pub/") {
|
||||
let url = format!("pubky://{key}");
|
||||
|
||||
let event = Event::delete(&url);
|
||||
let value = event.serialize();
|
||||
|
||||
let key = Timestamp::now().to_string();
|
||||
|
||||
self.tables.events.put(&mut wtxn, &key, &value)?;
|
||||
|
||||
// TODO: delete events older than a threshold.
|
||||
// TODO: move to events.rs
|
||||
}
|
||||
|
||||
deleted_entry && deleted_chunks
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
wtxn.commit()?;
|
||||
|
||||
Ok(deleted)
|
||||
}
|
||||
|
||||
pub fn get_entry(
|
||||
&self,
|
||||
txn: &RoTxn,
|
||||
public_key: &PublicKey,
|
||||
path: &str,
|
||||
) -> anyhow::Result<Option<Entry>> {
|
||||
let key = format!("{public_key}{path}");
|
||||
|
||||
if let Some(bytes) = self.tables.entries.get(txn, &key)? {
|
||||
return Ok(Some(Entry::deserialize(bytes)?));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Bytes stored at `path` for this user (0 if none).
|
||||
pub fn get_entry_content_length(
|
||||
&self,
|
||||
public_key: &PublicKey,
|
||||
path: &str,
|
||||
) -> anyhow::Result<u64> {
|
||||
let txn = self.env.read_txn()?;
|
||||
let content_length = self
|
||||
.get_entry(&txn, public_key, path)?
|
||||
.map(|e| e.content_length() as u64)
|
||||
.unwrap_or(0);
|
||||
Ok(content_length)
|
||||
}
|
||||
|
||||
pub fn contains_directory(&self, txn: &RoTxn, path: &str) -> anyhow::Result<bool> {
|
||||
Ok(self.tables.entries.get_greater_than(txn, path)?.is_some())
|
||||
}
|
||||
|
||||
/// Return a list of pubky urls.
|
||||
///
|
||||
/// - limit defaults to [crate::config::DEFAULT_LIST_LIMIT] and capped by [crate::config::DEFAULT_MAX_LIST_LIMIT]
|
||||
pub fn list(
|
||||
&self,
|
||||
txn: &RoTxn,
|
||||
path: &str,
|
||||
reverse: bool,
|
||||
limit: Option<u16>,
|
||||
cursor: Option<String>,
|
||||
shallow: bool,
|
||||
) -> anyhow::Result<Vec<String>> {
|
||||
// Vector to store results
|
||||
let mut results = Vec::new();
|
||||
|
||||
let limit = limit
|
||||
.unwrap_or(DEFAULT_LIST_LIMIT)
|
||||
.min(DEFAULT_MAX_LIST_LIMIT);
|
||||
|
||||
// TODO: make this more performant than split and allocations?
|
||||
|
||||
let mut threshold = cursor
|
||||
.map(|cursor| {
|
||||
// Removing leading forward slashes
|
||||
let mut file_or_directory = cursor.trim_start_matches('/');
|
||||
|
||||
if cursor.starts_with("pubky://") {
|
||||
file_or_directory = cursor.split(path).last().expect("should not be reachable")
|
||||
};
|
||||
|
||||
next_threshold(
|
||||
path,
|
||||
file_or_directory,
|
||||
file_or_directory.ends_with('/'),
|
||||
reverse,
|
||||
shallow,
|
||||
)
|
||||
})
|
||||
.unwrap_or(next_threshold(path, "", false, reverse, shallow));
|
||||
|
||||
for _ in 0..limit {
|
||||
if let Some((key, _)) = if reverse {
|
||||
self.tables.entries.get_lower_than(txn, &threshold)?
|
||||
} else {
|
||||
self.tables.entries.get_greater_than(txn, &threshold)?
|
||||
} {
|
||||
if !key.starts_with(path) {
|
||||
break;
|
||||
}
|
||||
|
||||
if shallow {
|
||||
let mut split = key[path.len()..].split('/');
|
||||
let file_or_directory = split.next().expect("should not be reachable");
|
||||
|
||||
let is_directory = split.next().is_some();
|
||||
|
||||
threshold =
|
||||
next_threshold(path, file_or_directory, is_directory, reverse, shallow);
|
||||
|
||||
results.push(format!(
|
||||
"pubky://{path}{file_or_directory}{}",
|
||||
if is_directory { "/" } else { "" }
|
||||
));
|
||||
} else {
|
||||
threshold = key.to_string();
|
||||
results.push(format!("pubky://{}", key))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the next threshold
|
||||
#[instrument]
|
||||
fn next_threshold(
|
||||
path: &str,
|
||||
file_or_directory: &str,
|
||||
is_directory: bool,
|
||||
reverse: bool,
|
||||
shallow: bool,
|
||||
) -> String {
|
||||
format!(
|
||||
"{path}{file_or_directory}{}",
|
||||
if file_or_directory.is_empty() {
|
||||
// No file_or_directory, early return
|
||||
if reverse {
|
||||
// `path/to/dir/\x7f` to catch all paths than `path/to/dir/`
|
||||
"\x7f"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
} else if shallow & is_directory {
|
||||
if reverse {
|
||||
// threshold = `path/to/dir\x2e`, since `\x2e` is lower than `/`
|
||||
"\x2e"
|
||||
} else {
|
||||
//threshold = `path/to/dir\x7f`, since `\x7f` is greater than `/`
|
||||
"\x7f"
|
||||
}
|
||||
} else {
|
||||
""
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct Entry {
|
||||
/// Encoding version
|
||||
version: usize,
|
||||
/// Modified at
|
||||
timestamp: Timestamp,
|
||||
content_hash: EntryHash,
|
||||
content_length: usize,
|
||||
content_type: String,
|
||||
// user_metadata: ?
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
struct EntryHash(Hash);
|
||||
|
||||
impl Default for EntryHash {
|
||||
fn default() -> Self {
|
||||
Self(Hash::from_bytes([0; 32]))
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for EntryHash {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let bytes = self.0.as_bytes();
|
||||
bytes.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for EntryHash {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let bytes: [u8; 32] = Deserialize::deserialize(deserializer)?;
|
||||
Ok(Self(Hash::from_bytes(bytes)))
|
||||
}
|
||||
}
|
||||
|
||||
impl Entry {
|
||||
pub fn new() -> Self {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
// === Setters ===
|
||||
|
||||
pub fn set_timestamp(&mut self, timestamp: &Timestamp) -> &mut Self {
|
||||
self.timestamp = *timestamp;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn set_content_hash(&mut self, content_hash: Hash) -> &mut Self {
|
||||
EntryHash(content_hash).clone_into(&mut self.content_hash);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn set_content_length(&mut self, content_length: usize) -> &mut Self {
|
||||
self.content_length = content_length;
|
||||
self
|
||||
}
|
||||
|
||||
// === Getters ===
|
||||
|
||||
pub fn timestamp(&self) -> &Timestamp {
|
||||
&self.timestamp
|
||||
}
|
||||
|
||||
pub fn content_hash(&self) -> &Hash {
|
||||
&self.content_hash.0
|
||||
}
|
||||
|
||||
pub fn content_length(&self) -> usize {
|
||||
self.content_length
|
||||
}
|
||||
|
||||
pub fn content_type(&self) -> &str {
|
||||
&self.content_type
|
||||
}
|
||||
|
||||
// === Public Method ===
|
||||
|
||||
pub fn read_content<'txn>(
|
||||
&self,
|
||||
db: &'txn LmDB,
|
||||
rtxn: &'txn RoTxn,
|
||||
) -> anyhow::Result<impl Iterator<Item = Result<&'txn [u8], heed::Error>> + 'txn> {
|
||||
db.read_entry_content(rtxn, self)
|
||||
}
|
||||
|
||||
pub fn serialize(&self) -> Vec<u8> {
|
||||
to_allocvec(self).expect("Session::serialize")
|
||||
}
|
||||
|
||||
pub fn deserialize(bytes: &[u8]) -> core::result::Result<Self, postcard::Error> {
|
||||
if bytes[0] > 0 {
|
||||
panic!("Unknown Entry version");
|
||||
}
|
||||
|
||||
from_bytes(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EntryWriter<'db> {
|
||||
db: &'db LmDB,
|
||||
buffer: File,
|
||||
hasher: Hasher,
|
||||
buffer_path: PathBuf,
|
||||
entry_key: String,
|
||||
timestamp: Timestamp,
|
||||
is_public: bool,
|
||||
}
|
||||
|
||||
impl<'db> EntryWriter<'db> {
|
||||
pub fn new(db: &'db LmDB, public_key: &PublicKey, path: &str) -> anyhow::Result<Self> {
|
||||
let hasher = Hasher::new();
|
||||
|
||||
let timestamp = Timestamp::now();
|
||||
|
||||
let buffer_path = db.buffers_dir.join(timestamp.to_string());
|
||||
|
||||
let buffer = File::create(&buffer_path)?;
|
||||
|
||||
let entry_key = format!("{public_key}{path}");
|
||||
|
||||
Ok(Self {
|
||||
db,
|
||||
buffer,
|
||||
hasher,
|
||||
buffer_path,
|
||||
entry_key,
|
||||
timestamp,
|
||||
is_public: path.starts_with("/pub/"),
|
||||
})
|
||||
}
|
||||
|
||||
/// Same ase [EntryWriter::write_all] but returns a Result of a mutable reference of itself
|
||||
/// to enable chaining with [Self::commit].
|
||||
pub fn update(&mut self, chunk: &[u8]) -> Result<&mut Self, std::io::Error> {
|
||||
self.write_all(chunk)?;
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Commit blob from the filesystem buffer to LMDB,
|
||||
/// write the [Entry], and commit the write transaction.
|
||||
pub fn commit(&self) -> anyhow::Result<Entry> {
|
||||
let hash = self.hasher.finalize();
|
||||
|
||||
let mut buffer = File::open(&self.buffer_path)?;
|
||||
|
||||
let mut wtxn = self.db.env.write_txn()?;
|
||||
|
||||
let mut chunk_key = [0; 12];
|
||||
chunk_key[0..8].copy_from_slice(&self.timestamp.to_bytes());
|
||||
|
||||
let mut chunk_index: u32 = 0;
|
||||
|
||||
loop {
|
||||
let mut chunk = vec![0_u8; self.db.max_chunk_size];
|
||||
|
||||
let bytes_read = buffer.read(&mut chunk)?;
|
||||
|
||||
if bytes_read == 0 {
|
||||
break; // EOF reached
|
||||
}
|
||||
|
||||
chunk_key[8..].copy_from_slice(&chunk_index.to_be_bytes());
|
||||
|
||||
self.db
|
||||
.tables
|
||||
.blobs
|
||||
.put(&mut wtxn, &chunk_key, &chunk[..bytes_read])?;
|
||||
|
||||
chunk_index += 1;
|
||||
}
|
||||
|
||||
let mut entry = Entry::new();
|
||||
entry.set_timestamp(&self.timestamp);
|
||||
|
||||
entry.set_content_hash(hash);
|
||||
|
||||
let length = buffer.metadata()?.len();
|
||||
entry.set_content_length(length as usize);
|
||||
|
||||
self.db
|
||||
.tables
|
||||
.entries
|
||||
.put(&mut wtxn, &self.entry_key, &entry.serialize())?;
|
||||
|
||||
// Write a public [Event].
|
||||
if self.is_public {
|
||||
let url = format!("pubky://{}", self.entry_key);
|
||||
let event = Event::put(&url);
|
||||
let value = event.serialize();
|
||||
|
||||
let key = entry.timestamp.to_string();
|
||||
|
||||
self.db.tables.events.put(&mut wtxn, &key, &value)?;
|
||||
|
||||
// TODO: delete events older than a threshold.
|
||||
// TODO: move to events.rs
|
||||
}
|
||||
|
||||
wtxn.commit()?;
|
||||
|
||||
std::fs::remove_file(&self.buffer_path)?;
|
||||
|
||||
Ok(entry)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::io::Write for EntryWriter<'_> {
|
||||
/// Write a chunk to a Filesystem based buffer.
|
||||
#[inline]
|
||||
fn write(&mut self, chunk: &[u8]) -> std::io::Result<usize> {
|
||||
self.hasher.update(chunk);
|
||||
self.buffer.write_all(chunk)?;
|
||||
|
||||
Ok(chunk.len())
|
||||
}
|
||||
|
||||
/// Does not do anything, you need to call [Self::commit]
|
||||
#[inline]
|
||||
fn flush(&mut self) -> std::io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use bytes::Bytes;
|
||||
use pkarr::Keypair;
|
||||
|
||||
use super::LmDB;
|
||||
|
||||
#[tokio::test]
|
||||
async fn entries() -> anyhow::Result<()> {
|
||||
let mut db = LmDB::test();
|
||||
|
||||
let keypair = Keypair::random();
|
||||
let public_key = keypair.public_key();
|
||||
let path = "/pub/foo.txt";
|
||||
|
||||
let chunk = Bytes::from(vec![1, 2, 3, 4, 5]);
|
||||
|
||||
db.write_entry(&public_key, path)?
|
||||
.update(&chunk)?
|
||||
.commit()?;
|
||||
|
||||
let rtxn = db.env.read_txn().unwrap();
|
||||
let entry = db.get_entry(&rtxn, &public_key, path).unwrap().unwrap();
|
||||
|
||||
assert_eq!(
|
||||
entry.content_hash(),
|
||||
&[
|
||||
2, 79, 103, 192, 66, 90, 61, 192, 47, 186, 245, 140, 185, 61, 229, 19, 46, 61, 117,
|
||||
197, 25, 250, 160, 186, 218, 33, 73, 29, 136, 201, 112, 87
|
||||
]
|
||||
);
|
||||
|
||||
let mut blob = vec![];
|
||||
|
||||
{
|
||||
let mut iter = entry.read_content(&db, &rtxn).unwrap();
|
||||
|
||||
while let Some(Ok(chunk)) = iter.next() {
|
||||
blob.extend_from_slice(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(blob, vec![1, 2, 3, 4, 5]);
|
||||
|
||||
rtxn.commit().unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chunked_entry() -> anyhow::Result<()> {
|
||||
let mut db = LmDB::test();
|
||||
|
||||
let keypair = Keypair::random();
|
||||
let public_key = keypair.public_key();
|
||||
let path = "/pub/foo.txt";
|
||||
|
||||
let chunk = Bytes::from(vec![0; 1024 * 1024]);
|
||||
|
||||
db.write_entry(&public_key, path)?
|
||||
.update(&chunk)?
|
||||
.commit()?;
|
||||
|
||||
let rtxn = db.env.read_txn().unwrap();
|
||||
let entry = db.get_entry(&rtxn, &public_key, path).unwrap().unwrap();
|
||||
|
||||
assert_eq!(
|
||||
entry.content_hash(),
|
||||
&[
|
||||
72, 141, 226, 2, 247, 59, 217, 118, 222, 78, 112, 72, 244, 225, 243, 154, 119, 109,
|
||||
134, 213, 130, 183, 52, 143, 245, 59, 244, 50, 185, 135, 252, 168
|
||||
]
|
||||
);
|
||||
|
||||
let mut blob = vec![];
|
||||
|
||||
{
|
||||
let mut iter = entry.read_content(&db, &rtxn).unwrap();
|
||||
|
||||
while let Some(Ok(chunk)) = iter.next() {
|
||||
blob.extend_from_slice(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(blob, vec![0; 1024 * 1024]);
|
||||
|
||||
let stats = db.tables.blobs.stat(&rtxn).unwrap();
|
||||
assert_eq!(stats.overflow_pages, 0);
|
||||
|
||||
rtxn.commit().unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
160
pubky-homeserver/src/persistence/lmdb/tables/files/blobs.rs
Normal file
160
pubky-homeserver/src/persistence/lmdb/tables/files/blobs.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
use super::super::super::LmDB;
|
||||
use super::{InDbFileId, InDbTempFile, SyncInDbTempFileWriter};
|
||||
use heed::{types::Bytes, Database};
|
||||
use std::io::Read;
|
||||
|
||||
/// (entry timestamp | chunk_index BE) => bytes
|
||||
pub type BlobsTable = Database<Bytes, Bytes>;
|
||||
pub const BLOBS_TABLE: &str = "blobs";
|
||||
|
||||
impl LmDB {
|
||||
/// Read the blobs into a temporary file.
|
||||
///
|
||||
/// The file is written to disk to minimize the size/duration of the LMDB transaction.
|
||||
pub(crate) fn read_file_sync(&self, id: &InDbFileId) -> anyhow::Result<InDbTempFile> {
|
||||
let mut file_writer = SyncInDbTempFileWriter::new()?;
|
||||
let rtxn = self.env.read_txn()?;
|
||||
let blobs_iter = self
|
||||
.tables
|
||||
.blobs
|
||||
.prefix_iter(&rtxn, &id.bytes())?
|
||||
.map(|i| i.map(|(_, bytes)| bytes));
|
||||
let mut file_exists = false;
|
||||
for read_result in blobs_iter {
|
||||
file_exists = true;
|
||||
let chunk = read_result?;
|
||||
file_writer.write_chunk(chunk)?;
|
||||
}
|
||||
|
||||
if !file_exists {
|
||||
return Ok(InDbTempFile::empty()?);
|
||||
}
|
||||
|
||||
let file = file_writer.complete()?;
|
||||
rtxn.commit()?;
|
||||
Ok(file)
|
||||
}
|
||||
|
||||
/// Read the blobs into a temporary file asynchronously.
|
||||
pub(crate) async fn read_file(&self, id: &InDbFileId) -> anyhow::Result<InDbTempFile> {
|
||||
let db = self.clone();
|
||||
let id = *id;
|
||||
let join_handle = tokio::task::spawn_blocking(move || -> anyhow::Result<InDbTempFile> {
|
||||
db.read_file_sync(&id)
|
||||
})
|
||||
.await;
|
||||
match join_handle {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
tracing::error!("Error reading file. JoinError: {:?}", e);
|
||||
Err(e.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Write the blobs from a temporary file to LMDB.
|
||||
pub(crate) fn write_file_sync<'txn>(
|
||||
&'txn self,
|
||||
file: &InDbTempFile,
|
||||
wtxn: &mut heed::RwTxn<'txn>,
|
||||
) -> anyhow::Result<InDbFileId> {
|
||||
let id = InDbFileId::new();
|
||||
let mut file_handle = file.open_file_handle()?;
|
||||
|
||||
let mut blob_index: u32 = 0;
|
||||
loop {
|
||||
let mut blob = vec![0_u8; self.max_chunk_size];
|
||||
let bytes_read = file_handle.read(&mut blob)?;
|
||||
let is_end_of_file = bytes_read == 0;
|
||||
if is_end_of_file {
|
||||
break; // EOF reached
|
||||
}
|
||||
|
||||
let blob_key = id.get_blob_key(blob_index);
|
||||
self.tables
|
||||
.blobs
|
||||
.put(wtxn, &blob_key, &blob[..bytes_read])?;
|
||||
|
||||
blob_index += 1;
|
||||
}
|
||||
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Delete the blobs from LMDB.
|
||||
pub(crate) fn delete_file<'txn>(
|
||||
&'txn self,
|
||||
file: &InDbFileId,
|
||||
wtxn: &mut heed::RwTxn<'txn>,
|
||||
) -> anyhow::Result<bool> {
|
||||
let mut deleted_chunks = false;
|
||||
|
||||
{
|
||||
let mut iter = self.tables.blobs.prefix_iter_mut(wtxn, &file.bytes())?;
|
||||
|
||||
while iter.next().is_some() {
|
||||
unsafe {
|
||||
deleted_chunks = iter.del_current()?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(deleted_chunks)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_write_read_delete_file() {
|
||||
let lmdb = LmDB::test();
|
||||
|
||||
// Write file to LMDB
|
||||
let write_file = InDbTempFile::zeros(50).await.unwrap();
|
||||
let mut wtxn = lmdb.env.write_txn().unwrap();
|
||||
let id = lmdb.write_file_sync(&write_file, &mut wtxn).unwrap();
|
||||
wtxn.commit().unwrap();
|
||||
|
||||
// Read file from LMDB
|
||||
let read_file = lmdb.read_file(&id).await.unwrap();
|
||||
|
||||
assert_eq!(read_file.len(), write_file.len());
|
||||
assert_eq!(read_file.hash(), write_file.hash());
|
||||
|
||||
let written_file_content = std::fs::read(write_file.path()).unwrap();
|
||||
let read_file_content = std::fs::read(read_file.path()).unwrap();
|
||||
assert_eq!(written_file_content, read_file_content);
|
||||
|
||||
// Delete file from LMDB
|
||||
let mut wtxn = lmdb.env.write_txn().unwrap();
|
||||
let deleted = lmdb.delete_file(&id, &mut wtxn).unwrap();
|
||||
wtxn.commit().unwrap();
|
||||
assert!(deleted);
|
||||
|
||||
// Try to read file from LMDB
|
||||
let read_file = lmdb.read_file(&id).await.unwrap();
|
||||
assert_eq!(read_file.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_write_empty_file() {
|
||||
let lmdb = LmDB::test();
|
||||
|
||||
// Write file to LMDB
|
||||
let write_file = InDbTempFile::empty().unwrap();
|
||||
let mut wtxn = lmdb.env.write_txn().unwrap();
|
||||
let id = lmdb.write_file_sync(&write_file, &mut wtxn).unwrap();
|
||||
wtxn.commit().unwrap();
|
||||
|
||||
// Read file from LMDB
|
||||
let read_file = lmdb.read_file(&id).await.unwrap();
|
||||
|
||||
assert_eq!(read_file.len(), write_file.len());
|
||||
assert_eq!(read_file.hash(), write_file.hash());
|
||||
|
||||
let written_file_content = std::fs::read(write_file.path()).unwrap();
|
||||
let read_file_content = std::fs::read(read_file.path()).unwrap();
|
||||
assert_eq!(written_file_content, read_file_content);
|
||||
}
|
||||
}
|
||||
499
pubky-homeserver/src/persistence/lmdb/tables/files/entries.rs
Normal file
499
pubky-homeserver/src/persistence/lmdb/tables/files/entries.rs
Normal file
@@ -0,0 +1,499 @@
|
||||
use super::super::events::Event;
|
||||
use super::{super::super::LmDB, InDbFileId, InDbTempFile};
|
||||
use crate::constants::{DEFAULT_LIST_LIMIT, DEFAULT_MAX_LIST_LIMIT};
|
||||
use crate::shared::webdav::EntryPath;
|
||||
use heed::{
|
||||
types::{Bytes, Str},
|
||||
Database, RoTxn,
|
||||
};
|
||||
use postcard::{from_bytes, to_allocvec};
|
||||
use pubky_common::{crypto::Hash, timestamp::Timestamp};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::instrument;
|
||||
|
||||
/// full_path(pubky/*path) => Entry.
|
||||
pub type EntriesTable = Database<Str, Bytes>;
|
||||
|
||||
pub const ENTRIES_TABLE: &str = "entries";
|
||||
|
||||
impl LmDB {
|
||||
/// Writes an entry to the database.
|
||||
///
|
||||
/// The entry is written to the database and the file is written to the blob store.
|
||||
/// An event is written to the events table.
|
||||
/// The entry is returned.
|
||||
pub async fn write_entry(
|
||||
&mut self,
|
||||
path: &EntryPath,
|
||||
file: &InDbTempFile,
|
||||
) -> anyhow::Result<Entry> {
|
||||
let mut db = self.clone();
|
||||
let path = path.clone();
|
||||
let file = file.clone();
|
||||
let join_handle = tokio::task::spawn_blocking(move || -> anyhow::Result<Entry> {
|
||||
db.write_entry_sync(&path, &file)
|
||||
})
|
||||
.await;
|
||||
match join_handle {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
tracing::error!("Error writing entry. JoinError: {:?}", e);
|
||||
Err(e.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Writes an entry to the database.
|
||||
///
|
||||
/// The entry is written to the database and the file is written to the blob store.
|
||||
/// An event is written to the events table.
|
||||
/// The entry is returned.
|
||||
pub fn write_entry_sync(
|
||||
&mut self,
|
||||
path: &EntryPath,
|
||||
file: &InDbTempFile,
|
||||
) -> anyhow::Result<Entry> {
|
||||
let mut wtxn = self.env.write_txn()?;
|
||||
let mut entry = Entry::new();
|
||||
entry.set_content_hash(*file.hash());
|
||||
entry.set_content_length(file.len());
|
||||
let file_id = self.write_file_sync(file, &mut wtxn)?;
|
||||
entry.set_timestamp(file_id.timestamp());
|
||||
let entry_key = path.to_string();
|
||||
self.tables
|
||||
.entries
|
||||
.put(&mut wtxn, entry_key.as_str(), &entry.serialize())?;
|
||||
|
||||
// Write a public [Event].
|
||||
let url = format!("pubky://{}", entry_key);
|
||||
let event = Event::put(&url);
|
||||
let value = event.serialize();
|
||||
|
||||
self.tables
|
||||
.events
|
||||
.put(&mut wtxn, file_id.timestamp().to_string().as_str(), &value)?;
|
||||
wtxn.commit()?;
|
||||
|
||||
Ok(entry)
|
||||
}
|
||||
|
||||
/// Get an entry from the database.
|
||||
/// This doesn't include the file but only metadata.
|
||||
pub fn get_entry(&self, path: &EntryPath) -> anyhow::Result<Option<Entry>> {
|
||||
let txn = self.env.read_txn()?;
|
||||
let entry = match self.tables.entries.get(&txn, path.as_str())? {
|
||||
Some(bytes) => Entry::deserialize(bytes)?,
|
||||
None => return Ok(None),
|
||||
};
|
||||
Ok(Some(entry))
|
||||
}
|
||||
|
||||
/// Delete an entry including the associated file from the database.
|
||||
pub async fn delete_entry(&mut self, path: &EntryPath) -> anyhow::Result<bool> {
|
||||
let mut db = self.clone();
|
||||
let path = path.clone();
|
||||
let join_handle = tokio::task::spawn_blocking(move || -> anyhow::Result<bool> {
|
||||
db.delete_entry_sync(&path)
|
||||
})
|
||||
.await;
|
||||
match join_handle {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
tracing::error!("Error deleting entry. JoinError: {:?}", e);
|
||||
Err(e.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Delete an entry including the associated file from the database.
|
||||
pub fn delete_entry_sync(&mut self, path: &EntryPath) -> anyhow::Result<bool> {
|
||||
let entry = match self.get_entry(path)? {
|
||||
Some(entry) => entry,
|
||||
None => return Ok(false),
|
||||
};
|
||||
|
||||
let mut wtxn = self.env.write_txn()?;
|
||||
let deleted = self.delete_file(&entry.file_id(), &mut wtxn)?;
|
||||
if !deleted {
|
||||
wtxn.abort();
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let deleted = self.tables.entries.delete(&mut wtxn, path.as_str())?;
|
||||
if !deleted {
|
||||
wtxn.abort();
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
// create DELETE event
|
||||
let url = format!("pubky://{}", path.as_str());
|
||||
|
||||
let event = Event::delete(&url);
|
||||
let value = event.serialize();
|
||||
|
||||
let key = Timestamp::now().to_string();
|
||||
|
||||
self.tables.events.put(&mut wtxn, &key, &value)?;
|
||||
|
||||
wtxn.commit()?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Bytes stored at `path` for this user (0 if none).
|
||||
pub fn get_entry_content_length(&self, path: &EntryPath) -> anyhow::Result<u64> {
|
||||
let content_length = self
|
||||
.get_entry(path)?
|
||||
.map(|e| e.content_length() as u64)
|
||||
.unwrap_or(0);
|
||||
Ok(content_length)
|
||||
}
|
||||
|
||||
pub fn contains_directory(&self, txn: &RoTxn, entry_path: &EntryPath) -> anyhow::Result<bool> {
|
||||
Ok(self
|
||||
.tables
|
||||
.entries
|
||||
.get_greater_than(txn, entry_path.as_str())?
|
||||
.is_some())
|
||||
}
|
||||
|
||||
/// Return a list of pubky urls.
|
||||
///
|
||||
/// - limit defaults to [crate::config::DEFAULT_LIST_LIMIT] and capped by [crate::config::DEFAULT_MAX_LIST_LIMIT]
|
||||
pub fn list_entries(
|
||||
&self,
|
||||
txn: &RoTxn,
|
||||
entry_path: &EntryPath,
|
||||
reverse: bool,
|
||||
limit: Option<u16>,
|
||||
cursor: Option<String>,
|
||||
shallow: bool,
|
||||
) -> anyhow::Result<Vec<String>> {
|
||||
// Vector to store results
|
||||
let mut results = Vec::new();
|
||||
|
||||
let limit = limit
|
||||
.unwrap_or(DEFAULT_LIST_LIMIT)
|
||||
.min(DEFAULT_MAX_LIST_LIMIT);
|
||||
|
||||
// TODO: make this more performant than split and allocations?
|
||||
|
||||
let mut threshold = cursor
|
||||
.map(|cursor| {
|
||||
// Removing leading forward slashes
|
||||
let mut file_or_directory = cursor.trim_start_matches('/');
|
||||
|
||||
if cursor.starts_with("pubky://") {
|
||||
file_or_directory = cursor
|
||||
.split(entry_path.as_str())
|
||||
.last()
|
||||
.expect("should not be reachable")
|
||||
};
|
||||
|
||||
next_threshold(
|
||||
entry_path.as_str(),
|
||||
file_or_directory,
|
||||
file_or_directory.ends_with('/'),
|
||||
reverse,
|
||||
shallow,
|
||||
)
|
||||
})
|
||||
.unwrap_or(next_threshold(
|
||||
entry_path.as_str(),
|
||||
"",
|
||||
false,
|
||||
reverse,
|
||||
shallow,
|
||||
));
|
||||
|
||||
for _ in 0..limit {
|
||||
if let Some((key, _)) = if reverse {
|
||||
self.tables.entries.get_lower_than(txn, &threshold)?
|
||||
} else {
|
||||
self.tables.entries.get_greater_than(txn, &threshold)?
|
||||
} {
|
||||
if !key.starts_with(entry_path.as_str()) {
|
||||
break;
|
||||
}
|
||||
|
||||
if shallow {
|
||||
let mut split = key[entry_path.as_str().len()..].split('/');
|
||||
let file_or_directory = split.next().expect("should not be reachable");
|
||||
|
||||
let is_directory = split.next().is_some();
|
||||
|
||||
threshold = next_threshold(
|
||||
entry_path.as_str(),
|
||||
file_or_directory,
|
||||
is_directory,
|
||||
reverse,
|
||||
shallow,
|
||||
);
|
||||
|
||||
results.push(format!(
|
||||
"pubky://{}{file_or_directory}{}",
|
||||
entry_path.as_str(),
|
||||
if is_directory { "/" } else { "" }
|
||||
));
|
||||
} else {
|
||||
threshold = key.to_string();
|
||||
results.push(format!("pubky://{}", key))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the next threshold
|
||||
#[instrument]
|
||||
fn next_threshold(
|
||||
path: &str,
|
||||
file_or_directory: &str,
|
||||
is_directory: bool,
|
||||
reverse: bool,
|
||||
shallow: bool,
|
||||
) -> String {
|
||||
format!(
|
||||
"{path}{file_or_directory}{}",
|
||||
if file_or_directory.is_empty() {
|
||||
// No file_or_directory, early return
|
||||
if reverse {
|
||||
// `path/to/dir/\x7f` to catch all paths than `path/to/dir/`
|
||||
"\x7f"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
} else if shallow & is_directory {
|
||||
if reverse {
|
||||
// threshold = `path/to/dir\x2e`, since `\x2e` is lower than `/`
|
||||
"\x2e"
|
||||
} else {
|
||||
//threshold = `path/to/dir\x7f`, since `\x7f` is greater than `/`
|
||||
"\x7f"
|
||||
}
|
||||
} else {
|
||||
""
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct Entry {
|
||||
/// Encoding version
|
||||
version: usize,
|
||||
/// Modified at
|
||||
timestamp: Timestamp,
|
||||
content_hash: EntryHash,
|
||||
content_length: usize,
|
||||
content_type: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
struct EntryHash(Hash);
|
||||
|
||||
impl Default for EntryHash {
|
||||
fn default() -> Self {
|
||||
Self(Hash::from_bytes([0; 32]))
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for EntryHash {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let bytes = self.0.as_bytes();
|
||||
bytes.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for EntryHash {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let bytes: [u8; 32] = Deserialize::deserialize(deserializer)?;
|
||||
Ok(Self(Hash::from_bytes(bytes)))
|
||||
}
|
||||
}
|
||||
|
||||
impl Entry {
|
||||
pub fn new() -> Self {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
pub fn chunk_key(&self, chunk_index: u32) -> [u8; 12] {
|
||||
let mut chunk_key = [0; 12];
|
||||
chunk_key[0..8].copy_from_slice(&self.timestamp.to_bytes());
|
||||
chunk_key[8..].copy_from_slice(&chunk_index.to_be_bytes());
|
||||
chunk_key
|
||||
}
|
||||
|
||||
// === Setters ===
|
||||
|
||||
pub fn set_timestamp(&mut self, timestamp: &Timestamp) -> &mut Self {
|
||||
self.timestamp = *timestamp;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn set_content_hash(&mut self, content_hash: Hash) -> &mut Self {
|
||||
EntryHash(content_hash).clone_into(&mut self.content_hash);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn set_content_length(&mut self, content_length: usize) -> &mut Self {
|
||||
self.content_length = content_length;
|
||||
self
|
||||
}
|
||||
|
||||
// === Getters ===
|
||||
|
||||
pub fn timestamp(&self) -> &Timestamp {
|
||||
&self.timestamp
|
||||
}
|
||||
|
||||
pub fn content_hash(&self) -> &Hash {
|
||||
&self.content_hash.0
|
||||
}
|
||||
|
||||
pub fn content_length(&self) -> usize {
|
||||
self.content_length
|
||||
}
|
||||
|
||||
pub fn content_type(&self) -> &str {
|
||||
&self.content_type
|
||||
}
|
||||
|
||||
pub fn file_id(&self) -> InDbFileId {
|
||||
InDbFileId::from(self.timestamp)
|
||||
}
|
||||
|
||||
// === Public Method ===
|
||||
|
||||
pub fn serialize(&self) -> Vec<u8> {
|
||||
to_allocvec(self).expect("Session::serialize")
|
||||
}
|
||||
|
||||
pub fn deserialize(bytes: &[u8]) -> core::result::Result<Self, postcard::Error> {
|
||||
if bytes[0] > 0 {
|
||||
panic!("Unknown Entry version");
|
||||
}
|
||||
|
||||
from_bytes(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::LmDB;
|
||||
use crate::{
|
||||
persistence::lmdb::tables::files::{InDbTempFile, SyncInDbTempFileWriter},
|
||||
shared::webdav::{EntryPath, WebDavPath},
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use pkarr::Keypair;
|
||||
use std::io::Read;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_write_read_delete_method() {
|
||||
let mut db = LmDB::test();
|
||||
|
||||
let path = EntryPath::new(
|
||||
Keypair::random().public_key(),
|
||||
WebDavPath::new("/pub/foo.txt").unwrap(),
|
||||
);
|
||||
let file = InDbTempFile::zeros(5).await.unwrap();
|
||||
let entry = db.write_entry_sync(&path, &file).unwrap();
|
||||
|
||||
let read_entry = db.get_entry(&path).unwrap().expect("Entry doesn't exist");
|
||||
assert_eq!(entry.content_hash(), read_entry.content_hash());
|
||||
assert_eq!(entry.content_length(), read_entry.content_length());
|
||||
assert_eq!(entry.timestamp(), read_entry.timestamp());
|
||||
|
||||
let read_file = db.read_file(&entry.file_id()).await.unwrap();
|
||||
let mut file_handle = read_file.open_file_handle().unwrap();
|
||||
let mut content = vec![];
|
||||
file_handle.read_to_end(&mut content).unwrap();
|
||||
assert_eq!(content, vec![0, 0, 0, 0, 0]);
|
||||
|
||||
let deleted = db.delete_entry_sync(&path).unwrap();
|
||||
assert!(deleted);
|
||||
|
||||
// Verify the entry and file are deleted
|
||||
let read_entry = db.get_entry(&path).unwrap();
|
||||
assert!(read_entry.is_none());
|
||||
let read_file = db.read_file(&entry.file_id()).await.unwrap();
|
||||
assert_eq!(read_file.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn entries() -> anyhow::Result<()> {
|
||||
let mut db = LmDB::test();
|
||||
|
||||
let keypair = Keypair::random();
|
||||
let public_key = keypair.public_key();
|
||||
let path = "/pub/foo.txt";
|
||||
|
||||
let entry_path = EntryPath::new(public_key, WebDavPath::new(path).unwrap());
|
||||
let chunk = Bytes::from(vec![1, 2, 3, 4, 5]);
|
||||
let mut writer = SyncInDbTempFileWriter::new()?;
|
||||
writer.write_chunk(&chunk)?;
|
||||
let file = writer.complete()?;
|
||||
|
||||
db.write_entry_sync(&entry_path, &file)?;
|
||||
|
||||
let entry = db.get_entry(&entry_path).unwrap().unwrap();
|
||||
|
||||
assert_eq!(
|
||||
entry.content_hash(),
|
||||
&[
|
||||
2, 79, 103, 192, 66, 90, 61, 192, 47, 186, 245, 140, 185, 61, 229, 19, 46, 61, 117,
|
||||
197, 25, 250, 160, 186, 218, 33, 73, 29, 136, 201, 112, 87
|
||||
]
|
||||
);
|
||||
|
||||
let read_file = db.read_file(&entry.file_id()).await.unwrap();
|
||||
let mut file_handle = read_file.open_file_handle().unwrap();
|
||||
let mut content = vec![];
|
||||
file_handle.read_to_end(&mut content).unwrap();
|
||||
assert_eq!(content, vec![1, 2, 3, 4, 5]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chunked_entry() -> anyhow::Result<()> {
|
||||
let mut db = LmDB::test();
|
||||
|
||||
let keypair = Keypair::random();
|
||||
let public_key = keypair.public_key();
|
||||
let path = "/pub/foo.txt";
|
||||
let entry_path = EntryPath::new(public_key, WebDavPath::new(path).unwrap());
|
||||
|
||||
let chunk = Bytes::from(vec![0; 1024 * 1024]);
|
||||
|
||||
let mut writer = SyncInDbTempFileWriter::new()?;
|
||||
writer.write_chunk(&chunk)?;
|
||||
let file = writer.complete()?;
|
||||
|
||||
db.write_entry_sync(&entry_path, &file)?;
|
||||
|
||||
let entry = db.get_entry(&entry_path).unwrap().unwrap();
|
||||
|
||||
assert_eq!(
|
||||
entry.content_hash(),
|
||||
&[
|
||||
72, 141, 226, 2, 247, 59, 217, 118, 222, 78, 112, 72, 244, 225, 243, 154, 119, 109,
|
||||
134, 213, 130, 183, 52, 143, 245, 59, 244, 50, 185, 135, 252, 168
|
||||
]
|
||||
);
|
||||
|
||||
let read_file = db.read_file(&entry.file_id()).await.unwrap();
|
||||
let mut file_handle = read_file.open_file_handle().unwrap();
|
||||
let mut content = vec![];
|
||||
file_handle.read_to_end(&mut content).unwrap();
|
||||
assert_eq!(content, vec![0; 1024 * 1024]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
236
pubky-homeserver/src/persistence/lmdb/tables/files/in_db_file.rs
Normal file
236
pubky-homeserver/src/persistence/lmdb/tables/files/in_db_file.rs
Normal file
@@ -0,0 +1,236 @@
|
||||
//!
|
||||
//! InDbFile is an abstraction over the way we store blobs/chunks in LMDB.
|
||||
//! Because the value size of LMDB is limited, we need to store multiple blobs for one file.
|
||||
//!
|
||||
//! - `InDbFileId` is the identifier of a file that consists of multiple blobs.
|
||||
//! - `InDbTempFile` is a helper to read/write a file to/from disk.
|
||||
//!
|
||||
use pubky_common::crypto::{Hash, Hasher};
|
||||
use pubky_common::timestamp::Timestamp;
|
||||
|
||||
/// A file identifier for a file stored in LMDB.
|
||||
/// The identifier is basically the timestamp of the file.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct InDbFileId(Timestamp);
|
||||
|
||||
impl InDbFileId {
|
||||
pub fn new() -> Self {
|
||||
Self(Timestamp::now())
|
||||
}
|
||||
|
||||
pub fn timestamp(&self) -> &Timestamp {
|
||||
&self.0
|
||||
}
|
||||
|
||||
pub fn bytes(&self) -> [u8; 8] {
|
||||
self.0.to_bytes()
|
||||
}
|
||||
|
||||
/// Create a blob key from a timestamp and a blob index.
|
||||
/// blob key = (timestamp | blob_index) => bytes.
|
||||
/// Max file size is 2^32 blobs.
|
||||
pub fn get_blob_key(&self, blob_index: u32) -> [u8; 12] {
|
||||
let mut blob_key = [0; 12];
|
||||
blob_key[0..8].copy_from_slice(&self.bytes());
|
||||
blob_key[8..].copy_from_slice(&blob_index.to_be_bytes());
|
||||
blob_key
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Timestamp> for InDbFileId {
|
||||
fn from(timestamp: Timestamp) -> Self {
|
||||
Self(timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for InDbFileId {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::{fs::File, io::Write, path::PathBuf};
|
||||
use tokio::fs::File as AsyncFile;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::task;
|
||||
|
||||
/// Writes a temp file to disk asynchronously.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct AsyncInDbTempFileWriter {
|
||||
// Temp dir is automatically deleted when the EntryTempFile is dropped.
|
||||
#[allow(dead_code)]
|
||||
dir: tempfile::TempDir,
|
||||
writer_file: AsyncFile,
|
||||
file_path: PathBuf,
|
||||
hasher: Hasher,
|
||||
}
|
||||
|
||||
impl AsyncInDbTempFileWriter {
|
||||
pub async fn new() -> Result<Self, std::io::Error> {
|
||||
let dir = task::spawn_blocking(tempfile::tempdir)
|
||||
.await
|
||||
.map_err(|join_error| {
|
||||
std::io::Error::other(format!(
|
||||
"Task join error for tempdir creation: {}",
|
||||
join_error
|
||||
))
|
||||
})??; // Handles the Result from tempfile::tempdir()
|
||||
|
||||
let file_path = dir.path().join("entry.bin");
|
||||
let writer_file = AsyncFile::create(file_path.clone()).await?;
|
||||
let hasher = Hasher::new();
|
||||
|
||||
Ok(Self {
|
||||
dir,
|
||||
writer_file,
|
||||
file_path,
|
||||
hasher,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a new BlobsTempFile with zero content.
|
||||
/// Convenient method used for testing.
|
||||
#[cfg(test)]
|
||||
pub async fn zeros(size_bytes: usize) -> Result<InDbTempFile, std::io::Error> {
|
||||
let mut file = Self::new().await?;
|
||||
let buffer = vec![0u8; size_bytes];
|
||||
file.write_chunk(&buffer).await?;
|
||||
file.complete().await
|
||||
}
|
||||
|
||||
/// Write a chunk to the file.
|
||||
/// Chunk writing is done by the axum body stream and by LMDB itself.
|
||||
pub async fn write_chunk(&mut self, chunk: &[u8]) -> Result<(), std::io::Error> {
|
||||
self.writer_file.write_all(chunk).await?;
|
||||
self.hasher.update(chunk);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Flush the file to disk.
|
||||
/// This completes the writing of the file.
|
||||
/// Returns a BlobsTempFile that can be used to read the file.
|
||||
pub async fn complete(mut self) -> Result<InDbTempFile, std::io::Error> {
|
||||
self.writer_file.flush().await?;
|
||||
let hash = self.hasher.finalize();
|
||||
let file_size = self.writer_file.metadata().await?.len();
|
||||
Ok(InDbTempFile {
|
||||
dir: Arc::new(self.dir),
|
||||
file_path: self.file_path,
|
||||
file_size: file_size as usize,
|
||||
file_hash: hash,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Writes a temp file to disk synchronously.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct SyncInDbTempFileWriter {
|
||||
// Temp dir is automatically deleted when the EntryTempFile is dropped.
|
||||
#[allow(dead_code)]
|
||||
dir: tempfile::TempDir,
|
||||
writer_file: File,
|
||||
file_path: PathBuf,
|
||||
hasher: Hasher,
|
||||
}
|
||||
|
||||
impl SyncInDbTempFileWriter {
|
||||
pub fn new() -> Result<Self, std::io::Error> {
|
||||
let dir = tempfile::tempdir()?;
|
||||
let file_path = dir.path().join("entry.bin");
|
||||
let writer_file = File::create(file_path.clone())?;
|
||||
let hasher = Hasher::new();
|
||||
|
||||
Ok(Self {
|
||||
dir,
|
||||
writer_file,
|
||||
file_path,
|
||||
hasher,
|
||||
})
|
||||
}
|
||||
|
||||
/// Write a chunk to the file.
|
||||
pub fn write_chunk(&mut self, chunk: &[u8]) -> Result<(), std::io::Error> {
|
||||
self.writer_file.write_all(chunk)?;
|
||||
self.hasher.update(chunk);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Flush the file to disk.
|
||||
/// This completes the writing of the file.
|
||||
/// Returns a BlobsTempFile that can be used to read the file.
|
||||
pub fn complete(mut self) -> Result<InDbTempFile, std::io::Error> {
|
||||
self.writer_file.flush()?;
|
||||
let hash = self.hasher.finalize();
|
||||
let file_size = self.writer_file.metadata()?.len();
|
||||
Ok(InDbTempFile {
|
||||
dir: Arc::new(self.dir),
|
||||
file_path: self.file_path,
|
||||
file_size: file_size as usize,
|
||||
file_hash: hash,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A temporary file helper for Entry.
|
||||
///
|
||||
/// Every file in LMDB is first written to disk before being written to LMDB.
|
||||
/// The same is true if you read a file from LMDB.
|
||||
///
|
||||
/// This is to keep the LMDB transaction small and fast.
|
||||
///
|
||||
/// As soon as EntryTempFile is dropped, the file on disk is deleted.
|
||||
///
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InDbTempFile {
|
||||
// Temp dir is automatically deleted when the EntryTempFile is dropped.
|
||||
#[allow(dead_code)]
|
||||
dir: Arc<tempfile::TempDir>,
|
||||
file_path: PathBuf,
|
||||
file_size: usize,
|
||||
file_hash: Hash,
|
||||
}
|
||||
|
||||
impl InDbTempFile {
|
||||
/// Create a new BlobsTempFile with random content.
|
||||
/// Convenient method used for testing.
|
||||
#[cfg(test)]
|
||||
pub async fn zeros(size_bytes: usize) -> Result<Self, std::io::Error> {
|
||||
AsyncInDbTempFileWriter::zeros(size_bytes).await
|
||||
}
|
||||
|
||||
/// Create a new InDbTempFile with zero content.
|
||||
pub fn empty() -> Result<Self, std::io::Error> {
|
||||
let dir = tempfile::tempdir()?;
|
||||
let file_path = dir.path().join("entry.bin");
|
||||
std::fs::File::create(file_path.clone())?;
|
||||
let file_size = 0;
|
||||
let hasher = Hasher::new();
|
||||
let file_hash = hasher.finalize();
|
||||
Ok(Self {
|
||||
dir: Arc::new(dir),
|
||||
file_path,
|
||||
file_size,
|
||||
file_hash,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.file_size
|
||||
}
|
||||
|
||||
pub fn hash(&self) -> &Hash {
|
||||
&self.file_hash
|
||||
}
|
||||
|
||||
/// Get the path of the file on disk.
|
||||
#[cfg(test)]
|
||||
pub fn path(&self) -> &PathBuf {
|
||||
&self.file_path
|
||||
}
|
||||
|
||||
/// Open the file on disk.
|
||||
pub fn open_file_handle(&self) -> Result<File, std::io::Error> {
|
||||
File::open(self.file_path.as_path())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
mod blobs;
|
||||
mod entries;
|
||||
mod in_db_file;
|
||||
|
||||
pub use blobs::{BlobsTable, BLOBS_TABLE};
|
||||
pub use entries::{EntriesTable, Entry, ENTRIES_TABLE};
|
||||
pub use in_db_file::*;
|
||||
@@ -157,6 +157,7 @@ impl LmDB {
|
||||
/// # Errors
|
||||
///
|
||||
/// - `UserQueryError::DatabaseError` if the database operation fails.
|
||||
#[cfg(test)]
|
||||
pub fn create_user(&self, pubkey: &PublicKey, wtxn: &mut RwTxn) -> anyhow::Result<()> {
|
||||
let user = User::default();
|
||||
self.tables.users.put(wtxn, pubkey, &user)?;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
mod http_error;
|
||||
mod pubkey_path_validator;
|
||||
pub(crate) mod webdav;
|
||||
|
||||
pub(crate) use http_error::{HttpError, HttpResult};
|
||||
pub(crate) use pubkey_path_validator::Z32Pubkey;
|
||||
|
||||
124
pubky-homeserver/src/shared/webdav/entry_path.rs
Normal file
124
pubky-homeserver/src/shared/webdav/entry_path.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
use pkarr::PublicKey;
|
||||
use std::str::FromStr;
|
||||
|
||||
use super::WebDavPath;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum EntryPathError {
|
||||
#[error("{0}")]
|
||||
Invalid(String),
|
||||
#[error("Failed to parse webdav path: {0}")]
|
||||
InvalidWebdavPath(anyhow::Error),
|
||||
#[error("Failed to parse pubkey: {0}")]
|
||||
InvalidPubkey(pkarr::errors::PublicKeyError),
|
||||
}
|
||||
|
||||
/// A path to an entry.
|
||||
///
|
||||
/// The path as a string is used to identify the entry.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EntryPath {
|
||||
#[allow(dead_code)]
|
||||
pubkey: PublicKey,
|
||||
path: WebDavPath,
|
||||
/// The key of the entry represented as a string.
|
||||
/// The key is the pubkey and the path concatenated.
|
||||
/// Example: `8pinxxgqs41n4aididenw5apqp1urfmzdztr8jt4abrkdn435ewo/folder/file.txt`
|
||||
/// This is cached/redundant to avoid reallocating the string on every access.
|
||||
key: String,
|
||||
}
|
||||
|
||||
impl EntryPath {
|
||||
pub fn new(pubkey: PublicKey, path: WebDavPath) -> Self {
|
||||
let key = format!("{}{}", pubkey, path);
|
||||
Self { pubkey, path, key }
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn pubkey(&self) -> &PublicKey {
|
||||
&self.pubkey
|
||||
}
|
||||
|
||||
pub fn path(&self) -> &WebDavPath {
|
||||
&self.path
|
||||
}
|
||||
|
||||
/// The key of the entry.
|
||||
///
|
||||
/// The key is the pubkey and the path concatenated.
|
||||
///
|
||||
/// Example: `8pinxxgqs41n4aididenw5apqp1urfmzdztr8jt4abrkdn435ewo/folder/file.txt`
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.key
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<str> for EntryPath {
|
||||
fn as_ref(&self) -> &str {
|
||||
&self.key
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for EntryPath {
|
||||
type Err = EntryPathError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, EntryPathError> {
|
||||
let first_slash_index = s
|
||||
.find('/')
|
||||
.ok_or(EntryPathError::Invalid("Missing '/'".to_string()))?;
|
||||
let (pubkey, path) = match s.split_at_checked(first_slash_index) {
|
||||
Some((pubkey, path)) => (pubkey, path),
|
||||
None => return Err(EntryPathError::Invalid("Missing '/'".to_string())),
|
||||
};
|
||||
let pubkey = PublicKey::from_str(pubkey).map_err(EntryPathError::InvalidPubkey)?;
|
||||
let webdav_path = WebDavPath::new(path).map_err(EntryPathError::InvalidWebdavPath)?;
|
||||
Ok(Self::new(pubkey, webdav_path))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for EntryPath {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
impl serde::Serialize for EntryPath {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(self.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::Deserialize<'de> for EntryPath {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
Self::from_str(&s).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_entry_path_from_str() {
|
||||
let pubkey =
|
||||
PublicKey::from_str("8pinxxgqs41n4aididenw5apqp1urfmzdztr8jt4abrkdn435ewo").unwrap();
|
||||
let path = WebDavPath::new("/pub/folder/file.txt").unwrap();
|
||||
let key = format!("{pubkey}{path}");
|
||||
let entry_path = EntryPath::new(pubkey, path);
|
||||
assert_eq!(entry_path.as_ref(), key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_entry_path_serde() {
|
||||
let string = "8pinxxgqs41n4aididenw5apqp1urfmzdztr8jt4abrkdn435ewo/pub/folder/file.txt";
|
||||
let entry_path = EntryPath::from_str(string).unwrap();
|
||||
assert_eq!(entry_path.to_string(), string);
|
||||
}
|
||||
}
|
||||
7
pubky-homeserver/src/shared/webdav/mod.rs
Normal file
7
pubky-homeserver/src/shared/webdav/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
mod entry_path;
|
||||
mod webdav_path;
|
||||
mod webdav_path_axum;
|
||||
|
||||
pub use entry_path::EntryPath;
|
||||
pub use webdav_path::WebDavPath;
|
||||
pub use webdav_path_axum::WebDavPathAxum;
|
||||
343
pubky-homeserver/src/shared/webdav/webdav_path.rs
Normal file
343
pubky-homeserver/src/shared/webdav/webdav_path.rs
Normal file
@@ -0,0 +1,343 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
/// A normalized and validated webdav path.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct WebDavPath {
|
||||
normalized_path: String,
|
||||
}
|
||||
|
||||
impl WebDavPath {
|
||||
/// Create a new WebDavPath from a already normalized path.
|
||||
/// Make sure the path is 100% normalized, and valid before using this constructor.
|
||||
///
|
||||
/// Use `WebDavPath::new` to create a new WebDavPath from an unnormalized path.
|
||||
pub fn new_unchecked(normalized_path: String) -> Self {
|
||||
Self {
|
||||
normalized_path: normalized_path.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new WebDavPath from an unnormalized path.
|
||||
///
|
||||
/// The path will be normalized and validated.
|
||||
pub fn new(unnormalized_path: &str) -> anyhow::Result<Self> {
|
||||
let normalized_path = normalize_and_validate_webdav_path(unnormalized_path)?;
|
||||
if !normalized_path.starts_with("/pub/") {
|
||||
return Err(anyhow::anyhow!("Path must start with /pub/"));
|
||||
}
|
||||
Ok(Self::new_unchecked(normalized_path))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn url_encode(&self) -> String {
|
||||
percent_encoding::utf8_percent_encode(self.normalized_path.as_str(), PATH_ENCODE_SET)
|
||||
.to_string()
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &str {
|
||||
self.normalized_path.as_str()
|
||||
}
|
||||
|
||||
pub fn is_directory(&self) -> bool {
|
||||
self.normalized_path.ends_with('/')
|
||||
}
|
||||
|
||||
/// Check if the path is a file.
|
||||
#[allow(dead_code)]
|
||||
pub fn is_file(&self) -> bool {
|
||||
!self.is_directory()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for WebDavPath {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.normalized_path)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for WebDavPath {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
Self::new(s)
|
||||
}
|
||||
}
|
||||
|
||||
// Encode all non-unreserved characters, except '/'.
|
||||
// See RFC3986, and https://en.wikipedia.org/wiki/Percent-encoding .
|
||||
const PATH_ENCODE_SET: &percent_encoding::AsciiSet = &percent_encoding::NON_ALPHANUMERIC
|
||||
.remove(b'-')
|
||||
.remove(b'_')
|
||||
.remove(b'.')
|
||||
.remove(b'~')
|
||||
.remove(b'/');
|
||||
|
||||
/// Maximum length of a single path segment.
|
||||
const MAX_WEBDAV_PATH_SEGMENT_LENGTH: usize = 255;
|
||||
/// Maximum total length of a normalized WebDAV path.
|
||||
const MAX_WEBDAV_PATH_TOTAL_LENGTH: usize = 4096;
|
||||
|
||||
/// Takes a path, normalizes and validates it.
|
||||
/// Make sure to url decode the path before calling this function.
|
||||
/// Inspired by https://github.com/messense/dav-server-rs/blob/740dae05ac2eeda8e2ea11fface3ab6d53b6705e/src/davpath.rs#L101
|
||||
fn normalize_and_validate_webdav_path(path: &str) -> anyhow::Result<String> {
|
||||
// Ensure the path starts with '/'
|
||||
if !path.starts_with('/') {
|
||||
return Err(anyhow::anyhow!("Path must start with '/'"));
|
||||
}
|
||||
|
||||
let is_dir = path.ends_with('/') || path.ends_with("..");
|
||||
// Split the path into segments, filtering out empty ones
|
||||
let segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
|
||||
// Build the normalized path
|
||||
let mut normalized_segments = vec![];
|
||||
|
||||
for segment in segments {
|
||||
// Check for segment length
|
||||
if segment.len() > MAX_WEBDAV_PATH_SEGMENT_LENGTH {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Invalid path: Segment exceeds maximum length of {} characters. Segment: '{}'",
|
||||
MAX_WEBDAV_PATH_SEGMENT_LENGTH,
|
||||
segment
|
||||
));
|
||||
}
|
||||
|
||||
// Check for any ASCII control characters in the decoded segment
|
||||
if segment.chars().any(|c| c.is_control()) {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Invalid path: ASCII control characters are not allowed in segments"
|
||||
));
|
||||
}
|
||||
|
||||
if segment == "." {
|
||||
continue;
|
||||
} else if segment == ".." {
|
||||
if normalized_segments.len() < 2 {
|
||||
return Err(anyhow::anyhow!("Failed to normalize path: '..'."));
|
||||
}
|
||||
normalized_segments.pop();
|
||||
normalized_segments.pop();
|
||||
} else {
|
||||
normalized_segments.push("/".to_string());
|
||||
normalized_segments.push(segment.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if is_dir {
|
||||
normalized_segments.push("/".to_string());
|
||||
}
|
||||
let full_path = normalized_segments.join("");
|
||||
|
||||
// Check for total path length
|
||||
if full_path.len() > MAX_WEBDAV_PATH_TOTAL_LENGTH {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Invalid path: Total path length exceeds maximum of {} characters. Length: {}, Path: '{}'",
|
||||
MAX_WEBDAV_PATH_TOTAL_LENGTH,
|
||||
full_path.len(),
|
||||
full_path
|
||||
));
|
||||
}
|
||||
|
||||
Ok(full_path)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn assert_valid_path(path: &str, expected: &str) {
|
||||
match normalize_and_validate_webdav_path(path) {
|
||||
Ok(path) => {
|
||||
assert_eq!(path, expected);
|
||||
}
|
||||
Err(e) => {
|
||||
assert!(
|
||||
false,
|
||||
"Path '{path}' is invalid. Should be '{expected}'. Error: {e}"
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn assert_invalid_path(path: &str) {
|
||||
if let Ok(normalized_path) = normalize_and_validate_webdav_path(path) {
|
||||
assert!(
|
||||
false,
|
||||
"Invalid path '{path}' is valid. Normalized result: '{normalized_path}'"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slash_is_valid() {
|
||||
assert_valid_path("/", "/");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_two_dots_is_valid() {
|
||||
assert_valid_path("/test/..", "/");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_two_dots_in_the_middle_is_valid() {
|
||||
assert_valid_path("/test/../test", "/test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_two_dots_in_the_middle_with_slash_is_valid() {
|
||||
assert_valid_path("/test/../test/", "/test/");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_two_dots_invalid() {
|
||||
assert_invalid_path("/..");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_two_dots_twice_invalid() {
|
||||
assert_invalid_path("/test/../..");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_two_slashes_is_valid() {
|
||||
assert_valid_path("//", "/");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_two_slashes_in_the_middle_is_valid() {
|
||||
assert_valid_path("/test//test", "/test/test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_one_segment_is_valid() {
|
||||
assert_valid_path("/test", "/test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_one_segment_with_trailing_slash_is_valid() {
|
||||
assert_valid_path("/test/", "/test/");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_two_segments_is_valid() {
|
||||
assert_valid_path("/test/test", "/test/test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_is_valid() {
|
||||
assert_valid_path("/dav/file*.txt", "/dav/file*.txt");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_two_slashes_in_the_middle_with_slash_is_valid() {
|
||||
assert_valid_path("/dav//folder/", "/dav/folder/");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_script_tag_is_valid() {
|
||||
assert_valid_path("/dav/<script>", "/dav/<script>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_null_is_invalid() {
|
||||
assert_invalid_path("/dav/file\0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_path_is_invalid() {
|
||||
assert_invalid_path("");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_missing_root_slash1_is_invalid() {
|
||||
assert_invalid_path("test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_missing_root_slash2_is_invalid() {
|
||||
assert_invalid_path("test/");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_path_test_over_test() {
|
||||
assert_invalid_path("test/test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_path_http_example_com_test() {
|
||||
assert_invalid_path("http://example.com/test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_path_backslash_test_backslash() {
|
||||
assert_invalid_path("\\test\\");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_path_dot() {
|
||||
assert_invalid_path(".");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_path_dot_dot() {
|
||||
assert_invalid_path("..");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_windows_path() {
|
||||
assert_invalid_path("C:\\dav\\file");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_path_dav_uber() {
|
||||
assert_valid_path("/dav/über", "/dav/über");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_webdav_pub_required() {
|
||||
WebDavPath::from_str("/pub/file.txt").expect("Should be valid");
|
||||
WebDavPath::from_str("/file.txt").expect_err("Should not be valid. /pub/ required.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_url_encode() {
|
||||
let url_encoded = "/pub/file%25.txt";
|
||||
let url_decoded = percent_encoding::percent_decode_str(url_encoded)
|
||||
.decode_utf8()
|
||||
.unwrap()
|
||||
.to_string();
|
||||
let path = WebDavPath::new(url_decoded.as_str()).unwrap();
|
||||
let normalized = path.to_string();
|
||||
assert_eq!(normalized, "/pub/file%.txt");
|
||||
assert_eq!(path.url_encode(), url_encoded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_segment_too_long() {
|
||||
let long_segment = "a".repeat(MAX_WEBDAV_PATH_SEGMENT_LENGTH + 1);
|
||||
let path = format!("/prefix/{}/suffix", long_segment);
|
||||
assert_invalid_path(&path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_segment_max_length_is_valid() {
|
||||
let max_segment = "a".repeat(MAX_WEBDAV_PATH_SEGMENT_LENGTH);
|
||||
let path = format!("/prefix/{}/suffix", max_segment);
|
||||
let expected_path = path.clone(); // Expected path is the same as input if valid
|
||||
assert_valid_path(&path, &expected_path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_total_path_too_long() {
|
||||
let num_segments = MAX_WEBDAV_PATH_TOTAL_LENGTH; // This will create path like "/a/a/.../a"
|
||||
let segments: Vec<String> = std::iter::repeat("a".to_string())
|
||||
.take(num_segments)
|
||||
.collect();
|
||||
let path = format!("/{}", segments.join("/"));
|
||||
assert_invalid_path(&path);
|
||||
|
||||
let almost_too_long_segment = "a".repeat(MAX_WEBDAV_PATH_TOTAL_LENGTH - 1); // e.g., if max is 10, this is 9 'a's
|
||||
let path_too_long = format!("/{}/b", almost_too_long_segment); // "/aaaaaaaaa/b" -> 1 + 9 + 1 + 1 = 12 > 10
|
||||
assert_invalid_path(&path_too_long);
|
||||
}
|
||||
}
|
||||
61
pubky-homeserver/src/shared/webdav/webdav_path_axum.rs
Normal file
61
pubky-homeserver/src/shared/webdav/webdav_path_axum.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::WebDavPath;
|
||||
|
||||
/// A webdav path that can be used with axum.
|
||||
///
|
||||
/// When using `.route("/{*path}", your_handler)` in axum, the path is passed without the leading slash.
|
||||
/// This struct adds the leading slash back and therefore allows direct validation of the path.
|
||||
///
|
||||
/// Usage in handler:
|
||||
///
|
||||
/// `Path(path): Path<WebDavPathAxum>`
|
||||
pub struct WebDavPathAxum(pub WebDavPath);
|
||||
|
||||
impl std::fmt::Display for WebDavPathAxum {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for WebDavPathAxum {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let with_slash = format!("/{}", s);
|
||||
let inner = WebDavPath::new(&with_slash)?;
|
||||
Ok(Self(inner))
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for WebDavPathAxum {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(self.0.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for WebDavPathAxum {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
Self::from_str(&s).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_webdav_path_axum() {
|
||||
let path = WebDavPathAxum::from_str("pub/foo/bar").unwrap();
|
||||
assert_eq!(path.0.as_str(), "/pub/foo/bar");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user