feat(homeserver): fallback missing pubky-host header to query param (#94)

* feat(homeserver): fallback missing pubky-host header to query param

* fix: host and overwrite with pubky-host

* add query pubky-host test
This commit is contained in:
SHAcollision
2025-03-26 09:33:40 -04:00
committed by GitHub
parent 8b0094ae6a
commit 283a38e9cc
2 changed files with 68 additions and 26 deletions

View File

@@ -29,10 +29,33 @@ async fn put_get_delete() {
.error_for_status()
.unwrap();
// Use Pubky native method to get data from homeserver
let response = client.get(url).send().await.unwrap().bytes().await.unwrap();
assert_eq!(response, bytes::Bytes::from(vec![0, 1, 2, 3, 4]));
// Use regular web method to get data from homeserver (with query pubky-host)
let regular_url = format!(
"{}pub/foo.txt?pubky-host={}",
server.url(),
keypair.public_key()
);
// We set `non.pubky.host` header as otherwise he client will use by default
// the homeserver pubky as host and this request will resolve the `/pub/foo.txt` of
// the wrong tenant user
let response = client
.get(regular_url)
.header("Host", "non.pubky.host")
.send()
.await
.unwrap()
.bytes()
.await
.unwrap();
assert_eq!(response, bytes::Bytes::from(vec![0, 1, 2, 3, 4]));
client
.delete(url)
.send()

View File

@@ -1,15 +1,12 @@
use pkarr::PublicKey;
use crate::core::error::Result;
use crate::core::extractors::PubkyHost;
use axum::{body::Body, http::Request};
use futures_util::future::BoxFuture;
use pkarr::PublicKey;
use std::{convert::Infallible, task::Poll};
use tower::{Layer, Service};
use crate::core::error::Result;
/// A Tower Layer to handle authorization for write operations.
/// A Tower Layer to extract and inject the PubkyHost into request extensions.
#[derive(Debug, Clone)]
pub struct PubkyHostLayer;
@@ -21,7 +18,7 @@ impl<S> Layer<S> for PubkyHostLayer {
}
}
/// Middleware that performs authorization checks for write operations.
/// Middleware that extracts the public key from headers or query parameters.
#[derive(Debug, Clone)]
pub struct PubkyHostLayerMiddleware<S> {
inner: S,
@@ -31,8 +28,8 @@ impl<S> Service<Request<Body>> for PubkyHostLayerMiddleware<S>
where
S: Service<Request<Body>, Response = axum::response::Response, Error = Infallible>
+ Send
+ 'static
+ Clone,
+ Clone
+ 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
@@ -40,25 +37,47 @@ where
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(|_| unreachable!()) // `Infallible` conversion
self.inner.poll_ready(cx).map_err(|_| unreachable!())
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
if let Some(public_key) = extract_pubky(&req) {
req.extensions_mut().insert(PubkyHost(public_key));
}
let mut inner = self.inner.clone();
let mut req = req;
Box::pin(async move {
let headers_to_check = ["host", "pubky-host"];
for header in headers_to_check {
if let Some(Ok(pubky_host)) = req.headers().get(header).map(|h| h.to_str()) {
if let Ok(public_key) = PublicKey::try_from(pubky_host) {
req.extensions_mut().insert(PubkyHost(public_key));
}
}
}
inner.call(req).await.map_err(|_| unreachable!())
})
Box::pin(async move { inner.call(req).await.map_err(|_| unreachable!()) })
}
}
/// Extracts a PublicKey by checking, in order:
/// 1. The "host" header.
/// 2. The "pubky-host" header (which overwrites any previously found key).
/// 3. The query parameter "pubky-host" if none was found in headers.
fn extract_pubky(req: &Request<Body>) -> Option<PublicKey> {
let mut pubky = None;
// Check headers in order: "host" then "pubky-host".
for header in ["host", "pubky-host"].iter() {
if let Some(val) = req.headers().get(*header) {
if let Ok(s) = val.to_str() {
if let Ok(key) = PublicKey::try_from(s) {
pubky = Some(key);
}
}
}
}
// If still no key, fall back to query parameter.
if pubky.is_none() {
pubky = req.uri().query().and_then(|query| {
query.split('&').find_map(|pair| {
let mut parts = pair.splitn(2, '=');
if let (Some(key), Some(val)) = (parts.next(), parts.next()) {
if key == "pubky-host" {
return PublicKey::try_from(val).ok();
}
}
None
})
});
}
pubky
}