diff --git a/http-relay/src/lib.rs b/http-relay/src/lib.rs index 08e066d..ef5e0d2 100644 --- a/http-relay/src/lib.rs +++ b/http-relay/src/lib.rs @@ -1,26 +1,36 @@ use std::{ collections::HashMap, net::{SocketAddr, TcpListener}, - sync::{Arc, Mutex}, + sync::Arc, }; use anyhow::Result; use axum::{ - body::{Body, Bytes}, + body::Bytes, extract::{Path, State}, response::IntoResponse, routing::get, Router, }; use axum_server::Handle; -use tokio::sync::Notify; +use tokio::sync::{oneshot, Mutex}; -use futures_util::{stream::StreamExt, TryFutureExt}; +use futures_util::TryFutureExt; use url::Url; // Shared state to store GET requests and their notifications -type SharedState = Arc, Arc)>>>; +type SharedState = Arc>>; + +enum ChannelState { + ProducerWaiting { + body: Bytes, + completion: oneshot::Sender<()>, + }, + ConsumerWaiting { + message_sender: oneshot::Sender, + }, +} #[derive(Debug, Default)] pub struct Config { @@ -110,58 +120,60 @@ impl HttpRelay { } mod link { + use axum::http::StatusCode; + use super::*; pub async fn get( Path(id): Path, State(state): State, ) -> impl IntoResponse { - // Create a notification for this ID - let notify = Arc::new(Notify::new()); + let mut channels = state.lock().await; - { - let mut map = state.lock().unwrap(); + match channels.remove(&id) { + Some(ChannelState::ProducerWaiting { body, completion }) => { + let _ = completion.send(()); - // Store the notification and return it when POST arrives - map.entry(id.clone()) - .or_insert_with(|| (vec![], notify.clone())); - } + (StatusCode::OK, body) + } + _ => { + let (message_sender, message_receiver) = oneshot::channel(); + channels.insert(id, ChannelState::ConsumerWaiting { message_sender }); + drop(channels); - notify.notified().await; - - // Respond with the data stored for this ID - let map = state.lock().unwrap(); - if let Some((data, _)) = map.get(&id) { - Bytes::from(data.clone()).into_response() - } else { - (axum::http::StatusCode::NOT_FOUND, "Not Found").into_response() + match message_receiver.await { + Ok(message) => (StatusCode::OK, message), + Err(_) => (StatusCode::NOT_FOUND, "Not Found".into()), + } + } } } pub async fn post( - Path(id): Path, + Path(channel): Path, State(state): State, - body: Body, + body: Bytes, ) -> impl IntoResponse { - // Aggregate the body into bytes - let mut stream = body.into_data_stream(); - let mut bytes = vec![]; - while let Some(next) = stream.next().await { - let chunk = next.map_err(|e| e.to_string()).unwrap(); - bytes.extend_from_slice(&chunk); - } + let mut channels = state.lock().await; - // Notify any waiting GET request for this ID - let mut map = state.lock().unwrap(); - if let Some((storage, notify)) = map.get_mut(&id) { - *storage = bytes; - notify.notify_one(); - Ok(()) - } else { - Err(( - axum::http::StatusCode::NOT_FOUND, - "No waiting GET request for this ID", - )) + match channels.remove(&channel) { + Some(ChannelState::ConsumerWaiting { message_sender }) => { + let _ = message_sender.send(body); + (StatusCode::OK, ()) + } + _ => { + let (completion_sender, completion_receiver) = oneshot::channel(); + channels.insert( + channel, + ChannelState::ProducerWaiting { + body, + completion: completion_sender, + }, + ); + drop(channels); + let _ = completion_receiver.await; + (StatusCode::OK, ()) + } } } } diff --git a/pubky/README.md b/pubky/README.md index d884443..5ed0b10 100644 --- a/pubky/README.md +++ b/pubky/README.md @@ -39,7 +39,7 @@ async fn main () { assert_eq!(response, bytes::Bytes::from(vec![0, 1, 2, 3, 4])); - // Delet an entry. + // Delete an entry. client.delete(url).await.unwrap(); let response = client.get(url).await.unwrap(); diff --git a/pubky/src/lib.rs b/pubky/src/lib.rs index cceaff8..4cc1d41 100644 --- a/pubky/src/lib.rs +++ b/pubky/src/lib.rs @@ -9,10 +9,12 @@ macro_rules! cross_debug { ($($arg:tt)*) => { - #[cfg(target_arch = "wasm32")] + #[cfg(all(not(test), target_arch = "wasm32"))] log::debug!($($arg)*); - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(not(test), not(target_arch = "wasm32")))] tracing::debug!($($arg)*); + #[cfg(test)] + println!($($arg)*); }; } diff --git a/pubky/src/native.rs b/pubky/src/native.rs index 2693377..4b1b1c7 100644 --- a/pubky/src/native.rs +++ b/pubky/src/native.rs @@ -13,7 +13,8 @@ pub mod api { use std::fmt::Debug; #[cfg(not(wasm_browser))] -use std::{sync::Arc, time::Duration}; +use std::sync::Arc; +use std::time::Duration; #[cfg(not(wasm_browser))] use mainline::Testnet; @@ -35,6 +36,7 @@ macro_rules! handle_http_error { #[derive(Debug, Default, Clone)] pub struct ClientBuilder { pkarr: pkarr::ClientBuilder, + http_request_timeout: Option, } impl ClientBuilder { @@ -46,7 +48,7 @@ impl ClientBuilder { pub fn testnet(&mut self, testnet: &Testnet) -> &mut Self { let bootstrap = testnet.bootstrap.clone(); - self.pkarr.bootstrap(&bootstrap); + self.pkarr.no_default_network().bootstrap(&bootstrap); if std::env::var("CI").is_err() { self.pkarr.request_timeout(Duration::from_millis(500)); @@ -65,6 +67,13 @@ impl ClientBuilder { self } + /// Set HTTP requests timeout. + pub fn request_timeout(&mut self, timeout: Duration) -> &mut Self { + self.http_request_timeout = Some(timeout); + + self + } + /// Build [Client] pub fn build(&self) -> Result { let pkarr = self.pkarr.build()?; @@ -76,28 +85,34 @@ impl ClientBuilder { let user_agent = DEFAULT_USER_AGENT; #[cfg(not(wasm_browser))] - let http = reqwest::ClientBuilder::from(pkarr.clone()) + let mut http_builder = reqwest::ClientBuilder::from(pkarr.clone()) // TODO: use persistent cookie jar .cookie_provider(cookie_store.clone()) - .user_agent(user_agent) - .build() - .expect("config expected to not error"); + .user_agent(user_agent); #[cfg(wasm_browser)] - let http = reqwest::Client::builder() - .user_agent(user_agent) - .build() - .expect("config expected to not error"); + let mut http_builder = reqwest::Client::builder().user_agent(user_agent); + + #[cfg(not(wasm_browser))] + let mut icann_http_builder = reqwest::Client::builder() + // TODO: use persistent cookie jar + .cookie_provider(cookie_store.clone()) + .user_agent(user_agent); + + // TODO: change this after Reqwest publish a release with timeout in wasm + #[cfg(not(wasm_browser))] + if let Some(timeout) = self.http_request_timeout { + http_builder = http_builder.timeout(timeout); + + icann_http_builder = icann_http_builder.timeout(timeout); + } Ok(Client { - http, pkarr, + http: http_builder.build().expect("config expected to not error"), #[cfg(not(wasm_browser))] - icann_http: reqwest::Client::builder() - // TODO: use persistent cookie jar - .cookie_provider(cookie_store.clone()) - .user_agent(user_agent) + icann_http: icann_http_builder .build() .expect("config expected to not error"), #[cfg(not(wasm_browser))] diff --git a/pubky/src/native/api/auth.rs b/pubky/src/native/api/auth.rs index dcee93c..375570c 100644 --- a/pubky/src/native/api/auth.rs +++ b/pubky/src/native/api/auth.rs @@ -257,7 +257,6 @@ impl Client { } } }?; - cross_debug!("LOOPING xxx {:?}", &response); let encrypted_token = response.bytes().await?; let token_bytes = decrypt(&encrypted_token, client_secret) @@ -283,6 +282,7 @@ impl AuthRequest { &self.url } + // TODO: Return better errors pub async fn response(&self) -> Result { self.rx .recv_async() @@ -293,6 +293,8 @@ impl AuthRequest { #[cfg(test)] mod tests { + use std::time::Duration; + use crate::*; use http_relay::HttpRelay; @@ -453,4 +455,88 @@ mod tests { assert_eq!(session.pubky(), &second_keypair.public_key()); assert!(session.capabilities().contains(&Capability::root())); } + + #[tokio::test] + async fn authz_timeout_reconnect() { + let testnet = Testnet::new(10).unwrap(); + let server = Homeserver::start_test(&testnet).await.unwrap(); + + let http_relay = HttpRelay::builder().build().await.unwrap(); + let http_relay_url = http_relay.local_link_url(); + + let keypair = Keypair::random(); + let pubky = keypair.public_key(); + + // Third party app side + let capabilities: Capabilities = + "/pub/pubky.app/:rw,/pub/foo.bar/file:r".try_into().unwrap(); + + let client = Client::builder() + .pkarr(|builder| builder.no_default_network().bootstrap(&testnet.bootstrap)) + .request_timeout(Duration::from_millis(1000)) + .build() + .unwrap(); + + let pubky_auth_request = client.auth_request(http_relay_url, &capabilities).unwrap(); + + // Authenticator side + { + let url = pubky_auth_request.url().clone(); + + tokio::spawn(async move { + loop { + tokio::time::sleep(Duration::from_millis(400)).await; + + let client = Client::builder() + .pkarr(|builder| builder.no_default_network().bootstrap(&testnet.bootstrap)) + .build() + .unwrap(); + + client.signup(&keypair, &server.public_key()).await.unwrap(); + + client.send_auth_token(&keypair, &url).await.unwrap(); + } + }); + } + + let public_key = pubky_auth_request.response().await.unwrap(); + + assert_eq!(&public_key, &pubky); + + let session = client.session(&pubky).await.unwrap().unwrap(); + assert_eq!(session.capabilities(), &capabilities.0); + + // Test access control enforcement + + client + .put(format!("pubky://{pubky}/pub/pubky.app/foo")) + .body(vec![]) + .send() + .await + .unwrap() + .error_for_status() + .unwrap(); + + assert_eq!( + client + .put(format!("pubky://{pubky}/pub/pubky.app")) + .body(vec![]) + .send() + .await + .unwrap() + .status(), + StatusCode::FORBIDDEN + ); + + assert_eq!( + client + .put(format!("pubky://{pubky}/pub/foo.bar/file")) + .body(vec![]) + .send() + .await + .unwrap() + .status(), + StatusCode::FORBIDDEN + ); + } }