mega changes

This commit is contained in:
2025-11-20 11:11:18 +01:00
parent e070c95190
commit 841d79f26b
138 changed files with 21499 additions and 8844 deletions

13
.gitignore vendored
View File

@@ -66,3 +66,16 @@ frontend/node_modules/
node_modules/
venv/
venv_memory_monitor/
to_delete/
security-review.md
features-plan.md
AGENTS.md
CLAUDE.md
backend/CURL_VERIFICATION_EXAMPLES.md
backend/OPENAI_COMPATIBILITY_GUIDE.md
backend/production-deployment-guide.md
features-plan.md
security-review.md
backend/.env.local
backend/.env.test

View File

@@ -1,12 +1,17 @@
import asyncio
import os
from logging.config import fileConfig
from dotenv import load_dotenv
from alembic import context
from sqlalchemy import engine_from_config, pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import AsyncEngine
# Load environment variables from .env file
load_dotenv()
load_dotenv('.env.local')
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config

View File

@@ -4,7 +4,7 @@ This migration represents the complete, accurate database schema based on the ac
model files in the codebase. All legacy migrations have been consolidated into this
single migration to ensure the database matches what the models expect.
Revision ID: 000_consolidated_ground_truth_schema
Revision ID: 000_ground_truth
Revises:
Create Date: 2025-08-22 10:30:00.000000

View File

@@ -0,0 +1,62 @@
"""add roles table
Revision ID: 001
Revises: 000_ground_truth
Create Date: 2025-01-30 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers
revision = '001_add_roles_table'
down_revision = '000_ground_truth'
branch_labels = None
depends_on = None
def upgrade():
# Create roles table
op.create_table(
'roles',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=50), nullable=False),
sa.Column('display_name', sa.String(length=100), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('level', sa.String(length=20), nullable=False),
sa.Column('permissions', sa.JSON(), nullable=True),
sa.Column('can_manage_users', sa.Boolean(), nullable=True, default=False),
sa.Column('can_manage_budgets', sa.Boolean(), nullable=True, default=False),
sa.Column('can_view_reports', sa.Boolean(), nullable=True, default=False),
sa.Column('can_manage_tools', sa.Boolean(), nullable=True, default=False),
sa.Column('inherits_from', sa.JSON(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
sa.Column('is_system_role', sa.Boolean(), nullable=True, default=False),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
# Create indexes for roles
op.create_index(op.f('ix_roles_id'), 'roles', ['id'], unique=False)
op.create_index(op.f('ix_roles_name'), 'roles', ['name'], unique=True)
op.create_index(op.f('ix_roles_level'), 'roles', ['level'], unique=False)
# Add role_id to users table
op.add_column('users', sa.Column('role_id', sa.Integer(), nullable=True))
op.create_foreign_key(
'fk_users_role_id', 'users', 'roles',
['role_id'], ['id'], ondelete='SET NULL'
)
op.create_index('ix_users_role_id', 'users', ['role_id'])
def downgrade():
# Remove role_id from users
op.drop_index('ix_users_role_id', table_name='users')
op.drop_constraint('fk_users_role_id', table_name='users', type_='foreignkey')
op.drop_column('users', 'role_id')
# Drop roles table
op.drop_index(op.f('ix_roles_level'), table_name='roles')
op.drop_index(op.f('ix_roles_name'), table_name='roles')
op.drop_index(op.f('ix_roles_id'), table_name='roles')
op.drop_table('roles')

View File

@@ -0,0 +1,118 @@
"""add tools tables
Revision ID: 002
Revises: 001_add_roles_table
Create Date: 2025-01-30 00:00:01.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers
revision = '002_add_tools_tables'
down_revision = '001_add_roles_table'
branch_labels = None
depends_on = None
def upgrade():
# Create tool_categories table
op.create_table(
'tool_categories',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=50), nullable=False),
sa.Column('display_name', sa.String(length=100), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('icon', sa.String(length=50), nullable=True),
sa.Column('color', sa.String(length=20), nullable=True),
sa.Column('sort_order', sa.Integer(), nullable=True, default=0),
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_tool_categories_id'), 'tool_categories', ['id'], unique=False)
op.create_index(op.f('ix_tool_categories_name'), 'tool_categories', ['name'], unique=True)
# Create tools table
op.create_table(
'tools',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=100), nullable=False),
sa.Column('display_name', sa.String(length=200), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('tool_type', sa.String(length=20), nullable=False),
sa.Column('code', sa.Text(), nullable=False),
sa.Column('parameters_schema', sa.JSON(), nullable=True),
sa.Column('return_schema', sa.JSON(), nullable=True),
sa.Column('timeout_seconds', sa.Integer(), nullable=True, default=30),
sa.Column('max_memory_mb', sa.Integer(), nullable=True, default=256),
sa.Column('max_cpu_seconds', sa.Float(), nullable=True, default=10.0),
sa.Column('docker_image', sa.String(length=200), nullable=True),
sa.Column('docker_command', sa.Text(), nullable=True),
sa.Column('is_public', sa.Boolean(), nullable=True, default=False),
sa.Column('is_approved', sa.Boolean(), nullable=True, default=False),
sa.Column('created_by_user_id', sa.Integer(), nullable=False),
sa.Column('category', sa.String(length=50), nullable=True),
sa.Column('tags', sa.JSON(), nullable=True),
sa.Column('usage_count', sa.Integer(), nullable=True, default=0),
sa.Column('last_used_at', sa.DateTime(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_tools_id'), 'tools', ['id'], unique=False)
op.create_index(op.f('ix_tools_name'), 'tools', ['name'], unique=False)
op.create_foreign_key(
'fk_tools_created_by_user_id', 'tools', 'users',
['created_by_user_id'], ['id'], ondelete='CASCADE'
)
# Create tool_executions table
op.create_table(
'tool_executions',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('tool_id', sa.Integer(), nullable=False),
sa.Column('executed_by_user_id', sa.Integer(), nullable=False),
sa.Column('parameters', sa.JSON(), nullable=True),
sa.Column('status', sa.String(length=20), nullable=False, default='pending'),
sa.Column('output', sa.Text(), nullable=True),
sa.Column('error_message', sa.Text(), nullable=True),
sa.Column('return_code', sa.Integer(), nullable=True),
sa.Column('execution_time_ms', sa.Integer(), nullable=True),
sa.Column('memory_used_mb', sa.Float(), nullable=True),
sa.Column('cpu_time_ms', sa.Integer(), nullable=True),
sa.Column('container_id', sa.String(length=100), nullable=True),
sa.Column('docker_logs', sa.Text(), nullable=True),
sa.Column('started_at', sa.DateTime(), nullable=True),
sa.Column('completed_at', sa.DateTime(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_tool_executions_id'), 'tool_executions', ['id'], unique=False)
op.create_foreign_key(
'fk_tool_executions_tool_id', 'tool_executions', 'tools',
['tool_id'], ['id'], ondelete='CASCADE'
)
op.create_foreign_key(
'fk_tool_executions_executed_by_user_id', 'tool_executions', 'users',
['executed_by_user_id'], ['id'], ondelete='CASCADE'
)
def downgrade():
# Drop tool_executions table
op.drop_constraint('fk_tool_executions_executed_by_user_id', table_name='tool_executions', type_='foreignkey')
op.drop_constraint('fk_tool_executions_tool_id', table_name='tool_executions', type_='foreignkey')
op.drop_index(op.f('ix_tool_executions_id'), table_name='tool_executions')
op.drop_table('tool_executions')
# Drop tools table
op.drop_constraint('fk_tools_created_by_user_id', table_name='tools', type_='foreignkey')
op.drop_index(op.f('ix_tools_name'), table_name='tools')
op.drop_index(op.f('ix_tools_id'), table_name='tools')
op.drop_table('tools')
# Drop tool_categories table
op.drop_index(op.f('ix_tool_categories_name'), table_name='tool_categories')
op.drop_index(op.f('ix_tool_categories_id'), table_name='tool_categories')
op.drop_table('tool_categories')

View File

@@ -0,0 +1,132 @@
"""add notifications tables
Revision ID: 003_add_notifications_tables
Revises: 002_add_tools_tables
Create Date: 2025-01-30 00:00:02.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers
revision = '003_add_notifications_tables'
down_revision = '002_add_tools_tables'
branch_labels = None
depends_on = None
def upgrade():
# Create notification_templates table
op.create_table(
'notification_templates',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=100), nullable=False),
sa.Column('display_name', sa.String(length=200), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('notification_type', sa.String(length=20), nullable=False),
sa.Column('subject_template', sa.Text(), nullable=True),
sa.Column('body_template', sa.Text(), nullable=False),
sa.Column('html_template', sa.Text(), nullable=True),
sa.Column('default_priority', sa.String(length=20), nullable=True, default='normal'),
sa.Column('variables', sa.JSON(), nullable=True),
sa.Column('metadata', sa.JSON(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_notification_templates_id'), 'notification_templates', ['id'], unique=False)
op.create_index(op.f('ix_notification_templates_name'), 'notification_templates', ['name'], unique=True)
# Create notification_channels table
op.create_table(
'notification_channels',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=100), nullable=False),
sa.Column('display_name', sa.String(length=200), nullable=False),
sa.Column('notification_type', sa.String(length=20), nullable=False),
sa.Column('config', sa.JSON(), nullable=False),
sa.Column('credentials', sa.JSON(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
sa.Column('is_default', sa.Boolean(), nullable=True, default=False),
sa.Column('rate_limit', sa.Integer(), nullable=True, default=100),
sa.Column('retry_count', sa.Integer(), nullable=True, default=3),
sa.Column('retry_delay_minutes', sa.Integer(), nullable=True, default=5),
sa.Column('last_used_at', sa.DateTime(), nullable=True),
sa.Column('success_count', sa.Integer(), nullable=True, default=0),
sa.Column('failure_count', sa.Integer(), nullable=True, default=0),
sa.Column('last_error', sa.Text(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_notification_channels_id'), 'notification_channels', ['id'], unique=False)
op.create_index(op.f('ix_notification_channels_name'), 'notification_channels', ['name'], unique=False)
# Create notifications table
op.create_table(
'notifications',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('subject', sa.String(length=500), nullable=True),
sa.Column('body', sa.Text(), nullable=False),
sa.Column('html_body', sa.Text(), nullable=True),
sa.Column('recipients', sa.JSON(), nullable=False),
sa.Column('cc_recipients', sa.JSON(), nullable=True),
sa.Column('bcc_recipients', sa.JSON(), nullable=True),
sa.Column('priority', sa.String(length=20), nullable=True, default='normal'),
sa.Column('scheduled_at', sa.DateTime(), nullable=True),
sa.Column('expires_at', sa.DateTime(), nullable=True),
sa.Column('template_id', sa.Integer(), nullable=True),
sa.Column('channel_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('status', sa.String(length=20), nullable=True, default='pending'),
sa.Column('attempts', sa.Integer(), nullable=True, default=0),
sa.Column('max_attempts', sa.Integer(), nullable=True, default=3),
sa.Column('sent_at', sa.DateTime(), nullable=True),
sa.Column('delivered_at', sa.DateTime(), nullable=True),
sa.Column('failed_at', sa.DateTime(), nullable=True),
sa.Column('error_message', sa.Text(), nullable=True),
sa.Column('external_id', sa.String(length=200), nullable=True),
sa.Column('callback_url', sa.String(length=500), nullable=True),
sa.Column('metadata', sa.JSON(), nullable=True),
sa.Column('tags', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_notifications_id'), 'notifications', ['id'], unique=False)
op.create_index(op.f('ix_notifications_status'), 'notifications', ['status'], unique=False)
op.create_index(op.f('ix_notifications_scheduled_at'), 'notifications', ['scheduled_at'], unique=False)
# Add foreign key constraints
op.create_foreign_key(
'fk_notifications_template_id', 'notifications', 'notification_templates',
['template_id'], ['id'], ondelete='SET NULL'
)
op.create_foreign_key(
'fk_notifications_channel_id', 'notifications', 'notification_channels',
['channel_id'], ['id'], ondelete='CASCADE'
)
op.create_foreign_key(
'fk_notifications_user_id', 'notifications', 'users',
['user_id'], ['id'], ondelete='SET NULL'
)
def downgrade():
# Drop notifications table
op.drop_constraint('fk_notifications_user_id', table_name='notifications', type_='foreignkey')
op.drop_constraint('fk_notifications_channel_id', table_name='notifications', type_='foreignkey')
op.drop_constraint('fk_notifications_template_id', table_name='notifications', type_='foreignkey')
op.drop_index(op.f('ix_notifications_scheduled_at'), table_name='notifications')
op.drop_index(op.f('ix_notifications_status'), table_name='notifications')
op.drop_index(op.f('ix_notifications_id'), table_name='notifications')
op.drop_table('notifications')
# Drop notification_channels table
op.drop_index(op.f('ix_notification_channels_name'), table_name='notification_channels')
op.drop_index(op.f('ix_notification_channels_id'), table_name='notification_channels')
op.drop_table('notification_channels')
# Drop notification_templates table
op.drop_index(op.f('ix_notification_templates_name'), table_name='notification_templates')
op.drop_index(op.f('ix_notification_templates_id'), table_name='notification_templates')
op.drop_table('notification_templates')

View File

@@ -0,0 +1,26 @@
"""Add force_password_change to users
Revision ID: 004_add_force_password_change
Revises: fd999a559a35
Create Date: 2025-01-31 10:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '004_add_force_password_change'
down_revision = 'fd999a559a35'
branch_labels = None
depends_on = None
def upgrade():
# Add force_password_change column to users table
op.add_column('users', sa.Column('force_password_change', sa.Boolean(), default=False, nullable=False, server_default='false'))
def downgrade():
# Remove force_password_change column from users table
op.drop_column('users', 'force_password_change')

View File

@@ -0,0 +1,79 @@
"""fix user nullable columns
Revision ID: 005_fix_user_nullable_columns
Revises: 004_add_force_password_change
Create Date: 2025-11-20 08:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "005_fix_user_nullable_columns"
down_revision = "004_add_force_password_change"
branch_labels = None
depends_on = None
def upgrade() -> None:
"""
Fix nullable columns in users table:
- Backfill NULL values for account_locked and failed_login_attempts
- Set proper server defaults
- Alter columns to NOT NULL
"""
# Use connection to execute raw SQL for backfilling
conn = op.get_bind()
# Backfill NULL values for account_locked
conn.execute(
sa.text("UPDATE users SET account_locked = FALSE WHERE account_locked IS NULL")
)
# Backfill NULL values for failed_login_attempts
conn.execute(
sa.text("UPDATE users SET failed_login_attempts = 0 WHERE failed_login_attempts IS NULL")
)
# Backfill NULL values for custom_permissions (use empty JSON object)
conn.execute(
sa.text("UPDATE users SET custom_permissions = '{}' WHERE custom_permissions IS NULL")
)
# Now alter columns to NOT NULL with server defaults
# Note: PostgreSQL syntax
op.alter_column('users', 'account_locked',
existing_type=sa.Boolean(),
nullable=False,
server_default=sa.false())
op.alter_column('users', 'failed_login_attempts',
existing_type=sa.Integer(),
nullable=False,
server_default='0')
op.alter_column('users', 'custom_permissions',
existing_type=sa.JSON(),
nullable=False,
server_default='{}')
def downgrade() -> None:
"""
Revert columns to nullable (original state from fd999a559a35)
"""
op.alter_column('users', 'account_locked',
existing_type=sa.Boolean(),
nullable=True,
server_default=None)
op.alter_column('users', 'failed_login_attempts',
existing_type=sa.Integer(),
nullable=True,
server_default=None)
op.alter_column('users', 'custom_permissions',
existing_type=sa.JSON(),
nullable=True,
server_default=None)

View File

@@ -0,0 +1,71 @@
"""fix missing user columns
Revision ID: fd999a559a35
Revises: 003_add_notifications_tables
Create Date: 2025-10-30 11:33:42.236622
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "fd999a559a35"
down_revision = "003_add_notifications_tables"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add missing columns to users table
# These columns should have been added in 001_add_roles_table.py but were not
# Use try/except to handle cases where columns might already exist
try:
op.add_column("users", sa.Column("custom_permissions", sa.JSON(), nullable=True, default=dict))
except Exception:
pass # Column might already exist
try:
op.add_column("users", sa.Column("account_locked", sa.Boolean(), nullable=True, default=False))
except Exception:
pass
try:
op.add_column("users", sa.Column("account_locked_until", sa.DateTime(), nullable=True))
except Exception:
pass
try:
op.add_column("users", sa.Column("failed_login_attempts", sa.Integer(), nullable=True, default=0))
except Exception:
pass
try:
op.add_column("users", sa.Column("last_failed_login", sa.DateTime(), nullable=True))
except Exception:
pass
def downgrade() -> None:
# Remove the columns
try:
op.drop_column("users", "last_failed_login")
except Exception:
pass
try:
op.drop_column("users", "failed_login_attempts")
except Exception:
pass
try:
op.drop_column("users", "account_locked_until")
except Exception:
pass
try:
op.drop_column("users", "account_locked")
except Exception:
pass
try:
op.drop_column("users", "custom_permissions")
except Exception:
pass

View File

@@ -44,7 +44,9 @@ class HealthChecker:
await session.execute(select(1))
# Check table availability
await session.execute(text("SELECT COUNT(*) FROM information_schema.tables"))
await session.execute(
text("SELECT COUNT(*) FROM information_schema.tables")
)
duration = time.time() - start_time
@@ -54,8 +56,8 @@ class HealthChecker:
"timestamp": datetime.utcnow().isoformat(),
"details": {
"connection": "successful",
"query_execution": "successful"
}
"query_execution": "successful",
},
}
except Exception as e:
@@ -64,10 +66,7 @@ class HealthChecker:
"status": "unhealthy",
"error": str(e),
"timestamp": datetime.utcnow().isoformat(),
"details": {
"connection": "failed",
"error_type": type(e).__name__
}
"details": {"connection": "failed", "error_type": type(e).__name__},
}
async def check_memory_health(self) -> Dict[str, Any]:
@@ -107,7 +106,7 @@ class HealthChecker:
"process_memory_mb": round(process_memory_mb, 2),
"system_memory_percent": memory.percent,
"system_available_gb": round(memory.available / (1024**3), 2),
"issues": issues
"issues": issues,
}
except Exception as e:
@@ -115,7 +114,7 @@ class HealthChecker:
return {
"status": "error",
"error": str(e),
"timestamp": datetime.utcnow().isoformat()
"timestamp": datetime.utcnow().isoformat(),
}
async def check_connection_health(self) -> Dict[str, Any]:
@@ -128,8 +127,16 @@ class HealthChecker:
# Analyze connections
total_connections = len(connections)
established_connections = len([c for c in connections if c.status == 'ESTABLISHED'])
http_connections = len([c for c in connections if any(port in str(c.laddr) for port in [80, 8000, 3000])])
established_connections = len(
[c for c in connections if c.status == "ESTABLISHED"]
)
http_connections = len(
[
c
for c in connections
if any(port in str(c.laddr) for port in [80, 8000, 3000])
]
)
# Check for connection issues
connection_status = "healthy"
@@ -154,7 +161,7 @@ class HealthChecker:
"total_connections": total_connections,
"established_connections": established_connections,
"http_connections": http_connections,
"issues": issues
"issues": issues,
}
except Exception as e:
@@ -162,7 +169,7 @@ class HealthChecker:
return {
"status": "error",
"error": str(e),
"timestamp": datetime.utcnow().isoformat()
"timestamp": datetime.utcnow().isoformat(),
}
async def check_embedding_service_health(self) -> Dict[str, Any]:
@@ -193,7 +200,7 @@ class HealthChecker:
"response_time_ms": round(duration * 1000, 2),
"timestamp": datetime.utcnow().isoformat(),
"stats": stats,
"issues": issues
"issues": issues,
}
except Exception as e:
@@ -201,7 +208,7 @@ class HealthChecker:
return {
"status": "error",
"error": str(e),
"timestamp": datetime.utcnow().isoformat()
"timestamp": datetime.utcnow().isoformat(),
}
async def check_redis_health(self) -> Dict[str, Any]:
@@ -209,7 +216,7 @@ class HealthChecker:
if not settings.REDIS_URL:
return {
"status": "not_configured",
"timestamp": datetime.utcnow().isoformat()
"timestamp": datetime.utcnow().isoformat(),
}
try:
@@ -233,7 +240,7 @@ class HealthChecker:
return {
"status": "healthy",
"response_time_ms": round(duration * 1000, 2),
"timestamp": datetime.utcnow().isoformat()
"timestamp": datetime.utcnow().isoformat(),
}
except Exception as e:
@@ -241,7 +248,7 @@ class HealthChecker:
return {
"status": "unhealthy",
"error": str(e),
"timestamp": datetime.utcnow().isoformat()
"timestamp": datetime.utcnow().isoformat(),
}
async def get_comprehensive_health(self) -> Dict[str, Any]:
@@ -251,7 +258,7 @@ class HealthChecker:
"memory": await self.check_memory_health(),
"connections": await self.check_connection_health(),
"embedding_service": await self.check_embedding_service_health(),
"redis": await self.check_redis_health()
"redis": await self.check_redis_health(),
}
# Determine overall status
@@ -274,12 +281,16 @@ class HealthChecker:
"summary": {
"total_checks": len(checks),
"healthy_checks": len([s for s in statuses if s == "healthy"]),
"degraded_checks": len([s for s in statuses if s in ["warning", "degraded", "unhealthy"]]),
"failed_checks": len([s for s in statuses if s in ["critical", "error"]]),
"total_issues": total_issues
"degraded_checks": len(
[s for s in statuses if s in ["warning", "degraded", "unhealthy"]]
),
"failed_checks": len(
[s for s in statuses if s in ["critical", "error"]]
),
"total_issues": total_issues,
},
"version": "1.0.0",
"uptime_seconds": int(time.time() - psutil.boot_time())
"uptime_seconds": int(time.time() - psutil.boot_time()),
}
@@ -294,7 +305,7 @@ async def basic_health_check():
"status": "healthy",
"app": settings.APP_NAME,
"version": "1.0.0",
"timestamp": datetime.utcnow().isoformat()
"timestamp": datetime.utcnow().isoformat(),
}
@@ -307,7 +318,7 @@ async def detailed_health_check():
logger.error(f"Detailed health check failed: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Health check failed: {str(e)}"
detail=f"Health check failed: {str(e)}",
)
@@ -320,7 +331,7 @@ async def memory_health_check():
logger.error(f"Memory health check failed: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Memory health check failed: {str(e)}"
detail=f"Memory health check failed: {str(e)}",
)
@@ -333,7 +344,7 @@ async def connection_health_check():
logger.error(f"Connection health check failed: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Connection health check failed: {str(e)}"
detail=f"Connection health check failed: {str(e)}",
)
@@ -346,5 +357,5 @@ async def embedding_service_health_check():
logger.error(f"Embedding service health check failed: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Embedding service health check failed: {str(e)}"
detail=f"Embedding service health check failed: {str(e)}",
)

View File

@@ -19,6 +19,7 @@ from ..v1.platform import router as platform_router
from ..v1.llm_internal import router as llm_internal_router
from ..v1.chatbot import router as chatbot_router
from .debugging import router as debugging_router
from ..v1.endpoints.user_management import router as user_management_router
# Create internal API router
internal_api_router = APIRouter()
@@ -27,47 +28,82 @@ internal_api_router = APIRouter()
internal_api_router.include_router(auth_router, prefix="/auth", tags=["internal-auth"])
# Include modules routes (frontend management)
internal_api_router.include_router(modules_router, prefix="/modules", tags=["internal-modules"])
internal_api_router.include_router(
modules_router, prefix="/modules", tags=["internal-modules"]
)
# Include platform routes (frontend platform management)
internal_api_router.include_router(platform_router, prefix="/platform", tags=["internal-platform"])
internal_api_router.include_router(
platform_router, prefix="/platform", tags=["internal-platform"]
)
# Include user management routes (frontend user admin)
internal_api_router.include_router(users_router, prefix="/users", tags=["internal-users"])
internal_api_router.include_router(
users_router, prefix="/users", tags=["internal-users"]
)
# Include API key management routes (frontend API key management)
internal_api_router.include_router(api_keys_router, prefix="/api-keys", tags=["internal-api-keys"])
internal_api_router.include_router(
api_keys_router, prefix="/api-keys", tags=["internal-api-keys"]
)
# Include budget management routes (frontend budget management)
internal_api_router.include_router(budgets_router, prefix="/budgets", tags=["internal-budgets"])
internal_api_router.include_router(
budgets_router, prefix="/budgets", tags=["internal-budgets"]
)
# Include audit log routes (frontend audit viewing)
internal_api_router.include_router(audit_router, prefix="/audit", tags=["internal-audit"])
internal_api_router.include_router(
audit_router, prefix="/audit", tags=["internal-audit"]
)
# Include settings management routes (frontend settings)
internal_api_router.include_router(settings_router, prefix="/settings", tags=["internal-settings"])
internal_api_router.include_router(
settings_router, prefix="/settings", tags=["internal-settings"]
)
# Include analytics routes (frontend analytics viewing)
internal_api_router.include_router(analytics_router, prefix="/analytics", tags=["internal-analytics"])
internal_api_router.include_router(
analytics_router, prefix="/analytics", tags=["internal-analytics"]
)
# Include RAG routes (frontend RAG document management)
internal_api_router.include_router(rag_router, prefix="/rag", tags=["internal-rag"])
# Include RAG debug routes (for demo and debugging)
internal_api_router.include_router(rag_debug_router, prefix="/rag/debug", tags=["internal-rag-debug"])
internal_api_router.include_router(
rag_debug_router, prefix="/rag/debug", tags=["internal-rag-debug"]
)
# Include prompt template routes (frontend prompt template management)
internal_api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["internal-prompt-templates"])
internal_api_router.include_router(
prompt_templates_router,
prefix="/prompt-templates",
tags=["internal-prompt-templates"],
)
# Include plugin registry routes (frontend plugin management)
internal_api_router.include_router(plugin_registry_router, prefix="/plugins", tags=["internal-plugins"])
internal_api_router.include_router(
plugin_registry_router, prefix="/plugins", tags=["internal-plugins"]
)
# Include internal LLM routes (frontend LLM service access with JWT auth)
internal_api_router.include_router(llm_internal_router, prefix="/llm", tags=["internal-llm"])
internal_api_router.include_router(
llm_internal_router, prefix="/llm", tags=["internal-llm"]
)
# Include chatbot routes (frontend chatbot management)
internal_api_router.include_router(chatbot_router, prefix="/chatbot", tags=["internal-chatbot"])
internal_api_router.include_router(
chatbot_router, prefix="/chatbot", tags=["internal-chatbot"]
)
# Include debugging routes (troubleshooting and diagnostics)
internal_api_router.include_router(debugging_router, prefix="/debugging", tags=["internal-debugging"])
internal_api_router.include_router(
debugging_router, prefix="/debugging", tags=["internal-debugging"]
)
# Include user management routes (advanced user and role management)
internal_api_router.include_router(
user_management_router, prefix="/user-management", tags=["internal-user-management"]
)

View File

@@ -20,23 +20,28 @@ router = APIRouter()
async def get_chatbot_config_debug(
chatbot_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Get detailed configuration for debugging a specific chatbot"""
# Get chatbot instance
chatbot = db.query(ChatbotInstance).filter(
ChatbotInstance.id == chatbot_id,
ChatbotInstance.user_id == current_user.id
).first()
chatbot = (
db.query(ChatbotInstance)
.filter(
ChatbotInstance.id == chatbot_id, ChatbotInstance.user_id == current_user.id
)
.first()
)
if not chatbot:
raise HTTPException(status_code=404, detail="Chatbot not found")
# Get prompt template
prompt_template = db.query(PromptTemplate).filter(
PromptTemplate.type == chatbot.chatbot_type
).first()
prompt_template = (
db.query(PromptTemplate)
.filter(PromptTemplate.type == chatbot.chatbot_type)
.first()
)
# Get RAG collections if configured
rag_collections = []
@@ -44,31 +49,37 @@ async def get_chatbot_config_debug(
collection_ids = chatbot.rag_collection_ids
if isinstance(collection_ids, str):
import json
try:
collection_ids = json.loads(collection_ids)
except:
collection_ids = []
if collection_ids:
collections = db.query(RagCollection).filter(
RagCollection.id.in_(collection_ids)
).all()
collections = (
db.query(RagCollection)
.filter(RagCollection.id.in_(collection_ids))
.all()
)
rag_collections = [
{
"id": col.id,
"name": col.name,
"document_count": col.document_count,
"qdrant_collection_name": col.qdrant_collection_name,
"is_active": col.is_active
"is_active": col.is_active,
}
for col in collections
]
# Get recent conversations count
from app.models.chatbot import ChatbotConversation
conversation_count = db.query(ChatbotConversation).filter(
ChatbotConversation.chatbot_instance_id == chatbot_id
).count()
conversation_count = (
db.query(ChatbotConversation)
.filter(ChatbotConversation.chatbot_instance_id == chatbot_id)
.count()
)
return {
"chatbot": {
@@ -78,20 +89,20 @@ async def get_chatbot_config_debug(
"description": chatbot.description,
"created_at": chatbot.created_at,
"is_active": chatbot.is_active,
"conversation_count": conversation_count
"conversation_count": conversation_count,
},
"prompt_template": {
"type": prompt_template.type if prompt_template else None,
"system_prompt": prompt_template.system_prompt if prompt_template else None,
"variables": prompt_template.variables if prompt_template else []
"variables": prompt_template.variables if prompt_template else [],
},
"rag_collections": rag_collections,
"configuration": {
"max_tokens": chatbot.max_tokens,
"temperature": chatbot.temperature,
"streaming": chatbot.streaming,
"memory_config": chatbot.memory_config
}
"memory_config": chatbot.memory_config,
},
}
@@ -101,15 +112,18 @@ async def test_rag_search(
query: str = "test query",
top_k: int = 5,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Test RAG search for a specific chatbot"""
# Get chatbot instance
chatbot = db.query(ChatbotInstance).filter(
ChatbotInstance.id == chatbot_id,
ChatbotInstance.user_id == current_user.id
).first()
chatbot = (
db.query(ChatbotInstance)
.filter(
ChatbotInstance.id == chatbot_id, ChatbotInstance.user_id == current_user.id
)
.first()
)
if not chatbot:
raise HTTPException(status_code=404, detail="Chatbot not found")
@@ -123,6 +137,7 @@ async def test_rag_search(
if chatbot.rag_collection_ids:
if isinstance(chatbot.rag_collection_ids, str):
import json
try:
collection_ids = json.loads(chatbot.rag_collection_ids)
except:
@@ -134,22 +149,19 @@ async def test_rag_search(
return {
"query": query,
"results": [],
"message": "No RAG collections configured for this chatbot"
"message": "No RAG collections configured for this chatbot",
}
# Perform search
search_results = await rag_module.search(
query=query,
collection_ids=collection_ids,
top_k=top_k,
score_threshold=0.5
query=query, collection_ids=collection_ids, top_k=top_k, score_threshold=0.5
)
return {
"query": query,
"results": search_results,
"collections_searched": collection_ids,
"result_count": len(search_results)
"result_count": len(search_results),
}
except Exception as e:
@@ -157,14 +169,13 @@ async def test_rag_search(
"query": query,
"results": [],
"error": str(e),
"message": "RAG search failed"
"message": "RAG search failed",
}
@router.get("/system/status")
async def get_system_status(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
db: Session = Depends(get_db), current_user: User = Depends(get_current_user)
):
"""Get system status for debugging"""
@@ -179,11 +190,12 @@ async def get_system_status(
module_status = {}
try:
from app.services.module_manager import module_manager
modules = module_manager.list_modules()
for module_name, module_info in modules.items():
module_status[module_name] = {
"status": module_info.get("status", "unknown"),
"enabled": module_info.get("enabled", False)
"enabled": module_info.get("enabled", False),
}
except Exception as e:
module_status = {"error": str(e)}
@@ -192,6 +204,7 @@ async def get_system_status(
redis_status = "not configured"
try:
from app.core.cache import core_cache
await core_cache.ping()
redis_status = "healthy"
except Exception as e:
@@ -201,6 +214,7 @@ async def get_system_status(
qdrant_status = "not configured"
try:
from app.services.qdrant_service import qdrant_service
collections = await qdrant_service.list_collections()
qdrant_status = f"healthy ({len(collections)} collections)"
except Exception as e:
@@ -211,5 +225,5 @@ async def get_system_status(
"modules": module_status,
"redis": redis_status,
"qdrant": qdrant_status,
"timestamp": "UTC"
"timestamp": "UTC",
}

View File

@@ -3,6 +3,7 @@ Public API v1 package - for external clients
"""
from fastapi import APIRouter
from ..v1.auth import router as auth_router
from ..v1.llm import router as llm_router
from ..v1.chatbot import router as chatbot_router
from ..v1.openai_compat import router as openai_router
@@ -10,6 +11,9 @@ from ..v1.openai_compat import router as openai_router
# Create public API router
public_api_router = APIRouter()
# Include authentication routes (needed for login/logout)
public_api_router.include_router(auth_router, prefix="/auth", tags=["authentication"])
# Include OpenAI-compatible routes (chat/completions, models, embeddings)
public_api_router.include_router(openai_router, tags=["openai-compat"])
@@ -17,4 +21,6 @@ public_api_router.include_router(openai_router, tags=["openai-compat"])
public_api_router.include_router(llm_router, prefix="/llm", tags=["public-llm"])
# Include public chatbot API (external chatbot integrations)
public_api_router.include_router(chatbot_router, prefix="/chatbot", tags=["public-chatbot"])
public_api_router.include_router(
chatbot_router, prefix="/chatbot", tags=["public-chatbot"]
)

View File

@@ -16,10 +16,9 @@ logger = logging.getLogger(__name__)
# Create router
router = APIRouter()
@router.get("/collections")
async def list_collections(
current_user: User = Depends(get_current_user)
):
async def list_collections(current_user: User = Depends(get_current_user)):
"""List all available RAG collections"""
try:
from app.services.qdrant_stats_service import qdrant_stats_service
@@ -31,10 +30,7 @@ async def list_collections(
# Extract collection names
collection_names = [col["name"] for col in collections]
return {
"collections": collection_names,
"count": len(collection_names)
}
return {"collections": collection_names, "count": len(collection_names)}
except Exception as e:
logger.error(f"List collections error: {e}")
@@ -45,10 +41,14 @@ async def list_collections(
async def debug_search(
query: str = Query(..., description="Search query"),
max_results: int = Query(10, ge=1, le=50, description="Maximum number of results"),
score_threshold: float = Query(0.3, ge=0.0, le=1.0, description="Minimum score threshold"),
collection_name: Optional[str] = Query(None, description="Collection name to search"),
score_threshold: float = Query(
0.3, ge=0.0, le=1.0, description="Minimum score threshold"
),
collection_name: Optional[str] = Query(
None, description="Collection name to search"
),
config: Optional[Dict[str, Any]] = None,
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Debug search endpoint with detailed information"""
try:
@@ -56,9 +56,7 @@ async def debug_search(
app_config = settings
# Initialize RAG module with BGE-M3 configuration
rag_config = {
"embedding_model": "BAAI/bge-m3"
}
rag_config = {"embedding_model": "BAAI/bge-m3"}
rag_module = RAGModule(app_config, config=rag_config)
# Get available collections if none specified
@@ -71,9 +69,9 @@ async def debug_search(
"results": [],
"debug_info": {
"error": "No collections available",
"collections_found": 0
"collections_found": 0,
},
"search_time_ms": 0
"search_time_ms": 0,
}
# Perform search
@@ -82,7 +80,7 @@ async def debug_search(
max_results=max_results,
score_threshold=score_threshold,
collection_name=collection_name,
config=config or {}
config=config or {},
)
return results
@@ -94,7 +92,7 @@ async def debug_search(
"debug_info": {
"error": str(e),
"query": query,
"collection_name": collection_name
"collection_name": collection_name,
},
"search_time_ms": 0
"search_time_ms": 0,
}

View File

@@ -17,6 +17,9 @@ from .rag import router as rag_router
from .chatbot import router as chatbot_router
from .prompt_templates import router as prompt_templates_router
from .plugin_registry import router as plugin_registry_router
from .endpoints.tools import router as tools_router
from .endpoints.tool_calling import router as tool_calling_router
from .endpoints.user_management import router as user_management_router
# Create main API router
api_router = APIRouter()
@@ -58,9 +61,23 @@ api_router.include_router(rag_router, prefix="/rag", tags=["rag"])
api_router.include_router(chatbot_router, prefix="/chatbot", tags=["chatbot"])
# Include prompt template routes
api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["prompt-templates"])
api_router.include_router(
prompt_templates_router, prefix="/prompt-templates", tags=["prompt-templates"]
)
# Include plugin registry routes
api_router.include_router(plugin_registry_router, prefix="/plugins", tags=["plugins"])
# Include tool management routes
api_router.include_router(tools_router, prefix="/tools", tags=["tools"])
# Include tool calling routes
api_router.include_router(
tool_calling_router, prefix="/tool-calling", tags=["tool-calling"]
)
# Include admin user management routes
api_router.include_router(
user_management_router, prefix="/admin/user-management", tags=["admin", "user-management"]
)

View File

@@ -22,105 +22,98 @@ router = APIRouter()
async def get_usage_metrics(
hours: int = Query(24, ge=1, le=168, description="Hours to analyze (1-168)"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""Get comprehensive usage metrics including costs and budgets"""
try:
analytics = get_analytics_service()
metrics = await analytics.get_usage_metrics(hours=hours, user_id=current_user['id'])
return {
"success": True,
"data": metrics,
"period_hours": hours
}
metrics = await analytics.get_usage_metrics(
hours=hours, user_id=current_user["id"]
)
return {"success": True, "data": metrics, "period_hours": hours}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting usage metrics: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error getting usage metrics: {str(e)}"
)
@router.get("/metrics/system")
async def get_system_metrics(
hours: int = Query(24, ge=1, le=168),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""Get system-wide metrics (admin only)"""
if not current_user['is_superuser']:
if not current_user["is_superuser"]:
raise HTTPException(status_code=403, detail="Admin access required")
try:
analytics = get_analytics_service()
metrics = await analytics.get_usage_metrics(hours=hours)
return {
"success": True,
"data": metrics,
"period_hours": hours
}
return {"success": True, "data": metrics, "period_hours": hours}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting system metrics: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error getting system metrics: {str(e)}"
)
@router.get("/health")
async def get_system_health(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
):
"""Get system health status including budget and performance analysis"""
try:
analytics = get_analytics_service()
health = await analytics.get_system_health()
return {
"success": True,
"data": health
}
return {"success": True, "data": health}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting system health: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error getting system health: {str(e)}"
)
@router.get("/costs")
async def get_cost_analysis(
days: int = Query(30, ge=1, le=365, description="Days to analyze (1-365)"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""Get detailed cost analysis and trends"""
try:
analytics = get_analytics_service()
analysis = await analytics.get_cost_analysis(days=days, user_id=current_user['id'])
return {
"success": True,
"data": analysis,
"period_days": days
}
analysis = await analytics.get_cost_analysis(
days=days, user_id=current_user["id"]
)
return {"success": True, "data": analysis, "period_days": days}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting cost analysis: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error getting cost analysis: {str(e)}"
)
@router.get("/costs/system")
async def get_system_cost_analysis(
days: int = Query(30, ge=1, le=365),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""Get system-wide cost analysis (admin only)"""
if not current_user['is_superuser']:
if not current_user["is_superuser"]:
raise HTTPException(status_code=403, detail="Admin access required")
try:
analytics = get_analytics_service()
analysis = await analytics.get_cost_analysis(days=days)
return {
"success": True,
"data": analysis,
"period_days": days
}
return {"success": True, "data": analysis, "period_days": days}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting system cost analysis: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error getting system cost analysis: {str(e)}"
)
@router.get("/endpoints")
async def get_endpoint_stats(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
):
"""Get endpoint usage statistics"""
try:
@@ -133,18 +126,20 @@ async def get_endpoint_stats(
"data": {
"endpoint_stats": dict(analytics.endpoint_stats),
"status_codes": dict(analytics.status_codes),
"model_stats": dict(analytics.model_stats)
}
"model_stats": dict(analytics.model_stats),
},
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting endpoint stats: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error getting endpoint stats: {str(e)}"
)
@router.get("/usage-trends")
async def get_usage_trends(
days: int = Query(7, ge=1, le=30, description="Days for trend analysis"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
"""Get usage trends over time"""
try:
@@ -155,50 +150,53 @@ async def get_usage_trends(
cutoff_time = datetime.utcnow() - timedelta(days=days)
# Daily usage trends
daily_usage = db.query(
func.date(UsageTracking.created_at).label('date'),
func.count(UsageTracking.id).label('requests'),
func.sum(UsageTracking.total_tokens).label('tokens'),
func.sum(UsageTracking.cost_cents).label('cost_cents')
).filter(
daily_usage = (
db.query(
func.date(UsageTracking.created_at).label("date"),
func.count(UsageTracking.id).label("requests"),
func.sum(UsageTracking.total_tokens).label("tokens"),
func.sum(UsageTracking.cost_cents).label("cost_cents"),
)
.filter(
UsageTracking.created_at >= cutoff_time,
UsageTracking.user_id == current_user['id']
).group_by(
func.date(UsageTracking.created_at)
).order_by('date').all()
UsageTracking.user_id == current_user["id"],
)
.group_by(func.date(UsageTracking.created_at))
.order_by("date")
.all()
)
trends = []
for date, requests, tokens, cost_cents in daily_usage:
trends.append({
trends.append(
{
"date": date.isoformat(),
"requests": requests,
"tokens": tokens or 0,
"cost_cents": cost_cents or 0,
"cost_dollars": (cost_cents or 0) / 100
})
"cost_dollars": (cost_cents or 0) / 100,
}
)
return {
"success": True,
"data": {
"trends": trends,
"period_days": days
}
}
return {"success": True, "data": {"trends": trends, "period_days": days}}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting usage trends: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error getting usage trends: {str(e)}"
)
@router.get("/overview")
async def get_analytics_overview(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
):
"""Get analytics overview data"""
try:
analytics = get_analytics_service()
# Get basic metrics
metrics = await analytics.get_usage_metrics(hours=24, user_id=current_user['id'])
metrics = await analytics.get_usage_metrics(
hours=24, user_id=current_user["id"]
)
health = await analytics.get_system_health()
return {
@@ -210,8 +208,8 @@ async def get_analytics_overview(
"error_rate": metrics.error_rate,
"budget_usage_percentage": metrics.budget_usage_percentage,
"system_health": health.status,
"health_score": health.score
}
"health_score": health.score,
},
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting overview: {str(e)}")
@@ -219,8 +217,7 @@ async def get_analytics_overview(
@router.get("/modules")
async def get_module_analytics(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
):
"""Get analytics data for all modules"""
try:
@@ -248,10 +245,14 @@ async def get_module_analytics(
"data": {
"modules": module_stats,
"total_modules": len(module_stats),
"system_health": "healthy" if all(m.get("initialized", False) for m in module_stats) else "warning"
}
"system_health": "healthy"
if all(m.get("initialized", False) for m in module_stats)
else "warning",
},
}
except Exception as e:
logger.error(f"Failed to get module analytics: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve module analytics")
raise HTTPException(
status_code=500, detail="Failed to retrieve module analytics"
)

View File

@@ -74,14 +74,16 @@ class APIKeyResponse(BaseModel):
last_used_at: Optional[datetime] = None
total_requests: int
total_tokens: int
total_cost_cents: int = Field(alias='total_cost')
total_cost_cents: int = Field(alias="total_cost")
rate_limit_per_minute: Optional[int] = None
rate_limit_per_hour: Optional[int] = None
rate_limit_per_day: Optional[int] = None
allowed_ips: List[str]
allowed_models: List[str] # Model restrictions
allowed_chatbots: List[str] # Chatbot restrictions
budget_limit: Optional[int] = Field(None, alias='budget_limit_cents') # Budget limit in cents
budget_limit: Optional[int] = Field(
None, alias="budget_limit_cents"
) # Budget limit in cents
budget_type: Optional[str] = None # Budget type
is_unlimited: bool = True # Unlimited budget flag
tags: List[str]
@@ -93,28 +95,28 @@ class APIKeyResponse(BaseModel):
def from_api_key(cls, api_key):
"""Create response from APIKey model with formatted key prefix"""
data = {
'id': api_key.id,
'name': api_key.name,
'description': api_key.description,
'key_prefix': api_key.key_prefix + "..." if api_key.key_prefix else "",
'scopes': api_key.scopes,
'is_active': api_key.is_active,
'expires_at': api_key.expires_at,
'created_at': api_key.created_at,
'last_used_at': api_key.last_used_at,
'total_requests': api_key.total_requests,
'total_tokens': api_key.total_tokens,
'total_cost': api_key.total_cost,
'rate_limit_per_minute': api_key.rate_limit_per_minute,
'rate_limit_per_hour': api_key.rate_limit_per_hour,
'rate_limit_per_day': api_key.rate_limit_per_day,
'allowed_ips': api_key.allowed_ips,
'allowed_models': api_key.allowed_models,
'allowed_chatbots': api_key.allowed_chatbots,
'budget_limit_cents': api_key.budget_limit_cents,
'budget_type': api_key.budget_type,
'is_unlimited': api_key.is_unlimited,
'tags': api_key.tags
"id": api_key.id,
"name": api_key.name,
"description": api_key.description,
"key_prefix": api_key.key_prefix + "..." if api_key.key_prefix else "",
"scopes": api_key.scopes,
"is_active": api_key.is_active,
"expires_at": api_key.expires_at,
"created_at": api_key.created_at,
"last_used_at": api_key.last_used_at,
"total_requests": api_key.total_requests,
"total_tokens": api_key.total_tokens,
"total_cost": api_key.total_cost,
"rate_limit_per_minute": api_key.rate_limit_per_minute,
"rate_limit_per_hour": api_key.rate_limit_per_hour,
"rate_limit_per_day": api_key.rate_limit_per_day,
"allowed_ips": api_key.allowed_ips,
"allowed_models": api_key.allowed_models,
"allowed_chatbots": api_key.allowed_chatbots,
"budget_limit_cents": api_key.budget_limit_cents,
"budget_type": api_key.budget_type,
"is_unlimited": api_key.is_unlimited,
"tags": api_key.tags,
}
return cls(**data)
@@ -148,13 +150,16 @@ class APIKeyUsageResponse(BaseModel):
def generate_api_key() -> tuple[str, str]:
"""Generate a new API key and return (full_key, key_hash)"""
# Generate random key part (32 characters)
key_part = ''.join(secrets.choice(string.ascii_letters + string.digits) for _ in range(32))
key_part = "".join(
secrets.choice(string.ascii_letters + string.digits) for _ in range(32)
)
# Create full key with prefix
full_key = f"{settings.API_KEY_PREFIX}{key_part}"
# Create hash for storage
from app.core.security import get_api_key_hash
key_hash = get_api_key_hash(full_key)
return full_key, key_hash
@@ -169,32 +174,40 @@ async def list_api_keys(
is_active: Optional[bool] = Query(None),
search: Optional[str] = Query(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""List API keys with pagination and filtering"""
# Check permissions - users can view their own API keys
if user_id and int(user_id) != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
if user_id and int(user_id) != current_user["id"]:
require_permission(
current_user.get("permissions", []), "platform:api-keys:read"
)
elif not user_id:
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
require_permission(
current_user.get("permissions", []), "platform:api-keys:read"
)
# If no user_id specified and user doesn't have admin permissions, show only their keys
if not user_id and "platform:api-keys:read" not in current_user.get("permissions", []):
user_id = current_user['id']
if not user_id and "platform:api-keys:read" not in current_user.get(
"permissions", []
):
user_id = current_user["id"]
# Build query
query = select(APIKey)
# Apply filters
if user_id:
query = query.where(APIKey.user_id == (int(user_id) if isinstance(user_id, str) else user_id))
query = query.where(
APIKey.user_id == (int(user_id) if isinstance(user_id, str) else user_id)
)
if is_active is not None:
query = query.where(APIKey.is_active == is_active)
if search:
query = query.where(
(APIKey.name.ilike(f"%{search}%")) |
(APIKey.description.ilike(f"%{search}%"))
(APIKey.name.ilike(f"%{search}%"))
| (APIKey.description.ilike(f"%{search}%"))
)
# Get total count using func.count()
@@ -202,13 +215,15 @@ async def list_api_keys(
# Apply same filters for count
if user_id:
total_query = total_query.where(APIKey.user_id == (int(user_id) if isinstance(user_id, str) else user_id))
total_query = total_query.where(
APIKey.user_id == (int(user_id) if isinstance(user_id, str) else user_id)
)
if is_active is not None:
total_query = total_query.where(APIKey.is_active == is_active)
if search:
total_query = total_query.where(
(APIKey.name.ilike(f"%{search}%")) |
(APIKey.description.ilike(f"%{search}%"))
(APIKey.name.ilike(f"%{search}%"))
| (APIKey.description.ilike(f"%{search}%"))
)
total_result = await db.execute(total_query)
@@ -225,17 +240,21 @@ async def list_api_keys(
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="list_api_keys",
resource_type="api_key",
details={"page": page, "size": size, "filters": {"user_id": user_id, "is_active": is_active, "search": search}}
details={
"page": page,
"size": size,
"filters": {"user_id": user_id, "is_active": is_active, "search": search},
},
)
return APIKeyListResponse(
api_keys=[APIKeyResponse.model_validate(key) for key in api_keys],
total=total,
page=page,
size=size
size=size,
)
@@ -243,7 +262,7 @@ async def list_api_keys(
async def get_api_key(
api_key_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get API key by ID"""
@@ -254,21 +273,22 @@ async def get_api_key(
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
)
# Check permissions - users can view their own API keys
if api_key.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
if api_key.user_id != current_user["id"]:
require_permission(
current_user.get("permissions", []), "platform:api-keys:read"
)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="get_api_key",
resource_type="api_key",
resource_id=api_key_id
resource_id=api_key_id,
)
return APIKeyResponse.model_validate(api_key)
@@ -278,7 +298,7 @@ async def get_api_key(
async def create_api_key(
api_key_data: APIKeyCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Create a new API key"""
@@ -295,7 +315,7 @@ async def create_api_key(
description=api_key_data.description,
key_hash=key_hash,
key_prefix=key_prefix,
user_id=current_user['id'],
user_id=current_user["id"],
scopes=api_key_data.scopes,
expires_at=api_key_data.expires_at,
rate_limit_per_minute=api_key_data.rate_limit_per_minute,
@@ -305,9 +325,11 @@ async def create_api_key(
allowed_models=api_key_data.allowed_models,
allowed_chatbots=api_key_data.allowed_chatbots,
is_unlimited=api_key_data.is_unlimited,
budget_limit_cents=api_key_data.budget_limit_cents if not api_key_data.is_unlimited else None,
budget_limit_cents=api_key_data.budget_limit_cents
if not api_key_data.is_unlimited
else None,
budget_type=api_key_data.budget_type if not api_key_data.is_unlimited else None,
tags=api_key_data.tags
tags=api_key_data.tags,
)
db.add(new_api_key)
@@ -315,19 +337,20 @@ async def create_api_key(
await db.refresh(new_api_key)
# Log audit event asynchronously (non-blocking)
asyncio.create_task(log_audit_event_async(
user_id=str(current_user['id']),
asyncio.create_task(
log_audit_event_async(
user_id=str(current_user["id"]),
action="create_api_key",
resource_type="api_key",
resource_id=str(new_api_key.id),
details={"name": api_key_data.name, "scopes": api_key_data.scopes}
))
details={"name": api_key_data.name, "scopes": api_key_data.scopes},
)
)
logger.info(f"API key created: {new_api_key.name} by {current_user['username']}")
return APIKeyCreateResponse(
api_key=APIKeyResponse.model_validate(new_api_key),
secret_key=full_key
api_key=APIKeyResponse.model_validate(new_api_key), secret_key=full_key
)
@@ -336,7 +359,7 @@ async def update_api_key(
api_key_id: str,
api_key_data: APIKeyUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Update API key"""
@@ -347,19 +370,20 @@ async def update_api_key(
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
)
# Check permissions - users can update their own API keys
if api_key.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:update")
if api_key.user_id != current_user["id"]:
require_permission(
current_user.get("permissions", []), "platform:api-keys:update"
)
# Store original values for audit
original_values = {
"name": api_key.name,
"scopes": api_key.scopes,
"is_active": api_key.is_active
"is_active": api_key.is_active,
}
# Update API key fields
@@ -373,15 +397,15 @@ async def update_api_key(
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="update_api_key",
resource_type="api_key",
resource_id=api_key_id,
details={
"updated_fields": list(update_data.keys()),
"before_values": original_values,
"after_values": {k: getattr(api_key, k) for k in update_data.keys()}
}
"after_values": {k: getattr(api_key, k) for k in update_data.keys()},
},
)
logger.info(f"API key updated: {api_key.name} by {current_user['username']}")
@@ -393,7 +417,7 @@ async def update_api_key(
async def delete_api_key(
api_key_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Delete API key"""
@@ -404,13 +428,14 @@ async def delete_api_key(
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
)
# Check permissions - users can delete their own API keys
if api_key.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:delete")
if api_key.user_id != current_user["id"]:
require_permission(
current_user.get("permissions", []), "platform:api-keys:delete"
)
# Delete API key
await db.delete(api_key)
@@ -419,11 +444,11 @@ async def delete_api_key(
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="delete_api_key",
resource_type="api_key",
resource_id=api_key_id,
details={"name": api_key.name}
details={"name": api_key.name},
)
logger.info(f"API key deleted: {api_key.name} by {current_user['username']}")
@@ -435,7 +460,7 @@ async def delete_api_key(
async def regenerate_api_key(
api_key_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Regenerate API key secret"""
@@ -446,13 +471,14 @@ async def regenerate_api_key(
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
)
# Check permissions - users can regenerate their own API keys
if api_key.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:update")
if api_key.user_id != current_user["id"]:
require_permission(
current_user.get("permissions", []), "platform:api-keys:update"
)
# Generate new API key
full_key, key_hash = generate_api_key()
@@ -468,18 +494,17 @@ async def regenerate_api_key(
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="regenerate_api_key",
resource_type="api_key",
resource_id=api_key_id,
details={"name": api_key.name}
details={"name": api_key.name},
)
logger.info(f"API key regenerated: {api_key.name} by {current_user['username']}")
return APIKeyCreateResponse(
api_key=APIKeyResponse.model_validate(api_key),
secret_key=full_key
api_key=APIKeyResponse.model_validate(api_key), secret_key=full_key
)
@@ -487,7 +512,7 @@ async def regenerate_api_key(
async def get_api_key_usage(
api_key_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get API key usage statistics"""
@@ -498,13 +523,14 @@ async def get_api_key_usage(
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
)
# Check permissions - users can view their own API key usage
if api_key.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
if api_key.user_id != current_user["id"]:
require_permission(
current_user.get("permissions", []), "platform:api-keys:read"
)
# Calculate usage statistics
from app.models.usage_tracking import UsageTracking
@@ -517,10 +543,9 @@ async def get_api_key_usage(
today_query = select(
func.count(UsageTracking.id),
func.sum(UsageTracking.total_tokens),
func.sum(UsageTracking.cost_cents)
func.sum(UsageTracking.cost_cents),
).where(
UsageTracking.api_key_id == api_key_id,
UsageTracking.created_at >= today_start
UsageTracking.api_key_id == api_key_id, UsageTracking.created_at >= today_start
)
today_result = await db.execute(today_query)
today_stats = today_result.first()
@@ -529,10 +554,9 @@ async def get_api_key_usage(
hour_query = select(
func.count(UsageTracking.id),
func.sum(UsageTracking.total_tokens),
func.sum(UsageTracking.cost_cents)
func.sum(UsageTracking.cost_cents),
).where(
UsageTracking.api_key_id == api_key_id,
UsageTracking.created_at >= hour_start
UsageTracking.api_key_id == api_key_id, UsageTracking.created_at >= hour_start
)
hour_result = await db.execute(hour_query)
hour_stats = hour_result.first()
@@ -540,10 +564,10 @@ async def get_api_key_usage(
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="get_api_key_usage",
resource_type="api_key",
resource_id=api_key_id
resource_id=api_key_id,
)
return APIKeyUsageResponse(
@@ -557,7 +581,7 @@ async def get_api_key_usage(
requests_this_hour=hour_stats[0] or 0,
tokens_this_hour=hour_stats[1] or 0,
cost_this_hour_cents=hour_stats[2] or 0,
last_used_at=api_key.last_used_at
last_used_at=api_key.last_used_at,
)
@@ -565,7 +589,7 @@ async def get_api_key_usage(
async def activate_api_key(
api_key_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Activate API key"""
@@ -576,13 +600,14 @@ async def activate_api_key(
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
)
# Check permissions - users can activate their own API keys
if api_key.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:update")
if api_key.user_id != current_user["id"]:
require_permission(
current_user.get("permissions", []), "platform:api-keys:update"
)
# Activate API key
api_key.is_active = True
@@ -591,11 +616,11 @@ async def activate_api_key(
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="activate_api_key",
resource_type="api_key",
resource_id=api_key_id,
details={"name": api_key.name}
details={"name": api_key.name},
)
logger.info(f"API key activated: {api_key.name} by {current_user['username']}")
@@ -607,7 +632,7 @@ async def activate_api_key(
async def deactivate_api_key(
api_key_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Deactivate API key"""
@@ -618,13 +643,14 @@ async def deactivate_api_key(
if not api_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
)
# Check permissions - users can deactivate their own API keys
if api_key.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:api-keys:update")
if api_key.user_id != current_user["id"]:
require_permission(
current_user.get("permissions", []), "platform:api-keys:update"
)
# Deactivate API key
api_key.is_active = False
@@ -633,11 +659,11 @@ async def deactivate_api_key(
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="deactivate_api_key",
resource_type="api_key",
resource_id=api_key_id,
details={"name": api_key.name}
details={"name": api_key.name},
)
logger.info(f"API key deactivated: {api_key.name} by {current_user['username']}")

View File

@@ -96,7 +96,7 @@ async def list_audit_logs(
severity: Optional[str] = Query(None),
search: Optional[str] = Query(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""List audit logs with filtering and pagination"""
@@ -128,7 +128,7 @@ async def list_audit_logs(
search_conditions = [
AuditLog.action.ilike(f"%{search}%"),
AuditLog.resource_type.ilike(f"%{search}%"),
AuditLog.details.astext.ilike(f"%{search}%")
AuditLog.details.astext.ilike(f"%{search}%"),
]
conditions.append(or_(*search_conditions))
@@ -166,19 +166,19 @@ async def list_audit_logs(
"end_date": end_date.isoformat() if end_date else None,
"success": success,
"severity": severity,
"search": search
"search": search,
},
"page": page,
"size": size,
"total_results": total
}
"total_results": total,
},
)
return AuditLogListResponse(
logs=[AuditLogResponse.model_validate(log) for log in logs],
total=total,
page=page,
size=size
size=size,
)
@@ -188,7 +188,7 @@ async def search_audit_logs(
page: int = Query(1, ge=1),
size: int = Query(50, ge=1, le=1000),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Advanced search for audit logs"""
@@ -205,7 +205,7 @@ async def search_audit_logs(
start_date=search_request.start_date,
end_date=search_request.end_date,
limit=size,
offset=(page - 1) * size
offset=(page - 1) * size,
)
# Get total count for the search
@@ -234,7 +234,7 @@ async def search_audit_logs(
search_conditions = [
AuditLog.action.ilike(f"%{search_request.search_text}%"),
AuditLog.resource_type.ilike(f"%{search_request.search_text}%"),
AuditLog.details.astext.ilike(f"%{search_request.search_text}%")
AuditLog.details.astext.ilike(f"%{search_request.search_text}%"),
]
conditions.append(or_(*search_conditions))
@@ -253,15 +253,15 @@ async def search_audit_logs(
details={
"search_criteria": search_request.model_dump(exclude_unset=True),
"results_count": len(logs),
"total_matches": total
}
"total_matches": total,
},
)
return AuditLogListResponse(
logs=[AuditLogResponse.model_validate(log) for log in logs],
total=total,
page=page,
size=size
size=size,
)
@@ -270,7 +270,7 @@ async def get_audit_statistics(
start_date: Optional[datetime] = Query(None),
end_date: Optional[datetime] = Query(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get audit log statistics"""
@@ -287,46 +287,62 @@ async def get_audit_statistics(
basic_stats = await get_audit_stats(db, start_date, end_date)
# Get additional statistics
conditions = [
AuditLog.created_at >= start_date,
AuditLog.created_at <= end_date
]
conditions = [AuditLog.created_at >= start_date, AuditLog.created_at <= end_date]
# Events by user
user_query = select(
AuditLog.user_id,
func.count(AuditLog.id).label('count')
).where(and_(*conditions)).group_by(AuditLog.user_id).order_by(func.count(AuditLog.id).desc()).limit(10)
user_query = (
select(AuditLog.user_id, func.count(AuditLog.id).label("count"))
.where(and_(*conditions))
.group_by(AuditLog.user_id)
.order_by(func.count(AuditLog.id).desc())
.limit(10)
)
user_result = await db.execute(user_query)
events_by_user = dict(user_result.fetchall())
# Events by hour of day
hour_query = select(
func.extract('hour', AuditLog.created_at).label('hour'),
func.count(AuditLog.id).label('count')
).where(and_(*conditions)).group_by(func.extract('hour', AuditLog.created_at)).order_by('hour')
hour_query = (
select(
func.extract("hour", AuditLog.created_at).label("hour"),
func.count(AuditLog.id).label("count"),
)
.where(and_(*conditions))
.group_by(func.extract("hour", AuditLog.created_at))
.order_by("hour")
)
hour_result = await db.execute(hour_query)
events_by_hour = dict(hour_result.fetchall())
# Top actions
top_actions_query = select(
AuditLog.action,
func.count(AuditLog.id).label('count')
).where(and_(*conditions)).group_by(AuditLog.action).order_by(func.count(AuditLog.id).desc()).limit(10)
top_actions_query = (
select(AuditLog.action, func.count(AuditLog.id).label("count"))
.where(and_(*conditions))
.group_by(AuditLog.action)
.order_by(func.count(AuditLog.id).desc())
.limit(10)
)
top_actions_result = await db.execute(top_actions_query)
top_actions = [{"action": row[0], "count": row[1]} for row in top_actions_result.fetchall()]
top_actions = [
{"action": row[0], "count": row[1]} for row in top_actions_result.fetchall()
]
# Top resources
top_resources_query = select(
AuditLog.resource_type,
func.count(AuditLog.id).label('count')
).where(and_(*conditions)).group_by(AuditLog.resource_type).order_by(func.count(AuditLog.id).desc()).limit(10)
top_resources_query = (
select(AuditLog.resource_type, func.count(AuditLog.id).label("count"))
.where(and_(*conditions))
.group_by(AuditLog.resource_type)
.order_by(func.count(AuditLog.id).desc())
.limit(10)
)
top_resources_result = await db.execute(top_resources_query)
top_resources = [{"resource_type": row[0], "count": row[1]} for row in top_resources_result.fetchall()]
top_resources = [
{"resource_type": row[0], "count": row[1]}
for row in top_resources_result.fetchall()
]
# Log audit event
await log_audit_event(
@@ -337,8 +353,8 @@ async def get_audit_statistics(
details={
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat(),
"total_events": basic_stats["total_events"]
}
"total_events": basic_stats["total_events"],
},
)
return AuditStatsResponse(
@@ -346,7 +362,7 @@ async def get_audit_statistics(
events_by_user=events_by_user,
events_by_hour=events_by_hour,
top_actions=top_actions,
top_resources=top_resources
top_resources=top_resources,
)
@@ -354,7 +370,7 @@ async def get_audit_statistics(
async def get_security_events(
hours: int = Query(24, ge=1, le=168), # Last 24 hours by default, max 1 week
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get security-related events and anomalies"""
@@ -365,13 +381,18 @@ async def get_security_events(
start_time = end_time - timedelta(hours=hours)
# Failed logins
failed_logins_query = select(AuditLog).where(
failed_logins_query = (
select(AuditLog)
.where(
and_(
AuditLog.created_at >= start_time,
AuditLog.action == "login",
AuditLog.success == False
AuditLog.success == False,
)
)
.order_by(AuditLog.created_at.desc())
.limit(50)
)
).order_by(AuditLog.created_at.desc()).limit(50)
failed_logins_result = await db.execute(failed_logins_query)
failed_logins = [
@@ -380,18 +401,23 @@ async def get_security_events(
"user_id": log.user_id,
"ip_address": log.ip_address,
"user_agent": log.user_agent,
"details": log.details
"details": log.details,
}
for log in failed_logins_result.scalars().all()
]
# High severity events
high_severity_query = select(AuditLog).where(
high_severity_query = (
select(AuditLog)
.where(
and_(
AuditLog.created_at >= start_time,
AuditLog.severity.in_(["error", "critical"])
AuditLog.severity.in_(["error", "critical"]),
)
)
.order_by(AuditLog.created_at.desc())
.limit(50)
)
).order_by(AuditLog.created_at.desc()).limit(50)
high_severity_result = await db.execute(high_severity_query)
high_severity_events = [
@@ -403,52 +429,61 @@ async def get_security_events(
"user_id": log.user_id,
"ip_address": log.ip_address,
"success": log.success,
"details": log.details
"details": log.details,
}
for log in high_severity_result.scalars().all()
]
# Suspicious activities (multiple failed attempts from same IP)
suspicious_ips_query = select(
AuditLog.ip_address,
func.count(AuditLog.id).label('failed_count')
).where(
suspicious_ips_query = (
select(AuditLog.ip_address, func.count(AuditLog.id).label("failed_count"))
.where(
and_(
AuditLog.created_at >= start_time,
AuditLog.success == False,
AuditLog.ip_address.isnot(None)
AuditLog.ip_address.isnot(None),
)
)
.group_by(AuditLog.ip_address)
.having(func.count(AuditLog.id) >= 5)
.order_by(func.count(AuditLog.id).desc())
)
).group_by(AuditLog.ip_address).having(func.count(AuditLog.id) >= 5).order_by(func.count(AuditLog.id).desc())
suspicious_ips_result = await db.execute(suspicious_ips_query)
suspicious_activities = [
{
"ip_address": row[0],
"failed_attempts": row[1],
"risk_level": "high" if row[1] >= 10 else "medium"
"risk_level": "high" if row[1] >= 10 else "medium",
}
for row in suspicious_ips_result.fetchall()
]
# Unusual access patterns (users accessing from multiple IPs)
unusual_access_query = select(
unusual_access_query = (
select(
AuditLog.user_id,
func.count(func.distinct(AuditLog.ip_address)).label('ip_count'),
func.array_agg(func.distinct(AuditLog.ip_address)).label('ip_addresses')
).where(
func.count(func.distinct(AuditLog.ip_address)).label("ip_count"),
func.array_agg(func.distinct(AuditLog.ip_address)).label("ip_addresses"),
)
.where(
and_(
AuditLog.created_at >= start_time,
AuditLog.user_id.isnot(None),
AuditLog.ip_address.isnot(None)
AuditLog.ip_address.isnot(None),
)
)
.group_by(AuditLog.user_id)
.having(func.count(func.distinct(AuditLog.ip_address)) >= 3)
.order_by(func.count(func.distinct(AuditLog.ip_address)).desc())
)
).group_by(AuditLog.user_id).having(func.count(func.distinct(AuditLog.ip_address)) >= 3).order_by(func.count(func.distinct(AuditLog.ip_address)).desc())
unusual_access_result = await db.execute(unusual_access_query)
unusual_access_patterns = [
{
"user_id": row[0],
"unique_ips": row[1],
"ip_addresses": row[2] if row[2] else []
"ip_addresses": row[2] if row[2] else [],
}
for row in unusual_access_result.fetchall()
]
@@ -464,15 +499,15 @@ async def get_security_events(
"failed_logins_count": len(failed_logins),
"high_severity_count": len(high_severity_events),
"suspicious_ips_count": len(suspicious_activities),
"unusual_access_patterns_count": len(unusual_access_patterns)
}
"unusual_access_patterns_count": len(unusual_access_patterns),
},
)
return SecurityEventsResponse(
suspicious_activities=suspicious_activities,
failed_logins=failed_logins,
unusual_access_patterns=unusual_access_patterns,
high_severity_events=high_severity_events
high_severity_events=high_severity_events,
)
@@ -485,7 +520,7 @@ async def export_audit_logs(
action: Optional[str] = Query(None),
resource_type: Optional[str] = Query(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Export audit logs in CSV or JSON format"""
@@ -503,10 +538,7 @@ async def export_audit_logs(
# Build query
query = select(AuditLog)
conditions = [
AuditLog.created_at >= start_date,
AuditLog.created_at <= end_date
]
conditions = [AuditLog.created_at >= start_date, AuditLog.created_at <= end_date]
if user_id:
conditions.append(AuditLog.user_id == user_id)
@@ -515,7 +547,11 @@ async def export_audit_logs(
if resource_type:
conditions.append(AuditLog.resource_type == resource_type)
query = query.where(and_(*conditions)).order_by(AuditLog.created_at.desc()).limit(max_records)
query = (
query.where(and_(*conditions))
.order_by(AuditLog.created_at.desc())
.limit(max_records)
)
# Execute query
result = await db.execute(query)
@@ -535,13 +571,14 @@ async def export_audit_logs(
"filters": {
"user_id": user_id,
"action": action,
"resource_type": resource_type
}
}
"resource_type": resource_type,
},
},
)
if format == "json":
from fastapi.responses import JSONResponse
export_data = [
{
"id": str(log.id),
@@ -554,7 +591,7 @@ async def export_audit_logs(
"user_agent": log.user_agent,
"success": log.success,
"severity": log.severity,
"created_at": log.created_at.isoformat()
"created_at": log.created_at.isoformat(),
}
for log in logs
]
@@ -569,14 +606,25 @@ async def export_audit_logs(
writer = csv.writer(output)
# Write header
writer.writerow([
"ID", "User ID", "Action", "Resource Type", "Resource ID",
"IP Address", "Success", "Severity", "Created At", "Details"
])
writer.writerow(
[
"ID",
"User ID",
"Action",
"Resource Type",
"Resource ID",
"IP Address",
"Success",
"Severity",
"Created At",
"Details",
]
)
# Write data
for log in logs:
writer.writerow([
writer.writerow(
[
str(log.id),
log.user_id or "",
log.action,
@@ -586,13 +634,16 @@ async def export_audit_logs(
log.success,
log.severity,
log.created_at.isoformat(),
str(log.details)
])
str(log.details),
]
)
output.seek(0)
return StreamingResponse(
io.BytesIO(output.getvalue().encode()),
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename=audit_logs_{start_date.strftime('%Y%m%d')}_{end_date.strftime('%Y%m%d')}.csv"}
headers={
"Content-Disposition": f"attachment; filename=audit_logs_{start_date.strftime('%Y%m%d')}_{end_date.strftime('%Y%m%d')}.csv"
},
)

View File

@@ -9,6 +9,7 @@ from fastapi.security import HTTPBearer
from pydantic import BaseModel, EmailStr, validator
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from app.core.config import settings
from app.core.logging import get_logger
@@ -39,24 +40,24 @@ class UserRegisterRequest(BaseModel):
first_name: Optional[str] = None
last_name: Optional[str] = None
@validator('password')
@validator("password")
def validate_password(cls, v):
if len(v) < 8:
raise ValueError('Password must be at least 8 characters long')
raise ValueError("Password must be at least 8 characters long")
if not any(c.isupper() for c in v):
raise ValueError('Password must contain at least one uppercase letter')
raise ValueError("Password must contain at least one uppercase letter")
if not any(c.islower() for c in v):
raise ValueError('Password must contain at least one lowercase letter')
raise ValueError("Password must contain at least one lowercase letter")
if not any(c.isdigit() for c in v):
raise ValueError('Password must contain at least one digit')
raise ValueError("Password must contain at least one digit")
return v
@validator('username')
@validator("username")
def validate_username(cls, v):
if len(v) < 3:
raise ValueError('Username must be at least 3 characters long')
raise ValueError("Username must be at least 3 characters long")
if not v.isalnum():
raise ValueError('Username must contain only alphanumeric characters')
raise ValueError("Username must contain only alphanumeric characters")
return v
@@ -65,10 +66,10 @@ class UserLoginRequest(BaseModel):
username: Optional[str] = None
password: str
@validator('email')
@validator("email")
def validate_email_or_username(cls, v, values):
if v is None and not values.get('username'):
raise ValueError('Either email or username must be provided')
if v is None and not values.get("username"):
raise ValueError("Either email or username must be provided")
return v
@@ -77,6 +78,8 @@ class TokenResponse(BaseModel):
refresh_token: str
token_type: str = "bearer"
expires_in: int
force_password_change: Optional[bool] = None
message: Optional[str] = None
class UserResponse(BaseModel):
@@ -86,7 +89,7 @@ class UserResponse(BaseModel):
full_name: Optional[str]
is_active: bool
is_verified: bool
role: str
role: Optional[str]
created_at: datetime
class Config:
@@ -101,24 +104,23 @@ class ChangePasswordRequest(BaseModel):
current_password: str
new_password: str
@validator('new_password')
@validator("new_password")
def validate_new_password(cls, v):
if len(v) < 8:
raise ValueError('Password must be at least 8 characters long')
raise ValueError("Password must be at least 8 characters long")
if not any(c.isupper() for c in v):
raise ValueError('Password must contain at least one uppercase letter')
raise ValueError("Password must contain at least one uppercase letter")
if not any(c.islower() for c in v):
raise ValueError('Password must contain at least one lowercase letter')
raise ValueError("Password must contain at least one lowercase letter")
if not any(c.isdigit() for c in v):
raise ValueError('Password must contain at least one digit')
raise ValueError("Password must contain at least one digit")
return v
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(
user_data: UserRegisterRequest,
db: AsyncSession = Depends(get_db)
):
@router.post(
"/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED
)
async def register(user_data: UserRegisterRequest, db: AsyncSession = Depends(get_db)):
"""Register a new user"""
# Check if user already exists
@@ -126,8 +128,7 @@ async def register(
result = await db.execute(stmt)
if result.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered"
status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered"
)
# Check if username already exists
@@ -135,8 +136,7 @@ async def register(
result = await db.execute(stmt)
if result.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username already taken"
status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken"
)
# Create new user
@@ -151,21 +151,27 @@ async def register(
full_name=full_name,
is_active=True,
is_verified=False,
role="user"
role_id=2, # Default to 'user' role (id=2)
)
db.add(user)
await db.commit()
await db.refresh(user)
return UserResponse.from_orm(user)
return UserResponse(
id=user.id,
email=user.email,
username=user.username,
full_name=user.full_name,
is_active=user.is_active,
is_verified=user.is_verified,
role=user.role.name if user.role else None,
created_at=user.created_at,
)
@router.post("/login", response_model=TokenResponse)
async def login(
user_data: UserLoginRequest,
db: AsyncSession = Depends(get_db)
):
async def login(user_data: UserLoginRequest, db: AsyncSession = Depends(get_db)):
"""Login user and return access tokens"""
# Determine identifier for logging and user lookup
@@ -187,9 +193,9 @@ async def login(
query_start = datetime.utcnow()
if user_data.email:
stmt = select(User).where(User.email == user_data.email)
stmt = select(User).options(selectinload(User.role)).where(User.email == user_data.email)
else:
stmt = select(User).where(User.username == user_data.username)
stmt = select(User).options(selectinload(User.role)).where(User.username == user_data.username)
result = await db.execute(stmt)
query_end = datetime.utcnow()
@@ -205,13 +211,18 @@ async def login(
identifier_lower = identifier.lower() if identifier else ""
admin_email = settings.ADMIN_EMAIL.lower() if settings.ADMIN_EMAIL else None
if user_data.email and admin_email and identifier_lower == admin_email and settings.ADMIN_PASSWORD:
if (
user_data.email
and admin_email
and identifier_lower == admin_email
and settings.ADMIN_PASSWORD
):
bootstrap_attempted = True
logger.info("LOGIN_ADMIN_BOOTSTRAP_START", email=user_data.email)
try:
await create_default_admin()
# Re-run lookup after bootstrap attempt
stmt = select(User).where(User.email == user_data.email)
stmt = select(User).options(selectinload(User.role)).where(User.email == user_data.email)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if user:
@@ -234,11 +245,13 @@ async def login(
logger.error("LOGIN_USER_LIST_FAILURE", error=str(e))
if bootstrap_attempted:
logger.warning("LOGIN_ADMIN_BOOTSTRAP_UNSUCCESSFUL", email=user_data.email)
logger.warning(
"LOGIN_ADMIN_BOOTSTRAP_UNSUCCESSFUL", email=user_data.email
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password"
detail="Incorrect email or password",
)
logger.info("LOGIN_USER_FOUND", email=user.email, is_active=user.is_active)
@@ -253,7 +266,7 @@ async def login(
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password"
detail="Incorrect email or password",
)
verify_end = datetime.utcnow()
@@ -264,8 +277,7 @@ async def login(
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User account is disabled"
status_code=status.HTTP_401_UNAUTHORIZED, detail="User account is disabled"
)
# Update last login
@@ -289,14 +301,12 @@ async def login(
"sub": str(user.id),
"email": user.email,
"is_superuser": user.is_superuser,
"role": user.role
"role": user.role.name if user.role else None,
},
expires_delta=access_token_expires
expires_delta=access_token_expires,
)
refresh_token = create_refresh_token(
data={"sub": str(user.id), "type": "refresh"}
)
refresh_token = create_refresh_token(data={"sub": str(user.id), "type": "refresh"})
token_end = datetime.utcnow()
logger.info(
"LOGIN_TOKEN_CREATE_SUCCESS",
@@ -309,17 +319,25 @@ async def login(
total_duration_seconds=total_time.total_seconds(),
)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
)
# Check if user needs to change password
response_data = {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
"expires_in": settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
}
# Add force password change flag if needed
if user.force_password_change:
response_data["force_password_change"] = True
response_data["message"] = "Password change required on first login"
return response_data
@router.post("/refresh", response_model=TokenResponse)
async def refresh_token(
token_data: RefreshTokenRequest,
db: AsyncSession = Depends(get_db)
token_data: RefreshTokenRequest, db: AsyncSession = Depends(get_db)
):
"""Refresh access token using refresh token"""
@@ -330,25 +348,28 @@ async def refresh_token(
if not user_id or token_type != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token"
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
)
# Get user from database
stmt = select(User).where(User.id == int(user_id))
stmt = select(User).options(selectinload(User.role)).where(User.id == int(user_id))
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or inactive"
detail="User not found or inactive",
)
# Create new access token
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
logger.info(f"REFRESH: Creating new access token with expiration: {access_token_expires}")
logger.info(f"REFRESH: ACCESS_TOKEN_EXPIRE_MINUTES from settings: {settings.ACCESS_TOKEN_EXPIRE_MINUTES}")
logger.info(
f"REFRESH: Creating new access token with expiration: {access_token_expires}"
)
logger.info(
f"REFRESH: ACCESS_TOKEN_EXPIRE_MINUTES from settings: {settings.ACCESS_TOKEN_EXPIRE_MINUTES}"
)
logger.info(f"REFRESH: Current UTC time: {datetime.utcnow().isoformat()}")
access_token = create_access_token(
@@ -356,15 +377,15 @@ async def refresh_token(
"sub": str(user.id),
"email": user.email,
"is_superuser": user.is_superuser,
"role": user.role
"role": user.role.name if user.role else None,
},
expires_delta=access_token_expires
expires_delta=access_token_expires,
)
return TokenResponse(
access_token=access_token,
refresh_token=token_data.refresh_token, # Keep same refresh token
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
)
except HTTPException:
@@ -374,30 +395,37 @@ async def refresh_token(
# Log the actual error for debugging
logger.error(f"Refresh token error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token"
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
)
@router.get("/me", response_model=UserResponse)
async def get_current_user_info(
current_user: dict = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get current user information"""
# Get full user details from database
stmt = select(User).where(User.id == int(current_user["id"]))
stmt = select(User).options(selectinload(User.role)).where(User.id == int(current_user["id"]))
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
return UserResponse.from_orm(user)
return UserResponse(
id=user.id,
email=user.email,
username=user.username,
full_name=user.full_name,
is_active=user.is_active,
is_verified=user.is_verified,
role=user.role.name if user.role else None,
created_at=user.created_at,
)
@router.post("/logout")
@@ -407,14 +435,12 @@ async def logout():
@router.post("/verify-token")
async def verify_user_token(
current_user: dict = Depends(get_current_user)
):
async def verify_user_token(current_user: dict = Depends(get_current_user)):
"""Verify if the current token is valid"""
return {
"valid": True,
"user_id": current_user["id"],
"email": current_user["email"]
"email": current_user["email"],
}
@@ -422,7 +448,7 @@ async def verify_user_token(
async def change_password(
password_data: ChangePasswordRequest,
current_user: dict = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Change user password"""
@@ -433,15 +459,14 @@ async def change_password(
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
# Verify current password
if not verify_password(password_data.current_password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Current password is incorrect"
detail="Current password is incorrect",
)
# Update password

View File

@@ -129,26 +129,30 @@ async def list_budgets(
budget_type: Optional[BudgetType] = Query(None),
is_enabled: Optional[bool] = Query(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""List budgets with pagination and filtering"""
# Check permissions - users can view their own budgets
if user_id and int(user_id) != current_user['id']:
if user_id and int(user_id) != current_user["id"]:
require_permission(current_user.get("permissions", []), "platform:budgets:read")
elif not user_id:
require_permission(current_user.get("permissions", []), "platform:budgets:read")
# If no user_id specified and user doesn't have admin permissions, show only their budgets
if not user_id and "platform:budgets:read" not in current_user.get("permissions", []):
user_id = current_user['id']
if not user_id and "platform:budgets:read" not in current_user.get(
"permissions", []
):
user_id = current_user["id"]
# Build query
query = select(Budget)
# Apply filters
if user_id:
query = query.where(Budget.user_id == (int(user_id) if isinstance(user_id, str) else user_id))
query = query.where(
Budget.user_id == (int(user_id) if isinstance(user_id, str) else user_id)
)
if budget_type:
query = query.where(Budget.budget_type == budget_type.value)
if is_enabled is not None:
@@ -159,7 +163,9 @@ async def list_budgets(
# Apply same filters to count query
if user_id:
count_query = count_query.where(Budget.user_id == (int(user_id) if isinstance(user_id, str) else user_id))
count_query = count_query.where(
Budget.user_id == (int(user_id) if isinstance(user_id, str) else user_id)
)
if budget_type:
count_query = count_query.where(Budget.budget_type == budget_type.value)
if is_enabled is not None:
@@ -182,23 +188,30 @@ async def list_budgets(
usage = await _calculate_budget_usage(db, budget)
budget_data = BudgetResponse.model_validate(budget)
budget_data.current_usage = usage
budget_data.usage_percentage = (usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
budget_data.usage_percentage = (
(usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
)
budget_responses.append(budget_data)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="list_budgets",
resource_type="budget",
details={"page": page, "size": size, "filters": {"user_id": user_id, "budget_type": budget_type, "is_enabled": is_enabled}}
details={
"page": page,
"size": size,
"filters": {
"user_id": user_id,
"budget_type": budget_type,
"is_enabled": is_enabled,
},
},
)
return BudgetListResponse(
budgets=budget_responses,
total=total,
page=page,
size=size
budgets=budget_responses, total=total, page=page, size=size
)
@@ -206,7 +219,7 @@ async def list_budgets(
async def get_budget(
budget_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get budget by ID"""
@@ -217,12 +230,11 @@ async def get_budget(
if not budget:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Budget not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Budget not found"
)
# Check permissions - users can view their own budgets
if budget.user_id != current_user['id']:
if budget.user_id != current_user["id"]:
require_permission(current_user.get("permissions", []), "platform:budgets:read")
# Calculate current usage
@@ -231,15 +243,17 @@ async def get_budget(
# Build response
budget_data = BudgetResponse.model_validate(budget)
budget_data.current_usage = usage
budget_data.usage_percentage = (usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
budget_data.usage_percentage = (
(usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="get_budget",
resource_type="budget",
resource_id=budget_id
resource_id=budget_id,
)
return budget_data
@@ -249,7 +263,7 @@ async def get_budget(
async def create_budget(
budget_data: BudgetCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Create a new budget"""
@@ -257,11 +271,17 @@ async def create_budget(
require_permission(current_user.get("permissions", []), "platform:budgets:create")
# If user_id not specified, use current user
target_user_id = budget_data.user_id or current_user['id']
target_user_id = budget_data.user_id or current_user["id"]
# If setting budget for another user, need admin permissions
if int(target_user_id) != current_user['id'] if isinstance(target_user_id, str) else target_user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:budgets:admin")
if (
int(target_user_id) != current_user["id"]
if isinstance(target_user_id, str)
else target_user_id != current_user["id"]
):
require_permission(
current_user.get("permissions", []), "platform:budgets:admin"
)
# Calculate period start and end
now = datetime.utcnow()
@@ -281,7 +301,7 @@ async def create_budget(
is_enabled=budget_data.is_enabled,
alert_threshold_percent=budget_data.alert_threshold_percent,
allowed_resources=budget_data.allowed_resources,
metadata=budget_data.metadata
metadata=budget_data.metadata,
)
db.add(new_budget)
@@ -296,11 +316,15 @@ async def create_budget(
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="create_budget",
resource_type="budget",
resource_id=str(new_budget.id),
details={"name": budget_data.name, "budget_type": budget_data.budget_type, "limit_amount": budget_data.limit_amount}
details={
"name": budget_data.name,
"budget_type": budget_data.budget_type,
"limit_amount": budget_data.limit_amount,
},
)
logger.info(f"Budget created: {new_budget.name} by {current_user['username']}")
@@ -313,7 +337,7 @@ async def update_budget(
budget_id: str,
budget_data: BudgetUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Update budget"""
@@ -324,19 +348,20 @@ async def update_budget(
if not budget:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Budget not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Budget not found"
)
# Check permissions - users can update their own budgets
if budget.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:budgets:update")
if budget.user_id != current_user["id"]:
require_permission(
current_user.get("permissions", []), "platform:budgets:update"
)
# Store original values for audit
original_values = {
"name": budget.name,
"limit_amount": budget.limit_amount,
"is_enabled": budget.is_enabled
"is_enabled": budget.is_enabled,
}
# Update budget fields
@@ -346,7 +371,9 @@ async def update_budget(
# Recalculate period if period_type changed
if "period_type" in update_data:
period_start, period_end = _calculate_period_bounds(datetime.utcnow(), budget.period_type)
period_start, period_end = _calculate_period_bounds(
datetime.utcnow(), budget.period_type
)
budget.period_start = period_start
budget.period_end = period_end
@@ -359,20 +386,22 @@ async def update_budget(
# Build response
budget_response = BudgetResponse.model_validate(budget)
budget_response.current_usage = usage
budget_response.usage_percentage = (usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
budget_response.usage_percentage = (
(usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="update_budget",
resource_type="budget",
resource_id=budget_id,
details={
"updated_fields": list(update_data.keys()),
"before_values": original_values,
"after_values": {k: getattr(budget, k) for k in update_data.keys()}
}
"after_values": {k: getattr(budget, k) for k in update_data.keys()},
},
)
logger.info(f"Budget updated: {budget.name} by {current_user['username']}")
@@ -384,7 +413,7 @@ async def update_budget(
async def delete_budget(
budget_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Delete budget"""
@@ -395,13 +424,14 @@ async def delete_budget(
if not budget:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Budget not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Budget not found"
)
# Check permissions - users can delete their own budgets
if budget.user_id != current_user['id']:
require_permission(current_user.get("permissions", []), "platform:budgets:delete")
if budget.user_id != current_user["id"]:
require_permission(
current_user.get("permissions", []), "platform:budgets:delete"
)
# Delete budget
await db.delete(budget)
@@ -410,11 +440,11 @@ async def delete_budget(
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="delete_budget",
resource_type="budget",
resource_id=budget_id,
details={"name": budget.name}
details={"name": budget.name},
)
logger.info(f"Budget deleted: {budget.name} by {current_user['username']}")
@@ -426,7 +456,7 @@ async def delete_budget(
async def get_budget_usage(
budget_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get detailed budget usage information"""
@@ -437,17 +467,18 @@ async def get_budget_usage(
if not budget:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Budget not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Budget not found"
)
# Check permissions - users can view their own budget usage
if budget.user_id != current_user['id']:
if budget.user_id != current_user["id"]:
require_permission(current_user.get("permissions", []), "platform:budgets:read")
# Calculate usage
current_usage = await _calculate_budget_usage(db, budget)
usage_percentage = (current_usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
usage_percentage = (
(current_usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
)
remaining_amount = max(0, budget.limit_amount - current_usage)
is_exceeded = current_usage > budget.limit_amount
@@ -470,10 +501,10 @@ async def get_budget_usage(
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="get_budget_usage",
resource_type="budget",
resource_id=budget_id
resource_id=budget_id,
)
return BudgetUsageResponse(
@@ -487,7 +518,7 @@ async def get_budget_usage(
is_exceeded=is_exceeded,
days_remaining=days_remaining,
projected_usage=projected_usage,
usage_history=usage_history
usage_history=usage_history,
)
@@ -495,7 +526,7 @@ async def get_budget_usage(
async def get_budget_alerts(
budget_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get budget alerts"""
@@ -506,51 +537,58 @@ async def get_budget_alerts(
if not budget:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Budget not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Budget not found"
)
# Check permissions - users can view their own budget alerts
if budget.user_id != current_user['id']:
if budget.user_id != current_user["id"]:
require_permission(current_user.get("permissions", []), "platform:budgets:read")
# Calculate usage
current_usage = await _calculate_budget_usage(db, budget)
usage_percentage = (current_usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
usage_percentage = (
(current_usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
)
alerts = []
# Check for alerts
if usage_percentage >= 100:
alerts.append(BudgetAlertResponse(
alerts.append(
BudgetAlertResponse(
budget_id=budget_id,
budget_name=budget.name,
alert_type="exceeded",
current_usage=current_usage,
limit_amount=budget.limit_amount,
usage_percentage=usage_percentage,
message=f"Budget '{budget.name}' has been exceeded ({usage_percentage:.1f}% used)"
))
message=f"Budget '{budget.name}' has been exceeded ({usage_percentage:.1f}% used)",
)
)
elif usage_percentage >= 90:
alerts.append(BudgetAlertResponse(
alerts.append(
BudgetAlertResponse(
budget_id=budget_id,
budget_name=budget.name,
alert_type="critical",
current_usage=current_usage,
limit_amount=budget.limit_amount,
usage_percentage=usage_percentage,
message=f"Budget '{budget.name}' is critically high ({usage_percentage:.1f}% used)"
))
message=f"Budget '{budget.name}' is critically high ({usage_percentage:.1f}% used)",
)
)
elif usage_percentage >= budget.alert_threshold_percent:
alerts.append(BudgetAlertResponse(
alerts.append(
BudgetAlertResponse(
budget_id=budget_id,
budget_name=budget.name,
alert_type="warning",
current_usage=current_usage,
limit_amount=budget.limit_amount,
usage_percentage=usage_percentage,
message=f"Budget '{budget.name}' has reached alert threshold ({usage_percentage:.1f}% used)"
))
message=f"Budget '{budget.name}' has reached alert threshold ({usage_percentage:.1f}% used)",
)
)
return alerts
@@ -565,7 +603,7 @@ async def _calculate_budget_usage(db: AsyncSession, budget: Budget) -> float:
# Filter by time period
query = query.where(
UsageTracking.created_at >= budget.period_start,
UsageTracking.created_at <= budget.period_end
UsageTracking.created_at <= budget.period_end,
)
# Filter by user or API key
@@ -594,7 +632,9 @@ async def _calculate_budget_usage(db: AsyncSession, budget: Budget) -> float:
return float(usage)
def _calculate_period_bounds(current_time: datetime, period_type: str) -> tuple[datetime, datetime]:
def _calculate_period_bounds(
current_time: datetime, period_type: str
) -> tuple[datetime, datetime]:
"""Calculate period start and end dates"""
if period_type == "hourly":
@@ -606,7 +646,9 @@ def _calculate_period_bounds(current_time: datetime, period_type: str) -> tuple[
elif period_type == "weekly":
# Start of week (Monday)
days_since_monday = current_time.weekday()
start = current_time.replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=days_since_monday)
start = current_time.replace(
hour=0, minute=0, second=0, microsecond=0
) - timedelta(days=days_since_monday)
end = start + timedelta(weeks=1) - timedelta(microseconds=1)
elif period_type == "monthly":
start = current_time.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
@@ -616,7 +658,9 @@ def _calculate_period_bounds(current_time: datetime, period_type: str) -> tuple[
next_month = start.replace(month=start.month + 1)
end = next_month - timedelta(microseconds=1)
elif period_type == "yearly":
start = current_time.replace(month=1, day=1, hour=0, minute=0, second=0, microsecond=0)
start = current_time.replace(
month=1, day=1, hour=0, minute=0, second=0, microsecond=0
)
end = start.replace(year=start.year + 1) - timedelta(microseconds=1)
else:
# Default to daily
@@ -626,7 +670,9 @@ def _calculate_period_bounds(current_time: datetime, period_type: str) -> tuple[
return start, end
async def _get_usage_history(db: AsyncSession, budget: Budget, days: int = 30) -> List[dict]:
async def _get_usage_history(
db: AsyncSession, budget: Budget, days: int = 30
) -> List[dict]:
"""Get usage history for the budget"""
end_date = datetime.utcnow()
@@ -634,13 +680,12 @@ async def _get_usage_history(db: AsyncSession, budget: Budget, days: int = 30) -
# Build query
query = select(
func.date(UsageTracking.created_at).label('date'),
func.sum(UsageTracking.total_tokens).label('tokens'),
func.sum(UsageTracking.cost_cents).label('cost_cents'),
func.count(UsageTracking.id).label('requests')
func.date(UsageTracking.created_at).label("date"),
func.sum(UsageTracking.total_tokens).label("tokens"),
func.sum(UsageTracking.cost_cents).label("cost_cents"),
func.count(UsageTracking.id).label("requests"),
).where(
UsageTracking.created_at >= start_date,
UsageTracking.created_at <= end_date
UsageTracking.created_at >= start_date, UsageTracking.created_at <= end_date
)
# Filter by user or API key
@@ -649,7 +694,9 @@ async def _get_usage_history(db: AsyncSession, budget: Budget, days: int = 30) -
elif budget.user_id:
query = query.where(UsageTracking.user_id == budget.user_id)
query = query.group_by(func.date(UsageTracking.created_at)).order_by(func.date(UsageTracking.created_at))
query = query.group_by(func.date(UsageTracking.created_at)).order_by(
func.date(UsageTracking.created_at)
)
result = await db.execute(query)
rows = result.fetchall()
@@ -664,12 +711,14 @@ async def _get_usage_history(db: AsyncSession, budget: Budget, days: int = 30) -
elif budget.budget_type == "requests":
usage_value = row.requests or 0
history.append({
history.append(
{
"date": row.date.isoformat(),
"usage": usage_value,
"tokens": row.tokens or 0,
"cost_dollars": (row.cost_cents or 0) / 100.0,
"requests": row.requests or 0
})
"requests": row.requests or 0,
}
)
return history

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
"""
API v1 endpoints package
"""

View File

@@ -0,0 +1,166 @@
"""
Tool calling API endpoints
Integration between LLM and tool execution
"""
from typing import List, Dict, Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db
from app.core.security import get_current_user
from app.services.tool_calling_service import ToolCallingService
from app.services.llm.models import ChatRequest, ChatResponse
from app.schemas.tool_calling import (
ToolCallRequest,
ToolCallResponse,
ToolExecutionRequest,
ToolValidationRequest,
ToolValidationResponse,
ToolHistoryResponse,
)
router = APIRouter()
@router.post("/chat/completions", response_model=ChatResponse)
async def create_chat_completion_with_tools(
request: ChatRequest,
auto_execute_tools: bool = Query(
True, description="Whether to automatically execute tool calls"
),
max_tool_calls: int = Query(
5, ge=1, le=10, description="Maximum number of tool calls"
),
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Create chat completion with tool calling support"""
service = ToolCallingService(db)
# Resolve user ID for context
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
# Set user context in request
request.user_id = str(user_id)
request.api_key_id = 1 # Default for internal usage
response = await service.create_chat_completion_with_tools(
request=request,
user=current_user,
auto_execute_tools=auto_execute_tools,
max_tool_calls=max_tool_calls,
)
return response
@router.post("/chat/completions/stream")
async def create_chat_completion_stream_with_tools(
request: ChatRequest,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Create streaming chat completion with tool calling support"""
service = ToolCallingService(db)
# Resolve user ID for context
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
# Set user context in request
request.user_id = str(user_id)
request.api_key_id = 1 # Default for internal usage
async def stream_generator():
async for chunk in service.create_chat_completion_stream_with_tools(
request=request, user=current_user
):
yield f"data: {chunk}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
stream_generator(),
media_type="text/plain",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
@router.post("/execute", response_model=ToolCallResponse)
async def execute_tool_by_name(
request: ToolExecutionRequest,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Execute a tool by name directly"""
service = ToolCallingService(db)
try:
result = await service.execute_tool_by_name(
tool_name=request.tool_name,
parameters=request.parameters,
user=current_user,
)
return ToolCallResponse(success=True, result=result, error=None)
except Exception as e:
return ToolCallResponse(success=False, result=None, error=str(e))
@router.post("/validate", response_model=ToolValidationResponse)
async def validate_tool_availability(
request: ToolValidationRequest,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Validate which tools are available to the user"""
service = ToolCallingService(db)
availability = await service.validate_tool_availability(
tool_names=request.tool_names, user=current_user
)
return ToolValidationResponse(tool_availability=availability)
@router.get("/history", response_model=ToolHistoryResponse)
async def get_tool_call_history(
limit: int = Query(
50, ge=1, le=100, description="Number of history items to return"
),
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Get recent tool execution history for the user"""
service = ToolCallingService(db)
history = await service.get_tool_call_history(user=current_user, limit=limit)
return ToolHistoryResponse(history=history, total=len(history))
@router.get("/available")
async def get_available_tools(
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Get tools available for function calling"""
service = ToolCallingService(db)
# Get available tools
tools = await service._get_available_tools_for_user(current_user)
# Convert to OpenAI format
openai_tools = await service._convert_tools_to_openai_format(tools)
return {"tools": openai_tools, "count": len(openai_tools)}

View File

@@ -0,0 +1,386 @@
"""
Tool management and execution API endpoints
"""
from typing import List, Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db
from app.core.security import get_current_user
from app.services.tool_management_service import ToolManagementService
from app.services.tool_execution_service import ToolExecutionService
from app.schemas.tool import (
ToolCreate,
ToolUpdate,
ToolResponse,
ToolListResponse,
ToolExecutionCreate,
ToolExecutionResponse,
ToolExecutionListResponse,
ToolCategoryCreate,
ToolCategoryResponse,
ToolStatisticsResponse,
)
router = APIRouter()
@router.post("/", response_model=ToolResponse)
async def create_tool(
tool_data: ToolCreate,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Create a new tool"""
service = ToolManagementService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
tool = await service.create_tool(
name=tool_data.name,
display_name=tool_data.display_name,
code=tool_data.code,
tool_type=tool_data.tool_type,
created_by_user_id=user_id,
description=tool_data.description,
parameters_schema=tool_data.parameters_schema,
return_schema=tool_data.return_schema,
timeout_seconds=tool_data.timeout_seconds,
max_memory_mb=tool_data.max_memory_mb,
max_cpu_seconds=tool_data.max_cpu_seconds,
docker_image=tool_data.docker_image,
docker_command=tool_data.docker_command,
category=tool_data.category,
tags=tool_data.tags,
is_public=tool_data.is_public,
)
return ToolResponse.from_orm(tool)
@router.get("/", response_model=ToolListResponse)
async def list_tools(
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
category: Optional[str] = Query(None),
tool_type: Optional[str] = Query(None),
is_public: Optional[bool] = Query(None),
is_approved: Optional[bool] = Query(None),
search: Optional[str] = Query(None),
tags: Optional[List[str]] = Query(None),
created_by_user_id: Optional[int] = Query(None),
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""List tools with filtering and pagination"""
service = ToolManagementService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
tools = await service.get_tools(
user_id=user_id,
skip=skip,
limit=limit,
category=category,
tool_type=tool_type,
is_public=is_public,
is_approved=is_approved,
search=search,
tags=tags,
created_by_user_id=created_by_user_id,
)
return ToolListResponse(
tools=[ToolResponse.from_orm(tool) for tool in tools],
total=len(tools),
skip=skip,
limit=limit,
)
@router.get("/{tool_id}", response_model=ToolResponse)
async def get_tool(
tool_id: int,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Get tool by ID"""
service = ToolManagementService(db)
tool = await service.get_tool_by_id(tool_id)
if not tool:
raise HTTPException(status_code=404, detail="Tool not found")
# Resolve underlying User object if available
user_obj = (
current_user.get("user_obj")
if isinstance(current_user, dict)
else current_user
)
# Check if user can access this tool
if not user_obj or not tool.can_be_used_by(user_obj):
raise HTTPException(status_code=403, detail="Access denied to this tool")
return ToolResponse.from_orm(tool)
@router.put("/{tool_id}", response_model=ToolResponse)
async def update_tool(
tool_id: int,
tool_data: ToolUpdate,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Update tool (only by creator or admin)"""
service = ToolManagementService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
tool = await service.update_tool(
tool_id=tool_id,
user_id=user_id,
display_name=tool_data.display_name,
description=tool_data.description,
code=tool_data.code,
parameters_schema=tool_data.parameters_schema,
return_schema=tool_data.return_schema,
timeout_seconds=tool_data.timeout_seconds,
max_memory_mb=tool_data.max_memory_mb,
max_cpu_seconds=tool_data.max_cpu_seconds,
docker_image=tool_data.docker_image,
docker_command=tool_data.docker_command,
category=tool_data.category,
tags=tool_data.tags,
is_public=tool_data.is_public,
is_active=tool_data.is_active,
)
return ToolResponse.from_orm(tool)
@router.delete("/{tool_id}")
async def delete_tool(
tool_id: int,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Delete tool (only by creator or admin)"""
service = ToolManagementService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
await service.delete_tool(tool_id, user_id)
return {"message": "Tool deleted successfully"}
@router.post("/{tool_id}/approve", response_model=ToolResponse)
async def approve_tool(
tool_id: int,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Approve tool for public use (admin only)"""
service = ToolManagementService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
tool = await service.approve_tool(tool_id, user_id)
return ToolResponse.from_orm(tool)
# Tool Execution Endpoints
@router.post("/{tool_id}/execute", response_model=ToolExecutionResponse)
async def execute_tool(
tool_id: int,
execution_data: ToolExecutionCreate,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Execute a tool with given parameters"""
service = ToolExecutionService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
execution = await service.execute_tool(
tool_id=tool_id,
user_id=user_id,
parameters=execution_data.parameters,
timeout_override=execution_data.timeout_override,
)
return ToolExecutionResponse.from_orm(execution)
@router.get("/executions", response_model=ToolExecutionListResponse)
async def list_executions(
tool_id: Optional[int] = Query(None),
executed_by_user_id: Optional[int] = Query(None),
status: Optional[str] = Query(None),
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""List tool executions with filtering"""
service = ToolExecutionService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
executions = await service.get_tool_executions(
tool_id=tool_id,
user_id=user_id,
executed_by_user_id=executed_by_user_id,
status=status,
skip=skip,
limit=limit,
)
return ToolExecutionListResponse(
executions=[
ToolExecutionResponse.from_orm(execution) for execution in executions
],
total=len(executions),
skip=skip,
limit=limit,
)
@router.get("/executions/{execution_id}", response_model=ToolExecutionResponse)
async def get_execution(
execution_id: int,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Get execution details"""
service = ToolExecutionService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
# Get execution through list with filter to ensure permission check
executions = await service.get_tool_executions(
user_id=user_id, skip=0, limit=1
)
execution = next((e for e in executions if e.id == execution_id), None)
if not execution:
raise HTTPException(status_code=404, detail="Execution not found")
return ToolExecutionResponse.from_orm(execution)
@router.post("/executions/{execution_id}/cancel", response_model=ToolExecutionResponse)
async def cancel_execution(
execution_id: int,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Cancel a running execution"""
service = ToolExecutionService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
execution = await service.cancel_execution(execution_id, user_id)
return ToolExecutionResponse.from_orm(execution)
@router.get("/executions/{execution_id}/logs")
async def get_execution_logs(
execution_id: int,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Get real-time logs for execution"""
service = ToolExecutionService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
logs = await service.get_execution_logs(execution_id, user_id)
return logs
# Tool Categories
@router.post("/categories", response_model=ToolCategoryResponse)
async def create_category(
category_data: ToolCategoryCreate,
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Create a new tool category (admin only)"""
user_obj = (
current_user.get("user_obj")
if isinstance(current_user, dict)
else current_user
)
if not user_obj or not user_obj.has_permission("manage_tools"):
raise HTTPException(status_code=403, detail="Admin privileges required")
service = ToolManagementService(db)
category = await service.create_category(
name=category_data.name,
display_name=category_data.display_name,
description=category_data.description,
icon=category_data.icon,
color=category_data.color,
sort_order=category_data.sort_order,
)
return ToolCategoryResponse.from_orm(category)
@router.get("/categories", response_model=List[ToolCategoryResponse])
async def list_categories(
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""List all active tool categories"""
service = ToolManagementService(db)
categories = await service.get_categories()
return [ToolCategoryResponse.from_orm(category) for category in categories]
# Statistics
@router.get("/statistics", response_model=ToolStatisticsResponse)
async def get_statistics(
db: AsyncSession = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Get tool usage statistics"""
service = ToolManagementService(db)
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
stats = await service.get_tool_statistics(user_id=user_id)
return ToolStatisticsResponse(**stats)

View File

@@ -0,0 +1,703 @@
"""
User Management API endpoints
Admin endpoints for managing users, roles, and audit logs
"""
import logging
from typing import Optional, List, Dict, Any
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi.security import HTTPBearer
from pydantic import BaseModel, EmailStr, validator, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.security import get_current_user
from app.db.database import get_db
from app.models.user import User
from app.models.role import Role
from app.models.audit_log import AuditLog
from app.services.user_management_service import UserManagementService
from app.services.permission_manager import require_permission
from app.schemas.role import RoleCreate, RoleUpdate
logger = logging.getLogger(__name__)
router = APIRouter()
security = HTTPBearer()
# Request/Response Models
class CreateUserRequest(BaseModel):
email: EmailStr
username: str
password: str
full_name: Optional[str] = None
role_id: Optional[int] = None
is_active: bool = True
force_password_change: bool = False
@validator("password")
def validate_password(cls, v):
if len(v) < 8:
raise ValueError("Password must be at least 8 characters long")
return v
@validator("username")
def validate_username(cls, v):
if len(v) < 3:
raise ValueError("Username must be at least 3 characters long")
if not v.replace("_", "").replace("-", "").isalnum():
raise ValueError("Username must contain only alphanumeric characters, underscores, and hyphens")
return v
class UpdateUserRequest(BaseModel):
email: Optional[EmailStr] = None
username: Optional[str] = None
full_name: Optional[str] = None
role_id: Optional[int] = None
is_active: Optional[bool] = None
is_verified: Optional[bool] = None
custom_permissions: Optional[Dict[str, Any]] = None
@validator("username")
def validate_username(cls, v):
if v is not None and len(v) < 3:
raise ValueError("Username must be at least 3 characters long")
return v
class AdminPasswordResetRequest(BaseModel):
new_password: str
force_change_on_login: bool = True
@validator("new_password")
def validate_password(cls, v):
if len(v) < 8:
raise ValueError("Password must be at least 8 characters long")
return v
class UserResponse(BaseModel):
id: int
email: str
username: str
full_name: Optional[str]
role_id: Optional[int]
role: Optional[Dict[str, Any]]
is_active: bool
is_verified: bool
account_locked: bool
force_password_change: bool
created_at: datetime
updated_at: datetime
last_login: Optional[datetime]
audit_summary: Optional[Dict[str, Any]] = None
class Config:
from_attributes = True
class UserListResponse(BaseModel):
users: List[UserResponse]
total: int
skip: int
limit: int
class RoleResponse(BaseModel):
id: int
name: str
display_name: str
description: Optional[str]
level: str
is_active: bool
created_at: datetime
class Config:
from_attributes = True
class AuditLogResponse(BaseModel):
id: int
user_id: Optional[int]
action: str
resource_type: str
resource_id: Optional[str]
description: str
details: Dict[str, Any]
severity: str
category: Optional[str]
success: bool
created_at: datetime
class Config:
from_attributes = True
# User Management Endpoints
@router.get("/users", response_model=UserListResponse)
async def get_users(
request: Request,
skip: int = 0,
limit: int = 100,
search: Optional[str] = None,
role_id: Optional[int] = None,
is_active: Optional[bool] = None,
include_audit_summary: bool = False,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Get all users with filtering options"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:users:read",
context={"user_id": current_user["id"]}
)
service = UserManagementService(db)
users_data = await service.get_users(
skip=skip,
limit=limit,
search=search,
role_id=role_id,
is_active=is_active,
)
# Convert to response format
users = []
for user in users_data:
user_dict = user.to_dict()
if include_audit_summary:
# Get audit summary for user
audit_logs = await service.get_user_audit_logs(user.id, limit=10)
user_dict["audit_summary"] = {
"recent_actions": len(audit_logs),
"last_login": user.last_login.isoformat() if user.last_login else None,
}
user_response = UserResponse(**user_dict)
users.append(user_response)
return UserListResponse(
users=users,
total=len(users), # Would need actual count query for large datasets
skip=skip,
limit=limit,
)
@router.get("/users/{user_id}", response_model=UserResponse)
async def get_user(
user_id: int,
include_audit_summary: bool = False,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Get specific user by ID"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:users:read",
context={"user_id": current_user["id"], "owner_id": user_id}
)
service = UserManagementService(db)
user = await service.get_user_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
user_dict = user.to_dict()
if include_audit_summary:
# Get audit summary for user
audit_logs = await service.get_user_audit_logs(user_id, limit=10)
user_dict["audit_summary"] = {
"recent_actions": len(audit_logs),
"last_login": user.last_login.isoformat() if user.last_login else None,
}
return UserResponse(**user_dict)
@router.post("/users", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def create_user(
user_data: CreateUserRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Create a new user"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:users:create",
)
service = UserManagementService(db)
user = await service.create_user(
email=user_data.email,
username=user_data.username,
password=user_data.password,
full_name=user_data.full_name,
role_id=user_data.role_id,
is_active=user_data.is_active,
is_verified=True, # Admin-created users are verified by default
custom_permissions={}, # Empty by default
)
# Log user creation in audit
await service._log_audit_event(
user_id=current_user["id"],
action="create",
resource_type="user",
resource_id=str(user.id),
description=f"User created by admin: {user.email}",
details={
"created_by": current_user["email"],
"target_user": user.email,
"role_id": user_data.role_id,
},
severity="medium",
)
user_dict = user.to_dict()
return UserResponse(**user_dict)
@router.put("/users/{user_id}", response_model=UserResponse)
async def update_user(
user_id: int,
user_data: UpdateUserRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Update user information"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:users:update",
context={"user_id": current_user["id"], "owner_id": user_id}
)
service = UserManagementService(db)
# Get current user for audit comparison
current_user_data = await service.get_user_by_id(user_id)
if not current_user_data:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
old_values = current_user_data.to_dict()
# Update user
user = await service.update_user(
user_id=user_id,
email=user_data.email,
username=user_data.username,
full_name=user_data.full_name,
role_id=user_data.role_id,
is_active=user_data.is_active,
is_verified=user_data.is_verified,
custom_permissions=user_data.custom_permissions,
)
# Log user update in audit
await service._log_audit_event(
user_id=current_user["id"],
action="update",
resource_type="user",
resource_id=str(user_id),
description=f"User updated by admin: {user.email}",
details={
"updated_by": current_user["email"],
"target_user": user.email,
},
old_values=old_values,
new_values=user.to_dict(),
severity="medium",
)
user_dict = user.to_dict()
return UserResponse(**user_dict)
@router.post("/users/{user_id}/password-reset")
async def admin_reset_password(
user_id: int,
password_data: AdminPasswordResetRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Admin reset user password with forced change option"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:users:manage",
)
service = UserManagementService(db)
user = await service.admin_reset_user_password(
user_id=user_id,
new_password=password_data.new_password,
force_change_on_login=password_data.force_change_on_login,
admin_user_id=current_user["id"],
)
return {
"message": f"Password reset for user {user.email}",
"force_change_on_login": password_data.force_change_on_login,
}
@router.delete("/users/{user_id}")
async def delete_user(
user_id: int,
hard_delete: bool = False,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Delete or deactivate user"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:users:delete",
)
service = UserManagementService(db)
# Get user info for audit before deletion
user = await service.get_user_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
user_email = user.email
# Delete user
success = await service.delete_user(user_id, hard_delete=hard_delete)
# Log user deletion in audit
await service._log_audit_event(
user_id=current_user["id"],
action="delete",
resource_type="user",
resource_id=str(user_id),
description=f"User {'hard deleted' if hard_delete else 'deactivated'} by admin: {user_email}",
details={
"deleted_by": current_user["email"],
"target_user": user_email,
"hard_delete": hard_delete,
},
severity="high",
)
return {
"message": f"User {'deleted' if hard_delete else 'deactivated'} successfully",
"user_email": user_email,
}
# Role Management Endpoints
@router.get("/roles", response_model=List[RoleResponse])
async def get_roles(
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Get all available roles"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:roles:read",
)
service = UserManagementService(db)
roles = await service.get_roles(is_active=True)
return [RoleResponse(**role.to_dict()) for role in roles]
@router.post("/roles", response_model=RoleResponse, status_code=status.HTTP_201_CREATED)
async def create_role(
role_data: RoleCreate,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Create a new role"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:roles:create",
)
service = UserManagementService(db)
# Create role
role = await service.create_role(
name=role_data.name,
display_name=role_data.display_name,
description=role_data.description,
level=role_data.level,
permissions=role_data.permissions,
can_manage_users=role_data.can_manage_users,
can_manage_budgets=role_data.can_manage_budgets,
can_view_reports=role_data.can_view_reports,
can_manage_tools=role_data.can_manage_tools,
inherits_from=role_data.inherits_from,
is_active=role_data.is_active,
is_system_role=role_data.is_system_role,
)
# Log role creation
await service._log_audit_event(
user_id=current_user["id"],
action="create",
resource_type="role",
resource_id=str(role.id),
description=f"Role created: {role.name}",
details={
"created_by": current_user["email"],
"role_name": role.name,
"level": role.level,
},
severity="medium",
)
return RoleResponse(**role.to_dict())
@router.put("/roles/{role_id}", response_model=RoleResponse)
async def update_role(
role_id: int,
role_data: RoleUpdate,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Update a role"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:roles:update",
)
service = UserManagementService(db)
# Get current role for audit
current_role = await service.get_role_by_id(role_id)
if not current_role:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Role not found"
)
# Prevent updating system roles
if current_role.is_system_role:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Cannot modify system roles"
)
old_values = current_role.to_dict()
# Update role
role = await service.update_role(
role_id=role_id,
display_name=role_data.display_name,
description=role_data.description,
permissions=role_data.permissions,
can_manage_users=role_data.can_manage_users,
can_manage_budgets=role_data.can_manage_budgets,
can_view_reports=role_data.can_view_reports,
can_manage_tools=role_data.can_manage_tools,
is_active=role_data.is_active,
)
# Log role update
await service._log_audit_event(
user_id=current_user["id"],
action="update",
resource_type="role",
resource_id=str(role_id),
description=f"Role updated: {role.name}",
details={
"updated_by": current_user["email"],
"role_name": role.name,
},
old_values=old_values,
new_values=role.to_dict(),
severity="medium",
)
return RoleResponse(**role.to_dict())
@router.delete("/roles/{role_id}")
async def delete_role(
role_id: int,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Delete a role"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:roles:delete",
)
service = UserManagementService(db)
# Get role info for audit before deletion
role = await service.get_role_by_id(role_id)
if not role:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Role not found"
)
# Prevent deleting system roles
if role.is_system_role:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Cannot delete system roles"
)
role_name = role.name
# Delete role
success = await service.delete_role(role_id)
# Log role deletion
await service._log_audit_event(
user_id=current_user["id"],
action="delete",
resource_type="role",
resource_id=str(role_id),
description=f"Role deleted: {role_name}",
details={
"deleted_by": current_user["email"],
"role_name": role_name,
},
severity="high",
)
return {
"message": f"Role {role_name} deleted successfully",
"role_name": role_name,
}
# Audit Log Endpoints
@router.get("/users/{user_id}/audit-logs", response_model=List[AuditLogResponse])
async def get_user_audit_logs(
user_id: int,
skip: int = 0,
limit: int = 50,
action_filter: Optional[str] = None,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Get audit logs for a specific user"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:audit:read",
context={"user_id": current_user["id"], "owner_id": user_id}
)
service = UserManagementService(db)
audit_logs = await service.get_user_audit_logs(
user_id=user_id,
skip=skip,
limit=limit,
action_filter=action_filter,
)
return [AuditLogResponse(**log.to_dict()) for log in audit_logs]
@router.get("/audit-logs", response_model=List[AuditLogResponse])
async def get_all_audit_logs(
skip: int = 0,
limit: int = 100,
user_id: Optional[int] = None,
action_filter: Optional[str] = None,
category_filter: Optional[str] = None,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Get all audit logs with filtering"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:audit:read",
)
# Direct database query for audit logs with filters
from sqlalchemy import select, and_, desc
query = select(AuditLog)
conditions = []
if user_id:
conditions.append(AuditLog.user_id == user_id)
if action_filter:
conditions.append(AuditLog.action == action_filter)
if category_filter:
conditions.append(AuditLog.category == category_filter)
if conditions:
query = query.where(and_(*conditions))
query = query.order_by(desc(AuditLog.created_at))
query = query.offset(skip).limit(limit)
result = await db.execute(query)
audit_logs = result.scalars().all()
return [AuditLogResponse(**log.to_dict()) for log in audit_logs]
# Statistics Endpoints
@router.get("/statistics")
async def get_user_management_statistics(
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Get user management statistics"""
# Check permission
require_permission(
current_user.get("permissions", []),
"platform:users:read",
)
service = UserManagementService(db)
user_stats = await service.get_user_statistics()
role_stats = await service.get_role_statistics()
return {
"users": user_stats,
"roles": role_stats,
"generated_at": datetime.utcnow().isoformat(),
}

View File

@@ -12,16 +12,33 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db.database import get_db
from app.services.api_key_auth import require_api_key, RequireScope, APIKeyAuthService, get_api_key_context
from app.services.api_key_auth import (
require_api_key,
RequireScope,
APIKeyAuthService,
get_api_key_context,
)
from app.core.security import get_current_user
from app.models.user import User
from app.core.config import settings
from app.services.llm.service import llm_service
from app.services.llm.models import ChatRequest, ChatMessage as LLMChatMessage, EmbeddingRequest as LLMEmbeddingRequest
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError, ValidationError
from app.services.llm.models import (
ChatRequest,
ChatMessage as LLMChatMessage,
EmbeddingRequest as LLMEmbeddingRequest,
)
from app.services.llm.exceptions import (
LLMError,
ProviderError,
SecurityError,
ValidationError,
)
from app.services.budget_enforcement import (
check_budget_for_request, record_request_usage, BudgetEnforcementService,
atomic_check_and_reserve_budget, atomic_finalize_usage
check_budget_for_request,
record_request_usage,
BudgetEnforcementService,
atomic_check_and_reserve_budget,
atomic_finalize_usage,
)
from app.services.cost_calculator import CostCalculator, estimate_request_cost
from app.utils.exceptions import AuthenticationError, AuthorizationError
@@ -30,11 +47,7 @@ from app.middleware.analytics import set_analytics_data
logger = logging.getLogger(__name__)
# Models response cache - simple in-memory cache for performance
_models_cache = {
"data": None,
"cached_at": 0,
"cache_ttl": 900 # 15 minutes cache TTL
}
_models_cache = {"data": None, "cached_at": 0, "cache_ttl": 900} # 15 minutes cache TTL
router = APIRouter()
@@ -44,8 +57,10 @@ async def get_cached_models() -> List[Dict[str, Any]]:
current_time = time.time()
# Check if cache is still valid
if (_models_cache["data"] is not None and
current_time - _models_cache["cached_at"] < _models_cache["cache_ttl"]):
if (
_models_cache["data"] is not None
and current_time - _models_cache["cached_at"] < _models_cache["cache_ttl"]
):
logger.debug("Returning cached models list")
return _models_cache["data"]
@@ -63,13 +78,17 @@ async def get_cached_models() -> List[Dict[str, Any]]:
"created": model_info.created or int(time.time()),
"owned_by": model_info.owned_by,
# Add frontend-expected fields
"name": getattr(model_info, 'name', model_info.id), # Use name if available, fallback to id
"provider": getattr(model_info, 'provider', model_info.owned_by), # Use provider if available, fallback to owned_by
"name": getattr(
model_info, "name", model_info.id
), # Use name if available, fallback to id
"provider": getattr(
model_info, "provider", model_info.owned_by
), # Use provider if available, fallback to owned_by
"capabilities": model_info.capabilities,
"context_window": model_info.context_window,
"max_output_tokens": model_info.max_output_tokens,
"supports_streaming": model_info.supports_streaming,
"supports_function_calling": model_info.supports_function_calling
"supports_function_calling": model_info.supports_function_calling,
}
# Include tasks field if present
if model_info.tasks:
@@ -138,11 +157,12 @@ class ModelsResponse(BaseModel):
# Authentication: Public API endpoints should use require_api_key
# Internal API endpoints should use get_current_user from core.security
# Endpoints
@router.get("/models", response_model=ModelsResponse)
async def list_models(
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""List available models"""
try:
@@ -155,7 +175,7 @@ async def list_models(
if not await auth_service.check_scope_permission(context, "models.list"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions to list models"
detail="Insufficient permissions to list models",
)
# Get models from cache or LLM service
@@ -164,7 +184,9 @@ async def list_models(
# Filter models based on API key permissions
api_key = context.get("api_key")
if api_key and api_key.allowed_models:
models = [model for model in models if model.get("id") in api_key.allowed_models]
models = [
model for model in models if model.get("id") in api_key.allowed_models
]
return ModelsResponse(data=models)
@@ -174,14 +196,14 @@ async def list_models(
logger.error(f"Error listing models: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to list models"
detail="Failed to list models",
)
@router.post("/models/invalidate-cache")
async def invalidate_models_cache_endpoint(
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Invalidate models cache (admin only)"""
# Check for admin permissions
@@ -190,7 +212,7 @@ async def invalidate_models_cache_endpoint(
if not user or not user.get("is_superuser"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin privileges required"
detail="Admin privileges required",
)
else:
# For API key users, check admin permissions
@@ -198,7 +220,7 @@ async def invalidate_models_cache_endpoint(
if not await auth_service.check_scope_permission(context, "admin.cache"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin permissions required to invalidate cache"
detail="Admin permissions required to invalidate cache",
)
invalidate_models_cache()
@@ -210,7 +232,7 @@ async def create_chat_completion(
request_body: Request,
chat_request: ChatCompletionRequest,
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Create chat completion with budget enforcement"""
try:
@@ -221,23 +243,27 @@ async def create_chat_completion(
auth_service = APIKeyAuthService(db)
# Check permissions
if not await auth_service.check_scope_permission(context, "chat.completions"):
if not await auth_service.check_scope_permission(
context, "chat.completions"
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions for chat completions"
detail="Insufficient permissions for chat completions",
)
if not await auth_service.check_model_permission(context, chat_request.model):
if not await auth_service.check_model_permission(
context, chat_request.model
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Model '{chat_request.model}' not allowed"
detail=f"Model '{chat_request.model}' not allowed",
)
api_key = context.get("api_key")
if not api_key:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="API key information not available"
detail="API key information not available",
)
elif auth_type == "jwt":
# For JWT authentication, we'll skip the detailed permission checks for now
@@ -246,13 +272,13 @@ async def create_chat_completion(
if not user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User information not available"
detail="User information not available",
)
api_key = None # JWT users don't have API keys
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication type"
detail="Invalid authentication type",
)
# Estimate token usage for budget checking
@@ -265,6 +291,7 @@ async def create_chat_completion(
# Get a synchronous session for budget enforcement
from app.db.database import SessionLocal
sync_db = SessionLocal()
try:
@@ -272,20 +299,32 @@ async def create_chat_completion(
warnings = []
reserved_budget_ids = []
if auth_type == "api_key" and api_key:
is_allowed, error_message, budget_warnings, budget_ids = atomic_check_and_reserve_budget(
sync_db, api_key, chat_request.model, int(estimated_tokens), "chat/completions"
(
is_allowed,
error_message,
budget_warnings,
budget_ids,
) = atomic_check_and_reserve_budget(
sync_db,
api_key,
chat_request.model,
int(estimated_tokens),
"chat/completions",
)
if not is_allowed:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Budget exceeded: {error_message}"
detail=f"Budget exceeded: {error_message}",
)
warnings = budget_warnings
reserved_budget_ids = budget_ids
# Convert messages to LLM service format
llm_messages = [LLMChatMessage(role=msg.role, content=msg.content) for msg in chat_request.messages]
llm_messages = [
LLMChatMessage(role=msg.role, content=msg.content)
for msg in chat_request.messages
]
# Create LLM service request
llm_request = ChatRequest(
@@ -299,7 +338,9 @@ async def create_chat_completion(
stop=chat_request.stop,
stream=chat_request.stream or False,
user_id=str(context.get("user_id", "anonymous")),
api_key_id=context.get("api_key_id", 0) if auth_type == "api_key" else 0
api_key_id=context.get("api_key_id", 0)
if auth_type == "api_key"
else 0,
)
# Make request to LLM service
@@ -316,21 +357,25 @@ async def create_chat_completion(
"index": choice.index,
"message": {
"role": choice.message.role,
"content": choice.message.content
"content": choice.message.content,
},
"finish_reason": choice.finish_reason
"finish_reason": choice.finish_reason,
}
for choice in llm_response.choices
],
"usage": {
"prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0,
"completion_tokens": llm_response.usage.completion_tokens if llm_response.usage else 0,
"total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0
} if llm_response.usage else {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
"prompt_tokens": llm_response.usage.prompt_tokens
if llm_response.usage
else 0,
"completion_tokens": llm_response.usage.completion_tokens
if llm_response.usage
else 0,
"total_tokens": llm_response.usage.total_tokens
if llm_response.usage
else 0,
}
if llm_response.usage
else {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
# Calculate actual cost and update usage
@@ -347,13 +392,20 @@ async def create_chat_completion(
# Finalize actual usage in budgets (only for API key users)
if auth_type == "api_key" and api_key:
atomic_finalize_usage(
sync_db, reserved_budget_ids, api_key, chat_request.model,
input_tokens, output_tokens, "chat/completions"
sync_db,
reserved_budget_ids,
api_key,
chat_request.model,
input_tokens,
output_tokens,
"chat/completions",
)
# Update API key usage statistics
auth_service = APIKeyAuthService(db)
await auth_service.update_usage_stats(context, total_tokens, actual_cost_cents)
await auth_service.update_usage_stats(
context, total_tokens, actual_cost_cents
)
# Set analytics data for middleware
set_analytics_data(
@@ -363,7 +415,7 @@ async def create_chat_completion(
total_tokens=total_tokens,
cost_cents=actual_cost_cents,
budget_ids=reserved_budget_ids,
budget_warnings=warnings
budget_warnings=warnings,
)
# Add budget warnings to response if any
@@ -381,37 +433,37 @@ async def create_chat_completion(
logger.warning(f"Security error in chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Security validation failed: {e.message}"
detail=f"Security validation failed: {e.message}",
)
except ValidationError as e:
logger.warning(f"Validation error in chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Request validation failed: {e.message}"
detail=f"Request validation failed: {e.message}",
)
except ProviderError as e:
logger.error(f"Provider error in chat completion: {e}")
if "rate limit" in str(e).lower():
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded"
detail="Rate limit exceeded",
)
else:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LLM service temporarily unavailable"
detail="LLM service temporarily unavailable",
)
except LLMError as e:
logger.error(f"LLM service error in chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="LLM service error"
detail="LLM service error",
)
except Exception as e:
logger.error(f"Unexpected error creating chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create chat completion"
detail="Failed to create chat completion",
)
@@ -419,7 +471,7 @@ async def create_chat_completion(
async def create_embedding(
request: EmbeddingRequest,
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Create embedding with budget enforcement"""
try:
@@ -429,20 +481,20 @@ async def create_embedding(
if not await auth_service.check_scope_permission(context, "embeddings.create"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions for embeddings"
detail="Insufficient permissions for embeddings",
)
if not await auth_service.check_model_permission(context, request.model):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Model '{request.model}' not allowed"
detail=f"Model '{request.model}' not allowed",
)
api_key = context.get("api_key")
if not api_key:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="API key information not available"
detail="API key information not available",
)
# Estimate token usage for budget checking
@@ -460,7 +512,7 @@ async def create_embedding(
if not is_allowed:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Budget exceeded: {error_message}"
detail=f"Budget exceeded: {error_message}",
)
# Create LLM service request
@@ -469,7 +521,7 @@ async def create_embedding(
input=request.input,
encoding_format=request.encoding_format,
user_id=str(context["user_id"]),
api_key_id=context["api_key_id"]
api_key_id=context["api_key_id"],
)
# Make request to LLM service
@@ -482,18 +534,24 @@ async def create_embedding(
{
"object": emb.object,
"index": emb.index,
"embedding": emb.embedding
"embedding": emb.embedding,
}
for emb in llm_response.data
],
"model": llm_response.model,
"usage": {
"prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0,
"total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0
} if llm_response.usage else {
"prompt_tokens": int(estimated_tokens),
"total_tokens": int(estimated_tokens)
"prompt_tokens": llm_response.usage.prompt_tokens
if llm_response.usage
else 0,
"total_tokens": llm_response.usage.total_tokens
if llm_response.usage
else 0,
}
if llm_response.usage
else {
"prompt_tokens": int(estimated_tokens),
"total_tokens": int(estimated_tokens),
},
}
# Calculate actual cost and update usage
@@ -511,7 +569,9 @@ async def create_embedding(
)
# Update API key usage statistics
await auth_service.update_usage_stats(context, total_tokens, actual_cost_cents)
await auth_service.update_usage_stats(
context, total_tokens, actual_cost_cents
)
# Add budget warnings to response if any
if warnings:
@@ -528,44 +588,42 @@ async def create_embedding(
logger.warning(f"Security error in embedding: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Security validation failed: {e.message}"
detail=f"Security validation failed: {e.message}",
)
except ValidationError as e:
logger.warning(f"Validation error in embedding: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Request validation failed: {e.message}"
detail=f"Request validation failed: {e.message}",
)
except ProviderError as e:
logger.error(f"Provider error in embedding: {e}")
if "rate limit" in str(e).lower():
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded"
detail="Rate limit exceeded",
)
else:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LLM service temporarily unavailable"
detail="LLM service temporarily unavailable",
)
except LLMError as e:
logger.error(f"LLM service error in embedding: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="LLM service error"
detail="LLM service error",
)
except Exception as e:
logger.error(f"Unexpected error creating embedding: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create embedding"
detail="Failed to create embedding",
)
@router.get("/health")
async def llm_health_check(
context: Dict[str, Any] = Depends(require_api_key)
):
async def llm_health_check(context: Dict[str, Any] = Depends(require_api_key)):
"""Health check for LLM service"""
try:
health_summary = llm_service.get_health_summary()
@@ -585,34 +643,31 @@ async def llm_health_check(
"status": overall_status,
"service": "LLM Service",
"service_status": health_summary,
"provider_status": {name: {
"provider_status": {
name: {
"status": status.status,
"latency_ms": status.latency_ms,
"error_message": status.error_message
} for name, status in provider_status.items()},
"error_message": status.error_message,
}
for name, status in provider_status.items()
},
"user_id": context["user_id"],
"api_key_name": context["api_key_name"]
"api_key_name": context["api_key_name"],
}
except Exception as e:
logger.error(f"LLM health check error: {e}")
return {
"status": "unhealthy",
"service": "LLM Service",
"error": str(e)
}
return {"status": "unhealthy", "service": "LLM Service", "error": str(e)}
@router.get("/usage")
async def get_usage_stats(
context: Dict[str, Any] = Depends(require_api_key)
):
async def get_usage_stats(context: Dict[str, Any] = Depends(require_api_key)):
"""Get usage statistics for the API key"""
try:
api_key = context.get("api_key")
if not api_key:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="API key information not available"
detail="API key information not available",
)
return {
@@ -622,15 +677,17 @@ async def get_usage_stats(
"total_tokens": api_key.total_tokens,
"total_cost_cents": api_key.total_cost,
"created_at": api_key.created_at.isoformat(),
"last_used_at": api_key.last_used_at.isoformat() if api_key.last_used_at else None,
"last_used_at": api_key.last_used_at.isoformat()
if api_key.last_used_at
else None,
"rate_limits": {
"per_minute": api_key.rate_limit_per_minute,
"per_hour": api_key.rate_limit_per_hour,
"per_day": api_key.rate_limit_per_day
"per_day": api_key.rate_limit_per_day,
},
"permissions": api_key.permissions,
"scopes": api_key.scopes,
"allowed_models": api_key.allowed_models
"allowed_models": api_key.allowed_models,
}
except HTTPException:
@@ -639,7 +696,7 @@ async def get_usage_stats(
logger.error(f"Error getting usage stats: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get usage statistics"
detail="Failed to get usage statistics",
)
@@ -647,7 +704,7 @@ async def get_usage_stats(
async def get_budget_status(
request: Request,
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get current budget status and usage analytics"""
try:
@@ -659,14 +716,14 @@ async def get_budget_status(
if not await auth_service.check_scope_permission(context, "budget.read"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions to read budget information"
detail="Insufficient permissions to read budget information",
)
api_key = context.get("api_key")
if not api_key:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="API key information not available"
detail="API key information not available",
)
# Convert AsyncSession to Session for budget enforcement
@@ -676,10 +733,7 @@ async def get_budget_status(
budget_service = BudgetEnforcementService(sync_db)
budget_status = budget_service.get_budget_status(api_key)
return {
"object": "budget_status",
"data": budget_status
}
return {"object": "budget_status", "data": budget_status}
finally:
sync_db.close()
@@ -689,7 +743,7 @@ async def get_budget_status(
if not user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User information not available"
detail="User information not available",
)
# Return basic budget info for JWT users
@@ -702,14 +756,14 @@ async def get_budget_status(
"projections": {
"daily_burn_rate": 0.0,
"projected_monthly": 0.0,
"days_remaining": 30
}
}
"days_remaining": 30,
},
},
}
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication type"
detail="Invalid authentication type",
)
except HTTPException:
@@ -718,7 +772,7 @@ async def get_budget_status(
logger.error(f"Error getting budget status: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get budget status"
detail="Failed to get budget status",
)
@@ -726,7 +780,7 @@ async def get_budget_status(
@router.get("/metrics")
async def get_llm_metrics(
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get LLM service metrics (admin only)"""
try:
@@ -735,7 +789,7 @@ async def get_llm_metrics(
if not await auth_service.check_scope_permission(context, "admin.metrics"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin permissions required to view metrics"
detail="Admin permissions required to view metrics",
)
metrics = llm_service.get_metrics()
@@ -748,8 +802,8 @@ async def get_llm_metrics(
"average_latency_ms": metrics.average_latency_ms,
"average_risk_score": metrics.average_risk_score,
"provider_metrics": metrics.provider_metrics,
"last_updated": metrics.last_updated.isoformat()
}
"last_updated": metrics.last_updated.isoformat(),
},
}
except HTTPException:
@@ -758,14 +812,14 @@ async def get_llm_metrics(
logger.error(f"Error getting LLM metrics: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get LLM metrics"
detail="Failed to get LLM metrics",
)
@router.get("/providers/status")
async def get_provider_status(
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get status of all LLM providers"""
try:
@@ -773,7 +827,7 @@ async def get_provider_status(
if not await auth_service.check_scope_permission(context, "admin.status"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin permissions required to view provider status"
detail="Admin permissions required to view provider status",
)
provider_status = await llm_service.get_provider_status()
@@ -787,10 +841,10 @@ async def get_provider_status(
"success_rate": status.success_rate,
"last_check": status.last_check.isoformat(),
"error_message": status.error_message,
"models_available": status.models_available
"models_available": status.models_available,
}
for name, status in provider_status.items()
}
},
}
except HTTPException:
@@ -799,5 +853,5 @@ async def get_provider_status(
logger.error(f"Error getting provider status: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get provider status"
detail="Failed to get provider status",
)

View File

@@ -13,7 +13,12 @@ from app.db.database import get_db
from app.core.security import get_current_user
from app.services.llm.service import llm_service
from app.services.llm.models import ChatRequest, ChatMessage as LLMChatMessage
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError, ValidationError
from app.services.llm.exceptions import (
LLMError,
ProviderError,
SecurityError,
ValidationError,
)
from app.api.v1.llm import get_cached_models # Reuse the caching logic
logger = logging.getLogger(__name__)
@@ -35,14 +40,12 @@ async def list_models(
logger.error(f"Failed to list models: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve models"
detail="Failed to retrieve models",
)
@router.get("/providers/status")
async def get_provider_status(
current_user: Dict[str, Any] = Depends(get_current_user)
):
async def get_provider_status(current_user: Dict[str, Any] = Depends(get_current_user)):
"""
Get status of all LLM providers for authenticated users
"""
@@ -58,23 +61,21 @@ async def get_provider_status(
"success_rate": status.success_rate,
"last_check": status.last_check.isoformat(),
"error_message": status.error_message,
"models_available": status.models_available
"models_available": status.models_available,
}
for name, status in provider_status.items()
}
},
}
except Exception as e:
logger.error(f"Failed to get provider status: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve provider status"
detail="Failed to retrieve provider status",
)
@router.get("/health")
async def health_check(
current_user: Dict[str, Any] = Depends(get_current_user)
):
async def health_check(current_user: Dict[str, Any] = Depends(get_current_user)):
"""
Get LLM service health status for authenticated users
"""
@@ -83,39 +84,35 @@ async def health_check(
return {
"status": health["status"],
"providers": health.get("providers", {}),
"timestamp": health.get("timestamp")
"timestamp": health.get("timestamp"),
}
except Exception as e:
logger.error(f"Health check failed: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Health check failed"
detail="Health check failed",
)
@router.get("/metrics")
async def get_metrics(
current_user: Dict[str, Any] = Depends(get_current_user)
):
async def get_metrics(current_user: Dict[str, Any] = Depends(get_current_user)):
"""
Get LLM service metrics for authenticated users
"""
try:
metrics = await llm_service.get_metrics()
return {
"object": "metrics",
"data": metrics
}
return {"object": "metrics", "data": metrics}
except Exception as e:
logger.error(f"Failed to get metrics: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve metrics"
detail="Failed to retrieve metrics",
)
class ChatCompletionRequest(BaseModel):
"""Request model for chat completions"""
model: str
messages: List[Dict[str, str]]
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0)
@@ -128,7 +125,7 @@ class ChatCompletionRequest(BaseModel):
async def create_chat_completion(
request: ChatCompletionRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""
Create chat completion for authenticated frontend users.
@@ -151,11 +148,13 @@ async def create_chat_completion(
top_p=request.top_p,
stream=request.stream,
user_id=user_id,
api_key_id=0 # Special value for JWT-authenticated requests
api_key_id=0, # Special value for JWT-authenticated requests
)
# Log the request for debugging
logger.info(f"Internal chat completion request from user {current_user.get('id')}: model={request.model}")
logger.info(
f"Internal chat completion request from user {current_user.get('id')}: model={request.model}"
)
# Process the request through the LLM service
response = await llm_service.create_chat_completion(chat_request)
@@ -171,17 +170,21 @@ async def create_chat_completion(
"index": choice.index,
"message": {
"role": choice.message.role,
"content": choice.message.content
"content": choice.message.content,
},
"finish_reason": choice.finish_reason
"finish_reason": choice.finish_reason,
}
for choice in response.choices
],
"usage": {
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
"total_tokens": response.usage.total_tokens if response.usage else 0
} if response.usage else None
"completion_tokens": response.usage.completion_tokens
if response.usage
else 0,
"total_tokens": response.usage.total_tokens if response.usage else 0,
}
if response.usage
else None,
}
return formatted_response
@@ -189,18 +192,17 @@ async def create_chat_completion(
except ValidationError as e:
logger.error(f"Validation error in chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid request: {str(e)}"
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid request: {str(e)}"
)
except LLMError as e:
logger.error(f"LLM service error: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"LLM service error: {str(e)}"
detail=f"LLM service error: {str(e)}",
)
except Exception as e:
logger.error(f"Unexpected error in chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to process chat completion"
detail="Failed to process chat completion",
)

View File

@@ -40,7 +40,7 @@ async def list_modules(current_user: Dict[str, Any] = Depends(get_current_user))
"description": module_info["description"],
"initialized": is_loaded,
"enabled": is_enabled,
"status": status # Add status field for frontend compatibility
"status": status, # Add status field for frontend compatibility
}
# Get module statistics if available and module is loaded
@@ -49,11 +49,14 @@ async def list_modules(current_user: Dict[str, Any] = Depends(get_current_user))
if hasattr(module_instance, "get_stats"):
try:
import asyncio
if asyncio.iscoroutinefunction(module_instance.get_stats):
stats = await module_instance.get_stats()
else:
stats = module_instance.get_stats()
api_module["stats"] = stats.__dict__ if hasattr(stats, "__dict__") else stats
api_module["stats"] = (
stats.__dict__ if hasattr(stats, "__dict__") else stats
)
except:
api_module["stats"] = {}
@@ -66,7 +69,7 @@ async def list_modules(current_user: Dict[str, Any] = Depends(get_current_user))
"total": len(modules),
"modules": modules,
"module_count": loaded_count,
"initialized": module_manager.initialized
"initialized": module_manager.initialized,
}
@@ -106,23 +109,30 @@ async def get_modules_status(current_user: Dict[str, Any] = Depends(get_current_
if hasattr(module_instance, "get_stats"):
try:
import asyncio
if asyncio.iscoroutinefunction(module_instance.get_stats):
stats_result = await module_instance.get_stats()
else:
stats_result = module_instance.get_stats()
stats = stats_result.__dict__ if hasattr(stats_result, "__dict__") else stats_result
stats = (
stats_result.__dict__
if hasattr(stats_result, "__dict__")
else stats_result
)
except:
stats = {}
modules_with_status.append({
modules_with_status.append(
{
"name": name,
"version": module_info["version"],
"description": module_info["description"],
"status": status,
"enabled": is_enabled,
"loaded": is_loaded,
"stats": stats
})
"stats": stats,
}
)
return {
"modules": modules_with_status,
@@ -130,12 +140,14 @@ async def get_modules_status(current_user: Dict[str, Any] = Depends(get_current_
"running": running_count,
"standby": standby_count,
"failed": failed_count,
"system_initialized": module_manager.initialized
"system_initialized": module_manager.initialized,
}
@router.get("/{module_name}")
async def get_module_info(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
async def get_module_info(
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Get detailed information about a specific module"""
log_api_request("get_module_info", {"module_name": module_name})
@@ -148,14 +160,17 @@ async def get_module_info(module_name: str, current_user: Dict[str, Any] = Depen
"version": getattr(module, "version", "1.0.0"),
"description": getattr(module, "description", ""),
"initialized": getattr(module, "initialized", False),
"enabled": module_manager.module_configs.get(module_name, ModuleConfig(module_name)).enabled,
"capabilities": []
"enabled": module_manager.module_configs.get(
module_name, ModuleConfig(module_name)
).enabled,
"capabilities": [],
}
# Get module capabilities
if hasattr(module, "get_module_info"):
try:
import asyncio
if asyncio.iscoroutinefunction(module.get_module_info):
info = await module.get_module_info()
else:
@@ -168,11 +183,14 @@ async def get_module_info(module_name: str, current_user: Dict[str, Any] = Depen
if hasattr(module, "get_stats"):
try:
import asyncio
if asyncio.iscoroutinefunction(module.get_stats):
stats = await module.get_stats()
else:
stats = module.get_stats()
module_info["stats"] = stats.__dict__ if hasattr(stats, "__dict__") else stats
module_info["stats"] = (
stats.__dict__ if hasattr(stats, "__dict__") else stats
)
except:
module_info["stats"] = {}
@@ -188,7 +206,9 @@ async def get_module_info(module_name: str, current_user: Dict[str, Any] = Depen
@router.post("/{module_name}/enable")
async def enable_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
async def enable_module(
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Enable a module"""
log_api_request("enable_module", {"module_name": module_name})
@@ -204,16 +224,18 @@ async def enable_module(module_name: str, current_user: Dict[str, Any] = Depends
try:
await module_manager._load_module(module_name, config)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to enable module '{module_name}': {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Failed to enable module '{module_name}': {str(e)}",
)
return {
"message": f"Module '{module_name}' enabled successfully",
"enabled": True
}
return {"message": f"Module '{module_name}' enabled successfully", "enabled": True}
@router.post("/{module_name}/disable")
async def disable_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
async def disable_module(
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Disable a module"""
log_api_request("disable_module", {"module_name": module_name})
@@ -229,11 +251,14 @@ async def disable_module(module_name: str, current_user: Dict[str, Any] = Depend
try:
await module_manager.unload_module(module_name)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to disable module '{module_name}': {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Failed to disable module '{module_name}': {str(e)}",
)
return {
"message": f"Module '{module_name}' disabled successfully",
"enabled": False
"enabled": False,
}
@@ -260,19 +285,21 @@ async def reload_all_modules(current_user: Dict[str, Any] = Depends(get_current_
"message": f"Reloaded {len(results) - len(failed_modules)}/{len(results)} modules successfully",
"success": False,
"results": results,
"failed_modules": failed_modules
"failed_modules": failed_modules,
}
else:
return {
"message": f"All {len(results)} modules reloaded successfully",
"success": True,
"results": results,
"failed_modules": []
"failed_modules": [],
}
@router.post("/{module_name}/reload")
async def reload_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
async def reload_module(
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Reload a specific module"""
log_api_request("reload_module", {"module_name": module_name})
@@ -282,16 +309,20 @@ async def reload_module(module_name: str, current_user: Dict[str, Any] = Depends
success = await module_manager.reload_module(module_name)
if not success:
raise HTTPException(status_code=500, detail=f"Failed to reload module '{module_name}'")
raise HTTPException(
status_code=500, detail=f"Failed to reload module '{module_name}'"
)
return {
"message": f"Module '{module_name}' reloaded successfully",
"reloaded": True
"reloaded": True,
}
@router.post("/{module_name}/restart")
async def restart_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
async def restart_module(
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Restart a specific module (alias for reload)"""
log_api_request("restart_module", {"module_name": module_name})
@@ -301,16 +332,20 @@ async def restart_module(module_name: str, current_user: Dict[str, Any] = Depend
success = await module_manager.reload_module(module_name)
if not success:
raise HTTPException(status_code=500, detail=f"Failed to restart module '{module_name}'")
raise HTTPException(
status_code=500, detail=f"Failed to restart module '{module_name}'"
)
return {
"message": f"Module '{module_name}' restarted successfully",
"restarted": True
"restarted": True,
}
@router.post("/{module_name}/start")
async def start_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
async def start_module(
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Start a specific module (enable and load)"""
log_api_request("start_module", {"module_name": module_name})
@@ -325,14 +360,13 @@ async def start_module(module_name: str, current_user: Dict[str, Any] = Depends(
if module_name not in module_manager.modules:
await module_manager._load_module(module_name, config)
return {
"message": f"Module '{module_name}' started successfully",
"started": True
}
return {"message": f"Module '{module_name}' started successfully", "started": True}
@router.post("/{module_name}/stop")
async def stop_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
async def stop_module(
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Stop a specific module (disable and unload)"""
log_api_request("stop_module", {"module_name": module_name})
@@ -347,14 +381,13 @@ async def stop_module(module_name: str, current_user: Dict[str, Any] = Depends(g
if module_name in module_manager.modules:
await module_manager.unload_module(module_name)
return {
"message": f"Module '{module_name}' stopped successfully",
"stopped": True
}
return {"message": f"Module '{module_name}' stopped successfully", "stopped": True}
@router.get("/{module_name}/stats")
async def get_module_stats(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
async def get_module_stats(
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Get module statistics"""
log_api_request("get_module_stats", {"module_name": module_name})
@@ -364,26 +397,39 @@ async def get_module_stats(module_name: str, current_user: Dict[str, Any] = Depe
module = module_manager.modules[module_name]
if not hasattr(module, "get_stats"):
raise HTTPException(status_code=404, detail=f"Module '{module_name}' does not provide statistics")
raise HTTPException(
status_code=404,
detail=f"Module '{module_name}' does not provide statistics",
)
try:
import asyncio
if asyncio.iscoroutinefunction(module.get_stats):
stats = await module.get_stats()
else:
stats = module.get_stats()
return {
"module": module_name,
"stats": stats.__dict__ if hasattr(stats, "__dict__") else stats
"stats": stats.__dict__ if hasattr(stats, "__dict__") else stats,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get statistics: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Failed to get statistics: {str(e)}"
)
@router.post("/{module_name}/execute")
async def execute_module_action(module_name: str, request_data: Dict[str, Any], current_user: Dict[str, Any] = Depends(get_current_user)):
async def execute_module_action(
module_name: str,
request_data: Dict[str, Any],
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Execute a module action through the interceptor pattern"""
log_api_request("execute_module_action", {"module_name": module_name, "action": request_data.get("action")})
log_api_request(
"execute_module_action",
{"module_name": module_name, "action": request_data.get("action")},
)
if module_name not in module_manager.modules:
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
@@ -391,14 +437,16 @@ async def execute_module_action(module_name: str, request_data: Dict[str, Any],
module = module_manager.modules[module_name]
# Check if module supports the new interceptor pattern
if hasattr(module, 'execute_with_interceptors'):
if hasattr(module, "execute_with_interceptors"):
try:
# Prepare context (would normally come from authentication middleware)
context = {
"user_id": "test_user", # Would come from authentication
"api_key_id": "test_api_key", # Would come from API key auth
"ip_address": "127.0.0.1", # Would come from request
"user_permissions": [f"modules:{module_name}:*"] # Would come from user/API key permissions
"user_permissions": [
f"modules:{module_name}:*"
], # Would come from user/API key permissions
}
# Execute through interceptor chain
@@ -408,11 +456,13 @@ async def execute_module_action(module_name: str, request_data: Dict[str, Any],
"module": module_name,
"success": True,
"response": response,
"interceptor_pattern": True
"interceptor_pattern": True,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Module execution failed: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Module execution failed: {str(e)}"
)
# Fallback for legacy modules
else:
@@ -423,6 +473,7 @@ async def execute_module_action(module_name: str, request_data: Dict[str, Any],
method = getattr(module, action)
if callable(method):
import asyncio
if asyncio.iscoroutinefunction(method):
response = await method(request_data)
else:
@@ -432,18 +483,27 @@ async def execute_module_action(module_name: str, request_data: Dict[str, Any],
"module": module_name,
"success": True,
"response": response,
"interceptor_pattern": False
"interceptor_pattern": False,
}
else:
raise HTTPException(status_code=400, detail=f"'{action}' is not callable")
raise HTTPException(
status_code=400, detail=f"'{action}' is not callable"
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Module execution failed: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Module execution failed: {str(e)}"
)
else:
raise HTTPException(status_code=400, detail=f"Action '{action}' not supported by module '{module_name}'")
raise HTTPException(
status_code=400,
detail=f"Action '{action}' not supported by module '{module_name}'",
)
@router.get("/{module_name}/config")
async def get_module_config(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
async def get_module_config(
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Get module configuration schema and current values"""
log_api_request("get_module_config", {"module_name": module_name})
@@ -471,10 +531,15 @@ async def get_module_config(module_name: str, current_user: Dict[str, Any] = Dep
dynamic_schema = copy.deepcopy(schema)
# Add enum options for the model field
if "properties" in dynamic_schema and "model" in dynamic_schema["properties"]:
if (
"properties" in dynamic_schema
and "model" in dynamic_schema["properties"]
):
dynamic_schema["properties"]["model"]["enum"] = model_ids
# Set a sensible default if the current default isn't in the list
current_default = dynamic_schema["properties"]["model"].get("default", "gpt-3.5-turbo")
current_default = dynamic_schema["properties"]["model"].get(
"default", "gpt-3.5-turbo"
)
if current_default not in model_ids and model_ids:
dynamic_schema["properties"]["model"]["default"] = model_ids[0]
@@ -489,12 +554,16 @@ async def get_module_config(module_name: str, current_user: Dict[str, Any] = Dep
"description": manifest.description,
"schema": schema,
"current_config": current_config,
"has_schema": schema is not None
"has_schema": schema is not None,
}
@router.post("/{module_name}/config")
async def update_module_config(module_name: str, config: dict, current_user: Dict[str, Any] = Depends(get_current_user)):
async def update_module_config(
module_name: str,
config: dict,
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Update module configuration"""
log_api_request("update_module_config", {"module_name": module_name})
@@ -518,7 +587,7 @@ async def update_module_config(module_name: str, config: dict, current_user: Dic
return {
"message": f"Configuration updated for module '{module_name}'",
"config": config
"config": config,
}
except Exception as e:

View File

@@ -12,9 +12,13 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db
from app.services.api_key_auth import require_api_key
from app.api.v1.llm import (
get_cached_models, ModelsResponse, ModelInfo,
ChatCompletionRequest, EmbeddingRequest, create_chat_completion as llm_chat_completion,
create_embedding as llm_create_embedding
get_cached_models,
ModelsResponse,
ModelInfo,
ChatCompletionRequest,
EmbeddingRequest,
create_chat_completion as llm_chat_completion,
create_embedding as llm_create_embedding,
)
logger = logging.getLogger(__name__)
@@ -22,8 +26,12 @@ logger = logging.getLogger(__name__)
router = APIRouter()
def openai_error_response(message: str, error_type: str = "invalid_request_error",
status_code: int = 400, code: str = None):
def openai_error_response(
message: str,
error_type: str = "invalid_request_error",
status_code: int = 400,
code: str = None,
):
"""Create OpenAI-compatible error response"""
error_data = {
"error": {
@@ -34,16 +42,13 @@ def openai_error_response(message: str, error_type: str = "invalid_request_error
if code:
error_data["error"]["code"] = code
return JSONResponse(
status_code=status_code,
content=error_data
)
return JSONResponse(status_code=status_code, content=error_data)
@router.get("/models", response_model=ModelsResponse)
async def list_models(
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""
Lists the currently available models, and provides basic information about each one
@@ -55,30 +60,23 @@ async def list_models(
try:
# Delegate to the existing LLM models endpoint
from app.api.v1.llm import list_models as llm_list_models
return await llm_list_models(context, db)
except HTTPException as e:
# Convert FastAPI HTTPException to OpenAI format
if e.status_code == 401:
return openai_error_response(
"Invalid authentication credentials",
"authentication_error",
401
"Invalid authentication credentials", "authentication_error", 401
)
elif e.status_code == 403:
return openai_error_response(
"Insufficient permissions",
"permission_error",
403
"Insufficient permissions", "permission_error", 403
)
else:
return openai_error_response(str(e.detail), "api_error", e.status_code)
except Exception as e:
logger.error(f"Error in OpenAI models endpoint: {e}")
return openai_error_response(
"Internal server error",
"api_error",
500
)
return openai_error_response("Internal server error", "api_error", 500)
@router.post("/chat/completions")
@@ -86,7 +84,7 @@ async def create_chat_completion(
request_body: Request,
chat_request: ChatCompletionRequest,
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""
Create chat completion - OpenAI compatible endpoint
@@ -102,7 +100,7 @@ async def create_chat_completion(
async def create_embedding(
request: EmbeddingRequest,
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""
Create embedding - OpenAI compatible endpoint
@@ -118,7 +116,7 @@ async def create_embedding(
async def retrieve_model(
model_id: str,
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""
Retrieve model information - OpenAI compatible endpoint
@@ -133,7 +131,9 @@ async def retrieve_model(
# Filter models based on API key permissions
api_key = context.get("api_key")
if api_key and api_key.allowed_models:
models = [model for model in models if model.get("id") in api_key.allowed_models]
models = [
model for model in models if model.get("id") in api_key.allowed_models
]
# Find the specific model
model = next((m for m in models if m.get("id") == model_id), None)
@@ -141,14 +141,14 @@ async def retrieve_model(
if not model:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model '{model_id}' not found"
detail=f"Model '{model_id}' not found",
)
return ModelInfo(
id=model.get("id", model_id),
object="model",
created=model.get("created", 0),
owned_by=model.get("owned_by", "system")
owned_by=model.get("owned_by", "system"),
)
except HTTPException:
@@ -157,5 +157,5 @@ async def retrieve_model(
logger.error(f"Error retrieving model {model_id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve model information"
detail="Failed to retrieve model information",
)

View File

@@ -7,7 +7,11 @@ from typing import List, Dict, Any, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from app.services.permission_manager import permission_registry, Permission, PermissionScope
from app.services.permission_manager import (
permission_registry,
Permission,
PermissionScope,
)
from app.core.logging import get_logger
from app.core.security import get_current_user
@@ -86,7 +90,7 @@ async def get_available_permissions(namespace: Optional[str] = None):
resource=perm.resource,
action=perm.action,
description=perm.description,
conditions=getattr(perm, 'conditions', None)
conditions=getattr(perm, "conditions", None),
)
for perm in perms
]
@@ -97,7 +101,7 @@ async def get_available_permissions(namespace: Optional[str] = None):
logger.error(f"Error getting permissions: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get permissions: {str(e)}"
detail=f"Failed to get permissions: {str(e)}",
)
@@ -112,7 +116,7 @@ async def get_permission_hierarchy():
logger.error(f"Error getting permission hierarchy: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get permission hierarchy: {str(e)}"
detail=f"Failed to get permission hierarchy: {str(e)}",
)
@@ -120,44 +124,43 @@ async def get_permission_hierarchy():
async def validate_permissions(request: PermissionValidationRequest):
"""Validate a list of permissions"""
try:
validation_result = permission_registry.validate_permissions(request.permissions)
validation_result = permission_registry.validate_permissions(
request.permissions
)
return PermissionValidationResponse(**validation_result)
except Exception as e:
logger.error(f"Error validating permissions: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to validate permissions: {str(e)}"
detail=f"Failed to validate permissions: {str(e)}",
)
@router.post("/permissions/check", response_model=PermissionCheckResponse)
async def check_permission(
request: PermissionCheckRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Check if user has a specific permission"""
try:
has_permission = permission_registry.check_permission(
request.user_permissions,
request.required_permission,
request.context
request.user_permissions, request.required_permission, request.context
)
matching_permissions = list(permission_registry.tree.get_matching_permissions(
request.user_permissions
))
matching_permissions = list(
permission_registry.tree.get_matching_permissions(request.user_permissions)
)
return PermissionCheckResponse(
has_permission=has_permission,
matching_permissions=matching_permissions
has_permission=has_permission, matching_permissions=matching_permissions
)
except Exception as e:
logger.error(f"Error checking permission: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to check permission: {str(e)}"
detail=f"Failed to check permission: {str(e)}",
)
@@ -172,7 +175,7 @@ async def get_module_permissions(module_id: str):
resource=perm.resource,
action=perm.action,
description=perm.description,
conditions=getattr(perm, 'conditions', None)
conditions=getattr(perm, "conditions", None),
)
for perm in permissions
]
@@ -181,7 +184,7 @@ async def get_module_permissions(module_id: str):
logger.error(f"Error getting module permissions: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get module permissions: {str(e)}"
detail=f"Failed to get module permissions: {str(e)}",
)
@@ -191,18 +194,19 @@ async def create_role(request: RoleRequest):
"""Create a custom role with specific permissions"""
try:
# Validate permissions first
validation_result = permission_registry.validate_permissions(request.permissions)
validation_result = permission_registry.validate_permissions(
request.permissions
)
if not validation_result["is_valid"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid permissions: {validation_result['invalid']}"
detail=f"Invalid permissions: {validation_result['invalid']}",
)
permission_registry.create_role(request.role_name, request.permissions)
return RoleResponse(
role_name=request.role_name,
permissions=request.permissions
role_name=request.role_name, permissions=request.permissions
)
except HTTPException:
@@ -211,7 +215,7 @@ async def create_role(request: RoleRequest):
logger.error(f"Error creating role: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create role: {str(e)}"
detail=f"Failed to create role: {str(e)}",
)
@@ -220,14 +224,17 @@ async def get_roles():
"""Get all available roles and their permissions"""
try:
# Combine default roles and custom roles
all_roles = {**permission_registry.default_roles, **permission_registry.role_permissions}
all_roles = {
**permission_registry.default_roles,
**permission_registry.role_permissions,
}
return all_roles
except Exception as e:
logger.error(f"Error getting roles: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get roles: {str(e)}"
detail=f"Failed to get roles: {str(e)}",
)
@@ -236,20 +243,17 @@ async def get_role(role_name: str):
"""Get a specific role and its permissions"""
try:
# Check default roles first, then custom roles
permissions = (permission_registry.role_permissions.get(role_name) or
permission_registry.default_roles.get(role_name))
permissions = permission_registry.role_permissions.get(
role_name
) or permission_registry.default_roles.get(role_name)
if permissions is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Role '{role_name}' not found"
detail=f"Role '{role_name}' not found",
)
return RoleResponse(
role_name=role_name,
permissions=permissions,
created=True
)
return RoleResponse(role_name=role_name, permissions=permissions, created=True)
except HTTPException:
raise
@@ -257,7 +261,7 @@ async def get_role(role_name: str):
logger.error(f"Error getting role: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get role: {str(e)}"
detail=f"Failed to get role: {str(e)}",
)
@@ -267,21 +271,20 @@ async def calculate_user_permissions(request: UserPermissionsRequest):
"""Calculate effective permissions for a user based on roles and custom permissions"""
try:
effective_permissions = permission_registry.get_user_permissions(
request.roles,
request.custom_permissions
request.roles, request.custom_permissions
)
return UserPermissionsResponse(
effective_permissions=effective_permissions,
roles=request.roles,
custom_permissions=request.custom_permissions or []
custom_permissions=request.custom_permissions or [],
)
except Exception as e:
logger.error(f"Error calculating user permissions: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to calculate user permissions: {str(e)}"
detail=f"Failed to calculate user permissions: {str(e)}",
)
@@ -293,7 +296,9 @@ async def platform_health():
# Get permission system status
total_permissions = len(permission_registry.tree.permissions)
total_modules = len(permission_registry.module_permissions)
total_roles = len(permission_registry.default_roles) + len(permission_registry.role_permissions)
total_roles = len(permission_registry.default_roles) + len(
permission_registry.role_permissions
)
return {
"status": "healthy",
@@ -302,16 +307,13 @@ async def platform_health():
"permission_system": {
"total_permissions": total_permissions,
"registered_modules": total_modules,
"available_roles": total_roles
}
"available_roles": total_roles,
},
}
except Exception as e:
logger.error(f"Error checking platform health: {str(e)}")
return {
"status": "unhealthy",
"error": str(e)
}
return {"status": "unhealthy", "error": str(e)}
@router.get("/metrics")
@@ -324,17 +326,18 @@ async def platform_metrics():
metrics = {
"permissions": {
"total": len(permission_registry.tree.permissions),
"by_namespace": {ns: len(perms) for ns, perms in namespaces.items()}
"by_namespace": {ns: len(perms) for ns, perms in namespaces.items()},
},
"modules": {
"registered": len(permission_registry.module_permissions),
"names": list(permission_registry.module_permissions.keys())
"names": list(permission_registry.module_permissions.keys()),
},
"roles": {
"default": len(permission_registry.default_roles),
"custom": len(permission_registry.role_permissions),
"total": len(permission_registry.default_roles) + len(permission_registry.role_permissions)
}
"total": len(permission_registry.default_roles)
+ len(permission_registry.role_permissions),
},
}
return metrics
@@ -343,5 +346,5 @@ async def platform_metrics():
logger.error(f"Error getting platform metrics: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get platform metrics: {str(e)}"
detail=f"Failed to get platform metrics: {str(e)}",
)

View File

@@ -46,28 +46,27 @@ async def discover_plugins(
category: str = "",
limit: int = 20,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Discover available plugins from repository"""
try:
tag_list = [tag.strip() for tag in tags.split(",") if tag.strip()] if tags else None
tag_list = (
[tag.strip() for tag in tags.split(",") if tag.strip()] if tags else None
)
plugins = await plugin_discovery.search_available_plugins(
query=query,
tags=tag_list,
category=category if category else None,
limit=limit,
db=db
db=db,
)
return {
"plugins": plugins,
"count": len(plugins),
"query": query,
"filters": {
"tags": tag_list,
"category": category
}
"filters": {"tags": tag_list, "category": category},
}
except Exception as e:
@@ -76,7 +75,9 @@ async def discover_plugins(
@router.get("/categories")
async def get_plugin_categories(current_user: Dict[str, Any] = Depends(get_current_user)):
async def get_plugin_categories(
current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Get available plugin categories"""
try:
categories = await plugin_discovery.get_plugin_categories()
@@ -87,37 +88,32 @@ async def get_plugin_categories(current_user: Dict[str, Any] = Depends(get_curre
raise HTTPException(status_code=500, detail=f"Failed to get categories: {e}")
@router.get("/installed")
async def get_installed_plugins(
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get user's installed plugins"""
try:
plugins = await plugin_discovery.get_installed_plugins(current_user["id"], db)
return {
"plugins": plugins,
"count": len(plugins)
}
return {"plugins": plugins, "count": len(plugins)}
except Exception as e:
logger.error(f"Failed to get installed plugins: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get installed plugins: {e}")
raise HTTPException(
status_code=500, detail=f"Failed to get installed plugins: {e}"
)
@router.get("/updates")
async def check_plugin_updates(
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Check for available plugin updates"""
try:
updates = await plugin_discovery.get_plugin_updates(db)
return {
"updates": updates,
"count": len(updates)
}
return {"updates": updates, "count": len(updates)}
except Exception as e:
logger.error(f"Failed to check updates: {e}")
@@ -130,12 +126,15 @@ async def install_plugin(
request: PluginInstallRequest,
background_tasks: BackgroundTasks,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Install plugin from repository"""
try:
if request.source != "repository":
raise HTTPException(status_code=400, detail="Only repository installation supported via this endpoint")
raise HTTPException(
status_code=400,
detail="Only repository installation supported via this endpoint",
)
# Start installation in background
background_tasks.add_task(
@@ -143,14 +142,14 @@ async def install_plugin(
request.plugin_id,
request.version,
current_user["id"],
db
db,
)
return {
"status": "installation_started",
"plugin_id": request.plugin_id,
"version": request.version,
"message": "Plugin installation started in background"
"message": "Plugin installation started in background",
}
except Exception as e:
@@ -163,17 +162,18 @@ async def install_plugin_from_file(
file: UploadFile = File(...),
background_tasks: BackgroundTasks = None,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Install plugin from uploaded file"""
try:
# Validate file type
if not file.filename.endswith('.zip'):
if not file.filename.endswith(".zip"):
raise HTTPException(status_code=400, detail="Only ZIP files are supported")
# Save uploaded file
import tempfile
with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as temp_file:
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as temp_file:
content = await file.read()
temp_file.write(content)
temp_file_path = temp_file.name
@@ -187,12 +187,13 @@ async def install_plugin_from_file(
return {
"status": "installed",
"result": result,
"message": "Plugin installed successfully"
"message": "Plugin installed successfully",
}
finally:
# Cleanup temp file
import os
os.unlink(temp_file_path)
except Exception as e:
@@ -205,7 +206,7 @@ async def uninstall_plugin(
plugin_id: str,
request: PluginUninstallRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Uninstall plugin"""
try:
@@ -216,7 +217,7 @@ async def uninstall_plugin(
return {
"status": "uninstalled",
"result": result,
"message": "Plugin uninstalled successfully"
"message": "Plugin uninstalled successfully",
}
except Exception as e:
@@ -229,7 +230,7 @@ async def uninstall_plugin(
async def enable_plugin(
plugin_id: str,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Enable plugin"""
try:
@@ -248,7 +249,7 @@ async def enable_plugin(
return {
"status": "enabled",
"plugin_id": plugin_id,
"message": "Plugin enabled successfully"
"message": "Plugin enabled successfully",
}
except Exception as e:
@@ -260,7 +261,7 @@ async def enable_plugin(
async def disable_plugin(
plugin_id: str,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Disable plugin"""
try:
@@ -283,7 +284,7 @@ async def disable_plugin(
return {
"status": "disabled",
"plugin_id": plugin_id,
"message": "Plugin disabled successfully"
"message": "Plugin disabled successfully",
}
except Exception as e:
@@ -295,7 +296,7 @@ async def disable_plugin(
async def load_plugin(
plugin_id: str,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Load plugin into runtime"""
try:
@@ -310,7 +311,9 @@ async def load_plugin(
raise HTTPException(status_code=404, detail="Plugin not found")
if plugin.status != "enabled":
raise HTTPException(status_code=400, detail="Plugin must be enabled to load")
raise HTTPException(
status_code=400, detail="Plugin must be enabled to load"
)
if plugin_id in plugin_loader.loaded_plugins:
raise HTTPException(status_code=400, detail="Plugin already loaded")
@@ -322,11 +325,13 @@ async def load_plugin(
plugin_context = plugin_context_manager.create_plugin_context(
plugin_id=plugin_id,
user_id=str(current_user.get("id", "unknown")), # Use actual user ID
session_type="api_load"
session_type="api_load",
)
# Generate plugin token based on context
plugin_token = plugin_context_manager.generate_plugin_token(plugin_context["context_id"])
plugin_token = plugin_context_manager.generate_plugin_token(
plugin_context["context_id"]
)
# Log plugin loading action
plugin_context_manager.add_audit_trail_entry(
@@ -335,8 +340,8 @@ async def load_plugin(
{
"plugin_dir": str(plugin_dir),
"user_id": current_user.get("id", "unknown"),
"action": "load_plugin_with_sandbox"
}
"action": "load_plugin_with_sandbox",
},
)
await plugin_loader.load_plugin_with_sandbox(plugin_dir, plugin_token)
@@ -344,7 +349,7 @@ async def load_plugin(
return {
"status": "loaded",
"plugin_id": plugin_id,
"message": "Plugin loaded successfully"
"message": "Plugin loaded successfully",
}
except Exception as e:
@@ -354,8 +359,7 @@ async def load_plugin(
@router.post("/{plugin_id}/unload")
async def unload_plugin(
plugin_id: str,
current_user: Dict[str, Any] = Depends(get_current_user)
plugin_id: str, current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Unload plugin from runtime"""
try:
@@ -369,7 +373,7 @@ async def unload_plugin(
return {
"status": "unloaded",
"plugin_id": plugin_id,
"message": "Plugin unloaded successfully"
"message": "Plugin unloaded successfully",
}
except Exception as e:
@@ -382,7 +386,7 @@ async def unload_plugin(
async def get_plugin_configuration(
plugin_id: str,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get plugin configuration for user with automatic decryption"""
try:
@@ -393,27 +397,25 @@ async def get_plugin_configuration(
plugin_id=plugin_id,
user_id=current_user["id"],
db=db,
decrypt_sensitive=False # Don't decrypt sensitive data for API response
decrypt_sensitive=False, # Don't decrypt sensitive data for API response
)
if config_data is not None:
return {
"plugin_id": plugin_id,
"configuration": config_data,
"has_configuration": True
"has_configuration": True,
}
else:
# Get default configuration from manifest
resolved_config = await plugin_config_manager.get_resolved_configuration(
plugin_id=plugin_id,
user_id=current_user["id"],
db=db
plugin_id=plugin_id, user_id=current_user["id"], db=db
)
return {
"plugin_id": plugin_id,
"configuration": resolved_config,
"has_configuration": False
"has_configuration": False,
}
except Exception as e:
@@ -426,7 +428,7 @@ async def save_plugin_configuration(
plugin_id: str,
config_request: Dict[str, Any],
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Save plugin configuration for user with automatic encryption of sensitive fields"""
try:
@@ -444,42 +446,46 @@ async def save_plugin_configuration(
config_data=config_data,
config_name=config_name,
config_description=config_description,
db=db
db=db,
)
return {
"status": "saved",
"plugin_id": plugin_id,
"configuration_id": str(saved_config.id),
"message": "Configuration saved successfully with automatic encryption of sensitive fields"
"message": "Configuration saved successfully with automatic encryption of sensitive fields",
}
except Exception as e:
logger.error(f"Failed to save plugin configuration: {e}")
await db.rollback()
raise HTTPException(status_code=500, detail=f"Failed to save configuration: {e}")
raise HTTPException(
status_code=500, detail=f"Failed to save configuration: {e}"
)
@router.get("/{plugin_id}/schema")
async def get_plugin_configuration_schema(
plugin_id: str,
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get plugin configuration schema from manifest"""
try:
from app.services.plugin_configuration_manager import plugin_config_manager
# Use the new configuration manager to get schema
schema = await plugin_config_manager.get_plugin_configuration_schema(plugin_id, db)
schema = await plugin_config_manager.get_plugin_configuration_schema(
plugin_id, db
)
if not schema:
raise HTTPException(status_code=404, detail=f"No configuration schema available for plugin '{plugin_id}'")
raise HTTPException(
status_code=404,
detail=f"No configuration schema available for plugin '{plugin_id}'",
)
return {
"plugin_id": plugin_id,
"schema": schema
}
return {"plugin_id": plugin_id, "schema": schema}
except HTTPException:
raise
@@ -493,7 +499,7 @@ async def test_plugin_credentials(
plugin_id: str,
test_request: Dict[str, Any],
current_user: Dict[str, Any] = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Test plugin credentials (currently supports Zammad)"""
import httpx
@@ -510,29 +516,36 @@ async def test_plugin_credentials(
plugin = result.scalar_one_or_none()
if not plugin:
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
raise HTTPException(
status_code=404, detail=f"Plugin '{plugin_id}' not found"
)
# Check if this is a Zammad plugin
if plugin.name.lower() != 'zammad':
raise HTTPException(status_code=400, detail=f"Credential testing not supported for plugin '{plugin.name}'")
if plugin.name.lower() != "zammad":
raise HTTPException(
status_code=400,
detail=f"Credential testing not supported for plugin '{plugin.name}'",
)
# Extract credentials from request
zammad_url = test_request.get('zammad_url')
api_token = test_request.get('api_token')
zammad_url = test_request.get("zammad_url")
api_token = test_request.get("api_token")
if not zammad_url or not api_token:
raise HTTPException(status_code=400, detail="Both zammad_url and api_token are required")
raise HTTPException(
status_code=400, detail="Both zammad_url and api_token are required"
)
# Clean up the URL (remove trailing slash)
zammad_url = zammad_url.rstrip('/')
zammad_url = zammad_url.rstrip("/")
# Test credentials by making a read-only API call to Zammad
async with httpx.AsyncClient(timeout=10.0) as client:
# Try to get user info - this is a safe read-only operation
test_url = f"{zammad_url}/api/v1/users/me"
headers = {
'Authorization': f'Token token={api_token}',
'Content-Type': 'application/json'
"Authorization": f"Token token={api_token}",
"Content-Type": "application/json",
}
response = await client.get(test_url, headers=headers)
@@ -540,66 +553,68 @@ async def test_plugin_credentials(
if response.status_code == 200:
# Success - credentials are valid
user_data = response.json()
user_email = user_data.get('email', 'unknown')
user_email = user_data.get("email", "unknown")
return {
"success": True,
"message": f"Credentials verified! Connected as: {user_email}",
"zammad_url": zammad_url,
"user_info": {
"email": user_email,
"firstname": user_data.get('firstname', ''),
"lastname": user_data.get('lastname', '')
}
"firstname": user_data.get("firstname", ""),
"lastname": user_data.get("lastname", ""),
},
}
elif response.status_code == 401:
return {
"success": False,
"message": "Invalid API token. Please check your token and try again.",
"error_code": "invalid_token"
"error_code": "invalid_token",
}
elif response.status_code == 404:
return {
"success": False,
"message": "Zammad URL not found. Please verify the URL is correct.",
"error_code": "invalid_url"
"error_code": "invalid_url",
}
else:
error_text = ""
try:
error_data = response.json()
error_text = error_data.get('error', error_data.get('message', ''))
error_text = error_data.get("error", error_data.get("message", ""))
except:
error_text = response.text[:200]
return {
"success": False,
"message": f"Connection failed (HTTP {response.status_code}): {error_text}",
"error_code": "connection_failed"
"error_code": "connection_failed",
}
except httpx.TimeoutException:
return {
"success": False,
"message": "Connection timeout. Please check the Zammad URL and your network connection.",
"error_code": "timeout"
"error_code": "timeout",
}
except httpx.ConnectError:
return {
"success": False,
"message": "Could not connect to Zammad. Please verify the URL is correct and accessible.",
"error_code": "connection_error"
"error_code": "connection_error",
}
except Exception as e:
logger.error(f"Failed to test plugin credentials: {e}")
return {
"success": False,
"message": f"Test failed: {str(e)}",
"error_code": "unknown_error"
"error_code": "unknown_error",
}
# Background task for plugin installation
async def install_plugin_background(plugin_id: str, version: str, user_id: str, db: AsyncSession):
async def install_plugin_background(
plugin_id: str, version: str, user_id: str, db: AsyncSession
):
"""Background task for plugin installation"""
try:
result = await plugin_installer.install_plugin_from_repository(

View File

@@ -17,7 +17,10 @@ from app.core.security import get_current_user
from app.models.user import User
from app.core.logging import log_api_request
from app.services.llm.service import llm_service
from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage
from app.services.llm.models import (
ChatRequest as LLMChatRequest,
ChatMessage as LLMChatMessage,
)
router = APIRouter()
@@ -59,11 +62,12 @@ class ImprovePromptRequest(BaseModel):
@router.get("/templates", response_model=List[PromptTemplateResponse])
async def list_prompt_templates(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
):
"""Get all prompt templates"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
log_api_request("list_prompt_templates", {"user_id": user_id})
try:
@@ -85,26 +89,36 @@ async def list_prompt_templates(
"is_default": template.is_default,
"is_active": template.is_active,
"version": template.version,
"created_at": template.created_at.isoformat() if template.created_at else None,
"updated_at": template.updated_at.isoformat() if template.updated_at else None
"created_at": template.created_at.isoformat()
if template.created_at
else None,
"updated_at": template.updated_at.isoformat()
if template.updated_at
else None,
}
template_list.append(template_dict)
return template_list
except Exception as e:
log_api_request("list_prompt_templates_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to fetch prompt templates: {str(e)}")
log_api_request(
"list_prompt_templates_error", {"error": str(e), "user_id": user_id}
)
raise HTTPException(
status_code=500, detail=f"Failed to fetch prompt templates: {str(e)}"
)
@router.get("/templates/{type_key}", response_model=PromptTemplateResponse)
async def get_prompt_template(
type_key: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get a specific prompt template by type key"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
log_api_request("get_prompt_template", {"user_id": user_id, "type_key": type_key})
try:
@@ -127,15 +141,23 @@ async def get_prompt_template(
"is_default": template.is_default,
"is_active": template.is_active,
"version": template.version,
"created_at": template.created_at.isoformat() if template.created_at else None,
"updated_at": template.updated_at.isoformat() if template.updated_at else None
"created_at": template.created_at.isoformat()
if template.created_at
else None,
"updated_at": template.updated_at.isoformat()
if template.updated_at
else None,
}
except HTTPException:
raise
except Exception as e:
log_api_request("get_prompt_template_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to fetch prompt template: {str(e)}")
log_api_request(
"get_prompt_template_error", {"error": str(e), "user_id": user_id}
)
raise HTTPException(
status_code=500, detail=f"Failed to fetch prompt template: {str(e)}"
)
@router.put("/templates/{type_key}")
@@ -143,15 +165,16 @@ async def update_prompt_template(
type_key: str,
request: PromptTemplateRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Update a prompt template"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("update_prompt_template", {
"user_id": user_id,
"type_key": type_key,
"name": request.name
})
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
log_api_request(
"update_prompt_template",
{"user_id": user_id, "type_key": type_key, "name": request.name},
)
try:
# Get existing template
@@ -175,7 +198,7 @@ async def update_prompt_template(
system_prompt=request.system_prompt,
is_active=request.is_active,
version=template.version + 1,
updated_at=datetime.utcnow()
updated_at=datetime.utcnow(),
)
)
@@ -183,8 +206,7 @@ async def update_prompt_template(
# Return updated template
updated_result = await db.execute(
select(PromptTemplate)
.where(PromptTemplate.type_key == type_key)
select(PromptTemplate).where(PromptTemplate.type_key == type_key)
)
updated_template = updated_result.scalar_one()
@@ -197,40 +219,51 @@ async def update_prompt_template(
"is_default": updated_template.is_default,
"is_active": updated_template.is_active,
"version": updated_template.version,
"created_at": updated_template.created_at.isoformat() if updated_template.created_at else None,
"updated_at": updated_template.updated_at.isoformat() if updated_template.updated_at else None
"created_at": updated_template.created_at.isoformat()
if updated_template.created_at
else None,
"updated_at": updated_template.updated_at.isoformat()
if updated_template.updated_at
else None,
}
except HTTPException:
raise
except Exception as e:
await db.rollback()
log_api_request("update_prompt_template_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to update prompt template: {str(e)}")
log_api_request(
"update_prompt_template_error", {"error": str(e), "user_id": user_id}
)
raise HTTPException(
status_code=500, detail=f"Failed to update prompt template: {str(e)}"
)
@router.post("/templates/create")
async def create_prompt_template(
request: PromptTemplateRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Create a new prompt template"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("create_prompt_template", {
"user_id": user_id,
"type_key": request.type_key,
"name": request.name
})
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
log_api_request(
"create_prompt_template",
{"user_id": user_id, "type_key": request.type_key, "name": request.name},
)
try:
# Check if template already exists
existing_result = await db.execute(
select(PromptTemplate)
.where(PromptTemplate.type_key == request.type_key)
select(PromptTemplate).where(PromptTemplate.type_key == request.type_key)
)
if existing_result.scalar_one_or_none():
raise HTTPException(status_code=400, detail="Prompt template with this type key already exists")
raise HTTPException(
status_code=400,
detail="Prompt template with this type key already exists",
)
# Create new template
template = PromptTemplate(
@@ -243,7 +276,7 @@ async def create_prompt_template(
is_active=request.is_active,
version=1,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
updated_at=datetime.utcnow(),
)
db.add(template)
@@ -259,25 +292,34 @@ async def create_prompt_template(
"is_default": template.is_default,
"is_active": template.is_active,
"version": template.version,
"created_at": template.created_at.isoformat() if template.created_at else None,
"updated_at": template.updated_at.isoformat() if template.updated_at else None
"created_at": template.created_at.isoformat()
if template.created_at
else None,
"updated_at": template.updated_at.isoformat()
if template.updated_at
else None,
}
except HTTPException:
raise
except Exception as e:
await db.rollback()
log_api_request("create_prompt_template_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to create prompt template: {str(e)}")
log_api_request(
"create_prompt_template_error", {"error": str(e), "user_id": user_id}
)
raise HTTPException(
status_code=500, detail=f"Failed to create prompt template: {str(e)}"
)
@router.get("/variables", response_model=List[PromptVariableResponse])
async def list_prompt_variables(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
):
"""Get all available prompt variables"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
log_api_request("list_prompt_variables", {"user_id": user_id})
try:
@@ -295,25 +337,31 @@ async def list_prompt_variables(
"variable_name": variable.variable_name,
"description": variable.description,
"example_value": variable.example_value,
"is_active": variable.is_active
"is_active": variable.is_active,
}
variable_list.append(variable_dict)
return variable_list
except Exception as e:
log_api_request("list_prompt_variables_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to fetch prompt variables: {str(e)}")
log_api_request(
"list_prompt_variables_error", {"error": str(e), "user_id": user_id}
)
raise HTTPException(
status_code=500, detail=f"Failed to fetch prompt variables: {str(e)}"
)
@router.post("/templates/{type_key}/reset")
async def reset_prompt_template(
type_key: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Reset a prompt template to its default"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
log_api_request("reset_prompt_template", {"user_id": user_id, "type_key": type_key})
# Define default prompts (same as in migration)
@@ -323,7 +371,7 @@ async def reset_prompt_template(
"teacher": "You are an experienced educational tutor and learning facilitator. Break down complex concepts into understandable, digestible parts. Use analogies, examples, and step-by-step explanations to help students learn. Encourage critical thinking through thoughtful questions. Be patient, supportive, and encouraging. Adapt your teaching style to different learning preferences. When a student makes mistakes, guide them to the correct answer rather than just providing it.",
"researcher": "You are a thorough research assistant with a focus on accuracy and evidence-based information. Provide well-researched, factual information with sources when possible. Be thorough in your analysis and present multiple perspectives when relevant topics have different viewpoints. Always distinguish between established facts, current research, and opinions. When information is uncertain or contested, clearly communicate the level of confidence and supporting evidence.",
"creative_writer": "You are an experienced creative writing mentor and storytelling expert. Help with brainstorming ideas, character development, plot structure, dialogue, and creative expression. Be imaginative and inspiring while providing constructive, actionable feedback. Encourage experimentation with different writing styles and techniques. When reviewing work, balance praise for strengths with specific suggestions for improvement. Help writers find their unique voice while mastering fundamental storytelling principles.",
"custom": "You are a helpful AI assistant. Your personality, expertise, and behavior will be defined by the user through custom instructions. Follow the user's guidance on how to respond, what tone to use, and what role to play. Be adaptable and responsive to the specific needs and preferences outlined in your configuration."
"custom": "You are a helpful AI assistant. Your personality, expertise, and behavior will be defined by the user through custom instructions. Follow the user's guidance on how to respond, what tone to use, and what role to play. Be adaptable and responsive to the specific needs and preferences outlined in your configuration.",
}
if type_key not in default_prompts:
@@ -337,7 +385,7 @@ async def reset_prompt_template(
.values(
system_prompt=default_prompts[type_key],
version=PromptTemplate.version + 1,
updated_at=datetime.utcnow()
updated_at=datetime.utcnow(),
)
)
@@ -347,22 +395,28 @@ async def reset_prompt_template(
except Exception as e:
await db.rollback()
log_api_request("reset_prompt_template_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to reset prompt template: {str(e)}")
log_api_request(
"reset_prompt_template_error", {"error": str(e), "user_id": user_id}
)
raise HTTPException(
status_code=500, detail=f"Failed to reset prompt template: {str(e)}"
)
@router.post("/improve")
async def improve_prompt_with_ai(
request: ImprovePromptRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Improve a prompt using AI"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("improve_prompt_with_ai", {
"user_id": user_id,
"chatbot_type": request.chatbot_type
})
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
log_api_request(
"improve_prompt_with_ai",
{"user_id": user_id, "chatbot_type": request.chatbot_type},
)
try:
# Create system message for improvement
@@ -392,7 +446,7 @@ Please improve this prompt to make it more effective for a {request.chatbot_type
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
{"role": "user", "content": user_message},
]
# Get available models to use a default model
@@ -406,11 +460,14 @@ Please improve this prompt to make it more effective for a {request.chatbot_type
# Prepare the chat request for the new LLM service
chat_request = LLMChatRequest(
model=default_model,
messages=[LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages],
messages=[
LLMChatMessage(role=msg["role"], content=msg["content"])
for msg in messages
],
temperature=0.3,
max_tokens=1000,
user_id=str(user_id),
api_key_id=1 # Using default API key, you might want to make this dynamic
api_key_id=1, # Using default API key, you might want to make this dynamic
)
# Make the AI call
@@ -422,23 +479,28 @@ Please improve this prompt to make it more effective for a {request.chatbot_type
return {
"improved_prompt": improved_prompt,
"original_prompt": request.current_prompt,
"model_used": default_model
"model_used": default_model,
}
except HTTPException:
raise
except Exception as e:
log_api_request("improve_prompt_with_ai_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to improve prompt: {str(e)}")
log_api_request(
"improve_prompt_with_ai_error", {"error": str(e), "user_id": user_id}
)
raise HTTPException(
status_code=500, detail=f"Failed to improve prompt: {str(e)}"
)
@router.post("/seed-defaults")
async def seed_default_templates(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
):
"""Seed default prompt templates for all chatbot types"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
user_id = (
current_user.get("id") if isinstance(current_user, dict) else current_user.id
)
log_api_request("seed_default_templates", {"user_id": user_id})
# Define default prompts (same as in reset)
@@ -446,33 +508,33 @@ async def seed_default_templates(
"assistant": {
"name": "General Assistant",
"description": "A helpful, accurate, and friendly AI assistant",
"prompt": "You are a helpful AI assistant. Provide accurate, concise, and friendly responses. Always aim to be helpful while being honest about your limitations. When you don't know something, say so clearly. Be professional but approachable in your communication style."
"prompt": "You are a helpful AI assistant. Provide accurate, concise, and friendly responses. Always aim to be helpful while being honest about your limitations. When you don't know something, say so clearly. Be professional but approachable in your communication style.",
},
"customer_support": {
"name": "Customer Support Agent",
"description": "Professional customer service representative focused on solving problems",
"prompt": "You are a professional customer support representative. Be empathetic, professional, and solution-focused in all interactions. Always try to understand the customer's issue fully before providing solutions. Use the knowledge base to provide accurate information. When you cannot resolve an issue, explain clearly how the customer can escalate or get further help. Maintain a helpful and patient tone even in difficult situations."
"prompt": "You are a professional customer support representative. Be empathetic, professional, and solution-focused in all interactions. Always try to understand the customer's issue fully before providing solutions. Use the knowledge base to provide accurate information. When you cannot resolve an issue, explain clearly how the customer can escalate or get further help. Maintain a helpful and patient tone even in difficult situations.",
},
"teacher": {
"name": "Educational Tutor",
"description": "Patient and encouraging educational facilitator",
"prompt": "You are an experienced educational tutor and learning facilitator. Break down complex concepts into understandable, digestible parts. Use analogies, examples, and step-by-step explanations to help students learn. Encourage critical thinking through thoughtful questions. Be patient, supportive, and encouraging. Adapt your teaching style to different learning preferences. When a student makes mistakes, guide them to the correct answer rather than just providing it."
"prompt": "You are an experienced educational tutor and learning facilitator. Break down complex concepts into understandable, digestible parts. Use analogies, examples, and step-by-step explanations to help students learn. Encourage critical thinking through thoughtful questions. Be patient, supportive, and encouraging. Adapt your teaching style to different learning preferences. When a student makes mistakes, guide them to the correct answer rather than just providing it.",
},
"researcher": {
"name": "Research Assistant",
"description": "Thorough researcher focused on evidence-based information",
"prompt": "You are a thorough research assistant with a focus on accuracy and evidence-based information. Provide well-researched, factual information with sources when possible. Be thorough in your analysis and present multiple perspectives when relevant topics have different viewpoints. Always distinguish between established facts, current research, and opinions. When information is uncertain or contested, clearly communicate the level of confidence and supporting evidence."
"prompt": "You are a thorough research assistant with a focus on accuracy and evidence-based information. Provide well-researched, factual information with sources when possible. Be thorough in your analysis and present multiple perspectives when relevant topics have different viewpoints. Always distinguish between established facts, current research, and opinions. When information is uncertain or contested, clearly communicate the level of confidence and supporting evidence.",
},
"creative_writer": {
"name": "Creative Writing Mentor",
"description": "Imaginative storytelling expert and writing coach",
"prompt": "You are an experienced creative writing mentor and storytelling expert. Help with brainstorming ideas, character development, plot structure, dialogue, and creative expression. Be imaginative and inspiring while providing constructive, actionable feedback. Encourage experimentation with different writing styles and techniques. When reviewing work, balance praise for strengths with specific suggestions for improvement. Help writers find their unique voice while mastering fundamental storytelling principles."
"prompt": "You are an experienced creative writing mentor and storytelling expert. Help with brainstorming ideas, character development, plot structure, dialogue, and creative expression. Be imaginative and inspiring while providing constructive, actionable feedback. Encourage experimentation with different writing styles and techniques. When reviewing work, balance praise for strengths with specific suggestions for improvement. Help writers find their unique voice while mastering fundamental storytelling principles.",
},
"custom": {
"name": "Custom Chatbot",
"description": "Customizable AI assistant with user-defined behavior",
"prompt": "You are a helpful AI assistant. Your personality, expertise, and behavior will be defined by the user through custom instructions. Follow the user's guidance on how to respond, what tone to use, and what role to play. Be adaptable and responsive to the specific needs and preferences outlined in your configuration."
}
"prompt": "You are a helpful AI assistant. Your personality, expertise, and behavior will be defined by the user through custom instructions. Follow the user's guidance on how to respond, what tone to use, and what role to play. Be adaptable and responsive to the specific needs and preferences outlined in your configuration.",
},
}
created_templates = []
@@ -530,7 +592,9 @@ async def seed_default_templates(
created_at=now,
updated_at=now,
)
.on_conflict_do_nothing(index_elements=[PromptTemplate.type_key])
.on_conflict_do_nothing(
index_elements=[PromptTemplate.type_key]
)
)
result = await db.execute(stmt)
@@ -548,10 +612,14 @@ async def seed_default_templates(
"message": "Default templates seeded successfully",
"created": created_templates,
"updated": updated_templates,
"total": len(created_templates) + len(updated_templates)
"total": len(created_templates) + len(updated_templates),
}
except Exception as e:
await db.rollback()
log_api_request("seed_default_templates_error", {"error": str(e), "user_id": user_id})
raise HTTPException(status_code=500, detail=f"Failed to seed default templates: {str(e)}")
log_api_request(
"seed_default_templates_error", {"error": str(e), "user_id": user_id}
)
raise HTTPException(
status_code=500, detail=f"Failed to seed default templates: {str(e)}"
)

View File

@@ -29,6 +29,7 @@ router = APIRouter(tags=["RAG"])
# Request/Response Models
class CollectionCreate(BaseModel):
name: str
description: Optional[str] = None
@@ -78,12 +79,13 @@ class StatsResponse(BaseModel):
# Collection Endpoints
@router.get("/collections", response_model=dict)
async def get_collections(
skip: int = 0,
limit: int = 100,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Get all RAG collections - live data directly from Qdrant (source of truth)"""
try:
@@ -103,7 +105,7 @@ async def get_collections(
"collections": paginated_collections,
"total": len(collections),
"total_documents": stats_data.get("total_documents", 0),
"total_size_bytes": stats_data.get("total_size_bytes", 0)
"total_size_bytes": stats_data.get("total_size_bytes", 0),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -113,20 +115,19 @@ async def get_collections(
async def create_collection(
collection_data: CollectionCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Create a new RAG collection"""
try:
rag_service = RAGService(db)
collection = await rag_service.create_collection(
name=collection_data.name,
description=collection_data.description
name=collection_data.name, description=collection_data.description
)
return {
"success": True,
"collection": collection.to_dict(),
"message": "Collection created successfully"
"message": "Collection created successfully",
}
except APIException as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
@@ -136,8 +137,7 @@ async def create_collection(
@router.get("/stats", response_model=dict)
async def get_rag_stats(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user)
):
"""Get overall RAG statistics - live data directly from Qdrant"""
try:
@@ -147,7 +147,11 @@ async def get_rag_stats(
stats_data = await qdrant_stats_service.get_collections_stats()
# Calculate active collections (collections with documents)
active_collections = sum(1 for col in stats_data.get("collections", []) if col.get("document_count", 0) > 0)
active_collections = sum(
1
for col in stats_data.get("collections", [])
if col.get("document_count", 0) > 0
)
# Calculate processing documents from database
processing_docs = 0
@@ -156,7 +160,9 @@ async def get_rag_stats(
from app.models.rag_document import RagDocument, ProcessingStatus
result = await db.execute(
select(RagDocument).where(RagDocument.status == ProcessingStatus.PROCESSING)
select(RagDocument).where(
RagDocument.status == ProcessingStatus.PROCESSING
)
)
processing_docs = len(result.scalars().all())
except Exception:
@@ -167,22 +173,28 @@ async def get_rag_stats(
"stats": {
"collections": {
"total": stats_data.get("total_collections", 0),
"active": active_collections
"active": active_collections,
},
"documents": {
"total": stats_data.get("total_documents", 0),
"processing": processing_docs,
"processed": stats_data.get("total_documents", 0) # Indexed documents
"processed": stats_data.get(
"total_documents", 0
), # Indexed documents
},
"storage": {
"total_size_bytes": stats_data.get("total_size_bytes", 0),
"total_size_mb": round(stats_data.get("total_size_bytes", 0) / (1024 * 1024), 2)
"total_size_mb": round(
stats_data.get("total_size_bytes", 0) / (1024 * 1024), 2
),
},
"vectors": {
"total": stats_data.get("total_documents", 0) # Same as documents for RAG
"total": stats_data.get(
"total_documents", 0
) # Same as documents for RAG
},
"last_updated": datetime.utcnow().isoformat(),
},
"last_updated": datetime.utcnow().isoformat()
}
}
return response_data
@@ -194,7 +206,7 @@ async def get_rag_stats(
async def get_collection(
collection_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Get a specific collection"""
try:
@@ -204,10 +216,7 @@ async def get_collection(
if not collection:
raise HTTPException(status_code=404, detail="Collection not found")
return {
"success": True,
"collection": collection.to_dict()
}
return {"success": True, "collection": collection.to_dict()}
except HTTPException:
raise
except Exception as e:
@@ -219,7 +228,7 @@ async def delete_collection(
collection_id: int,
cascade: bool = True, # Default to cascade deletion for better UX
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Delete a collection and optionally all its documents"""
try:
@@ -231,7 +240,8 @@ async def delete_collection(
return {
"success": True,
"message": "Collection deleted successfully" + (" (with documents)" if cascade else "")
"message": "Collection deleted successfully"
+ (" (with documents)" if cascade else ""),
}
except APIException as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
@@ -243,13 +253,14 @@ async def delete_collection(
# Document Endpoints
@router.get("/documents", response_model=dict)
async def get_documents(
collection_id: Optional[str] = None,
skip: int = 0,
limit: int = 100,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Get documents, optionally filtered by collection"""
try:
@@ -260,11 +271,7 @@ async def get_documents(
if collection_id.startswith("ext_"):
# External collections exist only in Qdrant and have no documents in PostgreSQL
# Return empty list since they don't have managed documents
return {
"success": True,
"documents": [],
"total": 0
}
return {"success": True, "documents": [], "total": 0}
else:
# Try to convert to integer for managed collections
try:
@@ -272,29 +279,25 @@ async def get_documents(
except (ValueError, TypeError):
# Attempt to resolve by Qdrant collection name
collection_row = await db.scalar(
select(RagCollection).where(RagCollection.qdrant_collection_name == collection_id)
select(RagCollection).where(
RagCollection.qdrant_collection_name == collection_id
)
)
if collection_row:
collection_id_int = collection_row.id
else:
# Unknown collection identifier; return empty result instead of erroring out
return {
"success": True,
"documents": [],
"total": 0
}
return {"success": True, "documents": [], "total": 0}
rag_service = RAGService(db)
documents = await rag_service.get_documents(
collection_id=collection_id_int,
skip=skip,
limit=limit
collection_id=collection_id_int, skip=skip, limit=limit
)
return {
"success": True,
"documents": [doc.to_dict() for doc in documents],
"total": len(documents)
"total": len(documents),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -305,13 +308,13 @@ async def upload_document(
collection_id: str = Form(...),
file: UploadFile = File(...),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Upload and process a document"""
try:
# Validate file can be read before processing
filename = file.filename or "unknown"
file_extension = filename.split('.')[-1].lower() if '.' in filename else ''
file_extension = filename.split(".")[-1].lower() if "." in filename else ""
# Read file content once and use it for all validations
file_content = await file.read()
@@ -324,50 +327,66 @@ async def upload_document(
try:
# Test file readability based on type
if file_extension == 'jsonl':
if file_extension == "jsonl":
# Validate JSONL format - try to parse first few lines
try:
content_str = file_content.decode('utf-8')
lines = content_str.strip().split('\n')[:5] # Check first 5 lines
content_str = file_content.decode("utf-8")
lines = content_str.strip().split("\n")[:5] # Check first 5 lines
import json
for i, line in enumerate(lines):
if line.strip(): # Skip empty lines
json.loads(line) # Will raise JSONDecodeError if invalid
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
raise HTTPException(
status_code=400, detail="File is not valid UTF-8 text"
)
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSONL format: {str(e)}")
raise HTTPException(
status_code=400, detail=f"Invalid JSONL format: {str(e)}"
)
elif file_extension in ['txt', 'md', 'py', 'js', 'html', 'css', 'json']:
elif file_extension in ["txt", "md", "py", "js", "html", "css", "json"]:
# Validate text files can be decoded
try:
file_content.decode('utf-8')
file_content.decode("utf-8")
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
raise HTTPException(
status_code=400, detail="File is not valid UTF-8 text"
)
elif file_extension in ['pdf']:
elif file_extension in ["pdf"]:
# For PDF files, just check if it starts with PDF signature
if not file_content.startswith(b'%PDF'):
raise HTTPException(status_code=400, detail="Invalid PDF file format")
if not file_content.startswith(b"%PDF"):
raise HTTPException(
status_code=400, detail="Invalid PDF file format"
)
elif file_extension in ['docx', 'xlsx', 'pptx']:
elif file_extension in ["docx", "xlsx", "pptx"]:
# For Office documents, check ZIP signature
if not file_content.startswith(b'PK'):
raise HTTPException(status_code=400, detail=f"Invalid {file_extension.upper()} file format")
if not file_content.startswith(b"PK"):
raise HTTPException(
status_code=400,
detail=f"Invalid {file_extension.upper()} file format",
)
# For other file types, we'll rely on the document processor
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=400, detail=f"File validation failed: {str(e)}")
raise HTTPException(
status_code=400, detail=f"File validation failed: {str(e)}"
)
rag_service = RAGService(db)
# Resolve collection identifier (supports both numeric IDs and Qdrant collection names)
collection_identifier = (collection_id or "").strip()
if not collection_identifier:
raise HTTPException(status_code=400, detail="Collection identifier is required")
raise HTTPException(
status_code=400, detail="Collection identifier is required"
)
resolved_collection_id: Optional[int] = None
@@ -379,7 +398,9 @@ async def upload_document(
qdrant_name = qdrant_name[4:]
try:
collection_record = await rag_service.ensure_collection_record(qdrant_name)
collection_record = await rag_service.ensure_collection_record(
qdrant_name
)
except Exception as ensure_error:
raise HTTPException(status_code=500, detail=str(ensure_error))
@@ -392,13 +413,13 @@ async def upload_document(
collection_id=resolved_collection_id,
file_content=file_content,
filename=filename,
content_type=file.content_type
content_type=file.content_type,
)
return {
"success": True,
"document": document.to_dict(),
"message": "Document uploaded and processing started"
"message": "Document uploaded and processing started",
}
except APIException as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
@@ -412,7 +433,7 @@ async def upload_document(
async def get_document(
document_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Get a specific document"""
try:
@@ -422,10 +443,7 @@ async def get_document(
if not document:
raise HTTPException(status_code=404, detail="Document not found")
return {
"success": True,
"document": document.to_dict()
}
return {"success": True, "document": document.to_dict()}
except HTTPException:
raise
except Exception as e:
@@ -436,7 +454,7 @@ async def get_document(
async def delete_document(
document_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Delete a document"""
try:
@@ -446,10 +464,7 @@ async def delete_document(
if not success:
raise HTTPException(status_code=404, detail="Document not found")
return {
"success": True,
"message": "Document deleted successfully"
}
return {"success": True, "message": "Document deleted successfully"}
except HTTPException:
raise
except Exception as e:
@@ -460,7 +475,7 @@ async def delete_document(
async def reprocess_document(
document_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Restart processing for a stuck or failed document"""
try:
@@ -475,12 +490,12 @@ async def reprocess_document(
else:
raise HTTPException(
status_code=400,
detail=f"Cannot reprocess document with status '{document.status}'. Only 'processing' or 'error' documents can be reprocessed."
detail=f"Cannot reprocess document with status '{document.status}'. Only 'processing' or 'error' documents can be reprocessed.",
)
return {
"success": True,
"message": "Document reprocessing started successfully"
"message": "Document reprocessing started successfully",
}
except APIException as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
@@ -494,7 +509,7 @@ async def reprocess_document(
async def download_document(
document_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""Download the original document file"""
try:
@@ -502,14 +517,16 @@ async def download_document(
result = await rag_service.download_document(document_id)
if not result:
raise HTTPException(status_code=404, detail="Document not found or file not available")
raise HTTPException(
status_code=404, detail="Document not found or file not available"
)
content, filename, mime_type = result
return StreamingResponse(
io.BytesIO(content),
media_type=mime_type,
headers={"Content-Disposition": f"attachment; filename={filename}"}
headers={"Content-Disposition": f"attachment; filename={filename}"},
)
except HTTPException:
raise
@@ -517,9 +534,9 @@ async def download_document(
raise HTTPException(status_code=500, detail=str(e))
# Debug Endpoints
@router.post("/debug/search")
async def search_with_debug(
query: str,
@@ -527,13 +544,13 @@ async def search_with_debug(
score_threshold: float = 0.3,
collection_name: str = None,
config: Dict[str, Any] = None,
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
) -> Dict[str, Any]:
"""
Enhanced search with comprehensive debug information
"""
# Get RAG module from module manager
rag_module = module_manager.modules.get('rag')
rag_module = module_manager.modules.get("rag")
if not rag_module or not rag_module.enabled:
raise HTTPException(status_code=503, detail="RAG module not initialized")
@@ -567,7 +584,7 @@ async def search_with_debug(
query,
max_results=max_results,
score_threshold=score_threshold,
collection_name=collection_name
collection_name=collection_name,
)
search_time = (asyncio.get_event_loop().time() - search_start) * 1000
@@ -575,22 +592,23 @@ async def search_with_debug(
scores = [r.score for r in results if r.score is not None]
if scores:
import statistics
debug_info["score_stats"] = {
"min": min(scores),
"max": max(scores),
"avg": statistics.mean(scores),
"stddev": statistics.stdev(scores) if len(scores) > 1 else 0
"stddev": statistics.stdev(scores) if len(scores) > 1 else 0,
}
# Get collection statistics
try:
from qdrant_client.http.models import Filter
collection_name = collection_name or rag_module.default_collection_name
# Count total documents
count_result = rag_module.qdrant_client.count(
collection_name=collection_name,
count_filter=Filter(must=[])
collection_name=collection_name, count_filter=Filter(must=[])
)
total_points = count_result.count
@@ -599,7 +617,7 @@ async def search_with_debug(
collection_name=collection_name,
limit=1000, # Sample for stats
with_payload=True,
with_vectors=False
with_vectors=False,
)
unique_docs = set()
@@ -618,7 +636,7 @@ async def search_with_debug(
debug_info["collection_stats"] = {
"total_documents": len(unique_docs),
"total_chunks": total_points,
"languages": sorted(list(languages))
"languages": sorted(list(languages)),
}
except Exception as e:
@@ -631,16 +649,18 @@ async def search_with_debug(
"document": {
"id": result.document.id,
"content": result.document.content,
"metadata": result.document.metadata
"metadata": result.document.metadata,
},
"score": result.score,
"debug_info": {}
"debug_info": {},
}
# Add hybrid search debug info if available
metadata = result.document.metadata or {}
if "_vector_score" in metadata:
enhanced_result["debug_info"]["vector_score"] = metadata["_vector_score"]
enhanced_result["debug_info"]["vector_score"] = metadata[
"_vector_score"
]
if "_bm25_score" in metadata:
enhanced_result["debug_info"]["bm25_score"] = metadata["_bm25_score"]
@@ -652,7 +672,7 @@ async def search_with_debug(
"results": enhanced_results,
"debug_info": debug_info,
"search_time_ms": search_time,
"timestamp": start_time.isoformat()
"timestamp": start_time.isoformat(),
}
except Exception as e:
@@ -661,17 +681,17 @@ async def search_with_debug(
finally:
# Restore original config if modified
if config and 'original_config' in locals():
if config and "original_config" in locals():
rag_module.config = original_config
@router.get("/debug/config")
async def get_current_config(
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
) -> Dict[str, Any]:
"""Get current RAG configuration"""
# Get RAG module from module manager
rag_module = module_manager.modules.get('rag')
rag_module = module_manager.modules.get("rag")
if not rag_module or not rag_module.enabled:
raise HTTPException(status_code=503, detail="RAG module not initialized")
@@ -679,5 +699,5 @@ async def get_current_config(
"config": rag_module.config,
"embedding_model": rag_module.embedding_model,
"enabled": rag_module.enabled,
"collections": await rag_module._get_collections_safely()
"collections": await rag_module._get_collections_safely(),
}

View File

@@ -88,71 +88,255 @@ class SecurityConfigResponse(BaseModel):
# Global settings storage (in a real app, this would be in database)
SETTINGS_STORE: Dict[str, Dict[str, Any]] = {
"platform": {
"app_name": {"value": "Confidential Empire", "type": "string", "description": "Application name"},
"maintenance_mode": {"value": False, "type": "boolean", "description": "Enable maintenance mode"},
"maintenance_message": {"value": None, "type": "string", "description": "Maintenance mode message"},
"debug_mode": {"value": False, "type": "boolean", "description": "Enable debug mode"},
"max_upload_size": {"value": 10485760, "type": "integer", "description": "Maximum upload size in bytes"},
"app_name": {
"value": "Confidential Empire",
"type": "string",
"description": "Application name",
},
"maintenance_mode": {
"value": False,
"type": "boolean",
"description": "Enable maintenance mode",
},
"maintenance_message": {
"value": None,
"type": "string",
"description": "Maintenance mode message",
},
"debug_mode": {
"value": False,
"type": "boolean",
"description": "Enable debug mode",
},
"max_upload_size": {
"value": 10485760,
"type": "integer",
"description": "Maximum upload size in bytes",
},
},
"api": {
# Security Settings
"security_enabled": {"value": True, "type": "boolean", "description": "Enable API security system"},
"rate_limiting_enabled": {"value": True, "type": "boolean", "description": "Enable rate limiting"},
"ip_reputation_enabled": {"value": True, "type": "boolean", "description": "Enable IP reputation checking"},
"anomaly_detection_enabled": {"value": True, "type": "boolean", "description": "Enable anomaly detection"},
"security_headers_enabled": {"value": True, "type": "boolean", "description": "Enable security headers"},
"security_enabled": {
"value": True,
"type": "boolean",
"description": "Enable API security system",
},
"rate_limiting_enabled": {
"value": True,
"type": "boolean",
"description": "Enable rate limiting",
},
"ip_reputation_enabled": {
"value": True,
"type": "boolean",
"description": "Enable IP reputation checking",
},
"anomaly_detection_enabled": {
"value": True,
"type": "boolean",
"description": "Enable anomaly detection",
},
"security_headers_enabled": {
"value": True,
"type": "boolean",
"description": "Enable security headers",
},
# Rate Limiting by Authentication Level
"rate_limit_authenticated_per_minute": {"value": 200, "type": "integer", "description": "Rate limit for authenticated users per minute"},
"rate_limit_authenticated_per_hour": {"value": 5000, "type": "integer", "description": "Rate limit for authenticated users per hour"},
"rate_limit_api_key_per_minute": {"value": 1000, "type": "integer", "description": "Rate limit for API key users per minute"},
"rate_limit_api_key_per_hour": {"value": 20000, "type": "integer", "description": "Rate limit for API key users per hour"},
"rate_limit_premium_per_minute": {"value": 5000, "type": "integer", "description": "Rate limit for premium users per minute"},
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer", "description": "Rate limit for premium users per hour"},
"rate_limit_authenticated_per_minute": {
"value": 200,
"type": "integer",
"description": "Rate limit for authenticated users per minute",
},
"rate_limit_authenticated_per_hour": {
"value": 5000,
"type": "integer",
"description": "Rate limit for authenticated users per hour",
},
"rate_limit_api_key_per_minute": {
"value": 1000,
"type": "integer",
"description": "Rate limit for API key users per minute",
},
"rate_limit_api_key_per_hour": {
"value": 20000,
"type": "integer",
"description": "Rate limit for API key users per hour",
},
"rate_limit_premium_per_minute": {
"value": 5000,
"type": "integer",
"description": "Rate limit for premium users per minute",
},
"rate_limit_premium_per_hour": {
"value": 100000,
"type": "integer",
"description": "Rate limit for premium users per hour",
},
# Security Thresholds
"security_warning_threshold": {"value": 0.6, "type": "float", "description": "Risk score threshold for warnings (0.0-1.0)"},
"anomaly_threshold": {"value": 0.7, "type": "float", "description": "Anomaly severity threshold (0.0-1.0)"},
"security_warning_threshold": {
"value": 0.6,
"type": "float",
"description": "Risk score threshold for warnings (0.0-1.0)",
},
"anomaly_threshold": {
"value": 0.7,
"type": "float",
"description": "Anomaly severity threshold (0.0-1.0)",
},
# Request Settings
"max_request_size_mb": {"value": 10, "type": "integer", "description": "Maximum request size in MB for standard users"},
"max_request_size_premium_mb": {"value": 50, "type": "integer", "description": "Maximum request size in MB for premium users"},
"enable_cors": {"value": True, "type": "boolean", "description": "Enable CORS headers"},
"cors_origins": {"value": ["http://localhost:3000", "http://localhost:53000"], "type": "list", "description": "Allowed CORS origins"},
"api_key_expiry_days": {"value": 90, "type": "integer", "description": "Default API key expiry in days"},
"max_request_size_mb": {
"value": 10,
"type": "integer",
"description": "Maximum request size in MB for standard users",
},
"max_request_size_premium_mb": {
"value": 50,
"type": "integer",
"description": "Maximum request size in MB for premium users",
},
"enable_cors": {
"value": True,
"type": "boolean",
"description": "Enable CORS headers",
},
"cors_origins": {
"value": ["http://localhost:3000", "http://localhost:53000"],
"type": "list",
"description": "Allowed CORS origins",
},
"api_key_expiry_days": {
"value": 90,
"type": "integer",
"description": "Default API key expiry in days",
},
# IP Security
"blocked_ips": {"value": [], "type": "list", "description": "List of blocked IP addresses"},
"allowed_ips": {"value": [], "type": "list", "description": "List of allowed IP addresses (empty = allow all)"},
"csp_header": {"value": "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline';", "type": "string", "description": "Content Security Policy header"},
"blocked_ips": {
"value": [],
"type": "list",
"description": "List of blocked IP addresses",
},
"allowed_ips": {
"value": [],
"type": "list",
"description": "List of allowed IP addresses (empty = allow all)",
},
"csp_header": {
"value": "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline';",
"type": "string",
"description": "Content Security Policy header",
},
},
"security": {
"password_min_length": {"value": 8, "type": "integer", "description": "Minimum password length"},
"password_require_special": {"value": True, "type": "boolean", "description": "Require special characters in passwords"},
"password_require_numbers": {"value": True, "type": "boolean", "description": "Require numbers in passwords"},
"password_require_uppercase": {"value": True, "type": "boolean", "description": "Require uppercase letters in passwords"},
"max_login_attempts": {"value": 5, "type": "integer", "description": "Maximum login attempts before lockout"},
"lockout_duration_minutes": {"value": 15, "type": "integer", "description": "Account lockout duration in minutes"},
"require_2fa": {"value": False, "type": "boolean", "description": "Require two-factor authentication"},
"ip_whitelist_enabled": {"value": False, "type": "boolean", "description": "Enable IP whitelist"},
"allowed_domains": {"value": [], "type": "list", "description": "Allowed email domains for registration"},
"password_min_length": {
"value": 8,
"type": "integer",
"description": "Minimum password length",
},
"password_require_special": {
"value": True,
"type": "boolean",
"description": "Require special characters in passwords",
},
"password_require_numbers": {
"value": True,
"type": "boolean",
"description": "Require numbers in passwords",
},
"password_require_uppercase": {
"value": True,
"type": "boolean",
"description": "Require uppercase letters in passwords",
},
"max_login_attempts": {
"value": 5,
"type": "integer",
"description": "Maximum login attempts before lockout",
},
"lockout_duration_minutes": {
"value": 15,
"type": "integer",
"description": "Account lockout duration in minutes",
},
"require_2fa": {
"value": False,
"type": "boolean",
"description": "Require two-factor authentication",
},
"ip_whitelist_enabled": {
"value": False,
"type": "boolean",
"description": "Enable IP whitelist",
},
"allowed_domains": {
"value": [],
"type": "list",
"description": "Allowed email domains for registration",
},
},
"features": {
"user_registration": {"value": True, "type": "boolean", "description": "Allow user registration"},
"api_key_creation": {"value": True, "type": "boolean", "description": "Allow API key creation"},
"budget_enforcement": {"value": True, "type": "boolean", "description": "Enable budget enforcement"},
"audit_logging": {"value": True, "type": "boolean", "description": "Enable audit logging"},
"module_hot_reload": {"value": True, "type": "boolean", "description": "Enable module hot reload"},
"tee_support": {"value": True, "type": "boolean", "description": "Enable TEE (Trusted Execution Environment) support"},
"advanced_analytics": {"value": True, "type": "boolean", "description": "Enable advanced analytics"},
"user_registration": {
"value": True,
"type": "boolean",
"description": "Allow user registration",
},
"api_key_creation": {
"value": True,
"type": "boolean",
"description": "Allow API key creation",
},
"budget_enforcement": {
"value": True,
"type": "boolean",
"description": "Enable budget enforcement",
},
"audit_logging": {
"value": True,
"type": "boolean",
"description": "Enable audit logging",
},
"module_hot_reload": {
"value": True,
"type": "boolean",
"description": "Enable module hot reload",
},
"tee_support": {
"value": True,
"type": "boolean",
"description": "Enable TEE (Trusted Execution Environment) support",
},
"advanced_analytics": {
"value": True,
"type": "boolean",
"description": "Enable advanced analytics",
},
},
"notifications": {
"email_enabled": {"value": False, "type": "boolean", "description": "Enable email notifications"},
"slack_enabled": {"value": False, "type": "boolean", "description": "Enable Slack notifications"},
"webhook_enabled": {"value": False, "type": "boolean", "description": "Enable webhook notifications"},
"budget_alerts": {"value": True, "type": "boolean", "description": "Enable budget alert notifications"},
"security_alerts": {"value": True, "type": "boolean", "description": "Enable security alert notifications"},
}
"email_enabled": {
"value": False,
"type": "boolean",
"description": "Enable email notifications",
},
"slack_enabled": {
"value": False,
"type": "boolean",
"description": "Enable Slack notifications",
},
"webhook_enabled": {
"value": False,
"type": "boolean",
"description": "Enable webhook notifications",
},
"budget_alerts": {
"value": True,
"type": "boolean",
"description": "Enable budget alert notifications",
},
"security_alerts": {
"value": True,
"type": "boolean",
"description": "Enable security alert notifications",
},
},
}
@@ -162,7 +346,7 @@ async def list_settings(
category: Optional[str] = None,
include_secrets: bool = False,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""List all settings or settings in a specific category"""
@@ -179,23 +363,26 @@ async def list_settings(
for key, setting in settings.items():
# Hide secret values unless specifically requested and user has permission
if setting.get("is_secret", False) and not include_secrets:
if not any(perm in current_user.get("permissions", []) for perm in ["platform:settings:admin", "platform:*"]):
if not any(
perm in current_user.get("permissions", [])
for perm in ["platform:settings:admin", "platform:*"]
):
continue
result[cat][key] = {
"value": setting["value"],
"type": setting["type"],
"description": setting.get("description", ""),
"is_secret": setting.get("is_secret", False)
"is_secret": setting.get("is_secret", False),
}
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="list_settings",
resource_type="setting",
details={"category": category, "include_secrets": include_secrets}
details={"category": category, "include_secrets": include_secrets},
)
return result
@@ -203,8 +390,7 @@ async def list_settings(
@router.get("/system-info", response_model=SystemInfoResponse)
async def get_system_info(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
):
"""Get system information and status"""
@@ -228,6 +414,7 @@ async def get_system_info(
# Get LLM service status
try:
from app.services.llm.service import llm_service
health_summary = llm_service.get_health_summary()
llm_service_status = health_summary.get("service_status", "unknown")
except Exception:
@@ -238,6 +425,7 @@ async def get_system_info(
# Get active users count (last 24 hours)
from datetime import datetime, timedelta
yesterday = datetime.utcnow() - timedelta(days=1)
active_users_query = select(User.id).where(User.last_login >= yesterday)
active_users_result = await db.execute(active_users_query)
@@ -254,9 +442,9 @@ async def get_system_info(
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="get_system_info",
resource_type="system"
resource_type="system",
)
return SystemInfoResponse(
@@ -268,14 +456,13 @@ async def get_system_info(
modules_loaded=modules_loaded,
active_users=active_users,
total_api_keys=total_api_keys,
uptime_seconds=uptime_seconds
uptime_seconds=uptime_seconds,
)
@router.get("/platform-config", response_model=PlatformConfigResponse)
async def get_platform_config(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
):
"""Get platform configuration"""
@@ -289,24 +476,33 @@ async def get_platform_config(
api_settings = SETTINGS_STORE.get("api", {})
return PlatformConfigResponse(
app_name=platform_settings.get("app_name", {}).get("value", "Confidential Empire"),
app_name=platform_settings.get("app_name", {}).get(
"value", "Confidential Empire"
),
debug_mode=platform_settings.get("debug_mode", {}).get("value", False),
log_level=app_settings.LOG_LEVEL,
cors_origins=app_settings.CORS_ORIGINS,
rate_limiting_enabled=api_settings.get("rate_limiting_enabled", {}).get("value", True),
max_upload_size=platform_settings.get("max_upload_size", {}).get("value", 10485760),
rate_limiting_enabled=api_settings.get("rate_limiting_enabled", {}).get(
"value", True
),
max_upload_size=platform_settings.get("max_upload_size", {}).get(
"value", 10485760
),
session_timeout_minutes=app_settings.SESSION_EXPIRE_MINUTES,
api_key_prefix=app_settings.API_KEY_PREFIX,
features=features,
maintenance_mode=platform_settings.get("maintenance_mode", {}).get("value", False),
maintenance_message=platform_settings.get("maintenance_message", {}).get("value")
maintenance_mode=platform_settings.get("maintenance_mode", {}).get(
"value", False
),
maintenance_message=platform_settings.get("maintenance_message", {}).get(
"value"
),
)
@router.get("/security-config", response_model=SecurityConfigResponse)
async def get_security_config(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
):
"""Get security configuration"""
@@ -316,16 +512,30 @@ async def get_security_config(
security_settings = SETTINGS_STORE.get("security", {})
return SecurityConfigResponse(
password_min_length=security_settings.get("password_min_length", {}).get("value", 8),
password_require_special=security_settings.get("password_require_special", {}).get("value", True),
password_require_numbers=security_settings.get("password_require_numbers", {}).get("value", True),
password_require_uppercase=security_settings.get("password_require_uppercase", {}).get("value", True),
password_min_length=security_settings.get("password_min_length", {}).get(
"value", 8
),
password_require_special=security_settings.get(
"password_require_special", {}
).get("value", True),
password_require_numbers=security_settings.get(
"password_require_numbers", {}
).get("value", True),
password_require_uppercase=security_settings.get(
"password_require_uppercase", {}
).get("value", True),
session_timeout_minutes=app_settings.SESSION_EXPIRE_MINUTES,
max_login_attempts=security_settings.get("max_login_attempts", {}).get("value", 5),
lockout_duration_minutes=security_settings.get("lockout_duration_minutes", {}).get("value", 15),
max_login_attempts=security_settings.get("max_login_attempts", {}).get(
"value", 5
),
lockout_duration_minutes=security_settings.get(
"lockout_duration_minutes", {}
).get("value", 15),
require_2fa=security_settings.get("require_2fa", {}).get("value", False),
allowed_domains=security_settings.get("allowed_domains", {}).get("value", []),
ip_whitelist_enabled=security_settings.get("ip_whitelist_enabled", {}).get("value", False)
ip_whitelist_enabled=security_settings.get("ip_whitelist_enabled", {}).get(
"value", False
),
)
@@ -334,7 +544,7 @@ async def get_setting(
category: str,
key: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get a specific setting value"""
@@ -344,28 +554,30 @@ async def get_setting(
if category not in SETTINGS_STORE:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Settings category '{category}' not found"
detail=f"Settings category '{category}' not found",
)
if key not in SETTINGS_STORE[category]:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Setting '{key}' not found in category '{category}'"
detail=f"Setting '{key}' not found in category '{category}'",
)
setting = SETTINGS_STORE[category][key]
# Check if it's a secret setting
if setting.get("is_secret", False):
require_permission(current_user.get("permissions", []), "platform:settings:admin")
require_permission(
current_user.get("permissions", []), "platform:settings:admin"
)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="get_setting",
resource_type="setting",
resource_id=f"{category}.{key}"
resource_id=f"{category}.{key}",
)
return {
@@ -374,7 +586,7 @@ async def get_setting(
"value": setting["value"],
"type": setting["type"],
"description": setting.get("description", ""),
"is_secret": setting.get("is_secret", False)
"is_secret": setting.get("is_secret", False),
}
@@ -383,7 +595,7 @@ async def update_category_settings(
category: str,
settings_data: Dict[str, Any],
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Update multiple settings in a category"""
@@ -393,7 +605,7 @@ async def update_category_settings(
if category not in SETTINGS_STORE:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Settings category '{category}' not found"
detail=f"Settings category '{category}' not found",
)
updated_settings = []
@@ -408,7 +620,9 @@ async def update_category_settings(
# Check if it's a secret setting
if setting.get("is_secret", False):
require_permission(current_user.get("permissions", []), "platform:settings:admin")
require_permission(
current_user.get("permissions", []), "platform:settings:admin"
)
# Store original value for audit
original_value = setting["value"]
@@ -425,7 +639,7 @@ async def update_category_settings(
continue
elif expected_type == "boolean" and not isinstance(new_value, bool):
if isinstance(new_value, str):
new_value = new_value.lower() in ('true', '1', 'yes', 'on')
new_value = new_value.lower() in ("true", "1", "yes", "on")
else:
errors.append(f"Setting '{key}' expects a boolean value")
continue
@@ -445,11 +659,9 @@ async def update_category_settings(
# Update setting
SETTINGS_STORE[category][key]["value"] = new_value
updated_settings.append({
"key": key,
"original_value": original_value,
"new_value": new_value
})
updated_settings.append(
{"key": key, "original_value": original_value, "new_value": new_value}
)
except Exception as e:
errors.append(f"Error updating setting '{key}': {str(e)}")
@@ -457,7 +669,7 @@ async def update_category_settings(
# Log audit event for bulk update
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="bulk_update_settings",
resource_type="setting",
resource_id=category,
@@ -465,25 +677,29 @@ async def update_category_settings(
"updated_count": len(updated_settings),
"errors_count": len(errors),
"updated_settings": updated_settings,
"errors": errors
}
"errors": errors,
},
)
logger.info(f"Bulk settings updated in category '{category}': {len(updated_settings)} settings by {current_user['username']}")
logger.info(
f"Bulk settings updated in category '{category}': {len(updated_settings)} settings by {current_user['username']}"
)
if errors and not updated_settings:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"No settings were updated. Errors: {errors}"
detail=f"No settings were updated. Errors: {errors}",
)
return {
"category": category,
"updated_count": len(updated_settings),
"errors_count": len(errors),
"updated_settings": [{"key": s["key"], "new_value": s["new_value"]} for s in updated_settings],
"updated_settings": [
{"key": s["key"], "new_value": s["new_value"]} for s in updated_settings
],
"errors": errors,
"message": f"Updated {len(updated_settings)} settings in category '{category}'"
"message": f"Updated {len(updated_settings)} settings in category '{category}'",
}
@@ -493,7 +709,7 @@ async def update_setting(
key: str,
setting_update: SettingUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Update a specific setting"""
@@ -503,20 +719,22 @@ async def update_setting(
if category not in SETTINGS_STORE:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Settings category '{category}' not found"
detail=f"Settings category '{category}' not found",
)
if key not in SETTINGS_STORE[category]:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Setting '{key}' not found in category '{category}'"
detail=f"Setting '{key}' not found in category '{category}'",
)
setting = SETTINGS_STORE[category][key]
# Check if it's a secret setting
if setting.get("is_secret", False):
require_permission(current_user.get("permissions", []), "platform:settings:admin")
require_permission(
current_user.get("permissions", []), "platform:settings:admin"
)
# Store original value for audit
original_value = setting["value"]
@@ -528,22 +746,22 @@ async def update_setting(
if expected_type == "integer" and not isinstance(new_value, int):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Setting '{key}' expects an integer value"
detail=f"Setting '{key}' expects an integer value",
)
elif expected_type == "boolean" and not isinstance(new_value, bool):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Setting '{key}' expects a boolean value"
detail=f"Setting '{key}' expects a boolean value",
)
elif expected_type == "float" and not isinstance(new_value, (int, float)):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Setting '{key}' expects a numeric value"
detail=f"Setting '{key}' expects a numeric value",
)
elif expected_type == "list" and not isinstance(new_value, list):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Setting '{key}' expects a list value"
detail=f"Setting '{key}' expects a list value",
)
# Update setting
@@ -554,15 +772,15 @@ async def update_setting(
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="update_setting",
resource_type="setting",
resource_id=f"{category}.{key}",
details={
"original_value": original_value,
"new_value": new_value,
"description_updated": setting_update.description is not None
}
"description_updated": setting_update.description is not None,
},
)
logger.info(f"Setting updated: {category}.{key} by {current_user['username']}")
@@ -573,7 +791,7 @@ async def update_setting(
"value": new_value,
"type": expected_type,
"description": SETTINGS_STORE[category][key].get("description", ""),
"message": "Setting updated successfully"
"message": "Setting updated successfully",
}
@@ -581,7 +799,7 @@ async def update_setting(
async def reset_to_defaults(
category: Optional[str] = None,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Reset settings to default values"""
@@ -603,7 +821,6 @@ async def reset_to_defaults(
"ip_reputation_enabled": {"value": True, "type": "boolean"},
"anomaly_detection_enabled": {"value": True, "type": "boolean"},
"security_headers_enabled": {"value": True, "type": "boolean"},
# Rate Limiting by Authentication Level
"rate_limit_authenticated_per_minute": {"value": 200, "type": "integer"},
"rate_limit_authenticated_per_hour": {"value": 5000, "type": "integer"},
@@ -611,22 +828,25 @@ async def reset_to_defaults(
"rate_limit_api_key_per_hour": {"value": 20000, "type": "integer"},
"rate_limit_premium_per_minute": {"value": 5000, "type": "integer"},
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer"},
# Security Thresholds
"security_warning_threshold": {"value": 0.6, "type": "float"},
"anomaly_threshold": {"value": 0.7, "type": "float"},
# Request Settings
"max_request_size_mb": {"value": 10, "type": "integer"},
"max_request_size_premium_mb": {"value": 50, "type": "integer"},
"enable_cors": {"value": True, "type": "boolean"},
"cors_origins": {"value": ["http://localhost:3000", "http://localhost:53000"], "type": "list"},
"cors_origins": {
"value": ["http://localhost:3000", "http://localhost:53000"],
"type": "list",
},
"api_key_expiry_days": {"value": 90, "type": "integer"},
# IP Security
"blocked_ips": {"value": [], "type": "list"},
"allowed_ips": {"value": [], "type": "list"},
"csp_header": {"value": "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline';", "type": "string"},
"csp_header": {
"value": "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline';",
"type": "string",
},
},
"security": {
"password_min_length": {"value": 8, "type": "integer"},
@@ -647,7 +867,7 @@ async def reset_to_defaults(
"module_hot_reload": {"value": True, "type": "boolean"},
"tee_support": {"value": True, "type": "boolean"},
"advanced_analytics": {"value": True, "type": "boolean"},
}
},
}
reset_categories = [category] if category else list(defaults.keys())
@@ -661,24 +881,25 @@ async def reset_to_defaults(
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="reset_settings_to_defaults",
resource_type="setting",
details={"categories_reset": reset_categories}
details={"categories_reset": reset_categories},
)
logger.info(f"Settings reset to defaults: {reset_categories} by {current_user['username']}")
logger.info(
f"Settings reset to defaults: {reset_categories} by {current_user['username']}"
)
return {
"message": f"Settings reset to defaults for categories: {reset_categories}",
"categories_reset": reset_categories
"categories_reset": reset_categories,
}
@router.post("/export")
async def export_settings(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
):
"""Export all settings to JSON"""
@@ -693,29 +914,32 @@ async def export_settings(
for key, setting in settings.items():
# Skip secret settings for non-admin users
if setting.get("is_secret", False):
if not any(perm in current_user.get("permissions", []) for perm in ["platform:settings:admin", "platform:*"]):
if not any(
perm in current_user.get("permissions", [])
for perm in ["platform:settings:admin", "platform:*"]
):
continue
export_data[category][key] = {
"value": setting["value"],
"type": setting["type"],
"description": setting.get("description", ""),
"is_secret": setting.get("is_secret", False)
"is_secret": setting.get("is_secret", False),
}
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="export_settings",
resource_type="setting",
details={"categories_exported": list(export_data.keys())}
details={"categories_exported": list(export_data.keys())},
)
return {
"settings": export_data,
"exported_at": datetime.utcnow().isoformat(),
"exported_by": current_user['username']
"exported_by": current_user["username"],
}
@@ -723,7 +947,7 @@ async def export_settings(
async def import_settings(
settings_data: Dict[str, Dict[str, Dict[str, Any]]],
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Import settings from JSON"""
@@ -750,15 +974,21 @@ async def import_settings(
# Basic type validation
if expected_type == "integer" and not isinstance(new_value, int):
errors.append(f"Invalid type for {category}.{key}: expected integer")
errors.append(
f"Invalid type for {category}.{key}: expected integer"
)
continue
elif expected_type == "boolean" and not isinstance(new_value, bool):
errors.append(f"Invalid type for {category}.{key}: expected boolean")
errors.append(
f"Invalid type for {category}.{key}: expected boolean"
)
continue
SETTINGS_STORE[category][key]["value"] = new_value
if "description" in setting_data:
SETTINGS_STORE[category][key]["description"] = setting_data["description"]
SETTINGS_STORE[category][key]["description"] = setting_data[
"description"
]
imported_count += 1
@@ -768,20 +998,22 @@ async def import_settings(
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="import_settings",
resource_type="setting",
details={
"imported_count": imported_count,
"errors_count": len(errors),
"errors": errors
}
"errors": errors,
},
)
logger.info(f"Settings imported: {imported_count} settings by {current_user['username']}")
logger.info(
f"Settings imported: {imported_count} settings by {current_user['username']}"
)
return {
"message": f"Import completed. {imported_count} settings imported.",
"imported_count": imported_count,
"errors": errors
"errors": errors,
}

View File

@@ -85,7 +85,7 @@ async def list_users(
is_active: Optional[bool] = Query(None),
search: Optional[str] = Query(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""List all users with pagination and filtering"""
@@ -102,9 +102,9 @@ async def list_users(
query = query.where(User.is_active == is_active)
if search:
query = query.where(
(User.username.ilike(f"%{search}%")) |
(User.email.ilike(f"%{search}%")) |
(User.full_name.ilike(f"%{search}%"))
(User.username.ilike(f"%{search}%"))
| (User.email.ilike(f"%{search}%"))
| (User.full_name.ilike(f"%{search}%"))
)
# Get total count
@@ -123,17 +123,21 @@ async def list_users(
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="list_users",
resource_type="user",
details={"page": page, "size": size, "filters": {"role": role, "is_active": is_active, "search": search}}
details={
"page": page,
"size": size,
"filters": {"role": role, "is_active": is_active, "search": search},
},
)
return UserListResponse(
users=[UserResponse.model_validate(user) for user in users],
total=total,
page=page,
size=size
size=size,
)
@@ -141,12 +145,12 @@ async def list_users(
async def get_user(
user_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get user by ID"""
# Check permissions (users can view their own profile)
if int(user_id) != current_user['id']:
if int(user_id) != current_user["id"]:
require_permission(current_user.get("permissions", []), "platform:users:read")
# Get user
@@ -156,17 +160,16 @@ async def get_user(
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="get_user",
resource_type="user",
resource_id=user_id
resource_id=user_id,
)
return UserResponse.model_validate(user)
@@ -176,7 +179,7 @@ async def get_user(
async def create_user(
user_data: UserCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Create a new user"""
@@ -193,7 +196,7 @@ async def create_user(
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User with this username or email already exists"
detail="User with this username or email already exists",
)
# Create user
@@ -204,7 +207,7 @@ async def create_user(
full_name=user_data.full_name,
hashed_password=hashed_password,
role=user_data.role,
is_active=user_data.is_active
is_active=user_data.is_active,
)
db.add(new_user)
@@ -214,11 +217,15 @@ async def create_user(
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="create_user",
resource_type="user",
resource_id=str(new_user.id),
details={"username": user_data.username, "email": user_data.email, "role": user_data.role}
details={
"username": user_data.username,
"email": user_data.email,
"role": user_data.role,
},
)
logger.info(f"User created: {new_user.username} by {current_user['username']}")
@@ -231,12 +238,12 @@ async def update_user(
user_id: str,
user_data: UserUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Update user"""
# Check permissions (users can update their own profile with limited fields)
is_self_update = int(user_id) == current_user['id']
is_self_update = int(user_id) == current_user["id"]
if not is_self_update:
require_permission(current_user.get("permissions", []), "platform:users:update")
@@ -247,8 +254,7 @@ async def update_user(
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
# For self-updates, restrict what can be changed
@@ -259,7 +265,7 @@ async def update_user(
if restricted_fields:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Cannot update fields: {restricted_fields}"
detail=f"Cannot update fields: {restricted_fields}",
)
# Store original values for audit
@@ -267,7 +273,7 @@ async def update_user(
"username": user.username,
"email": user.email,
"role": user.role,
"is_active": user.is_active
"is_active": user.is_active,
}
# Update user fields
@@ -281,15 +287,15 @@ async def update_user(
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="update_user",
resource_type="user",
resource_id=user_id,
details={
"updated_fields": list(update_data.keys()),
"before_values": original_values,
"after_values": {k: getattr(user, k) for k in update_data.keys()}
}
"after_values": {k: getattr(user, k) for k in update_data.keys()},
},
)
logger.info(f"User updated: {user.username} by {current_user['username']}")
@@ -301,7 +307,7 @@ async def update_user(
async def delete_user(
user_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Delete user (soft delete by deactivating)"""
@@ -309,10 +315,10 @@ async def delete_user(
require_permission(current_user.get("permissions", []), "platform:users:delete")
# Prevent self-deletion
if int(user_id) == current_user['id']:
if int(user_id) == current_user["id"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot delete your own account"
detail="Cannot delete your own account",
)
# Get user
@@ -322,8 +328,7 @@ async def delete_user(
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
# Soft delete by deactivating
@@ -333,11 +338,11 @@ async def delete_user(
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="delete_user",
resource_type="user",
resource_id=user_id,
details={"username": user.username, "email": user.email}
details={"username": user.username, "email": user.email},
)
logger.info(f"User deleted: {user.username} by {current_user['username']}")
@@ -350,12 +355,12 @@ async def change_password(
user_id: str,
password_data: PasswordChangeRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Change user password"""
# Users can only change their own password, or admins can change any password
is_self_update = int(user_id) == current_user['id']
is_self_update = int(user_id) == current_user["id"]
if not is_self_update:
require_permission(current_user.get("permissions", []), "platform:users:update")
@@ -366,8 +371,7 @@ async def change_password(
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
# For self-updates, verify current password
@@ -375,7 +379,7 @@ async def change_password(
if not verify_password(password_data.current_password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Current password is incorrect"
detail="Current password is incorrect",
)
# Update password
@@ -385,14 +389,16 @@ async def change_password(
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="change_password",
resource_type="user",
resource_id=user_id,
details={"target_user": user.username}
details={"target_user": user.username},
)
logger.info(f"Password changed for user: {user.username} by {current_user['username']}")
logger.info(
f"Password changed for user: {user.username} by {current_user['username']}"
)
return {"message": "Password changed successfully"}
@@ -402,7 +408,7 @@ async def reset_password(
user_id: str,
password_data: PasswordResetRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Reset user password (admin only)"""
@@ -416,8 +422,7 @@ async def reset_password(
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
# Reset password
@@ -427,14 +432,16 @@ async def reset_password(
# Log audit event
await log_audit_event(
db=db,
user_id=current_user['id'],
user_id=current_user["id"],
action="reset_password",
resource_type="user",
resource_id=user_id,
details={"target_user": user.username}
details={"target_user": user.username},
)
logger.info(f"Password reset for user: {user.username} by {current_user['username']}")
logger.info(
f"Password reset for user: {user.username} by {current_user['username']}"
)
return {"message": "Password reset successfully"}
@@ -443,14 +450,16 @@ async def reset_password(
async def get_user_api_keys(
user_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
):
"""Get API keys for a user"""
# Check permissions (users can view their own API keys)
is_self_request = int(user_id) == current_user['id']
is_self_request = int(user_id) == current_user["id"]
if not is_self_request:
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
require_permission(
current_user.get("permissions", []), "platform:api-keys:read"
)
# Get API keys
query = select(APIKey).where(APIKey.user_id == int(user_id))
@@ -466,8 +475,12 @@ async def get_user_api_keys(
"scopes": api_key.scopes,
"is_active": api_key.is_active,
"created_at": api_key.created_at.isoformat(),
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
"last_used_at": api_key.last_used_at.isoformat() if api_key.last_used_at else None
"expires_at": api_key.expires_at.isoformat()
if api_key.expires_at
else None,
"last_used_at": api_key.last_used_at.isoformat()
if api_key.last_used_at
else None,
}
for api_key in api_keys
]

View File

@@ -24,18 +24,13 @@ class CoreCacheService:
self.redis_pool: Optional[ConnectionPool] = None
self.redis_client: Optional[Redis] = None
self.enabled = False
self.stats = {
"hits": 0,
"misses": 0,
"errors": 0,
"total_requests": 0
}
self.stats = {"hits": 0, "misses": 0, "errors": 0, "total_requests": 0}
async def initialize(self):
"""Initialize the core cache service with connection pool"""
try:
# Create Redis connection pool for better resource management
redis_url = getattr(settings, 'REDIS_URL', 'redis://localhost:6379/0')
redis_url = getattr(settings, "REDIS_URL", "redis://localhost:6379/0")
self.redis_pool = ConnectionPool.from_url(
redis_url,
@@ -45,7 +40,7 @@ class CoreCacheService:
socket_timeout=5,
retry_on_timeout=True,
max_connections=20, # Shared pool for all cache operations
health_check_interval=30
health_check_interval=30,
)
self.redis_client = Redis(connection_pool=self.redis_pool)
@@ -105,7 +100,9 @@ class CoreCacheService:
self.stats["errors"] += 1
return default
async def set(self, key: str, value: Any, ttl: Optional[int] = None, prefix: str = "core") -> bool:
async def set(
self, key: str, value: Any, ttl: Optional[int] = None, prefix: str = "core"
) -> bool:
"""Set value in cache"""
if not self.enabled:
return False
@@ -172,7 +169,9 @@ class CoreCacheService:
self.stats["errors"] += 1
return 0
async def increment(self, key: str, amount: int = 1, ttl: Optional[int] = None, prefix: str = "core") -> int:
async def increment(
self, key: str, amount: int = 1, ttl: Optional[int] = None, prefix: str = "core"
) -> int:
"""Increment counter with optional TTL"""
if not self.enabled:
return 0
@@ -200,18 +199,24 @@ class CoreCacheService:
if self.enabled:
try:
info = await self.redis_client.info()
stats.update({
stats.update(
{
"redis_memory_used": info.get("used_memory_human", "N/A"),
"redis_connected_clients": info.get("connected_clients", 0),
"redis_total_commands": info.get("total_commands_processed", 0),
"redis_keyspace_hits": info.get("keyspace_hits", 0),
"redis_keyspace_misses": info.get("keyspace_misses", 0),
"connection_pool_size": self.redis_pool.connection_pool_size if self.redis_pool else 0,
"connection_pool_size": self.redis_pool.connection_pool_size
if self.redis_pool
else 0,
"hit_rate": round(
(stats["hits"] / stats["total_requests"]) * 100, 2
) if stats["total_requests"] > 0 else 0,
"enabled": True
})
)
if stats["total_requests"] > 0
else 0,
"enabled": True,
}
)
except Exception as e:
logger.error(f"Error getting Redis stats: {e}")
stats["enabled"] = False
@@ -232,7 +237,9 @@ class CoreCacheService:
# Specialized caching methods for common use cases
async def cache_api_key(self, key_prefix: str, api_key_data: Dict[str, Any], ttl: int = 300) -> bool:
async def cache_api_key(
self, key_prefix: str, api_key_data: Dict[str, Any], ttl: int = 300
) -> bool:
"""Cache API key data for authentication"""
return await self.set(key_prefix, api_key_data, ttl, prefix="auth")
@@ -244,36 +251,53 @@ class CoreCacheService:
"""Invalidate cached API key"""
return await self.delete(key_prefix, prefix="auth")
async def cache_verification_result(self, api_key: str, key_prefix: str, key_hash: str, is_valid: bool, ttl: int = 300) -> bool:
async def cache_verification_result(
self,
api_key: str,
key_prefix: str,
key_hash: str,
is_valid: bool,
ttl: int = 300,
) -> bool:
"""Cache API key verification result to avoid expensive bcrypt operations"""
verification_data = {
"key_hash": key_hash,
"is_valid": is_valid,
"timestamp": datetime.utcnow().isoformat()
"timestamp": datetime.utcnow().isoformat(),
}
return await self.set(f"verify:{key_prefix}", verification_data, ttl, prefix="auth")
return await self.set(
f"verify:{key_prefix}", verification_data, ttl, prefix="auth"
)
async def get_cached_verification(self, key_prefix: str) -> Optional[Dict[str, Any]]:
async def get_cached_verification(
self, key_prefix: str
) -> Optional[Dict[str, Any]]:
"""Get cached verification result"""
return await self.get(f"verify:{key_prefix}", prefix="auth")
async def cache_rate_limit(self, identifier: str, window_seconds: int, limit: int, current_count: int = 1) -> Dict[str, Any]:
async def cache_rate_limit(
self, identifier: str, window_seconds: int, limit: int, current_count: int = 1
) -> Dict[str, Any]:
"""Cache and track rate limit state"""
key = f"rate_limit:{identifier}:{window_seconds}"
try:
# Use atomic increment with expiry
count = await self.increment(key, current_count, window_seconds, prefix="rate")
count = await self.increment(
key, current_count, window_seconds, prefix="rate"
)
remaining = max(0, limit - count)
reset_time = int((datetime.utcnow() + timedelta(seconds=window_seconds)).timestamp())
reset_time = int(
(datetime.utcnow() + timedelta(seconds=window_seconds)).timestamp()
)
return {
"count": count,
"limit": limit,
"remaining": remaining,
"reset_time": reset_time,
"exceeded": count > limit
"exceeded": count > limit,
}
except Exception as e:
logger.error(f"Rate limit cache error: {e}")
@@ -282,8 +306,10 @@ class CoreCacheService:
"count": 0,
"limit": limit,
"remaining": limit,
"reset_time": int((datetime.utcnow() + timedelta(seconds=window_seconds)).timestamp()),
"exceeded": False
"reset_time": int(
(datetime.utcnow() + timedelta(seconds=window_seconds)).timestamp()
),
"exceeded": False,
}
@@ -297,7 +323,9 @@ async def get(key: str, default: Any = None, prefix: str = "core") -> Any:
return await core_cache.get(key, default, prefix)
async def set(key: str, value: Any, ttl: Optional[int] = None, prefix: str = "core") -> bool:
async def set(
key: str, value: Any, ttl: Optional[int] = None, prefix: str = "core"
) -> bool:
"""Set value in core cache"""
return await core_cache.set(key, value, ttl, prefix)

View File

@@ -21,7 +21,9 @@ class Settings(BaseSettings):
FRONTEND_INTERNAL_PORT: int = int(os.getenv("FRONTEND_INTERNAL_PORT", "3000"))
# Detailed logging for LLM interactions
LOG_LLM_PROMPTS: bool = os.getenv("LOG_LLM_PROMPTS", "False").lower() == "true" # Set to True to log prompts and context sent to LLM
LOG_LLM_PROMPTS: bool = (
os.getenv("LOG_LLM_PROMPTS", "False").lower() == "true"
) # Set to True to log prompts and context sent to LLM
# Database
DATABASE_URL: str = os.getenv("DATABASE_URL")
@@ -32,11 +34,19 @@ class Settings(BaseSettings):
# Security
JWT_SECRET: str = os.getenv("JWT_SECRET")
JWT_ALGORITHM: str = os.getenv("JWT_ALGORITHM", "HS256")
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440")) # 24 hours
REFRESH_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_MINUTES", "10080")) # 7 days
SESSION_EXPIRE_MINUTES: int = int(os.getenv("SESSION_EXPIRE_MINUTES", "1440")) # 24 hours
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(
os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440")
) # 24 hours
REFRESH_TOKEN_EXPIRE_MINUTES: int = int(
os.getenv("REFRESH_TOKEN_EXPIRE_MINUTES", "10080")
) # 7 days
SESSION_EXPIRE_MINUTES: int = int(
os.getenv("SESSION_EXPIRE_MINUTES", "1440")
) # 24 hours
API_KEY_PREFIX: str = os.getenv("API_KEY_PREFIX", "en_")
BCRYPT_ROUNDS: int = int(os.getenv("BCRYPT_ROUNDS", "6")) # Bcrypt work factor - lower for production performance
BCRYPT_ROUNDS: int = int(
os.getenv("BCRYPT_ROUNDS", "6")
) # Bcrypt work factor - lower for production performance
# Admin user provisioning (used only on first startup)
ADMIN_EMAIL: str = os.getenv("ADMIN_EMAIL")
@@ -45,12 +55,12 @@ class Settings(BaseSettings):
# Base URL for deriving CORS origins
BASE_URL: str = os.getenv("BASE_URL", "localhost")
@field_validator('CORS_ORIGINS', mode='before')
@field_validator("CORS_ORIGINS", mode="before")
@classmethod
def derive_cors_origins(cls, v, info):
"""Derive CORS origins from BASE_URL if not explicitly set"""
if v is None:
base_url = info.data.get('BASE_URL', 'localhost')
base_url = info.data.get("BASE_URL", "localhost")
# Support both HTTP and HTTPS for production environments
return [f"http://{base_url}", f"https://{base_url}"]
return v if isinstance(v, list) else [v]
@@ -64,14 +74,18 @@ class Settings(BaseSettings):
# LLM Service Security (removed encryption - credentials handled by proxy)
# Plugin System Security
PLUGIN_ENCRYPTION_KEY: Optional[str] = os.getenv("PLUGIN_ENCRYPTION_KEY") # Key for encrypting plugin secrets and configurations
PLUGIN_ENCRYPTION_KEY: Optional[str] = os.getenv(
"PLUGIN_ENCRYPTION_KEY"
) # Key for encrypting plugin secrets and configurations
# API Keys for LLM providers
OPENAI_API_KEY: Optional[str] = os.getenv("OPENAI_API_KEY")
ANTHROPIC_API_KEY: Optional[str] = os.getenv("ANTHROPIC_API_KEY")
GOOGLE_API_KEY: Optional[str] = os.getenv("GOOGLE_API_KEY")
PRIVATEMODE_API_KEY: Optional[str] = os.getenv("PRIVATEMODE_API_KEY")
PRIVATEMODE_PROXY_URL: str = os.getenv("PRIVATEMODE_PROXY_URL", "http://privatemode-proxy:8080/v1")
PRIVATEMODE_PROXY_URL: str = os.getenv(
"PRIVATEMODE_PROXY_URL", "http://privatemode-proxy:8080/v1"
)
# Qdrant
QDRANT_HOST: str = os.getenv("QDRANT_HOST", "localhost")
@@ -79,36 +93,62 @@ class Settings(BaseSettings):
QDRANT_API_KEY: Optional[str] = os.getenv("QDRANT_API_KEY")
QDRANT_URL: str = os.getenv("QDRANT_URL", "http://localhost:6333")
# Rate Limiting Configuration
# PrivateMode Standard tier limits (organization-level, not per user)
# These are shared across all API keys and users in the organization
PRIVATEMODE_REQUESTS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_MINUTE", "20"))
PRIVATEMODE_REQUESTS_PER_HOUR: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_HOUR", "1200"))
PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE", "20000"))
PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE", "10000"))
PRIVATEMODE_REQUESTS_PER_MINUTE: int = int(
os.getenv("PRIVATEMODE_REQUESTS_PER_MINUTE", "20")
)
PRIVATEMODE_REQUESTS_PER_HOUR: int = int(
os.getenv("PRIVATEMODE_REQUESTS_PER_HOUR", "1200")
)
PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE: int = int(
os.getenv("PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE", "20000")
)
PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE: int = int(
os.getenv("PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE", "10000")
)
# Per-user limits (additional protection on top of organization limits)
API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "20")) # Match PrivateMode
API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "1200"))
API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(
os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "20")
) # Match PrivateMode
API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(
os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "1200")
)
# API key users (programmatic access)
API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "20")) # Match PrivateMode
API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "1200"))
API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(
os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "20")
) # Match PrivateMode
API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(
os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "1200")
)
# Premium/Enterprise API keys
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")) # Match PrivateMode
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200"))
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(
os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")
) # Match PrivateMode
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(
os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200")
)
# Request Size Limits
API_MAX_REQUEST_BODY_SIZE: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE", "10485760")) # 10MB
API_MAX_REQUEST_BODY_SIZE_PREMIUM: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE_PREMIUM", "52428800")) # 50MB for premium
API_MAX_REQUEST_BODY_SIZE: int = int(
os.getenv("API_MAX_REQUEST_BODY_SIZE", "10485760")
) # 10MB
API_MAX_REQUEST_BODY_SIZE_PREMIUM: int = int(
os.getenv("API_MAX_REQUEST_BODY_SIZE_PREMIUM", "52428800")
) # 50MB for premium
# IP Security
# Security Headers
API_CSP_HEADER: str = os.getenv("API_CSP_HEADER", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'")
API_CSP_HEADER: str = os.getenv(
"API_CSP_HEADER",
"default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'",
)
# Monitoring
PROMETHEUS_ENABLED: bool = os.getenv("PROMETHEUS_ENABLED", "True").lower() == "true"
@@ -121,29 +161,46 @@ class Settings(BaseSettings):
MODULES_CONFIG_PATH: str = os.getenv("MODULES_CONFIG_PATH", "config/modules.yaml")
# RAG Embedding Configuration
RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE: int = int(os.getenv("RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE", "12"))
RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE: int = int(
os.getenv("RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE", "12")
)
RAG_EMBEDDING_BATCH_SIZE: int = int(os.getenv("RAG_EMBEDDING_BATCH_SIZE", "3"))
RAG_EMBEDDING_RETRY_COUNT: int = int(os.getenv("RAG_EMBEDDING_RETRY_COUNT", "3"))
RAG_EMBEDDING_RETRY_DELAYS: str = os.getenv("RAG_EMBEDDING_RETRY_DELAYS", "1,2,4,8,16")
RAG_EMBEDDING_DELAY_BETWEEN_BATCHES: float = float(os.getenv("RAG_EMBEDDING_DELAY_BETWEEN_BATCHES", "1.0"))
RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5"))
RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true"
RAG_WARN_ON_FALLBACK: bool = os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true"
RAG_EMBEDDING_RETRY_DELAYS: str = os.getenv(
"RAG_EMBEDDING_RETRY_DELAYS", "1,2,4,8,16"
)
RAG_EMBEDDING_DELAY_BETWEEN_BATCHES: float = float(
os.getenv("RAG_EMBEDDING_DELAY_BETWEEN_BATCHES", "1.0")
)
RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(
os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5")
)
RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = (
os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true"
)
RAG_WARN_ON_FALLBACK: bool = (
os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true"
)
RAG_EMBEDDING_MODEL: str = os.getenv("RAG_EMBEDDING_MODEL", "bge-m3")
RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300"))
RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120"))
RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(
os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300")
)
RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(
os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120")
)
RAG_INDEXING_TIMEOUT: int = int(os.getenv("RAG_INDEXING_TIMEOUT", "120"))
# Plugin configuration
PLUGINS_DIR: str = os.getenv("PLUGINS_DIR", "/plugins")
PLUGINS_CONFIG_PATH: str = os.getenv("PLUGINS_CONFIG_PATH", "config/plugins.yaml")
PLUGIN_REPOSITORY_URL: str = os.getenv("PLUGIN_REPOSITORY_URL", "https://plugins.enclava.com")
PLUGIN_REPOSITORY_URL: str = os.getenv(
"PLUGIN_REPOSITORY_URL", "https://plugins.enclava.com"
)
# Logging
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "json")
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
model_config = {
"env_file": ".env",
"case_sensitive": True,

View File

@@ -24,7 +24,9 @@ def setup_logging() -> None:
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
structlog.processors.JSONRenderer() if settings.LOG_FORMAT == "json" else structlog.dev.ConsoleRenderer(),
structlog.processors.JSONRenderer()
if settings.LOG_FORMAT == "json"
else structlog.dev.ConsoleRenderer(),
],
context_class=dict,
logger_factory=LoggerFactory(),

View File

@@ -0,0 +1,367 @@
"""
Permissions Module
Role-based access control decorators and utilities
"""
from datetime import datetime
from functools import wraps
from typing import List, Optional, Union, Callable
from fastapi import HTTPException, status, Depends
from fastapi.security import HTTPBearer
from app.models.user import User
security = HTTPBearer()
def require_permission(
user: User, permission: str, resource_id: Optional[Union[str, int]] = None
):
"""
Check if user has the required permission
Args:
user: User object from dependency injection
permission: Required permission string
resource_id: Optional resource ID for resource-specific permissions
Raises:
HTTPException: If user doesn't have the required permission
"""
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required"
)
# Check if user is active
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Account is not active"
)
# Check if account is locked
if user.account_locked:
if user.account_locked_until and user.account_locked_until > datetime.utcnow():
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Account is temporarily locked",
)
else:
# Unlock account if lock period has expired
user.unlock_account()
# Check permission
if not user.has_permission(permission):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Permission '{permission}' required",
)
def require_permissions(permissions: List[str], require_all: bool = True):
"""
Decorator to require multiple permissions
Args:
permissions: List of required permissions
require_all: If True, user must have all permissions. If False, any one permission is sufficient
"""
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
# Extract user from kwargs (assuming it's passed as a dependency)
user = None
for key, value in kwargs.items():
if isinstance(value, User):
user = value
break
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required",
)
# Check permissions
if require_all:
# User must have all permissions
for permission in permissions:
if not user.has_permission(permission):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"All of the following permissions required: {', '.join(permissions)}",
)
else:
# User needs at least one permission
if not any(
user.has_permission(permission) for permission in permissions
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"At least one of the following permissions required: {', '.join(permissions)}",
)
return await func(*args, **kwargs)
return wrapper
return decorator
def require_role(role_names: Union[str, List[str]], require_all: bool = True):
"""
Decorator to require specific roles
Args:
role_names: Required role name(s)
require_all: If True, user must have all roles. If False, any one role is sufficient
"""
if isinstance(role_names, str):
role_names = [role_names]
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
# Extract user from kwargs
user = None
for key, value in kwargs.items():
if isinstance(value, User):
user = value
break
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required",
)
# Check roles
user_role_names = []
if user.role:
user_role_names.append(user.role.name)
if require_all:
# User must have all roles
for role_name in role_names:
if role_name not in user_role_names:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"All of the following roles required: {', '.join(role_names)}",
)
else:
# User needs at least one role
if not any(role_name in user_role_names for role_name in role_names):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"At least one of the following roles required: {', '.join(role_names)}",
)
return await func(*args, **kwargs)
return wrapper
return decorator
def require_minimum_role(minimum_role_level: str):
"""
Decorator to require minimum role level based on hierarchy
Args:
minimum_role_level: Minimum required role level
"""
role_hierarchy = {"read_only": 1, "user": 2, "admin": 3, "super_admin": 4}
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
# Extract user from kwargs
user = None
for key, value in kwargs.items():
if isinstance(value, User):
user = value
break
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required",
)
# Superusers bypass role checks
if user.is_superuser:
return await func(*args, **kwargs)
# Check role level
if not user.role:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Minimum role level '{minimum_role_level}' required",
)
user_level = role_hierarchy.get(user.role.level, 0)
required_level = role_hierarchy.get(minimum_role_level, 0)
if user_level < required_level:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Minimum role level '{minimum_role_level}' required",
)
return await func(*args, **kwargs)
return wrapper
return decorator
def check_resource_permission(
user: User, resource_type: str, resource_id: Union[str, int], action: str
) -> bool:
"""
Check if user has permission to perform action on specific resource
Args:
user: User object
resource_type: Type of resource (e.g., 'user', 'budget', 'api_key')
resource_id: ID of the resource
action: Action to perform (e.g., 'read', 'update', 'delete')
Returns:
bool: True if user has permission, False otherwise
"""
# Superusers can do anything
if user.is_superuser:
return True
# Check basic permissions
permission = f"{action}_{resource_type}"
if user.has_permission(permission):
return True
# Check own resource permissions
if resource_type == "user" and str(resource_id) == str(user.id):
if user.has_permission(f"{action}_own"):
return True
# Check role-based resource access
if user.role:
# Admins can manage all users
if resource_type == "user" and user.role.level in ["admin", "super_admin"]:
return True
# Users with budget permissions can manage budgets
if resource_type == "budget" and user.role.can_manage_budgets:
return True
return False
def require_resource_permission(
resource_type: str, resource_id_param: str = "resource_id", action: str = "read"
):
"""
Decorator to require permission for specific resource
Args:
resource_type: Type of resource
resource_id_param: Name of parameter containing resource ID
action: Action to perform
"""
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
# Extract user and resource ID from kwargs
user = None
resource_id = None
for key, value in kwargs.items():
if isinstance(value, User):
user = value
elif key == resource_id_param:
resource_id = value
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required",
)
if resource_id is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Resource ID not provided",
)
# Check resource permission
if not check_resource_permission(user, resource_type, resource_id, action):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Permission '{action}_{resource_type}' required",
)
return await func(*args, **kwargs)
return wrapper
return decorator
def check_budget_permission(user: User, budget_id: int, action: str) -> bool:
"""
Check if user has permission to perform action on specific budget
Args:
user: User object
budget_id: ID of the budget
action: Action to perform
Returns:
bool: True if user has permission, False otherwise
"""
# Superusers can do anything
if user.is_superuser:
return True
# Check if user owns the budget
for budget in user.budgets:
if budget.id == budget_id and budget.is_active:
return user.has_permission(f"{action}_own")
# Check if user can manage all budgets
if user.has_permission(f"{action}_all") or (
user.role and user.role.can_manage_budgets
):
return True
return False
def check_api_key_permission(user: User, api_key_id: int, action: str) -> bool:
"""
Check if user has permission to perform action on specific API key
Args:
user: User object
api_key_id: ID of the API key
action: Action to perform
Returns:
bool: True if user has permission, False otherwise
"""
# Superusers can do anything
if user.is_superuser:
return True
# Check if user owns the API key
for api_key in user.api_keys:
if api_key.id == api_key_id:
return user.has_permission(f"{action}_own")
# Check if user can manage all API keys
if user.has_permission(f"{action}_all"):
return True
return False

View File

@@ -24,30 +24,35 @@ logger = logging.getLogger(__name__)
# Password hashing
# Use a lower work factor for better performance in production
pwd_context = CryptContext(
schemes=["bcrypt"],
deprecated="auto",
bcrypt__rounds=settings.BCRYPT_ROUNDS
schemes=["bcrypt"], deprecated="auto", bcrypt__rounds=settings.BCRYPT_ROUNDS
)
# JWT token handling
security = HTTPBearer()
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash"""
import time
start_time = time.time()
logger.info(f"=== PASSWORD VERIFICATION START === BCRYPT_ROUNDS: {settings.BCRYPT_ROUNDS}")
logger.info(
f"=== PASSWORD VERIFICATION START === BCRYPT_ROUNDS: {settings.BCRYPT_ROUNDS}"
)
try:
# Run password verification in a thread with timeout
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(pwd_context.verify, plain_password, hashed_password)
future = executor.submit(
pwd_context.verify, plain_password, hashed_password
)
result = future.result(timeout=5.0) # 5 second timeout
end_time = time.time()
duration = end_time - start_time
logger.info(f"=== PASSWORD VERIFICATION END === Duration: {duration:.3f}s, Result: {result}")
logger.info(
f"=== PASSWORD VERIFICATION END === Duration: {duration:.3f}s, Result: {result}"
)
if duration > 1:
logger.warning(f"PASSWORD VERIFICATION TOOK TOO LONG: {duration:.3f}s")
@@ -61,24 +66,33 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
except Exception as e:
end_time = time.time()
duration = end_time - start_time
logger.error(f"=== PASSWORD VERIFICATION FAILED === Duration: {duration:.3f}s, Error: {e}")
logger.error(
f"=== PASSWORD VERIFICATION FAILED === Duration: {duration:.3f}s, Error: {e}"
)
raise
def get_password_hash(password: str) -> str:
"""Generate password hash"""
return pwd_context.hash(password)
def verify_api_key(plain_api_key: str, hashed_api_key: str) -> bool:
"""Verify an API key against its hash"""
return pwd_context.verify(plain_api_key, hashed_api_key)
def get_api_key_hash(api_key: str) -> str:
"""Generate API key hash"""
return pwd_context.hash(api_key)
def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
def create_access_token(
data: Dict[str, Any], expires_delta: Optional[timedelta] = None
) -> str:
"""Create JWT access token"""
import time
start_time = time.time()
logger.info(f"=== CREATE ACCESS TOKEN START ===")
@@ -87,12 +101,16 @@ def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta]
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
expire = datetime.utcnow() + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
to_encode.update({"exp": expire})
logger.info(f"JWT encode start...")
encode_start = time.time()
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
encoded_jwt = jwt.encode(
to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM
)
encode_end = time.time()
encode_duration = encode_end - encode_start
@@ -103,7 +121,9 @@ def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta]
logger.info(f"Created access token for user {data.get('sub')}")
logger.info(f"Token expires at: {expire.isoformat()} (UTC)")
logger.info(f"Current UTC time: {datetime.utcnow().isoformat()}")
logger.info(f"ACCESS_TOKEN_EXPIRE_MINUTES setting: {settings.ACCESS_TOKEN_EXPIRE_MINUTES}")
logger.info(
f"ACCESS_TOKEN_EXPIRE_MINUTES setting: {settings.ACCESS_TOKEN_EXPIRE_MINUTES}"
)
logger.info(f"JWT encode duration: {encode_duration:.3f}s")
logger.info(f"Total token creation duration: {total_duration:.3f}s")
logger.info(f"=== CREATE ACCESS TOKEN END ===")
@@ -112,17 +132,25 @@ def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta]
except Exception as e:
end_time = time.time()
total_duration = end_time - start_time
logger.error(f"=== CREATE ACCESS TOKEN FAILED === Duration: {total_duration:.3f}s, Error: {e}")
logger.error(
f"=== CREATE ACCESS TOKEN FAILED === Duration: {total_duration:.3f}s, Error: {e}"
)
raise
def create_refresh_token(data: Dict[str, Any]) -> str:
"""Create JWT refresh token"""
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES)
expire = datetime.utcnow() + timedelta(
minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES
)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
encoded_jwt = jwt.encode(
to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM
)
return encoded_jwt
def verify_token(token: str) -> Dict[str, Any]:
"""Verify JWT token and return payload"""
try:
@@ -133,15 +161,21 @@ def verify_token(token: str) -> Dict[str, Any]:
# Decode without verification first to check expiration
try:
unverified_payload = jwt.get_unverified_claims(token)
exp_timestamp = unverified_payload.get('exp')
exp_timestamp = unverified_payload.get("exp")
if exp_timestamp:
exp_datetime = datetime.fromtimestamp(exp_timestamp, tz=None)
logger.info(f"Token expiration time: {exp_datetime.isoformat()} (UTC)")
logger.info(f"Time until expiration: {(exp_datetime - current_time).total_seconds()} seconds")
logger.info(
f"Time until expiration: {(exp_datetime - current_time).total_seconds()} seconds"
)
except Exception as decode_error:
logger.warning(f"Could not decode token for expiration check: {decode_error}")
logger.warning(
f"Could not decode token for expiration check: {decode_error}"
)
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
logger.info(f"Token verified successfully for user {payload.get('sub')}")
return payload
except JWTError as e:
@@ -149,9 +183,10 @@ def verify_token(token: str) -> Dict[str, Any]:
logger.warning(f"Current UTC time: {datetime.utcnow().isoformat()}")
raise AuthenticationError("Invalid token")
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Dict[str, Any]:
"""Get current user from JWT token"""
try:
@@ -167,9 +202,10 @@ async def get_current_user(
# Load user from database
from app.models.user import User
from sqlalchemy import select
from sqlalchemy.orm import selectinload
# Query user from database
stmt = select(User).where(User.id == int(user_id))
stmt = select(User).options(selectinload(User.role)).where(User.id == int(user_id))
result = await db.execute(stmt)
user = result.scalar_one_or_none()
@@ -181,7 +217,7 @@ async def get_current_user(
"is_superuser": payload.get("is_superuser", False),
"role": payload.get("role", "user"),
"is_active": True,
"permissions": [] # Default to empty list for permissions
"permissions": [], # Default to empty list for permissions
}
# Update last login
@@ -191,39 +227,43 @@ async def get_current_user(
# Calculate effective permissions using permission manager
from app.services.permission_manager import permission_registry
# Convert role string to list for permission calculation
user_roles = [user.role] if user.role else []
# Convert role to name for permission calculation
user_roles = [user.role.name] if user.role else []
# For super admin users, use only role-based permissions, ignore custom permissions
# Custom permissions might contain legacy formats like ['*'] that don't work with new system
# Custom permissions might contain legacy formats like ['*'] or dict formats
custom_permissions = []
if not user.is_superuser:
# Only use custom permissions for non-superuser accounts
if user.permissions:
if isinstance(user.permissions, list):
custom_permissions = user.permissions
# Support both list-based and dict-based custom permission formats
raw_custom_perms = getattr(user, "custom_permissions", None)
if raw_custom_perms:
if isinstance(raw_custom_perms, list):
custom_permissions = raw_custom_perms
elif isinstance(raw_custom_perms, dict):
granted = raw_custom_perms.get("granted")
if isinstance(granted, list):
custom_permissions = granted
# Calculate effective permissions based on role and custom permissions
effective_permissions = permission_registry.get_user_permissions(
roles=user_roles,
custom_permissions=custom_permissions
roles=user_roles, custom_permissions=custom_permissions
)
return {
"id": user.id,
"email": user.email,
"username": user.username,
"is_superuser": user.is_superuser,
"is_active": user.is_active,
"role": user.role,
"role": user.role.name if user.role else None,
"permissions": effective_permissions, # Use calculated permissions
"user_obj": user # Include full user object for other operations
"user_obj": user, # Include full user object for other operations
}
except Exception as e:
logger.error(f"Authentication error: {e}")
raise AuthenticationError("Could not validate credentials")
async def get_current_active_user(
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
@@ -233,6 +273,7 @@ async def get_current_active_user(
raise AuthenticationError("User account is inactive")
return current_user
async def get_current_superuser(
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
@@ -241,6 +282,7 @@ async def get_current_superuser(
raise AuthorizationError("Insufficient privileges")
return current_user
def generate_api_key() -> str:
"""Generate a new API key"""
import secrets
@@ -248,21 +290,23 @@ def generate_api_key() -> str:
# Generate random string
alphabet = string.ascii_letters + string.digits
api_key = ''.join(secrets.choice(alphabet) for _ in range(32))
api_key = "".join(secrets.choice(alphabet) for _ in range(32))
return f"{settings.API_KEY_PREFIX}{api_key}"
def hash_api_key(api_key: str) -> str:
"""Hash API key for storage"""
return get_password_hash(api_key)
def verify_api_key(api_key: str, hashed_key: str) -> bool:
"""Verify API key against hash"""
return verify_password(api_key, hashed_key)
async def get_api_key_user(
request: Request,
db: AsyncSession = Depends(get_db)
request: Request, db: AsyncSession = Depends(get_db)
) -> Optional[Dict[str, Any]]:
"""Get user from API key"""
api_key = request.headers.get("X-API-Key")
@@ -282,10 +326,14 @@ async def get_api_key_user(
key_prefix = api_key[:8]
# Query API key from database
stmt = select(APIKey).join(User).where(
stmt = (
select(APIKey)
.join(User)
.where(
APIKey.key_prefix == key_prefix,
APIKey.is_active == True,
User.is_active == True
User.is_active == True,
)
)
result = await db.execute(stmt)
db_api_key = result.scalar_one_or_none()
@@ -306,7 +354,7 @@ async def get_api_key_user(
await db.commit()
# Load associated user
user_stmt = select(User).where(User.id == db_api_key.user_id)
user_stmt = select(User).options(selectinload(User.role)).where(User.id == db_api_key.user_id)
user_result = await db.execute(user_stmt)
user = user_result.scalar_one_or_none()
@@ -316,22 +364,36 @@ async def get_api_key_user(
# Calculate effective permissions using permission manager
from app.services.permission_manager import permission_registry
# Convert role string to list for permission calculation
user_roles = [user.role] if user.role else []
# Convert role to name for permission calculation
user_roles = [user.role.name] if user.role else []
# Use API key specific permissions if available, otherwise use user permissions
# Use API key specific permissions if available
api_key_permissions = db_api_key.permissions if db_api_key.permissions else []
# Get custom permissions from database (convert dict to list if needed)
custom_permissions = api_key_permissions
if user.permissions:
if isinstance(user.permissions, list):
custom_permissions.extend(user.permissions)
# Normalize permissions into a flat list of granted permission strings
custom_permissions: list[str] = []
# Handle API key permissions that may be stored as list or dict
if isinstance(api_key_permissions, list):
custom_permissions.extend(api_key_permissions)
elif isinstance(api_key_permissions, dict):
api_granted = api_key_permissions.get("granted")
if isinstance(api_granted, list):
custom_permissions.extend(api_granted)
# Merge in user-level custom permissions for non-superusers
raw_user_custom = getattr(user, "custom_permissions", None)
if raw_user_custom and not user.is_superuser:
if isinstance(raw_user_custom, list):
custom_permissions.extend(raw_user_custom)
elif isinstance(raw_user_custom, dict):
user_granted = raw_user_custom.get("granted")
if isinstance(user_granted, list):
custom_permissions.extend(user_granted)
# Calculate effective permissions based on role and custom permissions
effective_permissions = permission_registry.get_user_permissions(
roles=user_roles,
custom_permissions=custom_permissions
roles=user_roles, custom_permissions=custom_permissions
)
return {
@@ -344,12 +406,13 @@ async def get_api_key_user(
"permissions": effective_permissions,
"api_key": db_api_key,
"user_obj": user,
"auth_type": "api_key"
"auth_type": "api_key",
}
except Exception as e:
logger.error(f"API key lookup error: {e}")
return None
class RequiresPermission:
"""Dependency class for permission checking"""
@@ -367,7 +430,14 @@ class RequiresPermission:
role_permissions = {
"user": ["read_own", "create_own", "update_own"],
"admin": ["read_all", "create_all", "update_all", "delete_own"],
"super_admin": ["read_all", "create_all", "update_all", "delete_all", "manage_users", "manage_modules"]
"super_admin": [
"read_all",
"create_all",
"update_all",
"delete_all",
"manage_users",
"manage_modules",
],
}
if role in role_permissions and self.permission in role_permissions[role]:
@@ -386,6 +456,7 @@ class RequiresPermission:
raise AuthorizationError(f"Permission '{self.permission}' required")
class RequiresRole:
"""Dependency class for role checking"""
@@ -401,11 +472,7 @@ class RequiresRole:
user_role = current_user.get("role", "user")
# Define role hierarchy
role_hierarchy = {
"user": 1,
"admin": 2,
"super_admin": 3
}
role_hierarchy = {"user": 1, "admin": 2, "super_admin": 3}
required_level = role_hierarchy.get(self.role, 0)
user_level = role_hierarchy.get(user_role, 0)
@@ -413,4 +480,6 @@ class RequiresRole:
if user_level >= required_level:
return current_user
raise AuthorizationError(f"Role '{self.role}' required, but user has role '{user_role}'")
raise AuthorizationError(
f"Role '{self.role}' required, but user has role '{user_role}'"
)

View File

@@ -72,6 +72,7 @@ metadata = MetaData()
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""Get database session"""
import time
start_time = time.time()
request_id = f"db_{int(time.time() * 1000)}"
@@ -86,7 +87,10 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]:
except Exception as e:
# Only log if there's an actual error, not normal operation
if str(e).strip(): # Only log if error message exists
logger.error(f"[{request_id}] Database session error: {str(e)}", exc_info=True)
logger.error(
f"[{request_id}] Database session error: {str(e)}",
exc_info=True,
)
await session.rollback()
raise
finally:
@@ -94,9 +98,13 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]:
await session.close()
close_time = time.time() - close_start
total_time = time.time() - start_time
logger.info(f"[{request_id}] Database session closed. Close time: {close_time:.3f}s, Total time: {total_time:.3f}s")
logger.info(
f"[{request_id}] Database session closed. Close time: {close_time:.3f}s, Total time: {total_time:.3f}s"
)
except Exception as e:
logger.error(f"[{request_id}] Failed to create database session: {e}", exc_info=True)
logger.error(
f"[{request_id}] Failed to create database session: {e}", exc_info=True
)
raise
@@ -106,8 +114,10 @@ async def init_db():
async with engine.begin() as conn:
# Import all models to ensure they're registered
from app.models.user import User
from app.models.role import Role
from app.models.api_key import APIKey
from app.models.usage_tracking import UsageTracking
# Import additional models - these are available
try:
from app.models.budget import Budget
@@ -127,6 +137,9 @@ async def init_db():
# Tables are now created via migration container - no need to create here
# await conn.run_sync(Base.metadata.create_all) # DISABLED - migrations handle this
# Create default roles if they don't exist
await create_default_roles()
# Create default admin user if no admin exists
await create_default_admin()
@@ -136,9 +149,42 @@ async def init_db():
raise
async def create_default_roles():
"""Create default roles if they don't exist"""
from app.models.role import Role, RoleLevel
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
try:
async with async_session_factory() as session:
# Check if any roles exist
stmt = select(Role).limit(1)
result = await session.execute(stmt)
existing_role = result.scalar_one_or_none()
if existing_role:
logger.info("Roles already exist - skipping default role creation")
return
# Create default roles using the Role.create_default_roles class method
default_roles = Role.create_default_roles()
for role in default_roles:
session.add(role)
await session.commit()
logger.info("Created default roles: read_only, user, admin, super_admin")
except SQLAlchemyError as e:
logger.error(f"Failed to create default roles due to database error: {e}")
raise
async def create_default_admin():
"""Create default admin user if user with ADMIN_EMAIL doesn't exist"""
from app.models.user import User
from app.models.role import Role
from app.core.security import get_password_hash
from app.core.config import settings
from sqlalchemy import select
@@ -159,19 +205,33 @@ async def create_default_admin():
existing_user = result.scalar_one_or_none()
if existing_user:
logger.info(f"User with email {admin_email} already exists - skipping admin creation")
logger.info(
f"User with email {admin_email} already exists - skipping admin creation"
)
return
# Get the super_admin role
stmt = select(Role).where(Role.name == "super_admin")
result = await session.execute(stmt)
super_admin_role = result.scalar_one_or_none()
if not super_admin_role:
logger.error("Super admin role not found - cannot create admin user")
return
# Create admin user from environment variables
# Generate username from email (part before @)
admin_username = admin_email.split('@')[0]
admin_username = admin_email.split("@")[0]
admin_user = User.create_default_admin(
email=admin_email,
username=admin_username,
password_hash=get_password_hash(admin_password)
password_hash=get_password_hash(admin_password),
)
# Assign the super_admin role
admin_user.role_id = super_admin_role.id
session.add(admin_user)
await session.commit()
@@ -179,14 +239,19 @@ async def create_default_admin():
logger.warning("ADMIN USER CREATED FROM ENVIRONMENT")
logger.warning(f"Email: {admin_email}")
logger.warning(f"Username: {admin_username}")
logger.warning("Password: [Set via ADMIN_PASSWORD - only used on first creation]")
logger.warning("Role: Super Administrator")
logger.warning(
"Password: [Set via ADMIN_PASSWORD - only used on first creation]"
)
logger.warning("PLEASE CHANGE THE PASSWORD AFTER FIRST LOGIN")
logger.warning("=" * 60)
except SQLAlchemyError as e:
logger.error(f"Failed to create default admin user due to database error: {e}")
except AttributeError as e:
logger.error(f"Failed to create default admin user: invalid ADMIN_EMAIL '{settings.ADMIN_EMAIL}'")
logger.error(
f"Failed to create default admin user: invalid ADMIN_EMAIL '{settings.ADMIN_EMAIL}'"
)
except Exception as e:
logger.error(f"Failed to create default admin user: {e}")
# Don't raise here as this shouldn't block the application startup

View File

@@ -59,7 +59,10 @@ async def _check_redis_startup():
duration = time.perf_counter() - start
logger.info(
"Startup Redis check succeeded",
extra={"redis_url": settings.REDIS_URL, "duration_seconds": round(duration, 3)},
extra={
"redis_url": settings.REDIS_URL,
"duration_seconds": round(duration, 3),
},
)
except Exception as exc: # noqa: BLE001
logger.warning(
@@ -107,6 +110,7 @@ async def lifespan(app: FastAPI):
# Initialize core cache service (before database to provide caching for auth)
from app.core.cache import core_cache
try:
await core_cache.initialize()
logger.info("Core cache service initialized successfully")
@@ -128,6 +132,7 @@ async def lifespan(app: FastAPI):
# Ensure platform permissions are registered before module discovery
from app.services.permission_manager import permission_registry
permission_registry.register_platform_permissions()
# Initialize LLM service (needed by RAG module) concurrently
@@ -156,6 +161,7 @@ async def lifespan(app: FastAPI):
# Initialize document processor
from app.services.document_processor import document_processor
try:
await document_processor.start()
app.state.document_processor = document_processor
@@ -171,6 +177,7 @@ async def lifespan(app: FastAPI):
# Start background audit worker
from app.services.audit_service import start_audit_worker
try:
start_audit_worker()
except Exception as exc:
@@ -179,10 +186,13 @@ async def lifespan(app: FastAPI):
# Initialize plugin auto-discovery service concurrently
async def initialize_plugins():
from app.services.plugin_autodiscovery import initialize_plugin_autodiscovery
try:
discovery_results = await initialize_plugin_autodiscovery()
app.state.plugin_discovery_results = discovery_results
logger.info(f"Plugin auto-discovery completed: {discovery_results.get('summary')}")
logger.info(
f"Plugin auto-discovery completed: {discovery_results.get('summary')}"
)
except Exception as exc:
logger.warning(f"Plugin auto-discovery failed: {exc}")
app.state.plugin_discovery_results = {"error": str(exc)}
@@ -205,6 +215,7 @@ async def lifespan(app: FastAPI):
# Cleanup embedding service HTTP sessions
from app.services.embedding_service import embedding_service
try:
await embedding_service.cleanup()
logger.info("Embedding service cleaned up successfully")
@@ -213,14 +224,16 @@ async def lifespan(app: FastAPI):
# Close core cache service
from app.core.cache import core_cache
await core_cache.cleanup()
# Close Redis connection for cached API key service
from app.services.cached_api_key import cached_api_key_service
await cached_api_key_service.close()
# Stop document processor
processor = getattr(app.state, 'document_processor', None)
processor = getattr(app.state, "document_processor", None)
if processor:
await processor.stop()
@@ -297,7 +310,9 @@ async def validation_exception_handler(request, exc: RequestValidationError):
"type": error.get("type", ""),
"location": error.get("loc", []),
"message": error.get("msg", ""),
"input": str(error.get("input", "")) if error.get("input") is not None else None
"input": str(error.get("input", ""))
if error.get("input") is not None
else None,
}
errors.append(error_dict)

View File

@@ -17,7 +17,7 @@ from app.db.database import get_db
logger = get_logger(__name__)
# Context variable to pass analytics data from endpoints to middleware
analytics_context: ContextVar[dict] = ContextVar('analytics_context', default={})
analytics_context: ContextVar[dict] = ContextVar("analytics_context", default={})
class AnalyticsMiddleware(BaseHTTPMiddleware):
@@ -28,7 +28,12 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
start_time = time.time()
# Skip analytics for health checks and static files
if request.url.path in ["/health", "/docs", "/redoc", "/openapi.json"] or request.url.path.startswith("/static"):
if request.url.path in [
"/health",
"/docs",
"/redoc",
"/openapi.json",
] or request.url.path.startswith("/static"):
return await call_next(request)
# Get user info if available from token
@@ -42,6 +47,7 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
# Try to extract user info from token without full validation
# This is a lightweight check for analytics purposes
from app.core.security import verify_token
try:
payload = verify_token(token)
user_id = int(payload.get("sub"))
@@ -77,7 +83,7 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
error_message = str(e)
response = JSONResponse(
status_code=500,
content={"error": "INTERNAL_ERROR", "message": "Internal server error"}
content={"error": "INTERNAL_ERROR", "message": "Internal server error"},
)
# Calculate timing
@@ -86,7 +92,7 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
# Get response size
response_size = 0
if hasattr(response, 'body'):
if hasattr(response, "body"):
response_size = len(response.body) if response.body else 0
# Get analytics data from context (set by endpoints)
@@ -107,22 +113,25 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
response_size=response_size,
error_message=error_message,
# Token/cost info populated by LLM endpoints via context
model=context_data.get('model'),
request_tokens=context_data.get('request_tokens', 0),
response_tokens=context_data.get('response_tokens', 0),
total_tokens=context_data.get('total_tokens', 0),
cost_cents=context_data.get('cost_cents', 0),
budget_ids=context_data.get('budget_ids', []),
budget_warnings=context_data.get('budget_warnings', [])
model=context_data.get("model"),
request_tokens=context_data.get("request_tokens", 0),
response_tokens=context_data.get("response_tokens", 0),
total_tokens=context_data.get("total_tokens", 0),
cost_cents=context_data.get("cost_cents", 0),
budget_ids=context_data.get("budget_ids", []),
budget_warnings=context_data.get("budget_warnings", []),
)
# Track the event
try:
from app.services.analytics import analytics_service
if analytics_service is not None:
await analytics_service.track_request(event)
else:
logger.warning("Analytics service not initialized, skipping event tracking")
logger.warning(
"Analytics service not initialized, skipping event tracking"
)
except Exception as e:
logger.error(f"Failed to track analytics event: {e}")
# Don't let analytics failures break the request

View File

@@ -0,0 +1,393 @@
"""
Audit Logging Middleware
Automatically logs user actions and system events
"""
import time
import json
import logging
from typing import Callable, Optional, Dict, Any
from datetime import datetime
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.audit_log import AuditLog, AuditAction, AuditSeverity
from app.db.database import get_db_session
from app.core.security import verify_token
logger = logging.getLogger(__name__)
class AuditLoggingMiddleware(BaseHTTPMiddleware):
"""Middleware to automatically log user actions and API calls"""
def __init__(self, app, exclude_paths: Optional[list] = None):
super().__init__(app)
# Paths to exclude from audit logging
self.exclude_paths = exclude_paths or [
"/docs",
"/redoc",
"/openapi.json",
"/health",
"/metrics",
"/static",
"/favicon.ico",
]
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Skip audit logging for excluded paths
if any(request.url.path.startswith(path) for path in self.exclude_paths):
return await call_next(request)
# Skip audit logging for health checks and static assets
if request.url.path in ["/", "/health"] or "/static/" in request.url.path:
return await call_next(request)
start_time = time.time()
# Extract user information from request
user_info = await self._extract_user_info(request)
# Prepare audit data
audit_data = {
"method": request.method,
"path": request.url.path,
"query_params": dict(request.query_params),
"ip_address": self._get_client_ip(request),
"user_agent": request.headers.get("user-agent"),
"timestamp": datetime.utcnow().isoformat(),
}
# Process request
response = await call_next(request)
# Calculate response time
process_time = time.time() - start_time
audit_data["response_time"] = round(process_time * 1000, 2) # milliseconds
audit_data["status_code"] = response.status_code
audit_data["success"] = 200 <= response.status_code < 400
# Log the audit event asynchronously
try:
await self._log_audit_event(user_info, audit_data, request)
except Exception as e:
logger.error(f"Failed to log audit event: {e}")
# Don't fail the request if audit logging fails
return response
async def _extract_user_info(self, request: Request) -> Optional[Dict[str, Any]]:
"""Extract user information from request headers"""
try:
# Try to get user info from Authorization header
auth_header = request.headers.get("authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header.split(" ")[1]
payload = verify_token(token)
return {
"user_id": int(payload.get("sub")) if payload.get("sub") else None,
"email": payload.get("email"),
"is_superuser": payload.get("is_superuser", False),
"role": payload.get("role"),
}
except Exception:
# If token verification fails, continue without user info
pass
# Try to get user info from API key header
api_key = request.headers.get("x-api-key")
if api_key:
# Would need to implement API key lookup here
# For now, just indicate it's an API key request
return {
"user_id": None,
"email": "api_key_user",
"is_superuser": False,
"role": "api_user",
"auth_type": "api_key",
}
return None
def _get_client_ip(self, request: Request) -> str:
"""Get client IP address with proxy support"""
# Check for forwarded headers first (for reverse proxy setups)
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
# Take the first IP in the chain
return forwarded_for.split(",")[0].strip()
forwarded = request.headers.get("x-forwarded")
if forwarded:
return forwarded
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip
# Fall back to direct client IP
return request.client.host if request.client else "unknown"
async def _log_audit_event(
self,
user_info: Optional[Dict[str, Any]],
audit_data: Dict[str, Any],
request: Request
):
"""Log the audit event to database"""
# Determine action based on HTTP method and path
action = self._determine_action(request.method, request.url.path)
# Determine resource type and ID from path
resource_type, resource_id = self._parse_resource_from_path(request.url.path)
# Create description
description = self._create_description(request.method, request.url.path, audit_data["success"])
# Determine severity
severity = self._determine_severity(request.method, audit_data["status_code"], request.url.path)
# Create audit log entry
try:
async with get_db_session() as db:
audit_log = AuditLog(
user_id=user_info.get("user_id") if user_info else None,
action=action,
resource_type=resource_type,
resource_id=resource_id,
description=description,
details={
"request": {
"method": audit_data["method"],
"path": audit_data["path"],
"query_params": audit_data["query_params"],
"response_time_ms": audit_data["response_time"],
},
"user_info": user_info,
},
ip_address=audit_data["ip_address"],
user_agent=audit_data["user_agent"],
severity=severity,
category=self._determine_category(request.url.path),
success=audit_data["success"],
tags=self._generate_tags(request.method, request.url.path),
)
db.add(audit_log)
await db.commit()
except Exception as e:
logger.error(f"Failed to save audit log to database: {e}")
# Could implement fallback logging to file here
def _determine_action(self, method: str, path: str) -> str:
"""Determine action type from HTTP method and path"""
method = method.upper()
if method == "GET":
return AuditAction.READ
elif method == "POST":
if "login" in path.lower():
return AuditAction.LOGIN
elif "logout" in path.lower():
return AuditAction.LOGOUT
else:
return AuditAction.CREATE
elif method == "PUT" or method == "PATCH":
return AuditAction.UPDATE
elif method == "DELETE":
return AuditAction.DELETE
else:
return method.lower()
def _parse_resource_from_path(self, path: str) -> tuple[str, Optional[str]]:
"""Parse resource type and ID from URL path"""
path_parts = path.strip("/").split("/")
# Skip API version prefix
if path_parts and path_parts[0] in ["api", "api-internal"]:
path_parts = path_parts[2:] # Skip 'api' and 'v1'
if not path_parts:
return "system", None
resource_type = path_parts[0]
resource_id = None
# Try to find numeric ID in path
for part in path_parts[1:]:
if part.isdigit():
resource_id = part
break
return resource_type, resource_id
def _create_description(self, method: str, path: str, success: bool) -> str:
"""Create human-readable description of the action"""
action_verbs = {
"GET": "accessed" if success else "attempted to access",
"POST": "created" if success else "attempted to create",
"PUT": "updated" if success else "attempted to update",
"PATCH": "modified" if success else "attempted to modify",
"DELETE": "deleted" if success else "attempted to delete",
}
verb = action_verbs.get(method, method.lower())
resource = path.strip("/").split("/")[-1] if "/" in path else path
return f"User {verb} {resource}"
def _determine_severity(self, method: str, status_code: int, path: str) -> str:
"""Determine severity level based on action and outcome"""
# Critical operations
if any(keyword in path.lower() for keyword in ["delete", "password", "admin", "key"]):
return AuditSeverity.HIGH
# Failed operations
if status_code >= 400:
if status_code >= 500:
return AuditSeverity.CRITICAL
elif status_code in [401, 403]:
return AuditSeverity.HIGH
else:
return AuditSeverity.MEDIUM
# Write operations
if method in ["POST", "PUT", "PATCH", "DELETE"]:
return AuditSeverity.MEDIUM
# Read operations
return AuditSeverity.LOW
def _determine_category(self, path: str) -> str:
"""Determine category based on path"""
path = path.lower()
if any(keyword in path for keyword in ["auth", "login", "logout", "token"]):
return "authentication"
elif any(keyword in path for keyword in ["user", "admin", "role", "permission"]):
return "user_management"
elif any(keyword in path for keyword in ["api-key", "key"]):
return "security"
elif any(keyword in path for keyword in ["budget", "billing", "usage"]):
return "financial"
elif any(keyword in path for keyword in ["audit", "log"]):
return "audit"
elif any(keyword in path for keyword in ["setting", "config"]):
return "configuration"
else:
return "general"
def _generate_tags(self, method: str, path: str) -> list[str]:
"""Generate tags for the audit log"""
tags = [method.lower()]
path_parts = path.strip("/").split("/")
if path_parts:
tags.append(path_parts[0])
# Add special tags
if "admin" in path.lower():
tags.append("admin_action")
if any(keyword in path.lower() for keyword in ["password", "auth", "login"]):
tags.append("security_action")
return tags
class LoginAuditMiddleware(BaseHTTPMiddleware):
"""Specialized middleware for login/logout events"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Only process auth-related endpoints
if not any(path in request.url.path for path in ["/auth/login", "/auth/logout", "/auth/refresh"]):
return await call_next(request)
start_time = time.time()
# Store request body for login attempts
request_body = None
if request.method == "POST" and "/login" in request.url.path:
try:
body = await request.body()
if body:
request_body = json.loads(body.decode())
# Re-create request with body for downstream processing
from starlette.requests import Request as StarletteRequest
from io import BytesIO
request._body = body
except Exception as e:
logger.warning(f"Failed to parse login request body: {e}")
response = await call_next(request)
# Log login/logout events
try:
await self._log_auth_event(request, response, request_body, time.time() - start_time)
except Exception as e:
logger.error(f"Failed to log auth event: {e}")
return response
async def _log_auth_event(self, request: Request, response: Response, request_body: dict, process_time: float):
"""Log authentication events"""
success = 200 <= response.status_code < 300
if "/login" in request.url.path:
# Extract email/username from request
identifier = None
if request_body:
identifier = request_body.get("email") or request_body.get("username")
# For successful logins, we could extract user_id from response
# For now, we'll use the identifier
async with get_db_session() as db:
audit_log = AuditLog.create_login_event(
user_id=None, # Would need to extract from response for successful logins
success=success,
ip_address=self._get_client_ip(request),
user_agent=request.headers.get("user-agent"),
error_message=f"HTTP {response.status_code}" if not success else None,
)
# Add additional details
audit_log.details.update({
"identifier": identifier,
"response_time_ms": round(process_time * 1000, 2),
})
db.add(audit_log)
await db.commit()
elif "/logout" in request.url.path:
# Extract user info from token if available
user_id = None
try:
auth_header = request.headers.get("authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header.split(" ")[1]
payload = verify_token(token)
user_id = int(payload.get("sub")) if payload.get("sub") else None
except Exception:
pass
async with get_db_session() as db:
audit_log = AuditLog.create_logout_event(
user_id=user_id,
session_id=None, # Could extract from token if stored
)
db.add(audit_log)
await db.commit()
def _get_client_ip(self, request: Request) -> str:
"""Get client IP address with proxy support"""
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
return request.client.host if request.client else "unknown"

View File

@@ -23,8 +23,12 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
request_id = str(uuid4())
# Skip debugging for health checks and static files
if request.url.path in ["/health", "/docs", "/redoc", "/openapi.json"] or \
request.url.path.startswith("/static"):
if request.url.path in [
"/health",
"/docs",
"/redoc",
"/openapi.json",
] or request.url.path.startswith("/static"):
return await call_next(request)
# Log request details
@@ -37,7 +41,7 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
try:
request_body = json.loads(body_bytes)
except json.JSONDecodeError:
request_body = body_bytes.decode('utf-8', errors='replace')
request_body = body_bytes.decode("utf-8", errors="replace")
# Restore body for downstream processing
request._body = body_bytes
except Exception:
@@ -45,8 +49,9 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
# Extract headers we care about
headers_to_log = {
"authorization": request.headers.get("Authorization", "")[:50] + "..." if
request.headers.get("Authorization") else None,
"authorization": request.headers.get("Authorization", "")[:50] + "..."
if request.headers.get("Authorization")
else None,
"content-type": request.headers.get("Content-Type"),
"user-agent": request.headers.get("User-Agent"),
"x-forwarded-for": request.headers.get("X-Forwarded-For"),
@@ -54,7 +59,9 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
}
# Log request
logger.info("=== API REQUEST DEBUG ===", extra={
logger.info(
"=== API REQUEST DEBUG ===",
extra={
"request_id": request_id,
"method": request.method,
"url": str(request.url),
@@ -63,8 +70,9 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
"headers": {k: v for k, v in headers_to_log.items() if v is not None},
"body": request_body,
"client_ip": request.client.host if request.client else None,
"timestamp": datetime.utcnow().isoformat()
})
"timestamp": datetime.utcnow().isoformat(),
},
)
# Process the request
start_time = time.time()
@@ -73,33 +81,43 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
# Add timeout detection
try:
logger.info(f"=== START PROCESSING REQUEST === {request_id} at {datetime.utcnow().isoformat()}")
logger.info(
f"=== START PROCESSING REQUEST === {request_id} at {datetime.utcnow().isoformat()}"
)
logger.info(f"Request path: {request.url.path}")
logger.info(f"Request method: {request.method}")
# Check if this is the login endpoint
if request.url.path == "/api-internal/v1/auth/login" and request.method == "POST":
if (
request.url.path == "/api-internal/v1/auth/login"
and request.method == "POST"
):
logger.info(f"=== LOGIN REQUEST DETECTED === {request_id}")
response = await call_next(request)
logger.info(f"=== REQUEST COMPLETED === {request_id} at {datetime.utcnow().isoformat()}")
logger.info(
f"=== REQUEST COMPLETED === {request_id} at {datetime.utcnow().isoformat()}"
)
# Capture response body for successful JSON responses
if response.status_code < 400 and isinstance(response, JSONResponse):
try:
response_body = json.loads(response.body.decode('utf-8'))
response_body = json.loads(response.body.decode("utf-8"))
except:
response_body = "[Failed to decode response body]"
except Exception as e:
logger.error(f"Request processing failed: {str(e)}", extra={
logger.error(
f"Request processing failed: {str(e)}",
extra={
"request_id": request_id,
"error": str(e),
"error_type": type(e).__name__
})
"error_type": type(e).__name__,
},
)
response = JSONResponse(
status_code=500,
content={"error": "INTERNAL_ERROR", "message": "Internal server error"}
content={"error": "INTERNAL_ERROR", "message": "Internal server error"},
)
# Calculate timing
@@ -107,14 +125,17 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
duration = (end_time - start_time) * 1000 # milliseconds
# Log response
logger.info("=== API RESPONSE DEBUG ===", extra={
logger.info(
"=== API RESPONSE DEBUG ===",
extra={
"request_id": request_id,
"status_code": response.status_code,
"duration_ms": round(duration, 2),
"response_body": response_body,
"response_headers": dict(response.headers),
"timestamp": datetime.utcnow().isoformat()
})
"timestamp": datetime.utcnow().isoformat(),
},
)
return response

View File

@@ -9,9 +9,31 @@ from .budget import Budget
from .audit_log import AuditLog
from .rag_collection import RagCollection
from .rag_document import RagDocument
from .chatbot import ChatbotInstance, ChatbotConversation, ChatbotMessage, ChatbotAnalytics
from .chatbot import (
ChatbotInstance,
ChatbotConversation,
ChatbotMessage,
ChatbotAnalytics,
)
from .prompt_template import PromptTemplate, ChatbotPromptVariable
from .plugin import Plugin, PluginConfiguration, PluginInstance, PluginAuditLog, PluginCronJob, PluginAPIGateway
from .plugin import (
Plugin,
PluginConfiguration,
PluginInstance,
PluginAuditLog,
PluginCronJob,
PluginAPIGateway,
)
from .role import Role, RoleLevel
from .tool import Tool, ToolExecution, ToolCategory, ToolType, ToolStatus
from .notification import (
Notification,
NotificationTemplate,
NotificationChannel,
NotificationType,
NotificationPriority,
NotificationStatus,
)
__all__ = [
"User",
@@ -32,5 +54,18 @@ __all__ = [
"PluginInstance",
"PluginAuditLog",
"PluginCronJob",
"PluginAPIGateway"
"PluginAPIGateway",
"Role",
"RoleLevel",
"Tool",
"ToolExecution",
"ToolCategory",
"ToolType",
"ToolStatus",
"Notification",
"NotificationTemplate",
"NotificationChannel",
"NotificationType",
"NotificationPriority",
"NotificationStatus",
]

View File

@@ -3,7 +3,16 @@ API Key model
"""
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON, ForeignKey
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
Boolean,
Text,
JSON,
ForeignKey,
)
from sqlalchemy.orm import relationship
from app.db.database import Base
@@ -16,15 +25,21 @@ class APIKey(Base):
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False) # Human-readable name for the API key
key_hash = Column(String, unique=True, index=True, nullable=False) # Hashed API key
key_prefix = Column(String, index=True, nullable=False) # First 8 characters for identification
key_prefix = Column(
String, index=True, nullable=False
) # First 8 characters for identification
# User relationship
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
user = relationship("User", back_populates="api_keys")
# Related data relationships
budgets = relationship("Budget", back_populates="api_key", cascade="all, delete-orphan")
usage_tracking = relationship("UsageTracking", back_populates="api_key", cascade="all, delete-orphan")
budgets = relationship(
"Budget", back_populates="api_key", cascade="all, delete-orphan"
)
usage_tracking = relationship(
"UsageTracking", back_populates="api_key", cascade="all, delete-orphan"
)
# Key status and permissions
is_active = Column(Boolean, default=True)
@@ -40,7 +55,9 @@ class APIKey(Base):
allowed_models = Column(JSON, default=list) # List of allowed LLM models
allowed_endpoints = Column(JSON, default=list) # List of allowed API endpoints
allowed_ips = Column(JSON, default=list) # IP whitelist
allowed_chatbots = Column(JSON, default=list) # List of allowed chatbot IDs for chatbot-specific keys
allowed_chatbots = Column(
JSON, default=list
) # List of allowed chatbot IDs for chatbot-specific keys
# Budget configuration
is_unlimited = Column(Boolean, default=True) # Unlimited budget flag
@@ -63,8 +80,12 @@ class APIKey(Base):
total_cost = Column(Integer, default=0) # In cents
# Relationships
usage_tracking = relationship("UsageTracking", back_populates="api_key", cascade="all, delete-orphan")
budgets = relationship("Budget", back_populates="api_key", cascade="all, delete-orphan")
usage_tracking = relationship(
"UsageTracking", back_populates="api_key", cascade="all, delete-orphan"
)
budgets = relationship(
"Budget", back_populates="api_key", cascade="all, delete-orphan"
)
plugin_audit_logs = relationship("PluginAuditLog", back_populates="api_key")
def __repr__(self):
@@ -91,14 +112,16 @@ class APIKey(Base):
"tags": self.tags,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
"last_used_at": self.last_used_at.isoformat()
if self.last_used_at
else None,
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
"total_requests": self.total_requests,
"total_tokens": self.total_tokens,
"total_cost_cents": self.total_cost,
"is_unlimited": self.is_unlimited,
"budget_limit": self.budget_limit_cents, # Map to budget_limit for API response
"budget_type": self.budget_type
"budget_type": self.budget_type,
}
if include_sensitive:
@@ -222,7 +245,9 @@ class APIKey(Base):
self.allowed_chatbots.remove(chatbot_id)
@classmethod
def create_default_key(cls, user_id: int, name: str, key_hash: str, key_prefix: str) -> "APIKey":
def create_default_key(
cls, user_id: int, name: str, key_hash: str, key_prefix: str
) -> "APIKey":
"""Create a default API key with standard permissions"""
return cls(
name=name,
@@ -230,17 +255,8 @@ class APIKey(Base):
key_prefix=key_prefix,
user_id=user_id,
is_active=True,
permissions={
"read": True,
"write": True,
"chat": True,
"embeddings": True
},
scopes=[
"chat.completions",
"embeddings.create",
"models.list"
],
permissions={"read": True, "write": True, "chat": True, "embeddings": True},
scopes=["chat.completions", "embeddings.create", "models.list"],
rate_limit_per_minute=60,
rate_limit_per_hour=3600,
rate_limit_per_day=86400,
@@ -248,12 +264,19 @@ class APIKey(Base):
allowed_endpoints=[], # All endpoints allowed by default
allowed_ips=[], # All IPs allowed by default
description="Default API key with standard permissions",
tags=["default"]
tags=["default"],
)
@classmethod
def create_restricted_key(cls, user_id: int, name: str, key_hash: str, key_prefix: str,
models: List[str], endpoints: List[str]) -> "APIKey":
def create_restricted_key(
cls,
user_id: int,
name: str,
key_hash: str,
key_prefix: str,
models: List[str],
endpoints: List[str],
) -> "APIKey":
"""Create a restricted API key with limited permissions"""
return cls(
name=name,
@@ -261,13 +284,8 @@ class APIKey(Base):
key_prefix=key_prefix,
user_id=user_id,
is_active=True,
permissions={
"read": True,
"chat": True
},
scopes=[
"chat.completions"
],
permissions={"read": True, "chat": True},
scopes=["chat.completions"],
rate_limit_per_minute=30,
rate_limit_per_hour=1800,
rate_limit_per_day=43200,
@@ -275,12 +293,19 @@ class APIKey(Base):
allowed_endpoints=endpoints,
allowed_ips=[],
description="Restricted API key with limited permissions",
tags=["restricted"]
tags=["restricted"],
)
@classmethod
def create_chatbot_key(cls, user_id: int, name: str, key_hash: str, key_prefix: str,
chatbot_id: str, chatbot_name: str) -> "APIKey":
def create_chatbot_key(
cls,
user_id: int,
name: str,
key_hash: str,
key_prefix: str,
chatbot_id: str,
chatbot_name: str,
) -> "APIKey":
"""Create a chatbot-specific API key"""
return cls(
name=name,
@@ -288,22 +313,18 @@ class APIKey(Base):
key_prefix=key_prefix,
user_id=user_id,
is_active=True,
permissions={
"chatbot": True
},
scopes=[
"chatbot.chat"
],
permissions={"chatbot": True},
scopes=["chatbot.chat"],
rate_limit_per_minute=100,
rate_limit_per_hour=6000,
rate_limit_per_day=144000,
allowed_models=[], # Will use chatbot's configured model
allowed_endpoints=[
f"/api/v1/chatbot/external/{chatbot_id}/chat",
f"/api/v1/chatbot/external/{chatbot_id}/chat/completions"
f"/api/v1/chatbot/external/{chatbot_id}/chat/completions",
],
allowed_ips=[],
allowed_chatbots=[chatbot_id],
description=f"API key for chatbot: {chatbot_name}",
tags=["chatbot", f"chatbot-{chatbot_id}"]
tags=["chatbot", f"chatbot-{chatbot_id}"],
)

View File

@@ -3,7 +3,16 @@ Audit log model for tracking system events and user actions
"""
from datetime import datetime
from typing import Optional, Dict, Any
from sqlalchemy import Column, Integer, String, DateTime, JSON, ForeignKey, Text, Boolean
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
JSON,
ForeignKey,
Text,
Boolean,
)
from sqlalchemy.orm import relationship
from app.db.database import Base
from enum import Enum
@@ -11,6 +20,7 @@ from enum import Enum
class AuditAction(str, Enum):
"""Audit action types"""
CREATE = "create"
READ = "read"
UPDATE = "update"
@@ -32,6 +42,7 @@ class AuditAction(str, Enum):
class AuditSeverity(str, Enum):
"""Audit severity levels"""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
@@ -51,7 +62,9 @@ class AuditLog(Base):
# Event details
action = Column(String, nullable=False)
resource_type = Column(String, nullable=False) # user, api_key, budget, module, etc.
resource_type = Column(
String, nullable=False
) # user, api_key, budget, module, etc.
resource_id = Column(String, nullable=True) # ID of the affected resource
# Event description and details
@@ -74,7 +87,9 @@ class AuditLog(Base):
# Additional metadata
tags = Column(JSON, default=list)
audit_metadata = Column("metadata", JSON, default=dict) # Map to 'metadata' column in DB
audit_metadata = Column(
"metadata", JSON, default=dict
) # Map to 'metadata' column in DB
# Before/after values for data changes
old_values = Column(JSON, nullable=True)
@@ -84,7 +99,9 @@ class AuditLog(Base):
created_at = Column(DateTime, default=datetime.utcnow, index=True)
def __repr__(self):
return f"<AuditLog(id={self.id}, action='{self.action}', user_id={self.user_id})>"
return (
f"<AuditLog(id={self.id}, action='{self.action}', user_id={self.user_id})>"
)
def to_dict(self):
"""Convert audit log to dictionary for API responses"""
@@ -108,7 +125,7 @@ class AuditLog(Base):
"metadata": self.audit_metadata,
"old_values": self.old_values,
"new_values": self.new_values,
"created_at": self.created_at.isoformat() if self.created_at else None
"created_at": self.created_at.isoformat() if self.created_at else None,
}
def is_security_event(self) -> bool:
@@ -120,7 +137,7 @@ class AuditLog(Base):
AuditAction.API_KEY_DELETE,
AuditAction.PERMISSION_GRANT,
AuditAction.PERMISSION_REVOKE,
AuditAction.SECURITY_EVENT
AuditAction.SECURITY_EVENT,
]
return self.action in security_actions or self.category == "security"
@@ -150,9 +167,15 @@ class AuditLog(Base):
self.new_values = new_values
@classmethod
def create_login_event(cls, user_id: int, success: bool = True,
ip_address: str = None, user_agent: str = None,
session_id: str = None, error_message: str = None) -> "AuditLog":
def create_login_event(
cls,
user_id: int,
success: bool = True,
ip_address: str = None,
user_agent: str = None,
session_id: str = None,
error_message: str = None,
) -> "AuditLog":
"""Create a login audit event"""
return cls(
user_id=user_id,
@@ -160,10 +183,7 @@ class AuditLog(Base):
resource_type="user",
resource_id=str(user_id),
description=f"User login {'successful' if success else 'failed'}",
details={
"login_method": "password",
"success": success
},
details={"login_method": "password", "success": success},
ip_address=ip_address,
user_agent=user_agent,
session_id=session_id,
@@ -171,7 +191,7 @@ class AuditLog(Base):
category="security",
success=success,
error_message=error_message,
tags=["authentication", "login"]
tags=["authentication", "login"],
)
@classmethod
@@ -183,20 +203,24 @@ class AuditLog(Base):
resource_type="user",
resource_id=str(user_id),
description="User logout",
details={
"logout_method": "manual"
},
details={"logout_method": "manual"},
session_id=session_id,
severity=AuditSeverity.LOW,
category="security",
success=True,
tags=["authentication", "logout"]
tags=["authentication", "logout"],
)
@classmethod
def create_api_key_event(cls, user_id: int, action: str, api_key_id: int,
api_key_name: str, success: bool = True,
error_message: str = None) -> "AuditLog":
def create_api_key_event(
cls,
user_id: int,
action: str,
api_key_id: int,
api_key_name: str,
success: bool = True,
error_message: str = None,
) -> "AuditLog":
"""Create an API key audit event"""
return cls(
user_id=user_id,
@@ -204,21 +228,24 @@ class AuditLog(Base):
resource_type="api_key",
resource_id=str(api_key_id),
description=f"API key {action}: {api_key_name}",
details={
"api_key_name": api_key_name,
"action": action
},
details={"api_key_name": api_key_name, "action": action},
severity=AuditSeverity.MEDIUM,
category="security",
success=success,
error_message=error_message,
tags=["api_key", action]
tags=["api_key", action],
)
@classmethod
def create_budget_event(cls, user_id: int, action: str, budget_id: int,
budget_name: str, details: Dict[str, Any] = None,
success: bool = True) -> "AuditLog":
def create_budget_event(
cls,
user_id: int,
action: str,
budget_id: int,
budget_name: str,
details: Dict[str, Any] = None,
success: bool = True,
) -> "AuditLog":
"""Create a budget audit event"""
return cls(
user_id=user_id,
@@ -227,16 +254,24 @@ class AuditLog(Base):
resource_id=str(budget_id),
description=f"Budget {action}: {budget_name}",
details=details or {},
severity=AuditSeverity.MEDIUM if action == AuditAction.BUDGET_EXCEED else AuditSeverity.LOW,
severity=AuditSeverity.MEDIUM
if action == AuditAction.BUDGET_EXCEED
else AuditSeverity.LOW,
category="financial",
success=success,
tags=["budget", action]
tags=["budget", action],
)
@classmethod
def create_module_event(cls, user_id: int, action: str, module_name: str,
success: bool = True, error_message: str = None,
details: Dict[str, Any] = None) -> "AuditLog":
def create_module_event(
cls,
user_id: int,
action: str,
module_name: str,
success: bool = True,
error_message: str = None,
details: Dict[str, Any] = None,
) -> "AuditLog":
"""Create a module audit event"""
return cls(
user_id=user_id,
@@ -249,12 +284,18 @@ class AuditLog(Base):
category="system",
success=success,
error_message=error_message,
tags=["module", action]
tags=["module", action],
)
@classmethod
def create_permission_event(cls, user_id: int, action: str, target_user_id: int,
permission: str, success: bool = True) -> "AuditLog":
def create_permission_event(
cls,
user_id: int,
action: str,
target_user_id: int,
permission: str,
success: bool = True,
) -> "AuditLog":
"""Create a permission audit event"""
return cls(
user_id=user_id,
@@ -262,21 +303,23 @@ class AuditLog(Base):
resource_type="permission",
resource_id=str(target_user_id),
description=f"Permission {action}: {permission} for user {target_user_id}",
details={
"permission": permission,
"target_user_id": target_user_id
},
details={"permission": permission, "target_user_id": target_user_id},
severity=AuditSeverity.HIGH,
category="security",
success=success,
tags=["permission", action]
tags=["permission", action],
)
@classmethod
def create_security_event(cls, user_id: int, event_type: str, description: str,
def create_security_event(
cls,
user_id: int,
event_type: str,
description: str,
severity: str = AuditSeverity.HIGH,
details: Dict[str, Any] = None,
ip_address: str = None) -> "AuditLog":
ip_address: str = None,
) -> "AuditLog":
"""Create a security audit event"""
return cls(
user_id=user_id,
@@ -289,15 +332,19 @@ class AuditLog(Base):
severity=severity,
category="security",
success=False, # Security events are typically failures
tags=["security", event_type]
tags=["security", event_type],
)
@classmethod
def create_system_event(cls, action: str, description: str,
def create_system_event(
cls,
action: str,
description: str,
resource_type: str = "system",
resource_id: str = None,
severity: str = AuditSeverity.LOW,
details: Dict[str, Any] = None) -> "AuditLog":
details: Dict[str, Any] = None,
) -> "AuditLog":
"""Create a system audit event"""
return cls(
user_id=None, # System events don't have a user
@@ -309,14 +356,20 @@ class AuditLog(Base):
severity=severity,
category="system",
success=True,
tags=["system", action]
tags=["system", action],
)
@classmethod
def create_data_change_event(cls, user_id: int, action: str, resource_type: str,
resource_id: str, description: str,
def create_data_change_event(
cls,
user_id: int,
action: str,
resource_type: str,
resource_id: str,
description: str,
old_values: Dict[str, Any],
new_values: Dict[str, Any]) -> "AuditLog":
new_values: Dict[str, Any],
) -> "AuditLog":
"""Create a data change audit event"""
return cls(
user_id=user_id,
@@ -329,7 +382,7 @@ class AuditLog(Base):
severity=AuditSeverity.LOW,
category="data",
success=True,
tags=["data_change", action]
tags=["data_change", action],
)
def get_summary(self) -> Dict[str, Any]:
@@ -342,5 +395,5 @@ class AuditLog(Base):
"severity": self.severity,
"success": self.success,
"created_at": self.created_at.isoformat() if self.created_at else None,
"user_id": self.user_id
"user_id": self.user_id,
}

View File

@@ -5,13 +5,24 @@ Budget model for managing spending limits and cost control
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from enum import Enum
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON, ForeignKey, Float
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
Boolean,
Text,
JSON,
ForeignKey,
Float,
)
from sqlalchemy.orm import relationship
from app.db.database import Base
class BudgetType(str, Enum):
"""Budget type enumeration"""
USER = "user"
API_KEY = "api_key"
GLOBAL = "global"
@@ -19,6 +30,7 @@ class BudgetType(str, Enum):
class BudgetPeriod(str, Enum):
"""Budget period types"""
DAILY = "daily"
WEEKLY = "weekly"
MONTHLY = "monthly"
@@ -38,18 +50,26 @@ class Budget(Base):
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
user = relationship("User", back_populates="budgets")
api_key_id = Column(Integer, ForeignKey("api_keys.id"), nullable=True) # Optional: specific to an API key
api_key_id = Column(
Integer, ForeignKey("api_keys.id"), nullable=True
) # Optional: specific to an API key
api_key = relationship("APIKey", back_populates="budgets")
# Usage tracking relationship
usage_tracking = relationship("UsageTracking", back_populates="budget", cascade="all, delete-orphan")
usage_tracking = relationship(
"UsageTracking", back_populates="budget", cascade="all, delete-orphan"
)
# Budget limits (in cents)
limit_cents = Column(Integer, nullable=False) # Maximum spend limit
warning_threshold_cents = Column(Integer, nullable=True) # Warning threshold (e.g., 80% of limit)
warning_threshold_cents = Column(
Integer, nullable=True
) # Warning threshold (e.g., 80% of limit)
# Time period settings
period_type = Column(String, nullable=False, default="monthly") # daily, weekly, monthly, yearly, custom
period_type = Column(
String, nullable=False, default="monthly"
) # daily, weekly, monthly, yearly, custom
period_start = Column(DateTime, nullable=False) # Start of current period
period_end = Column(DateTime, nullable=False) # End of current period
@@ -62,12 +82,18 @@ class Budget(Base):
is_warning_sent = Column(Boolean, default=False)
# Enforcement settings
enforce_hard_limit = Column(Boolean, default=True) # Block requests when limit exceeded
enforce_hard_limit = Column(
Boolean, default=True
) # Block requests when limit exceeded
enforce_warning = Column(Boolean, default=True) # Send warnings at threshold
# Allowed resources (optional filters)
allowed_models = Column(JSON, default=list) # Specific models this budget applies to
allowed_endpoints = Column(JSON, default=list) # Specific endpoints this budget applies to
allowed_models = Column(
JSON, default=list
) # Specific models this budget applies to
allowed_endpoints = Column(
JSON, default=list
) # Specific endpoints this budget applies to
# Metadata
description = Column(Text, nullable=True)
@@ -75,8 +101,12 @@ class Budget(Base):
currency = Column(String, default="USD")
# Auto-renewal settings
auto_renew = Column(Boolean, default=True) # Automatically renew budget for next period
rollover_unused = Column(Boolean, default=False) # Rollover unused budget to next period
auto_renew = Column(
Boolean, default=True
) # Automatically renew budget for next period
rollover_unused = Column(
Boolean, default=False
) # Rollover unused budget to next period
# Notification settings
notification_settings = Column(JSON, default=dict) # Email, webhook, etc.
@@ -99,15 +129,23 @@ class Budget(Base):
"limit_cents": self.limit_cents,
"limit_dollars": self.limit_cents / 100,
"warning_threshold_cents": self.warning_threshold_cents,
"warning_threshold_dollars": self.warning_threshold_cents / 100 if self.warning_threshold_cents else None,
"warning_threshold_dollars": self.warning_threshold_cents / 100
if self.warning_threshold_cents
else None,
"period_type": self.period_type,
"period_start": self.period_start.isoformat() if self.period_start else None,
"period_start": self.period_start.isoformat()
if self.period_start
else None,
"period_end": self.period_end.isoformat() if self.period_end else None,
"current_usage_cents": self.current_usage_cents,
"current_usage_dollars": self.current_usage_cents / 100,
"remaining_cents": max(0, self.limit_cents - self.current_usage_cents),
"remaining_dollars": max(0, (self.limit_cents - self.current_usage_cents) / 100),
"usage_percentage": (self.current_usage_cents / self.limit_cents * 100) if self.limit_cents > 0 else 0,
"remaining_dollars": max(
0, (self.limit_cents - self.current_usage_cents) / 100
),
"usage_percentage": (self.current_usage_cents / self.limit_cents * 100)
if self.limit_cents > 0
else 0,
"is_active": self.is_active,
"is_exceeded": self.is_exceeded,
"is_warning_sent": self.is_warning_sent,
@@ -123,7 +161,9 @@ class Budget(Base):
"notification_settings": self.notification_settings,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"last_reset_at": self.last_reset_at.isoformat() if self.last_reset_at else None
"last_reset_at": self.last_reset_at.isoformat()
if self.last_reset_at
else None,
}
def is_in_period(self) -> bool:
@@ -161,7 +201,10 @@ class Budget(Base):
self.is_exceeded = True
# Check if warning threshold is reached
if self.warning_threshold_cents and self.current_usage_cents >= self.warning_threshold_cents:
if (
self.warning_threshold_cents
and self.current_usage_cents >= self.warning_threshold_cents
):
if not self.is_warning_sent:
self.is_warning_sent = True
@@ -190,9 +233,13 @@ class Budget(Base):
self.period_start = self.period_end
# Handle month boundaries properly
if self.period_start.month == 12:
next_month = self.period_start.replace(year=self.period_start.year + 1, month=1)
next_month = self.period_start.replace(
year=self.period_start.year + 1, month=1
)
else:
next_month = self.period_start.replace(month=self.period_start.month + 1)
next_month = self.period_start.replace(
month=self.period_start.month + 1
)
self.period_end = next_month
elif self.period_type == "yearly":
self.period_start = self.period_end
@@ -230,7 +277,7 @@ class Budget(Base):
name: str,
limit_dollars: float,
api_key_id: Optional[int] = None,
warning_threshold_percentage: float = 0.8
warning_threshold_percentage: float = 0.8,
) -> "Budget":
"""Create a monthly budget"""
now = datetime.utcnow()
@@ -258,10 +305,7 @@ class Budget(Base):
enforce_hard_limit=True,
enforce_warning=True,
auto_renew=True,
notification_settings={
"email_on_warning": True,
"email_on_exceeded": True
}
notification_settings={"email_on_warning": True, "email_on_exceeded": True},
)
@classmethod
@@ -270,7 +314,7 @@ class Budget(Base):
user_id: int,
name: str,
limit_dollars: float,
api_key_id: Optional[int] = None
api_key_id: Optional[int] = None,
) -> "Budget":
"""Create a daily budget"""
now = datetime.utcnow()
@@ -292,5 +336,5 @@ class Budget(Base):
is_active=True,
enforce_hard_limit=True,
enforce_warning=True,
auto_renew=True
auto_renew=True,
)

View File

@@ -1,7 +1,16 @@
"""
Database models for chatbot module
"""
from sqlalchemy import Column, Integer, String, Text, Boolean, DateTime, JSON, ForeignKey
from sqlalchemy import (
Column,
Integer,
String,
Text,
Boolean,
DateTime,
JSON,
ForeignKey,
)
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base
from datetime import datetime
@@ -9,8 +18,10 @@ import uuid
from app.db.database import Base
class ChatbotInstance(Base):
"""Configured chatbot instance"""
__tablename__ = "chatbot_instances"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
@@ -27,7 +38,9 @@ class ChatbotInstance(Base):
is_active = Column(Boolean, default=True)
# Relationships
conversations = relationship("ChatbotConversation", back_populates="chatbot", cascade="all, delete-orphan")
conversations = relationship(
"ChatbotConversation", back_populates="chatbot", cascade="all, delete-orphan"
)
def __repr__(self):
return f"<ChatbotInstance(id='{self.id}', name='{self.name}')>"
@@ -35,6 +48,7 @@ class ChatbotInstance(Base):
class ChatbotConversation(Base):
"""Conversation state and history"""
__tablename__ = "chatbot_conversations"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
@@ -52,7 +66,9 @@ class ChatbotConversation(Base):
# Relationships
chatbot = relationship("ChatbotInstance", back_populates="conversations")
messages = relationship("ChatbotMessage", back_populates="conversation", cascade="all, delete-orphan")
messages = relationship(
"ChatbotMessage", back_populates="conversation", cascade="all, delete-orphan"
)
def __repr__(self):
return f"<ChatbotConversation(id='{self.id}', chatbot_id='{self.chatbot_id}')>"
@@ -60,10 +76,13 @@ class ChatbotConversation(Base):
class ChatbotMessage(Base):
"""Individual chat messages in conversations"""
__tablename__ = "chatbot_messages"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
conversation_id = Column(String, ForeignKey("chatbot_conversations.id"), nullable=False)
conversation_id = Column(
String, ForeignKey("chatbot_conversations.id"), nullable=False
)
# Message content
role = Column(String(20), nullable=False) # 'user', 'assistant', 'system'
@@ -85,6 +104,7 @@ class ChatbotMessage(Base):
class ChatbotAnalytics(Base):
"""Analytics and metrics for chatbot usage"""
__tablename__ = "chatbot_analytics"
id = Column(Integer, primary_key=True, autoincrement=True)
@@ -92,7 +112,9 @@ class ChatbotAnalytics(Base):
user_id = Column(String, nullable=False)
# Event tracking
event_type = Column(String(50), nullable=False) # 'message_sent', 'response_generated', etc.
event_type = Column(
String(50), nullable=False
) # 'message_sent', 'response_generated', etc.
event_data = Column(JSON, default=dict)
# Performance metrics

View File

@@ -10,6 +10,7 @@ from enum import Enum
class ModuleStatus(str, Enum):
"""Module status types"""
ACTIVE = "active"
INACTIVE = "inactive"
ERROR = "error"
@@ -19,6 +20,7 @@ class ModuleStatus(str, Enum):
class ModuleType(str, Enum):
"""Module type categories"""
CORE = "core"
INTERCEPTOR = "interceptor"
ANALYTICS = "analytics"
@@ -66,14 +68,20 @@ class Module(Base):
entry_point = Column(String, nullable=True) # Main module entry point
# Interceptor configuration
interceptor_chains = Column(JSON, default=list) # Which chains this module hooks into
interceptor_chains = Column(
JSON, default=list
) # Which chains this module hooks into
execution_order = Column(Integer, default=100) # Order in interceptor chain
# API endpoints
api_endpoints = Column(JSON, default=list) # List of API endpoints this module provides
api_endpoints = Column(
JSON, default=list
) # List of API endpoints this module provides
# Permissions and security
required_permissions = Column(JSON, default=list) # Permissions required to use this module
required_permissions = Column(
JSON, default=list
) # Permissions required to use this module
security_level = Column(String, default="low") # low, medium, high, critical
# Metadata
@@ -130,16 +138,22 @@ class Module(Base):
"metadata": self.module_metadata,
"last_error": self.last_error,
"error_count": self.error_count,
"last_started": self.last_started.isoformat() if self.last_started else None,
"last_stopped": self.last_stopped.isoformat() if self.last_stopped else None,
"last_started": self.last_started.isoformat()
if self.last_started
else None,
"last_stopped": self.last_stopped.isoformat()
if self.last_stopped
else None,
"request_count": self.request_count,
"success_count": self.success_count,
"error_count_runtime": self.error_count_runtime,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"installed_at": self.installed_at.isoformat() if self.installed_at else None,
"installed_at": self.installed_at.isoformat()
if self.installed_at
else None,
"success_rate": self.get_success_rate(),
"uptime": self.get_uptime_seconds() if self.is_running() else 0
"uptime": self.get_uptime_seconds() if self.is_running() else 0,
}
def is_running(self) -> bool:
@@ -315,8 +329,14 @@ class Module(Base):
self.required_permissions.remove(permission)
@classmethod
def create_core_module(cls, name: str, display_name: str, description: str,
version: str, entry_point: str) -> "Module":
def create_core_module(
cls,
name: str,
display_name: str,
description: str,
version: str,
entry_point: str,
) -> "Module":
"""Create a core module"""
return cls(
name=name,
@@ -341,7 +361,7 @@ class Module(Base):
required_permissions=[],
security_level="high",
tags=["core"],
module_metadata={}
module_metadata={},
)
@classmethod
@@ -365,20 +385,12 @@ class Module(Base):
"properties": {
"provider": {"type": "string", "enum": ["redis"]},
"ttl": {"type": "integer", "minimum": 60},
"max_size": {"type": "integer", "minimum": 1000}
"max_size": {"type": "integer", "minimum": 1000},
},
"required": ["provider", "ttl"]
},
config_values={
"provider": "redis",
"ttl": 3600,
"max_size": 10000
},
default_config={
"provider": "redis",
"ttl": 3600,
"max_size": 10000
"required": ["provider", "ttl"],
},
config_values={"provider": "redis", "ttl": 3600, "max_size": 10000},
default_config={"provider": "redis", "ttl": 3600, "max_size": 10000},
dependencies=[],
conflicts=[],
interceptor_chains=["pre_request", "post_response"],
@@ -387,7 +399,7 @@ class Module(Base):
required_permissions=["cache.read", "cache.write"],
security_level="low",
tags=["cache", "performance"],
module_metadata={}
module_metadata={},
)
@classmethod
@@ -412,21 +424,21 @@ class Module(Base):
"vector_db": {"type": "string", "enum": ["qdrant"]},
"embedding_model": {"type": "string"},
"chunk_size": {"type": "integer", "minimum": 100},
"max_results": {"type": "integer", "minimum": 1}
"max_results": {"type": "integer", "minimum": 1},
},
"required": ["vector_db", "embedding_model"]
"required": ["vector_db", "embedding_model"],
},
config_values={
"vector_db": "qdrant",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"chunk_size": 512,
"max_results": 10
"max_results": 10,
},
default_config={
"vector_db": "qdrant",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"chunk_size": 512,
"max_results": 10
"max_results": 10,
},
dependencies=[],
conflicts=[],
@@ -436,7 +448,7 @@ class Module(Base):
required_permissions=["rag.read", "rag.write"],
security_level="medium",
tags=["rag", "ai", "search"],
module_metadata={}
module_metadata={},
)
@classmethod
@@ -460,19 +472,19 @@ class Module(Base):
"properties": {
"track_requests": {"type": "boolean"},
"track_responses": {"type": "boolean"},
"retention_days": {"type": "integer", "minimum": 1}
"retention_days": {"type": "integer", "minimum": 1},
},
"required": ["track_requests", "track_responses"]
"required": ["track_requests", "track_responses"],
},
config_values={
"track_requests": True,
"track_responses": True,
"retention_days": 30
"retention_days": 30,
},
default_config={
"track_requests": True,
"track_responses": True,
"retention_days": 30
"retention_days": 30,
},
dependencies=[],
conflicts=[],
@@ -482,7 +494,7 @@ class Module(Base):
required_permissions=["analytics.read"],
security_level="low",
tags=["analytics", "monitoring"],
module_metadata={}
module_metadata={},
)
def get_health_status(self) -> Dict[str, Any]:
@@ -495,5 +507,7 @@ class Module(Base):
"uptime_seconds": self.get_uptime_seconds() if self.is_running() else 0,
"last_error": self.last_error,
"error_count": self.error_count_runtime,
"last_started": self.last_started.isoformat() if self.last_started else None
"last_started": self.last_started.isoformat()
if self.last_started
else None,
}

View File

@@ -0,0 +1,295 @@
"""
Notification models for multi-channel communication
"""
from datetime import datetime
from typing import Optional, Dict, Any
from enum import Enum
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
Boolean,
Text,
JSON,
ForeignKey,
)
from sqlalchemy.orm import relationship
from app.db.database import Base
class NotificationType(str, Enum):
"""Notification types"""
EMAIL = "email"
WEBHOOK = "webhook"
SLACK = "slack"
DISCORD = "discord"
SMS = "sms"
PUSH = "push"
class NotificationPriority(str, Enum):
"""Notification priority levels"""
LOW = "low"
NORMAL = "normal"
HIGH = "high"
URGENT = "urgent"
class NotificationStatus(str, Enum):
"""Notification delivery status"""
PENDING = "pending"
SENT = "sent"
DELIVERED = "delivered"
FAILED = "failed"
RETRY = "retry"
class NotificationTemplate(Base):
"""Notification template model"""
__tablename__ = "notification_templates"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), nullable=False, unique=True)
display_name = Column(String(200), nullable=False)
description = Column(Text, nullable=True)
# Template content
notification_type = Column(String(20), nullable=False) # NotificationType enum
subject_template = Column(Text, nullable=True) # For email/messages
body_template = Column(Text, nullable=False) # Main content
html_template = Column(Text, nullable=True) # HTML version for email
# Configuration
default_priority = Column(String(20), default=NotificationPriority.NORMAL)
variables = Column(JSON, default=dict) # Expected template variables
template_metadata = Column(JSON, default=dict) # Additional configuration
# Status
is_active = Column(Boolean, default=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
notifications = relationship("Notification", back_populates="template")
def __repr__(self):
return f"<NotificationTemplate(id={self.id}, name='{self.name}', type='{self.notification_type}')>"
def to_dict(self):
"""Convert template to dictionary"""
return {
"id": self.id,
"name": self.name,
"display_name": self.display_name,
"description": self.description,
"notification_type": self.notification_type,
"subject_template": self.subject_template,
"body_template": self.body_template,
"html_template": self.html_template,
"default_priority": self.default_priority,
"variables": self.variables,
"template_metadata": self.template_metadata,
"is_active": self.is_active,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
class NotificationChannel(Base):
"""Notification channel configuration"""
__tablename__ = "notification_channels"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), nullable=False)
display_name = Column(String(200), nullable=False)
notification_type = Column(String(20), nullable=False) # NotificationType enum
# Channel configuration
config = Column(JSON, nullable=False) # Channel-specific settings
credentials = Column(JSON, nullable=True) # Encrypted credentials
# Settings
is_active = Column(Boolean, default=True)
is_default = Column(Boolean, default=False)
rate_limit = Column(Integer, default=100) # Messages per minute
retry_count = Column(Integer, default=3)
retry_delay_minutes = Column(Integer, default=5)
# Health monitoring
last_used_at = Column(DateTime, nullable=True)
success_count = Column(Integer, default=0)
failure_count = Column(Integer, default=0)
last_error = Column(Text, nullable=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
notifications = relationship("Notification", back_populates="channel")
def __repr__(self):
return f"<NotificationChannel(id={self.id}, name='{self.name}', type='{self.notification_type}')>"
def to_dict(self):
"""Convert channel to dictionary (excluding sensitive credentials)"""
return {
"id": self.id,
"name": self.name,
"display_name": self.display_name,
"notification_type": self.notification_type,
"config": self.config,
"is_active": self.is_active,
"is_default": self.is_default,
"rate_limit": self.rate_limit,
"retry_count": self.retry_count,
"retry_delay_minutes": self.retry_delay_minutes,
"last_used_at": self.last_used_at.isoformat()
if self.last_used_at
else None,
"success_count": self.success_count,
"failure_count": self.failure_count,
"last_error": self.last_error,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
def update_stats(self, success: bool, error_message: Optional[str] = None):
"""Update channel statistics"""
self.last_used_at = datetime.utcnow()
if success:
self.success_count += 1
self.last_error = None
else:
self.failure_count += 1
self.last_error = error_message
class Notification(Base):
"""Individual notification instance"""
__tablename__ = "notifications"
id = Column(Integer, primary_key=True, index=True)
# Content
subject = Column(String(500), nullable=True)
body = Column(Text, nullable=False)
html_body = Column(Text, nullable=True)
# Recipients
recipients = Column(JSON, nullable=False) # List of recipient addresses/IDs
cc_recipients = Column(JSON, nullable=True) # CC recipients (for email)
bcc_recipients = Column(JSON, nullable=True) # BCC recipients (for email)
# Configuration
priority = Column(String(20), default=NotificationPriority.NORMAL)
scheduled_at = Column(DateTime, nullable=True) # For scheduled delivery
expires_at = Column(DateTime, nullable=True) # Expiration time
# References
template_id = Column(
Integer, ForeignKey("notification_templates.id"), nullable=True
)
channel_id = Column(Integer, ForeignKey("notification_channels.id"), nullable=False)
user_id = Column(Integer, ForeignKey("users.id"), nullable=True) # Triggering user
# Status tracking
status = Column(String(20), default=NotificationStatus.PENDING)
attempts = Column(Integer, default=0)
max_attempts = Column(Integer, default=3)
# Delivery tracking
sent_at = Column(DateTime, nullable=True)
delivered_at = Column(DateTime, nullable=True)
failed_at = Column(DateTime, nullable=True)
error_message = Column(Text, nullable=True)
# External references
external_id = Column(String(200), nullable=True) # Provider message ID
callback_url = Column(String(500), nullable=True) # Delivery callback
# Metadata
notification_metadata = Column(JSON, default=dict)
tags = Column(JSON, default=list)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
template = relationship("NotificationTemplate", back_populates="notifications")
channel = relationship("NotificationChannel", back_populates="notifications")
user = relationship("User", back_populates="notifications")
def __repr__(self):
return f"<Notification(id={self.id}, status='{self.status}', channel='{self.channel.name if self.channel else 'unknown'}')>"
def to_dict(self):
"""Convert notification to dictionary"""
return {
"id": self.id,
"subject": self.subject,
"body": self.body,
"html_body": self.html_body,
"recipients": self.recipients,
"cc_recipients": self.cc_recipients,
"bcc_recipients": self.bcc_recipients,
"priority": self.priority,
"scheduled_at": self.scheduled_at.isoformat()
if self.scheduled_at
else None,
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
"template_id": self.template_id,
"channel_id": self.channel_id,
"user_id": self.user_id,
"status": self.status,
"attempts": self.attempts,
"max_attempts": self.max_attempts,
"sent_at": self.sent_at.isoformat() if self.sent_at else None,
"delivered_at": self.delivered_at.isoformat()
if self.delivered_at
else None,
"failed_at": self.failed_at.isoformat() if self.failed_at else None,
"error_message": self.error_message,
"external_id": self.external_id,
"callback_url": self.callback_url,
"notification_metadata": self.notification_metadata,
"tags": self.tags,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
def mark_sent(self, external_id: Optional[str] = None):
"""Mark notification as sent"""
self.status = NotificationStatus.SENT
self.sent_at = datetime.utcnow()
self.external_id = external_id
def mark_delivered(self):
"""Mark notification as delivered"""
self.status = NotificationStatus.DELIVERED
self.delivered_at = datetime.utcnow()
def mark_failed(self, error_message: str):
"""Mark notification as failed"""
self.status = NotificationStatus.FAILED
self.failed_at = datetime.utcnow()
self.error_message = error_message
self.attempts += 1
def can_retry(self) -> bool:
"""Check if notification can be retried"""
return (
self.status in [NotificationStatus.FAILED, NotificationStatus.RETRY]
and self.attempts < self.max_attempts
and (self.expires_at is None or self.expires_at > datetime.utcnow())
)

View File

@@ -2,7 +2,17 @@
Plugin System Database Models
Defines the database schema for the isolated plugin architecture
"""
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, JSON, ForeignKey, Index
from sqlalchemy import (
Column,
Integer,
String,
Text,
DateTime,
Boolean,
JSON,
ForeignKey,
Index,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
@@ -13,12 +23,15 @@ from app.db.database import Base
class Plugin(Base):
"""Plugin registry - tracks all installed plugins"""
__tablename__ = "plugins"
# Primary identification
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(100), unique=True, nullable=False, index=True)
slug = Column(String(100), unique=True, nullable=False, index=True) # URL-safe identifier
slug = Column(
String(100), unique=True, nullable=False, index=True
) # URL-safe identifier
# Metadata
display_name = Column(String(200), nullable=False)
@@ -66,20 +79,29 @@ class Plugin(Base):
# Relationships
installed_by_user = relationship("User", back_populates="installed_plugins")
configurations = relationship("PluginConfiguration", back_populates="plugin", cascade="all, delete-orphan")
instances = relationship("PluginInstance", back_populates="plugin", cascade="all, delete-orphan")
audit_logs = relationship("PluginAuditLog", back_populates="plugin", cascade="all, delete-orphan")
cron_jobs = relationship("PluginCronJob", back_populates="plugin", cascade="all, delete-orphan")
configurations = relationship(
"PluginConfiguration", back_populates="plugin", cascade="all, delete-orphan"
)
instances = relationship(
"PluginInstance", back_populates="plugin", cascade="all, delete-orphan"
)
audit_logs = relationship(
"PluginAuditLog", back_populates="plugin", cascade="all, delete-orphan"
)
cron_jobs = relationship(
"PluginCronJob", back_populates="plugin", cascade="all, delete-orphan"
)
# Indexes for performance
__table_args__ = (
Index('idx_plugin_status_enabled', 'status', 'enabled'),
Index('idx_plugin_user_status', 'installed_by_user_id', 'status'),
Index("idx_plugin_status_enabled", "status", "enabled"),
Index("idx_plugin_user_status", "installed_by_user_id", "status"),
)
class PluginConfiguration(Base):
"""Plugin configuration instances - per user/environment configs"""
__tablename__ = "plugin_configurations"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
@@ -107,17 +129,20 @@ class PluginConfiguration(Base):
# Constraints
__table_args__ = (
Index('idx_plugin_config_user_active', 'plugin_id', 'user_id', 'is_active'),
Index("idx_plugin_config_user_active", "plugin_id", "user_id", "is_active"),
)
class PluginInstance(Base):
"""Plugin runtime instances - tracks running plugin processes"""
__tablename__ = "plugin_instances"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
plugin_id = Column(UUID(as_uuid=True), ForeignKey("plugins.id"), nullable=False)
configuration_id = Column(UUID(as_uuid=True), ForeignKey("plugin_configurations.id"))
configuration_id = Column(
UUID(as_uuid=True), ForeignKey("plugin_configurations.id")
)
# Runtime information
instance_name = Column(String(200), nullable=False)
@@ -148,13 +173,12 @@ class PluginInstance(Base):
plugin = relationship("Plugin", back_populates="instances")
configuration = relationship("PluginConfiguration")
__table_args__ = (
Index('idx_plugin_instance_status', 'plugin_id', 'status'),
)
__table_args__ = (Index("idx_plugin_instance_status", "plugin_id", "status"),)
class PluginAuditLog(Base):
"""Audit logging for all plugin activities"""
__tablename__ = "plugin_audit_logs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
@@ -162,7 +186,9 @@ class PluginAuditLog(Base):
instance_id = Column(UUID(as_uuid=True), ForeignKey("plugin_instances.id"))
# Event details
event_type = Column(String(50), nullable=False, index=True) # api_call, config_change, error, etc.
event_type = Column(
String(50), nullable=False, index=True
) # api_call, config_change, error, etc.
action = Column(String(100), nullable=False)
resource = Column(String(200)) # Resource being accessed
@@ -194,14 +220,15 @@ class PluginAuditLog(Base):
api_key = relationship("APIKey")
__table_args__ = (
Index('idx_plugin_audit_plugin_time', 'plugin_id', 'timestamp'),
Index('idx_plugin_audit_user_time', 'user_id', 'timestamp'),
Index('idx_plugin_audit_event_type', 'event_type', 'timestamp'),
Index("idx_plugin_audit_plugin_time", "plugin_id", "timestamp"),
Index("idx_plugin_audit_user_time", "user_id", "timestamp"),
Index("idx_plugin_audit_event_type", "event_type", "timestamp"),
)
class PluginCronJob(Base):
"""Plugin scheduled jobs and cron tasks"""
__tablename__ = "plugin_cron_jobs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
@@ -209,7 +236,9 @@ class PluginCronJob(Base):
# Job identification
job_name = Column(String(200), nullable=False)
job_id = Column(String(100), nullable=False, unique=True, index=True) # Unique scheduler ID
job_id = Column(
String(100), nullable=False, unique=True, index=True
) # Unique scheduler ID
# Schedule configuration
schedule = Column(String(100), nullable=False) # Cron expression
@@ -245,25 +274,32 @@ class PluginCronJob(Base):
created_by_user = relationship("User")
__table_args__ = (
Index('idx_plugin_cron_next_run', 'enabled', 'next_run_at'),
Index('idx_plugin_cron_plugin', 'plugin_id', 'enabled'),
Index("idx_plugin_cron_next_run", "enabled", "next_run_at"),
Index("idx_plugin_cron_plugin", "plugin_id", "enabled"),
)
class PluginAPIGateway(Base):
"""API gateway configuration for plugin routing"""
__tablename__ = "plugin_api_gateways"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
plugin_id = Column(UUID(as_uuid=True), ForeignKey("plugins.id"), nullable=False, unique=True)
plugin_id = Column(
UUID(as_uuid=True), ForeignKey("plugins.id"), nullable=False, unique=True
)
# API routing configuration
base_path = Column(String(200), nullable=False, unique=True) # /api/v1/plugins/zammad
base_path = Column(
String(200), nullable=False, unique=True
) # /api/v1/plugins/zammad
internal_url = Column(String(500), nullable=False) # http://plugin-zammad:8000
# Security settings
require_authentication = Column(Boolean, default=True, nullable=False)
allowed_methods = Column(JSON, default=["GET", "POST", "PUT", "DELETE"]) # HTTP methods
allowed_methods = Column(
JSON, default=["GET", "POST", "PUT", "DELETE"]
) # HTTP methods
rate_limit_per_minute = Column(Integer, default=60)
rate_limit_per_hour = Column(Integer, default=1000)
@@ -303,8 +339,10 @@ Add to APIKey model:
plugin_audit_logs = relationship("PluginAuditLog", back_populates="api_key")
"""
class PluginPermission(Base):
"""Plugin permission grants - tracks user permissions for plugins"""
__tablename__ = "plugin_permissions"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
@@ -312,8 +350,12 @@ class PluginPermission(Base):
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
# Permission details
permission_name = Column(String(200), nullable=False) # e.g., 'chatbot:invoke', 'rag:query'
granted = Column(Boolean, default=True, nullable=False) # True=granted, False=revoked
permission_name = Column(
String(200), nullable=False
) # e.g., 'chatbot:invoke', 'rag:query'
granted = Column(
Boolean, default=True, nullable=False
) # True=granted, False=revoked
# Grant/revoke tracking
granted_at = Column(DateTime, nullable=False, default=func.now())
@@ -332,7 +374,7 @@ class PluginPermission(Base):
revoked_by_user = relationship("User", foreign_keys=[revoked_by_user_id])
__table_args__ = (
Index('idx_plugin_permission_user_plugin', 'user_id', 'plugin_id'),
Index('idx_plugin_permission_plugin_name', 'plugin_id', 'permission_name'),
Index('idx_plugin_permission_active', 'plugin_id', 'user_id', 'granted'),
Index("idx_plugin_permission_user_plugin", "user_id", "plugin_id"),
Index("idx_plugin_permission_plugin_name", "plugin_id", "permission_name"),
Index("idx_plugin_permission_active", "plugin_id", "user_id", "granted"),
)

View File

@@ -10,18 +10,23 @@ from datetime import datetime
class PromptTemplate(Base):
"""Editable prompt templates for different chatbot types"""
__tablename__ = "prompt_templates"
id = Column(String, primary_key=True, index=True)
name = Column(String(255), nullable=False, index=True) # Human readable name
type_key = Column(String(100), nullable=False, unique=True, index=True) # assistant, customer_support, etc.
type_key = Column(
String(100), nullable=False, unique=True, index=True
) # assistant, customer_support, etc.
description = Column(Text, nullable=True)
system_prompt = Column(Text, nullable=False)
is_default = Column(Boolean, default=True, nullable=False)
is_active = Column(Boolean, default=True, nullable=False)
version = Column(Integer, default=1, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
updated_at = Column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
def __repr__(self):
return f"<PromptTemplate(type_key='{self.type_key}', name='{self.name}')>"
@@ -29,10 +34,13 @@ class PromptTemplate(Base):
class ChatbotPromptVariable(Base):
"""Available variables that can be used in prompts"""
__tablename__ = "prompt_variables"
id = Column(String, primary_key=True, index=True)
variable_name = Column(String(100), nullable=False, unique=True, index=True) # {user_name}, {context}, etc.
variable_name = Column(
String(100), nullable=False, unique=True, index=True
) # {user_name}, {context}, etc.
description = Column(Text, nullable=True)
example_value = Column(String(500), nullable=True)
is_active = Column(Boolean, default=True, nullable=False)

View File

@@ -15,7 +15,9 @@ class RagCollection(Base):
id = Column(Integer, primary_key=True, index=True)
name = Column(String(255), nullable=False, index=True)
description = Column(Text, nullable=True)
qdrant_collection_name = Column(String(255), nullable=False, unique=True, index=True)
qdrant_collection_name = Column(
String(255), nullable=False, unique=True, index=True
)
# Metadata
document_count = Column(Integer, default=0, nullable=False)
@@ -23,15 +25,26 @@ class RagCollection(Base):
vector_count = Column(Integer, default=0, nullable=False)
# Status tracking
status = Column(String(50), default='active', nullable=False) # 'active', 'indexing', 'error'
status = Column(
String(50), default="active", nullable=False
) # 'active', 'indexing', 'error'
is_active = Column(Boolean, default=True, nullable=False)
# Timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
created_at = Column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at = Column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
# Relationships
documents = relationship("RagDocument", back_populates="collection", cascade="all, delete-orphan")
documents = relationship(
"RagDocument", back_populates="collection", cascade="all, delete-orphan"
)
def to_dict(self):
"""Convert model to dictionary for API responses"""
@@ -45,7 +58,7 @@ class RagCollection(Base):
"status": self.status,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"is_active": self.is_active
"is_active": self.is_active,
}
def __repr__(self):

View File

@@ -3,7 +3,17 @@ RAG Document Model
Represents documents within RAG collections
"""
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, BigInteger, ForeignKey, JSON
from sqlalchemy import (
Column,
Integer,
String,
Text,
DateTime,
Boolean,
BigInteger,
ForeignKey,
JSON,
)
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from app.db.database import Base
@@ -15,7 +25,12 @@ class RagDocument(Base):
id = Column(Integer, primary_key=True, index=True)
# Collection relationship
collection_id = Column(Integer, ForeignKey("rag_collections.id", ondelete="CASCADE"), nullable=False, index=True)
collection_id = Column(
Integer,
ForeignKey("rag_collections.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
collection = relationship("RagCollection", back_populates="documents")
# File information
@@ -27,7 +42,9 @@ class RagDocument(Base):
mime_type = Column(String(100), nullable=True)
# Processing status
status = Column(String(50), default='processing', nullable=False) # 'processing', 'processed', 'error', 'indexed'
status = Column(
String(50), default="processing", nullable=False
) # 'processing', 'processed', 'error', 'indexed'
processing_error = Column(Text, nullable=True)
# Content information
@@ -36,17 +53,30 @@ class RagDocument(Base):
character_count = Column(Integer, default=0, nullable=False)
# Vector information
vector_count = Column(Integer, default=0, nullable=False) # number of chunks/vectors created
chunk_size = Column(Integer, default=1000, nullable=False) # chunk size used for vectorization
vector_count = Column(
Integer, default=0, nullable=False
) # number of chunks/vectors created
chunk_size = Column(
Integer, default=1000, nullable=False
) # chunk size used for vectorization
# Metadata extracted from document
document_metadata = Column(JSON, nullable=True) # language, entities, keywords, etc.
document_metadata = Column(
JSON, nullable=True
) # language, entities, keywords, etc.
# Processing timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
created_at = Column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
processed_at = Column(DateTime(timezone=True), nullable=True)
indexed_at = Column(DateTime(timezone=True), nullable=True)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
updated_at = Column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
# Soft delete
is_deleted = Column(Boolean, default=False, nullable=False)
@@ -72,10 +102,12 @@ class RagDocument(Base):
"chunk_size": self.chunk_size,
"metadata": self.document_metadata or {},
"created_at": self.created_at.isoformat() if self.created_at else None,
"processed_at": self.processed_at.isoformat() if self.processed_at else None,
"processed_at": self.processed_at.isoformat()
if self.processed_at
else None,
"indexed_at": self.indexed_at.isoformat() if self.indexed_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"is_deleted": self.is_deleted
"is_deleted": self.is_deleted,
}
def __repr__(self):

158
backend/app/models/role.py Normal file
View File

@@ -0,0 +1,158 @@
"""
Role model for hierarchical permission management
"""
from datetime import datetime
from typing import Optional, List, Dict, Any
from enum import Enum
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON
from sqlalchemy.orm import relationship
from app.db.database import Base
class RoleLevel(str, Enum):
"""Role hierarchy levels"""
READ_ONLY = "read_only" # Level 1: Can only view
USER = "user" # Level 2: Can create and manage own resources
ADMIN = "admin" # Level 3: Can manage users and settings
SUPER_ADMIN = "super_admin" # Level 4: Full system access
class Role(Base):
"""Role model with hierarchical permissions"""
__tablename__ = "roles"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(50), unique=True, nullable=False)
display_name = Column(String(100), nullable=False)
description = Column(Text, nullable=True)
level = Column(String(20), nullable=False) # RoleLevel enum
# Permissions configuration
permissions = Column(JSON, default=dict) # Granular permissions
can_manage_users = Column(Boolean, default=False)
can_manage_budgets = Column(Boolean, default=False)
can_view_reports = Column(Boolean, default=False)
can_manage_tools = Column(Boolean, default=False)
# Role hierarchy
inherits_from = Column(JSON, default=list) # List of parent role names
# Status
is_active = Column(Boolean, default=True)
is_system_role = Column(Boolean, default=False) # System roles cannot be deleted
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
users = relationship("User", back_populates="role")
def __repr__(self):
return f"<Role(id={self.id}, name='{self.name}', level='{self.level}')>"
def to_dict(self):
"""Convert role to dictionary"""
return {
"id": self.id,
"name": self.name,
"display_name": self.display_name,
"description": self.description,
"level": self.level,
"permissions": self.permissions,
"can_manage_users": self.can_manage_users,
"can_manage_budgets": self.can_manage_budgets,
"can_view_reports": self.can_view_reports,
"can_manage_tools": self.can_manage_tools,
"inherits_from": self.inherits_from,
"is_active": self.is_active,
"is_system_role": self.is_system_role,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
def has_permission(self, permission: str) -> bool:
"""Check if role has a specific permission"""
if self.level == RoleLevel.SUPER_ADMIN:
return True
# Check direct permissions
if permission in self.permissions.get("granted", []):
return True
# Check denied permissions
if permission in self.permissions.get("denied", []):
return False
# Check inherited permissions (simplified)
for parent_role in self.inherits_from:
# This would require recursive checking in a real implementation
pass
return False
@classmethod
def create_default_roles(cls):
"""Create default system roles"""
roles = [
cls(
name="read_only",
display_name="Read Only",
description="Can view own data only",
level=RoleLevel.READ_ONLY,
permissions={
"granted": ["read_own"],
"denied": ["create", "update", "delete"],
},
is_system_role=True,
),
cls(
name="user",
display_name="User",
description="Standard user with full access to own resources",
level=RoleLevel.USER,
permissions={
"granted": ["read_own", "create_own", "update_own", "delete_own"],
"denied": ["manage_users", "manage_all"],
},
inherits_from=["read_only"],
is_system_role=True,
),
cls(
name="admin",
display_name="Administrator",
description="Can manage users and view reports",
level=RoleLevel.ADMIN,
permissions={
"granted": [
"read_all",
"create_all",
"update_all",
"manage_users",
"view_reports",
],
"denied": ["system_settings"],
},
inherits_from=["user"],
can_manage_users=True,
can_manage_budgets=True,
can_view_reports=True,
is_system_role=True,
),
cls(
name="super_admin",
display_name="Super Administrator",
description="Full system access",
level=RoleLevel.SUPER_ADMIN,
permissions={"granted": ["*"]}, # All permissions
inherits_from=["admin"],
can_manage_users=True,
can_manage_budgets=True,
can_view_reports=True,
can_manage_tools=True,
is_system_role=True,
),
]
return roles

272
backend/app/models/tool.py Normal file
View File

@@ -0,0 +1,272 @@
"""
Tool model for custom tool execution
"""
from datetime import datetime
from typing import Optional, List, Dict, Any
from enum import Enum
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
Boolean,
Text,
JSON,
ForeignKey,
Float,
)
from sqlalchemy.orm import relationship
from app.db.database import Base
class ToolType(str, Enum):
"""Tool execution types"""
PYTHON = "python"
BASH = "bash"
DOCKER = "docker"
API = "api"
CUSTOM = "custom"
class ToolStatus(str, Enum):
"""Tool execution status"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
TIMEOUT = "timeout"
CANCELLED = "cancelled"
class Tool(Base):
"""Tool definition model"""
__tablename__ = "tools"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), nullable=False, index=True)
display_name = Column(String(200), nullable=False)
description = Column(Text, nullable=True)
# Tool configuration
tool_type = Column(String(20), nullable=False) # ToolType enum
code = Column(Text, nullable=False) # Tool implementation code
parameters_schema = Column(JSON, default=dict) # JSON schema for parameters
return_schema = Column(JSON, default=dict) # Expected return format
# Execution settings
timeout_seconds = Column(Integer, default=30)
max_memory_mb = Column(Integer, default=256)
max_cpu_seconds = Column(Float, default=10.0)
# Docker settings (for docker type tools)
docker_image = Column(String(200), nullable=True)
docker_command = Column(Text, nullable=True)
# Access control
is_public = Column(Boolean, default=False) # Public tools available to all users
is_approved = Column(Boolean, default=False) # Admin approved for security
created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
# Categories and tags
category = Column(String(50), nullable=True)
tags = Column(JSON, default=list)
# Usage tracking
usage_count = Column(Integer, default=0)
last_used_at = Column(DateTime, nullable=True)
# Status
is_active = Column(Boolean, default=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
created_by = relationship("User", back_populates="created_tools")
executions = relationship(
"ToolExecution", back_populates="tool", cascade="all, delete-orphan"
)
def __repr__(self):
return f"<Tool(id={self.id}, name='{self.name}', type='{self.tool_type}')>"
def to_dict(self):
"""Convert tool to dictionary"""
return {
"id": self.id,
"name": self.name,
"display_name": self.display_name,
"description": self.description,
"tool_type": self.tool_type,
"parameters_schema": self.parameters_schema,
"return_schema": self.return_schema,
"timeout_seconds": self.timeout_seconds,
"max_memory_mb": self.max_memory_mb,
"max_cpu_seconds": self.max_cpu_seconds,
"docker_image": self.docker_image,
"is_public": self.is_public,
"is_approved": self.is_approved,
"created_by_user_id": self.created_by_user_id,
"category": self.category,
"tags": self.tags,
"usage_count": self.usage_count,
"last_used_at": self.last_used_at.isoformat()
if self.last_used_at
else None,
"is_active": self.is_active,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
def increment_usage(self):
"""Increment usage count and update last used timestamp"""
self.usage_count += 1
self.last_used_at = datetime.utcnow()
def can_be_used_by(self, user) -> bool:
"""Check if user can use this tool"""
# Tool creator can always use their tools
if self.created_by_user_id == user.id:
return True
# Public and approved tools can be used by anyone
if self.is_public and self.is_approved:
return True
# Admin users can use any tool
if user.has_permission("manage_tools"):
return True
return False
class ToolExecution(Base):
"""Tool execution instance model"""
__tablename__ = "tool_executions"
id = Column(Integer, primary_key=True, index=True)
# Tool and user references
tool_id = Column(Integer, ForeignKey("tools.id"), nullable=False)
executed_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
# Execution details
parameters = Column(JSON, default=dict) # Input parameters
status = Column(String(20), nullable=False, default=ToolStatus.PENDING)
# Results
output = Column(Text, nullable=True) # Tool output
error_message = Column(Text, nullable=True) # Error details if failed
return_code = Column(Integer, nullable=True) # Exit code
# Resource usage
execution_time_ms = Column(Integer, nullable=True) # Actual execution time
memory_used_mb = Column(Float, nullable=True) # Peak memory usage
cpu_time_ms = Column(Integer, nullable=True) # CPU time used
# Docker execution details
container_id = Column(String(100), nullable=True) # Docker container ID
docker_logs = Column(Text, nullable=True) # Docker container logs
# Timestamps
started_at = Column(DateTime, nullable=True)
completed_at = Column(DateTime, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
# Relationships
tool = relationship("Tool", back_populates="executions")
executed_by = relationship("User", back_populates="tool_executions")
def __repr__(self):
return f"<ToolExecution(id={self.id}, tool_id={self.tool_id}, status='{self.status}')>"
def to_dict(self):
"""Convert execution to dictionary"""
return {
"id": self.id,
"tool_id": self.tool_id,
"tool_name": self.tool.name if self.tool else None,
"executed_by_user_id": self.executed_by_user_id,
"parameters": self.parameters,
"status": self.status,
"output": self.output,
"error_message": self.error_message,
"return_code": self.return_code,
"execution_time_ms": self.execution_time_ms,
"memory_used_mb": self.memory_used_mb,
"cpu_time_ms": self.cpu_time_ms,
"container_id": self.container_id,
"started_at": self.started_at.isoformat() if self.started_at else None,
"completed_at": self.completed_at.isoformat()
if self.completed_at
else None,
"created_at": self.created_at.isoformat() if self.created_at else None,
}
@property
def duration_seconds(self) -> float:
"""Calculate execution duration in seconds"""
if self.started_at and self.completed_at:
return (self.completed_at - self.started_at).total_seconds()
return 0.0
def is_running(self) -> bool:
"""Check if execution is currently running"""
return self.status in [ToolStatus.PENDING, ToolStatus.RUNNING]
def is_finished(self) -> bool:
"""Check if execution is finished (success or failure)"""
return self.status in [
ToolStatus.COMPLETED,
ToolStatus.FAILED,
ToolStatus.TIMEOUT,
ToolStatus.CANCELLED,
]
class ToolCategory(Base):
"""Tool category for organization"""
__tablename__ = "tool_categories"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(50), unique=True, nullable=False)
display_name = Column(String(100), nullable=False)
description = Column(Text, nullable=True)
# Visual
icon = Column(String(50), nullable=True) # Icon name
color = Column(String(20), nullable=True) # Color code
# Ordering
sort_order = Column(Integer, default=0)
# Status
is_active = Column(Boolean, default=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
def __repr__(self):
return f"<ToolCategory(id={self.id}, name='{self.name}')>"
def to_dict(self):
"""Convert category to dictionary"""
return {
"id": self.id,
"name": self.name,
"display_name": self.display_name,
"description": self.description,
"icon": self.icon,
"color": self.color,
"sort_order": self.sort_order,
"is_active": self.is_active,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}

View File

@@ -4,7 +4,17 @@ Usage Tracking model for API key usage statistics
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON, ForeignKey, Float
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
Boolean,
Text,
JSON,
ForeignKey,
Float,
)
from sqlalchemy.orm import relationship
from app.db.database import Base
@@ -82,7 +92,7 @@ class UsageTracking(Base):
"ip_address": self.ip_address,
"user_agent": self.user_agent,
"request_metadata": self.request_metadata,
"created_at": self.created_at.isoformat() if self.created_at else None
"created_at": self.created_at.isoformat() if self.created_at else None,
}
@classmethod
@@ -102,7 +112,7 @@ class UsageTracking(Base):
session_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
request_metadata: Optional[dict] = None
request_metadata: Optional[dict] = None,
) -> "UsageTracking":
"""Create a new usage tracking record"""
return cls(
@@ -121,5 +131,5 @@ class UsageTracking(Base):
session_id=session_id,
ip_address=ip_address,
user_agent=user_agent,
request_metadata=request_metadata or {}
request_metadata=request_metadata or {},
)

View File

@@ -4,16 +4,21 @@ User model
from datetime import datetime
from typing import Optional, List
from enum import Enum
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
Boolean,
Text,
JSON,
ForeignKey,
Numeric,
)
from sqlalchemy.orm import relationship
from sqlalchemy import inspect as sa_inspect
from app.db.database import Base
class UserRole(str, Enum):
"""User role enumeration"""
USER = "user"
ADMIN = "admin"
SUPER_ADMIN = "super_admin"
from decimal import Decimal
class User(Base):
@@ -27,14 +32,21 @@ class User(Base):
hashed_password = Column(String, nullable=False)
full_name = Column(String, nullable=True)
# User status and permissions
# Account status
is_active = Column(Boolean, default=True)
is_superuser = Column(Boolean, default=False)
is_verified = Column(Boolean, default=False)
is_superuser = Column(Boolean, default=False) # Legacy field for compatibility
# Role-based access control
role = Column(String, default=UserRole.USER.value) # user, admin, super_admin
permissions = Column(JSON, default=dict) # Custom permissions
# Role-based access control (using new Role model)
role_id = Column(Integer, ForeignKey("roles.id"), nullable=True)
custom_permissions = Column(JSON, default=dict) # Custom permissions override
# Account management
account_locked = Column(Boolean, default=False)
account_locked_until = Column(DateTime, nullable=True)
failed_login_attempts = Column(Integer, default=0)
last_failed_login = Column(DateTime, nullable=True)
force_password_change = Column(Boolean, default=False)
# Profile information
avatar_url = Column(String, nullable=True)
@@ -52,27 +64,59 @@ class User(Base):
notification_settings = Column(JSON, default=dict)
# Relationships
api_keys = relationship("APIKey", back_populates="user", cascade="all, delete-orphan")
usage_tracking = relationship("UsageTracking", back_populates="user", cascade="all, delete-orphan")
budgets = relationship("Budget", back_populates="user", cascade="all, delete-orphan")
audit_logs = relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
role = relationship("Role", back_populates="users")
api_keys = relationship(
"APIKey", back_populates="user", cascade="all, delete-orphan"
)
usage_tracking = relationship(
"UsageTracking", back_populates="user", cascade="all, delete-orphan"
)
budgets = relationship(
"Budget", back_populates="user", cascade="all, delete-orphan"
)
audit_logs = relationship(
"AuditLog", back_populates="user", cascade="all, delete-orphan"
)
installed_plugins = relationship("Plugin", back_populates="installed_by_user")
created_tools = relationship(
"Tool", back_populates="created_by", cascade="all, delete-orphan"
)
tool_executions = relationship(
"ToolExecution", back_populates="executed_by", cascade="all, delete-orphan"
)
notifications = relationship(
"Notification", back_populates="user", cascade="all, delete-orphan"
)
def __repr__(self):
return f"<User(id={self.id}, email='{self.email}', username='{self.username}')>"
def to_dict(self):
"""Convert user to dictionary for API responses"""
# Check if role relationship is loaded to avoid lazy loading in async context
inspector = sa_inspect(self)
role_loaded = "role" not in inspector.unloaded
return {
"id": self.id,
"email": self.email,
"username": self.username,
"full_name": self.full_name,
"is_active": self.is_active,
"is_superuser": self.is_superuser,
"is_verified": self.is_verified,
"role": self.role,
"permissions": self.permissions,
"is_superuser": self.is_superuser,
"role_id": self.role_id,
"role": self.role.to_dict() if role_loaded and self.role else None,
"custom_permissions": self.custom_permissions,
"account_locked": self.account_locked,
"account_locked_until": self.account_locked_until.isoformat()
if self.account_locked_until
else None,
"failed_login_attempts": self.failed_login_attempts,
"last_failed_login": self.last_failed_login.isoformat()
if self.last_failed_login
else None,
"force_password_change": self.force_password_change,
"avatar_url": self.avatar_url,
"bio": self.bio,
"company": self.company,
@@ -81,35 +125,49 @@ class User(Base):
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"last_login": self.last_login.isoformat() if self.last_login else None,
"preferences": self.preferences,
"notification_settings": self.notification_settings
"notification_settings": self.notification_settings,
}
def has_permission(self, permission: str) -> bool:
"""Check if user has a specific permission"""
"""Check if user has a specific permission using role hierarchy"""
if self.is_superuser:
return True
# Check role-based permissions
role_permissions = {
"user": ["read_own", "create_own", "update_own"],
"admin": ["read_all", "create_all", "update_all", "delete_own"],
"super_admin": ["read_all", "create_all", "update_all", "delete_all", "manage_users", "manage_modules"]
}
if self.role in role_permissions and permission in role_permissions[self.role]:
# Check custom permissions first (override)
if permission in self.custom_permissions.get("denied", []):
return False
if permission in self.custom_permissions.get("granted", []):
return True
# Check custom permissions
return permission in self.permissions
# Check role permissions if user has a role assigned
if self.role:
return self.role.has_permission(permission)
return False
def can_access_module(self, module_name: str) -> bool:
"""Check if user can access a specific module"""
if self.is_superuser:
return True
# Check module-specific permissions
module_permissions = self.permissions.get("modules", {})
return module_permissions.get(module_name, False)
# Check custom permissions first
module_permissions = self.custom_permissions.get("modules", {})
if module_name in module_permissions:
return module_permissions[module_name]
# Check role permissions
if self.role:
# For admin roles, allow all modules
if self.role.level in ["admin", "super_admin"]:
return True
# For regular users, check module access
elif self.role.level == "user":
return True # Basic users can access all modules
# For read-only users, limit access
elif self.role.level == "read_only":
return module_name in ["chatbot", "analytics"] # Only certain modules
return False
def update_last_login(self):
"""Update the last login timestamp"""
@@ -127,8 +185,97 @@ class User(Base):
self.notification_settings = {}
self.notification_settings.update(settings)
def get_effective_permissions(self) -> dict:
"""Get all effective permissions combining role and custom permissions"""
permissions = {"granted": set(), "denied": set()}
# Start with role permissions
if self.role:
role_perms = self.role.permissions
permissions["granted"].update(role_perms.get("granted", []))
permissions["denied"].update(role_perms.get("denied", []))
# Apply custom permissions (override role permissions)
permissions["granted"].update(self.custom_permissions.get("granted", []))
permissions["denied"].update(self.custom_permissions.get("denied", []))
# Remove any denied permissions from granted
permissions["granted"] -= permissions["denied"]
return {
"granted": list(permissions["granted"]),
"denied": list(permissions["denied"]),
}
def can_create_api_key(self) -> bool:
"""Check if user can create API keys based on role and limits"""
if not self.is_active or self.account_locked:
return False
# Check permission
if not self.has_permission("create_api_key"):
return False
# Check if user has reached their API key limit
current_keys = [key for key in self.api_keys if key.is_active]
max_keys = (
self.role.permissions.get("limits", {}).get("max_api_keys", 5)
if self.role
else 5
)
return len(current_keys) < max_keys
def can_create_tool(self) -> bool:
"""Check if user can create custom tools"""
return (
self.is_active
and not self.account_locked
and self.has_permission("create_tool")
)
def is_budget_exceeded(self) -> bool:
"""Check if user has exceeded their budget limits"""
if not self.budgets:
return False
active_budget = next((b for b in self.budgets if b.is_active), None)
if not active_budget:
return False
return active_budget.current_usage > active_budget.limit
def lock_account(self, duration_hours: int = 24):
"""Lock user account for specified duration"""
from datetime import timedelta
self.account_locked = True
self.account_locked_until = datetime.utcnow() + timedelta(hours=duration_hours)
def unlock_account(self):
"""Unlock user account"""
self.account_locked = False
self.account_locked_until = None
self.failed_login_attempts = 0
def record_failed_login(self):
"""Record a failed login attempt"""
self.failed_login_attempts += 1
self.last_failed_login = datetime.utcnow()
# Lock account after 5 failed attempts
if self.failed_login_attempts >= 5:
self.lock_account(24) # Lock for 24 hours
def reset_failed_logins(self):
"""Reset failed login counter"""
self.failed_login_attempts = 0
self.last_failed_login = None
@classmethod
def create_default_admin(cls, email: str, username: str, password_hash: str) -> "User":
def create_default_admin(
cls, email: str, username: str, password_hash: str
) -> "User":
"""Create a default admin user"""
return cls(
email=email,
@@ -136,24 +283,16 @@ class User(Base):
hashed_password=password_hash,
full_name="System Administrator",
is_active=True,
is_superuser=True,
is_superuser=True, # Legacy compatibility
is_verified=True,
role="super_admin",
permissions={
"modules": {
"cache": True,
"analytics": True,
"rag": True
}
},
preferences={
"theme": "dark",
"language": "en",
"timezone": "UTC"
# Note: role_id will be set after role is created in init_db
custom_permissions={
"modules": {"cache": True, "analytics": True, "rag": True}
},
preferences={"theme": "dark", "language": "en", "timezone": "UTC"},
notification_settings={
"email_notifications": True,
"security_alerts": True,
"system_updates": True
}
"system_updates": True,
},
)

View File

@@ -15,7 +15,4 @@ __version__ = "1.0.0"
__author__ = "Enclava Team"
# Export main classes for easy importing
__all__ = [
"ChatbotModule",
"create_module"
]
__all__ = ["ChatbotModule", "create_module"]

File diff suppressed because it is too large Load Diff

View File

@@ -23,7 +23,7 @@ from .protocols import (
ChatbotServiceProtocol,
LiteLLMClientProtocol,
WorkflowServiceProtocol,
ServiceRegistry
ServiceRegistry,
)
logger = logging.getLogger(__name__)
@@ -36,7 +36,9 @@ class ModuleFactory:
self.modules: Dict[str, Any] = {}
self.initialized = False
async def create_all_modules(self, config: Optional[Dict[str, Any]] = None) -> ServiceRegistry:
async def create_all_modules(
self, config: Optional[Dict[str, Any]] = None
) -> ServiceRegistry:
"""
Create all modules with proper dependency injection
@@ -59,7 +61,7 @@ class ModuleFactory:
# Step 3: Create chatbot module with RAG dependency
chatbot_module = create_chatbot_module(
litellm_client=litellm_client,
rag_service=rag_module # RAG module implements RAGServiceProtocol
rag_service=rag_module, # RAG module implements RAGServiceProtocol
)
# Step 4: Create workflow module with chatbot dependency
@@ -71,7 +73,7 @@ class ModuleFactory:
modules = {
"rag": rag_module,
"chatbot": chatbot_module,
"workflow": workflow_module
"workflow": workflow_module,
}
logger.info(f"Created {len(modules)} modules with dependencies wired")
@@ -84,14 +86,16 @@ class ModuleFactory:
return modules
async def _initialize_modules(self, modules: Dict[str, Any], config: Dict[str, Any]):
async def _initialize_modules(
self, modules: Dict[str, Any], config: Dict[str, Any]
):
"""Initialize all modules in dependency order"""
# Initialize in dependency order (modules with no deps first)
initialization_order = [
("rag", modules["rag"]),
("chatbot", modules["chatbot"]), # Depends on RAG
("workflow", modules["workflow"]) # Depends on Chatbot
("workflow", modules["workflow"]), # Depends on Chatbot
]
for module_name, module in initialization_order:
@@ -100,7 +104,7 @@ class ModuleFactory:
module_config = config.get(module_name, {})
# Different modules have different initialization patterns
if hasattr(module, 'initialize'):
if hasattr(module, "initialize"):
if module_name == "rag":
await module.initialize()
else:
@@ -110,7 +114,9 @@ class ModuleFactory:
except Exception as e:
logger.error(f"❌ Failed to initialize {module_name} module: {e}")
raise RuntimeError(f"Module initialization failed: {module_name}") from e
raise RuntimeError(
f"Module initialization failed: {module_name}"
) from e
async def cleanup_all_modules(self):
"""Cleanup all modules in reverse dependency order"""
@@ -125,7 +131,7 @@ class ModuleFactory:
try:
logger.info(f"Cleaning up {module_name} module...")
module = self.modules[module_name]
if hasattr(module, 'cleanup'):
if hasattr(module, "cleanup"):
await module.cleanup()
logger.info(f"{module_name} module cleaned up")
except Exception as e:
@@ -174,13 +180,16 @@ def create_rag_module(config: Optional[Dict[str, Any]] = None) -> RAGModule:
return RAGModule(config=config or {})
def create_chatbot_with_rag(rag_service: RAGServiceProtocol,
litellm_client: LiteLLMClientProtocol) -> ChatbotModule:
def create_chatbot_with_rag(
rag_service: RAGServiceProtocol, litellm_client: LiteLLMClientProtocol
) -> ChatbotModule:
"""Create chatbot module with RAG dependency"""
return create_chatbot_module(litellm_client=litellm_client, rag_service=rag_service)
def create_workflow_with_chatbot(chatbot_service: ChatbotServiceProtocol) -> WorkflowModule:
def create_workflow_with_chatbot(
chatbot_service: ChatbotServiceProtocol,
) -> WorkflowModule:
"""Create workflow module with chatbot dependency"""
return WorkflowModule(chatbot_service=chatbot_service)

View File

@@ -14,7 +14,9 @@ class RAGServiceProtocol(Protocol):
"""Protocol for RAG (Retrieval-Augmented Generation) service interface"""
@abstractmethod
async def search(self, query: str, collection_name: str, top_k: int) -> Dict[str, Any]:
async def search(
self, query: str, collection_name: str, top_k: int
) -> Dict[str, Any]:
"""
Search for relevant documents
@@ -29,7 +31,9 @@ class RAGServiceProtocol(Protocol):
...
@abstractmethod
async def index_document(self, content: str, metadata: Dict[str, Any] = None) -> str:
async def index_document(
self, content: str, metadata: Dict[str, Any] = None
) -> str:
"""
Index a document in the vector database
@@ -94,7 +98,9 @@ class LiteLLMClientProtocol(Protocol):
"""Protocol for LiteLLM client interface"""
@abstractmethod
async def completion(self, model: str, messages: List[Dict[str, str]], **kwargs) -> Any:
async def completion(
self, model: str, messages: List[Dict[str, str]], **kwargs
) -> Any:
"""
Create a completion using the specified model
@@ -109,8 +115,14 @@ class LiteLLMClientProtocol(Protocol):
...
@abstractmethod
async def create_chat_completion(self, model: str, messages: List[Dict[str, str]],
user_id: str, api_key_id: str, **kwargs) -> Any:
async def create_chat_completion(
self,
model: str,
messages: List[Dict[str, str]],
user_id: str,
api_key_id: str,
**kwargs,
) -> Any:
"""
Create a chat completion with user tracking
@@ -207,7 +219,9 @@ class WorkflowServiceProtocol(Protocol):
"""Protocol for Workflow service interface"""
@abstractmethod
async def execute_workflow(self, workflow: Any, input_data: Dict[str, Any] = None) -> Any:
async def execute_workflow(
self, workflow: Any, input_data: Dict[str, Any] = None
) -> Any:
"""
Execute a workflow definition

File diff suppressed because it is too large Load Diff

View File

@@ -13,77 +13,102 @@ from pathlib import Path
class PluginRuntimeSpec(BaseModel):
"""Plugin runtime requirements and dependencies"""
python_version: str = Field("3.11", description="Required Python version")
dependencies: List[str] = Field(default_factory=list, description="Required Python packages")
environment_variables: Dict[str, str] = Field(default_factory=dict, description="Required environment variables")
@validator('python_version')
python_version: str = Field("3.11", description="Required Python version")
dependencies: List[str] = Field(
default_factory=list, description="Required Python packages"
)
environment_variables: Dict[str, str] = Field(
default_factory=dict, description="Required environment variables"
)
@validator("python_version")
def validate_python_version(cls, v):
if not v.startswith(('3.9', '3.10', '3.11', '3.12')):
raise ValueError('Python version must be 3.9, 3.10, 3.11, or 3.12')
if not v.startswith(("3.9", "3.10", "3.11", "3.12")):
raise ValueError("Python version must be 3.9, 3.10, 3.11, or 3.12")
return v
class PluginPermissions(BaseModel):
"""Plugin permission specifications"""
platform_apis: List[str] = Field(default_factory=list, description="Platform API access scopes")
plugin_scopes: List[str] = Field(default_factory=list, description="Plugin-specific permission scopes")
external_domains: List[str] = Field(default_factory=list, description="Allowed external domains")
@validator('platform_apis')
platform_apis: List[str] = Field(
default_factory=list, description="Platform API access scopes"
)
plugin_scopes: List[str] = Field(
default_factory=list, description="Plugin-specific permission scopes"
)
external_domains: List[str] = Field(
default_factory=list, description="Allowed external domains"
)
@validator("platform_apis")
def validate_platform_apis(cls, v):
allowed_apis = [
'chatbot:invoke', 'chatbot:manage', 'chatbot:read',
'rag:query', 'rag:manage', 'rag:read',
'llm:completion', 'llm:embeddings', 'llm:models',
'workflow:execute', 'workflow:read',
'cache:read', 'cache:write'
"chatbot:invoke",
"chatbot:manage",
"chatbot:read",
"rag:query",
"rag:manage",
"rag:read",
"llm:completion",
"llm:embeddings",
"llm:models",
"workflow:execute",
"workflow:read",
"cache:read",
"cache:write",
]
for api in v:
if api not in allowed_apis and not api.endswith(':*'):
raise ValueError(f'Invalid platform API scope: {api}')
if api not in allowed_apis and not api.endswith(":*"):
raise ValueError(f"Invalid platform API scope: {api}")
return v
class PluginDatabaseSpec(BaseModel):
"""Plugin database configuration"""
schema: str = Field(..., description="Database schema name")
migrations_path: str = Field("./migrations", description="Path to migration files")
auto_migrate: bool = Field(True, description="Auto-run migrations on startup")
@validator('schema')
@validator("schema")
def validate_schema_name(cls, v):
if not v.startswith('plugin_'):
if not v.startswith("plugin_"):
raise ValueError('Database schema must start with "plugin_"')
if not v.replace('plugin_', '').replace('_', '').isalnum():
raise ValueError('Schema name must contain only alphanumeric characters and underscores')
if not v.replace("plugin_", "").replace("_", "").isalnum():
raise ValueError(
"Schema name must contain only alphanumeric characters and underscores"
)
return v
class PluginAPIEndpoint(BaseModel):
"""Plugin API endpoint specification"""
path: str = Field(..., description="API endpoint path")
methods: List[str] = Field(default=['GET'], description="Allowed HTTP methods")
methods: List[str] = Field(default=["GET"], description="Allowed HTTP methods")
description: str = Field("", description="Endpoint description")
auth_required: bool = Field(True, description="Whether authentication is required")
@validator('methods')
@validator("methods")
def validate_methods(cls, v):
allowed_methods = ['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS']
allowed_methods = ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]
for method in v:
if method not in allowed_methods:
raise ValueError(f'Invalid HTTP method: {method}')
raise ValueError(f"Invalid HTTP method: {method}")
return v
@validator('path')
@validator("path")
def validate_path(cls, v):
if not v.startswith('/'):
if not v.startswith("/"):
raise ValueError('API path must start with "/"')
return v
class PluginCronJob(BaseModel):
"""Plugin scheduled job specification"""
name: str = Field(..., description="Job name")
schedule: str = Field(..., description="Cron expression")
function: str = Field(..., description="Function to execute")
@@ -92,40 +117,55 @@ class PluginCronJob(BaseModel):
timeout_seconds: int = Field(300, description="Job timeout in seconds")
max_retries: int = Field(3, description="Maximum retry attempts")
@validator('schedule')
@validator("schedule")
def validate_cron_expression(cls, v):
# Basic cron validation - should have 5 parts
parts = v.split()
if len(parts) != 5:
raise ValueError('Cron expression must have 5 parts (minute hour day month weekday)')
raise ValueError(
"Cron expression must have 5 parts (minute hour day month weekday)"
)
return v
class PluginUIConfig(BaseModel):
"""Plugin UI configuration"""
configuration_schema: str = Field("./config_schema.json", description="JSON schema for configuration")
ui_components: str = Field("./ui/components", description="Path to UI components")
pages: List[Dict[str, str]] = Field(default_factory=list, description="Plugin pages")
@validator('pages')
configuration_schema: str = Field(
"./config_schema.json", description="JSON schema for configuration"
)
ui_components: str = Field("./ui/components", description="Path to UI components")
pages: List[Dict[str, str]] = Field(
default_factory=list, description="Plugin pages"
)
@validator("pages")
def validate_pages(cls, v):
required_fields = ['name', 'path', 'component']
required_fields = ["name", "path", "component"]
for page in v:
for field in required_fields:
if field not in page:
raise ValueError(f'Page must have {field} field')
raise ValueError(f"Page must have {field} field")
return v
class PluginExternalServices(BaseModel):
"""Plugin external service configuration"""
allowed_domains: List[str] = Field(default_factory=list, description="Allowed external domains")
webhooks: List[Dict[str, str]] = Field(default_factory=list, description="Webhook configurations")
rate_limits: Dict[str, int] = Field(default_factory=dict, description="Rate limits per domain")
allowed_domains: List[str] = Field(
default_factory=list, description="Allowed external domains"
)
webhooks: List[Dict[str, str]] = Field(
default_factory=list, description="Webhook configurations"
)
rate_limits: Dict[str, int] = Field(
default_factory=dict, description="Rate limits per domain"
)
class PluginMetadata(BaseModel):
"""Plugin metadata information"""
name: str = Field(..., description="Plugin name (must be unique)")
version: str = Field(..., description="Plugin version (semantic versioning)")
description: str = Field(..., description="Plugin description")
@@ -133,58 +173,78 @@ class PluginMetadata(BaseModel):
license: str = Field("MIT", description="Plugin license")
homepage: Optional[HttpUrl] = Field(None, description="Plugin homepage URL")
repository: Optional[HttpUrl] = Field(None, description="Plugin repository URL")
tags: List[str] = Field(default_factory=list, description="Plugin tags for discovery")
tags: List[str] = Field(
default_factory=list, description="Plugin tags for discovery"
)
@validator('name')
@validator("name")
def validate_name(cls, v):
if not v.replace('-', '').replace('_', '').isalnum():
raise ValueError('Plugin name must contain only alphanumeric characters, hyphens, and underscores')
if not v.replace("-", "").replace("_", "").isalnum():
raise ValueError(
"Plugin name must contain only alphanumeric characters, hyphens, and underscores"
)
if len(v) < 3 or len(v) > 50:
raise ValueError('Plugin name must be between 3 and 50 characters')
raise ValueError("Plugin name must be between 3 and 50 characters")
return v.lower()
@validator('version')
@validator("version")
def validate_version(cls, v):
# Basic semantic versioning validation
parts = v.split('.')
parts = v.split(".")
if len(parts) != 3:
raise ValueError('Version must follow semantic versioning (x.y.z)')
raise ValueError("Version must follow semantic versioning (x.y.z)")
for part in parts:
if not part.isdigit():
raise ValueError('Version parts must be numeric')
raise ValueError("Version parts must be numeric")
return v
class PluginManifest(BaseModel):
"""Complete plugin manifest specification"""
apiVersion: str = Field("v1", description="Manifest API version")
kind: str = Field("Plugin", description="Resource kind")
metadata: PluginMetadata = Field(..., description="Plugin metadata")
spec: "PluginSpec" = Field(..., description="Plugin specification")
@validator('apiVersion')
@validator("apiVersion")
def validate_api_version(cls, v):
if v not in ['v1']:
raise ValueError('Unsupported API version')
if v not in ["v1"]:
raise ValueError("Unsupported API version")
return v
@validator('kind')
@validator("kind")
def validate_kind(cls, v):
if v != 'Plugin':
if v != "Plugin":
raise ValueError('Kind must be "Plugin"')
return v
class PluginSpec(BaseModel):
"""Plugin specification details"""
runtime: PluginRuntimeSpec = Field(default_factory=PluginRuntimeSpec, description="Runtime requirements")
permissions: PluginPermissions = Field(default_factory=PluginPermissions, description="Permission requirements")
database: Optional[PluginDatabaseSpec] = Field(None, description="Database configuration")
api_endpoints: List[PluginAPIEndpoint] = Field(default_factory=list, description="API endpoints")
cron_jobs: List[PluginCronJob] = Field(default_factory=list, description="Scheduled jobs")
runtime: PluginRuntimeSpec = Field(
default_factory=PluginRuntimeSpec, description="Runtime requirements"
)
permissions: PluginPermissions = Field(
default_factory=PluginPermissions, description="Permission requirements"
)
database: Optional[PluginDatabaseSpec] = Field(
None, description="Database configuration"
)
api_endpoints: List[PluginAPIEndpoint] = Field(
default_factory=list, description="API endpoints"
)
cron_jobs: List[PluginCronJob] = Field(
default_factory=list, description="Scheduled jobs"
)
ui_config: Optional[PluginUIConfig] = Field(None, description="UI configuration")
external_services: Optional[PluginExternalServices] = Field(None, description="External service configuration")
config_schema: Dict[str, Any] = Field(default_factory=dict, description="Plugin configuration JSON schema")
external_services: Optional[PluginExternalServices] = Field(
None, description="External service configuration"
)
config_schema: Dict[str, Any] = Field(
default_factory=dict, description="Plugin configuration JSON schema"
)
# Update forward reference
@@ -194,18 +254,14 @@ PluginManifest.model_rebuild()
class PluginManifestValidator:
"""Plugin manifest validation and parsing utilities"""
REQUIRED_FILES = [
'manifest.yaml',
'main.py',
'requirements.txt'
]
REQUIRED_FILES = ["manifest.yaml", "main.py", "requirements.txt"]
OPTIONAL_FILES = [
'config_schema.json',
'README.md',
'ui/components',
'migrations',
'tests'
"config_schema.json",
"README.md",
"ui/components",
"migrations",
"tests",
]
@classmethod
@@ -217,7 +273,7 @@ class PluginManifestValidator:
raise FileNotFoundError(f"Manifest file not found: {manifest_path}")
try:
with open(manifest_path, 'r', encoding='utf-8') as f:
with open(manifest_path, "r", encoding="utf-8") as f:
manifest_data = yaml.safe_load(f)
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in manifest file: {e}")
@@ -243,19 +299,19 @@ class PluginManifestValidator:
raise FileNotFoundError(f"Required file missing: {required_file}")
# Validate main.py contains plugin class
main_py_path = plugin_dir / 'main.py'
with open(main_py_path, 'r', encoding='utf-8') as f:
main_py_path = plugin_dir / "main.py"
with open(main_py_path, "r", encoding="utf-8") as f:
main_content = f.read()
if 'BasePlugin' not in main_content:
if "BasePlugin" not in main_content:
raise ValueError("main.py must contain a class inheriting from BasePlugin")
# Validate requirements.txt format
requirements_path = plugin_dir / 'requirements.txt'
with open(requirements_path, 'r', encoding='utf-8') as f:
requirements_path = plugin_dir / "requirements.txt"
with open(requirements_path, "r", encoding="utf-8") as f:
requirements = f.read().strip()
if requirements and not all(line.strip() for line in requirements.split('\n')):
if requirements and not all(line.strip() for line in requirements.split("\n")):
raise ValueError("Invalid requirements.txt format")
# Validate config schema if specified
@@ -264,7 +320,8 @@ class PluginManifestValidator:
if schema_path.exists():
try:
import json
with open(schema_path, 'r', encoding='utf-8') as f:
with open(schema_path, "r", encoding="utf-8") as f:
json.load(f)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON schema: {e}")
@@ -283,7 +340,7 @@ class PluginManifestValidator:
"compatible": True,
"warnings": [],
"errors": [],
"platform_version": "1.0.0"
"platform_version": "1.0.0",
}
# Check platform API compatibility
@@ -319,38 +376,55 @@ class PluginManifestValidator:
def _is_platform_api_supported(cls, api: str) -> bool:
"""Check if platform API is supported"""
supported_apis = [
'chatbot:invoke', 'chatbot:manage', 'chatbot:read',
'rag:query', 'rag:manage', 'rag:read',
'llm:completion', 'llm:embeddings', 'llm:models',
'workflow:execute', 'workflow:read',
'cache:read', 'cache:write'
"chatbot:invoke",
"chatbot:manage",
"chatbot:read",
"rag:query",
"rag:manage",
"rag:read",
"llm:completion",
"llm:embeddings",
"llm:models",
"workflow:execute",
"workflow:read",
"cache:read",
"cache:write",
]
# Support wildcard permissions
if api.endswith(':*'):
if api.endswith(":*"):
base_api = api[:-2]
return any(supported.startswith(base_api + ':') for supported in supported_apis)
return any(
supported.startswith(base_api + ":") for supported in supported_apis
)
return api in supported_apis
@classmethod
def _is_python_version_supported(cls, version: str) -> bool:
"""Check if Python version is supported"""
supported_versions = ['3.9', '3.10', '3.11', '3.12']
supported_versions = ["3.9", "3.10", "3.11", "3.12"]
return any(version.startswith(v) for v in supported_versions)
@classmethod
def _is_dependency_conflicting(cls, dependency: str) -> bool:
"""Check if dependency might conflict with platform"""
# Extract package name (before ==, >=, etc.)
package_name = dependency.split('==')[0].split('>=')[0].split('<=')[0].split('>')[0].split('<')[0].strip()
package_name = (
dependency.split("==")[0]
.split(">=")[0]
.split("<=")[0]
.split(">")[0]
.split("<")[0]
.strip()
)
# Known conflicting packages
conflicting_packages = [
'sqlalchemy', # Platform uses specific version
'fastapi', # Platform uses specific version
'pydantic', # Platform uses specific version
'alembic' # Platform migration system
"sqlalchemy", # Platform uses specific version
"fastapi", # Platform uses specific version
"pydantic", # Platform uses specific version
"alembic", # Platform migration system
]
return package_name.lower() in conflicting_packages
@@ -359,8 +433,10 @@ class PluginManifestValidator:
def generate_manifest_hash(cls, manifest: PluginManifest) -> str:
"""Generate hash for manifest content verification"""
manifest_dict = manifest.dict()
manifest_str = yaml.dump(manifest_dict, sort_keys=True, default_flow_style=False)
return hashlib.sha256(manifest_str.encode('utf-8')).hexdigest()
manifest_str = yaml.dump(
manifest_dict, sort_keys=True, default_flow_style=False
)
return hashlib.sha256(manifest_str.encode("utf-8")).hexdigest()
@classmethod
def create_example_manifest(cls, plugin_name: str) -> PluginManifest:
@@ -372,29 +448,25 @@ class PluginManifestValidator:
description=f"Example {plugin_name} plugin for Enclava platform",
author="Enclava Team",
license="MIT",
tags=["integration", "example"]
tags=["integration", "example"],
),
spec=PluginSpec(
runtime=PluginRuntimeSpec(
python_version="3.11",
dependencies=[
"aiohttp>=3.8.0",
"pydantic>=2.0.0"
]
dependencies=["aiohttp>=3.8.0", "pydantic>=2.0.0"],
),
permissions=PluginPermissions(
platform_apis=["chatbot:invoke", "rag:query"],
plugin_scopes=["read", "write"]
plugin_scopes=["read", "write"],
),
database=PluginDatabaseSpec(
schema=f"plugin_{plugin_name}",
migrations_path="./migrations"
schema=f"plugin_{plugin_name}", migrations_path="./migrations"
),
api_endpoints=[
PluginAPIEndpoint(
path="/status",
methods=["GET"],
description="Plugin health status"
description="Plugin health status",
)
],
ui_config=PluginUIConfig(
@@ -403,11 +475,11 @@ class PluginManifestValidator:
{
"name": "dashboard",
"path": f"/plugins/{plugin_name}",
"component": f"{plugin_name.title()}Dashboard"
"component": f"{plugin_name.title()}Dashboard",
}
]
)
)
],
),
),
)
@@ -423,7 +495,7 @@ def validate_manifest_file(manifest_path: Union[str, Path]) -> Dict[str, Any]:
"manifest": manifest,
"compatibility": compatibility,
"hash": manifest_hash,
"errors": []
"errors": [],
}
except Exception as e:
@@ -432,5 +504,5 @@ def validate_manifest_file(manifest_path: Union[str, Path]) -> Dict[str, Any]:
"manifest": None,
"compatibility": None,
"hash": None,
"errors": [str(e)]
"errors": [str(e)],
}

367
backend/app/schemas/role.py Normal file
View File

@@ -0,0 +1,367 @@
"""
Role Management Schemas
Pydantic models for role management API
"""
from typing import Optional, List, Dict, Any
from datetime import datetime
from pydantic import BaseModel, validator
class RoleBase(BaseModel):
"""Base role schema"""
name: str
display_name: str
description: Optional[str] = None
level: str = "user"
permissions: Dict[str, Any] = {}
can_manage_users: bool = False
can_manage_budgets: bool = False
can_view_reports: bool = False
can_manage_tools: bool = False
inherits_from: List[str] = []
is_active: bool = True
@validator("name")
def validate_name(cls, v):
if len(v) < 2:
raise ValueError("Role name must be at least 2 characters long")
if len(v) > 50:
raise ValueError("Role name must be less than 50 characters long")
if not v.isalnum() and "_" not in v:
raise ValueError(
"Role name must contain only alphanumeric characters and underscores"
)
return v.lower()
@validator("display_name")
def validate_display_name(cls, v):
if len(v) < 2:
raise ValueError("Display name must be at least 2 characters long")
if len(v) > 100:
raise ValueError("Display name must be less than 100 characters long")
return v
@validator("level")
def validate_level(cls, v):
valid_levels = ["read_only", "user", "admin", "super_admin"]
if v not in valid_levels:
raise ValueError(f'Level must be one of: {", ".join(valid_levels)}')
return v
class RoleCreate(RoleBase):
"""Schema for creating a role"""
is_system_role: bool = False
class RoleUpdate(BaseModel):
"""Schema for updating a role"""
display_name: Optional[str] = None
description: Optional[str] = None
permissions: Optional[Dict[str, Any]] = None
can_manage_users: Optional[bool] = None
can_manage_budgets: Optional[bool] = None
can_view_reports: Optional[bool] = None
can_manage_tools: Optional[bool] = None
is_active: Optional[bool] = None
@validator("display_name")
def validate_display_name(cls, v):
if v is not None:
if len(v) < 2:
raise ValueError("Display name must be at least 2 characters long")
if len(v) > 100:
raise ValueError("Display name must be less than 100 characters long")
return v
class RoleResponse(BaseModel):
"""Role response schema"""
id: int
name: str
display_name: str
description: Optional[str]
level: str
permissions: Dict[str, Any]
can_manage_users: bool
can_manage_budgets: bool
can_view_reports: bool
can_manage_tools: bool
inherits_from: List[str]
is_active: bool
is_system_role: bool
created_at: Optional[datetime]
updated_at: Optional[datetime]
user_count: Optional[int] = 0 # Number of users with this role
class Config:
from_attributes = True
@classmethod
def from_orm(cls, obj):
"""Create response from ORM object"""
data = obj.to_dict()
# Add user count if available
if hasattr(obj, "users"):
data["user_count"] = len([u for u in obj.users if u.is_active])
return cls(**data)
class RoleListResponse(BaseModel):
"""Role list response schema"""
roles: List[RoleResponse]
total: int
skip: int
limit: int
class RoleAssignmentRequest(BaseModel):
"""Schema for role assignment"""
role_id: int
@validator("role_id")
def validate_role_id(cls, v):
if v <= 0:
raise ValueError("Role ID must be a positive integer")
return v
class RoleBulkAction(BaseModel):
"""Schema for bulk role actions"""
role_ids: List[int]
action: str # activate, deactivate, delete
action_data: Optional[Dict[str, Any]] = None
@validator("action")
def validate_action(cls, v):
valid_actions = ["activate", "deactivate", "delete"]
if v not in valid_actions:
raise ValueError(f'Action must be one of: {", ".join(valid_actions)}')
return v
@validator("role_ids")
def validate_role_ids(cls, v):
if not v:
raise ValueError("At least one role ID must be provided")
if len(v) > 50:
raise ValueError("Cannot perform bulk action on more than 50 roles at once")
return v
class RoleStatistics(BaseModel):
"""Role statistics schema"""
total_roles: int
active_roles: int
system_roles: int
roles_by_level: Dict[str, int]
roles_with_users: int
unused_roles: int
class RolePermission(BaseModel):
"""Individual role permission schema"""
permission: str
granted: bool = True
description: Optional[str] = None
class RolePermissionTemplate(BaseModel):
"""Role permission template schema"""
name: str
display_name: str
description: str
level: str
permissions: List[RolePermission]
can_manage_users: bool = False
can_manage_budgets: bool = False
can_view_reports: bool = False
can_manage_tools: bool = False
class RoleHierarchy(BaseModel):
"""Role hierarchy schema"""
role: RoleResponse
parent_roles: List[RoleResponse] = []
child_roles: List[RoleResponse] = []
effective_permissions: Dict[str, Any]
class RoleComparison(BaseModel):
"""Role comparison schema"""
role1: RoleResponse
role2: RoleResponse
common_permissions: List[str]
unique_to_role1: List[str]
unique_to_role2: List[str]
class RoleUsage(BaseModel):
"""Role usage statistics schema"""
role_id: int
role_name: str
user_count: int
active_user_count: int
last_assigned: Optional[datetime]
usage_trend: Dict[str, int] # Monthly usage data
class RoleSearchFilter(BaseModel):
"""Role search filter schema"""
search: Optional[str] = None
level: Optional[str] = None
is_active: Optional[bool] = None
is_system_role: Optional[bool] = None
has_users: Optional[bool] = None
created_after: Optional[datetime] = None
created_before: Optional[datetime] = None
# Predefined permission templates
ROLE_TEMPLATES = {
"read_only": RolePermissionTemplate(
name="read_only",
display_name="Read Only",
description="Can view own data only",
level="read_only",
permissions=[
RolePermission(
permission="read_own", granted=True, description="Read own data"
),
RolePermission(
permission="create", granted=False, description="Create new resources"
),
RolePermission(
permission="update",
granted=False,
description="Update existing resources",
),
RolePermission(
permission="delete", granted=False, description="Delete resources"
),
],
),
"user": RolePermissionTemplate(
name="user",
display_name="User",
description="Standard user with full access to own resources",
level="user",
permissions=[
RolePermission(
permission="read_own", granted=True, description="Read own data"
),
RolePermission(
permission="create_own",
granted=True,
description="Create own resources",
),
RolePermission(
permission="update_own",
granted=True,
description="Update own resources",
),
RolePermission(
permission="delete_own",
granted=True,
description="Delete own resources",
),
RolePermission(
permission="manage_users",
granted=False,
description="Manage other users",
),
RolePermission(
permission="manage_all",
granted=False,
description="Manage all resources",
),
],
inherits_from=["read_only"],
),
"admin": RolePermissionTemplate(
name="admin",
display_name="Administrator",
description="Can manage users and view reports",
level="admin",
permissions=[
RolePermission(
permission="read_all", granted=True, description="Read all data"
),
RolePermission(
permission="create_all",
granted=True,
description="Create any resources",
),
RolePermission(
permission="update_all",
granted=True,
description="Update any resources",
),
RolePermission(
permission="manage_users", granted=True, description="Manage users"
),
RolePermission(
permission="view_reports",
granted=True,
description="View system reports",
),
RolePermission(
permission="system_settings",
granted=False,
description="Modify system settings",
),
],
can_manage_users=True,
can_view_reports=True,
inherits_from=["user"],
),
"super_admin": RolePermissionTemplate(
name="super_admin",
display_name="Super Administrator",
description="Full system access",
level="super_admin",
permissions=[
RolePermission(permission="*", granted=True, description="All permissions")
],
can_manage_users=True,
can_manage_budgets=True,
can_view_reports=True,
can_manage_tools=True,
inherits_from=["admin"],
),
}
# Common permission definitions
COMMON_PERMISSIONS = {
"read_own": "Read own data and resources",
"read_all": "Read all data and resources",
"create_own": "Create own resources",
"create_all": "Create any resources",
"update_own": "Update own resources",
"update_all": "Update any resources",
"delete_own": "Delete own resources",
"delete_all": "Delete any resources",
"manage_users": "Manage user accounts",
"manage_roles": "Manage role assignments",
"manage_budgets": "Manage budget settings",
"view_reports": "View system reports",
"manage_tools": "Manage custom tools",
"system_settings": "Modify system settings",
"create_api_key": "Create API keys",
"create_tool": "Create custom tools",
"export_data": "Export system data",
}

219
backend/app/schemas/tool.py Normal file
View File

@@ -0,0 +1,219 @@
"""
Tool schemas for API requests and responses
"""
from typing import Optional, List, Dict, Any
from datetime import datetime
from pydantic import BaseModel, Field, validator
# Tool Creation and Update Schemas
class ToolCreate(BaseModel):
"""Schema for creating a new tool"""
name: str = Field(..., min_length=1, max_length=100, description="Unique tool name")
display_name: str = Field(
..., min_length=1, max_length=200, description="Display name for the tool"
)
description: Optional[str] = Field(None, description="Tool description")
tool_type: str = Field(..., description="Tool type (python, bash, docker, api)")
code: str = Field(..., min_length=1, description="Tool implementation code")
parameters_schema: Optional[Dict[str, Any]] = Field(
None, description="JSON schema for parameters"
)
return_schema: Optional[Dict[str, Any]] = Field(
None, description="Expected return format schema"
)
timeout_seconds: Optional[int] = Field(
30, ge=1, le=300, description="Execution timeout in seconds"
)
max_memory_mb: Optional[int] = Field(
256, ge=1, le=1024, description="Maximum memory in MB"
)
max_cpu_seconds: Optional[float] = Field(
10.0, ge=0.1, le=60.0, description="Maximum CPU time in seconds"
)
docker_image: Optional[str] = Field(
None, max_length=200, description="Docker image for execution"
)
docker_command: Optional[str] = Field(None, description="Docker command to run")
category: Optional[str] = Field(None, max_length=50, description="Tool category")
tags: Optional[List[str]] = Field(None, description="Tool tags")
is_public: Optional[bool] = Field(False, description="Whether tool is public")
@validator("tool_type")
def validate_tool_type(cls, v):
valid_types = ["python", "bash", "docker", "api", "custom"]
if v not in valid_types:
raise ValueError(f"Tool type must be one of: {valid_types}")
return v
@validator("tags")
def validate_tags(cls, v):
if v is not None and len(v) > 10:
raise ValueError("Maximum 10 tags allowed")
return v
class ToolUpdate(BaseModel):
"""Schema for updating a tool"""
display_name: Optional[str] = Field(None, min_length=1, max_length=200)
description: Optional[str] = None
code: Optional[str] = Field(None, min_length=1)
parameters_schema: Optional[Dict[str, Any]] = None
return_schema: Optional[Dict[str, Any]] = None
timeout_seconds: Optional[int] = Field(None, ge=1, le=300)
max_memory_mb: Optional[int] = Field(None, ge=1, le=1024)
max_cpu_seconds: Optional[float] = Field(None, ge=0.1, le=60.0)
docker_image: Optional[str] = Field(None, max_length=200)
docker_command: Optional[str] = None
category: Optional[str] = Field(None, max_length=50)
tags: Optional[List[str]] = None
is_public: Optional[bool] = None
is_active: Optional[bool] = None
@validator("tags")
def validate_tags(cls, v):
if v is not None and len(v) > 10:
raise ValueError("Maximum 10 tags allowed")
return v
# Tool Response Schemas
class ToolResponse(BaseModel):
"""Schema for tool response"""
id: int
name: str
display_name: str
description: Optional[str]
tool_type: str
parameters_schema: Dict[str, Any]
return_schema: Dict[str, Any]
timeout_seconds: int
max_memory_mb: int
max_cpu_seconds: float
docker_image: Optional[str]
is_public: bool
is_approved: bool
created_by_user_id: int
category: Optional[str]
tags: List[str]
usage_count: int
last_used_at: Optional[datetime]
is_active: bool
created_at: Optional[datetime]
updated_at: Optional[datetime]
class Config:
from_attributes = True
class ToolListResponse(BaseModel):
"""Schema for tool list response"""
tools: List[ToolResponse]
total: int
skip: int
limit: int
# Tool Execution Schemas
class ToolExecutionCreate(BaseModel):
"""Schema for creating a tool execution"""
parameters: Dict[str, Any] = Field(..., description="Parameters for tool execution")
timeout_override: Optional[int] = Field(
None, ge=1, le=300, description="Override default timeout"
)
class ToolExecutionResponse(BaseModel):
"""Schema for tool execution response"""
id: int
tool_id: int
tool_name: Optional[str]
executed_by_user_id: int
parameters: Dict[str, Any]
status: str
output: Optional[str]
error_message: Optional[str]
return_code: Optional[int]
execution_time_ms: Optional[int]
memory_used_mb: Optional[float]
cpu_time_ms: Optional[int]
container_id: Optional[str]
started_at: Optional[datetime]
completed_at: Optional[datetime]
created_at: Optional[datetime]
class Config:
from_attributes = True
class ToolExecutionListResponse(BaseModel):
"""Schema for tool execution list response"""
executions: List[ToolExecutionResponse]
total: int
skip: int
limit: int
# Tool Category Schemas
class ToolCategoryCreate(BaseModel):
"""Schema for creating a tool category"""
name: str = Field(
..., min_length=1, max_length=50, description="Unique category name"
)
display_name: str = Field(
..., min_length=1, max_length=100, description="Display name"
)
description: Optional[str] = Field(None, description="Category description")
icon: Optional[str] = Field(None, max_length=50, description="Icon name")
color: Optional[str] = Field(None, max_length=20, description="Color code")
sort_order: Optional[int] = Field(0, description="Sort order")
class ToolCategoryResponse(BaseModel):
"""Schema for tool category response"""
id: int
name: str
display_name: str
description: Optional[str]
icon: Optional[str]
color: Optional[str]
sort_order: int
is_active: bool
created_at: Optional[datetime]
updated_at: Optional[datetime]
class Config:
from_attributes = True
# Statistics Schema
class ToolStatisticsResponse(BaseModel):
"""Schema for tool statistics response"""
total_tools: int
public_tools: int
tools_by_type: Dict[str, int]
total_executions: int
executions_by_status: Dict[str, int]
recent_executions: int
top_tools: List[Dict[str, Any]]
user_tools: Optional[int] = None
user_executions: Optional[int] = None

View File

@@ -0,0 +1,68 @@
"""
Tool calling schemas for API requests and responses
"""
from typing import Dict, List, Optional, Any
from pydantic import BaseModel, Field
class ToolExecutionRequest(BaseModel):
"""Schema for executing a tool by name"""
tool_name: str = Field(..., description="Name of the tool to execute")
parameters: Dict[str, Any] = Field(
default_factory=dict, description="Parameters for tool execution"
)
class ToolCallResponse(BaseModel):
"""Schema for tool call response"""
success: bool = Field(..., description="Whether the tool call was successful")
result: Optional[Dict[str, Any]] = Field(None, description="Tool execution result")
error: Optional[str] = Field(None, description="Error message if failed")
class ToolValidationRequest(BaseModel):
"""Schema for validating tool availability"""
tool_names: List[str] = Field(..., description="List of tool names to validate")
class ToolValidationResponse(BaseModel):
"""Schema for tool validation response"""
tool_availability: Dict[str, bool] = Field(
..., description="Tool name to availability mapping"
)
class ToolHistoryItem(BaseModel):
"""Schema for tool execution history item"""
id: int = Field(..., description="Execution ID")
tool_name: str = Field(..., description="Tool name")
parameters: Dict[str, Any] = Field(..., description="Execution parameters")
status: str = Field(..., description="Execution status")
output: Optional[str] = Field(None, description="Tool output")
error_message: Optional[str] = Field(None, description="Error message if failed")
execution_time_ms: Optional[int] = Field(
None, description="Execution time in milliseconds"
)
created_at: Optional[str] = Field(None, description="Creation timestamp")
completed_at: Optional[str] = Field(None, description="Completion timestamp")
class ToolHistoryResponse(BaseModel):
"""Schema for tool execution history response"""
history: List[ToolHistoryItem] = Field(..., description="Tool execution history")
total: int = Field(..., description="Total number of history items")
class ToolCallRequest(BaseModel):
"""Schema for tool call request (placeholder for future use)"""
message: str = Field(..., description="Chat message")
tools: Optional[List[str]] = Field(
None, description="Specific tools to make available"
)

260
backend/app/schemas/user.py Normal file
View File

@@ -0,0 +1,260 @@
"""
User Management Schemas
Pydantic models for user management API
"""
from typing import Optional, List, Dict, Any
from datetime import datetime
from pydantic import BaseModel, EmailStr, validator
class UserBase(BaseModel):
"""Base user schema"""
email: EmailStr
username: str
full_name: Optional[str] = None
is_active: bool = True
is_verified: bool = False
@validator("username")
def validate_username(cls, v):
if len(v) < 3:
raise ValueError("Username must be at least 3 characters long")
if len(v) > 50:
raise ValueError("Username must be less than 50 characters long")
return v
@validator("email")
def validate_email(cls, v):
if len(v) > 255:
raise ValueError("Email must be less than 255 characters long")
return v
class UserCreate(UserBase):
"""Schema for creating a user"""
password: str
role_id: Optional[int] = None
custom_permissions: Dict[str, Any] = {}
@validator("password")
def validate_password(cls, v):
if len(v) < 8:
raise ValueError("Password must be at least 8 characters long")
return v
class UserUpdate(BaseModel):
"""Schema for updating a user"""
email: Optional[EmailStr] = None
username: Optional[str] = None
full_name: Optional[str] = None
role_id: Optional[int] = None
custom_permissions: Optional[Dict[str, Any]] = None
is_active: Optional[bool] = None
is_verified: Optional[bool] = None
@validator("username")
def validate_username(cls, v):
if v is not None:
if len(v) < 3:
raise ValueError("Username must be at least 3 characters long")
if len(v) > 50:
raise ValueError("Username must be less than 50 characters long")
return v
class PasswordChange(BaseModel):
"""Schema for changing password"""
current_password: Optional[str] = None
new_password: str
confirm_password: str
@validator("new_password")
def validate_new_password(cls, v):
if len(v) < 8:
raise ValueError("Password must be at least 8 characters long")
return v
@validator("confirm_password")
def passwords_match(cls, v, values):
if "new_password" in values and v != values["new_password"]:
raise ValueError("Passwords do not match")
return v
class RoleInfo(BaseModel):
"""Role information schema"""
id: int
name: str
display_name: str
level: str
permissions: Dict[str, Any]
class Config:
from_attributes = True
class UserResponse(BaseModel):
"""User response schema"""
id: int
email: str
username: str
full_name: Optional[str]
is_active: bool
is_verified: bool
is_superuser: bool
role_id: Optional[int]
role: Optional[RoleInfo]
custom_permissions: Dict[str, Any]
account_locked: Optional[bool] = False
account_locked_until: Optional[datetime]
failed_login_attempts: Optional[int] = 0
last_failed_login: Optional[datetime]
avatar_url: Optional[str]
bio: Optional[str]
company: Optional[str]
website: Optional[str]
created_at: Optional[datetime]
updated_at: Optional[datetime]
last_login: Optional[datetime]
preferences: Dict[str, Any]
notification_settings: Dict[str, Any]
class Config:
from_attributes = True
@classmethod
def from_orm(cls, obj):
"""Create response from ORM object with proper role handling"""
data = obj.to_dict()
if obj.role:
data["role"] = RoleInfo.from_orm(obj.role)
return cls(**data)
class UserListResponse(BaseModel):
"""User list response schema"""
users: List[UserResponse]
total: int
skip: int
limit: int
class AccountLockResponse(BaseModel):
"""Account lock response schema"""
user_id: int
is_locked: bool
locked_until: Optional[datetime]
message: str
class UserProfileUpdate(BaseModel):
"""Schema for user profile updates (limited fields)"""
full_name: Optional[str] = None
avatar_url: Optional[str] = None
bio: Optional[str] = None
company: Optional[str] = None
website: Optional[str] = None
preferences: Optional[Dict[str, Any]] = None
notification_settings: Optional[Dict[str, Any]] = None
class UserPreferences(BaseModel):
"""User preferences schema"""
theme: Optional[str] = "light"
language: Optional[str] = "en"
timezone: Optional[str] = "UTC"
email_notifications: Optional[bool] = True
security_alerts: Optional[bool] = True
system_updates: Optional[bool] = True
class UserSearchFilter(BaseModel):
"""User search filter schema"""
search: Optional[str] = None
role_id: Optional[int] = None
is_active: Optional[bool] = None
is_verified: Optional[bool] = None
created_after: Optional[datetime] = None
created_before: Optional[datetime] = None
last_login_after: Optional[datetime] = None
last_login_before: Optional[datetime] = None
class UserBulkAction(BaseModel):
"""Schema for bulk user actions"""
user_ids: List[int]
action: str # activate, deactivate, lock, unlock, assign_role, remove_role
action_data: Optional[Dict[str, Any]] = None
@validator("action")
def validate_action(cls, v):
valid_actions = [
"activate",
"deactivate",
"lock",
"unlock",
"assign_role",
"remove_role",
]
if v not in valid_actions:
raise ValueError(f'Action must be one of: {", ".join(valid_actions)}')
return v
@validator("user_ids")
def validate_user_ids(cls, v):
if not v:
raise ValueError("At least one user ID must be provided")
if len(v) > 100:
raise ValueError(
"Cannot perform bulk action on more than 100 users at once"
)
return v
class UserStatistics(BaseModel):
"""User statistics schema"""
total_users: int
active_users: int
verified_users: int
locked_users: int
users_by_role: Dict[str, int]
recent_registrations: int
registrations_by_month: Dict[str, int]
class UserActivity(BaseModel):
"""User activity schema"""
user_id: int
action: str
resource_type: Optional[str] = None
resource_id: Optional[int] = None
details: Optional[Dict[str, Any]] = None
timestamp: datetime
ip_address: Optional[str] = None
user_agent: Optional[str] = None
class UserActivityFilter(BaseModel):
"""User activity filter schema"""
user_id: Optional[int] = None
action: Optional[str] = None
resource_type: Optional[str] = None
date_from: Optional[datetime] = None
date_to: Optional[datetime] = None
ip_address: Optional[str] = None

View File

@@ -27,6 +27,7 @@ logger = get_logger(__name__)
@dataclass
class RequestEvent:
"""Enhanced request event data structure with budget integration"""
timestamp: datetime
method: str
path: str
@@ -55,6 +56,7 @@ class RequestEvent:
@dataclass
class UsageMetrics:
"""Usage metrics including costs and tokens"""
total_requests: int
successful_requests: int
failed_requests: int
@@ -84,6 +86,7 @@ class UsageMetrics:
@dataclass
class SystemHealth:
"""System health including budget and usage analysis"""
status: str # healthy, warning, critical
score: int # 0-100
issues: List[str]
@@ -113,20 +116,20 @@ class AnalyticsService:
self.cache_ttl = 300 # 5 minutes cache TTL
# Statistics counters
self.endpoint_stats = defaultdict(lambda: {
self.endpoint_stats = defaultdict(
lambda: {
"count": 0,
"total_time": 0,
"errors": 0,
"avg_time": 0,
"total_tokens": 0,
"total_cost_cents": 0
})
"total_cost_cents": 0,
}
)
self.status_codes = defaultdict(int)
self.model_stats = defaultdict(lambda: {
"count": 0,
"total_tokens": 0,
"total_cost_cents": 0
})
self.model_stats = defaultdict(
lambda: {"count": 0, "total_tokens": 0, "total_cost_cents": 0}
)
# Start cleanup task
asyncio.create_task(self._cleanup_old_events())
@@ -165,13 +168,19 @@ class AnalyticsService:
# Clear metrics cache to force recalculation
self.metrics_cache.clear()
logger.debug(f"Tracked request: {endpoint} - {event.status_code} - {event.response_time:.3f}s")
logger.debug(
f"Tracked request: {endpoint} - {event.status_code} - {event.response_time:.3f}s"
)
except Exception as e:
logger.error(f"Error tracking request: {e}")
async def get_usage_metrics(self, hours: int = 24, user_id: Optional[int] = None,
api_key_id: Optional[int] = None) -> UsageMetrics:
async def get_usage_metrics(
self,
hours: int = 24,
user_id: Optional[int] = None,
api_key_id: Optional[int] = None,
) -> UsageMetrics:
"""Get comprehensive usage metrics including costs and budgets"""
cache_key = f"usage_metrics_{hours}_{user_id}_{api_key_id}"
@@ -207,7 +216,9 @@ class AnalyticsService:
failed_requests = total_requests - successful_requests
if total_requests > 0:
avg_response_time = sum(e.response_time for e in recent_events) / total_requests
avg_response_time = (
sum(e.response_time for e in recent_events) / total_requests
)
requests_per_minute = total_requests / (hours * 60)
error_rate = (failed_requests / total_requests) * 100
else:
@@ -230,9 +241,14 @@ class AnalyticsService:
budget_query = self.db.query(Budget).filter(Budget.is_active == True)
if user_id:
budget_query = budget_query.filter(
or_(Budget.user_id == user_id, Budget.api_key_id.in_(
self.db.query(APIKey.id).filter(APIKey.user_id == user_id).subquery()
))
or_(
Budget.user_id == user_id,
Budget.api_key_id.in_(
self.db.query(APIKey.id)
.filter(APIKey.user_id == user_id)
.subquery()
),
)
)
budgets = budget_query.all()
@@ -253,7 +269,9 @@ class AnalyticsService:
top_endpoints = [
{"endpoint": endpoint, "count": count}
for endpoint, count in sorted(endpoint_counts.items(), key=lambda x: x[1], reverse=True)[:10]
for endpoint, count in sorted(
endpoint_counts.items(), key=lambda x: x[1], reverse=True
)[:10]
]
# Status codes from memory
@@ -262,21 +280,27 @@ class AnalyticsService:
status_counts[str(event.status_code)] += 1
# Top models from database
model_usage = self.db.query(
model_usage = (
self.db.query(
UsageTracking.model,
func.count(UsageTracking.id).label('count'),
func.sum(UsageTracking.total_tokens).label('tokens'),
func.sum(UsageTracking.cost_cents).label('cost')
).filter(and_(*filters)).filter(
UsageTracking.model.is_not(None)
).group_by(UsageTracking.model).order_by(desc('count')).limit(10).all()
func.count(UsageTracking.id).label("count"),
func.sum(UsageTracking.total_tokens).label("tokens"),
func.sum(UsageTracking.cost_cents).label("cost"),
)
.filter(and_(*filters))
.filter(UsageTracking.model.is_not(None))
.group_by(UsageTracking.model)
.order_by(desc("count"))
.limit(10)
.all()
)
top_models = [
{
"model": model,
"count": count,
"total_tokens": tokens or 0,
"total_cost_cents": cost or 0
"total_cost_cents": cost or 0,
}
for model, count, tokens, cost in model_usage
]
@@ -300,7 +324,7 @@ class AnalyticsService:
top_endpoints=top_endpoints,
status_codes=dict(status_counts),
top_models=top_models,
timestamp=datetime.utcnow()
timestamp=datetime.utcnow(),
)
# Cache the result
@@ -310,13 +334,24 @@ class AnalyticsService:
except Exception as e:
logger.error(f"Error getting usage metrics: {e}")
return UsageMetrics(
total_requests=0, successful_requests=0, failed_requests=0,
avg_response_time=0, requests_per_minute=0, error_rate=0,
total_tokens=0, total_cost_cents=0, avg_tokens_per_request=0,
avg_cost_per_request_cents=0, total_budget_cents=0,
used_budget_cents=0, budget_usage_percentage=0, active_budgets=0,
top_endpoints=[], status_codes={}, top_models=[],
timestamp=datetime.utcnow()
total_requests=0,
successful_requests=0,
failed_requests=0,
avg_response_time=0,
requests_per_minute=0,
error_rate=0,
total_tokens=0,
total_cost_cents=0,
avg_tokens_per_request=0,
avg_cost_per_request_cents=0,
total_budget_cents=0,
used_budget_cents=0,
budget_usage_percentage=0,
active_budgets=0,
top_endpoints=[],
status_codes={},
top_models=[],
timestamp=datetime.utcnow(),
)
async def get_system_health(self) -> SystemHealth:
@@ -347,22 +382,30 @@ class AnalyticsService:
recommendations.append("Optimize slow endpoints and database queries")
elif metrics.avg_response_time > 2.0:
health_score -= 10
issues.append(f"Elevated response time: {metrics.avg_response_time:.2f}s")
issues.append(
f"Elevated response time: {metrics.avg_response_time:.2f}s"
)
recommendations.append("Monitor performance trends")
# Check budget usage
if metrics.budget_usage_percentage > 90:
health_score -= 20
issues.append(f"Budget usage critical: {metrics.budget_usage_percentage:.1f}%")
issues.append(
f"Budget usage critical: {metrics.budget_usage_percentage:.1f}%"
)
recommendations.append("Review budget limits and usage patterns")
elif metrics.budget_usage_percentage > 75:
health_score -= 10
issues.append(f"Budget usage high: {metrics.budget_usage_percentage:.1f}%")
issues.append(
f"Budget usage high: {metrics.budget_usage_percentage:.1f}%"
)
recommendations.append("Monitor spending trends")
# Check for budgets near or over limit
budgets = self.db.query(Budget).filter(Budget.is_active == True).all()
budgets_near_limit = sum(1 for b in budgets if b.current_usage_cents >= b.limit_cents * 0.8)
budgets_near_limit = sum(
1 for b in budgets if b.current_usage_cents >= b.limit_cents * 0.8
)
budgets_exceeded = sum(1 for b in budgets if b.is_exceeded)
if budgets_exceeded > 0:
@@ -393,21 +436,28 @@ class AnalyticsService:
budget_usage_percentage=metrics.budget_usage_percentage,
budgets_near_limit=budgets_near_limit,
budgets_exceeded=budgets_exceeded,
timestamp=datetime.utcnow()
timestamp=datetime.utcnow(),
)
except Exception as e:
logger.error(f"Error getting system health: {e}")
return SystemHealth(
status="error", score=0,
status="error",
score=0,
issues=[f"Health check failed: {str(e)}"],
recommendations=["Check system logs and restart services"],
avg_response_time=0, error_rate=0, requests_per_minute=0,
budget_usage_percentage=0, budgets_near_limit=0,
budgets_exceeded=0, timestamp=datetime.utcnow()
avg_response_time=0,
error_rate=0,
requests_per_minute=0,
budget_usage_percentage=0,
budgets_near_limit=0,
budgets_exceeded=0,
timestamp=datetime.utcnow(),
)
async def get_cost_analysis(self, days: int = 30, user_id: Optional[int] = None) -> Dict[str, Any]:
async def get_cost_analysis(
self, days: int = 30, user_id: Optional[int] = None
) -> Dict[str, Any]:
"""Get detailed cost analysis and trends"""
try:
cutoff_time = datetime.utcnow() - timedelta(days=days)
@@ -448,9 +498,15 @@ class AnalyticsService:
total_requests = len(usage_records)
efficiency_metrics = {
"cost_per_token": (total_cost / total_tokens) if total_tokens > 0 else 0,
"cost_per_request": (total_cost / total_requests) if total_requests > 0 else 0,
"tokens_per_request": (total_tokens / total_requests) if total_requests > 0 else 0
"cost_per_token": (total_cost / total_tokens)
if total_tokens > 0
else 0,
"cost_per_request": (total_cost / total_requests)
if total_requests > 0
else 0,
"tokens_per_request": (total_tokens / total_requests)
if total_requests > 0
else 0,
}
return {
@@ -465,7 +521,7 @@ class AnalyticsService:
"requests_by_model": dict(requests_by_model),
"daily_costs": dict(daily_costs),
"cost_by_endpoint": dict(cost_by_endpoint),
"analysis_timestamp": datetime.utcnow().isoformat()
"analysis_timestamp": datetime.utcnow().isoformat(),
}
except Exception as e:
@@ -538,20 +594,20 @@ class InMemoryAnalyticsService:
self.cache_ttl = 300 # 5 minutes cache TTL
# Statistics counters
self.endpoint_stats = defaultdict(lambda: {
self.endpoint_stats = defaultdict(
lambda: {
"count": 0,
"total_time": 0,
"errors": 0,
"avg_time": 0,
"total_tokens": 0,
"total_cost_cents": 0
})
"total_cost_cents": 0,
}
)
self.status_codes = defaultdict(int)
self.model_stats = defaultdict(lambda: {
"count": 0,
"total_tokens": 0,
"total_cost_cents": 0
})
self.model_stats = defaultdict(
lambda: {"count": 0, "total_tokens": 0, "total_cost_cents": 0}
)
# Start cleanup task
asyncio.create_task(self._cleanup_old_events())
@@ -590,13 +646,19 @@ class InMemoryAnalyticsService:
# Clear metrics cache to force recalculation
self.metrics_cache.clear()
logger.debug(f"Tracked request: {endpoint} - {event.status_code} - {event.response_time:.3f}s")
logger.debug(
f"Tracked request: {endpoint} - {event.status_code} - {event.response_time:.3f}s"
)
except Exception as e:
logger.error(f"Error tracking request: {e}")
async def get_usage_metrics(self, hours: int = 24, user_id: Optional[int] = None,
api_key_id: Optional[int] = None) -> UsageMetrics:
async def get_usage_metrics(
self,
hours: int = 24,
user_id: Optional[int] = None,
api_key_id: Optional[int] = None,
) -> UsageMetrics:
"""Get comprehensive usage metrics including costs and budgets"""
cache_key = f"usage_metrics_{hours}_{user_id}_{api_key_id}"
@@ -622,7 +684,9 @@ class InMemoryAnalyticsService:
failed_requests = total_requests - successful_requests
if total_requests > 0:
avg_response_time = sum(e.response_time for e in recent_events) / total_requests
avg_response_time = (
sum(e.response_time for e in recent_events) / total_requests
)
requests_per_minute = total_requests / (hours * 60)
error_rate = (failed_requests / total_requests) * 100
else:
@@ -658,7 +722,9 @@ class InMemoryAnalyticsService:
top_endpoints = [
{"endpoint": endpoint, "count": count}
for endpoint, count in sorted(endpoint_counts.items(), key=lambda x: x[1], reverse=True)[:10]
for endpoint, count in sorted(
endpoint_counts.items(), key=lambda x: x[1], reverse=True
)[:10]
]
# Status codes from memory
@@ -679,9 +745,11 @@ class InMemoryAnalyticsService:
"model": model,
"count": data["count"],
"total_tokens": data["tokens"],
"total_cost_cents": data["cost"]
"total_cost_cents": data["cost"],
}
for model, data in sorted(model_usage.items(), key=lambda x: x[1]["count"], reverse=True)[:10]
for model, data in sorted(
model_usage.items(), key=lambda x: x[1]["count"], reverse=True
)[:10]
]
# Create metrics object
@@ -703,7 +771,7 @@ class InMemoryAnalyticsService:
top_endpoints=top_endpoints,
status_codes=dict(status_counts),
top_models=top_models,
timestamp=datetime.utcnow()
timestamp=datetime.utcnow(),
)
# Cache the result
@@ -713,13 +781,24 @@ class InMemoryAnalyticsService:
except Exception as e:
logger.error(f"Error getting usage metrics: {e}")
return UsageMetrics(
total_requests=0, successful_requests=0, failed_requests=0,
avg_response_time=0, requests_per_minute=0, error_rate=0,
total_tokens=0, total_cost_cents=0, avg_tokens_per_request=0,
avg_cost_per_request_cents=0, total_budget_cents=0,
used_budget_cents=0, budget_usage_percentage=0, active_budgets=0,
top_endpoints=[], status_codes={}, top_models=[],
timestamp=datetime.utcnow()
total_requests=0,
successful_requests=0,
failed_requests=0,
avg_response_time=0,
requests_per_minute=0,
error_rate=0,
total_tokens=0,
total_cost_cents=0,
avg_tokens_per_request=0,
avg_cost_per_request_cents=0,
total_budget_cents=0,
used_budget_cents=0,
budget_usage_percentage=0,
active_budgets=0,
top_endpoints=[],
status_codes={},
top_models=[],
timestamp=datetime.utcnow(),
)
async def get_system_health(self) -> SystemHealth:
@@ -750,17 +829,23 @@ class InMemoryAnalyticsService:
recommendations.append("Optimize slow endpoints and database queries")
elif metrics.avg_response_time > 2.0:
health_score -= 10
issues.append(f"Elevated response time: {metrics.avg_response_time:.2f}s")
issues.append(
f"Elevated response time: {metrics.avg_response_time:.2f}s"
)
recommendations.append("Monitor performance trends")
# Check budget usage
if metrics.budget_usage_percentage > 90:
health_score -= 20
issues.append(f"Budget usage critical: {metrics.budget_usage_percentage:.1f}%")
issues.append(
f"Budget usage critical: {metrics.budget_usage_percentage:.1f}%"
)
recommendations.append("Review budget limits and usage patterns")
elif metrics.budget_usage_percentage > 75:
health_score -= 10
issues.append(f"Budget usage high: {metrics.budget_usage_percentage:.1f}%")
issues.append(
f"Budget usage high: {metrics.budget_usage_percentage:.1f}%"
)
recommendations.append("Monitor spending trends")
# Determine overall status
@@ -782,21 +867,28 @@ class InMemoryAnalyticsService:
budget_usage_percentage=metrics.budget_usage_percentage,
budgets_near_limit=0, # Mock values since no DB access
budgets_exceeded=0,
timestamp=datetime.utcnow()
timestamp=datetime.utcnow(),
)
except Exception as e:
logger.error(f"Error getting system health: {e}")
return SystemHealth(
status="error", score=0,
status="error",
score=0,
issues=[f"Health check failed: {str(e)}"],
recommendations=["Check system logs and restart services"],
avg_response_time=0, error_rate=0, requests_per_minute=0,
budget_usage_percentage=0, budgets_near_limit=0,
budgets_exceeded=0, timestamp=datetime.utcnow()
avg_response_time=0,
error_rate=0,
requests_per_minute=0,
budget_usage_percentage=0,
budgets_near_limit=0,
budgets_exceeded=0,
timestamp=datetime.utcnow(),
)
async def get_cost_analysis(self, days: int = 30, user_id: Optional[int] = None) -> Dict[str, Any]:
async def get_cost_analysis(
self, days: int = 30, user_id: Optional[int] = None
) -> Dict[str, Any]:
"""Get detailed cost analysis and trends"""
try:
cutoff_time = datetime.utcnow() - timedelta(days=days)
@@ -835,9 +927,15 @@ class InMemoryAnalyticsService:
total_requests = len(events)
efficiency_metrics = {
"cost_per_token": (total_cost / total_tokens) if total_tokens > 0 else 0,
"cost_per_request": (total_cost / total_requests) if total_requests > 0 else 0,
"tokens_per_request": (total_tokens / total_requests) if total_requests > 0 else 0
"cost_per_token": (total_cost / total_tokens)
if total_tokens > 0
else 0,
"cost_per_request": (total_cost / total_requests)
if total_requests > 0
else 0,
"tokens_per_request": (total_tokens / total_requests)
if total_requests > 0
else 0,
}
return {
@@ -852,7 +950,7 @@ class InMemoryAnalyticsService:
"requests_by_model": dict(requests_by_model),
"daily_costs": dict(daily_costs),
"cost_by_endpoint": dict(cost_by_endpoint),
"analysis_timestamp": datetime.utcnow().isoformat()
"analysis_timestamp": datetime.utcnow().isoformat(),
}
except Exception as e:

View File

@@ -27,7 +27,9 @@ class APIKeyAuthService:
def __init__(self, db: AsyncSession):
self.db = db
async def validate_api_key(self, api_key: str, request: Request) -> Optional[Dict[str, Any]]:
async def validate_api_key(
self, api_key: str, request: Request
) -> Optional[Dict[str, Any]]:
"""Validate API key and return user context using Redis cache for performance"""
try:
if not api_key:
@@ -41,10 +43,14 @@ class APIKeyAuthService:
key_prefix = api_key[:8]
# Try cached verification first
cached_verification = await cached_api_key_service.verify_api_key_cached(api_key, key_prefix)
cached_verification = await cached_api_key_service.verify_api_key_cached(
api_key, key_prefix
)
# Get API key data from cache or database
context = await cached_api_key_service.get_cached_api_key(key_prefix, self.db)
context = await cached_api_key_service.get_cached_api_key(
key_prefix, self.db
)
if not context:
logger.warning(f"API key not found: {key_prefix}")
@@ -56,7 +62,7 @@ class APIKeyAuthService:
if not cached_verification:
# Get the actual key hash for verification (this should be in the cached context)
db_api_key = None
if not hasattr(api_key_obj, 'key_hash'):
if not hasattr(api_key_obj, "key_hash"):
# Fallback: fetch full API key from database for hash
stmt = select(APIKey).where(APIKey.key_prefix == key_prefix)
result = await self.db.execute(stmt)
@@ -73,7 +79,9 @@ class APIKeyAuthService:
return None
# Cache successful verification
await cached_api_key_service.cache_verification_result(api_key, key_prefix, key_hash, True)
await cached_api_key_service.cache_verification_result(
api_key, key_prefix, key_hash, True
)
# Check if key is valid (expiry, active status)
if not api_key_obj.is_valid():
@@ -89,7 +97,9 @@ class APIKeyAuthService:
return None
# Update last used timestamp asynchronously (performance optimization)
await cached_api_key_service.update_last_used(context["api_key_id"], self.db)
await cached_api_key_service.update_last_used(
context["api_key_id"], self.db
)
return context
@@ -97,7 +107,9 @@ class APIKeyAuthService:
logger.error(f"API key validation error: {e}")
return None
async def check_endpoint_permission(self, context: Dict[str, Any], endpoint: str) -> bool:
async def check_endpoint_permission(
self, context: Dict[str, Any], endpoint: str
) -> bool:
"""Check if API key has permission to access endpoint"""
api_key: APIKey = context.get("api_key")
if not api_key:
@@ -121,21 +133,24 @@ class APIKeyAuthService:
return api_key.has_scope(scope)
async def update_usage_stats(self, context: Dict[str, Any], tokens_used: int = 0, cost_cents: int = 0):
async def update_usage_stats(
self, context: Dict[str, Any], tokens_used: int = 0, cost_cents: int = 0
):
"""Update API key usage statistics"""
try:
api_key: APIKey = context.get("api_key")
if api_key:
api_key.update_usage(tokens_used, cost_cents)
await self.db.commit()
logger.info(f"Updated usage for API key {api_key.key_prefix}: +{tokens_used} tokens, +{cost_cents} cents")
logger.info(
f"Updated usage for API key {api_key.key_prefix}: +{tokens_used} tokens, +{cost_cents} cents"
)
except Exception as e:
logger.error(f"Failed to update usage stats: {e}")
async def get_api_key_context(
request: Request,
db: AsyncSession = Depends(get_db)
request: Request, db: AsyncSession = Depends(get_db)
) -> Optional[Dict[str, Any]]:
"""Dependency to get API key context from request"""
auth_service = APIKeyAuthService(db)
@@ -170,7 +185,7 @@ async def require_api_key(
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Valid API key required",
headers={"WWW-Authenticate": "Bearer"}
headers={"WWW-Authenticate": "Bearer"},
)
return context
@@ -190,7 +205,7 @@ async def get_current_api_key_user(
if not user or not api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User or API key not found in context"
detail="User or API key not found in context",
)
return user, api_key
@@ -210,7 +225,7 @@ async def get_api_key_auth(
if not api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API key not found in context"
detail="API key not found in context",
)
return api_key
@@ -227,7 +242,7 @@ class RequireScope:
if not await auth_service.check_scope_permission(context, self.scope):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Scope '{self.scope}' required"
detail=f"Scope '{self.scope}' required",
)
return context
@@ -243,6 +258,6 @@ class RequireModel:
if not await auth_service.check_model_permission(context, self.model):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Model '{self.model}' not allowed"
detail=f"Model '{self.model}' not allowed",
)
return context

View File

@@ -68,7 +68,7 @@ async def log_audit_event_async(
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
success: bool = True,
severity: str = "info"
severity: str = "info",
):
"""
Log an audit event asynchronously (non-blocking)
@@ -96,7 +96,7 @@ async def log_audit_event_async(
"user_agent": user_agent,
"success": success,
"severity": severity,
"created_at": datetime.utcnow()
"created_at": datetime.utcnow(),
}
# Queue the audit event (non-blocking)
@@ -122,7 +122,7 @@ async def log_audit_event(
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
success: bool = True,
severity: str = "info"
severity: str = "info",
):
"""
Log an audit event to the database
@@ -157,13 +157,15 @@ async def log_audit_event(
user_agent=user_agent,
success=success,
severity=severity,
created_at=datetime.utcnow()
created_at=datetime.utcnow(),
)
db.add(audit_log)
await db.flush() # Don't commit here, let the caller control the transaction
logger.debug(f"Audit event logged: {action} on {resource_type} by user {user_id}")
logger.debug(
f"Audit event logged: {action} on {resource_type} by user {user_id}"
)
except Exception as e:
logger.error(f"Failed to log audit event: {e}")
@@ -179,7 +181,7 @@ async def get_audit_logs(
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: int = 100,
offset: int = 0
offset: int = 0,
):
"""
Query audit logs with filtering
@@ -230,7 +232,7 @@ async def get_audit_logs(
async def get_audit_stats(
db: AsyncSession,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
end_date: Optional[datetime] = None,
):
"""
Get audit statistics
@@ -260,28 +262,36 @@ async def get_audit_stats(
total_events = total_result.scalar()
# Events by action
action_query = select(AuditLog.action, func.count(AuditLog.id)).group_by(AuditLog.action)
action_query = select(AuditLog.action, func.count(AuditLog.id)).group_by(
AuditLog.action
)
if conditions:
action_query = action_query.where(and_(*conditions))
action_result = await db.execute(action_query)
events_by_action = dict(action_result.fetchall())
# Events by resource type
resource_query = select(AuditLog.resource_type, func.count(AuditLog.id)).group_by(AuditLog.resource_type)
resource_query = select(AuditLog.resource_type, func.count(AuditLog.id)).group_by(
AuditLog.resource_type
)
if conditions:
resource_query = resource_query.where(and_(*conditions))
resource_result = await db.execute(resource_query)
events_by_resource = dict(resource_result.fetchall())
# Events by severity
severity_query = select(AuditLog.severity, func.count(AuditLog.id)).group_by(AuditLog.severity)
severity_query = select(AuditLog.severity, func.count(AuditLog.id)).group_by(
AuditLog.severity
)
if conditions:
severity_query = severity_query.where(and_(*conditions))
severity_result = await db.execute(severity_query)
events_by_severity = dict(severity_result.fetchall())
# Success rate
success_query = select(AuditLog.success, func.count(AuditLog.id)).group_by(AuditLog.success)
success_query = select(AuditLog.success, func.count(AuditLog.id)).group_by(
AuditLog.success
)
if conditions:
success_query = success_query.where(and_(*conditions))
success_result = await db.execute(success_query)
@@ -292,6 +302,10 @@ async def get_audit_stats(
"events_by_action": events_by_action,
"events_by_resource_type": events_by_resource,
"events_by_severity": events_by_severity,
"success_rate": success_stats.get(True, 0) / total_events if total_events > 0 else 0,
"failure_rate": success_stats.get(False, 0) / total_events if total_events > 0 else 0
"success_rate": success_stats.get(True, 0) / total_events
if total_events > 0
else 0,
"failure_rate": success_stats.get(False, 0) / total_events
if total_events > 0
else 0,
}

View File

@@ -22,6 +22,7 @@ logger = get_logger(__name__)
@dataclass
class Permission:
"""Represents a module permission"""
resource: str
action: str
description: str
@@ -33,6 +34,7 @@ class Permission:
@dataclass
class ModuleMetrics:
"""Module performance metrics"""
requests_processed: int = 0
average_response_time: float = 0.0
error_rate: float = 0.0
@@ -48,6 +50,7 @@ class ModuleMetrics:
@dataclass
class ModuleHealth:
"""Module health status"""
status: str = "healthy" # healthy, warning, error
message: str = "Module is functioning normally"
uptime: float = 0.0
@@ -67,7 +70,7 @@ class BaseModule(ABC):
self.metrics = ModuleMetrics()
self.health = ModuleHealth()
self.initialized = False
self.interceptors: List['ModuleInterceptor'] = []
self.interceptors: List["ModuleInterceptor"] = []
# Register default interceptors
self._register_default_interceptors()
@@ -88,7 +91,9 @@ class BaseModule(ABC):
return []
@abstractmethod
async def process_request(self, request: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
async def process_request(
self, request: Dict[str, Any], context: Dict[str, Any]
) -> Dict[str, Any]:
"""Process a module request"""
pass
@@ -115,10 +120,12 @@ class BaseModule(ABC):
ValidationInterceptor(),
MetricsInterceptor(self),
SecurityInterceptor(),
AuditInterceptor(self)
AuditInterceptor(self),
]
async def execute_with_interceptors(self, request: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
async def execute_with_interceptors(
self, request: Dict[str, Any], context: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute request through interceptor chain"""
start_time = time.time()
@@ -145,7 +152,7 @@ class BaseModule(ABC):
# Error handling interceptors
for interceptor in reversed(self.interceptors):
if hasattr(interceptor, 'handle_error'):
if hasattr(interceptor, "handle_error"):
await interceptor.handle_error(request, context, e)
raise
@@ -168,7 +175,9 @@ class BaseModule(ABC):
if not success:
self.metrics.total_errors += 1
self.metrics.error_rate = self.metrics.total_errors / self.metrics.requests_processed
self.metrics.error_rate = (
self.metrics.total_errors / self.metrics.requests_processed
)
# Update health status based on error rate
if self.metrics.error_rate > 0.1: # More than 10% error rate
@@ -176,9 +185,13 @@ class BaseModule(ABC):
self.health.message = f"High error rate: {self.metrics.error_rate:.2%}"
elif self.metrics.error_rate > 0.05: # More than 5% error rate
self.health.status = "warning"
self.health.message = f"Elevated error rate: {self.metrics.error_rate:.2%}"
self.health.message = (
f"Elevated error rate: {self.metrics.error_rate:.2%}"
)
else:
self.metrics.error_rate = self.metrics.total_errors / self.metrics.requests_processed
self.metrics.error_rate = (
self.metrics.total_errors / self.metrics.requests_processed
)
if self.metrics.error_rate <= 0.05:
self.health.status = "healthy"
self.health.message = "Module is functioning normally"
@@ -190,12 +203,16 @@ class ModuleInterceptor(ABC):
"""Base class for module interceptors"""
@abstractmethod
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
async def pre_process(
self, request: Dict[str, Any], context: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Pre-process the request"""
return request, context
@abstractmethod
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
async def post_process(
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
) -> Dict[str, Any]:
"""Post-process the response"""
return response
@@ -203,14 +220,18 @@ class ModuleInterceptor(ABC):
class AuthenticationInterceptor(ModuleInterceptor):
"""Handles authentication for module requests"""
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
async def pre_process(
self, request: Dict[str, Any], context: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Check if user is authenticated (context should contain user info from API auth)
if not context.get("user_id") and not context.get("api_key_id"):
raise AuthenticationError("Authentication required for module access")
return request, context
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
async def post_process(
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
) -> Dict[str, Any]:
return response
@@ -220,23 +241,31 @@ class PermissionInterceptor(ModuleInterceptor):
def __init__(self, module: BaseModule):
self.module = module
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
async def pre_process(
self, request: Dict[str, Any], context: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
action = request.get("action", "execute")
user_permissions = context.get("user_permissions", [])
if not self.module.check_access(user_permissions, action):
raise AuthenticationError(f"Insufficient permissions for module action: {action}")
raise AuthenticationError(
f"Insufficient permissions for module action: {action}"
)
return request, context
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
async def post_process(
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
) -> Dict[str, Any]:
return response
class ValidationInterceptor(ModuleInterceptor):
"""Handles request validation and sanitization"""
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
async def pre_process(
self, request: Dict[str, Any], context: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Sanitize request data
sanitized_request = self._sanitize_request(request)
@@ -245,7 +274,9 @@ class ValidationInterceptor(ModuleInterceptor):
return sanitized_request, context
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
async def post_process(
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
) -> Dict[str, Any]:
# Sanitize response data
return self._sanitize_response(response)
@@ -275,7 +306,9 @@ class ValidationInterceptor(ModuleInterceptor):
max_length = 10000
if len(value) > max_length:
value = value[:max_length]
logger.warning(f"Truncated long string in request: {len(value)} chars")
logger.warning(
f"Truncated long string in request: {len(value)} chars"
)
return value
elif isinstance(value, dict):
@@ -302,7 +335,9 @@ class ValidationInterceptor(ModuleInterceptor):
request_str = json.dumps(request)
max_size = 10 * 1024 * 1024 # 10MB
if len(request_str.encode()) > max_size:
raise ValidationError(f"Request size exceeds maximum allowed ({max_size} bytes)")
raise ValidationError(
f"Request size exceeds maximum allowed ({max_size} bytes)"
)
class MetricsInterceptor(ModuleInterceptor):
@@ -311,11 +346,15 @@ class MetricsInterceptor(ModuleInterceptor):
def __init__(self, module: BaseModule):
self.module = module
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
async def pre_process(
self, request: Dict[str, Any], context: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
context["_metrics_start_time"] = time.time()
return request, context
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
async def post_process(
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
) -> Dict[str, Any]:
# Metrics are updated in the base module execute_with_interceptors method
return response
@@ -323,13 +362,15 @@ class MetricsInterceptor(ModuleInterceptor):
class SecurityInterceptor(ModuleInterceptor):
"""Handles security-related processing"""
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
async def pre_process(
self, request: Dict[str, Any], context: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Add security headers to context
context["security_headers"] = {
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Strict-Transport-Security": "max-age=31536000; includeSubDomains"
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
}
# Check for suspicious patterns
@@ -337,7 +378,9 @@ class SecurityInterceptor(ModuleInterceptor):
return request, context
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
async def post_process(
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
) -> Dict[str, Any]:
# Remove any sensitive information from response
return self._remove_sensitive_data(response)
@@ -346,9 +389,18 @@ class SecurityInterceptor(ModuleInterceptor):
request_str = json.dumps(request).lower()
suspicious_patterns = [
"union select", "drop table", "insert into", "delete from",
"script>", "javascript:", "eval(", "expression(",
"../", "..\\", "file://", "ftp://",
"union select",
"drop table",
"insert into",
"delete from",
"script>",
"javascript:",
"eval(",
"expression(",
"../",
"..\\",
"file://",
"ftp://",
]
for pattern in suspicious_patterns:
@@ -363,7 +415,9 @@ class SecurityInterceptor(ModuleInterceptor):
def clean_dict(obj):
if isinstance(obj, dict):
return {
k: "***REDACTED***" if any(sk in k.lower() for sk in sensitive_keys) else clean_dict(v)
k: "***REDACTED***"
if any(sk in k.lower() for sk in sensitive_keys)
else clean_dict(v)
for k, v in obj.items()
}
elif isinstance(obj, list):
@@ -380,25 +434,39 @@ class AuditInterceptor(ModuleInterceptor):
def __init__(self, module: BaseModule):
self.module = module
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
async def pre_process(
self, request: Dict[str, Any], context: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
context["_audit_start_time"] = time.time()
context["_audit_request_hash"] = self._hash_request(request)
return request, context
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
async def post_process(
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
) -> Dict[str, Any]:
await self._log_audit_event(request, context, response, success=True)
return response
async def handle_error(self, request: Dict[str, Any], context: Dict[str, Any], error: Exception):
async def handle_error(
self, request: Dict[str, Any], context: Dict[str, Any], error: Exception
):
"""Handle error logging"""
await self._log_audit_event(request, context, {"error": str(error)}, success=False)
await self._log_audit_event(
request, context, {"error": str(error)}, success=False
)
def _hash_request(self, request: Dict[str, Any]) -> str:
"""Create a hash of the request for audit purposes"""
request_str = json.dumps(request, sort_keys=True)
return hashlib.sha256(request_str.encode()).hexdigest()[:16]
async def _log_audit_event(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any], success: bool):
async def _log_audit_event(
self,
request: Dict[str, Any],
context: Dict[str, Any],
response: Dict[str, Any],
success: bool,
):
"""Log audit event"""
duration = time.time() - context.get("_audit_start_time", time.time())

View File

@@ -28,6 +28,7 @@ from sqlalchemy.orm import Session
@dataclass
class PluginContext:
"""Plugin execution context with user and authentication info"""
user_id: Optional[str] = None
api_key_id: Optional[str] = None
user_permissions: List[str] = None
@@ -45,15 +46,19 @@ class PlatformAPIClient:
self.base_url = settings.INTERNAL_API_URL or "http://localhost:58000"
self.logger = get_logger(f"plugin.{plugin_id}.api_client")
async def _make_request(self, method: str, endpoint: str, **kwargs) -> Dict[str, Any]:
async def _make_request(
self, method: str, endpoint: str, **kwargs
) -> Dict[str, Any]:
"""Make authenticated request to platform API"""
headers = kwargs.setdefault('headers', {})
headers.update({
'Authorization': f'Bearer {self.plugin_token}',
'X-Plugin-ID': self.plugin_id,
'X-Platform-Client': 'plugin',
'Content-Type': 'application/json'
})
headers = kwargs.setdefault("headers", {})
headers.update(
{
"Authorization": f"Bearer {self.plugin_token}",
"X-Plugin-ID": self.plugin_id,
"X-Platform-Client": "plugin",
"Content-Type": "application/json",
}
)
url = f"{self.base_url}{endpoint}"
@@ -64,10 +69,10 @@ class PlatformAPIClient:
error_text = await response.text()
raise HTTPException(
status_code=response.status,
detail=f"Platform API error: {error_text}"
detail=f"Platform API error: {error_text}",
)
if response.content_type == 'application/json':
if response.content_type == "application/json":
return await response.json()
else:
return {"data": await response.text()}
@@ -75,73 +80,65 @@ class PlatformAPIClient:
except aiohttp.ClientError as e:
self.logger.error(f"Platform API client error: {e}")
raise HTTPException(
status_code=503,
detail=f"Platform API unavailable: {str(e)}"
status_code=503, detail=f"Platform API unavailable: {str(e)}"
)
async def get(self, endpoint: str, **kwargs) -> Dict[str, Any]:
"""GET request to platform API"""
return await self._make_request('GET', endpoint, **kwargs)
return await self._make_request("GET", endpoint, **kwargs)
async def post(self, endpoint: str, data: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
async def post(
self, endpoint: str, data: Dict[str, Any] = None, **kwargs
) -> Dict[str, Any]:
"""POST request to platform API"""
if data:
kwargs['json'] = data
return await self._make_request('POST', endpoint, **kwargs)
kwargs["json"] = data
return await self._make_request("POST", endpoint, **kwargs)
async def put(self, endpoint: str, data: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
async def put(
self, endpoint: str, data: Dict[str, Any] = None, **kwargs
) -> Dict[str, Any]:
"""PUT request to platform API"""
if data:
kwargs['json'] = data
return await self._make_request('PUT', endpoint, **kwargs)
kwargs["json"] = data
return await self._make_request("PUT", endpoint, **kwargs)
async def delete(self, endpoint: str, **kwargs) -> Dict[str, Any]:
"""DELETE request to platform API"""
return await self._make_request('DELETE', endpoint, **kwargs)
return await self._make_request("DELETE", endpoint, **kwargs)
# Platform-specific API methods
async def call_chatbot_api(self, chatbot_id: str, message: str,
context: Dict[str, Any] = None) -> Dict[str, Any]:
async def call_chatbot_api(
self, chatbot_id: str, message: str, context: Dict[str, Any] = None
) -> Dict[str, Any]:
"""Consume platform chatbot API"""
return await self.post(
f"/api/v1/chatbot/external/{chatbot_id}/chat",
{
"message": message,
"context": context or {}
}
{"message": message, "context": context or {}},
)
async def call_llm_api(self, model: str, messages: List[Dict[str, Any]],
**kwargs) -> Dict[str, Any]:
async def call_llm_api(
self, model: str, messages: List[Dict[str, Any]], **kwargs
) -> Dict[str, Any]:
"""Consume platform LLM API"""
return await self.post(
"/api/v1/llm/chat/completions",
{
"model": model,
"messages": messages,
**kwargs
}
{"model": model, "messages": messages, **kwargs},
)
async def search_rag(self, collection: str, query: str,
top_k: int = 5) -> Dict[str, Any]:
async def search_rag(
self, collection: str, query: str, top_k: int = 5
) -> Dict[str, Any]:
"""Consume platform RAG API"""
return await self.post(
f"/api/v1/rag/collections/{collection}/search",
{
"query": query,
"top_k": top_k
}
{"query": query, "top_k": top_k},
)
async def get_embeddings(self, model: str, input_text: str) -> Dict[str, Any]:
"""Generate embeddings via platform API"""
return await self.post(
"/api/v1/llm/embeddings",
{
"model": model,
"input": input_text
}
"/api/v1/llm/embeddings", {"model": model, "input": input_text}
)
@@ -157,13 +154,14 @@ class PluginConfigManager:
try:
# Use dependency injection to get database session
from app.db.database import SessionLocal
db = SessionLocal()
try:
# Query for active configuration
query = db.query(PluginConfiguration).filter(
PluginConfiguration.plugin_id == self.plugin_id,
PluginConfiguration.is_active == True
PluginConfiguration.is_active == True,
)
if user_id:
@@ -176,10 +174,14 @@ class PluginConfigManager:
config = query.first()
if config:
self.logger.debug(f"Retrieved configuration for plugin {self.plugin_id}, user {user_id}")
self.logger.debug(
f"Retrieved configuration for plugin {self.plugin_id}, user {user_id}"
)
return config.config_data or {}
else:
self.logger.debug(f"No configuration found for plugin {self.plugin_id}, user {user_id}")
self.logger.debug(
f"No configuration found for plugin {self.plugin_id}, user {user_id}"
)
return {}
finally:
@@ -189,21 +191,30 @@ class PluginConfigManager:
self.logger.error(f"Failed to get configuration: {e}")
return {}
async def save_config(self, config: Dict[str, Any], user_id: str,
async def save_config(
self,
config: Dict[str, Any],
user_id: str,
name: str = "Default Configuration",
description: str = None) -> bool:
description: str = None,
) -> bool:
"""Save plugin configuration for user"""
try:
from app.db.database import SessionLocal
db = SessionLocal()
try:
# Check if configuration already exists
existing_config = db.query(PluginConfiguration).filter(
existing_config = (
db.query(PluginConfiguration)
.filter(
PluginConfiguration.plugin_id == self.plugin_id,
PluginConfiguration.user_id == user_id,
PluginConfiguration.name == name
).first()
PluginConfiguration.name == name,
)
.first()
)
if existing_config:
# Update existing configuration
@@ -211,7 +222,9 @@ class PluginConfigManager:
existing_config.description = description
existing_config.is_active = True
self.logger.info(f"Updated configuration for plugin {self.plugin_id}, user {user_id}")
self.logger.info(
f"Updated configuration for plugin {self.plugin_id}, user {user_id}"
)
else:
# Create new configuration
new_config = PluginConfiguration(
@@ -222,20 +235,26 @@ class PluginConfigManager:
config_data=config,
is_active=True,
is_default=(name == "Default Configuration"),
created_by_user_id=user_id
created_by_user_id=user_id,
)
# If this is the first configuration for this user/plugin, make it default
existing_count = db.query(PluginConfiguration).filter(
existing_count = (
db.query(PluginConfiguration)
.filter(
PluginConfiguration.plugin_id == self.plugin_id,
PluginConfiguration.user_id == user_id
).count()
PluginConfiguration.user_id == user_id,
)
.count()
)
if existing_count == 0:
new_config.is_default = True
db.add(new_config)
self.logger.info(f"Created new configuration for plugin {self.plugin_id}, user {user_id}")
self.logger.info(
f"Created new configuration for plugin {self.plugin_id}, user {user_id}"
)
db.commit()
return True
@@ -251,11 +270,13 @@ class PluginConfigManager:
self.logger.error(f"Failed to save configuration: {e}")
return False
async def validate_config(self, config: Dict[str, Any],
schema: Dict[str, Any]) -> Tuple[bool, List[str]]:
async def validate_config(
self, config: Dict[str, Any], schema: Dict[str, Any]
) -> Tuple[bool, List[str]]:
"""Validate configuration against JSON schema"""
try:
import jsonschema
jsonschema.validate(config, schema)
return True, []
except jsonschema.ValidationError as e:
@@ -273,20 +294,27 @@ class PluginLogger:
# Sensitive data patterns to filter
self.sensitive_patterns = [
r'password', r'token', r'key', r'secret', r'api_key',
r'bearer', r'authorization', r'credential'
r"password",
r"token",
r"key",
r"secret",
r"api_key",
r"bearer",
r"authorization",
r"credential",
]
def _filter_sensitive_data(self, message: str) -> str:
"""Filter sensitive data from log messages"""
import re
filtered_message = message
for pattern in self.sensitive_patterns:
filtered_message = re.sub(
f'{pattern}[=:]\s*["\']?([^"\'\\s]+)["\']?',
f'{pattern}=***REDACTED***',
f"{pattern}[=:]\s*[\"']?([^\"'\\s]+)[\"']?",
f"{pattern}=***REDACTED***",
filtered_message,
flags=re.IGNORECASE
flags=re.IGNORECASE,
)
return filtered_message
@@ -360,7 +388,7 @@ class BasePlugin(ABC):
"request_count": self._request_count,
"error_count": self._error_count,
"error_rate": round(error_rate, 3),
"initialized": self.initialized
"initialized": self.initialized,
}
async def get_configuration_schema(self) -> Dict[str, Any]:
@@ -404,16 +432,19 @@ class BasePlugin(ABC):
def get_auth_context(self) -> PluginContext:
"""Dependency to get authentication context in API endpoints"""
async def _get_context(request: Request) -> PluginContext:
# Extract authentication info from request
# This would be populated by the plugin API gateway
return PluginContext(
user_id=request.headers.get('X-User-ID'),
api_key_id=request.headers.get('X-API-Key-ID'),
user_permissions=request.headers.get('X-User-Permissions', '').split(','),
ip_address=request.headers.get('X-Real-IP'),
user_agent=request.headers.get('User-Agent'),
request_id=request.headers.get('X-Request-ID')
user_id=request.headers.get("X-User-ID"),
api_key_id=request.headers.get("X-API-Key-ID"),
user_permissions=request.headers.get("X-User-Permissions", "").split(
","
),
ip_address=request.headers.get("X-Real-IP"),
user_agent=request.headers.get("User-Agent"),
request_id=request.headers.get("X-Request-ID"),
)
return Depends(_get_context)
@@ -430,28 +461,54 @@ class PluginSecurityManager:
BLOCKED_IMPORTS = {
# Core platform modules
'app.db', 'app.models', 'app.core', 'app.services',
'sqlalchemy', 'alembic',
"app.db",
"app.models",
"app.core",
"app.services",
"sqlalchemy",
"alembic",
# Security sensitive
'subprocess', 'eval', 'exec', 'compile', '__import__',
'os.system', 'os.popen', 'os.spawn',
"subprocess",
"eval",
"exec",
"compile",
"__import__",
"os.system",
"os.popen",
"os.spawn",
# System access
'socket', 'multiprocessing', 'threading'
"socket",
"multiprocessing",
"threading",
}
ALLOWED_IMPORTS = {
# Standard library
'asyncio', 'aiohttp', 'json', 'datetime', 'typing', 'pydantic',
'logging', 'time', 'uuid', 'hashlib', 'base64', 'pathlib',
're', 'urllib.parse', 'dataclasses', 'enum',
"asyncio",
"aiohttp",
"json",
"datetime",
"typing",
"pydantic",
"logging",
"time",
"uuid",
"hashlib",
"base64",
"pathlib",
"re",
"urllib.parse",
"dataclasses",
"enum",
# Approved third-party
'httpx', 'requests', 'pandas', 'numpy', 'yaml',
"httpx",
"requests",
"pandas",
"numpy",
"yaml",
# Plugin framework
'app.services.base_plugin', 'app.schemas.plugin_manifest'
"app.services.base_plugin",
"app.schemas.plugin_manifest",
}
@classmethod
@@ -459,7 +516,9 @@ class PluginSecurityManager:
"""Validate if plugin can import a module"""
# Block dangerous imports
if any(import_name.startswith(blocked) for blocked in cls.BLOCKED_IMPORTS):
raise SecurityError(f"Import '{import_name}' not allowed in plugin environment")
raise SecurityError(
f"Import '{import_name}' not allowed in plugin environment"
)
# Allow explicit safe imports
if any(import_name.startswith(allowed) for allowed in cls.ALLOWED_IMPORTS):
@@ -474,12 +533,12 @@ class PluginSecurityManager:
def create_plugin_sandbox(cls, plugin_id: str) -> Dict[str, Any]:
"""Create isolated environment for plugin execution"""
return {
'max_memory_mb': 128,
'max_cpu_percent': 25,
'max_disk_mb': 100,
'max_api_calls_per_minute': 100,
'allowed_domains': [], # Will be populated from manifest
'network_timeout_seconds': 30
"max_memory_mb": 128,
"max_cpu_percent": 25,
"max_disk_mb": 100,
"max_api_calls_per_minute": 100,
"allowed_domains": [], # Will be populated from manifest
"network_timeout_seconds": 30,
}
@@ -499,7 +558,9 @@ class PluginLoader:
validation_result = validate_manifest_file(manifest_path)
if not validation_result["valid"]:
raise ValidationError(f"Invalid plugin manifest: {validation_result['errors']}")
raise ValidationError(
f"Invalid plugin manifest: {validation_result['errors']}"
)
manifest = validation_result["manifest"]
@@ -511,8 +572,7 @@ class PluginLoader:
# Load plugin module
main_py_path = plugin_dir / "main.py"
spec = importlib.util.spec_from_file_location(
f"plugin_{manifest.metadata.name}",
main_py_path
f"plugin_{manifest.metadata.name}", main_py_path
)
if not spec or not spec.loader:
@@ -536,14 +596,18 @@ class PluginLoader:
plugin_class = None
for attr_name in dir(plugin_module):
attr = getattr(plugin_module, attr_name)
if (isinstance(attr, type) and
issubclass(attr, BasePlugin) and
attr is not BasePlugin):
if (
isinstance(attr, type)
and issubclass(attr, BasePlugin)
and attr is not BasePlugin
):
plugin_class = attr
break
if not plugin_class:
raise ValidationError("Plugin must contain a class inheriting from BasePlugin")
raise ValidationError(
"Plugin must contain a class inheriting from BasePlugin"
)
# Instantiate plugin
plugin_instance = plugin_class(manifest, plugin_token)
@@ -562,21 +626,30 @@ class PluginLoader:
def _validate_plugin_security(self, main_py_path: Path):
"""Validate plugin code for security issues"""
with open(main_py_path, 'r', encoding='utf-8') as f:
with open(main_py_path, "r", encoding="utf-8") as f:
code_content = f.read()
# Check for dangerous patterns
dangerous_patterns = [
'eval(', 'exec(', 'compile(',
'subprocess.', 'os.system', 'os.popen',
'__import__', 'importlib.import_module',
'from app.db', 'from app.models',
'sqlalchemy', 'SessionLocal'
"eval(",
"exec(",
"compile(",
"subprocess.",
"os.system",
"os.popen",
"__import__",
"importlib.import_module",
"from app.db",
"from app.models",
"sqlalchemy",
"SessionLocal",
]
for pattern in dangerous_patterns:
if pattern in code_content:
raise SecurityError(f"Dangerous pattern detected in plugin code: {pattern}")
raise SecurityError(
f"Dangerous pattern detected in plugin code: {pattern}"
)
async def unload_plugin(self, plugin_id: str) -> bool:
"""Unload a plugin and cleanup resources"""

View File

@@ -21,11 +21,13 @@ logger = get_logger(__name__)
class BudgetEnforcementError(Exception):
"""Custom exception for budget enforcement failures"""
pass
class BudgetExceededError(BudgetEnforcementError):
"""Exception raised when budget would be exceeded"""
def __init__(self, message: str, budget: Budget, requested_cost: int):
super().__init__(message)
self.budget = budget
@@ -34,6 +36,7 @@ class BudgetExceededError(BudgetEnforcementError):
class BudgetWarningError(BudgetEnforcementError):
"""Exception raised when budget warning threshold is reached"""
def __init__(self, message: str, budget: Budget, requested_cost: int):
super().__init__(message)
self.budget = budget
@@ -42,6 +45,7 @@ class BudgetWarningError(BudgetEnforcementError):
class BudgetConcurrencyError(BudgetEnforcementError):
"""Exception raised when budget update fails due to concurrency"""
def __init__(self, message: str, retry_count: int = 0):
super().__init__(message)
self.retry_count = retry_count
@@ -49,6 +53,7 @@ class BudgetConcurrencyError(BudgetEnforcementError):
class BudgetAtomicError(BudgetEnforcementError):
"""Exception raised when atomic budget operation fails"""
def __init__(self, message: str, budget_id: int, requested_amount: int):
super().__init__(message)
self.budget_id = budget_id
@@ -68,7 +73,7 @@ class BudgetEnforcementService:
api_key: APIKey,
model_name: str,
estimated_tokens: int,
endpoint: str = None
endpoint: str = None,
) -> Tuple[bool, Optional[str], List[Dict[str, Any]], List[int]]:
"""
Atomically check budget compliance and reserve spending
@@ -86,16 +91,27 @@ class BudgetEnforcementService:
# Try atomic reservation with retries
for attempt in range(self.max_retries):
try:
return self._attempt_atomic_reservation(budgets, estimated_cost, api_key.id, attempt)
return self._attempt_atomic_reservation(
budgets, estimated_cost, api_key.id, attempt
)
except BudgetConcurrencyError as e:
if attempt == self.max_retries - 1:
logger.error(f"Atomic budget reservation failed after {self.max_retries} attempts: {e}")
return False, f"Budget check temporarily unavailable (concurrency limit)", [], []
logger.error(
f"Atomic budget reservation failed after {self.max_retries} attempts: {e}"
)
return (
False,
f"Budget check temporarily unavailable (concurrency limit)",
[],
[],
)
# Exponential backoff with jitter
delay = self.retry_delay_base * (2 ** attempt) + random.uniform(0, 0.1)
delay = self.retry_delay_base * (2**attempt) + random.uniform(0, 0.1)
time.sleep(delay)
logger.info(f"Retrying atomic budget reservation (attempt {attempt + 2})")
logger.info(
f"Retrying atomic budget reservation (attempt {attempt + 2})"
)
except Exception as e:
logger.error(f"Unexpected error in atomic budget reservation: {e}")
return False, f"Budget check failed: {str(e)}", [], []
@@ -103,11 +119,7 @@ class BudgetEnforcementService:
return False, "Budget check failed after maximum retries", [], []
def _attempt_atomic_reservation(
self,
budgets: List[Budget],
estimated_cost: int,
api_key_id: int,
attempt: int
self, budgets: List[Budget], estimated_cost: int, api_key_id: int, attempt: int
) -> Tuple[bool, Optional[str], List[Dict[str, Any]], List[int]]:
"""Attempt to atomically reserve budget across all applicable budgets"""
warnings = []
@@ -119,12 +131,17 @@ class BudgetEnforcementService:
for budget in budgets:
# Lock budget row for update to prevent concurrent modifications
locked_budget = self.db.query(Budget).filter(
Budget.id == budget.id
).with_for_update().first()
locked_budget = (
self.db.query(Budget)
.filter(Budget.id == budget.id)
.with_for_update()
.first()
)
if not locked_budget:
raise BudgetAtomicError(f"Budget {budget.id} not found", budget.id, estimated_cost)
raise BudgetAtomicError(
f"Budget {budget.id} not found", budget.id, estimated_cost
)
# Reset budget if expired and auto-renew enabled
if locked_budget.is_expired() and locked_budget.auto_renew:
@@ -144,28 +161,42 @@ class BudgetEnforcementService:
f"Requested: ${estimated_cost/100:.4f}, "
f"Remaining: ${(locked_budget.limit_cents - locked_budget.current_usage_cents)/100:.2f}"
)
logger.warning(f"Budget exceeded for API key {api_key_id}: {error_msg}")
logger.warning(
f"Budget exceeded for API key {api_key_id}: {error_msg}"
)
self.db.rollback()
return False, error_msg, warnings, []
# Check warning threshold
if locked_budget.would_exceed_warning(estimated_cost) and not locked_budget.is_warning_sent:
if (
locked_budget.would_exceed_warning(estimated_cost)
and not locked_budget.is_warning_sent
):
warning_msg = (
f"Budget '{locked_budget.name}' approaching limit. "
f"Usage will be ${(locked_budget.current_usage_cents + estimated_cost)/100:.2f} "
f"of ${locked_budget.limit_cents/100:.2f} "
f"({((locked_budget.current_usage_cents + estimated_cost) / locked_budget.limit_cents * 100):.1f}%)"
)
warnings.append({
warnings.append(
{
"type": "budget_warning",
"budget_id": locked_budget.id,
"budget_name": locked_budget.name,
"message": warning_msg,
"current_usage_cents": locked_budget.current_usage_cents + estimated_cost,
"current_usage_cents": locked_budget.current_usage_cents
+ estimated_cost,
"limit_cents": locked_budget.limit_cents,
"usage_percentage": (locked_budget.current_usage_cents + estimated_cost) / locked_budget.limit_cents * 100
})
logger.info(f"Budget warning for API key {api_key_id}: {warning_msg}")
"usage_percentage": (
locked_budget.current_usage_cents + estimated_cost
)
/ locked_budget.limit_cents
* 100,
}
)
logger.info(
f"Budget warning for API key {api_key_id}: {warning_msg}"
)
# Reserve the budget (temporarily add estimated cost)
self._atomic_reserve_usage(locked_budget, estimated_cost)
@@ -173,12 +204,16 @@ class BudgetEnforcementService:
# Commit the reservation
self.db.commit()
logger.debug(f"Successfully reserved budget for API key {api_key_id}, estimated cost: ${estimated_cost/100:.4f}")
logger.debug(
f"Successfully reserved budget for API key {api_key_id}, estimated cost: ${estimated_cost/100:.4f}"
)
return True, None, warnings, reserved_budget_ids
except IntegrityError as e:
self.db.rollback()
raise BudgetConcurrencyError(f"Database integrity error during budget reservation: {e}", attempt)
raise BudgetConcurrencyError(
f"Database integrity error during budget reservation: {e}", attempt
)
except Exception as e:
self.db.rollback()
logger.error(f"Error in atomic budget reservation: {e}")
@@ -203,24 +238,35 @@ class BudgetEnforcementService:
.values(
current_usage_cents=Budget.current_usage_cents + amount_cents,
updated_at=datetime.utcnow(),
is_exceeded=Budget.current_usage_cents + amount_cents >= Budget.limit_cents,
is_exceeded=Budget.current_usage_cents + amount_cents
>= Budget.limit_cents,
is_warning_sent=(
Budget.is_warning_sent |
((Budget.warning_threshold_cents.isnot(None)) &
(Budget.current_usage_cents + amount_cents >= Budget.warning_threshold_cents))
Budget.is_warning_sent
| (
(Budget.warning_threshold_cents.isnot(None))
& (
Budget.current_usage_cents + amount_cents
>= Budget.warning_threshold_cents
)
)
),
)
)
if result.rowcount != 1:
raise BudgetAtomicError(f"Failed to update budget {budget.id}", budget.id, amount_cents)
raise BudgetAtomicError(
f"Failed to update budget {budget.id}", budget.id, amount_cents
)
# Update the in-memory object to reflect changes
budget.current_usage_cents += amount_cents
budget.updated_at = datetime.utcnow()
if budget.current_usage_cents >= budget.limit_cents:
budget.is_exceeded = True
if budget.warning_threshold_cents and budget.current_usage_cents >= budget.warning_threshold_cents:
if (
budget.warning_threshold_cents
and budget.current_usage_cents >= budget.warning_threshold_cents
):
budget.is_warning_sent = True
def atomic_finalize_usage(
@@ -230,7 +276,7 @@ class BudgetEnforcementService:
model_name: str,
input_tokens: int,
output_tokens: int,
endpoint: str = None
endpoint: str = None,
) -> List[Budget]:
"""
Finalize actual usage and adjust reservations
@@ -250,7 +296,9 @@ class BudgetEnforcementService:
return []
try:
actual_cost = CostCalculator.calculate_cost_cents(model_name, input_tokens, output_tokens)
actual_cost = CostCalculator.calculate_cost_cents(
model_name, input_tokens, output_tokens
)
updated_budgets = []
# Begin transaction for finalization
@@ -258,9 +306,12 @@ class BudgetEnforcementService:
for budget_id in reserved_budget_ids:
# Lock budget for update
budget = self.db.query(Budget).filter(
Budget.id == budget_id
).with_for_update().first()
budget = (
self.db.query(Budget)
.filter(Budget.id == budget_id)
.with_for_update()
.first()
)
if not budget:
logger.warning(f"Budget {budget_id} not found during finalization")
@@ -270,7 +321,9 @@ class BudgetEnforcementService:
# Calculate adjustment (actual cost - estimated cost already reserved)
# Note: We don't know the exact estimated cost that was reserved
# So we'll just set to actual cost (this is safe as we already reserved)
self._atomic_set_actual_usage(budget, actual_cost, input_tokens, output_tokens)
self._atomic_set_actual_usage(
budget, actual_cost, input_tokens, output_tokens
)
updated_budgets.append(budget)
logger.debug(
@@ -287,7 +340,9 @@ class BudgetEnforcementService:
self.db.rollback()
return []
def _atomic_set_actual_usage(self, budget: Budget, actual_cost: int, input_tokens: int, output_tokens: int):
def _atomic_set_actual_usage(
self, budget: Budget, actual_cost: int, input_tokens: int, output_tokens: int
):
"""Set the actual usage cost (replacing any reservation)"""
# For simplicity, we'll just ensure the current usage reflects actual cost
# In a more sophisticated system, you might track reservations separately
@@ -300,7 +355,7 @@ class BudgetEnforcementService:
api_key: APIKey,
model_name: str,
estimated_tokens: int,
endpoint: str = None
endpoint: str = None,
) -> Tuple[bool, Optional[str], List[Dict[str, Any]]]:
"""
Check if a request complies with budget limits
@@ -346,27 +401,41 @@ class BudgetEnforcementService:
f"Requested: ${estimated_cost/100:.4f}, "
f"Remaining: ${(budget.limit_cents - budget.current_usage_cents)/100:.2f}"
)
logger.warning(f"Budget exceeded for API key {api_key.id}: {error_msg}")
logger.warning(
f"Budget exceeded for API key {api_key.id}: {error_msg}"
)
return False, error_msg, warnings
# Check if request would trigger warning
if budget.would_exceed_warning(estimated_cost) and not budget.is_warning_sent:
if (
budget.would_exceed_warning(estimated_cost)
and not budget.is_warning_sent
):
warning_msg = (
f"Budget '{budget.name}' approaching limit. "
f"Usage will be ${(budget.current_usage_cents + estimated_cost)/100:.2f} "
f"of ${budget.limit_cents/100:.2f} "
f"({((budget.current_usage_cents + estimated_cost) / budget.limit_cents * 100):.1f}%)"
)
warnings.append({
warnings.append(
{
"type": "budget_warning",
"budget_id": budget.id,
"budget_name": budget.name,
"message": warning_msg,
"current_usage_cents": budget.current_usage_cents + estimated_cost,
"current_usage_cents": budget.current_usage_cents
+ estimated_cost,
"limit_cents": budget.limit_cents,
"usage_percentage": (budget.current_usage_cents + estimated_cost) / budget.limit_cents * 100
})
logger.info(f"Budget warning for API key {api_key.id}: {warning_msg}")
"usage_percentage": (
budget.current_usage_cents + estimated_cost
)
/ budget.limit_cents
* 100,
}
)
logger.info(
f"Budget warning for API key {api_key.id}: {warning_msg}"
)
return True, None, warnings
@@ -381,7 +450,7 @@ class BudgetEnforcementService:
model_name: str,
input_tokens: int,
output_tokens: int,
endpoint: str = None
endpoint: str = None,
) -> List[Budget]:
"""
Record actual usage against applicable budgets
@@ -398,7 +467,9 @@ class BudgetEnforcementService:
"""
try:
# Calculate actual cost
actual_cost = CostCalculator.calculate_cost_cents(model_name, input_tokens, output_tokens)
actual_cost = CostCalculator.calculate_cost_cents(
model_name, input_tokens, output_tokens
)
# Get applicable budgets
budgets = self._get_applicable_budgets(api_key, model_name, endpoint)
@@ -427,10 +498,7 @@ class BudgetEnforcementService:
return []
def _get_applicable_budgets(
self,
api_key: APIKey,
model_name: str = None,
endpoint: str = None
self, api_key: APIKey, model_name: str = None, endpoint: str = None
) -> List[Budget]:
"""Get budgets that apply to the given request"""
@@ -438,9 +506,11 @@ class BudgetEnforcementService:
conditions = [
Budget.is_active == True,
or_(
and_(Budget.user_id == api_key.user_id, Budget.api_key_id.is_(None)), # User budget
Budget.api_key_id == api_key.id # API key specific budget
)
and_(
Budget.user_id == api_key.user_id, Budget.api_key_id.is_(None)
), # User budget
Budget.api_key_id == api_key.id, # API key specific budget
),
]
# Query budgets
@@ -492,7 +562,7 @@ class BudgetEnforcementService:
"warning_budgets": 0,
"total_limit_cents": 0,
"total_usage_cents": 0,
"budgets": []
"budgets": [],
}
for budget in budgets:
@@ -500,12 +570,14 @@ class BudgetEnforcementService:
continue
budget_info = budget.to_dict()
budget_info.update({
budget_info.update(
{
"is_expired": budget.is_expired(),
"days_remaining": budget.get_period_days_remaining(),
"daily_burn_rate": budget.get_daily_burn_rate(),
"projected_spend": budget.get_projected_spend()
})
"projected_spend": budget.get_projected_spend(),
}
)
status["budgets"].append(budget_info)
status["active_budgets"] += 1
@@ -514,18 +586,25 @@ class BudgetEnforcementService:
if budget.is_exceeded:
status["exceeded_budgets"] += 1
elif budget.warning_threshold_cents and budget.current_usage_cents >= budget.warning_threshold_cents:
elif (
budget.warning_threshold_cents
and budget.current_usage_cents >= budget.warning_threshold_cents
):
status["warning_budgets"] += 1
# Calculate overall percentages
if status["total_limit_cents"] > 0:
status["overall_usage_percentage"] = (status["total_usage_cents"] / status["total_limit_cents"]) * 100
status["overall_usage_percentage"] = (
status["total_usage_cents"] / status["total_limit_cents"]
) * 100
else:
status["overall_usage_percentage"] = 0
status["total_limit_dollars"] = status["total_limit_cents"] / 100
status["total_usage_dollars"] = status["total_usage_cents"] / 100
status["total_remaining_cents"] = max(0, status["total_limit_cents"] - status["total_usage_cents"])
status["total_remaining_cents"] = max(
0, status["total_limit_cents"] - status["total_usage_cents"]
)
status["total_remaining_dollars"] = status["total_remaining_cents"] / 100
return status
@@ -538,14 +617,11 @@ class BudgetEnforcementService:
"active_budgets": 0,
"exceeded_budgets": 0,
"warning_budgets": 0,
"budgets": []
"budgets": [],
}
def create_default_user_budget(
self,
user_id: int,
limit_dollars: float = 10.0,
period_type: str = "monthly"
self, user_id: int, limit_dollars: float = 10.0, period_type: str = "monthly"
) -> Budget:
"""Create a default budget for a new user"""
try:
@@ -553,19 +629,21 @@ class BudgetEnforcementService:
budget = Budget.create_monthly_budget(
user_id=user_id,
name="Default Monthly Budget",
limit_dollars=limit_dollars
limit_dollars=limit_dollars,
)
else:
budget = Budget.create_daily_budget(
user_id=user_id,
name="Default Daily Budget",
limit_dollars=limit_dollars
limit_dollars=limit_dollars,
)
self.db.add(budget)
self.db.commit()
logger.info(f"Created default budget for user {user_id}: ${limit_dollars} {period_type}")
logger.info(
f"Created default budget for user {user_id}: ${limit_dollars} {period_type}"
)
return budget
@@ -577,13 +655,17 @@ class BudgetEnforcementService:
def check_and_reset_expired_budgets(self):
"""Background task to check and reset expired budgets"""
try:
expired_budgets = self.db.query(Budget).filter(
expired_budgets = (
self.db.query(Budget)
.filter(
and_(
Budget.is_active == True,
Budget.auto_renew == True,
Budget.period_end < datetime.utcnow()
Budget.period_end < datetime.utcnow(),
)
)
.all()
)
).all()
for budget in expired_budgets:
self._reset_expired_budget(budget)
@@ -596,17 +678,20 @@ class BudgetEnforcementService:
# Convenience functions
# DEPRECATED: Use atomic versions for race-condition-free budget enforcement
def check_budget_for_request(
db: Session,
api_key: APIKey,
model_name: str,
estimated_tokens: int,
endpoint: str = None
endpoint: str = None,
) -> Tuple[bool, Optional[str], List[Dict[str, Any]]]:
"""DEPRECATED: Convenience function to check budget compliance (race conditions possible)"""
service = BudgetEnforcementService(db)
return service.check_budget_compliance(api_key, model_name, estimated_tokens, endpoint)
return service.check_budget_compliance(
api_key, model_name, estimated_tokens, endpoint
)
def record_request_usage(
@@ -615,11 +700,13 @@ def record_request_usage(
model_name: str,
input_tokens: int,
output_tokens: int,
endpoint: str = None
endpoint: str = None,
) -> List[Budget]:
"""DEPRECATED: Convenience function to record actual usage (race conditions possible)"""
service = BudgetEnforcementService(db)
return service.record_usage(api_key, model_name, input_tokens, output_tokens, endpoint)
return service.record_usage(
api_key, model_name, input_tokens, output_tokens, endpoint
)
# ATOMIC VERSIONS: Race-condition-free budget enforcement
@@ -628,11 +715,13 @@ def atomic_check_and_reserve_budget(
api_key: APIKey,
model_name: str,
estimated_tokens: int,
endpoint: str = None
endpoint: str = None,
) -> Tuple[bool, Optional[str], List[Dict[str, Any]], List[int]]:
"""Atomic convenience function to check budget compliance and reserve spending"""
service = BudgetEnforcementService(db)
return service.atomic_check_and_reserve_budget(api_key, model_name, estimated_tokens, endpoint)
return service.atomic_check_and_reserve_budget(
api_key, model_name, estimated_tokens, endpoint
)
def atomic_finalize_usage(
@@ -642,8 +731,10 @@ def atomic_finalize_usage(
model_name: str,
input_tokens: int,
output_tokens: int,
endpoint: str = None
endpoint: str = None,
) -> List[Budget]:
"""Atomic convenience function to finalize actual usage after request completion"""
service = BudgetEnforcementService(db)
return service.atomic_finalize_usage(reserved_budget_ids, api_key, model_name, input_tokens, output_tokens, endpoint)
return service.atomic_finalize_usage(
reserved_budget_ids, api_key, model_name, input_tokens, output_tokens, endpoint
)

View File

@@ -31,9 +31,13 @@ class CachedAPIKeyService:
async def close(self):
"""Close method for compatibility - core cache handles its own lifecycle"""
logger.info("Cached API key service close called - core cache handles lifecycle")
logger.info(
"Cached API key service close called - core cache handles lifecycle"
)
async def get_cached_api_key(self, key_prefix: str, db: AsyncSession) -> Optional[Dict[str, Any]]:
async def get_cached_api_key(
self, key_prefix: str, db: AsyncSession
) -> Optional[Dict[str, Any]]:
"""
Get API key data from cache or database
Returns: Dictionary with api_key, user, and api_key_id
@@ -57,19 +61,18 @@ class CachedAPIKeyService:
return {
"api_key": api_key,
"user": user,
"api_key_id": api_key_data.get("id")
"api_key_id": api_key_data.get("id"),
}
logger.debug(f"API key cache miss for prefix: {key_prefix}, fetching from database")
logger.debug(
f"API key cache miss for prefix: {key_prefix}, fetching from database"
)
# Cache miss - fetch from database with optimized query
stmt = (
select(APIKey, User)
.join(User, APIKey.user_id == User.id)
.options(
joinedload(APIKey.user),
joinedload(User.api_keys)
)
.options(joinedload(APIKey.user), joinedload(User.api_keys))
.where(APIKey.key_prefix == key_prefix)
.where(APIKey.is_active == True)
)
@@ -86,11 +89,7 @@ class CachedAPIKeyService:
# Cache for future requests
await self._cache_api_key_data(key_prefix, api_key, user)
return {
"api_key": api_key,
"user": user,
"api_key_id": api_key.id
}
return {"api_key": api_key, "user": user, "api_key_id": api_key.id}
except Exception as e:
logger.error(f"Error retrieving API key for prefix {key_prefix}: {e}")
@@ -118,17 +117,25 @@ class CachedAPIKeyService:
"allowed_ips": api_key.allowed_ips,
"description": api_key.description,
"tags": api_key.tags,
"created_at": api_key.created_at.isoformat() if api_key.created_at else None,
"updated_at": api_key.updated_at.isoformat() if api_key.updated_at else None,
"last_used_at": api_key.last_used_at.isoformat() if api_key.last_used_at else None,
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
"created_at": api_key.created_at.isoformat()
if api_key.created_at
else None,
"updated_at": api_key.updated_at.isoformat()
if api_key.updated_at
else None,
"last_used_at": api_key.last_used_at.isoformat()
if api_key.last_used_at
else None,
"expires_at": api_key.expires_at.isoformat()
if api_key.expires_at
else None,
"total_requests": api_key.total_requests,
"total_tokens": api_key.total_tokens,
"total_cost": api_key.total_cost,
"is_unlimited": api_key.is_unlimited,
"budget_limit_cents": api_key.budget_limit_cents,
"budget_type": api_key.budget_type,
"allowed_chatbots": api_key.allowed_chatbots
"allowed_chatbots": api_key.allowed_chatbots,
},
"user_data": {
"id": user.id,
@@ -137,11 +144,17 @@ class CachedAPIKeyService:
"is_active": user.is_active,
"is_superuser": user.is_superuser,
"role": user.role,
"created_at": user.created_at.isoformat() if user.created_at else None,
"updated_at": user.updated_at.isoformat() if user.updated_at else None,
"last_login": user.last_login.isoformat() if user.last_login else None
"created_at": user.created_at.isoformat()
if user.created_at
else None,
"updated_at": user.updated_at.isoformat()
if user.updated_at
else None,
"last_login": user.last_login.isoformat()
if user.last_login
else None,
},
"cached_at": datetime.utcnow().isoformat()
"cached_at": datetime.utcnow().isoformat(),
}
await core_cache.cache_api_key(key_prefix, cache_data, self.cache_ttl)
@@ -150,7 +163,9 @@ class CachedAPIKeyService:
except Exception as e:
logger.error(f"Error caching API key data for prefix {key_prefix}: {e}")
async def verify_api_key_cached(self, api_key: str, key_prefix: str) -> Optional[bool]:
async def verify_api_key_cached(
self, api_key: str, key_prefix: str
) -> Optional[bool]:
"""
Verify API key using cached hash to avoid expensive bcrypt operations
Returns: True if verified, False if invalid, None if not cached
@@ -161,25 +176,39 @@ class CachedAPIKeyService:
if cached_verification:
# Check if cache is still valid (within TTL)
cached_timestamp = datetime.fromisoformat(cached_verification["timestamp"])
if datetime.utcnow() - cached_timestamp < timedelta(seconds=self.verification_cache_ttl):
logger.debug(f"API key verification cache hit for prefix: {key_prefix}")
cached_timestamp = datetime.fromisoformat(
cached_verification["timestamp"]
)
if datetime.utcnow() - cached_timestamp < timedelta(
seconds=self.verification_cache_ttl
):
logger.debug(
f"API key verification cache hit for prefix: {key_prefix}"
)
return cached_verification.get("is_valid", False)
return None # Not cached or expired
except Exception as e:
logger.error(f"Error checking verification cache for prefix {key_prefix}: {e}")
logger.error(
f"Error checking verification cache for prefix {key_prefix}: {e}"
)
return None
async def cache_verification_result(self, api_key: str, key_prefix: str, key_hash: str, is_valid: bool):
async def cache_verification_result(
self, api_key: str, key_prefix: str, key_hash: str, is_valid: bool
):
"""Cache API key verification result to avoid expensive bcrypt operations"""
try:
await core_cache.cache_verification_result(api_key, key_prefix, key_hash, is_valid, self.verification_cache_ttl)
await core_cache.cache_verification_result(
api_key, key_prefix, key_hash, is_valid, self.verification_cache_ttl
)
logger.debug(f"Cached verification result for prefix: {key_prefix}")
except Exception as e:
logger.error(f"Error caching verification result for prefix {key_prefix}: {e}")
logger.error(
f"Error caching verification result for prefix {key_prefix}: {e}"
)
async def invalidate_api_key_cache(self, key_prefix: str):
"""Invalidate cached API key data"""
@@ -187,7 +216,9 @@ class CachedAPIKeyService:
await core_cache.invalidate_api_key(key_prefix)
# Also invalidate verification cache
verification_keys = await core_cache.clear_pattern(f"verify:{key_prefix}*", prefix="auth")
verification_keys = await core_cache.clear_pattern(
f"verify:{key_prefix}*", prefix="auth"
)
logger.debug(f"Invalidated cache for API key prefix: {key_prefix}")
@@ -206,10 +237,7 @@ class CachedAPIKeyService:
return # Skip update if recent
# Update database
stmt = (
select(APIKey)
.where(APIKey.id == api_key_id)
)
stmt = select(APIKey).where(APIKey.id == api_key_id)
result = await db.execute(stmt)
api_key = result.scalar_one_or_none()
@@ -218,7 +246,9 @@ class CachedAPIKeyService:
await db.commit()
# Cache that we updated to prevent spam
await core_cache.set(cache_key, datetime.utcnow().isoformat(), ttl=300, prefix="perf")
await core_cache.set(
cache_key, datetime.utcnow().isoformat(), ttl=300, prefix="perf"
)
logger.debug(f"Updated last_used_at for API key {api_key_id}")
@@ -232,14 +262,14 @@ class CachedAPIKeyService:
return {
"cache_backend": "core_cache",
"cache_enabled": core_stats.get("enabled", False),
"cache_stats": core_stats
"cache_stats": core_stats,
}
except Exception as e:
logger.error(f"Error getting cache stats: {e}")
return {
"cache_backend": "core_cache",
"cache_enabled": False,
"error": str(e)
"error": str(e),
}

View File

@@ -25,6 +25,7 @@ logger = get_logger(__name__)
@dataclass
class ConfigVersion:
"""Configuration version metadata"""
version: str
timestamp: datetime
checksum: str
@@ -36,6 +37,7 @@ class ConfigVersion:
@dataclass
class ConfigSchema:
"""Configuration schema definition"""
name: str
required_fields: List[str]
optional_fields: List[str]
@@ -46,6 +48,7 @@ class ConfigSchema:
@dataclass
class ConfigStats:
"""Configuration manager statistics"""
total_configs: int
active_watchers: int
config_versions: int
@@ -78,19 +81,19 @@ class ConfigWatcher(FileSystemEventHandler):
self.last_modified[path] = current_time
# Trigger hot reload for config files
if path.endswith(('.json', '.yaml', '.yml', '.toml')):
if path.endswith((".json", ".yaml", ".yml", ".toml")):
# Schedule coroutine in a thread-safe way
try:
loop = asyncio.get_running_loop()
loop.call_soon_threadsafe(
lambda: asyncio.create_task(self.config_manager.reload_config_file(path))
lambda: asyncio.create_task(
self.config_manager.reload_config_file(path)
)
)
except RuntimeError:
# No running loop, schedule for later
threading.Thread(
target=self._schedule_reload,
args=(path,),
daemon=True
target=self._schedule_reload, args=(path,), daemon=True
).start()
def _schedule_reload(self, path: str):
@@ -110,7 +113,7 @@ class ConfigManager:
self.versions: Dict[str, List[ConfigVersion]] = {}
self.watchers: Dict[str, Observer] = {}
self.config_paths: Dict[str, Path] = {}
self.environment = os.getenv('ENVIRONMENT', 'development')
self.environment = os.getenv("ENVIRONMENT", "development")
self.start_time = time.time()
self.stats = ConfigStats(
total_configs=0,
@@ -119,7 +122,7 @@ class ConfigManager:
hot_reloads_performed=0,
validation_errors=0,
last_reload_time=datetime.now(),
uptime=0
uptime=0,
)
# Base configuration directories
@@ -157,7 +160,9 @@ class ConfigManager:
for field, expected_type in schema.field_types.items():
if field in config_data:
if not isinstance(config_data[field], expected_type):
logger.error(f"Invalid type for field '{field}' in config '{name}'. Expected {expected_type.__name__}")
logger.error(
f"Invalid type for field '{field}' in config '{name}'. Expected {expected_type.__name__}"
)
self.stats.validation_errors += 1
return False
@@ -165,7 +170,9 @@ class ConfigManager:
for field, validator in schema.validators.items():
if field in config_data:
if not validator(config_data[field]):
logger.error(f"Validation failed for field '{field}' in config '{name}'")
logger.error(
f"Validation failed for field '{field}' in config '{name}'"
)
self.stats.validation_errors += 1
return False
@@ -182,7 +189,9 @@ class ConfigManager:
json_str = json.dumps(data, sort_keys=True)
return hashlib.sha256(json_str.encode()).hexdigest()
def _create_version(self, name: str, config_data: Dict[str, Any], description: str = "Auto-save") -> ConfigVersion:
def _create_version(
self, name: str, config_data: Dict[str, Any], description: str = "Auto-save"
) -> ConfigVersion:
"""Create a new configuration version"""
version_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}"
checksum = self._calculate_checksum(config_data)
@@ -191,9 +200,9 @@ class ConfigManager:
version=version_id,
timestamp=datetime.now(),
checksum=checksum,
author=os.getenv('USER', 'system'),
author=os.getenv("USER", "system"),
description=description,
config_data=config_data.copy()
config_data=config_data.copy(),
)
if name not in self.versions:
@@ -209,8 +218,9 @@ class ConfigManager:
logger.debug(f"Created version {version_id} for config '{name}'")
return version
async def set_config(self, name: str, config_data: Dict[str, Any],
description: str = "Manual update") -> bool:
async def set_config(
self, name: str, config_data: Dict[str, Any], description: str = "Manual update"
) -> bool:
"""Set configuration with validation and versioning"""
try:
# Validate configuration
@@ -247,13 +257,15 @@ class ConfigManager:
return default
async def get_config_value(self, config_name: str, key: str, default: Any = None) -> Any:
async def get_config_value(
self, config_name: str, key: str, default: Any = None
) -> Any:
"""Get specific value from configuration"""
config = await self.get_config(config_name)
if config is None:
return default
keys = key.split('.')
keys = key.split(".")
value = config
try:
@@ -269,7 +281,7 @@ class ConfigManager:
try:
# Save as regular JSON
with open(file_path, 'w') as f:
with open(file_path, "w") as f:
json.dump(config_data, f, indent=2)
logger.debug(f"Saved config '{name}' to {file_path}")
@@ -288,7 +300,7 @@ class ConfigManager:
try:
# Load regular JSON
with open(file_path, 'r') as f:
with open(file_path, "r") as f:
return json.load(f)
except Exception as e:
@@ -302,11 +314,11 @@ class ConfigManager:
config_name = path.stem
# Load updated configuration
if path.suffix == '.json':
with open(path, 'r') as f:
if path.suffix == ".json":
with open(path, "r") as f:
new_config = json.load(f)
elif path.suffix in ['.yaml', '.yml']:
with open(path, 'r') as f:
elif path.suffix in [".yaml", ".yml"]:
with open(path, "r") as f:
new_config = yaml.safe_load(f)
else:
logger.warning(f"Unsupported config file format: {path.suffix}")
@@ -317,9 +329,13 @@ class ConfigManager:
self.configs[config_name] = new_config
self.stats.hot_reloads_performed += 1
self.stats.last_reload_time = datetime.now()
logger.info(f"Hot reloaded configuration '{config_name}' from {file_path}")
logger.info(
f"Hot reloaded configuration '{config_name}' from {file_path}"
)
else:
logger.error(f"Failed to hot reload '{config_name}' - validation failed")
logger.error(
f"Failed to hot reload '{config_name}' - validation failed"
)
except Exception as e:
logger.error(f"Error hot reloading config from {file_path}: {str(e)}")
@@ -375,7 +391,7 @@ async def _register_default_schemas():
required_fields=["host", "port", "name"],
optional_fields=["username", "password", "ssl"],
field_types={"host": str, "port": int, "name": str, "ssl": bool},
validators={"port": lambda x: 1 <= x <= 65535}
validators={"port": lambda x: 1 <= x <= 65535},
)
manager.register_schema("database", db_schema)
@@ -385,7 +401,7 @@ async def _register_default_schemas():
required_fields=["redis_url"],
optional_fields=["timeout", "max_connections"],
field_types={"redis_url": str, "timeout": int, "max_connections": int},
validators={"timeout": lambda x: x > 0}
validators={"timeout": lambda x: x > 0},
)
manager.register_schema("cache", cache_schema)
@@ -400,13 +416,13 @@ async def _load_default_configs():
"version": "1.0.0",
"debug": manager.environment == "development",
"log_level": "INFO",
"timezone": "UTC"
"timezone": "UTC",
},
"cache": {
"redis_url": "redis://empire-redis:6379/0",
"timeout": 30,
"max_connections": 10
}
"max_connections": 10,
},
}
for name, config in default_configs.items():

View File

@@ -27,7 +27,7 @@ class ConversationService:
chatbot_id: str,
user_id: str,
conversation_id: Optional[str] = None,
title: Optional[str] = None
title: Optional[str] = None,
) -> ChatbotConversation:
"""Get existing conversation or create a new one"""
@@ -38,7 +38,7 @@ class ConversationService:
ChatbotConversation.id == conversation_id,
ChatbotConversation.chatbot_id == chatbot_id,
ChatbotConversation.user_id == user_id,
ChatbotConversation.is_active == True
ChatbotConversation.is_active == True,
)
)
result = await self.db.execute(stmt)
@@ -48,7 +48,9 @@ class ConversationService:
logger.info(f"Found existing conversation {conversation_id}")
return conversation
else:
logger.warning(f"Conversation {conversation_id} not found or not accessible")
logger.warning(
f"Conversation {conversation_id} not found or not accessible"
)
# Create new conversation
if not title:
@@ -61,21 +63,20 @@ class ConversationService:
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
is_active=True,
context_data={}
context_data={},
)
self.db.add(conversation)
await self.db.commit()
await self.db.refresh(conversation)
logger.info(f"Created new conversation {conversation.id} for chatbot {chatbot_id}")
logger.info(
f"Created new conversation {conversation.id} for chatbot {chatbot_id}"
)
return conversation
async def get_conversation_history(
self,
conversation_id: str,
limit: int = 20,
include_system: bool = False
self, conversation_id: str, limit: int = 20, include_system: bool = False
) -> List[Dict[str, Any]]:
"""
Load conversation history for a conversation
@@ -96,7 +97,7 @@ class ConversationService:
# Optionally exclude system messages
if not include_system:
stmt = stmt.where(ChatbotMessage.role != 'system')
stmt = stmt.where(ChatbotMessage.role != "system")
# Order by timestamp descending and limit
stmt = stmt.order_by(desc(ChatbotMessage.timestamp)).limit(limit)
@@ -107,19 +108,27 @@ class ConversationService:
# Convert to list and reverse to get chronological order (oldest first)
history = []
for msg in reversed(messages):
history.append({
history.append(
{
"role": msg.role,
"content": msg.content,
"timestamp": msg.timestamp.isoformat() if msg.timestamp else None,
"timestamp": msg.timestamp.isoformat()
if msg.timestamp
else None,
"metadata": msg.message_metadata or {},
"sources": msg.sources
})
"sources": msg.sources,
}
)
logger.info(f"Loaded {len(history)} messages for conversation {conversation_id}")
logger.info(
f"Loaded {len(history)} messages for conversation {conversation_id}"
)
return history
except Exception as e:
logger.error(f"Failed to load conversation history for {conversation_id}: {e}")
logger.error(
f"Failed to load conversation history for {conversation_id}: {e}"
)
return [] # Return empty list on error to avoid breaking chat
async def add_message(
@@ -128,11 +137,11 @@ class ConversationService:
role: str,
content: str,
metadata: Optional[Dict[str, Any]] = None,
sources: Optional[List[Dict[str, Any]]] = None
sources: Optional[List[Dict[str, Any]]] = None,
) -> ChatbotMessage:
"""Add a message to a conversation"""
if role not in ['user', 'assistant', 'system']:
if role not in ["user", "assistant", "system"]:
raise ValueError(f"Invalid message role: {role}")
message = ChatbotMessage(
@@ -141,13 +150,15 @@ class ConversationService:
content=content,
timestamp=datetime.utcnow(),
message_metadata=metadata or {},
sources=sources
sources=sources,
)
self.db.add(message)
# Update conversation timestamp
stmt = select(ChatbotConversation).where(ChatbotConversation.id == conversation_id)
stmt = select(ChatbotConversation).where(
ChatbotConversation.id == conversation_id
)
result = await self.db.execute(stmt)
conversation = result.scalar_one_or_none()
@@ -164,18 +175,19 @@ class ConversationService:
"""Get statistics for a conversation"""
# Count messages by role
stmt = select(
ChatbotMessage.role,
func.count(ChatbotMessage.id).label('count')
).where(
ChatbotMessage.conversation_id == conversation_id
).group_by(ChatbotMessage.role)
stmt = (
select(ChatbotMessage.role, func.count(ChatbotMessage.id).label("count"))
.where(ChatbotMessage.conversation_id == conversation_id)
.group_by(ChatbotMessage.role)
)
result = await self.db.execute(stmt)
role_counts = {row.role: row.count for row in result}
# Get conversation info
stmt = select(ChatbotConversation).where(ChatbotConversation.id == conversation_id)
stmt = select(ChatbotConversation).where(
ChatbotConversation.id == conversation_id
)
result = await self.db.execute(stmt)
conversation = result.scalar_one_or_none()
@@ -185,12 +197,16 @@ class ConversationService:
return {
"conversation_id": conversation_id,
"title": conversation.title,
"created_at": conversation.created_at.isoformat() if conversation.created_at else None,
"updated_at": conversation.updated_at.isoformat() if conversation.updated_at else None,
"created_at": conversation.created_at.isoformat()
if conversation.created_at
else None,
"updated_at": conversation.updated_at.isoformat()
if conversation.updated_at
else None,
"total_messages": sum(role_counts.values()),
"user_messages": role_counts.get('user', 0),
"assistant_messages": role_counts.get('assistant', 0),
"system_messages": role_counts.get('system', 0)
"user_messages": role_counts.get("user", 0),
"assistant_messages": role_counts.get("assistant", 0),
"system_messages": role_counts.get("system", 0),
}
async def archive_old_conversations(self, days_inactive: int = 30) -> int:
@@ -202,7 +218,7 @@ class ConversationService:
stmt = select(ChatbotConversation).where(
and_(
ChatbotConversation.updated_at < cutoff_date,
ChatbotConversation.is_active == True
ChatbotConversation.is_active == True,
)
)
@@ -224,12 +240,16 @@ class ConversationService:
"""Delete a conversation and all its messages"""
# Verify ownership
stmt = select(ChatbotConversation).where(
stmt = (
select(ChatbotConversation)
.where(
and_(
ChatbotConversation.id == conversation_id,
ChatbotConversation.user_id == user_id
ChatbotConversation.user_id == user_id,
)
)
.options(selectinload(ChatbotConversation.messages))
)
).options(selectinload(ChatbotConversation.messages))
result = await self.db.execute(stmt)
conversation = result.scalar_one_or_none()
@@ -245,7 +265,9 @@ class ConversationService:
await self.db.delete(conversation)
await self.db.commit()
logger.info(f"Deleted conversation {conversation_id} with {len(conversation.messages)} messages")
logger.info(
f"Deleted conversation {conversation_id} with {len(conversation.messages)} messages"
)
return True
async def get_user_conversations(
@@ -253,21 +275,25 @@ class ConversationService:
user_id: str,
chatbot_id: Optional[str] = None,
limit: int = 50,
skip: int = 0
skip: int = 0,
) -> List[Dict[str, Any]]:
"""Get list of conversations for a user"""
stmt = select(ChatbotConversation).where(
and_(
ChatbotConversation.user_id == user_id,
ChatbotConversation.is_active == True
ChatbotConversation.is_active == True,
)
)
if chatbot_id:
stmt = stmt.where(ChatbotConversation.chatbot_id == chatbot_id)
stmt = stmt.order_by(desc(ChatbotConversation.updated_at)).offset(skip).limit(limit)
stmt = (
stmt.order_by(desc(ChatbotConversation.updated_at))
.offset(skip)
.limit(limit)
)
result = await self.db.execute(stmt)
conversations = result.scalars().all()
@@ -281,14 +307,20 @@ class ConversationService:
msg_count_result = await self.db.execute(msg_count_stmt)
message_count = msg_count_result.scalar() or 0
conversation_list.append({
conversation_list.append(
{
"id": conv.id,
"chatbot_id": conv.chatbot_id,
"title": conv.title,
"message_count": message_count,
"created_at": conv.created_at.isoformat() if conv.created_at else None,
"updated_at": conv.updated_at.isoformat() if conv.updated_at else None,
"context_data": conv.context_data or {}
})
"created_at": conv.created_at.isoformat()
if conv.created_at
else None,
"updated_at": conv.updated_at.isoformat()
if conv.updated_at
else None,
"context_data": conv.context_data or {},
}
)
return conversation_list

View File

@@ -17,20 +17,22 @@ class CostCalculator:
"gpt-4": {"input": 300, "output": 600}, # $0.03/$0.06 per 1K tokens
"gpt-4-turbo": {"input": 100, "output": 300}, # $0.01/$0.03 per 1K tokens
"gpt-3.5-turbo": {"input": 5, "output": 15}, # $0.0005/$0.0015 per 1K tokens
# Anthropic Models
"claude-3-opus": {"input": 150, "output": 750}, # $0.015/$0.075 per 1K tokens
"claude-3-sonnet": {"input": 30, "output": 150}, # $0.003/$0.015 per 1K tokens
"claude-3-haiku": {"input": 25, "output": 125}, # $0.00025/$0.00125 per 1K tokens
"claude-3-haiku": {
"input": 25,
"output": 125,
}, # $0.00025/$0.00125 per 1K tokens
# Google Models
"gemini-pro": {"input": 5, "output": 15}, # $0.0005/$0.0015 per 1K tokens
"gemini-pro-vision": {"input": 5, "output": 15}, # $0.0005/$0.0015 per 1K tokens
"gemini-pro-vision": {
"input": 5,
"output": 15,
}, # $0.0005/$0.0015 per 1K tokens
# Privatemode.ai Models (estimated pricing)
"privatemode-llama-70b": {"input": 40, "output": 80}, # Estimated pricing
"privatemode-mixtral": {"input": 20, "output": 40}, # Estimated pricing
# Embedding Models
"text-embedding-ada-002": {"input": 1, "output": 0}, # $0.0001 per 1K tokens
"text-embedding-3-small": {"input": 2, "output": 0}, # $0.00002 per 1K tokens
@@ -49,7 +51,9 @@ class CostCalculator:
# Look up pricing
pricing = cls.MODEL_PRICING.get(normalized_name, cls.DEFAULT_PRICING)
logger.debug(f"Pricing for model '{model_name}' (normalized: '{normalized_name}'): {pricing}")
logger.debug(
f"Pricing for model '{model_name}' (normalized: '{normalized_name}'): {pricing}"
)
return pricing
@classmethod
@@ -61,7 +65,7 @@ class CostCalculator:
normalized = model_name.lower()
for prefix in prefixes:
if normalized.startswith(prefix):
normalized = normalized[len(prefix):]
normalized = normalized[len(prefix) :]
break
# Handle special cases
@@ -80,10 +84,7 @@ class CostCalculator:
@classmethod
def calculate_cost_cents(
cls,
model_name: str,
input_tokens: int = 0,
output_tokens: int = 0
cls, model_name: str, input_tokens: int = 0, output_tokens: int = 0
) -> int:
"""
Calculate cost in cents for given token usage
@@ -147,7 +148,7 @@ class CostCalculator:
return {
"input": pricing_cents["input"] / 10000, # Convert 1/10000ths to dollars
"output": pricing_cents["output"] / 10000,
"currency": "USD"
"currency": "USD",
}
@classmethod
@@ -172,7 +173,9 @@ class CostCalculator:
# Convenience functions for common operations
def calculate_request_cost(model_name: str, input_tokens: int, output_tokens: int) -> int:
def calculate_request_cost(
model_name: str, input_tokens: int, output_tokens: int
) -> int:
"""Calculate cost for a single request"""
return CostCalculator.calculate_cost_cents(model_name, input_tokens, output_tokens)

View File

@@ -33,6 +33,7 @@ class ProcessingStatus(str, Enum):
@dataclass
class ProcessingTask:
"""Document processing task"""
document_id: int
priority: int = 1
retry_count: int = 0
@@ -57,7 +58,7 @@ class DocumentProcessor:
"processed_count": 0,
"error_count": 0,
"queue_size": 0,
"active_workers": 0
"active_workers": 0,
}
self._rag_module = None
self._rag_module_lock = asyncio.Lock()
@@ -111,11 +112,15 @@ class DocumentProcessor:
self.stats["queue_size"] = self.processing_queue.qsize()
logger.info(f"Added processing task for document {document_id} (priority: {priority})")
logger.info(
f"Added processing task for document {document_id} (priority: {priority})"
)
return True
except Exception as e:
logger.error(f"Failed to add processing task for document {document_id}: {e}")
logger.error(
f"Failed to add processing task for document {document_id}: {e}"
)
return False
async def _worker(self, worker_name: str):
@@ -126,10 +131,7 @@ class DocumentProcessor:
task: Optional[ProcessingTask] = None
try:
# Get task from queue (wait up to 1 second)
task = await asyncio.wait_for(
self.processing_queue.get(),
timeout=1.0
)
task = await asyncio.wait_for(self.processing_queue.get(), timeout=1.0)
self.stats["active_workers"] += 1
self.stats["queue_size"] = self.processing_queue.qsize()
@@ -141,14 +143,20 @@ class DocumentProcessor:
if success:
self.stats["processed_count"] += 1
logger.info(f"{worker_name}: Successfully processed document {task.document_id}")
logger.info(
f"{worker_name}: Successfully processed document {task.document_id}"
)
else:
# Retry logic
if task.retry_count < task.max_retries:
task.retry_count += 1
await asyncio.sleep(2 ** task.retry_count) # Exponential backoff
await asyncio.sleep(
2**task.retry_count
) # Exponential backoff
try:
await asyncio.wait_for(self.processing_queue.put(task), timeout=5.0)
await asyncio.wait_for(
self.processing_queue.put(task), timeout=5.0
)
except asyncio.TimeoutError:
logger.error(
"%s: Failed to requeue document %s due to saturated queue",
@@ -157,10 +165,14 @@ class DocumentProcessor:
)
self.stats["error_count"] += 1
continue
logger.warning(f"{worker_name}: Retrying document {task.document_id} (attempt {task.retry_count})")
logger.warning(
f"{worker_name}: Retrying document {task.document_id} (attempt {task.retry_count})"
)
else:
self.stats["error_count"] += 1
logger.error(f"{worker_name}: Failed to process document {task.document_id} after {task.max_retries} retries")
logger.error(
f"{worker_name}: Failed to process document {task.document_id} after {task.max_retries} retries"
)
except asyncio.TimeoutError:
# No tasks in queue, continue
@@ -183,30 +195,34 @@ class DocumentProcessor:
async def _get_rag_module(self):
"""Resolve and cache the RAG module instance"""
async with self._rag_module_lock:
if self._rag_module and getattr(self._rag_module, 'enabled', False):
if self._rag_module and getattr(self._rag_module, "enabled", False):
return self._rag_module
if not module_manager.initialized:
await module_manager.initialize()
rag_module = module_manager.get_module('rag')
rag_module = module_manager.get_module("rag")
if not rag_module:
enabled = await module_manager.enable_module('rag')
enabled = await module_manager.enable_module("rag")
if not enabled:
raise RuntimeError("Failed to enable RAG module")
rag_module = module_manager.get_module('rag')
rag_module = module_manager.get_module("rag")
if not rag_module:
raise RuntimeError("RAG module not available after enable attempt")
if not getattr(rag_module, 'enabled', True):
enabled = await module_manager.enable_module('rag')
if not getattr(rag_module, "enabled", True):
enabled = await module_manager.enable_module("rag")
if not enabled:
raise RuntimeError("RAG module is disabled and could not be re-enabled")
rag_module = module_manager.get_module('rag')
if not rag_module or not getattr(rag_module, 'enabled', True):
raise RuntimeError("RAG module is disabled and could not be re-enabled")
raise RuntimeError(
"RAG module is disabled and could not be re-enabled"
)
rag_module = module_manager.get_module("rag")
if not rag_module or not getattr(rag_module, "enabled", True):
raise RuntimeError(
"RAG module is disabled and could not be re-enabled"
)
self._rag_module = rag_module
logger.info("DocumentProcessor cached RAG module instance for reuse")
@@ -216,6 +232,7 @@ class DocumentProcessor:
"""Process a single document"""
from datetime import datetime
from app.db.database import async_session_factory
async with async_session_factory() as session:
try:
# Get document from database
@@ -245,42 +262,61 @@ class DocumentProcessor:
if not rag_module or not rag_module.enabled:
raise Exception("RAG module not available or not enabled")
logger.info(f"RAG module loaded successfully for document {task.document_id}")
logger.info(
f"RAG module loaded successfully for document {task.document_id}"
)
# Read file content
logger.info(f"Reading file content for document {task.document_id}: {document.file_path}")
logger.info(
f"Reading file content for document {task.document_id}: {document.file_path}"
)
file_path = Path(document.file_path)
try:
file_content = await asyncio.to_thread(file_path.read_bytes)
except FileNotFoundError:
logger.error(f"File not found for document {task.document_id}: {document.file_path}")
logger.error(
f"File not found for document {task.document_id}: {document.file_path}"
)
document.status = ProcessingStatus.ERROR
document.processing_error = "Document file not found on disk"
await session.commit()
return False
except Exception as exc:
logger.error(f"Failed reading file for document {task.document_id}: {exc}")
logger.error(
f"Failed reading file for document {task.document_id}: {exc}"
)
document.status = ProcessingStatus.ERROR
document.processing_error = f"Failed to read file: {exc}"
await session.commit()
return False
logger.info(f"File content read successfully for document {task.document_id}, size: {len(file_content)} bytes")
logger.info(
f"File content read successfully for document {task.document_id}, size: {len(file_content)} bytes"
)
# Process with RAG module
logger.info(f"Starting document processing for document {task.document_id} with RAG module")
logger.info(
f"Starting document processing for document {task.document_id} with RAG module"
)
# Special handling for JSONL files - skip processing phase
if document.file_type == 'jsonl':
if document.file_type == "jsonl":
# For JSONL files, we don't need to process content here
# The optimized JSONL processor will handle everything during indexing
document.converted_content = f"JSONL file with {len(file_content)} bytes"
document.converted_content = (
f"JSONL file with {len(file_content)} bytes"
)
document.word_count = 0 # Will be updated during indexing
document.character_count = len(file_content)
document.document_metadata = {"file_path": document.file_path, "processed": "jsonl"}
document.document_metadata = {
"file_path": document.file_path,
"processed": "jsonl",
}
document.status = ProcessingStatus.PROCESSED
document.processed_at = datetime.utcnow()
logger.info(f"JSONL document {task.document_id} marked for optimized processing")
logger.info(
f"JSONL document {task.document_id} marked for optimized processing"
)
else:
# Standard processing for other file types
try:
@@ -289,11 +325,13 @@ class DocumentProcessor:
rag_module.process_document(
file_content,
document.original_filename,
{"file_path": document.file_path}
{"file_path": document.file_path},
),
timeout=300.0 # 5 minute timeout
timeout=300.0, # 5 minute timeout
)
logger.info(
f"Document processing completed for document {task.document_id}"
)
logger.info(f"Document processing completed for document {task.document_id}")
# Update document with processed content
document.converted_content = processed_doc.content
@@ -303,16 +341,22 @@ class DocumentProcessor:
document.status = ProcessingStatus.PROCESSED
document.processed_at = datetime.utcnow()
except asyncio.TimeoutError:
logger.error(f"Document processing timed out for document {task.document_id}")
logger.error(
f"Document processing timed out for document {task.document_id}"
)
raise Exception("Document processing timed out after 5 minutes")
except Exception as e:
logger.error(f"Document processing failed for document {task.document_id}: {e}")
logger.error(
f"Document processing failed for document {task.document_id}: {e}"
)
raise
# Index in RAG system using same RAG module
if rag_module and document.converted_content:
try:
logger.info(f"Starting indexing for document {task.document_id} in collection {document.collection.qdrant_collection_name}")
logger.info(
f"Starting indexing for document {task.document_id} in collection {document.collection.qdrant_collection_name}"
)
# Index the document content in the correct Qdrant collection
doc_metadata = {
@@ -320,12 +364,12 @@ class DocumentProcessor:
"document_id": document.id,
"filename": document.original_filename,
"file_type": document.file_type,
**document.document_metadata
**document.document_metadata,
}
# Use the correct Qdrant collection name for this document
# For JSONL files, we need to use the processed document flow
if document.file_type == 'jsonl':
if document.file_type == "jsonl":
# Create a ProcessedDocument for the JSONL processor
from app.modules.rag.main import ProcessedDocument
from datetime import datetime
@@ -333,7 +377,9 @@ class DocumentProcessor:
# Calculate file hash
processed_at = datetime.utcnow()
file_hash = hashlib.md5(str(document.id).encode()).hexdigest()
file_hash = hashlib.md5(
str(document.id).encode()
).hexdigest()
processed_doc = ProcessedDocument(
id=str(document.id),
@@ -341,12 +387,14 @@ class DocumentProcessor:
extracted_text="", # Will be filled by JSONL processor
metadata={
**doc_metadata,
"file_path": document.file_path
"file_path": document.file_path,
},
original_filename=document.original_filename,
file_type=document.file_type,
mime_type=document.mime_type,
language=document.document_metadata.get('language', 'EN'),
language=document.document_metadata.get(
"language", "EN"
),
word_count=0, # Will be updated during processing
sentence_count=0, # Will be updated during processing
entities=[],
@@ -354,16 +402,16 @@ class DocumentProcessor:
processing_time=0.0,
processed_at=processed_at,
file_hash=file_hash,
file_size=document.file_size
file_size=document.file_size,
)
# The JSONL processor will read the original file
await asyncio.wait_for(
rag_module.index_processed_document(
processed_doc=processed_doc,
collection_name=document.collection.qdrant_collection_name
collection_name=document.collection.qdrant_collection_name,
),
timeout=300.0 # 5 minute timeout for JSONL processing
timeout=300.0, # 5 minute timeout for JSONL processing
)
else:
# Use standard indexing for other file types
@@ -371,15 +419,19 @@ class DocumentProcessor:
rag_module.index_document(
content=document.converted_content,
metadata=doc_metadata,
collection_name=document.collection.qdrant_collection_name
collection_name=document.collection.qdrant_collection_name,
),
timeout=120.0 # 2 minute timeout for indexing
timeout=120.0, # 2 minute timeout for indexing
)
logger.info(f"Document {task.document_id} indexed successfully in collection {document.collection.qdrant_collection_name}")
logger.info(
f"Document {task.document_id} indexed successfully in collection {document.collection.qdrant_collection_name}"
)
# Update vector count (approximate)
document.vector_count = max(1, len(document.converted_content) // 1000)
document.vector_count = max(
1, len(document.converted_content) // 1000
)
document.status = ProcessingStatus.INDEXED
document.indexed_at = datetime.utcnow()
@@ -392,7 +444,9 @@ class DocumentProcessor:
collection.updated_at = datetime.utcnow()
except Exception as e:
logger.error(f"Failed to index document {task.document_id} in RAG: {e}")
logger.error(
f"Failed to index document {task.document_id} in RAG: {e}"
)
# Mark as error since indexing failed
document.status = ProcessingStatus.ERROR
document.processing_error = f"Indexing failed: {str(e)}"
@@ -403,7 +457,7 @@ class DocumentProcessor:
except Exception as e:
# Mark document as error
if 'document' in locals() and document:
if "document" in locals() and document:
document.status = ProcessingStatus.ERROR
document.processing_error = str(e)
await session.commit()
@@ -417,7 +471,7 @@ class DocumentProcessor:
**self.stats,
"running": self.running,
"worker_count": len(self.workers),
"queue_size": self.processing_queue.qsize()
"queue_size": self.processing_queue.qsize(),
}
async def get_queue_status(self) -> Dict[str, Any]:
@@ -427,7 +481,7 @@ class DocumentProcessor:
"max_queue_size": self.max_queue_size,
"queue_full": self.processing_queue.full(),
"active_workers": self.stats["active_workers"],
"max_workers": self.max_workers
"max_workers": self.max_workers,
}

View File

@@ -19,7 +19,9 @@ class EmbeddingService:
"""Service for generating text embeddings using a local transformer model"""
def __init__(self, model_name: Optional[str] = None):
self.model_name = model_name or getattr(settings, "RAG_EMBEDDING_MODEL", "BAAI/bge-m3")
self.model_name = model_name or getattr(
settings, "RAG_EMBEDDING_MODEL", "BAAI/bge-m3"
)
self.dimension = 1024 # bge-m3 produces 1024-d vectors
self.initialized = False
self.local_model = None
@@ -56,7 +58,9 @@ class EmbeddingService:
return False
except Exception as exc:
logger.error(f"Failed to load local embedding model {self.model_name}: {exc}")
logger.error(
f"Failed to load local embedding model {self.model_name}: {exc}"
)
logger.warning("Falling back to random embeddings")
self.local_model = None
self.initialized = False
@@ -102,13 +106,21 @@ class EmbeddingService:
except Exception as exc:
logger.error(f"Local embedding generation failed: {exc}")
self.backend = "fallback_random"
return self._generate_fallback_embeddings(texts, duration=time.time() - start_time)
return self._generate_fallback_embeddings(
texts, duration=time.time() - start_time
)
logger.warning("Local embedding model unavailable; using fallback random embeddings")
logger.warning(
"Local embedding model unavailable; using fallback random embeddings"
)
self.backend = "fallback_random"
return self._generate_fallback_embeddings(texts, duration=time.time() - start_time)
return self._generate_fallback_embeddings(
texts, duration=time.time() - start_time
)
def _generate_fallback_embeddings(self, texts: List[str], duration: float = None) -> List[List[float]]:
def _generate_fallback_embeddings(
self, texts: List[str], duration: float = None
) -> List[List[float]]:
"""Generate fallback random embeddings when model unavailable"""
embeddings = []
for text in texts:
@@ -155,7 +167,7 @@ class EmbeddingService:
"model_loaded": self.initialized,
"dimension": self.dimension,
"backend": self.backend,
"initialized": self.initialized
"initialized": self.initialized,
}
async def cleanup(self):

View File

@@ -21,17 +21,30 @@ class EnhancedEmbeddingService(EmbeddingService):
def __init__(self, model_name: Optional[str] = None):
super().__init__(model_name)
self.rate_limit_tracker = {
'requests_count': 0,
'window_start': time.time(),
'window_size': 60, # 1 minute window
'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 12)), # Configurable
'retry_delays': [int(x) for x in getattr(settings, 'RAG_EMBEDDING_RETRY_DELAYS', '1,2,4,8,16').split(',')], # Exponential backoff
'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 1.0)),
'delay_per_request': float(getattr(settings, 'RAG_EMBEDDING_DELAY_PER_REQUEST', 0.5)),
'last_rate_limit_error': None
"requests_count": 0,
"window_start": time.time(),
"window_size": 60, # 1 minute window
"max_requests_per_minute": int(
getattr(settings, "RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE", 12)
), # Configurable
"retry_delays": [
int(x)
for x in getattr(
settings, "RAG_EMBEDDING_RETRY_DELAYS", "1,2,4,8,16"
).split(",")
], # Exponential backoff
"delay_between_batches": float(
getattr(settings, "RAG_EMBEDDING_DELAY_BETWEEN_BATCHES", 1.0)
),
"delay_per_request": float(
getattr(settings, "RAG_EMBEDDING_DELAY_PER_REQUEST", 0.5)
),
"last_rate_limit_error": None,
}
async def get_embeddings_with_retry(self, texts: List[str], max_retries: int = None) -> tuple[List[List[float]], bool]:
async def get_embeddings_with_retry(
self, texts: List[str], max_retries: int = None
) -> tuple[List[List[float]], bool]:
"""
Get embeddings with retry bookkeeping.
"""
@@ -51,13 +64,20 @@ class EnhancedEmbeddingService(EmbeddingService):
return {
**base_stats,
"rate_limit_info": {
"requests_in_current_window": self.rate_limit_tracker['requests_count'],
"max_requests_per_minute": self.rate_limit_tracker['max_requests_per_minute'],
"window_reset_in_seconds": max(0,
self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size'] - time.time()
"requests_in_current_window": self.rate_limit_tracker["requests_count"],
"max_requests_per_minute": self.rate_limit_tracker[
"max_requests_per_minute"
],
"window_reset_in_seconds": max(
0,
self.rate_limit_tracker["window_start"]
+ self.rate_limit_tracker["window_size"]
- time.time(),
),
"last_rate_limit_error": self.rate_limit_tracker['last_rate_limit_error']
}
"last_rate_limit_error": self.rate_limit_tracker[
"last_rate_limit_error"
],
},
}

View File

@@ -14,6 +14,7 @@ from qdrant_client.models import PointStruct, Filter, FieldCondition, MatchValue
from qdrant_client.http.models import Batch
from app.modules.rag.main import ProcessedDocument
# from app.core.analytics import log_module_event # Analytics module not available
logger = logging.getLogger(__name__)
@@ -26,8 +27,13 @@ class JSONLProcessor:
self.rag_module = rag_module
self.config = rag_module.config
async def process_and_index_jsonl(self, collection_name: str, content: bytes,
filename: str, metadata: Dict[str, Any]) -> str:
async def process_and_index_jsonl(
self,
collection_name: str,
content: bytes,
filename: str,
metadata: Dict[str, Any],
) -> str:
"""Process and index a JSONL file efficiently
Processes each JSON line as a separate document to avoid
@@ -35,8 +41,8 @@ class JSONLProcessor:
"""
try:
# Decode content
jsonl_content = content.decode('utf-8', errors='replace')
lines = jsonl_content.strip().split('\n')
jsonl_content = content.decode("utf-8", errors="replace")
lines = jsonl_content.strip().split("\n")
logger.info(f"Processing JSONL file {filename} with {len(lines)} lines")
@@ -58,28 +64,38 @@ class JSONLProcessor:
batch_start,
base_doc_id,
filename,
metadata
metadata,
)
processed_count += len(batch_lines)
# Log progress
if processed_count % 50 == 0:
logger.info(f"Processed {processed_count}/{len(lines)} lines from {filename}")
logger.info(
f"Processed {processed_count}/{len(lines)} lines from {filename}"
)
# Small delay to prevent resource exhaustion
await asyncio.sleep(0.05)
logger.info(f"Successfully processed JSONL file {filename} with {len(lines)} lines")
logger.info(
f"Successfully processed JSONL file {filename} with {len(lines)} lines"
)
return base_doc_id
except Exception as e:
logger.error(f"Error processing JSONL file {filename}: {e}")
raise
async def _process_jsonl_batch(self, collection_name: str, lines: List[str],
start_idx: int, base_doc_id: str,
filename: str, metadata: Dict[str, Any]) -> None:
async def _process_jsonl_batch(
self,
collection_name: str,
lines: List[str],
start_idx: int,
base_doc_id: str,
filename: str,
metadata: Dict[str, Any],
) -> None:
"""Process a batch of JSONL lines"""
try:
points = []
@@ -98,14 +114,14 @@ class JSONLProcessor:
continue
# Handle helpjuice export format
if 'payload' in data and data['payload'] is not None:
payload = data['payload']
article_id = data.get('id', f'article_{line_idx}')
if "payload" in data and data["payload"] is not None:
payload = data["payload"]
article_id = data.get("id", f"article_{line_idx}")
# Extract Q&A
question = payload.get('question', '')
answer = payload.get('answer', '')
language = payload.get('language', 'EN')
question = payload.get("question", "")
answer = payload.get("answer", "")
language = payload.get("language", "EN")
if question or answer:
# Create Q&A content
@@ -120,15 +136,18 @@ class JSONLProcessor:
"line_number": line_idx,
"content_type": "qa_pair",
"question": question[:100], # Truncate for metadata
"processed_at": datetime.utcnow().isoformat()
"processed_at": datetime.utcnow().isoformat(),
}
# Generate single embedding for the Q&A pair
embeddings = await self.rag_module._generate_embeddings([content])
embeddings = await self.rag_module._generate_embeddings(
[content]
)
# Create point
point_id = str(uuid.uuid4())
points.append(PointStruct(
points.append(
PointStruct(
id=point_id,
vector=embeddings[0],
payload={
@@ -136,9 +155,10 @@ class JSONLProcessor:
"document_id": f"{base_doc_id}_{article_id}",
"content": content,
"chunk_index": 0,
"chunk_count": 1
}
))
"chunk_count": 1,
},
)
)
# Handle generic JSON format
else:
@@ -146,12 +166,19 @@ class JSONLProcessor:
# For larger JSON objects, we might need to chunk
if len(content) > 1000:
chunks = self.rag_module._chunk_text(content, chunk_size=500)
embeddings = await self.rag_module._generate_embeddings(chunks)
chunks = self.rag_module._chunk_text(
content, chunk_size=500
)
embeddings = await self.rag_module._generate_embeddings(
chunks
)
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
for i, (chunk, embedding) in enumerate(
zip(chunks, embeddings)
):
point_id = str(uuid.uuid4())
points.append(PointStruct(
points.append(
PointStruct(
id=point_id,
vector=embedding,
payload={
@@ -162,14 +189,18 @@ class JSONLProcessor:
"document_id": f"{base_doc_id}_line_{line_idx}",
"content": chunk,
"chunk_index": i,
"chunk_count": len(chunks)
}
))
"chunk_count": len(chunks),
},
)
)
else:
# Small JSON - no chunking needed
embeddings = await self.rag_module._generate_embeddings([content])
embeddings = await self.rag_module._generate_embeddings(
[content]
)
point_id = str(uuid.uuid4())
points.append(PointStruct(
points.append(
PointStruct(
id=point_id,
vector=embeddings[0],
payload={
@@ -180,9 +211,10 @@ class JSONLProcessor:
"document_id": f"{base_doc_id}_line_{line_idx}",
"content": content,
"chunk_index": 0,
"chunk_count": 1
}
))
"chunk_count": 1,
},
)
)
except json.JSONDecodeError as e:
logger.warning(f"Error parsing JSONL line {line_idx}: {e}")
@@ -194,8 +226,7 @@ class JSONLProcessor:
# Insert all points in this batch
if points:
self.rag_module.qdrant_client.upsert(
collection_name=collection_name,
points=points
collection_name=collection_name, points=points
)
# Update stats

View File

@@ -17,5 +17,5 @@ __all__ = [
"EmbeddingResponse",
"LLMError",
"ProviderError",
"SecurityError"
"SecurityError",
]

View File

@@ -15,30 +15,51 @@ from .models import ResilienceConfig
class ProviderConfig(BaseModel):
"""Configuration for an LLM provider"""
name: str = Field(..., description="Provider name")
provider_type: str = Field(..., description="Provider type (e.g., 'openai', 'privatemode')")
provider_type: str = Field(
..., description="Provider type (e.g., 'openai', 'privatemode')"
)
enabled: bool = Field(True, description="Whether provider is enabled")
base_url: str = Field(..., description="Provider base URL")
api_key_env_var: str = Field(..., description="Environment variable for API key")
default_model: Optional[str] = Field(None, description="Default model for this provider")
supported_models: List[str] = Field(default_factory=list, description="List of supported models")
capabilities: List[str] = Field(default_factory=list, description="Provider capabilities")
default_model: Optional[str] = Field(
None, description="Default model for this provider"
)
supported_models: List[str] = Field(
default_factory=list, description="List of supported models"
)
capabilities: List[str] = Field(
default_factory=list, description="Provider capabilities"
)
priority: int = Field(1, description="Provider priority (lower = higher priority)")
# Rate limiting
max_requests_per_minute: Optional[int] = Field(None, description="Max requests per minute")
max_requests_per_hour: Optional[int] = Field(None, description="Max requests per hour")
max_requests_per_minute: Optional[int] = Field(
None, description="Max requests per minute"
)
max_requests_per_hour: Optional[int] = Field(
None, description="Max requests per hour"
)
# Model-specific settings
supports_streaming: bool = Field(False, description="Whether provider supports streaming")
supports_function_calling: bool = Field(False, description="Whether provider supports function calling")
max_context_window: Optional[int] = Field(None, description="Maximum context window size")
supports_streaming: bool = Field(
False, description="Whether provider supports streaming"
)
supports_function_calling: bool = Field(
False, description="Whether provider supports function calling"
)
max_context_window: Optional[int] = Field(
None, description="Maximum context window size"
)
max_output_tokens: Optional[int] = Field(None, description="Maximum output tokens")
# Resilience configuration
resilience: ResilienceConfig = Field(default_factory=ResilienceConfig, description="Resilience settings")
resilience: ResilienceConfig = Field(
default_factory=ResilienceConfig, description="Resilience settings"
)
@validator('priority')
@validator("priority")
def validate_priority(cls, v):
if v < 1:
raise ValueError("Priority must be >= 1")
@@ -50,32 +71,45 @@ class LLMServiceConfig(BaseModel):
# Global settings
default_provider: str = Field("privatemode", description="Default provider to use")
enable_detailed_logging: bool = Field(False, description="Enable detailed request/response logging")
enable_detailed_logging: bool = Field(
False, description="Enable detailed request/response logging"
)
enable_security_checks: bool = Field(True, description="Enable security validation")
enable_metrics_collection: bool = Field(True, description="Enable metrics collection")
enable_metrics_collection: bool = Field(
True, description="Enable metrics collection"
)
max_prompt_length: int = Field(50000, ge=1000, description="Maximum prompt length")
max_response_length: int = Field(32000, ge=1000, description="Maximum response length")
max_response_length: int = Field(
32000, ge=1000, description="Maximum response length"
)
# Performance settings
default_timeout_ms: int = Field(30000, ge=1000, le=300000, description="Default request timeout")
max_concurrent_requests: int = Field(100, ge=1, le=1000, description="Maximum concurrent requests")
default_timeout_ms: int = Field(
30000, ge=1000, le=300000, description="Default request timeout"
)
max_concurrent_requests: int = Field(
100, ge=1, le=1000, description="Maximum concurrent requests"
)
# Provider configurations
providers: Dict[str, ProviderConfig] = Field(default_factory=dict, description="Provider configurations")
providers: Dict[str, ProviderConfig] = Field(
default_factory=dict, description="Provider configurations"
)
# Token rate limiting (organization-wide)
token_limits_per_minute: Dict[str, int] = Field(
default_factory=lambda: {
"prompt_tokens": 20000, # PrivateMode Standard tier
"completion_tokens": 10000 # PrivateMode Standard tier
"completion_tokens": 10000, # PrivateMode Standard tier
},
description="Token rate limits per minute (organization-wide)"
description="Token rate limits per minute (organization-wide)",
)
# Model routing (model_name -> provider_name)
model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing")
model_routing: Dict[str, str] = Field(
default_factory=dict, description="Model to provider routing"
)
def create_default_config(env_vars=None) -> LLMServiceConfig:
@@ -105,13 +139,11 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
retry_delay_ms=1000,
timeout_ms=60000, # PrivateMode may be slower due to TEE
circuit_breaker_threshold=5,
circuit_breaker_reset_timeout_ms=120000
)
circuit_breaker_reset_timeout_ms=120000,
),
)
providers: Dict[str, ProviderConfig] = {
"privatemode": privatemode_config
}
providers: Dict[str, ProviderConfig] = {"privatemode": privatemode_config}
if env.OPENAI_API_KEY:
providers["openai"] = ProviderConfig(
@@ -126,7 +158,7 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
"gpt-4o",
"gpt-3.5-turbo",
"text-embedding-3-large",
"text-embedding-3-small"
"text-embedding-3-small",
],
capabilities=["chat", "embeddings"],
priority=2,
@@ -139,8 +171,8 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
retry_delay_ms=750,
timeout_ms=45000,
circuit_breaker_threshold=6,
circuit_breaker_reset_timeout_ms=60000
)
circuit_breaker_reset_timeout_ms=60000,
),
)
if env.ANTHROPIC_API_KEY:
@@ -154,7 +186,7 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
supported_models=[
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307"
"claude-3-haiku-20240307",
],
capabilities=["chat"],
priority=3,
@@ -167,8 +199,8 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
retry_delay_ms=1000,
timeout_ms=60000,
circuit_breaker_threshold=5,
circuit_breaker_reset_timeout_ms=90000
)
circuit_breaker_reset_timeout_ms=90000,
),
)
if env.GOOGLE_API_KEY:
@@ -181,7 +213,7 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
default_model="models/gemini-1.5-pro-latest",
supported_models=[
"models/gemini-1.5-pro-latest",
"models/gemini-1.5-flash-latest"
"models/gemini-1.5-flash-latest",
],
capabilities=["chat", "multimodal"],
priority=4,
@@ -194,13 +226,13 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
retry_delay_ms=1000,
timeout_ms=45000,
circuit_breaker_threshold=4,
circuit_breaker_reset_timeout_ms=60000
)
circuit_breaker_reset_timeout_ms=60000,
),
)
default_provider = next(
(name for name, provider in providers.items() if provider.enabled),
"privatemode"
"privatemode",
)
# Create main configuration
@@ -208,7 +240,7 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
default_provider=default_provider,
enable_detailed_logging=settings.LOG_LLM_PROMPTS,
providers=providers,
model_routing={} # Will be populated dynamically from provider models
model_routing={}, # Will be populated dynamically from provider models
)
return config
@@ -241,7 +273,7 @@ class EnvironmentVariables:
"privatemode": self.PRIVATEMODE_API_KEY,
"openai": self.OPENAI_API_KEY,
"anthropic": self.ANTHROPIC_API_KEY,
"google": self.GOOGLE_API_KEY
"google": self.GOOGLE_API_KEY,
}
return key_mapping.get(provider_name.lower())
@@ -307,8 +339,11 @@ class ConfigurationManager:
if missing_keys:
import logging
logger = logging.getLogger(__name__)
logger.warning(f"Missing API keys for enabled providers: {', '.join(missing_keys)}")
logger.warning(
f"Missing API keys for enabled providers: {', '.join(missing_keys)}"
)
# Validate default provider is enabled
default_provider = self._config.default_provider
@@ -323,8 +358,11 @@ class ConfigurationManager:
if invalid_routes:
import logging
logger = logging.getLogger(__name__)
logger.warning(f"Model routes point to disabled providers: {', '.join(invalid_routes)}")
logger.warning(
f"Model routes point to disabled providers: {', '.join(invalid_routes)}"
)
async def refresh_provider_models(self, provider_name: str, models: List[str]):
"""Update supported models for a provider dynamically"""
@@ -343,6 +381,7 @@ class ConfigurationManager:
self._config.model_routing[model] = provider_name
import logging
logger = logging.getLogger(__name__)
logger.info(f"Updated {provider_name} with {len(models)} models: {models}")

View File

@@ -8,7 +8,9 @@ Custom exceptions for LLM service operations.
class LLMError(Exception):
"""Base exception for LLM service errors"""
def __init__(self, message: str, error_code: str = "LLM_ERROR", details: dict = None):
def __init__(
self, message: str, error_code: str = "LLM_ERROR", details: dict = None
):
super().__init__(message)
self.message = message
self.error_code = error_code
@@ -18,7 +20,13 @@ class LLMError(Exception):
class ProviderError(LLMError):
"""Exception for LLM provider-specific errors"""
def __init__(self, message: str, provider: str, error_code: str = "PROVIDER_ERROR", details: dict = None):
def __init__(
self,
message: str,
provider: str,
error_code: str = "PROVIDER_ERROR",
details: dict = None,
):
super().__init__(message, error_code, details)
self.provider = provider
@@ -26,7 +34,13 @@ class ProviderError(LLMError):
class SecurityError(LLMError):
"""Exception for security-related errors"""
def __init__(self, message: str, risk_score: float = 0.0, error_code: str = "SECURITY_ERROR", details: dict = None):
def __init__(
self,
message: str,
risk_score: float = 0.0,
error_code: str = "SECURITY_ERROR",
details: dict = None,
):
super().__init__(message, error_code, details)
self.risk_score = risk_score
@@ -34,14 +48,22 @@ class SecurityError(LLMError):
class ConfigurationError(LLMError):
"""Exception for configuration-related errors"""
def __init__(self, message: str, error_code: str = "CONFIG_ERROR", details: dict = None):
def __init__(
self, message: str, error_code: str = "CONFIG_ERROR", details: dict = None
):
super().__init__(message, error_code, details)
class RateLimitError(LLMError):
"""Exception for rate limiting errors"""
def __init__(self, message: str, retry_after: int = None, error_code: str = "RATE_LIMIT_ERROR", details: dict = None):
def __init__(
self,
message: str,
retry_after: int = None,
error_code: str = "RATE_LIMIT_ERROR",
details: dict = None,
):
super().__init__(message, error_code, details)
self.retry_after = retry_after
@@ -49,7 +71,13 @@ class RateLimitError(LLMError):
class TimeoutError(LLMError):
"""Exception for timeout errors"""
def __init__(self, message: str, timeout_duration: float = None, error_code: str = "TIMEOUT_ERROR", details: dict = None):
def __init__(
self,
message: str,
timeout_duration: float = None,
error_code: str = "TIMEOUT_ERROR",
details: dict = None,
):
super().__init__(message, error_code, details)
self.timeout_duration = timeout_duration
@@ -57,6 +85,12 @@ class TimeoutError(LLMError):
class ValidationError(LLMError):
"""Exception for request validation errors"""
def __init__(self, message: str, field: str = None, error_code: str = "VALIDATION_ERROR", details: dict = None):
def __init__(
self,
message: str,
field: str = None,
error_code: str = "VALIDATION_ERROR",
details: dict = None,
):
super().__init__(message, error_code, details)
self.field = field

View File

@@ -20,6 +20,7 @@ logger = logging.getLogger(__name__)
@dataclass
class RequestMetric:
"""Individual request metric"""
timestamp: datetime
provider: str
model: str
@@ -45,7 +46,9 @@ class MetricsCollector:
"""
self.max_history_size = max_history_size
self._metrics: deque = deque(maxlen=max_history_size)
self._provider_metrics: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000))
self._provider_metrics: Dict[str, deque] = defaultdict(
lambda: deque(maxlen=1000)
)
self._lock = threading.RLock()
# Aggregated metrics cache
@@ -53,7 +56,9 @@ class MetricsCollector:
self._cached_metrics: Optional[LLMMetrics] = None
self._cache_ttl_seconds = 60 # Cache for 1 minute
logger.info(f"Metrics collector initialized with max history: {max_history_size}")
logger.info(
f"Metrics collector initialized with max history: {max_history_size}"
)
def record_request(
self,
@@ -66,7 +71,7 @@ class MetricsCollector:
security_risk_score: float = 0.0,
error_code: Optional[str] = None,
user_id: Optional[str] = None,
api_key_id: Optional[int] = None
api_key_id: Optional[int] = None,
):
"""Record a request metric"""
metric = RequestMetric(
@@ -80,7 +85,7 @@ class MetricsCollector:
security_risk_score=security_risk_score,
error_code=error_code,
user_id=user_id,
api_key_id=api_key_id
api_key_id=api_key_id,
)
with self._lock:
@@ -93,18 +98,25 @@ class MetricsCollector:
# Log significant events
if not success:
logger.warning(f"Request failed: {provider}/{model} - {error_code or 'Unknown error'}")
logger.warning(
f"Request failed: {provider}/{model} - {error_code or 'Unknown error'}"
)
elif security_risk_score > 0.6:
logger.info(f"High risk request: {provider}/{model} - risk score: {security_risk_score:.3f}")
logger.info(
f"High risk request: {provider}/{model} - risk score: {security_risk_score:.3f}"
)
def get_metrics(self, force_refresh: bool = False) -> LLMMetrics:
"""Get aggregated metrics"""
with self._lock:
# Check cache validity
if (not force_refresh and
self._cached_metrics and
self._cache_timestamp and
(datetime.utcnow() - self._cache_timestamp).total_seconds() < self._cache_ttl_seconds):
if (
not force_refresh
and self._cached_metrics
and self._cache_timestamp
and (datetime.utcnow() - self._cache_timestamp).total_seconds()
< self._cache_ttl_seconds
):
return self._cached_metrics
# Calculate fresh metrics
@@ -136,7 +148,9 @@ class MetricsCollector:
provider_metrics = {}
for provider, provider_data in self._provider_metrics.items():
if provider_data:
provider_metrics[provider] = self._calculate_provider_metrics(provider_data)
provider_metrics[provider] = self._calculate_provider_metrics(
provider_data
)
return LLMMetrics(
total_requests=total_requests,
@@ -145,7 +159,7 @@ class MetricsCollector:
average_latency_ms=avg_latency,
average_risk_score=avg_risk_score,
provider_metrics=provider_metrics,
last_updated=datetime.utcnow()
last_updated=datetime.utcnow(),
)
def _calculate_provider_metrics(self, provider_data: deque) -> Dict[str, Any]:
@@ -168,7 +182,9 @@ class MetricsCollector:
for metric in provider_data:
if metric.token_usage:
total_prompt_tokens += metric.token_usage.get("prompt_tokens", 0)
total_completion_tokens += metric.token_usage.get("completion_tokens", 0)
total_completion_tokens += metric.token_usage.get(
"completion_tokens", 0
)
total_tokens += metric.token_usage.get("total_tokens", 0)
# Model distribution
@@ -198,12 +214,14 @@ class MetricsCollector:
"total_completion_tokens": total_completion_tokens,
"total_tokens": total_tokens,
"avg_prompt_tokens": total_prompt_tokens / total if total > 0 else 0,
"avg_completion_tokens": total_completion_tokens / successful if successful > 0 else 0
"avg_completion_tokens": total_completion_tokens / successful
if successful > 0
else 0,
},
"model_distribution": dict(model_counts),
"request_type_distribution": dict(request_type_counts),
"error_distribution": dict(error_counts),
"recent_requests": total
"recent_requests": total,
}
def get_provider_metrics(self, provider: str) -> Optional[Dict[str, Any]]:
@@ -228,7 +246,11 @@ class MetricsCollector:
with self._lock:
for metric in self._metrics:
if metric.timestamp >= cutoff_time and not metric.success and metric.error_code:
if (
metric.timestamp >= cutoff_time
and not metric.success
and metric.error_code
):
error_counts[metric.error_code] += 1
return dict(error_counts)
@@ -252,7 +274,7 @@ class MetricsCollector:
"max_latency_ms": max(latencies),
"p95_latency_ms": self._percentile(latencies, 95),
"p99_latency_ms": self._percentile(latencies, 99),
"request_count": len(latencies)
"request_count": len(latencies),
}
return performance
@@ -307,9 +329,11 @@ class MetricsCollector:
"total_requests_5min": total_recent,
"average_latency_ms": metrics.average_latency_ms,
"error_count_1hour": sum(error_metrics.values()),
"top_errors": dict(sorted(error_metrics.items(), key=lambda x: x[1], reverse=True)[:5]),
"top_errors": dict(
sorted(error_metrics.items(), key=lambda x: x[1], reverse=True)[:5]
),
"provider_count": len(metrics.provider_metrics),
"last_updated": datetime.utcnow().isoformat()
"last_updated": datetime.utcnow().isoformat(),
}

View File

@@ -9,15 +9,30 @@ from pydantic import BaseModel, Field, validator
from datetime import datetime
class ToolCall(BaseModel):
"""Tool call in a message"""
id: str = Field(..., description="Tool call identifier")
type: str = Field("function", description="Tool call type")
function: Dict[str, Any] = Field(..., description="Function call details")
class ChatMessage(BaseModel):
"""Individual chat message"""
role: str = Field(..., description="Message role (system, user, assistant)")
content: str = Field(..., description="Message content")
name: Optional[str] = Field(None, description="Optional message name")
@validator('role')
role: str = Field(..., description="Message role (system, user, assistant)")
content: Optional[str] = Field(None, description="Message content")
name: Optional[str] = Field(None, description="Optional message name")
tool_calls: Optional[List[ToolCall]] = Field(
None, description="Tool calls in this message"
)
tool_call_id: Optional[str] = Field(
None, description="Tool call ID for tool responses"
)
@validator("role")
def validate_role(cls, v):
allowed_roles = {'system', 'user', 'assistant', 'function'}
allowed_roles = {"system", "user", "assistant", "function", "tool"}
if v not in allowed_roles:
raise ValueError(f"Role must be one of {allowed_roles}")
return v
@@ -25,21 +40,38 @@ class ChatMessage(BaseModel):
class ChatRequest(BaseModel):
"""Chat completion request"""
model: str = Field(..., description="Model identifier")
messages: List[ChatMessage] = Field(..., description="Chat messages")
temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature")
max_tokens: Optional[int] = Field(None, ge=1, le=32000, description="Maximum tokens to generate")
top_p: Optional[float] = Field(1.0, ge=0.0, le=1.0, description="Nucleus sampling parameter")
temperature: Optional[float] = Field(
0.7, ge=0.0, le=2.0, description="Sampling temperature"
)
max_tokens: Optional[int] = Field(
None, ge=1, le=32000, description="Maximum tokens to generate"
)
top_p: Optional[float] = Field(
1.0, ge=0.0, le=1.0, description="Nucleus sampling parameter"
)
top_k: Optional[int] = Field(None, ge=1, description="Top-k sampling parameter")
frequency_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="Frequency penalty")
presence_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="Presence penalty")
frequency_penalty: Optional[float] = Field(
0.0, ge=-2.0, le=2.0, description="Frequency penalty"
)
presence_penalty: Optional[float] = Field(
0.0, ge=-2.0, le=2.0, description="Presence penalty"
)
stop: Optional[Union[str, List[str]]] = Field(None, description="Stop sequences")
stream: Optional[bool] = Field(False, description="Stream response")
tools: Optional[List[Dict[str, Any]]] = Field(
None, description="Available tools for function calling"
)
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
None, description="Tool choice preference"
)
user_id: str = Field(..., description="User identifier")
api_key_id: int = Field(..., description="API key identifier")
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
@validator('messages')
@validator("messages")
def validate_messages(cls, v):
if not v:
raise ValueError("Messages cannot be empty")
@@ -48,6 +80,7 @@ class ChatRequest(BaseModel):
class TokenUsage(BaseModel):
"""Token usage information"""
prompt_tokens: int = Field(..., description="Tokens in the prompt")
completion_tokens: int = Field(..., description="Tokens in the completion")
total_tokens: int = Field(..., description="Total tokens used")
@@ -55,13 +88,17 @@ class TokenUsage(BaseModel):
class ChatChoice(BaseModel):
"""Chat completion choice"""
index: int = Field(..., description="Choice index")
message: ChatMessage = Field(..., description="Generated message")
finish_reason: Optional[str] = Field(None, description="Reason for completion finish")
finish_reason: Optional[str] = Field(
None, description="Reason for completion finish"
)
class ChatResponse(BaseModel):
"""Chat completion response"""
id: str = Field(..., description="Response identifier")
object: str = Field("chat.completion", description="Object type")
created: int = Field(..., description="Creation timestamp")
@@ -72,17 +109,26 @@ class ChatResponse(BaseModel):
system_fingerprint: Optional[str] = Field(None, description="System fingerprint")
# Security fields maintained for backward compatibility
security_check: Optional[bool] = Field(None, description="Whether security check passed")
security_check: Optional[bool] = Field(
None, description="Whether security check passed"
)
risk_score: Optional[float] = Field(None, description="Security risk score")
detected_patterns: Optional[List[str]] = Field(None, description="Detected security patterns")
detected_patterns: Optional[List[str]] = Field(
None, description="Detected security patterns"
)
# Performance metrics
latency_ms: Optional[float] = Field(None, description="Response latency in milliseconds")
provider_latency_ms: Optional[float] = Field(None, description="Provider-specific latency")
latency_ms: Optional[float] = Field(
None, description="Response latency in milliseconds"
)
provider_latency_ms: Optional[float] = Field(
None, description="Provider-specific latency"
)
class EmbeddingRequest(BaseModel):
"""Embedding generation request"""
model: str = Field(..., description="Embedding model identifier")
input: Union[str, List[str]] = Field(..., description="Text to embed")
encoding_format: Optional[str] = Field("float", description="Encoding format")
@@ -91,19 +137,22 @@ class EmbeddingRequest(BaseModel):
api_key_id: int = Field(..., description="API key identifier")
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
@validator('input')
@validator("input")
def validate_input(cls, v):
if isinstance(v, str):
if not v.strip():
raise ValueError("Input text cannot be empty")
elif isinstance(v, list):
if not v or not all(isinstance(item, str) and item.strip() for item in v):
raise ValueError("Input list cannot be empty and must contain non-empty strings")
raise ValueError(
"Input list cannot be empty and must contain non-empty strings"
)
return v
class EmbeddingData(BaseModel):
"""Single embedding data"""
object: str = Field("embedding", description="Object type")
index: int = Field(..., description="Embedding index")
embedding: List[float] = Field(..., description="Embedding vector")
@@ -111,6 +160,7 @@ class EmbeddingData(BaseModel):
class EmbeddingResponse(BaseModel):
"""Embedding generation response"""
object: str = Field("list", description="Object type")
data: List[EmbeddingData] = Field(..., description="Embedding data")
model: str = Field(..., description="Model used")
@@ -118,56 +168,90 @@ class EmbeddingResponse(BaseModel):
usage: Optional[TokenUsage] = Field(None, description="Token usage")
# Security fields maintained for backward compatibility
security_check: Optional[bool] = Field(None, description="Whether security check passed")
security_check: Optional[bool] = Field(
None, description="Whether security check passed"
)
risk_score: Optional[float] = Field(None, description="Security risk score")
detected_patterns: Optional[List[str]] = Field(None, description="Detected security patterns")
detected_patterns: Optional[List[str]] = Field(
None, description="Detected security patterns"
)
# Performance metrics
latency_ms: Optional[float] = Field(None, description="Response latency in milliseconds")
provider_latency_ms: Optional[float] = Field(None, description="Provider-specific latency")
latency_ms: Optional[float] = Field(
None, description="Response latency in milliseconds"
)
provider_latency_ms: Optional[float] = Field(
None, description="Provider-specific latency"
)
class ModelInfo(BaseModel):
"""Model information"""
id: str = Field(..., description="Model identifier")
object: str = Field("model", description="Object type")
created: Optional[int] = Field(None, description="Creation timestamp")
owned_by: str = Field(..., description="Model owner")
provider: str = Field(..., description="Provider name")
capabilities: List[str] = Field(default_factory=list, description="Model capabilities")
capabilities: List[str] = Field(
default_factory=list, description="Model capabilities"
)
context_window: Optional[int] = Field(None, description="Context window size")
max_output_tokens: Optional[int] = Field(None, description="Maximum output tokens")
supports_streaming: bool = Field(False, description="Whether model supports streaming")
supports_function_calling: bool = Field(False, description="Whether model supports function calling")
tasks: Optional[List[str]] = Field(None, description="Model tasks (e.g., generate, embed, vision)")
supports_streaming: bool = Field(
False, description="Whether model supports streaming"
)
supports_function_calling: bool = Field(
False, description="Whether model supports function calling"
)
tasks: Optional[List[str]] = Field(
None, description="Model tasks (e.g., generate, embed, vision)"
)
class ProviderStatus(BaseModel):
"""Provider health status"""
provider: str = Field(..., description="Provider name")
status: str = Field(..., description="Status (healthy, degraded, unavailable)")
latency_ms: Optional[float] = Field(None, description="Average latency")
success_rate: Optional[float] = Field(None, description="Success rate (0.0 to 1.0)")
last_check: datetime = Field(..., description="Last health check timestamp")
error_message: Optional[str] = Field(None, description="Error message if unhealthy")
models_available: List[str] = Field(default_factory=list, description="Available models")
models_available: List[str] = Field(
default_factory=list, description="Available models"
)
class LLMMetrics(BaseModel):
"""LLM service metrics"""
total_requests: int = Field(0, description="Total requests processed")
successful_requests: int = Field(0, description="Successful requests")
failed_requests: int = Field(0, description="Failed requests")
average_latency_ms: float = Field(0.0, description="Average response latency")
provider_metrics: Dict[str, Dict[str, Any]] = Field(default_factory=dict, description="Per-provider metrics")
last_updated: datetime = Field(default_factory=datetime.utcnow, description="Last metrics update")
provider_metrics: Dict[str, Dict[str, Any]] = Field(
default_factory=dict, description="Per-provider metrics"
)
last_updated: datetime = Field(
default_factory=datetime.utcnow, description="Last metrics update"
)
class ResilienceConfig(BaseModel):
"""Configuration for resilience patterns"""
max_retries: int = Field(3, ge=0, le=10, description="Maximum retry attempts")
retry_delay_ms: int = Field(1000, ge=100, le=30000, description="Initial retry delay")
retry_exponential_base: float = Field(2.0, ge=1.1, le=5.0, description="Exponential backoff base")
retry_delay_ms: int = Field(
1000, ge=100, le=30000, description="Initial retry delay"
)
retry_exponential_base: float = Field(
2.0, ge=1.1, le=5.0, description="Exponential backoff base"
)
timeout_ms: int = Field(30000, ge=1000, le=300000, description="Request timeout")
circuit_breaker_threshold: int = Field(5, ge=1, le=50, description="Circuit breaker failure threshold")
circuit_breaker_reset_timeout_ms: int = Field(60000, ge=10000, le=600000, description="Circuit breaker reset timeout")
circuit_breaker_threshold: int = Field(
5, ge=1, le=50, description="Circuit breaker failure threshold"
)
circuit_breaker_reset_timeout_ms: int = Field(
60000, ge=10000, le=600000, description="Circuit breaker reset timeout"
)

View File

@@ -9,8 +9,12 @@ from typing import List, Dict, Any, Optional, AsyncGenerator
import logging
from ..models import (
ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
ModelInfo, ProviderStatus
ChatRequest,
ChatResponse,
EmbeddingRequest,
EmbeddingResponse,
ModelInfo,
ProviderStatus,
)
from ..config import ProviderConfig
@@ -80,7 +84,9 @@ class BaseLLMProvider(ABC):
pass
@abstractmethod
async def create_chat_completion_stream(self, request: ChatRequest) -> AsyncGenerator[Dict[str, Any], None]:
async def create_chat_completion_stream(
self, request: ChatRequest
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Create streaming chat completion
@@ -121,7 +127,7 @@ class BaseLLMProvider(ABC):
async def cleanup(self):
"""Cleanup provider resources"""
if self._session and hasattr(self._session, 'close'):
if self._session and hasattr(self._session, "close"):
await self._session.close()
logger.debug(f"Cleaned up session for {self.name} provider")
@@ -147,24 +153,27 @@ class BaseLLMProvider(ABC):
context_window=self.config.max_context_window,
max_output_tokens=self.config.max_output_tokens,
supports_streaming=self.config.supports_streaming,
supports_function_calling=self.config.supports_function_calling
supports_function_calling=self.config.supports_function_calling,
)
def _validate_request(self, request: Any):
"""Base request validation (override for provider-specific validation)"""
if hasattr(request, 'model') and not self.supports_model(request.model):
if hasattr(request, "model") and not self.supports_model(request.model):
from ..exceptions import ValidationError
raise ValidationError(
f"Model '{request.model}' not supported by provider '{self.name}'",
field="model"
field="model",
)
def _create_headers(self, additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
def _create_headers(
self, additional_headers: Optional[Dict[str, str]] = None
) -> Dict[str, str]:
"""Create HTTP headers for requests"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
"User-Agent": f"Enclava-LLM-Service/{self.name}"
"User-Agent": f"Enclava-LLM-Service/{self.name}",
}
if additional_headers:
@@ -172,7 +181,9 @@ class BaseLLMProvider(ABC):
return headers
def _handle_http_error(self, status_code: int, response_text: str, provider_context: str = ""):
def _handle_http_error(
self, status_code: int, response_text: str, provider_context: str = ""
):
"""Handle HTTP errors consistently across providers"""
from ..exceptions import ProviderError, RateLimitError, ValidationError
@@ -183,40 +194,44 @@ class BaseLLMProvider(ABC):
f"Authentication failed for {context}",
provider=self.name,
error_code="AUTHENTICATION_ERROR",
details={"status_code": status_code, "response": response_text}
details={"status_code": status_code, "response": response_text},
)
elif status_code == 403:
raise ProviderError(
f"Access forbidden for {context}",
provider=self.name,
error_code="AUTHORIZATION_ERROR",
details={"status_code": status_code, "response": response_text}
details={"status_code": status_code, "response": response_text},
)
elif status_code == 429:
raise RateLimitError(
f"Rate limit exceeded for {context}",
error_code="RATE_LIMIT_ERROR",
details={"status_code": status_code, "response": response_text, "provider": self.name}
details={
"status_code": status_code,
"response": response_text,
"provider": self.name,
},
)
elif status_code == 400:
raise ValidationError(
f"Bad request for {context}: {response_text}",
error_code="BAD_REQUEST",
details={"status_code": status_code, "response": response_text}
details={"status_code": status_code, "response": response_text},
)
elif 500 <= status_code < 600:
raise ProviderError(
f"Server error for {context}: {response_text}",
provider=self.name,
error_code="SERVER_ERROR",
details={"status_code": status_code, "response": response_text}
details={"status_code": status_code, "response": response_text},
)
else:
raise ProviderError(
f"HTTP error {status_code} for {context}: {response_text}",
provider=self.name,
error_code="HTTP_ERROR",
details={"status_code": status_code, "response": response_text}
details={"status_code": status_code, "response": response_text},
)
def __str__(self) -> str:

View File

@@ -15,9 +15,16 @@ import aiohttp
from .base import BaseLLMProvider
from ..models import (
ChatRequest, ChatResponse, ChatMessage, ChatChoice, TokenUsage,
EmbeddingRequest, EmbeddingResponse, EmbeddingData,
ModelInfo, ProviderStatus
ChatRequest,
ChatResponse,
ChatMessage,
ChatChoice,
TokenUsage,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingData,
ModelInfo,
ProviderStatus,
)
from ..config import ProviderConfig
from ..exceptions import ProviderError, ValidationError, TimeoutError
@@ -30,7 +37,7 @@ class PrivateModeProvider(BaseLLMProvider):
def __init__(self, config: ProviderConfig, api_key: str):
super().__init__(config, api_key)
self.base_url = config.base_url.rstrip('/')
self.base_url = config.base_url.rstrip("/")
self._session: Optional[aiohttp.ClientSession] = None
# TEE-specific settings
@@ -52,17 +59,19 @@ class PrivateModeProvider(BaseLLMProvider):
limit=100, # Connection pool limit
limit_per_host=50,
ttl_dns_cache=300, # DNS cache TTL
use_dns_cache=True
use_dns_cache=True,
)
# Create session with security headers
timeout = aiohttp.ClientTimeout(total=self.config.resilience.timeout_ms / 1000.0)
timeout = aiohttp.ClientTimeout(
total=self.config.resilience.timeout_ms / 1000.0
)
self._session = aiohttp.ClientSession(
connector=connector,
timeout=timeout,
headers=self._create_headers(),
trust_env=False # Don't trust environment variables
trust_env=False, # Don't trust environment variables
)
logger.debug("Created new secure HTTP session for PrivateMode")
@@ -82,7 +91,9 @@ class PrivateModeProvider(BaseLLMProvider):
if response.status == 200:
models_data = await response.json()
models = [model.get("id", "") for model in models_data.get("data", [])]
models = [
model.get("id", "") for model in models_data.get("data", [])
]
return ProviderStatus(
provider=self.provider_name,
@@ -90,7 +101,7 @@ class PrivateModeProvider(BaseLLMProvider):
latency_ms=latency,
success_rate=1.0,
last_check=datetime.utcnow(),
models_available=models
models_available=models,
)
else:
error_text = await response.text()
@@ -101,7 +112,7 @@ class PrivateModeProvider(BaseLLMProvider):
success_rate=0.0,
last_check=datetime.utcnow(),
error_message=f"HTTP {response.status}: {error_text}",
models_available=[]
models_available=[],
)
except Exception as e:
@@ -115,7 +126,7 @@ class PrivateModeProvider(BaseLLMProvider):
success_rate=0.0,
last_check=datetime.utcnow(),
error_message=str(e),
models_available=[]
models_available=[],
)
async def get_models(self) -> List[ModelInfo]:
@@ -151,7 +162,9 @@ class PrivateModeProvider(BaseLLMProvider):
capabilities.append("vision")
# Check for function calling support in the API response
supports_function_calling = model_data.get("supports_function_calling", False)
supports_function_calling = model_data.get(
"supports_function_calling", False
)
if supports_function_calling:
capabilities.append("function_calling")
@@ -164,9 +177,11 @@ class PrivateModeProvider(BaseLLMProvider):
capabilities=capabilities,
context_window=model_data.get("context_window"),
max_output_tokens=model_data.get("max_output_tokens"),
supports_streaming=model_data.get("supports_streaming", True),
supports_streaming=model_data.get(
"supports_streaming", True
),
supports_function_calling=supports_function_calling,
tasks=tasks # Pass through tasks field from PrivateMode API
tasks=tasks, # Pass through tasks field from PrivateMode API
)
models.append(model_info)
@@ -174,7 +189,9 @@ class PrivateModeProvider(BaseLLMProvider):
return models
else:
error_text = await response.text()
self._handle_http_error(response.status, error_text, "models endpoint")
self._handle_http_error(
response.status, error_text, "models endpoint"
)
return [] # Never reached due to exception
except Exception as e:
@@ -186,7 +203,7 @@ class PrivateModeProvider(BaseLLMProvider):
"Failed to retrieve models from PrivateMode",
provider=self.provider_name,
error_code="MODEL_RETRIEVAL_ERROR",
details={"error": str(e)}
details={"error": str(e)},
)
async def create_chat_completion(self, request: ChatRequest) -> ChatResponse:
@@ -205,12 +222,12 @@ class PrivateModeProvider(BaseLLMProvider):
{
"role": msg.role,
"content": msg.content,
**({"name": msg.name} if msg.name else {})
**({"name": msg.name} if msg.name else {}),
}
for msg in request.messages
],
"temperature": request.temperature,
"stream": False # Non-streaming version
"stream": False, # Non-streaming version
}
# Add optional parameters
@@ -234,12 +251,11 @@ class PrivateModeProvider(BaseLLMProvider):
"api_key_id": request.api_key_id,
"timestamp": datetime.utcnow().isoformat(),
"enclava_request_id": str(uuid.uuid4()),
**(request.metadata or {})
**(request.metadata or {}),
}
async with session.post(
f"{self.base_url}/chat/completions",
json=payload
f"{self.base_url}/chat/completions", json=payload
) as response:
provider_latency = (time.time() - start_time) * 1000
@@ -254,9 +270,9 @@ class PrivateModeProvider(BaseLLMProvider):
index=choice_data.get("index", 0),
message=ChatMessage(
role=message_data.get("role", "assistant"),
content=message_data.get("content", "")
content=message_data.get("content", ""),
),
finish_reason=choice_data.get("finish_reason")
finish_reason=choice_data.get("finish_reason"),
)
choices.append(choice)
@@ -265,7 +281,7 @@ class PrivateModeProvider(BaseLLMProvider):
usage = TokenUsage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0)
total_tokens=usage_data.get("total_tokens", 0),
)
# Create response
@@ -281,15 +297,19 @@ class PrivateModeProvider(BaseLLMProvider):
security_check=True, # Will be set by security manager
risk_score=0.0, # Will be set by security manager
latency_ms=provider_latency,
provider_latency_ms=provider_latency
provider_latency_ms=provider_latency,
)
logger.debug(f"PrivateMode chat completion successful in {provider_latency:.2f}ms")
logger.debug(
f"PrivateMode chat completion successful in {provider_latency:.2f}ms"
)
return chat_response
else:
error_text = await response.text()
self._handle_http_error(response.status, error_text, "chat completion")
self._handle_http_error(
response.status, error_text, "chat completion"
)
except aiohttp.ClientError as e:
logger.error(f"PrivateMode request error: {e}")
@@ -297,7 +317,7 @@ class PrivateModeProvider(BaseLLMProvider):
"Network error communicating with PrivateMode",
provider=self.provider_name,
error_code="NETWORK_ERROR",
details={"error": str(e)}
details={"error": str(e)},
)
except Exception as e:
if isinstance(e, (ProviderError, ValidationError)):
@@ -308,10 +328,12 @@ class PrivateModeProvider(BaseLLMProvider):
"Unexpected error during chat completion",
provider=self.provider_name,
error_code="UNEXPECTED_ERROR",
details={"error": str(e)}
details={"error": str(e)},
)
async def create_chat_completion_stream(self, request: ChatRequest) -> AsyncGenerator[Dict[str, Any], None]:
async def create_chat_completion_stream(
self, request: ChatRequest
) -> AsyncGenerator[Dict[str, Any], None]:
"""Create streaming chat completion"""
self._validate_request(request)
@@ -325,12 +347,12 @@ class PrivateModeProvider(BaseLLMProvider):
{
"role": msg.role,
"content": msg.content,
**({"name": msg.name} if msg.name else {})
**({"name": msg.name} if msg.name else {}),
}
for msg in request.messages
],
"temperature": request.temperature,
"stream": True
"stream": True,
}
# Add optional parameters
@@ -349,12 +371,11 @@ class PrivateModeProvider(BaseLLMProvider):
payload["user"] = f"user_{request.user_id}"
async with session.post(
f"{self.base_url}/chat/completions",
json=payload
f"{self.base_url}/chat/completions", json=payload
) as response:
if response.status == 200:
async for line in response.content:
line = line.decode('utf-8').strip()
line = line.decode("utf-8").strip()
if line.startswith("data: "):
data_str = line[6:] # Remove "data: " prefix
@@ -366,11 +387,15 @@ class PrivateModeProvider(BaseLLMProvider):
chunk_data = json.loads(data_str)
yield chunk_data
except json.JSONDecodeError:
logger.warning(f"Failed to parse streaming chunk: {data_str}")
logger.warning(
f"Failed to parse streaming chunk: {data_str}"
)
continue
else:
error_text = await response.text()
self._handle_http_error(response.status, error_text, "streaming chat completion")
self._handle_http_error(
response.status, error_text, "streaming chat completion"
)
except aiohttp.ClientError as e:
logger.error(f"PrivateMode streaming error: {e}")
@@ -378,7 +403,7 @@ class PrivateModeProvider(BaseLLMProvider):
"Network error during streaming",
provider=self.provider_name,
error_code="STREAMING_ERROR",
details={"error": str(e)}
details={"error": str(e)},
)
async def create_embedding(self, request: EmbeddingRequest) -> EmbeddingResponse:
@@ -394,7 +419,7 @@ class PrivateModeProvider(BaseLLMProvider):
payload = {
"model": request.model,
"input": request.input,
"user": f"user_{request.user_id}"
"user": f"user_{request.user_id}",
}
# Add optional parameters
@@ -408,12 +433,11 @@ class PrivateModeProvider(BaseLLMProvider):
"user_id": request.user_id,
"api_key_id": request.api_key_id,
"timestamp": datetime.utcnow().isoformat(),
**(request.metadata or {})
**(request.metadata or {}),
}
async with session.post(
f"{self.base_url}/embeddings",
json=payload
f"{self.base_url}/embeddings", json=payload
) as response:
provider_latency = (time.time() - start_time) * 1000
@@ -426,7 +450,7 @@ class PrivateModeProvider(BaseLLMProvider):
embedding = EmbeddingData(
object="embedding",
index=emb_data.get("index", 0),
embedding=emb_data.get("embedding", [])
embedding=emb_data.get("embedding", []),
)
embeddings.append(embedding)
@@ -435,7 +459,9 @@ class PrivateModeProvider(BaseLLMProvider):
usage = TokenUsage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=0, # No completion tokens for embeddings
total_tokens=usage_data.get("total_tokens", usage_data.get("prompt_tokens", 0))
total_tokens=usage_data.get(
"total_tokens", usage_data.get("prompt_tokens", 0)
),
)
return EmbeddingResponse(
@@ -447,13 +473,15 @@ class PrivateModeProvider(BaseLLMProvider):
security_check=True, # Will be set by security manager
risk_score=0.0, # Will be set by security manager
latency_ms=provider_latency,
provider_latency_ms=provider_latency
provider_latency_ms=provider_latency,
)
else:
error_text = await response.text()
# Log the detailed error response from the provider
logger.error(f"PrivateMode embedding error - Status {response.status}: {error_text}")
logger.error(
f"PrivateMode embedding error - Status {response.status}: {error_text}"
)
self._handle_http_error(response.status, error_text, "embeddings")
except aiohttp.ClientError as e:
@@ -462,7 +490,7 @@ class PrivateModeProvider(BaseLLMProvider):
"Network error during embedding generation",
provider=self.provider_name,
error_code="EMBEDDING_ERROR",
details={"error": str(e)}
details={"error": str(e)},
)
except Exception as e:
if isinstance(e, (ProviderError, ValidationError)):
@@ -473,7 +501,7 @@ class PrivateModeProvider(BaseLLMProvider):
"Unexpected error during embedding generation",
provider=self.provider_name,
error_code="UNEXPECTED_ERROR",
details={"error": str(e)}
details={"error": str(e)},
)
async def cleanup(self):

View File

@@ -20,6 +20,7 @@ logger = logging.getLogger(__name__)
class CircuitBreakerState(Enum):
"""Circuit breaker states"""
CLOSED = "closed" # Normal operation
OPEN = "open" # Failing, blocking requests
HALF_OPEN = "half_open" # Testing if service recovered
@@ -28,6 +29,7 @@ class CircuitBreakerState(Enum):
@dataclass
class CircuitBreakerStats:
"""Circuit breaker statistics"""
failure_count: int = 0
success_count: int = 0
last_failure_time: Optional[datetime] = None
@@ -51,7 +53,9 @@ class CircuitBreaker:
if self.state == CircuitBreakerState.OPEN:
# Check if reset timeout has passed
if (datetime.utcnow() - self.stats.state_change_time).total_seconds() * 1000 > self.config.circuit_breaker_reset_timeout_ms:
if (
datetime.utcnow() - self.stats.state_change_time
).total_seconds() * 1000 > self.config.circuit_breaker_reset_timeout_ms:
self._transition_to_half_open()
return True
return False
@@ -72,7 +76,9 @@ class CircuitBreaker:
# Reset failure count on success
self.stats.failure_count = 0
logger.debug(f"Circuit breaker [{self.provider_name}]: Success recorded, state={self.state.value}")
logger.debug(
f"Circuit breaker [{self.provider_name}]: Success recorded, state={self.state.value}"
)
def record_failure(self):
"""Record failed request"""
@@ -85,27 +91,35 @@ class CircuitBreaker:
elif self.state == CircuitBreakerState.HALF_OPEN:
self._transition_to_open()
logger.warning(f"Circuit breaker [{self.provider_name}]: Failure recorded, "
f"count={self.stats.failure_count}, state={self.state.value}")
logger.warning(
f"Circuit breaker [{self.provider_name}]: Failure recorded, "
f"count={self.stats.failure_count}, state={self.state.value}"
)
def _transition_to_open(self):
"""Transition to OPEN state"""
self.state = CircuitBreakerState.OPEN
self.stats.state_change_time = datetime.utcnow()
logger.error(f"Circuit breaker [{self.provider_name}]: OPENED after {self.stats.failure_count} failures")
logger.error(
f"Circuit breaker [{self.provider_name}]: OPENED after {self.stats.failure_count} failures"
)
def _transition_to_half_open(self):
"""Transition to HALF_OPEN state"""
self.state = CircuitBreakerState.HALF_OPEN
self.stats.state_change_time = datetime.utcnow()
logger.info(f"Circuit breaker [{self.provider_name}]: Transitioning to HALF_OPEN for testing")
logger.info(
f"Circuit breaker [{self.provider_name}]: Transitioning to HALF_OPEN for testing"
)
def _transition_to_closed(self):
"""Transition to CLOSED state"""
self.state = CircuitBreakerState.CLOSED
self.stats.state_change_time = datetime.utcnow()
self.stats.failure_count = 0 # Reset failure count
logger.info(f"Circuit breaker [{self.provider_name}]: CLOSED - service recovered")
logger.info(
f"Circuit breaker [{self.provider_name}]: CLOSED - service recovered"
)
def get_stats(self) -> Dict[str, Any]:
"""Get circuit breaker statistics"""
@@ -113,10 +127,17 @@ class CircuitBreaker:
"state": self.state.value,
"failure_count": self.stats.failure_count,
"success_count": self.stats.success_count,
"last_failure_time": self.stats.last_failure_time.isoformat() if self.stats.last_failure_time else None,
"last_success_time": self.stats.last_success_time.isoformat() if self.stats.last_success_time else None,
"last_failure_time": self.stats.last_failure_time.isoformat()
if self.stats.last_failure_time
else None,
"last_success_time": self.stats.last_success_time.isoformat()
if self.stats.last_success_time
else None,
"state_change_time": self.stats.state_change_time.isoformat(),
"time_in_current_state_ms": (datetime.utcnow() - self.stats.state_change_time).total_seconds() * 1000
"time_in_current_state_ms": (
datetime.utcnow() - self.stats.state_change_time
).total_seconds()
* 1000,
}
@@ -132,7 +153,7 @@ class RetryManager:
*args,
retryable_exceptions: tuple = (Exception,),
non_retryable_exceptions: tuple = (RateLimitError,),
**kwargs
**kwargs,
) -> Any:
"""Execute function with retry logic"""
last_exception = None
@@ -149,11 +170,15 @@ class RetryManager:
last_exception = e
if attempt == self.config.max_retries:
logger.error(f"All {self.config.max_retries + 1} attempts failed. Last error: {e}")
logger.error(
f"All {self.config.max_retries + 1} attempts failed. Last error: {e}"
)
raise
delay = self._calculate_delay(attempt)
logger.warning(f"Attempt {attempt + 1} failed: {e}. Retrying in {delay}ms...")
logger.warning(
f"Attempt {attempt + 1} failed: {e}. Retrying in {delay}ms..."
)
await asyncio.sleep(delay / 1000.0)
@@ -165,10 +190,13 @@ class RetryManager:
def _calculate_delay(self, attempt: int) -> int:
"""Calculate delay for exponential backoff"""
delay = self.config.retry_delay_ms * (self.config.retry_exponential_base ** attempt)
delay = self.config.retry_delay_ms * (
self.config.retry_exponential_base**attempt
)
# Add some jitter to prevent thundering herd
import random
jitter = random.uniform(0.8, 1.2)
return int(delay * jitter)
@@ -181,18 +209,16 @@ class TimeoutManager:
self.config = config
async def execute_with_timeout(
self,
func: Callable,
*args,
timeout_override: Optional[int] = None,
**kwargs
self, func: Callable, *args, timeout_override: Optional[int] = None, **kwargs
) -> Any:
"""Execute function with timeout"""
timeout_ms = timeout_override or self.config.timeout_ms
timeout_seconds = timeout_ms / 1000.0
try:
return await asyncio.wait_for(func(*args, **kwargs), timeout=timeout_seconds)
return await asyncio.wait_for(
func(*args, **kwargs), timeout=timeout_seconds
)
except asyncio.TimeoutError:
error_msg = f"Request timed out after {timeout_ms}ms"
logger.error(error_msg)
@@ -216,7 +242,7 @@ class ResilienceManager:
retryable_exceptions: tuple = (Exception,),
non_retryable_exceptions: tuple = (RateLimitError,),
timeout_override: Optional[int] = None,
**kwargs
**kwargs,
) -> Any:
"""Execute function with full resilience patterns"""
@@ -237,14 +263,16 @@ class ResilienceManager:
retryable_exceptions=retryable_exceptions,
non_retryable_exceptions=non_retryable_exceptions,
timeout_override=timeout_override,
**kwargs
**kwargs,
)
# Record success
self.circuit_breaker.record_success()
execution_time = (time.time() - start_time) * 1000
logger.debug(f"Resilient execution succeeded for {self.provider_name} in {execution_time:.2f}ms")
logger.debug(
f"Resilient execution succeeded for {self.provider_name} in {execution_time:.2f}ms"
)
return result
@@ -253,7 +281,9 @@ class ResilienceManager:
self.circuit_breaker.record_failure()
execution_time = (time.time() - start_time) * 1000
logger.error(f"Resilient execution failed for {self.provider_name} after {execution_time:.2f}ms: {e}")
logger.error(
f"Resilient execution failed for {self.provider_name} after {execution_time:.2f}ms: {e}"
)
raise
@@ -281,8 +311,8 @@ class ResilienceManager:
"config": {
"max_retries": self.config.max_retries,
"timeout_ms": self.config.timeout_ms,
"circuit_breaker_threshold": self.config.circuit_breaker_threshold
}
"circuit_breaker_threshold": self.config.circuit_breaker_threshold,
},
}
@@ -293,11 +323,15 @@ class ResilienceManagerFactory:
_default_config = ResilienceConfig()
@classmethod
def get_manager(cls, provider_name: str, config: Optional[ResilienceConfig] = None) -> ResilienceManager:
def get_manager(
cls, provider_name: str, config: Optional[ResilienceConfig] = None
) -> ResilienceManager:
"""Get or create resilience manager for provider"""
if provider_name not in cls._managers:
manager_config = config or cls._default_config
cls._managers[provider_name] = ResilienceManager(manager_config, provider_name)
cls._managers[provider_name] = ResilienceManager(
manager_config, provider_name
)
return cls._managers[provider_name]
@@ -305,8 +339,7 @@ class ResilienceManagerFactory:
def get_all_health_status(cls) -> Dict[str, Dict[str, Any]]:
"""Get health status for all managed providers"""
return {
name: manager.get_health_status()
for name, manager in cls._managers.items()
name: manager.get_health_status() for name, manager in cls._managers.items()
}
@classmethod

View File

@@ -12,18 +12,28 @@ from typing import Dict, Any, Optional, List, AsyncGenerator
from datetime import datetime
from .models import (
ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
ModelInfo, ProviderStatus, LLMMetrics
ChatRequest,
ChatResponse,
EmbeddingRequest,
EmbeddingResponse,
ModelInfo,
ProviderStatus,
LLMMetrics,
)
from .config import config_manager, ProviderConfig
from ...core.config import settings
from .resilience import ResilienceManagerFactory
# from .metrics import metrics_collector
from .providers import BaseLLMProvider, PrivateModeProvider
from .exceptions import (
LLMError, ProviderError, SecurityError, ConfigurationError,
ValidationError, TimeoutError
LLMError,
ProviderError,
SecurityError,
ConfigurationError,
ValidationError,
TimeoutError,
)
logger = logging.getLogger(__name__)
@@ -52,7 +62,9 @@ class LLMService:
try:
# Get configuration
config = config_manager.get_config()
logger.info(f"Initializing LLM service with {len(config.providers)} configured providers")
logger.info(
f"Initializing LLM service with {len(config.providers)} configured providers"
)
# Initialize enabled providers
enabled_providers = config_manager.get_enabled_providers()
@@ -70,13 +82,17 @@ class LLMService:
default_provider = config.default_provider
if default_provider not in self._providers:
available_providers = list(self._providers.keys())
logger.warning(f"Default provider '{default_provider}' not available, using '{available_providers[0]}'")
logger.warning(
f"Default provider '{default_provider}' not available, using '{available_providers[0]}'"
)
config.default_provider = available_providers[0]
self._initialized = True
initialization_time = (time.time() - start_time) * 1000
logger.info(f"LLM Service initialized successfully in {initialization_time:.2f}ms")
logger.info(
f"LLM Service initialized successfully in {initialization_time:.2f}ms"
)
logger.info(f"Available providers: {list(self._providers.keys())}")
except Exception as e:
@@ -106,12 +122,16 @@ class LLMService:
# Test provider health
health_status = await provider.health_check()
if health_status.status == "unavailable":
logger.error(f"Provider '{provider_name}' failed health check: {health_status.error_message}")
logger.error(
f"Provider '{provider_name}' failed health check: {health_status.error_message}"
)
return
# Register provider
self._providers[provider_name] = provider
logger.info(f"Provider '{provider_name}' initialized successfully (status: {health_status.status})")
logger.info(
f"Provider '{provider_name}' initialized successfully (status: {health_status.status})"
)
# Fetch and update models dynamically
await self._refresh_provider_models(provider_name, provider)
@@ -126,7 +146,9 @@ class LLMService:
else:
raise ConfigurationError(f"Unknown provider type: {config.name}")
async def _refresh_provider_models(self, provider_name: str, provider: BaseLLMProvider):
async def _refresh_provider_models(
self, provider_name: str, provider: BaseLLMProvider
):
"""Fetch and update models dynamically from provider"""
try:
# Get models from provider
@@ -136,10 +158,14 @@ class LLMService:
# Update configuration
await config_manager.refresh_provider_models(provider_name, model_ids)
logger.info(f"Refreshed {len(model_ids)} models for provider '{provider_name}': {model_ids}")
logger.info(
f"Refreshed {len(model_ids)} models for provider '{provider_name}': {model_ids}"
)
except Exception as e:
logger.error(f"Failed to refresh models for provider '{provider_name}': {e}")
logger.error(
f"Failed to refresh models for provider '{provider_name}': {e}"
)
async def create_chat_completion(self, request: ChatRequest) -> ChatResponse:
"""Create chat completion with security and resilience"""
@@ -157,8 +183,10 @@ class LLMService:
provider = self._providers.get(provider_name)
if not provider:
raise ProviderError(f"No available provider for model '{request.model}'", provider=provider_name)
raise ProviderError(
f"No available provider for model '{request.model}'",
provider=provider_name,
)
# Execute with resilience
resilience_manager = ResilienceManagerFactory.get_manager(provider_name)
@@ -169,22 +197,18 @@ class LLMService:
provider.create_chat_completion,
request,
retryable_exceptions=(ProviderError, TimeoutError),
non_retryable_exceptions=(ValidationError,)
non_retryable_exceptions=(ValidationError,),
)
# Record successful request - metrics disabled
total_latency = (time.time() - start_time) * 1000
return response
except Exception as e:
# Record failed request - metrics disabled
total_latency = (time.time() - start_time) * 1000
error_code = getattr(e, 'error_code', e.__class__.__name__)
error_code = getattr(e, "error_code", e.__class__.__name__)
logger.exception(
"Chat completion failed for provider %s (model=%s, latency=%.2fms, error=%s)",
@@ -195,7 +219,9 @@ class LLMService:
)
raise
async def create_chat_completion_stream(self, request: ChatRequest) -> AsyncGenerator[Dict[str, Any], None]:
async def create_chat_completion_stream(
self, request: ChatRequest
) -> AsyncGenerator[Dict[str, Any], None]:
"""Create streaming chat completion"""
if not self._initialized:
await self.initialize()
@@ -203,13 +229,15 @@ class LLMService:
# Security validation disabled - always allow streaming requests
risk_score = 0.0
# Get provider
provider_name = self._get_provider_for_model(request.model)
provider = self._providers.get(provider_name)
if not provider:
raise ProviderError(f"No available provider for model '{request.model}'", provider=provider_name)
raise ProviderError(
f"No available provider for model '{request.model}'",
provider=provider_name,
)
# Execute streaming with resilience
resilience_manager = ResilienceManagerFactory.get_manager(provider_name)
@@ -219,13 +247,13 @@ class LLMService:
provider.create_chat_completion_stream,
request,
retryable_exceptions=(ProviderError, TimeoutError),
non_retryable_exceptions=(ValidationError,)
non_retryable_exceptions=(ValidationError,),
):
yield chunk
except Exception as e:
# Record streaming failure - metrics disabled
error_code = getattr(e, 'error_code', e.__class__.__name__)
error_code = getattr(e, "error_code", e.__class__.__name__)
logger.exception(
"Streaming chat completion failed for provider %s (model=%s, error=%s)",
provider_name,
@@ -242,13 +270,15 @@ class LLMService:
# Security validation disabled - always allow embedding requests
risk_score = 0.0
# Get provider
provider_name = self._get_provider_for_model(request.model)
provider = self._providers.get(provider_name)
if not provider:
raise ProviderError(f"No available provider for model '{request.model}'", provider=provider_name)
raise ProviderError(
f"No available provider for model '{request.model}'",
provider=provider_name,
)
# Execute with resilience
resilience_manager = ResilienceManagerFactory.get_manager(provider_name)
@@ -259,20 +289,18 @@ class LLMService:
provider.create_embedding,
request,
retryable_exceptions=(ProviderError, TimeoutError),
non_retryable_exceptions=(ValidationError,)
non_retryable_exceptions=(ValidationError,),
)
# Record successful request - metrics disabled
total_latency = (time.time() - start_time) * 1000
return response
except Exception as e:
# Record failed request - metrics disabled
total_latency = (time.time() - start_time) * 1000
error_code = getattr(e, 'error_code', e.__class__.__name__)
error_code = getattr(e, "error_code", e.__class__.__name__)
logger.exception(
"Embedding request failed for provider %s (model=%s, latency=%.2fms, error=%s)",
provider_name,
@@ -327,7 +355,7 @@ class LLMService:
status="unavailable",
last_check=datetime.utcnow(),
error_message=str(e),
models_available=[]
models_available=[],
)
return status_dict
@@ -336,10 +364,7 @@ class LLMService:
"""Get service metrics - metrics disabled"""
# return metrics_collector.get_metrics()
return LLMMetrics(
total_requests=0,
success_rate=0.0,
avg_latency_ms=0,
error_rates={}
total_requests=0, success_rate=0.0, avg_latency_ms=0, error_rates={}
)
def get_health_summary(self) -> Dict[str, Any]:
@@ -349,11 +374,13 @@ class LLMService:
return {
"service_status": "healthy" if self._initialized else "initializing",
"startup_time": self._startup_time.isoformat() if self._startup_time else None,
"startup_time": self._startup_time.isoformat()
if self._startup_time
else None,
"provider_count": len(self._providers),
"active_providers": list(self._providers.keys()),
"metrics": {"status": "disabled"},
"resilience": resilience_health
"resilience": resilience_health,
}
def _get_provider_for_model(self, model: str) -> str:

View File

@@ -15,6 +15,7 @@ from app.core.logging import log_module_event, log_security_event
@dataclass
class MetricData:
"""Individual metric data point"""
timestamp: datetime
value: float
labels: Dict[str, str] = field(default_factory=dict)
@@ -23,6 +24,7 @@ class MetricData:
@dataclass
class RequestMetrics:
"""Request-related metrics"""
total_requests: int = 0
successful_requests: int = 0
failed_requests: int = 0
@@ -37,6 +39,7 @@ class RequestMetrics:
@dataclass
class SystemMetrics:
"""System-related metrics"""
uptime: float = 0.0
memory_usage: float = 0.0
cpu_usage: float = 0.0
@@ -78,12 +81,16 @@ class MetricsService:
# Store historical data
self._store_metric("uptime", self.system_metrics.uptime)
self._store_metric("active_connections", self.system_metrics.active_connections)
self._store_metric(
"active_connections", self.system_metrics.active_connections
)
await asyncio.sleep(60) # Collect every minute
except Exception as e:
log_module_event("metrics_service", "system_metrics_error", {"error": str(e)})
log_module_event(
"metrics_service", "system_metrics_error", {"error": str(e)}
)
await asyncio.sleep(60)
async def _cleanup_old_metrics(self):
@@ -103,20 +110,20 @@ class MetricsService:
log_module_event("metrics_service", "cleanup_error", {"error": str(e)})
await asyncio.sleep(3600)
def _store_metric(self, name: str, value: float, labels: Optional[Dict[str, str]] = None):
def _store_metric(
self, name: str, value: float, labels: Optional[Dict[str, str]] = None
):
"""Store a metric data point"""
if labels is None:
labels = {}
metric_data = MetricData(
timestamp=datetime.now(),
value=value,
labels=labels
)
metric_data = MetricData(timestamp=datetime.now(), value=value, labels=labels)
self.metric_history[name].append(metric_data)
def start_request(self, request_id: str, endpoint: str, user_id: Optional[str] = None):
def start_request(
self, request_id: str, endpoint: str, user_id: Optional[str] = None
):
"""Start tracking a request"""
self.active_requests[request_id] = time.time()
@@ -124,13 +131,15 @@ class MetricsService:
self.request_metrics.total_requests += 1
# Track by endpoint
self.request_metrics.requests_by_endpoint[endpoint] = \
self.request_metrics.requests_by_endpoint[endpoint] = (
self.request_metrics.requests_by_endpoint.get(endpoint, 0) + 1
)
# Track by user
if user_id:
self.request_metrics.requests_by_user[user_id] = \
self.request_metrics.requests_by_user[user_id] = (
self.request_metrics.requests_by_user.get(user_id, 0) + 1
)
# Store metric
self._store_metric("requests_total", self.request_metrics.total_requests)
@@ -139,9 +148,14 @@ class MetricsService:
if user_id:
self._store_metric("requests_by_user", 1, {"user_id": user_id})
def end_request(self, request_id: str, success: bool = True,
model: Optional[str] = None, tokens_used: int = 0,
cost: float = 0.0):
def end_request(
self,
request_id: str,
success: bool = True,
model: Optional[str] = None,
tokens_used: int = 0,
cost: float = 0.0,
):
"""End tracking a request"""
if request_id not in self.active_requests:
return
@@ -158,7 +172,9 @@ class MetricsService:
# Update average response time
if self.response_times:
self.request_metrics.average_response_time = sum(self.response_times) / len(self.response_times)
self.request_metrics.average_response_time = sum(self.response_times) / len(
self.response_times
)
# Update token and cost metrics
self.request_metrics.total_tokens_used += tokens_used
@@ -166,8 +182,9 @@ class MetricsService:
# Track by model
if model:
self.request_metrics.requests_by_model[model] = \
self.request_metrics.requests_by_model[model] = (
self.request_metrics.requests_by_model.get(model, 0) + 1
)
# Store metrics
self._store_metric("response_time", response_time)
@@ -180,8 +197,13 @@ class MetricsService:
# Clean up
del self.active_requests[request_id]
def record_error(self, error_type: str, error_message: str,
endpoint: Optional[str] = None, user_id: Optional[str] = None):
def record_error(
self,
error_type: str,
error_message: str,
endpoint: Optional[str] = None,
user_id: Optional[str] = None,
):
"""Record an error occurrence"""
labels = {"error_type": error_type}
@@ -193,16 +215,23 @@ class MetricsService:
self._store_metric("errors_total", 1, labels)
# Log security events for authentication/authorization errors
if error_type in ["authentication_failed", "authorization_failed", "invalid_api_key"]:
log_security_event(error_type, user_id or "anonymous", {
"error": error_message,
"endpoint": endpoint
})
if error_type in [
"authentication_failed",
"authorization_failed",
"invalid_api_key",
]:
log_security_event(
error_type,
user_id or "anonymous",
{"error": error_message, "endpoint": endpoint},
)
def record_module_status(self, module_name: str, is_healthy: bool):
"""Record module health status"""
self.system_metrics.module_status[module_name] = is_healthy
self._store_metric("module_health", 1 if is_healthy else 0, {"module": module_name})
self._store_metric(
"module_health", 1 if is_healthy else 0, {"module": module_name}
)
def get_current_metrics(self) -> Dict[str, Any]:
"""Get current metrics snapshot"""
@@ -212,25 +241,28 @@ class MetricsService:
"successful_requests": self.request_metrics.successful_requests,
"failed_requests": self.request_metrics.failed_requests,
"success_rate": (
self.request_metrics.successful_requests / self.request_metrics.total_requests
if self.request_metrics.total_requests > 0 else 0
self.request_metrics.successful_requests
/ self.request_metrics.total_requests
if self.request_metrics.total_requests > 0
else 0
),
"average_response_time": self.request_metrics.average_response_time,
"total_tokens_used": self.request_metrics.total_tokens_used,
"total_cost": self.request_metrics.total_cost,
"requests_by_model": dict(self.request_metrics.requests_by_model),
"requests_by_user": dict(self.request_metrics.requests_by_user),
"requests_by_endpoint": dict(self.request_metrics.requests_by_endpoint)
"requests_by_endpoint": dict(self.request_metrics.requests_by_endpoint),
},
"system_metrics": {
"uptime": self.system_metrics.uptime,
"active_connections": self.system_metrics.active_connections,
"module_status": dict(self.system_metrics.module_status)
}
"module_status": dict(self.system_metrics.module_status),
},
}
def get_metrics_history(self, metric_name: str,
hours: int = 1) -> List[Dict[str, Any]]:
def get_metrics_history(
self, metric_name: str, hours: int = 1
) -> List[Dict[str, Any]]:
"""Get historical metrics data"""
if metric_name not in self.metric_history:
return []
@@ -241,7 +273,7 @@ class MetricsService:
{
"timestamp": data.timestamp.isoformat(),
"value": data.value,
"labels": data.labels
"labels": data.labels,
}
for data in self.metric_history[metric_name]
if data.timestamp > cutoff_time
@@ -251,18 +283,27 @@ class MetricsService:
"""Get top metrics by type"""
if metric_type == "models":
return dict(
sorted(self.request_metrics.requests_by_model.items(),
key=lambda x: x[1], reverse=True)[:limit]
sorted(
self.request_metrics.requests_by_model.items(),
key=lambda x: x[1],
reverse=True,
)[:limit]
)
elif metric_type == "users":
return dict(
sorted(self.request_metrics.requests_by_user.items(),
key=lambda x: x[1], reverse=True)[:limit]
sorted(
self.request_metrics.requests_by_user.items(),
key=lambda x: x[1],
reverse=True,
)[:limit]
)
elif metric_type == "endpoints":
return dict(
sorted(self.request_metrics.requests_by_endpoint.items(),
key=lambda x: x[1], reverse=True)[:limit]
sorted(
self.request_metrics.requests_by_endpoint.items(),
key=lambda x: x[1],
reverse=True,
)[:limit]
)
else:
return {}
@@ -275,11 +316,13 @@ class MetricsService:
"active_connections": self.system_metrics.active_connections,
"total_requests": self.request_metrics.total_requests,
"success_rate": (
self.request_metrics.successful_requests / self.request_metrics.total_requests
if self.request_metrics.total_requests > 0 else 1.0
self.request_metrics.successful_requests
/ self.request_metrics.total_requests
if self.request_metrics.total_requests > 0
else 1.0
),
"modules": self.system_metrics.module_status,
"timestamp": datetime.now().isoformat()
"timestamp": datetime.now().isoformat(),
}
async def reset_metrics(self):
@@ -305,4 +348,5 @@ def setup_metrics(app):
# Initialize metrics service
import asyncio
asyncio.create_task(metrics_service.initialize())

View File

@@ -18,6 +18,7 @@ logger = get_logger(__name__)
@dataclass
class ModuleManifest:
"""Module manifest loaded from module.yaml"""
name: str
version: str
description: str
@@ -69,7 +70,9 @@ class ModuleConfigManager:
self.schemas: Dict[str, Dict] = {}
self.configs: Dict[str, Dict] = {}
async def discover_modules(self, modules_path: str = "modules") -> Dict[str, ModuleManifest]:
async def discover_modules(
self, modules_path: str = "modules"
) -> Dict[str, ModuleManifest]:
"""Discover modules from filesystem using module.yaml manifests"""
discovered_modules = {}
@@ -91,14 +94,16 @@ class ModuleConfigManager:
if not manifest_path.exists():
# Check if it's a legacy module (has main.py but no manifest)
if (module_dir / "main.py").exists():
logger.info(f"Legacy module found (no manifest): {module_dir.name}")
logger.info(
f"Legacy module found (no manifest): {module_dir.name}"
)
# Create a basic manifest for legacy modules
manifest = ModuleManifest(
name=module_dir.name,
version="1.0.0",
description=f"Legacy {module_dir.name} module",
author="System",
category="legacy"
category="legacy",
)
discovered_modules[manifest.name] = manifest
continue
@@ -118,14 +123,16 @@ class ModuleConfigManager:
async def _load_module_manifest(self, manifest_path: Path) -> ModuleManifest:
"""Load and validate a module manifest file"""
try:
with open(manifest_path, 'r', encoding='utf-8') as f:
with open(manifest_path, "r", encoding="utf-8") as f:
manifest_data = yaml.safe_load(f)
# Validate required fields
required_fields = ['name', 'version', 'description', 'author']
required_fields = ["name", "version", "description", "author"]
for field in required_fields:
if field not in manifest_data:
raise ConfigurationError(f"Missing required field '{field}' in {manifest_path}")
raise ConfigurationError(
f"Missing required field '{field}' in {manifest_path}"
)
manifest = ModuleManifest(**manifest_data)
@@ -147,7 +154,7 @@ class ModuleConfigManager:
async def _load_module_schema(self, module_name: str, schema_path: Path):
"""Load JSON schema for module configuration"""
try:
with open(schema_path, 'r', encoding='utf-8') as f:
with open(schema_path, "r", encoding="utf-8") as f:
schema = json.load(f)
self.schemas[module_name] = schema
@@ -174,26 +181,32 @@ class ModuleConfigManager:
"""Validate module configuration against its schema"""
schema = self.schemas.get(module_name)
if not schema:
logger.info(f"No schema found for module {module_name}, skipping validation")
logger.info(
f"No schema found for module {module_name}, skipping validation"
)
return {"valid": True, "errors": []}
try:
validate(instance=config, schema=schema, format_checker=draft7_format_checker)
validate(
instance=config, schema=schema, format_checker=draft7_format_checker
)
return {"valid": True, "errors": []}
except ValidationError as e:
return {
"valid": False,
"errors": [{
"errors": [
{
"path": list(e.path),
"message": e.message,
"invalid_value": e.instance
}]
"invalid_value": e.instance,
}
],
}
except Exception as e:
return {
"valid": False,
"errors": [{"message": f"Schema validation failed: {str(e)}"}]
"errors": [{"message": f"Schema validation failed: {str(e)}"}],
}
async def save_module_config(self, module_name: str, config: Dict) -> bool:
@@ -202,7 +215,9 @@ class ModuleConfigManager:
validation_result = await self.validate_config(module_name, config)
if not validation_result["valid"]:
error_messages = [error["message"] for error in validation_result["errors"]]
raise ConfigurationError(f"Invalid configuration for {module_name}: {', '.join(error_messages)}")
raise ConfigurationError(
f"Invalid configuration for {module_name}: {', '.join(error_messages)}"
)
# Save configuration
self.configs[module_name] = config
@@ -214,7 +229,7 @@ class ModuleConfigManager:
config_file = config_dir / f"{module_name}.json"
try:
with open(config_file, 'w', encoding='utf-8') as f:
with open(config_file, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2)
logger.info(f"Saved configuration for module: {module_name}")
@@ -233,7 +248,7 @@ class ModuleConfigManager:
for config_file in config_dir.glob("*.json"):
module_name = config_file.stem
try:
with open(config_file, 'r', encoding='utf-8') as f:
with open(config_file, "r", encoding="utf-8") as f:
config = json.load(f)
self.configs[module_name] = config
@@ -246,7 +261,8 @@ class ModuleConfigManager:
"""List all discovered modules with their metadata"""
modules = []
for name, manifest in self.manifests.items():
modules.append({
modules.append(
{
"name": manifest.name,
"version": manifest.version,
"description": manifest.description,
@@ -258,12 +274,12 @@ class ModuleConfigManager:
"consumes": manifest.consumes,
"has_schema": name in self.schemas,
"has_config": name in self.configs,
"ui_config": manifest.ui_config
})
"ui_config": manifest.ui_config,
}
)
return modules
async def update_module_status(self, module_name: str, enabled: bool) -> bool:
"""Update module enabled status"""
manifest = self.manifests.get(module_name)
@@ -279,7 +295,7 @@ class ModuleConfigManager:
if manifest_path.exists():
try:
manifest_dict = asdict(manifest)
with open(manifest_path, 'w', encoding='utf-8') as f:
with open(manifest_path, "w", encoding="utf-8") as f:
yaml.dump(manifest_dict, f, default_flow_style=False)
logger.info(f"Updated module status: {module_name} enabled={enabled}")

View File

@@ -23,6 +23,7 @@ logger = get_logger(__name__)
@dataclass
class ModuleConfig:
"""Configuration for a module"""
name: str
enabled: bool = True
config: Dict[str, Any] = None
@@ -52,21 +53,24 @@ class ModuleFileWatcher(FileSystemEventHandler):
return parts[0] if parts else None
def on_modified(self, event):
if event.is_directory or not event.src_path.endswith('.py'):
if event.is_directory or not event.src_path.endswith(".py"):
return
module_name = self._resolve_module_name(event.src_path)
if not module_name or module_name not in self.module_manager.modules:
return
log_module_event("hot_reload", "file_changed", {
"module": module_name,
"file": event.src_path
})
log_module_event(
"hot_reload",
"file_changed",
{"module": module_name, "file": event.src_path},
)
loop = self.module_manager.loop
if not loop or loop.is_closed():
logger.debug("Hot reload skipped for %s; event loop unavailable", module_name)
logger.debug(
"Hot reload skipped for %s; event loop unavailable", module_name
)
return
try:
@@ -75,7 +79,8 @@ class ModuleFileWatcher(FileSystemEventHandler):
loop,
)
future.add_done_callback(
lambda f: f.exception() and logger.warning(
lambda f: f.exception()
and logger.warning(
"Module reload error for %s: %s", module_name, f.exception()
)
)
@@ -95,7 +100,9 @@ class ModuleManager:
self.file_observer = None
self.fastapi_app = None
self.loop: Optional[asyncio.AbstractEventLoop] = None
self.modules_root = (Path(__file__).resolve().parent.parent / "modules").resolve()
self.modules_root = (
Path(__file__).resolve().parent.parent / "modules"
).resolve()
async def initialize(self, fastapi_app=None):
"""Initialize the module manager and load all modules"""
@@ -117,13 +124,23 @@ class ModuleManager:
await self._load_modules()
self.initialized = True
log_module_event("module_manager", "initialized", {
log_module_event(
"module_manager",
"initialized",
{
"modules_count": len(self.modules),
"enabled_modules": [name for name, config in self.module_configs.items() if config.enabled]
})
"enabled_modules": [
name
for name, config in self.module_configs.items()
if config.enabled
],
},
)
except Exception as e:
log_module_event("module_manager", "initialization_failed", {"error": str(e)})
log_module_event(
"module_manager", "initialization_failed", {"error": str(e)}
)
raise ModuleLoadError(f"Failed to initialize module manager: {str(e)}")
async def _load_module_configs(self):
@@ -138,7 +155,9 @@ class ModuleManager:
logger.warning("Modules directory not found at %s", self.modules_root)
return
discovered_manifests = await module_config_manager.discover_modules(str(self.modules_root))
discovered_manifests = await module_config_manager.discover_modules(
str(self.modules_root)
)
# Load saved configurations
await module_config_manager.load_saved_configs()
@@ -150,7 +169,9 @@ class ModuleManager:
for name, manifest in discovered_manifests.items():
# Skip modules that are now core infrastructure
if name in EXCLUDED_MODULES:
logger.info(f"Skipping module '{name}' - now integrated as core infrastructure")
logger.info(
f"Skipping module '{name}' - now integrated as core infrastructure"
)
continue
saved_config = module_config_manager.get_module_config(name)
@@ -159,19 +180,25 @@ class ModuleManager:
name=manifest.name,
enabled=manifest.enabled,
config=saved_config,
dependencies=manifest.dependencies
dependencies=manifest.dependencies,
)
self.module_configs[name] = module_config
log_module_event(name, "discovered", {
log_module_event(
name,
"discovered",
{
"version": manifest.version,
"description": manifest.description,
"enabled": manifest.enabled,
"dependencies": manifest.dependencies
})
"dependencies": manifest.dependencies,
},
)
logger.info(f"Discovered {len(discovered_manifests)} modules: {list(discovered_manifests.keys())}")
logger.info(
f"Discovered {len(discovered_manifests)} modules: {list(discovered_manifests.keys())}"
)
except Exception as e:
logger.error(f"Failed to discover modules: {e}")
@@ -186,9 +213,7 @@ class ModuleManager:
"""Fallback to legacy hard-coded module loading"""
logger.warning("Falling back to legacy module configuration")
default_modules = [
ModuleConfig(name="rag", enabled=True, config={})
]
default_modules = [ModuleConfig(name="rag", enabled=True, config={})]
for config in default_modules:
self.module_configs[config.name] = config
@@ -212,7 +237,9 @@ class ModuleManager:
def visit(module_name: str):
if module_name in temp_visited:
raise ModuleLoadError(f"Circular dependency detected involving module: {module_name}")
raise ModuleLoadError(
f"Circular dependency detected involving module: {module_name}"
)
if module_name in visited:
return
@@ -263,14 +290,14 @@ class ModuleManager:
module_instance = None
# Pattern 1: {module_name}_module (e.g., cache_module)
if hasattr(module, f'{module_name}_module'):
module_instance = getattr(module, f'{module_name}_module')
if hasattr(module, f"{module_name}_module"):
module_instance = getattr(module, f"{module_name}_module")
# Pattern 2: Just 'module' attribute
elif hasattr(module, 'module'):
module_instance = getattr(module, 'module')
elif hasattr(module, "module"):
module_instance = getattr(module, "module")
# Pattern 3: Module class with same name as module (e.g., CacheModule)
elif hasattr(module, f'{module_name.title()}Module'):
module_class = getattr(module, f'{module_name.title()}Module')
elif hasattr(module, f"{module_name.title()}Module"):
module_class = getattr(module, f"{module_name.title()}Module")
if callable(module_class):
module_instance = module_class()
else:
@@ -283,14 +310,17 @@ class ModuleManager:
# Initialize the module if it has an init function
module_initialized = False
if hasattr(self.modules[module_name], 'initialize'):
if hasattr(self.modules[module_name], "initialize"):
try:
import inspect
init_method = self.modules[module_name].initialize
sig = inspect.signature(init_method)
param_count = len([p for p in sig.parameters.values() if p.name != 'self'])
param_count = len(
[p for p in sig.parameters.values() if p.name != "self"]
)
if hasattr(self.modules[module_name], 'config'):
if hasattr(self.modules[module_name], "config"):
# Pass config if it's a BaseModule
self.modules[module_name].config.update(config.config)
await self.modules[module_name].initialize()
@@ -303,7 +333,9 @@ class ModuleManager:
module_initialized = True
log_module_event(module_name, "initialized", {"success": True})
except Exception as e:
log_module_event(module_name, "initialization_failed", {"error": str(e)})
log_module_event(
module_name, "initialization_failed", {"error": str(e)}
)
module_initialized = False
else:
# Module doesn't have initialize method, mark as initialized anyway
@@ -320,26 +352,32 @@ class ModuleManager:
permissions = []
# New BaseModule method
if hasattr(self.modules[module_name], 'get_required_permissions'):
if hasattr(self.modules[module_name], "get_required_permissions"):
try:
permissions = self.modules[module_name].get_required_permissions()
log_module_event(module_name, "permissions_registered", {
"permissions_count": len(permissions),
"type": "BaseModule"
})
log_module_event(
module_name,
"permissions_registered",
{"permissions_count": len(permissions), "type": "BaseModule"},
)
except Exception as e:
log_module_event(module_name, "permissions_failed", {"error": str(e)})
log_module_event(
module_name, "permissions_failed", {"error": str(e)}
)
# Legacy method
elif hasattr(self.modules[module_name], 'get_permissions'):
elif hasattr(self.modules[module_name], "get_permissions"):
try:
permissions = self.modules[module_name].get_permissions()
log_module_event(module_name, "permissions_registered", {
"permissions_count": len(permissions),
"type": "legacy"
})
log_module_event(
module_name,
"permissions_registered",
{"permissions_count": len(permissions), "type": "legacy"},
)
except Exception as e:
log_module_event(module_name, "permissions_failed", {"error": str(e)})
log_module_event(
module_name, "permissions_failed", {"error": str(e)}
)
# Register permissions with the permission system
if permissions:
@@ -352,21 +390,29 @@ class ModuleManager:
except ImportError as e:
error_msg = f"Module {module_name} import failed: {str(e)}"
log_module_event(module_name, "load_failed", {"error": error_msg, "type": "ImportError"})
log_module_event(
module_name, "load_failed", {"error": error_msg, "type": "ImportError"}
)
# For critical modules, we might want to fail completely
if module_name in ['security', 'cache']:
if module_name in ["security", "cache"]:
raise ModuleLoadError(error_msg)
# For optional modules, log warning but continue
import warnings
warnings.warn(f"Optional module {module_name} failed to load: {str(e)}")
except Exception as e:
error_msg = f"Module {module_name} loading failed: {str(e)}"
log_module_event(module_name, "load_failed", {"error": error_msg, "type": type(e).__name__})
log_module_event(
module_name,
"load_failed",
{"error": error_msg, "type": type(e).__name__},
)
# For critical modules, we might want to fail completely
if module_name in ['security', 'cache']:
if module_name in ["security", "cache"]:
raise ModuleLoadError(error_msg)
# For optional modules, log warning but continue
import warnings
warnings.warn(f"Optional module {module_name} failed to load: {str(e)}")
async def _register_module_router(self, module_name: str, module_instance):
@@ -376,30 +422,37 @@ class ModuleManager:
try:
# Check if module has a router attribute
if hasattr(module_instance, 'router'):
router = getattr(module_instance, 'router')
if hasattr(module_instance, "router"):
router = getattr(module_instance, "router")
# Verify it's actually a FastAPI router
from fastapi import APIRouter
if isinstance(router, APIRouter):
# Register the router with the app
self.fastapi_app.include_router(router)
log_module_event(module_name, "router_registered", {
"router_prefix": getattr(router, 'prefix', 'unknown'),
"router_tags": getattr(router, 'tags', [])
})
log_module_event(
module_name,
"router_registered",
{
"router_prefix": getattr(router, "prefix", "unknown"),
"router_tags": getattr(router, "tags", []),
},
)
logger.info(f"Registered router for module {module_name}")
else:
logger.debug(f"Module {module_name} has 'router' attribute but it's not a FastAPI router")
logger.debug(
f"Module {module_name} has 'router' attribute but it's not a FastAPI router"
)
else:
logger.debug(f"Module {module_name} does not have a router")
except Exception as e:
log_module_event(module_name, "router_registration_failed", {
"error": str(e)
})
log_module_event(
module_name, "router_registration_failed", {"error": str(e)}
)
logger.warning(f"Failed to register router for module {module_name}: {e}")
async def unload_module(self, module_name: str):
@@ -411,7 +464,7 @@ class ModuleManager:
module = self.modules[module_name]
# Call cleanup if available
if hasattr(module, 'cleanup'):
if hasattr(module, "cleanup"):
await module.cleanup()
del self.modules[module_name]
@@ -435,7 +488,11 @@ class ModuleManager:
log_module_event(module_name, "reloaded", {"success": True})
return True
else:
log_module_event(module_name, "reload_skipped", {"reason": "Module disabled or no config"})
log_module_event(
module_name,
"reload_skipped",
{"reason": "Module disabled or no config"},
)
return False
except Exception as e:
log_module_event(module_name, "reload_failed", {"error": str(e)})
@@ -453,7 +510,9 @@ class ModuleManager:
"""Check if a module is loaded"""
return module_name in self.modules
async def execute_interceptor_chain(self, chain_type: str, context: Dict[str, Any]) -> Dict[str, Any]:
async def execute_interceptor_chain(
self, chain_type: str, context: Dict[str, Any]
) -> Dict[str, Any]:
"""Execute interceptor chain for all loaded modules"""
result_context = context.copy()
@@ -468,15 +527,17 @@ class ModuleManager:
interceptor = getattr(module, interceptor_method)
result_context = await interceptor(result_context)
log_module_event(module_name, "interceptor_executed", {
"chain_type": chain_type,
"success": True
})
log_module_event(
module_name,
"interceptor_executed",
{"chain_type": chain_type, "success": True},
)
except Exception as e:
log_module_event(module_name, "interceptor_failed", {
"chain_type": chain_type,
"error": str(e)
})
log_module_event(
module_name,
"interceptor_failed",
{"chain_type": chain_type, "error": str(e)},
)
# Continue with other modules even if one fails
continue
@@ -487,7 +548,9 @@ class ModuleManager:
if not self.initialized:
return
log_module_event("module_manager", "shutting_down", {"modules_count": len(self.modules)})
log_module_event(
"module_manager", "shutting_down", {"modules_count": len(self.modules)}
)
# Unload modules in reverse order
for module_name in reversed(self.module_order):
@@ -519,14 +582,22 @@ class ModuleManager:
return
if not self.modules_root.exists():
log_module_event("hot_reload", "watcher_skipped", {"reason": f"No modules directory at {self.modules_root}"})
log_module_event(
"hot_reload",
"watcher_skipped",
{"reason": f"No modules directory at {self.modules_root}"},
)
return
self.file_observer = Observer()
event_handler = ModuleFileWatcher(self, self.modules_root)
self.file_observer.schedule(event_handler, str(self.modules_root), recursive=True)
self.file_observer.schedule(
event_handler, str(self.modules_root), recursive=True
)
self.file_observer.start()
log_module_event("hot_reload", "watcher_started", {"path": str(self.modules_root)})
log_module_event(
"hot_reload", "watcher_started", {"path": str(self.modules_root)}
)
except Exception as e:
log_module_event("hot_reload", "watcher_failed", {"error": str(e)})
@@ -536,7 +607,9 @@ class ModuleManager:
"""Enable a module"""
try:
# Update the manifest status
success = await module_config_manager.update_module_status(module_name, True)
success = await module_config_manager.update_module_status(
module_name, True
)
if not success:
return False
@@ -561,7 +634,9 @@ class ModuleManager:
"""Disable a module"""
try:
# Update the manifest status
success = await module_config_manager.update_module_status(module_name, False)
success = await module_config_manager.update_module_status(
module_name, False
)
if not success:
return False
@@ -604,8 +679,9 @@ class ModuleManager:
"endpoints": manifest.endpoints,
"permissions": manifest.permissions,
"ui_config": manifest.ui_config,
"has_schema": module_config_manager.get_module_schema(module_name) is not None,
"current_config": module_config_manager.get_module_config(module_name)
"has_schema": module_config_manager.get_module_schema(module_name)
is not None,
"current_config": module_config_manager.get_module_config(module_name),
}
def list_all_modules(self) -> List[Dict]:
@@ -621,7 +697,9 @@ class ModuleManager:
"""Update module configuration"""
try:
# Validate and save the configuration
success = await module_config_manager.save_module_config(module_name, config)
success = await module_config_manager.save_module_config(
module_name, config
)
if not success:
return False
@@ -640,7 +718,6 @@ class ModuleManager:
log_module_event(module_name, "config_update_failed", {"error": str(e)})
return False
async def get_module_health(self, module_name: str) -> Dict:
"""Get module health status"""
manifest = module_config_manager.get_module_manifest(module_name)
@@ -656,11 +733,11 @@ class ModuleManager:
"enabled": manifest.enabled,
"dependencies_met": self._check_dependencies(module_name),
"last_loaded": None,
"error": None
"error": None,
}
# Check if module has custom health check
if module and hasattr(module, 'get_health'):
if module and hasattr(module, "get_health"):
try:
custom_health = await module.get_health()
health.update(custom_health)

View File

@@ -0,0 +1,556 @@
"""
Notification Service
Multi-channel notification system with email, webhooks, and other providers
"""
import asyncio
import json
import logging
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from jinja2 import Template, Environment, DictLoader
import aiohttp
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import and_, or_, desc, func
from fastapi import HTTPException, status
from app.models.notification import (
Notification,
NotificationTemplate,
NotificationChannel,
NotificationType,
NotificationStatus,
NotificationPriority,
)
from app.models.user import User
from app.core.config import settings
logger = logging.getLogger(__name__)
class NotificationService:
"""Service for managing and sending notifications"""
def __init__(self, db: AsyncSession):
self.db = db
self.jinja_env = Environment(loader=DictLoader({}))
async def send_notification(
self,
recipients: List[str],
subject: Optional[str] = None,
body: str = "",
html_body: Optional[str] = None,
notification_type: NotificationType = NotificationType.EMAIL,
priority: NotificationPriority = NotificationPriority.NORMAL,
template_name: Optional[str] = None,
template_variables: Optional[Dict[str, Any]] = None,
channel_name: Optional[str] = None,
user_id: Optional[int] = None,
scheduled_at: Optional[datetime] = None,
expires_at: Optional[datetime] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[List[str]] = None,
) -> Notification:
"""Send a notification through specified channel"""
# Get or create channel
channel = await self._get_or_default_channel(notification_type, channel_name)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No active channel found for type: {notification_type}",
)
# Process template if specified
if template_name:
template = await self._get_template(template_name)
if template:
rendered = await self._render_template(
template, template_variables or {}
)
subject = subject or rendered.get("subject")
body = body or rendered.get("body")
html_body = html_body or rendered.get("html_body")
# Create notification record
notification = Notification(
subject=subject,
body=body,
html_body=html_body,
recipients=recipients,
priority=priority,
scheduled_at=scheduled_at,
expires_at=expires_at,
channel_id=channel.id,
user_id=user_id,
metadata=metadata or {},
tags=tags or [],
)
self.db.add(notification)
await self.db.commit()
await self.db.refresh(notification)
# Send immediately if not scheduled
if scheduled_at is None or scheduled_at <= datetime.utcnow():
await self._deliver_notification(notification)
return notification
async def send_email(
self,
recipients: List[str],
subject: str,
body: str,
html_body: Optional[str] = None,
cc_recipients: Optional[List[str]] = None,
bcc_recipients: Optional[List[str]] = None,
**kwargs,
) -> Notification:
"""Send email notification"""
notification = await self.send_notification(
recipients=recipients,
subject=subject,
body=body,
html_body=html_body,
notification_type=NotificationType.EMAIL,
**kwargs,
)
if cc_recipients:
notification.cc_recipients = cc_recipients
if bcc_recipients:
notification.bcc_recipients = bcc_recipients
await self.db.commit()
return notification
async def send_webhook(
self,
webhook_url: str,
payload: Dict[str, Any],
headers: Optional[Dict[str, str]] = None,
**kwargs,
) -> Notification:
"""Send webhook notification"""
# Use webhook URL as recipient
notification = await self.send_notification(
recipients=[webhook_url],
body=json.dumps(payload),
notification_type=NotificationType.WEBHOOK,
metadata={"headers": headers or {}},
**kwargs,
)
return notification
async def send_slack_message(
self, channel: str, message: str, **kwargs
) -> Notification:
"""Send Slack message"""
notification = await self.send_notification(
recipients=[channel],
body=message,
notification_type=NotificationType.SLACK,
**kwargs,
)
return notification
async def process_scheduled_notifications(self):
"""Process notifications that are scheduled for delivery"""
now = datetime.utcnow()
# Get pending scheduled notifications that are due
stmt = select(Notification).where(
and_(
Notification.status == NotificationStatus.PENDING,
Notification.scheduled_at <= now,
or_(Notification.expires_at.is_(None), Notification.expires_at > now),
)
)
result = await self.db.execute(stmt)
notifications = result.scalars().all()
processed_count = 0
for notification in notifications:
try:
await self._deliver_notification(notification)
processed_count += 1
except Exception as e:
logger.error(
f"Failed to process scheduled notification {notification.id}: {e}"
)
logger.info(f"Processed {processed_count} scheduled notifications")
return processed_count
async def retry_failed_notifications(self):
"""Retry failed notifications that can be retried"""
# Get failed notifications that can be retried
stmt = select(Notification).where(
and_(
Notification.status.in_(
[NotificationStatus.FAILED, NotificationStatus.RETRY]
),
Notification.attempts < Notification.max_attempts,
or_(
Notification.expires_at.is_(None),
Notification.expires_at > datetime.utcnow(),
),
)
)
result = await self.db.execute(stmt)
notifications = result.scalars().all()
retried_count = 0
for notification in notifications:
# Check retry delay
if notification.failed_at:
retry_delay = timedelta(
minutes=notification.channel.retry_delay_minutes
)
if datetime.utcnow() - notification.failed_at < retry_delay:
continue
try:
await self._deliver_notification(notification)
retried_count += 1
except Exception as e:
logger.error(f"Failed to retry notification {notification.id}: {e}")
logger.info(f"Retried {retried_count} failed notifications")
return retried_count
async def _deliver_notification(self, notification: Notification):
"""Deliver a notification through its channel"""
channel = await self._get_channel_by_id(notification.channel_id)
if not channel or not channel.is_active:
notification.mark_failed("Channel not available")
await self.db.commit()
return
try:
if channel.notification_type == NotificationType.EMAIL:
await self._send_email(notification, channel)
elif channel.notification_type == NotificationType.WEBHOOK:
await self._send_webhook(notification, channel)
elif channel.notification_type == NotificationType.SLACK:
await self._send_slack(notification, channel)
else:
raise ValueError(
f"Unsupported notification type: {channel.notification_type}"
)
# Update channel stats
channel.update_stats(success=True)
except Exception as e:
logger.error(f"Failed to deliver notification {notification.id}: {e}")
notification.mark_failed(str(e))
channel.update_stats(success=False, error_message=str(e))
await self.db.commit()
async def _send_email(
self, notification: Notification, channel: NotificationChannel
):
"""Send email through SMTP"""
config = channel.config
credentials = channel.credentials or {}
# Create message
msg = MIMEMultipart("alternative")
msg["Subject"] = notification.subject or "No Subject"
msg["From"] = config.get("from_email", "noreply@example.com")
msg["To"] = ", ".join(notification.recipients)
if notification.cc_recipients:
msg["Cc"] = ", ".join(notification.cc_recipients)
# Add text part
text_part = MIMEText(notification.body, "plain", "utf-8")
msg.attach(text_part)
# Add HTML part if available
if notification.html_body:
html_part = MIMEText(notification.html_body, "html", "utf-8")
msg.attach(html_part)
# Send email
smtp_host = config.get("smtp_host", "localhost")
smtp_port = config.get("smtp_port", 587)
username = credentials.get("username")
password = credentials.get("password")
use_tls = config.get("use_tls", True)
with smtplib.SMTP(smtp_host, smtp_port) as server:
if use_tls:
server.starttls()
if username and password:
server.login(username, password)
all_recipients = notification.recipients[:]
if notification.cc_recipients:
all_recipients.extend(notification.cc_recipients)
if notification.bcc_recipients:
all_recipients.extend(notification.bcc_recipients)
server.sendmail(msg["From"], all_recipients, msg.as_string())
notification.mark_sent()
async def _send_webhook(
self, notification: Notification, channel: NotificationChannel
):
"""Send webhook HTTP request"""
webhook_url = notification.recipients[0] # URL is stored as recipient
headers = notification.metadata.get("headers", {})
headers.setdefault("Content-Type", "application/json")
# Parse body as JSON payload
try:
payload = json.loads(notification.body)
except json.JSONDecodeError:
payload = {"message": notification.body}
async with aiohttp.ClientSession() as session:
async with session.post(
webhook_url,
json=payload,
headers=headers,
timeout=aiohttp.ClientTimeout(total=30),
) as response:
if response.status >= 400:
raise Exception(
f"Webhook failed with status {response.status}: {await response.text()}"
)
external_id = response.headers.get("X-Message-ID")
notification.mark_sent(external_id)
async def _send_slack(
self, notification: Notification, channel: NotificationChannel
):
"""Send Slack message"""
credentials = channel.credentials or {}
webhook_url = credentials.get("webhook_url")
if not webhook_url:
raise ValueError("Slack webhook URL not configured")
payload = {
"channel": notification.recipients[0],
"text": notification.body,
"username": channel.config.get("username", "Enclava Bot"),
}
if notification.subject:
payload["attachments"] = [
{"title": notification.subject, "text": notification.body}
]
async with aiohttp.ClientSession() as session:
async with session.post(
webhook_url, json=payload, timeout=aiohttp.ClientTimeout(total=30)
) as response:
if response.status >= 400:
raise Exception(
f"Slack webhook failed with status {response.status}: {await response.text()}"
)
notification.mark_sent()
async def _get_or_default_channel(
self, notification_type: NotificationType, channel_name: Optional[str] = None
) -> Optional[NotificationChannel]:
"""Get specific channel or default for notification type"""
if channel_name:
stmt = select(NotificationChannel).where(
and_(
NotificationChannel.name == channel_name,
NotificationChannel.is_active == True,
)
)
else:
stmt = select(NotificationChannel).where(
and_(
NotificationChannel.notification_type == notification_type,
NotificationChannel.is_active == True,
NotificationChannel.is_default == True,
)
)
result = await self.db.execute(stmt)
return result.scalar_one_or_none()
async def _get_channel_by_id(
self, channel_id: int
) -> Optional[NotificationChannel]:
"""Get channel by ID"""
stmt = select(NotificationChannel).where(NotificationChannel.id == channel_id)
result = await self.db.execute(stmt)
return result.scalar_one_or_none()
async def _get_template(self, template_name: str) -> Optional[NotificationTemplate]:
"""Get notification template by name"""
stmt = select(NotificationTemplate).where(
and_(
NotificationTemplate.name == template_name,
NotificationTemplate.is_active == True,
)
)
result = await self.db.execute(stmt)
return result.scalar_one_or_none()
async def _render_template(
self, template: NotificationTemplate, variables: Dict[str, Any]
) -> Dict[str, str]:
"""Render template with variables"""
rendered = {}
# Render subject
if template.subject_template:
subject_tmpl = Template(template.subject_template)
rendered["subject"] = subject_tmpl.render(**variables)
# Render body
body_tmpl = Template(template.body_template)
rendered["body"] = body_tmpl.render(**variables)
# Render HTML body
if template.html_template:
html_tmpl = Template(template.html_template)
rendered["html_body"] = html_tmpl.render(**variables)
return rendered
# Management methods
async def create_template(
self,
name: str,
display_name: str,
notification_type: NotificationType,
body_template: str,
subject_template: Optional[str] = None,
html_template: Optional[str] = None,
description: Optional[str] = None,
default_priority: NotificationPriority = NotificationPriority.NORMAL,
variables: Optional[Dict[str, Any]] = None,
) -> NotificationTemplate:
"""Create notification template"""
template = NotificationTemplate(
name=name,
display_name=display_name,
description=description,
notification_type=notification_type,
subject_template=subject_template,
body_template=body_template,
html_template=html_template,
default_priority=default_priority,
variables=variables or {},
)
self.db.add(template)
await self.db.commit()
await self.db.refresh(template)
return template
async def create_channel(
self,
name: str,
display_name: str,
notification_type: NotificationType,
config: Dict[str, Any],
credentials: Optional[Dict[str, Any]] = None,
is_default: bool = False,
) -> NotificationChannel:
"""Create notification channel"""
channel = NotificationChannel(
name=name,
display_name=display_name,
notification_type=notification_type,
config=config,
credentials=credentials,
is_default=is_default,
)
self.db.add(channel)
await self.db.commit()
await self.db.refresh(channel)
return channel
async def get_notification_stats(self) -> Dict[str, Any]:
"""Get notification statistics"""
# Total notifications
total_notifications = await self.db.execute(select(func.count(Notification.id)))
total_count = total_notifications.scalar()
# Notifications by status
status_counts = await self.db.execute(
select(Notification.status, func.count(Notification.id)).group_by(
Notification.status
)
)
status_stats = dict(status_counts.all())
# Recent notifications (last 24h)
twenty_four_hours_ago = datetime.utcnow() - timedelta(hours=24)
recent_notifications = await self.db.execute(
select(func.count(Notification.id)).where(
Notification.created_at >= twenty_four_hours_ago
)
)
recent_count = recent_notifications.scalar()
# Channel performance
channel_stats = await self.db.execute(
select(
NotificationChannel.name,
NotificationChannel.success_count,
NotificationChannel.failure_count,
)
)
channel_performance = [
{
"name": name,
"success_count": success,
"failure_count": failure,
"success_rate": success / (success + failure)
if (success + failure) > 0
else 0,
}
for name, success, failure in channel_stats.all()
]
return {
"total_notifications": total_count,
"status_breakdown": status_stats,
"recent_notifications": recent_count,
"channel_performance": channel_performance,
}

View File

@@ -15,7 +15,9 @@ logger = logging.getLogger(__name__)
class OllamaEmbeddingService:
"""Service for generating text embeddings using Ollama"""
def __init__(self, model_name: str = "bge-m3", base_url: str = "http://172.17.0.1:11434"):
def __init__(
self, model_name: str = "bge-m3", base_url: str = "http://172.17.0.1:11434"
):
self.model_name = model_name
self.base_url = base_url
self.dimension = 1024 # bge-m3 dimension
@@ -37,20 +39,28 @@ class OllamaEmbeddingService:
return False
data = await resp.json()
models = [model['name'].split(':')[0] for model in data.get('models', [])]
models = [
model["name"].split(":")[0] for model in data.get("models", [])
]
if self.model_name not in models:
logger.error(f"Model {self.model_name} not found in Ollama. Available: {models}")
logger.error(
f"Model {self.model_name} not found in Ollama. Available: {models}"
)
return False
# Test embedding generation
test_embedding = await self.get_embedding("test")
if not test_embedding or len(test_embedding) != self.dimension:
logger.error(f"Failed to generate test embedding with {self.model_name}")
logger.error(
f"Failed to generate test embedding with {self.model_name}"
)
return False
self.initialized = True
logger.info(f"Ollama embedding service initialized with model: {self.model_name} (dimension: {self.dimension})")
logger.info(
f"Ollama embedding service initialized with model: {self.model_name} (dimension: {self.dimension})"
)
return True
except Exception as e:
@@ -85,27 +95,32 @@ class OllamaEmbeddingService:
# Call Ollama embedding API
async with self._session.post(
f"{self.base_url}/api/embeddings",
json={
"model": self.model_name,
"prompt": text
}
json={"model": self.model_name, "prompt": text},
) as resp:
if resp.status != 200:
logger.error(f"Ollama embedding request failed: {resp.status}")
logger.error(
f"Ollama embedding request failed: {resp.status}"
)
embeddings.append(self._generate_fallback_embedding(text))
continue
result = await resp.json()
if 'embedding' in result:
embedding = result['embedding']
if "embedding" in result:
embedding = result["embedding"]
if len(embedding) == self.dimension:
embeddings.append(embedding)
else:
logger.warning(f"Embedding dimension mismatch: expected {self.dimension}, got {len(embedding)}")
embeddings.append(self._generate_fallback_embedding(text))
logger.warning(
f"Embedding dimension mismatch: expected {self.dimension}, got {len(embedding)}"
)
embeddings.append(
self._generate_fallback_embedding(text)
)
else:
logger.error(f"No embedding in Ollama response for text: {text[:50]}...")
logger.error(
f"No embedding in Ollama response for text: {text[:50]}..."
)
embeddings.append(self._generate_fallback_embedding(text))
except Exception as e:
@@ -156,7 +171,7 @@ class OllamaEmbeddingService:
"dimension": self.dimension,
"backend": "Ollama",
"base_url": self.base_url,
"initialized": self.initialized
"initialized": self.initialized,
}
async def cleanup(self):

View File

@@ -10,13 +10,16 @@ from typing import Dict, List, Set, Optional, Any
from dataclasses import dataclass
from enum import Enum
from fastapi import HTTPException, status
from app.core.logging import get_logger
from app.utils.exceptions import CustomHTTPException
logger = get_logger(__name__)
class PermissionAction(str, Enum):
"""Standard permission actions"""
CREATE = "create"
READ = "read"
UPDATE = "update"
@@ -30,6 +33,7 @@ class PermissionAction(str, Enum):
@dataclass
class Permission:
"""Permission definition"""
resource: str
action: str
description: str = ""
@@ -39,6 +43,7 @@ class Permission:
@dataclass
class PermissionScope:
"""Permission scope for context-aware permissions"""
namespace: str
resource: str
action: str
@@ -137,33 +142,22 @@ class ModulePermissionRegistry:
def _initialize_default_roles(self) -> Dict[str, List[str]]:
"""Initialize default permission roles"""
return {
"super_admin": [
"platform:*",
"modules:*",
"llm:*"
],
"admin": [
"platform:*",
"modules:*",
"llm:*"
],
"super_admin": ["platform:*", "modules:*", "llm:*"],
"admin": ["platform:*", "modules:*", "llm:*"],
"developer": [
"platform:api-keys:*",
"platform:budgets:read",
"llm:completions:execute",
"llm:embeddings:execute",
"modules:*:read",
"modules:*:execute"
"modules:*:execute",
],
"user": [
"llm:completions:execute",
"llm:embeddings:execute",
"modules:*:read"
"modules:*:read",
],
"readonly": [
"platform:*:read",
"modules:*:read"
]
"readonly": ["platform:*:read", "modules:*:read"],
}
def register_module(self, module_id: str, permissions: List[Permission]):
@@ -187,32 +181,25 @@ class ModulePermissionRegistry:
Permission("users", "update", "Update users"),
Permission("users", "delete", "Delete users"),
Permission("users", "manage", "Full user management"),
Permission("api-keys", "create", "Create API keys"),
Permission("api-keys", "read", "View API keys"),
Permission("api-keys", "update", "Update API keys"),
Permission("api-keys", "delete", "Delete API keys"),
Permission("api-keys", "manage", "Full API key management"),
Permission("budgets", "create", "Create budgets"),
Permission("budgets", "read", "View budgets"),
Permission("budgets", "update", "Update budgets"),
Permission("budgets", "delete", "Delete budgets"),
Permission("budgets", "manage", "Full budget management"),
Permission("audit", "read", "View audit logs"),
Permission("audit", "export", "Export audit logs"),
Permission("settings", "read", "View settings"),
Permission("settings", "update", "Update settings"),
Permission("settings", "manage", "Full settings management"),
Permission("health", "read", "View health status"),
Permission("metrics", "read", "View metrics"),
Permission("permissions", "read", "View permissions"),
Permission("permissions", "manage", "Manage permissions"),
Permission("roles", "create", "Create roles"),
Permission("roles", "read", "View roles"),
Permission("roles", "update", "Update roles"),
@@ -238,8 +225,9 @@ class ModulePermissionRegistry:
logger.info("Registered platform and LLM permissions")
self._platform_permissions_registered = True
def check_permission(self, user_permissions: List[str], required: str,
context: Dict[str, Any] = None) -> bool:
def check_permission(
self, user_permissions: List[str], required: str, context: Dict[str, Any] = None
) -> bool:
"""Check if user has required permission"""
# Basic permission check
has_perm = self.tree.has_permission(user_permissions, required)
@@ -253,8 +241,9 @@ class ModulePermissionRegistry:
return True
def _check_context_permissions(self, user_permissions: List[str],
required: str, context: Dict[str, Any]) -> bool:
def _check_context_permissions(
self, user_permissions: List[str], required: str, context: Dict[str, Any]
) -> bool:
"""Check context-aware permissions"""
# Extract resource owner information
resource_owner = context.get("owner_id")
@@ -266,24 +255,32 @@ class ModulePermissionRegistry:
# Check for elevated permissions for cross-user access
if resource_owner and resource_owner != current_user:
elevated_required = required.replace(":read", ":manage").replace(":update", ":manage")
elevated_required = required.replace(":read", ":manage").replace(
":update", ":manage"
)
return self.tree.has_permission(user_permissions, elevated_required)
return True
def get_user_permissions(self, roles: List[str],
custom_permissions: List[str] = None) -> List[str]:
def get_user_permissions(
self, roles: List[str], custom_permissions: List[str] = None
) -> List[str]:
"""Get effective permissions for a user based on roles and custom permissions"""
import time
start_time = time.time()
logger.info(f"=== GET USER PERMISSIONS START === Roles: {roles}, Custom perms: {custom_permissions}")
logger.info(
f"=== GET USER PERMISSIONS START === Roles: {roles}, Custom perms: {custom_permissions}"
)
try:
permissions = set()
# Add role-based permissions
for role in roles:
role_perms = self.role_permissions.get(role, self.default_roles.get(role, []))
role_perms = self.role_permissions.get(
role, self.default_roles.get(role, [])
)
logger.info(f"Role '{role}' has {len(role_perms)} permissions")
permissions.update(role_perms)
@@ -295,20 +292,26 @@ class ModulePermissionRegistry:
result = list(permissions)
end_time = time.time()
duration = end_time - start_time
logger.info(f"=== GET USER PERMISSIONS END === Total permissions: {len(result)}, Duration: {duration:.3f}s")
logger.info(
f"=== GET USER PERMISSIONS END === Total permissions: {len(result)}, Duration: {duration:.3f}s"
)
return result
except Exception as e:
end_time = time.time()
duration = end_time - start_time
logger.error(f"=== GET USER PERMISSIONS FAILED === Duration: {duration:.3f}s, Error: {e}")
logger.error(
f"=== GET USER PERMISSIONS FAILED === Duration: {duration:.3f}s, Error: {e}"
)
raise
def get_module_permissions(self, module_id: str) -> List[Permission]:
"""Get all permissions for a specific module"""
return self.module_permissions.get(module_id, [])
def get_available_permissions(self, namespace: str = None) -> Dict[str, List[Permission]]:
def get_available_permissions(
self, namespace: str = None
) -> Dict[str, List[Permission]]:
"""Get all available permissions, optionally filtered by namespace"""
if namespace:
filtered = {}
@@ -345,11 +348,7 @@ class ModulePermissionRegistry:
else:
invalid.append(perm)
return {
"valid": valid,
"invalid": invalid,
"is_valid": len(invalid) == 0
}
return {"valid": valid, "invalid": invalid, "is_valid": len(invalid) == 0}
def _is_valid_wildcard(self, permission: str) -> bool:
"""Check if a wildcard permission is valid"""
@@ -371,6 +370,7 @@ class ModulePermissionRegistry:
def get_permission_hierarchy(self) -> Dict[str, Any]:
"""Get the permission hierarchy tree structure"""
def build_tree(node, path=""):
tree = {}
for key, value in node.items():
@@ -378,7 +378,7 @@ class ModulePermissionRegistry:
tree["_permission"] = {
"resource": value.resource,
"action": value.action,
"description": value.description
"description": value.description,
}
else:
current_path = f"{path}:{key}" if path else key
@@ -388,7 +388,11 @@ class ModulePermissionRegistry:
return build_tree(self.tree.root)
def require_permission(user_permissions: List[str], required_permission: str, context: Optional[Dict[str, Any]] = None):
def require_permission(
user_permissions: List[str],
required_permission: str,
context: Optional[Dict[str, Any]] = None,
):
"""
Decorator function to require a specific permission
Raises HTTPException if user doesn't have the required permission
@@ -403,11 +407,46 @@ def require_permission(user_permissions: List[str], required_permission: str, co
"""
from fastapi import HTTPException, status
if not permission_registry.check_permission(user_permissions, required_permission, context):
logger.warning(f"Permission denied: required '{required_permission}', user has {user_permissions}")
raise HTTPException(
if not permission_registry.check_permission(
user_permissions, required_permission, context
):
logger.warning(
f"Permission denied: required '{required_permission}', user has {user_permissions}"
)
# Create user-friendly error message
permission_parts = required_permission.split(":")
if len(permission_parts) >= 2:
resource = permission_parts[0]
action = permission_parts[1]
# Create more specific error messages based on resource type
if resource == "platform":
if action == "settings":
error_message = "You don't have permission to access platform settings. Contact your administrator to request access."
elif action == "analytics":
error_message = "You don't have permission to access analytics. Contact your administrator to request access."
else:
error_message = f"You don't have permission to {action} platform resources. Contact your administrator to request access."
elif resource == "tools":
error_message = f"You don't have permission to {action} tools. Contact your administrator to request tool management access."
elif resource == "users":
error_message = f"You don't have permission to {action} user accounts. Contact your administrator to request user management access."
elif resource == "api_keys":
error_message = f"You don't have permission to {action} API keys. Contact your administrator to request API key management access."
else:
error_message = f"You don't have permission to {action} {resource}. Contact your administrator to request access."
else:
error_message = f"Insufficient permissions. Contact your administrator to request access to: {required_permission}"
raise CustomHTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Insufficient permissions. Required: {required_permission}"
error_code="INSUFFICIENT_PERMISSIONS",
detail=error_message,
details={
"required_permission": required_permission,
"suggestion": "Contact your administrator to request the necessary permissions for this operation."
}
)

View File

@@ -47,7 +47,7 @@ class PluginAutoDiscovery:
plugin_info = await self._discover_plugin(item)
if plugin_info:
discovered.append(plugin_info)
self.discovered_plugins[plugin_info['slug']] = plugin_info
self.discovered_plugins[plugin_info["slug"]] = plugin_info
logger.info(f"Discovered {len(discovered)} plugins in directory")
return discovered
@@ -69,7 +69,9 @@ class PluginAutoDiscovery:
validation_result = validate_manifest_file(manifest_path)
if not validation_result["valid"]:
logger.warning(f"Invalid manifest for plugin {plugin_path.name}: {validation_result['errors']}")
logger.warning(
f"Invalid manifest for plugin {plugin_path.name}: {validation_result['errors']}"
)
return None
manifest = validation_result["manifest"]
@@ -86,6 +88,7 @@ class PluginAutoDiscovery:
# Convert manifest to JSON-serializable format
import json
manifest_dict = json.loads(manifest.json())
plugin_info = {
@@ -101,10 +104,12 @@ class PluginAutoDiscovery:
"main_py_path": str(main_py_path),
"manifest_hash": manifest_hash,
"package_hash": package_hash,
"discovered_at": datetime.now(timezone.utc)
"discovered_at": datetime.now(timezone.utc),
}
logger.info(f"Discovered plugin: {manifest.metadata.name} v{manifest.metadata.version}")
logger.info(
f"Discovered plugin: {manifest.metadata.name} v{manifest.metadata.version}"
)
return plugin_info
except Exception as e:
@@ -136,23 +141,33 @@ class PluginAutoDiscovery:
finally:
db.close()
successful_registrations = sum(1 for success in registration_results.values() if success)
logger.info(f"Plugin registration complete: {successful_registrations}/{len(registration_results)} successful")
successful_registrations = sum(
1 for success in registration_results.values() if success
)
logger.info(
f"Plugin registration complete: {successful_registrations}/{len(registration_results)} successful"
)
return registration_results
async def _register_single_plugin(self, db: Session, plugin_info: Dict[str, Any]) -> bool:
async def _register_single_plugin(
self, db: Session, plugin_info: Dict[str, Any]
) -> bool:
"""Register a single plugin in the database"""
try:
plugin_slug = plugin_info["slug"]
# Check if plugin already exists by slug
existing_plugin = db.query(Plugin).filter(Plugin.slug == plugin_slug).first()
existing_plugin = (
db.query(Plugin).filter(Plugin.slug == plugin_slug).first()
)
if existing_plugin:
# Update existing plugin if version is different
if existing_plugin.version != plugin_info["version"]:
logger.info(f"Updating plugin {plugin_slug}: {existing_plugin.version} -> {plugin_info['version']}")
logger.info(
f"Updating plugin {plugin_slug}: {existing_plugin.version} -> {plugin_info['version']}"
)
existing_plugin.version = plugin_info["version"]
existing_plugin.description = plugin_info["description"]
@@ -190,7 +205,7 @@ class PluginAutoDiscovery:
package_hash=plugin_info["package_hash"],
status="installed",
installed_by_user_id=1, # System installation
auto_enable=True # Auto-enable discovered plugins
auto_enable=True, # Auto-enable discovered plugins
)
db.add(plugin)
@@ -204,10 +219,14 @@ class PluginAutoDiscovery:
except Exception as e:
db.rollback()
logger.error(f"Database error registering plugin {plugin_info['slug']}: {e}")
logger.error(
f"Database error registering plugin {plugin_info['slug']}: {e}"
)
return False
async def _setup_plugin_database(self, plugin_id: str, plugin_info: Dict[str, Any]) -> bool:
async def _setup_plugin_database(
self, plugin_id: str, plugin_info: Dict[str, Any]
) -> bool:
"""Setup database schema for plugin"""
try:
manifest_data = plugin_info["manifest_data"]
@@ -234,11 +253,12 @@ class PluginAutoDiscovery:
try:
# Load plugin into sandbox using the correct method
plugin_dir = Path(plugin_info["plugin_path"])
plugin_token = f"plugin_{plugin_slug}_token" # Generate a token for the plugin
plugin_token = (
f"plugin_{plugin_slug}_token" # Generate a token for the plugin
)
plugin_instance = await plugin_loader.load_plugin_with_sandbox(
plugin_dir,
plugin_token
plugin_dir, plugin_token
)
if plugin_instance:
@@ -253,7 +273,9 @@ class PluginAutoDiscovery:
loading_results[plugin_slug] = False
successful_loads = sum(1 for success in loading_results.values() if success)
logger.info(f"Plugin loading complete: {successful_loads}/{len(loading_results)} successful")
logger.info(
f"Plugin loading complete: {successful_loads}/{len(loading_results)} successful"
)
return loading_results
@@ -268,8 +290,8 @@ class PluginAutoDiscovery:
"summary": {
"total_discovered": 0,
"successful_registrations": 0,
"successful_loads": 0
}
"successful_loads": 0,
},
}
try:
@@ -285,16 +307,22 @@ class PluginAutoDiscovery:
# Step 2: Register plugins in database
registration_results = await self.register_discovered_plugins()
results["registered"] = registration_results
results["summary"]["successful_registrations"] = sum(1 for success in registration_results.values() if success)
results["summary"]["successful_registrations"] = sum(
1 for success in registration_results.values() if success
)
# Step 3: Load plugins into sandbox
loading_results = await self.load_discovered_plugins()
results["loaded"] = loading_results
results["summary"]["successful_loads"] = sum(1 for success in loading_results.values() if success)
results["summary"]["successful_loads"] = sum(
1 for success in loading_results.values() if success
)
logger.info(f"Auto-discovery complete! Discovered: {results['summary']['total_discovered']}, "
logger.info(
f"Auto-discovery complete! Discovered: {results['summary']['total_discovered']}, "
f"Registered: {results['summary']['successful_registrations']}, "
f"Loaded: {results['summary']['successful_loads']}")
f"Loaded: {results['summary']['successful_loads']}"
)
return results
@@ -310,7 +338,7 @@ class PluginAutoDiscovery:
"plugins_dir_exists": self.plugins_dir.exists(),
"discovered_plugins": list(self.discovered_plugins.keys()),
"discovery_count": len(self.discovered_plugins),
"last_scan": datetime.now(timezone.utc).isoformat()
"last_scan": datetime.now(timezone.utc).isoformat(),
}
@@ -329,7 +357,14 @@ async def initialize_plugin_autodiscovery() -> Dict[str, Any]:
except Exception as e:
logger.error(f"Plugin auto-discovery initialization failed: {e}")
return {"error": str(e), "summary": {"total_discovered": 0, "successful_registrations": 0, "successful_loads": 0}}
return {
"error": str(e),
"summary": {
"total_discovered": 0,
"successful_registrations": 0,
"successful_loads": 0,
},
}
# Convenience function for manual plugin discovery

Some files were not shown because too many files have changed in this diff Show More