Migrate from sqlx to rusqlite (#783)

* Migrate from `sqlx` to rusqlite

1. Add rusqlite with rusqlite with a working thread
2. Add wallet without a thread (synchronous)
3. Add custom migration

Co-authored-by: thesimplekid <tsk@thesimplekid.com>
This commit is contained in:
C
2025-06-14 08:49:50 -03:00
committed by GitHub
parent a335b269b7
commit 5a6b28816a
18 changed files with 3300 additions and 2456 deletions

View File

@@ -1,46 +1,135 @@
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
use sqlx::{Error, Pool, Sqlite};
use rusqlite::{params, Connection};
#[inline(always)]
pub async fn create_sqlite_pool(
use crate::pool::{Pool, ResourceManager};
/// The config need to create a new SQLite connection
#[derive(Debug)]
pub struct Config {
path: Option<String>,
password: Option<String>,
}
/// Sqlite connection manager
#[derive(Debug)]
pub struct SqliteConnectionManager;
impl ResourceManager for SqliteConnectionManager {
type Config = Config;
type Resource = Connection;
type Error = rusqlite::Error;
fn new_resource(
config: &Self::Config,
) -> Result<Self::Resource, crate::pool::Error<Self::Error>> {
let conn = if let Some(path) = config.path.as_ref() {
Connection::open(path)?
} else {
Connection::open_in_memory()?
};
if let Some(password) = config.password.as_ref() {
conn.execute_batch(&format!("pragma key = '{password}';"))?;
}
conn.execute_batch(
r#"
pragma busy_timeout = 10000;
pragma journal_mode = WAL;
pragma synchronous = normal;
pragma temp_store = memory;
pragma mmap_size = 30000000000;
pragma cache = shared;
"#,
)?;
Ok(conn)
}
}
/// Create a configured rusqlite connection to a SQLite database.
/// For SQLCipher support, enable the "sqlcipher" feature and pass a password.
pub fn create_sqlite_pool(
path: &str,
#[cfg(feature = "sqlcipher")] password: String,
) -> Result<Pool<Sqlite>, Error> {
let db_options = SqliteConnectOptions::from_str(path)?
.busy_timeout(Duration::from_secs(10))
.read_only(false)
.pragma("busy_timeout", "5000")
.pragma("journal_mode", "wal")
.pragma("synchronous", "normal")
.pragma("temp_store", "memory")
.pragma("mmap_size", "30000000000")
.shared_cache(true)
.create_if_missing(true);
) -> Arc<Pool<SqliteConnectionManager>> {
#[cfg(feature = "sqlcipher")]
let db_options = db_options.pragma("key", password);
let password = Some(password);
let is_memory = path.contains(":memory:");
#[cfg(not(feature = "sqlcipher"))]
let password = None;
let options = SqlitePoolOptions::new()
.min_connections(1)
.max_connections(1);
let pool = if is_memory {
// Make sure that the connection is not closed after the first query, or any query, as long
// as the pool is not dropped
options
.idle_timeout(None)
.max_lifetime(None)
.test_before_acquire(false)
let (config, max_size) = if path.contains(":memory:") {
(
Config {
path: None,
password,
},
1,
)
} else {
options
}
.connect_with(db_options)
.await?;
(
Config {
path: Some(path.to_owned()),
password,
},
20,
)
};
Ok(pool)
Pool::new(config, max_size, Duration::from_secs(5))
}
/// Migrates the migration generated by `build.rs`
pub fn migrate(conn: &mut Connection, migrations: &[(&str, &str)]) -> Result<(), rusqlite::Error> {
let tx = conn.transaction()?;
tx.execute(
r#"
CREATE TABLE IF NOT EXISTS migrations (
name TEXT PRIMARY KEY,
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"#,
[],
)?;
if tx.query_row(
r#"select count(*) from sqlite_master where name = '_sqlx_migrations'"#,
[],
|row| row.get::<_, i32>(0),
)? == 1
{
tx.execute_batch(
r#"
INSERT INTO migrations
SELECT
version || '_' || REPLACE(description, ' ', '_') || '.sql',
execution_time
FROM _sqlx_migrations;
DROP TABLE _sqlx_migrations;
"#,
)?;
}
// Apply each migration if it hasnt been applied yet
for (name, sql) in migrations {
let already_applied: bool = tx.query_row(
"SELECT EXISTS(SELECT 1 FROM migrations WHERE name = ?1)",
params![name],
|row| row.get(0),
)?;
if !already_applied {
tx.execute_batch(sql)?;
tx.execute("INSERT INTO migrations (name) VALUES (?1)", params![name])?;
}
}
tx.commit()?;
Ok(())
}

View File

@@ -4,6 +4,9 @@
#![warn(rustdoc::bare_urls)]
mod common;
mod macros;
mod pool;
mod stmt;
#[cfg(feature = "mint")]
pub mod mint;

View File

@@ -0,0 +1,171 @@
//! Collection of macros to generate code to digest data from SQLite
/// Unpacks a vector of Column, and consumes it, parsing into individual variables, checking the
/// vector is big enough.
#[macro_export]
macro_rules! unpack_into {
(let ($($var:ident),+) = $array:expr) => {
let ($($var),+) = {
let mut vec = $array.to_vec();
vec.reverse();
let required = 0 $(+ {let _ = stringify!($var); 1})+;
if vec.len() < required {
return Err(Error::MissingColumn(required, vec.len()));
}
Ok::<_, Error>((
$(
vec.pop().expect(&format!("Checked length already for {}", stringify!($var)))
),+
))?
};
};
}
/// Parses a SQLite column as a string or NULL
#[macro_export]
macro_rules! column_as_nullable_string {
($col:expr, $callback_str:expr, $callback_bytes:expr) => {
(match $col {
$crate::stmt::Column::Text(text) => Ok(Some(text).and_then($callback_str)),
$crate::stmt::Column::Blob(bytes) => Ok(Some(bytes).and_then($callback_bytes)),
$crate::stmt::Column::Null => Ok(None),
other => Err(Error::InvalidType(
"String".to_owned(),
other.data_type().to_string(),
)),
})?
};
($col:expr, $callback_str:expr) => {
(match $col {
$crate::stmt::Column::Text(text) => Ok(Some(text).and_then($callback_str)),
$crate::stmt::Column::Blob(bytes) => {
Ok(Some(String::from_utf8_lossy(&bytes)).and_then($callback_str))
}
$crate::stmt::Column::Null => Ok(None),
other => Err(Error::InvalidType(
"String".to_owned(),
other.data_type().to_string(),
)),
})?
};
($col:expr) => {
(match $col {
$crate::stmt::Column::Text(text) => Ok(Some(text.to_owned())),
$crate::stmt::Column::Blob(bytes) => {
Ok(Some(String::from_utf8_lossy(&bytes).to_string()))
}
$crate::stmt::Column::Null => Ok(None),
other => Err(Error::InvalidType(
"String".to_owned(),
other.data_type().to_string(),
)),
})?
};
}
/// Parses a column as a number or NULL
#[macro_export]
macro_rules! column_as_nullable_number {
($col:expr) => {
(match $col {
$crate::stmt::Column::Text(text) => Ok(Some(text.parse().map_err(|_| {
Error::InvalidConversion(stringify!($col).to_owned(), "Number".to_owned())
})?)),
$crate::stmt::Column::Integer(n) => Ok(Some(n.try_into().map_err(|_| {
Error::InvalidConversion(stringify!($col).to_owned(), "Number".to_owned())
})?)),
$crate::stmt::Column::Null => Ok(None),
other => Err(Error::InvalidType(
"Number".to_owned(),
other.data_type().to_string(),
)),
})?
};
}
/// Parses a column as a number
#[macro_export]
macro_rules! column_as_number {
($col:expr) => {
(match $col {
$crate::stmt::Column::Text(text) => text.parse().map_err(|_| {
Error::InvalidConversion(stringify!($col).to_owned(), "Number".to_owned())
}),
$crate::stmt::Column::Integer(n) => n.try_into().map_err(|_| {
Error::InvalidConversion(stringify!($col).to_owned(), "Number".to_owned())
}),
other => Err(Error::InvalidType(
"Number".to_owned(),
other.data_type().to_string(),
)),
})?
};
}
/// Parses a column as a NULL or Binary
#[macro_export]
macro_rules! column_as_nullable_binary {
($col:expr) => {
(match $col {
$crate::stmt::Column::Text(text) => Ok(Some(text.as_bytes().to_vec())),
$crate::stmt::Column::Blob(bytes) => Ok(Some(bytes.to_owned())),
$crate::stmt::Column::Null => Ok(None),
other => Err(Error::InvalidType(
"String".to_owned(),
other.data_type().to_string(),
)),
})?
};
}
/// Parses a SQLite column as a binary
#[macro_export]
macro_rules! column_as_binary {
($col:expr) => {
(match $col {
$crate::stmt::Column::Text(text) => Ok(text.as_bytes().to_vec()),
$crate::stmt::Column::Blob(bytes) => Ok(bytes.to_owned()),
other => Err(Error::InvalidType(
"String".to_owned(),
other.data_type().to_string(),
)),
})?
};
}
/// Parses a SQLite column as a string
#[macro_export]
macro_rules! column_as_string {
($col:expr, $callback_str:expr, $callback_bytes:expr) => {
(match $col {
$crate::stmt::Column::Text(text) => $callback_str(&text).map_err(Error::from),
$crate::stmt::Column::Blob(bytes) => $callback_bytes(&bytes).map_err(Error::from),
other => Err(Error::InvalidType(
"String".to_owned(),
other.data_type().to_string(),
)),
})?
};
($col:expr, $callback:expr) => {
(match $col {
$crate::stmt::Column::Text(text) => $callback(&text).map_err(Error::from),
$crate::stmt::Column::Blob(bytes) => {
$callback(&String::from_utf8_lossy(&bytes)).map_err(Error::from)
}
other => Err(Error::InvalidType(
"String".to_owned(),
other.data_type().to_string(),
)),
})?
};
($col:expr) => {
(match $col {
$crate::stmt::Column::Text(text) => Ok(text.to_owned()),
$crate::stmt::Column::Blob(bytes) => Ok(String::from_utf8_lossy(&bytes).to_string()),
other => Err(Error::InvalidType(
"String".to_owned(),
other.data_type().to_string(),
)),
})?
};
}

View File

@@ -0,0 +1,519 @@
use std::marker::PhantomData;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{mpsc as std_mpsc, Arc, Mutex};
use std::thread::spawn;
use std::time::Instant;
use rusqlite::Connection;
use tokio::sync::{mpsc, oneshot};
use crate::common::SqliteConnectionManager;
use crate::mint::Error;
use crate::pool::{Pool, PooledResource};
use crate::stmt::{Column, ExpectedSqlResponse, Statement as InnerStatement, Value};
/// The number of queued SQL statements before it start failing
const SQL_QUEUE_SIZE: usize = 10_000;
/// How many ms is considered a slow query, and it'd be logged for further debugging
const SLOW_QUERY_THRESHOLD_MS: u128 = 20;
/// How many SQLite parallel connections can be used to read things in parallel
const WORKING_THREAD_POOL_SIZE: usize = 5;
#[derive(Debug, Clone)]
pub struct AsyncRusqlite {
sender: mpsc::Sender<DbRequest>,
inflight_requests: Arc<AtomicUsize>,
}
/// Internal request for the database thread
#[derive(Debug)]
pub enum DbRequest {
Sql(InnerStatement, oneshot::Sender<DbResponse>),
Begin(oneshot::Sender<DbResponse>),
Commit(oneshot::Sender<DbResponse>),
Rollback(oneshot::Sender<DbResponse>),
}
#[derive(Debug)]
pub enum DbResponse {
Transaction(mpsc::Sender<DbRequest>),
AffectedRows(usize),
Pluck(Option<Column>),
Row(Option<Vec<Column>>),
Rows(Vec<Vec<Column>>),
Error(Error),
Unexpected,
Ok,
}
/// Statement for the async_rusqlite wrapper
pub struct Statement(InnerStatement);
impl Statement {
/// Bind a variable
pub fn bind<C, V>(self, name: C, value: V) -> Self
where
C: ToString,
V: Into<Value>,
{
Self(self.0.bind(name, value))
}
/// Bind vec
pub fn bind_vec<C, V>(self, name: C, value: Vec<V>) -> Self
where
C: ToString,
V: Into<Value>,
{
Self(self.0.bind_vec(name, value))
}
/// Executes a query and return the number of affected rows
pub async fn execute<C>(self, conn: &C) -> Result<usize, Error>
where
C: DatabaseExecutor + Send + Sync,
{
conn.execute(self.0).await
}
/// Returns the first column of the first row of the query result
pub async fn pluck<C>(self, conn: &C) -> Result<Option<Column>, Error>
where
C: DatabaseExecutor + Send + Sync,
{
conn.pluck(self.0).await
}
/// Returns the first row of the query result
pub async fn fetch_one<C>(self, conn: &C) -> Result<Option<Vec<Column>>, Error>
where
C: DatabaseExecutor + Send + Sync,
{
conn.fetch_one(self.0).await
}
/// Returns all rows of the query result
pub async fn fetch_all<C>(self, conn: &C) -> Result<Vec<Vec<Column>>, Error>
where
C: DatabaseExecutor + Send + Sync,
{
conn.fetch_all(self.0).await
}
}
/// Process a query
#[inline(always)]
fn process_query(conn: &Connection, sql: InnerStatement) -> Result<DbResponse, Error> {
let start = Instant::now();
let mut args = sql.args;
let mut stmt = conn.prepare_cached(&sql.sql)?;
let total_parameters = stmt.parameter_count();
for index in 1..=total_parameters {
let value = if let Some(value) = stmt.parameter_name(index).map(|name| {
args.remove(name)
.ok_or(Error::MissingParameter(name.to_owned()))
}) {
value?
} else {
continue;
};
stmt.raw_bind_parameter(index, value)?;
}
let columns = stmt.column_count();
let to_return = match sql.expected_response {
ExpectedSqlResponse::AffectedRows => DbResponse::AffectedRows(stmt.raw_execute()?),
ExpectedSqlResponse::ManyRows => {
let mut rows = stmt.raw_query();
let mut results = vec![];
while let Some(row) = rows.next()? {
results.push(
(0..columns)
.map(|i| row.get(i))
.collect::<Result<Vec<_>, _>>()?,
)
}
DbResponse::Rows(results)
}
ExpectedSqlResponse::Pluck => {
let mut rows = stmt.raw_query();
DbResponse::Pluck(rows.next()?.map(|row| row.get(0usize)).transpose()?)
}
ExpectedSqlResponse::SingleRow => {
let mut rows = stmt.raw_query();
let row = rows
.next()?
.map(|row| {
(0..columns)
.map(|i| row.get(i))
.collect::<Result<Vec<_>, _>>()
})
.transpose()?;
DbResponse::Row(row)
}
};
let duration = start.elapsed();
if duration.as_millis() > SLOW_QUERY_THRESHOLD_MS {
tracing::warn!("[SLOW QUERY] Took {} ms: {}", duration.as_millis(), sql.sql);
}
Ok(to_return)
}
/// Spawns N number of threads to execute SQL statements
///
/// Enable parallelism with a pool of threads.
///
/// There is a main thread, which receives SQL requests and routes them to a worker thread from a
/// fixed-size pool.
///
/// By doing so, SQLite does synchronization, and Rust will only intervene when a transaction is
/// executed. Transactions are executed in the main thread.
fn rusqlite_spawn_worker_threads(
inflight_requests: Arc<AtomicUsize>,
threads: usize,
) -> std_mpsc::Sender<(
PooledResource<SqliteConnectionManager>,
InnerStatement,
oneshot::Sender<DbResponse>,
)> {
let (sender, receiver) = std_mpsc::channel::<(
PooledResource<SqliteConnectionManager>,
InnerStatement,
oneshot::Sender<DbResponse>,
)>();
let receiver = Arc::new(Mutex::new(receiver));
for _ in 0..threads {
let rx = receiver.clone();
let inflight_requests = inflight_requests.clone();
spawn(move || loop {
while let Ok((conn, sql, reply_to)) = rx.lock().expect("failed to acquire").recv() {
tracing::info!("Execute query: {}", sql.sql);
let result = process_query(&conn, sql);
let _ = match result {
Ok(ok) => reply_to.send(ok),
Err(err) => {
tracing::error!("Failed query with error {:?}", err);
reply_to.send(DbResponse::Error(err))
}
};
drop(conn);
inflight_requests.fetch_sub(1, Ordering::Relaxed);
}
});
}
sender
}
/// # Rusqlite main worker
///
/// This function takes ownership of a pool of connections to SQLite, executes SQL statements, and
/// returns the results or number of affected rows to the caller. All communications are done
/// through channels. This function is synchronous, but a thread pool exists to execute queries, and
/// SQLite will coordinate data access. Transactions are executed in the main and it takes ownership
/// of the main thread until it is finalized
///
/// This is meant to be called in their thread, as it will not exit the loop until the communication
/// channel is closed.
fn rusqlite_worker_manager(
mut receiver: mpsc::Receiver<DbRequest>,
pool: Arc<Pool<SqliteConnectionManager>>,
inflight_requests: Arc<AtomicUsize>,
) {
let send_sql_to_thread =
rusqlite_spawn_worker_threads(inflight_requests.clone(), WORKING_THREAD_POOL_SIZE);
let mut tx_id: usize = 0;
while let Some(request) = receiver.blocking_recv() {
inflight_requests.fetch_add(1, Ordering::Relaxed);
match request {
DbRequest::Sql(sql, reply_to) => {
let conn = match pool.get() {
Ok(conn) => conn,
Err(err) => {
tracing::error!("Failed to acquire a pool connection: {:?}", err);
inflight_requests.fetch_sub(1, Ordering::Relaxed);
let _ = reply_to.send(DbResponse::Error(err.into()));
continue;
}
};
let _ = send_sql_to_thread.send((conn, sql, reply_to));
continue;
}
DbRequest::Begin(reply_to) => {
let (sender, mut receiver) = mpsc::channel(SQL_QUEUE_SIZE);
let mut conn = match pool.get() {
Ok(conn) => conn,
Err(err) => {
tracing::error!("Failed to acquire a pool connection: {:?}", err);
inflight_requests.fetch_sub(1, Ordering::Relaxed);
let _ = reply_to.send(DbResponse::Error(err.into()));
continue;
}
};
let tx = match conn.transaction() {
Ok(tx) => tx,
Err(err) => {
tracing::error!("Failed to begin a transaction: {:?}", err);
inflight_requests.fetch_sub(1, Ordering::Relaxed);
let _ = reply_to.send(DbResponse::Error(err.into()));
continue;
}
};
// Transaction has begun successfully, send the `sender` back to the caller
// and wait for statements to execute. On `Drop` the wrapper transaction
// should send a `rollback`.
let _ = reply_to.send(DbResponse::Transaction(sender));
tx_id += 1;
// We intentionally handle the transaction hijacking the main loop, there is
// no point is queueing more operations for SQLite, since transaction have
// exclusive access. In other database implementation this block of code
// should be sent to their own thread to allow concurrency
loop {
let request = if let Some(request) = receiver.blocking_recv() {
request
} else {
// If the receiver loop is broken (i.e no more `senders` are active) and no
// `Commit` statement has been sent, this will trigger a `Rollback`
// automatically
tracing::info!("Tx {}: Transaction rollback on drop", tx_id);
let _ = tx.rollback();
break;
};
match request {
DbRequest::Commit(reply_to) => {
tracing::info!("Tx {}: Commit", tx_id);
let _ = reply_to.send(match tx.commit() {
Ok(()) => DbResponse::Ok,
Err(err) => DbResponse::Error(err.into()),
});
break;
}
DbRequest::Rollback(reply_to) => {
tracing::info!("Tx {}: Rollback", tx_id);
let _ = reply_to.send(match tx.rollback() {
Ok(()) => DbResponse::Ok,
Err(err) => DbResponse::Error(err.into()),
});
break;
}
DbRequest::Begin(reply_to) => {
let _ = reply_to.send(DbResponse::Unexpected);
}
DbRequest::Sql(sql, reply_to) => {
tracing::info!("Tx {}: SQL {}", tx_id, sql.sql);
let _ = match process_query(&tx, sql) {
Ok(ok) => reply_to.send(ok),
Err(err) => reply_to.send(DbResponse::Error(err)),
};
}
}
}
drop(conn);
}
DbRequest::Commit(reply_to) => {
let _ = reply_to.send(DbResponse::Unexpected);
}
DbRequest::Rollback(reply_to) => {
let _ = reply_to.send(DbResponse::Unexpected);
}
}
// If wasn't a `continue` the transaction is done by reaching this code, and we should
// decrease the inflight_request counter
inflight_requests.fetch_sub(1, Ordering::Relaxed);
}
}
#[async_trait::async_trait]
pub trait DatabaseExecutor {
/// Returns the connection to the database thread (or the on-going transaction)
fn get_queue_sender(&self) -> mpsc::Sender<DbRequest>;
/// Executes a query and returns the affected rows
async fn execute(&self, mut statement: InnerStatement) -> Result<usize, Error> {
let (sender, receiver) = oneshot::channel();
statement.expected_response = ExpectedSqlResponse::AffectedRows;
self.get_queue_sender()
.send(DbRequest::Sql(statement, sender))
.await
.map_err(|_| Error::Communication)?;
match receiver.await.map_err(|_| Error::Communication)? {
DbResponse::AffectedRows(n) => Ok(n),
DbResponse::Error(err) => Err(err),
_ => Err(Error::InvalidDbResponse),
}
}
/// Runs the query and returns the first row or None
async fn fetch_one(&self, mut statement: InnerStatement) -> Result<Option<Vec<Column>>, Error> {
let (sender, receiver) = oneshot::channel();
statement.expected_response = ExpectedSqlResponse::SingleRow;
self.get_queue_sender()
.send(DbRequest::Sql(statement, sender))
.await
.map_err(|_| Error::Communication)?;
match receiver.await.map_err(|_| Error::Communication)? {
DbResponse::Row(row) => Ok(row),
DbResponse::Error(err) => Err(err),
_ => Err(Error::InvalidDbResponse),
}
}
/// Runs the query and returns the first row or None
async fn fetch_all(&self, mut statement: InnerStatement) -> Result<Vec<Vec<Column>>, Error> {
let (sender, receiver) = oneshot::channel();
statement.expected_response = ExpectedSqlResponse::ManyRows;
self.get_queue_sender()
.send(DbRequest::Sql(statement, sender))
.await
.map_err(|_| Error::Communication)?;
match receiver.await.map_err(|_| Error::Communication)? {
DbResponse::Rows(rows) => Ok(rows),
DbResponse::Error(err) => Err(err),
_ => Err(Error::InvalidDbResponse),
}
}
async fn pluck(&self, mut statement: InnerStatement) -> Result<Option<Column>, Error> {
let (sender, receiver) = oneshot::channel();
statement.expected_response = ExpectedSqlResponse::Pluck;
self.get_queue_sender()
.send(DbRequest::Sql(statement, sender))
.await
.map_err(|_| Error::Communication)?;
match receiver.await.map_err(|_| Error::Communication)? {
DbResponse::Pluck(value) => Ok(value),
DbResponse::Error(err) => Err(err),
_ => Err(Error::InvalidDbResponse),
}
}
}
#[inline(always)]
pub fn query<T>(sql: T) -> Statement
where
T: ToString,
{
Statement(crate::stmt::Statement::new(sql))
}
impl AsyncRusqlite {
/// Creates a new Async Rusqlite wrapper.
pub fn new(pool: Arc<Pool<SqliteConnectionManager>>) -> Self {
let (sender, receiver) = mpsc::channel(SQL_QUEUE_SIZE);
let inflight_requests = Arc::new(AtomicUsize::new(0));
let inflight_requests_for_thread = inflight_requests.clone();
spawn(move || {
rusqlite_worker_manager(receiver, pool, inflight_requests_for_thread);
});
Self {
sender,
inflight_requests,
}
}
/// Show how many inflight requests
#[allow(dead_code)]
pub fn inflight_requests(&self) -> usize {
self.inflight_requests.load(Ordering::Relaxed)
}
/// Begins a transaction
///
/// If the transaction is Drop it will trigger a rollback operation
pub async fn begin(&self) -> Result<Transaction<'_>, Error> {
let (sender, receiver) = oneshot::channel();
self.sender
.send(DbRequest::Begin(sender))
.await
.map_err(|_| Error::Communication)?;
match receiver.await.map_err(|_| Error::Communication)? {
DbResponse::Transaction(db_sender) => Ok(Transaction {
db_sender,
_marker: PhantomData,
}),
DbResponse::Error(err) => Err(err),
_ => Err(Error::InvalidDbResponse),
}
}
}
impl DatabaseExecutor for AsyncRusqlite {
#[inline(always)]
fn get_queue_sender(&self) -> mpsc::Sender<DbRequest> {
self.sender.clone()
}
}
pub struct Transaction<'conn> {
db_sender: mpsc::Sender<DbRequest>,
_marker: PhantomData<&'conn ()>,
}
impl Drop for Transaction<'_> {
fn drop(&mut self) {
let (sender, _) = oneshot::channel();
let _ = self.db_sender.try_send(DbRequest::Rollback(sender));
}
}
impl Transaction<'_> {
pub async fn commit(self) -> Result<(), Error> {
let (sender, receiver) = oneshot::channel();
self.db_sender
.send(DbRequest::Commit(sender))
.await
.map_err(|_| Error::Communication)?;
match receiver.await.map_err(|_| Error::Communication)? {
DbResponse::Ok => Ok(()),
DbResponse::Error(err) => Err(err),
_ => Err(Error::InvalidDbResponse),
}
}
pub async fn rollback(self) -> Result<(), Error> {
let (sender, receiver) = oneshot::channel();
self.db_sender
.send(DbRequest::Rollback(sender))
.await
.map_err(|_| Error::Communication)?;
match receiver.await.map_err(|_| Error::Communication)? {
DbResponse::Ok => Ok(()),
DbResponse::Error(err) => Err(err),
_ => Err(Error::InvalidDbResponse),
}
}
}
impl DatabaseExecutor for Transaction<'_> {
/// Get the internal sender to the SQL queue
#[inline(always)]
fn get_queue_sender(&self) -> mpsc::Sender<DbRequest> {
self.db_sender.clone()
}
}

