fix: Send current state of the subscriptions (#444)

This commit is contained in:
César D. Rodas
2024-11-10 09:08:44 -03:00
committed by GitHub
parent 70ef5a4859
commit cc5b267367
9 changed files with 236 additions and 19 deletions

View File

@@ -11,17 +11,26 @@ use cdk::{
pub struct Method(Params);
#[derive(Debug, Clone, serde::Serialize)]
/// The response to a subscription request
pub struct Response {
/// Status
status: String,
/// Subscription ID
#[serde(rename = "subId")]
sub_id: SubId,
}
#[derive(Debug, Clone, serde::Serialize)]
/// The notification
///
/// This is the notification that is sent to the client when an event matches a
/// subscription
pub struct Notification {
/// The subscription ID
#[serde(rename = "subId")]
pub sub_id: SubId,
/// The notification payload
pub payload: NotificationPayload,
}
@@ -39,12 +48,21 @@ impl From<(SubId, NotificationPayload)> for WsNotification<Notification> {
impl WsHandle for Method {
type Response = Response;
/// The `handle` method is called when a client sends a subscription request
async fn handle(self, context: &mut WsContext) -> Result<Self::Response, WsError> {
let sub_id = self.0.id.clone();
if context.subscriptions.contains_key(&sub_id) {
// Subscription ID already exits. Returns an error instead of
// replacing the other subscription or avoiding it.
return Err(WsError::InvalidParams);
}
let mut subscription = context.state.mint.pubsub_manager.subscribe(self.0).await;
let mut subscription = context
.state
.mint
.pubsub_manager
.subscribe(self.0.clone())
.await;
let publisher = context.publisher.clone();
context.subscriptions.insert(
sub_id.clone(),

View File

@@ -35,7 +35,7 @@ async fn get_notification<T: StreamExt<Item = Result<Message, E>> + Unpin, E: De
.unwrap();
let mut response: serde_json::Value =
serde_json::from_str(&msg.to_text().unwrap()).expect("valid json");
serde_json::from_str(msg.to_text().unwrap()).expect("valid json");
let mut params_raw = response
.as_object_mut()
@@ -112,6 +112,18 @@ async fn test_regtest_mint_melt_round_trip() -> Result<()> {
assert!(melt_response.preimage.is_some());
assert!(melt_response.state == MeltQuoteState::Paid);
let (sub_id, payload) = get_notification(&mut reader, Duration::from_millis(15000)).await;
// first message is the current state
assert_eq!("test-sub", sub_id);
let payload = match payload {
NotificationPayload::MeltQuoteBolt11Response(melt) => melt,
_ => panic!("Wrong payload"),
};
assert_eq!(payload.amount + payload.fee_reserve, 100.into());
assert_eq!(payload.quote, melt.id);
assert_eq!(payload.state, MeltQuoteState::Unpaid);
// get current state
let (sub_id, payload) = get_notification(&mut reader, Duration::from_millis(15000)).await;
assert_eq!("test-sub", sub_id);
let payload = match payload {

View File

@@ -39,7 +39,7 @@ serde_json = "1"
serde_with = "3"
tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] }
thiserror = "1"
futures = { version = "0.3.28", default-features = false, optional = true }
futures = { version = "0.3.28", default-features = false, optional = true, features = ["alloc"] }
url = "2.3"
utoipa = { version = "4", optional = true }
uuid = { version = "1", features = ["v4"] }

View File

@@ -185,7 +185,7 @@ impl Mint {
Ok(Self {
mint_url: MintUrl::from_str(mint_url)?,
keysets: Arc::new(RwLock::new(active_keysets)),
pubsub_manager: Default::default(),
pubsub_manager: Arc::new(localstore.clone().into()),
secp_ctx,
quote_ttl,
xpriv,

View File

@@ -18,6 +18,7 @@ pub mod nut12;
pub mod nut13;
pub mod nut14;
pub mod nut15;
#[cfg(feature = "mint")]
pub mod nut17;
pub mod nut18;
@@ -48,5 +49,6 @@ pub use nut11::{Conditions, P2PKWitness, SigFlag, SpendingConditions};
pub use nut12::{BlindSignatureDleq, ProofDleq};
pub use nut14::HTLCWitness;
pub use nut15::{Mpp, MppMethodSettings, Settings as NUT15Settings};
#[cfg(feature = "mint")]
pub use nut17::{NotificationPayload, PubSubManager};
pub use nut18::{PaymentRequest, PaymentRequestPayload, Transport};

View File

@@ -5,7 +5,7 @@
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use super::nut01::PublicKey;
use super::{nut04, nut05, nut15, nut17, MppMethodSettings};
use super::{nut04, nut05, nut15, MppMethodSettings};
/// Mint Version
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
@@ -238,7 +238,8 @@ pub struct Nuts {
/// NUT17 Settings
#[serde(default)]
#[serde(rename = "17")]
pub nut17: nut17::SupportedSettings,
#[cfg(feature = "mint")]
pub nut17: super::nut17::SupportedSettings,
}
impl Nuts {

View File

@@ -1,5 +1,8 @@
//! Specific Subscription for the cdk crate
use super::{BlindSignature, CurrencyUnit, PaymentMethod};
use crate::cdk_database::{self, MintDatabase};
pub use crate::pub_sub::SubId;
use crate::{
nuts::{
MeltQuoteBolt11Response, MeltQuoteState, MintQuoteBolt11Response, MintQuoteState,
@@ -8,7 +11,11 @@ use crate::{
pub_sub::{self, Index, Indexable, SubscriptionGlobalId},
};
use serde::{Deserialize, Serialize};
use std::ops::Deref;
use std::{ops::Deref, sync::Arc};
mod on_subscription;
pub use on_subscription::OnSubscription;
/// Subscription Parameter according to the standard
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -57,10 +64,6 @@ impl Default for SupportedMethods {
}
}
pub use crate::pub_sub::SubId;
use super::{BlindSignature, CurrencyUnit, PaymentMethod};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
/// Subscription response
@@ -145,15 +148,27 @@ impl From<Params> for Vec<Index<(String, Kind)>> {
}
/// Manager
#[derive(Default)]
/// Publishsubscribe manager
///
/// Nut-17 implementation is system-wide and not only through the WebSocket, so
/// it is possible for another part of the system to subscribe to events.
pub struct PubSubManager(pub_sub::Manager<NotificationPayload, (String, Kind)>);
pub struct PubSubManager(pub_sub::Manager<NotificationPayload, (String, Kind), OnSubscription>);
#[allow(clippy::default_constructed_unit_structs)]
impl Default for PubSubManager {
fn default() -> Self {
PubSubManager(OnSubscription::default().into())
}
}
impl From<Arc<dyn MintDatabase<Err = cdk_database::Error> + Send + Sync>> for PubSubManager {
fn from(val: Arc<dyn MintDatabase<Err = cdk_database::Error> + Send + Sync>) -> Self {
PubSubManager(OnSubscription(Some(val)).into())
}
}
impl Deref for PubSubManager {
type Target = pub_sub::Manager<NotificationPayload, (String, Kind)>;
type Target = pub_sub::Manager<NotificationPayload, (String, Kind), OnSubscription>;
fn deref(&self) -> &Self::Target {
&self.0

View File

@@ -0,0 +1,110 @@
//! On Subscription
//!
//! This module contains the code that is triggered when a new subscription is created.
use super::{Kind, NotificationPayload};
use crate::{
cdk_database::{self, MintDatabase},
nuts::{MeltQuoteBolt11Response, MintQuoteBolt11Response, ProofState, PublicKey},
pub_sub::OnNewSubscription,
};
use std::{collections::HashMap, sync::Arc};
#[derive(Default)]
/// Subscription Init
///
/// This struct triggers code when a new subscription is created.
///
/// It is used to send the initial state of the subscription to the client.
pub struct OnSubscription(
pub(crate) Option<Arc<dyn MintDatabase<Err = cdk_database::Error> + Send + Sync>>,
);
#[async_trait::async_trait]
impl OnNewSubscription for OnSubscription {
type Event = NotificationPayload;
type Index = (String, Kind);
async fn on_new_subscription(
&self,
request: &[&Self::Index],
) -> Result<Vec<Self::Event>, String> {
let datastore = if let Some(localstore) = self.0.as_ref() {
localstore
} else {
return Ok(vec![]);
};
let mut to_return = vec![];
for (kind, values) in request.iter().fold(
HashMap::new(),
|mut acc: HashMap<&Kind, Vec<&String>>, (data, kind)| {
acc.entry(kind).or_default().push(data);
acc
},
) {
match kind {
Kind::Bolt11MeltQuote => {
let queries = values
.iter()
.map(|id| datastore.get_melt_quote(id))
.collect::<Vec<_>>();
to_return.extend(
futures::future::try_join_all(queries)
.await
.map(|quotes| {
quotes
.into_iter()
.filter_map(|quote| quote.map(|x| x.into()))
.map(|x: MeltQuoteBolt11Response| x.into())
.collect::<Vec<_>>()
})
.map_err(|e| e.to_string())?,
);
}
Kind::Bolt11MintQuote => {
let queries = values
.iter()
.map(|id| datastore.get_mint_quote(id))
.collect::<Vec<_>>();
to_return.extend(
futures::future::try_join_all(queries)
.await
.map(|quotes| {
quotes
.into_iter()
.filter_map(|quote| quote.map(|x| x.into()))
.map(|x: MintQuoteBolt11Response| x.into())
.collect::<Vec<_>>()
})
.map_err(|e| e.to_string())?,
);
}
Kind::ProofState => {
let public_keys = values
.iter()
.map(PublicKey::from_hex)
.collect::<Result<Vec<PublicKey>, _>>()
.map_err(|e| e.to_string())?;
to_return.extend(
datastore
.get_proofs_states(&public_keys)
.await
.map_err(|e| e.to_string())?
.into_iter()
.enumerate()
.filter_map(|(idx, state)| {
state.map(|state| (public_keys[idx], state).into())
})
.map(|state: ProofState| state.into()),
);
}
}
}
Ok(to_return)
}
}

View File

@@ -37,6 +37,25 @@ pub const DEFAULT_REMOVE_SIZE: usize = 10_000;
/// Default channel size for subscription buffering
pub const DEFAULT_CHANNEL_SIZE: usize = 10;
#[async_trait::async_trait]
/// On New Subscription trait
///
/// This trait is optional and it is used to notify the application when a new
/// subscription is created. This is useful when the application needs to send
/// the initial state to the subscriber upon subscription
pub trait OnNewSubscription {
/// Index type
type Index;
/// Subscription event type
type Event;
/// Called when a new subscription is created
async fn on_new_subscription(
&self,
request: &[&Self::Index],
) -> Result<Vec<Self::Event>, String>;
}
/// Subscription manager
///
/// This object keep track of all subscription listener and it is also
@@ -45,21 +64,24 @@ pub const DEFAULT_CHANNEL_SIZE: usize = 10;
/// The content of the notification is not relevant to this scope and it is up
/// to the application, therefore the generic T is used instead of a specific
/// type
pub struct Manager<T, I>
pub struct Manager<T, I, F>
where
T: Indexable<Type = I> + Clone + Send + Sync + 'static,
I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static,
F: OnNewSubscription<Index = I, Event = T> + 'static,
{
indexes: IndexTree<T, I>,
on_new_subscription: Option<F>,
unsubscription_sender: mpsc::Sender<(SubId, Vec<Index<I>>)>,
active_subscriptions: Arc<AtomicUsize>,
background_subscription_remover: Option<JoinHandle<()>>,
}
impl<T, I> Default for Manager<T, I>
impl<T, I, F> Default for Manager<T, I, F>
where
T: Indexable<Type = I> + Clone + Send + Sync + 'static,
I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static,
F: OnNewSubscription<Index = I, Event = T> + 'static,
{
fn default() -> Self {
let (sender, receiver) = mpsc::channel(DEFAULT_REMOVE_SIZE);
@@ -72,6 +94,7 @@ where
storage.clone(),
active_subscriptions.clone(),
))),
on_new_subscription: None,
unsubscription_sender: sender,
active_subscriptions,
indexes: storage,
@@ -79,10 +102,24 @@ where
}
}
impl<T, I> Manager<T, I>
impl<T, I, F> From<F> for Manager<T, I, F>
where
T: Indexable<Type = I> + Clone + Send + Sync + 'static,
I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static,
I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static,
F: OnNewSubscription<Index = I, Event = T> + 'static,
{
fn from(value: F) -> Self {
let mut manager: Self = Default::default();
manager.on_new_subscription = Some(value);
manager
}
}
impl<T, I, F> Manager<T, I, F>
where
T: Indexable<Type = I> + Clone + Send + Sync + 'static,
I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static,
F: OnNewSubscription<Index = I, Event = T> + 'static,
{
#[inline]
/// Broadcast an event to all listeners
@@ -132,8 +169,29 @@ where
) -> ActiveSubscription<T, I> {
let (sender, receiver) = mpsc::channel(10);
let sub_id: SubId = params.as_ref().clone();
let indexes: Vec<Index<I>> = params.into();
if let Some(on_new_subscription) = self.on_new_subscription.as_ref() {
match on_new_subscription
.on_new_subscription(&indexes.iter().map(|x| x.deref()).collect::<Vec<_>>())
.await
{
Ok(events) => {
for event in events {
let _ = sender.try_send((sub_id.clone(), event));
}
}
Err(err) => {
tracing::info!(
"Failed to get initial state for subscription: {:?}, {}",
sub_id,
err
);
}
}
}
let mut index_storage = self.indexes.write().await;
for index in indexes.clone() {
index_storage.insert(index, sender.clone());
@@ -180,10 +238,11 @@ where
}
/// Manager goes out of scope, stop all background tasks
impl<T, I> Drop for Manager<T, I>
impl<T, I, F> Drop for Manager<T, I, F>
where
T: Indexable<Type = I> + Clone + Send + Sync + 'static,
I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static,
F: OnNewSubscription<Index = I, Event = T> + 'static,
{
fn drop(&mut self) {
if let Some(handler) = self.background_subscription_remover.take() {