mirror of
https://github.com/aljazceru/enclava.git
synced 2026-02-22 00:04:34 +01:00
368 lines
13 KiB
Python
368 lines
13 KiB
Python
"""
|
|
Database connection and session management
|
|
|
|
This module manages database connections with optimized pool settings:
|
|
- Primary async pool (asyncpg): 30 + 50 overflow = 80 max connections
|
|
- Legacy sync pool (psycopg2): 5 + 10 overflow = 15 max connections
|
|
- Total: 95 max connections (under PostgreSQL default of 100)
|
|
|
|
Pool monitoring is available via get_pool_status() function.
|
|
"""
|
|
|
|
import logging
|
|
from typing import AsyncGenerator, Dict, Any
|
|
from sqlalchemy import create_engine, MetaData, event
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
from sqlalchemy.orm import sessionmaker, declarative_base
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
from app.core.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Pool metrics tracking
|
|
_pool_metrics = {
|
|
"async_checkouts": 0,
|
|
"async_checkins": 0,
|
|
"async_overflow": 0,
|
|
"sync_checkouts": 0,
|
|
"sync_checkins": 0,
|
|
"sync_overflow": 0,
|
|
}
|
|
|
|
# Create async engine with optimized connection pooling
|
|
# This is the PRIMARY engine - most operations should use async sessions
|
|
# Pool sizing: 30 base + 50 overflow = 80 max connections
|
|
# Note: PostgreSQL default max_connections=100, leave headroom for admin/monitoring
|
|
engine = create_async_engine(
|
|
settings.DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://"),
|
|
echo=settings.APP_DEBUG,
|
|
future=True,
|
|
pool_pre_ping=True,
|
|
pool_size=30, # Base pool size for steady-state operations
|
|
max_overflow=50, # Burst capacity for high load
|
|
pool_recycle=3600, # Recycle connections every hour
|
|
pool_timeout=30, # Max time to get connection from pool
|
|
connect_args={
|
|
"timeout": 5,
|
|
"command_timeout": 5,
|
|
"server_settings": {
|
|
"application_name": "enclava_backend",
|
|
},
|
|
},
|
|
)
|
|
|
|
# Create async session factory
|
|
async_session_factory = async_sessionmaker(
|
|
bind=engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
)
|
|
|
|
# Create synchronous engine for legacy code paths and startup operations
|
|
# IMPORTANT: This pool should be MINIMAL - prefer async operations for all new code
|
|
# Most budget enforcement, chatbot, and API operations now use async sessions
|
|
# Pool sizing: 5 base + 10 overflow = 15 max connections
|
|
sync_engine = create_engine(
|
|
settings.DATABASE_URL,
|
|
echo=settings.APP_DEBUG,
|
|
future=True,
|
|
pool_pre_ping=True,
|
|
pool_size=5, # Minimal - only for startup/migrations/legacy paths
|
|
max_overflow=10, # Small burst for edge cases
|
|
pool_recycle=3600, # Recycle connections every hour
|
|
pool_timeout=30, # Max time to get connection from pool
|
|
connect_args={
|
|
"connect_timeout": 5,
|
|
"application_name": "enclava_backend_sync",
|
|
},
|
|
)
|
|
|
|
# Create sync session factory
|
|
# NOTE: Prefer async_session_factory for all new code
|
|
SessionLocal = sessionmaker(
|
|
bind=sync_engine,
|
|
expire_on_commit=False,
|
|
)
|
|
|
|
# Create base class for models
|
|
Base = declarative_base()
|
|
|
|
|
|
# ============================================================================
|
|
# Pool Monitoring
|
|
# ============================================================================
|
|
|
|
def _setup_pool_monitoring():
|
|
"""Set up event listeners for pool monitoring"""
|
|
|
|
# Sync engine pool events
|
|
@event.listens_for(sync_engine, "checkout")
|
|
def sync_checkout(dbapi_conn, connection_record, connection_proxy):
|
|
_pool_metrics["sync_checkouts"] += 1
|
|
pool = sync_engine.pool
|
|
if pool.overflow() > 0:
|
|
_pool_metrics["sync_overflow"] = pool.overflow()
|
|
logger.debug(f"Sync pool checkout (overflow: {pool.overflow()})")
|
|
|
|
@event.listens_for(sync_engine, "checkin")
|
|
def sync_checkin(dbapi_conn, connection_record):
|
|
_pool_metrics["sync_checkins"] += 1
|
|
|
|
# Note: Async engine pool events work on the underlying sync engine
|
|
# We access it via engine.sync_engine for event registration
|
|
try:
|
|
sync_async_engine = engine.sync_engine
|
|
|
|
@event.listens_for(sync_async_engine, "checkout")
|
|
def async_checkout(dbapi_conn, connection_record, connection_proxy):
|
|
_pool_metrics["async_checkouts"] += 1
|
|
pool = sync_async_engine.pool
|
|
if pool.overflow() > 0:
|
|
_pool_metrics["async_overflow"] = pool.overflow()
|
|
logger.debug(f"Async pool checkout (overflow: {pool.overflow()})")
|
|
|
|
@event.listens_for(sync_async_engine, "checkin")
|
|
def async_checkin(dbapi_conn, connection_record):
|
|
_pool_metrics["async_checkins"] += 1
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Could not set up async pool monitoring: {e}")
|
|
|
|
|
|
def get_pool_status() -> Dict[str, Any]:
|
|
"""
|
|
Get current status of database connection pools.
|
|
|
|
Returns:
|
|
Dict containing pool statistics for both async and sync engines
|
|
"""
|
|
try:
|
|
# Get async pool status
|
|
async_pool = engine.sync_engine.pool
|
|
async_status = {
|
|
"size": async_pool.size(),
|
|
"checked_in": async_pool.checkedin(),
|
|
"checked_out": async_pool.checkedout(),
|
|
"overflow": async_pool.overflow(),
|
|
"invalid": async_pool.invalidatedcount() if hasattr(async_pool, 'invalidatedcount') else 0,
|
|
}
|
|
except Exception as e:
|
|
async_status = {"error": str(e)}
|
|
|
|
try:
|
|
# Get sync pool status
|
|
sync_pool = sync_engine.pool
|
|
sync_status = {
|
|
"size": sync_pool.size(),
|
|
"checked_in": sync_pool.checkedin(),
|
|
"checked_out": sync_pool.checkedout(),
|
|
"overflow": sync_pool.overflow(),
|
|
"invalid": sync_pool.invalidatedcount() if hasattr(sync_pool, 'invalidatedcount') else 0,
|
|
}
|
|
except Exception as e:
|
|
sync_status = {"error": str(e)}
|
|
|
|
return {
|
|
"async_pool": async_status,
|
|
"sync_pool": sync_status,
|
|
"metrics": _pool_metrics.copy(),
|
|
"config": {
|
|
"async_pool_size": 30,
|
|
"async_max_overflow": 50,
|
|
"async_max_connections": 80,
|
|
"sync_pool_size": 5,
|
|
"sync_max_overflow": 10,
|
|
"sync_max_connections": 15,
|
|
"total_max_connections": 95,
|
|
}
|
|
}
|
|
|
|
|
|
def log_pool_status():
|
|
"""Log current pool status (useful for debugging)"""
|
|
status = get_pool_status()
|
|
logger.info(f"Database pool status: {status}")
|
|
|
|
|
|
# Initialize pool monitoring
|
|
_setup_pool_monitoring()
|
|
|
|
# Metadata for migrations
|
|
metadata = MetaData()
|
|
|
|
|
|
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
|
"""Get database session"""
|
|
from fastapi import HTTPException
|
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
|
|
async with async_session_factory() as session:
|
|
try:
|
|
yield session
|
|
except (HTTPException, StarletteHTTPException):
|
|
# Don't log HTTP exceptions - these are normal API responses (401, 403, 404, etc.)
|
|
# Just rollback any pending transaction and re-raise
|
|
await session.rollback()
|
|
raise
|
|
except SQLAlchemyError as e:
|
|
# Log actual database errors
|
|
logger.error(f"Database error during request: {e}")
|
|
await session.rollback()
|
|
raise
|
|
except Exception as e:
|
|
# Log unexpected errors but don't treat them as database failures
|
|
logger.warning(f"Request error (non-database): {type(e).__name__}")
|
|
await session.rollback()
|
|
raise
|
|
|
|
|
|
async def init_db():
|
|
"""Initialize database"""
|
|
try:
|
|
async with engine.begin() as conn:
|
|
# Import all models to ensure they're registered
|
|
from app.models.user import User
|
|
from app.models.role import Role
|
|
from app.models.api_key import APIKey
|
|
from app.models.usage_tracking import UsageTracking
|
|
|
|
# Import additional models - these are available
|
|
try:
|
|
from app.models.budget import Budget
|
|
except ImportError:
|
|
logger.warning("Budget model not available yet")
|
|
|
|
try:
|
|
from app.models.audit_log import AuditLog
|
|
except ImportError:
|
|
logger.warning("AuditLog model not available yet")
|
|
|
|
try:
|
|
from app.models.module import Module
|
|
except ImportError:
|
|
logger.warning("Module model not available yet")
|
|
|
|
# Tables are now created via migration container - no need to create here
|
|
# await conn.run_sync(Base.metadata.create_all) # DISABLED - migrations handle this
|
|
|
|
# Create default roles if they don't exist
|
|
await create_default_roles()
|
|
|
|
# Create default admin user if no admin exists
|
|
await create_default_admin()
|
|
|
|
logger.info("Database initialized successfully")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize database: {e}")
|
|
raise
|
|
|
|
|
|
async def create_default_roles():
|
|
"""Create default roles if they don't exist"""
|
|
from app.models.role import Role, RoleLevel
|
|
from sqlalchemy import select
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
try:
|
|
async with async_session_factory() as session:
|
|
# Check if any roles exist
|
|
stmt = select(Role).limit(1)
|
|
result = await session.execute(stmt)
|
|
existing_role = result.scalar_one_or_none()
|
|
|
|
if existing_role:
|
|
logger.info("Roles already exist - skipping default role creation")
|
|
return
|
|
|
|
# Create default roles using the Role.create_default_roles class method
|
|
default_roles = Role.create_default_roles()
|
|
|
|
for role in default_roles:
|
|
session.add(role)
|
|
|
|
await session.commit()
|
|
|
|
logger.info("Created default roles: read_only, user, admin, super_admin")
|
|
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to create default roles due to database error: {e}")
|
|
raise
|
|
|
|
|
|
async def create_default_admin():
|
|
"""Create default admin user if user with ADMIN_EMAIL doesn't exist"""
|
|
from app.models.user import User
|
|
from app.models.role import Role
|
|
from app.core.security import get_password_hash
|
|
from app.core.config import settings
|
|
from sqlalchemy import select
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
try:
|
|
admin_email = settings.ADMIN_EMAIL
|
|
admin_password = settings.ADMIN_PASSWORD
|
|
|
|
if not admin_email or not admin_password:
|
|
logger.info("Admin bootstrap skipped: ADMIN_EMAIL or ADMIN_PASSWORD unset")
|
|
return
|
|
|
|
async with async_session_factory() as session:
|
|
# Check if user with ADMIN_EMAIL exists
|
|
stmt = select(User).where(User.email == admin_email)
|
|
result = await session.execute(stmt)
|
|
existing_user = result.scalar_one_or_none()
|
|
|
|
if existing_user:
|
|
logger.info(
|
|
f"User with email {admin_email} already exists - skipping admin creation"
|
|
)
|
|
return
|
|
|
|
# Get the super_admin role
|
|
stmt = select(Role).where(Role.name == "super_admin")
|
|
result = await session.execute(stmt)
|
|
super_admin_role = result.scalar_one_or_none()
|
|
|
|
if not super_admin_role:
|
|
logger.error("Super admin role not found - cannot create admin user")
|
|
return
|
|
|
|
# Create admin user from environment variables
|
|
# Generate username from email (part before @)
|
|
admin_username = admin_email.split("@")[0]
|
|
|
|
admin_user = User.create_default_admin(
|
|
email=admin_email,
|
|
username=admin_username,
|
|
password_hash=get_password_hash(admin_password),
|
|
)
|
|
|
|
# Assign the super_admin role
|
|
admin_user.role_id = super_admin_role.id
|
|
|
|
session.add(admin_user)
|
|
await session.commit()
|
|
|
|
logger.warning("=" * 60)
|
|
logger.warning("ADMIN USER CREATED FROM ENVIRONMENT")
|
|
logger.warning(f"Email: {admin_email}")
|
|
logger.warning(f"Username: {admin_username}")
|
|
logger.warning("Role: Super Administrator")
|
|
logger.warning(
|
|
"Password: [Set via ADMIN_PASSWORD - only used on first creation]"
|
|
)
|
|
logger.warning("PLEASE CHANGE THE PASSWORD AFTER FIRST LOGIN")
|
|
logger.warning("=" * 60)
|
|
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to create default admin user due to database error: {e}")
|
|
except AttributeError as e:
|
|
logger.error(
|
|
f"Failed to create default admin user: invalid ADMIN_EMAIL '{settings.ADMIN_EMAIL}'"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to create default admin user: {e}")
|
|
# Don't raise here as this shouldn't block the application startup
|