View File

@@ -0,0 +1,5 @@
// @generated
// Auto-generated by build.rs
pub static MIGRATIONS: &[(&str, &str)] = &[
("20250109143347_init.sql", include_str!(r#"./migrations/20250109143347_init.sql"#)),
];

View File

@@ -1,52 +1,57 @@
//! SQLite Mint Auth
use std::collections::HashMap;
use std::ops::DerefMut;
use std::path::Path;
use std::str::FromStr;
use std::time::Duration;
use async_trait::async_trait;
use cdk_common::database::{self, MintAuthDatabase};
use cdk_common::mint::MintKeySetInfo;
use cdk_common::nuts::{AuthProof, BlindSignature, Id, PublicKey, State};
use cdk_common::{AuthRequired, ProtectedEndpoint};
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
use sqlx::Row;
use tracing::instrument;
use super::async_rusqlite::AsyncRusqlite;
use super::{sqlite_row_to_blind_signature, sqlite_row_to_keyset_info};
use crate::column_as_string;
use crate::common::{create_sqlite_pool, migrate};
use crate::mint::async_rusqlite::query;
use crate::mint::Error;
/// Mint SQLite Database
#[derive(Debug, Clone)]
pub struct MintSqliteAuthDatabase {
pool: SqlitePool,
pool: AsyncRusqlite,
}
#[rustfmt::skip]
mod migrations;
impl MintSqliteAuthDatabase {
/// Create new [`MintSqliteAuthDatabase`]
pub async fn new(path: &Path) -> Result<Self, Error> {
let path = path.to_str().ok_or(Error::InvalidDbPath)?;
let db_options = SqliteConnectOptions::from_str(path)?
.busy_timeout(Duration::from_secs(5))
.read_only(false)
.create_if_missing(true)
.auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Full);
#[cfg(not(feature = "sqlcipher"))]
pub async fn new<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
let pool = create_sqlite_pool(path.as_ref().to_str().ok_or(Error::InvalidDbPath)?);
migrate(pool.get()?.deref_mut(), migrations::MIGRATIONS)?;
let pool = SqlitePoolOptions::new()
.max_connections(1)
.connect_with(db_options)
.await?;
Ok(Self { pool })
Ok(Self {
pool: AsyncRusqlite::new(pool),
})
}
/// Migrate [`MintSqliteAuthDatabase`]
pub async fn migrate(&self) {
sqlx::migrate!("./src/mint/auth/migrations")
.run(&self.pool)
.await
.expect("Could not run migrations");
/// Create new [`MintSqliteAuthDatabase`]
#[cfg(feature = "sqlcipher")]
pub async fn new<P: AsRef<Path>>(path: P, password: String) -> Result<Self, Error> {
let pool = create_sqlite_pool(
path.as_ref().to_str().ok_or(Error::InvalidDbPath)?,
password,
);
migrate(pool.get()?.deref_mut(), migrations::MIGRATIONS)?;
Ok(Self {
pool: AsyncRusqlite::new(pool),
})
}
}
@@ -57,230 +62,156 @@ impl MintAuthDatabase for MintSqliteAuthDatabase {
#[instrument(skip(self))]
async fn set_active_keyset(&self, id: Id) -> Result<(), Self::Err> {
tracing::info!("Setting auth keyset {id} active");
let mut transaction = self.pool.begin().await.map_err(Error::from)?;
let update_res = sqlx::query(
query(
r#"
UPDATE keyset
SET active = CASE
WHEN id = ? THEN TRUE
ELSE FALSE
END;
"#,
UPDATE keyset
SET active = CASE
WHEN id = :id THEN TRUE
ELSE FALSE
END;
"#,
)
.bind(id.to_string())
.execute(&mut *transaction)
.await;
match update_res {
Ok(_) => {
transaction.commit().await.map_err(Error::from)?;
Ok(())
}
Err(err) => {
tracing::error!("SQLite Could not update keyset");
if let Err(err) = transaction.rollback().await {
tracing::error!("Could not rollback sql transaction: {}", err);
}
Err(Error::from(err).into())
}
}
}
async fn get_active_keyset_id(&self) -> Result<Option<Id>, Self::Err> {
let mut transaction = self.pool.begin().await.map_err(Error::from)?;
let rec = sqlx::query(
r#"
SELECT id
FROM keyset
WHERE active = 1;
"#,
)
.fetch_one(&mut *transaction)
.await;
let rec = match rec {
Ok(rec) => {
transaction.commit().await.map_err(Error::from)?;
rec
}
Err(err) => match err {
sqlx::Error::RowNotFound => {
transaction.commit().await.map_err(Error::from)?;
return Ok(None);
}
_ => {
return {
if let Err(err) = transaction.rollback().await {
tracing::error!("Could not rollback sql transaction: {}", err);
}
Err(Error::SQLX(err).into())
}
}
},
};
Ok(Some(
Id::from_str(rec.try_get("id").map_err(Error::from)?).map_err(Error::from)?,
))
}
async fn add_keyset_info(&self, keyset: MintKeySetInfo) -> Result<(), Self::Err> {
let mut transaction = self.pool.begin().await.map_err(Error::from)?;
let res = sqlx::query(
r#"
INSERT OR REPLACE INTO keyset
(id, unit, active, valid_from, valid_to, derivation_path, max_order, derivation_path_index)
VALUES (?, ?, ?, ?, ?, ?, ?, ?);
"#,
)
.bind(keyset.id.to_string())
.bind(keyset.unit.to_string())
.bind(keyset.active)
.bind(keyset.valid_from as i64)
.bind(keyset.valid_to.map(|v| v as i64))
.bind(keyset.derivation_path.to_string())
.bind(keyset.max_order)
.bind(keyset.derivation_path_index)
.execute(&mut *transaction)
.await;
match res {
Ok(_) => {
transaction.commit().await.map_err(Error::from)?;
Ok(())
}
Err(err) => {
tracing::error!("SQLite could not add keyset info");
if let Err(err) = transaction.rollback().await {
tracing::error!("Could not rollback sql transaction: {}", err);
}
Err(Error::from(err).into())
}
}
}
async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
let mut transaction = self.pool.begin().await.map_err(Error::from)?;
let rec = sqlx::query(
r#"
SELECT *
FROM keyset
WHERE id=?;
"#,
)
.bind(id.to_string())
.fetch_one(&mut *transaction)
.await;
match rec {
Ok(rec) => {
transaction.commit().await.map_err(Error::from)?;
Ok(Some(sqlite_row_to_keyset_info(rec)?))
}
Err(err) => match err {
sqlx::Error::RowNotFound => {
transaction.commit().await.map_err(Error::from)?;
return Ok(None);
}
_ => {
tracing::error!("SQLite could not get keyset info");
if let Err(err) = transaction.rollback().await {
tracing::error!("Could not rollback sql transaction: {}", err);
}
return Err(Error::SQLX(err).into());
}
},
}
}
async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err> {
let mut transaction = self.pool.begin().await.map_err(Error::from)?;
let recs = sqlx::query(
r#"
SELECT *
FROM keyset;
"#,
)
.fetch_all(&mut *transaction)
.await
.map_err(Error::from);
match recs {
Ok(recs) => {
transaction.commit().await.map_err(Error::from)?;
Ok(recs
.into_iter()
.map(sqlite_row_to_keyset_info)
.collect::<Result<_, _>>()?)
}
Err(err) => {
tracing::error!("SQLite could not get keyset info");
if let Err(err) = transaction.rollback().await {
tracing::error!("Could not rollback sql transaction: {}", err);
}
Err(err.into())
}
}
}
async fn add_proof(&self, proof: AuthProof) -> Result<(), Self::Err> {
let mut transaction = self.pool.begin().await.map_err(Error::from)?;
if let Err(err) = sqlx::query(
r#"
INSERT INTO proof
(y, keyset_id, secret, c, state)
VALUES (?, ?, ?, ?, ?);
"#,
)
.bind(proof.y()?.to_bytes().to_vec())
.bind(proof.keyset_id.to_string())
.bind(proof.secret.to_string())
.bind(proof.c.to_bytes().to_vec())
.bind("UNSPENT")
.execute(&mut *transaction)
.await
.map_err(Error::from)
{
tracing::debug!("Attempting to add known proof. Skipping.... {:?}", err);
}
transaction.commit().await.map_err(Error::from)?;
.bind(":id", id.to_string())
.execute(&self.pool)
.await?;
Ok(())
}
async fn get_active_keyset_id(&self) -> Result<Option<Id>, Self::Err> {
Ok(query(
r#"
SELECT
id
FROM
keyset
WHERE
active = 1;
"#,
)
.pluck(&self.pool)
.await?
.map(|id| Ok::<_, Error>(column_as_string!(id, Id::from_str, Id::from_bytes)))
.transpose()?)
}
async fn add_keyset_info(&self, keyset: MintKeySetInfo) -> Result<(), Self::Err> {
query(
r#"
INSERT INTO
keyset (
id, unit, active, valid_from, valid_to, derivation_path,
max_order, derivation_path_index
)
VALUES (
:id, :unit, :active, :valid_from, :valid_to, :derivation_path,
:max_order, :derivation_path_index
)
ON CONFLICT(id) DO UPDATE SET
unit = excluded.unit,
active = excluded.active,
valid_from = excluded.valid_from,
valid_to = excluded.valid_to,
derivation_path = excluded.derivation_path,
max_order = excluded.max_order,
derivation_path_index = excluded.derivation_path_index
"#,
)
.bind(":id", keyset.id.to_string())
.bind(":unit", keyset.unit.to_string())
.bind(":active", keyset.active)
.bind(":valid_from", keyset.valid_from as i64)
.bind(":valid_to", keyset.valid_to.map(|v| v as i64))
.bind(":derivation_path", keyset.derivation_path.to_string())
.bind(":max_order", keyset.max_order)
.bind(":derivation_path_index", keyset.derivation_path_index)
.execute(&self.pool)
.await?;
Ok(())
}
async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
Ok(query(
r#"SELECT
id,
unit,
active,
valid_from,
valid_to,
derivation_path,
derivation_path_index,
max_order,
input_fee_ppk
FROM
keyset
WHERE id=:id"#,
)
.bind(":id", id.to_string())
.fetch_one(&self.pool)
.await?
.map(sqlite_row_to_keyset_info)
.transpose()?)
}
async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err> {
Ok(query(
r#"SELECT
id,
unit,
active,
valid_from,
valid_to,
derivation_path,
derivation_path_index,
max_order,
input_fee_ppk
FROM
keyset
WHERE id=:id"#,
)
.fetch_all(&self.pool)
.await?
.into_iter()
.map(sqlite_row_to_keyset_info)
.collect::<Result<Vec<_>, _>>()?)
}
async fn add_proof(&self, proof: AuthProof) -> Result<(), Self::Err> {
if let Err(err) = query(
r#"
INSERT INTO proof
(y, keyset_id, secret, c, state)
VALUES
(:y, :keyset_id, :secret, :c, :state)
"#,
)
.bind(":y", proof.y()?.to_bytes().to_vec())
.bind(":keyset_id", proof.keyset_id.to_string())
.bind(":secret", proof.secret.to_string())
.bind(":c", proof.c.to_bytes().to_vec())
.bind(":state", "UNSPENT".to_string())
.execute(&self.pool)
.await
{
tracing::debug!("Attempting to add known proof. Skipping.... {:?}", err);
}
Ok(())
}
async fn get_proofs_states(&self, ys: &[PublicKey]) -> Result<Vec<Option<State>>, Self::Err> {
let mut transaction = self.pool.begin().await.map_err(Error::from)?;
let sql = format!(
"SELECT y, state FROM proof WHERE y IN ({})",
"?,".repeat(ys.len()).trim_end_matches(',')
);
let mut current_states = ys
.iter()
.fold(sqlx::query(&sql), |query, y| {
query.bind(y.to_bytes().to_vec())
})
.fetch_all(&mut *transaction)
.await
.map_err(|err| {
tracing::error!("SQLite could not get state of proof: {err:?}");
Error::SQLX(err)
})?
let mut current_states = query(r#"SELECT y, state FROM proof WHERE y IN (:ys)"#)
.bind_vec(":ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
.fetch_all(&self.pool)
.await?
.into_iter()
.map(|row| {
PublicKey::from_slice(row.get("y"))
.map_err(Error::from)
.and_then(|y| {
let state: String = row.get("state");
State::from_str(&state)
.map_err(Error::from)
.map(|state| (y, state))
})
Ok((
column_as_string!(&row[0], PublicKey::from_hex, PublicKey::from_slice),
column_as_string!(&row[1], State::from_str),
))
})
.collect::<Result<HashMap<_, _>, _>>()?;
.collect::<Result<HashMap<_, _>, Error>>()?;
Ok(ys.iter().map(|y| current_states.remove(y)).collect())
}
@@ -290,36 +221,27 @@ VALUES (?, ?, ?, ?, ?);
y: &PublicKey,
proofs_state: State,
) -> Result<Option<State>, Self::Err> {
let mut transaction = self.pool.begin().await.map_err(Error::from)?;
let transaction = self.pool.begin().await?;
// Get current state for single y
let current_state = sqlx::query("SELECT state FROM proof WHERE y = ?")
.bind(y.to_bytes().to_vec())
.fetch_optional(&mut *transaction)
.await
.map_err(|err| {
tracing::error!("SQLite could not get state of proof: {err:?}");
Error::SQLX(err)
})?
.map(|row| {
let state: String = row.get("state");
State::from_str(&state).map_err(Error::from)
})
let current_state = query(r#"SELECT state FROM proof WHERE y = :y"#)
.bind(":y", y.to_bytes().to_vec())
.pluck(&transaction)
.await?
.map(|state| Ok::<_, Error>(column_as_string!(state, State::from_str)))
.transpose()?;
// Update state for single y
sqlx::query("UPDATE proof SET state = ? WHERE state != ? AND y = ?")
.bind(proofs_state.to_string())
.bind(State::Spent.to_string())
.bind(y.to_bytes().to_vec())
.execute(&mut *transaction)
.await
.map_err(|err| {
tracing::error!("SQLite could not update proof state: {err:?}");
Error::SQLX(err)
})?;
query(r#"UPDATE proof SET state = :new_state WHERE state = :state AND y = :y"#)
.bind(":y", y.to_bytes().to_vec())
.bind(
":state",
current_state.as_ref().map(|state| state.to_string()),
)
.bind(":new_state", proofs_state.to_string())
.execute(&transaction)
.await?;
transaction.commit().await?;
transaction.commit().await.map_err(Error::from)?;
Ok(current_state)
}
@@ -328,32 +250,27 @@ VALUES (?, ?, ?, ?, ?);
blinded_messages: &[PublicKey],
blind_signatures: &[BlindSignature],
) -> Result<(), Self::Err> {
let mut transaction = self.pool.begin().await.map_err(Error::from)?;
for (message, signature) in blinded_messages.iter().zip(blind_signatures) {
let res = sqlx::query(
r#"
INSERT INTO blind_signature
(y, amount, keyset_id, c)
VALUES (?, ?, ?, ?);
"#,
)
.bind(message.to_bytes().to_vec())
.bind(u64::from(signature.amount) as i64)
.bind(signature.keyset_id.to_string())
.bind(signature.c.to_bytes().to_vec())
.execute(&mut *transaction)
.await;
let transaction = self.pool.begin().await?;
if let Err(err) = res {
tracing::error!("SQLite could not add blind signature");
if let Err(err) = transaction.rollback().await {
tracing::error!("Could not rollback sql transaction: {}", err);
}
return Err(Error::SQLX(err).into());
}
for (message, signature) in blinded_messages.iter().zip(blind_signatures) {
query(
r#"
INSERT
INTO blind_signature
(y, amount, keyset_id, c)
VALUES
(:y, :amount, :keyset_id, :c)
"#,
)
.bind(":y", message.to_bytes().to_vec())
.bind(":amount", u64::from(signature.amount) as i64)
.bind(":keyset_id", signature.keyset_id.to_string())
.bind(":c", signature.c.to_bytes().to_vec())
.execute(&transaction)
.await?;
}
transaction.commit().await.map_err(Error::from)?;
transaction.commit().await?;
Ok(())
}
@@ -362,32 +279,40 @@ VALUES (?, ?, ?, ?);
&self,
blinded_messages: &[PublicKey],
) -> Result<Vec<Option<BlindSignature>>, Self::Err> {
let mut transaction = self.pool.begin().await.map_err(Error::from)?;
let sql = format!(
"SELECT * FROM blind_signature WHERE y IN ({})",
"?,".repeat(blinded_messages.len()).trim_end_matches(',')
);
let mut blinded_signatures = blinded_messages
.iter()
.fold(sqlx::query(&sql), |query, y| {
query.bind(y.to_bytes().to_vec())
})
.fetch_all(&mut *transaction)
.await
.map_err(|err| {
tracing::error!("SQLite could not get state of proof: {err:?}");
Error::SQLX(err)
})?
.into_iter()
.map(|row| {
PublicKey::from_slice(row.get("y"))
.map_err(Error::from)
.and_then(|y| sqlite_row_to_blind_signature(row).map(|blinded| (y, blinded)))
})
.collect::<Result<HashMap<_, _>, _>>()?;
let mut blinded_signatures = query(
r#"SELECT
keyset_id,
amount,
c,
dleq_e,
dleq_s,
y
FROM
blind_signature
WHERE y IN (:y)
"#,
)
.bind_vec(
":y",
blinded_messages
.iter()
.map(|y| y.to_bytes().to_vec())
.collect(),
)
.fetch_all(&self.pool)
.await?
.into_iter()
.map(|mut row| {
Ok((
column_as_string!(
&row.pop().ok_or(Error::InvalidDbResponse)?,
PublicKey::from_hex,
PublicKey::from_slice
),
sqlite_row_to_blind_signature(row)?,
))
})
.collect::<Result<HashMap<_, _>, Error>>()?;
Ok(blinded_messages
.iter()
.map(|y| blinded_signatures.remove(y))
@@ -398,21 +323,20 @@ VALUES (?, ?, ?, ?);
&self,
protected_endpoints: HashMap<ProtectedEndpoint, AuthRequired>,
) -> Result<(), Self::Err> {
let mut transaction = self.pool.begin().await.map_err(Error::from)?;
let transaction = self.pool.begin().await?;
for (endpoint, auth) in protected_endpoints.iter() {
if let Err(err) = sqlx::query(
if let Err(err) = query(
r#"
INSERT OR REPLACE INTO protected_endpoints
(endpoint, auth)
VALUES (?, ?);
"#,
INSERT OR REPLACE INTO protected_endpoints
(endpoint, auth)
VALUES (:endpoint, :auth);
"#,
)
.bind(serde_json::to_string(endpoint)?)
.bind(serde_json::to_string(auth)?)
.execute(&mut *transaction)
.bind(":endpoint", serde_json::to_string(endpoint)?)
.bind(":auth", serde_json::to_string(auth)?)
.execute(&transaction)
.await
.map_err(Error::from)
{
tracing::debug!(
"Attempting to add protected endpoint. Skipping.... {:?}",
@@ -421,7 +345,7 @@ VALUES (?, ?);
}
}
transaction.commit().await.map_err(Error::from)?;
transaction.commit().await?;
Ok(())
}
@@ -429,111 +353,52 @@ VALUES (?, ?);
&self,
protected_endpoints: Vec<ProtectedEndpoint>,
) -> Result<(), Self::Err> {
let mut transaction = self.pool.begin().await.map_err(Error::from)?;
let sql = format!(
"DELETE FROM protected_endpoints WHERE endpoint IN ({})",
std::iter::repeat("?")
.take(protected_endpoints.len())
.collect::<Vec<_>>()
.join(",")
);
let endpoints = protected_endpoints
.iter()
.map(serde_json::to_string)
.collect::<Result<Vec<_>, _>>()?;
endpoints
.iter()
.fold(sqlx::query(&sql), |query, endpoint| query.bind(endpoint))
.execute(&mut *transaction)
.await
.map_err(Error::from)?;
transaction.commit().await.map_err(Error::from)?;
query(r#"DELETE FROM protected_endpoints WHERE endpoint IN (:endpoints)"#)
.bind_vec(
":endpoints",
protected_endpoints
.iter()
.map(serde_json::to_string)
.collect::<Result<_, _>>()?,
)
.execute(&self.pool)
.await?;
Ok(())
}
async fn get_auth_for_endpoint(
&self,
protected_endpoint: ProtectedEndpoint,
) -> Result<Option<AuthRequired>, Self::Err> {
let mut transaction = self.pool.begin().await.map_err(Error::from)?;
let rec = sqlx::query(
r#"
SELECT *
FROM protected_endpoints
WHERE endpoint=?;
"#,
Ok(
query(r#"SELECT auth FROM protected_endpoints WHERE endpoint = :endpoint"#)
.bind(":endpoint", serde_json::to_string(&protected_endpoint)?)
.pluck(&self.pool)
.await?
.map(|auth| {
Ok::<_, Error>(column_as_string!(
auth,
serde_json::from_str,
serde_json::from_slice
))
})
.transpose()?,
)
.bind(serde_json::to_string(&protected_endpoint)?)
.fetch_one(&mut *transaction)
.await;
match rec {
Ok(rec) => {
transaction.commit().await.map_err(Error::from)?;
let auth: String = rec.try_get("auth").map_err(Error::from)?;
Ok(Some(serde_json::from_str(&auth)?))
}
Err(err) => match err {
sqlx::Error::RowNotFound => {
transaction.commit().await.map_err(Error::from)?;
return Ok(None);
}
_ => {
return {
if let Err(err) = transaction.rollback().await {
tracing::error!("Could not rollback sql transaction: {}", err);
}
Err(Error::SQLX(err).into())
}
}
},
}
}
async fn get_auth_for_endpoints(
&self,
) -> Result<HashMap<ProtectedEndpoint, Option<AuthRequired>>, Self::Err> {
let mut transaction = self.pool.begin().await.map_err(Error::from)?;
let recs = sqlx::query(
r#"
SELECT *
FROM protected_endpoints
"#,
)
.fetch_all(&mut *transaction)
.await;
match recs {
Ok(recs) => {
transaction.commit().await.map_err(Error::from)?;
let mut endpoints = HashMap::new();
for rec in recs {
let auth: String = rec.try_get("auth").map_err(Error::from)?;
let endpoint: String = rec.try_get("endpoint").map_err(Error::from)?;
let endpoint: ProtectedEndpoint = serde_json::from_str(&endpoint)?;
let auth: AuthRequired = serde_json::from_str(&auth)?;
endpoints.insert(endpoint, Some(auth));
}
Ok(endpoints)
}
Err(err) => {
tracing::error!("SQLite could not get protected endpoints");
if let Err(err) = transaction.rollback().await {
tracing::error!("Could not rollback sql transaction: {}", err);
}
Err(Error::from(err).into())
}
}
Ok(query(r#"SELECT endpoint, auth FROM protected_endpoints"#)
.fetch_all(&self.pool)
.await?
.into_iter()
.map(|row| {
let endpoint =
column_as_string!(&row[0], serde_json::from_str, serde_json::from_slice);
let auth = column_as_string!(&row[1], serde_json::from_str, serde_json::from_slice);
Ok((endpoint, Some(auth)))
})
.collect::<Result<HashMap<_, _>, Error>>()?)
}
}

View File

@@ -7,7 +7,42 @@ use thiserror::Error;
pub enum Error {
/// SQLX Error
#[error(transparent)]
SQLX(#[from] sqlx::Error),
Sqlite(#[from] rusqlite::Error),
/// Pool error
#[error(transparent)]
Pool(#[from] crate::pool::Error<rusqlite::Error>),
/// Invalid UUID
#[error("Invalid UUID: {0}")]
InvalidUuid(String),
/// QuoteNotFound
#[error("Quote not found")]
QuoteNotFound,
/// Missing named parameter
#[error("Missing named parameter {0}")]
MissingParameter(String),
/// Communication error with the database
#[error("Internal communication error")]
Communication,
/// Invalid response from the database thread
#[error("Unexpected database response")]
InvalidDbResponse,
/// Invalid db type
#[error("Invalid type from db, expected {0} got {1}")]
InvalidType(String, String),
/// Missing columns
#[error("Not enough elements: expected {0}, got {1}")]
MissingColumn(usize, usize),
/// Invalid data conversion in column
#[error("Error converting {0} to {1}")]
InvalidConversion(String, String),
/// NUT00 Error
#[error(transparent)]
CDKNUT00(#[from] cdk_common::nuts::nut00::Error),

View File

@@ -0,0 +1,23 @@
// @generated
// Auto-generated by build.rs
pub static MIGRATIONS: &[(&str, &str)] = &[
("20240612124932_init.sql", include_str!(r#"./migrations/20240612124932_init.sql"#)),
("20240618195700_quote_state.sql", include_str!(r#"./migrations/20240618195700_quote_state.sql"#)),
("20240626092101_nut04_state.sql", include_str!(r#"./migrations/20240626092101_nut04_state.sql"#)),
("20240703122347_request_lookup_id.sql", include_str!(r#"./migrations/20240703122347_request_lookup_id.sql"#)),
("20240710145043_input_fee.sql", include_str!(r#"./migrations/20240710145043_input_fee.sql"#)),
("20240711183109_derivation_path_index.sql", include_str!(r#"./migrations/20240711183109_derivation_path_index.sql"#)),
("20240718203721_allow_unspent.sql", include_str!(r#"./migrations/20240718203721_allow_unspent.sql"#)),
("20240811031111_update_mint_url.sql", include_str!(r#"./migrations/20240811031111_update_mint_url.sql"#)),
("20240919103407_proofs_quote_id.sql", include_str!(r#"./migrations/20240919103407_proofs_quote_id.sql"#)),
("20240923153640_melt_requests.sql", include_str!(r#"./migrations/20240923153640_melt_requests.sql"#)),
("20240930101140_dleq_for_sigs.sql", include_str!(r#"./migrations/20240930101140_dleq_for_sigs.sql"#)),
("20241108093102_mint_mint_quote_pubkey.sql", include_str!(r#"./migrations/20241108093102_mint_mint_quote_pubkey.sql"#)),
("20250103201327_amount_to_pay_msats.sql", include_str!(r#"./migrations/20250103201327_amount_to_pay_msats.sql"#)),
("20250129200912_remove_mint_url.sql", include_str!(r#"./migrations/20250129200912_remove_mint_url.sql"#)),
("20250129230326_add_config_table.sql", include_str!(r#"./migrations/20250129230326_add_config_table.sql"#)),
("20250307213652_keyset_id_as_foreign_key.sql", include_str!(r#"./migrations/20250307213652_keyset_id_as_foreign_key.sql"#)),
("20250406091754_mint_time_of_quotes.sql", include_str!(r#"./migrations/20250406091754_mint_time_of_quotes.sql"#)),
("20250406093755_mint_created_time_signature.sql", include_str!(r#"./migrations/20250406093755_mint_created_time_signature.sql"#)),
("20250415093121_drop_keystore_foreign.sql", include_str!(r#"./migrations/20250415093121_drop_keystore_foreign.sql"#)),
];

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,196 @@
//! Very simple connection pool, to avoid an external dependency on r2d2 and other crates. If this
//! endup work it can be re-used in other parts of the project and may be promoted to its own
//! generic crate
use std::fmt::Debug;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::time::Duration;
/// Pool error
#[derive(thiserror::Error, Debug)]
pub enum Error<E> {
/// Mutex Poison Error
#[error("Internal: PoisonError")]
Poison,
/// Timeout error
#[error("Timed out waiting for a resource")]
Timeout,
/// Internal database error
#[error(transparent)]
Resource(#[from] E),
}
/// Trait to manage resources
pub trait ResourceManager: Debug {
/// The resource to be pooled
type Resource: Debug;
/// The configuration that is needed in order to create the resource
type Config: Debug;
/// The error the resource may return when creating a new instance
type Error: Debug;
/// Creates a new resource with a given config
fn new_resource(config: &Self::Config) -> Result<Self::Resource, Error<Self::Error>>;
/// The object is dropped
fn drop(_resource: Self::Resource) {}
}
/// Generic connection pool of resources R
#[derive(Debug)]
pub struct Pool<RM>
where
RM: ResourceManager,
{
config: RM::Config,
queue: Mutex<Vec<RM::Resource>>,
in_use: AtomicUsize,
max_size: usize,
default_timeout: Duration,
waiter: Condvar,
}
/// The pooled resource
pub struct PooledResource<RM>
where
RM: ResourceManager,
{
resource: Option<RM::Resource>,
pool: Arc<Pool<RM>>,
}
impl<RM> Drop for PooledResource<RM>
where
RM: ResourceManager,
{
fn drop(&mut self) {
if let Some(resource) = self.resource.take() {
let mut active_resource = self.pool.queue.lock().expect("active_resource");
active_resource.push(resource);
self.pool.in_use.fetch_sub(1, Ordering::AcqRel);
// Notify a waiting thread
self.pool.waiter.notify_one();
}
}
}
impl<RM> Deref for PooledResource<RM>
where
RM: ResourceManager,
{
type Target = RM::Resource;
fn deref(&self) -> &Self::Target {
self.resource.as_ref().expect("resource already dropped")
}
}
impl<RM> DerefMut for PooledResource<RM>
where
RM: ResourceManager,
{
fn deref_mut(&mut self) -> &mut Self::Target {
self.resource.as_mut().expect("resource already dropped")
}
}
impl<RM> Pool<RM>
where
RM: ResourceManager,
{
/// Creates a new pool
pub fn new(config: RM::Config, max_size: usize, default_timeout: Duration) -> Arc<Self> {
Arc::new(Self {
config,
queue: Default::default(),
in_use: Default::default(),
waiter: Default::default(),
default_timeout,
max_size,
})
}
/// Similar to get_timeout but uses the default timeout value.
#[inline(always)]
pub fn get(self: &Arc<Self>) -> Result<PooledResource<RM>, Error<RM::Error>> {
self.get_timeout(self.default_timeout)
}
/// Get a new resource or fail after timeout is reached.
///
/// This function will return a free resource or create a new one if there is still room for it;
/// otherwise, it will wait for a resource to be released for reuse.
#[inline(always)]
pub fn get_timeout(
self: &Arc<Self>,
timeout: Duration,
) -> Result<PooledResource<RM>, Error<RM::Error>> {
let mut resources = self.queue.lock().map_err(|_| Error::Poison)?;
loop {
if let Some(resource) = resources.pop() {
drop(resources);
self.in_use.fetch_add(1, Ordering::AcqRel);
return Ok(PooledResource {
resource: Some(resource),
pool: self.clone(),
});
}
if self.in_use.load(Ordering::Relaxed) < self.max_size {
drop(resources);
self.in_use.fetch_add(1, Ordering::AcqRel);
return Ok(PooledResource {
resource: Some(RM::new_resource(&self.config)?),
pool: self.clone(),
});
}
resources = self
.waiter
.wait_timeout(resources, timeout)
.map_err(|_| Error::Poison)
.and_then(|(lock, timeout_result)| {
if timeout_result.timed_out() {
Err(Error::Timeout)
} else {
Ok(lock)
}
})?;
}
}
}
impl<RM> Drop for Pool<RM>
where
RM: ResourceManager,
{
fn drop(&mut self) {
if let Ok(mut resources) = self.queue.lock() {
loop {
while let Some(resource) = resources.pop() {
RM::drop(resource);
}
if self.in_use.load(Ordering::Relaxed) == 0 {
break;
}
resources = if let Ok(resources) = self.waiter.wait(resources) {
resources
} else {
break;
};
}
}
}
}

View File

@@ -0,0 +1,184 @@
use std::collections::HashMap;
use rusqlite::{self, CachedStatement};
use crate::common::SqliteConnectionManager;
use crate::pool::PooledResource;
/// The Value coming from SQLite
pub type Value = rusqlite::types::Value;
/// The Column type
pub type Column = Value;
/// Expected response type for a given SQL statement
#[derive(Debug, Clone, Copy, Default)]
pub enum ExpectedSqlResponse {
/// A single row
SingleRow,
/// All the rows that matches a query
#[default]
ManyRows,
/// How many rows were affected by the query
AffectedRows,
/// Return the first column of the first row
Pluck,
}
/// Sql message
#[derive(Default, Debug)]
pub struct Statement {
/// The SQL statement
pub sql: String,
/// The list of arguments for the placeholders. It only supports named arguments for simplicity
/// sake
pub args: HashMap<String, Value>,
/// The expected response type
pub expected_response: ExpectedSqlResponse,
}
impl Statement {
/// Creates a new statement
pub fn new<T>(sql: T) -> Self
where
T: ToString,
{
Self {
sql: sql.to_string(),
..Default::default()
}
}
/// Binds a given placeholder to a value.
#[inline]
pub fn bind<C, V>(mut self, name: C, value: V) -> Self
where
C: ToString,
V: Into<Value>,
{
self.args.insert(name.to_string(), value.into());
self
}
/// Binds a single variable with a vector.
///
/// This will rewrite the function from `:foo` (where value is vec![1, 2, 3]) to `:foo0, :foo1,
/// :foo2` and binds each value from the value vector accordingly.
#[inline]
pub fn bind_vec<C, V>(mut self, name: C, value: Vec<V>) -> Self
where
C: ToString,
V: Into<Value>,
{
let mut new_sql = String::with_capacity(self.sql.len());
let target = name.to_string();
let mut i = 0;
let placeholders = value
.into_iter()
.enumerate()
.map(|(key, value)| {
let key = format!("{target}{key}");
self.args.insert(key.clone(), value.into());
key
})
.collect::<Vec<_>>()
.join(",");
while let Some(pos) = self.sql[i..].find(&target) {
let abs_pos = i + pos;
let after = abs_pos + target.len();
let is_word_boundary = self.sql[after..]
.chars()
.next()
.map_or(true, |c| !c.is_alphanumeric() && c != '_');
if is_word_boundary {
new_sql.push_str(&self.sql[i..abs_pos]);
new_sql.push_str(&placeholders);
i = after;
} else {
new_sql.push_str(&self.sql[i..=abs_pos]);
i = abs_pos + 1;
}
}
new_sql.push_str(&self.sql[i..]);
self.sql = new_sql;
self
}
fn get_stmt(
self,
conn: &PooledResource<SqliteConnectionManager>,
) -> rusqlite::Result<CachedStatement<'_>> {
let mut stmt = conn.prepare_cached(&self.sql)?;
for (name, value) in self.args {
let index = stmt
.parameter_index(&name)
.map_err(|_| rusqlite::Error::InvalidColumnName(name.clone()))?
.ok_or(rusqlite::Error::InvalidColumnName(name))?;
stmt.raw_bind_parameter(index, value)?;
}
Ok(stmt)
}
/// Executes a query and returns the affected rows
pub fn plunk(
self,
conn: &PooledResource<SqliteConnectionManager>,
) -> rusqlite::Result<Option<Value>> {
let mut stmt = self.get_stmt(conn)?;
let mut rows = stmt.raw_query();
rows.next()?.map(|row| row.get(0)).transpose()
}
/// Executes a query and returns the affected rows
pub fn execute(
self,
conn: &PooledResource<SqliteConnectionManager>,
) -> rusqlite::Result<usize> {
self.get_stmt(conn)?.raw_execute()
}
/// Runs the query and returns the first row or None
pub fn fetch_one(
self,
conn: &PooledResource<SqliteConnectionManager>,
) -> rusqlite::Result<Option<Vec<Column>>> {
let mut stmt = self.get_stmt(conn)?;
let columns = stmt.column_count();
let mut rows = stmt.raw_query();
rows.next()?
.map(|row| {
(0..columns)
.map(|i| row.get(i))
.collect::<Result<Vec<_>, _>>()
})
.transpose()
}
/// Runs the query and returns the first row or None
pub fn fetch_all(
self,
conn: &PooledResource<SqliteConnectionManager>,
) -> rusqlite::Result<Vec<Vec<Column>>> {
let mut stmt = self.get_stmt(conn)?;
let columns = stmt.column_count();
let mut rows = stmt.raw_query();
let mut results = vec![];
while let Some(row) = rows.next()? {
results.push(
(0..columns)
.map(|i| row.get(i))
.collect::<Result<Vec<_>, _>>()?,
);
}
Ok(results)
}
}

View File

@@ -7,7 +7,23 @@ use thiserror::Error;
pub enum Error {
/// SQLX Error
#[error(transparent)]
SQLX(#[from] sqlx::Error),
Sqlite(#[from] rusqlite::Error),
/// Pool error
#[error(transparent)]
Pool(#[from] crate::pool::Error<rusqlite::Error>),
/// Missing columns
#[error("Not enough elements: expected {0}, got {1}")]
MissingColumn(usize, usize),
/// Invalid db type
#[error("Invalid type from db, expected {0} got {1}")]
InvalidType(String, String),
/// Invalid data conversion in column
#[error("Error converting {0} to {1}")]
InvalidConversion(String, String),
/// Serde Error
#[error(transparent)]
Serde(#[from] serde_json::Error),

View File

@@ -0,0 +1,19 @@
// @generated
// Auto-generated by build.rs
pub static MIGRATIONS: &[(&str, &str)] = &[
("20240612132920_init.sql", include_str!(r#"./migrations/20240612132920_init.sql"#)),
("20240618200350_quote_state.sql", include_str!(r#"./migrations/20240618200350_quote_state.sql"#)),
("20240626091921_nut04_state.sql", include_str!(r#"./migrations/20240626091921_nut04_state.sql"#)),
("20240710144711_input_fee.sql", include_str!(r#"./migrations/20240710144711_input_fee.sql"#)),
("20240810214105_mint_icon_url.sql", include_str!(r#"./migrations/20240810214105_mint_icon_url.sql"#)),
("20240810233905_update_mint_url.sql", include_str!(r#"./migrations/20240810233905_update_mint_url.sql"#)),
("20240902151515_icon_url.sql", include_str!(r#"./migrations/20240902151515_icon_url.sql"#)),
("20240902210905_mint_time.sql", include_str!(r#"./migrations/20240902210905_mint_time.sql"#)),
("20241011125207_mint_urls.sql", include_str!(r#"./migrations/20241011125207_mint_urls.sql"#)),
("20241108092756_wallet_mint_quote_secretkey.sql", include_str!(r#"./migrations/20241108092756_wallet_mint_quote_secretkey.sql"#)),
("20250214135017_mint_tos.sql", include_str!(r#"./migrations/20250214135017_mint_tos.sql"#)),
("20250310111513_drop_nostr_last_checked.sql", include_str!(r#"./migrations/20250310111513_drop_nostr_last_checked.sql"#)),
("20250314082116_allow_pending_spent.sql", include_str!(r#"./migrations/20250314082116_allow_pending_spent.sql"#)),
("20250323152040_wallet_dleq_proofs.sql", include_str!(r#"./migrations/20250323152040_wallet_dleq_proofs.sql"#)),
("20250401120000_add_transactions_table.sql", include_str!(r#"./migrations/20250401120000_add_transactions_table.sql"#)),
];

File diff suppressed because it is too large Load Diff