diff --git a/e2e/src/tests/public.rs b/e2e/src/tests/public.rs index 18e1906..68ce811 100644 --- a/e2e/src/tests/public.rs +++ b/e2e/src/tests/public.rs @@ -29,10 +29,33 @@ async fn put_get_delete() { .error_for_status() .unwrap(); + // Use Pubky native method to get data from homeserver let response = client.get(url).send().await.unwrap().bytes().await.unwrap(); assert_eq!(response, bytes::Bytes::from(vec![0, 1, 2, 3, 4])); + // Use regular web method to get data from homeserver (with query pubky-host) + let regular_url = format!( + "{}pub/foo.txt?pubky-host={}", + server.url(), + keypair.public_key() + ); + + // We set `non.pubky.host` header as otherwise he client will use by default + // the homeserver pubky as host and this request will resolve the `/pub/foo.txt` of + // the wrong tenant user + let response = client + .get(regular_url) + .header("Host", "non.pubky.host") + .send() + .await + .unwrap() + .bytes() + .await + .unwrap(); + + assert_eq!(response, bytes::Bytes::from(vec![0, 1, 2, 3, 4])); + client .delete(url) .send() diff --git a/pubky-homeserver/src/core/layers/pubky_host.rs b/pubky-homeserver/src/core/layers/pubky_host.rs index 1af6d1e..8b2ab40 100644 --- a/pubky-homeserver/src/core/layers/pubky_host.rs +++ b/pubky-homeserver/src/core/layers/pubky_host.rs @@ -1,15 +1,12 @@ -use pkarr::PublicKey; - +use crate::core::error::Result; use crate::core::extractors::PubkyHost; - use axum::{body::Body, http::Request}; use futures_util::future::BoxFuture; +use pkarr::PublicKey; use std::{convert::Infallible, task::Poll}; use tower::{Layer, Service}; -use crate::core::error::Result; - -/// A Tower Layer to handle authorization for write operations. +/// A Tower Layer to extract and inject the PubkyHost into request extensions. #[derive(Debug, Clone)] pub struct PubkyHostLayer; @@ -21,7 +18,7 @@ impl Layer for PubkyHostLayer { } } -/// Middleware that performs authorization checks for write operations. +/// Middleware that extracts the public key from headers or query parameters. #[derive(Debug, Clone)] pub struct PubkyHostLayerMiddleware { inner: S, @@ -31,8 +28,8 @@ impl Service> for PubkyHostLayerMiddleware where S: Service, Response = axum::response::Response, Error = Infallible> + Send - + 'static - + Clone, + + Clone + + 'static, S::Future: Send + 'static, { type Response = S::Response; @@ -40,25 +37,47 @@ where type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { - self.inner.poll_ready(cx).map_err(|_| unreachable!()) // `Infallible` conversion + self.inner.poll_ready(cx).map_err(|_| unreachable!()) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, mut req: Request) -> Self::Future { + if let Some(public_key) = extract_pubky(&req) { + req.extensions_mut().insert(PubkyHost(public_key)); + } let mut inner = self.inner.clone(); - let mut req = req; - - Box::pin(async move { - let headers_to_check = ["host", "pubky-host"]; - - for header in headers_to_check { - if let Some(Ok(pubky_host)) = req.headers().get(header).map(|h| h.to_str()) { - if let Ok(public_key) = PublicKey::try_from(pubky_host) { - req.extensions_mut().insert(PubkyHost(public_key)); - } - } - } - - inner.call(req).await.map_err(|_| unreachable!()) - }) + Box::pin(async move { inner.call(req).await.map_err(|_| unreachable!()) }) } } + +/// Extracts a PublicKey by checking, in order: +/// 1. The "host" header. +/// 2. The "pubky-host" header (which overwrites any previously found key). +/// 3. The query parameter "pubky-host" if none was found in headers. +fn extract_pubky(req: &Request) -> Option { + let mut pubky = None; + // Check headers in order: "host" then "pubky-host". + for header in ["host", "pubky-host"].iter() { + if let Some(val) = req.headers().get(*header) { + if let Ok(s) = val.to_str() { + if let Ok(key) = PublicKey::try_from(s) { + pubky = Some(key); + } + } + } + } + // If still no key, fall back to query parameter. + if pubky.is_none() { + pubky = req.uri().query().and_then(|query| { + query.split('&').find_map(|pair| { + let mut parts = pair.splitn(2, '='); + if let (Some(key), Some(val)) = (parts.next(), parts.next()) { + if key == "pubky-host" { + return PublicKey::try_from(val).ok(); + } + } + None + }) + }); + } + pubky +}