feat(pubky): update http relay to correctly test reconnecting

This commit is contained in:
nazeh
2025-02-10 14:48:28 +03:00
parent 2615461d95
commit bc1960e4e9
5 changed files with 175 additions and 60 deletions

View File

@@ -1,26 +1,36 @@
use std::{
collections::HashMap,
net::{SocketAddr, TcpListener},
sync::{Arc, Mutex},
sync::Arc,
};
use anyhow::Result;
use axum::{
body::{Body, Bytes},
body::Bytes,
extract::{Path, State},
response::IntoResponse,
routing::get,
Router,
};
use axum_server::Handle;
use tokio::sync::Notify;
use tokio::sync::{oneshot, Mutex};
use futures_util::{stream::StreamExt, TryFutureExt};
use futures_util::TryFutureExt;
use url::Url;
// Shared state to store GET requests and their notifications
type SharedState = Arc<Mutex<HashMap<String, (Vec<u8>, Arc<Notify>)>>>;
type SharedState = Arc<Mutex<HashMap<String, ChannelState>>>;
enum ChannelState {
ProducerWaiting {
body: Bytes,
completion: oneshot::Sender<()>,
},
ConsumerWaiting {
message_sender: oneshot::Sender<Bytes>,
},
}
#[derive(Debug, Default)]
pub struct Config {
@@ -110,58 +120,60 @@ impl HttpRelay {
}
mod link {
use axum::http::StatusCode;
use super::*;
pub async fn get(
Path(id): Path<String>,
State(state): State<SharedState>,
) -> impl IntoResponse {
// Create a notification for this ID
let notify = Arc::new(Notify::new());
let mut channels = state.lock().await;
{
let mut map = state.lock().unwrap();
match channels.remove(&id) {
Some(ChannelState::ProducerWaiting { body, completion }) => {
let _ = completion.send(());
// Store the notification and return it when POST arrives
map.entry(id.clone())
.or_insert_with(|| (vec![], notify.clone()));
}
(StatusCode::OK, body)
}
_ => {
let (message_sender, message_receiver) = oneshot::channel();
channels.insert(id, ChannelState::ConsumerWaiting { message_sender });
drop(channels);
notify.notified().await;
// Respond with the data stored for this ID
let map = state.lock().unwrap();
if let Some((data, _)) = map.get(&id) {
Bytes::from(data.clone()).into_response()
} else {
(axum::http::StatusCode::NOT_FOUND, "Not Found").into_response()
match message_receiver.await {
Ok(message) => (StatusCode::OK, message),
Err(_) => (StatusCode::NOT_FOUND, "Not Found".into()),
}
}
}
}
pub async fn post(
Path(id): Path<String>,
Path(channel): Path<String>,
State(state): State<SharedState>,
body: Body,
body: Bytes,
) -> impl IntoResponse {
// Aggregate the body into bytes
let mut stream = body.into_data_stream();
let mut bytes = vec![];
while let Some(next) = stream.next().await {
let chunk = next.map_err(|e| e.to_string()).unwrap();
bytes.extend_from_slice(&chunk);
}
let mut channels = state.lock().await;
// Notify any waiting GET request for this ID
let mut map = state.lock().unwrap();
if let Some((storage, notify)) = map.get_mut(&id) {
*storage = bytes;
notify.notify_one();
Ok(())
} else {
Err((
axum::http::StatusCode::NOT_FOUND,
"No waiting GET request for this ID",
))
match channels.remove(&channel) {
Some(ChannelState::ConsumerWaiting { message_sender }) => {
let _ = message_sender.send(body);
(StatusCode::OK, ())
}
_ => {
let (completion_sender, completion_receiver) = oneshot::channel();
channels.insert(
channel,
ChannelState::ProducerWaiting {
body,
completion: completion_sender,
},
);
drop(channels);
let _ = completion_receiver.await;
(StatusCode::OK, ())
}
}
}
}

View File

@@ -39,7 +39,7 @@ async fn main () {
assert_eq!(response, bytes::Bytes::from(vec![0, 1, 2, 3, 4]));
// Delet an entry.
// Delete an entry.
client.delete(url).await.unwrap();
let response = client.get(url).await.unwrap();

View File

@@ -9,10 +9,12 @@
macro_rules! cross_debug {
($($arg:tt)*) => {
#[cfg(target_arch = "wasm32")]
#[cfg(all(not(test), target_arch = "wasm32"))]
log::debug!($($arg)*);
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(not(test), not(target_arch = "wasm32")))]
tracing::debug!($($arg)*);
#[cfg(test)]
println!($($arg)*);
};
}

View File

@@ -13,7 +13,8 @@ pub mod api {
use std::fmt::Debug;
#[cfg(not(wasm_browser))]
use std::{sync::Arc, time::Duration};
use std::sync::Arc;
use std::time::Duration;
#[cfg(not(wasm_browser))]
use mainline::Testnet;
@@ -35,6 +36,7 @@ macro_rules! handle_http_error {
#[derive(Debug, Default, Clone)]
pub struct ClientBuilder {
pkarr: pkarr::ClientBuilder,
http_request_timeout: Option<Duration>,
}
impl ClientBuilder {
@@ -46,7 +48,7 @@ impl ClientBuilder {
pub fn testnet(&mut self, testnet: &Testnet) -> &mut Self {
let bootstrap = testnet.bootstrap.clone();
self.pkarr.bootstrap(&bootstrap);
self.pkarr.no_default_network().bootstrap(&bootstrap);
if std::env::var("CI").is_err() {
self.pkarr.request_timeout(Duration::from_millis(500));
@@ -65,6 +67,13 @@ impl ClientBuilder {
self
}
/// Set HTTP requests timeout.
pub fn request_timeout(&mut self, timeout: Duration) -> &mut Self {
self.http_request_timeout = Some(timeout);
self
}
/// Build [Client]
pub fn build(&self) -> Result<Client, BuildError> {
let pkarr = self.pkarr.build()?;
@@ -76,28 +85,34 @@ impl ClientBuilder {
let user_agent = DEFAULT_USER_AGENT;
#[cfg(not(wasm_browser))]
let http = reqwest::ClientBuilder::from(pkarr.clone())
let mut http_builder = reqwest::ClientBuilder::from(pkarr.clone())
// TODO: use persistent cookie jar
.cookie_provider(cookie_store.clone())
.user_agent(user_agent)
.build()
.expect("config expected to not error");
.user_agent(user_agent);
#[cfg(wasm_browser)]
let http = reqwest::Client::builder()
.user_agent(user_agent)
.build()
.expect("config expected to not error");
let mut http_builder = reqwest::Client::builder().user_agent(user_agent);
#[cfg(not(wasm_browser))]
let mut icann_http_builder = reqwest::Client::builder()
// TODO: use persistent cookie jar
.cookie_provider(cookie_store.clone())
.user_agent(user_agent);
// TODO: change this after Reqwest publish a release with timeout in wasm
#[cfg(not(wasm_browser))]
if let Some(timeout) = self.http_request_timeout {
http_builder = http_builder.timeout(timeout);
icann_http_builder = icann_http_builder.timeout(timeout);
}
Ok(Client {
http,
pkarr,
http: http_builder.build().expect("config expected to not error"),
#[cfg(not(wasm_browser))]
icann_http: reqwest::Client::builder()
// TODO: use persistent cookie jar
.cookie_provider(cookie_store.clone())
.user_agent(user_agent)
icann_http: icann_http_builder
.build()
.expect("config expected to not error"),
#[cfg(not(wasm_browser))]

View File

@@ -257,7 +257,6 @@ impl Client {
}
}
}?;
cross_debug!("LOOPING xxx {:?}", &response);
let encrypted_token = response.bytes().await?;
let token_bytes = decrypt(&encrypted_token, client_secret)
@@ -283,6 +282,7 @@ impl AuthRequest {
&self.url
}
// TODO: Return better errors
pub async fn response(&self) -> Result<PublicKey> {
self.rx
.recv_async()
@@ -293,6 +293,8 @@ impl AuthRequest {
#[cfg(test)]
mod tests {
use std::time::Duration;
use crate::*;
use http_relay::HttpRelay;
@@ -453,4 +455,88 @@ mod tests {
assert_eq!(session.pubky(), &second_keypair.public_key());
assert!(session.capabilities().contains(&Capability::root()));
}
#[tokio::test]
async fn authz_timeout_reconnect() {
let testnet = Testnet::new(10).unwrap();
let server = Homeserver::start_test(&testnet).await.unwrap();
let http_relay = HttpRelay::builder().build().await.unwrap();
let http_relay_url = http_relay.local_link_url();
let keypair = Keypair::random();
let pubky = keypair.public_key();
// Third party app side
let capabilities: Capabilities =
"/pub/pubky.app/:rw,/pub/foo.bar/file:r".try_into().unwrap();
let client = Client::builder()
.pkarr(|builder| builder.no_default_network().bootstrap(&testnet.bootstrap))
.request_timeout(Duration::from_millis(1000))
.build()
.unwrap();
let pubky_auth_request = client.auth_request(http_relay_url, &capabilities).unwrap();
// Authenticator side
{
let url = pubky_auth_request.url().clone();
tokio::spawn(async move {
loop {
tokio::time::sleep(Duration::from_millis(400)).await;
let client = Client::builder()
.pkarr(|builder| builder.no_default_network().bootstrap(&testnet.bootstrap))
.build()
.unwrap();
client.signup(&keypair, &server.public_key()).await.unwrap();
client.send_auth_token(&keypair, &url).await.unwrap();
}
});
}
let public_key = pubky_auth_request.response().await.unwrap();
assert_eq!(&public_key, &pubky);
let session = client.session(&pubky).await.unwrap().unwrap();
assert_eq!(session.capabilities(), &capabilities.0);
// Test access control enforcement
client
.put(format!("pubky://{pubky}/pub/pubky.app/foo"))
.body(vec![])
.send()
.await
.unwrap()
.error_for_status()
.unwrap();
assert_eq!(
client
.put(format!("pubky://{pubky}/pub/pubky.app"))
.body(vec![])
.send()
.await
.unwrap()
.status(),
StatusCode::FORBIDDEN
);
assert_eq!(
client
.put(format!("pubky://{pubky}/pub/foo.bar/file"))
.body(vec![])
.send()
.await
.unwrap()
.status(),
StatusCode::FORBIDDEN
);
}
}