diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6d1e848c..b2ea5438 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -490,6 +490,16 @@ jobs: steps: - name: checkout uses: actions/checkout@v4 + - name: Free Disk Space (Ubuntu) + uses: jlumbroso/free-disk-space@main + with: + tool-cache: false + android: true + dotnet: true + haskell: true + large-packages: true + docker-images: true + swap-storage: true - name: Install Nix uses: DeterminateSystems/nix-installer-action@v17 - name: Nix Cache diff --git a/Cargo.toml b/Cargo.toml index 9b8a6afb..d325d851 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,7 +70,7 @@ futures = { version = "0.3.28", default-features = false, features = ["async-awa lightning-invoice = { version = "0.33.0", features = ["serde", "std"] } lightning = { version = "0.1.2", default-features = false, features = ["std"]} ldk-node = "0.6.2" -serde = { version = "1", features = ["derive"] } +serde = { version = "1", features = ["derive", "rc"] } serde_json = "1" thiserror = { version = "2" } tokio = { version = "1", default-features = false, features = ["rt", "macros", "test-util", "sync"] } diff --git a/crates/cashu/Cargo.toml b/crates/cashu/Cargo.toml index affd6f52..830b3957 100644 --- a/crates/cashu/Cargo.toml +++ b/crates/cashu/Cargo.toml @@ -13,13 +13,13 @@ readme = "README.md" [features] default = ["mint", "wallet", "auth"] swagger = ["dep:utoipa"] -mint = ["dep:uuid"] +mint = [] wallet = [] auth = ["dep:strum", "dep:strum_macros", "dep:regex"] bench = [] [dependencies] -uuid = { workspace = true, optional = true } +uuid.workspace = true bitcoin.workspace = true cbor-diag.workspace = true ciborium.workspace = true diff --git a/crates/cashu/src/lib.rs b/crates/cashu/src/lib.rs index 4b1877c7..07733e52 100644 --- a/crates/cashu/src/lib.rs +++ b/crates/cashu/src/lib.rs @@ -16,7 +16,6 @@ pub use self::mint_url::MintUrl; pub use self::nuts::*; pub use self::util::SECP256K1; -#[cfg(feature = "mint")] pub mod quote_id; #[doc(hidden)] diff --git a/crates/cashu/src/nuts/nut17/mod.rs b/crates/cashu/src/nuts/nut17/mod.rs index d12960bc..eb05926c 100644 --- a/crates/cashu/src/nuts/nut17/mod.rs +++ b/crates/cashu/src/nuts/nut17/mod.rs @@ -2,13 +2,11 @@ use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; -#[cfg(feature = "mint")] use super::PublicKey; use crate::nuts::{ CurrencyUnit, MeltQuoteBolt11Response, MintQuoteBolt11Response, PaymentMethod, ProofState, }; -#[cfg(feature = "mint")] -use crate::quote_id::{QuoteId, QuoteIdError}; +use crate::quote_id::QuoteIdError; use crate::MintQuoteBolt12Response; pub mod ws; @@ -109,7 +107,10 @@ pub enum WsCommand { ProofState, } -impl From> for NotificationPayload { +impl From> for NotificationPayload +where + T: Clone, +{ fn from(mint_quote: MintQuoteBolt12Response) -> NotificationPayload { NotificationPayload::MintQuoteBolt12Response(mint_quote) } @@ -119,7 +120,10 @@ impl From> for NotificationPayload { #[serde(bound = "T: Serialize + DeserializeOwned")] #[serde(untagged)] /// Subscription response -pub enum NotificationPayload { +pub enum NotificationPayload +where + T: Clone, +{ /// Proof State ProofState(ProofState), /// Melt Quote Bolt11 Response @@ -130,38 +134,23 @@ pub enum NotificationPayload { MintQuoteBolt12Response(MintQuoteBolt12Response), } -impl From for NotificationPayload { - fn from(proof_state: ProofState) -> NotificationPayload { - NotificationPayload::ProofState(proof_state) - } -} - -impl From> for NotificationPayload { - fn from(melt_quote: MeltQuoteBolt11Response) -> NotificationPayload { - NotificationPayload::MeltQuoteBolt11Response(melt_quote) - } -} - -impl From> for NotificationPayload { - fn from(mint_quote: MintQuoteBolt11Response) -> NotificationPayload { - NotificationPayload::MintQuoteBolt11Response(mint_quote) - } -} - -#[cfg(feature = "mint")] -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Hash, Serialize)] +#[serde(bound = "T: Serialize + DeserializeOwned")] /// A parsed notification -pub enum Notification { +pub enum NotificationId +where + T: Clone, +{ /// ProofState id is a Pubkey ProofState(PublicKey), /// MeltQuote id is an QuoteId - MeltQuoteBolt11(QuoteId), + MeltQuoteBolt11(T), /// MintQuote id is an QuoteId - MintQuoteBolt11(QuoteId), + MintQuoteBolt11(T), /// MintQuote id is an QuoteId - MintQuoteBolt12(QuoteId), + MintQuoteBolt12(T), /// MintQuote id is an QuoteId - MeltQuoteBolt12(QuoteId), + MeltQuoteBolt12(T), } /// Kind @@ -187,7 +176,6 @@ impl AsRef for Params { /// Parsing error #[derive(thiserror::Error, Debug)] pub enum Error { - #[cfg(feature = "mint")] #[error("Uuid Error: {0}")] /// Uuid Error QuoteId(#[from] QuoteIdError), diff --git a/crates/cashu/src/nuts/nut17/ws.rs b/crates/cashu/src/nuts/nut17/ws.rs index a15c52ff..f248454c 100644 --- a/crates/cashu/src/nuts/nut17/ws.rs +++ b/crates/cashu/src/nuts/nut17/ws.rs @@ -36,7 +36,10 @@ pub struct WsUnsubscribeResponse { /// subscription #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(bound = "T: Serialize + DeserializeOwned, I: Serialize + DeserializeOwned")] -pub struct NotificationInner { +pub struct NotificationInner +where + T: Clone, +{ /// The subscription ID #[serde(rename = "subId")] pub sub_id: I, diff --git a/crates/cdk-axum/src/ws/mod.rs b/crates/cdk-axum/src/ws/mod.rs index 9ea30dc5..b5b0060c 100644 --- a/crates/cdk-axum/src/ws/mod.rs +++ b/crates/cdk-axum/src/ws/mod.rs @@ -1,9 +1,10 @@ use std::collections::HashMap; +use std::sync::Arc; use axum::extract::ws::{CloseFrame, Message, WebSocket}; use cdk::mint::QuoteId; use cdk::nuts::nut17::NotificationPayload; -use cdk::pub_sub::SubId; +use cdk::subscription::SubId; use cdk::ws::{ notification_to_ws_message, NotificationInner, WsErrorBody, WsMessageOrResponse, WsMethodRequest, WsRequest, @@ -36,8 +37,8 @@ pub use error::WsError; pub struct WsContext { state: MintState, - subscriptions: HashMap>, - publisher: mpsc::Sender<(SubId, NotificationPayload)>, + subscriptions: HashMap, tokio::task::JoinHandle<()>>, + publisher: mpsc::Sender<(Arc, NotificationPayload)>, } /// Main function for websocket connections diff --git a/crates/cdk-axum/src/ws/subscribe.rs b/crates/cdk-axum/src/ws/subscribe.rs index fb8c9201..94a0b285 100644 --- a/crates/cdk-axum/src/ws/subscribe.rs +++ b/crates/cdk-axum/src/ws/subscribe.rs @@ -1,4 +1,4 @@ -use cdk::subscription::{IndexableParams, Params}; +use cdk::subscription::Params; use cdk::ws::{WsResponseResult, WsSubscribeResponse}; use super::{WsContext, WsError}; @@ -15,22 +15,20 @@ pub(crate) async fn handle( return Err(WsError::InvalidParams); } - let params: IndexableParams = params.into(); - let mut subscription = context .state .mint .pubsub_manager() - .try_subscribe(params) - .await + .subscribe(params) .map_err(|_| WsError::ParseError)?; let publisher = context.publisher.clone(); + let sub_id_for_sender = sub_id.clone(); context.subscriptions.insert( sub_id.clone(), tokio::spawn(async move { while let Some(response) = subscription.recv().await { - let _ = publisher.send(response).await; + let _ = publisher.try_send((sub_id_for_sender.clone(), response.into_inner())); } }), ); diff --git a/crates/cdk-common/Cargo.toml b/crates/cdk-common/Cargo.toml index 2ca99661..5fb81bb9 100644 --- a/crates/cdk-common/Cargo.toml +++ b/crates/cdk-common/Cargo.toml @@ -40,10 +40,16 @@ anyhow.workspace = true serde_json.workspace = true serde_with.workspace = true web-time.workspace = true +tokio.workspace = true +parking_lot = "0.12.5" [target.'cfg(target_arch = "wasm32")'.dependencies] uuid = { workspace = true, features = ["js"], optional = true } +getrandom = { version = "0.2", features = ["js"] } +wasm-bindgen = "0.2" +wasm-bindgen-futures = "0.4" [dev-dependencies] rand.workspace = true bip39.workspace = true +wasm-bindgen-test = "0.3" diff --git a/crates/cdk-common/src/lib.rs b/crates/cdk-common/src/lib.rs index f30e082a..9db301f5 100644 --- a/crates/cdk-common/src/lib.rs +++ b/crates/cdk-common/src/lib.rs @@ -33,3 +33,5 @@ pub use cashu::nuts::{self, *}; pub use cashu::quote_id::{self, *}; pub use cashu::{dhke, ensure_cdk, mint_url, secret, util, SECP256K1}; pub use error::Error; +/// Re-export parking_lot for reuse +pub use parking_lot; diff --git a/crates/cdk-common/src/pub_sub/error.rs b/crates/cdk-common/src/pub_sub/error.rs new file mode 100644 index 00000000..c4845d3f --- /dev/null +++ b/crates/cdk-common/src/pub_sub/error.rs @@ -0,0 +1,44 @@ +//! Error types for the pub-sub module. + +use tokio::sync::mpsc::error::TrySendError; + +#[derive(thiserror::Error, Debug)] +/// Error +pub enum Error { + /// No subscription found + #[error("Subscription not found")] + NoSubscription, + + /// Parsing error + #[error("Parsing Error {0}")] + ParsingError(String), + + /// Internal error + #[error("Internal")] + Internal(Box), + + /// Internal error + #[error("Internal error {0}")] + InternalStr(String), + + /// Not supported + #[error("Not supported")] + NotSupported, + + /// Channel is full + #[error("Channel is full")] + ChannelFull, + + /// Channel is closed + #[error("Channel is close")] + ChannelClosed, +} + +impl From> for Error { + fn from(value: TrySendError) -> Self { + match value { + TrySendError::Closed(_) => Error::ChannelClosed, + TrySendError::Full(_) => Error::ChannelFull, + } + } +} diff --git a/crates/cdk-common/src/pub_sub/index.rs b/crates/cdk-common/src/pub_sub/index.rs deleted file mode 100644 index 15b11b44..00000000 --- a/crates/cdk-common/src/pub_sub/index.rs +++ /dev/null @@ -1,161 +0,0 @@ -//! WS Index - -use std::fmt::Debug; -use std::ops::Deref; -use std::sync::atomic::{AtomicUsize, Ordering}; - -use super::SubId; - -/// Indexable trait -pub trait Indexable { - /// The type of the index, it is unknown and it is up to the Manager's - /// generic type - type Type: PartialOrd + Ord + Send + Sync + Debug; - - /// To indexes - fn to_indexes(&self) -> Vec>; -} - -#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone)] -/// Index -/// -/// The Index is a sorted structure that is used to quickly find matches -/// -/// The counter is used to make sure each Index is unique, even if the prefix -/// are the same, and also to make sure that earlier indexes matches first -pub struct Index -where - T: PartialOrd + Ord + Send + Sync + Debug, -{ - prefix: T, - counter: SubscriptionGlobalId, - id: super::SubId, -} - -impl From<&Index> for super::SubId -where - T: PartialOrd + Ord + Send + Sync + Debug, -{ - fn from(val: &Index) -> Self { - val.id.clone() - } -} - -impl Deref for Index -where - T: PartialOrd + Ord + Send + Sync + Debug, -{ - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.prefix - } -} - -impl Index -where - T: PartialOrd + Ord + Send + Sync + Debug, -{ - /// Compare the - pub fn cmp_prefix(&self, other: &Index) -> std::cmp::Ordering { - self.prefix.cmp(&other.prefix) - } - - /// Returns a globally unique id for the Index - pub fn unique_id(&self) -> usize { - self.counter.0 - } -} - -impl From<(T, SubId, SubscriptionGlobalId)> for Index -where - T: PartialOrd + Ord + Send + Sync + Debug, -{ - fn from((prefix, id, counter): (T, SubId, SubscriptionGlobalId)) -> Self { - Self { - prefix, - id, - counter, - } - } -} - -impl From<(T, SubId)> for Index -where - T: PartialOrd + Ord + Send + Sync + Debug, -{ - fn from((prefix, id): (T, SubId)) -> Self { - Self { - prefix, - id, - counter: Default::default(), - } - } -} - -impl From for Index -where - T: PartialOrd + Ord + Send + Sync + Debug, -{ - fn from(prefix: T) -> Self { - Self { - prefix, - id: Default::default(), - counter: SubscriptionGlobalId(0), - } - } -} - -static COUNTER: AtomicUsize = AtomicUsize::new(0); - -/// Dummy type -/// -/// This is only use so each Index is unique, with the same prefix. -/// -/// The prefix is used to leverage the BTree to find things quickly, but each -/// entry/key must be unique, so we use this dummy type to make sure each Index -/// is unique. -/// -/// Unique is also used to make sure that the indexes are sorted by creation order -#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Copy)] -pub struct SubscriptionGlobalId(usize); - -impl Default for SubscriptionGlobalId { - fn default() -> Self { - Self(COUNTER.fetch_add(1, Ordering::Relaxed)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_index_from_tuple() { - let sub_id = SubId::from("test_sub_id"); - let prefix = "test_prefix"; - let index: Index<&str> = Index::from((prefix, sub_id.clone())); - assert_eq!(index.prefix, "test_prefix"); - assert_eq!(index.id, sub_id); - } - - #[test] - fn test_index_cmp_prefix() { - let sub_id = SubId::from("test_sub_id"); - let index1: Index<&str> = Index::from(("a", sub_id.clone())); - let index2: Index<&str> = Index::from(("b", sub_id.clone())); - assert_eq!(index1.cmp_prefix(&index2), std::cmp::Ordering::Less); - } - - #[test] - fn test_sub_id_from_str() { - let sub_id = SubId::from("test_sub_id"); - assert_eq!(sub_id.0, "test_sub_id"); - } - - #[test] - fn test_sub_id_deref() { - let sub_id = SubId::from("test_sub_id"); - assert_eq!(&*sub_id, "test_sub_id"); - } -} diff --git a/crates/cdk-common/src/pub_sub/mod.rs b/crates/cdk-common/src/pub_sub/mod.rs index aef062b5..7a0b0e08 100644 --- a/crates/cdk-common/src/pub_sub/mod.rs +++ b/crates/cdk-common/src/pub_sub/mod.rs @@ -1,77 +1,180 @@ -//! Publish–subscribe pattern. +//! Publish/Subscribe core //! -//! This is a generic implementation for -//! [NUT-17]() with a type -//! agnostic Publish-subscribe manager. +//! This module defines the transport-agnostic pub/sub primitives used by both +//! mint and wallet components. The design prioritizes: //! -//! The manager has a method for subscribers to subscribe to events with a -//! generic type that must be converted to a vector of indexes. +//! - **Request coalescing**: multiple local subscribers to the same remote topic +//! result in a single upstream subscription, with local fan‑out. +//! - **Latest-on-subscribe** (NUT-17): on (re)subscription, the most recent event +//! is fetched and delivered before streaming new ones. +//! - **Backpressure-aware delivery**: bounded channels + drop policies prevent +//! a slow consumer from stalling the whole pipeline. +//! - **Resilience**: automatic reconnect with exponential backoff; WebSocket +//! streaming when available, HTTP long-poll fallback otherwise. //! -//! Events are also generic that should implement the `Indexable` trait. -use std::fmt::Debug; -use std::ops::Deref; -use std::str::FromStr; +//! Terms used throughout the module: +//! - **Event**: a domain object that maps to one or more `Topic`s via `Event::get_topics`. +//! - **Topic**: an index/type that defines storage and matching semantics. +//! - **SubscriptionRequest**: a domain-specific filter that can be converted into +//! low-level transport messages (e.g., WebSocket subscribe frames). +//! - **Spec**: type bundle tying `Event`, `Topic`, `SubscriptionId`, and serialization. -use serde::{Deserialize, Serialize}; +mod error; +mod pubsub; +pub mod remote_consumer; +mod subscriber; +mod types; -pub mod index; +pub use self::error::Error; +pub use self::pubsub::Pubsub; +pub use self::subscriber::{Subscriber, SubscriptionRequest}; +pub use self::types::*; -/// Default size of the remove channel -pub const DEFAULT_REMOVE_SIZE: usize = 10_000; +#[cfg(test)] +mod test { + use std::collections::HashMap; + use std::sync::{Arc, RwLock}; -/// Default channel size for subscription buffering -pub const DEFAULT_CHANNEL_SIZE: usize = 10; + use serde::{Deserialize, Serialize}; -#[async_trait::async_trait] -/// On New Subscription trait -/// -/// This trait is optional and it is used to notify the application when a new -/// subscription is created. This is useful when the application needs to send -/// the initial state to the subscriber upon subscription -pub trait OnNewSubscription { - /// Index type - type Index; - /// Subscription event type - type Event; + use super::subscriber::SubscriptionRequest; + use super::{Error, Event, Pubsub, Spec, Subscriber}; - /// Called when a new subscription is created - async fn on_new_subscription( - &self, - request: &[&Self::Index], - ) -> Result, String>; -} + #[derive(Clone, Debug, Serialize, Eq, PartialEq, Deserialize)] + pub struct Message { + pub foo: u64, + pub bar: u64, + } -/// Subscription Id wrapper -/// -/// This is the place to add some sane default (like a max length) to the -/// subscription ID -#[derive(Debug, Clone, Default, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] -pub struct SubId(String); + #[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Deserialize, Serialize)] + pub enum IndexTest { + Foo(u64), + Bar(u64), + } -impl From<&str> for SubId { - fn from(s: &str) -> Self { - Self(s.to_string()) - } -} - -impl From for SubId { - fn from(s: String) -> Self { - Self(s) - } -} - -impl FromStr for SubId { - type Err = (); - - fn from_str(s: &str) -> Result { - Ok(Self(s.to_string())) - } -} - -impl Deref for SubId { - type Target = String; - - fn deref(&self) -> &Self::Target { - &self.0 + impl Event for Message { + type Topic = IndexTest; + + fn get_topics(&self) -> Vec { + vec![IndexTest::Foo(self.foo), IndexTest::Bar(self.bar)] + } + } + + pub struct CustomPubSub { + pub storage: Arc>>, + } + + #[async_trait::async_trait] + impl Spec for CustomPubSub { + type Topic = IndexTest; + + type Event = Message; + + type SubscriptionId = String; + + type Context = (); + + fn new_instance(_context: Self::Context) -> Arc + where + Self: Sized, + { + Arc::new(Self { + storage: Default::default(), + }) + } + + async fn fetch_events( + self: &Arc, + topics: Vec<::Topic>, + reply_to: Subscriber, + ) where + Self: Sized, + { + let storage = self.storage.read().unwrap(); + + for index in topics { + if let Some(value) = storage.get(&index) { + let _ = reply_to.send(value.clone()); + } + } + } + } + + #[derive(Debug, Clone)] + pub enum SubscriptionReq { + Foo(u64), + Bar(u64), + } + + impl SubscriptionRequest for SubscriptionReq { + type Topic = IndexTest; + + type SubscriptionId = String; + + fn try_get_topics(&self) -> Result, Error> { + Ok(vec![match self { + SubscriptionReq::Bar(n) => IndexTest::Bar(*n), + SubscriptionReq::Foo(n) => IndexTest::Foo(*n), + }]) + } + + fn subscription_name(&self) -> Arc { + Arc::new("test".to_owned()) + } + } + + #[tokio::test] + async fn delivery_twice_realtime() { + let pubsub = Pubsub::new(CustomPubSub::new_instance(())); + + assert_eq!(pubsub.active_subscribers(), 0); + + let mut subscriber = pubsub.subscribe(SubscriptionReq::Foo(2)).unwrap(); + + assert_eq!(pubsub.active_subscribers(), 1); + + let _ = pubsub.publish_now(Message { foo: 2, bar: 1 }); + let _ = pubsub.publish_now(Message { foo: 2, bar: 2 }); + + assert_eq!(subscriber.recv().await.map(|x| x.bar), Some(1)); + assert_eq!(subscriber.recv().await.map(|x| x.bar), Some(2)); + assert!(subscriber.try_recv().is_none()); + + drop(subscriber); + + assert_eq!(pubsub.active_subscribers(), 0); + } + + #[tokio::test] + async fn read_from_storage() { + let x = CustomPubSub::new_instance(()); + let storage = x.storage.clone(); + + let pubsub = Pubsub::new(x); + + { + // set previous value + let mut s = storage.write().unwrap(); + s.insert(IndexTest::Bar(2), Message { foo: 3, bar: 2 }); + } + + let mut subscriber = pubsub.subscribe(SubscriptionReq::Bar(2)).unwrap(); + + // Just should receive the latest + assert_eq!(subscriber.recv().await.map(|x| x.foo), Some(3)); + + // realtime delivery test + let _ = pubsub.publish_now(Message { foo: 1, bar: 2 }); + assert_eq!(subscriber.recv().await.map(|x| x.foo), Some(1)); + + { + // set previous value + let mut s = storage.write().unwrap(); + s.insert(IndexTest::Bar(2), Message { foo: 1, bar: 2 }); + } + + // new subscription should only get the latest state (it is up to the Topic trait) + let mut y = pubsub.subscribe(SubscriptionReq::Bar(2)).unwrap(); + assert_eq!(y.recv().await.map(|x| x.foo), Some(1)); } } diff --git a/crates/cdk-common/src/pub_sub/pubsub.rs b/crates/cdk-common/src/pub_sub/pubsub.rs new file mode 100644 index 00000000..fff71ef1 --- /dev/null +++ b/crates/cdk-common/src/pub_sub/pubsub.rs @@ -0,0 +1,185 @@ +//! Pub-sub producer + +use std::cmp::Ordering; +use std::collections::{BTreeMap, HashSet}; +use std::sync::atomic::AtomicUsize; +use std::sync::Arc; + +use parking_lot::RwLock; +use tokio::sync::mpsc; + +use super::subscriber::{ActiveSubscription, SubscriptionRequest}; +use super::{Error, Event, Spec, Subscriber}; + +/// Default channel size for subscription buffering +pub const DEFAULT_CHANNEL_SIZE: usize = 10_000; + +/// Subscriber Receiver +pub type SubReceiver = mpsc::Receiver<(Arc<::SubscriptionId>, ::Event)>; + +/// Internal Index Tree +pub type TopicTree = Arc< + RwLock< + BTreeMap< + // Index with a subscription unique ID + (::Topic, usize), + Subscriber, + >, + >, +>; + +/// Manager +pub struct Pubsub +where + S: Spec + 'static, +{ + inner: Arc, + listeners_topics: TopicTree, + unique_subscription_counter: AtomicUsize, + active_subscribers: Arc, +} + +impl Pubsub +where + S: Spec + 'static, +{ + /// Create a new instance + pub fn new(inner: Arc) -> Self { + Self { + inner, + listeners_topics: Default::default(), + unique_subscription_counter: 0.into(), + active_subscribers: Arc::new(0.into()), + } + } + + /// Total number of active subscribers, it is not the number of active topics being subscribed + pub fn active_subscribers(&self) -> usize { + self.active_subscribers + .load(std::sync::atomic::Ordering::Relaxed) + } + + /// Publish an event to all listenrs + #[inline(always)] + fn publish_internal(event: S::Event, listeners_index: &TopicTree) -> Result<(), Error> { + let index_storage = listeners_index.read(); + + let mut sent = HashSet::new(); + for topic in event.get_topics() { + for ((subscription_index, unique_id), sender) in + index_storage.range((topic.clone(), 0)..) + { + if subscription_index.cmp(&topic) != Ordering::Equal { + break; + } + if sent.contains(&unique_id) { + continue; + } + sent.insert(unique_id); + sender.send(event.clone()); + } + } + + Ok(()) + } + + /// Broadcast an event to all listeners + #[inline(always)] + pub fn publish(&self, event: E) + where + E: Into, + { + let topics = self.listeners_topics.clone(); + let event = event.into(); + + #[cfg(not(target_arch = "wasm32"))] + tokio::spawn(async move { + let _ = Self::publish_internal(event, &topics); + }); + + #[cfg(target_arch = "wasm32")] + wasm_bindgen_futures::spawn_local(async move { + let _ = Self::publish_internal(event, &topics); + }); + } + + /// Broadcast an event to all listeners right away, blocking the current thread + /// + /// This function takes an Arc to the storage struct, the event_id, the kind + /// and the vent to broadcast + #[inline(always)] + pub fn publish_now(&self, event: E) -> Result<(), Error> + where + E: Into, + { + let event = event.into(); + Self::publish_internal(event, &self.listeners_topics) + } + + /// Subscribe proving custom sender/receiver mpsc + #[inline(always)] + pub fn subscribe_with( + &self, + request: I, + sender: &mpsc::Sender<(Arc, S::Event)>, + receiver: Option>, + ) -> Result, Error> + where + I: SubscriptionRequest< + Topic = ::Topic, + SubscriptionId = S::SubscriptionId, + >, + { + let subscription_name = request.subscription_name(); + let sender = Subscriber::new(subscription_name.clone(), sender); + let mut index_storage = self.listeners_topics.write(); + let subscription_internal_id = self + .unique_subscription_counter + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + self.active_subscribers + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + let subscribed_to = request.try_get_topics()?; + + for index in subscribed_to.iter() { + index_storage.insert((index.clone(), subscription_internal_id), sender.clone()); + } + drop(index_storage); + + let inner = self.inner.clone(); + let subscribed_to_for_spawn = subscribed_to.clone(); + + #[cfg(not(target_arch = "wasm32"))] + tokio::spawn(async move { + // TODO: Ignore topics broadcasted from fetch_events _if_ any real time has been broadcasted already. + inner.fetch_events(subscribed_to_for_spawn, sender).await; + }); + + #[cfg(target_arch = "wasm32")] + wasm_bindgen_futures::spawn_local(async move { + inner.fetch_events(subscribed_to_for_spawn, sender).await; + }); + + Ok(ActiveSubscription::new( + subscription_internal_id, + subscription_name, + self.active_subscribers.clone(), + self.listeners_topics.clone(), + subscribed_to, + receiver, + )) + } + + /// Subscribe + pub fn subscribe(&self, request: I) -> Result, Error> + where + I: SubscriptionRequest< + Topic = ::Topic, + SubscriptionId = S::SubscriptionId, + >, + { + let (sender, receiver) = mpsc::channel(DEFAULT_CHANNEL_SIZE); + self.subscribe_with(request, &sender, Some(receiver)) + } +} diff --git a/crates/cdk-common/src/pub_sub/remote_consumer.rs b/crates/cdk-common/src/pub_sub/remote_consumer.rs new file mode 100644 index 00000000..b7ad9710 --- /dev/null +++ b/crates/cdk-common/src/pub_sub/remote_consumer.rs @@ -0,0 +1,885 @@ +//! Pub-sub consumer +//! +//! Consumers are designed to connect to a producer, through a transport, and subscribe to events. +use std::collections::{HashMap, VecDeque}; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; +use std::time::Duration; + +use parking_lot::RwLock; +use tokio::sync::mpsc; +use tokio::time::{sleep, Instant}; + +use super::subscriber::{ActiveSubscription, SubscriptionRequest}; +use super::{Error, Event, Pubsub, Spec}; + +const STREAM_CONNECTION_BACKOFF: Duration = Duration::from_millis(2_000); + +const STREAM_CONNECTION_MAX_BACKOFF: Duration = Duration::from_millis(30_000); + +const INTERNAL_POLL_SIZE: usize = 1_000; + +const POLL_SLEEP: Duration = Duration::from_millis(2_000); + +#[cfg(target_arch = "wasm32")] +use wasm_bindgen_futures; + +struct UniqueSubscription +where + S: Spec, +{ + name: S::SubscriptionId, + total_subscribers: usize, +} + +type UniqueSubscriptions = RwLock::Topic, UniqueSubscription>>; + +type ActiveSubscriptions = + RwLock::SubscriptionId>, Vec<::Topic>>>; + +type CacheEvent = HashMap<<::Event as Event>::Topic, ::Event>; + +/// Subscription consumer +pub struct Consumer +where + T: Transport + 'static, +{ + transport: T, + inner_pubsub: Arc>, + remote_subscriptions: UniqueSubscriptions, + subscriptions: ActiveSubscriptions, + stream_ctrl: RwLock>>>, + still_running: AtomicBool, + prefer_polling: bool, + /// Cached events + /// + /// The cached events are useful to share events. The cache is automatically evicted it is + /// disconnected from the remote source, meaning the cache is only active while there is an + /// active subscription to the remote source, and it remembers the latest event. + cached_events: Arc>>, +} + +/// Remote consumer +pub struct RemoteActiveConsumer +where + T: Transport + 'static, +{ + inner: ActiveSubscription, + previous_messages: VecDeque<::Event>, + consumer: Arc>, +} + +impl RemoteActiveConsumer +where + T: Transport + 'static, +{ + /// Receives the next event + pub async fn recv(&mut self) -> Option<::Event> { + if let Some(event) = self.previous_messages.pop_front() { + Some(event) + } else { + self.inner.recv().await + } + } + + /// Try receive an event or return Noen right away + pub fn try_recv(&mut self) -> Option<::Event> { + if let Some(event) = self.previous_messages.pop_front() { + Some(event) + } else { + self.inner.try_recv() + } + } + + /// Get the subscription name + pub fn name(&self) -> &::SubscriptionId { + self.inner.name() + } +} + +impl Drop for RemoteActiveConsumer +where + T: Transport + 'static, +{ + fn drop(&mut self) { + let _ = self.consumer.unsubscribe(self.name().clone()); + } +} + +/// Struct to relay events from Poll and Streams from the external subscription to the local +/// subscribers +pub struct InternalRelay +where + S: Spec + 'static, +{ + inner: Arc>, + cached_events: Arc>>, +} + +impl InternalRelay +where + S: Spec + 'static, +{ + /// Relay a remote event locally + pub fn send(&self, event: X) + where + X: Into, + { + let event = event.into(); + let mut cached_events = self.cached_events.write(); + + for topic in event.get_topics() { + cached_events.insert(topic, event.clone()); + } + + self.inner.publish(event); + } +} + +impl Consumer +where + T: Transport + 'static, +{ + /// Creates a new instance + pub fn new( + transport: T, + prefer_polling: bool, + context: ::Context, + ) -> Arc { + let this = Arc::new(Self { + transport, + prefer_polling, + inner_pubsub: Arc::new(Pubsub::new(T::Spec::new_instance(context))), + subscriptions: Default::default(), + remote_subscriptions: Default::default(), + stream_ctrl: RwLock::new(None), + cached_events: Default::default(), + still_running: true.into(), + }); + + #[cfg(not(target_arch = "wasm32"))] + tokio::spawn(Self::stream(this.clone())); + + #[cfg(target_arch = "wasm32")] + wasm_bindgen_futures::spawn_local(Self::stream(this.clone())); + + this + } + + async fn stream(instance: Arc) { + let mut stream_supported = true; + let mut poll_supported = true; + + let mut backoff = STREAM_CONNECTION_BACKOFF; + let mut retry_at = None; + + loop { + if (!stream_supported && !poll_supported) + || !instance + .still_running + .load(std::sync::atomic::Ordering::Relaxed) + { + break; + } + + if instance.remote_subscriptions.read().is_empty() { + sleep(Duration::from_millis(100)).await; + continue; + } + + if stream_supported + && !instance.prefer_polling + && retry_at + .map(|retry_at| retry_at < Instant::now()) + .unwrap_or(true) + { + let (sender, receiver) = mpsc::channel(INTERNAL_POLL_SIZE); + + { + *instance.stream_ctrl.write() = Some(sender); + } + + let current_subscriptions = { + instance + .remote_subscriptions + .read() + .iter() + .map(|(key, name)| (name.name.clone(), key.clone())) + .collect::>() + }; + + if let Err(err) = instance + .transport + .stream( + receiver, + current_subscriptions, + InternalRelay { + inner: instance.inner_pubsub.clone(), + cached_events: instance.cached_events.clone(), + }, + ) + .await + { + retry_at = Some(Instant::now() + backoff); + backoff = + (backoff + STREAM_CONNECTION_BACKOFF).min(STREAM_CONNECTION_MAX_BACKOFF); + + if matches!(err, Error::NotSupported) { + stream_supported = false; + } + tracing::error!("Long connection failed with error {:?}", err); + } else { + backoff = STREAM_CONNECTION_BACKOFF; + } + + // remove sender to stream, as there is no stream + let _ = instance.stream_ctrl.write().take(); + } + + if poll_supported { + let current_subscriptions = { + instance + .remote_subscriptions + .read() + .iter() + .map(|(key, name)| (name.name.clone(), key.clone())) + .collect::>() + }; + + if let Err(err) = instance + .transport + .poll( + current_subscriptions, + InternalRelay { + inner: instance.inner_pubsub.clone(), + cached_events: instance.cached_events.clone(), + }, + ) + .await + { + if matches!(err, Error::NotSupported) { + poll_supported = false; + } + tracing::error!("Polling failed with error {:?}", err); + } + + sleep(POLL_SLEEP).await; + } + } + } + + /// Unsubscribe from a topic, this is called automatically when RemoteActiveSubscription goes + /// out of scope + fn unsubscribe( + self: &Arc, + subscription_name: ::SubscriptionId, + ) -> Result<(), Error> { + let topics = self + .subscriptions + .write() + .remove(&subscription_name) + .ok_or(Error::NoSubscription)?; + + let mut remote_subscriptions = self.remote_subscriptions.write(); + + for topic in topics { + let mut remote_subscription = + if let Some(remote_subscription) = remote_subscriptions.remove(&topic) { + remote_subscription + } else { + continue; + }; + + remote_subscription.total_subscribers = remote_subscription + .total_subscribers + .checked_sub(1) + .unwrap_or_default(); + + if remote_subscription.total_subscribers == 0 { + let mut cached_events = self.cached_events.write(); + + cached_events.remove(&topic); + + self.message_to_stream(StreamCtrl::Unsubscribe(remote_subscription.name.clone()))?; + } else { + remote_subscriptions.insert(topic, remote_subscription); + } + } + + if remote_subscriptions.is_empty() { + self.message_to_stream(StreamCtrl::Stop)?; + } + + Ok(()) + } + + #[inline(always)] + fn message_to_stream(&self, message: StreamCtrl) -> Result<(), Error> { + let to_stream = self.stream_ctrl.read(); + + if let Some(to_stream) = to_stream.as_ref() { + Ok(to_stream.try_send(message)?) + } else { + Ok(()) + } + } + + /// Creates a subscription + /// + /// The subscriptions have two parts: + /// + /// 1. Will create the subscription to the remote Pubsub service, Any events will be moved to + /// the internal pubsub + /// + /// 2. The internal subscription to the inner Pubsub. Because all subscriptions are going the + /// transport, once events matches subscriptions, the inner_pubsub will receive the message and + /// broadcasat the event. + pub fn subscribe(self: &Arc, request: I) -> Result, Error> + where + I: SubscriptionRequest< + Topic = ::Topic, + SubscriptionId = ::SubscriptionId, + >, + { + let subscription_name = request.subscription_name(); + let topics = request.try_get_topics()?; + + let mut remote_subscriptions = self.remote_subscriptions.write(); + let mut subscriptions = self.subscriptions.write(); + + if subscriptions.get(&subscription_name).is_some() { + return Err(Error::NoSubscription); + } + + let mut previous_messages = Vec::new(); + let cached_events = self.cached_events.read(); + + for topic in topics.iter() { + if let Some(subscription) = remote_subscriptions.get_mut(topic) { + subscription.total_subscribers += 1; + + if let Some(v) = cached_events.get(topic).cloned() { + previous_messages.push(v); + } + } else { + let internal_sub_name = self.transport.new_name(); + remote_subscriptions.insert( + topic.clone(), + UniqueSubscription { + total_subscribers: 1, + name: internal_sub_name.clone(), + }, + ); + + // new subscription is created, so the connection worker should be notified + self.message_to_stream(StreamCtrl::Subscribe((internal_sub_name, topic.clone())))?; + } + } + + subscriptions.insert(subscription_name, topics); + drop(subscriptions); + + Ok(RemoteActiveConsumer { + inner: self.inner_pubsub.subscribe(request)?, + previous_messages: previous_messages.into(), + consumer: self.clone(), + }) + } +} + +impl Drop for Consumer +where + T: Transport + 'static, +{ + fn drop(&mut self) { + self.still_running + .store(false, std::sync::atomic::Ordering::Release); + if let Some(to_stream) = self.stream_ctrl.read().as_ref() { + let _ = to_stream.try_send(StreamCtrl::Stop).inspect_err(|err| { + tracing::error!("Failed to send message LongPoll::Stop due to {err:?}") + }); + } + } +} + +/// Subscribe Message +pub type SubscribeMessage = (::SubscriptionId, ::Topic); + +/// Messages sent from the [`Consumer`] to the [`Transport`] background loop. +pub enum StreamCtrl +where + S: Spec + 'static, +{ + /// Add a subscription + Subscribe(SubscribeMessage), + /// Desuscribe + Unsubscribe(S::SubscriptionId), + /// Exit the loop + Stop, +} + +impl Clone for StreamCtrl +where + S: Spec + 'static, +{ + fn clone(&self) -> Self { + match self { + Self::Subscribe(s) => Self::Subscribe(s.clone()), + Self::Unsubscribe(u) => Self::Unsubscribe(u.clone()), + Self::Stop => Self::Stop, + } + } +} + +/// Transport abstracts how the consumer talks to the remote pubsub. +/// +/// Implement this on your HTTP/WebSocket client. The transport is responsible for: +/// - creating unique subscription names, +/// - keeping a long connection via `stream` **or** performing on-demand `poll`, +/// - forwarding remote events to `InternalRelay`. +/// +/// ```ignore +/// struct WsTransport { /* ... */ } +/// #[async_trait::async_trait] +/// impl Transport for WsTransport { +/// type Topic = MyTopic; +/// fn new_name(&self) -> ::SubscriptionName { 0 } +/// async fn stream(/* ... */) -> Result<(), Error> { Ok(()) } +/// async fn poll(/* ... */) -> Result<(), Error> { Ok(()) } +/// } +/// ``` +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +pub trait Transport: Send + Sync { + /// Spec + type Spec: Spec; + + /// Create a new subscription name + fn new_name(&self) -> ::SubscriptionId; + + /// Opens a persistent connection and continuously streams events. + /// For protocols that support server push (e.g. WebSocket, SSE). + async fn stream( + &self, + subscribe_changes: mpsc::Receiver>, + topics: Vec>, + reply_to: InternalRelay, + ) -> Result<(), Error>; + + /// Performs a one-shot fetch of any currently available events. + /// Called repeatedly by the consumer when streaming is not available. + async fn poll( + &self, + topics: Vec>, + reply_to: InternalRelay, + ) -> Result<(), Error>; +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + + use tokio::sync::{mpsc, Mutex}; + use tokio::time::{timeout, Duration}; + + use super::{ + InternalRelay, RemoteActiveConsumer, StreamCtrl, SubscribeMessage, Transport, + INTERNAL_POLL_SIZE, + }; + use crate::pub_sub::remote_consumer::Consumer; + use crate::pub_sub::test::{CustomPubSub, IndexTest, Message}; + use crate::pub_sub::{Error, Spec, SubscriptionRequest}; + + // ===== Test Event/Topic types ===== + + #[derive(Clone, Debug)] + enum SubscriptionReq { + Foo(String, u64), + Bar(String, u64), + } + + impl SubscriptionRequest for SubscriptionReq { + type Topic = IndexTest; + + type SubscriptionId = String; + + fn try_get_topics(&self) -> Result, Error> { + Ok(vec![match self { + SubscriptionReq::Foo(_, n) => IndexTest::Foo(*n), + SubscriptionReq::Bar(_, n) => IndexTest::Bar(*n), + }]) + } + + fn subscription_name(&self) -> Arc { + Arc::new(match self { + SubscriptionReq::Foo(n, _) => n.to_string(), + SubscriptionReq::Bar(n, _) => n.to_string(), + }) + } + } + + // ===== A controllable in-memory Transport used by tests ===== + + /// TestTransport relays messages from a broadcast channel to the Consumer via `InternalRelay`. + /// It also forwards Subscribe/Unsubscribe/Stop signals to an observer channel so tests can assert them. + struct TestTransport { + name_ctr: AtomicUsize, + // We forward all transport-loop control messages here so tests can observe them. + observe_ctrl_tx: mpsc::Sender>, + // Whether stream / poll are supported. + support_long: bool, + support_poll: bool, + rx: Mutex>, + } + + impl TestTransport { + fn new( + support_long: bool, + support_poll: bool, + ) -> ( + Self, + mpsc::Sender, + mpsc::Receiver>, + ) { + let (events_tx, rx) = mpsc::channel::(INTERNAL_POLL_SIZE); + let (observe_ctrl_tx, observe_ctrl_rx) = + mpsc::channel::>(INTERNAL_POLL_SIZE); + + let t = TestTransport { + name_ctr: AtomicUsize::new(1), + rx: Mutex::new(rx), + observe_ctrl_tx, + support_long, + support_poll, + }; + + (t, events_tx, observe_ctrl_rx) + } + } + + #[async_trait::async_trait] + impl Transport for TestTransport { + type Spec = CustomPubSub; + + fn new_name(&self) -> ::SubscriptionId { + format!("sub-{}", self.name_ctr.fetch_add(1, Ordering::Relaxed)) + } + + async fn stream( + &self, + mut subscribe_changes: mpsc::Receiver>, + topics: Vec>, + reply_to: InternalRelay, + ) -> Result<(), Error> { + if !self.support_long { + return Err(Error::NotSupported); + } + + // Each invocation creates a fresh broadcast receiver + let mut rx = self.rx.lock().await; + let observe = self.observe_ctrl_tx.clone(); + + for topic in topics { + observe.try_send(StreamCtrl::Subscribe(topic)).unwrap(); + } + + loop { + tokio::select! { + // Forward any control (Subscribe/Unsubscribe/Stop) messages so the test can assert them. + Some(ctrl) = subscribe_changes.recv() => { + observe.try_send(ctrl.clone()).unwrap(); + if matches!(ctrl, StreamCtrl::Stop) { + break; + } + } + // Relay external events into the inner pubsub + Some(msg) = rx.recv() => { + reply_to.send(msg); + } + } + } + + Ok(()) + } + + async fn poll( + &self, + _topics: Vec>, + reply_to: InternalRelay, + ) -> Result<(), Error> { + if !self.support_poll { + return Err(Error::NotSupported); + } + + // On each poll call, drain anything currently pending and return. + // (The Consumer calls this repeatedly; first call happens immediately.) + let mut rx = self.rx.lock().await; + // Non-blocking drain pass: try a few times without sleeping to keep tests snappy + for _ in 0..32 { + match rx.try_recv() { + Ok(msg) => reply_to.send(msg), + Err(mpsc::error::TryRecvError::Empty) => continue, + Err(mpsc::error::TryRecvError::Disconnected) => break, + } + } + Ok(()) + } + } + + // ===== Helpers ===== + + async fn recv_next( + sub: &mut RemoteActiveConsumer, + dur_ms: u64, + ) -> Option<::Event> { + timeout(Duration::from_millis(dur_ms), sub.recv()) + .await + .ok() + .flatten() + } + + async fn expect_ctrl( + rx: &mut mpsc::Receiver>, + dur_ms: u64, + pred: impl Fn(&StreamCtrl) -> bool, + ) -> StreamCtrl { + timeout(Duration::from_millis(dur_ms), async { + loop { + if let Some(msg) = rx.recv().await { + if pred(&msg) { + break msg; + } + } + } + }) + .await + .expect("timed out waiting for control message") + } + + // ===== Tests ===== + + #[tokio::test] + async fn stream_delivery_and_unsubscribe_on_drop() { + // stream supported, poll supported (doesn't matter; prefer long) + let (transport, events_tx, mut ctrl_rx) = TestTransport::new(true, true); + + // prefer_polling = false so connection loop will try stream first. + let consumer = Consumer::new(transport, false, ()); + + // Subscribe to Foo(7) + let mut sub = consumer + .subscribe(SubscriptionReq::Foo("t".to_owned(), 7)) + .expect("subscribe ok"); + + // We should see a Subscribe(name, topic) forwarded to transport + let ctrl = expect_ctrl( + &mut ctrl_rx, + 1000, + |m| matches!(m, StreamCtrl::Subscribe((_, idx)) if *idx == IndexTest::Foo(7)), + ) + .await; + match ctrl { + StreamCtrl::Subscribe((name, idx)) => { + assert_ne!(name, "t".to_owned()); + assert_eq!(idx, IndexTest::Foo(7)); + } + _ => unreachable!(), + } + + // Send an event that matches Foo(7) + events_tx.send(Message { foo: 7, bar: 1 }).await.unwrap(); + let got = recv_next::(&mut sub, 1000) + .await + .expect("got event"); + assert_eq!(got, Message { foo: 7, bar: 1 }); + + // Dropping the RemoteActiveConsumer should trigger an Unsubscribe(name) + drop(sub); + let _ctrl = expect_ctrl(&mut ctrl_rx, 1000, |m| { + matches!(m, StreamCtrl::Unsubscribe(_)) + }) + .await; + + // Drop the Consumer -> Stop is sent so the transport loop exits cleanly + drop(consumer); + let _ = expect_ctrl(&mut ctrl_rx, 1000, |m| matches!(m, StreamCtrl::Stop)).await; + } + + #[tokio::test] + async fn test_cache_and_invalation() { + // stream supported, poll supported (doesn't matter; prefer long) + let (transport, events_tx, mut ctrl_rx) = TestTransport::new(true, true); + + // prefer_polling = false so connection loop will try stream first. + let consumer = Consumer::new(transport, false, ()); + + // Subscribe to Foo(7) + let mut sub_1 = consumer + .subscribe(SubscriptionReq::Foo("t".to_owned(), 7)) + .expect("subscribe ok"); + + // We should see a Subscribe(name, topic) forwarded to transport + let ctrl = expect_ctrl( + &mut ctrl_rx, + 1000, + |m| matches!(m, StreamCtrl::Subscribe((_, idx)) if *idx == IndexTest::Foo(7)), + ) + .await; + match ctrl { + StreamCtrl::Subscribe((name, idx)) => { + assert_ne!(name, "t1".to_owned()); + assert_eq!(idx, IndexTest::Foo(7)); + } + _ => unreachable!(), + } + + // Send an event that matches Foo(7) + events_tx.send(Message { foo: 7, bar: 1 }).await.unwrap(); + let got = recv_next::(&mut sub_1, 1000) + .await + .expect("got event"); + assert_eq!(got, Message { foo: 7, bar: 1 }); + + // Subscribe to Foo(7), should receive the latest message and future messages + let mut sub_2 = consumer + .subscribe(SubscriptionReq::Foo("t2".to_owned(), 7)) + .expect("subscribe ok"); + + let got = recv_next::(&mut sub_2, 1000) + .await + .expect("got event"); + assert_eq!(got, Message { foo: 7, bar: 1 }); + + // Dropping the RemoteActiveConsumer but not unsubscribe, since sub_2 is still active + drop(sub_1); + + // Subscribe to Foo(7), should receive the latest message and future messages + let mut sub_3 = consumer + .subscribe(SubscriptionReq::Foo("t3".to_owned(), 7)) + .expect("subscribe ok"); + + // receive cache message + let got = recv_next::(&mut sub_3, 1000) + .await + .expect("got event"); + assert_eq!(got, Message { foo: 7, bar: 1 }); + + // Send an event that matches Foo(7) + events_tx.send(Message { foo: 7, bar: 2 }).await.unwrap(); + + // receive new message + let got = recv_next::(&mut sub_2, 1000) + .await + .expect("got event"); + assert_eq!(got, Message { foo: 7, bar: 2 }); + + let got = recv_next::(&mut sub_3, 1000) + .await + .expect("got event"); + assert_eq!(got, Message { foo: 7, bar: 2 }); + + drop(sub_2); + drop(sub_3); + + let _ctrl = expect_ctrl(&mut ctrl_rx, 1000, |m| { + matches!(m, StreamCtrl::Unsubscribe(_)) + }) + .await; + + // The cache should be dropped, so no new messages + let mut sub_4 = consumer + .subscribe(SubscriptionReq::Foo("t4".to_owned(), 7)) + .expect("subscribe ok"); + + assert!( + recv_next::(&mut sub_4, 1000).await.is_none(), + "Should have not receive any update" + ); + + drop(sub_4); + + // Drop the Consumer -> Stop is sent so the transport loop exits cleanly + let _ = expect_ctrl(&mut ctrl_rx, 2000, |m| matches!(m, StreamCtrl::Stop)).await; + } + + #[tokio::test] + async fn falls_back_to_poll_when_stream_not_supported() { + // stream NOT supported, poll supported + let (transport, events_tx, _) = TestTransport::new(false, true); + // prefer_polling = true nudges the connection loop to poll first, but even if it + // tried stream, our transport returns NotSupported and the loop will use poll. + let consumer = Consumer::new(transport, true, ()); + + // Subscribe to Bar(5) + let mut sub = consumer + .subscribe(SubscriptionReq::Bar("t".to_owned(), 5)) + .expect("subscribe ok"); + + // Inject an event; the poll path should relay it on the first poll iteration + events_tx.send(Message { foo: 9, bar: 5 }).await.unwrap(); + let got = recv_next::(&mut sub, 1500) + .await + .expect("event relayed via polling"); + assert_eq!(got, Message { foo: 9, bar: 5 }); + } + + #[tokio::test] + async fn multiple_subscribers_share_single_remote_subscription() { + // This validates the "coalescing" behavior in Consumer::subscribe where multiple local + // subscribers to the same Topic should only create one remote subscription. + let (transport, events_tx, mut ctrl_rx) = TestTransport::new(true, true); + let consumer = Consumer::new(transport, false, ()); + + // Two local subscriptions to the SAME topic/name pair (different names) + let mut a = consumer + .subscribe(SubscriptionReq::Foo("t".to_owned(), 1)) + .expect("subscribe A"); + let _ = expect_ctrl( + &mut ctrl_rx, + 1000, + |m| matches!(m, StreamCtrl::Subscribe((_, idx)) if *idx == IndexTest::Foo(1)), + ) + .await; + + let mut b = consumer + .subscribe(SubscriptionReq::Foo("b".to_owned(), 1)) + .expect("subscribe B"); + + // No second Subscribe should be forwarded for the same topic (coalesced). + // Give a little time; if one appears, we'll fail explicitly. + if let Ok(Some(StreamCtrl::Subscribe((_, idx)))) = + timeout(Duration::from_millis(400), ctrl_rx.recv()).await + { + assert_ne!(idx, IndexTest::Foo(1), "should not resubscribe same topic"); + } + + // Send one event and ensure BOTH local subscribers receive it. + events_tx.send(Message { foo: 1, bar: 42 }).await.unwrap(); + let got_a = recv_next::(&mut a, 1000) + .await + .expect("A got"); + let got_b = recv_next::(&mut b, 1000) + .await + .expect("B got"); + assert_eq!(got_a, Message { foo: 1, bar: 42 }); + assert_eq!(got_b, Message { foo: 1, bar: 42 }); + + // Drop B: no Unsubscribe should be sent yet (still one local subscriber). + drop(b); + if let Ok(Some(StreamCtrl::Unsubscribe(_))) = + timeout(Duration::from_millis(400), ctrl_rx.recv()).await + { + panic!("Should NOT unsubscribe while another local subscriber exists"); + } + + // Drop A: now remote unsubscribe should occur. + drop(a); + let _ = expect_ctrl(&mut ctrl_rx, 1000, |m| { + matches!(m, StreamCtrl::Unsubscribe(_)) + }) + .await; + + let _ = expect_ctrl(&mut ctrl_rx, 1000, |m| matches!(m, StreamCtrl::Stop)).await; + } +} diff --git a/crates/cdk-common/src/pub_sub/subscriber.rs b/crates/cdk-common/src/pub_sub/subscriber.rs new file mode 100644 index 00000000..9c46e6ae --- /dev/null +++ b/crates/cdk-common/src/pub_sub/subscriber.rs @@ -0,0 +1,159 @@ +//! Active subscription +use std::fmt::Debug; +use std::sync::atomic::AtomicUsize; +use std::sync::{Arc, Mutex}; + +use tokio::sync::mpsc; + +use super::pubsub::{SubReceiver, TopicTree}; +use super::{Error, Spec}; + +/// Subscription request +pub trait SubscriptionRequest { + /// Topics + type Topic; + + /// Subscription Id + type SubscriptionId; + + /// Try to get topics from the request + fn try_get_topics(&self) -> Result, Error>; + + /// Get the subscription name + fn subscription_name(&self) -> Arc; +} + +/// Active Subscription +pub struct ActiveSubscription +where + S: Spec + 'static, +{ + id: usize, + name: Arc, + active_subscribers: Arc, + topics: TopicTree, + subscribed_to: Vec, + receiver: Option>, +} + +impl ActiveSubscription +where + S: Spec + 'static, +{ + /// Creates a new instance + pub fn new( + id: usize, + name: Arc, + active_subscribers: Arc, + topics: TopicTree, + subscribed_to: Vec, + receiver: Option>, + ) -> Self { + Self { + id, + name, + active_subscribers, + subscribed_to, + topics, + receiver, + } + } + + /// Receives the next event + pub async fn recv(&mut self) -> Option { + self.receiver.as_mut()?.recv().await.map(|(_, event)| event) + } + + /// Try receive an event or return Noen right away + pub fn try_recv(&mut self) -> Option { + self.receiver + .as_mut()? + .try_recv() + .ok() + .map(|(_, event)| event) + } + + /// Get the subscription name + pub fn name(&self) -> &S::SubscriptionId { + &self.name + } +} + +impl Drop for ActiveSubscription +where + S: Spec + 'static, +{ + fn drop(&mut self) { + // remove the listener + let mut topics = self.topics.write(); + for index in self.subscribed_to.drain(..) { + topics.remove(&(index, self.id)); + } + + // decrement the number of active subscribers + self.active_subscribers + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + } +} + +/// Lightweight sink used by producers to send events to subscribers. +/// +/// You usually do not construct a `Subscriber` directly — it is provided to you in +/// the [`Spec::fetch_events`] callback so you can backfill a new subscription. +#[derive(Debug)] +pub struct Subscriber +where + S: Spec + 'static, +{ + subscription: Arc, + inner: mpsc::Sender<(Arc, S::Event)>, + latest: Arc>>, +} + +impl Clone for Subscriber +where + S: Spec + 'static, +{ + fn clone(&self) -> Self { + Self { + subscription: self.subscription.clone(), + inner: self.inner.clone(), + latest: self.latest.clone(), + } + } +} + +impl Subscriber +where + S: Spec + 'static, +{ + /// Create a new instance + pub fn new( + subscription: Arc, + inner: &mpsc::Sender<(Arc, S::Event)>, + ) -> Self { + Self { + inner: inner.clone(), + subscription, + latest: Arc::new(Mutex::new(None)), + } + } + + /// Send a message + pub fn send(&self, event: S::Event) { + let mut latest = if let Ok(reader) = self.latest.lock() { + reader + } else { + let _ = self.inner.try_send((self.subscription.to_owned(), event)); + return; + }; + + if let Some(last_event) = latest.replace(event.clone()) { + if last_event == event { + return; + } + } + + let _ = self.inner.try_send((self.subscription.to_owned(), event)); + } +} diff --git a/crates/cdk-common/src/pub_sub/types.rs b/crates/cdk-common/src/pub_sub/types.rs new file mode 100644 index 00000000..7ceb169a --- /dev/null +++ b/crates/cdk-common/src/pub_sub/types.rs @@ -0,0 +1,80 @@ +//! Pubsub Event definition +//! +//! The Pubsub Event defines the Topic struct and how an event can be converted to Topics. + +use std::hash::Hash; +use std::sync::Arc; + +use serde::de::DeserializeOwned; +use serde::Serialize; + +use super::Subscriber; + +/// Pubsub settings +#[async_trait::async_trait] +pub trait Spec: Send + Sync { + /// Topic + type Topic: Send + + Sync + + Clone + + Eq + + PartialEq + + Ord + + PartialOrd + + Hash + + Send + + Sync + + DeserializeOwned + + Serialize; + + /// Event + type Event: Event + + Send + + Sync + + Eq + + PartialEq + + DeserializeOwned + + Serialize; + + /// Subscription Id + type SubscriptionId: Clone + + Default + + Eq + + PartialEq + + Ord + + PartialOrd + + Hash + + Send + + Sync + + DeserializeOwned + + Serialize; + + /// Create a new context + type Context; + + /// Create a new instance from a given context + fn new_instance(context: Self::Context) -> Arc + where + Self: Sized; + + /// Callback function that is called on new subscriptions, to back-fill optionally the previous + /// events + async fn fetch_events( + self: &Arc, + topics: Vec<::Topic>, + reply_to: Subscriber, + ) where + Self: Sized; +} + +/// Event trait +pub trait Event: Clone + Send + Sync + Eq + PartialEq + DeserializeOwned + Serialize { + /// Generic Topic + /// + /// It should be serializable/deserializable to be stored in the database layer and it should + /// also be sorted in a BTree for in-memory matching + type Topic; + + /// To topics + fn get_topics(&self) -> Vec; +} diff --git a/crates/cdk-common/src/subscription.rs b/crates/cdk-common/src/subscription.rs index 01407de6..3588beb2 100644 --- a/crates/cdk-common/src/subscription.rs +++ b/crates/cdk-common/src/subscription.rs @@ -1,98 +1,115 @@ //! Subscription types and traits -#[cfg(feature = "mint")] +use std::ops::Deref; use std::str::FromStr; +use std::sync::Arc; -use cashu::nut17::{self}; -#[cfg(feature = "mint")] -use cashu::nut17::{Error, Kind, Notification}; -#[cfg(feature = "mint")] +use cashu::nut17::{self, Kind, NotificationId}; use cashu::quote_id::QuoteId; -#[cfg(feature = "mint")] -use cashu::{NotificationPayload, PublicKey}; -#[cfg(feature = "mint")] +use cashu::PublicKey; use serde::{Deserialize, Serialize}; -#[cfg(feature = "mint")] -use crate::pub_sub::index::{Index, Indexable, SubscriptionGlobalId}; -use crate::pub_sub::SubId; +use crate::pub_sub::{Error, SubscriptionRequest}; -/// Subscription parameters. +/// CDK/Mint Subscription parameters. /// /// This is a concrete type alias for `nut17::Params`. -pub type Params = nut17::Params; +pub type Params = nut17::Params>; -/// Wrapper around `nut17::Params` to implement `Indexable` for `Notification`. -#[cfg(feature = "mint")] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct IndexableParams(Params); +impl SubscriptionRequest for Params { + type Topic = NotificationId; -#[cfg(feature = "mint")] -impl From for IndexableParams { - fn from(params: Params) -> Self { - Self(params) + type SubscriptionId = SubId; + + fn subscription_name(&self) -> Arc { + self.id.clone() } -} -#[cfg(feature = "mint")] -impl TryFrom for Vec> { - type Error = Error; - fn try_from(params: IndexableParams) -> Result { - let sub_id: SubscriptionGlobalId = Default::default(); - let params = params.0; - params - .filters - .into_iter() - .map(|filter| { - let idx = match params.kind { - Kind::Bolt11MeltQuote => { - Notification::MeltQuoteBolt11(QuoteId::from_str(&filter)?) - } - Kind::Bolt11MintQuote => { - Notification::MintQuoteBolt11(QuoteId::from_str(&filter)?) - } - Kind::ProofState => Notification::ProofState(PublicKey::from_str(&filter)?), - Kind::Bolt12MintQuote => { - Notification::MintQuoteBolt12(QuoteId::from_str(&filter)?) - } - }; + fn try_get_topics(&self) -> Result, Error> { + self.filters + .iter() + .map(|filter| match self.kind { + Kind::Bolt11MeltQuote => QuoteId::from_str(filter) + .map(NotificationId::MeltQuoteBolt11) + .map_err(|_| Error::ParsingError(filter.to_owned())), + Kind::Bolt11MintQuote => QuoteId::from_str(filter) + .map(NotificationId::MintQuoteBolt11) + .map_err(|_| Error::ParsingError(filter.to_owned())), + Kind::ProofState => PublicKey::from_str(filter) + .map(NotificationId::ProofState) + .map_err(|_| Error::ParsingError(filter.to_owned())), - Ok(Index::from((idx, params.id.clone(), sub_id))) + Kind::Bolt12MintQuote => QuoteId::from_str(filter) + .map(NotificationId::MintQuoteBolt12) + .map_err(|_| Error::ParsingError(filter.to_owned())), }) - .collect::>() + .collect::, _>>() } } -#[cfg(feature = "mint")] -impl AsRef for IndexableParams { - fn as_ref(&self) -> &SubId { - &self.0.id +/// Subscriptions parameters for the wallet +/// +/// This is because the Wallet can subscribe to non CDK quotes, where IDs are not constraint to +/// QuoteId +pub type WalletParams = nut17::Params>; + +impl SubscriptionRequest for WalletParams { + type Topic = NotificationId; + + type SubscriptionId = String; + + fn subscription_name(&self) -> Arc { + self.id.clone() + } + + fn try_get_topics(&self) -> Result, Error> { + self.filters + .iter() + .map(|filter| { + Ok(match self.kind { + Kind::Bolt11MeltQuote => NotificationId::MeltQuoteBolt11(filter.to_owned()), + Kind::Bolt11MintQuote => NotificationId::MintQuoteBolt11(filter.to_owned()), + Kind::ProofState => PublicKey::from_str(filter) + .map(NotificationId::ProofState) + .map_err(|_| Error::ParsingError(filter.to_owned()))?, + + Kind::Bolt12MintQuote => NotificationId::MintQuoteBolt12(filter.to_owned()), + }) + }) + .collect::, _>>() } } -#[cfg(feature = "mint")] -impl Indexable for NotificationPayload { - type Type = Notification; +/// Subscription Id wrapper +/// +/// This is the place to add some sane default (like a max length) to the +/// subscription ID +#[derive(Debug, Clone, Default, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] +pub struct SubId(String); - fn to_indexes(&self) -> Vec> { - match self { - NotificationPayload::ProofState(proof_state) => { - vec![Index::from(Notification::ProofState(proof_state.y))] - } - NotificationPayload::MeltQuoteBolt11Response(melt_quote) => { - vec![Index::from(Notification::MeltQuoteBolt11( - melt_quote.quote.clone(), - ))] - } - NotificationPayload::MintQuoteBolt11Response(mint_quote) => { - vec![Index::from(Notification::MintQuoteBolt11( - mint_quote.quote.clone(), - ))] - } - NotificationPayload::MintQuoteBolt12Response(mint_quote) => { - vec![Index::from(Notification::MintQuoteBolt12( - mint_quote.quote.clone(), - ))] - } - } +impl From<&str> for SubId { + fn from(s: &str) -> Self { + Self(s.to_string()) + } +} + +impl From for SubId { + fn from(s: String) -> Self { + Self(s) + } +} + +impl FromStr for SubId { + type Err = (); + + fn from_str(s: &str) -> Result { + Ok(Self(s.to_string())) + } +} + +impl Deref for SubId { + type Target = String; + + fn deref(&self) -> &Self::Target { + &self.0 } } diff --git a/crates/cdk-common/src/ws.rs b/crates/cdk-common/src/ws.rs index c471cc88..6c01f327 100644 --- a/crates/cdk-common/src/ws.rs +++ b/crates/cdk-common/src/ws.rs @@ -2,6 +2,8 @@ //! //! This module extends the `cashu` crate with types and functions for the CDK, using the correct //! expected ID types. +use std::sync::Arc; + #[cfg(feature = "mint")] use cashu::nut17::ws::JSON_RPC_VERSION; use cashu::nut17::{self}; @@ -10,7 +12,7 @@ use cashu::quote_id::QuoteId; #[cfg(feature = "mint")] use cashu::NotificationPayload; -use crate::pub_sub::SubId; +type SubId = Arc; /// Request to unsubscribe from a websocket subscription pub type WsUnsubscribeRequest = nut17::ws::WsUnsubscribeRequest; diff --git a/crates/cdk-ffi/Cargo.toml b/crates/cdk-ffi/Cargo.toml index b37b47b4..f71bf592 100644 --- a/crates/cdk-ffi/Cargo.toml +++ b/crates/cdk-ffi/Cargo.toml @@ -22,7 +22,7 @@ ctor = "0.2" futures = { workspace = true } once_cell = { workspace = true } rand = { workspace = true } -serde = { workspace = true, features = ["derive"] } +serde = { workspace = true, features = ["derive", "rc"] } serde_json = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["sync", "rt", "rt-multi-thread"] } @@ -41,4 +41,3 @@ postgres = ["cdk-postgres"] [[bin]] name = "uniffi-bindgen" path = "src/bin/uniffi-bindgen.rs" - diff --git a/crates/cdk-ffi/src/types/subscription.rs b/crates/cdk-ffi/src/types/subscription.rs index fb07389d..ccd036bf 100644 --- a/crates/cdk-ffi/src/types/subscription.rs +++ b/crates/cdk-ffi/src/types/subscription.rs @@ -1,6 +1,7 @@ //! Subscription-related FFI types +use std::sync::Arc; -use cdk::pub_sub::SubId; +use cdk::event::MintEvent; use serde::{Deserialize, Serialize}; use super::proof::ProofStateUpdate; @@ -53,21 +54,17 @@ pub struct SubscribeParams { pub id: Option, } -impl From for cdk::nuts::nut17::Params { +impl From for cdk::nuts::nut17::Params> { fn from(params: SubscribeParams) -> Self { - let sub_id = params - .id - .map(|id| SubId::from(id.as_str())) - .unwrap_or_else(|| { - // Generate a random ID - let uuid = uuid::Uuid::new_v4(); - SubId::from(uuid.to_string().as_str()) - }); + let sub_id = params.id.unwrap_or_else(|| { + // Generate a random ID + uuid::Uuid::new_v4().to_string() + }); cdk::nuts::nut17::Params { kind: params.kind.into(), filters: params.filters, - id: sub_id, + id: Arc::new(sub_id), } } } @@ -132,12 +129,7 @@ impl ActiveSubscription { /// Try to receive a notification without blocking pub async fn try_recv(&self) -> Result, FfiError> { let mut guard = self.inner.lock().await; - guard - .try_recv() - .map(|opt| opt.map(Into::into)) - .map_err(|e| FfiError::Generic { - msg: format!("Failed to receive notification: {}", e), - }) + Ok(guard.try_recv().map(Into::into)) } } @@ -156,9 +148,9 @@ pub enum NotificationPayload { }, } -impl From> for NotificationPayload { - fn from(payload: cdk::nuts::NotificationPayload) -> Self { - match payload { +impl From> for NotificationPayload { + fn from(payload: MintEvent) -> Self { + match payload.into() { cdk::nuts::NotificationPayload::ProofState(states) => NotificationPayload::ProofState { proof_states: vec![states.into()], }, diff --git a/crates/cdk-ffi/src/wallet.rs b/crates/cdk-ffi/src/wallet.rs index 802a873e..c144450f 100644 --- a/crates/cdk-ffi/src/wallet.rs +++ b/crates/cdk-ffi/src/wallet.rs @@ -349,7 +349,7 @@ impl Wallet { &self, params: SubscribeParams, ) -> Result, FfiError> { - let cdk_params: cdk::nuts::nut17::Params = params.clone().into(); + let cdk_params: cdk::nuts::nut17::Params> = params.clone().into(); let sub_id = cdk_params.id.to_string(); let active_sub = self.inner.subscribe(cdk_params).await; Ok(std::sync::Arc::new(ActiveSubscription::new( diff --git a/crates/cdk-integration-tests/tests/integration_tests_pure.rs b/crates/cdk-integration-tests/tests/integration_tests_pure.rs index 9ab1fb00..599f8688 100644 --- a/crates/cdk-integration-tests/tests/integration_tests_pure.rs +++ b/crates/cdk-integration-tests/tests/integration_tests_pure.rs @@ -13,6 +13,7 @@ use std::assert_eq; use std::collections::{HashMap, HashSet}; use std::hash::RandomState; use std::str::FromStr; +use std::sync::Arc; use std::time::Duration; use cashu::amount::SplitTarget; @@ -24,7 +25,7 @@ use cashu::{ }; use cdk::mint::Mint; use cdk::nuts::nut00::ProofsMethods; -use cdk::subscription::{IndexableParams, Params}; +use cdk::subscription::Params; use cdk::wallet::types::{TransactionDirection, TransactionId}; use cdk::wallet::{ReceiveOptions, SendMemo, SendOptions}; use cdk::Amount; @@ -485,15 +486,11 @@ pub async fn test_p2pk_swap() { let mut listener = mint_bob .pubsub_manager() - .try_subscribe::( - Params { - kind: cdk::nuts::nut17::Kind::ProofState, - filters: public_keys_to_listen.clone(), - id: "test".into(), - } - .into(), - ) - .await + .subscribe(Params { + kind: cdk::nuts::nut17::Kind::ProofState, + filters: public_keys_to_listen.clone(), + id: Arc::new("test".into()), + }) .expect("valid subscription"); match mint_bob.process_swap_request(swap_request).await { @@ -520,9 +517,8 @@ pub async fn test_p2pk_swap() { sleep(Duration::from_secs(1)).await; let mut msgs = HashMap::new(); - while let Ok((sub_id, msg)) = listener.try_recv() { - assert_eq!(sub_id, "test".into()); - match msg { + while let Some(msg) = listener.try_recv() { + match msg.into_inner() { NotificationPayload::ProofState(ProofState { y, state, .. }) => { msgs.entry(y.to_string()) .or_insert_with(Vec::new) @@ -544,7 +540,7 @@ pub async fn test_p2pk_swap() { ); } - assert!(listener.try_recv().is_err(), "no other event is happening"); + assert!(listener.try_recv().is_none(), "no other event is happening"); assert!(msgs.is_empty(), "Only expected key events are received"); } diff --git a/crates/cdk-integration-tests/tests/regtest.rs b/crates/cdk-integration-tests/tests/regtest.rs index ce0a73cd..98a76496 100644 --- a/crates/cdk-integration-tests/tests/regtest.rs +++ b/crates/cdk-integration-tests/tests/regtest.rs @@ -165,7 +165,7 @@ async fn test_websocket_connection() { .expect("timeout waiting for unpaid notification") .expect("No paid notification received"); - match msg { + match msg.into_inner() { NotificationPayload::MintQuoteBolt11Response(response) => { assert_eq!(response.quote.to_string(), mint_quote.id); assert_eq!(response.state, MintQuoteState::Unpaid); @@ -185,7 +185,7 @@ async fn test_websocket_connection() { .expect("timeout waiting for paid notification") .expect("No paid notification received"); - match msg { + match msg.into_inner() { NotificationPayload::MintQuoteBolt11Response(response) => { assert_eq!(response.quote.to_string(), mint_quote.id); assert_eq!(response.state, MintQuoteState::Paid); diff --git a/crates/cdk/Cargo.toml b/crates/cdk/Cargo.toml index c21dc59d..c80ab32f 100644 --- a/crates/cdk/Cargo.toml +++ b/crates/cdk/Cargo.toml @@ -100,8 +100,6 @@ ring = { version = "0.17.14", features = ["wasm32_unknown_unknown_js"] } rustls = { workspace = true, optional = true } uuid = { workspace = true, features = ["js"] } -wasm-bindgen = "0.2" -wasm-bindgen-futures = "0.4" gloo-timers = { version = "0.3", features = ["futures"] } [[example]] diff --git a/crates/cdk/src/event.rs b/crates/cdk/src/event.rs new file mode 100644 index 00000000..1c4fee67 --- /dev/null +++ b/crates/cdk/src/event.rs @@ -0,0 +1,127 @@ +//! Mint event types +use std::fmt::Debug; +use std::hash::Hash; +use std::ops::Deref; + +use cdk_common::nut17::NotificationId; +use cdk_common::pub_sub::Event; +use cdk_common::{ + MeltQuoteBolt11Response, MintQuoteBolt11Response, MintQuoteBolt12Response, NotificationPayload, + ProofState, +}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; + +/// Simple wrapper over `NotificationPayload` which is a foreign type +#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)] +#[serde(bound = "T: Serialize + DeserializeOwned")] +pub struct MintEvent(NotificationPayload) +where + T: Clone + Eq + PartialEq; + +impl From> for NotificationPayload +where + T: Clone + Eq + PartialEq, +{ + fn from(value: MintEvent) -> Self { + value.0 + } +} + +impl Deref for MintEvent +where + T: Clone + Eq + PartialEq, +{ + type Target = NotificationPayload; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From for MintEvent +where + T: Clone + Eq + PartialEq, +{ + fn from(value: ProofState) -> Self { + Self(NotificationPayload::ProofState(value)) + } +} + +impl MintEvent +where + T: Clone + Eq + PartialEq, +{ + /// New instance + pub fn new(t: NotificationPayload) -> Self { + Self(t) + } + + /// Get inner + pub fn inner(&self) -> &NotificationPayload { + &self.0 + } + + /// Into inner + pub fn into_inner(self) -> NotificationPayload { + self.0 + } +} + +impl From> for MintEvent +where + T: Clone + Eq + PartialEq, +{ + fn from(value: NotificationPayload) -> Self { + Self(value) + } +} + +impl From> for MintEvent +where + T: Clone + Eq + PartialEq, +{ + fn from(value: MintQuoteBolt11Response) -> Self { + Self(NotificationPayload::MintQuoteBolt11Response(value)) + } +} + +impl From> for MintEvent +where + T: Clone + Eq + PartialEq, +{ + fn from(value: MeltQuoteBolt11Response) -> Self { + Self(NotificationPayload::MeltQuoteBolt11Response(value)) + } +} + +impl From> for MintEvent +where + T: Clone + Eq + PartialEq, +{ + fn from(value: MintQuoteBolt12Response) -> Self { + Self(NotificationPayload::MintQuoteBolt12Response(value)) + } +} + +impl Event for MintEvent +where + T: Clone + Serialize + DeserializeOwned + Debug + Ord + Hash + Send + Sync + Eq + PartialEq, +{ + type Topic = NotificationId; + + fn get_topics(&self) -> Vec { + vec![match &self.0 { + NotificationPayload::MeltQuoteBolt11Response(r) => { + NotificationId::MeltQuoteBolt11(r.quote.to_owned()) + } + NotificationPayload::MintQuoteBolt11Response(r) => { + NotificationId::MintQuoteBolt11(r.quote.to_owned()) + } + NotificationPayload::MintQuoteBolt12Response(r) => { + NotificationId::MintQuoteBolt12(r.quote.to_owned()) + } + NotificationPayload::ProofState(p) => NotificationId::ProofState(p.y.to_owned()), + }] + } +} diff --git a/crates/cdk/src/lib.rs b/crates/cdk/src/lib.rs index d26076eb..58653d78 100644 --- a/crates/cdk/src/lib.rs +++ b/crates/cdk/src/lib.rs @@ -32,11 +32,9 @@ mod bip353; #[cfg(all(any(feature = "wallet", feature = "mint"), feature = "auth"))] mod oidc_client; -#[cfg(all(any(feature = "wallet", feature = "mint"), feature = "auth"))] -pub use oidc_client::OidcClient; - -pub mod pub_sub; - +#[cfg(feature = "mint")] +#[doc(hidden)] +pub use cdk_common::payment as cdk_payment; /// Re-export amount type #[doc(hidden)] pub use cdk_common::{ @@ -44,10 +42,11 @@ pub use cdk_common::{ error::{self, Error}, lightning_invoice, mint_url, nuts, secret, util, ws, Amount, Bolt11Invoice, }; -#[cfg(feature = "mint")] -#[doc(hidden)] -pub use cdk_common::{payment as cdk_payment, subscription}; +#[cfg(all(any(feature = "wallet", feature = "mint"), feature = "auth"))] +pub use oidc_client::OidcClient; +#[cfg(any(feature = "wallet", feature = "mint"))] +pub mod event; pub mod fees; #[doc(hidden)] @@ -69,6 +68,8 @@ pub use self::wallet::HttpClient; #[doc(hidden)] pub type Result> = std::result::Result; +/// Re-export subscription +pub use cdk_common::subscription; /// Re-export futures::Stream #[cfg(any(feature = "wallet", feature = "mint"))] pub use futures::{Stream, StreamExt}; diff --git a/crates/cdk/src/mint/issue/mod.rs b/crates/cdk/src/mint/issue/mod.rs index 8104b2e3..419abd3c 100644 --- a/crates/cdk/src/mint/issue/mod.rs +++ b/crates/cdk/src/mint/issue/mod.rs @@ -322,12 +322,12 @@ impl Mint { PaymentMethod::Bolt11 => { let res: MintQuoteBolt11Response = quote.clone().into(); self.pubsub_manager - .broadcast(NotificationPayload::MintQuoteBolt11Response(res)); + .publish(NotificationPayload::MintQuoteBolt11Response(res)); } PaymentMethod::Bolt12 => { let res: MintQuoteBolt12Response = quote.clone().try_into()?; self.pubsub_manager - .broadcast(NotificationPayload::MintQuoteBolt12Response(res)); + .publish(NotificationPayload::MintQuoteBolt12Response(res)); } PaymentMethod::Custom(_) => {} } diff --git a/crates/cdk/src/mint/mod.rs b/crates/cdk/src/mint/mod.rs index f93be3b6..fbe7d689 100644 --- a/crates/cdk/src/mint/mod.rs +++ b/crates/cdk/src/mint/mod.rs @@ -43,7 +43,7 @@ mod ln; mod melt; mod proof_writer; mod start_up_check; -pub mod subscription; +mod subscription; mod swap; mod verification; @@ -206,7 +206,7 @@ impl Mint { Ok(Self { signatory, - pubsub_manager: Arc::new(localstore.clone().into()), + pubsub_manager: PubSubManager::new(localstore.clone()), localstore, #[cfg(feature = "auth")] oidc_client: computed_info.nuts.nut21.as_ref().map(|nut21| { diff --git a/crates/cdk/src/mint/subscription.rs b/crates/cdk/src/mint/subscription.rs new file mode 100644 index 00000000..f60ebde9 --- /dev/null +++ b/crates/cdk/src/mint/subscription.rs @@ -0,0 +1,244 @@ +//! Specific Subscription for the cdk crate + +use std::ops::Deref; +use std::sync::Arc; + +use cdk_common::database::DynMintDatabase; +use cdk_common::mint::MintQuote; +use cdk_common::nut17::NotificationId; +use cdk_common::pub_sub::{Pubsub, Spec, Subscriber}; +use cdk_common::subscription::SubId; +use cdk_common::{ + Amount, BlindSignature, MeltQuoteBolt11Response, MeltQuoteState, MintQuoteBolt11Response, + MintQuoteBolt12Response, MintQuoteState, PaymentMethod, ProofState, PublicKey, QuoteId, +}; + +use crate::event::MintEvent; + +/// Mint subtopics +#[derive(Clone)] +pub struct MintPubSubSpec { + db: DynMintDatabase, +} + +impl MintPubSubSpec { + async fn get_events_from_db( + &self, + request: &[NotificationId], + ) -> Result>, String> { + let mut to_return = vec![]; + let mut public_keys: Vec = Vec::new(); + let mut melt_queries = Vec::new(); + let mut mint_queries = Vec::new(); + + for idx in request.iter() { + match idx { + NotificationId::ProofState(pk) => public_keys.push(*pk), + NotificationId::MeltQuoteBolt11(uuid) => { + melt_queries.push(self.db.get_melt_quote(uuid)) + } + NotificationId::MintQuoteBolt11(uuid) => { + mint_queries.push(self.db.get_mint_quote(uuid)) + } + NotificationId::MintQuoteBolt12(uuid) => { + mint_queries.push(self.db.get_mint_quote(uuid)) + } + NotificationId::MeltQuoteBolt12(uuid) => { + melt_queries.push(self.db.get_melt_quote(uuid)) + } + } + } + + if !melt_queries.is_empty() { + to_return.extend( + futures::future::try_join_all(melt_queries) + .await + .map(|quotes| { + quotes + .into_iter() + .filter_map(|quote| quote.map(|x| x.into())) + .map(|x: MeltQuoteBolt11Response| x.into()) + .collect::>() + }) + .map_err(|e| e.to_string())?, + ); + } + + if !mint_queries.is_empty() { + to_return.extend( + futures::future::try_join_all(mint_queries) + .await + .map(|quotes| { + quotes + .into_iter() + .filter_map(|quote| { + quote.and_then(|x| match x.payment_method { + PaymentMethod::Bolt11 => { + let response: MintQuoteBolt11Response = x.into(); + Some(response.into()) + } + PaymentMethod::Bolt12 => match x.try_into() { + Ok(response) => { + let response: MintQuoteBolt12Response = + response; + Some(response.into()) + } + Err(_) => None, + }, + PaymentMethod::Custom(_) => None, + }) + }) + .collect::>() + }) + .map_err(|e| e.to_string())?, + ); + } + + if !public_keys.is_empty() { + to_return.extend( + self.db + .get_proofs_states(public_keys.as_slice()) + .await + .map_err(|e| e.to_string())? + .into_iter() + .enumerate() + .filter_map(|(idx, state)| state.map(|state| (public_keys[idx], state).into())) + .map(|state: ProofState| state.into()), + ); + } + + Ok(to_return) + } +} + +#[async_trait::async_trait] +impl Spec for MintPubSubSpec { + type SubscriptionId = SubId; + + type Topic = NotificationId; + + type Event = MintEvent; + + type Context = DynMintDatabase; + + fn new_instance(context: Self::Context) -> Arc { + Arc::new(Self { db: context }) + } + + async fn fetch_events(self: &Arc, topics: Vec, reply_to: Subscriber) { + for event in self + .get_events_from_db(&topics) + .await + .inspect_err(|err| tracing::error!("Error reading events from db {err:?}")) + .unwrap_or_default() + { + let _ = reply_to.send(event); + } + } +} + +/// PubsubManager +pub struct PubSubManager(Pubsub); + +impl PubSubManager { + /// Create a new instance + pub fn new(db: DynMintDatabase) -> Arc { + Arc::new(Self(Pubsub::new(MintPubSubSpec::new_instance(db)))) + } + + /// Helper function to emit a ProofState status + pub fn proof_state>(&self, event: E) { + self.publish(event.into()); + } + + /// Helper function to publish even of a mint quote being paid + pub fn mint_quote_issue(&self, mint_quote: &MintQuote, total_issued: Amount) { + match mint_quote.payment_method { + PaymentMethod::Bolt11 => { + self.mint_quote_bolt11_status(mint_quote.clone(), MintQuoteState::Issued); + } + PaymentMethod::Bolt12 => { + self.mint_quote_bolt12_status( + mint_quote.clone(), + mint_quote.amount_paid(), + total_issued, + ); + } + _ => { + // We don't send ws updates for unknown methods + } + } + } + + /// Helper function to publish even of a mint quote being paid + pub fn mint_quote_payment(&self, mint_quote: &MintQuote, total_paid: Amount) { + match mint_quote.payment_method { + PaymentMethod::Bolt11 => { + self.mint_quote_bolt11_status(mint_quote.clone(), MintQuoteState::Paid); + } + PaymentMethod::Bolt12 => { + self.mint_quote_bolt12_status( + mint_quote.clone(), + total_paid, + mint_quote.amount_issued(), + ); + } + _ => { + // We don't send ws updates for unknown methods + } + } + } + + /// Helper function to emit a MintQuoteBolt11Response status + pub fn mint_quote_bolt11_status>>( + &self, + quote: E, + new_state: MintQuoteState, + ) { + let mut event = quote.into(); + event.state = new_state; + + self.publish(event); + } + + /// Helper function to emit a MintQuoteBolt11Response status + pub fn mint_quote_bolt12_status>>( + &self, + quote: E, + amount_paid: Amount, + amount_issued: Amount, + ) { + if let Ok(mut event) = quote.try_into() { + event.amount_paid = amount_paid; + event.amount_issued = amount_issued; + + self.publish(event); + } else { + tracing::warn!("Could not convert quote to MintQuoteResponse"); + } + } + + /// Helper function to emit a MeltQuoteBolt11Response status + pub fn melt_quote_status>>( + &self, + quote: E, + payment_preimage: Option, + change: Option>, + new_state: MeltQuoteState, + ) { + let mut quote = quote.into(); + quote.state = new_state; + quote.paid = Some(new_state == MeltQuoteState::Paid); + quote.payment_preimage = payment_preimage; + quote.change = change; + self.publish(quote); + } +} + +impl Deref for PubSubManager { + type Target = Pubsub; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/crates/cdk/src/mint/subscription/manager.rs b/crates/cdk/src/mint/subscription/manager.rs deleted file mode 100644 index d4c3b974..00000000 --- a/crates/cdk/src/mint/subscription/manager.rs +++ /dev/null @@ -1,292 +0,0 @@ -//! Specific Subscription for the cdk crate -use std::ops::Deref; - -use cdk_common::database::DynMintDatabase; -use cdk_common::mint::MintQuote; -use cdk_common::nut17::Notification; -use cdk_common::quote_id::QuoteId; -use cdk_common::{Amount, MintQuoteBolt12Response, NotificationPayload, PaymentMethod}; - -use super::OnSubscription; -use crate::nuts::{ - BlindSignature, MeltQuoteBolt11Response, MeltQuoteState, MintQuoteBolt11Response, - MintQuoteState, ProofState, -}; -use crate::pub_sub; - -/// Manager -/// Publish–subscribe manager -/// -/// Nut-17 implementation is system-wide and not only through the WebSocket, so -/// it is possible for another part of the system to subscribe to events. -pub struct PubSubManager( - pub_sub::Manager, Notification, OnSubscription>, -); - -#[allow(clippy::default_constructed_unit_structs)] -impl Default for PubSubManager { - fn default() -> Self { - PubSubManager(OnSubscription::default().into()) - } -} - -impl From for PubSubManager { - fn from(val: DynMintDatabase) -> Self { - PubSubManager(OnSubscription(Some(val)).into()) - } -} - -impl Deref for PubSubManager { - type Target = pub_sub::Manager, Notification, OnSubscription>; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl PubSubManager { - /// Helper function to emit a ProofState status - pub fn proof_state>(&self, event: E) { - self.broadcast(event.into().into()); - } - - /// Helper function to publish even of a mint quote being paid - pub fn mint_quote_issue(&self, mint_quote: &MintQuote, total_issued: Amount) { - match mint_quote.payment_method { - PaymentMethod::Bolt11 => { - self.mint_quote_bolt11_status(mint_quote.clone(), MintQuoteState::Issued); - } - PaymentMethod::Bolt12 => { - self.mint_quote_bolt12_status( - mint_quote.clone(), - mint_quote.amount_paid(), - total_issued, - ); - } - _ => { - // We don't send ws updates for unknown methods - } - } - } - - /// Helper function to publish even of a mint quote being paid - pub fn mint_quote_payment(&self, mint_quote: &MintQuote, total_paid: Amount) { - match mint_quote.payment_method { - PaymentMethod::Bolt11 => { - self.mint_quote_bolt11_status(mint_quote.clone(), MintQuoteState::Paid); - } - PaymentMethod::Bolt12 => { - self.mint_quote_bolt12_status( - mint_quote.clone(), - total_paid, - mint_quote.amount_issued(), - ); - } - _ => { - // We don't send ws updates for unknown methods - } - } - } - - /// Helper function to emit a MintQuoteBolt11Response status - pub fn mint_quote_bolt11_status>>( - &self, - quote: E, - new_state: MintQuoteState, - ) { - let mut event = quote.into(); - event.state = new_state; - - self.broadcast(event.into()); - } - - /// Helper function to emit a MintQuoteBolt11Response status - pub fn mint_quote_bolt12_status>>( - &self, - quote: E, - amount_paid: Amount, - amount_issued: Amount, - ) { - if let Ok(mut event) = quote.try_into() { - event.amount_paid = amount_paid; - event.amount_issued = amount_issued; - - self.broadcast(event.into()); - } else { - tracing::warn!("Could not convert quote to MintQuoteResponse"); - } - } - - /// Helper function to emit a MeltQuoteBolt11Response status - pub fn melt_quote_status>>( - &self, - quote: E, - payment_preimage: Option, - change: Option>, - new_state: MeltQuoteState, - ) { - let mut quote = quote.into(); - quote.state = new_state; - quote.paid = Some(new_state == MeltQuoteState::Paid); - quote.payment_preimage = payment_preimage; - quote.change = change; - self.broadcast(quote.into()); - } -} - -#[cfg(test)] -mod test { - use std::time::Duration; - - use tokio::time::sleep; - - use super::*; - use crate::nuts::nut17::Kind; - use crate::nuts::{PublicKey, State}; - use crate::subscription::{IndexableParams, Params}; - - #[tokio::test] - async fn active_and_drop() { - let manager = PubSubManager::default(); - let params: IndexableParams = Params { - kind: Kind::ProofState, - filters: vec![ - "02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2".to_owned(), - ], - id: "uno".into(), - } - .into(); - - // Although the same param is used, two subscriptions are created, that - // is because each index is unique, thanks to `Unique`, it is the - // responsibility of the implementor to make sure that SubId are unique - // either globally or per client - let subscriptions = vec![ - manager - .try_subscribe(params.clone()) - .await - .expect("valid subscription"), - manager - .try_subscribe(params) - .await - .expect("valid subscription"), - ]; - assert_eq!(2, manager.active_subscriptions()); - drop(subscriptions); - - sleep(Duration::from_millis(10)).await; - - assert_eq!(0, manager.active_subscriptions()); - } - - #[tokio::test] - async fn broadcast() { - let manager = PubSubManager::default(); - let mut subscriptions = [ - manager - .try_subscribe::( - Params { - kind: Kind::ProofState, - filters: vec![ - "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104" - .to_string(), - ], - id: "uno".into(), - } - .into(), - ) - .await - .expect("valid subscription"), - manager - .try_subscribe::( - Params { - kind: Kind::ProofState, - filters: vec![ - "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104" - .to_string(), - ], - id: "dos".into(), - } - .into(), - ) - .await - .expect("valid subscription"), - ]; - - let event = ProofState { - y: PublicKey::from_hex( - "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104", - ) - .expect("valid pk"), - state: State::Pending, - witness: None, - }; - - manager.broadcast(event.into()); - - sleep(Duration::from_millis(10)).await; - - let (sub1, _) = subscriptions[0].try_recv().expect("valid message"); - assert_eq!("uno", *sub1); - - let (sub1, _) = subscriptions[1].try_recv().expect("valid message"); - assert_eq!("dos", *sub1); - - assert!(subscriptions[0].try_recv().is_err()); - assert!(subscriptions[1].try_recv().is_err()); - } - - #[test] - fn parsing_request() { - let json = r#"{"kind":"proof_state","filters":["x"],"subId":"uno"}"#; - let params: Params = serde_json::from_str(json).expect("valid json"); - assert_eq!(params.kind, Kind::ProofState); - assert_eq!(params.filters, vec!["x"]); - assert_eq!(*params.id, "uno"); - } - - #[tokio::test] - async fn json_test() { - let manager = PubSubManager::default(); - let mut subscription = manager - .try_subscribe::( - serde_json::from_str(r#"{"kind":"proof_state","filters":["02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104"],"subId":"uno"}"#) - .expect("valid json"), - ) - .await.expect("valid subscription"); - - manager.broadcast( - ProofState { - y: PublicKey::from_hex( - "02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104", - ) - .expect("valid pk"), - state: State::Pending, - witness: None, - } - .into(), - ); - - // no one is listening for this event - manager.broadcast( - ProofState { - y: PublicKey::from_hex( - "020000000000000000000000000000000000000000000000000000000000000001", - ) - .expect("valid pk"), - state: State::Pending, - witness: None, - } - .into(), - ); - - sleep(Duration::from_millis(10)).await; - let (sub1, msg) = subscription.try_recv().expect("valid message"); - assert_eq!("uno", *sub1); - assert_eq!( - r#"{"Y":"02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104","state":"PENDING","witness":null}"#, - serde_json::to_string(&msg).expect("valid json") - ); - assert!(subscription.try_recv().is_err()); - } -} diff --git a/crates/cdk/src/mint/subscription/mod.rs b/crates/cdk/src/mint/subscription/mod.rs deleted file mode 100644 index 20216fab..00000000 --- a/crates/cdk/src/mint/subscription/mod.rs +++ /dev/null @@ -1,12 +0,0 @@ -//! Specific Subscription for the cdk crate - -#[cfg(feature = "mint")] -mod manager; -#[cfg(feature = "mint")] -mod on_subscription; -#[cfg(feature = "mint")] -pub use manager::PubSubManager; -#[cfg(feature = "mint")] -pub use on_subscription::OnSubscription; - -pub use crate::pub_sub::SubId; diff --git a/crates/cdk/src/mint/subscription/on_subscription.rs b/crates/cdk/src/mint/subscription/on_subscription.rs deleted file mode 100644 index 1e331db4..00000000 --- a/crates/cdk/src/mint/subscription/on_subscription.rs +++ /dev/null @@ -1,119 +0,0 @@ -//! On Subscription -//! -//! This module contains the code that is triggered when a new subscription is created. - -use cdk_common::database::DynMintDatabase; -use cdk_common::nut17::Notification; -use cdk_common::pub_sub::OnNewSubscription; -use cdk_common::quote_id::QuoteId; -use cdk_common::{MintQuoteBolt12Response, NotificationPayload, PaymentMethod}; - -use crate::nuts::{MeltQuoteBolt11Response, MintQuoteBolt11Response, ProofState, PublicKey}; - -#[derive(Default)] -/// Subscription Init -/// -/// This struct triggers code when a new subscription is created. -/// -/// It is used to send the initial state of the subscription to the client. -pub struct OnSubscription(pub(crate) Option); - -#[async_trait::async_trait] -impl OnNewSubscription for OnSubscription { - type Event = NotificationPayload; - type Index = Notification; - - async fn on_new_subscription( - &self, - request: &[&Self::Index], - ) -> Result, String> { - let datastore = if let Some(localstore) = self.0.as_ref() { - localstore - } else { - return Ok(vec![]); - }; - - let mut to_return = vec![]; - let mut public_keys: Vec = Vec::new(); - let mut melt_queries = Vec::new(); - let mut mint_queries = Vec::new(); - - for idx in request.iter() { - match idx { - Notification::ProofState(pk) => public_keys.push(*pk), - Notification::MeltQuoteBolt11(uuid) => { - melt_queries.push(datastore.get_melt_quote(uuid)) - } - Notification::MintQuoteBolt11(uuid) => { - mint_queries.push(datastore.get_mint_quote(uuid)) - } - Notification::MintQuoteBolt12(uuid) => { - mint_queries.push(datastore.get_mint_quote(uuid)) - } - Notification::MeltQuoteBolt12(uuid) => { - melt_queries.push(datastore.get_melt_quote(uuid)) - } - } - } - - if !melt_queries.is_empty() { - to_return.extend( - futures::future::try_join_all(melt_queries) - .await - .map(|quotes| { - quotes - .into_iter() - .filter_map(|quote| quote.map(|x| x.into())) - .map(|x: MeltQuoteBolt11Response| x.into()) - .collect::>() - }) - .map_err(|e| e.to_string())?, - ); - } - - if !mint_queries.is_empty() { - to_return.extend( - futures::future::try_join_all(mint_queries) - .await - .map(|quotes| { - quotes - .into_iter() - .filter_map(|quote| { - quote.and_then(|x| match x.payment_method { - PaymentMethod::Bolt11 => { - let response: MintQuoteBolt11Response = x.into(); - Some(response.into()) - } - PaymentMethod::Bolt12 => match x.try_into() { - Ok(response) => { - let response: MintQuoteBolt12Response = - response; - Some(response.into()) - } - Err(_) => None, - }, - PaymentMethod::Custom(_) => None, - }) - }) - .collect::>() - }) - .map_err(|e| e.to_string())?, - ); - } - - if !public_keys.is_empty() { - to_return.extend( - datastore - .get_proofs_states(public_keys.as_slice()) - .await - .map_err(|e| e.to_string())? - .into_iter() - .enumerate() - .filter_map(|(idx, state)| state.map(|state| (public_keys[idx], state).into())) - .map(|state: ProofState| state.into()), - ); - } - - Ok(to_return) - } -} diff --git a/crates/cdk/src/pub_sub.rs b/crates/cdk/src/pub_sub.rs deleted file mode 100644 index f8365552..00000000 --- a/crates/cdk/src/pub_sub.rs +++ /dev/null @@ -1,339 +0,0 @@ -//! Publish–subscribe pattern. -//! -//! This is a generic implementation for -//! [NUT-17]() with a type -//! agnostic Publish-subscribe manager. -//! -//! The manager has a method for subscribers to subscribe to events with a -//! generic type that must be converted to a vector of indexes. -//! -//! Events are also generic that should implement the `Indexable` trait. -use std::cmp::Ordering; -use std::collections::{BTreeMap, HashSet}; -use std::fmt::Debug; -use std::ops::{Deref, DerefMut}; -use std::sync::atomic::{self, AtomicUsize}; -use std::sync::Arc; - -pub use cdk_common::pub_sub::index::{Index, Indexable, SubscriptionGlobalId}; -use cdk_common::pub_sub::OnNewSubscription; -pub use cdk_common::pub_sub::SubId; -use tokio::sync::{mpsc, RwLock}; -use tokio::task::JoinHandle; - -type IndexTree = Arc, mpsc::Sender<(SubId, T)>>>>; - -/// Default size of the remove channel -pub const DEFAULT_REMOVE_SIZE: usize = 10_000; - -/// Default channel size for subscription buffering -pub const DEFAULT_CHANNEL_SIZE: usize = 10; - -/// Subscription manager -/// -/// This object keep track of all subscription listener and it is also -/// responsible for broadcasting events to all listeners -/// -/// The content of the notification is not relevant to this scope and it is up -/// to the application, therefore the generic T is used instead of a specific -/// type -pub struct Manager -where - T: Indexable + Clone + Send + Sync + 'static, - I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static, - F: OnNewSubscription + Send + Sync + 'static, -{ - indexes: IndexTree, - on_new_subscription: Option>, - unsubscription_sender: mpsc::Sender<(SubId, Vec>)>, - active_subscriptions: Arc, - background_subscription_remover: Option>, -} - -impl Default for Manager -where - T: Indexable + Clone + Send + Sync + 'static, - I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static, - F: OnNewSubscription + Send + Sync + 'static, -{ - fn default() -> Self { - let (sender, receiver) = mpsc::channel(DEFAULT_REMOVE_SIZE); - let active_subscriptions: Arc = Default::default(); - let storage: IndexTree = Arc::new(Default::default()); - - Self { - background_subscription_remover: Some(tokio::spawn(Self::remove_subscription( - receiver, - storage.clone(), - active_subscriptions.clone(), - ))), - on_new_subscription: None, - unsubscription_sender: sender, - active_subscriptions, - indexes: storage, - } - } -} - -impl From for Manager -where - T: Indexable + Clone + Send + Sync + 'static, - I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static, - F: OnNewSubscription + Send + Sync + 'static, -{ - fn from(value: F) -> Self { - let mut manager: Self = Default::default(); - manager.on_new_subscription = Some(Arc::new(value)); - manager - } -} - -impl Manager -where - T: Indexable + Clone + Send + Sync + 'static, - I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static, - F: OnNewSubscription + Send + Sync + 'static, -{ - #[inline] - /// Broadcast an event to all listeners - /// - /// This function takes an Arc to the storage struct, the event_id, the kind - /// and the vent to broadcast - async fn broadcast_impl(storage: &IndexTree, event: T) { - let index_storage = storage.read().await; - let mut sent = HashSet::new(); - for index in event.to_indexes() { - for (key, sender) in index_storage.range(index.clone()..) { - if index.cmp_prefix(key) != Ordering::Equal { - break; - } - let sub_id = key.unique_id(); - if sent.contains(&sub_id) { - continue; - } - sent.insert(sub_id); - let _ = sender.try_send((key.into(), event.clone())); - } - } - } - - /// Broadcasts an event to all listeners - /// - /// This public method will not block the caller, it will spawn a new task - /// instead - pub fn broadcast(&self, event: T) { - let storage = self.indexes.clone(); - tokio::spawn(async move { - Self::broadcast_impl(&storage, event).await; - }); - } - - /// Broadcasts an event to all listeners - /// - /// This method is async and will await for the broadcast to be completed - pub async fn broadcast_async(&self, event: T) { - Self::broadcast_impl(&self.indexes, event).await; - } - - /// Specific of the subscription, this is the abstraction between `subscribe` and `try_subscribe` - #[inline(always)] - async fn subscribe_inner( - &self, - sub_id: SubId, - indexes: Vec>, - ) -> ActiveSubscription { - let (sender, receiver) = mpsc::channel(10); - - let mut index_storage = self.indexes.write().await; - // Subscribe to events as soon as possible - for index in indexes.clone() { - index_storage.insert(index, sender.clone()); - } - drop(index_storage); - - if let Some(on_new_subscription) = self.on_new_subscription.clone() { - // After we're subscribed already, fetch the current status of matching events. It is - // down in another thread to return right away - let indexes_for_worker = indexes.clone(); - let sub_id_for_worker = sub_id.clone(); - tokio::spawn(async move { - match on_new_subscription - .on_new_subscription( - &indexes_for_worker - .iter() - .map(|x| x.deref()) - .collect::>(), - ) - .await - { - Ok(events) => { - for event in events { - let _ = sender.try_send((sub_id_for_worker.clone(), event)); - } - } - Err(err) => { - tracing::info!( - "Failed to get initial state for subscription: {:?}, {}", - sub_id_for_worker, - err - ); - } - } - }); - } - - self.active_subscriptions - .fetch_add(1, atomic::Ordering::Relaxed); - - ActiveSubscription { - sub_id, - receiver, - indexes, - drop: self.unsubscription_sender.clone(), - } - } - - /// Try to subscribe to a specific event - pub async fn try_subscribe

(&self, params: P) -> Result, P::Error> - where - P: AsRef + TryInto>>, - { - Ok(self - .subscribe_inner(params.as_ref().clone(), params.try_into()?) - .await) - } - - /// Subscribe to a specific event - pub async fn subscribe

(&self, params: P) -> ActiveSubscription - where - P: AsRef + Into>>, - { - self.subscribe_inner(params.as_ref().clone(), params.into()) - .await - } - - /// Return number of active subscriptions - pub fn active_subscriptions(&self) -> usize { - self.active_subscriptions.load(atomic::Ordering::SeqCst) - } - - /// Task to remove dropped subscriptions from the storage struct - /// - /// This task will run in the background (and will be dropped when the [`Manager`] - /// is) and will remove subscriptions from the storage struct it is dropped. - async fn remove_subscription( - mut receiver: mpsc::Receiver<(SubId, Vec>)>, - storage: IndexTree, - active_subscriptions: Arc, - ) { - while let Some((sub_id, indexes)) = receiver.recv().await { - tracing::info!("Removing subscription: {}", *sub_id); - - active_subscriptions.fetch_sub(1, atomic::Ordering::AcqRel); - - let mut index_storage = storage.write().await; - for key in indexes { - index_storage.remove(&key); - } - drop(index_storage); - } - } -} - -/// Manager goes out of scope, stop all background tasks -impl Drop for Manager -where - T: Indexable + Clone + Send + Sync + 'static, - I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static, - F: OnNewSubscription + Send + Sync + 'static, -{ - fn drop(&mut self) { - if let Some(handler) = self.background_subscription_remover.take() { - handler.abort(); - } - } -} - -/// Active Subscription -/// -/// This struct is a wrapper around the `mpsc::Receiver` and it also used -/// to keep track of the subscription itself. When this struct goes out of -/// scope, it will notify the Manager about it, so it can be removed from the -/// list of active listeners -pub struct ActiveSubscription -where - T: Send + Sync, - I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static, -{ - /// The subscription ID - pub sub_id: SubId, - indexes: Vec>, - receiver: mpsc::Receiver<(SubId, T)>, - drop: mpsc::Sender<(SubId, Vec>)>, -} - -impl Deref for ActiveSubscription -where - T: Send + Sync, - I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static, -{ - type Target = mpsc::Receiver<(SubId, T)>; - - fn deref(&self) -> &Self::Target { - &self.receiver - } -} - -impl DerefMut for ActiveSubscription -where - T: Indexable + Clone + Send + Sync + 'static, - I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static, -{ - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.receiver - } -} - -/// The ActiveSubscription is Drop out of scope, notify the Manager about it, so -/// it can be removed from the list of active listeners -/// -/// Having this in place, we can avoid memory leaks and also makes it super -/// simple to implement the Unsubscribe method -impl Drop for ActiveSubscription -where - T: Send + Sync, - I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static, -{ - fn drop(&mut self) { - let _ = self - .drop - .try_send((self.sub_id.clone(), self.indexes.drain(..).collect())); - } -} - -#[cfg(test)] -mod test { - use tokio::sync::mpsc; - - use super::*; - - #[test] - fn test_active_subscription_drop() { - let (tx, rx) = mpsc::channel::<(SubId, ())>(10); - let sub_id = SubId::from("test_sub_id"); - let indexes: Vec> = vec![Index::from(("test".to_string(), sub_id.clone()))]; - let (drop_tx, mut drop_rx) = mpsc::channel(10); - - { - let _active_subscription = ActiveSubscription { - sub_id: sub_id.clone(), - indexes, - receiver: rx, - drop: drop_tx, - }; - // When it goes out of scope, it should notify - } - assert_eq!(drop_rx.try_recv().unwrap().0, sub_id); // it should have notified - assert!(tx.try_send(("foo".into(), ())).is_err()); // subscriber is dropped - } -} diff --git a/crates/cdk/src/wallet/mod.rs b/crates/cdk/src/wallet/mod.rs index 5f2319f2..6ac0f14d 100644 --- a/crates/cdk/src/wallet/mod.rs +++ b/crates/cdk/src/wallet/mod.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use cdk_common::amount::FeeAndAmounts; use cdk_common::database::{self, WalletDatabase}; -use cdk_common::subscription::Params; +use cdk_common::subscription::WalletParams; use getrandom::getrandom; use subscription::{ActiveSubscription, SubscriptionManager}; #[cfg(feature = "auth")] @@ -108,40 +108,42 @@ pub enum WalletSubscription { Bolt12MintQuoteState(Vec), } -impl From for Params { +impl From for WalletParams { fn from(val: WalletSubscription) -> Self { let mut buffer = vec![0u8; 10]; getrandom(&mut buffer).expect("Failed to generate random bytes"); - let id = buffer - .iter() - .map(|&byte| { - let index = byte as usize % ALPHANUMERIC.len(); // 62 alphanumeric characters (A-Z, a-z, 0-9) - ALPHANUMERIC[index] as char - }) - .collect::(); + let id = Arc::new( + buffer + .iter() + .map(|&byte| { + let index = byte as usize % ALPHANUMERIC.len(); // 62 alphanumeric characters (A-Z, a-z, 0-9) + ALPHANUMERIC[index] as char + }) + .collect::(), + ); match val { - WalletSubscription::ProofState(filters) => Params { + WalletSubscription::ProofState(filters) => WalletParams { filters, kind: Kind::ProofState, - id: id.into(), + id, }, - WalletSubscription::Bolt11MintQuoteState(filters) => Params { + WalletSubscription::Bolt11MintQuoteState(filters) => WalletParams { filters, kind: Kind::Bolt11MintQuote, - id: id.into(), + id, }, - WalletSubscription::Bolt11MeltQuoteState(filters) => Params { + WalletSubscription::Bolt11MeltQuoteState(filters) => WalletParams { filters, kind: Kind::Bolt11MeltQuote, - id: id.into(), + id, }, - WalletSubscription::Bolt12MintQuoteState(filters) => Params { + WalletSubscription::Bolt12MintQuoteState(filters) => WalletParams { filters, kind: Kind::Bolt12MintQuote, - id: id.into(), + id, }, } } @@ -193,10 +195,10 @@ impl Wallet { } /// Subscribe to events - pub async fn subscribe>(&self, query: T) -> ActiveSubscription { + pub async fn subscribe>(&self, query: T) -> ActiveSubscription { self.subscription - .subscribe(self.mint_url.clone(), query.into(), Arc::new(self.clone())) - .await + .subscribe(self.mint_url.clone(), query.into()) + .expect("FIXME") } /// Fee required for proof set diff --git a/crates/cdk/src/wallet/multi_mint_wallet.rs b/crates/cdk/src/wallet/multi_mint_wallet.rs index 9534f1a1..1b5ec3f3 100644 --- a/crates/cdk/src/wallet/multi_mint_wallet.rs +++ b/crates/cdk/src/wallet/multi_mint_wallet.rs @@ -4,6 +4,7 @@ //! pairs use std::collections::BTreeMap; +use std::ops::Deref; use std::str::FromStr; use std::sync::Arc; @@ -675,7 +676,7 @@ impl MultiMintWallet { // Check if this is a mint quote response with paid state if let crate::nuts::nut17::NotificationPayload::MintQuoteBolt11Response( quote_response, - ) = notification + ) = notification.deref() { if quote_response.state == QuoteState::Paid { // Quote is paid, now mint the tokens @@ -1264,7 +1265,7 @@ impl MultiMintWallet { /// Melt (pay invoice) with automatic wallet selection (deprecated, use specific mint functions for better control) /// /// Automatically selects the best wallet to pay from based on: - /// - Available balance + /// - Available balance /// - Fees /// /// # Examples diff --git a/crates/cdk/src/wallet/streams/payment.rs b/crates/cdk/src/wallet/streams/payment.rs index 8139570a..354e0611 100644 --- a/crates/cdk/src/wallet/streams/payment.rs +++ b/crates/cdk/src/wallet/streams/payment.rs @@ -13,10 +13,11 @@ use futures::{FutureExt, Stream, StreamExt}; use tokio_util::sync::CancellationToken; use super::RecvFuture; +use crate::event::MintEvent; use crate::wallet::subscription::ActiveSubscription; use crate::{Wallet, WalletSubscription}; -type SubscribeReceived = (Option>, Vec); +type SubscribeReceived = (Option>, Vec); type PaymentValue = (String, Option); /// PaymentWaiter @@ -145,7 +146,7 @@ impl<'a> PaymentStream<'a> { Poll::Ready(None) } Some(info) => { - match info { + match info.into_inner() { NotificationPayload::MintQuoteBolt11Response(info) => { if info.state == MintQuoteState::Paid { self.is_finalized = true; diff --git a/crates/cdk/src/wallet/subscription/http.rs b/crates/cdk/src/wallet/subscription/http.rs deleted file mode 100644 index bbf2021c..00000000 --- a/crates/cdk/src/wallet/subscription/http.rs +++ /dev/null @@ -1,238 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; - -use cdk_common::MintQuoteBolt12Response; -use tokio::sync::{mpsc, RwLock}; -#[cfg(not(target_arch = "wasm32"))] -use tokio::time; -use web_time::Duration; - -use super::WsSubscriptionBody; -use crate::nuts::nut17::Kind; -use crate::nuts::{nut01, nut05, nut07, nut23, CheckStateRequest, NotificationPayload}; -use crate::pub_sub::SubId; -use crate::wallet::MintConnector; -use crate::Wallet; - -#[derive(Debug, Hash, PartialEq, Eq)] -enum UrlType { - Mint(String), - MintBolt12(String), - Melt(String), - PublicKey(nut01::PublicKey), -} - -#[derive(Debug, Eq, PartialEq)] -enum AnyState { - MintQuoteState(nut23::QuoteState), - MintBolt12QuoteState(MintQuoteBolt12Response), - MeltQuoteState(nut05::QuoteState), - PublicKey(nut07::State), - Empty, -} - -type SubscribedTo = HashMap>, SubId, AnyState)>; - -async fn convert_subscription( - sub_id: SubId, - subscriptions: &Arc>>, - subscribed_to: &mut SubscribedTo, -) -> Option<()> { - let subscription = subscriptions.read().await; - let sub = subscription.get(&sub_id)?; - tracing::debug!("New subscription: {:?}", sub); - match sub.1.kind { - Kind::Bolt11MintQuote => { - for id in sub.1.filters.iter().map(|id| UrlType::Mint(id.clone())) { - subscribed_to.insert(id, (sub.0.clone(), sub.1.id.clone(), AnyState::Empty)); - } - } - Kind::Bolt11MeltQuote => { - for id in sub.1.filters.iter().map(|id| UrlType::Melt(id.clone())) { - subscribed_to.insert(id, (sub.0.clone(), sub.1.id.clone(), AnyState::Empty)); - } - } - Kind::ProofState => { - for id in sub - .1 - .filters - .iter() - .map(|id| nut01::PublicKey::from_hex(id).map(UrlType::PublicKey)) - { - match id { - Ok(id) => { - subscribed_to - .insert(id, (sub.0.clone(), sub.1.id.clone(), AnyState::Empty)); - } - Err(err) => { - tracing::error!("Error parsing public key: {:?}. Subscription ignored, will never yield any result", err); - } - } - } - } - Kind::Bolt12MintQuote => { - for id in sub - .1 - .filters - .iter() - .map(|id| UrlType::MintBolt12(id.clone())) - { - subscribed_to.insert(id, (sub.0.clone(), sub.1.id.clone(), AnyState::Empty)); - } - } - } - - Some(()) -} - -#[cfg(not(target_arch = "wasm32"))] -#[inline] -pub async fn http_main>( - initial_state: S, - http_client: Arc, - subscriptions: Arc>>, - mut new_subscription_recv: mpsc::Receiver, - mut on_drop: mpsc::Receiver, - _wallet: Arc, -) { - let mut interval = time::interval(Duration::from_secs(2)); - let mut subscribed_to = SubscribedTo::new(); - - for sub_id in initial_state { - convert_subscription(sub_id, &subscriptions, &mut subscribed_to).await; - } - - loop { - tokio::select! { - _ = interval.tick() => { - poll_subscriptions(&http_client, &mut subscribed_to).await; - } - Some(subid) = new_subscription_recv.recv() => { - convert_subscription(subid, &subscriptions, &mut subscribed_to).await; - } - Some(id) = on_drop.recv() => { - subscribed_to.retain(|_, (_, sub_id, _)| *sub_id != id); - } - } - } -} - -#[cfg(target_arch = "wasm32")] -#[inline] -pub async fn http_main>( - initial_state: S, - http_client: Arc, - subscriptions: Arc>>, - mut new_subscription_recv: mpsc::Receiver, - mut on_drop: mpsc::Receiver, - _wallet: Arc, -) { - let mut subscribed_to = SubscribedTo::new(); - - for sub_id in initial_state { - convert_subscription(sub_id, &subscriptions, &mut subscribed_to).await; - } - - loop { - tokio::select! { - _ = gloo_timers::future::sleep(Duration::from_secs(2)) => { - poll_subscriptions(&http_client, &mut subscribed_to).await; - } - subid = new_subscription_recv.recv() => { - match subid { - Some(subid) => { - convert_subscription(subid, &subscriptions, &mut subscribed_to).await; - } - None => { - // New subscription channel closed - SubscriptionClient was dropped, terminate worker - break; - } - } - } - id = on_drop.recv() => { - match id { - Some(id) => { - subscribed_to.retain(|_, (_, sub_id, _)| *sub_id != id); - } - None => { - // Drop notification channel closed - SubscriptionClient was dropped, terminate worker - break; - } - } - } - } - } -} - -async fn poll_subscriptions( - http_client: &Arc, - subscribed_to: &mut SubscribedTo, -) { - for (url, (sender, _, last_state)) in subscribed_to.iter_mut() { - tracing::debug!("Polling: {:?}", url); - match url { - UrlType::MintBolt12(id) => { - let response = http_client.get_mint_quote_bolt12_status(id).await; - if let Ok(response) = response { - if *last_state == AnyState::MintBolt12QuoteState(response.clone()) { - continue; - } - *last_state = AnyState::MintBolt12QuoteState(response.clone()); - if let Err(err) = - sender.try_send(NotificationPayload::MintQuoteBolt12Response(response)) - { - tracing::error!("Error sending mint quote response: {:?}", err); - } - } - } - UrlType::Mint(id) => { - let response = http_client.get_mint_quote_status(id).await; - if let Ok(response) = response { - if *last_state == AnyState::MintQuoteState(response.state) { - continue; - } - *last_state = AnyState::MintQuoteState(response.state); - if let Err(err) = - sender.try_send(NotificationPayload::MintQuoteBolt11Response(response)) - { - tracing::error!("Error sending mint quote response: {:?}", err); - } - } - } - UrlType::Melt(id) => { - let response = http_client.get_melt_quote_status(id).await; - if let Ok(response) = response { - if *last_state == AnyState::MeltQuoteState(response.state) { - continue; - } - *last_state = AnyState::MeltQuoteState(response.state); - if let Err(err) = - sender.try_send(NotificationPayload::MeltQuoteBolt11Response(response)) - { - tracing::error!("Error sending melt quote response: {:?}", err); - } - } - } - UrlType::PublicKey(id) => { - let responses = http_client - .post_check_state(CheckStateRequest { ys: vec![*id] }) - .await; - if let Ok(mut responses) = responses { - let response = if let Some(state) = responses.states.pop() { - state - } else { - continue; - }; - - if *last_state == AnyState::PublicKey(response.state) { - continue; - } - *last_state = AnyState::PublicKey(response.state); - if let Err(err) = sender.try_send(NotificationPayload::ProofState(response)) { - tracing::error!("Error sending proof state response: {:?}", err); - } - } - } - } - } -} diff --git a/crates/cdk/src/wallet/subscription/mod.rs b/crates/cdk/src/wallet/subscription/mod.rs index 12143c38..a367b05d 100644 --- a/crates/cdk/src/wallet/subscription/mod.rs +++ b/crates/cdk/src/wallet/subscription/mod.rs @@ -7,28 +7,33 @@ //! the HTTP client. use std::collections::HashMap; use std::fmt::Debug; +use std::sync::atomic::AtomicUsize; use std::sync::Arc; -use cdk_common::subscription::Params; -use tokio::sync::{mpsc, RwLock}; -use tokio::task::JoinHandle; -#[cfg(target_arch = "wasm32")] -use wasm_bindgen_futures; +use cdk_common::nut17::ws::{WsMethodRequest, WsRequest, WsUnsubscribeRequest}; +use cdk_common::nut17::{Kind, NotificationId}; +use cdk_common::parking_lot::RwLock; +use cdk_common::pub_sub::remote_consumer::{ + Consumer, InternalRelay, RemoteActiveConsumer, StreamCtrl, SubscribeMessage, Transport, +}; +use cdk_common::pub_sub::{Error as PubsubError, Spec, Subscriber}; +use cdk_common::subscription::WalletParams; +use cdk_common::CheckStateRequest; +use tokio::sync::mpsc; +use uuid::Uuid; -use super::Wallet; +use crate::event::MintEvent; use crate::mint_url::MintUrl; -use crate::pub_sub::SubId; use crate::wallet::MintConnector; -mod http; -#[cfg(all( - not(feature = "http_subscription"), - feature = "mint", - not(target_arch = "wasm32") -))] +#[cfg(not(target_arch = "wasm32"))] mod ws; -type WsSubscriptionBody = (mpsc::Sender, Params); +/// Notification Payload +pub type NotificationPayload = crate::nuts::NotificationPayload; + +/// Type alias +pub type ActiveSubscription = RemoteActiveConsumer; /// Subscription manager /// @@ -45,13 +50,27 @@ type WsSubscriptionBody = (mpsc::Sender, Params); /// The subscribers have a simple-to-use interface, receiving an /// ActiveSubscription struct, which can be used to receive updates and to /// unsubscribe from updates automatically on the drop. -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct SubscriptionManager { - all_connections: Arc>>, + all_connections: Arc>>>>, http_client: Arc, prefer_http: bool, } +impl Debug for SubscriptionManager { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Subscription Manager connected to {:?}", + self.all_connections + .write() + .keys() + .cloned() + .collect::>() + ) + } +} + impl SubscriptionManager { /// Create a new subscription manager pub fn new(http_client: Arc, prefer_http: bool) -> Self { @@ -63,63 +82,54 @@ impl SubscriptionManager { } /// Subscribe to updates from a mint server with a given filter - pub async fn subscribe( + pub fn subscribe( &self, mint_url: MintUrl, - filter: Params, - wallet: Arc, - ) -> ActiveSubscription { - let subscription_clients = self.all_connections.read().await; - let id = filter.id.clone(); - if let Some(subscription_client) = subscription_clients.get(&mint_url) { - let (on_drop_notif, receiver) = subscription_client.subscribe(filter).await; - ActiveSubscription::new(receiver, id, on_drop_notif) - } else { - drop(subscription_clients); + filter: WalletParams, + ) -> Result, PubsubError> { + self.all_connections + .write() + .entry(mint_url.clone()) + .or_insert_with(|| { + Consumer::new( + SubscriptionClient { + mint_url, + http_client: self.http_client.clone(), + req_id: 0.into(), + }, + self.prefer_http, + (), + ) + }) + .subscribe(filter) + } +} - #[cfg(all( - not(feature = "http_subscription"), - feature = "mint", - not(target_arch = "wasm32") - ))] - let is_ws_support = self - .http_client - .get_mint_info() - .await - .map(|info| !info.nuts.nut17.supported.is_empty()) - .unwrap_or_default(); +/// MintSubTopics +#[derive(Clone, Default)] +pub struct MintSubTopics {} - #[cfg(any( - feature = "http_subscription", - not(feature = "mint"), - target_arch = "wasm32" - ))] - let is_ws_support = false; +#[async_trait::async_trait] +impl Spec for MintSubTopics { + type SubscriptionId = String; - let is_ws_support = if self.prefer_http { - false - } else { - is_ws_support - }; + type Event = MintEvent; - tracing::debug!( - "Connect to {:?} to subscribe. WebSocket is supported ({})", - mint_url, - is_ws_support - ); + type Topic = NotificationId; - let mut subscription_clients = self.all_connections.write().await; - let subscription_client = SubscriptionClient::new( - mint_url.clone(), - self.http_client.clone(), - is_ws_support, - wallet, - ); - let (on_drop_notif, receiver) = subscription_client.subscribe(filter).await; - subscription_clients.insert(mint_url, subscription_client); + type Context = (); - ActiveSubscription::new(receiver, id, on_drop_notif) - } + fn new_instance(_context: Self::Context) -> Arc + where + Self: Sized, + { + Arc::new(Self {}) + } + + async fn fetch_events(self: &Arc, _topics: Vec, _reply_to: Subscriber) + where + Self: Sized, + { } } @@ -129,225 +139,178 @@ impl SubscriptionManager { /// otherwise the HTTP pool and pause will be used (which is the less efficient /// method). #[derive(Debug)] +#[allow(dead_code)] pub struct SubscriptionClient { - new_subscription_notif: mpsc::Sender, - on_drop_notif: mpsc::Sender, - subscriptions: Arc>>, - worker: Option>, -} - -type NotificationPayload = crate::nuts::NotificationPayload; - -/// Active Subscription -pub struct ActiveSubscription { - sub_id: Option, - on_drop_notif: mpsc::Sender, - receiver: mpsc::Receiver, -} - -impl ActiveSubscription { - fn new( - receiver: mpsc::Receiver, - sub_id: SubId, - on_drop_notif: mpsc::Sender, - ) -> Self { - Self { - sub_id: Some(sub_id), - on_drop_notif, - receiver, - } - } - - /// Try to receive a notification - pub fn try_recv(&mut self) -> Result, Error> { - match self.receiver.try_recv() { - Ok(payload) => Ok(Some(payload)), - Err(mpsc::error::TryRecvError::Empty) => Ok(None), - Err(mpsc::error::TryRecvError::Disconnected) => Err(Error::Disconnected), - } - } - - /// Receive a notification asynchronously - pub async fn recv(&mut self) -> Option { - self.receiver.recv().await - } -} - -impl Drop for ActiveSubscription { - fn drop(&mut self) { - if let Some(sub_id) = self.sub_id.take() { - let _ = self.on_drop_notif.try_send(sub_id); - } - } -} - -/// Subscription client error -#[derive(thiserror::Error, Debug)] -pub enum Error { - /// Url error - #[error("Could not join paths: {0}")] - Url(#[from] crate::mint_url::Error), - /// Disconnected from the notification channel - #[error("Disconnected from the notification channel")] - Disconnected, + http_client: Arc, + mint_url: MintUrl, + req_id: AtomicUsize, } +#[allow(dead_code)] impl SubscriptionClient { - /// Create new [`SubscriptionClient`] - pub fn new( - url: MintUrl, - http_client: Arc, - prefer_ws_method: bool, - wallet: Arc, - ) -> Self { - let subscriptions = Arc::new(RwLock::new(HashMap::new())); - let (new_subscription_notif, new_subscription_recv) = mpsc::channel(100); - let (on_drop_notif, on_drop_recv) = mpsc::channel(1000); - - Self { - new_subscription_notif, - on_drop_notif, - subscriptions: subscriptions.clone(), - worker: Self::start_worker( - prefer_ws_method, - http_client, - url, - subscriptions, - new_subscription_recv, - on_drop_recv, - wallet, - ), - } - } - - #[allow(unused_variables)] - fn start_worker( - prefer_ws_method: bool, - http_client: Arc, - url: MintUrl, - subscriptions: Arc>>, - new_subscription_recv: mpsc::Receiver, - on_drop_recv: mpsc::Receiver, - wallet: Arc, - ) -> Option> { - #[cfg(any( - feature = "http_subscription", - not(feature = "mint"), - target_arch = "wasm32" - ))] - return Self::http_worker( - http_client, - subscriptions, - new_subscription_recv, - on_drop_recv, - wallet, - ); - - #[cfg(all( - not(feature = "http_subscription"), - feature = "mint", - not(target_arch = "wasm32") - ))] - if prefer_ws_method { - Self::ws_worker( - http_client, - url, - subscriptions, - new_subscription_recv, - on_drop_recv, - wallet, - ) - } else { - Self::http_worker( - http_client, - subscriptions, - new_subscription_recv, - on_drop_recv, - wallet, - ) - } - } - - /// Subscribe to a WebSocket channel - pub async fn subscribe( + fn get_sub_request( &self, - filter: Params, - ) -> (mpsc::Sender, mpsc::Receiver) { - let mut subscriptions = self.subscriptions.write().await; - let id = filter.id.clone(); + id: String, + params: NotificationId, + ) -> Option<(usize, String)> { + let (kind, filter) = match params { + NotificationId::ProofState(x) => (Kind::ProofState, x.to_string()), + NotificationId::MeltQuoteBolt11(q) | NotificationId::MeltQuoteBolt12(q) => { + (Kind::Bolt11MeltQuote, q) + } + NotificationId::MintQuoteBolt11(q) => (Kind::Bolt11MintQuote, q), + NotificationId::MintQuoteBolt12(q) => (Kind::Bolt12MintQuote, q), + }; - let (sender, receiver) = mpsc::channel(10_000); - subscriptions.insert(id.clone(), (sender, filter)); - drop(subscriptions); + let request: WsRequest<_> = ( + WsMethodRequest::Subscribe(WalletParams { + kind, + filters: vec![filter], + id: id.into(), + }), + self.req_id + .fetch_add(1, std::sync::atomic::Ordering::Relaxed), + ) + .into(); - let _ = self.new_subscription_notif.send(id).await; - (self.on_drop_notif.clone(), receiver) + serde_json::to_string(&request) + .inspect_err(|err| { + tracing::error!("Could not serialize subscribe message: {:?}", err); + }) + .map(|json| (request.id, json)) + .ok() } - /// HTTP subscription client - /// - /// This is a poll based subscription, where the client will poll the server - /// from time to time to get updates, notifying the subscribers on changes - fn http_worker( - http_client: Arc, - subscriptions: Arc>>, - new_subscription_recv: mpsc::Receiver, - on_drop: mpsc::Receiver, - wallet: Arc, - ) -> Option> { - let http_worker = http::http_main( - vec![], - http_client, - subscriptions, - new_subscription_recv, - on_drop, - wallet, - ); + fn get_unsub_request(&self, sub_id: String) -> Option { + let request: WsRequest<_> = ( + WsMethodRequest::Unsubscribe(WsUnsubscribeRequest { sub_id }), + self.req_id + .fetch_add(1, std::sync::atomic::Ordering::Relaxed), + ) + .into(); + + match serde_json::to_string(&request) { + Ok(json) => Some(json), + Err(err) => { + tracing::error!("Could not serialize unsubscribe message: {:?}", err); + None + } + } + } +} + +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +impl Transport for SubscriptionClient { + type Spec = MintSubTopics; + + fn new_name(&self) -> ::SubscriptionId { + Uuid::new_v4().to_string() + } + + async fn stream( + &self, + _ctrls: mpsc::Receiver>, + _topics: Vec>, + _reply_to: InternalRelay, + ) -> Result<(), PubsubError> { + #[cfg(not(target_arch = "wasm32"))] + let r = ws::stream_client(self, _ctrls, _topics, _reply_to).await; #[cfg(target_arch = "wasm32")] - { - wasm_bindgen_futures::spawn_local(http_worker); - None - } + let r = Err(PubsubError::NotSupported); - #[cfg(not(target_arch = "wasm32"))] - { - Some(tokio::spawn(http_worker)) - } + r } - /// WebSocket subscription client - /// - /// This is a WebSocket based subscription, where the client will connect to - /// the server and stay there idle waiting for server-side notifications - #[cfg(all( - not(feature = "http_subscription"), - feature = "mint", - not(target_arch = "wasm32") - ))] - fn ws_worker( - http_client: Arc, - url: MintUrl, - subscriptions: Arc>>, - new_subscription_recv: mpsc::Receiver, - on_drop: mpsc::Receiver, - wallet: Arc, - ) -> Option> { - Some(tokio::spawn(ws::ws_main( - http_client, - url, - subscriptions, - new_subscription_recv, - on_drop, - wallet, - ))) - } -} - -impl Drop for SubscriptionClient { - fn drop(&mut self) { - if let Some(handle) = self.worker.take() { - handle.abort(); - } + /// Poll on demand + async fn poll( + &self, + topics: Vec>, + reply_to: InternalRelay, + ) -> Result<(), PubsubError> { + let proofs = topics + .iter() + .filter_map(|(_, x)| match &x { + NotificationId::ProofState(p) => Some(*p), + _ => None, + }) + .collect::>(); + + if !proofs.is_empty() { + for state in self + .http_client + .post_check_state(CheckStateRequest { ys: proofs }) + .await + .map_err(|e| PubsubError::Internal(Box::new(e)))? + .states + { + reply_to.send(MintEvent::new(NotificationPayload::ProofState(state))); + } + } + + for topic in topics + .into_iter() + .map(|(_, x)| x) + .filter(|x| !matches!(x, NotificationId::ProofState(_))) + { + match topic { + NotificationId::MintQuoteBolt11(id) => { + let response = match self.http_client.get_mint_quote_status(&id).await { + Ok(success) => success, + Err(err) => { + tracing::error!("Error with MintBolt11 {} with {:?}", id, err); + continue; + } + }; + + reply_to.send(MintEvent::new( + NotificationPayload::MintQuoteBolt11Response(response.clone()), + )); + } + NotificationId::MeltQuoteBolt11(id) => { + let response = match self.http_client.get_melt_quote_status(&id).await { + Ok(success) => success, + Err(err) => { + tracing::error!("Error with MeltBolt11 {} with {:?}", id, err); + continue; + } + }; + + reply_to.send(MintEvent::new( + NotificationPayload::MeltQuoteBolt11Response(response), + )); + } + NotificationId::MintQuoteBolt12(id) => { + let response = match self.http_client.get_mint_quote_bolt12_status(&id).await { + Ok(success) => success, + Err(err) => { + tracing::error!("Error with MintBolt12 {} with {:?}", id, err); + continue; + } + }; + + reply_to.send(MintEvent::new( + NotificationPayload::MintQuoteBolt12Response(response), + )); + } + NotificationId::MeltQuoteBolt12(id) => { + let response = match self.http_client.get_melt_bolt12_quote_status(&id).await { + Ok(success) => success, + Err(err) => { + tracing::error!("Error with MeltBolt12 {} with {:?}", id, err); + continue; + } + }; + + reply_to.send(MintEvent::new( + NotificationPayload::MeltQuoteBolt11Response(response), + )); + } + _ => {} + } + } + + Ok(()) } } diff --git a/crates/cdk/src/wallet/subscription/ws.rs b/crates/cdk/src/wallet/subscription/ws.rs index 0fbed754..1c319829 100644 --- a/crates/cdk/src/wallet/subscription/ws.rs +++ b/crates/cdk/src/wallet/subscription/ws.rs @@ -1,38 +1,25 @@ -use std::collections::{HashMap, HashSet}; -use std::sync::atomic::AtomicUsize; -use std::sync::Arc; -use std::time::Duration; - -use cdk_common::subscription::Params; -use cdk_common::ws::{WsMessageOrResponse, WsMethodRequest, WsRequest, WsUnsubscribeRequest}; +use cdk_common::nut17::ws::WsMessageOrResponse; +use cdk_common::pub_sub::remote_consumer::{InternalRelay, StreamCtrl, SubscribeMessage}; +use cdk_common::pub_sub::Error as PubsubError; #[cfg(feature = "auth")] use cdk_common::{Method, RoutePath}; use futures::{SinkExt, StreamExt}; -use tokio::sync::{mpsc, RwLock}; -use tokio::time::sleep; +use tokio::sync::mpsc; use tokio_tungstenite::connect_async; use tokio_tungstenite::tungstenite::client::IntoClientRequest; use tokio_tungstenite::tungstenite::Message; -use super::http::http_main; -use super::WsSubscriptionBody; -use crate::mint_url::MintUrl; -use crate::pub_sub::SubId; -use crate::wallet::MintConnector; -use crate::Wallet; +use super::{MintSubTopics, SubscriptionClient}; -const MAX_ATTEMPT_FALLBACK_HTTP: usize = 10; - -#[inline] -pub async fn ws_main( - http_client: Arc, - mint_url: MintUrl, - subscriptions: Arc>>, - mut new_subscription_recv: mpsc::Receiver, - mut on_drop: mpsc::Receiver, - wallet: Arc, -) { - let mut url = mint_url +#[inline(always)] +pub(crate) async fn stream_client( + client: &SubscriptionClient, + mut ctrl: mpsc::Receiver>, + topics: Vec>, + reply_to: InternalRelay, +) -> Result<(), PubsubError> { + let mut url = client + .mint_url .join_paths(&["v1", "ws"]) .expect("Could not join paths"); @@ -42,241 +29,140 @@ pub async fn ws_main( url.set_scheme("ws").expect("Could not set scheme"); } - let request = match url.to_string().into_client_request() { - Ok(req) => req, - Err(err) => { - tracing::error!("Failed to create client request: {:?}", err); - // Fallback to HTTP client if we can't create the WebSocket request - return http_main( - std::iter::empty(), - http_client, - subscriptions, - new_subscription_recv, - on_drop, - wallet, - ) - .await; - } - }; + #[cfg(not(feature = "auth"))] + let request = url.to_string().into_client_request().map_err(|err| { + tracing::error!("Failed to create client request: {:?}", err); + // Fallback to HTTP client if we can't create the WebSocket request + cdk_common::pub_sub::Error::NotSupported + })?; - let mut active_subscriptions = HashMap::>::new(); - let mut failure_count = 0; + #[cfg(feature = "auth")] + let mut request = url.to_string().into_client_request().map_err(|err| { + tracing::error!("Failed to create client request: {:?}", err); + // Fallback to HTTP client if we can't create the WebSocket request + cdk_common::pub_sub::Error::NotSupported + })?; - loop { - if subscriptions.read().await.is_empty() { - // No active subscription - sleep(Duration::from_millis(100)).await; - continue; - } - - let mut request_clone = request.clone(); - #[cfg(feature = "auth")] - { - let auth_wallet = http_client.get_auth_wallet().await; - let token = match auth_wallet.as_ref() { - Some(auth_wallet) => { - let endpoint = cdk_common::ProtectedEndpoint::new(Method::Get, RoutePath::Ws); - match auth_wallet.get_auth_for_request(&endpoint).await { - Ok(token) => token, - Err(err) => { - tracing::warn!("Failed to get auth token: {:?}", err); - None - } + #[cfg(feature = "auth")] + { + let auth_wallet = client.http_client.get_auth_wallet().await; + let token = match auth_wallet.as_ref() { + Some(auth_wallet) => { + let endpoint = cdk_common::ProtectedEndpoint::new(Method::Get, RoutePath::Ws); + match auth_wallet.get_auth_for_request(&endpoint).await { + Ok(token) => token, + Err(err) => { + tracing::warn!("Failed to get auth token: {:?}", err); + None } } - None => None, + } + None => None, + }; + + if let Some(auth_token) = token { + let header_key = match &auth_token { + cdk_common::AuthToken::ClearAuth(_) => "Clear-auth", + cdk_common::AuthToken::BlindAuth(_) => "Blind-auth", }; - if let Some(auth_token) = token { - let header_key = match &auth_token { - cdk_common::AuthToken::ClearAuth(_) => "Clear-auth", - cdk_common::AuthToken::BlindAuth(_) => "Blind-auth", - }; - - match auth_token.to_string().parse() { - Ok(header_value) => { - request_clone.headers_mut().insert(header_key, header_value); - } - Err(err) => { - tracing::warn!("Failed to parse auth token as header value: {:?}", err); - } + match auth_token.to_string().parse() { + Ok(header_value) => { + request.headers_mut().insert(header_key, header_value); + } + Err(err) => { + tracing::warn!("Failed to parse auth token as header value: {:?}", err); } } } + } - tracing::debug!("Connecting to {}", url); - let ws_stream = match connect_async(request_clone.clone()).await { - Ok((ws_stream, _)) => ws_stream, - Err(err) => { - failure_count += 1; - tracing::error!("Could not connect to server: {:?}", err); - if failure_count > MAX_ATTEMPT_FALLBACK_HTTP { - tracing::error!( - "Could not connect to server after {MAX_ATTEMPT_FALLBACK_HTTP} attempts, falling back to HTTP-subscription client" - ); + tracing::debug!("Connecting to {}", url); + let ws_stream = connect_async(request) + .await + .map(|(ws_stream, _)| ws_stream) + .map_err(|err| { + tracing::error!("Error connecting: {err:?}"); - return http_main( - active_subscriptions.into_keys(), - http_client, - subscriptions, - new_subscription_recv, - on_drop, - wallet, - ) - .await; - } - continue; - } - }; - tracing::debug!("Connected to {}", url); + cdk_common::pub_sub::Error::Internal(Box::new(err)) + })?; - let (mut write, mut read) = ws_stream.split(); - let req_id = AtomicUsize::new(0); + tracing::debug!("Connected to {}", url); + let (mut write, mut read) = ws_stream.split(); - let get_sub_request = |params: Params| -> Option<(usize, String)> { - let request: WsRequest = ( - WsMethodRequest::Subscribe(params), - req_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed), - ) - .into(); - - match serde_json::to_string(&request) { - Ok(json) => Some((request.id, json)), - Err(err) => { - tracing::error!("Could not serialize subscribe message: {:?}", err); - None - } - } + for (name, index) in topics { + let (_, req) = if let Some(req) = client.get_sub_request(name, index) { + req + } else { + continue; }; - let get_unsub_request = |sub_id: SubId| -> Option { - let request: WsRequest = ( - WsMethodRequest::Unsubscribe(WsUnsubscribeRequest { sub_id }), - req_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed), - ) - .into(); + let _ = write.send(Message::Text(req.into())).await; + } - match serde_json::to_string(&request) { - Ok(json) => Some(json), - Err(err) => { - tracing::error!("Could not serialize unsubscribe message: {:?}", err); - None - } - } - }; - - // Websocket reconnected, restore all subscriptions - let mut subscription_requests = HashSet::new(); - - let read_subscriptions = subscriptions.read().await; - for (sub_id, _) in active_subscriptions.iter() { - if let Some(Some((req_id, req))) = read_subscriptions - .get(sub_id) - .map(|(_, params)| get_sub_request(params.clone())) - { - let _ = write.send(Message::Text(req.into())).await; - subscription_requests.insert(req_id); - } - } - drop(read_subscriptions); - - loop { - tokio::select! { - Some(msg) = read.next() => { - let msg = match msg { - Ok(msg) => msg, - Err(_) => { - if let Err(err) = write.send(Message::Close(None)).await { - tracing::error!("Closing error {err:?}"); - } - break - }, - }; - let msg = match msg { - Message::Text(msg) => msg, - _ => continue, - }; - let msg = match serde_json::from_str::(&msg) { - Ok(msg) => msg, - Err(_) => continue, - }; - - match msg { - WsMessageOrResponse::Notification(payload) => { - tracing::debug!("Received notification from server: {:?}", payload); - let _ = active_subscriptions.get(&payload.params.sub_id).map(|sender| { - let _ = sender.try_send(payload.params.payload); - }); - } - WsMessageOrResponse::Response(response) => { - tracing::debug!("Received response from server: {:?}", response); - subscription_requests.remove(&response.id); - // reset connection failure after a successful response from the serer - failure_count = 0; - } - WsMessageOrResponse::ErrorResponse(error) => { - tracing::error!("Received error from server: {:?}", error); - - if subscription_requests.contains(&error.id) { - failure_count += 1; - if failure_count > MAX_ATTEMPT_FALLBACK_HTTP { - tracing::error!( - "Falling back to HTTP client" - ); - - return http_main( - active_subscriptions.into_keys(), - http_client, - subscriptions, - new_subscription_recv, - on_drop, - wallet, - ) - .await; - } - - if let Err(err) = write.send(Message::Close(None)).await { - tracing::error!("Closing error {err:?}"); - } - - break; // break connection to force a reconnection, to attempt to recover form this error - } - } + loop { + tokio::select! { + Some(msg) = ctrl.recv() => { + match msg { + StreamCtrl::Subscribe(msg) => { + let (_, req) = if let Some(req) = client.get_sub_request(msg.0, msg.1) { + req + } else { + continue; + }; + let _ = write.send(Message::Text(req.into())).await; } - - } - Some(subid) = new_subscription_recv.recv() => { - let subscription = subscriptions.read().await; - let sub = if let Some(subscription) = subscription.get(&subid) { - subscription - } else { - continue - }; - tracing::debug!("Subscribing to {:?}", sub.1); - active_subscriptions.insert(subid, sub.0.clone()); - if let Some((req_id, json)) = get_sub_request(sub.1.clone()) { - let _ = write.send(Message::Text(json.into())).await; - subscription_requests.insert(req_id); + StreamCtrl::Unsubscribe(msg) => { + let req = if let Some(req) = client.get_unsub_request(msg) { + req + } else { + continue; + }; + let _ = write.send(Message::Text(req.into())).await; } - }, - Some(subid) = on_drop.recv() => { - let mut subscription = subscriptions.write().await; - if let Some(sub) = subscription.remove(&subid) { - drop(sub); - } - tracing::debug!("Unsubscribing from {:?}", subid); - if let Some(json) = get_unsub_request(subid) { - let _ = write.send(Message::Text(json.into())).await; - } - - if subscription.is_empty() { + StreamCtrl::Stop => { if let Err(err) = write.send(Message::Close(None)).await { tracing::error!("Closing error {err:?}"); } break; } + }; + } + Some(msg) = read.next() => { + let msg = match msg { + Ok(msg) => msg, + Err(_) => { + if let Err(err) = write.send(Message::Close(None)).await { + tracing::error!("Closing error {err:?}"); + } + break; + } + }; + let msg = match msg { + Message::Text(msg) => msg, + _ => continue, + }; + let msg = match serde_json::from_str::>(&msg) { + Ok(msg) => msg, + Err(_) => continue, + }; + + match msg { + WsMessageOrResponse::Notification(payload) => { + reply_to.send(payload.params.payload); + } + WsMessageOrResponse::Response(response) => { + tracing::debug!("Received response from server: {:?}", response); + } + WsMessageOrResponse::ErrorResponse(error) => { + tracing::debug!("Received an error from server: {:?}", error); + return Err(PubsubError::InternalStr(error.error.message)); + } } + } } } + + Ok(()) } diff --git a/justfile b/justfile index 2ed4316f..c250fa58 100644 --- a/justfile +++ b/justfile @@ -66,14 +66,14 @@ test: # run doc tests -test-pure db="memory": build +test-pure db="memory": #!/usr/bin/env bash set -euo pipefail if [ ! -f Cargo.toml ]; then cd {{invocation_directory()}} fi - # Run pure integration tests + # Run pure integration tests (cargo test will only build what's needed for the test) CDK_TEST_DB_TYPE={{db}} cargo test -p cdk-integration-tests --test integration_tests_pure -- --test-threads 1 test-all db="memory":