diff --git a/crates/cdk/src/cdk_database/wallet_memory.rs b/crates/cdk/src/cdk_database/wallet_memory.rs index 4bee0967..e5065dee 100644 --- a/crates/cdk/src/cdk_database/wallet_memory.rs +++ b/crates/cdk/src/cdk_database/wallet_memory.rs @@ -4,7 +4,7 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; use async_trait::async_trait; -use tokio::sync::Mutex; +use tokio::sync::RwLock; use super::WalletDatabase; use crate::cdk_database::Error; @@ -14,14 +14,14 @@ use crate::url::UncheckedUrl; #[derive(Default, Debug, Clone)] pub struct WalletMemoryDatabase { - mints: Arc>>>, - mint_keysets: Arc>>>, - mint_quotes: Arc>>, - melt_quotes: Arc>>, - mint_keys: Arc>>, - proofs: Arc>>>, - pending_proofs: Arc>>>, - keyset_counter: Arc>>, + mints: Arc>>>, + mint_keysets: Arc>>>, + mint_quotes: Arc>>, + melt_quotes: Arc>>, + mint_keys: Arc>>, + proofs: Arc>>>, + pending_proofs: Arc>>>, + keyset_counter: Arc>>, } impl WalletMemoryDatabase { @@ -32,20 +32,20 @@ impl WalletMemoryDatabase { keyset_counter: HashMap, ) -> Self { Self { - mints: Arc::new(Mutex::new(HashMap::new())), - mint_keysets: Arc::new(Mutex::new(HashMap::new())), - mint_quotes: Arc::new(Mutex::new( + mints: Arc::new(RwLock::new(HashMap::new())), + mint_keysets: Arc::new(RwLock::new(HashMap::new())), + mint_quotes: Arc::new(RwLock::new( mint_quotes.into_iter().map(|q| (q.id.clone(), q)).collect(), )), - melt_quotes: Arc::new(Mutex::new( + melt_quotes: Arc::new(RwLock::new( melt_quotes.into_iter().map(|q| (q.id.clone(), q)).collect(), )), - mint_keys: Arc::new(Mutex::new( + mint_keys: Arc::new(RwLock::new( mint_keys.into_iter().map(|k| (Id::from(&k), k)).collect(), )), - proofs: Arc::new(Mutex::new(HashMap::new())), - pending_proofs: Arc::new(Mutex::new(HashMap::new())), - keyset_counter: Arc::new(Mutex::new(keyset_counter)), + proofs: Arc::new(RwLock::new(HashMap::new())), + pending_proofs: Arc::new(RwLock::new(HashMap::new())), + keyset_counter: Arc::new(RwLock::new(keyset_counter)), } } } @@ -60,16 +60,16 @@ impl WalletDatabase for WalletMemoryDatabase { mint_url: UncheckedUrl, mint_info: Option, ) -> Result<(), Self::Err> { - self.mints.lock().await.insert(mint_url, mint_info); + self.mints.write().await.insert(mint_url, mint_info); Ok(()) } async fn get_mint(&self, mint_url: UncheckedUrl) -> Result, Self::Err> { - Ok(self.mints.lock().await.get(&mint_url).cloned().flatten()) + Ok(self.mints.read().await.get(&mint_url).cloned().flatten()) } async fn get_mints(&self) -> Result>, Error> { - Ok(self.mints.lock().await.clone()) + Ok(self.mints.read().await.clone()) } async fn add_mint_keysets( @@ -77,7 +77,7 @@ impl WalletDatabase for WalletMemoryDatabase { mint_url: UncheckedUrl, keysets: Vec, ) -> Result<(), Error> { - let mut current_keysets = self.mint_keysets.lock().await; + let mut current_keysets = self.mint_keysets.write().await; let mint_keysets = current_keysets.entry(mint_url).or_insert(HashSet::new()); mint_keysets.extend(keysets); @@ -91,7 +91,7 @@ impl WalletDatabase for WalletMemoryDatabase { ) -> Result>, Error> { Ok(self .mint_keysets - .lock() + .read() .await .get(&mint_url) .map(|ks| ks.iter().cloned().collect())) @@ -99,61 +99,61 @@ impl WalletDatabase for WalletMemoryDatabase { async fn add_mint_quote(&self, quote: MintQuote) -> Result<(), Error> { self.mint_quotes - .lock() + .write() .await .insert(quote.id.clone(), quote); Ok(()) } async fn get_mint_quote(&self, quote_id: &str) -> Result, Error> { - Ok(self.mint_quotes.lock().await.get(quote_id).cloned()) + Ok(self.mint_quotes.read().await.get(quote_id).cloned()) } async fn get_mint_quotes(&self) -> Result, Error> { - let quotes = self.mint_quotes.lock().await; + let quotes = self.mint_quotes.read().await; Ok(quotes.values().cloned().collect()) } async fn remove_mint_quote(&self, quote_id: &str) -> Result<(), Error> { - self.mint_quotes.lock().await.remove(quote_id); + self.mint_quotes.write().await.remove(quote_id); Ok(()) } async fn add_melt_quote(&self, quote: MeltQuote) -> Result<(), Error> { self.melt_quotes - .lock() + .write() .await .insert(quote.id.clone(), quote); Ok(()) } async fn get_melt_quote(&self, quote_id: &str) -> Result, Error> { - Ok(self.melt_quotes.lock().await.get(quote_id).cloned()) + Ok(self.melt_quotes.read().await.get(quote_id).cloned()) } async fn remove_melt_quote(&self, quote_id: &str) -> Result<(), Error> { - self.melt_quotes.lock().await.remove(quote_id); + self.melt_quotes.write().await.remove(quote_id); Ok(()) } async fn add_keys(&self, keys: Keys) -> Result<(), Error> { - self.mint_keys.lock().await.insert(Id::from(&keys), keys); + self.mint_keys.write().await.insert(Id::from(&keys), keys); Ok(()) } async fn get_keys(&self, id: &Id) -> Result, Error> { - Ok(self.mint_keys.lock().await.get(id).cloned()) + Ok(self.mint_keys.read().await.get(id).cloned()) } async fn remove_keys(&self, id: &Id) -> Result<(), Error> { - self.mint_keys.lock().await.remove(id); + self.mint_keys.write().await.remove(id); Ok(()) } async fn add_proofs(&self, mint_url: UncheckedUrl, proofs: Proofs) -> Result<(), Error> { - let mut all_proofs = self.proofs.lock().await; + let mut all_proofs = self.proofs.write().await; let mint_proofs = all_proofs.entry(mint_url).or_insert(HashSet::new()); mint_proofs.extend(proofs); @@ -164,14 +164,14 @@ impl WalletDatabase for WalletMemoryDatabase { async fn get_proofs(&self, mint_url: UncheckedUrl) -> Result, Error> { Ok(self .proofs - .lock() + .read() .await .get(&mint_url) .map(|p| p.iter().cloned().collect())) } async fn remove_proofs(&self, mint_url: UncheckedUrl, proofs: &Proofs) -> Result<(), Error> { - let mut mint_proofs = self.proofs.lock().await; + let mut mint_proofs = self.proofs.write().await; if let Some(mint_proofs) = mint_proofs.get_mut(&mint_url) { for proof in proofs { @@ -187,7 +187,7 @@ impl WalletDatabase for WalletMemoryDatabase { mint_url: UncheckedUrl, proofs: Proofs, ) -> Result<(), Error> { - let mut all_proofs = self.pending_proofs.lock().await; + let mut all_proofs = self.pending_proofs.write().await; let mint_proofs = all_proofs.entry(mint_url).or_insert(HashSet::new()); mint_proofs.extend(proofs); @@ -198,7 +198,7 @@ impl WalletDatabase for WalletMemoryDatabase { async fn get_pending_proofs(&self, mint_url: UncheckedUrl) -> Result, Error> { Ok(self .pending_proofs - .lock() + .read() .await .get(&mint_url) .map(|p| p.iter().cloned().collect())) @@ -209,7 +209,7 @@ impl WalletDatabase for WalletMemoryDatabase { mint_url: UncheckedUrl, proofs: &Proofs, ) -> Result<(), Error> { - let mut mint_proofs = self.pending_proofs.lock().await; + let mut mint_proofs = self.pending_proofs.write().await; if let Some(mint_proofs) = mint_proofs.get_mut(&mint_url) { for proof in proofs { @@ -221,16 +221,16 @@ impl WalletDatabase for WalletMemoryDatabase { } async fn increment_keyset_counter(&self, keyset_id: &Id, count: u32) -> Result<(), Error> { - let keyset_counter = self.keyset_counter.lock().await; + let keyset_counter = self.keyset_counter.read().await; let current_counter = keyset_counter.get(keyset_id).unwrap_or(&0); self.keyset_counter - .lock() + .write() .await .insert(*keyset_id, current_counter + count); Ok(()) } async fn get_keyset_counter(&self, id: &Id) -> Result, Error> { - Ok(self.keyset_counter.lock().await.get(id).cloned()) + Ok(self.keyset_counter.read().await.get(id).cloned()) } }