mirror of
https://github.com/aljazceru/cdk.git
synced 2025-12-19 13:44:55 +01:00
refactor: Add state check before deleting proofs to prevent removing spent proofs
This commit is contained in:
committed by
thesimplekid
parent
c6200331cf
commit
d41d3a7c94
@@ -1,6 +1,6 @@
|
||||
//! SQLite Mint
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
|
||||
@@ -36,6 +36,37 @@ pub struct MintSqliteDatabase {
|
||||
}
|
||||
|
||||
impl MintSqliteDatabase {
|
||||
/// Check if any proofs are spent
|
||||
async fn check_for_spent_proofs(
|
||||
&self,
|
||||
transaction: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
|
||||
ys: &[PublicKey],
|
||||
) -> Result<bool, database::Error> {
|
||||
if ys.is_empty() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let check_sql = format!(
|
||||
"SELECT state FROM proof WHERE y IN ({}) AND state = 'SPENT'",
|
||||
std::iter::repeat("?")
|
||||
.take(ys.len())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
|
||||
let spent_count = ys
|
||||
.iter()
|
||||
.fold(sqlx::query(&check_sql), |query, y| {
|
||||
query.bind(y.to_bytes().to_vec())
|
||||
})
|
||||
.fetch_all(&mut *transaction)
|
||||
.await
|
||||
.map_err(Error::from)?
|
||||
.len();
|
||||
|
||||
Ok(spent_count > 0)
|
||||
}
|
||||
|
||||
/// Create new [`MintSqliteDatabase`]
|
||||
pub async fn new<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
|
||||
Ok(Self {
|
||||
@@ -858,7 +889,13 @@ VALUES (?, ?, ?, ?, ?, ?, ?, ?);
|
||||
) -> Result<(), Self::Err> {
|
||||
let mut transaction = self.pool.begin().await.map_err(Error::from)?;
|
||||
|
||||
let sql = format!(
|
||||
if self.check_for_spent_proofs(&mut transaction, ys).await? {
|
||||
transaction.rollback().await.map_err(Error::from)?;
|
||||
return Err(Self::Err::AttemptRemoveSpentProof);
|
||||
}
|
||||
|
||||
// If no proofs are spent, proceed with deletion
|
||||
let delete_sql = format!(
|
||||
"DELETE FROM proof WHERE y IN ({})",
|
||||
std::iter::repeat("?")
|
||||
.take(ys.len())
|
||||
@@ -867,7 +904,7 @@ VALUES (?, ?, ?, ?, ?, ?, ?, ?);
|
||||
);
|
||||
|
||||
ys.iter()
|
||||
.fold(sqlx::query(&sql), |query, y| {
|
||||
.fold(sqlx::query(&delete_sql), |query, y| {
|
||||
query.bind(y.to_bytes().to_vec())
|
||||
})
|
||||
.execute(&mut transaction)
|
||||
@@ -1064,16 +1101,23 @@ WHERE keyset_id=?;
|
||||
})
|
||||
.collect::<Result<HashMap<_, _>, _>>()?;
|
||||
|
||||
let states = current_states.values().collect::<HashSet<_>>();
|
||||
|
||||
if states.contains(&State::Spent) {
|
||||
transaction.rollback().await.map_err(Error::from)?;
|
||||
tracing::warn!("Attempted to update state of spent proof");
|
||||
return Err(database::Error::AttemptUpdateSpentProof);
|
||||
}
|
||||
|
||||
// If no proofs are spent, proceed with update
|
||||
let update_sql = format!(
|
||||
"UPDATE proof SET state = ? WHERE state != ? AND y IN ({})",
|
||||
"UPDATE proof SET state = ? WHERE y IN ({})",
|
||||
"?,".repeat(ys.len()).trim_end_matches(',')
|
||||
);
|
||||
|
||||
ys.iter()
|
||||
.fold(
|
||||
sqlx::query(&update_sql)
|
||||
.bind(proofs_state.to_string())
|
||||
.bind(State::Spent.to_string()),
|
||||
sqlx::query(&update_sql).bind(proofs_state.to_string()),
|
||||
|query, y| query.bind(y.to_bytes().to_vec()),
|
||||
)
|
||||
.execute(&mut transaction)
|
||||
@@ -1647,3 +1691,125 @@ fn sqlite_row_to_melt_request(row: SqliteRow) -> Result<(MeltBolt11Request<Uuid>
|
||||
|
||||
Ok((melt_request, ln_key))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use cdk_common::Amount;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_remove_spent_proofs() {
|
||||
let db = memory::empty().await.unwrap();
|
||||
|
||||
// Create some test proofs
|
||||
let keyset_id = Id::from_str("00916bbf7ef91a36").unwrap();
|
||||
|
||||
let proofs = vec![
|
||||
Proof {
|
||||
amount: Amount::from(100),
|
||||
keyset_id: keyset_id.clone(),
|
||||
secret: Secret::generate(),
|
||||
c: SecretKey::generate().public_key(),
|
||||
witness: None,
|
||||
dleq: None,
|
||||
},
|
||||
Proof {
|
||||
amount: Amount::from(200),
|
||||
keyset_id: keyset_id.clone(),
|
||||
secret: Secret::generate(),
|
||||
c: SecretKey::generate().public_key(),
|
||||
witness: None,
|
||||
dleq: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Add proofs to database
|
||||
db.add_proofs(proofs.clone(), None).await.unwrap();
|
||||
|
||||
// Mark one proof as spent
|
||||
db.update_proofs_states(&[proofs[0].y().unwrap()], State::Spent)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Try to remove both proofs - should fail because one is spent
|
||||
let result = db
|
||||
.remove_proofs(&[proofs[0].y().unwrap(), proofs[1].y().unwrap()], None)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
database::Error::AttemptRemoveSpentProof
|
||||
));
|
||||
|
||||
// Verify both proofs still exist
|
||||
let states = db
|
||||
.get_proofs_states(&[proofs[0].y().unwrap(), proofs[1].y().unwrap()])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(states.len(), 2);
|
||||
assert_eq!(states[0], Some(State::Spent));
|
||||
assert_eq!(states[1], Some(State::Unspent));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_spent_proofs() {
|
||||
let db = memory::empty().await.unwrap();
|
||||
|
||||
// Create some test proofs
|
||||
let keyset_id = Id::from_str("00916bbf7ef91a36").unwrap();
|
||||
|
||||
let proofs = vec![
|
||||
Proof {
|
||||
amount: Amount::from(100),
|
||||
keyset_id: keyset_id.clone(),
|
||||
secret: Secret::generate(),
|
||||
c: SecretKey::generate().public_key(),
|
||||
witness: None,
|
||||
dleq: None,
|
||||
},
|
||||
Proof {
|
||||
amount: Amount::from(200),
|
||||
keyset_id: keyset_id.clone(),
|
||||
secret: Secret::generate(),
|
||||
c: SecretKey::generate().public_key(),
|
||||
witness: None,
|
||||
dleq: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Add proofs to database
|
||||
db.add_proofs(proofs.clone(), None).await.unwrap();
|
||||
|
||||
// Mark one proof as spent
|
||||
db.update_proofs_states(&[proofs[0].y().unwrap()], State::Spent)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Try to update both proofs - should fail because one is spent
|
||||
let result = db
|
||||
.update_proofs_states(
|
||||
&[proofs[0].y().unwrap(), proofs[1].y().unwrap()],
|
||||
State::Reserved,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
database::Error::AttemptUpdateSpentProof
|
||||
));
|
||||
|
||||
// Verify states haven't changed
|
||||
let states = db
|
||||
.get_proofs_states(&[proofs[0].y().unwrap(), proofs[1].y().unwrap()])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(states.len(), 2);
|
||||
assert_eq!(states[0], Some(State::Spent));
|
||||
assert_eq!(states[1], Some(State::Unspent));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user