diff --git a/Cargo.lock b/Cargo.lock index 9199a20..746cb90 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1732,10 +1732,12 @@ dependencies = [ "futures-util", "heed", "hex", + "httpdate", "libc", "pkarr", "postcard", "pubky-common", + "reqwest", "serde", "tokio", "toml", diff --git a/pubky-homeserver/Cargo.toml b/pubky-homeserver/Cargo.toml index bb1a908..7cfaa34 100644 --- a/pubky-homeserver/Cargo.toml +++ b/pubky-homeserver/Cargo.toml @@ -15,6 +15,7 @@ flume = "0.11.0" futures-util = "0.3.30" heed = "0.20.3" hex = "0.4.3" +httpdate = "1.0.3" libc = "0.2.159" pkarr = { workspace = true } postcard = { version = "1.0.8", features = ["alloc"] } @@ -27,3 +28,6 @@ tower-http = { version = "0.5.2", features = ["cors", "trace"] } tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } url = "2.5.2" + +[dev-dependencies] +reqwest = "0.12.8" diff --git a/pubky-homeserver/src/database/tables/entries.rs b/pubky-homeserver/src/database/tables/entries.rs index 1899578..9928c34 100644 --- a/pubky-homeserver/src/database/tables/entries.rs +++ b/pubky-homeserver/src/database/tables/entries.rs @@ -1,7 +1,11 @@ use pkarr::PublicKey; use postcard::{from_bytes, to_allocvec}; use serde::{Deserialize, Serialize}; -use std::{fs::File, io::Read, path::PathBuf}; +use std::{ + fs::File, + io::{Read, Write}, + path::PathBuf, +}; use tracing::instrument; use heed::{ @@ -345,6 +349,14 @@ impl<'db> EntryWriter<'db> { }) } + /// Same ase [EntryWriter::write_all] but returns a Result of a mutable reference of itself + /// to enable chaining with [Self::commit]. + pub fn update(&mut self, chunk: &[u8]) -> Result<&mut Self, std::io::Error> { + self.write_all(chunk)?; + + Ok(self) + } + /// Commit blob from the filesystem buffer to LMDB, /// write the [Entry], and commit the write transaction. pub fn commit(&self) -> anyhow::Result { @@ -432,8 +444,6 @@ impl<'db> std::io::Write for EntryWriter<'db> { #[cfg(test)] mod tests { - use std::io::Write; - use bytes::Bytes; use pkarr::{mainline::Testnet, Keypair}; @@ -442,7 +452,7 @@ mod tests { use super::DB; #[tokio::test] - async fn entries() { + async fn entries() -> anyhow::Result<()> { let mut db = DB::open(Config::test(&Testnet::new(0))).unwrap(); let keypair = Keypair::random(); @@ -451,9 +461,9 @@ mod tests { let chunk = Bytes::from(vec![1, 2, 3, 4, 5]); - let mut entry_writer = db.write_entry(&public_key, path).unwrap(); - entry_writer.write_all(&chunk).unwrap(); - entry_writer.commit().unwrap(); + db.write_entry(&public_key, path)? + .update(&chunk)? + .commit()?; let rtxn = db.env.read_txn().unwrap(); let entry = db.get_entry(&rtxn, &public_key, path).unwrap().unwrap(); @@ -479,10 +489,12 @@ mod tests { assert_eq!(blob, vec![1, 2, 3, 4, 5]); rtxn.commit().unwrap(); + + Ok(()) } #[tokio::test] - async fn chunked_entry() { + async fn chunked_entry() -> anyhow::Result<()> { let mut db = DB::open(Config::test(&Testnet::new(0))).unwrap(); let keypair = Keypair::random(); @@ -491,9 +503,9 @@ mod tests { let chunk = Bytes::from(vec![0; 1024 * 1024]); - let mut entry_writer = db.write_entry(&public_key, path).unwrap(); - entry_writer.write_all(&chunk).unwrap(); - entry_writer.commit().unwrap(); + db.write_entry(&public_key, path)? + .update(&chunk)? + .commit()?; let rtxn = db.env.read_txn().unwrap(); let entry = db.get_entry(&rtxn, &public_key, path).unwrap().unwrap(); @@ -522,5 +534,7 @@ mod tests { assert_eq!(stats.overflow_pages, 0); rtxn.commit().unwrap(); + + Ok(()) } } diff --git a/pubky-homeserver/src/routes/public.rs b/pubky-homeserver/src/routes/public.rs index 0e91ae9..3b9963f 100644 --- a/pubky-homeserver/src/routes/public.rs +++ b/pubky-homeserver/src/routes/public.rs @@ -1,12 +1,14 @@ use axum::{ body::Body, + debug_handler, extract::State, http::{header, HeaderMap, HeaderValue, Response, StatusCode}, response::IntoResponse, }; use futures_util::stream::StreamExt; +use httpdate::HttpDate; use pkarr::PublicKey; -use std::io::Write; +use std::{io::Write, str::FromStr}; use tower_cookies::Cookies; use crate::{ @@ -44,8 +46,10 @@ pub async fn put( Ok(()) } +#[debug_handler] pub async fn get( State(state): State, + headers: HeaderMap, pubky: Pubky, path: EntryPath, params: ListQueryParams, @@ -105,28 +109,16 @@ pub async fn get( Ok(()) }); - if let Some(entry) = entry_rx.recv_async().await? { - // TODO: add HEAD endpoint - // TODO: Enable seek API (range requests) - // TODO: Gzip? or brotli? - - Ok(Response::builder() - .status(StatusCode::OK) - .header(header::CONTENT_LENGTH, entry.content_length()) - .header(header::CONTENT_TYPE, entry.content_type()) - .header( - header::ETAG, - format!("\"{}\"", entry.content_hash().to_hex()), - ) - .body(Body::from_stream(chunks_rx.into_stream())) - .unwrap()) - } else { - Err(Error::with_status(StatusCode::NOT_FOUND))? - } + get_entry( + headers, + entry_rx.recv_async().await?, + Some(Body::from_stream(chunks_rx.into_stream())), + ) } pub async fn head( State(state): State, + headers: HeaderMap, pubky: Pubky, path: EntryPath, ) -> Result { @@ -134,14 +126,62 @@ pub async fn head( let rtxn = state.db.env.read_txn()?; - match state - .db - .get_entry(&rtxn, pubky.public_key(), path.as_str())? - .as_ref() - .map(HeaderMap::from) - { - Some(headers) => Ok(headers), - None => Err(Error::with_status(StatusCode::NOT_FOUND)), + get_entry( + headers, + state + .db + .get_entry(&rtxn, pubky.public_key(), path.as_str())?, + None, + ) +} + +pub fn get_entry( + headers: HeaderMap, + entry: Option, + body: Option, +) -> Result> { + if let Some(entry) = entry { + // TODO: Enable seek API (range requests) + // TODO: Gzip? or brotli? + + let mut response = HeaderMap::from(&entry).into_response(); + + // Handle IF_MODIFIED_SINCE + if let Some(condition_http_date) = headers + .get(header::IF_MODIFIED_SINCE) + .and_then(|h| h.to_str().ok()) + .and_then(|s| HttpDate::from_str(s).ok()) + { + let entry_http_date: HttpDate = entry.timestamp().to_owned().into(); + + if condition_http_date >= entry_http_date { + *response.status_mut() = StatusCode::NOT_MODIFIED; + } + }; + + // Handle IF_NONE_MATCH + if let Some(str) = headers + .get(header::IF_NONE_MATCH) + .and_then(|h| h.to_str().ok()) + { + let etag = format!("\"{}\"", entry.content_hash()); + if str + .trim() + .split(',') + .collect::>() + .contains(&etag.as_str()) + { + *response.status_mut() = StatusCode::NOT_MODIFIED; + }; + } + + if let Some(body) = body { + *response.body_mut() = body; + }; + + Ok(response) + } else { + Err(Error::with_status(StatusCode::NOT_FOUND))? } } @@ -237,3 +277,104 @@ impl From<&Entry> for HeaderMap { headers } } + +#[cfg(test)] +mod tests { + use axum::http::header; + use pkarr::{mainline::Testnet, Keypair}; + use reqwest::{self, Method, StatusCode}; + + use crate::Homeserver; + + #[tokio::test] + async fn if_last_modified() -> anyhow::Result<()> { + let testnet = Testnet::new(3); + let mut server = Homeserver::start_test(&testnet).await?; + + let public_key = Keypair::random().public_key(); + + let data = &[1, 2, 3, 4, 5]; + + server + .database_mut() + .write_entry(&public_key, "pub/foo")? + .update(data)? + .commit()?; + + let client = reqwest::Client::builder().build()?; + + let url = format!("http://localhost:{}/{public_key}/pub/foo", server.port()); + + let response = client.request(Method::GET, &url).send().await?; + + let response = client + .request(Method::GET, &url) + .header( + header::IF_MODIFIED_SINCE, + response.headers().get(header::LAST_MODIFIED).unwrap(), + ) + .send() + .await?; + + assert_eq!(response.status(), StatusCode::NOT_MODIFIED); + + let response = client + .request(Method::HEAD, &url) + .header( + header::IF_MODIFIED_SINCE, + response.headers().get(header::LAST_MODIFIED).unwrap(), + ) + .send() + .await?; + + assert_eq!(response.status(), StatusCode::NOT_MODIFIED); + + Ok(()) + } + + #[tokio::test] + async fn if_none_match() -> anyhow::Result<()> { + let testnet = Testnet::new(3); + let mut server = Homeserver::start_test(&testnet).await?; + + let public_key = Keypair::random().public_key(); + + let data = &[1, 2, 3, 4, 5]; + + server + .database_mut() + .write_entry(&public_key, "pub/foo")? + .update(data)? + .commit()?; + + let client = reqwest::Client::builder().build()?; + + let url = format!("http://localhost:{}/{public_key}/pub/foo", server.port()); + + let response = client.request(Method::GET, &url).send().await?; + + let response = client + .request(Method::GET, &url) + .header( + header::IF_NONE_MATCH, + response.headers().get(header::ETAG).unwrap(), + ) + .send() + .await?; + + assert_eq!(response.status(), StatusCode::NOT_MODIFIED); + + let response = client + .request(Method::HEAD, &url) + .header( + header::IF_NONE_MATCH, + response.headers().get(header::ETAG).unwrap(), + ) + .send() + .await?; + + assert_eq!(response.status(), StatusCode::NOT_MODIFIED); + + Ok(()) + } +} diff --git a/pubky-homeserver/src/server.rs b/pubky-homeserver/src/server.rs index d44d346..c3f8719 100644 --- a/pubky-homeserver/src/server.rs +++ b/pubky-homeserver/src/server.rs @@ -108,6 +108,11 @@ impl Homeserver { self.state.config.keypair().public_key() } + #[cfg(test)] + pub(crate) fn database_mut(&mut self) -> &mut DB { + &mut self.state.db + } + // === Public Methods === /// Shutdown the server and wait for all tasks to complete.