fix: implement atomic keyset counter

- remove get_keyset_counter
- update increment_keyset_counter to atomically increment and return counter value
- replace get+increment pattern with atomic increment everywhere
This commit is contained in:
vnprc
2025-08-08 18:59:16 -04:00
committed by thesimplekid
parent 51b26eae62
commit 5b30ca546d
7 changed files with 152 additions and 114 deletions

View File

@@ -99,10 +99,8 @@ pub trait Database: Debug {
/// Update proofs state in storage /// Update proofs state in storage
async fn update_proofs_state(&self, ys: Vec<PublicKey>, state: State) -> Result<(), Self::Err>; async fn update_proofs_state(&self, ys: Vec<PublicKey>, state: State) -> Result<(), Self::Err>;
/// Increment Keyset counter /// Atomically increment Keyset counter and return new value
async fn increment_keyset_counter(&self, keyset_id: &Id, count: u32) -> Result<(), Self::Err>; async fn increment_keyset_counter(&self, keyset_id: &Id, count: u32) -> Result<u32, Self::Err>;
/// Get current Keyset counter
async fn get_keyset_counter(&self, keyset_id: &Id) -> Result<u32, Self::Err>;
/// Add transaction to storage /// Add transaction to storage
async fn add_transaction(&self, transaction: Transaction) -> Result<(), Self::Err>; async fn add_transaction(&self, transaction: Transaction) -> Result<(), Self::Err>;

View File

@@ -760,10 +760,11 @@ impl WalletDatabase for WalletRedbDatabase {
} }
#[instrument(skip(self), fields(keyset_id = %keyset_id))] #[instrument(skip(self), fields(keyset_id = %keyset_id))]
async fn increment_keyset_counter(&self, keyset_id: &Id, count: u32) -> Result<(), Self::Err> { async fn increment_keyset_counter(&self, keyset_id: &Id, count: u32) -> Result<u32, Self::Err> {
let write_txn = self.db.begin_write().map_err(Error::from)?; let write_txn = self.db.begin_write().map_err(Error::from)?;
let current_counter; let current_counter;
let new_counter;
{ {
let table = write_txn.open_table(KEYSET_COUNTER).map_err(Error::from)?; let table = write_txn.open_table(KEYSET_COUNTER).map_err(Error::from)?;
let counter = table let counter = table
@@ -774,11 +775,12 @@ impl WalletDatabase for WalletRedbDatabase {
Some(c) => c.value(), Some(c) => c.value(),
None => 0, None => 0,
}; };
new_counter = current_counter + count;
} }
{ {
let mut table = write_txn.open_table(KEYSET_COUNTER).map_err(Error::from)?; let mut table = write_txn.open_table(KEYSET_COUNTER).map_err(Error::from)?;
let new_counter = current_counter + count;
table table
.insert(keyset_id.to_string().as_str(), new_counter) .insert(keyset_id.to_string().as_str(), new_counter)
@@ -786,19 +788,7 @@ impl WalletDatabase for WalletRedbDatabase {
} }
write_txn.commit().map_err(Error::from)?; write_txn.commit().map_err(Error::from)?;
Ok(()) Ok(new_counter)
}
#[instrument(skip(self), fields(keyset_id = %keyset_id))]
async fn get_keyset_counter(&self, keyset_id: &Id) -> Result<u32, Self::Err> {
let read_txn = self.db.begin_read().map_err(Error::from)?;
let table = read_txn.open_table(KEYSET_COUNTER).map_err(Error::from)?;
let counter = table
.get(keyset_id.to_string().as_str())
.map_err(Error::from)?;
Ok(counter.map_or(0, |c| c.value()))
} }
#[instrument(skip(self))] #[instrument(skip(self))]

View File

