diff --git a/crates/cdk-common/src/database/wallet.rs b/crates/cdk-common/src/database/wallet.rs index 0619ad38..8889b471 100644 --- a/crates/cdk-common/src/database/wallet.rs +++ b/crates/cdk-common/src/database/wallet.rs @@ -96,6 +96,13 @@ pub trait Database: Debug { state: Option>, spending_conditions: Option>, ) -> Result, Self::Err>; + /// Get balance + async fn get_balance( + &self, + mint_url: Option, + unit: Option, + state: Option>, + ) -> Result; /// Update proofs state in storage async fn update_proofs_state(&self, ys: Vec, state: State) -> Result<(), Self::Err>; diff --git a/crates/cdk-ffi/src/database.rs b/crates/cdk-ffi/src/database.rs index 578681b6..cdf6eb57 100644 --- a/crates/cdk-ffi/src/database.rs +++ b/crates/cdk-ffi/src/database.rs @@ -109,6 +109,14 @@ pub trait WalletDatabase: Send + Sync { spending_conditions: Option>, ) -> Result, FfiError>; + /// Get balance efficiently using SQL aggregation + async fn get_balance( + &self, + mint_url: Option, + unit: Option, + state: Option>, + ) -> Result; + /// Update proofs state in storage async fn update_proofs_state( &self, @@ -465,6 +473,22 @@ impl CdkWalletDatabase for WalletDatabaseBridge { cdk_result } + async fn get_balance( + &self, + mint_url: Option, + unit: Option, + state: Option>, + ) -> Result { + let ffi_mint_url = mint_url.map(Into::into); + let ffi_unit = unit.map(Into::into); + let ffi_state = state.map(|s| s.into_iter().map(Into::into).collect()); + + self.ffi_db + .get_balance(ffi_mint_url, ffi_unit, ffi_state) + .await + .map_err(|e| cdk::cdk_database::Error::Database(e.to_string().into())) + } + async fn update_proofs_state( &self, ys: Vec, diff --git a/crates/cdk-ffi/src/postgres.rs b/crates/cdk-ffi/src/postgres.rs index f128b53f..4c46a69b 100644 --- a/crates/cdk-ffi/src/postgres.rs +++ b/crates/cdk-ffi/src/postgres.rs @@ -289,6 +289,22 @@ impl WalletDatabase for WalletPostgresDatabase { Ok(result.into_iter().map(Into::into).collect()) } + async fn get_balance( + &self, + mint_url: Option, + unit: Option, + state: Option>, + ) -> Result { + let cdk_mint_url = mint_url.map(|u| u.try_into()).transpose()?; + let cdk_unit = unit.map(Into::into); + let cdk_state = state.map(|s| s.into_iter().map(Into::into).collect()); + + self.inner + .get_balance(cdk_mint_url, cdk_unit, cdk_state) + .await + .map_err(|e| FfiError::Database { msg: e.to_string() }) + } + async fn update_proofs_state( &self, ys: Vec, diff --git a/crates/cdk-ffi/src/sqlite.rs b/crates/cdk-ffi/src/sqlite.rs index 8feae6e5..b48bb36a 100644 --- a/crates/cdk-ffi/src/sqlite.rs +++ b/crates/cdk-ffi/src/sqlite.rs @@ -324,6 +324,22 @@ impl WalletDatabase for WalletSqliteDatabase { Ok(result.into_iter().map(Into::into).collect()) } + async fn get_balance( + &self, + mint_url: Option, + unit: Option, + state: Option>, + ) -> Result { + let cdk_mint_url = mint_url.map(|u| u.try_into()).transpose()?; + let cdk_unit = unit.map(Into::into); + let cdk_state = state.map(|s| s.into_iter().map(Into::into).collect()); + + self.inner + .get_balance(cdk_mint_url, cdk_unit, cdk_state) + .await + .map_err(|e| FfiError::Database { msg: e.to_string() }) + } + async fn update_proofs_state( &self, ys: Vec, diff --git a/crates/cdk-redb/src/wallet/mod.rs b/crates/cdk-redb/src/wallet/mod.rs index 37862fc2..f25c712e 100644 --- a/crates/cdk-redb/src/wallet/mod.rs +++ b/crates/cdk-redb/src/wallet/mod.rs @@ -721,6 +721,18 @@ impl WalletDatabase for WalletRedbDatabase { Ok(proofs) } + async fn get_balance( + &self, + mint_url: Option, + unit: Option, + state: Option>, + ) -> Result { + // For redb, we still need to fetch all proofs and sum them + // since redb doesn't have SQL aggregation + let proofs = self.get_proofs(mint_url, unit, state, None).await?; + Ok(proofs.iter().map(|p| u64::from(p.proof.amount)).sum()) + } + async fn update_proofs_state( &self, ys: Vec, diff --git a/crates/cdk-sql-common/src/wallet/mod.rs b/crates/cdk-sql-common/src/wallet/mod.rs index 8e2817d2..64371613 100644 --- a/crates/cdk-sql-common/src/wallet/mod.rs +++ b/crates/cdk-sql-common/src/wallet/mod.rs @@ -836,6 +836,70 @@ ON CONFLICT(id) DO UPDATE SET .collect::>()) } + async fn get_balance( + &self, + mint_url: Option, + unit: Option, + states: Option>, + ) -> Result { + let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?; + + let mut query_str = "SELECT COALESCE(SUM(amount), 0) as total FROM proof".to_string(); + let mut where_clauses = Vec::new(); + let states = states + .unwrap_or_default() + .into_iter() + .map(|x| x.to_string()) + .collect::>(); + + if mint_url.is_some() { + where_clauses.push("mint_url = :mint_url"); + } + if unit.is_some() { + where_clauses.push("unit = :unit"); + } + if !states.is_empty() { + where_clauses.push("state IN (:states)"); + } + + if !where_clauses.is_empty() { + query_str.push_str(" WHERE "); + query_str.push_str(&where_clauses.join(" AND ")); + } + + let mut q = query(&query_str)?; + + if let Some(ref mint_url) = mint_url { + q = q.bind("mint_url", mint_url.to_string()); + } + if let Some(ref unit) = unit { + q = q.bind("unit", unit.to_string()); + } + + if !states.is_empty() { + q = q.bind_vec("states", states); + } + + let balance = q + .pluck(&*conn) + .await? + .map(|n| { + // SQLite SUM returns INTEGER which we need to convert to u64 + match n { + crate::stmt::Column::Integer(i) => Ok(i as u64), + crate::stmt::Column::Real(f) => Ok(f as u64), + _ => Err(Error::Database(Box::new(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Invalid balance type", + )))), + } + }) + .transpose()? + .unwrap_or(0); + + Ok(balance) + } + async fn update_proofs_state(&self, ys: Vec, state: State) -> Result<(), Self::Err> { let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?; query("UPDATE proof SET state = :state WHERE y IN (:ys)")? diff --git a/crates/cdk/src/wallet/balance.rs b/crates/cdk/src/wallet/balance.rs index ce095162..1f83002d 100644 --- a/crates/cdk/src/wallet/balance.rs +++ b/crates/cdk/src/wallet/balance.rs @@ -1,13 +1,23 @@ use tracing::instrument; use crate::nuts::nut00::ProofsMethods; +use crate::nuts::State; use crate::{Amount, Error, Wallet}; impl Wallet { /// Total unspent balance of wallet #[instrument(skip(self))] pub async fn total_balance(&self) -> Result { - Ok(self.get_unspent_proofs().await?.total_amount()?) + // Use the efficient balance query instead of fetching all proofs + let balance = self + .localstore + .get_balance( + Some(self.mint_url.clone()), + Some(self.unit.clone()), + Some(vec![State::Unspent]), + ) + .await?; + Ok(Amount::from(balance)) } /// Total pending balance diff --git a/crates/cdk/src/wallet/subscription/mod.rs b/crates/cdk/src/wallet/subscription/mod.rs index 692a6ad1..12143c38 100644 --- a/crates/cdk/src/wallet/subscription/mod.rs +++ b/crates/cdk/src/wallet/subscription/mod.rs @@ -12,7 +12,6 @@ use std::sync::Arc; use cdk_common::subscription::Params; use tokio::sync::{mpsc, RwLock}; use tokio::task::JoinHandle; -use tracing::error; #[cfg(target_arch = "wasm32")] use wasm_bindgen_futures;