mirror of
https://github.com/aljazceru/cdk.git
synced 2025-12-18 21:25:09 +01:00
Introduce a generic pubsub mod in cdk-common (#1098)
* pubsub: consolidate into Spec, adopt Arc<SubscriptionId>, and wire through wallet/mint/WS/FFI
Refactor the pub/sub engine to a single Spec trait, move Event alongside it,
and propagate Arc-backed subscription IDs across the stack. This simplifies
generics, clarifies responsibilities, and preserves coalescing +
latest-on-subscribe semantics.
- **Single source of truth:** `Spec` owns `Topic`, `Event`, `SubscriptionId`,
`Context`, new_instance, and fetch_events.
- **Lean & explicit API:** Remove Topic trait split;
`Subscriber::send(Event)` carries sub-ID internally.
- **Performance/ergonomics:** `Arc<SubscriptionId>` avoids heavy clones and
makes channel/task hops trivial.
- Introduce `pub_sub/typ.rs` with:
- trait `Spec`
- trait `Event` colocated with Spec.
- Remove `pub_sub/event.rs` fold `Event` into `typ.rs`.
- Make `Pubsub<S>` generic over `Spec` and store `Arc<S>`.
- The subscriber holds `Arc<SubscriptionId>` and deduplicates the latest
entry per subscription.
- SubscriptionRequest: rename SubscriptionName → SubscriptionId; return
`Arc<...>` from `subscription_name()`.
- Remote consumer (Transport) now parameterized by `Spec`; control types
updated:
- `StreamCtrl<S>`, `SubscribeMessage<S>`, internal caches keyed by
`S::Topic`.
- Mint/wallet:
- Mint: `MintPubSubSpec` (Context = `DynMintDatabase`),
`PubSubManager(Pubsub<MintPubSubSpec>)`.
- Wallet: lightweight MintSubTopics Spec with `Context = ()`.
- IDs go Arc end-to-end:
- cdk-axum WS maps `HashMap<Arc<SubId>, JoinHandle<()>>`, publisher sends
`(Arc<SubId>, NotificationPayload)`.
- `subscription::{Params, WalletParams}` now use `Arc<...>`.
- cdk-ffi conversions & wallet glue updated.
- Integration tests updated for new types.
- Coalescing unchanged: multiple local subs to the same topic are combined
into a single remote sub.
- Backfill via `Spec::fetch_events(topics, Subscriber)`; Subscriber enforces
latest-only dedupe per subscription.
**Result:** a slimmer, more maintainable pub/sub core that’s easier to embed
across mint, wallet, transports, and FFI without sacrificing performance or
semantics.
---------
Co-authored-by: thesimplekid <tsk@thesimplekid.com>
This commit is contained in:
10
.github/workflows/ci.yml
vendored
10
.github/workflows/ci.yml
vendored
@@ -490,6 +490,16 @@ jobs:
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
tool-cache: false
|
||||
android: true
|
||||
dotnet: true
|
||||
haskell: true
|
||||
large-packages: true
|
||||
docker-images: true
|
||||
swap-storage: true
|
||||
- name: Install Nix
|
||||
uses: DeterminateSystems/nix-installer-action@v17
|
||||
- name: Nix Cache
|
||||
|
||||
@@ -70,7 +70,7 @@ futures = { version = "0.3.28", default-features = false, features = ["async-awa
|
||||
lightning-invoice = { version = "0.33.0", features = ["serde", "std"] }
|
||||
lightning = { version = "0.1.2", default-features = false, features = ["std"]}
|
||||
ldk-node = "0.6.2"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde = { version = "1", features = ["derive", "rc"] }
|
||||
serde_json = "1"
|
||||
thiserror = { version = "2" }
|
||||
tokio = { version = "1", default-features = false, features = ["rt", "macros", "test-util", "sync"] }
|
||||
|
||||
@@ -13,13 +13,13 @@ readme = "README.md"
|
||||
[features]
|
||||
default = ["mint", "wallet", "auth"]
|
||||
swagger = ["dep:utoipa"]
|
||||
mint = ["dep:uuid"]
|
||||
mint = []
|
||||
wallet = []
|
||||
auth = ["dep:strum", "dep:strum_macros", "dep:regex"]
|
||||
bench = []
|
||||
|
||||
[dependencies]
|
||||
uuid = { workspace = true, optional = true }
|
||||
uuid.workspace = true
|
||||
bitcoin.workspace = true
|
||||
cbor-diag.workspace = true
|
||||
ciborium.workspace = true
|
||||
|
||||
@@ -16,7 +16,6 @@ pub use self::mint_url::MintUrl;
|
||||
pub use self::nuts::*;
|
||||
pub use self::util::SECP256K1;
|
||||
|
||||
#[cfg(feature = "mint")]
|
||||
pub mod quote_id;
|
||||
|
||||
#[doc(hidden)]
|
||||
|
||||
@@ -2,13 +2,11 @@
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[cfg(feature = "mint")]
|
||||
use super::PublicKey;
|
||||
use crate::nuts::{
|
||||
CurrencyUnit, MeltQuoteBolt11Response, MintQuoteBolt11Response, PaymentMethod, ProofState,
|
||||
};
|
||||
#[cfg(feature = "mint")]
|
||||
use crate::quote_id::{QuoteId, QuoteIdError};
|
||||
use crate::quote_id::QuoteIdError;
|
||||
use crate::MintQuoteBolt12Response;
|
||||
|
||||
pub mod ws;
|
||||
@@ -109,7 +107,10 @@ pub enum WsCommand {
|
||||
ProofState,
|
||||
}
|
||||
|
||||
impl<T> From<MintQuoteBolt12Response<T>> for NotificationPayload<T> {
|
||||
impl<T> From<MintQuoteBolt12Response<T>> for NotificationPayload<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
fn from(mint_quote: MintQuoteBolt12Response<T>) -> NotificationPayload<T> {
|
||||
NotificationPayload::MintQuoteBolt12Response(mint_quote)
|
||||
}
|
||||
@@ -119,7 +120,10 @@ impl<T> From<MintQuoteBolt12Response<T>> for NotificationPayload<T> {
|
||||
#[serde(bound = "T: Serialize + DeserializeOwned")]
|
||||
#[serde(untagged)]
|
||||
/// Subscription response
|
||||
pub enum NotificationPayload<T> {
|
||||
pub enum NotificationPayload<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
/// Proof State
|
||||
ProofState(ProofState),
|
||||
/// Melt Quote Bolt11 Response
|
||||
@@ -130,38 +134,23 @@ pub enum NotificationPayload<T> {
|
||||
MintQuoteBolt12Response(MintQuoteBolt12Response<T>),
|
||||
}
|
||||
|
||||
impl<T> From<ProofState> for NotificationPayload<T> {
|
||||
fn from(proof_state: ProofState) -> NotificationPayload<T> {
|
||||
NotificationPayload::ProofState(proof_state)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<MeltQuoteBolt11Response<T>> for NotificationPayload<T> {
|
||||
fn from(melt_quote: MeltQuoteBolt11Response<T>) -> NotificationPayload<T> {
|
||||
NotificationPayload::MeltQuoteBolt11Response(melt_quote)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<MintQuoteBolt11Response<T>> for NotificationPayload<T> {
|
||||
fn from(mint_quote: MintQuoteBolt11Response<T>) -> NotificationPayload<T> {
|
||||
NotificationPayload::MintQuoteBolt11Response(mint_quote)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "mint")]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Hash, Serialize)]
|
||||
#[serde(bound = "T: Serialize + DeserializeOwned")]
|
||||
/// A parsed notification
|
||||
pub enum Notification {
|
||||
pub enum NotificationId<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
/// ProofState id is a Pubkey
|
||||
ProofState(PublicKey),
|
||||
/// MeltQuote id is an QuoteId
|
||||
MeltQuoteBolt11(QuoteId),
|
||||
MeltQuoteBolt11(T),
|
||||
/// MintQuote id is an QuoteId
|
||||
MintQuoteBolt11(QuoteId),
|
||||
MintQuoteBolt11(T),
|
||||
/// MintQuote id is an QuoteId
|
||||
MintQuoteBolt12(QuoteId),
|
||||
MintQuoteBolt12(T),
|
||||
/// MintQuote id is an QuoteId
|
||||
MeltQuoteBolt12(QuoteId),
|
||||
MeltQuoteBolt12(T),
|
||||
}
|
||||
|
||||
/// Kind
|
||||
@@ -187,7 +176,6 @@ impl<I> AsRef<I> for Params<I> {
|
||||
/// Parsing error
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
#[cfg(feature = "mint")]
|
||||
#[error("Uuid Error: {0}")]
|
||||
/// Uuid Error
|
||||
QuoteId(#[from] QuoteIdError),
|
||||
|
||||
@@ -36,7 +36,10 @@ pub struct WsUnsubscribeResponse<I> {
|
||||
/// subscription
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(bound = "T: Serialize + DeserializeOwned, I: Serialize + DeserializeOwned")]
|
||||
pub struct NotificationInner<T, I> {
|
||||
pub struct NotificationInner<T, I>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
/// The subscription ID
|
||||
#[serde(rename = "subId")]
|
||||
pub sub_id: I,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::extract::ws::{CloseFrame, Message, WebSocket};
|
||||
use cdk::mint::QuoteId;
|
||||
use cdk::nuts::nut17::NotificationPayload;
|
||||
use cdk::pub_sub::SubId;
|
||||
use cdk::subscription::SubId;
|
||||
use cdk::ws::{
|
||||
notification_to_ws_message, NotificationInner, WsErrorBody, WsMessageOrResponse,
|
||||
WsMethodRequest, WsRequest,
|
||||
@@ -36,8 +37,8 @@ pub use error::WsError;
|
||||
|
||||
pub struct WsContext {
|
||||
state: MintState,
|
||||
subscriptions: HashMap<SubId, tokio::task::JoinHandle<()>>,
|
||||
publisher: mpsc::Sender<(SubId, NotificationPayload<QuoteId>)>,
|
||||
subscriptions: HashMap<Arc<SubId>, tokio::task::JoinHandle<()>>,
|
||||
publisher: mpsc::Sender<(Arc<SubId>, NotificationPayload<QuoteId>)>,
|
||||
}
|
||||
|
||||
/// Main function for websocket connections
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use cdk::subscription::{IndexableParams, Params};
|
||||
use cdk::subscription::Params;
|
||||
use cdk::ws::{WsResponseResult, WsSubscribeResponse};
|
||||
|
||||
use super::{WsContext, WsError};
|
||||
@@ -15,22 +15,20 @@ pub(crate) async fn handle(
|
||||
return Err(WsError::InvalidParams);
|
||||
}
|
||||
|
||||
let params: IndexableParams = params.into();
|
||||
|
||||
let mut subscription = context
|
||||
.state
|
||||
.mint
|
||||
.pubsub_manager()
|
||||
.try_subscribe(params)
|
||||
.await
|
||||
.subscribe(params)
|
||||
.map_err(|_| WsError::ParseError)?;
|
||||
|
||||
let publisher = context.publisher.clone();
|
||||
let sub_id_for_sender = sub_id.clone();
|
||||
context.subscriptions.insert(
|
||||
sub_id.clone(),
|
||||
tokio::spawn(async move {
|
||||
while let Some(response) = subscription.recv().await {
|
||||
let _ = publisher.send(response).await;
|
||||
let _ = publisher.try_send((sub_id_for_sender.clone(), response.into_inner()));
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
@@ -40,10 +40,16 @@ anyhow.workspace = true
|
||||
serde_json.workspace = true
|
||||
serde_with.workspace = true
|
||||
web-time.workspace = true
|
||||
tokio.workspace = true
|
||||
parking_lot = "0.12.5"
|
||||
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
uuid = { workspace = true, features = ["js"], optional = true }
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
wasm-bindgen = "0.2"
|
||||
wasm-bindgen-futures = "0.4"
|
||||
|
||||
[dev-dependencies]
|
||||
rand.workspace = true
|
||||
bip39.workspace = true
|
||||
wasm-bindgen-test = "0.3"
|
||||
|
||||
@@ -33,3 +33,5 @@ pub use cashu::nuts::{self, *};
|
||||
pub use cashu::quote_id::{self, *};
|
||||
pub use cashu::{dhke, ensure_cdk, mint_url, secret, util, SECP256K1};
|
||||
pub use error::Error;
|
||||
/// Re-export parking_lot for reuse
|
||||
pub use parking_lot;
|
||||
|
||||
44
crates/cdk-common/src/pub_sub/error.rs
Normal file
44
crates/cdk-common/src/pub_sub/error.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
//! Error types for the pub-sub module.
|
||||
|
||||
use tokio::sync::mpsc::error::TrySendError;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
/// Error
|
||||
pub enum Error {
|
||||
/// No subscription found
|
||||
#[error("Subscription not found")]
|
||||
NoSubscription,
|
||||
|
||||
/// Parsing error
|
||||
#[error("Parsing Error {0}")]
|
||||
ParsingError(String),
|
||||
|
||||
/// Internal error
|
||||
#[error("Internal")]
|
||||
Internal(Box<dyn std::error::Error + Send + Sync>),
|
||||
|
||||
/// Internal error
|
||||
#[error("Internal error {0}")]
|
||||
InternalStr(String),
|
||||
|
||||
/// Not supported
|
||||
#[error("Not supported")]
|
||||
NotSupported,
|
||||
|
||||
/// Channel is full
|
||||
#[error("Channel is full")]
|
||||
ChannelFull,
|
||||
|
||||
/// Channel is closed
|
||||
#[error("Channel is close")]
|
||||
ChannelClosed,
|
||||
}
|
||||
|
||||
impl<T> From<TrySendError<T>> for Error {
|
||||
fn from(value: TrySendError<T>) -> Self {
|
||||
match value {
|
||||
TrySendError::Closed(_) => Error::ChannelClosed,
|
||||
TrySendError::Full(_) => Error::ChannelFull,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,161 +0,0 @@
|
||||
//! WS Index
|
||||
|
||||
use std::fmt::Debug;
|
||||
use std::ops::Deref;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
use super::SubId;
|
||||
|
||||
/// Indexable trait
|
||||
pub trait Indexable {
|
||||
/// The type of the index, it is unknown and it is up to the Manager's
|
||||
/// generic type
|
||||
type Type: PartialOrd + Ord + Send + Sync + Debug;
|
||||
|
||||
/// To indexes
|
||||
fn to_indexes(&self) -> Vec<Index<Self::Type>>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone)]
|
||||
/// Index
|
||||
///
|
||||
/// The Index is a sorted structure that is used to quickly find matches
|
||||
///
|
||||
/// The counter is used to make sure each Index is unique, even if the prefix
|
||||
/// are the same, and also to make sure that earlier indexes matches first
|
||||
pub struct Index<T>
|
||||
where
|
||||
T: PartialOrd + Ord + Send + Sync + Debug,
|
||||
{
|
||||
prefix: T,
|
||||
counter: SubscriptionGlobalId,
|
||||
id: super::SubId,
|
||||
}
|
||||
|
||||
impl<T> From<&Index<T>> for super::SubId
|
||||
where
|
||||
T: PartialOrd + Ord + Send + Sync + Debug,
|
||||
{
|
||||
fn from(val: &Index<T>) -> Self {
|
||||
val.id.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Deref for Index<T>
|
||||
where
|
||||
T: PartialOrd + Ord + Send + Sync + Debug,
|
||||
{
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.prefix
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Index<T>
|
||||
where
|
||||
T: PartialOrd + Ord + Send + Sync + Debug,
|
||||
{
|
||||
/// Compare the
|
||||
pub fn cmp_prefix(&self, other: &Index<T>) -> std::cmp::Ordering {
|
||||
self.prefix.cmp(&other.prefix)
|
||||
}
|
||||
|
||||
/// Returns a globally unique id for the Index
|
||||
pub fn unique_id(&self) -> usize {
|
||||
self.counter.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<(T, SubId, SubscriptionGlobalId)> for Index<T>
|
||||
where
|
||||
T: PartialOrd + Ord + Send + Sync + Debug,
|
||||
{
|
||||
fn from((prefix, id, counter): (T, SubId, SubscriptionGlobalId)) -> Self {
|
||||
Self {
|
||||
prefix,
|
||||
id,
|
||||
counter,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<(T, SubId)> for Index<T>
|
||||
where
|
||||
T: PartialOrd + Ord + Send + Sync + Debug,
|
||||
{
|
||||
fn from((prefix, id): (T, SubId)) -> Self {
|
||||
Self {
|
||||
prefix,
|
||||
id,
|
||||
counter: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<T> for Index<T>
|
||||
where
|
||||
T: PartialOrd + Ord + Send + Sync + Debug,
|
||||
{
|
||||
fn from(prefix: T) -> Self {
|
||||
Self {
|
||||
prefix,
|
||||
id: Default::default(),
|
||||
counter: SubscriptionGlobalId(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static COUNTER: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
/// Dummy type
|
||||
///
|
||||
/// This is only use so each Index is unique, with the same prefix.
|
||||
///
|
||||
/// The prefix is used to leverage the BTree to find things quickly, but each
|
||||
/// entry/key must be unique, so we use this dummy type to make sure each Index
|
||||
/// is unique.
|
||||
///
|
||||
/// Unique is also used to make sure that the indexes are sorted by creation order
|
||||
#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Copy)]
|
||||
pub struct SubscriptionGlobalId(usize);
|
||||
|
||||
impl Default for SubscriptionGlobalId {
|
||||
fn default() -> Self {
|
||||
Self(COUNTER.fetch_add(1, Ordering::Relaxed))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_index_from_tuple() {
|
||||
let sub_id = SubId::from("test_sub_id");
|
||||
let prefix = "test_prefix";
|
||||
let index: Index<&str> = Index::from((prefix, sub_id.clone()));
|
||||
assert_eq!(index.prefix, "test_prefix");
|
||||
assert_eq!(index.id, sub_id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_index_cmp_prefix() {
|
||||
let sub_id = SubId::from("test_sub_id");
|
||||
let index1: Index<&str> = Index::from(("a", sub_id.clone()));
|
||||
let index2: Index<&str> = Index::from(("b", sub_id.clone()));
|
||||
assert_eq!(index1.cmp_prefix(&index2), std::cmp::Ordering::Less);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub_id_from_str() {
|
||||
let sub_id = SubId::from("test_sub_id");
|
||||
assert_eq!(sub_id.0, "test_sub_id");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub_id_deref() {
|
||||
let sub_id = SubId::from("test_sub_id");
|
||||
assert_eq!(&*sub_id, "test_sub_id");
|
||||
}
|
||||
}
|
||||
@@ -1,77 +1,180 @@
|
||||
//! Publish–subscribe pattern.
|
||||
//! Publish/Subscribe core
|
||||
//!
|
||||
//! This is a generic implementation for
|
||||
//! [NUT-17](<https://github.com/cashubtc/nuts/blob/main/17.md>) with a type
|
||||
//! agnostic Publish-subscribe manager.
|
||||
//! This module defines the transport-agnostic pub/sub primitives used by both
|
||||
//! mint and wallet components. The design prioritizes:
|
||||
//!
|
||||
//! The manager has a method for subscribers to subscribe to events with a
|
||||
//! generic type that must be converted to a vector of indexes.
|
||||
//! - **Request coalescing**: multiple local subscribers to the same remote topic
|
||||
//! result in a single upstream subscription, with local fan‑out.
|
||||
//! - **Latest-on-subscribe** (NUT-17): on (re)subscription, the most recent event
|
||||
//! is fetched and delivered before streaming new ones.
|
||||
//! - **Backpressure-aware delivery**: bounded channels + drop policies prevent
|
||||
//! a slow consumer from stalling the whole pipeline.
|
||||
//! - **Resilience**: automatic reconnect with exponential backoff; WebSocket
|
||||
//! streaming when available, HTTP long-poll fallback otherwise.
|
||||
//!
|
||||
//! Events are also generic that should implement the `Indexable` trait.
|
||||
use std::fmt::Debug;
|
||||
use std::ops::Deref;
|
||||
use std::str::FromStr;
|
||||
//! Terms used throughout the module:
|
||||
//! - **Event**: a domain object that maps to one or more `Topic`s via `Event::get_topics`.
|
||||
//! - **Topic**: an index/type that defines storage and matching semantics.
|
||||
//! - **SubscriptionRequest**: a domain-specific filter that can be converted into
|
||||
//! low-level transport messages (e.g., WebSocket subscribe frames).
|
||||
//! - **Spec**: type bundle tying `Event`, `Topic`, `SubscriptionId`, and serialization.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
mod error;
|
||||
mod pubsub;
|
||||
pub mod remote_consumer;
|
||||
mod subscriber;
|
||||
mod types;
|
||||
|
||||
pub mod index;
|
||||
pub use self::error::Error;
|
||||
pub use self::pubsub::Pubsub;
|
||||
pub use self::subscriber::{Subscriber, SubscriptionRequest};
|
||||
pub use self::types::*;
|
||||
|
||||
/// Default size of the remove channel
|
||||
pub const DEFAULT_REMOVE_SIZE: usize = 10_000;
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
/// Default channel size for subscription buffering
|
||||
pub const DEFAULT_CHANNEL_SIZE: usize = 10;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[async_trait::async_trait]
|
||||
/// On New Subscription trait
|
||||
///
|
||||
/// This trait is optional and it is used to notify the application when a new
|
||||
/// subscription is created. This is useful when the application needs to send
|
||||
/// the initial state to the subscriber upon subscription
|
||||
pub trait OnNewSubscription {
|
||||
/// Index type
|
||||
type Index;
|
||||
/// Subscription event type
|
||||
type Event;
|
||||
use super::subscriber::SubscriptionRequest;
|
||||
use super::{Error, Event, Pubsub, Spec, Subscriber};
|
||||
|
||||
/// Called when a new subscription is created
|
||||
async fn on_new_subscription(
|
||||
&self,
|
||||
request: &[&Self::Index],
|
||||
) -> Result<Vec<Self::Event>, String>;
|
||||
}
|
||||
#[derive(Clone, Debug, Serialize, Eq, PartialEq, Deserialize)]
|
||||
pub struct Message {
|
||||
pub foo: u64,
|
||||
pub bar: u64,
|
||||
}
|
||||
|
||||
/// Subscription Id wrapper
|
||||
///
|
||||
/// This is the place to add some sane default (like a max length) to the
|
||||
/// subscription ID
|
||||
#[derive(Debug, Clone, Default, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
|
||||
pub struct SubId(String);
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Deserialize, Serialize)]
|
||||
pub enum IndexTest {
|
||||
Foo(u64),
|
||||
Bar(u64),
|
||||
}
|
||||
|
||||
impl From<&str> for SubId {
|
||||
fn from(s: &str) -> Self {
|
||||
Self(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for SubId {
|
||||
fn from(s: String) -> Self {
|
||||
Self(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for SubId {
|
||||
type Err = ();
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
Ok(Self(s.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for SubId {
|
||||
type Target = String;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
impl Event for Message {
|
||||
type Topic = IndexTest;
|
||||
|
||||
fn get_topics(&self) -> Vec<Self::Topic> {
|
||||
vec![IndexTest::Foo(self.foo), IndexTest::Bar(self.bar)]
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CustomPubSub {
|
||||
pub storage: Arc<RwLock<HashMap<IndexTest, Message>>>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Spec for CustomPubSub {
|
||||
type Topic = IndexTest;
|
||||
|
||||
type Event = Message;
|
||||
|
||||
type SubscriptionId = String;
|
||||
|
||||
type Context = ();
|
||||
|
||||
fn new_instance(_context: Self::Context) -> Arc<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Arc::new(Self {
|
||||
storage: Default::default(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn fetch_events(
|
||||
self: &Arc<Self>,
|
||||
topics: Vec<<Self::Event as Event>::Topic>,
|
||||
reply_to: Subscriber<Self>,
|
||||
) where
|
||||
Self: Sized,
|
||||
{
|
||||
let storage = self.storage.read().unwrap();
|
||||
|
||||
for index in topics {
|
||||
if let Some(value) = storage.get(&index) {
|
||||
let _ = reply_to.send(value.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SubscriptionReq {
|
||||
Foo(u64),
|
||||
Bar(u64),
|
||||
}
|
||||
|
||||
impl SubscriptionRequest for SubscriptionReq {
|
||||
type Topic = IndexTest;
|
||||
|
||||
type SubscriptionId = String;
|
||||
|
||||
fn try_get_topics(&self) -> Result<Vec<Self::Topic>, Error> {
|
||||
Ok(vec![match self {
|
||||
SubscriptionReq::Bar(n) => IndexTest::Bar(*n),
|
||||
SubscriptionReq::Foo(n) => IndexTest::Foo(*n),
|
||||
}])
|
||||
}
|
||||
|
||||
fn subscription_name(&self) -> Arc<Self::SubscriptionId> {
|
||||
Arc::new("test".to_owned())
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn delivery_twice_realtime() {
|
||||
let pubsub = Pubsub::new(CustomPubSub::new_instance(()));
|
||||
|
||||
assert_eq!(pubsub.active_subscribers(), 0);
|
||||
|
||||
let mut subscriber = pubsub.subscribe(SubscriptionReq::Foo(2)).unwrap();
|
||||
|
||||
assert_eq!(pubsub.active_subscribers(), 1);
|
||||
|
||||
let _ = pubsub.publish_now(Message { foo: 2, bar: 1 });
|
||||
let _ = pubsub.publish_now(Message { foo: 2, bar: 2 });
|
||||
|
||||
assert_eq!(subscriber.recv().await.map(|x| x.bar), Some(1));
|
||||
assert_eq!(subscriber.recv().await.map(|x| x.bar), Some(2));
|
||||
assert!(subscriber.try_recv().is_none());
|
||||
|
||||
drop(subscriber);
|
||||
|
||||
assert_eq!(pubsub.active_subscribers(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_from_storage() {
|
||||
let x = CustomPubSub::new_instance(());
|
||||
let storage = x.storage.clone();
|
||||
|
||||
let pubsub = Pubsub::new(x);
|
||||
|
||||
{
|
||||
// set previous value
|
||||
let mut s = storage.write().unwrap();
|
||||
s.insert(IndexTest::Bar(2), Message { foo: 3, bar: 2 });
|
||||
}
|
||||
|
||||
let mut subscriber = pubsub.subscribe(SubscriptionReq::Bar(2)).unwrap();
|
||||
|
||||
// Just should receive the latest
|
||||
assert_eq!(subscriber.recv().await.map(|x| x.foo), Some(3));
|
||||
|
||||
// realtime delivery test
|
||||
let _ = pubsub.publish_now(Message { foo: 1, bar: 2 });
|
||||
assert_eq!(subscriber.recv().await.map(|x| x.foo), Some(1));
|
||||
|
||||
{
|
||||
// set previous value
|
||||
let mut s = storage.write().unwrap();
|
||||
s.insert(IndexTest::Bar(2), Message { foo: 1, bar: 2 });
|
||||
}
|
||||
|
||||
// new subscription should only get the latest state (it is up to the Topic trait)
|
||||
let mut y = pubsub.subscribe(SubscriptionReq::Bar(2)).unwrap();
|
||||
assert_eq!(y.recv().await.map(|x| x.foo), Some(1));
|
||||
}
|
||||
}
|
||||
|
||||
185
crates/cdk-common/src/pub_sub/pubsub.rs
Normal file
185
crates/cdk-common/src/pub_sub/pubsub.rs
Normal file
@@ -0,0 +1,185 @@
|
||||
//! Pub-sub producer
|
||||
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::{BTreeMap, HashSet};
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::Arc;
|
||||
|
||||
use parking_lot::RwLock;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use super::subscriber::{ActiveSubscription, SubscriptionRequest};
|
||||
use super::{Error, Event, Spec, Subscriber};
|
||||
|
||||
/// Default channel size for subscription buffering
|
||||
pub const DEFAULT_CHANNEL_SIZE: usize = 10_000;
|
||||
|
||||
/// Subscriber Receiver
|
||||
pub type SubReceiver<S> = mpsc::Receiver<(Arc<<S as Spec>::SubscriptionId>, <S as Spec>::Event)>;
|
||||
|
||||
/// Internal Index Tree
|
||||
pub type TopicTree<T> = Arc<
|
||||
RwLock<
|
||||
BTreeMap<
|
||||
// Index with a subscription unique ID
|
||||
(<T as Spec>::Topic, usize),
|
||||
Subscriber<T>,
|
||||
>,
|
||||
>,
|
||||
>;
|
||||
|
||||
/// Manager
|
||||
pub struct Pubsub<S>
|
||||
where
|
||||
S: Spec + 'static,
|
||||
{
|
||||
inner: Arc<S>,
|
||||
listeners_topics: TopicTree<S>,
|
||||
unique_subscription_counter: AtomicUsize,
|
||||
active_subscribers: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl<S> Pubsub<S>
|
||||
where
|
||||
S: Spec + 'static,
|
||||
{
|
||||
/// Create a new instance
|
||||
pub fn new(inner: Arc<S>) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
listeners_topics: Default::default(),
|
||||
unique_subscription_counter: 0.into(),
|
||||
active_subscribers: Arc::new(0.into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Total number of active subscribers, it is not the number of active topics being subscribed
|
||||
pub fn active_subscribers(&self) -> usize {
|
||||
self.active_subscribers
|
||||
.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Publish an event to all listenrs
|
||||
#[inline(always)]
|
||||
fn publish_internal(event: S::Event, listeners_index: &TopicTree<S>) -> Result<(), Error> {
|
||||
let index_storage = listeners_index.read();
|
||||
|
||||
let mut sent = HashSet::new();
|
||||
for topic in event.get_topics() {
|
||||
for ((subscription_index, unique_id), sender) in
|
||||
index_storage.range((topic.clone(), 0)..)
|
||||
{
|
||||
if subscription_index.cmp(&topic) != Ordering::Equal {
|
||||
break;
|
||||
}
|
||||
if sent.contains(&unique_id) {
|
||||
continue;
|
||||
}
|
||||
sent.insert(unique_id);
|
||||
sender.send(event.clone());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Broadcast an event to all listeners
|
||||
#[inline(always)]
|
||||
pub fn publish<E>(&self, event: E)
|
||||
where
|
||||
E: Into<S::Event>,
|
||||
{
|
||||
let topics = self.listeners_topics.clone();
|
||||
let event = event.into();
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
tokio::spawn(async move {
|
||||
let _ = Self::publish_internal(event, &topics);
|
||||
});
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
wasm_bindgen_futures::spawn_local(async move {
|
||||
let _ = Self::publish_internal(event, &topics);
|
||||
});
|
||||
}
|
||||
|
||||
/// Broadcast an event to all listeners right away, blocking the current thread
|
||||
///
|
||||
/// This function takes an Arc to the storage struct, the event_id, the kind
|
||||
/// and the vent to broadcast
|
||||
#[inline(always)]
|
||||
pub fn publish_now<E>(&self, event: E) -> Result<(), Error>
|
||||
where
|
||||
E: Into<S::Event>,
|
||||
{
|
||||
let event = event.into();
|
||||
Self::publish_internal(event, &self.listeners_topics)
|
||||
}
|
||||
|
||||
/// Subscribe proving custom sender/receiver mpsc
|
||||
#[inline(always)]
|
||||
pub fn subscribe_with<I>(
|
||||
&self,
|
||||
request: I,
|
||||
sender: &mpsc::Sender<(Arc<I::SubscriptionId>, S::Event)>,
|
||||
receiver: Option<SubReceiver<S>>,
|
||||
) -> Result<ActiveSubscription<S>, Error>
|
||||
where
|
||||
I: SubscriptionRequest<
|
||||
Topic = <S::Event as Event>::Topic,
|
||||
SubscriptionId = S::SubscriptionId,
|
||||
>,
|
||||
{
|
||||
let subscription_name = request.subscription_name();
|
||||
let sender = Subscriber::new(subscription_name.clone(), sender);
|
||||
let mut index_storage = self.listeners_topics.write();
|
||||
let subscription_internal_id = self
|
||||
.unique_subscription_counter
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
self.active_subscribers
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
let subscribed_to = request.try_get_topics()?;
|
||||
|
||||
for index in subscribed_to.iter() {
|
||||
index_storage.insert((index.clone(), subscription_internal_id), sender.clone());
|
||||
}
|
||||
drop(index_storage);
|
||||
|
||||
let inner = self.inner.clone();
|
||||
let subscribed_to_for_spawn = subscribed_to.clone();
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
tokio::spawn(async move {
|
||||
// TODO: Ignore topics broadcasted from fetch_events _if_ any real time has been broadcasted already.
|
||||
inner.fetch_events(subscribed_to_for_spawn, sender).await;
|
||||
});
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
wasm_bindgen_futures::spawn_local(async move {
|
||||
inner.fetch_events(subscribed_to_for_spawn, sender).await;
|
||||
});
|
||||
|
||||
Ok(ActiveSubscription::new(
|
||||
subscription_internal_id,
|
||||
subscription_name,
|
||||
self.active_subscribers.clone(),
|
||||
self.listeners_topics.clone(),
|
||||
subscribed_to,
|
||||
receiver,
|
||||
))
|
||||
}
|
||||
|
||||
/// Subscribe
|
||||
pub fn subscribe<I>(&self, request: I) -> Result<ActiveSubscription<S>, Error>
|
||||
where
|
||||
I: SubscriptionRequest<
|
||||
Topic = <S::Event as Event>::Topic,
|
||||
SubscriptionId = S::SubscriptionId,
|
||||
>,
|
||||
{
|
||||
let (sender, receiver) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
|
||||
self.subscribe_with(request, &sender, Some(receiver))
|
||||
}
|
||||
}
|
||||
885
crates/cdk-common/src/pub_sub/remote_consumer.rs
Normal file
885
crates/cdk-common/src/pub_sub/remote_consumer.rs
Normal file
@@ -0,0 +1,885 @@
|
||||
//! Pub-sub consumer
|
||||
//!
|
||||
//! Consumers are designed to connect to a producer, through a transport, and subscribe to events.
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use parking_lot::RwLock;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::{sleep, Instant};
|
||||
|
||||
use super::subscriber::{ActiveSubscription, SubscriptionRequest};
|
||||
use super::{Error, Event, Pubsub, Spec};
|
||||
|
||||
const STREAM_CONNECTION_BACKOFF: Duration = Duration::from_millis(2_000);
|
||||
|
||||
const STREAM_CONNECTION_MAX_BACKOFF: Duration = Duration::from_millis(30_000);
|
||||
|
||||
const INTERNAL_POLL_SIZE: usize = 1_000;
|
||||
|
||||
const POLL_SLEEP: Duration = Duration::from_millis(2_000);
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
use wasm_bindgen_futures;
|
||||
|
||||
struct UniqueSubscription<S>
|
||||
where
|
||||
S: Spec,
|
||||
{
|
||||
name: S::SubscriptionId,
|
||||
total_subscribers: usize,
|
||||
}
|
||||
|
||||
type UniqueSubscriptions<S> = RwLock<HashMap<<S as Spec>::Topic, UniqueSubscription<S>>>;
|
||||
|
||||
type ActiveSubscriptions<S> =
|
||||
RwLock<HashMap<Arc<<S as Spec>::SubscriptionId>, Vec<<S as Spec>::Topic>>>;
|
||||
|
||||
type CacheEvent<S> = HashMap<<<S as Spec>::Event as Event>::Topic, <S as Spec>::Event>;
|
||||
|
||||
/// Subscription consumer
|
||||
pub struct Consumer<T>
|
||||
where
|
||||
T: Transport + 'static,
|
||||
{
|
||||
transport: T,
|
||||
inner_pubsub: Arc<Pubsub<T::Spec>>,
|
||||
remote_subscriptions: UniqueSubscriptions<T::Spec>,
|
||||
subscriptions: ActiveSubscriptions<T::Spec>,
|
||||
stream_ctrl: RwLock<Option<mpsc::Sender<StreamCtrl<T::Spec>>>>,
|
||||
still_running: AtomicBool,
|
||||
prefer_polling: bool,
|
||||
/// Cached events
|
||||
///
|
||||
/// The cached events are useful to share events. The cache is automatically evicted it is
|
||||
/// disconnected from the remote source, meaning the cache is only active while there is an
|
||||
/// active subscription to the remote source, and it remembers the latest event.
|
||||
cached_events: Arc<RwLock<CacheEvent<T::Spec>>>,
|
||||
}
|
||||
|
||||
/// Remote consumer
|
||||
pub struct RemoteActiveConsumer<T>
|
||||
where
|
||||
T: Transport + 'static,
|
||||
{
|
||||
inner: ActiveSubscription<T::Spec>,
|
||||
previous_messages: VecDeque<<T::Spec as Spec>::Event>,
|
||||
consumer: Arc<Consumer<T>>,
|
||||
}
|
||||
|
||||
impl<T> RemoteActiveConsumer<T>
|
||||
where
|
||||
T: Transport + 'static,
|
||||
{
|
||||
/// Receives the next event
|
||||
pub async fn recv(&mut self) -> Option<<T::Spec as Spec>::Event> {
|
||||
if let Some(event) = self.previous_messages.pop_front() {
|
||||
Some(event)
|
||||
} else {
|
||||
self.inner.recv().await
|
||||
}
|
||||
}
|
||||
|
||||
/// Try receive an event or return Noen right away
|
||||
pub fn try_recv(&mut self) -> Option<<T::Spec as Spec>::Event> {
|
||||
if let Some(event) = self.previous_messages.pop_front() {
|
||||
Some(event)
|
||||
} else {
|
||||
self.inner.try_recv()
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the subscription name
|
||||
pub fn name(&self) -> &<T::Spec as Spec>::SubscriptionId {
|
||||
self.inner.name()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for RemoteActiveConsumer<T>
|
||||
where
|
||||
T: Transport + 'static,
|
||||
{
|
||||
fn drop(&mut self) {
|
||||
let _ = self.consumer.unsubscribe(self.name().clone());
|
||||
}
|
||||
}
|
||||
|
||||
/// Struct to relay events from Poll and Streams from the external subscription to the local
|
||||
/// subscribers
|
||||
pub struct InternalRelay<S>
|
||||
where
|
||||
S: Spec + 'static,
|
||||
{
|
||||
inner: Arc<Pubsub<S>>,
|
||||
cached_events: Arc<RwLock<CacheEvent<S>>>,
|
||||
}
|
||||
|
||||
impl<S> InternalRelay<S>
|
||||
where
|
||||
S: Spec + 'static,
|
||||
{
|
||||
/// Relay a remote event locally
|
||||
pub fn send<X>(&self, event: X)
|
||||
where
|
||||
X: Into<S::Event>,
|
||||
{
|
||||
let event = event.into();
|
||||
let mut cached_events = self.cached_events.write();
|
||||
|
||||
for topic in event.get_topics() {
|
||||
cached_events.insert(topic, event.clone());
|
||||
}
|
||||
|
||||
self.inner.publish(event);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Consumer<T>
|
||||
where
|
||||
T: Transport + 'static,
|
||||
{
|
||||
/// Creates a new instance
|
||||
pub fn new(
|
||||
transport: T,
|
||||
prefer_polling: bool,
|
||||
context: <T::Spec as Spec>::Context,
|
||||
) -> Arc<Self> {
|
||||
let this = Arc::new(Self {
|
||||
transport,
|
||||
prefer_polling,
|
||||
inner_pubsub: Arc::new(Pubsub::new(T::Spec::new_instance(context))),
|
||||
subscriptions: Default::default(),
|
||||
remote_subscriptions: Default::default(),
|
||||
stream_ctrl: RwLock::new(None),
|
||||
cached_events: Default::default(),
|
||||
still_running: true.into(),
|
||||
});
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
tokio::spawn(Self::stream(this.clone()));
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
wasm_bindgen_futures::spawn_local(Self::stream(this.clone()));
|
||||
|
||||
this
|
||||
}
|
||||
|
||||
async fn stream(instance: Arc<Self>) {
|
||||
let mut stream_supported = true;
|
||||
let mut poll_supported = true;
|
||||
|
||||
let mut backoff = STREAM_CONNECTION_BACKOFF;
|
||||
let mut retry_at = None;
|
||||
|
||||
loop {
|
||||
if (!stream_supported && !poll_supported)
|
||||
|| !instance
|
||||
.still_running
|
||||
.load(std::sync::atomic::Ordering::Relaxed)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
if instance.remote_subscriptions.read().is_empty() {
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
if stream_supported
|
||||
&& !instance.prefer_polling
|
||||
&& retry_at
|
||||
.map(|retry_at| retry_at < Instant::now())
|
||||
.unwrap_or(true)
|
||||
{
|
||||
let (sender, receiver) = mpsc::channel(INTERNAL_POLL_SIZE);
|
||||
|
||||
{
|
||||
*instance.stream_ctrl.write() = Some(sender);
|
||||
}
|
||||
|
||||
let current_subscriptions = {
|
||||
instance
|
||||
.remote_subscriptions
|
||||
.read()
|
||||
.iter()
|
||||
.map(|(key, name)| (name.name.clone(), key.clone()))
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
if let Err(err) = instance
|
||||
.transport
|
||||
.stream(
|
||||
receiver,
|
||||
current_subscriptions,
|
||||
InternalRelay {
|
||||
inner: instance.inner_pubsub.clone(),
|
||||
cached_events: instance.cached_events.clone(),
|
||||
},
|
||||
)
|
||||
.await
|
||||
{
|
||||
retry_at = Some(Instant::now() + backoff);
|
||||
backoff =
|
||||
(backoff + STREAM_CONNECTION_BACKOFF).min(STREAM_CONNECTION_MAX_BACKOFF);
|
||||
|
||||
if matches!(err, Error::NotSupported) {
|
||||
stream_supported = false;
|
||||
}
|
||||
tracing::error!("Long connection failed with error {:?}", err);
|
||||
} else {
|
||||
backoff = STREAM_CONNECTION_BACKOFF;
|
||||
}
|
||||
|
||||
// remove sender to stream, as there is no stream
|
||||
let _ = instance.stream_ctrl.write().take();
|
||||
}
|
||||
|
||||
if poll_supported {
|
||||
let current_subscriptions = {
|
||||
instance
|
||||
.remote_subscriptions
|
||||
.read()
|
||||
.iter()
|
||||
.map(|(key, name)| (name.name.clone(), key.clone()))
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
if let Err(err) = instance
|
||||
.transport
|
||||
.poll(
|
||||
current_subscriptions,
|
||||
InternalRelay {
|
||||
inner: instance.inner_pubsub.clone(),
|
||||
cached_events: instance.cached_events.clone(),
|
||||
},
|
||||
)
|
||||
.await
|
||||
{
|
||||
if matches!(err, Error::NotSupported) {
|
||||
poll_supported = false;
|
||||
}
|
||||
tracing::error!("Polling failed with error {:?}", err);
|
||||
}
|
||||
|
||||
sleep(POLL_SLEEP).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Unsubscribe from a topic, this is called automatically when RemoteActiveSubscription<T> goes
|
||||
/// out of scope
|
||||
fn unsubscribe(
|
||||
self: &Arc<Self>,
|
||||
subscription_name: <T::Spec as Spec>::SubscriptionId,
|
||||
) -> Result<(), Error> {
|
||||
let topics = self
|
||||
.subscriptions
|
||||
.write()
|
||||
.remove(&subscription_name)
|
||||
.ok_or(Error::NoSubscription)?;
|
||||
|
||||
let mut remote_subscriptions = self.remote_subscriptions.write();
|
||||
|
||||
for topic in topics {
|
||||
let mut remote_subscription =
|
||||
if let Some(remote_subscription) = remote_subscriptions.remove(&topic) {
|
||||
remote_subscription
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
|
||||
remote_subscription.total_subscribers = remote_subscription
|
||||
.total_subscribers
|
||||
.checked_sub(1)
|
||||
.unwrap_or_default();
|
||||
|
||||
if remote_subscription.total_subscribers == 0 {
|
||||
let mut cached_events = self.cached_events.write();
|
||||
|
||||
cached_events.remove(&topic);
|
||||
|
||||
self.message_to_stream(StreamCtrl::Unsubscribe(remote_subscription.name.clone()))?;
|
||||
} else {
|
||||
remote_subscriptions.insert(topic, remote_subscription);
|
||||
}
|
||||
}
|
||||
|
||||
if remote_subscriptions.is_empty() {
|
||||
self.message_to_stream(StreamCtrl::Stop)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn message_to_stream(&self, message: StreamCtrl<T::Spec>) -> Result<(), Error> {
|
||||
let to_stream = self.stream_ctrl.read();
|
||||
|
||||
if let Some(to_stream) = to_stream.as_ref() {
|
||||
Ok(to_stream.try_send(message)?)
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a subscription
|
||||
///
|
||||
/// The subscriptions have two parts:
|
||||
///
|
||||
/// 1. Will create the subscription to the remote Pubsub service, Any events will be moved to
|
||||
/// the internal pubsub
|
||||
///
|
||||
/// 2. The internal subscription to the inner Pubsub. Because all subscriptions are going the
|
||||
/// transport, once events matches subscriptions, the inner_pubsub will receive the message and
|
||||
/// broadcasat the event.
|
||||
pub fn subscribe<I>(self: &Arc<Self>, request: I) -> Result<RemoteActiveConsumer<T>, Error>
|
||||
where
|
||||
I: SubscriptionRequest<
|
||||
Topic = <T::Spec as Spec>::Topic,
|
||||
SubscriptionId = <T::Spec as Spec>::SubscriptionId,
|
||||
>,
|
||||
{
|
||||
let subscription_name = request.subscription_name();
|
||||
let topics = request.try_get_topics()?;
|
||||
|
||||
let mut remote_subscriptions = self.remote_subscriptions.write();
|
||||
let mut subscriptions = self.subscriptions.write();
|
||||
|
||||
if subscriptions.get(&subscription_name).is_some() {
|
||||
return Err(Error::NoSubscription);
|
||||
}
|
||||
|
||||
let mut previous_messages = Vec::new();
|
||||
let cached_events = self.cached_events.read();
|
||||
|
||||
for topic in topics.iter() {
|
||||
if let Some(subscription) = remote_subscriptions.get_mut(topic) {
|
||||
subscription.total_subscribers += 1;
|
||||
|
||||
if let Some(v) = cached_events.get(topic).cloned() {
|
||||
previous_messages.push(v);
|
||||
}
|
||||
} else {
|
||||
let internal_sub_name = self.transport.new_name();
|
||||
remote_subscriptions.insert(
|
||||
topic.clone(),
|
||||
UniqueSubscription {
|
||||
total_subscribers: 1,
|
||||
name: internal_sub_name.clone(),
|
||||
},
|
||||
);
|
||||
|
||||
// new subscription is created, so the connection worker should be notified
|
||||
self.message_to_stream(StreamCtrl::Subscribe((internal_sub_name, topic.clone())))?;
|
||||
}
|
||||
}
|
||||
|
||||
subscriptions.insert(subscription_name, topics);
|
||||
drop(subscriptions);
|
||||
|
||||
Ok(RemoteActiveConsumer {
|
||||
inner: self.inner_pubsub.subscribe(request)?,
|
||||
previous_messages: previous_messages.into(),
|
||||
consumer: self.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for Consumer<T>
|
||||
where
|
||||
T: Transport + 'static,
|
||||
{
|
||||
fn drop(&mut self) {
|
||||
self.still_running
|
||||
.store(false, std::sync::atomic::Ordering::Release);
|
||||
if let Some(to_stream) = self.stream_ctrl.read().as_ref() {
|
||||
let _ = to_stream.try_send(StreamCtrl::Stop).inspect_err(|err| {
|
||||
tracing::error!("Failed to send message LongPoll::Stop due to {err:?}")
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Subscribe Message
|
||||
pub type SubscribeMessage<S> = (<S as Spec>::SubscriptionId, <S as Spec>::Topic);
|
||||
|
||||
/// Messages sent from the [`Consumer`] to the [`Transport`] background loop.
|
||||
pub enum StreamCtrl<S>
|
||||
where
|
||||
S: Spec + 'static,
|
||||
{
|
||||
/// Add a subscription
|
||||
Subscribe(SubscribeMessage<S>),
|
||||
/// Desuscribe
|
||||
Unsubscribe(S::SubscriptionId),
|
||||
/// Exit the loop
|
||||
Stop,
|
||||
}
|
||||
|
||||
impl<S> Clone for StreamCtrl<S>
|
||||
where
|
||||
S: Spec + 'static,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
match self {
|
||||
Self::Subscribe(s) => Self::Subscribe(s.clone()),
|
||||
Self::Unsubscribe(u) => Self::Unsubscribe(u.clone()),
|
||||
Self::Stop => Self::Stop,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Transport abstracts how the consumer talks to the remote pubsub.
|
||||
///
|
||||
/// Implement this on your HTTP/WebSocket client. The transport is responsible for:
|
||||
/// - creating unique subscription names,
|
||||
/// - keeping a long connection via `stream` **or** performing on-demand `poll`,
|
||||
/// - forwarding remote events to `InternalRelay`.
|
||||
///
|
||||
/// ```ignore
|
||||
/// struct WsTransport { /* ... */ }
|
||||
/// #[async_trait::async_trait]
|
||||
/// impl Transport for WsTransport {
|
||||
/// type Topic = MyTopic;
|
||||
/// fn new_name(&self) -> <Self::Topic as Topic>::SubscriptionName { 0 }
|
||||
/// async fn stream(/* ... */) -> Result<(), Error> { Ok(()) }
|
||||
/// async fn poll(/* ... */) -> Result<(), Error> { Ok(()) }
|
||||
/// }
|
||||
/// ```
|
||||
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
|
||||
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
|
||||
pub trait Transport: Send + Sync {
|
||||
/// Spec
|
||||
type Spec: Spec;
|
||||
|
||||
/// Create a new subscription name
|
||||
fn new_name(&self) -> <Self::Spec as Spec>::SubscriptionId;
|
||||
|
||||
/// Opens a persistent connection and continuously streams events.
|
||||
/// For protocols that support server push (e.g. WebSocket, SSE).
|
||||
async fn stream(
|
||||
&self,
|
||||
subscribe_changes: mpsc::Receiver<StreamCtrl<Self::Spec>>,
|
||||
topics: Vec<SubscribeMessage<Self::Spec>>,
|
||||
reply_to: InternalRelay<Self::Spec>,
|
||||
) -> Result<(), Error>;
|
||||
|
||||
/// Performs a one-shot fetch of any currently available events.
|
||||
/// Called repeatedly by the consumer when streaming is not available.
|
||||
async fn poll(
|
||||
&self,
|
||||
topics: Vec<SubscribeMessage<Self::Spec>>,
|
||||
reply_to: InternalRelay<Self::Spec>,
|
||||
) -> Result<(), Error>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tokio::time::{timeout, Duration};
|
||||
|
||||
use super::{
|
||||
InternalRelay, RemoteActiveConsumer, StreamCtrl, SubscribeMessage, Transport,
|
||||
INTERNAL_POLL_SIZE,
|
||||
};
|
||||
use crate::pub_sub::remote_consumer::Consumer;
|
||||
use crate::pub_sub::test::{CustomPubSub, IndexTest, Message};
|
||||
use crate::pub_sub::{Error, Spec, SubscriptionRequest};
|
||||
|
||||
// ===== Test Event/Topic types =====
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
enum SubscriptionReq {
|
||||
Foo(String, u64),
|
||||
Bar(String, u64),
|
||||
}
|
||||
|
||||
impl SubscriptionRequest for SubscriptionReq {
|
||||
type Topic = IndexTest;
|
||||
|
||||
type SubscriptionId = String;
|
||||
|
||||
fn try_get_topics(&self) -> Result<Vec<Self::Topic>, Error> {
|
||||
Ok(vec![match self {
|
||||
SubscriptionReq::Foo(_, n) => IndexTest::Foo(*n),
|
||||
SubscriptionReq::Bar(_, n) => IndexTest::Bar(*n),
|
||||
}])
|
||||
}
|
||||
|
||||
fn subscription_name(&self) -> Arc<Self::SubscriptionId> {
|
||||
Arc::new(match self {
|
||||
SubscriptionReq::Foo(n, _) => n.to_string(),
|
||||
SubscriptionReq::Bar(n, _) => n.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ===== A controllable in-memory Transport used by tests =====
|
||||
|
||||
/// TestTransport relays messages from a broadcast channel to the Consumer via `InternalRelay`.
|
||||
/// It also forwards Subscribe/Unsubscribe/Stop signals to an observer channel so tests can assert them.
|
||||
struct TestTransport {
|
||||
name_ctr: AtomicUsize,
|
||||
// We forward all transport-loop control messages here so tests can observe them.
|
||||
observe_ctrl_tx: mpsc::Sender<StreamCtrl<CustomPubSub>>,
|
||||
// Whether stream / poll are supported.
|
||||
support_long: bool,
|
||||
support_poll: bool,
|
||||
rx: Mutex<mpsc::Receiver<Message>>,
|
||||
}
|
||||
|
||||
impl TestTransport {
|
||||
fn new(
|
||||
support_long: bool,
|
||||
support_poll: bool,
|
||||
) -> (
|
||||
Self,
|
||||
mpsc::Sender<Message>,
|
||||
mpsc::Receiver<StreamCtrl<CustomPubSub>>,
|
||||
) {
|
||||
let (events_tx, rx) = mpsc::channel::<Message>(INTERNAL_POLL_SIZE);
|
||||
let (observe_ctrl_tx, observe_ctrl_rx) =
|
||||
mpsc::channel::<StreamCtrl<_>>(INTERNAL_POLL_SIZE);
|
||||
|
||||
let t = TestTransport {
|
||||
name_ctr: AtomicUsize::new(1),
|
||||
rx: Mutex::new(rx),
|
||||
observe_ctrl_tx,
|
||||
support_long,
|
||||
support_poll,
|
||||
};
|
||||
|
||||
(t, events_tx, observe_ctrl_rx)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Transport for TestTransport {
|
||||
type Spec = CustomPubSub;
|
||||
|
||||
fn new_name(&self) -> <Self::Spec as Spec>::SubscriptionId {
|
||||
format!("sub-{}", self.name_ctr.fetch_add(1, Ordering::Relaxed))
|
||||
}
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
mut subscribe_changes: mpsc::Receiver<StreamCtrl<Self::Spec>>,
|
||||
topics: Vec<SubscribeMessage<Self::Spec>>,
|
||||
reply_to: InternalRelay<Self::Spec>,
|
||||
) -> Result<(), Error> {
|
||||
if !self.support_long {
|
||||
return Err(Error::NotSupported);
|
||||
}
|
||||
|
||||
// Each invocation creates a fresh broadcast receiver
|
||||
let mut rx = self.rx.lock().await;
|
||||
let observe = self.observe_ctrl_tx.clone();
|
||||
|
||||
for topic in topics {
|
||||
observe.try_send(StreamCtrl::Subscribe(topic)).unwrap();
|
||||
}
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Forward any control (Subscribe/Unsubscribe/Stop) messages so the test can assert them.
|
||||
Some(ctrl) = subscribe_changes.recv() => {
|
||||
observe.try_send(ctrl.clone()).unwrap();
|
||||
if matches!(ctrl, StreamCtrl::Stop) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Relay external events into the inner pubsub
|
||||
Some(msg) = rx.recv() => {
|
||||
reply_to.send(msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn poll(
|
||||
&self,
|
||||
_topics: Vec<SubscribeMessage<Self::Spec>>,
|
||||
reply_to: InternalRelay<Self::Spec>,
|
||||
) -> Result<(), Error> {
|
||||
if !self.support_poll {
|
||||
return Err(Error::NotSupported);
|
||||
}
|
||||
|
||||
// On each poll call, drain anything currently pending and return.
|
||||
// (The Consumer calls this repeatedly; first call happens immediately.)
|
||||
let mut rx = self.rx.lock().await;
|
||||
// Non-blocking drain pass: try a few times without sleeping to keep tests snappy
|
||||
for _ in 0..32 {
|
||||
match rx.try_recv() {
|
||||
Ok(msg) => reply_to.send(msg),
|
||||
Err(mpsc::error::TryRecvError::Empty) => continue,
|
||||
Err(mpsc::error::TryRecvError::Disconnected) => break,
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ===== Helpers =====
|
||||
|
||||
async fn recv_next<T: Transport>(
|
||||
sub: &mut RemoteActiveConsumer<T>,
|
||||
dur_ms: u64,
|
||||
) -> Option<<T::Spec as Spec>::Event> {
|
||||
timeout(Duration::from_millis(dur_ms), sub.recv())
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
}
|
||||
|
||||
async fn expect_ctrl(
|
||||
rx: &mut mpsc::Receiver<StreamCtrl<CustomPubSub>>,
|
||||
dur_ms: u64,
|
||||
pred: impl Fn(&StreamCtrl<CustomPubSub>) -> bool,
|
||||
) -> StreamCtrl<CustomPubSub> {
|
||||
timeout(Duration::from_millis(dur_ms), async {
|
||||
loop {
|
||||
if let Some(msg) = rx.recv().await {
|
||||
if pred(&msg) {
|
||||
break msg;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("timed out waiting for control message")
|
||||
}
|
||||
|
||||
// ===== Tests =====
|
||||
|
||||
#[tokio::test]
|
||||
async fn stream_delivery_and_unsubscribe_on_drop() {
|
||||
// stream supported, poll supported (doesn't matter; prefer long)
|
||||
let (transport, events_tx, mut ctrl_rx) = TestTransport::new(true, true);
|
||||
|
||||
// prefer_polling = false so connection loop will try stream first.
|
||||
let consumer = Consumer::new(transport, false, ());
|
||||
|
||||
// Subscribe to Foo(7)
|
||||
let mut sub = consumer
|
||||
.subscribe(SubscriptionReq::Foo("t".to_owned(), 7))
|
||||
.expect("subscribe ok");
|
||||
|
||||
// We should see a Subscribe(name, topic) forwarded to transport
|
||||
let ctrl = expect_ctrl(
|
||||
&mut ctrl_rx,
|
||||
1000,
|
||||
|m| matches!(m, StreamCtrl::Subscribe((_, idx)) if *idx == IndexTest::Foo(7)),
|
||||
)
|
||||
.await;
|
||||
match ctrl {
|
||||
StreamCtrl::Subscribe((name, idx)) => {
|
||||
assert_ne!(name, "t".to_owned());
|
||||
assert_eq!(idx, IndexTest::Foo(7));
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
// Send an event that matches Foo(7)
|
||||
events_tx.send(Message { foo: 7, bar: 1 }).await.unwrap();
|
||||
let got = recv_next::<TestTransport>(&mut sub, 1000)
|
||||
.await
|
||||
.expect("got event");
|
||||
assert_eq!(got, Message { foo: 7, bar: 1 });
|
||||
|
||||
// Dropping the RemoteActiveConsumer should trigger an Unsubscribe(name)
|
||||
drop(sub);
|
||||
let _ctrl = expect_ctrl(&mut ctrl_rx, 1000, |m| {
|
||||
matches!(m, StreamCtrl::Unsubscribe(_))
|
||||
})
|
||||
.await;
|
||||
|
||||
// Drop the Consumer -> Stop is sent so the transport loop exits cleanly
|
||||
drop(consumer);
|
||||
let _ = expect_ctrl(&mut ctrl_rx, 1000, |m| matches!(m, StreamCtrl::Stop)).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_and_invalation() {
|
||||
// stream supported, poll supported (doesn't matter; prefer long)
|
||||
let (transport, events_tx, mut ctrl_rx) = TestTransport::new(true, true);
|
||||
|
||||
// prefer_polling = false so connection loop will try stream first.
|
||||
let consumer = Consumer::new(transport, false, ());
|
||||
|
||||
// Subscribe to Foo(7)
|
||||
let mut sub_1 = consumer
|
||||
.subscribe(SubscriptionReq::Foo("t".to_owned(), 7))
|
||||
.expect("subscribe ok");
|
||||
|
||||
// We should see a Subscribe(name, topic) forwarded to transport
|
||||
let ctrl = expect_ctrl(
|
||||
&mut ctrl_rx,
|
||||
1000,
|
||||
|m| matches!(m, StreamCtrl::Subscribe((_, idx)) if *idx == IndexTest::Foo(7)),
|
||||
)
|
||||
.await;
|
||||
match ctrl {
|
||||
StreamCtrl::Subscribe((name, idx)) => {
|
||||
assert_ne!(name, "t1".to_owned());
|
||||
assert_eq!(idx, IndexTest::Foo(7));
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
// Send an event that matches Foo(7)
|
||||
events_tx.send(Message { foo: 7, bar: 1 }).await.unwrap();
|
||||
let got = recv_next::<TestTransport>(&mut sub_1, 1000)
|
||||
.await
|
||||
.expect("got event");
|
||||
assert_eq!(got, Message { foo: 7, bar: 1 });
|
||||
|
||||
// Subscribe to Foo(7), should receive the latest message and future messages
|
||||
let mut sub_2 = consumer
|
||||
.subscribe(SubscriptionReq::Foo("t2".to_owned(), 7))
|
||||
.expect("subscribe ok");
|
||||
|
||||
let got = recv_next::<TestTransport>(&mut sub_2, 1000)
|
||||
.await
|
||||
.expect("got event");
|
||||
assert_eq!(got, Message { foo: 7, bar: 1 });
|
||||
|
||||
// Dropping the RemoteActiveConsumer but not unsubscribe, since sub_2 is still active
|
||||
drop(sub_1);
|
||||
|
||||
// Subscribe to Foo(7), should receive the latest message and future messages
|
||||
let mut sub_3 = consumer
|
||||
.subscribe(SubscriptionReq::Foo("t3".to_owned(), 7))
|
||||
.expect("subscribe ok");
|
||||
|
||||
// receive cache message
|
||||
let got = recv_next::<TestTransport>(&mut sub_3, 1000)
|
||||
.await
|
||||
.expect("got event");
|
||||
assert_eq!(got, Message { foo: 7, bar: 1 });
|
||||
|
||||
// Send an event that matches Foo(7)
|
||||
events_tx.send(Message { foo: 7, bar: 2 }).await.unwrap();
|
||||
|
||||
// receive new message
|
||||
let got = recv_next::<TestTransport>(&mut sub_2, 1000)
|
||||
.await
|
||||
.expect("got event");
|
||||
assert_eq!(got, Message { foo: 7, bar: 2 });
|
||||
|
||||
let got = recv_next::<TestTransport>(&mut sub_3, 1000)
|
||||
.await
|
||||
.expect("got event");
|
||||
assert_eq!(got, Message { foo: 7, bar: 2 });
|
||||
|
||||
drop(sub_2);
|
||||
drop(sub_3);
|
||||
|
||||
let _ctrl = expect_ctrl(&mut ctrl_rx, 1000, |m| {
|
||||
matches!(m, StreamCtrl::Unsubscribe(_))
|
||||
})
|
||||
.await;
|
||||
|
||||
// The cache should be dropped, so no new messages
|
||||
let mut sub_4 = consumer
|
||||
.subscribe(SubscriptionReq::Foo("t4".to_owned(), 7))
|
||||
.expect("subscribe ok");
|
||||
|
||||
assert!(
|
||||
recv_next::<TestTransport>(&mut sub_4, 1000).await.is_none(),
|
||||
"Should have not receive any update"
|
||||
);
|
||||
|
||||
drop(sub_4);
|
||||
|
||||
// Drop the Consumer -> Stop is sent so the transport loop exits cleanly
|
||||
let _ = expect_ctrl(&mut ctrl_rx, 2000, |m| matches!(m, StreamCtrl::Stop)).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn falls_back_to_poll_when_stream_not_supported() {
|
||||
// stream NOT supported, poll supported
|
||||
let (transport, events_tx, _) = TestTransport::new(false, true);
|
||||
// prefer_polling = true nudges the connection loop to poll first, but even if it
|
||||
// tried stream, our transport returns NotSupported and the loop will use poll.
|
||||
let consumer = Consumer::new(transport, true, ());
|
||||
|
||||
// Subscribe to Bar(5)
|
||||
let mut sub = consumer
|
||||
.subscribe(SubscriptionReq::Bar("t".to_owned(), 5))
|
||||
.expect("subscribe ok");
|
||||
|
||||
// Inject an event; the poll path should relay it on the first poll iteration
|
||||
events_tx.send(Message { foo: 9, bar: 5 }).await.unwrap();
|
||||
let got = recv_next::<TestTransport>(&mut sub, 1500)
|
||||
.await
|
||||
.expect("event relayed via polling");
|
||||
assert_eq!(got, Message { foo: 9, bar: 5 });
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn multiple_subscribers_share_single_remote_subscription() {
|
||||
// This validates the "coalescing" behavior in Consumer::subscribe where multiple local
|
||||
// subscribers to the same Topic should only create one remote subscription.
|
||||
let (transport, events_tx, mut ctrl_rx) = TestTransport::new(true, true);
|
||||
let consumer = Consumer::new(transport, false, ());
|
||||
|
||||
// Two local subscriptions to the SAME topic/name pair (different names)
|
||||
let mut a = consumer
|
||||
.subscribe(SubscriptionReq::Foo("t".to_owned(), 1))
|
||||
.expect("subscribe A");
|
||||
let _ = expect_ctrl(
|
||||
&mut ctrl_rx,
|
||||
1000,
|
||||
|m| matches!(m, StreamCtrl::Subscribe((_, idx)) if *idx == IndexTest::Foo(1)),
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut b = consumer
|
||||
.subscribe(SubscriptionReq::Foo("b".to_owned(), 1))
|
||||
.expect("subscribe B");
|
||||
|
||||
// No second Subscribe should be forwarded for the same topic (coalesced).
|
||||
// Give a little time; if one appears, we'll fail explicitly.
|
||||
if let Ok(Some(StreamCtrl::Subscribe((_, idx)))) =
|
||||
timeout(Duration::from_millis(400), ctrl_rx.recv()).await
|
||||
{
|
||||
assert_ne!(idx, IndexTest::Foo(1), "should not resubscribe same topic");
|
||||
}
|
||||
|
||||
// Send one event and ensure BOTH local subscribers receive it.
|
||||
events_tx.send(Message { foo: 1, bar: 42 }).await.unwrap();
|
||||
let got_a = recv_next::<TestTransport>(&mut a, 1000)
|
||||
.await
|
||||
.expect("A got");
|
||||
let got_b = recv_next::<TestTransport>(&mut b, 1000)
|
||||
.await
|
||||
.expect("B got");
|
||||
assert_eq!(got_a, Message { foo: 1, bar: 42 });
|
||||
assert_eq!(got_b, Message { foo: 1, bar: 42 });
|
||||
|
||||
// Drop B: no Unsubscribe should be sent yet (still one local subscriber).
|
||||
drop(b);
|
||||
if let Ok(Some(StreamCtrl::Unsubscribe(_))) =
|
||||
timeout(Duration::from_millis(400), ctrl_rx.recv()).await
|
||||
{
|
||||
panic!("Should NOT unsubscribe while another local subscriber exists");
|
||||
}
|
||||
|
||||
// Drop A: now remote unsubscribe should occur.
|
||||
drop(a);
|
||||
let _ = expect_ctrl(&mut ctrl_rx, 1000, |m| {
|
||||
matches!(m, StreamCtrl::Unsubscribe(_))
|
||||
})
|
||||
.await;
|
||||
|
||||
let _ = expect_ctrl(&mut ctrl_rx, 1000, |m| matches!(m, StreamCtrl::Stop)).await;
|
||||
}
|
||||
}
|
||||
159
crates/cdk-common/src/pub_sub/subscriber.rs
Normal file
159
crates/cdk-common/src/pub_sub/subscriber.rs
Normal file
@@ -0,0 +1,159 @@
|
||||
//! Active subscription
|
||||
use std::fmt::Debug;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use super::pubsub::{SubReceiver, TopicTree};
|
||||
use super::{Error, Spec};
|
||||
|
||||
/// Subscription request
|
||||
pub trait SubscriptionRequest {
|
||||
/// Topics
|
||||
type Topic;
|
||||
|
||||
/// Subscription Id
|
||||
type SubscriptionId;
|
||||
|
||||
/// Try to get topics from the request
|
||||
fn try_get_topics(&self) -> Result<Vec<Self::Topic>, Error>;
|
||||
|
||||
/// Get the subscription name
|
||||
fn subscription_name(&self) -> Arc<Self::SubscriptionId>;
|
||||
}
|
||||
|
||||
/// Active Subscription
|
||||
pub struct ActiveSubscription<S>
|
||||
where
|
||||
S: Spec + 'static,
|
||||
{
|
||||
id: usize,
|
||||
name: Arc<S::SubscriptionId>,
|
||||
active_subscribers: Arc<AtomicUsize>,
|
||||
topics: TopicTree<S>,
|
||||
subscribed_to: Vec<S::Topic>,
|
||||
receiver: Option<SubReceiver<S>>,
|
||||
}
|
||||
|
||||
impl<S> ActiveSubscription<S>
|
||||
where
|
||||
S: Spec + 'static,
|
||||
{
|
||||
/// Creates a new instance
|
||||
pub fn new(
|
||||
id: usize,
|
||||
name: Arc<S::SubscriptionId>,
|
||||
active_subscribers: Arc<AtomicUsize>,
|
||||
topics: TopicTree<S>,
|
||||
subscribed_to: Vec<S::Topic>,
|
||||
receiver: Option<SubReceiver<S>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id,
|
||||
name,
|
||||
active_subscribers,
|
||||
subscribed_to,
|
||||
topics,
|
||||
receiver,
|
||||
}
|
||||
}
|
||||
|
||||
/// Receives the next event
|
||||
pub async fn recv(&mut self) -> Option<S::Event> {
|
||||
self.receiver.as_mut()?.recv().await.map(|(_, event)| event)
|
||||
}
|
||||
|
||||
/// Try receive an event or return Noen right away
|
||||
pub fn try_recv(&mut self) -> Option<S::Event> {
|
||||
self.receiver
|
||||
.as_mut()?
|
||||
.try_recv()
|
||||
.ok()
|
||||
.map(|(_, event)| event)
|
||||
}
|
||||
|
||||
/// Get the subscription name
|
||||
pub fn name(&self) -> &S::SubscriptionId {
|
||||
&self.name
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Drop for ActiveSubscription<S>
|
||||
where
|
||||
S: Spec + 'static,
|
||||
{
|
||||
fn drop(&mut self) {
|
||||
// remove the listener
|
||||
let mut topics = self.topics.write();
|
||||
for index in self.subscribed_to.drain(..) {
|
||||
topics.remove(&(index, self.id));
|
||||
}
|
||||
|
||||
// decrement the number of active subscribers
|
||||
self.active_subscribers
|
||||
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
/// Lightweight sink used by producers to send events to subscribers.
|
||||
///
|
||||
/// You usually do not construct a `Subscriber` directly — it is provided to you in
|
||||
/// the [`Spec::fetch_events`] callback so you can backfill a new subscription.
|
||||
#[derive(Debug)]
|
||||
pub struct Subscriber<S>
|
||||
where
|
||||
S: Spec + 'static,
|
||||
{
|
||||
subscription: Arc<S::SubscriptionId>,
|
||||
inner: mpsc::Sender<(Arc<S::SubscriptionId>, S::Event)>,
|
||||
latest: Arc<Mutex<Option<S::Event>>>,
|
||||
}
|
||||
|
||||
impl<S> Clone for Subscriber<S>
|
||||
where
|
||||
S: Spec + 'static,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
subscription: self.subscription.clone(),
|
||||
inner: self.inner.clone(),
|
||||
latest: self.latest.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Subscriber<S>
|
||||
where
|
||||
S: Spec + 'static,
|
||||
{
|
||||
/// Create a new instance
|
||||
pub fn new(
|
||||
subscription: Arc<S::SubscriptionId>,
|
||||
inner: &mpsc::Sender<(Arc<S::SubscriptionId>, S::Event)>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: inner.clone(),
|
||||
subscription,
|
||||
latest: Arc::new(Mutex::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a message
|
||||
pub fn send(&self, event: S::Event) {
|
||||
let mut latest = if let Ok(reader) = self.latest.lock() {
|
||||
reader
|
||||
} else {
|
||||
let _ = self.inner.try_send((self.subscription.to_owned(), event));
|
||||
return;
|
||||
};
|
||||
|
||||
if let Some(last_event) = latest.replace(event.clone()) {
|
||||
if last_event == event {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let _ = self.inner.try_send((self.subscription.to_owned(), event));
|
||||
}
|
||||
}
|
||||
80
crates/cdk-common/src/pub_sub/types.rs
Normal file
80
crates/cdk-common/src/pub_sub/types.rs
Normal file
@@ -0,0 +1,80 @@
|
||||
//! Pubsub Event definition
|
||||
//!
|
||||
//! The Pubsub Event defines the Topic struct and how an event can be converted to Topics.
|
||||
|
||||
use std::hash::Hash;
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
|
||||
use super::Subscriber;
|
||||
|
||||
/// Pubsub settings
|
||||
#[async_trait::async_trait]
|
||||
pub trait Spec: Send + Sync {
|
||||
/// Topic
|
||||
type Topic: Send
|
||||
+ Sync
|
||||
+ Clone
|
||||
+ Eq
|
||||
+ PartialEq
|
||||
+ Ord
|
||||
+ PartialOrd
|
||||
+ Hash
|
||||
+ Send
|
||||
+ Sync
|
||||
+ DeserializeOwned
|
||||
+ Serialize;
|
||||
|
||||
/// Event
|
||||
type Event: Event<Topic = Self::Topic>
|
||||
+ Send
|
||||
+ Sync
|
||||
+ Eq
|
||||
+ PartialEq
|
||||
+ DeserializeOwned
|
||||
+ Serialize;
|
||||
|
||||
/// Subscription Id
|
||||
type SubscriptionId: Clone
|
||||
+ Default
|
||||
+ Eq
|
||||
+ PartialEq
|
||||
+ Ord
|
||||
+ PartialOrd
|
||||
+ Hash
|
||||
+ Send
|
||||
+ Sync
|
||||
+ DeserializeOwned
|
||||
+ Serialize;
|
||||
|
||||
/// Create a new context
|
||||
type Context;
|
||||
|
||||
/// Create a new instance from a given context
|
||||
fn new_instance(context: Self::Context) -> Arc<Self>
|
||||
where
|
||||
Self: Sized;
|
||||
|
||||
/// Callback function that is called on new subscriptions, to back-fill optionally the previous
|
||||
/// events
|
||||
async fn fetch_events(
|
||||
self: &Arc<Self>,
|
||||
topics: Vec<<Self::Event as Event>::Topic>,
|
||||
reply_to: Subscriber<Self>,
|
||||
) where
|
||||
Self: Sized;
|
||||
}
|
||||
|
||||
/// Event trait
|
||||
pub trait Event: Clone + Send + Sync + Eq + PartialEq + DeserializeOwned + Serialize {
|
||||
/// Generic Topic
|
||||
///
|
||||
/// It should be serializable/deserializable to be stored in the database layer and it should
|
||||
/// also be sorted in a BTree for in-memory matching
|
||||
type Topic;
|
||||
|
||||
/// To topics
|
||||
fn get_topics(&self) -> Vec<Self::Topic>;
|
||||
}
|
||||
@@ -1,98 +1,115 @@
|
||||
//! Subscription types and traits
|
||||
#[cfg(feature = "mint")]
|
||||
use std::ops::Deref;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use cashu::nut17::{self};
|
||||
#[cfg(feature = "mint")]
|
||||
use cashu::nut17::{Error, Kind, Notification};
|
||||
#[cfg(feature = "mint")]
|
||||
use cashu::nut17::{self, Kind, NotificationId};
|
||||
use cashu::quote_id::QuoteId;
|
||||
#[cfg(feature = "mint")]
|
||||
use cashu::{NotificationPayload, PublicKey};
|
||||
#[cfg(feature = "mint")]
|
||||
use cashu::PublicKey;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[cfg(feature = "mint")]
|
||||
use crate::pub_sub::index::{Index, Indexable, SubscriptionGlobalId};
|
||||
use crate::pub_sub::SubId;
|
||||
use crate::pub_sub::{Error, SubscriptionRequest};
|
||||
|
||||
/// Subscription parameters.
|
||||
/// CDK/Mint Subscription parameters.
|
||||
///
|
||||
/// This is a concrete type alias for `nut17::Params<SubId>`.
|
||||
pub type Params = nut17::Params<SubId>;
|
||||
pub type Params = nut17::Params<Arc<SubId>>;
|
||||
|
||||
/// Wrapper around `nut17::Params` to implement `Indexable` for `Notification`.
|
||||
#[cfg(feature = "mint")]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct IndexableParams(Params);
|
||||
impl SubscriptionRequest for Params {
|
||||
type Topic = NotificationId<QuoteId>;
|
||||
|
||||
#[cfg(feature = "mint")]
|
||||
impl From<Params> for IndexableParams {
|
||||
fn from(params: Params) -> Self {
|
||||
Self(params)
|
||||
type SubscriptionId = SubId;
|
||||
|
||||
fn subscription_name(&self) -> Arc<Self::SubscriptionId> {
|
||||
self.id.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "mint")]
|
||||
impl TryFrom<IndexableParams> for Vec<Index<Notification>> {
|
||||
type Error = Error;
|
||||
fn try_from(params: IndexableParams) -> Result<Self, Self::Error> {
|
||||
let sub_id: SubscriptionGlobalId = Default::default();
|
||||
let params = params.0;
|
||||
params
|
||||
.filters
|
||||
.into_iter()
|
||||
.map(|filter| {
|
||||
let idx = match params.kind {
|
||||
Kind::Bolt11MeltQuote => {
|
||||
Notification::MeltQuoteBolt11(QuoteId::from_str(&filter)?)
|
||||
}
|
||||
Kind::Bolt11MintQuote => {
|
||||
Notification::MintQuoteBolt11(QuoteId::from_str(&filter)?)
|
||||
}
|
||||
Kind::ProofState => Notification::ProofState(PublicKey::from_str(&filter)?),
|
||||
Kind::Bolt12MintQuote => {
|
||||
Notification::MintQuoteBolt12(QuoteId::from_str(&filter)?)
|
||||
}
|
||||
};
|
||||
fn try_get_topics(&self) -> Result<Vec<Self::Topic>, Error> {
|
||||
self.filters
|
||||
.iter()
|
||||
.map(|filter| match self.kind {
|
||||
Kind::Bolt11MeltQuote => QuoteId::from_str(filter)
|
||||
.map(NotificationId::MeltQuoteBolt11)
|
||||
.map_err(|_| Error::ParsingError(filter.to_owned())),
|
||||
Kind::Bolt11MintQuote => QuoteId::from_str(filter)
|
||||
.map(NotificationId::MintQuoteBolt11)
|
||||
.map_err(|_| Error::ParsingError(filter.to_owned())),
|
||||
Kind::ProofState => PublicKey::from_str(filter)
|
||||
.map(NotificationId::ProofState)
|
||||
.map_err(|_| Error::ParsingError(filter.to_owned())),
|
||||
|
||||
Ok(Index::from((idx, params.id.clone(), sub_id)))
|
||||
Kind::Bolt12MintQuote => QuoteId::from_str(filter)
|
||||
.map(NotificationId::MintQuoteBolt12)
|
||||
.map_err(|_| Error::ParsingError(filter.to_owned())),
|
||||
})
|
||||
.collect::<Result<_, _>>()
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "mint")]
|
||||
impl AsRef<SubId> for IndexableParams {
|
||||
fn as_ref(&self) -> &SubId {
|
||||
&self.0.id
|
||||
/// Subscriptions parameters for the wallet
|
||||
///
|
||||
/// This is because the Wallet can subscribe to non CDK quotes, where IDs are not constraint to
|
||||
/// QuoteId
|
||||
pub type WalletParams = nut17::Params<Arc<String>>;
|
||||
|
||||
impl SubscriptionRequest for WalletParams {
|
||||
type Topic = NotificationId<String>;
|
||||
|
||||
type SubscriptionId = String;
|
||||
|
||||
fn subscription_name(&self) -> Arc<Self::SubscriptionId> {
|
||||
self.id.clone()
|
||||
}
|
||||
|
||||
fn try_get_topics(&self) -> Result<Vec<Self::Topic>, Error> {
|
||||
self.filters
|
||||
.iter()
|
||||
.map(|filter| {
|
||||
Ok(match self.kind {
|
||||
Kind::Bolt11MeltQuote => NotificationId::MeltQuoteBolt11(filter.to_owned()),
|
||||
Kind::Bolt11MintQuote => NotificationId::MintQuoteBolt11(filter.to_owned()),
|
||||
Kind::ProofState => PublicKey::from_str(filter)
|
||||
.map(NotificationId::ProofState)
|
||||
.map_err(|_| Error::ParsingError(filter.to_owned()))?,
|
||||
|
||||
Kind::Bolt12MintQuote => NotificationId::MintQuoteBolt12(filter.to_owned()),
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "mint")]
|
||||
impl Indexable for NotificationPayload<QuoteId> {
|
||||
type Type = Notification;
|
||||
/// Subscription Id wrapper
|
||||
///
|
||||
/// This is the place to add some sane default (like a max length) to the
|
||||
/// subscription ID
|
||||
#[derive(Debug, Clone, Default, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
|
||||
pub struct SubId(String);
|
||||
|
||||
fn to_indexes(&self) -> Vec<Index<Self::Type>> {
|
||||
match self {
|
||||
NotificationPayload::ProofState(proof_state) => {
|
||||
vec![Index::from(Notification::ProofState(proof_state.y))]
|
||||
}
|
||||
NotificationPayload::MeltQuoteBolt11Response(melt_quote) => {
|
||||
vec![Index::from(Notification::MeltQuoteBolt11(
|
||||
melt_quote.quote.clone(),
|
||||
))]
|
||||
}
|
||||
NotificationPayload::MintQuoteBolt11Response(mint_quote) => {
|
||||
vec![Index::from(Notification::MintQuoteBolt11(
|
||||
mint_quote.quote.clone(),
|
||||
))]
|
||||
}
|
||||
NotificationPayload::MintQuoteBolt12Response(mint_quote) => {
|
||||
vec![Index::from(Notification::MintQuoteBolt12(
|
||||
mint_quote.quote.clone(),
|
||||
))]
|
||||
}
|
||||
}
|
||||
impl From<&str> for SubId {
|
||||
fn from(s: &str) -> Self {
|
||||
Self(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for SubId {
|
||||
fn from(s: String) -> Self {
|
||||
Self(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for SubId {
|
||||
type Err = ();
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
Ok(Self(s.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for SubId {
|
||||
type Target = String;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
//!
|
||||
//! This module extends the `cashu` crate with types and functions for the CDK, using the correct
|
||||
//! expected ID types.
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(feature = "mint")]
|
||||
use cashu::nut17::ws::JSON_RPC_VERSION;
|
||||
use cashu::nut17::{self};
|
||||
@@ -10,7 +12,7 @@ use cashu::quote_id::QuoteId;
|
||||
#[cfg(feature = "mint")]
|
||||
use cashu::NotificationPayload;
|
||||
|
||||
use crate::pub_sub::SubId;
|
||||
type SubId = Arc<crate::subscription::SubId>;
|
||||
|
||||
/// Request to unsubscribe from a websocket subscription
|
||||
pub type WsUnsubscribeRequest = nut17::ws::WsUnsubscribeRequest<SubId>;
|
||||
|
||||
@@ -22,7 +22,7 @@ ctor = "0.2"
|
||||
futures = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde = { workspace = true, features = ["derive", "rc"] }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true, features = ["sync", "rt", "rt-multi-thread"] }
|
||||
@@ -41,4 +41,3 @@ postgres = ["cdk-postgres"]
|
||||
[[bin]]
|
||||
name = "uniffi-bindgen"
|
||||
path = "src/bin/uniffi-bindgen.rs"
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
//! Subscription-related FFI types
|
||||
use std::sync::Arc;
|
||||
|
||||
use cdk::pub_sub::SubId;
|
||||
use cdk::event::MintEvent;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::proof::ProofStateUpdate;
|
||||
@@ -53,21 +54,17 @@ pub struct SubscribeParams {
|
||||
pub id: Option<String>,
|
||||
}
|
||||
|
||||
impl From<SubscribeParams> for cdk::nuts::nut17::Params<cdk::pub_sub::SubId> {
|
||||
impl From<SubscribeParams> for cdk::nuts::nut17::Params<Arc<String>> {
|
||||
fn from(params: SubscribeParams) -> Self {
|
||||
let sub_id = params
|
||||
.id
|
||||
.map(|id| SubId::from(id.as_str()))
|
||||
.unwrap_or_else(|| {
|
||||
// Generate a random ID
|
||||
let uuid = uuid::Uuid::new_v4();
|
||||
SubId::from(uuid.to_string().as_str())
|
||||
});
|
||||
let sub_id = params.id.unwrap_or_else(|| {
|
||||
// Generate a random ID
|
||||
uuid::Uuid::new_v4().to_string()
|
||||
});
|
||||
|
||||
cdk::nuts::nut17::Params {
|
||||
kind: params.kind.into(),
|
||||
filters: params.filters,
|
||||
id: sub_id,
|
||||
id: Arc::new(sub_id),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -132,12 +129,7 @@ impl ActiveSubscription {
|
||||
/// Try to receive a notification without blocking
|
||||
pub async fn try_recv(&self) -> Result<Option<NotificationPayload>, FfiError> {
|
||||
let mut guard = self.inner.lock().await;
|
||||
guard
|
||||
.try_recv()
|
||||
.map(|opt| opt.map(Into::into))
|
||||
.map_err(|e| FfiError::Generic {
|
||||
msg: format!("Failed to receive notification: {}", e),
|
||||
})
|
||||
Ok(guard.try_recv().map(Into::into))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -156,9 +148,9 @@ pub enum NotificationPayload {
|
||||
},
|
||||
}
|
||||
|
||||
impl From<cdk::nuts::NotificationPayload<String>> for NotificationPayload {
|
||||
fn from(payload: cdk::nuts::NotificationPayload<String>) -> Self {
|
||||
match payload {
|
||||
impl From<MintEvent<String>> for NotificationPayload {
|
||||
fn from(payload: MintEvent<String>) -> Self {
|
||||
match payload.into() {
|
||||
cdk::nuts::NotificationPayload::ProofState(states) => NotificationPayload::ProofState {
|
||||
proof_states: vec![states.into()],
|
||||
},
|
||||
|
||||
@@ -349,7 +349,7 @@ impl Wallet {
|
||||
&self,
|
||||
params: SubscribeParams,
|
||||
) -> Result<std::sync::Arc<ActiveSubscription>, FfiError> {
|
||||
let cdk_params: cdk::nuts::nut17::Params<cdk::pub_sub::SubId> = params.clone().into();
|
||||
let cdk_params: cdk::nuts::nut17::Params<Arc<String>> = params.clone().into();
|
||||
let sub_id = cdk_params.id.to_string();
|
||||
let active_sub = self.inner.subscribe(cdk_params).await;
|
||||
Ok(std::sync::Arc::new(ActiveSubscription::new(
|
||||
|
||||
@@ -13,6 +13,7 @@ use std::assert_eq;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::hash::RandomState;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use cashu::amount::SplitTarget;
|
||||
@@ -24,7 +25,7 @@ use cashu::{
|
||||
};
|
||||
use cdk::mint::Mint;
|
||||
use cdk::nuts::nut00::ProofsMethods;
|
||||
use cdk::subscription::{IndexableParams, Params};
|
||||
use cdk::subscription::Params;
|
||||
use cdk::wallet::types::{TransactionDirection, TransactionId};
|
||||
use cdk::wallet::{ReceiveOptions, SendMemo, SendOptions};
|
||||
use cdk::Amount;
|
||||
@@ -485,15 +486,11 @@ pub async fn test_p2pk_swap() {
|
||||
|
||||
let mut listener = mint_bob
|
||||
.pubsub_manager()
|
||||
.try_subscribe::<IndexableParams>(
|
||||
Params {
|
||||
kind: cdk::nuts::nut17::Kind::ProofState,
|
||||
filters: public_keys_to_listen.clone(),
|
||||
id: "test".into(),
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
.subscribe(Params {
|
||||
kind: cdk::nuts::nut17::Kind::ProofState,
|
||||
filters: public_keys_to_listen.clone(),
|
||||
id: Arc::new("test".into()),
|
||||
})
|
||||
.expect("valid subscription");
|
||||
|
||||
match mint_bob.process_swap_request(swap_request).await {
|
||||
@@ -520,9 +517,8 @@ pub async fn test_p2pk_swap() {
|
||||
sleep(Duration::from_secs(1)).await;
|
||||
|
||||
let mut msgs = HashMap::new();
|
||||
while let Ok((sub_id, msg)) = listener.try_recv() {
|
||||
assert_eq!(sub_id, "test".into());
|
||||
match msg {
|
||||
while let Some(msg) = listener.try_recv() {
|
||||
match msg.into_inner() {
|
||||
NotificationPayload::ProofState(ProofState { y, state, .. }) => {
|
||||
msgs.entry(y.to_string())
|
||||
.or_insert_with(Vec::new)
|
||||
@@ -544,7 +540,7 @@ pub async fn test_p2pk_swap() {
|
||||
);
|
||||
}
|
||||
|
||||
assert!(listener.try_recv().is_err(), "no other event is happening");
|
||||
assert!(listener.try_recv().is_none(), "no other event is happening");
|
||||
assert!(msgs.is_empty(), "Only expected key events are received");
|
||||
}
|
||||
|
||||
|
||||
@@ -165,7 +165,7 @@ async fn test_websocket_connection() {
|
||||
.expect("timeout waiting for unpaid notification")
|
||||
.expect("No paid notification received");
|
||||
|
||||
match msg {
|
||||
match msg.into_inner() {
|
||||
NotificationPayload::MintQuoteBolt11Response(response) => {
|
||||
assert_eq!(response.quote.to_string(), mint_quote.id);
|
||||
assert_eq!(response.state, MintQuoteState::Unpaid);
|
||||
@@ -185,7 +185,7 @@ async fn test_websocket_connection() {
|
||||
.expect("timeout waiting for paid notification")
|
||||
.expect("No paid notification received");
|
||||
|
||||
match msg {
|
||||
match msg.into_inner() {
|
||||
NotificationPayload::MintQuoteBolt11Response(response) => {
|
||||
assert_eq!(response.quote.to_string(), mint_quote.id);
|
||||
assert_eq!(response.state, MintQuoteState::Paid);
|
||||
|
||||
@@ -100,8 +100,6 @@ ring = { version = "0.17.14", features = ["wasm32_unknown_unknown_js"] }
|
||||
rustls = { workspace = true, optional = true }
|
||||
|
||||
uuid = { workspace = true, features = ["js"] }
|
||||
wasm-bindgen = "0.2"
|
||||
wasm-bindgen-futures = "0.4"
|
||||
gloo-timers = { version = "0.3", features = ["futures"] }
|
||||
|
||||
[[example]]
|
||||
|
||||
127
crates/cdk/src/event.rs
Normal file
127
crates/cdk/src/event.rs
Normal file
@@ -0,0 +1,127 @@
|
||||
//! Mint event types
|
||||
use std::fmt::Debug;
|
||||
use std::hash::Hash;
|
||||
use std::ops::Deref;
|
||||
|
||||
use cdk_common::nut17::NotificationId;
|
||||
use cdk_common::pub_sub::Event;
|
||||
use cdk_common::{
|
||||
MeltQuoteBolt11Response, MintQuoteBolt11Response, MintQuoteBolt12Response, NotificationPayload,
|
||||
ProofState,
|
||||
};
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Simple wrapper over `NotificationPayload<QuoteId>` which is a foreign type
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(bound = "T: Serialize + DeserializeOwned")]
|
||||
pub struct MintEvent<T>(NotificationPayload<T>)
|
||||
where
|
||||
T: Clone + Eq + PartialEq;
|
||||
|
||||
impl<T> From<MintEvent<T>> for NotificationPayload<T>
|
||||
where
|
||||
T: Clone + Eq + PartialEq,
|
||||
{
|
||||
fn from(value: MintEvent<T>) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Deref for MintEvent<T>
|
||||
where
|
||||
T: Clone + Eq + PartialEq,
|
||||
{
|
||||
type Target = NotificationPayload<T>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<ProofState> for MintEvent<T>
|
||||
where
|
||||
T: Clone + Eq + PartialEq,
|
||||
{
|
||||
fn from(value: ProofState) -> Self {
|
||||
Self(NotificationPayload::ProofState(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> MintEvent<T>
|
||||
where
|
||||
T: Clone + Eq + PartialEq,
|
||||
{
|
||||
/// New instance
|
||||
pub fn new(t: NotificationPayload<T>) -> Self {
|
||||
Self(t)
|
||||
}
|
||||
|
||||
/// Get inner
|
||||
pub fn inner(&self) -> &NotificationPayload<T> {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// Into inner
|
||||
pub fn into_inner(self) -> NotificationPayload<T> {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<NotificationPayload<T>> for MintEvent<T>
|
||||
where
|
||||
T: Clone + Eq + PartialEq,
|
||||
{
|
||||
fn from(value: NotificationPayload<T>) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<MintQuoteBolt11Response<T>> for MintEvent<T>
|
||||
where
|
||||
T: Clone + Eq + PartialEq,
|
||||
{
|
||||
fn from(value: MintQuoteBolt11Response<T>) -> Self {
|
||||
Self(NotificationPayload::MintQuoteBolt11Response(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<MeltQuoteBolt11Response<T>> for MintEvent<T>
|
||||
where
|
||||
T: Clone + Eq + PartialEq,
|
||||
{
|
||||
fn from(value: MeltQuoteBolt11Response<T>) -> Self {
|
||||
Self(NotificationPayload::MeltQuoteBolt11Response(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<MintQuoteBolt12Response<T>> for MintEvent<T>
|
||||
where
|
||||
T: Clone + Eq + PartialEq,
|
||||
{
|
||||
fn from(value: MintQuoteBolt12Response<T>) -> Self {
|
||||
Self(NotificationPayload::MintQuoteBolt12Response(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Event for MintEvent<T>
|
||||
where
|
||||
T: Clone + Serialize + DeserializeOwned + Debug + Ord + Hash + Send + Sync + Eq + PartialEq,
|
||||
{
|
||||
type Topic = NotificationId<T>;
|
||||
|
||||
fn get_topics(&self) -> Vec<Self::Topic> {
|
||||
vec![match &self.0 {
|
||||
NotificationPayload::MeltQuoteBolt11Response(r) => {
|
||||
NotificationId::MeltQuoteBolt11(r.quote.to_owned())
|
||||
}
|
||||
NotificationPayload::MintQuoteBolt11Response(r) => {
|
||||
NotificationId::MintQuoteBolt11(r.quote.to_owned())
|
||||
}
|
||||
NotificationPayload::MintQuoteBolt12Response(r) => {
|
||||
NotificationId::MintQuoteBolt12(r.quote.to_owned())
|
||||
}
|
||||
NotificationPayload::ProofState(p) => NotificationId::ProofState(p.y.to_owned()),
|
||||
}]
|
||||
}
|
||||
}
|
||||
@@ -32,11 +32,9 @@ mod bip353;
|
||||
#[cfg(all(any(feature = "wallet", feature = "mint"), feature = "auth"))]
|
||||
mod oidc_client;
|
||||
|
||||
#[cfg(all(any(feature = "wallet", feature = "mint"), feature = "auth"))]
|
||||
pub use oidc_client::OidcClient;
|
||||
|
||||
pub mod pub_sub;
|
||||
|
||||
#[cfg(feature = "mint")]
|
||||
#[doc(hidden)]
|
||||
pub use cdk_common::payment as cdk_payment;
|
||||
/// Re-export amount type
|
||||
#[doc(hidden)]
|
||||
pub use cdk_common::{
|
||||
@@ -44,10 +42,11 @@ pub use cdk_common::{
|
||||
error::{self, Error},
|
||||
lightning_invoice, mint_url, nuts, secret, util, ws, Amount, Bolt11Invoice,
|
||||
};
|
||||
#[cfg(feature = "mint")]
|
||||
#[doc(hidden)]
|
||||
pub use cdk_common::{payment as cdk_payment, subscription};
|
||||
#[cfg(all(any(feature = "wallet", feature = "mint"), feature = "auth"))]
|
||||
pub use oidc_client::OidcClient;
|
||||
|
||||
#[cfg(any(feature = "wallet", feature = "mint"))]
|
||||
pub mod event;
|
||||
pub mod fees;
|
||||
|
||||
#[doc(hidden)]
|
||||
@@ -69,6 +68,8 @@ pub use self::wallet::HttpClient;
|
||||
#[doc(hidden)]
|
||||
pub type Result<T, E = Box<dyn std::error::Error>> = std::result::Result<T, E>;
|
||||
|
||||
/// Re-export subscription
|
||||
pub use cdk_common::subscription;
|
||||
/// Re-export futures::Stream
|
||||
#[cfg(any(feature = "wallet", feature = "mint"))]
|
||||
pub use futures::{Stream, StreamExt};
|
||||
|
||||
@@ -322,12 +322,12 @@ impl Mint {
|
||||
PaymentMethod::Bolt11 => {
|
||||
let res: MintQuoteBolt11Response<QuoteId> = quote.clone().into();
|
||||
self.pubsub_manager
|
||||
.broadcast(NotificationPayload::MintQuoteBolt11Response(res));
|
||||
.publish(NotificationPayload::MintQuoteBolt11Response(res));
|
||||
}
|
||||
PaymentMethod::Bolt12 => {
|
||||
let res: MintQuoteBolt12Response<QuoteId> = quote.clone().try_into()?;
|
||||
self.pubsub_manager
|
||||
.broadcast(NotificationPayload::MintQuoteBolt12Response(res));
|
||||
.publish(NotificationPayload::MintQuoteBolt12Response(res));
|
||||
}
|
||||
PaymentMethod::Custom(_) => {}
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ mod ln;
|
||||
mod melt;
|
||||
mod proof_writer;
|
||||
mod start_up_check;
|
||||
pub mod subscription;
|
||||
mod subscription;
|
||||
mod swap;
|
||||
mod verification;
|
||||
|
||||
@@ -206,7 +206,7 @@ impl Mint {
|
||||
|
||||
Ok(Self {
|
||||
signatory,
|
||||
pubsub_manager: Arc::new(localstore.clone().into()),
|
||||
pubsub_manager: PubSubManager::new(localstore.clone()),
|
||||
localstore,
|
||||
#[cfg(feature = "auth")]
|
||||
oidc_client: computed_info.nuts.nut21.as_ref().map(|nut21| {
|
||||
|
||||
244
crates/cdk/src/mint/subscription.rs
Normal file
244
crates/cdk/src/mint/subscription.rs
Normal file
@@ -0,0 +1,244 @@
|
||||
//! Specific Subscription for the cdk crate
|
||||
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
use cdk_common::database::DynMintDatabase;
|
||||
use cdk_common::mint::MintQuote;
|
||||
use cdk_common::nut17::NotificationId;
|
||||
use cdk_common::pub_sub::{Pubsub, Spec, Subscriber};
|
||||
use cdk_common::subscription::SubId;
|
||||
use cdk_common::{
|
||||
Amount, BlindSignature, MeltQuoteBolt11Response, MeltQuoteState, MintQuoteBolt11Response,
|
||||
MintQuoteBolt12Response, MintQuoteState, PaymentMethod, ProofState, PublicKey, QuoteId,
|
||||
};
|
||||
|
||||
use crate::event::MintEvent;
|
||||
|
||||
/// Mint subtopics
|
||||
#[derive(Clone)]
|
||||
pub struct MintPubSubSpec {
|
||||
db: DynMintDatabase,
|
||||
}
|
||||
|
||||
impl MintPubSubSpec {
|
||||
async fn get_events_from_db(
|
||||
&self,
|
||||
request: &[NotificationId<QuoteId>],
|
||||
) -> Result<Vec<MintEvent<QuoteId>>, String> {
|
||||
let mut to_return = vec![];
|
||||
let mut public_keys: Vec<PublicKey> = Vec::new();
|
||||
let mut melt_queries = Vec::new();
|
||||
let mut mint_queries = Vec::new();
|
||||
|
||||
for idx in request.iter() {
|
||||
match idx {
|
||||
NotificationId::ProofState(pk) => public_keys.push(*pk),
|
||||
NotificationId::MeltQuoteBolt11(uuid) => {
|
||||
melt_queries.push(self.db.get_melt_quote(uuid))
|
||||
}
|
||||
NotificationId::MintQuoteBolt11(uuid) => {
|
||||
mint_queries.push(self.db.get_mint_quote(uuid))
|
||||
}
|
||||
NotificationId::MintQuoteBolt12(uuid) => {
|
||||
mint_queries.push(self.db.get_mint_quote(uuid))
|
||||
}
|
||||
NotificationId::MeltQuoteBolt12(uuid) => {
|
||||
melt_queries.push(self.db.get_melt_quote(uuid))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !melt_queries.is_empty() {
|
||||
to_return.extend(
|
||||
futures::future::try_join_all(melt_queries)
|
||||
.await
|
||||
.map(|quotes| {
|
||||
quotes
|
||||
.into_iter()
|
||||
.filter_map(|quote| quote.map(|x| x.into()))
|
||||
.map(|x: MeltQuoteBolt11Response<QuoteId>| x.into())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.map_err(|e| e.to_string())?,
|
||||
);
|
||||
}
|
||||
|
||||
if !mint_queries.is_empty() {
|
||||
to_return.extend(
|
||||
futures::future::try_join_all(mint_queries)
|
||||
.await
|
||||
.map(|quotes| {
|
||||
quotes
|
||||
.into_iter()
|
||||
.filter_map(|quote| {
|
||||
quote.and_then(|x| match x.payment_method {
|
||||
PaymentMethod::Bolt11 => {
|
||||
let response: MintQuoteBolt11Response<QuoteId> = x.into();
|
||||
Some(response.into())
|
||||
}
|
||||
PaymentMethod::Bolt12 => match x.try_into() {
|
||||
Ok(response) => {
|
||||
let response: MintQuoteBolt12Response<QuoteId> =
|
||||
response;
|
||||
Some(response.into())
|
||||
}
|
||||
Err(_) => None,
|
||||
},
|
||||
PaymentMethod::Custom(_) => None,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.map_err(|e| e.to_string())?,
|
||||
);
|
||||
}
|
||||
|
||||
if !public_keys.is_empty() {
|
||||
to_return.extend(
|
||||
self.db
|
||||
.get_proofs_states(public_keys.as_slice())
|
||||
.await
|
||||
.map_err(|e| e.to_string())?
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, state)| state.map(|state| (public_keys[idx], state).into()))
|
||||
.map(|state: ProofState| state.into()),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(to_return)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Spec for MintPubSubSpec {
|
||||
type SubscriptionId = SubId;
|
||||
|
||||
type Topic = NotificationId<QuoteId>;
|
||||
|
||||
type Event = MintEvent<QuoteId>;
|
||||
|
||||
type Context = DynMintDatabase;
|
||||
|
||||
fn new_instance(context: Self::Context) -> Arc<Self> {
|
||||
Arc::new(Self { db: context })
|
||||
}
|
||||
|
||||
async fn fetch_events(self: &Arc<Self>, topics: Vec<Self::Topic>, reply_to: Subscriber<Self>) {
|
||||
for event in self
|
||||
.get_events_from_db(&topics)
|
||||
.await
|
||||
.inspect_err(|err| tracing::error!("Error reading events from db {err:?}"))
|
||||
.unwrap_or_default()
|
||||
{
|
||||
let _ = reply_to.send(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// PubsubManager
|
||||
pub struct PubSubManager(Pubsub<MintPubSubSpec>);
|
||||
|
||||
impl PubSubManager {
|
||||
/// Create a new instance
|
||||
pub fn new(db: DynMintDatabase) -> Arc<Self> {
|
||||
Arc::new(Self(Pubsub::new(MintPubSubSpec::new_instance(db))))
|
||||
}
|
||||
|
||||
/// Helper function to emit a ProofState status
|
||||
pub fn proof_state<E: Into<ProofState>>(&self, event: E) {
|
||||
self.publish(event.into());
|
||||
}
|
||||
|
||||
/// Helper function to publish even of a mint quote being paid
|
||||
pub fn mint_quote_issue(&self, mint_quote: &MintQuote, total_issued: Amount) {
|
||||
match mint_quote.payment_method {
|
||||
PaymentMethod::Bolt11 => {
|
||||
self.mint_quote_bolt11_status(mint_quote.clone(), MintQuoteState::Issued);
|
||||
}
|
||||
PaymentMethod::Bolt12 => {
|
||||
self.mint_quote_bolt12_status(
|
||||
mint_quote.clone(),
|
||||
mint_quote.amount_paid(),
|
||||
total_issued,
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
// We don't send ws updates for unknown methods
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to publish even of a mint quote being paid
|
||||
pub fn mint_quote_payment(&self, mint_quote: &MintQuote, total_paid: Amount) {
|
||||
match mint_quote.payment_method {
|
||||
PaymentMethod::Bolt11 => {
|
||||
self.mint_quote_bolt11_status(mint_quote.clone(), MintQuoteState::Paid);
|
||||
}
|
||||
PaymentMethod::Bolt12 => {
|
||||
self.mint_quote_bolt12_status(
|
||||
mint_quote.clone(),
|
||||
total_paid,
|
||||
mint_quote.amount_issued(),
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
// We don't send ws updates for unknown methods
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to emit a MintQuoteBolt11Response status
|
||||
pub fn mint_quote_bolt11_status<E: Into<MintQuoteBolt11Response<QuoteId>>>(
|
||||
&self,
|
||||
quote: E,
|
||||
new_state: MintQuoteState,
|
||||
) {
|
||||
let mut event = quote.into();
|
||||
event.state = new_state;
|
||||
|
||||
self.publish(event);
|
||||
}
|
||||
|
||||
/// Helper function to emit a MintQuoteBolt11Response status
|
||||
pub fn mint_quote_bolt12_status<E: TryInto<MintQuoteBolt12Response<QuoteId>>>(
|
||||
&self,
|
||||
quote: E,
|
||||
amount_paid: Amount,
|
||||
amount_issued: Amount,
|
||||
) {
|
||||
if let Ok(mut event) = quote.try_into() {
|
||||
event.amount_paid = amount_paid;
|
||||
event.amount_issued = amount_issued;
|
||||
|
||||
self.publish(event);
|
||||
} else {
|
||||
tracing::warn!("Could not convert quote to MintQuoteResponse");
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to emit a MeltQuoteBolt11Response status
|
||||
pub fn melt_quote_status<E: Into<MeltQuoteBolt11Response<QuoteId>>>(
|
||||
&self,
|
||||
quote: E,
|
||||
payment_preimage: Option<String>,
|
||||
change: Option<Vec<BlindSignature>>,
|
||||
new_state: MeltQuoteState,
|
||||
) {
|
||||
let mut quote = quote.into();
|
||||
quote.state = new_state;
|
||||
quote.paid = Some(new_state == MeltQuoteState::Paid);
|
||||
quote.payment_preimage = payment_preimage;
|
||||
quote.change = change;
|
||||
self.publish(quote);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for PubSubManager {
|
||||
type Target = Pubsub<MintPubSubSpec>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
@@ -1,292 +0,0 @@
|
||||
//! Specific Subscription for the cdk crate
|
||||
use std::ops::Deref;
|
||||
|
||||
use cdk_common::database::DynMintDatabase;
|
||||
use cdk_common::mint::MintQuote;
|
||||
use cdk_common::nut17::Notification;
|
||||
use cdk_common::quote_id::QuoteId;
|
||||
use cdk_common::{Amount, MintQuoteBolt12Response, NotificationPayload, PaymentMethod};
|
||||
|
||||
use super::OnSubscription;
|
||||
use crate::nuts::{
|
||||
BlindSignature, MeltQuoteBolt11Response, MeltQuoteState, MintQuoteBolt11Response,
|
||||
MintQuoteState, ProofState,
|
||||
};
|
||||
use crate::pub_sub;
|
||||
|
||||
/// Manager
|
||||
/// Publish–subscribe manager
|
||||
///
|
||||
/// Nut-17 implementation is system-wide and not only through the WebSocket, so
|
||||
/// it is possible for another part of the system to subscribe to events.
|
||||
pub struct PubSubManager(
|
||||
pub_sub::Manager<NotificationPayload<QuoteId>, Notification, OnSubscription>,
|
||||
);
|
||||
|
||||
#[allow(clippy::default_constructed_unit_structs)]
|
||||
impl Default for PubSubManager {
|
||||
fn default() -> Self {
|
||||
PubSubManager(OnSubscription::default().into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DynMintDatabase> for PubSubManager {
|
||||
fn from(val: DynMintDatabase) -> Self {
|
||||
PubSubManager(OnSubscription(Some(val)).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for PubSubManager {
|
||||
type Target = pub_sub::Manager<NotificationPayload<QuoteId>, Notification, OnSubscription>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl PubSubManager {
|
||||
/// Helper function to emit a ProofState status
|
||||
pub fn proof_state<E: Into<ProofState>>(&self, event: E) {
|
||||
self.broadcast(event.into().into());
|
||||
}
|
||||
|
||||
/// Helper function to publish even of a mint quote being paid
|
||||
pub fn mint_quote_issue(&self, mint_quote: &MintQuote, total_issued: Amount) {
|
||||
match mint_quote.payment_method {
|
||||
PaymentMethod::Bolt11 => {
|
||||
self.mint_quote_bolt11_status(mint_quote.clone(), MintQuoteState::Issued);
|
||||
}
|
||||
PaymentMethod::Bolt12 => {
|
||||
self.mint_quote_bolt12_status(
|
||||
mint_quote.clone(),
|
||||
mint_quote.amount_paid(),
|
||||
total_issued,
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
// We don't send ws updates for unknown methods
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to publish even of a mint quote being paid
|
||||
pub fn mint_quote_payment(&self, mint_quote: &MintQuote, total_paid: Amount) {
|
||||
match mint_quote.payment_method {
|
||||
PaymentMethod::Bolt11 => {
|
||||
self.mint_quote_bolt11_status(mint_quote.clone(), MintQuoteState::Paid);
|
||||
}
|
||||
PaymentMethod::Bolt12 => {
|
||||
self.mint_quote_bolt12_status(
|
||||
mint_quote.clone(),
|
||||
total_paid,
|
||||
mint_quote.amount_issued(),
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
// We don't send ws updates for unknown methods
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to emit a MintQuoteBolt11Response status
|
||||
pub fn mint_quote_bolt11_status<E: Into<MintQuoteBolt11Response<QuoteId>>>(
|
||||
&self,
|
||||
quote: E,
|
||||
new_state: MintQuoteState,
|
||||
) {
|
||||
let mut event = quote.into();
|
||||
event.state = new_state;
|
||||
|
||||
self.broadcast(event.into());
|
||||
}
|
||||
|
||||
/// Helper function to emit a MintQuoteBolt11Response status
|
||||
pub fn mint_quote_bolt12_status<E: TryInto<MintQuoteBolt12Response<QuoteId>>>(
|
||||
&self,
|
||||
quote: E,
|
||||
amount_paid: Amount,
|
||||
amount_issued: Amount,
|
||||
) {
|
||||
if let Ok(mut event) = quote.try_into() {
|
||||
event.amount_paid = amount_paid;
|
||||
event.amount_issued = amount_issued;
|
||||
|
||||
self.broadcast(event.into());
|
||||
} else {
|
||||
tracing::warn!("Could not convert quote to MintQuoteResponse");
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to emit a MeltQuoteBolt11Response status
|
||||
pub fn melt_quote_status<E: Into<MeltQuoteBolt11Response<QuoteId>>>(
|
||||
&self,
|
||||
quote: E,
|
||||
payment_preimage: Option<String>,
|
||||
change: Option<Vec<BlindSignature>>,
|
||||
new_state: MeltQuoteState,
|
||||
) {
|
||||
let mut quote = quote.into();
|
||||
quote.state = new_state;
|
||||
quote.paid = Some(new_state == MeltQuoteState::Paid);
|
||||
quote.payment_preimage = payment_preimage;
|
||||
quote.change = change;
|
||||
self.broadcast(quote.into());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::time::sleep;
|
||||
|
||||
use super::*;
|
||||
use crate::nuts::nut17::Kind;
|
||||
use crate::nuts::{PublicKey, State};
|
||||
use crate::subscription::{IndexableParams, Params};
|
||||
|
||||
#[tokio::test]
|
||||
async fn active_and_drop() {
|
||||
let manager = PubSubManager::default();
|
||||
let params: IndexableParams = Params {
|
||||
kind: Kind::ProofState,
|
||||
filters: vec![
|
||||
"02a9acc1e48c25eeeb9289b5031cc57da9fe72f3fe2861d264bdc074209b107ba2".to_owned(),
|
||||
],
|
||||
id: "uno".into(),
|
||||
}
|
||||
.into();
|
||||
|
||||
// Although the same param is used, two subscriptions are created, that
|
||||
// is because each index is unique, thanks to `Unique`, it is the
|
||||
// responsibility of the implementor to make sure that SubId are unique
|
||||
// either globally or per client
|
||||
let subscriptions = vec![
|
||||
manager
|
||||
.try_subscribe(params.clone())
|
||||
.await
|
||||
.expect("valid subscription"),
|
||||
manager
|
||||
.try_subscribe(params)
|
||||
.await
|
||||
.expect("valid subscription"),
|
||||
];
|
||||
assert_eq!(2, manager.active_subscriptions());
|
||||
drop(subscriptions);
|
||||
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
|
||||
assert_eq!(0, manager.active_subscriptions());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn broadcast() {
|
||||
let manager = PubSubManager::default();
|
||||
let mut subscriptions = [
|
||||
manager
|
||||
.try_subscribe::<IndexableParams>(
|
||||
Params {
|
||||
kind: Kind::ProofState,
|
||||
filters: vec![
|
||||
"02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104"
|
||||
.to_string(),
|
||||
],
|
||||
id: "uno".into(),
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
.expect("valid subscription"),
|
||||
manager
|
||||
.try_subscribe::<IndexableParams>(
|
||||
Params {
|
||||
kind: Kind::ProofState,
|
||||
filters: vec![
|
||||
"02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104"
|
||||
.to_string(),
|
||||
],
|
||||
id: "dos".into(),
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
.expect("valid subscription"),
|
||||
];
|
||||
|
||||
let event = ProofState {
|
||||
y: PublicKey::from_hex(
|
||||
"02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104",
|
||||
)
|
||||
.expect("valid pk"),
|
||||
state: State::Pending,
|
||||
witness: None,
|
||||
};
|
||||
|
||||
manager.broadcast(event.into());
|
||||
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
|
||||
let (sub1, _) = subscriptions[0].try_recv().expect("valid message");
|
||||
assert_eq!("uno", *sub1);
|
||||
|
||||
let (sub1, _) = subscriptions[1].try_recv().expect("valid message");
|
||||
assert_eq!("dos", *sub1);
|
||||
|
||||
assert!(subscriptions[0].try_recv().is_err());
|
||||
assert!(subscriptions[1].try_recv().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parsing_request() {
|
||||
let json = r#"{"kind":"proof_state","filters":["x"],"subId":"uno"}"#;
|
||||
let params: Params = serde_json::from_str(json).expect("valid json");
|
||||
assert_eq!(params.kind, Kind::ProofState);
|
||||
assert_eq!(params.filters, vec!["x"]);
|
||||
assert_eq!(*params.id, "uno");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn json_test() {
|
||||
let manager = PubSubManager::default();
|
||||
let mut subscription = manager
|
||||
.try_subscribe::<IndexableParams>(
|
||||
serde_json::from_str(r#"{"kind":"proof_state","filters":["02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104"],"subId":"uno"}"#)
|
||||
.expect("valid json"),
|
||||
)
|
||||
.await.expect("valid subscription");
|
||||
|
||||
manager.broadcast(
|
||||
ProofState {
|
||||
y: PublicKey::from_hex(
|
||||
"02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104",
|
||||
)
|
||||
.expect("valid pk"),
|
||||
state: State::Pending,
|
||||
witness: None,
|
||||
}
|
||||
.into(),
|
||||
);
|
||||
|
||||
// no one is listening for this event
|
||||
manager.broadcast(
|
||||
ProofState {
|
||||
y: PublicKey::from_hex(
|
||||
"020000000000000000000000000000000000000000000000000000000000000001",
|
||||
)
|
||||
.expect("valid pk"),
|
||||
state: State::Pending,
|
||||
witness: None,
|
||||
}
|
||||
.into(),
|
||||
);
|
||||
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
let (sub1, msg) = subscription.try_recv().expect("valid message");
|
||||
assert_eq!("uno", *sub1);
|
||||
assert_eq!(
|
||||
r#"{"Y":"02194603ffa36356f4a56b7df9371fc3192472351453ec7398b8da8117e7c3e104","state":"PENDING","witness":null}"#,
|
||||
serde_json::to_string(&msg).expect("valid json")
|
||||
);
|
||||
assert!(subscription.try_recv().is_err());
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
//! Specific Subscription for the cdk crate
|
||||
|
||||
#[cfg(feature = "mint")]
|
||||
mod manager;
|
||||
#[cfg(feature = "mint")]
|
||||
mod on_subscription;
|
||||
#[cfg(feature = "mint")]
|
||||
pub use manager::PubSubManager;
|
||||
#[cfg(feature = "mint")]
|
||||
pub use on_subscription::OnSubscription;
|
||||
|
||||
pub use crate::pub_sub::SubId;
|
||||
@@ -1,119 +0,0 @@
|
||||
//! On Subscription
|
||||
//!
|
||||
//! This module contains the code that is triggered when a new subscription is created.
|
||||
|
||||
use cdk_common::database::DynMintDatabase;
|
||||
use cdk_common::nut17::Notification;
|
||||
use cdk_common::pub_sub::OnNewSubscription;
|
||||
use cdk_common::quote_id::QuoteId;
|
||||
use cdk_common::{MintQuoteBolt12Response, NotificationPayload, PaymentMethod};
|
||||
|
||||
use crate::nuts::{MeltQuoteBolt11Response, MintQuoteBolt11Response, ProofState, PublicKey};
|
||||
|
||||
#[derive(Default)]
|
||||
/// Subscription Init
|
||||
///
|
||||
/// This struct triggers code when a new subscription is created.
|
||||
///
|
||||
/// It is used to send the initial state of the subscription to the client.
|
||||
pub struct OnSubscription(pub(crate) Option<DynMintDatabase>);
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OnNewSubscription for OnSubscription {
|
||||
type Event = NotificationPayload<QuoteId>;
|
||||
type Index = Notification;
|
||||
|
||||
async fn on_new_subscription(
|
||||
&self,
|
||||
request: &[&Self::Index],
|
||||
) -> Result<Vec<Self::Event>, String> {
|
||||
let datastore = if let Some(localstore) = self.0.as_ref() {
|
||||
localstore
|
||||
} else {
|
||||
return Ok(vec![]);
|
||||
};
|
||||
|
||||
let mut to_return = vec![];
|
||||
let mut public_keys: Vec<PublicKey> = Vec::new();
|
||||
let mut melt_queries = Vec::new();
|
||||
let mut mint_queries = Vec::new();
|
||||
|
||||
for idx in request.iter() {
|
||||
match idx {
|
||||
Notification::ProofState(pk) => public_keys.push(*pk),
|
||||
Notification::MeltQuoteBolt11(uuid) => {
|
||||
melt_queries.push(datastore.get_melt_quote(uuid))
|
||||
}
|
||||
Notification::MintQuoteBolt11(uuid) => {
|
||||
mint_queries.push(datastore.get_mint_quote(uuid))
|
||||
}
|
||||
Notification::MintQuoteBolt12(uuid) => {
|
||||
mint_queries.push(datastore.get_mint_quote(uuid))
|
||||
}
|
||||
Notification::MeltQuoteBolt12(uuid) => {
|
||||
melt_queries.push(datastore.get_melt_quote(uuid))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !melt_queries.is_empty() {
|
||||
to_return.extend(
|
||||
futures::future::try_join_all(melt_queries)
|
||||
.await
|
||||
.map(|quotes| {
|
||||
quotes
|
||||
.into_iter()
|
||||
.filter_map(|quote| quote.map(|x| x.into()))
|
||||
.map(|x: MeltQuoteBolt11Response<QuoteId>| x.into())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.map_err(|e| e.to_string())?,
|
||||
);
|
||||
}
|
||||
|
||||
if !mint_queries.is_empty() {
|
||||
to_return.extend(
|
||||
futures::future::try_join_all(mint_queries)
|
||||
.await
|
||||
.map(|quotes| {
|
||||
quotes
|
||||
.into_iter()
|
||||
.filter_map(|quote| {
|
||||
quote.and_then(|x| match x.payment_method {
|
||||
PaymentMethod::Bolt11 => {
|
||||
let response: MintQuoteBolt11Response<QuoteId> = x.into();
|
||||
Some(response.into())
|
||||
}
|
||||
PaymentMethod::Bolt12 => match x.try_into() {
|
||||
Ok(response) => {
|
||||
let response: MintQuoteBolt12Response<QuoteId> =
|
||||
response;
|
||||
Some(response.into())
|
||||
}
|
||||
Err(_) => None,
|
||||
},
|
||||
PaymentMethod::Custom(_) => None,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.map_err(|e| e.to_string())?,
|
||||
);
|
||||
}
|
||||
|
||||
if !public_keys.is_empty() {
|
||||
to_return.extend(
|
||||
datastore
|
||||
.get_proofs_states(public_keys.as_slice())
|
||||
.await
|
||||
.map_err(|e| e.to_string())?
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, state)| state.map(|state| (public_keys[idx], state).into()))
|
||||
.map(|state: ProofState| state.into()),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(to_return)
|
||||
}
|
||||
}
|
||||
@@ -1,339 +0,0 @@
|
||||
//! Publish–subscribe pattern.
|
||||
//!
|
||||
//! This is a generic implementation for
|
||||
//! [NUT-17](<https://github.com/cashubtc/nuts/blob/main/17.md>) with a type
|
||||
//! agnostic Publish-subscribe manager.
|
||||
//!
|
||||
//! The manager has a method for subscribers to subscribe to events with a
|
||||
//! generic type that must be converted to a vector of indexes.
|
||||
//!
|
||||
//! Events are also generic that should implement the `Indexable` trait.
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::{BTreeMap, HashSet};
|
||||
use std::fmt::Debug;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::sync::atomic::{self, AtomicUsize};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use cdk_common::pub_sub::index::{Index, Indexable, SubscriptionGlobalId};
|
||||
use cdk_common::pub_sub::OnNewSubscription;
|
||||
pub use cdk_common::pub_sub::SubId;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
type IndexTree<T, I> = Arc<RwLock<BTreeMap<Index<I>, mpsc::Sender<(SubId, T)>>>>;
|
||||
|
||||
/// Default size of the remove channel
|
||||
pub const DEFAULT_REMOVE_SIZE: usize = 10_000;
|
||||
|
||||
/// Default channel size for subscription buffering
|
||||
pub const DEFAULT_CHANNEL_SIZE: usize = 10;
|
||||
|
||||
/// Subscription manager
|
||||
///
|
||||
/// This object keep track of all subscription listener and it is also
|
||||
/// responsible for broadcasting events to all listeners
|
||||
///
|
||||
/// The content of the notification is not relevant to this scope and it is up
|
||||
/// to the application, therefore the generic T is used instead of a specific
|
||||
/// type
|
||||
pub struct Manager<T, I, F>
|
||||
where
|
||||
T: Indexable<Type = I> + Clone + Send + Sync + 'static,
|
||||
I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static,
|
||||
F: OnNewSubscription<Index = I, Event = T> + Send + Sync + 'static,
|
||||
{
|
||||
indexes: IndexTree<T, I>,
|
||||
on_new_subscription: Option<Arc<F>>,
|
||||
unsubscription_sender: mpsc::Sender<(SubId, Vec<Index<I>>)>,
|
||||
active_subscriptions: Arc<AtomicUsize>,
|
||||
background_subscription_remover: Option<JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl<T, I, F> Default for Manager<T, I, F>
|
||||
where
|
||||
T: Indexable<Type = I> + Clone + Send + Sync + 'static,
|
||||
I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static,
|
||||
F: OnNewSubscription<Index = I, Event = T> + Send + Sync + 'static,
|
||||
{
|
||||
fn default() -> Self {
|
||||
let (sender, receiver) = mpsc::channel(DEFAULT_REMOVE_SIZE);
|
||||
let active_subscriptions: Arc<AtomicUsize> = Default::default();
|
||||
let storage: IndexTree<T, I> = Arc::new(Default::default());
|
||||
|
||||
Self {
|
||||
background_subscription_remover: Some(tokio::spawn(Self::remove_subscription(
|
||||
receiver,
|
||||
storage.clone(),
|
||||
active_subscriptions.clone(),
|
||||
))),
|
||||
on_new_subscription: None,
|
||||
unsubscription_sender: sender,
|
||||
active_subscriptions,
|
||||
indexes: storage,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, I, F> From<F> for Manager<T, I, F>
|
||||
where
|
||||
T: Indexable<Type = I> + Clone + Send + Sync + 'static,
|
||||
I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static,
|
||||
F: OnNewSubscription<Index = I, Event = T> + Send + Sync + 'static,
|
||||
{
|
||||
fn from(value: F) -> Self {
|
||||
let mut manager: Self = Default::default();
|
||||
manager.on_new_subscription = Some(Arc::new(value));
|
||||
manager
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, I, F> Manager<T, I, F>
|
||||
where
|
||||
T: Indexable<Type = I> + Clone + Send + Sync + 'static,
|
||||
I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static,
|
||||
F: OnNewSubscription<Index = I, Event = T> + Send + Sync + 'static,
|
||||
{
|
||||
#[inline]
|
||||
/// Broadcast an event to all listeners
|
||||
///
|
||||
/// This function takes an Arc to the storage struct, the event_id, the kind
|
||||
/// and the vent to broadcast
|
||||
async fn broadcast_impl(storage: &IndexTree<T, I>, event: T) {
|
||||
let index_storage = storage.read().await;
|
||||
let mut sent = HashSet::new();
|
||||
for index in event.to_indexes() {
|
||||
for (key, sender) in index_storage.range(index.clone()..) {
|
||||
if index.cmp_prefix(key) != Ordering::Equal {
|
||||
break;
|
||||
}
|
||||
let sub_id = key.unique_id();
|
||||
if sent.contains(&sub_id) {
|
||||
continue;
|
||||
}
|
||||
sent.insert(sub_id);
|
||||
let _ = sender.try_send((key.into(), event.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Broadcasts an event to all listeners
|
||||
///
|
||||
/// This public method will not block the caller, it will spawn a new task
|
||||
/// instead
|
||||
pub fn broadcast(&self, event: T) {
|
||||
let storage = self.indexes.clone();
|
||||
tokio::spawn(async move {
|
||||
Self::broadcast_impl(&storage, event).await;
|
||||
});
|
||||
}
|
||||
|
||||
/// Broadcasts an event to all listeners
|
||||
///
|
||||
/// This method is async and will await for the broadcast to be completed
|
||||
pub async fn broadcast_async(&self, event: T) {
|
||||
Self::broadcast_impl(&self.indexes, event).await;
|
||||
}
|
||||
|
||||
/// Specific of the subscription, this is the abstraction between `subscribe` and `try_subscribe`
|
||||
#[inline(always)]
|
||||
async fn subscribe_inner(
|
||||
&self,
|
||||
sub_id: SubId,
|
||||
indexes: Vec<Index<I>>,
|
||||
) -> ActiveSubscription<T, I> {
|
||||
let (sender, receiver) = mpsc::channel(10);
|
||||
|
||||
let mut index_storage = self.indexes.write().await;
|
||||
// Subscribe to events as soon as possible
|
||||
for index in indexes.clone() {
|
||||
index_storage.insert(index, sender.clone());
|
||||
}
|
||||
drop(index_storage);
|
||||
|
||||
if let Some(on_new_subscription) = self.on_new_subscription.clone() {
|
||||
// After we're subscribed already, fetch the current status of matching events. It is
|
||||
// down in another thread to return right away
|
||||
let indexes_for_worker = indexes.clone();
|
||||
let sub_id_for_worker = sub_id.clone();
|
||||
tokio::spawn(async move {
|
||||
match on_new_subscription
|
||||
.on_new_subscription(
|
||||
&indexes_for_worker
|
||||
.iter()
|
||||
.map(|x| x.deref())
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(events) => {
|
||||
for event in events {
|
||||
let _ = sender.try_send((sub_id_for_worker.clone(), event));
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::info!(
|
||||
"Failed to get initial state for subscription: {:?}, {}",
|
||||
sub_id_for_worker,
|
||||
err
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
self.active_subscriptions
|
||||
.fetch_add(1, atomic::Ordering::Relaxed);
|
||||
|
||||
ActiveSubscription {
|
||||
sub_id,
|
||||
receiver,
|
||||
indexes,
|
||||
drop: self.unsubscription_sender.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to subscribe to a specific event
|
||||
pub async fn try_subscribe<P>(&self, params: P) -> Result<ActiveSubscription<T, I>, P::Error>
|
||||
where
|
||||
P: AsRef<SubId> + TryInto<Vec<Index<I>>>,
|
||||
{
|
||||
Ok(self
|
||||
.subscribe_inner(params.as_ref().clone(), params.try_into()?)
|
||||
.await)
|
||||
}
|
||||
|
||||
/// Subscribe to a specific event
|
||||
pub async fn subscribe<P>(&self, params: P) -> ActiveSubscription<T, I>
|
||||
where
|
||||
P: AsRef<SubId> + Into<Vec<Index<I>>>,
|
||||
{
|
||||
self.subscribe_inner(params.as_ref().clone(), params.into())
|
||||
.await
|
||||
}
|
||||
|
||||
/// Return number of active subscriptions
|
||||
pub fn active_subscriptions(&self) -> usize {
|
||||
self.active_subscriptions.load(atomic::Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Task to remove dropped subscriptions from the storage struct
|
||||
///
|
||||
/// This task will run in the background (and will be dropped when the [`Manager`]
|
||||
/// is) and will remove subscriptions from the storage struct it is dropped.
|
||||
async fn remove_subscription(
|
||||
mut receiver: mpsc::Receiver<(SubId, Vec<Index<I>>)>,
|
||||
storage: IndexTree<T, I>,
|
||||
active_subscriptions: Arc<AtomicUsize>,
|
||||
) {
|
||||
while let Some((sub_id, indexes)) = receiver.recv().await {
|
||||
tracing::info!("Removing subscription: {}", *sub_id);
|
||||
|
||||
active_subscriptions.fetch_sub(1, atomic::Ordering::AcqRel);
|
||||
|
||||
let mut index_storage = storage.write().await;
|
||||
for key in indexes {
|
||||
index_storage.remove(&key);
|
||||
}
|
||||
drop(index_storage);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Manager goes out of scope, stop all background tasks
|
||||
impl<T, I, F> Drop for Manager<T, I, F>
|
||||
where
|
||||
T: Indexable<Type = I> + Clone + Send + Sync + 'static,
|
||||
I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static,
|
||||
F: OnNewSubscription<Index = I, Event = T> + Send + Sync + 'static,
|
||||
{
|
||||
fn drop(&mut self) {
|
||||
if let Some(handler) = self.background_subscription_remover.take() {
|
||||
handler.abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Active Subscription
|
||||
///
|
||||
/// This struct is a wrapper around the `mpsc::Receiver<Event>` and it also used
|
||||
/// to keep track of the subscription itself. When this struct goes out of
|
||||
/// scope, it will notify the Manager about it, so it can be removed from the
|
||||
/// list of active listeners
|
||||
pub struct ActiveSubscription<T, I>
|
||||
where
|
||||
T: Send + Sync,
|
||||
I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static,
|
||||
{
|
||||
/// The subscription ID
|
||||
pub sub_id: SubId,
|
||||
indexes: Vec<Index<I>>,
|
||||
receiver: mpsc::Receiver<(SubId, T)>,
|
||||
drop: mpsc::Sender<(SubId, Vec<Index<I>>)>,
|
||||
}
|
||||
|
||||
impl<T, I> Deref for ActiveSubscription<T, I>
|
||||
where
|
||||
T: Send + Sync,
|
||||
I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static,
|
||||
{
|
||||
type Target = mpsc::Receiver<(SubId, T)>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.receiver
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, I> DerefMut for ActiveSubscription<T, I>
|
||||
where
|
||||
T: Indexable + Clone + Send + Sync + 'static,
|
||||
I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static,
|
||||
{
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.receiver
|
||||
}
|
||||
}
|
||||
|
||||
/// The ActiveSubscription is Drop out of scope, notify the Manager about it, so
|
||||
/// it can be removed from the list of active listeners
|
||||
///
|
||||
/// Having this in place, we can avoid memory leaks and also makes it super
|
||||
/// simple to implement the Unsubscribe method
|
||||
impl<T, I> Drop for ActiveSubscription<T, I>
|
||||
where
|
||||
T: Send + Sync,
|
||||
I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static,
|
||||
{
|
||||
fn drop(&mut self) {
|
||||
let _ = self
|
||||
.drop
|
||||
.try_send((self.sub_id.clone(), self.indexes.drain(..).collect()));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_active_subscription_drop() {
|
||||
let (tx, rx) = mpsc::channel::<(SubId, ())>(10);
|
||||
let sub_id = SubId::from("test_sub_id");
|
||||
let indexes: Vec<Index<String>> = vec![Index::from(("test".to_string(), sub_id.clone()))];
|
||||
let (drop_tx, mut drop_rx) = mpsc::channel(10);
|
||||
|
||||
{
|
||||
let _active_subscription = ActiveSubscription {
|
||||
sub_id: sub_id.clone(),
|
||||
indexes,
|
||||
receiver: rx,
|
||||
drop: drop_tx,
|
||||
};
|
||||
// When it goes out of scope, it should notify
|
||||
}
|
||||
assert_eq!(drop_rx.try_recv().unwrap().0, sub_id); // it should have notified
|
||||
assert!(tx.try_send(("foo".into(), ())).is_err()); // subscriber is dropped
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,7 @@ use std::sync::Arc;
|
||||
|
||||
use cdk_common::amount::FeeAndAmounts;
|
||||
use cdk_common::database::{self, WalletDatabase};
|
||||
use cdk_common::subscription::Params;
|
||||
use cdk_common::subscription::WalletParams;
|
||||
use getrandom::getrandom;
|
||||
use subscription::{ActiveSubscription, SubscriptionManager};
|
||||
#[cfg(feature = "auth")]
|
||||
@@ -108,40 +108,42 @@ pub enum WalletSubscription {
|
||||
Bolt12MintQuoteState(Vec<String>),
|
||||
}
|
||||
|
||||
impl From<WalletSubscription> for Params {
|
||||
impl From<WalletSubscription> for WalletParams {
|
||||
fn from(val: WalletSubscription) -> Self {
|
||||
let mut buffer = vec![0u8; 10];
|
||||
|
||||
getrandom(&mut buffer).expect("Failed to generate random bytes");
|
||||
|
||||
let id = buffer
|
||||
.iter()
|
||||
.map(|&byte| {
|
||||
let index = byte as usize % ALPHANUMERIC.len(); // 62 alphanumeric characters (A-Z, a-z, 0-9)
|
||||
ALPHANUMERIC[index] as char
|
||||
})
|
||||
.collect::<String>();
|
||||
let id = Arc::new(
|
||||
buffer
|
||||
.iter()
|
||||
.map(|&byte| {
|
||||
let index = byte as usize % ALPHANUMERIC.len(); // 62 alphanumeric characters (A-Z, a-z, 0-9)
|
||||
ALPHANUMERIC[index] as char
|
||||
})
|
||||
.collect::<String>(),
|
||||
);
|
||||
|
||||
match val {
|
||||
WalletSubscription::ProofState(filters) => Params {
|
||||
WalletSubscription::ProofState(filters) => WalletParams {
|
||||
filters,
|
||||
kind: Kind::ProofState,
|
||||
id: id.into(),
|
||||
id,
|
||||
},
|
||||
WalletSubscription::Bolt11MintQuoteState(filters) => Params {
|
||||
WalletSubscription::Bolt11MintQuoteState(filters) => WalletParams {
|
||||
filters,
|
||||
kind: Kind::Bolt11MintQuote,
|
||||
id: id.into(),
|
||||
id,
|
||||
},
|
||||
WalletSubscription::Bolt11MeltQuoteState(filters) => Params {
|
||||
WalletSubscription::Bolt11MeltQuoteState(filters) => WalletParams {
|
||||
filters,
|
||||
kind: Kind::Bolt11MeltQuote,
|
||||
id: id.into(),
|
||||
id,
|
||||
},
|
||||
WalletSubscription::Bolt12MintQuoteState(filters) => Params {
|
||||
WalletSubscription::Bolt12MintQuoteState(filters) => WalletParams {
|
||||
filters,
|
||||
kind: Kind::Bolt12MintQuote,
|
||||
id: id.into(),
|
||||
id,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -193,10 +195,10 @@ impl Wallet {
|
||||
}
|
||||
|
||||
/// Subscribe to events
|
||||
pub async fn subscribe<T: Into<Params>>(&self, query: T) -> ActiveSubscription {
|
||||
pub async fn subscribe<T: Into<WalletParams>>(&self, query: T) -> ActiveSubscription {
|
||||
self.subscription
|
||||
.subscribe(self.mint_url.clone(), query.into(), Arc::new(self.clone()))
|
||||
.await
|
||||
.subscribe(self.mint_url.clone(), query.into())
|
||||
.expect("FIXME")
|
||||
}
|
||||
|
||||
/// Fee required for proof set
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
//! pairs
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
use std::ops::Deref;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -675,7 +676,7 @@ impl MultiMintWallet {
|
||||
// Check if this is a mint quote response with paid state
|
||||
if let crate::nuts::nut17::NotificationPayload::MintQuoteBolt11Response(
|
||||
quote_response,
|
||||
) = notification
|
||||
) = notification.deref()
|
||||
{
|
||||
if quote_response.state == QuoteState::Paid {
|
||||
// Quote is paid, now mint the tokens
|
||||
@@ -1264,7 +1265,7 @@ impl MultiMintWallet {
|
||||
/// Melt (pay invoice) with automatic wallet selection (deprecated, use specific mint functions for better control)
|
||||
///
|
||||
/// Automatically selects the best wallet to pay from based on:
|
||||
/// - Available balance
|
||||
/// - Available balance
|
||||
/// - Fees
|
||||
///
|
||||
/// # Examples
|
||||
|
||||
@@ -13,10 +13,11 @@ use futures::{FutureExt, Stream, StreamExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use super::RecvFuture;
|
||||
use crate::event::MintEvent;
|
||||
use crate::wallet::subscription::ActiveSubscription;
|
||||
use crate::{Wallet, WalletSubscription};
|
||||
|
||||
type SubscribeReceived = (Option<NotificationPayload<String>>, Vec<ActiveSubscription>);
|
||||
type SubscribeReceived = (Option<MintEvent<String>>, Vec<ActiveSubscription>);
|
||||
type PaymentValue = (String, Option<Amount>);
|
||||
|
||||
/// PaymentWaiter
|
||||
@@ -145,7 +146,7 @@ impl<'a> PaymentStream<'a> {
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Some(info) => {
|
||||
match info {
|
||||
match info.into_inner() {
|
||||
NotificationPayload::MintQuoteBolt11Response(info) => {
|
||||
if info.state == MintQuoteState::Paid {
|
||||
self.is_finalized = true;
|
||||
|
||||
@@ -1,238 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use cdk_common::MintQuoteBolt12Response;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use tokio::time;
|
||||
use web_time::Duration;
|
||||
|
||||
use super::WsSubscriptionBody;
|
||||
use crate::nuts::nut17::Kind;
|
||||
use crate::nuts::{nut01, nut05, nut07, nut23, CheckStateRequest, NotificationPayload};
|
||||
use crate::pub_sub::SubId;
|
||||
use crate::wallet::MintConnector;
|
||||
use crate::Wallet;
|
||||
|
||||
#[derive(Debug, Hash, PartialEq, Eq)]
|
||||
enum UrlType {
|
||||
Mint(String),
|
||||
MintBolt12(String),
|
||||
Melt(String),
|
||||
PublicKey(nut01::PublicKey),
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
enum AnyState {
|
||||
MintQuoteState(nut23::QuoteState),
|
||||
MintBolt12QuoteState(MintQuoteBolt12Response<String>),
|
||||
MeltQuoteState(nut05::QuoteState),
|
||||
PublicKey(nut07::State),
|
||||
Empty,
|
||||
}
|
||||
|
||||
type SubscribedTo = HashMap<UrlType, (mpsc::Sender<NotificationPayload<String>>, SubId, AnyState)>;
|
||||
|
||||
async fn convert_subscription(
|
||||
sub_id: SubId,
|
||||
subscriptions: &Arc<RwLock<HashMap<SubId, WsSubscriptionBody>>>,
|
||||
subscribed_to: &mut SubscribedTo,
|
||||
) -> Option<()> {
|
||||
let subscription = subscriptions.read().await;
|
||||
let sub = subscription.get(&sub_id)?;
|
||||
tracing::debug!("New subscription: {:?}", sub);
|
||||
match sub.1.kind {
|
||||
Kind::Bolt11MintQuote => {
|
||||
for id in sub.1.filters.iter().map(|id| UrlType::Mint(id.clone())) {
|
||||
subscribed_to.insert(id, (sub.0.clone(), sub.1.id.clone(), AnyState::Empty));
|
||||
}
|
||||
}
|
||||
Kind::Bolt11MeltQuote => {
|
||||
for id in sub.1.filters.iter().map(|id| UrlType::Melt(id.clone())) {
|
||||
subscribed_to.insert(id, (sub.0.clone(), sub.1.id.clone(), AnyState::Empty));
|
||||
}
|
||||
}
|
||||
Kind::ProofState => {
|
||||
for id in sub
|
||||
.1
|
||||
.filters
|
||||
.iter()
|
||||
.map(|id| nut01::PublicKey::from_hex(id).map(UrlType::PublicKey))
|
||||
{
|
||||
match id {
|
||||
Ok(id) => {
|
||||
subscribed_to
|
||||
.insert(id, (sub.0.clone(), sub.1.id.clone(), AnyState::Empty));
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!("Error parsing public key: {:?}. Subscription ignored, will never yield any result", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Kind::Bolt12MintQuote => {
|
||||
for id in sub
|
||||
.1
|
||||
.filters
|
||||
.iter()
|
||||
.map(|id| UrlType::MintBolt12(id.clone()))
|
||||
{
|
||||
subscribed_to.insert(id, (sub.0.clone(), sub.1.id.clone(), AnyState::Empty));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(())
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[inline]
|
||||
pub async fn http_main<S: IntoIterator<Item = SubId>>(
|
||||
initial_state: S,
|
||||
http_client: Arc<dyn MintConnector + Send + Sync>,
|
||||
subscriptions: Arc<RwLock<HashMap<SubId, WsSubscriptionBody>>>,
|
||||
mut new_subscription_recv: mpsc::Receiver<SubId>,
|
||||
mut on_drop: mpsc::Receiver<SubId>,
|
||||
_wallet: Arc<Wallet>,
|
||||
) {
|
||||
let mut interval = time::interval(Duration::from_secs(2));
|
||||
let mut subscribed_to = SubscribedTo::new();
|
||||
|
||||
for sub_id in initial_state {
|
||||
convert_subscription(sub_id, &subscriptions, &mut subscribed_to).await;
|
||||
}
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = interval.tick() => {
|
||||
poll_subscriptions(&http_client, &mut subscribed_to).await;
|
||||
}
|
||||
Some(subid) = new_subscription_recv.recv() => {
|
||||
convert_subscription(subid, &subscriptions, &mut subscribed_to).await;
|
||||
}
|
||||
Some(id) = on_drop.recv() => {
|
||||
subscribed_to.retain(|_, (_, sub_id, _)| *sub_id != id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
#[inline]
|
||||
pub async fn http_main<S: IntoIterator<Item = SubId>>(
|
||||
initial_state: S,
|
||||
http_client: Arc<dyn MintConnector + Send + Sync>,
|
||||
subscriptions: Arc<RwLock<HashMap<SubId, WsSubscriptionBody>>>,
|
||||
mut new_subscription_recv: mpsc::Receiver<SubId>,
|
||||
mut on_drop: mpsc::Receiver<SubId>,
|
||||
_wallet: Arc<Wallet>,
|
||||
) {
|
||||
let mut subscribed_to = SubscribedTo::new();
|
||||
|
||||
for sub_id in initial_state {
|
||||
convert_subscription(sub_id, &subscriptions, &mut subscribed_to).await;
|
||||
}
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = gloo_timers::future::sleep(Duration::from_secs(2)) => {
|
||||
poll_subscriptions(&http_client, &mut subscribed_to).await;
|
||||
}
|
||||
subid = new_subscription_recv.recv() => {
|
||||
match subid {
|
||||
Some(subid) => {
|
||||
convert_subscription(subid, &subscriptions, &mut subscribed_to).await;
|
||||
}
|
||||
None => {
|
||||
// New subscription channel closed - SubscriptionClient was dropped, terminate worker
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
id = on_drop.recv() => {
|
||||
match id {
|
||||
Some(id) => {
|
||||
subscribed_to.retain(|_, (_, sub_id, _)| *sub_id != id);
|
||||
}
|
||||
None => {
|
||||
// Drop notification channel closed - SubscriptionClient was dropped, terminate worker
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn poll_subscriptions(
|
||||
http_client: &Arc<dyn MintConnector + Send + Sync>,
|
||||
subscribed_to: &mut SubscribedTo,
|
||||
) {
|
||||
for (url, (sender, _, last_state)) in subscribed_to.iter_mut() {
|
||||
tracing::debug!("Polling: {:?}", url);
|
||||
match url {
|
||||
UrlType::MintBolt12(id) => {
|
||||
let response = http_client.get_mint_quote_bolt12_status(id).await;
|
||||
if let Ok(response) = response {
|
||||
if *last_state == AnyState::MintBolt12QuoteState(response.clone()) {
|
||||
continue;
|
||||
}
|
||||
*last_state = AnyState::MintBolt12QuoteState(response.clone());
|
||||
if let Err(err) =
|
||||
sender.try_send(NotificationPayload::MintQuoteBolt12Response(response))
|
||||
{
|
||||
tracing::error!("Error sending mint quote response: {:?}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
UrlType::Mint(id) => {
|
||||
let response = http_client.get_mint_quote_status(id).await;
|
||||
if let Ok(response) = response {
|
||||
if *last_state == AnyState::MintQuoteState(response.state) {
|
||||
continue;
|
||||
}
|
||||
*last_state = AnyState::MintQuoteState(response.state);
|
||||
if let Err(err) =
|
||||
sender.try_send(NotificationPayload::MintQuoteBolt11Response(response))
|
||||
{
|
||||
tracing::error!("Error sending mint quote response: {:?}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
UrlType::Melt(id) => {
|
||||
let response = http_client.get_melt_quote_status(id).await;
|
||||
if let Ok(response) = response {
|
||||
if *last_state == AnyState::MeltQuoteState(response.state) {
|
||||
continue;
|
||||
}
|
||||
*last_state = AnyState::MeltQuoteState(response.state);
|
||||
if let Err(err) =
|
||||
sender.try_send(NotificationPayload::MeltQuoteBolt11Response(response))
|
||||
{
|
||||
tracing::error!("Error sending melt quote response: {:?}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
UrlType::PublicKey(id) => {
|
||||
let responses = http_client
|
||||
.post_check_state(CheckStateRequest { ys: vec![*id] })
|
||||
.await;
|
||||
if let Ok(mut responses) = responses {
|
||||
let response = if let Some(state) = responses.states.pop() {
|
||||
state
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if *last_state == AnyState::PublicKey(response.state) {
|
||||
continue;
|
||||
}
|
||||
*last_state = AnyState::PublicKey(response.state);
|
||||
if let Err(err) = sender.try_send(NotificationPayload::ProofState(response)) {
|
||||
tracing::error!("Error sending proof state response: {:?}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7,28 +7,33 @@
|
||||
//! the HTTP client.
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Debug;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::Arc;
|
||||
|
||||
use cdk_common::subscription::Params;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use tokio::task::JoinHandle;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
use wasm_bindgen_futures;
|
||||
use cdk_common::nut17::ws::{WsMethodRequest, WsRequest, WsUnsubscribeRequest};
|
||||
use cdk_common::nut17::{Kind, NotificationId};
|
||||
use cdk_common::parking_lot::RwLock;
|
||||
use cdk_common::pub_sub::remote_consumer::{
|
||||
Consumer, InternalRelay, RemoteActiveConsumer, StreamCtrl, SubscribeMessage, Transport,
|
||||
};
|
||||
use cdk_common::pub_sub::{Error as PubsubError, Spec, Subscriber};
|
||||
use cdk_common::subscription::WalletParams;
|
||||
use cdk_common::CheckStateRequest;
|
||||
use tokio::sync::mpsc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::Wallet;
|
||||
use crate::event::MintEvent;
|
||||
use crate::mint_url::MintUrl;
|
||||
use crate::pub_sub::SubId;
|
||||
use crate::wallet::MintConnector;
|
||||
|
||||
mod http;
|
||||
#[cfg(all(
|
||||
not(feature = "http_subscription"),
|
||||
feature = "mint",
|
||||
not(target_arch = "wasm32")
|
||||
))]
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
mod ws;
|
||||
|
||||
type WsSubscriptionBody = (mpsc::Sender<NotificationPayload>, Params);
|
||||
/// Notification Payload
|
||||
pub type NotificationPayload = crate::nuts::NotificationPayload<String>;
|
||||
|
||||
/// Type alias
|
||||
pub type ActiveSubscription = RemoteActiveConsumer<SubscriptionClient>;
|
||||
|
||||
/// Subscription manager
|
||||
///
|
||||
@@ -45,13 +50,27 @@ type WsSubscriptionBody = (mpsc::Sender<NotificationPayload>, Params);
|
||||
/// The subscribers have a simple-to-use interface, receiving an
|
||||
/// ActiveSubscription struct, which can be used to receive updates and to
|
||||
/// unsubscribe from updates automatically on the drop.
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Clone)]
|
||||
pub struct SubscriptionManager {
|
||||
all_connections: Arc<RwLock<HashMap<MintUrl, SubscriptionClient>>>,
|
||||
all_connections: Arc<RwLock<HashMap<MintUrl, Arc<Consumer<SubscriptionClient>>>>>,
|
||||
http_client: Arc<dyn MintConnector + Send + Sync>,
|
||||
prefer_http: bool,
|
||||
}
|
||||
|
||||
impl Debug for SubscriptionManager {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Subscription Manager connected to {:?}",
|
||||
self.all_connections
|
||||
.write()
|
||||
.keys()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl SubscriptionManager {
|
||||
/// Create a new subscription manager
|
||||
pub fn new(http_client: Arc<dyn MintConnector + Send + Sync>, prefer_http: bool) -> Self {
|
||||
@@ -63,63 +82,54 @@ impl SubscriptionManager {
|
||||
}
|
||||
|
||||
/// Subscribe to updates from a mint server with a given filter
|
||||
pub async fn subscribe(
|
||||
pub fn subscribe(
|
||||
&self,
|
||||
mint_url: MintUrl,
|
||||
filter: Params,
|
||||
wallet: Arc<Wallet>,
|
||||
) -> ActiveSubscription {
|
||||
let subscription_clients = self.all_connections.read().await;
|
||||
let id = filter.id.clone();
|
||||
if let Some(subscription_client) = subscription_clients.get(&mint_url) {
|
||||
let (on_drop_notif, receiver) = subscription_client.subscribe(filter).await;
|
||||
ActiveSubscription::new(receiver, id, on_drop_notif)
|
||||
} else {
|
||||
drop(subscription_clients);
|
||||
filter: WalletParams,
|
||||
) -> Result<RemoteActiveConsumer<SubscriptionClient>, PubsubError> {
|
||||
self.all_connections
|
||||
.write()
|
||||
.entry(mint_url.clone())
|
||||
.or_insert_with(|| {
|
||||
Consumer::new(
|
||||
SubscriptionClient {
|
||||
mint_url,
|
||||
http_client: self.http_client.clone(),
|
||||
req_id: 0.into(),
|
||||
},
|
||||
self.prefer_http,
|
||||
(),
|
||||
)
|
||||
})
|
||||
.subscribe(filter)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(
|
||||
not(feature = "http_subscription"),
|
||||
feature = "mint",
|
||||
not(target_arch = "wasm32")
|
||||
))]
|
||||
let is_ws_support = self
|
||||
.http_client
|
||||
.get_mint_info()
|
||||
.await
|
||||
.map(|info| !info.nuts.nut17.supported.is_empty())
|
||||
.unwrap_or_default();
|
||||
/// MintSubTopics
|
||||
#[derive(Clone, Default)]
|
||||
pub struct MintSubTopics {}
|
||||
|
||||
#[cfg(any(
|
||||
feature = "http_subscription",
|
||||
not(feature = "mint"),
|
||||
target_arch = "wasm32"
|
||||
))]
|
||||
let is_ws_support = false;
|
||||
#[async_trait::async_trait]
|
||||
impl Spec for MintSubTopics {
|
||||
type SubscriptionId = String;
|
||||
|
||||
let is_ws_support = if self.prefer_http {
|
||||
false
|
||||
} else {
|
||||
is_ws_support
|
||||
};
|
||||
type Event = MintEvent<String>;
|
||||
|
||||
tracing::debug!(
|
||||
"Connect to {:?} to subscribe. WebSocket is supported ({})",
|
||||
mint_url,
|
||||
is_ws_support
|
||||
);
|
||||
type Topic = NotificationId<String>;
|
||||
|
||||
let mut subscription_clients = self.all_connections.write().await;
|
||||
let subscription_client = SubscriptionClient::new(
|
||||
mint_url.clone(),
|
||||
self.http_client.clone(),
|
||||
is_ws_support,
|
||||
wallet,
|
||||
);
|
||||
let (on_drop_notif, receiver) = subscription_client.subscribe(filter).await;
|
||||
subscription_clients.insert(mint_url, subscription_client);
|
||||
type Context = ();
|
||||
|
||||
ActiveSubscription::new(receiver, id, on_drop_notif)
|
||||
}
|
||||
fn new_instance(_context: Self::Context) -> Arc<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Arc::new(Self {})
|
||||
}
|
||||
|
||||
async fn fetch_events(self: &Arc<Self>, _topics: Vec<Self::Topic>, _reply_to: Subscriber<Self>)
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,225 +139,178 @@ impl SubscriptionManager {
|
||||
/// otherwise the HTTP pool and pause will be used (which is the less efficient
|
||||
/// method).
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub struct SubscriptionClient {
|
||||
new_subscription_notif: mpsc::Sender<SubId>,
|
||||
on_drop_notif: mpsc::Sender<SubId>,
|
||||
subscriptions: Arc<RwLock<HashMap<SubId, WsSubscriptionBody>>>,
|
||||
worker: Option<JoinHandle<()>>,
|
||||
}
|
||||
|
||||
type NotificationPayload = crate::nuts::NotificationPayload<String>;
|
||||
|
||||
/// Active Subscription
|
||||
pub struct ActiveSubscription {
|
||||
sub_id: Option<SubId>,
|
||||
on_drop_notif: mpsc::Sender<SubId>,
|
||||
receiver: mpsc::Receiver<NotificationPayload>,
|
||||
}
|
||||
|
||||
impl ActiveSubscription {
|
||||
fn new(
|
||||
receiver: mpsc::Receiver<NotificationPayload>,
|
||||
sub_id: SubId,
|
||||
on_drop_notif: mpsc::Sender<SubId>,
|
||||
) -> Self {
|
||||
Self {
|
||||
sub_id: Some(sub_id),
|
||||
on_drop_notif,
|
||||
receiver,
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to receive a notification
|
||||
pub fn try_recv(&mut self) -> Result<Option<NotificationPayload>, Error> {
|
||||
match self.receiver.try_recv() {
|
||||
Ok(payload) => Ok(Some(payload)),
|
||||
Err(mpsc::error::TryRecvError::Empty) => Ok(None),
|
||||
Err(mpsc::error::TryRecvError::Disconnected) => Err(Error::Disconnected),
|
||||
}
|
||||
}
|
||||
|
||||
/// Receive a notification asynchronously
|
||||
pub async fn recv(&mut self) -> Option<NotificationPayload> {
|
||||
self.receiver.recv().await
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ActiveSubscription {
|
||||
fn drop(&mut self) {
|
||||
if let Some(sub_id) = self.sub_id.take() {
|
||||
let _ = self.on_drop_notif.try_send(sub_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Subscription client error
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
/// Url error
|
||||
#[error("Could not join paths: {0}")]
|
||||
Url(#[from] crate::mint_url::Error),
|
||||
/// Disconnected from the notification channel
|
||||
#[error("Disconnected from the notification channel")]
|
||||
Disconnected,
|
||||
http_client: Arc<dyn MintConnector + Send + Sync>,
|
||||
mint_url: MintUrl,
|
||||
req_id: AtomicUsize,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl SubscriptionClient {
|
||||
/// Create new [`SubscriptionClient`]
|
||||
pub fn new(
|
||||
url: MintUrl,
|
||||
http_client: Arc<dyn MintConnector + Send + Sync>,
|
||||
prefer_ws_method: bool,
|
||||
wallet: Arc<Wallet>,
|
||||
) -> Self {
|
||||
let subscriptions = Arc::new(RwLock::new(HashMap::new()));
|
||||
let (new_subscription_notif, new_subscription_recv) = mpsc::channel(100);
|
||||
let (on_drop_notif, on_drop_recv) = mpsc::channel(1000);
|
||||
|
||||
Self {
|
||||
new_subscription_notif,
|
||||
on_drop_notif,
|
||||
subscriptions: subscriptions.clone(),
|
||||
worker: Self::start_worker(
|
||||
prefer_ws_method,
|
||||
http_client,
|
||||
url,
|
||||
subscriptions,
|
||||
new_subscription_recv,
|
||||
on_drop_recv,
|
||||
wallet,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused_variables)]
|
||||
fn start_worker(
|
||||
prefer_ws_method: bool,
|
||||
http_client: Arc<dyn MintConnector + Send + Sync>,
|
||||
url: MintUrl,
|
||||
subscriptions: Arc<RwLock<HashMap<SubId, WsSubscriptionBody>>>,
|
||||
new_subscription_recv: mpsc::Receiver<SubId>,
|
||||
on_drop_recv: mpsc::Receiver<SubId>,
|
||||
wallet: Arc<Wallet>,
|
||||
) -> Option<JoinHandle<()>> {
|
||||
#[cfg(any(
|
||||
feature = "http_subscription",
|
||||
not(feature = "mint"),
|
||||
target_arch = "wasm32"
|
||||
))]
|
||||
return Self::http_worker(
|
||||
http_client,
|
||||
subscriptions,
|
||||
new_subscription_recv,
|
||||
on_drop_recv,
|
||||
wallet,
|
||||
);
|
||||
|
||||
#[cfg(all(
|
||||
not(feature = "http_subscription"),
|
||||
feature = "mint",
|
||||
not(target_arch = "wasm32")
|
||||
))]
|
||||
if prefer_ws_method {
|
||||
Self::ws_worker(
|
||||
http_client,
|
||||
url,
|
||||
subscriptions,
|
||||
new_subscription_recv,
|
||||
on_drop_recv,
|
||||
wallet,
|
||||
)
|
||||
} else {
|
||||
Self::http_worker(
|
||||
http_client,
|
||||
subscriptions,
|
||||
new_subscription_recv,
|
||||
on_drop_recv,
|
||||
wallet,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Subscribe to a WebSocket channel
|
||||
pub async fn subscribe(
|
||||
fn get_sub_request(
|
||||
&self,
|
||||
filter: Params,
|
||||
) -> (mpsc::Sender<SubId>, mpsc::Receiver<NotificationPayload>) {
|
||||
let mut subscriptions = self.subscriptions.write().await;
|
||||
let id = filter.id.clone();
|
||||
id: String,
|
||||
params: NotificationId<String>,
|
||||
) -> Option<(usize, String)> {
|
||||
let (kind, filter) = match params {
|
||||
NotificationId::ProofState(x) => (Kind::ProofState, x.to_string()),
|
||||
NotificationId::MeltQuoteBolt11(q) | NotificationId::MeltQuoteBolt12(q) => {
|
||||
(Kind::Bolt11MeltQuote, q)
|
||||
}
|
||||
NotificationId::MintQuoteBolt11(q) => (Kind::Bolt11MintQuote, q),
|
||||
NotificationId::MintQuoteBolt12(q) => (Kind::Bolt12MintQuote, q),
|
||||
};
|
||||
|
||||
let (sender, receiver) = mpsc::channel(10_000);
|
||||
subscriptions.insert(id.clone(), (sender, filter));
|
||||
drop(subscriptions);
|
||||
let request: WsRequest<_> = (
|
||||
WsMethodRequest::Subscribe(WalletParams {
|
||||
kind,
|
||||
filters: vec![filter],
|
||||
id: id.into(),
|
||||
}),
|
||||
self.req_id
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
|
||||
)
|
||||
.into();
|
||||
|
||||
let _ = self.new_subscription_notif.send(id).await;
|
||||
(self.on_drop_notif.clone(), receiver)
|
||||
serde_json::to_string(&request)
|
||||
.inspect_err(|err| {
|
||||
tracing::error!("Could not serialize subscribe message: {:?}", err);
|
||||
})
|
||||
.map(|json| (request.id, json))
|
||||
.ok()
|
||||
}
|
||||
|
||||
/// HTTP subscription client
|
||||
///
|
||||
/// This is a poll based subscription, where the client will poll the server
|
||||
/// from time to time to get updates, notifying the subscribers on changes
|
||||
fn http_worker(
|
||||
http_client: Arc<dyn MintConnector + Send + Sync>,
|
||||
subscriptions: Arc<RwLock<HashMap<SubId, WsSubscriptionBody>>>,
|
||||
new_subscription_recv: mpsc::Receiver<SubId>,
|
||||
on_drop: mpsc::Receiver<SubId>,
|
||||
wallet: Arc<Wallet>,
|
||||
) -> Option<JoinHandle<()>> {
|
||||
let http_worker = http::http_main(
|
||||
vec![],
|
||||
http_client,
|
||||
subscriptions,
|
||||
new_subscription_recv,
|
||||
on_drop,
|
||||
wallet,
|
||||
);
|
||||
fn get_unsub_request(&self, sub_id: String) -> Option<String> {
|
||||
let request: WsRequest<_> = (
|
||||
WsMethodRequest::Unsubscribe(WsUnsubscribeRequest { sub_id }),
|
||||
self.req_id
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
|
||||
)
|
||||
.into();
|
||||
|
||||
match serde_json::to_string(&request) {
|
||||
Ok(json) => Some(json),
|
||||
Err(err) => {
|
||||
tracing::error!("Could not serialize unsubscribe message: {:?}", err);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
|
||||
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
|
||||
impl Transport for SubscriptionClient {
|
||||
type Spec = MintSubTopics;
|
||||
|
||||
fn new_name(&self) -> <Self::Spec as Spec>::SubscriptionId {
|
||||
Uuid::new_v4().to_string()
|
||||
}
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
_ctrls: mpsc::Receiver<StreamCtrl<Self::Spec>>,
|
||||
_topics: Vec<SubscribeMessage<Self::Spec>>,
|
||||
_reply_to: InternalRelay<Self::Spec>,
|
||||
) -> Result<(), PubsubError> {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
let r = ws::stream_client(self, _ctrls, _topics, _reply_to).await;
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
{
|
||||
wasm_bindgen_futures::spawn_local(http_worker);
|
||||
None
|
||||
}
|
||||
let r = Err(PubsubError::NotSupported);
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
{
|
||||
Some(tokio::spawn(http_worker))
|
||||
}
|
||||
r
|
||||
}
|
||||
|
||||
/// WebSocket subscription client
|
||||
///
|
||||
/// This is a WebSocket based subscription, where the client will connect to
|
||||
/// the server and stay there idle waiting for server-side notifications
|
||||
#[cfg(all(
|
||||
not(feature = "http_subscription"),
|
||||
feature = "mint",
|
||||
not(target_arch = "wasm32")
|
||||
))]
|
||||
fn ws_worker(
|
||||
http_client: Arc<dyn MintConnector + Send + Sync>,
|
||||
url: MintUrl,
|
||||
subscriptions: Arc<RwLock<HashMap<SubId, WsSubscriptionBody>>>,
|
||||
new_subscription_recv: mpsc::Receiver<SubId>,
|
||||
on_drop: mpsc::Receiver<SubId>,
|
||||
wallet: Arc<Wallet>,
|
||||
) -> Option<JoinHandle<()>> {
|
||||
Some(tokio::spawn(ws::ws_main(
|
||||
http_client,
|
||||
url,
|
||||
subscriptions,
|
||||
new_subscription_recv,
|
||||
on_drop,
|
||||
wallet,
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for SubscriptionClient {
|
||||
fn drop(&mut self) {
|
||||
if let Some(handle) = self.worker.take() {
|
||||
handle.abort();
|
||||
}
|
||||
/// Poll on demand
|
||||
async fn poll(
|
||||
&self,
|
||||
topics: Vec<SubscribeMessage<Self::Spec>>,
|
||||
reply_to: InternalRelay<Self::Spec>,
|
||||
) -> Result<(), PubsubError> {
|
||||
let proofs = topics
|
||||
.iter()
|
||||
.filter_map(|(_, x)| match &x {
|
||||
NotificationId::ProofState(p) => Some(*p),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if !proofs.is_empty() {
|
||||
for state in self
|
||||
.http_client
|
||||
.post_check_state(CheckStateRequest { ys: proofs })
|
||||
.await
|
||||
.map_err(|e| PubsubError::Internal(Box::new(e)))?
|
||||
.states
|
||||
{
|
||||
reply_to.send(MintEvent::new(NotificationPayload::ProofState(state)));
|
||||
}
|
||||
}
|
||||
|
||||
for topic in topics
|
||||
.into_iter()
|
||||
.map(|(_, x)| x)
|
||||
.filter(|x| !matches!(x, NotificationId::ProofState(_)))
|
||||
{
|
||||
match topic {
|
||||
NotificationId::MintQuoteBolt11(id) => {
|
||||
let response = match self.http_client.get_mint_quote_status(&id).await {
|
||||
Ok(success) => success,
|
||||
Err(err) => {
|
||||
tracing::error!("Error with MintBolt11 {} with {:?}", id, err);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
reply_to.send(MintEvent::new(
|
||||
NotificationPayload::MintQuoteBolt11Response(response.clone()),
|
||||
));
|
||||
}
|
||||
NotificationId::MeltQuoteBolt11(id) => {
|
||||
let response = match self.http_client.get_melt_quote_status(&id).await {
|
||||
Ok(success) => success,
|
||||
Err(err) => {
|
||||
tracing::error!("Error with MeltBolt11 {} with {:?}", id, err);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
reply_to.send(MintEvent::new(
|
||||
NotificationPayload::MeltQuoteBolt11Response(response),
|
||||
));
|
||||
}
|
||||
NotificationId::MintQuoteBolt12(id) => {
|
||||
let response = match self.http_client.get_mint_quote_bolt12_status(&id).await {
|
||||
Ok(success) => success,
|
||||
Err(err) => {
|
||||
tracing::error!("Error with MintBolt12 {} with {:?}", id, err);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
reply_to.send(MintEvent::new(
|
||||
NotificationPayload::MintQuoteBolt12Response(response),
|
||||
));
|
||||
}
|
||||
NotificationId::MeltQuoteBolt12(id) => {
|
||||
let response = match self.http_client.get_melt_bolt12_quote_status(&id).await {
|
||||
Ok(success) => success,
|
||||
Err(err) => {
|
||||
tracing::error!("Error with MeltBolt12 {} with {:?}", id, err);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
reply_to.send(MintEvent::new(
|
||||
NotificationPayload::MeltQuoteBolt11Response(response),
|
||||
));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,38 +1,25 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use cdk_common::subscription::Params;
|
||||
use cdk_common::ws::{WsMessageOrResponse, WsMethodRequest, WsRequest, WsUnsubscribeRequest};
|
||||
use cdk_common::nut17::ws::WsMessageOrResponse;
|
||||
use cdk_common::pub_sub::remote_consumer::{InternalRelay, StreamCtrl, SubscribeMessage};
|
||||
use cdk_common::pub_sub::Error as PubsubError;
|
||||
#[cfg(feature = "auth")]
|
||||
use cdk_common::{Method, RoutePath};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use tokio::time::sleep;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_tungstenite::connect_async;
|
||||
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
|
||||
use super::http::http_main;
|
||||
use super::WsSubscriptionBody;
|
||||
use crate::mint_url::MintUrl;
|
||||
use crate::pub_sub::SubId;
|
||||
use crate::wallet::MintConnector;
|
||||
use crate::Wallet;
|
||||
use super::{MintSubTopics, SubscriptionClient};
|
||||
|
||||
const MAX_ATTEMPT_FALLBACK_HTTP: usize = 10;
|
||||
|
||||
#[inline]
|
||||
pub async fn ws_main(
|
||||
http_client: Arc<dyn MintConnector + Send + Sync>,
|
||||
mint_url: MintUrl,
|
||||
subscriptions: Arc<RwLock<HashMap<SubId, WsSubscriptionBody>>>,
|
||||
mut new_subscription_recv: mpsc::Receiver<SubId>,
|
||||
mut on_drop: mpsc::Receiver<SubId>,
|
||||
wallet: Arc<Wallet>,
|
||||
) {
|
||||
let mut url = mint_url
|
||||
#[inline(always)]
|
||||
pub(crate) async fn stream_client(
|
||||
client: &SubscriptionClient,
|
||||
mut ctrl: mpsc::Receiver<StreamCtrl<MintSubTopics>>,
|
||||
topics: Vec<SubscribeMessage<MintSubTopics>>,
|
||||
reply_to: InternalRelay<MintSubTopics>,
|
||||
) -> Result<(), PubsubError> {
|
||||
let mut url = client
|
||||
.mint_url
|
||||
.join_paths(&["v1", "ws"])
|
||||
.expect("Could not join paths");
|
||||
|
||||
@@ -42,241 +29,140 @@ pub async fn ws_main(
|
||||
url.set_scheme("ws").expect("Could not set scheme");
|
||||
}
|
||||
|
||||
let request = match url.to_string().into_client_request() {
|
||||
Ok(req) => req,
|
||||
Err(err) => {
|
||||
tracing::error!("Failed to create client request: {:?}", err);
|
||||
// Fallback to HTTP client if we can't create the WebSocket request
|
||||
return http_main(
|
||||
std::iter::empty(),
|
||||
http_client,
|
||||
subscriptions,
|
||||
new_subscription_recv,
|
||||
on_drop,
|
||||
wallet,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
};
|
||||
#[cfg(not(feature = "auth"))]
|
||||
let request = url.to_string().into_client_request().map_err(|err| {
|
||||
tracing::error!("Failed to create client request: {:?}", err);
|
||||
// Fallback to HTTP client if we can't create the WebSocket request
|
||||
cdk_common::pub_sub::Error::NotSupported
|
||||
})?;
|
||||
|
||||
let mut active_subscriptions = HashMap::<SubId, mpsc::Sender<_>>::new();
|
||||
let mut failure_count = 0;
|
||||
#[cfg(feature = "auth")]
|
||||
let mut request = url.to_string().into_client_request().map_err(|err| {
|
||||
tracing::error!("Failed to create client request: {:?}", err);
|
||||
// Fallback to HTTP client if we can't create the WebSocket request
|
||||
cdk_common::pub_sub::Error::NotSupported
|
||||
})?;
|
||||
|
||||
loop {
|
||||
if subscriptions.read().await.is_empty() {
|
||||
// No active subscription
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut request_clone = request.clone();
|
||||
#[cfg(feature = "auth")]
|
||||
{
|
||||
let auth_wallet = http_client.get_auth_wallet().await;
|
||||
let token = match auth_wallet.as_ref() {
|
||||
Some(auth_wallet) => {
|
||||
let endpoint = cdk_common::ProtectedEndpoint::new(Method::Get, RoutePath::Ws);
|
||||
match auth_wallet.get_auth_for_request(&endpoint).await {
|
||||
Ok(token) => token,
|
||||
Err(err) => {
|
||||
tracing::warn!("Failed to get auth token: {:?}", err);
|
||||
None
|
||||
}
|
||||
#[cfg(feature = "auth")]
|
||||
{
|
||||
let auth_wallet = client.http_client.get_auth_wallet().await;
|
||||
let token = match auth_wallet.as_ref() {
|
||||
Some(auth_wallet) => {
|
||||
let endpoint = cdk_common::ProtectedEndpoint::new(Method::Get, RoutePath::Ws);
|
||||
match auth_wallet.get_auth_for_request(&endpoint).await {
|
||||
Ok(token) => token,
|
||||
Err(err) => {
|
||||
tracing::warn!("Failed to get auth token: {:?}", err);
|
||||
None
|
||||
}
|
||||
}
|
||||
None => None,
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
if let Some(auth_token) = token {
|
||||
let header_key = match &auth_token {
|
||||
cdk_common::AuthToken::ClearAuth(_) => "Clear-auth",
|
||||
cdk_common::AuthToken::BlindAuth(_) => "Blind-auth",
|
||||
};
|
||||
|
||||
if let Some(auth_token) = token {
|
||||
let header_key = match &auth_token {
|
||||
cdk_common::AuthToken::ClearAuth(_) => "Clear-auth",
|
||||
cdk_common::AuthToken::BlindAuth(_) => "Blind-auth",
|
||||
};
|
||||
|
||||
match auth_token.to_string().parse() {
|
||||
Ok(header_value) => {
|
||||
request_clone.headers_mut().insert(header_key, header_value);
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::warn!("Failed to parse auth token as header value: {:?}", err);
|
||||
}
|
||||
match auth_token.to_string().parse() {
|
||||
Ok(header_value) => {
|
||||
request.headers_mut().insert(header_key, header_value);
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::warn!("Failed to parse auth token as header value: {:?}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::debug!("Connecting to {}", url);
|
||||
let ws_stream = match connect_async(request_clone.clone()).await {
|
||||
Ok((ws_stream, _)) => ws_stream,
|
||||
Err(err) => {
|
||||
failure_count += 1;
|
||||
tracing::error!("Could not connect to server: {:?}", err);
|
||||
if failure_count > MAX_ATTEMPT_FALLBACK_HTTP {
|
||||
tracing::error!(
|
||||
"Could not connect to server after {MAX_ATTEMPT_FALLBACK_HTTP} attempts, falling back to HTTP-subscription client"
|
||||
);
|
||||
tracing::debug!("Connecting to {}", url);
|
||||
let ws_stream = connect_async(request)
|
||||
.await
|
||||
.map(|(ws_stream, _)| ws_stream)
|
||||
.map_err(|err| {
|
||||
tracing::error!("Error connecting: {err:?}");
|
||||
|
||||
return http_main(
|
||||
active_subscriptions.into_keys(),
|
||||
http_client,
|
||||
subscriptions,
|
||||
new_subscription_recv,
|
||||
on_drop,
|
||||
wallet,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
};
|
||||
tracing::debug!("Connected to {}", url);
|
||||
cdk_common::pub_sub::Error::Internal(Box::new(err))
|
||||
})?;
|
||||
|
||||
let (mut write, mut read) = ws_stream.split();
|
||||
let req_id = AtomicUsize::new(0);
|
||||
tracing::debug!("Connected to {}", url);
|
||||
let (mut write, mut read) = ws_stream.split();
|
||||
|
||||
let get_sub_request = |params: Params| -> Option<(usize, String)> {
|
||||
let request: WsRequest = (
|
||||
WsMethodRequest::Subscribe(params),
|
||||
req_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
|
||||
)
|
||||
.into();
|
||||
|
||||
match serde_json::to_string(&request) {
|
||||
Ok(json) => Some((request.id, json)),
|
||||
Err(err) => {
|
||||
tracing::error!("Could not serialize subscribe message: {:?}", err);
|
||||
None
|
||||
}
|
||||
}
|
||||
for (name, index) in topics {
|
||||
let (_, req) = if let Some(req) = client.get_sub_request(name, index) {
|
||||
req
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let get_unsub_request = |sub_id: SubId| -> Option<String> {
|
||||
let request: WsRequest = (
|
||||
WsMethodRequest::Unsubscribe(WsUnsubscribeRequest { sub_id }),
|
||||
req_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
|
||||
)
|
||||
.into();
|
||||
let _ = write.send(Message::Text(req.into())).await;
|
||||
}
|
||||
|
||||
match serde_json::to_string(&request) {
|
||||
Ok(json) => Some(json),
|
||||
Err(err) => {
|
||||
tracing::error!("Could not serialize unsubscribe message: {:?}", err);
|
||||
None
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Websocket reconnected, restore all subscriptions
|
||||
let mut subscription_requests = HashSet::new();
|
||||
|
||||
let read_subscriptions = subscriptions.read().await;
|
||||
for (sub_id, _) in active_subscriptions.iter() {
|
||||
if let Some(Some((req_id, req))) = read_subscriptions
|
||||
.get(sub_id)
|
||||
.map(|(_, params)| get_sub_request(params.clone()))
|
||||
{
|
||||
let _ = write.send(Message::Text(req.into())).await;
|
||||
subscription_requests.insert(req_id);
|
||||
}
|
||||
}
|
||||
drop(read_subscriptions);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(msg) = read.next() => {
|
||||
let msg = match msg {
|
||||
Ok(msg) => msg,
|
||||
Err(_) => {
|
||||
if let Err(err) = write.send(Message::Close(None)).await {
|
||||
tracing::error!("Closing error {err:?}");
|
||||
}
|
||||
break
|
||||
},
|
||||
};
|
||||
let msg = match msg {
|
||||
Message::Text(msg) => msg,
|
||||
_ => continue,
|
||||
};
|
||||
let msg = match serde_json::from_str::<WsMessageOrResponse>(&msg) {
|
||||
Ok(msg) => msg,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
match msg {
|
||||
WsMessageOrResponse::Notification(payload) => {
|
||||
tracing::debug!("Received notification from server: {:?}", payload);
|
||||
let _ = active_subscriptions.get(&payload.params.sub_id).map(|sender| {
|
||||
let _ = sender.try_send(payload.params.payload);
|
||||
});
|
||||
}
|
||||
WsMessageOrResponse::Response(response) => {
|
||||
tracing::debug!("Received response from server: {:?}", response);
|
||||
subscription_requests.remove(&response.id);
|
||||
// reset connection failure after a successful response from the serer
|
||||
failure_count = 0;
|
||||
}
|
||||
WsMessageOrResponse::ErrorResponse(error) => {
|
||||
tracing::error!("Received error from server: {:?}", error);
|
||||
|
||||
if subscription_requests.contains(&error.id) {
|
||||
failure_count += 1;
|
||||
if failure_count > MAX_ATTEMPT_FALLBACK_HTTP {
|
||||
tracing::error!(
|
||||
"Falling back to HTTP client"
|
||||
);
|
||||
|
||||
return http_main(
|
||||
active_subscriptions.into_keys(),
|
||||
http_client,
|
||||
subscriptions,
|
||||
new_subscription_recv,
|
||||
on_drop,
|
||||
wallet,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Err(err) = write.send(Message::Close(None)).await {
|
||||
tracing::error!("Closing error {err:?}");
|
||||
}
|
||||
|
||||
break; // break connection to force a reconnection, to attempt to recover form this error
|
||||
}
|
||||
}
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(msg) = ctrl.recv() => {
|
||||
match msg {
|
||||
StreamCtrl::Subscribe(msg) => {
|
||||
let (_, req) = if let Some(req) = client.get_sub_request(msg.0, msg.1) {
|
||||
req
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
let _ = write.send(Message::Text(req.into())).await;
|
||||
}
|
||||
|
||||
}
|
||||
Some(subid) = new_subscription_recv.recv() => {
|
||||
let subscription = subscriptions.read().await;
|
||||
let sub = if let Some(subscription) = subscription.get(&subid) {
|
||||
subscription
|
||||
} else {
|
||||
continue
|
||||
};
|
||||
tracing::debug!("Subscribing to {:?}", sub.1);
|
||||
active_subscriptions.insert(subid, sub.0.clone());
|
||||
if let Some((req_id, json)) = get_sub_request(sub.1.clone()) {
|
||||
let _ = write.send(Message::Text(json.into())).await;
|
||||
subscription_requests.insert(req_id);
|
||||
StreamCtrl::Unsubscribe(msg) => {
|
||||
let req = if let Some(req) = client.get_unsub_request(msg) {
|
||||
req
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
let _ = write.send(Message::Text(req.into())).await;
|
||||
}
|
||||
},
|
||||
Some(subid) = on_drop.recv() => {
|
||||
let mut subscription = subscriptions.write().await;
|
||||
if let Some(sub) = subscription.remove(&subid) {
|
||||
drop(sub);
|
||||
}
|
||||
tracing::debug!("Unsubscribing from {:?}", subid);
|
||||
if let Some(json) = get_unsub_request(subid) {
|
||||
let _ = write.send(Message::Text(json.into())).await;
|
||||
}
|
||||
|
||||
if subscription.is_empty() {
|
||||
StreamCtrl::Stop => {
|
||||
if let Err(err) = write.send(Message::Close(None)).await {
|
||||
tracing::error!("Closing error {err:?}");
|
||||
}
|
||||
break;
|
||||
}
|
||||
};
|
||||
}
|
||||
Some(msg) = read.next() => {
|
||||
let msg = match msg {
|
||||
Ok(msg) => msg,
|
||||
Err(_) => {
|
||||
if let Err(err) = write.send(Message::Close(None)).await {
|
||||
tracing::error!("Closing error {err:?}");
|
||||
}
|
||||
break;
|
||||
}
|
||||
};
|
||||
let msg = match msg {
|
||||
Message::Text(msg) => msg,
|
||||
_ => continue,
|
||||
};
|
||||
let msg = match serde_json::from_str::<WsMessageOrResponse<String>>(&msg) {
|
||||
Ok(msg) => msg,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
match msg {
|
||||
WsMessageOrResponse::Notification(payload) => {
|
||||
reply_to.send(payload.params.payload);
|
||||
}
|
||||
WsMessageOrResponse::Response(response) => {
|
||||
tracing::debug!("Received response from server: {:?}", response);
|
||||
}
|
||||
WsMessageOrResponse::ErrorResponse(error) => {
|
||||
tracing::debug!("Received an error from server: {:?}", error);
|
||||
return Err(PubsubError::InternalStr(error.error.message));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
4
justfile
4
justfile
@@ -66,14 +66,14 @@ test:
|
||||
|
||||
|
||||
# run doc tests
|
||||
test-pure db="memory": build
|
||||
test-pure db="memory":
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
if [ ! -f Cargo.toml ]; then
|
||||
cd {{invocation_directory()}}
|
||||
fi
|
||||
|
||||
# Run pure integration tests
|
||||
# Run pure integration tests (cargo test will only build what's needed for the test)
|
||||
CDK_TEST_DB_TYPE={{db}} cargo test -p cdk-integration-tests --test integration_tests_pure -- --test-threads 1
|
||||
|
||||
test-all db="memory":
|
||||
|
||||
Reference in New Issue
Block a user