feat(cdk): add WebSocket authentication support (#1116)

* feat(cdk): add WebSocket authentication support with comprehensive configuration

- Add WebSocket auth token injection for client connections
- Implement server-side WebSocket authentication verification
- Add configuration options for per-endpoint WebSocket auth types
- Include comprehensive documentation and example configuration
- Support clear, blind, and no-auth modes for WebSocket endpoin
This commit is contained in:
thesimplekid
2025-09-25 08:31:57 +01:00
committed by GitHub
parent f41ce0a3fb
commit e07a9c8e1e
8 changed files with 95 additions and 5 deletions

View File

@@ -161,6 +161,10 @@ pub enum RoutePath {
/// Bolt12 Quote
#[serde(rename = "/v1/melt/bolt12")]
MeltBolt12,
/// WebSocket
#[serde(rename = "/v1/ws")]
Ws,
}
/// Returns [`RoutePath`]s that match regex

View File

@@ -219,10 +219,23 @@ pub(crate) async fn get_check_mint_bolt11_quote(
#[instrument(skip_all)]
pub(crate) async fn ws_handler(
#[cfg(feature = "auth")] auth: AuthHeader,
State(state): State<MintState>,
ws: WebSocketUpgrade,
) -> impl IntoResponse {
ws.on_upgrade(|ws| main_websocket(ws, state))
) -> Result<impl IntoResponse, Response> {
#[cfg(feature = "auth")]
{
state
.mint
.verify_auth(
auth.into(),
&ProtectedEndpoint::new(Method::Get, RoutePath::Ws),
)
.await
.map_err(into_response)?;
}
Ok(ws.on_upgrade(|ws| main_websocket(ws, state)))
}
/// Mint tokens by paying a BOLT11 Lightning invoice.

View File

@@ -85,6 +85,7 @@ async fn start_fake_auth_mint(
swap: AuthType::Blind,
restore: AuthType::Blind,
check_proof_state: AuthType::Blind,
websocket_auth: AuthType::Blind,
});
// Set description for the mint

View File

@@ -457,6 +457,9 @@ pub struct Auth {
pub restore: AuthType,
#[serde(default)]
pub check_proof_state: AuthType,
/// Enable WebSocket authentication support
#[serde(default = "default_blind")]
pub websocket_auth: AuthType,
}
fn default_blind() -> AuthType {

View File

@@ -17,6 +17,10 @@ pub const ENV_AUTH_CHECK_MELT_QUOTE: &str = "CDK_MINTD_AUTH_CHECK_MELT_QUOTE";
pub const ENV_AUTH_SWAP: &str = "CDK_MINTD_AUTH_SWAP";
pub const ENV_AUTH_RESTORE: &str = "CDK_MINTD_AUTH_RESTORE";
pub const ENV_AUTH_CHECK_PROOF_STATE: &str = "CDK_MINTD_AUTH_CHECK_PROOF_STATE";
pub const ENV_AUTH_WEBSOCKET: &str = "CDK_MINTD_AUTH_WEBSOCKET";
pub const ENV_AUTH_WS_MINT_QUOTE: &str = "CDK_MINTD_AUTH_WS_MINT_QUOTE";
pub const ENV_AUTH_WS_MELT_QUOTE: &str = "CDK_MINTD_AUTH_WS_MELT_QUOTE";
pub const ENV_AUTH_WS_PROOF_STATE: &str = "CDK_MINTD_AUTH_WS_PROOF_STATE";
impl Auth {
pub fn from_env(mut self) -> Self {
@@ -94,6 +98,12 @@ impl Auth {
}
}
if let Ok(ws_auth_str) = env::var(ENV_AUTH_WEBSOCKET) {
if let Ok(auth_type) = ws_auth_str.parse() {
self.websocket_auth = auth_type;
}
}
self
}
}

View File

@@ -792,6 +792,12 @@ async fn setup_authentication(
add_endpoint(state_protected_endpoint, &auth_settings.check_proof_state);
}
// Ws endpoint
{
let ws_protected_endpoint = ProtectedEndpoint::new(Method::Get, RoutePath::Ws);
add_endpoint(ws_protected_endpoint, &auth_settings.websocket_auth);
}
mint_builder = mint_builder.with_auth(
auth_localstore.clone(),
auth_settings.openid_discovery,

View File

@@ -71,7 +71,7 @@ where
/// Get auth token for a protected endpoint
#[cfg(feature = "auth")]
#[instrument(skip(self))]
async fn get_auth_token(
pub async fn get_auth_token(
&self,
method: Method,
path: RoutePath,

View File

@@ -4,9 +4,12 @@ use std::sync::Arc;
use cdk_common::subscription::Params;
use cdk_common::ws::{WsMessageOrResponse, WsMethodRequest, WsRequest, WsUnsubscribeRequest};
#[cfg(feature = "auth")]
use cdk_common::{Method, RoutePath};
use futures::{SinkExt, StreamExt};
use tokio::sync::{mpsc, RwLock};
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::Message;
use super::http::http_main;
@@ -37,14 +40,64 @@ pub async fn ws_main(
url.set_scheme("ws").expect("Could not set scheme");
}
let url = url.to_string();
let request = match url.to_string().into_client_request() {
Ok(req) => req,
Err(err) => {
tracing::error!("Failed to create client request: {:?}", err);
// Fallback to HTTP client if we can't create the WebSocket request
return http_main(
std::iter::empty(),
http_client,
subscriptions,
new_subscription_recv,
on_drop,
wallet,
)
.await;
}
};
let mut active_subscriptions = HashMap::<SubId, mpsc::Sender<_>>::new();
let mut failure_count = 0;
loop {
let mut request_clone = request.clone();
#[cfg(feature = "auth")]
{
let auth_wallet = http_client.get_auth_wallet().await;
let token = match auth_wallet.as_ref() {
Some(auth_wallet) => {
let endpoint = cdk_common::ProtectedEndpoint::new(Method::Get, RoutePath::Ws);
match auth_wallet.get_auth_for_request(&endpoint).await {
Ok(token) => token,
Err(err) => {
tracing::warn!("Failed to get auth token: {:?}", err);
None
}
}
}
None => None,
};
if let Some(auth_token) = token {
let header_key = match &auth_token {
cdk_common::AuthToken::ClearAuth(_) => "Clear-auth",
cdk_common::AuthToken::BlindAuth(_) => "Blind-auth",
};
match auth_token.to_string().parse() {
Ok(header_value) => {
request_clone.headers_mut().insert(header_key, header_value);
}
Err(err) => {
tracing::warn!("Failed to parse auth token as header value: {:?}", err);
}
}
}
}
tracing::debug!("Connecting to {}", url);
let ws_stream = match connect_async(&url).await {
let ws_stream = match connect_async(request_clone.clone()).await {
Ok((ws_stream, _)) => ws_stream,
Err(err) => {
failure_count += 1;