mega changes

This commit is contained in:
2025-11-20 11:11:18 +01:00
parent e070c95190
commit 841d79f26b
138 changed files with 21499 additions and 8844 deletions

13
.gitignore vendored
View File

@@ -66,3 +66,16 @@ frontend/node_modules/
node_modules/
venv/
venv_memory_monitor/
to_delete/
security-review.md
features-plan.md
AGENTS.md
CLAUDE.md
backend/CURL_VERIFICATION_EXAMPLES.md
backend/OPENAI_COMPATIBILITY_GUIDE.md
backend/production-deployment-guide.md
features-plan.md
security-review.md
backend/.env.local
backend/.env.test

View File

@@ -1,12 +1,17 @@
import asyncio
import os
from logging.config import fileConfig
from dotenv import load_dotenv
from alembic import context
from sqlalchemy import engine_from_config, pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import AsyncEngine
# Load environment variables from .env file
load_dotenv()
load_dotenv('.env.local')
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config

View File

@@ -4,7 +4,7 @@ This migration represents the complete, accurate database schema based on the ac
model files in the codebase. All legacy migrations have been consolidated into this
single migration to ensure the database matches what the models expect.
Revision ID: 000_consolidated_ground_truth_schema
Revision ID: 000_ground_truth
Revises:
Create Date: 2025-08-22 10:30:00.000000

View File

@@ -0,0 +1,62 @@
"""add roles table
Revision ID: 001
Revises: 000_ground_truth
Create Date: 2025-01-30 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers
revision = '001_add_roles_table'
down_revision = '000_ground_truth'
branch_labels = None
depends_on = None
def upgrade():
# Create roles table
op.create_table(
'roles',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=50), nullable=False),
sa.Column('display_name', sa.String(length=100), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('level', sa.String(length=20), nullable=False),
sa.Column('permissions', sa.JSON(), nullable=True),
sa.Column('can_manage_users', sa.Boolean(), nullable=True, default=False),
sa.Column('can_manage_budgets', sa.Boolean(), nullable=True, default=False),
sa.Column('can_view_reports', sa.Boolean(), nullable=True, default=False),
sa.Column('can_manage_tools', sa.Boolean(), nullable=True, default=False),
sa.Column('inherits_from', sa.JSON(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
sa.Column('is_system_role', sa.Boolean(), nullable=True, default=False),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
# Create indexes for roles
op.create_index(op.f('ix_roles_id'), 'roles', ['id'], unique=False)
op.create_index(op.f('ix_roles_name'), 'roles', ['name'], unique=True)
op.create_index(op.f('ix_roles_level'), 'roles', ['level'], unique=False)
# Add role_id to users table
op.add_column('users', sa.Column('role_id', sa.Integer(), nullable=True))
op.create_foreign_key(
'fk_users_role_id', 'users', 'roles',
['role_id'], ['id'], ondelete='SET NULL'
)
op.create_index('ix_users_role_id', 'users', ['role_id'])
def downgrade():
# Remove role_id from users
op.drop_index('ix_users_role_id', table_name='users')
op.drop_constraint('fk_users_role_id', table_name='users', type_='foreignkey')
op.drop_column('users', 'role_id')
# Drop roles table
op.drop_index(op.f('ix_roles_level'), table_name='roles')
op.drop_index(op.f('ix_roles_name'), table_name='roles')
op.drop_index(op.f('ix_roles_id'), table_name='roles')
op.drop_table('roles')

View File

@@ -0,0 +1,118 @@
"""add tools tables
Revision ID: 002
Revises: 001_add_roles_table
Create Date: 2025-01-30 00:00:01.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers
revision = '002_add_tools_tables'
down_revision = '001_add_roles_table'
branch_labels = None
depends_on = None
def upgrade():
# Create tool_categories table
op.create_table(
'tool_categories',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=50), nullable=False),
sa.Column('display_name', sa.String(length=100), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('icon', sa.String(length=50), nullable=True),
sa.Column('color', sa.String(length=20), nullable=True),
sa.Column('sort_order', sa.Integer(), nullable=True, default=0),
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_tool_categories_id'), 'tool_categories', ['id'], unique=False)
op.create_index(op.f('ix_tool_categories_name'), 'tool_categories', ['name'], unique=True)
# Create tools table
op.create_table(
'tools',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=100), nullable=False),
sa.Column('display_name', sa.String(length=200), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('tool_type', sa.String(length=20), nullable=False),
sa.Column('code', sa.Text(), nullable=False),
sa.Column('parameters_schema', sa.JSON(), nullable=True),
sa.Column('return_schema', sa.JSON(), nullable=True),
sa.Column('timeout_seconds', sa.Integer(), nullable=True, default=30),
sa.Column('max_memory_mb', sa.Integer(), nullable=True, default=256),
sa.Column('max_cpu_seconds', sa.Float(), nullable=True, default=10.0),
sa.Column('docker_image', sa.String(length=200), nullable=True),
sa.Column('docker_command', sa.Text(), nullable=True),
sa.Column('is_public', sa.Boolean(), nullable=True, default=False),
sa.Column('is_approved', sa.Boolean(), nullable=True, default=False),
sa.Column('created_by_user_id', sa.Integer(), nullable=False),
sa.Column('category', sa.String(length=50), nullable=True),
sa.Column('tags', sa.JSON(), nullable=True),
sa.Column('usage_count', sa.Integer(), nullable=True, default=0),
sa.Column('last_used_at', sa.DateTime(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_tools_id'), 'tools', ['id'], unique=False)
op.create_index(op.f('ix_tools_name'), 'tools', ['name'], unique=False)
op.create_foreign_key(
'fk_tools_created_by_user_id', 'tools', 'users',
['created_by_user_id'], ['id'], ondelete='CASCADE'
)
# Create tool_executions table
op.create_table(
'tool_executions',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('tool_id', sa.Integer(), nullable=False),
sa.Column('executed_by_user_id', sa.Integer(), nullable=False),
sa.Column('parameters', sa.JSON(), nullable=True),
sa.Column('status', sa.String(length=20), nullable=False, default='pending'),
sa.Column('output', sa.Text(), nullable=True),
sa.Column('error_message', sa.Text(), nullable=True),
sa.Column('return_code', sa.Integer(), nullable=True),
sa.Column('execution_time_ms', sa.Integer(), nullable=True),
sa.Column('memory_used_mb', sa.Float(), nullable=True),
sa.Column('cpu_time_ms', sa.Integer(), nullable=True),
sa.Column('container_id', sa.String(length=100), nullable=True),
sa.Column('docker_logs', sa.Text(), nullable=True),
sa.Column('started_at', sa.DateTime(), nullable=True),
sa.Column('completed_at', sa.DateTime(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_tool_executions_id'), 'tool_executions', ['id'], unique=False)
op.create_foreign_key(
'fk_tool_executions_tool_id', 'tool_executions', 'tools',
['tool_id'], ['id'], ondelete='CASCADE'
)
op.create_foreign_key(
'fk_tool_executions_executed_by_user_id', 'tool_executions', 'users',
['executed_by_user_id'], ['id'], ondelete='CASCADE'
)
def downgrade():
# Drop tool_executions table
op.drop_constraint('fk_tool_executions_executed_by_user_id', table_name='tool_executions', type_='foreignkey')
op.drop_constraint('fk_tool_executions_tool_id', table_name='tool_executions', type_='foreignkey')
op.drop_index(op.f('ix_tool_executions_id'), table_name='tool_executions')
op.drop_table('tool_executions')
# Drop tools table
op.drop_constraint('fk_tools_created_by_user_id', table_name='tools', type_='foreignkey')
op.drop_index(op.f('ix_tools_name'), table_name='tools')
op.drop_index(op.f('ix_tools_id'), table_name='tools')
op.drop_table('tools')
# Drop tool_categories table
op.drop_index(op.f('ix_tool_categories_name'), table_name='tool_categories')
op.drop_index(op.f('ix_tool_categories_id'), table_name='tool_categories')
op.drop_table('tool_categories')

View File

@@ -0,0 +1,132 @@
"""add notifications tables
Revision ID: 003_add_notifications_tables
Revises: 002_add_tools_tables
Create Date: 2025-01-30 00:00:02.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers
revision = '003_add_notifications_tables'
down_revision = '002_add_tools_tables'
branch_labels = None
depends_on = None
def upgrade():
# Create notification_templates table
op.create_table(
'notification_templates',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=100), nullable=False),
sa.Column('display_name', sa.String(length=200), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('notification_type', sa.String(length=20), nullable=False),
sa.Column('subject_template', sa.Text(), nullable=True),
sa.Column('body_template', sa.Text(), nullable=False),
sa.Column('html_template', sa.Text(), nullable=True),
sa.Column('default_priority', sa.String(length=20), nullable=True, default='normal'),
sa.Column('variables', sa.JSON(), nullable=True),
sa.Column('metadata', sa.JSON(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_notification_templates_id'), 'notification_templates', ['id'], unique=False)
op.create_index(op.f('ix_notification_templates_name'), 'notification_templates', ['name'], unique=True)
# Create notification_channels table
op.create_table(
'notification_channels',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=100), nullable=False),
sa.Column('display_name', sa.String(length=200), nullable=False),
sa.Column('notification_type', sa.String(length=20), nullable=False),
sa.Column('config', sa.JSON(), nullable=False),
sa.Column('credentials', sa.JSON(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
sa.Column('is_default', sa.Boolean(), nullable=True, default=False),
sa.Column('rate_limit', sa.Integer(), nullable=True, default=100),
sa.Column('retry_count', sa.Integer(), nullable=True, default=3),
sa.Column('retry_delay_minutes', sa.Integer(), nullable=True, default=5),
sa.Column('last_used_at', sa.DateTime(), nullable=True),
sa.Column('success_count', sa.Integer(), nullable=True, default=0),
sa.Column('failure_count', sa.Integer(), nullable=True, default=0),
sa.Column('last_error', sa.Text(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_notification_channels_id'), 'notification_channels', ['id'], unique=False)
op.create_index(op.f('ix_notification_channels_name'), 'notification_channels', ['name'], unique=False)
# Create notifications table
op.create_table(
'notifications',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('subject', sa.String(length=500), nullable=True),
sa.Column('body', sa.Text(), nullable=False),
sa.Column('html_body', sa.Text(), nullable=True),
sa.Column('recipients', sa.JSON(), nullable=False),
sa.Column('cc_recipients', sa.JSON(), nullable=True),
sa.Column('bcc_recipients', sa.JSON(), nullable=True),
sa.Column('priority', sa.String(length=20), nullable=True, default='normal'),
sa.Column('scheduled_at', sa.DateTime(), nullable=True),
sa.Column('expires_at', sa.DateTime(), nullable=True),
sa.Column('template_id', sa.Integer(), nullable=True),
sa.Column('channel_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('status', sa.String(length=20), nullable=True, default='pending'),
sa.Column('attempts', sa.Integer(), nullable=True, default=0),
sa.Column('max_attempts', sa.Integer(), nullable=True, default=3),
sa.Column('sent_at', sa.DateTime(), nullable=True),
sa.Column('delivered_at', sa.DateTime(), nullable=True),
sa.Column('failed_at', sa.DateTime(), nullable=True),
sa.Column('error_message', sa.Text(), nullable=True),
sa.Column('external_id', sa.String(length=200), nullable=True),
sa.Column('callback_url', sa.String(length=500), nullable=True),
sa.Column('metadata', sa.JSON(), nullable=True),
sa.Column('tags', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_notifications_id'), 'notifications', ['id'], unique=False)
op.create_index(op.f('ix_notifications_status'), 'notifications', ['status'], unique=False)
op.create_index(op.f('ix_notifications_scheduled_at'), 'notifications', ['scheduled_at'], unique=False)
# Add foreign key constraints
op.create_foreign_key(
'fk_notifications_template_id', 'notifications', 'notification_templates',
['template_id'], ['id'], ondelete='SET NULL'
)
op.create_foreign_key(
'fk_notifications_channel_id', 'notifications', 'notification_channels',
['channel_id'], ['id'], ondelete='CASCADE'
)
op.create_foreign_key(
'fk_notifications_user_id', 'notifications', 'users',
['user_id'], ['id'], ondelete='SET NULL'
)
def downgrade():
# Drop notifications table
op.drop_constraint('fk_notifications_user_id', table_name='notifications', type_='foreignkey')
op.drop_constraint('fk_notifications_channel_id', table_name='notifications', type_='foreignkey')
op.drop_constraint('fk_notifications_template_id', table_name='notifications', type_='foreignkey')
op.drop_index(op.f('ix_notifications_scheduled_at'), table_name='notifications')
op.drop_index(op.f('ix_notifications_status'), table_name='notifications')
op.drop_index(op.f('ix_notifications_id'), table_name='notifications')
op.drop_table('notifications')
# Drop notification_channels table
op.drop_index(op.f('ix_notification_channels_name'), table_name='notification_channels')
op.drop_index(op.f('ix_notification_channels_id'), table_name='notification_channels')
op.drop_table('notification_channels')
# Drop notification_templates table
op.drop_index(op.f('ix_notification_templates_name'), table_name='notification_templates')
op.drop_index(op.f('ix_notification_templates_id'), table_name='notification_templates')
op.drop_table('notification_templates')

View File

@@ -0,0 +1,26 @@
"""Add force_password_change to users
Revision ID: 004_add_force_password_change
Revises: fd999a559a35
Create Date: 2025-01-31 10:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '004_add_force_password_change'
down_revision = 'fd999a559a35'
branch_labels = None
depends_on = None
def upgrade():
# Add force_password_change column to users table
op.add_column('users', sa.Column('force_password_change', sa.Boolean(), default=False, nullable=False, server_default='false'))
def downgrade():
# Remove force_password_change column from users table
op.drop_column('users', 'force_password_change')

View File

@@ -0,0 +1,79 @@
"""fix user nullable columns
Revision ID: 005_fix_user_nullable_columns
Revises: 004_add_force_password_change
Create Date: 2025-11-20 08:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "005_fix_user_nullable_columns"
down_revision = "004_add_force_password_change"
branch_labels = None
depends_on = None
def upgrade() -> None:
"""
Fix nullable columns in users table:
- Backfill NULL values for account_locked and failed_login_attempts
- Set proper server defaults
- Alter columns to NOT NULL
"""
# Use connection to execute raw SQL for backfilling
conn = op.get_bind()
# Backfill NULL values for account_locked
conn.execute(
sa.text("UPDATE users SET account_locked = FALSE WHERE account_locked IS NULL")
)
# Backfill NULL values for failed_login_attempts
conn.execute(
sa.text("UPDATE users SET failed_login_attempts = 0 WHERE failed_login_attempts IS NULL")
)
# Backfill NULL values for custom_permissions (use empty JSON object)
conn.execute(
sa.text("UPDATE users SET custom_permissions = '{}' WHERE custom_permissions IS NULL")
)
# Now alter columns to NOT NULL with server defaults
# Note: PostgreSQL syntax
op.alter_column('users', 'account_locked',
existing_type=sa.Boolean(),
nullable=False,
server_default=sa.false())
op.alter_column('users', 'failed_login_attempts',
existing_type=sa.Integer(),
nullable=False,
server_default='0')
op.alter_column('users', 'custom_permissions',
existing_type=sa.JSON(),
nullable=False,
server_default='{}')
def downgrade() -> None:
"""
Revert columns to nullable (original state from fd999a559a35)
"""
op.alter_column('users', 'account_locked',
existing_type=sa.Boolean(),
nullable=True,
server_default=None)
op.alter_column('users', 'failed_login_attempts',
existing_type=sa.Integer(),
nullable=True,
server_default=None)
op.alter_column('users', 'custom_permissions',
existing_type=sa.JSON(),
nullable=True,
server_default=None)

View File

@@ -0,0 +1,71 @@
"""fix missing user columns
Revision ID: fd999a559a35
Revises: 003_add_notifications_tables
Create Date: 2025-10-30 11:33:42.236622
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "fd999a559a35"
down_revision = "003_add_notifications_tables"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add missing columns to users table
# These columns should have been added in 001_add_roles_table.py but were not
# Use try/except to handle cases where columns might already exist
try:
op.add_column("users", sa.Column("custom_permissions", sa.JSON(), nullable=True, default=dict))
except Exception:
pass # Column might already exist
try:
op.add_column("users", sa.Column("account_locked", sa.Boolean(), nullable=True, default=False))
except Exception:
pass
try:
op.add_column("users", sa.Column("account_locked_until", sa.DateTime(), nullable=True))
except Exception:
pass
try:
op.add_column("users", sa.Column("failed_login_attempts", sa.Integer(), nullable=True, default=0))
except Exception:
pass
try:
op.add_column("users", sa.Column("last_failed_login", sa.DateTime(), nullable=True))
except Exception:
pass
def downgrade() -> None:
# Remove the columns
try:
op.drop_column("users", "last_failed_login")
except Exception:
pass
try:
op.drop_column("users", "failed_login_attempts")
except Exception:
pass
try:
op.drop_column("users", "account_locked_until")
except Exception:
pass
try:
op.drop_column("users", "account_locked")
except Exception:
pass
try:
op.drop_column("users", "custom_permissions")
except Exception:
pass

View File

@@ -4,4 +4,4 @@ Enclava - Modular AI Platform
__version__ = "1.0.0"
__author__ = "Enclava Team"
__description__ = "Enclava - Modular AI Platform with confidential processing"
__description__ = "Enclava - Modular AI Platform with confidential processing"

View File

@@ -1,3 +1,3 @@
"""
API package
"""
"""

View File

@@ -44,7 +44,9 @@ class HealthChecker:
await session.execute(select(1))
# Check table availability
await session.execute(text("SELECT COUNT(*) FROM information_schema.tables"))
await session.execute(
text("SELECT COUNT(*) FROM information_schema.tables")
)
duration = time.time() - start_time
@@ -54,8 +56,8 @@ class HealthChecker:
"timestamp": datetime.utcnow().isoformat(),
"details": {
"connection": "successful",
"query_execution": "successful"
}
"query_execution": "successful",
},
}
except Exception as e:
@@ -64,10 +66,7 @@ class HealthChecker:
"status": "unhealthy",
"error": str(e),
"timestamp": datetime.utcnow().isoformat(),
"details": {
"connection": "failed",
"error_type": type(e).__name__
}
"details": {"connection": "failed", "error_type": type(e).__name__},
}
async def check_memory_health(self) -> Dict[str, Any]:
@@ -107,7 +106,7 @@ class HealthChecker:
"process_memory_mb": round(process_memory_mb, 2),
"system_memory_percent": memory.percent,
"system_available_gb": round(memory.available / (1024**3), 2),
"issues": issues
"issues": issues,
}
except Exception as e:
@@ -115,7 +114,7 @@ class HealthChecker:
return {
"status": "error",
"error": str(e),
"timestamp": datetime.utcnow().isoformat()
"timestamp": datetime.utcnow().isoformat(),
}
async def check_connection_health(self) -> Dict[str, Any]:
@@ -128,8 +127,16 @@ class HealthChecker:
# Analyze connections
total_connections = len(connections)
established_connections = len([c for c in connections if c.status == 'ESTABLISHED'])
http_connections = len([c for c in connections if any(port in str(c.laddr) for port in [80, 8000, 3000])])
established_connections = len(
[c for c in connections if c.status == "ESTABLISHED"]
)
http_connections = len(
[
c
for c in connections
if any(port in str(c.laddr) for port in [80, 8000, 3000])
]
)
# Check for connection issues
connection_status = "healthy"
@@ -154,7 +161,7 @@ class HealthChecker:
"total_connections": total_connections,
"established_connections": established_connections,
"http_connections": http_connections,
"issues": issues
"issues": issues,
}
except Exception as e:
@@ -162,7 +169,7 @@ class HealthChecker:
return {
"status": "error",
"error": str(e),
"timestamp": datetime.utcnow().isoformat()
"timestamp": datetime.utcnow().isoformat(),
}
async def check_embedding_service_health(self) -> Dict[str, Any]:
@@ -193,7 +200,7 @@ class HealthChecker:
"response_time_ms": round(duration * 1000, 2),
"timestamp": datetime.utcnow().isoformat(),
"stats": stats,
"issues": issues
"issues": issues,
}
except Exception as e:
@@ -201,7 +208,7 @@ class HealthChecker:
return {
"status": "error",
"error": str(e),
"timestamp": datetime.utcnow().isoformat()
"timestamp": datetime.utcnow().isoformat(),
}
async def check_redis_health(self) -> Dict[str, Any]:
@@ -209,7 +216,7 @@ class HealthChecker:
if not settings.REDIS_URL:
return {
"status": "not_configured",
"timestamp": datetime.utcnow().isoformat()
"timestamp": datetime.utcnow().isoformat(),
}
try:
@@ -233,7 +240,7 @@ class HealthChecker:
return {
"status": "healthy",
"response_time_ms": round(duration * 1000, 2),
"timestamp": datetime.utcnow().isoformat()
"timestamp": datetime.utcnow().isoformat(),
}
except Exception as e:
@@ -241,7 +248,7 @@ class HealthChecker:
return {
"status": "unhealthy",
"error": str(e),
"timestamp": datetime.utcnow().isoformat()
"timestamp": datetime.utcnow().isoformat(),
}
async def get_comprehensive_health(self) -> Dict[str, Any]:
@@ -251,7 +258,7 @@ class HealthChecker:
"memory": await self.check_memory_health(),
"connections": await self.check_connection_health(),
"embedding_service": await self.check_embedding_service_health(),
"redis": await self.check_redis_health()
"redis": await self.check_redis_health(),
}
# Determine overall status
@@ -274,12 +281,16 @@ class HealthChecker:
"summary": {
"total_checks": len(checks),
"healthy_checks": len([s for s in statuses if s == "healthy"]),
"degraded_checks": len([s for s in statuses if s in ["warning", "degraded", "unhealthy"]]),
"failed_checks": len([s for s in statuses if s in ["critical", "error"]]),
"total_issues": total_issues
"degraded_checks": len(
[s for s in statuses if s in ["warning", "degraded", "unhealthy"]]
),
"failed_checks": len(
[s for s in statuses if s in ["critical", "error"]]
),
"total_issues": total_issues,
},
"version": "1.0.0",
"uptime_seconds": int(time.time() - psutil.boot_time())
"uptime_seconds": int(time.time() - psutil.boot_time()),
}
@@ -294,7 +305,7 @@ async def basic_health_check():
"status": "healthy",
"app": settings.APP_NAME,
"version": "1.0.0",
"timestamp": datetime.utcnow().isoformat()
"timestamp": datetime.utcnow().isoformat(),
}
@@ -307,7 +318,7 @@ async def detailed_health_check():
logger.error(f"Detailed health check failed: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Health check failed: {str(e)}"
detail=f"Health check failed: {str(e)}",
)
@@ -320,7 +331,7 @@ async def memory_health_check():
logger.error(f"Memory health check failed: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Memory health check failed: {str(e)}"
detail=f"Memory health check failed: {str(e)}",
)
@@ -333,7 +344,7 @@ async def connection_health_check():
logger.error(f"Connection health check failed: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Connection health check failed: {str(e)}"
detail=f"Connection health check failed: {str(e)}",
)
@@ -346,5 +357,5 @@ async def embedding_service_health_check():
logger.error(f"Embedding service health check failed: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Embedding service health check failed: {str(e)}"
)
detail=f"Embedding service health check failed: {str(e)}",
)

View File

@@ -19,6 +19,7 @@ from ..v1.platform import router as platform_router
from ..v1.llm_internal import router as llm_internal_router
from ..v1.chatbot import router as chatbot_router
from .debugging import router as debugging_router
from ..v1.endpoints.user_management import router as user_management_router
# Create internal API router
internal_api_router = APIRouter()
@@ -27,47 +28,82 @@ internal_api_router = APIRouter()
internal_api_router.include_router(auth_router, prefix="/auth", tags=["internal-auth"])
# Include modules routes (frontend management)
internal_api_router.include_router(modules_router, prefix="/modules", tags=["internal-modules"])
internal_api_router.include_router(
modules_router, prefix="/modules", tags=["internal-modules"]
)
# Include platform routes (frontend platform management)
internal_api_router.include_router(platform_router, prefix="/platform", tags=["internal-platform"])
# Include platform routes (frontend platform management)
internal_api_router.include_router(
platform_router, prefix="/platform", tags=["internal-platform"]
)
# Include user management routes (frontend user admin)
internal_api_router.include_router(users_router, prefix="/users", tags=["internal-users"])
internal_api_router.include_router(
users_router, prefix="/users", tags=["internal-users"]
)
# Include API key management routes (frontend API key management)
internal_api_router.include_router(api_keys_router, prefix="/api-keys", tags=["internal-api-keys"])
internal_api_router.include_router(
api_keys_router, prefix="/api-keys", tags=["internal-api-keys"]
)
# Include budget management routes (frontend budget management)
internal_api_router.include_router(budgets_router, prefix="/budgets", tags=["internal-budgets"])
internal_api_router.include_router(
budgets_router, prefix="/budgets", tags=["internal-budgets"]
)
# Include audit log routes (frontend audit viewing)
internal_api_router.include_router(audit_router, prefix="/audit", tags=["internal-audit"])
internal_api_router.include_router(
audit_router, prefix="/audit", tags=["internal-audit"]
)
# Include settings management routes (frontend settings)
internal_api_router.include_router(settings_router, prefix="/settings", tags=["internal-settings"])
internal_api_router.include_router(
settings_router, prefix="/settings", tags=["internal-settings"]
)
# Include analytics routes (frontend analytics viewing)
internal_api_router.include_router(analytics_router, prefix="/analytics", tags=["internal-analytics"])
internal_api_router.include_router(
analytics_router, prefix="/analytics", tags=["internal-analytics"]
)
# Include RAG routes (frontend RAG document management)
internal_api_router.include_router(rag_router, prefix="/rag", tags=["internal-rag"])
# Include RAG debug routes (for demo and debugging)
internal_api_router.include_router(rag_debug_router, prefix="/rag/debug", tags=["internal-rag-debug"])
internal_api_router.include_router(
rag_debug_router, prefix="/rag/debug", tags=["internal-rag-debug"]
)
# Include prompt template routes (frontend prompt template management)
internal_api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["internal-prompt-templates"])
internal_api_router.include_router(
prompt_templates_router,
prefix="/prompt-templates",
tags=["internal-prompt-templates"],
)
# Include plugin registry routes (frontend plugin management)
internal_api_router.include_router(plugin_registry_router, prefix="/plugins", tags=["internal-plugins"])
internal_api_router.include_router(
plugin_registry_router, prefix="/plugins", tags=["internal-plugins"]
)
# Include internal LLM routes (frontend LLM service access with JWT auth)
internal_api_router.include_router(llm_internal_router, prefix="/llm", tags=["internal-llm"])
internal_api_router.include_router(
llm_internal_router, prefix="/llm", tags=["internal-llm"]
)
# Include chatbot routes (frontend chatbot management)
internal_api_router.include_router(chatbot_router, prefix="/chatbot", tags=["internal-chatbot"])
internal_api_router.include_router(
chatbot_router, prefix="/chatbot", tags=["internal-chatbot"]
)
# Include debugging routes (troubleshooting and diagnostics)
internal_api_router.include_router(debugging_router, prefix="/debugging", tags=["internal-debugging"])
internal_api_router.include_router(
debugging_router, prefix="/debugging", tags=["internal-debugging"]
)
# Include user management routes (advanced user and role management)
internal_api_router.include_router(
user_management_router, prefix="/user-management", tags=["internal-user-management"]
)

View File

@@ -20,23 +20,28 @@ router = APIRouter()
async def get_chatbot_config_debug(
chatbot_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Get detailed configuration for debugging a specific chatbot"""
# Get chatbot instance
chatbot = db.query(ChatbotInstance).filter(
ChatbotInstance.id == chatbot_id,
ChatbotInstance.user_id == current_user.id
).first()
chatbot = (
db.query(ChatbotInstance)
.filter(
ChatbotInstance.id == chatbot_id, ChatbotInstance.user_id == current_user.id
)
.first()
)
if not chatbot:
raise HTTPException(status_code=404, detail="Chatbot not found")
# Get prompt template
prompt_template = db.query(PromptTemplate).filter(
PromptTemplate.type == chatbot.chatbot_type
).first()
prompt_template = (
db.query(PromptTemplate)
.filter(PromptTemplate.type == chatbot.chatbot_type)
.first()
)
# Get RAG collections if configured
rag_collections = []
@@ -44,31 +49,37 @@ async def get_chatbot_config_debug(
collection_ids = chatbot.rag_collection_ids
if isinstance(collection_ids, str):
import json
try:
collection_ids = json.loads(collection_ids)
except:
collection_ids = []
if collection_ids:
collections = db.query(RagCollection).filter(
RagCollection.id.in_(collection_ids)
).all()
collections = (
db.query(RagCollection)
.filter(RagCollection.id.in_(collection_ids))
.all()
)
rag_collections = [
{
"id": col.id,
"name": col.name,
"document_count": col.document_count,
"qdrant_collection_name": col.qdrant_collection_name,
"is_active": col.is_active
"is_active": col.is_active,
}
for col in collections
]
# Get recent conversations count
from app.models.chatbot import ChatbotConversation
conversation_count = db.query(ChatbotConversation).filter(
ChatbotConversation.chatbot_instance_id == chatbot_id
).count()
conversation_count = (
db.query(ChatbotConversation)
.filter(ChatbotConversation.chatbot_instance_id == chatbot_id)
.count()
)
return {
"chatbot": {
@@ -78,20 +89,20 @@ async def get_chatbot_config_debug(
"description": chatbot.description,
"created_at": chatbot.created_at,
"is_active": chatbot.is_active,
"conversation_count": conversation_count
"conversation_count": conversation_count,
},
"prompt_template": {
"type": prompt_template.type if prompt_template else None,
"system_prompt": prompt_template.system_prompt if prompt_template else None,
"variables": prompt_template.variables if prompt_template else []
"variables": prompt_template.variables if prompt_template else [],
},
"rag_collections": rag_collections,
"configuration": {
"max_tokens": chatbot.max_tokens,
"temperature": chatbot.temperature,
"streaming": chatbot.streaming,
"memory_config": chatbot.memory_config
}
"memory_config": chatbot.memory_config,
},
}
@@ -101,15 +112,18 @@ async def test_rag_search(
query: str = "test query",
top_k: int = 5,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Test RAG search for a specific chatbot"""
# Get chatbot instance
chatbot = db.query(ChatbotInstance).filter(
ChatbotInstance.id == chatbot_id,
ChatbotInstance.user_id == current_user.id
).first()
chatbot = (
db.query(ChatbotInstance)
.filter(
ChatbotInstance.id == chatbot_id, ChatbotInstance.user_id == current_user.id
)
.first()
)
if not chatbot:
raise HTTPException(status_code=404, detail="Chatbot not found")
@@ -123,6 +137,7 @@ async def test_rag_search(
if chatbot.rag_collection_ids:
if isinstance(chatbot.rag_collection_ids, str):
import json
try:
collection_ids = json.loads(chatbot.rag_collection_ids)
except:
@@ -134,22 +149,19 @@ async def test_rag_search(
return {
"query": query,
"results": [],
"message": "No RAG collections configured for this chatbot"
"message": "No RAG collections configured for this chatbot",
}
# Perform search
search_results = await rag_module.search(
query=query,
collection_ids=collection_ids,
top_k=top_k,
score_threshold=0.5
query=query, collection_ids=collection_ids, top_k=top_k, score_threshold=0.5
)
return {
"query": query,
"results": search_results,
"collections_searched": collection_ids,
"result_count": len(search_results)
"result_count": len(search_results),
}
except Exception as e:
@@ -157,14 +169,13 @@ async def test_rag_search(
"query": query,
"results": [],
"error": str(e),
"message": "RAG search failed"
"message": "RAG search failed",
}
@router.get("/system/status")
async def get_system_status(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
db: Session = Depends(get_db), current_user: User = Depends(get_current_user)
):
"""Get system status for debugging"""
@@ -179,11 +190,12 @@ async def get_system_status(
module_status = {}
try:
from app.services.module_manager import module_manager
modules = module_manager.list_modules()
for module_name, module_info in modules.items():
module_status[module_name] = {
"status": module_info.get("status", "unknown"),
"enabled": module_info.get("enabled", False)
"enabled": module_info.get("enabled", False),
}
except Exception as e:
module_status = {"error": str(e)}
@@ -192,6 +204,7 @@ async def get_system_status(
redis_status = "not configured"
try:
from app.core.cache import core_cache
await core_cache.ping()
redis_status = "healthy"
except Exception as e:
@@ -201,6 +214,7 @@ async def get_system_status(
qdrant_status = "not configured"
try:
from app.services.qdrant_service import qdrant_service
collections = await qdrant_service.list_collections()
qdrant_status = f"healthy ({len(collections)} collections)"
except Exception as e:
@@ -211,5 +225,5 @@ async def get_system_status(
"modules": module_status,
"redis": redis_status,
"qdrant": qdrant_status,
"timestamp": "UTC"
}
"timestamp": "UTC",
}

View File

@@ -3,6 +3,7 @@ Public API v1 package - for external clients
"""
from fastapi import APIRouter
from ..v1.auth import router as auth_router
from ..v1.llm import router as llm_router
from ..v1.chatbot import router as chatbot_router
from ..v1.openai_compat import router as openai_router
@@ -10,6 +11,9 @@ from ..v1.openai_compat import router as openai_router
# Create public API router
public_api_router = APIRouter()
# Include authentication routes (needed for login/logout)
public_api_router.include_router(auth_router, prefix="/auth", tags=["authentication"])
# Include OpenAI-compatible routes (chat/completions, models, embeddings)
public_api_router.include_router(openai_router, tags=["openai-compat"])
@@ -17,4 +21,6 @@ public_api_router.include_router(openai_router, tags=["openai-compat"])
public_api_router.include_router(llm_router, prefix="/llm", tags=["public-llm"])
# Include public chatbot API (external chatbot integrations)
public_api_router.include_router(chatbot_router, prefix="/chatbot", tags=["public-chatbot"])
public_api_router.include_router(
chatbot_router, prefix="/chatbot", tags=["public-chatbot"]
)

View File

@@ -16,10 +16,9 @@ logger = logging.getLogger(__name__)
# Create router
router = APIRouter()
@router.get("/collections")
async def list_collections(
current_user: User = Depends(get_current_user)
):
async def list_collections(current_user: User = Depends(get_current_user)):
"""List all available RAG collections"""
try:
from app.services.qdrant_stats_service import qdrant_stats_service
@@ -31,10 +30,7 @@ async def list_collections(
# Extract collection names
collection_names = [col["name"] for col in collections]
return {
"collections": collection_names,
"count": len(collection_names)
}
return {"collections": collection_names, "count": len(collection_names)}
except Exception as e:
logger.error(f"List collections error: {e}")
@@ -45,10 +41,14 @@ async def list_collections(
async def debug_search(
query: str = Query(..., description="Search query"),
max_results: int = Query(10, ge=1, le=50, description="Maximum number of results"),
score_threshold: float = Query(0.3, ge=0.0, le=1.0, description="Minimum score threshold"),
collection_name: Optional[str] = Query(None, description="Collection name to search"),
score_threshold: float = Query(
0.3, ge=0.0, le=1.0, description="Minimum score threshold"
),
collection_name: Optional[str] = Query(
None, description="Collection name to search"
),
config: Optional[Dict[str, Any]] = None,
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Debug search endpoint with detailed information"""
try:
@@ -56,9 +56,7 @@ async def debug_search(
app_config = settings
# Initialize RAG module with BGE-M3 configuration
rag_config = {
"embedding_model": "BAAI/bge-m3"
}
rag_config = {"embedding_model": "BAAI/bge-m3"}
rag_module = RAGModule(app_config, config=rag_config)
# Get available collections if none specified
@@ -71,9 +69,9 @@ async def debug_search(
"results": [],
"debug_info": {
"error": "No collections available",
"collections_found": 0
"collections_found": 0,
},
"search_time_ms": 0
"search_time_ms": 0,
}
# Perform search
@@ -82,7 +80,7 @@ async def debug_search(
max_results=max_results,
score_threshold=score_threshold,
collection_name=collection_name,
config=config or {}
config=config or {},
)
return results
@@ -94,7 +92,7 @@ async def debug_search(
"debug_info": {
"error": str(e),
"query": query,
"collection_name": collection_name
"collection_name": collection_name,
},
"search_time_ms": 0
}
"search_time_ms": 0,
}

View File

@@ -17,6 +17,9 @@ from .rag import router as rag_router
from .chatbot import router as chatbot_router
from .prompt_templates import router as prompt_templates_router
from .plugin_registry import router as plugin_registry_router
from .endpoints.tools import router as tools_router
from .endpoints.tool_calling import router as tool_calling_router
from .endpoints.user_management import router as user_management_router
# Create main API router
api_router = APIRouter()
@@ -58,9 +61,23 @@ api_router.include_router(rag_router, prefix="/rag", tags=["rag"])
api_router.include_router(chatbot_router, prefix="/chatbot", tags=["chatbot"])
# Include prompt template routes
api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["prompt-templates"])
api_router.include_router(
prompt_templates_router, prefix="/prompt-templates", tags=["prompt-templates"]
)
# Include plugin registry routes
api_router.include_router(plugin_registry_router, prefix="/plugins", tags=["plugins"])
api_router.include_router(plugin_registry_router, prefix="/plugins", tags=["plugins"])
# Include tool management routes
api_router.include_router(tools_router, prefix="/tools", tags=["tools"])
# Include tool calling routes
api_router.include_router(
tool_calling_router, prefix="/tool-calling", tags=["tool-calling"]
)
# Include admin user management routes
api_router.include_router(
user_management_router, prefix="/admin/user-management", tags=["admin", "user-management"]
)

View File

@@ -22,110 +22,103 @@ router = APIRouter()
async def get_usage_metrics(
hours: int = Query(24, ge=1, le=168, description="Hours to analyze (1-168)"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""Get comprehensive usage metrics including costs and budgets"""
try:
analytics = get_analytics_service()
metrics = await analytics.get_usage_metrics(hours=hours, user_id=current_user['id'])
return {
"success": True,
"data": metrics,
"period_hours": hours
}
metrics = await analytics.get_usage_metrics(
hours=hours, user_id=current_user["id"]
)
return {"success": True, "data": metrics, "period_hours": hours}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting usage metrics: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error getting usage metrics: {str(e)}"
)
@router.get("/metrics/system")
async def get_system_metrics(
hours: int = Query(24, ge=1, le=168),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""Get system-wide metrics (admin only)"""
if not current_user['is_superuser']:
if not current_user["is_superuser"]:
raise HTTPException(status_code=403, detail="Admin access required")
try:
analytics = get_analytics_service()
metrics = await analytics.get_usage_metrics(hours=hours)
return {
"success": True,
"data": metrics,
"period_hours": hours
}
return {"success": True, "data": metrics, "period_hours": hours}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting system metrics: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error getting system metrics: {str(e)}"
)
@router.get("/health")
async def get_system_health(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
):
"""Get system health status including budget and performance analysis"""
try:
analytics = get_analytics_service()
health = await analytics.get_system_health()
return {
"success": True,
"data": health
}
return {"success": True, "data": health}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting system health: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error getting system health: {str(e)}"
)
@router.get("/costs")
async def get_cost_analysis(
days: int = Query(30, ge=1, le=365, description="Days to analyze (1-365)"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""Get detailed cost analysis and trends"""
try:
analytics = get_analytics_service()
analysis = await analytics.get_cost_analysis(days=days, user_id=current_user['id'])
return {
"success": True,
"data": analysis,
"period_days": days
}
analysis = await analytics.get_cost_analysis(
days=days, user_id=current_user["id"]
)
return {"success": True, "data": analysis, "period_days": days}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting cost analysis: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error getting cost analysis: {str(e)}"
)
@router.get("/costs/system")
async def get_system_cost_analysis(
days: int = Query(30, ge=1, le=365),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""Get system-wide cost analysis (admin only)"""
if not current_user['is_superuser']:
if not current_user["is_superuser"]:
raise HTTPException(status_code=403, detail="Admin access required")
try:
analytics = get_analytics_service()
analysis = await analytics.get_cost_analysis(days=days)
return {
"success": True,
"data": analysis,
"period_days": days
}
return {"success": True, "data": analysis, "period_days": days}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting system cost analysis: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error getting system cost analysis: {str(e)}"
)
@router.get("/endpoints")
async def get_endpoint_stats(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
):
"""Get endpoint usage statistics"""
try:
analytics = get_analytics_service()
# For now, return the in-memory stats
# In future, this could be enhanced with database queries
return {
@@ -133,74 +126,79 @@ async def get_endpoint_stats(
"data": {
"endpoint_stats": dict(analytics.endpoint_stats),
"status_codes": dict(analytics.status_codes),
"model_stats": dict(analytics.model_stats)
}
"model_stats": dict(analytics.model_stats),
},
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting endpoint stats: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error getting endpoint stats: {str(e)}"
)
@router.get("/usage-trends")
async def get_usage_trends(
days: int = Query(7, ge=1, le=30, description="Days for trend analysis"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""Get usage trends over time"""
try:
from datetime import datetime, timedelta
from sqlalchemy import func
from app.models.usage_tracking import UsageTracking
cutoff_time = datetime.utcnow() - timedelta(days=days)
# Daily usage trends
daily_usage = db.query(
func.date(UsageTracking.created_at).label('date'),
func.count(UsageTracking.id).label('requests'),
func.sum(UsageTracking.total_tokens).label('tokens'),
func.sum(UsageTracking.cost_cents).label('cost_cents')
).filter(
UsageTracking.created_at >= cutoff_time,
UsageTracking.user_id == current_user['id']
).group_by(
func.date(UsageTracking.created_at)
).order_by('date').all()
daily_usage = (
db.query(
func.date(UsageTracking.created_at).label("date"),
func.count(UsageTracking.id).label("requests"),
func.sum(UsageTracking.total_tokens).label("tokens"),
func.sum(UsageTracking.cost_cents).label("cost_cents"),
)
.filter(
UsageTracking.created_at >= cutoff_time,
UsageTracking.user_id == current_user["id"],
)
.group_by(func.date(UsageTracking.created_at))
.order_by("date")
.all()
)
trends = []
for date, requests, tokens, cost_cents in daily_usage:
trends.append({
"date": date.isoformat(),
"requests": requests,
"tokens": tokens or 0,
"cost_cents": cost_cents or 0,
"cost_dollars": (cost_cents or 0) / 100
})
return {
"success": True,
"data": {
"trends": trends,
"period_days": days
}
}
trends.append(
{
"date": date.isoformat(),
"requests": requests,
"tokens": tokens or 0,
"cost_cents": cost_cents or 0,
"cost_dollars": (cost_cents or 0) / 100,
}
)
return {"success": True, "data": {"trends": trends, "period_days": days}}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting usage trends: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error getting usage trends: {str(e)}"
)
@router.get("/overview")
async def get_analytics_overview(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
):
"""Get analytics overview data"""
try:
analytics = get_analytics_service()
# Get basic metrics
metrics = await analytics.get_usage_metrics(hours=24, user_id=current_user['id'])
metrics = await analytics.get_usage_metrics(
hours=24, user_id=current_user["id"]
)
health = await analytics.get_system_health()
return {
"success": True,
"data": {
@@ -210,8 +208,8 @@ async def get_analytics_overview(
"error_rate": metrics.error_rate,
"budget_usage_percentage": metrics.budget_usage_percentage,
"system_health": health.status,
"health_score": health.score
}
"health_score": health.score,
},
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting overview: {str(e)}")
@@ -219,16 +217,15 @@ async def get_analytics_overview(
@router.get("/modules")
async def get_module_analytics(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
):
"""Get analytics data for all modules"""
try:
module_stats = []
for name, module in module_manager.modules.items():
stats = {"name": name, "initialized": getattr(module, "initialized", False)}
# Get module statistics if available
if hasattr(module, "get_stats"):
try:
@@ -240,18 +237,22 @@ async def get_module_analytics(
except Exception as e:
logger.warning(f"Failed to get stats for module {name}: {e}")
stats["error"] = str(e)
module_stats.append(stats)
return {
"success": True,
"data": {
"modules": module_stats,
"total_modules": len(module_stats),
"system_health": "healthy" if all(m.get("initialized", False) for m in module_stats) else "warning"
}
"system_health": "healthy"
if all(m.get("initialized", False) for m in module_stats)
else "warning",
},
}
except Exception as e:
logger.error(f"Failed to get module analytics: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve module analytics")
raise HTTPException(
status_code=500, detail="Failed to retrieve module analytics"
)

View File

@@ -74,47 +74,49 @@ class APIKeyResponse(BaseModel):
last_used_at: Optional[datetime] = None
total_requests: int
total_tokens: int
total_cost_cents: int = Field(alias='total_cost')
total_cost_cents: int = Field(alias="total_cost")
rate_limit_per_minute: Optional[int] = None
rate_limit_per_hour: Optional[int] = None
rate_limit_per_day: Optional[int] = None
allowed_ips: List[str]
allowed_models: List[str] # Model restrictions
allowed_chatbots: List[str] # Chatbot restrictions
budget_limit: Optional[int] = Field(None, alias='budget_limit_cents') # Budget limit in cents
budget_limit: Optional[int] = Field(
None, alias="budget_limit_cents"
) # Budget limit in cents
budget_type: Optional[str] = None # Budget type
is_unlimited: bool = True # Unlimited budget flag
tags: List[str]
class Config:
from_attributes = True
@classmethod
def from_api_key(cls, api_key):
"""Create response from APIKey model with formatted key prefix"""
data = {
'id': api_key.id,
'name': api_key.name,
'description': api_key.description,
'key_prefix': api_key.key_prefix + "..." if api_key.key_prefix else "",
'scopes': api_key.scopes,
'is_active': api_key.is_active,
'expires_at': api_key.expires_at,
'created_at': api_key.created_at,
'last_used_at': api_key.last_used_at,
'total_requests': api_key.total_requests,
'total_tokens': api_key.total_tokens,
'total_cost': api_key.total_cost,
'rate_limit_per_minute': api_key.rate_limit_per_minute,
'rate_limit_per_hour': api_key.rate_limit_per_hour,
'rate_limit_per_day': api_key.rate_limit_per_day,
'allowed_ips': api_key.allowed_ips,
'allowed_models': api_key.allowed_models,
'allowed_chatbots': api_key.allowed_chatbots,
'budget_limit_cents': api_key.budget_limit_cents,
'budget_type': api_key.budget_type,
'is_unlimited': api_key.is_unlimited,
'tags': api_key.tags
"id": api_key.id,
"name": api_key.name,
"description": api_key.description,
"key_prefix": api_key.key_prefix + "..." if api_key.key_prefix else "",
"scopes": api_key.scopes,
"is_active": api_key.is_active,
"expires_at": api_key.expires_at,
"created_at": api_key.created_at,
"last_used_at": api_key.last_used_at,
"total_requests": api_key.total_requests,
"total_tokens": api_key.total_tokens,
"total_cost": api_key.total_cost,
"rate_limit_per_minute": api_key.rate_limit_per_minute,
"rate_limit_per_hour": api_key.rate_limit_per_hour,
"rate_limit_per_day": api_key.rate_limit_per_day,
"allowed_ips": api_key.allowed_ips,
"allowed_models": api_key.allowed_models,
"allowed_chatbots": api_key.allowed_chatbots,
"budget_limit_cents": api_key.budget_limit_cents,
"budget_type": api_key.budget_type,
"is_unlimited": api_key.is_unlimited,
"tags": api_key.tags,
}
return cls(**data)
@@ -148,15 +150,18 @@ class APIKeyUsageResponse(BaseModel):
def generate_api_key() -> tuple[str, str]:
"""Generate a new API key and return (full_key, key_hash)"""
# Generate random key part (32 characters)
key_part = ''.join(secrets.choice(string.ascii_letters + string.digits) for _ in range(32))
key_part = "".join(
secrets.choice(string.ascii_letters + string.digits) for _ in range(32)
)
# Create full key with prefix
full_key = f"{settings.API_KEY_PREFIX}{key_part}"
# Create hash for storage
from app.core.security import get_api_key_hash
key_hash = get_api_key_hash(full_key)
return full_key, key_hash
@@ -169,73 +174,87 @@ async def list_api_keys(
is_active: Optional[bool] = Query(None),
search: Optional[str] = Query(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""List API keys with pagination and filtering"""
# Check permissions - users can view their own API keys
if user_id and int(user_id) != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
if user_id and int(user_id) != current_user["id"]:
require_permission(
current_user.get("permissions", []), "platform:api-keys:read"
)
elif not user_id:
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
require_permission(
current_user.get("permissions", []), "platform:api-keys:read"
)
# If no user_id specified and user doesn't have admin permissions, show only their keys
if not user_id and "platform:api-keys:read" not in current_user.get("permissions", []):
user_id = current_user['id']
if not user_id and "platform:api-keys:read" not in current_user.get(
"permissions", []
):
user_id = current_user["id"]
# Build query
query = select(APIKey)
# Apply filters
if user_id:
query = query.where(APIKey.user_id == (int(user_id) if isinstance(user_id, str) else user_id))
query = query.where(
APIKey.user_id == (int(user_id) if isinstance(user_id, str) else user_id)
)
if is_active is not None:
query = query.where(APIKey.is_active == is_active)
if search:
query = query.where(
(APIKey.name.ilike(f"%{search}%")) |
(APIKey.description.ilike(f"%{search}%"))
(APIKey.name.ilike(f"%{search}%"))
| (APIKey.description.ilike(f"%{search}%"))
)
# Get total count using func.count()
total_query = select(func.count(APIKey.id))
# Apply same filters for count
if user_id:
total_query = total_query.where(APIKey.user_id == (int(user_id) if isinstance(user_id, str) else user_id))
total_query = total_query.where(
APIKey.user_id == (int(user_id) if isinstance(user_id, str) else user_id)
)
if is_active is not None:
total_query = total_query.where(APIKey.is_active == is_active)
if search:
total_query = total_query.where(
(APIKey.name.ilike(f"%{search}%")) |
(APIKey.description.ilike(f"%{search}%"))
(APIKey.name.ilike(f"%{search}%"))
| (APIKey.description.ilike(f"%{search}%"))
)
total_result = await db.execute(total_query)
total = total_result.scalar()
# Apply pagination
offset = (page - 1) * size
query = query.offset(offset).limit(size).order_by(APIKey.created_at.desc())
# Execute query
result = await db.execute(query)
api_keys = result.scalars().all()
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="list_api_keys",
resource_type="api_key",
details={"page": page, "size": size, "filters": {"user_id": user_id, "is_active": is_active, "search": search}}
details={
"page": page,
"size": size,
"filters": {"user_id": user_id, "is_active": is_active, "search": search},
},
)
return APIKeyListResponse(
api_keys=[APIKeyResponse.model_validate(key) for key in api_keys],
total=total,
page=page,
size=size
size=size,
)
@@ -243,34 +262,35 @@ async def list_api_keys(
async def get_api_key(
api_key_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get API key by ID"""
# Get API key
query = select(APIKey).where(APIKey.id == int(api_key_id))
result = await db.execute(query)
api_key = result.scalar_one_or_none()
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
)
# Check permissions - users can view their own API keys
if api_key.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
if api_key.user_id != current_user["id"]:
require_permission(
current_user.get("permissions", []), "platform:api-keys:read"
)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="get_api_key",
resource_type="api_key",
resource_id=api_key_id
resource_id=api_key_id,
)
return APIKeyResponse.model_validate(api_key)
@@ -278,24 +298,24 @@ async def get_api_key(
async def create_api_key(
api_key_data: APIKeyCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Create a new API key"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:api-keys:create")
# Generate API key
full_key, key_hash = generate_api_key()
key_prefix = full_key[:8] # Store only first 8 characters for lookup
# Create API key
new_api_key = APIKey(
name=api_key_data.name,
description=api_key_data.description,
key_hash=key_hash,
key_prefix=key_prefix,
user_id=current_user['id'],
user_id=current_user["id"],
scopes=api_key_data.scopes,
expires_at=api_key_data.expires_at,
rate_limit_per_minute=api_key_data.rate_limit_per_minute,
@@ -305,29 +325,32 @@ async def create_api_key(
allowed_models=api_key_data.allowed_models,
allowed_chatbots=api_key_data.allowed_chatbots,
is_unlimited=api_key_data.is_unlimited,
budget_limit_cents=api_key_data.budget_limit_cents if not api_key_data.is_unlimited else None,
budget_limit_cents=api_key_data.budget_limit_cents
if not api_key_data.is_unlimited
else None,
budget_type=api_key_data.budget_type if not api_key_data.is_unlimited else None,
tags=api_key_data.tags
tags=api_key_data.tags,
)
db.add(new_api_key)
await db.commit()
await db.refresh(new_api_key)
# Log audit event asynchronously (non-blocking)
asyncio.create_task(log_audit_event_async(
user_id=str(current_user['id']),
action="create_api_key",
resource_type="api_key",
resource_id=str(new_api_key.id),
details={"name": api_key_data.name, "scopes": api_key_data.scopes}
))
asyncio.create_task(
log_audit_event_async(
user_id=str(current_user["id"]),
action="create_api_key",
resource_type="api_key",
resource_id=str(new_api_key.id),
details={"name": api_key_data.name, "scopes": api_key_data.scopes},
)
)
logger.info(f"API key created: {new_api_key.name} by {current_user['username']}")
return APIKeyCreateResponse(
api_key=APIKeyResponse.model_validate(new_api_key),
secret_key=full_key
api_key=APIKeyResponse.model_validate(new_api_key), secret_key=full_key
)
@@ -336,56 +359,57 @@ async def update_api_key(
api_key_id: str,
api_key_data: APIKeyUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Update API key"""
# Get API key
query = select(APIKey).where(APIKey.id == int(api_key_id))
result = await db.execute(query)
api_key = result.scalar_one_or_none()
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
)
# Check permissions - users can update their own API keys
if api_key.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:update")
if api_key.user_id != current_user["id"]:
require_permission(
current_user.get("permissions", []), "platform:api-keys:update"
)
# Store original values for audit
original_values = {
"name": api_key.name,
"scopes": api_key.scopes,
"is_active": api_key.is_active
"is_active": api_key.is_active,
}
# Update API key fields
update_data = api_key_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(api_key, field, value)
await db.commit()
await db.refresh(api_key)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="update_api_key",
resource_type="api_key",
resource_id=api_key_id,
details={
"updated_fields": list(update_data.keys()),
"before_values": original_values,
"after_values": {k: getattr(api_key, k) for k in update_data.keys()}
}
"after_values": {k: getattr(api_key, k) for k in update_data.keys()},
},
)
logger.info(f"API key updated: {api_key.name} by {current_user['username']}")
return APIKeyResponse.model_validate(api_key)
@@ -393,41 +417,42 @@ async def update_api_key(
async def delete_api_key(
api_key_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Delete API key"""
# Get API key
query = select(APIKey).where(APIKey.id == int(api_key_id))
result = await db.execute(query)
api_key = result.scalar_one_or_none()
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
)
# Check permissions - users can delete their own API keys
if api_key.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:delete")
if api_key.user_id != current_user["id"]:
require_permission(
current_user.get("permissions", []), "platform:api-keys:delete"
)
# Delete API key
await db.delete(api_key)
await db.commit()
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="delete_api_key",
resource_type="api_key",
resource_id=api_key_id,
details={"name": api_key.name}
details={"name": api_key.name},
)
logger.info(f"API key deleted: {api_key.name} by {current_user['username']}")
return {"message": "API key deleted successfully"}
@@ -435,51 +460,51 @@ async def delete_api_key(
async def regenerate_api_key(
api_key_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Regenerate API key secret"""
# Get API key
query = select(APIKey).where(APIKey.id == int(api_key_id))
result = await db.execute(query)
api_key = result.scalar_one_or_none()
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
)
# Check permissions - users can regenerate their own API keys
if api_key.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:update")
if api_key.user_id != current_user["id"]:
require_permission(
current_user.get("permissions", []), "platform:api-keys:update"
)
# Generate new API key
full_key, key_hash = generate_api_key()
key_prefix = full_key[:8] # Store only first 8 characters for lookup
# Update API key
api_key.key_hash = key_hash
api_key.key_prefix = key_prefix
await db.commit()
await db.refresh(api_key)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="regenerate_api_key",
resource_type="api_key",
resource_id=api_key_id,
details={"name": api_key.name}
details={"name": api_key.name},
)
logger.info(f"API key regenerated: {api_key.name} by {current_user['username']}")
return APIKeyCreateResponse(
api_key=APIKeyResponse.model_validate(api_key),
secret_key=full_key
api_key=APIKeyResponse.model_validate(api_key), secret_key=full_key
)
@@ -487,65 +512,64 @@ async def regenerate_api_key(
async def get_api_key_usage(
api_key_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get API key usage statistics"""
# Get API key
query = select(APIKey).where(APIKey.id == int(api_key_id))
result = await db.execute(query)
api_key = result.scalar_one_or_none()
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
)
# Check permissions - users can view their own API key usage
if api_key.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
if api_key.user_id != current_user["id"]:
require_permission(
current_user.get("permissions", []), "platform:api-keys:read"
)
# Calculate usage statistics
from app.models.usage_tracking import UsageTracking
now = datetime.utcnow()
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
hour_start = now.replace(minute=0, second=0, microsecond=0)
# Today's usage
today_query = select(
func.count(UsageTracking.id),
func.sum(UsageTracking.total_tokens),
func.sum(UsageTracking.cost_cents)
func.sum(UsageTracking.cost_cents),
).where(
UsageTracking.api_key_id == api_key_id,
UsageTracking.created_at >= today_start
UsageTracking.api_key_id == api_key_id, UsageTracking.created_at >= today_start
)
today_result = await db.execute(today_query)
today_stats = today_result.first()
# This hour's usage
hour_query = select(
func.count(UsageTracking.id),
func.sum(UsageTracking.total_tokens),
func.sum(UsageTracking.cost_cents)
func.sum(UsageTracking.cost_cents),
).where(
UsageTracking.api_key_id == api_key_id,
UsageTracking.created_at >= hour_start
UsageTracking.api_key_id == api_key_id, UsageTracking.created_at >= hour_start
)
hour_result = await db.execute(hour_query)
hour_stats = hour_result.first()
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="get_api_key_usage",
resource_type="api_key",
resource_id=api_key_id
resource_id=api_key_id,
)
return APIKeyUsageResponse(
api_key_id=api_key_id,
total_requests=api_key.total_requests,
@@ -557,7 +581,7 @@ async def get_api_key_usage(
requests_this_hour=hour_stats[0] or 0,
tokens_this_hour=hour_stats[1] or 0,
cost_this_hour_cents=hour_stats[2] or 0,
last_used_at=api_key.last_used_at
last_used_at=api_key.last_used_at,
)
@@ -565,41 +589,42 @@ async def get_api_key_usage(
async def activate_api_key(
api_key_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Activate API key"""
# Get API key
query = select(APIKey).where(APIKey.id == int(api_key_id))
result = await db.execute(query)
api_key = result.scalar_one_or_none()
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
)
# Check permissions - users can activate their own API keys
if api_key.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:update")
if api_key.user_id != current_user["id"]:
require_permission(
current_user.get("permissions", []), "platform:api-keys:update"
)
# Activate API key
api_key.is_active = True
await db.commit()
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="activate_api_key",
resource_type="api_key",
resource_id=api_key_id,
details={"name": api_key.name}
details={"name": api_key.name},
)
logger.info(f"API key activated: {api_key.name} by {current_user['username']}")
return {"message": "API key activated successfully"}
@@ -607,39 +632,40 @@ async def activate_api_key(
async def deactivate_api_key(
api_key_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Deactivate API key"""
# Get API key
query = select(APIKey).where(APIKey.id == int(api_key_id))
result = await db.execute(query)
api_key = result.scalar_one_or_none()
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
)
# Check permissions - users can deactivate their own API keys
if api_key.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:update")
if api_key.user_id != current_user["id"]:
require_permission(
current_user.get("permissions", []), "platform:api-keys:update"
)
# Deactivate API key
api_key.is_active = False
await db.commit()
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="deactivate_api_key",
resource_type="api_key",
resource_id=api_key_id,
details={"name": api_key.name}
details={"name": api_key.name},
)
logger.info(f"API key deactivated: {api_key.name} by {current_user['username']}")
return {"message": "API key deactivated successfully"}
return {"message": "API key deactivated successfully"}

View File

@@ -36,7 +36,7 @@ class AuditLogResponse(BaseModel):
success: bool
severity: str
created_at: datetime
class Config:
from_attributes = True
@@ -96,17 +96,17 @@ async def list_audit_logs(
severity: Optional[str] = Query(None),
search: Optional[str] = Query(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""List audit logs with filtering and pagination"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:audit:read")
# Build query
query = select(AuditLog)
conditions = []
# Apply filters
if user_id:
conditions.append(AuditLog.user_id == user_id)
@@ -128,29 +128,29 @@ async def list_audit_logs(
search_conditions = [
AuditLog.action.ilike(f"%{search}%"),
AuditLog.resource_type.ilike(f"%{search}%"),
AuditLog.details.astext.ilike(f"%{search}%")
AuditLog.details.astext.ilike(f"%{search}%"),
]
conditions.append(or_(*search_conditions))
if conditions:
query = query.where(and_(*conditions))
# Get total count
count_query = select(func.count(AuditLog.id))
if conditions:
count_query = count_query.where(and_(*conditions))
total_result = await db.execute(count_query)
total = total_result.scalar()
# Apply pagination and ordering
offset = (page - 1) * size
query = query.offset(offset).limit(size).order_by(AuditLog.created_at.desc())
# Execute query
result = await db.execute(query)
logs = result.scalars().all()
# Log audit event for this query
await log_audit_event(
db=db,
@@ -166,19 +166,19 @@ async def list_audit_logs(
"end_date": end_date.isoformat() if end_date else None,
"success": success,
"severity": severity,
"search": search
"search": search,
},
"page": page,
"size": size,
"total_results": total
}
"total_results": total,
},
)
return AuditLogListResponse(
logs=[AuditLogResponse.model_validate(log) for log in logs],
total=total,
page=page,
size=size
size=size,
)
@@ -188,13 +188,13 @@ async def search_audit_logs(
page: int = Query(1, ge=1),
size: int = Query(50, ge=1, le=1000),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Advanced search for audit logs"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:audit:read")
# Use the audit service function
logs = await get_audit_logs(
db=db,
@@ -205,13 +205,13 @@ async def search_audit_logs(
start_date=search_request.start_date,
end_date=search_request.end_date,
limit=size,
offset=(page - 1) * size
offset=(page - 1) * size,
)
# Get total count for the search
total_query = select(func.count(AuditLog.id))
conditions = []
if search_request.user_id:
conditions.append(AuditLog.user_id == search_request.user_id)
if search_request.action:
@@ -234,16 +234,16 @@ async def search_audit_logs(
search_conditions = [
AuditLog.action.ilike(f"%{search_request.search_text}%"),
AuditLog.resource_type.ilike(f"%{search_request.search_text}%"),
AuditLog.details.astext.ilike(f"%{search_request.search_text}%")
AuditLog.details.astext.ilike(f"%{search_request.search_text}%"),
]
conditions.append(or_(*search_conditions))
if conditions:
total_query = total_query.where(and_(*conditions))
total_result = await db.execute(total_query)
total = total_result.scalar()
# Log audit event
await log_audit_event(
db=db,
@@ -253,15 +253,15 @@ async def search_audit_logs(
details={
"search_criteria": search_request.model_dump(exclude_unset=True),
"results_count": len(logs),
"total_matches": total
}
"total_matches": total,
},
)
return AuditLogListResponse(
logs=[AuditLogResponse.model_validate(log) for log in logs],
total=total,
page=page,
size=size
size=size,
)
@@ -270,64 +270,80 @@ async def get_audit_statistics(
start_date: Optional[datetime] = Query(None),
end_date: Optional[datetime] = Query(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get audit log statistics"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:audit:read")
# Default to last 30 days if no dates provided
if not end_date:
end_date = datetime.utcnow()
if not start_date:
start_date = end_date - timedelta(days=30)
# Get basic stats using audit service
basic_stats = await get_audit_stats(db, start_date, end_date)
# Get additional statistics
conditions = [
AuditLog.created_at >= start_date,
AuditLog.created_at <= end_date
]
conditions = [AuditLog.created_at >= start_date, AuditLog.created_at <= end_date]
# Events by user
user_query = select(
AuditLog.user_id,
func.count(AuditLog.id).label('count')
).where(and_(*conditions)).group_by(AuditLog.user_id).order_by(func.count(AuditLog.id).desc()).limit(10)
user_query = (
select(AuditLog.user_id, func.count(AuditLog.id).label("count"))
.where(and_(*conditions))
.group_by(AuditLog.user_id)
.order_by(func.count(AuditLog.id).desc())
.limit(10)
)
user_result = await db.execute(user_query)
events_by_user = dict(user_result.fetchall())
# Events by hour of day
hour_query = select(
func.extract('hour', AuditLog.created_at).label('hour'),
func.count(AuditLog.id).label('count')
).where(and_(*conditions)).group_by(func.extract('hour', AuditLog.created_at)).order_by('hour')
hour_query = (
select(
func.extract("hour", AuditLog.created_at).label("hour"),
func.count(AuditLog.id).label("count"),
)
.where(and_(*conditions))
.group_by(func.extract("hour", AuditLog.created_at))
.order_by("hour")
)
hour_result = await db.execute(hour_query)
events_by_hour = dict(hour_result.fetchall())
# Top actions
top_actions_query = select(
AuditLog.action,
func.count(AuditLog.id).label('count')
).where(and_(*conditions)).group_by(AuditLog.action).order_by(func.count(AuditLog.id).desc()).limit(10)
top_actions_query = (
select(AuditLog.action, func.count(AuditLog.id).label("count"))
.where(and_(*conditions))
.group_by(AuditLog.action)
.order_by(func.count(AuditLog.id).desc())
.limit(10)
)
top_actions_result = await db.execute(top_actions_query)
top_actions = [{"action": row[0], "count": row[1]} for row in top_actions_result.fetchall()]
top_actions = [
{"action": row[0], "count": row[1]} for row in top_actions_result.fetchall()
]
# Top resources
top_resources_query = select(
AuditLog.resource_type,
func.count(AuditLog.id).label('count')
).where(and_(*conditions)).group_by(AuditLog.resource_type).order_by(func.count(AuditLog.id).desc()).limit(10)
top_resources_query = (
select(AuditLog.resource_type, func.count(AuditLog.id).label("count"))
.where(and_(*conditions))
.group_by(AuditLog.resource_type)
.order_by(func.count(AuditLog.id).desc())
.limit(10)
)
top_resources_result = await db.execute(top_resources_query)
top_resources = [{"resource_type": row[0], "count": row[1]} for row in top_resources_result.fetchall()]
top_resources = [
{"resource_type": row[0], "count": row[1]}
for row in top_resources_result.fetchall()
]
# Log audit event
await log_audit_event(
db=db,
@@ -337,16 +353,16 @@ async def get_audit_statistics(
details={
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat(),
"total_events": basic_stats["total_events"]
}
"total_events": basic_stats["total_events"],
},
)
return AuditStatsResponse(
**basic_stats,
events_by_user=events_by_user,
events_by_hour=events_by_hour,
top_actions=top_actions,
top_resources=top_resources
top_resources=top_resources,
)
@@ -354,25 +370,30 @@ async def get_audit_statistics(
async def get_security_events(
hours: int = Query(24, ge=1, le=168), # Last 24 hours by default, max 1 week
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get security-related events and anomalies"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:audit:read")
end_time = datetime.utcnow()
start_time = end_time - timedelta(hours=hours)
# Failed logins
failed_logins_query = select(AuditLog).where(
and_(
AuditLog.created_at >= start_time,
AuditLog.action == "login",
AuditLog.success == False
failed_logins_query = (
select(AuditLog)
.where(
and_(
AuditLog.created_at >= start_time,
AuditLog.action == "login",
AuditLog.success == False,
)
)
).order_by(AuditLog.created_at.desc()).limit(50)
.order_by(AuditLog.created_at.desc())
.limit(50)
)
failed_logins_result = await db.execute(failed_logins_query)
failed_logins = [
{
@@ -380,19 +401,24 @@ async def get_security_events(
"user_id": log.user_id,
"ip_address": log.ip_address,
"user_agent": log.user_agent,
"details": log.details
"details": log.details,
}
for log in failed_logins_result.scalars().all()
]
# High severity events
high_severity_query = select(AuditLog).where(
and_(
AuditLog.created_at >= start_time,
AuditLog.severity.in_(["error", "critical"])
high_severity_query = (
select(AuditLog)
.where(
and_(
AuditLog.created_at >= start_time,
AuditLog.severity.in_(["error", "critical"]),
)
)
).order_by(AuditLog.created_at.desc()).limit(50)
.order_by(AuditLog.created_at.desc())
.limit(50)
)
high_severity_result = await db.execute(high_severity_query)
high_severity_events = [
{
@@ -403,56 +429,65 @@ async def get_security_events(
"user_id": log.user_id,
"ip_address": log.ip_address,
"success": log.success,
"details": log.details
"details": log.details,
}
for log in high_severity_result.scalars().all()
]
# Suspicious activities (multiple failed attempts from same IP)
suspicious_ips_query = select(
AuditLog.ip_address,
func.count(AuditLog.id).label('failed_count')
).where(
and_(
AuditLog.created_at >= start_time,
AuditLog.success == False,
AuditLog.ip_address.isnot(None)
suspicious_ips_query = (
select(AuditLog.ip_address, func.count(AuditLog.id).label("failed_count"))
.where(
and_(
AuditLog.created_at >= start_time,
AuditLog.success == False,
AuditLog.ip_address.isnot(None),
)
)
).group_by(AuditLog.ip_address).having(func.count(AuditLog.id) >= 5).order_by(func.count(AuditLog.id).desc())
.group_by(AuditLog.ip_address)
.having(func.count(AuditLog.id) >= 5)
.order_by(func.count(AuditLog.id).desc())
)
suspicious_ips_result = await db.execute(suspicious_ips_query)
suspicious_activities = [
{
"ip_address": row[0],
"failed_attempts": row[1],
"risk_level": "high" if row[1] >= 10 else "medium"
"risk_level": "high" if row[1] >= 10 else "medium",
}
for row in suspicious_ips_result.fetchall()
]
# Unusual access patterns (users accessing from multiple IPs)
unusual_access_query = select(
AuditLog.user_id,
func.count(func.distinct(AuditLog.ip_address)).label('ip_count'),
func.array_agg(func.distinct(AuditLog.ip_address)).label('ip_addresses')
).where(
and_(
AuditLog.created_at >= start_time,
AuditLog.user_id.isnot(None),
AuditLog.ip_address.isnot(None)
unusual_access_query = (
select(
AuditLog.user_id,
func.count(func.distinct(AuditLog.ip_address)).label("ip_count"),
func.array_agg(func.distinct(AuditLog.ip_address)).label("ip_addresses"),
)
).group_by(AuditLog.user_id).having(func.count(func.distinct(AuditLog.ip_address)) >= 3).order_by(func.count(func.distinct(AuditLog.ip_address)).desc())
.where(
and_(
AuditLog.created_at >= start_time,
AuditLog.user_id.isnot(None),
AuditLog.ip_address.isnot(None),
)
)
.group_by(AuditLog.user_id)
.having(func.count(func.distinct(AuditLog.ip_address)) >= 3)
.order_by(func.count(func.distinct(AuditLog.ip_address)).desc())
)
unusual_access_result = await db.execute(unusual_access_query)
unusual_access_patterns = [
{
"user_id": row[0],
"unique_ips": row[1],
"ip_addresses": row[2] if row[2] else []
"ip_addresses": row[2] if row[2] else [],
}
for row in unusual_access_result.fetchall()
]
# Log audit event
await log_audit_event(
db=db,
@@ -464,15 +499,15 @@ async def get_security_events(
"failed_logins_count": len(failed_logins),
"high_severity_count": len(high_severity_events),
"suspicious_ips_count": len(suspicious_activities),
"unusual_access_patterns_count": len(unusual_access_patterns)
}
"unusual_access_patterns_count": len(unusual_access_patterns),
},
)
return SecurityEventsResponse(
suspicious_activities=suspicious_activities,
failed_logins=failed_logins,
unusual_access_patterns=unusual_access_patterns,
high_severity_events=high_severity_events
high_severity_events=high_severity_events,
)
@@ -485,42 +520,43 @@ async def export_audit_logs(
action: Optional[str] = Query(None),
resource_type: Optional[str] = Query(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Export audit logs in CSV or JSON format"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:audit:export")
# Default to last 30 days if no dates provided
if not end_date:
end_date = datetime.utcnow()
if not start_date:
start_date = end_date - timedelta(days=30)
# Limit export size
max_records = 10000
# Build query
query = select(AuditLog)
conditions = [
AuditLog.created_at >= start_date,
AuditLog.created_at <= end_date
]
conditions = [AuditLog.created_at >= start_date, AuditLog.created_at <= end_date]
if user_id:
conditions.append(AuditLog.user_id == user_id)
if action:
conditions.append(AuditLog.action == action)
if resource_type:
conditions.append(AuditLog.resource_type == resource_type)
query = query.where(and_(*conditions)).order_by(AuditLog.created_at.desc()).limit(max_records)
query = (
query.where(and_(*conditions))
.order_by(AuditLog.created_at.desc())
.limit(max_records)
)
# Execute query
result = await db.execute(query)
logs = result.scalars().all()
# Log export event
await log_audit_event(
db=db,
@@ -535,13 +571,14 @@ async def export_audit_logs(
"filters": {
"user_id": user_id,
"action": action,
"resource_type": resource_type
}
}
"resource_type": resource_type,
},
},
)
if format == "json":
from fastapi.responses import JSONResponse
export_data = [
{
"id": str(log.id),
@@ -554,45 +591,59 @@ async def export_audit_logs(
"user_agent": log.user_agent,
"success": log.success,
"severity": log.severity,
"created_at": log.created_at.isoformat()
"created_at": log.created_at.isoformat(),
}
for log in logs
]
return JSONResponse(content=export_data)
else: # CSV format
import csv
import io
from fastapi.responses import StreamingResponse
output = io.StringIO()
writer = csv.writer(output)
# Write header
writer.writerow([
"ID", "User ID", "Action", "Resource Type", "Resource ID",
"IP Address", "Success", "Severity", "Created At", "Details"
])
writer.writerow(
[
"ID",
"User ID",
"Action",
"Resource Type",
"Resource ID",
"IP Address",
"Success",
"Severity",
"Created At",
"Details",
]
)
# Write data
for log in logs:
writer.writerow([
str(log.id),
log.user_id or "",
log.action,
log.resource_type,
log.resource_id or "",
log.ip_address or "",
log.success,
log.severity,
log.created_at.isoformat(),
str(log.details)
])
writer.writerow(
[
str(log.id),
log.user_id or "",
log.action,
log.resource_type,
log.resource_id or "",
log.ip_address or "",
log.success,
log.severity,
log.created_at.isoformat(),
str(log.details),
]
)
output.seek(0)
return StreamingResponse(
io.BytesIO(output.getvalue().encode()),
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename=audit_logs_{start_date.strftime('%Y%m%d')}_{end_date.strftime('%Y%m%d')}.csv"}
)
headers={
"Content-Disposition": f"attachment; filename=audit_logs_{start_date.strftime('%Y%m%d')}_{end_date.strftime('%Y%m%d')}.csv"
},
)

View File

@@ -9,6 +9,7 @@ from fastapi.security import HTTPBearer
from pydantic import BaseModel, EmailStr, validator
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from app.core.config import settings
from app.core.logging import get_logger
@@ -38,25 +39,25 @@ class UserRegisterRequest(BaseModel):
password: str
first_name: Optional[str] = None
last_name: Optional[str] = None
@validator('password')
@validator("password")
def validate_password(cls, v):
if len(v) < 8:
raise ValueError('Password must be at least 8 characters long')
raise ValueError("Password must be at least 8 characters long")
if not any(c.isupper() for c in v):
raise ValueError('Password must contain at least one uppercase letter')
raise ValueError("Password must contain at least one uppercase letter")
if not any(c.islower() for c in v):
raise ValueError('Password must contain at least one lowercase letter')
raise ValueError("Password must contain at least one lowercase letter")
if not any(c.isdigit() for c in v):
raise ValueError('Password must contain at least one digit')
raise ValueError("Password must contain at least one digit")
return v
@validator('username')
@validator("username")
def validate_username(cls, v):
if len(v) < 3:
raise ValueError('Username must be at least 3 characters long')
raise ValueError("Username must be at least 3 characters long")
if not v.isalnum():
raise ValueError('Username must contain only alphanumeric characters')
raise ValueError("Username must contain only alphanumeric characters")
return v
@@ -65,10 +66,10 @@ class UserLoginRequest(BaseModel):
username: Optional[str] = None
password: str
@validator('email')
@validator("email")
def validate_email_or_username(cls, v, values):
if v is None and not values.get('username'):
raise ValueError('Either email or username must be provided')
if v is None and not values.get("username"):
raise ValueError("Either email or username must be provided")
return v
@@ -77,6 +78,8 @@ class TokenResponse(BaseModel):
refresh_token: str
token_type: str = "bearer"
expires_in: int
force_password_change: Optional[bool] = None
message: Optional[str] = None
class UserResponse(BaseModel):
@@ -86,9 +89,9 @@ class UserResponse(BaseModel):
full_name: Optional[str]
is_active: bool
is_verified: bool
role: str
role: Optional[str]
created_at: datetime
class Config:
from_attributes = True
@@ -100,50 +103,47 @@ class RefreshTokenRequest(BaseModel):
class ChangePasswordRequest(BaseModel):
current_password: str
new_password: str
@validator('new_password')
@validator("new_password")
def validate_new_password(cls, v):
if len(v) < 8:
raise ValueError('Password must be at least 8 characters long')
raise ValueError("Password must be at least 8 characters long")
if not any(c.isupper() for c in v):
raise ValueError('Password must contain at least one uppercase letter')
raise ValueError("Password must contain at least one uppercase letter")
if not any(c.islower() for c in v):
raise ValueError('Password must contain at least one lowercase letter')
raise ValueError("Password must contain at least one lowercase letter")
if not any(c.isdigit() for c in v):
raise ValueError('Password must contain at least one digit')
raise ValueError("Password must contain at least one digit")
return v
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(
user_data: UserRegisterRequest,
db: AsyncSession = Depends(get_db)
):
@router.post(
"/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED
)
async def register(user_data: UserRegisterRequest, db: AsyncSession = Depends(get_db)):
"""Register a new user"""
# Check if user already exists
stmt = select(User).where(User.email == user_data.email)
result = await db.execute(stmt)
if result.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered"
status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered"
)
# Check if username already exists
stmt = select(User).where(User.username == user_data.username)
result = await db.execute(stmt)
if result.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username already taken"
status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken"
)
# Create new user
full_name = None
if user_data.first_name or user_data.last_name:
full_name = f"{user_data.first_name or ''} {user_data.last_name or ''}".strip()
user = User(
email=user_data.email,
username=user_data.username,
@@ -151,23 +151,29 @@ async def register(
full_name=full_name,
is_active=True,
is_verified=False,
role="user"
role_id=2, # Default to 'user' role (id=2)
)
db.add(user)
await db.commit()
await db.refresh(user)
return UserResponse.from_orm(user)
return UserResponse(
id=user.id,
email=user.email,
username=user.username,
full_name=user.full_name,
is_active=user.is_active,
is_verified=user.is_verified,
role=user.role.name if user.role else None,
created_at=user.created_at,
)
@router.post("/login", response_model=TokenResponse)
async def login(
user_data: UserLoginRequest,
db: AsyncSession = Depends(get_db)
):
async def login(user_data: UserLoginRequest, db: AsyncSession = Depends(get_db)):
"""Login user and return access tokens"""
# Determine identifier for logging and user lookup
identifier = user_data.email if user_data.email else user_data.username
logger.info(
@@ -187,9 +193,9 @@ async def login(
query_start = datetime.utcnow()
if user_data.email:
stmt = select(User).where(User.email == user_data.email)
stmt = select(User).options(selectinload(User.role)).where(User.email == user_data.email)
else:
stmt = select(User).where(User.username == user_data.username)
stmt = select(User).options(selectinload(User.role)).where(User.username == user_data.username)
result = await db.execute(stmt)
query_end = datetime.utcnow()
@@ -205,13 +211,18 @@ async def login(
identifier_lower = identifier.lower() if identifier else ""
admin_email = settings.ADMIN_EMAIL.lower() if settings.ADMIN_EMAIL else None
if user_data.email and admin_email and identifier_lower == admin_email and settings.ADMIN_PASSWORD:
if (
user_data.email
and admin_email
and identifier_lower == admin_email
and settings.ADMIN_PASSWORD
):
bootstrap_attempted = True
logger.info("LOGIN_ADMIN_BOOTSTRAP_START", email=user_data.email)
try:
await create_default_admin()
# Re-run lookup after bootstrap attempt
stmt = select(User).where(User.email == user_data.email)
stmt = select(User).options(selectinload(User.role)).where(User.email == user_data.email)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if user:
@@ -232,19 +243,21 @@ async def login(
)
except Exception as e:
logger.error("LOGIN_USER_LIST_FAILURE", error=str(e))
if bootstrap_attempted:
logger.warning("LOGIN_ADMIN_BOOTSTRAP_UNSUCCESSFUL", email=user_data.email)
logger.warning(
"LOGIN_ADMIN_BOOTSTRAP_UNSUCCESSFUL", email=user_data.email
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password"
detail="Incorrect email or password",
)
logger.info("LOGIN_USER_FOUND", email=user.email, is_active=user.is_active)
logger.info("LOGIN_PASSWORD_VERIFY_START")
verify_start = datetime.utcnow()
if not verify_password(user_data.password, user.hashed_password):
verify_end = datetime.utcnow()
logger.warning(
@@ -253,21 +266,20 @@ async def login(
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password"
detail="Incorrect email or password",
)
verify_end = datetime.utcnow()
logger.info(
"LOGIN_PASSWORD_VERIFY_SUCCESS",
duration_seconds=(verify_end - verify_start).total_seconds(),
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User account is disabled"
status_code=status.HTTP_401_UNAUTHORIZED, detail="User account is disabled"
)
# Update last login
logger.info("LOGIN_LAST_LOGIN_UPDATE_START")
update_start = datetime.utcnow()
@@ -289,14 +301,12 @@ async def login(
"sub": str(user.id),
"email": user.email,
"is_superuser": user.is_superuser,
"role": user.role
"role": user.role.name if user.role else None,
},
expires_delta=access_token_expires
expires_delta=access_token_expires,
)
refresh_token = create_refresh_token(
data={"sub": str(user.id), "type": "refresh"}
)
refresh_token = create_refresh_token(data={"sub": str(user.id), "type": "refresh"})
token_end = datetime.utcnow()
logger.info(
"LOGIN_TOKEN_CREATE_SUCCESS",
@@ -308,65 +318,76 @@ async def login(
"LOGIN_DEBUG_COMPLETE",
total_duration_seconds=total_time.total_seconds(),
)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
)
# Check if user needs to change password
response_data = {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
"expires_in": settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
}
# Add force password change flag if needed
if user.force_password_change:
response_data["force_password_change"] = True
response_data["message"] = "Password change required on first login"
return response_data
@router.post("/refresh", response_model=TokenResponse)
async def refresh_token(
token_data: RefreshTokenRequest,
db: AsyncSession = Depends(get_db)
token_data: RefreshTokenRequest, db: AsyncSession = Depends(get_db)
):
"""Refresh access token using refresh token"""
try:
payload = verify_token(token_data.refresh_token)
user_id = payload.get("sub")
token_type = payload.get("type")
if not user_id or token_type != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token"
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
)
# Get user from database
stmt = select(User).where(User.id == int(user_id))
stmt = select(User).options(selectinload(User.role)).where(User.id == int(user_id))
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or inactive"
detail="User not found or inactive",
)
# Create new access token
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
logger.info(f"REFRESH: Creating new access token with expiration: {access_token_expires}")
logger.info(f"REFRESH: ACCESS_TOKEN_EXPIRE_MINUTES from settings: {settings.ACCESS_TOKEN_EXPIRE_MINUTES}")
logger.info(
f"REFRESH: Creating new access token with expiration: {access_token_expires}"
)
logger.info(
f"REFRESH: ACCESS_TOKEN_EXPIRE_MINUTES from settings: {settings.ACCESS_TOKEN_EXPIRE_MINUTES}"
)
logger.info(f"REFRESH: Current UTC time: {datetime.utcnow().isoformat()}")
access_token = create_access_token(
data={
"sub": str(user.id),
"email": user.email,
"is_superuser": user.is_superuser,
"role": user.role
"role": user.role.name if user.role else None,
},
expires_delta=access_token_expires
expires_delta=access_token_expires,
)
return TokenResponse(
access_token=access_token,
refresh_token=token_data.refresh_token, # Keep same refresh token
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
)
except HTTPException:
# Re-raise HTTPException without modification
raise
@@ -374,30 +395,37 @@ async def refresh_token(
# Log the actual error for debugging
logger.error(f"Refresh token error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token"
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
)
@router.get("/me", response_model=UserResponse)
async def get_current_user_info(
current_user: dict = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get current user information"""
# Get full user details from database
stmt = select(User).where(User.id == int(current_user["id"]))
stmt = select(User).options(selectinload(User.role)).where(User.id == int(current_user["id"]))
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
return UserResponse.from_orm(user)
return UserResponse(
id=user.id,
email=user.email,
username=user.username,
full_name=user.full_name,
is_active=user.is_active,
is_verified=user.is_verified,
role=user.role.name if user.role else None,
created_at=user.created_at,
)
@router.post("/logout")
@@ -407,14 +435,12 @@ async def logout():
@router.post("/verify-token")
async def verify_user_token(
current_user: dict = Depends(get_current_user)
):
async def verify_user_token(current_user: dict = Depends(get_current_user)):
"""Verify if the current token is valid"""
return {
"valid": True,
"user_id": current_user["id"],
"email": current_user["email"]
"email": current_user["email"],
}
@@ -422,32 +448,31 @@ async def verify_user_token(
async def change_password(
password_data: ChangePasswordRequest,
current_user: dict = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Change user password"""
# Get user from database
stmt = select(User).where(User.id == int(current_user["id"]))
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
# Verify current password
if not verify_password(password_data.current_password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Current password is incorrect"
detail="Current password is incorrect",
)
# Update password
user.hashed_password = get_password_hash(password_data.new_password)
user.updated_at = datetime.utcnow()
await db.commit()
return {"message": "Password changed successfully"}

View File

@@ -129,76 +129,89 @@ async def list_budgets(
budget_type: Optional[BudgetType] = Query(None),
is_enabled: Optional[bool] = Query(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""List budgets with pagination and filtering"""
# Check permissions - users can view their own budgets
if user_id and int(user_id) != current_user['id']:
if user_id and int(user_id) != current_user["id"]:
require_permission(current_user.get("permissions", []), "platform:budgets:read")
elif not user_id:
require_permission(current_user.get("permissions", []), "platform:budgets:read")
# If no user_id specified and user doesn't have admin permissions, show only their budgets
if not user_id and "platform:budgets:read" not in current_user.get("permissions", []):
user_id = current_user['id']
if not user_id and "platform:budgets:read" not in current_user.get(
"permissions", []
):
user_id = current_user["id"]
# Build query
query = select(Budget)
# Apply filters
if user_id:
query = query.where(Budget.user_id == (int(user_id) if isinstance(user_id, str) else user_id))
query = query.where(
Budget.user_id == (int(user_id) if isinstance(user_id, str) else user_id)
)
if budget_type:
query = query.where(Budget.budget_type == budget_type.value)
if is_enabled is not None:
query = query.where(Budget.is_enabled == is_enabled)
# Get total count
count_query = select(func.count(Budget.id))
# Apply same filters to count query
if user_id:
count_query = count_query.where(Budget.user_id == (int(user_id) if isinstance(user_id, str) else user_id))
count_query = count_query.where(
Budget.user_id == (int(user_id) if isinstance(user_id, str) else user_id)
)
if budget_type:
count_query = count_query.where(Budget.budget_type == budget_type.value)
if is_enabled is not None:
count_query = count_query.where(Budget.is_enabled == is_enabled)
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
# Apply pagination
offset = (page - 1) * size
query = query.offset(offset).limit(size).order_by(Budget.created_at.desc())
# Execute query
result = await db.execute(query)
budgets = result.scalars().all()
# Calculate current usage for each budget
budget_responses = []
for budget in budgets:
usage = await _calculate_budget_usage(db, budget)
budget_data = BudgetResponse.model_validate(budget)
budget_data.current_usage = usage
budget_data.usage_percentage = (usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
budget_data.usage_percentage = (
(usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
)
budget_responses.append(budget_data)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="list_budgets",
resource_type="budget",
details={"page": page, "size": size, "filters": {"user_id": user_id, "budget_type": budget_type, "is_enabled": is_enabled}}
details={
"page": page,
"size": size,
"filters": {
"user_id": user_id,
"budget_type": budget_type,
"is_enabled": is_enabled,
},
},
)
return BudgetListResponse(
budgets=budget_responses,
total=total,
page=page,
size=size
budgets=budget_responses, total=total, page=page, size=size
)
@@ -206,42 +219,43 @@ async def list_budgets(
async def get_budget(
budget_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get budget by ID"""
# Get budget
query = select(Budget).where(Budget.id == budget_id)
result = await db.execute(query)
budget = result.scalar_one_or_none()
if not budget:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Budget not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Budget not found"
)
# Check permissions - users can view their own budgets
if budget.user_id != current_user['id']:
if budget.user_id != current_user["id"]:
require_permission(current_user.get("permissions", []), "platform:budgets:read")
# Calculate current usage
usage = await _calculate_budget_usage(db, budget)
# Build response
budget_data = BudgetResponse.model_validate(budget)
budget_data.current_usage = usage
budget_data.usage_percentage = (usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
budget_data.usage_percentage = (
(usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="get_budget",
resource_type="budget",
resource_id=budget_id
resource_id=budget_id,
)
return budget_data
@@ -249,24 +263,30 @@ async def get_budget(
async def create_budget(
budget_data: BudgetCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Create a new budget"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:budgets:create")
# If user_id not specified, use current user
target_user_id = budget_data.user_id or current_user['id']
target_user_id = budget_data.user_id or current_user["id"]
# If setting budget for another user, need admin permissions
if int(target_user_id) != current_user['id'] if isinstance(target_user_id, str) else target_user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:budgets:admin")
if (
int(target_user_id) != current_user["id"]
if isinstance(target_user_id, str)
else target_user_id != current_user["id"]
):
require_permission(
current_user.get("permissions", []), "platform:budgets:admin"
)
# Calculate period start and end
now = datetime.utcnow()
period_start, period_end = _calculate_period_bounds(now, budget_data.period_type)
# Create budget
new_budget = Budget(
name=budget_data.name,
@@ -281,30 +301,34 @@ async def create_budget(
is_enabled=budget_data.is_enabled,
alert_threshold_percent=budget_data.alert_threshold_percent,
allowed_resources=budget_data.allowed_resources,
metadata=budget_data.metadata
metadata=budget_data.metadata,
)
db.add(new_budget)
await db.commit()
await db.refresh(new_budget)
# Build response
budget_response = BudgetResponse.model_validate(new_budget)
budget_response.current_usage = 0.0
budget_response.usage_percentage = 0.0
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="create_budget",
resource_type="budget",
resource_id=str(new_budget.id),
details={"name": budget_data.name, "budget_type": budget_data.budget_type, "limit_amount": budget_data.limit_amount}
details={
"name": budget_data.name,
"budget_type": budget_data.budget_type,
"limit_amount": budget_data.limit_amount,
},
)
logger.info(f"Budget created: {new_budget.name} by {current_user['username']}")
return budget_response
@@ -313,70 +337,75 @@ async def update_budget(
budget_id: str,
budget_data: BudgetUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Update budget"""
# Get budget
query = select(Budget).where(Budget.id == budget_id)
result = await db.execute(query)
budget = result.scalar_one_or_none()
if not budget:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Budget not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Budget not found"
)
# Check permissions - users can update their own budgets
if budget.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:budgets:update")
if budget.user_id != current_user["id"]:
require_permission(
current_user.get("permissions", []), "platform:budgets:update"
)
# Store original values for audit
original_values = {
"name": budget.name,
"limit_amount": budget.limit_amount,
"is_enabled": budget.is_enabled
"is_enabled": budget.is_enabled,
}
# Update budget fields
update_data = budget_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(budget, field, value)
# Recalculate period if period_type changed
if "period_type" in update_data:
period_start, period_end = _calculate_period_bounds(datetime.utcnow(), budget.period_type)
period_start, period_end = _calculate_period_bounds(
datetime.utcnow(), budget.period_type
)
budget.period_start = period_start
budget.period_end = period_end
await db.commit()
await db.refresh(budget)
# Calculate current usage
usage = await _calculate_budget_usage(db, budget)
# Build response
budget_response = BudgetResponse.model_validate(budget)
budget_response.current_usage = usage
budget_response.usage_percentage = (usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
budget_response.usage_percentage = (
(usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="update_budget",
resource_type="budget",
resource_id=budget_id,
details={
"updated_fields": list(update_data.keys()),
"before_values": original_values,
"after_values": {k: getattr(budget, k) for k in update_data.keys()}
}
"after_values": {k: getattr(budget, k) for k in update_data.keys()},
},
)
logger.info(f"Budget updated: {budget.name} by {current_user['username']}")
return budget_response
@@ -384,41 +413,42 @@ async def update_budget(
async def delete_budget(
budget_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Delete budget"""
# Get budget
query = select(Budget).where(Budget.id == budget_id)
result = await db.execute(query)
budget = result.scalar_one_or_none()
if not budget:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Budget not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Budget not found"
)
# Check permissions - users can delete their own budgets
if budget.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:budgets:delete")
if budget.user_id != current_user["id"]:
require_permission(
current_user.get("permissions", []), "platform:budgets:delete"
)
# Delete budget
await db.delete(budget)
await db.commit()
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="delete_budget",
resource_type="budget",
resource_id=budget_id,
details={"name": budget.name}
details={"name": budget.name},
)
logger.info(f"Budget deleted: {budget.name} by {current_user['username']}")
return {"message": "Budget deleted successfully"}
@@ -426,35 +456,36 @@ async def delete_budget(
async def get_budget_usage(
budget_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get detailed budget usage information"""
# Get budget
query = select(Budget).where(Budget.id == budget_id)
result = await db.execute(query)
budget = result.scalar_one_or_none()
if not budget:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Budget not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Budget not found"
)
# Check permissions - users can view their own budget usage
if budget.user_id != current_user['id']:
if budget.user_id != current_user["id"]:
require_permission(current_user.get("permissions", []), "platform:budgets:read")
# Calculate usage
current_usage = await _calculate_budget_usage(db, budget)
usage_percentage = (current_usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
usage_percentage = (
(current_usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
)
remaining_amount = max(0, budget.limit_amount - current_usage)
is_exceeded = current_usage > budget.limit_amount
# Calculate days remaining in period
now = datetime.utcnow()
days_remaining = max(0, (budget.period_end - now).days)
# Calculate projected usage
projected_usage = None
if days_remaining > 0 and current_usage > 0:
@@ -463,19 +494,19 @@ async def get_budget_usage(
daily_rate = current_usage / days_elapsed
total_days = (budget.period_end - budget.period_start).days + 1
projected_usage = daily_rate * total_days
# Get usage history (last 30 days)
usage_history = await _get_usage_history(db, budget, days=30)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="get_budget_usage",
resource_type="budget",
resource_id=budget_id
resource_id=budget_id,
)
return BudgetUsageResponse(
budget_id=budget_id,
current_usage=current_usage,
@@ -487,7 +518,7 @@ async def get_budget_usage(
is_exceeded=is_exceeded,
days_remaining=days_remaining,
projected_usage=projected_usage,
usage_history=usage_history
usage_history=usage_history,
)
@@ -495,85 +526,92 @@ async def get_budget_usage(
async def get_budget_alerts(
budget_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get budget alerts"""
# Get budget
query = select(Budget).where(Budget.id == budget_id)
result = await db.execute(query)
budget = result.scalar_one_or_none()
if not budget:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Budget not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Budget not found"
)
# Check permissions - users can view their own budget alerts
if budget.user_id != current_user['id']:
if budget.user_id != current_user["id"]:
require_permission(current_user.get("permissions", []), "platform:budgets:read")
# Calculate usage
current_usage = await _calculate_budget_usage(db, budget)
usage_percentage = (current_usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
usage_percentage = (
(current_usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
)
alerts = []
# Check for alerts
if usage_percentage >= 100:
alerts.append(BudgetAlertResponse(
budget_id=budget_id,
budget_name=budget.name,
alert_type="exceeded",
current_usage=current_usage,
limit_amount=budget.limit_amount,
usage_percentage=usage_percentage,
message=f"Budget '{budget.name}' has been exceeded ({usage_percentage:.1f}% used)"
))
alerts.append(
BudgetAlertResponse(
budget_id=budget_id,
budget_name=budget.name,
alert_type="exceeded",
current_usage=current_usage,
limit_amount=budget.limit_amount,
usage_percentage=usage_percentage,
message=f"Budget '{budget.name}' has been exceeded ({usage_percentage:.1f}% used)",
)
)
elif usage_percentage >= 90:
alerts.append(BudgetAlertResponse(
budget_id=budget_id,
budget_name=budget.name,
alert_type="critical",
current_usage=current_usage,
limit_amount=budget.limit_amount,
usage_percentage=usage_percentage,
message=f"Budget '{budget.name}' is critically high ({usage_percentage:.1f}% used)"
))
alerts.append(
BudgetAlertResponse(
budget_id=budget_id,
budget_name=budget.name,
alert_type="critical",
current_usage=current_usage,
limit_amount=budget.limit_amount,
usage_percentage=usage_percentage,
message=f"Budget '{budget.name}' is critically high ({usage_percentage:.1f}% used)",
)
)
elif usage_percentage >= budget.alert_threshold_percent:
alerts.append(BudgetAlertResponse(
budget_id=budget_id,
budget_name=budget.name,
alert_type="warning",
current_usage=current_usage,
limit_amount=budget.limit_amount,
usage_percentage=usage_percentage,
message=f"Budget '{budget.name}' has reached alert threshold ({usage_percentage:.1f}% used)"
))
alerts.append(
BudgetAlertResponse(
budget_id=budget_id,
budget_name=budget.name,
alert_type="warning",
current_usage=current_usage,
limit_amount=budget.limit_amount,
usage_percentage=usage_percentage,
message=f"Budget '{budget.name}' has reached alert threshold ({usage_percentage:.1f}% used)",
)
)
return alerts
# Helper functions
async def _calculate_budget_usage(db: AsyncSession, budget: Budget) -> float:
"""Calculate current usage for a budget"""
# Build base query
query = select(UsageTracking)
# Filter by time period
query = query.where(
UsageTracking.created_at >= budget.period_start,
UsageTracking.created_at <= budget.period_end
UsageTracking.created_at <= budget.period_end,
)
# Filter by user or API key
if budget.api_key_id:
query = query.where(UsageTracking.api_key_id == budget.api_key_id)
elif budget.user_id:
query = query.where(UsageTracking.user_id == budget.user_id)
# Calculate usage based on budget type
if budget.budget_type == "tokens":
usage_query = query.with_only_columns(func.sum(UsageTracking.total_tokens))
@@ -583,20 +621,22 @@ async def _calculate_budget_usage(db: AsyncSession, budget: Budget) -> float:
usage_query = query.with_only_columns(func.count(UsageTracking.id))
else:
return 0.0
result = await db.execute(usage_query)
usage = result.scalar() or 0
# Convert cents to dollars for dollar budgets
if budget.budget_type == "dollars":
usage = usage / 100.0
return float(usage)
def _calculate_period_bounds(current_time: datetime, period_type: str) -> tuple[datetime, datetime]:
def _calculate_period_bounds(
current_time: datetime, period_type: str
) -> tuple[datetime, datetime]:
"""Calculate period start and end dates"""
if period_type == "hourly":
start = current_time.replace(minute=0, second=0, microsecond=0)
end = start + timedelta(hours=1) - timedelta(microseconds=1)
@@ -606,7 +646,9 @@ def _calculate_period_bounds(current_time: datetime, period_type: str) -> tuple[
elif period_type == "weekly":
# Start of week (Monday)
days_since_monday = current_time.weekday()
start = current_time.replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=days_since_monday)
start = current_time.replace(
hour=0, minute=0, second=0, microsecond=0
) - timedelta(days=days_since_monday)
end = start + timedelta(weeks=1) - timedelta(microseconds=1)
elif period_type == "monthly":
start = current_time.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
@@ -616,44 +658,49 @@ def _calculate_period_bounds(current_time: datetime, period_type: str) -> tuple[
next_month = start.replace(month=start.month + 1)
end = next_month - timedelta(microseconds=1)
elif period_type == "yearly":
start = current_time.replace(month=1, day=1, hour=0, minute=0, second=0, microsecond=0)
start = current_time.replace(
month=1, day=1, hour=0, minute=0, second=0, microsecond=0
)
end = start.replace(year=start.year + 1) - timedelta(microseconds=1)
else:
# Default to daily
start = current_time.replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1) - timedelta(microseconds=1)
return start, end
async def _get_usage_history(db: AsyncSession, budget: Budget, days: int = 30) -> List[dict]:
async def _get_usage_history(
db: AsyncSession, budget: Budget, days: int = 30
) -> List[dict]:
"""Get usage history for the budget"""
end_date = datetime.utcnow()
start_date = end_date - timedelta(days=days)
# Build query
query = select(
func.date(UsageTracking.created_at).label('date'),
func.sum(UsageTracking.total_tokens).label('tokens'),
func.sum(UsageTracking.cost_cents).label('cost_cents'),
func.count(UsageTracking.id).label('requests')
func.date(UsageTracking.created_at).label("date"),
func.sum(UsageTracking.total_tokens).label("tokens"),
func.sum(UsageTracking.cost_cents).label("cost_cents"),
func.count(UsageTracking.id).label("requests"),
).where(
UsageTracking.created_at >= start_date,
UsageTracking.created_at <= end_date
UsageTracking.created_at >= start_date, UsageTracking.created_at <= end_date
)
# Filter by user or API key
if budget.api_key_id:
query = query.where(UsageTracking.api_key_id == budget.api_key_id)
elif budget.user_id:
query = query.where(UsageTracking.user_id == budget.user_id)
query = query.group_by(func.date(UsageTracking.created_at)).order_by(func.date(UsageTracking.created_at))
query = query.group_by(func.date(UsageTracking.created_at)).order_by(
func.date(UsageTracking.created_at)
)
result = await db.execute(query)
rows = result.fetchall()
history = []
for row in rows:
usage_value = 0
@@ -663,13 +710,15 @@ async def _get_usage_history(db: AsyncSession, budget: Budget, days: int = 30) -
usage_value = (row.cost_cents or 0) / 100.0
elif budget.budget_type == "requests":
usage_value = row.requests or 0
history.append({
"date": row.date.isoformat(),
"usage": usage_value,
"tokens": row.tokens or 0,
"cost_dollars": (row.cost_cents or 0) / 100.0,
"requests": row.requests or 0
})
return history
history.append(
{
"date": row.date.isoformat(),
"usage": usage_value,
"tokens": row.tokens or 0,
"cost_dollars": (row.cost_cents or 0) / 100.0,
"requests": row.requests or 0,
}
)
return history

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
"""
API v1 endpoints package
"""

View File

@@ -0,0 +1,166 @@
"""
Tool calling API endpoints
Integration between LLM and tool execution
"""
from typing import List, Dict, Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db
from app.core.security import get_current_user
from app.services.tool_calling_service import ToolCallingService
from app.services.llm.models import ChatRequest, ChatResponse
from app.schemas.tool_calling import (
ToolCallRequest,
ToolCallResponse,
ToolExecutionRequest,
ToolValidationRequest,
ToolValidationResponse,
ToolHistoryResponse,
)
router = APIRouter()
@router.post("/chat/completions", response_model=ChatResponse)
async def create_chat_completion_with_tools(
request: ChatRequest,
auto_execute_tools: bool = Query(
True, description="Whether to automatically execute tool calls"
),
max_tool_calls: int = Query(
5, ge=1, le=10, description="Maximum number of tool calls"
),
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Create chat completion with tool calling support"""
service = ToolCallingService(db)
# Resolve user ID for context
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
# Set user context in request
request.user_id = str(user_id)
request.api_key_id = 1 # Default for internal usage
response = await service.create_chat_completion_with_tools(
request=request,
user=current_user,
auto_execute_tools=auto_execute_tools,
max_tool_calls=max_tool_calls,
)
return response
@router.post("/chat/completions/stream")
async def create_chat_completion_stream_with_tools(
request: ChatRequest,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Create streaming chat completion with tool calling support"""
service = ToolCallingService(db)
# Resolve user ID for context
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
# Set user context in request
request.user_id = str(user_id)
request.api_key_id = 1 # Default for internal usage
async def stream_generator():
async for chunk in service.create_chat_completion_stream_with_tools(
request=request, user=current_user
):
yield f"data: {chunk}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
stream_generator(),
media_type="text/plain",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
@router.post("/execute", response_model=ToolCallResponse)
async def execute_tool_by_name(
request: ToolExecutionRequest,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Execute a tool by name directly"""
service = ToolCallingService(db)
try:
result = await service.execute_tool_by_name(
tool_name=request.tool_name,
parameters=request.parameters,
user=current_user,
)
return ToolCallResponse(success=True, result=result, error=None)
except Exception as e:
return ToolCallResponse(success=False, result=None, error=str(e))
@router.post("/validate", response_model=ToolValidationResponse)
async def validate_tool_availability(
request: ToolValidationRequest,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Validate which tools are available to the user"""
service = ToolCallingService(db)
availability = await service.validate_tool_availability(
tool_names=request.tool_names, user=current_user
)
return ToolValidationResponse(tool_availability=availability)
@router.get("/history", response_model=ToolHistoryResponse)
async def get_tool_call_history(
limit: int = Query(
50, ge=1, le=100, description="Number of history items to return"
),
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Get recent tool execution history for the user"""
service = ToolCallingService(db)
history = await service.get_tool_call_history(user=current_user, limit=limit)
return ToolHistoryResponse(history=history, total=len(history))
@router.get("/available")
async def get_available_tools(
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Get tools available for function calling"""
service = ToolCallingService(db)
# Get available tools
tools = await service._get_available_tools_for_user(current_user)
# Convert to OpenAI format
openai_tools = await service._convert_tools_to_openai_format(tools)
return {"tools": openai_tools, "count": len(openai_tools)}

View File

@@ -0,0 +1,386 @@
"""
Tool management and execution API endpoints
"""
from typing import List, Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db
from app.core.security import get_current_user
from app.services.tool_management_service import ToolManagementService
from app.services.tool_execution_service import ToolExecutionService
from app.schemas.tool import (
ToolCreate,
ToolUpdate,
ToolResponse,
ToolListResponse,
ToolExecutionCreate,
ToolExecutionResponse,
ToolExecutionListResponse,
ToolCategoryCreate,
ToolCategoryResponse,
ToolStatisticsResponse,
)
router = APIRouter()
@router.post("/", response_model=ToolResponse)
async def create_tool(
tool_data: ToolCreate,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Create a new tool"""
service = ToolManagementService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
tool = await service.create_tool(
name=tool_data.name,
display_name=tool_data.display_name,
code=tool_data.code,
tool_type=tool_data.tool_type,
created_by_user_id=user_id,
description=tool_data.description,
parameters_schema=tool_data.parameters_schema,
return_schema=tool_data.return_schema,
timeout_seconds=tool_data.timeout_seconds,
max_memory_mb=tool_data.max_memory_mb,
max_cpu_seconds=tool_data.max_cpu_seconds,
docker_image=tool_data.docker_image,
docker_command=tool_data.docker_command,
category=tool_data.category,
tags=tool_data.tags,
is_public=tool_data.is_public,
)
return ToolResponse.from_orm(tool)
@router.get("/", response_model=ToolListResponse)
async def list_tools(
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
category: Optional[str] = Query(None),
tool_type: Optional[str] = Query(None),
is_public: Optional[bool] = Query(None),
is_approved: Optional[bool] = Query(None),
search: Optional[str] = Query(None),
tags: Optional[List[str]] = Query(None),
created_by_user_id: Optional[int] = Query(None),
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""List tools with filtering and pagination"""
service = ToolManagementService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
tools = await service.get_tools(
user_id=user_id,
skip=skip,
limit=limit,
category=category,
tool_type=tool_type,
is_public=is_public,
is_approved=is_approved,
search=search,
tags=tags,
created_by_user_id=created_by_user_id,
)
return ToolListResponse(
tools=[ToolResponse.from_orm(tool) for tool in tools],
total=len(tools),
skip=skip,
limit=limit,
)
@router.get("/{tool_id}", response_model=ToolResponse)
async def get_tool(
tool_id: int,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Get tool by ID"""
service = ToolManagementService(db)
tool = await service.get_tool_by_id(tool_id)
if not tool:
raise HTTPException(status_code=404, detail="Tool not found")
# Resolve underlying User object if available
user_obj = (
current_user.get("user_obj")
if isinstance(current_user, dict)
else current_user
)
# Check if user can access this tool
if not user_obj or not tool.can_be_used_by(user_obj):
raise HTTPException(status_code=403, detail="Access denied to this tool")
return ToolResponse.from_orm(tool)
@router.put("/{tool_id}", response_model=ToolResponse)
async def update_tool(
tool_id: int,
tool_data: ToolUpdate,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Update tool (only by creator or admin)"""
service = ToolManagementService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
tool = await service.update_tool(
tool_id=tool_id,
user_id=user_id,
display_name=tool_data.display_name,
description=tool_data.description,
code=tool_data.code,
parameters_schema=tool_data.parameters_schema,
return_schema=tool_data.return_schema,
timeout_seconds=tool_data.timeout_seconds,
max_memory_mb=tool_data.max_memory_mb,
max_cpu_seconds=tool_data.max_cpu_seconds,
docker_image=tool_data.docker_image,
docker_command=tool_data.docker_command,
category=tool_data.category,
tags=tool_data.tags,
is_public=tool_data.is_public,
is_active=tool_data.is_active,
)
return ToolResponse.from_orm(tool)
@router.delete("/{tool_id}")
async def delete_tool(
tool_id: int,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Delete tool (only by creator or admin)"""
service = ToolManagementService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
await service.delete_tool(tool_id, user_id)
return {"message": "Tool deleted successfully"}
@router.post("/{tool_id}/approve", response_model=ToolResponse)
async def approve_tool(
tool_id: int,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Approve tool for public use (admin only)"""
service = ToolManagementService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
tool = await service.approve_tool(tool_id, user_id)
return ToolResponse.from_orm(tool)
# Tool Execution Endpoints
@router.post("/{tool_id}/execute", response_model=ToolExecutionResponse)
async def execute_tool(
tool_id: int,
execution_data: ToolExecutionCreate,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Execute a tool with given parameters"""
service = ToolExecutionService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
execution = await service.execute_tool(
tool_id=tool_id,
user_id=user_id,
parameters=execution_data.parameters,
timeout_override=execution_data.timeout_override,
)
return ToolExecutionResponse.from_orm(execution)
@router.get("/executions", response_model=ToolExecutionListResponse)
async def list_executions(
tool_id: Optional[int] = Query(None),
executed_by_user_id: Optional[int] = Query(None),
status: Optional[str] = Query(None),
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""List tool executions with filtering"""
service = ToolExecutionService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
executions = await service.get_tool_executions(
tool_id=tool_id,
user_id=user_id,
executed_by_user_id=executed_by_user_id,
status=status,
skip=skip,
limit=limit,
)
return ToolExecutionListResponse(
executions=[
ToolExecutionResponse.from_orm(execution) for execution in executions
],
total=len(executions),
skip=skip,
limit=limit,
)
@router.get("/executions/{execution_id}", response_model=ToolExecutionResponse)
async def get_execution(
execution_id: int,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Get execution details"""
service = ToolExecutionService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
# Get execution through list with filter to ensure permission check
executions = await service.get_tool_executions(
user_id=user_id, skip=0, limit=1
)
execution = next((e for e in executions if e.id == execution_id), None)
if not execution:
raise HTTPException(status_code=404, detail="Execution not found")
return ToolExecutionResponse.from_orm(execution)
@router.post("/executions/{execution_id}/cancel", response_model=ToolExecutionResponse)
async def cancel_execution(
execution_id: int,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Cancel a running execution"""
service = ToolExecutionService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
execution = await service.cancel_execution(execution_id, user_id)
return ToolExecutionResponse.from_orm(execution)
@router.get("/executions/{execution_id}/logs")
async def get_execution_logs(
execution_id: int,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Get real-time logs for execution"""
service = ToolExecutionService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
logs = await service.get_execution_logs(execution_id, user_id)
return logs
# Tool Categories
@router.post("/categories", response_model=ToolCategoryResponse)
async def create_category(
category_data: ToolCategoryCreate,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Create a new tool category (admin only)"""
user_obj = (
current_user.get("user_obj")
if isinstance(current_user, dict)
else current_user
)
if not user_obj or not user_obj.has_permission("manage_tools"):
raise HTTPException(status_code=403, detail="Admin privileges required")
service = ToolManagementService(db)
category = await service.create_category(
name=category_data.name,
display_name=category_data.display_name,
description=category_data.description,
icon=category_data.icon,
color=category_data.color,
sort_order=category_data.sort_order,
)
return ToolCategoryResponse.from_orm(category)
@router.get("/categories", response_model=List[ToolCategoryResponse])
async def list_categories(
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""List all active tool categories"""
service = ToolManagementService(db)
categories = await service.get_categories()
return [ToolCategoryResponse.from_orm(category) for category in categories]
# Statistics
@router.get("/statistics", response_model=ToolStatisticsResponse)
async def get_statistics(
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Get tool usage statistics"""
service = ToolManagementService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
stats = await service.get_tool_statistics(user_id=user_id)
return ToolStatisticsResponse(**stats)

View File

@@ -0,0 +1,703 @@
"""
User Management API endpoints
Admin endpoints for managing users, roles, and audit logs
"""
import logging
from typing import Optional, List, Dict, Any
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi.security import HTTPBearer
from pydantic import BaseModel, EmailStr, validator, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.security import get_current_user
from app.db.database import get_db
from app.models.user import User
from app.models.role import Role
from app.models.audit_log import AuditLog
from app.services.user_management_service import UserManagementService
from app.services.permission_manager import require_permission
from app.schemas.role import RoleCreate, RoleUpdate
logger = logging.getLogger(__name__)
router = APIRouter()
security = HTTPBearer()
# Request/Response Models
class CreateUserRequest(BaseModel):
email: EmailStr
username: str
password: str
full_name: Optional[str] = None
role_id: Optional[int] = None
is_active: bool = True
force_password_change: bool = False
@validator("password")
def validate_password(cls, v):
if len(v) < 8:
raise ValueError("Password must be at least 8 characters long")
return v
@validator("username")
def validate_username(cls, v):
if len(v) < 3:
raise ValueError("Username must be at least 3 characters long")
if not v.replace("_", "").replace("-", "").isalnum():
raise ValueError("Username must contain only alphanumeric characters, underscores, and hyphens")
return v
class UpdateUserRequest(BaseModel):
email: Optional[EmailStr] = None
username: Optional[str] = None
full_name: Optional[str] = None
role_id: Optional[int] = None
is_active: Optional[bool] = None
is_verified: Optional[bool] = None
custom_permissions: Optional[Dict[str, Any]] = None
@validator("username")
def validate_username(cls, v):
if v is not None and len(v) < 3:
raise ValueError("Username must be at least 3 characters long")
return v
class AdminPasswordResetRequest(BaseModel):
new_password: str
force_change_on_login: bool = True
@validator("new_password")
def validate_password(cls, v):
if len(v) < 8:
raise ValueError("Password must be at least 8 characters long")
return v
class UserResponse(BaseModel):
id: int
email: str
username: str
full_name: Optional[str]
role_id: Optional[int]
role: Optional[Dict[str, Any]]
is_active: bool
is_verified: bool
account_locked: bool
force_password_change: bool
created_at: datetime
updated_at: datetime
last_login: Optional[datetime]
audit_summary: Optional[Dict[str, Any]] = None
class Config:
from_attributes = True
class UserListResponse(BaseModel):
users: List[UserResponse]
total: int
skip: int
limit: int
class RoleResponse(BaseModel):
id: int
name: str
display_name: str
description: Optional[str]
level: str
is_active: bool
created_at: datetime
class Config:
from_attributes = True
class AuditLogResponse(BaseModel):
id: int
user_id: Optional[int]
action: str
resource_type: str
resource_id: Optional[str]
description: str
details: Dict[str, Any]
severity: str
category: Optional[str]
success: bool
created_at: datetime
class Config:
from_attributes = True
# User Management Endpoints
@router.get("/users", response_model=UserListResponse)
async def get_users(
request: Request,
skip: int = 0,
limit: int = 100,
search: Optional[str] = None,
role_id: Optional[int] = None,
is_active: Optional[bool] = None,
include_audit_summary: bool = False,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Get all users with filtering options"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:users:read",
context={"user_id": current_user["id"]}
)
service = UserManagementService(db)
users_data = await service.get_users(
skip=skip,
limit=limit,
search=search,
role_id=role_id,
is_active=is_active,
)
# Convert to response format
users = []
for user in users_data:
user_dict = user.to_dict()
if include_audit_summary:
# Get audit summary for user
audit_logs = await service.get_user_audit_logs(user.id, limit=10)
user_dict["audit_summary"] = {
"recent_actions": len(audit_logs),
"last_login": user.last_login.isoformat() if user.last_login else None,
}
user_response = UserResponse(**user_dict)
users.append(user_response)
return UserListResponse(
users=users,
total=len(users), # Would need actual count query for large datasets
skip=skip,
limit=limit,
)
@router.get("/users/{user_id}", response_model=UserResponse)
async def get_user(
user_id: int,
include_audit_summary: bool = False,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Get specific user by ID"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:users:read",
context={"user_id": current_user["id"], "owner_id": user_id}
)
service = UserManagementService(db)
user = await service.get_user_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
user_dict = user.to_dict()
if include_audit_summary:
# Get audit summary for user
audit_logs = await service.get_user_audit_logs(user_id, limit=10)
user_dict["audit_summary"] = {
"recent_actions": len(audit_logs),
"last_login": user.last_login.isoformat() if user.last_login else None,
}
return UserResponse(**user_dict)
@router.post("/users", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def create_user(
user_data: CreateUserRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Create a new user"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:users:create",
)
service = UserManagementService(db)
user = await service.create_user(
email=user_data.email,
username=user_data.username,
password=user_data.password,
full_name=user_data.full_name,
role_id=user_data.role_id,
is_active=user_data.is_active,
is_verified=True, # Admin-created users are verified by default
custom_permissions={}, # Empty by default
)
# Log user creation in audit
await service._log_audit_event(
user_id=current_user["id"],
action="create",
resource_type="user",
resource_id=str(user.id),
description=f"User created by admin: {user.email}",
details={
"created_by": current_user["email"],
"target_user": user.email,
"role_id": user_data.role_id,
},
severity="medium",
)
user_dict = user.to_dict()
return UserResponse(**user_dict)
@router.put("/users/{user_id}", response_model=UserResponse)
async def update_user(
user_id: int,
user_data: UpdateUserRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Update user information"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:users:update",
context={"user_id": current_user["id"], "owner_id": user_id}
)
service = UserManagementService(db)
# Get current user for audit comparison
current_user_data = await service.get_user_by_id(user_id)
if not current_user_data:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
old_values = current_user_data.to_dict()
# Update user
user = await service.update_user(
user_id=user_id,
email=user_data.email,
username=user_data.username,
full_name=user_data.full_name,
role_id=user_data.role_id,
is_active=user_data.is_active,
is_verified=user_data.is_verified,
custom_permissions=user_data.custom_permissions,
)
# Log user update in audit
await service._log_audit_event(
user_id=current_user["id"],
action="update",
resource_type="user",
resource_id=str(user_id),
description=f"User updated by admin: {user.email}",
details={
"updated_by": current_user["email"],
"target_user": user.email,
},
old_values=old_values,
new_values=user.to_dict(),
severity="medium",
)
user_dict = user.to_dict()
return UserResponse(**user_dict)
@router.post("/users/{user_id}/password-reset")
async def admin_reset_password(
user_id: int,
password_data: AdminPasswordResetRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Admin reset user password with forced change option"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:users:manage",
)
service = UserManagementService(db)
user = await service.admin_reset_user_password(
user_id=user_id,
new_password=password_data.new_password,
force_change_on_login=password_data.force_change_on_login,
admin_user_id=current_user["id"],
)
return {
"message": f"Password reset for user {user.email}",
"force_change_on_login": password_data.force_change_on_login,
}
@router.delete("/users/{user_id}")
async def delete_user(
user_id: int,
hard_delete: bool = False,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Delete or deactivate user"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:users:delete",
)
service = UserManagementService(db)
# Get user info for audit before deletion
user = await service.get_user_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
user_email = user.email
# Delete user
success = await service.delete_user(user_id, hard_delete=hard_delete)
# Log user deletion in audit
await service._log_audit_event(
user_id=current_user["id"],
action="delete",
resource_type="user",
resource_id=str(user_id),
description=f"User {'hard deleted' if hard_delete else 'deactivated'} by admin: {user_email}",
details={
"deleted_by": current_user["email"],
"target_user": user_email,
"hard_delete": hard_delete,
},
severity="high",
)
return {
"message": f"User {'deleted' if hard_delete else 'deactivated'} successfully",
"user_email": user_email,
}
# Role Management Endpoints
@router.get("/roles", response_model=List[RoleResponse])
async def get_roles(
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Get all available roles"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:roles:read",
)
service = UserManagementService(db)
roles = await service.get_roles(is_active=True)
return [RoleResponse(**role.to_dict()) for role in roles]
@router.post("/roles", response_model=RoleResponse, status_code=status.HTTP_201_CREATED)
async def create_role(
role_data: RoleCreate,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Create a new role"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:roles:create",
)
service = UserManagementService(db)
# Create role
role = await service.create_role(
name=role_data.name,
display_name=role_data.display_name,
description=role_data.description,
level=role_data.level,
permissions=role_data.permissions,
can_manage_users=role_data.can_manage_users,
can_manage_budgets=role_data.can_manage_budgets,
can_view_reports=role_data.can_view_reports,
can_manage_tools=role_data.can_manage_tools,
inherits_from=role_data.inherits_from,
is_active=role_data.is_active,
is_system_role=role_data.is_system_role,
)
# Log role creation
await service._log_audit_event(
user_id=current_user["id"],
action="create",
resource_type="role",
resource_id=str(role.id),
description=f"Role created: {role.name}",
details={
"created_by": current_user["email"],
"role_name": role.name,
"level": role.level,
},
severity="medium",
)
return RoleResponse(**role.to_dict())
@router.put("/roles/{role_id}", response_model=RoleResponse)
async def update_role(
role_id: int,
role_data: RoleUpdate,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Update a role"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:roles:update",
)
service = UserManagementService(db)
# Get current role for audit
current_role = await service.get_role_by_id(role_id)
if not current_role:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Role not found"
)
# Prevent updating system roles
if current_role.is_system_role:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Cannot modify system roles"
)
old_values = current_role.to_dict()
# Update role
role = await service.update_role(
role_id=role_id,
display_name=role_data.display_name,
description=role_data.description,
permissions=role_data.permissions,
can_manage_users=role_data.can_manage_users,
can_manage_budgets=role_data.can_manage_budgets,
can_view_reports=role_data.can_view_reports,
can_manage_tools=role_data.can_manage_tools,
is_active=role_data.is_active,
)
# Log role update
await service._log_audit_event(
user_id=current_user["id"],
action="update",
resource_type="role",
resource_id=str(role_id),
description=f"Role updated: {role.name}",
details={
"updated_by": current_user["email"],
"role_name": role.name,
},
old_values=old_values,
new_values=role.to_dict(),
severity="medium",
)
return RoleResponse(**role.to_dict())
@router.delete("/roles/{role_id}")
async def delete_role(
role_id: int,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Delete a role"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:roles:delete",
)
service = UserManagementService(db)
# Get role info for audit before deletion
role = await service.get_role_by_id(role_id)
if not role:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Role not found"
)
# Prevent deleting system roles
if role.is_system_role:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Cannot delete system roles"
)
role_name = role.name
# Delete role
success = await service.delete_role(role_id)
# Log role deletion
await service._log_audit_event(
user_id=current_user["id"],
action="delete",
resource_type="role",
resource_id=str(role_id),
description=f"Role deleted: {role_name}",
details={
"deleted_by": current_user["email"],
"role_name": role_name,
},
severity="high",
)
return {
"message": f"Role {role_name} deleted successfully",
"role_name": role_name,
}
# Audit Log Endpoints
@router.get("/users/{user_id}/audit-logs", response_model=List[AuditLogResponse])
async def get_user_audit_logs(
user_id: int,
skip: int = 0,
limit: int = 50,
action_filter: Optional[str] = None,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Get audit logs for a specific user"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:audit:read",
context={"user_id": current_user["id"], "owner_id": user_id}
)
service = UserManagementService(db)
audit_logs = await service.get_user_audit_logs(
user_id=user_id,
skip=skip,
limit=limit,
action_filter=action_filter,
)
return [AuditLogResponse(**log.to_dict()) for log in audit_logs]
@router.get("/audit-logs", response_model=List[AuditLogResponse])
async def get_all_audit_logs(
skip: int = 0,
limit: int = 100,
user_id: Optional[int] = None,
action_filter: Optional[str] = None,
category_filter: Optional[str] = None,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Get all audit logs with filtering"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:audit:read",
)
# Direct database query for audit logs with filters
from sqlalchemy import select, and_, desc
query = select(AuditLog)
conditions = []
if user_id:
conditions.append(AuditLog.user_id == user_id)
if action_filter:
conditions.append(AuditLog.action == action_filter)
if category_filter:
conditions.append(AuditLog.category == category_filter)
if conditions:
query = query.where(and_(*conditions))
query = query.order_by(desc(AuditLog.created_at))
query = query.offset(skip).limit(limit)
result = await db.execute(query)
audit_logs = result.scalars().all()
return [AuditLogResponse(**log.to_dict()) for log in audit_logs]
# Statistics Endpoints
@router.get("/statistics")
async def get_user_management_statistics(
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Get user management statistics"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:users:read",
)
service = UserManagementService(db)
user_stats = await service.get_user_statistics()
role_stats = await service.get_role_statistics()
return {
"users": user_stats,
"roles": role_stats,
"generated_at": datetime.utcnow().isoformat(),
}

View File

@@ -12,16 +12,33 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db.database import get_db
from app.services.api_key_auth import require_api_key, RequireScope, APIKeyAuthService, get_api_key_context
from app.services.api_key_auth import (
require_api_key,
RequireScope,
APIKeyAuthService,
get_api_key_context,
)
from app.core.security import get_current_user
from app.models.user import User
from app.core.config import settings
from app.services.llm.service import llm_service
from app.services.llm.models import ChatRequest, ChatMessage as LLMChatMessage, EmbeddingRequest as LLMEmbeddingRequest
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError, ValidationError
from app.services.llm.models import (
ChatRequest,
ChatMessage as LLMChatMessage,
EmbeddingRequest as LLMEmbeddingRequest,
)
from app.services.llm.exceptions import (
LLMError,
ProviderError,
SecurityError,
ValidationError,
)
from app.services.budget_enforcement import (
check_budget_for_request, record_request_usage, BudgetEnforcementService,
atomic_check_and_reserve_budget, atomic_finalize_usage
check_budget_for_request,
record_request_usage,
BudgetEnforcementService,
atomic_check_and_reserve_budget,
atomic_finalize_usage,
)
from app.services.cost_calculator import CostCalculator, estimate_request_cost
from app.utils.exceptions import AuthenticationError, AuthorizationError
@@ -30,11 +47,7 @@ from app.middleware.analytics import set_analytics_data
logger = logging.getLogger(__name__)
# Models response cache - simple in-memory cache for performance
_models_cache = {
"data": None,
"cached_at": 0,
"cache_ttl": 900 # 15 minutes cache TTL
}
_models_cache = {"data": None, "cached_at": 0, "cache_ttl": 900} # 15 minutes cache TTL
router = APIRouter()
@@ -42,18 +55,20 @@ router = APIRouter()
async def get_cached_models() -> List[Dict[str, Any]]:
"""Get models from cache or fetch from LLM service if cache is stale"""
current_time = time.time()
# Check if cache is still valid
if (_models_cache["data"] is not None and
current_time - _models_cache["cached_at"] < _models_cache["cache_ttl"]):
if (
_models_cache["data"] is not None
and current_time - _models_cache["cached_at"] < _models_cache["cache_ttl"]
):
logger.debug("Returning cached models list")
return _models_cache["data"]
# Cache miss or stale - fetch from LLM service
try:
logger.debug("Fetching fresh models list from LLM service")
model_infos = await llm_service.get_models()
# Convert ModelInfo objects to dict format for compatibility
models = []
for model_info in model_infos:
@@ -63,32 +78,36 @@ async def get_cached_models() -> List[Dict[str, Any]]:
"created": model_info.created or int(time.time()),
"owned_by": model_info.owned_by,
# Add frontend-expected fields
"name": getattr(model_info, 'name', model_info.id), # Use name if available, fallback to id
"provider": getattr(model_info, 'provider', model_info.owned_by), # Use provider if available, fallback to owned_by
"name": getattr(
model_info, "name", model_info.id
), # Use name if available, fallback to id
"provider": getattr(
model_info, "provider", model_info.owned_by
), # Use provider if available, fallback to owned_by
"capabilities": model_info.capabilities,
"context_window": model_info.context_window,
"max_output_tokens": model_info.max_output_tokens,
"supports_streaming": model_info.supports_streaming,
"supports_function_calling": model_info.supports_function_calling
"supports_function_calling": model_info.supports_function_calling,
}
# Include tasks field if present
if model_info.tasks:
model_dict["tasks"] = model_info.tasks
models.append(model_dict)
# Update cache
_models_cache["data"] = models
_models_cache["cached_at"] = current_time
return models
except Exception as e:
logger.error(f"Failed to fetch models from LLM service: {e}")
# Return stale cache if available, otherwise empty list
if _models_cache["data"] is not None:
logger.warning("Returning stale cached models due to fetch error")
return _models_cache["data"]
return []
@@ -138,11 +157,12 @@ class ModelsResponse(BaseModel):
# Authentication: Public API endpoints should use require_api_key
# Internal API endpoints should use get_current_user from core.security
# Endpoints
@router.get("/models", response_model=ModelsResponse)
async def list_models(
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""List available models"""
try:
@@ -155,33 +175,35 @@ async def list_models(
if not await auth_service.check_scope_permission(context, "models.list"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions to list models"
detail="Insufficient permissions to list models",
)
# Get models from cache or LLM service
models = await get_cached_models()
# Filter models based on API key permissions
api_key = context.get("api_key")
if api_key and api_key.allowed_models:
models = [model for model in models if model.get("id") in api_key.allowed_models]
models = [
model for model in models if model.get("id") in api_key.allowed_models
]
return ModelsResponse(data=models)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error listing models: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to list models"
detail="Failed to list models",
)
@router.post("/models/invalidate-cache")
async def invalidate_models_cache_endpoint(
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Invalidate models cache (admin only)"""
# Check for admin permissions
@@ -190,7 +212,7 @@ async def invalidate_models_cache_endpoint(
if not user or not user.get("is_superuser"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin privileges required"
detail="Admin privileges required",
)
else:
# For API key users, check admin permissions
@@ -198,9 +220,9 @@ async def invalidate_models_cache_endpoint(
if not await auth_service.check_scope_permission(context, "admin.cache"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin permissions required to invalidate cache"
detail="Admin permissions required to invalidate cache",
)
invalidate_models_cache()
return {"message": "Models cache invalidated successfully"}
@@ -210,34 +232,38 @@ async def create_chat_completion(
request_body: Request,
chat_request: ChatCompletionRequest,
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Create chat completion with budget enforcement"""
try:
auth_type = context.get("auth_type", "api_key")
# Handle different authentication types
if auth_type == "api_key":
auth_service = APIKeyAuthService(db)
# Check permissions
if not await auth_service.check_scope_permission(context, "chat.completions"):
if not await auth_service.check_scope_permission(
context, "chat.completions"
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions for chat completions"
detail="Insufficient permissions for chat completions",
)
if not await auth_service.check_model_permission(context, chat_request.model):
if not await auth_service.check_model_permission(
context, chat_request.model
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Model '{chat_request.model}' not allowed"
detail=f"Model '{chat_request.model}' not allowed",
)
api_key = context.get("api_key")
if not api_key:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="API key information not available"
detail="API key information not available",
)
elif auth_type == "jwt":
# For JWT authentication, we'll skip the detailed permission checks for now
@@ -246,15 +272,15 @@ async def create_chat_completion(
if not user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User information not available"
detail="User information not available",
)
api_key = None # JWT users don't have API keys
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication type"
detail="Invalid authentication type",
)
# Estimate token usage for budget checking
messages_text = " ".join([msg.content for msg in chat_request.messages])
estimated_tokens = len(messages_text.split()) * 1.3 # Rough token estimation
@@ -262,31 +288,44 @@ async def create_chat_completion(
estimated_tokens += chat_request.max_tokens
else:
estimated_tokens += 150 # Default response length estimate
# Get a synchronous session for budget enforcement
from app.db.database import SessionLocal
sync_db = SessionLocal()
try:
# Atomic budget check and reservation (only for API key users)
warnings = []
reserved_budget_ids = []
if auth_type == "api_key" and api_key:
is_allowed, error_message, budget_warnings, budget_ids = atomic_check_and_reserve_budget(
sync_db, api_key, chat_request.model, int(estimated_tokens), "chat/completions"
(
is_allowed,
error_message,
budget_warnings,
budget_ids,
) = atomic_check_and_reserve_budget(
sync_db,
api_key,
chat_request.model,
int(estimated_tokens),
"chat/completions",
)
if not is_allowed:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Budget exceeded: {error_message}"
detail=f"Budget exceeded: {error_message}",
)
warnings = budget_warnings
reserved_budget_ids = budget_ids
# Convert messages to LLM service format
llm_messages = [LLMChatMessage(role=msg.role, content=msg.content) for msg in chat_request.messages]
llm_messages = [
LLMChatMessage(role=msg.role, content=msg.content)
for msg in chat_request.messages
]
# Create LLM service request
llm_request = ChatRequest(
model=chat_request.model,
@@ -299,12 +338,14 @@ async def create_chat_completion(
stop=chat_request.stop,
stream=chat_request.stream or False,
user_id=str(context.get("user_id", "anonymous")),
api_key_id=context.get("api_key_id", 0) if auth_type == "api_key" else 0
api_key_id=context.get("api_key_id", 0)
if auth_type == "api_key"
else 0,
)
# Make request to LLM service
llm_response = await llm_service.create_chat_completion(llm_request)
# Convert LLM service response to API format
response = {
"id": llm_response.id,
@@ -316,45 +357,56 @@ async def create_chat_completion(
"index": choice.index,
"message": {
"role": choice.message.role,
"content": choice.message.content
"content": choice.message.content,
},
"finish_reason": choice.finish_reason
"finish_reason": choice.finish_reason,
}
for choice in llm_response.choices
],
"usage": {
"prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0,
"completion_tokens": llm_response.usage.completion_tokens if llm_response.usage else 0,
"total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0
} if llm_response.usage else {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
"prompt_tokens": llm_response.usage.prompt_tokens
if llm_response.usage
else 0,
"completion_tokens": llm_response.usage.completion_tokens
if llm_response.usage
else 0,
"total_tokens": llm_response.usage.total_tokens
if llm_response.usage
else 0,
}
if llm_response.usage
else {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
# Calculate actual cost and update usage
usage = response.get("usage", {})
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
total_tokens = usage.get("total_tokens", input_tokens + output_tokens)
# Calculate accurate cost
actual_cost_cents = CostCalculator.calculate_cost_cents(
chat_request.model, input_tokens, output_tokens
)
# Finalize actual usage in budgets (only for API key users)
if auth_type == "api_key" and api_key:
atomic_finalize_usage(
sync_db, reserved_budget_ids, api_key, chat_request.model,
input_tokens, output_tokens, "chat/completions"
sync_db,
reserved_budget_ids,
api_key,
chat_request.model,
input_tokens,
output_tokens,
"chat/completions",
)
# Update API key usage statistics
auth_service = APIKeyAuthService(db)
await auth_service.update_usage_stats(context, total_tokens, actual_cost_cents)
await auth_service.update_usage_stats(
context, total_tokens, actual_cost_cents
)
# Set analytics data for middleware
set_analytics_data(
model=chat_request.model,
@@ -363,55 +415,55 @@ async def create_chat_completion(
total_tokens=total_tokens,
cost_cents=actual_cost_cents,
budget_ids=reserved_budget_ids,
budget_warnings=warnings
budget_warnings=warnings,
)
# Add budget warnings to response if any
if warnings:
response["budget_warnings"] = warnings
return response
finally:
sync_db.close()
except HTTPException:
raise
except SecurityError as e:
logger.warning(f"Security error in chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Security validation failed: {e.message}"
detail=f"Security validation failed: {e.message}",
)
except ValidationError as e:
logger.warning(f"Validation error in chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Request validation failed: {e.message}"
detail=f"Request validation failed: {e.message}",
)
except ProviderError as e:
logger.error(f"Provider error in chat completion: {e}")
if "rate limit" in str(e).lower():
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded"
detail="Rate limit exceeded",
)
else:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LLM service temporarily unavailable"
detail="LLM service temporarily unavailable",
)
except LLMError as e:
logger.error(f"LLM service error in chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="LLM service error"
detail="LLM service error",
)
except Exception as e:
logger.error(f"Unexpected error creating chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create chat completion"
detail="Failed to create chat completion",
)
@@ -419,62 +471,62 @@ async def create_chat_completion(
async def create_embedding(
request: EmbeddingRequest,
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Create embedding with budget enforcement"""
try:
auth_service = APIKeyAuthService(db)
# Check permissions
if not await auth_service.check_scope_permission(context, "embeddings.create"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions for embeddings"
detail="Insufficient permissions for embeddings",
)
if not await auth_service.check_model_permission(context, request.model):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Model '{request.model}' not allowed"
detail=f"Model '{request.model}' not allowed",
)
api_key = context.get("api_key")
if not api_key:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="API key information not available"
detail="API key information not available",
)
# Estimate token usage for budget checking
estimated_tokens = len(request.input.split()) * 1.3 # Rough token estimation
# Convert AsyncSession to Session for budget enforcement
sync_db = Session(bind=db.bind.sync_engine)
try:
# Check budget compliance before making request
is_allowed, error_message, warnings = check_budget_for_request(
sync_db, api_key, request.model, int(estimated_tokens), "embeddings"
)
if not is_allowed:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Budget exceeded: {error_message}"
detail=f"Budget exceeded: {error_message}",
)
# Create LLM service request
llm_request = LLMEmbeddingRequest(
model=request.model,
input=request.input,
encoding_format=request.encoding_format,
user_id=str(context["user_id"]),
api_key_id=context["api_key_id"]
api_key_id=context["api_key_id"],
)
# Make request to LLM service
llm_response = await llm_service.create_embedding(llm_request)
# Convert LLM service response to API format
response = {
"object": llm_response.object,
@@ -482,139 +534,142 @@ async def create_embedding(
{
"object": emb.object,
"index": emb.index,
"embedding": emb.embedding
"embedding": emb.embedding,
}
for emb in llm_response.data
],
"model": llm_response.model,
"usage": {
"prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0,
"total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0
} if llm_response.usage else {
"prompt_tokens": int(estimated_tokens),
"total_tokens": int(estimated_tokens)
"prompt_tokens": llm_response.usage.prompt_tokens
if llm_response.usage
else 0,
"total_tokens": llm_response.usage.total_tokens
if llm_response.usage
else 0,
}
if llm_response.usage
else {
"prompt_tokens": int(estimated_tokens),
"total_tokens": int(estimated_tokens),
},
}
# Calculate actual cost and update usage
usage = response.get("usage", {})
total_tokens = usage.get("total_tokens", int(estimated_tokens))
# Calculate accurate cost (embeddings typically use input tokens only)
actual_cost_cents = CostCalculator.calculate_cost_cents(
request.model, total_tokens, 0
)
# Record actual usage in budgets
record_request_usage(
sync_db, api_key, request.model, total_tokens, 0, "embeddings"
)
# Update API key usage statistics
await auth_service.update_usage_stats(context, total_tokens, actual_cost_cents)
await auth_service.update_usage_stats(
context, total_tokens, actual_cost_cents
)
# Add budget warnings to response if any
if warnings:
response["budget_warnings"] = warnings
return response
finally:
sync_db.close()
except HTTPException:
raise
except SecurityError as e:
logger.warning(f"Security error in embedding: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Security validation failed: {e.message}"
detail=f"Security validation failed: {e.message}",
)
except ValidationError as e:
logger.warning(f"Validation error in embedding: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Request validation failed: {e.message}"
detail=f"Request validation failed: {e.message}",
)
except ProviderError as e:
logger.error(f"Provider error in embedding: {e}")
if "rate limit" in str(e).lower():
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded"
detail="Rate limit exceeded",
)
else:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LLM service temporarily unavailable"
detail="LLM service temporarily unavailable",
)
except LLMError as e:
logger.error(f"LLM service error in embedding: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="LLM service error"
detail="LLM service error",
)
except Exception as e:
logger.error(f"Unexpected error creating embedding: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create embedding"
detail="Failed to create embedding",
)
@router.get("/health")
async def llm_health_check(
context: Dict[str, Any] = Depends(require_api_key)
):
async def llm_health_check(context: Dict[str, Any] = Depends(require_api_key)):
"""Health check for LLM service"""
try:
health_summary = llm_service.get_health_summary()
provider_status = await llm_service.get_provider_status()
# Determine overall health
overall_status = "healthy"
if health_summary["service_status"] != "healthy":
overall_status = "degraded"
for provider, status in provider_status.items():
if status.status == "unavailable":
overall_status = "degraded"
break
return {
"status": overall_status,
"service": "LLM Service",
"service_status": health_summary,
"provider_status": {name: {
"status": status.status,
"latency_ms": status.latency_ms,
"error_message": status.error_message
} for name, status in provider_status.items()},
"provider_status": {
name: {
"status": status.status,
"latency_ms": status.latency_ms,
"error_message": status.error_message,
}
for name, status in provider_status.items()
},
"user_id": context["user_id"],
"api_key_name": context["api_key_name"]
"api_key_name": context["api_key_name"],
}
except Exception as e:
logger.error(f"LLM health check error: {e}")
return {
"status": "unhealthy",
"service": "LLM Service",
"error": str(e)
}
return {"status": "unhealthy", "service": "LLM Service", "error": str(e)}
@router.get("/usage")
async def get_usage_stats(
context: Dict[str, Any] = Depends(require_api_key)
):
async def get_usage_stats(context: Dict[str, Any] = Depends(require_api_key)):
"""Get usage statistics for the API key"""
try:
api_key = context.get("api_key")
if not api_key:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="API key information not available"
detail="API key information not available",
)
return {
"api_key_id": api_key.id,
"api_key_name": api_key.name,
@@ -622,24 +677,26 @@ async def get_usage_stats(
"total_tokens": api_key.total_tokens,
"total_cost_cents": api_key.total_cost,
"created_at": api_key.created_at.isoformat(),
"last_used_at": api_key.last_used_at.isoformat() if api_key.last_used_at else None,
"last_used_at": api_key.last_used_at.isoformat()
if api_key.last_used_at
else None,
"rate_limits": {
"per_minute": api_key.rate_limit_per_minute,
"per_hour": api_key.rate_limit_per_hour,
"per_day": api_key.rate_limit_per_day
"per_day": api_key.rate_limit_per_day,
},
"permissions": api_key.permissions,
"scopes": api_key.scopes,
"allowed_models": api_key.allowed_models
"allowed_models": api_key.allowed_models,
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting usage stats: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get usage statistics"
detail="Failed to get usage statistics",
)
@@ -647,51 +704,48 @@ async def get_usage_stats(
async def get_budget_status(
request: Request,
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get current budget status and usage analytics"""
try:
auth_type = context.get("auth_type", "api_key")
# Check permissions based on auth type
if auth_type == "api_key":
auth_service = APIKeyAuthService(db)
if not await auth_service.check_scope_permission(context, "budget.read"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions to read budget information"
detail="Insufficient permissions to read budget information",
)
api_key = context.get("api_key")
if not api_key:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="API key information not available"
detail="API key information not available",
)
# Convert AsyncSession to Session for budget enforcement
sync_db = Session(bind=db.bind.sync_engine)
try:
budget_service = BudgetEnforcementService(sync_db)
budget_status = budget_service.get_budget_status(api_key)
return {
"object": "budget_status",
"data": budget_status
}
return {"object": "budget_status", "data": budget_status}
finally:
sync_db.close()
elif auth_type == "jwt":
# For JWT authentication, return user-level budget information
user = context.get("user")
if not user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User information not available"
detail="User information not available",
)
# Return basic budget info for JWT users
return {
"object": "budget_status",
@@ -702,23 +756,23 @@ async def get_budget_status(
"projections": {
"daily_burn_rate": 0.0,
"projected_monthly": 0.0,
"days_remaining": 30
}
}
"days_remaining": 30,
},
},
}
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication type"
detail="Invalid authentication type",
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting budget status: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get budget status"
detail="Failed to get budget status",
)
@@ -726,7 +780,7 @@ async def get_budget_status(
@router.get("/metrics")
async def get_llm_metrics(
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get LLM service metrics (admin only)"""
try:
@@ -735,9 +789,9 @@ async def get_llm_metrics(
if not await auth_service.check_scope_permission(context, "admin.metrics"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin permissions required to view metrics"
detail="Admin permissions required to view metrics",
)
metrics = llm_service.get_metrics()
return {
"object": "llm_metrics",
@@ -745,27 +799,27 @@ async def get_llm_metrics(
"total_requests": metrics.total_requests,
"successful_requests": metrics.successful_requests,
"failed_requests": metrics.failed_requests,
"average_latency_ms": metrics.average_latency_ms,
"average_latency_ms": metrics.average_latency_ms,
"average_risk_score": metrics.average_risk_score,
"provider_metrics": metrics.provider_metrics,
"last_updated": metrics.last_updated.isoformat()
}
"last_updated": metrics.last_updated.isoformat(),
},
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting LLM metrics: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get LLM metrics"
detail="Failed to get LLM metrics",
)
@router.get("/providers/status")
async def get_provider_status(
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get status of all LLM providers"""
try:
@@ -773,9 +827,9 @@ async def get_provider_status(
if not await auth_service.check_scope_permission(context, "admin.status"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin permissions required to view provider status"
detail="Admin permissions required to view provider status",
)
provider_status = await llm_service.get_provider_status()
return {
"object": "provider_status",
@@ -787,17 +841,17 @@ async def get_provider_status(
"success_rate": status.success_rate,
"last_check": status.last_check.isoformat(),
"error_message": status.error_message,
"models_available": status.models_available
"models_available": status.models_available,
}
for name, status in provider_status.items()
}
},
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting provider status: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get provider status"
)
detail="Failed to get provider status",
)

View File

@@ -13,7 +13,12 @@ from app.db.database import get_db
from app.core.security import get_current_user
from app.services.llm.service import llm_service
from app.services.llm.models import ChatRequest, ChatMessage as LLMChatMessage
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError, ValidationError
from app.services.llm.exceptions import (
LLMError,
ProviderError,
SecurityError,
ValidationError,
)
from app.api.v1.llm import get_cached_models # Reuse the caching logic
logger = logging.getLogger(__name__)
@@ -35,14 +40,12 @@ async def list_models(
logger.error(f"Failed to list models: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve models"
detail="Failed to retrieve models",
)
@router.get("/providers/status")
async def get_provider_status(
current_user: Dict[str, Any] = Depends(get_current_user)
):
async def get_provider_status(current_user: Dict[str, Any] = Depends(get_current_user)):
"""
Get status of all LLM providers for authenticated users
"""
@@ -58,23 +61,21 @@ async def get_provider_status(
"success_rate": status.success_rate,
"last_check": status.last_check.isoformat(),
"error_message": status.error_message,
"models_available": status.models_available
"models_available": status.models_available,
}
for name, status in provider_status.items()
}
},
}
except Exception as e:
logger.error(f"Failed to get provider status: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve provider status"
detail="Failed to retrieve provider status",
)
@router.get("/health")
async def health_check(
current_user: Dict[str, Any] = Depends(get_current_user)
):
async def health_check(current_user: Dict[str, Any] = Depends(get_current_user)):
"""
Get LLM service health status for authenticated users
"""
@@ -83,39 +84,35 @@ async def health_check(
return {
"status": health["status"],
"providers": health.get("providers", {}),
"timestamp": health.get("timestamp")
"timestamp": health.get("timestamp"),
}
except Exception as e:
logger.error(f"Health check failed: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Health check failed"
detail="Health check failed",
)
@router.get("/metrics")
async def get_metrics(
current_user: Dict[str, Any] = Depends(get_current_user)
):
async def get_metrics(current_user: Dict[str, Any] = Depends(get_current_user)):
"""
Get LLM service metrics for authenticated users
"""
try:
metrics = await llm_service.get_metrics()
return {
"object": "metrics",
"data": metrics
}
return {"object": "metrics", "data": metrics}
except Exception as e:
logger.error(f"Failed to get metrics: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve metrics"
detail="Failed to retrieve metrics",
)
class ChatCompletionRequest(BaseModel):
"""Request model for chat completions"""
model: str
messages: List[Dict[str, str]]
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0)
@@ -128,7 +125,7 @@ class ChatCompletionRequest(BaseModel):
async def create_chat_completion(
request: ChatCompletionRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""
Create chat completion for authenticated frontend users.
@@ -137,7 +134,7 @@ async def create_chat_completion(
try:
# Get user ID from JWT token context
user_id = str(current_user.get("id", current_user.get("sub", "0")))
# Convert request to LLM service format
# For internal use, we use a special api_key_id of 0 to indicate JWT auth
chat_request = ChatRequest(
@@ -151,15 +148,17 @@ async def create_chat_completion(
top_p=request.top_p,
stream=request.stream,
user_id=user_id,
api_key_id=0 # Special value for JWT-authenticated requests
api_key_id=0, # Special value for JWT-authenticated requests
)
# Log the request for debugging
logger.info(f"Internal chat completion request from user {current_user.get('id')}: model={request.model}")
logger.info(
f"Internal chat completion request from user {current_user.get('id')}: model={request.model}"
)
# Process the request through the LLM service
response = await llm_service.create_chat_completion(chat_request)
# Format the response to match OpenAI's structure
formatted_response = {
"id": response.id,
@@ -171,36 +170,39 @@ async def create_chat_completion(
"index": choice.index,
"message": {
"role": choice.message.role,
"content": choice.message.content
"content": choice.message.content,
},
"finish_reason": choice.finish_reason
"finish_reason": choice.finish_reason,
}
for choice in response.choices
],
"usage": {
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
"total_tokens": response.usage.total_tokens if response.usage else 0
} if response.usage else None
"completion_tokens": response.usage.completion_tokens
if response.usage
else 0,
"total_tokens": response.usage.total_tokens if response.usage else 0,
}
if response.usage
else None,
}
return formatted_response
except ValidationError as e:
logger.error(f"Validation error in chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid request: {str(e)}"
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid request: {str(e)}"
)
except LLMError as e:
logger.error(f"LLM service error: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"LLM service error: {str(e)}"
detail=f"LLM service error: {str(e)}",
)
except Exception as e:
logger.error(f"Unexpected error in chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to process chat completion"
)
detail="Failed to process chat completion",
)

View File

@@ -15,17 +15,17 @@ router = APIRouter()
async def list_modules(current_user: Dict[str, Any] = Depends(get_current_user)):
"""Get list of all discovered modules with their status (enabled and disabled)"""
log_api_request("list_modules", {})
# Get all discovered modules including disabled ones
all_modules = module_manager.list_all_modules()
modules = []
for module_info in all_modules:
# Convert module_info to API format with status field
name = module_info["name"]
is_loaded = module_info["loaded"] # Module is actually loaded in memory
is_enabled = module_info["enabled"] # Module is enabled in config
# Determine status based on enabled + loaded state
if is_enabled and is_loaded:
status = "running"
@@ -33,40 +33,43 @@ async def list_modules(current_user: Dict[str, Any] = Depends(get_current_user))
status = "error" # Enabled but failed to load
else: # not is_enabled (regardless of loaded state)
status = "standby" # Disabled
api_module = {
"name": name,
"version": module_info["version"],
"description": module_info["description"],
"initialized": is_loaded,
"initialized": is_loaded,
"enabled": is_enabled,
"status": status # Add status field for frontend compatibility
"status": status, # Add status field for frontend compatibility
}
# Get module statistics if available and module is loaded
if module_info["loaded"] and module_info["name"] in module_manager.modules:
module_instance = module_manager.modules[module_info["name"]]
if hasattr(module_instance, "get_stats"):
try:
import asyncio
if asyncio.iscoroutinefunction(module_instance.get_stats):
stats = await module_instance.get_stats()
else:
stats = module_instance.get_stats()
api_module["stats"] = stats.__dict__ if hasattr(stats, "__dict__") else stats
api_module["stats"] = (
stats.__dict__ if hasattr(stats, "__dict__") else stats
)
except:
api_module["stats"] = {}
modules.append(api_module)
# Calculate stats
loaded_count = sum(1 for m in modules if m["initialized"] and m["enabled"])
return {
"total": len(modules),
"modules": modules,
"module_count": loaded_count,
"initialized": module_manager.initialized
"initialized": module_manager.initialized,
}
@@ -74,20 +77,20 @@ async def list_modules(current_user: Dict[str, Any] = Depends(get_current_user))
async def get_modules_status(current_user: Dict[str, Any] = Depends(get_current_user)):
"""Get comprehensive module status - CONSOLIDATED endpoint"""
log_api_request("get_modules_status", {})
# Get all discovered modules including disabled ones
all_modules = module_manager.list_all_modules()
modules_with_status = []
running_count = 0
standby_count = 0
failed_count = 0
for module_info in all_modules:
name = module_info["name"]
is_loaded = module_info["loaded"] # Module is actually loaded in memory
is_enabled = module_info["enabled"] # Module is enabled in config
# Determine status based on enabled + loaded state
if is_enabled and is_loaded:
status = "running"
@@ -98,7 +101,7 @@ async def get_modules_status(current_user: Dict[str, Any] = Depends(get_current_
else: # not is_enabled (regardless of loaded state)
status = "standby" # Disabled
standby_count += 1
# Get module statistics if available and loaded
stats = {}
if is_loaded and name in module_manager.modules:
@@ -106,56 +109,68 @@ async def get_modules_status(current_user: Dict[str, Any] = Depends(get_current_
if hasattr(module_instance, "get_stats"):
try:
import asyncio
if asyncio.iscoroutinefunction(module_instance.get_stats):
stats_result = await module_instance.get_stats()
else:
stats_result = module_instance.get_stats()
stats = stats_result.__dict__ if hasattr(stats_result, "__dict__") else stats_result
stats = (
stats_result.__dict__
if hasattr(stats_result, "__dict__")
else stats_result
)
except:
stats = {}
modules_with_status.append({
"name": name,
"version": module_info["version"],
"description": module_info["description"],
"status": status,
"enabled": is_enabled,
"loaded": is_loaded,
"stats": stats
})
modules_with_status.append(
{
"name": name,
"version": module_info["version"],
"description": module_info["description"],
"status": status,
"enabled": is_enabled,
"loaded": is_loaded,
"stats": stats,
}
)
return {
"modules": modules_with_status,
"total": len(modules_with_status),
"running": running_count,
"standby": standby_count,
"standby": standby_count,
"failed": failed_count,
"system_initialized": module_manager.initialized
"system_initialized": module_manager.initialized,
}
@router.get("/{module_name}")
async def get_module_info(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
async def get_module_info(
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Get detailed information about a specific module"""
log_api_request("get_module_info", {"module_name": module_name})
if module_name not in module_manager.modules:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
module = module_manager.modules[module_name]
module_info = {
"name": module_name,
"version": getattr(module, "version", "1.0.0"),
"description": getattr(module, "description", ""),
"initialized": getattr(module, "initialized", False),
"enabled": module_manager.module_configs.get(module_name, ModuleConfig(module_name)).enabled,
"capabilities": []
"enabled": module_manager.module_configs.get(
module_name, ModuleConfig(module_name)
).enabled,
"capabilities": [],
}
# Get module capabilities
if hasattr(module, "get_module_info"):
try:
import asyncio
if asyncio.iscoroutinefunction(module.get_module_info):
info = await module.get_module_info()
else:
@@ -163,19 +178,22 @@ async def get_module_info(module_name: str, current_user: Dict[str, Any] = Depen
module_info.update(info)
except:
pass
# Get module statistics
if hasattr(module, "get_stats"):
try:
import asyncio
if asyncio.iscoroutinefunction(module.get_stats):
stats = await module.get_stats()
else:
stats = module.get_stats()
module_info["stats"] = stats.__dict__ if hasattr(stats, "__dict__") else stats
module_info["stats"] = (
stats.__dict__ if hasattr(stats, "__dict__") else stats
)
except:
module_info["stats"] = {}
# List available methods
methods = []
for attr_name in dir(module):
@@ -183,57 +201,64 @@ async def get_module_info(module_name: str, current_user: Dict[str, Any] = Depen
if callable(attr) and not attr_name.startswith("_"):
methods.append(attr_name)
module_info["methods"] = methods
return module_info
@router.post("/{module_name}/enable")
async def enable_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
async def enable_module(
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Enable a module"""
log_api_request("enable_module", {"module_name": module_name})
if module_name not in module_manager.module_configs:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
# Enable the module in config
config = module_manager.module_configs[module_name]
config.enabled = True
# Load the module if not already loaded
if module_name not in module_manager.modules:
try:
await module_manager._load_module(module_name, config)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to enable module '{module_name}': {str(e)}")
return {
"message": f"Module '{module_name}' enabled successfully",
"enabled": True
}
raise HTTPException(
status_code=500,
detail=f"Failed to enable module '{module_name}': {str(e)}",
)
return {"message": f"Module '{module_name}' enabled successfully", "enabled": True}
@router.post("/{module_name}/disable")
async def disable_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
async def disable_module(
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Disable a module"""
log_api_request("disable_module", {"module_name": module_name})
if module_name not in module_manager.module_configs:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
# Disable the module in config
config = module_manager.module_configs[module_name]
config.enabled = False
# Unload the module if loaded
if module_name in module_manager.modules:
try:
await module_manager.unload_module(module_name)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to disable module '{module_name}': {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Failed to disable module '{module_name}': {str(e)}",
)
return {
"message": f"Module '{module_name}' disabled successfully",
"enabled": False
"enabled": False,
}
@@ -241,10 +266,10 @@ async def disable_module(module_name: str, current_user: Dict[str, Any] = Depend
async def reload_all_modules(current_user: Dict[str, Any] = Depends(get_current_user)):
"""Reload all modules"""
log_api_request("reload_all_modules", {})
results = {}
failed_modules = []
for module_name in list(module_manager.modules.keys()):
try:
success = await module_manager.reload_module(module_name)
@@ -254,272 +279,316 @@ async def reload_all_modules(current_user: Dict[str, Any] = Depends(get_current_
except Exception as e:
results[module_name] = {"success": False, "error": str(e)}
failed_modules.append(module_name)
if failed_modules:
return {
"message": f"Reloaded {len(results) - len(failed_modules)}/{len(results)} modules successfully",
"success": False,
"results": results,
"failed_modules": failed_modules
"failed_modules": failed_modules,
}
else:
return {
"message": f"All {len(results)} modules reloaded successfully",
"success": True,
"results": results,
"failed_modules": []
"failed_modules": [],
}
@router.post("/{module_name}/reload")
async def reload_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
async def reload_module(
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Reload a specific module"""
log_api_request("reload_module", {"module_name": module_name})
if module_name not in module_manager.modules:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
success = await module_manager.reload_module(module_name)
if not success:
raise HTTPException(status_code=500, detail=f"Failed to reload module '{module_name}'")
raise HTTPException(
status_code=500, detail=f"Failed to reload module '{module_name}'"
)
return {
"message": f"Module '{module_name}' reloaded successfully",
"reloaded": True
"reloaded": True,
}
@router.post("/{module_name}/restart")
async def restart_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
async def restart_module(
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Restart a specific module (alias for reload)"""
log_api_request("restart_module", {"module_name": module_name})
if module_name not in module_manager.modules:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
success = await module_manager.reload_module(module_name)
if not success:
raise HTTPException(status_code=500, detail=f"Failed to restart module '{module_name}'")
raise HTTPException(
status_code=500, detail=f"Failed to restart module '{module_name}'"
)
return {
"message": f"Module '{module_name}' restarted successfully",
"restarted": True
"restarted": True,
}
@router.post("/{module_name}/start")
async def start_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
async def start_module(
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Start a specific module (enable and load)"""
log_api_request("start_module", {"module_name": module_name})
if module_name not in module_manager.module_configs:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
# Enable the module
config = module_manager.module_configs[module_name]
config.enabled = True
# Load the module if not already loaded
if module_name not in module_manager.modules:
await module_manager._load_module(module_name, config)
return {
"message": f"Module '{module_name}' started successfully",
"started": True
}
return {"message": f"Module '{module_name}' started successfully", "started": True}
@router.post("/{module_name}/stop")
async def stop_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
async def stop_module(
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Stop a specific module (disable and unload)"""
log_api_request("stop_module", {"module_name": module_name})
if module_name not in module_manager.module_configs:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
# Disable the module
config = module_manager.module_configs[module_name]
config.enabled = False
# Unload the module if loaded
if module_name in module_manager.modules:
await module_manager.unload_module(module_name)
return {
"message": f"Module '{module_name}' stopped successfully",
"stopped": True
}
return {"message": f"Module '{module_name}' stopped successfully", "stopped": True}
@router.get("/{module_name}/stats")
async def get_module_stats(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
async def get_module_stats(
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Get module statistics"""
log_api_request("get_module_stats", {"module_name": module_name})
if module_name not in module_manager.modules:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
module = module_manager.modules[module_name]
if not hasattr(module, "get_stats"):
raise HTTPException(status_code=404, detail=f"Module '{module_name}' does not provide statistics")
raise HTTPException(
status_code=404,
detail=f"Module '{module_name}' does not provide statistics",
)
try:
import asyncio
if asyncio.iscoroutinefunction(module.get_stats):
stats = await module.get_stats()
else:
stats = module.get_stats()
return {
"module": module_name,
"stats": stats.__dict__ if hasattr(stats, "__dict__") else stats
"stats": stats.__dict__ if hasattr(stats, "__dict__") else stats,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get statistics: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Failed to get statistics: {str(e)}"
)
@router.post("/{module_name}/execute")
async def execute_module_action(module_name: str, request_data: Dict[str, Any], current_user: Dict[str, Any] = Depends(get_current_user)):
async def execute_module_action(
module_name: str,
request_data: Dict[str, Any],
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Execute a module action through the interceptor pattern"""
log_api_request("execute_module_action", {"module_name": module_name, "action": request_data.get("action")})
log_api_request(
"execute_module_action",
{"module_name": module_name, "action": request_data.get("action")},
)
if module_name not in module_manager.modules:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
module = module_manager.modules[module_name]
# Check if module supports the new interceptor pattern
if hasattr(module, 'execute_with_interceptors'):
if hasattr(module, "execute_with_interceptors"):
try:
# Prepare context (would normally come from authentication middleware)
context = {
"user_id": "test_user", # Would come from authentication
"api_key_id": "test_api_key", # Would come from API key auth
"ip_address": "127.0.0.1", # Would come from request
"user_permissions": [f"modules:{module_name}:*"] # Would come from user/API key permissions
"user_permissions": [
f"modules:{module_name}:*"
], # Would come from user/API key permissions
}
# Execute through interceptor chain
response = await module.execute_with_interceptors(request_data, context)
return {
"module": module_name,
"success": True,
"response": response,
"interceptor_pattern": True
"interceptor_pattern": True,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Module execution failed: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Module execution failed: {str(e)}"
)
# Fallback for legacy modules
else:
action = request_data.get("action", "execute")
if hasattr(module, action):
try:
method = getattr(module, action)
if callable(method):
import asyncio
if asyncio.iscoroutinefunction(method):
response = await method(request_data)
else:
response = method(request_data)
return {
"module": module_name,
"success": True,
"response": response,
"interceptor_pattern": False
"interceptor_pattern": False,
}
else:
raise HTTPException(status_code=400, detail=f"'{action}' is not callable")
raise HTTPException(
status_code=400, detail=f"'{action}' is not callable"
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Module execution failed: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Module execution failed: {str(e)}"
)
else:
raise HTTPException(status_code=400, detail=f"Action '{action}' not supported by module '{module_name}'")
raise HTTPException(
status_code=400,
detail=f"Action '{action}' not supported by module '{module_name}'",
)
@router.get("/{module_name}/config")
async def get_module_config(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
async def get_module_config(
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Get module configuration schema and current values"""
log_api_request("get_module_config", {"module_name": module_name})
from app.services.module_config_manager import module_config_manager
from app.services.llm.service import llm_service
import copy
# Get module manifest and schema
manifest = module_config_manager.get_module_manifest(module_name)
if not manifest:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
schema = module_config_manager.get_module_schema(module_name)
current_config = module_config_manager.get_module_config(module_name)
# For Signal module, populate model options dynamically
if module_name == "signal" and schema:
try:
# Get available models from LLM service
models_data = await llm_service.get_models()
model_ids = [model.id for model in models_data]
if model_ids:
# Create a copy of the schema to avoid modifying the original
dynamic_schema = copy.deepcopy(schema)
# Add enum options for the model field
if "properties" in dynamic_schema and "model" in dynamic_schema["properties"]:
if (
"properties" in dynamic_schema
and "model" in dynamic_schema["properties"]
):
dynamic_schema["properties"]["model"]["enum"] = model_ids
# Set a sensible default if the current default isn't in the list
current_default = dynamic_schema["properties"]["model"].get("default", "gpt-3.5-turbo")
current_default = dynamic_schema["properties"]["model"].get(
"default", "gpt-3.5-turbo"
)
if current_default not in model_ids and model_ids:
dynamic_schema["properties"]["model"]["default"] = model_ids[0]
schema = dynamic_schema
except Exception as e:
# If we can't get models, log warning but continue with original schema
logger.warning(f"Failed to get dynamic models for Signal config: {e}")
return {
"module": module_name,
"description": manifest.description,
"schema": schema,
"current_config": current_config,
"has_schema": schema is not None
"has_schema": schema is not None,
}
@router.post("/{module_name}/config")
async def update_module_config(module_name: str, config: dict, current_user: Dict[str, Any] = Depends(get_current_user)):
async def update_module_config(
module_name: str,
config: dict,
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Update module configuration"""
log_api_request("update_module_config", {"module_name": module_name})
from app.services.module_config_manager import module_config_manager
# Validate module exists
manifest = module_config_manager.get_module_manifest(module_name)
if not manifest:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
try:
# Save configuration
success = await module_config_manager.save_module_config(module_name, config)
if not success:
raise HTTPException(status_code=500, detail="Failed to save configuration")
# Update module manager with new config
success = await module_manager.update_module_config(module_name, config)
if not success:
raise HTTPException(status_code=500, detail="Failed to apply configuration")
return {
"message": f"Configuration updated for module '{module_name}'",
"config": config
"config": config,
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
raise HTTPException(status_code=400, detail=str(e))

View File

@@ -12,9 +12,13 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db
from app.services.api_key_auth import require_api_key
from app.api.v1.llm import (
get_cached_models, ModelsResponse, ModelInfo,
ChatCompletionRequest, EmbeddingRequest, create_chat_completion as llm_chat_completion,
create_embedding as llm_create_embedding
get_cached_models,
ModelsResponse,
ModelInfo,
ChatCompletionRequest,
EmbeddingRequest,
create_chat_completion as llm_chat_completion,
create_embedding as llm_create_embedding,
)
logger = logging.getLogger(__name__)
@@ -22,8 +26,12 @@ logger = logging.getLogger(__name__)
router = APIRouter()
def openai_error_response(message: str, error_type: str = "invalid_request_error",
status_code: int = 400, code: str = None):
def openai_error_response(
message: str,
error_type: str = "invalid_request_error",
status_code: int = 400,
code: str = None,
):
"""Create OpenAI-compatible error response"""
error_data = {
"error": {
@@ -33,52 +41,42 @@ def openai_error_response(message: str, error_type: str = "invalid_request_error
}
if code:
error_data["error"]["code"] = code
return JSONResponse(
status_code=status_code,
content=error_data
)
return JSONResponse(status_code=status_code, content=error_data)
@router.get("/models", response_model=ModelsResponse)
async def list_models(
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""
Lists the currently available models, and provides basic information about each one
Lists the currently available models, and provides basic information about each one
such as the owner and availability.
This endpoint follows the exact OpenAI API specification:
GET /v1/models
"""
try:
# Delegate to the existing LLM models endpoint
from app.api.v1.llm import list_models as llm_list_models
return await llm_list_models(context, db)
except HTTPException as e:
# Convert FastAPI HTTPException to OpenAI format
if e.status_code == 401:
return openai_error_response(
"Invalid authentication credentials",
"authentication_error",
401
"Invalid authentication credentials", "authentication_error", 401
)
elif e.status_code == 403:
return openai_error_response(
"Insufficient permissions",
"permission_error",
403
"Insufficient permissions", "permission_error", 403
)
else:
return openai_error_response(str(e.detail), "api_error", e.status_code)
except Exception as e:
logger.error(f"Error in OpenAI models endpoint: {e}")
return openai_error_response(
"Internal server error",
"api_error",
500
)
return openai_error_response("Internal server error", "api_error", 500)
@router.post("/chat/completions")
@@ -86,11 +84,11 @@ async def create_chat_completion(
request_body: Request,
chat_request: ChatCompletionRequest,
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""
Create chat completion - OpenAI compatible endpoint
This endpoint follows the exact OpenAI API specification:
POST /v1/chat/completions
"""
@@ -102,11 +100,11 @@ async def create_chat_completion(
async def create_embedding(
request: EmbeddingRequest,
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""
Create embedding - OpenAI compatible endpoint
This endpoint follows the exact OpenAI API specification:
POST /v1/embeddings
"""
@@ -118,44 +116,46 @@ async def create_embedding(
async def retrieve_model(
model_id: str,
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""
Retrieve model information - OpenAI compatible endpoint
This endpoint follows the exact OpenAI API specification:
GET /v1/models/{model}
"""
try:
# Get all models and find the specific one
models = await get_cached_models()
# Filter models based on API key permissions
api_key = context.get("api_key")
if api_key and api_key.allowed_models:
models = [model for model in models if model.get("id") in api_key.allowed_models]
models = [
model for model in models if model.get("id") in api_key.allowed_models
]
# Find the specific model
model = next((m for m in models if m.get("id") == model_id), None)
if not model:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model '{model_id}' not found"
detail=f"Model '{model_id}' not found",
)
return ModelInfo(
id=model.get("id", model_id),
object="model",
created=model.get("created", 0),
owned_by=model.get("owned_by", "system")
owned_by=model.get("owned_by", "system"),
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error retrieving model {model_id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve model information"
)
detail="Failed to retrieve model information",
)

View File

@@ -7,7 +7,11 @@ from typing import List, Dict, Any, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from app.services.permission_manager import permission_registry, Permission, PermissionScope
from app.services.permission_manager import (
permission_registry,
Permission,
PermissionScope,
)
from app.core.logging import get_logger
from app.core.security import get_current_user
@@ -77,7 +81,7 @@ async def get_available_permissions(namespace: Optional[str] = None):
"""Get all available permissions, optionally filtered by namespace"""
try:
permissions = permission_registry.get_available_permissions(namespace)
# Convert to response format
result = {}
for ns, perms in permissions.items():
@@ -86,18 +90,18 @@ async def get_available_permissions(namespace: Optional[str] = None):
resource=perm.resource,
action=perm.action,
description=perm.description,
conditions=getattr(perm, 'conditions', None)
conditions=getattr(perm, "conditions", None),
)
for perm in perms
]
return result
except Exception as e:
logger.error(f"Error getting permissions: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get permissions: {str(e)}"
detail=f"Failed to get permissions: {str(e)}",
)
@@ -107,12 +111,12 @@ async def get_permission_hierarchy():
try:
hierarchy = permission_registry.get_permission_hierarchy()
return PermissionHierarchyResponse(hierarchy=hierarchy)
except Exception as e:
logger.error(f"Error getting permission hierarchy: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get permission hierarchy: {str(e)}"
detail=f"Failed to get permission hierarchy: {str(e)}",
)
@@ -120,44 +124,43 @@ async def get_permission_hierarchy():
async def validate_permissions(request: PermissionValidationRequest):
"""Validate a list of permissions"""
try:
validation_result = permission_registry.validate_permissions(request.permissions)
validation_result = permission_registry.validate_permissions(
request.permissions
)
return PermissionValidationResponse(**validation_result)
except Exception as e:
logger.error(f"Error validating permissions: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to validate permissions: {str(e)}"
detail=f"Failed to validate permissions: {str(e)}",
)
@router.post("/permissions/check", response_model=PermissionCheckResponse)
async def check_permission(
request: PermissionCheckRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Check if user has a specific permission"""
try:
has_permission = permission_registry.check_permission(
request.user_permissions,
request.required_permission,
request.context
request.user_permissions, request.required_permission, request.context
)
matching_permissions = list(permission_registry.tree.get_matching_permissions(
request.user_permissions
))
matching_permissions = list(
permission_registry.tree.get_matching_permissions(request.user_permissions)
)
return PermissionCheckResponse(
has_permission=has_permission,
matching_permissions=matching_permissions
has_permission=has_permission, matching_permissions=matching_permissions
)
except Exception as e:
logger.error(f"Error checking permission: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to check permission: {str(e)}"
detail=f"Failed to check permission: {str(e)}",
)
@@ -166,22 +169,22 @@ async def get_module_permissions(module_id: str):
"""Get permissions for a specific module"""
try:
permissions = permission_registry.get_module_permissions(module_id)
return [
PermissionResponse(
resource=perm.resource,
action=perm.action,
description=perm.description,
conditions=getattr(perm, 'conditions', None)
conditions=getattr(perm, "conditions", None),
)
for perm in permissions
]
except Exception as e:
logger.error(f"Error getting module permissions: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get module permissions: {str(e)}"
detail=f"Failed to get module permissions: {str(e)}",
)
@@ -191,27 +194,28 @@ async def create_role(request: RoleRequest):
"""Create a custom role with specific permissions"""
try:
# Validate permissions first
validation_result = permission_registry.validate_permissions(request.permissions)
validation_result = permission_registry.validate_permissions(
request.permissions
)
if not validation_result["is_valid"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid permissions: {validation_result['invalid']}"
detail=f"Invalid permissions: {validation_result['invalid']}",
)
permission_registry.create_role(request.role_name, request.permissions)
return RoleResponse(
role_name=request.role_name,
permissions=request.permissions
role_name=request.role_name, permissions=request.permissions
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error creating role: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create role: {str(e)}"
detail=f"Failed to create role: {str(e)}",
)
@@ -220,14 +224,17 @@ async def get_roles():
"""Get all available roles and their permissions"""
try:
# Combine default roles and custom roles
all_roles = {**permission_registry.default_roles, **permission_registry.role_permissions}
all_roles = {
**permission_registry.default_roles,
**permission_registry.role_permissions,
}
return all_roles
except Exception as e:
logger.error(f"Error getting roles: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get roles: {str(e)}"
detail=f"Failed to get roles: {str(e)}",
)
@@ -236,28 +243,25 @@ async def get_role(role_name: str):
"""Get a specific role and its permissions"""
try:
# Check default roles first, then custom roles
permissions = (permission_registry.role_permissions.get(role_name) or
permission_registry.default_roles.get(role_name))
permissions = permission_registry.role_permissions.get(
role_name
) or permission_registry.default_roles.get(role_name)
if permissions is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Role '{role_name}' not found"
detail=f"Role '{role_name}' not found",
)
return RoleResponse(
role_name=role_name,
permissions=permissions,
created=True
)
return RoleResponse(role_name=role_name, permissions=permissions, created=True)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting role: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get role: {str(e)}"
detail=f"Failed to get role: {str(e)}",
)
@@ -267,21 +271,20 @@ async def calculate_user_permissions(request: UserPermissionsRequest):
"""Calculate effective permissions for a user based on roles and custom permissions"""
try:
effective_permissions = permission_registry.get_user_permissions(
request.roles,
request.custom_permissions
request.roles, request.custom_permissions
)
return UserPermissionsResponse(
effective_permissions=effective_permissions,
roles=request.roles,
custom_permissions=request.custom_permissions or []
custom_permissions=request.custom_permissions or [],
)
except Exception as e:
logger.error(f"Error calculating user permissions: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to calculate user permissions: {str(e)}"
detail=f"Failed to calculate user permissions: {str(e)}",
)
@@ -293,8 +296,10 @@ async def platform_health():
# Get permission system status
total_permissions = len(permission_registry.tree.permissions)
total_modules = len(permission_registry.module_permissions)
total_roles = len(permission_registry.default_roles) + len(permission_registry.role_permissions)
total_roles = len(permission_registry.default_roles) + len(
permission_registry.role_permissions
)
return {
"status": "healthy",
"service": "Confidential Empire Platform API",
@@ -302,16 +307,13 @@ async def platform_health():
"permission_system": {
"total_permissions": total_permissions,
"registered_modules": total_modules,
"available_roles": total_roles
}
"available_roles": total_roles,
},
}
except Exception as e:
logger.error(f"Error checking platform health: {str(e)}")
return {
"status": "unhealthy",
"error": str(e)
}
return {"status": "unhealthy", "error": str(e)}
@router.get("/metrics")
@@ -320,28 +322,29 @@ async def platform_metrics():
try:
# Get permission system metrics
namespaces = permission_registry.get_available_permissions()
metrics = {
"permissions": {
"total": len(permission_registry.tree.permissions),
"by_namespace": {ns: len(perms) for ns, perms in namespaces.items()}
"by_namespace": {ns: len(perms) for ns, perms in namespaces.items()},
},
"modules": {
"registered": len(permission_registry.module_permissions),
"names": list(permission_registry.module_permissions.keys())
"names": list(permission_registry.module_permissions.keys()),
},
"roles": {
"default": len(permission_registry.default_roles),
"custom": len(permission_registry.role_permissions),
"total": len(permission_registry.default_roles) + len(permission_registry.role_permissions)
}
"total": len(permission_registry.default_roles)
+ len(permission_registry.role_permissions),
},
}
return metrics
except Exception as e:
logger.error(f"Error getting platform metrics: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get platform metrics: {str(e)}"
)
detail=f"Failed to get platform metrics: {str(e)}",
)

View File

@@ -46,79 +46,75 @@ async def discover_plugins(
category: str = "",
limit: int = 20,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Discover available plugins from repository"""
try:
tag_list = [tag.strip() for tag in tags.split(",") if tag.strip()] if tags else None
tag_list = (
[tag.strip() for tag in tags.split(",") if tag.strip()] if tags else None
)
plugins = await plugin_discovery.search_available_plugins(
query=query,
tags=tag_list,
tags=tag_list,
category=category if category else None,
limit=limit,
db=db
db=db,
)
return {
"plugins": plugins,
"count": len(plugins),
"query": query,
"filters": {
"tags": tag_list,
"category": category
}
"filters": {"tags": tag_list, "category": category},
}
except Exception as e:
logger.error(f"Plugin discovery failed: {e}")
raise HTTPException(status_code=500, detail=f"Discovery failed: {e}")
@router.get("/categories")
async def get_plugin_categories(current_user: Dict[str, Any] = Depends(get_current_user)):
async def get_plugin_categories(
current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Get available plugin categories"""
try:
categories = await plugin_discovery.get_plugin_categories()
return {"categories": categories}
except Exception as e:
logger.error(f"Failed to get categories: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get categories: {e}")
@router.get("/installed")
async def get_installed_plugins(
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get user's installed plugins"""
try:
plugins = await plugin_discovery.get_installed_plugins(current_user["id"], db)
return {
"plugins": plugins,
"count": len(plugins)
}
return {"plugins": plugins, "count": len(plugins)}
except Exception as e:
logger.error(f"Failed to get installed plugins: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get installed plugins: {e}")
raise HTTPException(
status_code=500, detail=f"Failed to get installed plugins: {e}"
)
@router.get("/updates")
async def check_plugin_updates(
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Check for available plugin updates"""
try:
updates = await plugin_discovery.get_plugin_updates(db)
return {
"updates": updates,
"count": len(updates)
}
return {"updates": updates, "count": len(updates)}
except Exception as e:
logger.error(f"Failed to check updates: {e}")
raise HTTPException(status_code=500, detail=f"Failed to check updates: {e}")
@@ -130,29 +126,32 @@ async def install_plugin(
request: PluginInstallRequest,
background_tasks: BackgroundTasks,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Install plugin from repository"""
try:
if request.source != "repository":
raise HTTPException(status_code=400, detail="Only repository installation supported via this endpoint")
raise HTTPException(
status_code=400,
detail="Only repository installation supported via this endpoint",
)
# Start installation in background
background_tasks.add_task(
install_plugin_background,
request.plugin_id,
request.version,
current_user["id"],
db
db,
)
return {
"status": "installation_started",
"plugin_id": request.plugin_id,
"version": request.version,
"message": "Plugin installation started in background"
"message": "Plugin installation started in background",
}
except Exception as e:
logger.error(f"Plugin installation failed: {e}")
raise HTTPException(status_code=500, detail=f"Installation failed: {e}")
@@ -163,38 +162,40 @@ async def install_plugin_from_file(
file: UploadFile = File(...),
background_tasks: BackgroundTasks = None,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Install plugin from uploaded file"""
try:
# Validate file type
if not file.filename.endswith('.zip'):
if not file.filename.endswith(".zip"):
raise HTTPException(status_code=400, detail="Only ZIP files are supported")
# Save uploaded file
import tempfile
with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as temp_file:
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as temp_file:
content = await file.read()
temp_file.write(content)
temp_file_path = temp_file.name
try:
# Install plugin
result = await plugin_installer.install_plugin_from_file(
temp_file_path, current_user["id"], db
)
return {
"status": "installed",
"result": result,
"message": "Plugin installed successfully"
"message": "Plugin installed successfully",
}
finally:
# Cleanup temp file
import os
os.unlink(temp_file_path)
except Exception as e:
logger.error(f"File upload installation failed: {e}")
raise HTTPException(status_code=500, detail=f"Installation failed: {e}")
@@ -205,20 +206,20 @@ async def uninstall_plugin(
plugin_id: str,
request: PluginUninstallRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Uninstall plugin"""
try:
result = await plugin_installer.uninstall_plugin(
plugin_id, current_user["id"], db, request.keep_data
)
return {
"status": "uninstalled",
"result": result,
"message": "Plugin uninstalled successfully"
"message": "Plugin uninstalled successfully",
}
except Exception as e:
logger.error(f"Plugin uninstall failed: {e}")
raise HTTPException(status_code=500, detail=f"Uninstall failed: {e}")
@@ -229,28 +230,28 @@ async def uninstall_plugin(
async def enable_plugin(
plugin_id: str,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Enable plugin"""
try:
from app.models.plugin import Plugin
from sqlalchemy import select
stmt = select(Plugin).where(Plugin.id == plugin_id)
result = await db.execute(stmt)
plugin = result.scalar_one_or_none()
if not plugin:
raise HTTPException(status_code=404, detail="Plugin not found")
plugin.status = "enabled"
await db.commit()
return {
"status": "enabled",
"plugin_id": plugin_id,
"message": "Plugin enabled successfully"
"message": "Plugin enabled successfully",
}
except Exception as e:
logger.error(f"Plugin enable failed: {e}")
raise HTTPException(status_code=500, detail=f"Enable failed: {e}")
@@ -260,32 +261,32 @@ async def enable_plugin(
async def disable_plugin(
plugin_id: str,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Disable plugin"""
try:
from app.models.plugin import Plugin
from sqlalchemy import select
stmt = select(Plugin).where(Plugin.id == plugin_id)
result = await db.execute(stmt)
plugin = result.scalar_one_or_none()
if not plugin:
raise HTTPException(status_code=404, detail="Plugin not found")
# Unload if currently loaded
if plugin_id in plugin_loader.loaded_plugins:
await plugin_loader.unload_plugin(plugin_id)
plugin.status = "disabled"
await db.commit()
return {
"status": "disabled",
"status": "disabled",
"plugin_id": plugin_id,
"message": "Plugin disabled successfully"
"message": "Plugin disabled successfully",
}
except Exception as e:
logger.error(f"Plugin disable failed: {e}")
raise HTTPException(status_code=500, detail=f"Disable failed: {e}")
@@ -295,58 +296,62 @@ async def disable_plugin(
async def load_plugin(
plugin_id: str,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Load plugin into runtime"""
try:
from app.models.plugin import Plugin
from pathlib import Path
from sqlalchemy import select
stmt = select(Plugin).where(Plugin.id == plugin_id)
result = await db.execute(stmt)
plugin = result.scalar_one_or_none()
if not plugin:
raise HTTPException(status_code=404, detail="Plugin not found")
if plugin.status != "enabled":
raise HTTPException(status_code=400, detail="Plugin must be enabled to load")
raise HTTPException(
status_code=400, detail="Plugin must be enabled to load"
)
if plugin_id in plugin_loader.loaded_plugins:
raise HTTPException(status_code=400, detail="Plugin already loaded")
# Load plugin with proper context management
plugin_dir = Path(plugin.plugin_dir)
# Create plugin context for standardized interface
plugin_context = plugin_context_manager.create_plugin_context(
plugin_id=plugin_id,
user_id=str(current_user.get("id", "unknown")), # Use actual user ID
session_type="api_load"
session_type="api_load",
)
# Generate plugin token based on context
plugin_token = plugin_context_manager.generate_plugin_token(plugin_context["context_id"])
plugin_token = plugin_context_manager.generate_plugin_token(
plugin_context["context_id"]
)
# Log plugin loading action
plugin_context_manager.add_audit_trail_entry(
plugin_context["context_id"],
"plugin_load_via_api",
{
"plugin_dir": str(plugin_dir),
"plugin_dir": str(plugin_dir),
"user_id": current_user.get("id", "unknown"),
"action": "load_plugin_with_sandbox"
}
"action": "load_plugin_with_sandbox",
},
)
await plugin_loader.load_plugin_with_sandbox(plugin_dir, plugin_token)
return {
"status": "loaded",
"plugin_id": plugin_id,
"message": "Plugin loaded successfully"
"message": "Plugin loaded successfully",
}
except Exception as e:
logger.error(f"Plugin load failed: {e}")
raise HTTPException(status_code=500, detail=f"Load failed: {e}")
@@ -354,24 +359,23 @@ async def load_plugin(
@router.post("/{plugin_id}/unload")
async def unload_plugin(
plugin_id: str,
current_user: Dict[str, Any] = Depends(get_current_user)
plugin_id: str, current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Unload plugin from runtime"""
try:
if plugin_id not in plugin_loader.loaded_plugins:
raise HTTPException(status_code=404, detail="Plugin not loaded")
success = await plugin_loader.unload_plugin(plugin_id)
if not success:
raise HTTPException(status_code=500, detail="Failed to unload plugin")
return {
"status": "unloaded",
"plugin_id": plugin_id,
"message": "Plugin unloaded successfully"
"message": "Plugin unloaded successfully",
}
except Exception as e:
logger.error(f"Plugin unload failed: {e}")
raise HTTPException(status_code=500, detail=f"Unload failed: {e}")
@@ -382,40 +386,38 @@ async def unload_plugin(
async def get_plugin_configuration(
plugin_id: str,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get plugin configuration for user with automatic decryption"""
try:
from app.services.plugin_configuration_manager import plugin_config_manager
# Use the new configuration manager to get decrypted configuration
config_data = await plugin_config_manager.get_plugin_configuration(
plugin_id=plugin_id,
user_id=current_user["id"],
db=db,
decrypt_sensitive=False # Don't decrypt sensitive data for API response
decrypt_sensitive=False, # Don't decrypt sensitive data for API response
)
if config_data is not None:
return {
"plugin_id": plugin_id,
"configuration": config_data,
"has_configuration": True
"has_configuration": True,
}
else:
# Get default configuration from manifest
resolved_config = await plugin_config_manager.get_resolved_configuration(
plugin_id=plugin_id,
user_id=current_user["id"],
db=db
plugin_id=plugin_id, user_id=current_user["id"], db=db
)
return {
"plugin_id": plugin_id,
"configuration": resolved_config,
"has_configuration": False
"has_configuration": False,
}
except Exception as e:
logger.error(f"Failed to get plugin configuration: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get configuration: {e}")
@@ -426,17 +428,17 @@ async def save_plugin_configuration(
plugin_id: str,
config_request: Dict[str, Any],
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Save plugin configuration for user with automatic encryption of sensitive fields"""
try:
from app.services.plugin_configuration_manager import plugin_config_manager
# Extract configuration data and metadata
config_data = config_request.get("configuration", {})
config_name = config_request.get("name", "Default Configuration")
config_description = config_request.get("description")
# Use the new configuration manager to save with automatic encryption
saved_config = await plugin_config_manager.save_plugin_configuration(
plugin_id=plugin_id,
@@ -444,43 +446,47 @@ async def save_plugin_configuration(
config_data=config_data,
config_name=config_name,
config_description=config_description,
db=db
db=db,
)
return {
"status": "saved",
"plugin_id": plugin_id,
"configuration_id": str(saved_config.id),
"message": "Configuration saved successfully with automatic encryption of sensitive fields"
"message": "Configuration saved successfully with automatic encryption of sensitive fields",
}
except Exception as e:
logger.error(f"Failed to save plugin configuration: {e}")
await db.rollback()
raise HTTPException(status_code=500, detail=f"Failed to save configuration: {e}")
raise HTTPException(
status_code=500, detail=f"Failed to save configuration: {e}"
)
@router.get("/{plugin_id}/schema")
async def get_plugin_configuration_schema(
plugin_id: str,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get plugin configuration schema from manifest"""
try:
from app.services.plugin_configuration_manager import plugin_config_manager
# Use the new configuration manager to get schema
schema = await plugin_config_manager.get_plugin_configuration_schema(plugin_id, db)
schema = await plugin_config_manager.get_plugin_configuration_schema(
plugin_id, db
)
if not schema:
raise HTTPException(status_code=404, detail=f"No configuration schema available for plugin '{plugin_id}'")
return {
"plugin_id": plugin_id,
"schema": schema
}
raise HTTPException(
status_code=404,
detail=f"No configuration schema available for plugin '{plugin_id}'",
)
return {"plugin_id": plugin_id, "schema": schema}
except HTTPException:
raise
except Exception as e:
@@ -493,120 +499,129 @@ async def test_plugin_credentials(
plugin_id: str,
test_request: Dict[str, Any],
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Test plugin credentials (currently supports Zammad)"""
import httpx
try:
logger.info(f"Testing credentials for plugin {plugin_id}")
# Get plugin from database to check its name
from app.models.plugin import Plugin
from sqlalchemy import select
stmt = select(Plugin).where(Plugin.id == plugin_id)
result = await db.execute(stmt)
plugin = result.scalar_one_or_none()
if not plugin:
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
raise HTTPException(
status_code=404, detail=f"Plugin '{plugin_id}' not found"
)
# Check if this is a Zammad plugin
if plugin.name.lower() != 'zammad':
raise HTTPException(status_code=400, detail=f"Credential testing not supported for plugin '{plugin.name}'")
if plugin.name.lower() != "zammad":
raise HTTPException(
status_code=400,
detail=f"Credential testing not supported for plugin '{plugin.name}'",
)
# Extract credentials from request
zammad_url = test_request.get('zammad_url')
api_token = test_request.get('api_token')
zammad_url = test_request.get("zammad_url")
api_token = test_request.get("api_token")
if not zammad_url or not api_token:
raise HTTPException(status_code=400, detail="Both zammad_url and api_token are required")
raise HTTPException(
status_code=400, detail="Both zammad_url and api_token are required"
)
# Clean up the URL (remove trailing slash)
zammad_url = zammad_url.rstrip('/')
zammad_url = zammad_url.rstrip("/")
# Test credentials by making a read-only API call to Zammad
async with httpx.AsyncClient(timeout=10.0) as client:
# Try to get user info - this is a safe read-only operation
test_url = f"{zammad_url}/api/v1/users/me"
headers = {
'Authorization': f'Token token={api_token}',
'Content-Type': 'application/json'
"Authorization": f"Token token={api_token}",
"Content-Type": "application/json",
}
response = await client.get(test_url, headers=headers)
if response.status_code == 200:
# Success - credentials are valid
user_data = response.json()
user_email = user_data.get('email', 'unknown')
user_email = user_data.get("email", "unknown")
return {
"success": True,
"message": f"Credentials verified! Connected as: {user_email}",
"zammad_url": zammad_url,
"user_info": {
"email": user_email,
"firstname": user_data.get('firstname', ''),
"lastname": user_data.get('lastname', '')
}
"firstname": user_data.get("firstname", ""),
"lastname": user_data.get("lastname", ""),
},
}
elif response.status_code == 401:
return {
"success": False,
"message": "Invalid API token. Please check your token and try again.",
"error_code": "invalid_token"
"error_code": "invalid_token",
}
elif response.status_code == 404:
return {
"success": False,
"message": "Zammad URL not found. Please verify the URL is correct.",
"error_code": "invalid_url"
"error_code": "invalid_url",
}
else:
error_text = ""
try:
error_data = response.json()
error_text = error_data.get('error', error_data.get('message', ''))
error_text = error_data.get("error", error_data.get("message", ""))
except:
error_text = response.text[:200]
return {
"success": False,
"message": f"Connection failed (HTTP {response.status_code}): {error_text}",
"error_code": "connection_failed"
"error_code": "connection_failed",
}
except httpx.TimeoutException:
return {
"success": False,
"message": "Connection timeout. Please check the Zammad URL and your network connection.",
"error_code": "timeout"
"error_code": "timeout",
}
except httpx.ConnectError:
return {
"success": False,
"message": "Could not connect to Zammad. Please verify the URL is correct and accessible.",
"error_code": "connection_error"
"error_code": "connection_error",
}
except Exception as e:
logger.error(f"Failed to test plugin credentials: {e}")
return {
"success": False,
"message": f"Test failed: {str(e)}",
"error_code": "unknown_error"
"error_code": "unknown_error",
}
# Background task for plugin installation
async def install_plugin_background(plugin_id: str, version: str, user_id: str, db: AsyncSession):
async def install_plugin_background(
plugin_id: str, version: str, user_id: str, db: AsyncSession
):
"""Background task for plugin installation"""
try:
result = await plugin_installer.install_plugin_from_repository(
plugin_id, version, user_id, db
)
logger.info(f"Background installation completed: {result}")
except Exception as e:
logger.error(f"Background installation failed: {e}")
# TODO: Notify user of installation failure
# TODO: Notify user of installation failure

View File

@@ -17,7 +17,10 @@ from app.core.security import get_current_user
from app.models.user import User
from app.core.logging import log_api_request
from app.services.llm.service import llm_service
from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage
from app.services.llm.models import (
ChatRequest as LLMChatRequest,
ChatMessage as LLMChatMessage,
)
router = APIRouter()
@@ -59,13 +62,14 @@ class ImprovePromptRequest(BaseModel):
@router.get("/templates", response_model=List[PromptTemplateResponse])
async def list_prompt_templates(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
):
"""Get all prompt templates"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
log_api_request("list_prompt_templates", {"user_id": user_id})
try:
result = await db.execute(
select(PromptTemplate)
@@ -73,7 +77,7 @@ async def list_prompt_templates(
.order_by(PromptTemplate.name)
)
templates = result.scalars().all()
template_list = []
for template in templates:
template_dict = {
@@ -85,28 +89,38 @@ async def list_prompt_templates(
"is_default": template.is_default,
"is_active": template.is_active,
"version": template.version,
"created_at": template.created_at.isoformat() if template.created_at else None,
"updated_at": template.updated_at.isoformat() if template.updated_at else None
"created_at": template.created_at.isoformat()
if template.created_at
else None,
"updated_at": template.updated_at.isoformat()
if template.updated_at
else None,
}
template_list.append(template_dict)
return template_list
except Exception as e:
log_api_request("list_prompt_templates_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to fetch prompt templates: {str(e)}")
log_api_request(
"list_prompt_templates_error", {"error": str(e), "user_id": user_id}
)
raise HTTPException(
status_code=500, detail=f"Failed to fetch prompt templates: {str(e)}"
)
@router.get("/templates/{type_key}", response_model=PromptTemplateResponse)
async def get_prompt_template(
type_key: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get a specific prompt template by type key"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
log_api_request("get_prompt_template", {"user_id": user_id, "type_key": type_key})
try:
result = await db.execute(
select(PromptTemplate)
@@ -114,10 +128,10 @@ async def get_prompt_template(
.where(PromptTemplate.is_active == True)
)
template = result.scalar_one_or_none()
if not template:
raise HTTPException(status_code=404, detail="Prompt template not found")
return {
"id": template.id,
"name": template.name,
@@ -127,15 +141,23 @@ async def get_prompt_template(
"is_default": template.is_default,
"is_active": template.is_active,
"version": template.version,
"created_at": template.created_at.isoformat() if template.created_at else None,
"updated_at": template.updated_at.isoformat() if template.updated_at else None
"created_at": template.created_at.isoformat()
if template.created_at
else None,
"updated_at": template.updated_at.isoformat()
if template.updated_at
else None,
}
except HTTPException:
raise
except Exception as e:
log_api_request("get_prompt_template_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to fetch prompt template: {str(e)}")
log_api_request(
"get_prompt_template_error", {"error": str(e), "user_id": user_id}
)
raise HTTPException(
status_code=500, detail=f"Failed to fetch prompt template: {str(e)}"
)
@router.put("/templates/{type_key}")
@@ -143,16 +165,17 @@ async def update_prompt_template(
type_key: str,
request: PromptTemplateRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Update a prompt template"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("update_prompt_template", {
"user_id": user_id,
"type_key": type_key,
"name": request.name
})
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
log_api_request(
"update_prompt_template",
{"user_id": user_id, "type_key": type_key, "name": request.name},
)
try:
# Get existing template
result = await db.execute(
@@ -161,10 +184,10 @@ async def update_prompt_template(
.where(PromptTemplate.is_active == True)
)
template = result.scalar_one_or_none()
if not template:
raise HTTPException(status_code=404, detail="Prompt template not found")
# Update the template
await db.execute(
update(PromptTemplate)
@@ -175,19 +198,18 @@ async def update_prompt_template(
system_prompt=request.system_prompt,
is_active=request.is_active,
version=template.version + 1,
updated_at=datetime.utcnow()
updated_at=datetime.utcnow(),
)
)
await db.commit()
# Return updated template
updated_result = await db.execute(
select(PromptTemplate)
.where(PromptTemplate.type_key == type_key)
select(PromptTemplate).where(PromptTemplate.type_key == type_key)
)
updated_template = updated_result.scalar_one()
return {
"id": updated_template.id,
"name": updated_template.name,
@@ -197,41 +219,52 @@ async def update_prompt_template(
"is_default": updated_template.is_default,
"is_active": updated_template.is_active,
"version": updated_template.version,
"created_at": updated_template.created_at.isoformat() if updated_template.created_at else None,
"updated_at": updated_template.updated_at.isoformat() if updated_template.updated_at else None
"created_at": updated_template.created_at.isoformat()
if updated_template.created_at
else None,
"updated_at": updated_template.updated_at.isoformat()
if updated_template.updated_at
else None,
}
except HTTPException:
raise
except Exception as e:
await db.rollback()
log_api_request("update_prompt_template_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to update prompt template: {str(e)}")
log_api_request(
"update_prompt_template_error", {"error": str(e), "user_id": user_id}
)
raise HTTPException(
status_code=500, detail=f"Failed to update prompt template: {str(e)}"
)
@router.post("/templates/create")
async def create_prompt_template(
request: PromptTemplateRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Create a new prompt template"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("create_prompt_template", {
"user_id": user_id,
"type_key": request.type_key,
"name": request.name
})
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
log_api_request(
"create_prompt_template",
{"user_id": user_id, "type_key": request.type_key, "name": request.name},
)
try:
# Check if template already exists
existing_result = await db.execute(
select(PromptTemplate)
.where(PromptTemplate.type_key == request.type_key)
select(PromptTemplate).where(PromptTemplate.type_key == request.type_key)
)
if existing_result.scalar_one_or_none():
raise HTTPException(status_code=400, detail="Prompt template with this type key already exists")
raise HTTPException(
status_code=400,
detail="Prompt template with this type key already exists",
)
# Create new template
template = PromptTemplate(
id=str(uuid.uuid4()),
@@ -243,13 +276,13 @@ async def create_prompt_template(
is_active=request.is_active,
version=1,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
updated_at=datetime.utcnow(),
)
db.add(template)
await db.commit()
await db.refresh(template)
return {
"id": template.id,
"name": template.name,
@@ -259,27 +292,36 @@ async def create_prompt_template(
"is_default": template.is_default,
"is_active": template.is_active,
"version": template.version,
"created_at": template.created_at.isoformat() if template.created_at else None,
"updated_at": template.updated_at.isoformat() if template.updated_at else None
"created_at": template.created_at.isoformat()
if template.created_at
else None,
"updated_at": template.updated_at.isoformat()
if template.updated_at
else None,
}
except HTTPException:
raise
except Exception as e:
await db.rollback()
log_api_request("create_prompt_template_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to create prompt template: {str(e)}")
log_api_request(
"create_prompt_template_error", {"error": str(e), "user_id": user_id}
)
raise HTTPException(
status_code=500, detail=f"Failed to create prompt template: {str(e)}"
)
@router.get("/variables", response_model=List[PromptVariableResponse])
async def list_prompt_variables(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
):
"""Get all available prompt variables"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
log_api_request("list_prompt_variables", {"user_id": user_id})
try:
result = await db.execute(
select(ChatbotPromptVariable)
@@ -287,7 +329,7 @@ async def list_prompt_variables(
.order_by(ChatbotPromptVariable.variable_name)
)
variables = result.scalars().all()
variable_list = []
for variable in variables:
variable_dict = {
@@ -295,27 +337,33 @@ async def list_prompt_variables(
"variable_name": variable.variable_name,
"description": variable.description,
"example_value": variable.example_value,
"is_active": variable.is_active
"is_active": variable.is_active,
}
variable_list.append(variable_dict)
return variable_list
except Exception as e:
log_api_request("list_prompt_variables_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to fetch prompt variables: {str(e)}")
log_api_request(
"list_prompt_variables_error", {"error": str(e), "user_id": user_id}
)
raise HTTPException(
status_code=500, detail=f"Failed to fetch prompt variables: {str(e)}"
)
@router.post("/templates/{type_key}/reset")
async def reset_prompt_template(
type_key: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Reset a prompt template to its default"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
log_api_request("reset_prompt_template", {"user_id": user_id, "type_key": type_key})
# Define default prompts (same as in migration)
default_prompts = {
"assistant": "You are a helpful AI assistant. Provide accurate, concise, and friendly responses. Always aim to be helpful while being honest about your limitations. When you don't know something, say so clearly. Be professional but approachable in your communication style.",
@@ -323,12 +371,12 @@ async def reset_prompt_template(
"teacher": "You are an experienced educational tutor and learning facilitator. Break down complex concepts into understandable, digestible parts. Use analogies, examples, and step-by-step explanations to help students learn. Encourage critical thinking through thoughtful questions. Be patient, supportive, and encouraging. Adapt your teaching style to different learning preferences. When a student makes mistakes, guide them to the correct answer rather than just providing it.",
"researcher": "You are a thorough research assistant with a focus on accuracy and evidence-based information. Provide well-researched, factual information with sources when possible. Be thorough in your analysis and present multiple perspectives when relevant topics have different viewpoints. Always distinguish between established facts, current research, and opinions. When information is uncertain or contested, clearly communicate the level of confidence and supporting evidence.",
"creative_writer": "You are an experienced creative writing mentor and storytelling expert. Help with brainstorming ideas, character development, plot structure, dialogue, and creative expression. Be imaginative and inspiring while providing constructive, actionable feedback. Encourage experimentation with different writing styles and techniques. When reviewing work, balance praise for strengths with specific suggestions for improvement. Help writers find their unique voice while mastering fundamental storytelling principles.",
"custom": "You are a helpful AI assistant. Your personality, expertise, and behavior will be defined by the user through custom instructions. Follow the user's guidance on how to respond, what tone to use, and what role to play. Be adaptable and responsive to the specific needs and preferences outlined in your configuration."
"custom": "You are a helpful AI assistant. Your personality, expertise, and behavior will be defined by the user through custom instructions. Follow the user's guidance on how to respond, what tone to use, and what role to play. Be adaptable and responsive to the specific needs and preferences outlined in your configuration.",
}
if type_key not in default_prompts:
raise HTTPException(status_code=404, detail="Unknown prompt template type")
try:
# Update the template to default
await db.execute(
@@ -337,33 +385,39 @@ async def reset_prompt_template(
.values(
system_prompt=default_prompts[type_key],
version=PromptTemplate.version + 1,
updated_at=datetime.utcnow()
updated_at=datetime.utcnow(),
)
)
await db.commit()
return {"message": "Prompt template reset to default successfully"}
except Exception as e:
await db.rollback()
log_api_request("reset_prompt_template_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to reset prompt template: {str(e)}")
log_api_request(
"reset_prompt_template_error", {"error": str(e), "user_id": user_id}
)
raise HTTPException(
status_code=500, detail=f"Failed to reset prompt template: {str(e)}"
)
@router.post("/improve")
async def improve_prompt_with_ai(
request: ImprovePromptRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Improve a prompt using AI"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("improve_prompt_with_ai", {
"user_id": user_id,
"chatbot_type": request.chatbot_type
})
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
log_api_request(
"improve_prompt_with_ai",
{"user_id": user_id, "chatbot_type": request.chatbot_type},
)
try:
# Create system message for improvement
system_message = """You are an expert prompt engineer. Your task is to improve the given prompt to make it more effective, clear, and specific for the intended chatbot type.
@@ -392,92 +446,100 @@ Please improve this prompt to make it more effective for a {request.chatbot_type
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
{"role": "user", "content": user_message},
]
# Get available models to use a default model
models = await llm_service.get_models()
if not models:
raise HTTPException(status_code=503, detail="No LLM models available")
# Use the first available model (you might want to make this configurable)
default_model = models[0].id
# Prepare the chat request for the new LLM service
chat_request = LLMChatRequest(
model=default_model,
messages=[LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages],
messages=[
LLMChatMessage(role=msg["role"], content=msg["content"])
for msg in messages
],
temperature=0.3,
max_tokens=1000,
user_id=str(user_id),
api_key_id=1 # Using default API key, you might want to make this dynamic
api_key_id=1, # Using default API key, you might want to make this dynamic
)
# Make the AI call
response = await llm_service.create_chat_completion(chat_request)
# Extract the improved prompt from the response
improved_prompt = response.choices[0].message.content.strip()
return {
"improved_prompt": improved_prompt,
"original_prompt": request.current_prompt,
"model_used": default_model
"model_used": default_model,
}
except HTTPException:
raise
except Exception as e:
log_api_request("improve_prompt_with_ai_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to improve prompt: {str(e)}")
log_api_request(
"improve_prompt_with_ai_error", {"error": str(e), "user_id": user_id}
)
raise HTTPException(
status_code=500, detail=f"Failed to improve prompt: {str(e)}"
)
@router.post("/seed-defaults")
async def seed_default_templates(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
):
"""Seed default prompt templates for all chatbot types"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
log_api_request("seed_default_templates", {"user_id": user_id})
# Define default prompts (same as in reset)
default_prompts = {
"assistant": {
"name": "General Assistant",
"description": "A helpful, accurate, and friendly AI assistant",
"prompt": "You are a helpful AI assistant. Provide accurate, concise, and friendly responses. Always aim to be helpful while being honest about your limitations. When you don't know something, say so clearly. Be professional but approachable in your communication style."
"prompt": "You are a helpful AI assistant. Provide accurate, concise, and friendly responses. Always aim to be helpful while being honest about your limitations. When you don't know something, say so clearly. Be professional but approachable in your communication style.",
},
"customer_support": {
"name": "Customer Support Agent",
"description": "Professional customer service representative focused on solving problems",
"prompt": "You are a professional customer support representative. Be empathetic, professional, and solution-focused in all interactions. Always try to understand the customer's issue fully before providing solutions. Use the knowledge base to provide accurate information. When you cannot resolve an issue, explain clearly how the customer can escalate or get further help. Maintain a helpful and patient tone even in difficult situations."
"prompt": "You are a professional customer support representative. Be empathetic, professional, and solution-focused in all interactions. Always try to understand the customer's issue fully before providing solutions. Use the knowledge base to provide accurate information. When you cannot resolve an issue, explain clearly how the customer can escalate or get further help. Maintain a helpful and patient tone even in difficult situations.",
},
"teacher": {
"name": "Educational Tutor",
"description": "Patient and encouraging educational facilitator",
"prompt": "You are an experienced educational tutor and learning facilitator. Break down complex concepts into understandable, digestible parts. Use analogies, examples, and step-by-step explanations to help students learn. Encourage critical thinking through thoughtful questions. Be patient, supportive, and encouraging. Adapt your teaching style to different learning preferences. When a student makes mistakes, guide them to the correct answer rather than just providing it."
"prompt": "You are an experienced educational tutor and learning facilitator. Break down complex concepts into understandable, digestible parts. Use analogies, examples, and step-by-step explanations to help students learn. Encourage critical thinking through thoughtful questions. Be patient, supportive, and encouraging. Adapt your teaching style to different learning preferences. When a student makes mistakes, guide them to the correct answer rather than just providing it.",
},
"researcher": {
"name": "Research Assistant",
"description": "Thorough researcher focused on evidence-based information",
"prompt": "You are a thorough research assistant with a focus on accuracy and evidence-based information. Provide well-researched, factual information with sources when possible. Be thorough in your analysis and present multiple perspectives when relevant topics have different viewpoints. Always distinguish between established facts, current research, and opinions. When information is uncertain or contested, clearly communicate the level of confidence and supporting evidence."
"prompt": "You are a thorough research assistant with a focus on accuracy and evidence-based information. Provide well-researched, factual information with sources when possible. Be thorough in your analysis and present multiple perspectives when relevant topics have different viewpoints. Always distinguish between established facts, current research, and opinions. When information is uncertain or contested, clearly communicate the level of confidence and supporting evidence.",
},
"creative_writer": {
"name": "Creative Writing Mentor",
"description": "Imaginative storytelling expert and writing coach",
"prompt": "You are an experienced creative writing mentor and storytelling expert. Help with brainstorming ideas, character development, plot structure, dialogue, and creative expression. Be imaginative and inspiring while providing constructive, actionable feedback. Encourage experimentation with different writing styles and techniques. When reviewing work, balance praise for strengths with specific suggestions for improvement. Help writers find their unique voice while mastering fundamental storytelling principles."
"prompt": "You are an experienced creative writing mentor and storytelling expert. Help with brainstorming ideas, character development, plot structure, dialogue, and creative expression. Be imaginative and inspiring while providing constructive, actionable feedback. Encourage experimentation with different writing styles and techniques. When reviewing work, balance praise for strengths with specific suggestions for improvement. Help writers find their unique voice while mastering fundamental storytelling principles.",
},
"custom": {
"name": "Custom Chatbot",
"description": "Customizable AI assistant with user-defined behavior",
"prompt": "You are a helpful AI assistant. Your personality, expertise, and behavior will be defined by the user through custom instructions. Follow the user's guidance on how to respond, what tone to use, and what role to play. Be adaptable and responsive to the specific needs and preferences outlined in your configuration."
}
"prompt": "You are a helpful AI assistant. Your personality, expertise, and behavior will be defined by the user through custom instructions. Follow the user's guidance on how to respond, what tone to use, and what role to play. Be adaptable and responsive to the specific needs and preferences outlined in your configuration.",
},
}
created_templates = []
updated_templates = []
try:
for type_key, template_data in default_prompts.items():
# Check if template already exists
@@ -530,7 +592,9 @@ async def seed_default_templates(
created_at=now,
updated_at=now,
)
.on_conflict_do_nothing(index_elements=[PromptTemplate.type_key])
.on_conflict_do_nothing(
index_elements=[PromptTemplate.type_key]
)
)
result = await db.execute(stmt)
@@ -541,17 +605,21 @@ async def seed_default_templates(
"prompt_template_seed_skipped",
{"type_key": type_key, "reason": "already_exists"},
)
await db.commit()
return {
"message": "Default templates seeded successfully",
"created": created_templates,
"updated": updated_templates,
"total": len(created_templates) + len(updated_templates)
"total": len(created_templates) + len(updated_templates),
}
except Exception as e:
await db.rollback()
log_api_request("seed_default_templates_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to seed default templates: {str(e)}")
log_api_request(
"seed_default_templates_error", {"error": str(e), "user_id": user_id}
)
raise HTTPException(
status_code=500, detail=f"Failed to seed default templates: {str(e)}"
)

View File

@@ -29,6 +29,7 @@ router = APIRouter(tags=["RAG"])
# Request/Response Models
class CollectionCreate(BaseModel):
name: str
description: Optional[str] = None
@@ -78,12 +79,13 @@ class StatsResponse(BaseModel):
# Collection Endpoints
@router.get("/collections", response_model=dict)
async def get_collections(
skip: int = 0,
limit: int = 100,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Get all RAG collections - live data directly from Qdrant (source of truth)"""
try:
@@ -103,7 +105,7 @@ async def get_collections(
"collections": paginated_collections,
"total": len(collections),
"total_documents": stats_data.get("total_documents", 0),
"total_size_bytes": stats_data.get("total_size_bytes", 0)
"total_size_bytes": stats_data.get("total_size_bytes", 0),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -113,20 +115,19 @@ async def get_collections(
async def create_collection(
collection_data: CollectionCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Create a new RAG collection"""
try:
rag_service = RAGService(db)
collection = await rag_service.create_collection(
name=collection_data.name,
description=collection_data.description
name=collection_data.name, description=collection_data.description
)
return {
"success": True,
"collection": collection.to_dict(),
"message": "Collection created successfully"
"message": "Collection created successfully",
}
except APIException as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
@@ -136,8 +137,7 @@ async def create_collection(
@router.get("/stats", response_model=dict)
async def get_rag_stats(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user)
):
"""Get overall RAG statistics - live data directly from Qdrant"""
try:
@@ -147,7 +147,11 @@ async def get_rag_stats(
stats_data = await qdrant_stats_service.get_collections_stats()
# Calculate active collections (collections with documents)
active_collections = sum(1 for col in stats_data.get("collections", []) if col.get("document_count", 0) > 0)
active_collections = sum(
1
for col in stats_data.get("collections", [])
if col.get("document_count", 0) > 0
)
# Calculate processing documents from database
processing_docs = 0
@@ -156,7 +160,9 @@ async def get_rag_stats(
from app.models.rag_document import RagDocument, ProcessingStatus
result = await db.execute(
select(RagDocument).where(RagDocument.status == ProcessingStatus.PROCESSING)
select(RagDocument).where(
RagDocument.status == ProcessingStatus.PROCESSING
)
)
processing_docs = len(result.scalars().all())
except Exception:
@@ -167,22 +173,28 @@ async def get_rag_stats(
"stats": {
"collections": {
"total": stats_data.get("total_collections", 0),
"active": active_collections
"active": active_collections,
},
"documents": {
"total": stats_data.get("total_documents", 0),
"processing": processing_docs,
"processed": stats_data.get("total_documents", 0) # Indexed documents
"processed": stats_data.get(
"total_documents", 0
), # Indexed documents
},
"storage": {
"total_size_bytes": stats_data.get("total_size_bytes", 0),
"total_size_mb": round(stats_data.get("total_size_bytes", 0) / (1024 * 1024), 2)
"total_size_mb": round(
stats_data.get("total_size_bytes", 0) / (1024 * 1024), 2
),
},
"vectors": {
"total": stats_data.get("total_documents", 0) # Same as documents for RAG
"total": stats_data.get(
"total_documents", 0
) # Same as documents for RAG
},
"last_updated": datetime.utcnow().isoformat()
}
"last_updated": datetime.utcnow().isoformat(),
},
}
return response_data
@@ -194,20 +206,17 @@ async def get_rag_stats(
async def get_collection(
collection_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Get a specific collection"""
try:
rag_service = RAGService(db)
collection = await rag_service.get_collection(collection_id)
if not collection:
raise HTTPException(status_code=404, detail="Collection not found")
return {
"success": True,
"collection": collection.to_dict()
}
return {"success": True, "collection": collection.to_dict()}
except HTTPException:
raise
except Exception as e:
@@ -219,19 +228,20 @@ async def delete_collection(
collection_id: int,
cascade: bool = True, # Default to cascade deletion for better UX
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Delete a collection and optionally all its documents"""
try:
rag_service = RAGService(db)
success = await rag_service.delete_collection(collection_id, cascade=cascade)
if not success:
raise HTTPException(status_code=404, detail="Collection not found")
return {
"success": True,
"message": "Collection deleted successfully" + (" (with documents)" if cascade else "")
"message": "Collection deleted successfully"
+ (" (with documents)" if cascade else ""),
}
except APIException as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
@@ -243,13 +253,14 @@ async def delete_collection(
# Document Endpoints
@router.get("/documents", response_model=dict)
async def get_documents(
collection_id: Optional[str] = None,
skip: int = 0,
limit: int = 100,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Get documents, optionally filtered by collection"""
try:
@@ -260,11 +271,7 @@ async def get_documents(
if collection_id.startswith("ext_"):
# External collections exist only in Qdrant and have no documents in PostgreSQL
# Return empty list since they don't have managed documents
return {
"success": True,
"documents": [],
"total": 0
}
return {"success": True, "documents": [], "total": 0}
else:
# Try to convert to integer for managed collections
try:
@@ -272,29 +279,25 @@ async def get_documents(
except (ValueError, TypeError):
# Attempt to resolve by Qdrant collection name
collection_row = await db.scalar(
select(RagCollection).where(RagCollection.qdrant_collection_name == collection_id)
select(RagCollection).where(
RagCollection.qdrant_collection_name == collection_id
)
)
if collection_row:
collection_id_int = collection_row.id
else:
# Unknown collection identifier; return empty result instead of erroring out
return {
"success": True,
"documents": [],
"total": 0
}
return {"success": True, "documents": [], "total": 0}
rag_service = RAGService(db)
documents = await rag_service.get_documents(
collection_id=collection_id_int,
skip=skip,
limit=limit
collection_id=collection_id_int, skip=skip, limit=limit
)
return {
"success": True,
"documents": [doc.to_dict() for doc in documents],
"total": len(documents)
"total": len(documents),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -305,13 +308,13 @@ async def upload_document(
collection_id: str = Form(...),
file: UploadFile = File(...),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Upload and process a document"""
try:
# Validate file can be read before processing
filename = file.filename or "unknown"
file_extension = filename.split('.')[-1].lower() if '.' in filename else ''
file_extension = filename.split(".")[-1].lower() if "." in filename else ""
# Read file content once and use it for all validations
file_content = await file.read()
@@ -324,50 +327,66 @@ async def upload_document(
try:
# Test file readability based on type
if file_extension == 'jsonl':
if file_extension == "jsonl":
# Validate JSONL format - try to parse first few lines
try:
content_str = file_content.decode('utf-8')
lines = content_str.strip().split('\n')[:5] # Check first 5 lines
content_str = file_content.decode("utf-8")
lines = content_str.strip().split("\n")[:5] # Check first 5 lines
import json
for i, line in enumerate(lines):
if line.strip(): # Skip empty lines
json.loads(line) # Will raise JSONDecodeError if invalid
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
raise HTTPException(
status_code=400, detail="File is not valid UTF-8 text"
)
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSONL format: {str(e)}")
raise HTTPException(
status_code=400, detail=f"Invalid JSONL format: {str(e)}"
)
elif file_extension in ['txt', 'md', 'py', 'js', 'html', 'css', 'json']:
elif file_extension in ["txt", "md", "py", "js", "html", "css", "json"]:
# Validate text files can be decoded
try:
file_content.decode('utf-8')
file_content.decode("utf-8")
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
raise HTTPException(
status_code=400, detail="File is not valid UTF-8 text"
)
elif file_extension in ['pdf']:
elif file_extension in ["pdf"]:
# For PDF files, just check if it starts with PDF signature
if not file_content.startswith(b'%PDF'):
raise HTTPException(status_code=400, detail="Invalid PDF file format")
if not file_content.startswith(b"%PDF"):
raise HTTPException(
status_code=400, detail="Invalid PDF file format"
)
elif file_extension in ['docx', 'xlsx', 'pptx']:
elif file_extension in ["docx", "xlsx", "pptx"]:
# For Office documents, check ZIP signature
if not file_content.startswith(b'PK'):
raise HTTPException(status_code=400, detail=f"Invalid {file_extension.upper()} file format")
if not file_content.startswith(b"PK"):
raise HTTPException(
status_code=400,
detail=f"Invalid {file_extension.upper()} file format",
)
# For other file types, we'll rely on the document processor
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=400, detail=f"File validation failed: {str(e)}")
raise HTTPException(
status_code=400, detail=f"File validation failed: {str(e)}"
)
rag_service = RAGService(db)
# Resolve collection identifier (supports both numeric IDs and Qdrant collection names)
collection_identifier = (collection_id or "").strip()
if not collection_identifier:
raise HTTPException(status_code=400, detail="Collection identifier is required")
raise HTTPException(
status_code=400, detail="Collection identifier is required"
)
resolved_collection_id: Optional[int] = None
@@ -379,7 +398,9 @@ async def upload_document(
qdrant_name = qdrant_name[4:]
try:
collection_record = await rag_service.ensure_collection_record(qdrant_name)
collection_record = await rag_service.ensure_collection_record(
qdrant_name
)
except Exception as ensure_error:
raise HTTPException(status_code=500, detail=str(ensure_error))
@@ -392,13 +413,13 @@ async def upload_document(
collection_id=resolved_collection_id,
file_content=file_content,
filename=filename,
content_type=file.content_type
content_type=file.content_type,
)
return {
"success": True,
"document": document.to_dict(),
"message": "Document uploaded and processing started"
"message": "Document uploaded and processing started",
}
except APIException as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
@@ -412,20 +433,17 @@ async def upload_document(
async def get_document(
document_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Get a specific document"""
try:
rag_service = RAGService(db)
document = await rag_service.get_document(document_id)
if not document:
raise HTTPException(status_code=404, detail="Document not found")
return {
"success": True,
"document": document.to_dict()
}
return {"success": True, "document": document.to_dict()}
except HTTPException:
raise
except Exception as e:
@@ -436,20 +454,17 @@ async def get_document(
async def delete_document(
document_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Delete a document"""
try:
rag_service = RAGService(db)
success = await rag_service.delete_document(document_id)
if not success:
raise HTTPException(status_code=404, detail="Document not found")
return {
"success": True,
"message": "Document deleted successfully"
}
return {"success": True, "message": "Document deleted successfully"}
except HTTPException:
raise
except Exception as e:
@@ -460,13 +475,13 @@ async def delete_document(
async def reprocess_document(
document_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Restart processing for a stuck or failed document"""
try:
rag_service = RAGService(db)
success = await rag_service.reprocess_document(document_id)
if not success:
# Get document to check if it exists and its current status
document = await rag_service.get_document(document_id)
@@ -474,13 +489,13 @@ async def reprocess_document(
raise HTTPException(status_code=404, detail="Document not found")
else:
raise HTTPException(
status_code=400,
detail=f"Cannot reprocess document with status '{document.status}'. Only 'processing' or 'error' documents can be reprocessed."
status_code=400,
detail=f"Cannot reprocess document with status '{document.status}'. Only 'processing' or 'error' documents can be reprocessed.",
)
return {
"success": True,
"message": "Document reprocessing started successfully"
"message": "Document reprocessing started successfully",
}
except APIException as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
@@ -494,22 +509,24 @@ async def reprocess_document(
async def download_document(
document_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Download the original document file"""
try:
rag_service = RAGService(db)
result = await rag_service.download_document(document_id)
if not result:
raise HTTPException(status_code=404, detail="Document not found or file not available")
raise HTTPException(
status_code=404, detail="Document not found or file not available"
)
content, filename, mime_type = result
return StreamingResponse(
io.BytesIO(content),
media_type=mime_type,
headers={"Content-Disposition": f"attachment; filename={filename}"}
headers={"Content-Disposition": f"attachment; filename={filename}"},
)
except HTTPException:
raise
@@ -517,9 +534,9 @@ async def download_document(
raise HTTPException(status_code=500, detail=str(e))
# Debug Endpoints
@router.post("/debug/search")
async def search_with_debug(
query: str,
@@ -527,13 +544,13 @@ async def search_with_debug(
score_threshold: float = 0.3,
collection_name: str = None,
config: Dict[str, Any] = None,
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
) -> Dict[str, Any]:
"""
Enhanced search with comprehensive debug information
"""
# Get RAG module from module manager
rag_module = module_manager.modules.get('rag')
rag_module = module_manager.modules.get("rag")
if not rag_module or not rag_module.enabled:
raise HTTPException(status_code=503, detail="RAG module not initialized")
@@ -567,7 +584,7 @@ async def search_with_debug(
query,
max_results=max_results,
score_threshold=score_threshold,
collection_name=collection_name
collection_name=collection_name,
)
search_time = (asyncio.get_event_loop().time() - search_start) * 1000
@@ -575,22 +592,23 @@ async def search_with_debug(
scores = [r.score for r in results if r.score is not None]
if scores:
import statistics
debug_info["score_stats"] = {
"min": min(scores),
"max": max(scores),
"avg": statistics.mean(scores),
"stddev": statistics.stdev(scores) if len(scores) > 1 else 0
"stddev": statistics.stdev(scores) if len(scores) > 1 else 0,
}
# Get collection statistics
try:
from qdrant_client.http.models import Filter
collection_name = collection_name or rag_module.default_collection_name
# Count total documents
count_result = rag_module.qdrant_client.count(
collection_name=collection_name,
count_filter=Filter(must=[])
collection_name=collection_name, count_filter=Filter(must=[])
)
total_points = count_result.count
@@ -599,7 +617,7 @@ async def search_with_debug(
collection_name=collection_name,
limit=1000, # Sample for stats
with_payload=True,
with_vectors=False
with_vectors=False,
)
unique_docs = set()
@@ -618,7 +636,7 @@ async def search_with_debug(
debug_info["collection_stats"] = {
"total_documents": len(unique_docs),
"total_chunks": total_points,
"languages": sorted(list(languages))
"languages": sorted(list(languages)),
}
except Exception as e:
@@ -631,16 +649,18 @@ async def search_with_debug(
"document": {
"id": result.document.id,
"content": result.document.content,
"metadata": result.document.metadata
"metadata": result.document.metadata,
},
"score": result.score,
"debug_info": {}
"debug_info": {},
}
# Add hybrid search debug info if available
metadata = result.document.metadata or {}
if "_vector_score" in metadata:
enhanced_result["debug_info"]["vector_score"] = metadata["_vector_score"]
enhanced_result["debug_info"]["vector_score"] = metadata[
"_vector_score"
]
if "_bm25_score" in metadata:
enhanced_result["debug_info"]["bm25_score"] = metadata["_bm25_score"]
@@ -652,7 +672,7 @@ async def search_with_debug(
"results": enhanced_results,
"debug_info": debug_info,
"search_time_ms": search_time,
"timestamp": start_time.isoformat()
"timestamp": start_time.isoformat(),
}
except Exception as e:
@@ -661,17 +681,17 @@ async def search_with_debug(
finally:
# Restore original config if modified
if config and 'original_config' in locals():
if config and "original_config" in locals():
rag_module.config = original_config
@router.get("/debug/config")
async def get_current_config(
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
) -> Dict[str, Any]:
"""Get current RAG configuration"""
# Get RAG module from module manager
rag_module = module_manager.modules.get('rag')
rag_module = module_manager.modules.get("rag")
if not rag_module or not rag_module.enabled:
raise HTTPException(status_code=503, detail="RAG module not initialized")
@@ -679,5 +699,5 @@ async def get_current_config(
"config": rag_module.config,
"embedding_model": rag_module.embedding_model,
"enabled": rag_module.enabled,
"collections": await rag_module._get_collections_safely()
"collections": await rag_module._get_collections_safely(),
}

File diff suppressed because it is too large Load Diff

View File

@@ -85,16 +85,16 @@ async def list_users(
is_active: Optional[bool] = Query(None),
search: Optional[str] = Query(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""List all users with pagination and filtering"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:users:read")
# Build query
query = select(User)
# Apply filters
if role:
query = query.where(User.role == role)
@@ -102,38 +102,42 @@ async def list_users(
query = query.where(User.is_active == is_active)
if search:
query = query.where(
(User.username.ilike(f"%{search}%")) |
(User.email.ilike(f"%{search}%")) |
(User.full_name.ilike(f"%{search}%"))
(User.username.ilike(f"%{search}%"))
| (User.email.ilike(f"%{search}%"))
| (User.full_name.ilike(f"%{search}%"))
)
# Get total count
total_query = select(User.id).select_from(query.subquery())
total_result = await db.execute(total_query)
total = len(total_result.fetchall())
# Apply pagination
offset = (page - 1) * size
query = query.offset(offset).limit(size)
# Execute query
result = await db.execute(query)
users = result.scalars().all()
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="list_users",
resource_type="user",
details={"page": page, "size": size, "filters": {"role": role, "is_active": is_active, "search": search}}
details={
"page": page,
"size": size,
"filters": {"role": role, "is_active": is_active, "search": search},
},
)
return UserListResponse(
users=[UserResponse.model_validate(user) for user in users],
total=total,
page=page,
size=size
size=size,
)
@@ -141,34 +145,33 @@ async def list_users(
async def get_user(
user_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get user by ID"""
# Check permissions (users can view their own profile)
if int(user_id) != current_user['id']:
if int(user_id) != current_user["id"]:
require_permission(current_user.get("permissions", []), "platform:users:read")
# Get user
query = select(User).where(User.id == int(user_id))
result = await db.execute(query)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="get_user",
resource_type="user",
resource_id=user_id
resource_id=user_id,
)
return UserResponse.model_validate(user)
@@ -176,26 +179,26 @@ async def get_user(
async def create_user(
user_data: UserCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Create a new user"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:users:create")
# Check if user already exists
query = select(User).where(
(User.username == user_data.username) | (User.email == user_data.email)
)
result = await db.execute(query)
existing_user = result.scalar_one_or_none()
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User with this username or email already exists"
detail="User with this username or email already exists",
)
# Create user
hashed_password = get_password_hash(user_data.password)
new_user = User(
@@ -204,25 +207,29 @@ async def create_user(
full_name=user_data.full_name,
hashed_password=hashed_password,
role=user_data.role,
is_active=user_data.is_active
is_active=user_data.is_active,
)
db.add(new_user)
await db.commit()
await db.refresh(new_user)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="create_user",
resource_type="user",
resource_id=str(new_user.id),
details={"username": user_data.username, "email": user_data.email, "role": user_data.role}
details={
"username": user_data.username,
"email": user_data.email,
"role": user_data.role,
},
)
logger.info(f"User created: {new_user.username} by {current_user['username']}")
return UserResponse.model_validate(new_user)
@@ -231,26 +238,25 @@ async def update_user(
user_id: str,
user_data: UserUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Update user"""
# Check permissions (users can update their own profile with limited fields)
is_self_update = int(user_id) == current_user['id']
is_self_update = int(user_id) == current_user["id"]
if not is_self_update:
require_permission(current_user.get("permissions", []), "platform:users:update")
# Get user
query = select(User).where(User.id == int(user_id))
result = await db.execute(query)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
# For self-updates, restrict what can be changed
if is_self_update:
allowed_fields = {"username", "email", "full_name"}
@@ -259,41 +265,41 @@ async def update_user(
if restricted_fields:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Cannot update fields: {restricted_fields}"
detail=f"Cannot update fields: {restricted_fields}",
)
# Store original values for audit
original_values = {
"username": user.username,
"email": user.email,
"role": user.role,
"is_active": user.is_active
"is_active": user.is_active,
}
# Update user fields
update_data = user_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(user, field, value)
await db.commit()
await db.refresh(user)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="update_user",
resource_type="user",
resource_id=user_id,
details={
"updated_fields": list(update_data.keys()),
"before_values": original_values,
"after_values": {k: getattr(user, k) for k in update_data.keys()}
}
"after_values": {k: getattr(user, k) for k in update_data.keys()},
},
)
logger.info(f"User updated: {user.username} by {current_user['username']}")
return UserResponse.model_validate(user)
@@ -301,47 +307,46 @@ async def update_user(
async def delete_user(
user_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Delete user (soft delete by deactivating)"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:users:delete")
# Prevent self-deletion
if int(user_id) == current_user['id']:
if int(user_id) == current_user["id"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot delete your own account"
detail="Cannot delete your own account",
)
# Get user
query = select(User).where(User.id == int(user_id))
result = await db.execute(query)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
# Soft delete by deactivating
user.is_active = False
await db.commit()
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="delete_user",
resource_type="user",
resource_id=user_id,
details={"username": user.username, "email": user.email}
details={"username": user.username, "email": user.email},
)
logger.info(f"User deleted: {user.username} by {current_user['username']}")
return {"message": "User deleted successfully"}
@@ -350,50 +355,51 @@ async def change_password(
user_id: str,
password_data: PasswordChangeRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Change user password"""
# Users can only change their own password, or admins can change any password
is_self_update = int(user_id) == current_user['id']
is_self_update = int(user_id) == current_user["id"]
if not is_self_update:
require_permission(current_user.get("permissions", []), "platform:users:update")
# Get user
query = select(User).where(User.id == int(user_id))
result = await db.execute(query)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
# For self-updates, verify current password
if is_self_update:
if not verify_password(password_data.current_password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Current password is incorrect"
detail="Current password is incorrect",
)
# Update password
user.hashed_password = get_password_hash(password_data.new_password)
await db.commit()
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="change_password",
resource_type="user",
resource_id=user_id,
details={"target_user": user.username}
details={"target_user": user.username},
)
logger.info(f"Password changed for user: {user.username} by {current_user['username']}")
logger.info(
f"Password changed for user: {user.username} by {current_user['username']}"
)
return {"message": "Password changed successfully"}
@@ -402,40 +408,41 @@ async def reset_password(
user_id: str,
password_data: PasswordResetRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Reset user password (admin only)"""
# Check permissions
require_permission(current_user.get("permissions", []), "platform:users:update")
# Get user
query = select(User).where(User.id == int(user_id))
result = await db.execute(query)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
# Reset password
user.hashed_password = get_password_hash(password_data.new_password)
await db.commit()
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="reset_password",
resource_type="user",
resource_id=user_id,
details={"target_user": user.username}
details={"target_user": user.username},
)
logger.info(f"Password reset for user: {user.username} by {current_user['username']}")
logger.info(
f"Password reset for user: {user.username} by {current_user['username']}"
)
return {"message": "Password reset successfully"}
@@ -443,20 +450,22 @@ async def reset_password(
async def get_user_api_keys(
user_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get API keys for a user"""
# Check permissions (users can view their own API keys)
is_self_request = int(user_id) == current_user['id']
is_self_request = int(user_id) == current_user["id"]
if not is_self_request:
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
require_permission(
current_user.get("permissions", []), "platform:api-keys:read"
)
# Get API keys
query = select(APIKey).where(APIKey.user_id == int(user_id))
result = await db.execute(query)
api_keys = result.scalars().all()
# Return safe representation (no key values)
return [
{
@@ -466,8 +475,12 @@ async def get_user_api_keys(
"scopes": api_key.scopes,
"is_active": api_key.is_active,
"created_at": api_key.created_at.isoformat(),
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
"last_used_at": api_key.last_used_at.isoformat() if api_key.last_used_at else None
"expires_at": api_key.expires_at.isoformat()
if api_key.expires_at
else None,
"last_used_at": api_key.last_used_at.isoformat()
if api_key.last_used_at
else None,
}
for api_key in api_keys
]
]

View File

@@ -1,3 +1,3 @@
"""
Core package
"""
"""

View File

@@ -19,24 +19,19 @@ logger = logging.getLogger(__name__)
class CoreCacheService:
"""Core Redis-based cache service for system-wide caching"""
def __init__(self):
self.redis_pool: Optional[ConnectionPool] = None
self.redis_client: Optional[Redis] = None
self.enabled = False
self.stats = {
"hits": 0,
"misses": 0,
"errors": 0,
"total_requests": 0
}
self.stats = {"hits": 0, "misses": 0, "errors": 0, "total_requests": 0}
async def initialize(self):
"""Initialize the core cache service with connection pool"""
try:
# Create Redis connection pool for better resource management
redis_url = getattr(settings, 'REDIS_URL', 'redis://localhost:6379/0')
redis_url = getattr(settings, "REDIS_URL", "redis://localhost:6379/0")
self.redis_pool = ConnectionPool.from_url(
redis_url,
encoding="utf-8",
@@ -45,141 +40,145 @@ class CoreCacheService:
socket_timeout=5,
retry_on_timeout=True,
max_connections=20, # Shared pool for all cache operations
health_check_interval=30
health_check_interval=30,
)
self.redis_client = Redis(connection_pool=self.redis_pool)
# Test connection
await self.redis_client.ping()
self.enabled = True
logger.info("Core cache service initialized with Redis connection pool")
except Exception as e:
logger.error(f"Failed to initialize core cache service: {e}")
self.enabled = False
raise
async def cleanup(self):
"""Cleanup cache resources"""
if self.redis_client:
await self.redis_client.close()
self.redis_client = None
if self.redis_pool:
await self.redis_pool.disconnect()
self.redis_pool = None
self.enabled = False
logger.info("Core cache service cleaned up")
def _get_cache_key(self, key: str, prefix: str = "core") -> str:
"""Generate cache key with prefix"""
return f"{prefix}:{key}"
async def get(self, key: str, default: Any = None, prefix: str = "core") -> Any:
"""Get value from cache"""
if not self.enabled:
return default
try:
cache_key = self._get_cache_key(key, prefix)
value = await self.redis_client.get(cache_key)
if value is None:
self.stats["misses"] += 1
return default
self.stats["hits"] += 1
self.stats["total_requests"] += 1
# Try to deserialize JSON
try:
return json.loads(value)
except json.JSONDecodeError:
return value
except Exception as e:
logger.error(f"Cache get error for key {key}: {e}")
self.stats["errors"] += 1
return default
async def set(self, key: str, value: Any, ttl: Optional[int] = None, prefix: str = "core") -> bool:
async def set(
self, key: str, value: Any, ttl: Optional[int] = None, prefix: str = "core"
) -> bool:
"""Set value in cache"""
if not self.enabled:
return False
try:
cache_key = self._get_cache_key(key, prefix)
ttl = ttl or 3600 # Default 1 hour TTL
# Serialize complex objects as JSON
if isinstance(value, (dict, list, tuple)):
value = json.dumps(value)
await self.redis_client.setex(cache_key, ttl, value)
return True
except Exception as e:
logger.error(f"Cache set error for key {key}: {e}")
self.stats["errors"] += 1
return False
async def delete(self, key: str, prefix: str = "core") -> bool:
"""Delete key from cache"""
if not self.enabled:
return False
try:
cache_key = self._get_cache_key(key, prefix)
result = await self.redis_client.delete(cache_key)
return result > 0
except Exception as e:
logger.error(f"Cache delete error for key {key}: {e}")
self.stats["errors"] += 1
return False
async def exists(self, key: str, prefix: str = "core") -> bool:
"""Check if key exists in cache"""
if not self.enabled:
return False
try:
cache_key = self._get_cache_key(key, prefix)
return await self.redis_client.exists(cache_key) > 0
except Exception as e:
logger.error(f"Cache exists error for key {key}: {e}")
self.stats["errors"] += 1
return False
async def clear_pattern(self, pattern: str, prefix: str = "core") -> int:
"""Clear keys matching pattern"""
if not self.enabled:
return 0
try:
cache_pattern = self._get_cache_key(pattern, prefix)
keys = await self.redis_client.keys(cache_pattern)
if keys:
return await self.redis_client.delete(*keys)
return 0
except Exception as e:
logger.error(f"Cache clear pattern error for pattern {pattern}: {e}")
self.stats["errors"] += 1
return 0
async def increment(self, key: str, amount: int = 1, ttl: Optional[int] = None, prefix: str = "core") -> int:
async def increment(
self, key: str, amount: int = 1, ttl: Optional[int] = None, prefix: str = "core"
) -> int:
"""Increment counter with optional TTL"""
if not self.enabled:
return 0
try:
cache_key = self._get_cache_key(key, prefix)
# Use pipeline for atomic increment + expire
async with self.redis_client.pipeline() as pipe:
await pipe.incr(cache_key, amount)
@@ -187,93 +186,118 @@ class CoreCacheService:
await pipe.expire(cache_key, ttl)
results = await pipe.execute()
return results[0]
except Exception as e:
logger.error(f"Cache increment error for key {key}: {e}")
self.stats["errors"] += 1
return 0
async def get_stats(self) -> Dict[str, Any]:
"""Get comprehensive cache statistics"""
stats = self.stats.copy()
if self.enabled:
try:
info = await self.redis_client.info()
stats.update({
"redis_memory_used": info.get("used_memory_human", "N/A"),
"redis_connected_clients": info.get("connected_clients", 0),
"redis_total_commands": info.get("total_commands_processed", 0),
"redis_keyspace_hits": info.get("keyspace_hits", 0),
"redis_keyspace_misses": info.get("keyspace_misses", 0),
"connection_pool_size": self.redis_pool.connection_pool_size if self.redis_pool else 0,
"hit_rate": round(
(stats["hits"] / stats["total_requests"]) * 100, 2
) if stats["total_requests"] > 0 else 0,
"enabled": True
})
stats.update(
{
"redis_memory_used": info.get("used_memory_human", "N/A"),
"redis_connected_clients": info.get("connected_clients", 0),
"redis_total_commands": info.get("total_commands_processed", 0),
"redis_keyspace_hits": info.get("keyspace_hits", 0),
"redis_keyspace_misses": info.get("keyspace_misses", 0),
"connection_pool_size": self.redis_pool.connection_pool_size
if self.redis_pool
else 0,
"hit_rate": round(
(stats["hits"] / stats["total_requests"]) * 100, 2
)
if stats["total_requests"] > 0
else 0,
"enabled": True,
}
)
except Exception as e:
logger.error(f"Error getting Redis stats: {e}")
stats["enabled"] = False
else:
stats["enabled"] = False
return stats
@asynccontextmanager
async def pipeline(self):
"""Context manager for Redis pipeline operations"""
if not self.enabled:
yield None
return
async with self.redis_client.pipeline() as pipe:
yield pipe
# Specialized caching methods for common use cases
async def cache_api_key(self, key_prefix: str, api_key_data: Dict[str, Any], ttl: int = 300) -> bool:
async def cache_api_key(
self, key_prefix: str, api_key_data: Dict[str, Any], ttl: int = 300
) -> bool:
"""Cache API key data for authentication"""
return await self.set(key_prefix, api_key_data, ttl, prefix="auth")
async def get_cached_api_key(self, key_prefix: str) -> Optional[Dict[str, Any]]:
"""Get cached API key data"""
return await self.get(key_prefix, prefix="auth")
async def invalidate_api_key(self, key_prefix: str) -> bool:
"""Invalidate cached API key"""
return await self.delete(key_prefix, prefix="auth")
async def cache_verification_result(self, api_key: str, key_prefix: str, key_hash: str, is_valid: bool, ttl: int = 300) -> bool:
async def cache_verification_result(
self,
api_key: str,
key_prefix: str,
key_hash: str,
is_valid: bool,
ttl: int = 300,
) -> bool:
"""Cache API key verification result to avoid expensive bcrypt operations"""
verification_data = {
"key_hash": key_hash,
"is_valid": is_valid,
"timestamp": datetime.utcnow().isoformat()
"timestamp": datetime.utcnow().isoformat(),
}
return await self.set(f"verify:{key_prefix}", verification_data, ttl, prefix="auth")
async def get_cached_verification(self, key_prefix: str) -> Optional[Dict[str, Any]]:
return await self.set(
f"verify:{key_prefix}", verification_data, ttl, prefix="auth"
)
async def get_cached_verification(
self, key_prefix: str
) -> Optional[Dict[str, Any]]:
"""Get cached verification result"""
return await self.get(f"verify:{key_prefix}", prefix="auth")
async def cache_rate_limit(self, identifier: str, window_seconds: int, limit: int, current_count: int = 1) -> Dict[str, Any]:
async def cache_rate_limit(
self, identifier: str, window_seconds: int, limit: int, current_count: int = 1
) -> Dict[str, Any]:
"""Cache and track rate limit state"""
key = f"rate_limit:{identifier}:{window_seconds}"
try:
# Use atomic increment with expiry
count = await self.increment(key, current_count, window_seconds, prefix="rate")
count = await self.increment(
key, current_count, window_seconds, prefix="rate"
)
remaining = max(0, limit - count)
reset_time = int((datetime.utcnow() + timedelta(seconds=window_seconds)).timestamp())
reset_time = int(
(datetime.utcnow() + timedelta(seconds=window_seconds)).timestamp()
)
return {
"count": count,
"limit": limit,
"remaining": remaining,
"reset_time": reset_time,
"exceeded": count > limit
"exceeded": count > limit,
}
except Exception as e:
logger.error(f"Rate limit cache error: {e}")
@@ -282,8 +306,10 @@ class CoreCacheService:
"count": 0,
"limit": limit,
"remaining": limit,
"reset_time": int((datetime.utcnow() + timedelta(seconds=window_seconds)).timestamp()),
"exceeded": False
"reset_time": int(
(datetime.utcnow() + timedelta(seconds=window_seconds)).timestamp()
),
"exceeded": False,
}
@@ -297,7 +323,9 @@ async def get(key: str, default: Any = None, prefix: str = "core") -> Any:
return await core_cache.get(key, default, prefix)
async def set(key: str, value: Any, ttl: Optional[int] = None, prefix: str = "core") -> bool:
async def set(
key: str, value: Any, ttl: Optional[int] = None, prefix: str = "core"
) -> bool:
"""Set value in core cache"""
return await core_cache.set(key, value, ttl, prefix)
@@ -319,4 +347,4 @@ async def clear_pattern(pattern: str, prefix: str = "core") -> int:
async def get_stats() -> Dict[str, Any]:
"""Get core cache statistics"""
return await core_cache.get_stats()
return await core_cache.get_stats()

View File

@@ -10,7 +10,7 @@ from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""Application settings"""
# Application
APP_NAME: str = os.getenv("APP_NAME", "Enclava")
APP_DEBUG: bool = os.getenv("APP_DEBUG", "False").lower() == "true"
@@ -19,131 +19,188 @@ class Settings(BaseSettings):
APP_PORT: int = int(os.getenv("APP_PORT", "8000"))
BACKEND_INTERNAL_PORT: int = int(os.getenv("BACKEND_INTERNAL_PORT", "8000"))
FRONTEND_INTERNAL_PORT: int = int(os.getenv("FRONTEND_INTERNAL_PORT", "3000"))
# Detailed logging for LLM interactions
LOG_LLM_PROMPTS: bool = os.getenv("LOG_LLM_PROMPTS", "False").lower() == "true" # Set to True to log prompts and context sent to LLM
LOG_LLM_PROMPTS: bool = (
os.getenv("LOG_LLM_PROMPTS", "False").lower() == "true"
) # Set to True to log prompts and context sent to LLM
# Database
DATABASE_URL: str = os.getenv("DATABASE_URL")
# Redis
REDIS_URL: str = os.getenv("REDIS_URL", "redis://localhost:6379")
# Security
JWT_SECRET: str = os.getenv("JWT_SECRET")
JWT_ALGORITHM: str = os.getenv("JWT_ALGORITHM", "HS256")
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440")) # 24 hours
REFRESH_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_MINUTES", "10080")) # 7 days
SESSION_EXPIRE_MINUTES: int = int(os.getenv("SESSION_EXPIRE_MINUTES", "1440")) # 24 hours
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(
os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440")
) # 24 hours
REFRESH_TOKEN_EXPIRE_MINUTES: int = int(
os.getenv("REFRESH_TOKEN_EXPIRE_MINUTES", "10080")
) # 7 days
SESSION_EXPIRE_MINUTES: int = int(
os.getenv("SESSION_EXPIRE_MINUTES", "1440")
) # 24 hours
API_KEY_PREFIX: str = os.getenv("API_KEY_PREFIX", "en_")
BCRYPT_ROUNDS: int = int(os.getenv("BCRYPT_ROUNDS", "6")) # Bcrypt work factor - lower for production performance
BCRYPT_ROUNDS: int = int(
os.getenv("BCRYPT_ROUNDS", "6")
) # Bcrypt work factor - lower for production performance
# Admin user provisioning (used only on first startup)
ADMIN_EMAIL: str = os.getenv("ADMIN_EMAIL")
ADMIN_PASSWORD: str = os.getenv("ADMIN_PASSWORD")
# Base URL for deriving CORS origins
BASE_URL: str = os.getenv("BASE_URL", "localhost")
@field_validator('CORS_ORIGINS', mode='before')
@field_validator("CORS_ORIGINS", mode="before")
@classmethod
def derive_cors_origins(cls, v, info):
"""Derive CORS origins from BASE_URL if not explicitly set"""
if v is None:
base_url = info.data.get('BASE_URL', 'localhost')
base_url = info.data.get("BASE_URL", "localhost")
# Support both HTTP and HTTPS for production environments
return [f"http://{base_url}", f"https://{base_url}"]
return v if isinstance(v, list) else [v]
# CORS origins (derived from BASE_URL)
CORS_ORIGINS: Optional[List[str]] = None
# LLM Service Configuration (replaced LiteLLM)
# LLM service configuration is now handled in app/services/llm/config.py
# LLM Service Security (removed encryption - credentials handled by proxy)
# Plugin System Security
PLUGIN_ENCRYPTION_KEY: Optional[str] = os.getenv("PLUGIN_ENCRYPTION_KEY") # Key for encrypting plugin secrets and configurations
PLUGIN_ENCRYPTION_KEY: Optional[str] = os.getenv(
"PLUGIN_ENCRYPTION_KEY"
) # Key for encrypting plugin secrets and configurations
# API Keys for LLM providers
OPENAI_API_KEY: Optional[str] = os.getenv("OPENAI_API_KEY")
ANTHROPIC_API_KEY: Optional[str] = os.getenv("ANTHROPIC_API_KEY")
GOOGLE_API_KEY: Optional[str] = os.getenv("GOOGLE_API_KEY")
PRIVATEMODE_API_KEY: Optional[str] = os.getenv("PRIVATEMODE_API_KEY")
PRIVATEMODE_PROXY_URL: str = os.getenv("PRIVATEMODE_PROXY_URL", "http://privatemode-proxy:8080/v1")
PRIVATEMODE_PROXY_URL: str = os.getenv(
"PRIVATEMODE_PROXY_URL", "http://privatemode-proxy:8080/v1"
)
# Qdrant
QDRANT_HOST: str = os.getenv("QDRANT_HOST", "localhost")
QDRANT_PORT: int = int(os.getenv("QDRANT_PORT", "6333"))
QDRANT_API_KEY: Optional[str] = os.getenv("QDRANT_API_KEY")
QDRANT_URL: str = os.getenv("QDRANT_URL", "http://localhost:6333")
# Rate Limiting Configuration
# PrivateMode Standard tier limits (organization-level, not per user)
# These are shared across all API keys and users in the organization
PRIVATEMODE_REQUESTS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_MINUTE", "20"))
PRIVATEMODE_REQUESTS_PER_HOUR: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_HOUR", "1200"))
PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE", "20000"))
PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE", "10000"))
PRIVATEMODE_REQUESTS_PER_MINUTE: int = int(
os.getenv("PRIVATEMODE_REQUESTS_PER_MINUTE", "20")
)
PRIVATEMODE_REQUESTS_PER_HOUR: int = int(
os.getenv("PRIVATEMODE_REQUESTS_PER_HOUR", "1200")
)
PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE: int = int(
os.getenv("PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE", "20000")
)
PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE: int = int(
os.getenv("PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE", "10000")
)
# Per-user limits (additional protection on top of organization limits)
API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "20")) # Match PrivateMode
API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "1200"))
API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(
os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "20")
) # Match PrivateMode
API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(
os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "1200")
)
# API key users (programmatic access)
API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "20")) # Match PrivateMode
API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "1200"))
API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(
os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "20")
) # Match PrivateMode
API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(
os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "1200")
)
# Premium/Enterprise API keys
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")) # Match PrivateMode
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200"))
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(
os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")
) # Match PrivateMode
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(
os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200")
)
# Request Size Limits
API_MAX_REQUEST_BODY_SIZE: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE", "10485760")) # 10MB
API_MAX_REQUEST_BODY_SIZE_PREMIUM: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE_PREMIUM", "52428800")) # 50MB for premium
API_MAX_REQUEST_BODY_SIZE: int = int(
os.getenv("API_MAX_REQUEST_BODY_SIZE", "10485760")
) # 10MB
API_MAX_REQUEST_BODY_SIZE_PREMIUM: int = int(
os.getenv("API_MAX_REQUEST_BODY_SIZE_PREMIUM", "52428800")
) # 50MB for premium
# IP Security
# Security Headers
API_CSP_HEADER: str = os.getenv("API_CSP_HEADER", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'")
API_CSP_HEADER: str = os.getenv(
"API_CSP_HEADER",
"default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'",
)
# Monitoring
PROMETHEUS_ENABLED: bool = os.getenv("PROMETHEUS_ENABLED", "True").lower() == "true"
PROMETHEUS_PORT: int = int(os.getenv("PROMETHEUS_PORT", "9090"))
# File uploads
MAX_UPLOAD_SIZE: int = int(os.getenv("MAX_UPLOAD_SIZE", "10485760")) # 10MB
# Module configuration
MODULES_CONFIG_PATH: str = os.getenv("MODULES_CONFIG_PATH", "config/modules.yaml")
# RAG Embedding Configuration
RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE: int = int(os.getenv("RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE", "12"))
RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE: int = int(
os.getenv("RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE", "12")
)
RAG_EMBEDDING_BATCH_SIZE: int = int(os.getenv("RAG_EMBEDDING_BATCH_SIZE", "3"))
RAG_EMBEDDING_RETRY_COUNT: int = int(os.getenv("RAG_EMBEDDING_RETRY_COUNT", "3"))
RAG_EMBEDDING_RETRY_DELAYS: str = os.getenv("RAG_EMBEDDING_RETRY_DELAYS", "1,2,4,8,16")
RAG_EMBEDDING_DELAY_BETWEEN_BATCHES: float = float(os.getenv("RAG_EMBEDDING_DELAY_BETWEEN_BATCHES", "1.0"))
RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5"))
RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true"
RAG_WARN_ON_FALLBACK: bool = os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true"
RAG_EMBEDDING_RETRY_DELAYS: str = os.getenv(
"RAG_EMBEDDING_RETRY_DELAYS", "1,2,4,8,16"
)
RAG_EMBEDDING_DELAY_BETWEEN_BATCHES: float = float(
os.getenv("RAG_EMBEDDING_DELAY_BETWEEN_BATCHES", "1.0")
)
RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(
os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5")
)
RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = (
os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true"
)
RAG_WARN_ON_FALLBACK: bool = (
os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true"
)
RAG_EMBEDDING_MODEL: str = os.getenv("RAG_EMBEDDING_MODEL", "bge-m3")
RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300"))
RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120"))
RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(
os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300")
)
RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(
os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120")
)
RAG_INDEXING_TIMEOUT: int = int(os.getenv("RAG_INDEXING_TIMEOUT", "120"))
# Plugin configuration
PLUGINS_DIR: str = os.getenv("PLUGINS_DIR", "/plugins")
PLUGINS_CONFIG_PATH: str = os.getenv("PLUGINS_CONFIG_PATH", "config/plugins.yaml")
PLUGIN_REPOSITORY_URL: str = os.getenv("PLUGIN_REPOSITORY_URL", "https://plugins.enclava.com")
PLUGIN_REPOSITORY_URL: str = os.getenv(
"PLUGIN_REPOSITORY_URL", "https://plugins.enclava.com"
)
# Logging
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "json")
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
model_config = {
"env_file": ".env",
"case_sensitive": True,

View File

@@ -13,7 +13,7 @@ from app.core.config import settings
def setup_logging() -> None:
"""Setup structured logging"""
# Configure structlog
structlog.configure(
processors=[
@@ -24,21 +24,23 @@ def setup_logging() -> None:
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
structlog.processors.JSONRenderer() if settings.LOG_FORMAT == "json" else structlog.dev.ConsoleRenderer(),
structlog.processors.JSONRenderer()
if settings.LOG_FORMAT == "json"
else structlog.dev.ConsoleRenderer(),
],
context_class=dict,
logger_factory=LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)
# Configure standard logging
logging.basicConfig(
format="%(message)s",
stream=sys.stdout,
level=getattr(logging, settings.LOG_LEVEL.upper()),
)
# Set specific loggers
logging.getLogger("uvicorn").setLevel(logging.WARNING)
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
@@ -52,17 +54,17 @@ def get_logger(name: str) -> structlog.stdlib.BoundLogger:
class RequestContextFilter(logging.Filter):
"""Add request context to log records"""
def filter(self, record: logging.LogRecord) -> bool:
# Add request context if available
from contextvars import ContextVar
request_id: ContextVar[str] = ContextVar("request_id", default="")
user_id: ContextVar[str] = ContextVar("user_id", default="")
record.request_id = request_id.get()
record.user_id = user_id.get()
return True
@@ -77,7 +79,7 @@ def log_request(
) -> None:
"""Log HTTP request"""
logger = get_logger("api.request")
log_data = {
"method": method,
"path": path,
@@ -87,7 +89,7 @@ def log_request(
"request_id": request_id,
**kwargs,
}
if status_code >= 500:
logger.error("Request failed", **log_data)
elif status_code >= 400:
@@ -105,7 +107,7 @@ def log_security_event(
) -> None:
"""Log security event"""
logger = get_logger("security")
log_data = {
"event_type": event_type,
"user_id": user_id,
@@ -113,7 +115,7 @@ def log_security_event(
"details": details or {},
**kwargs,
}
logger.warning("Security event", **log_data)
@@ -125,14 +127,14 @@ def log_module_event(
) -> None:
"""Log module event"""
logger = get_logger("module")
log_data = {
"module_id": module_id,
"event_type": event_type,
"details": details or {},
**kwargs,
}
logger.info("Module event", **log_data)
@@ -143,11 +145,11 @@ def log_api_request(
) -> None:
"""Log API request for modules endpoints"""
logger = get_logger("api.module")
log_data = {
"endpoint": endpoint,
"params": params or {},
**kwargs,
}
logger.info("API request", **log_data)
logger.info("API request", **log_data)

View File

@@ -0,0 +1,367 @@
"""
Permissions Module
Role-based access control decorators and utilities
"""
from datetime import datetime
from functools import wraps
from typing import List, Optional, Union, Callable
from fastapi import HTTPException, status, Depends
from fastapi.security import HTTPBearer
from app.models.user import User
security = HTTPBearer()
def require_permission(
user: User, permission: str, resource_id: Optional[Union[str, int]] = None
):
"""
Check if user has the required permission
Args:
user: User object from dependency injection
permission: Required permission string
resource_id: Optional resource ID for resource-specific permissions
Raises:
HTTPException: If user doesn't have the required permission
"""
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required"
)
# Check if user is active
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Account is not active"
)
# Check if account is locked
if user.account_locked:
if user.account_locked_until and user.account_locked_until > datetime.utcnow():
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Account is temporarily locked",
)
else:
# Unlock account if lock period has expired
user.unlock_account()
# Check permission
if not user.has_permission(permission):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Permission '{permission}' required",
)
def require_permissions(permissions: List[str], require_all: bool = True):
"""
Decorator to require multiple permissions
Args:
permissions: List of required permissions
require_all: If True, user must have all permissions. If False, any one permission is sufficient
"""
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
# Extract user from kwargs (assuming it's passed as a dependency)
user = None
for key, value in kwargs.items():
if isinstance(value, User):
user = value
break
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required",
)
# Check permissions
if require_all:
# User must have all permissions
for permission in permissions:
if not user.has_permission(permission):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"All of the following permissions required: {', '.join(permissions)}",
)
else:
# User needs at least one permission
if not any(
user.has_permission(permission) for permission in permissions
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"At least one of the following permissions required: {', '.join(permissions)}",
)
return await func(*args, **kwargs)
return wrapper
return decorator
def require_role(role_names: Union[str, List[str]], require_all: bool = True):
"""
Decorator to require specific roles
Args:
role_names: Required role name(s)
require_all: If True, user must have all roles. If False, any one role is sufficient
"""
if isinstance(role_names, str):
role_names = [role_names]
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
# Extract user from kwargs
user = None
for key, value in kwargs.items():
if isinstance(value, User):
user = value
break
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required",
)
# Check roles
user_role_names = []
if user.role:
user_role_names.append(user.role.name)
if require_all:
# User must have all roles
for role_name in role_names:
if role_name not in user_role_names:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"All of the following roles required: {', '.join(role_names)}",
)
else:
# User needs at least one role
if not any(role_name in user_role_names for role_name in role_names):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"At least one of the following roles required: {', '.join(role_names)}",
)
return await func(*args, **kwargs)
return wrapper
return decorator
def require_minimum_role(minimum_role_level: str):
"""
Decorator to require minimum role level based on hierarchy
Args:
minimum_role_level: Minimum required role level
"""
role_hierarchy = {"read_only": 1, "user": 2, "admin": 3, "super_admin": 4}
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
# Extract user from kwargs
user = None
for key, value in kwargs.items():
if isinstance(value, User):
user = value
break
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required",
)
# Superusers bypass role checks
if user.is_superuser:
return await func(*args, **kwargs)
# Check role level
if not user.role:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Minimum role level '{minimum_role_level}' required",
)
user_level = role_hierarchy.get(user.role.level, 0)
required_level = role_hierarchy.get(minimum_role_level, 0)
if user_level < required_level:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Minimum role level '{minimum_role_level}' required",
)
return await func(*args, **kwargs)
return wrapper
return decorator
def check_resource_permission(
user: User, resource_type: str, resource_id: Union[str, int], action: str
) -> bool:
"""
Check if user has permission to perform action on specific resource
Args:
user: User object
resource_type: Type of resource (e.g., 'user', 'budget', 'api_key')
resource_id: ID of the resource
action: Action to perform (e.g., 'read', 'update', 'delete')
Returns:
bool: True if user has permission, False otherwise
"""
# Superusers can do anything
if user.is_superuser:
return True
# Check basic permissions
permission = f"{action}_{resource_type}"
if user.has_permission(permission):
return True
# Check own resource permissions
if resource_type == "user" and str(resource_id) == str(user.id):
if user.has_permission(f"{action}_own"):
return True
# Check role-based resource access
if user.role:
# Admins can manage all users
if resource_type == "user" and user.role.level in ["admin", "super_admin"]:
return True
# Users with budget permissions can manage budgets
if resource_type == "budget" and user.role.can_manage_budgets:
return True
return False
def require_resource_permission(
resource_type: str, resource_id_param: str = "resource_id", action: str = "read"
):
"""
Decorator to require permission for specific resource
Args:
resource_type: Type of resource
resource_id_param: Name of parameter containing resource ID
action: Action to perform
"""
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
# Extract user and resource ID from kwargs
user = None
resource_id = None
for key, value in kwargs.items():
if isinstance(value, User):
user = value
elif key == resource_id_param:
resource_id = value
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required",
)
if resource_id is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Resource ID not provided",
)
# Check resource permission
if not check_resource_permission(user, resource_type, resource_id, action):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Permission '{action}_{resource_type}' required",
)
return await func(*args, **kwargs)
return wrapper
return decorator
def check_budget_permission(user: User, budget_id: int, action: str) -> bool:
"""
Check if user has permission to perform action on specific budget
Args:
user: User object
budget_id: ID of the budget
action: Action to perform
Returns:
bool: True if user has permission, False otherwise
"""
# Superusers can do anything
if user.is_superuser:
return True
# Check if user owns the budget
for budget in user.budgets:
if budget.id == budget_id and budget.is_active:
return user.has_permission(f"{action}_own")
# Check if user can manage all budgets
if user.has_permission(f"{action}_all") or (
user.role and user.role.can_manage_budgets
):
return True
return False
def check_api_key_permission(user: User, api_key_id: int, action: str) -> bool:
"""
Check if user has permission to perform action on specific API key
Args:
user: User object
api_key_id: ID of the API key
action: Action to perform
Returns:
bool: True if user has permission, False otherwise
"""
# Superusers can do anything
if user.is_superuser:
return True
# Check if user owns the API key
for api_key in user.api_keys:
if api_key.id == api_key_id:
return user.has_permission(f"{action}_own")
# Check if user can manage all API keys
if user.has_permission(f"{action}_all"):
return True
return False

View File

@@ -24,34 +24,39 @@ logger = logging.getLogger(__name__)
# Password hashing
# Use a lower work factor for better performance in production
pwd_context = CryptContext(
schemes=["bcrypt"],
deprecated="auto",
bcrypt__rounds=settings.BCRYPT_ROUNDS
schemes=["bcrypt"], deprecated="auto", bcrypt__rounds=settings.BCRYPT_ROUNDS
)
# JWT token handling
security = HTTPBearer()
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash"""
import time
start_time = time.time()
logger.info(f"=== PASSWORD VERIFICATION START === BCRYPT_ROUNDS: {settings.BCRYPT_ROUNDS}")
logger.info(
f"=== PASSWORD VERIFICATION START === BCRYPT_ROUNDS: {settings.BCRYPT_ROUNDS}"
)
try:
# Run password verification in a thread with timeout
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(pwd_context.verify, plain_password, hashed_password)
future = executor.submit(
pwd_context.verify, plain_password, hashed_password
)
result = future.result(timeout=5.0) # 5 second timeout
end_time = time.time()
duration = end_time - start_time
logger.info(f"=== PASSWORD VERIFICATION END === Duration: {duration:.3f}s, Result: {result}")
logger.info(
f"=== PASSWORD VERIFICATION END === Duration: {duration:.3f}s, Result: {result}"
)
if duration > 1:
logger.warning(f"PASSWORD VERIFICATION TOOK TOO LONG: {duration:.3f}s")
return result
except concurrent.futures.TimeoutError:
end_time = time.time()
@@ -61,87 +66,116 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
except Exception as e:
end_time = time.time()
duration = end_time - start_time
logger.error(f"=== PASSWORD VERIFICATION FAILED === Duration: {duration:.3f}s, Error: {e}")
logger.error(
f"=== PASSWORD VERIFICATION FAILED === Duration: {duration:.3f}s, Error: {e}"
)
raise
def get_password_hash(password: str) -> str:
"""Generate password hash"""
return pwd_context.hash(password)
def verify_api_key(plain_api_key: str, hashed_api_key: str) -> bool:
"""Verify an API key against its hash"""
return pwd_context.verify(plain_api_key, hashed_api_key)
def get_api_key_hash(api_key: str) -> str:
"""Generate API key hash"""
return pwd_context.hash(api_key)
def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
def create_access_token(
data: Dict[str, Any], expires_delta: Optional[timedelta] = None
) -> str:
"""Create JWT access token"""
import time
start_time = time.time()
logger.info(f"=== CREATE ACCESS TOKEN START ===")
try:
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
expire = datetime.utcnow() + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
to_encode.update({"exp": expire})
logger.info(f"JWT encode start...")
encode_start = time.time()
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
encoded_jwt = jwt.encode(
to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM
)
encode_end = time.time()
encode_duration = encode_end - encode_start
end_time = time.time()
total_duration = end_time - start_time
# Log token creation details
logger.info(f"Created access token for user {data.get('sub')}")
logger.info(f"Token expires at: {expire.isoformat()} (UTC)")
logger.info(f"Current UTC time: {datetime.utcnow().isoformat()}")
logger.info(f"ACCESS_TOKEN_EXPIRE_MINUTES setting: {settings.ACCESS_TOKEN_EXPIRE_MINUTES}")
logger.info(
f"ACCESS_TOKEN_EXPIRE_MINUTES setting: {settings.ACCESS_TOKEN_EXPIRE_MINUTES}"
)
logger.info(f"JWT encode duration: {encode_duration:.3f}s")
logger.info(f"Total token creation duration: {total_duration:.3f}s")
logger.info(f"=== CREATE ACCESS TOKEN END ===")
return encoded_jwt
except Exception as e:
end_time = time.time()
total_duration = end_time - start_time
logger.error(f"=== CREATE ACCESS TOKEN FAILED === Duration: {total_duration:.3f}s, Error: {e}")
logger.error(
f"=== CREATE ACCESS TOKEN FAILED === Duration: {total_duration:.3f}s, Error: {e}"
)
raise
def create_refresh_token(data: Dict[str, Any]) -> str:
"""Create JWT refresh token"""
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES)
expire = datetime.utcnow() + timedelta(
minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES
)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
encoded_jwt = jwt.encode(
to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM
)
return encoded_jwt
def verify_token(token: str) -> Dict[str, Any]:
"""Verify JWT token and return payload"""
try:
# Log current time before verification
current_time = datetime.utcnow()
logger.info(f"Verifying token at: {current_time.isoformat()} (UTC)")
# Decode without verification first to check expiration
try:
unverified_payload = jwt.get_unverified_claims(token)
exp_timestamp = unverified_payload.get('exp')
exp_timestamp = unverified_payload.get("exp")
if exp_timestamp:
exp_datetime = datetime.fromtimestamp(exp_timestamp, tz=None)
logger.info(f"Token expiration time: {exp_datetime.isoformat()} (UTC)")
logger.info(f"Time until expiration: {(exp_datetime - current_time).total_seconds()} seconds")
logger.info(
f"Time until expiration: {(exp_datetime - current_time).total_seconds()} seconds"
)
except Exception as decode_error:
logger.warning(f"Could not decode token for expiration check: {decode_error}")
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
logger.warning(
f"Could not decode token for expiration check: {decode_error}"
)
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
logger.info(f"Token verified successfully for user {payload.get('sub')}")
return payload
except JWTError as e:
@@ -149,30 +183,32 @@ def verify_token(token: str) -> Dict[str, Any]:
logger.warning(f"Current UTC time: {datetime.utcnow().isoformat()}")
raise AuthenticationError("Invalid token")
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Dict[str, Any]:
"""Get current user from JWT token"""
try:
# Log server time for debugging clock sync issues
server_time = datetime.utcnow()
logger.info(f"get_current_user called at: {server_time.isoformat()} (UTC)")
payload = verify_token(credentials.credentials)
user_id: str = payload.get("sub")
if user_id is None:
raise AuthenticationError("Invalid token payload")
# Load user from database
from app.models.user import User
from sqlalchemy import select
from sqlalchemy.orm import selectinload
# Query user from database
stmt = select(User).where(User.id == int(user_id))
stmt = select(User).options(selectinload(User.role)).where(User.id == int(user_id))
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
# If user doesn't exist in DB but token is valid, create basic user info from token
return {
@@ -181,49 +217,53 @@ async def get_current_user(
"is_superuser": payload.get("is_superuser", False),
"role": payload.get("role", "user"),
"is_active": True,
"permissions": [] # Default to empty list for permissions
"permissions": [], # Default to empty list for permissions
}
# Update last login
user.update_last_login()
await db.commit()
# Calculate effective permissions using permission manager
from app.services.permission_manager import permission_registry
# Convert role string to list for permission calculation
user_roles = [user.role] if user.role else []
# Convert role to name for permission calculation
user_roles = [user.role.name] if user.role else []
# For super admin users, use only role-based permissions, ignore custom permissions
# Custom permissions might contain legacy formats like ['*'] that don't work with new system
# Custom permissions might contain legacy formats like ['*'] or dict formats
custom_permissions = []
if not user.is_superuser:
# Only use custom permissions for non-superuser accounts
if user.permissions:
if isinstance(user.permissions, list):
custom_permissions = user.permissions
# Support both list-based and dict-based custom permission formats
raw_custom_perms = getattr(user, "custom_permissions", None)
if raw_custom_perms:
if isinstance(raw_custom_perms, list):
custom_permissions = raw_custom_perms
elif isinstance(raw_custom_perms, dict):
granted = raw_custom_perms.get("granted")
if isinstance(granted, list):
custom_permissions = granted
# Calculate effective permissions based on role and custom permissions
effective_permissions = permission_registry.get_user_permissions(
roles=user_roles,
custom_permissions=custom_permissions
roles=user_roles, custom_permissions=custom_permissions
)
return {
"id": user.id,
"email": user.email,
"username": user.username,
"is_superuser": user.is_superuser,
"is_active": user.is_active,
"role": user.role,
"role": user.role.name if user.role else None,
"permissions": effective_permissions, # Use calculated permissions
"user_obj": user # Include full user object for other operations
"user_obj": user, # Include full user object for other operations
}
except Exception as e:
logger.error(f"Authentication error: {e}")
raise AuthenticationError("Could not validate credentials")
async def get_current_active_user(
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
@@ -233,6 +273,7 @@ async def get_current_active_user(
raise AuthenticationError("User account is inactive")
return current_user
async def get_current_superuser(
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
@@ -241,99 +282,120 @@ async def get_current_superuser(
raise AuthorizationError("Insufficient privileges")
return current_user
def generate_api_key() -> str:
"""Generate a new API key"""
import secrets
import string
# Generate random string
alphabet = string.ascii_letters + string.digits
api_key = ''.join(secrets.choice(alphabet) for _ in range(32))
api_key = "".join(secrets.choice(alphabet) for _ in range(32))
return f"{settings.API_KEY_PREFIX}{api_key}"
def hash_api_key(api_key: str) -> str:
"""Hash API key for storage"""
return get_password_hash(api_key)
def verify_api_key(api_key: str, hashed_key: str) -> bool:
"""Verify API key against hash"""
return verify_password(api_key, hashed_key)
async def get_api_key_user(
request: Request,
db: AsyncSession = Depends(get_db)
request: Request, db: AsyncSession = Depends(get_db)
) -> Optional[Dict[str, Any]]:
"""Get user from API key"""
api_key = request.headers.get("X-API-Key")
if not api_key:
return None
# Implement API key lookup in database
from app.models.api_key import APIKey
from app.models.user import User
from sqlalchemy import select
try:
# Extract key prefix for lookup
if len(api_key) < 8:
return None
key_prefix = api_key[:8]
# Query API key from database
stmt = select(APIKey).join(User).where(
APIKey.key_prefix == key_prefix,
APIKey.is_active == True,
User.is_active == True
stmt = (
select(APIKey)
.join(User)
.where(
APIKey.key_prefix == key_prefix,
APIKey.is_active == True,
User.is_active == True,
)
)
result = await db.execute(stmt)
db_api_key = result.scalar_one_or_none()
if not db_api_key:
return None
# Verify the API key hash
if not verify_api_key(api_key, db_api_key.key_hash):
return None
# Check if key is valid (not expired)
if not db_api_key.is_valid():
return None
# Update last used timestamp
db_api_key.last_used_at = datetime.utcnow()
await db.commit()
# Load associated user
user_stmt = select(User).where(User.id == db_api_key.user_id)
user_stmt = select(User).options(selectinload(User.role)).where(User.id == db_api_key.user_id)
user_result = await db.execute(user_stmt)
user = user_result.scalar_one_or_none()
if not user or not user.is_active:
return None
# Calculate effective permissions using permission manager
from app.services.permission_manager import permission_registry
# Convert role string to list for permission calculation
user_roles = [user.role] if user.role else []
# Use API key specific permissions if available, otherwise use user permissions
# Convert role to name for permission calculation
user_roles = [user.role.name] if user.role else []
# Use API key specific permissions if available
api_key_permissions = db_api_key.permissions if db_api_key.permissions else []
# Get custom permissions from database (convert dict to list if needed)
custom_permissions = api_key_permissions
if user.permissions:
if isinstance(user.permissions, list):
custom_permissions.extend(user.permissions)
# Normalize permissions into a flat list of granted permission strings
custom_permissions: list[str] = []
# Handle API key permissions that may be stored as list or dict
if isinstance(api_key_permissions, list):
custom_permissions.extend(api_key_permissions)
elif isinstance(api_key_permissions, dict):
api_granted = api_key_permissions.get("granted")
if isinstance(api_granted, list):
custom_permissions.extend(api_granted)
# Merge in user-level custom permissions for non-superusers
raw_user_custom = getattr(user, "custom_permissions", None)
if raw_user_custom and not user.is_superuser:
if isinstance(raw_user_custom, list):
custom_permissions.extend(raw_user_custom)
elif isinstance(raw_user_custom, dict):
user_granted = raw_user_custom.get("granted")
if isinstance(user_granted, list):
custom_permissions.extend(user_granted)
# Calculate effective permissions based on role and custom permissions
effective_permissions = permission_registry.get_user_permissions(
roles=user_roles,
custom_permissions=custom_permissions
roles=user_roles, custom_permissions=custom_permissions
)
return {
"id": user.id,
"email": user.email,
@@ -344,73 +406,80 @@ async def get_api_key_user(
"permissions": effective_permissions,
"api_key": db_api_key,
"user_obj": user,
"auth_type": "api_key"
"auth_type": "api_key",
}
except Exception as e:
logger.error(f"API key lookup error: {e}")
return None
class RequiresPermission:
"""Dependency class for permission checking"""
def __init__(self, permission: str):
self.permission = permission
def __call__(self, current_user: Dict[str, Any] = Depends(get_current_user)):
# Implement permission checking
# Check if user is superuser (has all permissions)
if current_user.get("is_superuser", False):
return current_user
# Check role-based permissions
role = current_user.get("role", "user")
role_permissions = {
"user": ["read_own", "create_own", "update_own"],
"admin": ["read_all", "create_all", "update_all", "delete_own"],
"super_admin": ["read_all", "create_all", "update_all", "delete_all", "manage_users", "manage_modules"]
"super_admin": [
"read_all",
"create_all",
"update_all",
"delete_all",
"manage_users",
"manage_modules",
],
}
if role in role_permissions and self.permission in role_permissions[role]:
return current_user
# Check custom permissions
user_permissions = current_user.get("permissions", {})
if self.permission in user_permissions:
return current_user
# If user has access to full user object, use the model's has_permission method
user_obj = current_user.get("user_obj")
if user_obj and hasattr(user_obj, "has_permission"):
if user_obj.has_permission(self.permission):
return current_user
raise AuthorizationError(f"Permission '{self.permission}' required")
class RequiresRole:
"""Dependency class for role checking"""
def __init__(self, role: str):
self.role = role
def __call__(self, current_user: Dict[str, Any] = Depends(get_current_user)):
# Implement role checking
# Superusers have access to everything
if current_user.get("is_superuser", False):
return current_user
user_role = current_user.get("role", "user")
# Define role hierarchy
role_hierarchy = {
"user": 1,
"admin": 2,
"super_admin": 3
}
role_hierarchy = {"user": 1, "admin": 2, "super_admin": 3}
required_level = role_hierarchy.get(self.role, 0)
user_level = role_hierarchy.get(user_role, 0)
if user_level >= required_level:
return current_user
raise AuthorizationError(f"Role '{self.role}' required, but user has role '{user_role}'")
raise AuthorizationError(
f"Role '{self.role}' required, but user has role '{user_role}'"
)

View File

@@ -1,3 +1,3 @@
"""
Database package
"""
"""

View File

@@ -20,10 +20,10 @@ engine = create_async_engine(
echo=settings.APP_DEBUG,
future=True,
pool_pre_ping=True,
pool_size=50, # Increased from 20 for better concurrency
max_overflow=100, # Increased from 30 for burst capacity
pool_recycle=3600, # Recycle connections every hour
pool_timeout=30, # Max time to get connection from pool
pool_size=50, # Increased from 20 for better concurrency
max_overflow=100, # Increased from 30 for burst capacity
pool_recycle=3600, # Recycle connections every hour
pool_timeout=30, # Max time to get connection from pool
connect_args={
"timeout": 5,
"command_timeout": 5,
@@ -46,10 +46,10 @@ sync_engine = create_engine(
echo=settings.APP_DEBUG,
future=True,
pool_pre_ping=True,
pool_size=25, # Increased from 10 for better performance
max_overflow=50, # Increased from 20 for burst capacity
pool_recycle=3600, # Recycle connections every hour
pool_timeout=30, # Max time to get connection from pool
pool_size=25, # Increased from 10 for better performance
max_overflow=50, # Increased from 20 for burst capacity
pool_recycle=3600, # Recycle connections every hour
pool_timeout=30, # Max time to get connection from pool
connect_args={
"connect_timeout": 5,
"application_name": "enclava_backend_sync",
@@ -72,11 +72,12 @@ metadata = MetaData()
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""Get database session"""
import time
start_time = time.time()
request_id = f"db_{int(time.time() * 1000)}"
logger.info(f"[{request_id}] === DATABASE SESSION START ===")
try:
logger.info(f"[{request_id}] Creating database session...")
async with async_session_factory() as session:
@@ -86,7 +87,10 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]:
except Exception as e:
# Only log if there's an actual error, not normal operation
if str(e).strip(): # Only log if error message exists
logger.error(f"[{request_id}] Database session error: {str(e)}", exc_info=True)
logger.error(
f"[{request_id}] Database session error: {str(e)}",
exc_info=True,
)
await session.rollback()
raise
finally:
@@ -94,9 +98,13 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]:
await session.close()
close_time = time.time() - close_start
total_time = time.time() - start_time
logger.info(f"[{request_id}] Database session closed. Close time: {close_time:.3f}s, Total time: {total_time:.3f}s")
logger.info(
f"[{request_id}] Database session closed. Close time: {close_time:.3f}s, Total time: {total_time:.3f}s"
)
except Exception as e:
logger.error(f"[{request_id}] Failed to create database session: {e}", exc_info=True)
logger.error(
f"[{request_id}] Failed to create database session: {e}", exc_info=True
)
raise
@@ -106,44 +114,82 @@ async def init_db():
async with engine.begin() as conn:
# Import all models to ensure they're registered
from app.models.user import User
from app.models.role import Role
from app.models.api_key import APIKey
from app.models.usage_tracking import UsageTracking
# Import additional models - these are available
try:
from app.models.budget import Budget
except ImportError:
logger.warning("Budget model not available yet")
try:
from app.models.audit_log import AuditLog
except ImportError:
logger.warning("AuditLog model not available yet")
try:
from app.models.module import Module
except ImportError:
logger.warning("Module model not available yet")
# Tables are now created via migration container - no need to create here
# await conn.run_sync(Base.metadata.create_all) # DISABLED - migrations handle this
# Create default roles if they don't exist
await create_default_roles()
# Create default admin user if no admin exists
await create_default_admin()
logger.info("Database initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize database: {e}")
raise
async def create_default_roles():
"""Create default roles if they don't exist"""
from app.models.role import Role, RoleLevel
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
try:
async with async_session_factory() as session:
# Check if any roles exist
stmt = select(Role).limit(1)
result = await session.execute(stmt)
existing_role = result.scalar_one_or_none()
if existing_role:
logger.info("Roles already exist - skipping default role creation")
return
# Create default roles using the Role.create_default_roles class method
default_roles = Role.create_default_roles()
for role in default_roles:
session.add(role)
await session.commit()
logger.info("Created default roles: read_only, user, admin, super_admin")
except SQLAlchemyError as e:
logger.error(f"Failed to create default roles due to database error: {e}")
raise
async def create_default_admin():
"""Create default admin user if user with ADMIN_EMAIL doesn't exist"""
from app.models.user import User
from app.models.role import Role
from app.core.security import get_password_hash
from app.core.config import settings
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
try:
admin_email = settings.ADMIN_EMAIL
admin_password = settings.ADMIN_PASSWORD
@@ -151,42 +197,61 @@ async def create_default_admin():
if not admin_email or not admin_password:
logger.info("Admin bootstrap skipped: ADMIN_EMAIL or ADMIN_PASSWORD unset")
return
async with async_session_factory() as session:
# Check if user with ADMIN_EMAIL exists
stmt = select(User).where(User.email == admin_email)
result = await session.execute(stmt)
existing_user = result.scalar_one_or_none()
if existing_user:
logger.info(f"User with email {admin_email} already exists - skipping admin creation")
logger.info(
f"User with email {admin_email} already exists - skipping admin creation"
)
return
# Get the super_admin role
stmt = select(Role).where(Role.name == "super_admin")
result = await session.execute(stmt)
super_admin_role = result.scalar_one_or_none()
if not super_admin_role:
logger.error("Super admin role not found - cannot create admin user")
return
# Create admin user from environment variables
# Generate username from email (part before @)
admin_username = admin_email.split('@')[0]
admin_username = admin_email.split("@")[0]
admin_user = User.create_default_admin(
email=admin_email,
username=admin_username,
password_hash=get_password_hash(admin_password)
password_hash=get_password_hash(admin_password),
)
# Assign the super_admin role
admin_user.role_id = super_admin_role.id
session.add(admin_user)
await session.commit()
logger.warning("=" * 60)
logger.warning("ADMIN USER CREATED FROM ENVIRONMENT")
logger.warning(f"Email: {admin_email}")
logger.warning(f"Username: {admin_username}")
logger.warning("Password: [Set via ADMIN_PASSWORD - only used on first creation]")
logger.warning("Role: Super Administrator")
logger.warning(
"Password: [Set via ADMIN_PASSWORD - only used on first creation]"
)
logger.warning("PLEASE CHANGE THE PASSWORD AFTER FIRST LOGIN")
logger.warning("=" * 60)
except SQLAlchemyError as e:
logger.error(f"Failed to create default admin user due to database error: {e}")
except AttributeError as e:
logger.error(f"Failed to create default admin user: invalid ADMIN_EMAIL '{settings.ADMIN_EMAIL}'")
logger.error(
f"Failed to create default admin user: invalid ADMIN_EMAIL '{settings.ADMIN_EMAIL}'"
)
except Exception as e:
logger.error(f"Failed to create default admin user: {e}")
# Don't raise here as this shouldn't block the application startup

View File

@@ -59,7 +59,10 @@ async def _check_redis_startup():
duration = time.perf_counter() - start
logger.info(
"Startup Redis check succeeded",
extra={"redis_url": settings.REDIS_URL, "duration_seconds": round(duration, 3)},
extra={
"redis_url": settings.REDIS_URL,
"duration_seconds": round(duration, 3),
},
)
except Exception as exc: # noqa: BLE001
logger.warning(
@@ -104,9 +107,10 @@ async def lifespan(app: FastAPI):
"""
logger.info("Starting Enclava platform...")
background_tasks = []
# Initialize core cache service (before database to provide caching for auth)
from app.core.cache import core_cache
try:
await core_cache.initialize()
logger.info("Core cache service initialized successfully")
@@ -122,12 +126,13 @@ async def lifespan(app: FastAPI):
# Initialize database
await init_db()
# Initialize config manager
await init_config_manager()
# Ensure platform permissions are registered before module discovery
from app.services.permission_manager import permission_registry
permission_registry.register_platform_permissions()
# Initialize LLM service (needed by RAG module) concurrently
@@ -153,40 +158,45 @@ async def lifespan(app: FastAPI):
await module_manager.initialize(app)
app.state.module_manager = module_manager
logger.info("Module manager initialized successfully")
# Initialize document processor
from app.services.document_processor import document_processor
try:
await document_processor.start()
app.state.document_processor = document_processor
except Exception as exc:
logger.error(f"Document processor failed to start: {exc}")
app.state.document_processor = None
# Setup metrics
try:
setup_metrics(app)
except Exception as exc:
logger.warning(f"Metrics setup failed: {exc}")
# Start background audit worker
from app.services.audit_service import start_audit_worker
try:
start_audit_worker()
except Exception as exc:
logger.warning(f"Audit worker failed to start: {exc}")
# Initialize plugin auto-discovery service concurrently
async def initialize_plugins():
from app.services.plugin_autodiscovery import initialize_plugin_autodiscovery
try:
discovery_results = await initialize_plugin_autodiscovery()
app.state.plugin_discovery_results = discovery_results
logger.info(f"Plugin auto-discovery completed: {discovery_results.get('summary')}")
logger.info(
f"Plugin auto-discovery completed: {discovery_results.get('summary')}"
)
except Exception as exc:
logger.warning(f"Plugin auto-discovery failed: {exc}")
app.state.plugin_discovery_results = {"error": str(exc)}
background_tasks.append(asyncio.create_task(initialize_plugins()))
if background_tasks:
@@ -194,9 +204,9 @@ async def lifespan(app: FastAPI):
for result in results:
if isinstance(result, Exception):
logger.warning(f"Background startup task failed: {result}")
logger.info("Platform started successfully")
try:
yield
finally:
@@ -205,6 +215,7 @@ async def lifespan(app: FastAPI):
# Cleanup embedding service HTTP sessions
from app.services.embedding_service import embedding_service
try:
await embedding_service.cleanup()
logger.info("Embedding service cleaned up successfully")
@@ -213,14 +224,16 @@ async def lifespan(app: FastAPI):
# Close core cache service
from app.core.cache import core_cache
await core_cache.cleanup()
# Close Redis connection for cached API key service
from app.services.cached_api_key import cached_api_key_service
await cached_api_key_service.close()
# Stop document processor
processor = getattr(app.state, 'document_processor', None)
processor = getattr(app.state, "document_processor", None)
if processor:
await processor.stop()
@@ -297,10 +310,12 @@ async def validation_exception_handler(request, exc: RequestValidationError):
"type": error.get("type", ""),
"location": error.get("loc", []),
"message": error.get("msg", ""),
"input": str(error.get("input", "")) if error.get("input") is not None else None
"input": str(error.get("input", ""))
if error.get("input") is not None
else None,
}
errors.append(error_dict)
return JSONResponse(
status_code=422,
content={
@@ -326,7 +341,7 @@ async def general_exception_handler(request, exc: Exception):
# Include Internal API routes (for frontend)
app.include_router(internal_api_router, prefix="/api-internal/v1")
# Include Public API routes (for external clients)
# Include Public API routes (for external clients)
app.include_router(public_api_router, prefix="/api/v1")
# OpenAI-compatible routes are now included in public API router at /api/v1/
@@ -357,7 +372,7 @@ async def root():
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host=settings.APP_HOST,

View File

@@ -17,24 +17,29 @@ from app.db.database import get_db
logger = get_logger(__name__)
# Context variable to pass analytics data from endpoints to middleware
analytics_context: ContextVar[dict] = ContextVar('analytics_context', default={})
analytics_context: ContextVar[dict] = ContextVar("analytics_context", default={})
class AnalyticsMiddleware(BaseHTTPMiddleware):
"""Middleware to automatically track all requests for analytics"""
async def dispatch(self, request: Request, call_next):
# Start timing
start_time = time.time()
# Skip analytics for health checks and static files
if request.url.path in ["/health", "/docs", "/redoc", "/openapi.json"] or request.url.path.startswith("/static"):
if request.url.path in [
"/health",
"/docs",
"/redoc",
"/openapi.json",
] or request.url.path.startswith("/static"):
return await call_next(request)
# Get user info if available from token
user_id = None
api_key_id = None
try:
authorization = request.headers.get("Authorization")
if authorization and authorization.startswith("Bearer "):
@@ -42,6 +47,7 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
# Try to extract user info from token without full validation
# This is a lightweight check for analytics purposes
from app.core.security import verify_token
try:
payload = verify_token(token)
user_id = int(payload.get("sub"))
@@ -51,7 +57,7 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
except Exception:
# Don't let analytics break the request
pass
# Get client IP
client_ip = request.client.host if request.client else None
if not client_ip:
@@ -59,17 +65,17 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
client_ip = request.headers.get("X-Forwarded-For", "").split(",")[0].strip()
if not client_ip:
client_ip = request.headers.get("X-Real-IP", "unknown")
# Get user agent
user_agent = request.headers.get("User-Agent", "")
# Get request size
request_size = int(request.headers.get("Content-Length", 0))
# Process the request
response = None
error_message = None
try:
response = await call_next(request)
except Exception as e:
@@ -77,21 +83,21 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
error_message = str(e)
response = JSONResponse(
status_code=500,
content={"error": "INTERNAL_ERROR", "message": "Internal server error"}
content={"error": "INTERNAL_ERROR", "message": "Internal server error"},
)
# Calculate timing
end_time = time.time()
response_time = (end_time - start_time) * 1000 # Convert to milliseconds
# Get response size
response_size = 0
if hasattr(response, 'body'):
if hasattr(response, "body"):
response_size = len(response.body) if response.body else 0
# Get analytics data from context (set by endpoints)
context_data = analytics_context.get({})
# Create analytics event
event = RequestEvent(
timestamp=datetime.utcnow(),
@@ -107,26 +113,29 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
response_size=response_size,
error_message=error_message,
# Token/cost info populated by LLM endpoints via context
model=context_data.get('model'),
request_tokens=context_data.get('request_tokens', 0),
response_tokens=context_data.get('response_tokens', 0),
total_tokens=context_data.get('total_tokens', 0),
cost_cents=context_data.get('cost_cents', 0),
budget_ids=context_data.get('budget_ids', []),
budget_warnings=context_data.get('budget_warnings', [])
model=context_data.get("model"),
request_tokens=context_data.get("request_tokens", 0),
response_tokens=context_data.get("response_tokens", 0),
total_tokens=context_data.get("total_tokens", 0),
cost_cents=context_data.get("cost_cents", 0),
budget_ids=context_data.get("budget_ids", []),
budget_warnings=context_data.get("budget_warnings", []),
)
# Track the event
try:
from app.services.analytics import analytics_service
if analytics_service is not None:
await analytics_service.track_request(event)
else:
logger.warning("Analytics service not initialized, skipping event tracking")
logger.warning(
"Analytics service not initialized, skipping event tracking"
)
except Exception as e:
logger.error(f"Failed to track analytics event: {e}")
# Don't let analytics failures break the request
return response
@@ -140,4 +149,4 @@ def set_analytics_data(**kwargs):
def setup_analytics_middleware(app):
"""Add analytics middleware to the FastAPI app"""
app.add_middleware(AnalyticsMiddleware)
logger.info("Analytics middleware configured")
logger.info("Analytics middleware configured")

View File

@@ -0,0 +1,393 @@
"""
Audit Logging Middleware
Automatically logs user actions and system events
"""
import time
import json
import logging
from typing import Callable, Optional, Dict, Any
from datetime import datetime
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.audit_log import AuditLog, AuditAction, AuditSeverity
from app.db.database import get_db_session
from app.core.security import verify_token
logger = logging.getLogger(__name__)
class AuditLoggingMiddleware(BaseHTTPMiddleware):
"""Middleware to automatically log user actions and API calls"""
def __init__(self, app, exclude_paths: Optional[list] = None):
super().__init__(app)
# Paths to exclude from audit logging
self.exclude_paths = exclude_paths or [
"/docs",
"/redoc",
"/openapi.json",
"/health",
"/metrics",
"/static",
"/favicon.ico",
]
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Skip audit logging for excluded paths
if any(request.url.path.startswith(path) for path in self.exclude_paths):
return await call_next(request)
# Skip audit logging for health checks and static assets
if request.url.path in ["/", "/health"] or "/static/" in request.url.path:
return await call_next(request)
start_time = time.time()
# Extract user information from request
user_info = await self._extract_user_info(request)
# Prepare audit data
audit_data = {
"method": request.method,
"path": request.url.path,
"query_params": dict(request.query_params),
"ip_address": self._get_client_ip(request),
"user_agent": request.headers.get("user-agent"),
"timestamp": datetime.utcnow().isoformat(),
}
# Process request
response = await call_next(request)
# Calculate response time
process_time = time.time() - start_time
audit_data["response_time"] = round(process_time * 1000, 2) # milliseconds
audit_data["status_code"] = response.status_code
audit_data["success"] = 200 <= response.status_code < 400
# Log the audit event asynchronously
try:
await self._log_audit_event(user_info, audit_data, request)
except Exception as e:
logger.error(f"Failed to log audit event: {e}")
# Don't fail the request if audit logging fails
return response
async def _extract_user_info(self, request: Request) -> Optional[Dict[str, Any]]:
"""Extract user information from request headers"""
try:
# Try to get user info from Authorization header
auth_header = request.headers.get("authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header.split(" ")[1]
payload = verify_token(token)
return {
"user_id": int(payload.get("sub")) if payload.get("sub") else None,
"email": payload.get("email"),
"is_superuser": payload.get("is_superuser", False),
"role": payload.get("role"),
}
except Exception:
# If token verification fails, continue without user info
pass
# Try to get user info from API key header
api_key = request.headers.get("x-api-key")
if api_key:
# Would need to implement API key lookup here
# For now, just indicate it's an API key request
return {
"user_id": None,
"email": "api_key_user",
"is_superuser": False,
"role": "api_user",
"auth_type": "api_key",
}
return None
def _get_client_ip(self, request: Request) -> str:
"""Get client IP address with proxy support"""
# Check for forwarded headers first (for reverse proxy setups)
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
# Take the first IP in the chain
return forwarded_for.split(",")[0].strip()
forwarded = request.headers.get("x-forwarded")
if forwarded:
return forwarded
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip
# Fall back to direct client IP
return request.client.host if request.client else "unknown"
async def _log_audit_event(
self,
user_info: Optional[Dict[str, Any]],
audit_data: Dict[str, Any],
request: Request
):
"""Log the audit event to database"""
# Determine action based on HTTP method and path
action = self._determine_action(request.method, request.url.path)
# Determine resource type and ID from path
resource_type, resource_id = self._parse_resource_from_path(request.url.path)
# Create description
description = self._create_description(request.method, request.url.path, audit_data["success"])
# Determine severity
severity = self._determine_severity(request.method, audit_data["status_code"], request.url.path)
# Create audit log entry
try:
async with get_db_session() as db:
audit_log = AuditLog(
user_id=user_info.get("user_id") if user_info else None,
action=action,
resource_type=resource_type,
resource_id=resource_id,
description=description,
details={
"request": {
"method": audit_data["method"],
"path": audit_data["path"],
"query_params": audit_data["query_params"],
"response_time_ms": audit_data["response_time"],
},
"user_info": user_info,
},
ip_address=audit_data["ip_address"],
user_agent=audit_data["user_agent"],
severity=severity,
category=self._determine_category(request.url.path),
success=audit_data["success"],
tags=self._generate_tags(request.method, request.url.path),
)
db.add(audit_log)
await db.commit()
except Exception as e:
logger.error(f"Failed to save audit log to database: {e}")
# Could implement fallback logging to file here
def _determine_action(self, method: str, path: str) -> str:
"""Determine action type from HTTP method and path"""
method = method.upper()
if method == "GET":
return AuditAction.READ
elif method == "POST":
if "login" in path.lower():
return AuditAction.LOGIN
elif "logout" in path.lower():
return AuditAction.LOGOUT
else:
return AuditAction.CREATE
elif method == "PUT" or method == "PATCH":
return AuditAction.UPDATE
elif method == "DELETE":
return AuditAction.DELETE
else:
return method.lower()
def _parse_resource_from_path(self, path: str) -> tuple[str, Optional[str]]:
"""Parse resource type and ID from URL path"""
path_parts = path.strip("/").split("/")
# Skip API version prefix
if path_parts and path_parts[0] in ["api", "api-internal"]:
path_parts = path_parts[2:] # Skip 'api' and 'v1'
if not path_parts:
return "system", None
resource_type = path_parts[0]
resource_id = None
# Try to find numeric ID in path
for part in path_parts[1:]:
if part.isdigit():
resource_id = part
break
return resource_type, resource_id
def _create_description(self, method: str, path: str, success: bool) -> str:
"""Create human-readable description of the action"""
action_verbs = {
"GET": "accessed" if success else "attempted to access",
"POST": "created" if success else "attempted to create",
"PUT": "updated" if success else "attempted to update",
"PATCH": "modified" if success else "attempted to modify",
"DELETE": "deleted" if success else "attempted to delete",
}
verb = action_verbs.get(method, method.lower())
resource = path.strip("/").split("/")[-1] if "/" in path else path
return f"User {verb} {resource}"
def _determine_severity(self, method: str, status_code: int, path: str) -> str:
"""Determine severity level based on action and outcome"""
# Critical operations
if any(keyword in path.lower() for keyword in ["delete", "password", "admin", "key"]):
return AuditSeverity.HIGH
# Failed operations
if status_code >= 400:
if status_code >= 500:
return AuditSeverity.CRITICAL
elif status_code in [401, 403]:
return AuditSeverity.HIGH
else:
return AuditSeverity.MEDIUM
# Write operations
if method in ["POST", "PUT", "PATCH", "DELETE"]:
return AuditSeverity.MEDIUM
# Read operations
return AuditSeverity.LOW
def _determine_category(self, path: str) -> str:
"""Determine category based on path"""
path = path.lower()
if any(keyword in path for keyword in ["auth", "login", "logout", "token"]):
return "authentication"
elif any(keyword in path for keyword in ["user", "admin", "role", "permission"]):
return "user_management"
elif any(keyword in path for keyword in ["api-key", "key"]):
return "security"
elif any(keyword in path for keyword in ["budget", "billing", "usage"]):
return "financial"
elif any(keyword in path for keyword in ["audit", "log"]):
return "audit"
elif any(keyword in path for keyword in ["setting", "config"]):
return "configuration"
else:
return "general"
def _generate_tags(self, method: str, path: str) -> list[str]:
"""Generate tags for the audit log"""
tags = [method.lower()]
path_parts = path.strip("/").split("/")
if path_parts:
tags.append(path_parts[0])
# Add special tags
if "admin" in path.lower():
tags.append("admin_action")
if any(keyword in path.lower() for keyword in ["password", "auth", "login"]):
tags.append("security_action")
return tags
class LoginAuditMiddleware(BaseHTTPMiddleware):
"""Specialized middleware for login/logout events"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Only process auth-related endpoints
if not any(path in request.url.path for path in ["/auth/login", "/auth/logout", "/auth/refresh"]):
return await call_next(request)
start_time = time.time()
# Store request body for login attempts
request_body = None
if request.method == "POST" and "/login" in request.url.path:
try:
body = await request.body()
if body:
request_body = json.loads(body.decode())
# Re-create request with body for downstream processing
from starlette.requests import Request as StarletteRequest
from io import BytesIO
request._body = body
except Exception as e:
logger.warning(f"Failed to parse login request body: {e}")
response = await call_next(request)
# Log login/logout events
try:
await self._log_auth_event(request, response, request_body, time.time() - start_time)
except Exception as e:
logger.error(f"Failed to log auth event: {e}")
return response
async def _log_auth_event(self, request: Request, response: Response, request_body: dict, process_time: float):
"""Log authentication events"""
success = 200 <= response.status_code < 300
if "/login" in request.url.path:
# Extract email/username from request
identifier = None
if request_body:
identifier = request_body.get("email") or request_body.get("username")
# For successful logins, we could extract user_id from response
# For now, we'll use the identifier
async with get_db_session() as db:
audit_log = AuditLog.create_login_event(
user_id=None, # Would need to extract from response for successful logins
success=success,
ip_address=self._get_client_ip(request),
user_agent=request.headers.get("user-agent"),
error_message=f"HTTP {response.status_code}" if not success else None,
)
# Add additional details
audit_log.details.update({
"identifier": identifier,
"response_time_ms": round(process_time * 1000, 2),
})
db.add(audit_log)
await db.commit()
elif "/logout" in request.url.path:
# Extract user info from token if available
user_id = None
try:
auth_header = request.headers.get("authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header.split(" ")[1]
payload = verify_token(token)
user_id = int(payload.get("sub")) if payload.get("sub") else None
except Exception:
pass
async with get_db_session() as db:
audit_log = AuditLog.create_logout_event(
user_id=user_id,
session_id=None, # Could extract from token if stored
)
db.add(audit_log)
await db.commit()
def _get_client_ip(self, request: Request) -> str:
"""Get client IP address with proxy support"""
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
return request.client.host if request.client else "unknown"

View File

@@ -23,8 +23,12 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
request_id = str(uuid4())
# Skip debugging for health checks and static files
if request.url.path in ["/health", "/docs", "/redoc", "/openapi.json"] or \
request.url.path.startswith("/static"):
if request.url.path in [
"/health",
"/docs",
"/redoc",
"/openapi.json",
] or request.url.path.startswith("/static"):
return await call_next(request)
# Log request details
@@ -37,7 +41,7 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
try:
request_body = json.loads(body_bytes)
except json.JSONDecodeError:
request_body = body_bytes.decode('utf-8', errors='replace')
request_body = body_bytes.decode("utf-8", errors="replace")
# Restore body for downstream processing
request._body = body_bytes
except Exception:
@@ -45,8 +49,9 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
# Extract headers we care about
headers_to_log = {
"authorization": request.headers.get("Authorization", "")[:50] + "..." if
request.headers.get("Authorization") else None,
"authorization": request.headers.get("Authorization", "")[:50] + "..."
if request.headers.get("Authorization")
else None,
"content-type": request.headers.get("Content-Type"),
"user-agent": request.headers.get("User-Agent"),
"x-forwarded-for": request.headers.get("X-Forwarded-For"),
@@ -54,17 +59,20 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
}
# Log request
logger.info("=== API REQUEST DEBUG ===", extra={
"request_id": request_id,
"method": request.method,
"url": str(request.url),
"path": request.url.path,
"query_params": dict(request.query_params),
"headers": {k: v for k, v in headers_to_log.items() if v is not None},
"body": request_body,
"client_ip": request.client.host if request.client else None,
"timestamp": datetime.utcnow().isoformat()
})
logger.info(
"=== API REQUEST DEBUG ===",
extra={
"request_id": request_id,
"method": request.method,
"url": str(request.url),
"path": request.url.path,
"query_params": dict(request.query_params),
"headers": {k: v for k, v in headers_to_log.items() if v is not None},
"body": request_body,
"client_ip": request.client.host if request.client else None,
"timestamp": datetime.utcnow().isoformat(),
},
)
# Process the request
start_time = time.time()
@@ -73,33 +81,43 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
# Add timeout detection
try:
logger.info(f"=== START PROCESSING REQUEST === {request_id} at {datetime.utcnow().isoformat()}")
logger.info(
f"=== START PROCESSING REQUEST === {request_id} at {datetime.utcnow().isoformat()}"
)
logger.info(f"Request path: {request.url.path}")
logger.info(f"Request method: {request.method}")
# Check if this is the login endpoint
if request.url.path == "/api-internal/v1/auth/login" and request.method == "POST":
if (
request.url.path == "/api-internal/v1/auth/login"
and request.method == "POST"
):
logger.info(f"=== LOGIN REQUEST DETECTED === {request_id}")
response = await call_next(request)
logger.info(f"=== REQUEST COMPLETED === {request_id} at {datetime.utcnow().isoformat()}")
logger.info(
f"=== REQUEST COMPLETED === {request_id} at {datetime.utcnow().isoformat()}"
)
# Capture response body for successful JSON responses
if response.status_code < 400 and isinstance(response, JSONResponse):
try:
response_body = json.loads(response.body.decode('utf-8'))
response_body = json.loads(response.body.decode("utf-8"))
except:
response_body = "[Failed to decode response body]"
except Exception as e:
logger.error(f"Request processing failed: {str(e)}", extra={
"request_id": request_id,
"error": str(e),
"error_type": type(e).__name__
})
logger.error(
f"Request processing failed: {str(e)}",
extra={
"request_id": request_id,
"error": str(e),
"error_type": type(e).__name__,
},
)
response = JSONResponse(
status_code=500,
content={"error": "INTERNAL_ERROR", "message": "Internal server error"}
content={"error": "INTERNAL_ERROR", "message": "Internal server error"},
)
# Calculate timing
@@ -107,14 +125,17 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
duration = (end_time - start_time) * 1000 # milliseconds
# Log response
logger.info("=== API RESPONSE DEBUG ===", extra={
"request_id": request_id,
"status_code": response.status_code,
"duration_ms": round(duration, 2),
"response_body": response_body,
"response_headers": dict(response.headers),
"timestamp": datetime.utcnow().isoformat()
})
logger.info(
"=== API RESPONSE DEBUG ===",
extra={
"request_id": request_id,
"status_code": response.status_code,
"duration_ms": round(duration, 2),
"response_body": response_body,
"response_headers": dict(response.headers),
"timestamp": datetime.utcnow().isoformat(),
},
)
return response
@@ -122,4 +143,4 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
def setup_debugging_middleware(app):
"""Add debugging middleware to the FastAPI app"""
app.add_middleware(DebuggingMiddleware)
logger.info("Debugging middleware configured")
logger.info("Debugging middleware configured")

View File

@@ -9,28 +9,63 @@ from .budget import Budget
from .audit_log import AuditLog
from .rag_collection import RagCollection
from .rag_document import RagDocument
from .chatbot import ChatbotInstance, ChatbotConversation, ChatbotMessage, ChatbotAnalytics
from .chatbot import (
ChatbotInstance,
ChatbotConversation,
ChatbotMessage,
ChatbotAnalytics,
)
from .prompt_template import PromptTemplate, ChatbotPromptVariable
from .plugin import Plugin, PluginConfiguration, PluginInstance, PluginAuditLog, PluginCronJob, PluginAPIGateway
from .plugin import (
Plugin,
PluginConfiguration,
PluginInstance,
PluginAuditLog,
PluginCronJob,
PluginAPIGateway,
)
from .role import Role, RoleLevel
from .tool import Tool, ToolExecution, ToolCategory, ToolType, ToolStatus
from .notification import (
Notification,
NotificationTemplate,
NotificationChannel,
NotificationType,
NotificationPriority,
NotificationStatus,
)
__all__ = [
"User",
"APIKey",
"UsageTracking",
"Budget",
"User",
"APIKey",
"UsageTracking",
"Budget",
"AuditLog",
"RagCollection",
"RagCollection",
"RagDocument",
"ChatbotInstance",
"ChatbotConversation",
"ChatbotConversation",
"ChatbotMessage",
"ChatbotAnalytics",
"PromptTemplate",
"ChatbotPromptVariable",
"Plugin",
"PluginConfiguration",
"PluginInstance",
"PluginInstance",
"PluginAuditLog",
"PluginCronJob",
"PluginAPIGateway"
]
"PluginAPIGateway",
"Role",
"RoleLevel",
"Tool",
"ToolExecution",
"ToolCategory",
"ToolType",
"ToolStatus",
"Notification",
"NotificationTemplate",
"NotificationChannel",
"NotificationType",
"NotificationPriority",
"NotificationStatus",
]

View File

@@ -3,73 +3,94 @@ API Key model
"""
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON, ForeignKey
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
Boolean,
Text,
JSON,
ForeignKey,
)
from sqlalchemy.orm import relationship
from app.db.database import Base
class APIKey(Base):
"""API Key model for authentication and access control"""
__tablename__ = "api_keys"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False) # Human-readable name for the API key
key_hash = Column(String, unique=True, index=True, nullable=False) # Hashed API key
key_prefix = Column(String, index=True, nullable=False) # First 8 characters for identification
key_prefix = Column(
String, index=True, nullable=False
) # First 8 characters for identification
# User relationship
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
user = relationship("User", back_populates="api_keys")
# Related data relationships
budgets = relationship("Budget", back_populates="api_key", cascade="all, delete-orphan")
usage_tracking = relationship("UsageTracking", back_populates="api_key", cascade="all, delete-orphan")
budgets = relationship(
"Budget", back_populates="api_key", cascade="all, delete-orphan"
)
usage_tracking = relationship(
"UsageTracking", back_populates="api_key", cascade="all, delete-orphan"
)
# Key status and permissions
is_active = Column(Boolean, default=True)
permissions = Column(JSON, default=dict) # Specific permissions for this key
scopes = Column(JSON, default=list) # OAuth-like scopes
# Usage limits
rate_limit_per_minute = Column(Integer, default=60) # Requests per minute
rate_limit_per_hour = Column(Integer, default=3600) # Requests per hour
rate_limit_per_day = Column(Integer, default=86400) # Requests per day
# Allowed resources
allowed_models = Column(JSON, default=list) # List of allowed LLM models
allowed_endpoints = Column(JSON, default=list) # List of allowed API endpoints
allowed_ips = Column(JSON, default=list) # IP whitelist
allowed_chatbots = Column(JSON, default=list) # List of allowed chatbot IDs for chatbot-specific keys
allowed_chatbots = Column(
JSON, default=list
) # List of allowed chatbot IDs for chatbot-specific keys
# Budget configuration
is_unlimited = Column(Boolean, default=True) # Unlimited budget flag
budget_limit_cents = Column(Integer, nullable=True) # Budget limit in cents
budget_type = Column(String, nullable=True) # "total" or "monthly"
# Metadata
description = Column(Text, nullable=True)
tags = Column(JSON, default=list) # For organizing keys
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
last_used_at = Column(DateTime, nullable=True)
expires_at = Column(DateTime, nullable=True) # Optional expiration
# Usage tracking
total_requests = Column(Integer, default=0)
total_tokens = Column(Integer, default=0)
total_cost = Column(Integer, default=0) # In cents
# Relationships
usage_tracking = relationship("UsageTracking", back_populates="api_key", cascade="all, delete-orphan")
budgets = relationship("Budget", back_populates="api_key", cascade="all, delete-orphan")
usage_tracking = relationship(
"UsageTracking", back_populates="api_key", cascade="all, delete-orphan"
)
budgets = relationship(
"Budget", back_populates="api_key", cascade="all, delete-orphan"
)
plugin_audit_logs = relationship("PluginAuditLog", back_populates="api_key")
def __repr__(self):
return f"<APIKey(id={self.id}, name='{self.name}', user_id={self.user_id})>"
def to_dict(self, include_sensitive: bool = False):
"""Convert API key to dictionary for API responses"""
data = {
@@ -91,138 +112,142 @@ class APIKey(Base):
"tags": self.tags,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
"last_used_at": self.last_used_at.isoformat()
if self.last_used_at
else None,
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
"total_requests": self.total_requests,
"total_tokens": self.total_tokens,
"total_cost_cents": self.total_cost,
"is_unlimited": self.is_unlimited,
"budget_limit": self.budget_limit_cents, # Map to budget_limit for API response
"budget_type": self.budget_type
"budget_type": self.budget_type,
}
if include_sensitive:
data["key_hash"] = self.key_hash
return data
def is_expired(self) -> bool:
"""Check if the API key has expired"""
if self.expires_at is None:
return False
return datetime.utcnow() > self.expires_at
def is_valid(self) -> bool:
"""Check if the API key is valid and active"""
return self.is_active and not self.is_expired()
def has_permission(self, permission: str) -> bool:
"""Check if the API key has a specific permission"""
return permission in self.permissions
def has_scope(self, scope: str) -> bool:
"""Check if the API key has a specific scope"""
return scope in self.scopes
def can_access_model(self, model_name: str) -> bool:
"""Check if the API key can access a specific model"""
if not self.allowed_models: # Empty list means all models allowed
return True
return model_name in self.allowed_models
def can_access_endpoint(self, endpoint: str) -> bool:
"""Check if the API key can access a specific endpoint"""
if not self.allowed_endpoints: # Empty list means all endpoints allowed
return True
return endpoint in self.allowed_endpoints
def can_access_from_ip(self, ip_address: str) -> bool:
"""Check if the API key can be used from a specific IP"""
if not self.allowed_ips: # Empty list means all IPs allowed
return True
return ip_address in self.allowed_ips
def can_access_chatbot(self, chatbot_id: str) -> bool:
"""Check if the API key can access a specific chatbot"""
if not self.allowed_chatbots: # Empty list means all chatbots allowed
return True
return chatbot_id in self.allowed_chatbots
def update_usage(self, tokens_used: int = 0, cost_cents: int = 0):
"""Update usage statistics"""
self.total_requests += 1
self.total_tokens += tokens_used
self.total_cost += cost_cents
self.last_used_at = datetime.utcnow()
def set_expiration(self, days: int):
"""Set expiration date in days from now"""
self.expires_at = datetime.utcnow() + timedelta(days=days)
def extend_expiration(self, days: int):
"""Extend expiration date by specified days"""
if self.expires_at is None:
self.expires_at = datetime.utcnow() + timedelta(days=days)
else:
self.expires_at = self.expires_at + timedelta(days=days)
def revoke(self):
"""Revoke the API key"""
self.is_active = False
self.updated_at = datetime.utcnow()
def add_scope(self, scope: str):
"""Add a scope to the API key"""
if scope not in self.scopes:
self.scopes.append(scope)
def remove_scope(self, scope: str):
"""Remove a scope from the API key"""
if scope in self.scopes:
self.scopes.remove(scope)
def add_allowed_model(self, model_name: str):
"""Add an allowed model"""
if model_name not in self.allowed_models:
self.allowed_models.append(model_name)
def remove_allowed_model(self, model_name: str):
"""Remove an allowed model"""
if model_name in self.allowed_models:
self.allowed_models.remove(model_name)
def add_allowed_endpoint(self, endpoint: str):
"""Add an allowed endpoint"""
if endpoint not in self.allowed_endpoints:
self.allowed_endpoints.append(endpoint)
def remove_allowed_endpoint(self, endpoint: str):
"""Remove an allowed endpoint"""
if endpoint in self.allowed_endpoints:
self.allowed_endpoints.remove(endpoint)
def add_allowed_ip(self, ip_address: str):
"""Add an allowed IP address"""
if ip_address not in self.allowed_ips:
self.allowed_ips.append(ip_address)
def remove_allowed_ip(self, ip_address: str):
"""Remove an allowed IP address"""
if ip_address in self.allowed_ips:
self.allowed_ips.remove(ip_address)
def add_allowed_chatbot(self, chatbot_id: str):
"""Add an allowed chatbot"""
if chatbot_id not in self.allowed_chatbots:
self.allowed_chatbots.append(chatbot_id)
def remove_allowed_chatbot(self, chatbot_id: str):
"""Remove an allowed chatbot"""
if chatbot_id in self.allowed_chatbots:
self.allowed_chatbots.remove(chatbot_id)
@classmethod
def create_default_key(cls, user_id: int, name: str, key_hash: str, key_prefix: str) -> "APIKey":
def create_default_key(
cls, user_id: int, name: str, key_hash: str, key_prefix: str
) -> "APIKey":
"""Create a default API key with standard permissions"""
return cls(
name=name,
@@ -230,17 +255,8 @@ class APIKey(Base):
key_prefix=key_prefix,
user_id=user_id,
is_active=True,
permissions={
"read": True,
"write": True,
"chat": True,
"embeddings": True
},
scopes=[
"chat.completions",
"embeddings.create",
"models.list"
],
permissions={"read": True, "write": True, "chat": True, "embeddings": True},
scopes=["chat.completions", "embeddings.create", "models.list"],
rate_limit_per_minute=60,
rate_limit_per_hour=3600,
rate_limit_per_day=86400,
@@ -248,12 +264,19 @@ class APIKey(Base):
allowed_endpoints=[], # All endpoints allowed by default
allowed_ips=[], # All IPs allowed by default
description="Default API key with standard permissions",
tags=["default"]
tags=["default"],
)
@classmethod
def create_restricted_key(cls, user_id: int, name: str, key_hash: str, key_prefix: str,
models: List[str], endpoints: List[str]) -> "APIKey":
def create_restricted_key(
cls,
user_id: int,
name: str,
key_hash: str,
key_prefix: str,
models: List[str],
endpoints: List[str],
) -> "APIKey":
"""Create a restricted API key with limited permissions"""
return cls(
name=name,
@@ -261,13 +284,8 @@ class APIKey(Base):
key_prefix=key_prefix,
user_id=user_id,
is_active=True,
permissions={
"read": True,
"chat": True
},
scopes=[
"chat.completions"
],
permissions={"read": True, "chat": True},
scopes=["chat.completions"],
rate_limit_per_minute=30,
rate_limit_per_hour=1800,
rate_limit_per_day=43200,
@@ -275,12 +293,19 @@ class APIKey(Base):
allowed_endpoints=endpoints,
allowed_ips=[],
description="Restricted API key with limited permissions",
tags=["restricted"]
tags=["restricted"],
)
@classmethod
def create_chatbot_key(cls, user_id: int, name: str, key_hash: str, key_prefix: str,
chatbot_id: str, chatbot_name: str) -> "APIKey":
def create_chatbot_key(
cls,
user_id: int,
name: str,
key_hash: str,
key_prefix: str,
chatbot_id: str,
chatbot_name: str,
) -> "APIKey":
"""Create a chatbot-specific API key"""
return cls(
name=name,
@@ -288,22 +313,18 @@ class APIKey(Base):
key_prefix=key_prefix,
user_id=user_id,
is_active=True,
permissions={
"chatbot": True
},
scopes=[
"chatbot.chat"
],
permissions={"chatbot": True},
scopes=["chatbot.chat"],
rate_limit_per_minute=100,
rate_limit_per_hour=6000,
rate_limit_per_day=144000,
allowed_models=[], # Will use chatbot's configured model
allowed_endpoints=[
f"/api/v1/chatbot/external/{chatbot_id}/chat",
f"/api/v1/chatbot/external/{chatbot_id}/chat/completions"
f"/api/v1/chatbot/external/{chatbot_id}/chat/completions",
],
allowed_ips=[],
allowed_chatbots=[chatbot_id],
description=f"API key for chatbot: {chatbot_name}",
tags=["chatbot", f"chatbot-{chatbot_id}"]
)
tags=["chatbot", f"chatbot-{chatbot_id}"],
)

View File

@@ -3,7 +3,16 @@ Audit log model for tracking system events and user actions
"""
from datetime import datetime
from typing import Optional, Dict, Any
from sqlalchemy import Column, Integer, String, DateTime, JSON, ForeignKey, Text, Boolean
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
JSON,
ForeignKey,
Text,
Boolean,
)
from sqlalchemy.orm import relationship
from app.db.database import Base
from enum import Enum
@@ -11,6 +20,7 @@ from enum import Enum
class AuditAction(str, Enum):
"""Audit action types"""
CREATE = "create"
READ = "read"
UPDATE = "update"
@@ -32,6 +42,7 @@ class AuditAction(str, Enum):
class AuditSeverity(str, Enum):
"""Audit severity levels"""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
@@ -40,52 +51,58 @@ class AuditSeverity(str, Enum):
class AuditLog(Base):
"""Audit log model for tracking system events and user actions"""
__tablename__ = "audit_logs"
id = Column(Integer, primary_key=True, index=True)
# User relationship (nullable for system events)
user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
user = relationship("User", back_populates="audit_logs")
# Event details
action = Column(String, nullable=False)
resource_type = Column(String, nullable=False) # user, api_key, budget, module, etc.
resource_type = Column(
String, nullable=False
) # user, api_key, budget, module, etc.
resource_id = Column(String, nullable=True) # ID of the affected resource
# Event description and details
description = Column(Text, nullable=False)
details = Column(JSON, default=dict) # Additional event details
# Request context
ip_address = Column(String, nullable=True)
user_agent = Column(String, nullable=True)
session_id = Column(String, nullable=True)
request_id = Column(String, nullable=True)
# Event classification
severity = Column(String, default=AuditSeverity.LOW)
category = Column(String, nullable=True) # security, access, data, system
# Success/failure tracking
success = Column(Boolean, default=True)
error_message = Column(Text, nullable=True)
# Additional metadata
tags = Column(JSON, default=list)
audit_metadata = Column("metadata", JSON, default=dict) # Map to 'metadata' column in DB
audit_metadata = Column(
"metadata", JSON, default=dict
) # Map to 'metadata' column in DB
# Before/after values for data changes
old_values = Column(JSON, nullable=True)
new_values = Column(JSON, nullable=True)
# Timestamp
created_at = Column(DateTime, default=datetime.utcnow, index=True)
def __repr__(self):
return f"<AuditLog(id={self.id}, action='{self.action}', user_id={self.user_id})>"
return (
f"<AuditLog(id={self.id}, action='{self.action}', user_id={self.user_id})>"
)
def to_dict(self):
"""Convert audit log to dictionary for API responses"""
return {
@@ -108,9 +125,9 @@ class AuditLog(Base):
"metadata": self.audit_metadata,
"old_values": self.old_values,
"new_values": self.new_values,
"created_at": self.created_at.isoformat() if self.created_at else None
"created_at": self.created_at.isoformat() if self.created_at else None,
}
def is_security_event(self) -> bool:
"""Check if this is a security-related event"""
security_actions = [
@@ -120,39 +137,45 @@ class AuditLog(Base):
AuditAction.API_KEY_DELETE,
AuditAction.PERMISSION_GRANT,
AuditAction.PERMISSION_REVOKE,
AuditAction.SECURITY_EVENT
AuditAction.SECURITY_EVENT,
]
return self.action in security_actions or self.category == "security"
def is_high_severity(self) -> bool:
"""Check if this is a high severity event"""
return self.severity in [AuditSeverity.HIGH, AuditSeverity.CRITICAL]
def add_tag(self, tag: str):
"""Add a tag to the audit log"""
if tag not in self.tags:
self.tags.append(tag)
def remove_tag(self, tag: str):
"""Remove a tag from the audit log"""
if tag in self.tags:
self.tags.remove(tag)
def update_metadata(self, key: str, value: Any):
"""Update metadata"""
if self.audit_metadata is None:
self.audit_metadata = {}
self.audit_metadata[key] = value
def set_before_after(self, old_values: Dict[str, Any], new_values: Dict[str, Any]):
"""Set before and after values for data changes"""
self.old_values = old_values
self.new_values = new_values
@classmethod
def create_login_event(cls, user_id: int, success: bool = True,
ip_address: str = None, user_agent: str = None,
session_id: str = None, error_message: str = None) -> "AuditLog":
def create_login_event(
cls,
user_id: int,
success: bool = True,
ip_address: str = None,
user_agent: str = None,
session_id: str = None,
error_message: str = None,
) -> "AuditLog":
"""Create a login audit event"""
return cls(
user_id=user_id,
@@ -160,10 +183,7 @@ class AuditLog(Base):
resource_type="user",
resource_id=str(user_id),
description=f"User login {'successful' if success else 'failed'}",
details={
"login_method": "password",
"success": success
},
details={"login_method": "password", "success": success},
ip_address=ip_address,
user_agent=user_agent,
session_id=session_id,
@@ -171,9 +191,9 @@ class AuditLog(Base):
category="security",
success=success,
error_message=error_message,
tags=["authentication", "login"]
tags=["authentication", "login"],
)
@classmethod
def create_logout_event(cls, user_id: int, session_id: str = None) -> "AuditLog":
"""Create a logout audit event"""
@@ -183,20 +203,24 @@ class AuditLog(Base):
resource_type="user",
resource_id=str(user_id),
description="User logout",
details={
"logout_method": "manual"
},
details={"logout_method": "manual"},
session_id=session_id,
severity=AuditSeverity.LOW,
category="security",
success=True,
tags=["authentication", "logout"]
tags=["authentication", "logout"],
)
@classmethod
def create_api_key_event(cls, user_id: int, action: str, api_key_id: int,
api_key_name: str, success: bool = True,
error_message: str = None) -> "AuditLog":
def create_api_key_event(
cls,
user_id: int,
action: str,
api_key_id: int,
api_key_name: str,
success: bool = True,
error_message: str = None,
) -> "AuditLog":
"""Create an API key audit event"""
return cls(
user_id=user_id,
@@ -204,21 +228,24 @@ class AuditLog(Base):
resource_type="api_key",
resource_id=str(api_key_id),
description=f"API key {action}: {api_key_name}",
details={
"api_key_name": api_key_name,
"action": action
},
details={"api_key_name": api_key_name, "action": action},
severity=AuditSeverity.MEDIUM,
category="security",
success=success,
error_message=error_message,
tags=["api_key", action]
tags=["api_key", action],
)
@classmethod
def create_budget_event(cls, user_id: int, action: str, budget_id: int,
budget_name: str, details: Dict[str, Any] = None,
success: bool = True) -> "AuditLog":
def create_budget_event(
cls,
user_id: int,
action: str,
budget_id: int,
budget_name: str,
details: Dict[str, Any] = None,
success: bool = True,
) -> "AuditLog":
"""Create a budget audit event"""
return cls(
user_id=user_id,
@@ -227,16 +254,24 @@ class AuditLog(Base):
resource_id=str(budget_id),
description=f"Budget {action}: {budget_name}",
details=details or {},
severity=AuditSeverity.MEDIUM if action == AuditAction.BUDGET_EXCEED else AuditSeverity.LOW,
severity=AuditSeverity.MEDIUM
if action == AuditAction.BUDGET_EXCEED
else AuditSeverity.LOW,
category="financial",
success=success,
tags=["budget", action]
tags=["budget", action],
)
@classmethod
def create_module_event(cls, user_id: int, action: str, module_name: str,
success: bool = True, error_message: str = None,
details: Dict[str, Any] = None) -> "AuditLog":
def create_module_event(
cls,
user_id: int,
action: str,
module_name: str,
success: bool = True,
error_message: str = None,
details: Dict[str, Any] = None,
) -> "AuditLog":
"""Create a module audit event"""
return cls(
user_id=user_id,
@@ -249,12 +284,18 @@ class AuditLog(Base):
category="system",
success=success,
error_message=error_message,
tags=["module", action]
tags=["module", action],
)
@classmethod
def create_permission_event(cls, user_id: int, action: str, target_user_id: int,
permission: str, success: bool = True) -> "AuditLog":
def create_permission_event(
cls,
user_id: int,
action: str,
target_user_id: int,
permission: str,
success: bool = True,
) -> "AuditLog":
"""Create a permission audit event"""
return cls(
user_id=user_id,
@@ -262,21 +303,23 @@ class AuditLog(Base):
resource_type="permission",
resource_id=str(target_user_id),
description=f"Permission {action}: {permission} for user {target_user_id}",
details={
"permission": permission,
"target_user_id": target_user_id
},
details={"permission": permission, "target_user_id": target_user_id},
severity=AuditSeverity.HIGH,
category="security",
success=success,
tags=["permission", action]
tags=["permission", action],
)
@classmethod
def create_security_event(cls, user_id: int, event_type: str, description: str,
severity: str = AuditSeverity.HIGH,
details: Dict[str, Any] = None,
ip_address: str = None) -> "AuditLog":
def create_security_event(
cls,
user_id: int,
event_type: str,
description: str,
severity: str = AuditSeverity.HIGH,
details: Dict[str, Any] = None,
ip_address: str = None,
) -> "AuditLog":
"""Create a security audit event"""
return cls(
user_id=user_id,
@@ -289,15 +332,19 @@ class AuditLog(Base):
severity=severity,
category="security",
success=False, # Security events are typically failures
tags=["security", event_type]
tags=["security", event_type],
)
@classmethod
def create_system_event(cls, action: str, description: str,
resource_type: str = "system",
resource_id: str = None,
severity: str = AuditSeverity.LOW,
details: Dict[str, Any] = None) -> "AuditLog":
def create_system_event(
cls,
action: str,
description: str,
resource_type: str = "system",
resource_id: str = None,
severity: str = AuditSeverity.LOW,
details: Dict[str, Any] = None,
) -> "AuditLog":
"""Create a system audit event"""
return cls(
user_id=None, # System events don't have a user
@@ -309,14 +356,20 @@ class AuditLog(Base):
severity=severity,
category="system",
success=True,
tags=["system", action]
tags=["system", action],
)
@classmethod
def create_data_change_event(cls, user_id: int, action: str, resource_type: str,
resource_id: str, description: str,
old_values: Dict[str, Any],
new_values: Dict[str, Any]) -> "AuditLog":
def create_data_change_event(
cls,
user_id: int,
action: str,
resource_type: str,
resource_id: str,
description: str,
old_values: Dict[str, Any],
new_values: Dict[str, Any],
) -> "AuditLog":
"""Create a data change audit event"""
return cls(
user_id=user_id,
@@ -329,9 +382,9 @@ class AuditLog(Base):
severity=AuditSeverity.LOW,
category="data",
success=True,
tags=["data_change", action]
tags=["data_change", action],
)
def get_summary(self) -> Dict[str, Any]:
"""Get a summary of the audit log"""
return {
@@ -342,5 +395,5 @@ class AuditLog(Base):
"severity": self.severity,
"success": self.success,
"created_at": self.created_at.isoformat() if self.created_at else None,
"user_id": self.user_id
}
"user_id": self.user_id,
}

View File

@@ -5,13 +5,24 @@ Budget model for managing spending limits and cost control
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from enum import Enum
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON, ForeignKey, Float
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
Boolean,
Text,
JSON,
ForeignKey,
Float,
)
from sqlalchemy.orm import relationship
from app.db.database import Base
class BudgetType(str, Enum):
"""Budget type enumeration"""
USER = "user"
API_KEY = "api_key"
GLOBAL = "global"
@@ -19,6 +30,7 @@ class BudgetType(str, Enum):
class BudgetPeriod(str, Enum):
"""Budget period types"""
DAILY = "daily"
WEEKLY = "weekly"
MONTHLY = "monthly"
@@ -28,67 +40,85 @@ class BudgetPeriod(str, Enum):
class Budget(Base):
"""Budget model for setting and managing spending limits"""
__tablename__ = "budgets"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False) # Human-readable name for the budget
# User and API Key relationships
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
user = relationship("User", back_populates="budgets")
api_key_id = Column(Integer, ForeignKey("api_keys.id"), nullable=True) # Optional: specific to an API key
api_key_id = Column(
Integer, ForeignKey("api_keys.id"), nullable=True
) # Optional: specific to an API key
api_key = relationship("APIKey", back_populates="budgets")
# Usage tracking relationship
usage_tracking = relationship("UsageTracking", back_populates="budget", cascade="all, delete-orphan")
usage_tracking = relationship(
"UsageTracking", back_populates="budget", cascade="all, delete-orphan"
)
# Budget limits (in cents)
limit_cents = Column(Integer, nullable=False) # Maximum spend limit
warning_threshold_cents = Column(Integer, nullable=True) # Warning threshold (e.g., 80% of limit)
warning_threshold_cents = Column(
Integer, nullable=True
) # Warning threshold (e.g., 80% of limit)
# Time period settings
period_type = Column(String, nullable=False, default="monthly") # daily, weekly, monthly, yearly, custom
period_type = Column(
String, nullable=False, default="monthly"
) # daily, weekly, monthly, yearly, custom
period_start = Column(DateTime, nullable=False) # Start of current period
period_end = Column(DateTime, nullable=False) # End of current period
# Current usage (in cents)
current_usage_cents = Column(Integer, default=0) # Spent in current period
# Budget status
is_active = Column(Boolean, default=True)
is_exceeded = Column(Boolean, default=False)
is_warning_sent = Column(Boolean, default=False)
# Enforcement settings
enforce_hard_limit = Column(Boolean, default=True) # Block requests when limit exceeded
enforce_hard_limit = Column(
Boolean, default=True
) # Block requests when limit exceeded
enforce_warning = Column(Boolean, default=True) # Send warnings at threshold
# Allowed resources (optional filters)
allowed_models = Column(JSON, default=list) # Specific models this budget applies to
allowed_endpoints = Column(JSON, default=list) # Specific endpoints this budget applies to
allowed_models = Column(
JSON, default=list
) # Specific models this budget applies to
allowed_endpoints = Column(
JSON, default=list
) # Specific endpoints this budget applies to
# Metadata
description = Column(Text, nullable=True)
tags = Column(JSON, default=list)
currency = Column(String, default="USD")
# Auto-renewal settings
auto_renew = Column(Boolean, default=True) # Automatically renew budget for next period
rollover_unused = Column(Boolean, default=False) # Rollover unused budget to next period
auto_renew = Column(
Boolean, default=True
) # Automatically renew budget for next period
rollover_unused = Column(
Boolean, default=False
) # Rollover unused budget to next period
# Notification settings
notification_settings = Column(JSON, default=dict) # Email, webhook, etc.
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
last_reset_at = Column(DateTime, nullable=True) # Last time budget was reset
def __repr__(self):
return f"<Budget(id={self.id}, name='{self.name}', user_id={self.user_id}, limit=${self.limit_cents/100:.2f})>"
def to_dict(self):
"""Convert budget to dictionary for API responses"""
return {
@@ -99,15 +129,23 @@ class Budget(Base):
"limit_cents": self.limit_cents,
"limit_dollars": self.limit_cents / 100,
"warning_threshold_cents": self.warning_threshold_cents,
"warning_threshold_dollars": self.warning_threshold_cents / 100 if self.warning_threshold_cents else None,
"warning_threshold_dollars": self.warning_threshold_cents / 100
if self.warning_threshold_cents
else None,
"period_type": self.period_type,
"period_start": self.period_start.isoformat() if self.period_start else None,
"period_start": self.period_start.isoformat()
if self.period_start
else None,
"period_end": self.period_end.isoformat() if self.period_end else None,
"current_usage_cents": self.current_usage_cents,
"current_usage_dollars": self.current_usage_cents / 100,
"remaining_cents": max(0, self.limit_cents - self.current_usage_cents),
"remaining_dollars": max(0, (self.limit_cents - self.current_usage_cents) / 100),
"usage_percentage": (self.current_usage_cents / self.limit_cents * 100) if self.limit_cents > 0 else 0,
"remaining_dollars": max(
0, (self.limit_cents - self.current_usage_cents) / 100
),
"usage_percentage": (self.current_usage_cents / self.limit_cents * 100)
if self.limit_cents > 0
else 0,
"is_active": self.is_active,
"is_exceeded": self.is_exceeded,
"is_warning_sent": self.is_warning_sent,
@@ -123,62 +161,67 @@ class Budget(Base):
"notification_settings": self.notification_settings,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"last_reset_at": self.last_reset_at.isoformat() if self.last_reset_at else None
"last_reset_at": self.last_reset_at.isoformat()
if self.last_reset_at
else None,
}
def is_in_period(self) -> bool:
"""Check if current time is within budget period"""
now = datetime.utcnow()
return self.period_start <= now <= self.period_end
def is_expired(self) -> bool:
"""Check if budget period has expired"""
return datetime.utcnow() > self.period_end
def can_spend(self, amount_cents: int) -> bool:
"""Check if spending amount is within budget"""
if not self.is_active or not self.is_in_period():
return False
if not self.enforce_hard_limit:
return True
return (self.current_usage_cents + amount_cents) <= self.limit_cents
def would_exceed_warning(self, amount_cents: int) -> bool:
"""Check if spending amount would exceed warning threshold"""
if not self.warning_threshold_cents:
return False
return (self.current_usage_cents + amount_cents) >= self.warning_threshold_cents
def add_usage(self, amount_cents: int):
"""Add usage to current budget"""
self.current_usage_cents += amount_cents
# Check if budget is exceeded
if self.current_usage_cents >= self.limit_cents:
self.is_exceeded = True
# Check if warning threshold is reached
if self.warning_threshold_cents and self.current_usage_cents >= self.warning_threshold_cents:
if (
self.warning_threshold_cents
and self.current_usage_cents >= self.warning_threshold_cents
):
if not self.is_warning_sent:
self.is_warning_sent = True
self.updated_at = datetime.utcnow()
def reset_period(self):
"""Reset budget for new period"""
if self.rollover_unused and self.current_usage_cents < self.limit_cents:
# Rollover unused budget
unused_amount = self.limit_cents - self.current_usage_cents
self.limit_cents += unused_amount
self.current_usage_cents = 0
self.is_exceeded = False
self.is_warning_sent = False
self.last_reset_at = datetime.utcnow()
# Calculate next period
if self.period_type == "daily":
self.period_start = self.period_end
@@ -190,39 +233,43 @@ class Budget(Base):
self.period_start = self.period_end
# Handle month boundaries properly
if self.period_start.month == 12:
next_month = self.period_start.replace(year=self.period_start.year + 1, month=1)
next_month = self.period_start.replace(
year=self.period_start.year + 1, month=1
)
else:
next_month = self.period_start.replace(month=self.period_start.month + 1)
next_month = self.period_start.replace(
month=self.period_start.month + 1
)
self.period_end = next_month
elif self.period_type == "yearly":
self.period_start = self.period_end
self.period_end = self.period_start.replace(year=self.period_start.year + 1)
self.updated_at = datetime.utcnow()
def get_period_days_remaining(self) -> int:
"""Get number of days remaining in current period"""
if self.is_expired():
return 0
return (self.period_end - datetime.utcnow()).days
def get_daily_burn_rate(self) -> float:
"""Get average daily spend rate in current period"""
if not self.is_in_period():
return 0.0
days_elapsed = (datetime.utcnow() - self.period_start).days
if days_elapsed == 0:
days_elapsed = 1 # Avoid division by zero
return self.current_usage_cents / days_elapsed / 100 # Return in dollars
def get_projected_spend(self) -> float:
"""Get projected spend for entire period based on current burn rate"""
daily_burn = self.get_daily_burn_rate()
total_period_days = (self.period_end - self.period_start).days
return daily_burn * total_period_days
@classmethod
def create_monthly_budget(
cls,
@@ -230,7 +277,7 @@ class Budget(Base):
name: str,
limit_dollars: float,
api_key_id: Optional[int] = None,
warning_threshold_percentage: float = 0.8
warning_threshold_percentage: float = 0.8,
) -> "Budget":
"""Create a monthly budget"""
now = datetime.utcnow()
@@ -241,10 +288,10 @@ class Budget(Base):
period_end = period_start.replace(year=now.year + 1, month=1)
else:
period_end = period_start.replace(month=now.month + 1)
limit_cents = int(limit_dollars * 100)
warning_threshold_cents = int(limit_cents * warning_threshold_percentage)
return cls(
name=name,
user_id=user_id,
@@ -258,28 +305,25 @@ class Budget(Base):
enforce_hard_limit=True,
enforce_warning=True,
auto_renew=True,
notification_settings={
"email_on_warning": True,
"email_on_exceeded": True
}
notification_settings={"email_on_warning": True, "email_on_exceeded": True},
)
@classmethod
def create_daily_budget(
cls,
user_id: int,
name: str,
limit_dollars: float,
api_key_id: Optional[int] = None
api_key_id: Optional[int] = None,
) -> "Budget":
"""Create a daily budget"""
now = datetime.utcnow()
period_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
period_end = period_start + timedelta(days=1)
limit_cents = int(limit_dollars * 100)
warning_threshold_cents = int(limit_cents * 0.8) # 80% warning threshold
return cls(
name=name,
user_id=user_id,
@@ -292,5 +336,5 @@ class Budget(Base):
is_active=True,
enforce_hard_limit=True,
enforce_warning=True,
auto_renew=True
)
auto_renew=True,
)

View File

@@ -1,7 +1,16 @@
"""
Database models for chatbot module
"""
from sqlalchemy import Column, Integer, String, Text, Boolean, DateTime, JSON, ForeignKey
from sqlalchemy import (
Column,
Integer,
String,
Text,
Boolean,
DateTime,
JSON,
ForeignKey,
)
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base
from datetime import datetime
@@ -9,102 +18,115 @@ import uuid
from app.db.database import Base
class ChatbotInstance(Base):
"""Configured chatbot instance"""
__tablename__ = "chatbot_instances"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
name = Column(String(255), nullable=False)
description = Column(Text)
# Configuration stored as JSON
config = Column(JSON, nullable=False)
# Metadata
created_by = Column(String, nullable=False) # User ID
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
is_active = Column(Boolean, default=True)
# Relationships
conversations = relationship("ChatbotConversation", back_populates="chatbot", cascade="all, delete-orphan")
conversations = relationship(
"ChatbotConversation", back_populates="chatbot", cascade="all, delete-orphan"
)
def __repr__(self):
return f"<ChatbotInstance(id='{self.id}', name='{self.name}')>"
class ChatbotConversation(Base):
"""Conversation state and history"""
__tablename__ = "chatbot_conversations"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
chatbot_id = Column(String, ForeignKey("chatbot_instances.id"), nullable=False)
user_id = Column(String, nullable=False) # User ID
# Conversation metadata
title = Column(String(255)) # Auto-generated or user-defined title
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
is_active = Column(Boolean, default=True)
# Conversation context and settings
context_data = Column(JSON, default=dict) # Additional context
# Relationships
chatbot = relationship("ChatbotInstance", back_populates="conversations")
messages = relationship("ChatbotMessage", back_populates="conversation", cascade="all, delete-orphan")
messages = relationship(
"ChatbotMessage", back_populates="conversation", cascade="all, delete-orphan"
)
def __repr__(self):
return f"<ChatbotConversation(id='{self.id}', chatbot_id='{self.chatbot_id}')>"
class ChatbotMessage(Base):
"""Individual chat messages in conversations"""
__tablename__ = "chatbot_messages"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
conversation_id = Column(String, ForeignKey("chatbot_conversations.id"), nullable=False)
conversation_id = Column(
String, ForeignKey("chatbot_conversations.id"), nullable=False
)
# Message content
role = Column(String(20), nullable=False) # 'user', 'assistant', 'system'
content = Column(Text, nullable=False)
# Metadata
timestamp = Column(DateTime, default=datetime.utcnow)
message_metadata = Column(JSON, default=dict) # Token counts, model used, etc.
# RAG sources if applicable
sources = Column(JSON) # RAG sources used for this message
# Relationships
conversation = relationship("ChatbotConversation", back_populates="messages")
def __repr__(self):
return f"<ChatbotMessage(id='{self.id}', role='{self.role}')>"
class ChatbotAnalytics(Base):
"""Analytics and metrics for chatbot usage"""
__tablename__ = "chatbot_analytics"
id = Column(Integer, primary_key=True, autoincrement=True)
chatbot_id = Column(String, ForeignKey("chatbot_instances.id"), nullable=False)
user_id = Column(String, nullable=False)
# Event tracking
event_type = Column(String(50), nullable=False) # 'message_sent', 'response_generated', etc.
event_type = Column(
String(50), nullable=False
) # 'message_sent', 'response_generated', etc.
event_data = Column(JSON, default=dict)
# Performance metrics
response_time_ms = Column(Integer)
token_count = Column(Integer)
cost_cents = Column(Integer)
# Context
model_used = Column(String(100))
rag_used = Column(Boolean, default=False)
timestamp = Column(DateTime, default=datetime.utcnow)
def __repr__(self):
return f"<ChatbotAnalytics(id={self.id}, event_type='{self.event_type}')>"
return f"<ChatbotAnalytics(id={self.id}, event_type='{self.event_type}')>"

View File

@@ -10,6 +10,7 @@ from enum import Enum
class ModuleStatus(str, Enum):
"""Module status types"""
ACTIVE = "active"
INACTIVE = "inactive"
ERROR = "error"
@@ -19,6 +20,7 @@ class ModuleStatus(str, Enum):
class ModuleType(str, Enum):
"""Module type categories"""
CORE = "core"
INTERCEPTOR = "interceptor"
ANALYTICS = "analytics"
@@ -30,75 +32,81 @@ class ModuleType(str, Enum):
class Module(Base):
"""Module model for tracking installed modules and their configurations"""
__tablename__ = "modules"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, unique=True, index=True, nullable=False)
display_name = Column(String, nullable=False)
description = Column(Text, nullable=True)
# Module classification
module_type = Column(String, default=ModuleType.CUSTOM)
category = Column(String, nullable=True) # cache, rag, analytics, etc.
# Module details
version = Column(String, nullable=False)
author = Column(String, nullable=True)
license = Column(String, nullable=True)
# Module status
status = Column(String, default=ModuleStatus.INACTIVE)
is_enabled = Column(Boolean, default=False)
is_core = Column(Boolean, default=False) # Core modules cannot be disabled
# Configuration
config_schema = Column(JSON, default=dict) # JSON schema for configuration
config_values = Column(JSON, default=dict) # Current configuration values
default_config = Column(JSON, default=dict) # Default configuration
# Dependencies
dependencies = Column(JSON, default=list) # List of module dependencies
conflicts = Column(JSON, default=list) # List of conflicting modules
# Installation details
install_path = Column(String, nullable=True)
entry_point = Column(String, nullable=True) # Main module entry point
# Interceptor configuration
interceptor_chains = Column(JSON, default=list) # Which chains this module hooks into
interceptor_chains = Column(
JSON, default=list
) # Which chains this module hooks into
execution_order = Column(Integer, default=100) # Order in interceptor chain
# API endpoints
api_endpoints = Column(JSON, default=list) # List of API endpoints this module provides
api_endpoints = Column(
JSON, default=list
) # List of API endpoints this module provides
# Permissions and security
required_permissions = Column(JSON, default=list) # Permissions required to use this module
required_permissions = Column(
JSON, default=list
) # Permissions required to use this module
security_level = Column(String, default="low") # low, medium, high, critical
# Metadata
tags = Column(JSON, default=list)
module_metadata = Column(JSON, default=dict)
# Runtime information
last_error = Column(Text, nullable=True)
error_count = Column(Integer, default=0)
last_started = Column(DateTime, nullable=True)
last_stopped = Column(DateTime, nullable=True)
# Statistics
request_count = Column(Integer, default=0)
success_count = Column(Integer, default=0)
error_count_runtime = Column(Integer, default=0)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
installed_at = Column(DateTime, default=datetime.utcnow)
def __repr__(self):
return f"<Module(id={self.id}, name='{self.name}', status='{self.status}')>"
def to_dict(self):
"""Convert module to dictionary for API responses"""
return {
@@ -130,81 +138,87 @@ class Module(Base):
"metadata": self.module_metadata,
"last_error": self.last_error,
"error_count": self.error_count,
"last_started": self.last_started.isoformat() if self.last_started else None,
"last_stopped": self.last_stopped.isoformat() if self.last_stopped else None,
"last_started": self.last_started.isoformat()
if self.last_started
else None,
"last_stopped": self.last_stopped.isoformat()
if self.last_stopped
else None,
"request_count": self.request_count,
"success_count": self.success_count,
"error_count_runtime": self.error_count_runtime,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"installed_at": self.installed_at.isoformat() if self.installed_at else None,
"installed_at": self.installed_at.isoformat()
if self.installed_at
else None,
"success_rate": self.get_success_rate(),
"uptime": self.get_uptime_seconds() if self.is_running() else 0
"uptime": self.get_uptime_seconds() if self.is_running() else 0,
}
def is_running(self) -> bool:
"""Check if module is currently running"""
return self.status == ModuleStatus.ACTIVE
def is_healthy(self) -> bool:
"""Check if module is healthy (running without recent errors)"""
return self.is_running() and self.error_count_runtime == 0
def get_success_rate(self) -> float:
"""Get success rate as percentage"""
if self.request_count == 0:
return 100.0
return (self.success_count / self.request_count) * 100
def get_uptime_seconds(self) -> int:
"""Get uptime in seconds"""
if not self.last_started:
return 0
return int((datetime.utcnow() - self.last_started).total_seconds())
def can_be_disabled(self) -> bool:
"""Check if module can be disabled"""
return not self.is_core
def has_dependency(self, module_name: str) -> bool:
"""Check if module has a specific dependency"""
return module_name in self.dependencies
def conflicts_with(self, module_name: str) -> bool:
"""Check if module conflicts with another module"""
return module_name in self.conflicts
def requires_permission(self, permission: str) -> bool:
"""Check if module requires a specific permission"""
return permission in self.required_permissions
def hooks_into_chain(self, chain_name: str) -> bool:
"""Check if module hooks into a specific interceptor chain"""
return chain_name in self.interceptor_chains
def provides_endpoint(self, endpoint: str) -> bool:
"""Check if module provides a specific API endpoint"""
return endpoint in self.api_endpoints
def update_config(self, config_updates: Dict[str, Any]):
"""Update module configuration"""
if self.config_values is None:
self.config_values = {}
self.config_values.update(config_updates)
self.updated_at = datetime.utcnow()
def reset_config(self):
"""Reset configuration to default values"""
self.config_values = self.default_config.copy() if self.default_config else {}
self.updated_at = datetime.utcnow()
def enable(self):
"""Enable the module"""
if self.status != ModuleStatus.ERROR:
self.is_enabled = True
self.status = ModuleStatus.LOADING
self.updated_at = datetime.utcnow()
def disable(self):
"""Disable the module"""
if self.can_be_disabled():
@@ -212,20 +226,20 @@ class Module(Base):
self.status = ModuleStatus.DISABLED
self.last_stopped = datetime.utcnow()
self.updated_at = datetime.utcnow()
def start(self):
"""Start the module"""
self.status = ModuleStatus.ACTIVE
self.last_started = datetime.utcnow()
self.last_error = None
self.updated_at = datetime.utcnow()
def stop(self):
"""Stop the module"""
self.status = ModuleStatus.INACTIVE
self.last_stopped = datetime.utcnow()
self.updated_at = datetime.utcnow()
def set_error(self, error_message: str):
"""Set module error status"""
self.status = ModuleStatus.ERROR
@@ -233,13 +247,13 @@ class Module(Base):
self.error_count += 1
self.error_count_runtime += 1
self.updated_at = datetime.utcnow()
def clear_error(self):
"""Clear error status"""
self.last_error = None
self.error_count_runtime = 0
self.updated_at = datetime.utcnow()
def record_request(self, success: bool = True):
"""Record a request to this module"""
self.request_count += 1
@@ -247,76 +261,82 @@ class Module(Base):
self.success_count += 1
else:
self.error_count_runtime += 1
def add_tag(self, tag: str):
"""Add a tag to the module"""
if tag not in self.tags:
self.tags.append(tag)
def remove_tag(self, tag: str):
"""Remove a tag from the module"""
if tag in self.tags:
self.tags.remove(tag)
def update_metadata(self, key: str, value: Any):
"""Update metadata"""
if self.module_metadata is None:
self.module_metadata = {}
self.module_metadata[key] = value
def add_dependency(self, module_name: str):
"""Add a dependency"""
if module_name not in self.dependencies:
self.dependencies.append(module_name)
def remove_dependency(self, module_name: str):
"""Remove a dependency"""
if module_name in self.dependencies:
self.dependencies.remove(module_name)
def add_conflict(self, module_name: str):
"""Add a conflict"""
if module_name not in self.conflicts:
self.conflicts.append(module_name)
def remove_conflict(self, module_name: str):
"""Remove a conflict"""
if module_name in self.conflicts:
self.conflicts.remove(module_name)
def add_interceptor_chain(self, chain_name: str):
"""Add an interceptor chain"""
if chain_name not in self.interceptor_chains:
self.interceptor_chains.append(chain_name)
def remove_interceptor_chain(self, chain_name: str):
"""Remove an interceptor chain"""
if chain_name in self.interceptor_chains:
self.interceptor_chains.remove(chain_name)
def add_api_endpoint(self, endpoint: str):
"""Add an API endpoint"""
if endpoint not in self.api_endpoints:
self.api_endpoints.append(endpoint)
def remove_api_endpoint(self, endpoint: str):
"""Remove an API endpoint"""
if endpoint in self.api_endpoints:
self.api_endpoints.remove(endpoint)
def add_required_permission(self, permission: str):
"""Add a required permission"""
if permission not in self.required_permissions:
self.required_permissions.append(permission)
def remove_required_permission(self, permission: str):
"""Remove a required permission"""
if permission in self.required_permissions:
self.required_permissions.remove(permission)
@classmethod
def create_core_module(cls, name: str, display_name: str, description: str,
version: str, entry_point: str) -> "Module":
def create_core_module(
cls,
name: str,
display_name: str,
description: str,
version: str,
entry_point: str,
) -> "Module":
"""Create a core module"""
return cls(
name=name,
@@ -341,9 +361,9 @@ class Module(Base):
required_permissions=[],
security_level="high",
tags=["core"],
module_metadata={}
module_metadata={},
)
@classmethod
def create_cache_module(cls) -> "Module":
"""Create the cache module"""
@@ -365,20 +385,12 @@ class Module(Base):
"properties": {
"provider": {"type": "string", "enum": ["redis"]},
"ttl": {"type": "integer", "minimum": 60},
"max_size": {"type": "integer", "minimum": 1000}
"max_size": {"type": "integer", "minimum": 1000},
},
"required": ["provider", "ttl"]
},
config_values={
"provider": "redis",
"ttl": 3600,
"max_size": 10000
},
default_config={
"provider": "redis",
"ttl": 3600,
"max_size": 10000
"required": ["provider", "ttl"],
},
config_values={"provider": "redis", "ttl": 3600, "max_size": 10000},
default_config={"provider": "redis", "ttl": 3600, "max_size": 10000},
dependencies=[],
conflicts=[],
interceptor_chains=["pre_request", "post_response"],
@@ -387,9 +399,9 @@ class Module(Base):
required_permissions=["cache.read", "cache.write"],
security_level="low",
tags=["cache", "performance"],
module_metadata={}
module_metadata={},
)
@classmethod
def create_rag_module(cls) -> "Module":
"""Create the RAG module"""
@@ -412,21 +424,21 @@ class Module(Base):
"vector_db": {"type": "string", "enum": ["qdrant"]},
"embedding_model": {"type": "string"},
"chunk_size": {"type": "integer", "minimum": 100},
"max_results": {"type": "integer", "minimum": 1}
"max_results": {"type": "integer", "minimum": 1},
},
"required": ["vector_db", "embedding_model"]
"required": ["vector_db", "embedding_model"],
},
config_values={
"vector_db": "qdrant",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"chunk_size": 512,
"max_results": 10
"max_results": 10,
},
default_config={
"vector_db": "qdrant",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"chunk_size": 512,
"max_results": 10
"max_results": 10,
},
dependencies=[],
conflicts=[],
@@ -436,9 +448,9 @@ class Module(Base):
required_permissions=["rag.read", "rag.write"],
security_level="medium",
tags=["rag", "ai", "search"],
module_metadata={}
module_metadata={},
)
@classmethod
def create_analytics_module(cls) -> "Module":
"""Create the analytics module"""
@@ -460,19 +472,19 @@ class Module(Base):
"properties": {
"track_requests": {"type": "boolean"},
"track_responses": {"type": "boolean"},
"retention_days": {"type": "integer", "minimum": 1}
"retention_days": {"type": "integer", "minimum": 1},
},
"required": ["track_requests", "track_responses"]
"required": ["track_requests", "track_responses"],
},
config_values={
"track_requests": True,
"track_responses": True,
"retention_days": 30
"retention_days": 30,
},
default_config={
"track_requests": True,
"track_responses": True,
"retention_days": 30
"retention_days": 30,
},
dependencies=[],
conflicts=[],
@@ -482,9 +494,9 @@ class Module(Base):
required_permissions=["analytics.read"],
security_level="low",
tags=["analytics", "monitoring"],
module_metadata={}
module_metadata={},
)
def get_health_status(self) -> Dict[str, Any]:
"""Get health status of the module"""
return {
@@ -495,5 +507,7 @@ class Module(Base):
"uptime_seconds": self.get_uptime_seconds() if self.is_running() else 0,
"last_error": self.last_error,
"error_count": self.error_count_runtime,
"last_started": self.last_started.isoformat() if self.last_started else None
}
"last_started": self.last_started.isoformat()
if self.last_started
else None,
}

View File

@@ -0,0 +1,295 @@
"""
Notification models for multi-channel communication
"""
from datetime import datetime
from typing import Optional, Dict, Any
from enum import Enum
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
Boolean,
Text,
JSON,
ForeignKey,
)
from sqlalchemy.orm import relationship
from app.db.database import Base
class NotificationType(str, Enum):
"""Notification types"""
EMAIL = "email"
WEBHOOK = "webhook"
SLACK = "slack"
DISCORD = "discord"
SMS = "sms"
PUSH = "push"
class NotificationPriority(str, Enum):
"""Notification priority levels"""
LOW = "low"
NORMAL = "normal"
HIGH = "high"
URGENT = "urgent"
class NotificationStatus(str, Enum):
"""Notification delivery status"""
PENDING = "pending"
SENT = "sent"
DELIVERED = "delivered"
FAILED = "failed"
RETRY = "retry"
class NotificationTemplate(Base):
"""Notification template model"""
__tablename__ = "notification_templates"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), nullable=False, unique=True)
display_name = Column(String(200), nullable=False)
description = Column(Text, nullable=True)
# Template content
notification_type = Column(String(20), nullable=False) # NotificationType enum
subject_template = Column(Text, nullable=True) # For email/messages
body_template = Column(Text, nullable=False) # Main content
html_template = Column(Text, nullable=True) # HTML version for email
# Configuration
default_priority = Column(String(20), default=NotificationPriority.NORMAL)
variables = Column(JSON, default=dict) # Expected template variables
template_metadata = Column(JSON, default=dict) # Additional configuration
# Status
is_active = Column(Boolean, default=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
notifications = relationship("Notification", back_populates="template")
def __repr__(self):
return f"<NotificationTemplate(id={self.id}, name='{self.name}', type='{self.notification_type}')>"
def to_dict(self):
"""Convert template to dictionary"""
return {
"id": self.id,
"name": self.name,
"display_name": self.display_name,
"description": self.description,
"notification_type": self.notification_type,
"subject_template": self.subject_template,
"body_template": self.body_template,
"html_template": self.html_template,
"default_priority": self.default_priority,
"variables": self.variables,
"template_metadata": self.template_metadata,
"is_active": self.is_active,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
class NotificationChannel(Base):
"""Notification channel configuration"""
__tablename__ = "notification_channels"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), nullable=False)
display_name = Column(String(200), nullable=False)
notification_type = Column(String(20), nullable=False) # NotificationType enum
# Channel configuration
config = Column(JSON, nullable=False) # Channel-specific settings
credentials = Column(JSON, nullable=True) # Encrypted credentials
# Settings
is_active = Column(Boolean, default=True)
is_default = Column(Boolean, default=False)
rate_limit = Column(Integer, default=100) # Messages per minute
retry_count = Column(Integer, default=3)
retry_delay_minutes = Column(Integer, default=5)
# Health monitoring
last_used_at = Column(DateTime, nullable=True)
success_count = Column(Integer, default=0)
failure_count = Column(Integer, default=0)
last_error = Column(Text, nullable=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
notifications = relationship("Notification", back_populates="channel")
def __repr__(self):
return f"<NotificationChannel(id={self.id}, name='{self.name}', type='{self.notification_type}')>"
def to_dict(self):
"""Convert channel to dictionary (excluding sensitive credentials)"""
return {
"id": self.id,
"name": self.name,
"display_name": self.display_name,
"notification_type": self.notification_type,
"config": self.config,
"is_active": self.is_active,
"is_default": self.is_default,
"rate_limit": self.rate_limit,
"retry_count": self.retry_count,
"retry_delay_minutes": self.retry_delay_minutes,
"last_used_at": self.last_used_at.isoformat()
if self.last_used_at
else None,
"success_count": self.success_count,
"failure_count": self.failure_count,
"last_error": self.last_error,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
def update_stats(self, success: bool, error_message: Optional[str] = None):
"""Update channel statistics"""
self.last_used_at = datetime.utcnow()
if success:
self.success_count += 1
self.last_error = None
else:
self.failure_count += 1
self.last_error = error_message
class Notification(Base):
"""Individual notification instance"""
__tablename__ = "notifications"
id = Column(Integer, primary_key=True, index=True)
# Content
subject = Column(String(500), nullable=True)
body = Column(Text, nullable=False)
html_body = Column(Text, nullable=True)
# Recipients
recipients = Column(JSON, nullable=False) # List of recipient addresses/IDs
cc_recipients = Column(JSON, nullable=True) # CC recipients (for email)
bcc_recipients = Column(JSON, nullable=True) # BCC recipients (for email)
# Configuration
priority = Column(String(20), default=NotificationPriority.NORMAL)
scheduled_at = Column(DateTime, nullable=True) # For scheduled delivery
expires_at = Column(DateTime, nullable=True) # Expiration time
# References
template_id = Column(
Integer, ForeignKey("notification_templates.id"), nullable=True
)
channel_id = Column(Integer, ForeignKey("notification_channels.id"), nullable=False)
user_id = Column(Integer, ForeignKey("users.id"), nullable=True) # Triggering user
# Status tracking
status = Column(String(20), default=NotificationStatus.PENDING)
attempts = Column(Integer, default=0)
max_attempts = Column(Integer, default=3)
# Delivery tracking
sent_at = Column(DateTime, nullable=True)
delivered_at = Column(DateTime, nullable=True)
failed_at = Column(DateTime, nullable=True)
error_message = Column(Text, nullable=True)
# External references
external_id = Column(String(200), nullable=True) # Provider message ID
callback_url = Column(String(500), nullable=True) # Delivery callback
# Metadata
notification_metadata = Column(JSON, default=dict)
tags = Column(JSON, default=list)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
template = relationship("NotificationTemplate", back_populates="notifications")
channel = relationship("NotificationChannel", back_populates="notifications")
user = relationship("User", back_populates="notifications")
def __repr__(self):
return f"<Notification(id={self.id}, status='{self.status}', channel='{self.channel.name if self.channel else 'unknown'}')>"
def to_dict(self):
"""Convert notification to dictionary"""
return {
"id": self.id,
"subject": self.subject,
"body": self.body,
"html_body": self.html_body,
"recipients": self.recipients,
"cc_recipients": self.cc_recipients,
"bcc_recipients": self.bcc_recipients,
"priority": self.priority,
"scheduled_at": self.scheduled_at.isoformat()
if self.scheduled_at
else None,
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
"template_id": self.template_id,
"channel_id": self.channel_id,
"user_id": self.user_id,
"status": self.status,
"attempts": self.attempts,
"max_attempts": self.max_attempts,
"sent_at": self.sent_at.isoformat() if self.sent_at else None,
"delivered_at": self.delivered_at.isoformat()
if self.delivered_at
else None,
"failed_at": self.failed_at.isoformat() if self.failed_at else None,
"error_message": self.error_message,
"external_id": self.external_id,
"callback_url": self.callback_url,
"notification_metadata": self.notification_metadata,
"tags": self.tags,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
def mark_sent(self, external_id: Optional[str] = None):
"""Mark notification as sent"""
self.status = NotificationStatus.SENT
self.sent_at = datetime.utcnow()
self.external_id = external_id
def mark_delivered(self):
"""Mark notification as delivered"""
self.status = NotificationStatus.DELIVERED
self.delivered_at = datetime.utcnow()
def mark_failed(self, error_message: str):
"""Mark notification as failed"""
self.status = NotificationStatus.FAILED
self.failed_at = datetime.utcnow()
self.error_message = error_message
self.attempts += 1
def can_retry(self) -> bool:
"""Check if notification can be retried"""
return (
self.status in [NotificationStatus.FAILED, NotificationStatus.RETRY]
and self.attempts < self.max_attempts
and (self.expires_at is None or self.expires_at > datetime.utcnow())
)

View File

@@ -2,7 +2,17 @@
Plugin System Database Models
Defines the database schema for the isolated plugin architecture
"""
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, JSON, ForeignKey, Index
from sqlalchemy import (
Column,
Integer,
String,
Text,
DateTime,
Boolean,
JSON,
ForeignKey,
Index,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
@@ -13,13 +23,16 @@ from app.db.database import Base
class Plugin(Base):
"""Plugin registry - tracks all installed plugins"""
__tablename__ = "plugins"
# Primary identification
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(100), unique=True, nullable=False, index=True)
slug = Column(String(100), unique=True, nullable=False, index=True) # URL-safe identifier
slug = Column(
String(100), unique=True, nullable=False, index=True
) # URL-safe identifier
# Metadata
display_name = Column(String(200), nullable=False)
description = Column(Text)
@@ -27,65 +40,74 @@ class Plugin(Base):
author = Column(String(200))
homepage = Column(String(500))
repository = Column(String(500))
# Plugin file information
package_path = Column(String(500), nullable=False) # Path to plugin package
manifest_hash = Column(String(64), nullable=False) # SHA256 of manifest file
package_hash = Column(String(64), nullable=False) # SHA256 of plugin package
package_hash = Column(String(64), nullable=False) # SHA256 of plugin package
# Status and lifecycle
status = Column(String(20), nullable=False, default="installed", index=True)
# Statuses: installing, installed, enabled, disabled, error, uninstalling
enabled = Column(Boolean, default=False, nullable=False, index=True)
auto_enable = Column(Boolean, default=False, nullable=False)
# Installation tracking
installed_at = Column(DateTime, nullable=False, default=func.now())
enabled_at = Column(DateTime)
last_updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
installed_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
# Configuration and requirements
manifest_data = Column(JSON) # Complete plugin.yaml content
config_schema = Column(JSON) # JSON schema for plugin configuration
default_config = Column(JSON) # Default configuration values
# Security and permissions
required_permissions = Column(JSON) # List of required permission scopes
api_scopes = Column(JSON) # Required API access scopes
resource_limits = Column(JSON) # Memory, CPU, storage limits
# Database isolation
database_name = Column(String(100), unique=True) # Isolated database name
database_url = Column(String(1000)) # Connection string for plugin database
# Error tracking
last_error = Column(Text)
error_count = Column(Integer, default=0)
last_error_at = Column(DateTime)
# Relationships
installed_by_user = relationship("User", back_populates="installed_plugins")
configurations = relationship("PluginConfiguration", back_populates="plugin", cascade="all, delete-orphan")
instances = relationship("PluginInstance", back_populates="plugin", cascade="all, delete-orphan")
audit_logs = relationship("PluginAuditLog", back_populates="plugin", cascade="all, delete-orphan")
cron_jobs = relationship("PluginCronJob", back_populates="plugin", cascade="all, delete-orphan")
configurations = relationship(
"PluginConfiguration", back_populates="plugin", cascade="all, delete-orphan"
)
instances = relationship(
"PluginInstance", back_populates="plugin", cascade="all, delete-orphan"
)
audit_logs = relationship(
"PluginAuditLog", back_populates="plugin", cascade="all, delete-orphan"
)
cron_jobs = relationship(
"PluginCronJob", back_populates="plugin", cascade="all, delete-orphan"
)
# Indexes for performance
__table_args__ = (
Index('idx_plugin_status_enabled', 'status', 'enabled'),
Index('idx_plugin_user_status', 'installed_by_user_id', 'status'),
Index("idx_plugin_status_enabled", "status", "enabled"),
Index("idx_plugin_user_status", "installed_by_user_id", "status"),
)
class PluginConfiguration(Base):
"""Plugin configuration instances - per user/environment configs"""
__tablename__ = "plugin_configurations"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
plugin_id = Column(UUID(as_uuid=True), ForeignKey("plugins.id"), nullable=False)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
# Configuration data
name = Column(String(200), nullable=False) # Human-readable config name
description = Column(Text)
@@ -94,133 +116,140 @@ class PluginConfiguration(Base):
schema_version = Column(String(50)) # Schema version for migration support
is_active = Column(Boolean, default=False, nullable=False)
is_default = Column(Boolean, default=False, nullable=False)
# Metadata
created_at = Column(DateTime, nullable=False, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
# Relationships
plugin = relationship("Plugin", back_populates="configurations")
user = relationship("User", foreign_keys=[user_id])
created_by_user = relationship("User", foreign_keys=[created_by_user_id])
# Constraints
__table_args__ = (
Index('idx_plugin_config_user_active', 'plugin_id', 'user_id', 'is_active'),
Index("idx_plugin_config_user_active", "plugin_id", "user_id", "is_active"),
)
class PluginInstance(Base):
"""Plugin runtime instances - tracks running plugin processes"""
__tablename__ = "plugin_instances"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
plugin_id = Column(UUID(as_uuid=True), ForeignKey("plugins.id"), nullable=False)
configuration_id = Column(UUID(as_uuid=True), ForeignKey("plugin_configurations.id"))
configuration_id = Column(
UUID(as_uuid=True), ForeignKey("plugin_configurations.id")
)
# Runtime information
instance_name = Column(String(200), nullable=False)
process_id = Column(String(100)) # Docker container ID or process ID
status = Column(String(20), nullable=False, default="starting", index=True)
# Statuses: starting, running, stopping, stopped, error, crashed
# Performance tracking
start_time = Column(DateTime, nullable=False, default=func.now())
last_heartbeat = Column(DateTime, default=func.now())
stop_time = Column(DateTime)
restart_count = Column(Integer, default=0)
# Resource usage
memory_usage_mb = Column(Integer)
cpu_usage_percent = Column(Integer)
# Health monitoring
health_status = Column(String(20), default="unknown") # healthy, warning, critical
health_message = Column(Text)
last_health_check = Column(DateTime)
# Error tracking
last_error = Column(Text)
error_count = Column(Integer, default=0)
# Relationships
plugin = relationship("Plugin", back_populates="instances")
configuration = relationship("PluginConfiguration")
__table_args__ = (
Index('idx_plugin_instance_status', 'plugin_id', 'status'),
)
__table_args__ = (Index("idx_plugin_instance_status", "plugin_id", "status"),)
class PluginAuditLog(Base):
"""Audit logging for all plugin activities"""
__tablename__ = "plugin_audit_logs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
plugin_id = Column(UUID(as_uuid=True), ForeignKey("plugins.id"), nullable=False)
instance_id = Column(UUID(as_uuid=True), ForeignKey("plugin_instances.id"))
# Event details
event_type = Column(String(50), nullable=False, index=True) # api_call, config_change, error, etc.
event_type = Column(
String(50), nullable=False, index=True
) # api_call, config_change, error, etc.
action = Column(String(100), nullable=False)
resource = Column(String(200)) # Resource being accessed
# Context information
user_id = Column(Integer, ForeignKey("users.id"))
api_key_id = Column(Integer, ForeignKey("api_keys.id"))
ip_address = Column(String(45)) # IPv4 or IPv6
user_agent = Column(String(500))
# Request/response data
request_data = Column(JSON) # Sanitized request data
response_status = Column(Integer)
response_data = Column(JSON) # Sanitized response data
# Performance metrics
duration_ms = Column(Integer)
# Status and errors
success = Column(Boolean, nullable=False, index=True)
error_message = Column(Text)
# Timestamps
timestamp = Column(DateTime, nullable=False, default=func.now(), index=True)
# Relationships
plugin = relationship("Plugin", back_populates="audit_logs")
instance = relationship("PluginInstance")
user = relationship("User")
api_key = relationship("APIKey")
__table_args__ = (
Index('idx_plugin_audit_plugin_time', 'plugin_id', 'timestamp'),
Index('idx_plugin_audit_user_time', 'user_id', 'timestamp'),
Index('idx_plugin_audit_event_type', 'event_type', 'timestamp'),
Index("idx_plugin_audit_plugin_time", "plugin_id", "timestamp"),
Index("idx_plugin_audit_user_time", "user_id", "timestamp"),
Index("idx_plugin_audit_event_type", "event_type", "timestamp"),
)
class PluginCronJob(Base):
"""Plugin scheduled jobs and cron tasks"""
__tablename__ = "plugin_cron_jobs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
plugin_id = Column(UUID(as_uuid=True), ForeignKey("plugins.id"), nullable=False)
# Job identification
job_name = Column(String(200), nullable=False)
job_id = Column(String(100), nullable=False, unique=True, index=True) # Unique scheduler ID
job_id = Column(
String(100), nullable=False, unique=True, index=True
) # Unique scheduler ID
# Schedule configuration
schedule = Column(String(100), nullable=False) # Cron expression
timezone = Column(String(50), default="UTC")
enabled = Column(Boolean, default=True, nullable=False, index=True)
# Job details
description = Column(Text)
function_name = Column(String(200), nullable=False) # Plugin function to call
job_data = Column(JSON) # Parameters for the job function
# Execution tracking
last_run_at = Column(DateTime)
next_run_at = Column(DateTime, index=True)
@@ -228,65 +257,72 @@ class PluginCronJob(Base):
run_count = Column(Integer, default=0)
success_count = Column(Integer, default=0)
error_count = Column(Integer, default=0)
# Error handling
last_error = Column(Text)
last_error_at = Column(DateTime)
max_retries = Column(Integer, default=3)
retry_delay_seconds = Column(Integer, default=60)
# Lifecycle
created_at = Column(DateTime, nullable=False, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
# Relationships
plugin = relationship("Plugin", back_populates="cron_jobs")
created_by_user = relationship("User")
__table_args__ = (
Index('idx_plugin_cron_next_run', 'enabled', 'next_run_at'),
Index('idx_plugin_cron_plugin', 'plugin_id', 'enabled'),
Index("idx_plugin_cron_next_run", "enabled", "next_run_at"),
Index("idx_plugin_cron_plugin", "plugin_id", "enabled"),
)
class PluginAPIGateway(Base):
"""API gateway configuration for plugin routing"""
__tablename__ = "plugin_api_gateways"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
plugin_id = Column(UUID(as_uuid=True), ForeignKey("plugins.id"), nullable=False, unique=True)
plugin_id = Column(
UUID(as_uuid=True), ForeignKey("plugins.id"), nullable=False, unique=True
)
# API routing configuration
base_path = Column(String(200), nullable=False, unique=True) # /api/v1/plugins/zammad
base_path = Column(
String(200), nullable=False, unique=True
) # /api/v1/plugins/zammad
internal_url = Column(String(500), nullable=False) # http://plugin-zammad:8000
# Security settings
require_authentication = Column(Boolean, default=True, nullable=False)
allowed_methods = Column(JSON, default=["GET", "POST", "PUT", "DELETE"]) # HTTP methods
allowed_methods = Column(
JSON, default=["GET", "POST", "PUT", "DELETE"]
) # HTTP methods
rate_limit_per_minute = Column(Integer, default=60)
rate_limit_per_hour = Column(Integer, default=1000)
# CORS settings
cors_enabled = Column(Boolean, default=True, nullable=False)
cors_origins = Column(JSON, default=["*"])
cors_methods = Column(JSON, default=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
cors_headers = Column(JSON, default=["*"])
# Circuit breaker settings
circuit_breaker_enabled = Column(Boolean, default=True, nullable=False)
failure_threshold = Column(Integer, default=5)
recovery_timeout_seconds = Column(Integer, default=60)
# Monitoring
enabled = Column(Boolean, default=True, nullable=False, index=True)
last_health_check = Column(DateTime)
health_status = Column(String(20), default="unknown") # healthy, unhealthy, timeout
# Timestamps
created_at = Column(DateTime, nullable=False, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
# Relationships
plugin = relationship("Plugin")
@@ -303,36 +339,42 @@ Add to APIKey model:
plugin_audit_logs = relationship("PluginAuditLog", back_populates="api_key")
"""
class PluginPermission(Base):
"""Plugin permission grants - tracks user permissions for plugins"""
__tablename__ = "plugin_permissions"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
plugin_id = Column(UUID(as_uuid=True), ForeignKey("plugins.id"), nullable=False)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
# Permission details
permission_name = Column(String(200), nullable=False) # e.g., 'chatbot:invoke', 'rag:query'
granted = Column(Boolean, default=True, nullable=False) # True=granted, False=revoked
permission_name = Column(
String(200), nullable=False
) # e.g., 'chatbot:invoke', 'rag:query'
granted = Column(
Boolean, default=True, nullable=False
) # True=granted, False=revoked
# Grant/revoke tracking
granted_at = Column(DateTime, nullable=False, default=func.now())
granted_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
revoked_at = Column(DateTime)
revoked_by_user_id = Column(Integer, ForeignKey("users.id"))
# Metadata
reason = Column(Text) # Reason for grant/revocation
expires_at = Column(DateTime) # Optional expiration time
# Relationships
plugin = relationship("Plugin")
user = relationship("User", foreign_keys=[user_id])
granted_by_user = relationship("User", foreign_keys=[granted_by_user_id])
revoked_by_user = relationship("User", foreign_keys=[revoked_by_user_id])
__table_args__ = (
Index('idx_plugin_permission_user_plugin', 'user_id', 'plugin_id'),
Index('idx_plugin_permission_plugin_name', 'plugin_id', 'permission_name'),
Index('idx_plugin_permission_active', 'plugin_id', 'user_id', 'granted'),
Index("idx_plugin_permission_user_plugin", "user_id", "plugin_id"),
Index("idx_plugin_permission_plugin_name", "plugin_id", "permission_name"),
Index("idx_plugin_permission_active", "plugin_id", "user_id", "granted"),
)

View File

@@ -10,33 +10,41 @@ from datetime import datetime
class PromptTemplate(Base):
"""Editable prompt templates for different chatbot types"""
__tablename__ = "prompt_templates"
id = Column(String, primary_key=True, index=True)
name = Column(String(255), nullable=False, index=True) # Human readable name
type_key = Column(String(100), nullable=False, unique=True, index=True) # assistant, customer_support, etc.
type_key = Column(
String(100), nullable=False, unique=True, index=True
) # assistant, customer_support, etc.
description = Column(Text, nullable=True)
system_prompt = Column(Text, nullable=False)
is_default = Column(Boolean, default=True, nullable=False)
is_active = Column(Boolean, default=True, nullable=False)
version = Column(Integer, default=1, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
updated_at = Column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
def __repr__(self):
return f"<PromptTemplate(type_key='{self.type_key}', name='{self.name}')>"
class ChatbotPromptVariable(Base):
"""Available variables that can be used in prompts"""
__tablename__ = "prompt_variables"
id = Column(String, primary_key=True, index=True)
variable_name = Column(String(100), nullable=False, unique=True, index=True) # {user_name}, {context}, etc.
variable_name = Column(
String(100), nullable=False, unique=True, index=True
) # {user_name}, {context}, etc.
description = Column(Text, nullable=True)
example_value = Column(String(500), nullable=True)
is_active = Column(Boolean, default=True, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
def __repr__(self):
return f"<PromptVariable(name='{self.variable_name}')>"
return f"<PromptVariable(name='{self.variable_name}')>"

View File

@@ -15,23 +15,36 @@ class RagCollection(Base):
id = Column(Integer, primary_key=True, index=True)
name = Column(String(255), nullable=False, index=True)
description = Column(Text, nullable=True)
qdrant_collection_name = Column(String(255), nullable=False, unique=True, index=True)
qdrant_collection_name = Column(
String(255), nullable=False, unique=True, index=True
)
# Metadata
document_count = Column(Integer, default=0, nullable=False)
size_bytes = Column(BigInteger, default=0, nullable=False)
vector_count = Column(Integer, default=0, nullable=False)
# Status tracking
status = Column(String(50), default='active', nullable=False) # 'active', 'indexing', 'error'
status = Column(
String(50), default="active", nullable=False
) # 'active', 'indexing', 'error'
is_active = Column(Boolean, default=True, nullable=False)
# Timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
created_at = Column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at = Column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
# Relationships
documents = relationship("RagDocument", back_populates="collection", cascade="all, delete-orphan")
documents = relationship(
"RagDocument", back_populates="collection", cascade="all, delete-orphan"
)
def to_dict(self):
"""Convert model to dictionary for API responses"""
@@ -45,8 +58,8 @@ class RagCollection(Base):
"status": self.status,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"is_active": self.is_active
"is_active": self.is_active,
}
def __repr__(self):
return f"<RagCollection(id={self.id}, name='{self.name}', documents={self.document_count})>"
return f"<RagCollection(id={self.id}, name='{self.name}', documents={self.document_count})>"

View File

@@ -3,7 +3,17 @@ RAG Document Model
Represents documents within RAG collections
"""
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, BigInteger, ForeignKey, JSON
from sqlalchemy import (
Column,
Integer,
String,
Text,
DateTime,
Boolean,
BigInteger,
ForeignKey,
JSON,
)
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from app.db.database import Base
@@ -13,11 +23,16 @@ class RagDocument(Base):
__tablename__ = "rag_documents"
id = Column(Integer, primary_key=True, index=True)
# Collection relationship
collection_id = Column(Integer, ForeignKey("rag_collections.id", ondelete="CASCADE"), nullable=False, index=True)
collection_id = Column(
Integer,
ForeignKey("rag_collections.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
collection = relationship("RagCollection", back_populates="documents")
# File information
filename = Column(String(255), nullable=False) # sanitized filename for storage
original_filename = Column(String(255), nullable=False) # user's original filename
@@ -25,29 +40,44 @@ class RagDocument(Base):
file_type = Column(String(50), nullable=False) # pdf, docx, txt, etc.
file_size = Column(BigInteger, nullable=False) # file size in bytes
mime_type = Column(String(100), nullable=True)
# Processing status
status = Column(String(50), default='processing', nullable=False) # 'processing', 'processed', 'error', 'indexed'
status = Column(
String(50), default="processing", nullable=False
) # 'processing', 'processed', 'error', 'indexed'
processing_error = Column(Text, nullable=True)
# Content information
converted_content = Column(Text, nullable=True) # markdown converted content
word_count = Column(Integer, default=0, nullable=False)
character_count = Column(Integer, default=0, nullable=False)
# Vector information
vector_count = Column(Integer, default=0, nullable=False) # number of chunks/vectors created
chunk_size = Column(Integer, default=1000, nullable=False) # chunk size used for vectorization
vector_count = Column(
Integer, default=0, nullable=False
) # number of chunks/vectors created
chunk_size = Column(
Integer, default=1000, nullable=False
) # chunk size used for vectorization
# Metadata extracted from document
document_metadata = Column(JSON, nullable=True) # language, entities, keywords, etc.
document_metadata = Column(
JSON, nullable=True
) # language, entities, keywords, etc.
# Processing timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
created_at = Column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
processed_at = Column(DateTime(timezone=True), nullable=True)
indexed_at = Column(DateTime(timezone=True), nullable=True)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
updated_at = Column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
# Soft delete
is_deleted = Column(Boolean, default=False, nullable=False)
deleted_at = Column(DateTime(timezone=True), nullable=True)
@@ -72,11 +102,13 @@ class RagDocument(Base):
"chunk_size": self.chunk_size,
"metadata": self.document_metadata or {},
"created_at": self.created_at.isoformat() if self.created_at else None,
"processed_at": self.processed_at.isoformat() if self.processed_at else None,
"processed_at": self.processed_at.isoformat()
if self.processed_at
else None,
"indexed_at": self.indexed_at.isoformat() if self.indexed_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"is_deleted": self.is_deleted
"is_deleted": self.is_deleted,
}
def __repr__(self):
return f"<RagDocument(id={self.id}, filename='{self.original_filename}', status='{self.status}')>"
return f"<RagDocument(id={self.id}, filename='{self.original_filename}', status='{self.status}')>"

158
backend/app/models/role.py Normal file
View File

@@ -0,0 +1,158 @@
"""
Role model for hierarchical permission management
"""
from datetime import datetime
from typing import Optional, List, Dict, Any
from enum import Enum
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON
from sqlalchemy.orm import relationship
from app.db.database import Base
class RoleLevel(str, Enum):
"""Role hierarchy levels"""
READ_ONLY = "read_only" # Level 1: Can only view
USER = "user" # Level 2: Can create and manage own resources
ADMIN = "admin" # Level 3: Can manage users and settings
SUPER_ADMIN = "super_admin" # Level 4: Full system access
class Role(Base):
"""Role model with hierarchical permissions"""
__tablename__ = "roles"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(50), unique=True, nullable=False)
display_name = Column(String(100), nullable=False)
description = Column(Text, nullable=True)
level = Column(String(20), nullable=False) # RoleLevel enum
# Permissions configuration
permissions = Column(JSON, default=dict) # Granular permissions
can_manage_users = Column(Boolean, default=False)
can_manage_budgets = Column(Boolean, default=False)
can_view_reports = Column(Boolean, default=False)
can_manage_tools = Column(Boolean, default=False)
# Role hierarchy
inherits_from = Column(JSON, default=list) # List of parent role names
# Status
is_active = Column(Boolean, default=True)
is_system_role = Column(Boolean, default=False) # System roles cannot be deleted
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
users = relationship("User", back_populates="role")
def __repr__(self):
return f"<Role(id={self.id}, name='{self.name}', level='{self.level}')>"
def to_dict(self):
"""Convert role to dictionary"""
return {
"id": self.id,
"name": self.name,
"display_name": self.display_name,
"description": self.description,
"level": self.level,
"permissions": self.permissions,
"can_manage_users": self.can_manage_users,
"can_manage_budgets": self.can_manage_budgets,
"can_view_reports": self.can_view_reports,
"can_manage_tools": self.can_manage_tools,
"inherits_from": self.inherits_from,
"is_active": self.is_active,
"is_system_role": self.is_system_role,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
def has_permission(self, permission: str) -> bool:
"""Check if role has a specific permission"""
if self.level == RoleLevel.SUPER_ADMIN:
return True
# Check direct permissions
if permission in self.permissions.get("granted", []):
return True
# Check denied permissions
if permission in self.permissions.get("denied", []):
return False
# Check inherited permissions (simplified)
for parent_role in self.inherits_from:
# This would require recursive checking in a real implementation
pass
return False
@classmethod
def create_default_roles(cls):
"""Create default system roles"""
roles = [
cls(
name="read_only",
display_name="Read Only",
description="Can view own data only",
level=RoleLevel.READ_ONLY,
permissions={
"granted": ["read_own"],
"denied": ["create", "update", "delete"],
},
is_system_role=True,
),
cls(
name="user",
display_name="User",
description="Standard user with full access to own resources",
level=RoleLevel.USER,
permissions={
"granted": ["read_own", "create_own", "update_own", "delete_own"],
"denied": ["manage_users", "manage_all"],
},
inherits_from=["read_only"],
is_system_role=True,
),
cls(
name="admin",
display_name="Administrator",
description="Can manage users and view reports",
level=RoleLevel.ADMIN,
permissions={
"granted": [
"read_all",
"create_all",
"update_all",
"manage_users",
"view_reports",
],
"denied": ["system_settings"],
},
inherits_from=["user"],
can_manage_users=True,
can_manage_budgets=True,
can_view_reports=True,
is_system_role=True,
),
cls(
name="super_admin",
display_name="Super Administrator",
description="Full system access",
level=RoleLevel.SUPER_ADMIN,
permissions={"granted": ["*"]}, # All permissions
inherits_from=["admin"],
can_manage_users=True,
can_manage_budgets=True,
can_view_reports=True,
can_manage_tools=True,
is_system_role=True,
),
]
return roles

272
backend/app/models/tool.py Normal file
View File

@@ -0,0 +1,272 @@
"""
Tool model for custom tool execution
"""
from datetime import datetime
from typing import Optional, List, Dict, Any
from enum import Enum
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
Boolean,
Text,
JSON,
ForeignKey,
Float,
)
from sqlalchemy.orm import relationship
from app.db.database import Base
class ToolType(str, Enum):
"""Tool execution types"""
PYTHON = "python"
BASH = "bash"
DOCKER = "docker"
API = "api"
CUSTOM = "custom"
class ToolStatus(str, Enum):
"""Tool execution status"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
TIMEOUT = "timeout"
CANCELLED = "cancelled"
class Tool(Base):
"""Tool definition model"""
__tablename__ = "tools"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), nullable=False, index=True)
display_name = Column(String(200), nullable=False)
description = Column(Text, nullable=True)
# Tool configuration
tool_type = Column(String(20), nullable=False) # ToolType enum
code = Column(Text, nullable=False) # Tool implementation code
parameters_schema = Column(JSON, default=dict) # JSON schema for parameters
return_schema = Column(JSON, default=dict) # Expected return format
# Execution settings
timeout_seconds = Column(Integer, default=30)
max_memory_mb = Column(Integer, default=256)
max_cpu_seconds = Column(Float, default=10.0)
# Docker settings (for docker type tools)
docker_image = Column(String(200), nullable=True)
docker_command = Column(Text, nullable=True)
# Access control
is_public = Column(Boolean, default=False) # Public tools available to all users
is_approved = Column(Boolean, default=False) # Admin approved for security
created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
# Categories and tags
category = Column(String(50), nullable=True)
tags = Column(JSON, default=list)
# Usage tracking
usage_count = Column(Integer, default=0)
last_used_at = Column(DateTime, nullable=True)
# Status
is_active = Column(Boolean, default=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
created_by = relationship("User", back_populates="created_tools")
executions = relationship(
"ToolExecution", back_populates="tool", cascade="all, delete-orphan"
)
def __repr__(self):
return f"<Tool(id={self.id}, name='{self.name}', type='{self.tool_type}')>"
def to_dict(self):
"""Convert tool to dictionary"""
return {
"id": self.id,
"name": self.name,
"display_name": self.display_name,
"description": self.description,
"tool_type": self.tool_type,
"parameters_schema": self.parameters_schema,
"return_schema": self.return_schema,
"timeout_seconds": self.timeout_seconds,
"max_memory_mb": self.max_memory_mb,
"max_cpu_seconds": self.max_cpu_seconds,
"docker_image": self.docker_image,
"is_public": self.is_public,
"is_approved": self.is_approved,
"created_by_user_id": self.created_by_user_id,
"category": self.category,
"tags": self.tags,
"usage_count": self.usage_count,
"last_used_at": self.last_used_at.isoformat()
if self.last_used_at
else None,
"is_active": self.is_active,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
def increment_usage(self):
"""Increment usage count and update last used timestamp"""
self.usage_count += 1
self.last_used_at = datetime.utcnow()
def can_be_used_by(self, user) -> bool:
"""Check if user can use this tool"""
# Tool creator can always use their tools
if self.created_by_user_id == user.id:
return True
# Public and approved tools can be used by anyone
if self.is_public and self.is_approved:
return True
# Admin users can use any tool
if user.has_permission("manage_tools"):
return True
return False
class ToolExecution(Base):
"""Tool execution instance model"""
__tablename__ = "tool_executions"
id = Column(Integer, primary_key=True, index=True)
# Tool and user references
tool_id = Column(Integer, ForeignKey("tools.id"), nullable=False)
executed_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
# Execution details
parameters = Column(JSON, default=dict) # Input parameters
status = Column(String(20), nullable=False, default=ToolStatus.PENDING)
# Results
output = Column(Text, nullable=True) # Tool output
error_message = Column(Text, nullable=True) # Error details if failed
return_code = Column(Integer, nullable=True) # Exit code
# Resource usage
execution_time_ms = Column(Integer, nullable=True) # Actual execution time
memory_used_mb = Column(Float, nullable=True) # Peak memory usage
cpu_time_ms = Column(Integer, nullable=True) # CPU time used
# Docker execution details
container_id = Column(String(100), nullable=True) # Docker container ID
docker_logs = Column(Text, nullable=True) # Docker container logs
# Timestamps
started_at = Column(DateTime, nullable=True)
completed_at = Column(DateTime, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
# Relationships
tool = relationship("Tool", back_populates="executions")
executed_by = relationship("User", back_populates="tool_executions")
def __repr__(self):
return f"<ToolExecution(id={self.id}, tool_id={self.tool_id}, status='{self.status}')>"
def to_dict(self):
"""Convert execution to dictionary"""
return {
"id": self.id,
"tool_id": self.tool_id,
"tool_name": self.tool.name if self.tool else None,
"executed_by_user_id": self.executed_by_user_id,
"parameters": self.parameters,
"status": self.status,
"output": self.output,
"error_message": self.error_message,
"return_code": self.return_code,
"execution_time_ms": self.execution_time_ms,
"memory_used_mb": self.memory_used_mb,
"cpu_time_ms": self.cpu_time_ms,
"container_id": self.container_id,
"started_at": self.started_at.isoformat() if self.started_at else None,
"completed_at": self.completed_at.isoformat()
if self.completed_at
else None,
"created_at": self.created_at.isoformat() if self.created_at else None,
}
@property
def duration_seconds(self) -> float:
"""Calculate execution duration in seconds"""
if self.started_at and self.completed_at:
return (self.completed_at - self.started_at).total_seconds()
return 0.0
def is_running(self) -> bool:
"""Check if execution is currently running"""
return self.status in [ToolStatus.PENDING, ToolStatus.RUNNING]
def is_finished(self) -> bool:
"""Check if execution is finished (success or failure)"""
return self.status in [
ToolStatus.COMPLETED,
ToolStatus.FAILED,
ToolStatus.TIMEOUT,
ToolStatus.CANCELLED,
]
class ToolCategory(Base):
"""Tool category for organization"""
__tablename__ = "tool_categories"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(50), unique=True, nullable=False)
display_name = Column(String(100), nullable=False)
description = Column(Text, nullable=True)
# Visual
icon = Column(String(50), nullable=True) # Icon name
color = Column(String(20), nullable=True) # Color code
# Ordering
sort_order = Column(Integer, default=0)
# Status
is_active = Column(Boolean, default=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
def __repr__(self):
return f"<ToolCategory(id={self.id}, name='{self.name}')>"
def to_dict(self):
"""Convert category to dictionary"""
return {
"id": self.id,
"name": self.name,
"display_name": self.display_name,
"description": self.description,
"icon": self.icon,
"color": self.color,
"sort_order": self.sort_order,
"is_active": self.is_active,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}

View File

@@ -4,63 +4,73 @@ Usage Tracking model for API key usage statistics
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON, ForeignKey, Float
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
Boolean,
Text,
JSON,
ForeignKey,
Float,
)
from sqlalchemy.orm import relationship
from app.db.database import Base
class UsageTracking(Base):
"""Usage tracking model for detailed API key usage statistics"""
__tablename__ = "usage_tracking"
id = Column(Integer, primary_key=True, index=True)
# API Key relationship
api_key_id = Column(Integer, ForeignKey("api_keys.id"), nullable=False)
api_key = relationship("APIKey", back_populates="usage_tracking")
# User relationship
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
user = relationship("User", back_populates="usage_tracking")
# Budget relationship (optional)
budget_id = Column(Integer, ForeignKey("budgets.id"), nullable=True)
budget = relationship("Budget", back_populates="usage_tracking")
# Request information
endpoint = Column(String, nullable=False) # API endpoint used
method = Column(String, nullable=False) # HTTP method
model = Column(String, nullable=True) # Model used (if applicable)
# Usage metrics
request_tokens = Column(Integer, default=0) # Input tokens
response_tokens = Column(Integer, default=0) # Output tokens
total_tokens = Column(Integer, default=0) # Total tokens used
# Cost tracking
cost_cents = Column(Integer, default=0) # Cost in cents
cost_currency = Column(String, default="USD") # Currency
# Response information
response_time_ms = Column(Integer, nullable=True) # Response time in milliseconds
status_code = Column(Integer, nullable=True) # HTTP status code
# Request metadata
request_id = Column(String, nullable=True) # Unique request identifier
session_id = Column(String, nullable=True) # Session identifier
ip_address = Column(String, nullable=True) # Client IP address
user_agent = Column(String, nullable=True) # User agent
# Additional metadata
request_metadata = Column(JSON, default=dict) # Additional request metadata
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
def __repr__(self):
return f"<UsageTracking(id={self.id}, api_key_id={self.api_key_id}, endpoint='{self.endpoint}')>"
def to_dict(self):
"""Convert usage tracking to dictionary for API responses"""
return {
@@ -82,9 +92,9 @@ class UsageTracking(Base):
"ip_address": self.ip_address,
"user_agent": self.user_agent,
"request_metadata": self.request_metadata,
"created_at": self.created_at.isoformat() if self.created_at else None
"created_at": self.created_at.isoformat() if self.created_at else None,
}
@classmethod
def create_tracking_record(
cls,
@@ -102,7 +112,7 @@ class UsageTracking(Base):
session_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
request_metadata: Optional[dict] = None
request_metadata: Optional[dict] = None,
) -> "UsageTracking":
"""Create a new usage tracking record"""
return cls(
@@ -121,5 +131,5 @@ class UsageTracking(Base):
session_id=session_id,
ip_address=ip_address,
user_agent=user_agent,
request_metadata=request_metadata or {}
)
request_metadata=request_metadata or {},
)

View File

@@ -4,75 +4,119 @@ User model
from datetime import datetime
from typing import Optional, List
from enum import Enum
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
Boolean,
Text,
JSON,
ForeignKey,
Numeric,
)
from sqlalchemy.orm import relationship
from sqlalchemy import inspect as sa_inspect
from app.db.database import Base
class UserRole(str, Enum):
"""User role enumeration"""
USER = "user"
ADMIN = "admin"
SUPER_ADMIN = "super_admin"
from decimal import Decimal
class User(Base):
"""User model for authentication and user management"""
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
email = Column(String, unique=True, index=True, nullable=False)
username = Column(String, unique=True, index=True, nullable=False)
hashed_password = Column(String, nullable=False)
full_name = Column(String, nullable=True)
# User status and permissions
# Account status
is_active = Column(Boolean, default=True)
is_superuser = Column(Boolean, default=False)
is_verified = Column(Boolean, default=False)
# Role-based access control
role = Column(String, default=UserRole.USER.value) # user, admin, super_admin
permissions = Column(JSON, default=dict) # Custom permissions
is_superuser = Column(Boolean, default=False) # Legacy field for compatibility
# Role-based access control (using new Role model)
role_id = Column(Integer, ForeignKey("roles.id"), nullable=True)
custom_permissions = Column(JSON, default=dict) # Custom permissions override
# Account management
account_locked = Column(Boolean, default=False)
account_locked_until = Column(DateTime, nullable=True)
failed_login_attempts = Column(Integer, default=0)
last_failed_login = Column(DateTime, nullable=True)
force_password_change = Column(Boolean, default=False)
# Profile information
avatar_url = Column(String, nullable=True)
bio = Column(Text, nullable=True)
company = Column(String, nullable=True)
website = Column(String, nullable=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
last_login = Column(DateTime, nullable=True)
# Settings
preferences = Column(JSON, default=dict)
notification_settings = Column(JSON, default=dict)
# Relationships
api_keys = relationship("APIKey", back_populates="user", cascade="all, delete-orphan")
usage_tracking = relationship("UsageTracking", back_populates="user", cascade="all, delete-orphan")
budgets = relationship("Budget", back_populates="user", cascade="all, delete-orphan")
audit_logs = relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
role = relationship("Role", back_populates="users")
api_keys = relationship(
"APIKey", back_populates="user", cascade="all, delete-orphan"
)
usage_tracking = relationship(
"UsageTracking", back_populates="user", cascade="all, delete-orphan"
)
budgets = relationship(
"Budget", back_populates="user", cascade="all, delete-orphan"
)
audit_logs = relationship(
"AuditLog", back_populates="user", cascade="all, delete-orphan"
)
installed_plugins = relationship("Plugin", back_populates="installed_by_user")
created_tools = relationship(
"Tool", back_populates="created_by", cascade="all, delete-orphan"
)
tool_executions = relationship(
"ToolExecution", back_populates="executed_by", cascade="all, delete-orphan"
)
notifications = relationship(
"Notification", back_populates="user", cascade="all, delete-orphan"
)
def __repr__(self):
return f"<User(id={self.id}, email='{self.email}', username='{self.username}')>"
def to_dict(self):
"""Convert user to dictionary for API responses"""
# Check if role relationship is loaded to avoid lazy loading in async context
inspector = sa_inspect(self)
role_loaded = "role" not in inspector.unloaded
return {
"id": self.id,
"email": self.email,
"username": self.username,
"full_name": self.full_name,
"is_active": self.is_active,
"is_superuser": self.is_superuser,
"is_verified": self.is_verified,
"role": self.role,
"permissions": self.permissions,
"is_superuser": self.is_superuser,
"role_id": self.role_id,
"role": self.role.to_dict() if role_loaded and self.role else None,
"custom_permissions": self.custom_permissions,
"account_locked": self.account_locked,
"account_locked_until": self.account_locked_until.isoformat()
if self.account_locked_until
else None,
"failed_login_attempts": self.failed_login_attempts,
"last_failed_login": self.last_failed_login.isoformat()
if self.last_failed_login
else None,
"force_password_change": self.force_password_change,
"avatar_url": self.avatar_url,
"bio": self.bio,
"company": self.company,
@@ -81,54 +125,157 @@ class User(Base):
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"last_login": self.last_login.isoformat() if self.last_login else None,
"preferences": self.preferences,
"notification_settings": self.notification_settings
"notification_settings": self.notification_settings,
}
def has_permission(self, permission: str) -> bool:
"""Check if user has a specific permission"""
"""Check if user has a specific permission using role hierarchy"""
if self.is_superuser:
return True
# Check role-based permissions
role_permissions = {
"user": ["read_own", "create_own", "update_own"],
"admin": ["read_all", "create_all", "update_all", "delete_own"],
"super_admin": ["read_all", "create_all", "update_all", "delete_all", "manage_users", "manage_modules"]
}
if self.role in role_permissions and permission in role_permissions[self.role]:
# Check custom permissions first (override)
if permission in self.custom_permissions.get("denied", []):
return False
if permission in self.custom_permissions.get("granted", []):
return True
# Check custom permissions
return permission in self.permissions
# Check role permissions if user has a role assigned
if self.role:
return self.role.has_permission(permission)
return False
def can_access_module(self, module_name: str) -> bool:
"""Check if user can access a specific module"""
if self.is_superuser:
return True
# Check module-specific permissions
module_permissions = self.permissions.get("modules", {})
return module_permissions.get(module_name, False)
# Check custom permissions first
module_permissions = self.custom_permissions.get("modules", {})
if module_name in module_permissions:
return module_permissions[module_name]
# Check role permissions
if self.role:
# For admin roles, allow all modules
if self.role.level in ["admin", "super_admin"]:
return True
# For regular users, check module access
elif self.role.level == "user":
return True # Basic users can access all modules
# For read-only users, limit access
elif self.role.level == "read_only":
return module_name in ["chatbot", "analytics"] # Only certain modules
return False
def update_last_login(self):
"""Update the last login timestamp"""
self.last_login = datetime.utcnow()
def update_preferences(self, preferences: dict):
"""Update user preferences"""
if self.preferences is None:
self.preferences = {}
self.preferences.update(preferences)
def update_notification_settings(self, settings: dict):
"""Update notification settings"""
if self.notification_settings is None:
self.notification_settings = {}
self.notification_settings.update(settings)
def get_effective_permissions(self) -> dict:
"""Get all effective permissions combining role and custom permissions"""
permissions = {"granted": set(), "denied": set()}
# Start with role permissions
if self.role:
role_perms = self.role.permissions
permissions["granted"].update(role_perms.get("granted", []))
permissions["denied"].update(role_perms.get("denied", []))
# Apply custom permissions (override role permissions)
permissions["granted"].update(self.custom_permissions.get("granted", []))
permissions["denied"].update(self.custom_permissions.get("denied", []))
# Remove any denied permissions from granted
permissions["granted"] -= permissions["denied"]
return {
"granted": list(permissions["granted"]),
"denied": list(permissions["denied"]),
}
def can_create_api_key(self) -> bool:
"""Check if user can create API keys based on role and limits"""
if not self.is_active or self.account_locked:
return False
# Check permission
if not self.has_permission("create_api_key"):
return False
# Check if user has reached their API key limit
current_keys = [key for key in self.api_keys if key.is_active]
max_keys = (
self.role.permissions.get("limits", {}).get("max_api_keys", 5)
if self.role
else 5
)
return len(current_keys) < max_keys
def can_create_tool(self) -> bool:
"""Check if user can create custom tools"""
return (
self.is_active
and not self.account_locked
and self.has_permission("create_tool")
)
def is_budget_exceeded(self) -> bool:
"""Check if user has exceeded their budget limits"""
if not self.budgets:
return False
active_budget = next((b for b in self.budgets if b.is_active), None)
if not active_budget:
return False
return active_budget.current_usage > active_budget.limit
def lock_account(self, duration_hours: int = 24):
"""Lock user account for specified duration"""
from datetime import timedelta
self.account_locked = True
self.account_locked_until = datetime.utcnow() + timedelta(hours=duration_hours)
def unlock_account(self):
"""Unlock user account"""
self.account_locked = False
self.account_locked_until = None
self.failed_login_attempts = 0
def record_failed_login(self):
"""Record a failed login attempt"""
self.failed_login_attempts += 1
self.last_failed_login = datetime.utcnow()
# Lock account after 5 failed attempts
if self.failed_login_attempts >= 5:
self.lock_account(24) # Lock for 24 hours
def reset_failed_logins(self):
"""Reset failed login counter"""
self.failed_login_attempts = 0
self.last_failed_login = None
@classmethod
def create_default_admin(cls, email: str, username: str, password_hash: str) -> "User":
def create_default_admin(
cls, email: str, username: str, password_hash: str
) -> "User":
"""Create a default admin user"""
return cls(
email=email,
@@ -136,24 +283,16 @@ class User(Base):
hashed_password=password_hash,
full_name="System Administrator",
is_active=True,
is_superuser=True,
is_superuser=True, # Legacy compatibility
is_verified=True,
role="super_admin",
permissions={
"modules": {
"cache": True,
"analytics": True,
"rag": True
}
},
preferences={
"theme": "dark",
"language": "en",
"timezone": "UTC"
# Note: role_id will be set after role is created in init_db
custom_permissions={
"modules": {"cache": True, "analytics": True, "rag": True}
},
preferences={"theme": "dark", "language": "en", "timezone": "UTC"},
notification_settings={
"email_notifications": True,
"security_alerts": True,
"system_updates": True
}
)
"system_updates": True,
},
)

View File

@@ -15,7 +15,4 @@ __version__ = "1.0.0"
__author__ = "Enclava Team"
# Export main classes for easy importing
__all__ = [
"ChatbotModule",
"create_module"
]
__all__ = ["ChatbotModule", "create_module"]

File diff suppressed because it is too large Load Diff

View File

@@ -10,7 +10,7 @@ from typing import Dict, Optional, Any
import logging
# Import all modules
from .rag.main import RAGModule
from .rag.main import RAGModule
from .chatbot.main import ChatbotModule, create_module as create_chatbot_module
from .workflow.main import WorkflowModule
@@ -19,11 +19,11 @@ from app.services.litellm_client import LiteLLMClient
# Import protocols for type safety
from .protocols import (
RAGServiceProtocol,
ChatbotServiceProtocol,
RAGServiceProtocol,
ChatbotServiceProtocol,
LiteLLMClientProtocol,
WorkflowServiceProtocol,
ServiceRegistry
ServiceRegistry,
)
logger = logging.getLogger(__name__)
@@ -31,113 +31,119 @@ logger = logging.getLogger(__name__)
class ModuleFactory:
"""Factory for creating and wiring module dependencies"""
def __init__(self):
self.modules: Dict[str, Any] = {}
self.initialized = False
async def create_all_modules(self, config: Optional[Dict[str, Any]] = None) -> ServiceRegistry:
async def create_all_modules(
self, config: Optional[Dict[str, Any]] = None
) -> ServiceRegistry:
"""
Create all modules with proper dependency injection
Args:
config: Optional configuration for modules
Returns:
Dictionary of created modules with their dependencies wired
"""
config = config or {}
logger.info("Creating modules with dependency injection...")
# Step 1: Create LiteLLM client (shared dependency)
litellm_client = LiteLLMClient()
# Step 2: Create RAG module (no dependencies on other modules)
rag_module = RAGModule(config=config.get("rag", {}))
# Step 3: Create chatbot module with RAG dependency
chatbot_module = create_chatbot_module(
litellm_client=litellm_client,
rag_service=rag_module # RAG module implements RAGServiceProtocol
rag_service=rag_module, # RAG module implements RAGServiceProtocol
)
# Step 4: Create workflow module with chatbot dependency
# Step 4: Create workflow module with chatbot dependency
workflow_module = WorkflowModule(
chatbot_service=chatbot_module # Chatbot module implements ChatbotServiceProtocol
)
# Store all modules
modules = {
"rag": rag_module,
"chatbot": chatbot_module,
"workflow": workflow_module
"workflow": workflow_module,
}
logger.info(f"Created {len(modules)} modules with dependencies wired")
# Initialize all modules
await self._initialize_modules(modules, config)
self.modules = modules
self.initialized = True
return modules
async def _initialize_modules(self, modules: Dict[str, Any], config: Dict[str, Any]):
async def _initialize_modules(
self, modules: Dict[str, Any], config: Dict[str, Any]
):
"""Initialize all modules in dependency order"""
# Initialize in dependency order (modules with no deps first)
initialization_order = [
("rag", modules["rag"]),
("chatbot", modules["chatbot"]), # Depends on RAG
("workflow", modules["workflow"]) # Depends on Chatbot
("workflow", modules["workflow"]), # Depends on Chatbot
]
for module_name, module in initialization_order:
try:
logger.info(f"Initializing {module_name} module...")
module_config = config.get(module_name, {})
# Different modules have different initialization patterns
if hasattr(module, 'initialize'):
if hasattr(module, "initialize"):
if module_name == "rag":
await module.initialize()
else:
await module.initialize(**module_config)
logger.info(f"{module_name} module initialized successfully")
except Exception as e:
logger.error(f"❌ Failed to initialize {module_name} module: {e}")
raise RuntimeError(f"Module initialization failed: {module_name}") from e
raise RuntimeError(
f"Module initialization failed: {module_name}"
) from e
async def cleanup_all_modules(self):
"""Cleanup all modules in reverse dependency order"""
if not self.initialized:
return
# Cleanup in reverse order
cleanup_order = ["workflow", "chatbot", "rag"]
for module_name in cleanup_order:
if module_name in self.modules:
try:
logger.info(f"Cleaning up {module_name} module...")
module = self.modules[module_name]
if hasattr(module, 'cleanup'):
if hasattr(module, "cleanup"):
await module.cleanup()
logger.info(f"{module_name} module cleaned up")
except Exception as e:
logger.error(f"❌ Error cleaning up {module_name}: {e}")
self.modules.clear()
self.initialized = False
def get_module(self, name: str) -> Optional[Any]:
"""Get a module by name"""
return self.modules.get(name)
def is_initialized(self) -> bool:
"""Check if factory is initialized"""
return self.initialized
@@ -174,13 +180,16 @@ def create_rag_module(config: Optional[Dict[str, Any]] = None) -> RAGModule:
return RAGModule(config=config or {})
def create_chatbot_with_rag(rag_service: RAGServiceProtocol,
litellm_client: LiteLLMClientProtocol) -> ChatbotModule:
def create_chatbot_with_rag(
rag_service: RAGServiceProtocol, litellm_client: LiteLLMClientProtocol
) -> ChatbotModule:
"""Create chatbot module with RAG dependency"""
return create_chatbot_module(litellm_client=litellm_client, rag_service=rag_service)
def create_workflow_with_chatbot(chatbot_service: ChatbotServiceProtocol) -> WorkflowModule:
def create_workflow_with_chatbot(
chatbot_service: ChatbotServiceProtocol,
) -> WorkflowModule:
"""Create workflow module with chatbot dependency"""
return WorkflowModule(chatbot_service=chatbot_service)
@@ -188,38 +197,38 @@ def create_workflow_with_chatbot(chatbot_service: ChatbotServiceProtocol) -> Wor
# Module registry for backward compatibility
class ModuleRegistry:
"""Registry that provides access to modules (for backward compatibility)"""
def __init__(self, factory: ModuleFactory):
self._factory = factory
@property
def modules(self) -> Dict[str, Any]:
"""Get all modules (compatible with existing module_manager interface)"""
return self._factory.modules
def get(self, name: str) -> Optional[Any]:
"""Get module by name"""
return self._factory.get_module(name)
def __getitem__(self, name: str) -> Any:
"""Support dictionary-style access"""
module = self.get(name)
if module is None:
raise KeyError(f"Module '{name}' not found")
return module
def keys(self):
"""Get module names"""
return self._factory.modules.keys()
def values(self):
"""Get module instances"""
"""Get module instances"""
return self._factory.modules.values()
def items(self):
"""Get module name-instance pairs"""
return self._factory.modules.items()
# Create registry instance for backward compatibility
module_registry = ModuleRegistry(module_factory)
module_registry = ModuleRegistry(module_factory)

View File

@@ -12,44 +12,48 @@ from abc import abstractmethod
class RAGServiceProtocol(Protocol):
"""Protocol for RAG (Retrieval-Augmented Generation) service interface"""
@abstractmethod
async def search(self, query: str, collection_name: str, top_k: int) -> Dict[str, Any]:
async def search(
self, query: str, collection_name: str, top_k: int
) -> Dict[str, Any]:
"""
Search for relevant documents
Args:
query: Search query string
collection_name: Name of the collection to search in
top_k: Number of top results to return
Returns:
Dictionary containing search results with 'results' key
"""
...
@abstractmethod
async def index_document(self, content: str, metadata: Dict[str, Any] = None) -> str:
async def index_document(
self, content: str, metadata: Dict[str, Any] = None
) -> str:
"""
Index a document in the vector database
Args:
content: Document content to index
metadata: Optional metadata for the document
Returns:
Document ID
"""
...
@abstractmethod
async def delete_document(self, document_id: str) -> bool:
"""
Delete a document from the vector database
Args:
document_id: ID of document to delete
Returns:
True if successfully deleted
"""
@@ -58,32 +62,32 @@ class RAGServiceProtocol(Protocol):
class ChatbotServiceProtocol(Protocol):
"""Protocol for Chatbot service interface"""
@abstractmethod
async def chat_completion(self, request: Any, user_id: str, db: Any) -> Any:
"""
Generate chat completion response
Args:
request: Chat request object
user_id: ID of the user making the request
db: Database session
Returns:
Chat response object
"""
...
@abstractmethod
async def create_chatbot(self, config: Any, user_id: str, db: Any) -> Any:
"""
Create a new chatbot instance
Args:
config: Chatbot configuration
user_id: ID of the user creating the chatbot
db: Database session
Returns:
Created chatbot instance
"""
@@ -92,35 +96,43 @@ class ChatbotServiceProtocol(Protocol):
class LiteLLMClientProtocol(Protocol):
"""Protocol for LiteLLM client interface"""
@abstractmethod
async def completion(self, model: str, messages: List[Dict[str, str]], **kwargs) -> Any:
async def completion(
self, model: str, messages: List[Dict[str, str]], **kwargs
) -> Any:
"""
Create a completion using the specified model
Args:
model: Model name to use
messages: List of messages for the conversation
**kwargs: Additional parameters for the completion
Returns:
Completion response object
"""
...
@abstractmethod
async def create_chat_completion(self, model: str, messages: List[Dict[str, str]],
user_id: str, api_key_id: str, **kwargs) -> Any:
async def create_chat_completion(
self,
model: str,
messages: List[Dict[str, str]],
user_id: str,
api_key_id: str,
**kwargs,
) -> Any:
"""
Create a chat completion with user tracking
Args:
model: Model name to use
messages: List of messages for the conversation
user_id: ID of the user making the request
api_key_id: API key identifier
**kwargs: Additional parameters
Returns:
Chat completion response
"""
@@ -129,44 +141,44 @@ class LiteLLMClientProtocol(Protocol):
class CacheServiceProtocol(Protocol):
"""Protocol for Cache service interface"""
@abstractmethod
async def get(self, key: str, default: Any = None) -> Any:
"""
Get value from cache
Args:
key: Cache key
default: Default value if key not found
Returns:
Cached value or default
"""
...
@abstractmethod
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""
Set value in cache
Args:
key: Cache key
value: Value to cache
ttl: Time to live in seconds
Returns:
True if successfully cached
"""
...
@abstractmethod
async def delete(self, key: str) -> bool:
"""
Delete key from cache
Args:
key: Cache key to delete
Returns:
True if successfully deleted
"""
@@ -175,28 +187,28 @@ class CacheServiceProtocol(Protocol):
class SecurityServiceProtocol(Protocol):
"""Protocol for Security service interface"""
@abstractmethod
async def analyze_request(self, request: Any) -> Any:
"""
Perform security analysis on a request
Args:
request: Request object to analyze
Returns:
Security analysis result
"""
...
@abstractmethod
async def validate_request(self, request: Any) -> bool:
"""
Validate request for security compliance
Args:
request: Request object to validate
Returns:
True if request is valid/safe
"""
@@ -205,29 +217,31 @@ class SecurityServiceProtocol(Protocol):
class WorkflowServiceProtocol(Protocol):
"""Protocol for Workflow service interface"""
@abstractmethod
async def execute_workflow(self, workflow: Any, input_data: Dict[str, Any] = None) -> Any:
async def execute_workflow(
self, workflow: Any, input_data: Dict[str, Any] = None
) -> Any:
"""
Execute a workflow definition
Args:
workflow: Workflow definition to execute
input_data: Optional input data for the workflow
Returns:
Workflow execution result
"""
...
@abstractmethod
async def get_execution(self, execution_id: str) -> Any:
"""
Get workflow execution status
Args:
execution_id: ID of the execution to retrieve
Returns:
Execution status object
"""
@@ -236,17 +250,17 @@ class WorkflowServiceProtocol(Protocol):
class ModuleServiceProtocol(Protocol):
"""Base protocol for all module services"""
@abstractmethod
async def initialize(self, **kwargs) -> None:
"""Initialize the module"""
...
@abstractmethod
async def cleanup(self) -> None:
"""Cleanup module resources"""
...
@abstractmethod
def get_required_permissions(self) -> List[Any]:
"""Get required permissions for this module"""
@@ -255,4 +269,4 @@ class ModuleServiceProtocol(Protocol):
# Type aliases for common service combinations
ServiceRegistry = Dict[str, ModuleServiceProtocol]
ServiceDependencies = Dict[str, Optional[ModuleServiceProtocol]]
ServiceDependencies = Dict[str, Optional[ModuleServiceProtocol]]

View File

@@ -3,4 +3,4 @@ RAG (Retrieval-Augmented Generation) module for Confidential Empire platform
"""
from .main import RAGModule
__all__ = ["RAGModule"]
__all__ = ["RAGModule"]

File diff suppressed because it is too large Load Diff

View File

@@ -13,77 +13,102 @@ from pathlib import Path
class PluginRuntimeSpec(BaseModel):
"""Plugin runtime requirements and dependencies"""
python_version: str = Field("3.11", description="Required Python version")
dependencies: List[str] = Field(default_factory=list, description="Required Python packages")
environment_variables: Dict[str, str] = Field(default_factory=dict, description="Required environment variables")
@validator('python_version')
dependencies: List[str] = Field(
default_factory=list, description="Required Python packages"
)
environment_variables: Dict[str, str] = Field(
default_factory=dict, description="Required environment variables"
)
@validator("python_version")
def validate_python_version(cls, v):
if not v.startswith(('3.9', '3.10', '3.11', '3.12')):
raise ValueError('Python version must be 3.9, 3.10, 3.11, or 3.12')
if not v.startswith(("3.9", "3.10", "3.11", "3.12")):
raise ValueError("Python version must be 3.9, 3.10, 3.11, or 3.12")
return v
class PluginPermissions(BaseModel):
"""Plugin permission specifications"""
platform_apis: List[str] = Field(default_factory=list, description="Platform API access scopes")
plugin_scopes: List[str] = Field(default_factory=list, description="Plugin-specific permission scopes")
external_domains: List[str] = Field(default_factory=list, description="Allowed external domains")
@validator('platform_apis')
platform_apis: List[str] = Field(
default_factory=list, description="Platform API access scopes"
)
plugin_scopes: List[str] = Field(
default_factory=list, description="Plugin-specific permission scopes"
)
external_domains: List[str] = Field(
default_factory=list, description="Allowed external domains"
)
@validator("platform_apis")
def validate_platform_apis(cls, v):
allowed_apis = [
'chatbot:invoke', 'chatbot:manage', 'chatbot:read',
'rag:query', 'rag:manage', 'rag:read',
'llm:completion', 'llm:embeddings', 'llm:models',
'workflow:execute', 'workflow:read',
'cache:read', 'cache:write'
"chatbot:invoke",
"chatbot:manage",
"chatbot:read",
"rag:query",
"rag:manage",
"rag:read",
"llm:completion",
"llm:embeddings",
"llm:models",
"workflow:execute",
"workflow:read",
"cache:read",
"cache:write",
]
for api in v:
if api not in allowed_apis and not api.endswith(':*'):
raise ValueError(f'Invalid platform API scope: {api}')
if api not in allowed_apis and not api.endswith(":*"):
raise ValueError(f"Invalid platform API scope: {api}")
return v
class PluginDatabaseSpec(BaseModel):
"""Plugin database configuration"""
schema: str = Field(..., description="Database schema name")
migrations_path: str = Field("./migrations", description="Path to migration files")
auto_migrate: bool = Field(True, description="Auto-run migrations on startup")
@validator('schema')
@validator("schema")
def validate_schema_name(cls, v):
if not v.startswith('plugin_'):
if not v.startswith("plugin_"):
raise ValueError('Database schema must start with "plugin_"')
if not v.replace('plugin_', '').replace('_', '').isalnum():
raise ValueError('Schema name must contain only alphanumeric characters and underscores')
if not v.replace("plugin_", "").replace("_", "").isalnum():
raise ValueError(
"Schema name must contain only alphanumeric characters and underscores"
)
return v
class PluginAPIEndpoint(BaseModel):
"""Plugin API endpoint specification"""
path: str = Field(..., description="API endpoint path")
methods: List[str] = Field(default=['GET'], description="Allowed HTTP methods")
methods: List[str] = Field(default=["GET"], description="Allowed HTTP methods")
description: str = Field("", description="Endpoint description")
auth_required: bool = Field(True, description="Whether authentication is required")
@validator('methods')
@validator("methods")
def validate_methods(cls, v):
allowed_methods = ['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS']
allowed_methods = ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]
for method in v:
if method not in allowed_methods:
raise ValueError(f'Invalid HTTP method: {method}')
raise ValueError(f"Invalid HTTP method: {method}")
return v
@validator('path')
@validator("path")
def validate_path(cls, v):
if not v.startswith('/'):
if not v.startswith("/"):
raise ValueError('API path must start with "/"')
return v
class PluginCronJob(BaseModel):
"""Plugin scheduled job specification"""
name: str = Field(..., description="Job name")
schedule: str = Field(..., description="Cron expression")
function: str = Field(..., description="Function to execute")
@@ -91,41 +116,56 @@ class PluginCronJob(BaseModel):
enabled: bool = Field(True, description="Whether job is enabled by default")
timeout_seconds: int = Field(300, description="Job timeout in seconds")
max_retries: int = Field(3, description="Maximum retry attempts")
@validator('schedule')
@validator("schedule")
def validate_cron_expression(cls, v):
# Basic cron validation - should have 5 parts
parts = v.split()
if len(parts) != 5:
raise ValueError('Cron expression must have 5 parts (minute hour day month weekday)')
raise ValueError(
"Cron expression must have 5 parts (minute hour day month weekday)"
)
return v
class PluginUIConfig(BaseModel):
"""Plugin UI configuration"""
configuration_schema: str = Field("./config_schema.json", description="JSON schema for configuration")
configuration_schema: str = Field(
"./config_schema.json", description="JSON schema for configuration"
)
ui_components: str = Field("./ui/components", description="Path to UI components")
pages: List[Dict[str, str]] = Field(default_factory=list, description="Plugin pages")
@validator('pages')
pages: List[Dict[str, str]] = Field(
default_factory=list, description="Plugin pages"
)
@validator("pages")
def validate_pages(cls, v):
required_fields = ['name', 'path', 'component']
required_fields = ["name", "path", "component"]
for page in v:
for field in required_fields:
if field not in page:
raise ValueError(f'Page must have {field} field')
raise ValueError(f"Page must have {field} field")
return v
class PluginExternalServices(BaseModel):
"""Plugin external service configuration"""
allowed_domains: List[str] = Field(default_factory=list, description="Allowed external domains")
webhooks: List[Dict[str, str]] = Field(default_factory=list, description="Webhook configurations")
rate_limits: Dict[str, int] = Field(default_factory=dict, description="Rate limits per domain")
allowed_domains: List[str] = Field(
default_factory=list, description="Allowed external domains"
)
webhooks: List[Dict[str, str]] = Field(
default_factory=list, description="Webhook configurations"
)
rate_limits: Dict[str, int] = Field(
default_factory=dict, description="Rate limits per domain"
)
class PluginMetadata(BaseModel):
"""Plugin metadata information"""
name: str = Field(..., description="Plugin name (must be unique)")
version: str = Field(..., description="Plugin version (semantic versioning)")
description: str = Field(..., description="Plugin description")
@@ -133,58 +173,78 @@ class PluginMetadata(BaseModel):
license: str = Field("MIT", description="Plugin license")
homepage: Optional[HttpUrl] = Field(None, description="Plugin homepage URL")
repository: Optional[HttpUrl] = Field(None, description="Plugin repository URL")
tags: List[str] = Field(default_factory=list, description="Plugin tags for discovery")
@validator('name')
tags: List[str] = Field(
default_factory=list, description="Plugin tags for discovery"
)
@validator("name")
def validate_name(cls, v):
if not v.replace('-', '').replace('_', '').isalnum():
raise ValueError('Plugin name must contain only alphanumeric characters, hyphens, and underscores')
if not v.replace("-", "").replace("_", "").isalnum():
raise ValueError(
"Plugin name must contain only alphanumeric characters, hyphens, and underscores"
)
if len(v) < 3 or len(v) > 50:
raise ValueError('Plugin name must be between 3 and 50 characters')
raise ValueError("Plugin name must be between 3 and 50 characters")
return v.lower()
@validator('version')
@validator("version")
def validate_version(cls, v):
# Basic semantic versioning validation
parts = v.split('.')
parts = v.split(".")
if len(parts) != 3:
raise ValueError('Version must follow semantic versioning (x.y.z)')
raise ValueError("Version must follow semantic versioning (x.y.z)")
for part in parts:
if not part.isdigit():
raise ValueError('Version parts must be numeric')
raise ValueError("Version parts must be numeric")
return v
class PluginManifest(BaseModel):
"""Complete plugin manifest specification"""
apiVersion: str = Field("v1", description="Manifest API version")
kind: str = Field("Plugin", description="Resource kind")
metadata: PluginMetadata = Field(..., description="Plugin metadata")
spec: "PluginSpec" = Field(..., description="Plugin specification")
@validator('apiVersion')
@validator("apiVersion")
def validate_api_version(cls, v):
if v not in ['v1']:
raise ValueError('Unsupported API version')
if v not in ["v1"]:
raise ValueError("Unsupported API version")
return v
@validator('kind')
@validator("kind")
def validate_kind(cls, v):
if v != 'Plugin':
if v != "Plugin":
raise ValueError('Kind must be "Plugin"')
return v
class PluginSpec(BaseModel):
"""Plugin specification details"""
runtime: PluginRuntimeSpec = Field(default_factory=PluginRuntimeSpec, description="Runtime requirements")
permissions: PluginPermissions = Field(default_factory=PluginPermissions, description="Permission requirements")
database: Optional[PluginDatabaseSpec] = Field(None, description="Database configuration")
api_endpoints: List[PluginAPIEndpoint] = Field(default_factory=list, description="API endpoints")
cron_jobs: List[PluginCronJob] = Field(default_factory=list, description="Scheduled jobs")
runtime: PluginRuntimeSpec = Field(
default_factory=PluginRuntimeSpec, description="Runtime requirements"
)
permissions: PluginPermissions = Field(
default_factory=PluginPermissions, description="Permission requirements"
)
database: Optional[PluginDatabaseSpec] = Field(
None, description="Database configuration"
)
api_endpoints: List[PluginAPIEndpoint] = Field(
default_factory=list, description="API endpoints"
)
cron_jobs: List[PluginCronJob] = Field(
default_factory=list, description="Scheduled jobs"
)
ui_config: Optional[PluginUIConfig] = Field(None, description="UI configuration")
external_services: Optional[PluginExternalServices] = Field(None, description="External service configuration")
config_schema: Dict[str, Any] = Field(default_factory=dict, description="Plugin configuration JSON schema")
external_services: Optional[PluginExternalServices] = Field(
None, description="External service configuration"
)
config_schema: Dict[str, Any] = Field(
default_factory=dict, description="Plugin configuration JSON schema"
)
# Update forward reference
@@ -193,111 +253,108 @@ PluginManifest.model_rebuild()
class PluginManifestValidator:
"""Plugin manifest validation and parsing utilities"""
REQUIRED_FILES = [
'manifest.yaml',
'main.py',
'requirements.txt'
]
REQUIRED_FILES = ["manifest.yaml", "main.py", "requirements.txt"]
OPTIONAL_FILES = [
'config_schema.json',
'README.md',
'ui/components',
'migrations',
'tests'
"config_schema.json",
"README.md",
"ui/components",
"migrations",
"tests",
]
@classmethod
def load_from_file(cls, manifest_path: Union[str, Path]) -> PluginManifest:
"""Load and validate plugin manifest from YAML file"""
manifest_path = Path(manifest_path)
if not manifest_path.exists():
raise FileNotFoundError(f"Manifest file not found: {manifest_path}")
try:
with open(manifest_path, 'r', encoding='utf-8') as f:
with open(manifest_path, "r", encoding="utf-8") as f:
manifest_data = yaml.safe_load(f)
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in manifest file: {e}")
try:
manifest = PluginManifest(**manifest_data)
except Exception as e:
raise ValueError(f"Invalid manifest structure: {e}")
# Additional validation
cls._validate_plugin_structure(manifest_path.parent, manifest)
return manifest
@classmethod
def _validate_plugin_structure(cls, plugin_dir: Path, manifest: PluginManifest):
"""Validate plugin directory structure and required files"""
# Check required files
for required_file in cls.REQUIRED_FILES:
file_path = plugin_dir / required_file
if not file_path.exists():
raise FileNotFoundError(f"Required file missing: {required_file}")
# Validate main.py contains plugin class
main_py_path = plugin_dir / 'main.py'
with open(main_py_path, 'r', encoding='utf-8') as f:
main_py_path = plugin_dir / "main.py"
with open(main_py_path, "r", encoding="utf-8") as f:
main_content = f.read()
if 'BasePlugin' not in main_content:
if "BasePlugin" not in main_content:
raise ValueError("main.py must contain a class inheriting from BasePlugin")
# Validate requirements.txt format
requirements_path = plugin_dir / 'requirements.txt'
with open(requirements_path, 'r', encoding='utf-8') as f:
requirements_path = plugin_dir / "requirements.txt"
with open(requirements_path, "r", encoding="utf-8") as f:
requirements = f.read().strip()
if requirements and not all(line.strip() for line in requirements.split('\n')):
if requirements and not all(line.strip() for line in requirements.split("\n")):
raise ValueError("Invalid requirements.txt format")
# Validate config schema if specified
if manifest.spec.ui_config and manifest.spec.ui_config.configuration_schema:
schema_path = plugin_dir / manifest.spec.ui_config.configuration_schema
if schema_path.exists():
try:
import json
with open(schema_path, 'r', encoding='utf-8') as f:
with open(schema_path, "r", encoding="utf-8") as f:
json.load(f)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON schema: {e}")
# Validate migrations if database is specified
if manifest.spec.database:
migrations_path = plugin_dir / manifest.spec.database.migrations_path
if migrations_path.exists() and not migrations_path.is_dir():
raise ValueError("Migrations path must be a directory")
@classmethod
def validate_plugin_compatibility(cls, manifest: PluginManifest) -> Dict[str, Any]:
"""Validate plugin compatibility with platform"""
compatibility_report = {
"compatible": True,
"warnings": [],
"errors": [],
"platform_version": "1.0.0"
"platform_version": "1.0.0",
}
# Check platform API compatibility
unsupported_apis = []
for api in manifest.spec.permissions.platform_apis:
if not cls._is_platform_api_supported(api):
unsupported_apis.append(api)
if unsupported_apis:
compatibility_report["errors"].append(
f"Unsupported platform APIs: {', '.join(unsupported_apis)}"
)
compatibility_report["compatible"] = False
# Check Python version compatibility
required_version = manifest.spec.runtime.python_version
if not cls._is_python_version_supported(required_version):
@@ -305,63 +362,82 @@ class PluginManifestValidator:
f"Unsupported Python version: {required_version}"
)
compatibility_report["compatible"] = False
# Check dependency compatibility
for dependency in manifest.spec.runtime.dependencies:
if cls._is_dependency_conflicting(dependency):
compatibility_report["warnings"].append(
f"Potential dependency conflict: {dependency}"
)
return compatibility_report
@classmethod
def _is_platform_api_supported(cls, api: str) -> bool:
"""Check if platform API is supported"""
supported_apis = [
'chatbot:invoke', 'chatbot:manage', 'chatbot:read',
'rag:query', 'rag:manage', 'rag:read',
'llm:completion', 'llm:embeddings', 'llm:models',
'workflow:execute', 'workflow:read',
'cache:read', 'cache:write'
"chatbot:invoke",
"chatbot:manage",
"chatbot:read",
"rag:query",
"rag:manage",
"rag:read",
"llm:completion",
"llm:embeddings",
"llm:models",
"workflow:execute",
"workflow:read",
"cache:read",
"cache:write",
]
# Support wildcard permissions
if api.endswith(':*'):
if api.endswith(":*"):
base_api = api[:-2]
return any(supported.startswith(base_api + ':') for supported in supported_apis)
return any(
supported.startswith(base_api + ":") for supported in supported_apis
)
return api in supported_apis
@classmethod
def _is_python_version_supported(cls, version: str) -> bool:
"""Check if Python version is supported"""
supported_versions = ['3.9', '3.10', '3.11', '3.12']
supported_versions = ["3.9", "3.10", "3.11", "3.12"]
return any(version.startswith(v) for v in supported_versions)
@classmethod
def _is_dependency_conflicting(cls, dependency: str) -> bool:
"""Check if dependency might conflict with platform"""
# Extract package name (before ==, >=, etc.)
package_name = dependency.split('==')[0].split('>=')[0].split('<=')[0].split('>')[0].split('<')[0].strip()
package_name = (
dependency.split("==")[0]
.split(">=")[0]
.split("<=")[0]
.split(">")[0]
.split("<")[0]
.strip()
)
# Known conflicting packages
conflicting_packages = [
'sqlalchemy', # Platform uses specific version
'fastapi', # Platform uses specific version
'pydantic', # Platform uses specific version
'alembic' # Platform migration system
"sqlalchemy", # Platform uses specific version
"fastapi", # Platform uses specific version
"pydantic", # Platform uses specific version
"alembic", # Platform migration system
]
return package_name.lower() in conflicting_packages
@classmethod
def generate_manifest_hash(cls, manifest: PluginManifest) -> str:
"""Generate hash for manifest content verification"""
manifest_dict = manifest.dict()
manifest_str = yaml.dump(manifest_dict, sort_keys=True, default_flow_style=False)
return hashlib.sha256(manifest_str.encode('utf-8')).hexdigest()
manifest_str = yaml.dump(
manifest_dict, sort_keys=True, default_flow_style=False
)
return hashlib.sha256(manifest_str.encode("utf-8")).hexdigest()
@classmethod
def create_example_manifest(cls, plugin_name: str) -> PluginManifest:
"""Create an example plugin manifest for development"""
@@ -372,29 +448,25 @@ class PluginManifestValidator:
description=f"Example {plugin_name} plugin for Enclava platform",
author="Enclava Team",
license="MIT",
tags=["integration", "example"]
tags=["integration", "example"],
),
spec=PluginSpec(
runtime=PluginRuntimeSpec(
python_version="3.11",
dependencies=[
"aiohttp>=3.8.0",
"pydantic>=2.0.0"
]
dependencies=["aiohttp>=3.8.0", "pydantic>=2.0.0"],
),
permissions=PluginPermissions(
platform_apis=["chatbot:invoke", "rag:query"],
plugin_scopes=["read", "write"]
plugin_scopes=["read", "write"],
),
database=PluginDatabaseSpec(
schema=f"plugin_{plugin_name}",
migrations_path="./migrations"
schema=f"plugin_{plugin_name}", migrations_path="./migrations"
),
api_endpoints=[
PluginAPIEndpoint(
path="/status",
methods=["GET"],
description="Plugin health status"
description="Plugin health status",
)
],
ui_config=PluginUIConfig(
@@ -403,11 +475,11 @@ class PluginManifestValidator:
{
"name": "dashboard",
"path": f"/plugins/{plugin_name}",
"component": f"{plugin_name.title()}Dashboard"
"component": f"{plugin_name.title()}Dashboard",
}
]
)
)
],
),
),
)
@@ -417,20 +489,20 @@ def validate_manifest_file(manifest_path: Union[str, Path]) -> Dict[str, Any]:
manifest = PluginManifestValidator.load_from_file(manifest_path)
compatibility = PluginManifestValidator.validate_plugin_compatibility(manifest)
manifest_hash = PluginManifestValidator.generate_manifest_hash(manifest)
return {
"valid": True,
"manifest": manifest,
"compatibility": compatibility,
"hash": manifest_hash,
"errors": []
"errors": [],
}
except Exception as e:
return {
"valid": False,
"manifest": None,
"compatibility": None,
"hash": None,
"errors": [str(e)]
}
"errors": [str(e)],
}

367
backend/app/schemas/role.py Normal file
View File

@@ -0,0 +1,367 @@
"""
Role Management Schemas
Pydantic models for role management API
"""
from typing import Optional, List, Dict, Any
from datetime import datetime
from pydantic import BaseModel, validator
class RoleBase(BaseModel):
"""Base role schema"""
name: str
display_name: str
description: Optional[str] = None
level: str = "user"
permissions: Dict[str, Any] = {}
can_manage_users: bool = False
can_manage_budgets: bool = False
can_view_reports: bool = False
can_manage_tools: bool = False
inherits_from: List[str] = []
is_active: bool = True
@validator("name")
def validate_name(cls, v):
if len(v) < 2:
raise ValueError("Role name must be at least 2 characters long")
if len(v) > 50:
raise ValueError("Role name must be less than 50 characters long")
if not v.isalnum() and "_" not in v:
raise ValueError(
"Role name must contain only alphanumeric characters and underscores"
)
return v.lower()
@validator("display_name")
def validate_display_name(cls, v):
if len(v) < 2:
raise ValueError("Display name must be at least 2 characters long")
if len(v) > 100:
raise ValueError("Display name must be less than 100 characters long")
return v
@validator("level")
def validate_level(cls, v):
valid_levels = ["read_only", "user", "admin", "super_admin"]
if v not in valid_levels:
raise ValueError(f'Level must be one of: {", ".join(valid_levels)}')
return v
class RoleCreate(RoleBase):
"""Schema for creating a role"""
is_system_role: bool = False
class RoleUpdate(BaseModel):
"""Schema for updating a role"""
display_name: Optional[str] = None
description: Optional[str] = None
permissions: Optional[Dict[str, Any]] = None
can_manage_users: Optional[bool] = None
can_manage_budgets: Optional[bool] = None
can_view_reports: Optional[bool] = None
can_manage_tools: Optional[bool] = None
is_active: Optional[bool] = None
@validator("display_name")
def validate_display_name(cls, v):
if v is not None:
if len(v) < 2:
raise ValueError("Display name must be at least 2 characters long")
if len(v) > 100:
raise ValueError("Display name must be less than 100 characters long")
return v
class RoleResponse(BaseModel):
"""Role response schema"""
id: int
name: str
display_name: str
description: Optional[str]
level: str
permissions: Dict[str, Any]
can_manage_users: bool
can_manage_budgets: bool
can_view_reports: bool
can_manage_tools: bool
inherits_from: List[str]
is_active: bool
is_system_role: bool
created_at: Optional[datetime]
updated_at: Optional[datetime]
user_count: Optional[int] = 0 # Number of users with this role
class Config:
from_attributes = True
@classmethod
def from_orm(cls, obj):
"""Create response from ORM object"""
data = obj.to_dict()
# Add user count if available
if hasattr(obj, "users"):
data["user_count"] = len([u for u in obj.users if u.is_active])
return cls(**data)
class RoleListResponse(BaseModel):
"""Role list response schema"""
roles: List[RoleResponse]
total: int
skip: int
limit: int
class RoleAssignmentRequest(BaseModel):
"""Schema for role assignment"""
role_id: int
@validator("role_id")
def validate_role_id(cls, v):
if v <= 0:
raise ValueError("Role ID must be a positive integer")
return v
class RoleBulkAction(BaseModel):
"""Schema for bulk role actions"""
role_ids: List[int]
action: str # activate, deactivate, delete
action_data: Optional[Dict[str, Any]] = None
@validator("action")
def validate_action(cls, v):
valid_actions = ["activate", "deactivate", "delete"]
if v not in valid_actions:
raise ValueError(f'Action must be one of: {", ".join(valid_actions)}')
return v
@validator("role_ids")
def validate_role_ids(cls, v):
if not v:
raise ValueError("At least one role ID must be provided")
if len(v) > 50:
raise ValueError("Cannot perform bulk action on more than 50 roles at once")
return v
class RoleStatistics(BaseModel):
"""Role statistics schema"""
total_roles: int
active_roles: int
system_roles: int
roles_by_level: Dict[str, int]
roles_with_users: int
unused_roles: int
class RolePermission(BaseModel):
"""Individual role permission schema"""
permission: str
granted: bool = True
description: Optional[str] = None
class RolePermissionTemplate(BaseModel):
"""Role permission template schema"""
name: str
display_name: str
description: str
level: str
permissions: List[RolePermission]
can_manage_users: bool = False
can_manage_budgets: bool = False
can_view_reports: bool = False
can_manage_tools: bool = False
class RoleHierarchy(BaseModel):
"""Role hierarchy schema"""
role: RoleResponse
parent_roles: List[RoleResponse] = []
child_roles: List[RoleResponse] = []
effective_permissions: Dict[str, Any]
class RoleComparison(BaseModel):
"""Role comparison schema"""
role1: RoleResponse
role2: RoleResponse
common_permissions: List[str]
unique_to_role1: List[str]
unique_to_role2: List[str]
class RoleUsage(BaseModel):
"""Role usage statistics schema"""
role_id: int
role_name: str
user_count: int
active_user_count: int
last_assigned: Optional[datetime]
usage_trend: Dict[str, int] # Monthly usage data
class RoleSearchFilter(BaseModel):
"""Role search filter schema"""
search: Optional[str] = None
level: Optional[str] = None
is_active: Optional[bool] = None
is_system_role: Optional[bool] = None
has_users: Optional[bool] = None
created_after: Optional[datetime] = None
created_before: Optional[datetime] = None
# Predefined permission templates
ROLE_TEMPLATES = {
"read_only": RolePermissionTemplate(
name="read_only",
display_name="Read Only",
description="Can view own data only",
level="read_only",
permissions=[
RolePermission(
permission="read_own", granted=True, description="Read own data"
),
RolePermission(
permission="create", granted=False, description="Create new resources"
),
RolePermission(
permission="update",
granted=False,
description="Update existing resources",
),
RolePermission(
permission="delete", granted=False, description="Delete resources"
),
],
),
"user": RolePermissionTemplate(
name="user",
display_name="User",
description="Standard user with full access to own resources",
level="user",
permissions=[
RolePermission(
permission="read_own", granted=True, description="Read own data"
),
RolePermission(
permission="create_own",
granted=True,
description="Create own resources",
),
RolePermission(
permission="update_own",
granted=True,
description="Update own resources",
),
RolePermission(
permission="delete_own",
granted=True,
description="Delete own resources",
),
RolePermission(
permission="manage_users",
granted=False,
description="Manage other users",
),
RolePermission(
permission="manage_all",
granted=False,
description="Manage all resources",
),
],
inherits_from=["read_only"],
),
"admin": RolePermissionTemplate(
name="admin",
display_name="Administrator",
description="Can manage users and view reports",
level="admin",
permissions=[
RolePermission(
permission="read_all", granted=True, description="Read all data"
),
RolePermission(
permission="create_all",
granted=True,
description="Create any resources",
),
RolePermission(
permission="update_all",
granted=True,
description="Update any resources",
),
RolePermission(
permission="manage_users", granted=True, description="Manage users"
),
RolePermission(
permission="view_reports",
granted=True,
description="View system reports",
),
RolePermission(
permission="system_settings",
granted=False,
description="Modify system settings",
),
],
can_manage_users=True,
can_view_reports=True,
inherits_from=["user"],
),
"super_admin": RolePermissionTemplate(
name="super_admin",
display_name="Super Administrator",
description="Full system access",
level="super_admin",
permissions=[
RolePermission(permission="*", granted=True, description="All permissions")
],
can_manage_users=True,
can_manage_budgets=True,
can_view_reports=True,
can_manage_tools=True,
inherits_from=["admin"],
),
}
# Common permission definitions
COMMON_PERMISSIONS = {
"read_own": "Read own data and resources",
"read_all": "Read all data and resources",
"create_own": "Create own resources",
"create_all": "Create any resources",
"update_own": "Update own resources",
"update_all": "Update any resources",
"delete_own": "Delete own resources",
"delete_all": "Delete any resources",
"manage_users": "Manage user accounts",
"manage_roles": "Manage role assignments",
"manage_budgets": "Manage budget settings",
"view_reports": "View system reports",
"manage_tools": "Manage custom tools",
"system_settings": "Modify system settings",
"create_api_key": "Create API keys",
"create_tool": "Create custom tools",
"export_data": "Export system data",
}

219
backend/app/schemas/tool.py Normal file
View File

@@ -0,0 +1,219 @@
"""
Tool schemas for API requests and responses
"""
from typing import Optional, List, Dict, Any
from datetime import datetime
from pydantic import BaseModel, Field, validator
# Tool Creation and Update Schemas
class ToolCreate(BaseModel):
"""Schema for creating a new tool"""
name: str = Field(..., min_length=1, max_length=100, description="Unique tool name")
display_name: str = Field(
..., min_length=1, max_length=200, description="Display name for the tool"
)
description: Optional[str] = Field(None, description="Tool description")
tool_type: str = Field(..., description="Tool type (python, bash, docker, api)")
code: str = Field(..., min_length=1, description="Tool implementation code")
parameters_schema: Optional[Dict[str, Any]] = Field(
None, description="JSON schema for parameters"
)
return_schema: Optional[Dict[str, Any]] = Field(
None, description="Expected return format schema"
)
timeout_seconds: Optional[int] = Field(
30, ge=1, le=300, description="Execution timeout in seconds"
)
max_memory_mb: Optional[int] = Field(
256, ge=1, le=1024, description="Maximum memory in MB"
)
max_cpu_seconds: Optional[float] = Field(
10.0, ge=0.1, le=60.0, description="Maximum CPU time in seconds"
)
docker_image: Optional[str] = Field(
None, max_length=200, description="Docker image for execution"
)
docker_command: Optional[str] = Field(None, description="Docker command to run")
category: Optional[str] = Field(None, max_length=50, description="Tool category")
tags: Optional[List[str]] = Field(None, description="Tool tags")
is_public: Optional[bool] = Field(False, description="Whether tool is public")
@validator("tool_type")
def validate_tool_type(cls, v):
valid_types = ["python", "bash", "docker", "api", "custom"]
if v not in valid_types:
raise ValueError(f"Tool type must be one of: {valid_types}")
return v
@validator("tags")
def validate_tags(cls, v):
if v is not None and len(v) > 10:
raise ValueError("Maximum 10 tags allowed")
return v
class ToolUpdate(BaseModel):
"""Schema for updating a tool"""
display_name: Optional[str] = Field(None, min_length=1, max_length=200)
description: Optional[str] = None
code: Optional[str] = Field(None, min_length=1)
parameters_schema: Optional[Dict[str, Any]] = None
return_schema: Optional[Dict[str, Any]] = None
timeout_seconds: Optional[int] = Field(None, ge=1, le=300)
max_memory_mb: Optional[int] = Field(None, ge=1, le=1024)
max_cpu_seconds: Optional[float] = Field(None, ge=0.1, le=60.0)
docker_image: Optional[str] = Field(None, max_length=200)
docker_command: Optional[str] = None
category: Optional[str] = Field(None, max_length=50)
tags: Optional[List[str]] = None
is_public: Optional[bool] = None
is_active: Optional[bool] = None
@validator("tags")
def validate_tags(cls, v):
if v is not None and len(v) > 10:
raise ValueError("Maximum 10 tags allowed")
return v
# Tool Response Schemas
class ToolResponse(BaseModel):
"""Schema for tool response"""
id: int
name: str
display_name: str
description: Optional[str]
tool_type: str
parameters_schema: Dict[str, Any]
return_schema: Dict[str, Any]
timeout_seconds: int
max_memory_mb: int
max_cpu_seconds: float
docker_image: Optional[str]
is_public: bool
is_approved: bool
created_by_user_id: int
category: Optional[str]
tags: List[str]
usage_count: int
last_used_at: Optional[datetime]
is_active: bool
created_at: Optional[datetime]
updated_at: Optional[datetime]
class Config:
from_attributes = True
class ToolListResponse(BaseModel):
"""Schema for tool list response"""
tools: List[ToolResponse]
total: int
skip: int
limit: int
# Tool Execution Schemas
class ToolExecutionCreate(BaseModel):
"""Schema for creating a tool execution"""
parameters: Dict[str, Any] = Field(..., description="Parameters for tool execution")
timeout_override: Optional[int] = Field(
None, ge=1, le=300, description="Override default timeout"
)
class ToolExecutionResponse(BaseModel):
"""Schema for tool execution response"""
id: int
tool_id: int
tool_name: Optional[str]
executed_by_user_id: int
parameters: Dict[str, Any]
status: str
output: Optional[str]
error_message: Optional[str]
return_code: Optional[int]
execution_time_ms: Optional[int]
memory_used_mb: Optional[float]
cpu_time_ms: Optional[int]
container_id: Optional[str]
started_at: Optional[datetime]
completed_at: Optional[datetime]
created_at: Optional[datetime]
class Config:
from_attributes = True
class ToolExecutionListResponse(BaseModel):
"""Schema for tool execution list response"""
executions: List[ToolExecutionResponse]
total: int
skip: int
limit: int
# Tool Category Schemas
class ToolCategoryCreate(BaseModel):
"""Schema for creating a tool category"""
name: str = Field(
..., min_length=1, max_length=50, description="Unique category name"
)
display_name: str = Field(
..., min_length=1, max_length=100, description="Display name"
)
description: Optional[str] = Field(None, description="Category description")
icon: Optional[str] = Field(None, max_length=50, description="Icon name")
color: Optional[str] = Field(None, max_length=20, description="Color code")
sort_order: Optional[int] = Field(0, description="Sort order")
class ToolCategoryResponse(BaseModel):
"""Schema for tool category response"""
id: int
name: str
display_name: str
description: Optional[str]
icon: Optional[str]
color: Optional[str]
sort_order: int
is_active: bool
created_at: Optional[datetime]
updated_at: Optional[datetime]
class Config:
from_attributes = True
# Statistics Schema
class ToolStatisticsResponse(BaseModel):
"""Schema for tool statistics response"""
total_tools: int
public_tools: int
tools_by_type: Dict[str, int]
total_executions: int
executions_by_status: Dict[str, int]
recent_executions: int
top_tools: List[Dict[str, Any]]
user_tools: Optional[int] = None
user_executions: Optional[int] = None

View File

@@ -0,0 +1,68 @@
"""
Tool calling schemas for API requests and responses
"""
from typing import Dict, List, Optional, Any
from pydantic import BaseModel, Field
class ToolExecutionRequest(BaseModel):
"""Schema for executing a tool by name"""
tool_name: str = Field(..., description="Name of the tool to execute")
parameters: Dict[str, Any] = Field(
default_factory=dict, description="Parameters for tool execution"
)
class ToolCallResponse(BaseModel):
"""Schema for tool call response"""
success: bool = Field(..., description="Whether the tool call was successful")
result: Optional[Dict[str, Any]] = Field(None, description="Tool execution result")
error: Optional[str] = Field(None, description="Error message if failed")
class ToolValidationRequest(BaseModel):
"""Schema for validating tool availability"""
tool_names: List[str] = Field(..., description="List of tool names to validate")
class ToolValidationResponse(BaseModel):
"""Schema for tool validation response"""
tool_availability: Dict[str, bool] = Field(
..., description="Tool name to availability mapping"
)
class ToolHistoryItem(BaseModel):
"""Schema for tool execution history item"""
id: int = Field(..., description="Execution ID")
tool_name: str = Field(..., description="Tool name")
parameters: Dict[str, Any] = Field(..., description="Execution parameters")
status: str = Field(..., description="Execution status")
output: Optional[str] = Field(None, description="Tool output")
error_message: Optional[str] = Field(None, description="Error message if failed")
execution_time_ms: Optional[int] = Field(
None, description="Execution time in milliseconds"
)
created_at: Optional[str] = Field(None, description="Creation timestamp")
completed_at: Optional[str] = Field(None, description="Completion timestamp")
class ToolHistoryResponse(BaseModel):
"""Schema for tool execution history response"""
history: List[ToolHistoryItem] = Field(..., description="Tool execution history")
total: int = Field(..., description="Total number of history items")
class ToolCallRequest(BaseModel):
"""Schema for tool call request (placeholder for future use)"""
message: str = Field(..., description="Chat message")
tools: Optional[List[str]] = Field(
None, description="Specific tools to make available"
)

260
backend/app/schemas/user.py Normal file
View File

@@ -0,0 +1,260 @@
"""
User Management Schemas
Pydantic models for user management API
"""
from typing import Optional, List, Dict, Any
from datetime import datetime
from pydantic import BaseModel, EmailStr, validator
class UserBase(BaseModel):
"""Base user schema"""
email: EmailStr
username: str
full_name: Optional[str] = None
is_active: bool = True
is_verified: bool = False
@validator("username")
def validate_username(cls, v):
if len(v) < 3:
raise ValueError("Username must be at least 3 characters long")
if len(v) > 50:
raise ValueError("Username must be less than 50 characters long")
return v
@validator("email")
def validate_email(cls, v):
if len(v) > 255:
raise ValueError("Email must be less than 255 characters long")
return v
class UserCreate(UserBase):
"""Schema for creating a user"""
password: str
role_id: Optional[int] = None
custom_permissions: Dict[str, Any] = {}
@validator("password")
def validate_password(cls, v):
if len(v) < 8:
raise ValueError("Password must be at least 8 characters long")
return v
class UserUpdate(BaseModel):
"""Schema for updating a user"""
email: Optional[EmailStr] = None
username: Optional[str] = None
full_name: Optional[str] = None
role_id: Optional[int] = None
custom_permissions: Optional[Dict[str, Any]] = None
is_active: Optional[bool] = None
is_verified: Optional[bool] = None
@validator("username")
def validate_username(cls, v):
if v is not None:
if len(v) < 3:
raise ValueError("Username must be at least 3 characters long")
if len(v) > 50:
raise ValueError("Username must be less than 50 characters long")
return v
class PasswordChange(BaseModel):
"""Schema for changing password"""
current_password: Optional[str] = None
new_password: str
confirm_password: str
@validator("new_password")
def validate_new_password(cls, v):
if len(v) < 8:
raise ValueError("Password must be at least 8 characters long")
return v
@validator("confirm_password")
def passwords_match(cls, v, values):
if "new_password" in values and v != values["new_password"]:
raise ValueError("Passwords do not match")
return v
class RoleInfo(BaseModel):
"""Role information schema"""
id: int
name: str
display_name: str
level: str
permissions: Dict[str, Any]
class Config:
from_attributes = True
class UserResponse(BaseModel):
"""User response schema"""
id: int
email: str
username: str
full_name: Optional[str]
is_active: bool
is_verified: bool
is_superuser: bool
role_id: Optional[int]
role: Optional[RoleInfo]
custom_permissions: Dict[str, Any]
account_locked: Optional[bool] = False
account_locked_until: Optional[datetime]
failed_login_attempts: Optional[int] = 0
last_failed_login: Optional[datetime]
avatar_url: Optional[str]
bio: Optional[str]
company: Optional[str]
website: Optional[str]
created_at: Optional[datetime]
updated_at: Optional[datetime]
last_login: Optional[datetime]
preferences: Dict[str, Any]
notification_settings: Dict[str, Any]
class Config:
from_attributes = True
@classmethod
def from_orm(cls, obj):
"""Create response from ORM object with proper role handling"""
data = obj.to_dict()
if obj.role:
data["role"] = RoleInfo.from_orm(obj.role)
return cls(**data)
class UserListResponse(BaseModel):
"""User list response schema"""
users: List[UserResponse]
total: int
skip: int
limit: int
class AccountLockResponse(BaseModel):
"""Account lock response schema"""
user_id: int
is_locked: bool
locked_until: Optional[datetime]
message: str
class UserProfileUpdate(BaseModel):
"""Schema for user profile updates (limited fields)"""
full_name: Optional[str] = None
avatar_url: Optional[str] = None
bio: Optional[str] = None
company: Optional[str] = None
website: Optional[str] = None
preferences: Optional[Dict[str, Any]] = None
notification_settings: Optional[Dict[str, Any]] = None
class UserPreferences(BaseModel):
"""User preferences schema"""
theme: Optional[str] = "light"
language: Optional[str] = "en"
timezone: Optional[str] = "UTC"
email_notifications: Optional[bool] = True
security_alerts: Optional[bool] = True
system_updates: Optional[bool] = True
class UserSearchFilter(BaseModel):
"""User search filter schema"""
search: Optional[str] = None
role_id: Optional[int] = None
is_active: Optional[bool] = None
is_verified: Optional[bool] = None
created_after: Optional[datetime] = None
created_before: Optional[datetime] = None
last_login_after: Optional[datetime] = None
last_login_before: Optional[datetime] = None
class UserBulkAction(BaseModel):
"""Schema for bulk user actions"""
user_ids: List[int]
action: str # activate, deactivate, lock, unlock, assign_role, remove_role
action_data: Optional[Dict[str, Any]] = None
@validator("action")
def validate_action(cls, v):
valid_actions = [
"activate",
"deactivate",
"lock",
"unlock",
"assign_role",
"remove_role",
]
if v not in valid_actions:
raise ValueError(f'Action must be one of: {", ".join(valid_actions)}')
return v
@validator("user_ids")
def validate_user_ids(cls, v):
if not v:
raise ValueError("At least one user ID must be provided")
if len(v) > 100:
raise ValueError(
"Cannot perform bulk action on more than 100 users at once"
)
return v
class UserStatistics(BaseModel):
"""User statistics schema"""
total_users: int
active_users: int
verified_users: int
locked_users: int
users_by_role: Dict[str, int]
recent_registrations: int
registrations_by_month: Dict[str, int]
class UserActivity(BaseModel):
"""User activity schema"""
user_id: int
action: str
resource_type: Optional[str] = None
resource_id: Optional[int] = None
details: Optional[Dict[str, Any]] = None
timestamp: datetime
ip_address: Optional[str] = None
user_agent: Optional[str] = None
class UserActivityFilter(BaseModel):
"""User activity filter schema"""
user_id: Optional[int] = None
action: Optional[str] = None
resource_type: Optional[str] = None
date_from: Optional[datetime] = None
date_to: Optional[datetime] = None
ip_address: Optional[str] = None

View File

@@ -1,3 +1,3 @@
"""
Services package
"""
"""

File diff suppressed because it is too large Load Diff

View File

@@ -23,40 +23,46 @@ logger = logging.getLogger(__name__)
class APIKeyAuthService:
"""Service for API key authentication and validation"""
def __init__(self, db: AsyncSession):
self.db = db
async def validate_api_key(self, api_key: str, request: Request) -> Optional[Dict[str, Any]]:
async def validate_api_key(
self, api_key: str, request: Request
) -> Optional[Dict[str, Any]]:
"""Validate API key and return user context using Redis cache for performance"""
try:
if not api_key:
return None
# Extract key prefix for lookup
if len(api_key) < 8:
logger.warning(f"Invalid API key format: too short")
return None
key_prefix = api_key[:8]
# Try cached verification first
cached_verification = await cached_api_key_service.verify_api_key_cached(api_key, key_prefix)
cached_verification = await cached_api_key_service.verify_api_key_cached(
api_key, key_prefix
)
# Get API key data from cache or database
context = await cached_api_key_service.get_cached_api_key(key_prefix, self.db)
context = await cached_api_key_service.get_cached_api_key(
key_prefix, self.db
)
if not context:
logger.warning(f"API key not found: {key_prefix}")
return None
api_key_obj = context["api_key"]
# If not in verification cache, verify and cache the result
if not cached_verification:
# Get the actual key hash for verification (this should be in the cached context)
db_api_key = None
if not hasattr(api_key_obj, 'key_hash'):
if not hasattr(api_key_obj, "key_hash"):
# Fallback: fetch full API key from database for hash
stmt = select(APIKey).where(APIKey.key_prefix == key_prefix)
result = await self.db.execute(stmt)
@@ -66,76 +72,85 @@ class APIKeyAuthService:
key_hash = db_api_key.key_hash
else:
key_hash = api_key_obj.key_hash
# Verify the API key hash
if not verify_api_key(api_key, key_hash):
logger.warning(f"Invalid API key hash: {key_prefix}")
return None
# Cache successful verification
await cached_api_key_service.cache_verification_result(api_key, key_prefix, key_hash, True)
await cached_api_key_service.cache_verification_result(
api_key, key_prefix, key_hash, True
)
# Check if key is valid (expiry, active status)
if not api_key_obj.is_valid():
logger.warning(f"API key expired or inactive: {key_prefix}")
# Invalidate cache for expired keys
await cached_api_key_service.invalidate_api_key_cache(key_prefix)
return None
# Check IP restrictions
client_ip = request.client.host if request.client else "unknown"
if not api_key_obj.can_access_from_ip(client_ip):
logger.warning(f"IP not allowed for API key {key_prefix}: {client_ip}")
return None
# Update last used timestamp asynchronously (performance optimization)
await cached_api_key_service.update_last_used(context["api_key_id"], self.db)
await cached_api_key_service.update_last_used(
context["api_key_id"], self.db
)
return context
except Exception as e:
logger.error(f"API key validation error: {e}")
return None
async def check_endpoint_permission(self, context: Dict[str, Any], endpoint: str) -> bool:
async def check_endpoint_permission(
self, context: Dict[str, Any], endpoint: str
) -> bool:
"""Check if API key has permission to access endpoint"""
api_key: APIKey = context.get("api_key")
if not api_key:
return False
return api_key.can_access_endpoint(endpoint)
async def check_model_permission(self, context: Dict[str, Any], model: str) -> bool:
"""Check if API key has permission to access model"""
api_key: APIKey = context.get("api_key")
if not api_key:
return False
return api_key.can_access_model(model)
async def check_scope_permission(self, context: Dict[str, Any], scope: str) -> bool:
"""Check if API key has required scope"""
api_key: APIKey = context.get("api_key")
if not api_key:
return False
return api_key.has_scope(scope)
async def update_usage_stats(self, context: Dict[str, Any], tokens_used: int = 0, cost_cents: int = 0):
async def update_usage_stats(
self, context: Dict[str, Any], tokens_used: int = 0, cost_cents: int = 0
):
"""Update API key usage statistics"""
try:
api_key: APIKey = context.get("api_key")
if api_key:
api_key.update_usage(tokens_used, cost_cents)
await self.db.commit()
logger.info(f"Updated usage for API key {api_key.key_prefix}: +{tokens_used} tokens, +{cost_cents} cents")
logger.info(
f"Updated usage for API key {api_key.key_prefix}: +{tokens_used} tokens, +{cost_cents} cents"
)
except Exception as e:
logger.error(f"Failed to update usage stats: {e}")
async def get_api_key_context(
request: Request,
db: AsyncSession = Depends(get_db)
request: Request, db: AsyncSession = Depends(get_db)
) -> Optional[Dict[str, Any]]:
"""Dependency to get API key context from request"""
auth_service = APIKeyAuthService(db)
@@ -170,7 +185,7 @@ async def require_api_key(
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Valid API key required",
headers={"WWW-Authenticate": "Bearer"}
headers={"WWW-Authenticate": "Bearer"},
)
return context
@@ -180,19 +195,19 @@ async def get_current_api_key_user(
) -> tuple:
"""
Dependency that returns current user and API key as a tuple
Returns:
tuple: (user, api_key)
"""
user = context.get("user")
api_key = context.get("api_key")
if not user or not api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User or API key not found in context"
detail="User or API key not found in context",
)
return user, api_key
@@ -201,48 +216,48 @@ async def get_api_key_auth(
) -> APIKey:
"""
Dependency that returns the authenticated API key object
Returns:
APIKey: The authenticated API key object
"""
api_key = context.get("api_key")
if not api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API key not found in context"
detail="API key not found in context",
)
return api_key
class RequireScope:
"""Dependency class for scope checking"""
def __init__(self, scope: str):
self.scope = scope
async def __call__(self, context: Dict[str, Any] = Depends(require_api_key)):
auth_service = APIKeyAuthService(context.get("db"))
if not await auth_service.check_scope_permission(context, self.scope):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Scope '{self.scope}' required"
detail=f"Scope '{self.scope}' required",
)
return context
class RequireModel:
"""Dependency class for model access checking"""
def __init__(self, model: str):
self.model = model
async def __call__(self, context: Dict[str, Any] = Depends(require_api_key)):
auth_service = APIKeyAuthService(context.get("db"))
if not await auth_service.check_model_permission(context, self.model):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Model '{self.model}' not allowed"
detail=f"Model '{self.model}' not allowed",
)
return context
return context

View File

@@ -20,17 +20,17 @@ _audit_worker_started = False
async def _audit_worker():
"""Background worker to process audit events"""
from app.db.database import async_session_factory
logger.info("Audit worker started")
while True:
try:
# Get audit event from queue
audit_data = await _audit_queue.get()
if audit_data is None: # Shutdown signal
break
# Process the audit event in a separate database session
async with async_session_factory() as db:
try:
@@ -41,9 +41,9 @@ async def _audit_worker():
except Exception as e:
logger.error(f"Failed to write audit log in background: {e}")
await db.rollback()
_audit_queue.task_done()
except Exception as e:
logger.error(f"Audit worker error: {e}")
await asyncio.sleep(1) # Brief pause before retrying
@@ -68,11 +68,11 @@ async def log_audit_event_async(
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
success: bool = True,
severity: str = "info"
severity: str = "info",
):
"""
Log an audit event asynchronously (non-blocking)
This function queues the audit event for background processing,
so it doesn't block the main request flow.
"""
@@ -80,11 +80,11 @@ async def log_audit_event_async(
# Ensure audit worker is started
if not _audit_worker_started:
start_audit_worker()
audit_details = details or {}
if api_key_id:
audit_details["api_key_id"] = api_key_id
audit_data = {
"user_id": user_id,
"action": action,
@@ -96,16 +96,16 @@ async def log_audit_event_async(
"user_agent": user_agent,
"success": success,
"severity": severity,
"created_at": datetime.utcnow()
"created_at": datetime.utcnow(),
}
# Queue the audit event (non-blocking)
try:
_audit_queue.put_nowait(audit_data)
logger.debug(f"Audit event queued: {action} on {resource_type}")
except asyncio.QueueFull:
logger.warning("Audit queue full, dropping audit event")
except Exception as e:
logger.error(f"Failed to queue audit event: {e}")
# Don't raise - audit failures shouldn't break main operations
@@ -122,11 +122,11 @@ async def log_audit_event(
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
success: bool = True,
severity: str = "info"
severity: str = "info",
):
"""
Log an audit event to the database
Args:
db: Database session
user_id: ID of the user performing the action
@@ -140,12 +140,12 @@ async def log_audit_event(
success: Whether the action was successful
severity: Severity level (info, warning, error, critical)
"""
try:
audit_details = details or {}
if api_key_id:
audit_details["api_key_id"] = api_key_id
audit_log = AuditLog(
user_id=user_id,
action=action,
@@ -157,14 +157,16 @@ async def log_audit_event(
user_agent=user_agent,
success=success,
severity=severity,
created_at=datetime.utcnow()
created_at=datetime.utcnow(),
)
db.add(audit_log)
await db.flush() # Don't commit here, let the caller control the transaction
logger.debug(f"Audit event logged: {action} on {resource_type} by user {user_id}")
logger.debug(
f"Audit event logged: {action} on {resource_type} by user {user_id}"
)
except Exception as e:
logger.error(f"Failed to log audit event: {e}")
# Don't raise here as audit logging shouldn't break the main operation
@@ -179,11 +181,11 @@ async def get_audit_logs(
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: int = 100,
offset: int = 0
offset: int = 0,
):
"""
Query audit logs with filtering
Args:
db: Database session
user_id: Filter by user ID
@@ -194,16 +196,16 @@ async def get_audit_logs(
end_date: Filter by end date
limit: Maximum number of results
offset: Number of results to skip
Returns:
List of audit log entries
"""
from sqlalchemy import select, and_
query = select(AuditLog)
conditions = []
if user_id:
conditions.append(AuditLog.user_id == user_id)
if action:
@@ -216,13 +218,13 @@ async def get_audit_logs(
conditions.append(AuditLog.created_at >= start_date)
if end_date:
conditions.append(AuditLog.created_at <= end_date)
if conditions:
query = query.where(and_(*conditions))
query = query.order_by(AuditLog.created_at.desc())
query = query.offset(offset).limit(limit)
result = await db.execute(query)
return result.scalars().all()
@@ -230,68 +232,80 @@ async def get_audit_logs(
async def get_audit_stats(
db: AsyncSession,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
end_date: Optional[datetime] = None,
):
"""
Get audit statistics
Args:
db: Database session
start_date: Start date for statistics
end_date: End date for statistics
Returns:
Dictionary with audit statistics
"""
from sqlalchemy import select, func, and_
conditions = []
if start_date:
conditions.append(AuditLog.created_at >= start_date)
if end_date:
conditions.append(AuditLog.created_at <= end_date)
# Total events
total_query = select(func.count(AuditLog.id))
if conditions:
total_query = total_query.where(and_(*conditions))
total_result = await db.execute(total_query)
total_events = total_result.scalar()
# Events by action
action_query = select(AuditLog.action, func.count(AuditLog.id)).group_by(AuditLog.action)
action_query = select(AuditLog.action, func.count(AuditLog.id)).group_by(
AuditLog.action
)
if conditions:
action_query = action_query.where(and_(*conditions))
action_result = await db.execute(action_query)
events_by_action = dict(action_result.fetchall())
# Events by resource type
resource_query = select(AuditLog.resource_type, func.count(AuditLog.id)).group_by(AuditLog.resource_type)
resource_query = select(AuditLog.resource_type, func.count(AuditLog.id)).group_by(
AuditLog.resource_type
)
if conditions:
resource_query = resource_query.where(and_(*conditions))
resource_result = await db.execute(resource_query)
events_by_resource = dict(resource_result.fetchall())
# Events by severity
severity_query = select(AuditLog.severity, func.count(AuditLog.id)).group_by(AuditLog.severity)
severity_query = select(AuditLog.severity, func.count(AuditLog.id)).group_by(
AuditLog.severity
)
if conditions:
severity_query = severity_query.where(and_(*conditions))
severity_result = await db.execute(severity_query)
events_by_severity = dict(severity_result.fetchall())
# Success rate
success_query = select(AuditLog.success, func.count(AuditLog.id)).group_by(AuditLog.success)
success_query = select(AuditLog.success, func.count(AuditLog.id)).group_by(
AuditLog.success
)
if conditions:
success_query = success_query.where(and_(*conditions))
success_result = await db.execute(success_query)
success_stats = dict(success_result.fetchall())
return {
"total_events": total_events,
"events_by_action": events_by_action,
"events_by_resource_type": events_by_resource,
"events_by_severity": events_by_severity,
"success_rate": success_stats.get(True, 0) / total_events if total_events > 0 else 0,
"failure_rate": success_stats.get(False, 0) / total_events if total_events > 0 else 0
}
"success_rate": success_stats.get(True, 0) / total_events
if total_events > 0
else 0,
"failure_rate": success_stats.get(False, 0) / total_events
if total_events > 0
else 0,
}

View File

@@ -22,10 +22,11 @@ logger = get_logger(__name__)
@dataclass
class Permission:
"""Represents a module permission"""
resource: str
action: str
description: str
def __str__(self) -> str:
return f"{self.resource}:{self.action}"
@@ -33,26 +34,28 @@ class Permission:
@dataclass
class ModuleMetrics:
"""Module performance metrics"""
requests_processed: int = 0
average_response_time: float = 0.0
error_rate: float = 0.0
last_activity: Optional[str] = None
total_errors: int = 0
uptime_start: float = 0.0
def __post_init__(self):
if self.uptime_start == 0.0:
self.uptime_start = time.time()
@dataclass
@dataclass
class ModuleHealth:
"""Module health status"""
status: str = "healthy" # healthy, warning, error
message: str = "Module is functioning normally"
uptime: float = 0.0
last_check: float = 0.0
def __post_init__(self):
if self.last_check == 0.0:
self.last_check = time.time()
@@ -60,53 +63,55 @@ class ModuleHealth:
class BaseModule(ABC):
"""Base class for all modules with interceptor pattern support"""
def __init__(self, module_id: str, config: Dict[str, Any] = None):
self.module_id = module_id
self.config = config or {}
self.metrics = ModuleMetrics()
self.health = ModuleHealth()
self.initialized = False
self.interceptors: List['ModuleInterceptor'] = []
self.interceptors: List["ModuleInterceptor"] = []
# Register default interceptors
self._register_default_interceptors()
@abstractmethod
async def initialize(self) -> None:
"""Initialize the module"""
pass
@abstractmethod
async def cleanup(self) -> None:
"""Cleanup module resources"""
pass
@abstractmethod
def get_required_permissions(self) -> List[Permission]:
"""Return list of permissions this module requires"""
return []
@abstractmethod
async def process_request(self, request: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
async def process_request(
self, request: Dict[str, Any], context: Dict[str, Any]
) -> Dict[str, Any]:
"""Process a module request"""
pass
def get_health(self) -> ModuleHealth:
"""Get current module health status"""
self.health.uptime = time.time() - self.metrics.uptime_start
self.health.last_check = time.time()
return self.health
def get_metrics(self) -> ModuleMetrics:
"""Get current module metrics"""
return self.metrics
def check_access(self, user_permissions: List[str], action: str) -> bool:
"""Check if user can perform action on this module"""
required = f"modules:{self.module_id}:{action}"
return permission_registry.check_permission(user_permissions, required)
def _register_default_interceptors(self):
"""Register default interceptors for all modules"""
self.interceptors = [
@@ -115,47 +120,49 @@ class BaseModule(ABC):
ValidationInterceptor(),
MetricsInterceptor(self),
SecurityInterceptor(),
AuditInterceptor(self)
AuditInterceptor(self),
]
async def execute_with_interceptors(self, request: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
async def execute_with_interceptors(
self, request: Dict[str, Any], context: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute request through interceptor chain"""
start_time = time.time()
try:
# Pre-processing interceptors
for interceptor in self.interceptors:
request, context = await interceptor.pre_process(request, context)
# Execute main module logic
response = await self.process_request(request, context)
# Post-processing interceptors (in reverse order)
for interceptor in reversed(self.interceptors):
response = await interceptor.post_process(request, context, response)
# Update metrics
self._update_metrics(start_time, success=True)
return response
except Exception as e:
# Update error metrics
self._update_metrics(start_time, success=False, error=str(e))
# Error handling interceptors
for interceptor in reversed(self.interceptors):
if hasattr(interceptor, 'handle_error'):
if hasattr(interceptor, "handle_error"):
await interceptor.handle_error(request, context, e)
raise
def _update_metrics(self, start_time: float, success: bool, error: str = None):
"""Update module metrics"""
duration = time.time() - start_time
self.metrics.requests_processed += 1
# Update average response time
if self.metrics.requests_processed == 1:
self.metrics.average_response_time = duration
@@ -165,94 +172,118 @@ class BaseModule(ABC):
self.metrics.average_response_time = (
alpha * duration + (1 - alpha) * self.metrics.average_response_time
)
if not success:
self.metrics.total_errors += 1
self.metrics.error_rate = self.metrics.total_errors / self.metrics.requests_processed
self.metrics.error_rate = (
self.metrics.total_errors / self.metrics.requests_processed
)
# Update health status based on error rate
if self.metrics.error_rate > 0.1: # More than 10% error rate
self.health.status = "error"
self.health.message = f"High error rate: {self.metrics.error_rate:.2%}"
elif self.metrics.error_rate > 0.05: # More than 5% error rate
self.health.status = "warning"
self.health.message = f"Elevated error rate: {self.metrics.error_rate:.2%}"
self.health.message = (
f"Elevated error rate: {self.metrics.error_rate:.2%}"
)
else:
self.metrics.error_rate = self.metrics.total_errors / self.metrics.requests_processed
self.metrics.error_rate = (
self.metrics.total_errors / self.metrics.requests_processed
)
if self.metrics.error_rate <= 0.05:
self.health.status = "healthy"
self.health.message = "Module is functioning normally"
self.metrics.last_activity = time.strftime("%Y-%m-%d %H:%M:%S")
class ModuleInterceptor(ABC):
"""Base class for module interceptors"""
@abstractmethod
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
async def pre_process(
self, request: Dict[str, Any], context: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Pre-process the request"""
return request, context
@abstractmethod
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
async def post_process(
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
) -> Dict[str, Any]:
"""Post-process the response"""
return response
class AuthenticationInterceptor(ModuleInterceptor):
"""Handles authentication for module requests"""
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
async def pre_process(
self, request: Dict[str, Any], context: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Check if user is authenticated (context should contain user info from API auth)
if not context.get("user_id") and not context.get("api_key_id"):
raise AuthenticationError("Authentication required for module access")
return request, context
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
async def post_process(
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
) -> Dict[str, Any]:
return response
class PermissionInterceptor(ModuleInterceptor):
"""Handles permission checking for module requests"""
def __init__(self, module: BaseModule):
self.module = module
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
async def pre_process(
self, request: Dict[str, Any], context: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
action = request.get("action", "execute")
user_permissions = context.get("user_permissions", [])
if not self.module.check_access(user_permissions, action):
raise AuthenticationError(f"Insufficient permissions for module action: {action}")
raise AuthenticationError(
f"Insufficient permissions for module action: {action}"
)
return request, context
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
async def post_process(
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
) -> Dict[str, Any]:
return response
class ValidationInterceptor(ModuleInterceptor):
"""Handles request validation and sanitization"""
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
async def pre_process(
self, request: Dict[str, Any], context: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Sanitize request data
sanitized_request = self._sanitize_request(request)
# Validate request structure
self._validate_request(sanitized_request)
return sanitized_request, context
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
async def post_process(
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
) -> Dict[str, Any]:
# Sanitize response data
return self._sanitize_response(response)
def _sanitize_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""Remove potentially dangerous content from request"""
sanitized = copy.deepcopy(request)
# Define dangerous patterns
dangerous_patterns = [
r"<script[^>]*>.*?</script>",
@@ -264,19 +295,21 @@ class ValidationInterceptor(ModuleInterceptor):
r"eval\s*\(",
r"Function\s*\(",
]
def sanitize_value(value):
if isinstance(value, str):
# Remove dangerous patterns
for pattern in dangerous_patterns:
value = re.sub(pattern, "", value, flags=re.IGNORECASE)
# Limit string length
max_length = 10000
if len(value) > max_length:
value = value[:max_length]
logger.warning(f"Truncated long string in request: {len(value)} chars")
logger.warning(
f"Truncated long string in request: {len(value)} chars"
)
return value
elif isinstance(value, dict):
return {k: sanitize_value(v) for k, v in value.items()}
@@ -284,124 +317,159 @@ class ValidationInterceptor(ModuleInterceptor):
return [sanitize_value(item) for item in value]
else:
return value
return sanitize_value(sanitized)
def _sanitize_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
"""Sanitize response data"""
# Similar sanitization for responses
return self._sanitize_request(response)
def _validate_request(self, request: Dict[str, Any]):
"""Validate request structure"""
# Check for required fields
if not isinstance(request, dict):
raise ValidationError("Request must be a dictionary")
# Check request size
request_str = json.dumps(request)
max_size = 10 * 1024 * 1024 # 10MB
if len(request_str.encode()) > max_size:
raise ValidationError(f"Request size exceeds maximum allowed ({max_size} bytes)")
raise ValidationError(
f"Request size exceeds maximum allowed ({max_size} bytes)"
)
class MetricsInterceptor(ModuleInterceptor):
"""Handles metrics collection for module requests"""
def __init__(self, module: BaseModule):
self.module = module
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
async def pre_process(
self, request: Dict[str, Any], context: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
context["_metrics_start_time"] = time.time()
return request, context
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
async def post_process(
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
) -> Dict[str, Any]:
# Metrics are updated in the base module execute_with_interceptors method
return response
class SecurityInterceptor(ModuleInterceptor):
"""Handles security-related processing"""
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
async def pre_process(
self, request: Dict[str, Any], context: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Add security headers to context
context["security_headers"] = {
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Strict-Transport-Security": "max-age=31536000; includeSubDomains"
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
}
# Check for suspicious patterns
self._check_security_patterns(request)
return request, context
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
async def post_process(
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
) -> Dict[str, Any]:
# Remove any sensitive information from response
return self._remove_sensitive_data(response)
def _check_security_patterns(self, request: Dict[str, Any]):
"""Check for suspicious security patterns"""
request_str = json.dumps(request).lower()
suspicious_patterns = [
"union select", "drop table", "insert into", "delete from",
"script>", "javascript:", "eval(", "expression(",
"../", "..\\", "file://", "ftp://",
"union select",
"drop table",
"insert into",
"delete from",
"script>",
"javascript:",
"eval(",
"expression(",
"../",
"..\\",
"file://",
"ftp://",
]
for pattern in suspicious_patterns:
if pattern in request_str:
logger.warning(f"Suspicious pattern detected in request: {pattern}")
# Could implement additional security measures here
def _remove_sensitive_data(self, response: Dict[str, Any]) -> Dict[str, Any]:
"""Remove sensitive data from response"""
sensitive_keys = ["password", "secret", "token", "key", "private"]
def clean_dict(obj):
if isinstance(obj, dict):
return {
k: "***REDACTED***" if any(sk in k.lower() for sk in sensitive_keys) else clean_dict(v)
k: "***REDACTED***"
if any(sk in k.lower() for sk in sensitive_keys)
else clean_dict(v)
for k, v in obj.items()
}
elif isinstance(obj, list):
return [clean_dict(item) for item in obj]
else:
return obj
return clean_dict(response)
class AuditInterceptor(ModuleInterceptor):
"""Handles audit logging for module requests"""
def __init__(self, module: BaseModule):
self.module = module
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
async def pre_process(
self, request: Dict[str, Any], context: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
context["_audit_start_time"] = time.time()
context["_audit_request_hash"] = self._hash_request(request)
return request, context
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
async def post_process(
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
) -> Dict[str, Any]:
await self._log_audit_event(request, context, response, success=True)
return response
async def handle_error(self, request: Dict[str, Any], context: Dict[str, Any], error: Exception):
async def handle_error(
self, request: Dict[str, Any], context: Dict[str, Any], error: Exception
):
"""Handle error logging"""
await self._log_audit_event(request, context, {"error": str(error)}, success=False)
await self._log_audit_event(
request, context, {"error": str(error)}, success=False
)
def _hash_request(self, request: Dict[str, Any]) -> str:
"""Create a hash of the request for audit purposes"""
request_str = json.dumps(request, sort_keys=True)
return hashlib.sha256(request_str.encode()).hexdigest()[:16]
async def _log_audit_event(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any], success: bool):
async def _log_audit_event(
self,
request: Dict[str, Any],
context: Dict[str, Any],
response: Dict[str, Any],
success: bool,
):
"""Log audit event"""
duration = time.time() - context.get("_audit_start_time", time.time())
audit_data = {
"module_id": self.module.module_id,
"action": request.get("action", "execute"),
@@ -413,11 +481,11 @@ class AuditInterceptor(ModuleInterceptor):
"duration_ms": int(duration * 1000),
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
}
if not success:
audit_data["error"] = response.get("error", "Unknown error")
# Log the audit event
logger.info(f"Module audit: {json.dumps(audit_data)}")
# Could also store in database for persistent audit trail
# Could also store in database for persistent audit trail

View File

@@ -28,6 +28,7 @@ from sqlalchemy.orm import Session
@dataclass
class PluginContext:
"""Plugin execution context with user and authentication info"""
user_id: Optional[str] = None
api_key_id: Optional[str] = None
user_permissions: List[str] = None
@@ -38,25 +39,29 @@ class PluginContext:
class PlatformAPIClient:
"""Secure client for plugins to access platform APIs"""
def __init__(self, plugin_id: str, plugin_token: str):
self.plugin_id = plugin_id
self.plugin_token = plugin_token
self.base_url = settings.INTERNAL_API_URL or "http://localhost:58000"
self.logger = get_logger(f"plugin.{plugin_id}.api_client")
async def _make_request(self, method: str, endpoint: str, **kwargs) -> Dict[str, Any]:
async def _make_request(
self, method: str, endpoint: str, **kwargs
) -> Dict[str, Any]:
"""Make authenticated request to platform API"""
headers = kwargs.setdefault('headers', {})
headers.update({
'Authorization': f'Bearer {self.plugin_token}',
'X-Plugin-ID': self.plugin_id,
'X-Platform-Client': 'plugin',
'Content-Type': 'application/json'
})
headers = kwargs.setdefault("headers", {})
headers.update(
{
"Authorization": f"Bearer {self.plugin_token}",
"X-Plugin-ID": self.plugin_id,
"X-Platform-Client": "plugin",
"Content-Type": "application/json",
}
)
url = f"{self.base_url}{endpoint}"
try:
async with aiohttp.ClientSession() as session:
async with session.request(method, url, **kwargs) as response:
@@ -64,154 +69,162 @@ class PlatformAPIClient:
error_text = await response.text()
raise HTTPException(
status_code=response.status,
detail=f"Platform API error: {error_text}"
detail=f"Platform API error: {error_text}",
)
if response.content_type == 'application/json':
if response.content_type == "application/json":
return await response.json()
else:
return {"data": await response.text()}
except aiohttp.ClientError as e:
self.logger.error(f"Platform API client error: {e}")
raise HTTPException(
status_code=503,
detail=f"Platform API unavailable: {str(e)}"
status_code=503, detail=f"Platform API unavailable: {str(e)}"
)
async def get(self, endpoint: str, **kwargs) -> Dict[str, Any]:
"""GET request to platform API"""
return await self._make_request('GET', endpoint, **kwargs)
async def post(self, endpoint: str, data: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
return await self._make_request("GET", endpoint, **kwargs)
async def post(
self, endpoint: str, data: Dict[str, Any] = None, **kwargs
) -> Dict[str, Any]:
"""POST request to platform API"""
if data:
kwargs['json'] = data
return await self._make_request('POST', endpoint, **kwargs)
async def put(self, endpoint: str, data: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
kwargs["json"] = data
return await self._make_request("POST", endpoint, **kwargs)
async def put(
self, endpoint: str, data: Dict[str, Any] = None, **kwargs
) -> Dict[str, Any]:
"""PUT request to platform API"""
if data:
kwargs['json'] = data
return await self._make_request('PUT', endpoint, **kwargs)
kwargs["json"] = data
return await self._make_request("PUT", endpoint, **kwargs)
async def delete(self, endpoint: str, **kwargs) -> Dict[str, Any]:
"""DELETE request to platform API"""
return await self._make_request('DELETE', endpoint, **kwargs)
return await self._make_request("DELETE", endpoint, **kwargs)
# Platform-specific API methods
async def call_chatbot_api(self, chatbot_id: str, message: str,
context: Dict[str, Any] = None) -> Dict[str, Any]:
async def call_chatbot_api(
self, chatbot_id: str, message: str, context: Dict[str, Any] = None
) -> Dict[str, Any]:
"""Consume platform chatbot API"""
return await self.post(
f"/api/v1/chatbot/external/{chatbot_id}/chat",
{
"message": message,
"context": context or {}
}
{"message": message, "context": context or {}},
)
async def call_llm_api(self, model: str, messages: List[Dict[str, Any]],
**kwargs) -> Dict[str, Any]:
async def call_llm_api(
self, model: str, messages: List[Dict[str, Any]], **kwargs
) -> Dict[str, Any]:
"""Consume platform LLM API"""
return await self.post(
"/api/v1/llm/chat/completions",
{
"model": model,
"messages": messages,
**kwargs
}
{"model": model, "messages": messages, **kwargs},
)
async def search_rag(self, collection: str, query: str,
top_k: int = 5) -> Dict[str, Any]:
async def search_rag(
self, collection: str, query: str, top_k: int = 5
) -> Dict[str, Any]:
"""Consume platform RAG API"""
return await self.post(
f"/api/v1/rag/collections/{collection}/search",
{
"query": query,
"top_k": top_k
}
{"query": query, "top_k": top_k},
)
async def get_embeddings(self, model: str, input_text: str) -> Dict[str, Any]:
"""Generate embeddings via platform API"""
return await self.post(
"/api/v1/llm/embeddings",
{
"model": model,
"input": input_text
}
"/api/v1/llm/embeddings", {"model": model, "input": input_text}
)
class PluginConfigManager:
"""Manages plugin configuration with validation and encryption"""
def __init__(self, plugin_id: str):
self.plugin_id = plugin_id
self.logger = get_logger(f"plugin.{plugin_id}.config")
async def get_config(self, user_id: Optional[str] = None) -> Dict[str, Any]:
"""Get plugin configuration for user (or default)"""
try:
# Use dependency injection to get database session
from app.db.database import SessionLocal
db = SessionLocal()
try:
# Query for active configuration
query = db.query(PluginConfiguration).filter(
PluginConfiguration.plugin_id == self.plugin_id,
PluginConfiguration.is_active == True
PluginConfiguration.is_active == True,
)
if user_id:
# Get user-specific configuration
query = query.filter(PluginConfiguration.user_id == user_id)
else:
# Get default configuration (is_default=True)
query = query.filter(PluginConfiguration.is_default == True)
config = query.first()
if config:
self.logger.debug(f"Retrieved configuration for plugin {self.plugin_id}, user {user_id}")
self.logger.debug(
f"Retrieved configuration for plugin {self.plugin_id}, user {user_id}"
)
return config.config_data or {}
else:
self.logger.debug(f"No configuration found for plugin {self.plugin_id}, user {user_id}")
self.logger.debug(
f"No configuration found for plugin {self.plugin_id}, user {user_id}"
)
return {}
finally:
db.close()
except Exception as e:
self.logger.error(f"Failed to get configuration: {e}")
return {}
async def save_config(self, config: Dict[str, Any], user_id: str,
name: str = "Default Configuration",
description: str = None) -> bool:
async def save_config(
self,
config: Dict[str, Any],
user_id: str,
name: str = "Default Configuration",
description: str = None,
) -> bool:
"""Save plugin configuration for user"""
try:
from app.db.database import SessionLocal
db = SessionLocal()
try:
# Check if configuration already exists
existing_config = db.query(PluginConfiguration).filter(
PluginConfiguration.plugin_id == self.plugin_id,
PluginConfiguration.user_id == user_id,
PluginConfiguration.name == name
).first()
existing_config = (
db.query(PluginConfiguration)
.filter(
PluginConfiguration.plugin_id == self.plugin_id,
PluginConfiguration.user_id == user_id,
PluginConfiguration.name == name,
)
.first()
)
if existing_config:
# Update existing configuration
existing_config.config_data = config
existing_config.description = description
existing_config.is_active = True
self.logger.info(f"Updated configuration for plugin {self.plugin_id}, user {user_id}")
self.logger.info(
f"Updated configuration for plugin {self.plugin_id}, user {user_id}"
)
else:
# Create new configuration
new_config = PluginConfiguration(
@@ -222,40 +235,48 @@ class PluginConfigManager:
config_data=config,
is_active=True,
is_default=(name == "Default Configuration"),
created_by_user_id=user_id
created_by_user_id=user_id,
)
# If this is the first configuration for this user/plugin, make it default
existing_count = db.query(PluginConfiguration).filter(
PluginConfiguration.plugin_id == self.plugin_id,
PluginConfiguration.user_id == user_id
).count()
existing_count = (
db.query(PluginConfiguration)
.filter(
PluginConfiguration.plugin_id == self.plugin_id,
PluginConfiguration.user_id == user_id,
)
.count()
)
if existing_count == 0:
new_config.is_default = True
db.add(new_config)
self.logger.info(f"Created new configuration for plugin {self.plugin_id}, user {user_id}")
self.logger.info(
f"Created new configuration for plugin {self.plugin_id}, user {user_id}"
)
db.commit()
return True
except Exception as e:
db.rollback()
self.logger.error(f"Database error saving configuration: {e}")
return False
finally:
db.close()
except Exception as e:
self.logger.error(f"Failed to save configuration: {e}")
return False
async def validate_config(self, config: Dict[str, Any],
schema: Dict[str, Any]) -> Tuple[bool, List[str]]:
async def validate_config(
self, config: Dict[str, Any], schema: Dict[str, Any]
) -> Tuple[bool, List[str]]:
"""Validate configuration against JSON schema"""
try:
import jsonschema
jsonschema.validate(config, schema)
return True, []
except jsonschema.ValidationError as e:
@@ -266,45 +287,52 @@ class PluginConfigManager:
class PluginLogger:
"""Plugin-specific logger with security filtering"""
def __init__(self, plugin_id: str):
self.plugin_id = plugin_id
self.logger = get_logger(f"plugin.{plugin_id}")
# Sensitive data patterns to filter
self.sensitive_patterns = [
r'password', r'token', r'key', r'secret', r'api_key',
r'bearer', r'authorization', r'credential'
r"password",
r"token",
r"key",
r"secret",
r"api_key",
r"bearer",
r"authorization",
r"credential",
]
def _filter_sensitive_data(self, message: str) -> str:
"""Filter sensitive data from log messages"""
import re
filtered_message = message
for pattern in self.sensitive_patterns:
filtered_message = re.sub(
f'{pattern}[=:]\s*["\']?([^"\'\\s]+)["\']?',
f'{pattern}=***REDACTED***',
f"{pattern}[=:]\s*[\"']?([^\"'\\s]+)[\"']?",
f"{pattern}=***REDACTED***",
filtered_message,
flags=re.IGNORECASE
flags=re.IGNORECASE,
)
return filtered_message
def info(self, message: str, **kwargs):
"""Log info message with sensitive data filtering"""
filtered_message = self._filter_sensitive_data(message)
self.logger.info(f"[PLUGIN:{self.plugin_id}] {filtered_message}", **kwargs)
def warning(self, message: str, **kwargs):
"""Log warning message with sensitive data filtering"""
filtered_message = self._filter_sensitive_data(message)
self.logger.warning(f"[PLUGIN:{self.plugin_id}] {filtered_message}", **kwargs)
def error(self, message: str, **kwargs):
"""Log error message with sensitive data filtering"""
filtered_message = self._filter_sensitive_data(message)
self.logger.error(f"[PLUGIN:{self.plugin_id}] {filtered_message}", **kwargs)
def debug(self, message: str, **kwargs):
"""Log debug message with sensitive data filtering"""
filtered_message = self._filter_sensitive_data(message)
@@ -313,45 +341,45 @@ class PluginLogger:
class BasePlugin(ABC):
"""Base class for all Enclava plugins with security and isolation"""
def __init__(self, manifest: PluginManifest, plugin_token: str):
self.manifest = manifest
self.plugin_id = manifest.metadata.name
self.version = manifest.metadata.version
# Initialize plugin services
self.api_client = PlatformAPIClient(self.plugin_id, plugin_token)
self.config = PluginConfigManager(self.plugin_id)
self.logger = PluginLogger(self.plugin_id)
# Plugin state
self.initialized = False
self._startup_time = time.time()
self._request_count = 0
self._error_count = 0
self.logger.info(f"Plugin {self.plugin_id} v{self.version} instantiated")
@abstractmethod
def get_api_router(self) -> APIRouter:
"""Return FastAPI router for plugin endpoints"""
pass
@abstractmethod
async def initialize(self) -> bool:
"""Initialize plugin resources and connections"""
pass
@abstractmethod
async def cleanup(self) -> bool:
"""Cleanup plugin resources on shutdown"""
pass
async def health_check(self) -> Dict[str, Any]:
"""Plugin health status"""
uptime = time.time() - self._startup_time
error_rate = self._error_count / max(self._request_count, 1)
return {
"status": "healthy" if error_rate < 0.1 else "warning",
"plugin": self.plugin_id,
@@ -360,28 +388,28 @@ class BasePlugin(ABC):
"request_count": self._request_count,
"error_count": self._error_count,
"error_rate": round(error_rate, 3),
"initialized": self.initialized
"initialized": self.initialized,
}
async def get_configuration_schema(self) -> Dict[str, Any]:
"""Return JSON schema for plugin configuration"""
return self.manifest.spec.config_schema
async def execute_cron_job(self, job_name: str) -> bool:
"""Execute scheduled cron job"""
self.logger.info(f"Executing cron job: {job_name}")
# Find job in manifest
job_spec = None
for job in self.manifest.spec.cron_jobs:
if job.name == job_name:
job_spec = job
break
if not job_spec:
self.logger.error(f"Cron job not found: {job_name}")
return False
try:
# Get the function to execute
if hasattr(self, job_spec.function):
@@ -390,34 +418,37 @@ class BasePlugin(ABC):
result = await func()
else:
result = func()
self.logger.info(f"Cron job {job_name} completed successfully")
return bool(result)
else:
self.logger.error(f"Cron job function not found: {job_spec.function}")
return False
except Exception as e:
self.logger.error(f"Cron job {job_name} failed: {e}")
self._error_count += 1
return False
def get_auth_context(self) -> PluginContext:
"""Dependency to get authentication context in API endpoints"""
async def _get_context(request: Request) -> PluginContext:
# Extract authentication info from request
# This would be populated by the plugin API gateway
return PluginContext(
user_id=request.headers.get('X-User-ID'),
api_key_id=request.headers.get('X-API-Key-ID'),
user_permissions=request.headers.get('X-User-Permissions', '').split(','),
ip_address=request.headers.get('X-Real-IP'),
user_agent=request.headers.get('User-Agent'),
request_id=request.headers.get('X-Request-ID')
user_id=request.headers.get("X-User-ID"),
api_key_id=request.headers.get("X-API-Key-ID"),
user_permissions=request.headers.get("X-User-Permissions", "").split(
","
),
ip_address=request.headers.get("X-Real-IP"),
user_agent=request.headers.get("User-Agent"),
request_id=request.headers.get("X-Request-ID"),
)
return Depends(_get_context)
def _track_request(self, success: bool = True):
"""Track request metrics"""
self._request_count += 1
@@ -427,164 +458,206 @@ class BasePlugin(ABC):
class PluginSecurityManager:
"""Manages plugin security and isolation"""
BLOCKED_IMPORTS = {
# Core platform modules
'app.db', 'app.models', 'app.core', 'app.services',
'sqlalchemy', 'alembic',
"app.db",
"app.models",
"app.core",
"app.services",
"sqlalchemy",
"alembic",
# Security sensitive
'subprocess', 'eval', 'exec', 'compile', '__import__',
'os.system', 'os.popen', 'os.spawn',
"subprocess",
"eval",
"exec",
"compile",
"__import__",
"os.system",
"os.popen",
"os.spawn",
# System access
'socket', 'multiprocessing', 'threading'
"socket",
"multiprocessing",
"threading",
}
ALLOWED_IMPORTS = {
# Standard library
'asyncio', 'aiohttp', 'json', 'datetime', 'typing', 'pydantic',
'logging', 'time', 'uuid', 'hashlib', 'base64', 'pathlib',
're', 'urllib.parse', 'dataclasses', 'enum',
"asyncio",
"aiohttp",
"json",
"datetime",
"typing",
"pydantic",
"logging",
"time",
"uuid",
"hashlib",
"base64",
"pathlib",
"re",
"urllib.parse",
"dataclasses",
"enum",
# Approved third-party
'httpx', 'requests', 'pandas', 'numpy', 'yaml',
"httpx",
"requests",
"pandas",
"numpy",
"yaml",
# Plugin framework
'app.services.base_plugin', 'app.schemas.plugin_manifest'
"app.services.base_plugin",
"app.schemas.plugin_manifest",
}
@classmethod
def validate_plugin_import(cls, import_name: str) -> bool:
"""Validate if plugin can import a module"""
# Block dangerous imports
if any(import_name.startswith(blocked) for blocked in cls.BLOCKED_IMPORTS):
raise SecurityError(f"Import '{import_name}' not allowed in plugin environment")
raise SecurityError(
f"Import '{import_name}' not allowed in plugin environment"
)
# Allow explicit safe imports
if any(import_name.startswith(allowed) for allowed in cls.ALLOWED_IMPORTS):
return True
# Log potentially unsafe imports
logger = get_logger("plugin.security")
logger.warning(f"Potentially unsafe import in plugin: {import_name}")
return True
@classmethod
def create_plugin_sandbox(cls, plugin_id: str) -> Dict[str, Any]:
"""Create isolated environment for plugin execution"""
return {
'max_memory_mb': 128,
'max_cpu_percent': 25,
'max_disk_mb': 100,
'max_api_calls_per_minute': 100,
'allowed_domains': [], # Will be populated from manifest
'network_timeout_seconds': 30
"max_memory_mb": 128,
"max_cpu_percent": 25,
"max_disk_mb": 100,
"max_api_calls_per_minute": 100,
"allowed_domains": [], # Will be populated from manifest
"network_timeout_seconds": 30,
}
class PluginLoader:
"""Loads and validates plugins from directories"""
def __init__(self):
self.logger = get_logger("plugin.loader")
self.loaded_plugins: Dict[str, BasePlugin] = {}
async def load_plugin(self, plugin_dir: Path, plugin_token: str) -> BasePlugin:
"""Load a plugin from a directory"""
self.logger.info(f"Loading plugin from: {plugin_dir}")
# Load and validate manifest
manifest_path = plugin_dir / "manifest.yaml"
validation_result = validate_manifest_file(manifest_path)
if not validation_result["valid"]:
raise ValidationError(f"Invalid plugin manifest: {validation_result['errors']}")
raise ValidationError(
f"Invalid plugin manifest: {validation_result['errors']}"
)
manifest = validation_result["manifest"]
# Check compatibility
compatibility = validation_result["compatibility"]
if not compatibility["compatible"]:
raise ValidationError(f"Plugin incompatible: {compatibility['errors']}")
# Load plugin module
main_py_path = plugin_dir / "main.py"
spec = importlib.util.spec_from_file_location(
f"plugin_{manifest.metadata.name}",
main_py_path
f"plugin_{manifest.metadata.name}", main_py_path
)
if not spec or not spec.loader:
raise ValidationError(f"Cannot load plugin module: {main_py_path}")
# Security check before loading
self._validate_plugin_security(main_py_path)
# Load module
plugin_module = importlib.util.module_from_spec(spec)
# Add to sys.modules to allow imports
sys.modules[spec.name] = plugin_module
try:
spec.loader.exec_module(plugin_module)
except Exception as e:
raise ValidationError(f"Failed to execute plugin module: {e}")
# Find plugin class
plugin_class = None
for attr_name in dir(plugin_module):
attr = getattr(plugin_module, attr_name)
if (isinstance(attr, type) and
issubclass(attr, BasePlugin) and
attr is not BasePlugin):
if (
isinstance(attr, type)
and issubclass(attr, BasePlugin)
and attr is not BasePlugin
):
plugin_class = attr
break
if not plugin_class:
raise ValidationError("Plugin must contain a class inheriting from BasePlugin")
raise ValidationError(
"Plugin must contain a class inheriting from BasePlugin"
)
# Instantiate plugin
plugin_instance = plugin_class(manifest, plugin_token)
# Initialize plugin
try:
await plugin_instance.initialize()
plugin_instance.initialized = True
except Exception as e:
raise ValidationError(f"Plugin initialization failed: {e}")
self.loaded_plugins[manifest.metadata.name] = plugin_instance
self.logger.info(f"Plugin {manifest.metadata.name} loaded successfully")
return plugin_instance
def _validate_plugin_security(self, main_py_path: Path):
"""Validate plugin code for security issues"""
with open(main_py_path, 'r', encoding='utf-8') as f:
with open(main_py_path, "r", encoding="utf-8") as f:
code_content = f.read()
# Check for dangerous patterns
dangerous_patterns = [
'eval(', 'exec(', 'compile(',
'subprocess.', 'os.system', 'os.popen',
'__import__', 'importlib.import_module',
'from app.db', 'from app.models',
'sqlalchemy', 'SessionLocal'
"eval(",
"exec(",
"compile(",
"subprocess.",
"os.system",
"os.popen",
"__import__",
"importlib.import_module",
"from app.db",
"from app.models",
"sqlalchemy",
"SessionLocal",
]
for pattern in dangerous_patterns:
if pattern in code_content:
raise SecurityError(f"Dangerous pattern detected in plugin code: {pattern}")
raise SecurityError(
f"Dangerous pattern detected in plugin code: {pattern}"
)
async def unload_plugin(self, plugin_id: str) -> bool:
"""Unload a plugin and cleanup resources"""
if plugin_id not in self.loaded_plugins:
return False
plugin = self.loaded_plugins[plugin_id]
try:
await plugin.cleanup()
del self.loaded_plugins[plugin_id]
@@ -593,11 +666,11 @@ class PluginLoader:
except Exception as e:
self.logger.error(f"Error unloading plugin {plugin_id}: {e}")
return False
def get_plugin(self, plugin_id: str) -> Optional[BasePlugin]:
"""Get loaded plugin by ID"""
return self.loaded_plugins.get(plugin_id)
def list_loaded_plugins(self) -> List[str]:
"""List all loaded plugin IDs"""
return list(self.loaded_plugins.keys())
return list(self.loaded_plugins.keys())

View File

@@ -21,11 +21,13 @@ logger = get_logger(__name__)
class BudgetEnforcementError(Exception):
"""Custom exception for budget enforcement failures"""
pass
class BudgetExceededError(BudgetEnforcementError):
"""Exception raised when budget would be exceeded"""
def __init__(self, message: str, budget: Budget, requested_cost: int):
super().__init__(message)
self.budget = budget
@@ -34,6 +36,7 @@ class BudgetExceededError(BudgetEnforcementError):
class BudgetWarningError(BudgetEnforcementError):
"""Exception raised when budget warning threshold is reached"""
def __init__(self, message: str, budget: Budget, requested_cost: int):
super().__init__(message)
self.budget = budget
@@ -42,6 +45,7 @@ class BudgetWarningError(BudgetEnforcementError):
class BudgetConcurrencyError(BudgetEnforcementError):
"""Exception raised when budget update fails due to concurrency"""
def __init__(self, message: str, retry_count: int = 0):
super().__init__(message)
self.retry_count = retry_count
@@ -49,6 +53,7 @@ class BudgetConcurrencyError(BudgetEnforcementError):
class BudgetAtomicError(BudgetEnforcementError):
"""Exception raised when atomic budget operation fails"""
def __init__(self, message: str, budget_id: int, requested_amount: int):
super().__init__(message)
self.budget_id = budget_id
@@ -57,84 +62,96 @@ class BudgetAtomicError(BudgetEnforcementError):
class BudgetEnforcementService:
"""Service for enforcing budget limits and tracking usage"""
def __init__(self, db: Session):
self.db = db
self.max_retries = 3
self.retry_delay_base = 0.1 # Base delay in seconds
def atomic_check_and_reserve_budget(
self,
api_key: APIKey,
model_name: str,
estimated_tokens: int,
endpoint: str = None
endpoint: str = None,
) -> Tuple[bool, Optional[str], List[Dict[str, Any]], List[int]]:
"""
Atomically check budget compliance and reserve spending
Returns:
Tuple of (is_allowed, error_message, warnings, reserved_budget_ids)
"""
estimated_cost = estimate_request_cost(model_name, estimated_tokens)
budgets = self._get_applicable_budgets(api_key, model_name, endpoint)
if not budgets:
logger.debug(f"No applicable budgets found for API key {api_key.id}")
return True, None, [], []
# Try atomic reservation with retries
for attempt in range(self.max_retries):
try:
return self._attempt_atomic_reservation(budgets, estimated_cost, api_key.id, attempt)
return self._attempt_atomic_reservation(
budgets, estimated_cost, api_key.id, attempt
)
except BudgetConcurrencyError as e:
if attempt == self.max_retries - 1:
logger.error(f"Atomic budget reservation failed after {self.max_retries} attempts: {e}")
return False, f"Budget check temporarily unavailable (concurrency limit)", [], []
logger.error(
f"Atomic budget reservation failed after {self.max_retries} attempts: {e}"
)
return (
False,
f"Budget check temporarily unavailable (concurrency limit)",
[],
[],
)
# Exponential backoff with jitter
delay = self.retry_delay_base * (2 ** attempt) + random.uniform(0, 0.1)
delay = self.retry_delay_base * (2**attempt) + random.uniform(0, 0.1)
time.sleep(delay)
logger.info(f"Retrying atomic budget reservation (attempt {attempt + 2})")
logger.info(
f"Retrying atomic budget reservation (attempt {attempt + 2})"
)
except Exception as e:
logger.error(f"Unexpected error in atomic budget reservation: {e}")
return False, f"Budget check failed: {str(e)}", [], []
return False, "Budget check failed after maximum retries", [], []
def _attempt_atomic_reservation(
self,
budgets: List[Budget],
estimated_cost: int,
api_key_id: int,
attempt: int
self, budgets: List[Budget], estimated_cost: int, api_key_id: int, attempt: int
) -> Tuple[bool, Optional[str], List[Dict[str, Any]], List[int]]:
"""Attempt to atomically reserve budget across all applicable budgets"""
warnings = []
reserved_budget_ids = []
try:
# Begin transaction
self.db.begin()
for budget in budgets:
# Lock budget row for update to prevent concurrent modifications
locked_budget = self.db.query(Budget).filter(
Budget.id == budget.id
).with_for_update().first()
locked_budget = (
self.db.query(Budget)
.filter(Budget.id == budget.id)
.with_for_update()
.first()
)
if not locked_budget:
raise BudgetAtomicError(f"Budget {budget.id} not found", budget.id, estimated_cost)
raise BudgetAtomicError(
f"Budget {budget.id} not found", budget.id, estimated_cost
)
# Reset budget if expired and auto-renew enabled
if locked_budget.is_expired() and locked_budget.auto_renew:
self._reset_expired_budget(locked_budget)
self.db.flush() # Ensure reset is applied before checking
# Skip inactive or expired budgets
if not locked_budget.is_active or locked_budget.is_expired():
continue
# Check if request would exceed budget using atomic operation
if not self._atomic_can_spend(locked_budget, estimated_cost):
error_msg = (
@@ -144,56 +161,74 @@ class BudgetEnforcementService:
f"Requested: ${estimated_cost/100:.4f}, "
f"Remaining: ${(locked_budget.limit_cents - locked_budget.current_usage_cents)/100:.2f}"
)
logger.warning(f"Budget exceeded for API key {api_key_id}: {error_msg}")
logger.warning(
f"Budget exceeded for API key {api_key_id}: {error_msg}"
)
self.db.rollback()
return False, error_msg, warnings, []
# Check warning threshold
if locked_budget.would_exceed_warning(estimated_cost) and not locked_budget.is_warning_sent:
if (
locked_budget.would_exceed_warning(estimated_cost)
and not locked_budget.is_warning_sent
):
warning_msg = (
f"Budget '{locked_budget.name}' approaching limit. "
f"Usage will be ${(locked_budget.current_usage_cents + estimated_cost)/100:.2f} "
f"of ${locked_budget.limit_cents/100:.2f} "
f"({((locked_budget.current_usage_cents + estimated_cost) / locked_budget.limit_cents * 100):.1f}%)"
)
warnings.append({
"type": "budget_warning",
"budget_id": locked_budget.id,
"budget_name": locked_budget.name,
"message": warning_msg,
"current_usage_cents": locked_budget.current_usage_cents + estimated_cost,
"limit_cents": locked_budget.limit_cents,
"usage_percentage": (locked_budget.current_usage_cents + estimated_cost) / locked_budget.limit_cents * 100
})
logger.info(f"Budget warning for API key {api_key_id}: {warning_msg}")
warnings.append(
{
"type": "budget_warning",
"budget_id": locked_budget.id,
"budget_name": locked_budget.name,
"message": warning_msg,
"current_usage_cents": locked_budget.current_usage_cents
+ estimated_cost,
"limit_cents": locked_budget.limit_cents,
"usage_percentage": (
locked_budget.current_usage_cents + estimated_cost
)
/ locked_budget.limit_cents
* 100,
}
)
logger.info(
f"Budget warning for API key {api_key_id}: {warning_msg}"
)
# Reserve the budget (temporarily add estimated cost)
self._atomic_reserve_usage(locked_budget, estimated_cost)
reserved_budget_ids.append(locked_budget.id)
# Commit the reservation
self.db.commit()
logger.debug(f"Successfully reserved budget for API key {api_key_id}, estimated cost: ${estimated_cost/100:.4f}")
logger.debug(
f"Successfully reserved budget for API key {api_key_id}, estimated cost: ${estimated_cost/100:.4f}"
)
return True, None, warnings, reserved_budget_ids
except IntegrityError as e:
self.db.rollback()
raise BudgetConcurrencyError(f"Database integrity error during budget reservation: {e}", attempt)
raise BudgetConcurrencyError(
f"Database integrity error during budget reservation: {e}", attempt
)
except Exception as e:
self.db.rollback()
logger.error(f"Error in atomic budget reservation: {e}")
raise
def _atomic_can_spend(self, budget: Budget, amount_cents: int) -> bool:
"""Atomically check if budget can accommodate spending"""
if not budget.is_active or not budget.is_in_period():
return False
if not budget.enforce_hard_limit:
return True
return (budget.current_usage_cents + amount_cents) <= budget.limit_cents
def _atomic_reserve_usage(self, budget: Budget, amount_cents: int):
"""Atomically reserve usage in budget (add to current usage)"""
# Use database-level atomic update
@@ -203,26 +238,37 @@ class BudgetEnforcementService:
.values(
current_usage_cents=Budget.current_usage_cents + amount_cents,
updated_at=datetime.utcnow(),
is_exceeded=Budget.current_usage_cents + amount_cents >= Budget.limit_cents,
is_exceeded=Budget.current_usage_cents + amount_cents
>= Budget.limit_cents,
is_warning_sent=(
Budget.is_warning_sent |
((Budget.warning_threshold_cents.isnot(None)) &
(Budget.current_usage_cents + amount_cents >= Budget.warning_threshold_cents))
)
Budget.is_warning_sent
| (
(Budget.warning_threshold_cents.isnot(None))
& (
Budget.current_usage_cents + amount_cents
>= Budget.warning_threshold_cents
)
)
),
)
)
if result.rowcount != 1:
raise BudgetAtomicError(f"Failed to update budget {budget.id}", budget.id, amount_cents)
raise BudgetAtomicError(
f"Failed to update budget {budget.id}", budget.id, amount_cents
)
# Update the in-memory object to reflect changes
budget.current_usage_cents += amount_cents
budget.updated_at = datetime.utcnow()
if budget.current_usage_cents >= budget.limit_cents:
budget.is_exceeded = True
if budget.warning_threshold_cents and budget.current_usage_cents >= budget.warning_threshold_cents:
if (
budget.warning_threshold_cents
and budget.current_usage_cents >= budget.warning_threshold_cents
):
budget.is_warning_sent = True
def atomic_finalize_usage(
self,
reserved_budget_ids: List[int],
@@ -230,11 +276,11 @@ class BudgetEnforcementService:
model_name: str,
input_tokens: int,
output_tokens: int,
endpoint: str = None
endpoint: str = None,
) -> List[Budget]:
"""
Finalize actual usage and adjust reservations
Args:
reserved_budget_ids: Budget IDs that had usage reserved
api_key: API key that made the request
@@ -242,101 +288,110 @@ class BudgetEnforcementService:
input_tokens: Actual input tokens used
output_tokens: Actual output tokens used
endpoint: API endpoint that was accessed
Returns:
List of budgets that were updated
"""
if not reserved_budget_ids:
return []
try:
actual_cost = CostCalculator.calculate_cost_cents(model_name, input_tokens, output_tokens)
actual_cost = CostCalculator.calculate_cost_cents(
model_name, input_tokens, output_tokens
)
updated_budgets = []
# Begin transaction for finalization
self.db.begin()
for budget_id in reserved_budget_ids:
# Lock budget for update
budget = self.db.query(Budget).filter(
Budget.id == budget_id
).with_for_update().first()
budget = (
self.db.query(Budget)
.filter(Budget.id == budget_id)
.with_for_update()
.first()
)
if not budget:
logger.warning(f"Budget {budget_id} not found during finalization")
continue
if budget.is_active and budget.is_in_period():
# Calculate adjustment (actual cost - estimated cost already reserved)
# Note: We don't know the exact estimated cost that was reserved
# So we'll just set to actual cost (this is safe as we already reserved)
self._atomic_set_actual_usage(budget, actual_cost, input_tokens, output_tokens)
self._atomic_set_actual_usage(
budget, actual_cost, input_tokens, output_tokens
)
updated_budgets.append(budget)
logger.debug(
f"Finalized usage for budget {budget.id}: "
f"${actual_cost/100:.4f} (total: ${budget.current_usage_cents/100:.2f})"
)
# Commit finalization
self.db.commit()
return updated_budgets
except Exception as e:
logger.error(f"Error finalizing budget usage: {e}")
self.db.rollback()
return []
def _atomic_set_actual_usage(self, budget: Budget, actual_cost: int, input_tokens: int, output_tokens: int):
def _atomic_set_actual_usage(
self, budget: Budget, actual_cost: int, input_tokens: int, output_tokens: int
):
"""Set the actual usage cost (replacing any reservation)"""
# For simplicity, we'll just ensure the current usage reflects actual cost
# In a more sophisticated system, you might track reservations separately
# For now, the reservation system ensures we don't exceed limits
# and the actual cost will be very close to estimated cost
pass # The reservation already added the estimated cost, actual cost adjustment is minimal
def check_budget_compliance(
self,
api_key: APIKey,
model_name: str,
estimated_tokens: int,
endpoint: str = None
endpoint: str = None,
) -> Tuple[bool, Optional[str], List[Dict[str, Any]]]:
"""
Check if a request complies with budget limits
Args:
api_key: API key making the request
model_name: Model being used
estimated_tokens: Estimated token usage
endpoint: API endpoint being accessed
Returns:
Tuple of (is_allowed, error_message, warnings)
"""
try:
# Calculate estimated cost
estimated_cost = estimate_request_cost(model_name, estimated_tokens)
# Get applicable budgets
budgets = self._get_applicable_budgets(api_key, model_name, endpoint)
if not budgets:
logger.debug(f"No applicable budgets found for API key {api_key.id}")
return True, None, []
warnings = []
# Check each budget
for budget in budgets:
# Reset budget if period expired and auto-renew is enabled
if budget.is_expired() and budget.auto_renew:
self._reset_expired_budget(budget)
# Skip inactive or expired budgets
if not budget.is_active or budget.is_expired():
continue
# Check if request would exceed budget
if not budget.can_spend(estimated_cost):
error_msg = (
@@ -346,145 +401,160 @@ class BudgetEnforcementService:
f"Requested: ${estimated_cost/100:.4f}, "
f"Remaining: ${(budget.limit_cents - budget.current_usage_cents)/100:.2f}"
)
logger.warning(f"Budget exceeded for API key {api_key.id}: {error_msg}")
logger.warning(
f"Budget exceeded for API key {api_key.id}: {error_msg}"
)
return False, error_msg, warnings
# Check if request would trigger warning
if budget.would_exceed_warning(estimated_cost) and not budget.is_warning_sent:
if (
budget.would_exceed_warning(estimated_cost)
and not budget.is_warning_sent
):
warning_msg = (
f"Budget '{budget.name}' approaching limit. "
f"Usage will be ${(budget.current_usage_cents + estimated_cost)/100:.2f} "
f"of ${budget.limit_cents/100:.2f} "
f"({((budget.current_usage_cents + estimated_cost) / budget.limit_cents * 100):.1f}%)"
)
warnings.append({
"type": "budget_warning",
"budget_id": budget.id,
"budget_name": budget.name,
"message": warning_msg,
"current_usage_cents": budget.current_usage_cents + estimated_cost,
"limit_cents": budget.limit_cents,
"usage_percentage": (budget.current_usage_cents + estimated_cost) / budget.limit_cents * 100
})
logger.info(f"Budget warning for API key {api_key.id}: {warning_msg}")
warnings.append(
{
"type": "budget_warning",
"budget_id": budget.id,
"budget_name": budget.name,
"message": warning_msg,
"current_usage_cents": budget.current_usage_cents
+ estimated_cost,
"limit_cents": budget.limit_cents,
"usage_percentage": (
budget.current_usage_cents + estimated_cost
)
/ budget.limit_cents
* 100,
}
)
logger.info(
f"Budget warning for API key {api_key.id}: {warning_msg}"
)
return True, None, warnings
except Exception as e:
logger.error(f"Error checking budget compliance: {e}")
# Allow request on error to avoid blocking legitimate usage
return True, None, []
def record_usage(
self,
api_key: APIKey,
model_name: str,
input_tokens: int,
output_tokens: int,
endpoint: str = None
endpoint: str = None,
) -> List[Budget]:
"""
Record actual usage against applicable budgets
Args:
api_key: API key that made the request
model_name: Model that was used
input_tokens: Actual input tokens used
output_tokens: Actual output tokens used
endpoint: API endpoint that was accessed
Returns:
List of budgets that were updated
"""
try:
# Calculate actual cost
actual_cost = CostCalculator.calculate_cost_cents(model_name, input_tokens, output_tokens)
actual_cost = CostCalculator.calculate_cost_cents(
model_name, input_tokens, output_tokens
)
# Get applicable budgets
budgets = self._get_applicable_budgets(api_key, model_name, endpoint)
updated_budgets = []
for budget in budgets:
if budget.is_active and budget.is_in_period():
# Add usage to budget
budget.add_usage(actual_cost)
updated_budgets.append(budget)
logger.debug(
f"Recorded usage for budget {budget.id}: "
f"${actual_cost/100:.4f} (total: ${budget.current_usage_cents/100:.2f})"
)
# Commit changes
self.db.commit()
return updated_budgets
except Exception as e:
logger.error(f"Error recording budget usage: {e}")
self.db.rollback()
return []
def _get_applicable_budgets(
self,
api_key: APIKey,
model_name: str = None,
endpoint: str = None
self, api_key: APIKey, model_name: str = None, endpoint: str = None
) -> List[Budget]:
"""Get budgets that apply to the given request"""
# Build query conditions
conditions = [
Budget.is_active == True,
or_(
and_(Budget.user_id == api_key.user_id, Budget.api_key_id.is_(None)), # User budget
Budget.api_key_id == api_key.id # API key specific budget
)
and_(
Budget.user_id == api_key.user_id, Budget.api_key_id.is_(None)
), # User budget
Budget.api_key_id == api_key.id, # API key specific budget
),
]
# Query budgets
query = self.db.query(Budget).filter(and_(*conditions))
budgets = query.all()
# Filter budgets based on allowed models/endpoints
applicable_budgets = []
for budget in budgets:
# Check model restrictions
if model_name and budget.allowed_models:
if model_name not in budget.allowed_models:
continue
# Check endpoint restrictions
if endpoint and budget.allowed_endpoints:
if endpoint not in budget.allowed_endpoints:
continue
applicable_budgets.append(budget)
return applicable_budgets
def _reset_expired_budget(self, budget: Budget):
"""Reset an expired budget for the next period"""
try:
budget.reset_period()
self.db.commit()
logger.info(
f"Reset expired budget {budget.id} for new period: "
f"{budget.period_start} to {budget.period_end}"
)
except Exception as e:
logger.error(f"Error resetting expired budget {budget.id}: {e}")
self.db.rollback()
def get_budget_status(self, api_key: APIKey) -> Dict[str, Any]:
"""Get comprehensive budget status for an API key"""
try:
budgets = self._get_applicable_budgets(api_key)
status = {
"total_budgets": len(budgets),
"active_budgets": 0,
@@ -492,44 +562,53 @@ class BudgetEnforcementService:
"warning_budgets": 0,
"total_limit_cents": 0,
"total_usage_cents": 0,
"budgets": []
"budgets": [],
}
for budget in budgets:
if not budget.is_active:
continue
budget_info = budget.to_dict()
budget_info.update({
"is_expired": budget.is_expired(),
"days_remaining": budget.get_period_days_remaining(),
"daily_burn_rate": budget.get_daily_burn_rate(),
"projected_spend": budget.get_projected_spend()
})
budget_info.update(
{
"is_expired": budget.is_expired(),
"days_remaining": budget.get_period_days_remaining(),
"daily_burn_rate": budget.get_daily_burn_rate(),
"projected_spend": budget.get_projected_spend(),
}
)
status["budgets"].append(budget_info)
status["active_budgets"] += 1
status["total_limit_cents"] += budget.limit_cents
status["total_usage_cents"] += budget.current_usage_cents
if budget.is_exceeded:
status["exceeded_budgets"] += 1
elif budget.warning_threshold_cents and budget.current_usage_cents >= budget.warning_threshold_cents:
elif (
budget.warning_threshold_cents
and budget.current_usage_cents >= budget.warning_threshold_cents
):
status["warning_budgets"] += 1
# Calculate overall percentages
if status["total_limit_cents"] > 0:
status["overall_usage_percentage"] = (status["total_usage_cents"] / status["total_limit_cents"]) * 100
status["overall_usage_percentage"] = (
status["total_usage_cents"] / status["total_limit_cents"]
) * 100
else:
status["overall_usage_percentage"] = 0
status["total_limit_dollars"] = status["total_limit_cents"] / 100
status["total_usage_dollars"] = status["total_usage_cents"] / 100
status["total_remaining_cents"] = max(0, status["total_limit_cents"] - status["total_usage_cents"])
status["total_remaining_cents"] = max(
0, status["total_limit_cents"] - status["total_usage_cents"]
)
status["total_remaining_dollars"] = status["total_remaining_cents"] / 100
return status
except Exception as e:
logger.error(f"Error getting budget status: {e}")
return {
@@ -538,14 +617,11 @@ class BudgetEnforcementService:
"active_budgets": 0,
"exceeded_budgets": 0,
"warning_budgets": 0,
"budgets": []
"budgets": [],
}
def create_default_user_budget(
self,
user_id: int,
limit_dollars: float = 10.0,
period_type: str = "monthly"
self, user_id: int, limit_dollars: float = 10.0, period_type: str = "monthly"
) -> Budget:
"""Create a default budget for a new user"""
try:
@@ -553,60 +629,69 @@ class BudgetEnforcementService:
budget = Budget.create_monthly_budget(
user_id=user_id,
name="Default Monthly Budget",
limit_dollars=limit_dollars
limit_dollars=limit_dollars,
)
else:
budget = Budget.create_daily_budget(
user_id=user_id,
name="Default Daily Budget",
limit_dollars=limit_dollars
limit_dollars=limit_dollars,
)
self.db.add(budget)
self.db.commit()
logger.info(f"Created default budget for user {user_id}: ${limit_dollars} {period_type}")
logger.info(
f"Created default budget for user {user_id}: ${limit_dollars} {period_type}"
)
return budget
except Exception as e:
logger.error(f"Error creating default budget: {e}")
self.db.rollback()
raise
def check_and_reset_expired_budgets(self):
"""Background task to check and reset expired budgets"""
try:
expired_budgets = self.db.query(Budget).filter(
and_(
Budget.is_active == True,
Budget.auto_renew == True,
Budget.period_end < datetime.utcnow()
expired_budgets = (
self.db.query(Budget)
.filter(
and_(
Budget.is_active == True,
Budget.auto_renew == True,
Budget.period_end < datetime.utcnow(),
)
)
).all()
.all()
)
for budget in expired_budgets:
self._reset_expired_budget(budget)
logger.info(f"Reset {len(expired_budgets)} expired budgets")
except Exception as e:
logger.error(f"Error in budget reset task: {e}")
# Convenience functions
# DEPRECATED: Use atomic versions for race-condition-free budget enforcement
def check_budget_for_request(
db: Session,
api_key: APIKey,
model_name: str,
estimated_tokens: int,
endpoint: str = None
endpoint: str = None,
) -> Tuple[bool, Optional[str], List[Dict[str, Any]]]:
"""DEPRECATED: Convenience function to check budget compliance (race conditions possible)"""
service = BudgetEnforcementService(db)
return service.check_budget_compliance(api_key, model_name, estimated_tokens, endpoint)
return service.check_budget_compliance(
api_key, model_name, estimated_tokens, endpoint
)
def record_request_usage(
@@ -615,11 +700,13 @@ def record_request_usage(
model_name: str,
input_tokens: int,
output_tokens: int,
endpoint: str = None
endpoint: str = None,
) -> List[Budget]:
"""DEPRECATED: Convenience function to record actual usage (race conditions possible)"""
service = BudgetEnforcementService(db)
return service.record_usage(api_key, model_name, input_tokens, output_tokens, endpoint)
return service.record_usage(
api_key, model_name, input_tokens, output_tokens, endpoint
)
# ATOMIC VERSIONS: Race-condition-free budget enforcement
@@ -628,11 +715,13 @@ def atomic_check_and_reserve_budget(
api_key: APIKey,
model_name: str,
estimated_tokens: int,
endpoint: str = None
endpoint: str = None,
) -> Tuple[bool, Optional[str], List[Dict[str, Any]], List[int]]:
"""Atomic convenience function to check budget compliance and reserve spending"""
service = BudgetEnforcementService(db)
return service.atomic_check_and_reserve_budget(api_key, model_name, estimated_tokens, endpoint)
return service.atomic_check_and_reserve_budget(
api_key, model_name, estimated_tokens, endpoint
)
def atomic_finalize_usage(
@@ -642,8 +731,10 @@ def atomic_finalize_usage(
model_name: str,
input_tokens: int,
output_tokens: int,
endpoint: str = None
endpoint: str = None,
) -> List[Budget]:
"""Atomic convenience function to finalize actual usage after request completion"""
service = BudgetEnforcementService(db)
return service.atomic_finalize_usage(reserved_budget_ids, api_key, model_name, input_tokens, output_tokens, endpoint)
return service.atomic_finalize_usage(
reserved_budget_ids, api_key, model_name, input_tokens, output_tokens, endpoint
)

View File

@@ -23,17 +23,21 @@ logger = logging.getLogger(__name__)
class CachedAPIKeyService:
"""Core cache-backed API key caching service for performance optimization"""
def __init__(self):
self.cache_ttl = 300 # 5 minutes cache TTL
self.verification_cache_ttl = 3600 # 1 hour for verification results
logger.info("Cached API key service initialized with core cache backend")
async def close(self):
"""Close method for compatibility - core cache handles its own lifecycle"""
logger.info("Cached API key service close called - core cache handles lifecycle")
async def get_cached_api_key(self, key_prefix: str, db: AsyncSession) -> Optional[Dict[str, Any]]:
logger.info(
"Cached API key service close called - core cache handles lifecycle"
)
async def get_cached_api_key(
self, key_prefix: str, db: AsyncSession
) -> Optional[Dict[str, Any]]:
"""
Get API key data from cache or database
Returns: Dictionary with api_key, user, and api_key_id
@@ -43,59 +47,54 @@ class CachedAPIKeyService:
cached_data = await core_cache.get_cached_api_key(key_prefix)
if cached_data:
logger.debug(f"API key cache hit for prefix: {key_prefix}")
# Recreate APIKey object from cached data
api_key_data = cached_data.get("api_key_data", {})
user_data = cached_data.get("user_data", {})
# Create APIKey instance
api_key = APIKey(**api_key_data)
# Create User instance
# Create User instance
user = User(**user_data)
return {
"api_key": api_key,
"user": user,
"api_key_id": api_key_data.get("id")
"api_key_id": api_key_data.get("id"),
}
logger.debug(f"API key cache miss for prefix: {key_prefix}, fetching from database")
logger.debug(
f"API key cache miss for prefix: {key_prefix}, fetching from database"
)
# Cache miss - fetch from database with optimized query
stmt = (
select(APIKey, User)
.join(User, APIKey.user_id == User.id)
.options(
joinedload(APIKey.user),
joinedload(User.api_keys)
)
.options(joinedload(APIKey.user), joinedload(User.api_keys))
.where(APIKey.key_prefix == key_prefix)
.where(APIKey.is_active == True)
)
result = await db.execute(stmt)
api_key_user = result.first()
if not api_key_user:
logger.debug(f"API key not found in database for prefix: {key_prefix}")
return None
api_key, user = api_key_user
# Cache for future requests
await self._cache_api_key_data(key_prefix, api_key, user)
return {
"api_key": api_key,
"user": user,
"api_key_id": api_key.id
}
return {"api_key": api_key, "user": user, "api_key_id": api_key.id}
except Exception as e:
logger.error(f"Error retrieving API key for prefix {key_prefix}: {e}")
return None
async def _cache_api_key_data(self, key_prefix: str, api_key: APIKey, user: User):
"""Cache API key and user data"""
try:
@@ -118,17 +117,25 @@ class CachedAPIKeyService:
"allowed_ips": api_key.allowed_ips,
"description": api_key.description,
"tags": api_key.tags,
"created_at": api_key.created_at.isoformat() if api_key.created_at else None,
"updated_at": api_key.updated_at.isoformat() if api_key.updated_at else None,
"last_used_at": api_key.last_used_at.isoformat() if api_key.last_used_at else None,
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
"created_at": api_key.created_at.isoformat()
if api_key.created_at
else None,
"updated_at": api_key.updated_at.isoformat()
if api_key.updated_at
else None,
"last_used_at": api_key.last_used_at.isoformat()
if api_key.last_used_at
else None,
"expires_at": api_key.expires_at.isoformat()
if api_key.expires_at
else None,
"total_requests": api_key.total_requests,
"total_tokens": api_key.total_tokens,
"total_cost": api_key.total_cost,
"is_unlimited": api_key.is_unlimited,
"budget_limit_cents": api_key.budget_limit_cents,
"budget_type": api_key.budget_type,
"allowed_chatbots": api_key.allowed_chatbots
"allowed_chatbots": api_key.allowed_chatbots,
},
"user_data": {
"id": user.id,
@@ -137,20 +144,28 @@ class CachedAPIKeyService:
"is_active": user.is_active,
"is_superuser": user.is_superuser,
"role": user.role,
"created_at": user.created_at.isoformat() if user.created_at else None,
"updated_at": user.updated_at.isoformat() if user.updated_at else None,
"last_login": user.last_login.isoformat() if user.last_login else None
"created_at": user.created_at.isoformat()
if user.created_at
else None,
"updated_at": user.updated_at.isoformat()
if user.updated_at
else None,
"last_login": user.last_login.isoformat()
if user.last_login
else None,
},
"cached_at": datetime.utcnow().isoformat()
"cached_at": datetime.utcnow().isoformat(),
}
await core_cache.cache_api_key(key_prefix, cache_data, self.cache_ttl)
logger.debug(f"Cached API key data for prefix: {key_prefix}")
except Exception as e:
logger.error(f"Error caching API key data for prefix {key_prefix}: {e}")
async def verify_api_key_cached(self, api_key: str, key_prefix: str) -> Optional[bool]:
async def verify_api_key_cached(
self, api_key: str, key_prefix: str
) -> Optional[bool]:
"""
Verify API key using cached hash to avoid expensive bcrypt operations
Returns: True if verified, False if invalid, None if not cached
@@ -158,73 +173,88 @@ class CachedAPIKeyService:
try:
# Check verification cache
cached_verification = await core_cache.get_cached_verification(key_prefix)
if cached_verification:
# Check if cache is still valid (within TTL)
cached_timestamp = datetime.fromisoformat(cached_verification["timestamp"])
if datetime.utcnow() - cached_timestamp < timedelta(seconds=self.verification_cache_ttl):
logger.debug(f"API key verification cache hit for prefix: {key_prefix}")
cached_timestamp = datetime.fromisoformat(
cached_verification["timestamp"]
)
if datetime.utcnow() - cached_timestamp < timedelta(
seconds=self.verification_cache_ttl
):
logger.debug(
f"API key verification cache hit for prefix: {key_prefix}"
)
return cached_verification.get("is_valid", False)
return None # Not cached or expired
except Exception as e:
logger.error(f"Error checking verification cache for prefix {key_prefix}: {e}")
logger.error(
f"Error checking verification cache for prefix {key_prefix}: {e}"
)
return None
async def cache_verification_result(self, api_key: str, key_prefix: str, key_hash: str, is_valid: bool):
async def cache_verification_result(
self, api_key: str, key_prefix: str, key_hash: str, is_valid: bool
):
"""Cache API key verification result to avoid expensive bcrypt operations"""
try:
await core_cache.cache_verification_result(api_key, key_prefix, key_hash, is_valid, self.verification_cache_ttl)
await core_cache.cache_verification_result(
api_key, key_prefix, key_hash, is_valid, self.verification_cache_ttl
)
logger.debug(f"Cached verification result for prefix: {key_prefix}")
except Exception as e:
logger.error(f"Error caching verification result for prefix {key_prefix}: {e}")
logger.error(
f"Error caching verification result for prefix {key_prefix}: {e}"
)
async def invalidate_api_key_cache(self, key_prefix: str):
"""Invalidate cached API key data"""
try:
await core_cache.invalidate_api_key(key_prefix)
# Also invalidate verification cache
verification_keys = await core_cache.clear_pattern(f"verify:{key_prefix}*", prefix="auth")
verification_keys = await core_cache.clear_pattern(
f"verify:{key_prefix}*", prefix="auth"
)
logger.debug(f"Invalidated cache for API key prefix: {key_prefix}")
except Exception as e:
logger.error(f"Error invalidating cache for prefix {key_prefix}: {e}")
async def update_last_used(self, api_key_id: int, db: AsyncSession):
"""Update last used timestamp asynchronously for performance"""
try:
# Use core cache to track update requests to avoid database spam
cache_key = f"last_used_update:{api_key_id}"
# Check if we recently updated (within 5 minutes)
last_update = await core_cache.get(cache_key, prefix="perf")
if last_update:
return # Skip update if recent
# Update database
stmt = (
select(APIKey)
.where(APIKey.id == api_key_id)
)
stmt = select(APIKey).where(APIKey.id == api_key_id)
result = await db.execute(stmt)
api_key = result.scalar_one_or_none()
if api_key:
api_key.last_used_at = datetime.utcnow()
await db.commit()
# Cache that we updated to prevent spam
await core_cache.set(cache_key, datetime.utcnow().isoformat(), ttl=300, prefix="perf")
await core_cache.set(
cache_key, datetime.utcnow().isoformat(), ttl=300, prefix="perf"
)
logger.debug(f"Updated last_used_at for API key {api_key_id}")
except Exception as e:
logger.error(f"Error updating last_used for API key {api_key_id}: {e}")
async def get_cache_stats(self) -> Dict[str, Any]:
"""Get cache performance statistics"""
try:
@@ -232,16 +262,16 @@ class CachedAPIKeyService:
return {
"cache_backend": "core_cache",
"cache_enabled": core_stats.get("enabled", False),
"cache_stats": core_stats
"cache_stats": core_stats,
}
except Exception as e:
logger.error(f"Error getting cache stats: {e}")
return {
"cache_backend": "core_cache",
"cache_backend": "core_cache",
"cache_enabled": False,
"error": str(e)
"error": str(e),
}
# Global instance
cached_api_key_service = CachedAPIKeyService()
cached_api_key_service = CachedAPIKeyService()

View File

@@ -25,6 +25,7 @@ logger = get_logger(__name__)
@dataclass
class ConfigVersion:
"""Configuration version metadata"""
version: str
timestamp: datetime
checksum: str
@@ -36,6 +37,7 @@ class ConfigVersion:
@dataclass
class ConfigSchema:
"""Configuration schema definition"""
name: str
required_fields: List[str]
optional_fields: List[str]
@@ -46,6 +48,7 @@ class ConfigSchema:
@dataclass
class ConfigStats:
"""Configuration manager statistics"""
total_configs: int
active_watchers: int
config_versions: int
@@ -57,42 +60,42 @@ class ConfigStats:
class ConfigWatcher(FileSystemEventHandler):
"""File system watcher for configuration changes"""
def __init__(self, config_manager):
self.config_manager = config_manager
self.debounce_time = 1.0 # 1 second debounce
self.last_modified = {}
def on_modified(self, event):
if event.is_directory:
return
path = event.src_path
current_time = time.time()
# Debounce rapid file changes
if path in self.last_modified:
if current_time - self.last_modified[path] < self.debounce_time:
return
self.last_modified[path] = current_time
# Trigger hot reload for config files
if path.endswith(('.json', '.yaml', '.yml', '.toml')):
if path.endswith((".json", ".yaml", ".yml", ".toml")):
# Schedule coroutine in a thread-safe way
try:
loop = asyncio.get_running_loop()
loop.call_soon_threadsafe(
lambda: asyncio.create_task(self.config_manager.reload_config_file(path))
lambda: asyncio.create_task(
self.config_manager.reload_config_file(path)
)
)
except RuntimeError:
# No running loop, schedule for later
threading.Thread(
target=self._schedule_reload,
args=(path,),
daemon=True
target=self._schedule_reload, args=(path,), daemon=True
).start()
def _schedule_reload(self, path: str):
"""Schedule reload in a new thread if no event loop is available"""
try:
@@ -103,14 +106,14 @@ class ConfigWatcher(FileSystemEventHandler):
class ConfigManager:
"""Core configuration management system"""
def __init__(self):
self.configs: Dict[str, Dict[str, Any]] = {}
self.schemas: Dict[str, ConfigSchema] = {}
self.versions: Dict[str, List[ConfigVersion]] = {}
self.watchers: Dict[str, Observer] = {}
self.config_paths: Dict[str, Path] = {}
self.environment = os.getenv('ENVIRONMENT', 'development')
self.environment = os.getenv("ENVIRONMENT", "development")
self.start_time = time.time()
self.stats = ConfigStats(
total_configs=0,
@@ -119,32 +122,32 @@ class ConfigManager:
hot_reloads_performed=0,
validation_errors=0,
last_reload_time=datetime.now(),
uptime=0
uptime=0,
)
# Base configuration directories
self.config_base_dir = Path("configs")
self.config_base_dir.mkdir(exist_ok=True)
# Environment-specific directory
self.env_config_dir = self.config_base_dir / self.environment
self.env_config_dir.mkdir(exist_ok=True)
logger.info(f"ConfigManager initialized for environment: {self.environment}")
def register_schema(self, name: str, schema: ConfigSchema):
"""Register a configuration schema for validation"""
self.schemas[name] = schema
logger.info(f"Registered configuration schema: {name}")
def validate_config(self, name: str, config_data: Dict[str, Any]) -> bool:
"""Validate configuration against registered schema"""
if name not in self.schemas:
logger.debug(f"No schema registered for config: {name}")
return True
schema = self.schemas[name]
try:
# Check required fields
for field in schema.required_fields:
@@ -152,189 +155,202 @@ class ConfigManager:
logger.error(f"Missing required field '{field}' in config '{name}'")
self.stats.validation_errors += 1
return False
# Validate field types
for field, expected_type in schema.field_types.items():
if field in config_data:
if not isinstance(config_data[field], expected_type):
logger.error(f"Invalid type for field '{field}' in config '{name}'. Expected {expected_type.__name__}")
logger.error(
f"Invalid type for field '{field}' in config '{name}'. Expected {expected_type.__name__}"
)
self.stats.validation_errors += 1
return False
# Run custom validators
for field, validator in schema.validators.items():
if field in config_data:
if not validator(config_data[field]):
logger.error(f"Validation failed for field '{field}' in config '{name}'")
logger.error(
f"Validation failed for field '{field}' in config '{name}'"
)
self.stats.validation_errors += 1
return False
logger.debug(f"Configuration '{name}' passed validation")
return True
except Exception as e:
logger.error(f"Error validating config '{name}': {str(e)}")
self.stats.validation_errors += 1
return False
def _calculate_checksum(self, data: Dict[str, Any]) -> str:
"""Calculate checksum for configuration data"""
json_str = json.dumps(data, sort_keys=True)
return hashlib.sha256(json_str.encode()).hexdigest()
def _create_version(self, name: str, config_data: Dict[str, Any], description: str = "Auto-save") -> ConfigVersion:
def _create_version(
self, name: str, config_data: Dict[str, Any], description: str = "Auto-save"
) -> ConfigVersion:
"""Create a new configuration version"""
version_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}"
checksum = self._calculate_checksum(config_data)
version = ConfigVersion(
version=version_id,
timestamp=datetime.now(),
checksum=checksum,
author=os.getenv('USER', 'system'),
author=os.getenv("USER", "system"),
description=description,
config_data=config_data.copy()
config_data=config_data.copy(),
)
if name not in self.versions:
self.versions[name] = []
self.versions[name].append(version)
# Keep only last 10 versions
if len(self.versions[name]) > 10:
self.versions[name] = self.versions[name][-10:]
self.stats.config_versions += 1
logger.debug(f"Created version {version_id} for config '{name}'")
return version
async def set_config(self, name: str, config_data: Dict[str, Any],
description: str = "Manual update") -> bool:
async def set_config(
self, name: str, config_data: Dict[str, Any], description: str = "Manual update"
) -> bool:
"""Set configuration with validation and versioning"""
try:
# Validate configuration
if not self.validate_config(name, config_data):
return False
# Create version before updating
self._create_version(name, config_data, description)
# Store configuration
self.configs[name] = config_data.copy()
self.stats.total_configs = len(self.configs)
# Save to file
await self._save_config_to_file(name, config_data)
logger.info(f"Configuration '{name}' updated successfully")
return True
except Exception as e:
logger.error(f"Error setting config '{name}': {str(e)}")
return False
async def get_config(self, name: str, default: Any = None) -> Any:
"""Get configuration value"""
if name in self.configs:
return self.configs[name]
# Try to load from file if not in memory
config_data = await self._load_config_from_file(name)
if config_data is not None:
self.configs[name] = config_data
return config_data
return default
async def get_config_value(self, config_name: str, key: str, default: Any = None) -> Any:
async def get_config_value(
self, config_name: str, key: str, default: Any = None
) -> Any:
"""Get specific value from configuration"""
config = await self.get_config(config_name)
if config is None:
return default
keys = key.split('.')
keys = key.split(".")
value = config
try:
for k in keys:
value = value[k]
return value
except (KeyError, TypeError):
return default
async def _save_config_to_file(self, name: str, config_data: Dict[str, Any]):
"""Save configuration to file"""
file_path = self.env_config_dir / f"{name}.json"
try:
# Save as regular JSON
with open(file_path, 'w') as f:
with open(file_path, "w") as f:
json.dump(config_data, f, indent=2)
logger.debug(f"Saved config '{name}' to {file_path}")
self.config_paths[name] = file_path
except Exception as e:
logger.error(f"Error saving config '{name}' to file: {str(e)}")
raise
async def _load_config_from_file(self, name: str) -> Optional[Dict[str, Any]]:
"""Load configuration from file"""
file_path = self.env_config_dir / f"{name}.json"
if not file_path.exists():
return None
try:
# Load regular JSON
with open(file_path, 'r') as f:
with open(file_path, "r") as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading config '{name}' from file: {str(e)}")
return None
async def reload_config_file(self, file_path: str):
"""Hot reload configuration from file change"""
try:
path = Path(file_path)
config_name = path.stem
# Load updated configuration
if path.suffix == '.json':
with open(path, 'r') as f:
if path.suffix == ".json":
with open(path, "r") as f:
new_config = json.load(f)
elif path.suffix in ['.yaml', '.yml']:
with open(path, 'r') as f:
elif path.suffix in [".yaml", ".yml"]:
with open(path, "r") as f:
new_config = yaml.safe_load(f)
else:
logger.warning(f"Unsupported config file format: {path.suffix}")
return
# Validate and update
if self.validate_config(config_name, new_config):
self.configs[config_name] = new_config
self.stats.hot_reloads_performed += 1
self.stats.last_reload_time = datetime.now()
logger.info(f"Hot reloaded configuration '{config_name}' from {file_path}")
logger.info(
f"Hot reloaded configuration '{config_name}' from {file_path}"
)
else:
logger.error(f"Failed to hot reload '{config_name}' - validation failed")
logger.error(
f"Failed to hot reload '{config_name}' - validation failed"
)
except Exception as e:
logger.error(f"Error hot reloading config from {file_path}: {str(e)}")
def get_stats(self) -> Dict[str, Any]:
"""Get configuration management statistics"""
self.stats.uptime = time.time() - self.start_time
return asdict(self.stats)
async def cleanup(self):
"""Cleanup resources"""
for watcher in self.watchers.values():
watcher.stop()
watcher.join()
self.watchers.clear()
logger.info("Configuration management cleanup completed")
@@ -355,37 +371,37 @@ async def init_config_manager():
"""Initialize the global config manager"""
global config_manager
config_manager = ConfigManager()
# Register default schemas
await _register_default_schemas()
# Load default configurations
await _load_default_configs()
logger.info("Configuration manager initialized")
async def _register_default_schemas():
"""Register default configuration schemas"""
manager = get_config_manager()
# Database schema
db_schema = ConfigSchema(
name="database",
required_fields=["host", "port", "name"],
optional_fields=["username", "password", "ssl"],
field_types={"host": str, "port": int, "name": str, "ssl": bool},
validators={"port": lambda x: 1 <= x <= 65535}
validators={"port": lambda x: 1 <= x <= 65535},
)
manager.register_schema("database", db_schema)
# Cache schema
cache_schema = ConfigSchema(
name="cache",
required_fields=["redis_url"],
optional_fields=["timeout", "max_connections"],
field_types={"redis_url": str, "timeout": int, "max_connections": int},
validators={"timeout": lambda x: x > 0}
validators={"timeout": lambda x: x > 0},
)
manager.register_schema("cache", cache_schema)
@@ -393,21 +409,21 @@ async def _register_default_schemas():
async def _load_default_configs():
"""Load default configurations"""
manager = get_config_manager()
default_configs = {
"app": {
"name": "Confidential Empire",
"version": "1.0.0",
"debug": manager.environment == "development",
"log_level": "INFO",
"timezone": "UTC"
"timezone": "UTC",
},
"cache": {
"redis_url": "redis://empire-redis:6379/0",
"timeout": 30,
"max_connections": 10
}
"max_connections": 10,
},
}
for name, config in default_configs.items():
await manager.set_config(name, config, description="Default configuration")
await manager.set_config(name, config, description="Default configuration")

View File

@@ -18,19 +18,19 @@ logger = logging.getLogger(__name__)
class ConversationService:
"""Service for managing chatbot conversations and message history"""
def __init__(self, db: AsyncSession):
self.db = db
async def get_or_create_conversation(
self,
chatbot_id: str,
self,
chatbot_id: str,
user_id: str,
conversation_id: Optional[str] = None,
title: Optional[str] = None
title: Optional[str] = None,
) -> ChatbotConversation:
"""Get existing conversation or create a new one"""
# If conversation_id provided, try to get existing conversation
if conversation_id:
stmt = select(ChatbotConversation).where(
@@ -38,22 +38,24 @@ class ConversationService:
ChatbotConversation.id == conversation_id,
ChatbotConversation.chatbot_id == chatbot_id,
ChatbotConversation.user_id == user_id,
ChatbotConversation.is_active == True
ChatbotConversation.is_active == True,
)
)
result = await self.db.execute(stmt)
conversation = result.scalar_one_or_none()
if conversation:
logger.info(f"Found existing conversation {conversation_id}")
return conversation
else:
logger.warning(f"Conversation {conversation_id} not found or not accessible")
logger.warning(
f"Conversation {conversation_id} not found or not accessible"
)
# Create new conversation
if not title:
title = f"Chat {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
conversation = ChatbotConversation(
chatbot_id=chatbot_id,
user_id=user_id,
@@ -61,30 +63,29 @@ class ConversationService:
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
is_active=True,
context_data={}
context_data={},
)
self.db.add(conversation)
await self.db.commit()
await self.db.refresh(conversation)
logger.info(f"Created new conversation {conversation.id} for chatbot {chatbot_id}")
logger.info(
f"Created new conversation {conversation.id} for chatbot {chatbot_id}"
)
return conversation
async def get_conversation_history(
self,
conversation_id: str,
limit: int = 20,
include_system: bool = False
self, conversation_id: str, limit: int = 20, include_system: bool = False
) -> List[Dict[str, Any]]:
"""
Load conversation history for a conversation
Args:
conversation_id: ID of the conversation
limit: Maximum number of messages to return (default 20)
include_system: Whether to include system messages (default False)
Returns:
List of messages in chronological order (oldest first)
"""
@@ -93,185 +94,210 @@ class ConversationService:
stmt = select(ChatbotMessage).where(
ChatbotMessage.conversation_id == conversation_id
)
# Optionally exclude system messages
if not include_system:
stmt = stmt.where(ChatbotMessage.role != 'system')
stmt = stmt.where(ChatbotMessage.role != "system")
# Order by timestamp descending and limit
stmt = stmt.order_by(desc(ChatbotMessage.timestamp)).limit(limit)
result = await self.db.execute(stmt)
messages = result.scalars().all()
# Convert to list and reverse to get chronological order (oldest first)
history = []
for msg in reversed(messages):
history.append({
"role": msg.role,
"content": msg.content,
"timestamp": msg.timestamp.isoformat() if msg.timestamp else None,
"metadata": msg.message_metadata or {},
"sources": msg.sources
})
logger.info(f"Loaded {len(history)} messages for conversation {conversation_id}")
history.append(
{
"role": msg.role,
"content": msg.content,
"timestamp": msg.timestamp.isoformat()
if msg.timestamp
else None,
"metadata": msg.message_metadata or {},
"sources": msg.sources,
}
)
logger.info(
f"Loaded {len(history)} messages for conversation {conversation_id}"
)
return history
except Exception as e:
logger.error(f"Failed to load conversation history for {conversation_id}: {e}")
logger.error(
f"Failed to load conversation history for {conversation_id}: {e}"
)
return [] # Return empty list on error to avoid breaking chat
async def add_message(
self,
conversation_id: str,
role: str,
content: str,
metadata: Optional[Dict[str, Any]] = None,
sources: Optional[List[Dict[str, Any]]] = None
sources: Optional[List[Dict[str, Any]]] = None,
) -> ChatbotMessage:
"""Add a message to a conversation"""
if role not in ['user', 'assistant', 'system']:
if role not in ["user", "assistant", "system"]:
raise ValueError(f"Invalid message role: {role}")
message = ChatbotMessage(
conversation_id=conversation_id,
role=role,
content=content,
timestamp=datetime.utcnow(),
message_metadata=metadata or {},
sources=sources
sources=sources,
)
self.db.add(message)
# Update conversation timestamp
stmt = select(ChatbotConversation).where(ChatbotConversation.id == conversation_id)
stmt = select(ChatbotConversation).where(
ChatbotConversation.id == conversation_id
)
result = await self.db.execute(stmt)
conversation = result.scalar_one_or_none()
if conversation:
conversation.updated_at = datetime.utcnow()
await self.db.commit()
await self.db.refresh(message)
logger.info(f"Added {role} message to conversation {conversation_id}")
return message
async def get_conversation_stats(self, conversation_id: str) -> Dict[str, Any]:
"""Get statistics for a conversation"""
# Count messages by role
stmt = select(
ChatbotMessage.role,
func.count(ChatbotMessage.id).label('count')
).where(
ChatbotMessage.conversation_id == conversation_id
).group_by(ChatbotMessage.role)
stmt = (
select(ChatbotMessage.role, func.count(ChatbotMessage.id).label("count"))
.where(ChatbotMessage.conversation_id == conversation_id)
.group_by(ChatbotMessage.role)
)
result = await self.db.execute(stmt)
role_counts = {row.role: row.count for row in result}
# Get conversation info
stmt = select(ChatbotConversation).where(ChatbotConversation.id == conversation_id)
stmt = select(ChatbotConversation).where(
ChatbotConversation.id == conversation_id
)
result = await self.db.execute(stmt)
conversation = result.scalar_one_or_none()
if not conversation:
raise APIException(status_code=404, error_code="CONVERSATION_NOT_FOUND")
return {
"conversation_id": conversation_id,
"title": conversation.title,
"created_at": conversation.created_at.isoformat() if conversation.created_at else None,
"updated_at": conversation.updated_at.isoformat() if conversation.updated_at else None,
"created_at": conversation.created_at.isoformat()
if conversation.created_at
else None,
"updated_at": conversation.updated_at.isoformat()
if conversation.updated_at
else None,
"total_messages": sum(role_counts.values()),
"user_messages": role_counts.get('user', 0),
"assistant_messages": role_counts.get('assistant', 0),
"system_messages": role_counts.get('system', 0)
"user_messages": role_counts.get("user", 0),
"assistant_messages": role_counts.get("assistant", 0),
"system_messages": role_counts.get("system", 0),
}
async def archive_old_conversations(self, days_inactive: int = 30) -> int:
"""Archive conversations that haven't been used in specified days"""
cutoff_date = datetime.utcnow() - timedelta(days=days_inactive)
# Find conversations to archive
stmt = select(ChatbotConversation).where(
and_(
ChatbotConversation.updated_at < cutoff_date,
ChatbotConversation.is_active == True
ChatbotConversation.is_active == True,
)
)
result = await self.db.execute(stmt)
conversations = result.scalars().all()
archived_count = 0
for conversation in conversations:
conversation.is_active = False
archived_count += 1
if archived_count > 0:
await self.db.commit()
logger.info(f"Archived {archived_count} inactive conversations")
return archived_count
async def delete_conversation(self, conversation_id: str, user_id: str) -> bool:
"""Delete a conversation and all its messages"""
# Verify ownership
stmt = select(ChatbotConversation).where(
and_(
ChatbotConversation.id == conversation_id,
ChatbotConversation.user_id == user_id
stmt = (
select(ChatbotConversation)
.where(
and_(
ChatbotConversation.id == conversation_id,
ChatbotConversation.user_id == user_id,
)
)
).options(selectinload(ChatbotConversation.messages))
.options(selectinload(ChatbotConversation.messages))
)
result = await self.db.execute(stmt)
conversation = result.scalar_one_or_none()
if not conversation:
return False
# Delete all messages first
for message in conversation.messages:
await self.db.delete(message)
# Delete conversation
await self.db.delete(conversation)
await self.db.commit()
logger.info(f"Deleted conversation {conversation_id} with {len(conversation.messages)} messages")
logger.info(
f"Deleted conversation {conversation_id} with {len(conversation.messages)} messages"
)
return True
async def get_user_conversations(
self,
user_id: str,
self,
user_id: str,
chatbot_id: Optional[str] = None,
limit: int = 50,
skip: int = 0
skip: int = 0,
) -> List[Dict[str, Any]]:
"""Get list of conversations for a user"""
stmt = select(ChatbotConversation).where(
and_(
ChatbotConversation.user_id == user_id,
ChatbotConversation.is_active == True
ChatbotConversation.is_active == True,
)
)
if chatbot_id:
stmt = stmt.where(ChatbotConversation.chatbot_id == chatbot_id)
stmt = stmt.order_by(desc(ChatbotConversation.updated_at)).offset(skip).limit(limit)
stmt = (
stmt.order_by(desc(ChatbotConversation.updated_at))
.offset(skip)
.limit(limit)
)
result = await self.db.execute(stmt)
conversations = result.scalars().all()
conversation_list = []
for conv in conversations:
# Get message count
@@ -280,15 +306,21 @@ class ConversationService:
)
msg_count_result = await self.db.execute(msg_count_stmt)
message_count = msg_count_result.scalar() or 0
conversation_list.append({
"id": conv.id,
"chatbot_id": conv.chatbot_id,
"title": conv.title,
"message_count": message_count,
"created_at": conv.created_at.isoformat() if conv.created_at else None,
"updated_at": conv.updated_at.isoformat() if conv.updated_at else None,
"context_data": conv.context_data or {}
})
return conversation_list
conversation_list.append(
{
"id": conv.id,
"chatbot_id": conv.chatbot_id,
"title": conv.title,
"message_count": message_count,
"created_at": conv.created_at.isoformat()
if conv.created_at
else None,
"updated_at": conv.updated_at.isoformat()
if conv.updated_at
else None,
"context_data": conv.context_data or {},
}
)
return conversation_list

View File

@@ -10,60 +10,64 @@ logger = get_logger(__name__)
class CostCalculator:
"""Service for calculating costs based on model usage and token consumption"""
# Model pricing in 1/10000ths of a dollar per 1000 tokens (input/output)
MODEL_PRICING = {
# OpenAI Models
"gpt-4": {"input": 300, "output": 600}, # $0.03/$0.06 per 1K tokens
"gpt-4-turbo": {"input": 100, "output": 300}, # $0.01/$0.03 per 1K tokens
"gpt-3.5-turbo": {"input": 5, "output": 15}, # $0.0005/$0.0015 per 1K tokens
# Anthropic Models
"claude-3-opus": {"input": 150, "output": 750}, # $0.015/$0.075 per 1K tokens
"claude-3-sonnet": {"input": 30, "output": 150}, # $0.003/$0.015 per 1K tokens
"claude-3-haiku": {"input": 25, "output": 125}, # $0.00025/$0.00125 per 1K tokens
"claude-3-haiku": {
"input": 25,
"output": 125,
}, # $0.00025/$0.00125 per 1K tokens
# Google Models
"gemini-pro": {"input": 5, "output": 15}, # $0.0005/$0.0015 per 1K tokens
"gemini-pro-vision": {"input": 5, "output": 15}, # $0.0005/$0.0015 per 1K tokens
"gemini-pro-vision": {
"input": 5,
"output": 15,
}, # $0.0005/$0.0015 per 1K tokens
# Privatemode.ai Models (estimated pricing)
"privatemode-llama-70b": {"input": 40, "output": 80}, # Estimated pricing
"privatemode-mixtral": {"input": 20, "output": 40}, # Estimated pricing
# Embedding Models
"text-embedding-ada-002": {"input": 1, "output": 0}, # $0.0001 per 1K tokens
"text-embedding-3-small": {"input": 2, "output": 0}, # $0.00002 per 1K tokens
"text-embedding-3-large": {"input": 13, "output": 0}, # $0.00013 per 1K tokens
}
# Default pricing for unknown models
DEFAULT_PRICING = {"input": 10, "output": 20} # $0.001/$0.002 per 1K tokens
@classmethod
def get_model_pricing(cls, model_name: str) -> Dict[str, int]:
"""Get pricing for a specific model"""
# Normalize model name (remove provider prefixes)
normalized_name = cls._normalize_model_name(model_name)
# Look up pricing
pricing = cls.MODEL_PRICING.get(normalized_name, cls.DEFAULT_PRICING)
logger.debug(f"Pricing for model '{model_name}' (normalized: '{normalized_name}'): {pricing}")
logger.debug(
f"Pricing for model '{model_name}' (normalized: '{normalized_name}'): {pricing}"
)
return pricing
@classmethod
def _normalize_model_name(cls, model_name: str) -> str:
"""Normalize model name by removing provider prefixes"""
# Remove common provider prefixes
prefixes = ["openai/", "anthropic/", "google/", "gemini/", "privatemode/"]
normalized = model_name.lower()
for prefix in prefixes:
if normalized.startswith(prefix):
normalized = normalized[len(prefix):]
normalized = normalized[len(prefix) :]
break
# Handle special cases
if "claude-3-opus-20240229" in normalized:
return "claude-3-opus"
@@ -75,91 +79,88 @@ class CostCalculator:
return "privatemode-llama-70b"
elif "mistralai/mixtral-8x7b-instruct" in normalized:
return "privatemode-mixtral"
return normalized
@classmethod
def calculate_cost_cents(
cls,
model_name: str,
input_tokens: int = 0,
output_tokens: int = 0
cls, model_name: str, input_tokens: int = 0, output_tokens: int = 0
) -> int:
"""
Calculate cost in cents for given token usage
Args:
model_name: Name of the LLM model
input_tokens: Number of input tokens used
output_tokens: Number of output tokens generated
Returns:
Total cost in cents
"""
pricing = cls.get_model_pricing(model_name)
# Calculate cost per token type
input_cost_cents = (input_tokens * pricing["input"]) // 1000
output_cost_cents = (output_tokens * pricing["output"]) // 1000
total_cost_cents = input_cost_cents + output_cost_cents
logger.debug(
f"Cost calculation for {model_name}: "
f"input_tokens={input_tokens} (${input_cost_cents/100:.4f}), "
f"output_tokens={output_tokens} (${output_cost_cents/100:.4f}), "
f"total=${total_cost_cents/100:.4f}"
)
return total_cost_cents
@classmethod
def estimate_cost_cents(cls, model_name: str, estimated_tokens: int) -> int:
"""
Estimate cost for a request based on estimated total tokens
Assumes 70% input, 30% output token distribution
Args:
model_name: Name of the LLM model
estimated_tokens: Estimated total tokens for the request
Returns:
Estimated cost in cents
"""
input_tokens = int(estimated_tokens * 0.7) # 70% input
output_tokens = int(estimated_tokens * 0.3) # 30% output
return cls.calculate_cost_cents(model_name, input_tokens, output_tokens)
@classmethod
def get_cost_per_1k_tokens(cls, model_name: str) -> Dict[str, float]:
"""
Get cost per 1000 tokens in dollars for display purposes
Args:
model_name: Name of the LLM model
Returns:
Dictionary with input and output costs in dollars per 1K tokens
"""
pricing_cents = cls.get_model_pricing(model_name)
return {
"input": pricing_cents["input"] / 10000, # Convert 1/10000ths to dollars
"output": pricing_cents["output"] / 10000,
"currency": "USD"
"currency": "USD",
}
@classmethod
def get_all_model_pricing(cls) -> Dict[str, Dict[str, float]]:
"""Get pricing for all supported models in dollars"""
pricing_data = {}
for model_name in cls.MODEL_PRICING.keys():
pricing_data[model_name] = cls.get_cost_per_1k_tokens(model_name)
return pricing_data
@classmethod
def format_cost_display(cls, cost_cents: int) -> str:
"""Format cost in 1/1000ths of a dollar for display"""
@@ -172,7 +173,9 @@ class CostCalculator:
# Convenience functions for common operations
def calculate_request_cost(model_name: str, input_tokens: int, output_tokens: int) -> int:
def calculate_request_cost(
model_name: str, input_tokens: int, output_tokens: int
) -> int:
"""Calculate cost for a single request"""
return CostCalculator.calculate_cost_cents(model_name, input_tokens, output_tokens)
@@ -184,4 +187,4 @@ def estimate_request_cost(model_name: str, estimated_tokens: int) -> int:
def get_model_pricing_display(model_name: str) -> Dict[str, float]:
"""Get model pricing for display"""
return CostCalculator.get_cost_per_1k_tokens(model_name)
return CostCalculator.get_cost_per_1k_tokens(model_name)

View File

@@ -33,12 +33,13 @@ class ProcessingStatus(str, Enum):
@dataclass
class ProcessingTask:
"""Document processing task"""
document_id: int
priority: int = 1
retry_count: int = 0
max_retries: int = 3
created_at: datetime = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.utcnow()
@@ -46,7 +47,7 @@ class ProcessingTask:
class DocumentProcessor:
"""Async document processor with queue management"""
def __init__(self, max_workers: int = 3, max_queue_size: int = 100):
self.max_workers = max_workers
self.max_queue_size = max_queue_size
@@ -57,49 +58,49 @@ class DocumentProcessor:
"processed_count": 0,
"error_count": 0,
"queue_size": 0,
"active_workers": 0
"active_workers": 0,
}
self._rag_module = None
self._rag_module_lock = asyncio.Lock()
async def start(self):
"""Start the document processor"""
if self.running:
return
self.running = True
logger.info(f"Starting document processor with {self.max_workers} workers")
# Start worker tasks
for i in range(self.max_workers):
worker = asyncio.create_task(self._worker(f"worker-{i}"))
self.workers.append(worker)
logger.info("Document processor started")
async def stop(self):
"""Stop the document processor"""
if not self.running:
return
self.running = False
logger.info("Stopping document processor...")
# Cancel all workers
for worker in self.workers:
worker.cancel()
# Wait for workers to finish
await asyncio.gather(*self.workers, return_exceptions=True)
self.workers.clear()
logger.info("Document processor stopped")
async def add_task(self, document_id: int, priority: int = 1) -> bool:
"""Add a document processing task to the queue"""
try:
task = ProcessingTask(document_id=document_id, priority=priority)
try:
await asyncio.wait_for(self.processing_queue.put(task), timeout=5.0)
except asyncio.TimeoutError:
@@ -108,47 +109,54 @@ class DocumentProcessor:
document_id,
)
return False
self.stats["queue_size"] = self.processing_queue.qsize()
logger.info(f"Added processing task for document {document_id} (priority: {priority})")
logger.info(
f"Added processing task for document {document_id} (priority: {priority})"
)
return True
except Exception as e:
logger.error(f"Failed to add processing task for document {document_id}: {e}")
logger.error(
f"Failed to add processing task for document {document_id}: {e}"
)
return False
async def _worker(self, worker_name: str):
"""Worker coroutine that processes documents"""
logger.info(f"Started worker: {worker_name}")
while self.running:
task: Optional[ProcessingTask] = None
try:
# Get task from queue (wait up to 1 second)
task = await asyncio.wait_for(
self.processing_queue.get(),
timeout=1.0
)
task = await asyncio.wait_for(self.processing_queue.get(), timeout=1.0)
self.stats["active_workers"] += 1
self.stats["queue_size"] = self.processing_queue.qsize()
logger.info(f"{worker_name}: Processing document {task.document_id}")
# Process the document
success = await self._process_document(task)
if success:
self.stats["processed_count"] += 1
logger.info(f"{worker_name}: Successfully processed document {task.document_id}")
logger.info(
f"{worker_name}: Successfully processed document {task.document_id}"
)
else:
# Retry logic
if task.retry_count < task.max_retries:
task.retry_count += 1
await asyncio.sleep(2 ** task.retry_count) # Exponential backoff
await asyncio.sleep(
2**task.retry_count
) # Exponential backoff
try:
await asyncio.wait_for(self.processing_queue.put(task), timeout=5.0)
await asyncio.wait_for(
self.processing_queue.put(task), timeout=5.0
)
except asyncio.TimeoutError:
logger.error(
"%s: Failed to requeue document %s due to saturated queue",
@@ -157,11 +165,15 @@ class DocumentProcessor:
)
self.stats["error_count"] += 1
continue
logger.warning(f"{worker_name}: Retrying document {task.document_id} (attempt {task.retry_count})")
logger.warning(
f"{worker_name}: Retrying document {task.document_id} (attempt {task.retry_count})"
)
else:
self.stats["error_count"] += 1
logger.error(f"{worker_name}: Failed to process document {task.document_id} after {task.max_retries} retries")
logger.error(
f"{worker_name}: Failed to process document {task.document_id} after {task.max_retries} retries"
)
except asyncio.TimeoutError:
# No tasks in queue, continue
continue
@@ -183,30 +195,34 @@ class DocumentProcessor:
async def _get_rag_module(self):
"""Resolve and cache the RAG module instance"""
async with self._rag_module_lock:
if self._rag_module and getattr(self._rag_module, 'enabled', False):
if self._rag_module and getattr(self._rag_module, "enabled", False):
return self._rag_module
if not module_manager.initialized:
await module_manager.initialize()
rag_module = module_manager.get_module('rag')
rag_module = module_manager.get_module("rag")
if not rag_module:
enabled = await module_manager.enable_module('rag')
enabled = await module_manager.enable_module("rag")
if not enabled:
raise RuntimeError("Failed to enable RAG module")
rag_module = module_manager.get_module('rag')
rag_module = module_manager.get_module("rag")
if not rag_module:
raise RuntimeError("RAG module not available after enable attempt")
if not getattr(rag_module, 'enabled', True):
enabled = await module_manager.enable_module('rag')
if not getattr(rag_module, "enabled", True):
enabled = await module_manager.enable_module("rag")
if not enabled:
raise RuntimeError("RAG module is disabled and could not be re-enabled")
rag_module = module_manager.get_module('rag')
if not rag_module or not getattr(rag_module, 'enabled', True):
raise RuntimeError("RAG module is disabled and could not be re-enabled")
raise RuntimeError(
"RAG module is disabled and could not be re-enabled"
)
rag_module = module_manager.get_module("rag")
if not rag_module or not getattr(rag_module, "enabled", True):
raise RuntimeError(
"RAG module is disabled and could not be re-enabled"
)
self._rag_module = rag_module
logger.info("DocumentProcessor cached RAG module instance for reuse")
@@ -216,6 +232,7 @@ class DocumentProcessor:
"""Process a single document"""
from datetime import datetime
from app.db.database import async_session_factory
async with async_session_factory() as session:
try:
# Get document from database
@@ -226,11 +243,11 @@ class DocumentProcessor:
)
result = await session.execute(stmt)
document = result.scalar_one_or_none()
if not document:
logger.error(f"Document {task.document_id} not found")
return False
# Update status to processing
document.status = ProcessingStatus.PROCESSING
await session.commit()
@@ -244,43 +261,62 @@ class DocumentProcessor:
if not rag_module or not rag_module.enabled:
raise Exception("RAG module not available or not enabled")
logger.info(f"RAG module loaded successfully for document {task.document_id}")
logger.info(
f"RAG module loaded successfully for document {task.document_id}"
)
# Read file content
logger.info(f"Reading file content for document {task.document_id}: {document.file_path}")
logger.info(
f"Reading file content for document {task.document_id}: {document.file_path}"
)
file_path = Path(document.file_path)
try:
file_content = await asyncio.to_thread(file_path.read_bytes)
except FileNotFoundError:
logger.error(f"File not found for document {task.document_id}: {document.file_path}")
logger.error(
f"File not found for document {task.document_id}: {document.file_path}"
)
document.status = ProcessingStatus.ERROR
document.processing_error = "Document file not found on disk"
await session.commit()
return False
except Exception as exc:
logger.error(f"Failed reading file for document {task.document_id}: {exc}")
logger.error(
f"Failed reading file for document {task.document_id}: {exc}"
)
document.status = ProcessingStatus.ERROR
document.processing_error = f"Failed to read file: {exc}"
await session.commit()
return False
logger.info(f"File content read successfully for document {task.document_id}, size: {len(file_content)} bytes")
logger.info(
f"File content read successfully for document {task.document_id}, size: {len(file_content)} bytes"
)
# Process with RAG module
logger.info(f"Starting document processing for document {task.document_id} with RAG module")
logger.info(
f"Starting document processing for document {task.document_id} with RAG module"
)
# Special handling for JSONL files - skip processing phase
if document.file_type == 'jsonl':
if document.file_type == "jsonl":
# For JSONL files, we don't need to process content here
# The optimized JSONL processor will handle everything during indexing
document.converted_content = f"JSONL file with {len(file_content)} bytes"
document.converted_content = (
f"JSONL file with {len(file_content)} bytes"
)
document.word_count = 0 # Will be updated during indexing
document.character_count = len(file_content)
document.document_metadata = {"file_path": document.file_path, "processed": "jsonl"}
document.document_metadata = {
"file_path": document.file_path,
"processed": "jsonl",
}
document.status = ProcessingStatus.PROCESSED
document.processed_at = datetime.utcnow()
logger.info(f"JSONL document {task.document_id} marked for optimized processing")
logger.info(
f"JSONL document {task.document_id} marked for optimized processing"
)
else:
# Standard processing for other file types
try:
@@ -289,11 +325,13 @@ class DocumentProcessor:
rag_module.process_document(
file_content,
document.original_filename,
{"file_path": document.file_path}
{"file_path": document.file_path},
),
timeout=300.0 # 5 minute timeout
timeout=300.0, # 5 minute timeout
)
logger.info(
f"Document processing completed for document {task.document_id}"
)
logger.info(f"Document processing completed for document {task.document_id}")
# Update document with processed content
document.converted_content = processed_doc.content
@@ -303,29 +341,35 @@ class DocumentProcessor:
document.status = ProcessingStatus.PROCESSED
document.processed_at = datetime.utcnow()
except asyncio.TimeoutError:
logger.error(f"Document processing timed out for document {task.document_id}")
logger.error(
f"Document processing timed out for document {task.document_id}"
)
raise Exception("Document processing timed out after 5 minutes")
except Exception as e:
logger.error(f"Document processing failed for document {task.document_id}: {e}")
logger.error(
f"Document processing failed for document {task.document_id}: {e}"
)
raise
# Index in RAG system using same RAG module
if rag_module and document.converted_content:
try:
logger.info(f"Starting indexing for document {task.document_id} in collection {document.collection.qdrant_collection_name}")
logger.info(
f"Starting indexing for document {task.document_id} in collection {document.collection.qdrant_collection_name}"
)
# Index the document content in the correct Qdrant collection
doc_metadata = {
"collection_id": document.collection_id,
"document_id": document.id,
"filename": document.original_filename,
"file_type": document.file_type,
**document.document_metadata
**document.document_metadata,
}
# Use the correct Qdrant collection name for this document
# For JSONL files, we need to use the processed document flow
if document.file_type == 'jsonl':
if document.file_type == "jsonl":
# Create a ProcessedDocument for the JSONL processor
from app.modules.rag.main import ProcessedDocument
from datetime import datetime
@@ -333,7 +377,9 @@ class DocumentProcessor:
# Calculate file hash
processed_at = datetime.utcnow()
file_hash = hashlib.md5(str(document.id).encode()).hexdigest()
file_hash = hashlib.md5(
str(document.id).encode()
).hexdigest()
processed_doc = ProcessedDocument(
id=str(document.id),
@@ -341,12 +387,14 @@ class DocumentProcessor:
extracted_text="", # Will be filled by JSONL processor
metadata={
**doc_metadata,
"file_path": document.file_path
"file_path": document.file_path,
},
original_filename=document.original_filename,
file_type=document.file_type,
mime_type=document.mime_type,
language=document.document_metadata.get('language', 'EN'),
language=document.document_metadata.get(
"language", "EN"
),
word_count=0, # Will be updated during processing
sentence_count=0, # Will be updated during processing
entities=[],
@@ -354,16 +402,16 @@ class DocumentProcessor:
processing_time=0.0,
processed_at=processed_at,
file_hash=file_hash,
file_size=document.file_size
file_size=document.file_size,
)
# The JSONL processor will read the original file
await asyncio.wait_for(
rag_module.index_processed_document(
processed_doc=processed_doc,
collection_name=document.collection.qdrant_collection_name
collection_name=document.collection.qdrant_collection_name,
),
timeout=300.0 # 5 minute timeout for JSONL processing
timeout=300.0, # 5 minute timeout for JSONL processing
)
else:
# Use standard indexing for other file types
@@ -371,18 +419,22 @@ class DocumentProcessor:
rag_module.index_document(
content=document.converted_content,
metadata=doc_metadata,
collection_name=document.collection.qdrant_collection_name
collection_name=document.collection.qdrant_collection_name,
),
timeout=120.0 # 2 minute timeout for indexing
timeout=120.0, # 2 minute timeout for indexing
)
logger.info(f"Document {task.document_id} indexed successfully in collection {document.collection.qdrant_collection_name}")
logger.info(
f"Document {task.document_id} indexed successfully in collection {document.collection.qdrant_collection_name}"
)
# Update vector count (approximate)
document.vector_count = max(1, len(document.converted_content) // 1000)
document.vector_count = max(
1, len(document.converted_content) // 1000
)
document.status = ProcessingStatus.INDEXED
document.indexed_at = datetime.utcnow()
# Update collection stats
collection = document.collection
if collection and document.status == ProcessingStatus.INDEXED:
@@ -390,36 +442,38 @@ class DocumentProcessor:
collection.size_bytes += document.file_size
collection.vector_count += document.vector_count
collection.updated_at = datetime.utcnow()
except Exception as e:
logger.error(f"Failed to index document {task.document_id} in RAG: {e}")
logger.error(
f"Failed to index document {task.document_id} in RAG: {e}"
)
# Mark as error since indexing failed
document.status = ProcessingStatus.ERROR
document.processing_error = f"Indexing failed: {str(e)}"
# Don't raise the exception to avoid retries on indexing failures
await session.commit()
return True
except Exception as e:
# Mark document as error
if 'document' in locals() and document:
if "document" in locals() and document:
document.status = ProcessingStatus.ERROR
document.processing_error = str(e)
await session.commit()
logger.error(f"Error processing document {task.document_id}: {e}")
return False
async def get_stats(self) -> Dict[str, Any]:
"""Get processor statistics"""
return {
**self.stats,
"running": self.running,
"worker_count": len(self.workers),
"queue_size": self.processing_queue.qsize()
"queue_size": self.processing_queue.qsize(),
}
async def get_queue_status(self) -> Dict[str, Any]:
"""Get detailed queue status"""
return {
@@ -427,7 +481,7 @@ class DocumentProcessor:
"max_queue_size": self.max_queue_size,
"queue_full": self.processing_queue.full(),
"active_workers": self.stats["active_workers"],
"max_workers": self.max_workers
"max_workers": self.max_workers,
}

View File

@@ -19,7 +19,9 @@ class EmbeddingService:
"""Service for generating text embeddings using a local transformer model"""
def __init__(self, model_name: Optional[str] = None):
self.model_name = model_name or getattr(settings, "RAG_EMBEDDING_MODEL", "BAAI/bge-m3")
self.model_name = model_name or getattr(
settings, "RAG_EMBEDDING_MODEL", "BAAI/bge-m3"
)
self.dimension = 1024 # bge-m3 produces 1024-d vectors
self.initialized = False
self.local_model = None
@@ -56,13 +58,15 @@ class EmbeddingService:
return False
except Exception as exc:
logger.error(f"Failed to load local embedding model {self.model_name}: {exc}")
logger.error(
f"Failed to load local embedding model {self.model_name}: {exc}"
)
logger.warning("Falling back to random embeddings")
self.local_model = None
self.initialized = False
self.backend = "fallback_random"
return False
async def get_embedding(self, text: str) -> List[float]:
"""Get embedding for a single text"""
embeddings = await self.get_embeddings([text])
@@ -102,13 +106,21 @@ class EmbeddingService:
except Exception as exc:
logger.error(f"Local embedding generation failed: {exc}")
self.backend = "fallback_random"
return self._generate_fallback_embeddings(texts, duration=time.time() - start_time)
return self._generate_fallback_embeddings(
texts, duration=time.time() - start_time
)
logger.warning("Local embedding model unavailable; using fallback random embeddings")
logger.warning(
"Local embedding model unavailable; using fallback random embeddings"
)
self.backend = "fallback_random"
return self._generate_fallback_embeddings(texts, duration=time.time() - start_time)
def _generate_fallback_embeddings(self, texts: List[str], duration: float = None) -> List[List[float]]:
return self._generate_fallback_embeddings(
texts, duration=time.time() - start_time
)
def _generate_fallback_embeddings(
self, texts: List[str], duration: float = None
) -> List[List[float]]:
"""Generate fallback random embeddings when model unavailable"""
embeddings = []
for text in texts:
@@ -124,30 +136,30 @@ class EmbeddingService:
},
)
return embeddings
def _generate_fallback_embedding(self, text: str) -> List[float]:
"""Generate a single fallback embedding"""
dimension = self.dimension or 1024
# Use hash for reproducible random embeddings
np.random.seed(hash(text) % 2**32)
return np.random.random(dimension).tolist()
async def similarity(self, text1: str, text2: str) -> float:
"""Calculate cosine similarity between two texts"""
embeddings = await self.get_embeddings([text1, text2])
# Calculate cosine similarity
vec1 = np.array(embeddings[0])
vec2 = np.array(embeddings[1])
# Normalize vectors
vec1_norm = vec1 / np.linalg.norm(vec1)
vec2_norm = vec2 / np.linalg.norm(vec2)
# Calculate cosine similarity
similarity = np.dot(vec1_norm, vec2_norm)
return float(similarity)
async def get_stats(self) -> Dict[str, Any]:
"""Get embedding service statistics"""
return {
@@ -155,7 +167,7 @@ class EmbeddingService:
"model_loaded": self.initialized,
"dimension": self.dimension,
"backend": self.backend,
"initialized": self.initialized
"initialized": self.initialized,
}
async def cleanup(self):

View File

@@ -21,17 +21,30 @@ class EnhancedEmbeddingService(EmbeddingService):
def __init__(self, model_name: Optional[str] = None):
super().__init__(model_name)
self.rate_limit_tracker = {
'requests_count': 0,
'window_start': time.time(),
'window_size': 60, # 1 minute window
'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 12)), # Configurable
'retry_delays': [int(x) for x in getattr(settings, 'RAG_EMBEDDING_RETRY_DELAYS', '1,2,4,8,16').split(',')], # Exponential backoff
'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 1.0)),
'delay_per_request': float(getattr(settings, 'RAG_EMBEDDING_DELAY_PER_REQUEST', 0.5)),
'last_rate_limit_error': None
"requests_count": 0,
"window_start": time.time(),
"window_size": 60, # 1 minute window
"max_requests_per_minute": int(
getattr(settings, "RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE", 12)
), # Configurable
"retry_delays": [
int(x)
for x in getattr(
settings, "RAG_EMBEDDING_RETRY_DELAYS", "1,2,4,8,16"
).split(",")
], # Exponential backoff
"delay_between_batches": float(
getattr(settings, "RAG_EMBEDDING_DELAY_BETWEEN_BATCHES", 1.0)
),
"delay_per_request": float(
getattr(settings, "RAG_EMBEDDING_DELAY_PER_REQUEST", 0.5)
),
"last_rate_limit_error": None,
}
async def get_embeddings_with_retry(self, texts: List[str], max_retries: int = None) -> tuple[List[List[float]], bool]:
async def get_embeddings_with_retry(
self, texts: List[str], max_retries: int = None
) -> tuple[List[List[float]], bool]:
"""
Get embeddings with retry bookkeeping.
"""
@@ -51,13 +64,20 @@ class EnhancedEmbeddingService(EmbeddingService):
return {
**base_stats,
"rate_limit_info": {
"requests_in_current_window": self.rate_limit_tracker['requests_count'],
"max_requests_per_minute": self.rate_limit_tracker['max_requests_per_minute'],
"window_reset_in_seconds": max(0,
self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size'] - time.time()
"requests_in_current_window": self.rate_limit_tracker["requests_count"],
"max_requests_per_minute": self.rate_limit_tracker[
"max_requests_per_minute"
],
"window_reset_in_seconds": max(
0,
self.rate_limit_tracker["window_start"]
+ self.rate_limit_tracker["window_size"]
- time.time(),
),
"last_rate_limit_error": self.rate_limit_tracker['last_rate_limit_error']
}
"last_rate_limit_error": self.rate_limit_tracker[
"last_rate_limit_error"
],
},
}

View File

@@ -14,6 +14,7 @@ from qdrant_client.models import PointStruct, Filter, FieldCondition, MatchValue
from qdrant_client.http.models import Batch
from app.modules.rag.main import ProcessedDocument
# from app.core.analytics import log_module_event # Analytics module not available
logger = logging.getLogger(__name__)
@@ -26,8 +27,13 @@ class JSONLProcessor:
self.rag_module = rag_module
self.config = rag_module.config
async def process_and_index_jsonl(self, collection_name: str, content: bytes,
filename: str, metadata: Dict[str, Any]) -> str:
async def process_and_index_jsonl(
self,
collection_name: str,
content: bytes,
filename: str,
metadata: Dict[str, Any],
) -> str:
"""Process and index a JSONL file efficiently
Processes each JSON line as a separate document to avoid
@@ -35,8 +41,8 @@ class JSONLProcessor:
"""
try:
# Decode content
jsonl_content = content.decode('utf-8', errors='replace')
lines = jsonl_content.strip().split('\n')
jsonl_content = content.decode("utf-8", errors="replace")
lines = jsonl_content.strip().split("\n")
logger.info(f"Processing JSONL file {filename} with {len(lines)} lines")
@@ -58,28 +64,38 @@ class JSONLProcessor:
batch_start,
base_doc_id,
filename,
metadata
metadata,
)
processed_count += len(batch_lines)
# Log progress
if processed_count % 50 == 0:
logger.info(f"Processed {processed_count}/{len(lines)} lines from {filename}")
logger.info(
f"Processed {processed_count}/{len(lines)} lines from {filename}"
)
# Small delay to prevent resource exhaustion
await asyncio.sleep(0.05)
logger.info(f"Successfully processed JSONL file {filename} with {len(lines)} lines")
logger.info(
f"Successfully processed JSONL file {filename} with {len(lines)} lines"
)
return base_doc_id
except Exception as e:
logger.error(f"Error processing JSONL file {filename}: {e}")
raise
async def _process_jsonl_batch(self, collection_name: str, lines: List[str],
start_idx: int, base_doc_id: str,
filename: str, metadata: Dict[str, Any]) -> None:
async def _process_jsonl_batch(
self,
collection_name: str,
lines: List[str],
start_idx: int,
base_doc_id: str,
filename: str,
metadata: Dict[str, Any],
) -> None:
"""Process a batch of JSONL lines"""
try:
points = []
@@ -98,14 +114,14 @@ class JSONLProcessor:
continue
# Handle helpjuice export format
if 'payload' in data and data['payload'] is not None:
payload = data['payload']
article_id = data.get('id', f'article_{line_idx}')
if "payload" in data and data["payload"] is not None:
payload = data["payload"]
article_id = data.get("id", f"article_{line_idx}")
# Extract Q&A
question = payload.get('question', '')
answer = payload.get('answer', '')
language = payload.get('language', 'EN')
question = payload.get("question", "")
answer = payload.get("answer", "")
language = payload.get("language", "EN")
if question or answer:
# Create Q&A content
@@ -120,25 +136,29 @@ class JSONLProcessor:
"line_number": line_idx,
"content_type": "qa_pair",
"question": question[:100], # Truncate for metadata
"processed_at": datetime.utcnow().isoformat()
"processed_at": datetime.utcnow().isoformat(),
}
# Generate single embedding for the Q&A pair
embeddings = await self.rag_module._generate_embeddings([content])
embeddings = await self.rag_module._generate_embeddings(
[content]
)
# Create point
point_id = str(uuid.uuid4())
points.append(PointStruct(
id=point_id,
vector=embeddings[0],
payload={
**doc_metadata,
"document_id": f"{base_doc_id}_{article_id}",
"content": content,
"chunk_index": 0,
"chunk_count": 1
}
))
points.append(
PointStruct(
id=point_id,
vector=embeddings[0],
payload={
**doc_metadata,
"document_id": f"{base_doc_id}_{article_id}",
"content": content,
"chunk_index": 0,
"chunk_count": 1,
},
)
)
# Handle generic JSON format
else:
@@ -146,43 +166,55 @@ class JSONLProcessor:
# For larger JSON objects, we might need to chunk
if len(content) > 1000:
chunks = self.rag_module._chunk_text(content, chunk_size=500)
embeddings = await self.rag_module._generate_embeddings(chunks)
chunks = self.rag_module._chunk_text(
content, chunk_size=500
)
embeddings = await self.rag_module._generate_embeddings(
chunks
)
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
for i, (chunk, embedding) in enumerate(
zip(chunks, embeddings)
):
point_id = str(uuid.uuid4())
points.append(PointStruct(
points.append(
PointStruct(
id=point_id,
vector=embedding,
payload={
**metadata,
"filename": filename,
"line_number": line_idx,
"content_type": "json_object",
"document_id": f"{base_doc_id}_line_{line_idx}",
"content": chunk,
"chunk_index": i,
"chunk_count": len(chunks),
},
)
)
else:
# Small JSON - no chunking needed
embeddings = await self.rag_module._generate_embeddings(
[content]
)
point_id = str(uuid.uuid4())
points.append(
PointStruct(
id=point_id,
vector=embedding,
vector=embeddings[0],
payload={
**metadata,
"filename": filename,
"line_number": line_idx,
"content_type": "json_object",
"document_id": f"{base_doc_id}_line_{line_idx}",
"content": chunk,
"chunk_index": i,
"chunk_count": len(chunks)
}
))
else:
# Small JSON - no chunking needed
embeddings = await self.rag_module._generate_embeddings([content])
point_id = str(uuid.uuid4())
points.append(PointStruct(
id=point_id,
vector=embeddings[0],
payload={
**metadata,
"filename": filename,
"line_number": line_idx,
"content_type": "json_object",
"document_id": f"{base_doc_id}_line_{line_idx}",
"content": content,
"chunk_index": 0,
"chunk_count": 1
}
))
"content": content,
"chunk_index": 0,
"chunk_count": 1,
},
)
)
except json.JSONDecodeError as e:
logger.warning(f"Error parsing JSONL line {line_idx}: {e}")
@@ -194,8 +226,7 @@ class JSONLProcessor:
# Insert all points in this batch
if points:
self.rag_module.qdrant_client.upsert(
collection_name=collection_name,
points=points
collection_name=collection_name, points=points
)
# Update stats
@@ -208,4 +239,4 @@ class JSONLProcessor:
except Exception as e:
logger.error(f"Error processing JSONL batch: {e}")
raise
raise

View File

@@ -11,11 +11,11 @@ from .exceptions import LLMError, ProviderError, SecurityError
__all__ = [
"LLMService",
"ChatRequest",
"ChatRequest",
"ChatResponse",
"EmbeddingRequest",
"EmbeddingResponse",
"EmbeddingResponse",
"LLMError",
"ProviderError",
"SecurityError"
]
"SecurityError",
]

View File

@@ -15,30 +15,51 @@ from .models import ResilienceConfig
class ProviderConfig(BaseModel):
"""Configuration for an LLM provider"""
name: str = Field(..., description="Provider name")
provider_type: str = Field(..., description="Provider type (e.g., 'openai', 'privatemode')")
provider_type: str = Field(
..., description="Provider type (e.g., 'openai', 'privatemode')"
)
enabled: bool = Field(True, description="Whether provider is enabled")
base_url: str = Field(..., description="Provider base URL")
api_key_env_var: str = Field(..., description="Environment variable for API key")
default_model: Optional[str] = Field(None, description="Default model for this provider")
supported_models: List[str] = Field(default_factory=list, description="List of supported models")
capabilities: List[str] = Field(default_factory=list, description="Provider capabilities")
default_model: Optional[str] = Field(
None, description="Default model for this provider"
)
supported_models: List[str] = Field(
default_factory=list, description="List of supported models"
)
capabilities: List[str] = Field(
default_factory=list, description="Provider capabilities"
)
priority: int = Field(1, description="Provider priority (lower = higher priority)")
# Rate limiting
max_requests_per_minute: Optional[int] = Field(None, description="Max requests per minute")
max_requests_per_hour: Optional[int] = Field(None, description="Max requests per hour")
max_requests_per_minute: Optional[int] = Field(
None, description="Max requests per minute"
)
max_requests_per_hour: Optional[int] = Field(
None, description="Max requests per hour"
)
# Model-specific settings
supports_streaming: bool = Field(False, description="Whether provider supports streaming")
supports_function_calling: bool = Field(False, description="Whether provider supports function calling")
max_context_window: Optional[int] = Field(None, description="Maximum context window size")
supports_streaming: bool = Field(
False, description="Whether provider supports streaming"
)
supports_function_calling: bool = Field(
False, description="Whether provider supports function calling"
)
max_context_window: Optional[int] = Field(
None, description="Maximum context window size"
)
max_output_tokens: Optional[int] = Field(None, description="Maximum output tokens")
# Resilience configuration
resilience: ResilienceConfig = Field(default_factory=ResilienceConfig, description="Resilience settings")
@validator('priority')
resilience: ResilienceConfig = Field(
default_factory=ResilienceConfig, description="Resilience settings"
)
@validator("priority")
def validate_priority(cls, v):
if v < 1:
raise ValueError("Priority must be >= 1")
@@ -47,35 +68,48 @@ class ProviderConfig(BaseModel):
class LLMServiceConfig(BaseModel):
"""Main LLM service configuration"""
# Global settings
default_provider: str = Field("privatemode", description="Default provider to use")
enable_detailed_logging: bool = Field(False, description="Enable detailed request/response logging")
enable_detailed_logging: bool = Field(
False, description="Enable detailed request/response logging"
)
enable_security_checks: bool = Field(True, description="Enable security validation")
enable_metrics_collection: bool = Field(True, description="Enable metrics collection")
enable_metrics_collection: bool = Field(
True, description="Enable metrics collection"
)
max_prompt_length: int = Field(50000, ge=1000, description="Maximum prompt length")
max_response_length: int = Field(32000, ge=1000, description="Maximum response length")
max_response_length: int = Field(
32000, ge=1000, description="Maximum response length"
)
# Performance settings
default_timeout_ms: int = Field(30000, ge=1000, le=300000, description="Default request timeout")
max_concurrent_requests: int = Field(100, ge=1, le=1000, description="Maximum concurrent requests")
default_timeout_ms: int = Field(
30000, ge=1000, le=300000, description="Default request timeout"
)
max_concurrent_requests: int = Field(
100, ge=1, le=1000, description="Maximum concurrent requests"
)
# Provider configurations
providers: Dict[str, ProviderConfig] = Field(default_factory=dict, description="Provider configurations")
providers: Dict[str, ProviderConfig] = Field(
default_factory=dict, description="Provider configurations"
)
# Token rate limiting (organization-wide)
token_limits_per_minute: Dict[str, int] = Field(
default_factory=lambda: {
"prompt_tokens": 20000, # PrivateMode Standard tier
"completion_tokens": 10000 # PrivateMode Standard tier
"prompt_tokens": 20000, # PrivateMode Standard tier
"completion_tokens": 10000, # PrivateMode Standard tier
},
description="Token rate limits per minute (organization-wide)"
description="Token rate limits per minute (organization-wide)",
)
# Model routing (model_name -> provider_name)
model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing")
model_routing: Dict[str, str] = Field(
default_factory=dict, description="Model to provider routing"
)
def create_default_config(env_vars=None) -> LLMServiceConfig:
@@ -94,8 +128,8 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
supported_models=[], # Will be populated dynamically from proxy
capabilities=["chat", "embeddings", "tee"],
priority=1,
max_requests_per_minute=20, # PrivateMode Standard tier limit: 20 req/min
max_requests_per_hour=1200, # 20 req/min * 60 min
max_requests_per_minute=20, # PrivateMode Standard tier limit: 20 req/min
max_requests_per_hour=1200, # 20 req/min * 60 min
supports_streaming=True,
supports_function_calling=True,
max_context_window=128000,
@@ -105,13 +139,11 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
retry_delay_ms=1000,
timeout_ms=60000, # PrivateMode may be slower due to TEE
circuit_breaker_threshold=5,
circuit_breaker_reset_timeout_ms=120000
)
circuit_breaker_reset_timeout_ms=120000,
),
)
providers: Dict[str, ProviderConfig] = {
"privatemode": privatemode_config
}
providers: Dict[str, ProviderConfig] = {"privatemode": privatemode_config}
if env.OPENAI_API_KEY:
providers["openai"] = ProviderConfig(
@@ -126,7 +158,7 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
"gpt-4o",
"gpt-3.5-turbo",
"text-embedding-3-large",
"text-embedding-3-small"
"text-embedding-3-small",
],
capabilities=["chat", "embeddings"],
priority=2,
@@ -139,10 +171,10 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
retry_delay_ms=750,
timeout_ms=45000,
circuit_breaker_threshold=6,
circuit_breaker_reset_timeout_ms=60000
)
circuit_breaker_reset_timeout_ms=60000,
),
)
if env.ANTHROPIC_API_KEY:
providers["anthropic"] = ProviderConfig(
name="anthropic",
@@ -154,7 +186,7 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
supported_models=[
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307"
"claude-3-haiku-20240307",
],
capabilities=["chat"],
priority=3,
@@ -167,10 +199,10 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
retry_delay_ms=1000,
timeout_ms=60000,
circuit_breaker_threshold=5,
circuit_breaker_reset_timeout_ms=90000
)
circuit_breaker_reset_timeout_ms=90000,
),
)
if env.GOOGLE_API_KEY:
providers["google"] = ProviderConfig(
name="google",
@@ -181,7 +213,7 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
default_model="models/gemini-1.5-pro-latest",
supported_models=[
"models/gemini-1.5-pro-latest",
"models/gemini-1.5-flash-latest"
"models/gemini-1.5-flash-latest",
],
capabilities=["chat", "multimodal"],
priority=4,
@@ -194,169 +226,176 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
retry_delay_ms=1000,
timeout_ms=45000,
circuit_breaker_threshold=4,
circuit_breaker_reset_timeout_ms=60000
)
circuit_breaker_reset_timeout_ms=60000,
),
)
default_provider = next(
(name for name, provider in providers.items() if provider.enabled),
"privatemode"
"privatemode",
)
# Create main configuration
config = LLMServiceConfig(
default_provider=default_provider,
enable_detailed_logging=settings.LOG_LLM_PROMPTS,
providers=providers,
model_routing={} # Will be populated dynamically from provider models
model_routing={}, # Will be populated dynamically from provider models
)
return config
@dataclass
class EnvironmentVariables:
"""Environment variables used by LLM service"""
# Provider API keys
PRIVATEMODE_API_KEY: Optional[str] = None
OPENAI_API_KEY: Optional[str] = None
ANTHROPIC_API_KEY: Optional[str] = None
GOOGLE_API_KEY: Optional[str] = None
# Service settings
LOG_LLM_PROMPTS: bool = False
def __post_init__(self):
"""Load values from environment"""
self.PRIVATEMODE_API_KEY = os.getenv("PRIVATEMODE_API_KEY")
self.OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
self.OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
self.ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
self.GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
self.LOG_LLM_PROMPTS = os.getenv("LOG_LLM_PROMPTS", "false").lower() == "true"
def get_api_key(self, provider_name: str) -> Optional[str]:
"""Get API key for a specific provider"""
key_mapping = {
"privatemode": self.PRIVATEMODE_API_KEY,
"openai": self.OPENAI_API_KEY,
"anthropic": self.ANTHROPIC_API_KEY,
"google": self.GOOGLE_API_KEY
"google": self.GOOGLE_API_KEY,
}
return key_mapping.get(provider_name.lower())
def validate_required_keys(self, enabled_providers: List[str]) -> List[str]:
"""Validate that required API keys are present"""
missing_keys = []
for provider in enabled_providers:
if not self.get_api_key(provider):
missing_keys.append(f"{provider.upper()}_API_KEY")
return missing_keys
class ConfigurationManager:
"""Manages LLM service configuration"""
def __init__(self):
self._config: Optional[LLMServiceConfig] = None
self._env_vars = EnvironmentVariables()
def get_config(self) -> LLMServiceConfig:
"""Get current configuration"""
if self._config is None:
self._config = create_default_config(self._env_vars)
self._validate_configuration()
return self._config
def update_config(self, config: LLMServiceConfig):
"""Update configuration"""
self._config = config
self._validate_configuration()
def get_provider_config(self, provider_name: str) -> Optional[ProviderConfig]:
"""Get configuration for a specific provider"""
config = self.get_config()
return config.providers.get(provider_name)
def get_provider_for_model(self, model_name: str) -> Optional[str]:
"""Get provider name for a specific model"""
config = self.get_config()
return config.model_routing.get(model_name)
def get_enabled_providers(self) -> List[str]:
"""Get list of enabled providers"""
config = self.get_config()
return [name for name, provider in config.providers.items() if provider.enabled]
def get_api_key(self, provider_name: str) -> Optional[str]:
"""Get API key for provider"""
return self._env_vars.get_api_key(provider_name)
def _validate_configuration(self):
"""Validate current configuration"""
if not self._config:
return
# Check for enabled providers without API keys
enabled_providers = self.get_enabled_providers()
missing_keys = self._env_vars.validate_required_keys(enabled_providers)
if missing_keys:
import logging
logger = logging.getLogger(__name__)
logger.warning(f"Missing API keys for enabled providers: {', '.join(missing_keys)}")
logger.warning(
f"Missing API keys for enabled providers: {', '.join(missing_keys)}"
)
# Validate default provider is enabled
default_provider = self._config.default_provider
if default_provider not in enabled_providers:
raise ValueError(f"Default provider '{default_provider}' is not enabled")
# Validate model routing points to enabled providers
invalid_routes = []
for model, provider in self._config.model_routing.items():
if provider not in enabled_providers:
invalid_routes.append(f"{model} -> {provider}")
if invalid_routes:
import logging
logger = logging.getLogger(__name__)
logger.warning(f"Model routes point to disabled providers: {', '.join(invalid_routes)}")
logger.warning(
f"Model routes point to disabled providers: {', '.join(invalid_routes)}"
)
async def refresh_provider_models(self, provider_name: str, models: List[str]):
"""Update supported models for a provider dynamically"""
if not self._config:
return
provider_config = self._config.providers.get(provider_name)
if not provider_config:
return
# Update supported models
provider_config.supported_models = models
# Update model routing - map all models to this provider
for model in models:
self._config.model_routing[model] = provider_name
import logging
logger = logging.getLogger(__name__)
logger.info(f"Updated {provider_name} with {len(models)} models: {models}")
async def get_all_available_models(self) -> Dict[str, List[str]]:
"""Get all available models grouped by provider"""
config = self.get_config()
models_by_provider = {}
for provider_name, provider_config in config.providers.items():
if provider_config.enabled:
models_by_provider[provider_name] = provider_config.supported_models
return models_by_provider
def get_model_provider_mapping(self) -> Dict[str, str]:
"""Get current model to provider mapping"""
config = self.get_config()

View File

@@ -7,8 +7,10 @@ Custom exceptions for LLM service operations.
class LLMError(Exception):
"""Base exception for LLM service errors"""
def __init__(self, message: str, error_code: str = "LLM_ERROR", details: dict = None):
def __init__(
self, message: str, error_code: str = "LLM_ERROR", details: dict = None
):
super().__init__(message)
self.message = message
self.error_code = error_code
@@ -17,46 +19,78 @@ class LLMError(Exception):
class ProviderError(LLMError):
"""Exception for LLM provider-specific errors"""
def __init__(self, message: str, provider: str, error_code: str = "PROVIDER_ERROR", details: dict = None):
def __init__(
self,
message: str,
provider: str,
error_code: str = "PROVIDER_ERROR",
details: dict = None,
):
super().__init__(message, error_code, details)
self.provider = provider
class SecurityError(LLMError):
"""Exception for security-related errors"""
def __init__(self, message: str, risk_score: float = 0.0, error_code: str = "SECURITY_ERROR", details: dict = None):
def __init__(
self,
message: str,
risk_score: float = 0.0,
error_code: str = "SECURITY_ERROR",
details: dict = None,
):
super().__init__(message, error_code, details)
self.risk_score = risk_score
class ConfigurationError(LLMError):
"""Exception for configuration-related errors"""
def __init__(self, message: str, error_code: str = "CONFIG_ERROR", details: dict = None):
def __init__(
self, message: str, error_code: str = "CONFIG_ERROR", details: dict = None
):
super().__init__(message, error_code, details)
class RateLimitError(LLMError):
"""Exception for rate limiting errors"""
def __init__(self, message: str, retry_after: int = None, error_code: str = "RATE_LIMIT_ERROR", details: dict = None):
def __init__(
self,
message: str,
retry_after: int = None,
error_code: str = "RATE_LIMIT_ERROR",
details: dict = None,
):
super().__init__(message, error_code, details)
self.retry_after = retry_after
class TimeoutError(LLMError):
"""Exception for timeout errors"""
def __init__(self, message: str, timeout_duration: float = None, error_code: str = "TIMEOUT_ERROR", details: dict = None):
def __init__(
self,
message: str,
timeout_duration: float = None,
error_code: str = "TIMEOUT_ERROR",
details: dict = None,
):
super().__init__(message, error_code, details)
self.timeout_duration = timeout_duration
class ValidationError(LLMError):
"""Exception for request validation errors"""
def __init__(self, message: str, field: str = None, error_code: str = "VALIDATION_ERROR", details: dict = None):
def __init__(
self,
message: str,
field: str = None,
error_code: str = "VALIDATION_ERROR",
details: dict = None,
):
super().__init__(message, error_code, details)
self.field = field
self.field = field

View File

@@ -20,6 +20,7 @@ logger = logging.getLogger(__name__)
@dataclass
class RequestMetric:
"""Individual request metric"""
timestamp: datetime
provider: str
model: str
@@ -35,26 +36,30 @@ class RequestMetric:
class MetricsCollector:
"""Collects and aggregates LLM service metrics"""
def __init__(self, max_history_size: int = 10000):
"""
Initialize metrics collector
Args:
max_history_size: Maximum number of metrics to keep in memory
"""
self.max_history_size = max_history_size
self._metrics: deque = deque(maxlen=max_history_size)
self._provider_metrics: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000))
self._provider_metrics: Dict[str, deque] = defaultdict(
lambda: deque(maxlen=1000)
)
self._lock = threading.RLock()
# Aggregated metrics cache
self._cache_timestamp: Optional[datetime] = None
self._cached_metrics: Optional[LLMMetrics] = None
self._cache_ttl_seconds = 60 # Cache for 1 minute
logger.info(f"Metrics collector initialized with max history: {max_history_size}")
logger.info(
f"Metrics collector initialized with max history: {max_history_size}"
)
def record_request(
self,
provider: str,
@@ -66,7 +71,7 @@ class MetricsCollector:
security_risk_score: float = 0.0,
error_code: Optional[str] = None,
user_id: Optional[str] = None,
api_key_id: Optional[int] = None
api_key_id: Optional[int] = None,
):
"""Record a request metric"""
metric = RequestMetric(
@@ -80,64 +85,73 @@ class MetricsCollector:
security_risk_score=security_risk_score,
error_code=error_code,
user_id=user_id,
api_key_id=api_key_id
api_key_id=api_key_id,
)
with self._lock:
self._metrics.append(metric)
self._provider_metrics[provider].append(metric)
# Invalidate cache
self._cached_metrics = None
self._cache_timestamp = None
# Log significant events
if not success:
logger.warning(f"Request failed: {provider}/{model} - {error_code or 'Unknown error'}")
logger.warning(
f"Request failed: {provider}/{model} - {error_code or 'Unknown error'}"
)
elif security_risk_score > 0.6:
logger.info(f"High risk request: {provider}/{model} - risk score: {security_risk_score:.3f}")
logger.info(
f"High risk request: {provider}/{model} - risk score: {security_risk_score:.3f}"
)
def get_metrics(self, force_refresh: bool = False) -> LLMMetrics:
"""Get aggregated metrics"""
with self._lock:
# Check cache validity
if (not force_refresh and
self._cached_metrics and
self._cache_timestamp and
(datetime.utcnow() - self._cache_timestamp).total_seconds() < self._cache_ttl_seconds):
if (
not force_refresh
and self._cached_metrics
and self._cache_timestamp
and (datetime.utcnow() - self._cache_timestamp).total_seconds()
< self._cache_ttl_seconds
):
return self._cached_metrics
# Calculate fresh metrics
metrics = self._calculate_metrics()
# Cache results
self._cached_metrics = metrics
self._cache_timestamp = datetime.utcnow()
return metrics
def _calculate_metrics(self) -> LLMMetrics:
"""Calculate aggregated metrics from recorded data"""
if not self._metrics:
return LLMMetrics()
total_requests = len(self._metrics)
successful_requests = sum(1 for m in self._metrics if m.success)
failed_requests = total_requests - successful_requests
# Calculate averages
latencies = [m.latency_ms for m in self._metrics if m.latency_ms > 0]
risk_scores = [m.security_risk_score for m in self._metrics]
avg_latency = sum(latencies) / len(latencies) if latencies else 0.0
avg_risk_score = sum(risk_scores) / len(risk_scores) if risk_scores else 0.0
# Provider-specific metrics
provider_metrics = {}
for provider, provider_data in self._provider_metrics.items():
if provider_data:
provider_metrics[provider] = self._calculate_provider_metrics(provider_data)
provider_metrics[provider] = self._calculate_provider_metrics(
provider_data
)
return LLMMetrics(
total_requests=total_requests,
successful_requests=successful_requests,
@@ -145,48 +159,50 @@ class MetricsCollector:
average_latency_ms=avg_latency,
average_risk_score=avg_risk_score,
provider_metrics=provider_metrics,
last_updated=datetime.utcnow()
last_updated=datetime.utcnow(),
)
def _calculate_provider_metrics(self, provider_data: deque) -> Dict[str, Any]:
"""Calculate metrics for a specific provider"""
if not provider_data:
return {}
total = len(provider_data)
successful = sum(1 for m in provider_data if m.success)
failed = total - successful
latencies = [m.latency_ms for m in provider_data if m.latency_ms > 0]
avg_latency = sum(latencies) / len(latencies) if latencies else 0.0
# Token usage aggregation
total_prompt_tokens = 0
total_completion_tokens = 0
total_tokens = 0
for metric in provider_data:
if metric.token_usage:
total_prompt_tokens += metric.token_usage.get("prompt_tokens", 0)
total_completion_tokens += metric.token_usage.get("completion_tokens", 0)
total_completion_tokens += metric.token_usage.get(
"completion_tokens", 0
)
total_tokens += metric.token_usage.get("total_tokens", 0)
# Model distribution
model_counts = defaultdict(int)
for metric in provider_data:
model_counts[metric.model] += 1
# Request type distribution
request_type_counts = defaultdict(int)
for metric in provider_data:
request_type_counts[metric.request_type] += 1
# Error analysis
error_counts = defaultdict(int)
for metric in provider_data:
if not metric.success and metric.error_code:
error_counts[metric.error_code] += 1
return {
"total_requests": total,
"successful_requests": successful,
@@ -198,51 +214,57 @@ class MetricsCollector:
"total_completion_tokens": total_completion_tokens,
"total_tokens": total_tokens,
"avg_prompt_tokens": total_prompt_tokens / total if total > 0 else 0,
"avg_completion_tokens": total_completion_tokens / successful if successful > 0 else 0
"avg_completion_tokens": total_completion_tokens / successful
if successful > 0
else 0,
},
"model_distribution": dict(model_counts),
"request_type_distribution": dict(request_type_counts),
"error_distribution": dict(error_counts),
"recent_requests": total
"recent_requests": total,
}
def get_provider_metrics(self, provider: str) -> Optional[Dict[str, Any]]:
"""Get metrics for a specific provider"""
with self._lock:
if provider not in self._provider_metrics:
return None
return self._calculate_provider_metrics(self._provider_metrics[provider])
def get_recent_metrics(self, minutes: int = 5) -> List[RequestMetric]:
"""Get metrics from the last N minutes"""
cutoff_time = datetime.utcnow() - timedelta(minutes=minutes)
with self._lock:
return [m for m in self._metrics if m.timestamp >= cutoff_time]
def get_error_metrics(self, hours: int = 1) -> Dict[str, int]:
"""Get error distribution from the last N hours"""
cutoff_time = datetime.utcnow() - timedelta(hours=hours)
error_counts = defaultdict(int)
with self._lock:
for metric in self._metrics:
if metric.timestamp >= cutoff_time and not metric.success and metric.error_code:
if (
metric.timestamp >= cutoff_time
and not metric.success
and metric.error_code
):
error_counts[metric.error_code] += 1
return dict(error_counts)
def get_performance_metrics(self, minutes: int = 15) -> Dict[str, Dict[str, float]]:
"""Get performance metrics by provider from the last N minutes"""
cutoff_time = datetime.utcnow() - timedelta(minutes=minutes)
provider_perf = defaultdict(list)
with self._lock:
for metric in self._metrics:
if metric.timestamp >= cutoff_time and metric.success:
provider_perf[metric.provider].append(metric.latency_ms)
performance = {}
for provider, latencies in provider_perf.items():
if latencies:
@@ -252,26 +274,26 @@ class MetricsCollector:
"max_latency_ms": max(latencies),
"p95_latency_ms": self._percentile(latencies, 95),
"p99_latency_ms": self._percentile(latencies, 99),
"request_count": len(latencies)
"request_count": len(latencies),
}
return performance
def _percentile(self, data: List[float], percentile: int) -> float:
"""Calculate percentile of a list of numbers"""
if not data:
return 0.0
sorted_data = sorted(data)
index = (percentile / 100.0) * (len(sorted_data) - 1)
if index.is_integer():
return sorted_data[int(index)]
else:
lower = sorted_data[int(index)]
upper = sorted_data[int(index) + 1]
return lower + (upper - lower) * (index - int(index))
def clear_metrics(self):
"""Clear all metrics (use with caution)"""
with self._lock:
@@ -279,20 +301,20 @@ class MetricsCollector:
self._provider_metrics.clear()
self._cached_metrics = None
self._cache_timestamp = None
logger.info("All metrics cleared")
def get_health_summary(self) -> Dict[str, Any]:
"""Get a health summary for monitoring"""
metrics = self.get_metrics()
recent_metrics = self.get_recent_metrics(minutes=5)
error_metrics = self.get_error_metrics(hours=1)
# Calculate health scores
total_recent = len(recent_metrics)
successful_recent = sum(1 for m in recent_metrics if m.success)
success_rate = successful_recent / total_recent if total_recent > 0 else 1.0
# Determine health status
if success_rate >= 0.95:
health_status = "healthy"
@@ -300,18 +322,20 @@ class MetricsCollector:
health_status = "degraded"
else:
health_status = "unhealthy"
return {
"health_status": health_status,
"success_rate_5min": success_rate,
"total_requests_5min": total_recent,
"average_latency_ms": metrics.average_latency_ms,
"error_count_1hour": sum(error_metrics.values()),
"top_errors": dict(sorted(error_metrics.items(), key=lambda x: x[1], reverse=True)[:5]),
"top_errors": dict(
sorted(error_metrics.items(), key=lambda x: x[1], reverse=True)[:5]
),
"provider_count": len(metrics.provider_metrics),
"last_updated": datetime.utcnow().isoformat()
"last_updated": datetime.utcnow().isoformat(),
}
# Global metrics collector instance
metrics_collector = MetricsCollector()
metrics_collector = MetricsCollector()

View File

@@ -9,15 +9,30 @@ from pydantic import BaseModel, Field, validator
from datetime import datetime
class ToolCall(BaseModel):
"""Tool call in a message"""
id: str = Field(..., description="Tool call identifier")
type: str = Field("function", description="Tool call type")
function: Dict[str, Any] = Field(..., description="Function call details")
class ChatMessage(BaseModel):
"""Individual chat message"""
role: str = Field(..., description="Message role (system, user, assistant)")
content: str = Field(..., description="Message content")
content: Optional[str] = Field(None, description="Message content")
name: Optional[str] = Field(None, description="Optional message name")
@validator('role')
tool_calls: Optional[List[ToolCall]] = Field(
None, description="Tool calls in this message"
)
tool_call_id: Optional[str] = Field(
None, description="Tool call ID for tool responses"
)
@validator("role")
def validate_role(cls, v):
allowed_roles = {'system', 'user', 'assistant', 'function'}
allowed_roles = {"system", "user", "assistant", "function", "tool"}
if v not in allowed_roles:
raise ValueError(f"Role must be one of {allowed_roles}")
return v
@@ -25,21 +40,38 @@ class ChatMessage(BaseModel):
class ChatRequest(BaseModel):
"""Chat completion request"""
model: str = Field(..., description="Model identifier")
messages: List[ChatMessage] = Field(..., description="Chat messages")
temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature")
max_tokens: Optional[int] = Field(None, ge=1, le=32000, description="Maximum tokens to generate")
top_p: Optional[float] = Field(1.0, ge=0.0, le=1.0, description="Nucleus sampling parameter")
temperature: Optional[float] = Field(
0.7, ge=0.0, le=2.0, description="Sampling temperature"
)
max_tokens: Optional[int] = Field(
None, ge=1, le=32000, description="Maximum tokens to generate"
)
top_p: Optional[float] = Field(
1.0, ge=0.0, le=1.0, description="Nucleus sampling parameter"
)
top_k: Optional[int] = Field(None, ge=1, description="Top-k sampling parameter")
frequency_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="Frequency penalty")
presence_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="Presence penalty")
frequency_penalty: Optional[float] = Field(
0.0, ge=-2.0, le=2.0, description="Frequency penalty"
)
presence_penalty: Optional[float] = Field(
0.0, ge=-2.0, le=2.0, description="Presence penalty"
)
stop: Optional[Union[str, List[str]]] = Field(None, description="Stop sequences")
stream: Optional[bool] = Field(False, description="Stream response")
tools: Optional[List[Dict[str, Any]]] = Field(
None, description="Available tools for function calling"
)
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
None, description="Tool choice preference"
)
user_id: str = Field(..., description="User identifier")
api_key_id: int = Field(..., description="API key identifier")
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
@validator('messages')
@validator("messages")
def validate_messages(cls, v):
if not v:
raise ValueError("Messages cannot be empty")
@@ -48,6 +80,7 @@ class ChatRequest(BaseModel):
class TokenUsage(BaseModel):
"""Token usage information"""
prompt_tokens: int = Field(..., description="Tokens in the prompt")
completion_tokens: int = Field(..., description="Tokens in the completion")
total_tokens: int = Field(..., description="Total tokens used")
@@ -55,13 +88,17 @@ class TokenUsage(BaseModel):
class ChatChoice(BaseModel):
"""Chat completion choice"""
index: int = Field(..., description="Choice index")
message: ChatMessage = Field(..., description="Generated message")
finish_reason: Optional[str] = Field(None, description="Reason for completion finish")
finish_reason: Optional[str] = Field(
None, description="Reason for completion finish"
)
class ChatResponse(BaseModel):
"""Chat completion response"""
id: str = Field(..., description="Response identifier")
object: str = Field("chat.completion", description="Object type")
created: int = Field(..., description="Creation timestamp")
@@ -70,19 +107,28 @@ class ChatResponse(BaseModel):
choices: List[ChatChoice] = Field(..., description="Generated choices")
usage: Optional[TokenUsage] = Field(None, description="Token usage")
system_fingerprint: Optional[str] = Field(None, description="System fingerprint")
# Security fields maintained for backward compatibility
security_check: Optional[bool] = Field(None, description="Whether security check passed")
security_check: Optional[bool] = Field(
None, description="Whether security check passed"
)
risk_score: Optional[float] = Field(None, description="Security risk score")
detected_patterns: Optional[List[str]] = Field(None, description="Detected security patterns")
detected_patterns: Optional[List[str]] = Field(
None, description="Detected security patterns"
)
# Performance metrics
latency_ms: Optional[float] = Field(None, description="Response latency in milliseconds")
provider_latency_ms: Optional[float] = Field(None, description="Provider-specific latency")
latency_ms: Optional[float] = Field(
None, description="Response latency in milliseconds"
)
provider_latency_ms: Optional[float] = Field(
None, description="Provider-specific latency"
)
class EmbeddingRequest(BaseModel):
"""Embedding generation request"""
model: str = Field(..., description="Embedding model identifier")
input: Union[str, List[str]] = Field(..., description="Text to embed")
encoding_format: Optional[str] = Field("float", description="Encoding format")
@@ -90,20 +136,23 @@ class EmbeddingRequest(BaseModel):
user_id: str = Field(..., description="User identifier")
api_key_id: int = Field(..., description="API key identifier")
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
@validator('input')
@validator("input")
def validate_input(cls, v):
if isinstance(v, str):
if not v.strip():
raise ValueError("Input text cannot be empty")
elif isinstance(v, list):
if not v or not all(isinstance(item, str) and item.strip() for item in v):
raise ValueError("Input list cannot be empty and must contain non-empty strings")
raise ValueError(
"Input list cannot be empty and must contain non-empty strings"
)
return v
class EmbeddingData(BaseModel):
"""Single embedding data"""
object: str = Field("embedding", description="Object type")
index: int = Field(..., description="Embedding index")
embedding: List[float] = Field(..., description="Embedding vector")
@@ -111,63 +160,98 @@ class EmbeddingData(BaseModel):
class EmbeddingResponse(BaseModel):
"""Embedding generation response"""
object: str = Field("list", description="Object type")
data: List[EmbeddingData] = Field(..., description="Embedding data")
model: str = Field(..., description="Model used")
provider: str = Field(..., description="Provider used")
usage: Optional[TokenUsage] = Field(None, description="Token usage")
# Security fields maintained for backward compatibility
security_check: Optional[bool] = Field(None, description="Whether security check passed")
security_check: Optional[bool] = Field(
None, description="Whether security check passed"
)
risk_score: Optional[float] = Field(None, description="Security risk score")
detected_patterns: Optional[List[str]] = Field(None, description="Detected security patterns")
detected_patterns: Optional[List[str]] = Field(
None, description="Detected security patterns"
)
# Performance metrics
latency_ms: Optional[float] = Field(None, description="Response latency in milliseconds")
provider_latency_ms: Optional[float] = Field(None, description="Provider-specific latency")
latency_ms: Optional[float] = Field(
None, description="Response latency in milliseconds"
)
provider_latency_ms: Optional[float] = Field(
None, description="Provider-specific latency"
)
class ModelInfo(BaseModel):
"""Model information"""
id: str = Field(..., description="Model identifier")
object: str = Field("model", description="Object type")
created: Optional[int] = Field(None, description="Creation timestamp")
owned_by: str = Field(..., description="Model owner")
provider: str = Field(..., description="Provider name")
capabilities: List[str] = Field(default_factory=list, description="Model capabilities")
capabilities: List[str] = Field(
default_factory=list, description="Model capabilities"
)
context_window: Optional[int] = Field(None, description="Context window size")
max_output_tokens: Optional[int] = Field(None, description="Maximum output tokens")
supports_streaming: bool = Field(False, description="Whether model supports streaming")
supports_function_calling: bool = Field(False, description="Whether model supports function calling")
tasks: Optional[List[str]] = Field(None, description="Model tasks (e.g., generate, embed, vision)")
supports_streaming: bool = Field(
False, description="Whether model supports streaming"
)
supports_function_calling: bool = Field(
False, description="Whether model supports function calling"
)
tasks: Optional[List[str]] = Field(
None, description="Model tasks (e.g., generate, embed, vision)"
)
class ProviderStatus(BaseModel):
"""Provider health status"""
provider: str = Field(..., description="Provider name")
status: str = Field(..., description="Status (healthy, degraded, unavailable)")
latency_ms: Optional[float] = Field(None, description="Average latency")
success_rate: Optional[float] = Field(None, description="Success rate (0.0 to 1.0)")
last_check: datetime = Field(..., description="Last health check timestamp")
error_message: Optional[str] = Field(None, description="Error message if unhealthy")
models_available: List[str] = Field(default_factory=list, description="Available models")
models_available: List[str] = Field(
default_factory=list, description="Available models"
)
class LLMMetrics(BaseModel):
"""LLM service metrics"""
total_requests: int = Field(0, description="Total requests processed")
successful_requests: int = Field(0, description="Successful requests")
failed_requests: int = Field(0, description="Failed requests")
average_latency_ms: float = Field(0.0, description="Average response latency")
provider_metrics: Dict[str, Dict[str, Any]] = Field(default_factory=dict, description="Per-provider metrics")
last_updated: datetime = Field(default_factory=datetime.utcnow, description="Last metrics update")
provider_metrics: Dict[str, Dict[str, Any]] = Field(
default_factory=dict, description="Per-provider metrics"
)
last_updated: datetime = Field(
default_factory=datetime.utcnow, description="Last metrics update"
)
class ResilienceConfig(BaseModel):
"""Configuration for resilience patterns"""
max_retries: int = Field(3, ge=0, le=10, description="Maximum retry attempts")
retry_delay_ms: int = Field(1000, ge=100, le=30000, description="Initial retry delay")
retry_exponential_base: float = Field(2.0, ge=1.1, le=5.0, description="Exponential backoff base")
retry_delay_ms: int = Field(
1000, ge=100, le=30000, description="Initial retry delay"
)
retry_exponential_base: float = Field(
2.0, ge=1.1, le=5.0, description="Exponential backoff base"
)
timeout_ms: int = Field(30000, ge=1000, le=300000, description="Request timeout")
circuit_breaker_threshold: int = Field(5, ge=1, le=50, description="Circuit breaker failure threshold")
circuit_breaker_reset_timeout_ms: int = Field(60000, ge=10000, le=600000, description="Circuit breaker reset timeout")
circuit_breaker_threshold: int = Field(
5, ge=1, le=50, description="Circuit breaker failure threshold"
)
circuit_breaker_reset_timeout_ms: int = Field(
60000, ge=10000, le=600000, description="Circuit breaker reset timeout"
)

View File

@@ -7,4 +7,4 @@ Base provider interface and provider implementations.
from .base import BaseLLMProvider
from .privatemode import PrivateModeProvider
__all__ = ["BaseLLMProvider", "PrivateModeProvider"]
__all__ = ["BaseLLMProvider", "PrivateModeProvider"]

View File

@@ -9,8 +9,12 @@ from typing import List, Dict, Any, Optional, AsyncGenerator
import logging
from ..models import (
ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
ModelInfo, ProviderStatus
ChatRequest,
ChatResponse,
EmbeddingRequest,
EmbeddingResponse,
ModelInfo,
ProviderStatus,
)
from ..config import ProviderConfig
@@ -19,11 +23,11 @@ logger = logging.getLogger(__name__)
class BaseLLMProvider(ABC):
"""Abstract base class for LLM providers"""
def __init__(self, config: ProviderConfig, api_key: str):
"""
Initialize provider
Args:
config: Provider configuration
api_key: Decrypted API key for the provider
@@ -32,112 +36,114 @@ class BaseLLMProvider(ABC):
self.api_key = api_key
self.name = config.name
self._session = None
logger.info(f"Initializing {self.name} provider")
@property
@abstractmethod
def provider_name(self) -> str:
"""Get provider name"""
pass
@abstractmethod
async def health_check(self) -> ProviderStatus:
"""
Check provider health status
Returns:
ProviderStatus with current health information
"""
pass
@abstractmethod
async def get_models(self) -> List[ModelInfo]:
"""
Get list of available models
Returns:
List of available models with their capabilities
"""
pass
@abstractmethod
async def create_chat_completion(self, request: ChatRequest) -> ChatResponse:
"""
Create chat completion
Args:
request: Chat completion request
Returns:
Chat completion response
Raises:
ProviderError: If provider-specific error occurs
SecurityError: If security validation fails
ValidationError: If request validation fails
"""
pass
@abstractmethod
async def create_chat_completion_stream(self, request: ChatRequest) -> AsyncGenerator[Dict[str, Any], None]:
async def create_chat_completion_stream(
self, request: ChatRequest
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Create streaming chat completion
Args:
request: Chat completion request with stream=True
Yields:
Streaming response chunks
Raises:
ProviderError: If provider-specific error occurs
SecurityError: If security validation fails
ValidationError: If request validation fails
"""
pass
@abstractmethod
async def create_embedding(self, request: EmbeddingRequest) -> EmbeddingResponse:
"""
Create embeddings
Args:
request: Embedding generation request
Returns:
Embedding response
Raises:
ProviderError: If provider-specific error occurs
SecurityError: If security validation fails
ValidationError: If request validation fails
"""
pass
async def initialize(self):
"""Initialize provider resources (override if needed)"""
pass
async def cleanup(self):
"""Cleanup provider resources"""
if self._session and hasattr(self._session, 'close'):
if self._session and hasattr(self._session, "close"):
await self._session.close()
logger.debug(f"Cleaned up session for {self.name} provider")
def supports_model(self, model_name: str) -> bool:
"""Check if provider supports a specific model"""
return model_name in self.config.supported_models
def supports_capability(self, capability: str) -> bool:
"""Check if provider supports a specific capability"""
return capability in self.config.capabilities
def get_model_info(self, model_name: str) -> Optional[ModelInfo]:
"""Get information about a specific model (override for provider-specific info)"""
if not self.supports_model(model_name):
return None
return ModelInfo(
id=model_name,
object="model",
@@ -147,80 +153,89 @@ class BaseLLMProvider(ABC):
context_window=self.config.max_context_window,
max_output_tokens=self.config.max_output_tokens,
supports_streaming=self.config.supports_streaming,
supports_function_calling=self.config.supports_function_calling
supports_function_calling=self.config.supports_function_calling,
)
def _validate_request(self, request: Any):
"""Base request validation (override for provider-specific validation)"""
if hasattr(request, 'model') and not self.supports_model(request.model):
if hasattr(request, "model") and not self.supports_model(request.model):
from ..exceptions import ValidationError
raise ValidationError(
f"Model '{request.model}' not supported by provider '{self.name}'",
field="model"
field="model",
)
def _create_headers(self, additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
def _create_headers(
self, additional_headers: Optional[Dict[str, str]] = None
) -> Dict[str, str]:
"""Create HTTP headers for requests"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
"User-Agent": f"Enclava-LLM-Service/{self.name}"
"User-Agent": f"Enclava-LLM-Service/{self.name}",
}
if additional_headers:
headers.update(additional_headers)
return headers
def _handle_http_error(self, status_code: int, response_text: str, provider_context: str = ""):
def _handle_http_error(
self, status_code: int, response_text: str, provider_context: str = ""
):
"""Handle HTTP errors consistently across providers"""
from ..exceptions import ProviderError, RateLimitError, ValidationError
context = f"{self.name} {provider_context}".strip()
if status_code == 401:
raise ProviderError(
f"Authentication failed for {context}",
provider=self.name,
error_code="AUTHENTICATION_ERROR",
details={"status_code": status_code, "response": response_text}
details={"status_code": status_code, "response": response_text},
)
elif status_code == 403:
raise ProviderError(
f"Access forbidden for {context}",
provider=self.name,
error_code="AUTHORIZATION_ERROR",
details={"status_code": status_code, "response": response_text}
details={"status_code": status_code, "response": response_text},
)
elif status_code == 429:
raise RateLimitError(
f"Rate limit exceeded for {context}",
error_code="RATE_LIMIT_ERROR",
details={"status_code": status_code, "response": response_text, "provider": self.name}
details={
"status_code": status_code,
"response": response_text,
"provider": self.name,
},
)
elif status_code == 400:
raise ValidationError(
f"Bad request for {context}: {response_text}",
error_code="BAD_REQUEST",
details={"status_code": status_code, "response": response_text}
details={"status_code": status_code, "response": response_text},
)
elif 500 <= status_code < 600:
raise ProviderError(
f"Server error for {context}: {response_text}",
provider=self.name,
error_code="SERVER_ERROR",
details={"status_code": status_code, "response": response_text}
details={"status_code": status_code, "response": response_text},
)
else:
raise ProviderError(
f"HTTP error {status_code} for {context}: {response_text}",
provider=self.name,
error_code="HTTP_ERROR",
details={"status_code": status_code, "response": response_text}
details={"status_code": status_code, "response": response_text},
)
def __str__(self) -> str:
return f"{self.__class__.__name__}(name={self.name})"
def __repr__(self) -> str:
return f"{self.__class__.__name__}(name={self.name}, enabled={self.config.enabled})"
return f"{self.__class__.__name__}(name={self.name}, enabled={self.config.enabled})"

View File

@@ -15,9 +15,16 @@ import aiohttp
from .base import BaseLLMProvider
from ..models import (
ChatRequest, ChatResponse, ChatMessage, ChatChoice, TokenUsage,
EmbeddingRequest, EmbeddingResponse, EmbeddingData,
ModelInfo, ProviderStatus
ChatRequest,
ChatResponse,
ChatMessage,
ChatChoice,
TokenUsage,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingData,
ModelInfo,
ProviderStatus,
)
from ..config import ProviderConfig
from ..exceptions import ProviderError, ValidationError, TimeoutError
@@ -27,22 +34,22 @@ logger = logging.getLogger(__name__)
class PrivateModeProvider(BaseLLMProvider):
"""PrivateMode.ai provider with TEE security"""
def __init__(self, config: ProviderConfig, api_key: str):
super().__init__(config, api_key)
self.base_url = config.base_url.rstrip('/')
self.base_url = config.base_url.rstrip("/")
self._session: Optional[aiohttp.ClientSession] = None
# TEE-specific settings
self.verify_ssl = True # Always verify SSL for security
self.trust_env = False # Don't trust environment proxy settings
logger.info(f"PrivateMode provider initialized with base URL: {self.base_url}")
@property
def provider_name(self) -> str:
return "privatemode"
async def _get_session(self) -> aiohttp.ClientSession:
"""Get or create HTTP session with security settings"""
if self._session is None or self._session.closed:
@@ -52,45 +59,49 @@ class PrivateModeProvider(BaseLLMProvider):
limit=100, # Connection pool limit
limit_per_host=50,
ttl_dns_cache=300, # DNS cache TTL
use_dns_cache=True
use_dns_cache=True,
)
# Create session with security headers
timeout = aiohttp.ClientTimeout(total=self.config.resilience.timeout_ms / 1000.0)
timeout = aiohttp.ClientTimeout(
total=self.config.resilience.timeout_ms / 1000.0
)
self._session = aiohttp.ClientSession(
connector=connector,
timeout=timeout,
headers=self._create_headers(),
trust_env=False # Don't trust environment variables
trust_env=False, # Don't trust environment variables
)
logger.debug("Created new secure HTTP session for PrivateMode")
return self._session
async def health_check(self) -> ProviderStatus:
"""Check PrivateMode.ai service health"""
start_time = time.time()
try:
session = await self._get_session()
# Use a lightweight endpoint for health check
async with session.get(f"{self.base_url}/models") as response:
latency = (time.time() - start_time) * 1000
if response.status == 200:
models_data = await response.json()
models = [model.get("id", "") for model in models_data.get("data", [])]
models = [
model.get("id", "") for model in models_data.get("data", [])
]
return ProviderStatus(
provider=self.provider_name,
status="healthy",
latency_ms=latency,
success_rate=1.0,
last_check=datetime.utcnow(),
models_available=models
models_available=models,
)
else:
error_text = await response.text()
@@ -101,13 +112,13 @@ class PrivateModeProvider(BaseLLMProvider):
success_rate=0.0,
last_check=datetime.utcnow(),
error_message=f"HTTP {response.status}: {error_text}",
models_available=[]
models_available=[],
)
except Exception as e:
latency = (time.time() - start_time) * 1000
logger.error(f"PrivateMode health check failed: {e}")
return ProviderStatus(
provider=self.provider_name,
status="unavailable",
@@ -115,33 +126,33 @@ class PrivateModeProvider(BaseLLMProvider):
success_rate=0.0,
last_check=datetime.utcnow(),
error_message=str(e),
models_available=[]
models_available=[],
)
async def get_models(self) -> List[ModelInfo]:
"""Get available models from PrivateMode.ai"""
try:
session = await self._get_session()
async with session.get(f"{self.base_url}/models") as response:
if response.status == 200:
data = await response.json()
models_data = data.get("data", [])
models = []
for model_data in models_data:
model_id = model_data.get("id", "")
if not model_id:
continue
# Extract all information directly from API response
# Determine capabilities based on tasks field
tasks = model_data.get("tasks", [])
capabilities = []
# All PrivateMode models have TEE capability
capabilities.append("tee")
# Add capabilities based on tasks
if "generate" in tasks:
capabilities.append("chat")
@@ -149,12 +160,14 @@ class PrivateModeProvider(BaseLLMProvider):
capabilities.append("embeddings")
if "vision" in tasks:
capabilities.append("vision")
# Check for function calling support in the API response
supports_function_calling = model_data.get("supports_function_calling", False)
supports_function_calling = model_data.get(
"supports_function_calling", False
)
if supports_function_calling:
capabilities.append("function_calling")
model_info = ModelInfo(
id=model_id,
object="model",
@@ -164,40 +177,44 @@ class PrivateModeProvider(BaseLLMProvider):
capabilities=capabilities,
context_window=model_data.get("context_window"),
max_output_tokens=model_data.get("max_output_tokens"),
supports_streaming=model_data.get("supports_streaming", True),
supports_streaming=model_data.get(
"supports_streaming", True
),
supports_function_calling=supports_function_calling,
tasks=tasks # Pass through tasks field from PrivateMode API
tasks=tasks, # Pass through tasks field from PrivateMode API
)
models.append(model_info)
logger.info(f"Retrieved {len(models)} models from PrivateMode")
return models
else:
error_text = await response.text()
self._handle_http_error(response.status, error_text, "models endpoint")
self._handle_http_error(
response.status, error_text, "models endpoint"
)
return [] # Never reached due to exception
except Exception as e:
if isinstance(e, ProviderError):
raise
logger.error(f"Failed to get models from PrivateMode: {e}")
raise ProviderError(
"Failed to retrieve models from PrivateMode",
provider=self.provider_name,
error_code="MODEL_RETRIEVAL_ERROR",
details={"error": str(e)}
details={"error": str(e)},
)
async def create_chat_completion(self, request: ChatRequest) -> ChatResponse:
"""Create chat completion via PrivateMode.ai"""
self._validate_request(request)
start_time = time.time()
try:
session = await self._get_session()
# Prepare request payload
payload = {
"model": request.model,
@@ -205,14 +222,14 @@ class PrivateModeProvider(BaseLLMProvider):
{
"role": msg.role,
"content": msg.content,
**({"name": msg.name} if msg.name else {})
**({"name": msg.name} if msg.name else {}),
}
for msg in request.messages
],
"temperature": request.temperature,
"stream": False # Non-streaming version
"stream": False, # Non-streaming version
}
# Add optional parameters
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
@@ -224,28 +241,27 @@ class PrivateModeProvider(BaseLLMProvider):
payload["presence_penalty"] = request.presence_penalty
if request.stop is not None:
payload["stop"] = request.stop
# Add user tracking
payload["user"] = f"user_{request.user_id}"
# Add metadata for TEE audit trail
payload["metadata"] = {
"user_id": request.user_id,
"api_key_id": request.api_key_id,
"timestamp": datetime.utcnow().isoformat(),
"enclava_request_id": str(uuid.uuid4()),
**(request.metadata or {})
**(request.metadata or {}),
}
async with session.post(
f"{self.base_url}/chat/completions",
json=payload
f"{self.base_url}/chat/completions", json=payload
) as response:
provider_latency = (time.time() - start_time) * 1000
if response.status == 200:
data = await response.json()
# Parse response
choices = []
for choice_data in data.get("choices", []):
@@ -254,20 +270,20 @@ class PrivateModeProvider(BaseLLMProvider):
index=choice_data.get("index", 0),
message=ChatMessage(
role=message_data.get("role", "assistant"),
content=message_data.get("content", "")
content=message_data.get("content", ""),
),
finish_reason=choice_data.get("finish_reason")
finish_reason=choice_data.get("finish_reason"),
)
choices.append(choice)
# Parse token usage
usage_data = data.get("usage", {})
usage = TokenUsage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0)
total_tokens=usage_data.get("total_tokens", 0),
)
# Create response
chat_response = ChatResponse(
id=data.get("id", str(uuid.uuid4())),
@@ -279,45 +295,51 @@ class PrivateModeProvider(BaseLLMProvider):
usage=usage,
system_fingerprint=data.get("system_fingerprint"),
security_check=True, # Will be set by security manager
risk_score=0.0, # Will be set by security manager
risk_score=0.0, # Will be set by security manager
latency_ms=provider_latency,
provider_latency_ms=provider_latency
provider_latency_ms=provider_latency,
)
logger.debug(
f"PrivateMode chat completion successful in {provider_latency:.2f}ms"
)
logger.debug(f"PrivateMode chat completion successful in {provider_latency:.2f}ms")
return chat_response
else:
error_text = await response.text()
self._handle_http_error(response.status, error_text, "chat completion")
self._handle_http_error(
response.status, error_text, "chat completion"
)
except aiohttp.ClientError as e:
logger.error(f"PrivateMode request error: {e}")
raise ProviderError(
"Network error communicating with PrivateMode",
provider=self.provider_name,
error_code="NETWORK_ERROR",
details={"error": str(e)}
details={"error": str(e)},
)
except Exception as e:
if isinstance(e, (ProviderError, ValidationError)):
raise
logger.error(f"Unexpected error in PrivateMode chat completion: {e}")
raise ProviderError(
"Unexpected error during chat completion",
provider=self.provider_name,
error_code="UNEXPECTED_ERROR",
details={"error": str(e)}
details={"error": str(e)},
)
async def create_chat_completion_stream(self, request: ChatRequest) -> AsyncGenerator[Dict[str, Any], None]:
async def create_chat_completion_stream(
self, request: ChatRequest
) -> AsyncGenerator[Dict[str, Any], None]:
"""Create streaming chat completion"""
self._validate_request(request)
try:
session = await self._get_session()
# Prepare streaming payload
payload = {
"model": request.model,
@@ -325,14 +347,14 @@ class PrivateModeProvider(BaseLLMProvider):
{
"role": msg.role,
"content": msg.content,
**({"name": msg.name} if msg.name else {})
**({"name": msg.name} if msg.name else {}),
}
for msg in request.messages
],
"temperature": request.temperature,
"stream": True
"stream": True,
}
# Add optional parameters
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
@@ -344,100 +366,104 @@ class PrivateModeProvider(BaseLLMProvider):
payload["presence_penalty"] = request.presence_penalty
if request.stop is not None:
payload["stop"] = request.stop
# Add user tracking
payload["user"] = f"user_{request.user_id}"
async with session.post(
f"{self.base_url}/chat/completions",
json=payload
f"{self.base_url}/chat/completions", json=payload
) as response:
if response.status == 200:
async for line in response.content:
line = line.decode('utf-8').strip()
line = line.decode("utf-8").strip()
if line.startswith("data: "):
data_str = line[6:] # Remove "data: " prefix
if data_str == "[DONE]":
break
try:
chunk_data = json.loads(data_str)
yield chunk_data
except json.JSONDecodeError:
logger.warning(f"Failed to parse streaming chunk: {data_str}")
logger.warning(
f"Failed to parse streaming chunk: {data_str}"
)
continue
else:
error_text = await response.text()
self._handle_http_error(response.status, error_text, "streaming chat completion")
self._handle_http_error(
response.status, error_text, "streaming chat completion"
)
except aiohttp.ClientError as e:
logger.error(f"PrivateMode streaming error: {e}")
raise ProviderError(
"Network error during streaming",
provider=self.provider_name,
error_code="STREAMING_ERROR",
details={"error": str(e)}
details={"error": str(e)},
)
async def create_embedding(self, request: EmbeddingRequest) -> EmbeddingResponse:
"""Create embeddings via PrivateMode.ai"""
self._validate_request(request)
start_time = time.time()
try:
session = await self._get_session()
# Prepare embedding payload
payload = {
"model": request.model,
"input": request.input,
"user": f"user_{request.user_id}"
"user": f"user_{request.user_id}",
}
# Add optional parameters
if request.encoding_format:
payload["encoding_format"] = request.encoding_format
if request.dimensions:
payload["dimensions"] = request.dimensions
# Add metadata
payload["metadata"] = {
"user_id": request.user_id,
"api_key_id": request.api_key_id,
"timestamp": datetime.utcnow().isoformat(),
**(request.metadata or {})
**(request.metadata or {}),
}
async with session.post(
f"{self.base_url}/embeddings",
json=payload
f"{self.base_url}/embeddings", json=payload
) as response:
provider_latency = (time.time() - start_time) * 1000
if response.status == 200:
data = await response.json()
# Parse embedding data
embeddings = []
for emb_data in data.get("data", []):
embedding = EmbeddingData(
object="embedding",
index=emb_data.get("index", 0),
embedding=emb_data.get("embedding", [])
embedding=emb_data.get("embedding", []),
)
embeddings.append(embedding)
# Parse usage
usage_data = data.get("usage", {})
usage = TokenUsage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=0, # No completion tokens for embeddings
total_tokens=usage_data.get("total_tokens", usage_data.get("prompt_tokens", 0))
total_tokens=usage_data.get(
"total_tokens", usage_data.get("prompt_tokens", 0)
),
)
return EmbeddingResponse(
object="list",
data=embeddings,
@@ -445,37 +471,39 @@ class PrivateModeProvider(BaseLLMProvider):
provider=self.provider_name,
usage=usage,
security_check=True, # Will be set by security manager
risk_score=0.0, # Will be set by security manager
risk_score=0.0, # Will be set by security manager
latency_ms=provider_latency,
provider_latency_ms=provider_latency
provider_latency_ms=provider_latency,
)
else:
error_text = await response.text()
# Log the detailed error response from the provider
logger.error(f"PrivateMode embedding error - Status {response.status}: {error_text}")
logger.error(
f"PrivateMode embedding error - Status {response.status}: {error_text}"
)
self._handle_http_error(response.status, error_text, "embeddings")
except aiohttp.ClientError as e:
logger.error(f"PrivateMode embedding error: {e}")
raise ProviderError(
"Network error during embedding generation",
provider=self.provider_name,
error_code="EMBEDDING_ERROR",
details={"error": str(e)}
details={"error": str(e)},
)
except Exception as e:
if isinstance(e, (ProviderError, ValidationError)):
raise
logger.error(f"Unexpected error in PrivateMode embedding: {e}")
raise ProviderError(
"Unexpected error during embedding generation",
provider=self.provider_name,
error_code="UNEXPECTED_ERROR",
details={"error": str(e)}
details={"error": str(e)},
)
async def cleanup(self):
"""Cleanup PrivateMode provider resources"""
# Close HTTP session to prevent memory leaks
@@ -485,4 +513,4 @@ class PrivateModeProvider(BaseLLMProvider):
logger.debug("Closed PrivateMode HTTP session")
await super().cleanup()
logger.debug("PrivateMode provider cleanup completed")
logger.debug("PrivateMode provider cleanup completed")

View File

@@ -20,14 +20,16 @@ logger = logging.getLogger(__name__)
class CircuitBreakerState(Enum):
"""Circuit breaker states"""
CLOSED = "closed" # Normal operation
OPEN = "open" # Failing, blocking requests
CLOSED = "closed" # Normal operation
OPEN = "open" # Failing, blocking requests
HALF_OPEN = "half_open" # Testing if service recovered
@dataclass
class CircuitBreakerStats:
"""Circuit breaker statistics"""
failure_count: int = 0
success_count: int = 0
last_failure_time: Optional[datetime] = None
@@ -37,162 +39,186 @@ class CircuitBreakerStats:
class CircuitBreaker:
"""Circuit breaker implementation for provider resilience"""
def __init__(self, config: ResilienceConfig, provider_name: str):
self.config = config
self.provider_name = provider_name
self.state = CircuitBreakerState.CLOSED
self.stats = CircuitBreakerStats()
def can_execute(self) -> bool:
"""Check if request can be executed"""
if self.state == CircuitBreakerState.CLOSED:
return True
if self.state == CircuitBreakerState.OPEN:
# Check if reset timeout has passed
if (datetime.utcnow() - self.stats.state_change_time).total_seconds() * 1000 > self.config.circuit_breaker_reset_timeout_ms:
if (
datetime.utcnow() - self.stats.state_change_time
).total_seconds() * 1000 > self.config.circuit_breaker_reset_timeout_ms:
self._transition_to_half_open()
return True
return False
if self.state == CircuitBreakerState.HALF_OPEN:
return True
return False
def record_success(self):
"""Record successful request"""
self.stats.success_count += 1
self.stats.last_success_time = datetime.utcnow()
if self.state == CircuitBreakerState.HALF_OPEN:
self._transition_to_closed()
elif self.state == CircuitBreakerState.CLOSED:
# Reset failure count on success
self.stats.failure_count = 0
logger.debug(f"Circuit breaker [{self.provider_name}]: Success recorded, state={self.state.value}")
logger.debug(
f"Circuit breaker [{self.provider_name}]: Success recorded, state={self.state.value}"
)
def record_failure(self):
"""Record failed request"""
self.stats.failure_count += 1
self.stats.last_failure_time = datetime.utcnow()
if self.state == CircuitBreakerState.CLOSED:
if self.stats.failure_count >= self.config.circuit_breaker_threshold:
self._transition_to_open()
elif self.state == CircuitBreakerState.HALF_OPEN:
self._transition_to_open()
logger.warning(f"Circuit breaker [{self.provider_name}]: Failure recorded, "
f"count={self.stats.failure_count}, state={self.state.value}")
logger.warning(
f"Circuit breaker [{self.provider_name}]: Failure recorded, "
f"count={self.stats.failure_count}, state={self.state.value}"
)
def _transition_to_open(self):
"""Transition to OPEN state"""
self.state = CircuitBreakerState.OPEN
self.stats.state_change_time = datetime.utcnow()
logger.error(f"Circuit breaker [{self.provider_name}]: OPENED after {self.stats.failure_count} failures")
logger.error(
f"Circuit breaker [{self.provider_name}]: OPENED after {self.stats.failure_count} failures"
)
def _transition_to_half_open(self):
"""Transition to HALF_OPEN state"""
self.state = CircuitBreakerState.HALF_OPEN
self.stats.state_change_time = datetime.utcnow()
logger.info(f"Circuit breaker [{self.provider_name}]: Transitioning to HALF_OPEN for testing")
logger.info(
f"Circuit breaker [{self.provider_name}]: Transitioning to HALF_OPEN for testing"
)
def _transition_to_closed(self):
"""Transition to CLOSED state"""
self.state = CircuitBreakerState.CLOSED
self.stats.state_change_time = datetime.utcnow()
self.stats.failure_count = 0 # Reset failure count
logger.info(f"Circuit breaker [{self.provider_name}]: CLOSED - service recovered")
logger.info(
f"Circuit breaker [{self.provider_name}]: CLOSED - service recovered"
)
def get_stats(self) -> Dict[str, Any]:
"""Get circuit breaker statistics"""
return {
"state": self.state.value,
"failure_count": self.stats.failure_count,
"success_count": self.stats.success_count,
"last_failure_time": self.stats.last_failure_time.isoformat() if self.stats.last_failure_time else None,
"last_success_time": self.stats.last_success_time.isoformat() if self.stats.last_success_time else None,
"last_failure_time": self.stats.last_failure_time.isoformat()
if self.stats.last_failure_time
else None,
"last_success_time": self.stats.last_success_time.isoformat()
if self.stats.last_success_time
else None,
"state_change_time": self.stats.state_change_time.isoformat(),
"time_in_current_state_ms": (datetime.utcnow() - self.stats.state_change_time).total_seconds() * 1000
"time_in_current_state_ms": (
datetime.utcnow() - self.stats.state_change_time
).total_seconds()
* 1000,
}
class RetryManager:
"""Manages retry logic with exponential backoff"""
def __init__(self, config: ResilienceConfig):
self.config = config
async def execute_with_retry(
self,
func: Callable,
*args,
retryable_exceptions: tuple = (Exception,),
non_retryable_exceptions: tuple = (RateLimitError,),
**kwargs
**kwargs,
) -> Any:
"""Execute function with retry logic"""
last_exception = None
for attempt in range(self.config.max_retries + 1):
try:
return await func(*args, **kwargs)
except non_retryable_exceptions as e:
logger.warning(f"Non-retryable exception on attempt {attempt + 1}: {e}")
raise
except retryable_exceptions as e:
last_exception = e
if attempt == self.config.max_retries:
logger.error(f"All {self.config.max_retries + 1} attempts failed. Last error: {e}")
logger.error(
f"All {self.config.max_retries + 1} attempts failed. Last error: {e}"
)
raise
delay = self._calculate_delay(attempt)
logger.warning(f"Attempt {attempt + 1} failed: {e}. Retrying in {delay}ms...")
logger.warning(
f"Attempt {attempt + 1} failed: {e}. Retrying in {delay}ms..."
)
await asyncio.sleep(delay / 1000.0)
# This should never be reached, but just in case
if last_exception:
raise last_exception
else:
raise LLMError("Unexpected error in retry logic")
def _calculate_delay(self, attempt: int) -> int:
"""Calculate delay for exponential backoff"""
delay = self.config.retry_delay_ms * (self.config.retry_exponential_base ** attempt)
delay = self.config.retry_delay_ms * (
self.config.retry_exponential_base**attempt
)
# Add some jitter to prevent thundering herd
import random
jitter = random.uniform(0.8, 1.2)
return int(delay * jitter)
class TimeoutManager:
"""Manages request timeouts"""
def __init__(self, config: ResilienceConfig):
self.config = config
async def execute_with_timeout(
self,
func: Callable,
*args,
timeout_override: Optional[int] = None,
**kwargs
self, func: Callable, *args, timeout_override: Optional[int] = None, **kwargs
) -> Any:
"""Execute function with timeout"""
timeout_ms = timeout_override or self.config.timeout_ms
timeout_seconds = timeout_ms / 1000.0
try:
return await asyncio.wait_for(func(*args, **kwargs), timeout=timeout_seconds)
return await asyncio.wait_for(
func(*args, **kwargs), timeout=timeout_seconds
)
except asyncio.TimeoutError:
error_msg = f"Request timed out after {timeout_ms}ms"
logger.error(error_msg)
@@ -201,14 +227,14 @@ class TimeoutManager:
class ResilienceManager:
"""Comprehensive resilience manager combining all patterns"""
def __init__(self, config: ResilienceConfig, provider_name: str):
self.config = config
self.provider_name = provider_name
self.circuit_breaker = CircuitBreaker(config, provider_name)
self.retry_manager = RetryManager(config)
self.timeout_manager = TimeoutManager(config)
async def execute(
self,
func: Callable,
@@ -216,18 +242,18 @@ class ResilienceManager:
retryable_exceptions: tuple = (Exception,),
non_retryable_exceptions: tuple = (RateLimitError,),
timeout_override: Optional[int] = None,
**kwargs
**kwargs,
) -> Any:
"""Execute function with full resilience patterns"""
# Check circuit breaker
if not self.circuit_breaker.can_execute():
error_msg = f"Circuit breaker is OPEN for provider {self.provider_name}"
logger.error(error_msg)
raise LLMError(error_msg, error_code="CIRCUIT_BREAKER_OPEN")
start_time = time.time()
try:
# Execute with timeout and retry
result = await self.retry_manager.execute_with_retry(
@@ -237,30 +263,34 @@ class ResilienceManager:
retryable_exceptions=retryable_exceptions,
non_retryable_exceptions=non_retryable_exceptions,
timeout_override=timeout_override,
**kwargs
**kwargs,
)
# Record success
self.circuit_breaker.record_success()
execution_time = (time.time() - start_time) * 1000
logger.debug(f"Resilient execution succeeded for {self.provider_name} in {execution_time:.2f}ms")
logger.debug(
f"Resilient execution succeeded for {self.provider_name} in {execution_time:.2f}ms"
)
return result
except Exception as e:
# Record failure
self.circuit_breaker.record_failure()
execution_time = (time.time() - start_time) * 1000
logger.error(f"Resilient execution failed for {self.provider_name} after {execution_time:.2f}ms: {e}")
logger.error(
f"Resilient execution failed for {self.provider_name} after {execution_time:.2f}ms: {e}"
)
raise
def get_health_status(self) -> Dict[str, Any]:
"""Get comprehensive health status"""
cb_stats = self.circuit_breaker.get_stats()
# Determine overall health
if cb_stats["state"] == "open":
health = "unhealthy"
@@ -273,7 +303,7 @@ class ResilienceManager:
health = "degraded"
else:
health = "healthy"
return {
"provider": self.provider_name,
"health": health,
@@ -281,34 +311,37 @@ class ResilienceManager:
"config": {
"max_retries": self.config.max_retries,
"timeout_ms": self.config.timeout_ms,
"circuit_breaker_threshold": self.config.circuit_breaker_threshold
}
"circuit_breaker_threshold": self.config.circuit_breaker_threshold,
},
}
class ResilienceManagerFactory:
"""Factory for creating resilience managers"""
_managers: Dict[str, ResilienceManager] = {}
_default_config = ResilienceConfig()
@classmethod
def get_manager(cls, provider_name: str, config: Optional[ResilienceConfig] = None) -> ResilienceManager:
def get_manager(
cls, provider_name: str, config: Optional[ResilienceConfig] = None
) -> ResilienceManager:
"""Get or create resilience manager for provider"""
if provider_name not in cls._managers:
manager_config = config or cls._default_config
cls._managers[provider_name] = ResilienceManager(manager_config, provider_name)
cls._managers[provider_name] = ResilienceManager(
manager_config, provider_name
)
return cls._managers[provider_name]
@classmethod
def get_all_health_status(cls) -> Dict[str, Dict[str, Any]]:
"""Get health status for all managed providers"""
return {
name: manager.get_health_status()
for name, manager in cls._managers.items()
name: manager.get_health_status() for name, manager in cls._managers.items()
}
@classmethod
def update_config(cls, provider_name: str, config: ResilienceConfig):
"""Update configuration for a specific provider"""
@@ -317,7 +350,7 @@ class ResilienceManagerFactory:
cls._managers[provider_name].circuit_breaker.config = config
cls._managers[provider_name].retry_manager.config = config
cls._managers[provider_name].timeout_manager.config = config
@classmethod
def reset_circuit_breaker(cls, provider_name: str):
"""Manually reset circuit breaker for a provider"""
@@ -325,8 +358,8 @@ class ResilienceManagerFactory:
manager = cls._managers[provider_name]
manager.circuit_breaker._transition_to_closed()
logger.info(f"Manually reset circuit breaker for {provider_name}")
@classmethod
def set_default_config(cls, config: ResilienceConfig):
"""Set default configuration for new managers"""
cls._default_config = config
cls._default_config = config

View File

@@ -12,18 +12,28 @@ from typing import Dict, Any, Optional, List, AsyncGenerator
from datetime import datetime
from .models import (
ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
ModelInfo, ProviderStatus, LLMMetrics
ChatRequest,
ChatResponse,
EmbeddingRequest,
EmbeddingResponse,
ModelInfo,
ProviderStatus,
LLMMetrics,
)
from .config import config_manager, ProviderConfig
from ...core.config import settings
from .resilience import ResilienceManagerFactory
# from .metrics import metrics_collector
from .providers import BaseLLMProvider, PrivateModeProvider
from .exceptions import (
LLMError, ProviderError, SecurityError, ConfigurationError,
ValidationError, TimeoutError
LLMError,
ProviderError,
SecurityError,
ConfigurationError,
ValidationError,
TimeoutError,
)
logger = logging.getLogger(__name__)
@@ -31,58 +41,64 @@ logger = logging.getLogger(__name__)
class LLMService:
"""Main LLM service coordinating all components"""
def __init__(self):
"""Initialize LLM service"""
self._providers: Dict[str, BaseLLMProvider] = {}
self._initialized = False
self._startup_time: Optional[datetime] = None
logger.info("LLM Service initialized")
async def initialize(self):
"""Initialize service and providers"""
if self._initialized:
logger.warning("LLM Service already initialized")
return
start_time = time.time()
self._startup_time = datetime.utcnow()
try:
# Get configuration
config = config_manager.get_config()
logger.info(f"Initializing LLM service with {len(config.providers)} configured providers")
logger.info(
f"Initializing LLM service with {len(config.providers)} configured providers"
)
# Initialize enabled providers
enabled_providers = config_manager.get_enabled_providers()
if not enabled_providers:
raise ConfigurationError("No enabled providers found")
for provider_name in enabled_providers:
await self._initialize_provider(provider_name)
# Verify we have at least one working provider
if not self._providers:
raise ConfigurationError("No providers successfully initialized")
# Verify default provider is available
default_provider = config.default_provider
if default_provider not in self._providers:
available_providers = list(self._providers.keys())
logger.warning(f"Default provider '{default_provider}' not available, using '{available_providers[0]}'")
logger.warning(
f"Default provider '{default_provider}' not available, using '{available_providers[0]}'"
)
config.default_provider = available_providers[0]
self._initialized = True
initialization_time = (time.time() - start_time) * 1000
logger.info(f"LLM Service initialized successfully in {initialization_time:.2f}ms")
logger.info(
f"LLM Service initialized successfully in {initialization_time:.2f}ms"
)
logger.info(f"Available providers: {list(self._providers.keys())}")
except Exception as e:
logger.error(f"Failed to initialize LLM service: {e}")
raise ConfigurationError(f"LLM service initialization failed: {e}")
async def _initialize_provider(self, provider_name: str):
"""Initialize a specific provider"""
try:
@@ -90,101 +106,109 @@ class LLMService:
if not provider_config or not provider_config.enabled:
logger.warning(f"Provider '{provider_name}' not enabled, skipping")
return
# Get API key
api_key = config_manager.get_api_key(provider_name)
if not api_key:
logger.error(f"No API key found for provider '{provider_name}'")
return
# Create provider instance
provider = self._create_provider(provider_config, api_key)
# Initialize provider
await provider.initialize()
# Test provider health
health_status = await provider.health_check()
if health_status.status == "unavailable":
logger.error(f"Provider '{provider_name}' failed health check: {health_status.error_message}")
logger.error(
f"Provider '{provider_name}' failed health check: {health_status.error_message}"
)
return
# Register provider
self._providers[provider_name] = provider
logger.info(f"Provider '{provider_name}' initialized successfully (status: {health_status.status})")
logger.info(
f"Provider '{provider_name}' initialized successfully (status: {health_status.status})"
)
# Fetch and update models dynamically
await self._refresh_provider_models(provider_name, provider)
except Exception as e:
logger.error(f"Failed to initialize provider '{provider_name}': {e}")
def _create_provider(self, config: ProviderConfig, api_key: str) -> BaseLLMProvider:
"""Create provider instance based on configuration"""
if config.name == "privatemode":
return PrivateModeProvider(config, api_key)
else:
raise ConfigurationError(f"Unknown provider type: {config.name}")
async def _refresh_provider_models(self, provider_name: str, provider: BaseLLMProvider):
async def _refresh_provider_models(
self, provider_name: str, provider: BaseLLMProvider
):
"""Fetch and update models dynamically from provider"""
try:
# Get models from provider
models = await provider.get_models()
model_ids = [model.id for model in models]
# Update configuration
await config_manager.refresh_provider_models(provider_name, model_ids)
logger.info(f"Refreshed {len(model_ids)} models for provider '{provider_name}': {model_ids}")
logger.info(
f"Refreshed {len(model_ids)} models for provider '{provider_name}': {model_ids}"
)
except Exception as e:
logger.error(f"Failed to refresh models for provider '{provider_name}': {e}")
logger.error(
f"Failed to refresh models for provider '{provider_name}': {e}"
)
async def create_chat_completion(self, request: ChatRequest) -> ChatResponse:
"""Create chat completion with security and resilience"""
if not self._initialized:
await self.initialize()
# Validate request
if not request.messages:
raise ValidationError("Messages cannot be empty", field="messages")
risk_score = 0.0
# Get provider for model
provider_name = self._get_provider_for_model(request.model)
provider = self._providers.get(provider_name)
if not provider:
raise ProviderError(f"No available provider for model '{request.model}'", provider=provider_name)
raise ProviderError(
f"No available provider for model '{request.model}'",
provider=provider_name,
)
# Execute with resilience
resilience_manager = ResilienceManagerFactory.get_manager(provider_name)
start_time = time.time()
try:
response = await resilience_manager.execute(
provider.create_chat_completion,
request,
retryable_exceptions=(ProviderError, TimeoutError),
non_retryable_exceptions=(ValidationError,)
non_retryable_exceptions=(ValidationError,),
)
# Record successful request - metrics disabled
total_latency = (time.time() - start_time) * 1000
return response
except Exception as e:
# Record failed request - metrics disabled
total_latency = (time.time() - start_time) * 1000
error_code = getattr(e, 'error_code', e.__class__.__name__)
error_code = getattr(e, "error_code", e.__class__.__name__)
logger.exception(
"Chat completion failed for provider %s (model=%s, latency=%.2fms, error=%s)",
@@ -194,38 +218,42 @@ class LLMService:
error_code,
)
raise
async def create_chat_completion_stream(self, request: ChatRequest) -> AsyncGenerator[Dict[str, Any], None]:
async def create_chat_completion_stream(
self, request: ChatRequest
) -> AsyncGenerator[Dict[str, Any], None]:
"""Create streaming chat completion"""
if not self._initialized:
await self.initialize()
# Security validation disabled - always allow streaming requests
risk_score = 0.0
# Get provider
provider_name = self._get_provider_for_model(request.model)
provider = self._providers.get(provider_name)
if not provider:
raise ProviderError(f"No available provider for model '{request.model}'", provider=provider_name)
raise ProviderError(
f"No available provider for model '{request.model}'",
provider=provider_name,
)
# Execute streaming with resilience
resilience_manager = ResilienceManagerFactory.get_manager(provider_name)
try:
async for chunk in await resilience_manager.execute(
provider.create_chat_completion_stream,
request,
retryable_exceptions=(ProviderError, TimeoutError),
non_retryable_exceptions=(ValidationError,)
non_retryable_exceptions=(ValidationError,),
):
yield chunk
except Exception as e:
# Record streaming failure - metrics disabled
error_code = getattr(e, 'error_code', e.__class__.__name__)
error_code = getattr(e, "error_code", e.__class__.__name__)
logger.exception(
"Streaming chat completion failed for provider %s (model=%s, error=%s)",
provider_name,
@@ -233,46 +261,46 @@ class LLMService:
error_code,
)
raise
async def create_embedding(self, request: EmbeddingRequest) -> EmbeddingResponse:
"""Create embeddings with security and resilience"""
if not self._initialized:
await self.initialize()
# Security validation disabled - always allow embedding requests
risk_score = 0.0
# Get provider
provider_name = self._get_provider_for_model(request.model)
provider = self._providers.get(provider_name)
if not provider:
raise ProviderError(f"No available provider for model '{request.model}'", provider=provider_name)
raise ProviderError(
f"No available provider for model '{request.model}'",
provider=provider_name,
)
# Execute with resilience
resilience_manager = ResilienceManagerFactory.get_manager(provider_name)
start_time = time.time()
try:
response = await resilience_manager.execute(
provider.create_embedding,
request,
retryable_exceptions=(ProviderError, TimeoutError),
non_retryable_exceptions=(ValidationError,)
non_retryable_exceptions=(ValidationError,),
)
# Record successful request - metrics disabled
total_latency = (time.time() - start_time) * 1000
return response
except Exception as e:
# Record failed request - metrics disabled
total_latency = (time.time() - start_time) * 1000
error_code = getattr(e, 'error_code', e.__class__.__name__)
error_code = getattr(e, "error_code", e.__class__.__name__)
logger.exception(
"Embedding request failed for provider %s (model=%s, latency=%.2fms, error=%s)",
provider_name,
@@ -281,14 +309,14 @@ class LLMService:
error_code,
)
raise
async def get_models(self, provider_name: Optional[str] = None) -> List[ModelInfo]:
"""Get available models from all or specific provider"""
if not self._initialized:
await self.initialize()
models = []
if provider_name:
# Get models from specific provider
provider = self._providers.get(provider_name)
@@ -306,16 +334,16 @@ class LLMService:
models.extend(provider_models)
except Exception as e:
logger.error(f"Failed to get models from {name}: {e}")
return models
async def get_provider_status(self) -> Dict[str, ProviderStatus]:
"""Get health status of all providers"""
if not self._initialized:
await self.initialize()
status_dict = {}
for name, provider in self._providers.items():
try:
status = await provider.health_check()
@@ -327,21 +355,18 @@ class LLMService:
status="unavailable",
last_check=datetime.utcnow(),
error_message=str(e),
models_available=[]
models_available=[],
)
return status_dict
def get_metrics(self) -> LLMMetrics:
"""Get service metrics - metrics disabled"""
# return metrics_collector.get_metrics()
return LLMMetrics(
total_requests=0,
success_rate=0.0,
avg_latency_ms=0,
error_rates={}
total_requests=0, success_rate=0.0, avg_latency_ms=0, error_rates={}
)
def get_health_summary(self) -> Dict[str, Any]:
"""Get comprehensive health summary - metrics disabled"""
# metrics_health = metrics_collector.get_health_summary()
@@ -349,40 +374,42 @@ class LLMService:
return {
"service_status": "healthy" if self._initialized else "initializing",
"startup_time": self._startup_time.isoformat() if self._startup_time else None,
"startup_time": self._startup_time.isoformat()
if self._startup_time
else None,
"provider_count": len(self._providers),
"active_providers": list(self._providers.keys()),
"metrics": {"status": "disabled"},
"resilience": resilience_health
"resilience": resilience_health,
}
def _get_provider_for_model(self, model: str) -> str:
"""Get provider name for a model"""
# Check model routing first
provider_name = config_manager.get_provider_for_model(model)
if provider_name and provider_name in self._providers:
return provider_name
# Fall back to providers that support the model
for name, provider in self._providers.items():
if provider.supports_model(model):
return name
# Use default provider as last resort
config = config_manager.get_config()
if config.default_provider in self._providers:
return config.default_provider
# If nothing else works, use first available provider
if self._providers:
return list(self._providers.keys())[0]
raise ProviderError(f"No provider found for model '{model}'", provider="none")
async def cleanup(self):
"""Cleanup service resources"""
logger.info("Cleaning up LLM service")
# Cleanup providers
for name, provider in self._providers.items():
try:
@@ -390,7 +417,7 @@ class LLMService:
logger.debug(f"Cleaned up provider: {name}")
except Exception as e:
logger.error(f"Error cleaning up provider {name}: {e}")
self._providers.clear()
self._initialized = False
logger.info("LLM service cleanup completed")

Some files were not shown because too many files have changed in this diff Show More