@@ -839,42 +839,44 @@ ON CONFLICT(id) DO UPDATE SET
} }
#[instrument(skip(self), fields(keyset_id = %keyset_id))] #[instrument(skip(self), fields(keyset_id = %keyset_id))]
async fn increment_keyset_counter(&self, keyset_id: &Id, count: u32) -> Result<(), Self::Err> { async fn increment_keyset_counter(&self, keyset_id: &Id, count: u32) -> Result<u32, Self::Err> {
let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?; let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
query( let tx = ConnectionWithTransaction::new(conn).await?;
// Lock the row and get current counter
let current_counter = query(
r#" r#"
UPDATE keyset SELECT counter
SET counter=counter+:count FROM keyset
WHERE id=:id WHERE id=:id
"#, FOR UPDATE
)?
.bind("count", count)
.bind("id", keyset_id.to_string())
.execute(&*conn)
.await?;
Ok(())
}
#[instrument(skip(self), fields(keyset_id = %keyset_id))]
async fn get_keyset_counter(&self, keyset_id: &Id) -> Result<u32, Self::Err> {
let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
Ok(query(
r#"
SELECT
counter
FROM
keyset
WHERE
id=:id
"#, "#,
)? )?
.bind("id", keyset_id.to_string()) .bind("id", keyset_id.to_string())
.pluck(&*conn) .pluck(&tx)
.await? .await?
.map(|n| Ok::<_, Error>(column_as_number!(n))) .map(|n| Ok::<_, Error>(column_as_number!(n)))
.transpose()? .transpose()?
.unwrap_or(0)) .unwrap_or(0);
let new_counter = current_counter + count;
// Update with the new counter value
query(
r#"
UPDATE keyset
SET counter=:new_counter
WHERE id=:id
"#,
)?
.bind("new_counter", new_counter)
.bind("id", keyset_id.to_string())
.execute(&tx)
.await?;
tx.commit().await?;
Ok(new_counter)
} }
#[instrument(skip(self))] #[instrument(skip(self))]

View File

