diff --git a/crates/cashu/src/nuts/auth/nut21.rs b/crates/cashu/src/nuts/auth/nut21.rs index 6392bd4a..7ecc798b 100644 --- a/crates/cashu/src/nuts/auth/nut21.rs +++ b/crates/cashu/src/nuts/auth/nut21.rs @@ -161,6 +161,10 @@ pub enum RoutePath { /// Bolt12 Quote #[serde(rename = "/v1/melt/bolt12")] MeltBolt12, + + /// WebSocket + #[serde(rename = "/v1/ws")] + Ws, } /// Returns [`RoutePath`]s that match regex diff --git a/crates/cdk-axum/src/router_handlers.rs b/crates/cdk-axum/src/router_handlers.rs index 207f34c7..ab70e977 100644 --- a/crates/cdk-axum/src/router_handlers.rs +++ b/crates/cdk-axum/src/router_handlers.rs @@ -219,10 +219,23 @@ pub(crate) async fn get_check_mint_bolt11_quote( #[instrument(skip_all)] pub(crate) async fn ws_handler( + #[cfg(feature = "auth")] auth: AuthHeader, State(state): State, ws: WebSocketUpgrade, -) -> impl IntoResponse { - ws.on_upgrade(|ws| main_websocket(ws, state)) +) -> Result { + #[cfg(feature = "auth")] + { + state + .mint + .verify_auth( + auth.into(), + &ProtectedEndpoint::new(Method::Get, RoutePath::Ws), + ) + .await + .map_err(into_response)?; + } + + Ok(ws.on_upgrade(|ws| main_websocket(ws, state))) } /// Mint tokens by paying a BOLT11 Lightning invoice. diff --git a/crates/cdk-integration-tests/src/bin/start_fake_auth_mint.rs b/crates/cdk-integration-tests/src/bin/start_fake_auth_mint.rs index 43e5e4a8..7d7dc8e1 100644 --- a/crates/cdk-integration-tests/src/bin/start_fake_auth_mint.rs +++ b/crates/cdk-integration-tests/src/bin/start_fake_auth_mint.rs @@ -85,6 +85,7 @@ async fn start_fake_auth_mint( swap: AuthType::Blind, restore: AuthType::Blind, check_proof_state: AuthType::Blind, + websocket_auth: AuthType::Blind, }); // Set description for the mint diff --git a/crates/cdk-mintd/src/config.rs b/crates/cdk-mintd/src/config.rs index fa5bc6ff..a5b2ed9e 100644 --- a/crates/cdk-mintd/src/config.rs +++ b/crates/cdk-mintd/src/config.rs @@ -457,6 +457,9 @@ pub struct Auth { pub restore: AuthType, #[serde(default)] pub check_proof_state: AuthType, + /// Enable WebSocket authentication support + #[serde(default = "default_blind")] + pub websocket_auth: AuthType, } fn default_blind() -> AuthType { diff --git a/crates/cdk-mintd/src/env_vars/auth.rs b/crates/cdk-mintd/src/env_vars/auth.rs index 62d32597..c80e6475 100644 --- a/crates/cdk-mintd/src/env_vars/auth.rs +++ b/crates/cdk-mintd/src/env_vars/auth.rs @@ -17,6 +17,10 @@ pub const ENV_AUTH_CHECK_MELT_QUOTE: &str = "CDK_MINTD_AUTH_CHECK_MELT_QUOTE"; pub const ENV_AUTH_SWAP: &str = "CDK_MINTD_AUTH_SWAP"; pub const ENV_AUTH_RESTORE: &str = "CDK_MINTD_AUTH_RESTORE"; pub const ENV_AUTH_CHECK_PROOF_STATE: &str = "CDK_MINTD_AUTH_CHECK_PROOF_STATE"; +pub const ENV_AUTH_WEBSOCKET: &str = "CDK_MINTD_AUTH_WEBSOCKET"; +pub const ENV_AUTH_WS_MINT_QUOTE: &str = "CDK_MINTD_AUTH_WS_MINT_QUOTE"; +pub const ENV_AUTH_WS_MELT_QUOTE: &str = "CDK_MINTD_AUTH_WS_MELT_QUOTE"; +pub const ENV_AUTH_WS_PROOF_STATE: &str = "CDK_MINTD_AUTH_WS_PROOF_STATE"; impl Auth { pub fn from_env(mut self) -> Self { @@ -94,6 +98,12 @@ impl Auth { } } + if let Ok(ws_auth_str) = env::var(ENV_AUTH_WEBSOCKET) { + if let Ok(auth_type) = ws_auth_str.parse() { + self.websocket_auth = auth_type; + } + } + self } } diff --git a/crates/cdk-mintd/src/lib.rs b/crates/cdk-mintd/src/lib.rs index 2a058e95..34ed4833 100644 --- a/crates/cdk-mintd/src/lib.rs +++ b/crates/cdk-mintd/src/lib.rs @@ -792,6 +792,12 @@ async fn setup_authentication( add_endpoint(state_protected_endpoint, &auth_settings.check_proof_state); } + // Ws endpoint + { + let ws_protected_endpoint = ProtectedEndpoint::new(Method::Get, RoutePath::Ws); + add_endpoint(ws_protected_endpoint, &auth_settings.websocket_auth); + } + mint_builder = mint_builder.with_auth( auth_localstore.clone(), auth_settings.openid_discovery, diff --git a/crates/cdk/src/wallet/mint_connector/http_client.rs b/crates/cdk/src/wallet/mint_connector/http_client.rs index c2f29bc9..1a954a72 100644 --- a/crates/cdk/src/wallet/mint_connector/http_client.rs +++ b/crates/cdk/src/wallet/mint_connector/http_client.rs @@ -71,7 +71,7 @@ where /// Get auth token for a protected endpoint #[cfg(feature = "auth")] #[instrument(skip(self))] - async fn get_auth_token( + pub async fn get_auth_token( &self, method: Method, path: RoutePath, diff --git a/crates/cdk/src/wallet/subscription/ws.rs b/crates/cdk/src/wallet/subscription/ws.rs index d39fe238..f0204a90 100644 --- a/crates/cdk/src/wallet/subscription/ws.rs +++ b/crates/cdk/src/wallet/subscription/ws.rs @@ -4,9 +4,12 @@ use std::sync::Arc; use cdk_common::subscription::Params; use cdk_common::ws::{WsMessageOrResponse, WsMethodRequest, WsRequest, WsUnsubscribeRequest}; +#[cfg(feature = "auth")] +use cdk_common::{Method, RoutePath}; use futures::{SinkExt, StreamExt}; use tokio::sync::{mpsc, RwLock}; use tokio_tungstenite::connect_async; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; use tokio_tungstenite::tungstenite::Message; use super::http::http_main; @@ -37,14 +40,64 @@ pub async fn ws_main( url.set_scheme("ws").expect("Could not set scheme"); } - let url = url.to_string(); + let request = match url.to_string().into_client_request() { + Ok(req) => req, + Err(err) => { + tracing::error!("Failed to create client request: {:?}", err); + // Fallback to HTTP client if we can't create the WebSocket request + return http_main( + std::iter::empty(), + http_client, + subscriptions, + new_subscription_recv, + on_drop, + wallet, + ) + .await; + } + }; let mut active_subscriptions = HashMap::>::new(); let mut failure_count = 0; loop { + let mut request_clone = request.clone(); + #[cfg(feature = "auth")] + { + let auth_wallet = http_client.get_auth_wallet().await; + let token = match auth_wallet.as_ref() { + Some(auth_wallet) => { + let endpoint = cdk_common::ProtectedEndpoint::new(Method::Get, RoutePath::Ws); + match auth_wallet.get_auth_for_request(&endpoint).await { + Ok(token) => token, + Err(err) => { + tracing::warn!("Failed to get auth token: {:?}", err); + None + } + } + } + None => None, + }; + + if let Some(auth_token) = token { + let header_key = match &auth_token { + cdk_common::AuthToken::ClearAuth(_) => "Clear-auth", + cdk_common::AuthToken::BlindAuth(_) => "Blind-auth", + }; + + match auth_token.to_string().parse() { + Ok(header_value) => { + request_clone.headers_mut().insert(header_key, header_value); + } + Err(err) => { + tracing::warn!("Failed to parse auth token as header value: {:?}", err); + } + } + } + } + tracing::debug!("Connecting to {}", url); - let ws_stream = match connect_async(&url).await { + let ws_stream = match connect_async(request_clone.clone()).await { Ok((ws_stream, _)) => ws_stream, Err(err) => { failure_count += 1;