diff --git a/Cargo.toml b/Cargo.toml index 697a44d7..b2b8e627 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,7 @@ lightning-invoice = { version = "0.32.0", features = ["serde", "std"] } serde = { version = "1", features = ["derive"] } serde_json = "1" thiserror = { version = "1" } -tokio = { version = "1", default-features = false } +tokio = { version = "1", default-features = false, features = ["rt", "macros", "test-util"] } tokio-util = { version = "0.7.11", default-features = false } tower-http = { version = "0.6.1", features = ["compression-full", "decompression-full", "cors", "trace"] } tokio-tungstenite = { version = "0.26.0", default-features = false } diff --git a/crates/cdk-common/src/database/mod.rs b/crates/cdk-common/src/database/mod.rs index f9a0e5ca..06d86524 100644 --- a/crates/cdk-common/src/database/mod.rs +++ b/crates/cdk-common/src/database/mod.rs @@ -31,4 +31,10 @@ pub enum Error { /// Unknown Quote #[error("Unknown Quote")] UnknownQuote, + /// Attempt to remove spent proof + #[error("Attempt to remove spent proof")] + AttemptRemoveSpentProof, + /// Attempt to update state of spent proof + #[error("Attempt to update state of spent proof")] + AttemptUpdateSpentProof, } diff --git a/crates/cdk-redb/Cargo.toml b/crates/cdk-redb/Cargo.toml index 1f600046..b839b6c5 100644 --- a/crates/cdk-redb/Cargo.toml +++ b/crates/cdk-redb/Cargo.toml @@ -25,3 +25,7 @@ serde.workspace = true serde_json.workspace = true lightning-invoice.workspace = true uuid.workspace = true + +[dev-dependencies] +tempfile = "3.17.1" +tokio.workspace = true diff --git a/crates/cdk-redb/src/mint/mod.rs b/crates/cdk-redb/src/mint/mod.rs index ef1cdf92..86972dc5 100644 --- a/crates/cdk-redb/src/mint/mod.rs +++ b/crates/cdk-redb/src/mint/mod.rs @@ -1,7 +1,7 @@ //! SQLite Storage for CDK use std::cmp::Ordering; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::path::Path; use std::str::FromStr; use std::sync::Arc; @@ -558,22 +558,36 @@ impl MintDatabase for MintRedbDatabase { ) -> Result<(), Self::Err> { let write_txn = self.db.begin_write().map_err(Error::from)?; - { - let mut proofs_table = write_txn.open_table(PROOFS_TABLE).map_err(Error::from)?; - - for y in ys { - proofs_table.remove(&y.to_bytes()).map_err(Error::from)?; - } - } + let mut states: HashSet = HashSet::new(); { let mut proof_state_table = write_txn .open_table(PROOFS_STATE_TABLE) .map_err(Error::from)?; for y in ys { - proof_state_table + let state = proof_state_table .remove(&y.to_bytes()) .map_err(Error::from)?; + + if let Some(state) = state { + let state: State = serde_json::from_str(state.value()).map_err(Error::from)?; + + states.insert(state); + } + } + } + + if states.contains(&State::Spent) { + tracing::warn!("Db attempted to remove spent proof"); + write_txn.abort().map_err(Error::from)?; + return Err(Self::Err::AttemptRemoveSpentProof); + } + + { + let mut proofs_table = write_txn.open_table(PROOFS_TABLE).map_err(Error::from)?; + + for y in ys { + proofs_table.remove(&y.to_bytes()).map_err(Error::from)?; } } @@ -684,37 +698,44 @@ impl MintDatabase for MintRedbDatabase { let write_txn = self.db.begin_write().map_err(Error::from)?; let mut states = Vec::with_capacity(ys.len()); + { + let table = write_txn + .open_table(PROOFS_STATE_TABLE) + .map_err(Error::from)?; + { + // First collect current states + for y in ys { + let current_state = match table.get(y.to_bytes()).map_err(Error::from)? { + Some(state) => { + Some(serde_json::from_str(state.value()).map_err(Error::from)?) + } + None => None, + }; + states.push(current_state); + } + } + } - let state_str = serde_json::to_string(&proofs_state).map_err(Error::from)?; + // Check if any proofs are spent + if states.iter().any(|state| *state == Some(State::Spent)) { + write_txn.abort().map_err(Error::from)?; + return Err(database::Error::AttemptUpdateSpentProof); + } { let mut table = write_txn .open_table(PROOFS_STATE_TABLE) .map_err(Error::from)?; - - for y in ys { - let current_state; - { - match table.get(y.to_bytes()).map_err(Error::from)? { - Some(state) => { - current_state = - Some(serde_json::from_str(state.value()).map_err(Error::from)?) - } - None => current_state = None, - } - } - states.push(current_state); - } - - for (y, current_state) in ys.iter().zip(&states) { - if current_state != &Some(State::Spent) { + { + // If no proofs are spent, proceed with update + let state_str = serde_json::to_string(&proofs_state).map_err(Error::from)?; + for y in ys { table .insert(y.to_bytes(), state_str.as_str()) .map_err(Error::from)?; } } } - write_txn.commit().map_err(Error::from)?; Ok(states) @@ -924,3 +945,137 @@ impl MintDatabase for MintRedbDatabase { Err(Error::UnknownQuoteTTL.into()) } } + +#[cfg(test)] +mod tests { + use cdk_common::secret::Secret; + use cdk_common::{Amount, SecretKey}; + use tempfile::tempdir; + + use super::*; + + #[tokio::test] + async fn test_remove_spent_proofs() { + let tmp_dir = tempdir().unwrap(); + + let db = MintRedbDatabase::new(&tmp_dir.path().join("mint.redb")).unwrap(); + // Create some test proofs + let keyset_id = Id::from_str("00916bbf7ef91a36").unwrap(); + + let proofs = vec![ + Proof { + amount: Amount::from(100), + keyset_id: keyset_id.clone(), + secret: Secret::generate(), + c: SecretKey::generate().public_key(), + witness: None, + dleq: None, + }, + Proof { + amount: Amount::from(200), + keyset_id: keyset_id.clone(), + secret: Secret::generate(), + c: SecretKey::generate().public_key(), + witness: None, + dleq: None, + }, + ]; + + // Add proofs to database + db.add_proofs(proofs.clone(), None).await.unwrap(); + + // Mark one proof as spent + db.update_proofs_states(&[proofs[0].y().unwrap()], State::Spent) + .await + .unwrap(); + + db.update_proofs_states(&[proofs[1].y().unwrap()], State::Unspent) + .await + .unwrap(); + + // Try to remove both proofs - should fail because one is spent + let result = db + .remove_proofs(&[proofs[0].y().unwrap(), proofs[1].y().unwrap()], None) + .await; + + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + database::Error::AttemptRemoveSpentProof + )); + + // Verify both proofs still exist + let states = db + .get_proofs_states(&[proofs[0].y().unwrap(), proofs[1].y().unwrap()]) + .await + .unwrap(); + + assert_eq!(states.len(), 2); + assert_eq!(states[0], Some(State::Spent)); + assert_eq!(states[1], Some(State::Unspent)); + } + + #[tokio::test] + async fn test_update_spent_proofs() { + let tmp_dir = tempdir().unwrap(); + + let db = MintRedbDatabase::new(&tmp_dir.path().join("mint.redb")).unwrap(); + // Create some test proofs + let keyset_id = Id::from_str("00916bbf7ef91a36").unwrap(); + + let proofs = vec![ + Proof { + amount: Amount::from(100), + keyset_id: keyset_id.clone(), + secret: Secret::generate(), + c: SecretKey::generate().public_key(), + witness: None, + dleq: None, + }, + Proof { + amount: Amount::from(200), + keyset_id: keyset_id.clone(), + secret: Secret::generate(), + c: SecretKey::generate().public_key(), + witness: None, + dleq: None, + }, + ]; + + // Add proofs to database + db.add_proofs(proofs.clone(), None).await.unwrap(); + + // Mark one proof as spent + db.update_proofs_states(&[proofs[0].y().unwrap()], State::Spent) + .await + .unwrap(); + + db.update_proofs_states(&[proofs[1].y().unwrap()], State::Unspent) + .await + .unwrap(); + + // Mark one proof as spent + let result = db + .update_proofs_states( + &[proofs[0].y().unwrap(), proofs[1].y().unwrap()], + State::Unspent, + ) + .await; + + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + database::Error::AttemptUpdateSpentProof + )); + + // Verify both proofs still exist + let states = db + .get_proofs_states(&[proofs[0].y().unwrap(), proofs[1].y().unwrap()]) + .await + .unwrap(); + + assert_eq!(states.len(), 2); + assert_eq!(states[0], Some(State::Spent)); + assert_eq!(states[1], Some(State::Unspent)); + } +} diff --git a/crates/cdk-sqlite/src/mint/mod.rs b/crates/cdk-sqlite/src/mint/mod.rs index 366d57e9..fb272d01 100644 --- a/crates/cdk-sqlite/src/mint/mod.rs +++ b/crates/cdk-sqlite/src/mint/mod.rs @@ -1,6 +1,6 @@ //! SQLite Mint -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::path::Path; use std::str::FromStr; @@ -36,6 +36,37 @@ pub struct MintSqliteDatabase { } impl MintSqliteDatabase { + /// Check if any proofs are spent + async fn check_for_spent_proofs( + &self, + transaction: &mut sqlx::Transaction<'_, sqlx::Sqlite>, + ys: &[PublicKey], + ) -> Result { + if ys.is_empty() { + return Ok(false); + } + + let check_sql = format!( + "SELECT state FROM proof WHERE y IN ({}) AND state = 'SPENT'", + std::iter::repeat("?") + .take(ys.len()) + .collect::>() + .join(",") + ); + + let spent_count = ys + .iter() + .fold(sqlx::query(&check_sql), |query, y| { + query.bind(y.to_bytes().to_vec()) + }) + .fetch_all(&mut *transaction) + .await + .map_err(Error::from)? + .len(); + + Ok(spent_count > 0) + } + /// Create new [`MintSqliteDatabase`] pub async fn new>(path: P) -> Result { Ok(Self { @@ -858,7 +889,13 @@ VALUES (?, ?, ?, ?, ?, ?, ?, ?); ) -> Result<(), Self::Err> { let mut transaction = self.pool.begin().await.map_err(Error::from)?; - let sql = format!( + if self.check_for_spent_proofs(&mut transaction, ys).await? { + transaction.rollback().await.map_err(Error::from)?; + return Err(Self::Err::AttemptRemoveSpentProof); + } + + // If no proofs are spent, proceed with deletion + let delete_sql = format!( "DELETE FROM proof WHERE y IN ({})", std::iter::repeat("?") .take(ys.len()) @@ -867,7 +904,7 @@ VALUES (?, ?, ?, ?, ?, ?, ?, ?); ); ys.iter() - .fold(sqlx::query(&sql), |query, y| { + .fold(sqlx::query(&delete_sql), |query, y| { query.bind(y.to_bytes().to_vec()) }) .execute(&mut transaction) @@ -1064,16 +1101,23 @@ WHERE keyset_id=?; }) .collect::, _>>()?; + let states = current_states.values().collect::>(); + + if states.contains(&State::Spent) { + transaction.rollback().await.map_err(Error::from)?; + tracing::warn!("Attempted to update state of spent proof"); + return Err(database::Error::AttemptUpdateSpentProof); + } + + // If no proofs are spent, proceed with update let update_sql = format!( - "UPDATE proof SET state = ? WHERE state != ? AND y IN ({})", + "UPDATE proof SET state = ? WHERE y IN ({})", "?,".repeat(ys.len()).trim_end_matches(',') ); ys.iter() .fold( - sqlx::query(&update_sql) - .bind(proofs_state.to_string()) - .bind(State::Spent.to_string()), + sqlx::query(&update_sql).bind(proofs_state.to_string()), |query, y| query.bind(y.to_bytes().to_vec()), ) .execute(&mut transaction) @@ -1647,3 +1691,125 @@ fn sqlite_row_to_melt_request(row: SqliteRow) -> Result<(MeltBolt11Request Ok((melt_request, ln_key)) } + +#[cfg(test)] +mod tests { + use cdk_common::Amount; + + use super::*; + + #[tokio::test] + async fn test_remove_spent_proofs() { + let db = memory::empty().await.unwrap(); + + // Create some test proofs + let keyset_id = Id::from_str("00916bbf7ef91a36").unwrap(); + + let proofs = vec![ + Proof { + amount: Amount::from(100), + keyset_id: keyset_id.clone(), + secret: Secret::generate(), + c: SecretKey::generate().public_key(), + witness: None, + dleq: None, + }, + Proof { + amount: Amount::from(200), + keyset_id: keyset_id.clone(), + secret: Secret::generate(), + c: SecretKey::generate().public_key(), + witness: None, + dleq: None, + }, + ]; + + // Add proofs to database + db.add_proofs(proofs.clone(), None).await.unwrap(); + + // Mark one proof as spent + db.update_proofs_states(&[proofs[0].y().unwrap()], State::Spent) + .await + .unwrap(); + + // Try to remove both proofs - should fail because one is spent + let result = db + .remove_proofs(&[proofs[0].y().unwrap(), proofs[1].y().unwrap()], None) + .await; + + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + database::Error::AttemptRemoveSpentProof + )); + + // Verify both proofs still exist + let states = db + .get_proofs_states(&[proofs[0].y().unwrap(), proofs[1].y().unwrap()]) + .await + .unwrap(); + + assert_eq!(states.len(), 2); + assert_eq!(states[0], Some(State::Spent)); + assert_eq!(states[1], Some(State::Unspent)); + } + + #[tokio::test] + async fn test_update_spent_proofs() { + let db = memory::empty().await.unwrap(); + + // Create some test proofs + let keyset_id = Id::from_str("00916bbf7ef91a36").unwrap(); + + let proofs = vec![ + Proof { + amount: Amount::from(100), + keyset_id: keyset_id.clone(), + secret: Secret::generate(), + c: SecretKey::generate().public_key(), + witness: None, + dleq: None, + }, + Proof { + amount: Amount::from(200), + keyset_id: keyset_id.clone(), + secret: Secret::generate(), + c: SecretKey::generate().public_key(), + witness: None, + dleq: None, + }, + ]; + + // Add proofs to database + db.add_proofs(proofs.clone(), None).await.unwrap(); + + // Mark one proof as spent + db.update_proofs_states(&[proofs[0].y().unwrap()], State::Spent) + .await + .unwrap(); + + // Try to update both proofs - should fail because one is spent + let result = db + .update_proofs_states( + &[proofs[0].y().unwrap(), proofs[1].y().unwrap()], + State::Reserved, + ) + .await; + + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + database::Error::AttemptUpdateSpentProof + )); + + // Verify states haven't changed + let states = db + .get_proofs_states(&[proofs[0].y().unwrap(), proofs[1].y().unwrap()]) + .await + .unwrap(); + + assert_eq!(states.len(), 2); + assert_eq!(states[0], Some(State::Spent)); + assert_eq!(states[1], Some(State::Unspent)); + } +} diff --git a/crates/cdk/src/mint/check_spendable.rs b/crates/cdk/src/mint/check_spendable.rs index d93d8759..6f27a23c 100644 --- a/crates/cdk/src/mint/check_spendable.rs +++ b/crates/cdk/src/mint/check_spendable.rs @@ -3,7 +3,7 @@ use std::collections::HashSet; use tracing::instrument; use super::{CheckStateRequest, CheckStateResponse, Mint, ProofState, PublicKey, State}; -use crate::Error; +use crate::{cdk_database, Error}; impl Mint { /// Check state @@ -41,10 +41,15 @@ impl Mint { ys: &[PublicKey], proof_state: State, ) -> Result<(), Error> { - let original_proofs_state = self - .localstore - .update_proofs_states(ys, proof_state) - .await?; + let original_proofs_state = + match self.localstore.update_proofs_states(ys, proof_state).await { + Ok(states) => states, + Err(cdk_database::Error::AttemptUpdateSpentProof) + | Err(cdk_database::Error::AttemptRemoveSpentProof) => { + return Err(Error::TokenAlreadySpent) + } + Err(err) => return Err(err.into()), + }; let proofs_state = original_proofs_state .iter()