clean commit

This commit is contained in:
2025-08-19 09:50:15 +02:00
parent 3c5cca407d
commit 69a947fa0b
249 changed files with 65688 additions and 0 deletions

40
backend/Dockerfile Normal file
View File

@@ -0,0 +1,40 @@
FROM python:3.11-slim
# Set environment variables
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONPATH=/app
# Set work directory
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \
build-essential \
libpq-dev \
curl \
ffmpeg \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements and install Python dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Download spaCy English model for NLP processing
RUN python -m spacy download en_core_web_sm
# Copy application code
COPY . .
# Create logs directory
RUN mkdir -p logs
# Expose port
EXPOSE 8000
# Health check
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# Run the application
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]

98
backend/alembic.ini Normal file
View File

@@ -0,0 +1,98 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python-dateutil library that can be
# installed by adding `alembic[tz]` to the pip requirements
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses
# os.pathsep. If this key is omitted entirely, it falls back to the legacy
# behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

99
backend/alembic/env.py Normal file
View File

@@ -0,0 +1,99 @@
import asyncio
import os
from logging.config import fileConfig
from alembic import context
from sqlalchemy import engine_from_config, pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import AsyncEngine
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
from app.db.database import Base
from app.models.user import User
from app.models.api_key import APIKey
from app.models.budget import Budget
from app.models.usage_tracking import UsageTracking
from app.models.audit_log import AuditLog
from app.models.module import Module
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
# Get database URL from environment
database_url = os.getenv("DATABASE_URL", "postgresql://empire:empire123@localhost:5432/empire")
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = database_url
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
async def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
configuration = config.get_section(config.config_ini_section)
configuration["sqlalchemy.url"] = database_url.replace("postgresql://", "postgresql+asyncpg://")
connectable = AsyncEngine(
engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
if context.is_offline_mode():
run_migrations_offline()
else:
asyncio.run(run_migrations_online())

View File

@@ -0,0 +1,24 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,251 @@
"""Initial schema
Revision ID: 001
Revises:
Create Date: 2025-01-01 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '001'
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create users table
op.create_table('users',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('email', sa.String(), nullable=False),
sa.Column('username', sa.String(), nullable=False),
sa.Column('hashed_password', sa.String(), nullable=False),
sa.Column('full_name', sa.String(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True),
sa.Column('is_superuser', sa.Boolean(), nullable=True),
sa.Column('is_verified', sa.Boolean(), nullable=True),
sa.Column('role', sa.String(), nullable=True),
sa.Column('permissions', sa.JSON(), nullable=True),
sa.Column('avatar_url', sa.String(), nullable=True),
sa.Column('bio', sa.Text(), nullable=True),
sa.Column('company', sa.String(), nullable=True),
sa.Column('website', sa.String(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('last_login', sa.DateTime(), nullable=True),
sa.Column('preferences', sa.JSON(), nullable=True),
sa.Column('notification_settings', sa.JSON(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False)
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True)
# Create api_keys table
op.create_table('api_keys',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('key_hash', sa.String(), nullable=False),
sa.Column('key_prefix', sa.String(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=True),
sa.Column('permissions', sa.JSON(), nullable=True),
sa.Column('scopes', sa.JSON(), nullable=True),
sa.Column('rate_limit_per_minute', sa.Integer(), nullable=True),
sa.Column('rate_limit_per_hour', sa.Integer(), nullable=True),
sa.Column('rate_limit_per_day', sa.Integer(), nullable=True),
sa.Column('allowed_models', sa.JSON(), nullable=True),
sa.Column('allowed_endpoints', sa.JSON(), nullable=True),
sa.Column('allowed_ips', sa.JSON(), nullable=True),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('tags', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('last_used_at', sa.DateTime(), nullable=True),
sa.Column('expires_at', sa.DateTime(), nullable=True),
sa.Column('total_requests', sa.Integer(), nullable=True),
sa.Column('total_tokens', sa.Integer(), nullable=True),
sa.Column('total_cost', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_api_keys_id'), 'api_keys', ['id'], unique=False)
op.create_index(op.f('ix_api_keys_key_hash'), 'api_keys', ['key_hash'], unique=True)
op.create_index(op.f('ix_api_keys_key_prefix'), 'api_keys', ['key_prefix'], unique=False)
# Create budgets table
op.create_table('budgets',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('description', sa.String(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('limit_amount', sa.Float(), nullable=False),
sa.Column('currency', sa.String(), nullable=True),
sa.Column('period', sa.String(), nullable=True),
sa.Column('current_usage', sa.Float(), nullable=True),
sa.Column('remaining_amount', sa.Float(), nullable=True),
sa.Column('status', sa.String(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True),
sa.Column('alert_thresholds', sa.JSON(), nullable=True),
sa.Column('alerts_sent', sa.JSON(), nullable=True),
sa.Column('auto_suspend_on_exceed', sa.Boolean(), nullable=True),
sa.Column('auto_notify_on_exceed', sa.Boolean(), nullable=True),
sa.Column('period_start', sa.DateTime(), nullable=False),
sa.Column('period_end', sa.DateTime(), nullable=False),
sa.Column('allowed_models', sa.JSON(), nullable=True),
sa.Column('allowed_endpoints', sa.JSON(), nullable=True),
sa.Column('user_groups', sa.JSON(), nullable=True),
sa.Column('tags', sa.JSON(), nullable=True),
sa.Column('metadata', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('last_reset_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_budgets_id'), 'budgets', ['id'], unique=False)
# Create usage_tracking table
op.create_table('usage_tracking',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('request_id', sa.String(), nullable=False),
sa.Column('session_id', sa.String(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('api_key_id', sa.Integer(), nullable=True),
sa.Column('endpoint', sa.String(), nullable=False),
sa.Column('method', sa.String(), nullable=False),
sa.Column('user_agent', sa.String(), nullable=True),
sa.Column('ip_address', sa.String(), nullable=True),
sa.Column('model_name', sa.String(), nullable=True),
sa.Column('provider', sa.String(), nullable=True),
sa.Column('model_version', sa.String(), nullable=True),
sa.Column('request_data', sa.JSON(), nullable=True),
sa.Column('response_data', sa.JSON(), nullable=True),
sa.Column('prompt_tokens', sa.Integer(), nullable=True),
sa.Column('completion_tokens', sa.Integer(), nullable=True),
sa.Column('total_tokens', sa.Integer(), nullable=True),
sa.Column('cost_per_token', sa.Float(), nullable=True),
sa.Column('total_cost', sa.Float(), nullable=True),
sa.Column('currency', sa.String(), nullable=True),
sa.Column('response_time', sa.Float(), nullable=True),
sa.Column('queue_time', sa.Float(), nullable=True),
sa.Column('processing_time', sa.Float(), nullable=True),
sa.Column('status', sa.String(), nullable=True),
sa.Column('status_code', sa.Integer(), nullable=True),
sa.Column('error_message', sa.Text(), nullable=True),
sa.Column('error_type', sa.String(), nullable=True),
sa.Column('modules_used', sa.JSON(), nullable=True),
sa.Column('interceptor_chain', sa.JSON(), nullable=True),
sa.Column('tags', sa.JSON(), nullable=True),
sa.Column('metadata', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('started_at', sa.DateTime(), nullable=True),
sa.Column('completed_at', sa.DateTime(), nullable=True),
sa.Column('cache_hit', sa.Boolean(), nullable=True),
sa.Column('cache_key', sa.String(), nullable=True),
sa.Column('rate_limit_remaining', sa.Integer(), nullable=True),
sa.Column('rate_limit_reset', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['api_key_id'], ['api_keys.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_usage_tracking_id'), 'usage_tracking', ['id'], unique=False)
op.create_index(op.f('ix_usage_tracking_request_id'), 'usage_tracking', ['request_id'], unique=True)
op.create_index(op.f('ix_usage_tracking_session_id'), 'usage_tracking', ['session_id'], unique=False)
# Create audit_logs table
op.create_table('audit_logs',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('action', sa.String(), nullable=False),
sa.Column('resource_type', sa.String(), nullable=False),
sa.Column('resource_id', sa.String(), nullable=True),
sa.Column('description', sa.Text(), nullable=False),
sa.Column('details', sa.JSON(), nullable=True),
sa.Column('ip_address', sa.String(), nullable=True),
sa.Column('user_agent', sa.String(), nullable=True),
sa.Column('session_id', sa.String(), nullable=True),
sa.Column('request_id', sa.String(), nullable=True),
sa.Column('severity', sa.String(), nullable=True),
sa.Column('category', sa.String(), nullable=True),
sa.Column('success', sa.Boolean(), nullable=True),
sa.Column('error_message', sa.Text(), nullable=True),
sa.Column('tags', sa.JSON(), nullable=True),
sa.Column('metadata', sa.JSON(), nullable=True),
sa.Column('old_values', sa.JSON(), nullable=True),
sa.Column('new_values', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_audit_logs_created_at'), 'audit_logs', ['created_at'], unique=False)
op.create_index(op.f('ix_audit_logs_id'), 'audit_logs', ['id'], unique=False)
# Create modules table
op.create_table('modules',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('display_name', sa.String(), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('module_type', sa.String(), nullable=True),
sa.Column('category', sa.String(), nullable=True),
sa.Column('version', sa.String(), nullable=False),
sa.Column('author', sa.String(), nullable=True),
sa.Column('license', sa.String(), nullable=True),
sa.Column('status', sa.String(), nullable=True),
sa.Column('is_enabled', sa.Boolean(), nullable=True),
sa.Column('is_core', sa.Boolean(), nullable=True),
sa.Column('config_schema', sa.JSON(), nullable=True),
sa.Column('config_values', sa.JSON(), nullable=True),
sa.Column('default_config', sa.JSON(), nullable=True),
sa.Column('dependencies', sa.JSON(), nullable=True),
sa.Column('conflicts', sa.JSON(), nullable=True),
sa.Column('install_path', sa.String(), nullable=True),
sa.Column('entry_point', sa.String(), nullable=True),
sa.Column('interceptor_chains', sa.JSON(), nullable=True),
sa.Column('execution_order', sa.Integer(), nullable=True),
sa.Column('api_endpoints', sa.JSON(), nullable=True),
sa.Column('required_permissions', sa.JSON(), nullable=True),
sa.Column('security_level', sa.String(), nullable=True),
sa.Column('tags', sa.JSON(), nullable=True),
sa.Column('metadata', sa.JSON(), nullable=True),
sa.Column('last_error', sa.Text(), nullable=True),
sa.Column('error_count', sa.Integer(), nullable=True),
sa.Column('last_started', sa.DateTime(), nullable=True),
sa.Column('last_stopped', sa.DateTime(), nullable=True),
sa.Column('request_count', sa.Integer(), nullable=True),
sa.Column('success_count', sa.Integer(), nullable=True),
sa.Column('error_count_runtime', sa.Integer(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('installed_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_modules_id'), 'modules', ['id'], unique=False)
op.create_index(op.f('ix_modules_name'), 'modules', ['name'], unique=True)
def downgrade() -> None:
op.drop_index(op.f('ix_modules_name'), table_name='modules')
op.drop_index(op.f('ix_modules_id'), table_name='modules')
op.drop_table('modules')
op.drop_index(op.f('ix_audit_logs_id'), table_name='audit_logs')
op.drop_index(op.f('ix_audit_logs_created_at'), table_name='audit_logs')
op.drop_table('audit_logs')
op.drop_index(op.f('ix_usage_tracking_session_id'), table_name='usage_tracking')
op.drop_index(op.f('ix_usage_tracking_request_id'), table_name='usage_tracking')
op.drop_index(op.f('ix_usage_tracking_id'), table_name='usage_tracking')
op.drop_table('usage_tracking')
op.drop_index(op.f('ix_budgets_id'), table_name='budgets')
op.drop_table('budgets')
op.drop_index(op.f('ix_api_keys_key_prefix'), table_name='api_keys')
op.drop_index(op.f('ix_api_keys_key_hash'), table_name='api_keys')
op.drop_index(op.f('ix_api_keys_id'), table_name='api_keys')
op.drop_table('api_keys')
op.drop_index(op.f('ix_users_username'), table_name='users')
op.drop_index(op.f('ix_users_id'), table_name='users')
op.drop_index(op.f('ix_users_email'), table_name='users')
op.drop_table('users')

View File

@@ -0,0 +1,84 @@
"""Add RAG collections and documents tables
Revision ID: 002
Revises: 001
Create Date: 2025-07-23 19:30:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '002'
down_revision = '001'
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create rag_collections table
op.create_table('rag_collections',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(255), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('qdrant_collection_name', sa.String(255), nullable=False),
sa.Column('document_count', sa.Integer(), nullable=False, server_default='0'),
sa.Column('size_bytes', sa.BigInteger(), nullable=False, server_default='0'),
sa.Column('vector_count', sa.Integer(), nullable=False, server_default='0'),
sa.Column('status', sa.String(50), nullable=False, server_default='active'),
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_rag_collections_id'), 'rag_collections', ['id'], unique=False)
op.create_index(op.f('ix_rag_collections_name'), 'rag_collections', ['name'], unique=False)
op.create_index(op.f('ix_rag_collections_qdrant_collection_name'), 'rag_collections', ['qdrant_collection_name'], unique=True)
# Create rag_documents table
op.create_table('rag_documents',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('collection_id', sa.Integer(), nullable=False),
sa.Column('filename', sa.String(255), nullable=False),
sa.Column('original_filename', sa.String(255), nullable=False),
sa.Column('file_path', sa.String(500), nullable=False),
sa.Column('file_type', sa.String(50), nullable=False),
sa.Column('file_size', sa.BigInteger(), nullable=False),
sa.Column('mime_type', sa.String(100), nullable=True),
sa.Column('status', sa.String(50), nullable=False, server_default='processing'),
sa.Column('processing_error', sa.Text(), nullable=True),
sa.Column('converted_content', sa.Text(), nullable=True),
sa.Column('word_count', sa.Integer(), nullable=False, server_default='0'),
sa.Column('character_count', sa.Integer(), nullable=False, server_default='0'),
sa.Column('vector_count', sa.Integer(), nullable=False, server_default='0'),
sa.Column('chunk_size', sa.Integer(), nullable=False, server_default='1000'),
sa.Column('document_metadata', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('processed_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('indexed_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('is_deleted', sa.Boolean(), nullable=False, server_default='false'),
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(['collection_id'], ['rag_collections.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_rag_documents_id'), 'rag_documents', ['id'], unique=False)
op.create_index(op.f('ix_rag_documents_collection_id'), 'rag_documents', ['collection_id'], unique=False)
op.create_index(op.f('ix_rag_documents_filename'), 'rag_documents', ['filename'], unique=False)
op.create_index(op.f('ix_rag_documents_status'), 'rag_documents', ['status'], unique=False)
op.create_index(op.f('ix_rag_documents_created_at'), 'rag_documents', ['created_at'], unique=False)
def downgrade() -> None:
op.drop_index(op.f('ix_rag_documents_created_at'), table_name='rag_documents')
op.drop_index(op.f('ix_rag_documents_status'), table_name='rag_documents')
op.drop_index(op.f('ix_rag_documents_filename'), table_name='rag_documents')
op.drop_index(op.f('ix_rag_documents_collection_id'), table_name='rag_documents')
op.drop_index(op.f('ix_rag_documents_id'), table_name='rag_documents')
op.drop_table('rag_documents')
op.drop_index(op.f('ix_rag_collections_qdrant_collection_name'), table_name='rag_collections')
op.drop_index(op.f('ix_rag_collections_name'), table_name='rag_collections')
op.drop_index(op.f('ix_rag_collections_id'), table_name='rag_collections')
op.drop_table('rag_collections')

View File

@@ -0,0 +1,82 @@
"""Fix budget and usage_tracking columns
Revision ID: 003
Revises: 002
Create Date: 2025-07-24 09:30:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '003'
down_revision = '002'
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add missing columns to budgets table
op.add_column('budgets', sa.Column('api_key_id', sa.Integer(), nullable=True))
op.add_column('budgets', sa.Column('limit_cents', sa.Integer(), nullable=False, server_default='0'))
op.add_column('budgets', sa.Column('warning_threshold_cents', sa.Integer(), nullable=True))
op.add_column('budgets', sa.Column('period_type', sa.String(), nullable=False, server_default='monthly'))
op.add_column('budgets', sa.Column('current_usage_cents', sa.Integer(), nullable=True, server_default='0'))
op.add_column('budgets', sa.Column('is_exceeded', sa.Boolean(), nullable=True, server_default='false'))
op.add_column('budgets', sa.Column('is_warning_sent', sa.Boolean(), nullable=True, server_default='false'))
op.add_column('budgets', sa.Column('enforce_hard_limit', sa.Boolean(), nullable=True, server_default='true'))
op.add_column('budgets', sa.Column('enforce_warning', sa.Boolean(), nullable=True, server_default='true'))
op.add_column('budgets', sa.Column('auto_renew', sa.Boolean(), nullable=True, server_default='true'))
op.add_column('budgets', sa.Column('rollover_unused', sa.Boolean(), nullable=True, server_default='false'))
op.add_column('budgets', sa.Column('notification_settings', sa.JSON(), nullable=True))
# Create foreign key for api_key_id
op.create_foreign_key('fk_budgets_api_key_id', 'budgets', 'api_keys', ['api_key_id'], ['id'])
# Update usage_tracking table
op.add_column('usage_tracking', sa.Column('budget_id', sa.Integer(), nullable=True))
op.add_column('usage_tracking', sa.Column('model', sa.String(), nullable=True))
op.add_column('usage_tracking', sa.Column('request_tokens', sa.Integer(), nullable=True))
op.add_column('usage_tracking', sa.Column('response_tokens', sa.Integer(), nullable=True))
op.add_column('usage_tracking', sa.Column('cost_cents', sa.Integer(), nullable=True))
op.add_column('usage_tracking', sa.Column('cost_currency', sa.String(), nullable=True, server_default='USD'))
op.add_column('usage_tracking', sa.Column('response_time_ms', sa.Integer(), nullable=True))
op.add_column('usage_tracking', sa.Column('request_metadata', sa.JSON(), nullable=True))
# Create foreign key for budget_id
op.create_foreign_key('fk_usage_tracking_budget_id', 'usage_tracking', 'budgets', ['budget_id'], ['id'])
# Update modules table
op.add_column('modules', sa.Column('module_metadata', sa.JSON(), nullable=True))
def downgrade() -> None:
# Remove added columns from modules
op.drop_column('modules', 'module_metadata')
# Remove added columns and constraints from usage_tracking
op.drop_constraint('fk_usage_tracking_budget_id', 'usage_tracking', type_='foreignkey')
op.drop_column('usage_tracking', 'request_metadata')
op.drop_column('usage_tracking', 'response_time_ms')
op.drop_column('usage_tracking', 'cost_currency')
op.drop_column('usage_tracking', 'cost_cents')
op.drop_column('usage_tracking', 'response_tokens')
op.drop_column('usage_tracking', 'request_tokens')
op.drop_column('usage_tracking', 'model')
op.drop_column('usage_tracking', 'budget_id')
# Remove added columns and constraints from budgets
op.drop_constraint('fk_budgets_api_key_id', 'budgets', type_='foreignkey')
op.drop_column('budgets', 'notification_settings')
op.drop_column('budgets', 'rollover_unused')
op.drop_column('budgets', 'auto_renew')
op.drop_column('budgets', 'enforce_warning')
op.drop_column('budgets', 'enforce_hard_limit')
op.drop_column('budgets', 'is_warning_sent')
op.drop_column('budgets', 'is_exceeded')
op.drop_column('budgets', 'current_usage_cents')
op.drop_column('budgets', 'period_type')
op.drop_column('budgets', 'warning_threshold_cents')
op.drop_column('budgets', 'limit_cents')
op.drop_column('budgets', 'api_key_id')

View File

@@ -0,0 +1,34 @@
"""Add budget fields to API keys
Revision ID: 004_add_api_key_budget_fields
Revises: 8bf097417ff0
Create Date: 2024-07-25 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '004_add_api_key_budget_fields'
down_revision = '8bf097417ff0'
branch_labels = None
depends_on = None
def upgrade():
"""Add budget-related fields to api_keys table"""
# Add budget configuration columns
op.add_column('api_keys', sa.Column('is_unlimited', sa.Boolean(), default=True, nullable=False))
op.add_column('api_keys', sa.Column('budget_limit_cents', sa.Integer(), nullable=True))
op.add_column('api_keys', sa.Column('budget_type', sa.String(), nullable=True))
# Set default values for existing records
op.execute("UPDATE api_keys SET is_unlimited = true WHERE is_unlimited IS NULL")
def downgrade():
"""Remove budget-related fields from api_keys table"""
op.drop_column('api_keys', 'budget_type')
op.drop_column('api_keys', 'budget_limit_cents')
op.drop_column('api_keys', 'is_unlimited')

View File

@@ -0,0 +1,192 @@
"""Add prompt templates for editable chatbot prompts
Revision ID: 005_add_prompt_templates
Revises: 004_add_api_key_budget_fields
Create Date: 2025-08-07 17:50:00.000000
"""
from alembic import op
import sqlalchemy as sa
from datetime import datetime
import uuid
# revision identifiers, used by Alembic.
revision = '005_add_prompt_templates'
down_revision = '004_add_api_key_budget_fields'
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create prompt_templates table
op.create_table(
'prompt_templates',
sa.Column('id', sa.String(), primary_key=True, index=True),
sa.Column('name', sa.String(255), nullable=False, index=True),
sa.Column('type_key', sa.String(100), nullable=False, unique=True, index=True),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('system_prompt', sa.Text(), nullable=False),
sa.Column('is_default', sa.Boolean(), default=True, nullable=False),
sa.Column('is_active', sa.Boolean(), default=True, nullable=False),
sa.Column('version', sa.Integer(), default=1, nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), onupdate=sa.func.now())
)
# Create prompt_variables table
op.create_table(
'prompt_variables',
sa.Column('id', sa.String(), primary_key=True, index=True),
sa.Column('variable_name', sa.String(100), nullable=False, unique=True, index=True),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('example_value', sa.String(500), nullable=True),
sa.Column('is_active', sa.Boolean(), default=True, nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now())
)
# Insert default prompt templates
prompt_templates_table = sa.table(
'prompt_templates',
sa.column('id', sa.String),
sa.column('name', sa.String),
sa.column('type_key', sa.String),
sa.column('description', sa.Text),
sa.column('system_prompt', sa.Text),
sa.column('is_default', sa.Boolean),
sa.column('is_active', sa.Boolean),
sa.column('version', sa.Integer),
sa.column('created_at', sa.DateTime),
sa.column('updated_at', sa.DateTime)
)
current_time = datetime.utcnow()
default_prompts = [
{
'id': str(uuid.uuid4()),
'name': 'General Assistant',
'type_key': 'assistant',
'description': 'Helpful AI assistant for general questions and tasks',
'system_prompt': 'You are a helpful AI assistant. Provide accurate, concise, and friendly responses. Always aim to be helpful while being honest about your limitations. When you don\'t know something, say so clearly. Be professional but approachable in your communication style.',
'is_default': True,
'is_active': True,
'version': 1,
'created_at': current_time,
'updated_at': current_time
},
{
'id': str(uuid.uuid4()),
'name': 'Customer Support',
'type_key': 'customer_support',
'description': 'Professional customer service chatbot',
'system_prompt': 'You are a professional customer support representative. Be empathetic, professional, and solution-focused in all interactions. Always try to understand the customer\'s issue fully before providing solutions. Use the knowledge base to provide accurate information. When you cannot resolve an issue, explain clearly how the customer can escalate or get further help. Maintain a helpful and patient tone even in difficult situations.',
'is_default': True,
'is_active': True,
'version': 1,
'created_at': current_time,
'updated_at': current_time
},
{
'id': str(uuid.uuid4()),
'name': 'Educational Tutor',
'type_key': 'teacher',
'description': 'Educational tutor and learning assistant',
'system_prompt': 'You are an experienced educational tutor and learning facilitator. Break down complex concepts into understandable, digestible parts. Use analogies, examples, and step-by-step explanations to help students learn. Encourage critical thinking through thoughtful questions. Be patient, supportive, and encouraging. Adapt your teaching style to different learning preferences. When a student makes mistakes, guide them to the correct answer rather than just providing it.',
'is_default': True,
'is_active': True,
'version': 1,
'created_at': current_time,
'updated_at': current_time
},
{
'id': str(uuid.uuid4()),
'name': 'Research Assistant',
'type_key': 'researcher',
'description': 'Research assistant with fact-checking focus',
'system_prompt': 'You are a thorough research assistant with a focus on accuracy and evidence-based information. Provide well-researched, factual information with sources when possible. Be thorough in your analysis and present multiple perspectives when relevant topics have different viewpoints. Always distinguish between established facts, current research, and opinions. When information is uncertain or contested, clearly communicate the level of confidence and supporting evidence.',
'is_default': True,
'is_active': True,
'version': 1,
'created_at': current_time,
'updated_at': current_time
},
{
'id': str(uuid.uuid4()),
'name': 'Creative Writing Assistant',
'type_key': 'creative_writer',
'description': 'Creative writing and storytelling assistant',
'system_prompt': 'You are an experienced creative writing mentor and storytelling expert. Help with brainstorming ideas, character development, plot structure, dialogue, and creative expression. Be imaginative and inspiring while providing constructive, actionable feedback. Encourage experimentation with different writing styles and techniques. When reviewing work, balance praise for strengths with specific suggestions for improvement. Help writers find their unique voice while mastering fundamental storytelling principles.',
'is_default': True,
'is_active': True,
'version': 1,
'created_at': current_time,
'updated_at': current_time
},
{
'id': str(uuid.uuid4()),
'name': 'Custom Chatbot',
'type_key': 'custom',
'description': 'Fully customizable chatbot with user-defined personality',
'system_prompt': 'You are a helpful AI assistant. Your personality, expertise, and behavior will be defined by the user through custom instructions. Follow the user\'s guidance on how to respond, what tone to use, and what role to play. Be adaptable and responsive to the specific needs and preferences outlined in your configuration.',
'is_default': True,
'is_active': True,
'version': 1,
'created_at': current_time,
'updated_at': current_time
}
]
op.bulk_insert(prompt_templates_table, default_prompts)
# Insert default prompt variables
variables_table = sa.table(
'prompt_variables',
sa.column('id', sa.String),
sa.column('variable_name', sa.String),
sa.column('description', sa.Text),
sa.column('example_value', sa.String),
sa.column('is_active', sa.Boolean),
sa.column('created_at', sa.DateTime)
)
default_variables = [
{
'id': str(uuid.uuid4()),
'variable_name': '{user_name}',
'description': 'The name of the user chatting with the bot',
'example_value': 'John Smith',
'is_active': True,
'created_at': current_time
},
{
'id': str(uuid.uuid4()),
'variable_name': '{context}',
'description': 'Additional context from RAG or previous conversation',
'example_value': 'Based on the uploaded documents...',
'is_active': True,
'created_at': current_time
},
{
'id': str(uuid.uuid4()),
'variable_name': '{company_name}',
'description': 'Your company or organization name',
'example_value': 'Acme Corporation',
'is_active': True,
'created_at': current_time
},
{
'id': str(uuid.uuid4()),
'variable_name': '{current_date}',
'description': 'Current date and time',
'example_value': '2025-08-07 17:50:00',
'is_active': True,
'created_at': current_time
}
]
op.bulk_insert(variables_table, default_variables)
def downgrade() -> None:
op.drop_table('prompt_variables')
op.drop_table('prompt_templates')

View File

@@ -0,0 +1,33 @@
"""Add chatbot API key support
Revision ID: 009_add_chatbot_api_key_support
Revises: 004_add_api_key_budget_fields
Create Date: 2025-01-08 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers
revision = '009_add_chatbot_api_key_support'
down_revision = '004_add_api_key_budget_fields'
branch_labels = None
depends_on = None
def upgrade():
"""Add allowed_chatbots column to api_keys table"""
# Add the allowed_chatbots column
op.add_column('api_keys', sa.Column('allowed_chatbots', sa.JSON(), nullable=True))
# Update existing records to have empty allowed_chatbots list
op.execute("UPDATE api_keys SET allowed_chatbots = '[]' WHERE allowed_chatbots IS NULL")
# Make the column non-nullable with a default
op.alter_column('api_keys', 'allowed_chatbots', nullable=False, server_default='[]')
def downgrade():
"""Remove allowed_chatbots column from api_keys table"""
op.drop_column('api_keys', 'allowed_chatbots')

View File

@@ -0,0 +1,81 @@
"""add workflow tables only
Revision ID: 010_add_workflow_tables_only
Revises: f7af0923d38b
Create Date: 2025-08-18 09:03:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '010_add_workflow_tables_only'
down_revision = 'f7af0923d38b'
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create workflow_definitions table
op.create_table('workflow_definitions',
sa.Column('id', sa.String(), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('version', sa.String(length=50), nullable=True),
sa.Column('steps', sa.JSON(), nullable=False),
sa.Column('variables', sa.JSON(), nullable=True),
sa.Column('metadata', sa.JSON(), nullable=True),
sa.Column('timeout', sa.Integer(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True),
sa.Column('created_by', sa.String(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
# Create workflow_executions table
op.create_table('workflow_executions',
sa.Column('id', sa.String(), nullable=False),
sa.Column('workflow_id', sa.String(), nullable=False),
sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', 'CANCELLED', name='workflowstatus'), nullable=True),
sa.Column('current_step', sa.String(), nullable=True),
sa.Column('input_data', sa.JSON(), nullable=True),
sa.Column('context', sa.JSON(), nullable=True),
sa.Column('results', sa.JSON(), nullable=True),
sa.Column('error', sa.Text(), nullable=True),
sa.Column('started_at', sa.DateTime(), nullable=True),
sa.Column('completed_at', sa.DateTime(), nullable=True),
sa.Column('executed_by', sa.String(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['workflow_id'], ['workflow_definitions.id'], ),
sa.PrimaryKeyConstraint('id')
)
# Create workflow_step_logs table
op.create_table('workflow_step_logs',
sa.Column('id', sa.String(), nullable=False),
sa.Column('execution_id', sa.String(), nullable=False),
sa.Column('step_id', sa.String(), nullable=False),
sa.Column('step_name', sa.String(length=255), nullable=False),
sa.Column('step_type', sa.String(length=50), nullable=False),
sa.Column('status', sa.String(length=50), nullable=False),
sa.Column('input_data', sa.JSON(), nullable=True),
sa.Column('output_data', sa.JSON(), nullable=True),
sa.Column('error', sa.Text(), nullable=True),
sa.Column('started_at', sa.DateTime(), nullable=True),
sa.Column('completed_at', sa.DateTime(), nullable=True),
sa.Column('duration_ms', sa.Integer(), nullable=True),
sa.Column('retry_count', sa.Integer(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['execution_id'], ['workflow_executions.id'], ),
sa.PrimaryKeyConstraint('id')
)
def downgrade() -> None:
op.drop_table('workflow_step_logs')
op.drop_table('workflow_executions')
op.drop_table('workflow_definitions')
op.execute('DROP TYPE IF EXISTS workflowstatus')

View File

@@ -0,0 +1,79 @@
"""Add chatbot tables
Revision ID: 8bf097417ff0
Revises: 003
Create Date: 2025-07-25 03:58:39.403887
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '8bf097417ff0'
down_revision = '003'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust\! ###
op.create_table('chatbot_instances',
sa.Column('id', sa.String(), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('config', sa.JSON(), nullable=False),
sa.Column('created_by', sa.String(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('chatbot_conversations',
sa.Column('id', sa.String(), nullable=False),
sa.Column('chatbot_id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=False),
sa.Column('title', sa.String(length=255), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True),
sa.Column('context_data', sa.JSON(), nullable=True),
sa.ForeignKeyConstraint(['chatbot_id'], ['chatbot_instances.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('chatbot_analytics',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('chatbot_id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=False),
sa.Column('event_type', sa.String(length=50), nullable=False),
sa.Column('event_data', sa.JSON(), nullable=True),
sa.Column('response_time_ms', sa.Integer(), nullable=True),
sa.Column('token_count', sa.Integer(), nullable=True),
sa.Column('cost_cents', sa.Integer(), nullable=True),
sa.Column('model_used', sa.String(length=100), nullable=True),
sa.Column('rag_used', sa.Boolean(), nullable=True),
sa.Column('timestamp', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['chatbot_id'], ['chatbot_instances.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('chatbot_messages',
sa.Column('id', sa.String(), nullable=False),
sa.Column('conversation_id', sa.String(), nullable=False),
sa.Column('role', sa.String(length=20), nullable=False),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('timestamp', sa.DateTime(), nullable=True),
sa.Column('message_metadata', sa.JSON(), nullable=True),
sa.Column('sources', sa.JSON(), nullable=True),
sa.ForeignKeyConstraint(['conversation_id'], ['chatbot_conversations.id'], ),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust\! ###
op.drop_table('chatbot_messages')
op.drop_table('chatbot_analytics')
op.drop_table('chatbot_conversations')
op.drop_table('chatbot_instances')
# ### end Alembic commands ###

View File

View File

@@ -0,0 +1,24 @@
"""merge prompt templates and chatbot api key support
Revision ID: f7af0923d38b
Revises: 005_add_prompt_templates, 009_add_chatbot_api_key_support
Create Date: 2025-08-18 06:51:17.515233
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'f7af0923d38b'
down_revision = ('005_add_prompt_templates', '009_add_chatbot_api_key_support')
branch_labels = None
depends_on = None
def upgrade() -> None:
pass
def downgrade() -> None:
pass

7
backend/app/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
"""
Confidential Empire - Modular AI Gateway Platform
"""
__version__ = "1.0.0"
__author__ = "Confidential Empire Team"
__description__ = "Modular AI Gateway Platform with confidential AI processing"

View File

@@ -0,0 +1,3 @@
"""
API package
"""

View File

@@ -0,0 +1,68 @@
"""
API v1 package
"""
from fastapi import APIRouter
from .auth import router as auth_router
from .llm import router as llm_router
from .tee import router as tee_router
from .modules import router as modules_router
from .platform import router as platform_router
from .users import router as users_router
from .api_keys import router as api_keys_router
from .budgets import router as budgets_router
from .audit import router as audit_router
from .settings import router as settings_router
from .analytics import router as analytics_router
from .rag import router as rag_router
from .chatbot import router as chatbot_router
from .prompt_templates import router as prompt_templates_router
from .security import router as security_router
# Create main API router
api_router = APIRouter()
# Include authentication routes
api_router.include_router(auth_router, prefix="/auth", tags=["authentication"])
# Include LLM proxy routes
api_router.include_router(llm_router, prefix="/llm", tags=["llm"])
# Include TEE routes
api_router.include_router(tee_router, prefix="/tee", tags=["tee"])
# Include modules routes
api_router.include_router(modules_router, prefix="/modules", tags=["modules"])
# Include platform routes
api_router.include_router(platform_router, prefix="/platform", tags=["platform"])
# Include user management routes
api_router.include_router(users_router, prefix="/users", tags=["users"])
# Include API key management routes
api_router.include_router(api_keys_router, prefix="/api-keys", tags=["api-keys"])
# Include budget management routes
api_router.include_router(budgets_router, prefix="/budgets", tags=["budgets"])
# Include audit log routes
api_router.include_router(audit_router, prefix="/audit", tags=["audit"])
# Include settings management routes
api_router.include_router(settings_router, prefix="/settings", tags=["settings"])
# Include analytics routes
api_router.include_router(analytics_router, prefix="/analytics", tags=["analytics"])
# Include RAG routes
api_router.include_router(rag_router, prefix="/rag", tags=["rag"])
# Include chatbot routes
api_router.include_router(chatbot_router, prefix="/chatbot", tags=["chatbot"])
# Include prompt template routes
api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["prompt-templates"])
# Include security routes
api_router.include_router(security_router, prefix="/security", tags=["security"])

View File

@@ -0,0 +1,257 @@
"""
Analytics API endpoints for usage metrics, cost analysis, and system health
Integrated with the core analytics service for comprehensive tracking.
"""
from typing import Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from app.core.security import get_current_user
from app.db.database import get_db
from app.models.user import User
from app.services.analytics import get_analytics_service
from app.services.module_manager import module_manager
from app.core.logging import get_logger
logger = get_logger(__name__)
router = APIRouter()
@router.get("/metrics")
async def get_usage_metrics(
hours: int = Query(24, ge=1, le=168, description="Hours to analyze (1-168)"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Get comprehensive usage metrics including costs and budgets"""
try:
analytics = get_analytics_service()
metrics = await analytics.get_usage_metrics(hours=hours, user_id=current_user['id'])
return {
"success": True,
"data": metrics,
"period_hours": hours
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting usage metrics: {str(e)}")
@router.get("/metrics/system")
async def get_system_metrics(
hours: int = Query(24, ge=1, le=168),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Get system-wide metrics (admin only)"""
if not current_user['is_superuser']:
raise HTTPException(status_code=403, detail="Admin access required")
try:
analytics = get_analytics_service()
metrics = await analytics.get_usage_metrics(hours=hours)
return {
"success": True,
"data": metrics,
"period_hours": hours
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting system metrics: {str(e)}")
@router.get("/health")
async def get_system_health(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Get system health status including budget and performance analysis"""
try:
analytics = get_analytics_service()
health = await analytics.get_system_health()
return {
"success": True,
"data": health
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting system health: {str(e)}")
@router.get("/costs")
async def get_cost_analysis(
days: int = Query(30, ge=1, le=365, description="Days to analyze (1-365)"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Get detailed cost analysis and trends"""
try:
analytics = get_analytics_service()
analysis = await analytics.get_cost_analysis(days=days, user_id=current_user['id'])
return {
"success": True,
"data": analysis,
"period_days": days
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting cost analysis: {str(e)}")
@router.get("/costs/system")
async def get_system_cost_analysis(
days: int = Query(30, ge=1, le=365),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Get system-wide cost analysis (admin only)"""
if not current_user['is_superuser']:
raise HTTPException(status_code=403, detail="Admin access required")
try:
analytics = get_analytics_service()
analysis = await analytics.get_cost_analysis(days=days)
return {
"success": True,
"data": analysis,
"period_days": days
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting system cost analysis: {str(e)}")
@router.get("/endpoints")
async def get_endpoint_stats(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Get endpoint usage statistics"""
try:
analytics = get_analytics_service()
# For now, return the in-memory stats
# In future, this could be enhanced with database queries
return {
"success": True,
"data": {
"endpoint_stats": dict(analytics.endpoint_stats),
"status_codes": dict(analytics.status_codes),
"model_stats": dict(analytics.model_stats)
}
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting endpoint stats: {str(e)}")
@router.get("/usage-trends")
async def get_usage_trends(
days: int = Query(7, ge=1, le=30, description="Days for trend analysis"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Get usage trends over time"""
try:
from datetime import datetime, timedelta
from sqlalchemy import func
from app.models.usage_tracking import UsageTracking
cutoff_time = datetime.utcnow() - timedelta(days=days)
# Daily usage trends
daily_usage = db.query(
func.date(UsageTracking.created_at).label('date'),
func.count(UsageTracking.id).label('requests'),
func.sum(UsageTracking.total_tokens).label('tokens'),
func.sum(UsageTracking.cost_cents).label('cost_cents')
).filter(
UsageTracking.created_at >= cutoff_time,
UsageTracking.user_id == current_user['id']
).group_by(
func.date(UsageTracking.created_at)
).order_by('date').all()
trends = []
for date, requests, tokens, cost_cents in daily_usage:
trends.append({
"date": date.isoformat(),
"requests": requests,
"tokens": tokens or 0,
"cost_cents": cost_cents or 0,
"cost_dollars": (cost_cents or 0) / 100
})
return {
"success": True,
"data": {
"trends": trends,
"period_days": days
}
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting usage trends: {str(e)}")
@router.get("/overview")
async def get_analytics_overview(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Get analytics overview data"""
try:
analytics = get_analytics_service()
# Get basic metrics
metrics = await analytics.get_usage_metrics(hours=24, user_id=current_user['id'])
health = await analytics.get_system_health()
return {
"success": True,
"data": {
"total_requests": metrics.total_requests,
"total_cost_dollars": metrics.total_cost_cents / 100,
"avg_response_time": metrics.avg_response_time,
"error_rate": metrics.error_rate,
"budget_usage_percentage": metrics.budget_usage_percentage,
"system_health": health.status,
"health_score": health.score
}
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting overview: {str(e)}")
@router.get("/modules")
async def get_module_analytics(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Get analytics data for all modules"""
try:
module_stats = []
for name, module in module_manager.modules.items():
stats = {"name": name, "initialized": getattr(module, "initialized", False)}
# Get module statistics if available
if hasattr(module, "get_stats"):
try:
module_data = module.get_stats()
if hasattr(module_data, "__dict__"):
stats.update(module_data.__dict__)
elif isinstance(module_data, dict):
stats.update(module_data)
except Exception as e:
logger.warning(f"Failed to get stats for module {name}: {e}")
stats["error"] = str(e)
module_stats.append(stats)
return {
"success": True,
"data": {
"modules": module_stats,
"total_modules": len(module_stats),
"system_health": "healthy" if all(m.get("initialized", False) for m in module_stats) else "warning"
}
}
except Exception as e:
logger.error(f"Failed to get module analytics: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve module analytics")

View File

@@ -0,0 +1,645 @@
"""
API Key management endpoints
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Query
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete, func
from datetime import datetime, timedelta
import asyncio
import secrets
import string
from app.db.database import get_db
from app.models.api_key import APIKey
from app.models.user import User
from app.core.security import get_current_user
from app.services.permission_manager import require_permission
from app.services.audit_service import log_audit_event, log_audit_event_async
from app.core.logging import get_logger
from app.core.config import settings
logger = get_logger(__name__)
router = APIRouter()
# Pydantic models
class APIKeyCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=500)
scopes: List[str] = Field(default_factory=list)
expires_at: Optional[datetime] = None
rate_limit_per_minute: Optional[int] = Field(None, ge=1, le=10000)
rate_limit_per_hour: Optional[int] = Field(None, ge=1, le=100000)
rate_limit_per_day: Optional[int] = Field(None, ge=1, le=1000000)
allowed_ips: List[str] = Field(default_factory=list)
allowed_models: List[str] = Field(default_factory=list) # Model restrictions
allowed_chatbots: List[str] = Field(default_factory=list) # Chatbot restrictions
is_unlimited: bool = True # Unlimited budget flag
budget_limit_cents: Optional[int] = Field(None, ge=0) # Budget limit in cents
budget_type: Optional[str] = Field(None, pattern="^(total|monthly)$") # Budget type
tags: List[str] = Field(default_factory=list)
class APIKeyUpdate(BaseModel):
name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=500)
scopes: Optional[List[str]] = None
expires_at: Optional[datetime] = None
is_active: Optional[bool] = None
rate_limit_per_minute: Optional[int] = Field(None, ge=1, le=10000)
rate_limit_per_hour: Optional[int] = Field(None, ge=1, le=100000)
rate_limit_per_day: Optional[int] = Field(None, ge=1, le=1000000)
allowed_ips: Optional[List[str]] = None
allowed_models: Optional[List[str]] = None # Model restrictions
allowed_chatbots: Optional[List[str]] = None # Chatbot restrictions
is_unlimited: Optional[bool] = None # Unlimited budget flag
budget_limit_cents: Optional[int] = Field(None, ge=0) # Budget limit in cents
budget_type: Optional[str] = Field(None, pattern="^(total|monthly)$") # Budget type
tags: Optional[List[str]] = None
class APIKeyResponse(BaseModel):
id: int
name: str
description: Optional[str] = None
key_prefix: str
scopes: List[str]
is_active: bool
expires_at: Optional[datetime] = None
created_at: datetime
last_used_at: Optional[datetime] = None
total_requests: int
total_tokens: int
total_cost_cents: int = Field(alias='total_cost')
rate_limit_per_minute: Optional[int] = None
rate_limit_per_hour: Optional[int] = None
rate_limit_per_day: Optional[int] = None
allowed_ips: List[str]
allowed_models: List[str] # Model restrictions
allowed_chatbots: List[str] # Chatbot restrictions
budget_limit: Optional[int] = Field(None, alias='budget_limit_cents') # Budget limit in cents
budget_type: Optional[str] = None # Budget type
is_unlimited: bool = True # Unlimited budget flag
tags: List[str]
class Config:
from_attributes = True
@classmethod
def from_api_key(cls, api_key):
"""Create response from APIKey model with formatted key prefix"""
data = {
'id': api_key.id,
'name': api_key.name,
'description': api_key.description,
'key_prefix': api_key.key_prefix + "..." if api_key.key_prefix else "",
'scopes': api_key.scopes,
'is_active': api_key.is_active,
'expires_at': api_key.expires_at,
'created_at': api_key.created_at,
'last_used_at': api_key.last_used_at,
'total_requests': api_key.total_requests,
'total_tokens': api_key.total_tokens,
'total_cost': api_key.total_cost,
'rate_limit_per_minute': api_key.rate_limit_per_minute,
'rate_limit_per_hour': api_key.rate_limit_per_hour,
'rate_limit_per_day': api_key.rate_limit_per_day,
'allowed_ips': api_key.allowed_ips,
'allowed_models': api_key.allowed_models,
'allowed_chatbots': api_key.allowed_chatbots,
'budget_limit_cents': api_key.budget_limit_cents,
'budget_type': api_key.budget_type,
'is_unlimited': api_key.is_unlimited,
'tags': api_key.tags
}
return cls(**data)
class APIKeyCreateResponse(BaseModel):
api_key: APIKeyResponse
secret_key: str # Only returned on creation
class APIKeyListResponse(BaseModel):
api_keys: List[APIKeyResponse]
total: int
page: int
size: int
class APIKeyUsageResponse(BaseModel):
api_key_id: str
total_requests: int
total_tokens: int
total_cost_cents: int
requests_today: int
tokens_today: int
cost_today_cents: int
requests_this_hour: int
tokens_this_hour: int
cost_this_hour_cents: int
last_used_at: Optional[datetime] = None
def generate_api_key() -> tuple[str, str]:
"""Generate a new API key and return (full_key, key_hash)"""
# Generate random key part (32 characters)
key_part = ''.join(secrets.choice(string.ascii_letters + string.digits) for _ in range(32))
# Create full key with prefix
full_key = f"{settings.API_KEY_PREFIX}{key_part}"
# Create hash for storage
from app.core.security import get_api_key_hash
key_hash = get_api_key_hash(full_key)
return full_key, key_hash
# API Key CRUD endpoints
@router.get("/", response_model=APIKeyListResponse)
async def list_api_keys(
page: int = Query(1, ge=1),
size: int = Query(10, ge=1, le=100),
user_id: Optional[str] = Query(None),
is_active: Optional[bool] = Query(None),
search: Optional[str] = Query(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""List API keys with pagination and filtering"""
# Check permissions - users can view their own API keys
if user_id and int(user_id) != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
elif not user_id:
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
# If no user_id specified and user doesn't have admin permissions, show only their keys
if not user_id and "platform:api-keys:read" not in current_user.get("permissions", []):
user_id = current_user['id']
# Build query
query = select(APIKey)
# Apply filters
if user_id:
query = query.where(APIKey.user_id == (int(user_id) if isinstance(user_id, str) else user_id))
if is_active is not None:
query = query.where(APIKey.is_active == is_active)
if search:
query = query.where(
(APIKey.name.ilike(f"%{search}%")) |
(APIKey.description.ilike(f"%{search}%"))
)
# Get total count using func.count()
total_query = select(func.count(APIKey.id))
# Apply same filters for count
if user_id:
total_query = total_query.where(APIKey.user_id == (int(user_id) if isinstance(user_id, str) else user_id))
if is_active is not None:
total_query = total_query.where(APIKey.is_active == is_active)
if search:
total_query = total_query.where(
(APIKey.name.ilike(f"%{search}%")) |
(APIKey.description.ilike(f"%{search}%"))
)
total_result = await db.execute(total_query)
total = total_result.scalar()
# Apply pagination
offset = (page - 1) * size
query = query.offset(offset).limit(size).order_by(APIKey.created_at.desc())
# Execute query
result = await db.execute(query)
api_keys = result.scalars().all()
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="list_api_keys",
resource_type="api_key",
details={"page": page, "size": size, "filters": {"user_id": user_id, "is_active": is_active, "search": search}}
)
return APIKeyListResponse(
api_keys=[APIKeyResponse.model_validate(key) for key in api_keys],
total=total,
page=page,
size=size
)
@router.get("/{api_key_id}", response_model=APIKeyResponse)
async def get_api_key(
api_key_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Get API key by ID"""
# Get API key
query = select(APIKey).where(APIKey.id == int(api_key_id))
result = await db.execute(query)
api_key = result.scalar_one_or_none()
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
)
# Check permissions - users can view their own API keys
if api_key.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="get_api_key",
resource_type="api_key",
resource_id=api_key_id
)
return APIKeyResponse.model_validate(api_key)
@router.post("/", response_model=APIKeyCreateResponse)
async def create_api_key(
api_key_data: APIKeyCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Create a new API key"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:api-keys:create")
# Generate API key
full_key, key_hash = generate_api_key()
key_prefix = full_key[:8] # Store only first 8 characters for lookup
# Create API key
new_api_key = APIKey(
name=api_key_data.name,
description=api_key_data.description,
key_hash=key_hash,
key_prefix=key_prefix,
user_id=current_user['id'],
scopes=api_key_data.scopes,
expires_at=api_key_data.expires_at,
rate_limit_per_minute=api_key_data.rate_limit_per_minute,
rate_limit_per_hour=api_key_data.rate_limit_per_hour,
rate_limit_per_day=api_key_data.rate_limit_per_day,
allowed_ips=api_key_data.allowed_ips,
allowed_models=api_key_data.allowed_models,
allowed_chatbots=api_key_data.allowed_chatbots,
is_unlimited=api_key_data.is_unlimited,
budget_limit_cents=api_key_data.budget_limit_cents if not api_key_data.is_unlimited else None,
budget_type=api_key_data.budget_type if not api_key_data.is_unlimited else None,
tags=api_key_data.tags
)
db.add(new_api_key)
await db.commit()
await db.refresh(new_api_key)
# Log audit event asynchronously (non-blocking)
asyncio.create_task(log_audit_event_async(
user_id=str(current_user['id']),
action="create_api_key",
resource_type="api_key",
resource_id=str(new_api_key.id),
details={"name": api_key_data.name, "scopes": api_key_data.scopes}
))
logger.info(f"API key created: {new_api_key.name} by {current_user['username']}")
return APIKeyCreateResponse(
api_key=APIKeyResponse.model_validate(new_api_key),
secret_key=full_key
)
@router.put("/{api_key_id}", response_model=APIKeyResponse)
async def update_api_key(
api_key_id: str,
api_key_data: APIKeyUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Update API key"""
# Get API key
query = select(APIKey).where(APIKey.id == int(api_key_id))
result = await db.execute(query)
api_key = result.scalar_one_or_none()
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
)
# Check permissions - users can update their own API keys
if api_key.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:update")
# Store original values for audit
original_values = {
"name": api_key.name,
"scopes": api_key.scopes,
"is_active": api_key.is_active
}
# Update API key fields
update_data = api_key_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(api_key, field, value)
await db.commit()
await db.refresh(api_key)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="update_api_key",
resource_type="api_key",
resource_id=api_key_id,
details={
"updated_fields": list(update_data.keys()),
"before_values": original_values,
"after_values": {k: getattr(api_key, k) for k in update_data.keys()}
}
)
logger.info(f"API key updated: {api_key.name} by {current_user['username']}")
return APIKeyResponse.model_validate(api_key)
@router.delete("/{api_key_id}")
async def delete_api_key(
api_key_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Delete API key"""
# Get API key
query = select(APIKey).where(APIKey.id == int(api_key_id))
result = await db.execute(query)
api_key = result.scalar_one_or_none()
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
)
# Check permissions - users can delete their own API keys
if api_key.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:delete")
# Delete API key
await db.delete(api_key)
await db.commit()
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="delete_api_key",
resource_type="api_key",
resource_id=api_key_id,
details={"name": api_key.name}
)
logger.info(f"API key deleted: {api_key.name} by {current_user['username']}")
return {"message": "API key deleted successfully"}
@router.post("/{api_key_id}/regenerate", response_model=APIKeyCreateResponse)
async def regenerate_api_key(
api_key_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Regenerate API key secret"""
# Get API key
query = select(APIKey).where(APIKey.id == int(api_key_id))
result = await db.execute(query)
api_key = result.scalar_one_or_none()
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
)
# Check permissions - users can regenerate their own API keys
if api_key.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:update")
# Generate new API key
full_key, key_hash = generate_api_key()
key_prefix = full_key[:8] # Store only first 8 characters for lookup
# Update API key
api_key.key_hash = key_hash
api_key.key_prefix = key_prefix
await db.commit()
await db.refresh(api_key)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="regenerate_api_key",
resource_type="api_key",
resource_id=api_key_id,
details={"name": api_key.name}
)
logger.info(f"API key regenerated: {api_key.name} by {current_user['username']}")
return APIKeyCreateResponse(
api_key=APIKeyResponse.model_validate(api_key),
secret_key=full_key
)
@router.get("/{api_key_id}/usage", response_model=APIKeyUsageResponse)
async def get_api_key_usage(
api_key_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Get API key usage statistics"""
# Get API key
query = select(APIKey).where(APIKey.id == int(api_key_id))
result = await db.execute(query)
api_key = result.scalar_one_or_none()
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
)
# Check permissions - users can view their own API key usage
if api_key.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
# Calculate usage statistics
from app.models.usage_tracking import UsageTracking
now = datetime.utcnow()
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
hour_start = now.replace(minute=0, second=0, microsecond=0)
# Today's usage
today_query = select(
func.count(UsageTracking.id),
func.sum(UsageTracking.total_tokens),
func.sum(UsageTracking.cost_cents)
).where(
UsageTracking.api_key_id == api_key_id,
UsageTracking.created_at >= today_start
)
today_result = await db.execute(today_query)
today_stats = today_result.first()
# This hour's usage
hour_query = select(
func.count(UsageTracking.id),
func.sum(UsageTracking.total_tokens),
func.sum(UsageTracking.cost_cents)
).where(
UsageTracking.api_key_id == api_key_id,
UsageTracking.created_at >= hour_start
)
hour_result = await db.execute(hour_query)
hour_stats = hour_result.first()
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="get_api_key_usage",
resource_type="api_key",
resource_id=api_key_id
)
return APIKeyUsageResponse(
api_key_id=api_key_id,
total_requests=api_key.total_requests,
total_tokens=api_key.total_tokens,
total_cost_cents=api_key.total_cost_cents,
requests_today=today_stats[0] or 0,
tokens_today=today_stats[1] or 0,
cost_today_cents=today_stats[2] or 0,
requests_this_hour=hour_stats[0] or 0,
tokens_this_hour=hour_stats[1] or 0,
cost_this_hour_cents=hour_stats[2] or 0,
last_used_at=api_key.last_used_at
)
@router.post("/{api_key_id}/activate")
async def activate_api_key(
api_key_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Activate API key"""
# Get API key
query = select(APIKey).where(APIKey.id == int(api_key_id))
result = await db.execute(query)
api_key = result.scalar_one_or_none()
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
)
# Check permissions - users can activate their own API keys
if api_key.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:update")
# Activate API key
api_key.is_active = True
await db.commit()
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="activate_api_key",
resource_type="api_key",
resource_id=api_key_id,
details={"name": api_key.name}
)
logger.info(f"API key activated: {api_key.name} by {current_user['username']}")
return {"message": "API key activated successfully"}
@router.post("/{api_key_id}/deactivate")
async def deactivate_api_key(
api_key_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Deactivate API key"""
# Get API key
query = select(APIKey).where(APIKey.id == int(api_key_id))
result = await db.execute(query)
api_key = result.scalar_one_or_none()
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
)
# Check permissions - users can deactivate their own API keys
if api_key.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:update")
# Deactivate API key
api_key.is_active = False
await db.commit()
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="deactivate_api_key",
resource_type="api_key",
resource_id=api_key_id,
details={"name": api_key.name}
)
logger.info(f"API key deactivated: {api_key.name} by {current_user['username']}")
return {"message": "API key deactivated successfully"}

598
backend/app/api/v1/audit.py Normal file
View File

@@ -0,0 +1,598 @@
"""
Audit log query endpoints
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Query
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, and_, or_
from datetime import datetime, timedelta
from app.db.database import get_db
from app.models.audit_log import AuditLog
from app.models.user import User
from app.core.security import get_current_user
from app.services.permission_manager import require_permission
from app.services.audit_service import log_audit_event, get_audit_logs, get_audit_stats
from app.core.logging import get_logger
logger = get_logger(__name__)
router = APIRouter()
# Pydantic models
class AuditLogResponse(BaseModel):
id: str
user_id: Optional[str] = None
api_key_id: Optional[str] = None
action: str
resource_type: str
resource_id: Optional[str] = None
details: dict
ip_address: Optional[str] = None
user_agent: Optional[str] = None
success: bool
severity: str
created_at: datetime
class Config:
from_attributes = True
class AuditLogListResponse(BaseModel):
logs: List[AuditLogResponse]
total: int
page: int
size: int
class AuditStatsResponse(BaseModel):
total_events: int
events_by_action: dict
events_by_resource_type: dict
events_by_severity: dict
success_rate: float
failure_rate: float
events_by_user: dict
events_by_hour: dict
top_actions: List[dict]
top_resources: List[dict]
class AuditSearchRequest(BaseModel):
user_id: Optional[str] = None
action: Optional[str] = None
resource_type: Optional[str] = None
resource_id: Optional[str] = None
start_date: Optional[datetime] = None
end_date: Optional[datetime] = None
success: Optional[bool] = None
severity: Optional[str] = None
ip_address: Optional[str] = None
search_text: Optional[str] = None
class SecurityEventsResponse(BaseModel):
suspicious_activities: List[dict]
failed_logins: List[dict]
unusual_access_patterns: List[dict]
high_severity_events: List[dict]
# Audit log query endpoints
@router.get("/", response_model=AuditLogListResponse)
async def list_audit_logs(
page: int = Query(1, ge=1),
size: int = Query(50, ge=1, le=1000),
user_id: Optional[str] = Query(None),
action: Optional[str] = Query(None),
resource_type: Optional[str] = Query(None),
resource_id: Optional[str] = Query(None),
start_date: Optional[datetime] = Query(None),
end_date: Optional[datetime] = Query(None),
success: Optional[bool] = Query(None),
severity: Optional[str] = Query(None),
search: Optional[str] = Query(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""List audit logs with filtering and pagination"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:audit:read")
# Build query
query = select(AuditLog)
conditions = []
# Apply filters
if user_id:
conditions.append(AuditLog.user_id == user_id)
if action:
conditions.append(AuditLog.action == action)
if resource_type:
conditions.append(AuditLog.resource_type == resource_type)
if resource_id:
conditions.append(AuditLog.resource_id == resource_id)
if start_date:
conditions.append(AuditLog.created_at >= start_date)
if end_date:
conditions.append(AuditLog.created_at <= end_date)
if success is not None:
conditions.append(AuditLog.success == success)
if severity:
conditions.append(AuditLog.severity == severity)
if search:
search_conditions = [
AuditLog.action.ilike(f"%{search}%"),
AuditLog.resource_type.ilike(f"%{search}%"),
AuditLog.details.astext.ilike(f"%{search}%")
]
conditions.append(or_(*search_conditions))
if conditions:
query = query.where(and_(*conditions))
# Get total count
count_query = select(func.count(AuditLog.id))
if conditions:
count_query = count_query.where(and_(*conditions))
total_result = await db.execute(count_query)
total = total_result.scalar()
# Apply pagination and ordering
offset = (page - 1) * size
query = query.offset(offset).limit(size).order_by(AuditLog.created_at.desc())
# Execute query
result = await db.execute(query)
logs = result.scalars().all()
# Log audit event for this query
await log_audit_event(
db=db,
user_id=current_user["id"],
action="query_audit_logs",
resource_type="audit_log",
details={
"filters": {
"user_id": user_id,
"action": action,
"resource_type": resource_type,
"start_date": start_date.isoformat() if start_date else None,
"end_date": end_date.isoformat() if end_date else None,
"success": success,
"severity": severity,
"search": search
},
"page": page,
"size": size,
"total_results": total
}
)
return AuditLogListResponse(
logs=[AuditLogResponse.model_validate(log) for log in logs],
total=total,
page=page,
size=size
)
@router.post("/search", response_model=AuditLogListResponse)
async def search_audit_logs(
search_request: AuditSearchRequest,
page: int = Query(1, ge=1),
size: int = Query(50, ge=1, le=1000),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Advanced search for audit logs"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:audit:read")
# Use the audit service function
logs = await get_audit_logs(
db=db,
user_id=search_request.user_id,
action=search_request.action,
resource_type=search_request.resource_type,
resource_id=search_request.resource_id,
start_date=search_request.start_date,
end_date=search_request.end_date,
limit=size,
offset=(page - 1) * size
)
# Get total count for the search
total_query = select(func.count(AuditLog.id))
conditions = []
if search_request.user_id:
conditions.append(AuditLog.user_id == search_request.user_id)
if search_request.action:
conditions.append(AuditLog.action == search_request.action)
if search_request.resource_type:
conditions.append(AuditLog.resource_type == search_request.resource_type)
if search_request.resource_id:
conditions.append(AuditLog.resource_id == search_request.resource_id)
if search_request.start_date:
conditions.append(AuditLog.created_at >= search_request.start_date)
if search_request.end_date:
conditions.append(AuditLog.created_at <= search_request.end_date)
if search_request.success is not None:
conditions.append(AuditLog.success == search_request.success)
if search_request.severity:
conditions.append(AuditLog.severity == search_request.severity)
if search_request.ip_address:
conditions.append(AuditLog.ip_address == search_request.ip_address)
if search_request.search_text:
search_conditions = [
AuditLog.action.ilike(f"%{search_request.search_text}%"),
AuditLog.resource_type.ilike(f"%{search_request.search_text}%"),
AuditLog.details.astext.ilike(f"%{search_request.search_text}%")
]
conditions.append(or_(*search_conditions))
if conditions:
total_query = total_query.where(and_(*conditions))
total_result = await db.execute(total_query)
total = total_result.scalar()
# Log audit event
await log_audit_event(
db=db,
user_id=current_user["id"],
action="advanced_search_audit_logs",
resource_type="audit_log",
details={
"search_criteria": search_request.model_dump(exclude_unset=True),
"results_count": len(logs),
"total_matches": total
}
)
return AuditLogListResponse(
logs=[AuditLogResponse.model_validate(log) for log in logs],
total=total,
page=page,
size=size
)
@router.get("/stats", response_model=AuditStatsResponse)
async def get_audit_statistics(
start_date: Optional[datetime] = Query(None),
end_date: Optional[datetime] = Query(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Get audit log statistics"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:audit:read")
# Default to last 30 days if no dates provided
if not end_date:
end_date = datetime.utcnow()
if not start_date:
start_date = end_date - timedelta(days=30)
# Get basic stats using audit service
basic_stats = await get_audit_stats(db, start_date, end_date)
# Get additional statistics
conditions = [
AuditLog.created_at >= start_date,
AuditLog.created_at <= end_date
]
# Events by user
user_query = select(
AuditLog.user_id,
func.count(AuditLog.id).label('count')
).where(and_(*conditions)).group_by(AuditLog.user_id).order_by(func.count(AuditLog.id).desc()).limit(10)
user_result = await db.execute(user_query)
events_by_user = dict(user_result.fetchall())
# Events by hour of day
hour_query = select(
func.extract('hour', AuditLog.created_at).label('hour'),
func.count(AuditLog.id).label('count')
).where(and_(*conditions)).group_by(func.extract('hour', AuditLog.created_at)).order_by('hour')
hour_result = await db.execute(hour_query)
events_by_hour = dict(hour_result.fetchall())
# Top actions
top_actions_query = select(
AuditLog.action,
func.count(AuditLog.id).label('count')
).where(and_(*conditions)).group_by(AuditLog.action).order_by(func.count(AuditLog.id).desc()).limit(10)
top_actions_result = await db.execute(top_actions_query)
top_actions = [{"action": row[0], "count": row[1]} for row in top_actions_result.fetchall()]
# Top resources
top_resources_query = select(
AuditLog.resource_type,
func.count(AuditLog.id).label('count')
).where(and_(*conditions)).group_by(AuditLog.resource_type).order_by(func.count(AuditLog.id).desc()).limit(10)
top_resources_result = await db.execute(top_resources_query)
top_resources = [{"resource_type": row[0], "count": row[1]} for row in top_resources_result.fetchall()]
# Log audit event
await log_audit_event(
db=db,
user_id=current_user["id"],
action="get_audit_statistics",
resource_type="audit_log",
details={
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat(),
"total_events": basic_stats["total_events"]
}
)
return AuditStatsResponse(
**basic_stats,
events_by_user=events_by_user,
events_by_hour=events_by_hour,
top_actions=top_actions,
top_resources=top_resources
)
@router.get("/security-events", response_model=SecurityEventsResponse)
async def get_security_events(
hours: int = Query(24, ge=1, le=168), # Last 24 hours by default, max 1 week
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Get security-related events and anomalies"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:audit:read")
end_time = datetime.utcnow()
start_time = end_time - timedelta(hours=hours)
# Failed logins
failed_logins_query = select(AuditLog).where(
and_(
AuditLog.created_at >= start_time,
AuditLog.action == "login",
AuditLog.success == False
)
).order_by(AuditLog.created_at.desc()).limit(50)
failed_logins_result = await db.execute(failed_logins_query)
failed_logins = [
{
"timestamp": log.created_at.isoformat(),
"user_id": log.user_id,
"ip_address": log.ip_address,
"user_agent": log.user_agent,
"details": log.details
}
for log in failed_logins_result.scalars().all()
]
# High severity events
high_severity_query = select(AuditLog).where(
and_(
AuditLog.created_at >= start_time,
AuditLog.severity.in_(["error", "critical"])
)
).order_by(AuditLog.created_at.desc()).limit(50)
high_severity_result = await db.execute(high_severity_query)
high_severity_events = [
{
"timestamp": log.created_at.isoformat(),
"action": log.action,
"resource_type": log.resource_type,
"severity": log.severity,
"user_id": log.user_id,
"ip_address": log.ip_address,
"success": log.success,
"details": log.details
}
for log in high_severity_result.scalars().all()
]
# Suspicious activities (multiple failed attempts from same IP)
suspicious_ips_query = select(
AuditLog.ip_address,
func.count(AuditLog.id).label('failed_count')
).where(
and_(
AuditLog.created_at >= start_time,
AuditLog.success == False,
AuditLog.ip_address.isnot(None)
)
).group_by(AuditLog.ip_address).having(func.count(AuditLog.id) >= 5).order_by(func.count(AuditLog.id).desc())
suspicious_ips_result = await db.execute(suspicious_ips_query)
suspicious_activities = [
{
"ip_address": row[0],
"failed_attempts": row[1],
"risk_level": "high" if row[1] >= 10 else "medium"
}
for row in suspicious_ips_result.fetchall()
]
# Unusual access patterns (users accessing from multiple IPs)
unusual_access_query = select(
AuditLog.user_id,
func.count(func.distinct(AuditLog.ip_address)).label('ip_count'),
func.array_agg(func.distinct(AuditLog.ip_address)).label('ip_addresses')
).where(
and_(
AuditLog.created_at >= start_time,
AuditLog.user_id.isnot(None),
AuditLog.ip_address.isnot(None)
)
).group_by(AuditLog.user_id).having(func.count(func.distinct(AuditLog.ip_address)) >= 3).order_by(func.count(func.distinct(AuditLog.ip_address)).desc())
unusual_access_result = await db.execute(unusual_access_query)
unusual_access_patterns = [
{
"user_id": row[0],
"unique_ips": row[1],
"ip_addresses": row[2] if row[2] else []
}
for row in unusual_access_result.fetchall()
]
# Log audit event
await log_audit_event(
db=db,
user_id=current_user["id"],
action="get_security_events",
resource_type="audit_log",
details={
"time_range_hours": hours,
"failed_logins_count": len(failed_logins),
"high_severity_count": len(high_severity_events),
"suspicious_ips_count": len(suspicious_activities),
"unusual_access_patterns_count": len(unusual_access_patterns)
}
)
return SecurityEventsResponse(
suspicious_activities=suspicious_activities,
failed_logins=failed_logins,
unusual_access_patterns=unusual_access_patterns,
high_severity_events=high_severity_events
)
@router.get("/export")
async def export_audit_logs(
format: str = Query("csv", pattern="^(csv|json)$"),
start_date: Optional[datetime] = Query(None),
end_date: Optional[datetime] = Query(None),
user_id: Optional[str] = Query(None),
action: Optional[str] = Query(None),
resource_type: Optional[str] = Query(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Export audit logs in CSV or JSON format"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:audit:export")
# Default to last 30 days if no dates provided
if not end_date:
end_date = datetime.utcnow()
if not start_date:
start_date = end_date - timedelta(days=30)
# Limit export size
max_records = 10000
# Build query
query = select(AuditLog)
conditions = [
AuditLog.created_at >= start_date,
AuditLog.created_at <= end_date
]
if user_id:
conditions.append(AuditLog.user_id == user_id)
if action:
conditions.append(AuditLog.action == action)
if resource_type:
conditions.append(AuditLog.resource_type == resource_type)
query = query.where(and_(*conditions)).order_by(AuditLog.created_at.desc()).limit(max_records)
# Execute query
result = await db.execute(query)
logs = result.scalars().all()
# Log export event
await log_audit_event(
db=db,
user_id=current_user["id"],
action="export_audit_logs",
resource_type="audit_log",
details={
"format": format,
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat(),
"records_exported": len(logs),
"filters": {
"user_id": user_id,
"action": action,
"resource_type": resource_type
}
}
)
if format == "json":
from fastapi.responses import JSONResponse
export_data = [
{
"id": str(log.id),
"user_id": log.user_id,
"action": log.action,
"resource_type": log.resource_type,
"resource_id": log.resource_id,
"details": log.details,
"ip_address": log.ip_address,
"user_agent": log.user_agent,
"success": log.success,
"severity": log.severity,
"created_at": log.created_at.isoformat()
}
for log in logs
]
return JSONResponse(content=export_data)
else: # CSV format
import csv
import io
from fastapi.responses import StreamingResponse
output = io.StringIO()
writer = csv.writer(output)
# Write header
writer.writerow([
"ID", "User ID", "Action", "Resource Type", "Resource ID",
"IP Address", "Success", "Severity", "Created At", "Details"
])
# Write data
for log in logs:
writer.writerow([
str(log.id),
log.user_id or "",
log.action,
log.resource_type,
log.resource_id or "",
log.ip_address or "",
log.success,
log.severity,
log.created_at.isoformat(),
str(log.details)
])
output.seek(0)
return StreamingResponse(
io.BytesIO(output.getvalue().encode()),
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename=audit_logs_{start_date.strftime('%Y%m%d')}_{end_date.strftime('%Y%m%d')}.csv"}
)

279
backend/app/api/v1/auth.py Normal file
View File

@@ -0,0 +1,279 @@
"""
Authentication API endpoints
"""
from datetime import datetime, timedelta
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import HTTPBearer
from pydantic import BaseModel, EmailStr, validator
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.config import settings
from app.core.security import (
verify_password,
get_password_hash,
create_access_token,
create_refresh_token,
verify_token,
get_current_user,
get_current_active_user,
)
from app.db.database import get_db
from app.models.user import User
from app.utils.exceptions import AuthenticationError, ValidationError
router = APIRouter()
security = HTTPBearer()
# Request/Response Models
class UserRegisterRequest(BaseModel):
email: EmailStr
username: str
password: str
first_name: Optional[str] = None
last_name: Optional[str] = None
@validator('password')
def validate_password(cls, v):
if len(v) < 8:
raise ValueError('Password must be at least 8 characters long')
if not any(c.isupper() for c in v):
raise ValueError('Password must contain at least one uppercase letter')
if not any(c.islower() for c in v):
raise ValueError('Password must contain at least one lowercase letter')
if not any(c.isdigit() for c in v):
raise ValueError('Password must contain at least one digit')
return v
@validator('username')
def validate_username(cls, v):
if len(v) < 3:
raise ValueError('Username must be at least 3 characters long')
if not v.isalnum():
raise ValueError('Username must contain only alphanumeric characters')
return v
class UserLoginRequest(BaseModel):
email: EmailStr
password: str
class TokenResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
expires_in: int
class UserResponse(BaseModel):
id: int
email: str
username: str
full_name: Optional[str]
is_active: bool
is_verified: bool
role: str
created_at: datetime
class Config:
from_attributes = True
class RefreshTokenRequest(BaseModel):
refresh_token: str
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(
user_data: UserRegisterRequest,
db: AsyncSession = Depends(get_db)
):
"""Register a new user"""
# Check if user already exists
stmt = select(User).where(User.email == user_data.email)
result = await db.execute(stmt)
if result.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered"
)
# Check if username already exists
stmt = select(User).where(User.username == user_data.username)
result = await db.execute(stmt)
if result.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username already taken"
)
# Create new user
full_name = None
if user_data.first_name or user_data.last_name:
full_name = f"{user_data.first_name or ''} {user_data.last_name or ''}".strip()
user = User(
email=user_data.email,
username=user_data.username,
hashed_password=get_password_hash(user_data.password),
full_name=full_name,
is_active=True,
is_verified=False,
role="user"
)
db.add(user)
await db.commit()
await db.refresh(user)
return UserResponse.from_orm(user)
@router.post("/login", response_model=TokenResponse)
async def login(
user_data: UserLoginRequest,
db: AsyncSession = Depends(get_db)
):
"""Login user and return access tokens"""
# Get user by email
stmt = select(User).where(User.email == user_data.email)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user or not verify_password(user_data.password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password"
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User account is disabled"
)
# Update last login
user.update_last_login()
await db.commit()
# Create tokens
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={
"sub": str(user.id),
"email": user.email,
"is_superuser": user.is_superuser,
"role": user.role
},
expires_delta=access_token_expires
)
refresh_token = create_refresh_token(
data={"sub": str(user.id), "type": "refresh"}
)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
)
@router.post("/refresh", response_model=TokenResponse)
async def refresh_token(
token_data: RefreshTokenRequest,
db: AsyncSession = Depends(get_db)
):
"""Refresh access token using refresh token"""
try:
payload = verify_token(token_data.refresh_token)
user_id = payload.get("sub")
token_type = payload.get("type")
if not user_id or token_type != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token"
)
# Get user from database
stmt = select(User).where(User.id == int(user_id))
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or inactive"
)
# Create new access token
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={
"sub": str(user.id),
"email": user.email,
"is_superuser": user.is_superuser,
"role": user.role
},
expires_delta=access_token_expires
)
return TokenResponse(
access_token=access_token,
refresh_token=token_data.refresh_token, # Keep same refresh token
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token"
)
@router.get("/me", response_model=UserResponse)
async def get_current_user_info(
current_user: dict = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db)
):
"""Get current user information"""
# Get full user details from database
stmt = select(User).where(User.id == int(current_user["id"]))
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
return UserResponse.from_orm(user)
@router.post("/logout")
async def logout():
"""Logout user (client should discard tokens)"""
return {"message": "Successfully logged out"}
@router.post("/verify-token")
async def verify_user_token(
current_user: dict = Depends(get_current_user)
):
"""Verify if the current token is valid"""
return {
"valid": True,
"user_id": current_user["id"],
"email": current_user["email"]
}

View File

@@ -0,0 +1,675 @@
"""
Budget management endpoints
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Query
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete, func
from datetime import datetime, timedelta
from enum import Enum
from app.db.database import get_db
from app.models.budget import Budget
from app.models.user import User
from app.models.usage_tracking import UsageTracking
from app.core.security import get_current_user
from app.services.permission_manager import require_permission
from app.services.audit_service import log_audit_event
from app.core.logging import get_logger
logger = get_logger(__name__)
router = APIRouter()
# Enums
class BudgetType(str, Enum):
TOKENS = "tokens"
DOLLARS = "dollars"
REQUESTS = "requests"
class PeriodType(str, Enum):
HOURLY = "hourly"
DAILY = "daily"
WEEKLY = "weekly"
MONTHLY = "monthly"
YEARLY = "yearly"
# Pydantic models
class BudgetCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=500)
budget_type: BudgetType
limit_amount: float = Field(..., gt=0)
period_type: PeriodType
user_id: Optional[str] = None # Admin can set budgets for other users
api_key_id: Optional[str] = None # Budget can be linked to specific API key
is_enabled: bool = True
alert_threshold_percent: float = Field(80.0, ge=0, le=100)
allowed_resources: List[str] = Field(default_factory=list)
metadata: dict = Field(default_factory=dict)
class BudgetUpdate(BaseModel):
name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=500)
limit_amount: Optional[float] = Field(None, gt=0)
period_type: Optional[PeriodType] = None
is_enabled: Optional[bool] = None
alert_threshold_percent: Optional[float] = Field(None, ge=0, le=100)
allowed_resources: Optional[List[str]] = None
metadata: Optional[dict] = None
class BudgetResponse(BaseModel):
id: str
name: str
description: Optional[str] = None
budget_type: str
limit_amount: float
period_type: str
period_start: datetime
period_end: datetime
current_usage: float
usage_percentage: float
is_enabled: bool
alert_threshold_percent: float
user_id: Optional[str] = None
api_key_id: Optional[str] = None
allowed_resources: List[str]
metadata: dict
created_at: datetime
updated_at: Optional[datetime] = None
class Config:
from_attributes = True
class BudgetListResponse(BaseModel):
budgets: List[BudgetResponse]
total: int
page: int
size: int
class BudgetUsageResponse(BaseModel):
budget_id: str
current_usage: float
limit_amount: float
usage_percentage: float
remaining_amount: float
period_start: datetime
period_end: datetime
is_exceeded: bool
days_remaining: int
projected_usage: Optional[float] = None
usage_history: List[dict] = Field(default_factory=list)
class BudgetAlertResponse(BaseModel):
budget_id: str
budget_name: str
alert_type: str # "warning", "critical", "exceeded"
current_usage: float
limit_amount: float
usage_percentage: float
message: str
# Budget CRUD endpoints
@router.get("/", response_model=BudgetListResponse)
async def list_budgets(
page: int = Query(1, ge=1),
size: int = Query(10, ge=1, le=100),
user_id: Optional[str] = Query(None),
budget_type: Optional[BudgetType] = Query(None),
is_enabled: Optional[bool] = Query(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""List budgets with pagination and filtering"""
# Check permissions - users can view their own budgets
if user_id and int(user_id) != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:budgets:read")
elif not user_id:
require_permission(current_user.get("permissions", []), "platform:budgets:read")
# If no user_id specified and user doesn't have admin permissions, show only their budgets
if not user_id and "platform:budgets:read" not in current_user.get("permissions", []):
user_id = current_user['id']
# Build query
query = select(Budget)
# Apply filters
if user_id:
query = query.where(Budget.user_id == (int(user_id) if isinstance(user_id, str) else user_id))
if budget_type:
query = query.where(Budget.budget_type == budget_type.value)
if is_enabled is not None:
query = query.where(Budget.is_enabled == is_enabled)
# Get total count
count_query = select(func.count(Budget.id))
# Apply same filters to count query
if user_id:
count_query = count_query.where(Budget.user_id == (int(user_id) if isinstance(user_id, str) else user_id))
if budget_type:
count_query = count_query.where(Budget.budget_type == budget_type.value)
if is_enabled is not None:
count_query = count_query.where(Budget.is_enabled == is_enabled)
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
# Apply pagination
offset = (page - 1) * size
query = query.offset(offset).limit(size).order_by(Budget.created_at.desc())
# Execute query
result = await db.execute(query)
budgets = result.scalars().all()
# Calculate current usage for each budget
budget_responses = []
for budget in budgets:
usage = await _calculate_budget_usage(db, budget)
budget_data = BudgetResponse.model_validate(budget)
budget_data.current_usage = usage
budget_data.usage_percentage = (usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
budget_responses.append(budget_data)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="list_budgets",
resource_type="budget",
details={"page": page, "size": size, "filters": {"user_id": user_id, "budget_type": budget_type, "is_enabled": is_enabled}}
)
return BudgetListResponse(
budgets=budget_responses,
total=total,
page=page,
size=size
)
@router.get("/{budget_id}", response_model=BudgetResponse)
async def get_budget(
budget_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Get budget by ID"""
# Get budget
query = select(Budget).where(Budget.id == budget_id)
result = await db.execute(query)
budget = result.scalar_one_or_none()
if not budget:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Budget not found"
)
# Check permissions - users can view their own budgets
if budget.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:budgets:read")
# Calculate current usage
usage = await _calculate_budget_usage(db, budget)
# Build response
budget_data = BudgetResponse.model_validate(budget)
budget_data.current_usage = usage
budget_data.usage_percentage = (usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="get_budget",
resource_type="budget",
resource_id=budget_id
)
return budget_data
@router.post("/", response_model=BudgetResponse)
async def create_budget(
budget_data: BudgetCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Create a new budget"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:budgets:create")
# If user_id not specified, use current user
target_user_id = budget_data.user_id or current_user['id']
# If setting budget for another user, need admin permissions
if int(target_user_id) != current_user['id'] if isinstance(target_user_id, str) else target_user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:budgets:admin")
# Calculate period start and end
now = datetime.utcnow()
period_start, period_end = _calculate_period_bounds(now, budget_data.period_type)
# Create budget
new_budget = Budget(
name=budget_data.name,
description=budget_data.description,
budget_type=budget_data.budget_type.value,
limit_amount=budget_data.limit_amount,
period_type=budget_data.period_type.value,
period_start=period_start,
period_end=period_end,
user_id=target_user_id,
api_key_id=budget_data.api_key_id,
is_enabled=budget_data.is_enabled,
alert_threshold_percent=budget_data.alert_threshold_percent,
allowed_resources=budget_data.allowed_resources,
metadata=budget_data.metadata
)
db.add(new_budget)
await db.commit()
await db.refresh(new_budget)
# Build response
budget_response = BudgetResponse.model_validate(new_budget)
budget_response.current_usage = 0.0
budget_response.usage_percentage = 0.0
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="create_budget",
resource_type="budget",
resource_id=str(new_budget.id),
details={"name": budget_data.name, "budget_type": budget_data.budget_type, "limit_amount": budget_data.limit_amount}
)
logger.info(f"Budget created: {new_budget.name} by {current_user['username']}")
return budget_response
@router.put("/{budget_id}", response_model=BudgetResponse)
async def update_budget(
budget_id: str,
budget_data: BudgetUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Update budget"""
# Get budget
query = select(Budget).where(Budget.id == budget_id)
result = await db.execute(query)
budget = result.scalar_one_or_none()
if not budget:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Budget not found"
)
# Check permissions - users can update their own budgets
if budget.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:budgets:update")
# Store original values for audit
original_values = {
"name": budget.name,
"limit_amount": budget.limit_amount,
"is_enabled": budget.is_enabled
}
# Update budget fields
update_data = budget_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(budget, field, value)
# Recalculate period if period_type changed
if "period_type" in update_data:
period_start, period_end = _calculate_period_bounds(datetime.utcnow(), budget.period_type)
budget.period_start = period_start
budget.period_end = period_end
await db.commit()
await db.refresh(budget)
# Calculate current usage
usage = await _calculate_budget_usage(db, budget)
# Build response
budget_response = BudgetResponse.model_validate(budget)
budget_response.current_usage = usage
budget_response.usage_percentage = (usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="update_budget",
resource_type="budget",
resource_id=budget_id,
details={
"updated_fields": list(update_data.keys()),
"before_values": original_values,
"after_values": {k: getattr(budget, k) for k in update_data.keys()}
}
)
logger.info(f"Budget updated: {budget.name} by {current_user['username']}")
return budget_response
@router.delete("/{budget_id}")
async def delete_budget(
budget_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Delete budget"""
# Get budget
query = select(Budget).where(Budget.id == budget_id)
result = await db.execute(query)
budget = result.scalar_one_or_none()
if not budget:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Budget not found"
)
# Check permissions - users can delete their own budgets
if budget.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:budgets:delete")
# Delete budget
await db.delete(budget)
await db.commit()
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="delete_budget",
resource_type="budget",
resource_id=budget_id,
details={"name": budget.name}
)
logger.info(f"Budget deleted: {budget.name} by {current_user['username']}")
return {"message": "Budget deleted successfully"}
@router.get("/{budget_id}/usage", response_model=BudgetUsageResponse)
async def get_budget_usage(
budget_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Get detailed budget usage information"""
# Get budget
query = select(Budget).where(Budget.id == budget_id)
result = await db.execute(query)
budget = result.scalar_one_or_none()
if not budget:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Budget not found"
)
# Check permissions - users can view their own budget usage
if budget.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:budgets:read")
# Calculate usage
current_usage = await _calculate_budget_usage(db, budget)
usage_percentage = (current_usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
remaining_amount = max(0, budget.limit_amount - current_usage)
is_exceeded = current_usage > budget.limit_amount
# Calculate days remaining in period
now = datetime.utcnow()
days_remaining = max(0, (budget.period_end - now).days)
# Calculate projected usage
projected_usage = None
if days_remaining > 0 and current_usage > 0:
days_elapsed = (now - budget.period_start).days + 1
if days_elapsed > 0:
daily_rate = current_usage / days_elapsed
total_days = (budget.period_end - budget.period_start).days + 1
projected_usage = daily_rate * total_days
# Get usage history (last 30 days)
usage_history = await _get_usage_history(db, budget, days=30)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="get_budget_usage",
resource_type="budget",
resource_id=budget_id
)
return BudgetUsageResponse(
budget_id=budget_id,
current_usage=current_usage,
limit_amount=budget.limit_amount,
usage_percentage=usage_percentage,
remaining_amount=remaining_amount,
period_start=budget.period_start,
period_end=budget.period_end,
is_exceeded=is_exceeded,
days_remaining=days_remaining,
projected_usage=projected_usage,
usage_history=usage_history
)
@router.get("/{budget_id}/alerts", response_model=List[BudgetAlertResponse])
async def get_budget_alerts(
budget_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Get budget alerts"""
# Get budget
query = select(Budget).where(Budget.id == budget_id)
result = await db.execute(query)
budget = result.scalar_one_or_none()
if not budget:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Budget not found"
)
# Check permissions - users can view their own budget alerts
if budget.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:budgets:read")
# Calculate usage
current_usage = await _calculate_budget_usage(db, budget)
usage_percentage = (current_usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
alerts = []
# Check for alerts
if usage_percentage >= 100:
alerts.append(BudgetAlertResponse(
budget_id=budget_id,
budget_name=budget.name,
alert_type="exceeded",
current_usage=current_usage,
limit_amount=budget.limit_amount,
usage_percentage=usage_percentage,
message=f"Budget '{budget.name}' has been exceeded ({usage_percentage:.1f}% used)"
))
elif usage_percentage >= 90:
alerts.append(BudgetAlertResponse(
budget_id=budget_id,
budget_name=budget.name,
alert_type="critical",
current_usage=current_usage,
limit_amount=budget.limit_amount,
usage_percentage=usage_percentage,
message=f"Budget '{budget.name}' is critically high ({usage_percentage:.1f}% used)"
))
elif usage_percentage >= budget.alert_threshold_percent:
alerts.append(BudgetAlertResponse(
budget_id=budget_id,
budget_name=budget.name,
alert_type="warning",
current_usage=current_usage,
limit_amount=budget.limit_amount,
usage_percentage=usage_percentage,
message=f"Budget '{budget.name}' has reached alert threshold ({usage_percentage:.1f}% used)"
))
return alerts
# Helper functions
async def _calculate_budget_usage(db: AsyncSession, budget: Budget) -> float:
"""Calculate current usage for a budget"""
# Build base query
query = select(UsageTracking)
# Filter by time period
query = query.where(
UsageTracking.created_at >= budget.period_start,
UsageTracking.created_at <= budget.period_end
)
# Filter by user or API key
if budget.api_key_id:
query = query.where(UsageTracking.api_key_id == budget.api_key_id)
elif budget.user_id:
query = query.where(UsageTracking.user_id == budget.user_id)
# Calculate usage based on budget type
if budget.budget_type == "tokens":
usage_query = query.with_only_columns(func.sum(UsageTracking.total_tokens))
elif budget.budget_type == "dollars":
usage_query = query.with_only_columns(func.sum(UsageTracking.cost_cents))
elif budget.budget_type == "requests":
usage_query = query.with_only_columns(func.count(UsageTracking.id))
else:
return 0.0
result = await db.execute(usage_query)
usage = result.scalar() or 0
# Convert cents to dollars for dollar budgets
if budget.budget_type == "dollars":
usage = usage / 100.0
return float(usage)
def _calculate_period_bounds(current_time: datetime, period_type: str) -> tuple[datetime, datetime]:
"""Calculate period start and end dates"""
if period_type == "hourly":
start = current_time.replace(minute=0, second=0, microsecond=0)
end = start + timedelta(hours=1) - timedelta(microseconds=1)
elif period_type == "daily":
start = current_time.replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1) - timedelta(microseconds=1)
elif period_type == "weekly":
# Start of week (Monday)
days_since_monday = current_time.weekday()
start = current_time.replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=days_since_monday)
end = start + timedelta(weeks=1) - timedelta(microseconds=1)
elif period_type == "monthly":
start = current_time.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
if start.month == 12:
next_month = start.replace(year=start.year + 1, month=1)
else:
next_month = start.replace(month=start.month + 1)
end = next_month - timedelta(microseconds=1)
elif period_type == "yearly":
start = current_time.replace(month=1, day=1, hour=0, minute=0, second=0, microsecond=0)
end = start.replace(year=start.year + 1) - timedelta(microseconds=1)
else:
# Default to daily
start = current_time.replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1) - timedelta(microseconds=1)
return start, end
async def _get_usage_history(db: AsyncSession, budget: Budget, days: int = 30) -> List[dict]:
"""Get usage history for the budget"""
end_date = datetime.utcnow()
start_date = end_date - timedelta(days=days)
# Build query
query = select(
func.date(UsageTracking.created_at).label('date'),
func.sum(UsageTracking.total_tokens).label('tokens'),
func.sum(UsageTracking.cost_cents).label('cost_cents'),
func.count(UsageTracking.id).label('requests')
).where(
UsageTracking.created_at >= start_date,
UsageTracking.created_at <= end_date
)
# Filter by user or API key
if budget.api_key_id:
query = query.where(UsageTracking.api_key_id == budget.api_key_id)
elif budget.user_id:
query = query.where(UsageTracking.user_id == budget.user_id)
query = query.group_by(func.date(UsageTracking.created_at)).order_by(func.date(UsageTracking.created_at))
result = await db.execute(query)
rows = result.fetchall()
history = []
for row in rows:
usage_value = 0
if budget.budget_type == "tokens":
usage_value = row.tokens or 0
elif budget.budget_type == "dollars":
usage_value = (row.cost_cents or 0) / 100.0
elif budget.budget_type == "requests":
usage_value = row.requests or 0
history.append({
"date": row.date.isoformat(),
"usage": usage_value,
"tokens": row.tokens or 0,
"cost_dollars": (row.cost_cents or 0) / 100.0,
"requests": row.requests or 0
})
return history

View File

@@ -0,0 +1,772 @@
"""
Chatbot API endpoints
"""
import asyncio
from typing import Dict, Any, List, Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete
from datetime import datetime
from app.db.database import get_db
from app.models.chatbot import ChatbotInstance, ChatbotConversation, ChatbotMessage, ChatbotAnalytics
from app.core.logging import log_api_request
from app.services.module_manager import module_manager
from app.core.security import get_current_user
from app.models.user import User
from app.services.api_key_auth import get_api_key_auth
from app.models.api_key import APIKey
router = APIRouter()
class ChatbotCreateRequest(BaseModel):
name: str
chatbot_type: str = "assistant"
model: str = "gpt-3.5-turbo"
system_prompt: str = ""
use_rag: bool = False
rag_collection: Optional[str] = None
rag_top_k: int = 5
temperature: float = 0.7
max_tokens: int = 1000
memory_length: int = 10
fallback_responses: List[str] = []
class ChatRequest(BaseModel):
message: str
conversation_id: Optional[str] = None
@router.get("/list")
async def list_chatbots(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Get list of all chatbots for the current user"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("list_chatbots", {"user_id": user_id})
try:
# Query chatbots created by the current user
result = await db.execute(
select(ChatbotInstance)
.where(ChatbotInstance.created_by == str(user_id))
.order_by(ChatbotInstance.created_at.desc())
)
chatbots = result.scalars().all()
chatbot_list = []
for chatbot in chatbots:
chatbot_dict = {
"id": chatbot.id,
"name": chatbot.name,
"description": chatbot.description,
"config": chatbot.config,
"created_by": chatbot.created_by,
"created_at": chatbot.created_at.isoformat() if chatbot.created_at else None,
"updated_at": chatbot.updated_at.isoformat() if chatbot.updated_at else None,
"is_active": chatbot.is_active
}
chatbot_list.append(chatbot_dict)
return chatbot_list
except Exception as e:
log_api_request("list_chatbots_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to fetch chatbots: {str(e)}")
@router.post("/create")
async def create_chatbot(
request: ChatbotCreateRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Create a new chatbot instance"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("create_chatbot", {
"user_id": user_id,
"chatbot_name": request.name,
"chatbot_type": request.chatbot_type
})
try:
# Get the chatbot module
chatbot_module = module_manager.get_module("chatbot")
if not chatbot_module:
raise HTTPException(status_code=500, detail="Chatbot module not available")
# Import needed types
from modules.chatbot.main import ChatbotConfig
# Create chatbot config object
config = ChatbotConfig(
name=request.name,
chatbot_type=request.chatbot_type,
model=request.model,
system_prompt=request.system_prompt,
use_rag=request.use_rag,
rag_collection=request.rag_collection,
rag_top_k=request.rag_top_k,
temperature=request.temperature,
max_tokens=request.max_tokens,
memory_length=request.memory_length,
fallback_responses=request.fallback_responses
)
# Use sync database session for module compatibility
from app.db.database import SessionLocal
sync_db = SessionLocal()
try:
# Use the chatbot module's create method (which handles default prompts)
chatbot = await chatbot_module.create_chatbot(config, str(user_id), sync_db)
finally:
sync_db.close()
# Return the created chatbot
return {
"id": chatbot.id,
"name": chatbot.name,
"description": f"AI chatbot of type {request.chatbot_type}",
"config": chatbot.config.__dict__,
"created_by": chatbot.created_by,
"created_at": chatbot.created_at.isoformat() if chatbot.created_at else None,
"updated_at": chatbot.updated_at.isoformat() if chatbot.updated_at else None,
"is_active": chatbot.is_active
}
except Exception as e:
await db.rollback()
log_api_request("create_chatbot_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to create chatbot: {str(e)}")
@router.put("/update/{chatbot_id}")
async def update_chatbot(
chatbot_id: str,
request: ChatbotCreateRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Update an existing chatbot instance"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("update_chatbot", {
"user_id": user_id,
"chatbot_id": chatbot_id,
"chatbot_name": request.name
})
try:
# Get existing chatbot and verify ownership
result = await db.execute(
select(ChatbotInstance)
.where(ChatbotInstance.id == chatbot_id)
.where(ChatbotInstance.created_by == str(user_id))
)
chatbot = result.scalar_one_or_none()
if not chatbot:
raise HTTPException(status_code=404, detail="Chatbot not found or access denied")
# Update chatbot configuration
config = {
"name": request.name,
"chatbot_type": request.chatbot_type,
"model": request.model,
"system_prompt": request.system_prompt,
"use_rag": request.use_rag,
"rag_collection": request.rag_collection,
"rag_top_k": request.rag_top_k,
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"memory_length": request.memory_length,
"fallback_responses": request.fallback_responses
}
# Update the chatbot
await db.execute(
update(ChatbotInstance)
.where(ChatbotInstance.id == chatbot_id)
.values(
name=request.name,
config=config,
updated_at=datetime.utcnow()
)
)
await db.commit()
# Return updated chatbot
updated_result = await db.execute(
select(ChatbotInstance)
.where(ChatbotInstance.id == chatbot_id)
)
updated_chatbot = updated_result.scalar_one()
return {
"id": updated_chatbot.id,
"name": updated_chatbot.name,
"description": updated_chatbot.description,
"config": updated_chatbot.config,
"created_by": updated_chatbot.created_by,
"created_at": updated_chatbot.created_at.isoformat() if updated_chatbot.created_at else None,
"updated_at": updated_chatbot.updated_at.isoformat() if updated_chatbot.updated_at else None,
"is_active": updated_chatbot.is_active
}
except HTTPException:
raise
except Exception as e:
await db.rollback()
log_api_request("update_chatbot_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to update chatbot: {str(e)}")
@router.post("/chat/{chatbot_id}")
async def chat_with_chatbot(
chatbot_id: str,
request: ChatRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Send a message to a chatbot and get a response"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("chat_with_chatbot", {
"user_id": user_id,
"chatbot_id": chatbot_id,
"message_length": len(request.message)
})
try:
# Get the chatbot instance
result = await db.execute(
select(ChatbotInstance)
.where(ChatbotInstance.id == chatbot_id)
.where(ChatbotInstance.created_by == str(user_id))
)
chatbot = result.scalar_one_or_none()
if not chatbot:
raise HTTPException(status_code=404, detail="Chatbot not found")
if not chatbot.is_active:
raise HTTPException(status_code=400, detail="Chatbot is not active")
# Get or create conversation
conversation = None
if request.conversation_id:
conv_result = await db.execute(
select(ChatbotConversation)
.where(ChatbotConversation.id == request.conversation_id)
.where(ChatbotConversation.chatbot_id == chatbot_id)
.where(ChatbotConversation.user_id == str(user_id))
)
conversation = conv_result.scalar_one_or_none()
if not conversation:
# Create new conversation
conversation = ChatbotConversation(
chatbot_id=chatbot_id,
user_id=str(user_id),
title=f"Chat {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}",
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
is_active=True,
context_data={}
)
db.add(conversation)
await db.commit()
await db.refresh(conversation)
# Save user message
user_message = ChatbotMessage(
conversation_id=conversation.id,
role="user",
content=request.message,
timestamp=datetime.utcnow(),
message_metadata={},
sources=None
)
db.add(user_message)
# Get chatbot module and generate response
try:
chatbot_module = module_manager.modules.get("chatbot")
if not chatbot_module:
raise HTTPException(status_code=500, detail="Chatbot module not available")
# Use the chatbot module to generate a response
response_data = await chatbot_module.chat(
chatbot_config=chatbot.config,
message=request.message,
conversation_history=[], # TODO: Load conversation history
user_id=str(user_id)
)
response_content = response_data.get("response", "I'm sorry, I couldn't generate a response.")
except Exception as e:
# Use fallback response
fallback_responses = chatbot.config.get("fallback_responses", [
"I'm sorry, I'm having trouble processing your request right now."
])
response_content = fallback_responses[0] if fallback_responses else "I'm sorry, I couldn't process your request."
# Save assistant message
assistant_message = ChatbotMessage(
conversation_id=conversation.id,
role="assistant",
content=response_content,
timestamp=datetime.utcnow(),
message_metadata={},
sources=None
)
db.add(assistant_message)
# Update conversation timestamp
conversation.updated_at = datetime.utcnow()
await db.commit()
return {
"conversation_id": conversation.id,
"response": response_content,
"timestamp": assistant_message.timestamp.isoformat()
}
except HTTPException:
raise
except Exception as e:
await db.rollback()
log_api_request("chat_with_chatbot_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to process chat: {str(e)}")
@router.get("/conversations/{chatbot_id}")
async def get_chatbot_conversations(
chatbot_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Get conversations for a chatbot"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("get_chatbot_conversations", {
"user_id": user_id,
"chatbot_id": chatbot_id
})
try:
# Verify chatbot ownership
chatbot_result = await db.execute(
select(ChatbotInstance)
.where(ChatbotInstance.id == chatbot_id)
.where(ChatbotInstance.created_by == str(user_id))
)
chatbot = chatbot_result.scalar_one_or_none()
if not chatbot:
raise HTTPException(status_code=404, detail="Chatbot not found")
# Get conversations
result = await db.execute(
select(ChatbotConversation)
.where(ChatbotConversation.chatbot_id == chatbot_id)
.where(ChatbotConversation.user_id == str(user_id))
.order_by(ChatbotConversation.updated_at.desc())
)
conversations = result.scalars().all()
conversation_list = []
for conv in conversations:
conversation_list.append({
"id": conv.id,
"title": conv.title,
"created_at": conv.created_at.isoformat() if conv.created_at else None,
"updated_at": conv.updated_at.isoformat() if conv.updated_at else None,
"is_active": conv.is_active
})
return conversation_list
except HTTPException:
raise
except Exception as e:
log_api_request("get_chatbot_conversations_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to fetch conversations: {str(e)}")
@router.get("/conversations/{conversation_id}/messages")
async def get_conversation_messages(
conversation_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Get messages for a conversation"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("get_conversation_messages", {
"user_id": user_id,
"conversation_id": conversation_id
})
try:
# Verify conversation ownership
conv_result = await db.execute(
select(ChatbotConversation)
.where(ChatbotConversation.id == conversation_id)
.where(ChatbotConversation.user_id == str(user_id))
)
conversation = conv_result.scalar_one_or_none()
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
# Get messages
result = await db.execute(
select(ChatbotMessage)
.where(ChatbotMessage.conversation_id == conversation_id)
.order_by(ChatbotMessage.timestamp.asc())
)
messages = result.scalars().all()
message_list = []
for msg in messages:
message_list.append({
"id": msg.id,
"role": msg.role,
"content": msg.content,
"timestamp": msg.timestamp.isoformat() if msg.timestamp else None,
"metadata": msg.message_metadata,
"sources": msg.sources
})
return message_list
except HTTPException:
raise
except Exception as e:
log_api_request("get_conversation_messages_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to fetch messages: {str(e)}")
@router.delete("/delete/{chatbot_id}")
async def delete_chatbot(
chatbot_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Delete a chatbot and all associated conversations/messages"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("delete_chatbot", {
"user_id": user_id,
"chatbot_id": chatbot_id
})
try:
# Get existing chatbot and verify ownership
result = await db.execute(
select(ChatbotInstance)
.where(ChatbotInstance.id == chatbot_id)
.where(ChatbotInstance.created_by == str(user_id))
)
chatbot = result.scalar_one_or_none()
if not chatbot:
raise HTTPException(status_code=404, detail="Chatbot not found or access denied")
# Delete all messages associated with this chatbot's conversations
await db.execute(
delete(ChatbotMessage)
.where(ChatbotMessage.conversation_id.in_(
select(ChatbotConversation.id)
.where(ChatbotConversation.chatbot_id == chatbot_id)
))
)
# Delete all conversations associated with this chatbot
await db.execute(
delete(ChatbotConversation)
.where(ChatbotConversation.chatbot_id == chatbot_id)
)
# Delete any analytics data
await db.execute(
delete(ChatbotAnalytics)
.where(ChatbotAnalytics.chatbot_id == chatbot_id)
)
# Finally, delete the chatbot itself
await db.execute(
delete(ChatbotInstance)
.where(ChatbotInstance.id == chatbot_id)
)
await db.commit()
return {"message": "Chatbot deleted successfully", "chatbot_id": chatbot_id}
except HTTPException:
raise
except Exception as e:
await db.rollback()
log_api_request("delete_chatbot_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to delete chatbot: {str(e)}")
@router.post("/external/{chatbot_id}/chat")
async def external_chat_with_chatbot(
chatbot_id: str,
request: ChatRequest,
api_key: APIKey = Depends(get_api_key_auth),
db: AsyncSession = Depends(get_db)
):
"""External API endpoint for chatbot access with API key authentication"""
log_api_request("external_chat_with_chatbot", {
"chatbot_id": chatbot_id,
"api_key_id": api_key.id,
"message_length": len(request.message)
})
try:
# Check if API key can access this chatbot
if not api_key.can_access_chatbot(chatbot_id):
raise HTTPException(status_code=403, detail="API key not authorized for this chatbot")
# Get the chatbot instance
result = await db.execute(
select(ChatbotInstance)
.where(ChatbotInstance.id == chatbot_id)
)
chatbot = result.scalar_one_or_none()
if not chatbot:
raise HTTPException(status_code=404, detail="Chatbot not found")
if not chatbot.is_active:
raise HTTPException(status_code=400, detail="Chatbot is not active")
# Get or create conversation
conversation = None
if request.conversation_id:
conv_result = await db.execute(
select(ChatbotConversation)
.where(ChatbotConversation.id == request.conversation_id)
.where(ChatbotConversation.chatbot_id == chatbot_id)
)
conversation = conv_result.scalar_one_or_none()
if not conversation:
# Create new conversation with API key as the user context
conversation = ChatbotConversation(
chatbot_id=chatbot_id,
user_id=f"api_key_{api_key.id}",
title=f"API Chat {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}",
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
is_active=True,
context_data={"api_key_id": api_key.id}
)
db.add(conversation)
await db.commit()
await db.refresh(conversation)
# Save user message
user_message = ChatbotMessage(
conversation_id=conversation.id,
role="user",
content=request.message,
timestamp=datetime.utcnow(),
message_metadata={"api_key_id": api_key.id},
sources=None
)
db.add(user_message)
# Get chatbot module and generate response
try:
chatbot_module = module_manager.modules.get("chatbot")
if not chatbot_module:
raise HTTPException(status_code=500, detail="Chatbot module not available")
# Use the chatbot module to generate a response
response_data = await chatbot_module.chat(
chatbot_config=chatbot.config,
message=request.message,
conversation_history=[], # TODO: Load conversation history
user_id=f"api_key_{api_key.id}"
)
response_content = response_data.get("response", "I'm sorry, I couldn't generate a response.")
sources = response_data.get("sources")
except Exception as e:
# Use fallback response
fallback_responses = chatbot.config.get("fallback_responses", [
"I'm sorry, I'm having trouble processing your request right now."
])
response_content = fallback_responses[0] if fallback_responses else "I'm sorry, I couldn't process your request."
sources = None
# Save assistant message
assistant_message = ChatbotMessage(
conversation_id=conversation.id,
role="assistant",
content=response_content,
timestamp=datetime.utcnow(),
message_metadata={"api_key_id": api_key.id},
sources=sources
)
db.add(assistant_message)
# Update conversation timestamp
conversation.updated_at = datetime.utcnow()
# Update API key usage stats
api_key.update_usage(tokens_used=len(request.message) + len(response_content), cost_cents=0)
await db.commit()
return {
"conversation_id": conversation.id,
"response": response_content,
"sources": sources,
"timestamp": assistant_message.timestamp.isoformat(),
"chatbot_id": chatbot_id
}
except HTTPException:
raise
except Exception as e:
await db.rollback()
log_api_request("external_chat_with_chatbot_error", {"error": str(e), "chatbot_id": chatbot_id})
raise HTTPException(status_code=500, detail=f"Failed to process chat: {str(e)}")
@router.post("/{chatbot_id}/api-key")
async def create_chatbot_api_key(
chatbot_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Create an API key for a specific chatbot"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("create_chatbot_api_key", {
"user_id": user_id,
"chatbot_id": chatbot_id
})
try:
# Get existing chatbot and verify ownership
result = await db.execute(
select(ChatbotInstance)
.where(ChatbotInstance.id == chatbot_id)
.where(ChatbotInstance.created_by == str(user_id))
)
chatbot = result.scalar_one_or_none()
if not chatbot:
raise HTTPException(status_code=404, detail="Chatbot not found or access denied")
# Generate API key
from app.api.v1.api_keys import generate_api_key
full_key, key_hash = generate_api_key()
key_prefix = full_key[:8]
# Create chatbot-specific API key
new_api_key = APIKey.create_chatbot_key(
user_id=user_id,
name=f"{chatbot.name} API Key",
key_hash=key_hash,
key_prefix=key_prefix,
chatbot_id=chatbot_id,
chatbot_name=chatbot.name
)
db.add(new_api_key)
await db.commit()
await db.refresh(new_api_key)
return {
"api_key_id": new_api_key.id,
"name": new_api_key.name,
"key_prefix": new_api_key.key_prefix + "...",
"secret_key": full_key, # Only returned on creation
"chatbot_id": chatbot_id,
"chatbot_name": chatbot.name,
"endpoint": f"/api/v1/chatbot/external/{chatbot_id}/chat",
"scopes": new_api_key.scopes,
"rate_limit_per_minute": new_api_key.rate_limit_per_minute,
"created_at": new_api_key.created_at.isoformat()
}
except HTTPException:
raise
except Exception as e:
await db.rollback()
log_api_request("create_chatbot_api_key_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to create chatbot API key: {str(e)}")
@router.get("/{chatbot_id}/api-keys")
async def list_chatbot_api_keys(
chatbot_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""List API keys for a specific chatbot"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("list_chatbot_api_keys", {
"user_id": user_id,
"chatbot_id": chatbot_id
})
try:
# Get existing chatbot and verify ownership
result = await db.execute(
select(ChatbotInstance)
.where(ChatbotInstance.id == chatbot_id)
.where(ChatbotInstance.created_by == str(user_id))
)
chatbot = result.scalar_one_or_none()
if not chatbot:
raise HTTPException(status_code=404, detail="Chatbot not found or access denied")
# Get API keys that can access this chatbot
api_keys_result = await db.execute(
select(APIKey)
.where(APIKey.user_id == user_id)
.where(APIKey.allowed_chatbots.contains([chatbot_id]))
.order_by(APIKey.created_at.desc())
)
api_keys = api_keys_result.scalars().all()
api_key_list = []
for api_key in api_keys:
api_key_list.append({
"id": api_key.id,
"name": api_key.name,
"key_prefix": api_key.key_prefix + "...",
"is_active": api_key.is_active,
"created_at": api_key.created_at.isoformat(),
"last_used_at": api_key.last_used_at.isoformat() if api_key.last_used_at else None,
"total_requests": api_key.total_requests,
"rate_limit_per_minute": api_key.rate_limit_per_minute,
"scopes": api_key.scopes
})
return {
"chatbot_id": chatbot_id,
"chatbot_name": chatbot.name,
"api_keys": api_key_list,
"total": len(api_key_list)
}
except HTTPException:
raise
except Exception as e:
log_api_request("list_chatbot_api_keys_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to list chatbot API keys: {str(e)}")

675
backend/app/api/v1/llm.py Normal file
View File

@@ -0,0 +1,675 @@
"""
LLM API endpoints - proxy to LiteLLM service with authentication and budget enforcement
"""
import logging
import time
from typing import Dict, Any, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db.database import get_db
from app.services.api_key_auth import require_api_key, RequireScope, APIKeyAuthService, get_api_key_context
from app.core.security import get_current_user
from app.models.user import User
from app.core.config import settings
from app.services.litellm_client import litellm_client
from app.services.budget_enforcement import (
check_budget_for_request, record_request_usage, BudgetEnforcementService,
atomic_check_and_reserve_budget, atomic_finalize_usage
)
from app.services.cost_calculator import CostCalculator, estimate_request_cost
from app.utils.exceptions import AuthenticationError, AuthorizationError
from app.middleware.analytics import set_analytics_data
logger = logging.getLogger(__name__)
# Models response cache - simple in-memory cache for performance
_models_cache = {
"data": None,
"cached_at": 0,
"cache_ttl": 900 # 15 minutes cache TTL
}
router = APIRouter()
async def get_cached_models() -> List[Dict[str, Any]]:
"""Get models from cache or fetch from LiteLLM if cache is stale"""
current_time = time.time()
# Check if cache is still valid
if (_models_cache["data"] is not None and
current_time - _models_cache["cached_at"] < _models_cache["cache_ttl"]):
logger.debug("Returning cached models list")
return _models_cache["data"]
# Cache miss or stale - fetch from LiteLLM
try:
logger.debug("Fetching fresh models list from LiteLLM")
models = await litellm_client.get_models()
# Update cache
_models_cache["data"] = models
_models_cache["cached_at"] = current_time
return models
except Exception as e:
logger.error(f"Failed to fetch models from LiteLLM: {e}")
# Return stale cache if available, otherwise empty list
if _models_cache["data"] is not None:
logger.warning("Returning stale cached models due to fetch error")
return _models_cache["data"]
return []
def invalidate_models_cache():
"""Invalidate the models cache (useful for admin operations)"""
_models_cache["data"] = None
_models_cache["cached_at"] = 0
logger.info("Models cache invalidated")
# Request/Response Models
class ChatMessage(BaseModel):
role: str = Field(..., description="Message role (system, user, assistant)")
content: str = Field(..., description="Message content")
class ChatCompletionRequest(BaseModel):
model: str = Field(..., description="Model name")
messages: List[ChatMessage] = Field(..., description="List of messages")
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
temperature: Optional[float] = Field(None, description="Temperature for sampling")
top_p: Optional[float] = Field(None, description="Top-p sampling parameter")
frequency_penalty: Optional[float] = Field(None, description="Frequency penalty")
presence_penalty: Optional[float] = Field(None, description="Presence penalty")
stop: Optional[List[str]] = Field(None, description="Stop sequences")
stream: Optional[bool] = Field(False, description="Stream response")
class EmbeddingRequest(BaseModel):
model: str = Field(..., description="Model name")
input: str = Field(..., description="Input text to embed")
encoding_format: Optional[str] = Field("float", description="Encoding format")
class ModelInfo(BaseModel):
id: str
object: str = "model"
created: int
owned_by: str
class ModelsResponse(BaseModel):
object: str = "list"
data: List[ModelInfo]
# Hybrid authentication function
async def get_auth_context(
request: Request,
db: AsyncSession = Depends(get_db)
) -> Dict[str, Any]:
"""Get authentication context from either API key or JWT token"""
# Try API key authentication first
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header[7:]
# Check if it's an API key (starts with ce_ prefix)
if token.startswith(settings.API_KEY_PREFIX):
try:
context = await get_api_key_context(request, db)
if context:
return context
except Exception as e:
logger.warning(f"API key authentication failed: {e}")
else:
# Try JWT token authentication
try:
from app.core.security import get_current_user
# Create a fake credentials object for JWT validation
from fastapi.security import HTTPAuthorizationCredentials
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
user = await get_current_user(credentials, db)
if user:
return {
"user": user,
"auth_type": "jwt",
"api_key": None
}
except Exception as e:
logger.warning(f"JWT authentication failed: {e}")
# Try X-API-Key header
api_key = request.headers.get("X-API-Key")
if api_key:
try:
context = await get_api_key_context(request, db)
if context:
return context
except Exception as e:
logger.warning(f"X-API-Key authentication failed: {e}")
# No valid authentication found
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Valid API key or authentication token required"
)
# Endpoints
@router.get("/models", response_model=ModelsResponse)
async def list_models(
context: Dict[str, Any] = Depends(get_auth_context),
db: AsyncSession = Depends(get_db)
):
"""List available models"""
try:
# For JWT users, allow access to list models
if context.get("auth_type") == "jwt":
pass # JWT users can list models
else:
# For API key users, check permissions
auth_service = APIKeyAuthService(db)
if not await auth_service.check_scope_permission(context, "models.list"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions to list models"
)
# Get models from cache or LiteLLM
models = await get_cached_models()
# Filter models based on API key permissions
api_key = context.get("api_key")
if api_key and api_key.allowed_models:
models = [model for model in models if model.get("id") in api_key.allowed_models]
return ModelsResponse(data=models)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error listing models: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to list models"
)
@router.post("/models/invalidate-cache")
async def invalidate_models_cache_endpoint(
context: Dict[str, Any] = Depends(get_auth_context),
db: AsyncSession = Depends(get_db)
):
"""Invalidate models cache (admin only)"""
# Check for admin permissions
if context.get("auth_type") == "jwt":
user = context.get("user")
if not user or not user.get("is_superuser"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin privileges required"
)
else:
# For API key users, check admin permissions
auth_service = APIKeyAuthService(db)
if not await auth_service.check_scope_permission(context, "admin.cache"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin permissions required to invalidate cache"
)
invalidate_models_cache()
return {"message": "Models cache invalidated successfully"}
@router.post("/chat/completions")
async def create_chat_completion(
request_body: Request,
chat_request: ChatCompletionRequest,
context: Dict[str, Any] = Depends(get_auth_context),
db: AsyncSession = Depends(get_db)
):
"""Create chat completion with budget enforcement"""
try:
auth_type = context.get("auth_type", "api_key")
# Handle different authentication types
if auth_type == "api_key":
auth_service = APIKeyAuthService(db)
# Check permissions
if not await auth_service.check_scope_permission(context, "chat.completions"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions for chat completions"
)
if not await auth_service.check_model_permission(context, chat_request.model):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Model '{chat_request.model}' not allowed"
)
api_key = context.get("api_key")
if not api_key:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="API key information not available"
)
elif auth_type == "jwt":
# For JWT authentication, we'll skip the detailed permission checks for now
# and create a dummy API key context for budget tracking
user = context.get("user")
if not user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User information not available"
)
api_key = None # JWT users don't have API keys
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication type"
)
# Estimate token usage for budget checking
messages_text = " ".join([msg.content for msg in chat_request.messages])
estimated_tokens = len(messages_text.split()) * 1.3 # Rough token estimation
if chat_request.max_tokens:
estimated_tokens += chat_request.max_tokens
else:
estimated_tokens += 150 # Default response length estimate
# Get a synchronous session for budget enforcement
from app.db.database import SessionLocal
sync_db = SessionLocal()
try:
# Atomic budget check and reservation (only for API key users)
warnings = []
reserved_budget_ids = []
if auth_type == "api_key" and api_key:
is_allowed, error_message, budget_warnings, budget_ids = atomic_check_and_reserve_budget(
sync_db, api_key, chat_request.model, int(estimated_tokens), "chat/completions"
)
if not is_allowed:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Budget exceeded: {error_message}"
)
warnings = budget_warnings
reserved_budget_ids = budget_ids
# Convert messages to dict format
messages = [{"role": msg.role, "content": msg.content} for msg in chat_request.messages]
# Prepare additional parameters
kwargs = {}
if chat_request.max_tokens is not None:
kwargs["max_tokens"] = chat_request.max_tokens
if chat_request.temperature is not None:
kwargs["temperature"] = chat_request.temperature
if chat_request.top_p is not None:
kwargs["top_p"] = chat_request.top_p
if chat_request.frequency_penalty is not None:
kwargs["frequency_penalty"] = chat_request.frequency_penalty
if chat_request.presence_penalty is not None:
kwargs["presence_penalty"] = chat_request.presence_penalty
if chat_request.stop is not None:
kwargs["stop"] = chat_request.stop
if chat_request.stream is not None:
kwargs["stream"] = chat_request.stream
# Make request to LiteLLM
response = await litellm_client.create_chat_completion(
model=chat_request.model,
messages=messages,
user_id=str(context.get("user_id", "anonymous")),
api_key_id=context.get("api_key_id", "jwt_user"),
**kwargs
)
# Calculate actual cost and update usage
usage = response.get("usage", {})
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
total_tokens = usage.get("total_tokens", input_tokens + output_tokens)
# Calculate accurate cost
actual_cost_cents = CostCalculator.calculate_cost_cents(
chat_request.model, input_tokens, output_tokens
)
# Finalize actual usage in budgets (only for API key users)
if auth_type == "api_key" and api_key:
atomic_finalize_usage(
sync_db, reserved_budget_ids, api_key, chat_request.model,
input_tokens, output_tokens, "chat/completions"
)
# Update API key usage statistics
auth_service = APIKeyAuthService(db)
await auth_service.update_usage_stats(context, total_tokens, actual_cost_cents)
# Set analytics data for middleware
set_analytics_data(
model=chat_request.model,
request_tokens=input_tokens,
response_tokens=output_tokens,
total_tokens=total_tokens,
cost_cents=actual_cost_cents,
budget_ids=reserved_budget_ids,
budget_warnings=warnings
)
# Add budget warnings to response if any
if warnings:
response["budget_warnings"] = warnings
return response
finally:
sync_db.close()
except HTTPException:
raise
except Exception as e:
logger.error(f"Error creating chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create chat completion"
)
@router.post("/embeddings")
async def create_embedding(
request: EmbeddingRequest,
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
):
"""Create embedding with budget enforcement"""
try:
auth_service = APIKeyAuthService(db)
# Check permissions
if not await auth_service.check_scope_permission(context, "embeddings.create"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions for embeddings"
)
if not await auth_service.check_model_permission(context, request.model):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Model '{request.model}' not allowed"
)
api_key = context.get("api_key")
if not api_key:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="API key information not available"
)
# Estimate token usage for budget checking
estimated_tokens = len(request.input.split()) * 1.3 # Rough token estimation
# Convert AsyncSession to Session for budget enforcement
sync_db = Session(bind=db.bind.sync_engine)
try:
# Check budget compliance before making request
is_allowed, error_message, warnings = check_budget_for_request(
sync_db, api_key, request.model, int(estimated_tokens), "embeddings"
)
if not is_allowed:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Budget exceeded: {error_message}"
)
# Make request to LiteLLM
response = await litellm_client.create_embedding(
model=request.model,
input_text=request.input,
user_id=str(context["user_id"]),
api_key_id=context["api_key_id"],
encoding_format=request.encoding_format
)
# Calculate actual cost and update usage
usage = response.get("usage", {})
total_tokens = usage.get("total_tokens", int(estimated_tokens))
# Calculate accurate cost (embeddings typically use input tokens only)
actual_cost_cents = CostCalculator.calculate_cost_cents(
request.model, total_tokens, 0
)
# Record actual usage in budgets
record_request_usage(
sync_db, api_key, request.model, total_tokens, 0, "embeddings"
)
# Update API key usage statistics
await auth_service.update_usage_stats(context, total_tokens, actual_cost_cents)
# Add budget warnings to response if any
if warnings:
response["budget_warnings"] = warnings
return response
finally:
sync_db.close()
except HTTPException:
raise
except Exception as e:
logger.error(f"Error creating embedding: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create embedding"
)
@router.get("/health")
async def llm_health_check(
context: Dict[str, Any] = Depends(require_api_key)
):
"""Health check for LLM service"""
try:
health_status = await litellm_client.health_check()
return {
"status": "healthy",
"service": "LLM Proxy",
"litellm_status": health_status,
"user_id": context["user_id"],
"api_key_name": context["api_key_name"]
}
except Exception as e:
logger.error(f"LLM health check error: {e}")
return {
"status": "unhealthy",
"service": "LLM Proxy",
"error": str(e)
}
@router.get("/usage")
async def get_usage_stats(
context: Dict[str, Any] = Depends(require_api_key)
):
"""Get usage statistics for the API key"""
try:
api_key = context.get("api_key")
if not api_key:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="API key information not available"
)
return {
"api_key_id": api_key.id,
"api_key_name": api_key.name,
"total_requests": api_key.total_requests,
"total_tokens": api_key.total_tokens,
"total_cost_cents": api_key.total_cost,
"created_at": api_key.created_at.isoformat(),
"last_used_at": api_key.last_used_at.isoformat() if api_key.last_used_at else None,
"rate_limits": {
"per_minute": api_key.rate_limit_per_minute,
"per_hour": api_key.rate_limit_per_hour,
"per_day": api_key.rate_limit_per_day
},
"permissions": api_key.permissions,
"scopes": api_key.scopes,
"allowed_models": api_key.allowed_models
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting usage stats: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get usage statistics"
)
@router.get("/budget/status")
async def get_budget_status(
request: Request,
context: Dict[str, Any] = Depends(get_auth_context),
db: AsyncSession = Depends(get_db)
):
"""Get current budget status and usage analytics"""
try:
auth_type = context.get("auth_type", "api_key")
# Check permissions based on auth type
if auth_type == "api_key":
auth_service = APIKeyAuthService(db)
if not await auth_service.check_scope_permission(context, "budget.read"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions to read budget information"
)
api_key = context.get("api_key")
if not api_key:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="API key information not available"
)
# Convert AsyncSession to Session for budget enforcement
sync_db = Session(bind=db.bind.sync_engine)
try:
budget_service = BudgetEnforcementService(sync_db)
budget_status = budget_service.get_budget_status(api_key)
return {
"object": "budget_status",
"data": budget_status
}
finally:
sync_db.close()
elif auth_type == "jwt":
# For JWT authentication, return user-level budget information
user = context.get("user")
if not user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User information not available"
)
# Return basic budget info for JWT users
return {
"object": "budget_status",
"data": {
"budgets": [],
"total_usage": 0.0,
"warnings": [],
"projections": {
"daily_burn_rate": 0.0,
"projected_monthly": 0.0,
"days_remaining": 30
}
}
}
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication type"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting budget status: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get budget status"
)
# Generic proxy endpoint for other LiteLLM endpoints
@router.api_route("/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
async def proxy_endpoint(
endpoint: str,
request: Request,
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
):
"""Generic proxy endpoint for LiteLLM requests"""
try:
auth_service = APIKeyAuthService(db)
# Check endpoint permission
if not await auth_service.check_endpoint_permission(context, endpoint):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Endpoint '{endpoint}' not allowed"
)
# Get request body
if request.method in ["POST", "PUT", "PATCH"]:
try:
payload = await request.json()
except:
payload = {}
else:
payload = dict(request.query_params)
# Make request to LiteLLM
response = await litellm_client.proxy_request(
method=request.method,
endpoint=endpoint,
payload=payload,
user_id=str(context["user_id"]),
api_key_id=context["api_key_id"]
)
return response
except HTTPException:
raise
except Exception as e:
logger.error(f"Error proxying request to {endpoint}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to proxy request"
)

View File

@@ -0,0 +1,478 @@
"""
Modules API endpoints
"""
from typing import Dict, Any, List
from fastapi import APIRouter, Depends, HTTPException
from app.services.module_manager import module_manager, ModuleConfig
from app.core.logging import log_api_request
router = APIRouter()
@router.get("/")
async def list_modules():
"""Get list of all discovered modules with their status (enabled and disabled)"""
log_api_request("list_modules", {})
# Get all discovered modules including disabled ones
all_modules = module_manager.list_all_modules()
modules = []
for module_info in all_modules:
# Convert module_info to API format
api_module = {
"name": module_info["name"],
"version": module_info["version"],
"description": module_info["description"],
"initialized": module_info["loaded"],
"enabled": module_info["enabled"]
}
# Get module statistics if available and module is loaded
if module_info["loaded"] and module_info["name"] in module_manager.modules:
module_instance = module_manager.modules[module_info["name"]]
if hasattr(module_instance, "get_stats"):
try:
import asyncio
if asyncio.iscoroutinefunction(module_instance.get_stats):
stats = await module_instance.get_stats()
else:
stats = module_instance.get_stats()
api_module["stats"] = stats.__dict__ if hasattr(stats, "__dict__") else stats
except:
api_module["stats"] = {}
modules.append(api_module)
# Calculate stats
loaded_count = sum(1 for m in modules if m["initialized"] and m["enabled"])
return {
"total": len(modules),
"modules": modules,
"module_count": loaded_count,
"initialized": module_manager.initialized
}
@router.get("/status")
async def get_modules_status():
"""Get summary status of all modules"""
log_api_request("get_modules_status", {})
total_modules = len(module_manager.modules)
running_modules = 0
standby_modules = 0
failed_modules = 0
for name, module in module_manager.modules.items():
config = module_manager.module_configs.get(name)
is_initialized = getattr(module, "initialized", False)
is_enabled = config.enabled if config else True
if is_initialized and is_enabled:
running_modules += 1
elif not is_initialized:
failed_modules += 1
else:
standby_modules += 1
return {
"total": total_modules,
"running": running_modules,
"standby": standby_modules,
"failed": failed_modules,
"system_initialized": module_manager.initialized
}
@router.get("/{module_name}")
async def get_module_info(module_name: str):
"""Get detailed information about a specific module"""
log_api_request("get_module_info", {"module_name": module_name})
if module_name not in module_manager.modules:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
module = module_manager.modules[module_name]
module_info = {
"name": module_name,
"version": getattr(module, "version", "1.0.0"),
"description": getattr(module, "description", ""),
"initialized": getattr(module, "initialized", False),
"enabled": module_manager.module_configs.get(module_name, ModuleConfig(module_name)).enabled,
"capabilities": []
}
# Get module capabilities
if hasattr(module, "get_module_info"):
try:
import asyncio
if asyncio.iscoroutinefunction(module.get_module_info):
info = await module.get_module_info()
else:
info = module.get_module_info()
module_info.update(info)
except:
pass
# Get module statistics
if hasattr(module, "get_stats"):
try:
import asyncio
if asyncio.iscoroutinefunction(module.get_stats):
stats = await module.get_stats()
else:
stats = module.get_stats()
module_info["stats"] = stats.__dict__ if hasattr(stats, "__dict__") else stats
except:
module_info["stats"] = {}
# List available methods
methods = []
for attr_name in dir(module):
attr = getattr(module, attr_name)
if callable(attr) and not attr_name.startswith("_"):
methods.append(attr_name)
module_info["methods"] = methods
return module_info
@router.post("/{module_name}/enable")
async def enable_module(module_name: str):
"""Enable a module"""
log_api_request("enable_module", {"module_name": module_name})
if module_name not in module_manager.module_configs:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
# Enable the module in config
config = module_manager.module_configs[module_name]
config.enabled = True
# Load the module if not already loaded
if module_name not in module_manager.modules:
try:
await module_manager._load_module(module_name, config)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to enable module '{module_name}': {str(e)}")
return {
"message": f"Module '{module_name}' enabled successfully",
"enabled": True
}
@router.post("/{module_name}/disable")
async def disable_module(module_name: str):
"""Disable a module"""
log_api_request("disable_module", {"module_name": module_name})
if module_name not in module_manager.module_configs:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
# Disable the module in config
config = module_manager.module_configs[module_name]
config.enabled = False
# Unload the module if loaded
if module_name in module_manager.modules:
try:
await module_manager.unload_module(module_name)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to disable module '{module_name}': {str(e)}")
return {
"message": f"Module '{module_name}' disabled successfully",
"enabled": False
}
@router.post("/all/reload")
async def reload_all_modules():
"""Reload all modules"""
log_api_request("reload_all_modules", {})
results = {}
failed_modules = []
for module_name in list(module_manager.modules.keys()):
try:
success = await module_manager.reload_module(module_name)
results[module_name] = {"success": success, "error": None}
if not success:
failed_modules.append(module_name)
except Exception as e:
results[module_name] = {"success": False, "error": str(e)}
failed_modules.append(module_name)
if failed_modules:
return {
"message": f"Reloaded {len(results) - len(failed_modules)}/{len(results)} modules successfully",
"success": False,
"results": results,
"failed_modules": failed_modules
}
else:
return {
"message": f"All {len(results)} modules reloaded successfully",
"success": True,
"results": results,
"failed_modules": []
}
@router.post("/{module_name}/reload")
async def reload_module(module_name: str):
"""Reload a specific module"""
log_api_request("reload_module", {"module_name": module_name})
if module_name not in module_manager.modules:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
success = await module_manager.reload_module(module_name)
if not success:
raise HTTPException(status_code=500, detail=f"Failed to reload module '{module_name}'")
return {
"message": f"Module '{module_name}' reloaded successfully",
"reloaded": True
}
@router.post("/{module_name}/restart")
async def restart_module(module_name: str):
"""Restart a specific module (alias for reload)"""
log_api_request("restart_module", {"module_name": module_name})
if module_name not in module_manager.modules:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
success = await module_manager.reload_module(module_name)
if not success:
raise HTTPException(status_code=500, detail=f"Failed to restart module '{module_name}'")
return {
"message": f"Module '{module_name}' restarted successfully",
"restarted": True
}
@router.post("/{module_name}/start")
async def start_module(module_name: str):
"""Start a specific module (enable and load)"""
log_api_request("start_module", {"module_name": module_name})
if module_name not in module_manager.module_configs:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
# Enable the module
config = module_manager.module_configs[module_name]
config.enabled = True
# Load the module if not already loaded
if module_name not in module_manager.modules:
await module_manager._load_module(module_name, config)
return {
"message": f"Module '{module_name}' started successfully",
"started": True
}
@router.post("/{module_name}/stop")
async def stop_module(module_name: str):
"""Stop a specific module (disable and unload)"""
log_api_request("stop_module", {"module_name": module_name})
if module_name not in module_manager.module_configs:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
# Disable the module
config = module_manager.module_configs[module_name]
config.enabled = False
# Unload the module if loaded
if module_name in module_manager.modules:
await module_manager.unload_module(module_name)
return {
"message": f"Module '{module_name}' stopped successfully",
"stopped": True
}
@router.get("/{module_name}/stats")
async def get_module_stats(module_name: str):
"""Get module statistics"""
log_api_request("get_module_stats", {"module_name": module_name})
if module_name not in module_manager.modules:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
module = module_manager.modules[module_name]
if not hasattr(module, "get_stats"):
raise HTTPException(status_code=404, detail=f"Module '{module_name}' does not provide statistics")
try:
import asyncio
if asyncio.iscoroutinefunction(module.get_stats):
stats = await module.get_stats()
else:
stats = module.get_stats()
return {
"module": module_name,
"stats": stats.__dict__ if hasattr(stats, "__dict__") else stats
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get statistics: {str(e)}")
@router.post("/{module_name}/execute")
async def execute_module_action(module_name: str, request_data: Dict[str, Any]):
"""Execute a module action through the interceptor pattern"""
log_api_request("execute_module_action", {"module_name": module_name, "action": request_data.get("action")})
if module_name not in module_manager.modules:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
module = module_manager.modules[module_name]
# Check if module supports the new interceptor pattern
if hasattr(module, 'execute_with_interceptors'):
try:
# Prepare context (would normally come from authentication middleware)
context = {
"user_id": "test_user", # Would come from authentication
"api_key_id": "test_api_key", # Would come from API key auth
"ip_address": "127.0.0.1", # Would come from request
"user_permissions": [f"modules:{module_name}:*"] # Would come from user/API key permissions
}
# Execute through interceptor chain
response = await module.execute_with_interceptors(request_data, context)
return {
"module": module_name,
"success": True,
"response": response,
"interceptor_pattern": True
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Module execution failed: {str(e)}")
# Fallback for legacy modules
else:
action = request_data.get("action", "execute")
if hasattr(module, action):
try:
method = getattr(module, action)
if callable(method):
import asyncio
if asyncio.iscoroutinefunction(method):
response = await method(request_data)
else:
response = method(request_data)
return {
"module": module_name,
"success": True,
"response": response,
"interceptor_pattern": False
}
else:
raise HTTPException(status_code=400, detail=f"'{action}' is not callable")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Module execution failed: {str(e)}")
else:
raise HTTPException(status_code=400, detail=f"Action '{action}' not supported by module '{module_name}'")
@router.get("/{module_name}/config")
async def get_module_config(module_name: str):
"""Get module configuration schema and current values"""
log_api_request("get_module_config", {"module_name": module_name})
from app.services.module_config_manager import module_config_manager
from app.services.litellm_client import litellm_client
import copy
# Get module manifest and schema
manifest = module_config_manager.get_module_manifest(module_name)
if not manifest:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
schema = module_config_manager.get_module_schema(module_name)
current_config = module_config_manager.get_module_config(module_name)
# For Signal module, populate model options dynamically
if module_name == "signal" and schema:
try:
# Get available models from LiteLLM
models_data = await litellm_client.get_models()
model_ids = [model.get("id", model.get("model", "")) for model in models_data if model.get("id") or model.get("model")]
if model_ids:
# Create a copy of the schema to avoid modifying the original
dynamic_schema = copy.deepcopy(schema)
# Add enum options for the model field
if "properties" in dynamic_schema and "model" in dynamic_schema["properties"]:
dynamic_schema["properties"]["model"]["enum"] = model_ids
# Set a sensible default if the current default isn't in the list
current_default = dynamic_schema["properties"]["model"].get("default", "gpt-3.5-turbo")
if current_default not in model_ids and model_ids:
dynamic_schema["properties"]["model"]["default"] = model_ids[0]
schema = dynamic_schema
except Exception as e:
# If we can't get models, log warning but continue with original schema
logger.warning(f"Failed to get dynamic models for Signal config: {e}")
return {
"module": module_name,
"description": manifest.description,
"schema": schema,
"current_config": current_config,
"has_schema": schema is not None
}
@router.post("/{module_name}/config")
async def update_module_config(module_name: str, config: dict):
"""Update module configuration"""
log_api_request("update_module_config", {"module_name": module_name})
from app.services.module_config_manager import module_config_manager
# Validate module exists
manifest = module_config_manager.get_module_manifest(module_name)
if not manifest:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
try:
# Save configuration
success = await module_config_manager.save_module_config(module_name, config)
if not success:
raise HTTPException(status_code=500, detail="Failed to save configuration")
# Update module manager with new config
success = await module_manager.update_module_config(module_name, config)
if not success:
raise HTTPException(status_code=500, detail="Failed to apply configuration")
return {
"message": f"Configuration updated for module '{module_name}'",
"config": config
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))

View File

@@ -0,0 +1,160 @@
"""
OpenAI-compatible API endpoints
Following the exact OpenAI API specification for compatibility with OpenAI clients
"""
import logging
from typing import Dict, Any, List
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import JSONResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db
from app.api.v1.llm import (
get_auth_context, get_cached_models, ModelsResponse, ModelInfo,
ChatCompletionRequest, EmbeddingRequest, create_chat_completion as llm_chat_completion,
create_embedding as llm_create_embedding
)
logger = logging.getLogger(__name__)
router = APIRouter()
def openai_error_response(message: str, error_type: str = "invalid_request_error",
status_code: int = 400, code: str = None):
"""Create OpenAI-compatible error response"""
error_data = {
"error": {
"message": message,
"type": error_type,
}
}
if code:
error_data["error"]["code"] = code
return JSONResponse(
status_code=status_code,
content=error_data
)
@router.get("/models", response_model=ModelsResponse)
async def list_models(
context: Dict[str, Any] = Depends(get_auth_context),
db: AsyncSession = Depends(get_db)
):
"""
Lists the currently available models, and provides basic information about each one
such as the owner and availability.
This endpoint follows the exact OpenAI API specification:
GET /v1/models
"""
try:
# Delegate to the existing LLM models endpoint
from app.api.v1.llm import list_models as llm_list_models
return await llm_list_models(context, db)
except HTTPException as e:
# Convert FastAPI HTTPException to OpenAI format
if e.status_code == 401:
return openai_error_response(
"Invalid authentication credentials",
"authentication_error",
401
)
elif e.status_code == 403:
return openai_error_response(
"Insufficient permissions",
"permission_error",
403
)
else:
return openai_error_response(str(e.detail), "api_error", e.status_code)
except Exception as e:
logger.error(f"Error in OpenAI models endpoint: {e}")
return openai_error_response(
"Internal server error",
"api_error",
500
)
@router.post("/chat/completions")
async def create_chat_completion(
request_body: Request,
chat_request: ChatCompletionRequest,
context: Dict[str, Any] = Depends(get_auth_context),
db: AsyncSession = Depends(get_db)
):
"""
Create chat completion - OpenAI compatible endpoint
This endpoint follows the exact OpenAI API specification:
POST /v1/chat/completions
"""
# Delegate to the existing LLM chat completions endpoint
return await llm_chat_completion(request_body, chat_request, context, db)
@router.post("/embeddings")
async def create_embedding(
request: EmbeddingRequest,
context: Dict[str, Any] = Depends(get_auth_context),
db: AsyncSession = Depends(get_db)
):
"""
Create embedding - OpenAI compatible endpoint
This endpoint follows the exact OpenAI API specification:
POST /v1/embeddings
"""
# Delegate to the existing LLM embeddings endpoint
return await llm_create_embedding(request, context, db)
@router.get("/models/{model_id}")
async def retrieve_model(
model_id: str,
context: Dict[str, Any] = Depends(get_auth_context),
db: AsyncSession = Depends(get_db)
):
"""
Retrieve model information - OpenAI compatible endpoint
This endpoint follows the exact OpenAI API specification:
GET /v1/models/{model}
"""
try:
# Get all models and find the specific one
models = await get_cached_models()
# Filter models based on API key permissions
api_key = context.get("api_key")
if api_key and api_key.allowed_models:
models = [model for model in models if model.get("id") in api_key.allowed_models]
# Find the specific model
model = next((m for m in models if m.get("id") == model_id), None)
if not model:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model '{model_id}' not found"
)
return ModelInfo(
id=model.get("id", model_id),
object="model",
created=model.get("created", 0),
owned_by=model.get("owned_by", "system")
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error retrieving model {model_id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve model information"
)

View File

@@ -0,0 +1,343 @@
"""
Platform API routes for core platform operations
Includes permissions, users, API keys, budgets, audit, etc.
"""
from typing import List, Dict, Any, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from app.services.permission_manager import permission_registry, Permission, PermissionScope
from app.core.logging import get_logger
logger = get_logger(__name__)
router = APIRouter()
# Pydantic models for API
class PermissionResponse(BaseModel):
resource: str
action: str
description: str
conditions: Optional[Dict[str, Any]] = None
class PermissionHierarchyResponse(BaseModel):
hierarchy: Dict[str, Any]
class PermissionValidationRequest(BaseModel):
permissions: List[str]
class PermissionValidationResponse(BaseModel):
valid: List[str]
invalid: List[str]
is_valid: bool
class PermissionCheckRequest(BaseModel):
user_permissions: List[str]
required_permission: str
context: Optional[Dict[str, Any]] = None
class PermissionCheckResponse(BaseModel):
has_permission: bool
matching_permissions: List[str]
class RoleRequest(BaseModel):
role_name: str
permissions: List[str]
class RoleResponse(BaseModel):
role_name: str
permissions: List[str]
created: bool = True
class UserPermissionsRequest(BaseModel):
roles: List[str]
custom_permissions: Optional[List[str]] = None
class UserPermissionsResponse(BaseModel):
effective_permissions: List[str]
roles: List[str]
custom_permissions: List[str]
# Permission management endpoints
@router.get("/permissions", response_model=Dict[str, List[PermissionResponse]])
async def get_available_permissions(namespace: Optional[str] = None):
"""Get all available permissions, optionally filtered by namespace"""
try:
permissions = permission_registry.get_available_permissions(namespace)
# Convert to response format
result = {}
for ns, perms in permissions.items():
result[ns] = [
PermissionResponse(
resource=perm.resource,
action=perm.action,
description=perm.description,
conditions=perm.conditions
)
for perm in perms
]
return result
except Exception as e:
logger.error(f"Error getting permissions: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get permissions: {str(e)}"
)
@router.get("/permissions/hierarchy", response_model=PermissionHierarchyResponse)
async def get_permission_hierarchy():
"""Get the permission hierarchy tree structure"""
try:
hierarchy = permission_registry.get_permission_hierarchy()
return PermissionHierarchyResponse(hierarchy=hierarchy)
except Exception as e:
logger.error(f"Error getting permission hierarchy: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get permission hierarchy: {str(e)}"
)
@router.post("/permissions/validate", response_model=PermissionValidationResponse)
async def validate_permissions(request: PermissionValidationRequest):
"""Validate a list of permissions"""
try:
validation_result = permission_registry.validate_permissions(request.permissions)
return PermissionValidationResponse(**validation_result)
except Exception as e:
logger.error(f"Error validating permissions: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to validate permissions: {str(e)}"
)
@router.post("/permissions/check", response_model=PermissionCheckResponse)
async def check_permission(request: PermissionCheckRequest):
"""Check if user has a specific permission"""
try:
has_permission = permission_registry.check_permission(
request.user_permissions,
request.required_permission,
request.context
)
matching_permissions = list(permission_registry.tree.get_matching_permissions(
request.user_permissions
))
return PermissionCheckResponse(
has_permission=has_permission,
matching_permissions=matching_permissions
)
except Exception as e:
logger.error(f"Error checking permission: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to check permission: {str(e)}"
)
@router.get("/permissions/modules/{module_id}", response_model=List[PermissionResponse])
async def get_module_permissions(module_id: str):
"""Get permissions for a specific module"""
try:
permissions = permission_registry.get_module_permissions(module_id)
return [
PermissionResponse(
resource=perm.resource,
action=perm.action,
description=perm.description,
conditions=perm.conditions
)
for perm in permissions
]
except Exception as e:
logger.error(f"Error getting module permissions: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get module permissions: {str(e)}"
)
# Role management endpoints
@router.post("/roles", response_model=RoleResponse)
async def create_role(request: RoleRequest):
"""Create a custom role with specific permissions"""
try:
# Validate permissions first
validation_result = permission_registry.validate_permissions(request.permissions)
if not validation_result["is_valid"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid permissions: {validation_result['invalid']}"
)
permission_registry.create_role(request.role_name, request.permissions)
return RoleResponse(
role_name=request.role_name,
permissions=request.permissions
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error creating role: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create role: {str(e)}"
)
@router.get("/roles", response_model=Dict[str, List[str]])
async def get_roles():
"""Get all available roles and their permissions"""
try:
# Combine default roles and custom roles
all_roles = {**permission_registry.default_roles, **permission_registry.role_permissions}
return all_roles
except Exception as e:
logger.error(f"Error getting roles: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get roles: {str(e)}"
)
@router.get("/roles/{role_name}", response_model=RoleResponse)
async def get_role(role_name: str):
"""Get a specific role and its permissions"""
try:
# Check default roles first, then custom roles
permissions = (permission_registry.role_permissions.get(role_name) or
permission_registry.default_roles.get(role_name))
if permissions is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Role '{role_name}' not found"
)
return RoleResponse(
role_name=role_name,
permissions=permissions,
created=True
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting role: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get role: {str(e)}"
)
# User permission calculation endpoints
@router.post("/users/permissions", response_model=UserPermissionsResponse)
async def calculate_user_permissions(request: UserPermissionsRequest):
"""Calculate effective permissions for a user based on roles and custom permissions"""
try:
effective_permissions = permission_registry.get_user_permissions(
request.roles,
request.custom_permissions
)
return UserPermissionsResponse(
effective_permissions=effective_permissions,
roles=request.roles,
custom_permissions=request.custom_permissions or []
)
except Exception as e:
logger.error(f"Error calculating user permissions: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to calculate user permissions: {str(e)}"
)
# Health and status endpoints
@router.get("/health")
async def platform_health():
"""Platform health check endpoint"""
try:
# Get permission system status
total_permissions = len(permission_registry.tree.permissions)
total_modules = len(permission_registry.module_permissions)
total_roles = len(permission_registry.default_roles) + len(permission_registry.role_permissions)
return {
"status": "healthy",
"service": "Confidential Empire Platform API",
"version": "1.0.0",
"permission_system": {
"total_permissions": total_permissions,
"registered_modules": total_modules,
"available_roles": total_roles
}
}
except Exception as e:
logger.error(f"Error checking platform health: {str(e)}")
return {
"status": "unhealthy",
"error": str(e)
}
@router.get("/metrics")
async def platform_metrics():
"""Get platform metrics"""
try:
# Get permission system metrics
namespaces = permission_registry.get_available_permissions()
metrics = {
"permissions": {
"total": len(permission_registry.tree.permissions),
"by_namespace": {ns: len(perms) for ns, perms in namespaces.items()}
},
"modules": {
"registered": len(permission_registry.module_permissions),
"names": list(permission_registry.module_permissions.keys())
},
"roles": {
"default": len(permission_registry.default_roles),
"custom": len(permission_registry.role_permissions),
"total": len(permission_registry.default_roles) + len(permission_registry.role_permissions)
}
}
return metrics
except Exception as e:
logger.error(f"Error getting platform metrics: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get platform metrics: {str(e)}"
)

View File

@@ -0,0 +1,427 @@
"""
Prompt Template API endpoints
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete
from datetime import datetime
import uuid
from app.db.database import get_db
from app.models.prompt_template import PromptTemplate, ChatbotPromptVariable
from app.core.security import get_current_user
from app.models.user import User
from app.core.logging import log_api_request
from app.services.litellm_client import litellm_client
router = APIRouter()
class PromptTemplateRequest(BaseModel):
name: str
type_key: str
description: Optional[str] = None
system_prompt: str
is_active: bool = True
class PromptTemplateResponse(BaseModel):
id: str
name: str
type_key: str
description: Optional[str]
system_prompt: str
is_default: bool
is_active: bool
version: int
created_at: str
updated_at: str
class PromptVariableResponse(BaseModel):
id: str
variable_name: str
description: Optional[str]
example_value: Optional[str]
is_active: bool
class ImprovePromptRequest(BaseModel):
current_prompt: str
chatbot_type: str
improvement_instructions: Optional[str] = None
@router.get("/templates", response_model=List[PromptTemplateResponse])
async def list_prompt_templates(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Get all prompt templates"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("list_prompt_templates", {"user_id": user_id})
try:
result = await db.execute(
select(PromptTemplate)
.where(PromptTemplate.is_active == True)
.order_by(PromptTemplate.name)
)
templates = result.scalars().all()
template_list = []
for template in templates:
template_dict = {
"id": template.id,
"name": template.name,
"type_key": template.type_key,
"description": template.description,
"system_prompt": template.system_prompt,
"is_default": template.is_default,
"is_active": template.is_active,
"version": template.version,
"created_at": template.created_at.isoformat() if template.created_at else None,
"updated_at": template.updated_at.isoformat() if template.updated_at else None
}
template_list.append(template_dict)
return template_list
except Exception as e:
log_api_request("list_prompt_templates_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to fetch prompt templates: {str(e)}")
@router.get("/templates/{type_key}", response_model=PromptTemplateResponse)
async def get_prompt_template(
type_key: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Get a specific prompt template by type key"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("get_prompt_template", {"user_id": user_id, "type_key": type_key})
try:
result = await db.execute(
select(PromptTemplate)
.where(PromptTemplate.type_key == type_key)
.where(PromptTemplate.is_active == True)
)
template = result.scalar_one_or_none()
if not template:
raise HTTPException(status_code=404, detail="Prompt template not found")
return {
"id": template.id,
"name": template.name,
"type_key": template.type_key,
"description": template.description,
"system_prompt": template.system_prompt,
"is_default": template.is_default,
"is_active": template.is_active,
"version": template.version,
"created_at": template.created_at.isoformat() if template.created_at else None,
"updated_at": template.updated_at.isoformat() if template.updated_at else None
}
except HTTPException:
raise
except Exception as e:
log_api_request("get_prompt_template_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to fetch prompt template: {str(e)}")
@router.put("/templates/{type_key}")
async def update_prompt_template(
type_key: str,
request: PromptTemplateRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Update a prompt template"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("update_prompt_template", {
"user_id": user_id,
"type_key": type_key,
"name": request.name
})
try:
# Get existing template
result = await db.execute(
select(PromptTemplate)
.where(PromptTemplate.type_key == type_key)
.where(PromptTemplate.is_active == True)
)
template = result.scalar_one_or_none()
if not template:
raise HTTPException(status_code=404, detail="Prompt template not found")
# Update the template
await db.execute(
update(PromptTemplate)
.where(PromptTemplate.type_key == type_key)
.values(
name=request.name,
description=request.description,
system_prompt=request.system_prompt,
is_active=request.is_active,
version=template.version + 1,
updated_at=datetime.utcnow()
)
)
await db.commit()
# Return updated template
updated_result = await db.execute(
select(PromptTemplate)
.where(PromptTemplate.type_key == type_key)
)
updated_template = updated_result.scalar_one()
return {
"id": updated_template.id,
"name": updated_template.name,
"type_key": updated_template.type_key,
"description": updated_template.description,
"system_prompt": updated_template.system_prompt,
"is_default": updated_template.is_default,
"is_active": updated_template.is_active,
"version": updated_template.version,
"created_at": updated_template.created_at.isoformat() if updated_template.created_at else None,
"updated_at": updated_template.updated_at.isoformat() if updated_template.updated_at else None
}
except HTTPException:
raise
except Exception as e:
await db.rollback()
log_api_request("update_prompt_template_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to update prompt template: {str(e)}")
@router.post("/templates/create")
async def create_prompt_template(
request: PromptTemplateRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Create a new prompt template"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("create_prompt_template", {
"user_id": user_id,
"type_key": request.type_key,
"name": request.name
})
try:
# Check if template already exists
existing_result = await db.execute(
select(PromptTemplate)
.where(PromptTemplate.type_key == request.type_key)
)
if existing_result.scalar_one_or_none():
raise HTTPException(status_code=400, detail="Prompt template with this type key already exists")
# Create new template
template = PromptTemplate(
id=str(uuid.uuid4()),
name=request.name,
type_key=request.type_key,
description=request.description,
system_prompt=request.system_prompt,
is_default=False,
is_active=request.is_active,
version=1,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
db.add(template)
await db.commit()
await db.refresh(template)
return {
"id": template.id,
"name": template.name,
"type_key": template.type_key,
"description": template.description,
"system_prompt": template.system_prompt,
"is_default": template.is_default,
"is_active": template.is_active,
"version": template.version,
"created_at": template.created_at.isoformat() if template.created_at else None,
"updated_at": template.updated_at.isoformat() if template.updated_at else None
}
except HTTPException:
raise
except Exception as e:
await db.rollback()
log_api_request("create_prompt_template_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to create prompt template: {str(e)}")
@router.get("/variables", response_model=List[PromptVariableResponse])
async def list_prompt_variables(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Get all available prompt variables"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("list_prompt_variables", {"user_id": user_id})
try:
result = await db.execute(
select(ChatbotPromptVariable)
.where(ChatbotPromptVariable.is_active == True)
.order_by(ChatbotPromptVariable.variable_name)
)
variables = result.scalars().all()
variable_list = []
for variable in variables:
variable_dict = {
"id": variable.id,
"variable_name": variable.variable_name,
"description": variable.description,
"example_value": variable.example_value,
"is_active": variable.is_active
}
variable_list.append(variable_dict)
return variable_list
except Exception as e:
log_api_request("list_prompt_variables_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to fetch prompt variables: {str(e)}")
@router.post("/templates/{type_key}/reset")
async def reset_prompt_template(
type_key: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Reset a prompt template to its default"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("reset_prompt_template", {"user_id": user_id, "type_key": type_key})
# Define default prompts (same as in migration)
default_prompts = {
"assistant": "You are a helpful AI assistant. Provide accurate, concise, and friendly responses. Always aim to be helpful while being honest about your limitations. When you don't know something, say so clearly. Be professional but approachable in your communication style.",
"customer_support": "You are a professional customer support representative. Be empathetic, professional, and solution-focused in all interactions. Always try to understand the customer's issue fully before providing solutions. Use the knowledge base to provide accurate information. When you cannot resolve an issue, explain clearly how the customer can escalate or get further help. Maintain a helpful and patient tone even in difficult situations.",
"teacher": "You are an experienced educational tutor and learning facilitator. Break down complex concepts into understandable, digestible parts. Use analogies, examples, and step-by-step explanations to help students learn. Encourage critical thinking through thoughtful questions. Be patient, supportive, and encouraging. Adapt your teaching style to different learning preferences. When a student makes mistakes, guide them to the correct answer rather than just providing it.",
"researcher": "You are a thorough research assistant with a focus on accuracy and evidence-based information. Provide well-researched, factual information with sources when possible. Be thorough in your analysis and present multiple perspectives when relevant topics have different viewpoints. Always distinguish between established facts, current research, and opinions. When information is uncertain or contested, clearly communicate the level of confidence and supporting evidence.",
"creative_writer": "You are an experienced creative writing mentor and storytelling expert. Help with brainstorming ideas, character development, plot structure, dialogue, and creative expression. Be imaginative and inspiring while providing constructive, actionable feedback. Encourage experimentation with different writing styles and techniques. When reviewing work, balance praise for strengths with specific suggestions for improvement. Help writers find their unique voice while mastering fundamental storytelling principles.",
"custom": "You are a helpful AI assistant. Your personality, expertise, and behavior will be defined by the user through custom instructions. Follow the user's guidance on how to respond, what tone to use, and what role to play. Be adaptable and responsive to the specific needs and preferences outlined in your configuration."
}
if type_key not in default_prompts:
raise HTTPException(status_code=404, detail="Unknown prompt template type")
try:
# Update the template to default
await db.execute(
update(PromptTemplate)
.where(PromptTemplate.type_key == type_key)
.values(
system_prompt=default_prompts[type_key],
version=PromptTemplate.version + 1,
updated_at=datetime.utcnow()
)
)
await db.commit()
return {"message": "Prompt template reset to default successfully"}
except Exception as e:
await db.rollback()
log_api_request("reset_prompt_template_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to reset prompt template: {str(e)}")
@router.post("/improve")
async def improve_prompt_with_ai(
request: ImprovePromptRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Improve a prompt using AI"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("improve_prompt_with_ai", {
"user_id": user_id,
"chatbot_type": request.chatbot_type
})
try:
# Create system message for improvement
system_message = """You are an expert prompt engineer. Your task is to improve the given prompt to make it more effective, clear, and specific for the intended chatbot type.
Guidelines for improvement:
1. Make the prompt more specific and actionable
2. Add relevant context and constraints
3. Improve clarity and reduce ambiguity
4. Include appropriate tone and personality instructions
5. Add specific behavior examples when helpful
6. Ensure the prompt aligns with the chatbot type
7. Keep the prompt professional and ethical
8. Make it concise but comprehensive
Return ONLY the improved prompt text without any additional explanation or formatting."""
# Create user message with current prompt and context
user_message = f"""Chatbot Type: {request.chatbot_type}
Current Prompt:
{request.current_prompt}
{f"Additional Instructions: {request.improvement_instructions}" if request.improvement_instructions else ""}
Please improve this prompt to make it more effective for a {request.chatbot_type} chatbot."""
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
]
# Get available models to use a default model
models = await litellm_client.get_models()
if not models:
raise HTTPException(status_code=503, detail="No LLM models available")
# Use the first available model (you might want to make this configurable)
default_model = models[0]["id"]
# Make the AI call
response = await litellm_client.create_chat_completion(
model=default_model,
messages=messages,
user_id=str(user_id),
api_key_id=1, # Using default API key, you might want to make this dynamic
temperature=0.3,
max_tokens=1000
)
# Extract the improved prompt from the response
improved_prompt = response["choices"][0]["message"]["content"].strip()
return {
"improved_prompt": improved_prompt,
"original_prompt": request.current_prompt,
"model_used": default_model
}
except HTTPException:
raise
except Exception as e:
log_api_request("improve_prompt_with_ai_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to improve prompt: {str(e)}")

363
backend/app/api/v1/rag.py Normal file
View File

@@ -0,0 +1,363 @@
"""
RAG API Endpoints
Provides REST API for RAG (Retrieval Augmented Generation) operations
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel
import io
from app.db.database import get_db
from app.core.security import get_current_user
from app.models.user import User
from app.services.rag_service import RAGService
from app.utils.exceptions import APIException
router = APIRouter(tags=["RAG"])
# Request/Response Models
class CollectionCreate(BaseModel):
name: str
description: Optional[str] = None
class CollectionResponse(BaseModel):
id: str
name: str
description: str
document_count: int
size_bytes: int
vector_count: int
status: str
created_at: str
updated_at: str
is_active: bool
class DocumentResponse(BaseModel):
id: str
collection_id: str
collection_name: Optional[str]
filename: str
original_filename: str
file_type: str
size: int
mime_type: Optional[str]
status: str
processing_error: Optional[str]
converted_content: Optional[str]
word_count: int
character_count: int
vector_count: int
metadata: dict
created_at: str
processed_at: Optional[str]
indexed_at: Optional[str]
updated_at: str
class StatsResponse(BaseModel):
collections: dict
documents: dict
storage: dict
vectors: dict
# Collection Endpoints
@router.get("/collections", response_model=dict)
async def get_collections(
skip: int = 0,
limit: int = 100,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get all RAG collections from Qdrant (source of truth) with PostgreSQL metadata"""
try:
rag_service = RAGService(db)
collections_data = await rag_service.get_all_collections(skip=skip, limit=limit)
return {
"success": True,
"collections": collections_data,
"total": len(collections_data)
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/collections", response_model=dict)
async def create_collection(
collection_data: CollectionCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Create a new RAG collection"""
try:
rag_service = RAGService(db)
collection = await rag_service.create_collection(
name=collection_data.name,
description=collection_data.description
)
return {
"success": True,
"collection": collection.to_dict(),
"message": "Collection created successfully"
}
except APIException as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/collections/{collection_id}", response_model=dict)
async def get_collection(
collection_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get a specific collection"""
try:
rag_service = RAGService(db)
collection = await rag_service.get_collection(collection_id)
if not collection:
raise HTTPException(status_code=404, detail="Collection not found")
return {
"success": True,
"collection": collection.to_dict()
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/collections/{collection_id}", response_model=dict)
async def delete_collection(
collection_id: int,
cascade: bool = True, # Default to cascade deletion for better UX
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Delete a collection and optionally all its documents"""
try:
rag_service = RAGService(db)
success = await rag_service.delete_collection(collection_id, cascade=cascade)
if not success:
raise HTTPException(status_code=404, detail="Collection not found")
return {
"success": True,
"message": "Collection deleted successfully" + (" (with documents)" if cascade else "")
}
except APIException as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Document Endpoints
@router.get("/documents", response_model=dict)
async def get_documents(
collection_id: Optional[int] = None,
skip: int = 0,
limit: int = 100,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get documents, optionally filtered by collection"""
try:
rag_service = RAGService(db)
documents = await rag_service.get_documents(
collection_id=collection_id,
skip=skip,
limit=limit
)
return {
"success": True,
"documents": [doc.to_dict() for doc in documents],
"total": len(documents)
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/documents", response_model=dict)
async def upload_document(
collection_id: int = Form(...),
file: UploadFile = File(...),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Upload and process a document"""
try:
# Read file content
file_content = await file.read()
if len(file_content) == 0:
raise HTTPException(status_code=400, detail="Empty file uploaded")
if len(file_content) > 50 * 1024 * 1024: # 50MB limit
raise HTTPException(status_code=400, detail="File too large (max 50MB)")
rag_service = RAGService(db)
document = await rag_service.upload_document(
collection_id=collection_id,
file_content=file_content,
filename=file.filename or "unknown",
content_type=file.content_type
)
return {
"success": True,
"document": document.to_dict(),
"message": "Document uploaded and processing started"
}
except APIException as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/documents/{document_id}", response_model=dict)
async def get_document(
document_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get a specific document"""
try:
rag_service = RAGService(db)
document = await rag_service.get_document(document_id)
if not document:
raise HTTPException(status_code=404, detail="Document not found")
return {
"success": True,
"document": document.to_dict()
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/documents/{document_id}", response_model=dict)
async def delete_document(
document_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Delete a document"""
try:
rag_service = RAGService(db)
success = await rag_service.delete_document(document_id)
if not success:
raise HTTPException(status_code=404, detail="Document not found")
return {
"success": True,
"message": "Document deleted successfully"
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/documents/{document_id}/reprocess", response_model=dict)
async def reprocess_document(
document_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Restart processing for a stuck or failed document"""
try:
rag_service = RAGService(db)
success = await rag_service.reprocess_document(document_id)
if not success:
# Get document to check if it exists and its current status
document = await rag_service.get_document(document_id)
if not document:
raise HTTPException(status_code=404, detail="Document not found")
else:
raise HTTPException(
status_code=400,
detail=f"Cannot reprocess document with status '{document.status}'. Only 'processing' or 'error' documents can be reprocessed."
)
return {
"success": True,
"message": "Document reprocessing started successfully"
}
except APIException as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/documents/{document_id}/download")
async def download_document(
document_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Download the original document file"""
try:
rag_service = RAGService(db)
result = await rag_service.download_document(document_id)
if not result:
raise HTTPException(status_code=404, detail="Document not found or file not available")
content, filename, mime_type = result
return StreamingResponse(
io.BytesIO(content),
media_type=mime_type,
headers={"Content-Disposition": f"attachment; filename={filename}"}
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Stats Endpoint
@router.get("/stats", response_model=dict)
async def get_rag_stats(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get RAG system statistics"""
try:
rag_service = RAGService(db)
stats = await rag_service.get_stats()
return {
"success": True,
"stats": stats
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1,251 @@
"""
Security API endpoints for monitoring and configuration
"""
from typing import Dict, Any, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel, Field
from app.core.security import get_current_active_user, RequiresRole
from app.middleware.security import get_security_stats, get_request_auth_level, get_request_risk_score
from app.core.config import settings
from app.core.logging import get_logger
logger = get_logger(__name__)
router = APIRouter(tags=["security"])
# Pydantic models for API responses
class SecurityStatsResponse(BaseModel):
"""Security statistics response model"""
total_requests_analyzed: int
threats_detected: int
threats_blocked: int
anomalies_detected: int
rate_limits_exceeded: int
avg_analysis_time: float
threat_types: Dict[str, int]
threat_levels: Dict[str, int]
top_attacking_ips: List[tuple]
security_enabled: bool
threat_detection_enabled: bool
rate_limiting_enabled: bool
class SecurityConfigResponse(BaseModel):
"""Security configuration response model"""
security_enabled: bool = Field(description="Overall security system enabled")
threat_detection_enabled: bool = Field(description="Threat detection analysis enabled")
rate_limiting_enabled: bool = Field(description="Rate limiting enabled")
ip_reputation_enabled: bool = Field(description="IP reputation checking enabled")
anomaly_detection_enabled: bool = Field(description="Anomaly detection enabled")
security_headers_enabled: bool = Field(description="Security headers enabled")
# Rate limiting settings
unauthenticated_per_minute: int = Field(description="Rate limit for unauthenticated requests per minute")
authenticated_per_minute: int = Field(description="Rate limit for authenticated users per minute")
api_key_per_minute: int = Field(description="Rate limit for API key users per minute")
premium_per_minute: int = Field(description="Rate limit for premium users per minute")
# Security thresholds
risk_threshold: float = Field(description="Risk score threshold for blocking requests")
warning_threshold: float = Field(description="Risk score threshold for warnings")
anomaly_threshold: float = Field(description="Anomaly severity threshold")
# IP settings
blocked_ips: List[str] = Field(description="List of blocked IP addresses")
allowed_ips: List[str] = Field(description="List of allowed IP addresses (empty = allow all)")
class RateLimitInfoResponse(BaseModel):
"""Rate limit information for current request"""
auth_level: str = Field(description="Authentication level (unauthenticated, authenticated, api_key, premium)")
current_limits: Dict[str, int] = Field(description="Current rate limits for this auth level")
remaining_requests: Optional[Dict[str, int]] = Field(description="Estimated remaining requests (if available)")
@router.get("/stats", response_model=SecurityStatsResponse)
async def get_security_statistics(
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
):
"""
Get security system statistics
Requires admin role. Returns comprehensive statistics about:
- Request analysis counts
- Threat detection results
- Rate limiting enforcement
- Top attacking IPs
- Performance metrics
"""
try:
stats = get_security_stats()
return SecurityStatsResponse(**stats)
except Exception as e:
logger.error(f"Error getting security stats: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve security statistics"
)
@router.get("/config", response_model=SecurityConfigResponse)
async def get_security_config(
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
):
"""
Get current security configuration
Requires admin role. Returns current security settings including:
- Feature enablement flags
- Rate limiting thresholds
- Security thresholds
- IP allowlists/blocklists
"""
return SecurityConfigResponse(
security_enabled=settings.API_SECURITY_ENABLED,
threat_detection_enabled=settings.API_THREAT_DETECTION_ENABLED,
rate_limiting_enabled=settings.API_RATE_LIMITING_ENABLED,
ip_reputation_enabled=settings.API_IP_REPUTATION_ENABLED,
anomaly_detection_enabled=settings.API_ANOMALY_DETECTION_ENABLED,
security_headers_enabled=settings.API_SECURITY_HEADERS_ENABLED,
unauthenticated_per_minute=settings.API_RATE_LIMIT_UNAUTHENTICATED_PER_MINUTE,
authenticated_per_minute=settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE,
api_key_per_minute=settings.API_RATE_LIMIT_API_KEY_PER_MINUTE,
premium_per_minute=settings.API_RATE_LIMIT_PREMIUM_PER_MINUTE,
risk_threshold=settings.API_SECURITY_RISK_THRESHOLD,
warning_threshold=settings.API_SECURITY_WARNING_THRESHOLD,
anomaly_threshold=settings.API_SECURITY_ANOMALY_THRESHOLD,
blocked_ips=settings.API_BLOCKED_IPS,
allowed_ips=settings.API_ALLOWED_IPS
)
@router.get("/status")
async def get_security_status(
request: Request,
current_user: Dict[str, Any] = Depends(get_current_active_user)
):
"""
Get security status for current request
Returns information about the security analysis of the current request:
- Authentication level
- Risk score (if available)
- Rate limiting status
"""
auth_level = get_request_auth_level(request)
risk_score = get_request_risk_score(request)
# Get rate limits for current auth level
from app.core.threat_detection import AuthLevel
try:
auth_enum = AuthLevel(auth_level)
from app.core.threat_detection import threat_detection_service
minute_limit, hour_limit = threat_detection_service.get_rate_limits(auth_enum)
rate_limit_info = RateLimitInfoResponse(
auth_level=auth_level,
current_limits={
"per_minute": minute_limit,
"per_hour": hour_limit
},
remaining_requests=None # We don't track remaining requests in current implementation
)
except ValueError:
rate_limit_info = RateLimitInfoResponse(
auth_level=auth_level,
current_limits={},
remaining_requests=None
)
return {
"security_enabled": settings.API_SECURITY_ENABLED,
"auth_level": auth_level,
"risk_score": round(risk_score, 3) if risk_score > 0 else None,
"rate_limit_info": rate_limit_info.dict(),
"security_headers_enabled": settings.API_SECURITY_HEADERS_ENABLED
}
@router.post("/test")
async def test_security_analysis(
request: Request,
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
):
"""
Test security analysis on current request
Requires admin role. Manually triggers security analysis on the current request
and returns detailed results. Useful for testing security rules and thresholds.
"""
try:
from app.middleware.security import analyze_request_security
analysis = await analyze_request_security(request, current_user)
return {
"analysis_complete": True,
"is_threat": analysis.is_threat,
"risk_score": round(analysis.risk_score, 3),
"auth_level": analysis.auth_level.value,
"should_block": analysis.should_block,
"rate_limit_exceeded": analysis.rate_limit_exceeded,
"threat_count": len(analysis.threats),
"threats": [
{
"type": threat.threat_type,
"level": threat.level.value,
"confidence": round(threat.confidence, 3),
"description": threat.description,
"mitigation": threat.mitigation
}
for threat in analysis.threats
],
"recommendations": analysis.recommendations
}
except Exception as e:
logger.error(f"Error in security analysis test: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to perform security analysis test"
)
@router.get("/health")
async def security_health_check():
"""
Security system health check
Public endpoint that returns the health status of the security system.
Does not require authentication.
"""
try:
stats = get_security_stats()
# Basic health checks
is_healthy = (
settings.API_SECURITY_ENABLED and
stats.get("total_requests_analyzed", 0) >= 0 and
stats.get("avg_analysis_time", 0) < 1.0 # Analysis should be under 1 second
)
return {
"status": "healthy" if is_healthy else "degraded",
"security_enabled": settings.API_SECURITY_ENABLED,
"threat_detection_enabled": settings.API_THREAT_DETECTION_ENABLED,
"rate_limiting_enabled": settings.API_RATE_LIMITING_ENABLED,
"avg_analysis_time_ms": round(stats.get("avg_analysis_time", 0) * 1000, 2),
"total_requests_analyzed": stats.get("total_requests_analyzed", 0)
}
except Exception as e:
logger.error(f"Security health check failed: {e}")
return {
"status": "unhealthy",
"error": "Security system error",
"security_enabled": settings.API_SECURITY_ENABLED
}

View File

@@ -0,0 +1,677 @@
"""
Settings management endpoints
"""
from typing import List, Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete
from app.db.database import get_db
from app.models.user import User
from app.core.security import get_current_user
from app.services.permission_manager import require_permission
from app.services.audit_service import log_audit_event
from app.core.logging import get_logger
from app.core.config import settings as app_settings
logger = get_logger(__name__)
router = APIRouter()
# Pydantic models
class SettingValue(BaseModel):
value: Any
value_type: str = Field(..., pattern="^(string|integer|float|boolean|json|list)$")
description: Optional[str] = None
is_secret: bool = False
class SettingResponse(BaseModel):
key: str
value: Any
value_type: str
description: Optional[str] = None
is_secret: bool = False
category: str
is_system: bool = False
created_at: str
updated_at: Optional[str] = None
class SettingUpdate(BaseModel):
value: Any
description: Optional[str] = None
class SystemInfoResponse(BaseModel):
version: str
environment: str
database_status: str
redis_status: str
litellm_status: str
modules_loaded: int
active_users: int
total_api_keys: int
uptime_seconds: int
class PlatformConfigResponse(BaseModel):
app_name: str
debug_mode: bool
log_level: str
cors_origins: List[str]
rate_limiting_enabled: bool
max_upload_size: int
session_timeout_minutes: int
api_key_prefix: str
features: Dict[str, bool]
maintenance_mode: bool = False
maintenance_message: Optional[str] = None
class SecurityConfigResponse(BaseModel):
password_min_length: int
password_require_special: bool
password_require_numbers: bool
password_require_uppercase: bool
session_timeout_minutes: int
max_login_attempts: int
lockout_duration_minutes: int
require_2fa: bool = False
allowed_domains: List[str] = Field(default_factory=list)
ip_whitelist_enabled: bool = False
# Global settings storage (in a real app, this would be in database)
SETTINGS_STORE: Dict[str, Dict[str, Any]] = {
"platform": {
"app_name": {"value": "Confidential Empire", "type": "string", "description": "Application name"},
"maintenance_mode": {"value": False, "type": "boolean", "description": "Enable maintenance mode"},
"maintenance_message": {"value": None, "type": "string", "description": "Maintenance mode message"},
"debug_mode": {"value": False, "type": "boolean", "description": "Enable debug mode"},
"max_upload_size": {"value": 10485760, "type": "integer", "description": "Maximum upload size in bytes"},
},
"api": {
# Security Settings
"security_enabled": {"value": True, "type": "boolean", "description": "Enable API security system"},
"threat_detection_enabled": {"value": True, "type": "boolean", "description": "Enable threat detection analysis"},
"rate_limiting_enabled": {"value": True, "type": "boolean", "description": "Enable rate limiting"},
"ip_reputation_enabled": {"value": True, "type": "boolean", "description": "Enable IP reputation checking"},
"anomaly_detection_enabled": {"value": True, "type": "boolean", "description": "Enable anomaly detection"},
"security_headers_enabled": {"value": True, "type": "boolean", "description": "Enable security headers"},
# Rate Limiting by Authentication Level
"rate_limit_authenticated_per_minute": {"value": 200, "type": "integer", "description": "Rate limit for authenticated users per minute"},
"rate_limit_authenticated_per_hour": {"value": 5000, "type": "integer", "description": "Rate limit for authenticated users per hour"},
"rate_limit_api_key_per_minute": {"value": 1000, "type": "integer", "description": "Rate limit for API key users per minute"},
"rate_limit_api_key_per_hour": {"value": 20000, "type": "integer", "description": "Rate limit for API key users per hour"},
"rate_limit_premium_per_minute": {"value": 5000, "type": "integer", "description": "Rate limit for premium users per minute"},
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer", "description": "Rate limit for premium users per hour"},
# Security Thresholds
"security_risk_threshold": {"value": 0.8, "type": "float", "description": "Risk score threshold for blocking requests (0.0-1.0)"},
"security_warning_threshold": {"value": 0.6, "type": "float", "description": "Risk score threshold for warnings (0.0-1.0)"},
"anomaly_threshold": {"value": 0.7, "type": "float", "description": "Anomaly severity threshold (0.0-1.0)"},
# Request Settings
"max_request_size_mb": {"value": 10, "type": "integer", "description": "Maximum request size in MB for standard users"},
"max_request_size_premium_mb": {"value": 50, "type": "integer", "description": "Maximum request size in MB for premium users"},
"enable_cors": {"value": True, "type": "boolean", "description": "Enable CORS headers"},
"cors_origins": {"value": ["http://localhost:3000", "http://localhost:53000"], "type": "list", "description": "Allowed CORS origins"},
"api_key_expiry_days": {"value": 90, "type": "integer", "description": "Default API key expiry in days"},
# IP Security
"blocked_ips": {"value": [], "type": "list", "description": "List of blocked IP addresses"},
"allowed_ips": {"value": [], "type": "list", "description": "List of allowed IP addresses (empty = allow all)"},
"csp_header": {"value": "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline';", "type": "string", "description": "Content Security Policy header"},
},
"security": {
"password_min_length": {"value": 8, "type": "integer", "description": "Minimum password length"},
"password_require_special": {"value": True, "type": "boolean", "description": "Require special characters in passwords"},
"password_require_numbers": {"value": True, "type": "boolean", "description": "Require numbers in passwords"},
"password_require_uppercase": {"value": True, "type": "boolean", "description": "Require uppercase letters in passwords"},
"max_login_attempts": {"value": 5, "type": "integer", "description": "Maximum login attempts before lockout"},
"lockout_duration_minutes": {"value": 15, "type": "integer", "description": "Account lockout duration in minutes"},
"require_2fa": {"value": False, "type": "boolean", "description": "Require two-factor authentication"},
"ip_whitelist_enabled": {"value": False, "type": "boolean", "description": "Enable IP whitelist"},
"allowed_domains": {"value": [], "type": "list", "description": "Allowed email domains for registration"},
},
"features": {
"user_registration": {"value": True, "type": "boolean", "description": "Allow user registration"},
"api_key_creation": {"value": True, "type": "boolean", "description": "Allow API key creation"},
"budget_enforcement": {"value": True, "type": "boolean", "description": "Enable budget enforcement"},
"audit_logging": {"value": True, "type": "boolean", "description": "Enable audit logging"},
"module_hot_reload": {"value": True, "type": "boolean", "description": "Enable module hot reload"},
"tee_support": {"value": True, "type": "boolean", "description": "Enable TEE (Trusted Execution Environment) support"},
"advanced_analytics": {"value": True, "type": "boolean", "description": "Enable advanced analytics"},
},
"notifications": {
"email_enabled": {"value": False, "type": "boolean", "description": "Enable email notifications"},
"slack_enabled": {"value": False, "type": "boolean", "description": "Enable Slack notifications"},
"webhook_enabled": {"value": False, "type": "boolean", "description": "Enable webhook notifications"},
"budget_alerts": {"value": True, "type": "boolean", "description": "Enable budget alert notifications"},
"security_alerts": {"value": True, "type": "boolean", "description": "Enable security alert notifications"},
}
}
# Settings management endpoints
@router.get("/")
async def list_settings(
category: Optional[str] = None,
include_secrets: bool = False,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""List all settings or settings in a specific category"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:settings:read")
result = {}
for cat, settings in SETTINGS_STORE.items():
if category and cat != category:
continue
result[cat] = {}
for key, setting in settings.items():
# Hide secret values unless specifically requested and user has permission
if setting.get("is_secret", False) and not include_secrets:
if not any(perm in current_user.get("permissions", []) for perm in ["platform:settings:admin", "platform:*"]):
continue
result[cat][key] = {
"value": setting["value"],
"type": setting["type"],
"description": setting.get("description", ""),
"is_secret": setting.get("is_secret", False)
}
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="list_settings",
resource_type="setting",
details={"category": category, "include_secrets": include_secrets}
)
return result
@router.get("/system-info", response_model=SystemInfoResponse)
async def get_system_info(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Get system information and status"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:settings:read")
import psutil
import time
from app.models.api_key import APIKey
# Get database status
try:
await db.execute(select(1))
database_status = "healthy"
except Exception:
database_status = "error"
# Get Redis status (simplified check)
redis_status = "healthy" # Would implement actual Redis check
# Get LiteLLM status (simplified check)
litellm_status = "healthy" # Would implement actual LiteLLM check
# Get modules loaded (from module manager)
modules_loaded = 8 # Would get from actual module manager
# Get active users count (last 24 hours)
from datetime import datetime, timedelta
yesterday = datetime.utcnow() - timedelta(days=1)
active_users_query = select(User.id).where(User.last_login >= yesterday)
active_users_result = await db.execute(active_users_query)
active_users = len(active_users_result.fetchall())
# Get total API keys
total_api_keys_query = select(APIKey.id)
total_api_keys_result = await db.execute(total_api_keys_query)
total_api_keys = len(total_api_keys_result.fetchall())
# Get uptime (simplified - would track actual start time)
uptime_seconds = int(time.time()) % 86400 # Placeholder
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="get_system_info",
resource_type="system"
)
return SystemInfoResponse(
version="1.0.0",
environment="production",
database_status=database_status,
redis_status=redis_status,
litellm_status=litellm_status,
modules_loaded=modules_loaded,
active_users=active_users,
total_api_keys=total_api_keys,
uptime_seconds=uptime_seconds
)
@router.get("/platform-config", response_model=PlatformConfigResponse)
async def get_platform_config(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Get platform configuration"""
# Basic users can see non-sensitive platform config
platform_settings = SETTINGS_STORE.get("platform", {})
feature_settings = SETTINGS_STORE.get("features", {})
features = {key: setting["value"] for key, setting in feature_settings.items()}
# Get API settings for rate limiting
api_settings = SETTINGS_STORE.get("api", {})
return PlatformConfigResponse(
app_name=platform_settings.get("app_name", {}).get("value", "Confidential Empire"),
debug_mode=platform_settings.get("debug_mode", {}).get("value", False),
log_level=app_settings.LOG_LEVEL,
cors_origins=app_settings.CORS_ORIGINS,
rate_limiting_enabled=api_settings.get("rate_limiting_enabled", {}).get("value", True),
max_upload_size=platform_settings.get("max_upload_size", {}).get("value", 10485760),
session_timeout_minutes=app_settings.SESSION_EXPIRE_MINUTES,
api_key_prefix=app_settings.API_KEY_PREFIX,
features=features,
maintenance_mode=platform_settings.get("maintenance_mode", {}).get("value", False),
maintenance_message=platform_settings.get("maintenance_message", {}).get("value")
)
@router.get("/security-config", response_model=SecurityConfigResponse)
async def get_security_config(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Get security configuration"""
# Check permissions for sensitive security settings
require_permission(current_user.get("permissions", []), "platform:settings:read")
security_settings = SETTINGS_STORE.get("security", {})
return SecurityConfigResponse(
password_min_length=security_settings.get("password_min_length", {}).get("value", 8),
password_require_special=security_settings.get("password_require_special", {}).get("value", True),
password_require_numbers=security_settings.get("password_require_numbers", {}).get("value", True),
password_require_uppercase=security_settings.get("password_require_uppercase", {}).get("value", True),
session_timeout_minutes=app_settings.SESSION_EXPIRE_MINUTES,
max_login_attempts=security_settings.get("max_login_attempts", {}).get("value", 5),
lockout_duration_minutes=security_settings.get("lockout_duration_minutes", {}).get("value", 15),
require_2fa=security_settings.get("require_2fa", {}).get("value", False),
allowed_domains=security_settings.get("allowed_domains", {}).get("value", []),
ip_whitelist_enabled=security_settings.get("ip_whitelist_enabled", {}).get("value", False)
)
@router.get("/{category}/{key}")
async def get_setting(
category: str,
key: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Get a specific setting value"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:settings:read")
if category not in SETTINGS_STORE:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Settings category '{category}' not found"
)
if key not in SETTINGS_STORE[category]:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Setting '{key}' not found in category '{category}'"
)
setting = SETTINGS_STORE[category][key]
# Check if it's a secret setting
if setting.get("is_secret", False):
require_permission(current_user.get("permissions", []), "platform:settings:admin")
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="get_setting",
resource_type="setting",
resource_id=f"{category}.{key}"
)
return {
"category": category,
"key": key,
"value": setting["value"],
"type": setting["type"],
"description": setting.get("description", ""),
"is_secret": setting.get("is_secret", False)
}
@router.put("/{category}/{key}")
async def update_setting(
category: str,
key: str,
setting_update: SettingUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Update a specific setting"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:settings:update")
if category not in SETTINGS_STORE:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Settings category '{category}' not found"
)
if key not in SETTINGS_STORE[category]:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Setting '{key}' not found in category '{category}'"
)
setting = SETTINGS_STORE[category][key]
# Check if it's a secret setting
if setting.get("is_secret", False):
require_permission(current_user.get("permissions", []), "platform:settings:admin")
# Store original value for audit
original_value = setting["value"]
# Validate value type
expected_type = setting["type"]
new_value = setting_update.value
if expected_type == "integer" and not isinstance(new_value, int):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Setting '{key}' expects an integer value"
)
elif expected_type == "boolean" and not isinstance(new_value, bool):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Setting '{key}' expects a boolean value"
)
elif expected_type == "float" and not isinstance(new_value, (int, float)):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Setting '{key}' expects a numeric value"
)
elif expected_type == "list" and not isinstance(new_value, list):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Setting '{key}' expects a list value"
)
# Update setting
SETTINGS_STORE[category][key]["value"] = new_value
if setting_update.description is not None:
SETTINGS_STORE[category][key]["description"] = setting_update.description
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="update_setting",
resource_type="setting",
resource_id=f"{category}.{key}",
details={
"original_value": original_value,
"new_value": new_value,
"description_updated": setting_update.description is not None
}
)
logger.info(f"Setting updated: {category}.{key} by {current_user['username']}")
return {
"category": category,
"key": key,
"value": new_value,
"type": expected_type,
"description": SETTINGS_STORE[category][key].get("description", ""),
"message": "Setting updated successfully"
}
@router.post("/reset-defaults")
async def reset_to_defaults(
category: Optional[str] = None,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Reset settings to default values"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:settings:admin")
# Define default values
defaults = {
"platform": {
"app_name": {"value": "Confidential Empire", "type": "string"},
"maintenance_mode": {"value": False, "type": "boolean"},
"debug_mode": {"value": False, "type": "boolean"},
"max_upload_size": {"value": 10485760, "type": "integer"},
},
"api": {
# Security Settings
"security_enabled": {"value": True, "type": "boolean"},
"threat_detection_enabled": {"value": True, "type": "boolean"},
"rate_limiting_enabled": {"value": True, "type": "boolean"},
"ip_reputation_enabled": {"value": True, "type": "boolean"},
"anomaly_detection_enabled": {"value": True, "type": "boolean"},
"security_headers_enabled": {"value": True, "type": "boolean"},
# Rate Limiting by Authentication Level
"rate_limit_authenticated_per_minute": {"value": 200, "type": "integer"},
"rate_limit_authenticated_per_hour": {"value": 5000, "type": "integer"},
"rate_limit_api_key_per_minute": {"value": 1000, "type": "integer"},
"rate_limit_api_key_per_hour": {"value": 20000, "type": "integer"},
"rate_limit_premium_per_minute": {"value": 5000, "type": "integer"},
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer"},
# Security Thresholds
"security_risk_threshold": {"value": 0.8, "type": "float"},
"security_warning_threshold": {"value": 0.6, "type": "float"},
"anomaly_threshold": {"value": 0.7, "type": "float"},
# Request Settings
"max_request_size_mb": {"value": 10, "type": "integer"},
"max_request_size_premium_mb": {"value": 50, "type": "integer"},
"enable_cors": {"value": True, "type": "boolean"},
"cors_origins": {"value": ["http://localhost:3000", "http://localhost:53000"], "type": "list"},
"api_key_expiry_days": {"value": 90, "type": "integer"},
# IP Security
"blocked_ips": {"value": [], "type": "list"},
"allowed_ips": {"value": [], "type": "list"},
"csp_header": {"value": "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline';", "type": "string"},
},
"security": {
"password_min_length": {"value": 8, "type": "integer"},
"password_require_special": {"value": True, "type": "boolean"},
"password_require_numbers": {"value": True, "type": "boolean"},
"password_require_uppercase": {"value": True, "type": "boolean"},
"max_login_attempts": {"value": 5, "type": "integer"},
"lockout_duration_minutes": {"value": 15, "type": "integer"},
"require_2fa": {"value": False, "type": "boolean"},
"ip_whitelist_enabled": {"value": False, "type": "boolean"},
"allowed_domains": {"value": [], "type": "list"},
},
"features": {
"user_registration": {"value": True, "type": "boolean"},
"api_key_creation": {"value": True, "type": "boolean"},
"budget_enforcement": {"value": True, "type": "boolean"},
"audit_logging": {"value": True, "type": "boolean"},
"module_hot_reload": {"value": True, "type": "boolean"},
"tee_support": {"value": True, "type": "boolean"},
"advanced_analytics": {"value": True, "type": "boolean"},
}
}
reset_categories = [category] if category else list(defaults.keys())
for cat in reset_categories:
if cat in defaults and cat in SETTINGS_STORE:
for key, default_setting in defaults[cat].items():
if key in SETTINGS_STORE[cat]:
SETTINGS_STORE[cat][key]["value"] = default_setting["value"]
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="reset_settings_to_defaults",
resource_type="setting",
details={"categories_reset": reset_categories}
)
logger.info(f"Settings reset to defaults: {reset_categories} by {current_user['username']}")
return {
"message": f"Settings reset to defaults for categories: {reset_categories}",
"categories_reset": reset_categories
}
@router.post("/export")
async def export_settings(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Export all settings to JSON"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:settings:export")
# Export all settings (excluding secrets for non-admin users)
export_data = {}
for category, settings in SETTINGS_STORE.items():
export_data[category] = {}
for key, setting in settings.items():
# Skip secret settings for non-admin users
if setting.get("is_secret", False):
if not any(perm in current_user.get("permissions", []) for perm in ["platform:settings:admin", "platform:*"]):
continue
export_data[category][key] = {
"value": setting["value"],
"type": setting["type"],
"description": setting.get("description", ""),
"is_secret": setting.get("is_secret", False)
}
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="export_settings",
resource_type="setting",
details={"categories_exported": list(export_data.keys())}
)
return {
"settings": export_data,
"exported_at": datetime.utcnow().isoformat(),
"exported_by": current_user['username']
}
@router.post("/import")
async def import_settings(
settings_data: Dict[str, Dict[str, Dict[str, Any]]],
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Import settings from JSON"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:settings:admin")
imported_count = 0
errors = []
for category, settings in settings_data.items():
if category not in SETTINGS_STORE:
errors.append(f"Unknown category: {category}")
continue
for key, setting_data in settings.items():
if key not in SETTINGS_STORE[category]:
errors.append(f"Unknown setting: {category}.{key}")
continue
try:
# Validate and import
expected_type = SETTINGS_STORE[category][key]["type"]
new_value = setting_data.get("value")
# Basic type validation
if expected_type == "integer" and not isinstance(new_value, int):
errors.append(f"Invalid type for {category}.{key}: expected integer")
continue
elif expected_type == "boolean" and not isinstance(new_value, bool):
errors.append(f"Invalid type for {category}.{key}: expected boolean")
continue
SETTINGS_STORE[category][key]["value"] = new_value
if "description" in setting_data:
SETTINGS_STORE[category][key]["description"] = setting_data["description"]
imported_count += 1
except Exception as e:
errors.append(f"Error importing {category}.{key}: {str(e)}")
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="import_settings",
resource_type="setting",
details={
"imported_count": imported_count,
"errors_count": len(errors),
"errors": errors
}
)
logger.info(f"Settings imported: {imported_count} settings by {current_user['username']}")
return {
"message": f"Import completed. {imported_count} settings imported.",
"imported_count": imported_count,
"errors": errors
}

334
backend/app/api/v1/tee.py Normal file
View File

@@ -0,0 +1,334 @@
"""
TEE (Trusted Execution Environment) API endpoints
Handles Privatemode.ai TEE integration endpoints
"""
import logging
from typing import Dict, Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, Field
from app.services.tee_service import tee_service
from app.services.api_key_auth import get_current_api_key_user
from app.models.user import User
from app.models.api_key import APIKey
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/tee", tags=["tee"])
security = HTTPBearer()
class AttestationRequest(BaseModel):
"""Request model for attestation"""
nonce: Optional[str] = Field(None, description="Optional nonce for attestation")
class AttestationVerificationRequest(BaseModel):
"""Request model for attestation verification"""
report: str = Field(..., description="Attestation report")
signature: str = Field(..., description="Attestation signature")
certificate_chain: str = Field(..., description="Certificate chain")
nonce: Optional[str] = Field(None, description="Optional nonce")
class SecureSessionRequest(BaseModel):
"""Request model for secure session creation"""
capabilities: Optional[list] = Field(
default=["confidential_inference", "secure_memory", "attestation"],
description="Requested TEE capabilities"
)
@router.get("/health")
async def get_tee_health():
"""
Get TEE environment health status
Returns comprehensive health information about the TEE environment
including capabilities, status, and availability.
"""
try:
health_data = await tee_service.health_check()
return {
"success": True,
"data": health_data
}
except Exception as e:
logger.error(f"TEE health check failed: {e}")
raise HTTPException(
status_code=500,
detail="Failed to get TEE health status"
)
@router.get("/capabilities")
async def get_tee_capabilities(
current_user: tuple = Depends(get_current_api_key_user)
):
"""
Get TEE environment capabilities
Returns detailed information about TEE capabilities including
supported features, encryption algorithms, and security properties.
Requires authentication.
"""
try:
user, api_key = current_user
capabilities = await tee_service.get_tee_capabilities()
return {
"success": True,
"data": capabilities
}
except Exception as e:
logger.error(f"Failed to get TEE capabilities: {e}")
raise HTTPException(
status_code=500,
detail="Failed to get TEE capabilities"
)
@router.post("/attestation")
async def get_attestation(
request: AttestationRequest,
current_user: tuple = Depends(get_current_api_key_user)
):
"""
Get TEE attestation report
Generates a cryptographic attestation report that proves the integrity
and authenticity of the TEE environment. The report can be used to
verify that code is running in a genuine TEE.
Requires authentication.
"""
try:
user, api_key = current_user
attestation_data = await tee_service.get_attestation(request.nonce)
return {
"success": True,
"data": attestation_data
}
except Exception as e:
logger.error(f"Failed to get attestation: {e}")
raise HTTPException(
status_code=500,
detail="Failed to get TEE attestation"
)
@router.post("/attestation/verify")
async def verify_attestation(
request: AttestationVerificationRequest,
current_user: tuple = Depends(get_current_api_key_user)
):
"""
Verify TEE attestation report
Verifies the authenticity and integrity of a TEE attestation report.
This includes validating the certificate chain, signature, and
measurements against known good values.
Requires authentication.
"""
try:
user, api_key = current_user
attestation_data = {
"report": request.report,
"signature": request.signature,
"certificate_chain": request.certificate_chain,
"nonce": request.nonce
}
verification_result = await tee_service.verify_attestation(attestation_data)
return {
"success": True,
"data": verification_result
}
except Exception as e:
logger.error(f"Failed to verify attestation: {e}")
raise HTTPException(
status_code=500,
detail="Failed to verify TEE attestation"
)
@router.post("/session")
async def create_secure_session(
request: SecureSessionRequest,
current_user: tuple = Depends(get_current_api_key_user)
):
"""
Create a secure TEE session
Creates a secure session within the TEE environment with requested
capabilities. The session provides isolated execution context with
enhanced security properties.
Requires authentication.
"""
try:
user, api_key = current_user
session_data = await tee_service.create_secure_session(
user_id=str(user.id),
api_key_id=api_key.id
)
return {
"success": True,
"data": session_data
}
except Exception as e:
logger.error(f"Failed to create secure session: {e}")
raise HTTPException(
status_code=500,
detail="Failed to create TEE secure session"
)
@router.get("/metrics")
async def get_privacy_metrics(
current_user: tuple = Depends(get_current_api_key_user)
):
"""
Get privacy and security metrics
Returns comprehensive metrics about TEE usage, privacy protection,
and security status including request counts, data encrypted,
and performance statistics.
Requires authentication.
"""
try:
user, api_key = current_user
metrics = await tee_service.get_privacy_metrics()
return {
"success": True,
"data": metrics
}
except Exception as e:
logger.error(f"Failed to get privacy metrics: {e}")
raise HTTPException(
status_code=500,
detail="Failed to get privacy metrics"
)
@router.get("/models")
async def list_tee_models(
current_user: tuple = Depends(get_current_api_key_user)
):
"""
List available TEE models
Returns a list of AI models available through the TEE environment.
These models provide confidential inference capabilities with
enhanced privacy and security properties.
Requires authentication.
"""
try:
user, api_key = current_user
models = await tee_service.list_tee_models()
return {
"success": True,
"data": models,
"count": len(models)
}
except Exception as e:
logger.error(f"Failed to list TEE models: {e}")
raise HTTPException(
status_code=500,
detail="Failed to list TEE models"
)
@router.get("/status")
async def get_tee_status(
current_user: tuple = Depends(get_current_api_key_user)
):
"""
Get comprehensive TEE status
Returns combined status information including health, capabilities,
and metrics for a complete overview of the TEE environment.
Requires authentication.
"""
try:
user, api_key = current_user
# Get all status information
health_data = await tee_service.health_check()
capabilities = await tee_service.get_tee_capabilities()
metrics = await tee_service.get_privacy_metrics()
models = await tee_service.list_tee_models()
status_data = {
"health": health_data,
"capabilities": capabilities,
"metrics": metrics,
"models": {
"available": len(models),
"list": models
},
"summary": {
"tee_enabled": health_data.get("tee_enabled", False),
"secure_inference_available": len(models) > 0,
"attestation_available": health_data.get("attestation_available", False),
"privacy_score": metrics.get("privacy_score", 0)
}
}
return {
"success": True,
"data": status_data
}
except Exception as e:
logger.error(f"Failed to get TEE status: {e}")
raise HTTPException(
status_code=500,
detail="Failed to get TEE status"
)
@router.delete("/cache")
async def clear_attestation_cache(
current_user: tuple = Depends(get_current_api_key_user)
):
"""
Clear attestation cache
Manually clears the attestation cache to force fresh attestation
reports. This can be useful for debugging or when attestation
requirements change.
Requires authentication.
"""
try:
user, api_key = current_user
# Clear the cache
await tee_service.cleanup_expired_cache()
tee_service.attestation_cache.clear()
return {
"success": True,
"message": "Attestation cache cleared successfully"
}
except Exception as e:
logger.error(f"Failed to clear attestation cache: {e}")
raise HTTPException(
status_code=500,
detail="Failed to clear attestation cache"
)

472
backend/app/api/v1/users.py Normal file
View File

@@ -0,0 +1,472 @@
"""
User management API endpoints
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Query
from pydantic import BaseModel, EmailStr
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete
from sqlalchemy.orm import selectinload
from app.db.database import get_db
from app.models.user import User
from app.models.api_key import APIKey
from app.models.budget import Budget
from app.core.security import get_current_user, get_password_hash, verify_password
from app.services.permission_manager import require_permission
from app.services.audit_service import log_audit_event
from app.core.logging import get_logger
logger = get_logger(__name__)
router = APIRouter()
# Pydantic models
class UserCreate(BaseModel):
username: str
email: EmailStr
full_name: Optional[str] = None
password: str
role: str = "user"
is_active: bool = True
class UserUpdate(BaseModel):
username: Optional[str] = None
email: Optional[EmailStr] = None
full_name: Optional[str] = None
role: Optional[str] = None
is_active: Optional[bool] = None
is_verified: Optional[bool] = None
class UserResponse(BaseModel):
id: str
username: str
email: str
full_name: Optional[str] = None
role: str
is_active: bool
is_verified: bool
is_superuser: bool
created_at: str
updated_at: Optional[str] = None
last_login: Optional[str] = None
class Config:
from_attributes = True
class UserListResponse(BaseModel):
users: List[UserResponse]
total: int
page: int
size: int
class PasswordChangeRequest(BaseModel):
current_password: str
new_password: str
class PasswordResetRequest(BaseModel):
new_password: str
# User CRUD endpoints
@router.get("/", response_model=UserListResponse)
async def list_users(
page: int = Query(1, ge=1),
size: int = Query(10, ge=1, le=100),
role: Optional[str] = Query(None),
is_active: Optional[bool] = Query(None),
search: Optional[str] = Query(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""List all users with pagination and filtering"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:users:read")
# Build query
query = select(User)
# Apply filters
if role:
query = query.where(User.role == role)
if is_active is not None:
query = query.where(User.is_active == is_active)
if search:
query = query.where(
(User.username.ilike(f"%{search}%")) |
(User.email.ilike(f"%{search}%")) |
(User.full_name.ilike(f"%{search}%"))
)
# Get total count
total_query = select(User.id).select_from(query.subquery())
total_result = await db.execute(total_query)
total = len(total_result.fetchall())
# Apply pagination
offset = (page - 1) * size
query = query.offset(offset).limit(size)
# Execute query
result = await db.execute(query)
users = result.scalars().all()
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="list_users",
resource_type="user",
details={"page": page, "size": size, "filters": {"role": role, "is_active": is_active, "search": search}}
)
return UserListResponse(
users=[UserResponse.model_validate(user) for user in users],
total=total,
page=page,
size=size
)
@router.get("/{user_id}", response_model=UserResponse)
async def get_user(
user_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Get user by ID"""
# Check permissions (users can view their own profile)
if int(user_id) != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:users:read")
# Get user
query = select(User).where(User.id == int(user_id))
result = await db.execute(query)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="get_user",
resource_type="user",
resource_id=user_id
)
return UserResponse.model_validate(user)
@router.post("/", response_model=UserResponse)
async def create_user(
user_data: UserCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Create a new user"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:users:create")
# Check if user already exists
query = select(User).where(
(User.username == user_data.username) | (User.email == user_data.email)
)
result = await db.execute(query)
existing_user = result.scalar_one_or_none()
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User with this username or email already exists"
)
# Create user
hashed_password = get_password_hash(user_data.password)
new_user = User(
username=user_data.username,
email=user_data.email,
full_name=user_data.full_name,
hashed_password=hashed_password,
role=user_data.role,
is_active=user_data.is_active
)
db.add(new_user)
await db.commit()
await db.refresh(new_user)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="create_user",
resource_type="user",
resource_id=str(new_user.id),
details={"username": user_data.username, "email": user_data.email, "role": user_data.role}
)
logger.info(f"User created: {new_user.username} by {current_user['username']}")
return UserResponse.model_validate(new_user)
@router.put("/{user_id}", response_model=UserResponse)
async def update_user(
user_id: str,
user_data: UserUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Update user"""
# Check permissions (users can update their own profile with limited fields)
is_self_update = int(user_id) == current_user['id']
if not is_self_update:
require_permission(current_user.get("permissions", []), "platform:users:update")
# Get user
query = select(User).where(User.id == int(user_id))
result = await db.execute(query)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
# For self-updates, restrict what can be changed
if is_self_update:
allowed_fields = {"username", "email", "full_name"}
update_data = user_data.model_dump(exclude_unset=True)
restricted_fields = set(update_data.keys()) - allowed_fields
if restricted_fields:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Cannot update fields: {restricted_fields}"
)
# Store original values for audit
original_values = {
"username": user.username,
"email": user.email,
"role": user.role,
"is_active": user.is_active
}
# Update user fields
update_data = user_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(user, field, value)
await db.commit()
await db.refresh(user)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="update_user",
resource_type="user",
resource_id=user_id,
details={
"updated_fields": list(update_data.keys()),
"before_values": original_values,
"after_values": {k: getattr(user, k) for k in update_data.keys()}
}
)
logger.info(f"User updated: {user.username} by {current_user['username']}")
return UserResponse.model_validate(user)
@router.delete("/{user_id}")
async def delete_user(
user_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Delete user (soft delete by deactivating)"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:users:delete")
# Prevent self-deletion
if int(user_id) == current_user['id']:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot delete your own account"
)
# Get user
query = select(User).where(User.id == int(user_id))
result = await db.execute(query)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
# Soft delete by deactivating
user.is_active = False
await db.commit()
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="delete_user",
resource_type="user",
resource_id=user_id,
details={"username": user.username, "email": user.email}
)
logger.info(f"User deleted: {user.username} by {current_user['username']}")
return {"message": "User deleted successfully"}
@router.post("/{user_id}/change-password")
async def change_password(
user_id: str,
password_data: PasswordChangeRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Change user password"""
# Users can only change their own password, or admins can change any password
is_self_update = int(user_id) == current_user['id']
if not is_self_update:
require_permission(current_user.get("permissions", []), "platform:users:update")
# Get user
query = select(User).where(User.id == int(user_id))
result = await db.execute(query)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
# For self-updates, verify current password
if is_self_update:
if not verify_password(password_data.current_password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Current password is incorrect"
)
# Update password
user.hashed_password = get_password_hash(password_data.new_password)
await db.commit()
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="change_password",
resource_type="user",
resource_id=user_id,
details={"target_user": user.username}
)
logger.info(f"Password changed for user: {user.username} by {current_user['username']}")
return {"message": "Password changed successfully"}
@router.post("/{user_id}/reset-password")
async def reset_password(
user_id: str,
password_data: PasswordResetRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Reset user password (admin only)"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:users:update")
# Get user
query = select(User).where(User.id == int(user_id))
result = await db.execute(query)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
# Reset password
user.hashed_password = get_password_hash(password_data.new_password)
await db.commit()
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
action="reset_password",
resource_type="user",
resource_id=user_id,
details={"target_user": user.username}
)
logger.info(f"Password reset for user: {user.username} by {current_user['username']}")
return {"message": "Password reset successfully"}
@router.get("/{user_id}/api-keys", response_model=List[dict])
async def get_user_api_keys(
user_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Get API keys for a user"""
# Check permissions (users can view their own API keys)
is_self_request = int(user_id) == current_user['id']
if not is_self_request:
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
# Get API keys
query = select(APIKey).where(APIKey.user_id == int(user_id))
result = await db.execute(query)
api_keys = result.scalars().all()
# Return safe representation (no key values)
return [
{
"id": str(api_key.id),
"name": api_key.name,
"key_prefix": api_key.key_prefix,
"scopes": api_key.scopes,
"is_active": api_key.is_active,
"created_at": api_key.created_at.isoformat(),
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
"last_used_at": api_key.last_used_at.isoformat() if api_key.last_used_at else None
}
for api_key in api_keys
]

View File

@@ -0,0 +1,3 @@
"""
Core package
"""

130
backend/app/core/config.py Normal file
View File

@@ -0,0 +1,130 @@
"""
Configuration settings for the application
"""
import os
from typing import List, Optional, Union
from pydantic import field_validator
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""Application settings"""
# Application
APP_NAME: str = "Shifra"
APP_DEBUG: bool = False
APP_LOG_LEVEL: str = "INFO"
APP_HOST: str = "0.0.0.0"
APP_PORT: int = 8000
# Detailed logging for LLM interactions
LOG_LLM_PROMPTS: bool = False # Set to True to log prompts and context sent to LLM
# Database
DATABASE_URL: str = "postgresql://empire_user:empire_pass@localhost:5432/empire_db"
# Redis
REDIS_URL: str = "redis://localhost:6379"
# Security
JWT_SECRET: str = "your-super-secret-jwt-key-here"
JWT_ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
REFRESH_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 7 # 7 days
SESSION_EXPIRE_MINUTES: int = 60 * 24 # 24 hours
API_KEY_PREFIX: str = "ce_"
# Admin user provisioning
ADMIN_USER: str = "admin"
ADMIN_PASSWORD: str = "admin123"
ADMIN_EMAIL: Optional[str] = None
# CORS
CORS_ORIGINS: List[str] = ["http://localhost:3000", "http://localhost:8000"]
# LiteLLM
LITELLM_BASE_URL: str = "http://localhost:4000"
LITELLM_MASTER_KEY: str = "empire-master-key"
# API Keys for LLM providers
OPENAI_API_KEY: Optional[str] = None
ANTHROPIC_API_KEY: Optional[str] = None
GOOGLE_API_KEY: Optional[str] = None
PRIVATEMODE_API_KEY: Optional[str] = None
# Qdrant
QDRANT_HOST: str = "localhost"
QDRANT_PORT: int = 6333
QDRANT_API_KEY: Optional[str] = None
# API & Security Settings
API_SECURITY_ENABLED: bool = True
API_THREAT_DETECTION_ENABLED: bool = True
API_IP_REPUTATION_ENABLED: bool = True
API_ANOMALY_DETECTION_ENABLED: bool = True
# Rate Limiting Configuration
API_RATE_LIMITING_ENABLED: bool = True
# Authenticated users (JWT token)
API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = 300
API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = 5000
# API key users (programmatic access)
API_RATE_LIMIT_API_KEY_PER_MINUTE: int = 1000
API_RATE_LIMIT_API_KEY_PER_HOUR: int = 20000
# Premium/Enterprise API keys
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = 5000
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = 100000
# Security Thresholds
API_SECURITY_RISK_THRESHOLD: float = 0.8 # Block requests above this risk score
API_SECURITY_WARNING_THRESHOLD: float = 0.6 # Log warnings above this threshold
API_SECURITY_ANOMALY_THRESHOLD: float = 0.7 # Flag anomalies above this threshold
# Request Size Limits
API_MAX_REQUEST_BODY_SIZE: int = 10 * 1024 * 1024 # 10MB
API_MAX_REQUEST_BODY_SIZE_PREMIUM: int = 50 * 1024 * 1024 # 50MB for premium
# IP Security
API_BLOCKED_IPS: List[str] = [] # IPs to always block
API_ALLOWED_IPS: List[str] = [] # IPs to always allow (empty = allow all)
API_IP_REPUTATION_CACHE_TTL: int = 3600 # 1 hour
# Security Headers
API_SECURITY_HEADERS_ENABLED: bool = True
API_CSP_HEADER: str = "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'"
# Monitoring
PROMETHEUS_ENABLED: bool = True
PROMETHEUS_PORT: int = 9090
# File uploads
MAX_UPLOAD_SIZE: int = 10 * 1024 * 1024 # 10MB
# Module configuration
MODULES_CONFIG_PATH: str = "config/modules.yaml"
# Logging
LOG_FORMAT: str = "json"
LOG_LEVEL: str = "INFO"
@field_validator("CORS_ORIGINS", mode="before")
@classmethod
def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str]:
if isinstance(v, str) and not v.startswith("["):
return [i.strip() for i in v.split(",")]
elif isinstance(v, (list, str)):
return v
raise ValueError(v)
model_config = {
"env_file": ".env",
"case_sensitive": True
}
# Global settings instance
settings = Settings()

153
backend/app/core/logging.py Normal file
View File

@@ -0,0 +1,153 @@
"""
Logging configuration
"""
import logging
import sys
from typing import Any, Dict
import structlog
from structlog.stdlib import LoggerFactory
from app.core.config import settings
def setup_logging() -> None:
"""Setup structured logging"""
# Configure structlog
structlog.configure(
processors=[
structlog.stdlib.filter_by_level,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
structlog.processors.JSONRenderer() if settings.LOG_FORMAT == "json" else structlog.dev.ConsoleRenderer(),
],
context_class=dict,
logger_factory=LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)
# Configure standard logging
logging.basicConfig(
format="%(message)s",
stream=sys.stdout,
level=getattr(logging, settings.LOG_LEVEL.upper()),
)
# Set specific loggers
logging.getLogger("uvicorn").setLevel(logging.WARNING)
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
def get_logger(name: str) -> structlog.stdlib.BoundLogger:
"""Get a structured logger"""
return structlog.get_logger(name)
class RequestContextFilter(logging.Filter):
"""Add request context to log records"""
def filter(self, record: logging.LogRecord) -> bool:
# Add request context if available
from contextvars import ContextVar
request_id: ContextVar[str] = ContextVar("request_id", default="")
user_id: ContextVar[str] = ContextVar("user_id", default="")
record.request_id = request_id.get()
record.user_id = user_id.get()
return True
def log_request(
method: str,
path: str,
status_code: int,
processing_time: float,
user_id: str = None,
request_id: str = None,
**kwargs: Any,
) -> None:
"""Log HTTP request"""
logger = get_logger("api.request")
log_data = {
"method": method,
"path": path,
"status_code": status_code,
"processing_time": processing_time,
"user_id": user_id,
"request_id": request_id,
**kwargs,
}
if status_code >= 500:
logger.error("Request failed", **log_data)
elif status_code >= 400:
logger.warning("Request error", **log_data)
else:
logger.info("Request completed", **log_data)
def log_security_event(
event_type: str,
user_id: str = None,
ip_address: str = None,
details: Dict[str, Any] = None,
**kwargs: Any,
) -> None:
"""Log security event"""
logger = get_logger("security")
log_data = {
"event_type": event_type,
"user_id": user_id,
"ip_address": ip_address,
"details": details or {},
**kwargs,
}
logger.warning("Security event", **log_data)
def log_module_event(
module_id: str,
event_type: str,
details: Dict[str, Any] = None,
**kwargs: Any,
) -> None:
"""Log module event"""
logger = get_logger("module")
log_data = {
"module_id": module_id,
"event_type": event_type,
"details": details or {},
**kwargs,
}
logger.info("Module event", **log_data)
def log_api_request(
endpoint: str,
params: Dict[str, Any] = None,
**kwargs: Any,
) -> None:
"""Log API request for modules endpoints"""
logger = get_logger("api.module")
log_data = {
"endpoint": endpoint,
"params": params or {},
**kwargs,
}
logger.info("API request", **log_data)

View File

@@ -0,0 +1,333 @@
"""
Security utilities for authentication and authorization
"""
import logging
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from uuid import UUID
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.db.database import get_db
from app.utils.exceptions import AuthenticationError, AuthorizationError
logger = logging.getLogger(__name__)
# Password hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# JWT token handling
security = HTTPBearer()
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash"""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""Generate password hash"""
return pwd_context.hash(password)
def verify_api_key(plain_api_key: str, hashed_api_key: str) -> bool:
"""Verify an API key against its hash"""
return pwd_context.verify(plain_api_key, hashed_api_key)
def get_api_key_hash(api_key: str) -> str:
"""Generate API key hash"""
return pwd_context.hash(api_key)
def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
"""Create JWT access token"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
return encoded_jwt
def create_refresh_token(data: Dict[str, Any]) -> str:
"""Create JWT refresh token"""
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
return encoded_jwt
def verify_token(token: str) -> Dict[str, Any]:
"""Verify JWT token and return payload"""
try:
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
return payload
except JWTError as e:
logger.warning(f"Token verification failed: {e}")
raise AuthenticationError("Invalid token")
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db)
) -> Dict[str, Any]:
"""Get current user from JWT token"""
try:
payload = verify_token(credentials.credentials)
user_id: str = payload.get("sub")
if user_id is None:
raise AuthenticationError("Invalid token payload")
# Load user from database
from app.models.user import User
from sqlalchemy import select
# Query user from database
stmt = select(User).where(User.id == int(user_id))
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
# If user doesn't exist in DB but token is valid, create basic user info from token
return {
"id": int(user_id),
"email": payload.get("email"),
"is_superuser": payload.get("is_superuser", False),
"role": payload.get("role", "user"),
"is_active": True,
"permissions": [] # Default to empty list for permissions
}
# Update last login
user.update_last_login()
await db.commit()
# Calculate effective permissions using permission manager
from app.services.permission_manager import permission_registry
# Convert role string to list for permission calculation
user_roles = [user.role] if user.role else []
# For super admin users, use only role-based permissions, ignore custom permissions
# Custom permissions might contain legacy formats like ['*'] that don't work with new system
custom_permissions = []
if not user.is_superuser:
# Only use custom permissions for non-superuser accounts
if user.permissions:
if isinstance(user.permissions, list):
custom_permissions = user.permissions
# Calculate effective permissions based on role and custom permissions
effective_permissions = permission_registry.get_user_permissions(
roles=user_roles,
custom_permissions=custom_permissions
)
return {
"id": user.id,
"email": user.email,
"username": user.username,
"is_superuser": user.is_superuser,
"is_active": user.is_active,
"role": user.role,
"permissions": effective_permissions, # Use calculated permissions
"user_obj": user # Include full user object for other operations
}
except Exception as e:
logger.error(f"Authentication error: {e}")
raise AuthenticationError("Could not validate credentials")
async def get_current_active_user(
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""Get current active user"""
# Check if user is active in database
if not current_user.get("is_active", False):
raise AuthenticationError("User account is inactive")
return current_user
async def get_current_superuser(
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""Get current superuser"""
if not current_user.get("is_superuser"):
raise AuthorizationError("Insufficient privileges")
return current_user
def generate_api_key() -> str:
"""Generate a new API key"""
import secrets
import string
# Generate random string
alphabet = string.ascii_letters + string.digits
api_key = ''.join(secrets.choice(alphabet) for _ in range(32))
return f"{settings.API_KEY_PREFIX}{api_key}"
def hash_api_key(api_key: str) -> str:
"""Hash API key for storage"""
return get_password_hash(api_key)
def verify_api_key(api_key: str, hashed_key: str) -> bool:
"""Verify API key against hash"""
return verify_password(api_key, hashed_key)
async def get_api_key_user(
request: Request,
db: AsyncSession = Depends(get_db)
) -> Optional[Dict[str, Any]]:
"""Get user from API key"""
api_key = request.headers.get("X-API-Key")
if not api_key:
return None
# Implement API key lookup in database
from app.models.api_key import APIKey
from app.models.user import User
from sqlalchemy import select
try:
# Extract key prefix for lookup
if len(api_key) < 8:
return None
key_prefix = api_key[:8]
# Query API key from database
stmt = select(APIKey).join(User).where(
APIKey.key_prefix == key_prefix,
APIKey.is_active == True,
User.is_active == True
)
result = await db.execute(stmt)
db_api_key = result.scalar_one_or_none()
if not db_api_key:
return None
# Verify the API key hash
if not verify_api_key(api_key, db_api_key.key_hash):
return None
# Check if key is valid (not expired)
if not db_api_key.is_valid():
return None
# Update last used timestamp
db_api_key.last_used_at = datetime.utcnow()
await db.commit()
# Load associated user
user_stmt = select(User).where(User.id == db_api_key.user_id)
user_result = await db.execute(user_stmt)
user = user_result.scalar_one_or_none()
if not user or not user.is_active:
return None
# Calculate effective permissions using permission manager
from app.services.permission_manager import permission_registry
# Convert role string to list for permission calculation
user_roles = [user.role] if user.role else []
# Use API key specific permissions if available, otherwise use user permissions
api_key_permissions = db_api_key.permissions if db_api_key.permissions else []
# Get custom permissions from database (convert dict to list if needed)
custom_permissions = api_key_permissions
if user.permissions:
if isinstance(user.permissions, list):
custom_permissions.extend(user.permissions)
# Calculate effective permissions based on role and custom permissions
effective_permissions = permission_registry.get_user_permissions(
roles=user_roles,
custom_permissions=custom_permissions
)
return {
"id": user.id,
"email": user.email,
"username": user.username,
"is_superuser": user.is_superuser,
"is_active": user.is_active,
"role": user.role,
"permissions": effective_permissions,
"api_key": db_api_key,
"user_obj": user,
"auth_type": "api_key"
}
except Exception as e:
logger.error(f"API key lookup error: {e}")
return None
class RequiresPermission:
"""Dependency class for permission checking"""
def __init__(self, permission: str):
self.permission = permission
def __call__(self, current_user: Dict[str, Any] = Depends(get_current_user)):
# Implement permission checking
# Check if user is superuser (has all permissions)
if current_user.get("is_superuser", False):
return current_user
# Check role-based permissions
role = current_user.get("role", "user")
role_permissions = {
"user": ["read_own", "create_own", "update_own"],
"admin": ["read_all", "create_all", "update_all", "delete_own"],
"super_admin": ["read_all", "create_all", "update_all", "delete_all", "manage_users", "manage_modules"]
}
if role in role_permissions and self.permission in role_permissions[role]:
return current_user
# Check custom permissions
user_permissions = current_user.get("permissions", {})
if self.permission in user_permissions:
return current_user
# If user has access to full user object, use the model's has_permission method
user_obj = current_user.get("user_obj")
if user_obj and hasattr(user_obj, "has_permission"):
if user_obj.has_permission(self.permission):
return current_user
raise AuthorizationError(f"Permission '{self.permission}' required")
class RequiresRole:
"""Dependency class for role checking"""
def __init__(self, role: str):
self.role = role
def __call__(self, current_user: Dict[str, Any] = Depends(get_current_user)):
# Implement role checking
# Superusers have access to everything
if current_user.get("is_superuser", False):
return current_user
user_role = current_user.get("role", "user")
# Define role hierarchy
role_hierarchy = {
"user": 1,
"admin": 2,
"super_admin": 3
}
required_level = role_hierarchy.get(self.role, 0)
user_level = role_hierarchy.get(user_role, 0)
if user_level >= required_level:
return current_user
raise AuthorizationError(f"Role '{self.role}' required, but user has role '{user_role}'")

View File

@@ -0,0 +1,744 @@
"""
Core threat detection and security analysis for the platform
"""
import re
import time
from collections import defaultdict, deque
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple, Any, Union
from urllib.parse import unquote
from fastapi import Request
from app.core.config import settings
from app.core.logging import get_logger
logger = get_logger(__name__)
class ThreatLevel(Enum):
"""Threat severity levels"""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class AuthLevel(Enum):
"""Authentication levels for rate limiting"""
AUTHENTICATED = "authenticated"
API_KEY = "api_key"
PREMIUM = "premium"
@dataclass
class SecurityThreat:
"""Security threat detection result"""
threat_type: str
level: ThreatLevel
confidence: float
description: str
source_ip: str
user_agent: Optional[str] = None
request_path: Optional[str] = None
payload: Optional[str] = None
timestamp: datetime = field(default_factory=datetime.utcnow)
mitigation: Optional[str] = None
@dataclass
class SecurityAnalysis:
"""Comprehensive security analysis result"""
is_threat: bool
threats: List[SecurityThreat]
risk_score: float
recommendations: List[str]
auth_level: AuthLevel
rate_limit_exceeded: bool
should_block: bool
timestamp: datetime = field(default_factory=datetime.utcnow)
@dataclass
class RateLimitInfo:
"""Rate limiting information"""
auth_level: AuthLevel
requests_per_minute: int
requests_per_hour: int
minute_limit: int
hour_limit: int
exceeded: bool
@dataclass
class AnomalyDetection:
"""Anomaly detection result"""
is_anomaly: bool
anomaly_type: str
severity: float
details: Dict[str, Any]
baseline_value: Optional[float] = None
current_value: Optional[float] = None
class ThreatDetectionService:
"""Core threat detection and security analysis service"""
def __init__(self):
self.name = "threat_detection"
# Statistics
self.stats = {
'total_requests_analyzed': 0,
'threats_detected': 0,
'threats_blocked': 0,
'anomalies_detected': 0,
'rate_limits_exceeded': 0,
'total_analysis_time': 0,
'threat_types': defaultdict(int),
'threat_levels': defaultdict(int),
'attacking_ips': defaultdict(int)
}
# Threat detection patterns
self.sql_injection_patterns = [
r"(\bunion\b.*\bselect\b)",
r"(\bselect\b.*\bfrom\b)",
r"(\binsert\b.*\binto\b)",
r"(\bupdate\b.*\bset\b)",
r"(\bdelete\b.*\bfrom\b)",
r"(\bdrop\b.*\btable\b)",
r"(\bor\b.*\b1\s*=\s*1\b)",
r"(\band\b.*\b1\s*=\s*1\b)",
r"(\bexec\b.*\bxp_\w+)",
r"(\bsp_\w+)",
r"(\bsleep\b\s*\(\s*\d+\s*\))",
r"(\bwaitfor\b.*\bdelay\b)",
r"(\bbenchmark\b\s*\(\s*\d+)",
r"(\bload_file\b\s*\()",
r"(\binto\b.*\boutfile\b)"
]
self.xss_patterns = [
r"<script[^>]*>.*?</script>",
r"<iframe[^>]*>.*?</iframe>",
r"<object[^>]*>.*?</object>",
r"<embed[^>]*>.*?</embed>",
r"<link[^>]*>",
r"<meta[^>]*>",
r"javascript:",
r"vbscript:",
r"on\w+\s*=",
r"style\s*=.*expression",
r"style\s*=.*javascript"
]
self.path_traversal_patterns = [
r"\.\.\/",
r"\.\.\\",
r"%2e%2e%2f",
r"%2e%2e%5c",
r"..%2f",
r"..%5c",
r"%252e%252e%252f",
r"%252e%252e%255c"
]
self.command_injection_patterns = [
r";\s*cat\s+",
r";\s*ls\s+",
r";\s*pwd\s*",
r";\s*whoami\s*",
r";\s*id\s*",
r";\s*uname\s*",
r";\s*ps\s+",
r";\s*netstat\s+",
r";\s*wget\s+",
r";\s*curl\s+",
r"\|\s*cat\s+",
r"\|\s*ls\s+",
r"&&\s*cat\s+",
r"&&\s*ls\s+"
]
self.suspicious_ua_patterns = [
r"sqlmap",
r"nikto",
r"nmap",
r"masscan",
r"zap",
r"burp",
r"w3af",
r"acunetix",
r"nessus",
r"openvas",
r"metasploit"
]
# Rate limiting tracking - separate by auth level (excluding unauthenticated since they're blocked)
self.rate_limits = {
AuthLevel.AUTHENTICATED: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)}),
AuthLevel.API_KEY: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)}),
AuthLevel.PREMIUM: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)})
}
# Anomaly detection
self.request_history = deque(maxlen=1000)
self.ip_history = defaultdict(lambda: deque(maxlen=100))
self.endpoint_history = defaultdict(lambda: deque(maxlen=100))
# Blocked and allowed IPs
self.blocked_ips = set(settings.API_BLOCKED_IPS)
self.allowed_ips = set(settings.API_ALLOWED_IPS) if settings.API_ALLOWED_IPS else None
# IP reputation cache
self.ip_reputation_cache = {}
self.cache_expiry = {}
# Compile patterns for performance
self._compile_patterns()
logger.info(f"ThreatDetectionService initialized with {len(self.sql_injection_patterns)} SQL patterns, "
f"{len(self.xss_patterns)} XSS patterns, rate limiting enabled: {settings.API_RATE_LIMITING_ENABLED}")
def _compile_patterns(self):
"""Compile regex patterns for better performance"""
try:
self.compiled_sql_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.sql_injection_patterns]
self.compiled_xss_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.xss_patterns]
self.compiled_path_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.path_traversal_patterns]
self.compiled_cmd_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.command_injection_patterns]
self.compiled_ua_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.suspicious_ua_patterns]
except re.error as e:
logger.error(f"Failed to compile security patterns: {e}")
# Fallback to empty lists to prevent crashes
self.compiled_sql_patterns = []
self.compiled_xss_patterns = []
self.compiled_path_patterns = []
self.compiled_cmd_patterns = []
self.compiled_ua_patterns = []
def determine_auth_level(self, request: Request, user_context: Optional[Dict] = None) -> AuthLevel:
"""Determine authentication level for rate limiting"""
# Check if request has API key authentication
if hasattr(request.state, 'api_key_context') and request.state.api_key_context:
api_key = request.state.api_key_context.get('api_key')
if api_key and hasattr(api_key, 'tier'):
# Check for premium tier
if api_key.tier in ['premium', 'enterprise']:
return AuthLevel.PREMIUM
return AuthLevel.API_KEY
# Check for JWT authentication
if user_context or hasattr(request.state, 'user'):
return AuthLevel.AUTHENTICATED
# Check Authorization header for API key
auth_header = request.headers.get("Authorization", "")
api_key_header = request.headers.get("X-API-Key", "")
if auth_header.startswith("Bearer ") or api_key_header:
return AuthLevel.API_KEY
# Default to authenticated since unauthenticated requests are blocked at middleware
return AuthLevel.AUTHENTICATED
def get_rate_limits(self, auth_level: AuthLevel) -> Tuple[int, int]:
"""Get rate limits for authentication level"""
if not settings.API_RATE_LIMITING_ENABLED:
return float('inf'), float('inf')
if auth_level == AuthLevel.AUTHENTICATED:
return (settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE, settings.API_RATE_LIMIT_AUTHENTICATED_PER_HOUR)
elif auth_level == AuthLevel.API_KEY:
return (settings.API_RATE_LIMIT_API_KEY_PER_MINUTE, settings.API_RATE_LIMIT_API_KEY_PER_HOUR)
elif auth_level == AuthLevel.PREMIUM:
return (settings.API_RATE_LIMIT_PREMIUM_PER_MINUTE, settings.API_RATE_LIMIT_PREMIUM_PER_HOUR)
else:
# Fallback to authenticated limits
return (settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE, settings.API_RATE_LIMIT_AUTHENTICATED_PER_HOUR)
def check_rate_limit(self, client_ip: str, auth_level: AuthLevel) -> RateLimitInfo:
"""Check if request exceeds rate limits"""
minute_limit, hour_limit = self.get_rate_limits(auth_level)
current_time = time.time()
# Get or create tracking for this auth level
if auth_level not in self.rate_limits:
# This shouldn't happen, but handle gracefully
return RateLimitInfo(
auth_level=auth_level,
requests_per_minute=0,
requests_per_hour=0,
minute_limit=minute_limit,
hour_limit=hour_limit,
exceeded=False
)
ip_limits = self.rate_limits[auth_level][client_ip]
# Clean old entries
minute_ago = current_time - 60
hour_ago = current_time - 3600
while ip_limits['minute'] and ip_limits['minute'][0] < minute_ago:
ip_limits['minute'].popleft()
while ip_limits['hour'] and ip_limits['hour'][0] < hour_ago:
ip_limits['hour'].popleft()
# Check current counts
requests_per_minute = len(ip_limits['minute'])
requests_per_hour = len(ip_limits['hour'])
# Check if limits exceeded
exceeded = (requests_per_minute >= minute_limit) or (requests_per_hour >= hour_limit)
# Add current request to tracking
if not exceeded:
ip_limits['minute'].append(current_time)
ip_limits['hour'].append(current_time)
return RateLimitInfo(
auth_level=auth_level,
requests_per_minute=requests_per_minute,
requests_per_hour=requests_per_hour,
minute_limit=minute_limit,
hour_limit=hour_limit,
exceeded=exceeded
)
async def analyze_request(self, request: Request, user_context: Optional[Dict] = None) -> SecurityAnalysis:
"""Perform comprehensive security analysis on a request"""
start_time = time.time()
try:
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "")
path = str(request.url.path)
method = request.method
# Determine authentication level
auth_level = self.determine_auth_level(request, user_context)
# Check IP allowlist/blocklist first
if self.allowed_ips and client_ip not in self.allowed_ips:
threat = SecurityThreat(
threat_type="ip_not_allowed",
level=ThreatLevel.HIGH,
confidence=1.0,
description=f"IP {client_ip} not in allowlist",
source_ip=client_ip,
mitigation="Add IP to allowlist or remove IP restrictions"
)
return SecurityAnalysis(
is_threat=True,
threats=[threat],
risk_score=1.0,
recommendations=["Block request immediately"],
auth_level=auth_level,
rate_limit_exceeded=False,
should_block=True
)
if client_ip in self.blocked_ips:
threat = SecurityThreat(
threat_type="ip_blocked",
level=ThreatLevel.CRITICAL,
confidence=1.0,
description=f"IP {client_ip} is blocked",
source_ip=client_ip,
mitigation="Remove IP from blocklist if legitimate"
)
return SecurityAnalysis(
is_threat=True,
threats=[threat],
risk_score=1.0,
recommendations=["Block request immediately"],
auth_level=auth_level,
rate_limit_exceeded=False,
should_block=True
)
# Check rate limiting
rate_limit_info = self.check_rate_limit(client_ip, auth_level)
if rate_limit_info.exceeded:
self.stats['rate_limits_exceeded'] += 1
threat = SecurityThreat(
threat_type="rate_limit_exceeded",
level=ThreatLevel.MEDIUM,
confidence=0.9,
description=f"Rate limit exceeded for {auth_level.value}: {rate_limit_info.requests_per_minute}/min, {rate_limit_info.requests_per_hour}/hr",
source_ip=client_ip,
mitigation=f"Implement rate limiting, current limits: {rate_limit_info.minute_limit}/min, {rate_limit_info.hour_limit}/hr"
)
return SecurityAnalysis(
is_threat=True,
threats=[threat],
risk_score=0.7,
recommendations=[f"Rate limit exceeded for {auth_level.value} user"],
auth_level=auth_level,
rate_limit_exceeded=True,
should_block=True
)
# Skip threat detection if disabled
if not settings.API_THREAT_DETECTION_ENABLED:
return SecurityAnalysis(
is_threat=False,
threats=[],
risk_score=0.0,
recommendations=[],
auth_level=auth_level,
rate_limit_exceeded=False,
should_block=False
)
# Collect request data for threat analysis
query_params = str(request.query_params)
headers = dict(request.headers)
# Try to get body content safely
body_content = ""
try:
if hasattr(request, '_body') and request._body:
body_content = request._body.decode() if isinstance(request._body, bytes) else str(request._body)
except:
pass
threats = []
# Analyze for various threats
threats.extend(await self._detect_sql_injection(query_params, body_content, path, client_ip))
threats.extend(await self._detect_xss(query_params, body_content, headers, client_ip))
threats.extend(await self._detect_path_traversal(path, query_params, client_ip))
threats.extend(await self._detect_command_injection(query_params, body_content, client_ip))
threats.extend(await self._detect_suspicious_patterns(headers, user_agent, path, client_ip))
# Anomaly detection if enabled
if settings.API_ANOMALY_DETECTION_ENABLED:
anomaly = await self._detect_anomalies(client_ip, path, method, len(body_content))
if anomaly.is_anomaly and anomaly.severity > settings.API_SECURITY_ANOMALY_THRESHOLD:
threat = SecurityThreat(
threat_type=f"anomaly_{anomaly.anomaly_type}",
level=ThreatLevel.MEDIUM if anomaly.severity > 0.7 else ThreatLevel.LOW,
confidence=anomaly.severity,
description=f"Anomalous behavior detected: {anomaly.details}",
source_ip=client_ip,
user_agent=user_agent,
request_path=path
)
threats.append(threat)
# Calculate risk score
risk_score = self._calculate_risk_score(threats)
# Determine if request should be blocked
should_block = risk_score >= settings.API_SECURITY_RISK_THRESHOLD
# Generate recommendations
recommendations = self._generate_recommendations(threats, risk_score, auth_level)
# Update statistics
self._update_stats(threats, time.time() - start_time)
return SecurityAnalysis(
is_threat=len(threats) > 0,
threats=threats,
risk_score=risk_score,
recommendations=recommendations,
auth_level=auth_level,
rate_limit_exceeded=False,
should_block=should_block
)
except Exception as e:
logger.error(f"Error in threat analysis: {e}")
return SecurityAnalysis(
is_threat=False,
threats=[],
risk_score=0.0,
recommendations=["Error occurred during security analysis"],
auth_level=AuthLevel.AUTHENTICATED,
rate_limit_exceeded=False,
should_block=False
)
async def _detect_sql_injection(self, query_params: str, body_content: str, path: str, client_ip: str) -> List[SecurityThreat]:
"""Detect SQL injection attempts"""
threats = []
content_to_check = f"{query_params} {body_content} {path}".lower()
for pattern in self.compiled_sql_patterns:
if pattern.search(content_to_check):
threat = SecurityThreat(
threat_type="sql_injection",
level=ThreatLevel.HIGH,
confidence=0.85,
description="Potential SQL injection attempt detected",
source_ip=client_ip,
payload=pattern.pattern,
mitigation="Block request, sanitize input, use parameterized queries"
)
threats.append(threat)
break # Don't duplicate for multiple patterns
return threats
async def _detect_xss(self, query_params: str, body_content: str, headers: dict, client_ip: str) -> List[SecurityThreat]:
"""Detect XSS attempts"""
threats = []
content_to_check = f"{query_params} {body_content}".lower()
# Check headers for XSS
for header_name, header_value in headers.items():
content_to_check += f" {header_value}".lower()
for pattern in self.compiled_xss_patterns:
if pattern.search(content_to_check):
threat = SecurityThreat(
threat_type="xss",
level=ThreatLevel.HIGH,
confidence=0.80,
description="Potential XSS attack detected",
source_ip=client_ip,
payload=pattern.pattern,
mitigation="Block request, sanitize input, implement CSP headers"
)
threats.append(threat)
break
return threats
async def _detect_path_traversal(self, path: str, query_params: str, client_ip: str) -> List[SecurityThreat]:
"""Detect path traversal attempts"""
threats = []
content_to_check = f"{path} {query_params}".lower()
decoded_content = unquote(content_to_check)
for pattern in self.compiled_path_patterns:
if pattern.search(content_to_check) or pattern.search(decoded_content):
threat = SecurityThreat(
threat_type="path_traversal",
level=ThreatLevel.HIGH,
confidence=0.90,
description="Path traversal attempt detected",
source_ip=client_ip,
request_path=path,
mitigation="Block request, validate file paths, implement access controls"
)
threats.append(threat)
break
return threats
async def _detect_command_injection(self, query_params: str, body_content: str, client_ip: str) -> List[SecurityThreat]:
"""Detect command injection attempts"""
threats = []
content_to_check = f"{query_params} {body_content}".lower()
for pattern in self.compiled_cmd_patterns:
if pattern.search(content_to_check):
threat = SecurityThreat(
threat_type="command_injection",
level=ThreatLevel.CRITICAL,
confidence=0.95,
description="Command injection attempt detected",
source_ip=client_ip,
payload=pattern.pattern,
mitigation="Block request immediately, sanitize input, disable shell execution"
)
threats.append(threat)
break
return threats
async def _detect_suspicious_patterns(self, headers: dict, user_agent: str, path: str, client_ip: str) -> List[SecurityThreat]:
"""Detect suspicious patterns in headers and user agent"""
threats = []
# Check for suspicious user agents
ua_lower = user_agent.lower()
for pattern in self.compiled_ua_patterns:
if pattern.search(ua_lower):
threat = SecurityThreat(
threat_type="suspicious_user_agent",
level=ThreatLevel.HIGH,
confidence=0.85,
description=f"Suspicious user agent detected: {pattern.pattern}",
source_ip=client_ip,
user_agent=user_agent,
mitigation="Block request, monitor IP for further activity"
)
threats.append(threat)
break
# Check for suspicious headers
if "x-forwarded-for" in headers and "x-real-ip" in headers:
# Potential header manipulation
threat = SecurityThreat(
threat_type="header_manipulation",
level=ThreatLevel.LOW,
confidence=0.30,
description="Potential IP header manipulation detected",
source_ip=client_ip,
mitigation="Validate proxy headers, implement IP whitelisting"
)
threats.append(threat)
return threats
async def _detect_anomalies(self, client_ip: str, path: str, method: str, body_size: int) -> AnomalyDetection:
"""Detect anomalous behavior patterns"""
try:
# Request size anomaly
max_size = settings.API_MAX_REQUEST_BODY_SIZE
if body_size > max_size:
return AnomalyDetection(
is_anomaly=True,
anomaly_type="request_size",
severity=0.8,
details={"body_size": body_size, "threshold": max_size},
current_value=body_size,
baseline_value=max_size // 10
)
# Unusual endpoint access
if path.startswith("/admin") or path.startswith("/api/admin"):
return AnomalyDetection(
is_anomaly=True,
anomaly_type="sensitive_endpoint",
severity=0.6,
details={"path": path, "reason": "admin endpoint access"},
current_value=1.0,
baseline_value=0.0
)
# IP request frequency anomaly
current_time = time.time()
ip_requests = self.ip_history[client_ip]
# Clean old entries (last 5 minutes)
five_minutes_ago = current_time - 300
while ip_requests and ip_requests[0] < five_minutes_ago:
ip_requests.popleft()
ip_requests.append(current_time)
if len(ip_requests) > 100: # More than 100 requests in 5 minutes
return AnomalyDetection(
is_anomaly=True,
anomaly_type="request_frequency",
severity=0.7,
details={"requests_5min": len(ip_requests), "threshold": 100},
current_value=len(ip_requests),
baseline_value=10 # 10 requests baseline
)
return AnomalyDetection(
is_anomaly=False,
anomaly_type="none",
severity=0.0,
details={}
)
except Exception as e:
logger.error(f"Error in anomaly detection: {e}")
return AnomalyDetection(
is_anomaly=False,
anomaly_type="error",
severity=0.0,
details={"error": str(e)}
)
def _calculate_risk_score(self, threats: List[SecurityThreat]) -> float:
"""Calculate overall risk score based on threats"""
if not threats:
return 0.0
score = 0.0
for threat in threats:
level_multiplier = {
ThreatLevel.LOW: 0.25,
ThreatLevel.MEDIUM: 0.5,
ThreatLevel.HIGH: 0.75,
ThreatLevel.CRITICAL: 1.0
}
score += threat.confidence * level_multiplier.get(threat.level, 0.5)
# Normalize to 0-1 range
return min(score / len(threats), 1.0)
def _generate_recommendations(self, threats: List[SecurityThreat], risk_score: float, auth_level: AuthLevel) -> List[str]:
"""Generate security recommendations based on analysis"""
recommendations = []
if risk_score >= settings.API_SECURITY_RISK_THRESHOLD:
recommendations.append("CRITICAL: Block this request immediately")
elif risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
recommendations.append("HIGH: Consider blocking or rate limiting this IP")
elif risk_score > 0.4:
recommendations.append("MEDIUM: Monitor this IP closely")
threat_types = {threat.threat_type for threat in threats}
if "sql_injection" in threat_types:
recommendations.append("Implement parameterized queries and input validation")
if "xss" in threat_types:
recommendations.append("Implement Content Security Policy (CSP) headers")
if "command_injection" in threat_types:
recommendations.append("Disable shell execution and validate all inputs")
if "path_traversal" in threat_types:
recommendations.append("Implement proper file path validation and access controls")
if "rate_limit_exceeded" in threat_types:
recommendations.append(f"Rate limiting active for {auth_level.value} user")
if not recommendations:
recommendations.append("No immediate action required, continue monitoring")
return recommendations
def _update_stats(self, threats: List[SecurityThreat], analysis_time: float):
"""Update service statistics"""
self.stats['total_requests_analyzed'] += 1
self.stats['total_analysis_time'] += analysis_time
if threats:
self.stats['threats_detected'] += len(threats)
for threat in threats:
self.stats['threat_types'][threat.threat_type] += 1
self.stats['threat_levels'][threat.level.value] += 1
if threat.source_ip:
self.stats['attacking_ips'][threat.source_ip] += 1
def get_stats(self) -> Dict[str, Any]:
"""Get service statistics"""
avg_time = (self.stats['total_analysis_time'] / self.stats['total_requests_analyzed']
if self.stats['total_requests_analyzed'] > 0 else 0)
# Get top attacking IPs
top_ips = sorted(self.stats['attacking_ips'].items(), key=lambda x: x[1], reverse=True)[:10]
return {
"total_requests_analyzed": self.stats['total_requests_analyzed'],
"threats_detected": self.stats['threats_detected'],
"threats_blocked": self.stats['threats_blocked'],
"anomalies_detected": self.stats['anomalies_detected'],
"rate_limits_exceeded": self.stats['rate_limits_exceeded'],
"avg_analysis_time": avg_time,
"threat_types": dict(self.stats['threat_types']),
"threat_levels": dict(self.stats['threat_levels']),
"top_attacking_ips": top_ips,
"security_enabled": settings.API_SECURITY_ENABLED,
"threat_detection_enabled": settings.API_THREAT_DETECTION_ENABLED,
"rate_limiting_enabled": settings.API_RATE_LIMITING_ENABLED
}
# Global threat detection service instance
threat_detection_service = ThreatDetectionService()

View File

@@ -0,0 +1,3 @@
"""
Database package
"""

160
backend/app/db/database.py Normal file
View File

@@ -0,0 +1,160 @@
"""
Database connection and session management
"""
import logging
from typing import AsyncGenerator
from sqlalchemy import create_engine, MetaData
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.pool import StaticPool
from app.core.config import settings
logger = logging.getLogger(__name__)
# Create async engine with optimized connection pooling
engine = create_async_engine(
settings.DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://"),
echo=settings.APP_DEBUG,
future=True,
pool_pre_ping=True,
pool_size=50, # Increased from 20 for better concurrency
max_overflow=100, # Increased from 30 for burst capacity
pool_recycle=3600, # Recycle connections every hour
pool_timeout=30, # Max time to get connection from pool
connect_args={
"command_timeout": 5,
"server_settings": {
"application_name": "shifra_backend",
},
},
)
# Create async session factory
async_session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
)
# Create synchronous engine and session for budget enforcement (optimized)
sync_engine = create_engine(
settings.DATABASE_URL,
echo=settings.APP_DEBUG,
future=True,
pool_pre_ping=True,
pool_size=25, # Increased from 10 for better performance
max_overflow=50, # Increased from 20 for burst capacity
pool_recycle=3600, # Recycle connections every hour
pool_timeout=30, # Max time to get connection from pool
connect_args={
"application_name": "shifra_backend_sync",
},
)
# Create sync session factory
SessionLocal = sessionmaker(
bind=sync_engine,
expire_on_commit=False,
)
# Create base class for models
Base = declarative_base()
# Metadata for migrations
metadata = MetaData()
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""Get database session"""
async with async_session_factory() as session:
try:
yield session
except Exception as e:
logger.error(f"Database session error: {e}")
await session.rollback()
raise
finally:
await session.close()
async def init_db():
"""Initialize database"""
try:
async with engine.begin() as conn:
# Import all models to ensure they're registered
from app.models.user import User
from app.models.api_key import APIKey
from app.models.usage_tracking import UsageTracking
# Import additional models - these are available
try:
from app.models.budget import Budget
except ImportError:
logger.warning("Budget model not available yet")
try:
from app.models.audit_log import AuditLog
except ImportError:
logger.warning("AuditLog model not available yet")
try:
from app.models.module import Module
except ImportError:
logger.warning("Module model not available yet")
# Create all tables
await conn.run_sync(Base.metadata.create_all)
# Create default admin user if no admin exists
await create_default_admin()
logger.info("Database initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize database: {e}")
raise
async def create_default_admin():
"""Create default admin user if none exists"""
from app.models.user import User
from app.core.security import get_password_hash
from app.core.config import settings
from sqlalchemy import select
try:
async with async_session_factory() as session:
# Check if any admin user exists
stmt = select(User).where(User.role == "super_admin")
result = await session.execute(stmt)
existing_admin = result.scalar_one_or_none()
if existing_admin:
logger.info("Admin user already exists")
return
# Create default admin user from environment variables
admin_username = settings.ADMIN_USER
admin_password = settings.ADMIN_PASSWORD
admin_email = settings.ADMIN_EMAIL or f"{admin_username}@example.com"
admin_user = User.create_default_admin(
email=admin_email,
username=admin_username,
password_hash=get_password_hash(admin_password)
)
session.add(admin_user)
await session.commit()
logger.warning("=" * 60)
logger.warning("DEFAULT ADMIN USER CREATED")
logger.warning(f"Email: {admin_email}")
logger.warning(f"Username: {admin_username}")
logger.warning("Password: [Set via ADMIN_PASSWORD environment variable]")
logger.warning("PLEASE CHANGE THE PASSWORD AFTER FIRST LOGIN")
logger.warning("=" * 60)
except Exception as e:
logger.error(f"Failed to create default admin user: {e}")
# Don't raise here as this shouldn't block the application startup

219
backend/app/main.py Normal file
View File

@@ -0,0 +1,219 @@
"""
Main FastAPI application entry point
"""
import logging
import sys
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException
from starlette.middleware.sessions import SessionMiddleware
from app.core.config import settings
from app.core.logging import setup_logging
from app.core.security import get_current_user
from app.db.database import init_db
from app.api.v1 import api_router
from app.utils.exceptions import CustomHTTPException
from app.services.module_manager import module_manager
from app.services.metrics import setup_metrics
from app.services.analytics import init_analytics_service
from app.middleware.analytics import setup_analytics_middleware
from app.services.config_manager import init_config_manager
# Setup logging
setup_logging()
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Application lifespan handler
"""
logger.info("Starting Confidential Empire platform...")
# Initialize database
await init_db()
# Initialize config manager
await init_config_manager()
# Initialize analytics service
init_analytics_service()
# Initialize module manager with FastAPI app for router registration
await module_manager.initialize(app)
app.state.module_manager = module_manager
# Initialize document processor
from app.services.document_processor import document_processor
await document_processor.start()
app.state.document_processor = document_processor
# Setup metrics
setup_metrics(app)
# Start background audit worker
from app.services.audit_service import start_audit_worker
start_audit_worker()
logger.info("Platform started successfully")
yield
# Cleanup
logger.info("Shutting down platform...")
# Close Redis connection for cached API key service
from app.services.cached_api_key import cached_api_key_service
await cached_api_key_service.close()
# Stop document processor
if hasattr(app.state, 'document_processor'):
await app.state.document_processor.stop()
await module_manager.cleanup()
logger.info("Platform shutdown complete")
# Create FastAPI application
app = FastAPI(
title=settings.APP_NAME,
description="Modular AI Gateway Platform with confidential AI processing",
version="1.0.0",
openapi_url="/api/v1/openapi.json",
docs_url="/api/v1/docs",
redoc_url="/api/v1/redoc",
lifespan=lifespan,
)
# Add middleware
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(GZipMiddleware, minimum_size=1000)
app.add_middleware(
SessionMiddleware,
secret_key=settings.JWT_SECRET,
max_age=settings.SESSION_EXPIRE_MINUTES * 60,
)
# Add analytics middleware
setup_analytics_middleware(app)
# Add security middleware
from app.middleware.security import setup_security_middleware
setup_security_middleware(app, enabled=settings.API_SECURITY_ENABLED)
# Exception handlers
@app.exception_handler(CustomHTTPException)
async def custom_http_exception_handler(request, exc: CustomHTTPException):
return JSONResponse(
status_code=exc.status_code,
content={
"error": exc.error_code,
"message": exc.detail,
"details": exc.details,
},
)
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc: HTTPException):
return JSONResponse(
status_code=exc.status_code,
content={
"error": "HTTP_ERROR",
"message": exc.detail,
},
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc: RequestValidationError):
# Convert validation errors to JSON-serializable format
errors = []
for error in exc.errors():
error_dict = {
"type": error.get("type", ""),
"location": error.get("loc", []),
"message": error.get("msg", ""),
"input": str(error.get("input", "")) if error.get("input") is not None else None
}
errors.append(error_dict)
return JSONResponse(
status_code=422,
content={
"error": "VALIDATION_ERROR",
"message": "Invalid request data",
"details": errors,
},
)
@app.exception_handler(Exception)
async def general_exception_handler(request, exc: Exception):
logger.error(f"Unhandled exception: {exc}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"error": "INTERNAL_SERVER_ERROR",
"message": "An unexpected error occurred",
},
)
# Include API routes
app.include_router(api_router, prefix="/api/v1")
# Include OpenAI-compatible routes
from app.api.v1.openai_compat import router as openai_router
app.include_router(openai_router, prefix="/v1", tags=["openai-compat"])
# Health check endpoint
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"app": settings.APP_NAME,
"version": "1.0.0",
}
# Root endpoint
@app.get("/")
async def root():
"""Root endpoint"""
return {
"message": "Confidential Empire - Modular AI Gateway Platform",
"version": "1.0.0",
"docs": "/api/v1/docs",
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host=settings.APP_HOST,
port=settings.APP_PORT,
reload=settings.APP_DEBUG,
log_level=settings.APP_LOG_LEVEL.lower(),
)

View File

@@ -0,0 +1,143 @@
"""
Analytics middleware for automatic request tracking
"""
import time
from datetime import datetime
from typing import Optional
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from sqlalchemy.orm import Session
from contextvars import ContextVar
from app.core.logging import get_logger
from app.services.analytics import RequestEvent, get_analytics_service
from app.db.database import get_db
logger = get_logger(__name__)
# Context variable to pass analytics data from endpoints to middleware
analytics_context: ContextVar[dict] = ContextVar('analytics_context', default={})
class AnalyticsMiddleware(BaseHTTPMiddleware):
"""Middleware to automatically track all requests for analytics"""
async def dispatch(self, request: Request, call_next):
# Start timing
start_time = time.time()
# Skip analytics for health checks and static files
if request.url.path in ["/health", "/docs", "/redoc", "/openapi.json"] or request.url.path.startswith("/static"):
return await call_next(request)
# Get user info if available from token
user_id = None
api_key_id = None
try:
authorization = request.headers.get("Authorization")
if authorization and authorization.startswith("Bearer "):
token = authorization.split(" ")[1]
# Try to extract user info from token without full validation
# This is a lightweight check for analytics purposes
from app.core.security import verify_token
try:
payload = verify_token(token)
user_id = int(payload.get("sub"))
except:
# Token might be invalid, but we still want to track the request
pass
except Exception:
# Don't let analytics break the request
pass
# Get client IP
client_ip = request.client.host if request.client else None
if not client_ip:
# Check for forwarded headers
client_ip = request.headers.get("X-Forwarded-For", "").split(",")[0].strip()
if not client_ip:
client_ip = request.headers.get("X-Real-IP", "unknown")
# Get user agent
user_agent = request.headers.get("User-Agent", "")
# Get request size
request_size = int(request.headers.get("Content-Length", 0))
# Process the request
response = None
error_message = None
try:
response = await call_next(request)
except Exception as e:
logger.error(f"Request failed: {e}")
error_message = str(e)
response = JSONResponse(
status_code=500,
content={"error": "INTERNAL_ERROR", "message": "Internal server error"}
)
# Calculate timing
end_time = time.time()
response_time = (end_time - start_time) * 1000 # Convert to milliseconds
# Get response size
response_size = 0
if hasattr(response, 'body'):
response_size = len(response.body) if response.body else 0
# Get analytics data from context (set by endpoints)
context_data = analytics_context.get({})
# Create analytics event
event = RequestEvent(
timestamp=datetime.utcnow(),
method=request.method,
path=request.url.path,
status_code=response.status_code if response else 500,
response_time=response_time,
user_id=user_id,
api_key_id=api_key_id,
ip_address=client_ip,
user_agent=user_agent,
request_size=request_size,
response_size=response_size,
error_message=error_message,
# Token/cost info populated by LLM endpoints via context
model=context_data.get('model'),
request_tokens=context_data.get('request_tokens', 0),
response_tokens=context_data.get('response_tokens', 0),
total_tokens=context_data.get('total_tokens', 0),
cost_cents=context_data.get('cost_cents', 0),
budget_ids=context_data.get('budget_ids', []),
budget_warnings=context_data.get('budget_warnings', [])
)
# Track the event
try:
from app.services.analytics import analytics_service
if analytics_service is not None:
await analytics_service.track_request(event)
else:
logger.warning("Analytics service not initialized, skipping event tracking")
except Exception as e:
logger.error(f"Failed to track analytics event: {e}")
# Don't let analytics failures break the request
return response
def set_analytics_data(**kwargs):
"""Helper function for endpoints to set analytics data"""
current_context = analytics_context.get({})
current_context.update(kwargs)
analytics_context.set(current_context)
def setup_analytics_middleware(app):
"""Add analytics middleware to the FastAPI app"""
app.add_middleware(AnalyticsMiddleware)
logger.info("Analytics middleware configured")

View File

@@ -0,0 +1,313 @@
"""
Rate limiting middleware
"""
import time
import redis
from typing import Dict, Optional
from fastapi import Request, HTTPException, status
from fastapi.responses import JSONResponse
import asyncio
from datetime import datetime, timedelta
from app.core.config import settings
from app.core.logging import get_logger
logger = get_logger(__name__)
class RateLimiter:
"""Rate limiting implementation using Redis"""
def __init__(self):
try:
self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
self.redis_client.ping() # Test connection
logger.info("Rate limiter initialized with Redis backend")
except Exception as e:
logger.warning(f"Redis not available for rate limiting: {e}")
self.redis_client = None
# Fall back to in-memory rate limiting
self.memory_store: Dict[str, Dict[str, float]] = {}
async def check_rate_limit(
self,
key: str,
limit: int,
window_seconds: int,
identifier: str = "default"
) -> tuple[bool, Dict[str, int]]:
"""
Check if request is within rate limit
Args:
key: Rate limiting key (e.g., IP address, API key)
limit: Maximum number of requests allowed
window_seconds: Time window in seconds
identifier: Additional identifier for the rate limit
Returns:
Tuple of (is_allowed, headers_dict)
"""
full_key = f"rate_limit:{identifier}:{key}"
current_time = int(time.time())
window_start = current_time - window_seconds
if self.redis_client:
return await self._check_redis_rate_limit(
full_key, limit, window_seconds, current_time, window_start
)
else:
return self._check_memory_rate_limit(
full_key, limit, window_seconds, current_time, window_start
)
async def _check_redis_rate_limit(
self,
key: str,
limit: int,
window_seconds: int,
current_time: int,
window_start: int
) -> tuple[bool, Dict[str, int]]:
"""Check rate limit using Redis"""
pipe = self.redis_client.pipeline()
# Remove old entries
pipe.zremrangebyscore(key, 0, window_start)
# Count current requests in window
pipe.zcard(key)
# Add current request
pipe.zadd(key, {str(current_time): current_time})
# Set expiration
pipe.expire(key, window_seconds + 1)
results = pipe.execute()
current_requests = results[1]
# Calculate remaining requests and reset time
remaining = max(0, limit - current_requests - 1)
reset_time = current_time + window_seconds
headers = {
"X-RateLimit-Limit": limit,
"X-RateLimit-Remaining": remaining,
"X-RateLimit-Reset": reset_time,
"X-RateLimit-Window": window_seconds
}
is_allowed = current_requests < limit
if not is_allowed:
logger.warning(f"Rate limit exceeded for key: {key}")
return is_allowed, headers
def _check_memory_rate_limit(
self,
key: str,
limit: int,
window_seconds: int,
current_time: int,
window_start: int
) -> tuple[bool, Dict[str, int]]:
"""Check rate limit using in-memory storage"""
if key not in self.memory_store:
self.memory_store[key] = {}
# Clean old entries
store = self.memory_store[key]
keys_to_remove = [k for k, v in store.items() if v < window_start]
for k in keys_to_remove:
del store[k]
current_requests = len(store)
# Calculate remaining requests and reset time
remaining = max(0, limit - current_requests - 1)
reset_time = current_time + window_seconds
headers = {
"X-RateLimit-Limit": limit,
"X-RateLimit-Remaining": remaining,
"X-RateLimit-Reset": reset_time,
"X-RateLimit-Window": window_seconds
}
is_allowed = current_requests < limit
if is_allowed:
# Add current request
store[str(current_time)] = current_time
else:
logger.warning(f"Rate limit exceeded for key: {key}")
return is_allowed, headers
# Global rate limiter instance
rate_limiter = RateLimiter()
async def rate_limit_middleware(request: Request, call_next):
"""
Rate limiting middleware for FastAPI
"""
# Skip rate limiting for health checks and static files
if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]:
response = await call_next(request)
return response
# Get client IP
client_ip = request.client.host
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
client_ip = forwarded_for.split(",")[0].strip()
# Check for API key in headers
api_key = None
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
api_key = auth_header[7:]
elif request.headers.get("X-API-Key"):
api_key = request.headers.get("X-API-Key")
# Determine rate limiting strategy
if api_key:
# API key-based rate limiting
rate_limit_key = f"api_key:{api_key}"
# Get API key limits from database (simplified - would implement proper lookup)
limit_per_minute = 100 # Default limit
limit_per_hour = 1000 # Default limit
# Check per-minute limit
is_allowed_minute, headers_minute = await rate_limiter.check_rate_limit(
rate_limit_key, limit_per_minute, 60, "minute"
)
# Check per-hour limit
is_allowed_hour, headers_hour = await rate_limiter.check_rate_limit(
rate_limit_key, limit_per_hour, 3600, "hour"
)
is_allowed = is_allowed_minute and is_allowed_hour
headers = headers_minute # Use minute headers for response
else:
# IP-based rate limiting for unauthenticated requests
rate_limit_key = f"ip:{client_ip}"
# More restrictive limits for unauthenticated requests
limit_per_minute = 20
limit_per_hour = 100
# Check per-minute limit
is_allowed_minute, headers_minute = await rate_limiter.check_rate_limit(
rate_limit_key, limit_per_minute, 60, "minute"
)
# Check per-hour limit
is_allowed_hour, headers_hour = await rate_limiter.check_rate_limit(
rate_limit_key, limit_per_hour, 3600, "hour"
)
is_allowed = is_allowed_minute and is_allowed_hour
headers = headers_minute # Use minute headers for response
# If rate limit exceeded, return 429
if not is_allowed:
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"error": "RATE_LIMIT_EXCEEDED",
"message": "Rate limit exceeded. Please try again later.",
"details": {
"limit": headers["X-RateLimit-Limit"],
"reset_time": headers["X-RateLimit-Reset"]
}
},
headers={k: str(v) for k, v in headers.items()}
)
# Continue with request
response = await call_next(request)
# Add rate limit headers to response
for key, value in headers.items():
response.headers[key] = str(value)
return response
class RateLimitExceeded(HTTPException):
"""Exception raised when rate limit is exceeded"""
def __init__(self, limit: int, reset_time: int):
super().__init__(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"Rate limit exceeded. Limit: {limit}, Reset: {reset_time}"
)
# Decorator for applying rate limits to specific endpoints
def rate_limit(requests_per_minute: int = 60, requests_per_hour: int = 1000):
"""
Decorator to apply rate limiting to specific endpoints
Args:
requests_per_minute: Maximum requests per minute
requests_per_hour: Maximum requests per hour
"""
def decorator(func):
async def wrapper(*args, **kwargs):
# This would be implemented to work with FastAPI dependencies
# For now, this is a placeholder for endpoint-specific rate limiting
return await func(*args, **kwargs)
return wrapper
return decorator
# Helper functions for different rate limiting strategies
async def check_api_key_rate_limit(api_key: str, endpoint: str) -> bool:
"""Check rate limit for specific API key and endpoint"""
# This would lookup API key specific limits from database
# For now, using default limits
key = f"api_key:{api_key}:endpoint:{endpoint}"
is_allowed, _ = await rate_limiter.check_rate_limit(
key, limit=100, window_seconds=60, identifier="endpoint"
)
return is_allowed
async def check_user_rate_limit(user_id: str, action: str) -> bool:
"""Check rate limit for specific user and action"""
key = f"user:{user_id}:action:{action}"
is_allowed, _ = await rate_limiter.check_rate_limit(
key, limit=50, window_seconds=60, identifier="user_action"
)
return is_allowed
async def apply_burst_protection(key: str) -> bool:
"""Apply burst protection for high-frequency actions"""
# Allow burst of 10 requests in 10 seconds
is_allowed, _ = await rate_limiter.check_rate_limit(
key, limit=10, window_seconds=10, identifier="burst"
)
return is_allowed

View File

@@ -0,0 +1,278 @@
"""
Security middleware for request/response processing
"""
import json
import time
from typing import Callable, Optional, Dict, Any
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from app.core.config import settings
from app.core.logging import get_logger
from app.core.threat_detection import threat_detection_service, SecurityAnalysis
logger = get_logger(__name__)
class SecurityMiddleware(BaseHTTPMiddleware):
"""Security middleware for threat detection and request filtering"""
def __init__(self, app, enabled: bool = True):
super().__init__(app)
self.enabled = enabled and settings.API_SECURITY_ENABLED
logger.info(f"SecurityMiddleware initialized, enabled: {self.enabled}")
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""Process request through security analysis"""
if not self.enabled:
# Security disabled, pass through
return await call_next(request)
# Skip security analysis for certain endpoints
if self._should_skip_security(request):
response = await call_next(request)
return self._add_security_headers(response)
# Simple authentication check - drop requests without valid auth
if not self._has_valid_auth(request):
return JSONResponse(
content={"error": "Authentication required", "message": "Valid API key or authentication token required"},
status_code=401,
headers={"WWW-Authenticate": "Bearer"}
)
try:
# Get user context if available
user_context = getattr(request.state, 'user', None)
# Perform security analysis
start_time = time.time()
analysis = await threat_detection_service.analyze_request(request, user_context)
analysis_time = time.time() - start_time
# Store analysis in request state for later use
request.state.security_analysis = analysis
# Log security events
if analysis.is_threat:
await self._log_security_event(request, analysis)
# Check if request should be blocked
if analysis.should_block:
threat_detection_service.stats['threats_blocked'] += 1
logger.warning(f"Blocked request from {request.client.host if request.client else 'unknown'}: "
f"risk_score={analysis.risk_score:.3f}, threats={len(analysis.threats)}")
# Return security block response
return self._create_block_response(analysis)
# Log warnings for medium-risk requests
if analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
logger.warning(f"High-risk request detected from {request.client.host if request.client else 'unknown'}: "
f"risk_score={analysis.risk_score:.3f}, auth_level={analysis.auth_level.value}")
# Continue with request processing
response = await call_next(request)
# Add security headers and metrics
response = self._add_security_headers(response)
response = self._add_security_metrics(response, analysis, analysis_time)
return response
except Exception as e:
logger.error(f"Security middleware error: {e}")
# Continue with request on security middleware errors to avoid breaking the app
response = await call_next(request)
return self._add_security_headers(response)
def _should_skip_security(self, request: Request) -> bool:
"""Determine if security analysis should be skipped for this request"""
path = request.url.path
# Skip for health checks and static assets
skip_paths = [
"/health",
"/metrics",
"/api/v1/docs",
"/api/v1/openapi.json",
"/api/v1/redoc",
"/favicon.ico"
]
# Skip for static file extensions
static_extensions = [".css", ".js", ".png", ".jpg", ".jpeg", ".gif", ".ico", ".svg", ".woff", ".woff2"]
return (
path in skip_paths or
any(path.endswith(ext) for ext in static_extensions) or
path.startswith("/static/")
)
def _has_valid_auth(self, request: Request) -> bool:
"""Check if request has valid authentication"""
# Check Authorization header
auth_header = request.headers.get("Authorization", "")
api_key_header = request.headers.get("X-API-Key", "")
# Has some form of auth token/key
return (
auth_header.startswith("Bearer ") and len(auth_header) > 7 or
len(api_key_header.strip()) > 0
)
def _create_block_response(self, analysis: SecurityAnalysis) -> JSONResponse:
"""Create response for blocked requests"""
# Determine status code based on threat type
status_code = 403 # Forbidden by default
# Rate limiting gets 429
if analysis.rate_limit_exceeded:
status_code = 429
# Critical threats get 403
for threat in analysis.threats:
if threat.threat_type in ["command_injection", "sql_injection"]:
status_code = 403
break
response_data = {
"error": "Security Policy Violation",
"message": "Request blocked due to security policy violation",
"risk_score": round(analysis.risk_score, 3),
"auth_level": analysis.auth_level.value,
"threat_count": len(analysis.threats),
"recommendations": analysis.recommendations[:3] # Limit to first 3 recommendations
}
# Add rate limiting info if applicable
if analysis.rate_limit_exceeded:
response_data["error"] = "Rate Limit Exceeded"
response_data["message"] = f"Rate limit exceeded for {analysis.auth_level.value} user"
response_data["retry_after"] = "60" # Suggest retry after 60 seconds
response = JSONResponse(
content=response_data,
status_code=status_code
)
# Add rate limiting headers
if analysis.rate_limit_exceeded:
response.headers["Retry-After"] = "60"
response.headers["X-RateLimit-Limit"] = "See API documentation"
response.headers["X-RateLimit-Reset"] = str(int(time.time() + 60))
return response
def _add_security_headers(self, response: Response) -> Response:
"""Add security headers to response"""
if not settings.API_SECURITY_HEADERS_ENABLED:
return response
# Standard security headers
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# Only add HSTS for HTTPS
if hasattr(response, 'headers') and response.headers.get("X-Forwarded-Proto") == "https":
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
# Content Security Policy
if settings.API_CSP_HEADER:
response.headers["Content-Security-Policy"] = settings.API_CSP_HEADER
return response
def _add_security_metrics(self, response: Response, analysis: SecurityAnalysis, analysis_time: float) -> Response:
"""Add security metrics to response headers (for debugging/monitoring)"""
# Only add in debug mode or for admin users
if settings.APP_DEBUG:
response.headers["X-Security-Risk-Score"] = str(round(analysis.risk_score, 3))
response.headers["X-Security-Threats"] = str(len(analysis.threats))
response.headers["X-Security-Auth-Level"] = analysis.auth_level.value
response.headers["X-Security-Analysis-Time"] = f"{analysis_time*1000:.1f}ms"
return response
async def _log_security_event(self, request: Request, analysis: SecurityAnalysis):
"""Log security events for audit and monitoring"""
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "")
# Create security event log
event_data = {
"timestamp": analysis.timestamp.isoformat(),
"client_ip": client_ip,
"user_agent": user_agent,
"path": str(request.url.path),
"method": request.method,
"risk_score": round(analysis.risk_score, 3),
"auth_level": analysis.auth_level.value,
"threat_count": len(analysis.threats),
"rate_limit_exceeded": analysis.rate_limit_exceeded,
"should_block": analysis.should_block,
"threats": [
{
"type": threat.threat_type,
"level": threat.level.value,
"confidence": round(threat.confidence, 3),
"description": threat.description
}
for threat in analysis.threats[:5] # Limit to first 5 threats
],
"recommendations": analysis.recommendations
}
# Log at appropriate level based on risk
if analysis.should_block:
logger.warning(f"SECURITY_BLOCK: {json.dumps(event_data)}")
elif analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
logger.warning(f"SECURITY_WARNING: {json.dumps(event_data)}")
else:
logger.info(f"SECURITY_THREAT: {json.dumps(event_data)}")
def setup_security_middleware(app, enabled: bool = True) -> None:
"""Setup security middleware on FastAPI app"""
if enabled and settings.API_SECURITY_ENABLED:
app.add_middleware(SecurityMiddleware, enabled=enabled)
logger.info("Security middleware enabled")
else:
logger.info("Security middleware disabled")
# Helper functions for manual security checks
async def analyze_request_security(request: Request, user_context: Optional[Dict] = None) -> SecurityAnalysis:
"""Manually analyze request security (for use in route handlers)"""
return await threat_detection_service.analyze_request(request, user_context)
def get_security_stats() -> Dict[str, Any]:
"""Get security statistics"""
return threat_detection_service.get_stats()
def is_request_blocked(request: Request) -> bool:
"""Check if request was blocked by security analysis"""
if hasattr(request.state, 'security_analysis'):
return request.state.security_analysis.should_block
return False
def get_request_risk_score(request: Request) -> float:
"""Get risk score for request"""
if hasattr(request.state, 'security_analysis'):
return request.state.security_analysis.risk_score
return 0.0
def get_request_auth_level(request: Request) -> str:
"""Get authentication level for request"""
if hasattr(request.state, 'security_analysis'):
return request.state.security_analysis.auth_level.value
return "unknown"

View File

@@ -0,0 +1,33 @@
"""
Database models package
"""
from .user import User
from .api_key import APIKey
from .usage_tracking import UsageTracking
from .budget import Budget
from .audit_log import AuditLog
from .rag_collection import RagCollection
from .rag_document import RagDocument
from .chatbot import ChatbotInstance, ChatbotConversation, ChatbotMessage, ChatbotAnalytics
from .prompt_template import PromptTemplate, ChatbotPromptVariable
from .workflow import WorkflowDefinition, WorkflowExecution, WorkflowStepLog
__all__ = [
"User",
"APIKey",
"UsageTracking",
"Budget",
"AuditLog",
"RagCollection",
"RagDocument",
"ChatbotInstance",
"ChatbotConversation",
"ChatbotMessage",
"ChatbotAnalytics",
"PromptTemplate",
"ChatbotPromptVariable",
"WorkflowDefinition",
"WorkflowExecution",
"WorkflowStepLog"
]

View File

@@ -0,0 +1,307 @@
"""
API Key model
"""
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON, ForeignKey
from sqlalchemy.orm import relationship
from app.db.database import Base
class APIKey(Base):
"""API Key model for authentication and access control"""
__tablename__ = "api_keys"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False) # Human-readable name for the API key
key_hash = Column(String, unique=True, index=True, nullable=False) # Hashed API key
key_prefix = Column(String, index=True, nullable=False) # First 8 characters for identification
# User relationship
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
user = relationship("User", back_populates="api_keys")
# Related data relationships
budgets = relationship("Budget", back_populates="api_key", cascade="all, delete-orphan")
usage_tracking = relationship("UsageTracking", back_populates="api_key", cascade="all, delete-orphan")
# Key status and permissions
is_active = Column(Boolean, default=True)
permissions = Column(JSON, default=dict) # Specific permissions for this key
scopes = Column(JSON, default=list) # OAuth-like scopes
# Usage limits
rate_limit_per_minute = Column(Integer, default=60) # Requests per minute
rate_limit_per_hour = Column(Integer, default=3600) # Requests per hour
rate_limit_per_day = Column(Integer, default=86400) # Requests per day
# Allowed resources
allowed_models = Column(JSON, default=list) # List of allowed LLM models
allowed_endpoints = Column(JSON, default=list) # List of allowed API endpoints
allowed_ips = Column(JSON, default=list) # IP whitelist
allowed_chatbots = Column(JSON, default=list) # List of allowed chatbot IDs for chatbot-specific keys
# Budget configuration
is_unlimited = Column(Boolean, default=True) # Unlimited budget flag
budget_limit_cents = Column(Integer, nullable=True) # Budget limit in cents
budget_type = Column(String, nullable=True) # "total" or "monthly"
# Metadata
description = Column(Text, nullable=True)
tags = Column(JSON, default=list) # For organizing keys
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
last_used_at = Column(DateTime, nullable=True)
expires_at = Column(DateTime, nullable=True) # Optional expiration
# Usage tracking
total_requests = Column(Integer, default=0)
total_tokens = Column(Integer, default=0)
total_cost = Column(Integer, default=0) # In cents
# Relationships
usage_tracking = relationship("UsageTracking", back_populates="api_key", cascade="all, delete-orphan")
budgets = relationship("Budget", back_populates="api_key", cascade="all, delete-orphan")
def __repr__(self):
return f"<APIKey(id={self.id}, name='{self.name}', user_id={self.user_id})>"
def to_dict(self, include_sensitive: bool = False):
"""Convert API key to dictionary for API responses"""
data = {
"id": self.id,
"name": self.name,
"key_prefix": self.key_prefix,
"user_id": self.user_id,
"is_active": self.is_active,
"permissions": self.permissions,
"scopes": self.scopes,
"rate_limit_per_minute": self.rate_limit_per_minute,
"rate_limit_per_hour": self.rate_limit_per_hour,
"rate_limit_per_day": self.rate_limit_per_day,
"allowed_models": self.allowed_models,
"allowed_endpoints": self.allowed_endpoints,
"allowed_ips": self.allowed_ips,
"allowed_chatbots": self.allowed_chatbots,
"description": self.description,
"tags": self.tags,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
"total_requests": self.total_requests,
"total_tokens": self.total_tokens,
"total_cost_cents": self.total_cost,
"is_unlimited": self.is_unlimited,
"budget_limit": self.budget_limit_cents, # Map to budget_limit for API response
"budget_type": self.budget_type
}
if include_sensitive:
data["key_hash"] = self.key_hash
return data
def is_expired(self) -> bool:
"""Check if the API key has expired"""
if self.expires_at is None:
return False
return datetime.utcnow() > self.expires_at
def is_valid(self) -> bool:
"""Check if the API key is valid and active"""
return self.is_active and not self.is_expired()
def has_permission(self, permission: str) -> bool:
"""Check if the API key has a specific permission"""
return permission in self.permissions
def has_scope(self, scope: str) -> bool:
"""Check if the API key has a specific scope"""
return scope in self.scopes
def can_access_model(self, model_name: str) -> bool:
"""Check if the API key can access a specific model"""
if not self.allowed_models: # Empty list means all models allowed
return True
return model_name in self.allowed_models
def can_access_endpoint(self, endpoint: str) -> bool:
"""Check if the API key can access a specific endpoint"""
if not self.allowed_endpoints: # Empty list means all endpoints allowed
return True
return endpoint in self.allowed_endpoints
def can_access_from_ip(self, ip_address: str) -> bool:
"""Check if the API key can be used from a specific IP"""
if not self.allowed_ips: # Empty list means all IPs allowed
return True
return ip_address in self.allowed_ips
def can_access_chatbot(self, chatbot_id: str) -> bool:
"""Check if the API key can access a specific chatbot"""
if not self.allowed_chatbots: # Empty list means all chatbots allowed
return True
return chatbot_id in self.allowed_chatbots
def update_usage(self, tokens_used: int = 0, cost_cents: int = 0):
"""Update usage statistics"""
self.total_requests += 1
self.total_tokens += tokens_used
self.total_cost += cost_cents
self.last_used_at = datetime.utcnow()
def set_expiration(self, days: int):
"""Set expiration date in days from now"""
self.expires_at = datetime.utcnow() + timedelta(days=days)
def extend_expiration(self, days: int):
"""Extend expiration date by specified days"""
if self.expires_at is None:
self.expires_at = datetime.utcnow() + timedelta(days=days)
else:
self.expires_at = self.expires_at + timedelta(days=days)
def revoke(self):
"""Revoke the API key"""
self.is_active = False
self.updated_at = datetime.utcnow()
def add_scope(self, scope: str):
"""Add a scope to the API key"""
if scope not in self.scopes:
self.scopes.append(scope)
def remove_scope(self, scope: str):
"""Remove a scope from the API key"""
if scope in self.scopes:
self.scopes.remove(scope)
def add_allowed_model(self, model_name: str):
"""Add an allowed model"""
if model_name not in self.allowed_models:
self.allowed_models.append(model_name)
def remove_allowed_model(self, model_name: str):
"""Remove an allowed model"""
if model_name in self.allowed_models:
self.allowed_models.remove(model_name)
def add_allowed_endpoint(self, endpoint: str):
"""Add an allowed endpoint"""
if endpoint not in self.allowed_endpoints:
self.allowed_endpoints.append(endpoint)
def remove_allowed_endpoint(self, endpoint: str):
"""Remove an allowed endpoint"""
if endpoint in self.allowed_endpoints:
self.allowed_endpoints.remove(endpoint)
def add_allowed_ip(self, ip_address: str):
"""Add an allowed IP address"""
if ip_address not in self.allowed_ips:
self.allowed_ips.append(ip_address)
def remove_allowed_ip(self, ip_address: str):
"""Remove an allowed IP address"""
if ip_address in self.allowed_ips:
self.allowed_ips.remove(ip_address)
def add_allowed_chatbot(self, chatbot_id: str):
"""Add an allowed chatbot"""
if chatbot_id not in self.allowed_chatbots:
self.allowed_chatbots.append(chatbot_id)
def remove_allowed_chatbot(self, chatbot_id: str):
"""Remove an allowed chatbot"""
if chatbot_id in self.allowed_chatbots:
self.allowed_chatbots.remove(chatbot_id)
@classmethod
def create_default_key(cls, user_id: int, name: str, key_hash: str, key_prefix: str) -> "APIKey":
"""Create a default API key with standard permissions"""
return cls(
name=name,
key_hash=key_hash,
key_prefix=key_prefix,
user_id=user_id,
is_active=True,
permissions={
"read": True,
"write": True,
"chat": True,
"embeddings": True
},
scopes=[
"chat.completions",
"embeddings.create",
"models.list"
],
rate_limit_per_minute=60,
rate_limit_per_hour=3600,
rate_limit_per_day=86400,
allowed_models=[], # All models allowed by default
allowed_endpoints=[], # All endpoints allowed by default
allowed_ips=[], # All IPs allowed by default
description="Default API key with standard permissions",
tags=["default"]
)
@classmethod
def create_restricted_key(cls, user_id: int, name: str, key_hash: str, key_prefix: str,
models: List[str], endpoints: List[str]) -> "APIKey":
"""Create a restricted API key with limited permissions"""
return cls(
name=name,
key_hash=key_hash,
key_prefix=key_prefix,
user_id=user_id,
is_active=True,
permissions={
"read": True,
"chat": True
},
scopes=[
"chat.completions"
],
rate_limit_per_minute=30,
rate_limit_per_hour=1800,
rate_limit_per_day=43200,
allowed_models=models,
allowed_endpoints=endpoints,
allowed_ips=[],
description="Restricted API key with limited permissions",
tags=["restricted"]
)
@classmethod
def create_chatbot_key(cls, user_id: int, name: str, key_hash: str, key_prefix: str,
chatbot_id: str, chatbot_name: str) -> "APIKey":
"""Create a chatbot-specific API key"""
return cls(
name=name,
key_hash=key_hash,
key_prefix=key_prefix,
user_id=user_id,
is_active=True,
permissions={
"chatbot": True
},
scopes=[
"chatbot.chat"
],
rate_limit_per_minute=100,
rate_limit_per_hour=6000,
rate_limit_per_day=144000,
allowed_models=[], # Will use chatbot's configured model
allowed_endpoints=[
f"/api/v1/chatbot/external/{chatbot_id}/chat"
],
allowed_ips=[],
allowed_chatbots=[chatbot_id],
description=f"API key for chatbot: {chatbot_name}",
tags=["chatbot", f"chatbot-{chatbot_id}"]
)

View File

@@ -0,0 +1,346 @@
"""
Audit log model for tracking system events and user actions
"""
from datetime import datetime
from typing import Optional, Dict, Any
from sqlalchemy import Column, Integer, String, DateTime, JSON, ForeignKey, Text, Boolean
from sqlalchemy.orm import relationship
from app.db.database import Base
from enum import Enum
class AuditAction(str, Enum):
"""Audit action types"""
CREATE = "create"
READ = "read"
UPDATE = "update"
DELETE = "delete"
LOGIN = "login"
LOGOUT = "logout"
API_KEY_CREATE = "api_key_create"
API_KEY_DELETE = "api_key_delete"
BUDGET_CREATE = "budget_create"
BUDGET_UPDATE = "budget_update"
BUDGET_EXCEED = "budget_exceed"
MODULE_ENABLE = "module_enable"
MODULE_DISABLE = "module_disable"
PERMISSION_GRANT = "permission_grant"
PERMISSION_REVOKE = "permission_revoke"
SYSTEM_CONFIG = "system_config"
SECURITY_EVENT = "security_event"
class AuditSeverity(str, Enum):
"""Audit severity levels"""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class AuditLog(Base):
"""Audit log model for tracking system events and user actions"""
__tablename__ = "audit_logs"
id = Column(Integer, primary_key=True, index=True)
# User relationship (nullable for system events)
user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
user = relationship("User", back_populates="audit_logs")
# Event details
action = Column(String, nullable=False)
resource_type = Column(String, nullable=False) # user, api_key, budget, module, etc.
resource_id = Column(String, nullable=True) # ID of the affected resource
# Event description and details
description = Column(Text, nullable=False)
details = Column(JSON, default=dict) # Additional event details
# Request context
ip_address = Column(String, nullable=True)
user_agent = Column(String, nullable=True)
session_id = Column(String, nullable=True)
request_id = Column(String, nullable=True)
# Event classification
severity = Column(String, default=AuditSeverity.LOW)
category = Column(String, nullable=True) # security, access, data, system
# Success/failure tracking
success = Column(Boolean, default=True)
error_message = Column(Text, nullable=True)
# Additional metadata
tags = Column(JSON, default=list)
audit_metadata = Column("metadata", JSON, default=dict) # Map to 'metadata' column in DB
# Before/after values for data changes
old_values = Column(JSON, nullable=True)
new_values = Column(JSON, nullable=True)
# Timestamp
created_at = Column(DateTime, default=datetime.utcnow, index=True)
def __repr__(self):
return f"<AuditLog(id={self.id}, action='{self.action}', user_id={self.user_id})>"
def to_dict(self):
"""Convert audit log to dictionary for API responses"""
return {
"id": self.id,
"user_id": self.user_id,
"action": self.action,
"resource_type": self.resource_type,
"resource_id": self.resource_id,
"description": self.description,
"details": self.details,
"ip_address": self.ip_address,
"user_agent": self.user_agent,
"session_id": self.session_id,
"request_id": self.request_id,
"severity": self.severity,
"category": self.category,
"success": self.success,
"error_message": self.error_message,
"tags": self.tags,
"metadata": self.audit_metadata,
"old_values": self.old_values,
"new_values": self.new_values,
"created_at": self.created_at.isoformat() if self.created_at else None
}
def is_security_event(self) -> bool:
"""Check if this is a security-related event"""
security_actions = [
AuditAction.LOGIN,
AuditAction.LOGOUT,
AuditAction.API_KEY_CREATE,
AuditAction.API_KEY_DELETE,
AuditAction.PERMISSION_GRANT,
AuditAction.PERMISSION_REVOKE,
AuditAction.SECURITY_EVENT
]
return self.action in security_actions or self.category == "security"
def is_high_severity(self) -> bool:
"""Check if this is a high severity event"""
return self.severity in [AuditSeverity.HIGH, AuditSeverity.CRITICAL]
def add_tag(self, tag: str):
"""Add a tag to the audit log"""
if tag not in self.tags:
self.tags.append(tag)
def remove_tag(self, tag: str):
"""Remove a tag from the audit log"""
if tag in self.tags:
self.tags.remove(tag)
def update_metadata(self, key: str, value: Any):
"""Update metadata"""
if self.audit_metadata is None:
self.audit_metadata = {}
self.audit_metadata[key] = value
def set_before_after(self, old_values: Dict[str, Any], new_values: Dict[str, Any]):
"""Set before and after values for data changes"""
self.old_values = old_values
self.new_values = new_values
@classmethod
def create_login_event(cls, user_id: int, success: bool = True,
ip_address: str = None, user_agent: str = None,
session_id: str = None, error_message: str = None) -> "AuditLog":
"""Create a login audit event"""
return cls(
user_id=user_id,
action=AuditAction.LOGIN,
resource_type="user",
resource_id=str(user_id),
description=f"User login {'successful' if success else 'failed'}",
details={
"login_method": "password",
"success": success
},
ip_address=ip_address,
user_agent=user_agent,
session_id=session_id,
severity=AuditSeverity.LOW if success else AuditSeverity.MEDIUM,
category="security",
success=success,
error_message=error_message,
tags=["authentication", "login"]
)
@classmethod
def create_logout_event(cls, user_id: int, session_id: str = None) -> "AuditLog":
"""Create a logout audit event"""
return cls(
user_id=user_id,
action=AuditAction.LOGOUT,
resource_type="user",
resource_id=str(user_id),
description="User logout",
details={
"logout_method": "manual"
},
session_id=session_id,
severity=AuditSeverity.LOW,
category="security",
success=True,
tags=["authentication", "logout"]
)
@classmethod
def create_api_key_event(cls, user_id: int, action: str, api_key_id: int,
api_key_name: str, success: bool = True,
error_message: str = None) -> "AuditLog":
"""Create an API key audit event"""
return cls(
user_id=user_id,
action=action,
resource_type="api_key",
resource_id=str(api_key_id),
description=f"API key {action}: {api_key_name}",
details={
"api_key_name": api_key_name,
"action": action
},
severity=AuditSeverity.MEDIUM,
category="security",
success=success,
error_message=error_message,
tags=["api_key", action]
)
@classmethod
def create_budget_event(cls, user_id: int, action: str, budget_id: int,
budget_name: str, details: Dict[str, Any] = None,
success: bool = True) -> "AuditLog":
"""Create a budget audit event"""
return cls(
user_id=user_id,
action=action,
resource_type="budget",
resource_id=str(budget_id),
description=f"Budget {action}: {budget_name}",
details=details or {},
severity=AuditSeverity.MEDIUM if action == AuditAction.BUDGET_EXCEED else AuditSeverity.LOW,
category="financial",
success=success,
tags=["budget", action]
)
@classmethod
def create_module_event(cls, user_id: int, action: str, module_name: str,
success: bool = True, error_message: str = None,
details: Dict[str, Any] = None) -> "AuditLog":
"""Create a module audit event"""
return cls(
user_id=user_id,
action=action,
resource_type="module",
resource_id=module_name,
description=f"Module {action}: {module_name}",
details=details or {},
severity=AuditSeverity.MEDIUM,
category="system",
success=success,
error_message=error_message,
tags=["module", action]
)
@classmethod
def create_permission_event(cls, user_id: int, action: str, target_user_id: int,
permission: str, success: bool = True) -> "AuditLog":
"""Create a permission audit event"""
return cls(
user_id=user_id,
action=action,
resource_type="permission",
resource_id=str(target_user_id),
description=f"Permission {action}: {permission} for user {target_user_id}",
details={
"permission": permission,
"target_user_id": target_user_id
},
severity=AuditSeverity.HIGH,
category="security",
success=success,
tags=["permission", action]
)
@classmethod
def create_security_event(cls, user_id: int, event_type: str, description: str,
severity: str = AuditSeverity.HIGH,
details: Dict[str, Any] = None,
ip_address: str = None) -> "AuditLog":
"""Create a security audit event"""
return cls(
user_id=user_id,
action=AuditAction.SECURITY_EVENT,
resource_type="security",
resource_id=event_type,
description=description,
details=details or {},
ip_address=ip_address,
severity=severity,
category="security",
success=False, # Security events are typically failures
tags=["security", event_type]
)
@classmethod
def create_system_event(cls, action: str, description: str,
resource_type: str = "system",
resource_id: str = None,
severity: str = AuditSeverity.LOW,
details: Dict[str, Any] = None) -> "AuditLog":
"""Create a system audit event"""
return cls(
user_id=None, # System events don't have a user
action=action,
resource_type=resource_type,
resource_id=resource_id,
description=description,
details=details or {},
severity=severity,
category="system",
success=True,
tags=["system", action]
)
@classmethod
def create_data_change_event(cls, user_id: int, action: str, resource_type: str,
resource_id: str, description: str,
old_values: Dict[str, Any],
new_values: Dict[str, Any]) -> "AuditLog":
"""Create a data change audit event"""
return cls(
user_id=user_id,
action=action,
resource_type=resource_type,
resource_id=resource_id,
description=description,
old_values=old_values,
new_values=new_values,
severity=AuditSeverity.LOW,
category="data",
success=True,
tags=["data_change", action]
)
def get_summary(self) -> Dict[str, Any]:
"""Get a summary of the audit log"""
return {
"id": self.id,
"action": self.action,
"resource_type": self.resource_type,
"description": self.description,
"severity": self.severity,
"success": self.success,
"created_at": self.created_at.isoformat() if self.created_at else None,
"user_id": self.user_id
}

View File

@@ -0,0 +1,296 @@
"""
Budget model for managing spending limits and cost control
"""
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from enum import Enum
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON, ForeignKey, Float
from sqlalchemy.orm import relationship
from app.db.database import Base
class BudgetType(str, Enum):
"""Budget type enumeration"""
USER = "user"
API_KEY = "api_key"
GLOBAL = "global"
class BudgetPeriod(str, Enum):
"""Budget period types"""
DAILY = "daily"
WEEKLY = "weekly"
MONTHLY = "monthly"
YEARLY = "yearly"
CUSTOM = "custom"
class Budget(Base):
"""Budget model for setting and managing spending limits"""
__tablename__ = "budgets"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False) # Human-readable name for the budget
# User and API Key relationships
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
user = relationship("User", back_populates="budgets")
api_key_id = Column(Integer, ForeignKey("api_keys.id"), nullable=True) # Optional: specific to an API key
api_key = relationship("APIKey", back_populates="budgets")
# Usage tracking relationship
usage_tracking = relationship("UsageTracking", back_populates="budget", cascade="all, delete-orphan")
# Budget limits (in cents)
limit_cents = Column(Integer, nullable=False) # Maximum spend limit
warning_threshold_cents = Column(Integer, nullable=True) # Warning threshold (e.g., 80% of limit)
# Time period settings
period_type = Column(String, nullable=False, default="monthly") # daily, weekly, monthly, yearly, custom
period_start = Column(DateTime, nullable=False) # Start of current period
period_end = Column(DateTime, nullable=False) # End of current period
# Current usage (in cents)
current_usage_cents = Column(Integer, default=0) # Spent in current period
# Budget status
is_active = Column(Boolean, default=True)
is_exceeded = Column(Boolean, default=False)
is_warning_sent = Column(Boolean, default=False)
# Enforcement settings
enforce_hard_limit = Column(Boolean, default=True) # Block requests when limit exceeded
enforce_warning = Column(Boolean, default=True) # Send warnings at threshold
# Allowed resources (optional filters)
allowed_models = Column(JSON, default=list) # Specific models this budget applies to
allowed_endpoints = Column(JSON, default=list) # Specific endpoints this budget applies to
# Metadata
description = Column(Text, nullable=True)
tags = Column(JSON, default=list)
currency = Column(String, default="USD")
# Auto-renewal settings
auto_renew = Column(Boolean, default=True) # Automatically renew budget for next period
rollover_unused = Column(Boolean, default=False) # Rollover unused budget to next period
# Notification settings
notification_settings = Column(JSON, default=dict) # Email, webhook, etc.
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
last_reset_at = Column(DateTime, nullable=True) # Last time budget was reset
def __repr__(self):
return f"<Budget(id={self.id}, name='{self.name}', user_id={self.user_id}, limit=${self.limit_cents/100:.2f})>"
def to_dict(self):
"""Convert budget to dictionary for API responses"""
return {
"id": self.id,
"name": self.name,
"user_id": self.user_id,
"api_key_id": self.api_key_id,
"limit_cents": self.limit_cents,
"limit_dollars": self.limit_cents / 100,
"warning_threshold_cents": self.warning_threshold_cents,
"warning_threshold_dollars": self.warning_threshold_cents / 100 if self.warning_threshold_cents else None,
"period_type": self.period_type,
"period_start": self.period_start.isoformat() if self.period_start else None,
"period_end": self.period_end.isoformat() if self.period_end else None,
"current_usage_cents": self.current_usage_cents,
"current_usage_dollars": self.current_usage_cents / 100,
"remaining_cents": max(0, self.limit_cents - self.current_usage_cents),
"remaining_dollars": max(0, (self.limit_cents - self.current_usage_cents) / 100),
"usage_percentage": (self.current_usage_cents / self.limit_cents * 100) if self.limit_cents > 0 else 0,
"is_active": self.is_active,
"is_exceeded": self.is_exceeded,
"is_warning_sent": self.is_warning_sent,
"enforce_hard_limit": self.enforce_hard_limit,
"enforce_warning": self.enforce_warning,
"allowed_models": self.allowed_models,
"allowed_endpoints": self.allowed_endpoints,
"description": self.description,
"tags": self.tags,
"currency": self.currency,
"auto_renew": self.auto_renew,
"rollover_unused": self.rollover_unused,
"notification_settings": self.notification_settings,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"last_reset_at": self.last_reset_at.isoformat() if self.last_reset_at else None
}
def is_in_period(self) -> bool:
"""Check if current time is within budget period"""
now = datetime.utcnow()
return self.period_start <= now <= self.period_end
def is_expired(self) -> bool:
"""Check if budget period has expired"""
return datetime.utcnow() > self.period_end
def can_spend(self, amount_cents: int) -> bool:
"""Check if spending amount is within budget"""
if not self.is_active or not self.is_in_period():
return False
if not self.enforce_hard_limit:
return True
return (self.current_usage_cents + amount_cents) <= self.limit_cents
def would_exceed_warning(self, amount_cents: int) -> bool:
"""Check if spending amount would exceed warning threshold"""
if not self.warning_threshold_cents:
return False
return (self.current_usage_cents + amount_cents) >= self.warning_threshold_cents
def add_usage(self, amount_cents: int):
"""Add usage to current budget"""
self.current_usage_cents += amount_cents
# Check if budget is exceeded
if self.current_usage_cents >= self.limit_cents:
self.is_exceeded = True
# Check if warning threshold is reached
if self.warning_threshold_cents and self.current_usage_cents >= self.warning_threshold_cents:
if not self.is_warning_sent:
self.is_warning_sent = True
self.updated_at = datetime.utcnow()
def reset_period(self):
"""Reset budget for new period"""
if self.rollover_unused and self.current_usage_cents < self.limit_cents:
# Rollover unused budget
unused_amount = self.limit_cents - self.current_usage_cents
self.limit_cents += unused_amount
self.current_usage_cents = 0
self.is_exceeded = False
self.is_warning_sent = False
self.last_reset_at = datetime.utcnow()
# Calculate next period
if self.period_type == "daily":
self.period_start = self.period_end
self.period_end = self.period_start + timedelta(days=1)
elif self.period_type == "weekly":
self.period_start = self.period_end
self.period_end = self.period_start + timedelta(weeks=1)
elif self.period_type == "monthly":
self.period_start = self.period_end
# Handle month boundaries properly
if self.period_start.month == 12:
next_month = self.period_start.replace(year=self.period_start.year + 1, month=1)
else:
next_month = self.period_start.replace(month=self.period_start.month + 1)
self.period_end = next_month
elif self.period_type == "yearly":
self.period_start = self.period_end
self.period_end = self.period_start.replace(year=self.period_start.year + 1)
self.updated_at = datetime.utcnow()
def get_period_days_remaining(self) -> int:
"""Get number of days remaining in current period"""
if self.is_expired():
return 0
return (self.period_end - datetime.utcnow()).days
def get_daily_burn_rate(self) -> float:
"""Get average daily spend rate in current period"""
if not self.is_in_period():
return 0.0
days_elapsed = (datetime.utcnow() - self.period_start).days
if days_elapsed == 0:
days_elapsed = 1 # Avoid division by zero
return self.current_usage_cents / days_elapsed / 100 # Return in dollars
def get_projected_spend(self) -> float:
"""Get projected spend for entire period based on current burn rate"""
daily_burn = self.get_daily_burn_rate()
total_period_days = (self.period_end - self.period_start).days
return daily_burn * total_period_days
@classmethod
def create_monthly_budget(
cls,
user_id: int,
name: str,
limit_dollars: float,
api_key_id: Optional[int] = None,
warning_threshold_percentage: float = 0.8
) -> "Budget":
"""Create a monthly budget"""
now = datetime.utcnow()
# Start of current month
period_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
# Start of next month
if now.month == 12:
period_end = period_start.replace(year=now.year + 1, month=1)
else:
period_end = period_start.replace(month=now.month + 1)
limit_cents = int(limit_dollars * 100)
warning_threshold_cents = int(limit_cents * warning_threshold_percentage)
return cls(
name=name,
user_id=user_id,
api_key_id=api_key_id,
limit_cents=limit_cents,
warning_threshold_cents=warning_threshold_cents,
period_type="monthly",
period_start=period_start,
period_end=period_end,
is_active=True,
enforce_hard_limit=True,
enforce_warning=True,
auto_renew=True,
notification_settings={
"email_on_warning": True,
"email_on_exceeded": True
}
)
@classmethod
def create_daily_budget(
cls,
user_id: int,
name: str,
limit_dollars: float,
api_key_id: Optional[int] = None
) -> "Budget":
"""Create a daily budget"""
now = datetime.utcnow()
period_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
period_end = period_start + timedelta(days=1)
limit_cents = int(limit_dollars * 100)
warning_threshold_cents = int(limit_cents * 0.8) # 80% warning threshold
return cls(
name=name,
user_id=user_id,
api_key_id=api_key_id,
limit_cents=limit_cents,
warning_threshold_cents=warning_threshold_cents,
period_type="daily",
period_start=period_start,
period_end=period_end,
is_active=True,
enforce_hard_limit=True,
enforce_warning=True,
auto_renew=True
)

View File

@@ -0,0 +1,110 @@
"""
Database models for chatbot module
"""
from sqlalchemy import Column, Integer, String, Text, Boolean, DateTime, JSON, ForeignKey
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base
from datetime import datetime
import uuid
from app.db.database import Base
class ChatbotInstance(Base):
"""Configured chatbot instance"""
__tablename__ = "chatbot_instances"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
name = Column(String(255), nullable=False)
description = Column(Text)
# Configuration stored as JSON
config = Column(JSON, nullable=False)
# Metadata
created_by = Column(String, nullable=False) # User ID
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
is_active = Column(Boolean, default=True)
# Relationships
conversations = relationship("ChatbotConversation", back_populates="chatbot", cascade="all, delete-orphan")
def __repr__(self):
return f"<ChatbotInstance(id='{self.id}', name='{self.name}')>"
class ChatbotConversation(Base):
"""Conversation state and history"""
__tablename__ = "chatbot_conversations"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
chatbot_id = Column(String, ForeignKey("chatbot_instances.id"), nullable=False)
user_id = Column(String, nullable=False) # User ID
# Conversation metadata
title = Column(String(255)) # Auto-generated or user-defined title
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
is_active = Column(Boolean, default=True)
# Conversation context and settings
context_data = Column(JSON, default=dict) # Additional context
# Relationships
chatbot = relationship("ChatbotInstance", back_populates="conversations")
messages = relationship("ChatbotMessage", back_populates="conversation", cascade="all, delete-orphan")
def __repr__(self):
return f"<ChatbotConversation(id='{self.id}', chatbot_id='{self.chatbot_id}')>"
class ChatbotMessage(Base):
"""Individual chat messages in conversations"""
__tablename__ = "chatbot_messages"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
conversation_id = Column(String, ForeignKey("chatbot_conversations.id"), nullable=False)
# Message content
role = Column(String(20), nullable=False) # 'user', 'assistant', 'system'
content = Column(Text, nullable=False)
# Metadata
timestamp = Column(DateTime, default=datetime.utcnow)
message_metadata = Column(JSON, default=dict) # Token counts, model used, etc.
# RAG sources if applicable
sources = Column(JSON) # RAG sources used for this message
# Relationships
conversation = relationship("ChatbotConversation", back_populates="messages")
def __repr__(self):
return f"<ChatbotMessage(id='{self.id}', role='{self.role}')>"
class ChatbotAnalytics(Base):
"""Analytics and metrics for chatbot usage"""
__tablename__ = "chatbot_analytics"
id = Column(Integer, primary_key=True, autoincrement=True)
chatbot_id = Column(String, ForeignKey("chatbot_instances.id"), nullable=False)
user_id = Column(String, nullable=False)
# Event tracking
event_type = Column(String(50), nullable=False) # 'message_sent', 'response_generated', etc.
event_data = Column(JSON, default=dict)
# Performance metrics
response_time_ms = Column(Integer)
token_count = Column(Integer)
cost_cents = Column(Integer)
# Context
model_used = Column(String(100))
rag_used = Column(Boolean, default=False)
timestamp = Column(DateTime, default=datetime.utcnow)
def __repr__(self):
return f"<ChatbotAnalytics(id={self.id}, event_type='{self.event_type}')>"

View File

@@ -0,0 +1,499 @@
"""
Module model for tracking installed modules and their configurations
"""
from datetime import datetime
from typing import Optional, Dict, Any, List
from sqlalchemy import Column, Integer, String, DateTime, Boolean, JSON, Text
from app.db.database import Base
from enum import Enum
class ModuleStatus(str, Enum):
"""Module status types"""
ACTIVE = "active"
INACTIVE = "inactive"
ERROR = "error"
LOADING = "loading"
DISABLED = "disabled"
class ModuleType(str, Enum):
"""Module type categories"""
CORE = "core"
INTERCEPTOR = "interceptor"
ANALYTICS = "analytics"
SECURITY = "security"
STORAGE = "storage"
INTEGRATION = "integration"
CUSTOM = "custom"
class Module(Base):
"""Module model for tracking installed modules and their configurations"""
__tablename__ = "modules"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, unique=True, index=True, nullable=False)
display_name = Column(String, nullable=False)
description = Column(Text, nullable=True)
# Module classification
module_type = Column(String, default=ModuleType.CUSTOM)
category = Column(String, nullable=True) # cache, rag, analytics, etc.
# Module details
version = Column(String, nullable=False)
author = Column(String, nullable=True)
license = Column(String, nullable=True)
# Module status
status = Column(String, default=ModuleStatus.INACTIVE)
is_enabled = Column(Boolean, default=False)
is_core = Column(Boolean, default=False) # Core modules cannot be disabled
# Configuration
config_schema = Column(JSON, default=dict) # JSON schema for configuration
config_values = Column(JSON, default=dict) # Current configuration values
default_config = Column(JSON, default=dict) # Default configuration
# Dependencies
dependencies = Column(JSON, default=list) # List of module dependencies
conflicts = Column(JSON, default=list) # List of conflicting modules
# Installation details
install_path = Column(String, nullable=True)
entry_point = Column(String, nullable=True) # Main module entry point
# Interceptor configuration
interceptor_chains = Column(JSON, default=list) # Which chains this module hooks into
execution_order = Column(Integer, default=100) # Order in interceptor chain
# API endpoints
api_endpoints = Column(JSON, default=list) # List of API endpoints this module provides
# Permissions and security
required_permissions = Column(JSON, default=list) # Permissions required to use this module
security_level = Column(String, default="low") # low, medium, high, critical
# Metadata
tags = Column(JSON, default=list)
module_metadata = Column(JSON, default=dict)
# Runtime information
last_error = Column(Text, nullable=True)
error_count = Column(Integer, default=0)
last_started = Column(DateTime, nullable=True)
last_stopped = Column(DateTime, nullable=True)
# Statistics
request_count = Column(Integer, default=0)
success_count = Column(Integer, default=0)
error_count_runtime = Column(Integer, default=0)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
installed_at = Column(DateTime, default=datetime.utcnow)
def __repr__(self):
return f"<Module(id={self.id}, name='{self.name}', status='{self.status}')>"
def to_dict(self):
"""Convert module to dictionary for API responses"""
return {
"id": self.id,
"name": self.name,
"display_name": self.display_name,
"description": self.description,
"module_type": self.module_type,
"category": self.category,
"version": self.version,
"author": self.author,
"license": self.license,
"status": self.status,
"is_enabled": self.is_enabled,
"is_core": self.is_core,
"config_schema": self.config_schema,
"config_values": self.config_values,
"default_config": self.default_config,
"dependencies": self.dependencies,
"conflicts": self.conflicts,
"install_path": self.install_path,
"entry_point": self.entry_point,
"interceptor_chains": self.interceptor_chains,
"execution_order": self.execution_order,
"api_endpoints": self.api_endpoints,
"required_permissions": self.required_permissions,
"security_level": self.security_level,
"tags": self.tags,
"metadata": self.module_metadata,
"last_error": self.last_error,
"error_count": self.error_count,
"last_started": self.last_started.isoformat() if self.last_started else None,
"last_stopped": self.last_stopped.isoformat() if self.last_stopped else None,
"request_count": self.request_count,
"success_count": self.success_count,
"error_count_runtime": self.error_count_runtime,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"installed_at": self.installed_at.isoformat() if self.installed_at else None,
"success_rate": self.get_success_rate(),
"uptime": self.get_uptime_seconds() if self.is_running() else 0
}
def is_running(self) -> bool:
"""Check if module is currently running"""
return self.status == ModuleStatus.ACTIVE
def is_healthy(self) -> bool:
"""Check if module is healthy (running without recent errors)"""
return self.is_running() and self.error_count_runtime == 0
def get_success_rate(self) -> float:
"""Get success rate as percentage"""
if self.request_count == 0:
return 100.0
return (self.success_count / self.request_count) * 100
def get_uptime_seconds(self) -> int:
"""Get uptime in seconds"""
if not self.last_started:
return 0
return int((datetime.utcnow() - self.last_started).total_seconds())
def can_be_disabled(self) -> bool:
"""Check if module can be disabled"""
return not self.is_core
def has_dependency(self, module_name: str) -> bool:
"""Check if module has a specific dependency"""
return module_name in self.dependencies
def conflicts_with(self, module_name: str) -> bool:
"""Check if module conflicts with another module"""
return module_name in self.conflicts
def requires_permission(self, permission: str) -> bool:
"""Check if module requires a specific permission"""
return permission in self.required_permissions
def hooks_into_chain(self, chain_name: str) -> bool:
"""Check if module hooks into a specific interceptor chain"""
return chain_name in self.interceptor_chains
def provides_endpoint(self, endpoint: str) -> bool:
"""Check if module provides a specific API endpoint"""
return endpoint in self.api_endpoints
def update_config(self, config_updates: Dict[str, Any]):
"""Update module configuration"""
if self.config_values is None:
self.config_values = {}
self.config_values.update(config_updates)
self.updated_at = datetime.utcnow()
def reset_config(self):
"""Reset configuration to default values"""
self.config_values = self.default_config.copy() if self.default_config else {}
self.updated_at = datetime.utcnow()
def enable(self):
"""Enable the module"""
if self.status != ModuleStatus.ERROR:
self.is_enabled = True
self.status = ModuleStatus.LOADING
self.updated_at = datetime.utcnow()
def disable(self):
"""Disable the module"""
if self.can_be_disabled():
self.is_enabled = False
self.status = ModuleStatus.DISABLED
self.last_stopped = datetime.utcnow()
self.updated_at = datetime.utcnow()
def start(self):
"""Start the module"""
self.status = ModuleStatus.ACTIVE
self.last_started = datetime.utcnow()
self.last_error = None
self.updated_at = datetime.utcnow()
def stop(self):
"""Stop the module"""
self.status = ModuleStatus.INACTIVE
self.last_stopped = datetime.utcnow()
self.updated_at = datetime.utcnow()
def set_error(self, error_message: str):
"""Set module error status"""
self.status = ModuleStatus.ERROR
self.last_error = error_message
self.error_count += 1
self.error_count_runtime += 1
self.updated_at = datetime.utcnow()
def clear_error(self):
"""Clear error status"""
self.last_error = None
self.error_count_runtime = 0
self.updated_at = datetime.utcnow()
def record_request(self, success: bool = True):
"""Record a request to this module"""
self.request_count += 1
if success:
self.success_count += 1
else:
self.error_count_runtime += 1
def add_tag(self, tag: str):
"""Add a tag to the module"""
if tag not in self.tags:
self.tags.append(tag)
def remove_tag(self, tag: str):
"""Remove a tag from the module"""
if tag in self.tags:
self.tags.remove(tag)
def update_metadata(self, key: str, value: Any):
"""Update metadata"""
if self.module_metadata is None:
self.module_metadata = {}
self.module_metadata[key] = value
def add_dependency(self, module_name: str):
"""Add a dependency"""
if module_name not in self.dependencies:
self.dependencies.append(module_name)
def remove_dependency(self, module_name: str):
"""Remove a dependency"""
if module_name in self.dependencies:
self.dependencies.remove(module_name)
def add_conflict(self, module_name: str):
"""Add a conflict"""
if module_name not in self.conflicts:
self.conflicts.append(module_name)
def remove_conflict(self, module_name: str):
"""Remove a conflict"""
if module_name in self.conflicts:
self.conflicts.remove(module_name)
def add_interceptor_chain(self, chain_name: str):
"""Add an interceptor chain"""
if chain_name not in self.interceptor_chains:
self.interceptor_chains.append(chain_name)
def remove_interceptor_chain(self, chain_name: str):
"""Remove an interceptor chain"""
if chain_name in self.interceptor_chains:
self.interceptor_chains.remove(chain_name)
def add_api_endpoint(self, endpoint: str):
"""Add an API endpoint"""
if endpoint not in self.api_endpoints:
self.api_endpoints.append(endpoint)
def remove_api_endpoint(self, endpoint: str):
"""Remove an API endpoint"""
if endpoint in self.api_endpoints:
self.api_endpoints.remove(endpoint)
def add_required_permission(self, permission: str):
"""Add a required permission"""
if permission not in self.required_permissions:
self.required_permissions.append(permission)
def remove_required_permission(self, permission: str):
"""Remove a required permission"""
if permission in self.required_permissions:
self.required_permissions.remove(permission)
@classmethod
def create_core_module(cls, name: str, display_name: str, description: str,
version: str, entry_point: str) -> "Module":
"""Create a core module"""
return cls(
name=name,
display_name=display_name,
description=description,
module_type=ModuleType.CORE,
version=version,
author="Confidential Empire",
license="Proprietary",
status=ModuleStatus.ACTIVE,
is_enabled=True,
is_core=True,
entry_point=entry_point,
config_schema={},
config_values={},
default_config={},
dependencies=[],
conflicts=[],
interceptor_chains=[],
execution_order=10, # Core modules run first
api_endpoints=[],
required_permissions=[],
security_level="high",
tags=["core"],
module_metadata={}
)
@classmethod
def create_cache_module(cls) -> "Module":
"""Create the cache module"""
return cls(
name="cache",
display_name="Cache Module",
description="Redis-based caching for improved performance",
module_type=ModuleType.INTERCEPTOR,
category="cache",
version="1.0.0",
author="Confidential Empire",
license="Proprietary",
status=ModuleStatus.INACTIVE,
is_enabled=True,
is_core=False,
entry_point="app.modules.cache.main",
config_schema={
"type": "object",
"properties": {
"provider": {"type": "string", "enum": ["redis"]},
"ttl": {"type": "integer", "minimum": 60},
"max_size": {"type": "integer", "minimum": 1000}
},
"required": ["provider", "ttl"]
},
config_values={
"provider": "redis",
"ttl": 3600,
"max_size": 10000
},
default_config={
"provider": "redis",
"ttl": 3600,
"max_size": 10000
},
dependencies=[],
conflicts=[],
interceptor_chains=["pre_request", "post_response"],
execution_order=20,
api_endpoints=["/api/v1/cache/stats", "/api/v1/cache/clear"],
required_permissions=["cache.read", "cache.write"],
security_level="low",
tags=["cache", "performance"],
module_metadata={}
)
@classmethod
def create_rag_module(cls) -> "Module":
"""Create the RAG module"""
return cls(
name="rag",
display_name="RAG Module",
description="Retrieval Augmented Generation with vector database",
module_type=ModuleType.INTERCEPTOR,
category="rag",
version="1.0.0",
author="Confidential Empire",
license="Proprietary",
status=ModuleStatus.INACTIVE,
is_enabled=True,
is_core=False,
entry_point="app.modules.rag.main",
config_schema={
"type": "object",
"properties": {
"vector_db": {"type": "string", "enum": ["qdrant"]},
"embedding_model": {"type": "string"},
"chunk_size": {"type": "integer", "minimum": 100},
"max_results": {"type": "integer", "minimum": 1}
},
"required": ["vector_db", "embedding_model"]
},
config_values={
"vector_db": "qdrant",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"chunk_size": 512,
"max_results": 10
},
default_config={
"vector_db": "qdrant",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"chunk_size": 512,
"max_results": 10
},
dependencies=[],
conflicts=[],
interceptor_chains=["pre_request"],
execution_order=30,
api_endpoints=["/api/v1/rag/documents", "/api/v1/rag/search"],
required_permissions=["rag.read", "rag.write"],
security_level="medium",
tags=["rag", "ai", "search"],
module_metadata={}
)
@classmethod
def create_analytics_module(cls) -> "Module":
"""Create the analytics module"""
return cls(
name="analytics",
display_name="Analytics Module",
description="Request and response analytics and monitoring",
module_type=ModuleType.ANALYTICS,
category="analytics",
version="1.0.0",
author="Confidential Empire",
license="Proprietary",
status=ModuleStatus.INACTIVE,
is_enabled=True,
is_core=False,
entry_point="app.modules.analytics.main",
config_schema={
"type": "object",
"properties": {
"track_requests": {"type": "boolean"},
"track_responses": {"type": "boolean"},
"retention_days": {"type": "integer", "minimum": 1}
},
"required": ["track_requests", "track_responses"]
},
config_values={
"track_requests": True,
"track_responses": True,
"retention_days": 30
},
default_config={
"track_requests": True,
"track_responses": True,
"retention_days": 30
},
dependencies=[],
conflicts=[],
interceptor_chains=["pre_request", "post_response"],
execution_order=90, # Analytics runs last
api_endpoints=["/api/v1/analytics/stats", "/api/v1/analytics/reports"],
required_permissions=["analytics.read"],
security_level="low",
tags=["analytics", "monitoring"],
module_metadata={}
)
def get_health_status(self) -> Dict[str, Any]:
"""Get health status of the module"""
return {
"name": self.name,
"status": self.status,
"is_healthy": self.is_healthy(),
"success_rate": self.get_success_rate(),
"uptime_seconds": self.get_uptime_seconds() if self.is_running() else 0,
"last_error": self.last_error,
"error_count": self.error_count_runtime,
"last_started": self.last_started.isoformat() if self.last_started else None
}

View File

@@ -0,0 +1,42 @@
"""
Prompt Template Models for customizable chatbot prompts
"""
from sqlalchemy import Column, String, Text, DateTime, Boolean, Integer
from sqlalchemy.sql import func
from app.db.database import Base
from datetime import datetime
class PromptTemplate(Base):
"""Editable prompt templates for different chatbot types"""
__tablename__ = "prompt_templates"
id = Column(String, primary_key=True, index=True)
name = Column(String(255), nullable=False, index=True) # Human readable name
type_key = Column(String(100), nullable=False, unique=True, index=True) # assistant, customer_support, etc.
description = Column(Text, nullable=True)
system_prompt = Column(Text, nullable=False)
is_default = Column(Boolean, default=True, nullable=False)
is_active = Column(Boolean, default=True, nullable=False)
version = Column(Integer, default=1, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
def __repr__(self):
return f"<PromptTemplate(type_key='{self.type_key}', name='{self.name}')>"
class ChatbotPromptVariable(Base):
"""Available variables that can be used in prompts"""
__tablename__ = "prompt_variables"
id = Column(String, primary_key=True, index=True)
variable_name = Column(String(100), nullable=False, unique=True, index=True) # {user_name}, {context}, etc.
description = Column(Text, nullable=True)
example_value = Column(String(500), nullable=True)
is_active = Column(Boolean, default=True, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
def __repr__(self):
return f"<PromptVariable(name='{self.variable_name}')>"

View File

@@ -0,0 +1,52 @@
"""
RAG Collection Model
Represents document collections for the RAG system
"""
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, BigInteger
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from app.db.database import Base
class RagCollection(Base):
__tablename__ = "rag_collections"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(255), nullable=False, index=True)
description = Column(Text, nullable=True)
qdrant_collection_name = Column(String(255), nullable=False, unique=True, index=True)
# Metadata
document_count = Column(Integer, default=0, nullable=False)
size_bytes = Column(BigInteger, default=0, nullable=False)
vector_count = Column(Integer, default=0, nullable=False)
# Status tracking
status = Column(String(50), default='active', nullable=False) # 'active', 'indexing', 'error'
is_active = Column(Boolean, default=True, nullable=False)
# Timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
# Relationships
documents = relationship("RagDocument", back_populates="collection", cascade="all, delete-orphan")
def to_dict(self):
"""Convert model to dictionary for API responses"""
return {
"id": str(self.id),
"name": self.name,
"description": self.description or "",
"document_count": self.document_count,
"size_bytes": self.size_bytes,
"vector_count": self.vector_count,
"status": self.status,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"is_active": self.is_active
}
def __repr__(self):
return f"<RagCollection(id={self.id}, name='{self.name}', documents={self.document_count})>"

View File

@@ -0,0 +1,82 @@
"""
RAG Document Model
Represents documents within RAG collections
"""
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, BigInteger, ForeignKey, JSON
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from app.db.database import Base
class RagDocument(Base):
__tablename__ = "rag_documents"
id = Column(Integer, primary_key=True, index=True)
# Collection relationship
collection_id = Column(Integer, ForeignKey("rag_collections.id", ondelete="CASCADE"), nullable=False, index=True)
collection = relationship("RagCollection", back_populates="documents")
# File information
filename = Column(String(255), nullable=False) # sanitized filename for storage
original_filename = Column(String(255), nullable=False) # user's original filename
file_path = Column(String(500), nullable=False) # path to stored file
file_type = Column(String(50), nullable=False) # pdf, docx, txt, etc.
file_size = Column(BigInteger, nullable=False) # file size in bytes
mime_type = Column(String(100), nullable=True)
# Processing status
status = Column(String(50), default='processing', nullable=False) # 'processing', 'processed', 'error', 'indexed'
processing_error = Column(Text, nullable=True)
# Content information
converted_content = Column(Text, nullable=True) # markdown converted content
word_count = Column(Integer, default=0, nullable=False)
character_count = Column(Integer, default=0, nullable=False)
# Vector information
vector_count = Column(Integer, default=0, nullable=False) # number of chunks/vectors created
chunk_size = Column(Integer, default=1000, nullable=False) # chunk size used for vectorization
# Metadata extracted from document
document_metadata = Column(JSON, nullable=True) # language, entities, keywords, etc.
# Processing timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
processed_at = Column(DateTime(timezone=True), nullable=True)
indexed_at = Column(DateTime(timezone=True), nullable=True)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
# Soft delete
is_deleted = Column(Boolean, default=False, nullable=False)
deleted_at = Column(DateTime(timezone=True), nullable=True)
def to_dict(self):
"""Convert model to dictionary for API responses"""
return {
"id": str(self.id),
"collection_id": str(self.collection_id),
"collection_name": self.collection.name if self.collection else None,
"filename": self.filename,
"original_filename": self.original_filename,
"file_type": self.file_type,
"size": self.file_size,
"mime_type": self.mime_type,
"status": self.status,
"processing_error": self.processing_error,
"converted_content": self.converted_content,
"word_count": self.word_count,
"character_count": self.character_count,
"vector_count": self.vector_count,
"chunk_size": self.chunk_size,
"metadata": self.document_metadata or {},
"created_at": self.created_at.isoformat() if self.created_at else None,
"processed_at": self.processed_at.isoformat() if self.processed_at else None,
"indexed_at": self.indexed_at.isoformat() if self.indexed_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"is_deleted": self.is_deleted
}
def __repr__(self):
return f"<RagDocument(id={self.id}, filename='{self.original_filename}', status='{self.status}')>"

View File

@@ -0,0 +1,125 @@
"""
Usage Tracking model for API key usage statistics
"""
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON, ForeignKey, Float
from sqlalchemy.orm import relationship
from app.db.database import Base
class UsageTracking(Base):
"""Usage tracking model for detailed API key usage statistics"""
__tablename__ = "usage_tracking"
id = Column(Integer, primary_key=True, index=True)
# API Key relationship
api_key_id = Column(Integer, ForeignKey("api_keys.id"), nullable=False)
api_key = relationship("APIKey", back_populates="usage_tracking")
# User relationship
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
user = relationship("User", back_populates="usage_tracking")
# Budget relationship (optional)
budget_id = Column(Integer, ForeignKey("budgets.id"), nullable=True)
budget = relationship("Budget", back_populates="usage_tracking")
# Request information
endpoint = Column(String, nullable=False) # API endpoint used
method = Column(String, nullable=False) # HTTP method
model = Column(String, nullable=True) # Model used (if applicable)
# Usage metrics
request_tokens = Column(Integer, default=0) # Input tokens
response_tokens = Column(Integer, default=0) # Output tokens
total_tokens = Column(Integer, default=0) # Total tokens used
# Cost tracking
cost_cents = Column(Integer, default=0) # Cost in cents
cost_currency = Column(String, default="USD") # Currency
# Response information
response_time_ms = Column(Integer, nullable=True) # Response time in milliseconds
status_code = Column(Integer, nullable=True) # HTTP status code
# Request metadata
request_id = Column(String, nullable=True) # Unique request identifier
session_id = Column(String, nullable=True) # Session identifier
ip_address = Column(String, nullable=True) # Client IP address
user_agent = Column(String, nullable=True) # User agent
# Additional metadata
request_metadata = Column(JSON, default=dict) # Additional request metadata
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
def __repr__(self):
return f"<UsageTracking(id={self.id}, api_key_id={self.api_key_id}, endpoint='{self.endpoint}')>"
def to_dict(self):
"""Convert usage tracking to dictionary for API responses"""
return {
"id": self.id,
"api_key_id": self.api_key_id,
"user_id": self.user_id,
"endpoint": self.endpoint,
"method": self.method,
"model": self.model,
"request_tokens": self.request_tokens,
"response_tokens": self.response_tokens,
"total_tokens": self.total_tokens,
"cost_cents": self.cost_cents,
"cost_currency": self.cost_currency,
"response_time_ms": self.response_time_ms,
"status_code": self.status_code,
"request_id": self.request_id,
"session_id": self.session_id,
"ip_address": self.ip_address,
"user_agent": self.user_agent,
"request_metadata": self.request_metadata,
"created_at": self.created_at.isoformat() if self.created_at else None
}
@classmethod
def create_tracking_record(
cls,
api_key_id: int,
user_id: int,
endpoint: str,
method: str,
model: Optional[str] = None,
request_tokens: int = 0,
response_tokens: int = 0,
cost_cents: int = 0,
response_time_ms: Optional[int] = None,
status_code: Optional[int] = None,
request_id: Optional[str] = None,
session_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
request_metadata: Optional[dict] = None
) -> "UsageTracking":
"""Create a new usage tracking record"""
return cls(
api_key_id=api_key_id,
user_id=user_id,
endpoint=endpoint,
method=method,
model=model,
request_tokens=request_tokens,
response_tokens=response_tokens,
total_tokens=request_tokens + response_tokens,
cost_cents=cost_cents,
response_time_ms=response_time_ms,
status_code=status_code,
request_id=request_id,
session_id=session_id,
ip_address=ip_address,
user_agent=user_agent,
request_metadata=request_metadata or {}
)

158
backend/app/models/user.py Normal file
View File

@@ -0,0 +1,158 @@
"""
User model
"""
from datetime import datetime
from typing import Optional, List
from enum import Enum
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON
from sqlalchemy.orm import relationship
from app.db.database import Base
class UserRole(str, Enum):
"""User role enumeration"""
USER = "user"
ADMIN = "admin"
SUPER_ADMIN = "super_admin"
class User(Base):
"""User model for authentication and user management"""
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
email = Column(String, unique=True, index=True, nullable=False)
username = Column(String, unique=True, index=True, nullable=False)
hashed_password = Column(String, nullable=False)
full_name = Column(String, nullable=True)
# User status and permissions
is_active = Column(Boolean, default=True)
is_superuser = Column(Boolean, default=False)
is_verified = Column(Boolean, default=False)
# Role-based access control
role = Column(String, default=UserRole.USER.value) # user, admin, super_admin
permissions = Column(JSON, default=dict) # Custom permissions
# Profile information
avatar_url = Column(String, nullable=True)
bio = Column(Text, nullable=True)
company = Column(String, nullable=True)
website = Column(String, nullable=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
last_login = Column(DateTime, nullable=True)
# Settings
preferences = Column(JSON, default=dict)
notification_settings = Column(JSON, default=dict)
# Relationships
api_keys = relationship("APIKey", back_populates="user", cascade="all, delete-orphan")
usage_tracking = relationship("UsageTracking", back_populates="user", cascade="all, delete-orphan")
budgets = relationship("Budget", back_populates="user", cascade="all, delete-orphan")
audit_logs = relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
def __repr__(self):
return f"<User(id={self.id}, email='{self.email}', username='{self.username}')>"
def to_dict(self):
"""Convert user to dictionary for API responses"""
return {
"id": self.id,
"email": self.email,
"username": self.username,
"full_name": self.full_name,
"is_active": self.is_active,
"is_superuser": self.is_superuser,
"is_verified": self.is_verified,
"role": self.role,
"permissions": self.permissions,
"avatar_url": self.avatar_url,
"bio": self.bio,
"company": self.company,
"website": self.website,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"last_login": self.last_login.isoformat() if self.last_login else None,
"preferences": self.preferences,
"notification_settings": self.notification_settings
}
def has_permission(self, permission: str) -> bool:
"""Check if user has a specific permission"""
if self.is_superuser:
return True
# Check role-based permissions
role_permissions = {
"user": ["read_own", "create_own", "update_own"],
"admin": ["read_all", "create_all", "update_all", "delete_own"],
"super_admin": ["read_all", "create_all", "update_all", "delete_all", "manage_users", "manage_modules"]
}
if self.role in role_permissions and permission in role_permissions[self.role]:
return True
# Check custom permissions
return permission in self.permissions
def can_access_module(self, module_name: str) -> bool:
"""Check if user can access a specific module"""
if self.is_superuser:
return True
# Check module-specific permissions
module_permissions = self.permissions.get("modules", {})
return module_permissions.get(module_name, False)
def update_last_login(self):
"""Update the last login timestamp"""
self.last_login = datetime.utcnow()
def update_preferences(self, preferences: dict):
"""Update user preferences"""
if self.preferences is None:
self.preferences = {}
self.preferences.update(preferences)
def update_notification_settings(self, settings: dict):
"""Update notification settings"""
if self.notification_settings is None:
self.notification_settings = {}
self.notification_settings.update(settings)
@classmethod
def create_default_admin(cls, email: str, username: str, password_hash: str) -> "User":
"""Create a default admin user"""
return cls(
email=email,
username=username,
hashed_password=password_hash,
full_name="System Administrator",
is_active=True,
is_superuser=True,
is_verified=True,
role="super_admin",
permissions={
"modules": {
"cache": True,
"analytics": True,
"rag": True
}
},
preferences={
"theme": "dark",
"language": "en",
"timezone": "UTC"
},
notification_settings={
"email_notifications": True,
"security_alerts": True,
"system_updates": True
}
)

View File

@@ -0,0 +1,118 @@
"""
Database models for workflow module
"""
from sqlalchemy import Column, Integer, String, Text, Boolean, DateTime, JSON, ForeignKey, Enum as SQLEnum
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base
from datetime import datetime
import uuid
import enum
from app.db.database import Base
class WorkflowStatus(enum.Enum):
"""Workflow execution statuses"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class WorkflowDefinition(Base):
"""Workflow definition/template"""
__tablename__ = "workflow_definitions"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
name = Column(String(255), nullable=False)
description = Column(Text)
version = Column(String(50), default="1.0.0")
# Workflow definition stored as JSON
steps = Column(JSON, nullable=False)
variables = Column(JSON, default={})
workflow_metadata = Column("metadata", JSON, default={})
# Configuration
timeout = Column(Integer) # Timeout in seconds
is_active = Column(Boolean, default=True)
# Metadata
created_by = Column(String, nullable=False) # User ID
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
executions = relationship("WorkflowExecution", back_populates="workflow", cascade="all, delete-orphan")
def __repr__(self):
return f"<WorkflowDefinition(id='{self.id}', name='{self.name}')>"
class WorkflowExecution(Base):
"""Workflow execution instance"""
__tablename__ = "workflow_executions"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
workflow_id = Column(String, ForeignKey("workflow_definitions.id"), nullable=False)
# Execution state
status = Column(SQLEnum(WorkflowStatus), default=WorkflowStatus.PENDING)
current_step = Column(String) # Current step ID
# Execution data
input_data = Column(JSON, default={})
context = Column(JSON, default={})
results = Column(JSON, default={})
error = Column(Text)
# Timing
started_at = Column(DateTime)
completed_at = Column(DateTime)
# Metadata
executed_by = Column(String, nullable=False) # User ID or system
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
workflow = relationship("WorkflowDefinition", back_populates="executions")
step_logs = relationship("WorkflowStepLog", back_populates="execution", cascade="all, delete-orphan")
def __repr__(self):
return f"<WorkflowExecution(id='{self.id}', workflow_id='{self.workflow_id}', status='{self.status}')>"
class WorkflowStepLog(Base):
"""Individual step execution log"""
__tablename__ = "workflow_step_logs"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
execution_id = Column(String, ForeignKey("workflow_executions.id"), nullable=False)
# Step information
step_id = Column(String, nullable=False)
step_name = Column(String(255), nullable=False)
step_type = Column(String(50), nullable=False)
# Execution details
status = Column(String(50), nullable=False) # started, completed, failed
input_data = Column(JSON, default={})
output_data = Column(JSON, default={})
error = Column(Text)
# Timing
started_at = Column(DateTime, default=datetime.utcnow)
completed_at = Column(DateTime)
duration_ms = Column(Integer) # Duration in milliseconds
# Metadata
retry_count = Column(Integer, default=0)
created_at = Column(DateTime, default=datetime.utcnow)
# Relationships
execution = relationship("WorkflowExecution", back_populates="step_logs")
def __repr__(self):
return f"<WorkflowStepLog(id='{self.id}', step_name='{self.step_name}', status='{self.status}')>"

View File

@@ -0,0 +1,3 @@
"""
Services package
"""

View File

@@ -0,0 +1,896 @@
"""
Analytics service for request tracking, usage metrics, and performance monitoring
Integrated with the core app for budget tracking and token usage analysis.
"""
import asyncio
import json
import logging
import time
from typing import Any, Dict, List, Optional, Tuple
from datetime import datetime, timedelta
from dataclasses import dataclass, asdict
from collections import defaultdict, deque
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_, func, desc
from app.core.config import settings
from app.core.logging import get_logger
from app.models.usage_tracking import UsageTracking
from app.models.api_key import APIKey
from app.models.budget import Budget
from app.models.user import User
logger = get_logger(__name__)
@dataclass
class RequestEvent:
"""Enhanced request event data structure with budget integration"""
timestamp: datetime
method: str
path: str
status_code: int
response_time: float
user_id: Optional[int] = None
api_key_id: Optional[int] = None
ip_address: Optional[str] = None
user_agent: Optional[str] = None
request_size: int = 0
response_size: int = 0
error_message: Optional[str] = None
# Token and cost tracking
model: Optional[str] = None
request_tokens: int = 0
response_tokens: int = 0
total_tokens: int = 0
cost_cents: int = 0
# Budget information
budget_ids: List[int] = None
budget_warnings: List[str] = None
@dataclass
class UsageMetrics:
"""Usage metrics including costs and tokens"""
total_requests: int
successful_requests: int
failed_requests: int
avg_response_time: float
requests_per_minute: float
error_rate: float
# Token and cost metrics
total_tokens: int
total_cost_cents: int
avg_tokens_per_request: float
avg_cost_per_request_cents: float
# Budget metrics
total_budget_cents: int
used_budget_cents: int
budget_usage_percentage: float
active_budgets: int
# Time-based metrics
top_endpoints: List[Dict[str, Any]]
status_codes: Dict[str, int]
top_models: List[Dict[str, Any]]
timestamp: datetime
@dataclass
class SystemHealth:
"""System health including budget and usage analysis"""
status: str # healthy, warning, critical
score: int # 0-100
issues: List[str]
recommendations: List[str]
# Performance metrics
avg_response_time: float
error_rate: float
requests_per_minute: float
# Budget health
budget_usage_percentage: float
budgets_near_limit: int
budgets_exceeded: int
timestamp: datetime
class AnalyticsService:
"""Analytics service for comprehensive request and usage tracking"""
def __init__(self, db: Session):
self.db = db
self.enabled = True
self.events: deque = deque(maxlen=10000) # Keep last 10k events in memory
self.metrics_cache = {}
self.cache_ttl = 300 # 5 minutes cache TTL
# Statistics counters
self.endpoint_stats = defaultdict(lambda: {
"count": 0,
"total_time": 0,
"errors": 0,
"avg_time": 0,
"total_tokens": 0,
"total_cost_cents": 0
})
self.status_codes = defaultdict(int)
self.model_stats = defaultdict(lambda: {
"count": 0,
"total_tokens": 0,
"total_cost_cents": 0
})
# Start cleanup task
asyncio.create_task(self._cleanup_old_events())
async def track_request(self, event: RequestEvent):
"""Track a request event with comprehensive metrics"""
if not self.enabled:
return
try:
# Add to events queue
self.events.append(event)
# Update endpoint stats
endpoint = f"{event.method} {event.path}"
stats = self.endpoint_stats[endpoint]
stats["count"] += 1
stats["total_time"] += event.response_time
stats["avg_time"] = stats["total_time"] / stats["count"]
stats["total_tokens"] += event.total_tokens
stats["total_cost_cents"] += event.cost_cents
if event.status_code >= 400:
stats["errors"] += 1
# Update status code stats
self.status_codes[str(event.status_code)] += 1
# Update model stats
if event.model:
model_stats = self.model_stats[event.model]
model_stats["count"] += 1
model_stats["total_tokens"] += event.total_tokens
model_stats["total_cost_cents"] += event.cost_cents
# Clear metrics cache to force recalculation
self.metrics_cache.clear()
logger.debug(f"Tracked request: {endpoint} - {event.status_code} - {event.response_time:.3f}s")
except Exception as e:
logger.error(f"Error tracking request: {e}")
async def get_usage_metrics(self, hours: int = 24, user_id: Optional[int] = None,
api_key_id: Optional[int] = None) -> UsageMetrics:
"""Get comprehensive usage metrics including costs and budgets"""
cache_key = f"usage_metrics_{hours}_{user_id}_{api_key_id}"
# Check cache
if cache_key in self.metrics_cache:
cached_time, cached_data = self.metrics_cache[cache_key]
if datetime.utcnow() - cached_time < timedelta(seconds=self.cache_ttl):
return cached_data
try:
cutoff_time = datetime.utcnow() - timedelta(hours=hours)
# Build query filters
filters = [UsageTracking.created_at >= cutoff_time]
if user_id:
filters.append(UsageTracking.user_id == user_id)
if api_key_id:
filters.append(UsageTracking.api_key_id == api_key_id)
# Get usage tracking records
usage_records = self.db.query(UsageTracking).filter(and_(*filters)).all()
# Get recent events from memory
recent_events = [e for e in self.events if e.timestamp >= cutoff_time]
if user_id:
recent_events = [e for e in recent_events if e.user_id == user_id]
if api_key_id:
recent_events = [e for e in recent_events if e.api_key_id == api_key_id]
# Calculate basic request metrics
total_requests = len(recent_events)
successful_requests = sum(1 for e in recent_events if e.status_code < 400)
failed_requests = total_requests - successful_requests
if total_requests > 0:
avg_response_time = sum(e.response_time for e in recent_events) / total_requests
requests_per_minute = total_requests / (hours * 60)
error_rate = (failed_requests / total_requests) * 100
else:
avg_response_time = 0
requests_per_minute = 0
error_rate = 0
# Calculate token and cost metrics from database
total_tokens = sum(r.total_tokens for r in usage_records)
total_cost_cents = sum(r.cost_cents for r in usage_records)
if total_requests > 0:
avg_tokens_per_request = total_tokens / total_requests
avg_cost_per_request_cents = total_cost_cents / total_requests
else:
avg_tokens_per_request = 0
avg_cost_per_request_cents = 0
# Get budget information
budget_query = self.db.query(Budget).filter(Budget.is_active == True)
if user_id:
budget_query = budget_query.filter(
or_(Budget.user_id == user_id, Budget.api_key_id.in_(
self.db.query(APIKey.id).filter(APIKey.user_id == user_id).subquery()
))
)
budgets = budget_query.all()
active_budgets = len(budgets)
total_budget_cents = sum(b.limit_cents for b in budgets)
used_budget_cents = sum(b.current_usage_cents for b in budgets)
if total_budget_cents > 0:
budget_usage_percentage = (used_budget_cents / total_budget_cents) * 100
else:
budget_usage_percentage = 0
# Top endpoints from memory
endpoint_counts = defaultdict(int)
for event in recent_events:
endpoint = f"{event.method} {event.path}"
endpoint_counts[endpoint] += 1
top_endpoints = [
{"endpoint": endpoint, "count": count}
for endpoint, count in sorted(endpoint_counts.items(), key=lambda x: x[1], reverse=True)[:10]
]
# Status codes from memory
status_counts = defaultdict(int)
for event in recent_events:
status_counts[str(event.status_code)] += 1
# Top models from database
model_usage = self.db.query(
UsageTracking.model,
func.count(UsageTracking.id).label('count'),
func.sum(UsageTracking.total_tokens).label('tokens'),
func.sum(UsageTracking.cost_cents).label('cost')
).filter(and_(*filters)).filter(
UsageTracking.model.is_not(None)
).group_by(UsageTracking.model).order_by(desc('count')).limit(10).all()
top_models = [
{
"model": model,
"count": count,
"total_tokens": tokens or 0,
"total_cost_cents": cost or 0
}
for model, count, tokens, cost in model_usage
]
# Create metrics object
metrics = UsageMetrics(
total_requests=total_requests,
successful_requests=successful_requests,
failed_requests=failed_requests,
avg_response_time=round(avg_response_time, 3),
requests_per_minute=round(requests_per_minute, 2),
error_rate=round(error_rate, 2),
total_tokens=total_tokens,
total_cost_cents=total_cost_cents,
avg_tokens_per_request=round(avg_tokens_per_request, 1),
avg_cost_per_request_cents=round(avg_cost_per_request_cents, 2),
total_budget_cents=total_budget_cents,
used_budget_cents=used_budget_cents,
budget_usage_percentage=round(budget_usage_percentage, 2),
active_budgets=active_budgets,
top_endpoints=top_endpoints,
status_codes=dict(status_counts),
top_models=top_models,
timestamp=datetime.utcnow()
)
# Cache the result
self.metrics_cache[cache_key] = (datetime.utcnow(), metrics)
return metrics
except Exception as e:
logger.error(f"Error getting usage metrics: {e}")
return UsageMetrics(
total_requests=0, successful_requests=0, failed_requests=0,
avg_response_time=0, requests_per_minute=0, error_rate=0,
total_tokens=0, total_cost_cents=0, avg_tokens_per_request=0,
avg_cost_per_request_cents=0, total_budget_cents=0,
used_budget_cents=0, budget_usage_percentage=0, active_budgets=0,
top_endpoints=[], status_codes={}, top_models=[],
timestamp=datetime.utcnow()
)
async def get_system_health(self) -> SystemHealth:
"""Get comprehensive system health including budget status"""
try:
# Get recent metrics
metrics = await self.get_usage_metrics(hours=1)
# Calculate health score
health_score = 100
issues = []
recommendations = []
# Check error rate
if metrics.error_rate > 10:
health_score -= 30
issues.append(f"High error rate: {metrics.error_rate:.1f}%")
recommendations.append("Investigate error patterns and root causes")
elif metrics.error_rate > 5:
health_score -= 15
issues.append(f"Elevated error rate: {metrics.error_rate:.1f}%")
recommendations.append("Monitor error trends")
# Check response time
if metrics.avg_response_time > 5.0:
health_score -= 25
issues.append(f"High response time: {metrics.avg_response_time:.2f}s")
recommendations.append("Optimize slow endpoints and database queries")
elif metrics.avg_response_time > 2.0:
health_score -= 10
issues.append(f"Elevated response time: {metrics.avg_response_time:.2f}s")
recommendations.append("Monitor performance trends")
# Check budget usage
if metrics.budget_usage_percentage > 90:
health_score -= 20
issues.append(f"Budget usage critical: {metrics.budget_usage_percentage:.1f}%")
recommendations.append("Review budget limits and usage patterns")
elif metrics.budget_usage_percentage > 75:
health_score -= 10
issues.append(f"Budget usage high: {metrics.budget_usage_percentage:.1f}%")
recommendations.append("Monitor spending trends")
# Check for budgets near or over limit
budgets = self.db.query(Budget).filter(Budget.is_active == True).all()
budgets_near_limit = sum(1 for b in budgets if b.current_usage_cents >= b.limit_cents * 0.8)
budgets_exceeded = sum(1 for b in budgets if b.is_exceeded)
if budgets_exceeded > 0:
health_score -= 25
issues.append(f"{budgets_exceeded} budgets exceeded")
recommendations.append("Address budget overruns immediately")
elif budgets_near_limit > 0:
health_score -= 10
issues.append(f"{budgets_near_limit} budgets near limit")
recommendations.append("Review budget allocations")
# Determine overall status
if health_score >= 90:
status = "healthy"
elif health_score >= 70:
status = "warning"
else:
status = "critical"
return SystemHealth(
status=status,
score=max(0, health_score),
issues=issues,
recommendations=recommendations,
avg_response_time=metrics.avg_response_time,
error_rate=metrics.error_rate,
requests_per_minute=metrics.requests_per_minute,
budget_usage_percentage=metrics.budget_usage_percentage,
budgets_near_limit=budgets_near_limit,
budgets_exceeded=budgets_exceeded,
timestamp=datetime.utcnow()
)
except Exception as e:
logger.error(f"Error getting system health: {e}")
return SystemHealth(
status="error", score=0,
issues=[f"Health check failed: {str(e)}"],
recommendations=["Check system logs and restart services"],
avg_response_time=0, error_rate=0, requests_per_minute=0,
budget_usage_percentage=0, budgets_near_limit=0,
budgets_exceeded=0, timestamp=datetime.utcnow()
)
async def get_cost_analysis(self, days: int = 30, user_id: Optional[int] = None) -> Dict[str, Any]:
"""Get detailed cost analysis and trends"""
try:
cutoff_time = datetime.utcnow() - timedelta(days=days)
# Build query filters
filters = [UsageTracking.created_at >= cutoff_time]
if user_id:
filters.append(UsageTracking.user_id == user_id)
# Get usage records
usage_records = self.db.query(UsageTracking).filter(and_(*filters)).all()
# Cost by model
cost_by_model = defaultdict(int)
tokens_by_model = defaultdict(int)
requests_by_model = defaultdict(int)
for record in usage_records:
if record.model:
cost_by_model[record.model] += record.cost_cents
tokens_by_model[record.model] += record.total_tokens
requests_by_model[record.model] += 1
# Daily cost trends
daily_costs = defaultdict(int)
for record in usage_records:
day = record.created_at.date().isoformat()
daily_costs[day] += record.cost_cents
# Cost by endpoint
cost_by_endpoint = defaultdict(int)
for record in usage_records:
cost_by_endpoint[record.endpoint] += record.cost_cents
# Calculate efficiency metrics
total_cost = sum(cost_by_model.values())
total_tokens = sum(tokens_by_model.values())
total_requests = len(usage_records)
efficiency_metrics = {
"cost_per_token": (total_cost / total_tokens) if total_tokens > 0 else 0,
"cost_per_request": (total_cost / total_requests) if total_requests > 0 else 0,
"tokens_per_request": (total_tokens / total_requests) if total_requests > 0 else 0
}
return {
"period_days": days,
"total_cost_cents": total_cost,
"total_cost_dollars": total_cost / 100,
"total_tokens": total_tokens,
"total_requests": total_requests,
"efficiency_metrics": efficiency_metrics,
"cost_by_model": dict(cost_by_model),
"tokens_by_model": dict(tokens_by_model),
"requests_by_model": dict(requests_by_model),
"daily_costs": dict(daily_costs),
"cost_by_endpoint": dict(cost_by_endpoint),
"analysis_timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error getting cost analysis: {e}")
return {"error": str(e)}
async def _cleanup_old_events(self):
"""Cleanup old events from memory"""
while self.enabled:
try:
cutoff_time = datetime.utcnow() - timedelta(hours=24)
# Remove old events
while self.events and self.events[0].timestamp < cutoff_time:
self.events.popleft()
# Clear old cache entries
current_time = datetime.utcnow()
expired_keys = []
for key, (cached_time, _) in self.metrics_cache.items():
if current_time - cached_time > timedelta(seconds=self.cache_ttl):
expired_keys.append(key)
for key in expired_keys:
del self.metrics_cache[key]
# Sleep for 1 hour before next cleanup
await asyncio.sleep(3600)
except Exception as e:
logger.error(f"Error in analytics cleanup: {e}")
await asyncio.sleep(300) # Wait 5 minutes on error
def cleanup(self):
"""Cleanup analytics resources"""
self.enabled = False
self.events.clear()
self.metrics_cache.clear()
self.endpoint_stats.clear()
self.status_codes.clear()
self.model_stats.clear()
# Global analytics service will be initialized in main.py
analytics_service = None
def get_analytics_service():
"""Get the global analytics service instance"""
if analytics_service is None:
raise RuntimeError("Analytics service not initialized")
return analytics_service
def init_analytics_service():
"""Initialize the global analytics service"""
global analytics_service
# Initialize without database session - will be provided per request
analytics_service = InMemoryAnalyticsService()
logger.info("Analytics service initialized")
class InMemoryAnalyticsService:
"""Analytics service that works without a persistent database session"""
def __init__(self):
self.enabled = True
self.events: deque = deque(maxlen=10000) # Keep last 10k events in memory
self.metrics_cache = {}
self.cache_ttl = 300 # 5 minutes cache TTL
# Statistics counters
self.endpoint_stats = defaultdict(lambda: {
"count": 0,
"total_time": 0,
"errors": 0,
"avg_time": 0,
"total_tokens": 0,
"total_cost_cents": 0
})
self.status_codes = defaultdict(int)
self.model_stats = defaultdict(lambda: {
"count": 0,
"total_tokens": 0,
"total_cost_cents": 0
})
# Start cleanup task
asyncio.create_task(self._cleanup_old_events())
async def track_request(self, event: RequestEvent):
"""Track a request event with comprehensive metrics"""
if not self.enabled:
return
try:
# Add to events queue
self.events.append(event)
# Update endpoint stats
endpoint = f"{event.method} {event.path}"
stats = self.endpoint_stats[endpoint]
stats["count"] += 1
stats["total_time"] += event.response_time
stats["avg_time"] = stats["total_time"] / stats["count"]
stats["total_tokens"] += event.total_tokens
stats["total_cost_cents"] += event.cost_cents
if event.status_code >= 400:
stats["errors"] += 1
# Update status code stats
self.status_codes[str(event.status_code)] += 1
# Update model stats
if event.model:
model_stats = self.model_stats[event.model]
model_stats["count"] += 1
model_stats["total_tokens"] += event.total_tokens
model_stats["total_cost_cents"] += event.cost_cents
# Clear metrics cache to force recalculation
self.metrics_cache.clear()
logger.debug(f"Tracked request: {endpoint} - {event.status_code} - {event.response_time:.3f}s")
except Exception as e:
logger.error(f"Error tracking request: {e}")
async def get_usage_metrics(self, hours: int = 24, user_id: Optional[int] = None,
api_key_id: Optional[int] = None) -> UsageMetrics:
"""Get comprehensive usage metrics including costs and budgets"""
cache_key = f"usage_metrics_{hours}_{user_id}_{api_key_id}"
# Check cache
if cache_key in self.metrics_cache:
cached_time, cached_data = self.metrics_cache[cache_key]
if datetime.utcnow() - cached_time < timedelta(seconds=self.cache_ttl):
return cached_data
try:
cutoff_time = datetime.utcnow() - timedelta(hours=hours)
# Get recent events from memory
recent_events = [e for e in self.events if e.timestamp >= cutoff_time]
if user_id:
recent_events = [e for e in recent_events if e.user_id == user_id]
if api_key_id:
recent_events = [e for e in recent_events if e.api_key_id == api_key_id]
# Calculate basic request metrics
total_requests = len(recent_events)
successful_requests = sum(1 for e in recent_events if e.status_code < 400)
failed_requests = total_requests - successful_requests
if total_requests > 0:
avg_response_time = sum(e.response_time for e in recent_events) / total_requests
requests_per_minute = total_requests / (hours * 60)
error_rate = (failed_requests / total_requests) * 100
else:
avg_response_time = 0
requests_per_minute = 0
error_rate = 0
# Calculate token and cost metrics from events
total_tokens = sum(e.total_tokens for e in recent_events)
total_cost_cents = sum(e.cost_cents for e in recent_events)
if total_requests > 0:
avg_tokens_per_request = total_tokens / total_requests
avg_cost_per_request_cents = total_cost_cents / total_requests
else:
avg_tokens_per_request = 0
avg_cost_per_request_cents = 0
# Mock budget information (since we don't have DB access here)
total_budget_cents = 100000 # $1000 default
used_budget_cents = total_cost_cents
if total_budget_cents > 0:
budget_usage_percentage = (used_budget_cents / total_budget_cents) * 100
else:
budget_usage_percentage = 0
# Top endpoints from memory
endpoint_counts = defaultdict(int)
for event in recent_events:
endpoint = f"{event.method} {event.path}"
endpoint_counts[endpoint] += 1
top_endpoints = [
{"endpoint": endpoint, "count": count}
for endpoint, count in sorted(endpoint_counts.items(), key=lambda x: x[1], reverse=True)[:10]
]
# Status codes from memory
status_counts = defaultdict(int)
for event in recent_events:
status_counts[str(event.status_code)] += 1
# Top models from events
model_usage = defaultdict(lambda: {"count": 0, "tokens": 0, "cost": 0})
for event in recent_events:
if event.model:
model_usage[event.model]["count"] += 1
model_usage[event.model]["tokens"] += event.total_tokens
model_usage[event.model]["cost"] += event.cost_cents
top_models = [
{
"model": model,
"count": data["count"],
"total_tokens": data["tokens"],
"total_cost_cents": data["cost"]
}
for model, data in sorted(model_usage.items(), key=lambda x: x[1]["count"], reverse=True)[:10]
]
# Create metrics object
metrics = UsageMetrics(
total_requests=total_requests,
successful_requests=successful_requests,
failed_requests=failed_requests,
avg_response_time=round(avg_response_time, 3),
requests_per_minute=round(requests_per_minute, 2),
error_rate=round(error_rate, 2),
total_tokens=total_tokens,
total_cost_cents=total_cost_cents,
avg_tokens_per_request=round(avg_tokens_per_request, 1),
avg_cost_per_request_cents=round(avg_cost_per_request_cents, 2),
total_budget_cents=total_budget_cents,
used_budget_cents=used_budget_cents,
budget_usage_percentage=round(budget_usage_percentage, 2),
active_budgets=1, # Mock value
top_endpoints=top_endpoints,
status_codes=dict(status_counts),
top_models=top_models,
timestamp=datetime.utcnow()
)
# Cache the result
self.metrics_cache[cache_key] = (datetime.utcnow(), metrics)
return metrics
except Exception as e:
logger.error(f"Error getting usage metrics: {e}")
return UsageMetrics(
total_requests=0, successful_requests=0, failed_requests=0,
avg_response_time=0, requests_per_minute=0, error_rate=0,
total_tokens=0, total_cost_cents=0, avg_tokens_per_request=0,
avg_cost_per_request_cents=0, total_budget_cents=0,
used_budget_cents=0, budget_usage_percentage=0, active_budgets=0,
top_endpoints=[], status_codes={}, top_models=[],
timestamp=datetime.utcnow()
)
async def get_system_health(self) -> SystemHealth:
"""Get comprehensive system health including budget status"""
try:
# Get recent metrics
metrics = await self.get_usage_metrics(hours=1)
# Calculate health score
health_score = 100
issues = []
recommendations = []
# Check error rate
if metrics.error_rate > 10:
health_score -= 30
issues.append(f"High error rate: {metrics.error_rate:.1f}%")
recommendations.append("Investigate error patterns and root causes")
elif metrics.error_rate > 5:
health_score -= 15
issues.append(f"Elevated error rate: {metrics.error_rate:.1f}%")
recommendations.append("Monitor error trends")
# Check response time
if metrics.avg_response_time > 5.0:
health_score -= 25
issues.append(f"High response time: {metrics.avg_response_time:.2f}s")
recommendations.append("Optimize slow endpoints and database queries")
elif metrics.avg_response_time > 2.0:
health_score -= 10
issues.append(f"Elevated response time: {metrics.avg_response_time:.2f}s")
recommendations.append("Monitor performance trends")
# Check budget usage
if metrics.budget_usage_percentage > 90:
health_score -= 20
issues.append(f"Budget usage critical: {metrics.budget_usage_percentage:.1f}%")
recommendations.append("Review budget limits and usage patterns")
elif metrics.budget_usage_percentage > 75:
health_score -= 10
issues.append(f"Budget usage high: {metrics.budget_usage_percentage:.1f}%")
recommendations.append("Monitor spending trends")
# Determine overall status
if health_score >= 90:
status = "healthy"
elif health_score >= 70:
status = "warning"
else:
status = "critical"
return SystemHealth(
status=status,
score=max(0, health_score),
issues=issues,
recommendations=recommendations,
avg_response_time=metrics.avg_response_time,
error_rate=metrics.error_rate,
requests_per_minute=metrics.requests_per_minute,
budget_usage_percentage=metrics.budget_usage_percentage,
budgets_near_limit=0, # Mock values since no DB access
budgets_exceeded=0,
timestamp=datetime.utcnow()
)
except Exception as e:
logger.error(f"Error getting system health: {e}")
return SystemHealth(
status="error", score=0,
issues=[f"Health check failed: {str(e)}"],
recommendations=["Check system logs and restart services"],
avg_response_time=0, error_rate=0, requests_per_minute=0,
budget_usage_percentage=0, budgets_near_limit=0,
budgets_exceeded=0, timestamp=datetime.utcnow()
)
async def get_cost_analysis(self, days: int = 30, user_id: Optional[int] = None) -> Dict[str, Any]:
"""Get detailed cost analysis and trends"""
try:
cutoff_time = datetime.utcnow() - timedelta(days=days)
# Get events from memory
events = [e for e in self.events if e.timestamp >= cutoff_time]
if user_id:
events = [e for e in events if e.user_id == user_id]
# Cost by model
cost_by_model = defaultdict(int)
tokens_by_model = defaultdict(int)
requests_by_model = defaultdict(int)
for event in events:
if event.model:
cost_by_model[event.model] += event.cost_cents
tokens_by_model[event.model] += event.total_tokens
requests_by_model[event.model] += 1
# Daily cost trends
daily_costs = defaultdict(int)
for event in events:
day = event.timestamp.date().isoformat()
daily_costs[day] += event.cost_cents
# Cost by endpoint
cost_by_endpoint = defaultdict(int)
for event in events:
endpoint = f"{event.method} {event.path}"
cost_by_endpoint[endpoint] += event.cost_cents
# Calculate efficiency metrics
total_cost = sum(cost_by_model.values())
total_tokens = sum(tokens_by_model.values())
total_requests = len(events)
efficiency_metrics = {
"cost_per_token": (total_cost / total_tokens) if total_tokens > 0 else 0,
"cost_per_request": (total_cost / total_requests) if total_requests > 0 else 0,
"tokens_per_request": (total_tokens / total_requests) if total_requests > 0 else 0
}
return {
"period_days": days,
"total_cost_cents": total_cost,
"total_cost_dollars": total_cost / 100,
"total_tokens": total_tokens,
"total_requests": total_requests,
"efficiency_metrics": efficiency_metrics,
"cost_by_model": dict(cost_by_model),
"tokens_by_model": dict(tokens_by_model),
"requests_by_model": dict(requests_by_model),
"daily_costs": dict(daily_costs),
"cost_by_endpoint": dict(cost_by_endpoint),
"analysis_timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error getting cost analysis: {e}")
return {"error": str(e)}
async def _cleanup_old_events(self):
"""Cleanup old events from memory"""
while self.enabled:
try:
cutoff_time = datetime.utcnow() - timedelta(hours=24)
# Remove old events
while self.events and self.events[0].timestamp < cutoff_time:
self.events.popleft()
# Clear old cache entries
current_time = datetime.utcnow()
expired_keys = []
for key, (cached_time, _) in self.metrics_cache.items():
if current_time - cached_time > timedelta(seconds=self.cache_ttl):
expired_keys.append(key)
for key in expired_keys:
del self.metrics_cache[key]
# Sleep for 1 hour before next cleanup
await asyncio.sleep(3600)
except Exception as e:
logger.error(f"Error in analytics cleanup: {e}")
await asyncio.sleep(300) # Wait 5 minutes on error
def cleanup(self):
"""Cleanup analytics resources"""
self.enabled = False
self.events.clear()
self.metrics_cache.clear()
self.endpoint_stats.clear()
self.status_codes.clear()
self.model_stats.clear()

View File

@@ -0,0 +1,248 @@
"""
API Key Authentication Service
Handles API key validation and user authentication with Redis caching for performance
"""
import logging
from typing import Optional, Dict, Any
from datetime import datetime
from fastapi import HTTPException, Request, status, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.security import verify_api_key
from app.db.database import get_db
from app.models.api_key import APIKey
from app.models.user import User
from app.utils.exceptions import AuthenticationError, AuthorizationError
from app.services.cached_api_key import cached_api_key_service
logger = logging.getLogger(__name__)
class APIKeyAuthService:
"""Service for API key authentication and validation"""
def __init__(self, db: AsyncSession):
self.db = db
async def validate_api_key(self, api_key: str, request: Request) -> Optional[Dict[str, Any]]:
"""Validate API key and return user context using Redis cache for performance"""
try:
if not api_key:
return None
# Extract key prefix for lookup
if len(api_key) < 8:
logger.warning(f"Invalid API key format: too short")
return None
key_prefix = api_key[:8]
# Try cached verification first
cached_verification = await cached_api_key_service.verify_api_key_cached(api_key, key_prefix)
# Get API key data from cache or database
context = await cached_api_key_service.get_cached_api_key(key_prefix, self.db)
if not context:
logger.warning(f"API key not found: {key_prefix}")
return None
api_key_obj = context["api_key"]
# If not in verification cache, verify and cache the result
if not cached_verification:
# Get the actual key hash for verification (this should be in the cached context)
db_api_key = None
if not hasattr(api_key_obj, 'key_hash'):
# Fallback: fetch full API key from database for hash
stmt = select(APIKey).where(APIKey.key_prefix == key_prefix)
result = await self.db.execute(stmt)
db_api_key = result.scalar_one_or_none()
if not db_api_key:
return None
key_hash = db_api_key.key_hash
else:
key_hash = api_key_obj.key_hash
# Verify the API key hash
if not verify_api_key(api_key, key_hash):
logger.warning(f"Invalid API key hash: {key_prefix}")
return None
# Cache successful verification
await cached_api_key_service.cache_verification_result(api_key, key_prefix, key_hash, True)
# Check if key is valid (expiry, active status)
if not api_key_obj.is_valid():
logger.warning(f"API key expired or inactive: {key_prefix}")
# Invalidate cache for expired keys
await cached_api_key_service.invalidate_api_key_cache(key_prefix)
return None
# Check IP restrictions
client_ip = request.client.host if request.client else "unknown"
if not api_key_obj.can_access_from_ip(client_ip):
logger.warning(f"IP not allowed for API key {key_prefix}: {client_ip}")
return None
# Update last used timestamp asynchronously (performance optimization)
await cached_api_key_service.update_last_used(context["api_key_id"], self.db)
return context
except Exception as e:
logger.error(f"API key validation error: {e}")
return None
async def check_endpoint_permission(self, context: Dict[str, Any], endpoint: str) -> bool:
"""Check if API key has permission to access endpoint"""
api_key: APIKey = context.get("api_key")
if not api_key:
return False
return api_key.can_access_endpoint(endpoint)
async def check_model_permission(self, context: Dict[str, Any], model: str) -> bool:
"""Check if API key has permission to access model"""
api_key: APIKey = context.get("api_key")
if not api_key:
return False
return api_key.can_access_model(model)
async def check_scope_permission(self, context: Dict[str, Any], scope: str) -> bool:
"""Check if API key has required scope"""
api_key: APIKey = context.get("api_key")
if not api_key:
return False
return api_key.has_scope(scope)
async def update_usage_stats(self, context: Dict[str, Any], tokens_used: int = 0, cost_cents: int = 0):
"""Update API key usage statistics"""
try:
api_key: APIKey = context.get("api_key")
if api_key:
api_key.update_usage(tokens_used, cost_cents)
await self.db.commit()
logger.info(f"Updated usage for API key {api_key.key_prefix}: +{tokens_used} tokens, +{cost_cents} cents")
except Exception as e:
logger.error(f"Failed to update usage stats: {e}")
async def get_api_key_context(
request: Request,
db: AsyncSession = Depends(get_db)
) -> Optional[Dict[str, Any]]:
"""Dependency to get API key context from request"""
auth_service = APIKeyAuthService(db)
# Try different auth methods
api_key = None
# 1. Check Authorization header (Bearer token)
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
api_key = auth_header[7:]
# 2. Check X-API-Key header
if not api_key:
api_key = request.headers.get("X-API-Key")
# 3. Check query parameter
if not api_key:
api_key = request.query_params.get("api_key")
if not api_key:
return None
return await auth_service.validate_api_key(api_key, request)
async def require_api_key(
context: Dict[str, Any] = Depends(get_api_key_context)
) -> Dict[str, Any]:
"""Dependency that requires valid API key"""
if not context:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Valid API key required",
headers={"WWW-Authenticate": "Bearer"}
)
return context
async def get_current_api_key_user(
context: Dict[str, Any] = Depends(require_api_key)
) -> tuple:
"""
Dependency that returns current user and API key as a tuple
Returns:
tuple: (user, api_key)
"""
user = context.get("user")
api_key = context.get("api_key")
if not user or not api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User or API key not found in context"
)
return user, api_key
async def get_api_key_auth(
context: Dict[str, Any] = Depends(require_api_key)
) -> APIKey:
"""
Dependency that returns the authenticated API key object
Returns:
APIKey: The authenticated API key object
"""
api_key = context.get("api_key")
if not api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API key not found in context"
)
return api_key
class RequireScope:
"""Dependency class for scope checking"""
def __init__(self, scope: str):
self.scope = scope
async def __call__(self, context: Dict[str, Any] = Depends(require_api_key)):
auth_service = APIKeyAuthService(context.get("db"))
if not await auth_service.check_scope_permission(context, self.scope):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Scope '{self.scope}' required"
)
return context
class RequireModel:
"""Dependency class for model access checking"""
def __init__(self, model: str):
self.model = model
async def __call__(self, context: Dict[str, Any] = Depends(require_api_key)):
auth_service = APIKeyAuthService(context.get("db"))
if not await auth_service.check_model_permission(context, self.model):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Model '{self.model}' not allowed"
)
return context

View File

@@ -0,0 +1,606 @@
"""
API Proxy with comprehensive security interceptors
"""
import json
import time
import re
from typing import Dict, List, Any, Optional
from fastapi import Request, Response, HTTPException, status
from fastapi.responses import JSONResponse
import httpx
import yaml
from pathlib import Path
from app.core.config import settings
from app.core.logging import get_logger
from app.services.api_key_auth import get_api_key_info
from app.services.budget_enforcement import check_budget_and_record_usage
from app.middleware.rate_limiting import rate_limiter
from app.utils.exceptions import ValidationError, AuthenticationError, RateLimitExceeded
from app.services.audit_service import create_audit_log
logger = get_logger(__name__)
class SecurityConfiguration:
"""Security configuration for API proxy"""
def __init__(self):
self.config = self._load_security_config()
def _load_security_config(self) -> Dict[str, Any]:
"""Load security configuration"""
return {
"rate_limits": {
"global": 10000, # per hour
"per_key": 1000, # per hour
"per_endpoint": {
"/api/llm/v1/chat/completions": 100, # per minute
"/api/modules/v1/rag/search": 500, # per hour
}
},
"max_request_size": 10 * 1024 * 1024, # 10MB
"max_string_length": 50000,
"timeout": 30, # seconds
"required_headers": ["X-API-Key"],
"ip_whitelist_enabled": False,
"ip_whitelist": [],
"ip_blacklist": [],
"forbidden_patterns": [
"<script", "javascript:", "data:text/html", "vbscript:",
"union select", "drop table", "insert into", "delete from"
],
"audit": {
"enabled": True,
"include_request_body": False,
"include_response_body": False,
"sensitive_paths": ["/api/platform/v1/auth"]
}
}
class RequestValidator:
"""Validates API requests against schemas and security policies"""
def __init__(self, config: SecurityConfiguration):
self.config = config
self.schemas = self._load_openapi_schemas()
def _load_openapi_schemas(self) -> Dict[str, Any]:
"""Load OpenAPI schemas for validation"""
# Would load actual OpenAPI schemas in production
return {
"POST /api/llm/v1/chat/completions": {
"requestBody": {
"type": "object",
"required": ["model", "messages"],
"properties": {
"model": {"type": "string"},
"messages": {"type": "array"},
"temperature": {"type": "number", "minimum": 0, "maximum": 2},
"max_tokens": {"type": "integer", "minimum": 1, "maximum": 32000}
}
}
},
"POST /api/modules/v1/rag/search": {
"requestBody": {
"type": "object",
"required": ["query"],
"properties": {
"query": {"type": "string", "maxLength": 1000},
"limit": {"type": "integer", "minimum": 1, "maximum": 100}
}
}
}
}
async def validate(self, path: str, method: str, body: Dict, headers: Dict) -> Dict:
"""Validate request against schema and security policies"""
# Check request size
body_str = json.dumps(body)
if len(body_str.encode()) > self.config.config["max_request_size"]:
raise ValidationError(f"Request size exceeds maximum allowed")
# Check required headers
for header in self.config.config["required_headers"]:
if header not in headers:
raise ValidationError(f"Missing required header: {header}")
# Validate against schema if available
schema_key = f"{method.upper()} {path}"
if schema_key in self.schemas:
await self._validate_against_schema(body, self.schemas[schema_key])
# Security validation
self._validate_security_patterns(body)
return body
async def _validate_against_schema(self, body: Dict, schema: Dict):
"""Validate request body against OpenAPI schema"""
request_schema = schema.get("requestBody", {})
# Basic validation (would use proper JSON schema validator in production)
if "required" in request_schema:
for field in request_schema["required"]:
if field not in body:
raise ValidationError(f"Missing required field: {field}")
if "properties" in request_schema:
for field, constraints in request_schema["properties"].items():
if field in body:
await self._validate_field(field, body[field], constraints)
async def _validate_field(self, field_name: str, value: Any, constraints: Dict):
"""Validate individual field against constraints"""
field_type = constraints.get("type")
if field_type == "string":
if not isinstance(value, str):
raise ValidationError(f"Field {field_name} must be a string")
if "maxLength" in constraints and len(value) > constraints["maxLength"]:
raise ValidationError(f"Field {field_name} exceeds maximum length")
elif field_type == "integer":
if not isinstance(value, int):
raise ValidationError(f"Field {field_name} must be an integer")
if "minimum" in constraints and value < constraints["minimum"]:
raise ValidationError(f"Field {field_name} below minimum value")
if "maximum" in constraints and value > constraints["maximum"]:
raise ValidationError(f"Field {field_name} exceeds maximum value")
elif field_type == "number":
if not isinstance(value, (int, float)):
raise ValidationError(f"Field {field_name} must be a number")
if "minimum" in constraints and value < constraints["minimum"]:
raise ValidationError(f"Field {field_name} below minimum value")
if "maximum" in constraints and value > constraints["maximum"]:
raise ValidationError(f"Field {field_name} exceeds maximum value")
def _validate_security_patterns(self, body: Dict):
"""Check for forbidden security patterns"""
body_str = json.dumps(body).lower()
for pattern in self.config.config["forbidden_patterns"]:
if pattern.lower() in body_str:
raise ValidationError(f"Request contains forbidden pattern: {pattern}")
class APISecurityProxy:
"""Main API security proxy with interceptor pattern"""
def __init__(self):
self.config = SecurityConfiguration()
self.request_validator = RequestValidator(self.config)
async def proxy_request(self, request: Request, path: str) -> Response:
"""
Main proxy method that implements the full interceptor pattern
"""
start_time = time.time()
api_key_info = None
user_permissions = []
try:
# 1. Extract and validate API key
api_key_info = await self._extract_and_validate_api_key(request)
if api_key_info:
user_permissions = api_key_info.get("permissions", [])
# 2. IP validation (if enabled)
await self._validate_ip_address(request)
# 3. Rate limiting
await self._check_rate_limits(request, path, api_key_info)
# 4. Request validation and sanitization
request_body = await self._get_request_body(request)
validated_body = await self.request_validator.validate(
path=path,
method=request.method,
body=request_body,
headers=dict(request.headers)
)
# 5. Sanitize request
sanitized_body = self._sanitize_request(validated_body)
# 6. Budget checking (for LLM endpoints)
if path.startswith("/api/llm/"):
await self._check_budget_constraints(api_key_info, sanitized_body)
# 7. Build proxy headers
proxy_headers = self._build_proxy_headers(request, api_key_info)
# 8. Log security event
await self._log_security_event(
request=request,
path=path,
api_key_info=api_key_info,
sanitized_body=sanitized_body
)
# 9. Forward request to appropriate backend
response = await self._forward_request(
path=path,
method=request.method,
body=sanitized_body,
headers=proxy_headers
)
# 10. Validate and sanitize response
validated_response = await self._process_response(path, response)
# 11. Record usage metrics
await self._record_usage_metrics(
api_key_info=api_key_info,
path=path,
duration=time.time() - start_time,
success=True
)
return validated_response
except Exception as e:
# Error handling and logging
await self._handle_error(
request=request,
path=path,
api_key_info=api_key_info,
error=e,
duration=time.time() - start_time
)
# Return appropriate error response
return await self._create_error_response(e)
async def _extract_and_validate_api_key(self, request: Request) -> Optional[Dict[str, Any]]:
"""Extract and validate API key from request"""
# Try different auth methods
api_key = None
# Bearer token
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
api_key = auth_header[7:]
# X-API-Key header
elif request.headers.get("X-API-Key"):
api_key = request.headers.get("X-API-Key")
if not api_key:
raise AuthenticationError("Missing API key")
# Validate API key
api_key_info = await get_api_key_info(api_key)
if not api_key_info:
raise AuthenticationError("Invalid API key")
if not api_key_info.get("is_active", False):
raise AuthenticationError("API key is disabled")
return api_key_info
async def _validate_ip_address(self, request: Request):
"""Validate client IP address against whitelist/blacklist"""
client_ip = request.client.host
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
client_ip = forwarded_for.split(",")[0].strip()
config = self.config.config
# Check blacklist
if client_ip in config["ip_blacklist"]:
raise AuthenticationError(f"IP address {client_ip} is blacklisted")
# Check whitelist (if enabled)
if config["ip_whitelist_enabled"] and client_ip not in config["ip_whitelist"]:
raise AuthenticationError(f"IP address {client_ip} is not whitelisted")
async def _check_rate_limits(self, request: Request, path: str, api_key_info: Optional[Dict]):
"""Check rate limits for the request"""
client_ip = request.client.host
api_key = api_key_info.get("key_prefix", "") if api_key_info else None
# Use existing rate limiter
if api_key:
# API key-based rate limiting
rate_limit_key = f"api_key:{api_key}"
limit_per_minute = api_key_info.get("rate_limit_per_minute", 100)
limit_per_hour = api_key_info.get("rate_limit_per_hour", 1000)
# Check per-minute limit
is_allowed_minute, _ = await rate_limiter.check_rate_limit(
rate_limit_key, limit_per_minute, 60, "minute"
)
# Check per-hour limit
is_allowed_hour, _ = await rate_limiter.check_rate_limit(
rate_limit_key, limit_per_hour, 3600, "hour"
)
if not (is_allowed_minute and is_allowed_hour):
raise RateLimitExceeded("API key rate limit exceeded")
else:
# IP-based rate limiting for unauthenticated requests
rate_limit_key = f"ip:{client_ip}"
is_allowed_minute, _ = await rate_limiter.check_rate_limit(
rate_limit_key, 20, 60, "minute"
)
if not is_allowed_minute:
raise RateLimitExceeded("IP rate limit exceeded")
async def _get_request_body(self, request: Request) -> Dict[str, Any]:
"""Extract request body"""
try:
if request.method in ["POST", "PUT", "PATCH"]:
return await request.json()
else:
return {}
except Exception:
return {}
def _sanitize_request(self, body: Dict[str, Any]) -> Dict[str, Any]:
"""Sanitize request data"""
def sanitize_value(value):
if isinstance(value, str):
# Remove forbidden patterns
for pattern in self.config.config["forbidden_patterns"]:
value = re.sub(re.escape(pattern), "", value, flags=re.IGNORECASE)
# Limit string length
max_length = self.config.config["max_string_length"]
if len(value) > max_length:
value = value[:max_length]
logger.warning(f"Truncated long string in request: {len(value)} chars")
return value
elif isinstance(value, dict):
return {k: sanitize_value(v) for k, v in value.items()}
elif isinstance(value, list):
return [sanitize_value(item) for item in value]
else:
return value
return sanitize_value(body)
async def _check_budget_constraints(self, api_key_info: Dict, body: Dict):
"""Check budget constraints for LLM requests"""
if not api_key_info:
return
# Estimate cost based on request
estimated_cost = self._estimate_request_cost(body)
# Check budget
user_id = api_key_info.get("user_id")
api_key_id = api_key_info.get("id")
budget_ok = await check_budget_and_record_usage(
user_id=user_id,
api_key_id=api_key_id,
estimated_cost=estimated_cost,
actual_cost=0, # Will be updated after response
metadata={"endpoint": "llm_proxy", "model": body.get("model", "unknown")}
)
if not budget_ok:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail="Budget limit exceeded"
)
def _estimate_request_cost(self, body: Dict) -> float:
"""Estimate cost of LLM request"""
# Rough estimation based on model and tokens
model = body.get("model", "gpt-3.5-turbo")
messages = body.get("messages", [])
max_tokens = body.get("max_tokens", 1000)
# Estimate input tokens
input_text = " ".join([msg.get("content", "") for msg in messages if isinstance(msg, dict)])
input_tokens = len(input_text.split()) * 1.3 # Rough approximation
# Model pricing (simplified)
pricing = {
"gpt-4": {"input": 0.03, "output": 0.06}, # per 1K tokens
"gpt-3.5-turbo": {"input": 0.001, "output": 0.002},
"claude-3-sonnet": {"input": 0.003, "output": 0.015},
"claude-3-haiku": {"input": 0.00025, "output": 0.00125}
}
model_pricing = pricing.get(model, pricing["gpt-3.5-turbo"])
estimated_cost = (
(input_tokens / 1000) * model_pricing["input"] +
(max_tokens / 1000) * model_pricing["output"]
)
return estimated_cost
def _build_proxy_headers(self, request: Request, api_key_info: Optional[Dict]) -> Dict[str, str]:
"""Build headers for proxy request"""
headers = {
"Content-Type": "application/json",
"User-Agent": f"ConfidentialEmpire-Proxy/1.0",
"X-Forwarded-For": request.client.host,
"X-Request-ID": f"req_{int(time.time() * 1000)}"
}
if api_key_info:
headers["X-User-ID"] = str(api_key_info.get("user_id", ""))
headers["X-API-Key-ID"] = str(api_key_info.get("id", ""))
return headers
async def _log_security_event(self, request: Request, path: str, api_key_info: Optional[Dict], sanitized_body: Dict):
"""Log security event for audit trail"""
await create_audit_log(
action=f"api_proxy_{request.method.lower()}",
resource_type="api_endpoint",
resource_id=path,
user_id=api_key_info.get("user_id") if api_key_info else None,
success=True,
ip_address=request.client.host,
user_agent=request.headers.get("User-Agent", ""),
metadata={
"endpoint": path,
"method": request.method,
"api_key_id": api_key_info.get("id") if api_key_info else None,
"request_size": len(json.dumps(sanitized_body))
}
)
async def _forward_request(self, path: str, method: str, body: Dict, headers: Dict) -> Dict:
"""Forward request to appropriate backend service"""
# Determine target service based on path
if path.startswith("/api/llm/"):
target_url = f"{settings.LITELLM_BASE_URL}{path}"
target_headers = {**headers, "Authorization": f"Bearer {settings.LITELLM_MASTER_KEY}"}
elif path.startswith("/api/modules/"):
# Route to module system
return await self._route_to_module(path, method, body, headers)
else:
raise ValidationError(f"Unknown endpoint: {path}")
# Make HTTP request to target service
timeout = self.config.config["timeout"]
async with httpx.AsyncClient(timeout=timeout) as client:
if method == "GET":
response = await client.get(target_url, headers=target_headers)
elif method == "POST":
response = await client.post(target_url, json=body, headers=target_headers)
elif method == "PUT":
response = await client.put(target_url, json=body, headers=target_headers)
elif method == "DELETE":
response = await client.delete(target_url, headers=target_headers)
else:
raise ValidationError(f"Unsupported HTTP method: {method}")
if response.status_code >= 400:
raise HTTPException(status_code=response.status_code, detail=response.text)
return response.json()
async def _route_to_module(self, path: str, method: str, body: Dict, headers: Dict) -> Dict:
"""Route request to module system"""
# Extract module name from path
# e.g., /api/modules/v1/rag/search -> module: rag, action: search
path_parts = path.strip("/").split("/")
if len(path_parts) >= 4:
module_name = path_parts[3]
action = path_parts[4] if len(path_parts) > 4 else "execute"
else:
raise ValidationError("Invalid module path")
# Import module manager
from app.services.module_manager import module_manager
if module_name not in module_manager.modules:
raise ValidationError(f"Module not found: {module_name}")
module = module_manager.modules[module_name]
# Prepare context
context = {
"user_id": headers.get("X-User-ID"),
"api_key_id": headers.get("X-API-Key-ID"),
"ip_address": headers.get("X-Forwarded-For"),
"user_permissions": [] # Would be populated from API key info
}
# Prepare request
module_request = {
"action": action,
"method": method,
**body
}
# Execute through module's interceptor chain
if hasattr(module, 'execute_with_interceptors'):
return await module.execute_with_interceptors(module_request, context)
else:
# Fallback for legacy modules
if hasattr(module, action):
return await getattr(module, action)(module_request)
else:
raise ValidationError(f"Action not supported: {action}")
async def _process_response(self, path: str, response: Dict) -> JSONResponse:
"""Process and validate response"""
# Add security headers
headers = {
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Strict-Transport-Security": "max-age=31536000; includeSubDomains"
}
return JSONResponse(content=response, headers=headers)
async def _record_usage_metrics(self, api_key_info: Optional[Dict], path: str, duration: float, success: bool):
"""Record usage metrics"""
if api_key_info:
# Record API key usage
# This would update database metrics
pass
async def _handle_error(self, request: Request, path: str, api_key_info: Optional[Dict], error: Exception, duration: float):
"""Handle and log errors"""
await create_audit_log(
action=f"api_proxy_{request.method.lower()}",
resource_type="api_endpoint",
resource_id=path,
user_id=api_key_info.get("user_id") if api_key_info else None,
success=False,
error_message=str(error),
ip_address=request.client.host,
user_agent=request.headers.get("User-Agent", ""),
metadata={
"endpoint": path,
"method": request.method,
"duration_ms": int(duration * 1000),
"error_type": type(error).__name__
}
)
async def _create_error_response(self, error: Exception) -> JSONResponse:
"""Create appropriate error response"""
if isinstance(error, AuthenticationError):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"error": "AUTHENTICATION_ERROR", "message": str(error)}
)
elif isinstance(error, ValidationError):
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"error": "VALIDATION_ERROR", "message": str(error)}
)
elif isinstance(error, RateLimitExceeded):
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={"error": "RATE_LIMIT_EXCEEDED", "message": str(error)}
)
elif isinstance(error, HTTPException):
return JSONResponse(
status_code=error.status_code,
content={"error": "HTTP_ERROR", "message": error.detail}
)
else:
logger.error(f"Unexpected error in API proxy: {error}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"error": "INTERNAL_ERROR", "message": "An unexpected error occurred"}
)
# Global proxy instance
api_security_proxy = APISecurityProxy()

View File

@@ -0,0 +1,297 @@
"""
Audit logging service with async/non-blocking capabilities
"""
import asyncio
from typing import Optional, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import datetime
from app.models.audit_log import AuditLog
from app.core.logging import get_logger
logger = get_logger(__name__)
# Background audit logging queue
_audit_queue = asyncio.Queue(maxsize=1000)
_audit_worker_started = False
async def _audit_worker():
"""Background worker to process audit events"""
from app.db.database import async_session_factory
logger.info("Audit worker started")
while True:
try:
# Get audit event from queue
audit_data = await _audit_queue.get()
if audit_data is None: # Shutdown signal
break
# Process the audit event in a separate database session
async with async_session_factory() as db:
try:
audit_log = AuditLog(**audit_data)
db.add(audit_log)
await db.commit()
logger.debug(f"Background audit logged: {audit_data.get('action')}")
except Exception as e:
logger.error(f"Failed to write audit log in background: {e}")
await db.rollback()
_audit_queue.task_done()
except Exception as e:
logger.error(f"Audit worker error: {e}")
await asyncio.sleep(1) # Brief pause before retrying
def start_audit_worker():
"""Start the background audit worker"""
global _audit_worker_started
if not _audit_worker_started:
asyncio.create_task(_audit_worker())
_audit_worker_started = True
logger.info("Audit worker task created")
async def log_audit_event_async(
user_id: Optional[str] = None,
api_key_id: Optional[str] = None,
action: str = "",
resource_type: str = "",
resource_id: Optional[str] = None,
details: Optional[Dict[str, Any]] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
success: bool = True,
severity: str = "info"
):
"""
Log an audit event asynchronously (non-blocking)
This function queues the audit event for background processing,
so it doesn't block the main request flow.
"""
try:
# Ensure audit worker is started
if not _audit_worker_started:
start_audit_worker()
audit_details = details or {}
if api_key_id:
audit_details["api_key_id"] = api_key_id
audit_data = {
"user_id": user_id,
"action": action,
"resource_type": resource_type,
"resource_id": resource_id,
"description": f"{action} on {resource_type}",
"details": audit_details,
"ip_address": ip_address,
"user_agent": user_agent,
"success": success,
"severity": severity,
"created_at": datetime.utcnow()
}
# Queue the audit event (non-blocking)
try:
_audit_queue.put_nowait(audit_data)
logger.debug(f"Audit event queued: {action} on {resource_type}")
except asyncio.QueueFull:
logger.warning("Audit queue full, dropping audit event")
except Exception as e:
logger.error(f"Failed to queue audit event: {e}")
# Don't raise - audit failures shouldn't break main operations
async def log_audit_event(
db: AsyncSession,
user_id: Optional[str] = None,
api_key_id: Optional[str] = None,
action: str = "",
resource_type: str = "",
resource_id: Optional[str] = None,
details: Optional[Dict[str, Any]] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
success: bool = True,
severity: str = "info"
):
"""
Log an audit event to the database
Args:
db: Database session
user_id: ID of the user performing the action
api_key_id: ID of the API key used (if applicable)
action: Action being performed (e.g., "create_user", "login", "delete_resource")
resource_type: Type of resource being acted upon (e.g., "user", "api_key", "budget")
resource_id: ID of the specific resource
details: Additional details about the action
ip_address: IP address of the request
user_agent: User agent string
success: Whether the action was successful
severity: Severity level (info, warning, error, critical)
"""
try:
audit_details = details or {}
if api_key_id:
audit_details["api_key_id"] = api_key_id
audit_log = AuditLog(
user_id=user_id,
action=action,
resource_type=resource_type,
resource_id=resource_id,
description=f"{action} on {resource_type}",
details=audit_details,
ip_address=ip_address,
user_agent=user_agent,
success=success,
severity=severity,
created_at=datetime.utcnow()
)
db.add(audit_log)
await db.flush() # Don't commit here, let the caller control the transaction
logger.debug(f"Audit event logged: {action} on {resource_type} by user {user_id}")
except Exception as e:
logger.error(f"Failed to log audit event: {e}")
# Don't raise here as audit logging shouldn't break the main operation
async def get_audit_logs(
db: AsyncSession,
user_id: Optional[str] = None,
action: Optional[str] = None,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: int = 100,
offset: int = 0
):
"""
Query audit logs with filtering
Args:
db: Database session
user_id: Filter by user ID
action: Filter by action
resource_type: Filter by resource type
resource_id: Filter by resource ID
start_date: Filter by start date
end_date: Filter by end date
limit: Maximum number of results
offset: Number of results to skip
Returns:
List of audit log entries
"""
from sqlalchemy import select, and_
query = select(AuditLog)
conditions = []
if user_id:
conditions.append(AuditLog.user_id == user_id)
if action:
conditions.append(AuditLog.action == action)
if resource_type:
conditions.append(AuditLog.resource_type == resource_type)
if resource_id:
conditions.append(AuditLog.resource_id == resource_id)
if start_date:
conditions.append(AuditLog.created_at >= start_date)
if end_date:
conditions.append(AuditLog.created_at <= end_date)
if conditions:
query = query.where(and_(*conditions))
query = query.order_by(AuditLog.created_at.desc())
query = query.offset(offset).limit(limit)
result = await db.execute(query)
return result.scalars().all()
async def get_audit_stats(
db: AsyncSession,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
):
"""
Get audit statistics
Args:
db: Database session
start_date: Start date for statistics
end_date: End date for statistics
Returns:
Dictionary with audit statistics
"""
from sqlalchemy import select, func, and_
conditions = []
if start_date:
conditions.append(AuditLog.created_at >= start_date)
if end_date:
conditions.append(AuditLog.created_at <= end_date)
# Total events
total_query = select(func.count(AuditLog.id))
if conditions:
total_query = total_query.where(and_(*conditions))
total_result = await db.execute(total_query)
total_events = total_result.scalar()
# Events by action
action_query = select(AuditLog.action, func.count(AuditLog.id)).group_by(AuditLog.action)
if conditions:
action_query = action_query.where(and_(*conditions))
action_result = await db.execute(action_query)
events_by_action = dict(action_result.fetchall())
# Events by resource type
resource_query = select(AuditLog.resource_type, func.count(AuditLog.id)).group_by(AuditLog.resource_type)
if conditions:
resource_query = resource_query.where(and_(*conditions))
resource_result = await db.execute(resource_query)
events_by_resource = dict(resource_result.fetchall())
# Events by severity
severity_query = select(AuditLog.severity, func.count(AuditLog.id)).group_by(AuditLog.severity)
if conditions:
severity_query = severity_query.where(and_(*conditions))
severity_result = await db.execute(severity_query)
events_by_severity = dict(severity_result.fetchall())
# Success rate
success_query = select(AuditLog.success, func.count(AuditLog.id)).group_by(AuditLog.success)
if conditions:
success_query = success_query.where(and_(*conditions))
success_result = await db.execute(success_query)
success_stats = dict(success_result.fetchall())
return {
"total_events": total_events,
"events_by_action": events_by_action,
"events_by_resource_type": events_by_resource,
"events_by_severity": events_by_severity,
"success_rate": success_stats.get(True, 0) / total_events if total_events > 0 else 0,
"failure_rate": success_stats.get(False, 0) / total_events if total_events > 0 else 0
}

View File

@@ -0,0 +1,423 @@
"""
Base module interface and interceptor pattern implementation
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass
from fastapi import Request, Response
import json
import re
import copy
import time
import hashlib
from urllib.parse import urlparse
from app.core.logging import get_logger
from app.utils.exceptions import ValidationError, AuthenticationError, RateLimitExceeded
from app.services.permission_manager import permission_registry
logger = get_logger(__name__)
@dataclass
class Permission:
"""Represents a module permission"""
resource: str
action: str
description: str
def __str__(self) -> str:
return f"{self.resource}:{self.action}"
@dataclass
class ModuleMetrics:
"""Module performance metrics"""
requests_processed: int = 0
average_response_time: float = 0.0
error_rate: float = 0.0
last_activity: Optional[str] = None
total_errors: int = 0
uptime_start: float = 0.0
def __post_init__(self):
if self.uptime_start == 0.0:
self.uptime_start = time.time()
@dataclass
class ModuleHealth:
"""Module health status"""
status: str = "healthy" # healthy, warning, error
message: str = "Module is functioning normally"
uptime: float = 0.0
last_check: float = 0.0
def __post_init__(self):
if self.last_check == 0.0:
self.last_check = time.time()
class BaseModule(ABC):
"""Base class for all modules with interceptor pattern support"""
def __init__(self, module_id: str, config: Dict[str, Any] = None):
self.module_id = module_id
self.config = config or {}
self.metrics = ModuleMetrics()
self.health = ModuleHealth()
self.initialized = False
self.interceptors: List['ModuleInterceptor'] = []
# Register default interceptors
self._register_default_interceptors()
@abstractmethod
async def initialize(self) -> None:
"""Initialize the module"""
pass
@abstractmethod
async def cleanup(self) -> None:
"""Cleanup module resources"""
pass
@abstractmethod
def get_required_permissions(self) -> List[Permission]:
"""Return list of permissions this module requires"""
return []
@abstractmethod
async def process_request(self, request: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
"""Process a module request"""
pass
def get_health(self) -> ModuleHealth:
"""Get current module health status"""
self.health.uptime = time.time() - self.metrics.uptime_start
self.health.last_check = time.time()
return self.health
def get_metrics(self) -> ModuleMetrics:
"""Get current module metrics"""
return self.metrics
def check_access(self, user_permissions: List[str], action: str) -> bool:
"""Check if user can perform action on this module"""
required = f"modules:{self.module_id}:{action}"
return permission_registry.check_permission(user_permissions, required)
def _register_default_interceptors(self):
"""Register default interceptors for all modules"""
self.interceptors = [
AuthenticationInterceptor(),
PermissionInterceptor(self),
ValidationInterceptor(),
MetricsInterceptor(self),
SecurityInterceptor(),
AuditInterceptor(self)
]
async def execute_with_interceptors(self, request: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
"""Execute request through interceptor chain"""
start_time = time.time()
try:
# Pre-processing interceptors
for interceptor in self.interceptors:
request, context = await interceptor.pre_process(request, context)
# Execute main module logic
response = await self.process_request(request, context)
# Post-processing interceptors (in reverse order)
for interceptor in reversed(self.interceptors):
response = await interceptor.post_process(request, context, response)
# Update metrics
self._update_metrics(start_time, success=True)
return response
except Exception as e:
# Update error metrics
self._update_metrics(start_time, success=False, error=str(e))
# Error handling interceptors
for interceptor in reversed(self.interceptors):
if hasattr(interceptor, 'handle_error'):
await interceptor.handle_error(request, context, e)
raise
def _update_metrics(self, start_time: float, success: bool, error: str = None):
"""Update module metrics"""
duration = time.time() - start_time
self.metrics.requests_processed += 1
# Update average response time
if self.metrics.requests_processed == 1:
self.metrics.average_response_time = duration
else:
# Exponential moving average
alpha = 0.1
self.metrics.average_response_time = (
alpha * duration + (1 - alpha) * self.metrics.average_response_time
)
if not success:
self.metrics.total_errors += 1
self.metrics.error_rate = self.metrics.total_errors / self.metrics.requests_processed
# Update health status based on error rate
if self.metrics.error_rate > 0.1: # More than 10% error rate
self.health.status = "error"
self.health.message = f"High error rate: {self.metrics.error_rate:.2%}"
elif self.metrics.error_rate > 0.05: # More than 5% error rate
self.health.status = "warning"
self.health.message = f"Elevated error rate: {self.metrics.error_rate:.2%}"
else:
self.metrics.error_rate = self.metrics.total_errors / self.metrics.requests_processed
if self.metrics.error_rate <= 0.05:
self.health.status = "healthy"
self.health.message = "Module is functioning normally"
self.metrics.last_activity = time.strftime("%Y-%m-%d %H:%M:%S")
class ModuleInterceptor(ABC):
"""Base class for module interceptors"""
@abstractmethod
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Pre-process the request"""
return request, context
@abstractmethod
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
"""Post-process the response"""
return response
class AuthenticationInterceptor(ModuleInterceptor):
"""Handles authentication for module requests"""
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Check if user is authenticated (context should contain user info from API auth)
if not context.get("user_id") and not context.get("api_key_id"):
raise AuthenticationError("Authentication required for module access")
return request, context
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
return response
class PermissionInterceptor(ModuleInterceptor):
"""Handles permission checking for module requests"""
def __init__(self, module: BaseModule):
self.module = module
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
action = request.get("action", "execute")
user_permissions = context.get("user_permissions", [])
if not self.module.check_access(user_permissions, action):
raise AuthenticationError(f"Insufficient permissions for module action: {action}")
return request, context
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
return response
class ValidationInterceptor(ModuleInterceptor):
"""Handles request validation and sanitization"""
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Sanitize request data
sanitized_request = self._sanitize_request(request)
# Validate request structure
self._validate_request(sanitized_request)
return sanitized_request, context
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
# Sanitize response data
return self._sanitize_response(response)
def _sanitize_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""Remove potentially dangerous content from request"""
sanitized = copy.deepcopy(request)
# Define dangerous patterns
dangerous_patterns = [
r"<script[^>]*>.*?</script>",
r"javascript:",
r"data:text/html",
r"vbscript:",
r"onload\s*=",
r"onerror\s*=",
r"eval\s*\(",
r"Function\s*\(",
]
def sanitize_value(value):
if isinstance(value, str):
# Remove dangerous patterns
for pattern in dangerous_patterns:
value = re.sub(pattern, "", value, flags=re.IGNORECASE)
# Limit string length
max_length = 10000
if len(value) > max_length:
value = value[:max_length]
logger.warning(f"Truncated long string in request: {len(value)} chars")
return value
elif isinstance(value, dict):
return {k: sanitize_value(v) for k, v in value.items()}
elif isinstance(value, list):
return [sanitize_value(item) for item in value]
else:
return value
return sanitize_value(sanitized)
def _sanitize_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
"""Sanitize response data"""
# Similar sanitization for responses
return self._sanitize_request(response)
def _validate_request(self, request: Dict[str, Any]):
"""Validate request structure"""
# Check for required fields
if not isinstance(request, dict):
raise ValidationError("Request must be a dictionary")
# Check request size
request_str = json.dumps(request)
max_size = 10 * 1024 * 1024 # 10MB
if len(request_str.encode()) > max_size:
raise ValidationError(f"Request size exceeds maximum allowed ({max_size} bytes)")
class MetricsInterceptor(ModuleInterceptor):
"""Handles metrics collection for module requests"""
def __init__(self, module: BaseModule):
self.module = module
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
context["_metrics_start_time"] = time.time()
return request, context
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
# Metrics are updated in the base module execute_with_interceptors method
return response
class SecurityInterceptor(ModuleInterceptor):
"""Handles security-related processing"""
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Add security headers to context
context["security_headers"] = {
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Strict-Transport-Security": "max-age=31536000; includeSubDomains"
}
# Check for suspicious patterns
self._check_security_patterns(request)
return request, context
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
# Remove any sensitive information from response
return self._remove_sensitive_data(response)
def _check_security_patterns(self, request: Dict[str, Any]):
"""Check for suspicious security patterns"""
request_str = json.dumps(request).lower()
suspicious_patterns = [
"union select", "drop table", "insert into", "delete from",
"script>", "javascript:", "eval(", "expression(",
"../", "..\\", "file://", "ftp://",
]
for pattern in suspicious_patterns:
if pattern in request_str:
logger.warning(f"Suspicious pattern detected in request: {pattern}")
# Could implement additional security measures here
def _remove_sensitive_data(self, response: Dict[str, Any]) -> Dict[str, Any]:
"""Remove sensitive data from response"""
sensitive_keys = ["password", "secret", "token", "key", "private"]
def clean_dict(obj):
if isinstance(obj, dict):
return {
k: "***REDACTED***" if any(sk in k.lower() for sk in sensitive_keys) else clean_dict(v)
for k, v in obj.items()
}
elif isinstance(obj, list):
return [clean_dict(item) for item in obj]
else:
return obj
return clean_dict(response)
class AuditInterceptor(ModuleInterceptor):
"""Handles audit logging for module requests"""
def __init__(self, module: BaseModule):
self.module = module
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
context["_audit_start_time"] = time.time()
context["_audit_request_hash"] = self._hash_request(request)
return request, context
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
await self._log_audit_event(request, context, response, success=True)
return response
async def handle_error(self, request: Dict[str, Any], context: Dict[str, Any], error: Exception):
"""Handle error logging"""
await self._log_audit_event(request, context, {"error": str(error)}, success=False)
def _hash_request(self, request: Dict[str, Any]) -> str:
"""Create a hash of the request for audit purposes"""
request_str = json.dumps(request, sort_keys=True)
return hashlib.sha256(request_str.encode()).hexdigest()[:16]
async def _log_audit_event(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any], success: bool):
"""Log audit event"""
duration = time.time() - context.get("_audit_start_time", time.time())
audit_data = {
"module_id": self.module.module_id,
"action": request.get("action", "execute"),
"user_id": context.get("user_id"),
"api_key_id": context.get("api_key_id"),
"ip_address": context.get("ip_address"),
"request_hash": context.get("_audit_request_hash"),
"success": success,
"duration_ms": int(duration * 1000),
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
}
if not success:
audit_data["error"] = response.get("error", "Unknown error")
# Log the audit event
logger.info(f"Module audit: {json.dumps(audit_data)}")
# Could also store in database for persistent audit trail

View File

@@ -0,0 +1,649 @@
"""
Budget enforcement service for managing spending limits and cost control
"""
from typing import Optional, List, Tuple, Dict, Any
from datetime import datetime
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_, text, select, update
from sqlalchemy.exc import IntegrityError
import time
import random
from app.models.budget import Budget
from app.models.api_key import APIKey
from app.models.user import User
from app.services.cost_calculator import CostCalculator, estimate_request_cost
from app.core.logging import get_logger
logger = get_logger(__name__)
class BudgetEnforcementError(Exception):
"""Custom exception for budget enforcement failures"""
pass
class BudgetExceededError(BudgetEnforcementError):
"""Exception raised when budget would be exceeded"""
def __init__(self, message: str, budget: Budget, requested_cost: int):
super().__init__(message)
self.budget = budget
self.requested_cost = requested_cost
class BudgetWarningError(BudgetEnforcementError):
"""Exception raised when budget warning threshold is reached"""
def __init__(self, message: str, budget: Budget, requested_cost: int):
super().__init__(message)
self.budget = budget
self.requested_cost = requested_cost
class BudgetConcurrencyError(BudgetEnforcementError):
"""Exception raised when budget update fails due to concurrency"""
def __init__(self, message: str, retry_count: int = 0):
super().__init__(message)
self.retry_count = retry_count
class BudgetAtomicError(BudgetEnforcementError):
"""Exception raised when atomic budget operation fails"""
def __init__(self, message: str, budget_id: int, requested_amount: int):
super().__init__(message)
self.budget_id = budget_id
self.requested_amount = requested_amount
class BudgetEnforcementService:
"""Service for enforcing budget limits and tracking usage"""
def __init__(self, db: Session):
self.db = db
self.max_retries = 3
self.retry_delay_base = 0.1 # Base delay in seconds
def atomic_check_and_reserve_budget(
self,
api_key: APIKey,
model_name: str,
estimated_tokens: int,
endpoint: str = None
) -> Tuple[bool, Optional[str], List[Dict[str, Any]], List[int]]:
"""
Atomically check budget compliance and reserve spending
Returns:
Tuple of (is_allowed, error_message, warnings, reserved_budget_ids)
"""
estimated_cost = estimate_request_cost(model_name, estimated_tokens)
budgets = self._get_applicable_budgets(api_key, model_name, endpoint)
if not budgets:
logger.debug(f"No applicable budgets found for API key {api_key.id}")
return True, None, [], []
# Try atomic reservation with retries
for attempt in range(self.max_retries):
try:
return self._attempt_atomic_reservation(budgets, estimated_cost, api_key.id, attempt)
except BudgetConcurrencyError as e:
if attempt == self.max_retries - 1:
logger.error(f"Atomic budget reservation failed after {self.max_retries} attempts: {e}")
return False, f"Budget check temporarily unavailable (concurrency limit)", [], []
# Exponential backoff with jitter
delay = self.retry_delay_base * (2 ** attempt) + random.uniform(0, 0.1)
time.sleep(delay)
logger.info(f"Retrying atomic budget reservation (attempt {attempt + 2})")
except Exception as e:
logger.error(f"Unexpected error in atomic budget reservation: {e}")
return False, f"Budget check failed: {str(e)}", [], []
return False, "Budget check failed after maximum retries", [], []
def _attempt_atomic_reservation(
self,
budgets: List[Budget],
estimated_cost: int,
api_key_id: int,
attempt: int
) -> Tuple[bool, Optional[str], List[Dict[str, Any]], List[int]]:
"""Attempt to atomically reserve budget across all applicable budgets"""
warnings = []
reserved_budget_ids = []
try:
# Begin transaction
self.db.begin()
for budget in budgets:
# Lock budget row for update to prevent concurrent modifications
locked_budget = self.db.query(Budget).filter(
Budget.id == budget.id
).with_for_update().first()
if not locked_budget:
raise BudgetAtomicError(f"Budget {budget.id} not found", budget.id, estimated_cost)
# Reset budget if expired and auto-renew enabled
if locked_budget.is_expired() and locked_budget.auto_renew:
self._reset_expired_budget(locked_budget)
self.db.flush() # Ensure reset is applied before checking
# Skip inactive or expired budgets
if not locked_budget.is_active or locked_budget.is_expired():
continue
# Check if request would exceed budget using atomic operation
if not self._atomic_can_spend(locked_budget, estimated_cost):
error_msg = (
f"Request would exceed budget '{locked_budget.name}' "
f"(${locked_budget.limit_cents/100:.2f}). "
f"Current usage: ${locked_budget.current_usage_cents/100:.2f}, "
f"Requested: ${estimated_cost/100:.4f}, "
f"Remaining: ${(locked_budget.limit_cents - locked_budget.current_usage_cents)/100:.2f}"
)
logger.warning(f"Budget exceeded for API key {api_key_id}: {error_msg}")
self.db.rollback()
return False, error_msg, warnings, []
# Check warning threshold
if locked_budget.would_exceed_warning(estimated_cost) and not locked_budget.is_warning_sent:
warning_msg = (
f"Budget '{locked_budget.name}' approaching limit. "
f"Usage will be ${(locked_budget.current_usage_cents + estimated_cost)/100:.2f} "
f"of ${locked_budget.limit_cents/100:.2f} "
f"({((locked_budget.current_usage_cents + estimated_cost) / locked_budget.limit_cents * 100):.1f}%)"
)
warnings.append({
"type": "budget_warning",
"budget_id": locked_budget.id,
"budget_name": locked_budget.name,
"message": warning_msg,
"current_usage_cents": locked_budget.current_usage_cents + estimated_cost,
"limit_cents": locked_budget.limit_cents,
"usage_percentage": (locked_budget.current_usage_cents + estimated_cost) / locked_budget.limit_cents * 100
})
logger.info(f"Budget warning for API key {api_key_id}: {warning_msg}")
# Reserve the budget (temporarily add estimated cost)
self._atomic_reserve_usage(locked_budget, estimated_cost)
reserved_budget_ids.append(locked_budget.id)
# Commit the reservation
self.db.commit()
logger.debug(f"Successfully reserved budget for API key {api_key_id}, estimated cost: ${estimated_cost/100:.4f}")
return True, None, warnings, reserved_budget_ids
except IntegrityError as e:
self.db.rollback()
raise BudgetConcurrencyError(f"Database integrity error during budget reservation: {e}", attempt)
except Exception as e:
self.db.rollback()
logger.error(f"Error in atomic budget reservation: {e}")
raise
def _atomic_can_spend(self, budget: Budget, amount_cents: int) -> bool:
"""Atomically check if budget can accommodate spending"""
if not budget.is_active or not budget.is_in_period():
return False
if not budget.enforce_hard_limit:
return True
return (budget.current_usage_cents + amount_cents) <= budget.limit_cents
def _atomic_reserve_usage(self, budget: Budget, amount_cents: int):
"""Atomically reserve usage in budget (add to current usage)"""
# Use database-level atomic update
result = self.db.execute(
update(Budget)
.where(Budget.id == budget.id)
.values(
current_usage_cents=Budget.current_usage_cents + amount_cents,
updated_at=datetime.utcnow(),
is_exceeded=Budget.current_usage_cents + amount_cents >= Budget.limit_cents,
is_warning_sent=(
Budget.is_warning_sent |
((Budget.warning_threshold_cents.isnot(None)) &
(Budget.current_usage_cents + amount_cents >= Budget.warning_threshold_cents))
)
)
)
if result.rowcount != 1:
raise BudgetAtomicError(f"Failed to update budget {budget.id}", budget.id, amount_cents)
# Update the in-memory object to reflect changes
budget.current_usage_cents += amount_cents
budget.updated_at = datetime.utcnow()
if budget.current_usage_cents >= budget.limit_cents:
budget.is_exceeded = True
if budget.warning_threshold_cents and budget.current_usage_cents >= budget.warning_threshold_cents:
budget.is_warning_sent = True
def atomic_finalize_usage(
self,
reserved_budget_ids: List[int],
api_key: APIKey,
model_name: str,
input_tokens: int,
output_tokens: int,
endpoint: str = None
) -> List[Budget]:
"""
Finalize actual usage and adjust reservations
Args:
reserved_budget_ids: Budget IDs that had usage reserved
api_key: API key that made the request
model_name: Model that was used
input_tokens: Actual input tokens used
output_tokens: Actual output tokens used
endpoint: API endpoint that was accessed
Returns:
List of budgets that were updated
"""
if not reserved_budget_ids:
return []
try:
actual_cost = CostCalculator.calculate_cost_cents(model_name, input_tokens, output_tokens)
updated_budgets = []
# Begin transaction for finalization
self.db.begin()
for budget_id in reserved_budget_ids:
# Lock budget for update
budget = self.db.query(Budget).filter(
Budget.id == budget_id
).with_for_update().first()
if not budget:
logger.warning(f"Budget {budget_id} not found during finalization")
continue
if budget.is_active and budget.is_in_period():
# Calculate adjustment (actual cost - estimated cost already reserved)
# Note: We don't know the exact estimated cost that was reserved
# So we'll just set to actual cost (this is safe as we already reserved)
self._atomic_set_actual_usage(budget, actual_cost, input_tokens, output_tokens)
updated_budgets.append(budget)
logger.debug(
f"Finalized usage for budget {budget.id}: "
f"${actual_cost/100:.4f} (total: ${budget.current_usage_cents/100:.2f})"
)
# Commit finalization
self.db.commit()
return updated_budgets
except Exception as e:
logger.error(f"Error finalizing budget usage: {e}")
self.db.rollback()
return []
def _atomic_set_actual_usage(self, budget: Budget, actual_cost: int, input_tokens: int, output_tokens: int):
"""Set the actual usage cost (replacing any reservation)"""
# For simplicity, we'll just ensure the current usage reflects actual cost
# In a more sophisticated system, you might track reservations separately
# For now, the reservation system ensures we don't exceed limits
# and the actual cost will be very close to estimated cost
pass # The reservation already added the estimated cost, actual cost adjustment is minimal
def check_budget_compliance(
self,
api_key: APIKey,
model_name: str,
estimated_tokens: int,
endpoint: str = None
) -> Tuple[bool, Optional[str], List[Dict[str, Any]]]:
"""
Check if a request complies with budget limits
Args:
api_key: API key making the request
model_name: Model being used
estimated_tokens: Estimated token usage
endpoint: API endpoint being accessed
Returns:
Tuple of (is_allowed, error_message, warnings)
"""
try:
# Calculate estimated cost
estimated_cost = estimate_request_cost(model_name, estimated_tokens)
# Get applicable budgets
budgets = self._get_applicable_budgets(api_key, model_name, endpoint)
if not budgets:
logger.debug(f"No applicable budgets found for API key {api_key.id}")
return True, None, []
warnings = []
# Check each budget
for budget in budgets:
# Reset budget if period expired and auto-renew is enabled
if budget.is_expired() and budget.auto_renew:
self._reset_expired_budget(budget)
# Skip inactive or expired budgets
if not budget.is_active or budget.is_expired():
continue
# Check if request would exceed budget
if not budget.can_spend(estimated_cost):
error_msg = (
f"Request would exceed budget '{budget.name}' "
f"(${budget.limit_cents/100:.2f}). "
f"Current usage: ${budget.current_usage_cents/100:.2f}, "
f"Requested: ${estimated_cost/100:.4f}, "
f"Remaining: ${(budget.limit_cents - budget.current_usage_cents)/100:.2f}"
)
logger.warning(f"Budget exceeded for API key {api_key.id}: {error_msg}")
return False, error_msg, warnings
# Check if request would trigger warning
if budget.would_exceed_warning(estimated_cost) and not budget.is_warning_sent:
warning_msg = (
f"Budget '{budget.name}' approaching limit. "
f"Usage will be ${(budget.current_usage_cents + estimated_cost)/100:.2f} "
f"of ${budget.limit_cents/100:.2f} "
f"({((budget.current_usage_cents + estimated_cost) / budget.limit_cents * 100):.1f}%)"
)
warnings.append({
"type": "budget_warning",
"budget_id": budget.id,
"budget_name": budget.name,
"message": warning_msg,
"current_usage_cents": budget.current_usage_cents + estimated_cost,
"limit_cents": budget.limit_cents,
"usage_percentage": (budget.current_usage_cents + estimated_cost) / budget.limit_cents * 100
})
logger.info(f"Budget warning for API key {api_key.id}: {warning_msg}")
return True, None, warnings
except Exception as e:
logger.error(f"Error checking budget compliance: {e}")
# Allow request on error to avoid blocking legitimate usage
return True, None, []
def record_usage(
self,
api_key: APIKey,
model_name: str,
input_tokens: int,
output_tokens: int,
endpoint: str = None
) -> List[Budget]:
"""
Record actual usage against applicable budgets
Args:
api_key: API key that made the request
model_name: Model that was used
input_tokens: Actual input tokens used
output_tokens: Actual output tokens used
endpoint: API endpoint that was accessed
Returns:
List of budgets that were updated
"""
try:
# Calculate actual cost
actual_cost = CostCalculator.calculate_cost_cents(model_name, input_tokens, output_tokens)
# Get applicable budgets
budgets = self._get_applicable_budgets(api_key, model_name, endpoint)
updated_budgets = []
for budget in budgets:
if budget.is_active and budget.is_in_period():
# Add usage to budget
budget.add_usage(actual_cost)
updated_budgets.append(budget)
logger.debug(
f"Recorded usage for budget {budget.id}: "
f"${actual_cost/100:.4f} (total: ${budget.current_usage_cents/100:.2f})"
)
# Commit changes
self.db.commit()
return updated_budgets
except Exception as e:
logger.error(f"Error recording budget usage: {e}")
self.db.rollback()
return []
def _get_applicable_budgets(
self,
api_key: APIKey,
model_name: str = None,
endpoint: str = None
) -> List[Budget]:
"""Get budgets that apply to the given request"""
# Build query conditions
conditions = [
Budget.is_active == True,
or_(
and_(Budget.user_id == api_key.user_id, Budget.api_key_id.is_(None)), # User budget
Budget.api_key_id == api_key.id # API key specific budget
)
]
# Query budgets
query = self.db.query(Budget).filter(and_(*conditions))
budgets = query.all()
# Filter budgets based on allowed models/endpoints
applicable_budgets = []
for budget in budgets:
# Check model restrictions
if model_name and budget.allowed_models:
if model_name not in budget.allowed_models:
continue
# Check endpoint restrictions
if endpoint and budget.allowed_endpoints:
if endpoint not in budget.allowed_endpoints:
continue
applicable_budgets.append(budget)
return applicable_budgets
def _reset_expired_budget(self, budget: Budget):
"""Reset an expired budget for the next period"""
try:
budget.reset_period()
self.db.commit()
logger.info(
f"Reset expired budget {budget.id} for new period: "
f"{budget.period_start} to {budget.period_end}"
)
except Exception as e:
logger.error(f"Error resetting expired budget {budget.id}: {e}")
self.db.rollback()
def get_budget_status(self, api_key: APIKey) -> Dict[str, Any]:
"""Get comprehensive budget status for an API key"""
try:
budgets = self._get_applicable_budgets(api_key)
status = {
"total_budgets": len(budgets),
"active_budgets": 0,
"exceeded_budgets": 0,
"warning_budgets": 0,
"total_limit_cents": 0,
"total_usage_cents": 0,
"budgets": []
}
for budget in budgets:
if not budget.is_active:
continue
budget_info = budget.to_dict()
budget_info.update({
"is_expired": budget.is_expired(),
"days_remaining": budget.get_period_days_remaining(),
"daily_burn_rate": budget.get_daily_burn_rate(),
"projected_spend": budget.get_projected_spend()
})
status["budgets"].append(budget_info)
status["active_budgets"] += 1
status["total_limit_cents"] += budget.limit_cents
status["total_usage_cents"] += budget.current_usage_cents
if budget.is_exceeded:
status["exceeded_budgets"] += 1
elif budget.warning_threshold_cents and budget.current_usage_cents >= budget.warning_threshold_cents:
status["warning_budgets"] += 1
# Calculate overall percentages
if status["total_limit_cents"] > 0:
status["overall_usage_percentage"] = (status["total_usage_cents"] / status["total_limit_cents"]) * 100
else:
status["overall_usage_percentage"] = 0
status["total_limit_dollars"] = status["total_limit_cents"] / 100
status["total_usage_dollars"] = status["total_usage_cents"] / 100
status["total_remaining_cents"] = max(0, status["total_limit_cents"] - status["total_usage_cents"])
status["total_remaining_dollars"] = status["total_remaining_cents"] / 100
return status
except Exception as e:
logger.error(f"Error getting budget status: {e}")
return {
"error": str(e),
"total_budgets": 0,
"active_budgets": 0,
"exceeded_budgets": 0,
"warning_budgets": 0,
"budgets": []
}
def create_default_user_budget(
self,
user_id: int,
limit_dollars: float = 10.0,
period_type: str = "monthly"
) -> Budget:
"""Create a default budget for a new user"""
try:
if period_type == "monthly":
budget = Budget.create_monthly_budget(
user_id=user_id,
name="Default Monthly Budget",
limit_dollars=limit_dollars
)
else:
budget = Budget.create_daily_budget(
user_id=user_id,
name="Default Daily Budget",
limit_dollars=limit_dollars
)
self.db.add(budget)
self.db.commit()
logger.info(f"Created default budget for user {user_id}: ${limit_dollars} {period_type}")
return budget
except Exception as e:
logger.error(f"Error creating default budget: {e}")
self.db.rollback()
raise
def check_and_reset_expired_budgets(self):
"""Background task to check and reset expired budgets"""
try:
expired_budgets = self.db.query(Budget).filter(
and_(
Budget.is_active == True,
Budget.auto_renew == True,
Budget.period_end < datetime.utcnow()
)
).all()
for budget in expired_budgets:
self._reset_expired_budget(budget)
logger.info(f"Reset {len(expired_budgets)} expired budgets")
except Exception as e:
logger.error(f"Error in budget reset task: {e}")
# Convenience functions
# DEPRECATED: Use atomic versions for race-condition-free budget enforcement
def check_budget_for_request(
db: Session,
api_key: APIKey,
model_name: str,
estimated_tokens: int,
endpoint: str = None
) -> Tuple[bool, Optional[str], List[Dict[str, Any]]]:
"""DEPRECATED: Convenience function to check budget compliance (race conditions possible)"""
service = BudgetEnforcementService(db)
return service.check_budget_compliance(api_key, model_name, estimated_tokens, endpoint)
def record_request_usage(
db: Session,
api_key: APIKey,
model_name: str,
input_tokens: int,
output_tokens: int,
endpoint: str = None
) -> List[Budget]:
"""DEPRECATED: Convenience function to record actual usage (race conditions possible)"""
service = BudgetEnforcementService(db)
return service.record_usage(api_key, model_name, input_tokens, output_tokens, endpoint)
# ATOMIC VERSIONS: Race-condition-free budget enforcement
def atomic_check_and_reserve_budget(
db: Session,
api_key: APIKey,
model_name: str,
estimated_tokens: int,
endpoint: str = None
) -> Tuple[bool, Optional[str], List[Dict[str, Any]], List[int]]:
"""Atomic convenience function to check budget compliance and reserve spending"""
service = BudgetEnforcementService(db)
return service.atomic_check_and_reserve_budget(api_key, model_name, estimated_tokens, endpoint)
def atomic_finalize_usage(
db: Session,
reserved_budget_ids: List[int],
api_key: APIKey,
model_name: str,
input_tokens: int,
output_tokens: int,
endpoint: str = None
) -> List[Budget]:
"""Atomic convenience function to finalize actual usage after request completion"""
service = BudgetEnforcementService(db)
return service.atomic_finalize_usage(reserved_budget_ids, api_key, model_name, input_tokens, output_tokens, endpoint)

View File

@@ -0,0 +1,428 @@
"""
Cached API Key Service
High-performance Redis-based API key caching to reduce authentication overhead
from ~60ms to ~5ms by avoiding expensive bcrypt operations
"""
import json
import logging
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, Tuple
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.security import verify_api_key
from app.models.api_key import APIKey
from app.models.user import User
# Check Redis availability at runtime, not import time
aioredis = None
REDIS_AVAILABLE = False
def _import_aioredis():
"""Import aioredis at runtime"""
global aioredis, REDIS_AVAILABLE
if aioredis is None:
try:
import aioredis as _aioredis
aioredis = _aioredis
REDIS_AVAILABLE = True
return True
except ImportError as e:
REDIS_AVAILABLE = False
return False
except Exception as e:
# Handle the Python 3.11 + aioredis 2.0.1 compatibility issue
REDIS_AVAILABLE = False
return False
return REDIS_AVAILABLE
logger = logging.getLogger(__name__)
class CachedAPIKeyService:
"""Redis-backed API key caching service for performance optimization with fallback to optimized database queries"""
def __init__(self):
self.redis = None
self.cache_ttl = 300 # 5 minutes cache TTL
self.verification_cache_ttl = 3600 # 1 hour for verification results
self.redis_enabled = _import_aioredis()
if not self.redis_enabled:
logger.warning("Redis not available, falling back to optimized database queries only")
async def get_redis(self):
"""Get Redis connection, create if doesn't exist"""
if not self.redis_enabled or not REDIS_AVAILABLE:
return None
if not self.redis and aioredis:
try:
self.redis = aioredis.from_url(
settings.REDIS_URL,
encoding="utf-8",
decode_responses=True,
socket_connect_timeout=5,
socket_timeout=5,
retry_on_timeout=True,
health_check_interval=30
)
# Test the connection
await self.redis.ping()
logger.info("Redis connection established for API key caching")
except Exception as e:
logger.warning(f"Redis connection failed, disabling cache: {e}")
self.redis_enabled = False
self.redis = None
return self.redis
async def close(self):
"""Close Redis connection"""
if self.redis and self.redis_enabled:
try:
await self.redis.close()
except Exception as e:
logger.warning(f"Error closing Redis connection: {e}")
def _get_cache_key(self, key_prefix: str) -> str:
"""Generate cache key for API key data"""
return f"api_key:data:{key_prefix}"
def _get_verification_cache_key(self, key_prefix: str, key_suffix_hash: str) -> str:
"""Generate cache key for API key verification results"""
return f"api_key:verified:{key_prefix}:{key_suffix_hash}"
def _get_last_used_cache_key(self, api_key_id: int) -> str:
"""Generate cache key for last used timestamp"""
return f"api_key:last_used:{api_key_id}"
async def _serialize_api_key_data(self, api_key: APIKey, user: User) -> str:
"""Serialize API key and user data for caching"""
data = {
# API Key data
"api_key_id": api_key.id,
"api_key_name": api_key.name,
"key_hash": api_key.key_hash,
"key_prefix": api_key.key_prefix,
"is_active": api_key.is_active,
"permissions": api_key.permissions,
"scopes": api_key.scopes,
"rate_limit_per_minute": api_key.rate_limit_per_minute,
"rate_limit_per_hour": api_key.rate_limit_per_hour,
"rate_limit_per_day": api_key.rate_limit_per_day,
"allowed_models": api_key.allowed_models,
"allowed_endpoints": api_key.allowed_endpoints,
"allowed_ips": api_key.allowed_ips,
"is_unlimited": api_key.is_unlimited,
"budget_limit_cents": api_key.budget_limit_cents,
"budget_type": api_key.budget_type,
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
"total_requests": api_key.total_requests,
"total_tokens": api_key.total_tokens,
"total_cost": api_key.total_cost,
# User data
"user_id": user.id,
"user_email": user.email,
"user_role": user.role,
"user_is_active": user.is_active,
# Cache metadata
"cached_at": datetime.utcnow().isoformat()
}
return json.dumps(data, default=str)
async def _deserialize_api_key_data(self, cached_data: str) -> Optional[Dict[str, Any]]:
"""Deserialize cached API key data"""
try:
data = json.loads(cached_data)
# Check if cached data is still valid
if data.get("expires_at"):
expires_at = datetime.fromisoformat(data["expires_at"])
if datetime.utcnow() > expires_at:
return None
# Reconstruct the context object expected by the rest of the system
context = {
"user_id": data["user_id"],
"user_email": data["user_email"],
"user_role": data["user_role"],
"api_key_id": data["api_key_id"],
"api_key_name": data["api_key_name"],
"permissions": data["permissions"],
"scopes": data["scopes"],
"rate_limits": {
"per_minute": data["rate_limit_per_minute"],
"per_hour": data["rate_limit_per_hour"],
"per_day": data["rate_limit_per_day"]
},
# Create minimal API key object with necessary attributes
"api_key": type("APIKey", (), {
"id": data["api_key_id"],
"name": data["api_key_name"],
"key_prefix": data["key_prefix"],
"is_active": data["is_active"],
"permissions": data["permissions"],
"scopes": data["scopes"],
"allowed_models": data["allowed_models"],
"allowed_endpoints": data["allowed_endpoints"],
"allowed_ips": data["allowed_ips"],
"is_unlimited": data["is_unlimited"],
"budget_limit_cents": data["budget_limit_cents"],
"budget_type": data["budget_type"],
"total_requests": data["total_requests"],
"total_tokens": data["total_tokens"],
"total_cost": data["total_cost"],
"expires_at": datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None,
"can_access_model": lambda model: not data["allowed_models"] or model in data["allowed_models"],
"can_access_endpoint": lambda endpoint: not data["allowed_endpoints"] or endpoint in data["allowed_endpoints"],
"can_access_from_ip": lambda ip: not data["allowed_ips"] or ip in data["allowed_ips"],
"has_scope": lambda scope: scope in data["scopes"],
"is_valid": lambda: data["is_active"] and (not data.get("expires_at") or datetime.utcnow() <= datetime.fromisoformat(data["expires_at"])),
"update_usage": lambda tokens, cost: None # Handled separately for cache consistency
})(),
# Create minimal user object
"user": type("User", (), {
"id": data["user_id"],
"email": data["user_email"],
"role": data["user_role"],
"is_active": data["user_is_active"]
})()
}
return context
except Exception as e:
logger.warning(f"Failed to deserialize cached API key data: {e}")
return None
async def get_cached_api_key(self, key_prefix: str, db: AsyncSession) -> Optional[Dict[str, Any]]:
"""Get API key data from cache or database with optimized queries"""
try:
redis = await self.get_redis()
# If Redis is available, try cache first
if redis:
cache_key = self._get_cache_key(key_prefix)
# Try to get from cache first
cached_data = await redis.get(cache_key)
if cached_data:
logger.debug(f"API key cache hit for {key_prefix}")
context = await self._deserialize_api_key_data(cached_data)
if context:
return context
else:
# Invalid cached data, remove it
await redis.delete(cache_key)
logger.debug(f"API key cache miss for {key_prefix}, fetching from database")
else:
logger.debug(f"Redis not available, fetching API key {key_prefix} from database with optimized query")
# Cache miss or Redis not available - fetch from database with optimized query
context = await self._fetch_from_database(key_prefix, db)
# If Redis is available and we have data, cache it
if context and redis:
try:
api_key = context["api_key"]
user = context["user"]
# Reconstruct full objects for serialization
full_api_key = await self._get_full_api_key_from_db(key_prefix, db)
if full_api_key:
cached_data = await self._serialize_api_key_data(full_api_key, user)
await redis.setex(cache_key, self.cache_ttl, cached_data)
logger.debug(f"Cached API key data for {key_prefix}")
except Exception as cache_error:
logger.warning(f"Failed to cache API key data: {cache_error}")
# Don't fail the request if caching fails
return context
except Exception as e:
logger.error(f"Error in cached API key lookup for {key_prefix}: {e}")
# Fallback to database
return await self._fetch_from_database(key_prefix, db)
async def _get_full_api_key_from_db(self, key_prefix: str, db: AsyncSession) -> Optional[APIKey]:
"""Helper to get full API key object from database"""
stmt = select(APIKey).where(APIKey.key_prefix == key_prefix)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def _fetch_from_database(self, key_prefix: str, db: AsyncSession) -> Optional[Dict[str, Any]]:
"""Fetch API key and user data from database with optimized query"""
try:
# Optimized query with joinedload to eliminate N+1 query problem
stmt = select(APIKey).options(
joinedload(APIKey.user)
).where(APIKey.key_prefix == key_prefix)
result = await db.execute(stmt)
api_key = result.scalar_one_or_none()
if not api_key:
logger.warning(f"API key not found: {key_prefix}")
return None
user = api_key.user
if not user or not user.is_active:
logger.warning(f"User not found or inactive for API key: {key_prefix}")
return None
# Return the same structure as the original service
return {
"user_id": user.id,
"user_email": user.email,
"user_role": user.role,
"api_key_id": api_key.id,
"api_key_name": api_key.name,
"api_key": api_key,
"user": user,
"permissions": api_key.permissions,
"scopes": api_key.scopes,
"rate_limits": {
"per_minute": api_key.rate_limit_per_minute,
"per_hour": api_key.rate_limit_per_hour,
"per_day": api_key.rate_limit_per_day
}
}
except Exception as e:
logger.error(f"Database error fetching API key {key_prefix}: {e}")
return None
async def verify_api_key_cached(self, api_key: str, key_prefix: str) -> bool:
"""Cache API key verification results to avoid repeated bcrypt operations"""
try:
redis = await self.get_redis()
# If Redis is not available, skip caching
if not redis:
logger.debug(f"Redis not available, skipping verification cache for {key_prefix}")
return False # Caller should handle full verification
# Create a hash of the key suffix for cache key (never store the actual key)
import hashlib
key_suffix = api_key[8:] if len(api_key) > 8 else api_key
key_suffix_hash = hashlib.sha256(key_suffix.encode()).hexdigest()[:16]
verification_cache_key = self._get_verification_cache_key(key_prefix, key_suffix_hash)
# Check verification cache
cached_result = await redis.get(verification_cache_key)
if cached_result:
logger.debug(f"API key verification cache hit for {key_prefix}")
return cached_result == "valid"
# Need to do actual verification - get the hash from database
# This should be called only after we've confirmed the key exists
logger.debug(f"API key verification cache miss for {key_prefix}")
return False # Caller should handle full verification
except Exception as e:
logger.warning(f"Error in verification cache for {key_prefix}: {e}")
return False
async def cache_verification_result(self, api_key: str, key_prefix: str, key_hash: str, is_valid: bool):
"""Cache the verification result to avoid future bcrypt operations"""
try:
# Only cache successful verifications and do actual verification
actual_valid = verify_api_key(api_key, key_hash)
if actual_valid != is_valid:
logger.warning(f"Verification mismatch for {key_prefix}")
return
if actual_valid:
redis = await self.get_redis()
# If Redis is not available, skip caching
if not redis:
logger.debug(f"Redis not available, skipping verification result cache for {key_prefix}")
return
# Create a hash of the key suffix for cache key
import hashlib
key_suffix = api_key[8:] if len(api_key) > 8 else api_key
key_suffix_hash = hashlib.sha256(key_suffix.encode()).hexdigest()[:16]
verification_cache_key = self._get_verification_cache_key(key_prefix, key_suffix_hash)
# Cache successful verification
await redis.setex(verification_cache_key, self.verification_cache_ttl, "valid")
logger.debug(f"Cached verification result for {key_prefix}")
except Exception as e:
logger.warning(f"Error caching verification result for {key_prefix}: {e}")
async def invalidate_api_key_cache(self, key_prefix: str):
"""Invalidate cached data for an API key"""
try:
redis = await self.get_redis()
# If Redis is not available, skip invalidation
if not redis:
logger.debug(f"Redis not available, skipping cache invalidation for {key_prefix}")
return
cache_key = self._get_cache_key(key_prefix)
await redis.delete(cache_key)
# Also invalidate verification cache - get all verification keys for this prefix
pattern = f"api_key:verified:{key_prefix}:*"
keys = await redis.keys(pattern)
if keys:
await redis.delete(*keys)
logger.debug(f"Invalidated cache for API key {key_prefix}")
except Exception as e:
logger.warning(f"Error invalidating cache for {key_prefix}: {e}")
async def update_last_used(self, api_key_id: int, db: AsyncSession):
"""Update last used timestamp with write-through cache"""
try:
redis = await self.get_redis()
current_time = datetime.utcnow()
should_update = True
# If Redis is available, check if we've updated recently (avoid too frequent DB writes)
if redis:
cache_key = self._get_last_used_cache_key(api_key_id)
last_update = await redis.get(cache_key)
if last_update:
last_update_time = datetime.fromisoformat(last_update)
if current_time - last_update_time < timedelta(minutes=1):
# Skip update if last update was less than 1 minute ago
should_update = False
if should_update:
# Update database
stmt = select(APIKey).where(APIKey.id == api_key_id)
result = await db.execute(stmt)
api_key = result.scalar_one_or_none()
if api_key:
api_key.last_used_at = current_time
await db.commit()
# Update cache if Redis is available
if redis:
cache_key = self._get_last_used_cache_key(api_key_id)
await redis.setex(cache_key, 300, current_time.isoformat())
logger.debug(f"Updated last used timestamp for API key {api_key_id}")
except Exception as e:
logger.warning(f"Error updating last used timestamp for API key {api_key_id}: {e}")
# Global cached service instance
cached_api_key_service = CachedAPIKeyService()

View File

@@ -0,0 +1,451 @@
"""
Configuration Management Service - Core App Integration
Provides centralized configuration management with hot-reloading and encryption.
"""
import asyncio
import json
import os
import hashlib
import time
import threading
from typing import Dict, Any, Optional, List, Union, Callable
from pathlib import Path
from dataclasses import dataclass, asdict
from datetime import datetime
from cryptography.fernet import Fernet
import yaml
import logging
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from app.core.logging import get_logger
logger = get_logger(__name__)
@dataclass
class ConfigVersion:
"""Configuration version metadata"""
version: str
timestamp: datetime
checksum: str
author: str
description: str
config_data: Dict[str, Any]
@dataclass
class ConfigSchema:
"""Configuration schema definition"""
name: str
required_fields: List[str]
optional_fields: List[str]
field_types: Dict[str, type]
validators: Dict[str, Callable]
@dataclass
class ConfigStats:
"""Configuration manager statistics"""
total_configs: int
active_watchers: int
config_versions: int
encrypted_configs: int
hot_reloads_performed: int
validation_errors: int
last_reload_time: datetime
uptime: float
class ConfigWatcher(FileSystemEventHandler):
"""File system watcher for configuration changes"""
def __init__(self, config_manager):
self.config_manager = config_manager
self.debounce_time = 1.0 # 1 second debounce
self.last_modified = {}
def on_modified(self, event):
if event.is_directory:
return
path = event.src_path
current_time = time.time()
# Debounce rapid file changes
if path in self.last_modified:
if current_time - self.last_modified[path] < self.debounce_time:
return
self.last_modified[path] = current_time
# Trigger hot reload for config files
if path.endswith(('.json', '.yaml', '.yml', '.toml')):
# Schedule coroutine in a thread-safe way
try:
loop = asyncio.get_running_loop()
loop.call_soon_threadsafe(
lambda: asyncio.create_task(self.config_manager.reload_config_file(path))
)
except RuntimeError:
# No running loop, schedule for later
threading.Thread(
target=self._schedule_reload,
args=(path,),
daemon=True
).start()
def _schedule_reload(self, path: str):
"""Schedule reload in a new thread if no event loop is available"""
try:
logger.info(f"Scheduling config reload for {path}")
except Exception as e:
logger.error(f"Error scheduling config reload for {path}: {str(e)}")
class ConfigManager:
"""Core configuration management system"""
def __init__(self):
self.configs: Dict[str, Dict[str, Any]] = {}
self.schemas: Dict[str, ConfigSchema] = {}
self.versions: Dict[str, List[ConfigVersion]] = {}
self.watchers: Dict[str, Observer] = {}
self.encrypted_configs: set = set()
self.config_paths: Dict[str, Path] = {}
self.environment = os.getenv('ENVIRONMENT', 'development')
self.start_time = time.time()
self.stats = ConfigStats(
total_configs=0,
active_watchers=0,
config_versions=0,
encrypted_configs=0,
hot_reloads_performed=0,
validation_errors=0,
last_reload_time=datetime.now(),
uptime=0
)
# Initialize encryption key
self.encryption_key = self._get_or_create_encryption_key()
self.cipher = Fernet(self.encryption_key)
# Base configuration directories
self.config_base_dir = Path("configs")
self.config_base_dir.mkdir(exist_ok=True)
# Environment-specific directory
self.env_config_dir = self.config_base_dir / self.environment
self.env_config_dir.mkdir(exist_ok=True)
logger.info(f"ConfigManager initialized for environment: {self.environment}")
def _get_or_create_encryption_key(self) -> bytes:
"""Get or create encryption key for sensitive configurations"""
key_file = Path(".config_encryption_key")
if key_file.exists():
return key_file.read_bytes()
else:
key = Fernet.generate_key()
key_file.write_bytes(key)
key_file.chmod(0o600) # Restrict permissions
logger.info("Generated new encryption key for configuration management")
return key
def register_schema(self, name: str, schema: ConfigSchema):
"""Register a configuration schema for validation"""
self.schemas[name] = schema
logger.info(f"Registered configuration schema: {name}")
def validate_config(self, name: str, config_data: Dict[str, Any]) -> bool:
"""Validate configuration against registered schema"""
if name not in self.schemas:
logger.debug(f"No schema registered for config: {name}")
return True
schema = self.schemas[name]
try:
# Check required fields
for field in schema.required_fields:
if field not in config_data:
logger.error(f"Missing required field '{field}' in config '{name}'")
self.stats.validation_errors += 1
return False
# Validate field types
for field, expected_type in schema.field_types.items():
if field in config_data:
if not isinstance(config_data[field], expected_type):
logger.error(f"Invalid type for field '{field}' in config '{name}'. Expected {expected_type.__name__}")
self.stats.validation_errors += 1
return False
# Run custom validators
for field, validator in schema.validators.items():
if field in config_data:
if not validator(config_data[field]):
logger.error(f"Validation failed for field '{field}' in config '{name}'")
self.stats.validation_errors += 1
return False
logger.debug(f"Configuration '{name}' passed validation")
return True
except Exception as e:
logger.error(f"Error validating config '{name}': {str(e)}")
self.stats.validation_errors += 1
return False
def _calculate_checksum(self, data: Dict[str, Any]) -> str:
"""Calculate checksum for configuration data"""
json_str = json.dumps(data, sort_keys=True)
return hashlib.sha256(json_str.encode()).hexdigest()
def _create_version(self, name: str, config_data: Dict[str, Any], description: str = "Auto-save") -> ConfigVersion:
"""Create a new configuration version"""
version_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}"
checksum = self._calculate_checksum(config_data)
version = ConfigVersion(
version=version_id,
timestamp=datetime.now(),
checksum=checksum,
author=os.getenv('USER', 'system'),
description=description,
config_data=config_data.copy()
)
if name not in self.versions:
self.versions[name] = []
self.versions[name].append(version)
# Keep only last 10 versions
if len(self.versions[name]) > 10:
self.versions[name] = self.versions[name][-10:]
self.stats.config_versions += 1
logger.debug(f"Created version {version_id} for config '{name}'")
return version
async def set_config(self, name: str, config_data: Dict[str, Any],
encrypted: bool = False, description: str = "Manual update") -> bool:
"""Set configuration with validation and versioning"""
try:
# Validate configuration
if not self.validate_config(name, config_data):
return False
# Create version before updating
self._create_version(name, config_data, description)
# Handle encryption if requested
if encrypted:
self.encrypted_configs.add(name)
# Store configuration
self.configs[name] = config_data.copy()
self.stats.total_configs = len(self.configs)
# Save to file
await self._save_config_to_file(name, config_data, encrypted)
logger.info(f"Configuration '{name}' updated successfully")
return True
except Exception as e:
logger.error(f"Error setting config '{name}': {str(e)}")
return False
async def get_config(self, name: str, default: Any = None) -> Any:
"""Get configuration value"""
if name in self.configs:
return self.configs[name]
# Try to load from file if not in memory
config_data = await self._load_config_from_file(name)
if config_data is not None:
self.configs[name] = config_data
return config_data
return default
async def get_config_value(self, config_name: str, key: str, default: Any = None) -> Any:
"""Get specific value from configuration"""
config = await self.get_config(config_name)
if config is None:
return default
keys = key.split('.')
value = config
try:
for k in keys:
value = value[k]
return value
except (KeyError, TypeError):
return default
async def _save_config_to_file(self, name: str, config_data: Dict[str, Any], encrypted: bool = False):
"""Save configuration to file"""
file_path = self.env_config_dir / f"{name}.json"
try:
if encrypted:
# Encrypt sensitive data
json_str = json.dumps(config_data, indent=2)
encrypted_data = self.cipher.encrypt(json_str.encode())
file_path.write_bytes(encrypted_data)
logger.debug(f"Saved encrypted config '{name}' to {file_path}")
else:
# Save as regular JSON
with open(file_path, 'w') as f:
json.dump(config_data, f, indent=2)
logger.debug(f"Saved config '{name}' to {file_path}")
self.config_paths[name] = file_path
except Exception as e:
logger.error(f"Error saving config '{name}' to file: {str(e)}")
raise
async def _load_config_from_file(self, name: str) -> Optional[Dict[str, Any]]:
"""Load configuration from file"""
file_path = self.env_config_dir / f"{name}.json"
if not file_path.exists():
return None
try:
if name in self.encrypted_configs:
# Decrypt sensitive data
encrypted_data = file_path.read_bytes()
decrypted_data = self.cipher.decrypt(encrypted_data)
return json.loads(decrypted_data.decode())
else:
# Load regular JSON
with open(file_path, 'r') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading config '{name}' from file: {str(e)}")
return None
async def reload_config_file(self, file_path: str):
"""Hot reload configuration from file change"""
try:
path = Path(file_path)
config_name = path.stem
# Load updated configuration
if path.suffix == '.json':
with open(path, 'r') as f:
new_config = json.load(f)
elif path.suffix in ['.yaml', '.yml']:
with open(path, 'r') as f:
new_config = yaml.safe_load(f)
else:
logger.warning(f"Unsupported config file format: {path.suffix}")
return
# Validate and update
if self.validate_config(config_name, new_config):
self.configs[config_name] = new_config
self.stats.hot_reloads_performed += 1
self.stats.last_reload_time = datetime.now()
logger.info(f"Hot reloaded configuration '{config_name}' from {file_path}")
else:
logger.error(f"Failed to hot reload '{config_name}' - validation failed")
except Exception as e:
logger.error(f"Error hot reloading config from {file_path}: {str(e)}")
def get_stats(self) -> Dict[str, Any]:
"""Get configuration management statistics"""
self.stats.uptime = time.time() - self.start_time
return asdict(self.stats)
async def cleanup(self):
"""Cleanup resources"""
for watcher in self.watchers.values():
watcher.stop()
watcher.join()
self.watchers.clear()
logger.info("Configuration management cleanup completed")
# Global config manager
config_manager: Optional[ConfigManager] = None
def get_config_manager() -> ConfigManager:
"""Get the global config manager instance"""
global config_manager
if config_manager is None:
config_manager = ConfigManager()
return config_manager
async def init_config_manager():
"""Initialize the global config manager"""
global config_manager
config_manager = ConfigManager()
# Register default schemas
await _register_default_schemas()
# Load default configurations
await _load_default_configs()
logger.info("Configuration manager initialized")
async def _register_default_schemas():
"""Register default configuration schemas"""
manager = get_config_manager()
# Database schema
db_schema = ConfigSchema(
name="database",
required_fields=["host", "port", "name"],
optional_fields=["username", "password", "ssl"],
field_types={"host": str, "port": int, "name": str, "ssl": bool},
validators={"port": lambda x: 1 <= x <= 65535}
)
manager.register_schema("database", db_schema)
# Cache schema
cache_schema = ConfigSchema(
name="cache",
required_fields=["redis_url"],
optional_fields=["timeout", "max_connections"],
field_types={"redis_url": str, "timeout": int, "max_connections": int},
validators={"timeout": lambda x: x > 0}
)
manager.register_schema("cache", cache_schema)
async def _load_default_configs():
"""Load default configurations"""
manager = get_config_manager()
default_configs = {
"app": {
"name": "Confidential Empire",
"version": "1.0.0",
"debug": manager.environment == "development",
"log_level": "INFO",
"timezone": "UTC"
},
"cache": {
"redis_url": "redis://empire-redis:6379/0",
"timeout": 30,
"max_connections": 10
}
}
for name, config in default_configs.items():
await manager.set_config(name, config, description="Default configuration")

View File

@@ -0,0 +1,187 @@
"""
Cost calculation service for LLM model pricing
"""
from typing import Dict, Optional
from app.core.logging import get_logger
logger = get_logger(__name__)
class CostCalculator:
"""Service for calculating costs based on model usage and token consumption"""
# Model pricing in 1/10000ths of a dollar per 1000 tokens (input/output)
MODEL_PRICING = {
# OpenAI Models
"gpt-4": {"input": 300, "output": 600}, # $0.03/$0.06 per 1K tokens
"gpt-4-turbo": {"input": 100, "output": 300}, # $0.01/$0.03 per 1K tokens
"gpt-3.5-turbo": {"input": 5, "output": 15}, # $0.0005/$0.0015 per 1K tokens
# Anthropic Models
"claude-3-opus": {"input": 150, "output": 750}, # $0.015/$0.075 per 1K tokens
"claude-3-sonnet": {"input": 30, "output": 150}, # $0.003/$0.015 per 1K tokens
"claude-3-haiku": {"input": 25, "output": 125}, # $0.00025/$0.00125 per 1K tokens
# Google Models
"gemini-pro": {"input": 5, "output": 15}, # $0.0005/$0.0015 per 1K tokens
"gemini-pro-vision": {"input": 5, "output": 15}, # $0.0005/$0.0015 per 1K tokens
# Privatemode.ai Models (estimated pricing)
"privatemode-llama-70b": {"input": 40, "output": 80}, # Estimated pricing
"privatemode-mixtral": {"input": 20, "output": 40}, # Estimated pricing
# Embedding Models
"text-embedding-ada-002": {"input": 1, "output": 0}, # $0.0001 per 1K tokens
"text-embedding-3-small": {"input": 2, "output": 0}, # $0.00002 per 1K tokens
"text-embedding-3-large": {"input": 13, "output": 0}, # $0.00013 per 1K tokens
}
# Default pricing for unknown models
DEFAULT_PRICING = {"input": 10, "output": 20} # $0.001/$0.002 per 1K tokens
@classmethod
def get_model_pricing(cls, model_name: str) -> Dict[str, int]:
"""Get pricing for a specific model"""
# Normalize model name (remove provider prefixes)
normalized_name = cls._normalize_model_name(model_name)
# Look up pricing
pricing = cls.MODEL_PRICING.get(normalized_name, cls.DEFAULT_PRICING)
logger.debug(f"Pricing for model '{model_name}' (normalized: '{normalized_name}'): {pricing}")
return pricing
@classmethod
def _normalize_model_name(cls, model_name: str) -> str:
"""Normalize model name by removing provider prefixes"""
# Remove common provider prefixes
prefixes = ["openai/", "anthropic/", "google/", "gemini/", "privatemode/"]
normalized = model_name.lower()
for prefix in prefixes:
if normalized.startswith(prefix):
normalized = normalized[len(prefix):]
break
# Handle special cases
if "claude-3-opus-20240229" in normalized:
return "claude-3-opus"
elif "claude-3-sonnet-20240229" in normalized:
return "claude-3-sonnet"
elif "claude-3-haiku-20240307" in normalized:
return "claude-3-haiku"
elif "meta-llama/llama-3.1-70b-instruct" in normalized:
return "privatemode-llama-70b"
elif "mistralai/mixtral-8x7b-instruct" in normalized:
return "privatemode-mixtral"
return normalized
@classmethod
def calculate_cost_cents(
cls,
model_name: str,
input_tokens: int = 0,
output_tokens: int = 0
) -> int:
"""
Calculate cost in cents for given token usage
Args:
model_name: Name of the LLM model
input_tokens: Number of input tokens used
output_tokens: Number of output tokens generated
Returns:
Total cost in cents
"""
pricing = cls.get_model_pricing(model_name)
# Calculate cost per token type
input_cost_cents = (input_tokens * pricing["input"]) // 1000
output_cost_cents = (output_tokens * pricing["output"]) // 1000
total_cost_cents = input_cost_cents + output_cost_cents
logger.debug(
f"Cost calculation for {model_name}: "
f"input_tokens={input_tokens} (${input_cost_cents/100:.4f}), "
f"output_tokens={output_tokens} (${output_cost_cents/100:.4f}), "
f"total=${total_cost_cents/100:.4f}"
)
return total_cost_cents
@classmethod
def estimate_cost_cents(cls, model_name: str, estimated_tokens: int) -> int:
"""
Estimate cost for a request based on estimated total tokens
Assumes 70% input, 30% output token distribution
Args:
model_name: Name of the LLM model
estimated_tokens: Estimated total tokens for the request
Returns:
Estimated cost in cents
"""
input_tokens = int(estimated_tokens * 0.7) # 70% input
output_tokens = int(estimated_tokens * 0.3) # 30% output
return cls.calculate_cost_cents(model_name, input_tokens, output_tokens)
@classmethod
def get_cost_per_1k_tokens(cls, model_name: str) -> Dict[str, float]:
"""
Get cost per 1000 tokens in dollars for display purposes
Args:
model_name: Name of the LLM model
Returns:
Dictionary with input and output costs in dollars per 1K tokens
"""
pricing_cents = cls.get_model_pricing(model_name)
return {
"input": pricing_cents["input"] / 10000, # Convert 1/10000ths to dollars
"output": pricing_cents["output"] / 10000,
"currency": "USD"
}
@classmethod
def get_all_model_pricing(cls) -> Dict[str, Dict[str, float]]:
"""Get pricing for all supported models in dollars"""
pricing_data = {}
for model_name in cls.MODEL_PRICING.keys():
pricing_data[model_name] = cls.get_cost_per_1k_tokens(model_name)
return pricing_data
@classmethod
def format_cost_display(cls, cost_cents: int) -> str:
"""Format cost in 1/1000ths of a dollar for display"""
if cost_cents == 0:
return "$0.00"
elif cost_cents < 1000:
return f"${cost_cents/1000:.4f}"
else:
return f"${cost_cents/1000:.2f}"
# Convenience functions for common operations
def calculate_request_cost(model_name: str, input_tokens: int, output_tokens: int) -> int:
"""Calculate cost for a single request"""
return CostCalculator.calculate_cost_cents(model_name, input_tokens, output_tokens)
def estimate_request_cost(model_name: str, estimated_tokens: int) -> int:
"""Estimate cost for a request"""
return CostCalculator.estimate_cost_cents(model_name, estimated_tokens)
def get_model_pricing_display(model_name: str) -> Dict[str, float]:
"""Get model pricing for display"""
return CostCalculator.get_cost_per_1k_tokens(model_name)

View File

@@ -0,0 +1,311 @@
"""
Document Processor Service
Handles async document processing with queue management
"""
import asyncio
import logging
from typing import Dict, Any, Optional, List
from datetime import datetime
from enum import Enum
from dataclasses import dataclass
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update
from sqlalchemy.orm import selectinload
from app.db.database import get_db
from app.models.rag_document import RagDocument
from app.models.rag_collection import RagCollection
logger = logging.getLogger(__name__)
class ProcessingStatus(str, Enum):
PENDING = "pending"
PROCESSING = "processing"
PROCESSED = "processed"
INDEXED = "indexed"
ERROR = "error"
@dataclass
class ProcessingTask:
"""Document processing task"""
document_id: int
priority: int = 1
retry_count: int = 0
max_retries: int = 3
created_at: datetime = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.utcnow()
class DocumentProcessor:
"""Async document processor with queue management"""
def __init__(self, max_workers: int = 3, max_queue_size: int = 100):
self.max_workers = max_workers
self.max_queue_size = max_queue_size
self.processing_queue: asyncio.Queue = asyncio.Queue(maxsize=max_queue_size)
self.workers: List[asyncio.Task] = []
self.running = False
self.stats = {
"processed_count": 0,
"error_count": 0,
"queue_size": 0,
"active_workers": 0
}
async def start(self):
"""Start the document processor"""
if self.running:
return
self.running = True
logger.info(f"Starting document processor with {self.max_workers} workers")
# Start worker tasks
for i in range(self.max_workers):
worker = asyncio.create_task(self._worker(f"worker-{i}"))
self.workers.append(worker)
logger.info("Document processor started")
async def stop(self):
"""Stop the document processor"""
if not self.running:
return
self.running = False
logger.info("Stopping document processor...")
# Cancel all workers
for worker in self.workers:
worker.cancel()
# Wait for workers to finish
await asyncio.gather(*self.workers, return_exceptions=True)
self.workers.clear()
logger.info("Document processor stopped")
async def add_task(self, document_id: int, priority: int = 1) -> bool:
"""Add a document processing task to the queue"""
try:
task = ProcessingTask(document_id=document_id, priority=priority)
# Check if queue is full
if self.processing_queue.full():
logger.warning(f"Processing queue is full, dropping task for document {document_id}")
return False
await self.processing_queue.put(task)
self.stats["queue_size"] = self.processing_queue.qsize()
logger.info(f"Added processing task for document {document_id} (priority: {priority})")
return True
except Exception as e:
logger.error(f"Failed to add processing task for document {document_id}: {e}")
return False
async def _worker(self, worker_name: str):
"""Worker coroutine that processes documents"""
logger.info(f"Started worker: {worker_name}")
while self.running:
try:
# Get task from queue (wait up to 1 second)
task = await asyncio.wait_for(
self.processing_queue.get(),
timeout=1.0
)
self.stats["active_workers"] += 1
self.stats["queue_size"] = self.processing_queue.qsize()
logger.info(f"{worker_name}: Processing document {task.document_id}")
# Process the document
success = await self._process_document(task)
if success:
self.stats["processed_count"] += 1
logger.info(f"{worker_name}: Successfully processed document {task.document_id}")
else:
# Retry logic
if task.retry_count < task.max_retries:
task.retry_count += 1
await asyncio.sleep(2 ** task.retry_count) # Exponential backoff
await self.processing_queue.put(task)
logger.warning(f"{worker_name}: Retrying document {task.document_id} (attempt {task.retry_count})")
else:
self.stats["error_count"] += 1
logger.error(f"{worker_name}: Failed to process document {task.document_id} after {task.max_retries} retries")
self.stats["active_workers"] -= 1
except asyncio.TimeoutError:
# No tasks in queue, continue
continue
except asyncio.CancelledError:
# Worker cancelled, exit
break
except Exception as e:
self.stats["active_workers"] -= 1
logger.error(f"{worker_name}: Unexpected error: {e}")
await asyncio.sleep(1) # Brief pause before continuing
logger.info(f"Worker stopped: {worker_name}")
async def _process_document(self, task: ProcessingTask) -> bool:
"""Process a single document"""
from app.db.database import async_session_factory
async with async_session_factory() as session:
try:
# Get document from database
stmt = (
select(RagDocument)
.options(selectinload(RagDocument.collection))
.where(RagDocument.id == task.document_id)
)
result = await session.execute(stmt)
document = result.scalar_one_or_none()
if not document:
logger.error(f"Document {task.document_id} not found")
return False
# Update status to processing
document.status = ProcessingStatus.PROCESSING
await session.commit()
# Get RAG module for processing (now includes content processing)
try:
from app.services.module_manager import module_manager
rag_module = module_manager.get_module('rag')
except Exception as e:
logger.error(f"Failed to get RAG module: {e}")
raise Exception(f"RAG module not available: {e}")
if not rag_module:
raise Exception("RAG module not available")
logger.info(f"RAG module loaded successfully for document {task.document_id}")
# Read file content
logger.info(f"Reading file content for document {task.document_id}: {document.file_path}")
with open(document.file_path, 'rb') as f:
file_content = f.read()
logger.info(f"File content read successfully for document {task.document_id}, size: {len(file_content)} bytes")
# Process with RAG module
logger.info(f"Starting document processing for document {task.document_id} with RAG module")
try:
# Add timeout to prevent hanging
processed_doc = await asyncio.wait_for(
rag_module.process_document(
file_content,
document.original_filename,
{}
),
timeout=300.0 # 5 minute timeout
)
logger.info(f"Document processing completed for document {task.document_id}")
except asyncio.TimeoutError:
logger.error(f"Document processing timed out for document {task.document_id}")
raise Exception("Document processing timed out after 5 minutes")
except Exception as e:
logger.error(f"Document processing failed for document {task.document_id}: {e}")
raise
# Update document with processed content
document.converted_content = processed_doc.content
document.word_count = processed_doc.word_count
document.character_count = len(processed_doc.content)
document.document_metadata = processed_doc.metadata
document.status = ProcessingStatus.PROCESSED
document.processed_at = datetime.utcnow()
# Index in RAG system using same RAG module
if rag_module and document.converted_content:
try:
logger.info(f"Starting indexing for document {task.document_id} in collection {document.collection.qdrant_collection_name}")
# Index the document content in the correct Qdrant collection
doc_metadata = {
"collection_id": document.collection_id,
"document_id": document.id,
"filename": document.original_filename,
"file_type": document.file_type,
**document.document_metadata
}
# Use the correct Qdrant collection name for this document
await asyncio.wait_for(
rag_module.index_document(
content=document.converted_content,
metadata=doc_metadata,
collection_name=document.collection.qdrant_collection_name
),
timeout=120.0 # 2 minute timeout for indexing
)
logger.info(f"Document {task.document_id} indexed successfully in collection {document.collection.qdrant_collection_name}")
# Update vector count (approximate)
document.vector_count = max(1, len(document.converted_content) // 1000)
document.status = ProcessingStatus.INDEXED
document.indexed_at = datetime.utcnow()
# Update collection stats
collection = document.collection
if collection and document.status == ProcessingStatus.INDEXED:
collection.document_count += 1
collection.size_bytes += document.file_size
collection.vector_count += document.vector_count
collection.updated_at = datetime.utcnow()
except Exception as e:
logger.error(f"Failed to index document {task.document_id} in RAG: {e}")
# Keep as processed even if indexing fails
# Don't raise the exception to avoid retries on indexing failures
await session.commit()
return True
except Exception as e:
# Mark document as error
if 'document' in locals() and document:
document.status = ProcessingStatus.ERROR
document.processing_error = str(e)
await session.commit()
logger.error(f"Error processing document {task.document_id}: {e}")
return False
async def get_stats(self) -> Dict[str, Any]:
"""Get processor statistics"""
return {
**self.stats,
"running": self.running,
"worker_count": len(self.workers),
"queue_size": self.processing_queue.qsize()
}
async def get_queue_status(self) -> Dict[str, Any]:
"""Get detailed queue status"""
return {
"queue_size": self.processing_queue.qsize(),
"max_queue_size": self.max_queue_size,
"queue_full": self.processing_queue.full(),
"active_workers": self.stats["active_workers"],
"max_workers": self.max_workers
}
# Global document processor instance
document_processor = DocumentProcessor()

View File

@@ -0,0 +1,160 @@
"""
Embedding Service
Provides text embedding functionality using LiteLLM proxy
"""
import logging
from typing import List, Dict, Any, Optional
import numpy as np
logger = logging.getLogger(__name__)
class EmbeddingService:
"""Service for generating text embeddings using LiteLLM"""
def __init__(self, model_name: str = "privatemode-embeddings"):
self.model_name = model_name
self.litellm_client = None
self.dimension = 1024 # Actual dimension for privatemode-embeddings
self.initialized = False
async def initialize(self):
"""Initialize the embedding service with LiteLLM"""
try:
from app.services.litellm_client import litellm_client
self.litellm_client = litellm_client
# Test connection to LiteLLM
health = await self.litellm_client.health_check()
if health.get("status") == "unhealthy":
logger.error(f"LiteLLM service unhealthy: {health.get('error')}")
return False
self.initialized = True
logger.info(f"Embedding service initialized with LiteLLM: {self.model_name} (dimension: {self.dimension})")
return True
except Exception as e:
logger.error(f"Failed to initialize LiteLLM embedding service: {e}")
logger.warning("Using fallback random embeddings")
return False
async def get_embedding(self, text: str) -> List[float]:
"""Get embedding for a single text"""
embeddings = await self.get_embeddings([text])
return embeddings[0]
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get embeddings for multiple texts using LiteLLM"""
if not self.initialized or not self.litellm_client:
# Fallback to random embeddings if not initialized
logger.warning("LiteLLM not available, using random embeddings")
return self._generate_fallback_embeddings(texts)
try:
embeddings = []
# Process texts in batches for efficiency
batch_size = 10
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
# Process each text in the batch
batch_embeddings = []
for text in batch:
try:
# Truncate text if it's too long for the model's context window
# privatemode-embeddings has a 512 token limit, truncate to ~400 tokens worth of chars
# Rough estimate: 1 token ≈ 4 characters, so 400 tokens ≈ 1600 chars
max_chars = 1600
if len(text) > max_chars:
truncated_text = text[:max_chars]
logger.debug(f"Truncated text from {len(text)} to {max_chars} chars for embedding")
else:
truncated_text = text
# Call LiteLLM embedding endpoint
response = await self.litellm_client.create_embedding(
model=self.model_name,
input_text=truncated_text,
user_id="rag_system",
api_key_id=0 # System API key
)
# Extract embedding from response
if "data" in response and len(response["data"]) > 0:
embedding = response["data"][0].get("embedding", [])
if embedding:
batch_embeddings.append(embedding)
# Update dimension based on actual embedding size
if not hasattr(self, '_dimension_confirmed'):
self.dimension = len(embedding)
self._dimension_confirmed = True
logger.info(f"Confirmed embedding dimension: {self.dimension}")
else:
logger.warning(f"No embedding in response for text: {text[:50]}...")
batch_embeddings.append(self._generate_fallback_embedding(text))
else:
logger.warning(f"Invalid response structure for text: {text[:50]}...")
batch_embeddings.append(self._generate_fallback_embedding(text))
except Exception as e:
logger.error(f"Error getting embedding for text: {e}")
batch_embeddings.append(self._generate_fallback_embedding(text))
embeddings.extend(batch_embeddings)
return embeddings
except Exception as e:
logger.error(f"Error generating embeddings with LiteLLM: {e}")
# Fallback to random embeddings
return self._generate_fallback_embeddings(texts)
def _generate_fallback_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Generate fallback random embeddings when model unavailable"""
embeddings = []
for text in texts:
embeddings.append(self._generate_fallback_embedding(text))
return embeddings
def _generate_fallback_embedding(self, text: str) -> List[float]:
"""Generate a single fallback embedding"""
dimension = self.dimension or 1024 # Default dimension for privatemode-embeddings
# Use hash for reproducible random embeddings
np.random.seed(hash(text) % 2**32)
return np.random.random(dimension).tolist()
async def similarity(self, text1: str, text2: str) -> float:
"""Calculate cosine similarity between two texts"""
embeddings = await self.get_embeddings([text1, text2])
# Calculate cosine similarity
vec1 = np.array(embeddings[0])
vec2 = np.array(embeddings[1])
# Normalize vectors
vec1_norm = vec1 / np.linalg.norm(vec1)
vec2_norm = vec2 / np.linalg.norm(vec2)
# Calculate cosine similarity
similarity = np.dot(vec1_norm, vec2_norm)
return float(similarity)
async def get_stats(self) -> Dict[str, Any]:
"""Get embedding service statistics"""
return {
"model_name": self.model_name,
"model_loaded": self.initialized,
"dimension": self.dimension,
"backend": "LiteLLM",
"initialized": self.initialized
}
async def cleanup(self):
"""Cleanup resources"""
self.initialized = False
self.litellm_client = None
# Global embedding service instance
embedding_service = EmbeddingService()

View File

@@ -0,0 +1,304 @@
"""
LiteLLM Client Service
Handles communication with the LiteLLM proxy service
"""
import asyncio
import json
import logging
from typing import Dict, Any, Optional, List
from datetime import datetime
import aiohttp
from fastapi import HTTPException, status
from app.core.config import settings
logger = logging.getLogger(__name__)
class LiteLLMClient:
"""Client for communicating with LiteLLM proxy service"""
def __init__(self):
self.base_url = settings.LITELLM_BASE_URL
self.master_key = settings.LITELLM_MASTER_KEY
self.session: Optional[aiohttp.ClientSession] = None
self.timeout = aiohttp.ClientTimeout(total=600) # 10 minutes timeout
async def _get_session(self) -> aiohttp.ClientSession:
"""Get or create aiohttp session"""
if self.session is None or self.session.closed:
self.session = aiohttp.ClientSession(
timeout=self.timeout,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self.master_key}"
}
)
return self.session
async def close(self):
"""Close the HTTP session"""
if self.session and not self.session.closed:
await self.session.close()
async def health_check(self) -> Dict[str, Any]:
"""Check LiteLLM proxy health"""
try:
session = await self._get_session()
async with session.get(f"{self.base_url}/health") as response:
if response.status == 200:
return await response.json()
else:
logger.error(f"LiteLLM health check failed: {response.status}")
return {"status": "unhealthy", "error": f"HTTP {response.status}"}
except Exception as e:
logger.error(f"LiteLLM health check error: {e}")
return {"status": "unhealthy", "error": str(e)}
async def get_models(self) -> List[Dict[str, Any]]:
"""Get available models from LiteLLM"""
try:
session = await self._get_session()
async with session.get(f"{self.base_url}/models") as response:
if response.status == 200:
data = await response.json()
return data.get("data", [])
else:
logger.error(f"Failed to get models: {response.status}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LiteLLM service unavailable"
)
except aiohttp.ClientError as e:
logger.error(f"LiteLLM models request error: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LiteLLM service unavailable"
)
async def create_chat_completion(
self,
model: str,
messages: List[Dict[str, str]],
user_id: str,
api_key_id: int,
**kwargs
) -> Dict[str, Any]:
"""Create chat completion via LiteLLM proxy"""
try:
# Prepare request payload
payload = {
"model": model,
"messages": messages,
"user": f"user_{user_id}", # User identifier for tracking
"metadata": {
"api_key_id": api_key_id,
"user_id": user_id,
"timestamp": datetime.utcnow().isoformat()
},
**kwargs
}
session = await self._get_session()
async with session.post(
f"{self.base_url}/chat/completions",
json=payload
) as response:
if response.status == 200:
return await response.json()
else:
error_text = await response.text()
logger.error(f"LiteLLM chat completion failed: {response.status} - {error_text}")
# Handle specific error cases
if response.status == 401:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key"
)
elif response.status == 429:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded"
)
elif response.status == 400:
try:
error_data = await response.json()
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_data.get("error", {}).get("message", "Bad request")
)
except:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid request"
)
else:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LiteLLM service error"
)
except aiohttp.ClientError as e:
logger.error(f"LiteLLM chat completion request error: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LiteLLM service unavailable"
)
async def create_embedding(
self,
model: str,
input_text: str,
user_id: str,
api_key_id: int,
**kwargs
) -> Dict[str, Any]:
"""Create embedding via LiteLLM proxy"""
try:
payload = {
"model": model,
"input": input_text,
"user": f"user_{user_id}",
"metadata": {
"api_key_id": api_key_id,
"user_id": user_id,
"timestamp": datetime.utcnow().isoformat()
},
**kwargs
}
session = await self._get_session()
async with session.post(
f"{self.base_url}/embeddings",
json=payload
) as response:
if response.status == 200:
return await response.json()
else:
error_text = await response.text()
logger.error(f"LiteLLM embedding failed: {response.status} - {error_text}")
if response.status == 401:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key"
)
elif response.status == 429:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded"
)
else:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LiteLLM service error"
)
except aiohttp.ClientError as e:
logger.error(f"LiteLLM embedding request error: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LiteLLM service unavailable"
)
async def get_models(self) -> List[Dict[str, Any]]:
"""Get available models from LiteLLM proxy"""
try:
session = await self._get_session()
async with session.get(f"{self.base_url}/models") as response:
if response.status == 200:
data = await response.json()
# Return models with exact names from upstream providers
models = data.get("data", [])
# Pass through model names exactly as they come from upstream
# Don't modify model IDs - keep them as the original provider names
processed_models = []
for model in models:
# Keep the exact model ID from upstream provider
processed_models.append({
"id": model.get("id", ""), # Exact model name from provider
"object": model.get("object", "model"),
"created": model.get("created", 1677610602),
"owned_by": model.get("owned_by", "openai")
})
return processed_models
else:
error_text = await response.text()
logger.error(f"LiteLLM models request failed: {response.status} - {error_text}")
return []
except aiohttp.ClientError as e:
logger.error(f"LiteLLM models request error: {e}")
return []
async def proxy_request(
self,
method: str,
endpoint: str,
payload: Dict[str, Any],
user_id: str,
api_key_id: int
) -> Dict[str, Any]:
"""Generic proxy request to LiteLLM"""
try:
# Add metadata to payload
if isinstance(payload, dict):
payload["metadata"] = {
"api_key_id": api_key_id,
"user_id": user_id,
"timestamp": datetime.utcnow().isoformat()
}
if "user" not in payload:
payload["user"] = f"user_{user_id}"
session = await self._get_session()
# Make the request
async with session.request(
method,
f"{self.base_url}/{endpoint.lstrip('/')}",
json=payload if method.upper() in ['POST', 'PUT', 'PATCH'] else None,
params=payload if method.upper() == 'GET' else None
) as response:
if response.status == 200:
return await response.json()
else:
error_text = await response.text()
logger.error(f"LiteLLM proxy request failed: {response.status} - {error_text}")
if response.status == 401:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key"
)
elif response.status == 429:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded"
)
else:
raise HTTPException(
status_code=response.status,
detail=f"LiteLLM service error: {error_text}"
)
except aiohttp.ClientError as e:
logger.error(f"LiteLLM proxy request error: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LiteLLM service unavailable"
)
async def list_models(self) -> List[str]:
"""Get list of available model names/IDs"""
try:
models_data = await self.get_models()
return [model.get("id", model.get("model", "")) for model in models_data if model.get("id") or model.get("model")]
except Exception as e:
logger.error(f"Error listing model names: {str(e)}")
return []
# Global LiteLLM client instance
litellm_client = LiteLLMClient()

View File

@@ -0,0 +1,308 @@
"""
Metrics and monitoring service
"""
import time
from typing import Dict, Any, Optional, List
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import asyncio
from collections import defaultdict, deque
from app.core.config import settings
from app.core.logging import log_module_event, log_security_event
@dataclass
class MetricData:
"""Individual metric data point"""
timestamp: datetime
value: float
labels: Dict[str, str] = field(default_factory=dict)
@dataclass
class RequestMetrics:
"""Request-related metrics"""
total_requests: int = 0
successful_requests: int = 0
failed_requests: int = 0
average_response_time: float = 0.0
total_tokens_used: int = 0
total_cost: float = 0.0
requests_by_model: Dict[str, int] = field(default_factory=dict)
requests_by_user: Dict[str, int] = field(default_factory=dict)
requests_by_endpoint: Dict[str, int] = field(default_factory=dict)
@dataclass
class SystemMetrics:
"""System-related metrics"""
uptime: float = 0.0
memory_usage: float = 0.0
cpu_usage: float = 0.0
active_connections: int = 0
module_status: Dict[str, bool] = field(default_factory=dict)
class MetricsService:
"""Service for collecting and managing metrics"""
def __init__(self):
self.request_metrics = RequestMetrics()
self.system_metrics = SystemMetrics()
self.metric_history: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000))
self.start_time = time.time()
self.response_times: deque = deque(maxlen=100) # Keep last 100 response times
self.active_requests: Dict[str, float] = {} # Track active requests
async def initialize(self):
"""Initialize the metrics service"""
log_module_event("metrics_service", "initializing", {})
self.start_time = time.time()
# Start background tasks
asyncio.create_task(self._collect_system_metrics())
asyncio.create_task(self._cleanup_old_metrics())
log_module_event("metrics_service", "initialized", {"success": True})
async def _collect_system_metrics(self):
"""Collect system metrics periodically"""
while True:
try:
# Update uptime
self.system_metrics.uptime = time.time() - self.start_time
# Update active connections
self.system_metrics.active_connections = len(self.active_requests)
# Store historical data
self._store_metric("uptime", self.system_metrics.uptime)
self._store_metric("active_connections", self.system_metrics.active_connections)
await asyncio.sleep(60) # Collect every minute
except Exception as e:
log_module_event("metrics_service", "system_metrics_error", {"error": str(e)})
await asyncio.sleep(60)
async def _cleanup_old_metrics(self):
"""Clean up old metric data"""
while True:
try:
cutoff_time = datetime.now() - timedelta(hours=24)
for metric_name, metric_data in self.metric_history.items():
# Remove old data points
while metric_data and metric_data[0].timestamp < cutoff_time:
metric_data.popleft()
await asyncio.sleep(3600) # Clean up every hour
except Exception as e:
log_module_event("metrics_service", "cleanup_error", {"error": str(e)})
await asyncio.sleep(3600)
def _store_metric(self, name: str, value: float, labels: Optional[Dict[str, str]] = None):
"""Store a metric data point"""
if labels is None:
labels = {}
metric_data = MetricData(
timestamp=datetime.now(),
value=value,
labels=labels
)
self.metric_history[name].append(metric_data)
def start_request(self, request_id: str, endpoint: str, user_id: Optional[str] = None):
"""Start tracking a request"""
self.active_requests[request_id] = time.time()
# Update request metrics
self.request_metrics.total_requests += 1
# Track by endpoint
self.request_metrics.requests_by_endpoint[endpoint] = \
self.request_metrics.requests_by_endpoint.get(endpoint, 0) + 1
# Track by user
if user_id:
self.request_metrics.requests_by_user[user_id] = \
self.request_metrics.requests_by_user.get(user_id, 0) + 1
# Store metric
self._store_metric("requests_total", self.request_metrics.total_requests)
self._store_metric("requests_by_endpoint", 1, {"endpoint": endpoint})
if user_id:
self._store_metric("requests_by_user", 1, {"user_id": user_id})
def end_request(self, request_id: str, success: bool = True,
model: Optional[str] = None, tokens_used: int = 0,
cost: float = 0.0):
"""End tracking a request"""
if request_id not in self.active_requests:
return
# Calculate response time
response_time = time.time() - self.active_requests[request_id]
self.response_times.append(response_time)
# Update metrics
if success:
self.request_metrics.successful_requests += 1
else:
self.request_metrics.failed_requests += 1
# Update average response time
if self.response_times:
self.request_metrics.average_response_time = sum(self.response_times) / len(self.response_times)
# Update token and cost metrics
self.request_metrics.total_tokens_used += tokens_used
self.request_metrics.total_cost += cost
# Track by model
if model:
self.request_metrics.requests_by_model[model] = \
self.request_metrics.requests_by_model.get(model, 0) + 1
# Store metrics
self._store_metric("response_time", response_time)
self._store_metric("tokens_used", tokens_used)
self._store_metric("cost", cost)
if model:
self._store_metric("requests_by_model", 1, {"model": model})
# Clean up
del self.active_requests[request_id]
def record_error(self, error_type: str, error_message: str,
endpoint: Optional[str] = None, user_id: Optional[str] = None):
"""Record an error occurrence"""
labels = {"error_type": error_type}
if endpoint:
labels["endpoint"] = endpoint
if user_id:
labels["user_id"] = user_id
self._store_metric("errors_total", 1, labels)
# Log security events for authentication/authorization errors
if error_type in ["authentication_failed", "authorization_failed", "invalid_api_key"]:
log_security_event(error_type, user_id or "anonymous", {
"error": error_message,
"endpoint": endpoint
})
def record_module_status(self, module_name: str, is_healthy: bool):
"""Record module health status"""
self.system_metrics.module_status[module_name] = is_healthy
self._store_metric("module_health", 1 if is_healthy else 0, {"module": module_name})
def get_current_metrics(self) -> Dict[str, Any]:
"""Get current metrics snapshot"""
return {
"request_metrics": {
"total_requests": self.request_metrics.total_requests,
"successful_requests": self.request_metrics.successful_requests,
"failed_requests": self.request_metrics.failed_requests,
"success_rate": (
self.request_metrics.successful_requests / self.request_metrics.total_requests
if self.request_metrics.total_requests > 0 else 0
),
"average_response_time": self.request_metrics.average_response_time,
"total_tokens_used": self.request_metrics.total_tokens_used,
"total_cost": self.request_metrics.total_cost,
"requests_by_model": dict(self.request_metrics.requests_by_model),
"requests_by_user": dict(self.request_metrics.requests_by_user),
"requests_by_endpoint": dict(self.request_metrics.requests_by_endpoint)
},
"system_metrics": {
"uptime": self.system_metrics.uptime,
"active_connections": self.system_metrics.active_connections,
"module_status": dict(self.system_metrics.module_status)
}
}
def get_metrics_history(self, metric_name: str,
hours: int = 1) -> List[Dict[str, Any]]:
"""Get historical metrics data"""
if metric_name not in self.metric_history:
return []
cutoff_time = datetime.now() - timedelta(hours=hours)
return [
{
"timestamp": data.timestamp.isoformat(),
"value": data.value,
"labels": data.labels
}
for data in self.metric_history[metric_name]
if data.timestamp > cutoff_time
]
def get_top_metrics(self, metric_type: str, limit: int = 10) -> Dict[str, Any]:
"""Get top metrics by type"""
if metric_type == "models":
return dict(
sorted(self.request_metrics.requests_by_model.items(),
key=lambda x: x[1], reverse=True)[:limit]
)
elif metric_type == "users":
return dict(
sorted(self.request_metrics.requests_by_user.items(),
key=lambda x: x[1], reverse=True)[:limit]
)
elif metric_type == "endpoints":
return dict(
sorted(self.request_metrics.requests_by_endpoint.items(),
key=lambda x: x[1], reverse=True)[:limit]
)
else:
return {}
def get_health_check(self) -> Dict[str, Any]:
"""Get health check information"""
return {
"status": "healthy",
"uptime": self.system_metrics.uptime,
"active_connections": self.system_metrics.active_connections,
"total_requests": self.request_metrics.total_requests,
"success_rate": (
self.request_metrics.successful_requests / self.request_metrics.total_requests
if self.request_metrics.total_requests > 0 else 1.0
),
"modules": self.system_metrics.module_status,
"timestamp": datetime.now().isoformat()
}
async def reset_metrics(self):
"""Reset all metrics (for testing purposes)"""
self.request_metrics = RequestMetrics()
self.system_metrics = SystemMetrics()
self.metric_history.clear()
self.response_times.clear()
self.active_requests.clear()
self.start_time = time.time()
log_module_event("metrics_service", "metrics_reset", {"success": True})
# Global metrics service instance
metrics_service = MetricsService()
def setup_metrics(app):
"""Setup metrics service with FastAPI app"""
# Store metrics service in app state
app.state.metrics_service = metrics_service
# Initialize metrics service
import asyncio
asyncio.create_task(metrics_service.initialize())

View File

@@ -0,0 +1,308 @@
"""
Module-specific configuration management service
Works alongside the general ConfigManager for module discovery and schema validation
"""
import json
import yaml
from typing import Dict, List, Any, Optional
from pathlib import Path
from jsonschema import validate, ValidationError, draft7_format_checker
from dataclasses import dataclass, asdict
from app.core.logging import get_logger
from app.utils.exceptions import ConfigurationError
logger = get_logger(__name__)
@dataclass
class ModuleManifest:
"""Module manifest loaded from module.yaml"""
name: str
version: str
description: str
author: str
category: str = "general"
enabled: bool = True
auto_start: bool = True
dependencies: List[str] = None
optional_dependencies: List[str] = None
config_schema: Optional[str] = None
ui_components: Optional[str] = None
provides: List[str] = None
consumes: List[str] = None
endpoints: List[Dict] = None
workflow_steps: List[Dict] = None
permissions: List[Dict] = None
analytics_events: List[Dict] = None
health_checks: List[Dict] = None
ui_config: Dict = None
documentation: Dict = None
def __post_init__(self):
if self.dependencies is None:
self.dependencies = []
if self.optional_dependencies is None:
self.optional_dependencies = []
if self.provides is None:
self.provides = []
if self.consumes is None:
self.consumes = []
if self.endpoints is None:
self.endpoints = []
if self.workflow_steps is None:
self.workflow_steps = []
if self.permissions is None:
self.permissions = []
if self.analytics_events is None:
self.analytics_events = []
if self.health_checks is None:
self.health_checks = []
if self.ui_config is None:
self.ui_config = {}
if self.documentation is None:
self.documentation = {}
class ModuleConfigManager:
"""Manages module configurations and JSON schema validation"""
def __init__(self):
self.manifests: Dict[str, ModuleManifest] = {}
self.schemas: Dict[str, Dict] = {}
self.configs: Dict[str, Dict] = {}
async def discover_modules(self, modules_path: str = "modules") -> Dict[str, ModuleManifest]:
"""Discover modules from filesystem using module.yaml manifests"""
discovered_modules = {}
modules_dir = Path(modules_path)
if not modules_dir.exists():
logger.warning(f"Modules directory not found: {modules_path}")
return discovered_modules
logger.info(f"Discovering modules in: {modules_dir.absolute()}")
for module_dir in modules_dir.iterdir():
if not module_dir.is_dir():
continue
manifest_path = module_dir / "module.yaml"
if not manifest_path.exists():
# Try module.yml as fallback
manifest_path = module_dir / "module.yml"
if not manifest_path.exists():
# Check if it's a legacy module (has main.py but no manifest)
if (module_dir / "main.py").exists():
logger.info(f"Legacy module found (no manifest): {module_dir.name}")
# Create a basic manifest for legacy modules
manifest = ModuleManifest(
name=module_dir.name,
version="1.0.0",
description=f"Legacy {module_dir.name} module",
author="System",
category="legacy"
)
discovered_modules[manifest.name] = manifest
continue
try:
manifest = await self._load_module_manifest(manifest_path)
discovered_modules[manifest.name] = manifest
logger.info(f"Discovered module: {manifest.name} v{manifest.version}")
except Exception as e:
logger.error(f"Failed to load manifest for {module_dir.name}: {e}")
continue
self.manifests = discovered_modules
return discovered_modules
async def _load_module_manifest(self, manifest_path: Path) -> ModuleManifest:
"""Load and validate a module manifest file"""
try:
with open(manifest_path, 'r', encoding='utf-8') as f:
manifest_data = yaml.safe_load(f)
# Validate required fields
required_fields = ['name', 'version', 'description', 'author']
for field in required_fields:
if field not in manifest_data:
raise ConfigurationError(f"Missing required field '{field}' in {manifest_path}")
manifest = ModuleManifest(**manifest_data)
# Load configuration schema if specified
if manifest.config_schema:
schema_path = manifest_path.parent / manifest.config_schema
if schema_path.exists():
await self._load_module_schema(manifest.name, schema_path)
else:
logger.warning(f"Config schema not found: {schema_path}")
return manifest
except yaml.YAMLError as e:
raise ConfigurationError(f"Invalid YAML in {manifest_path}: {e}")
except Exception as e:
raise ConfigurationError(f"Failed to load manifest {manifest_path}: {e}")
async def _load_module_schema(self, module_name: str, schema_path: Path):
"""Load JSON schema for module configuration"""
try:
with open(schema_path, 'r', encoding='utf-8') as f:
schema = json.load(f)
self.schemas[module_name] = schema
logger.info(f"Loaded configuration schema for module: {module_name}")
except json.JSONDecodeError as e:
raise ConfigurationError(f"Invalid JSON schema in {schema_path}: {e}")
except Exception as e:
raise ConfigurationError(f"Failed to load schema {schema_path}: {e}")
def get_module_manifest(self, module_name: str) -> Optional[ModuleManifest]:
"""Get module manifest by name"""
return self.manifests.get(module_name)
def get_module_schema(self, module_name: str) -> Optional[Dict]:
"""Get configuration schema for a module"""
return self.schemas.get(module_name)
def get_module_config(self, module_name: str) -> Dict:
"""Get current configuration for a module"""
return self.configs.get(module_name, {})
async def validate_config(self, module_name: str, config: Dict) -> Dict:
"""Validate module configuration against its schema"""
schema = self.schemas.get(module_name)
if not schema:
logger.info(f"No schema found for module {module_name}, skipping validation")
return {"valid": True, "errors": []}
try:
validate(instance=config, schema=schema, format_checker=draft7_format_checker)
return {"valid": True, "errors": []}
except ValidationError as e:
return {
"valid": False,
"errors": [{
"path": list(e.path),
"message": e.message,
"invalid_value": e.instance
}]
}
except Exception as e:
return {
"valid": False,
"errors": [{"message": f"Schema validation failed: {str(e)}"}]
}
async def save_module_config(self, module_name: str, config: Dict) -> bool:
"""Save module configuration"""
# Validate configuration first
validation_result = await self.validate_config(module_name, config)
if not validation_result["valid"]:
error_messages = [error["message"] for error in validation_result["errors"]]
raise ConfigurationError(f"Invalid configuration for {module_name}: {', '.join(error_messages)}")
# Save configuration
self.configs[module_name] = config
# In production, this would persist to database
# For now, we'll save to a local JSON file
config_dir = Path("backend/storage/module_configs")
config_dir.mkdir(parents=True, exist_ok=True)
config_file = config_dir / f"{module_name}.json"
try:
with open(config_file, 'w', encoding='utf-8') as f:
json.dump(config, f, indent=2)
logger.info(f"Saved configuration for module: {module_name}")
return True
except Exception as e:
logger.error(f"Failed to save config for {module_name}: {e}")
return False
async def load_saved_configs(self):
"""Load previously saved module configurations"""
config_dir = Path("backend/storage/module_configs")
if not config_dir.exists():
return
for config_file in config_dir.glob("*.json"):
module_name = config_file.stem
try:
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
self.configs[module_name] = config
logger.info(f"Loaded saved configuration for module: {module_name}")
except Exception as e:
logger.error(f"Failed to load saved config for {module_name}: {e}")
def list_available_modules(self) -> List[Dict]:
"""List all discovered modules with their metadata"""
modules = []
for name, manifest in self.manifests.items():
modules.append({
"name": manifest.name,
"version": manifest.version,
"description": manifest.description,
"author": manifest.author,
"category": manifest.category,
"enabled": manifest.enabled,
"dependencies": manifest.dependencies,
"provides": manifest.provides,
"consumes": manifest.consumes,
"has_schema": name in self.schemas,
"has_config": name in self.configs,
"ui_config": manifest.ui_config
})
return modules
def get_workflow_steps(self) -> Dict[str, List[Dict]]:
"""Get all available workflow steps from modules"""
workflow_steps = {}
for name, manifest in self.manifests.items():
if manifest.workflow_steps:
workflow_steps[name] = manifest.workflow_steps
return workflow_steps
async def update_module_status(self, module_name: str, enabled: bool) -> bool:
"""Update module enabled status"""
manifest = self.manifests.get(module_name)
if not manifest:
return False
manifest.enabled = enabled
# Update the manifest file
modules_dir = Path("modules")
manifest_path = modules_dir / module_name / "module.yaml"
if manifest_path.exists():
try:
manifest_dict = asdict(manifest)
with open(manifest_path, 'w', encoding='utf-8') as f:
yaml.dump(manifest_dict, f, default_flow_style=False)
logger.info(f"Updated module status: {module_name} enabled={enabled}")
return True
except Exception as e:
logger.error(f"Failed to update manifest for {module_name}: {e}")
return False
return False
# Global module config manager instance
module_config_manager = ModuleConfigManager()

View File

@@ -0,0 +1,672 @@
"""
Module management service with dynamic discovery
"""
import asyncio
import importlib
import os
import sys
from typing import Dict, List, Optional, Any
from pathlib import Path
from dataclasses import dataclass
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from app.core.config import settings
from app.core.logging import log_module_event, get_logger
from app.utils.exceptions import ModuleLoadError, ModuleNotFoundError
from app.services.permission_manager import permission_registry
from app.services.module_config_manager import module_config_manager, ModuleManifest
logger = get_logger(__name__)
@dataclass
class ModuleConfig:
"""Configuration for a module"""
name: str
enabled: bool = True
config: Dict[str, Any] = None
dependencies: List[str] = None
def __post_init__(self):
if self.config is None:
self.config = {}
if self.dependencies is None:
self.dependencies = []
class ModuleFileWatcher(FileSystemEventHandler):
"""Watch for changes in module files"""
def __init__(self, module_manager):
self.module_manager = module_manager
def on_modified(self, event):
if event.is_directory or not event.src_path.endswith('.py'):
return
# Extract module name from path
path_parts = Path(event.src_path).parts
if 'modules' in path_parts:
modules_index = path_parts.index('modules')
if modules_index + 1 < len(path_parts):
module_name = path_parts[modules_index + 1]
if module_name in self.module_manager.modules:
log_module_event("hot_reload", "file_changed", {
"module": module_name,
"file": event.src_path
})
# Schedule reload
asyncio.create_task(self.module_manager.reload_module(module_name))
class ModuleManager:
"""Manages loading, unloading, and execution of modules"""
def __init__(self):
self.modules: Dict[str, Any] = {}
self.module_configs: Dict[str, ModuleConfig] = {}
self.module_order: List[str] = []
self.initialized = False
self.hot_reload_enabled = True
self.file_observer = None
self.fastapi_app = None
async def initialize(self, fastapi_app=None):
"""Initialize the module manager and load all modules"""
if self.initialized:
return
# Store FastAPI app reference for router registration
self.fastapi_app = fastapi_app
log_module_event("module_manager", "initializing", {"action": "start"})
try:
# Load module configurations
await self._load_module_configs()
# Load and initialize modules
await self._load_modules()
self.initialized = True
log_module_event("module_manager", "initialized", {
"modules_count": len(self.modules),
"enabled_modules": [name for name, config in self.module_configs.items() if config.enabled]
})
except Exception as e:
log_module_event("module_manager", "initialization_failed", {"error": str(e)})
raise ModuleLoadError(f"Failed to initialize module manager: {str(e)}")
async def _load_module_configs(self):
"""Load module configurations from dynamic discovery"""
# Initialize permission system
permission_registry.register_platform_permissions()
# Discover modules dynamically from filesystem
try:
discovered_manifests = await module_config_manager.discover_modules("modules")
# Load saved configurations
await module_config_manager.load_saved_configs()
# Convert manifests to ModuleConfig objects
for name, manifest in discovered_manifests.items():
saved_config = module_config_manager.get_module_config(name)
module_config = ModuleConfig(
name=manifest.name,
enabled=manifest.enabled,
config=saved_config,
dependencies=manifest.dependencies
)
self.module_configs[name] = module_config
log_module_event(name, "discovered", {
"version": manifest.version,
"description": manifest.description,
"enabled": manifest.enabled,
"dependencies": manifest.dependencies
})
logger.info(f"Discovered {len(discovered_manifests)} modules: {list(discovered_manifests.keys())}")
except Exception as e:
logger.error(f"Failed to discover modules: {e}")
# Fallback to legacy hard-coded modules
await self._load_legacy_modules()
# Start file watcher for hot-reload
if self.hot_reload_enabled:
await self._start_file_watcher()
async def _load_legacy_modules(self):
"""Fallback to legacy hard-coded module loading"""
logger.warning("Falling back to legacy module configuration")
default_modules = [
ModuleConfig(name="rag", enabled=True, config={}),
ModuleConfig(name="workflow", enabled=True, config={})
]
for config in default_modules:
self.module_configs[config.name] = config
async def _load_modules(self):
"""Load all enabled modules"""
# Sort modules by dependencies
self._sort_modules_by_dependencies()
for module_name in self.module_order:
config = self.module_configs[module_name]
if config.enabled:
await self._load_module(module_name, config)
def _sort_modules_by_dependencies(self):
"""Sort modules by their dependencies using topological sort"""
# Simple topological sort
visited = set()
temp_visited = set()
self.module_order = []
def visit(module_name: str):
if module_name in temp_visited:
raise ModuleLoadError(f"Circular dependency detected involving module: {module_name}")
if module_name in visited:
return
temp_visited.add(module_name)
# Visit dependencies first
config = self.module_configs.get(module_name)
if config and config.dependencies:
for dep in config.dependencies:
if dep in self.module_configs:
visit(dep)
temp_visited.remove(module_name)
visited.add(module_name)
self.module_order.append(module_name)
for module_name in self.module_configs:
if module_name not in visited:
visit(module_name)
async def _load_module(self, module_name: str, config: ModuleConfig):
"""Load a single module"""
try:
log_module_event(module_name, "loading", {"config": config.config})
# Check if module exists in the modules directory
# Try multiple possible locations in order of preference
possible_paths = [
Path(f"modules/{module_name}"), # Docker container path
Path(f"modules/{module_name}"), # Container path
Path(f"app/modules/{module_name}") # Legacy path
]
module_dir = None
modules_base_path = None
for path in possible_paths:
if path.exists():
module_dir = path
modules_base_path = path.parent
break
if module_dir and module_dir.exists():
# Use direct import from modules directory
module_path = f"modules.{module_name}.main"
# Add modules directory to Python path if not already there
modules_path_str = str(modules_base_path.absolute())
if modules_path_str not in sys.path:
sys.path.insert(0, modules_path_str)
# Force reload if already imported
if module_path in sys.modules:
importlib.reload(sys.modules[module_path])
module = sys.modules[module_path]
else:
module = importlib.import_module(module_path)
else:
# Final fallback - try app.modules path (legacy)
try:
module_path = f"app.modules.{module_name}.main"
module = importlib.import_module(module_path)
except ImportError:
raise ModuleLoadError(f"Module {module_name} not found in any expected location: {[str(p) for p in possible_paths]}")
# Get the module instance - try multiple patterns
module_instance = None
# Pattern 1: {module_name}_module (e.g., cache_module)
if hasattr(module, f'{module_name}_module'):
module_instance = getattr(module, f'{module_name}_module')
# Pattern 2: Just 'module' attribute
elif hasattr(module, 'module'):
module_instance = getattr(module, 'module')
# Pattern 3: Module class with same name as module (e.g., CacheModule)
elif hasattr(module, f'{module_name.title()}Module'):
module_class = getattr(module, f'{module_name.title()}Module')
if callable(module_class):
module_instance = module_class()
else:
module_instance = module_class
# Pattern 4: Use the module itself as fallback
else:
module_instance = module
self.modules[module_name] = module_instance
# Initialize the module if it has an init function
module_initialized = False
if hasattr(self.modules[module_name], 'initialize'):
try:
import inspect
init_method = self.modules[module_name].initialize
sig = inspect.signature(init_method)
param_count = len([p for p in sig.parameters.values() if p.name != 'self'])
if hasattr(self.modules[module_name], 'config'):
# Pass config if it's a BaseModule
self.modules[module_name].config.update(config.config)
await self.modules[module_name].initialize()
elif param_count > 0:
# Legacy module - pass config as parameter
await self.modules[module_name].initialize(config.config)
else:
# Module initialize method takes no parameters
await self.modules[module_name].initialize()
module_initialized = True
log_module_event(module_name, "initialized", {"success": True})
except Exception as e:
log_module_event(module_name, "initialization_failed", {"error": str(e)})
module_initialized = False
else:
# Module doesn't have initialize method, mark as initialized anyway
module_initialized = True
# Mark module initialization status (safely)
try:
self.modules[module_name].initialized = module_initialized
except AttributeError:
# Module doesn't support the initialized attribute, that's okay
pass
# Register module permissions - check both new and legacy methods
permissions = []
# New BaseModule method
if hasattr(self.modules[module_name], 'get_required_permissions'):
try:
permissions = self.modules[module_name].get_required_permissions()
log_module_event(module_name, "permissions_registered", {
"permissions_count": len(permissions),
"type": "BaseModule"
})
except Exception as e:
log_module_event(module_name, "permissions_failed", {"error": str(e)})
# Legacy method
elif hasattr(self.modules[module_name], 'get_permissions'):
try:
permissions = self.modules[module_name].get_permissions()
log_module_event(module_name, "permissions_registered", {
"permissions_count": len(permissions),
"type": "legacy"
})
except Exception as e:
log_module_event(module_name, "permissions_failed", {"error": str(e)})
# Register permissions with the permission system
if permissions:
permission_registry.register_module(module_name, permissions)
# Register module router with FastAPI app if available
await self._register_module_router(module_name, self.modules[module_name])
log_module_event(module_name, "loaded", {"success": True})
except ImportError as e:
error_msg = f"Module {module_name} import failed: {str(e)}"
log_module_event(module_name, "load_failed", {"error": error_msg, "type": "ImportError"})
# For critical modules, we might want to fail completely
if module_name in ['security', 'cache']:
raise ModuleLoadError(error_msg)
# For optional modules, log warning but continue
import warnings
warnings.warn(f"Optional module {module_name} failed to load: {str(e)}")
except Exception as e:
error_msg = f"Module {module_name} loading failed: {str(e)}"
log_module_event(module_name, "load_failed", {"error": error_msg, "type": type(e).__name__})
# For critical modules, we might want to fail completely
if module_name in ['security', 'cache']:
raise ModuleLoadError(error_msg)
# For optional modules, log warning but continue
import warnings
warnings.warn(f"Optional module {module_name} failed to load: {str(e)}")
async def _register_module_router(self, module_name: str, module_instance):
"""Register a module's router with the FastAPI app if it has one"""
if not self.fastapi_app or not module_instance:
return
try:
# Check if module has a router attribute
if hasattr(module_instance, 'router'):
router = getattr(module_instance, 'router')
# Verify it's actually a FastAPI router
from fastapi import APIRouter
if isinstance(router, APIRouter):
# Register the router with the app
self.fastapi_app.include_router(router)
log_module_event(module_name, "router_registered", {
"router_prefix": getattr(router, 'prefix', 'unknown'),
"router_tags": getattr(router, 'tags', [])
})
logger.info(f"Registered router for module {module_name}")
else:
logger.debug(f"Module {module_name} has 'router' attribute but it's not a FastAPI router")
else:
logger.debug(f"Module {module_name} does not have a router")
except Exception as e:
log_module_event(module_name, "router_registration_failed", {
"error": str(e)
})
logger.warning(f"Failed to register router for module {module_name}: {e}")
async def unload_module(self, module_name: str):
"""Unload a module"""
if module_name not in self.modules:
raise ModuleNotFoundError(f"Module {module_name} not loaded")
try:
module = self.modules[module_name]
# Call cleanup if available
if hasattr(module, 'cleanup'):
await module.cleanup()
del self.modules[module_name]
log_module_event(module_name, "unloaded", {"success": True})
except Exception as e:
log_module_event(module_name, "unload_failed", {"error": str(e)})
raise ModuleLoadError(f"Failed to unload module {module_name}: {str(e)}")
async def reload_module(self, module_name: str) -> bool:
"""Reload a module"""
log_module_event(module_name, "reloading", {})
try:
if module_name in self.modules:
await self.unload_module(module_name)
config = self.module_configs.get(module_name)
if config and config.enabled:
await self._load_module(module_name, config)
log_module_event(module_name, "reloaded", {"success": True})
return True
else:
log_module_event(module_name, "reload_skipped", {"reason": "Module disabled or no config"})
return False
except Exception as e:
log_module_event(module_name, "reload_failed", {"error": str(e)})
return False
def get_module(self, module_name: str) -> Optional[Any]:
"""Get a loaded module"""
return self.modules.get(module_name)
def list_modules(self) -> List[str]:
"""List all loaded modules"""
return list(self.modules.keys())
def is_module_loaded(self, module_name: str) -> bool:
"""Check if a module is loaded"""
return module_name in self.modules
async def execute_interceptor_chain(self, chain_type: str, context: Dict[str, Any]) -> Dict[str, Any]:
"""Execute interceptor chain for all loaded modules"""
result_context = context.copy()
for module_name in self.module_order:
if module_name in self.modules:
module = self.modules[module_name]
# Check if module has the interceptor
interceptor_method = f"{chain_type}_interceptor"
if hasattr(module, interceptor_method):
try:
interceptor = getattr(module, interceptor_method)
result_context = await interceptor(result_context)
log_module_event(module_name, "interceptor_executed", {
"chain_type": chain_type,
"success": True
})
except Exception as e:
log_module_event(module_name, "interceptor_failed", {
"chain_type": chain_type,
"error": str(e)
})
# Continue with other modules even if one fails
continue
return result_context
async def shutdown(self):
"""Shutdown all modules"""
if not self.initialized:
return
log_module_event("module_manager", "shutting_down", {"modules_count": len(self.modules)})
# Unload modules in reverse order
for module_name in reversed(self.module_order):
if module_name in self.modules:
try:
await self.unload_module(module_name)
except Exception as e:
log_module_event(module_name, "shutdown_error", {"error": str(e)})
self.initialized = False
log_module_event("module_manager", "shutdown_complete", {"success": True})
async def cleanup(self):
"""Cleanup method - alias for shutdown"""
await self.shutdown()
async def _start_file_watcher(self):
"""Start watching module files for changes"""
try:
# Try multiple possible locations for modules directory
possible_modules_paths = [
Path("modules"), # Docker container path
Path("modules"), # Container path
Path("app/modules") # Legacy path
]
modules_path = None
for path in possible_modules_paths:
if path.exists():
modules_path = path
break
if modules_path and modules_path.exists():
self.file_observer = Observer()
event_handler = ModuleFileWatcher(self)
self.file_observer.schedule(event_handler, str(modules_path), recursive=True)
self.file_observer.start()
log_module_event("hot_reload", "watcher_started", {"path": str(modules_path)})
else:
log_module_event("hot_reload", "watcher_skipped", {"reason": "No modules directory found"})
except Exception as e:
log_module_event("hot_reload", "watcher_failed", {"error": str(e)})
# Dynamic Module Management Methods
async def enable_module(self, module_name: str) -> bool:
"""Enable a module"""
try:
# Update the manifest status
success = await module_config_manager.update_module_status(module_name, True)
if not success:
return False
# Update local config
if module_name in self.module_configs:
self.module_configs[module_name].enabled = True
# Load the module if not already loaded
if module_name not in self.modules:
config = self.module_configs.get(module_name)
if config:
await self._load_module(module_name, config)
log_module_event(module_name, "enabled", {"success": True})
return True
except Exception as e:
log_module_event(module_name, "enable_failed", {"error": str(e)})
return False
async def disable_module(self, module_name: str) -> bool:
"""Disable a module"""
try:
# Update the manifest status
success = await module_config_manager.update_module_status(module_name, False)
if not success:
return False
# Update local config
if module_name in self.module_configs:
self.module_configs[module_name].enabled = False
# Unload the module if loaded
if module_name in self.modules:
await self.unload_module(module_name)
log_module_event(module_name, "disabled", {"success": True})
return True
except Exception as e:
log_module_event(module_name, "disable_failed", {"error": str(e)})
return False
def get_module_info(self, module_name: str) -> Optional[Dict]:
"""Get comprehensive module information"""
manifest = module_config_manager.get_module_manifest(module_name)
if not manifest:
return None
config = self.module_configs.get(module_name)
is_loaded = self.is_module_loaded(module_name)
return {
"name": manifest.name,
"version": manifest.version,
"description": manifest.description,
"author": manifest.author,
"category": manifest.category,
"enabled": config.enabled if config else manifest.enabled,
"loaded": is_loaded,
"dependencies": manifest.dependencies,
"optional_dependencies": manifest.optional_dependencies,
"provides": manifest.provides,
"consumes": manifest.consumes,
"endpoints": manifest.endpoints,
"workflow_steps": manifest.workflow_steps,
"permissions": manifest.permissions,
"ui_config": manifest.ui_config,
"has_schema": module_config_manager.get_module_schema(module_name) is not None,
"current_config": module_config_manager.get_module_config(module_name)
}
def list_all_modules(self) -> List[Dict]:
"""List all discovered modules with their information"""
modules = []
for name in module_config_manager.manifests.keys():
module_info = self.get_module_info(name)
if module_info:
modules.append(module_info)
return modules
async def update_module_config(self, module_name: str, config: Dict) -> bool:
"""Update module configuration"""
try:
# Validate and save the configuration
success = await module_config_manager.save_module_config(module_name, config)
if not success:
return False
# Update local config
if module_name in self.module_configs:
self.module_configs[module_name].config = config
# Reload the module if it's currently loaded
if self.is_module_loaded(module_name):
await self.reload_module(module_name)
log_module_event(module_name, "config_updated", {"success": True})
return True
except Exception as e:
log_module_event(module_name, "config_update_failed", {"error": str(e)})
return False
def get_workflow_steps(self) -> Dict[str, List[Dict]]:
"""Get all available workflow steps from modules"""
return module_config_manager.get_workflow_steps()
async def get_module_health(self, module_name: str) -> Dict:
"""Get module health status"""
manifest = module_config_manager.get_module_manifest(module_name)
if not manifest:
return {"status": "unknown", "message": "Module not found"}
is_loaded = self.is_module_loaded(module_name)
module = self.get_module(module_name) if is_loaded else None
health = {
"status": "healthy" if is_loaded else "stopped",
"loaded": is_loaded,
"enabled": manifest.enabled,
"dependencies_met": self._check_dependencies(module_name),
"last_loaded": None,
"error": None
}
# Check if module has custom health check
if module and hasattr(module, 'get_health'):
try:
custom_health = await module.get_health()
health.update(custom_health)
except Exception as e:
health["status"] = "error"
health["error"] = str(e)
return health
def _check_dependencies(self, module_name: str) -> bool:
"""Check if all module dependencies are met"""
manifest = module_config_manager.get_module_manifest(module_name)
if not manifest or not manifest.dependencies:
return True
for dep in manifest.dependencies:
if not self.is_module_loaded(dep):
return False
return True
# Global module manager instance
module_manager = ModuleManager()

View File

@@ -0,0 +1,393 @@
"""
Enhanced Permission Manager for Module-Specific Permissions
Provides hierarchical permission management with wildcard support,
dynamic module permission registration, and fine-grained access control.
"""
import re
from typing import Dict, List, Set, Optional, Any
from dataclasses import dataclass
from enum import Enum
from app.core.logging import get_logger
logger = get_logger(__name__)
class PermissionAction(str, Enum):
"""Standard permission actions"""
CREATE = "create"
READ = "read"
UPDATE = "update"
DELETE = "delete"
EXECUTE = "execute"
MANAGE = "manage"
VIEW = "view"
ADMIN = "admin"
@dataclass
class Permission:
"""Permission definition"""
resource: str
action: str
description: str = ""
conditions: Optional[Dict[str, Any]] = None
@dataclass
class PermissionScope:
"""Permission scope for context-aware permissions"""
namespace: str
resource: str
action: str
context: Dict[str, Any] = None
class PermissionTree:
"""Hierarchical permission tree for efficient wildcard matching"""
def __init__(self):
self.root = {}
self.permissions: Dict[str, Permission] = {}
def add_permission(self, permission_string: str, permission: Permission):
"""Add a permission to the tree"""
parts = permission_string.split(":")
current = self.root
for part in parts:
if part not in current:
current[part] = {}
current = current[part]
current["_permission"] = permission
self.permissions[permission_string] = permission
def has_permission(self, user_permissions: List[str], required: str) -> bool:
"""Check if user has required permission with wildcard support"""
# Handle None or empty permissions
if not user_permissions:
return False
# Direct match
if required in user_permissions:
return True
# Check wildcard patterns
for user_perm in user_permissions:
if self._matches_wildcard(user_perm, required):
return True
return False
def _matches_wildcard(self, pattern: str, permission: str) -> bool:
"""Check if a wildcard pattern matches a permission"""
if "*" not in pattern:
return pattern == permission
pattern_parts = pattern.split(":")
perm_parts = permission.split(":")
# Handle patterns ending with * (e.g., "platform:*" should match "platform:audit:read")
if pattern_parts[-1] == "*" and len(pattern_parts) == 2:
# Pattern like "platform:*" should match any permission starting with "platform:"
if len(perm_parts) >= 2 and pattern_parts[0] == perm_parts[0]:
return True
# Original logic for exact-length matching (e.g., "platform:audit:*" matches "platform:audit:read")
if len(pattern_parts) != len(perm_parts):
return False
for pattern_part, perm_part in zip(pattern_parts, perm_parts):
if pattern_part == "*":
continue
elif pattern_part != perm_part:
return False
return True
def get_matching_permissions(self, user_permissions: List[str]) -> Set[str]:
"""Get all permissions that match user's granted permissions"""
matching = set()
for granted in user_permissions:
if "*" in granted:
# Find all permissions matching this wildcard
for perm in self.permissions.keys():
if self._matches_wildcard(granted, perm):
matching.add(perm)
else:
matching.add(granted)
return matching
class ModulePermissionRegistry:
"""Registry for module-specific permissions"""
def __init__(self):
self.tree = PermissionTree()
self.module_permissions: Dict[str, List[Permission]] = {}
self.role_permissions: Dict[str, List[str]] = {}
self.default_roles = self._initialize_default_roles()
def _initialize_default_roles(self) -> Dict[str, List[str]]:
"""Initialize default permission roles"""
return {
"super_admin": [
"platform:*",
"modules:*",
"llm:*"
],
"admin": [
"platform:*",
"modules:*",
"llm:*"
],
"developer": [
"platform:api-keys:*",
"platform:budgets:read",
"llm:completions:execute",
"llm:embeddings:execute",
"modules:*:read",
"modules:*:execute"
],
"user": [
"llm:completions:execute",
"llm:embeddings:execute",
"modules:*:read"
],
"readonly": [
"platform:*:read",
"modules:*:read"
]
}
def register_module(self, module_id: str, permissions: List[Permission]):
"""Register permissions for a module"""
self.module_permissions[module_id] = permissions
for perm in permissions:
perm_string = f"modules:{module_id}:{perm.resource}:{perm.action}"
self.tree.add_permission(perm_string, perm)
logger.info(f"Registered {len(permissions)} permissions for module {module_id}")
def register_platform_permissions(self):
"""Register core platform permissions"""
platform_permissions = [
Permission("users", "create", "Create users"),
Permission("users", "read", "View users"),
Permission("users", "update", "Update users"),
Permission("users", "delete", "Delete users"),
Permission("users", "manage", "Full user management"),
Permission("api-keys", "create", "Create API keys"),
Permission("api-keys", "read", "View API keys"),
Permission("api-keys", "update", "Update API keys"),
Permission("api-keys", "delete", "Delete API keys"),
Permission("api-keys", "manage", "Full API key management"),
Permission("budgets", "create", "Create budgets"),
Permission("budgets", "read", "View budgets"),
Permission("budgets", "update", "Update budgets"),
Permission("budgets", "delete", "Delete budgets"),
Permission("budgets", "manage", "Full budget management"),
Permission("audit", "read", "View audit logs"),
Permission("audit", "export", "Export audit logs"),
Permission("settings", "read", "View settings"),
Permission("settings", "update", "Update settings"),
Permission("settings", "manage", "Full settings management"),
Permission("health", "read", "View health status"),
Permission("metrics", "read", "View metrics"),
Permission("permissions", "read", "View permissions"),
Permission("permissions", "manage", "Manage permissions"),
Permission("roles", "create", "Create roles"),
Permission("roles", "read", "View roles"),
Permission("roles", "update", "Update roles"),
Permission("roles", "delete", "Delete roles"),
]
for perm in platform_permissions:
perm_string = f"platform:{perm.resource}:{perm.action}"
self.tree.add_permission(perm_string, perm)
# Register LLM permissions
llm_permissions = [
Permission("completions", "execute", "Execute chat completions"),
Permission("embeddings", "execute", "Execute embeddings"),
Permission("models", "list", "List available models"),
Permission("usage", "view", "View usage statistics"),
]
for perm in llm_permissions:
perm_string = f"llm:{perm.resource}:{perm.action}"
self.tree.add_permission(perm_string, perm)
logger.info("Registered platform and LLM permissions")
def check_permission(self, user_permissions: List[str], required: str,
context: Dict[str, Any] = None) -> bool:
"""Check if user has required permission"""
# Basic permission check
has_perm = self.tree.has_permission(user_permissions, required)
if not has_perm:
return False
# Context-based permission checks
if context:
return self._check_context_permissions(user_permissions, required, context)
return True
def _check_context_permissions(self, user_permissions: List[str],
required: str, context: Dict[str, Any]) -> bool:
"""Check context-aware permissions"""
# Extract resource owner information
resource_owner = context.get("owner_id")
current_user = context.get("user_id")
# Users can always access their own resources
if resource_owner and current_user and resource_owner == current_user:
return True
# Check for elevated permissions for cross-user access
if resource_owner and resource_owner != current_user:
elevated_required = required.replace(":read", ":manage").replace(":update", ":manage")
return self.tree.has_permission(user_permissions, elevated_required)
return True
def get_user_permissions(self, roles: List[str],
custom_permissions: List[str] = None) -> List[str]:
"""Get effective permissions for a user based on roles and custom permissions"""
permissions = set()
# Add role-based permissions
for role in roles:
role_perms = self.role_permissions.get(role, self.default_roles.get(role, []))
permissions.update(role_perms)
# Add custom permissions
if custom_permissions:
permissions.update(custom_permissions)
return list(permissions)
def get_module_permissions(self, module_id: str) -> List[Permission]:
"""Get all permissions for a specific module"""
return self.module_permissions.get(module_id, [])
def get_available_permissions(self, namespace: str = None) -> Dict[str, List[Permission]]:
"""Get all available permissions, optionally filtered by namespace"""
if namespace:
filtered = {}
for perm_string, permission in self.tree.permissions.items():
if perm_string.startswith(f"{namespace}:"):
if namespace not in filtered:
filtered[namespace] = []
filtered[namespace].append(permission)
return filtered
# Group by namespace
grouped = {}
for perm_string, permission in self.tree.permissions.items():
namespace = perm_string.split(":")[0]
if namespace not in grouped:
grouped[namespace] = []
grouped[namespace].append(permission)
return grouped
def create_role(self, role_name: str, permissions: List[str]):
"""Create a custom role with specific permissions"""
self.role_permissions[role_name] = permissions
logger.info(f"Created role '{role_name}' with {len(permissions)} permissions")
def validate_permissions(self, permissions: List[str]) -> Dict[str, Any]:
"""Validate a list of permissions"""
valid = []
invalid = []
for perm in permissions:
if perm in self.tree.permissions or self._is_valid_wildcard(perm):
valid.append(perm)
else:
invalid.append(perm)
return {
"valid": valid,
"invalid": invalid,
"is_valid": len(invalid) == 0
}
def _is_valid_wildcard(self, permission: str) -> bool:
"""Check if a wildcard permission is valid"""
if "*" not in permission:
return False
parts = permission.split(":")
# Check if the structure is valid
if len(parts) < 2:
return False
# Check if there are any valid permissions matching this pattern
for existing_perm in self.tree.permissions.keys():
if self.tree._matches_wildcard(permission, existing_perm):
return True
return False
def get_permission_hierarchy(self) -> Dict[str, Any]:
"""Get the permission hierarchy tree structure"""
def build_tree(node, path=""):
tree = {}
for key, value in node.items():
if key == "_permission":
tree["_permission"] = {
"resource": value.resource,
"action": value.action,
"description": value.description
}
else:
current_path = f"{path}:{key}" if path else key
tree[key] = build_tree(value, current_path)
return tree
return build_tree(self.tree.root)
def require_permission(user_permissions: List[str], required_permission: str, context: Optional[Dict[str, Any]] = None):
"""
Decorator function to require a specific permission
Raises HTTPException if user doesn't have the required permission
Args:
user_permissions: List of user's permissions
required_permission: The permission string required
context: Optional context for conditional permissions
Raises:
HTTPException: If user doesn't have the required permission
"""
from fastapi import HTTPException, status
if not permission_registry.check_permission(user_permissions, required_permission, context):
logger.warning(f"Permission denied: required '{required_permission}', user has {user_permissions}")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Insufficient permissions. Required: {required_permission}"
)
# Global permission registry instance
permission_registry = ModulePermissionRegistry()

View File

@@ -0,0 +1,789 @@
"""
RAG Service
Handles all RAG (Retrieval Augmented Generation) operations including
collections, documents, processing, and vector operations
"""
import os
import uuid
import mimetypes
import logging
from typing import List, Optional, Dict, Any, Tuple
from pathlib import Path
from datetime import datetime
import hashlib
import asyncio
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete, func, and_, or_
from sqlalchemy.orm import selectinload
from app.models.rag_collection import RagCollection
from app.models.rag_document import RagDocument
from app.utils.exceptions import APIException
logger = logging.getLogger(__name__)
class RAGService:
"""Service for RAG operations"""
def __init__(self, db: AsyncSession):
self.db = db
self.upload_dir = Path("storage/rag_documents")
self.upload_dir.mkdir(parents=True, exist_ok=True)
# Collection Operations
async def create_collection(self, name: str, description: Optional[str] = None) -> RagCollection:
"""Create a new RAG collection"""
# Check if collection name already exists
stmt = select(RagCollection).where(RagCollection.name == name, RagCollection.is_active == True)
existing = await self.db.scalar(stmt)
if existing:
raise APIException(status_code=400, error_code="COLLECTION_EXISTS", detail=f"Collection '{name}' already exists")
# Generate unique Qdrant collection name
qdrant_name = f"rag_{name.lower().replace(' ', '_').replace('-', '_')}_{uuid.uuid4().hex[:8]}"
# Create collection
collection = RagCollection(
name=name,
description=description,
qdrant_collection_name=qdrant_name,
status='active'
)
self.db.add(collection)
await self.db.commit()
await self.db.refresh(collection)
# TODO: Create Qdrant collection
await self._create_qdrant_collection(qdrant_name)
return collection
async def get_collections(self, skip: int = 0, limit: int = 100) -> List[RagCollection]:
"""Get all active collections"""
stmt = (
select(RagCollection)
.where(RagCollection.is_active == True)
.order_by(RagCollection.created_at.desc())
.offset(skip)
.limit(limit)
)
result = await self.db.execute(stmt)
return result.scalars().all()
async def get_collection(self, collection_id: int) -> Optional[RagCollection]:
"""Get a collection by ID"""
stmt = select(RagCollection).where(
RagCollection.id == collection_id,
RagCollection.is_active == True
)
return await self.db.scalar(stmt)
async def get_all_collections(self, skip: int = 0, limit: int = 100) -> List[dict]:
"""Get all collections from Qdrant (source of truth) with additional metadata from PostgreSQL."""
logger.info("Getting all RAG collections from Qdrant (source of truth)")
all_collections = []
try:
# Get RAG module instance to access Qdrant collections
from app.services.module_manager import module_manager
rag_module = module_manager.get_module("rag")
if not rag_module or not hasattr(rag_module, 'qdrant_client'):
logger.warning("RAG module or Qdrant client not available")
# Fallback to PostgreSQL only
managed_collections = await self.get_collections(skip=skip, limit=limit)
return [
{
"id": collection.id,
"name": collection.name,
"description": collection.description or "",
"document_count": collection.document_count or 0,
"size_bytes": collection.size_bytes or 0,
"vector_count": collection.vector_count or 0,
"status": collection.status,
"created_at": collection.created_at.isoformat() if collection.created_at else "",
"updated_at": collection.updated_at.isoformat() if collection.updated_at else "",
"is_active": collection.is_active,
"qdrant_collection_name": collection.qdrant_collection_name,
"is_managed": True,
"source": "managed"
}
for collection in managed_collections
]
# Get all collections from Qdrant (source of truth) using safe method
qdrant_collection_names = await rag_module._get_collections_safely()
logger.info(f"Found {len(qdrant_collection_names)} collections in Qdrant")
# Get metadata from PostgreSQL for additional info
db_metadata = await self.get_collections(skip=0, limit=1000)
metadata_by_name = {col.qdrant_collection_name: col for col in db_metadata}
# Process each Qdrant collection
for qdrant_name in qdrant_collection_names:
logger.info(f"Processing Qdrant collection: {qdrant_name}")
try:
# Get detailed collection info from Qdrant using safe method
collection_info = await rag_module._get_collection_info_safely(qdrant_name)
point_count = collection_info.get("points_count", 0)
vector_size = collection_info.get("vector_size", 384)
# Estimate collection size (points * vector_size * 4 bytes + metadata overhead)
estimated_size = int(point_count * vector_size * 4 * 1.2) # 20% overhead for metadata
# Get metadata from PostgreSQL if available
db_metadata_entry = metadata_by_name.get(qdrant_name)
if db_metadata_entry:
# Use PostgreSQL metadata but Qdrant data for counts/size
collection_data = {
"id": db_metadata_entry.id,
"name": db_metadata_entry.name,
"description": db_metadata_entry.description or "",
"document_count": point_count, # From Qdrant (real data)
"size_bytes": estimated_size, # From Qdrant (real data)
"vector_count": point_count, # From Qdrant (real data)
"status": db_metadata_entry.status,
"created_at": db_metadata_entry.created_at.isoformat() if db_metadata_entry.created_at else "",
"updated_at": db_metadata_entry.updated_at.isoformat() if db_metadata_entry.updated_at else "",
"is_active": db_metadata_entry.is_active,
"qdrant_collection_name": qdrant_name,
"is_managed": True,
"source": "managed"
}
else:
# Collection exists in Qdrant but not in our metadata
from datetime import datetime
now = datetime.utcnow()
collection_data = {
"id": f"ext_{qdrant_name}", # External identifier
"name": qdrant_name,
"description": f"External Qdrant collection (vectors: {vector_size}d, points: {point_count})",
"document_count": point_count, # From Qdrant
"size_bytes": estimated_size, # From Qdrant
"vector_count": point_count, # From Qdrant
"status": "active",
"created_at": now.isoformat(),
"updated_at": now.isoformat(),
"is_active": True,
"qdrant_collection_name": qdrant_name,
"is_managed": False,
"source": "external"
}
all_collections.append(collection_data)
except Exception as e:
logger.error(f"Error processing collection {qdrant_name}: {e}")
# Still add the collection but with minimal info
from datetime import datetime
now = datetime.utcnow()
collection_data = {
"id": f"ext_{qdrant_name}",
"name": qdrant_name,
"description": f"External Qdrant collection (error loading details: {str(e)})",
"document_count": 0,
"size_bytes": 0,
"vector_count": 0,
"status": "error",
"created_at": now.isoformat(),
"updated_at": now.isoformat(),
"is_active": True,
"qdrant_collection_name": qdrant_name,
"is_managed": False,
"source": "external"
}
all_collections.append(collection_data)
except Exception as e:
logger.error(f"Error fetching collections from Qdrant: {e}")
# Fallback to managed collections only
managed_collections = await self.get_collections(skip=skip, limit=limit)
return [
{
"id": collection.id,
"name": collection.name,
"description": collection.description or "",
"document_count": collection.document_count or 0,
"size_bytes": collection.size_bytes or 0,
"vector_count": collection.vector_count or 0,
"status": collection.status,
"created_at": collection.created_at.isoformat() if collection.created_at else "",
"updated_at": collection.updated_at.isoformat() if collection.updated_at else "",
"is_active": collection.is_active,
"qdrant_collection_name": collection.qdrant_collection_name,
"is_managed": True,
"source": "managed"
}
for collection in managed_collections
]
# Apply pagination
if skip > 0 or limit < len(all_collections):
all_collections = all_collections[skip:skip + limit]
logger.info(f"Total collections returned: {len(all_collections)}")
return all_collections
async def delete_collection(self, collection_id: int, cascade: bool = True) -> bool:
"""Delete a collection and optionally all its documents"""
collection = await self.get_collection(collection_id)
if not collection:
return False
# Get all documents in the collection
stmt = select(RagDocument).where(
RagDocument.collection_id == collection_id,
RagDocument.is_deleted == False
)
result = await self.db.execute(stmt)
documents = result.scalars().all()
if documents and not cascade:
raise APIException(
status_code=400,
error_code="COLLECTION_HAS_DOCUMENTS",
detail=f"Cannot delete collection with {len(documents)} documents. Set cascade=true to delete documents along with collection."
)
# Delete all documents in the collection (cascade deletion)
if documents:
for document in documents:
# Soft delete document
document.is_deleted = True
document.deleted_at = datetime.utcnow()
# Delete physical file if it exists
try:
import os
if os.path.exists(document.file_path):
os.remove(document.file_path)
except Exception as e:
logger.warning(f"Failed to delete file {document.file_path}: {e}")
# Soft delete collection
collection.is_active = False
collection.updated_at = datetime.utcnow()
await self.db.commit()
# Delete Qdrant collection
try:
await self._delete_qdrant_collection(collection.qdrant_collection_name)
except Exception as e:
logger.warning(f"Failed to delete Qdrant collection {collection.qdrant_collection_name}: {e}")
return True
# Document Operations
async def upload_document(
self,
collection_id: int,
file_content: bytes,
filename: str,
content_type: Optional[str] = None
) -> RagDocument:
"""Upload and process a document"""
# Verify collection exists
collection = await self.get_collection(collection_id)
if not collection:
raise APIException(status_code=404, error_code="COLLECTION_NOT_FOUND", detail="Collection not found")
# Validate file type
file_ext = Path(filename).suffix.lower()
if not self._is_supported_file_type(file_ext):
raise APIException(
status_code=400,
error_code="UNSUPPORTED_FILE_TYPE",
detail=f"Unsupported file type: {file_ext}. Supported: .pdf, .docx, .doc, .txt, .md"
)
# Generate safe filename
safe_filename = self._generate_safe_filename(filename)
file_path = self.upload_dir / f"{collection_id}" / safe_filename
file_path.parent.mkdir(parents=True, exist_ok=True)
# Save file
with open(file_path, 'wb') as f:
f.write(file_content)
# Detect MIME type
if not content_type:
content_type, _ = mimetypes.guess_type(filename)
# Create document record
document = RagDocument(
collection_id=collection_id,
filename=safe_filename,
original_filename=filename,
file_path=str(file_path),
file_type=file_ext.lstrip('.'),
file_size=len(file_content),
mime_type=content_type,
status='processing'
)
self.db.add(document)
await self.db.commit()
await self.db.refresh(document)
# Load the collection relationship to avoid lazy loading issues
from sqlalchemy.orm import selectinload
from sqlalchemy import select
stmt = select(RagDocument).options(selectinload(RagDocument.collection)).where(RagDocument.id == document.id)
result = await self.db.execute(stmt)
document = result.scalar_one()
# Add document to processing queue
from app.services.document_processor import document_processor
await document_processor.add_task(document.id, priority=1)
return document
async def get_documents(
self,
collection_id: Optional[int] = None,
skip: int = 0,
limit: int = 100
) -> List[RagDocument]:
"""Get documents, optionally filtered by collection"""
stmt = (
select(RagDocument)
.options(selectinload(RagDocument.collection))
.where(RagDocument.is_deleted == False)
.order_by(RagDocument.created_at.desc())
.offset(skip)
.limit(limit)
)
if collection_id:
stmt = stmt.where(RagDocument.collection_id == collection_id)
result = await self.db.execute(stmt)
return result.scalars().all()
async def get_document(self, document_id: int) -> Optional[RagDocument]:
"""Get a document by ID"""
stmt = (
select(RagDocument)
.options(selectinload(RagDocument.collection))
.where(
RagDocument.id == document_id,
RagDocument.is_deleted == False
)
)
return await self.db.scalar(stmt)
async def delete_document(self, document_id: int) -> bool:
"""Delete a document"""
document = await self.get_document(document_id)
if not document:
return False
# Soft delete document
document.is_deleted = True
document.deleted_at = datetime.utcnow()
document.updated_at = datetime.utcnow()
await self.db.commit()
# Update collection statistics
await self._update_collection_stats(document.collection_id)
# Remove vectors from Qdrant
await self._delete_document_vectors(document.id, document.collection.qdrant_collection_name)
# Remove file
try:
if os.path.exists(document.file_path):
os.remove(document.file_path)
except Exception as e:
print(f"Warning: Could not delete file {document.file_path}: {e}")
return True
async def download_document(self, document_id: int) -> Optional[Tuple[bytes, str, str]]:
"""Download original document file"""
document = await self.get_document(document_id)
if not document or not os.path.exists(document.file_path):
return None
try:
with open(document.file_path, 'rb') as f:
content = f.read()
return content, document.original_filename, document.mime_type or 'application/octet-stream'
except Exception:
return None
# Stats and Analytics
async def get_stats(self) -> Dict[str, Any]:
"""Get RAG system statistics"""
# Collection stats
collection_count_stmt = select(func.count(RagCollection.id)).where(RagCollection.is_active == True)
total_collections = await self.db.scalar(collection_count_stmt)
# Document stats
doc_count_stmt = select(func.count(RagDocument.id)).where(RagDocument.is_deleted == False)
total_documents = await self.db.scalar(doc_count_stmt)
# Processing stats
processing_stmt = select(func.count(RagDocument.id)).where(
RagDocument.is_deleted == False,
RagDocument.status == 'processing'
)
processing_documents = await self.db.scalar(processing_stmt)
# Size stats
size_stmt = select(func.sum(RagDocument.file_size)).where(RagDocument.is_deleted == False)
total_size = await self.db.scalar(size_stmt) or 0
# Vector stats
vector_stmt = select(func.sum(RagDocument.vector_count)).where(RagDocument.is_deleted == False)
total_vectors = await self.db.scalar(vector_stmt) or 0
return {
"collections": {
"total": total_collections or 0,
"active": total_collections or 0
},
"documents": {
"total": total_documents or 0,
"processing": processing_documents or 0,
"processed": (total_documents or 0) - (processing_documents or 0)
},
"storage": {
"total_size_bytes": total_size,
"total_size_mb": round(total_size / (1024 * 1024), 2) if total_size else 0
},
"vectors": {
"total": total_vectors
}
}
# Private Helper Methods
def _is_supported_file_type(self, file_ext: str) -> bool:
"""Check if file type is supported"""
supported_types = {'.pdf', '.docx', '.doc', '.txt', '.md', '.html', '.json', '.csv', '.xlsx', '.xls'}
return file_ext.lower() in supported_types
def _generate_safe_filename(self, filename: str) -> str:
"""Generate a safe filename for storage"""
# Extract extension
path = Path(filename)
ext = path.suffix
name = path.stem
# Create hash of original filename for uniqueness
hash_suffix = hashlib.md5(filename.encode()).hexdigest()[:8]
# Sanitize name
safe_name = "".join(c for c in name if c.isalnum() or c in (' ', '-', '_')).strip()
safe_name = safe_name.replace(' ', '_')
# Combine with timestamp and hash
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return f"{safe_name}_{timestamp}_{hash_suffix}{ext}"
async def _create_qdrant_collection(self, collection_name: str):
"""Create collection in Qdrant vector database"""
try:
# Get RAG module to create the collection
try:
from app.services.module_manager import module_manager
rag_module = module_manager.get_module('rag')
except ImportError as e:
logger.error(f"Failed to import module_manager: {e}")
rag_module = None
if rag_module and hasattr(rag_module, 'create_collection'):
success = await rag_module.create_collection(collection_name)
if success:
logger.info(f"Created Qdrant collection: {collection_name}")
else:
logger.error(f"Failed to create Qdrant collection: {collection_name}")
else:
logger.warning("RAG module not available for collection creation")
except Exception as e:
logger.error(f"Error creating Qdrant collection {collection_name}: {e}")
# Don't re-raise the error - collection is already saved in database
# The Qdrant collection can be created later if needed
async def _delete_qdrant_collection(self, collection_name: str):
"""Delete collection from Qdrant vector database"""
try:
# Get RAG module to delete the collection
try:
from app.services.module_manager import module_manager
rag_module = module_manager.get_module('rag')
except ImportError as e:
logger.error(f"Failed to import module_manager: {e}")
rag_module = None
if rag_module and hasattr(rag_module, 'delete_collection'):
success = await rag_module.delete_collection(collection_name)
if success:
logger.info(f"Deleted Qdrant collection: {collection_name}")
else:
logger.warning(f"Qdrant collection not found or already deleted: {collection_name}")
else:
logger.warning("RAG module not available for collection deletion")
except Exception as e:
logger.error(f"Error deleting Qdrant collection {collection_name}: {e}")
# Don't re-raise the error for deletion as it's not critical if cleanup fails
async def _update_collection_stats(self, collection_id: int):
"""Update collection statistics (document count, size, etc.)"""
try:
# Get collection
collection = await self.get_collection(collection_id)
if not collection:
return
# Count active documents
stmt = select(func.count(RagDocument.id)).where(
RagDocument.collection_id == collection_id,
RagDocument.is_deleted == False
)
doc_count = await self.db.scalar(stmt) or 0
# Sum file sizes
stmt = select(func.sum(RagDocument.file_size)).where(
RagDocument.collection_id == collection_id,
RagDocument.is_deleted == False
)
total_size = await self.db.scalar(stmt) or 0
# Sum vector counts
stmt = select(func.sum(RagDocument.vector_count)).where(
RagDocument.collection_id == collection_id,
RagDocument.is_deleted == False
)
vector_count = await self.db.scalar(stmt) or 0
# Update collection
collection.document_count = doc_count
collection.size_bytes = total_size
collection.vector_count = vector_count
collection.updated_at = datetime.utcnow()
await self.db.commit()
except Exception as e:
logger.error(f"Failed to update collection stats for {collection_id}: {e}")
async def _delete_document_vectors(self, document_id: int, collection_name: str):
"""Delete document vectors from Qdrant"""
try:
# Get RAG module to delete the document vectors
try:
from app.services.module_manager import module_manager
rag_module = module_manager.get_module('rag')
except ImportError as e:
logger.error(f"Failed to import module_manager: {e}")
rag_module = None
if rag_module and hasattr(rag_module, 'delete_document'):
# Create a document ID that matches what was used during indexing
doc_id = str(document_id)
success = await rag_module.delete_document(doc_id, collection_name)
if success:
logger.info(f"Deleted vectors for document {document_id} from collection {collection_name}")
else:
logger.warning(f"No vectors found for document {document_id} in collection {collection_name}")
else:
logger.warning("RAG module not available for document vector deletion")
except Exception as e:
logger.error(f"Error deleting document vectors for {document_id} from {collection_name}: {e}")
# Don't re-raise the error as document deletion should continue
async def _get_qdrant_collections(self) -> List[str]:
"""Get list of all collection names from Qdrant"""
try:
# Get RAG module to access Qdrant collections
from app.services.module_manager import module_manager
rag_module = module_manager.get_module('rag')
if rag_module and hasattr(rag_module, '_get_collections_safely'):
return await rag_module._get_collections_safely()
else:
logger.warning("RAG module or safe collections method not available")
return []
except Exception as e:
logger.error(f"Error getting Qdrant collections: {e}")
return []
async def _get_qdrant_collection_point_count(self, collection_name: str) -> int:
"""Get the number of points (documents) in a Qdrant collection"""
try:
# Get RAG module to access Qdrant collections
from app.services.module_manager import module_manager
rag_module = module_manager.get_module('rag')
if rag_module and hasattr(rag_module, '_get_collection_info_safely'):
collection_info = await rag_module._get_collection_info_safely(collection_name)
return collection_info.get("points_count", 0)
else:
logger.warning("RAG module or safe collection info method not available")
return 0
except Exception as e:
logger.warning(f"Could not get point count for collection {collection_name}: {e}")
return 0
async def _process_document(self, document_id: int):
"""Process document content and create vectors"""
try:
# Get fresh document from database
async with self.db as session:
document = await session.get(RagDocument, document_id)
if not document:
return
# Process with RAG module (now includes content processing)
try:
from app.services.module_manager import module_manager
rag_module = module_manager.get_module('rag')
except ImportError as e:
logger.error(f"Failed to import module_manager: {e}")
rag_module = None
if rag_module:
# Read file content
with open(document.file_path, 'rb') as f:
file_content = f.read()
# Process with RAG module
try:
processed_doc = await rag_module.process_document(
file_content,
document.original_filename,
{}
)
# Success case - update document with processed content
document.converted_content = processed_doc.content
document.word_count = processed_doc.word_count
document.character_count = len(processed_doc.content)
document.document_metadata = processed_doc.metadata
document.status = 'processed'
document.processed_at = datetime.utcnow()
# Index the processed document in the correct Qdrant collection
try:
# Get the collection's Qdrant collection name
from sqlalchemy.orm import selectinload
from sqlalchemy import select
stmt = select(RagDocument).options(selectinload(RagDocument.collection)).where(RagDocument.id == document_id)
result = await session.execute(stmt)
doc_with_collection = result.scalar_one()
qdrant_collection_name = doc_with_collection.collection.qdrant_collection_name
# Index in Qdrant with the correct collection name
await rag_module.index_processed_document(processed_doc, qdrant_collection_name)
# Calculate actual vector count (estimate based on content length)
document.vector_count = max(1, len(processed_doc.content) // 500) # ~500 chars per chunk
document.status = 'indexed'
document.indexed_at = datetime.utcnow()
except Exception as index_error:
logger.error(f"Failed to index document {document_id} in Qdrant: {index_error}")
document.status = 'error'
document.processing_error = f"Indexing failed: {str(index_error)}"
# Update collection stats
if document.status == 'indexed':
collection = doc_with_collection.collection
collection.document_count += 1
collection.size_bytes += document.file_size
collection.vector_count += document.vector_count
collection.updated_at = datetime.utcnow()
except Exception as e:
# Error case - mark document as failed
document.status = 'error'
document.processing_error = str(e)
await session.commit()
else:
# No RAG module available
document.status = 'error'
document.processing_error = 'RAG module not available'
await session.commit()
except Exception as e:
# Update document with error status
async with self.db as session:
document = await session.get(RagDocument, document_id)
if document:
document.status = 'error'
document.processing_error = str(e)
await session.commit()
async def reprocess_document(self, document_id: int) -> bool:
"""Restart processing for a stuck or failed document"""
try:
# Get document from database
document = await self.get_document(document_id)
if not document:
logger.error(f"Document {document_id} not found for reprocessing")
return False
# Check if document is in a state where reprocessing makes sense
if document.status not in ['processing', 'error']:
logger.warning(f"Document {document_id} status is '{document.status}', cannot reprocess")
return False
logger.info(f"Restarting processing for document {document_id} (current status: {document.status})")
# Reset document status and clear errors
document.status = 'pending'
document.processing_error = None
document.processed_at = None
document.indexed_at = None
document.updated_at = datetime.utcnow()
await self.db.commit()
# Re-queue document for processing
try:
from app.services.document_processor import document_processor
success = await document_processor.add_task(document_id, priority=1)
if success:
logger.info(f"Document {document_id} successfully re-queued for processing")
else:
logger.error(f"Failed to re-queue document {document_id} for processing")
# Revert status back to error
document.status = 'error'
document.processing_error = "Failed to re-queue for processing"
await self.db.commit()
return success
except Exception as e:
logger.error(f"Error re-queuing document {document_id}: {e}")
# Revert status back to error
document.status = 'error'
document.processing_error = f"Failed to re-queue: {str(e)}"
await self.db.commit()
return False
except Exception as e:
logger.error(f"Error reprocessing document {document_id}: {e}")
return False

View File

@@ -0,0 +1,363 @@
"""
Trusted Execution Environment (TEE) Service
Handles Privatemode.ai TEE integration for confidential computing
"""
import asyncio
import json
import logging
import os
from typing import Dict, Any, Optional, List
from datetime import datetime, timedelta
from enum import Enum
import aiohttp
from fastapi import HTTPException, status
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.primitives.serialization import load_pem_public_key
import base64
from app.core.config import settings
logger = logging.getLogger(__name__)
class TEEStatus(str, Enum):
"""TEE environment status"""
HEALTHY = "healthy"
DEGRADED = "degraded"
OFFLINE = "offline"
UNKNOWN = "unknown"
class AttestationStatus(str, Enum):
"""Attestation verification status"""
VERIFIED = "verified"
FAILED = "failed"
PENDING = "pending"
EXPIRED = "expired"
class TEEService:
"""Service for managing Privatemode.ai TEE integration"""
def __init__(self):
self.privatemode_base_url = "http://privatemode-proxy:8080"
self.privatemode_api_key = settings.PRIVATEMODE_API_KEY
self.session: Optional[aiohttp.ClientSession] = None
self.timeout = aiohttp.ClientTimeout(total=300) # 5 minutes timeout
self.attestation_cache = {} # Cache for attestation results
self.attestation_ttl = timedelta(hours=1) # Cache TTL
async def _get_session(self) -> aiohttp.ClientSession:
"""Get or create aiohttp session"""
if self.session is None or self.session.closed:
self.session = aiohttp.ClientSession(
timeout=self.timeout,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self.privatemode_api_key}"
}
)
return self.session
async def close(self):
"""Close the HTTP session"""
if self.session and not self.session.closed:
await self.session.close()
async def health_check(self) -> Dict[str, Any]:
"""Check TEE environment health"""
try:
session = await self._get_session()
async with session.get(f"{self.privatemode_base_url}/health") as response:
if response.status == 200:
health_data = await response.json()
return {
"status": TEEStatus.HEALTHY.value,
"timestamp": datetime.utcnow().isoformat(),
"tee_enabled": health_data.get("tee_enabled", False),
"attestation_available": health_data.get("attestation_available", False),
"secure_memory": health_data.get("secure_memory", False),
"details": health_data
}
else:
return {
"status": TEEStatus.DEGRADED.value,
"timestamp": datetime.utcnow().isoformat(),
"error": f"HTTP {response.status}"
}
except Exception as e:
logger.error(f"TEE health check error: {e}")
return {
"status": TEEStatus.OFFLINE.value,
"timestamp": datetime.utcnow().isoformat(),
"error": str(e)
}
async def get_attestation(self, nonce: Optional[str] = None) -> Dict[str, Any]:
"""Get TEE attestation report"""
try:
if not nonce:
nonce = base64.b64encode(os.urandom(32)).decode()
# Check cache first
cache_key = f"attestation_{nonce}"
if cache_key in self.attestation_cache:
cached_result = self.attestation_cache[cache_key]
if datetime.fromisoformat(cached_result["timestamp"]) + self.attestation_ttl > datetime.utcnow():
return cached_result
session = await self._get_session()
payload = {"nonce": nonce}
async with session.post(
f"{self.privatemode_base_url}/attestation",
json=payload
) as response:
if response.status == 200:
attestation_data = await response.json()
# Process attestation report
result = {
"status": AttestationStatus.VERIFIED.value,
"timestamp": datetime.utcnow().isoformat(),
"nonce": nonce,
"report": attestation_data.get("report"),
"signature": attestation_data.get("signature"),
"certificate_chain": attestation_data.get("certificate_chain"),
"measurements": attestation_data.get("measurements", {}),
"tee_type": attestation_data.get("tee_type", "unknown"),
"verified": True
}
# Cache the result
self.attestation_cache[cache_key] = result
return result
else:
error_text = await response.text()
logger.error(f"TEE attestation failed: {response.status} - {error_text}")
return {
"status": AttestationStatus.FAILED.value,
"timestamp": datetime.utcnow().isoformat(),
"nonce": nonce,
"error": error_text,
"verified": False
}
except Exception as e:
logger.error(f"TEE attestation error: {e}")
return {
"status": AttestationStatus.FAILED.value,
"timestamp": datetime.utcnow().isoformat(),
"nonce": nonce,
"error": str(e),
"verified": False
}
async def verify_attestation(self, attestation_data: Dict[str, Any]) -> Dict[str, Any]:
"""Verify TEE attestation report"""
try:
# Extract components
report = attestation_data.get("report")
signature = attestation_data.get("signature")
cert_chain = attestation_data.get("certificate_chain")
if not all([report, signature, cert_chain]):
return {
"verified": False,
"status": AttestationStatus.FAILED.value,
"error": "Missing required attestation components"
}
# Verify signature (simplified - in production, use proper certificate validation)
try:
# This is a placeholder for actual attestation verification
# In production, you would:
# 1. Validate the certificate chain
# 2. Verify the signature using the public key
# 3. Check measurements against known good values
# 4. Validate the nonce
verification_result = {
"verified": True,
"status": AttestationStatus.VERIFIED.value,
"timestamp": datetime.utcnow().isoformat(),
"certificate_valid": True,
"signature_valid": True,
"measurements_valid": True,
"nonce_valid": True
}
return verification_result
except Exception as verify_error:
logger.error(f"Attestation verification failed: {verify_error}")
return {
"verified": False,
"status": AttestationStatus.FAILED.value,
"error": str(verify_error)
}
except Exception as e:
logger.error(f"Attestation verification error: {e}")
return {
"verified": False,
"status": AttestationStatus.FAILED.value,
"error": str(e)
}
async def get_tee_capabilities(self) -> Dict[str, Any]:
"""Get TEE environment capabilities"""
try:
session = await self._get_session()
async with session.get(f"{self.privatemode_base_url}/capabilities") as response:
if response.status == 200:
capabilities = await response.json()
return {
"timestamp": datetime.utcnow().isoformat(),
"tee_type": capabilities.get("tee_type", "unknown"),
"secure_memory_size": capabilities.get("secure_memory_size", 0),
"encryption_algorithms": capabilities.get("encryption_algorithms", []),
"attestation_types": capabilities.get("attestation_types", []),
"key_management": capabilities.get("key_management", False),
"secure_storage": capabilities.get("secure_storage", False),
"network_isolation": capabilities.get("network_isolation", False),
"confidential_computing": capabilities.get("confidential_computing", False),
"details": capabilities
}
else:
return {
"timestamp": datetime.utcnow().isoformat(),
"error": f"Failed to get capabilities: HTTP {response.status}"
}
except Exception as e:
logger.error(f"TEE capabilities error: {e}")
return {
"timestamp": datetime.utcnow().isoformat(),
"error": str(e)
}
async def create_secure_session(self, user_id: str, api_key_id: int) -> Dict[str, Any]:
"""Create a secure TEE session"""
try:
session = await self._get_session()
payload = {
"user_id": user_id,
"api_key_id": api_key_id,
"timestamp": datetime.utcnow().isoformat(),
"requested_capabilities": [
"confidential_inference",
"secure_memory",
"attestation"
]
}
async with session.post(
f"{self.privatemode_base_url}/session",
json=payload
) as response:
if response.status == 201:
session_data = await response.json()
return {
"session_id": session_data.get("session_id"),
"status": "active",
"timestamp": datetime.utcnow().isoformat(),
"capabilities": session_data.get("capabilities", []),
"expires_at": session_data.get("expires_at"),
"attestation_token": session_data.get("attestation_token")
}
else:
error_text = await response.text()
logger.error(f"TEE session creation failed: {response.status} - {error_text}")
raise HTTPException(
status_code=response.status,
detail=f"Failed to create TEE session: {error_text}"
)
except aiohttp.ClientError as e:
logger.error(f"TEE session creation error: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="TEE service unavailable"
)
async def get_privacy_metrics(self) -> Dict[str, Any]:
"""Get privacy and security metrics"""
try:
session = await self._get_session()
async with session.get(f"{self.privatemode_base_url}/metrics") as response:
if response.status == 200:
metrics = await response.json()
return {
"timestamp": datetime.utcnow().isoformat(),
"requests_processed": metrics.get("requests_processed", 0),
"data_encrypted": metrics.get("data_encrypted", 0),
"attestations_verified": metrics.get("attestations_verified", 0),
"secure_sessions": metrics.get("secure_sessions", 0),
"uptime": metrics.get("uptime", 0),
"memory_usage": metrics.get("memory_usage", {}),
"performance": metrics.get("performance", {}),
"privacy_score": metrics.get("privacy_score", 0)
}
else:
return {
"timestamp": datetime.utcnow().isoformat(),
"error": f"Failed to get metrics: HTTP {response.status}"
}
except Exception as e:
logger.error(f"TEE metrics error: {e}")
return {
"timestamp": datetime.utcnow().isoformat(),
"error": str(e)
}
async def list_tee_models(self) -> List[Dict[str, Any]]:
"""List available TEE models"""
try:
session = await self._get_session()
async with session.get(f"{self.privatemode_base_url}/models") as response:
if response.status == 200:
models_data = await response.json()
models = []
for model in models_data.get("models", []):
models.append({
"id": model.get("id"),
"name": model.get("name"),
"type": model.get("type", "chat"),
"provider": "privatemode",
"tee_enabled": True,
"confidential_computing": True,
"secure_inference": True,
"attestation_required": model.get("attestation_required", False),
"max_tokens": model.get("max_tokens", 4096),
"cost_per_token": model.get("cost_per_token", 0.0),
"availability": model.get("availability", "available")
})
return models
else:
logger.error(f"Failed to get TEE models: HTTP {response.status}")
return []
except Exception as e:
logger.error(f"TEE models error: {e}")
return []
async def cleanup_expired_cache(self):
"""Clean up expired attestation cache entries"""
current_time = datetime.utcnow()
expired_keys = []
for key, cached_data in self.attestation_cache.items():
if datetime.fromisoformat(cached_data["timestamp"]) + self.attestation_ttl <= current_time:
expired_keys.append(key)
for key in expired_keys:
del self.attestation_cache[key]
logger.info(f"Cleaned up {len(expired_keys)} expired attestation cache entries")
# Global TEE service instance
tee_service = TEEService()

View File

@@ -0,0 +1,3 @@
"""
Utilities package
"""

View File

@@ -0,0 +1,149 @@
"""
Custom exceptions
"""
from typing import Optional, Dict, Any
from fastapi import HTTPException, status
class CustomHTTPException(HTTPException):
"""Base custom HTTP exception"""
def __init__(
self,
status_code: int,
error_code: str,
detail: str,
details: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
):
super().__init__(status_code=status_code, detail=detail, headers=headers)
self.error_code = error_code
self.details = details or {}
class AuthenticationError(CustomHTTPException):
"""Authentication error"""
def __init__(self, detail: str = "Authentication failed", details: Optional[Dict[str, Any]] = None):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
error_code="AUTHENTICATION_ERROR",
detail=detail,
details=details,
headers={"WWW-Authenticate": "Bearer"},
)
class AuthorizationError(CustomHTTPException):
"""Authorization error"""
def __init__(self, detail: str = "Insufficient permissions", details: Optional[Dict[str, Any]] = None):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
error_code="AUTHORIZATION_ERROR",
detail=detail,
details=details,
)
class ValidationError(CustomHTTPException):
"""Validation error"""
def __init__(self, detail: str = "Invalid data", details: Optional[Dict[str, Any]] = None):
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
error_code="VALIDATION_ERROR",
detail=detail,
details=details,
)
class NotFoundError(CustomHTTPException):
"""Not found error"""
def __init__(self, detail: str = "Resource not found", details: Optional[Dict[str, Any]] = None):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
error_code="NOT_FOUND",
detail=detail,
details=details,
)
class ConflictError(CustomHTTPException):
"""Conflict error"""
def __init__(self, detail: str = "Resource conflict", details: Optional[Dict[str, Any]] = None):
super().__init__(
status_code=status.HTTP_409_CONFLICT,
error_code="CONFLICT",
detail=detail,
details=details,
)
class RateLimitError(CustomHTTPException):
"""Rate limit error"""
def __init__(self, detail: str = "Rate limit exceeded", details: Optional[Dict[str, Any]] = None):
super().__init__(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
error_code="RATE_LIMIT_EXCEEDED",
detail=detail,
details=details,
)
class BudgetExceededError(CustomHTTPException):
"""Budget exceeded error"""
def __init__(self, detail: str = "Budget exceeded", details: Optional[Dict[str, Any]] = None):
super().__init__(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
error_code="BUDGET_EXCEEDED",
detail=detail,
details=details,
)
class ModuleError(CustomHTTPException):
"""Module error"""
def __init__(self, detail: str = "Module error", details: Optional[Dict[str, Any]] = None):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
error_code="MODULE_ERROR",
detail=detail,
details=details,
)
class CircuitBreakerOpen(Exception):
"""Circuit breaker is open"""
pass
class ModuleLoadError(Exception):
"""Module load error"""
pass
class ModuleNotFoundError(Exception):
"""Module not found error"""
pass
class ModuleFatalError(Exception):
"""Fatal module error"""
pass
class ConfigurationError(Exception):
"""Configuration error"""
pass
# Aliases for backwards compatibility
RateLimitExceeded = RateLimitError
APIException = CustomHTTPException # Generic API exception alias

View File

@@ -0,0 +1,7 @@
{
"name": "Confidential Empire",
"version": "1.0.0",
"debug": true,
"log_level": "INFO",
"timezone": "UTC"
}

View File

@@ -0,0 +1,5 @@
{
"redis_url": "redis://empire-redis:6379/0",
"timeout": 30,
"max_connections": 10
}

View File

@@ -0,0 +1,10 @@
{
"interval": 30,
"alert_thresholds": {
"cpu_warning": 80,
"cpu_critical": 95,
"memory_warning": 85,
"memory_critical": 95
},
"retention_hours": 24
}

6
backend/modules/cache/__init__.py vendored Normal file
View File

@@ -0,0 +1,6 @@
"""
Cache module for Confidential Empire platform
"""
from .main import CacheModule
__all__ = ["CacheModule"]

281
backend/modules/cache/main.py vendored Normal file
View File

@@ -0,0 +1,281 @@
"""
Cache module implementation with Redis backend
"""
import asyncio
import json
import logging
from typing import Any, Dict, Optional, Union
from datetime import datetime, timedelta
import redis.asyncio as redis
from redis.asyncio import Redis
from contextlib import asynccontextmanager
from app.core.config import settings
from app.core.logging import log_module_event
logger = logging.getLogger(__name__)
class CacheModule:
"""Redis-based cache module for request/response caching"""
def __init__(self):
self.redis_client: Optional[Redis] = None
self.config: Dict[str, Any] = {}
self.enabled = False
self.stats = {
"hits": 0,
"misses": 0,
"errors": 0,
"total_requests": 0
}
async def initialize(self):
"""Initialize the cache module"""
try:
# Initialize Redis connection
redis_url = getattr(settings, 'REDIS_URL', 'redis://localhost:6379/0')
self.redis_client = redis.from_url(
redis_url,
encoding="utf-8",
decode_responses=True,
socket_connect_timeout=5,
socket_timeout=5,
retry_on_timeout=True
)
# Test connection
await self.redis_client.ping()
self.enabled = True
log_module_event("cache", "initialized", {
"provider": self.config.get("provider", "redis"),
"ttl": self.config.get("ttl", 3600),
"max_size": self.config.get("max_size", 10000)
})
except Exception as e:
logger.error(f"Failed to initialize cache module: {e}")
log_module_event("cache", "initialization_failed", {"error": str(e)})
self.enabled = False
raise
async def cleanup(self):
"""Cleanup cache resources"""
if self.redis_client:
await self.redis_client.close()
self.redis_client = None
self.enabled = False
log_module_event("cache", "cleanup", {"success": True})
def _get_cache_key(self, key: str, prefix: str = "ce") -> str:
"""Generate cache key with prefix"""
return f"{prefix}:{key}"
async def get(self, key: str, default: Any = None) -> Any:
"""Get value from cache"""
if not self.enabled:
return default
try:
cache_key = self._get_cache_key(key)
value = await self.redis_client.get(cache_key)
if value is None:
self.stats["misses"] += 1
return default
self.stats["hits"] += 1
self.stats["total_requests"] += 1
# Try to deserialize JSON
try:
return json.loads(value)
except json.JSONDecodeError:
return value
except Exception as e:
logger.error(f"Cache get error: {e}")
self.stats["errors"] += 1
return default
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""Set value in cache"""
if not self.enabled:
return False
try:
cache_key = self._get_cache_key(key)
ttl = ttl or self.config.get("ttl", 3600)
# Serialize complex objects as JSON
if isinstance(value, (dict, list, tuple)):
value = json.dumps(value)
await self.redis_client.setex(cache_key, ttl, value)
return True
except Exception as e:
logger.error(f"Cache set error: {e}")
self.stats["errors"] += 1
return False
async def delete(self, key: str) -> bool:
"""Delete key from cache"""
if not self.enabled:
return False
try:
cache_key = self._get_cache_key(key)
result = await self.redis_client.delete(cache_key)
return result > 0
except Exception as e:
logger.error(f"Cache delete error: {e}")
self.stats["errors"] += 1
return False
async def exists(self, key: str) -> bool:
"""Check if key exists in cache"""
if not self.enabled:
return False
try:
cache_key = self._get_cache_key(key)
return await self.redis_client.exists(cache_key) > 0
except Exception as e:
logger.error(f"Cache exists error: {e}")
self.stats["errors"] += 1
return False
async def clear_pattern(self, pattern: str) -> int:
"""Clear keys matching pattern"""
if not self.enabled:
return 0
try:
cache_pattern = self._get_cache_key(pattern)
keys = await self.redis_client.keys(cache_pattern)
if keys:
return await self.redis_client.delete(*keys)
return 0
except Exception as e:
logger.error(f"Cache clear pattern error: {e}")
self.stats["errors"] += 1
return 0
async def clear_all(self) -> bool:
"""Clear all cache entries"""
if not self.enabled:
return False
try:
await self.redis_client.flushdb()
return True
except Exception as e:
logger.error(f"Cache clear all error: {e}")
self.stats["errors"] += 1
return False
async def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics"""
stats = self.stats.copy()
if self.enabled:
try:
info = await self.redis_client.info()
stats.update({
"redis_memory_used": info.get("used_memory_human", "N/A"),
"redis_connected_clients": info.get("connected_clients", 0),
"redis_total_commands": info.get("total_commands_processed", 0),
"hit_rate": round(
(stats["hits"] / stats["total_requests"]) * 100, 2
) if stats["total_requests"] > 0 else 0
})
except Exception as e:
logger.error(f"Error getting Redis stats: {e}")
return stats
async def pre_request_interceptor(self, context: Dict[str, Any]) -> Dict[str, Any]:
"""Pre-request interceptor for caching"""
if not self.enabled:
return context
request = context.get("request")
if not request:
return context
# Only cache GET requests
if request.method != "GET":
return context
# Generate cache key from request
cache_key = f"request:{request.method}:{request.url.path}"
if request.query_params:
cache_key += f":{hash(str(request.query_params))}"
# Check if cached response exists
cached_response = await self.get(cache_key)
if cached_response:
log_module_event("cache", "hit", {"cache_key": cache_key})
context["cached_response"] = cached_response
context["cache_hit"] = True
else:
log_module_event("cache", "miss", {"cache_key": cache_key})
context["cache_key"] = cache_key
context["cache_hit"] = False
return context
async def post_response_interceptor(self, context: Dict[str, Any]) -> Dict[str, Any]:
"""Post-response interceptor for caching"""
if not self.enabled:
return context
# Skip if this was a cache hit
if context.get("cache_hit"):
return context
cache_key = context.get("cache_key")
response = context.get("response")
if cache_key and response and response.status_code == 200:
# Cache successful responses
cache_data = {
"status_code": response.status_code,
"headers": dict(response.headers),
"body": response.body.decode() if hasattr(response, 'body') else None,
"timestamp": datetime.utcnow().isoformat()
}
await self.set(cache_key, cache_data)
log_module_event("cache", "stored", {"cache_key": cache_key})
return context
# Global cache instance
cache_module = CacheModule()
# Module interface functions
async def initialize():
"""Initialize cache module"""
await cache_module.initialize()
async def cleanup():
"""Cleanup cache module"""
await cache_module.cleanup()
async def pre_request_interceptor(context: Dict[str, Any]) -> Dict[str, Any]:
"""Pre-request interceptor"""
return await cache_module.pre_request_interceptor(context)
async def post_response_interceptor(context: Dict[str, Any]) -> Dict[str, Any]:
"""Post-response interceptor"""
return await cache_module.post_response_interceptor(context)# Force reload
# Trigger reload

View File

@@ -0,0 +1,21 @@
"""
Chatbot Module - AI Chatbot with RAG Integration
This module provides AI chatbot capabilities with:
- Multiple personality types (Assistant, Customer Support, Teacher, etc.)
- RAG integration for knowledge-based responses
- Conversation memory and context management
- Workflow integration as building blocks
- UI-configurable settings
"""
from .main import ChatbotModule, create_module
__version__ = "1.0.0"
__author__ = "AI Gateway Team"
# Export main classes for easy importing
__all__ = [
"ChatbotModule",
"create_module"
]

View File

@@ -0,0 +1,126 @@
{
"title": "Chatbot Configuration",
"type": "object",
"properties": {
"name": {
"type": "string",
"title": "Chatbot Name",
"description": "Display name for this chatbot instance",
"minLength": 1,
"maxLength": 100
},
"chatbot_type": {
"type": "string",
"title": "Chatbot Type",
"description": "Select the type of chatbot personality",
"enum": ["assistant", "customer_support", "teacher", "researcher", "creative_writer", "custom"],
"enumNames": ["General Assistant", "Customer Support", "Teacher", "Researcher", "Creative Writer", "Custom"],
"default": "assistant"
},
"model": {
"type": "string",
"title": "AI Model",
"description": "Choose the LLM model for responses",
"enum": ["gpt-4", "gpt-3.5-turbo", "claude-3-sonnet", "claude-3-opus", "llama-70b"],
"default": "gpt-3.5-turbo"
},
"system_prompt": {
"type": "string",
"title": "System Prompt",
"description": "Define the chatbot's personality and behavior instructions",
"ui:widget": "textarea",
"ui:options": {
"rows": 6,
"placeholder": "You are a helpful AI assistant..."
}
},
"use_rag": {
"type": "boolean",
"title": "Enable Knowledge Base",
"description": "Use RAG to search knowledge base for context",
"default": false
},
"rag_collection": {
"type": "string",
"title": "Knowledge Base Collection",
"description": "Select which document collection to search",
"ui:widget": "rag-collection-selector",
"ui:condition": "use_rag === true"
},
"rag_top_k": {
"type": "integer",
"title": "Knowledge Base Results",
"description": "Number of relevant documents to include",
"minimum": 1,
"maximum": 10,
"default": 5,
"ui:condition": "use_rag === true"
},
"temperature": {
"type": "number",
"title": "Response Creativity",
"description": "Controls randomness (0.0 = focused, 1.0 = creative)",
"minimum": 0,
"maximum": 1,
"default": 0.7,
"ui:widget": "range",
"ui:options": {
"step": 0.1
}
},
"max_tokens": {
"type": "integer",
"title": "Maximum Response Length",
"description": "Maximum number of tokens in response",
"minimum": 50,
"maximum": 4000,
"default": 1000,
"ui:widget": "range",
"ui:options": {
"step": 50
}
},
"memory_length": {
"type": "integer",
"title": "Conversation Memory",
"description": "Number of previous message pairs to remember",
"minimum": 1,
"maximum": 50,
"default": 10,
"ui:widget": "range"
},
"fallback_responses": {
"type": "array",
"title": "Fallback Responses",
"description": "Responses to use when the AI cannot answer",
"items": {
"type": "string",
"title": "Fallback Response"
},
"default": [
"I'm not sure how to help with that. Could you please rephrase your question?",
"I don't have enough information to answer that question accurately.",
"That's outside my knowledge area. Is there something else I can help you with?"
],
"ui:options": {
"orderable": true,
"addable": true,
"removable": true
}
}
},
"required": ["name", "chatbot_type", "model"],
"ui:order": [
"name",
"chatbot_type",
"model",
"system_prompt",
"use_rag",
"rag_collection",
"rag_top_k",
"temperature",
"max_tokens",
"memory_length",
"fallback_responses"
]
}

View File

@@ -0,0 +1,182 @@
{
"name": "Customer Support Workflow",
"description": "Intelligent customer support workflow with intent classification, knowledge base search, and chatbot response generation",
"version": "1.0",
"variables": {
"support_chatbot_id": "cs-bot-001",
"escalation_threshold": 0.3,
"max_attempts": 3
},
"steps": [
{
"id": "classify_intent",
"name": "Classify Customer Intent",
"type": "llm_call",
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "system",
"content": "You are an intent classifier for customer support. Classify the customer message into one of these categories: technical_issue, billing_question, feature_request, complaint, general_inquiry. Also provide a confidence score between 0 and 1. Respond with JSON: {\"intent\": \"category\", \"confidence\": 0.95, \"reasoning\": \"explanation\"}"
},
{
"role": "user",
"content": "{{ inputs.customer_message }}"
}
],
"output_variable": "intent_classification"
},
{
"id": "search_knowledge_base",
"name": "Search Knowledge Base",
"type": "workflow_step",
"module": "rag",
"action": "search",
"config": {
"query": "{{ inputs.customer_message }}",
"collection": "support_documentation",
"top_k": 5,
"include_metadata": true
},
"output_variable": "knowledge_results"
},
{
"id": "check_confidence",
"name": "Check Intent Confidence",
"type": "condition",
"condition": "JSON.parse(steps.classify_intent.result).confidence > variables.escalation_threshold",
"true_steps": [
{
"id": "generate_chatbot_response",
"name": "Generate Chatbot Response",
"type": "workflow_step",
"module": "chatbot",
"action": "workflow_chat_step",
"config": {
"message": "{{ inputs.customer_message }}",
"chatbot_id": "{{ variables.support_chatbot_id }}",
"use_rag": true,
"context": {
"intent": "{{ steps.classify_intent.result }}",
"knowledge_base_results": "{{ steps.search_knowledge_base.result }}",
"customer_history": "{{ inputs.customer_history }}",
"additional_instructions": "Be empathetic and professional. If you cannot fully resolve the issue, offer to escalate to a human agent."
}
},
"output_variable": "chatbot_response"
},
{
"id": "analyze_response_quality",
"name": "Analyze Response Quality",
"type": "llm_call",
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "system",
"content": "Analyze if this customer support response adequately addresses the customer's question. Consider completeness, accuracy, and helpfulness. Respond with JSON: {\"quality_score\": 0.85, \"is_adequate\": true, \"requires_escalation\": false, \"reasoning\": \"explanation\"}"
},
{
"role": "user",
"content": "Customer Question: {{ inputs.customer_message }}\\n\\nChatbot Response: {{ steps.generate_chatbot_response.result.response }}\\n\\nKnowledge Base Context: {{ steps.search_knowledge_base.result }}"
}
],
"output_variable": "response_quality"
},
{
"id": "final_response_decision",
"name": "Final Response Decision",
"type": "condition",
"condition": "JSON.parse(steps.analyze_response_quality.result).is_adequate === true",
"true_steps": [
{
"id": "send_chatbot_response",
"name": "Send Chatbot Response",
"type": "output",
"config": {
"response_type": "chatbot_response",
"message": "{{ steps.generate_chatbot_response.result.response }}",
"sources": "{{ steps.generate_chatbot_response.result.sources }}",
"confidence": "{{ JSON.parse(steps.classify_intent.result).confidence }}",
"quality_score": "{{ JSON.parse(steps.analyze_response_quality.result).quality_score }}"
}
}
],
"false_steps": [
{
"id": "escalate_to_human",
"name": "Escalate to Human Agent",
"type": "output",
"config": {
"response_type": "human_escalation",
"message": "I'd like to connect you with one of our human support agents who can better assist with your specific situation. Please hold on while I transfer you.",
"escalation_reason": "Response quality below threshold",
"intent": "{{ steps.classify_intent.result }}",
"attempted_response": "{{ steps.generate_chatbot_response.result.response }}",
"priority": "normal"
}
}
]
}
],
"false_steps": [
{
"id": "low_confidence_escalation",
"name": "Low Confidence Escalation",
"type": "output",
"config": {
"response_type": "human_escalation",
"message": "I want to make sure you get the best possible help. Let me connect you with one of our human support agents.",
"escalation_reason": "Low intent classification confidence",
"intent": "{{ steps.classify_intent.result }}",
"priority": "high"
}
}
]
},
{
"id": "log_interaction",
"name": "Log Customer Interaction",
"type": "workflow_step",
"module": "analytics",
"action": "log_event",
"config": {
"event_type": "customer_support_interaction",
"data": {
"customer_message": "{{ inputs.customer_message }}",
"intent_classification": "{{ steps.classify_intent.result }}",
"response_generated": "{{ steps.generate_chatbot_response.result.response }}",
"knowledge_base_used": "{{ steps.search_knowledge_base.result }}",
"escalated": "{{ outputs.response_type === 'human_escalation' }}",
"workflow_execution_time": "{{ execution_time }}",
"timestamp": "{{ current_timestamp }}"
}
}
}
],
"outputs": {
"response_type": "string",
"message": "string",
"sources": "array",
"escalation_reason": "string",
"confidence": "number",
"quality_score": "number"
},
"error_handling": {
"retry_failed_steps": true,
"max_retries": 2,
"fallback_response": "I apologize, but I'm experiencing technical difficulties. Please contact our support team directly for assistance."
},
"metadata": {
"created_by": "support_team",
"use_case": "customer_support_automation",
"tags": ["customer_support", "chatbot", "rag", "escalation"],
"estimated_execution_time": "5-15 seconds"
}
}

View File

@@ -0,0 +1,893 @@
"""
Chatbot Module Implementation
Provides AI chatbot capabilities with:
- RAG integration for knowledge-based responses
- Custom prompts and personalities
- Conversation memory and context
- Workflow integration as building blocks
- UI-configurable settings
"""
import asyncio
import json
from pprint import pprint
import uuid
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional, Union
from dataclasses import dataclass
from pydantic import BaseModel, Field
from enum import Enum
from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy.orm import Session
from app.core.logging import get_logger
from app.services.litellm_client import LiteLLMClient
from app.services.base_module import BaseModule, Permission
from app.models.user import User
from app.models.chatbot import ChatbotInstance as DBChatbotInstance, ChatbotConversation as DBConversation, ChatbotMessage as DBMessage, ChatbotAnalytics
from app.core.security import get_current_user
from app.db.database import get_db
from app.core.config import settings
# Import protocols for type hints and dependency injection
from ..protocols import RAGServiceProtocol, LiteLLMClientProtocol
logger = get_logger(__name__)
class ChatbotType(str, Enum):
"""Types of chatbot personalities"""
ASSISTANT = "assistant"
CUSTOMER_SUPPORT = "customer_support"
TEACHER = "teacher"
RESEARCHER = "researcher"
CREATIVE_WRITER = "creative_writer"
CUSTOM = "custom"
class MessageRole(str, Enum):
"""Message roles in conversation"""
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
@dataclass
class ChatbotConfig:
"""Chatbot configuration"""
name: str
chatbot_type: str # Changed from ChatbotType enum to str to allow custom types
model: str
rag_collection: Optional[str] = None
system_prompt: str = ""
temperature: float = 0.7
max_tokens: int = 1000
memory_length: int = 10 # Number of previous messages to remember
use_rag: bool = False
rag_top_k: int = 5
fallback_responses: List[str] = None
def __post_init__(self):
if self.fallback_responses is None:
self.fallback_responses = [
"I'm not sure how to help with that. Could you please rephrase your question?",
"I don't have enough information to answer that question accurately.",
"That's outside my knowledge area. Is there something else I can help you with?"
]
class ChatMessage(BaseModel):
"""Individual chat message"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
role: MessageRole
content: str
timestamp: datetime = Field(default_factory=datetime.utcnow)
metadata: Dict[str, Any] = Field(default_factory=dict)
sources: Optional[List[Dict[str, Any]]] = None
class Conversation(BaseModel):
"""Conversation state"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
chatbot_id: str
user_id: str
messages: List[ChatMessage] = Field(default_factory=list)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
metadata: Dict[str, Any] = Field(default_factory=dict)
class ChatRequest(BaseModel):
"""Chat completion request"""
message: str
conversation_id: Optional[str] = None
chatbot_id: str
use_rag: Optional[bool] = None
context: Optional[Dict[str, Any]] = None
class ChatResponse(BaseModel):
"""Chat completion response"""
response: str
conversation_id: str
message_id: str
sources: Optional[List[Dict[str, Any]]] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
class ChatbotInstance(BaseModel):
"""Configured chatbot instance"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
name: str
config: ChatbotConfig
created_by: str
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
is_active: bool = True
class ChatbotModule(BaseModule):
"""Main chatbot module implementation"""
def __init__(self, litellm_client: Optional[LiteLLMClientProtocol] = None,
rag_service: Optional[RAGServiceProtocol] = None):
super().__init__("chatbot")
self.litellm_client = litellm_client
self.rag_module = rag_service # Keep same name for compatibility
self.db_session = None
# System prompts will be loaded from database
self.system_prompts = {}
async def initialize(self, **kwargs):
"""Initialize the chatbot module"""
await super().initialize(**kwargs)
# Get dependencies from global services if not already injected
if not self.litellm_client:
try:
from app.services.litellm_client import litellm_client
self.litellm_client = litellm_client
logger.info("LiteLLM client injected from global service")
except Exception as e:
logger.warning(f"Could not inject LiteLLM client: {e}")
if not self.rag_module:
try:
# Try to get RAG module from module manager
from app.services.module_manager import module_manager
if hasattr(module_manager, 'modules') and 'rag' in module_manager.modules:
self.rag_module = module_manager.modules['rag']
logger.info("RAG module injected from module manager")
except Exception as e:
logger.warning(f"Could not inject RAG module: {e}")
# Load prompt templates from database
await self._load_prompt_templates()
logger.info("Chatbot module initialized")
logger.info(f"LiteLLM client available after init: {self.litellm_client is not None}")
logger.info(f"RAG module available after init: {self.rag_module is not None}")
logger.info(f"Loaded {len(self.system_prompts)} prompt templates")
async def _ensure_dependencies(self):
"""Lazy load dependencies if not available"""
if not self.litellm_client:
try:
from app.services.litellm_client import litellm_client
self.litellm_client = litellm_client
logger.info("LiteLLM client lazy loaded")
except Exception as e:
logger.warning(f"Could not lazy load LiteLLM client: {e}")
if not self.rag_module:
try:
# Try to get RAG module from module manager
from app.services.module_manager import module_manager
if hasattr(module_manager, 'modules') and 'rag' in module_manager.modules:
self.rag_module = module_manager.modules['rag']
logger.info("RAG module lazy loaded from module manager")
except Exception as e:
logger.warning(f"Could not lazy load RAG module: {e}")
async def _load_prompt_templates(self):
"""Load prompt templates from database"""
try:
from app.db.database import SessionLocal
from app.models.prompt_template import PromptTemplate
from sqlalchemy import select
db = SessionLocal()
try:
result = db.execute(
select(PromptTemplate)
.where(PromptTemplate.is_active == True)
)
templates = result.scalars().all()
for template in templates:
self.system_prompts[template.type_key] = template.system_prompt
logger.info(f"Loaded {len(self.system_prompts)} prompt templates from database")
finally:
db.close()
except Exception as e:
logger.warning(f"Could not load prompt templates from database: {e}")
# Fallback to hardcoded prompts
self.system_prompts = {
"assistant": "You are a helpful AI assistant. Provide accurate, concise, and friendly responses. Always aim to be helpful while being honest about your limitations.",
"customer_support": "You are a professional customer support representative. Be empathetic, professional, and solution-focused in all interactions.",
"teacher": "You are an experienced educational tutor. Break down complex concepts into understandable parts. Be patient, supportive, and encouraging.",
"researcher": "You are a thorough research assistant with a focus on accuracy and evidence-based information.",
"creative_writer": "You are an experienced creative writing mentor and storytelling expert.",
"custom": "You are a helpful AI assistant. Your personality and behavior will be defined by custom instructions."
}
async def get_system_prompt_for_type(self, chatbot_type: str) -> str:
"""Get system prompt for a specific chatbot type"""
if chatbot_type in self.system_prompts:
return self.system_prompts[chatbot_type]
# If not found, try to reload templates
await self._load_prompt_templates()
return self.system_prompts.get(chatbot_type, self.system_prompts.get("assistant",
"You are a helpful AI assistant. Provide accurate, concise, and friendly responses."))
async def create_chatbot(self, config: ChatbotConfig, user_id: str, db: Session) -> ChatbotInstance:
"""Create a new chatbot instance"""
# Set system prompt based on type if not provided or empty
if not config.system_prompt or config.system_prompt.strip() == "":
config.system_prompt = await self.get_system_prompt_for_type(config.chatbot_type)
# Create database record
db_chatbot = DBChatbotInstance(
name=config.name,
description=f"{config.chatbot_type.replace('_', ' ').title()} chatbot",
config=config.__dict__,
created_by=user_id
)
db.add(db_chatbot)
db.commit()
db.refresh(db_chatbot)
# Convert to response model
chatbot = ChatbotInstance(
id=db_chatbot.id,
name=db_chatbot.name,
config=ChatbotConfig(**db_chatbot.config),
created_by=db_chatbot.created_by,
created_at=db_chatbot.created_at,
updated_at=db_chatbot.updated_at,
is_active=db_chatbot.is_active
)
logger.info(f"Created new chatbot: {chatbot.name} ({chatbot.id})")
return chatbot
async def chat_completion(self, request: ChatRequest, user_id: str, db: Session) -> ChatResponse:
"""Generate chat completion response"""
# Get chatbot configuration from database
db_chatbot = db.query(DBChatbotInstance).filter(DBChatbotInstance.id == request.chatbot_id).first()
if not db_chatbot:
raise HTTPException(status_code=404, detail="Chatbot not found")
chatbot_config = ChatbotConfig(**db_chatbot.config)
# Get or create conversation
conversation = await self._get_or_create_conversation(
request.conversation_id, request.chatbot_id, user_id, db
)
# Create user message
user_message = DBMessage(
conversation_id=conversation.id,
role=MessageRole.USER.value,
content=request.message
)
db.add(user_message)
db.commit()
db.refresh(user_message)
logger.info(f"Created user message with ID {user_message.id} for conversation {conversation.id}")
try:
# Force the session to see the committed changes
db.expire_all()
# Get conversation history for context - INCLUDING the message we just created
messages = db.query(DBMessage).filter(
DBMessage.conversation_id == conversation.id
).order_by(DBMessage.timestamp.desc()).limit(chatbot_config.memory_length * 2 + 1).all()
logger.info(f"Query for conversation_id={conversation.id}, memory_length={chatbot_config.memory_length}")
logger.info(f"Found {len(messages)} messages in conversation history")
# If we don't have any messages, manually add the user message we just created
if len(messages) == 0:
logger.warning(f"No messages found in query, but we just created message {user_message.id}")
logger.warning(f"Using the user message we just created")
messages = [user_message]
for idx, msg in enumerate(messages):
logger.info(f"Message {idx}: id={msg.id}, role={msg.role}, content_preview={msg.content[:50] if msg.content else 'None'}...")
# Generate response
response_content, sources = await self._generate_response(
request.message, messages, chatbot_config, request.context, db
)
# Create assistant message
assistant_message = DBMessage(
conversation_id=conversation.id,
role=MessageRole.ASSISTANT.value,
content=response_content,
sources=sources,
metadata={"model": chatbot_config.model, "temperature": chatbot_config.temperature}
)
db.add(assistant_message)
db.commit()
db.refresh(assistant_message)
# Update conversation timestamp
conversation.updated_at = datetime.utcnow()
db.commit()
return ChatResponse(
response=response_content,
conversation_id=conversation.id,
message_id=assistant_message.id,
sources=sources
)
except Exception as e:
logger.error(f"Chat completion failed: {e}")
# Return fallback response
fallback = chatbot_config.fallback_responses[0] if chatbot_config.fallback_responses else "I'm having trouble responding right now."
assistant_message = DBMessage(
conversation_id=conversation.id,
role=MessageRole.ASSISTANT.value,
content=fallback,
metadata={"error": str(e), "fallback": True}
)
db.add(assistant_message)
db.commit()
db.refresh(assistant_message)
return ChatResponse(
response=fallback,
conversation_id=conversation.id,
message_id=assistant_message.id,
metadata={"error": str(e), "fallback": True}
)
async def _generate_response(self, message: str, db_messages: List[DBMessage],
config: ChatbotConfig, context: Optional[Dict] = None, db: Session = None) -> tuple[str, Optional[List]]:
"""Generate response using LLM with optional RAG"""
# Lazy load dependencies if not available
await self._ensure_dependencies()
sources = None
rag_context = ""
# RAG search if enabled
if config.use_rag and config.rag_collection and self.rag_module:
logger.info(f"RAG search enabled for collection: {config.rag_collection}")
try:
# Get the Qdrant collection name from RAG collection
qdrant_collection_name = await self._get_qdrant_collection_name(config.rag_collection, db)
logger.info(f"Qdrant collection name: {qdrant_collection_name}")
if qdrant_collection_name:
logger.info(f"Searching RAG documents: query='{message[:50]}...', max_results={config.rag_top_k}")
rag_results = await self.rag_module.search_documents(
query=message,
max_results=config.rag_top_k,
collection_name=qdrant_collection_name
)
if rag_results:
logger.info(f"RAG search found {len(rag_results)} results")
sources = [{"title": f"Document {i+1}", "content": result.document.content[:200]}
for i, result in enumerate(rag_results)]
# Build full RAG context from all results
rag_context = "\\n\\nRelevant information from knowledge base:\\n" + "\\n\\n".join([
f"[Document {i+1}]:\\n{result.document.content}" for i, result in enumerate(rag_results)
])
# Detailed RAG logging - ALWAYS log for debugging
logger.info("=== COMPREHENSIVE RAG SEARCH RESULTS ===")
logger.info(f"Query: '{message}'")
logger.info(f"Collection: {qdrant_collection_name}")
logger.info(f"Number of results: {len(rag_results)}")
for i, result in enumerate(rag_results):
logger.info(f"\\n--- RAG Result {i+1} ---")
logger.info(f"Score: {getattr(result, 'score', 'N/A')}")
logger.info(f"Document ID: {getattr(result.document, 'id', 'N/A')}")
logger.info(f"Full Content ({len(result.document.content)} chars):")
logger.info(f"{result.document.content}")
if hasattr(result.document, 'metadata'):
logger.info(f"Metadata: {result.document.metadata}")
logger.info(f"\\n=== RAG CONTEXT BEING ADDED TO PROMPT ({len(rag_context)} chars) ===")
logger.info(rag_context)
logger.info("=== END RAG SEARCH RESULTS ===")
else:
logger.warning("RAG search returned no results")
else:
logger.warning(f"RAG collection '{config.rag_collection}' not found in database")
except Exception as e:
logger.warning(f"RAG search failed: {e}")
import traceback
logger.warning(f"RAG search traceback: {traceback.format_exc()}")
# Build conversation context
messages = self._build_conversation_messages(db_messages, config, rag_context, context)
# CRITICAL: Add the current user message to the messages array
# This ensures the LLM knows what the user is asking, not just the history
messages.append({
"role": "user",
"content": message
})
logger.info(f"Added current user message to messages array")
# LLM completion
logger.info(f"Attempting LLM completion with model: {config.model}")
logger.info(f"Messages to send: {len(messages)} messages")
# Always log detailed prompts for debugging
logger.info("=== COMPREHENSIVE LLM REQUEST ===")
logger.info(f"Model: {config.model}")
logger.info(f"Temperature: {config.temperature}")
logger.info(f"Max tokens: {config.max_tokens}")
logger.info(f"RAG enabled: {config.use_rag}")
logger.info(f"RAG collection: {config.rag_collection}")
if config.use_rag and rag_context:
logger.info(f"RAG context added: {len(rag_context)} characters")
logger.info(f"RAG sources: {len(sources) if sources else 0} documents")
logger.info("\\n=== COMPLETE MESSAGES SENT TO LLM ===")
for i, msg in enumerate(messages):
logger.info(f"\\n--- Message {i+1} ---")
logger.info(f"Role: {msg['role']}")
logger.info(f"Content ({len(msg['content'])} chars):")
# Truncate long content for logging (full RAG context can be very long)
if len(msg['content']) > 500:
logger.info(f"{msg['content'][:500]}... [truncated, total {len(msg['content'])} chars]")
else:
logger.info(msg['content'])
logger.info("=== END COMPREHENSIVE LLM REQUEST ===")
if self.litellm_client:
try:
logger.info("Calling LiteLLM client create_chat_completion...")
response = await self.litellm_client.create_chat_completion(
model=config.model,
messages=messages,
user_id="chatbot_user",
api_key_id="chatbot_api_key",
temperature=config.temperature,
max_tokens=config.max_tokens
)
logger.info(f"LiteLLM response received, response keys: {list(response.keys())}")
# Extract response content from the LiteLLM response format
if 'choices' in response and response['choices']:
content = response['choices'][0]['message']['content']
logger.info(f"Response content length: {len(content)}")
# Always log response for debugging
logger.info("=== COMPREHENSIVE LLM RESPONSE ===")
logger.info(f"Response content ({len(content)} chars):")
logger.info(content)
if 'usage' in response:
usage = response['usage']
logger.info(f"Token usage - Prompt: {usage.get('prompt_tokens', 'N/A')}, Completion: {usage.get('completion_tokens', 'N/A')}, Total: {usage.get('total_tokens', 'N/A')}")
if sources:
logger.info(f"RAG sources included: {len(sources)} documents")
logger.info("=== END COMPREHENSIVE LLM RESPONSE ===")
return content, sources
else:
logger.warning("No choices in LiteLLM response")
return "I received an empty response from the AI model.", sources
except Exception as e:
logger.error(f"LiteLLM completion failed: {e}")
raise e
else:
logger.warning("No LiteLLM client available, using fallback")
# Fallback if no LLM client
return "I'm currently unable to process your request. Please try again later.", None
def _build_conversation_messages(self, db_messages: List[DBMessage], config: ChatbotConfig,
rag_context: str = "", context: Optional[Dict] = None) -> List[Dict]:
"""Build messages array for LLM completion"""
messages = []
# System prompt
system_prompt = config.system_prompt
if rag_context:
system_prompt += rag_context
if context and context.get('additional_instructions'):
system_prompt += f"\\n\\nAdditional instructions: {context['additional_instructions']}"
messages.append({"role": "system", "content": system_prompt})
logger.info(f"Building messages from {len(db_messages)} database messages")
# Conversation history (messages are already limited by memory_length in the query)
# Reverse to get chronological order
# Include ALL messages - the current user message is needed for the LLM to respond!
for idx, msg in enumerate(reversed(db_messages)):
logger.info(f"Processing message {idx}: role={msg.role}, content_preview={msg.content[:50] if msg.content else 'None'}...")
if msg.role in ["user", "assistant"]:
messages.append({
"role": msg.role,
"content": msg.content
})
logger.info(f"Added message with role {msg.role} to LLM messages")
else:
logger.info(f"Skipped message with role {msg.role}")
logger.info(f"Final messages array has {len(messages)} messages")
from pprint import pprint
pprint(messages) # For debugging, can be removed in production
return messages
async def _get_or_create_conversation(self, conversation_id: Optional[str],
chatbot_id: str, user_id: str, db: Session) -> DBConversation:
"""Get existing conversation or create new one"""
if conversation_id:
conversation = db.query(DBConversation).filter(DBConversation.id == conversation_id).first()
if conversation:
return conversation
# Create new conversation
conversation = DBConversation(
chatbot_id=chatbot_id,
user_id=user_id,
title="New Conversation"
)
db.add(conversation)
db.commit()
db.refresh(conversation)
return conversation
def get_router(self) -> APIRouter:
"""Get FastAPI router for chatbot endpoints"""
router = APIRouter(prefix="/chatbot", tags=["chatbot"])
@router.post("/chat", response_model=ChatResponse)
async def chat_endpoint(
request: ChatRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Chat completion endpoint"""
return await self.chat_completion(request, str(current_user['id']), db)
@router.post("/create", response_model=ChatbotInstance)
async def create_chatbot_endpoint(
config: ChatbotConfig,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Create new chatbot instance"""
return await self.create_chatbot(config, str(current_user['id']), db)
@router.get("/list", response_model=List[ChatbotInstance])
async def list_chatbots_endpoint(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""List user's chatbots"""
db_chatbots = db.query(DBChatbotInstance).filter(
(DBChatbotInstance.created_by == str(current_user['id'])) |
(DBChatbotInstance.created_by == "system")
).all()
chatbots = []
for db_chatbot in db_chatbots:
chatbot = ChatbotInstance(
id=db_chatbot.id,
name=db_chatbot.name,
config=ChatbotConfig(**db_chatbot.config),
created_by=db_chatbot.created_by,
created_at=db_chatbot.created_at,
updated_at=db_chatbot.updated_at,
is_active=db_chatbot.is_active
)
chatbots.append(chatbot)
return chatbots
@router.get("/conversations/{conversation_id}", response_model=Conversation)
async def get_conversation_endpoint(
conversation_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Get conversation history"""
conversation = db.query(DBConversation).filter(
DBConversation.id == conversation_id
).first()
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
# Check if user owns this conversation
if conversation.user_id != str(current_user['id']):
raise HTTPException(status_code=403, detail="Not authorized")
# Get messages
messages = db.query(DBMessage).filter(
DBMessage.conversation_id == conversation_id
).order_by(DBMessage.timestamp).all()
# Convert to response model
chat_messages = []
for msg in messages:
chat_message = ChatMessage(
id=msg.id,
role=MessageRole(msg.role),
content=msg.content,
timestamp=msg.timestamp,
metadata=msg.metadata or {},
sources=msg.sources
)
chat_messages.append(chat_message)
response_conversation = Conversation(
id=conversation.id,
chatbot_id=conversation.chatbot_id,
user_id=conversation.user_id,
messages=chat_messages,
created_at=conversation.created_at,
updated_at=conversation.updated_at,
metadata=conversation.context_data or {}
)
return response_conversation
@router.get("/types", response_model=List[Dict[str, str]])
async def get_chatbot_types_endpoint():
"""Get available chatbot types and their descriptions"""
return [
{"type": "assistant", "name": "General Assistant", "description": "Helpful AI assistant for general questions"},
{"type": "customer_support", "name": "Customer Support", "description": "Professional customer service chatbot"},
{"type": "teacher", "name": "Teacher", "description": "Educational tutor and learning assistant"},
{"type": "researcher", "name": "Researcher", "description": "Research assistant with fact-checking focus"},
{"type": "creative_writer", "name": "Creative Writer", "description": "Creative writing and storytelling assistant"},
{"type": "custom", "name": "Custom", "description": "Custom chatbot with user-defined personality"}
]
return router
# API Compatibility Methods
async def chat(self, chatbot_config: Dict[str, Any], message: str,
conversation_history: List = None, user_id: str = "anonymous") -> Dict[str, Any]:
"""Chat method for API compatibility"""
logger.info(f"Chat method called with message: {message[:50]}... by user: {user_id}")
# Lazy load dependencies
await self._ensure_dependencies()
logger.info(f"LiteLLM client available: {self.litellm_client is not None}")
logger.info(f"RAG module available: {self.rag_module is not None}")
try:
# Create a minimal database session for the chat
from app.db.database import SessionLocal
db = SessionLocal()
try:
# Convert config dict to ChatbotConfig
config = ChatbotConfig(
name=chatbot_config.get("name", "Unknown"),
chatbot_type=chatbot_config.get("chatbot_type", "assistant"),
model=chatbot_config.get("model", "gpt-3.5-turbo"),
system_prompt=chatbot_config.get("system_prompt", ""),
temperature=chatbot_config.get("temperature", 0.7),
max_tokens=chatbot_config.get("max_tokens", 1000),
memory_length=chatbot_config.get("memory_length", 10),
use_rag=chatbot_config.get("use_rag", False),
rag_collection=chatbot_config.get("rag_collection"),
rag_top_k=chatbot_config.get("rag_top_k", 5),
fallback_responses=chatbot_config.get("fallback_responses", [])
)
# Generate response using internal method with empty message history
response_content, sources = await self._generate_response(
message, [], config, None, db
)
return {
"response": response_content,
"sources": sources,
"conversation_id": None,
"message_id": f"msg_{uuid.uuid4()}"
}
finally:
db.close()
except Exception as e:
logger.error(f"Chat method failed: {e}")
fallback_responses = chatbot_config.get("fallback_responses", [
"I'm sorry, I'm having trouble processing your request right now."
])
return {
"response": fallback_responses[0] if fallback_responses else "I'm sorry, I couldn't process your request.",
"sources": None,
"conversation_id": None,
"message_id": f"msg_{uuid.uuid4()}"
}
# Workflow Integration Methods
async def workflow_chat_step(self, context: Dict[str, Any], step_config: Dict[str, Any], db: Session) -> Dict[str, Any]:
"""Execute chatbot as a workflow step"""
message = step_config.get('message', '')
chatbot_id = step_config.get('chatbot_id')
use_rag = step_config.get('use_rag', False)
# Template substitution from context
message = self._substitute_template_variables(message, context)
request = ChatRequest(
message=message,
chatbot_id=chatbot_id,
use_rag=use_rag,
context=step_config.get('context', {})
)
# Use system user for workflow executions
response = await self.chat_completion(request, "workflow_system", db)
return {
"response": response.response,
"conversation_id": response.conversation_id,
"sources": response.sources,
"metadata": response.metadata
}
def _substitute_template_variables(self, template: str, context: Dict[str, Any]) -> str:
"""Simple template variable substitution"""
import re
def replace_var(match):
var_path = match.group(1)
try:
# Simple dot notation support: context.user.name
value = context
for part in var_path.split('.'):
value = value[part]
return str(value)
except (KeyError, TypeError):
return match.group(0) # Return original if not found
return re.sub(r'\\{\\{\\s*([^}]+)\\s*\\}\\}', replace_var, template)
async def _get_qdrant_collection_name(self, collection_identifier: str, db: Session) -> Optional[str]:
"""Get Qdrant collection name from RAG collection ID, name, or direct Qdrant collection"""
try:
from app.models.rag_collection import RagCollection
from sqlalchemy import select
logger.info(f"Looking up RAG collection with identifier: '{collection_identifier}'")
# First check if this might be a direct Qdrant collection name
# (e.g., starts with "ext_", "rag_", or contains specific patterns)
if collection_identifier.startswith(("ext_", "rag_", "test_")) or "_" in collection_identifier:
# Check if this collection exists in Qdrant directly
actual_collection_name = collection_identifier
# Remove "ext_" prefix if present
if collection_identifier.startswith("ext_"):
actual_collection_name = collection_identifier[4:]
logger.info(f"Checking if '{actual_collection_name}' exists in Qdrant directly")
if self.rag_module:
try:
# Try to verify the collection exists in Qdrant
from qdrant_client import QdrantClient
qdrant_client = QdrantClient(host="shifra-qdrant", port=6333)
collections = qdrant_client.get_collections()
collection_names = [c.name for c in collections.collections]
if actual_collection_name in collection_names:
logger.info(f"Found Qdrant collection directly: {actual_collection_name}")
return actual_collection_name
except Exception as e:
logger.warning(f"Error checking Qdrant collections: {e}")
rag_collection = None
# Then try PostgreSQL lookup by ID if numeric
if collection_identifier.isdigit():
logger.info(f"Treating '{collection_identifier}' as collection ID")
stmt = select(RagCollection).where(
RagCollection.id == int(collection_identifier),
RagCollection.is_active == True
)
result = db.execute(stmt)
rag_collection = result.scalar_one_or_none()
# If not found by ID, try to look up by name in PostgreSQL
if not rag_collection:
logger.info(f"Collection not found by ID, trying by name: '{collection_identifier}'")
stmt = select(RagCollection).where(
RagCollection.name == collection_identifier,
RagCollection.is_active == True
)
result = db.execute(stmt)
rag_collection = result.scalar_one_or_none()
if rag_collection:
logger.info(f"Found RAG collection: ID={rag_collection.id}, name='{rag_collection.name}', qdrant_collection='{rag_collection.qdrant_collection_name}'")
return rag_collection.qdrant_collection_name
else:
logger.warning(f"RAG collection '{collection_identifier}' not found in database (tried both ID and name)")
return None
except Exception as e:
logger.error(f"Error looking up RAG collection '{collection_identifier}': {e}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
return None
# Required abstract methods from BaseModule
async def cleanup(self):
"""Cleanup chatbot module resources"""
logger.info("Chatbot module cleanup completed")
def get_required_permissions(self) -> List[Permission]:
"""Get required permissions for chatbot module"""
return [
Permission("chatbots", "create", "Create chatbot instances"),
Permission("chatbots", "configure", "Configure chatbot settings"),
Permission("chatbots", "chat", "Use chatbot for conversations"),
Permission("chatbots", "manage", "Manage all chatbots")
]
async def process_request(self, request_type: str, data: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
"""Process chatbot requests"""
if request_type == "chat":
# Handle chat requests
chat_request = ChatRequest(**data)
user_id = context.get("user_id", "anonymous")
db = context.get("db")
if db:
response = await self.chat_completion(chat_request, user_id, db)
return {
"success": True,
"response": response.response,
"conversation_id": response.conversation_id,
"sources": response.sources
}
return {"success": False, "error": f"Unknown request type: {request_type}"}
# Module factory function
def create_module(litellm_client: Optional[LiteLLMClientProtocol] = None,
rag_service: Optional[RAGServiceProtocol] = None) -> ChatbotModule:
"""Factory function to create chatbot module instance"""
return ChatbotModule(litellm_client=litellm_client, rag_service=rag_service)
# Create module instance (dependencies will be injected via factory)
chatbot_module = ChatbotModule()

View File

@@ -0,0 +1,146 @@
name: chatbot
version: 1.0.0
description: "AI Chatbot with RAG integration and customizable prompts"
author: "AI Gateway Team"
category: "conversation"
# Module lifecycle
enabled: true
auto_start: true
dependencies:
- rag
- workflow
optional_dependencies:
- analytics
# Configuration
config_schema: "./config_schema.json"
ui_components: "./ui_components/"
# Module capabilities
provides:
- "chat_completion"
- "conversation_management"
- "chatbot_configuration"
- "workflow_chat_step"
consumes:
- "rag_search"
- "llm_completion"
- "workflow_execution"
# API endpoints
endpoints:
- path: "/chatbot/chat"
method: "POST"
description: "Generate chat completion"
- path: "/chatbot/create"
method: "POST"
description: "Create new chatbot instance"
- path: "/chatbot/list"
method: "GET"
description: "List user chatbots"
# Workflow integration
workflow_steps:
- name: "chatbot_response"
description: "Generate chatbot response with optional RAG context"
inputs:
- name: "message"
type: "string"
required: true
description: "User message to respond to"
- name: "chatbot_id"
type: "string"
required: true
description: "ID of configured chatbot instance"
- name: "use_rag"
type: "boolean"
required: false
default: false
description: "Whether to use RAG for context"
- name: "context"
type: "object"
required: false
description: "Additional context data"
outputs:
- name: "response"
type: "string"
description: "Generated chatbot response"
- name: "conversation_id"
type: "string"
description: "Conversation ID for follow-up"
- name: "sources"
type: "array"
description: "RAG sources used (if any)"
# UI Configuration
ui_config:
icon: "message-circle"
color: "#10B981"
category: "AI & ML"
# Configuration forms
forms:
- name: "basic_config"
title: "Basic Settings"
fields: ["name", "chatbot_type", "model"]
- name: "personality"
title: "Personality & Prompts"
fields: ["system_prompt", "temperature", "fallback_responses"]
- name: "knowledge_base"
title: "Knowledge Base"
fields: ["use_rag", "rag_collection", "rag_top_k"]
- name: "advanced"
title: "Advanced Settings"
fields: ["max_tokens", "memory_length"]
# Permissions
permissions:
- name: "chatbot.create"
description: "Create new chatbot instances"
- name: "chatbot.configure"
description: "Configure chatbot settings"
- name: "chatbot.chat"
description: "Use chatbot for conversations"
- name: "chatbot.manage"
description: "Manage all chatbots (admin)"
# Analytics events
analytics_events:
- name: "chatbot_created"
description: "New chatbot instance created"
- name: "chat_message_sent"
description: "User sent message to chatbot"
- name: "chat_response_generated"
description: "Chatbot generated response"
- name: "rag_context_used"
description: "RAG context was used in response"
# Health checks
health_checks:
- name: "llm_connectivity"
description: "Check LLM client connection"
- name: "rag_availability"
description: "Check RAG module availability"
- name: "conversation_memory"
description: "Check conversation storage health"
# Documentation
documentation:
readme: "./README.md"
examples: "./examples/"
api_docs: "./docs/api.md"

225
backend/modules/factory.py Normal file
View File

@@ -0,0 +1,225 @@
"""
Module Factory for Confidential Empire
This factory creates and wires up all modules with their dependencies.
It ensures proper dependency injection while maintaining optimal performance
through direct method calls and minimal indirection.
"""
from typing import Dict, Optional, Any
import logging
# Import all modules
from .rag.main import RAGModule
from .chatbot.main import ChatbotModule, create_module as create_chatbot_module
from .workflow.main import WorkflowModule
# Import services that modules depend on
from app.services.litellm_client import LiteLLMClient
# Import protocols for type safety
from .protocols import (
RAGServiceProtocol,
ChatbotServiceProtocol,
LiteLLMClientProtocol,
WorkflowServiceProtocol,
ServiceRegistry
)
logger = logging.getLogger(__name__)
class ModuleFactory:
"""Factory for creating and wiring module dependencies"""
def __init__(self):
self.modules: Dict[str, Any] = {}
self.initialized = False
async def create_all_modules(self, config: Optional[Dict[str, Any]] = None) -> ServiceRegistry:
"""
Create all modules with proper dependency injection
Args:
config: Optional configuration for modules
Returns:
Dictionary of created modules with their dependencies wired
"""
config = config or {}
logger.info("Creating modules with dependency injection...")
# Step 1: Create LiteLLM client (shared dependency)
litellm_client = LiteLLMClient()
# Step 2: Create RAG module (no dependencies on other modules)
rag_module = RAGModule(config=config.get("rag", {}))
# Step 3: Create chatbot module with RAG dependency
chatbot_module = create_chatbot_module(
litellm_client=litellm_client,
rag_service=rag_module # RAG module implements RAGServiceProtocol
)
# Step 4: Create workflow module with chatbot dependency
workflow_module = WorkflowModule(
chatbot_service=chatbot_module # Chatbot module implements ChatbotServiceProtocol
)
# Store all modules
modules = {
"rag": rag_module,
"chatbot": chatbot_module,
"workflow": workflow_module
}
logger.info(f"Created {len(modules)} modules with dependencies wired")
# Initialize all modules
await self._initialize_modules(modules, config)
self.modules = modules
self.initialized = True
return modules
async def _initialize_modules(self, modules: Dict[str, Any], config: Dict[str, Any]):
"""Initialize all modules in dependency order"""
# Initialize in dependency order (modules with no deps first)
initialization_order = [
("rag", modules["rag"]),
("chatbot", modules["chatbot"]), # Depends on RAG
("workflow", modules["workflow"]) # Depends on Chatbot
]
for module_name, module in initialization_order:
try:
logger.info(f"Initializing {module_name} module...")
module_config = config.get(module_name, {})
# Different modules have different initialization patterns
if hasattr(module, 'initialize'):
if module_name == "rag":
await module.initialize()
else:
await module.initialize(**module_config)
logger.info(f"{module_name} module initialized successfully")
except Exception as e:
logger.error(f"❌ Failed to initialize {module_name} module: {e}")
raise RuntimeError(f"Module initialization failed: {module_name}") from e
async def cleanup_all_modules(self):
"""Cleanup all modules in reverse dependency order"""
if not self.initialized:
return
# Cleanup in reverse order
cleanup_order = ["workflow", "chatbot", "rag"]
for module_name in cleanup_order:
if module_name in self.modules:
try:
logger.info(f"Cleaning up {module_name} module...")
module = self.modules[module_name]
if hasattr(module, 'cleanup'):
await module.cleanup()
logger.info(f"{module_name} module cleaned up")
except Exception as e:
logger.error(f"❌ Error cleaning up {module_name}: {e}")
self.modules.clear()
self.initialized = False
def get_module(self, name: str) -> Optional[Any]:
"""Get a module by name"""
return self.modules.get(name)
def is_initialized(self) -> bool:
"""Check if factory is initialized"""
return self.initialized
# Global factory instance
module_factory = ModuleFactory()
# Convenience functions for external use
async def create_modules(config: Optional[Dict[str, Any]] = None) -> ServiceRegistry:
"""Create all modules with dependencies wired"""
return await module_factory.create_all_modules(config)
async def cleanup_modules():
"""Cleanup all modules"""
await module_factory.cleanup_all_modules()
def get_module(name: str) -> Optional[Any]:
"""Get a module by name"""
return module_factory.get_module(name)
def get_all_modules() -> Dict[str, Any]:
"""Get all modules"""
return module_factory.modules.copy()
# Factory functions for individual modules (for testing/special cases)
def create_rag_module(config: Optional[Dict[str, Any]] = None) -> RAGModule:
"""Create RAG module"""
return RAGModule(config=config or {})
def create_chatbot_with_rag(rag_service: RAGServiceProtocol,
litellm_client: LiteLLMClientProtocol) -> ChatbotModule:
"""Create chatbot module with RAG dependency"""
return create_chatbot_module(litellm_client=litellm_client, rag_service=rag_service)
def create_workflow_with_chatbot(chatbot_service: ChatbotServiceProtocol) -> WorkflowModule:
"""Create workflow module with chatbot dependency"""
return WorkflowModule(chatbot_service=chatbot_service)
# Module registry for backward compatibility
class ModuleRegistry:
"""Registry that provides access to modules (for backward compatibility)"""
def __init__(self, factory: ModuleFactory):
self._factory = factory
@property
def modules(self) -> Dict[str, Any]:
"""Get all modules (compatible with existing module_manager interface)"""
return self._factory.modules
def get(self, name: str) -> Optional[Any]:
"""Get module by name"""
return self._factory.get_module(name)
def __getitem__(self, name: str) -> Any:
"""Support dictionary-style access"""
module = self.get(name)
if module is None:
raise KeyError(f"Module '{name}' not found")
return module
def keys(self):
"""Get module names"""
return self._factory.modules.keys()
def values(self):
"""Get module instances"""
return self._factory.modules.values()
def items(self):
"""Get module name-instance pairs"""
return self._factory.modules.items()
# Create registry instance for backward compatibility
module_registry = ModuleRegistry(module_factory)

View File

@@ -0,0 +1,258 @@
"""
Module Protocols for Confidential Empire
This file defines the interface contracts that modules must implement for inter-module communication.
Using Python protocols provides compile-time type checking with zero runtime overhead.
"""
from typing import Protocol, Dict, List, Any, Optional, Union
from datetime import datetime
from abc import abstractmethod
class RAGServiceProtocol(Protocol):
"""Protocol for RAG (Retrieval-Augmented Generation) service interface"""
@abstractmethod
async def search(self, query: str, collection_name: str, top_k: int) -> Dict[str, Any]:
"""
Search for relevant documents
Args:
query: Search query string
collection_name: Name of the collection to search in
top_k: Number of top results to return
Returns:
Dictionary containing search results with 'results' key
"""
...
@abstractmethod
async def index_document(self, content: str, metadata: Dict[str, Any] = None) -> str:
"""
Index a document in the vector database
Args:
content: Document content to index
metadata: Optional metadata for the document
Returns:
Document ID
"""
...
@abstractmethod
async def delete_document(self, document_id: str) -> bool:
"""
Delete a document from the vector database
Args:
document_id: ID of document to delete
Returns:
True if successfully deleted
"""
...
class ChatbotServiceProtocol(Protocol):
"""Protocol for Chatbot service interface"""
@abstractmethod
async def chat_completion(self, request: Any, user_id: str, db: Any) -> Any:
"""
Generate chat completion response
Args:
request: Chat request object
user_id: ID of the user making the request
db: Database session
Returns:
Chat response object
"""
...
@abstractmethod
async def create_chatbot(self, config: Any, user_id: str, db: Any) -> Any:
"""
Create a new chatbot instance
Args:
config: Chatbot configuration
user_id: ID of the user creating the chatbot
db: Database session
Returns:
Created chatbot instance
"""
...
class LiteLLMClientProtocol(Protocol):
"""Protocol for LiteLLM client interface"""
@abstractmethod
async def completion(self, model: str, messages: List[Dict[str, str]], **kwargs) -> Any:
"""
Create a completion using the specified model
Args:
model: Model name to use
messages: List of messages for the conversation
**kwargs: Additional parameters for the completion
Returns:
Completion response object
"""
...
@abstractmethod
async def create_chat_completion(self, model: str, messages: List[Dict[str, str]],
user_id: str, api_key_id: str, **kwargs) -> Any:
"""
Create a chat completion with user tracking
Args:
model: Model name to use
messages: List of messages for the conversation
user_id: ID of the user making the request
api_key_id: API key identifier
**kwargs: Additional parameters
Returns:
Chat completion response
"""
...
class CacheServiceProtocol(Protocol):
"""Protocol for Cache service interface"""
@abstractmethod
async def get(self, key: str, default: Any = None) -> Any:
"""
Get value from cache
Args:
key: Cache key
default: Default value if key not found
Returns:
Cached value or default
"""
...
@abstractmethod
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""
Set value in cache
Args:
key: Cache key
value: Value to cache
ttl: Time to live in seconds
Returns:
True if successfully cached
"""
...
@abstractmethod
async def delete(self, key: str) -> bool:
"""
Delete key from cache
Args:
key: Cache key to delete
Returns:
True if successfully deleted
"""
...
class SecurityServiceProtocol(Protocol):
"""Protocol for Security service interface"""
@abstractmethod
async def analyze_request(self, request: Any) -> Any:
"""
Perform security analysis on a request
Args:
request: Request object to analyze
Returns:
Security analysis result
"""
...
@abstractmethod
async def validate_request(self, request: Any) -> bool:
"""
Validate request for security compliance
Args:
request: Request object to validate
Returns:
True if request is valid/safe
"""
...
class WorkflowServiceProtocol(Protocol):
"""Protocol for Workflow service interface"""
@abstractmethod
async def execute_workflow(self, workflow: Any, input_data: Dict[str, Any] = None) -> Any:
"""
Execute a workflow definition
Args:
workflow: Workflow definition to execute
input_data: Optional input data for the workflow
Returns:
Workflow execution result
"""
...
@abstractmethod
async def get_execution(self, execution_id: str) -> Any:
"""
Get workflow execution status
Args:
execution_id: ID of the execution to retrieve
Returns:
Execution status object
"""
...
class ModuleServiceProtocol(Protocol):
"""Base protocol for all module services"""
@abstractmethod
async def initialize(self, **kwargs) -> None:
"""Initialize the module"""
...
@abstractmethod
async def cleanup(self) -> None:
"""Cleanup module resources"""
...
@abstractmethod
def get_required_permissions(self) -> List[Any]:
"""Get required permissions for this module"""
...
# Type aliases for common service combinations
ServiceRegistry = Dict[str, ModuleServiceProtocol]
ServiceDependencies = Dict[str, Optional[ModuleServiceProtocol]]

View File

@@ -0,0 +1,6 @@
"""
RAG (Retrieval-Augmented Generation) module for Confidential Empire platform
"""
from .main import RAGModule
__all__ = ["RAGModule"]

1591
backend/modules/rag/main.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,10 @@
"""
Workflow Module for Confidential Empire
Provides workflow orchestration capabilities for chaining multiple LLM calls,
conditional logic, and data transformations.
"""
from .main import WorkflowModule
__all__ = ["WorkflowModule"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,389 @@
{
"templates": [
{
"id": "simple_chatbot_interaction",
"name": "Simple Chatbot Interaction",
"description": "Basic workflow that processes user input through a configured chatbot",
"version": "1.0",
"variables": {
"user_message": "Hello, I need help with my account",
"chatbot_id": "customer_support_bot"
},
"steps": [
{
"id": "chatbot_response",
"name": "Get Chatbot Response",
"type": "chatbot",
"chatbot_id": "{chatbot_id}",
"message_template": "{user_message}",
"output_variable": "bot_response",
"create_new_conversation": true,
"save_conversation_id": "conversation_id"
}
],
"outputs": {
"response": "{bot_response}",
"conversation_id": "{conversation_id}"
},
"metadata": {
"created_by": "system",
"use_case": "customer_support",
"tags": ["chatbot", "simple", "customer_support"]
}
},
{
"id": "multi_turn_customer_support",
"name": "Multi-Turn Customer Support Flow",
"description": "Advanced customer support workflow with intent classification, knowledge base lookup, and escalation",
"version": "1.0",
"variables": {
"user_message": "My order hasn't arrived yet",
"customer_support_chatbot": "support_assistant",
"escalation_chatbot": "human_handoff_bot",
"rag_collection": "support_knowledge_base"
},
"steps": [
{
"id": "classify_intent",
"name": "Classify User Intent",
"type": "llm_call",
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "system",
"content": "You are an intent classifier. Classify the user's message into one of: order_inquiry, technical_support, billing, general_question, escalation_needed. Respond with only the classification."
},
{
"role": "user",
"content": "{user_message}"
}
],
"output_variable": "intent"
},
{
"id": "handle_order_inquiry",
"name": "Handle Order Inquiry",
"type": "chatbot",
"conditions": ["{intent} == 'order_inquiry'"],
"chatbot_id": "{customer_support_chatbot}",
"message_template": "Customer inquiry about order: {user_message}",
"output_variable": "support_response",
"context_variables": {
"intent": "intent",
"inquiry_type": "order_inquiry"
}
},
{
"id": "handle_technical_support",
"name": "Handle Technical Support",
"type": "chatbot",
"conditions": ["{intent} == 'technical_support'"],
"chatbot_id": "{customer_support_chatbot}",
"message_template": "Technical support request: {user_message}",
"output_variable": "support_response",
"context_variables": {
"intent": "intent",
"inquiry_type": "technical_support"
}
},
{
"id": "escalate_to_human",
"name": "Escalate to Human Agent",
"type": "chatbot",
"conditions": ["{intent} == 'escalation_needed'"],
"chatbot_id": "{escalation_chatbot}",
"message_template": "Customer needs human assistance: {user_message}",
"output_variable": "escalation_response"
},
{
"id": "general_response",
"name": "General Support Response",
"type": "chatbot",
"conditions": ["{intent} == 'general_question' or {intent} == 'billing'"],
"chatbot_id": "{customer_support_chatbot}",
"message_template": "{user_message}",
"output_variable": "support_response"
},
{
"id": "format_final_response",
"name": "Format Final Response",
"type": "transform",
"input_variable": "support_response",
"output_variable": "final_response",
"transformation": "extract:response"
}
],
"outputs": {
"intent": "{intent}",
"response": "{final_response}",
"escalation_response": "{escalation_response}"
},
"error_handling": {
"retry_failed_steps": true,
"max_retries": 2,
"fallback_response": "I apologize, but I'm experiencing technical difficulties. Please try again later or contact support directly."
},
"metadata": {
"created_by": "system",
"use_case": "customer_support",
"tags": ["chatbot", "multi_turn", "intent_classification", "escalation"]
}
},
{
"id": "research_assistant_workflow",
"name": "AI Research Assistant",
"description": "Research workflow that uses specialized chatbots for different research tasks",
"version": "1.0",
"variables": {
"research_topic": "artificial intelligence trends 2024",
"researcher_chatbot": "ai_researcher",
"analyst_chatbot": "data_analyst",
"writer_chatbot": "content_writer"
},
"steps": [
{
"id": "initial_research",
"name": "Conduct Initial Research",
"type": "chatbot",
"chatbot_id": "{researcher_chatbot}",
"message_template": "Please research the following topic and provide key findings: {research_topic}",
"output_variable": "research_findings",
"create_new_conversation": true,
"save_conversation_id": "research_conversation"
},
{
"id": "analyze_findings",
"name": "Analyze Research Findings",
"type": "chatbot",
"chatbot_id": "{analyst_chatbot}",
"message_template": "Please analyze these research findings and identify key trends and insights: {research_findings}",
"output_variable": "analysis_results",
"create_new_conversation": true
},
{
"id": "create_summary",
"name": "Create Executive Summary",
"type": "chatbot",
"chatbot_id": "{writer_chatbot}",
"message_template": "Create an executive summary based on this research and analysis:\n\nTopic: {research_topic}\nResearch: {research_findings}\nAnalysis: {analysis_results}",
"output_variable": "executive_summary"
},
{
"id": "follow_up_questions",
"name": "Generate Follow-up Questions",
"type": "chatbot",
"chatbot_id": "{researcher_chatbot}",
"message_template": "Based on this research on {research_topic}, what are 5 important follow-up questions that should be investigated further?",
"output_variable": "follow_up_questions",
"conversation_id": "research_conversation"
}
],
"outputs": {
"research_findings": "{research_findings}",
"analysis": "{analysis_results}",
"summary": "{executive_summary}",
"next_steps": "{follow_up_questions}",
"conversation_id": "{research_conversation}"
},
"metadata": {
"created_by": "system",
"use_case": "research_automation",
"tags": ["chatbot", "research", "analysis", "multi_agent"]
}
},
{
"id": "content_creation_pipeline",
"name": "AI Content Creation Pipeline",
"description": "Multi-stage content creation using different specialized chatbots",
"version": "1.0",
"variables": {
"content_brief": "Write a blog post about sustainable technology innovations",
"target_audience": "tech-savvy professionals",
"content_length": "1500 words",
"research_bot": "researcher_assistant",
"writer_bot": "creative_writer",
"editor_bot": "content_editor"
},
"steps": [
{
"id": "research_phase",
"name": "Research Content Topic",
"type": "chatbot",
"chatbot_id": "{research_bot}",
"message_template": "Research this content brief: {content_brief}. Target audience: {target_audience}. Provide key points, statistics, and current trends.",
"output_variable": "research_data",
"create_new_conversation": true,
"save_conversation_id": "content_conversation"
},
{
"id": "create_outline",
"name": "Create Content Outline",
"type": "chatbot",
"chatbot_id": "{writer_bot}",
"message_template": "Create a detailed outline for: {content_brief}\nTarget audience: {target_audience}\nLength: {content_length}\nResearch data: {research_data}",
"output_variable": "content_outline"
},
{
"id": "write_content",
"name": "Write First Draft",
"type": "chatbot",
"chatbot_id": "{writer_bot}",
"message_template": "Write the full content based on this outline: {content_outline}\nBrief: {content_brief}\nResearch: {research_data}\nTarget length: {content_length}",
"output_variable": "first_draft"
},
{
"id": "edit_content",
"name": "Edit and Polish Content",
"type": "chatbot",
"chatbot_id": "{editor_bot}",
"message_template": "Please edit and improve this content for clarity, engagement, and professional tone:\n\n{first_draft}",
"output_variable": "final_content"
},
{
"id": "generate_metadata",
"name": "Generate SEO Metadata",
"type": "parallel",
"steps": [
{
"id": "create_title_options",
"name": "Generate Title Options",
"type": "chatbot",
"chatbot_id": "{writer_bot}",
"message_template": "Generate 5 compelling SEO-optimized titles for this content: {final_content}",
"output_variable": "title_options"
},
{
"id": "create_meta_description",
"name": "Create Meta Description",
"type": "chatbot",
"chatbot_id": "{editor_bot}",
"message_template": "Create an SEO-optimized meta description (150-160 characters) for this content: {final_content}",
"output_variable": "meta_description"
}
]
}
],
"outputs": {
"research": "{research_data}",
"outline": "{content_outline}",
"draft": "{first_draft}",
"final_content": "{final_content}",
"titles": "{title_options}",
"meta_description": "{meta_description}",
"conversation_id": "{content_conversation}"
},
"metadata": {
"created_by": "system",
"use_case": "content_creation",
"tags": ["chatbot", "content", "writing", "parallel", "seo"]
}
},
{
"id": "demo_chatbot_workflow",
"name": "Demo Interactive Chatbot",
"description": "A demonstration workflow showcasing interactive chatbot capabilities with multi-turn conversation, context awareness, and intelligent response handling",
"version": "1.0.0",
"steps": [
{
"id": "welcome_interaction",
"name": "Welcome User",
"type": "chatbot",
"chatbot_id": "demo_assistant",
"message_template": "Hello! I'm a demo AI assistant. What can I help you with today? Feel free to ask me about anything - technology, general questions, or just have a conversation!",
"output_variable": "welcome_response",
"create_new_conversation": true,
"save_conversation_id": "demo_conversation_id",
"enabled": true
},
{
"id": "analyze_user_intent",
"name": "Analyze User Intent",
"type": "llm_call",
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "system",
"content": "Analyze the user's response and classify their intent. Categories: question, casual_chat, technical_help, information_request, creative_task, other. Respond with only the category name."
},
{
"role": "user",
"content": "{user_input}"
}
],
"output_variable": "user_intent",
"parameters": {
"temperature": 0.3,
"max_tokens": 50
},
"enabled": true
},
{
"id": "personalized_response",
"name": "Generate Personalized Response",
"type": "chatbot",
"chatbot_id": "demo_assistant",
"message_template": "User intent: {user_intent}. User message: {user_input}. Please provide a helpful, engaging response tailored to their specific need.",
"output_variable": "personalized_response",
"conversation_id": "demo_conversation_id",
"context_variables": {
"intent": "user_intent",
"previous_welcome": "welcome_response"
},
"enabled": true
},
{
"id": "follow_up_suggestions",
"name": "Generate Follow-up Suggestions",
"type": "llm_call",
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "system",
"content": "Based on the conversation, suggest 2-3 relevant follow-up questions or topics the user might be interested in. Format as a simple list."
},
{
"role": "user",
"content": "User intent: {user_intent}, Response given: {personalized_response}"
}
],
"output_variable": "follow_up_suggestions",
"parameters": {
"temperature": 0.7,
"max_tokens": 150
},
"enabled": true
},
{
"id": "conversation_summary",
"name": "Create Conversation Summary",
"type": "transform",
"input_variable": "personalized_response",
"output_variable": "conversation_summary",
"transformation": "extract:content",
"enabled": true
}
],
"variables": {
"user_input": "I'm interested in learning about artificial intelligence and how it's changing the world",
"demo_assistant": "assistant"
},
"metadata": {
"created_by": "demo_system",
"use_case": "demonstration",
"tags": ["demo", "chatbot", "interactive", "multi_turn"],
"demo_instructions": {
"description": "This workflow demonstrates key chatbot capabilities including conversation continuity, intent analysis, personalized responses, and follow-up suggestions.",
"usage": "Execute this workflow with different user inputs to see how the chatbot adapts its responses based on intent analysis and conversation context.",
"features": [
"Multi-turn conversation with persistent conversation ID",
"Intent classification for tailored responses",
"Context-aware personalized interactions",
"Automatic follow-up suggestions",
"Conversation summarization"
]
}
},
"timeout": 300
}
]
}

79
backend/requirements.txt Normal file
View File

@@ -0,0 +1,79 @@
# Core framework
fastapi==0.104.1
uvicorn[standard]==0.24.0
pydantic==2.4.2
pydantic-settings==2.0.3
# Database
sqlalchemy==2.0.23
alembic==1.12.1
psycopg2-binary==2.9.9
asyncpg==0.29.0
# Redis
redis==5.0.1
aioredis==2.0.1
# Authentication & Security
python-jose[cryptography]==3.3.0
passlib[bcrypt]==1.7.4
bcrypt==4.0.1
python-multipart==0.0.6
cryptography==41.0.7
itsdangerous==2.1.2
# HTTP Client
httpx==0.25.2
aiohttp==3.9.0
# Background tasks
celery==5.3.4
flower==2.0.1
# Validation & Serialization
email-validator==2.1.0
python-dateutil==2.8.2
jsonschema==4.19.2
# Logging & Monitoring
structlog==23.2.0
prometheus-client==0.19.0
opentelemetry-api==1.21.0
opentelemetry-sdk==1.21.0
psutil==5.9.6
# Vector Database
qdrant-client==1.7.0
# Text Processing
tiktoken==0.5.1
# NLP and Content Processing (required for RAG module with integrated content processing)
nltk==3.8.1
spacy==3.7.2
markitdown==0.0.1a2
python-docx==1.1.0
sentence-transformers==2.6.1
# Optional heavy ML dependencies (commented out for lighter deployments)
# transformers==4.35.2
# Configuration
pyyaml==6.0.1
python-dotenv==1.0.0
# Module System
watchdog==3.0.0
click==8.2.1
# Development
pytest==7.4.3
pytest-asyncio==0.21.1
pytest-cov==4.1.0
black==23.11.0
isort==5.12.0
flake8==6.1.0
mypy==1.7.1
# API Documentation
swagger-ui-bundle==0.0.9

View File

@@ -0,0 +1,3 @@
"""
Test suite for Confidential Empire platform.
"""

View File

@@ -0,0 +1,132 @@
"""
Test LLM API endpoints.
"""
import pytest
from httpx import AsyncClient
from unittest.mock import patch, AsyncMock
class TestLLMEndpoints:
"""Test LLM API endpoints."""
@pytest.mark.asyncio
async def test_chat_completion_success(self, client: AsyncClient):
"""Test successful chat completion."""
# Mock the LiteLLM client response
mock_response = {
"choices": [
{
"message": {
"content": "Hello! How can I help you today?",
"role": "assistant"
}
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 15,
"total_tokens": 25
}
}
with patch("app.services.litellm_client.LiteLLMClient.create_chat_completion") as mock_chat:
mock_chat.return_value = mock_response
response = await client.post(
"/api/v1/llm/chat/completions",
json={
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Hello"}
]
},
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 200
data = response.json()
assert "choices" in data
assert data["choices"][0]["message"]["content"] == "Hello! How can I help you today?"
@pytest.mark.asyncio
async def test_chat_completion_unauthorized(self, client: AsyncClient):
"""Test chat completion without API key."""
response = await client.post(
"/api/v1/llm/chat/completions",
json={
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Hello"}
]
}
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_embeddings_success(self, client: AsyncClient):
"""Test successful embeddings generation."""
mock_response = {
"data": [
{
"embedding": [0.1, 0.2, 0.3],
"index": 0
}
],
"usage": {
"prompt_tokens": 5,
"total_tokens": 5
}
}
with patch("app.services.litellm_client.LiteLLMClient.create_embedding") as mock_embeddings:
mock_embeddings.return_value = mock_response
response = await client.post(
"/api/v1/llm/embeddings",
json={
"model": "text-embedding-ada-002",
"input": "Hello world"
},
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 200
data = response.json()
assert "data" in data
assert len(data["data"][0]["embedding"]) == 3
@pytest.mark.asyncio
async def test_budget_exceeded(self, client: AsyncClient):
"""Test budget exceeded scenario."""
with patch("app.services.budget_enforcement.BudgetEnforcementService.check_budget_compliance") as mock_check:
mock_check.side_effect = Exception("Budget exceeded")
response = await client.post(
"/api/v1/llm/chat/completions",
json={
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Hello"}
]
},
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 402 # Payment required
@pytest.mark.asyncio
async def test_model_validation(self, client: AsyncClient):
"""Test model validation."""
response = await client.post(
"/api/v1/llm/chat/completions",
json={
"model": "invalid-model",
"messages": [
{"role": "user", "content": "Hello"}
]
},
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 400

View File

@@ -0,0 +1,137 @@
"""
Test TEE API endpoints.
"""
import pytest
from httpx import AsyncClient
from unittest.mock import patch, AsyncMock
class TestTEEEndpoints:
"""Test TEE API endpoints."""
@pytest.mark.asyncio
async def test_tee_health_check(self, client: AsyncClient):
"""Test TEE health check endpoint."""
mock_health = {
"status": "healthy",
"timestamp": "2024-01-01T00:00:00Z",
"version": "1.0.0"
}
with patch("app.services.tee_service.TEEService.health_check") as mock_check:
mock_check.return_value = mock_health
response = await client.get(
"/api/v1/tee/health",
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
@pytest.mark.asyncio
async def test_tee_capabilities(self, client: AsyncClient):
"""Test TEE capabilities endpoint."""
mock_capabilities = {
"hardware_security": True,
"encryption_at_rest": True,
"memory_protection": True,
"supported_models": ["gpt-3.5-turbo", "claude-3-haiku"]
}
with patch("app.services.tee_service.TEEService.get_tee_capabilities") as mock_caps:
mock_caps.return_value = mock_capabilities
response = await client.get(
"/api/v1/tee/capabilities",
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 200
data = response.json()
assert data["hardware_security"] is True
assert "supported_models" in data
@pytest.mark.asyncio
async def test_tee_attestation(self, client: AsyncClient):
"""Test TEE attestation endpoint."""
mock_attestation = {
"attestation_document": "base64_encoded_document",
"signature": "signature_data",
"timestamp": "2024-01-01T00:00:00Z",
"valid": True
}
with patch("app.services.tee_service.TEEService.get_attestation") as mock_att:
mock_att.return_value = mock_attestation
response = await client.get(
"/api/v1/tee/attestation",
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 200
data = response.json()
assert data["valid"] is True
assert "attestation_document" in data
@pytest.mark.asyncio
async def test_tee_session_creation(self, client: AsyncClient):
"""Test TEE secure session creation."""
mock_session = {
"session_id": "secure-session-123",
"public_key": "public_key_data",
"expires_at": "2024-01-01T01:00:00Z"
}
with patch("app.services.tee_service.TEEService.create_secure_session") as mock_session_create:
mock_session_create.return_value = mock_session
response = await client.post(
"/api/v1/tee/session",
json={"model": "gpt-3.5-turbo"},
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 200
data = response.json()
assert "session_id" in data
assert "public_key" in data
@pytest.mark.asyncio
async def test_tee_metrics(self, client: AsyncClient):
"""Test TEE metrics endpoint."""
mock_metrics = {
"total_requests": 1000,
"successful_requests": 995,
"failed_requests": 5,
"avg_response_time": 0.125,
"privacy_score": 95.8,
"security_level": "high"
}
with patch("app.services.tee_service.TEEService.get_privacy_metrics") as mock_metrics_get:
mock_metrics_get.return_value = mock_metrics
response = await client.get(
"/api/v1/tee/metrics",
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 200
data = response.json()
assert data["privacy_score"] == 95.8
assert data["security_level"] == "high"
@pytest.mark.asyncio
async def test_tee_unauthorized(self, client: AsyncClient):
"""Test TEE endpoints without authentication."""
response = await client.get("/api/v1/tee/health")
assert response.status_code == 401
response = await client.get("/api/v1/tee/capabilities")
assert response.status_code == 401
response = await client.get("/api/v1/tee/attestation")
assert response.status_code == 401

95
backend/tests/conftest.py Normal file
View File

@@ -0,0 +1,95 @@
"""
Pytest configuration and fixtures for testing.
"""
import pytest
import asyncio
from httpx import AsyncClient
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from app.main import app
from app.db.database import get_db, Base
from app.core.config import settings
# Test database URL
TEST_DATABASE_URL = "sqlite+aiosqlite:///./test.db"
@pytest.fixture(scope="session")
def event_loop():
"""Create an instance of the default event loop for the test session."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="session")
async def test_engine():
"""Create test database engine."""
engine = create_async_engine(
TEST_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
# Create tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
# Cleanup
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()
@pytest.fixture
async def test_db(test_engine):
"""Create test database session."""
async_session = sessionmaker(
test_engine, class_=AsyncSession, expire_on_commit=False
)
async with async_session() as session:
yield session
@pytest.fixture
async def client(test_db):
"""Create test client."""
async def override_get_db():
yield test_db
app.dependency_overrides[get_db] = override_get_db
async with AsyncClient(app=app, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
@pytest.fixture
def test_user_data():
"""Test user data."""
return {
"email": "test@example.com",
"username": "testuser",
"full_name": "Test User",
"password": "testpassword123"
}
@pytest.fixture
def test_api_key_data():
"""Test API key data."""
return {
"name": "Test API Key",
"scopes": ["llm.chat", "llm.embeddings"],
"budget_limit": 100.0,
"budget_period": "monthly"
}

View File

@@ -0,0 +1,795 @@
#!/usr/bin/env python3
"""
Comprehensive Platform Integration Test
Tests all major platform functionality including:
- User authentication
- API key creation and management
- Budget enforcement
- LLM API (OpenAI compatible via LiteLLM)
- RAG system with real documents
- Ollama integration
- Module system
"""
import asyncio
import aiohttp
import aiofiles
import json
import logging
import sys
import time
from typing import Dict, Any, Optional, List
from datetime import datetime
from pathlib import Path
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class PlatformTester:
"""Comprehensive platform integration tester"""
def __init__(self,
backend_url: str = "http://localhost:58000",
frontend_url: str = "http://localhost:53000"):
self.backend_url = backend_url
self.frontend_url = frontend_url
self.session: Optional[aiohttp.ClientSession] = None
# Test data storage
self.user_data = {}
self.api_keys = []
self.budgets = []
self.collections = []
self.documents = []
# Test results
self.results = {
"passed": 0,
"failed": 0,
"tests": []
}
async def __aenter__(self):
timeout = aiohttp.ClientTimeout(total=120)
self.session = aiohttp.ClientSession(timeout=timeout)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.session:
await self.session.close()
def log_test(self, test_name: str, success: bool, details: str = "", data: Dict = None):
"""Log test result"""
status = "PASS" if success else "FAIL"
logger.info(f"[{status}] {test_name}: {details}")
self.results["tests"].append({
"name": test_name,
"success": success,
"details": details,
"data": data or {},
"timestamp": datetime.utcnow().isoformat()
})
if success:
self.results["passed"] += 1
else:
self.results["failed"] += 1
async def test_platform_health(self):
"""Test 1: Platform health and availability"""
logger.info("=" * 60)
logger.info("TEST 1: Platform Health Check")
logger.info("=" * 60)
try:
# Test backend health
async with self.session.get(f"{self.backend_url}/health") as response:
if response.status == 200:
health_data = await response.json()
self.log_test("Backend Health", True, f"Status: {health_data.get('status')}", health_data)
else:
self.log_test("Backend Health", False, f"HTTP {response.status}")
return False
# Test frontend availability
try:
async with self.session.get(f"{self.frontend_url}") as response:
if response.status == 200:
self.log_test("Frontend Availability", True, "Frontend accessible")
else:
self.log_test("Frontend Availability", False, f"HTTP {response.status}")
except Exception as e:
self.log_test("Frontend Availability", False, f"Connection error: {e}")
# Test API documentation
async with self.session.get(f"{self.backend_url}/api/v1/docs") as response:
if response.status == 200:
self.log_test("API Documentation", True, "Swagger UI accessible")
else:
self.log_test("API Documentation", False, f"HTTP {response.status}")
return True
except Exception as e:
self.log_test("Platform Health", False, f"Connection error: {e}")
return False
async def test_user_authentication(self):
"""Test 2: User registration and authentication"""
logger.info("=" * 60)
logger.info("TEST 2: User Authentication")
logger.info("=" * 60)
try:
# Create unique test user
timestamp = int(time.time())
test_email = f"test_{timestamp}@platform-test.com"
test_password = "TestPassword123!"
test_username = f"test_user_{timestamp}"
# Register user
register_data = {
"email": test_email,
"password": test_password,
"username": test_username
}
async with self.session.post(
f"{self.backend_url}/api/v1/auth/register",
json=register_data
) as response:
if response.status == 201:
user_data = await response.json()
self.user_data = user_data
self.log_test("User Registration", True, f"User created: {user_data.get('email')}", user_data)
else:
error_data = await response.json()
self.log_test("User Registration", False, f"HTTP {response.status}: {error_data}")
return False
# Login user
login_data = {
"email": test_email,
"password": test_password
}
async with self.session.post(
f"{self.backend_url}/api/v1/auth/login",
json=login_data
) as response:
if response.status == 200:
login_response = await response.json()
self.user_data["access_token"] = login_response["access_token"]
self.user_data["refresh_token"] = login_response["refresh_token"]
self.log_test("User Login", True, "Authentication successful", {"token_type": login_response.get("token_type")})
else:
error_data = await response.json()
self.log_test("User Login", False, f"HTTP {response.status}: {error_data}")
return False
# Test token verification
headers = {"Authorization": f"Bearer {self.user_data['access_token']}"}
async with self.session.get(f"{self.backend_url}/api/v1/auth/me", headers=headers) as response:
if response.status == 200:
user_info = await response.json()
self.log_test("Token Verification", True, f"User info retrieved: {user_info.get('email')}", user_info)
else:
error_data = await response.json()
self.log_test("Token Verification", False, f"HTTP {response.status}: {error_data}")
return False
return True
except Exception as e:
self.log_test("User Authentication", False, f"Error: {e}")
return False
async def test_api_key_management(self):
"""Test 3: API key creation and management"""
logger.info("=" * 60)
logger.info("TEST 3: API Key Management")
logger.info("=" * 60)
if not self.user_data.get("access_token"):
self.log_test("API Key Management", False, "No access token available")
return False
try:
headers = {"Authorization": f"Bearer {self.user_data['access_token']}"}
# Create API key
api_key_data = {
"name": "Test API Key",
"description": "API key for comprehensive platform testing",
"scopes": ["chat.completions", "embeddings.create", "models.list"],
"expires_in_days": 30
}
async with self.session.post(
f"{self.backend_url}/api/v1/api-keys",
json=api_key_data,
headers=headers
) as response:
if response.status == 201:
api_key_response = await response.json()
self.api_keys.append(api_key_response)
self.log_test("API Key Creation", True, f"Key created: {api_key_response.get('name')}", {
"key_id": api_key_response.get("id"),
"key_prefix": api_key_response.get("key_prefix")
})
else:
error_data = await response.json()
self.log_test("API Key Creation", False, f"HTTP {response.status}: {error_data}")
return False
# List API keys
async with self.session.get(f"{self.backend_url}/api/v1/api-keys", headers=headers) as response:
if response.status == 200:
keys_list = await response.json()
self.log_test("API Key Listing", True, f"Found {len(keys_list)} keys", {"count": len(keys_list)})
else:
error_data = await response.json()
self.log_test("API Key Listing", False, f"HTTP {response.status}: {error_data}")
return True
except Exception as e:
self.log_test("API Key Management", False, f"Error: {e}")
return False
async def test_budget_system(self):
"""Test 4: Budget creation and enforcement"""
logger.info("=" * 60)
logger.info("TEST 4: Budget System")
logger.info("=" * 60)
if not self.user_data.get("access_token") or not self.api_keys:
self.log_test("Budget System", False, "Prerequisites not met")
return False
try:
headers = {"Authorization": f"Bearer {self.user_data['access_token']}"}
api_key_id = self.api_keys[0].get("id")
# Create budget
budget_data = {
"name": "Test Budget",
"description": "Budget for comprehensive testing",
"api_key_id": api_key_id,
"budget_type": "monthly",
"limit_cents": 10000, # $100.00
"alert_thresholds": [50, 80, 95]
}
async with self.session.post(
f"{self.backend_url}/api/v1/budgets",
json=budget_data,
headers=headers
) as response:
if response.status == 201:
budget_response = await response.json()
self.budgets.append(budget_response)
self.log_test("Budget Creation", True, f"Budget created: {budget_response.get('name')}", {
"budget_id": budget_response.get("id"),
"limit": budget_response.get("limit_cents")
})
else:
error_data = await response.json()
self.log_test("Budget Creation", False, f"HTTP {response.status}: {error_data}")
return False
# Get budget status
async with self.session.get(f"{self.backend_url}/api/v1/llm/budget/status", headers=headers) as response:
if response.status == 200:
budget_status = await response.json()
self.log_test("Budget Status Check", True, "Budget status retrieved", budget_status)
else:
error_data = await response.json()
self.log_test("Budget Status Check", False, f"HTTP {response.status}: {error_data}")
return True
except Exception as e:
self.log_test("Budget System", False, f"Error: {e}")
return False
async def test_llm_integration(self):
"""Test 5: LLM API (OpenAI compatible via LiteLLM)"""
logger.info("=" * 60)
logger.info("TEST 5: LLM Integration (OpenAI Compatible)")
logger.info("=" * 60)
if not self.api_keys:
self.log_test("LLM Integration", False, "No API key available")
return False
try:
# Use API key for authentication
api_key = self.api_keys[0].get("api_key", "")
if not api_key:
# Try to use the token as fallback
api_key = self.user_data.get("access_token", "")
headers = {"Authorization": f"Bearer {api_key}"}
# Test 1: List available models
async with self.session.get(f"{self.backend_url}/api/v1/llm/models", headers=headers) as response:
if response.status == 200:
models_data = await response.json()
model_count = len(models_data.get("models", []))
self.log_test("List Models", True, f"Found {model_count} models", {"model_count": model_count})
else:
error_data = await response.json()
self.log_test("List Models", False, f"HTTP {response.status}: {error_data}")
# Test 2: Chat completion (joke request as specified)
chat_data = {
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Tell me a programming joke"}
],
"max_tokens": 150,
"temperature": 0.7
}
async with self.session.post(
f"{self.backend_url}/api/v1/llm/chat/completions",
json=chat_data,
headers=headers
) as response:
if response.status == 200:
chat_response = await response.json()
joke = chat_response.get("choices", [{}])[0].get("message", {}).get("content", "")
tokens_used = chat_response.get("usage", {}).get("total_tokens", 0)
self.log_test("Chat Completion", True, f"Joke received ({tokens_used} tokens)", {
"joke_preview": joke[:100] + "..." if len(joke) > 100 else joke,
"tokens_used": tokens_used
})
else:
error_data = await response.json()
self.log_test("Chat Completion", False, f"HTTP {response.status}: {error_data}")
# Test 3: Embeddings
embedding_data = {
"model": "text-embedding-ada-002",
"input": "This is a test sentence for embedding generation."
}
async with self.session.post(
f"{self.backend_url}/api/v1/llm/embeddings",
json=embedding_data,
headers=headers
) as response:
if response.status == 200:
embedding_response = await response.json()
embedding_dim = len(embedding_response.get("data", [{}])[0].get("embedding", []))
self.log_test("Embeddings", True, f"Embedding generated ({embedding_dim} dimensions)", {
"dimension": embedding_dim,
"tokens_used": embedding_response.get("usage", {}).get("total_tokens", 0)
})
else:
error_data = await response.json()
self.log_test("Embeddings", False, f"HTTP {response.status}: {error_data}")
return True
except Exception as e:
self.log_test("LLM Integration", False, f"Error: {e}")
return False
async def test_ollama_integration(self):
"""Test 6: Ollama integration"""
logger.info("=" * 60)
logger.info("TEST 6: Ollama Integration")
logger.info("=" * 60)
try:
# Test Ollama proxy health
ollama_url = "http://localhost:11434"
try:
async with self.session.get(f"{ollama_url}/api/tags") as response:
if response.status == 200:
models_data = await response.json()
model_count = len(models_data.get("models", []))
self.log_test("Ollama Connection", True, f"Connected to Ollama ({model_count} models)", {
"model_count": model_count,
"models": [m.get("name") for m in models_data.get("models", [])][:5]
})
# Test Ollama chat if models are available
if model_count > 0:
model_name = models_data["models"][0]["name"]
chat_data = {
"model": model_name,
"messages": [{"role": "user", "content": "Hello, how are you?"}],
"stream": False
}
try:
async with self.session.post(
f"{ollama_url}/api/chat",
json=chat_data,
timeout=aiohttp.ClientTimeout(total=30)
) as chat_response:
if chat_response.status == 200:
chat_result = await chat_response.json()
response_text = chat_result.get("message", {}).get("content", "")
self.log_test("Ollama Chat", True, f"Response from {model_name}", {
"model": model_name,
"response_preview": response_text[:100] + "..." if len(response_text) > 100 else response_text
})
else:
self.log_test("Ollama Chat", False, f"HTTP {chat_response.status}")
except asyncio.TimeoutError:
self.log_test("Ollama Chat", False, "Timeout - model may be loading")
else:
self.log_test("Ollama Models", False, "No models available in Ollama")
else:
self.log_test("Ollama Connection", False, f"HTTP {response.status}")
except Exception as e:
self.log_test("Ollama Connection", False, f"Connection error: {e}")
return True
except Exception as e:
self.log_test("Ollama Integration", False, f"Error: {e}")
return False
async def test_rag_system(self):
"""Test 7: RAG system with real document processing"""
logger.info("=" * 60)
logger.info("TEST 7: RAG System")
logger.info("=" * 60)
if not self.user_data.get("access_token"):
self.log_test("RAG System", False, "No access token available")
return False
try:
headers = {"Authorization": f"Bearer {self.user_data['access_token']}"}
# Create test collection
collection_data = {
"name": f"Test Collection {int(time.time())}",
"description": "Comprehensive test collection for RAG functionality"
}
async with self.session.post(
f"{self.backend_url}/api/v1/rag/collections",
json=collection_data,
headers=headers
) as response:
if response.status == 200:
collection_response = await response.json()
collection = collection_response.get("collection", {})
self.collections.append(collection)
self.log_test("RAG Collection Creation", True, f"Collection created: {collection.get('name')}", {
"collection_id": collection.get("id"),
"name": collection.get("name")
})
else:
error_data = await response.json()
self.log_test("RAG Collection Creation", False, f"HTTP {response.status}: {error_data}")
return False
# Create test document for upload
test_content = f"""# Test Document for RAG System
This is a comprehensive test document created at {datetime.utcnow().isoformat()}.
## Introduction
This document contains various types of content to test the RAG system's ability to:
- Extract and process text content
- Generate meaningful embeddings
- Index content for search and retrieval
## Technical Details
The RAG system should be able to process this document and make it searchable.
Key capabilities include:
- Document chunking and processing
- Vector embedding generation
- Semantic search functionality
- Content retrieval and ranking
## Testing Scenarios
This document will be used to test:
1. Document upload and processing
2. Content extraction and conversion
3. Vector generation and indexing
4. Search and retrieval accuracy
## Keywords for Search Testing
artificial intelligence, machine learning, natural language processing,
vector database, semantic search, document processing, text analysis
"""
# Upload document
collection_id = self.collections[-1]["id"]
# Create form data
form_data = aiohttp.FormData()
form_data.add_field('collection_id', str(collection_id))
form_data.add_field('file', test_content.encode(),
filename='test_document.txt',
content_type='text/plain')
async with self.session.post(
f"{self.backend_url}/api/v1/rag/documents",
data=form_data,
headers=headers
) as response:
if response.status == 200:
document_response = await response.json()
document = document_response.get("document", {})
self.documents.append(document)
self.log_test("RAG Document Upload", True, f"Document uploaded: {document.get('filename')}", {
"document_id": document.get("id"),
"filename": document.get("filename"),
"size": document.get("size")
})
else:
error_data = await response.json()
self.log_test("RAG Document Upload", False, f"HTTP {response.status}: {error_data}")
return False
# Wait for document processing (check status multiple times)
document_id = self.documents[-1]["id"]
processing_complete = False
for attempt in range(30): # Wait up to 60 seconds
await asyncio.sleep(2)
async with self.session.get(
f"{self.backend_url}/api/v1/rag/documents/{document_id}",
headers=headers
) as response:
if response.status == 200:
doc_status = await response.json()
document_info = doc_status.get("document", {})
status = document_info.get("status", "unknown")
if status in ["processed", "indexed"]:
processing_complete = True
word_count = document_info.get("word_count", 0)
vector_count = document_info.get("vector_count", 0)
self.log_test("RAG Document Processing", True, f"Processing complete: {status}", {
"status": status,
"word_count": word_count,
"vector_count": vector_count,
"processing_time": f"{(attempt + 1) * 2} seconds"
})
break
elif status == "error":
error_msg = document_info.get("processing_error", "Unknown error")
self.log_test("RAG Document Processing", False, f"Processing failed: {error_msg}")
break
elif attempt == 29:
self.log_test("RAG Document Processing", False, "Processing timeout after 60 seconds")
break
# Test document search (if processing completed)
if processing_complete:
search_query = "artificial intelligence machine learning"
# Note: Search endpoint might not be implemented yet, so we'll test what's available
async with self.session.get(
f"{self.backend_url}/api/v1/rag/stats",
headers=headers
) as response:
if response.status == 200:
rag_stats = await response.json()
stats = rag_stats.get("stats", {})
self.log_test("RAG Statistics", True, "RAG stats retrieved", stats)
else:
error_data = await response.json()
self.log_test("RAG Statistics", False, f"HTTP {response.status}: {error_data}")
return True
except Exception as e:
self.log_test("RAG System", False, f"Error: {e}")
return False
async def test_module_system(self):
"""Test 8: Module system functionality"""
logger.info("=" * 60)
logger.info("TEST 8: Module System")
logger.info("=" * 60)
try:
# Test modules status
async with self.session.get(f"{self.backend_url}/api/v1/modules/status") as response:
if response.status == 200:
modules_status = await response.json()
enabled_count = len([m for m in modules_status if m.get("enabled")])
total_count = len(modules_status)
self.log_test("Module System Status", True, f"{enabled_count}/{total_count} modules enabled", {
"enabled_modules": enabled_count,
"total_modules": total_count,
"modules": [m.get("name") for m in modules_status if m.get("enabled")]
})
else:
error_data = await response.json()
self.log_test("Module System Status", False, f"HTTP {response.status}: {error_data}")
# Test individual module info
test_modules = ["rag", "content", "cache"]
for module_name in test_modules:
async with self.session.get(f"{self.backend_url}/api/v1/modules/{module_name}") as response:
if response.status == 200:
module_info = await response.json()
self.log_test(f"Module Info ({module_name})", True, f"Module info retrieved", {
"module": module_name,
"enabled": module_info.get("enabled"),
"version": module_info.get("version")
})
else:
self.log_test(f"Module Info ({module_name})", False, f"HTTP {response.status}")
return True
except Exception as e:
self.log_test("Module System", False, f"Error: {e}")
return False
async def cleanup_test_data(self):
"""Cleanup test data created during testing"""
logger.info("=" * 60)
logger.info("CLEANUP: Removing test data")
logger.info("=" * 60)
if not self.user_data.get("access_token"):
return
headers = {"Authorization": f"Bearer {self.user_data['access_token']}"}
try:
# Delete documents
for document in self.documents:
doc_id = document.get("id")
try:
async with self.session.delete(
f"{self.backend_url}/api/v1/rag/documents/{doc_id}",
headers=headers
) as response:
if response.status == 200:
logger.info(f"Deleted document {doc_id}")
else:
logger.warning(f"Failed to delete document {doc_id}: HTTP {response.status}")
except Exception as e:
logger.warning(f"Error deleting document {doc_id}: {e}")
# Delete collections
for collection in self.collections:
collection_id = collection.get("id")
try:
async with self.session.delete(
f"{self.backend_url}/api/v1/rag/collections/{collection_id}",
headers=headers
) as response:
if response.status == 200:
logger.info(f"Deleted collection {collection_id}")
else:
logger.warning(f"Failed to delete collection {collection_id}: HTTP {response.status}")
except Exception as e:
logger.warning(f"Error deleting collection {collection_id}: {e}")
# Delete budgets
for budget in self.budgets:
budget_id = budget.get("id")
try:
async with self.session.delete(
f"{self.backend_url}/api/v1/budgets/{budget_id}",
headers=headers
) as response:
if response.status == 200:
logger.info(f"Deleted budget {budget_id}")
else:
logger.warning(f"Failed to delete budget {budget_id}: HTTP {response.status}")
except Exception as e:
logger.warning(f"Error deleting budget {budget_id}: {e}")
# Delete API keys
for api_key in self.api_keys:
key_id = api_key.get("id")
try:
async with self.session.delete(
f"{self.backend_url}/api/v1/api-keys/{key_id}",
headers=headers
) as response:
if response.status == 200:
logger.info(f"Deleted API key {key_id}")
else:
logger.warning(f"Failed to delete API key {key_id}: HTTP {response.status}")
except Exception as e:
logger.warning(f"Error deleting API key {key_id}: {e}")
logger.info("Cleanup completed")
except Exception as e:
logger.error(f"Error during cleanup: {e}")
async def run_all_tests(self):
"""Run all platform tests"""
logger.info("🚀 Starting Comprehensive Platform Integration Tests")
logger.info("=" * 80)
start_time = time.time()
# Run all tests in sequence
tests = [
self.test_platform_health,
self.test_user_authentication,
self.test_api_key_management,
self.test_budget_system,
self.test_llm_integration,
self.test_ollama_integration,
self.test_rag_system,
self.test_module_system
]
for test_func in tests:
try:
await test_func()
except Exception as e:
logger.error(f"Unexpected error in {test_func.__name__}: {e}")
self.log_test(test_func.__name__, False, f"Unexpected error: {e}")
# Brief pause between tests
await asyncio.sleep(1)
# Cleanup
await self.cleanup_test_data()
# Print final results
end_time = time.time()
duration = end_time - start_time
logger.info("=" * 80)
logger.info("COMPREHENSIVE PLATFORM TEST RESULTS")
logger.info("=" * 80)
logger.info(f"Total Tests: {self.results['passed'] + self.results['failed']}")
logger.info(f"Passed: {self.results['passed']}")
logger.info(f"Failed: {self.results['failed']}")
logger.info(f"Success Rate: {(self.results['passed'] / (self.results['passed'] + self.results['failed']) * 100):.1f}%")
logger.info(f"Duration: {duration:.2f} seconds")
if self.results['failed'] > 0:
logger.info("\\nFailed Tests:")
for test in self.results['tests']:
if not test['success']:
logger.info(f" - {test['name']}: {test['details']}")
logger.info("=" * 80)
# Save detailed results to file
results_file = f"platform_test_results_{int(time.time())}.json"
with open(results_file, 'w') as f:
json.dump(self.results, f, indent=2)
logger.info(f"Detailed results saved to: {results_file}")
# Return success/failure
return self.results['failed'] == 0
async def main():
"""Main test runner"""
try:
async with PlatformTester() as tester:
success = await tester.run_all_tests()
return 0 if success else 1
except KeyboardInterrupt:
logger.info("\\nTest interrupted by user")
return 1
except Exception as e:
logger.error(f"Test runner error: {e}")
return 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
sys.exit(exit_code)

Some files were not shown because too many files have changed in this diff Show More