fix: implement atomic keyset counter

- remove get_keyset_counter
- update increment_keyset_counter to atomically increment and return counter value
- replace get+increment pattern with atomic increment everywhere
This commit is contained in:
vnprc
2025-08-08 18:59:16 -04:00
committed by thesimplekid
parent 51b26eae62
commit 5b30ca546d
7 changed files with 152 additions and 114 deletions

View File

@@ -99,10 +99,8 @@ pub trait Database: Debug {
/// Update proofs state in storage
async fn update_proofs_state(&self, ys: Vec<PublicKey>, state: State) -> Result<(), Self::Err>;
/// Increment Keyset counter
async fn increment_keyset_counter(&self, keyset_id: &Id, count: u32) -> Result<(), Self::Err>;
/// Get current Keyset counter
async fn get_keyset_counter(&self, keyset_id: &Id) -> Result<u32, Self::Err>;
/// Atomically increment Keyset counter and return new value
async fn increment_keyset_counter(&self, keyset_id: &Id, count: u32) -> Result<u32, Self::Err>;
/// Add transaction to storage
async fn add_transaction(&self, transaction: Transaction) -> Result<(), Self::Err>;

View File

@@ -760,10 +760,11 @@ impl WalletDatabase for WalletRedbDatabase {
}
#[instrument(skip(self), fields(keyset_id = %keyset_id))]
async fn increment_keyset_counter(&self, keyset_id: &Id, count: u32) -> Result<(), Self::Err> {
async fn increment_keyset_counter(&self, keyset_id: &Id, count: u32) -> Result<u32, Self::Err> {
let write_txn = self.db.begin_write().map_err(Error::from)?;
let current_counter;
let new_counter;
{
let table = write_txn.open_table(KEYSET_COUNTER).map_err(Error::from)?;
let counter = table
@@ -774,11 +775,12 @@ impl WalletDatabase for WalletRedbDatabase {
Some(c) => c.value(),
None => 0,
};
new_counter = current_counter + count;
}
{
let mut table = write_txn.open_table(KEYSET_COUNTER).map_err(Error::from)?;
let new_counter = current_counter + count;
table
.insert(keyset_id.to_string().as_str(), new_counter)
@@ -786,19 +788,7 @@ impl WalletDatabase for WalletRedbDatabase {
}
write_txn.commit().map_err(Error::from)?;
Ok(())
}
#[instrument(skip(self), fields(keyset_id = %keyset_id))]
async fn get_keyset_counter(&self, keyset_id: &Id) -> Result<u32, Self::Err> {
let read_txn = self.db.begin_read().map_err(Error::from)?;
let table = read_txn.open_table(KEYSET_COUNTER).map_err(Error::from)?;
let counter = table
.get(keyset_id.to_string().as_str())
.map_err(Error::from)?;
Ok(counter.map_or(0, |c| c.value()))
Ok(new_counter)
}
#[instrument(skip(self))]

View File

@@ -839,42 +839,44 @@ ON CONFLICT(id) DO UPDATE SET
}
#[instrument(skip(self), fields(keyset_id = %keyset_id))]
async fn increment_keyset_counter(&self, keyset_id: &Id, count: u32) -> Result<(), Self::Err> {
async fn increment_keyset_counter(&self, keyset_id: &Id, count: u32) -> Result<u32, Self::Err> {
let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
query(
let tx = ConnectionWithTransaction::new(conn).await?;
// Lock the row and get current counter
let current_counter = query(
r#"
UPDATE keyset
SET counter=counter+:count
SELECT counter
FROM keyset
WHERE id=:id
"#,
)?
.bind("count", count)
.bind("id", keyset_id.to_string())
.execute(&*conn)
.await?;
Ok(())
}
#[instrument(skip(self), fields(keyset_id = %keyset_id))]
async fn get_keyset_counter(&self, keyset_id: &Id) -> Result<u32, Self::Err> {
let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
Ok(query(
r#"
SELECT
counter
FROM
keyset
WHERE
id=:id
FOR UPDATE
"#,
)?
.bind("id", keyset_id.to_string())
.pluck(&*conn)
.pluck(&tx)
.await?
.map(|n| Ok::<_, Error>(column_as_number!(n)))
.transpose()?
.unwrap_or(0))
.unwrap_or(0);
let new_counter = current_counter + count;
// Update with the new counter value
query(
r#"
UPDATE keyset
SET counter=:new_counter
WHERE id=:id
"#,
)?
.bind("new_counter", new_counter)
.bind("id", keyset_id.to_string())
.execute(&tx)
.await?;
tx.commit().await?;
Ok(new_counter)
}
#[instrument(skip(self))]

View File

@@ -229,11 +229,6 @@ impl Wallet {
let active_keyset_id = self.fetch_active_keyset().await?.id;
let count = self
.localstore
.get_keyset_counter(&active_keyset_id)
.await?;
let premint_secrets = match &spending_conditions {
Some(spending_conditions) => PreMintSecrets::with_conditions(
active_keyset_id,
@@ -241,13 +236,33 @@ impl Wallet {
&amount_split_target,
spending_conditions,
)?,
None => PreMintSecrets::from_seed(
active_keyset_id,
count,
&self.seed,
amount_mintable,
&amount_split_target,
)?,
None => {
// Calculate how many secrets we'll need
let amount_split = amount_mintable.split_targeted(&amount_split_target)?;
let num_secrets = amount_split.len() as u32;
tracing::debug!(
"Incrementing keyset {} counter by {}",
active_keyset_id,
num_secrets
);
// Atomically get the counter range we need
let new_counter = self
.localstore
.increment_keyset_counter(&active_keyset_id, num_secrets)
.await?;
let count = new_counter - num_secrets;
PreMintSecrets::from_seed(
active_keyset_id,
count,
&self.seed,
amount_mintable,
&amount_split_target,
)?
}
};
let mut request = MintRequest {
@@ -286,19 +301,6 @@ impl Wallet {
// Remove filled quote from store
self.localstore.remove_mint_quote(&quote_info.id).await?;
if spending_conditions.is_none() {
tracing::debug!(
"Incrementing keyset {} counter by {}",
active_keyset_id,
proofs.len()
);
// Update counter for keyset
self.localstore
.increment_keyset_counter(&active_keyset_id, proofs.len() as u32)
.await?;
}
let proof_infos = proofs
.iter()
.map(|proof| {

View File

@@ -107,11 +107,6 @@ impl Wallet {
let active_keyset_id = self.fetch_active_keyset().await?.id;
let count = self
.localstore
.get_keyset_counter(&active_keyset_id)
.await?;
let amount = match amount {
Some(amount) => amount,
None => {
@@ -135,13 +130,33 @@ impl Wallet {
&amount_split_target,
spending_conditions,
)?,
None => PreMintSecrets::from_seed(
active_keyset_id,
count,
&self.seed,
amount,
&amount_split_target,
)?,
None => {
// Calculate how many secrets we'll need without generating them
let amount_split = amount.split_targeted(&amount_split_target)?;
let num_secrets = amount_split.len() as u32;
tracing::debug!(
"Incrementing keyset {} counter by {}",
active_keyset_id,
num_secrets
);
// Atomically get the counter range we need
let new_counter = self
.localstore
.increment_keyset_counter(&active_keyset_id, num_secrets)
.await?;
let count = new_counter - num_secrets;
PreMintSecrets::from_seed(
active_keyset_id,
count,
&self.seed,
amount,
&amount_split_target,
)?
}
};
let mut request = MintRequest {
@@ -190,13 +205,6 @@ impl Wallet {
self.localstore.add_mint_quote(quote_info.clone()).await?;
if spending_conditions.is_none() {
// Update counter for keyset
self.localstore
.increment_keyset_counter(&active_keyset_id, proofs.len() as u32)
.await?;
}
let proof_infos = proofs
.iter()
.map(|proof| {

View File

@@ -15,7 +15,7 @@ use crate::nuts::{
use crate::types::{Melted, ProofInfo};
use crate::util::unix_time;
use crate::wallet::MeltQuote;
use crate::{ensure_cdk, Error, Wallet};
use crate::{ensure_cdk, Amount, Error, Wallet};
impl Wallet {
/// Melt Quote
@@ -148,17 +148,32 @@ impl Wallet {
let active_keyset_id = self.fetch_active_keyset().await?.id;
let count = self
.localstore
.get_keyset_counter(&active_keyset_id)
.await?;
let change_amount = proofs_total - quote_info.amount;
let premint_secrets = PreMintSecrets::from_seed_blank(
active_keyset_id,
count,
&self.seed,
proofs_total - quote_info.amount,
)?;
let premint_secrets = if change_amount <= Amount::ZERO {
PreMintSecrets::new(active_keyset_id)
} else {
// TODO: consolidate this calculation with from_seed_blank into a shared function
// Calculate how many secrets will be needed using the same logic as from_seed_blank
let num_secrets =
((u64::from(change_amount) as f64).log2().ceil() as u64).max(1) as u32;
tracing::debug!(
"Incrementing keyset {} counter by {}",
active_keyset_id,
num_secrets
);
// Atomically get the counter range we need
let new_counter = self
.localstore
.increment_keyset_counter(&active_keyset_id, num_secrets)
.await?;
let count = new_counter - num_secrets;
PreMintSecrets::from_seed_blank(active_keyset_id, count, &self.seed, change_amount)?
};
let request = MeltRequest::new(
quote_id.to_string(),
@@ -226,11 +241,6 @@ impl Wallet {
change_proofs.total_amount()?
);
// Update counter for keyset
self.localstore
.increment_keyset_counter(&active_keyset_id, change_proofs.len() as u32)
.await?;
change_proofs
.into_iter()
.map(|proof| {

View File

@@ -52,10 +52,6 @@ impl Wallet {
&active_keys,
)?;
self.localstore
.increment_keyset_counter(&active_keyset_id, pre_swap.derived_secret_count)
.await?;
let mut added_proofs = Vec::new();
let change_proofs;
let send_proofs;
@@ -248,10 +244,42 @@ impl Wallet {
let derived_secret_count;
let mut count = self
.localstore
.get_keyset_counter(&active_keyset_id)
.await?;
// Calculate total secrets needed and atomically reserve counter range
let total_secrets_needed = match spending_conditions {
Some(_) => {
// For spending conditions, we only need to count change secrets
change_amount.split_targeted(&change_split_target)?.len() as u32
}
None => {
// For no spending conditions, count both send and change secrets
let send_count = send_amount
.unwrap_or(Amount::ZERO)
.split_targeted(&SplitTarget::default())?
.len() as u32;
let change_count = change_amount.split_targeted(&change_split_target)?.len() as u32;
send_count + change_count
}
};
// Atomically get the counter range we need
let starting_counter = if total_secrets_needed > 0 {
tracing::debug!(
"Incrementing keyset {} counter by {}",
active_keyset_id,
total_secrets_needed
);
let new_counter = self
.localstore
.increment_keyset_counter(&active_keyset_id, total_secrets_needed)
.await?;
new_counter - total_secrets_needed
} else {
0 // No secrets needed, don't increment the counter
};
let mut count = starting_counter;
let (mut desired_messages, change_messages) = match spending_conditions {
Some(conditions) => {