diff --git a/crates/cdk-postgres/Cargo.toml b/crates/cdk-postgres/Cargo.toml index cadcf2bd..5b5b4af2 100644 --- a/crates/cdk-postgres/Cargo.toml +++ b/crates/cdk-postgres/Cargo.toml @@ -32,4 +32,5 @@ uuid.workspace = true tokio-postgres = "0.7.13" futures-util = "0.3.31" postgres-native-tls = "0.5.1" +native-tls = "0.2" once_cell.workspace = true diff --git a/crates/cdk-postgres/src/lib.rs b/crates/cdk-postgres/src/lib.rs index 408e60ec..e32a0f35 100644 --- a/crates/cdk-postgres/src/lib.rs +++ b/crates/cdk-postgres/src/lib.rs @@ -10,6 +10,9 @@ use cdk_sql_common::pool::{DatabaseConfig, DatabasePool}; use cdk_sql_common::stmt::{Column, Statement}; use cdk_sql_common::{SQLMintDatabase, SQLWalletDatabase}; use db::{pg_batch, pg_execute, pg_fetch_all, pg_fetch_one, pg_pluck}; +use native_tls; +use native_tls::TlsConnector; +use postgres_native_tls::MakeTlsConnector; use tokio::sync::{Mutex, Notify}; use tokio::time::timeout; use tokio_postgres::{connect, Client, Error as PgError, NoTls}; @@ -25,6 +28,11 @@ pub enum SslMode { NoTls(NoTls), NativeTls(postgres_native_tls::MakeTlsConnector), } +const SSLMODE_VERIFY_FULL: &str = "sslmode=verify-full"; +const SSLMODE_VERIFY_CA: &str = "sslmode=verify-ca"; +const SSLMODE_PREFER: &str = "sslmode=prefer"; +const SSLMODE_ALLOW: &str = "sslmode=allow"; +const SSLMODE_REQUIRE: &str = "sslmode=require"; impl Default for SslMode { fn default() -> Self { @@ -61,10 +69,44 @@ impl DatabaseConfig for PgConfig { } impl From<&str> for PgConfig { - fn from(value: &str) -> Self { + fn from(conn_str: &str) -> Self { + fn build_tls(accept_invalid_certs: bool, accept_invalid_hostnames: bool) -> SslMode { + let mut builder = TlsConnector::builder(); + if accept_invalid_certs { + builder.danger_accept_invalid_certs(true); + } + if accept_invalid_hostnames { + builder.danger_accept_invalid_hostnames(true); + } + + match builder.build() { + Ok(connector) => { + let make_tls_connector = MakeTlsConnector::new(connector); + SslMode::NativeTls(make_tls_connector) + } + Err(_) => SslMode::NoTls(NoTls {}), + } + } + + let tls = if conn_str.contains(SSLMODE_VERIFY_FULL) { + // Strict TLS: valid certs and hostnames required + build_tls(false, false) + } else if conn_str.contains(SSLMODE_VERIFY_CA) { + // Verify CA, but allow invalid hostnames + build_tls(false, true) + } else if conn_str.contains(SSLMODE_PREFER) + || conn_str.contains(SSLMODE_ALLOW) + || conn_str.contains(SSLMODE_REQUIRE) + { + // Lenient TLS for preferred/allow/require: accept invalid certs and hostnames + build_tls(true, true) + } else { + SslMode::NoTls(NoTls {}) + }; + PgConfig { - url: value.to_owned(), - tls: Default::default(), + url: conn_str.to_owned(), + tls, } } }