@@ -229,11 +229,6 @@ impl Wallet {
let active_keyset_id = self.fetch_active_keyset().await?.id; let active_keyset_id = self.fetch_active_keyset().await?.id;
let count = self
.localstore
.get_keyset_counter(&active_keyset_id)
.await?;
let premint_secrets = match &spending_conditions { let premint_secrets = match &spending_conditions {
Some(spending_conditions) => PreMintSecrets::with_conditions( Some(spending_conditions) => PreMintSecrets::with_conditions(
active_keyset_id, active_keyset_id,
@@ -241,13 +236,33 @@ impl Wallet {
&amount_split_target, &amount_split_target,
spending_conditions, spending_conditions,
)?, )?,
None => PreMintSecrets::from_seed( None => {
active_keyset_id, // Calculate how many secrets we'll need
count, let amount_split = amount_mintable.split_targeted(&amount_split_target)?;
&self.seed, let num_secrets = amount_split.len() as u32;
amount_mintable,
&amount_split_target, tracing::debug!(
)?, "Incrementing keyset {} counter by {}",
active_keyset_id,
num_secrets
);
// Atomically get the counter range we need
let new_counter = self
.localstore
.increment_keyset_counter(&active_keyset_id, num_secrets)
.await?;
let count = new_counter - num_secrets;
PreMintSecrets::from_seed(
active_keyset_id,
count,
&self.seed,
amount_mintable,
&amount_split_target,
)?
}
}; };
let mut request = MintRequest { let mut request = MintRequest {
@@ -286,19 +301,6 @@ impl Wallet {
// Remove filled quote from store // Remove filled quote from store
self.localstore.remove_mint_quote(&quote_info.id).await?; self.localstore.remove_mint_quote(&quote_info.id).await?;
if spending_conditions.is_none() {
tracing::debug!(
"Incrementing keyset {} counter by {}",
active_keyset_id,
proofs.len()
);
// Update counter for keyset
self.localstore
.increment_keyset_counter(&active_keyset_id, proofs.len() as u32)
.await?;
}
let proof_infos = proofs let proof_infos = proofs
.iter() .iter()
.map(|proof| { .map(|proof| {

View File

@@ -107,11 +107,6 @@ impl Wallet {
let active_keyset_id = self.fetch_active_keyset().await?.id; let active_keyset_id = self.fetch_active_keyset().await?.id;
let count = self
.localstore
.get_keyset_counter(&active_keyset_id)
.await?;
let amount = match amount { let amount = match amount {
Some(amount) => amount, Some(amount) => amount,
None => { None => {
@@ -135,13 +130,33 @@ impl Wallet {
&amount_split_target, &amount_split_target,
spending_conditions, spending_conditions,
)?, )?,
None => PreMintSecrets::from_seed( None => {
active_keyset_id, // Calculate how many secrets we'll need without generating them
count, let amount_split = amount.split_targeted(&amount_split_target)?;
&self.seed, let num_secrets = amount_split.len() as u32;
amount,
&amount_split_target, tracing::debug!(
)?, "Incrementing keyset {} counter by {}",
active_keyset_id,
num_secrets
);
// Atomically get the counter range we need
let new_counter = self
.localstore
.increment_keyset_counter(&active_keyset_id, num_secrets)
.await?;
let count = new_counter - num_secrets;
PreMintSecrets::from_seed(
active_keyset_id,
count,
&self.seed,
amount,
&amount_split_target,
)?
}
}; };
let mut request = MintRequest { let mut request = MintRequest {
@@ -190,13 +205,6 @@ impl Wallet {
self.localstore.add_mint_quote(quote_info.clone()).await?; self.localstore.add_mint_quote(quote_info.clone()).await?;
if spending_conditions.is_none() {
// Update counter for keyset
self.localstore
.increment_keyset_counter(&active_keyset_id, proofs.len() as u32)
.await?;
}
let proof_infos = proofs let proof_infos = proofs
.iter() .iter()
.map(|proof| { .map(|proof| {

View File

@@ -15,7 +15,7 @@ use crate::nuts::{
use crate::types::{Melted, ProofInfo}; use crate::types::{Melted, ProofInfo};
use crate::util::unix_time; use crate::util::unix_time;
use crate::wallet::MeltQuote; use crate::wallet::MeltQuote;
use crate::{ensure_cdk, Error, Wallet}; use crate::{ensure_cdk, Amount, Error, Wallet};
impl Wallet { impl Wallet {
/// Melt Quote /// Melt Quote
@@ -148,17 +148,32 @@ impl Wallet {
let active_keyset_id = self.fetch_active_keyset().await?.id; let active_keyset_id = self.fetch_active_keyset().await?.id;
let count = self let change_amount = proofs_total - quote_info.amount;
.localstore
.get_keyset_counter(&active_keyset_id)
.await?;
let premint_secrets = PreMintSecrets::from_seed_blank( let premint_secrets = if change_amount <= Amount::ZERO {
active_keyset_id, PreMintSecrets::new(active_keyset_id)
count, } else {
&self.seed, // TODO: consolidate this calculation with from_seed_blank into a shared function
proofs_total - quote_info.amount, // Calculate how many secrets will be needed using the same logic as from_seed_blank
)?; let num_secrets =
((u64::from(change_amount) as f64).log2().ceil() as u64).max(1) as u32;
tracing::debug!(
"Incrementing keyset {} counter by {}",
active_keyset_id,
num_secrets
);
// Atomically get the counter range we need
let new_counter = self
.localstore
.increment_keyset_counter(&active_keyset_id, num_secrets)
.await?;
let count = new_counter - num_secrets;
PreMintSecrets::from_seed_blank(active_keyset_id, count, &self.seed, change_amount)?
};
let request = MeltRequest::new( let request = MeltRequest::new(
quote_id.to_string(), quote_id.to_string(),
@@ -226,11 +241,6 @@ impl Wallet {
change_proofs.total_amount()? change_proofs.total_amount()?
); );
// Update counter for keyset
self.localstore
.increment_keyset_counter(&active_keyset_id, change_proofs.len() as u32)
.await?;
change_proofs change_proofs
.into_iter() .into_iter()
.map(|proof| { .map(|proof| {

View File

@@ -52,10 +52,6 @@ impl Wallet {
&active_keys, &active_keys,
)?; )?;
self.localstore
.increment_keyset_counter(&active_keyset_id, pre_swap.derived_secret_count)
.await?;
let mut added_proofs = Vec::new(); let mut added_proofs = Vec::new();
let change_proofs; let change_proofs;
let send_proofs; let send_proofs;
@@ -248,10 +244,42 @@ impl Wallet {
let derived_secret_count; let derived_secret_count;
let mut count = self // Calculate total secrets needed and atomically reserve counter range
.localstore let total_secrets_needed = match spending_conditions {
.get_keyset_counter(&active_keyset_id) Some(_) => {
.await?; // For spending conditions, we only need to count change secrets
change_amount.split_targeted(&change_split_target)?.len() as u32
}
None => {
// For no spending conditions, count both send and change secrets
let send_count = send_amount
.unwrap_or(Amount::ZERO)
.split_targeted(&SplitTarget::default())?
.len() as u32;
let change_count = change_amount.split_targeted(&change_split_target)?.len() as u32;
send_count + change_count
}
};
// Atomically get the counter range we need
let starting_counter = if total_secrets_needed > 0 {
tracing::debug!(
"Incrementing keyset {} counter by {}",
active_keyset_id,
total_secrets_needed
);
let new_counter = self
.localstore
.increment_keyset_counter(&active_keyset_id, total_secrets_needed)
.await?;
new_counter - total_secrets_needed
} else {
0 // No secrets needed, don't increment the counter
};
let mut count = starting_counter;
let (mut desired_messages, change_messages) = match spending_conditions { let (mut desired_messages, change_messages) = match spending_conditions {
Some(conditions) => { Some(conditions) => {