mirror of
https://github.com/aljazceru/cdk.git
synced 2025-12-20 06:05:09 +01:00
346 lines
11 KiB
Rust
346 lines
11 KiB
Rust
use std::fmt::Debug;
|
|
use std::sync::atomic::AtomicBool;
|
|
use std::sync::{Arc, OnceLock};
|
|
use std::time::Duration;
|
|
|
|
use cdk_common::database::Error;
|
|
use cdk_sql_common::database::{DatabaseConnector, DatabaseExecutor, GenericTransactionHandler};
|
|
use cdk_sql_common::mint::SQLMintAuthDatabase;
|
|
use cdk_sql_common::pool::{DatabaseConfig, DatabasePool};
|
|
use cdk_sql_common::stmt::{Column, Statement};
|
|
use cdk_sql_common::{SQLMintDatabase, SQLWalletDatabase};
|
|
use db::{pg_batch, pg_execute, pg_fetch_all, pg_fetch_one, pg_pluck};
|
|
use native_tls::TlsConnector;
|
|
use postgres_native_tls::MakeTlsConnector;
|
|
use tokio::sync::{Mutex, Notify};
|
|
use tokio::time::timeout;
|
|
use tokio_postgres::{connect, Client, Error as PgError, NoTls};
|
|
|
|
mod db;
|
|
mod value;
|
|
|
|
#[derive(Debug)]
|
|
pub struct PgConnectionPool;
|
|
|
|
#[derive(Clone)]
|
|
pub enum SslMode {
|
|
NoTls(NoTls),
|
|
NativeTls(postgres_native_tls::MakeTlsConnector),
|
|
}
|
|
const SSLMODE_VERIFY_FULL: &str = "sslmode=verify-full";
|
|
const SSLMODE_VERIFY_CA: &str = "sslmode=verify-ca";
|
|
const SSLMODE_PREFER: &str = "sslmode=prefer";
|
|
const SSLMODE_ALLOW: &str = "sslmode=allow";
|
|
const SSLMODE_REQUIRE: &str = "sslmode=require";
|
|
|
|
impl Default for SslMode {
|
|
fn default() -> Self {
|
|
SslMode::NoTls(NoTls {})
|
|
}
|
|
}
|
|
|
|
impl Debug for SslMode {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
let debug_text = match self {
|
|
Self::NoTls(_) => "NoTls",
|
|
Self::NativeTls(_) => "NativeTls",
|
|
};
|
|
|
|
write!(f, "SslMode::{debug_text}")
|
|
}
|
|
}
|
|
|
|
/// Postgres configuration
|
|
#[derive(Clone, Debug)]
|
|
pub struct PgConfig {
|
|
url: String,
|
|
schema: Option<String>,
|
|
tls: SslMode,
|
|
}
|
|
|
|
impl DatabaseConfig for PgConfig {
|
|
fn default_timeout(&self) -> Duration {
|
|
Duration::from_secs(10)
|
|
}
|
|
|
|
fn max_size(&self) -> usize {
|
|
20
|
|
}
|
|
}
|
|
|
|
impl PgConfig {
|
|
/// strip schema from the connection string
|
|
fn strip_schema(input: &str) -> (Option<String>, String) {
|
|
let mut schema: Option<String> = None;
|
|
|
|
// Split by whitespace
|
|
let mut parts = Vec::new();
|
|
for token in input.split_whitespace() {
|
|
if let Some(rest) = token.strip_prefix("schema=") {
|
|
schema = Some(rest.to_string());
|
|
} else {
|
|
parts.push(token);
|
|
}
|
|
}
|
|
|
|
let cleaned = parts.join(" ");
|
|
(schema, cleaned)
|
|
}
|
|
}
|
|
|
|
impl From<&str> for PgConfig {
|
|
fn from(conn_str: &str) -> Self {
|
|
let (schema, conn_str) = Self::strip_schema(conn_str);
|
|
fn build_tls(accept_invalid_certs: bool, accept_invalid_hostnames: bool) -> SslMode {
|
|
let mut builder = TlsConnector::builder();
|
|
if accept_invalid_certs {
|
|
builder.danger_accept_invalid_certs(true);
|
|
}
|
|
if accept_invalid_hostnames {
|
|
builder.danger_accept_invalid_hostnames(true);
|
|
}
|
|
|
|
match builder.build() {
|
|
Ok(connector) => {
|
|
let make_tls_connector = MakeTlsConnector::new(connector);
|
|
SslMode::NativeTls(make_tls_connector)
|
|
}
|
|
Err(_) => SslMode::NoTls(NoTls {}),
|
|
}
|
|
}
|
|
|
|
let tls = if conn_str.contains(SSLMODE_VERIFY_FULL) {
|
|
// Strict TLS: valid certs and hostnames required
|
|
build_tls(false, false)
|
|
} else if conn_str.contains(SSLMODE_VERIFY_CA) {
|
|
// Verify CA, but allow invalid hostnames
|
|
build_tls(false, true)
|
|
} else if conn_str.contains(SSLMODE_PREFER)
|
|
|| conn_str.contains(SSLMODE_ALLOW)
|
|
|| conn_str.contains(SSLMODE_REQUIRE)
|
|
{
|
|
// Lenient TLS for preferred/allow/require: accept invalid certs and hostnames
|
|
build_tls(true, true)
|
|
} else {
|
|
SslMode::NoTls(NoTls {})
|
|
};
|
|
|
|
PgConfig {
|
|
url: conn_str.to_owned(),
|
|
schema,
|
|
tls,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl DatabasePool for PgConnectionPool {
|
|
type Config = PgConfig;
|
|
|
|
type Connection = PostgresConnection;
|
|
|
|
type Error = PgError;
|
|
|
|
fn new_resource(
|
|
config: &Self::Config,
|
|
stale: Arc<AtomicBool>,
|
|
timeout: Duration,
|
|
) -> Result<Self::Connection, cdk_sql_common::pool::Error<Self::Error>> {
|
|
Ok(PostgresConnection::new(config.to_owned(), timeout, stale))
|
|
}
|
|
}
|
|
|
|
/// A postgres connection
|
|
#[derive(Debug)]
|
|
pub struct PostgresConnection {
|
|
timeout: Duration,
|
|
error: Arc<Mutex<Option<cdk_common::database::Error>>>,
|
|
result: Arc<OnceLock<Client>>,
|
|
notify: Arc<Notify>,
|
|
}
|
|
|
|
impl PostgresConnection {
|
|
/// Creates a new instance
|
|
pub fn new(config: PgConfig, timeout: Duration, stale: Arc<AtomicBool>) -> Self {
|
|
let failed = Arc::new(Mutex::new(None));
|
|
let result = Arc::new(OnceLock::new());
|
|
let notify = Arc::new(Notify::new());
|
|
let error_clone = failed.clone();
|
|
let result_clone = result.clone();
|
|
let notify_clone = notify.clone();
|
|
|
|
async fn select_schema(conn: &Client, schema: &str) -> Result<(), Error> {
|
|
conn.batch_execute(&format!(
|
|
r#"
|
|
CREATE SCHEMA IF NOT EXISTS "{schema}";
|
|
SET search_path TO "{schema}"
|
|
"#
|
|
))
|
|
.await
|
|
.map_err(|e| Error::Database(Box::new(e)))
|
|
}
|
|
|
|
tokio::spawn(async move {
|
|
match config.tls {
|
|
SslMode::NoTls(tls) => {
|
|
let (client, connection) = match connect(&config.url, tls).await {
|
|
Ok((client, connection)) => (client, connection),
|
|
Err(err) => {
|
|
*error_clone.lock().await =
|
|
Some(cdk_common::database::Error::Database(Box::new(err)));
|
|
stale.store(false, std::sync::atomic::Ordering::Release);
|
|
notify_clone.notify_waiters();
|
|
return;
|
|
}
|
|
};
|
|
|
|
let stale_for_spawn = stale.clone();
|
|
tokio::spawn(async move {
|
|
let _ = connection.await;
|
|
stale_for_spawn.store(true, std::sync::atomic::Ordering::Release);
|
|
});
|
|
|
|
if let Some(schema) = config.schema.as_ref() {
|
|
if let Err(err) = select_schema(&client, schema).await {
|
|
*error_clone.lock().await = Some(err);
|
|
stale.store(false, std::sync::atomic::Ordering::Release);
|
|
notify_clone.notify_waiters();
|
|
return;
|
|
}
|
|
}
|
|
|
|
let _ = result_clone.set(client);
|
|
notify_clone.notify_waiters();
|
|
}
|
|
SslMode::NativeTls(tls) => {
|
|
let (client, connection) = match connect(&config.url, tls).await {
|
|
Ok((client, connection)) => (client, connection),
|
|
Err(err) => {
|
|
*error_clone.lock().await =
|
|
Some(cdk_common::database::Error::Database(Box::new(err)));
|
|
stale.store(false, std::sync::atomic::Ordering::Release);
|
|
notify_clone.notify_waiters();
|
|
return;
|
|
}
|
|
};
|
|
|
|
let stale_for_spawn = stale.clone();
|
|
tokio::spawn(async move {
|
|
let _ = connection.await;
|
|
stale_for_spawn.store(true, std::sync::atomic::Ordering::Release);
|
|
});
|
|
|
|
if let Some(schema) = config.schema.as_ref() {
|
|
if let Err(err) = select_schema(&client, schema).await {
|
|
*error_clone.lock().await = Some(err);
|
|
stale.store(true, std::sync::atomic::Ordering::Release);
|
|
notify_clone.notify_waiters();
|
|
return;
|
|
}
|
|
}
|
|
|
|
let _ = result_clone.set(client);
|
|
notify_clone.notify_waiters();
|
|
}
|
|
}
|
|
});
|
|
|
|
Self {
|
|
error: failed,
|
|
timeout,
|
|
result,
|
|
notify,
|
|
}
|
|
}
|
|
|
|
/// Gets the wrapped instance or the connection error. The connection is returned as reference,
|
|
/// and the actual error is returned once, next times a generic error would be returned
|
|
async fn inner(&self) -> Result<&Client, cdk_common::database::Error> {
|
|
if let Some(client) = self.result.get() {
|
|
return Ok(client);
|
|
}
|
|
|
|
if let Some(error) = self.error.lock().await.take() {
|
|
return Err(error);
|
|
}
|
|
|
|
if timeout(self.timeout, self.notify.notified()).await.is_err() {
|
|
return Err(cdk_common::database::Error::Internal("Timeout".to_owned()));
|
|
}
|
|
|
|
// Check result again
|
|
if let Some(client) = self.result.get() {
|
|
Ok(client)
|
|
} else if let Some(error) = self.error.lock().await.take() {
|
|
Err(error)
|
|
} else {
|
|
Err(cdk_common::database::Error::Internal(
|
|
"Failed connection".to_owned(),
|
|
))
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
impl DatabaseConnector for PostgresConnection {
|
|
type Transaction = GenericTransactionHandler<Self>;
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
impl DatabaseExecutor for PostgresConnection {
|
|
fn name() -> &'static str {
|
|
"postgres"
|
|
}
|
|
|
|
async fn execute(&self, statement: Statement) -> Result<usize, Error> {
|
|
pg_execute(self.inner().await?, statement).await
|
|
}
|
|
|
|
async fn fetch_one(&self, statement: Statement) -> Result<Option<Vec<Column>>, Error> {
|
|
pg_fetch_one(self.inner().await?, statement).await
|
|
}
|
|
|
|
async fn fetch_all(&self, statement: Statement) -> Result<Vec<Vec<Column>>, Error> {
|
|
pg_fetch_all(self.inner().await?, statement).await
|
|
}
|
|
|
|
async fn pluck(&self, statement: Statement) -> Result<Option<Column>, Error> {
|
|
pg_pluck(self.inner().await?, statement).await
|
|
}
|
|
|
|
async fn batch(&self, statement: Statement) -> Result<(), Error> {
|
|
pg_batch(self.inner().await?, statement).await
|
|
}
|
|
}
|
|
|
|
/// Mint DB implementation with PostgreSQL
|
|
pub type MintPgDatabase = SQLMintDatabase<PgConnectionPool>;
|
|
|
|
/// Mint Auth database with Postgres
|
|
#[cfg(feature = "auth")]
|
|
pub type MintPgAuthDatabase = SQLMintAuthDatabase<PgConnectionPool>;
|
|
|
|
/// Mint DB implementation with PostgresSQL
|
|
pub type WalletPgDatabase = SQLWalletDatabase<PgConnectionPool>;
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use cdk_common::mint_db_test;
|
|
|
|
use super::*;
|
|
|
|
async fn provide_db(test_id: String) -> MintPgDatabase {
|
|
let db_url = std::env::var("CDK_MINTD_DATABASE_URL")
|
|
.or_else(|_| std::env::var("PG_DB_URL")) // Fallback for compatibility
|
|
.unwrap_or("host=localhost user=test password=test dbname=testdb port=5433".to_owned());
|
|
|
|
let db_url = format!("{db_url} schema={test_id}");
|
|
|
|
let db = MintPgDatabase::new(db_url.as_str())
|
|
.await
|
|
.expect("database");
|
|
db
|
|
}
|
|
|
|
mint_db_test!(provide_db);
|
|
}
|