mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
mega changes
This commit is contained in:
13
.gitignore
vendored
13
.gitignore
vendored
@@ -66,3 +66,16 @@ frontend/node_modules/
|
||||
node_modules/
|
||||
venv/
|
||||
venv_memory_monitor/
|
||||
to_delete/
|
||||
security-review.md
|
||||
features-plan.md
|
||||
AGENTS.md
|
||||
CLAUDE.md
|
||||
backend/CURL_VERIFICATION_EXAMPLES.md
|
||||
backend/OPENAI_COMPATIBILITY_GUIDE.md
|
||||
backend/production-deployment-guide.md
|
||||
features-plan.md
|
||||
security-review.md
|
||||
backend/.env.local
|
||||
backend/.env.test
|
||||
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
import asyncio
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
load_dotenv('.env.local')
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
@@ -4,7 +4,7 @@ This migration represents the complete, accurate database schema based on the ac
|
||||
model files in the codebase. All legacy migrations have been consolidated into this
|
||||
single migration to ensure the database matches what the models expect.
|
||||
|
||||
Revision ID: 000_consolidated_ground_truth_schema
|
||||
Revision ID: 000_ground_truth
|
||||
Revises:
|
||||
Create Date: 2025-08-22 10:30:00.000000
|
||||
|
||||
|
||||
62
backend/alembic/versions/001_add_roles_table.py
Normal file
62
backend/alembic/versions/001_add_roles_table.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""add roles table
|
||||
|
||||
Revision ID: 001
|
||||
Revises: 000_ground_truth
|
||||
Create Date: 2025-01-30 00:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers
|
||||
revision = '001_add_roles_table'
|
||||
down_revision = '000_ground_truth'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# Create roles table
|
||||
op.create_table(
|
||||
'roles',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(length=50), nullable=False),
|
||||
sa.Column('display_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('level', sa.String(length=20), nullable=False),
|
||||
sa.Column('permissions', sa.JSON(), nullable=True),
|
||||
sa.Column('can_manage_users', sa.Boolean(), nullable=True, default=False),
|
||||
sa.Column('can_manage_budgets', sa.Boolean(), nullable=True, default=False),
|
||||
sa.Column('can_view_reports', sa.Boolean(), nullable=True, default=False),
|
||||
sa.Column('can_manage_tools', sa.Boolean(), nullable=True, default=False),
|
||||
sa.Column('inherits_from', sa.JSON(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
|
||||
sa.Column('is_system_role', sa.Boolean(), nullable=True, default=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
|
||||
# Create indexes for roles
|
||||
op.create_index(op.f('ix_roles_id'), 'roles', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_roles_name'), 'roles', ['name'], unique=True)
|
||||
op.create_index(op.f('ix_roles_level'), 'roles', ['level'], unique=False)
|
||||
|
||||
# Add role_id to users table
|
||||
op.add_column('users', sa.Column('role_id', sa.Integer(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
'fk_users_role_id', 'users', 'roles',
|
||||
['role_id'], ['id'], ondelete='SET NULL'
|
||||
)
|
||||
op.create_index('ix_users_role_id', 'users', ['role_id'])
|
||||
|
||||
def downgrade():
|
||||
# Remove role_id from users
|
||||
op.drop_index('ix_users_role_id', table_name='users')
|
||||
op.drop_constraint('fk_users_role_id', table_name='users', type_='foreignkey')
|
||||
op.drop_column('users', 'role_id')
|
||||
|
||||
# Drop roles table
|
||||
op.drop_index(op.f('ix_roles_level'), table_name='roles')
|
||||
op.drop_index(op.f('ix_roles_name'), table_name='roles')
|
||||
op.drop_index(op.f('ix_roles_id'), table_name='roles')
|
||||
op.drop_table('roles')
|
||||
118
backend/alembic/versions/002_add_tools_tables.py
Normal file
118
backend/alembic/versions/002_add_tools_tables.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""add tools tables
|
||||
|
||||
Revision ID: 002
|
||||
Revises: 001_add_roles_table
|
||||
Create Date: 2025-01-30 00:00:01.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers
|
||||
revision = '002_add_tools_tables'
|
||||
down_revision = '001_add_roles_table'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# Create tool_categories table
|
||||
op.create_table(
|
||||
'tool_categories',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(length=50), nullable=False),
|
||||
sa.Column('display_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('icon', sa.String(length=50), nullable=True),
|
||||
sa.Column('color', sa.String(length=20), nullable=True),
|
||||
sa.Column('sort_order', sa.Integer(), nullable=True, default=0),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_tool_categories_id'), 'tool_categories', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_tool_categories_name'), 'tool_categories', ['name'], unique=True)
|
||||
|
||||
# Create tools table
|
||||
op.create_table(
|
||||
'tools',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('display_name', sa.String(length=200), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('tool_type', sa.String(length=20), nullable=False),
|
||||
sa.Column('code', sa.Text(), nullable=False),
|
||||
sa.Column('parameters_schema', sa.JSON(), nullable=True),
|
||||
sa.Column('return_schema', sa.JSON(), nullable=True),
|
||||
sa.Column('timeout_seconds', sa.Integer(), nullable=True, default=30),
|
||||
sa.Column('max_memory_mb', sa.Integer(), nullable=True, default=256),
|
||||
sa.Column('max_cpu_seconds', sa.Float(), nullable=True, default=10.0),
|
||||
sa.Column('docker_image', sa.String(length=200), nullable=True),
|
||||
sa.Column('docker_command', sa.Text(), nullable=True),
|
||||
sa.Column('is_public', sa.Boolean(), nullable=True, default=False),
|
||||
sa.Column('is_approved', sa.Boolean(), nullable=True, default=False),
|
||||
sa.Column('created_by_user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('category', sa.String(length=50), nullable=True),
|
||||
sa.Column('tags', sa.JSON(), nullable=True),
|
||||
sa.Column('usage_count', sa.Integer(), nullable=True, default=0),
|
||||
sa.Column('last_used_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_tools_id'), 'tools', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_tools_name'), 'tools', ['name'], unique=False)
|
||||
op.create_foreign_key(
|
||||
'fk_tools_created_by_user_id', 'tools', 'users',
|
||||
['created_by_user_id'], ['id'], ondelete='CASCADE'
|
||||
)
|
||||
|
||||
# Create tool_executions table
|
||||
op.create_table(
|
||||
'tool_executions',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('tool_id', sa.Integer(), nullable=False),
|
||||
sa.Column('executed_by_user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('parameters', sa.JSON(), nullable=True),
|
||||
sa.Column('status', sa.String(length=20), nullable=False, default='pending'),
|
||||
sa.Column('output', sa.Text(), nullable=True),
|
||||
sa.Column('error_message', sa.Text(), nullable=True),
|
||||
sa.Column('return_code', sa.Integer(), nullable=True),
|
||||
sa.Column('execution_time_ms', sa.Integer(), nullable=True),
|
||||
sa.Column('memory_used_mb', sa.Float(), nullable=True),
|
||||
sa.Column('cpu_time_ms', sa.Integer(), nullable=True),
|
||||
sa.Column('container_id', sa.String(length=100), nullable=True),
|
||||
sa.Column('docker_logs', sa.Text(), nullable=True),
|
||||
sa.Column('started_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('completed_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_tool_executions_id'), 'tool_executions', ['id'], unique=False)
|
||||
op.create_foreign_key(
|
||||
'fk_tool_executions_tool_id', 'tool_executions', 'tools',
|
||||
['tool_id'], ['id'], ondelete='CASCADE'
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'fk_tool_executions_executed_by_user_id', 'tool_executions', 'users',
|
||||
['executed_by_user_id'], ['id'], ondelete='CASCADE'
|
||||
)
|
||||
|
||||
def downgrade():
|
||||
# Drop tool_executions table
|
||||
op.drop_constraint('fk_tool_executions_executed_by_user_id', table_name='tool_executions', type_='foreignkey')
|
||||
op.drop_constraint('fk_tool_executions_tool_id', table_name='tool_executions', type_='foreignkey')
|
||||
op.drop_index(op.f('ix_tool_executions_id'), table_name='tool_executions')
|
||||
op.drop_table('tool_executions')
|
||||
|
||||
# Drop tools table
|
||||
op.drop_constraint('fk_tools_created_by_user_id', table_name='tools', type_='foreignkey')
|
||||
op.drop_index(op.f('ix_tools_name'), table_name='tools')
|
||||
op.drop_index(op.f('ix_tools_id'), table_name='tools')
|
||||
op.drop_table('tools')
|
||||
|
||||
# Drop tool_categories table
|
||||
op.drop_index(op.f('ix_tool_categories_name'), table_name='tool_categories')
|
||||
op.drop_index(op.f('ix_tool_categories_id'), table_name='tool_categories')
|
||||
op.drop_table('tool_categories')
|
||||
132
backend/alembic/versions/003_add_notifications_tables.py
Normal file
132
backend/alembic/versions/003_add_notifications_tables.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""add notifications tables
|
||||
|
||||
Revision ID: 003_add_notifications_tables
|
||||
Revises: 002_add_tools_tables
|
||||
Create Date: 2025-01-30 00:00:02.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers
|
||||
revision = '003_add_notifications_tables'
|
||||
down_revision = '002_add_tools_tables'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# Create notification_templates table
|
||||
op.create_table(
|
||||
'notification_templates',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('display_name', sa.String(length=200), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('notification_type', sa.String(length=20), nullable=False),
|
||||
sa.Column('subject_template', sa.Text(), nullable=True),
|
||||
sa.Column('body_template', sa.Text(), nullable=False),
|
||||
sa.Column('html_template', sa.Text(), nullable=True),
|
||||
sa.Column('default_priority', sa.String(length=20), nullable=True, default='normal'),
|
||||
sa.Column('variables', sa.JSON(), nullable=True),
|
||||
sa.Column('metadata', sa.JSON(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_notification_templates_id'), 'notification_templates', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_notification_templates_name'), 'notification_templates', ['name'], unique=True)
|
||||
|
||||
# Create notification_channels table
|
||||
op.create_table(
|
||||
'notification_channels',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('display_name', sa.String(length=200), nullable=False),
|
||||
sa.Column('notification_type', sa.String(length=20), nullable=False),
|
||||
sa.Column('config', sa.JSON(), nullable=False),
|
||||
sa.Column('credentials', sa.JSON(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
|
||||
sa.Column('is_default', sa.Boolean(), nullable=True, default=False),
|
||||
sa.Column('rate_limit', sa.Integer(), nullable=True, default=100),
|
||||
sa.Column('retry_count', sa.Integer(), nullable=True, default=3),
|
||||
sa.Column('retry_delay_minutes', sa.Integer(), nullable=True, default=5),
|
||||
sa.Column('last_used_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('success_count', sa.Integer(), nullable=True, default=0),
|
||||
sa.Column('failure_count', sa.Integer(), nullable=True, default=0),
|
||||
sa.Column('last_error', sa.Text(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_notification_channels_id'), 'notification_channels', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_notification_channels_name'), 'notification_channels', ['name'], unique=False)
|
||||
|
||||
# Create notifications table
|
||||
op.create_table(
|
||||
'notifications',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('subject', sa.String(length=500), nullable=True),
|
||||
sa.Column('body', sa.Text(), nullable=False),
|
||||
sa.Column('html_body', sa.Text(), nullable=True),
|
||||
sa.Column('recipients', sa.JSON(), nullable=False),
|
||||
sa.Column('cc_recipients', sa.JSON(), nullable=True),
|
||||
sa.Column('bcc_recipients', sa.JSON(), nullable=True),
|
||||
sa.Column('priority', sa.String(length=20), nullable=True, default='normal'),
|
||||
sa.Column('scheduled_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('expires_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('template_id', sa.Integer(), nullable=True),
|
||||
sa.Column('channel_id', sa.Integer(), nullable=False),
|
||||
sa.Column('user_id', sa.Integer(), nullable=True),
|
||||
sa.Column('status', sa.String(length=20), nullable=True, default='pending'),
|
||||
sa.Column('attempts', sa.Integer(), nullable=True, default=0),
|
||||
sa.Column('max_attempts', sa.Integer(), nullable=True, default=3),
|
||||
sa.Column('sent_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('delivered_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('failed_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('error_message', sa.Text(), nullable=True),
|
||||
sa.Column('external_id', sa.String(length=200), nullable=True),
|
||||
sa.Column('callback_url', sa.String(length=500), nullable=True),
|
||||
sa.Column('metadata', sa.JSON(), nullable=True),
|
||||
sa.Column('tags', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_notifications_id'), 'notifications', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_notifications_status'), 'notifications', ['status'], unique=False)
|
||||
op.create_index(op.f('ix_notifications_scheduled_at'), 'notifications', ['scheduled_at'], unique=False)
|
||||
|
||||
# Add foreign key constraints
|
||||
op.create_foreign_key(
|
||||
'fk_notifications_template_id', 'notifications', 'notification_templates',
|
||||
['template_id'], ['id'], ondelete='SET NULL'
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'fk_notifications_channel_id', 'notifications', 'notification_channels',
|
||||
['channel_id'], ['id'], ondelete='CASCADE'
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'fk_notifications_user_id', 'notifications', 'users',
|
||||
['user_id'], ['id'], ondelete='SET NULL'
|
||||
)
|
||||
|
||||
def downgrade():
|
||||
# Drop notifications table
|
||||
op.drop_constraint('fk_notifications_user_id', table_name='notifications', type_='foreignkey')
|
||||
op.drop_constraint('fk_notifications_channel_id', table_name='notifications', type_='foreignkey')
|
||||
op.drop_constraint('fk_notifications_template_id', table_name='notifications', type_='foreignkey')
|
||||
op.drop_index(op.f('ix_notifications_scheduled_at'), table_name='notifications')
|
||||
op.drop_index(op.f('ix_notifications_status'), table_name='notifications')
|
||||
op.drop_index(op.f('ix_notifications_id'), table_name='notifications')
|
||||
op.drop_table('notifications')
|
||||
|
||||
# Drop notification_channels table
|
||||
op.drop_index(op.f('ix_notification_channels_name'), table_name='notification_channels')
|
||||
op.drop_index(op.f('ix_notification_channels_id'), table_name='notification_channels')
|
||||
op.drop_table('notification_channels')
|
||||
|
||||
# Drop notification_templates table
|
||||
op.drop_index(op.f('ix_notification_templates_name'), table_name='notification_templates')
|
||||
op.drop_index(op.f('ix_notification_templates_id'), table_name='notification_templates')
|
||||
op.drop_table('notification_templates')
|
||||
26
backend/alembic/versions/004_add_force_password_change.py
Normal file
26
backend/alembic/versions/004_add_force_password_change.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Add force_password_change to users
|
||||
|
||||
Revision ID: 004_add_force_password_change
|
||||
Revises: fd999a559a35
|
||||
Create Date: 2025-01-31 10:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '004_add_force_password_change'
|
||||
down_revision = 'fd999a559a35'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Add force_password_change column to users table
|
||||
op.add_column('users', sa.Column('force_password_change', sa.Boolean(), default=False, nullable=False, server_default='false'))
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Remove force_password_change column from users table
|
||||
op.drop_column('users', 'force_password_change')
|
||||
79
backend/alembic/versions/005_fix_user_nullable_columns.py
Normal file
79
backend/alembic/versions/005_fix_user_nullable_columns.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""fix user nullable columns
|
||||
|
||||
Revision ID: 005_fix_user_nullable_columns
|
||||
Revises: 004_add_force_password_change
|
||||
Create Date: 2025-11-20 08:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "005_fix_user_nullable_columns"
|
||||
down_revision = "004_add_force_password_change"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""
|
||||
Fix nullable columns in users table:
|
||||
- Backfill NULL values for account_locked and failed_login_attempts
|
||||
- Set proper server defaults
|
||||
- Alter columns to NOT NULL
|
||||
"""
|
||||
# Use connection to execute raw SQL for backfilling
|
||||
conn = op.get_bind()
|
||||
|
||||
# Backfill NULL values for account_locked
|
||||
conn.execute(
|
||||
sa.text("UPDATE users SET account_locked = FALSE WHERE account_locked IS NULL")
|
||||
)
|
||||
|
||||
# Backfill NULL values for failed_login_attempts
|
||||
conn.execute(
|
||||
sa.text("UPDATE users SET failed_login_attempts = 0 WHERE failed_login_attempts IS NULL")
|
||||
)
|
||||
|
||||
# Backfill NULL values for custom_permissions (use empty JSON object)
|
||||
conn.execute(
|
||||
sa.text("UPDATE users SET custom_permissions = '{}' WHERE custom_permissions IS NULL")
|
||||
)
|
||||
|
||||
# Now alter columns to NOT NULL with server defaults
|
||||
# Note: PostgreSQL syntax
|
||||
op.alter_column('users', 'account_locked',
|
||||
existing_type=sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false())
|
||||
|
||||
op.alter_column('users', 'failed_login_attempts',
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False,
|
||||
server_default='0')
|
||||
|
||||
op.alter_column('users', 'custom_permissions',
|
||||
existing_type=sa.JSON(),
|
||||
nullable=False,
|
||||
server_default='{}')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""
|
||||
Revert columns to nullable (original state from fd999a559a35)
|
||||
"""
|
||||
op.alter_column('users', 'account_locked',
|
||||
existing_type=sa.Boolean(),
|
||||
nullable=True,
|
||||
server_default=None)
|
||||
|
||||
op.alter_column('users', 'failed_login_attempts',
|
||||
existing_type=sa.Integer(),
|
||||
nullable=True,
|
||||
server_default=None)
|
||||
|
||||
op.alter_column('users', 'custom_permissions',
|
||||
existing_type=sa.JSON(),
|
||||
nullable=True,
|
||||
server_default=None)
|
||||
@@ -0,0 +1,71 @@
|
||||
"""fix missing user columns
|
||||
|
||||
Revision ID: fd999a559a35
|
||||
Revises: 003_add_notifications_tables
|
||||
Create Date: 2025-10-30 11:33:42.236622
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "fd999a559a35"
|
||||
down_revision = "003_add_notifications_tables"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add missing columns to users table
|
||||
# These columns should have been added in 001_add_roles_table.py but were not
|
||||
|
||||
# Use try/except to handle cases where columns might already exist
|
||||
try:
|
||||
op.add_column("users", sa.Column("custom_permissions", sa.JSON(), nullable=True, default=dict))
|
||||
except Exception:
|
||||
pass # Column might already exist
|
||||
|
||||
try:
|
||||
op.add_column("users", sa.Column("account_locked", sa.Boolean(), nullable=True, default=False))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
op.add_column("users", sa.Column("account_locked_until", sa.DateTime(), nullable=True))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
op.add_column("users", sa.Column("failed_login_attempts", sa.Integer(), nullable=True, default=0))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
op.add_column("users", sa.Column("last_failed_login", sa.DateTime(), nullable=True))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the columns
|
||||
try:
|
||||
op.drop_column("users", "last_failed_login")
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
op.drop_column("users", "failed_login_attempts")
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
op.drop_column("users", "account_locked_until")
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
op.drop_column("users", "account_locked")
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
op.drop_column("users", "custom_permissions")
|
||||
except Exception:
|
||||
pass
|
||||
@@ -44,7 +44,9 @@ class HealthChecker:
|
||||
await session.execute(select(1))
|
||||
|
||||
# Check table availability
|
||||
await session.execute(text("SELECT COUNT(*) FROM information_schema.tables"))
|
||||
await session.execute(
|
||||
text("SELECT COUNT(*) FROM information_schema.tables")
|
||||
)
|
||||
|
||||
duration = time.time() - start_time
|
||||
|
||||
@@ -54,8 +56,8 @@ class HealthChecker:
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"details": {
|
||||
"connection": "successful",
|
||||
"query_execution": "successful"
|
||||
}
|
||||
"query_execution": "successful",
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -64,10 +66,7 @@ class HealthChecker:
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"details": {
|
||||
"connection": "failed",
|
||||
"error_type": type(e).__name__
|
||||
}
|
||||
"details": {"connection": "failed", "error_type": type(e).__name__},
|
||||
}
|
||||
|
||||
async def check_memory_health(self) -> Dict[str, Any]:
|
||||
@@ -107,7 +106,7 @@ class HealthChecker:
|
||||
"process_memory_mb": round(process_memory_mb, 2),
|
||||
"system_memory_percent": memory.percent,
|
||||
"system_available_gb": round(memory.available / (1024**3), 2),
|
||||
"issues": issues
|
||||
"issues": issues,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -115,7 +114,7 @@ class HealthChecker:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
async def check_connection_health(self) -> Dict[str, Any]:
|
||||
@@ -128,8 +127,16 @@ class HealthChecker:
|
||||
|
||||
# Analyze connections
|
||||
total_connections = len(connections)
|
||||
established_connections = len([c for c in connections if c.status == 'ESTABLISHED'])
|
||||
http_connections = len([c for c in connections if any(port in str(c.laddr) for port in [80, 8000, 3000])])
|
||||
established_connections = len(
|
||||
[c for c in connections if c.status == "ESTABLISHED"]
|
||||
)
|
||||
http_connections = len(
|
||||
[
|
||||
c
|
||||
for c in connections
|
||||
if any(port in str(c.laddr) for port in [80, 8000, 3000])
|
||||
]
|
||||
)
|
||||
|
||||
# Check for connection issues
|
||||
connection_status = "healthy"
|
||||
@@ -154,7 +161,7 @@ class HealthChecker:
|
||||
"total_connections": total_connections,
|
||||
"established_connections": established_connections,
|
||||
"http_connections": http_connections,
|
||||
"issues": issues
|
||||
"issues": issues,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -162,7 +169,7 @@ class HealthChecker:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
async def check_embedding_service_health(self) -> Dict[str, Any]:
|
||||
@@ -193,7 +200,7 @@ class HealthChecker:
|
||||
"response_time_ms": round(duration * 1000, 2),
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"stats": stats,
|
||||
"issues": issues
|
||||
"issues": issues,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -201,7 +208,7 @@ class HealthChecker:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
async def check_redis_health(self) -> Dict[str, Any]:
|
||||
@@ -209,7 +216,7 @@ class HealthChecker:
|
||||
if not settings.REDIS_URL:
|
||||
return {
|
||||
"status": "not_configured",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -233,7 +240,7 @@ class HealthChecker:
|
||||
return {
|
||||
"status": "healthy",
|
||||
"response_time_ms": round(duration * 1000, 2),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -241,7 +248,7 @@ class HealthChecker:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
async def get_comprehensive_health(self) -> Dict[str, Any]:
|
||||
@@ -251,7 +258,7 @@ class HealthChecker:
|
||||
"memory": await self.check_memory_health(),
|
||||
"connections": await self.check_connection_health(),
|
||||
"embedding_service": await self.check_embedding_service_health(),
|
||||
"redis": await self.check_redis_health()
|
||||
"redis": await self.check_redis_health(),
|
||||
}
|
||||
|
||||
# Determine overall status
|
||||
@@ -274,12 +281,16 @@ class HealthChecker:
|
||||
"summary": {
|
||||
"total_checks": len(checks),
|
||||
"healthy_checks": len([s for s in statuses if s == "healthy"]),
|
||||
"degraded_checks": len([s for s in statuses if s in ["warning", "degraded", "unhealthy"]]),
|
||||
"failed_checks": len([s for s in statuses if s in ["critical", "error"]]),
|
||||
"total_issues": total_issues
|
||||
"degraded_checks": len(
|
||||
[s for s in statuses if s in ["warning", "degraded", "unhealthy"]]
|
||||
),
|
||||
"failed_checks": len(
|
||||
[s for s in statuses if s in ["critical", "error"]]
|
||||
),
|
||||
"total_issues": total_issues,
|
||||
},
|
||||
"version": "1.0.0",
|
||||
"uptime_seconds": int(time.time() - psutil.boot_time())
|
||||
"uptime_seconds": int(time.time() - psutil.boot_time()),
|
||||
}
|
||||
|
||||
|
||||
@@ -294,7 +305,7 @@ async def basic_health_check():
|
||||
"status": "healthy",
|
||||
"app": settings.APP_NAME,
|
||||
"version": "1.0.0",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@@ -307,7 +318,7 @@ async def detailed_health_check():
|
||||
logger.error(f"Detailed health check failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Health check failed: {str(e)}"
|
||||
detail=f"Health check failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -320,7 +331,7 @@ async def memory_health_check():
|
||||
logger.error(f"Memory health check failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Memory health check failed: {str(e)}"
|
||||
detail=f"Memory health check failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -333,7 +344,7 @@ async def connection_health_check():
|
||||
logger.error(f"Connection health check failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Connection health check failed: {str(e)}"
|
||||
detail=f"Connection health check failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -346,5 +357,5 @@ async def embedding_service_health_check():
|
||||
logger.error(f"Embedding service health check failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Embedding service health check failed: {str(e)}"
|
||||
detail=f"Embedding service health check failed: {str(e)}",
|
||||
)
|
||||
@@ -19,6 +19,7 @@ from ..v1.platform import router as platform_router
|
||||
from ..v1.llm_internal import router as llm_internal_router
|
||||
from ..v1.chatbot import router as chatbot_router
|
||||
from .debugging import router as debugging_router
|
||||
from ..v1.endpoints.user_management import router as user_management_router
|
||||
|
||||
# Create internal API router
|
||||
internal_api_router = APIRouter()
|
||||
@@ -27,47 +28,82 @@ internal_api_router = APIRouter()
|
||||
internal_api_router.include_router(auth_router, prefix="/auth", tags=["internal-auth"])
|
||||
|
||||
# Include modules routes (frontend management)
|
||||
internal_api_router.include_router(modules_router, prefix="/modules", tags=["internal-modules"])
|
||||
internal_api_router.include_router(
|
||||
modules_router, prefix="/modules", tags=["internal-modules"]
|
||||
)
|
||||
|
||||
# Include platform routes (frontend platform management)
|
||||
internal_api_router.include_router(platform_router, prefix="/platform", tags=["internal-platform"])
|
||||
internal_api_router.include_router(
|
||||
platform_router, prefix="/platform", tags=["internal-platform"]
|
||||
)
|
||||
|
||||
# Include user management routes (frontend user admin)
|
||||
internal_api_router.include_router(users_router, prefix="/users", tags=["internal-users"])
|
||||
internal_api_router.include_router(
|
||||
users_router, prefix="/users", tags=["internal-users"]
|
||||
)
|
||||
|
||||
# Include API key management routes (frontend API key management)
|
||||
internal_api_router.include_router(api_keys_router, prefix="/api-keys", tags=["internal-api-keys"])
|
||||
internal_api_router.include_router(
|
||||
api_keys_router, prefix="/api-keys", tags=["internal-api-keys"]
|
||||
)
|
||||
|
||||
# Include budget management routes (frontend budget management)
|
||||
internal_api_router.include_router(budgets_router, prefix="/budgets", tags=["internal-budgets"])
|
||||
internal_api_router.include_router(
|
||||
budgets_router, prefix="/budgets", tags=["internal-budgets"]
|
||||
)
|
||||
|
||||
# Include audit log routes (frontend audit viewing)
|
||||
internal_api_router.include_router(audit_router, prefix="/audit", tags=["internal-audit"])
|
||||
internal_api_router.include_router(
|
||||
audit_router, prefix="/audit", tags=["internal-audit"]
|
||||
)
|
||||
|
||||
# Include settings management routes (frontend settings)
|
||||
internal_api_router.include_router(settings_router, prefix="/settings", tags=["internal-settings"])
|
||||
internal_api_router.include_router(
|
||||
settings_router, prefix="/settings", tags=["internal-settings"]
|
||||
)
|
||||
|
||||
# Include analytics routes (frontend analytics viewing)
|
||||
internal_api_router.include_router(analytics_router, prefix="/analytics", tags=["internal-analytics"])
|
||||
internal_api_router.include_router(
|
||||
analytics_router, prefix="/analytics", tags=["internal-analytics"]
|
||||
)
|
||||
|
||||
# Include RAG routes (frontend RAG document management)
|
||||
internal_api_router.include_router(rag_router, prefix="/rag", tags=["internal-rag"])
|
||||
|
||||
# Include RAG debug routes (for demo and debugging)
|
||||
internal_api_router.include_router(rag_debug_router, prefix="/rag/debug", tags=["internal-rag-debug"])
|
||||
internal_api_router.include_router(
|
||||
rag_debug_router, prefix="/rag/debug", tags=["internal-rag-debug"]
|
||||
)
|
||||
|
||||
# Include prompt template routes (frontend prompt template management)
|
||||
internal_api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["internal-prompt-templates"])
|
||||
internal_api_router.include_router(
|
||||
prompt_templates_router,
|
||||
prefix="/prompt-templates",
|
||||
tags=["internal-prompt-templates"],
|
||||
)
|
||||
|
||||
|
||||
# Include plugin registry routes (frontend plugin management)
|
||||
internal_api_router.include_router(plugin_registry_router, prefix="/plugins", tags=["internal-plugins"])
|
||||
internal_api_router.include_router(
|
||||
plugin_registry_router, prefix="/plugins", tags=["internal-plugins"]
|
||||
)
|
||||
|
||||
# Include internal LLM routes (frontend LLM service access with JWT auth)
|
||||
internal_api_router.include_router(llm_internal_router, prefix="/llm", tags=["internal-llm"])
|
||||
internal_api_router.include_router(
|
||||
llm_internal_router, prefix="/llm", tags=["internal-llm"]
|
||||
)
|
||||
|
||||
# Include chatbot routes (frontend chatbot management)
|
||||
internal_api_router.include_router(chatbot_router, prefix="/chatbot", tags=["internal-chatbot"])
|
||||
internal_api_router.include_router(
|
||||
chatbot_router, prefix="/chatbot", tags=["internal-chatbot"]
|
||||
)
|
||||
|
||||
# Include debugging routes (troubleshooting and diagnostics)
|
||||
internal_api_router.include_router(debugging_router, prefix="/debugging", tags=["internal-debugging"])
|
||||
internal_api_router.include_router(
|
||||
debugging_router, prefix="/debugging", tags=["internal-debugging"]
|
||||
)
|
||||
|
||||
# Include user management routes (advanced user and role management)
|
||||
internal_api_router.include_router(
|
||||
user_management_router, prefix="/user-management", tags=["internal-user-management"]
|
||||
)
|
||||
|
||||
@@ -20,23 +20,28 @@ router = APIRouter()
|
||||
async def get_chatbot_config_debug(
|
||||
chatbot_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get detailed configuration for debugging a specific chatbot"""
|
||||
|
||||
# Get chatbot instance
|
||||
chatbot = db.query(ChatbotInstance).filter(
|
||||
ChatbotInstance.id == chatbot_id,
|
||||
ChatbotInstance.user_id == current_user.id
|
||||
).first()
|
||||
chatbot = (
|
||||
db.query(ChatbotInstance)
|
||||
.filter(
|
||||
ChatbotInstance.id == chatbot_id, ChatbotInstance.user_id == current_user.id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not chatbot:
|
||||
raise HTTPException(status_code=404, detail="Chatbot not found")
|
||||
|
||||
# Get prompt template
|
||||
prompt_template = db.query(PromptTemplate).filter(
|
||||
PromptTemplate.type == chatbot.chatbot_type
|
||||
).first()
|
||||
prompt_template = (
|
||||
db.query(PromptTemplate)
|
||||
.filter(PromptTemplate.type == chatbot.chatbot_type)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Get RAG collections if configured
|
||||
rag_collections = []
|
||||
@@ -44,31 +49,37 @@ async def get_chatbot_config_debug(
|
||||
collection_ids = chatbot.rag_collection_ids
|
||||
if isinstance(collection_ids, str):
|
||||
import json
|
||||
|
||||
try:
|
||||
collection_ids = json.loads(collection_ids)
|
||||
except:
|
||||
collection_ids = []
|
||||
|
||||
if collection_ids:
|
||||
collections = db.query(RagCollection).filter(
|
||||
RagCollection.id.in_(collection_ids)
|
||||
).all()
|
||||
collections = (
|
||||
db.query(RagCollection)
|
||||
.filter(RagCollection.id.in_(collection_ids))
|
||||
.all()
|
||||
)
|
||||
rag_collections = [
|
||||
{
|
||||
"id": col.id,
|
||||
"name": col.name,
|
||||
"document_count": col.document_count,
|
||||
"qdrant_collection_name": col.qdrant_collection_name,
|
||||
"is_active": col.is_active
|
||||
"is_active": col.is_active,
|
||||
}
|
||||
for col in collections
|
||||
]
|
||||
|
||||
# Get recent conversations count
|
||||
from app.models.chatbot import ChatbotConversation
|
||||
conversation_count = db.query(ChatbotConversation).filter(
|
||||
ChatbotConversation.chatbot_instance_id == chatbot_id
|
||||
).count()
|
||||
|
||||
conversation_count = (
|
||||
db.query(ChatbotConversation)
|
||||
.filter(ChatbotConversation.chatbot_instance_id == chatbot_id)
|
||||
.count()
|
||||
)
|
||||
|
||||
return {
|
||||
"chatbot": {
|
||||
@@ -78,20 +89,20 @@ async def get_chatbot_config_debug(
|
||||
"description": chatbot.description,
|
||||
"created_at": chatbot.created_at,
|
||||
"is_active": chatbot.is_active,
|
||||
"conversation_count": conversation_count
|
||||
"conversation_count": conversation_count,
|
||||
},
|
||||
"prompt_template": {
|
||||
"type": prompt_template.type if prompt_template else None,
|
||||
"system_prompt": prompt_template.system_prompt if prompt_template else None,
|
||||
"variables": prompt_template.variables if prompt_template else []
|
||||
"variables": prompt_template.variables if prompt_template else [],
|
||||
},
|
||||
"rag_collections": rag_collections,
|
||||
"configuration": {
|
||||
"max_tokens": chatbot.max_tokens,
|
||||
"temperature": chatbot.temperature,
|
||||
"streaming": chatbot.streaming,
|
||||
"memory_config": chatbot.memory_config
|
||||
}
|
||||
"memory_config": chatbot.memory_config,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -101,15 +112,18 @@ async def test_rag_search(
|
||||
query: str = "test query",
|
||||
top_k: int = 5,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Test RAG search for a specific chatbot"""
|
||||
|
||||
# Get chatbot instance
|
||||
chatbot = db.query(ChatbotInstance).filter(
|
||||
ChatbotInstance.id == chatbot_id,
|
||||
ChatbotInstance.user_id == current_user.id
|
||||
).first()
|
||||
chatbot = (
|
||||
db.query(ChatbotInstance)
|
||||
.filter(
|
||||
ChatbotInstance.id == chatbot_id, ChatbotInstance.user_id == current_user.id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not chatbot:
|
||||
raise HTTPException(status_code=404, detail="Chatbot not found")
|
||||
@@ -123,6 +137,7 @@ async def test_rag_search(
|
||||
if chatbot.rag_collection_ids:
|
||||
if isinstance(chatbot.rag_collection_ids, str):
|
||||
import json
|
||||
|
||||
try:
|
||||
collection_ids = json.loads(chatbot.rag_collection_ids)
|
||||
except:
|
||||
@@ -134,22 +149,19 @@ async def test_rag_search(
|
||||
return {
|
||||
"query": query,
|
||||
"results": [],
|
||||
"message": "No RAG collections configured for this chatbot"
|
||||
"message": "No RAG collections configured for this chatbot",
|
||||
}
|
||||
|
||||
# Perform search
|
||||
search_results = await rag_module.search(
|
||||
query=query,
|
||||
collection_ids=collection_ids,
|
||||
top_k=top_k,
|
||||
score_threshold=0.5
|
||||
query=query, collection_ids=collection_ids, top_k=top_k, score_threshold=0.5
|
||||
)
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"results": search_results,
|
||||
"collections_searched": collection_ids,
|
||||
"result_count": len(search_results)
|
||||
"result_count": len(search_results),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -157,14 +169,13 @@ async def test_rag_search(
|
||||
"query": query,
|
||||
"results": [],
|
||||
"error": str(e),
|
||||
"message": "RAG search failed"
|
||||
"message": "RAG search failed",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/system/status")
|
||||
async def get_system_status(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
db: Session = Depends(get_db), current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get system status for debugging"""
|
||||
|
||||
@@ -179,11 +190,12 @@ async def get_system_status(
|
||||
module_status = {}
|
||||
try:
|
||||
from app.services.module_manager import module_manager
|
||||
|
||||
modules = module_manager.list_modules()
|
||||
for module_name, module_info in modules.items():
|
||||
module_status[module_name] = {
|
||||
"status": module_info.get("status", "unknown"),
|
||||
"enabled": module_info.get("enabled", False)
|
||||
"enabled": module_info.get("enabled", False),
|
||||
}
|
||||
except Exception as e:
|
||||
module_status = {"error": str(e)}
|
||||
@@ -192,6 +204,7 @@ async def get_system_status(
|
||||
redis_status = "not configured"
|
||||
try:
|
||||
from app.core.cache import core_cache
|
||||
|
||||
await core_cache.ping()
|
||||
redis_status = "healthy"
|
||||
except Exception as e:
|
||||
@@ -201,6 +214,7 @@ async def get_system_status(
|
||||
qdrant_status = "not configured"
|
||||
try:
|
||||
from app.services.qdrant_service import qdrant_service
|
||||
|
||||
collections = await qdrant_service.list_collections()
|
||||
qdrant_status = f"healthy ({len(collections)} collections)"
|
||||
except Exception as e:
|
||||
@@ -211,5 +225,5 @@ async def get_system_status(
|
||||
"modules": module_status,
|
||||
"redis": redis_status,
|
||||
"qdrant": qdrant_status,
|
||||
"timestamp": "UTC"
|
||||
"timestamp": "UTC",
|
||||
}
|
||||
@@ -3,6 +3,7 @@ Public API v1 package - for external clients
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from ..v1.auth import router as auth_router
|
||||
from ..v1.llm import router as llm_router
|
||||
from ..v1.chatbot import router as chatbot_router
|
||||
from ..v1.openai_compat import router as openai_router
|
||||
@@ -10,6 +11,9 @@ from ..v1.openai_compat import router as openai_router
|
||||
# Create public API router
|
||||
public_api_router = APIRouter()
|
||||
|
||||
# Include authentication routes (needed for login/logout)
|
||||
public_api_router.include_router(auth_router, prefix="/auth", tags=["authentication"])
|
||||
|
||||
# Include OpenAI-compatible routes (chat/completions, models, embeddings)
|
||||
public_api_router.include_router(openai_router, tags=["openai-compat"])
|
||||
|
||||
@@ -17,4 +21,6 @@ public_api_router.include_router(openai_router, tags=["openai-compat"])
|
||||
public_api_router.include_router(llm_router, prefix="/llm", tags=["public-llm"])
|
||||
|
||||
# Include public chatbot API (external chatbot integrations)
|
||||
public_api_router.include_router(chatbot_router, prefix="/chatbot", tags=["public-chatbot"])
|
||||
public_api_router.include_router(
|
||||
chatbot_router, prefix="/chatbot", tags=["public-chatbot"]
|
||||
)
|
||||
|
||||
@@ -16,10 +16,9 @@ logger = logging.getLogger(__name__)
|
||||
# Create router
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/collections")
|
||||
async def list_collections(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
async def list_collections(current_user: User = Depends(get_current_user)):
|
||||
"""List all available RAG collections"""
|
||||
try:
|
||||
from app.services.qdrant_stats_service import qdrant_stats_service
|
||||
@@ -31,10 +30,7 @@ async def list_collections(
|
||||
# Extract collection names
|
||||
collection_names = [col["name"] for col in collections]
|
||||
|
||||
return {
|
||||
"collections": collection_names,
|
||||
"count": len(collection_names)
|
||||
}
|
||||
return {"collections": collection_names, "count": len(collection_names)}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"List collections error: {e}")
|
||||
@@ -45,10 +41,14 @@ async def list_collections(
|
||||
async def debug_search(
|
||||
query: str = Query(..., description="Search query"),
|
||||
max_results: int = Query(10, ge=1, le=50, description="Maximum number of results"),
|
||||
score_threshold: float = Query(0.3, ge=0.0, le=1.0, description="Minimum score threshold"),
|
||||
collection_name: Optional[str] = Query(None, description="Collection name to search"),
|
||||
score_threshold: float = Query(
|
||||
0.3, ge=0.0, le=1.0, description="Minimum score threshold"
|
||||
),
|
||||
collection_name: Optional[str] = Query(
|
||||
None, description="Collection name to search"
|
||||
),
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Debug search endpoint with detailed information"""
|
||||
try:
|
||||
@@ -56,9 +56,7 @@ async def debug_search(
|
||||
app_config = settings
|
||||
|
||||
# Initialize RAG module with BGE-M3 configuration
|
||||
rag_config = {
|
||||
"embedding_model": "BAAI/bge-m3"
|
||||
}
|
||||
rag_config = {"embedding_model": "BAAI/bge-m3"}
|
||||
rag_module = RAGModule(app_config, config=rag_config)
|
||||
|
||||
# Get available collections if none specified
|
||||
@@ -71,9 +69,9 @@ async def debug_search(
|
||||
"results": [],
|
||||
"debug_info": {
|
||||
"error": "No collections available",
|
||||
"collections_found": 0
|
||||
"collections_found": 0,
|
||||
},
|
||||
"search_time_ms": 0
|
||||
"search_time_ms": 0,
|
||||
}
|
||||
|
||||
# Perform search
|
||||
@@ -82,7 +80,7 @@ async def debug_search(
|
||||
max_results=max_results,
|
||||
score_threshold=score_threshold,
|
||||
collection_name=collection_name,
|
||||
config=config or {}
|
||||
config=config or {},
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -94,7 +92,7 @@ async def debug_search(
|
||||
"debug_info": {
|
||||
"error": str(e),
|
||||
"query": query,
|
||||
"collection_name": collection_name
|
||||
"collection_name": collection_name,
|
||||
},
|
||||
"search_time_ms": 0
|
||||
"search_time_ms": 0,
|
||||
}
|
||||
@@ -17,6 +17,9 @@ from .rag import router as rag_router
|
||||
from .chatbot import router as chatbot_router
|
||||
from .prompt_templates import router as prompt_templates_router
|
||||
from .plugin_registry import router as plugin_registry_router
|
||||
from .endpoints.tools import router as tools_router
|
||||
from .endpoints.tool_calling import router as tool_calling_router
|
||||
from .endpoints.user_management import router as user_management_router
|
||||
|
||||
# Create main API router
|
||||
api_router = APIRouter()
|
||||
@@ -58,9 +61,23 @@ api_router.include_router(rag_router, prefix="/rag", tags=["rag"])
|
||||
api_router.include_router(chatbot_router, prefix="/chatbot", tags=["chatbot"])
|
||||
|
||||
# Include prompt template routes
|
||||
api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["prompt-templates"])
|
||||
|
||||
api_router.include_router(
|
||||
prompt_templates_router, prefix="/prompt-templates", tags=["prompt-templates"]
|
||||
)
|
||||
|
||||
|
||||
# Include plugin registry routes
|
||||
api_router.include_router(plugin_registry_router, prefix="/plugins", tags=["plugins"])
|
||||
|
||||
# Include tool management routes
|
||||
api_router.include_router(tools_router, prefix="/tools", tags=["tools"])
|
||||
|
||||
# Include tool calling routes
|
||||
api_router.include_router(
|
||||
tool_calling_router, prefix="/tool-calling", tags=["tool-calling"]
|
||||
)
|
||||
|
||||
# Include admin user management routes
|
||||
api_router.include_router(
|
||||
user_management_router, prefix="/admin/user-management", tags=["admin", "user-management"]
|
||||
)
|
||||
|
||||
@@ -22,105 +22,98 @@ router = APIRouter()
|
||||
async def get_usage_metrics(
|
||||
hours: int = Query(24, ge=1, le=168, description="Hours to analyze (1-168)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get comprehensive usage metrics including costs and budgets"""
|
||||
try:
|
||||
analytics = get_analytics_service()
|
||||
metrics = await analytics.get_usage_metrics(hours=hours, user_id=current_user['id'])
|
||||
return {
|
||||
"success": True,
|
||||
"data": metrics,
|
||||
"period_hours": hours
|
||||
}
|
||||
metrics = await analytics.get_usage_metrics(
|
||||
hours=hours, user_id=current_user["id"]
|
||||
)
|
||||
return {"success": True, "data": metrics, "period_hours": hours}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting usage metrics: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error getting usage metrics: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/metrics/system")
|
||||
async def get_system_metrics(
|
||||
hours: int = Query(24, ge=1, le=168),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get system-wide metrics (admin only)"""
|
||||
if not current_user['is_superuser']:
|
||||
if not current_user["is_superuser"]:
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
try:
|
||||
analytics = get_analytics_service()
|
||||
metrics = await analytics.get_usage_metrics(hours=hours)
|
||||
return {
|
||||
"success": True,
|
||||
"data": metrics,
|
||||
"period_hours": hours
|
||||
}
|
||||
return {"success": True, "data": metrics, "period_hours": hours}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting system metrics: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error getting system metrics: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def get_system_health(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get system health status including budget and performance analysis"""
|
||||
try:
|
||||
analytics = get_analytics_service()
|
||||
health = await analytics.get_system_health()
|
||||
return {
|
||||
"success": True,
|
||||
"data": health
|
||||
}
|
||||
return {"success": True, "data": health}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting system health: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error getting system health: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/costs")
|
||||
async def get_cost_analysis(
|
||||
days: int = Query(30, ge=1, le=365, description="Days to analyze (1-365)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get detailed cost analysis and trends"""
|
||||
try:
|
||||
analytics = get_analytics_service()
|
||||
analysis = await analytics.get_cost_analysis(days=days, user_id=current_user['id'])
|
||||
return {
|
||||
"success": True,
|
||||
"data": analysis,
|
||||
"period_days": days
|
||||
}
|
||||
analysis = await analytics.get_cost_analysis(
|
||||
days=days, user_id=current_user["id"]
|
||||
)
|
||||
return {"success": True, "data": analysis, "period_days": days}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting cost analysis: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error getting cost analysis: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/costs/system")
|
||||
async def get_system_cost_analysis(
|
||||
days: int = Query(30, ge=1, le=365),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get system-wide cost analysis (admin only)"""
|
||||
if not current_user['is_superuser']:
|
||||
if not current_user["is_superuser"]:
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
try:
|
||||
analytics = get_analytics_service()
|
||||
analysis = await analytics.get_cost_analysis(days=days)
|
||||
return {
|
||||
"success": True,
|
||||
"data": analysis,
|
||||
"period_days": days
|
||||
}
|
||||
return {"success": True, "data": analysis, "period_days": days}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting system cost analysis: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error getting system cost analysis: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/endpoints")
|
||||
async def get_endpoint_stats(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get endpoint usage statistics"""
|
||||
try:
|
||||
@@ -133,18 +126,20 @@ async def get_endpoint_stats(
|
||||
"data": {
|
||||
"endpoint_stats": dict(analytics.endpoint_stats),
|
||||
"status_codes": dict(analytics.status_codes),
|
||||
"model_stats": dict(analytics.model_stats)
|
||||
}
|
||||
"model_stats": dict(analytics.model_stats),
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting endpoint stats: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error getting endpoint stats: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/usage-trends")
|
||||
async def get_usage_trends(
|
||||
days: int = Query(7, ge=1, le=30, description="Days for trend analysis"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get usage trends over time"""
|
||||
try:
|
||||
@@ -155,50 +150,53 @@ async def get_usage_trends(
|
||||
cutoff_time = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
# Daily usage trends
|
||||
daily_usage = db.query(
|
||||
func.date(UsageTracking.created_at).label('date'),
|
||||
func.count(UsageTracking.id).label('requests'),
|
||||
func.sum(UsageTracking.total_tokens).label('tokens'),
|
||||
func.sum(UsageTracking.cost_cents).label('cost_cents')
|
||||
).filter(
|
||||
daily_usage = (
|
||||
db.query(
|
||||
func.date(UsageTracking.created_at).label("date"),
|
||||
func.count(UsageTracking.id).label("requests"),
|
||||
func.sum(UsageTracking.total_tokens).label("tokens"),
|
||||
func.sum(UsageTracking.cost_cents).label("cost_cents"),
|
||||
)
|
||||
.filter(
|
||||
UsageTracking.created_at >= cutoff_time,
|
||||
UsageTracking.user_id == current_user['id']
|
||||
).group_by(
|
||||
func.date(UsageTracking.created_at)
|
||||
).order_by('date').all()
|
||||
UsageTracking.user_id == current_user["id"],
|
||||
)
|
||||
.group_by(func.date(UsageTracking.created_at))
|
||||
.order_by("date")
|
||||
.all()
|
||||
)
|
||||
|
||||
trends = []
|
||||
for date, requests, tokens, cost_cents in daily_usage:
|
||||
trends.append({
|
||||
trends.append(
|
||||
{
|
||||
"date": date.isoformat(),
|
||||
"requests": requests,
|
||||
"tokens": tokens or 0,
|
||||
"cost_cents": cost_cents or 0,
|
||||
"cost_dollars": (cost_cents or 0) / 100
|
||||
})
|
||||
"cost_dollars": (cost_cents or 0) / 100,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"trends": trends,
|
||||
"period_days": days
|
||||
}
|
||||
}
|
||||
return {"success": True, "data": {"trends": trends, "period_days": days}}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting usage trends: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error getting usage trends: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/overview")
|
||||
async def get_analytics_overview(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get analytics overview data"""
|
||||
try:
|
||||
analytics = get_analytics_service()
|
||||
|
||||
# Get basic metrics
|
||||
metrics = await analytics.get_usage_metrics(hours=24, user_id=current_user['id'])
|
||||
metrics = await analytics.get_usage_metrics(
|
||||
hours=24, user_id=current_user["id"]
|
||||
)
|
||||
health = await analytics.get_system_health()
|
||||
|
||||
return {
|
||||
@@ -210,8 +208,8 @@ async def get_analytics_overview(
|
||||
"error_rate": metrics.error_rate,
|
||||
"budget_usage_percentage": metrics.budget_usage_percentage,
|
||||
"system_health": health.status,
|
||||
"health_score": health.score
|
||||
}
|
||||
"health_score": health.score,
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting overview: {str(e)}")
|
||||
@@ -219,8 +217,7 @@ async def get_analytics_overview(
|
||||
|
||||
@router.get("/modules")
|
||||
async def get_module_analytics(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get analytics data for all modules"""
|
||||
try:
|
||||
@@ -248,10 +245,14 @@ async def get_module_analytics(
|
||||
"data": {
|
||||
"modules": module_stats,
|
||||
"total_modules": len(module_stats),
|
||||
"system_health": "healthy" if all(m.get("initialized", False) for m in module_stats) else "warning"
|
||||
}
|
||||
"system_health": "healthy"
|
||||
if all(m.get("initialized", False) for m in module_stats)
|
||||
else "warning",
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get module analytics: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve module analytics")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to retrieve module analytics"
|
||||
)
|
||||
|
||||
@@ -74,14 +74,16 @@ class APIKeyResponse(BaseModel):
|
||||
last_used_at: Optional[datetime] = None
|
||||
total_requests: int
|
||||
total_tokens: int
|
||||
total_cost_cents: int = Field(alias='total_cost')
|
||||
total_cost_cents: int = Field(alias="total_cost")
|
||||
rate_limit_per_minute: Optional[int] = None
|
||||
rate_limit_per_hour: Optional[int] = None
|
||||
rate_limit_per_day: Optional[int] = None
|
||||
allowed_ips: List[str]
|
||||
allowed_models: List[str] # Model restrictions
|
||||
allowed_chatbots: List[str] # Chatbot restrictions
|
||||
budget_limit: Optional[int] = Field(None, alias='budget_limit_cents') # Budget limit in cents
|
||||
budget_limit: Optional[int] = Field(
|
||||
None, alias="budget_limit_cents"
|
||||
) # Budget limit in cents
|
||||
budget_type: Optional[str] = None # Budget type
|
||||
is_unlimited: bool = True # Unlimited budget flag
|
||||
tags: List[str]
|
||||
@@ -93,28 +95,28 @@ class APIKeyResponse(BaseModel):
|
||||
def from_api_key(cls, api_key):
|
||||
"""Create response from APIKey model with formatted key prefix"""
|
||||
data = {
|
||||
'id': api_key.id,
|
||||
'name': api_key.name,
|
||||
'description': api_key.description,
|
||||
'key_prefix': api_key.key_prefix + "..." if api_key.key_prefix else "",
|
||||
'scopes': api_key.scopes,
|
||||
'is_active': api_key.is_active,
|
||||
'expires_at': api_key.expires_at,
|
||||
'created_at': api_key.created_at,
|
||||
'last_used_at': api_key.last_used_at,
|
||||
'total_requests': api_key.total_requests,
|
||||
'total_tokens': api_key.total_tokens,
|
||||
'total_cost': api_key.total_cost,
|
||||
'rate_limit_per_minute': api_key.rate_limit_per_minute,
|
||||
'rate_limit_per_hour': api_key.rate_limit_per_hour,
|
||||
'rate_limit_per_day': api_key.rate_limit_per_day,
|
||||
'allowed_ips': api_key.allowed_ips,
|
||||
'allowed_models': api_key.allowed_models,
|
||||
'allowed_chatbots': api_key.allowed_chatbots,
|
||||
'budget_limit_cents': api_key.budget_limit_cents,
|
||||
'budget_type': api_key.budget_type,
|
||||
'is_unlimited': api_key.is_unlimited,
|
||||
'tags': api_key.tags
|
||||
"id": api_key.id,
|
||||
"name": api_key.name,
|
||||
"description": api_key.description,
|
||||
"key_prefix": api_key.key_prefix + "..." if api_key.key_prefix else "",
|
||||
"scopes": api_key.scopes,
|
||||
"is_active": api_key.is_active,
|
||||
"expires_at": api_key.expires_at,
|
||||
"created_at": api_key.created_at,
|
||||
"last_used_at": api_key.last_used_at,
|
||||
"total_requests": api_key.total_requests,
|
||||
"total_tokens": api_key.total_tokens,
|
||||
"total_cost": api_key.total_cost,
|
||||
"rate_limit_per_minute": api_key.rate_limit_per_minute,
|
||||
"rate_limit_per_hour": api_key.rate_limit_per_hour,
|
||||
"rate_limit_per_day": api_key.rate_limit_per_day,
|
||||
"allowed_ips": api_key.allowed_ips,
|
||||
"allowed_models": api_key.allowed_models,
|
||||
"allowed_chatbots": api_key.allowed_chatbots,
|
||||
"budget_limit_cents": api_key.budget_limit_cents,
|
||||
"budget_type": api_key.budget_type,
|
||||
"is_unlimited": api_key.is_unlimited,
|
||||
"tags": api_key.tags,
|
||||
}
|
||||
return cls(**data)
|
||||
|
||||
@@ -148,13 +150,16 @@ class APIKeyUsageResponse(BaseModel):
|
||||
def generate_api_key() -> tuple[str, str]:
|
||||
"""Generate a new API key and return (full_key, key_hash)"""
|
||||
# Generate random key part (32 characters)
|
||||
key_part = ''.join(secrets.choice(string.ascii_letters + string.digits) for _ in range(32))
|
||||
key_part = "".join(
|
||||
secrets.choice(string.ascii_letters + string.digits) for _ in range(32)
|
||||
)
|
||||
|
||||
# Create full key with prefix
|
||||
full_key = f"{settings.API_KEY_PREFIX}{key_part}"
|
||||
|
||||
# Create hash for storage
|
||||
from app.core.security import get_api_key_hash
|
||||
|
||||
key_hash = get_api_key_hash(full_key)
|
||||
|
||||
return full_key, key_hash
|
||||
@@ -169,32 +174,40 @@ async def list_api_keys(
|
||||
is_active: Optional[bool] = Query(None),
|
||||
search: Optional[str] = Query(None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List API keys with pagination and filtering"""
|
||||
|
||||
# Check permissions - users can view their own API keys
|
||||
if user_id and int(user_id) != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
|
||||
if user_id and int(user_id) != current_user["id"]:
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:api-keys:read"
|
||||
)
|
||||
elif not user_id:
|
||||
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:api-keys:read"
|
||||
)
|
||||
|
||||
# If no user_id specified and user doesn't have admin permissions, show only their keys
|
||||
if not user_id and "platform:api-keys:read" not in current_user.get("permissions", []):
|
||||
user_id = current_user['id']
|
||||
if not user_id and "platform:api-keys:read" not in current_user.get(
|
||||
"permissions", []
|
||||
):
|
||||
user_id = current_user["id"]
|
||||
|
||||
# Build query
|
||||
query = select(APIKey)
|
||||
|
||||
# Apply filters
|
||||
if user_id:
|
||||
query = query.where(APIKey.user_id == (int(user_id) if isinstance(user_id, str) else user_id))
|
||||
query = query.where(
|
||||
APIKey.user_id == (int(user_id) if isinstance(user_id, str) else user_id)
|
||||
)
|
||||
if is_active is not None:
|
||||
query = query.where(APIKey.is_active == is_active)
|
||||
if search:
|
||||
query = query.where(
|
||||
(APIKey.name.ilike(f"%{search}%")) |
|
||||
(APIKey.description.ilike(f"%{search}%"))
|
||||
(APIKey.name.ilike(f"%{search}%"))
|
||||
| (APIKey.description.ilike(f"%{search}%"))
|
||||
)
|
||||
|
||||
# Get total count using func.count()
|
||||
@@ -202,13 +215,15 @@ async def list_api_keys(
|
||||
|
||||
# Apply same filters for count
|
||||
if user_id:
|
||||
total_query = total_query.where(APIKey.user_id == (int(user_id) if isinstance(user_id, str) else user_id))
|
||||
total_query = total_query.where(
|
||||
APIKey.user_id == (int(user_id) if isinstance(user_id, str) else user_id)
|
||||
)
|
||||
if is_active is not None:
|
||||
total_query = total_query.where(APIKey.is_active == is_active)
|
||||
if search:
|
||||
total_query = total_query.where(
|
||||
(APIKey.name.ilike(f"%{search}%")) |
|
||||
(APIKey.description.ilike(f"%{search}%"))
|
||||
(APIKey.name.ilike(f"%{search}%"))
|
||||
| (APIKey.description.ilike(f"%{search}%"))
|
||||
)
|
||||
|
||||
total_result = await db.execute(total_query)
|
||||
@@ -225,17 +240,21 @@ async def list_api_keys(
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="list_api_keys",
|
||||
resource_type="api_key",
|
||||
details={"page": page, "size": size, "filters": {"user_id": user_id, "is_active": is_active, "search": search}}
|
||||
details={
|
||||
"page": page,
|
||||
"size": size,
|
||||
"filters": {"user_id": user_id, "is_active": is_active, "search": search},
|
||||
},
|
||||
)
|
||||
|
||||
return APIKeyListResponse(
|
||||
api_keys=[APIKeyResponse.model_validate(key) for key in api_keys],
|
||||
total=total,
|
||||
page=page,
|
||||
size=size
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
@@ -243,7 +262,7 @@ async def list_api_keys(
|
||||
async def get_api_key(
|
||||
api_key_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get API key by ID"""
|
||||
|
||||
@@ -254,21 +273,22 @@ async def get_api_key(
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
|
||||
)
|
||||
|
||||
# Check permissions - users can view their own API keys
|
||||
if api_key.user_id != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
|
||||
if api_key.user_id != current_user["id"]:
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:api-keys:read"
|
||||
)
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="get_api_key",
|
||||
resource_type="api_key",
|
||||
resource_id=api_key_id
|
||||
resource_id=api_key_id,
|
||||
)
|
||||
|
||||
return APIKeyResponse.model_validate(api_key)
|
||||
@@ -278,7 +298,7 @@ async def get_api_key(
|
||||
async def create_api_key(
|
||||
api_key_data: APIKeyCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a new API key"""
|
||||
|
||||
@@ -295,7 +315,7 @@ async def create_api_key(
|
||||
description=api_key_data.description,
|
||||
key_hash=key_hash,
|
||||
key_prefix=key_prefix,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
scopes=api_key_data.scopes,
|
||||
expires_at=api_key_data.expires_at,
|
||||
rate_limit_per_minute=api_key_data.rate_limit_per_minute,
|
||||
@@ -305,9 +325,11 @@ async def create_api_key(
|
||||
allowed_models=api_key_data.allowed_models,
|
||||
allowed_chatbots=api_key_data.allowed_chatbots,
|
||||
is_unlimited=api_key_data.is_unlimited,
|
||||
budget_limit_cents=api_key_data.budget_limit_cents if not api_key_data.is_unlimited else None,
|
||||
budget_limit_cents=api_key_data.budget_limit_cents
|
||||
if not api_key_data.is_unlimited
|
||||
else None,
|
||||
budget_type=api_key_data.budget_type if not api_key_data.is_unlimited else None,
|
||||
tags=api_key_data.tags
|
||||
tags=api_key_data.tags,
|
||||
)
|
||||
|
||||
db.add(new_api_key)
|
||||
@@ -315,19 +337,20 @@ async def create_api_key(
|
||||
await db.refresh(new_api_key)
|
||||
|
||||
# Log audit event asynchronously (non-blocking)
|
||||
asyncio.create_task(log_audit_event_async(
|
||||
user_id=str(current_user['id']),
|
||||
asyncio.create_task(
|
||||
log_audit_event_async(
|
||||
user_id=str(current_user["id"]),
|
||||
action="create_api_key",
|
||||
resource_type="api_key",
|
||||
resource_id=str(new_api_key.id),
|
||||
details={"name": api_key_data.name, "scopes": api_key_data.scopes}
|
||||
))
|
||||
details={"name": api_key_data.name, "scopes": api_key_data.scopes},
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"API key created: {new_api_key.name} by {current_user['username']}")
|
||||
|
||||
return APIKeyCreateResponse(
|
||||
api_key=APIKeyResponse.model_validate(new_api_key),
|
||||
secret_key=full_key
|
||||
api_key=APIKeyResponse.model_validate(new_api_key), secret_key=full_key
|
||||
)
|
||||
|
||||
|
||||
@@ -336,7 +359,7 @@ async def update_api_key(
|
||||
api_key_id: str,
|
||||
api_key_data: APIKeyUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update API key"""
|
||||
|
||||
@@ -347,19 +370,20 @@ async def update_api_key(
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
|
||||
)
|
||||
|
||||
# Check permissions - users can update their own API keys
|
||||
if api_key.user_id != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:api-keys:update")
|
||||
if api_key.user_id != current_user["id"]:
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:api-keys:update"
|
||||
)
|
||||
|
||||
# Store original values for audit
|
||||
original_values = {
|
||||
"name": api_key.name,
|
||||
"scopes": api_key.scopes,
|
||||
"is_active": api_key.is_active
|
||||
"is_active": api_key.is_active,
|
||||
}
|
||||
|
||||
# Update API key fields
|
||||
@@ -373,15 +397,15 @@ async def update_api_key(
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="update_api_key",
|
||||
resource_type="api_key",
|
||||
resource_id=api_key_id,
|
||||
details={
|
||||
"updated_fields": list(update_data.keys()),
|
||||
"before_values": original_values,
|
||||
"after_values": {k: getattr(api_key, k) for k in update_data.keys()}
|
||||
}
|
||||
"after_values": {k: getattr(api_key, k) for k in update_data.keys()},
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"API key updated: {api_key.name} by {current_user['username']}")
|
||||
@@ -393,7 +417,7 @@ async def update_api_key(
|
||||
async def delete_api_key(
|
||||
api_key_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Delete API key"""
|
||||
|
||||
@@ -404,13 +428,14 @@ async def delete_api_key(
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
|
||||
)
|
||||
|
||||
# Check permissions - users can delete their own API keys
|
||||
if api_key.user_id != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:api-keys:delete")
|
||||
if api_key.user_id != current_user["id"]:
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:api-keys:delete"
|
||||
)
|
||||
|
||||
# Delete API key
|
||||
await db.delete(api_key)
|
||||
@@ -419,11 +444,11 @@ async def delete_api_key(
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="delete_api_key",
|
||||
resource_type="api_key",
|
||||
resource_id=api_key_id,
|
||||
details={"name": api_key.name}
|
||||
details={"name": api_key.name},
|
||||
)
|
||||
|
||||
logger.info(f"API key deleted: {api_key.name} by {current_user['username']}")
|
||||
@@ -435,7 +460,7 @@ async def delete_api_key(
|
||||
async def regenerate_api_key(
|
||||
api_key_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Regenerate API key secret"""
|
||||
|
||||
@@ -446,13 +471,14 @@ async def regenerate_api_key(
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
|
||||
)
|
||||
|
||||
# Check permissions - users can regenerate their own API keys
|
||||
if api_key.user_id != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:api-keys:update")
|
||||
if api_key.user_id != current_user["id"]:
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:api-keys:update"
|
||||
)
|
||||
|
||||
# Generate new API key
|
||||
full_key, key_hash = generate_api_key()
|
||||
@@ -468,18 +494,17 @@ async def regenerate_api_key(
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="regenerate_api_key",
|
||||
resource_type="api_key",
|
||||
resource_id=api_key_id,
|
||||
details={"name": api_key.name}
|
||||
details={"name": api_key.name},
|
||||
)
|
||||
|
||||
logger.info(f"API key regenerated: {api_key.name} by {current_user['username']}")
|
||||
|
||||
return APIKeyCreateResponse(
|
||||
api_key=APIKeyResponse.model_validate(api_key),
|
||||
secret_key=full_key
|
||||
api_key=APIKeyResponse.model_validate(api_key), secret_key=full_key
|
||||
)
|
||||
|
||||
|
||||
@@ -487,7 +512,7 @@ async def regenerate_api_key(
|
||||
async def get_api_key_usage(
|
||||
api_key_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get API key usage statistics"""
|
||||
|
||||
@@ -498,13 +523,14 @@ async def get_api_key_usage(
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
|
||||
)
|
||||
|
||||
# Check permissions - users can view their own API key usage
|
||||
if api_key.user_id != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
|
||||
if api_key.user_id != current_user["id"]:
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:api-keys:read"
|
||||
)
|
||||
|
||||
# Calculate usage statistics
|
||||
from app.models.usage_tracking import UsageTracking
|
||||
@@ -517,10 +543,9 @@ async def get_api_key_usage(
|
||||
today_query = select(
|
||||
func.count(UsageTracking.id),
|
||||
func.sum(UsageTracking.total_tokens),
|
||||
func.sum(UsageTracking.cost_cents)
|
||||
func.sum(UsageTracking.cost_cents),
|
||||
).where(
|
||||
UsageTracking.api_key_id == api_key_id,
|
||||
UsageTracking.created_at >= today_start
|
||||
UsageTracking.api_key_id == api_key_id, UsageTracking.created_at >= today_start
|
||||
)
|
||||
today_result = await db.execute(today_query)
|
||||
today_stats = today_result.first()
|
||||
@@ -529,10 +554,9 @@ async def get_api_key_usage(
|
||||
hour_query = select(
|
||||
func.count(UsageTracking.id),
|
||||
func.sum(UsageTracking.total_tokens),
|
||||
func.sum(UsageTracking.cost_cents)
|
||||
func.sum(UsageTracking.cost_cents),
|
||||
).where(
|
||||
UsageTracking.api_key_id == api_key_id,
|
||||
UsageTracking.created_at >= hour_start
|
||||
UsageTracking.api_key_id == api_key_id, UsageTracking.created_at >= hour_start
|
||||
)
|
||||
hour_result = await db.execute(hour_query)
|
||||
hour_stats = hour_result.first()
|
||||
@@ -540,10 +564,10 @@ async def get_api_key_usage(
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="get_api_key_usage",
|
||||
resource_type="api_key",
|
||||
resource_id=api_key_id
|
||||
resource_id=api_key_id,
|
||||
)
|
||||
|
||||
return APIKeyUsageResponse(
|
||||
@@ -557,7 +581,7 @@ async def get_api_key_usage(
|
||||
requests_this_hour=hour_stats[0] or 0,
|
||||
tokens_this_hour=hour_stats[1] or 0,
|
||||
cost_this_hour_cents=hour_stats[2] or 0,
|
||||
last_used_at=api_key.last_used_at
|
||||
last_used_at=api_key.last_used_at,
|
||||
)
|
||||
|
||||
|
||||
@@ -565,7 +589,7 @@ async def get_api_key_usage(
|
||||
async def activate_api_key(
|
||||
api_key_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Activate API key"""
|
||||
|
||||
@@ -576,13 +600,14 @@ async def activate_api_key(
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
|
||||
)
|
||||
|
||||
# Check permissions - users can activate their own API keys
|
||||
if api_key.user_id != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:api-keys:update")
|
||||
if api_key.user_id != current_user["id"]:
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:api-keys:update"
|
||||
)
|
||||
|
||||
# Activate API key
|
||||
api_key.is_active = True
|
||||
@@ -591,11 +616,11 @@ async def activate_api_key(
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="activate_api_key",
|
||||
resource_type="api_key",
|
||||
resource_id=api_key_id,
|
||||
details={"name": api_key.name}
|
||||
details={"name": api_key.name},
|
||||
)
|
||||
|
||||
logger.info(f"API key activated: {api_key.name} by {current_user['username']}")
|
||||
@@ -607,7 +632,7 @@ async def activate_api_key(
|
||||
async def deactivate_api_key(
|
||||
api_key_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Deactivate API key"""
|
||||
|
||||
@@ -618,13 +643,14 @@ async def deactivate_api_key(
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
|
||||
)
|
||||
|
||||
# Check permissions - users can deactivate their own API keys
|
||||
if api_key.user_id != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:api-keys:update")
|
||||
if api_key.user_id != current_user["id"]:
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:api-keys:update"
|
||||
)
|
||||
|
||||
# Deactivate API key
|
||||
api_key.is_active = False
|
||||
@@ -633,11 +659,11 @@ async def deactivate_api_key(
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="deactivate_api_key",
|
||||
resource_type="api_key",
|
||||
resource_id=api_key_id,
|
||||
details={"name": api_key.name}
|
||||
details={"name": api_key.name},
|
||||
)
|
||||
|
||||
logger.info(f"API key deactivated: {api_key.name} by {current_user['username']}")
|
||||
|
||||
@@ -96,7 +96,7 @@ async def list_audit_logs(
|
||||
severity: Optional[str] = Query(None),
|
||||
search: Optional[str] = Query(None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List audit logs with filtering and pagination"""
|
||||
|
||||
@@ -128,7 +128,7 @@ async def list_audit_logs(
|
||||
search_conditions = [
|
||||
AuditLog.action.ilike(f"%{search}%"),
|
||||
AuditLog.resource_type.ilike(f"%{search}%"),
|
||||
AuditLog.details.astext.ilike(f"%{search}%")
|
||||
AuditLog.details.astext.ilike(f"%{search}%"),
|
||||
]
|
||||
conditions.append(or_(*search_conditions))
|
||||
|
||||
@@ -166,19 +166,19 @@ async def list_audit_logs(
|
||||
"end_date": end_date.isoformat() if end_date else None,
|
||||
"success": success,
|
||||
"severity": severity,
|
||||
"search": search
|
||||
"search": search,
|
||||
},
|
||||
"page": page,
|
||||
"size": size,
|
||||
"total_results": total
|
||||
}
|
||||
"total_results": total,
|
||||
},
|
||||
)
|
||||
|
||||
return AuditLogListResponse(
|
||||
logs=[AuditLogResponse.model_validate(log) for log in logs],
|
||||
total=total,
|
||||
page=page,
|
||||
size=size
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
@@ -188,7 +188,7 @@ async def search_audit_logs(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(50, ge=1, le=1000),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Advanced search for audit logs"""
|
||||
|
||||
@@ -205,7 +205,7 @@ async def search_audit_logs(
|
||||
start_date=search_request.start_date,
|
||||
end_date=search_request.end_date,
|
||||
limit=size,
|
||||
offset=(page - 1) * size
|
||||
offset=(page - 1) * size,
|
||||
)
|
||||
|
||||
# Get total count for the search
|
||||
@@ -234,7 +234,7 @@ async def search_audit_logs(
|
||||
search_conditions = [
|
||||
AuditLog.action.ilike(f"%{search_request.search_text}%"),
|
||||
AuditLog.resource_type.ilike(f"%{search_request.search_text}%"),
|
||||
AuditLog.details.astext.ilike(f"%{search_request.search_text}%")
|
||||
AuditLog.details.astext.ilike(f"%{search_request.search_text}%"),
|
||||
]
|
||||
conditions.append(or_(*search_conditions))
|
||||
|
||||
@@ -253,15 +253,15 @@ async def search_audit_logs(
|
||||
details={
|
||||
"search_criteria": search_request.model_dump(exclude_unset=True),
|
||||
"results_count": len(logs),
|
||||
"total_matches": total
|
||||
}
|
||||
"total_matches": total,
|
||||
},
|
||||
)
|
||||
|
||||
return AuditLogListResponse(
|
||||
logs=[AuditLogResponse.model_validate(log) for log in logs],
|
||||
total=total,
|
||||
page=page,
|
||||
size=size
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
@@ -270,7 +270,7 @@ async def get_audit_statistics(
|
||||
start_date: Optional[datetime] = Query(None),
|
||||
end_date: Optional[datetime] = Query(None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get audit log statistics"""
|
||||
|
||||
@@ -287,46 +287,62 @@ async def get_audit_statistics(
|
||||
basic_stats = await get_audit_stats(db, start_date, end_date)
|
||||
|
||||
# Get additional statistics
|
||||
conditions = [
|
||||
AuditLog.created_at >= start_date,
|
||||
AuditLog.created_at <= end_date
|
||||
]
|
||||
conditions = [AuditLog.created_at >= start_date, AuditLog.created_at <= end_date]
|
||||
|
||||
# Events by user
|
||||
user_query = select(
|
||||
AuditLog.user_id,
|
||||
func.count(AuditLog.id).label('count')
|
||||
).where(and_(*conditions)).group_by(AuditLog.user_id).order_by(func.count(AuditLog.id).desc()).limit(10)
|
||||
user_query = (
|
||||
select(AuditLog.user_id, func.count(AuditLog.id).label("count"))
|
||||
.where(and_(*conditions))
|
||||
.group_by(AuditLog.user_id)
|
||||
.order_by(func.count(AuditLog.id).desc())
|
||||
.limit(10)
|
||||
)
|
||||
|
||||
user_result = await db.execute(user_query)
|
||||
events_by_user = dict(user_result.fetchall())
|
||||
|
||||
# Events by hour of day
|
||||
hour_query = select(
|
||||
func.extract('hour', AuditLog.created_at).label('hour'),
|
||||
func.count(AuditLog.id).label('count')
|
||||
).where(and_(*conditions)).group_by(func.extract('hour', AuditLog.created_at)).order_by('hour')
|
||||
hour_query = (
|
||||
select(
|
||||
func.extract("hour", AuditLog.created_at).label("hour"),
|
||||
func.count(AuditLog.id).label("count"),
|
||||
)
|
||||
.where(and_(*conditions))
|
||||
.group_by(func.extract("hour", AuditLog.created_at))
|
||||
.order_by("hour")
|
||||
)
|
||||
|
||||
hour_result = await db.execute(hour_query)
|
||||
events_by_hour = dict(hour_result.fetchall())
|
||||
|
||||
# Top actions
|
||||
top_actions_query = select(
|
||||
AuditLog.action,
|
||||
func.count(AuditLog.id).label('count')
|
||||
).where(and_(*conditions)).group_by(AuditLog.action).order_by(func.count(AuditLog.id).desc()).limit(10)
|
||||
top_actions_query = (
|
||||
select(AuditLog.action, func.count(AuditLog.id).label("count"))
|
||||
.where(and_(*conditions))
|
||||
.group_by(AuditLog.action)
|
||||
.order_by(func.count(AuditLog.id).desc())
|
||||
.limit(10)
|
||||
)
|
||||
|
||||
top_actions_result = await db.execute(top_actions_query)
|
||||
top_actions = [{"action": row[0], "count": row[1]} for row in top_actions_result.fetchall()]
|
||||
top_actions = [
|
||||
{"action": row[0], "count": row[1]} for row in top_actions_result.fetchall()
|
||||
]
|
||||
|
||||
# Top resources
|
||||
top_resources_query = select(
|
||||
AuditLog.resource_type,
|
||||
func.count(AuditLog.id).label('count')
|
||||
).where(and_(*conditions)).group_by(AuditLog.resource_type).order_by(func.count(AuditLog.id).desc()).limit(10)
|
||||
top_resources_query = (
|
||||
select(AuditLog.resource_type, func.count(AuditLog.id).label("count"))
|
||||
.where(and_(*conditions))
|
||||
.group_by(AuditLog.resource_type)
|
||||
.order_by(func.count(AuditLog.id).desc())
|
||||
.limit(10)
|
||||
)
|
||||
|
||||
top_resources_result = await db.execute(top_resources_query)
|
||||
top_resources = [{"resource_type": row[0], "count": row[1]} for row in top_resources_result.fetchall()]
|
||||
top_resources = [
|
||||
{"resource_type": row[0], "count": row[1]}
|
||||
for row in top_resources_result.fetchall()
|
||||
]
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
@@ -337,8 +353,8 @@ async def get_audit_statistics(
|
||||
details={
|
||||
"start_date": start_date.isoformat(),
|
||||
"end_date": end_date.isoformat(),
|
||||
"total_events": basic_stats["total_events"]
|
||||
}
|
||||
"total_events": basic_stats["total_events"],
|
||||
},
|
||||
)
|
||||
|
||||
return AuditStatsResponse(
|
||||
@@ -346,7 +362,7 @@ async def get_audit_statistics(
|
||||
events_by_user=events_by_user,
|
||||
events_by_hour=events_by_hour,
|
||||
top_actions=top_actions,
|
||||
top_resources=top_resources
|
||||
top_resources=top_resources,
|
||||
)
|
||||
|
||||
|
||||
@@ -354,7 +370,7 @@ async def get_audit_statistics(
|
||||
async def get_security_events(
|
||||
hours: int = Query(24, ge=1, le=168), # Last 24 hours by default, max 1 week
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get security-related events and anomalies"""
|
||||
|
||||
@@ -365,13 +381,18 @@ async def get_security_events(
|
||||
start_time = end_time - timedelta(hours=hours)
|
||||
|
||||
# Failed logins
|
||||
failed_logins_query = select(AuditLog).where(
|
||||
failed_logins_query = (
|
||||
select(AuditLog)
|
||||
.where(
|
||||
and_(
|
||||
AuditLog.created_at >= start_time,
|
||||
AuditLog.action == "login",
|
||||
AuditLog.success == False
|
||||
AuditLog.success == False,
|
||||
)
|
||||
)
|
||||
.order_by(AuditLog.created_at.desc())
|
||||
.limit(50)
|
||||
)
|
||||
).order_by(AuditLog.created_at.desc()).limit(50)
|
||||
|
||||
failed_logins_result = await db.execute(failed_logins_query)
|
||||
failed_logins = [
|
||||
@@ -380,18 +401,23 @@ async def get_security_events(
|
||||
"user_id": log.user_id,
|
||||
"ip_address": log.ip_address,
|
||||
"user_agent": log.user_agent,
|
||||
"details": log.details
|
||||
"details": log.details,
|
||||
}
|
||||
for log in failed_logins_result.scalars().all()
|
||||
]
|
||||
|
||||
# High severity events
|
||||
high_severity_query = select(AuditLog).where(
|
||||
high_severity_query = (
|
||||
select(AuditLog)
|
||||
.where(
|
||||
and_(
|
||||
AuditLog.created_at >= start_time,
|
||||
AuditLog.severity.in_(["error", "critical"])
|
||||
AuditLog.severity.in_(["error", "critical"]),
|
||||
)
|
||||
)
|
||||
.order_by(AuditLog.created_at.desc())
|
||||
.limit(50)
|
||||
)
|
||||
).order_by(AuditLog.created_at.desc()).limit(50)
|
||||
|
||||
high_severity_result = await db.execute(high_severity_query)
|
||||
high_severity_events = [
|
||||
@@ -403,52 +429,61 @@ async def get_security_events(
|
||||
"user_id": log.user_id,
|
||||
"ip_address": log.ip_address,
|
||||
"success": log.success,
|
||||
"details": log.details
|
||||
"details": log.details,
|
||||
}
|
||||
for log in high_severity_result.scalars().all()
|
||||
]
|
||||
|
||||
# Suspicious activities (multiple failed attempts from same IP)
|
||||
suspicious_ips_query = select(
|
||||
AuditLog.ip_address,
|
||||
func.count(AuditLog.id).label('failed_count')
|
||||
).where(
|
||||
suspicious_ips_query = (
|
||||
select(AuditLog.ip_address, func.count(AuditLog.id).label("failed_count"))
|
||||
.where(
|
||||
and_(
|
||||
AuditLog.created_at >= start_time,
|
||||
AuditLog.success == False,
|
||||
AuditLog.ip_address.isnot(None)
|
||||
AuditLog.ip_address.isnot(None),
|
||||
)
|
||||
)
|
||||
.group_by(AuditLog.ip_address)
|
||||
.having(func.count(AuditLog.id) >= 5)
|
||||
.order_by(func.count(AuditLog.id).desc())
|
||||
)
|
||||
).group_by(AuditLog.ip_address).having(func.count(AuditLog.id) >= 5).order_by(func.count(AuditLog.id).desc())
|
||||
|
||||
suspicious_ips_result = await db.execute(suspicious_ips_query)
|
||||
suspicious_activities = [
|
||||
{
|
||||
"ip_address": row[0],
|
||||
"failed_attempts": row[1],
|
||||
"risk_level": "high" if row[1] >= 10 else "medium"
|
||||
"risk_level": "high" if row[1] >= 10 else "medium",
|
||||
}
|
||||
for row in suspicious_ips_result.fetchall()
|
||||
]
|
||||
|
||||
# Unusual access patterns (users accessing from multiple IPs)
|
||||
unusual_access_query = select(
|
||||
unusual_access_query = (
|
||||
select(
|
||||
AuditLog.user_id,
|
||||
func.count(func.distinct(AuditLog.ip_address)).label('ip_count'),
|
||||
func.array_agg(func.distinct(AuditLog.ip_address)).label('ip_addresses')
|
||||
).where(
|
||||
func.count(func.distinct(AuditLog.ip_address)).label("ip_count"),
|
||||
func.array_agg(func.distinct(AuditLog.ip_address)).label("ip_addresses"),
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
AuditLog.created_at >= start_time,
|
||||
AuditLog.user_id.isnot(None),
|
||||
AuditLog.ip_address.isnot(None)
|
||||
AuditLog.ip_address.isnot(None),
|
||||
)
|
||||
)
|
||||
.group_by(AuditLog.user_id)
|
||||
.having(func.count(func.distinct(AuditLog.ip_address)) >= 3)
|
||||
.order_by(func.count(func.distinct(AuditLog.ip_address)).desc())
|
||||
)
|
||||
).group_by(AuditLog.user_id).having(func.count(func.distinct(AuditLog.ip_address)) >= 3).order_by(func.count(func.distinct(AuditLog.ip_address)).desc())
|
||||
|
||||
unusual_access_result = await db.execute(unusual_access_query)
|
||||
unusual_access_patterns = [
|
||||
{
|
||||
"user_id": row[0],
|
||||
"unique_ips": row[1],
|
||||
"ip_addresses": row[2] if row[2] else []
|
||||
"ip_addresses": row[2] if row[2] else [],
|
||||
}
|
||||
for row in unusual_access_result.fetchall()
|
||||
]
|
||||
@@ -464,15 +499,15 @@ async def get_security_events(
|
||||
"failed_logins_count": len(failed_logins),
|
||||
"high_severity_count": len(high_severity_events),
|
||||
"suspicious_ips_count": len(suspicious_activities),
|
||||
"unusual_access_patterns_count": len(unusual_access_patterns)
|
||||
}
|
||||
"unusual_access_patterns_count": len(unusual_access_patterns),
|
||||
},
|
||||
)
|
||||
|
||||
return SecurityEventsResponse(
|
||||
suspicious_activities=suspicious_activities,
|
||||
failed_logins=failed_logins,
|
||||
unusual_access_patterns=unusual_access_patterns,
|
||||
high_severity_events=high_severity_events
|
||||
high_severity_events=high_severity_events,
|
||||
)
|
||||
|
||||
|
||||
@@ -485,7 +520,7 @@ async def export_audit_logs(
|
||||
action: Optional[str] = Query(None),
|
||||
resource_type: Optional[str] = Query(None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Export audit logs in CSV or JSON format"""
|
||||
|
||||
@@ -503,10 +538,7 @@ async def export_audit_logs(
|
||||
|
||||
# Build query
|
||||
query = select(AuditLog)
|
||||
conditions = [
|
||||
AuditLog.created_at >= start_date,
|
||||
AuditLog.created_at <= end_date
|
||||
]
|
||||
conditions = [AuditLog.created_at >= start_date, AuditLog.created_at <= end_date]
|
||||
|
||||
if user_id:
|
||||
conditions.append(AuditLog.user_id == user_id)
|
||||
@@ -515,7 +547,11 @@ async def export_audit_logs(
|
||||
if resource_type:
|
||||
conditions.append(AuditLog.resource_type == resource_type)
|
||||
|
||||
query = query.where(and_(*conditions)).order_by(AuditLog.created_at.desc()).limit(max_records)
|
||||
query = (
|
||||
query.where(and_(*conditions))
|
||||
.order_by(AuditLog.created_at.desc())
|
||||
.limit(max_records)
|
||||
)
|
||||
|
||||
# Execute query
|
||||
result = await db.execute(query)
|
||||
@@ -535,13 +571,14 @@ async def export_audit_logs(
|
||||
"filters": {
|
||||
"user_id": user_id,
|
||||
"action": action,
|
||||
"resource_type": resource_type
|
||||
}
|
||||
}
|
||||
"resource_type": resource_type,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if format == "json":
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
export_data = [
|
||||
{
|
||||
"id": str(log.id),
|
||||
@@ -554,7 +591,7 @@ async def export_audit_logs(
|
||||
"user_agent": log.user_agent,
|
||||
"success": log.success,
|
||||
"severity": log.severity,
|
||||
"created_at": log.created_at.isoformat()
|
||||
"created_at": log.created_at.isoformat(),
|
||||
}
|
||||
for log in logs
|
||||
]
|
||||
@@ -569,14 +606,25 @@ async def export_audit_logs(
|
||||
writer = csv.writer(output)
|
||||
|
||||
# Write header
|
||||
writer.writerow([
|
||||
"ID", "User ID", "Action", "Resource Type", "Resource ID",
|
||||
"IP Address", "Success", "Severity", "Created At", "Details"
|
||||
])
|
||||
writer.writerow(
|
||||
[
|
||||
"ID",
|
||||
"User ID",
|
||||
"Action",
|
||||
"Resource Type",
|
||||
"Resource ID",
|
||||
"IP Address",
|
||||
"Success",
|
||||
"Severity",
|
||||
"Created At",
|
||||
"Details",
|
||||
]
|
||||
)
|
||||
|
||||
# Write data
|
||||
for log in logs:
|
||||
writer.writerow([
|
||||
writer.writerow(
|
||||
[
|
||||
str(log.id),
|
||||
log.user_id or "",
|
||||
log.action,
|
||||
@@ -586,13 +634,16 @@ async def export_audit_logs(
|
||||
log.success,
|
||||
log.severity,
|
||||
log.created_at.isoformat(),
|
||||
str(log.details)
|
||||
])
|
||||
str(log.details),
|
||||
]
|
||||
)
|
||||
|
||||
output.seek(0)
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(output.getvalue().encode()),
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": f"attachment; filename=audit_logs_{start_date.strftime('%Y%m%d')}_{end_date.strftime('%Y%m%d')}.csv"}
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=audit_logs_{start_date.strftime('%Y%m%d')}_{end_date.strftime('%Y%m%d')}.csv"
|
||||
},
|
||||
)
|
||||
@@ -9,6 +9,7 @@ from fastapi.security import HTTPBearer
|
||||
from pydantic import BaseModel, EmailStr, validator
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
@@ -39,24 +40,24 @@ class UserRegisterRequest(BaseModel):
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
|
||||
@validator('password')
|
||||
@validator("password")
|
||||
def validate_password(cls, v):
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters long')
|
||||
raise ValueError("Password must be at least 8 characters long")
|
||||
if not any(c.isupper() for c in v):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
raise ValueError("Password must contain at least one uppercase letter")
|
||||
if not any(c.islower() for c in v):
|
||||
raise ValueError('Password must contain at least one lowercase letter')
|
||||
raise ValueError("Password must contain at least one lowercase letter")
|
||||
if not any(c.isdigit() for c in v):
|
||||
raise ValueError('Password must contain at least one digit')
|
||||
raise ValueError("Password must contain at least one digit")
|
||||
return v
|
||||
|
||||
@validator('username')
|
||||
@validator("username")
|
||||
def validate_username(cls, v):
|
||||
if len(v) < 3:
|
||||
raise ValueError('Username must be at least 3 characters long')
|
||||
raise ValueError("Username must be at least 3 characters long")
|
||||
if not v.isalnum():
|
||||
raise ValueError('Username must contain only alphanumeric characters')
|
||||
raise ValueError("Username must contain only alphanumeric characters")
|
||||
return v
|
||||
|
||||
|
||||
@@ -65,10 +66,10 @@ class UserLoginRequest(BaseModel):
|
||||
username: Optional[str] = None
|
||||
password: str
|
||||
|
||||
@validator('email')
|
||||
@validator("email")
|
||||
def validate_email_or_username(cls, v, values):
|
||||
if v is None and not values.get('username'):
|
||||
raise ValueError('Either email or username must be provided')
|
||||
if v is None and not values.get("username"):
|
||||
raise ValueError("Either email or username must be provided")
|
||||
return v
|
||||
|
||||
|
||||
@@ -77,6 +78,8 @@ class TokenResponse(BaseModel):
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
force_password_change: Optional[bool] = None
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
@@ -86,7 +89,7 @@ class UserResponse(BaseModel):
|
||||
full_name: Optional[str]
|
||||
is_active: bool
|
||||
is_verified: bool
|
||||
role: str
|
||||
role: Optional[str]
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
@@ -101,24 +104,23 @@ class ChangePasswordRequest(BaseModel):
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
||||
@validator('new_password')
|
||||
@validator("new_password")
|
||||
def validate_new_password(cls, v):
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters long')
|
||||
raise ValueError("Password must be at least 8 characters long")
|
||||
if not any(c.isupper() for c in v):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
raise ValueError("Password must contain at least one uppercase letter")
|
||||
if not any(c.islower() for c in v):
|
||||
raise ValueError('Password must contain at least one lowercase letter')
|
||||
raise ValueError("Password must contain at least one lowercase letter")
|
||||
if not any(c.isdigit() for c in v):
|
||||
raise ValueError('Password must contain at least one digit')
|
||||
raise ValueError("Password must contain at least one digit")
|
||||
return v
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(
|
||||
user_data: UserRegisterRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
@router.post(
|
||||
"/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
async def register(user_data: UserRegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""Register a new user"""
|
||||
|
||||
# Check if user already exists
|
||||
@@ -126,8 +128,7 @@ async def register(
|
||||
result = await db.execute(stmt)
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already registered"
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered"
|
||||
)
|
||||
|
||||
# Check if username already exists
|
||||
@@ -135,8 +136,7 @@ async def register(
|
||||
result = await db.execute(stmt)
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Username already taken"
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken"
|
||||
)
|
||||
|
||||
# Create new user
|
||||
@@ -151,21 +151,27 @@ async def register(
|
||||
full_name=full_name,
|
||||
is_active=True,
|
||||
is_verified=False,
|
||||
role="user"
|
||||
role_id=2, # Default to 'user' role (id=2)
|
||||
)
|
||||
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
return UserResponse.from_orm(user)
|
||||
return UserResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
username=user.username,
|
||||
full_name=user.full_name,
|
||||
is_active=user.is_active,
|
||||
is_verified=user.is_verified,
|
||||
role=user.role.name if user.role else None,
|
||||
created_at=user.created_at,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(
|
||||
user_data: UserLoginRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
async def login(user_data: UserLoginRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""Login user and return access tokens"""
|
||||
|
||||
# Determine identifier for logging and user lookup
|
||||
@@ -187,9 +193,9 @@ async def login(
|
||||
query_start = datetime.utcnow()
|
||||
|
||||
if user_data.email:
|
||||
stmt = select(User).where(User.email == user_data.email)
|
||||
stmt = select(User).options(selectinload(User.role)).where(User.email == user_data.email)
|
||||
else:
|
||||
stmt = select(User).where(User.username == user_data.username)
|
||||
stmt = select(User).options(selectinload(User.role)).where(User.username == user_data.username)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
query_end = datetime.utcnow()
|
||||
@@ -205,13 +211,18 @@ async def login(
|
||||
identifier_lower = identifier.lower() if identifier else ""
|
||||
admin_email = settings.ADMIN_EMAIL.lower() if settings.ADMIN_EMAIL else None
|
||||
|
||||
if user_data.email and admin_email and identifier_lower == admin_email and settings.ADMIN_PASSWORD:
|
||||
if (
|
||||
user_data.email
|
||||
and admin_email
|
||||
and identifier_lower == admin_email
|
||||
and settings.ADMIN_PASSWORD
|
||||
):
|
||||
bootstrap_attempted = True
|
||||
logger.info("LOGIN_ADMIN_BOOTSTRAP_START", email=user_data.email)
|
||||
try:
|
||||
await create_default_admin()
|
||||
# Re-run lookup after bootstrap attempt
|
||||
stmt = select(User).where(User.email == user_data.email)
|
||||
stmt = select(User).options(selectinload(User.role)).where(User.email == user_data.email)
|
||||
result = await db.execute(stmt)
|
||||
user = result.scalar_one_or_none()
|
||||
if user:
|
||||
@@ -234,11 +245,13 @@ async def login(
|
||||
logger.error("LOGIN_USER_LIST_FAILURE", error=str(e))
|
||||
|
||||
if bootstrap_attempted:
|
||||
logger.warning("LOGIN_ADMIN_BOOTSTRAP_UNSUCCESSFUL", email=user_data.email)
|
||||
logger.warning(
|
||||
"LOGIN_ADMIN_BOOTSTRAP_UNSUCCESSFUL", email=user_data.email
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect email or password"
|
||||
detail="Incorrect email or password",
|
||||
)
|
||||
|
||||
logger.info("LOGIN_USER_FOUND", email=user.email, is_active=user.is_active)
|
||||
@@ -253,7 +266,7 @@ async def login(
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect email or password"
|
||||
detail="Incorrect email or password",
|
||||
)
|
||||
|
||||
verify_end = datetime.utcnow()
|
||||
@@ -264,8 +277,7 @@ async def login(
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User account is disabled"
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="User account is disabled"
|
||||
)
|
||||
|
||||
# Update last login
|
||||
@@ -289,14 +301,12 @@ async def login(
|
||||
"sub": str(user.id),
|
||||
"email": user.email,
|
||||
"is_superuser": user.is_superuser,
|
||||
"role": user.role
|
||||
"role": user.role.name if user.role else None,
|
||||
},
|
||||
expires_delta=access_token_expires
|
||||
expires_delta=access_token_expires,
|
||||
)
|
||||
|
||||
refresh_token = create_refresh_token(
|
||||
data={"sub": str(user.id), "type": "refresh"}
|
||||
)
|
||||
refresh_token = create_refresh_token(data={"sub": str(user.id), "type": "refresh"})
|
||||
token_end = datetime.utcnow()
|
||||
logger.info(
|
||||
"LOGIN_TOKEN_CREATE_SUCCESS",
|
||||
@@ -309,17 +319,25 @@ async def login(
|
||||
total_duration_seconds=total_time.total_seconds(),
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
)
|
||||
# Check if user needs to change password
|
||||
response_data = {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
}
|
||||
|
||||
# Add force password change flag if needed
|
||||
if user.force_password_change:
|
||||
response_data["force_password_change"] = True
|
||||
response_data["message"] = "Password change required on first login"
|
||||
|
||||
return response_data
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh_token(
|
||||
token_data: RefreshTokenRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
token_data: RefreshTokenRequest, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Refresh access token using refresh token"""
|
||||
|
||||
@@ -330,25 +348,28 @@ async def refresh_token(
|
||||
|
||||
if not user_id or token_type != "refresh":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token"
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
|
||||
)
|
||||
|
||||
# Get user from database
|
||||
stmt = select(User).where(User.id == int(user_id))
|
||||
stmt = select(User).options(selectinload(User.role)).where(User.id == int(user_id))
|
||||
result = await db.execute(stmt)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found or inactive"
|
||||
detail="User not found or inactive",
|
||||
)
|
||||
|
||||
# Create new access token
|
||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
logger.info(f"REFRESH: Creating new access token with expiration: {access_token_expires}")
|
||||
logger.info(f"REFRESH: ACCESS_TOKEN_EXPIRE_MINUTES from settings: {settings.ACCESS_TOKEN_EXPIRE_MINUTES}")
|
||||
logger.info(
|
||||
f"REFRESH: Creating new access token with expiration: {access_token_expires}"
|
||||
)
|
||||
logger.info(
|
||||
f"REFRESH: ACCESS_TOKEN_EXPIRE_MINUTES from settings: {settings.ACCESS_TOKEN_EXPIRE_MINUTES}"
|
||||
)
|
||||
logger.info(f"REFRESH: Current UTC time: {datetime.utcnow().isoformat()}")
|
||||
|
||||
access_token = create_access_token(
|
||||
@@ -356,15 +377,15 @@ async def refresh_token(
|
||||
"sub": str(user.id),
|
||||
"email": user.email,
|
||||
"is_superuser": user.is_superuser,
|
||||
"role": user.role
|
||||
"role": user.role.name if user.role else None,
|
||||
},
|
||||
expires_delta=access_token_expires
|
||||
expires_delta=access_token_expires,
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=token_data.refresh_token, # Keep same refresh token
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
@@ -374,30 +395,37 @@ async def refresh_token(
|
||||
# Log the actual error for debugging
|
||||
logger.error(f"Refresh token error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token"
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_current_user_info(
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get current user information"""
|
||||
|
||||
# Get full user details from database
|
||||
stmt = select(User).where(User.id == int(current_user["id"]))
|
||||
stmt = select(User).options(selectinload(User.role)).where(User.id == int(current_user["id"]))
|
||||
result = await db.execute(stmt)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
return UserResponse.from_orm(user)
|
||||
return UserResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
username=user.username,
|
||||
full_name=user.full_name,
|
||||
is_active=user.is_active,
|
||||
is_verified=user.is_verified,
|
||||
role=user.role.name if user.role else None,
|
||||
created_at=user.created_at,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
@@ -407,14 +435,12 @@ async def logout():
|
||||
|
||||
|
||||
@router.post("/verify-token")
|
||||
async def verify_user_token(
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
async def verify_user_token(current_user: dict = Depends(get_current_user)):
|
||||
"""Verify if the current token is valid"""
|
||||
return {
|
||||
"valid": True,
|
||||
"user_id": current_user["id"],
|
||||
"email": current_user["email"]
|
||||
"email": current_user["email"],
|
||||
}
|
||||
|
||||
|
||||
@@ -422,7 +448,7 @@ async def verify_user_token(
|
||||
async def change_password(
|
||||
password_data: ChangePasswordRequest,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Change user password"""
|
||||
|
||||
@@ -433,15 +459,14 @@ async def change_password(
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
# Verify current password
|
||||
if not verify_password(password_data.current_password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Current password is incorrect"
|
||||
detail="Current password is incorrect",
|
||||
)
|
||||
|
||||
# Update password
|
||||
|
||||
@@ -129,26 +129,30 @@ async def list_budgets(
|
||||
budget_type: Optional[BudgetType] = Query(None),
|
||||
is_enabled: Optional[bool] = Query(None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List budgets with pagination and filtering"""
|
||||
|
||||
# Check permissions - users can view their own budgets
|
||||
if user_id and int(user_id) != current_user['id']:
|
||||
if user_id and int(user_id) != current_user["id"]:
|
||||
require_permission(current_user.get("permissions", []), "platform:budgets:read")
|
||||
elif not user_id:
|
||||
require_permission(current_user.get("permissions", []), "platform:budgets:read")
|
||||
|
||||
# If no user_id specified and user doesn't have admin permissions, show only their budgets
|
||||
if not user_id and "platform:budgets:read" not in current_user.get("permissions", []):
|
||||
user_id = current_user['id']
|
||||
if not user_id and "platform:budgets:read" not in current_user.get(
|
||||
"permissions", []
|
||||
):
|
||||
user_id = current_user["id"]
|
||||
|
||||
# Build query
|
||||
query = select(Budget)
|
||||
|
||||
# Apply filters
|
||||
if user_id:
|
||||
query = query.where(Budget.user_id == (int(user_id) if isinstance(user_id, str) else user_id))
|
||||
query = query.where(
|
||||
Budget.user_id == (int(user_id) if isinstance(user_id, str) else user_id)
|
||||
)
|
||||
if budget_type:
|
||||
query = query.where(Budget.budget_type == budget_type.value)
|
||||
if is_enabled is not None:
|
||||
@@ -159,7 +163,9 @@ async def list_budgets(
|
||||
|
||||
# Apply same filters to count query
|
||||
if user_id:
|
||||
count_query = count_query.where(Budget.user_id == (int(user_id) if isinstance(user_id, str) else user_id))
|
||||
count_query = count_query.where(
|
||||
Budget.user_id == (int(user_id) if isinstance(user_id, str) else user_id)
|
||||
)
|
||||
if budget_type:
|
||||
count_query = count_query.where(Budget.budget_type == budget_type.value)
|
||||
if is_enabled is not None:
|
||||
@@ -182,23 +188,30 @@ async def list_budgets(
|
||||
usage = await _calculate_budget_usage(db, budget)
|
||||
budget_data = BudgetResponse.model_validate(budget)
|
||||
budget_data.current_usage = usage
|
||||
budget_data.usage_percentage = (usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
|
||||
budget_data.usage_percentage = (
|
||||
(usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
|
||||
)
|
||||
budget_responses.append(budget_data)
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="list_budgets",
|
||||
resource_type="budget",
|
||||
details={"page": page, "size": size, "filters": {"user_id": user_id, "budget_type": budget_type, "is_enabled": is_enabled}}
|
||||
details={
|
||||
"page": page,
|
||||
"size": size,
|
||||
"filters": {
|
||||
"user_id": user_id,
|
||||
"budget_type": budget_type,
|
||||
"is_enabled": is_enabled,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
return BudgetListResponse(
|
||||
budgets=budget_responses,
|
||||
total=total,
|
||||
page=page,
|
||||
size=size
|
||||
budgets=budget_responses, total=total, page=page, size=size
|
||||
)
|
||||
|
||||
|
||||
@@ -206,7 +219,7 @@ async def list_budgets(
|
||||
async def get_budget(
|
||||
budget_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get budget by ID"""
|
||||
|
||||
@@ -217,12 +230,11 @@ async def get_budget(
|
||||
|
||||
if not budget:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Budget not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Budget not found"
|
||||
)
|
||||
|
||||
# Check permissions - users can view their own budgets
|
||||
if budget.user_id != current_user['id']:
|
||||
if budget.user_id != current_user["id"]:
|
||||
require_permission(current_user.get("permissions", []), "platform:budgets:read")
|
||||
|
||||
# Calculate current usage
|
||||
@@ -231,15 +243,17 @@ async def get_budget(
|
||||
# Build response
|
||||
budget_data = BudgetResponse.model_validate(budget)
|
||||
budget_data.current_usage = usage
|
||||
budget_data.usage_percentage = (usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
|
||||
budget_data.usage_percentage = (
|
||||
(usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
|
||||
)
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="get_budget",
|
||||
resource_type="budget",
|
||||
resource_id=budget_id
|
||||
resource_id=budget_id,
|
||||
)
|
||||
|
||||
return budget_data
|
||||
@@ -249,7 +263,7 @@ async def get_budget(
|
||||
async def create_budget(
|
||||
budget_data: BudgetCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a new budget"""
|
||||
|
||||
@@ -257,11 +271,17 @@ async def create_budget(
|
||||
require_permission(current_user.get("permissions", []), "platform:budgets:create")
|
||||
|
||||
# If user_id not specified, use current user
|
||||
target_user_id = budget_data.user_id or current_user['id']
|
||||
target_user_id = budget_data.user_id or current_user["id"]
|
||||
|
||||
# If setting budget for another user, need admin permissions
|
||||
if int(target_user_id) != current_user['id'] if isinstance(target_user_id, str) else target_user_id != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:budgets:admin")
|
||||
if (
|
||||
int(target_user_id) != current_user["id"]
|
||||
if isinstance(target_user_id, str)
|
||||
else target_user_id != current_user["id"]
|
||||
):
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:budgets:admin"
|
||||
)
|
||||
|
||||
# Calculate period start and end
|
||||
now = datetime.utcnow()
|
||||
@@ -281,7 +301,7 @@ async def create_budget(
|
||||
is_enabled=budget_data.is_enabled,
|
||||
alert_threshold_percent=budget_data.alert_threshold_percent,
|
||||
allowed_resources=budget_data.allowed_resources,
|
||||
metadata=budget_data.metadata
|
||||
metadata=budget_data.metadata,
|
||||
)
|
||||
|
||||
db.add(new_budget)
|
||||
@@ -296,11 +316,15 @@ async def create_budget(
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="create_budget",
|
||||
resource_type="budget",
|
||||
resource_id=str(new_budget.id),
|
||||
details={"name": budget_data.name, "budget_type": budget_data.budget_type, "limit_amount": budget_data.limit_amount}
|
||||
details={
|
||||
"name": budget_data.name,
|
||||
"budget_type": budget_data.budget_type,
|
||||
"limit_amount": budget_data.limit_amount,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Budget created: {new_budget.name} by {current_user['username']}")
|
||||
@@ -313,7 +337,7 @@ async def update_budget(
|
||||
budget_id: str,
|
||||
budget_data: BudgetUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update budget"""
|
||||
|
||||
@@ -324,19 +348,20 @@ async def update_budget(
|
||||
|
||||
if not budget:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Budget not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Budget not found"
|
||||
)
|
||||
|
||||
# Check permissions - users can update their own budgets
|
||||
if budget.user_id != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:budgets:update")
|
||||
if budget.user_id != current_user["id"]:
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:budgets:update"
|
||||
)
|
||||
|
||||
# Store original values for audit
|
||||
original_values = {
|
||||
"name": budget.name,
|
||||
"limit_amount": budget.limit_amount,
|
||||
"is_enabled": budget.is_enabled
|
||||
"is_enabled": budget.is_enabled,
|
||||
}
|
||||
|
||||
# Update budget fields
|
||||
@@ -346,7 +371,9 @@ async def update_budget(
|
||||
|
||||
# Recalculate period if period_type changed
|
||||
if "period_type" in update_data:
|
||||
period_start, period_end = _calculate_period_bounds(datetime.utcnow(), budget.period_type)
|
||||
period_start, period_end = _calculate_period_bounds(
|
||||
datetime.utcnow(), budget.period_type
|
||||
)
|
||||
budget.period_start = period_start
|
||||
budget.period_end = period_end
|
||||
|
||||
@@ -359,20 +386,22 @@ async def update_budget(
|
||||
# Build response
|
||||
budget_response = BudgetResponse.model_validate(budget)
|
||||
budget_response.current_usage = usage
|
||||
budget_response.usage_percentage = (usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
|
||||
budget_response.usage_percentage = (
|
||||
(usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
|
||||
)
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="update_budget",
|
||||
resource_type="budget",
|
||||
resource_id=budget_id,
|
||||
details={
|
||||
"updated_fields": list(update_data.keys()),
|
||||
"before_values": original_values,
|
||||
"after_values": {k: getattr(budget, k) for k in update_data.keys()}
|
||||
}
|
||||
"after_values": {k: getattr(budget, k) for k in update_data.keys()},
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Budget updated: {budget.name} by {current_user['username']}")
|
||||
@@ -384,7 +413,7 @@ async def update_budget(
|
||||
async def delete_budget(
|
||||
budget_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Delete budget"""
|
||||
|
||||
@@ -395,13 +424,14 @@ async def delete_budget(
|
||||
|
||||
if not budget:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Budget not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Budget not found"
|
||||
)
|
||||
|
||||
# Check permissions - users can delete their own budgets
|
||||
if budget.user_id != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:budgets:delete")
|
||||
if budget.user_id != current_user["id"]:
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:budgets:delete"
|
||||
)
|
||||
|
||||
# Delete budget
|
||||
await db.delete(budget)
|
||||
@@ -410,11 +440,11 @@ async def delete_budget(
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="delete_budget",
|
||||
resource_type="budget",
|
||||
resource_id=budget_id,
|
||||
details={"name": budget.name}
|
||||
details={"name": budget.name},
|
||||
)
|
||||
|
||||
logger.info(f"Budget deleted: {budget.name} by {current_user['username']}")
|
||||
@@ -426,7 +456,7 @@ async def delete_budget(
|
||||
async def get_budget_usage(
|
||||
budget_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get detailed budget usage information"""
|
||||
|
||||
@@ -437,17 +467,18 @@ async def get_budget_usage(
|
||||
|
||||
if not budget:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Budget not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Budget not found"
|
||||
)
|
||||
|
||||
# Check permissions - users can view their own budget usage
|
||||
if budget.user_id != current_user['id']:
|
||||
if budget.user_id != current_user["id"]:
|
||||
require_permission(current_user.get("permissions", []), "platform:budgets:read")
|
||||
|
||||
# Calculate usage
|
||||
current_usage = await _calculate_budget_usage(db, budget)
|
||||
usage_percentage = (current_usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
|
||||
usage_percentage = (
|
||||
(current_usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
|
||||
)
|
||||
remaining_amount = max(0, budget.limit_amount - current_usage)
|
||||
is_exceeded = current_usage > budget.limit_amount
|
||||
|
||||
@@ -470,10 +501,10 @@ async def get_budget_usage(
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="get_budget_usage",
|
||||
resource_type="budget",
|
||||
resource_id=budget_id
|
||||
resource_id=budget_id,
|
||||
)
|
||||
|
||||
return BudgetUsageResponse(
|
||||
@@ -487,7 +518,7 @@ async def get_budget_usage(
|
||||
is_exceeded=is_exceeded,
|
||||
days_remaining=days_remaining,
|
||||
projected_usage=projected_usage,
|
||||
usage_history=usage_history
|
||||
usage_history=usage_history,
|
||||
)
|
||||
|
||||
|
||||
@@ -495,7 +526,7 @@ async def get_budget_usage(
|
||||
async def get_budget_alerts(
|
||||
budget_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get budget alerts"""
|
||||
|
||||
@@ -506,51 +537,58 @@ async def get_budget_alerts(
|
||||
|
||||
if not budget:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Budget not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Budget not found"
|
||||
)
|
||||
|
||||
# Check permissions - users can view their own budget alerts
|
||||
if budget.user_id != current_user['id']:
|
||||
if budget.user_id != current_user["id"]:
|
||||
require_permission(current_user.get("permissions", []), "platform:budgets:read")
|
||||
|
||||
# Calculate usage
|
||||
current_usage = await _calculate_budget_usage(db, budget)
|
||||
usage_percentage = (current_usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
|
||||
usage_percentage = (
|
||||
(current_usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
|
||||
)
|
||||
|
||||
alerts = []
|
||||
|
||||
# Check for alerts
|
||||
if usage_percentage >= 100:
|
||||
alerts.append(BudgetAlertResponse(
|
||||
alerts.append(
|
||||
BudgetAlertResponse(
|
||||
budget_id=budget_id,
|
||||
budget_name=budget.name,
|
||||
alert_type="exceeded",
|
||||
current_usage=current_usage,
|
||||
limit_amount=budget.limit_amount,
|
||||
usage_percentage=usage_percentage,
|
||||
message=f"Budget '{budget.name}' has been exceeded ({usage_percentage:.1f}% used)"
|
||||
))
|
||||
message=f"Budget '{budget.name}' has been exceeded ({usage_percentage:.1f}% used)",
|
||||
)
|
||||
)
|
||||
elif usage_percentage >= 90:
|
||||
alerts.append(BudgetAlertResponse(
|
||||
alerts.append(
|
||||
BudgetAlertResponse(
|
||||
budget_id=budget_id,
|
||||
budget_name=budget.name,
|
||||
alert_type="critical",
|
||||
current_usage=current_usage,
|
||||
limit_amount=budget.limit_amount,
|
||||
usage_percentage=usage_percentage,
|
||||
message=f"Budget '{budget.name}' is critically high ({usage_percentage:.1f}% used)"
|
||||
))
|
||||
message=f"Budget '{budget.name}' is critically high ({usage_percentage:.1f}% used)",
|
||||
)
|
||||
)
|
||||
elif usage_percentage >= budget.alert_threshold_percent:
|
||||
alerts.append(BudgetAlertResponse(
|
||||
alerts.append(
|
||||
BudgetAlertResponse(
|
||||
budget_id=budget_id,
|
||||
budget_name=budget.name,
|
||||
alert_type="warning",
|
||||
current_usage=current_usage,
|
||||
limit_amount=budget.limit_amount,
|
||||
usage_percentage=usage_percentage,
|
||||
message=f"Budget '{budget.name}' has reached alert threshold ({usage_percentage:.1f}% used)"
|
||||
))
|
||||
message=f"Budget '{budget.name}' has reached alert threshold ({usage_percentage:.1f}% used)",
|
||||
)
|
||||
)
|
||||
|
||||
return alerts
|
||||
|
||||
@@ -565,7 +603,7 @@ async def _calculate_budget_usage(db: AsyncSession, budget: Budget) -> float:
|
||||
# Filter by time period
|
||||
query = query.where(
|
||||
UsageTracking.created_at >= budget.period_start,
|
||||
UsageTracking.created_at <= budget.period_end
|
||||
UsageTracking.created_at <= budget.period_end,
|
||||
)
|
||||
|
||||
# Filter by user or API key
|
||||
@@ -594,7 +632,9 @@ async def _calculate_budget_usage(db: AsyncSession, budget: Budget) -> float:
|
||||
return float(usage)
|
||||
|
||||
|
||||
def _calculate_period_bounds(current_time: datetime, period_type: str) -> tuple[datetime, datetime]:
|
||||
def _calculate_period_bounds(
|
||||
current_time: datetime, period_type: str
|
||||
) -> tuple[datetime, datetime]:
|
||||
"""Calculate period start and end dates"""
|
||||
|
||||
if period_type == "hourly":
|
||||
@@ -606,7 +646,9 @@ def _calculate_period_bounds(current_time: datetime, period_type: str) -> tuple[
|
||||
elif period_type == "weekly":
|
||||
# Start of week (Monday)
|
||||
days_since_monday = current_time.weekday()
|
||||
start = current_time.replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=days_since_monday)
|
||||
start = current_time.replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
) - timedelta(days=days_since_monday)
|
||||
end = start + timedelta(weeks=1) - timedelta(microseconds=1)
|
||||
elif period_type == "monthly":
|
||||
start = current_time.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
@@ -616,7 +658,9 @@ def _calculate_period_bounds(current_time: datetime, period_type: str) -> tuple[
|
||||
next_month = start.replace(month=start.month + 1)
|
||||
end = next_month - timedelta(microseconds=1)
|
||||
elif period_type == "yearly":
|
||||
start = current_time.replace(month=1, day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
start = current_time.replace(
|
||||
month=1, day=1, hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
end = start.replace(year=start.year + 1) - timedelta(microseconds=1)
|
||||
else:
|
||||
# Default to daily
|
||||
@@ -626,7 +670,9 @@ def _calculate_period_bounds(current_time: datetime, period_type: str) -> tuple[
|
||||
return start, end
|
||||
|
||||
|
||||
async def _get_usage_history(db: AsyncSession, budget: Budget, days: int = 30) -> List[dict]:
|
||||
async def _get_usage_history(
|
||||
db: AsyncSession, budget: Budget, days: int = 30
|
||||
) -> List[dict]:
|
||||
"""Get usage history for the budget"""
|
||||
|
||||
end_date = datetime.utcnow()
|
||||
@@ -634,13 +680,12 @@ async def _get_usage_history(db: AsyncSession, budget: Budget, days: int = 30) -
|
||||
|
||||
# Build query
|
||||
query = select(
|
||||
func.date(UsageTracking.created_at).label('date'),
|
||||
func.sum(UsageTracking.total_tokens).label('tokens'),
|
||||
func.sum(UsageTracking.cost_cents).label('cost_cents'),
|
||||
func.count(UsageTracking.id).label('requests')
|
||||
func.date(UsageTracking.created_at).label("date"),
|
||||
func.sum(UsageTracking.total_tokens).label("tokens"),
|
||||
func.sum(UsageTracking.cost_cents).label("cost_cents"),
|
||||
func.count(UsageTracking.id).label("requests"),
|
||||
).where(
|
||||
UsageTracking.created_at >= start_date,
|
||||
UsageTracking.created_at <= end_date
|
||||
UsageTracking.created_at >= start_date, UsageTracking.created_at <= end_date
|
||||
)
|
||||
|
||||
# Filter by user or API key
|
||||
@@ -649,7 +694,9 @@ async def _get_usage_history(db: AsyncSession, budget: Budget, days: int = 30) -
|
||||
elif budget.user_id:
|
||||
query = query.where(UsageTracking.user_id == budget.user_id)
|
||||
|
||||
query = query.group_by(func.date(UsageTracking.created_at)).order_by(func.date(UsageTracking.created_at))
|
||||
query = query.group_by(func.date(UsageTracking.created_at)).order_by(
|
||||
func.date(UsageTracking.created_at)
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
rows = result.fetchall()
|
||||
@@ -664,12 +711,14 @@ async def _get_usage_history(db: AsyncSession, budget: Budget, days: int = 30) -
|
||||
elif budget.budget_type == "requests":
|
||||
usage_value = row.requests or 0
|
||||
|
||||
history.append({
|
||||
history.append(
|
||||
{
|
||||
"date": row.date.isoformat(),
|
||||
"usage": usage_value,
|
||||
"tokens": row.tokens or 0,
|
||||
"cost_dollars": (row.cost_cents or 0) / 100.0,
|
||||
"requests": row.requests or 0
|
||||
})
|
||||
"requests": row.requests or 0,
|
||||
}
|
||||
)
|
||||
|
||||
return history
|
||||
File diff suppressed because it is too large
Load Diff
3
backend/app/api/v1/endpoints/__init__.py
Normal file
3
backend/app/api/v1/endpoints/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
API v1 endpoints package
|
||||
"""
|
||||
166
backend/app/api/v1/endpoints/tool_calling.py
Normal file
166
backend/app/api/v1/endpoints/tool_calling.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Tool calling API endpoints
|
||||
Integration between LLM and tool execution
|
||||
"""
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.core.security import get_current_user
|
||||
from app.services.tool_calling_service import ToolCallingService
|
||||
from app.services.llm.models import ChatRequest, ChatResponse
|
||||
from app.schemas.tool_calling import (
|
||||
ToolCallRequest,
|
||||
ToolCallResponse,
|
||||
ToolExecutionRequest,
|
||||
ToolValidationRequest,
|
||||
ToolValidationResponse,
|
||||
ToolHistoryResponse,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/chat/completions", response_model=ChatResponse)
|
||||
async def create_chat_completion_with_tools(
|
||||
request: ChatRequest,
|
||||
auto_execute_tools: bool = Query(
|
||||
True, description="Whether to automatically execute tool calls"
|
||||
),
|
||||
max_tool_calls: int = Query(
|
||||
5, ge=1, le=10, description="Maximum number of tool calls"
|
||||
),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Create chat completion with tool calling support"""
|
||||
|
||||
service = ToolCallingService(db)
|
||||
|
||||
# Resolve user ID for context
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
# Set user context in request
|
||||
request.user_id = str(user_id)
|
||||
request.api_key_id = 1 # Default for internal usage
|
||||
|
||||
response = await service.create_chat_completion_with_tools(
|
||||
request=request,
|
||||
user=current_user,
|
||||
auto_execute_tools=auto_execute_tools,
|
||||
max_tool_calls=max_tool_calls,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/chat/completions/stream")
|
||||
async def create_chat_completion_stream_with_tools(
|
||||
request: ChatRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Create streaming chat completion with tool calling support"""
|
||||
|
||||
service = ToolCallingService(db)
|
||||
|
||||
# Resolve user ID for context
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
# Set user context in request
|
||||
request.user_id = str(user_id)
|
||||
request.api_key_id = 1 # Default for internal usage
|
||||
|
||||
async def stream_generator():
|
||||
async for chunk in service.create_chat_completion_stream_with_tools(
|
||||
request=request, user=current_user
|
||||
):
|
||||
yield f"data: {chunk}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
stream_generator(),
|
||||
media_type="text/plain",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/execute", response_model=ToolCallResponse)
|
||||
async def execute_tool_by_name(
|
||||
request: ToolExecutionRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Execute a tool by name directly"""
|
||||
|
||||
service = ToolCallingService(db)
|
||||
|
||||
try:
|
||||
result = await service.execute_tool_by_name(
|
||||
tool_name=request.tool_name,
|
||||
parameters=request.parameters,
|
||||
user=current_user,
|
||||
)
|
||||
|
||||
return ToolCallResponse(success=True, result=result, error=None)
|
||||
|
||||
except Exception as e:
|
||||
return ToolCallResponse(success=False, result=None, error=str(e))
|
||||
|
||||
|
||||
@router.post("/validate", response_model=ToolValidationResponse)
|
||||
async def validate_tool_availability(
|
||||
request: ToolValidationRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Validate which tools are available to the user"""
|
||||
|
||||
service = ToolCallingService(db)
|
||||
|
||||
availability = await service.validate_tool_availability(
|
||||
tool_names=request.tool_names, user=current_user
|
||||
)
|
||||
|
||||
return ToolValidationResponse(tool_availability=availability)
|
||||
|
||||
|
||||
@router.get("/history", response_model=ToolHistoryResponse)
|
||||
async def get_tool_call_history(
|
||||
limit: int = Query(
|
||||
50, ge=1, le=100, description="Number of history items to return"
|
||||
),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Get recent tool execution history for the user"""
|
||||
|
||||
service = ToolCallingService(db)
|
||||
|
||||
history = await service.get_tool_call_history(user=current_user, limit=limit)
|
||||
|
||||
return ToolHistoryResponse(history=history, total=len(history))
|
||||
|
||||
|
||||
@router.get("/available")
|
||||
async def get_available_tools(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Get tools available for function calling"""
|
||||
|
||||
service = ToolCallingService(db)
|
||||
|
||||
# Get available tools
|
||||
tools = await service._get_available_tools_for_user(current_user)
|
||||
|
||||
# Convert to OpenAI format
|
||||
openai_tools = await service._convert_tools_to_openai_format(tools)
|
||||
|
||||
return {"tools": openai_tools, "count": len(openai_tools)}
|
||||
386
backend/app/api/v1/endpoints/tools.py
Normal file
386
backend/app/api/v1/endpoints/tools.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
Tool management and execution API endpoints
|
||||
"""
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.core.security import get_current_user
|
||||
from app.services.tool_management_service import ToolManagementService
|
||||
from app.services.tool_execution_service import ToolExecutionService
|
||||
from app.schemas.tool import (
|
||||
ToolCreate,
|
||||
ToolUpdate,
|
||||
ToolResponse,
|
||||
ToolListResponse,
|
||||
ToolExecutionCreate,
|
||||
ToolExecutionResponse,
|
||||
ToolExecutionListResponse,
|
||||
ToolCategoryCreate,
|
||||
ToolCategoryResponse,
|
||||
ToolStatisticsResponse,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/", response_model=ToolResponse)
|
||||
async def create_tool(
|
||||
tool_data: ToolCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Create a new tool"""
|
||||
service = ToolManagementService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
tool = await service.create_tool(
|
||||
name=tool_data.name,
|
||||
display_name=tool_data.display_name,
|
||||
code=tool_data.code,
|
||||
tool_type=tool_data.tool_type,
|
||||
created_by_user_id=user_id,
|
||||
description=tool_data.description,
|
||||
parameters_schema=tool_data.parameters_schema,
|
||||
return_schema=tool_data.return_schema,
|
||||
timeout_seconds=tool_data.timeout_seconds,
|
||||
max_memory_mb=tool_data.max_memory_mb,
|
||||
max_cpu_seconds=tool_data.max_cpu_seconds,
|
||||
docker_image=tool_data.docker_image,
|
||||
docker_command=tool_data.docker_command,
|
||||
category=tool_data.category,
|
||||
tags=tool_data.tags,
|
||||
is_public=tool_data.is_public,
|
||||
)
|
||||
|
||||
return ToolResponse.from_orm(tool)
|
||||
|
||||
|
||||
@router.get("/", response_model=ToolListResponse)
|
||||
async def list_tools(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
category: Optional[str] = Query(None),
|
||||
tool_type: Optional[str] = Query(None),
|
||||
is_public: Optional[bool] = Query(None),
|
||||
is_approved: Optional[bool] = Query(None),
|
||||
search: Optional[str] = Query(None),
|
||||
tags: Optional[List[str]] = Query(None),
|
||||
created_by_user_id: Optional[int] = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""List tools with filtering and pagination"""
|
||||
service = ToolManagementService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
tools = await service.get_tools(
|
||||
user_id=user_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
category=category,
|
||||
tool_type=tool_type,
|
||||
is_public=is_public,
|
||||
is_approved=is_approved,
|
||||
search=search,
|
||||
tags=tags,
|
||||
created_by_user_id=created_by_user_id,
|
||||
)
|
||||
|
||||
return ToolListResponse(
|
||||
tools=[ToolResponse.from_orm(tool) for tool in tools],
|
||||
total=len(tools),
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{tool_id}", response_model=ToolResponse)
|
||||
async def get_tool(
|
||||
tool_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Get tool by ID"""
|
||||
service = ToolManagementService(db)
|
||||
|
||||
tool = await service.get_tool_by_id(tool_id)
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="Tool not found")
|
||||
|
||||
# Resolve underlying User object if available
|
||||
user_obj = (
|
||||
current_user.get("user_obj")
|
||||
if isinstance(current_user, dict)
|
||||
else current_user
|
||||
)
|
||||
|
||||
# Check if user can access this tool
|
||||
if not user_obj or not tool.can_be_used_by(user_obj):
|
||||
raise HTTPException(status_code=403, detail="Access denied to this tool")
|
||||
|
||||
return ToolResponse.from_orm(tool)
|
||||
|
||||
|
||||
@router.put("/{tool_id}", response_model=ToolResponse)
|
||||
async def update_tool(
|
||||
tool_id: int,
|
||||
tool_data: ToolUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Update tool (only by creator or admin)"""
|
||||
service = ToolManagementService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
tool = await service.update_tool(
|
||||
tool_id=tool_id,
|
||||
user_id=user_id,
|
||||
display_name=tool_data.display_name,
|
||||
description=tool_data.description,
|
||||
code=tool_data.code,
|
||||
parameters_schema=tool_data.parameters_schema,
|
||||
return_schema=tool_data.return_schema,
|
||||
timeout_seconds=tool_data.timeout_seconds,
|
||||
max_memory_mb=tool_data.max_memory_mb,
|
||||
max_cpu_seconds=tool_data.max_cpu_seconds,
|
||||
docker_image=tool_data.docker_image,
|
||||
docker_command=tool_data.docker_command,
|
||||
category=tool_data.category,
|
||||
tags=tool_data.tags,
|
||||
is_public=tool_data.is_public,
|
||||
is_active=tool_data.is_active,
|
||||
)
|
||||
|
||||
return ToolResponse.from_orm(tool)
|
||||
|
||||
|
||||
@router.delete("/{tool_id}")
|
||||
async def delete_tool(
|
||||
tool_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Delete tool (only by creator or admin)"""
|
||||
service = ToolManagementService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
await service.delete_tool(tool_id, user_id)
|
||||
return {"message": "Tool deleted successfully"}
|
||||
|
||||
|
||||
@router.post("/{tool_id}/approve", response_model=ToolResponse)
|
||||
async def approve_tool(
|
||||
tool_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Approve tool for public use (admin only)"""
|
||||
service = ToolManagementService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
tool = await service.approve_tool(tool_id, user_id)
|
||||
return ToolResponse.from_orm(tool)
|
||||
|
||||
|
||||
# Tool Execution Endpoints
|
||||
|
||||
|
||||
@router.post("/{tool_id}/execute", response_model=ToolExecutionResponse)
|
||||
async def execute_tool(
|
||||
tool_id: int,
|
||||
execution_data: ToolExecutionCreate,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Execute a tool with given parameters"""
|
||||
service = ToolExecutionService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
execution = await service.execute_tool(
|
||||
tool_id=tool_id,
|
||||
user_id=user_id,
|
||||
parameters=execution_data.parameters,
|
||||
timeout_override=execution_data.timeout_override,
|
||||
)
|
||||
|
||||
return ToolExecutionResponse.from_orm(execution)
|
||||
|
||||
|
||||
@router.get("/executions", response_model=ToolExecutionListResponse)
|
||||
async def list_executions(
|
||||
tool_id: Optional[int] = Query(None),
|
||||
executed_by_user_id: Optional[int] = Query(None),
|
||||
status: Optional[str] = Query(None),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""List tool executions with filtering"""
|
||||
service = ToolExecutionService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
executions = await service.get_tool_executions(
|
||||
tool_id=tool_id,
|
||||
user_id=user_id,
|
||||
executed_by_user_id=executed_by_user_id,
|
||||
status=status,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return ToolExecutionListResponse(
|
||||
executions=[
|
||||
ToolExecutionResponse.from_orm(execution) for execution in executions
|
||||
],
|
||||
total=len(executions),
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/executions/{execution_id}", response_model=ToolExecutionResponse)
|
||||
async def get_execution(
|
||||
execution_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Get execution details"""
|
||||
service = ToolExecutionService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
# Get execution through list with filter to ensure permission check
|
||||
executions = await service.get_tool_executions(
|
||||
user_id=user_id, skip=0, limit=1
|
||||
)
|
||||
|
||||
execution = next((e for e in executions if e.id == execution_id), None)
|
||||
if not execution:
|
||||
raise HTTPException(status_code=404, detail="Execution not found")
|
||||
|
||||
return ToolExecutionResponse.from_orm(execution)
|
||||
|
||||
|
||||
@router.post("/executions/{execution_id}/cancel", response_model=ToolExecutionResponse)
|
||||
async def cancel_execution(
|
||||
execution_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Cancel a running execution"""
|
||||
service = ToolExecutionService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
execution = await service.cancel_execution(execution_id, user_id)
|
||||
return ToolExecutionResponse.from_orm(execution)
|
||||
|
||||
|
||||
@router.get("/executions/{execution_id}/logs")
|
||||
async def get_execution_logs(
|
||||
execution_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Get real-time logs for execution"""
|
||||
service = ToolExecutionService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
logs = await service.get_execution_logs(execution_id, user_id)
|
||||
return logs
|
||||
|
||||
|
||||
# Tool Categories
|
||||
|
||||
|
||||
@router.post("/categories", response_model=ToolCategoryResponse)
|
||||
async def create_category(
|
||||
category_data: ToolCategoryCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Create a new tool category (admin only)"""
|
||||
user_obj = (
|
||||
current_user.get("user_obj")
|
||||
if isinstance(current_user, dict)
|
||||
else current_user
|
||||
)
|
||||
|
||||
if not user_obj or not user_obj.has_permission("manage_tools"):
|
||||
raise HTTPException(status_code=403, detail="Admin privileges required")
|
||||
|
||||
service = ToolManagementService(db)
|
||||
|
||||
category = await service.create_category(
|
||||
name=category_data.name,
|
||||
display_name=category_data.display_name,
|
||||
description=category_data.description,
|
||||
icon=category_data.icon,
|
||||
color=category_data.color,
|
||||
sort_order=category_data.sort_order,
|
||||
)
|
||||
|
||||
return ToolCategoryResponse.from_orm(category)
|
||||
|
||||
|
||||
@router.get("/categories", response_model=List[ToolCategoryResponse])
|
||||
async def list_categories(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""List all active tool categories"""
|
||||
service = ToolManagementService(db)
|
||||
|
||||
categories = await service.get_categories()
|
||||
return [ToolCategoryResponse.from_orm(category) for category in categories]
|
||||
|
||||
|
||||
# Statistics
|
||||
|
||||
|
||||
@router.get("/statistics", response_model=ToolStatisticsResponse)
|
||||
async def get_statistics(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Get tool usage statistics"""
|
||||
service = ToolManagementService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
stats = await service.get_tool_statistics(user_id=user_id)
|
||||
return ToolStatisticsResponse(**stats)
|
||||
703
backend/app/api/v1/endpoints/user_management.py
Normal file
703
backend/app/api/v1/endpoints/user_management.py
Normal file
@@ -0,0 +1,703 @@
|
||||
"""
|
||||
User Management API endpoints
|
||||
Admin endpoints for managing users, roles, and audit logs
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi.security import HTTPBearer
|
||||
from pydantic import BaseModel, EmailStr, validator, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.db.database import get_db
|
||||
from app.models.user import User
|
||||
from app.models.role import Role
|
||||
from app.models.audit_log import AuditLog
|
||||
from app.services.user_management_service import UserManagementService
|
||||
from app.services.permission_manager import require_permission
|
||||
from app.schemas.role import RoleCreate, RoleUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
class CreateUserRequest(BaseModel):
|
||||
email: EmailStr
|
||||
username: str
|
||||
password: str
|
||||
full_name: Optional[str] = None
|
||||
role_id: Optional[int] = None
|
||||
is_active: bool = True
|
||||
force_password_change: bool = False
|
||||
|
||||
@validator("password")
|
||||
def validate_password(cls, v):
|
||||
if len(v) < 8:
|
||||
raise ValueError("Password must be at least 8 characters long")
|
||||
return v
|
||||
|
||||
@validator("username")
|
||||
def validate_username(cls, v):
|
||||
if len(v) < 3:
|
||||
raise ValueError("Username must be at least 3 characters long")
|
||||
if not v.replace("_", "").replace("-", "").isalnum():
|
||||
raise ValueError("Username must contain only alphanumeric characters, underscores, and hyphens")
|
||||
return v
|
||||
|
||||
|
||||
class UpdateUserRequest(BaseModel):
|
||||
email: Optional[EmailStr] = None
|
||||
username: Optional[str] = None
|
||||
full_name: Optional[str] = None
|
||||
role_id: Optional[int] = None
|
||||
is_active: Optional[bool] = None
|
||||
is_verified: Optional[bool] = None
|
||||
custom_permissions: Optional[Dict[str, Any]] = None
|
||||
|
||||
@validator("username")
|
||||
def validate_username(cls, v):
|
||||
if v is not None and len(v) < 3:
|
||||
raise ValueError("Username must be at least 3 characters long")
|
||||
return v
|
||||
|
||||
|
||||
class AdminPasswordResetRequest(BaseModel):
|
||||
new_password: str
|
||||
force_change_on_login: bool = True
|
||||
|
||||
@validator("new_password")
|
||||
def validate_password(cls, v):
|
||||
if len(v) < 8:
|
||||
raise ValueError("Password must be at least 8 characters long")
|
||||
return v
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: int
|
||||
email: str
|
||||
username: str
|
||||
full_name: Optional[str]
|
||||
role_id: Optional[int]
|
||||
role: Optional[Dict[str, Any]]
|
||||
is_active: bool
|
||||
is_verified: bool
|
||||
account_locked: bool
|
||||
force_password_change: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_login: Optional[datetime]
|
||||
audit_summary: Optional[Dict[str, Any]] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class UserListResponse(BaseModel):
|
||||
users: List[UserResponse]
|
||||
total: int
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
|
||||
class RoleResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str]
|
||||
level: str
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class AuditLogResponse(BaseModel):
|
||||
id: int
|
||||
user_id: Optional[int]
|
||||
action: str
|
||||
resource_type: str
|
||||
resource_id: Optional[str]
|
||||
description: str
|
||||
details: Dict[str, Any]
|
||||
severity: str
|
||||
category: Optional[str]
|
||||
success: bool
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# User Management Endpoints
|
||||
@router.get("/users", response_model=UserListResponse)
|
||||
async def get_users(
|
||||
request: Request,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
search: Optional[str] = None,
|
||||
role_id: Optional[int] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
include_audit_summary: bool = False,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get all users with filtering options"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:users:read",
|
||||
context={"user_id": current_user["id"]}
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
users_data = await service.get_users(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
search=search,
|
||||
role_id=role_id,
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
# Convert to response format
|
||||
users = []
|
||||
for user in users_data:
|
||||
user_dict = user.to_dict()
|
||||
if include_audit_summary:
|
||||
# Get audit summary for user
|
||||
audit_logs = await service.get_user_audit_logs(user.id, limit=10)
|
||||
user_dict["audit_summary"] = {
|
||||
"recent_actions": len(audit_logs),
|
||||
"last_login": user.last_login.isoformat() if user.last_login else None,
|
||||
}
|
||||
user_response = UserResponse(**user_dict)
|
||||
users.append(user_response)
|
||||
|
||||
return UserListResponse(
|
||||
users=users,
|
||||
total=len(users), # Would need actual count query for large datasets
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/users/{user_id}", response_model=UserResponse)
|
||||
async def get_user(
|
||||
user_id: int,
|
||||
include_audit_summary: bool = False,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get specific user by ID"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:users:read",
|
||||
context={"user_id": current_user["id"], "owner_id": user_id}
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
user = await service.get_user_by_id(user_id)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
user_dict = user.to_dict()
|
||||
|
||||
if include_audit_summary:
|
||||
# Get audit summary for user
|
||||
audit_logs = await service.get_user_audit_logs(user_id, limit=10)
|
||||
user_dict["audit_summary"] = {
|
||||
"recent_actions": len(audit_logs),
|
||||
"last_login": user.last_login.isoformat() if user.last_login else None,
|
||||
}
|
||||
|
||||
return UserResponse(**user_dict)
|
||||
|
||||
|
||||
@router.post("/users", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_user(
|
||||
user_data: CreateUserRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a new user"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:users:create",
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
|
||||
user = await service.create_user(
|
||||
email=user_data.email,
|
||||
username=user_data.username,
|
||||
password=user_data.password,
|
||||
full_name=user_data.full_name,
|
||||
role_id=user_data.role_id,
|
||||
is_active=user_data.is_active,
|
||||
is_verified=True, # Admin-created users are verified by default
|
||||
custom_permissions={}, # Empty by default
|
||||
)
|
||||
|
||||
# Log user creation in audit
|
||||
await service._log_audit_event(
|
||||
user_id=current_user["id"],
|
||||
action="create",
|
||||
resource_type="user",
|
||||
resource_id=str(user.id),
|
||||
description=f"User created by admin: {user.email}",
|
||||
details={
|
||||
"created_by": current_user["email"],
|
||||
"target_user": user.email,
|
||||
"role_id": user_data.role_id,
|
||||
},
|
||||
severity="medium",
|
||||
)
|
||||
|
||||
user_dict = user.to_dict()
|
||||
return UserResponse(**user_dict)
|
||||
|
||||
|
||||
@router.put("/users/{user_id}", response_model=UserResponse)
|
||||
async def update_user(
|
||||
user_id: int,
|
||||
user_data: UpdateUserRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update user information"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:users:update",
|
||||
context={"user_id": current_user["id"], "owner_id": user_id}
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
|
||||
# Get current user for audit comparison
|
||||
current_user_data = await service.get_user_by_id(user_id)
|
||||
if not current_user_data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
old_values = current_user_data.to_dict()
|
||||
|
||||
# Update user
|
||||
user = await service.update_user(
|
||||
user_id=user_id,
|
||||
email=user_data.email,
|
||||
username=user_data.username,
|
||||
full_name=user_data.full_name,
|
||||
role_id=user_data.role_id,
|
||||
is_active=user_data.is_active,
|
||||
is_verified=user_data.is_verified,
|
||||
custom_permissions=user_data.custom_permissions,
|
||||
)
|
||||
|
||||
# Log user update in audit
|
||||
await service._log_audit_event(
|
||||
user_id=current_user["id"],
|
||||
action="update",
|
||||
resource_type="user",
|
||||
resource_id=str(user_id),
|
||||
description=f"User updated by admin: {user.email}",
|
||||
details={
|
||||
"updated_by": current_user["email"],
|
||||
"target_user": user.email,
|
||||
},
|
||||
old_values=old_values,
|
||||
new_values=user.to_dict(),
|
||||
severity="medium",
|
||||
)
|
||||
|
||||
user_dict = user.to_dict()
|
||||
return UserResponse(**user_dict)
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/password-reset")
|
||||
async def admin_reset_password(
|
||||
user_id: int,
|
||||
password_data: AdminPasswordResetRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Admin reset user password with forced change option"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:users:manage",
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
|
||||
user = await service.admin_reset_user_password(
|
||||
user_id=user_id,
|
||||
new_password=password_data.new_password,
|
||||
force_change_on_login=password_data.force_change_on_login,
|
||||
admin_user_id=current_user["id"],
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Password reset for user {user.email}",
|
||||
"force_change_on_login": password_data.force_change_on_login,
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}")
|
||||
async def delete_user(
|
||||
user_id: int,
|
||||
hard_delete: bool = False,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Delete or deactivate user"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:users:delete",
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
|
||||
# Get user info for audit before deletion
|
||||
user = await service.get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
user_email = user.email
|
||||
|
||||
# Delete user
|
||||
success = await service.delete_user(user_id, hard_delete=hard_delete)
|
||||
|
||||
# Log user deletion in audit
|
||||
await service._log_audit_event(
|
||||
user_id=current_user["id"],
|
||||
action="delete",
|
||||
resource_type="user",
|
||||
resource_id=str(user_id),
|
||||
description=f"User {'hard deleted' if hard_delete else 'deactivated'} by admin: {user_email}",
|
||||
details={
|
||||
"deleted_by": current_user["email"],
|
||||
"target_user": user_email,
|
||||
"hard_delete": hard_delete,
|
||||
},
|
||||
severity="high",
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"User {'deleted' if hard_delete else 'deactivated'} successfully",
|
||||
"user_email": user_email,
|
||||
}
|
||||
|
||||
|
||||
# Role Management Endpoints
|
||||
@router.get("/roles", response_model=List[RoleResponse])
|
||||
async def get_roles(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get all available roles"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:roles:read",
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
roles = await service.get_roles(is_active=True)
|
||||
|
||||
return [RoleResponse(**role.to_dict()) for role in roles]
|
||||
|
||||
|
||||
@router.post("/roles", response_model=RoleResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_role(
|
||||
role_data: RoleCreate,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a new role"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:roles:create",
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
|
||||
# Create role
|
||||
role = await service.create_role(
|
||||
name=role_data.name,
|
||||
display_name=role_data.display_name,
|
||||
description=role_data.description,
|
||||
level=role_data.level,
|
||||
permissions=role_data.permissions,
|
||||
can_manage_users=role_data.can_manage_users,
|
||||
can_manage_budgets=role_data.can_manage_budgets,
|
||||
can_view_reports=role_data.can_view_reports,
|
||||
can_manage_tools=role_data.can_manage_tools,
|
||||
inherits_from=role_data.inherits_from,
|
||||
is_active=role_data.is_active,
|
||||
is_system_role=role_data.is_system_role,
|
||||
)
|
||||
|
||||
# Log role creation
|
||||
await service._log_audit_event(
|
||||
user_id=current_user["id"],
|
||||
action="create",
|
||||
resource_type="role",
|
||||
resource_id=str(role.id),
|
||||
description=f"Role created: {role.name}",
|
||||
details={
|
||||
"created_by": current_user["email"],
|
||||
"role_name": role.name,
|
||||
"level": role.level,
|
||||
},
|
||||
severity="medium",
|
||||
)
|
||||
|
||||
return RoleResponse(**role.to_dict())
|
||||
|
||||
|
||||
@router.put("/roles/{role_id}", response_model=RoleResponse)
|
||||
async def update_role(
|
||||
role_id: int,
|
||||
role_data: RoleUpdate,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update a role"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:roles:update",
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
|
||||
# Get current role for audit
|
||||
current_role = await service.get_role_by_id(role_id)
|
||||
if not current_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Role not found"
|
||||
)
|
||||
|
||||
# Prevent updating system roles
|
||||
if current_role.is_system_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Cannot modify system roles"
|
||||
)
|
||||
|
||||
old_values = current_role.to_dict()
|
||||
|
||||
# Update role
|
||||
role = await service.update_role(
|
||||
role_id=role_id,
|
||||
display_name=role_data.display_name,
|
||||
description=role_data.description,
|
||||
permissions=role_data.permissions,
|
||||
can_manage_users=role_data.can_manage_users,
|
||||
can_manage_budgets=role_data.can_manage_budgets,
|
||||
can_view_reports=role_data.can_view_reports,
|
||||
can_manage_tools=role_data.can_manage_tools,
|
||||
is_active=role_data.is_active,
|
||||
)
|
||||
|
||||
# Log role update
|
||||
await service._log_audit_event(
|
||||
user_id=current_user["id"],
|
||||
action="update",
|
||||
resource_type="role",
|
||||
resource_id=str(role_id),
|
||||
description=f"Role updated: {role.name}",
|
||||
details={
|
||||
"updated_by": current_user["email"],
|
||||
"role_name": role.name,
|
||||
},
|
||||
old_values=old_values,
|
||||
new_values=role.to_dict(),
|
||||
severity="medium",
|
||||
)
|
||||
|
||||
return RoleResponse(**role.to_dict())
|
||||
|
||||
|
||||
@router.delete("/roles/{role_id}")
|
||||
async def delete_role(
|
||||
role_id: int,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Delete a role"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:roles:delete",
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
|
||||
# Get role info for audit before deletion
|
||||
role = await service.get_role_by_id(role_id)
|
||||
if not role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Role not found"
|
||||
)
|
||||
|
||||
# Prevent deleting system roles
|
||||
if role.is_system_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Cannot delete system roles"
|
||||
)
|
||||
|
||||
role_name = role.name
|
||||
|
||||
# Delete role
|
||||
success = await service.delete_role(role_id)
|
||||
|
||||
# Log role deletion
|
||||
await service._log_audit_event(
|
||||
user_id=current_user["id"],
|
||||
action="delete",
|
||||
resource_type="role",
|
||||
resource_id=str(role_id),
|
||||
description=f"Role deleted: {role_name}",
|
||||
details={
|
||||
"deleted_by": current_user["email"],
|
||||
"role_name": role_name,
|
||||
},
|
||||
severity="high",
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Role {role_name} deleted successfully",
|
||||
"role_name": role_name,
|
||||
}
|
||||
|
||||
|
||||
# Audit Log Endpoints
|
||||
@router.get("/users/{user_id}/audit-logs", response_model=List[AuditLogResponse])
|
||||
async def get_user_audit_logs(
|
||||
user_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
action_filter: Optional[str] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get audit logs for a specific user"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:audit:read",
|
||||
context={"user_id": current_user["id"], "owner_id": user_id}
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
audit_logs = await service.get_user_audit_logs(
|
||||
user_id=user_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
action_filter=action_filter,
|
||||
)
|
||||
|
||||
return [AuditLogResponse(**log.to_dict()) for log in audit_logs]
|
||||
|
||||
|
||||
@router.get("/audit-logs", response_model=List[AuditLogResponse])
|
||||
async def get_all_audit_logs(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
user_id: Optional[int] = None,
|
||||
action_filter: Optional[str] = None,
|
||||
category_filter: Optional[str] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get all audit logs with filtering"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:audit:read",
|
||||
)
|
||||
|
||||
# Direct database query for audit logs with filters
|
||||
from sqlalchemy import select, and_, desc
|
||||
|
||||
query = select(AuditLog)
|
||||
conditions = []
|
||||
|
||||
if user_id:
|
||||
conditions.append(AuditLog.user_id == user_id)
|
||||
if action_filter:
|
||||
conditions.append(AuditLog.action == action_filter)
|
||||
if category_filter:
|
||||
conditions.append(AuditLog.category == category_filter)
|
||||
|
||||
if conditions:
|
||||
query = query.where(and_(*conditions))
|
||||
|
||||
query = query.order_by(desc(AuditLog.created_at))
|
||||
query = query.offset(skip).limit(limit)
|
||||
|
||||
result = await db.execute(query)
|
||||
audit_logs = result.scalars().all()
|
||||
|
||||
return [AuditLogResponse(**log.to_dict()) for log in audit_logs]
|
||||
|
||||
|
||||
# Statistics Endpoints
|
||||
@router.get("/statistics")
|
||||
async def get_user_management_statistics(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get user management statistics"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:users:read",
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
user_stats = await service.get_user_statistics()
|
||||
role_stats = await service.get_role_statistics()
|
||||
|
||||
return {
|
||||
"users": user_stats,
|
||||
"roles": role_stats,
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
@@ -12,16 +12,33 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.services.api_key_auth import require_api_key, RequireScope, APIKeyAuthService, get_api_key_context
|
||||
from app.services.api_key_auth import (
|
||||
require_api_key,
|
||||
RequireScope,
|
||||
APIKeyAuthService,
|
||||
get_api_key_context,
|
||||
)
|
||||
from app.core.security import get_current_user
|
||||
from app.models.user import User
|
||||
from app.core.config import settings
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import ChatRequest, ChatMessage as LLMChatMessage, EmbeddingRequest as LLMEmbeddingRequest
|
||||
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError, ValidationError
|
||||
from app.services.llm.models import (
|
||||
ChatRequest,
|
||||
ChatMessage as LLMChatMessage,
|
||||
EmbeddingRequest as LLMEmbeddingRequest,
|
||||
)
|
||||
from app.services.llm.exceptions import (
|
||||
LLMError,
|
||||
ProviderError,
|
||||
SecurityError,
|
||||
ValidationError,
|
||||
)
|
||||
from app.services.budget_enforcement import (
|
||||
check_budget_for_request, record_request_usage, BudgetEnforcementService,
|
||||
atomic_check_and_reserve_budget, atomic_finalize_usage
|
||||
check_budget_for_request,
|
||||
record_request_usage,
|
||||
BudgetEnforcementService,
|
||||
atomic_check_and_reserve_budget,
|
||||
atomic_finalize_usage,
|
||||
)
|
||||
from app.services.cost_calculator import CostCalculator, estimate_request_cost
|
||||
from app.utils.exceptions import AuthenticationError, AuthorizationError
|
||||
@@ -30,11 +47,7 @@ from app.middleware.analytics import set_analytics_data
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Models response cache - simple in-memory cache for performance
|
||||
_models_cache = {
|
||||
"data": None,
|
||||
"cached_at": 0,
|
||||
"cache_ttl": 900 # 15 minutes cache TTL
|
||||
}
|
||||
_models_cache = {"data": None, "cached_at": 0, "cache_ttl": 900} # 15 minutes cache TTL
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -44,8 +57,10 @@ async def get_cached_models() -> List[Dict[str, Any]]:
|
||||
current_time = time.time()
|
||||
|
||||
# Check if cache is still valid
|
||||
if (_models_cache["data"] is not None and
|
||||
current_time - _models_cache["cached_at"] < _models_cache["cache_ttl"]):
|
||||
if (
|
||||
_models_cache["data"] is not None
|
||||
and current_time - _models_cache["cached_at"] < _models_cache["cache_ttl"]
|
||||
):
|
||||
logger.debug("Returning cached models list")
|
||||
return _models_cache["data"]
|
||||
|
||||
@@ -63,13 +78,17 @@ async def get_cached_models() -> List[Dict[str, Any]]:
|
||||
"created": model_info.created or int(time.time()),
|
||||
"owned_by": model_info.owned_by,
|
||||
# Add frontend-expected fields
|
||||
"name": getattr(model_info, 'name', model_info.id), # Use name if available, fallback to id
|
||||
"provider": getattr(model_info, 'provider', model_info.owned_by), # Use provider if available, fallback to owned_by
|
||||
"name": getattr(
|
||||
model_info, "name", model_info.id
|
||||
), # Use name if available, fallback to id
|
||||
"provider": getattr(
|
||||
model_info, "provider", model_info.owned_by
|
||||
), # Use provider if available, fallback to owned_by
|
||||
"capabilities": model_info.capabilities,
|
||||
"context_window": model_info.context_window,
|
||||
"max_output_tokens": model_info.max_output_tokens,
|
||||
"supports_streaming": model_info.supports_streaming,
|
||||
"supports_function_calling": model_info.supports_function_calling
|
||||
"supports_function_calling": model_info.supports_function_calling,
|
||||
}
|
||||
# Include tasks field if present
|
||||
if model_info.tasks:
|
||||
@@ -138,11 +157,12 @@ class ModelsResponse(BaseModel):
|
||||
# Authentication: Public API endpoints should use require_api_key
|
||||
# Internal API endpoints should use get_current_user from core.security
|
||||
|
||||
|
||||
# Endpoints
|
||||
@router.get("/models", response_model=ModelsResponse)
|
||||
async def list_models(
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List available models"""
|
||||
try:
|
||||
@@ -155,7 +175,7 @@ async def list_models(
|
||||
if not await auth_service.check_scope_permission(context, "models.list"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Insufficient permissions to list models"
|
||||
detail="Insufficient permissions to list models",
|
||||
)
|
||||
|
||||
# Get models from cache or LLM service
|
||||
@@ -164,7 +184,9 @@ async def list_models(
|
||||
# Filter models based on API key permissions
|
||||
api_key = context.get("api_key")
|
||||
if api_key and api_key.allowed_models:
|
||||
models = [model for model in models if model.get("id") in api_key.allowed_models]
|
||||
models = [
|
||||
model for model in models if model.get("id") in api_key.allowed_models
|
||||
]
|
||||
|
||||
return ModelsResponse(data=models)
|
||||
|
||||
@@ -174,14 +196,14 @@ async def list_models(
|
||||
logger.error(f"Error listing models: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to list models"
|
||||
detail="Failed to list models",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/models/invalidate-cache")
|
||||
async def invalidate_models_cache_endpoint(
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Invalidate models cache (admin only)"""
|
||||
# Check for admin permissions
|
||||
@@ -190,7 +212,7 @@ async def invalidate_models_cache_endpoint(
|
||||
if not user or not user.get("is_superuser"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin privileges required"
|
||||
detail="Admin privileges required",
|
||||
)
|
||||
else:
|
||||
# For API key users, check admin permissions
|
||||
@@ -198,7 +220,7 @@ async def invalidate_models_cache_endpoint(
|
||||
if not await auth_service.check_scope_permission(context, "admin.cache"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin permissions required to invalidate cache"
|
||||
detail="Admin permissions required to invalidate cache",
|
||||
)
|
||||
|
||||
invalidate_models_cache()
|
||||
@@ -210,7 +232,7 @@ async def create_chat_completion(
|
||||
request_body: Request,
|
||||
chat_request: ChatCompletionRequest,
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create chat completion with budget enforcement"""
|
||||
try:
|
||||
@@ -221,23 +243,27 @@ async def create_chat_completion(
|
||||
auth_service = APIKeyAuthService(db)
|
||||
|
||||
# Check permissions
|
||||
if not await auth_service.check_scope_permission(context, "chat.completions"):
|
||||
if not await auth_service.check_scope_permission(
|
||||
context, "chat.completions"
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Insufficient permissions for chat completions"
|
||||
detail="Insufficient permissions for chat completions",
|
||||
)
|
||||
|
||||
if not await auth_service.check_model_permission(context, chat_request.model):
|
||||
if not await auth_service.check_model_permission(
|
||||
context, chat_request.model
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Model '{chat_request.model}' not allowed"
|
||||
detail=f"Model '{chat_request.model}' not allowed",
|
||||
)
|
||||
|
||||
api_key = context.get("api_key")
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="API key information not available"
|
||||
detail="API key information not available",
|
||||
)
|
||||
elif auth_type == "jwt":
|
||||
# For JWT authentication, we'll skip the detailed permission checks for now
|
||||
@@ -246,13 +272,13 @@ async def create_chat_completion(
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User information not available"
|
||||
detail="User information not available",
|
||||
)
|
||||
api_key = None # JWT users don't have API keys
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication type"
|
||||
detail="Invalid authentication type",
|
||||
)
|
||||
|
||||
# Estimate token usage for budget checking
|
||||
@@ -265,6 +291,7 @@ async def create_chat_completion(
|
||||
|
||||
# Get a synchronous session for budget enforcement
|
||||
from app.db.database import SessionLocal
|
||||
|
||||
sync_db = SessionLocal()
|
||||
|
||||
try:
|
||||
@@ -272,20 +299,32 @@ async def create_chat_completion(
|
||||
warnings = []
|
||||
reserved_budget_ids = []
|
||||
if auth_type == "api_key" and api_key:
|
||||
is_allowed, error_message, budget_warnings, budget_ids = atomic_check_and_reserve_budget(
|
||||
sync_db, api_key, chat_request.model, int(estimated_tokens), "chat/completions"
|
||||
(
|
||||
is_allowed,
|
||||
error_message,
|
||||
budget_warnings,
|
||||
budget_ids,
|
||||
) = atomic_check_and_reserve_budget(
|
||||
sync_db,
|
||||
api_key,
|
||||
chat_request.model,
|
||||
int(estimated_tokens),
|
||||
"chat/completions",
|
||||
)
|
||||
|
||||
if not is_allowed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Budget exceeded: {error_message}"
|
||||
detail=f"Budget exceeded: {error_message}",
|
||||
)
|
||||
warnings = budget_warnings
|
||||
reserved_budget_ids = budget_ids
|
||||
|
||||
# Convert messages to LLM service format
|
||||
llm_messages = [LLMChatMessage(role=msg.role, content=msg.content) for msg in chat_request.messages]
|
||||
llm_messages = [
|
||||
LLMChatMessage(role=msg.role, content=msg.content)
|
||||
for msg in chat_request.messages
|
||||
]
|
||||
|
||||
# Create LLM service request
|
||||
llm_request = ChatRequest(
|
||||
@@ -299,7 +338,9 @@ async def create_chat_completion(
|
||||
stop=chat_request.stop,
|
||||
stream=chat_request.stream or False,
|
||||
user_id=str(context.get("user_id", "anonymous")),
|
||||
api_key_id=context.get("api_key_id", 0) if auth_type == "api_key" else 0
|
||||
api_key_id=context.get("api_key_id", 0)
|
||||
if auth_type == "api_key"
|
||||
else 0,
|
||||
)
|
||||
|
||||
# Make request to LLM service
|
||||
@@ -316,21 +357,25 @@ async def create_chat_completion(
|
||||
"index": choice.index,
|
||||
"message": {
|
||||
"role": choice.message.role,
|
||||
"content": choice.message.content
|
||||
"content": choice.message.content,
|
||||
},
|
||||
"finish_reason": choice.finish_reason
|
||||
"finish_reason": choice.finish_reason,
|
||||
}
|
||||
for choice in llm_response.choices
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0,
|
||||
"completion_tokens": llm_response.usage.completion_tokens if llm_response.usage else 0,
|
||||
"total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0
|
||||
} if llm_response.usage else {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
"prompt_tokens": llm_response.usage.prompt_tokens
|
||||
if llm_response.usage
|
||||
else 0,
|
||||
"completion_tokens": llm_response.usage.completion_tokens
|
||||
if llm_response.usage
|
||||
else 0,
|
||||
"total_tokens": llm_response.usage.total_tokens
|
||||
if llm_response.usage
|
||||
else 0,
|
||||
}
|
||||
if llm_response.usage
|
||||
else {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
# Calculate actual cost and update usage
|
||||
@@ -347,13 +392,20 @@ async def create_chat_completion(
|
||||
# Finalize actual usage in budgets (only for API key users)
|
||||
if auth_type == "api_key" and api_key:
|
||||
atomic_finalize_usage(
|
||||
sync_db, reserved_budget_ids, api_key, chat_request.model,
|
||||
input_tokens, output_tokens, "chat/completions"
|
||||
sync_db,
|
||||
reserved_budget_ids,
|
||||
api_key,
|
||||
chat_request.model,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
"chat/completions",
|
||||
)
|
||||
|
||||
# Update API key usage statistics
|
||||
auth_service = APIKeyAuthService(db)
|
||||
await auth_service.update_usage_stats(context, total_tokens, actual_cost_cents)
|
||||
await auth_service.update_usage_stats(
|
||||
context, total_tokens, actual_cost_cents
|
||||
)
|
||||
|
||||
# Set analytics data for middleware
|
||||
set_analytics_data(
|
||||
@@ -363,7 +415,7 @@ async def create_chat_completion(
|
||||
total_tokens=total_tokens,
|
||||
cost_cents=actual_cost_cents,
|
||||
budget_ids=reserved_budget_ids,
|
||||
budget_warnings=warnings
|
||||
budget_warnings=warnings,
|
||||
)
|
||||
|
||||
# Add budget warnings to response if any
|
||||
@@ -381,37 +433,37 @@ async def create_chat_completion(
|
||||
logger.warning(f"Security error in chat completion: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Security validation failed: {e.message}"
|
||||
detail=f"Security validation failed: {e.message}",
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Validation error in chat completion: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Request validation failed: {e.message}"
|
||||
detail=f"Request validation failed: {e.message}",
|
||||
)
|
||||
except ProviderError as e:
|
||||
logger.error(f"Provider error in chat completion: {e}")
|
||||
if "rate limit" in str(e).lower():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Rate limit exceeded"
|
||||
detail="Rate limit exceeded",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LLM service temporarily unavailable"
|
||||
detail="LLM service temporarily unavailable",
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.error(f"LLM service error in chat completion: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="LLM service error"
|
||||
detail="LLM service error",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error creating chat completion: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create chat completion"
|
||||
detail="Failed to create chat completion",
|
||||
)
|
||||
|
||||
|
||||
@@ -419,7 +471,7 @@ async def create_chat_completion(
|
||||
async def create_embedding(
|
||||
request: EmbeddingRequest,
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create embedding with budget enforcement"""
|
||||
try:
|
||||
@@ -429,20 +481,20 @@ async def create_embedding(
|
||||
if not await auth_service.check_scope_permission(context, "embeddings.create"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Insufficient permissions for embeddings"
|
||||
detail="Insufficient permissions for embeddings",
|
||||
)
|
||||
|
||||
if not await auth_service.check_model_permission(context, request.model):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Model '{request.model}' not allowed"
|
||||
detail=f"Model '{request.model}' not allowed",
|
||||
)
|
||||
|
||||
api_key = context.get("api_key")
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="API key information not available"
|
||||
detail="API key information not available",
|
||||
)
|
||||
|
||||
# Estimate token usage for budget checking
|
||||
@@ -460,7 +512,7 @@ async def create_embedding(
|
||||
if not is_allowed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Budget exceeded: {error_message}"
|
||||
detail=f"Budget exceeded: {error_message}",
|
||||
)
|
||||
|
||||
# Create LLM service request
|
||||
@@ -469,7 +521,7 @@ async def create_embedding(
|
||||
input=request.input,
|
||||
encoding_format=request.encoding_format,
|
||||
user_id=str(context["user_id"]),
|
||||
api_key_id=context["api_key_id"]
|
||||
api_key_id=context["api_key_id"],
|
||||
)
|
||||
|
||||
# Make request to LLM service
|
||||
@@ -482,18 +534,24 @@ async def create_embedding(
|
||||
{
|
||||
"object": emb.object,
|
||||
"index": emb.index,
|
||||
"embedding": emb.embedding
|
||||
"embedding": emb.embedding,
|
||||
}
|
||||
for emb in llm_response.data
|
||||
],
|
||||
"model": llm_response.model,
|
||||
"usage": {
|
||||
"prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0,
|
||||
"total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0
|
||||
} if llm_response.usage else {
|
||||
"prompt_tokens": int(estimated_tokens),
|
||||
"total_tokens": int(estimated_tokens)
|
||||
"prompt_tokens": llm_response.usage.prompt_tokens
|
||||
if llm_response.usage
|
||||
else 0,
|
||||
"total_tokens": llm_response.usage.total_tokens
|
||||
if llm_response.usage
|
||||
else 0,
|
||||
}
|
||||
if llm_response.usage
|
||||
else {
|
||||
"prompt_tokens": int(estimated_tokens),
|
||||
"total_tokens": int(estimated_tokens),
|
||||
},
|
||||
}
|
||||
|
||||
# Calculate actual cost and update usage
|
||||
@@ -511,7 +569,9 @@ async def create_embedding(
|
||||
)
|
||||
|
||||
# Update API key usage statistics
|
||||
await auth_service.update_usage_stats(context, total_tokens, actual_cost_cents)
|
||||
await auth_service.update_usage_stats(
|
||||
context, total_tokens, actual_cost_cents
|
||||
)
|
||||
|
||||
# Add budget warnings to response if any
|
||||
if warnings:
|
||||
@@ -528,44 +588,42 @@ async def create_embedding(
|
||||
logger.warning(f"Security error in embedding: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Security validation failed: {e.message}"
|
||||
detail=f"Security validation failed: {e.message}",
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Validation error in embedding: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Request validation failed: {e.message}"
|
||||
detail=f"Request validation failed: {e.message}",
|
||||
)
|
||||
except ProviderError as e:
|
||||
logger.error(f"Provider error in embedding: {e}")
|
||||
if "rate limit" in str(e).lower():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Rate limit exceeded"
|
||||
detail="Rate limit exceeded",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LLM service temporarily unavailable"
|
||||
detail="LLM service temporarily unavailable",
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.error(f"LLM service error in embedding: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="LLM service error"
|
||||
detail="LLM service error",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error creating embedding: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create embedding"
|
||||
detail="Failed to create embedding",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def llm_health_check(
|
||||
context: Dict[str, Any] = Depends(require_api_key)
|
||||
):
|
||||
async def llm_health_check(context: Dict[str, Any] = Depends(require_api_key)):
|
||||
"""Health check for LLM service"""
|
||||
try:
|
||||
health_summary = llm_service.get_health_summary()
|
||||
@@ -585,34 +643,31 @@ async def llm_health_check(
|
||||
"status": overall_status,
|
||||
"service": "LLM Service",
|
||||
"service_status": health_summary,
|
||||
"provider_status": {name: {
|
||||
"provider_status": {
|
||||
name: {
|
||||
"status": status.status,
|
||||
"latency_ms": status.latency_ms,
|
||||
"error_message": status.error_message
|
||||
} for name, status in provider_status.items()},
|
||||
"error_message": status.error_message,
|
||||
}
|
||||
for name, status in provider_status.items()
|
||||
},
|
||||
"user_id": context["user_id"],
|
||||
"api_key_name": context["api_key_name"]
|
||||
"api_key_name": context["api_key_name"],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"LLM health check error: {e}")
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"service": "LLM Service",
|
||||
"error": str(e)
|
||||
}
|
||||
return {"status": "unhealthy", "service": "LLM Service", "error": str(e)}
|
||||
|
||||
|
||||
@router.get("/usage")
|
||||
async def get_usage_stats(
|
||||
context: Dict[str, Any] = Depends(require_api_key)
|
||||
):
|
||||
async def get_usage_stats(context: Dict[str, Any] = Depends(require_api_key)):
|
||||
"""Get usage statistics for the API key"""
|
||||
try:
|
||||
api_key = context.get("api_key")
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="API key information not available"
|
||||
detail="API key information not available",
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -622,15 +677,17 @@ async def get_usage_stats(
|
||||
"total_tokens": api_key.total_tokens,
|
||||
"total_cost_cents": api_key.total_cost,
|
||||
"created_at": api_key.created_at.isoformat(),
|
||||
"last_used_at": api_key.last_used_at.isoformat() if api_key.last_used_at else None,
|
||||
"last_used_at": api_key.last_used_at.isoformat()
|
||||
if api_key.last_used_at
|
||||
else None,
|
||||
"rate_limits": {
|
||||
"per_minute": api_key.rate_limit_per_minute,
|
||||
"per_hour": api_key.rate_limit_per_hour,
|
||||
"per_day": api_key.rate_limit_per_day
|
||||
"per_day": api_key.rate_limit_per_day,
|
||||
},
|
||||
"permissions": api_key.permissions,
|
||||
"scopes": api_key.scopes,
|
||||
"allowed_models": api_key.allowed_models
|
||||
"allowed_models": api_key.allowed_models,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
@@ -639,7 +696,7 @@ async def get_usage_stats(
|
||||
logger.error(f"Error getting usage stats: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get usage statistics"
|
||||
detail="Failed to get usage statistics",
|
||||
)
|
||||
|
||||
|
||||
@@ -647,7 +704,7 @@ async def get_usage_stats(
|
||||
async def get_budget_status(
|
||||
request: Request,
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get current budget status and usage analytics"""
|
||||
try:
|
||||
@@ -659,14 +716,14 @@ async def get_budget_status(
|
||||
if not await auth_service.check_scope_permission(context, "budget.read"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Insufficient permissions to read budget information"
|
||||
detail="Insufficient permissions to read budget information",
|
||||
)
|
||||
|
||||
api_key = context.get("api_key")
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="API key information not available"
|
||||
detail="API key information not available",
|
||||
)
|
||||
|
||||
# Convert AsyncSession to Session for budget enforcement
|
||||
@@ -676,10 +733,7 @@ async def get_budget_status(
|
||||
budget_service = BudgetEnforcementService(sync_db)
|
||||
budget_status = budget_service.get_budget_status(api_key)
|
||||
|
||||
return {
|
||||
"object": "budget_status",
|
||||
"data": budget_status
|
||||
}
|
||||
return {"object": "budget_status", "data": budget_status}
|
||||
finally:
|
||||
sync_db.close()
|
||||
|
||||
@@ -689,7 +743,7 @@ async def get_budget_status(
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User information not available"
|
||||
detail="User information not available",
|
||||
)
|
||||
|
||||
# Return basic budget info for JWT users
|
||||
@@ -702,14 +756,14 @@ async def get_budget_status(
|
||||
"projections": {
|
||||
"daily_burn_rate": 0.0,
|
||||
"projected_monthly": 0.0,
|
||||
"days_remaining": 30
|
||||
}
|
||||
}
|
||||
"days_remaining": 30,
|
||||
},
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication type"
|
||||
detail="Invalid authentication type",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
@@ -718,7 +772,7 @@ async def get_budget_status(
|
||||
logger.error(f"Error getting budget status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get budget status"
|
||||
detail="Failed to get budget status",
|
||||
)
|
||||
|
||||
|
||||
@@ -726,7 +780,7 @@ async def get_budget_status(
|
||||
@router.get("/metrics")
|
||||
async def get_llm_metrics(
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get LLM service metrics (admin only)"""
|
||||
try:
|
||||
@@ -735,7 +789,7 @@ async def get_llm_metrics(
|
||||
if not await auth_service.check_scope_permission(context, "admin.metrics"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin permissions required to view metrics"
|
||||
detail="Admin permissions required to view metrics",
|
||||
)
|
||||
|
||||
metrics = llm_service.get_metrics()
|
||||
@@ -748,8 +802,8 @@ async def get_llm_metrics(
|
||||
"average_latency_ms": metrics.average_latency_ms,
|
||||
"average_risk_score": metrics.average_risk_score,
|
||||
"provider_metrics": metrics.provider_metrics,
|
||||
"last_updated": metrics.last_updated.isoformat()
|
||||
}
|
||||
"last_updated": metrics.last_updated.isoformat(),
|
||||
},
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
@@ -758,14 +812,14 @@ async def get_llm_metrics(
|
||||
logger.error(f"Error getting LLM metrics: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get LLM metrics"
|
||||
detail="Failed to get LLM metrics",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/providers/status")
|
||||
async def get_provider_status(
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get status of all LLM providers"""
|
||||
try:
|
||||
@@ -773,7 +827,7 @@ async def get_provider_status(
|
||||
if not await auth_service.check_scope_permission(context, "admin.status"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin permissions required to view provider status"
|
||||
detail="Admin permissions required to view provider status",
|
||||
)
|
||||
|
||||
provider_status = await llm_service.get_provider_status()
|
||||
@@ -787,10 +841,10 @@ async def get_provider_status(
|
||||
"success_rate": status.success_rate,
|
||||
"last_check": status.last_check.isoformat(),
|
||||
"error_message": status.error_message,
|
||||
"models_available": status.models_available
|
||||
"models_available": status.models_available,
|
||||
}
|
||||
for name, status in provider_status.items()
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
@@ -799,5 +853,5 @@ async def get_provider_status(
|
||||
logger.error(f"Error getting provider status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get provider status"
|
||||
detail="Failed to get provider status",
|
||||
)
|
||||
@@ -13,7 +13,12 @@ from app.db.database import get_db
|
||||
from app.core.security import get_current_user
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import ChatRequest, ChatMessage as LLMChatMessage
|
||||
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError, ValidationError
|
||||
from app.services.llm.exceptions import (
|
||||
LLMError,
|
||||
ProviderError,
|
||||
SecurityError,
|
||||
ValidationError,
|
||||
)
|
||||
from app.api.v1.llm import get_cached_models # Reuse the caching logic
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -35,14 +40,12 @@ async def list_models(
|
||||
logger.error(f"Failed to list models: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve models"
|
||||
detail="Failed to retrieve models",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/providers/status")
|
||||
async def get_provider_status(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
async def get_provider_status(current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
"""
|
||||
Get status of all LLM providers for authenticated users
|
||||
"""
|
||||
@@ -58,23 +61,21 @@ async def get_provider_status(
|
||||
"success_rate": status.success_rate,
|
||||
"last_check": status.last_check.isoformat(),
|
||||
"error_message": status.error_message,
|
||||
"models_available": status.models_available
|
||||
"models_available": status.models_available,
|
||||
}
|
||||
for name, status in provider_status.items()
|
||||
}
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get provider status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve provider status"
|
||||
detail="Failed to retrieve provider status",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
async def health_check(current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
"""
|
||||
Get LLM service health status for authenticated users
|
||||
"""
|
||||
@@ -83,39 +84,35 @@ async def health_check(
|
||||
return {
|
||||
"status": health["status"],
|
||||
"providers": health.get("providers", {}),
|
||||
"timestamp": health.get("timestamp")
|
||||
"timestamp": health.get("timestamp"),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Health check failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Health check failed"
|
||||
detail="Health check failed",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
async def get_metrics(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
async def get_metrics(current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
"""
|
||||
Get LLM service metrics for authenticated users
|
||||
"""
|
||||
try:
|
||||
metrics = await llm_service.get_metrics()
|
||||
return {
|
||||
"object": "metrics",
|
||||
"data": metrics
|
||||
}
|
||||
return {"object": "metrics", "data": metrics}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get metrics: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve metrics"
|
||||
detail="Failed to retrieve metrics",
|
||||
)
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
"""Request model for chat completions"""
|
||||
|
||||
model: str
|
||||
messages: List[Dict[str, str]]
|
||||
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0)
|
||||
@@ -128,7 +125,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
async def create_chat_completion(
|
||||
request: ChatCompletionRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Create chat completion for authenticated frontend users.
|
||||
@@ -151,11 +148,13 @@ async def create_chat_completion(
|
||||
top_p=request.top_p,
|
||||
stream=request.stream,
|
||||
user_id=user_id,
|
||||
api_key_id=0 # Special value for JWT-authenticated requests
|
||||
api_key_id=0, # Special value for JWT-authenticated requests
|
||||
)
|
||||
|
||||
# Log the request for debugging
|
||||
logger.info(f"Internal chat completion request from user {current_user.get('id')}: model={request.model}")
|
||||
logger.info(
|
||||
f"Internal chat completion request from user {current_user.get('id')}: model={request.model}"
|
||||
)
|
||||
|
||||
# Process the request through the LLM service
|
||||
response = await llm_service.create_chat_completion(chat_request)
|
||||
@@ -171,17 +170,21 @@ async def create_chat_completion(
|
||||
"index": choice.index,
|
||||
"message": {
|
||||
"role": choice.message.role,
|
||||
"content": choice.message.content
|
||||
"content": choice.message.content,
|
||||
},
|
||||
"finish_reason": choice.finish_reason
|
||||
"finish_reason": choice.finish_reason,
|
||||
}
|
||||
for choice in response.choices
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
|
||||
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
|
||||
"total_tokens": response.usage.total_tokens if response.usage else 0
|
||||
} if response.usage else None
|
||||
"completion_tokens": response.usage.completion_tokens
|
||||
if response.usage
|
||||
else 0,
|
||||
"total_tokens": response.usage.total_tokens if response.usage else 0,
|
||||
}
|
||||
if response.usage
|
||||
else None,
|
||||
}
|
||||
|
||||
return formatted_response
|
||||
@@ -189,18 +192,17 @@ async def create_chat_completion(
|
||||
except ValidationError as e:
|
||||
logger.error(f"Validation error in chat completion: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid request: {str(e)}"
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid request: {str(e)}"
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.error(f"LLM service error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"LLM service error: {str(e)}"
|
||||
detail=f"LLM service error: {str(e)}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in chat completion: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to process chat completion"
|
||||
detail="Failed to process chat completion",
|
||||
)
|
||||
@@ -40,7 +40,7 @@ async def list_modules(current_user: Dict[str, Any] = Depends(get_current_user))
|
||||
"description": module_info["description"],
|
||||
"initialized": is_loaded,
|
||||
"enabled": is_enabled,
|
||||
"status": status # Add status field for frontend compatibility
|
||||
"status": status, # Add status field for frontend compatibility
|
||||
}
|
||||
|
||||
# Get module statistics if available and module is loaded
|
||||
@@ -49,11 +49,14 @@ async def list_modules(current_user: Dict[str, Any] = Depends(get_current_user))
|
||||
if hasattr(module_instance, "get_stats"):
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
if asyncio.iscoroutinefunction(module_instance.get_stats):
|
||||
stats = await module_instance.get_stats()
|
||||
else:
|
||||
stats = module_instance.get_stats()
|
||||
api_module["stats"] = stats.__dict__ if hasattr(stats, "__dict__") else stats
|
||||
api_module["stats"] = (
|
||||
stats.__dict__ if hasattr(stats, "__dict__") else stats
|
||||
)
|
||||
except:
|
||||
api_module["stats"] = {}
|
||||
|
||||
@@ -66,7 +69,7 @@ async def list_modules(current_user: Dict[str, Any] = Depends(get_current_user))
|
||||
"total": len(modules),
|
||||
"modules": modules,
|
||||
"module_count": loaded_count,
|
||||
"initialized": module_manager.initialized
|
||||
"initialized": module_manager.initialized,
|
||||
}
|
||||
|
||||
|
||||
@@ -106,23 +109,30 @@ async def get_modules_status(current_user: Dict[str, Any] = Depends(get_current_
|
||||
if hasattr(module_instance, "get_stats"):
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
if asyncio.iscoroutinefunction(module_instance.get_stats):
|
||||
stats_result = await module_instance.get_stats()
|
||||
else:
|
||||
stats_result = module_instance.get_stats()
|
||||
stats = stats_result.__dict__ if hasattr(stats_result, "__dict__") else stats_result
|
||||
stats = (
|
||||
stats_result.__dict__
|
||||
if hasattr(stats_result, "__dict__")
|
||||
else stats_result
|
||||
)
|
||||
except:
|
||||
stats = {}
|
||||
|
||||
modules_with_status.append({
|
||||
modules_with_status.append(
|
||||
{
|
||||
"name": name,
|
||||
"version": module_info["version"],
|
||||
"description": module_info["description"],
|
||||
"status": status,
|
||||
"enabled": is_enabled,
|
||||
"loaded": is_loaded,
|
||||
"stats": stats
|
||||
})
|
||||
"stats": stats,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"modules": modules_with_status,
|
||||
@@ -130,12 +140,14 @@ async def get_modules_status(current_user: Dict[str, Any] = Depends(get_current_
|
||||
"running": running_count,
|
||||
"standby": standby_count,
|
||||
"failed": failed_count,
|
||||
"system_initialized": module_manager.initialized
|
||||
"system_initialized": module_manager.initialized,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{module_name}")
|
||||
async def get_module_info(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def get_module_info(
|
||||
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Get detailed information about a specific module"""
|
||||
log_api_request("get_module_info", {"module_name": module_name})
|
||||
|
||||
@@ -148,14 +160,17 @@ async def get_module_info(module_name: str, current_user: Dict[str, Any] = Depen
|
||||
"version": getattr(module, "version", "1.0.0"),
|
||||
"description": getattr(module, "description", ""),
|
||||
"initialized": getattr(module, "initialized", False),
|
||||
"enabled": module_manager.module_configs.get(module_name, ModuleConfig(module_name)).enabled,
|
||||
"capabilities": []
|
||||
"enabled": module_manager.module_configs.get(
|
||||
module_name, ModuleConfig(module_name)
|
||||
).enabled,
|
||||
"capabilities": [],
|
||||
}
|
||||
|
||||
# Get module capabilities
|
||||
if hasattr(module, "get_module_info"):
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
if asyncio.iscoroutinefunction(module.get_module_info):
|
||||
info = await module.get_module_info()
|
||||
else:
|
||||
@@ -168,11 +183,14 @@ async def get_module_info(module_name: str, current_user: Dict[str, Any] = Depen
|
||||
if hasattr(module, "get_stats"):
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
if asyncio.iscoroutinefunction(module.get_stats):
|
||||
stats = await module.get_stats()
|
||||
else:
|
||||
stats = module.get_stats()
|
||||
module_info["stats"] = stats.__dict__ if hasattr(stats, "__dict__") else stats
|
||||
module_info["stats"] = (
|
||||
stats.__dict__ if hasattr(stats, "__dict__") else stats
|
||||
)
|
||||
except:
|
||||
module_info["stats"] = {}
|
||||
|
||||
@@ -188,7 +206,9 @@ async def get_module_info(module_name: str, current_user: Dict[str, Any] = Depen
|
||||
|
||||
|
||||
@router.post("/{module_name}/enable")
|
||||
async def enable_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def enable_module(
|
||||
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Enable a module"""
|
||||
log_api_request("enable_module", {"module_name": module_name})
|
||||
|
||||
@@ -204,16 +224,18 @@ async def enable_module(module_name: str, current_user: Dict[str, Any] = Depends
|
||||
try:
|
||||
await module_manager._load_module(module_name, config)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to enable module '{module_name}': {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to enable module '{module_name}': {str(e)}",
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Module '{module_name}' enabled successfully",
|
||||
"enabled": True
|
||||
}
|
||||
return {"message": f"Module '{module_name}' enabled successfully", "enabled": True}
|
||||
|
||||
|
||||
@router.post("/{module_name}/disable")
|
||||
async def disable_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def disable_module(
|
||||
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Disable a module"""
|
||||
log_api_request("disable_module", {"module_name": module_name})
|
||||
|
||||
@@ -229,11 +251,14 @@ async def disable_module(module_name: str, current_user: Dict[str, Any] = Depend
|
||||
try:
|
||||
await module_manager.unload_module(module_name)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to disable module '{module_name}': {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to disable module '{module_name}': {str(e)}",
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Module '{module_name}' disabled successfully",
|
||||
"enabled": False
|
||||
"enabled": False,
|
||||
}
|
||||
|
||||
|
||||
@@ -260,19 +285,21 @@ async def reload_all_modules(current_user: Dict[str, Any] = Depends(get_current_
|
||||
"message": f"Reloaded {len(results) - len(failed_modules)}/{len(results)} modules successfully",
|
||||
"success": False,
|
||||
"results": results,
|
||||
"failed_modules": failed_modules
|
||||
"failed_modules": failed_modules,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"message": f"All {len(results)} modules reloaded successfully",
|
||||
"success": True,
|
||||
"results": results,
|
||||
"failed_modules": []
|
||||
"failed_modules": [],
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{module_name}/reload")
|
||||
async def reload_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def reload_module(
|
||||
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Reload a specific module"""
|
||||
log_api_request("reload_module", {"module_name": module_name})
|
||||
|
||||
@@ -282,16 +309,20 @@ async def reload_module(module_name: str, current_user: Dict[str, Any] = Depends
|
||||
success = await module_manager.reload_module(module_name)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to reload module '{module_name}'")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to reload module '{module_name}'"
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Module '{module_name}' reloaded successfully",
|
||||
"reloaded": True
|
||||
"reloaded": True,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{module_name}/restart")
|
||||
async def restart_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def restart_module(
|
||||
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Restart a specific module (alias for reload)"""
|
||||
log_api_request("restart_module", {"module_name": module_name})
|
||||
|
||||
@@ -301,16 +332,20 @@ async def restart_module(module_name: str, current_user: Dict[str, Any] = Depend
|
||||
success = await module_manager.reload_module(module_name)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to restart module '{module_name}'")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to restart module '{module_name}'"
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Module '{module_name}' restarted successfully",
|
||||
"restarted": True
|
||||
"restarted": True,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{module_name}/start")
|
||||
async def start_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def start_module(
|
||||
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Start a specific module (enable and load)"""
|
||||
log_api_request("start_module", {"module_name": module_name})
|
||||
|
||||
@@ -325,14 +360,13 @@ async def start_module(module_name: str, current_user: Dict[str, Any] = Depends(
|
||||
if module_name not in module_manager.modules:
|
||||
await module_manager._load_module(module_name, config)
|
||||
|
||||
return {
|
||||
"message": f"Module '{module_name}' started successfully",
|
||||
"started": True
|
||||
}
|
||||
return {"message": f"Module '{module_name}' started successfully", "started": True}
|
||||
|
||||
|
||||
@router.post("/{module_name}/stop")
|
||||
async def stop_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def stop_module(
|
||||
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Stop a specific module (disable and unload)"""
|
||||
log_api_request("stop_module", {"module_name": module_name})
|
||||
|
||||
@@ -347,14 +381,13 @@ async def stop_module(module_name: str, current_user: Dict[str, Any] = Depends(g
|
||||
if module_name in module_manager.modules:
|
||||
await module_manager.unload_module(module_name)
|
||||
|
||||
return {
|
||||
"message": f"Module '{module_name}' stopped successfully",
|
||||
"stopped": True
|
||||
}
|
||||
return {"message": f"Module '{module_name}' stopped successfully", "stopped": True}
|
||||
|
||||
|
||||
@router.get("/{module_name}/stats")
|
||||
async def get_module_stats(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def get_module_stats(
|
||||
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Get module statistics"""
|
||||
log_api_request("get_module_stats", {"module_name": module_name})
|
||||
|
||||
@@ -364,26 +397,39 @@ async def get_module_stats(module_name: str, current_user: Dict[str, Any] = Depe
|
||||
module = module_manager.modules[module_name]
|
||||
|
||||
if not hasattr(module, "get_stats"):
|
||||
raise HTTPException(status_code=404, detail=f"Module '{module_name}' does not provide statistics")
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Module '{module_name}' does not provide statistics",
|
||||
)
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
if asyncio.iscoroutinefunction(module.get_stats):
|
||||
stats = await module.get_stats()
|
||||
else:
|
||||
stats = module.get_stats()
|
||||
return {
|
||||
"module": module_name,
|
||||
"stats": stats.__dict__ if hasattr(stats, "__dict__") else stats
|
||||
"stats": stats.__dict__ if hasattr(stats, "__dict__") else stats,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get statistics: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get statistics: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{module_name}/execute")
|
||||
async def execute_module_action(module_name: str, request_data: Dict[str, Any], current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def execute_module_action(
|
||||
module_name: str,
|
||||
request_data: Dict[str, Any],
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Execute a module action through the interceptor pattern"""
|
||||
log_api_request("execute_module_action", {"module_name": module_name, "action": request_data.get("action")})
|
||||
log_api_request(
|
||||
"execute_module_action",
|
||||
{"module_name": module_name, "action": request_data.get("action")},
|
||||
)
|
||||
|
||||
if module_name not in module_manager.modules:
|
||||
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
|
||||
@@ -391,14 +437,16 @@ async def execute_module_action(module_name: str, request_data: Dict[str, Any],
|
||||
module = module_manager.modules[module_name]
|
||||
|
||||
# Check if module supports the new interceptor pattern
|
||||
if hasattr(module, 'execute_with_interceptors'):
|
||||
if hasattr(module, "execute_with_interceptors"):
|
||||
try:
|
||||
# Prepare context (would normally come from authentication middleware)
|
||||
context = {
|
||||
"user_id": "test_user", # Would come from authentication
|
||||
"api_key_id": "test_api_key", # Would come from API key auth
|
||||
"ip_address": "127.0.0.1", # Would come from request
|
||||
"user_permissions": [f"modules:{module_name}:*"] # Would come from user/API key permissions
|
||||
"user_permissions": [
|
||||
f"modules:{module_name}:*"
|
||||
], # Would come from user/API key permissions
|
||||
}
|
||||
|
||||
# Execute through interceptor chain
|
||||
@@ -408,11 +456,13 @@ async def execute_module_action(module_name: str, request_data: Dict[str, Any],
|
||||
"module": module_name,
|
||||
"success": True,
|
||||
"response": response,
|
||||
"interceptor_pattern": True
|
||||
"interceptor_pattern": True,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Module execution failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Module execution failed: {str(e)}"
|
||||
)
|
||||
|
||||
# Fallback for legacy modules
|
||||
else:
|
||||
@@ -423,6 +473,7 @@ async def execute_module_action(module_name: str, request_data: Dict[str, Any],
|
||||
method = getattr(module, action)
|
||||
if callable(method):
|
||||
import asyncio
|
||||
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
response = await method(request_data)
|
||||
else:
|
||||
@@ -432,18 +483,27 @@ async def execute_module_action(module_name: str, request_data: Dict[str, Any],
|
||||
"module": module_name,
|
||||
"success": True,
|
||||
"response": response,
|
||||
"interceptor_pattern": False
|
||||
"interceptor_pattern": False,
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"'{action}' is not callable")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"'{action}' is not callable"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Module execution failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Module execution failed: {str(e)}"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Action '{action}' not supported by module '{module_name}'")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Action '{action}' not supported by module '{module_name}'",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{module_name}/config")
|
||||
async def get_module_config(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def get_module_config(
|
||||
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Get module configuration schema and current values"""
|
||||
log_api_request("get_module_config", {"module_name": module_name})
|
||||
|
||||
@@ -471,10 +531,15 @@ async def get_module_config(module_name: str, current_user: Dict[str, Any] = Dep
|
||||
dynamic_schema = copy.deepcopy(schema)
|
||||
|
||||
# Add enum options for the model field
|
||||
if "properties" in dynamic_schema and "model" in dynamic_schema["properties"]:
|
||||
if (
|
||||
"properties" in dynamic_schema
|
||||
and "model" in dynamic_schema["properties"]
|
||||
):
|
||||
dynamic_schema["properties"]["model"]["enum"] = model_ids
|
||||
# Set a sensible default if the current default isn't in the list
|
||||
current_default = dynamic_schema["properties"]["model"].get("default", "gpt-3.5-turbo")
|
||||
current_default = dynamic_schema["properties"]["model"].get(
|
||||
"default", "gpt-3.5-turbo"
|
||||
)
|
||||
if current_default not in model_ids and model_ids:
|
||||
dynamic_schema["properties"]["model"]["default"] = model_ids[0]
|
||||
|
||||
@@ -489,12 +554,16 @@ async def get_module_config(module_name: str, current_user: Dict[str, Any] = Dep
|
||||
"description": manifest.description,
|
||||
"schema": schema,
|
||||
"current_config": current_config,
|
||||
"has_schema": schema is not None
|
||||
"has_schema": schema is not None,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{module_name}/config")
|
||||
async def update_module_config(module_name: str, config: dict, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def update_module_config(
|
||||
module_name: str,
|
||||
config: dict,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Update module configuration"""
|
||||
log_api_request("update_module_config", {"module_name": module_name})
|
||||
|
||||
@@ -518,7 +587,7 @@ async def update_module_config(module_name: str, config: dict, current_user: Dic
|
||||
|
||||
return {
|
||||
"message": f"Configuration updated for module '{module_name}'",
|
||||
"config": config
|
||||
"config": config,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -12,9 +12,13 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.db.database import get_db
|
||||
from app.services.api_key_auth import require_api_key
|
||||
from app.api.v1.llm import (
|
||||
get_cached_models, ModelsResponse, ModelInfo,
|
||||
ChatCompletionRequest, EmbeddingRequest, create_chat_completion as llm_chat_completion,
|
||||
create_embedding as llm_create_embedding
|
||||
get_cached_models,
|
||||
ModelsResponse,
|
||||
ModelInfo,
|
||||
ChatCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
create_chat_completion as llm_chat_completion,
|
||||
create_embedding as llm_create_embedding,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -22,8 +26,12 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def openai_error_response(message: str, error_type: str = "invalid_request_error",
|
||||
status_code: int = 400, code: str = None):
|
||||
def openai_error_response(
|
||||
message: str,
|
||||
error_type: str = "invalid_request_error",
|
||||
status_code: int = 400,
|
||||
code: str = None,
|
||||
):
|
||||
"""Create OpenAI-compatible error response"""
|
||||
error_data = {
|
||||
"error": {
|
||||
@@ -34,16 +42,13 @@ def openai_error_response(message: str, error_type: str = "invalid_request_error
|
||||
if code:
|
||||
error_data["error"]["code"] = code
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=error_data
|
||||
)
|
||||
return JSONResponse(status_code=status_code, content=error_data)
|
||||
|
||||
|
||||
@router.get("/models", response_model=ModelsResponse)
|
||||
async def list_models(
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Lists the currently available models, and provides basic information about each one
|
||||
@@ -55,30 +60,23 @@ async def list_models(
|
||||
try:
|
||||
# Delegate to the existing LLM models endpoint
|
||||
from app.api.v1.llm import list_models as llm_list_models
|
||||
|
||||
return await llm_list_models(context, db)
|
||||
except HTTPException as e:
|
||||
# Convert FastAPI HTTPException to OpenAI format
|
||||
if e.status_code == 401:
|
||||
return openai_error_response(
|
||||
"Invalid authentication credentials",
|
||||
"authentication_error",
|
||||
401
|
||||
"Invalid authentication credentials", "authentication_error", 401
|
||||
)
|
||||
elif e.status_code == 403:
|
||||
return openai_error_response(
|
||||
"Insufficient permissions",
|
||||
"permission_error",
|
||||
403
|
||||
"Insufficient permissions", "permission_error", 403
|
||||
)
|
||||
else:
|
||||
return openai_error_response(str(e.detail), "api_error", e.status_code)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in OpenAI models endpoint: {e}")
|
||||
return openai_error_response(
|
||||
"Internal server error",
|
||||
"api_error",
|
||||
500
|
||||
)
|
||||
return openai_error_response("Internal server error", "api_error", 500)
|
||||
|
||||
|
||||
@router.post("/chat/completions")
|
||||
@@ -86,7 +84,7 @@ async def create_chat_completion(
|
||||
request_body: Request,
|
||||
chat_request: ChatCompletionRequest,
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Create chat completion - OpenAI compatible endpoint
|
||||
@@ -102,7 +100,7 @@ async def create_chat_completion(
|
||||
async def create_embedding(
|
||||
request: EmbeddingRequest,
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Create embedding - OpenAI compatible endpoint
|
||||
@@ -118,7 +116,7 @@ async def create_embedding(
|
||||
async def retrieve_model(
|
||||
model_id: str,
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Retrieve model information - OpenAI compatible endpoint
|
||||
@@ -133,7 +131,9 @@ async def retrieve_model(
|
||||
# Filter models based on API key permissions
|
||||
api_key = context.get("api_key")
|
||||
if api_key and api_key.allowed_models:
|
||||
models = [model for model in models if model.get("id") in api_key.allowed_models]
|
||||
models = [
|
||||
model for model in models if model.get("id") in api_key.allowed_models
|
||||
]
|
||||
|
||||
# Find the specific model
|
||||
model = next((m for m in models if m.get("id") == model_id), None)
|
||||
@@ -141,14 +141,14 @@ async def retrieve_model(
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Model '{model_id}' not found"
|
||||
detail=f"Model '{model_id}' not found",
|
||||
)
|
||||
|
||||
return ModelInfo(
|
||||
id=model.get("id", model_id),
|
||||
object="model",
|
||||
created=model.get("created", 0),
|
||||
owned_by=model.get("owned_by", "system")
|
||||
owned_by=model.get("owned_by", "system"),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
@@ -157,5 +157,5 @@ async def retrieve_model(
|
||||
logger.error(f"Error retrieving model {model_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve model information"
|
||||
detail="Failed to retrieve model information",
|
||||
)
|
||||
@@ -7,7 +7,11 @@ from typing import List, Dict, Any, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.services.permission_manager import permission_registry, Permission, PermissionScope
|
||||
from app.services.permission_manager import (
|
||||
permission_registry,
|
||||
Permission,
|
||||
PermissionScope,
|
||||
)
|
||||
from app.core.logging import get_logger
|
||||
from app.core.security import get_current_user
|
||||
|
||||
@@ -86,7 +90,7 @@ async def get_available_permissions(namespace: Optional[str] = None):
|
||||
resource=perm.resource,
|
||||
action=perm.action,
|
||||
description=perm.description,
|
||||
conditions=getattr(perm, 'conditions', None)
|
||||
conditions=getattr(perm, "conditions", None),
|
||||
)
|
||||
for perm in perms
|
||||
]
|
||||
@@ -97,7 +101,7 @@ async def get_available_permissions(namespace: Optional[str] = None):
|
||||
logger.error(f"Error getting permissions: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get permissions: {str(e)}"
|
||||
detail=f"Failed to get permissions: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -112,7 +116,7 @@ async def get_permission_hierarchy():
|
||||
logger.error(f"Error getting permission hierarchy: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get permission hierarchy: {str(e)}"
|
||||
detail=f"Failed to get permission hierarchy: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -120,44 +124,43 @@ async def get_permission_hierarchy():
|
||||
async def validate_permissions(request: PermissionValidationRequest):
|
||||
"""Validate a list of permissions"""
|
||||
try:
|
||||
validation_result = permission_registry.validate_permissions(request.permissions)
|
||||
validation_result = permission_registry.validate_permissions(
|
||||
request.permissions
|
||||
)
|
||||
return PermissionValidationResponse(**validation_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating permissions: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to validate permissions: {str(e)}"
|
||||
detail=f"Failed to validate permissions: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/permissions/check", response_model=PermissionCheckResponse)
|
||||
async def check_permission(
|
||||
request: PermissionCheckRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Check if user has a specific permission"""
|
||||
try:
|
||||
has_permission = permission_registry.check_permission(
|
||||
request.user_permissions,
|
||||
request.required_permission,
|
||||
request.context
|
||||
request.user_permissions, request.required_permission, request.context
|
||||
)
|
||||
|
||||
matching_permissions = list(permission_registry.tree.get_matching_permissions(
|
||||
request.user_permissions
|
||||
))
|
||||
matching_permissions = list(
|
||||
permission_registry.tree.get_matching_permissions(request.user_permissions)
|
||||
)
|
||||
|
||||
return PermissionCheckResponse(
|
||||
has_permission=has_permission,
|
||||
matching_permissions=matching_permissions
|
||||
has_permission=has_permission, matching_permissions=matching_permissions
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking permission: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to check permission: {str(e)}"
|
||||
detail=f"Failed to check permission: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -172,7 +175,7 @@ async def get_module_permissions(module_id: str):
|
||||
resource=perm.resource,
|
||||
action=perm.action,
|
||||
description=perm.description,
|
||||
conditions=getattr(perm, 'conditions', None)
|
||||
conditions=getattr(perm, "conditions", None),
|
||||
)
|
||||
for perm in permissions
|
||||
]
|
||||
@@ -181,7 +184,7 @@ async def get_module_permissions(module_id: str):
|
||||
logger.error(f"Error getting module permissions: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get module permissions: {str(e)}"
|
||||
detail=f"Failed to get module permissions: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -191,18 +194,19 @@ async def create_role(request: RoleRequest):
|
||||
"""Create a custom role with specific permissions"""
|
||||
try:
|
||||
# Validate permissions first
|
||||
validation_result = permission_registry.validate_permissions(request.permissions)
|
||||
validation_result = permission_registry.validate_permissions(
|
||||
request.permissions
|
||||
)
|
||||
if not validation_result["is_valid"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid permissions: {validation_result['invalid']}"
|
||||
detail=f"Invalid permissions: {validation_result['invalid']}",
|
||||
)
|
||||
|
||||
permission_registry.create_role(request.role_name, request.permissions)
|
||||
|
||||
return RoleResponse(
|
||||
role_name=request.role_name,
|
||||
permissions=request.permissions
|
||||
role_name=request.role_name, permissions=request.permissions
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
@@ -211,7 +215,7 @@ async def create_role(request: RoleRequest):
|
||||
logger.error(f"Error creating role: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to create role: {str(e)}"
|
||||
detail=f"Failed to create role: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -220,14 +224,17 @@ async def get_roles():
|
||||
"""Get all available roles and their permissions"""
|
||||
try:
|
||||
# Combine default roles and custom roles
|
||||
all_roles = {**permission_registry.default_roles, **permission_registry.role_permissions}
|
||||
all_roles = {
|
||||
**permission_registry.default_roles,
|
||||
**permission_registry.role_permissions,
|
||||
}
|
||||
return all_roles
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting roles: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get roles: {str(e)}"
|
||||
detail=f"Failed to get roles: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -236,20 +243,17 @@ async def get_role(role_name: str):
|
||||
"""Get a specific role and its permissions"""
|
||||
try:
|
||||
# Check default roles first, then custom roles
|
||||
permissions = (permission_registry.role_permissions.get(role_name) or
|
||||
permission_registry.default_roles.get(role_name))
|
||||
permissions = permission_registry.role_permissions.get(
|
||||
role_name
|
||||
) or permission_registry.default_roles.get(role_name)
|
||||
|
||||
if permissions is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Role '{role_name}' not found"
|
||||
detail=f"Role '{role_name}' not found",
|
||||
)
|
||||
|
||||
return RoleResponse(
|
||||
role_name=role_name,
|
||||
permissions=permissions,
|
||||
created=True
|
||||
)
|
||||
return RoleResponse(role_name=role_name, permissions=permissions, created=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -257,7 +261,7 @@ async def get_role(role_name: str):
|
||||
logger.error(f"Error getting role: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get role: {str(e)}"
|
||||
detail=f"Failed to get role: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -267,21 +271,20 @@ async def calculate_user_permissions(request: UserPermissionsRequest):
|
||||
"""Calculate effective permissions for a user based on roles and custom permissions"""
|
||||
try:
|
||||
effective_permissions = permission_registry.get_user_permissions(
|
||||
request.roles,
|
||||
request.custom_permissions
|
||||
request.roles, request.custom_permissions
|
||||
)
|
||||
|
||||
return UserPermissionsResponse(
|
||||
effective_permissions=effective_permissions,
|
||||
roles=request.roles,
|
||||
custom_permissions=request.custom_permissions or []
|
||||
custom_permissions=request.custom_permissions or [],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating user permissions: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to calculate user permissions: {str(e)}"
|
||||
detail=f"Failed to calculate user permissions: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -293,7 +296,9 @@ async def platform_health():
|
||||
# Get permission system status
|
||||
total_permissions = len(permission_registry.tree.permissions)
|
||||
total_modules = len(permission_registry.module_permissions)
|
||||
total_roles = len(permission_registry.default_roles) + len(permission_registry.role_permissions)
|
||||
total_roles = len(permission_registry.default_roles) + len(
|
||||
permission_registry.role_permissions
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
@@ -302,16 +307,13 @@ async def platform_health():
|
||||
"permission_system": {
|
||||
"total_permissions": total_permissions,
|
||||
"registered_modules": total_modules,
|
||||
"available_roles": total_roles
|
||||
}
|
||||
"available_roles": total_roles,
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking platform health: {str(e)}")
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e)
|
||||
}
|
||||
return {"status": "unhealthy", "error": str(e)}
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
@@ -324,17 +326,18 @@ async def platform_metrics():
|
||||
metrics = {
|
||||
"permissions": {
|
||||
"total": len(permission_registry.tree.permissions),
|
||||
"by_namespace": {ns: len(perms) for ns, perms in namespaces.items()}
|
||||
"by_namespace": {ns: len(perms) for ns, perms in namespaces.items()},
|
||||
},
|
||||
"modules": {
|
||||
"registered": len(permission_registry.module_permissions),
|
||||
"names": list(permission_registry.module_permissions.keys())
|
||||
"names": list(permission_registry.module_permissions.keys()),
|
||||
},
|
||||
"roles": {
|
||||
"default": len(permission_registry.default_roles),
|
||||
"custom": len(permission_registry.role_permissions),
|
||||
"total": len(permission_registry.default_roles) + len(permission_registry.role_permissions)
|
||||
}
|
||||
"total": len(permission_registry.default_roles)
|
||||
+ len(permission_registry.role_permissions),
|
||||
},
|
||||
}
|
||||
|
||||
return metrics
|
||||
@@ -343,5 +346,5 @@ async def platform_metrics():
|
||||
logger.error(f"Error getting platform metrics: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get platform metrics: {str(e)}"
|
||||
detail=f"Failed to get platform metrics: {str(e)}",
|
||||
)
|
||||
@@ -46,28 +46,27 @@ async def discover_plugins(
|
||||
category: str = "",
|
||||
limit: int = 20,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Discover available plugins from repository"""
|
||||
try:
|
||||
tag_list = [tag.strip() for tag in tags.split(",") if tag.strip()] if tags else None
|
||||
tag_list = (
|
||||
[tag.strip() for tag in tags.split(",") if tag.strip()] if tags else None
|
||||
)
|
||||
|
||||
plugins = await plugin_discovery.search_available_plugins(
|
||||
query=query,
|
||||
tags=tag_list,
|
||||
category=category if category else None,
|
||||
limit=limit,
|
||||
db=db
|
||||
db=db,
|
||||
)
|
||||
|
||||
return {
|
||||
"plugins": plugins,
|
||||
"count": len(plugins),
|
||||
"query": query,
|
||||
"filters": {
|
||||
"tags": tag_list,
|
||||
"category": category
|
||||
}
|
||||
"filters": {"tags": tag_list, "category": category},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -76,7 +75,9 @@ async def discover_plugins(
|
||||
|
||||
|
||||
@router.get("/categories")
|
||||
async def get_plugin_categories(current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def get_plugin_categories(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Get available plugin categories"""
|
||||
try:
|
||||
categories = await plugin_discovery.get_plugin_categories()
|
||||
@@ -87,37 +88,32 @@ async def get_plugin_categories(current_user: Dict[str, Any] = Depends(get_curre
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get categories: {e}")
|
||||
|
||||
|
||||
|
||||
@router.get("/installed")
|
||||
async def get_installed_plugins(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get user's installed plugins"""
|
||||
try:
|
||||
plugins = await plugin_discovery.get_installed_plugins(current_user["id"], db)
|
||||
return {
|
||||
"plugins": plugins,
|
||||
"count": len(plugins)
|
||||
}
|
||||
return {"plugins": plugins, "count": len(plugins)}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get installed plugins: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get installed plugins: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get installed plugins: {e}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/updates")
|
||||
async def check_plugin_updates(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Check for available plugin updates"""
|
||||
try:
|
||||
updates = await plugin_discovery.get_plugin_updates(db)
|
||||
return {
|
||||
"updates": updates,
|
||||
"count": len(updates)
|
||||
}
|
||||
return {"updates": updates, "count": len(updates)}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check updates: {e}")
|
||||
@@ -130,12 +126,15 @@ async def install_plugin(
|
||||
request: PluginInstallRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Install plugin from repository"""
|
||||
try:
|
||||
if request.source != "repository":
|
||||
raise HTTPException(status_code=400, detail="Only repository installation supported via this endpoint")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Only repository installation supported via this endpoint",
|
||||
)
|
||||
|
||||
# Start installation in background
|
||||
background_tasks.add_task(
|
||||
@@ -143,14 +142,14 @@ async def install_plugin(
|
||||
request.plugin_id,
|
||||
request.version,
|
||||
current_user["id"],
|
||||
db
|
||||
db,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "installation_started",
|
||||
"plugin_id": request.plugin_id,
|
||||
"version": request.version,
|
||||
"message": "Plugin installation started in background"
|
||||
"message": "Plugin installation started in background",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -163,17 +162,18 @@ async def install_plugin_from_file(
|
||||
file: UploadFile = File(...),
|
||||
background_tasks: BackgroundTasks = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Install plugin from uploaded file"""
|
||||
try:
|
||||
# Validate file type
|
||||
if not file.filename.endswith('.zip'):
|
||||
if not file.filename.endswith(".zip"):
|
||||
raise HTTPException(status_code=400, detail="Only ZIP files are supported")
|
||||
|
||||
# Save uploaded file
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as temp_file:
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as temp_file:
|
||||
content = await file.read()
|
||||
temp_file.write(content)
|
||||
temp_file_path = temp_file.name
|
||||
@@ -187,12 +187,13 @@ async def install_plugin_from_file(
|
||||
return {
|
||||
"status": "installed",
|
||||
"result": result,
|
||||
"message": "Plugin installed successfully"
|
||||
"message": "Plugin installed successfully",
|
||||
}
|
||||
|
||||
finally:
|
||||
# Cleanup temp file
|
||||
import os
|
||||
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
except Exception as e:
|
||||
@@ -205,7 +206,7 @@ async def uninstall_plugin(
|
||||
plugin_id: str,
|
||||
request: PluginUninstallRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Uninstall plugin"""
|
||||
try:
|
||||
@@ -216,7 +217,7 @@ async def uninstall_plugin(
|
||||
return {
|
||||
"status": "uninstalled",
|
||||
"result": result,
|
||||
"message": "Plugin uninstalled successfully"
|
||||
"message": "Plugin uninstalled successfully",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -229,7 +230,7 @@ async def uninstall_plugin(
|
||||
async def enable_plugin(
|
||||
plugin_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Enable plugin"""
|
||||
try:
|
||||
@@ -248,7 +249,7 @@ async def enable_plugin(
|
||||
return {
|
||||
"status": "enabled",
|
||||
"plugin_id": plugin_id,
|
||||
"message": "Plugin enabled successfully"
|
||||
"message": "Plugin enabled successfully",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -260,7 +261,7 @@ async def enable_plugin(
|
||||
async def disable_plugin(
|
||||
plugin_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Disable plugin"""
|
||||
try:
|
||||
@@ -283,7 +284,7 @@ async def disable_plugin(
|
||||
return {
|
||||
"status": "disabled",
|
||||
"plugin_id": plugin_id,
|
||||
"message": "Plugin disabled successfully"
|
||||
"message": "Plugin disabled successfully",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -295,7 +296,7 @@ async def disable_plugin(
|
||||
async def load_plugin(
|
||||
plugin_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Load plugin into runtime"""
|
||||
try:
|
||||
@@ -310,7 +311,9 @@ async def load_plugin(
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
|
||||
if plugin.status != "enabled":
|
||||
raise HTTPException(status_code=400, detail="Plugin must be enabled to load")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Plugin must be enabled to load"
|
||||
)
|
||||
|
||||
if plugin_id in plugin_loader.loaded_plugins:
|
||||
raise HTTPException(status_code=400, detail="Plugin already loaded")
|
||||
@@ -322,11 +325,13 @@ async def load_plugin(
|
||||
plugin_context = plugin_context_manager.create_plugin_context(
|
||||
plugin_id=plugin_id,
|
||||
user_id=str(current_user.get("id", "unknown")), # Use actual user ID
|
||||
session_type="api_load"
|
||||
session_type="api_load",
|
||||
)
|
||||
|
||||
# Generate plugin token based on context
|
||||
plugin_token = plugin_context_manager.generate_plugin_token(plugin_context["context_id"])
|
||||
plugin_token = plugin_context_manager.generate_plugin_token(
|
||||
plugin_context["context_id"]
|
||||
)
|
||||
|
||||
# Log plugin loading action
|
||||
plugin_context_manager.add_audit_trail_entry(
|
||||
@@ -335,8 +340,8 @@ async def load_plugin(
|
||||
{
|
||||
"plugin_dir": str(plugin_dir),
|
||||
"user_id": current_user.get("id", "unknown"),
|
||||
"action": "load_plugin_with_sandbox"
|
||||
}
|
||||
"action": "load_plugin_with_sandbox",
|
||||
},
|
||||
)
|
||||
|
||||
await plugin_loader.load_plugin_with_sandbox(plugin_dir, plugin_token)
|
||||
@@ -344,7 +349,7 @@ async def load_plugin(
|
||||
return {
|
||||
"status": "loaded",
|
||||
"plugin_id": plugin_id,
|
||||
"message": "Plugin loaded successfully"
|
||||
"message": "Plugin loaded successfully",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -354,8 +359,7 @@ async def load_plugin(
|
||||
|
||||
@router.post("/{plugin_id}/unload")
|
||||
async def unload_plugin(
|
||||
plugin_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
plugin_id: str, current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Unload plugin from runtime"""
|
||||
try:
|
||||
@@ -369,7 +373,7 @@ async def unload_plugin(
|
||||
return {
|
||||
"status": "unloaded",
|
||||
"plugin_id": plugin_id,
|
||||
"message": "Plugin unloaded successfully"
|
||||
"message": "Plugin unloaded successfully",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -382,7 +386,7 @@ async def unload_plugin(
|
||||
async def get_plugin_configuration(
|
||||
plugin_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get plugin configuration for user with automatic decryption"""
|
||||
try:
|
||||
@@ -393,27 +397,25 @@ async def get_plugin_configuration(
|
||||
plugin_id=plugin_id,
|
||||
user_id=current_user["id"],
|
||||
db=db,
|
||||
decrypt_sensitive=False # Don't decrypt sensitive data for API response
|
||||
decrypt_sensitive=False, # Don't decrypt sensitive data for API response
|
||||
)
|
||||
|
||||
if config_data is not None:
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"configuration": config_data,
|
||||
"has_configuration": True
|
||||
"has_configuration": True,
|
||||
}
|
||||
else:
|
||||
# Get default configuration from manifest
|
||||
resolved_config = await plugin_config_manager.get_resolved_configuration(
|
||||
plugin_id=plugin_id,
|
||||
user_id=current_user["id"],
|
||||
db=db
|
||||
plugin_id=plugin_id, user_id=current_user["id"], db=db
|
||||
)
|
||||
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"configuration": resolved_config,
|
||||
"has_configuration": False
|
||||
"has_configuration": False,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -426,7 +428,7 @@ async def save_plugin_configuration(
|
||||
plugin_id: str,
|
||||
config_request: Dict[str, Any],
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Save plugin configuration for user with automatic encryption of sensitive fields"""
|
||||
try:
|
||||
@@ -444,42 +446,46 @@ async def save_plugin_configuration(
|
||||
config_data=config_data,
|
||||
config_name=config_name,
|
||||
config_description=config_description,
|
||||
db=db
|
||||
db=db,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "saved",
|
||||
"plugin_id": plugin_id,
|
||||
"configuration_id": str(saved_config.id),
|
||||
"message": "Configuration saved successfully with automatic encryption of sensitive fields"
|
||||
"message": "Configuration saved successfully with automatic encryption of sensitive fields",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save plugin configuration: {e}")
|
||||
await db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"Failed to save configuration: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to save configuration: {e}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{plugin_id}/schema")
|
||||
async def get_plugin_configuration_schema(
|
||||
plugin_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get plugin configuration schema from manifest"""
|
||||
try:
|
||||
from app.services.plugin_configuration_manager import plugin_config_manager
|
||||
|
||||
# Use the new configuration manager to get schema
|
||||
schema = await plugin_config_manager.get_plugin_configuration_schema(plugin_id, db)
|
||||
schema = await plugin_config_manager.get_plugin_configuration_schema(
|
||||
plugin_id, db
|
||||
)
|
||||
|
||||
if not schema:
|
||||
raise HTTPException(status_code=404, detail=f"No configuration schema available for plugin '{plugin_id}'")
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No configuration schema available for plugin '{plugin_id}'",
|
||||
)
|
||||
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"schema": schema
|
||||
}
|
||||
return {"plugin_id": plugin_id, "schema": schema}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -493,7 +499,7 @@ async def test_plugin_credentials(
|
||||
plugin_id: str,
|
||||
test_request: Dict[str, Any],
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Test plugin credentials (currently supports Zammad)"""
|
||||
import httpx
|
||||
@@ -510,29 +516,36 @@ async def test_plugin_credentials(
|
||||
plugin = result.scalar_one_or_none()
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Plugin '{plugin_id}' not found"
|
||||
)
|
||||
|
||||
# Check if this is a Zammad plugin
|
||||
if plugin.name.lower() != 'zammad':
|
||||
raise HTTPException(status_code=400, detail=f"Credential testing not supported for plugin '{plugin.name}'")
|
||||
if plugin.name.lower() != "zammad":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Credential testing not supported for plugin '{plugin.name}'",
|
||||
)
|
||||
|
||||
# Extract credentials from request
|
||||
zammad_url = test_request.get('zammad_url')
|
||||
api_token = test_request.get('api_token')
|
||||
zammad_url = test_request.get("zammad_url")
|
||||
api_token = test_request.get("api_token")
|
||||
|
||||
if not zammad_url or not api_token:
|
||||
raise HTTPException(status_code=400, detail="Both zammad_url and api_token are required")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Both zammad_url and api_token are required"
|
||||
)
|
||||
|
||||
# Clean up the URL (remove trailing slash)
|
||||
zammad_url = zammad_url.rstrip('/')
|
||||
zammad_url = zammad_url.rstrip("/")
|
||||
|
||||
# Test credentials by making a read-only API call to Zammad
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
# Try to get user info - this is a safe read-only operation
|
||||
test_url = f"{zammad_url}/api/v1/users/me"
|
||||
headers = {
|
||||
'Authorization': f'Token token={api_token}',
|
||||
'Content-Type': 'application/json'
|
||||
"Authorization": f"Token token={api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = await client.get(test_url, headers=headers)
|
||||
@@ -540,66 +553,68 @@ async def test_plugin_credentials(
|
||||
if response.status_code == 200:
|
||||
# Success - credentials are valid
|
||||
user_data = response.json()
|
||||
user_email = user_data.get('email', 'unknown')
|
||||
user_email = user_data.get("email", "unknown")
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Credentials verified! Connected as: {user_email}",
|
||||
"zammad_url": zammad_url,
|
||||
"user_info": {
|
||||
"email": user_email,
|
||||
"firstname": user_data.get('firstname', ''),
|
||||
"lastname": user_data.get('lastname', '')
|
||||
}
|
||||
"firstname": user_data.get("firstname", ""),
|
||||
"lastname": user_data.get("lastname", ""),
|
||||
},
|
||||
}
|
||||
elif response.status_code == 401:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Invalid API token. Please check your token and try again.",
|
||||
"error_code": "invalid_token"
|
||||
"error_code": "invalid_token",
|
||||
}
|
||||
elif response.status_code == 404:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Zammad URL not found. Please verify the URL is correct.",
|
||||
"error_code": "invalid_url"
|
||||
"error_code": "invalid_url",
|
||||
}
|
||||
else:
|
||||
error_text = ""
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_text = error_data.get('error', error_data.get('message', ''))
|
||||
error_text = error_data.get("error", error_data.get("message", ""))
|
||||
except:
|
||||
error_text = response.text[:200]
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Connection failed (HTTP {response.status_code}): {error_text}",
|
||||
"error_code": "connection_failed"
|
||||
"error_code": "connection_failed",
|
||||
}
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Connection timeout. Please check the Zammad URL and your network connection.",
|
||||
"error_code": "timeout"
|
||||
"error_code": "timeout",
|
||||
}
|
||||
except httpx.ConnectError:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Could not connect to Zammad. Please verify the URL is correct and accessible.",
|
||||
"error_code": "connection_error"
|
||||
"error_code": "connection_error",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test plugin credentials: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Test failed: {str(e)}",
|
||||
"error_code": "unknown_error"
|
||||
"error_code": "unknown_error",
|
||||
}
|
||||
|
||||
|
||||
# Background task for plugin installation
|
||||
async def install_plugin_background(plugin_id: str, version: str, user_id: str, db: AsyncSession):
|
||||
async def install_plugin_background(
|
||||
plugin_id: str, version: str, user_id: str, db: AsyncSession
|
||||
):
|
||||
"""Background task for plugin installation"""
|
||||
try:
|
||||
result = await plugin_installer.install_plugin_from_repository(
|
||||
|
||||
@@ -17,7 +17,10 @@ from app.core.security import get_current_user
|
||||
from app.models.user import User
|
||||
from app.core.logging import log_api_request
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage
|
||||
from app.services.llm.models import (
|
||||
ChatRequest as LLMChatRequest,
|
||||
ChatMessage as LLMChatMessage,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -59,11 +62,12 @@ class ImprovePromptRequest(BaseModel):
|
||||
|
||||
@router.get("/templates", response_model=List[PromptTemplateResponse])
|
||||
async def list_prompt_templates(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get all prompt templates"""
|
||||
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
log_api_request("list_prompt_templates", {"user_id": user_id})
|
||||
|
||||
try:
|
||||
@@ -85,26 +89,36 @@ async def list_prompt_templates(
|
||||
"is_default": template.is_default,
|
||||
"is_active": template.is_active,
|
||||
"version": template.version,
|
||||
"created_at": template.created_at.isoformat() if template.created_at else None,
|
||||
"updated_at": template.updated_at.isoformat() if template.updated_at else None
|
||||
"created_at": template.created_at.isoformat()
|
||||
if template.created_at
|
||||
else None,
|
||||
"updated_at": template.updated_at.isoformat()
|
||||
if template.updated_at
|
||||
else None,
|
||||
}
|
||||
template_list.append(template_dict)
|
||||
|
||||
return template_list
|
||||
|
||||
except Exception as e:
|
||||
log_api_request("list_prompt_templates_error", {"error": str(e), "user_id": user_id})
|
||||
raise HTTPException(status_code=500, detail=f"Failed to fetch prompt templates: {str(e)}")
|
||||
log_api_request(
|
||||
"list_prompt_templates_error", {"error": str(e), "user_id": user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch prompt templates: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/templates/{type_key}", response_model=PromptTemplateResponse)
|
||||
async def get_prompt_template(
|
||||
type_key: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get a specific prompt template by type key"""
|
||||
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
log_api_request("get_prompt_template", {"user_id": user_id, "type_key": type_key})
|
||||
|
||||
try:
|
||||
@@ -127,15 +141,23 @@ async def get_prompt_template(
|
||||
"is_default": template.is_default,
|
||||
"is_active": template.is_active,
|
||||
"version": template.version,
|
||||
"created_at": template.created_at.isoformat() if template.created_at else None,
|
||||
"updated_at": template.updated_at.isoformat() if template.updated_at else None
|
||||
"created_at": template.created_at.isoformat()
|
||||
if template.created_at
|
||||
else None,
|
||||
"updated_at": template.updated_at.isoformat()
|
||||
if template.updated_at
|
||||
else None,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
log_api_request("get_prompt_template_error", {"error": str(e), "user_id": user_id})
|
||||
raise HTTPException(status_code=500, detail=f"Failed to fetch prompt template: {str(e)}")
|
||||
log_api_request(
|
||||
"get_prompt_template_error", {"error": str(e), "user_id": user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch prompt template: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/templates/{type_key}")
|
||||
@@ -143,15 +165,16 @@ async def update_prompt_template(
|
||||
type_key: str,
|
||||
request: PromptTemplateRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update a prompt template"""
|
||||
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
log_api_request("update_prompt_template", {
|
||||
"user_id": user_id,
|
||||
"type_key": type_key,
|
||||
"name": request.name
|
||||
})
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
log_api_request(
|
||||
"update_prompt_template",
|
||||
{"user_id": user_id, "type_key": type_key, "name": request.name},
|
||||
)
|
||||
|
||||
try:
|
||||
# Get existing template
|
||||
@@ -175,7 +198,7 @@ async def update_prompt_template(
|
||||
system_prompt=request.system_prompt,
|
||||
is_active=request.is_active,
|
||||
version=template.version + 1,
|
||||
updated_at=datetime.utcnow()
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -183,8 +206,7 @@ async def update_prompt_template(
|
||||
|
||||
# Return updated template
|
||||
updated_result = await db.execute(
|
||||
select(PromptTemplate)
|
||||
.where(PromptTemplate.type_key == type_key)
|
||||
select(PromptTemplate).where(PromptTemplate.type_key == type_key)
|
||||
)
|
||||
updated_template = updated_result.scalar_one()
|
||||
|
||||
@@ -197,40 +219,51 @@ async def update_prompt_template(
|
||||
"is_default": updated_template.is_default,
|
||||
"is_active": updated_template.is_active,
|
||||
"version": updated_template.version,
|
||||
"created_at": updated_template.created_at.isoformat() if updated_template.created_at else None,
|
||||
"updated_at": updated_template.updated_at.isoformat() if updated_template.updated_at else None
|
||||
"created_at": updated_template.created_at.isoformat()
|
||||
if updated_template.created_at
|
||||
else None,
|
||||
"updated_at": updated_template.updated_at.isoformat()
|
||||
if updated_template.updated_at
|
||||
else None,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
log_api_request("update_prompt_template_error", {"error": str(e), "user_id": user_id})
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update prompt template: {str(e)}")
|
||||
log_api_request(
|
||||
"update_prompt_template_error", {"error": str(e), "user_id": user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update prompt template: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/templates/create")
|
||||
async def create_prompt_template(
|
||||
request: PromptTemplateRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a new prompt template"""
|
||||
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
log_api_request("create_prompt_template", {
|
||||
"user_id": user_id,
|
||||
"type_key": request.type_key,
|
||||
"name": request.name
|
||||
})
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
log_api_request(
|
||||
"create_prompt_template",
|
||||
{"user_id": user_id, "type_key": request.type_key, "name": request.name},
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if template already exists
|
||||
existing_result = await db.execute(
|
||||
select(PromptTemplate)
|
||||
.where(PromptTemplate.type_key == request.type_key)
|
||||
select(PromptTemplate).where(PromptTemplate.type_key == request.type_key)
|
||||
)
|
||||
if existing_result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail="Prompt template with this type key already exists")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Prompt template with this type key already exists",
|
||||
)
|
||||
|
||||
# Create new template
|
||||
template = PromptTemplate(
|
||||
@@ -243,7 +276,7 @@ async def create_prompt_template(
|
||||
is_active=request.is_active,
|
||||
version=1,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(template)
|
||||
@@ -259,25 +292,34 @@ async def create_prompt_template(
|
||||
"is_default": template.is_default,
|
||||
"is_active": template.is_active,
|
||||
"version": template.version,
|
||||
"created_at": template.created_at.isoformat() if template.created_at else None,
|
||||
"updated_at": template.updated_at.isoformat() if template.updated_at else None
|
||||
"created_at": template.created_at.isoformat()
|
||||
if template.created_at
|
||||
else None,
|
||||
"updated_at": template.updated_at.isoformat()
|
||||
if template.updated_at
|
||||
else None,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
log_api_request("create_prompt_template_error", {"error": str(e), "user_id": user_id})
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create prompt template: {str(e)}")
|
||||
log_api_request(
|
||||
"create_prompt_template_error", {"error": str(e), "user_id": user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create prompt template: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/variables", response_model=List[PromptVariableResponse])
|
||||
async def list_prompt_variables(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get all available prompt variables"""
|
||||
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
log_api_request("list_prompt_variables", {"user_id": user_id})
|
||||
|
||||
try:
|
||||
@@ -295,25 +337,31 @@ async def list_prompt_variables(
|
||||
"variable_name": variable.variable_name,
|
||||
"description": variable.description,
|
||||
"example_value": variable.example_value,
|
||||
"is_active": variable.is_active
|
||||
"is_active": variable.is_active,
|
||||
}
|
||||
variable_list.append(variable_dict)
|
||||
|
||||
return variable_list
|
||||
|
||||
except Exception as e:
|
||||
log_api_request("list_prompt_variables_error", {"error": str(e), "user_id": user_id})
|
||||
raise HTTPException(status_code=500, detail=f"Failed to fetch prompt variables: {str(e)}")
|
||||
log_api_request(
|
||||
"list_prompt_variables_error", {"error": str(e), "user_id": user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch prompt variables: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/templates/{type_key}/reset")
|
||||
async def reset_prompt_template(
|
||||
type_key: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Reset a prompt template to its default"""
|
||||
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
log_api_request("reset_prompt_template", {"user_id": user_id, "type_key": type_key})
|
||||
|
||||
# Define default prompts (same as in migration)
|
||||
@@ -323,7 +371,7 @@ async def reset_prompt_template(
|
||||
"teacher": "You are an experienced educational tutor and learning facilitator. Break down complex concepts into understandable, digestible parts. Use analogies, examples, and step-by-step explanations to help students learn. Encourage critical thinking through thoughtful questions. Be patient, supportive, and encouraging. Adapt your teaching style to different learning preferences. When a student makes mistakes, guide them to the correct answer rather than just providing it.",
|
||||
"researcher": "You are a thorough research assistant with a focus on accuracy and evidence-based information. Provide well-researched, factual information with sources when possible. Be thorough in your analysis and present multiple perspectives when relevant topics have different viewpoints. Always distinguish between established facts, current research, and opinions. When information is uncertain or contested, clearly communicate the level of confidence and supporting evidence.",
|
||||
"creative_writer": "You are an experienced creative writing mentor and storytelling expert. Help with brainstorming ideas, character development, plot structure, dialogue, and creative expression. Be imaginative and inspiring while providing constructive, actionable feedback. Encourage experimentation with different writing styles and techniques. When reviewing work, balance praise for strengths with specific suggestions for improvement. Help writers find their unique voice while mastering fundamental storytelling principles.",
|
||||
"custom": "You are a helpful AI assistant. Your personality, expertise, and behavior will be defined by the user through custom instructions. Follow the user's guidance on how to respond, what tone to use, and what role to play. Be adaptable and responsive to the specific needs and preferences outlined in your configuration."
|
||||
"custom": "You are a helpful AI assistant. Your personality, expertise, and behavior will be defined by the user through custom instructions. Follow the user's guidance on how to respond, what tone to use, and what role to play. Be adaptable and responsive to the specific needs and preferences outlined in your configuration.",
|
||||
}
|
||||
|
||||
if type_key not in default_prompts:
|
||||
@@ -337,7 +385,7 @@ async def reset_prompt_template(
|
||||
.values(
|
||||
system_prompt=default_prompts[type_key],
|
||||
version=PromptTemplate.version + 1,
|
||||
updated_at=datetime.utcnow()
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -347,22 +395,28 @@ async def reset_prompt_template(
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
log_api_request("reset_prompt_template_error", {"error": str(e), "user_id": user_id})
|
||||
raise HTTPException(status_code=500, detail=f"Failed to reset prompt template: {str(e)}")
|
||||
log_api_request(
|
||||
"reset_prompt_template_error", {"error": str(e), "user_id": user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to reset prompt template: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/improve")
|
||||
async def improve_prompt_with_ai(
|
||||
request: ImprovePromptRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Improve a prompt using AI"""
|
||||
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
log_api_request("improve_prompt_with_ai", {
|
||||
"user_id": user_id,
|
||||
"chatbot_type": request.chatbot_type
|
||||
})
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
log_api_request(
|
||||
"improve_prompt_with_ai",
|
||||
{"user_id": user_id, "chatbot_type": request.chatbot_type},
|
||||
)
|
||||
|
||||
try:
|
||||
# Create system message for improvement
|
||||
@@ -392,7 +446,7 @@ Please improve this prompt to make it more effective for a {request.chatbot_type
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_message},
|
||||
{"role": "user", "content": user_message}
|
||||
{"role": "user", "content": user_message},
|
||||
]
|
||||
|
||||
# Get available models to use a default model
|
||||
@@ -406,11 +460,14 @@ Please improve this prompt to make it more effective for a {request.chatbot_type
|
||||
# Prepare the chat request for the new LLM service
|
||||
chat_request = LLMChatRequest(
|
||||
model=default_model,
|
||||
messages=[LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages],
|
||||
messages=[
|
||||
LLMChatMessage(role=msg["role"], content=msg["content"])
|
||||
for msg in messages
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=1000,
|
||||
user_id=str(user_id),
|
||||
api_key_id=1 # Using default API key, you might want to make this dynamic
|
||||
api_key_id=1, # Using default API key, you might want to make this dynamic
|
||||
)
|
||||
|
||||
# Make the AI call
|
||||
@@ -422,23 +479,28 @@ Please improve this prompt to make it more effective for a {request.chatbot_type
|
||||
return {
|
||||
"improved_prompt": improved_prompt,
|
||||
"original_prompt": request.current_prompt,
|
||||
"model_used": default_model
|
||||
"model_used": default_model,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
log_api_request("improve_prompt_with_ai_error", {"error": str(e), "user_id": user_id})
|
||||
raise HTTPException(status_code=500, detail=f"Failed to improve prompt: {str(e)}")
|
||||
log_api_request(
|
||||
"improve_prompt_with_ai_error", {"error": str(e), "user_id": user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to improve prompt: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/seed-defaults")
|
||||
async def seed_default_templates(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Seed default prompt templates for all chatbot types"""
|
||||
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
log_api_request("seed_default_templates", {"user_id": user_id})
|
||||
|
||||
# Define default prompts (same as in reset)
|
||||
@@ -446,33 +508,33 @@ async def seed_default_templates(
|
||||
"assistant": {
|
||||
"name": "General Assistant",
|
||||
"description": "A helpful, accurate, and friendly AI assistant",
|
||||
"prompt": "You are a helpful AI assistant. Provide accurate, concise, and friendly responses. Always aim to be helpful while being honest about your limitations. When you don't know something, say so clearly. Be professional but approachable in your communication style."
|
||||
"prompt": "You are a helpful AI assistant. Provide accurate, concise, and friendly responses. Always aim to be helpful while being honest about your limitations. When you don't know something, say so clearly. Be professional but approachable in your communication style.",
|
||||
},
|
||||
"customer_support": {
|
||||
"name": "Customer Support Agent",
|
||||
"description": "Professional customer service representative focused on solving problems",
|
||||
"prompt": "You are a professional customer support representative. Be empathetic, professional, and solution-focused in all interactions. Always try to understand the customer's issue fully before providing solutions. Use the knowledge base to provide accurate information. When you cannot resolve an issue, explain clearly how the customer can escalate or get further help. Maintain a helpful and patient tone even in difficult situations."
|
||||
"prompt": "You are a professional customer support representative. Be empathetic, professional, and solution-focused in all interactions. Always try to understand the customer's issue fully before providing solutions. Use the knowledge base to provide accurate information. When you cannot resolve an issue, explain clearly how the customer can escalate or get further help. Maintain a helpful and patient tone even in difficult situations.",
|
||||
},
|
||||
"teacher": {
|
||||
"name": "Educational Tutor",
|
||||
"description": "Patient and encouraging educational facilitator",
|
||||
"prompt": "You are an experienced educational tutor and learning facilitator. Break down complex concepts into understandable, digestible parts. Use analogies, examples, and step-by-step explanations to help students learn. Encourage critical thinking through thoughtful questions. Be patient, supportive, and encouraging. Adapt your teaching style to different learning preferences. When a student makes mistakes, guide them to the correct answer rather than just providing it."
|
||||
"prompt": "You are an experienced educational tutor and learning facilitator. Break down complex concepts into understandable, digestible parts. Use analogies, examples, and step-by-step explanations to help students learn. Encourage critical thinking through thoughtful questions. Be patient, supportive, and encouraging. Adapt your teaching style to different learning preferences. When a student makes mistakes, guide them to the correct answer rather than just providing it.",
|
||||
},
|
||||
"researcher": {
|
||||
"name": "Research Assistant",
|
||||
"description": "Thorough researcher focused on evidence-based information",
|
||||
"prompt": "You are a thorough research assistant with a focus on accuracy and evidence-based information. Provide well-researched, factual information with sources when possible. Be thorough in your analysis and present multiple perspectives when relevant topics have different viewpoints. Always distinguish between established facts, current research, and opinions. When information is uncertain or contested, clearly communicate the level of confidence and supporting evidence."
|
||||
"prompt": "You are a thorough research assistant with a focus on accuracy and evidence-based information. Provide well-researched, factual information with sources when possible. Be thorough in your analysis and present multiple perspectives when relevant topics have different viewpoints. Always distinguish between established facts, current research, and opinions. When information is uncertain or contested, clearly communicate the level of confidence and supporting evidence.",
|
||||
},
|
||||
"creative_writer": {
|
||||
"name": "Creative Writing Mentor",
|
||||
"description": "Imaginative storytelling expert and writing coach",
|
||||
"prompt": "You are an experienced creative writing mentor and storytelling expert. Help with brainstorming ideas, character development, plot structure, dialogue, and creative expression. Be imaginative and inspiring while providing constructive, actionable feedback. Encourage experimentation with different writing styles and techniques. When reviewing work, balance praise for strengths with specific suggestions for improvement. Help writers find their unique voice while mastering fundamental storytelling principles."
|
||||
"prompt": "You are an experienced creative writing mentor and storytelling expert. Help with brainstorming ideas, character development, plot structure, dialogue, and creative expression. Be imaginative and inspiring while providing constructive, actionable feedback. Encourage experimentation with different writing styles and techniques. When reviewing work, balance praise for strengths with specific suggestions for improvement. Help writers find their unique voice while mastering fundamental storytelling principles.",
|
||||
},
|
||||
"custom": {
|
||||
"name": "Custom Chatbot",
|
||||
"description": "Customizable AI assistant with user-defined behavior",
|
||||
"prompt": "You are a helpful AI assistant. Your personality, expertise, and behavior will be defined by the user through custom instructions. Follow the user's guidance on how to respond, what tone to use, and what role to play. Be adaptable and responsive to the specific needs and preferences outlined in your configuration."
|
||||
}
|
||||
"prompt": "You are a helpful AI assistant. Your personality, expertise, and behavior will be defined by the user through custom instructions. Follow the user's guidance on how to respond, what tone to use, and what role to play. Be adaptable and responsive to the specific needs and preferences outlined in your configuration.",
|
||||
},
|
||||
}
|
||||
|
||||
created_templates = []
|
||||
@@ -530,7 +592,9 @@ async def seed_default_templates(
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=[PromptTemplate.type_key])
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[PromptTemplate.type_key]
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
@@ -548,10 +612,14 @@ async def seed_default_templates(
|
||||
"message": "Default templates seeded successfully",
|
||||
"created": created_templates,
|
||||
"updated": updated_templates,
|
||||
"total": len(created_templates) + len(updated_templates)
|
||||
"total": len(created_templates) + len(updated_templates),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
log_api_request("seed_default_templates_error", {"error": str(e), "user_id": user_id})
|
||||
raise HTTPException(status_code=500, detail=f"Failed to seed default templates: {str(e)}")
|
||||
log_api_request(
|
||||
"seed_default_templates_error", {"error": str(e), "user_id": user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to seed default templates: {str(e)}"
|
||||
)
|
||||
|
||||
@@ -29,6 +29,7 @@ router = APIRouter(tags=["RAG"])
|
||||
|
||||
# Request/Response Models
|
||||
|
||||
|
||||
class CollectionCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
@@ -78,12 +79,13 @@ class StatsResponse(BaseModel):
|
||||
|
||||
# Collection Endpoints
|
||||
|
||||
|
||||
@router.get("/collections", response_model=dict)
|
||||
async def get_collections(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get all RAG collections - live data directly from Qdrant (source of truth)"""
|
||||
try:
|
||||
@@ -103,7 +105,7 @@ async def get_collections(
|
||||
"collections": paginated_collections,
|
||||
"total": len(collections),
|
||||
"total_documents": stats_data.get("total_documents", 0),
|
||||
"total_size_bytes": stats_data.get("total_size_bytes", 0)
|
||||
"total_size_bytes": stats_data.get("total_size_bytes", 0),
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -113,20 +115,19 @@ async def get_collections(
|
||||
async def create_collection(
|
||||
collection_data: CollectionCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Create a new RAG collection"""
|
||||
try:
|
||||
rag_service = RAGService(db)
|
||||
collection = await rag_service.create_collection(
|
||||
name=collection_data.name,
|
||||
description=collection_data.description
|
||||
name=collection_data.name, description=collection_data.description
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"collection": collection.to_dict(),
|
||||
"message": "Collection created successfully"
|
||||
"message": "Collection created successfully",
|
||||
}
|
||||
except APIException as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
@@ -136,8 +137,7 @@ async def create_collection(
|
||||
|
||||
@router.get("/stats", response_model=dict)
|
||||
async def get_rag_stats(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get overall RAG statistics - live data directly from Qdrant"""
|
||||
try:
|
||||
@@ -147,7 +147,11 @@ async def get_rag_stats(
|
||||
stats_data = await qdrant_stats_service.get_collections_stats()
|
||||
|
||||
# Calculate active collections (collections with documents)
|
||||
active_collections = sum(1 for col in stats_data.get("collections", []) if col.get("document_count", 0) > 0)
|
||||
active_collections = sum(
|
||||
1
|
||||
for col in stats_data.get("collections", [])
|
||||
if col.get("document_count", 0) > 0
|
||||
)
|
||||
|
||||
# Calculate processing documents from database
|
||||
processing_docs = 0
|
||||
@@ -156,7 +160,9 @@ async def get_rag_stats(
|
||||
from app.models.rag_document import RagDocument, ProcessingStatus
|
||||
|
||||
result = await db.execute(
|
||||
select(RagDocument).where(RagDocument.status == ProcessingStatus.PROCESSING)
|
||||
select(RagDocument).where(
|
||||
RagDocument.status == ProcessingStatus.PROCESSING
|
||||
)
|
||||
)
|
||||
processing_docs = len(result.scalars().all())
|
||||
except Exception:
|
||||
@@ -167,22 +173,28 @@ async def get_rag_stats(
|
||||
"stats": {
|
||||
"collections": {
|
||||
"total": stats_data.get("total_collections", 0),
|
||||
"active": active_collections
|
||||
"active": active_collections,
|
||||
},
|
||||
"documents": {
|
||||
"total": stats_data.get("total_documents", 0),
|
||||
"processing": processing_docs,
|
||||
"processed": stats_data.get("total_documents", 0) # Indexed documents
|
||||
"processed": stats_data.get(
|
||||
"total_documents", 0
|
||||
), # Indexed documents
|
||||
},
|
||||
"storage": {
|
||||
"total_size_bytes": stats_data.get("total_size_bytes", 0),
|
||||
"total_size_mb": round(stats_data.get("total_size_bytes", 0) / (1024 * 1024), 2)
|
||||
"total_size_mb": round(
|
||||
stats_data.get("total_size_bytes", 0) / (1024 * 1024), 2
|
||||
),
|
||||
},
|
||||
"vectors": {
|
||||
"total": stats_data.get("total_documents", 0) # Same as documents for RAG
|
||||
"total": stats_data.get(
|
||||
"total_documents", 0
|
||||
) # Same as documents for RAG
|
||||
},
|
||||
"last_updated": datetime.utcnow().isoformat(),
|
||||
},
|
||||
"last_updated": datetime.utcnow().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
return response_data
|
||||
@@ -194,7 +206,7 @@ async def get_rag_stats(
|
||||
async def get_collection(
|
||||
collection_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get a specific collection"""
|
||||
try:
|
||||
@@ -204,10 +216,7 @@ async def get_collection(
|
||||
if not collection:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"collection": collection.to_dict()
|
||||
}
|
||||
return {"success": True, "collection": collection.to_dict()}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -219,7 +228,7 @@ async def delete_collection(
|
||||
collection_id: int,
|
||||
cascade: bool = True, # Default to cascade deletion for better UX
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a collection and optionally all its documents"""
|
||||
try:
|
||||
@@ -231,7 +240,8 @@ async def delete_collection(
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Collection deleted successfully" + (" (with documents)" if cascade else "")
|
||||
"message": "Collection deleted successfully"
|
||||
+ (" (with documents)" if cascade else ""),
|
||||
}
|
||||
except APIException as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
@@ -243,13 +253,14 @@ async def delete_collection(
|
||||
|
||||
# Document Endpoints
|
||||
|
||||
|
||||
@router.get("/documents", response_model=dict)
|
||||
async def get_documents(
|
||||
collection_id: Optional[str] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get documents, optionally filtered by collection"""
|
||||
try:
|
||||
@@ -260,11 +271,7 @@ async def get_documents(
|
||||
if collection_id.startswith("ext_"):
|
||||
# External collections exist only in Qdrant and have no documents in PostgreSQL
|
||||
# Return empty list since they don't have managed documents
|
||||
return {
|
||||
"success": True,
|
||||
"documents": [],
|
||||
"total": 0
|
||||
}
|
||||
return {"success": True, "documents": [], "total": 0}
|
||||
else:
|
||||
# Try to convert to integer for managed collections
|
||||
try:
|
||||
@@ -272,29 +279,25 @@ async def get_documents(
|
||||
except (ValueError, TypeError):
|
||||
# Attempt to resolve by Qdrant collection name
|
||||
collection_row = await db.scalar(
|
||||
select(RagCollection).where(RagCollection.qdrant_collection_name == collection_id)
|
||||
select(RagCollection).where(
|
||||
RagCollection.qdrant_collection_name == collection_id
|
||||
)
|
||||
)
|
||||
if collection_row:
|
||||
collection_id_int = collection_row.id
|
||||
else:
|
||||
# Unknown collection identifier; return empty result instead of erroring out
|
||||
return {
|
||||
"success": True,
|
||||
"documents": [],
|
||||
"total": 0
|
||||
}
|
||||
return {"success": True, "documents": [], "total": 0}
|
||||
|
||||
rag_service = RAGService(db)
|
||||
documents = await rag_service.get_documents(
|
||||
collection_id=collection_id_int,
|
||||
skip=skip,
|
||||
limit=limit
|
||||
collection_id=collection_id_int, skip=skip, limit=limit
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"documents": [doc.to_dict() for doc in documents],
|
||||
"total": len(documents)
|
||||
"total": len(documents),
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -305,13 +308,13 @@ async def upload_document(
|
||||
collection_id: str = Form(...),
|
||||
file: UploadFile = File(...),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Upload and process a document"""
|
||||
try:
|
||||
# Validate file can be read before processing
|
||||
filename = file.filename or "unknown"
|
||||
file_extension = filename.split('.')[-1].lower() if '.' in filename else ''
|
||||
file_extension = filename.split(".")[-1].lower() if "." in filename else ""
|
||||
|
||||
# Read file content once and use it for all validations
|
||||
file_content = await file.read()
|
||||
@@ -324,50 +327,66 @@ async def upload_document(
|
||||
|
||||
try:
|
||||
# Test file readability based on type
|
||||
if file_extension == 'jsonl':
|
||||
if file_extension == "jsonl":
|
||||
# Validate JSONL format - try to parse first few lines
|
||||
try:
|
||||
content_str = file_content.decode('utf-8')
|
||||
lines = content_str.strip().split('\n')[:5] # Check first 5 lines
|
||||
content_str = file_content.decode("utf-8")
|
||||
lines = content_str.strip().split("\n")[:5] # Check first 5 lines
|
||||
import json
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip(): # Skip empty lines
|
||||
json.loads(line) # Will raise JSONDecodeError if invalid
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="File is not valid UTF-8 text"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSONL format: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid JSONL format: {str(e)}"
|
||||
)
|
||||
|
||||
elif file_extension in ['txt', 'md', 'py', 'js', 'html', 'css', 'json']:
|
||||
elif file_extension in ["txt", "md", "py", "js", "html", "css", "json"]:
|
||||
# Validate text files can be decoded
|
||||
try:
|
||||
file_content.decode('utf-8')
|
||||
file_content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="File is not valid UTF-8 text"
|
||||
)
|
||||
|
||||
elif file_extension in ['pdf']:
|
||||
elif file_extension in ["pdf"]:
|
||||
# For PDF files, just check if it starts with PDF signature
|
||||
if not file_content.startswith(b'%PDF'):
|
||||
raise HTTPException(status_code=400, detail="Invalid PDF file format")
|
||||
if not file_content.startswith(b"%PDF"):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid PDF file format"
|
||||
)
|
||||
|
||||
elif file_extension in ['docx', 'xlsx', 'pptx']:
|
||||
elif file_extension in ["docx", "xlsx", "pptx"]:
|
||||
# For Office documents, check ZIP signature
|
||||
if not file_content.startswith(b'PK'):
|
||||
raise HTTPException(status_code=400, detail=f"Invalid {file_extension.upper()} file format")
|
||||
if not file_content.startswith(b"PK"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid {file_extension.upper()} file format",
|
||||
)
|
||||
|
||||
# For other file types, we'll rely on the document processor
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"File validation failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"File validation failed: {str(e)}"
|
||||
)
|
||||
|
||||
rag_service = RAGService(db)
|
||||
|
||||
# Resolve collection identifier (supports both numeric IDs and Qdrant collection names)
|
||||
collection_identifier = (collection_id or "").strip()
|
||||
if not collection_identifier:
|
||||
raise HTTPException(status_code=400, detail="Collection identifier is required")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Collection identifier is required"
|
||||
)
|
||||
|
||||
resolved_collection_id: Optional[int] = None
|
||||
|
||||
@@ -379,7 +398,9 @@ async def upload_document(
|
||||
qdrant_name = qdrant_name[4:]
|
||||
|
||||
try:
|
||||
collection_record = await rag_service.ensure_collection_record(qdrant_name)
|
||||
collection_record = await rag_service.ensure_collection_record(
|
||||
qdrant_name
|
||||
)
|
||||
except Exception as ensure_error:
|
||||
raise HTTPException(status_code=500, detail=str(ensure_error))
|
||||
|
||||
@@ -392,13 +413,13 @@ async def upload_document(
|
||||
collection_id=resolved_collection_id,
|
||||
file_content=file_content,
|
||||
filename=filename,
|
||||
content_type=file.content_type
|
||||
content_type=file.content_type,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"document": document.to_dict(),
|
||||
"message": "Document uploaded and processing started"
|
||||
"message": "Document uploaded and processing started",
|
||||
}
|
||||
except APIException as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
@@ -412,7 +433,7 @@ async def upload_document(
|
||||
async def get_document(
|
||||
document_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get a specific document"""
|
||||
try:
|
||||
@@ -422,10 +443,7 @@ async def get_document(
|
||||
if not document:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"document": document.to_dict()
|
||||
}
|
||||
return {"success": True, "document": document.to_dict()}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -436,7 +454,7 @@ async def get_document(
|
||||
async def delete_document(
|
||||
document_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a document"""
|
||||
try:
|
||||
@@ -446,10 +464,7 @@ async def delete_document(
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Document deleted successfully"
|
||||
}
|
||||
return {"success": True, "message": "Document deleted successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -460,7 +475,7 @@ async def delete_document(
|
||||
async def reprocess_document(
|
||||
document_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Restart processing for a stuck or failed document"""
|
||||
try:
|
||||
@@ -475,12 +490,12 @@ async def reprocess_document(
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot reprocess document with status '{document.status}'. Only 'processing' or 'error' documents can be reprocessed."
|
||||
detail=f"Cannot reprocess document with status '{document.status}'. Only 'processing' or 'error' documents can be reprocessed.",
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Document reprocessing started successfully"
|
||||
"message": "Document reprocessing started successfully",
|
||||
}
|
||||
except APIException as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
@@ -494,7 +509,7 @@ async def reprocess_document(
|
||||
async def download_document(
|
||||
document_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Download the original document file"""
|
||||
try:
|
||||
@@ -502,14 +517,16 @@ async def download_document(
|
||||
result = await rag_service.download_document(document_id)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Document not found or file not available")
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Document not found or file not available"
|
||||
)
|
||||
|
||||
content, filename, mime_type = result
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(content),
|
||||
media_type=mime_type,
|
||||
headers={"Content-Disposition": f"attachment; filename={filename}"}
|
||||
headers={"Content-Disposition": f"attachment; filename={filename}"},
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -517,9 +534,9 @@ async def download_document(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
||||
# Debug Endpoints
|
||||
|
||||
|
||||
@router.post("/debug/search")
|
||||
async def search_with_debug(
|
||||
query: str,
|
||||
@@ -527,13 +544,13 @@ async def search_with_debug(
|
||||
score_threshold: float = 0.3,
|
||||
collection_name: str = None,
|
||||
config: Dict[str, Any] = None,
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Enhanced search with comprehensive debug information
|
||||
"""
|
||||
# Get RAG module from module manager
|
||||
rag_module = module_manager.modules.get('rag')
|
||||
rag_module = module_manager.modules.get("rag")
|
||||
if not rag_module or not rag_module.enabled:
|
||||
raise HTTPException(status_code=503, detail="RAG module not initialized")
|
||||
|
||||
@@ -567,7 +584,7 @@ async def search_with_debug(
|
||||
query,
|
||||
max_results=max_results,
|
||||
score_threshold=score_threshold,
|
||||
collection_name=collection_name
|
||||
collection_name=collection_name,
|
||||
)
|
||||
search_time = (asyncio.get_event_loop().time() - search_start) * 1000
|
||||
|
||||
@@ -575,22 +592,23 @@ async def search_with_debug(
|
||||
scores = [r.score for r in results if r.score is not None]
|
||||
if scores:
|
||||
import statistics
|
||||
|
||||
debug_info["score_stats"] = {
|
||||
"min": min(scores),
|
||||
"max": max(scores),
|
||||
"avg": statistics.mean(scores),
|
||||
"stddev": statistics.stdev(scores) if len(scores) > 1 else 0
|
||||
"stddev": statistics.stdev(scores) if len(scores) > 1 else 0,
|
||||
}
|
||||
|
||||
# Get collection statistics
|
||||
try:
|
||||
from qdrant_client.http.models import Filter
|
||||
|
||||
collection_name = collection_name or rag_module.default_collection_name
|
||||
|
||||
# Count total documents
|
||||
count_result = rag_module.qdrant_client.count(
|
||||
collection_name=collection_name,
|
||||
count_filter=Filter(must=[])
|
||||
collection_name=collection_name, count_filter=Filter(must=[])
|
||||
)
|
||||
total_points = count_result.count
|
||||
|
||||
@@ -599,7 +617,7 @@ async def search_with_debug(
|
||||
collection_name=collection_name,
|
||||
limit=1000, # Sample for stats
|
||||
with_payload=True,
|
||||
with_vectors=False
|
||||
with_vectors=False,
|
||||
)
|
||||
|
||||
unique_docs = set()
|
||||
@@ -618,7 +636,7 @@ async def search_with_debug(
|
||||
debug_info["collection_stats"] = {
|
||||
"total_documents": len(unique_docs),
|
||||
"total_chunks": total_points,
|
||||
"languages": sorted(list(languages))
|
||||
"languages": sorted(list(languages)),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -631,16 +649,18 @@ async def search_with_debug(
|
||||
"document": {
|
||||
"id": result.document.id,
|
||||
"content": result.document.content,
|
||||
"metadata": result.document.metadata
|
||||
"metadata": result.document.metadata,
|
||||
},
|
||||
"score": result.score,
|
||||
"debug_info": {}
|
||||
"debug_info": {},
|
||||
}
|
||||
|
||||
# Add hybrid search debug info if available
|
||||
metadata = result.document.metadata or {}
|
||||
if "_vector_score" in metadata:
|
||||
enhanced_result["debug_info"]["vector_score"] = metadata["_vector_score"]
|
||||
enhanced_result["debug_info"]["vector_score"] = metadata[
|
||||
"_vector_score"
|
||||
]
|
||||
if "_bm25_score" in metadata:
|
||||
enhanced_result["debug_info"]["bm25_score"] = metadata["_bm25_score"]
|
||||
|
||||
@@ -652,7 +672,7 @@ async def search_with_debug(
|
||||
"results": enhanced_results,
|
||||
"debug_info": debug_info,
|
||||
"search_time_ms": search_time,
|
||||
"timestamp": start_time.isoformat()
|
||||
"timestamp": start_time.isoformat(),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -661,17 +681,17 @@ async def search_with_debug(
|
||||
|
||||
finally:
|
||||
# Restore original config if modified
|
||||
if config and 'original_config' in locals():
|
||||
if config and "original_config" in locals():
|
||||
rag_module.config = original_config
|
||||
|
||||
|
||||
@router.get("/debug/config")
|
||||
async def get_current_config(
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current RAG configuration"""
|
||||
# Get RAG module from module manager
|
||||
rag_module = module_manager.modules.get('rag')
|
||||
rag_module = module_manager.modules.get("rag")
|
||||
if not rag_module or not rag_module.enabled:
|
||||
raise HTTPException(status_code=503, detail="RAG module not initialized")
|
||||
|
||||
@@ -679,5 +699,5 @@ async def get_current_config(
|
||||
"config": rag_module.config,
|
||||
"embedding_model": rag_module.embedding_model,
|
||||
"enabled": rag_module.enabled,
|
||||
"collections": await rag_module._get_collections_safely()
|
||||
"collections": await rag_module._get_collections_safely(),
|
||||
}
|
||||
|
||||
@@ -88,71 +88,255 @@ class SecurityConfigResponse(BaseModel):
|
||||
# Global settings storage (in a real app, this would be in database)
|
||||
SETTINGS_STORE: Dict[str, Dict[str, Any]] = {
|
||||
"platform": {
|
||||
"app_name": {"value": "Confidential Empire", "type": "string", "description": "Application name"},
|
||||
"maintenance_mode": {"value": False, "type": "boolean", "description": "Enable maintenance mode"},
|
||||
"maintenance_message": {"value": None, "type": "string", "description": "Maintenance mode message"},
|
||||
"debug_mode": {"value": False, "type": "boolean", "description": "Enable debug mode"},
|
||||
"max_upload_size": {"value": 10485760, "type": "integer", "description": "Maximum upload size in bytes"},
|
||||
"app_name": {
|
||||
"value": "Confidential Empire",
|
||||
"type": "string",
|
||||
"description": "Application name",
|
||||
},
|
||||
"maintenance_mode": {
|
||||
"value": False,
|
||||
"type": "boolean",
|
||||
"description": "Enable maintenance mode",
|
||||
},
|
||||
"maintenance_message": {
|
||||
"value": None,
|
||||
"type": "string",
|
||||
"description": "Maintenance mode message",
|
||||
},
|
||||
"debug_mode": {
|
||||
"value": False,
|
||||
"type": "boolean",
|
||||
"description": "Enable debug mode",
|
||||
},
|
||||
"max_upload_size": {
|
||||
"value": 10485760,
|
||||
"type": "integer",
|
||||
"description": "Maximum upload size in bytes",
|
||||
},
|
||||
},
|
||||
"api": {
|
||||
# Security Settings
|
||||
"security_enabled": {"value": True, "type": "boolean", "description": "Enable API security system"},
|
||||
"rate_limiting_enabled": {"value": True, "type": "boolean", "description": "Enable rate limiting"},
|
||||
"ip_reputation_enabled": {"value": True, "type": "boolean", "description": "Enable IP reputation checking"},
|
||||
"anomaly_detection_enabled": {"value": True, "type": "boolean", "description": "Enable anomaly detection"},
|
||||
"security_headers_enabled": {"value": True, "type": "boolean", "description": "Enable security headers"},
|
||||
|
||||
"security_enabled": {
|
||||
"value": True,
|
||||
"type": "boolean",
|
||||
"description": "Enable API security system",
|
||||
},
|
||||
"rate_limiting_enabled": {
|
||||
"value": True,
|
||||
"type": "boolean",
|
||||
"description": "Enable rate limiting",
|
||||
},
|
||||
"ip_reputation_enabled": {
|
||||
"value": True,
|
||||
"type": "boolean",
|
||||
"description": "Enable IP reputation checking",
|
||||
},
|
||||
"anomaly_detection_enabled": {
|
||||
"value": True,
|
||||
"type": "boolean",
|
||||
"description": "Enable anomaly detection",
|
||||
},
|
||||
"security_headers_enabled": {
|
||||
"value": True,
|
||||
"type": "boolean",
|
||||
"description": "Enable security headers",
|
||||
},
|
||||
# Rate Limiting by Authentication Level
|
||||
"rate_limit_authenticated_per_minute": {"value": 200, "type": "integer", "description": "Rate limit for authenticated users per minute"},
|
||||
"rate_limit_authenticated_per_hour": {"value": 5000, "type": "integer", "description": "Rate limit for authenticated users per hour"},
|
||||
"rate_limit_api_key_per_minute": {"value": 1000, "type": "integer", "description": "Rate limit for API key users per minute"},
|
||||
"rate_limit_api_key_per_hour": {"value": 20000, "type": "integer", "description": "Rate limit for API key users per hour"},
|
||||
"rate_limit_premium_per_minute": {"value": 5000, "type": "integer", "description": "Rate limit for premium users per minute"},
|
||||
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer", "description": "Rate limit for premium users per hour"},
|
||||
|
||||
"rate_limit_authenticated_per_minute": {
|
||||
"value": 200,
|
||||
"type": "integer",
|
||||
"description": "Rate limit for authenticated users per minute",
|
||||
},
|
||||
"rate_limit_authenticated_per_hour": {
|
||||
"value": 5000,
|
||||
"type": "integer",
|
||||
"description": "Rate limit for authenticated users per hour",
|
||||
},
|
||||
"rate_limit_api_key_per_minute": {
|
||||
"value": 1000,
|
||||
"type": "integer",
|
||||
"description": "Rate limit for API key users per minute",
|
||||
},
|
||||
"rate_limit_api_key_per_hour": {
|
||||
"value": 20000,
|
||||
"type": "integer",
|
||||
"description": "Rate limit for API key users per hour",
|
||||
},
|
||||
"rate_limit_premium_per_minute": {
|
||||
"value": 5000,
|
||||
"type": "integer",
|
||||
"description": "Rate limit for premium users per minute",
|
||||
},
|
||||
"rate_limit_premium_per_hour": {
|
||||
"value": 100000,
|
||||
"type": "integer",
|
||||
"description": "Rate limit for premium users per hour",
|
||||
},
|
||||
# Security Thresholds
|
||||
"security_warning_threshold": {"value": 0.6, "type": "float", "description": "Risk score threshold for warnings (0.0-1.0)"},
|
||||
"anomaly_threshold": {"value": 0.7, "type": "float", "description": "Anomaly severity threshold (0.0-1.0)"},
|
||||
|
||||
"security_warning_threshold": {
|
||||
"value": 0.6,
|
||||
"type": "float",
|
||||
"description": "Risk score threshold for warnings (0.0-1.0)",
|
||||
},
|
||||
"anomaly_threshold": {
|
||||
"value": 0.7,
|
||||
"type": "float",
|
||||
"description": "Anomaly severity threshold (0.0-1.0)",
|
||||
},
|
||||
# Request Settings
|
||||
"max_request_size_mb": {"value": 10, "type": "integer", "description": "Maximum request size in MB for standard users"},
|
||||
"max_request_size_premium_mb": {"value": 50, "type": "integer", "description": "Maximum request size in MB for premium users"},
|
||||
"enable_cors": {"value": True, "type": "boolean", "description": "Enable CORS headers"},
|
||||
"cors_origins": {"value": ["http://localhost:3000", "http://localhost:53000"], "type": "list", "description": "Allowed CORS origins"},
|
||||
"api_key_expiry_days": {"value": 90, "type": "integer", "description": "Default API key expiry in days"},
|
||||
|
||||
"max_request_size_mb": {
|
||||
"value": 10,
|
||||
"type": "integer",
|
||||
"description": "Maximum request size in MB for standard users",
|
||||
},
|
||||
"max_request_size_premium_mb": {
|
||||
"value": 50,
|
||||
"type": "integer",
|
||||
"description": "Maximum request size in MB for premium users",
|
||||
},
|
||||
"enable_cors": {
|
||||
"value": True,
|
||||
"type": "boolean",
|
||||
"description": "Enable CORS headers",
|
||||
},
|
||||
"cors_origins": {
|
||||
"value": ["http://localhost:3000", "http://localhost:53000"],
|
||||
"type": "list",
|
||||
"description": "Allowed CORS origins",
|
||||
},
|
||||
"api_key_expiry_days": {
|
||||
"value": 90,
|
||||
"type": "integer",
|
||||
"description": "Default API key expiry in days",
|
||||
},
|
||||
# IP Security
|
||||
"blocked_ips": {"value": [], "type": "list", "description": "List of blocked IP addresses"},
|
||||
"allowed_ips": {"value": [], "type": "list", "description": "List of allowed IP addresses (empty = allow all)"},
|
||||
"csp_header": {"value": "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline';", "type": "string", "description": "Content Security Policy header"},
|
||||
"blocked_ips": {
|
||||
"value": [],
|
||||
"type": "list",
|
||||
"description": "List of blocked IP addresses",
|
||||
},
|
||||
"allowed_ips": {
|
||||
"value": [],
|
||||
"type": "list",
|
||||
"description": "List of allowed IP addresses (empty = allow all)",
|
||||
},
|
||||
"csp_header": {
|
||||
"value": "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline';",
|
||||
"type": "string",
|
||||
"description": "Content Security Policy header",
|
||||
},
|
||||
},
|
||||
"security": {
|
||||
"password_min_length": {"value": 8, "type": "integer", "description": "Minimum password length"},
|
||||
"password_require_special": {"value": True, "type": "boolean", "description": "Require special characters in passwords"},
|
||||
"password_require_numbers": {"value": True, "type": "boolean", "description": "Require numbers in passwords"},
|
||||
"password_require_uppercase": {"value": True, "type": "boolean", "description": "Require uppercase letters in passwords"},
|
||||
"max_login_attempts": {"value": 5, "type": "integer", "description": "Maximum login attempts before lockout"},
|
||||
"lockout_duration_minutes": {"value": 15, "type": "integer", "description": "Account lockout duration in minutes"},
|
||||
"require_2fa": {"value": False, "type": "boolean", "description": "Require two-factor authentication"},
|
||||
"ip_whitelist_enabled": {"value": False, "type": "boolean", "description": "Enable IP whitelist"},
|
||||
"allowed_domains": {"value": [], "type": "list", "description": "Allowed email domains for registration"},
|
||||
"password_min_length": {
|
||||
"value": 8,
|
||||
"type": "integer",
|
||||
"description": "Minimum password length",
|
||||
},
|
||||
"password_require_special": {
|
||||
"value": True,
|
||||
"type": "boolean",
|
||||
"description": "Require special characters in passwords",
|
||||
},
|
||||
"password_require_numbers": {
|
||||
"value": True,
|
||||
"type": "boolean",
|
||||
"description": "Require numbers in passwords",
|
||||
},
|
||||
"password_require_uppercase": {
|
||||
"value": True,
|
||||
"type": "boolean",
|
||||
"description": "Require uppercase letters in passwords",
|
||||
},
|
||||
"max_login_attempts": {
|
||||
"value": 5,
|
||||
"type": "integer",
|
||||
"description": "Maximum login attempts before lockout",
|
||||
},
|
||||
"lockout_duration_minutes": {
|
||||
"value": 15,
|
||||
"type": "integer",
|
||||
"description": "Account lockout duration in minutes",
|
||||
},
|
||||
"require_2fa": {
|
||||
"value": False,
|
||||
"type": "boolean",
|
||||
"description": "Require two-factor authentication",
|
||||
},
|
||||
"ip_whitelist_enabled": {
|
||||
"value": False,
|
||||
"type": "boolean",
|
||||
"description": "Enable IP whitelist",
|
||||
},
|
||||
"allowed_domains": {
|
||||
"value": [],
|
||||
"type": "list",
|
||||
"description": "Allowed email domains for registration",
|
||||
},
|
||||
},
|
||||
"features": {
|
||||
"user_registration": {"value": True, "type": "boolean", "description": "Allow user registration"},
|
||||
"api_key_creation": {"value": True, "type": "boolean", "description": "Allow API key creation"},
|
||||
"budget_enforcement": {"value": True, "type": "boolean", "description": "Enable budget enforcement"},
|
||||
"audit_logging": {"value": True, "type": "boolean", "description": "Enable audit logging"},
|
||||
"module_hot_reload": {"value": True, "type": "boolean", "description": "Enable module hot reload"},
|
||||
"tee_support": {"value": True, "type": "boolean", "description": "Enable TEE (Trusted Execution Environment) support"},
|
||||
"advanced_analytics": {"value": True, "type": "boolean", "description": "Enable advanced analytics"},
|
||||
"user_registration": {
|
||||
"value": True,
|
||||
"type": "boolean",
|
||||
"description": "Allow user registration",
|
||||
},
|
||||
"api_key_creation": {
|
||||
"value": True,
|
||||
"type": "boolean",
|
||||
"description": "Allow API key creation",
|
||||
},
|
||||
"budget_enforcement": {
|
||||
"value": True,
|
||||
"type": "boolean",
|
||||
"description": "Enable budget enforcement",
|
||||
},
|
||||
"audit_logging": {
|
||||
"value": True,
|
||||
"type": "boolean",
|
||||
"description": "Enable audit logging",
|
||||
},
|
||||
"module_hot_reload": {
|
||||
"value": True,
|
||||
"type": "boolean",
|
||||
"description": "Enable module hot reload",
|
||||
},
|
||||
"tee_support": {
|
||||
"value": True,
|
||||
"type": "boolean",
|
||||
"description": "Enable TEE (Trusted Execution Environment) support",
|
||||
},
|
||||
"advanced_analytics": {
|
||||
"value": True,
|
||||
"type": "boolean",
|
||||
"description": "Enable advanced analytics",
|
||||
},
|
||||
},
|
||||
"notifications": {
|
||||
"email_enabled": {"value": False, "type": "boolean", "description": "Enable email notifications"},
|
||||
"slack_enabled": {"value": False, "type": "boolean", "description": "Enable Slack notifications"},
|
||||
"webhook_enabled": {"value": False, "type": "boolean", "description": "Enable webhook notifications"},
|
||||
"budget_alerts": {"value": True, "type": "boolean", "description": "Enable budget alert notifications"},
|
||||
"security_alerts": {"value": True, "type": "boolean", "description": "Enable security alert notifications"},
|
||||
}
|
||||
"email_enabled": {
|
||||
"value": False,
|
||||
"type": "boolean",
|
||||
"description": "Enable email notifications",
|
||||
},
|
||||
"slack_enabled": {
|
||||
"value": False,
|
||||
"type": "boolean",
|
||||
"description": "Enable Slack notifications",
|
||||
},
|
||||
"webhook_enabled": {
|
||||
"value": False,
|
||||
"type": "boolean",
|
||||
"description": "Enable webhook notifications",
|
||||
},
|
||||
"budget_alerts": {
|
||||
"value": True,
|
||||
"type": "boolean",
|
||||
"description": "Enable budget alert notifications",
|
||||
},
|
||||
"security_alerts": {
|
||||
"value": True,
|
||||
"type": "boolean",
|
||||
"description": "Enable security alert notifications",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -162,7 +346,7 @@ async def list_settings(
|
||||
category: Optional[str] = None,
|
||||
include_secrets: bool = False,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List all settings or settings in a specific category"""
|
||||
|
||||
@@ -179,23 +363,26 @@ async def list_settings(
|
||||
for key, setting in settings.items():
|
||||
# Hide secret values unless specifically requested and user has permission
|
||||
if setting.get("is_secret", False) and not include_secrets:
|
||||
if not any(perm in current_user.get("permissions", []) for perm in ["platform:settings:admin", "platform:*"]):
|
||||
if not any(
|
||||
perm in current_user.get("permissions", [])
|
||||
for perm in ["platform:settings:admin", "platform:*"]
|
||||
):
|
||||
continue
|
||||
|
||||
result[cat][key] = {
|
||||
"value": setting["value"],
|
||||
"type": setting["type"],
|
||||
"description": setting.get("description", ""),
|
||||
"is_secret": setting.get("is_secret", False)
|
||||
"is_secret": setting.get("is_secret", False),
|
||||
}
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="list_settings",
|
||||
resource_type="setting",
|
||||
details={"category": category, "include_secrets": include_secrets}
|
||||
details={"category": category, "include_secrets": include_secrets},
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -203,8 +390,7 @@ async def list_settings(
|
||||
|
||||
@router.get("/system-info", response_model=SystemInfoResponse)
|
||||
async def get_system_info(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get system information and status"""
|
||||
|
||||
@@ -228,6 +414,7 @@ async def get_system_info(
|
||||
# Get LLM service status
|
||||
try:
|
||||
from app.services.llm.service import llm_service
|
||||
|
||||
health_summary = llm_service.get_health_summary()
|
||||
llm_service_status = health_summary.get("service_status", "unknown")
|
||||
except Exception:
|
||||
@@ -238,6 +425,7 @@ async def get_system_info(
|
||||
|
||||
# Get active users count (last 24 hours)
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
yesterday = datetime.utcnow() - timedelta(days=1)
|
||||
active_users_query = select(User.id).where(User.last_login >= yesterday)
|
||||
active_users_result = await db.execute(active_users_query)
|
||||
@@ -254,9 +442,9 @@ async def get_system_info(
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="get_system_info",
|
||||
resource_type="system"
|
||||
resource_type="system",
|
||||
)
|
||||
|
||||
return SystemInfoResponse(
|
||||
@@ -268,14 +456,13 @@ async def get_system_info(
|
||||
modules_loaded=modules_loaded,
|
||||
active_users=active_users,
|
||||
total_api_keys=total_api_keys,
|
||||
uptime_seconds=uptime_seconds
|
||||
uptime_seconds=uptime_seconds,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/platform-config", response_model=PlatformConfigResponse)
|
||||
async def get_platform_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get platform configuration"""
|
||||
|
||||
@@ -289,24 +476,33 @@ async def get_platform_config(
|
||||
api_settings = SETTINGS_STORE.get("api", {})
|
||||
|
||||
return PlatformConfigResponse(
|
||||
app_name=platform_settings.get("app_name", {}).get("value", "Confidential Empire"),
|
||||
app_name=platform_settings.get("app_name", {}).get(
|
||||
"value", "Confidential Empire"
|
||||
),
|
||||
debug_mode=platform_settings.get("debug_mode", {}).get("value", False),
|
||||
log_level=app_settings.LOG_LEVEL,
|
||||
cors_origins=app_settings.CORS_ORIGINS,
|
||||
rate_limiting_enabled=api_settings.get("rate_limiting_enabled", {}).get("value", True),
|
||||
max_upload_size=platform_settings.get("max_upload_size", {}).get("value", 10485760),
|
||||
rate_limiting_enabled=api_settings.get("rate_limiting_enabled", {}).get(
|
||||
"value", True
|
||||
),
|
||||
max_upload_size=platform_settings.get("max_upload_size", {}).get(
|
||||
"value", 10485760
|
||||
),
|
||||
session_timeout_minutes=app_settings.SESSION_EXPIRE_MINUTES,
|
||||
api_key_prefix=app_settings.API_KEY_PREFIX,
|
||||
features=features,
|
||||
maintenance_mode=platform_settings.get("maintenance_mode", {}).get("value", False),
|
||||
maintenance_message=platform_settings.get("maintenance_message", {}).get("value")
|
||||
maintenance_mode=platform_settings.get("maintenance_mode", {}).get(
|
||||
"value", False
|
||||
),
|
||||
maintenance_message=platform_settings.get("maintenance_message", {}).get(
|
||||
"value"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/security-config", response_model=SecurityConfigResponse)
|
||||
async def get_security_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get security configuration"""
|
||||
|
||||
@@ -316,16 +512,30 @@ async def get_security_config(
|
||||
security_settings = SETTINGS_STORE.get("security", {})
|
||||
|
||||
return SecurityConfigResponse(
|
||||
password_min_length=security_settings.get("password_min_length", {}).get("value", 8),
|
||||
password_require_special=security_settings.get("password_require_special", {}).get("value", True),
|
||||
password_require_numbers=security_settings.get("password_require_numbers", {}).get("value", True),
|
||||
password_require_uppercase=security_settings.get("password_require_uppercase", {}).get("value", True),
|
||||
password_min_length=security_settings.get("password_min_length", {}).get(
|
||||
"value", 8
|
||||
),
|
||||
password_require_special=security_settings.get(
|
||||
"password_require_special", {}
|
||||
).get("value", True),
|
||||
password_require_numbers=security_settings.get(
|
||||
"password_require_numbers", {}
|
||||
).get("value", True),
|
||||
password_require_uppercase=security_settings.get(
|
||||
"password_require_uppercase", {}
|
||||
).get("value", True),
|
||||
session_timeout_minutes=app_settings.SESSION_EXPIRE_MINUTES,
|
||||
max_login_attempts=security_settings.get("max_login_attempts", {}).get("value", 5),
|
||||
lockout_duration_minutes=security_settings.get("lockout_duration_minutes", {}).get("value", 15),
|
||||
max_login_attempts=security_settings.get("max_login_attempts", {}).get(
|
||||
"value", 5
|
||||
),
|
||||
lockout_duration_minutes=security_settings.get(
|
||||
"lockout_duration_minutes", {}
|
||||
).get("value", 15),
|
||||
require_2fa=security_settings.get("require_2fa", {}).get("value", False),
|
||||
allowed_domains=security_settings.get("allowed_domains", {}).get("value", []),
|
||||
ip_whitelist_enabled=security_settings.get("ip_whitelist_enabled", {}).get("value", False)
|
||||
ip_whitelist_enabled=security_settings.get("ip_whitelist_enabled", {}).get(
|
||||
"value", False
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -334,7 +544,7 @@ async def get_setting(
|
||||
category: str,
|
||||
key: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get a specific setting value"""
|
||||
|
||||
@@ -344,28 +554,30 @@ async def get_setting(
|
||||
if category not in SETTINGS_STORE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Settings category '{category}' not found"
|
||||
detail=f"Settings category '{category}' not found",
|
||||
)
|
||||
|
||||
if key not in SETTINGS_STORE[category]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Setting '{key}' not found in category '{category}'"
|
||||
detail=f"Setting '{key}' not found in category '{category}'",
|
||||
)
|
||||
|
||||
setting = SETTINGS_STORE[category][key]
|
||||
|
||||
# Check if it's a secret setting
|
||||
if setting.get("is_secret", False):
|
||||
require_permission(current_user.get("permissions", []), "platform:settings:admin")
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:settings:admin"
|
||||
)
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="get_setting",
|
||||
resource_type="setting",
|
||||
resource_id=f"{category}.{key}"
|
||||
resource_id=f"{category}.{key}",
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -374,7 +586,7 @@ async def get_setting(
|
||||
"value": setting["value"],
|
||||
"type": setting["type"],
|
||||
"description": setting.get("description", ""),
|
||||
"is_secret": setting.get("is_secret", False)
|
||||
"is_secret": setting.get("is_secret", False),
|
||||
}
|
||||
|
||||
|
||||
@@ -383,7 +595,7 @@ async def update_category_settings(
|
||||
category: str,
|
||||
settings_data: Dict[str, Any],
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update multiple settings in a category"""
|
||||
|
||||
@@ -393,7 +605,7 @@ async def update_category_settings(
|
||||
if category not in SETTINGS_STORE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Settings category '{category}' not found"
|
||||
detail=f"Settings category '{category}' not found",
|
||||
)
|
||||
|
||||
updated_settings = []
|
||||
@@ -408,7 +620,9 @@ async def update_category_settings(
|
||||
|
||||
# Check if it's a secret setting
|
||||
if setting.get("is_secret", False):
|
||||
require_permission(current_user.get("permissions", []), "platform:settings:admin")
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:settings:admin"
|
||||
)
|
||||
|
||||
# Store original value for audit
|
||||
original_value = setting["value"]
|
||||
@@ -425,7 +639,7 @@ async def update_category_settings(
|
||||
continue
|
||||
elif expected_type == "boolean" and not isinstance(new_value, bool):
|
||||
if isinstance(new_value, str):
|
||||
new_value = new_value.lower() in ('true', '1', 'yes', 'on')
|
||||
new_value = new_value.lower() in ("true", "1", "yes", "on")
|
||||
else:
|
||||
errors.append(f"Setting '{key}' expects a boolean value")
|
||||
continue
|
||||
@@ -445,11 +659,9 @@ async def update_category_settings(
|
||||
|
||||
# Update setting
|
||||
SETTINGS_STORE[category][key]["value"] = new_value
|
||||
updated_settings.append({
|
||||
"key": key,
|
||||
"original_value": original_value,
|
||||
"new_value": new_value
|
||||
})
|
||||
updated_settings.append(
|
||||
{"key": key, "original_value": original_value, "new_value": new_value}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
errors.append(f"Error updating setting '{key}': {str(e)}")
|
||||
@@ -457,7 +669,7 @@ async def update_category_settings(
|
||||
# Log audit event for bulk update
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="bulk_update_settings",
|
||||
resource_type="setting",
|
||||
resource_id=category,
|
||||
@@ -465,25 +677,29 @@ async def update_category_settings(
|
||||
"updated_count": len(updated_settings),
|
||||
"errors_count": len(errors),
|
||||
"updated_settings": updated_settings,
|
||||
"errors": errors
|
||||
}
|
||||
"errors": errors,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Bulk settings updated in category '{category}': {len(updated_settings)} settings by {current_user['username']}")
|
||||
logger.info(
|
||||
f"Bulk settings updated in category '{category}': {len(updated_settings)} settings by {current_user['username']}"
|
||||
)
|
||||
|
||||
if errors and not updated_settings:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"No settings were updated. Errors: {errors}"
|
||||
detail=f"No settings were updated. Errors: {errors}",
|
||||
)
|
||||
|
||||
return {
|
||||
"category": category,
|
||||
"updated_count": len(updated_settings),
|
||||
"errors_count": len(errors),
|
||||
"updated_settings": [{"key": s["key"], "new_value": s["new_value"]} for s in updated_settings],
|
||||
"updated_settings": [
|
||||
{"key": s["key"], "new_value": s["new_value"]} for s in updated_settings
|
||||
],
|
||||
"errors": errors,
|
||||
"message": f"Updated {len(updated_settings)} settings in category '{category}'"
|
||||
"message": f"Updated {len(updated_settings)} settings in category '{category}'",
|
||||
}
|
||||
|
||||
|
||||
@@ -493,7 +709,7 @@ async def update_setting(
|
||||
key: str,
|
||||
setting_update: SettingUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update a specific setting"""
|
||||
|
||||
@@ -503,20 +719,22 @@ async def update_setting(
|
||||
if category not in SETTINGS_STORE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Settings category '{category}' not found"
|
||||
detail=f"Settings category '{category}' not found",
|
||||
)
|
||||
|
||||
if key not in SETTINGS_STORE[category]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Setting '{key}' not found in category '{category}'"
|
||||
detail=f"Setting '{key}' not found in category '{category}'",
|
||||
)
|
||||
|
||||
setting = SETTINGS_STORE[category][key]
|
||||
|
||||
# Check if it's a secret setting
|
||||
if setting.get("is_secret", False):
|
||||
require_permission(current_user.get("permissions", []), "platform:settings:admin")
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:settings:admin"
|
||||
)
|
||||
|
||||
# Store original value for audit
|
||||
original_value = setting["value"]
|
||||
@@ -528,22 +746,22 @@ async def update_setting(
|
||||
if expected_type == "integer" and not isinstance(new_value, int):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Setting '{key}' expects an integer value"
|
||||
detail=f"Setting '{key}' expects an integer value",
|
||||
)
|
||||
elif expected_type == "boolean" and not isinstance(new_value, bool):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Setting '{key}' expects a boolean value"
|
||||
detail=f"Setting '{key}' expects a boolean value",
|
||||
)
|
||||
elif expected_type == "float" and not isinstance(new_value, (int, float)):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Setting '{key}' expects a numeric value"
|
||||
detail=f"Setting '{key}' expects a numeric value",
|
||||
)
|
||||
elif expected_type == "list" and not isinstance(new_value, list):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Setting '{key}' expects a list value"
|
||||
detail=f"Setting '{key}' expects a list value",
|
||||
)
|
||||
|
||||
# Update setting
|
||||
@@ -554,15 +772,15 @@ async def update_setting(
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="update_setting",
|
||||
resource_type="setting",
|
||||
resource_id=f"{category}.{key}",
|
||||
details={
|
||||
"original_value": original_value,
|
||||
"new_value": new_value,
|
||||
"description_updated": setting_update.description is not None
|
||||
}
|
||||
"description_updated": setting_update.description is not None,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Setting updated: {category}.{key} by {current_user['username']}")
|
||||
@@ -573,7 +791,7 @@ async def update_setting(
|
||||
"value": new_value,
|
||||
"type": expected_type,
|
||||
"description": SETTINGS_STORE[category][key].get("description", ""),
|
||||
"message": "Setting updated successfully"
|
||||
"message": "Setting updated successfully",
|
||||
}
|
||||
|
||||
|
||||
@@ -581,7 +799,7 @@ async def update_setting(
|
||||
async def reset_to_defaults(
|
||||
category: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Reset settings to default values"""
|
||||
|
||||
@@ -603,7 +821,6 @@ async def reset_to_defaults(
|
||||
"ip_reputation_enabled": {"value": True, "type": "boolean"},
|
||||
"anomaly_detection_enabled": {"value": True, "type": "boolean"},
|
||||
"security_headers_enabled": {"value": True, "type": "boolean"},
|
||||
|
||||
# Rate Limiting by Authentication Level
|
||||
"rate_limit_authenticated_per_minute": {"value": 200, "type": "integer"},
|
||||
"rate_limit_authenticated_per_hour": {"value": 5000, "type": "integer"},
|
||||
@@ -611,22 +828,25 @@ async def reset_to_defaults(
|
||||
"rate_limit_api_key_per_hour": {"value": 20000, "type": "integer"},
|
||||
"rate_limit_premium_per_minute": {"value": 5000, "type": "integer"},
|
||||
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer"},
|
||||
|
||||
# Security Thresholds
|
||||
"security_warning_threshold": {"value": 0.6, "type": "float"},
|
||||
"anomaly_threshold": {"value": 0.7, "type": "float"},
|
||||
|
||||
# Request Settings
|
||||
"max_request_size_mb": {"value": 10, "type": "integer"},
|
||||
"max_request_size_premium_mb": {"value": 50, "type": "integer"},
|
||||
"enable_cors": {"value": True, "type": "boolean"},
|
||||
"cors_origins": {"value": ["http://localhost:3000", "http://localhost:53000"], "type": "list"},
|
||||
"cors_origins": {
|
||||
"value": ["http://localhost:3000", "http://localhost:53000"],
|
||||
"type": "list",
|
||||
},
|
||||
"api_key_expiry_days": {"value": 90, "type": "integer"},
|
||||
|
||||
# IP Security
|
||||
"blocked_ips": {"value": [], "type": "list"},
|
||||
"allowed_ips": {"value": [], "type": "list"},
|
||||
"csp_header": {"value": "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline';", "type": "string"},
|
||||
"csp_header": {
|
||||
"value": "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline';",
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"security": {
|
||||
"password_min_length": {"value": 8, "type": "integer"},
|
||||
@@ -647,7 +867,7 @@ async def reset_to_defaults(
|
||||
"module_hot_reload": {"value": True, "type": "boolean"},
|
||||
"tee_support": {"value": True, "type": "boolean"},
|
||||
"advanced_analytics": {"value": True, "type": "boolean"},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
reset_categories = [category] if category else list(defaults.keys())
|
||||
@@ -661,24 +881,25 @@ async def reset_to_defaults(
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="reset_settings_to_defaults",
|
||||
resource_type="setting",
|
||||
details={"categories_reset": reset_categories}
|
||||
details={"categories_reset": reset_categories},
|
||||
)
|
||||
|
||||
logger.info(f"Settings reset to defaults: {reset_categories} by {current_user['username']}")
|
||||
logger.info(
|
||||
f"Settings reset to defaults: {reset_categories} by {current_user['username']}"
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Settings reset to defaults for categories: {reset_categories}",
|
||||
"categories_reset": reset_categories
|
||||
"categories_reset": reset_categories,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/export")
|
||||
async def export_settings(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Export all settings to JSON"""
|
||||
|
||||
@@ -693,29 +914,32 @@ async def export_settings(
|
||||
for key, setting in settings.items():
|
||||
# Skip secret settings for non-admin users
|
||||
if setting.get("is_secret", False):
|
||||
if not any(perm in current_user.get("permissions", []) for perm in ["platform:settings:admin", "platform:*"]):
|
||||
if not any(
|
||||
perm in current_user.get("permissions", [])
|
||||
for perm in ["platform:settings:admin", "platform:*"]
|
||||
):
|
||||
continue
|
||||
|
||||
export_data[category][key] = {
|
||||
"value": setting["value"],
|
||||
"type": setting["type"],
|
||||
"description": setting.get("description", ""),
|
||||
"is_secret": setting.get("is_secret", False)
|
||||
"is_secret": setting.get("is_secret", False),
|
||||
}
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="export_settings",
|
||||
resource_type="setting",
|
||||
details={"categories_exported": list(export_data.keys())}
|
||||
details={"categories_exported": list(export_data.keys())},
|
||||
)
|
||||
|
||||
return {
|
||||
"settings": export_data,
|
||||
"exported_at": datetime.utcnow().isoformat(),
|
||||
"exported_by": current_user['username']
|
||||
"exported_by": current_user["username"],
|
||||
}
|
||||
|
||||
|
||||
@@ -723,7 +947,7 @@ async def export_settings(
|
||||
async def import_settings(
|
||||
settings_data: Dict[str, Dict[str, Dict[str, Any]]],
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Import settings from JSON"""
|
||||
|
||||
@@ -750,15 +974,21 @@ async def import_settings(
|
||||
|
||||
# Basic type validation
|
||||
if expected_type == "integer" and not isinstance(new_value, int):
|
||||
errors.append(f"Invalid type for {category}.{key}: expected integer")
|
||||
errors.append(
|
||||
f"Invalid type for {category}.{key}: expected integer"
|
||||
)
|
||||
continue
|
||||
elif expected_type == "boolean" and not isinstance(new_value, bool):
|
||||
errors.append(f"Invalid type for {category}.{key}: expected boolean")
|
||||
errors.append(
|
||||
f"Invalid type for {category}.{key}: expected boolean"
|
||||
)
|
||||
continue
|
||||
|
||||
SETTINGS_STORE[category][key]["value"] = new_value
|
||||
if "description" in setting_data:
|
||||
SETTINGS_STORE[category][key]["description"] = setting_data["description"]
|
||||
SETTINGS_STORE[category][key]["description"] = setting_data[
|
||||
"description"
|
||||
]
|
||||
|
||||
imported_count += 1
|
||||
|
||||
@@ -768,20 +998,22 @@ async def import_settings(
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="import_settings",
|
||||
resource_type="setting",
|
||||
details={
|
||||
"imported_count": imported_count,
|
||||
"errors_count": len(errors),
|
||||
"errors": errors
|
||||
}
|
||||
"errors": errors,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Settings imported: {imported_count} settings by {current_user['username']}")
|
||||
logger.info(
|
||||
f"Settings imported: {imported_count} settings by {current_user['username']}"
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Import completed. {imported_count} settings imported.",
|
||||
"imported_count": imported_count,
|
||||
"errors": errors
|
||||
"errors": errors,
|
||||
}
|
||||
@@ -85,7 +85,7 @@ async def list_users(
|
||||
is_active: Optional[bool] = Query(None),
|
||||
search: Optional[str] = Query(None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List all users with pagination and filtering"""
|
||||
|
||||
@@ -102,9 +102,9 @@ async def list_users(
|
||||
query = query.where(User.is_active == is_active)
|
||||
if search:
|
||||
query = query.where(
|
||||
(User.username.ilike(f"%{search}%")) |
|
||||
(User.email.ilike(f"%{search}%")) |
|
||||
(User.full_name.ilike(f"%{search}%"))
|
||||
(User.username.ilike(f"%{search}%"))
|
||||
| (User.email.ilike(f"%{search}%"))
|
||||
| (User.full_name.ilike(f"%{search}%"))
|
||||
)
|
||||
|
||||
# Get total count
|
||||
@@ -123,17 +123,21 @@ async def list_users(
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="list_users",
|
||||
resource_type="user",
|
||||
details={"page": page, "size": size, "filters": {"role": role, "is_active": is_active, "search": search}}
|
||||
details={
|
||||
"page": page,
|
||||
"size": size,
|
||||
"filters": {"role": role, "is_active": is_active, "search": search},
|
||||
},
|
||||
)
|
||||
|
||||
return UserListResponse(
|
||||
users=[UserResponse.model_validate(user) for user in users],
|
||||
total=total,
|
||||
page=page,
|
||||
size=size
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
@@ -141,12 +145,12 @@ async def list_users(
|
||||
async def get_user(
|
||||
user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get user by ID"""
|
||||
|
||||
# Check permissions (users can view their own profile)
|
||||
if int(user_id) != current_user['id']:
|
||||
if int(user_id) != current_user["id"]:
|
||||
require_permission(current_user.get("permissions", []), "platform:users:read")
|
||||
|
||||
# Get user
|
||||
@@ -156,17 +160,16 @@ async def get_user(
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="get_user",
|
||||
resource_type="user",
|
||||
resource_id=user_id
|
||||
resource_id=user_id,
|
||||
)
|
||||
|
||||
return UserResponse.model_validate(user)
|
||||
@@ -176,7 +179,7 @@ async def get_user(
|
||||
async def create_user(
|
||||
user_data: UserCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a new user"""
|
||||
|
||||
@@ -193,7 +196,7 @@ async def create_user(
|
||||
if existing_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User with this username or email already exists"
|
||||
detail="User with this username or email already exists",
|
||||
)
|
||||
|
||||
# Create user
|
||||
@@ -204,7 +207,7 @@ async def create_user(
|
||||
full_name=user_data.full_name,
|
||||
hashed_password=hashed_password,
|
||||
role=user_data.role,
|
||||
is_active=user_data.is_active
|
||||
is_active=user_data.is_active,
|
||||
)
|
||||
|
||||
db.add(new_user)
|
||||
@@ -214,11 +217,15 @@ async def create_user(
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="create_user",
|
||||
resource_type="user",
|
||||
resource_id=str(new_user.id),
|
||||
details={"username": user_data.username, "email": user_data.email, "role": user_data.role}
|
||||
details={
|
||||
"username": user_data.username,
|
||||
"email": user_data.email,
|
||||
"role": user_data.role,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"User created: {new_user.username} by {current_user['username']}")
|
||||
@@ -231,12 +238,12 @@ async def update_user(
|
||||
user_id: str,
|
||||
user_data: UserUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update user"""
|
||||
|
||||
# Check permissions (users can update their own profile with limited fields)
|
||||
is_self_update = int(user_id) == current_user['id']
|
||||
is_self_update = int(user_id) == current_user["id"]
|
||||
if not is_self_update:
|
||||
require_permission(current_user.get("permissions", []), "platform:users:update")
|
||||
|
||||
@@ -247,8 +254,7 @@ async def update_user(
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
# For self-updates, restrict what can be changed
|
||||
@@ -259,7 +265,7 @@ async def update_user(
|
||||
if restricted_fields:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Cannot update fields: {restricted_fields}"
|
||||
detail=f"Cannot update fields: {restricted_fields}",
|
||||
)
|
||||
|
||||
# Store original values for audit
|
||||
@@ -267,7 +273,7 @@ async def update_user(
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"role": user.role,
|
||||
"is_active": user.is_active
|
||||
"is_active": user.is_active,
|
||||
}
|
||||
|
||||
# Update user fields
|
||||
@@ -281,15 +287,15 @@ async def update_user(
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="update_user",
|
||||
resource_type="user",
|
||||
resource_id=user_id,
|
||||
details={
|
||||
"updated_fields": list(update_data.keys()),
|
||||
"before_values": original_values,
|
||||
"after_values": {k: getattr(user, k) for k in update_data.keys()}
|
||||
}
|
||||
"after_values": {k: getattr(user, k) for k in update_data.keys()},
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"User updated: {user.username} by {current_user['username']}")
|
||||
@@ -301,7 +307,7 @@ async def update_user(
|
||||
async def delete_user(
|
||||
user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Delete user (soft delete by deactivating)"""
|
||||
|
||||
@@ -309,10 +315,10 @@ async def delete_user(
|
||||
require_permission(current_user.get("permissions", []), "platform:users:delete")
|
||||
|
||||
# Prevent self-deletion
|
||||
if int(user_id) == current_user['id']:
|
||||
if int(user_id) == current_user["id"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot delete your own account"
|
||||
detail="Cannot delete your own account",
|
||||
)
|
||||
|
||||
# Get user
|
||||
@@ -322,8 +328,7 @@ async def delete_user(
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
# Soft delete by deactivating
|
||||
@@ -333,11 +338,11 @@ async def delete_user(
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="delete_user",
|
||||
resource_type="user",
|
||||
resource_id=user_id,
|
||||
details={"username": user.username, "email": user.email}
|
||||
details={"username": user.username, "email": user.email},
|
||||
)
|
||||
|
||||
logger.info(f"User deleted: {user.username} by {current_user['username']}")
|
||||
@@ -350,12 +355,12 @@ async def change_password(
|
||||
user_id: str,
|
||||
password_data: PasswordChangeRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Change user password"""
|
||||
|
||||
# Users can only change their own password, or admins can change any password
|
||||
is_self_update = int(user_id) == current_user['id']
|
||||
is_self_update = int(user_id) == current_user["id"]
|
||||
if not is_self_update:
|
||||
require_permission(current_user.get("permissions", []), "platform:users:update")
|
||||
|
||||
@@ -366,8 +371,7 @@ async def change_password(
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
# For self-updates, verify current password
|
||||
@@ -375,7 +379,7 @@ async def change_password(
|
||||
if not verify_password(password_data.current_password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Current password is incorrect"
|
||||
detail="Current password is incorrect",
|
||||
)
|
||||
|
||||
# Update password
|
||||
@@ -385,14 +389,16 @@ async def change_password(
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="change_password",
|
||||
resource_type="user",
|
||||
resource_id=user_id,
|
||||
details={"target_user": user.username}
|
||||
details={"target_user": user.username},
|
||||
)
|
||||
|
||||
logger.info(f"Password changed for user: {user.username} by {current_user['username']}")
|
||||
logger.info(
|
||||
f"Password changed for user: {user.username} by {current_user['username']}"
|
||||
)
|
||||
|
||||
return {"message": "Password changed successfully"}
|
||||
|
||||
@@ -402,7 +408,7 @@ async def reset_password(
|
||||
user_id: str,
|
||||
password_data: PasswordResetRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Reset user password (admin only)"""
|
||||
|
||||
@@ -416,8 +422,7 @@ async def reset_password(
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
# Reset password
|
||||
@@ -427,14 +432,16 @@ async def reset_password(
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="reset_password",
|
||||
resource_type="user",
|
||||
resource_id=user_id,
|
||||
details={"target_user": user.username}
|
||||
details={"target_user": user.username},
|
||||
)
|
||||
|
||||
logger.info(f"Password reset for user: {user.username} by {current_user['username']}")
|
||||
logger.info(
|
||||
f"Password reset for user: {user.username} by {current_user['username']}"
|
||||
)
|
||||
|
||||
return {"message": "Password reset successfully"}
|
||||
|
||||
@@ -443,14 +450,16 @@ async def reset_password(
|
||||
async def get_user_api_keys(
|
||||
user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get API keys for a user"""
|
||||
|
||||
# Check permissions (users can view their own API keys)
|
||||
is_self_request = int(user_id) == current_user['id']
|
||||
is_self_request = int(user_id) == current_user["id"]
|
||||
if not is_self_request:
|
||||
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:api-keys:read"
|
||||
)
|
||||
|
||||
# Get API keys
|
||||
query = select(APIKey).where(APIKey.user_id == int(user_id))
|
||||
@@ -466,8 +475,12 @@ async def get_user_api_keys(
|
||||
"scopes": api_key.scopes,
|
||||
"is_active": api_key.is_active,
|
||||
"created_at": api_key.created_at.isoformat(),
|
||||
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
|
||||
"last_used_at": api_key.last_used_at.isoformat() if api_key.last_used_at else None
|
||||
"expires_at": api_key.expires_at.isoformat()
|
||||
if api_key.expires_at
|
||||
else None,
|
||||
"last_used_at": api_key.last_used_at.isoformat()
|
||||
if api_key.last_used_at
|
||||
else None,
|
||||
}
|
||||
for api_key in api_keys
|
||||
]
|
||||
@@ -24,18 +24,13 @@ class CoreCacheService:
|
||||
self.redis_pool: Optional[ConnectionPool] = None
|
||||
self.redis_client: Optional[Redis] = None
|
||||
self.enabled = False
|
||||
self.stats = {
|
||||
"hits": 0,
|
||||
"misses": 0,
|
||||
"errors": 0,
|
||||
"total_requests": 0
|
||||
}
|
||||
self.stats = {"hits": 0, "misses": 0, "errors": 0, "total_requests": 0}
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the core cache service with connection pool"""
|
||||
try:
|
||||
# Create Redis connection pool for better resource management
|
||||
redis_url = getattr(settings, 'REDIS_URL', 'redis://localhost:6379/0')
|
||||
redis_url = getattr(settings, "REDIS_URL", "redis://localhost:6379/0")
|
||||
|
||||
self.redis_pool = ConnectionPool.from_url(
|
||||
redis_url,
|
||||
@@ -45,7 +40,7 @@ class CoreCacheService:
|
||||
socket_timeout=5,
|
||||
retry_on_timeout=True,
|
||||
max_connections=20, # Shared pool for all cache operations
|
||||
health_check_interval=30
|
||||
health_check_interval=30,
|
||||
)
|
||||
|
||||
self.redis_client = Redis(connection_pool=self.redis_pool)
|
||||
@@ -105,7 +100,9 @@ class CoreCacheService:
|
||||
self.stats["errors"] += 1
|
||||
return default
|
||||
|
||||
async def set(self, key: str, value: Any, ttl: Optional[int] = None, prefix: str = "core") -> bool:
|
||||
async def set(
|
||||
self, key: str, value: Any, ttl: Optional[int] = None, prefix: str = "core"
|
||||
) -> bool:
|
||||
"""Set value in cache"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
@@ -172,7 +169,9 @@ class CoreCacheService:
|
||||
self.stats["errors"] += 1
|
||||
return 0
|
||||
|
||||
async def increment(self, key: str, amount: int = 1, ttl: Optional[int] = None, prefix: str = "core") -> int:
|
||||
async def increment(
|
||||
self, key: str, amount: int = 1, ttl: Optional[int] = None, prefix: str = "core"
|
||||
) -> int:
|
||||
"""Increment counter with optional TTL"""
|
||||
if not self.enabled:
|
||||
return 0
|
||||
@@ -200,18 +199,24 @@ class CoreCacheService:
|
||||
if self.enabled:
|
||||
try:
|
||||
info = await self.redis_client.info()
|
||||
stats.update({
|
||||
stats.update(
|
||||
{
|
||||
"redis_memory_used": info.get("used_memory_human", "N/A"),
|
||||
"redis_connected_clients": info.get("connected_clients", 0),
|
||||
"redis_total_commands": info.get("total_commands_processed", 0),
|
||||
"redis_keyspace_hits": info.get("keyspace_hits", 0),
|
||||
"redis_keyspace_misses": info.get("keyspace_misses", 0),
|
||||
"connection_pool_size": self.redis_pool.connection_pool_size if self.redis_pool else 0,
|
||||
"connection_pool_size": self.redis_pool.connection_pool_size
|
||||
if self.redis_pool
|
||||
else 0,
|
||||
"hit_rate": round(
|
||||
(stats["hits"] / stats["total_requests"]) * 100, 2
|
||||
) if stats["total_requests"] > 0 else 0,
|
||||
"enabled": True
|
||||
})
|
||||
)
|
||||
if stats["total_requests"] > 0
|
||||
else 0,
|
||||
"enabled": True,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Redis stats: {e}")
|
||||
stats["enabled"] = False
|
||||
@@ -232,7 +237,9 @@ class CoreCacheService:
|
||||
|
||||
# Specialized caching methods for common use cases
|
||||
|
||||
async def cache_api_key(self, key_prefix: str, api_key_data: Dict[str, Any], ttl: int = 300) -> bool:
|
||||
async def cache_api_key(
|
||||
self, key_prefix: str, api_key_data: Dict[str, Any], ttl: int = 300
|
||||
) -> bool:
|
||||
"""Cache API key data for authentication"""
|
||||
return await self.set(key_prefix, api_key_data, ttl, prefix="auth")
|
||||
|
||||
@@ -244,36 +251,53 @@ class CoreCacheService:
|
||||
"""Invalidate cached API key"""
|
||||
return await self.delete(key_prefix, prefix="auth")
|
||||
|
||||
async def cache_verification_result(self, api_key: str, key_prefix: str, key_hash: str, is_valid: bool, ttl: int = 300) -> bool:
|
||||
async def cache_verification_result(
|
||||
self,
|
||||
api_key: str,
|
||||
key_prefix: str,
|
||||
key_hash: str,
|
||||
is_valid: bool,
|
||||
ttl: int = 300,
|
||||
) -> bool:
|
||||
"""Cache API key verification result to avoid expensive bcrypt operations"""
|
||||
verification_data = {
|
||||
"key_hash": key_hash,
|
||||
"is_valid": is_valid,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
return await self.set(f"verify:{key_prefix}", verification_data, ttl, prefix="auth")
|
||||
return await self.set(
|
||||
f"verify:{key_prefix}", verification_data, ttl, prefix="auth"
|
||||
)
|
||||
|
||||
async def get_cached_verification(self, key_prefix: str) -> Optional[Dict[str, Any]]:
|
||||
async def get_cached_verification(
|
||||
self, key_prefix: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get cached verification result"""
|
||||
return await self.get(f"verify:{key_prefix}", prefix="auth")
|
||||
|
||||
async def cache_rate_limit(self, identifier: str, window_seconds: int, limit: int, current_count: int = 1) -> Dict[str, Any]:
|
||||
async def cache_rate_limit(
|
||||
self, identifier: str, window_seconds: int, limit: int, current_count: int = 1
|
||||
) -> Dict[str, Any]:
|
||||
"""Cache and track rate limit state"""
|
||||
key = f"rate_limit:{identifier}:{window_seconds}"
|
||||
|
||||
try:
|
||||
# Use atomic increment with expiry
|
||||
count = await self.increment(key, current_count, window_seconds, prefix="rate")
|
||||
count = await self.increment(
|
||||
key, current_count, window_seconds, prefix="rate"
|
||||
)
|
||||
|
||||
remaining = max(0, limit - count)
|
||||
reset_time = int((datetime.utcnow() + timedelta(seconds=window_seconds)).timestamp())
|
||||
reset_time = int(
|
||||
(datetime.utcnow() + timedelta(seconds=window_seconds)).timestamp()
|
||||
)
|
||||
|
||||
return {
|
||||
"count": count,
|
||||
"limit": limit,
|
||||
"remaining": remaining,
|
||||
"reset_time": reset_time,
|
||||
"exceeded": count > limit
|
||||
"exceeded": count > limit,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Rate limit cache error: {e}")
|
||||
@@ -282,8 +306,10 @@ class CoreCacheService:
|
||||
"count": 0,
|
||||
"limit": limit,
|
||||
"remaining": limit,
|
||||
"reset_time": int((datetime.utcnow() + timedelta(seconds=window_seconds)).timestamp()),
|
||||
"exceeded": False
|
||||
"reset_time": int(
|
||||
(datetime.utcnow() + timedelta(seconds=window_seconds)).timestamp()
|
||||
),
|
||||
"exceeded": False,
|
||||
}
|
||||
|
||||
|
||||
@@ -297,7 +323,9 @@ async def get(key: str, default: Any = None, prefix: str = "core") -> Any:
|
||||
return await core_cache.get(key, default, prefix)
|
||||
|
||||
|
||||
async def set(key: str, value: Any, ttl: Optional[int] = None, prefix: str = "core") -> bool:
|
||||
async def set(
|
||||
key: str, value: Any, ttl: Optional[int] = None, prefix: str = "core"
|
||||
) -> bool:
|
||||
"""Set value in core cache"""
|
||||
return await core_cache.set(key, value, ttl, prefix)
|
||||
|
||||
|
||||
@@ -21,7 +21,9 @@ class Settings(BaseSettings):
|
||||
FRONTEND_INTERNAL_PORT: int = int(os.getenv("FRONTEND_INTERNAL_PORT", "3000"))
|
||||
|
||||
# Detailed logging for LLM interactions
|
||||
LOG_LLM_PROMPTS: bool = os.getenv("LOG_LLM_PROMPTS", "False").lower() == "true" # Set to True to log prompts and context sent to LLM
|
||||
LOG_LLM_PROMPTS: bool = (
|
||||
os.getenv("LOG_LLM_PROMPTS", "False").lower() == "true"
|
||||
) # Set to True to log prompts and context sent to LLM
|
||||
|
||||
# Database
|
||||
DATABASE_URL: str = os.getenv("DATABASE_URL")
|
||||
@@ -32,11 +34,19 @@ class Settings(BaseSettings):
|
||||
# Security
|
||||
JWT_SECRET: str = os.getenv("JWT_SECRET")
|
||||
JWT_ALGORITHM: str = os.getenv("JWT_ALGORITHM", "HS256")
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440")) # 24 hours
|
||||
REFRESH_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_MINUTES", "10080")) # 7 days
|
||||
SESSION_EXPIRE_MINUTES: int = int(os.getenv("SESSION_EXPIRE_MINUTES", "1440")) # 24 hours
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(
|
||||
os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440")
|
||||
) # 24 hours
|
||||
REFRESH_TOKEN_EXPIRE_MINUTES: int = int(
|
||||
os.getenv("REFRESH_TOKEN_EXPIRE_MINUTES", "10080")
|
||||
) # 7 days
|
||||
SESSION_EXPIRE_MINUTES: int = int(
|
||||
os.getenv("SESSION_EXPIRE_MINUTES", "1440")
|
||||
) # 24 hours
|
||||
API_KEY_PREFIX: str = os.getenv("API_KEY_PREFIX", "en_")
|
||||
BCRYPT_ROUNDS: int = int(os.getenv("BCRYPT_ROUNDS", "6")) # Bcrypt work factor - lower for production performance
|
||||
BCRYPT_ROUNDS: int = int(
|
||||
os.getenv("BCRYPT_ROUNDS", "6")
|
||||
) # Bcrypt work factor - lower for production performance
|
||||
|
||||
# Admin user provisioning (used only on first startup)
|
||||
ADMIN_EMAIL: str = os.getenv("ADMIN_EMAIL")
|
||||
@@ -45,12 +55,12 @@ class Settings(BaseSettings):
|
||||
# Base URL for deriving CORS origins
|
||||
BASE_URL: str = os.getenv("BASE_URL", "localhost")
|
||||
|
||||
@field_validator('CORS_ORIGINS', mode='before')
|
||||
@field_validator("CORS_ORIGINS", mode="before")
|
||||
@classmethod
|
||||
def derive_cors_origins(cls, v, info):
|
||||
"""Derive CORS origins from BASE_URL if not explicitly set"""
|
||||
if v is None:
|
||||
base_url = info.data.get('BASE_URL', 'localhost')
|
||||
base_url = info.data.get("BASE_URL", "localhost")
|
||||
# Support both HTTP and HTTPS for production environments
|
||||
return [f"http://{base_url}", f"https://{base_url}"]
|
||||
return v if isinstance(v, list) else [v]
|
||||
@@ -64,14 +74,18 @@ class Settings(BaseSettings):
|
||||
# LLM Service Security (removed encryption - credentials handled by proxy)
|
||||
|
||||
# Plugin System Security
|
||||
PLUGIN_ENCRYPTION_KEY: Optional[str] = os.getenv("PLUGIN_ENCRYPTION_KEY") # Key for encrypting plugin secrets and configurations
|
||||
PLUGIN_ENCRYPTION_KEY: Optional[str] = os.getenv(
|
||||
"PLUGIN_ENCRYPTION_KEY"
|
||||
) # Key for encrypting plugin secrets and configurations
|
||||
|
||||
# API Keys for LLM providers
|
||||
OPENAI_API_KEY: Optional[str] = os.getenv("OPENAI_API_KEY")
|
||||
ANTHROPIC_API_KEY: Optional[str] = os.getenv("ANTHROPIC_API_KEY")
|
||||
GOOGLE_API_KEY: Optional[str] = os.getenv("GOOGLE_API_KEY")
|
||||
PRIVATEMODE_API_KEY: Optional[str] = os.getenv("PRIVATEMODE_API_KEY")
|
||||
PRIVATEMODE_PROXY_URL: str = os.getenv("PRIVATEMODE_PROXY_URL", "http://privatemode-proxy:8080/v1")
|
||||
PRIVATEMODE_PROXY_URL: str = os.getenv(
|
||||
"PRIVATEMODE_PROXY_URL", "http://privatemode-proxy:8080/v1"
|
||||
)
|
||||
|
||||
# Qdrant
|
||||
QDRANT_HOST: str = os.getenv("QDRANT_HOST", "localhost")
|
||||
@@ -79,36 +93,62 @@ class Settings(BaseSettings):
|
||||
QDRANT_API_KEY: Optional[str] = os.getenv("QDRANT_API_KEY")
|
||||
QDRANT_URL: str = os.getenv("QDRANT_URL", "http://localhost:6333")
|
||||
|
||||
|
||||
# Rate Limiting Configuration
|
||||
|
||||
# PrivateMode Standard tier limits (organization-level, not per user)
|
||||
# These are shared across all API keys and users in the organization
|
||||
PRIVATEMODE_REQUESTS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_MINUTE", "20"))
|
||||
PRIVATEMODE_REQUESTS_PER_HOUR: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_HOUR", "1200"))
|
||||
PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE", "20000"))
|
||||
PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE", "10000"))
|
||||
PRIVATEMODE_REQUESTS_PER_MINUTE: int = int(
|
||||
os.getenv("PRIVATEMODE_REQUESTS_PER_MINUTE", "20")
|
||||
)
|
||||
PRIVATEMODE_REQUESTS_PER_HOUR: int = int(
|
||||
os.getenv("PRIVATEMODE_REQUESTS_PER_HOUR", "1200")
|
||||
)
|
||||
PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE: int = int(
|
||||
os.getenv("PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE", "20000")
|
||||
)
|
||||
PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE: int = int(
|
||||
os.getenv("PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE", "10000")
|
||||
)
|
||||
|
||||
# Per-user limits (additional protection on top of organization limits)
|
||||
API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "20")) # Match PrivateMode
|
||||
API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "1200"))
|
||||
API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(
|
||||
os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "20")
|
||||
) # Match PrivateMode
|
||||
API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(
|
||||
os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "1200")
|
||||
)
|
||||
|
||||
# API key users (programmatic access)
|
||||
API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "20")) # Match PrivateMode
|
||||
API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "1200"))
|
||||
API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(
|
||||
os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "20")
|
||||
) # Match PrivateMode
|
||||
API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(
|
||||
os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "1200")
|
||||
)
|
||||
|
||||
# Premium/Enterprise API keys
|
||||
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")) # Match PrivateMode
|
||||
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200"))
|
||||
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(
|
||||
os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")
|
||||
) # Match PrivateMode
|
||||
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(
|
||||
os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200")
|
||||
)
|
||||
|
||||
# Request Size Limits
|
||||
API_MAX_REQUEST_BODY_SIZE: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE", "10485760")) # 10MB
|
||||
API_MAX_REQUEST_BODY_SIZE_PREMIUM: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE_PREMIUM", "52428800")) # 50MB for premium
|
||||
API_MAX_REQUEST_BODY_SIZE: int = int(
|
||||
os.getenv("API_MAX_REQUEST_BODY_SIZE", "10485760")
|
||||
) # 10MB
|
||||
API_MAX_REQUEST_BODY_SIZE_PREMIUM: int = int(
|
||||
os.getenv("API_MAX_REQUEST_BODY_SIZE_PREMIUM", "52428800")
|
||||
) # 50MB for premium
|
||||
|
||||
# IP Security
|
||||
|
||||
# Security Headers
|
||||
API_CSP_HEADER: str = os.getenv("API_CSP_HEADER", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'")
|
||||
API_CSP_HEADER: str = os.getenv(
|
||||
"API_CSP_HEADER",
|
||||
"default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'",
|
||||
)
|
||||
|
||||
# Monitoring
|
||||
PROMETHEUS_ENABLED: bool = os.getenv("PROMETHEUS_ENABLED", "True").lower() == "true"
|
||||
@@ -121,29 +161,46 @@ class Settings(BaseSettings):
|
||||
MODULES_CONFIG_PATH: str = os.getenv("MODULES_CONFIG_PATH", "config/modules.yaml")
|
||||
|
||||
# RAG Embedding Configuration
|
||||
RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE: int = int(os.getenv("RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE", "12"))
|
||||
RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE: int = int(
|
||||
os.getenv("RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE", "12")
|
||||
)
|
||||
RAG_EMBEDDING_BATCH_SIZE: int = int(os.getenv("RAG_EMBEDDING_BATCH_SIZE", "3"))
|
||||
RAG_EMBEDDING_RETRY_COUNT: int = int(os.getenv("RAG_EMBEDDING_RETRY_COUNT", "3"))
|
||||
RAG_EMBEDDING_RETRY_DELAYS: str = os.getenv("RAG_EMBEDDING_RETRY_DELAYS", "1,2,4,8,16")
|
||||
RAG_EMBEDDING_DELAY_BETWEEN_BATCHES: float = float(os.getenv("RAG_EMBEDDING_DELAY_BETWEEN_BATCHES", "1.0"))
|
||||
RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5"))
|
||||
RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true"
|
||||
RAG_WARN_ON_FALLBACK: bool = os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true"
|
||||
RAG_EMBEDDING_RETRY_DELAYS: str = os.getenv(
|
||||
"RAG_EMBEDDING_RETRY_DELAYS", "1,2,4,8,16"
|
||||
)
|
||||
RAG_EMBEDDING_DELAY_BETWEEN_BATCHES: float = float(
|
||||
os.getenv("RAG_EMBEDDING_DELAY_BETWEEN_BATCHES", "1.0")
|
||||
)
|
||||
RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(
|
||||
os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5")
|
||||
)
|
||||
RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = (
|
||||
os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true"
|
||||
)
|
||||
RAG_WARN_ON_FALLBACK: bool = (
|
||||
os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true"
|
||||
)
|
||||
RAG_EMBEDDING_MODEL: str = os.getenv("RAG_EMBEDDING_MODEL", "bge-m3")
|
||||
RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300"))
|
||||
RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120"))
|
||||
RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(
|
||||
os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300")
|
||||
)
|
||||
RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(
|
||||
os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120")
|
||||
)
|
||||
RAG_INDEXING_TIMEOUT: int = int(os.getenv("RAG_INDEXING_TIMEOUT", "120"))
|
||||
|
||||
# Plugin configuration
|
||||
PLUGINS_DIR: str = os.getenv("PLUGINS_DIR", "/plugins")
|
||||
PLUGINS_CONFIG_PATH: str = os.getenv("PLUGINS_CONFIG_PATH", "config/plugins.yaml")
|
||||
PLUGIN_REPOSITORY_URL: str = os.getenv("PLUGIN_REPOSITORY_URL", "https://plugins.enclava.com")
|
||||
PLUGIN_REPOSITORY_URL: str = os.getenv(
|
||||
"PLUGIN_REPOSITORY_URL", "https://plugins.enclava.com"
|
||||
)
|
||||
|
||||
# Logging
|
||||
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "json")
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||
|
||||
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"case_sensitive": True,
|
||||
|
||||
@@ -24,7 +24,9 @@ def setup_logging() -> None:
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.processors.UnicodeDecoder(),
|
||||
structlog.processors.JSONRenderer() if settings.LOG_FORMAT == "json" else structlog.dev.ConsoleRenderer(),
|
||||
structlog.processors.JSONRenderer()
|
||||
if settings.LOG_FORMAT == "json"
|
||||
else structlog.dev.ConsoleRenderer(),
|
||||
],
|
||||
context_class=dict,
|
||||
logger_factory=LoggerFactory(),
|
||||
|
||||
367
backend/app/core/permissions.py
Normal file
367
backend/app/core/permissions.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
Permissions Module
|
||||
Role-based access control decorators and utilities
|
||||
"""
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Union, Callable
|
||||
from fastapi import HTTPException, status, Depends
|
||||
from fastapi.security import HTTPBearer
|
||||
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
def require_permission(
|
||||
user: User, permission: str, resource_id: Optional[Union[str, int]] = None
|
||||
):
|
||||
"""
|
||||
Check if user has the required permission
|
||||
|
||||
Args:
|
||||
user: User object from dependency injection
|
||||
permission: Required permission string
|
||||
resource_id: Optional resource ID for resource-specific permissions
|
||||
|
||||
Raises:
|
||||
HTTPException: If user doesn't have the required permission
|
||||
"""
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required"
|
||||
)
|
||||
|
||||
# Check if user is active
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Account is not active"
|
||||
)
|
||||
|
||||
# Check if account is locked
|
||||
if user.account_locked:
|
||||
if user.account_locked_until and user.account_locked_until > datetime.utcnow():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Account is temporarily locked",
|
||||
)
|
||||
else:
|
||||
# Unlock account if lock period has expired
|
||||
user.unlock_account()
|
||||
|
||||
# Check permission
|
||||
if not user.has_permission(permission):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Permission '{permission}' required",
|
||||
)
|
||||
|
||||
|
||||
def require_permissions(permissions: List[str], require_all: bool = True):
|
||||
"""
|
||||
Decorator to require multiple permissions
|
||||
|
||||
Args:
|
||||
permissions: List of required permissions
|
||||
require_all: If True, user must have all permissions. If False, any one permission is sufficient
|
||||
"""
|
||||
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Extract user from kwargs (assuming it's passed as a dependency)
|
||||
user = None
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, User):
|
||||
user = value
|
||||
break
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
)
|
||||
|
||||
# Check permissions
|
||||
if require_all:
|
||||
# User must have all permissions
|
||||
for permission in permissions:
|
||||
if not user.has_permission(permission):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"All of the following permissions required: {', '.join(permissions)}",
|
||||
)
|
||||
else:
|
||||
# User needs at least one permission
|
||||
if not any(
|
||||
user.has_permission(permission) for permission in permissions
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"At least one of the following permissions required: {', '.join(permissions)}",
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_role(role_names: Union[str, List[str]], require_all: bool = True):
|
||||
"""
|
||||
Decorator to require specific roles
|
||||
|
||||
Args:
|
||||
role_names: Required role name(s)
|
||||
require_all: If True, user must have all roles. If False, any one role is sufficient
|
||||
"""
|
||||
if isinstance(role_names, str):
|
||||
role_names = [role_names]
|
||||
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Extract user from kwargs
|
||||
user = None
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, User):
|
||||
user = value
|
||||
break
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
)
|
||||
|
||||
# Check roles
|
||||
user_role_names = []
|
||||
if user.role:
|
||||
user_role_names.append(user.role.name)
|
||||
|
||||
if require_all:
|
||||
# User must have all roles
|
||||
for role_name in role_names:
|
||||
if role_name not in user_role_names:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"All of the following roles required: {', '.join(role_names)}",
|
||||
)
|
||||
else:
|
||||
# User needs at least one role
|
||||
if not any(role_name in user_role_names for role_name in role_names):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"At least one of the following roles required: {', '.join(role_names)}",
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_minimum_role(minimum_role_level: str):
|
||||
"""
|
||||
Decorator to require minimum role level based on hierarchy
|
||||
|
||||
Args:
|
||||
minimum_role_level: Minimum required role level
|
||||
"""
|
||||
role_hierarchy = {"read_only": 1, "user": 2, "admin": 3, "super_admin": 4}
|
||||
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Extract user from kwargs
|
||||
user = None
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, User):
|
||||
user = value
|
||||
break
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
)
|
||||
|
||||
# Superusers bypass role checks
|
||||
if user.is_superuser:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
# Check role level
|
||||
if not user.role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Minimum role level '{minimum_role_level}' required",
|
||||
)
|
||||
|
||||
user_level = role_hierarchy.get(user.role.level, 0)
|
||||
required_level = role_hierarchy.get(minimum_role_level, 0)
|
||||
|
||||
if user_level < required_level:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Minimum role level '{minimum_role_level}' required",
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def check_resource_permission(
|
||||
user: User, resource_type: str, resource_id: Union[str, int], action: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user has permission to perform action on specific resource
|
||||
|
||||
Args:
|
||||
user: User object
|
||||
resource_type: Type of resource (e.g., 'user', 'budget', 'api_key')
|
||||
resource_id: ID of the resource
|
||||
action: Action to perform (e.g., 'read', 'update', 'delete')
|
||||
|
||||
Returns:
|
||||
bool: True if user has permission, False otherwise
|
||||
"""
|
||||
# Superusers can do anything
|
||||
if user.is_superuser:
|
||||
return True
|
||||
|
||||
# Check basic permissions
|
||||
permission = f"{action}_{resource_type}"
|
||||
if user.has_permission(permission):
|
||||
return True
|
||||
|
||||
# Check own resource permissions
|
||||
if resource_type == "user" and str(resource_id) == str(user.id):
|
||||
if user.has_permission(f"{action}_own"):
|
||||
return True
|
||||
|
||||
# Check role-based resource access
|
||||
if user.role:
|
||||
# Admins can manage all users
|
||||
if resource_type == "user" and user.role.level in ["admin", "super_admin"]:
|
||||
return True
|
||||
|
||||
# Users with budget permissions can manage budgets
|
||||
if resource_type == "budget" and user.role.can_manage_budgets:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def require_resource_permission(
|
||||
resource_type: str, resource_id_param: str = "resource_id", action: str = "read"
|
||||
):
|
||||
"""
|
||||
Decorator to require permission for specific resource
|
||||
|
||||
Args:
|
||||
resource_type: Type of resource
|
||||
resource_id_param: Name of parameter containing resource ID
|
||||
action: Action to perform
|
||||
"""
|
||||
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Extract user and resource ID from kwargs
|
||||
user = None
|
||||
resource_id = None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, User):
|
||||
user = value
|
||||
elif key == resource_id_param:
|
||||
resource_id = value
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
)
|
||||
|
||||
if resource_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Resource ID not provided",
|
||||
)
|
||||
|
||||
# Check resource permission
|
||||
if not check_resource_permission(user, resource_type, resource_id, action):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Permission '{action}_{resource_type}' required",
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def check_budget_permission(user: User, budget_id: int, action: str) -> bool:
|
||||
"""
|
||||
Check if user has permission to perform action on specific budget
|
||||
|
||||
Args:
|
||||
user: User object
|
||||
budget_id: ID of the budget
|
||||
action: Action to perform
|
||||
|
||||
Returns:
|
||||
bool: True if user has permission, False otherwise
|
||||
"""
|
||||
# Superusers can do anything
|
||||
if user.is_superuser:
|
||||
return True
|
||||
|
||||
# Check if user owns the budget
|
||||
for budget in user.budgets:
|
||||
if budget.id == budget_id and budget.is_active:
|
||||
return user.has_permission(f"{action}_own")
|
||||
|
||||
# Check if user can manage all budgets
|
||||
if user.has_permission(f"{action}_all") or (
|
||||
user.role and user.role.can_manage_budgets
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_api_key_permission(user: User, api_key_id: int, action: str) -> bool:
|
||||
"""
|
||||
Check if user has permission to perform action on specific API key
|
||||
|
||||
Args:
|
||||
user: User object
|
||||
api_key_id: ID of the API key
|
||||
action: Action to perform
|
||||
|
||||
Returns:
|
||||
bool: True if user has permission, False otherwise
|
||||
"""
|
||||
# Superusers can do anything
|
||||
if user.is_superuser:
|
||||
return True
|
||||
|
||||
# Check if user owns the API key
|
||||
for api_key in user.api_keys:
|
||||
if api_key.id == api_key_id:
|
||||
return user.has_permission(f"{action}_own")
|
||||
|
||||
# Check if user can manage all API keys
|
||||
if user.has_permission(f"{action}_all"):
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -24,30 +24,35 @@ logger = logging.getLogger(__name__)
|
||||
# Password hashing
|
||||
# Use a lower work factor for better performance in production
|
||||
pwd_context = CryptContext(
|
||||
schemes=["bcrypt"],
|
||||
deprecated="auto",
|
||||
bcrypt__rounds=settings.BCRYPT_ROUNDS
|
||||
schemes=["bcrypt"], deprecated="auto", bcrypt__rounds=settings.BCRYPT_ROUNDS
|
||||
)
|
||||
|
||||
# JWT token handling
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
logger.info(f"=== PASSWORD VERIFICATION START === BCRYPT_ROUNDS: {settings.BCRYPT_ROUNDS}")
|
||||
logger.info(
|
||||
f"=== PASSWORD VERIFICATION START === BCRYPT_ROUNDS: {settings.BCRYPT_ROUNDS}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Run password verification in a thread with timeout
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(pwd_context.verify, plain_password, hashed_password)
|
||||
future = executor.submit(
|
||||
pwd_context.verify, plain_password, hashed_password
|
||||
)
|
||||
result = future.result(timeout=5.0) # 5 second timeout
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
logger.info(f"=== PASSWORD VERIFICATION END === Duration: {duration:.3f}s, Result: {result}")
|
||||
logger.info(
|
||||
f"=== PASSWORD VERIFICATION END === Duration: {duration:.3f}s, Result: {result}"
|
||||
)
|
||||
|
||||
if duration > 1:
|
||||
logger.warning(f"PASSWORD VERIFICATION TOOK TOO LONG: {duration:.3f}s")
|
||||
@@ -61,24 +66,33 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
logger.error(f"=== PASSWORD VERIFICATION FAILED === Duration: {duration:.3f}s, Error: {e}")
|
||||
logger.error(
|
||||
f"=== PASSWORD VERIFICATION FAILED === Duration: {duration:.3f}s, Error: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Generate password hash"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_api_key(plain_api_key: str, hashed_api_key: str) -> bool:
|
||||
"""Verify an API key against its hash"""
|
||||
return pwd_context.verify(plain_api_key, hashed_api_key)
|
||||
|
||||
|
||||
def get_api_key_hash(api_key: str) -> str:
|
||||
"""Generate API key hash"""
|
||||
return pwd_context.hash(api_key)
|
||||
|
||||
def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||
|
||||
def create_access_token(
|
||||
data: Dict[str, Any], expires_delta: Optional[timedelta] = None
|
||||
) -> str:
|
||||
"""Create JWT access token"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
logger.info(f"=== CREATE ACCESS TOKEN START ===")
|
||||
|
||||
@@ -87,12 +101,16 @@ def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta]
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
expire = datetime.utcnow() + timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
|
||||
to_encode.update({"exp": expire})
|
||||
logger.info(f"JWT encode start...")
|
||||
encode_start = time.time()
|
||||
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM
|
||||
)
|
||||
encode_end = time.time()
|
||||
encode_duration = encode_end - encode_start
|
||||
|
||||
@@ -103,7 +121,9 @@ def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta]
|
||||
logger.info(f"Created access token for user {data.get('sub')}")
|
||||
logger.info(f"Token expires at: {expire.isoformat()} (UTC)")
|
||||
logger.info(f"Current UTC time: {datetime.utcnow().isoformat()}")
|
||||
logger.info(f"ACCESS_TOKEN_EXPIRE_MINUTES setting: {settings.ACCESS_TOKEN_EXPIRE_MINUTES}")
|
||||
logger.info(
|
||||
f"ACCESS_TOKEN_EXPIRE_MINUTES setting: {settings.ACCESS_TOKEN_EXPIRE_MINUTES}"
|
||||
)
|
||||
logger.info(f"JWT encode duration: {encode_duration:.3f}s")
|
||||
logger.info(f"Total token creation duration: {total_duration:.3f}s")
|
||||
logger.info(f"=== CREATE ACCESS TOKEN END ===")
|
||||
@@ -112,17 +132,25 @@ def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta]
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
total_duration = end_time - start_time
|
||||
logger.error(f"=== CREATE ACCESS TOKEN FAILED === Duration: {total_duration:.3f}s, Error: {e}")
|
||||
logger.error(
|
||||
f"=== CREATE ACCESS TOKEN FAILED === Duration: {total_duration:.3f}s, Error: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def create_refresh_token(data: Dict[str, Any]) -> str:
|
||||
"""Create JWT refresh token"""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES)
|
||||
expire = datetime.utcnow() + timedelta(
|
||||
minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM
|
||||
)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def verify_token(token: str) -> Dict[str, Any]:
|
||||
"""Verify JWT token and return payload"""
|
||||
try:
|
||||
@@ -133,15 +161,21 @@ def verify_token(token: str) -> Dict[str, Any]:
|
||||
# Decode without verification first to check expiration
|
||||
try:
|
||||
unverified_payload = jwt.get_unverified_claims(token)
|
||||
exp_timestamp = unverified_payload.get('exp')
|
||||
exp_timestamp = unverified_payload.get("exp")
|
||||
if exp_timestamp:
|
||||
exp_datetime = datetime.fromtimestamp(exp_timestamp, tz=None)
|
||||
logger.info(f"Token expiration time: {exp_datetime.isoformat()} (UTC)")
|
||||
logger.info(f"Time until expiration: {(exp_datetime - current_time).total_seconds()} seconds")
|
||||
logger.info(
|
||||
f"Time until expiration: {(exp_datetime - current_time).total_seconds()} seconds"
|
||||
)
|
||||
except Exception as decode_error:
|
||||
logger.warning(f"Could not decode token for expiration check: {decode_error}")
|
||||
logger.warning(
|
||||
f"Could not decode token for expiration check: {decode_error}"
|
||||
)
|
||||
|
||||
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
logger.info(f"Token verified successfully for user {payload.get('sub')}")
|
||||
return payload
|
||||
except JWTError as e:
|
||||
@@ -149,9 +183,10 @@ def verify_token(token: str) -> Dict[str, Any]:
|
||||
logger.warning(f"Current UTC time: {datetime.utcnow().isoformat()}")
|
||||
raise AuthenticationError("Invalid token")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current user from JWT token"""
|
||||
try:
|
||||
@@ -167,9 +202,10 @@ async def get_current_user(
|
||||
# Load user from database
|
||||
from app.models.user import User
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
# Query user from database
|
||||
stmt = select(User).where(User.id == int(user_id))
|
||||
stmt = select(User).options(selectinload(User.role)).where(User.id == int(user_id))
|
||||
result = await db.execute(stmt)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
@@ -181,7 +217,7 @@ async def get_current_user(
|
||||
"is_superuser": payload.get("is_superuser", False),
|
||||
"role": payload.get("role", "user"),
|
||||
"is_active": True,
|
||||
"permissions": [] # Default to empty list for permissions
|
||||
"permissions": [], # Default to empty list for permissions
|
||||
}
|
||||
|
||||
# Update last login
|
||||
@@ -191,39 +227,43 @@ async def get_current_user(
|
||||
# Calculate effective permissions using permission manager
|
||||
from app.services.permission_manager import permission_registry
|
||||
|
||||
# Convert role string to list for permission calculation
|
||||
user_roles = [user.role] if user.role else []
|
||||
# Convert role to name for permission calculation
|
||||
user_roles = [user.role.name] if user.role else []
|
||||
|
||||
# For super admin users, use only role-based permissions, ignore custom permissions
|
||||
# Custom permissions might contain legacy formats like ['*'] that don't work with new system
|
||||
# Custom permissions might contain legacy formats like ['*'] or dict formats
|
||||
custom_permissions = []
|
||||
if not user.is_superuser:
|
||||
# Only use custom permissions for non-superuser accounts
|
||||
if user.permissions:
|
||||
if isinstance(user.permissions, list):
|
||||
custom_permissions = user.permissions
|
||||
# Support both list-based and dict-based custom permission formats
|
||||
raw_custom_perms = getattr(user, "custom_permissions", None)
|
||||
if raw_custom_perms:
|
||||
if isinstance(raw_custom_perms, list):
|
||||
custom_permissions = raw_custom_perms
|
||||
elif isinstance(raw_custom_perms, dict):
|
||||
granted = raw_custom_perms.get("granted")
|
||||
if isinstance(granted, list):
|
||||
custom_permissions = granted
|
||||
|
||||
# Calculate effective permissions based on role and custom permissions
|
||||
effective_permissions = permission_registry.get_user_permissions(
|
||||
roles=user_roles,
|
||||
custom_permissions=custom_permissions
|
||||
roles=user_roles, custom_permissions=custom_permissions
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"username": user.username,
|
||||
"is_superuser": user.is_superuser,
|
||||
"is_active": user.is_active,
|
||||
"role": user.role,
|
||||
"role": user.role.name if user.role else None,
|
||||
"permissions": effective_permissions, # Use calculated permissions
|
||||
"user_obj": user # Include full user object for other operations
|
||||
"user_obj": user, # Include full user object for other operations
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication error: {e}")
|
||||
raise AuthenticationError("Could not validate credentials")
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
@@ -233,6 +273,7 @@ async def get_current_active_user(
|
||||
raise AuthenticationError("User account is inactive")
|
||||
return current_user
|
||||
|
||||
|
||||
async def get_current_superuser(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
@@ -241,6 +282,7 @@ async def get_current_superuser(
|
||||
raise AuthorizationError("Insufficient privileges")
|
||||
return current_user
|
||||
|
||||
|
||||
def generate_api_key() -> str:
|
||||
"""Generate a new API key"""
|
||||
import secrets
|
||||
@@ -248,21 +290,23 @@ def generate_api_key() -> str:
|
||||
|
||||
# Generate random string
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
api_key = ''.join(secrets.choice(alphabet) for _ in range(32))
|
||||
api_key = "".join(secrets.choice(alphabet) for _ in range(32))
|
||||
|
||||
return f"{settings.API_KEY_PREFIX}{api_key}"
|
||||
|
||||
|
||||
def hash_api_key(api_key: str) -> str:
|
||||
"""Hash API key for storage"""
|
||||
return get_password_hash(api_key)
|
||||
|
||||
|
||||
def verify_api_key(api_key: str, hashed_key: str) -> bool:
|
||||
"""Verify API key against hash"""
|
||||
return verify_password(api_key, hashed_key)
|
||||
|
||||
|
||||
async def get_api_key_user(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
request: Request, db: AsyncSession = Depends(get_db)
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get user from API key"""
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
@@ -282,10 +326,14 @@ async def get_api_key_user(
|
||||
key_prefix = api_key[:8]
|
||||
|
||||
# Query API key from database
|
||||
stmt = select(APIKey).join(User).where(
|
||||
stmt = (
|
||||
select(APIKey)
|
||||
.join(User)
|
||||
.where(
|
||||
APIKey.key_prefix == key_prefix,
|
||||
APIKey.is_active == True,
|
||||
User.is_active == True
|
||||
User.is_active == True,
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
db_api_key = result.scalar_one_or_none()
|
||||
@@ -306,7 +354,7 @@ async def get_api_key_user(
|
||||
await db.commit()
|
||||
|
||||
# Load associated user
|
||||
user_stmt = select(User).where(User.id == db_api_key.user_id)
|
||||
user_stmt = select(User).options(selectinload(User.role)).where(User.id == db_api_key.user_id)
|
||||
user_result = await db.execute(user_stmt)
|
||||
user = user_result.scalar_one_or_none()
|
||||
|
||||
@@ -316,22 +364,36 @@ async def get_api_key_user(
|
||||
# Calculate effective permissions using permission manager
|
||||
from app.services.permission_manager import permission_registry
|
||||
|
||||
# Convert role string to list for permission calculation
|
||||
user_roles = [user.role] if user.role else []
|
||||
# Convert role to name for permission calculation
|
||||
user_roles = [user.role.name] if user.role else []
|
||||
|
||||
# Use API key specific permissions if available, otherwise use user permissions
|
||||
# Use API key specific permissions if available
|
||||
api_key_permissions = db_api_key.permissions if db_api_key.permissions else []
|
||||
|
||||
# Get custom permissions from database (convert dict to list if needed)
|
||||
custom_permissions = api_key_permissions
|
||||
if user.permissions:
|
||||
if isinstance(user.permissions, list):
|
||||
custom_permissions.extend(user.permissions)
|
||||
# Normalize permissions into a flat list of granted permission strings
|
||||
custom_permissions: list[str] = []
|
||||
|
||||
# Handle API key permissions that may be stored as list or dict
|
||||
if isinstance(api_key_permissions, list):
|
||||
custom_permissions.extend(api_key_permissions)
|
||||
elif isinstance(api_key_permissions, dict):
|
||||
api_granted = api_key_permissions.get("granted")
|
||||
if isinstance(api_granted, list):
|
||||
custom_permissions.extend(api_granted)
|
||||
|
||||
# Merge in user-level custom permissions for non-superusers
|
||||
raw_user_custom = getattr(user, "custom_permissions", None)
|
||||
if raw_user_custom and not user.is_superuser:
|
||||
if isinstance(raw_user_custom, list):
|
||||
custom_permissions.extend(raw_user_custom)
|
||||
elif isinstance(raw_user_custom, dict):
|
||||
user_granted = raw_user_custom.get("granted")
|
||||
if isinstance(user_granted, list):
|
||||
custom_permissions.extend(user_granted)
|
||||
|
||||
# Calculate effective permissions based on role and custom permissions
|
||||
effective_permissions = permission_registry.get_user_permissions(
|
||||
roles=user_roles,
|
||||
custom_permissions=custom_permissions
|
||||
roles=user_roles, custom_permissions=custom_permissions
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -344,12 +406,13 @@ async def get_api_key_user(
|
||||
"permissions": effective_permissions,
|
||||
"api_key": db_api_key,
|
||||
"user_obj": user,
|
||||
"auth_type": "api_key"
|
||||
"auth_type": "api_key",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"API key lookup error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class RequiresPermission:
|
||||
"""Dependency class for permission checking"""
|
||||
|
||||
@@ -367,7 +430,14 @@ class RequiresPermission:
|
||||
role_permissions = {
|
||||
"user": ["read_own", "create_own", "update_own"],
|
||||
"admin": ["read_all", "create_all", "update_all", "delete_own"],
|
||||
"super_admin": ["read_all", "create_all", "update_all", "delete_all", "manage_users", "manage_modules"]
|
||||
"super_admin": [
|
||||
"read_all",
|
||||
"create_all",
|
||||
"update_all",
|
||||
"delete_all",
|
||||
"manage_users",
|
||||
"manage_modules",
|
||||
],
|
||||
}
|
||||
|
||||
if role in role_permissions and self.permission in role_permissions[role]:
|
||||
@@ -386,6 +456,7 @@ class RequiresPermission:
|
||||
|
||||
raise AuthorizationError(f"Permission '{self.permission}' required")
|
||||
|
||||
|
||||
class RequiresRole:
|
||||
"""Dependency class for role checking"""
|
||||
|
||||
@@ -401,11 +472,7 @@ class RequiresRole:
|
||||
user_role = current_user.get("role", "user")
|
||||
|
||||
# Define role hierarchy
|
||||
role_hierarchy = {
|
||||
"user": 1,
|
||||
"admin": 2,
|
||||
"super_admin": 3
|
||||
}
|
||||
role_hierarchy = {"user": 1, "admin": 2, "super_admin": 3}
|
||||
|
||||
required_level = role_hierarchy.get(self.role, 0)
|
||||
user_level = role_hierarchy.get(user_role, 0)
|
||||
@@ -413,4 +480,6 @@ class RequiresRole:
|
||||
if user_level >= required_level:
|
||||
return current_user
|
||||
|
||||
raise AuthorizationError(f"Role '{self.role}' required, but user has role '{user_role}'")
|
||||
raise AuthorizationError(
|
||||
f"Role '{self.role}' required, but user has role '{user_role}'"
|
||||
)
|
||||
|
||||
@@ -72,6 +72,7 @@ metadata = MetaData()
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Get database session"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
request_id = f"db_{int(time.time() * 1000)}"
|
||||
|
||||
@@ -86,7 +87,10 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
except Exception as e:
|
||||
# Only log if there's an actual error, not normal operation
|
||||
if str(e).strip(): # Only log if error message exists
|
||||
logger.error(f"[{request_id}] Database session error: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"[{request_id}] Database session error: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
@@ -94,9 +98,13 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
await session.close()
|
||||
close_time = time.time() - close_start
|
||||
total_time = time.time() - start_time
|
||||
logger.info(f"[{request_id}] Database session closed. Close time: {close_time:.3f}s, Total time: {total_time:.3f}s")
|
||||
logger.info(
|
||||
f"[{request_id}] Database session closed. Close time: {close_time:.3f}s, Total time: {total_time:.3f}s"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[{request_id}] Failed to create database session: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"[{request_id}] Failed to create database session: {e}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@@ -106,8 +114,10 @@ async def init_db():
|
||||
async with engine.begin() as conn:
|
||||
# Import all models to ensure they're registered
|
||||
from app.models.user import User
|
||||
from app.models.role import Role
|
||||
from app.models.api_key import APIKey
|
||||
from app.models.usage_tracking import UsageTracking
|
||||
|
||||
# Import additional models - these are available
|
||||
try:
|
||||
from app.models.budget import Budget
|
||||
@@ -127,6 +137,9 @@ async def init_db():
|
||||
# Tables are now created via migration container - no need to create here
|
||||
# await conn.run_sync(Base.metadata.create_all) # DISABLED - migrations handle this
|
||||
|
||||
# Create default roles if they don't exist
|
||||
await create_default_roles()
|
||||
|
||||
# Create default admin user if no admin exists
|
||||
await create_default_admin()
|
||||
|
||||
@@ -136,9 +149,42 @@ async def init_db():
|
||||
raise
|
||||
|
||||
|
||||
async def create_default_roles():
|
||||
"""Create default roles if they don't exist"""
|
||||
from app.models.role import Role, RoleLevel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
try:
|
||||
async with async_session_factory() as session:
|
||||
# Check if any roles exist
|
||||
stmt = select(Role).limit(1)
|
||||
result = await session.execute(stmt)
|
||||
existing_role = result.scalar_one_or_none()
|
||||
|
||||
if existing_role:
|
||||
logger.info("Roles already exist - skipping default role creation")
|
||||
return
|
||||
|
||||
# Create default roles using the Role.create_default_roles class method
|
||||
default_roles = Role.create_default_roles()
|
||||
|
||||
for role in default_roles:
|
||||
session.add(role)
|
||||
|
||||
await session.commit()
|
||||
|
||||
logger.info("Created default roles: read_only, user, admin, super_admin")
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Failed to create default roles due to database error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def create_default_admin():
|
||||
"""Create default admin user if user with ADMIN_EMAIL doesn't exist"""
|
||||
from app.models.user import User
|
||||
from app.models.role import Role
|
||||
from app.core.security import get_password_hash
|
||||
from app.core.config import settings
|
||||
from sqlalchemy import select
|
||||
@@ -159,19 +205,33 @@ async def create_default_admin():
|
||||
existing_user = result.scalar_one_or_none()
|
||||
|
||||
if existing_user:
|
||||
logger.info(f"User with email {admin_email} already exists - skipping admin creation")
|
||||
logger.info(
|
||||
f"User with email {admin_email} already exists - skipping admin creation"
|
||||
)
|
||||
return
|
||||
|
||||
# Get the super_admin role
|
||||
stmt = select(Role).where(Role.name == "super_admin")
|
||||
result = await session.execute(stmt)
|
||||
super_admin_role = result.scalar_one_or_none()
|
||||
|
||||
if not super_admin_role:
|
||||
logger.error("Super admin role not found - cannot create admin user")
|
||||
return
|
||||
|
||||
# Create admin user from environment variables
|
||||
# Generate username from email (part before @)
|
||||
admin_username = admin_email.split('@')[0]
|
||||
admin_username = admin_email.split("@")[0]
|
||||
|
||||
admin_user = User.create_default_admin(
|
||||
email=admin_email,
|
||||
username=admin_username,
|
||||
password_hash=get_password_hash(admin_password)
|
||||
password_hash=get_password_hash(admin_password),
|
||||
)
|
||||
|
||||
# Assign the super_admin role
|
||||
admin_user.role_id = super_admin_role.id
|
||||
|
||||
session.add(admin_user)
|
||||
await session.commit()
|
||||
|
||||
@@ -179,14 +239,19 @@ async def create_default_admin():
|
||||
logger.warning("ADMIN USER CREATED FROM ENVIRONMENT")
|
||||
logger.warning(f"Email: {admin_email}")
|
||||
logger.warning(f"Username: {admin_username}")
|
||||
logger.warning("Password: [Set via ADMIN_PASSWORD - only used on first creation]")
|
||||
logger.warning("Role: Super Administrator")
|
||||
logger.warning(
|
||||
"Password: [Set via ADMIN_PASSWORD - only used on first creation]"
|
||||
)
|
||||
logger.warning("PLEASE CHANGE THE PASSWORD AFTER FIRST LOGIN")
|
||||
logger.warning("=" * 60)
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Failed to create default admin user due to database error: {e}")
|
||||
except AttributeError as e:
|
||||
logger.error(f"Failed to create default admin user: invalid ADMIN_EMAIL '{settings.ADMIN_EMAIL}'")
|
||||
logger.error(
|
||||
f"Failed to create default admin user: invalid ADMIN_EMAIL '{settings.ADMIN_EMAIL}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create default admin user: {e}")
|
||||
# Don't raise here as this shouldn't block the application startup
|
||||
|
||||
@@ -59,7 +59,10 @@ async def _check_redis_startup():
|
||||
duration = time.perf_counter() - start
|
||||
logger.info(
|
||||
"Startup Redis check succeeded",
|
||||
extra={"redis_url": settings.REDIS_URL, "duration_seconds": round(duration, 3)},
|
||||
extra={
|
||||
"redis_url": settings.REDIS_URL,
|
||||
"duration_seconds": round(duration, 3),
|
||||
},
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
@@ -107,6 +110,7 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
# Initialize core cache service (before database to provide caching for auth)
|
||||
from app.core.cache import core_cache
|
||||
|
||||
try:
|
||||
await core_cache.initialize()
|
||||
logger.info("Core cache service initialized successfully")
|
||||
@@ -128,6 +132,7 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
# Ensure platform permissions are registered before module discovery
|
||||
from app.services.permission_manager import permission_registry
|
||||
|
||||
permission_registry.register_platform_permissions()
|
||||
|
||||
# Initialize LLM service (needed by RAG module) concurrently
|
||||
@@ -156,6 +161,7 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
# Initialize document processor
|
||||
from app.services.document_processor import document_processor
|
||||
|
||||
try:
|
||||
await document_processor.start()
|
||||
app.state.document_processor = document_processor
|
||||
@@ -171,6 +177,7 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
# Start background audit worker
|
||||
from app.services.audit_service import start_audit_worker
|
||||
|
||||
try:
|
||||
start_audit_worker()
|
||||
except Exception as exc:
|
||||
@@ -179,10 +186,13 @@ async def lifespan(app: FastAPI):
|
||||
# Initialize plugin auto-discovery service concurrently
|
||||
async def initialize_plugins():
|
||||
from app.services.plugin_autodiscovery import initialize_plugin_autodiscovery
|
||||
|
||||
try:
|
||||
discovery_results = await initialize_plugin_autodiscovery()
|
||||
app.state.plugin_discovery_results = discovery_results
|
||||
logger.info(f"Plugin auto-discovery completed: {discovery_results.get('summary')}")
|
||||
logger.info(
|
||||
f"Plugin auto-discovery completed: {discovery_results.get('summary')}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Plugin auto-discovery failed: {exc}")
|
||||
app.state.plugin_discovery_results = {"error": str(exc)}
|
||||
@@ -205,6 +215,7 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
# Cleanup embedding service HTTP sessions
|
||||
from app.services.embedding_service import embedding_service
|
||||
|
||||
try:
|
||||
await embedding_service.cleanup()
|
||||
logger.info("Embedding service cleaned up successfully")
|
||||
@@ -213,14 +224,16 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
# Close core cache service
|
||||
from app.core.cache import core_cache
|
||||
|
||||
await core_cache.cleanup()
|
||||
|
||||
# Close Redis connection for cached API key service
|
||||
from app.services.cached_api_key import cached_api_key_service
|
||||
|
||||
await cached_api_key_service.close()
|
||||
|
||||
# Stop document processor
|
||||
processor = getattr(app.state, 'document_processor', None)
|
||||
processor = getattr(app.state, "document_processor", None)
|
||||
if processor:
|
||||
await processor.stop()
|
||||
|
||||
@@ -297,7 +310,9 @@ async def validation_exception_handler(request, exc: RequestValidationError):
|
||||
"type": error.get("type", ""),
|
||||
"location": error.get("loc", []),
|
||||
"message": error.get("msg", ""),
|
||||
"input": str(error.get("input", "")) if error.get("input") is not None else None
|
||||
"input": str(error.get("input", ""))
|
||||
if error.get("input") is not None
|
||||
else None,
|
||||
}
|
||||
errors.append(error_dict)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from app.db.database import get_db
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Context variable to pass analytics data from endpoints to middleware
|
||||
analytics_context: ContextVar[dict] = ContextVar('analytics_context', default={})
|
||||
analytics_context: ContextVar[dict] = ContextVar("analytics_context", default={})
|
||||
|
||||
|
||||
class AnalyticsMiddleware(BaseHTTPMiddleware):
|
||||
@@ -28,7 +28,12 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
|
||||
start_time = time.time()
|
||||
|
||||
# Skip analytics for health checks and static files
|
||||
if request.url.path in ["/health", "/docs", "/redoc", "/openapi.json"] or request.url.path.startswith("/static"):
|
||||
if request.url.path in [
|
||||
"/health",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
] or request.url.path.startswith("/static"):
|
||||
return await call_next(request)
|
||||
|
||||
# Get user info if available from token
|
||||
@@ -42,6 +47,7 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
|
||||
# Try to extract user info from token without full validation
|
||||
# This is a lightweight check for analytics purposes
|
||||
from app.core.security import verify_token
|
||||
|
||||
try:
|
||||
payload = verify_token(token)
|
||||
user_id = int(payload.get("sub"))
|
||||
@@ -77,7 +83,7 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
|
||||
error_message = str(e)
|
||||
response = JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": "INTERNAL_ERROR", "message": "Internal server error"}
|
||||
content={"error": "INTERNAL_ERROR", "message": "Internal server error"},
|
||||
)
|
||||
|
||||
# Calculate timing
|
||||
@@ -86,7 +92,7 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
# Get response size
|
||||
response_size = 0
|
||||
if hasattr(response, 'body'):
|
||||
if hasattr(response, "body"):
|
||||
response_size = len(response.body) if response.body else 0
|
||||
|
||||
# Get analytics data from context (set by endpoints)
|
||||
@@ -107,22 +113,25 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
|
||||
response_size=response_size,
|
||||
error_message=error_message,
|
||||
# Token/cost info populated by LLM endpoints via context
|
||||
model=context_data.get('model'),
|
||||
request_tokens=context_data.get('request_tokens', 0),
|
||||
response_tokens=context_data.get('response_tokens', 0),
|
||||
total_tokens=context_data.get('total_tokens', 0),
|
||||
cost_cents=context_data.get('cost_cents', 0),
|
||||
budget_ids=context_data.get('budget_ids', []),
|
||||
budget_warnings=context_data.get('budget_warnings', [])
|
||||
model=context_data.get("model"),
|
||||
request_tokens=context_data.get("request_tokens", 0),
|
||||
response_tokens=context_data.get("response_tokens", 0),
|
||||
total_tokens=context_data.get("total_tokens", 0),
|
||||
cost_cents=context_data.get("cost_cents", 0),
|
||||
budget_ids=context_data.get("budget_ids", []),
|
||||
budget_warnings=context_data.get("budget_warnings", []),
|
||||
)
|
||||
|
||||
# Track the event
|
||||
try:
|
||||
from app.services.analytics import analytics_service
|
||||
|
||||
if analytics_service is not None:
|
||||
await analytics_service.track_request(event)
|
||||
else:
|
||||
logger.warning("Analytics service not initialized, skipping event tracking")
|
||||
logger.warning(
|
||||
"Analytics service not initialized, skipping event tracking"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to track analytics event: {e}")
|
||||
# Don't let analytics failures break the request
|
||||
|
||||
393
backend/app/middleware/audit_middleware.py
Normal file
393
backend/app/middleware/audit_middleware.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""
|
||||
Audit Logging Middleware
|
||||
Automatically logs user actions and system events
|
||||
"""
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
from typing import Callable, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.audit_log import AuditLog, AuditAction, AuditSeverity
|
||||
from app.db.database import get_db_session
|
||||
from app.core.security import verify_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuditLoggingMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to automatically log user actions and API calls"""
|
||||
|
||||
def __init__(self, app, exclude_paths: Optional[list] = None):
|
||||
super().__init__(app)
|
||||
# Paths to exclude from audit logging
|
||||
self.exclude_paths = exclude_paths or [
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/static",
|
||||
"/favicon.ico",
|
||||
]
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# Skip audit logging for excluded paths
|
||||
if any(request.url.path.startswith(path) for path in self.exclude_paths):
|
||||
return await call_next(request)
|
||||
|
||||
# Skip audit logging for health checks and static assets
|
||||
if request.url.path in ["/", "/health"] or "/static/" in request.url.path:
|
||||
return await call_next(request)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Extract user information from request
|
||||
user_info = await self._extract_user_info(request)
|
||||
|
||||
# Prepare audit data
|
||||
audit_data = {
|
||||
"method": request.method,
|
||||
"path": request.url.path,
|
||||
"query_params": dict(request.query_params),
|
||||
"ip_address": self._get_client_ip(request),
|
||||
"user_agent": request.headers.get("user-agent"),
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Calculate response time
|
||||
process_time = time.time() - start_time
|
||||
audit_data["response_time"] = round(process_time * 1000, 2) # milliseconds
|
||||
audit_data["status_code"] = response.status_code
|
||||
audit_data["success"] = 200 <= response.status_code < 400
|
||||
|
||||
# Log the audit event asynchronously
|
||||
try:
|
||||
await self._log_audit_event(user_info, audit_data, request)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log audit event: {e}")
|
||||
# Don't fail the request if audit logging fails
|
||||
|
||||
return response
|
||||
|
||||
async def _extract_user_info(self, request: Request) -> Optional[Dict[str, Any]]:
|
||||
"""Extract user information from request headers"""
|
||||
try:
|
||||
# Try to get user info from Authorization header
|
||||
auth_header = request.headers.get("authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
token = auth_header.split(" ")[1]
|
||||
payload = verify_token(token)
|
||||
return {
|
||||
"user_id": int(payload.get("sub")) if payload.get("sub") else None,
|
||||
"email": payload.get("email"),
|
||||
"is_superuser": payload.get("is_superuser", False),
|
||||
"role": payload.get("role"),
|
||||
}
|
||||
except Exception:
|
||||
# If token verification fails, continue without user info
|
||||
pass
|
||||
|
||||
# Try to get user info from API key header
|
||||
api_key = request.headers.get("x-api-key")
|
||||
if api_key:
|
||||
# Would need to implement API key lookup here
|
||||
# For now, just indicate it's an API key request
|
||||
return {
|
||||
"user_id": None,
|
||||
"email": "api_key_user",
|
||||
"is_superuser": False,
|
||||
"role": "api_user",
|
||||
"auth_type": "api_key",
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Get client IP address with proxy support"""
|
||||
# Check for forwarded headers first (for reverse proxy setups)
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
# Take the first IP in the chain
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
forwarded = request.headers.get("x-forwarded")
|
||||
if forwarded:
|
||||
return forwarded
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# Fall back to direct client IP
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
async def _log_audit_event(
|
||||
self,
|
||||
user_info: Optional[Dict[str, Any]],
|
||||
audit_data: Dict[str, Any],
|
||||
request: Request
|
||||
):
|
||||
"""Log the audit event to database"""
|
||||
|
||||
# Determine action based on HTTP method and path
|
||||
action = self._determine_action(request.method, request.url.path)
|
||||
|
||||
# Determine resource type and ID from path
|
||||
resource_type, resource_id = self._parse_resource_from_path(request.url.path)
|
||||
|
||||
# Create description
|
||||
description = self._create_description(request.method, request.url.path, audit_data["success"])
|
||||
|
||||
# Determine severity
|
||||
severity = self._determine_severity(request.method, audit_data["status_code"], request.url.path)
|
||||
|
||||
# Create audit log entry
|
||||
try:
|
||||
async with get_db_session() as db:
|
||||
audit_log = AuditLog(
|
||||
user_id=user_info.get("user_id") if user_info else None,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
description=description,
|
||||
details={
|
||||
"request": {
|
||||
"method": audit_data["method"],
|
||||
"path": audit_data["path"],
|
||||
"query_params": audit_data["query_params"],
|
||||
"response_time_ms": audit_data["response_time"],
|
||||
},
|
||||
"user_info": user_info,
|
||||
},
|
||||
ip_address=audit_data["ip_address"],
|
||||
user_agent=audit_data["user_agent"],
|
||||
severity=severity,
|
||||
category=self._determine_category(request.url.path),
|
||||
success=audit_data["success"],
|
||||
tags=self._generate_tags(request.method, request.url.path),
|
||||
)
|
||||
|
||||
db.add(audit_log)
|
||||
await db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save audit log to database: {e}")
|
||||
# Could implement fallback logging to file here
|
||||
|
||||
def _determine_action(self, method: str, path: str) -> str:
|
||||
"""Determine action type from HTTP method and path"""
|
||||
method = method.upper()
|
||||
|
||||
if method == "GET":
|
||||
return AuditAction.READ
|
||||
elif method == "POST":
|
||||
if "login" in path.lower():
|
||||
return AuditAction.LOGIN
|
||||
elif "logout" in path.lower():
|
||||
return AuditAction.LOGOUT
|
||||
else:
|
||||
return AuditAction.CREATE
|
||||
elif method == "PUT" or method == "PATCH":
|
||||
return AuditAction.UPDATE
|
||||
elif method == "DELETE":
|
||||
return AuditAction.DELETE
|
||||
else:
|
||||
return method.lower()
|
||||
|
||||
def _parse_resource_from_path(self, path: str) -> tuple[str, Optional[str]]:
|
||||
"""Parse resource type and ID from URL path"""
|
||||
path_parts = path.strip("/").split("/")
|
||||
|
||||
# Skip API version prefix
|
||||
if path_parts and path_parts[0] in ["api", "api-internal"]:
|
||||
path_parts = path_parts[2:] # Skip 'api' and 'v1'
|
||||
|
||||
if not path_parts:
|
||||
return "system", None
|
||||
|
||||
resource_type = path_parts[0]
|
||||
resource_id = None
|
||||
|
||||
# Try to find numeric ID in path
|
||||
for part in path_parts[1:]:
|
||||
if part.isdigit():
|
||||
resource_id = part
|
||||
break
|
||||
|
||||
return resource_type, resource_id
|
||||
|
||||
def _create_description(self, method: str, path: str, success: bool) -> str:
|
||||
"""Create human-readable description of the action"""
|
||||
action_verbs = {
|
||||
"GET": "accessed" if success else "attempted to access",
|
||||
"POST": "created" if success else "attempted to create",
|
||||
"PUT": "updated" if success else "attempted to update",
|
||||
"PATCH": "modified" if success else "attempted to modify",
|
||||
"DELETE": "deleted" if success else "attempted to delete",
|
||||
}
|
||||
|
||||
verb = action_verbs.get(method, method.lower())
|
||||
resource = path.strip("/").split("/")[-1] if "/" in path else path
|
||||
|
||||
return f"User {verb} {resource}"
|
||||
|
||||
def _determine_severity(self, method: str, status_code: int, path: str) -> str:
|
||||
"""Determine severity level based on action and outcome"""
|
||||
|
||||
# Critical operations
|
||||
if any(keyword in path.lower() for keyword in ["delete", "password", "admin", "key"]):
|
||||
return AuditSeverity.HIGH
|
||||
|
||||
# Failed operations
|
||||
if status_code >= 400:
|
||||
if status_code >= 500:
|
||||
return AuditSeverity.CRITICAL
|
||||
elif status_code in [401, 403]:
|
||||
return AuditSeverity.HIGH
|
||||
else:
|
||||
return AuditSeverity.MEDIUM
|
||||
|
||||
# Write operations
|
||||
if method in ["POST", "PUT", "PATCH", "DELETE"]:
|
||||
return AuditSeverity.MEDIUM
|
||||
|
||||
# Read operations
|
||||
return AuditSeverity.LOW
|
||||
|
||||
def _determine_category(self, path: str) -> str:
|
||||
"""Determine category based on path"""
|
||||
path = path.lower()
|
||||
|
||||
if any(keyword in path for keyword in ["auth", "login", "logout", "token"]):
|
||||
return "authentication"
|
||||
elif any(keyword in path for keyword in ["user", "admin", "role", "permission"]):
|
||||
return "user_management"
|
||||
elif any(keyword in path for keyword in ["api-key", "key"]):
|
||||
return "security"
|
||||
elif any(keyword in path for keyword in ["budget", "billing", "usage"]):
|
||||
return "financial"
|
||||
elif any(keyword in path for keyword in ["audit", "log"]):
|
||||
return "audit"
|
||||
elif any(keyword in path for keyword in ["setting", "config"]):
|
||||
return "configuration"
|
||||
else:
|
||||
return "general"
|
||||
|
||||
def _generate_tags(self, method: str, path: str) -> list[str]:
|
||||
"""Generate tags for the audit log"""
|
||||
tags = [method.lower()]
|
||||
|
||||
path_parts = path.strip("/").split("/")
|
||||
if path_parts:
|
||||
tags.append(path_parts[0])
|
||||
|
||||
# Add special tags
|
||||
if "admin" in path.lower():
|
||||
tags.append("admin_action")
|
||||
if any(keyword in path.lower() for keyword in ["password", "auth", "login"]):
|
||||
tags.append("security_action")
|
||||
|
||||
return tags
|
||||
|
||||
|
||||
class LoginAuditMiddleware(BaseHTTPMiddleware):
|
||||
"""Specialized middleware for login/logout events"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# Only process auth-related endpoints
|
||||
if not any(path in request.url.path for path in ["/auth/login", "/auth/logout", "/auth/refresh"]):
|
||||
return await call_next(request)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Store request body for login attempts
|
||||
request_body = None
|
||||
if request.method == "POST" and "/login" in request.url.path:
|
||||
try:
|
||||
body = await request.body()
|
||||
if body:
|
||||
request_body = json.loads(body.decode())
|
||||
# Re-create request with body for downstream processing
|
||||
from starlette.requests import Request as StarletteRequest
|
||||
from io import BytesIO
|
||||
request._body = body
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse login request body: {e}")
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
# Log login/logout events
|
||||
try:
|
||||
await self._log_auth_event(request, response, request_body, time.time() - start_time)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log auth event: {e}")
|
||||
|
||||
return response
|
||||
|
||||
async def _log_auth_event(self, request: Request, response: Response, request_body: dict, process_time: float):
|
||||
"""Log authentication events"""
|
||||
|
||||
success = 200 <= response.status_code < 300
|
||||
|
||||
if "/login" in request.url.path:
|
||||
# Extract email/username from request
|
||||
identifier = None
|
||||
if request_body:
|
||||
identifier = request_body.get("email") or request_body.get("username")
|
||||
|
||||
# For successful logins, we could extract user_id from response
|
||||
# For now, we'll use the identifier
|
||||
|
||||
async with get_db_session() as db:
|
||||
audit_log = AuditLog.create_login_event(
|
||||
user_id=None, # Would need to extract from response for successful logins
|
||||
success=success,
|
||||
ip_address=self._get_client_ip(request),
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
error_message=f"HTTP {response.status_code}" if not success else None,
|
||||
)
|
||||
|
||||
# Add additional details
|
||||
audit_log.details.update({
|
||||
"identifier": identifier,
|
||||
"response_time_ms": round(process_time * 1000, 2),
|
||||
})
|
||||
|
||||
db.add(audit_log)
|
||||
await db.commit()
|
||||
|
||||
elif "/logout" in request.url.path:
|
||||
# Extract user info from token if available
|
||||
user_id = None
|
||||
try:
|
||||
auth_header = request.headers.get("authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
token = auth_header.split(" ")[1]
|
||||
payload = verify_token(token)
|
||||
user_id = int(payload.get("sub")) if payload.get("sub") else None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async with get_db_session() as db:
|
||||
audit_log = AuditLog.create_logout_event(
|
||||
user_id=user_id,
|
||||
session_id=None, # Could extract from token if stored
|
||||
)
|
||||
|
||||
db.add(audit_log)
|
||||
await db.commit()
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Get client IP address with proxy support"""
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
return request.client.host if request.client else "unknown"
|
||||
@@ -23,8 +23,12 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
|
||||
request_id = str(uuid4())
|
||||
|
||||
# Skip debugging for health checks and static files
|
||||
if request.url.path in ["/health", "/docs", "/redoc", "/openapi.json"] or \
|
||||
request.url.path.startswith("/static"):
|
||||
if request.url.path in [
|
||||
"/health",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
] or request.url.path.startswith("/static"):
|
||||
return await call_next(request)
|
||||
|
||||
# Log request details
|
||||
@@ -37,7 +41,7 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
|
||||
try:
|
||||
request_body = json.loads(body_bytes)
|
||||
except json.JSONDecodeError:
|
||||
request_body = body_bytes.decode('utf-8', errors='replace')
|
||||
request_body = body_bytes.decode("utf-8", errors="replace")
|
||||
# Restore body for downstream processing
|
||||
request._body = body_bytes
|
||||
except Exception:
|
||||
@@ -45,8 +49,9 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
# Extract headers we care about
|
||||
headers_to_log = {
|
||||
"authorization": request.headers.get("Authorization", "")[:50] + "..." if
|
||||
request.headers.get("Authorization") else None,
|
||||
"authorization": request.headers.get("Authorization", "")[:50] + "..."
|
||||
if request.headers.get("Authorization")
|
||||
else None,
|
||||
"content-type": request.headers.get("Content-Type"),
|
||||
"user-agent": request.headers.get("User-Agent"),
|
||||
"x-forwarded-for": request.headers.get("X-Forwarded-For"),
|
||||
@@ -54,7 +59,9 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
|
||||
}
|
||||
|
||||
# Log request
|
||||
logger.info("=== API REQUEST DEBUG ===", extra={
|
||||
logger.info(
|
||||
"=== API REQUEST DEBUG ===",
|
||||
extra={
|
||||
"request_id": request_id,
|
||||
"method": request.method,
|
||||
"url": str(request.url),
|
||||
@@ -63,8 +70,9 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
|
||||
"headers": {k: v for k, v in headers_to_log.items() if v is not None},
|
||||
"body": request_body,
|
||||
"client_ip": request.client.host if request.client else None,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
# Process the request
|
||||
start_time = time.time()
|
||||
@@ -73,33 +81,43 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
# Add timeout detection
|
||||
try:
|
||||
logger.info(f"=== START PROCESSING REQUEST === {request_id} at {datetime.utcnow().isoformat()}")
|
||||
logger.info(
|
||||
f"=== START PROCESSING REQUEST === {request_id} at {datetime.utcnow().isoformat()}"
|
||||
)
|
||||
logger.info(f"Request path: {request.url.path}")
|
||||
logger.info(f"Request method: {request.method}")
|
||||
|
||||
# Check if this is the login endpoint
|
||||
if request.url.path == "/api-internal/v1/auth/login" and request.method == "POST":
|
||||
if (
|
||||
request.url.path == "/api-internal/v1/auth/login"
|
||||
and request.method == "POST"
|
||||
):
|
||||
logger.info(f"=== LOGIN REQUEST DETECTED === {request_id}")
|
||||
|
||||
response = await call_next(request)
|
||||
logger.info(f"=== REQUEST COMPLETED === {request_id} at {datetime.utcnow().isoformat()}")
|
||||
logger.info(
|
||||
f"=== REQUEST COMPLETED === {request_id} at {datetime.utcnow().isoformat()}"
|
||||
)
|
||||
|
||||
# Capture response body for successful JSON responses
|
||||
if response.status_code < 400 and isinstance(response, JSONResponse):
|
||||
try:
|
||||
response_body = json.loads(response.body.decode('utf-8'))
|
||||
response_body = json.loads(response.body.decode("utf-8"))
|
||||
except:
|
||||
response_body = "[Failed to decode response body]"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Request processing failed: {str(e)}", extra={
|
||||
logger.error(
|
||||
f"Request processing failed: {str(e)}",
|
||||
extra={
|
||||
"request_id": request_id,
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__
|
||||
})
|
||||
"error_type": type(e).__name__,
|
||||
},
|
||||
)
|
||||
response = JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": "INTERNAL_ERROR", "message": "Internal server error"}
|
||||
content={"error": "INTERNAL_ERROR", "message": "Internal server error"},
|
||||
)
|
||||
|
||||
# Calculate timing
|
||||
@@ -107,14 +125,17 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
|
||||
duration = (end_time - start_time) * 1000 # milliseconds
|
||||
|
||||
# Log response
|
||||
logger.info("=== API RESPONSE DEBUG ===", extra={
|
||||
logger.info(
|
||||
"=== API RESPONSE DEBUG ===",
|
||||
extra={
|
||||
"request_id": request_id,
|
||||
"status_code": response.status_code,
|
||||
"duration_ms": round(duration, 2),
|
||||
"response_body": response_body,
|
||||
"response_headers": dict(response.headers),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@@ -9,9 +9,31 @@ from .budget import Budget
|
||||
from .audit_log import AuditLog
|
||||
from .rag_collection import RagCollection
|
||||
from .rag_document import RagDocument
|
||||
from .chatbot import ChatbotInstance, ChatbotConversation, ChatbotMessage, ChatbotAnalytics
|
||||
from .chatbot import (
|
||||
ChatbotInstance,
|
||||
ChatbotConversation,
|
||||
ChatbotMessage,
|
||||
ChatbotAnalytics,
|
||||
)
|
||||
from .prompt_template import PromptTemplate, ChatbotPromptVariable
|
||||
from .plugin import Plugin, PluginConfiguration, PluginInstance, PluginAuditLog, PluginCronJob, PluginAPIGateway
|
||||
from .plugin import (
|
||||
Plugin,
|
||||
PluginConfiguration,
|
||||
PluginInstance,
|
||||
PluginAuditLog,
|
||||
PluginCronJob,
|
||||
PluginAPIGateway,
|
||||
)
|
||||
from .role import Role, RoleLevel
|
||||
from .tool import Tool, ToolExecution, ToolCategory, ToolType, ToolStatus
|
||||
from .notification import (
|
||||
Notification,
|
||||
NotificationTemplate,
|
||||
NotificationChannel,
|
||||
NotificationType,
|
||||
NotificationPriority,
|
||||
NotificationStatus,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
@@ -32,5 +54,18 @@ __all__ = [
|
||||
"PluginInstance",
|
||||
"PluginAuditLog",
|
||||
"PluginCronJob",
|
||||
"PluginAPIGateway"
|
||||
"PluginAPIGateway",
|
||||
"Role",
|
||||
"RoleLevel",
|
||||
"Tool",
|
||||
"ToolExecution",
|
||||
"ToolCategory",
|
||||
"ToolType",
|
||||
"ToolStatus",
|
||||
"Notification",
|
||||
"NotificationTemplate",
|
||||
"NotificationChannel",
|
||||
"NotificationType",
|
||||
"NotificationPriority",
|
||||
"NotificationStatus",
|
||||
]
|
||||
@@ -3,7 +3,16 @@ API Key model
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any, List
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON, ForeignKey
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
Boolean,
|
||||
Text,
|
||||
JSON,
|
||||
ForeignKey,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db.database import Base
|
||||
|
||||
@@ -16,15 +25,21 @@ class APIKey(Base):
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False) # Human-readable name for the API key
|
||||
key_hash = Column(String, unique=True, index=True, nullable=False) # Hashed API key
|
||||
key_prefix = Column(String, index=True, nullable=False) # First 8 characters for identification
|
||||
key_prefix = Column(
|
||||
String, index=True, nullable=False
|
||||
) # First 8 characters for identification
|
||||
|
||||
# User relationship
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
user = relationship("User", back_populates="api_keys")
|
||||
|
||||
# Related data relationships
|
||||
budgets = relationship("Budget", back_populates="api_key", cascade="all, delete-orphan")
|
||||
usage_tracking = relationship("UsageTracking", back_populates="api_key", cascade="all, delete-orphan")
|
||||
budgets = relationship(
|
||||
"Budget", back_populates="api_key", cascade="all, delete-orphan"
|
||||
)
|
||||
usage_tracking = relationship(
|
||||
"UsageTracking", back_populates="api_key", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Key status and permissions
|
||||
is_active = Column(Boolean, default=True)
|
||||
@@ -40,7 +55,9 @@ class APIKey(Base):
|
||||
allowed_models = Column(JSON, default=list) # List of allowed LLM models
|
||||
allowed_endpoints = Column(JSON, default=list) # List of allowed API endpoints
|
||||
allowed_ips = Column(JSON, default=list) # IP whitelist
|
||||
allowed_chatbots = Column(JSON, default=list) # List of allowed chatbot IDs for chatbot-specific keys
|
||||
allowed_chatbots = Column(
|
||||
JSON, default=list
|
||||
) # List of allowed chatbot IDs for chatbot-specific keys
|
||||
|
||||
# Budget configuration
|
||||
is_unlimited = Column(Boolean, default=True) # Unlimited budget flag
|
||||
@@ -63,8 +80,12 @@ class APIKey(Base):
|
||||
total_cost = Column(Integer, default=0) # In cents
|
||||
|
||||
# Relationships
|
||||
usage_tracking = relationship("UsageTracking", back_populates="api_key", cascade="all, delete-orphan")
|
||||
budgets = relationship("Budget", back_populates="api_key", cascade="all, delete-orphan")
|
||||
usage_tracking = relationship(
|
||||
"UsageTracking", back_populates="api_key", cascade="all, delete-orphan"
|
||||
)
|
||||
budgets = relationship(
|
||||
"Budget", back_populates="api_key", cascade="all, delete-orphan"
|
||||
)
|
||||
plugin_audit_logs = relationship("PluginAuditLog", back_populates="api_key")
|
||||
|
||||
def __repr__(self):
|
||||
@@ -91,14 +112,16 @@ class APIKey(Base):
|
||||
"tags": self.tags,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
|
||||
"last_used_at": self.last_used_at.isoformat()
|
||||
if self.last_used_at
|
||||
else None,
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"total_requests": self.total_requests,
|
||||
"total_tokens": self.total_tokens,
|
||||
"total_cost_cents": self.total_cost,
|
||||
"is_unlimited": self.is_unlimited,
|
||||
"budget_limit": self.budget_limit_cents, # Map to budget_limit for API response
|
||||
"budget_type": self.budget_type
|
||||
"budget_type": self.budget_type,
|
||||
}
|
||||
|
||||
if include_sensitive:
|
||||
@@ -222,7 +245,9 @@ class APIKey(Base):
|
||||
self.allowed_chatbots.remove(chatbot_id)
|
||||
|
||||
@classmethod
|
||||
def create_default_key(cls, user_id: int, name: str, key_hash: str, key_prefix: str) -> "APIKey":
|
||||
def create_default_key(
|
||||
cls, user_id: int, name: str, key_hash: str, key_prefix: str
|
||||
) -> "APIKey":
|
||||
"""Create a default API key with standard permissions"""
|
||||
return cls(
|
||||
name=name,
|
||||
@@ -230,17 +255,8 @@ class APIKey(Base):
|
||||
key_prefix=key_prefix,
|
||||
user_id=user_id,
|
||||
is_active=True,
|
||||
permissions={
|
||||
"read": True,
|
||||
"write": True,
|
||||
"chat": True,
|
||||
"embeddings": True
|
||||
},
|
||||
scopes=[
|
||||
"chat.completions",
|
||||
"embeddings.create",
|
||||
"models.list"
|
||||
],
|
||||
permissions={"read": True, "write": True, "chat": True, "embeddings": True},
|
||||
scopes=["chat.completions", "embeddings.create", "models.list"],
|
||||
rate_limit_per_minute=60,
|
||||
rate_limit_per_hour=3600,
|
||||
rate_limit_per_day=86400,
|
||||
@@ -248,12 +264,19 @@ class APIKey(Base):
|
||||
allowed_endpoints=[], # All endpoints allowed by default
|
||||
allowed_ips=[], # All IPs allowed by default
|
||||
description="Default API key with standard permissions",
|
||||
tags=["default"]
|
||||
tags=["default"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_restricted_key(cls, user_id: int, name: str, key_hash: str, key_prefix: str,
|
||||
models: List[str], endpoints: List[str]) -> "APIKey":
|
||||
def create_restricted_key(
|
||||
cls,
|
||||
user_id: int,
|
||||
name: str,
|
||||
key_hash: str,
|
||||
key_prefix: str,
|
||||
models: List[str],
|
||||
endpoints: List[str],
|
||||
) -> "APIKey":
|
||||
"""Create a restricted API key with limited permissions"""
|
||||
return cls(
|
||||
name=name,
|
||||
@@ -261,13 +284,8 @@ class APIKey(Base):
|
||||
key_prefix=key_prefix,
|
||||
user_id=user_id,
|
||||
is_active=True,
|
||||
permissions={
|
||||
"read": True,
|
||||
"chat": True
|
||||
},
|
||||
scopes=[
|
||||
"chat.completions"
|
||||
],
|
||||
permissions={"read": True, "chat": True},
|
||||
scopes=["chat.completions"],
|
||||
rate_limit_per_minute=30,
|
||||
rate_limit_per_hour=1800,
|
||||
rate_limit_per_day=43200,
|
||||
@@ -275,12 +293,19 @@ class APIKey(Base):
|
||||
allowed_endpoints=endpoints,
|
||||
allowed_ips=[],
|
||||
description="Restricted API key with limited permissions",
|
||||
tags=["restricted"]
|
||||
tags=["restricted"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_chatbot_key(cls, user_id: int, name: str, key_hash: str, key_prefix: str,
|
||||
chatbot_id: str, chatbot_name: str) -> "APIKey":
|
||||
def create_chatbot_key(
|
||||
cls,
|
||||
user_id: int,
|
||||
name: str,
|
||||
key_hash: str,
|
||||
key_prefix: str,
|
||||
chatbot_id: str,
|
||||
chatbot_name: str,
|
||||
) -> "APIKey":
|
||||
"""Create a chatbot-specific API key"""
|
||||
return cls(
|
||||
name=name,
|
||||
@@ -288,22 +313,18 @@ class APIKey(Base):
|
||||
key_prefix=key_prefix,
|
||||
user_id=user_id,
|
||||
is_active=True,
|
||||
permissions={
|
||||
"chatbot": True
|
||||
},
|
||||
scopes=[
|
||||
"chatbot.chat"
|
||||
],
|
||||
permissions={"chatbot": True},
|
||||
scopes=["chatbot.chat"],
|
||||
rate_limit_per_minute=100,
|
||||
rate_limit_per_hour=6000,
|
||||
rate_limit_per_day=144000,
|
||||
allowed_models=[], # Will use chatbot's configured model
|
||||
allowed_endpoints=[
|
||||
f"/api/v1/chatbot/external/{chatbot_id}/chat",
|
||||
f"/api/v1/chatbot/external/{chatbot_id}/chat/completions"
|
||||
f"/api/v1/chatbot/external/{chatbot_id}/chat/completions",
|
||||
],
|
||||
allowed_ips=[],
|
||||
allowed_chatbots=[chatbot_id],
|
||||
description=f"API key for chatbot: {chatbot_name}",
|
||||
tags=["chatbot", f"chatbot-{chatbot_id}"]
|
||||
tags=["chatbot", f"chatbot-{chatbot_id}"],
|
||||
)
|
||||
@@ -3,7 +3,16 @@ Audit log model for tracking system events and user actions
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from sqlalchemy import Column, Integer, String, DateTime, JSON, ForeignKey, Text, Boolean
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
JSON,
|
||||
ForeignKey,
|
||||
Text,
|
||||
Boolean,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db.database import Base
|
||||
from enum import Enum
|
||||
@@ -11,6 +20,7 @@ from enum import Enum
|
||||
|
||||
class AuditAction(str, Enum):
|
||||
"""Audit action types"""
|
||||
|
||||
CREATE = "create"
|
||||
READ = "read"
|
||||
UPDATE = "update"
|
||||
@@ -32,6 +42,7 @@ class AuditAction(str, Enum):
|
||||
|
||||
class AuditSeverity(str, Enum):
|
||||
"""Audit severity levels"""
|
||||
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
@@ -51,7 +62,9 @@ class AuditLog(Base):
|
||||
|
||||
# Event details
|
||||
action = Column(String, nullable=False)
|
||||
resource_type = Column(String, nullable=False) # user, api_key, budget, module, etc.
|
||||
resource_type = Column(
|
||||
String, nullable=False
|
||||
) # user, api_key, budget, module, etc.
|
||||
resource_id = Column(String, nullable=True) # ID of the affected resource
|
||||
|
||||
# Event description and details
|
||||
@@ -74,7 +87,9 @@ class AuditLog(Base):
|
||||
|
||||
# Additional metadata
|
||||
tags = Column(JSON, default=list)
|
||||
audit_metadata = Column("metadata", JSON, default=dict) # Map to 'metadata' column in DB
|
||||
audit_metadata = Column(
|
||||
"metadata", JSON, default=dict
|
||||
) # Map to 'metadata' column in DB
|
||||
|
||||
# Before/after values for data changes
|
||||
old_values = Column(JSON, nullable=True)
|
||||
@@ -84,7 +99,9 @@ class AuditLog(Base):
|
||||
created_at = Column(DateTime, default=datetime.utcnow, index=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<AuditLog(id={self.id}, action='{self.action}', user_id={self.user_id})>"
|
||||
return (
|
||||
f"<AuditLog(id={self.id}, action='{self.action}', user_id={self.user_id})>"
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert audit log to dictionary for API responses"""
|
||||
@@ -108,7 +125,7 @@ class AuditLog(Base):
|
||||
"metadata": self.audit_metadata,
|
||||
"old_values": self.old_values,
|
||||
"new_values": self.new_values,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
def is_security_event(self) -> bool:
|
||||
@@ -120,7 +137,7 @@ class AuditLog(Base):
|
||||
AuditAction.API_KEY_DELETE,
|
||||
AuditAction.PERMISSION_GRANT,
|
||||
AuditAction.PERMISSION_REVOKE,
|
||||
AuditAction.SECURITY_EVENT
|
||||
AuditAction.SECURITY_EVENT,
|
||||
]
|
||||
return self.action in security_actions or self.category == "security"
|
||||
|
||||
@@ -150,9 +167,15 @@ class AuditLog(Base):
|
||||
self.new_values = new_values
|
||||
|
||||
@classmethod
|
||||
def create_login_event(cls, user_id: int, success: bool = True,
|
||||
ip_address: str = None, user_agent: str = None,
|
||||
session_id: str = None, error_message: str = None) -> "AuditLog":
|
||||
def create_login_event(
|
||||
cls,
|
||||
user_id: int,
|
||||
success: bool = True,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
session_id: str = None,
|
||||
error_message: str = None,
|
||||
) -> "AuditLog":
|
||||
"""Create a login audit event"""
|
||||
return cls(
|
||||
user_id=user_id,
|
||||
@@ -160,10 +183,7 @@ class AuditLog(Base):
|
||||
resource_type="user",
|
||||
resource_id=str(user_id),
|
||||
description=f"User login {'successful' if success else 'failed'}",
|
||||
details={
|
||||
"login_method": "password",
|
||||
"success": success
|
||||
},
|
||||
details={"login_method": "password", "success": success},
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
session_id=session_id,
|
||||
@@ -171,7 +191,7 @@ class AuditLog(Base):
|
||||
category="security",
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
tags=["authentication", "login"]
|
||||
tags=["authentication", "login"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -183,20 +203,24 @@ class AuditLog(Base):
|
||||
resource_type="user",
|
||||
resource_id=str(user_id),
|
||||
description="User logout",
|
||||
details={
|
||||
"logout_method": "manual"
|
||||
},
|
||||
details={"logout_method": "manual"},
|
||||
session_id=session_id,
|
||||
severity=AuditSeverity.LOW,
|
||||
category="security",
|
||||
success=True,
|
||||
tags=["authentication", "logout"]
|
||||
tags=["authentication", "logout"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_api_key_event(cls, user_id: int, action: str, api_key_id: int,
|
||||
api_key_name: str, success: bool = True,
|
||||
error_message: str = None) -> "AuditLog":
|
||||
def create_api_key_event(
|
||||
cls,
|
||||
user_id: int,
|
||||
action: str,
|
||||
api_key_id: int,
|
||||
api_key_name: str,
|
||||
success: bool = True,
|
||||
error_message: str = None,
|
||||
) -> "AuditLog":
|
||||
"""Create an API key audit event"""
|
||||
return cls(
|
||||
user_id=user_id,
|
||||
@@ -204,21 +228,24 @@ class AuditLog(Base):
|
||||
resource_type="api_key",
|
||||
resource_id=str(api_key_id),
|
||||
description=f"API key {action}: {api_key_name}",
|
||||
details={
|
||||
"api_key_name": api_key_name,
|
||||
"action": action
|
||||
},
|
||||
details={"api_key_name": api_key_name, "action": action},
|
||||
severity=AuditSeverity.MEDIUM,
|
||||
category="security",
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
tags=["api_key", action]
|
||||
tags=["api_key", action],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_budget_event(cls, user_id: int, action: str, budget_id: int,
|
||||
budget_name: str, details: Dict[str, Any] = None,
|
||||
success: bool = True) -> "AuditLog":
|
||||
def create_budget_event(
|
||||
cls,
|
||||
user_id: int,
|
||||
action: str,
|
||||
budget_id: int,
|
||||
budget_name: str,
|
||||
details: Dict[str, Any] = None,
|
||||
success: bool = True,
|
||||
) -> "AuditLog":
|
||||
"""Create a budget audit event"""
|
||||
return cls(
|
||||
user_id=user_id,
|
||||
@@ -227,16 +254,24 @@ class AuditLog(Base):
|
||||
resource_id=str(budget_id),
|
||||
description=f"Budget {action}: {budget_name}",
|
||||
details=details or {},
|
||||
severity=AuditSeverity.MEDIUM if action == AuditAction.BUDGET_EXCEED else AuditSeverity.LOW,
|
||||
severity=AuditSeverity.MEDIUM
|
||||
if action == AuditAction.BUDGET_EXCEED
|
||||
else AuditSeverity.LOW,
|
||||
category="financial",
|
||||
success=success,
|
||||
tags=["budget", action]
|
||||
tags=["budget", action],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_module_event(cls, user_id: int, action: str, module_name: str,
|
||||
success: bool = True, error_message: str = None,
|
||||
details: Dict[str, Any] = None) -> "AuditLog":
|
||||
def create_module_event(
|
||||
cls,
|
||||
user_id: int,
|
||||
action: str,
|
||||
module_name: str,
|
||||
success: bool = True,
|
||||
error_message: str = None,
|
||||
details: Dict[str, Any] = None,
|
||||
) -> "AuditLog":
|
||||
"""Create a module audit event"""
|
||||
return cls(
|
||||
user_id=user_id,
|
||||
@@ -249,12 +284,18 @@ class AuditLog(Base):
|
||||
category="system",
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
tags=["module", action]
|
||||
tags=["module", action],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_permission_event(cls, user_id: int, action: str, target_user_id: int,
|
||||
permission: str, success: bool = True) -> "AuditLog":
|
||||
def create_permission_event(
|
||||
cls,
|
||||
user_id: int,
|
||||
action: str,
|
||||
target_user_id: int,
|
||||
permission: str,
|
||||
success: bool = True,
|
||||
) -> "AuditLog":
|
||||
"""Create a permission audit event"""
|
||||
return cls(
|
||||
user_id=user_id,
|
||||
@@ -262,21 +303,23 @@ class AuditLog(Base):
|
||||
resource_type="permission",
|
||||
resource_id=str(target_user_id),
|
||||
description=f"Permission {action}: {permission} for user {target_user_id}",
|
||||
details={
|
||||
"permission": permission,
|
||||
"target_user_id": target_user_id
|
||||
},
|
||||
details={"permission": permission, "target_user_id": target_user_id},
|
||||
severity=AuditSeverity.HIGH,
|
||||
category="security",
|
||||
success=success,
|
||||
tags=["permission", action]
|
||||
tags=["permission", action],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_security_event(cls, user_id: int, event_type: str, description: str,
|
||||
def create_security_event(
|
||||
cls,
|
||||
user_id: int,
|
||||
event_type: str,
|
||||
description: str,
|
||||
severity: str = AuditSeverity.HIGH,
|
||||
details: Dict[str, Any] = None,
|
||||
ip_address: str = None) -> "AuditLog":
|
||||
ip_address: str = None,
|
||||
) -> "AuditLog":
|
||||
"""Create a security audit event"""
|
||||
return cls(
|
||||
user_id=user_id,
|
||||
@@ -289,15 +332,19 @@ class AuditLog(Base):
|
||||
severity=severity,
|
||||
category="security",
|
||||
success=False, # Security events are typically failures
|
||||
tags=["security", event_type]
|
||||
tags=["security", event_type],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_system_event(cls, action: str, description: str,
|
||||
def create_system_event(
|
||||
cls,
|
||||
action: str,
|
||||
description: str,
|
||||
resource_type: str = "system",
|
||||
resource_id: str = None,
|
||||
severity: str = AuditSeverity.LOW,
|
||||
details: Dict[str, Any] = None) -> "AuditLog":
|
||||
details: Dict[str, Any] = None,
|
||||
) -> "AuditLog":
|
||||
"""Create a system audit event"""
|
||||
return cls(
|
||||
user_id=None, # System events don't have a user
|
||||
@@ -309,14 +356,20 @@ class AuditLog(Base):
|
||||
severity=severity,
|
||||
category="system",
|
||||
success=True,
|
||||
tags=["system", action]
|
||||
tags=["system", action],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_data_change_event(cls, user_id: int, action: str, resource_type: str,
|
||||
resource_id: str, description: str,
|
||||
def create_data_change_event(
|
||||
cls,
|
||||
user_id: int,
|
||||
action: str,
|
||||
resource_type: str,
|
||||
resource_id: str,
|
||||
description: str,
|
||||
old_values: Dict[str, Any],
|
||||
new_values: Dict[str, Any]) -> "AuditLog":
|
||||
new_values: Dict[str, Any],
|
||||
) -> "AuditLog":
|
||||
"""Create a data change audit event"""
|
||||
return cls(
|
||||
user_id=user_id,
|
||||
@@ -329,7 +382,7 @@ class AuditLog(Base):
|
||||
severity=AuditSeverity.LOW,
|
||||
category="data",
|
||||
success=True,
|
||||
tags=["data_change", action]
|
||||
tags=["data_change", action],
|
||||
)
|
||||
|
||||
def get_summary(self) -> Dict[str, Any]:
|
||||
@@ -342,5 +395,5 @@ class AuditLog(Base):
|
||||
"severity": self.severity,
|
||||
"success": self.success,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"user_id": self.user_id
|
||||
"user_id": self.user_id,
|
||||
}
|
||||
@@ -5,13 +5,24 @@ Budget model for managing spending limits and cost control
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
from enum import Enum
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON, ForeignKey, Float
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
Boolean,
|
||||
Text,
|
||||
JSON,
|
||||
ForeignKey,
|
||||
Float,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class BudgetType(str, Enum):
|
||||
"""Budget type enumeration"""
|
||||
|
||||
USER = "user"
|
||||
API_KEY = "api_key"
|
||||
GLOBAL = "global"
|
||||
@@ -19,6 +30,7 @@ class BudgetType(str, Enum):
|
||||
|
||||
class BudgetPeriod(str, Enum):
|
||||
"""Budget period types"""
|
||||
|
||||
DAILY = "daily"
|
||||
WEEKLY = "weekly"
|
||||
MONTHLY = "monthly"
|
||||
@@ -38,18 +50,26 @@ class Budget(Base):
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
user = relationship("User", back_populates="budgets")
|
||||
|
||||
api_key_id = Column(Integer, ForeignKey("api_keys.id"), nullable=True) # Optional: specific to an API key
|
||||
api_key_id = Column(
|
||||
Integer, ForeignKey("api_keys.id"), nullable=True
|
||||
) # Optional: specific to an API key
|
||||
api_key = relationship("APIKey", back_populates="budgets")
|
||||
|
||||
# Usage tracking relationship
|
||||
usage_tracking = relationship("UsageTracking", back_populates="budget", cascade="all, delete-orphan")
|
||||
usage_tracking = relationship(
|
||||
"UsageTracking", back_populates="budget", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Budget limits (in cents)
|
||||
limit_cents = Column(Integer, nullable=False) # Maximum spend limit
|
||||
warning_threshold_cents = Column(Integer, nullable=True) # Warning threshold (e.g., 80% of limit)
|
||||
warning_threshold_cents = Column(
|
||||
Integer, nullable=True
|
||||
) # Warning threshold (e.g., 80% of limit)
|
||||
|
||||
# Time period settings
|
||||
period_type = Column(String, nullable=False, default="monthly") # daily, weekly, monthly, yearly, custom
|
||||
period_type = Column(
|
||||
String, nullable=False, default="monthly"
|
||||
) # daily, weekly, monthly, yearly, custom
|
||||
period_start = Column(DateTime, nullable=False) # Start of current period
|
||||
period_end = Column(DateTime, nullable=False) # End of current period
|
||||
|
||||
@@ -62,12 +82,18 @@ class Budget(Base):
|
||||
is_warning_sent = Column(Boolean, default=False)
|
||||
|
||||
# Enforcement settings
|
||||
enforce_hard_limit = Column(Boolean, default=True) # Block requests when limit exceeded
|
||||
enforce_hard_limit = Column(
|
||||
Boolean, default=True
|
||||
) # Block requests when limit exceeded
|
||||
enforce_warning = Column(Boolean, default=True) # Send warnings at threshold
|
||||
|
||||
# Allowed resources (optional filters)
|
||||
allowed_models = Column(JSON, default=list) # Specific models this budget applies to
|
||||
allowed_endpoints = Column(JSON, default=list) # Specific endpoints this budget applies to
|
||||
allowed_models = Column(
|
||||
JSON, default=list
|
||||
) # Specific models this budget applies to
|
||||
allowed_endpoints = Column(
|
||||
JSON, default=list
|
||||
) # Specific endpoints this budget applies to
|
||||
|
||||
# Metadata
|
||||
description = Column(Text, nullable=True)
|
||||
@@ -75,8 +101,12 @@ class Budget(Base):
|
||||
currency = Column(String, default="USD")
|
||||
|
||||
# Auto-renewal settings
|
||||
auto_renew = Column(Boolean, default=True) # Automatically renew budget for next period
|
||||
rollover_unused = Column(Boolean, default=False) # Rollover unused budget to next period
|
||||
auto_renew = Column(
|
||||
Boolean, default=True
|
||||
) # Automatically renew budget for next period
|
||||
rollover_unused = Column(
|
||||
Boolean, default=False
|
||||
) # Rollover unused budget to next period
|
||||
|
||||
# Notification settings
|
||||
notification_settings = Column(JSON, default=dict) # Email, webhook, etc.
|
||||
@@ -99,15 +129,23 @@ class Budget(Base):
|
||||
"limit_cents": self.limit_cents,
|
||||
"limit_dollars": self.limit_cents / 100,
|
||||
"warning_threshold_cents": self.warning_threshold_cents,
|
||||
"warning_threshold_dollars": self.warning_threshold_cents / 100 if self.warning_threshold_cents else None,
|
||||
"warning_threshold_dollars": self.warning_threshold_cents / 100
|
||||
if self.warning_threshold_cents
|
||||
else None,
|
||||
"period_type": self.period_type,
|
||||
"period_start": self.period_start.isoformat() if self.period_start else None,
|
||||
"period_start": self.period_start.isoformat()
|
||||
if self.period_start
|
||||
else None,
|
||||
"period_end": self.period_end.isoformat() if self.period_end else None,
|
||||
"current_usage_cents": self.current_usage_cents,
|
||||
"current_usage_dollars": self.current_usage_cents / 100,
|
||||
"remaining_cents": max(0, self.limit_cents - self.current_usage_cents),
|
||||
"remaining_dollars": max(0, (self.limit_cents - self.current_usage_cents) / 100),
|
||||
"usage_percentage": (self.current_usage_cents / self.limit_cents * 100) if self.limit_cents > 0 else 0,
|
||||
"remaining_dollars": max(
|
||||
0, (self.limit_cents - self.current_usage_cents) / 100
|
||||
),
|
||||
"usage_percentage": (self.current_usage_cents / self.limit_cents * 100)
|
||||
if self.limit_cents > 0
|
||||
else 0,
|
||||
"is_active": self.is_active,
|
||||
"is_exceeded": self.is_exceeded,
|
||||
"is_warning_sent": self.is_warning_sent,
|
||||
@@ -123,7 +161,9 @@ class Budget(Base):
|
||||
"notification_settings": self.notification_settings,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"last_reset_at": self.last_reset_at.isoformat() if self.last_reset_at else None
|
||||
"last_reset_at": self.last_reset_at.isoformat()
|
||||
if self.last_reset_at
|
||||
else None,
|
||||
}
|
||||
|
||||
def is_in_period(self) -> bool:
|
||||
@@ -161,7 +201,10 @@ class Budget(Base):
|
||||
self.is_exceeded = True
|
||||
|
||||
# Check if warning threshold is reached
|
||||
if self.warning_threshold_cents and self.current_usage_cents >= self.warning_threshold_cents:
|
||||
if (
|
||||
self.warning_threshold_cents
|
||||
and self.current_usage_cents >= self.warning_threshold_cents
|
||||
):
|
||||
if not self.is_warning_sent:
|
||||
self.is_warning_sent = True
|
||||
|
||||
@@ -190,9 +233,13 @@ class Budget(Base):
|
||||
self.period_start = self.period_end
|
||||
# Handle month boundaries properly
|
||||
if self.period_start.month == 12:
|
||||
next_month = self.period_start.replace(year=self.period_start.year + 1, month=1)
|
||||
next_month = self.period_start.replace(
|
||||
year=self.period_start.year + 1, month=1
|
||||
)
|
||||
else:
|
||||
next_month = self.period_start.replace(month=self.period_start.month + 1)
|
||||
next_month = self.period_start.replace(
|
||||
month=self.period_start.month + 1
|
||||
)
|
||||
self.period_end = next_month
|
||||
elif self.period_type == "yearly":
|
||||
self.period_start = self.period_end
|
||||
@@ -230,7 +277,7 @@ class Budget(Base):
|
||||
name: str,
|
||||
limit_dollars: float,
|
||||
api_key_id: Optional[int] = None,
|
||||
warning_threshold_percentage: float = 0.8
|
||||
warning_threshold_percentage: float = 0.8,
|
||||
) -> "Budget":
|
||||
"""Create a monthly budget"""
|
||||
now = datetime.utcnow()
|
||||
@@ -258,10 +305,7 @@ class Budget(Base):
|
||||
enforce_hard_limit=True,
|
||||
enforce_warning=True,
|
||||
auto_renew=True,
|
||||
notification_settings={
|
||||
"email_on_warning": True,
|
||||
"email_on_exceeded": True
|
||||
}
|
||||
notification_settings={"email_on_warning": True, "email_on_exceeded": True},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -270,7 +314,7 @@ class Budget(Base):
|
||||
user_id: int,
|
||||
name: str,
|
||||
limit_dollars: float,
|
||||
api_key_id: Optional[int] = None
|
||||
api_key_id: Optional[int] = None,
|
||||
) -> "Budget":
|
||||
"""Create a daily budget"""
|
||||
now = datetime.utcnow()
|
||||
@@ -292,5 +336,5 @@ class Budget(Base):
|
||||
is_active=True,
|
||||
enforce_hard_limit=True,
|
||||
enforce_warning=True,
|
||||
auto_renew=True
|
||||
auto_renew=True,
|
||||
)
|
||||
@@ -1,7 +1,16 @@
|
||||
"""
|
||||
Database models for chatbot module
|
||||
"""
|
||||
from sqlalchemy import Column, Integer, String, Text, Boolean, DateTime, JSON, ForeignKey
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
Boolean,
|
||||
DateTime,
|
||||
JSON,
|
||||
ForeignKey,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from datetime import datetime
|
||||
@@ -9,8 +18,10 @@ import uuid
|
||||
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class ChatbotInstance(Base):
|
||||
"""Configured chatbot instance"""
|
||||
|
||||
__tablename__ = "chatbot_instances"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
@@ -27,7 +38,9 @@ class ChatbotInstance(Base):
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
# Relationships
|
||||
conversations = relationship("ChatbotConversation", back_populates="chatbot", cascade="all, delete-orphan")
|
||||
conversations = relationship(
|
||||
"ChatbotConversation", back_populates="chatbot", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ChatbotInstance(id='{self.id}', name='{self.name}')>"
|
||||
@@ -35,6 +48,7 @@ class ChatbotInstance(Base):
|
||||
|
||||
class ChatbotConversation(Base):
|
||||
"""Conversation state and history"""
|
||||
|
||||
__tablename__ = "chatbot_conversations"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
@@ -52,7 +66,9 @@ class ChatbotConversation(Base):
|
||||
|
||||
# Relationships
|
||||
chatbot = relationship("ChatbotInstance", back_populates="conversations")
|
||||
messages = relationship("ChatbotMessage", back_populates="conversation", cascade="all, delete-orphan")
|
||||
messages = relationship(
|
||||
"ChatbotMessage", back_populates="conversation", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ChatbotConversation(id='{self.id}', chatbot_id='{self.chatbot_id}')>"
|
||||
@@ -60,10 +76,13 @@ class ChatbotConversation(Base):
|
||||
|
||||
class ChatbotMessage(Base):
|
||||
"""Individual chat messages in conversations"""
|
||||
|
||||
__tablename__ = "chatbot_messages"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
conversation_id = Column(String, ForeignKey("chatbot_conversations.id"), nullable=False)
|
||||
conversation_id = Column(
|
||||
String, ForeignKey("chatbot_conversations.id"), nullable=False
|
||||
)
|
||||
|
||||
# Message content
|
||||
role = Column(String(20), nullable=False) # 'user', 'assistant', 'system'
|
||||
@@ -85,6 +104,7 @@ class ChatbotMessage(Base):
|
||||
|
||||
class ChatbotAnalytics(Base):
|
||||
"""Analytics and metrics for chatbot usage"""
|
||||
|
||||
__tablename__ = "chatbot_analytics"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
@@ -92,7 +112,9 @@ class ChatbotAnalytics(Base):
|
||||
user_id = Column(String, nullable=False)
|
||||
|
||||
# Event tracking
|
||||
event_type = Column(String(50), nullable=False) # 'message_sent', 'response_generated', etc.
|
||||
event_type = Column(
|
||||
String(50), nullable=False
|
||||
) # 'message_sent', 'response_generated', etc.
|
||||
event_data = Column(JSON, default=dict)
|
||||
|
||||
# Performance metrics
|
||||
|
||||
@@ -10,6 +10,7 @@ from enum import Enum
|
||||
|
||||
class ModuleStatus(str, Enum):
|
||||
"""Module status types"""
|
||||
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
ERROR = "error"
|
||||
@@ -19,6 +20,7 @@ class ModuleStatus(str, Enum):
|
||||
|
||||
class ModuleType(str, Enum):
|
||||
"""Module type categories"""
|
||||
|
||||
CORE = "core"
|
||||
INTERCEPTOR = "interceptor"
|
||||
ANALYTICS = "analytics"
|
||||
@@ -66,14 +68,20 @@ class Module(Base):
|
||||
entry_point = Column(String, nullable=True) # Main module entry point
|
||||
|
||||
# Interceptor configuration
|
||||
interceptor_chains = Column(JSON, default=list) # Which chains this module hooks into
|
||||
interceptor_chains = Column(
|
||||
JSON, default=list
|
||||
) # Which chains this module hooks into
|
||||
execution_order = Column(Integer, default=100) # Order in interceptor chain
|
||||
|
||||
# API endpoints
|
||||
api_endpoints = Column(JSON, default=list) # List of API endpoints this module provides
|
||||
api_endpoints = Column(
|
||||
JSON, default=list
|
||||
) # List of API endpoints this module provides
|
||||
|
||||
# Permissions and security
|
||||
required_permissions = Column(JSON, default=list) # Permissions required to use this module
|
||||
required_permissions = Column(
|
||||
JSON, default=list
|
||||
) # Permissions required to use this module
|
||||
security_level = Column(String, default="low") # low, medium, high, critical
|
||||
|
||||
# Metadata
|
||||
@@ -130,16 +138,22 @@ class Module(Base):
|
||||
"metadata": self.module_metadata,
|
||||
"last_error": self.last_error,
|
||||
"error_count": self.error_count,
|
||||
"last_started": self.last_started.isoformat() if self.last_started else None,
|
||||
"last_stopped": self.last_stopped.isoformat() if self.last_stopped else None,
|
||||
"last_started": self.last_started.isoformat()
|
||||
if self.last_started
|
||||
else None,
|
||||
"last_stopped": self.last_stopped.isoformat()
|
||||
if self.last_stopped
|
||||
else None,
|
||||
"request_count": self.request_count,
|
||||
"success_count": self.success_count,
|
||||
"error_count_runtime": self.error_count_runtime,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"installed_at": self.installed_at.isoformat() if self.installed_at else None,
|
||||
"installed_at": self.installed_at.isoformat()
|
||||
if self.installed_at
|
||||
else None,
|
||||
"success_rate": self.get_success_rate(),
|
||||
"uptime": self.get_uptime_seconds() if self.is_running() else 0
|
||||
"uptime": self.get_uptime_seconds() if self.is_running() else 0,
|
||||
}
|
||||
|
||||
def is_running(self) -> bool:
|
||||
@@ -315,8 +329,14 @@ class Module(Base):
|
||||
self.required_permissions.remove(permission)
|
||||
|
||||
@classmethod
|
||||
def create_core_module(cls, name: str, display_name: str, description: str,
|
||||
version: str, entry_point: str) -> "Module":
|
||||
def create_core_module(
|
||||
cls,
|
||||
name: str,
|
||||
display_name: str,
|
||||
description: str,
|
||||
version: str,
|
||||
entry_point: str,
|
||||
) -> "Module":
|
||||
"""Create a core module"""
|
||||
return cls(
|
||||
name=name,
|
||||
@@ -341,7 +361,7 @@ class Module(Base):
|
||||
required_permissions=[],
|
||||
security_level="high",
|
||||
tags=["core"],
|
||||
module_metadata={}
|
||||
module_metadata={},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -365,20 +385,12 @@ class Module(Base):
|
||||
"properties": {
|
||||
"provider": {"type": "string", "enum": ["redis"]},
|
||||
"ttl": {"type": "integer", "minimum": 60},
|
||||
"max_size": {"type": "integer", "minimum": 1000}
|
||||
"max_size": {"type": "integer", "minimum": 1000},
|
||||
},
|
||||
"required": ["provider", "ttl"]
|
||||
},
|
||||
config_values={
|
||||
"provider": "redis",
|
||||
"ttl": 3600,
|
||||
"max_size": 10000
|
||||
},
|
||||
default_config={
|
||||
"provider": "redis",
|
||||
"ttl": 3600,
|
||||
"max_size": 10000
|
||||
"required": ["provider", "ttl"],
|
||||
},
|
||||
config_values={"provider": "redis", "ttl": 3600, "max_size": 10000},
|
||||
default_config={"provider": "redis", "ttl": 3600, "max_size": 10000},
|
||||
dependencies=[],
|
||||
conflicts=[],
|
||||
interceptor_chains=["pre_request", "post_response"],
|
||||
@@ -387,7 +399,7 @@ class Module(Base):
|
||||
required_permissions=["cache.read", "cache.write"],
|
||||
security_level="low",
|
||||
tags=["cache", "performance"],
|
||||
module_metadata={}
|
||||
module_metadata={},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -412,21 +424,21 @@ class Module(Base):
|
||||
"vector_db": {"type": "string", "enum": ["qdrant"]},
|
||||
"embedding_model": {"type": "string"},
|
||||
"chunk_size": {"type": "integer", "minimum": 100},
|
||||
"max_results": {"type": "integer", "minimum": 1}
|
||||
"max_results": {"type": "integer", "minimum": 1},
|
||||
},
|
||||
"required": ["vector_db", "embedding_model"]
|
||||
"required": ["vector_db", "embedding_model"],
|
||||
},
|
||||
config_values={
|
||||
"vector_db": "qdrant",
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"chunk_size": 512,
|
||||
"max_results": 10
|
||||
"max_results": 10,
|
||||
},
|
||||
default_config={
|
||||
"vector_db": "qdrant",
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"chunk_size": 512,
|
||||
"max_results": 10
|
||||
"max_results": 10,
|
||||
},
|
||||
dependencies=[],
|
||||
conflicts=[],
|
||||
@@ -436,7 +448,7 @@ class Module(Base):
|
||||
required_permissions=["rag.read", "rag.write"],
|
||||
security_level="medium",
|
||||
tags=["rag", "ai", "search"],
|
||||
module_metadata={}
|
||||
module_metadata={},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -460,19 +472,19 @@ class Module(Base):
|
||||
"properties": {
|
||||
"track_requests": {"type": "boolean"},
|
||||
"track_responses": {"type": "boolean"},
|
||||
"retention_days": {"type": "integer", "minimum": 1}
|
||||
"retention_days": {"type": "integer", "minimum": 1},
|
||||
},
|
||||
"required": ["track_requests", "track_responses"]
|
||||
"required": ["track_requests", "track_responses"],
|
||||
},
|
||||
config_values={
|
||||
"track_requests": True,
|
||||
"track_responses": True,
|
||||
"retention_days": 30
|
||||
"retention_days": 30,
|
||||
},
|
||||
default_config={
|
||||
"track_requests": True,
|
||||
"track_responses": True,
|
||||
"retention_days": 30
|
||||
"retention_days": 30,
|
||||
},
|
||||
dependencies=[],
|
||||
conflicts=[],
|
||||
@@ -482,7 +494,7 @@ class Module(Base):
|
||||
required_permissions=["analytics.read"],
|
||||
security_level="low",
|
||||
tags=["analytics", "monitoring"],
|
||||
module_metadata={}
|
||||
module_metadata={},
|
||||
)
|
||||
|
||||
def get_health_status(self) -> Dict[str, Any]:
|
||||
@@ -495,5 +507,7 @@ class Module(Base):
|
||||
"uptime_seconds": self.get_uptime_seconds() if self.is_running() else 0,
|
||||
"last_error": self.last_error,
|
||||
"error_count": self.error_count_runtime,
|
||||
"last_started": self.last_started.isoformat() if self.last_started else None
|
||||
"last_started": self.last_started.isoformat()
|
||||
if self.last_started
|
||||
else None,
|
||||
}
|
||||
295
backend/app/models/notification.py
Normal file
295
backend/app/models/notification.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""
|
||||
Notification models for multi-channel communication
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from enum import Enum
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
Boolean,
|
||||
Text,
|
||||
JSON,
|
||||
ForeignKey,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class NotificationType(str, Enum):
|
||||
"""Notification types"""
|
||||
|
||||
EMAIL = "email"
|
||||
WEBHOOK = "webhook"
|
||||
SLACK = "slack"
|
||||
DISCORD = "discord"
|
||||
SMS = "sms"
|
||||
PUSH = "push"
|
||||
|
||||
|
||||
class NotificationPriority(str, Enum):
|
||||
"""Notification priority levels"""
|
||||
|
||||
LOW = "low"
|
||||
NORMAL = "normal"
|
||||
HIGH = "high"
|
||||
URGENT = "urgent"
|
||||
|
||||
|
||||
class NotificationStatus(str, Enum):
|
||||
"""Notification delivery status"""
|
||||
|
||||
PENDING = "pending"
|
||||
SENT = "sent"
|
||||
DELIVERED = "delivered"
|
||||
FAILED = "failed"
|
||||
RETRY = "retry"
|
||||
|
||||
|
||||
class NotificationTemplate(Base):
|
||||
"""Notification template model"""
|
||||
|
||||
__tablename__ = "notification_templates"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(100), nullable=False, unique=True)
|
||||
display_name = Column(String(200), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# Template content
|
||||
notification_type = Column(String(20), nullable=False) # NotificationType enum
|
||||
subject_template = Column(Text, nullable=True) # For email/messages
|
||||
body_template = Column(Text, nullable=False) # Main content
|
||||
html_template = Column(Text, nullable=True) # HTML version for email
|
||||
|
||||
# Configuration
|
||||
default_priority = Column(String(20), default=NotificationPriority.NORMAL)
|
||||
variables = Column(JSON, default=dict) # Expected template variables
|
||||
template_metadata = Column(JSON, default=dict) # Additional configuration
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
notifications = relationship("Notification", back_populates="template")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<NotificationTemplate(id={self.id}, name='{self.name}', type='{self.notification_type}')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert template to dictionary"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"display_name": self.display_name,
|
||||
"description": self.description,
|
||||
"notification_type": self.notification_type,
|
||||
"subject_template": self.subject_template,
|
||||
"body_template": self.body_template,
|
||||
"html_template": self.html_template,
|
||||
"default_priority": self.default_priority,
|
||||
"variables": self.variables,
|
||||
"template_metadata": self.template_metadata,
|
||||
"is_active": self.is_active,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
class NotificationChannel(Base):
|
||||
"""Notification channel configuration"""
|
||||
|
||||
__tablename__ = "notification_channels"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(100), nullable=False)
|
||||
display_name = Column(String(200), nullable=False)
|
||||
notification_type = Column(String(20), nullable=False) # NotificationType enum
|
||||
|
||||
# Channel configuration
|
||||
config = Column(JSON, nullable=False) # Channel-specific settings
|
||||
credentials = Column(JSON, nullable=True) # Encrypted credentials
|
||||
|
||||
# Settings
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_default = Column(Boolean, default=False)
|
||||
rate_limit = Column(Integer, default=100) # Messages per minute
|
||||
retry_count = Column(Integer, default=3)
|
||||
retry_delay_minutes = Column(Integer, default=5)
|
||||
|
||||
# Health monitoring
|
||||
last_used_at = Column(DateTime, nullable=True)
|
||||
success_count = Column(Integer, default=0)
|
||||
failure_count = Column(Integer, default=0)
|
||||
last_error = Column(Text, nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
notifications = relationship("Notification", back_populates="channel")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<NotificationChannel(id={self.id}, name='{self.name}', type='{self.notification_type}')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert channel to dictionary (excluding sensitive credentials)"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"display_name": self.display_name,
|
||||
"notification_type": self.notification_type,
|
||||
"config": self.config,
|
||||
"is_active": self.is_active,
|
||||
"is_default": self.is_default,
|
||||
"rate_limit": self.rate_limit,
|
||||
"retry_count": self.retry_count,
|
||||
"retry_delay_minutes": self.retry_delay_minutes,
|
||||
"last_used_at": self.last_used_at.isoformat()
|
||||
if self.last_used_at
|
||||
else None,
|
||||
"success_count": self.success_count,
|
||||
"failure_count": self.failure_count,
|
||||
"last_error": self.last_error,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
def update_stats(self, success: bool, error_message: Optional[str] = None):
|
||||
"""Update channel statistics"""
|
||||
self.last_used_at = datetime.utcnow()
|
||||
if success:
|
||||
self.success_count += 1
|
||||
self.last_error = None
|
||||
else:
|
||||
self.failure_count += 1
|
||||
self.last_error = error_message
|
||||
|
||||
|
||||
class Notification(Base):
|
||||
"""Individual notification instance"""
|
||||
|
||||
__tablename__ = "notifications"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
# Content
|
||||
subject = Column(String(500), nullable=True)
|
||||
body = Column(Text, nullable=False)
|
||||
html_body = Column(Text, nullable=True)
|
||||
|
||||
# Recipients
|
||||
recipients = Column(JSON, nullable=False) # List of recipient addresses/IDs
|
||||
cc_recipients = Column(JSON, nullable=True) # CC recipients (for email)
|
||||
bcc_recipients = Column(JSON, nullable=True) # BCC recipients (for email)
|
||||
|
||||
# Configuration
|
||||
priority = Column(String(20), default=NotificationPriority.NORMAL)
|
||||
scheduled_at = Column(DateTime, nullable=True) # For scheduled delivery
|
||||
expires_at = Column(DateTime, nullable=True) # Expiration time
|
||||
|
||||
# References
|
||||
template_id = Column(
|
||||
Integer, ForeignKey("notification_templates.id"), nullable=True
|
||||
)
|
||||
channel_id = Column(Integer, ForeignKey("notification_channels.id"), nullable=False)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=True) # Triggering user
|
||||
|
||||
# Status tracking
|
||||
status = Column(String(20), default=NotificationStatus.PENDING)
|
||||
attempts = Column(Integer, default=0)
|
||||
max_attempts = Column(Integer, default=3)
|
||||
|
||||
# Delivery tracking
|
||||
sent_at = Column(DateTime, nullable=True)
|
||||
delivered_at = Column(DateTime, nullable=True)
|
||||
failed_at = Column(DateTime, nullable=True)
|
||||
error_message = Column(Text, nullable=True)
|
||||
|
||||
# External references
|
||||
external_id = Column(String(200), nullable=True) # Provider message ID
|
||||
callback_url = Column(String(500), nullable=True) # Delivery callback
|
||||
|
||||
# Metadata
|
||||
notification_metadata = Column(JSON, default=dict)
|
||||
tags = Column(JSON, default=list)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
template = relationship("NotificationTemplate", back_populates="notifications")
|
||||
channel = relationship("NotificationChannel", back_populates="notifications")
|
||||
user = relationship("User", back_populates="notifications")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Notification(id={self.id}, status='{self.status}', channel='{self.channel.name if self.channel else 'unknown'}')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert notification to dictionary"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"subject": self.subject,
|
||||
"body": self.body,
|
||||
"html_body": self.html_body,
|
||||
"recipients": self.recipients,
|
||||
"cc_recipients": self.cc_recipients,
|
||||
"bcc_recipients": self.bcc_recipients,
|
||||
"priority": self.priority,
|
||||
"scheduled_at": self.scheduled_at.isoformat()
|
||||
if self.scheduled_at
|
||||
else None,
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"template_id": self.template_id,
|
||||
"channel_id": self.channel_id,
|
||||
"user_id": self.user_id,
|
||||
"status": self.status,
|
||||
"attempts": self.attempts,
|
||||
"max_attempts": self.max_attempts,
|
||||
"sent_at": self.sent_at.isoformat() if self.sent_at else None,
|
||||
"delivered_at": self.delivered_at.isoformat()
|
||||
if self.delivered_at
|
||||
else None,
|
||||
"failed_at": self.failed_at.isoformat() if self.failed_at else None,
|
||||
"error_message": self.error_message,
|
||||
"external_id": self.external_id,
|
||||
"callback_url": self.callback_url,
|
||||
"notification_metadata": self.notification_metadata,
|
||||
"tags": self.tags,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
def mark_sent(self, external_id: Optional[str] = None):
|
||||
"""Mark notification as sent"""
|
||||
self.status = NotificationStatus.SENT
|
||||
self.sent_at = datetime.utcnow()
|
||||
self.external_id = external_id
|
||||
|
||||
def mark_delivered(self):
|
||||
"""Mark notification as delivered"""
|
||||
self.status = NotificationStatus.DELIVERED
|
||||
self.delivered_at = datetime.utcnow()
|
||||
|
||||
def mark_failed(self, error_message: str):
|
||||
"""Mark notification as failed"""
|
||||
self.status = NotificationStatus.FAILED
|
||||
self.failed_at = datetime.utcnow()
|
||||
self.error_message = error_message
|
||||
self.attempts += 1
|
||||
|
||||
def can_retry(self) -> bool:
|
||||
"""Check if notification can be retried"""
|
||||
return (
|
||||
self.status in [NotificationStatus.FAILED, NotificationStatus.RETRY]
|
||||
and self.attempts < self.max_attempts
|
||||
and (self.expires_at is None or self.expires_at > datetime.utcnow())
|
||||
)
|
||||
@@ -2,7 +2,17 @@
|
||||
Plugin System Database Models
|
||||
Defines the database schema for the isolated plugin architecture
|
||||
"""
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, JSON, ForeignKey, Index
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
DateTime,
|
||||
Boolean,
|
||||
JSON,
|
||||
ForeignKey,
|
||||
Index,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
@@ -13,12 +23,15 @@ from app.db.database import Base
|
||||
|
||||
class Plugin(Base):
|
||||
"""Plugin registry - tracks all installed plugins"""
|
||||
|
||||
__tablename__ = "plugins"
|
||||
|
||||
# Primary identification
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(100), unique=True, nullable=False, index=True)
|
||||
slug = Column(String(100), unique=True, nullable=False, index=True) # URL-safe identifier
|
||||
slug = Column(
|
||||
String(100), unique=True, nullable=False, index=True
|
||||
) # URL-safe identifier
|
||||
|
||||
# Metadata
|
||||
display_name = Column(String(200), nullable=False)
|
||||
@@ -66,20 +79,29 @@ class Plugin(Base):
|
||||
|
||||
# Relationships
|
||||
installed_by_user = relationship("User", back_populates="installed_plugins")
|
||||
configurations = relationship("PluginConfiguration", back_populates="plugin", cascade="all, delete-orphan")
|
||||
instances = relationship("PluginInstance", back_populates="plugin", cascade="all, delete-orphan")
|
||||
audit_logs = relationship("PluginAuditLog", back_populates="plugin", cascade="all, delete-orphan")
|
||||
cron_jobs = relationship("PluginCronJob", back_populates="plugin", cascade="all, delete-orphan")
|
||||
configurations = relationship(
|
||||
"PluginConfiguration", back_populates="plugin", cascade="all, delete-orphan"
|
||||
)
|
||||
instances = relationship(
|
||||
"PluginInstance", back_populates="plugin", cascade="all, delete-orphan"
|
||||
)
|
||||
audit_logs = relationship(
|
||||
"PluginAuditLog", back_populates="plugin", cascade="all, delete-orphan"
|
||||
)
|
||||
cron_jobs = relationship(
|
||||
"PluginCronJob", back_populates="plugin", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Indexes for performance
|
||||
__table_args__ = (
|
||||
Index('idx_plugin_status_enabled', 'status', 'enabled'),
|
||||
Index('idx_plugin_user_status', 'installed_by_user_id', 'status'),
|
||||
Index("idx_plugin_status_enabled", "status", "enabled"),
|
||||
Index("idx_plugin_user_status", "installed_by_user_id", "status"),
|
||||
)
|
||||
|
||||
|
||||
class PluginConfiguration(Base):
|
||||
"""Plugin configuration instances - per user/environment configs"""
|
||||
|
||||
__tablename__ = "plugin_configurations"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
@@ -107,17 +129,20 @@ class PluginConfiguration(Base):
|
||||
|
||||
# Constraints
|
||||
__table_args__ = (
|
||||
Index('idx_plugin_config_user_active', 'plugin_id', 'user_id', 'is_active'),
|
||||
Index("idx_plugin_config_user_active", "plugin_id", "user_id", "is_active"),
|
||||
)
|
||||
|
||||
|
||||
class PluginInstance(Base):
|
||||
"""Plugin runtime instances - tracks running plugin processes"""
|
||||
|
||||
__tablename__ = "plugin_instances"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
plugin_id = Column(UUID(as_uuid=True), ForeignKey("plugins.id"), nullable=False)
|
||||
configuration_id = Column(UUID(as_uuid=True), ForeignKey("plugin_configurations.id"))
|
||||
configuration_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("plugin_configurations.id")
|
||||
)
|
||||
|
||||
# Runtime information
|
||||
instance_name = Column(String(200), nullable=False)
|
||||
@@ -148,13 +173,12 @@ class PluginInstance(Base):
|
||||
plugin = relationship("Plugin", back_populates="instances")
|
||||
configuration = relationship("PluginConfiguration")
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_plugin_instance_status', 'plugin_id', 'status'),
|
||||
)
|
||||
__table_args__ = (Index("idx_plugin_instance_status", "plugin_id", "status"),)
|
||||
|
||||
|
||||
class PluginAuditLog(Base):
|
||||
"""Audit logging for all plugin activities"""
|
||||
|
||||
__tablename__ = "plugin_audit_logs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
@@ -162,7 +186,9 @@ class PluginAuditLog(Base):
|
||||
instance_id = Column(UUID(as_uuid=True), ForeignKey("plugin_instances.id"))
|
||||
|
||||
# Event details
|
||||
event_type = Column(String(50), nullable=False, index=True) # api_call, config_change, error, etc.
|
||||
event_type = Column(
|
||||
String(50), nullable=False, index=True
|
||||
) # api_call, config_change, error, etc.
|
||||
action = Column(String(100), nullable=False)
|
||||
resource = Column(String(200)) # Resource being accessed
|
||||
|
||||
@@ -194,14 +220,15 @@ class PluginAuditLog(Base):
|
||||
api_key = relationship("APIKey")
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_plugin_audit_plugin_time', 'plugin_id', 'timestamp'),
|
||||
Index('idx_plugin_audit_user_time', 'user_id', 'timestamp'),
|
||||
Index('idx_plugin_audit_event_type', 'event_type', 'timestamp'),
|
||||
Index("idx_plugin_audit_plugin_time", "plugin_id", "timestamp"),
|
||||
Index("idx_plugin_audit_user_time", "user_id", "timestamp"),
|
||||
Index("idx_plugin_audit_event_type", "event_type", "timestamp"),
|
||||
)
|
||||
|
||||
|
||||
class PluginCronJob(Base):
|
||||
"""Plugin scheduled jobs and cron tasks"""
|
||||
|
||||
__tablename__ = "plugin_cron_jobs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
@@ -209,7 +236,9 @@ class PluginCronJob(Base):
|
||||
|
||||
# Job identification
|
||||
job_name = Column(String(200), nullable=False)
|
||||
job_id = Column(String(100), nullable=False, unique=True, index=True) # Unique scheduler ID
|
||||
job_id = Column(
|
||||
String(100), nullable=False, unique=True, index=True
|
||||
) # Unique scheduler ID
|
||||
|
||||
# Schedule configuration
|
||||
schedule = Column(String(100), nullable=False) # Cron expression
|
||||
@@ -245,25 +274,32 @@ class PluginCronJob(Base):
|
||||
created_by_user = relationship("User")
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_plugin_cron_next_run', 'enabled', 'next_run_at'),
|
||||
Index('idx_plugin_cron_plugin', 'plugin_id', 'enabled'),
|
||||
Index("idx_plugin_cron_next_run", "enabled", "next_run_at"),
|
||||
Index("idx_plugin_cron_plugin", "plugin_id", "enabled"),
|
||||
)
|
||||
|
||||
|
||||
class PluginAPIGateway(Base):
|
||||
"""API gateway configuration for plugin routing"""
|
||||
|
||||
__tablename__ = "plugin_api_gateways"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
plugin_id = Column(UUID(as_uuid=True), ForeignKey("plugins.id"), nullable=False, unique=True)
|
||||
plugin_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("plugins.id"), nullable=False, unique=True
|
||||
)
|
||||
|
||||
# API routing configuration
|
||||
base_path = Column(String(200), nullable=False, unique=True) # /api/v1/plugins/zammad
|
||||
base_path = Column(
|
||||
String(200), nullable=False, unique=True
|
||||
) # /api/v1/plugins/zammad
|
||||
internal_url = Column(String(500), nullable=False) # http://plugin-zammad:8000
|
||||
|
||||
# Security settings
|
||||
require_authentication = Column(Boolean, default=True, nullable=False)
|
||||
allowed_methods = Column(JSON, default=["GET", "POST", "PUT", "DELETE"]) # HTTP methods
|
||||
allowed_methods = Column(
|
||||
JSON, default=["GET", "POST", "PUT", "DELETE"]
|
||||
) # HTTP methods
|
||||
rate_limit_per_minute = Column(Integer, default=60)
|
||||
rate_limit_per_hour = Column(Integer, default=1000)
|
||||
|
||||
@@ -303,8 +339,10 @@ Add to APIKey model:
|
||||
plugin_audit_logs = relationship("PluginAuditLog", back_populates="api_key")
|
||||
"""
|
||||
|
||||
|
||||
class PluginPermission(Base):
|
||||
"""Plugin permission grants - tracks user permissions for plugins"""
|
||||
|
||||
__tablename__ = "plugin_permissions"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
@@ -312,8 +350,12 @@ class PluginPermission(Base):
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
|
||||
# Permission details
|
||||
permission_name = Column(String(200), nullable=False) # e.g., 'chatbot:invoke', 'rag:query'
|
||||
granted = Column(Boolean, default=True, nullable=False) # True=granted, False=revoked
|
||||
permission_name = Column(
|
||||
String(200), nullable=False
|
||||
) # e.g., 'chatbot:invoke', 'rag:query'
|
||||
granted = Column(
|
||||
Boolean, default=True, nullable=False
|
||||
) # True=granted, False=revoked
|
||||
|
||||
# Grant/revoke tracking
|
||||
granted_at = Column(DateTime, nullable=False, default=func.now())
|
||||
@@ -332,7 +374,7 @@ class PluginPermission(Base):
|
||||
revoked_by_user = relationship("User", foreign_keys=[revoked_by_user_id])
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_plugin_permission_user_plugin', 'user_id', 'plugin_id'),
|
||||
Index('idx_plugin_permission_plugin_name', 'plugin_id', 'permission_name'),
|
||||
Index('idx_plugin_permission_active', 'plugin_id', 'user_id', 'granted'),
|
||||
Index("idx_plugin_permission_user_plugin", "user_id", "plugin_id"),
|
||||
Index("idx_plugin_permission_plugin_name", "plugin_id", "permission_name"),
|
||||
Index("idx_plugin_permission_active", "plugin_id", "user_id", "granted"),
|
||||
)
|
||||
|
||||
@@ -10,18 +10,23 @@ from datetime import datetime
|
||||
|
||||
class PromptTemplate(Base):
|
||||
"""Editable prompt templates for different chatbot types"""
|
||||
|
||||
__tablename__ = "prompt_templates"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
name = Column(String(255), nullable=False, index=True) # Human readable name
|
||||
type_key = Column(String(100), nullable=False, unique=True, index=True) # assistant, customer_support, etc.
|
||||
type_key = Column(
|
||||
String(100), nullable=False, unique=True, index=True
|
||||
) # assistant, customer_support, etc.
|
||||
description = Column(Text, nullable=True)
|
||||
system_prompt = Column(Text, nullable=False)
|
||||
is_default = Column(Boolean, default=True, nullable=False)
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
version = Column(Integer, default=1, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<PromptTemplate(type_key='{self.type_key}', name='{self.name}')>"
|
||||
@@ -29,10 +34,13 @@ class PromptTemplate(Base):
|
||||
|
||||
class ChatbotPromptVariable(Base):
|
||||
"""Available variables that can be used in prompts"""
|
||||
|
||||
__tablename__ = "prompt_variables"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
variable_name = Column(String(100), nullable=False, unique=True, index=True) # {user_name}, {context}, etc.
|
||||
variable_name = Column(
|
||||
String(100), nullable=False, unique=True, index=True
|
||||
) # {user_name}, {context}, etc.
|
||||
description = Column(Text, nullable=True)
|
||||
example_value = Column(String(500), nullable=True)
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
|
||||
@@ -15,7 +15,9 @@ class RagCollection(Base):
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
description = Column(Text, nullable=True)
|
||||
qdrant_collection_name = Column(String(255), nullable=False, unique=True, index=True)
|
||||
qdrant_collection_name = Column(
|
||||
String(255), nullable=False, unique=True, index=True
|
||||
)
|
||||
|
||||
# Metadata
|
||||
document_count = Column(Integer, default=0, nullable=False)
|
||||
@@ -23,15 +25,26 @@ class RagCollection(Base):
|
||||
vector_count = Column(Integer, default=0, nullable=False)
|
||||
|
||||
# Status tracking
|
||||
status = Column(String(50), default='active', nullable=False) # 'active', 'indexing', 'error'
|
||||
status = Column(
|
||||
String(50), default="active", nullable=False
|
||||
) # 'active', 'indexing', 'error'
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
documents = relationship("RagDocument", back_populates="collection", cascade="all, delete-orphan")
|
||||
documents = relationship(
|
||||
"RagDocument", back_populates="collection", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert model to dictionary for API responses"""
|
||||
@@ -45,7 +58,7 @@ class RagCollection(Base):
|
||||
"status": self.status,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"is_active": self.is_active
|
||||
"is_active": self.is_active,
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -3,7 +3,17 @@ RAG Document Model
|
||||
Represents documents within RAG collections
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, BigInteger, ForeignKey, JSON
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
DateTime,
|
||||
Boolean,
|
||||
BigInteger,
|
||||
ForeignKey,
|
||||
JSON,
|
||||
)
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db.database import Base
|
||||
@@ -15,7 +25,12 @@ class RagDocument(Base):
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
# Collection relationship
|
||||
collection_id = Column(Integer, ForeignKey("rag_collections.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
collection_id = Column(
|
||||
Integer,
|
||||
ForeignKey("rag_collections.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
collection = relationship("RagCollection", back_populates="documents")
|
||||
|
||||
# File information
|
||||
@@ -27,7 +42,9 @@ class RagDocument(Base):
|
||||
mime_type = Column(String(100), nullable=True)
|
||||
|
||||
# Processing status
|
||||
status = Column(String(50), default='processing', nullable=False) # 'processing', 'processed', 'error', 'indexed'
|
||||
status = Column(
|
||||
String(50), default="processing", nullable=False
|
||||
) # 'processing', 'processed', 'error', 'indexed'
|
||||
processing_error = Column(Text, nullable=True)
|
||||
|
||||
# Content information
|
||||
@@ -36,17 +53,30 @@ class RagDocument(Base):
|
||||
character_count = Column(Integer, default=0, nullable=False)
|
||||
|
||||
# Vector information
|
||||
vector_count = Column(Integer, default=0, nullable=False) # number of chunks/vectors created
|
||||
chunk_size = Column(Integer, default=1000, nullable=False) # chunk size used for vectorization
|
||||
vector_count = Column(
|
||||
Integer, default=0, nullable=False
|
||||
) # number of chunks/vectors created
|
||||
chunk_size = Column(
|
||||
Integer, default=1000, nullable=False
|
||||
) # chunk size used for vectorization
|
||||
|
||||
# Metadata extracted from document
|
||||
document_metadata = Column(JSON, nullable=True) # language, entities, keywords, etc.
|
||||
document_metadata = Column(
|
||||
JSON, nullable=True
|
||||
) # language, entities, keywords, etc.
|
||||
|
||||
# Processing timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
processed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
indexed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Soft delete
|
||||
is_deleted = Column(Boolean, default=False, nullable=False)
|
||||
@@ -72,10 +102,12 @@ class RagDocument(Base):
|
||||
"chunk_size": self.chunk_size,
|
||||
"metadata": self.document_metadata or {},
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"processed_at": self.processed_at.isoformat() if self.processed_at else None,
|
||||
"processed_at": self.processed_at.isoformat()
|
||||
if self.processed_at
|
||||
else None,
|
||||
"indexed_at": self.indexed_at.isoformat() if self.indexed_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"is_deleted": self.is_deleted
|
||||
"is_deleted": self.is_deleted,
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
158
backend/app/models/role.py
Normal file
158
backend/app/models/role.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
Role model for hierarchical permission management
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from enum import Enum
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class RoleLevel(str, Enum):
|
||||
"""Role hierarchy levels"""
|
||||
|
||||
READ_ONLY = "read_only" # Level 1: Can only view
|
||||
USER = "user" # Level 2: Can create and manage own resources
|
||||
ADMIN = "admin" # Level 3: Can manage users and settings
|
||||
SUPER_ADMIN = "super_admin" # Level 4: Full system access
|
||||
|
||||
|
||||
class Role(Base):
|
||||
"""Role model with hierarchical permissions"""
|
||||
|
||||
__tablename__ = "roles"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(50), unique=True, nullable=False)
|
||||
display_name = Column(String(100), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
level = Column(String(20), nullable=False) # RoleLevel enum
|
||||
|
||||
# Permissions configuration
|
||||
permissions = Column(JSON, default=dict) # Granular permissions
|
||||
can_manage_users = Column(Boolean, default=False)
|
||||
can_manage_budgets = Column(Boolean, default=False)
|
||||
can_view_reports = Column(Boolean, default=False)
|
||||
can_manage_tools = Column(Boolean, default=False)
|
||||
|
||||
# Role hierarchy
|
||||
inherits_from = Column(JSON, default=list) # List of parent role names
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_system_role = Column(Boolean, default=False) # System roles cannot be deleted
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
users = relationship("User", back_populates="role")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Role(id={self.id}, name='{self.name}', level='{self.level}')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert role to dictionary"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"display_name": self.display_name,
|
||||
"description": self.description,
|
||||
"level": self.level,
|
||||
"permissions": self.permissions,
|
||||
"can_manage_users": self.can_manage_users,
|
||||
"can_manage_budgets": self.can_manage_budgets,
|
||||
"can_view_reports": self.can_view_reports,
|
||||
"can_manage_tools": self.can_manage_tools,
|
||||
"inherits_from": self.inherits_from,
|
||||
"is_active": self.is_active,
|
||||
"is_system_role": self.is_system_role,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
def has_permission(self, permission: str) -> bool:
|
||||
"""Check if role has a specific permission"""
|
||||
if self.level == RoleLevel.SUPER_ADMIN:
|
||||
return True
|
||||
|
||||
# Check direct permissions
|
||||
if permission in self.permissions.get("granted", []):
|
||||
return True
|
||||
|
||||
# Check denied permissions
|
||||
if permission in self.permissions.get("denied", []):
|
||||
return False
|
||||
|
||||
# Check inherited permissions (simplified)
|
||||
for parent_role in self.inherits_from:
|
||||
# This would require recursive checking in a real implementation
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def create_default_roles(cls):
|
||||
"""Create default system roles"""
|
||||
roles = [
|
||||
cls(
|
||||
name="read_only",
|
||||
display_name="Read Only",
|
||||
description="Can view own data only",
|
||||
level=RoleLevel.READ_ONLY,
|
||||
permissions={
|
||||
"granted": ["read_own"],
|
||||
"denied": ["create", "update", "delete"],
|
||||
},
|
||||
is_system_role=True,
|
||||
),
|
||||
cls(
|
||||
name="user",
|
||||
display_name="User",
|
||||
description="Standard user with full access to own resources",
|
||||
level=RoleLevel.USER,
|
||||
permissions={
|
||||
"granted": ["read_own", "create_own", "update_own", "delete_own"],
|
||||
"denied": ["manage_users", "manage_all"],
|
||||
},
|
||||
inherits_from=["read_only"],
|
||||
is_system_role=True,
|
||||
),
|
||||
cls(
|
||||
name="admin",
|
||||
display_name="Administrator",
|
||||
description="Can manage users and view reports",
|
||||
level=RoleLevel.ADMIN,
|
||||
permissions={
|
||||
"granted": [
|
||||
"read_all",
|
||||
"create_all",
|
||||
"update_all",
|
||||
"manage_users",
|
||||
"view_reports",
|
||||
],
|
||||
"denied": ["system_settings"],
|
||||
},
|
||||
inherits_from=["user"],
|
||||
can_manage_users=True,
|
||||
can_manage_budgets=True,
|
||||
can_view_reports=True,
|
||||
is_system_role=True,
|
||||
),
|
||||
cls(
|
||||
name="super_admin",
|
||||
display_name="Super Administrator",
|
||||
description="Full system access",
|
||||
level=RoleLevel.SUPER_ADMIN,
|
||||
permissions={"granted": ["*"]}, # All permissions
|
||||
inherits_from=["admin"],
|
||||
can_manage_users=True,
|
||||
can_manage_budgets=True,
|
||||
can_view_reports=True,
|
||||
can_manage_tools=True,
|
||||
is_system_role=True,
|
||||
),
|
||||
]
|
||||
return roles
|
||||
272
backend/app/models/tool.py
Normal file
272
backend/app/models/tool.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
Tool model for custom tool execution
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from enum import Enum
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
Boolean,
|
||||
Text,
|
||||
JSON,
|
||||
ForeignKey,
|
||||
Float,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class ToolType(str, Enum):
|
||||
"""Tool execution types"""
|
||||
|
||||
PYTHON = "python"
|
||||
BASH = "bash"
|
||||
DOCKER = "docker"
|
||||
API = "api"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class ToolStatus(str, Enum):
|
||||
"""Tool execution status"""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
TIMEOUT = "timeout"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class Tool(Base):
|
||||
"""Tool definition model"""
|
||||
|
||||
__tablename__ = "tools"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
display_name = Column(String(200), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# Tool configuration
|
||||
tool_type = Column(String(20), nullable=False) # ToolType enum
|
||||
code = Column(Text, nullable=False) # Tool implementation code
|
||||
parameters_schema = Column(JSON, default=dict) # JSON schema for parameters
|
||||
return_schema = Column(JSON, default=dict) # Expected return format
|
||||
|
||||
# Execution settings
|
||||
timeout_seconds = Column(Integer, default=30)
|
||||
max_memory_mb = Column(Integer, default=256)
|
||||
max_cpu_seconds = Column(Float, default=10.0)
|
||||
|
||||
# Docker settings (for docker type tools)
|
||||
docker_image = Column(String(200), nullable=True)
|
||||
docker_command = Column(Text, nullable=True)
|
||||
|
||||
# Access control
|
||||
is_public = Column(Boolean, default=False) # Public tools available to all users
|
||||
is_approved = Column(Boolean, default=False) # Admin approved for security
|
||||
created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
|
||||
# Categories and tags
|
||||
category = Column(String(50), nullable=True)
|
||||
tags = Column(JSON, default=list)
|
||||
|
||||
# Usage tracking
|
||||
usage_count = Column(Integer, default=0)
|
||||
last_used_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
created_by = relationship("User", back_populates="created_tools")
|
||||
executions = relationship(
|
||||
"ToolExecution", back_populates="tool", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Tool(id={self.id}, name='{self.name}', type='{self.tool_type}')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert tool to dictionary"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"display_name": self.display_name,
|
||||
"description": self.description,
|
||||
"tool_type": self.tool_type,
|
||||
"parameters_schema": self.parameters_schema,
|
||||
"return_schema": self.return_schema,
|
||||
"timeout_seconds": self.timeout_seconds,
|
||||
"max_memory_mb": self.max_memory_mb,
|
||||
"max_cpu_seconds": self.max_cpu_seconds,
|
||||
"docker_image": self.docker_image,
|
||||
"is_public": self.is_public,
|
||||
"is_approved": self.is_approved,
|
||||
"created_by_user_id": self.created_by_user_id,
|
||||
"category": self.category,
|
||||
"tags": self.tags,
|
||||
"usage_count": self.usage_count,
|
||||
"last_used_at": self.last_used_at.isoformat()
|
||||
if self.last_used_at
|
||||
else None,
|
||||
"is_active": self.is_active,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
def increment_usage(self):
|
||||
"""Increment usage count and update last used timestamp"""
|
||||
self.usage_count += 1
|
||||
self.last_used_at = datetime.utcnow()
|
||||
|
||||
def can_be_used_by(self, user) -> bool:
|
||||
"""Check if user can use this tool"""
|
||||
# Tool creator can always use their tools
|
||||
if self.created_by_user_id == user.id:
|
||||
return True
|
||||
|
||||
# Public and approved tools can be used by anyone
|
||||
if self.is_public and self.is_approved:
|
||||
return True
|
||||
|
||||
# Admin users can use any tool
|
||||
if user.has_permission("manage_tools"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class ToolExecution(Base):
|
||||
"""Tool execution instance model"""
|
||||
|
||||
__tablename__ = "tool_executions"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
# Tool and user references
|
||||
tool_id = Column(Integer, ForeignKey("tools.id"), nullable=False)
|
||||
executed_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
|
||||
# Execution details
|
||||
parameters = Column(JSON, default=dict) # Input parameters
|
||||
status = Column(String(20), nullable=False, default=ToolStatus.PENDING)
|
||||
|
||||
# Results
|
||||
output = Column(Text, nullable=True) # Tool output
|
||||
error_message = Column(Text, nullable=True) # Error details if failed
|
||||
return_code = Column(Integer, nullable=True) # Exit code
|
||||
|
||||
# Resource usage
|
||||
execution_time_ms = Column(Integer, nullable=True) # Actual execution time
|
||||
memory_used_mb = Column(Float, nullable=True) # Peak memory usage
|
||||
cpu_time_ms = Column(Integer, nullable=True) # CPU time used
|
||||
|
||||
# Docker execution details
|
||||
container_id = Column(String(100), nullable=True) # Docker container ID
|
||||
docker_logs = Column(Text, nullable=True) # Docker container logs
|
||||
|
||||
# Timestamps
|
||||
started_at = Column(DateTime, nullable=True)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
tool = relationship("Tool", back_populates="executions")
|
||||
executed_by = relationship("User", back_populates="tool_executions")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ToolExecution(id={self.id}, tool_id={self.tool_id}, status='{self.status}')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert execution to dictionary"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"tool_id": self.tool_id,
|
||||
"tool_name": self.tool.name if self.tool else None,
|
||||
"executed_by_user_id": self.executed_by_user_id,
|
||||
"parameters": self.parameters,
|
||||
"status": self.status,
|
||||
"output": self.output,
|
||||
"error_message": self.error_message,
|
||||
"return_code": self.return_code,
|
||||
"execution_time_ms": self.execution_time_ms,
|
||||
"memory_used_mb": self.memory_used_mb,
|
||||
"cpu_time_ms": self.cpu_time_ms,
|
||||
"container_id": self.container_id,
|
||||
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||
"completed_at": self.completed_at.isoformat()
|
||||
if self.completed_at
|
||||
else None,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
@property
|
||||
def duration_seconds(self) -> float:
|
||||
"""Calculate execution duration in seconds"""
|
||||
if self.started_at and self.completed_at:
|
||||
return (self.completed_at - self.started_at).total_seconds()
|
||||
return 0.0
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Check if execution is currently running"""
|
||||
return self.status in [ToolStatus.PENDING, ToolStatus.RUNNING]
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
"""Check if execution is finished (success or failure)"""
|
||||
return self.status in [
|
||||
ToolStatus.COMPLETED,
|
||||
ToolStatus.FAILED,
|
||||
ToolStatus.TIMEOUT,
|
||||
ToolStatus.CANCELLED,
|
||||
]
|
||||
|
||||
|
||||
class ToolCategory(Base):
|
||||
"""Tool category for organization"""
|
||||
|
||||
__tablename__ = "tool_categories"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(50), unique=True, nullable=False)
|
||||
display_name = Column(String(100), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# Visual
|
||||
icon = Column(String(50), nullable=True) # Icon name
|
||||
color = Column(String(20), nullable=True) # Color code
|
||||
|
||||
# Ordering
|
||||
sort_order = Column(Integer, default=0)
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ToolCategory(id={self.id}, name='{self.name}')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert category to dictionary"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"display_name": self.display_name,
|
||||
"description": self.description,
|
||||
"icon": self.icon,
|
||||
"color": self.color,
|
||||
"sort_order": self.sort_order,
|
||||
"is_active": self.is_active,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
@@ -4,7 +4,17 @@ Usage Tracking model for API key usage statistics
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON, ForeignKey, Float
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
Boolean,
|
||||
Text,
|
||||
JSON,
|
||||
ForeignKey,
|
||||
Float,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db.database import Base
|
||||
|
||||
@@ -82,7 +92,7 @@ class UsageTracking(Base):
|
||||
"ip_address": self.ip_address,
|
||||
"user_agent": self.user_agent,
|
||||
"request_metadata": self.request_metadata,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -102,7 +112,7 @@ class UsageTracking(Base):
|
||||
session_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
request_metadata: Optional[dict] = None
|
||||
request_metadata: Optional[dict] = None,
|
||||
) -> "UsageTracking":
|
||||
"""Create a new usage tracking record"""
|
||||
return cls(
|
||||
@@ -121,5 +131,5 @@ class UsageTracking(Base):
|
||||
session_id=session_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_metadata=request_metadata or {}
|
||||
request_metadata=request_metadata or {},
|
||||
)
|
||||
@@ -4,16 +4,21 @@ User model
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from enum import Enum
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
Boolean,
|
||||
Text,
|
||||
JSON,
|
||||
ForeignKey,
|
||||
Numeric,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class UserRole(str, Enum):
|
||||
"""User role enumeration"""
|
||||
USER = "user"
|
||||
ADMIN = "admin"
|
||||
SUPER_ADMIN = "super_admin"
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
class User(Base):
|
||||
@@ -27,14 +32,21 @@ class User(Base):
|
||||
hashed_password = Column(String, nullable=False)
|
||||
full_name = Column(String, nullable=True)
|
||||
|
||||
# User status and permissions
|
||||
# Account status
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_superuser = Column(Boolean, default=False)
|
||||
is_verified = Column(Boolean, default=False)
|
||||
is_superuser = Column(Boolean, default=False) # Legacy field for compatibility
|
||||
|
||||
# Role-based access control
|
||||
role = Column(String, default=UserRole.USER.value) # user, admin, super_admin
|
||||
permissions = Column(JSON, default=dict) # Custom permissions
|
||||
# Role-based access control (using new Role model)
|
||||
role_id = Column(Integer, ForeignKey("roles.id"), nullable=True)
|
||||
custom_permissions = Column(JSON, default=dict) # Custom permissions override
|
||||
|
||||
# Account management
|
||||
account_locked = Column(Boolean, default=False)
|
||||
account_locked_until = Column(DateTime, nullable=True)
|
||||
failed_login_attempts = Column(Integer, default=0)
|
||||
last_failed_login = Column(DateTime, nullable=True)
|
||||
force_password_change = Column(Boolean, default=False)
|
||||
|
||||
# Profile information
|
||||
avatar_url = Column(String, nullable=True)
|
||||
@@ -52,27 +64,59 @@ class User(Base):
|
||||
notification_settings = Column(JSON, default=dict)
|
||||
|
||||
# Relationships
|
||||
api_keys = relationship("APIKey", back_populates="user", cascade="all, delete-orphan")
|
||||
usage_tracking = relationship("UsageTracking", back_populates="user", cascade="all, delete-orphan")
|
||||
budgets = relationship("Budget", back_populates="user", cascade="all, delete-orphan")
|
||||
audit_logs = relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
|
||||
role = relationship("Role", back_populates="users")
|
||||
api_keys = relationship(
|
||||
"APIKey", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
usage_tracking = relationship(
|
||||
"UsageTracking", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
budgets = relationship(
|
||||
"Budget", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
audit_logs = relationship(
|
||||
"AuditLog", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
installed_plugins = relationship("Plugin", back_populates="installed_by_user")
|
||||
created_tools = relationship(
|
||||
"Tool", back_populates="created_by", cascade="all, delete-orphan"
|
||||
)
|
||||
tool_executions = relationship(
|
||||
"ToolExecution", back_populates="executed_by", cascade="all, delete-orphan"
|
||||
)
|
||||
notifications = relationship(
|
||||
"Notification", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User(id={self.id}, email='{self.email}', username='{self.username}')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert user to dictionary for API responses"""
|
||||
# Check if role relationship is loaded to avoid lazy loading in async context
|
||||
inspector = sa_inspect(self)
|
||||
role_loaded = "role" not in inspector.unloaded
|
||||
|
||||
return {
|
||||
"id": self.id,
|
||||
"email": self.email,
|
||||
"username": self.username,
|
||||
"full_name": self.full_name,
|
||||
"is_active": self.is_active,
|
||||
"is_superuser": self.is_superuser,
|
||||
"is_verified": self.is_verified,
|
||||
"role": self.role,
|
||||
"permissions": self.permissions,
|
||||
"is_superuser": self.is_superuser,
|
||||
"role_id": self.role_id,
|
||||
"role": self.role.to_dict() if role_loaded and self.role else None,
|
||||
"custom_permissions": self.custom_permissions,
|
||||
"account_locked": self.account_locked,
|
||||
"account_locked_until": self.account_locked_until.isoformat()
|
||||
if self.account_locked_until
|
||||
else None,
|
||||
"failed_login_attempts": self.failed_login_attempts,
|
||||
"last_failed_login": self.last_failed_login.isoformat()
|
||||
if self.last_failed_login
|
||||
else None,
|
||||
"force_password_change": self.force_password_change,
|
||||
"avatar_url": self.avatar_url,
|
||||
"bio": self.bio,
|
||||
"company": self.company,
|
||||
@@ -81,35 +125,49 @@ class User(Base):
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"last_login": self.last_login.isoformat() if self.last_login else None,
|
||||
"preferences": self.preferences,
|
||||
"notification_settings": self.notification_settings
|
||||
"notification_settings": self.notification_settings,
|
||||
}
|
||||
|
||||
def has_permission(self, permission: str) -> bool:
|
||||
"""Check if user has a specific permission"""
|
||||
"""Check if user has a specific permission using role hierarchy"""
|
||||
if self.is_superuser:
|
||||
return True
|
||||
|
||||
# Check role-based permissions
|
||||
role_permissions = {
|
||||
"user": ["read_own", "create_own", "update_own"],
|
||||
"admin": ["read_all", "create_all", "update_all", "delete_own"],
|
||||
"super_admin": ["read_all", "create_all", "update_all", "delete_all", "manage_users", "manage_modules"]
|
||||
}
|
||||
|
||||
if self.role in role_permissions and permission in role_permissions[self.role]:
|
||||
# Check custom permissions first (override)
|
||||
if permission in self.custom_permissions.get("denied", []):
|
||||
return False
|
||||
if permission in self.custom_permissions.get("granted", []):
|
||||
return True
|
||||
|
||||
# Check custom permissions
|
||||
return permission in self.permissions
|
||||
# Check role permissions if user has a role assigned
|
||||
if self.role:
|
||||
return self.role.has_permission(permission)
|
||||
|
||||
return False
|
||||
|
||||
def can_access_module(self, module_name: str) -> bool:
|
||||
"""Check if user can access a specific module"""
|
||||
if self.is_superuser:
|
||||
return True
|
||||
|
||||
# Check module-specific permissions
|
||||
module_permissions = self.permissions.get("modules", {})
|
||||
return module_permissions.get(module_name, False)
|
||||
# Check custom permissions first
|
||||
module_permissions = self.custom_permissions.get("modules", {})
|
||||
if module_name in module_permissions:
|
||||
return module_permissions[module_name]
|
||||
|
||||
# Check role permissions
|
||||
if self.role:
|
||||
# For admin roles, allow all modules
|
||||
if self.role.level in ["admin", "super_admin"]:
|
||||
return True
|
||||
# For regular users, check module access
|
||||
elif self.role.level == "user":
|
||||
return True # Basic users can access all modules
|
||||
# For read-only users, limit access
|
||||
elif self.role.level == "read_only":
|
||||
return module_name in ["chatbot", "analytics"] # Only certain modules
|
||||
|
||||
return False
|
||||
|
||||
def update_last_login(self):
|
||||
"""Update the last login timestamp"""
|
||||
@@ -127,8 +185,97 @@ class User(Base):
|
||||
self.notification_settings = {}
|
||||
self.notification_settings.update(settings)
|
||||
|
||||
def get_effective_permissions(self) -> dict:
|
||||
"""Get all effective permissions combining role and custom permissions"""
|
||||
permissions = {"granted": set(), "denied": set()}
|
||||
|
||||
# Start with role permissions
|
||||
if self.role:
|
||||
role_perms = self.role.permissions
|
||||
permissions["granted"].update(role_perms.get("granted", []))
|
||||
permissions["denied"].update(role_perms.get("denied", []))
|
||||
|
||||
# Apply custom permissions (override role permissions)
|
||||
permissions["granted"].update(self.custom_permissions.get("granted", []))
|
||||
permissions["denied"].update(self.custom_permissions.get("denied", []))
|
||||
|
||||
# Remove any denied permissions from granted
|
||||
permissions["granted"] -= permissions["denied"]
|
||||
|
||||
return {
|
||||
"granted": list(permissions["granted"]),
|
||||
"denied": list(permissions["denied"]),
|
||||
}
|
||||
|
||||
def can_create_api_key(self) -> bool:
|
||||
"""Check if user can create API keys based on role and limits"""
|
||||
if not self.is_active or self.account_locked:
|
||||
return False
|
||||
|
||||
# Check permission
|
||||
if not self.has_permission("create_api_key"):
|
||||
return False
|
||||
|
||||
# Check if user has reached their API key limit
|
||||
current_keys = [key for key in self.api_keys if key.is_active]
|
||||
max_keys = (
|
||||
self.role.permissions.get("limits", {}).get("max_api_keys", 5)
|
||||
if self.role
|
||||
else 5
|
||||
)
|
||||
|
||||
return len(current_keys) < max_keys
|
||||
|
||||
def can_create_tool(self) -> bool:
|
||||
"""Check if user can create custom tools"""
|
||||
return (
|
||||
self.is_active
|
||||
and not self.account_locked
|
||||
and self.has_permission("create_tool")
|
||||
)
|
||||
|
||||
def is_budget_exceeded(self) -> bool:
|
||||
"""Check if user has exceeded their budget limits"""
|
||||
if not self.budgets:
|
||||
return False
|
||||
|
||||
active_budget = next((b for b in self.budgets if b.is_active), None)
|
||||
if not active_budget:
|
||||
return False
|
||||
|
||||
return active_budget.current_usage > active_budget.limit
|
||||
|
||||
def lock_account(self, duration_hours: int = 24):
|
||||
"""Lock user account for specified duration"""
|
||||
from datetime import timedelta
|
||||
|
||||
self.account_locked = True
|
||||
self.account_locked_until = datetime.utcnow() + timedelta(hours=duration_hours)
|
||||
|
||||
def unlock_account(self):
|
||||
"""Unlock user account"""
|
||||
self.account_locked = False
|
||||
self.account_locked_until = None
|
||||
self.failed_login_attempts = 0
|
||||
|
||||
def record_failed_login(self):
|
||||
"""Record a failed login attempt"""
|
||||
self.failed_login_attempts += 1
|
||||
self.last_failed_login = datetime.utcnow()
|
||||
|
||||
# Lock account after 5 failed attempts
|
||||
if self.failed_login_attempts >= 5:
|
||||
self.lock_account(24) # Lock for 24 hours
|
||||
|
||||
def reset_failed_logins(self):
|
||||
"""Reset failed login counter"""
|
||||
self.failed_login_attempts = 0
|
||||
self.last_failed_login = None
|
||||
|
||||
@classmethod
|
||||
def create_default_admin(cls, email: str, username: str, password_hash: str) -> "User":
|
||||
def create_default_admin(
|
||||
cls, email: str, username: str, password_hash: str
|
||||
) -> "User":
|
||||
"""Create a default admin user"""
|
||||
return cls(
|
||||
email=email,
|
||||
@@ -136,24 +283,16 @@ class User(Base):
|
||||
hashed_password=password_hash,
|
||||
full_name="System Administrator",
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
is_superuser=True, # Legacy compatibility
|
||||
is_verified=True,
|
||||
role="super_admin",
|
||||
permissions={
|
||||
"modules": {
|
||||
"cache": True,
|
||||
"analytics": True,
|
||||
"rag": True
|
||||
}
|
||||
},
|
||||
preferences={
|
||||
"theme": "dark",
|
||||
"language": "en",
|
||||
"timezone": "UTC"
|
||||
# Note: role_id will be set after role is created in init_db
|
||||
custom_permissions={
|
||||
"modules": {"cache": True, "analytics": True, "rag": True}
|
||||
},
|
||||
preferences={"theme": "dark", "language": "en", "timezone": "UTC"},
|
||||
notification_settings={
|
||||
"email_notifications": True,
|
||||
"security_alerts": True,
|
||||
"system_updates": True
|
||||
}
|
||||
"system_updates": True,
|
||||
},
|
||||
)
|
||||
@@ -15,7 +15,4 @@ __version__ = "1.0.0"
|
||||
__author__ = "Enclava Team"
|
||||
|
||||
# Export main classes for easy importing
|
||||
__all__ = [
|
||||
"ChatbotModule",
|
||||
"create_module"
|
||||
]
|
||||
__all__ = ["ChatbotModule", "create_module"]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,7 +23,7 @@ from .protocols import (
|
||||
ChatbotServiceProtocol,
|
||||
LiteLLMClientProtocol,
|
||||
WorkflowServiceProtocol,
|
||||
ServiceRegistry
|
||||
ServiceRegistry,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -36,7 +36,9 @@ class ModuleFactory:
|
||||
self.modules: Dict[str, Any] = {}
|
||||
self.initialized = False
|
||||
|
||||
async def create_all_modules(self, config: Optional[Dict[str, Any]] = None) -> ServiceRegistry:
|
||||
async def create_all_modules(
|
||||
self, config: Optional[Dict[str, Any]] = None
|
||||
) -> ServiceRegistry:
|
||||
"""
|
||||
Create all modules with proper dependency injection
|
||||
|
||||
@@ -59,7 +61,7 @@ class ModuleFactory:
|
||||
# Step 3: Create chatbot module with RAG dependency
|
||||
chatbot_module = create_chatbot_module(
|
||||
litellm_client=litellm_client,
|
||||
rag_service=rag_module # RAG module implements RAGServiceProtocol
|
||||
rag_service=rag_module, # RAG module implements RAGServiceProtocol
|
||||
)
|
||||
|
||||
# Step 4: Create workflow module with chatbot dependency
|
||||
@@ -71,7 +73,7 @@ class ModuleFactory:
|
||||
modules = {
|
||||
"rag": rag_module,
|
||||
"chatbot": chatbot_module,
|
||||
"workflow": workflow_module
|
||||
"workflow": workflow_module,
|
||||
}
|
||||
|
||||
logger.info(f"Created {len(modules)} modules with dependencies wired")
|
||||
@@ -84,14 +86,16 @@ class ModuleFactory:
|
||||
|
||||
return modules
|
||||
|
||||
async def _initialize_modules(self, modules: Dict[str, Any], config: Dict[str, Any]):
|
||||
async def _initialize_modules(
|
||||
self, modules: Dict[str, Any], config: Dict[str, Any]
|
||||
):
|
||||
"""Initialize all modules in dependency order"""
|
||||
|
||||
# Initialize in dependency order (modules with no deps first)
|
||||
initialization_order = [
|
||||
("rag", modules["rag"]),
|
||||
("chatbot", modules["chatbot"]), # Depends on RAG
|
||||
("workflow", modules["workflow"]) # Depends on Chatbot
|
||||
("workflow", modules["workflow"]), # Depends on Chatbot
|
||||
]
|
||||
|
||||
for module_name, module in initialization_order:
|
||||
@@ -100,7 +104,7 @@ class ModuleFactory:
|
||||
module_config = config.get(module_name, {})
|
||||
|
||||
# Different modules have different initialization patterns
|
||||
if hasattr(module, 'initialize'):
|
||||
if hasattr(module, "initialize"):
|
||||
if module_name == "rag":
|
||||
await module.initialize()
|
||||
else:
|
||||
@@ -110,7 +114,9 @@ class ModuleFactory:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to initialize {module_name} module: {e}")
|
||||
raise RuntimeError(f"Module initialization failed: {module_name}") from e
|
||||
raise RuntimeError(
|
||||
f"Module initialization failed: {module_name}"
|
||||
) from e
|
||||
|
||||
async def cleanup_all_modules(self):
|
||||
"""Cleanup all modules in reverse dependency order"""
|
||||
@@ -125,7 +131,7 @@ class ModuleFactory:
|
||||
try:
|
||||
logger.info(f"Cleaning up {module_name} module...")
|
||||
module = self.modules[module_name]
|
||||
if hasattr(module, 'cleanup'):
|
||||
if hasattr(module, "cleanup"):
|
||||
await module.cleanup()
|
||||
logger.info(f"✅ {module_name} module cleaned up")
|
||||
except Exception as e:
|
||||
@@ -174,13 +180,16 @@ def create_rag_module(config: Optional[Dict[str, Any]] = None) -> RAGModule:
|
||||
return RAGModule(config=config or {})
|
||||
|
||||
|
||||
def create_chatbot_with_rag(rag_service: RAGServiceProtocol,
|
||||
litellm_client: LiteLLMClientProtocol) -> ChatbotModule:
|
||||
def create_chatbot_with_rag(
|
||||
rag_service: RAGServiceProtocol, litellm_client: LiteLLMClientProtocol
|
||||
) -> ChatbotModule:
|
||||
"""Create chatbot module with RAG dependency"""
|
||||
return create_chatbot_module(litellm_client=litellm_client, rag_service=rag_service)
|
||||
|
||||
|
||||
def create_workflow_with_chatbot(chatbot_service: ChatbotServiceProtocol) -> WorkflowModule:
|
||||
def create_workflow_with_chatbot(
|
||||
chatbot_service: ChatbotServiceProtocol,
|
||||
) -> WorkflowModule:
|
||||
"""Create workflow module with chatbot dependency"""
|
||||
return WorkflowModule(chatbot_service=chatbot_service)
|
||||
|
||||
|
||||
@@ -14,7 +14,9 @@ class RAGServiceProtocol(Protocol):
|
||||
"""Protocol for RAG (Retrieval-Augmented Generation) service interface"""
|
||||
|
||||
@abstractmethod
|
||||
async def search(self, query: str, collection_name: str, top_k: int) -> Dict[str, Any]:
|
||||
async def search(
|
||||
self, query: str, collection_name: str, top_k: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Search for relevant documents
|
||||
|
||||
@@ -29,7 +31,9 @@ class RAGServiceProtocol(Protocol):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def index_document(self, content: str, metadata: Dict[str, Any] = None) -> str:
|
||||
async def index_document(
|
||||
self, content: str, metadata: Dict[str, Any] = None
|
||||
) -> str:
|
||||
"""
|
||||
Index a document in the vector database
|
||||
|
||||
@@ -94,7 +98,9 @@ class LiteLLMClientProtocol(Protocol):
|
||||
"""Protocol for LiteLLM client interface"""
|
||||
|
||||
@abstractmethod
|
||||
async def completion(self, model: str, messages: List[Dict[str, str]], **kwargs) -> Any:
|
||||
async def completion(
|
||||
self, model: str, messages: List[Dict[str, str]], **kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Create a completion using the specified model
|
||||
|
||||
@@ -109,8 +115,14 @@ class LiteLLMClientProtocol(Protocol):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def create_chat_completion(self, model: str, messages: List[Dict[str, str]],
|
||||
user_id: str, api_key_id: str, **kwargs) -> Any:
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
user_id: str,
|
||||
api_key_id: str,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
Create a chat completion with user tracking
|
||||
|
||||
@@ -207,7 +219,9 @@ class WorkflowServiceProtocol(Protocol):
|
||||
"""Protocol for Workflow service interface"""
|
||||
|
||||
@abstractmethod
|
||||
async def execute_workflow(self, workflow: Any, input_data: Dict[str, Any] = None) -> Any:
|
||||
async def execute_workflow(
|
||||
self, workflow: Any, input_data: Dict[str, Any] = None
|
||||
) -> Any:
|
||||
"""
|
||||
Execute a workflow definition
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,77 +13,102 @@ from pathlib import Path
|
||||
|
||||
class PluginRuntimeSpec(BaseModel):
|
||||
"""Plugin runtime requirements and dependencies"""
|
||||
python_version: str = Field("3.11", description="Required Python version")
|
||||
dependencies: List[str] = Field(default_factory=list, description="Required Python packages")
|
||||
environment_variables: Dict[str, str] = Field(default_factory=dict, description="Required environment variables")
|
||||
|
||||
@validator('python_version')
|
||||
python_version: str = Field("3.11", description="Required Python version")
|
||||
dependencies: List[str] = Field(
|
||||
default_factory=list, description="Required Python packages"
|
||||
)
|
||||
environment_variables: Dict[str, str] = Field(
|
||||
default_factory=dict, description="Required environment variables"
|
||||
)
|
||||
|
||||
@validator("python_version")
|
||||
def validate_python_version(cls, v):
|
||||
if not v.startswith(('3.9', '3.10', '3.11', '3.12')):
|
||||
raise ValueError('Python version must be 3.9, 3.10, 3.11, or 3.12')
|
||||
if not v.startswith(("3.9", "3.10", "3.11", "3.12")):
|
||||
raise ValueError("Python version must be 3.9, 3.10, 3.11, or 3.12")
|
||||
return v
|
||||
|
||||
|
||||
class PluginPermissions(BaseModel):
|
||||
"""Plugin permission specifications"""
|
||||
platform_apis: List[str] = Field(default_factory=list, description="Platform API access scopes")
|
||||
plugin_scopes: List[str] = Field(default_factory=list, description="Plugin-specific permission scopes")
|
||||
external_domains: List[str] = Field(default_factory=list, description="Allowed external domains")
|
||||
|
||||
@validator('platform_apis')
|
||||
platform_apis: List[str] = Field(
|
||||
default_factory=list, description="Platform API access scopes"
|
||||
)
|
||||
plugin_scopes: List[str] = Field(
|
||||
default_factory=list, description="Plugin-specific permission scopes"
|
||||
)
|
||||
external_domains: List[str] = Field(
|
||||
default_factory=list, description="Allowed external domains"
|
||||
)
|
||||
|
||||
@validator("platform_apis")
|
||||
def validate_platform_apis(cls, v):
|
||||
allowed_apis = [
|
||||
'chatbot:invoke', 'chatbot:manage', 'chatbot:read',
|
||||
'rag:query', 'rag:manage', 'rag:read',
|
||||
'llm:completion', 'llm:embeddings', 'llm:models',
|
||||
'workflow:execute', 'workflow:read',
|
||||
'cache:read', 'cache:write'
|
||||
"chatbot:invoke",
|
||||
"chatbot:manage",
|
||||
"chatbot:read",
|
||||
"rag:query",
|
||||
"rag:manage",
|
||||
"rag:read",
|
||||
"llm:completion",
|
||||
"llm:embeddings",
|
||||
"llm:models",
|
||||
"workflow:execute",
|
||||
"workflow:read",
|
||||
"cache:read",
|
||||
"cache:write",
|
||||
]
|
||||
for api in v:
|
||||
if api not in allowed_apis and not api.endswith(':*'):
|
||||
raise ValueError(f'Invalid platform API scope: {api}')
|
||||
if api not in allowed_apis and not api.endswith(":*"):
|
||||
raise ValueError(f"Invalid platform API scope: {api}")
|
||||
return v
|
||||
|
||||
|
||||
class PluginDatabaseSpec(BaseModel):
|
||||
"""Plugin database configuration"""
|
||||
|
||||
schema: str = Field(..., description="Database schema name")
|
||||
migrations_path: str = Field("./migrations", description="Path to migration files")
|
||||
auto_migrate: bool = Field(True, description="Auto-run migrations on startup")
|
||||
|
||||
@validator('schema')
|
||||
@validator("schema")
|
||||
def validate_schema_name(cls, v):
|
||||
if not v.startswith('plugin_'):
|
||||
if not v.startswith("plugin_"):
|
||||
raise ValueError('Database schema must start with "plugin_"')
|
||||
if not v.replace('plugin_', '').replace('_', '').isalnum():
|
||||
raise ValueError('Schema name must contain only alphanumeric characters and underscores')
|
||||
if not v.replace("plugin_", "").replace("_", "").isalnum():
|
||||
raise ValueError(
|
||||
"Schema name must contain only alphanumeric characters and underscores"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class PluginAPIEndpoint(BaseModel):
|
||||
"""Plugin API endpoint specification"""
|
||||
|
||||
path: str = Field(..., description="API endpoint path")
|
||||
methods: List[str] = Field(default=['GET'], description="Allowed HTTP methods")
|
||||
methods: List[str] = Field(default=["GET"], description="Allowed HTTP methods")
|
||||
description: str = Field("", description="Endpoint description")
|
||||
auth_required: bool = Field(True, description="Whether authentication is required")
|
||||
|
||||
@validator('methods')
|
||||
@validator("methods")
|
||||
def validate_methods(cls, v):
|
||||
allowed_methods = ['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS']
|
||||
allowed_methods = ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]
|
||||
for method in v:
|
||||
if method not in allowed_methods:
|
||||
raise ValueError(f'Invalid HTTP method: {method}')
|
||||
raise ValueError(f"Invalid HTTP method: {method}")
|
||||
return v
|
||||
|
||||
@validator('path')
|
||||
@validator("path")
|
||||
def validate_path(cls, v):
|
||||
if not v.startswith('/'):
|
||||
if not v.startswith("/"):
|
||||
raise ValueError('API path must start with "/"')
|
||||
return v
|
||||
|
||||
|
||||
class PluginCronJob(BaseModel):
|
||||
"""Plugin scheduled job specification"""
|
||||
|
||||
name: str = Field(..., description="Job name")
|
||||
schedule: str = Field(..., description="Cron expression")
|
||||
function: str = Field(..., description="Function to execute")
|
||||
@@ -92,40 +117,55 @@ class PluginCronJob(BaseModel):
|
||||
timeout_seconds: int = Field(300, description="Job timeout in seconds")
|
||||
max_retries: int = Field(3, description="Maximum retry attempts")
|
||||
|
||||
@validator('schedule')
|
||||
@validator("schedule")
|
||||
def validate_cron_expression(cls, v):
|
||||
# Basic cron validation - should have 5 parts
|
||||
parts = v.split()
|
||||
if len(parts) != 5:
|
||||
raise ValueError('Cron expression must have 5 parts (minute hour day month weekday)')
|
||||
raise ValueError(
|
||||
"Cron expression must have 5 parts (minute hour day month weekday)"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class PluginUIConfig(BaseModel):
|
||||
"""Plugin UI configuration"""
|
||||
configuration_schema: str = Field("./config_schema.json", description="JSON schema for configuration")
|
||||
ui_components: str = Field("./ui/components", description="Path to UI components")
|
||||
pages: List[Dict[str, str]] = Field(default_factory=list, description="Plugin pages")
|
||||
|
||||
@validator('pages')
|
||||
configuration_schema: str = Field(
|
||||
"./config_schema.json", description="JSON schema for configuration"
|
||||
)
|
||||
ui_components: str = Field("./ui/components", description="Path to UI components")
|
||||
pages: List[Dict[str, str]] = Field(
|
||||
default_factory=list, description="Plugin pages"
|
||||
)
|
||||
|
||||
@validator("pages")
|
||||
def validate_pages(cls, v):
|
||||
required_fields = ['name', 'path', 'component']
|
||||
required_fields = ["name", "path", "component"]
|
||||
for page in v:
|
||||
for field in required_fields:
|
||||
if field not in page:
|
||||
raise ValueError(f'Page must have {field} field')
|
||||
raise ValueError(f"Page must have {field} field")
|
||||
return v
|
||||
|
||||
|
||||
class PluginExternalServices(BaseModel):
|
||||
"""Plugin external service configuration"""
|
||||
allowed_domains: List[str] = Field(default_factory=list, description="Allowed external domains")
|
||||
webhooks: List[Dict[str, str]] = Field(default_factory=list, description="Webhook configurations")
|
||||
rate_limits: Dict[str, int] = Field(default_factory=dict, description="Rate limits per domain")
|
||||
|
||||
allowed_domains: List[str] = Field(
|
||||
default_factory=list, description="Allowed external domains"
|
||||
)
|
||||
webhooks: List[Dict[str, str]] = Field(
|
||||
default_factory=list, description="Webhook configurations"
|
||||
)
|
||||
rate_limits: Dict[str, int] = Field(
|
||||
default_factory=dict, description="Rate limits per domain"
|
||||
)
|
||||
|
||||
|
||||
class PluginMetadata(BaseModel):
|
||||
"""Plugin metadata information"""
|
||||
|
||||
name: str = Field(..., description="Plugin name (must be unique)")
|
||||
version: str = Field(..., description="Plugin version (semantic versioning)")
|
||||
description: str = Field(..., description="Plugin description")
|
||||
@@ -133,58 +173,78 @@ class PluginMetadata(BaseModel):
|
||||
license: str = Field("MIT", description="Plugin license")
|
||||
homepage: Optional[HttpUrl] = Field(None, description="Plugin homepage URL")
|
||||
repository: Optional[HttpUrl] = Field(None, description="Plugin repository URL")
|
||||
tags: List[str] = Field(default_factory=list, description="Plugin tags for discovery")
|
||||
tags: List[str] = Field(
|
||||
default_factory=list, description="Plugin tags for discovery"
|
||||
)
|
||||
|
||||
@validator('name')
|
||||
@validator("name")
|
||||
def validate_name(cls, v):
|
||||
if not v.replace('-', '').replace('_', '').isalnum():
|
||||
raise ValueError('Plugin name must contain only alphanumeric characters, hyphens, and underscores')
|
||||
if not v.replace("-", "").replace("_", "").isalnum():
|
||||
raise ValueError(
|
||||
"Plugin name must contain only alphanumeric characters, hyphens, and underscores"
|
||||
)
|
||||
if len(v) < 3 or len(v) > 50:
|
||||
raise ValueError('Plugin name must be between 3 and 50 characters')
|
||||
raise ValueError("Plugin name must be between 3 and 50 characters")
|
||||
return v.lower()
|
||||
|
||||
@validator('version')
|
||||
@validator("version")
|
||||
def validate_version(cls, v):
|
||||
# Basic semantic versioning validation
|
||||
parts = v.split('.')
|
||||
parts = v.split(".")
|
||||
if len(parts) != 3:
|
||||
raise ValueError('Version must follow semantic versioning (x.y.z)')
|
||||
raise ValueError("Version must follow semantic versioning (x.y.z)")
|
||||
for part in parts:
|
||||
if not part.isdigit():
|
||||
raise ValueError('Version parts must be numeric')
|
||||
raise ValueError("Version parts must be numeric")
|
||||
return v
|
||||
|
||||
|
||||
class PluginManifest(BaseModel):
|
||||
"""Complete plugin manifest specification"""
|
||||
|
||||
apiVersion: str = Field("v1", description="Manifest API version")
|
||||
kind: str = Field("Plugin", description="Resource kind")
|
||||
metadata: PluginMetadata = Field(..., description="Plugin metadata")
|
||||
spec: "PluginSpec" = Field(..., description="Plugin specification")
|
||||
|
||||
@validator('apiVersion')
|
||||
@validator("apiVersion")
|
||||
def validate_api_version(cls, v):
|
||||
if v not in ['v1']:
|
||||
raise ValueError('Unsupported API version')
|
||||
if v not in ["v1"]:
|
||||
raise ValueError("Unsupported API version")
|
||||
return v
|
||||
|
||||
@validator('kind')
|
||||
@validator("kind")
|
||||
def validate_kind(cls, v):
|
||||
if v != 'Plugin':
|
||||
if v != "Plugin":
|
||||
raise ValueError('Kind must be "Plugin"')
|
||||
return v
|
||||
|
||||
|
||||
class PluginSpec(BaseModel):
|
||||
"""Plugin specification details"""
|
||||
runtime: PluginRuntimeSpec = Field(default_factory=PluginRuntimeSpec, description="Runtime requirements")
|
||||
permissions: PluginPermissions = Field(default_factory=PluginPermissions, description="Permission requirements")
|
||||
database: Optional[PluginDatabaseSpec] = Field(None, description="Database configuration")
|
||||
api_endpoints: List[PluginAPIEndpoint] = Field(default_factory=list, description="API endpoints")
|
||||
cron_jobs: List[PluginCronJob] = Field(default_factory=list, description="Scheduled jobs")
|
||||
|
||||
runtime: PluginRuntimeSpec = Field(
|
||||
default_factory=PluginRuntimeSpec, description="Runtime requirements"
|
||||
)
|
||||
permissions: PluginPermissions = Field(
|
||||
default_factory=PluginPermissions, description="Permission requirements"
|
||||
)
|
||||
database: Optional[PluginDatabaseSpec] = Field(
|
||||
None, description="Database configuration"
|
||||
)
|
||||
api_endpoints: List[PluginAPIEndpoint] = Field(
|
||||
default_factory=list, description="API endpoints"
|
||||
)
|
||||
cron_jobs: List[PluginCronJob] = Field(
|
||||
default_factory=list, description="Scheduled jobs"
|
||||
)
|
||||
ui_config: Optional[PluginUIConfig] = Field(None, description="UI configuration")
|
||||
external_services: Optional[PluginExternalServices] = Field(None, description="External service configuration")
|
||||
config_schema: Dict[str, Any] = Field(default_factory=dict, description="Plugin configuration JSON schema")
|
||||
external_services: Optional[PluginExternalServices] = Field(
|
||||
None, description="External service configuration"
|
||||
)
|
||||
config_schema: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Plugin configuration JSON schema"
|
||||
)
|
||||
|
||||
|
||||
# Update forward reference
|
||||
@@ -194,18 +254,14 @@ PluginManifest.model_rebuild()
|
||||
class PluginManifestValidator:
|
||||
"""Plugin manifest validation and parsing utilities"""
|
||||
|
||||
REQUIRED_FILES = [
|
||||
'manifest.yaml',
|
||||
'main.py',
|
||||
'requirements.txt'
|
||||
]
|
||||
REQUIRED_FILES = ["manifest.yaml", "main.py", "requirements.txt"]
|
||||
|
||||
OPTIONAL_FILES = [
|
||||
'config_schema.json',
|
||||
'README.md',
|
||||
'ui/components',
|
||||
'migrations',
|
||||
'tests'
|
||||
"config_schema.json",
|
||||
"README.md",
|
||||
"ui/components",
|
||||
"migrations",
|
||||
"tests",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
@@ -217,7 +273,7 @@ class PluginManifestValidator:
|
||||
raise FileNotFoundError(f"Manifest file not found: {manifest_path}")
|
||||
|
||||
try:
|
||||
with open(manifest_path, 'r', encoding='utf-8') as f:
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest_data = yaml.safe_load(f)
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"Invalid YAML in manifest file: {e}")
|
||||
@@ -243,19 +299,19 @@ class PluginManifestValidator:
|
||||
raise FileNotFoundError(f"Required file missing: {required_file}")
|
||||
|
||||
# Validate main.py contains plugin class
|
||||
main_py_path = plugin_dir / 'main.py'
|
||||
with open(main_py_path, 'r', encoding='utf-8') as f:
|
||||
main_py_path = plugin_dir / "main.py"
|
||||
with open(main_py_path, "r", encoding="utf-8") as f:
|
||||
main_content = f.read()
|
||||
|
||||
if 'BasePlugin' not in main_content:
|
||||
if "BasePlugin" not in main_content:
|
||||
raise ValueError("main.py must contain a class inheriting from BasePlugin")
|
||||
|
||||
# Validate requirements.txt format
|
||||
requirements_path = plugin_dir / 'requirements.txt'
|
||||
with open(requirements_path, 'r', encoding='utf-8') as f:
|
||||
requirements_path = plugin_dir / "requirements.txt"
|
||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||
requirements = f.read().strip()
|
||||
|
||||
if requirements and not all(line.strip() for line in requirements.split('\n')):
|
||||
if requirements and not all(line.strip() for line in requirements.split("\n")):
|
||||
raise ValueError("Invalid requirements.txt format")
|
||||
|
||||
# Validate config schema if specified
|
||||
@@ -264,7 +320,8 @@ class PluginManifestValidator:
|
||||
if schema_path.exists():
|
||||
try:
|
||||
import json
|
||||
with open(schema_path, 'r', encoding='utf-8') as f:
|
||||
|
||||
with open(schema_path, "r", encoding="utf-8") as f:
|
||||
json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON schema: {e}")
|
||||
@@ -283,7 +340,7 @@ class PluginManifestValidator:
|
||||
"compatible": True,
|
||||
"warnings": [],
|
||||
"errors": [],
|
||||
"platform_version": "1.0.0"
|
||||
"platform_version": "1.0.0",
|
||||
}
|
||||
|
||||
# Check platform API compatibility
|
||||
@@ -319,38 +376,55 @@ class PluginManifestValidator:
|
||||
def _is_platform_api_supported(cls, api: str) -> bool:
|
||||
"""Check if platform API is supported"""
|
||||
supported_apis = [
|
||||
'chatbot:invoke', 'chatbot:manage', 'chatbot:read',
|
||||
'rag:query', 'rag:manage', 'rag:read',
|
||||
'llm:completion', 'llm:embeddings', 'llm:models',
|
||||
'workflow:execute', 'workflow:read',
|
||||
'cache:read', 'cache:write'
|
||||
"chatbot:invoke",
|
||||
"chatbot:manage",
|
||||
"chatbot:read",
|
||||
"rag:query",
|
||||
"rag:manage",
|
||||
"rag:read",
|
||||
"llm:completion",
|
||||
"llm:embeddings",
|
||||
"llm:models",
|
||||
"workflow:execute",
|
||||
"workflow:read",
|
||||
"cache:read",
|
||||
"cache:write",
|
||||
]
|
||||
|
||||
# Support wildcard permissions
|
||||
if api.endswith(':*'):
|
||||
if api.endswith(":*"):
|
||||
base_api = api[:-2]
|
||||
return any(supported.startswith(base_api + ':') for supported in supported_apis)
|
||||
return any(
|
||||
supported.startswith(base_api + ":") for supported in supported_apis
|
||||
)
|
||||
|
||||
return api in supported_apis
|
||||
|
||||
@classmethod
|
||||
def _is_python_version_supported(cls, version: str) -> bool:
|
||||
"""Check if Python version is supported"""
|
||||
supported_versions = ['3.9', '3.10', '3.11', '3.12']
|
||||
supported_versions = ["3.9", "3.10", "3.11", "3.12"]
|
||||
return any(version.startswith(v) for v in supported_versions)
|
||||
|
||||
@classmethod
|
||||
def _is_dependency_conflicting(cls, dependency: str) -> bool:
|
||||
"""Check if dependency might conflict with platform"""
|
||||
# Extract package name (before ==, >=, etc.)
|
||||
package_name = dependency.split('==')[0].split('>=')[0].split('<=')[0].split('>')[0].split('<')[0].strip()
|
||||
package_name = (
|
||||
dependency.split("==")[0]
|
||||
.split(">=")[0]
|
||||
.split("<=")[0]
|
||||
.split(">")[0]
|
||||
.split("<")[0]
|
||||
.strip()
|
||||
)
|
||||
|
||||
# Known conflicting packages
|
||||
conflicting_packages = [
|
||||
'sqlalchemy', # Platform uses specific version
|
||||
'fastapi', # Platform uses specific version
|
||||
'pydantic', # Platform uses specific version
|
||||
'alembic' # Platform migration system
|
||||
"sqlalchemy", # Platform uses specific version
|
||||
"fastapi", # Platform uses specific version
|
||||
"pydantic", # Platform uses specific version
|
||||
"alembic", # Platform migration system
|
||||
]
|
||||
|
||||
return package_name.lower() in conflicting_packages
|
||||
@@ -359,8 +433,10 @@ class PluginManifestValidator:
|
||||
def generate_manifest_hash(cls, manifest: PluginManifest) -> str:
|
||||
"""Generate hash for manifest content verification"""
|
||||
manifest_dict = manifest.dict()
|
||||
manifest_str = yaml.dump(manifest_dict, sort_keys=True, default_flow_style=False)
|
||||
return hashlib.sha256(manifest_str.encode('utf-8')).hexdigest()
|
||||
manifest_str = yaml.dump(
|
||||
manifest_dict, sort_keys=True, default_flow_style=False
|
||||
)
|
||||
return hashlib.sha256(manifest_str.encode("utf-8")).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def create_example_manifest(cls, plugin_name: str) -> PluginManifest:
|
||||
@@ -372,29 +448,25 @@ class PluginManifestValidator:
|
||||
description=f"Example {plugin_name} plugin for Enclava platform",
|
||||
author="Enclava Team",
|
||||
license="MIT",
|
||||
tags=["integration", "example"]
|
||||
tags=["integration", "example"],
|
||||
),
|
||||
spec=PluginSpec(
|
||||
runtime=PluginRuntimeSpec(
|
||||
python_version="3.11",
|
||||
dependencies=[
|
||||
"aiohttp>=3.8.0",
|
||||
"pydantic>=2.0.0"
|
||||
]
|
||||
dependencies=["aiohttp>=3.8.0", "pydantic>=2.0.0"],
|
||||
),
|
||||
permissions=PluginPermissions(
|
||||
platform_apis=["chatbot:invoke", "rag:query"],
|
||||
plugin_scopes=["read", "write"]
|
||||
plugin_scopes=["read", "write"],
|
||||
),
|
||||
database=PluginDatabaseSpec(
|
||||
schema=f"plugin_{plugin_name}",
|
||||
migrations_path="./migrations"
|
||||
schema=f"plugin_{plugin_name}", migrations_path="./migrations"
|
||||
),
|
||||
api_endpoints=[
|
||||
PluginAPIEndpoint(
|
||||
path="/status",
|
||||
methods=["GET"],
|
||||
description="Plugin health status"
|
||||
description="Plugin health status",
|
||||
)
|
||||
],
|
||||
ui_config=PluginUIConfig(
|
||||
@@ -403,11 +475,11 @@ class PluginManifestValidator:
|
||||
{
|
||||
"name": "dashboard",
|
||||
"path": f"/plugins/{plugin_name}",
|
||||
"component": f"{plugin_name.title()}Dashboard"
|
||||
"component": f"{plugin_name.title()}Dashboard",
|
||||
}
|
||||
]
|
||||
)
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -423,7 +495,7 @@ def validate_manifest_file(manifest_path: Union[str, Path]) -> Dict[str, Any]:
|
||||
"manifest": manifest,
|
||||
"compatibility": compatibility,
|
||||
"hash": manifest_hash,
|
||||
"errors": []
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -432,5 +504,5 @@ def validate_manifest_file(manifest_path: Union[str, Path]) -> Dict[str, Any]:
|
||||
"manifest": None,
|
||||
"compatibility": None,
|
||||
"hash": None,
|
||||
"errors": [str(e)]
|
||||
"errors": [str(e)],
|
||||
}
|
||||
367
backend/app/schemas/role.py
Normal file
367
backend/app/schemas/role.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
Role Management Schemas
|
||||
Pydantic models for role management API
|
||||
"""
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
|
||||
class RoleBase(BaseModel):
|
||||
"""Base role schema"""
|
||||
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
level: str = "user"
|
||||
permissions: Dict[str, Any] = {}
|
||||
can_manage_users: bool = False
|
||||
can_manage_budgets: bool = False
|
||||
can_view_reports: bool = False
|
||||
can_manage_tools: bool = False
|
||||
inherits_from: List[str] = []
|
||||
is_active: bool = True
|
||||
|
||||
@validator("name")
|
||||
def validate_name(cls, v):
|
||||
if len(v) < 2:
|
||||
raise ValueError("Role name must be at least 2 characters long")
|
||||
if len(v) > 50:
|
||||
raise ValueError("Role name must be less than 50 characters long")
|
||||
if not v.isalnum() and "_" not in v:
|
||||
raise ValueError(
|
||||
"Role name must contain only alphanumeric characters and underscores"
|
||||
)
|
||||
return v.lower()
|
||||
|
||||
@validator("display_name")
|
||||
def validate_display_name(cls, v):
|
||||
if len(v) < 2:
|
||||
raise ValueError("Display name must be at least 2 characters long")
|
||||
if len(v) > 100:
|
||||
raise ValueError("Display name must be less than 100 characters long")
|
||||
return v
|
||||
|
||||
@validator("level")
|
||||
def validate_level(cls, v):
|
||||
valid_levels = ["read_only", "user", "admin", "super_admin"]
|
||||
if v not in valid_levels:
|
||||
raise ValueError(f'Level must be one of: {", ".join(valid_levels)}')
|
||||
return v
|
||||
|
||||
|
||||
class RoleCreate(RoleBase):
|
||||
"""Schema for creating a role"""
|
||||
|
||||
is_system_role: bool = False
|
||||
|
||||
|
||||
class RoleUpdate(BaseModel):
|
||||
"""Schema for updating a role"""
|
||||
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
permissions: Optional[Dict[str, Any]] = None
|
||||
can_manage_users: Optional[bool] = None
|
||||
can_manage_budgets: Optional[bool] = None
|
||||
can_view_reports: Optional[bool] = None
|
||||
can_manage_tools: Optional[bool] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
@validator("display_name")
|
||||
def validate_display_name(cls, v):
|
||||
if v is not None:
|
||||
if len(v) < 2:
|
||||
raise ValueError("Display name must be at least 2 characters long")
|
||||
if len(v) > 100:
|
||||
raise ValueError("Display name must be less than 100 characters long")
|
||||
return v
|
||||
|
||||
|
||||
class RoleResponse(BaseModel):
|
||||
"""Role response schema"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str]
|
||||
level: str
|
||||
permissions: Dict[str, Any]
|
||||
can_manage_users: bool
|
||||
can_manage_budgets: bool
|
||||
can_view_reports: bool
|
||||
can_manage_tools: bool
|
||||
inherits_from: List[str]
|
||||
is_active: bool
|
||||
is_system_role: bool
|
||||
created_at: Optional[datetime]
|
||||
updated_at: Optional[datetime]
|
||||
user_count: Optional[int] = 0 # Number of users with this role
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls, obj):
|
||||
"""Create response from ORM object"""
|
||||
data = obj.to_dict()
|
||||
# Add user count if available
|
||||
if hasattr(obj, "users"):
|
||||
data["user_count"] = len([u for u in obj.users if u.is_active])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class RoleListResponse(BaseModel):
|
||||
"""Role list response schema"""
|
||||
|
||||
roles: List[RoleResponse]
|
||||
total: int
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
|
||||
class RoleAssignmentRequest(BaseModel):
|
||||
"""Schema for role assignment"""
|
||||
|
||||
role_id: int
|
||||
|
||||
@validator("role_id")
|
||||
def validate_role_id(cls, v):
|
||||
if v <= 0:
|
||||
raise ValueError("Role ID must be a positive integer")
|
||||
return v
|
||||
|
||||
|
||||
class RoleBulkAction(BaseModel):
|
||||
"""Schema for bulk role actions"""
|
||||
|
||||
role_ids: List[int]
|
||||
action: str # activate, deactivate, delete
|
||||
action_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
@validator("action")
|
||||
def validate_action(cls, v):
|
||||
valid_actions = ["activate", "deactivate", "delete"]
|
||||
if v not in valid_actions:
|
||||
raise ValueError(f'Action must be one of: {", ".join(valid_actions)}')
|
||||
return v
|
||||
|
||||
@validator("role_ids")
|
||||
def validate_role_ids(cls, v):
|
||||
if not v:
|
||||
raise ValueError("At least one role ID must be provided")
|
||||
if len(v) > 50:
|
||||
raise ValueError("Cannot perform bulk action on more than 50 roles at once")
|
||||
return v
|
||||
|
||||
|
||||
class RoleStatistics(BaseModel):
|
||||
"""Role statistics schema"""
|
||||
|
||||
total_roles: int
|
||||
active_roles: int
|
||||
system_roles: int
|
||||
roles_by_level: Dict[str, int]
|
||||
roles_with_users: int
|
||||
unused_roles: int
|
||||
|
||||
|
||||
class RolePermission(BaseModel):
|
||||
"""Individual role permission schema"""
|
||||
|
||||
permission: str
|
||||
granted: bool = True
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class RolePermissionTemplate(BaseModel):
|
||||
"""Role permission template schema"""
|
||||
|
||||
name: str
|
||||
display_name: str
|
||||
description: str
|
||||
level: str
|
||||
permissions: List[RolePermission]
|
||||
can_manage_users: bool = False
|
||||
can_manage_budgets: bool = False
|
||||
can_view_reports: bool = False
|
||||
can_manage_tools: bool = False
|
||||
|
||||
|
||||
class RoleHierarchy(BaseModel):
|
||||
"""Role hierarchy schema"""
|
||||
|
||||
role: RoleResponse
|
||||
parent_roles: List[RoleResponse] = []
|
||||
child_roles: List[RoleResponse] = []
|
||||
effective_permissions: Dict[str, Any]
|
||||
|
||||
|
||||
class RoleComparison(BaseModel):
|
||||
"""Role comparison schema"""
|
||||
|
||||
role1: RoleResponse
|
||||
role2: RoleResponse
|
||||
common_permissions: List[str]
|
||||
unique_to_role1: List[str]
|
||||
unique_to_role2: List[str]
|
||||
|
||||
|
||||
class RoleUsage(BaseModel):
|
||||
"""Role usage statistics schema"""
|
||||
|
||||
role_id: int
|
||||
role_name: str
|
||||
user_count: int
|
||||
active_user_count: int
|
||||
last_assigned: Optional[datetime]
|
||||
usage_trend: Dict[str, int] # Monthly usage data
|
||||
|
||||
|
||||
class RoleSearchFilter(BaseModel):
|
||||
"""Role search filter schema"""
|
||||
|
||||
search: Optional[str] = None
|
||||
level: Optional[str] = None
|
||||
is_active: Optional[bool] = None
|
||||
is_system_role: Optional[bool] = None
|
||||
has_users: Optional[bool] = None
|
||||
created_after: Optional[datetime] = None
|
||||
created_before: Optional[datetime] = None
|
||||
|
||||
|
||||
# Predefined permission templates
|
||||
ROLE_TEMPLATES = {
|
||||
"read_only": RolePermissionTemplate(
|
||||
name="read_only",
|
||||
display_name="Read Only",
|
||||
description="Can view own data only",
|
||||
level="read_only",
|
||||
permissions=[
|
||||
RolePermission(
|
||||
permission="read_own", granted=True, description="Read own data"
|
||||
),
|
||||
RolePermission(
|
||||
permission="create", granted=False, description="Create new resources"
|
||||
),
|
||||
RolePermission(
|
||||
permission="update",
|
||||
granted=False,
|
||||
description="Update existing resources",
|
||||
),
|
||||
RolePermission(
|
||||
permission="delete", granted=False, description="Delete resources"
|
||||
),
|
||||
],
|
||||
),
|
||||
"user": RolePermissionTemplate(
|
||||
name="user",
|
||||
display_name="User",
|
||||
description="Standard user with full access to own resources",
|
||||
level="user",
|
||||
permissions=[
|
||||
RolePermission(
|
||||
permission="read_own", granted=True, description="Read own data"
|
||||
),
|
||||
RolePermission(
|
||||
permission="create_own",
|
||||
granted=True,
|
||||
description="Create own resources",
|
||||
),
|
||||
RolePermission(
|
||||
permission="update_own",
|
||||
granted=True,
|
||||
description="Update own resources",
|
||||
),
|
||||
RolePermission(
|
||||
permission="delete_own",
|
||||
granted=True,
|
||||
description="Delete own resources",
|
||||
),
|
||||
RolePermission(
|
||||
permission="manage_users",
|
||||
granted=False,
|
||||
description="Manage other users",
|
||||
),
|
||||
RolePermission(
|
||||
permission="manage_all",
|
||||
granted=False,
|
||||
description="Manage all resources",
|
||||
),
|
||||
],
|
||||
inherits_from=["read_only"],
|
||||
),
|
||||
"admin": RolePermissionTemplate(
|
||||
name="admin",
|
||||
display_name="Administrator",
|
||||
description="Can manage users and view reports",
|
||||
level="admin",
|
||||
permissions=[
|
||||
RolePermission(
|
||||
permission="read_all", granted=True, description="Read all data"
|
||||
),
|
||||
RolePermission(
|
||||
permission="create_all",
|
||||
granted=True,
|
||||
description="Create any resources",
|
||||
),
|
||||
RolePermission(
|
||||
permission="update_all",
|
||||
granted=True,
|
||||
description="Update any resources",
|
||||
),
|
||||
RolePermission(
|
||||
permission="manage_users", granted=True, description="Manage users"
|
||||
),
|
||||
RolePermission(
|
||||
permission="view_reports",
|
||||
granted=True,
|
||||
description="View system reports",
|
||||
),
|
||||
RolePermission(
|
||||
permission="system_settings",
|
||||
granted=False,
|
||||
description="Modify system settings",
|
||||
),
|
||||
],
|
||||
can_manage_users=True,
|
||||
can_view_reports=True,
|
||||
inherits_from=["user"],
|
||||
),
|
||||
"super_admin": RolePermissionTemplate(
|
||||
name="super_admin",
|
||||
display_name="Super Administrator",
|
||||
description="Full system access",
|
||||
level="super_admin",
|
||||
permissions=[
|
||||
RolePermission(permission="*", granted=True, description="All permissions")
|
||||
],
|
||||
can_manage_users=True,
|
||||
can_manage_budgets=True,
|
||||
can_view_reports=True,
|
||||
can_manage_tools=True,
|
||||
inherits_from=["admin"],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# Common permission definitions
|
||||
COMMON_PERMISSIONS = {
|
||||
"read_own": "Read own data and resources",
|
||||
"read_all": "Read all data and resources",
|
||||
"create_own": "Create own resources",
|
||||
"create_all": "Create any resources",
|
||||
"update_own": "Update own resources",
|
||||
"update_all": "Update any resources",
|
||||
"delete_own": "Delete own resources",
|
||||
"delete_all": "Delete any resources",
|
||||
"manage_users": "Manage user accounts",
|
||||
"manage_roles": "Manage role assignments",
|
||||
"manage_budgets": "Manage budget settings",
|
||||
"view_reports": "View system reports",
|
||||
"manage_tools": "Manage custom tools",
|
||||
"system_settings": "Modify system settings",
|
||||
"create_api_key": "Create API keys",
|
||||
"create_tool": "Create custom tools",
|
||||
"export_data": "Export system data",
|
||||
}
|
||||
219
backend/app/schemas/tool.py
Normal file
219
backend/app/schemas/tool.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
Tool schemas for API requests and responses
|
||||
"""
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
# Tool Creation and Update Schemas
|
||||
|
||||
|
||||
class ToolCreate(BaseModel):
|
||||
"""Schema for creating a new tool"""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=100, description="Unique tool name")
|
||||
display_name: str = Field(
|
||||
..., min_length=1, max_length=200, description="Display name for the tool"
|
||||
)
|
||||
description: Optional[str] = Field(None, description="Tool description")
|
||||
tool_type: str = Field(..., description="Tool type (python, bash, docker, api)")
|
||||
code: str = Field(..., min_length=1, description="Tool implementation code")
|
||||
parameters_schema: Optional[Dict[str, Any]] = Field(
|
||||
None, description="JSON schema for parameters"
|
||||
)
|
||||
return_schema: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Expected return format schema"
|
||||
)
|
||||
timeout_seconds: Optional[int] = Field(
|
||||
30, ge=1, le=300, description="Execution timeout in seconds"
|
||||
)
|
||||
max_memory_mb: Optional[int] = Field(
|
||||
256, ge=1, le=1024, description="Maximum memory in MB"
|
||||
)
|
||||
max_cpu_seconds: Optional[float] = Field(
|
||||
10.0, ge=0.1, le=60.0, description="Maximum CPU time in seconds"
|
||||
)
|
||||
docker_image: Optional[str] = Field(
|
||||
None, max_length=200, description="Docker image for execution"
|
||||
)
|
||||
docker_command: Optional[str] = Field(None, description="Docker command to run")
|
||||
category: Optional[str] = Field(None, max_length=50, description="Tool category")
|
||||
tags: Optional[List[str]] = Field(None, description="Tool tags")
|
||||
is_public: Optional[bool] = Field(False, description="Whether tool is public")
|
||||
|
||||
@validator("tool_type")
|
||||
def validate_tool_type(cls, v):
|
||||
valid_types = ["python", "bash", "docker", "api", "custom"]
|
||||
if v not in valid_types:
|
||||
raise ValueError(f"Tool type must be one of: {valid_types}")
|
||||
return v
|
||||
|
||||
@validator("tags")
|
||||
def validate_tags(cls, v):
|
||||
if v is not None and len(v) > 10:
|
||||
raise ValueError("Maximum 10 tags allowed")
|
||||
return v
|
||||
|
||||
|
||||
class ToolUpdate(BaseModel):
|
||||
"""Schema for updating a tool"""
|
||||
|
||||
display_name: Optional[str] = Field(None, min_length=1, max_length=200)
|
||||
description: Optional[str] = None
|
||||
code: Optional[str] = Field(None, min_length=1)
|
||||
parameters_schema: Optional[Dict[str, Any]] = None
|
||||
return_schema: Optional[Dict[str, Any]] = None
|
||||
timeout_seconds: Optional[int] = Field(None, ge=1, le=300)
|
||||
max_memory_mb: Optional[int] = Field(None, ge=1, le=1024)
|
||||
max_cpu_seconds: Optional[float] = Field(None, ge=0.1, le=60.0)
|
||||
docker_image: Optional[str] = Field(None, max_length=200)
|
||||
docker_command: Optional[str] = None
|
||||
category: Optional[str] = Field(None, max_length=50)
|
||||
tags: Optional[List[str]] = None
|
||||
is_public: Optional[bool] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
@validator("tags")
|
||||
def validate_tags(cls, v):
|
||||
if v is not None and len(v) > 10:
|
||||
raise ValueError("Maximum 10 tags allowed")
|
||||
return v
|
||||
|
||||
|
||||
# Tool Response Schemas
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
"""Schema for tool response"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str]
|
||||
tool_type: str
|
||||
parameters_schema: Dict[str, Any]
|
||||
return_schema: Dict[str, Any]
|
||||
timeout_seconds: int
|
||||
max_memory_mb: int
|
||||
max_cpu_seconds: float
|
||||
docker_image: Optional[str]
|
||||
is_public: bool
|
||||
is_approved: bool
|
||||
created_by_user_id: int
|
||||
category: Optional[str]
|
||||
tags: List[str]
|
||||
usage_count: int
|
||||
last_used_at: Optional[datetime]
|
||||
is_active: bool
|
||||
created_at: Optional[datetime]
|
||||
updated_at: Optional[datetime]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ToolListResponse(BaseModel):
|
||||
"""Schema for tool list response"""
|
||||
|
||||
tools: List[ToolResponse]
|
||||
total: int
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
|
||||
# Tool Execution Schemas
|
||||
|
||||
|
||||
class ToolExecutionCreate(BaseModel):
|
||||
"""Schema for creating a tool execution"""
|
||||
|
||||
parameters: Dict[str, Any] = Field(..., description="Parameters for tool execution")
|
||||
timeout_override: Optional[int] = Field(
|
||||
None, ge=1, le=300, description="Override default timeout"
|
||||
)
|
||||
|
||||
|
||||
class ToolExecutionResponse(BaseModel):
|
||||
"""Schema for tool execution response"""
|
||||
|
||||
id: int
|
||||
tool_id: int
|
||||
tool_name: Optional[str]
|
||||
executed_by_user_id: int
|
||||
parameters: Dict[str, Any]
|
||||
status: str
|
||||
output: Optional[str]
|
||||
error_message: Optional[str]
|
||||
return_code: Optional[int]
|
||||
execution_time_ms: Optional[int]
|
||||
memory_used_mb: Optional[float]
|
||||
cpu_time_ms: Optional[int]
|
||||
container_id: Optional[str]
|
||||
started_at: Optional[datetime]
|
||||
completed_at: Optional[datetime]
|
||||
created_at: Optional[datetime]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ToolExecutionListResponse(BaseModel):
|
||||
"""Schema for tool execution list response"""
|
||||
|
||||
executions: List[ToolExecutionResponse]
|
||||
total: int
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
|
||||
# Tool Category Schemas
|
||||
|
||||
|
||||
class ToolCategoryCreate(BaseModel):
|
||||
"""Schema for creating a tool category"""
|
||||
|
||||
name: str = Field(
|
||||
..., min_length=1, max_length=50, description="Unique category name"
|
||||
)
|
||||
display_name: str = Field(
|
||||
..., min_length=1, max_length=100, description="Display name"
|
||||
)
|
||||
description: Optional[str] = Field(None, description="Category description")
|
||||
icon: Optional[str] = Field(None, max_length=50, description="Icon name")
|
||||
color: Optional[str] = Field(None, max_length=20, description="Color code")
|
||||
sort_order: Optional[int] = Field(0, description="Sort order")
|
||||
|
||||
|
||||
class ToolCategoryResponse(BaseModel):
|
||||
"""Schema for tool category response"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str]
|
||||
icon: Optional[str]
|
||||
color: Optional[str]
|
||||
sort_order: int
|
||||
is_active: bool
|
||||
created_at: Optional[datetime]
|
||||
updated_at: Optional[datetime]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# Statistics Schema
|
||||
|
||||
|
||||
class ToolStatisticsResponse(BaseModel):
|
||||
"""Schema for tool statistics response"""
|
||||
|
||||
total_tools: int
|
||||
public_tools: int
|
||||
tools_by_type: Dict[str, int]
|
||||
total_executions: int
|
||||
executions_by_status: Dict[str, int]
|
||||
recent_executions: int
|
||||
top_tools: List[Dict[str, Any]]
|
||||
user_tools: Optional[int] = None
|
||||
user_executions: Optional[int] = None
|
||||
68
backend/app/schemas/tool_calling.py
Normal file
68
backend/app/schemas/tool_calling.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Tool calling schemas for API requests and responses
|
||||
"""
|
||||
from typing import Dict, List, Optional, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ToolExecutionRequest(BaseModel):
|
||||
"""Schema for executing a tool by name"""
|
||||
|
||||
tool_name: str = Field(..., description="Name of the tool to execute")
|
||||
parameters: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Parameters for tool execution"
|
||||
)
|
||||
|
||||
|
||||
class ToolCallResponse(BaseModel):
|
||||
"""Schema for tool call response"""
|
||||
|
||||
success: bool = Field(..., description="Whether the tool call was successful")
|
||||
result: Optional[Dict[str, Any]] = Field(None, description="Tool execution result")
|
||||
error: Optional[str] = Field(None, description="Error message if failed")
|
||||
|
||||
|
||||
class ToolValidationRequest(BaseModel):
|
||||
"""Schema for validating tool availability"""
|
||||
|
||||
tool_names: List[str] = Field(..., description="List of tool names to validate")
|
||||
|
||||
|
||||
class ToolValidationResponse(BaseModel):
|
||||
"""Schema for tool validation response"""
|
||||
|
||||
tool_availability: Dict[str, bool] = Field(
|
||||
..., description="Tool name to availability mapping"
|
||||
)
|
||||
|
||||
|
||||
class ToolHistoryItem(BaseModel):
|
||||
"""Schema for tool execution history item"""
|
||||
|
||||
id: int = Field(..., description="Execution ID")
|
||||
tool_name: str = Field(..., description="Tool name")
|
||||
parameters: Dict[str, Any] = Field(..., description="Execution parameters")
|
||||
status: str = Field(..., description="Execution status")
|
||||
output: Optional[str] = Field(None, description="Tool output")
|
||||
error_message: Optional[str] = Field(None, description="Error message if failed")
|
||||
execution_time_ms: Optional[int] = Field(
|
||||
None, description="Execution time in milliseconds"
|
||||
)
|
||||
created_at: Optional[str] = Field(None, description="Creation timestamp")
|
||||
completed_at: Optional[str] = Field(None, description="Completion timestamp")
|
||||
|
||||
|
||||
class ToolHistoryResponse(BaseModel):
|
||||
"""Schema for tool execution history response"""
|
||||
|
||||
history: List[ToolHistoryItem] = Field(..., description="Tool execution history")
|
||||
total: int = Field(..., description="Total number of history items")
|
||||
|
||||
|
||||
class ToolCallRequest(BaseModel):
|
||||
"""Schema for tool call request (placeholder for future use)"""
|
||||
|
||||
message: str = Field(..., description="Chat message")
|
||||
tools: Optional[List[str]] = Field(
|
||||
None, description="Specific tools to make available"
|
||||
)
|
||||
260
backend/app/schemas/user.py
Normal file
260
backend/app/schemas/user.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""
|
||||
User Management Schemas
|
||||
Pydantic models for user management API
|
||||
"""
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, EmailStr, validator
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
"""Base user schema"""
|
||||
|
||||
email: EmailStr
|
||||
username: str
|
||||
full_name: Optional[str] = None
|
||||
is_active: bool = True
|
||||
is_verified: bool = False
|
||||
|
||||
@validator("username")
|
||||
def validate_username(cls, v):
|
||||
if len(v) < 3:
|
||||
raise ValueError("Username must be at least 3 characters long")
|
||||
if len(v) > 50:
|
||||
raise ValueError("Username must be less than 50 characters long")
|
||||
return v
|
||||
|
||||
@validator("email")
|
||||
def validate_email(cls, v):
|
||||
if len(v) > 255:
|
||||
raise ValueError("Email must be less than 255 characters long")
|
||||
return v
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
"""Schema for creating a user"""
|
||||
|
||||
password: str
|
||||
role_id: Optional[int] = None
|
||||
custom_permissions: Dict[str, Any] = {}
|
||||
|
||||
@validator("password")
|
||||
def validate_password(cls, v):
|
||||
if len(v) < 8:
|
||||
raise ValueError("Password must be at least 8 characters long")
|
||||
return v
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
"""Schema for updating a user"""
|
||||
|
||||
email: Optional[EmailStr] = None
|
||||
username: Optional[str] = None
|
||||
full_name: Optional[str] = None
|
||||
role_id: Optional[int] = None
|
||||
custom_permissions: Optional[Dict[str, Any]] = None
|
||||
is_active: Optional[bool] = None
|
||||
is_verified: Optional[bool] = None
|
||||
|
||||
@validator("username")
|
||||
def validate_username(cls, v):
|
||||
if v is not None:
|
||||
if len(v) < 3:
|
||||
raise ValueError("Username must be at least 3 characters long")
|
||||
if len(v) > 50:
|
||||
raise ValueError("Username must be less than 50 characters long")
|
||||
return v
|
||||
|
||||
|
||||
class PasswordChange(BaseModel):
|
||||
"""Schema for changing password"""
|
||||
|
||||
current_password: Optional[str] = None
|
||||
new_password: str
|
||||
confirm_password: str
|
||||
|
||||
@validator("new_password")
|
||||
def validate_new_password(cls, v):
|
||||
if len(v) < 8:
|
||||
raise ValueError("Password must be at least 8 characters long")
|
||||
return v
|
||||
|
||||
@validator("confirm_password")
|
||||
def passwords_match(cls, v, values):
|
||||
if "new_password" in values and v != values["new_password"]:
|
||||
raise ValueError("Passwords do not match")
|
||||
return v
|
||||
|
||||
|
||||
class RoleInfo(BaseModel):
|
||||
"""Role information schema"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
display_name: str
|
||||
level: str
|
||||
permissions: Dict[str, Any]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
"""User response schema"""
|
||||
|
||||
id: int
|
||||
email: str
|
||||
username: str
|
||||
full_name: Optional[str]
|
||||
is_active: bool
|
||||
is_verified: bool
|
||||
is_superuser: bool
|
||||
role_id: Optional[int]
|
||||
role: Optional[RoleInfo]
|
||||
custom_permissions: Dict[str, Any]
|
||||
account_locked: Optional[bool] = False
|
||||
account_locked_until: Optional[datetime]
|
||||
failed_login_attempts: Optional[int] = 0
|
||||
last_failed_login: Optional[datetime]
|
||||
avatar_url: Optional[str]
|
||||
bio: Optional[str]
|
||||
company: Optional[str]
|
||||
website: Optional[str]
|
||||
created_at: Optional[datetime]
|
||||
updated_at: Optional[datetime]
|
||||
last_login: Optional[datetime]
|
||||
preferences: Dict[str, Any]
|
||||
notification_settings: Dict[str, Any]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls, obj):
|
||||
"""Create response from ORM object with proper role handling"""
|
||||
data = obj.to_dict()
|
||||
if obj.role:
|
||||
data["role"] = RoleInfo.from_orm(obj.role)
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class UserListResponse(BaseModel):
|
||||
"""User list response schema"""
|
||||
|
||||
users: List[UserResponse]
|
||||
total: int
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
|
||||
class AccountLockResponse(BaseModel):
|
||||
"""Account lock response schema"""
|
||||
|
||||
user_id: int
|
||||
is_locked: bool
|
||||
locked_until: Optional[datetime]
|
||||
message: str
|
||||
|
||||
|
||||
class UserProfileUpdate(BaseModel):
|
||||
"""Schema for user profile updates (limited fields)"""
|
||||
|
||||
full_name: Optional[str] = None
|
||||
avatar_url: Optional[str] = None
|
||||
bio: Optional[str] = None
|
||||
company: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
preferences: Optional[Dict[str, Any]] = None
|
||||
notification_settings: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class UserPreferences(BaseModel):
|
||||
"""User preferences schema"""
|
||||
|
||||
theme: Optional[str] = "light"
|
||||
language: Optional[str] = "en"
|
||||
timezone: Optional[str] = "UTC"
|
||||
email_notifications: Optional[bool] = True
|
||||
security_alerts: Optional[bool] = True
|
||||
system_updates: Optional[bool] = True
|
||||
|
||||
|
||||
class UserSearchFilter(BaseModel):
|
||||
"""User search filter schema"""
|
||||
|
||||
search: Optional[str] = None
|
||||
role_id: Optional[int] = None
|
||||
is_active: Optional[bool] = None
|
||||
is_verified: Optional[bool] = None
|
||||
created_after: Optional[datetime] = None
|
||||
created_before: Optional[datetime] = None
|
||||
last_login_after: Optional[datetime] = None
|
||||
last_login_before: Optional[datetime] = None
|
||||
|
||||
|
||||
class UserBulkAction(BaseModel):
|
||||
"""Schema for bulk user actions"""
|
||||
|
||||
user_ids: List[int]
|
||||
action: str # activate, deactivate, lock, unlock, assign_role, remove_role
|
||||
action_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
@validator("action")
|
||||
def validate_action(cls, v):
|
||||
valid_actions = [
|
||||
"activate",
|
||||
"deactivate",
|
||||
"lock",
|
||||
"unlock",
|
||||
"assign_role",
|
||||
"remove_role",
|
||||
]
|
||||
if v not in valid_actions:
|
||||
raise ValueError(f'Action must be one of: {", ".join(valid_actions)}')
|
||||
return v
|
||||
|
||||
@validator("user_ids")
|
||||
def validate_user_ids(cls, v):
|
||||
if not v:
|
||||
raise ValueError("At least one user ID must be provided")
|
||||
if len(v) > 100:
|
||||
raise ValueError(
|
||||
"Cannot perform bulk action on more than 100 users at once"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class UserStatistics(BaseModel):
|
||||
"""User statistics schema"""
|
||||
|
||||
total_users: int
|
||||
active_users: int
|
||||
verified_users: int
|
||||
locked_users: int
|
||||
users_by_role: Dict[str, int]
|
||||
recent_registrations: int
|
||||
registrations_by_month: Dict[str, int]
|
||||
|
||||
|
||||
class UserActivity(BaseModel):
|
||||
"""User activity schema"""
|
||||
|
||||
user_id: int
|
||||
action: str
|
||||
resource_type: Optional[str] = None
|
||||
resource_id: Optional[int] = None
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
timestamp: datetime
|
||||
ip_address: Optional[str] = None
|
||||
user_agent: Optional[str] = None
|
||||
|
||||
|
||||
class UserActivityFilter(BaseModel):
|
||||
"""User activity filter schema"""
|
||||
|
||||
user_id: Optional[int] = None
|
||||
action: Optional[str] = None
|
||||
resource_type: Optional[str] = None
|
||||
date_from: Optional[datetime] = None
|
||||
date_to: Optional[datetime] = None
|
||||
ip_address: Optional[str] = None
|
||||
@@ -27,6 +27,7 @@ logger = get_logger(__name__)
|
||||
@dataclass
|
||||
class RequestEvent:
|
||||
"""Enhanced request event data structure with budget integration"""
|
||||
|
||||
timestamp: datetime
|
||||
method: str
|
||||
path: str
|
||||
@@ -55,6 +56,7 @@ class RequestEvent:
|
||||
@dataclass
|
||||
class UsageMetrics:
|
||||
"""Usage metrics including costs and tokens"""
|
||||
|
||||
total_requests: int
|
||||
successful_requests: int
|
||||
failed_requests: int
|
||||
@@ -84,6 +86,7 @@ class UsageMetrics:
|
||||
@dataclass
|
||||
class SystemHealth:
|
||||
"""System health including budget and usage analysis"""
|
||||
|
||||
status: str # healthy, warning, critical
|
||||
score: int # 0-100
|
||||
issues: List[str]
|
||||
@@ -113,20 +116,20 @@ class AnalyticsService:
|
||||
self.cache_ttl = 300 # 5 minutes cache TTL
|
||||
|
||||
# Statistics counters
|
||||
self.endpoint_stats = defaultdict(lambda: {
|
||||
self.endpoint_stats = defaultdict(
|
||||
lambda: {
|
||||
"count": 0,
|
||||
"total_time": 0,
|
||||
"errors": 0,
|
||||
"avg_time": 0,
|
||||
"total_tokens": 0,
|
||||
"total_cost_cents": 0
|
||||
})
|
||||
"total_cost_cents": 0,
|
||||
}
|
||||
)
|
||||
self.status_codes = defaultdict(int)
|
||||
self.model_stats = defaultdict(lambda: {
|
||||
"count": 0,
|
||||
"total_tokens": 0,
|
||||
"total_cost_cents": 0
|
||||
})
|
||||
self.model_stats = defaultdict(
|
||||
lambda: {"count": 0, "total_tokens": 0, "total_cost_cents": 0}
|
||||
)
|
||||
|
||||
# Start cleanup task
|
||||
asyncio.create_task(self._cleanup_old_events())
|
||||
@@ -165,13 +168,19 @@ class AnalyticsService:
|
||||
# Clear metrics cache to force recalculation
|
||||
self.metrics_cache.clear()
|
||||
|
||||
logger.debug(f"Tracked request: {endpoint} - {event.status_code} - {event.response_time:.3f}s")
|
||||
logger.debug(
|
||||
f"Tracked request: {endpoint} - {event.status_code} - {event.response_time:.3f}s"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error tracking request: {e}")
|
||||
|
||||
async def get_usage_metrics(self, hours: int = 24, user_id: Optional[int] = None,
|
||||
api_key_id: Optional[int] = None) -> UsageMetrics:
|
||||
async def get_usage_metrics(
|
||||
self,
|
||||
hours: int = 24,
|
||||
user_id: Optional[int] = None,
|
||||
api_key_id: Optional[int] = None,
|
||||
) -> UsageMetrics:
|
||||
"""Get comprehensive usage metrics including costs and budgets"""
|
||||
cache_key = f"usage_metrics_{hours}_{user_id}_{api_key_id}"
|
||||
|
||||
@@ -207,7 +216,9 @@ class AnalyticsService:
|
||||
failed_requests = total_requests - successful_requests
|
||||
|
||||
if total_requests > 0:
|
||||
avg_response_time = sum(e.response_time for e in recent_events) / total_requests
|
||||
avg_response_time = (
|
||||
sum(e.response_time for e in recent_events) / total_requests
|
||||
)
|
||||
requests_per_minute = total_requests / (hours * 60)
|
||||
error_rate = (failed_requests / total_requests) * 100
|
||||
else:
|
||||
@@ -230,9 +241,14 @@ class AnalyticsService:
|
||||
budget_query = self.db.query(Budget).filter(Budget.is_active == True)
|
||||
if user_id:
|
||||
budget_query = budget_query.filter(
|
||||
or_(Budget.user_id == user_id, Budget.api_key_id.in_(
|
||||
self.db.query(APIKey.id).filter(APIKey.user_id == user_id).subquery()
|
||||
))
|
||||
or_(
|
||||
Budget.user_id == user_id,
|
||||
Budget.api_key_id.in_(
|
||||
self.db.query(APIKey.id)
|
||||
.filter(APIKey.user_id == user_id)
|
||||
.subquery()
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
budgets = budget_query.all()
|
||||
@@ -253,7 +269,9 @@ class AnalyticsService:
|
||||
|
||||
top_endpoints = [
|
||||
{"endpoint": endpoint, "count": count}
|
||||
for endpoint, count in sorted(endpoint_counts.items(), key=lambda x: x[1], reverse=True)[:10]
|
||||
for endpoint, count in sorted(
|
||||
endpoint_counts.items(), key=lambda x: x[1], reverse=True
|
||||
)[:10]
|
||||
]
|
||||
|
||||
# Status codes from memory
|
||||
@@ -262,21 +280,27 @@ class AnalyticsService:
|
||||
status_counts[str(event.status_code)] += 1
|
||||
|
||||
# Top models from database
|
||||
model_usage = self.db.query(
|
||||
model_usage = (
|
||||
self.db.query(
|
||||
UsageTracking.model,
|
||||
func.count(UsageTracking.id).label('count'),
|
||||
func.sum(UsageTracking.total_tokens).label('tokens'),
|
||||
func.sum(UsageTracking.cost_cents).label('cost')
|
||||
).filter(and_(*filters)).filter(
|
||||
UsageTracking.model.is_not(None)
|
||||
).group_by(UsageTracking.model).order_by(desc('count')).limit(10).all()
|
||||
func.count(UsageTracking.id).label("count"),
|
||||
func.sum(UsageTracking.total_tokens).label("tokens"),
|
||||
func.sum(UsageTracking.cost_cents).label("cost"),
|
||||
)
|
||||
.filter(and_(*filters))
|
||||
.filter(UsageTracking.model.is_not(None))
|
||||
.group_by(UsageTracking.model)
|
||||
.order_by(desc("count"))
|
||||
.limit(10)
|
||||
.all()
|
||||
)
|
||||
|
||||
top_models = [
|
||||
{
|
||||
"model": model,
|
||||
"count": count,
|
||||
"total_tokens": tokens or 0,
|
||||
"total_cost_cents": cost or 0
|
||||
"total_cost_cents": cost or 0,
|
||||
}
|
||||
for model, count, tokens, cost in model_usage
|
||||
]
|
||||
@@ -300,7 +324,7 @@ class AnalyticsService:
|
||||
top_endpoints=top_endpoints,
|
||||
status_codes=dict(status_counts),
|
||||
top_models=top_models,
|
||||
timestamp=datetime.utcnow()
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
|
||||
# Cache the result
|
||||
@@ -310,13 +334,24 @@ class AnalyticsService:
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting usage metrics: {e}")
|
||||
return UsageMetrics(
|
||||
total_requests=0, successful_requests=0, failed_requests=0,
|
||||
avg_response_time=0, requests_per_minute=0, error_rate=0,
|
||||
total_tokens=0, total_cost_cents=0, avg_tokens_per_request=0,
|
||||
avg_cost_per_request_cents=0, total_budget_cents=0,
|
||||
used_budget_cents=0, budget_usage_percentage=0, active_budgets=0,
|
||||
top_endpoints=[], status_codes={}, top_models=[],
|
||||
timestamp=datetime.utcnow()
|
||||
total_requests=0,
|
||||
successful_requests=0,
|
||||
failed_requests=0,
|
||||
avg_response_time=0,
|
||||
requests_per_minute=0,
|
||||
error_rate=0,
|
||||
total_tokens=0,
|
||||
total_cost_cents=0,
|
||||
avg_tokens_per_request=0,
|
||||
avg_cost_per_request_cents=0,
|
||||
total_budget_cents=0,
|
||||
used_budget_cents=0,
|
||||
budget_usage_percentage=0,
|
||||
active_budgets=0,
|
||||
top_endpoints=[],
|
||||
status_codes={},
|
||||
top_models=[],
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
|
||||
async def get_system_health(self) -> SystemHealth:
|
||||
@@ -347,22 +382,30 @@ class AnalyticsService:
|
||||
recommendations.append("Optimize slow endpoints and database queries")
|
||||
elif metrics.avg_response_time > 2.0:
|
||||
health_score -= 10
|
||||
issues.append(f"Elevated response time: {metrics.avg_response_time:.2f}s")
|
||||
issues.append(
|
||||
f"Elevated response time: {metrics.avg_response_time:.2f}s"
|
||||
)
|
||||
recommendations.append("Monitor performance trends")
|
||||
|
||||
# Check budget usage
|
||||
if metrics.budget_usage_percentage > 90:
|
||||
health_score -= 20
|
||||
issues.append(f"Budget usage critical: {metrics.budget_usage_percentage:.1f}%")
|
||||
issues.append(
|
||||
f"Budget usage critical: {metrics.budget_usage_percentage:.1f}%"
|
||||
)
|
||||
recommendations.append("Review budget limits and usage patterns")
|
||||
elif metrics.budget_usage_percentage > 75:
|
||||
health_score -= 10
|
||||
issues.append(f"Budget usage high: {metrics.budget_usage_percentage:.1f}%")
|
||||
issues.append(
|
||||
f"Budget usage high: {metrics.budget_usage_percentage:.1f}%"
|
||||
)
|
||||
recommendations.append("Monitor spending trends")
|
||||
|
||||
# Check for budgets near or over limit
|
||||
budgets = self.db.query(Budget).filter(Budget.is_active == True).all()
|
||||
budgets_near_limit = sum(1 for b in budgets if b.current_usage_cents >= b.limit_cents * 0.8)
|
||||
budgets_near_limit = sum(
|
||||
1 for b in budgets if b.current_usage_cents >= b.limit_cents * 0.8
|
||||
)
|
||||
budgets_exceeded = sum(1 for b in budgets if b.is_exceeded)
|
||||
|
||||
if budgets_exceeded > 0:
|
||||
@@ -393,21 +436,28 @@ class AnalyticsService:
|
||||
budget_usage_percentage=metrics.budget_usage_percentage,
|
||||
budgets_near_limit=budgets_near_limit,
|
||||
budgets_exceeded=budgets_exceeded,
|
||||
timestamp=datetime.utcnow()
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting system health: {e}")
|
||||
return SystemHealth(
|
||||
status="error", score=0,
|
||||
status="error",
|
||||
score=0,
|
||||
issues=[f"Health check failed: {str(e)}"],
|
||||
recommendations=["Check system logs and restart services"],
|
||||
avg_response_time=0, error_rate=0, requests_per_minute=0,
|
||||
budget_usage_percentage=0, budgets_near_limit=0,
|
||||
budgets_exceeded=0, timestamp=datetime.utcnow()
|
||||
avg_response_time=0,
|
||||
error_rate=0,
|
||||
requests_per_minute=0,
|
||||
budget_usage_percentage=0,
|
||||
budgets_near_limit=0,
|
||||
budgets_exceeded=0,
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
|
||||
async def get_cost_analysis(self, days: int = 30, user_id: Optional[int] = None) -> Dict[str, Any]:
|
||||
async def get_cost_analysis(
|
||||
self, days: int = 30, user_id: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get detailed cost analysis and trends"""
|
||||
try:
|
||||
cutoff_time = datetime.utcnow() - timedelta(days=days)
|
||||
@@ -448,9 +498,15 @@ class AnalyticsService:
|
||||
total_requests = len(usage_records)
|
||||
|
||||
efficiency_metrics = {
|
||||
"cost_per_token": (total_cost / total_tokens) if total_tokens > 0 else 0,
|
||||
"cost_per_request": (total_cost / total_requests) if total_requests > 0 else 0,
|
||||
"tokens_per_request": (total_tokens / total_requests) if total_requests > 0 else 0
|
||||
"cost_per_token": (total_cost / total_tokens)
|
||||
if total_tokens > 0
|
||||
else 0,
|
||||
"cost_per_request": (total_cost / total_requests)
|
||||
if total_requests > 0
|
||||
else 0,
|
||||
"tokens_per_request": (total_tokens / total_requests)
|
||||
if total_requests > 0
|
||||
else 0,
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -465,7 +521,7 @@ class AnalyticsService:
|
||||
"requests_by_model": dict(requests_by_model),
|
||||
"daily_costs": dict(daily_costs),
|
||||
"cost_by_endpoint": dict(cost_by_endpoint),
|
||||
"analysis_timestamp": datetime.utcnow().isoformat()
|
||||
"analysis_timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -538,20 +594,20 @@ class InMemoryAnalyticsService:
|
||||
self.cache_ttl = 300 # 5 minutes cache TTL
|
||||
|
||||
# Statistics counters
|
||||
self.endpoint_stats = defaultdict(lambda: {
|
||||
self.endpoint_stats = defaultdict(
|
||||
lambda: {
|
||||
"count": 0,
|
||||
"total_time": 0,
|
||||
"errors": 0,
|
||||
"avg_time": 0,
|
||||
"total_tokens": 0,
|
||||
"total_cost_cents": 0
|
||||
})
|
||||
"total_cost_cents": 0,
|
||||
}
|
||||
)
|
||||
self.status_codes = defaultdict(int)
|
||||
self.model_stats = defaultdict(lambda: {
|
||||
"count": 0,
|
||||
"total_tokens": 0,
|
||||
"total_cost_cents": 0
|
||||
})
|
||||
self.model_stats = defaultdict(
|
||||
lambda: {"count": 0, "total_tokens": 0, "total_cost_cents": 0}
|
||||
)
|
||||
|
||||
# Start cleanup task
|
||||
asyncio.create_task(self._cleanup_old_events())
|
||||
@@ -590,13 +646,19 @@ class InMemoryAnalyticsService:
|
||||
# Clear metrics cache to force recalculation
|
||||
self.metrics_cache.clear()
|
||||
|
||||
logger.debug(f"Tracked request: {endpoint} - {event.status_code} - {event.response_time:.3f}s")
|
||||
logger.debug(
|
||||
f"Tracked request: {endpoint} - {event.status_code} - {event.response_time:.3f}s"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error tracking request: {e}")
|
||||
|
||||
async def get_usage_metrics(self, hours: int = 24, user_id: Optional[int] = None,
|
||||
api_key_id: Optional[int] = None) -> UsageMetrics:
|
||||
async def get_usage_metrics(
|
||||
self,
|
||||
hours: int = 24,
|
||||
user_id: Optional[int] = None,
|
||||
api_key_id: Optional[int] = None,
|
||||
) -> UsageMetrics:
|
||||
"""Get comprehensive usage metrics including costs and budgets"""
|
||||
cache_key = f"usage_metrics_{hours}_{user_id}_{api_key_id}"
|
||||
|
||||
@@ -622,7 +684,9 @@ class InMemoryAnalyticsService:
|
||||
failed_requests = total_requests - successful_requests
|
||||
|
||||
if total_requests > 0:
|
||||
avg_response_time = sum(e.response_time for e in recent_events) / total_requests
|
||||
avg_response_time = (
|
||||
sum(e.response_time for e in recent_events) / total_requests
|
||||
)
|
||||
requests_per_minute = total_requests / (hours * 60)
|
||||
error_rate = (failed_requests / total_requests) * 100
|
||||
else:
|
||||
@@ -658,7 +722,9 @@ class InMemoryAnalyticsService:
|
||||
|
||||
top_endpoints = [
|
||||
{"endpoint": endpoint, "count": count}
|
||||
for endpoint, count in sorted(endpoint_counts.items(), key=lambda x: x[1], reverse=True)[:10]
|
||||
for endpoint, count in sorted(
|
||||
endpoint_counts.items(), key=lambda x: x[1], reverse=True
|
||||
)[:10]
|
||||
]
|
||||
|
||||
# Status codes from memory
|
||||
@@ -679,9 +745,11 @@ class InMemoryAnalyticsService:
|
||||
"model": model,
|
||||
"count": data["count"],
|
||||
"total_tokens": data["tokens"],
|
||||
"total_cost_cents": data["cost"]
|
||||
"total_cost_cents": data["cost"],
|
||||
}
|
||||
for model, data in sorted(model_usage.items(), key=lambda x: x[1]["count"], reverse=True)[:10]
|
||||
for model, data in sorted(
|
||||
model_usage.items(), key=lambda x: x[1]["count"], reverse=True
|
||||
)[:10]
|
||||
]
|
||||
|
||||
# Create metrics object
|
||||
@@ -703,7 +771,7 @@ class InMemoryAnalyticsService:
|
||||
top_endpoints=top_endpoints,
|
||||
status_codes=dict(status_counts),
|
||||
top_models=top_models,
|
||||
timestamp=datetime.utcnow()
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
|
||||
# Cache the result
|
||||
@@ -713,13 +781,24 @@ class InMemoryAnalyticsService:
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting usage metrics: {e}")
|
||||
return UsageMetrics(
|
||||
total_requests=0, successful_requests=0, failed_requests=0,
|
||||
avg_response_time=0, requests_per_minute=0, error_rate=0,
|
||||
total_tokens=0, total_cost_cents=0, avg_tokens_per_request=0,
|
||||
avg_cost_per_request_cents=0, total_budget_cents=0,
|
||||
used_budget_cents=0, budget_usage_percentage=0, active_budgets=0,
|
||||
top_endpoints=[], status_codes={}, top_models=[],
|
||||
timestamp=datetime.utcnow()
|
||||
total_requests=0,
|
||||
successful_requests=0,
|
||||
failed_requests=0,
|
||||
avg_response_time=0,
|
||||
requests_per_minute=0,
|
||||
error_rate=0,
|
||||
total_tokens=0,
|
||||
total_cost_cents=0,
|
||||
avg_tokens_per_request=0,
|
||||
avg_cost_per_request_cents=0,
|
||||
total_budget_cents=0,
|
||||
used_budget_cents=0,
|
||||
budget_usage_percentage=0,
|
||||
active_budgets=0,
|
||||
top_endpoints=[],
|
||||
status_codes={},
|
||||
top_models=[],
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
|
||||
async def get_system_health(self) -> SystemHealth:
|
||||
@@ -750,17 +829,23 @@ class InMemoryAnalyticsService:
|
||||
recommendations.append("Optimize slow endpoints and database queries")
|
||||
elif metrics.avg_response_time > 2.0:
|
||||
health_score -= 10
|
||||
issues.append(f"Elevated response time: {metrics.avg_response_time:.2f}s")
|
||||
issues.append(
|
||||
f"Elevated response time: {metrics.avg_response_time:.2f}s"
|
||||
)
|
||||
recommendations.append("Monitor performance trends")
|
||||
|
||||
# Check budget usage
|
||||
if metrics.budget_usage_percentage > 90:
|
||||
health_score -= 20
|
||||
issues.append(f"Budget usage critical: {metrics.budget_usage_percentage:.1f}%")
|
||||
issues.append(
|
||||
f"Budget usage critical: {metrics.budget_usage_percentage:.1f}%"
|
||||
)
|
||||
recommendations.append("Review budget limits and usage patterns")
|
||||
elif metrics.budget_usage_percentage > 75:
|
||||
health_score -= 10
|
||||
issues.append(f"Budget usage high: {metrics.budget_usage_percentage:.1f}%")
|
||||
issues.append(
|
||||
f"Budget usage high: {metrics.budget_usage_percentage:.1f}%"
|
||||
)
|
||||
recommendations.append("Monitor spending trends")
|
||||
|
||||
# Determine overall status
|
||||
@@ -782,21 +867,28 @@ class InMemoryAnalyticsService:
|
||||
budget_usage_percentage=metrics.budget_usage_percentage,
|
||||
budgets_near_limit=0, # Mock values since no DB access
|
||||
budgets_exceeded=0,
|
||||
timestamp=datetime.utcnow()
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting system health: {e}")
|
||||
return SystemHealth(
|
||||
status="error", score=0,
|
||||
status="error",
|
||||
score=0,
|
||||
issues=[f"Health check failed: {str(e)}"],
|
||||
recommendations=["Check system logs and restart services"],
|
||||
avg_response_time=0, error_rate=0, requests_per_minute=0,
|
||||
budget_usage_percentage=0, budgets_near_limit=0,
|
||||
budgets_exceeded=0, timestamp=datetime.utcnow()
|
||||
avg_response_time=0,
|
||||
error_rate=0,
|
||||
requests_per_minute=0,
|
||||
budget_usage_percentage=0,
|
||||
budgets_near_limit=0,
|
||||
budgets_exceeded=0,
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
|
||||
async def get_cost_analysis(self, days: int = 30, user_id: Optional[int] = None) -> Dict[str, Any]:
|
||||
async def get_cost_analysis(
|
||||
self, days: int = 30, user_id: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get detailed cost analysis and trends"""
|
||||
try:
|
||||
cutoff_time = datetime.utcnow() - timedelta(days=days)
|
||||
@@ -835,9 +927,15 @@ class InMemoryAnalyticsService:
|
||||
total_requests = len(events)
|
||||
|
||||
efficiency_metrics = {
|
||||
"cost_per_token": (total_cost / total_tokens) if total_tokens > 0 else 0,
|
||||
"cost_per_request": (total_cost / total_requests) if total_requests > 0 else 0,
|
||||
"tokens_per_request": (total_tokens / total_requests) if total_requests > 0 else 0
|
||||
"cost_per_token": (total_cost / total_tokens)
|
||||
if total_tokens > 0
|
||||
else 0,
|
||||
"cost_per_request": (total_cost / total_requests)
|
||||
if total_requests > 0
|
||||
else 0,
|
||||
"tokens_per_request": (total_tokens / total_requests)
|
||||
if total_requests > 0
|
||||
else 0,
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -852,7 +950,7 @@ class InMemoryAnalyticsService:
|
||||
"requests_by_model": dict(requests_by_model),
|
||||
"daily_costs": dict(daily_costs),
|
||||
"cost_by_endpoint": dict(cost_by_endpoint),
|
||||
"analysis_timestamp": datetime.utcnow().isoformat()
|
||||
"analysis_timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -27,7 +27,9 @@ class APIKeyAuthService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def validate_api_key(self, api_key: str, request: Request) -> Optional[Dict[str, Any]]:
|
||||
async def validate_api_key(
|
||||
self, api_key: str, request: Request
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Validate API key and return user context using Redis cache for performance"""
|
||||
try:
|
||||
if not api_key:
|
||||
@@ -41,10 +43,14 @@ class APIKeyAuthService:
|
||||
key_prefix = api_key[:8]
|
||||
|
||||
# Try cached verification first
|
||||
cached_verification = await cached_api_key_service.verify_api_key_cached(api_key, key_prefix)
|
||||
cached_verification = await cached_api_key_service.verify_api_key_cached(
|
||||
api_key, key_prefix
|
||||
)
|
||||
|
||||
# Get API key data from cache or database
|
||||
context = await cached_api_key_service.get_cached_api_key(key_prefix, self.db)
|
||||
context = await cached_api_key_service.get_cached_api_key(
|
||||
key_prefix, self.db
|
||||
)
|
||||
|
||||
if not context:
|
||||
logger.warning(f"API key not found: {key_prefix}")
|
||||
@@ -56,7 +62,7 @@ class APIKeyAuthService:
|
||||
if not cached_verification:
|
||||
# Get the actual key hash for verification (this should be in the cached context)
|
||||
db_api_key = None
|
||||
if not hasattr(api_key_obj, 'key_hash'):
|
||||
if not hasattr(api_key_obj, "key_hash"):
|
||||
# Fallback: fetch full API key from database for hash
|
||||
stmt = select(APIKey).where(APIKey.key_prefix == key_prefix)
|
||||
result = await self.db.execute(stmt)
|
||||
@@ -73,7 +79,9 @@ class APIKeyAuthService:
|
||||
return None
|
||||
|
||||
# Cache successful verification
|
||||
await cached_api_key_service.cache_verification_result(api_key, key_prefix, key_hash, True)
|
||||
await cached_api_key_service.cache_verification_result(
|
||||
api_key, key_prefix, key_hash, True
|
||||
)
|
||||
|
||||
# Check if key is valid (expiry, active status)
|
||||
if not api_key_obj.is_valid():
|
||||
@@ -89,7 +97,9 @@ class APIKeyAuthService:
|
||||
return None
|
||||
|
||||
# Update last used timestamp asynchronously (performance optimization)
|
||||
await cached_api_key_service.update_last_used(context["api_key_id"], self.db)
|
||||
await cached_api_key_service.update_last_used(
|
||||
context["api_key_id"], self.db
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
@@ -97,7 +107,9 @@ class APIKeyAuthService:
|
||||
logger.error(f"API key validation error: {e}")
|
||||
return None
|
||||
|
||||
async def check_endpoint_permission(self, context: Dict[str, Any], endpoint: str) -> bool:
|
||||
async def check_endpoint_permission(
|
||||
self, context: Dict[str, Any], endpoint: str
|
||||
) -> bool:
|
||||
"""Check if API key has permission to access endpoint"""
|
||||
api_key: APIKey = context.get("api_key")
|
||||
if not api_key:
|
||||
@@ -121,21 +133,24 @@ class APIKeyAuthService:
|
||||
|
||||
return api_key.has_scope(scope)
|
||||
|
||||
async def update_usage_stats(self, context: Dict[str, Any], tokens_used: int = 0, cost_cents: int = 0):
|
||||
async def update_usage_stats(
|
||||
self, context: Dict[str, Any], tokens_used: int = 0, cost_cents: int = 0
|
||||
):
|
||||
"""Update API key usage statistics"""
|
||||
try:
|
||||
api_key: APIKey = context.get("api_key")
|
||||
if api_key:
|
||||
api_key.update_usage(tokens_used, cost_cents)
|
||||
await self.db.commit()
|
||||
logger.info(f"Updated usage for API key {api_key.key_prefix}: +{tokens_used} tokens, +{cost_cents} cents")
|
||||
logger.info(
|
||||
f"Updated usage for API key {api_key.key_prefix}: +{tokens_used} tokens, +{cost_cents} cents"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update usage stats: {e}")
|
||||
|
||||
|
||||
async def get_api_key_context(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
request: Request, db: AsyncSession = Depends(get_db)
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Dependency to get API key context from request"""
|
||||
auth_service = APIKeyAuthService(db)
|
||||
@@ -170,7 +185,7 @@ async def require_api_key(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Valid API key required",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return context
|
||||
|
||||
@@ -190,7 +205,7 @@ async def get_current_api_key_user(
|
||||
if not user or not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User or API key not found in context"
|
||||
detail="User or API key not found in context",
|
||||
)
|
||||
|
||||
return user, api_key
|
||||
@@ -210,7 +225,7 @@ async def get_api_key_auth(
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="API key not found in context"
|
||||
detail="API key not found in context",
|
||||
)
|
||||
|
||||
return api_key
|
||||
@@ -227,7 +242,7 @@ class RequireScope:
|
||||
if not await auth_service.check_scope_permission(context, self.scope):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Scope '{self.scope}' required"
|
||||
detail=f"Scope '{self.scope}' required",
|
||||
)
|
||||
return context
|
||||
|
||||
@@ -243,6 +258,6 @@ class RequireModel:
|
||||
if not await auth_service.check_model_permission(context, self.model):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Model '{self.model}' not allowed"
|
||||
detail=f"Model '{self.model}' not allowed",
|
||||
)
|
||||
return context
|
||||
@@ -68,7 +68,7 @@ async def log_audit_event_async(
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
success: bool = True,
|
||||
severity: str = "info"
|
||||
severity: str = "info",
|
||||
):
|
||||
"""
|
||||
Log an audit event asynchronously (non-blocking)
|
||||
@@ -96,7 +96,7 @@ async def log_audit_event_async(
|
||||
"user_agent": user_agent,
|
||||
"success": success,
|
||||
"severity": severity,
|
||||
"created_at": datetime.utcnow()
|
||||
"created_at": datetime.utcnow(),
|
||||
}
|
||||
|
||||
# Queue the audit event (non-blocking)
|
||||
@@ -122,7 +122,7 @@ async def log_audit_event(
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
success: bool = True,
|
||||
severity: str = "info"
|
||||
severity: str = "info",
|
||||
):
|
||||
"""
|
||||
Log an audit event to the database
|
||||
@@ -157,13 +157,15 @@ async def log_audit_event(
|
||||
user_agent=user_agent,
|
||||
success=success,
|
||||
severity=severity,
|
||||
created_at=datetime.utcnow()
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(audit_log)
|
||||
await db.flush() # Don't commit here, let the caller control the transaction
|
||||
|
||||
logger.debug(f"Audit event logged: {action} on {resource_type} by user {user_id}")
|
||||
logger.debug(
|
||||
f"Audit event logged: {action} on {resource_type} by user {user_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log audit event: {e}")
|
||||
@@ -179,7 +181,7 @@ async def get_audit_logs(
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Query audit logs with filtering
|
||||
@@ -230,7 +232,7 @@ async def get_audit_logs(
|
||||
async def get_audit_stats(
|
||||
db: AsyncSession,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
end_date: Optional[datetime] = None,
|
||||
):
|
||||
"""
|
||||
Get audit statistics
|
||||
@@ -260,28 +262,36 @@ async def get_audit_stats(
|
||||
total_events = total_result.scalar()
|
||||
|
||||
# Events by action
|
||||
action_query = select(AuditLog.action, func.count(AuditLog.id)).group_by(AuditLog.action)
|
||||
action_query = select(AuditLog.action, func.count(AuditLog.id)).group_by(
|
||||
AuditLog.action
|
||||
)
|
||||
if conditions:
|
||||
action_query = action_query.where(and_(*conditions))
|
||||
action_result = await db.execute(action_query)
|
||||
events_by_action = dict(action_result.fetchall())
|
||||
|
||||
# Events by resource type
|
||||
resource_query = select(AuditLog.resource_type, func.count(AuditLog.id)).group_by(AuditLog.resource_type)
|
||||
resource_query = select(AuditLog.resource_type, func.count(AuditLog.id)).group_by(
|
||||
AuditLog.resource_type
|
||||
)
|
||||
if conditions:
|
||||
resource_query = resource_query.where(and_(*conditions))
|
||||
resource_result = await db.execute(resource_query)
|
||||
events_by_resource = dict(resource_result.fetchall())
|
||||
|
||||
# Events by severity
|
||||
severity_query = select(AuditLog.severity, func.count(AuditLog.id)).group_by(AuditLog.severity)
|
||||
severity_query = select(AuditLog.severity, func.count(AuditLog.id)).group_by(
|
||||
AuditLog.severity
|
||||
)
|
||||
if conditions:
|
||||
severity_query = severity_query.where(and_(*conditions))
|
||||
severity_result = await db.execute(severity_query)
|
||||
events_by_severity = dict(severity_result.fetchall())
|
||||
|
||||
# Success rate
|
||||
success_query = select(AuditLog.success, func.count(AuditLog.id)).group_by(AuditLog.success)
|
||||
success_query = select(AuditLog.success, func.count(AuditLog.id)).group_by(
|
||||
AuditLog.success
|
||||
)
|
||||
if conditions:
|
||||
success_query = success_query.where(and_(*conditions))
|
||||
success_result = await db.execute(success_query)
|
||||
@@ -292,6 +302,10 @@ async def get_audit_stats(
|
||||
"events_by_action": events_by_action,
|
||||
"events_by_resource_type": events_by_resource,
|
||||
"events_by_severity": events_by_severity,
|
||||
"success_rate": success_stats.get(True, 0) / total_events if total_events > 0 else 0,
|
||||
"failure_rate": success_stats.get(False, 0) / total_events if total_events > 0 else 0
|
||||
"success_rate": success_stats.get(True, 0) / total_events
|
||||
if total_events > 0
|
||||
else 0,
|
||||
"failure_rate": success_stats.get(False, 0) / total_events
|
||||
if total_events > 0
|
||||
else 0,
|
||||
}
|
||||
@@ -22,6 +22,7 @@ logger = get_logger(__name__)
|
||||
@dataclass
|
||||
class Permission:
|
||||
"""Represents a module permission"""
|
||||
|
||||
resource: str
|
||||
action: str
|
||||
description: str
|
||||
@@ -33,6 +34,7 @@ class Permission:
|
||||
@dataclass
|
||||
class ModuleMetrics:
|
||||
"""Module performance metrics"""
|
||||
|
||||
requests_processed: int = 0
|
||||
average_response_time: float = 0.0
|
||||
error_rate: float = 0.0
|
||||
@@ -48,6 +50,7 @@ class ModuleMetrics:
|
||||
@dataclass
|
||||
class ModuleHealth:
|
||||
"""Module health status"""
|
||||
|
||||
status: str = "healthy" # healthy, warning, error
|
||||
message: str = "Module is functioning normally"
|
||||
uptime: float = 0.0
|
||||
@@ -67,7 +70,7 @@ class BaseModule(ABC):
|
||||
self.metrics = ModuleMetrics()
|
||||
self.health = ModuleHealth()
|
||||
self.initialized = False
|
||||
self.interceptors: List['ModuleInterceptor'] = []
|
||||
self.interceptors: List["ModuleInterceptor"] = []
|
||||
|
||||
# Register default interceptors
|
||||
self._register_default_interceptors()
|
||||
@@ -88,7 +91,9 @@ class BaseModule(ABC):
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
async def process_request(self, request: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def process_request(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Process a module request"""
|
||||
pass
|
||||
|
||||
@@ -115,10 +120,12 @@ class BaseModule(ABC):
|
||||
ValidationInterceptor(),
|
||||
MetricsInterceptor(self),
|
||||
SecurityInterceptor(),
|
||||
AuditInterceptor(self)
|
||||
AuditInterceptor(self),
|
||||
]
|
||||
|
||||
async def execute_with_interceptors(self, request: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def execute_with_interceptors(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute request through interceptor chain"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -145,7 +152,7 @@ class BaseModule(ABC):
|
||||
|
||||
# Error handling interceptors
|
||||
for interceptor in reversed(self.interceptors):
|
||||
if hasattr(interceptor, 'handle_error'):
|
||||
if hasattr(interceptor, "handle_error"):
|
||||
await interceptor.handle_error(request, context, e)
|
||||
|
||||
raise
|
||||
@@ -168,7 +175,9 @@ class BaseModule(ABC):
|
||||
|
||||
if not success:
|
||||
self.metrics.total_errors += 1
|
||||
self.metrics.error_rate = self.metrics.total_errors / self.metrics.requests_processed
|
||||
self.metrics.error_rate = (
|
||||
self.metrics.total_errors / self.metrics.requests_processed
|
||||
)
|
||||
|
||||
# Update health status based on error rate
|
||||
if self.metrics.error_rate > 0.1: # More than 10% error rate
|
||||
@@ -176,9 +185,13 @@ class BaseModule(ABC):
|
||||
self.health.message = f"High error rate: {self.metrics.error_rate:.2%}"
|
||||
elif self.metrics.error_rate > 0.05: # More than 5% error rate
|
||||
self.health.status = "warning"
|
||||
self.health.message = f"Elevated error rate: {self.metrics.error_rate:.2%}"
|
||||
self.health.message = (
|
||||
f"Elevated error rate: {self.metrics.error_rate:.2%}"
|
||||
)
|
||||
else:
|
||||
self.metrics.error_rate = self.metrics.total_errors / self.metrics.requests_processed
|
||||
self.metrics.error_rate = (
|
||||
self.metrics.total_errors / self.metrics.requests_processed
|
||||
)
|
||||
if self.metrics.error_rate <= 0.05:
|
||||
self.health.status = "healthy"
|
||||
self.health.message = "Module is functioning normally"
|
||||
@@ -190,12 +203,16 @@ class ModuleInterceptor(ABC):
|
||||
"""Base class for module interceptors"""
|
||||
|
||||
@abstractmethod
|
||||
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
async def pre_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""Pre-process the request"""
|
||||
return request, context
|
||||
|
||||
@abstractmethod
|
||||
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def post_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Post-process the response"""
|
||||
return response
|
||||
|
||||
@@ -203,14 +220,18 @@ class ModuleInterceptor(ABC):
|
||||
class AuthenticationInterceptor(ModuleInterceptor):
|
||||
"""Handles authentication for module requests"""
|
||||
|
||||
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
async def pre_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
# Check if user is authenticated (context should contain user info from API auth)
|
||||
if not context.get("user_id") and not context.get("api_key_id"):
|
||||
raise AuthenticationError("Authentication required for module access")
|
||||
|
||||
return request, context
|
||||
|
||||
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def post_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
return response
|
||||
|
||||
|
||||
@@ -220,23 +241,31 @@ class PermissionInterceptor(ModuleInterceptor):
|
||||
def __init__(self, module: BaseModule):
|
||||
self.module = module
|
||||
|
||||
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
async def pre_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
action = request.get("action", "execute")
|
||||
user_permissions = context.get("user_permissions", [])
|
||||
|
||||
if not self.module.check_access(user_permissions, action):
|
||||
raise AuthenticationError(f"Insufficient permissions for module action: {action}")
|
||||
raise AuthenticationError(
|
||||
f"Insufficient permissions for module action: {action}"
|
||||
)
|
||||
|
||||
return request, context
|
||||
|
||||
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def post_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
return response
|
||||
|
||||
|
||||
class ValidationInterceptor(ModuleInterceptor):
|
||||
"""Handles request validation and sanitization"""
|
||||
|
||||
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
async def pre_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
# Sanitize request data
|
||||
sanitized_request = self._sanitize_request(request)
|
||||
|
||||
@@ -245,7 +274,9 @@ class ValidationInterceptor(ModuleInterceptor):
|
||||
|
||||
return sanitized_request, context
|
||||
|
||||
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def post_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
# Sanitize response data
|
||||
return self._sanitize_response(response)
|
||||
|
||||
@@ -275,7 +306,9 @@ class ValidationInterceptor(ModuleInterceptor):
|
||||
max_length = 10000
|
||||
if len(value) > max_length:
|
||||
value = value[:max_length]
|
||||
logger.warning(f"Truncated long string in request: {len(value)} chars")
|
||||
logger.warning(
|
||||
f"Truncated long string in request: {len(value)} chars"
|
||||
)
|
||||
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
@@ -302,7 +335,9 @@ class ValidationInterceptor(ModuleInterceptor):
|
||||
request_str = json.dumps(request)
|
||||
max_size = 10 * 1024 * 1024 # 10MB
|
||||
if len(request_str.encode()) > max_size:
|
||||
raise ValidationError(f"Request size exceeds maximum allowed ({max_size} bytes)")
|
||||
raise ValidationError(
|
||||
f"Request size exceeds maximum allowed ({max_size} bytes)"
|
||||
)
|
||||
|
||||
|
||||
class MetricsInterceptor(ModuleInterceptor):
|
||||
@@ -311,11 +346,15 @@ class MetricsInterceptor(ModuleInterceptor):
|
||||
def __init__(self, module: BaseModule):
|
||||
self.module = module
|
||||
|
||||
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
async def pre_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
context["_metrics_start_time"] = time.time()
|
||||
return request, context
|
||||
|
||||
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def post_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
# Metrics are updated in the base module execute_with_interceptors method
|
||||
return response
|
||||
|
||||
@@ -323,13 +362,15 @@ class MetricsInterceptor(ModuleInterceptor):
|
||||
class SecurityInterceptor(ModuleInterceptor):
|
||||
"""Handles security-related processing"""
|
||||
|
||||
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
async def pre_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
# Add security headers to context
|
||||
context["security_headers"] = {
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-Frame-Options": "DENY",
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains"
|
||||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
|
||||
}
|
||||
|
||||
# Check for suspicious patterns
|
||||
@@ -337,7 +378,9 @@ class SecurityInterceptor(ModuleInterceptor):
|
||||
|
||||
return request, context
|
||||
|
||||
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def post_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
# Remove any sensitive information from response
|
||||
return self._remove_sensitive_data(response)
|
||||
|
||||
@@ -346,9 +389,18 @@ class SecurityInterceptor(ModuleInterceptor):
|
||||
request_str = json.dumps(request).lower()
|
||||
|
||||
suspicious_patterns = [
|
||||
"union select", "drop table", "insert into", "delete from",
|
||||
"script>", "javascript:", "eval(", "expression(",
|
||||
"../", "..\\", "file://", "ftp://",
|
||||
"union select",
|
||||
"drop table",
|
||||
"insert into",
|
||||
"delete from",
|
||||
"script>",
|
||||
"javascript:",
|
||||
"eval(",
|
||||
"expression(",
|
||||
"../",
|
||||
"..\\",
|
||||
"file://",
|
||||
"ftp://",
|
||||
]
|
||||
|
||||
for pattern in suspicious_patterns:
|
||||
@@ -363,7 +415,9 @@ class SecurityInterceptor(ModuleInterceptor):
|
||||
def clean_dict(obj):
|
||||
if isinstance(obj, dict):
|
||||
return {
|
||||
k: "***REDACTED***" if any(sk in k.lower() for sk in sensitive_keys) else clean_dict(v)
|
||||
k: "***REDACTED***"
|
||||
if any(sk in k.lower() for sk in sensitive_keys)
|
||||
else clean_dict(v)
|
||||
for k, v in obj.items()
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
@@ -380,25 +434,39 @@ class AuditInterceptor(ModuleInterceptor):
|
||||
def __init__(self, module: BaseModule):
|
||||
self.module = module
|
||||
|
||||
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
async def pre_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
context["_audit_start_time"] = time.time()
|
||||
context["_audit_request_hash"] = self._hash_request(request)
|
||||
return request, context
|
||||
|
||||
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def post_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
await self._log_audit_event(request, context, response, success=True)
|
||||
return response
|
||||
|
||||
async def handle_error(self, request: Dict[str, Any], context: Dict[str, Any], error: Exception):
|
||||
async def handle_error(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any], error: Exception
|
||||
):
|
||||
"""Handle error logging"""
|
||||
await self._log_audit_event(request, context, {"error": str(error)}, success=False)
|
||||
await self._log_audit_event(
|
||||
request, context, {"error": str(error)}, success=False
|
||||
)
|
||||
|
||||
def _hash_request(self, request: Dict[str, Any]) -> str:
|
||||
"""Create a hash of the request for audit purposes"""
|
||||
request_str = json.dumps(request, sort_keys=True)
|
||||
return hashlib.sha256(request_str.encode()).hexdigest()[:16]
|
||||
|
||||
async def _log_audit_event(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any], success: bool):
|
||||
async def _log_audit_event(
|
||||
self,
|
||||
request: Dict[str, Any],
|
||||
context: Dict[str, Any],
|
||||
response: Dict[str, Any],
|
||||
success: bool,
|
||||
):
|
||||
"""Log audit event"""
|
||||
duration = time.time() - context.get("_audit_start_time", time.time())
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from sqlalchemy.orm import Session
|
||||
@dataclass
|
||||
class PluginContext:
|
||||
"""Plugin execution context with user and authentication info"""
|
||||
|
||||
user_id: Optional[str] = None
|
||||
api_key_id: Optional[str] = None
|
||||
user_permissions: List[str] = None
|
||||
@@ -45,15 +46,19 @@ class PlatformAPIClient:
|
||||
self.base_url = settings.INTERNAL_API_URL or "http://localhost:58000"
|
||||
self.logger = get_logger(f"plugin.{plugin_id}.api_client")
|
||||
|
||||
async def _make_request(self, method: str, endpoint: str, **kwargs) -> Dict[str, Any]:
|
||||
async def _make_request(
|
||||
self, method: str, endpoint: str, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Make authenticated request to platform API"""
|
||||
headers = kwargs.setdefault('headers', {})
|
||||
headers.update({
|
||||
'Authorization': f'Bearer {self.plugin_token}',
|
||||
'X-Plugin-ID': self.plugin_id,
|
||||
'X-Platform-Client': 'plugin',
|
||||
'Content-Type': 'application/json'
|
||||
})
|
||||
headers = kwargs.setdefault("headers", {})
|
||||
headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {self.plugin_token}",
|
||||
"X-Plugin-ID": self.plugin_id,
|
||||
"X-Platform-Client": "plugin",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
)
|
||||
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
|
||||
@@ -64,10 +69,10 @@ class PlatformAPIClient:
|
||||
error_text = await response.text()
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"Platform API error: {error_text}"
|
||||
detail=f"Platform API error: {error_text}",
|
||||
)
|
||||
|
||||
if response.content_type == 'application/json':
|
||||
if response.content_type == "application/json":
|
||||
return await response.json()
|
||||
else:
|
||||
return {"data": await response.text()}
|
||||
@@ -75,73 +80,65 @@ class PlatformAPIClient:
|
||||
except aiohttp.ClientError as e:
|
||||
self.logger.error(f"Platform API client error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"Platform API unavailable: {str(e)}"
|
||||
status_code=503, detail=f"Platform API unavailable: {str(e)}"
|
||||
)
|
||||
|
||||
async def get(self, endpoint: str, **kwargs) -> Dict[str, Any]:
|
||||
"""GET request to platform API"""
|
||||
return await self._make_request('GET', endpoint, **kwargs)
|
||||
return await self._make_request("GET", endpoint, **kwargs)
|
||||
|
||||
async def post(self, endpoint: str, data: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
|
||||
async def post(
|
||||
self, endpoint: str, data: Dict[str, Any] = None, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""POST request to platform API"""
|
||||
if data:
|
||||
kwargs['json'] = data
|
||||
return await self._make_request('POST', endpoint, **kwargs)
|
||||
kwargs["json"] = data
|
||||
return await self._make_request("POST", endpoint, **kwargs)
|
||||
|
||||
async def put(self, endpoint: str, data: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
|
||||
async def put(
|
||||
self, endpoint: str, data: Dict[str, Any] = None, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""PUT request to platform API"""
|
||||
if data:
|
||||
kwargs['json'] = data
|
||||
return await self._make_request('PUT', endpoint, **kwargs)
|
||||
kwargs["json"] = data
|
||||
return await self._make_request("PUT", endpoint, **kwargs)
|
||||
|
||||
async def delete(self, endpoint: str, **kwargs) -> Dict[str, Any]:
|
||||
"""DELETE request to platform API"""
|
||||
return await self._make_request('DELETE', endpoint, **kwargs)
|
||||
return await self._make_request("DELETE", endpoint, **kwargs)
|
||||
|
||||
# Platform-specific API methods
|
||||
async def call_chatbot_api(self, chatbot_id: str, message: str,
|
||||
context: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
async def call_chatbot_api(
|
||||
self, chatbot_id: str, message: str, context: Dict[str, Any] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Consume platform chatbot API"""
|
||||
return await self.post(
|
||||
f"/api/v1/chatbot/external/{chatbot_id}/chat",
|
||||
{
|
||||
"message": message,
|
||||
"context": context or {}
|
||||
}
|
||||
{"message": message, "context": context or {}},
|
||||
)
|
||||
|
||||
async def call_llm_api(self, model: str, messages: List[Dict[str, Any]],
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
async def call_llm_api(
|
||||
self, model: str, messages: List[Dict[str, Any]], **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Consume platform LLM API"""
|
||||
return await self.post(
|
||||
"/api/v1/llm/chat/completions",
|
||||
{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**kwargs
|
||||
}
|
||||
{"model": model, "messages": messages, **kwargs},
|
||||
)
|
||||
|
||||
async def search_rag(self, collection: str, query: str,
|
||||
top_k: int = 5) -> Dict[str, Any]:
|
||||
async def search_rag(
|
||||
self, collection: str, query: str, top_k: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
"""Consume platform RAG API"""
|
||||
return await self.post(
|
||||
f"/api/v1/rag/collections/{collection}/search",
|
||||
{
|
||||
"query": query,
|
||||
"top_k": top_k
|
||||
}
|
||||
{"query": query, "top_k": top_k},
|
||||
)
|
||||
|
||||
async def get_embeddings(self, model: str, input_text: str) -> Dict[str, Any]:
|
||||
"""Generate embeddings via platform API"""
|
||||
return await self.post(
|
||||
"/api/v1/llm/embeddings",
|
||||
{
|
||||
"model": model,
|
||||
"input": input_text
|
||||
}
|
||||
"/api/v1/llm/embeddings", {"model": model, "input": input_text}
|
||||
)
|
||||
|
||||
|
||||
@@ -157,13 +154,14 @@ class PluginConfigManager:
|
||||
try:
|
||||
# Use dependency injection to get database session
|
||||
from app.db.database import SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
# Query for active configuration
|
||||
query = db.query(PluginConfiguration).filter(
|
||||
PluginConfiguration.plugin_id == self.plugin_id,
|
||||
PluginConfiguration.is_active == True
|
||||
PluginConfiguration.is_active == True,
|
||||
)
|
||||
|
||||
if user_id:
|
||||
@@ -176,10 +174,14 @@ class PluginConfigManager:
|
||||
config = query.first()
|
||||
|
||||
if config:
|
||||
self.logger.debug(f"Retrieved configuration for plugin {self.plugin_id}, user {user_id}")
|
||||
self.logger.debug(
|
||||
f"Retrieved configuration for plugin {self.plugin_id}, user {user_id}"
|
||||
)
|
||||
return config.config_data or {}
|
||||
else:
|
||||
self.logger.debug(f"No configuration found for plugin {self.plugin_id}, user {user_id}")
|
||||
self.logger.debug(
|
||||
f"No configuration found for plugin {self.plugin_id}, user {user_id}"
|
||||
)
|
||||
return {}
|
||||
|
||||
finally:
|
||||
@@ -189,21 +191,30 @@ class PluginConfigManager:
|
||||
self.logger.error(f"Failed to get configuration: {e}")
|
||||
return {}
|
||||
|
||||
async def save_config(self, config: Dict[str, Any], user_id: str,
|
||||
async def save_config(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
user_id: str,
|
||||
name: str = "Default Configuration",
|
||||
description: str = None) -> bool:
|
||||
description: str = None,
|
||||
) -> bool:
|
||||
"""Save plugin configuration for user"""
|
||||
try:
|
||||
from app.db.database import SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
# Check if configuration already exists
|
||||
existing_config = db.query(PluginConfiguration).filter(
|
||||
existing_config = (
|
||||
db.query(PluginConfiguration)
|
||||
.filter(
|
||||
PluginConfiguration.plugin_id == self.plugin_id,
|
||||
PluginConfiguration.user_id == user_id,
|
||||
PluginConfiguration.name == name
|
||||
).first()
|
||||
PluginConfiguration.name == name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_config:
|
||||
# Update existing configuration
|
||||
@@ -211,7 +222,9 @@ class PluginConfigManager:
|
||||
existing_config.description = description
|
||||
existing_config.is_active = True
|
||||
|
||||
self.logger.info(f"Updated configuration for plugin {self.plugin_id}, user {user_id}")
|
||||
self.logger.info(
|
||||
f"Updated configuration for plugin {self.plugin_id}, user {user_id}"
|
||||
)
|
||||
else:
|
||||
# Create new configuration
|
||||
new_config = PluginConfiguration(
|
||||
@@ -222,20 +235,26 @@ class PluginConfigManager:
|
||||
config_data=config,
|
||||
is_active=True,
|
||||
is_default=(name == "Default Configuration"),
|
||||
created_by_user_id=user_id
|
||||
created_by_user_id=user_id,
|
||||
)
|
||||
|
||||
# If this is the first configuration for this user/plugin, make it default
|
||||
existing_count = db.query(PluginConfiguration).filter(
|
||||
existing_count = (
|
||||
db.query(PluginConfiguration)
|
||||
.filter(
|
||||
PluginConfiguration.plugin_id == self.plugin_id,
|
||||
PluginConfiguration.user_id == user_id
|
||||
).count()
|
||||
PluginConfiguration.user_id == user_id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
if existing_count == 0:
|
||||
new_config.is_default = True
|
||||
|
||||
db.add(new_config)
|
||||
self.logger.info(f"Created new configuration for plugin {self.plugin_id}, user {user_id}")
|
||||
self.logger.info(
|
||||
f"Created new configuration for plugin {self.plugin_id}, user {user_id}"
|
||||
)
|
||||
|
||||
db.commit()
|
||||
return True
|
||||
@@ -251,11 +270,13 @@ class PluginConfigManager:
|
||||
self.logger.error(f"Failed to save configuration: {e}")
|
||||
return False
|
||||
|
||||
async def validate_config(self, config: Dict[str, Any],
|
||||
schema: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
||||
async def validate_config(
|
||||
self, config: Dict[str, Any], schema: Dict[str, Any]
|
||||
) -> Tuple[bool, List[str]]:
|
||||
"""Validate configuration against JSON schema"""
|
||||
try:
|
||||
import jsonschema
|
||||
|
||||
jsonschema.validate(config, schema)
|
||||
return True, []
|
||||
except jsonschema.ValidationError as e:
|
||||
@@ -273,20 +294,27 @@ class PluginLogger:
|
||||
|
||||
# Sensitive data patterns to filter
|
||||
self.sensitive_patterns = [
|
||||
r'password', r'token', r'key', r'secret', r'api_key',
|
||||
r'bearer', r'authorization', r'credential'
|
||||
r"password",
|
||||
r"token",
|
||||
r"key",
|
||||
r"secret",
|
||||
r"api_key",
|
||||
r"bearer",
|
||||
r"authorization",
|
||||
r"credential",
|
||||
]
|
||||
|
||||
def _filter_sensitive_data(self, message: str) -> str:
|
||||
"""Filter sensitive data from log messages"""
|
||||
import re
|
||||
|
||||
filtered_message = message
|
||||
for pattern in self.sensitive_patterns:
|
||||
filtered_message = re.sub(
|
||||
f'{pattern}[=:]\s*["\']?([^"\'\\s]+)["\']?',
|
||||
f'{pattern}=***REDACTED***',
|
||||
f"{pattern}[=:]\s*[\"']?([^\"'\\s]+)[\"']?",
|
||||
f"{pattern}=***REDACTED***",
|
||||
filtered_message,
|
||||
flags=re.IGNORECASE
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
return filtered_message
|
||||
|
||||
@@ -360,7 +388,7 @@ class BasePlugin(ABC):
|
||||
"request_count": self._request_count,
|
||||
"error_count": self._error_count,
|
||||
"error_rate": round(error_rate, 3),
|
||||
"initialized": self.initialized
|
||||
"initialized": self.initialized,
|
||||
}
|
||||
|
||||
async def get_configuration_schema(self) -> Dict[str, Any]:
|
||||
@@ -404,16 +432,19 @@ class BasePlugin(ABC):
|
||||
|
||||
def get_auth_context(self) -> PluginContext:
|
||||
"""Dependency to get authentication context in API endpoints"""
|
||||
|
||||
async def _get_context(request: Request) -> PluginContext:
|
||||
# Extract authentication info from request
|
||||
# This would be populated by the plugin API gateway
|
||||
return PluginContext(
|
||||
user_id=request.headers.get('X-User-ID'),
|
||||
api_key_id=request.headers.get('X-API-Key-ID'),
|
||||
user_permissions=request.headers.get('X-User-Permissions', '').split(','),
|
||||
ip_address=request.headers.get('X-Real-IP'),
|
||||
user_agent=request.headers.get('User-Agent'),
|
||||
request_id=request.headers.get('X-Request-ID')
|
||||
user_id=request.headers.get("X-User-ID"),
|
||||
api_key_id=request.headers.get("X-API-Key-ID"),
|
||||
user_permissions=request.headers.get("X-User-Permissions", "").split(
|
||||
","
|
||||
),
|
||||
ip_address=request.headers.get("X-Real-IP"),
|
||||
user_agent=request.headers.get("User-Agent"),
|
||||
request_id=request.headers.get("X-Request-ID"),
|
||||
)
|
||||
|
||||
return Depends(_get_context)
|
||||
@@ -430,28 +461,54 @@ class PluginSecurityManager:
|
||||
|
||||
BLOCKED_IMPORTS = {
|
||||
# Core platform modules
|
||||
'app.db', 'app.models', 'app.core', 'app.services',
|
||||
'sqlalchemy', 'alembic',
|
||||
|
||||
"app.db",
|
||||
"app.models",
|
||||
"app.core",
|
||||
"app.services",
|
||||
"sqlalchemy",
|
||||
"alembic",
|
||||
# Security sensitive
|
||||
'subprocess', 'eval', 'exec', 'compile', '__import__',
|
||||
'os.system', 'os.popen', 'os.spawn',
|
||||
|
||||
"subprocess",
|
||||
"eval",
|
||||
"exec",
|
||||
"compile",
|
||||
"__import__",
|
||||
"os.system",
|
||||
"os.popen",
|
||||
"os.spawn",
|
||||
# System access
|
||||
'socket', 'multiprocessing', 'threading'
|
||||
"socket",
|
||||
"multiprocessing",
|
||||
"threading",
|
||||
}
|
||||
|
||||
ALLOWED_IMPORTS = {
|
||||
# Standard library
|
||||
'asyncio', 'aiohttp', 'json', 'datetime', 'typing', 'pydantic',
|
||||
'logging', 'time', 'uuid', 'hashlib', 'base64', 'pathlib',
|
||||
're', 'urllib.parse', 'dataclasses', 'enum',
|
||||
|
||||
"asyncio",
|
||||
"aiohttp",
|
||||
"json",
|
||||
"datetime",
|
||||
"typing",
|
||||
"pydantic",
|
||||
"logging",
|
||||
"time",
|
||||
"uuid",
|
||||
"hashlib",
|
||||
"base64",
|
||||
"pathlib",
|
||||
"re",
|
||||
"urllib.parse",
|
||||
"dataclasses",
|
||||
"enum",
|
||||
# Approved third-party
|
||||
'httpx', 'requests', 'pandas', 'numpy', 'yaml',
|
||||
|
||||
"httpx",
|
||||
"requests",
|
||||
"pandas",
|
||||
"numpy",
|
||||
"yaml",
|
||||
# Plugin framework
|
||||
'app.services.base_plugin', 'app.schemas.plugin_manifest'
|
||||
"app.services.base_plugin",
|
||||
"app.schemas.plugin_manifest",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -459,7 +516,9 @@ class PluginSecurityManager:
|
||||
"""Validate if plugin can import a module"""
|
||||
# Block dangerous imports
|
||||
if any(import_name.startswith(blocked) for blocked in cls.BLOCKED_IMPORTS):
|
||||
raise SecurityError(f"Import '{import_name}' not allowed in plugin environment")
|
||||
raise SecurityError(
|
||||
f"Import '{import_name}' not allowed in plugin environment"
|
||||
)
|
||||
|
||||
# Allow explicit safe imports
|
||||
if any(import_name.startswith(allowed) for allowed in cls.ALLOWED_IMPORTS):
|
||||
@@ -474,12 +533,12 @@ class PluginSecurityManager:
|
||||
def create_plugin_sandbox(cls, plugin_id: str) -> Dict[str, Any]:
|
||||
"""Create isolated environment for plugin execution"""
|
||||
return {
|
||||
'max_memory_mb': 128,
|
||||
'max_cpu_percent': 25,
|
||||
'max_disk_mb': 100,
|
||||
'max_api_calls_per_minute': 100,
|
||||
'allowed_domains': [], # Will be populated from manifest
|
||||
'network_timeout_seconds': 30
|
||||
"max_memory_mb": 128,
|
||||
"max_cpu_percent": 25,
|
||||
"max_disk_mb": 100,
|
||||
"max_api_calls_per_minute": 100,
|
||||
"allowed_domains": [], # Will be populated from manifest
|
||||
"network_timeout_seconds": 30,
|
||||
}
|
||||
|
||||
|
||||
@@ -499,7 +558,9 @@ class PluginLoader:
|
||||
validation_result = validate_manifest_file(manifest_path)
|
||||
|
||||
if not validation_result["valid"]:
|
||||
raise ValidationError(f"Invalid plugin manifest: {validation_result['errors']}")
|
||||
raise ValidationError(
|
||||
f"Invalid plugin manifest: {validation_result['errors']}"
|
||||
)
|
||||
|
||||
manifest = validation_result["manifest"]
|
||||
|
||||
@@ -511,8 +572,7 @@ class PluginLoader:
|
||||
# Load plugin module
|
||||
main_py_path = plugin_dir / "main.py"
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
f"plugin_{manifest.metadata.name}",
|
||||
main_py_path
|
||||
f"plugin_{manifest.metadata.name}", main_py_path
|
||||
)
|
||||
|
||||
if not spec or not spec.loader:
|
||||
@@ -536,14 +596,18 @@ class PluginLoader:
|
||||
plugin_class = None
|
||||
for attr_name in dir(plugin_module):
|
||||
attr = getattr(plugin_module, attr_name)
|
||||
if (isinstance(attr, type) and
|
||||
issubclass(attr, BasePlugin) and
|
||||
attr is not BasePlugin):
|
||||
if (
|
||||
isinstance(attr, type)
|
||||
and issubclass(attr, BasePlugin)
|
||||
and attr is not BasePlugin
|
||||
):
|
||||
plugin_class = attr
|
||||
break
|
||||
|
||||
if not plugin_class:
|
||||
raise ValidationError("Plugin must contain a class inheriting from BasePlugin")
|
||||
raise ValidationError(
|
||||
"Plugin must contain a class inheriting from BasePlugin"
|
||||
)
|
||||
|
||||
# Instantiate plugin
|
||||
plugin_instance = plugin_class(manifest, plugin_token)
|
||||
@@ -562,21 +626,30 @@ class PluginLoader:
|
||||
|
||||
def _validate_plugin_security(self, main_py_path: Path):
|
||||
"""Validate plugin code for security issues"""
|
||||
with open(main_py_path, 'r', encoding='utf-8') as f:
|
||||
with open(main_py_path, "r", encoding="utf-8") as f:
|
||||
code_content = f.read()
|
||||
|
||||
# Check for dangerous patterns
|
||||
dangerous_patterns = [
|
||||
'eval(', 'exec(', 'compile(',
|
||||
'subprocess.', 'os.system', 'os.popen',
|
||||
'__import__', 'importlib.import_module',
|
||||
'from app.db', 'from app.models',
|
||||
'sqlalchemy', 'SessionLocal'
|
||||
"eval(",
|
||||
"exec(",
|
||||
"compile(",
|
||||
"subprocess.",
|
||||
"os.system",
|
||||
"os.popen",
|
||||
"__import__",
|
||||
"importlib.import_module",
|
||||
"from app.db",
|
||||
"from app.models",
|
||||
"sqlalchemy",
|
||||
"SessionLocal",
|
||||
]
|
||||
|
||||
for pattern in dangerous_patterns:
|
||||
if pattern in code_content:
|
||||
raise SecurityError(f"Dangerous pattern detected in plugin code: {pattern}")
|
||||
raise SecurityError(
|
||||
f"Dangerous pattern detected in plugin code: {pattern}"
|
||||
)
|
||||
|
||||
async def unload_plugin(self, plugin_id: str) -> bool:
|
||||
"""Unload a plugin and cleanup resources"""
|
||||
|
||||
@@ -21,11 +21,13 @@ logger = get_logger(__name__)
|
||||
|
||||
class BudgetEnforcementError(Exception):
|
||||
"""Custom exception for budget enforcement failures"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BudgetExceededError(BudgetEnforcementError):
|
||||
"""Exception raised when budget would be exceeded"""
|
||||
|
||||
def __init__(self, message: str, budget: Budget, requested_cost: int):
|
||||
super().__init__(message)
|
||||
self.budget = budget
|
||||
@@ -34,6 +36,7 @@ class BudgetExceededError(BudgetEnforcementError):
|
||||
|
||||
class BudgetWarningError(BudgetEnforcementError):
|
||||
"""Exception raised when budget warning threshold is reached"""
|
||||
|
||||
def __init__(self, message: str, budget: Budget, requested_cost: int):
|
||||
super().__init__(message)
|
||||
self.budget = budget
|
||||
@@ -42,6 +45,7 @@ class BudgetWarningError(BudgetEnforcementError):
|
||||
|
||||
class BudgetConcurrencyError(BudgetEnforcementError):
|
||||
"""Exception raised when budget update fails due to concurrency"""
|
||||
|
||||
def __init__(self, message: str, retry_count: int = 0):
|
||||
super().__init__(message)
|
||||
self.retry_count = retry_count
|
||||
@@ -49,6 +53,7 @@ class BudgetConcurrencyError(BudgetEnforcementError):
|
||||
|
||||
class BudgetAtomicError(BudgetEnforcementError):
|
||||
"""Exception raised when atomic budget operation fails"""
|
||||
|
||||
def __init__(self, message: str, budget_id: int, requested_amount: int):
|
||||
super().__init__(message)
|
||||
self.budget_id = budget_id
|
||||
@@ -68,7 +73,7 @@ class BudgetEnforcementService:
|
||||
api_key: APIKey,
|
||||
model_name: str,
|
||||
estimated_tokens: int,
|
||||
endpoint: str = None
|
||||
endpoint: str = None,
|
||||
) -> Tuple[bool, Optional[str], List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
Atomically check budget compliance and reserve spending
|
||||
@@ -86,16 +91,27 @@ class BudgetEnforcementService:
|
||||
# Try atomic reservation with retries
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
return self._attempt_atomic_reservation(budgets, estimated_cost, api_key.id, attempt)
|
||||
return self._attempt_atomic_reservation(
|
||||
budgets, estimated_cost, api_key.id, attempt
|
||||
)
|
||||
except BudgetConcurrencyError as e:
|
||||
if attempt == self.max_retries - 1:
|
||||
logger.error(f"Atomic budget reservation failed after {self.max_retries} attempts: {e}")
|
||||
return False, f"Budget check temporarily unavailable (concurrency limit)", [], []
|
||||
logger.error(
|
||||
f"Atomic budget reservation failed after {self.max_retries} attempts: {e}"
|
||||
)
|
||||
return (
|
||||
False,
|
||||
f"Budget check temporarily unavailable (concurrency limit)",
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
# Exponential backoff with jitter
|
||||
delay = self.retry_delay_base * (2 ** attempt) + random.uniform(0, 0.1)
|
||||
delay = self.retry_delay_base * (2**attempt) + random.uniform(0, 0.1)
|
||||
time.sleep(delay)
|
||||
logger.info(f"Retrying atomic budget reservation (attempt {attempt + 2})")
|
||||
logger.info(
|
||||
f"Retrying atomic budget reservation (attempt {attempt + 2})"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in atomic budget reservation: {e}")
|
||||
return False, f"Budget check failed: {str(e)}", [], []
|
||||
@@ -103,11 +119,7 @@ class BudgetEnforcementService:
|
||||
return False, "Budget check failed after maximum retries", [], []
|
||||
|
||||
def _attempt_atomic_reservation(
|
||||
self,
|
||||
budgets: List[Budget],
|
||||
estimated_cost: int,
|
||||
api_key_id: int,
|
||||
attempt: int
|
||||
self, budgets: List[Budget], estimated_cost: int, api_key_id: int, attempt: int
|
||||
) -> Tuple[bool, Optional[str], List[Dict[str, Any]], List[int]]:
|
||||
"""Attempt to atomically reserve budget across all applicable budgets"""
|
||||
warnings = []
|
||||
@@ -119,12 +131,17 @@ class BudgetEnforcementService:
|
||||
|
||||
for budget in budgets:
|
||||
# Lock budget row for update to prevent concurrent modifications
|
||||
locked_budget = self.db.query(Budget).filter(
|
||||
Budget.id == budget.id
|
||||
).with_for_update().first()
|
||||
locked_budget = (
|
||||
self.db.query(Budget)
|
||||
.filter(Budget.id == budget.id)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
|
||||
if not locked_budget:
|
||||
raise BudgetAtomicError(f"Budget {budget.id} not found", budget.id, estimated_cost)
|
||||
raise BudgetAtomicError(
|
||||
f"Budget {budget.id} not found", budget.id, estimated_cost
|
||||
)
|
||||
|
||||
# Reset budget if expired and auto-renew enabled
|
||||
if locked_budget.is_expired() and locked_budget.auto_renew:
|
||||
@@ -144,28 +161,42 @@ class BudgetEnforcementService:
|
||||
f"Requested: ${estimated_cost/100:.4f}, "
|
||||
f"Remaining: ${(locked_budget.limit_cents - locked_budget.current_usage_cents)/100:.2f}"
|
||||
)
|
||||
logger.warning(f"Budget exceeded for API key {api_key_id}: {error_msg}")
|
||||
logger.warning(
|
||||
f"Budget exceeded for API key {api_key_id}: {error_msg}"
|
||||
)
|
||||
self.db.rollback()
|
||||
return False, error_msg, warnings, []
|
||||
|
||||
# Check warning threshold
|
||||
if locked_budget.would_exceed_warning(estimated_cost) and not locked_budget.is_warning_sent:
|
||||
if (
|
||||
locked_budget.would_exceed_warning(estimated_cost)
|
||||
and not locked_budget.is_warning_sent
|
||||
):
|
||||
warning_msg = (
|
||||
f"Budget '{locked_budget.name}' approaching limit. "
|
||||
f"Usage will be ${(locked_budget.current_usage_cents + estimated_cost)/100:.2f} "
|
||||
f"of ${locked_budget.limit_cents/100:.2f} "
|
||||
f"({((locked_budget.current_usage_cents + estimated_cost) / locked_budget.limit_cents * 100):.1f}%)"
|
||||
)
|
||||
warnings.append({
|
||||
warnings.append(
|
||||
{
|
||||
"type": "budget_warning",
|
||||
"budget_id": locked_budget.id,
|
||||
"budget_name": locked_budget.name,
|
||||
"message": warning_msg,
|
||||
"current_usage_cents": locked_budget.current_usage_cents + estimated_cost,
|
||||
"current_usage_cents": locked_budget.current_usage_cents
|
||||
+ estimated_cost,
|
||||
"limit_cents": locked_budget.limit_cents,
|
||||
"usage_percentage": (locked_budget.current_usage_cents + estimated_cost) / locked_budget.limit_cents * 100
|
||||
})
|
||||
logger.info(f"Budget warning for API key {api_key_id}: {warning_msg}")
|
||||
"usage_percentage": (
|
||||
locked_budget.current_usage_cents + estimated_cost
|
||||
)
|
||||
/ locked_budget.limit_cents
|
||||
* 100,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Budget warning for API key {api_key_id}: {warning_msg}"
|
||||
)
|
||||
|
||||
# Reserve the budget (temporarily add estimated cost)
|
||||
self._atomic_reserve_usage(locked_budget, estimated_cost)
|
||||
@@ -173,12 +204,16 @@ class BudgetEnforcementService:
|
||||
|
||||
# Commit the reservation
|
||||
self.db.commit()
|
||||
logger.debug(f"Successfully reserved budget for API key {api_key_id}, estimated cost: ${estimated_cost/100:.4f}")
|
||||
logger.debug(
|
||||
f"Successfully reserved budget for API key {api_key_id}, estimated cost: ${estimated_cost/100:.4f}"
|
||||
)
|
||||
return True, None, warnings, reserved_budget_ids
|
||||
|
||||
except IntegrityError as e:
|
||||
self.db.rollback()
|
||||
raise BudgetConcurrencyError(f"Database integrity error during budget reservation: {e}", attempt)
|
||||
raise BudgetConcurrencyError(
|
||||
f"Database integrity error during budget reservation: {e}", attempt
|
||||
)
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error in atomic budget reservation: {e}")
|
||||
@@ -203,24 +238,35 @@ class BudgetEnforcementService:
|
||||
.values(
|
||||
current_usage_cents=Budget.current_usage_cents + amount_cents,
|
||||
updated_at=datetime.utcnow(),
|
||||
is_exceeded=Budget.current_usage_cents + amount_cents >= Budget.limit_cents,
|
||||
is_exceeded=Budget.current_usage_cents + amount_cents
|
||||
>= Budget.limit_cents,
|
||||
is_warning_sent=(
|
||||
Budget.is_warning_sent |
|
||||
((Budget.warning_threshold_cents.isnot(None)) &
|
||||
(Budget.current_usage_cents + amount_cents >= Budget.warning_threshold_cents))
|
||||
Budget.is_warning_sent
|
||||
| (
|
||||
(Budget.warning_threshold_cents.isnot(None))
|
||||
& (
|
||||
Budget.current_usage_cents + amount_cents
|
||||
>= Budget.warning_threshold_cents
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if result.rowcount != 1:
|
||||
raise BudgetAtomicError(f"Failed to update budget {budget.id}", budget.id, amount_cents)
|
||||
raise BudgetAtomicError(
|
||||
f"Failed to update budget {budget.id}", budget.id, amount_cents
|
||||
)
|
||||
|
||||
# Update the in-memory object to reflect changes
|
||||
budget.current_usage_cents += amount_cents
|
||||
budget.updated_at = datetime.utcnow()
|
||||
if budget.current_usage_cents >= budget.limit_cents:
|
||||
budget.is_exceeded = True
|
||||
if budget.warning_threshold_cents and budget.current_usage_cents >= budget.warning_threshold_cents:
|
||||
if (
|
||||
budget.warning_threshold_cents
|
||||
and budget.current_usage_cents >= budget.warning_threshold_cents
|
||||
):
|
||||
budget.is_warning_sent = True
|
||||
|
||||
def atomic_finalize_usage(
|
||||
@@ -230,7 +276,7 @@ class BudgetEnforcementService:
|
||||
model_name: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
endpoint: str = None
|
||||
endpoint: str = None,
|
||||
) -> List[Budget]:
|
||||
"""
|
||||
Finalize actual usage and adjust reservations
|
||||
@@ -250,7 +296,9 @@ class BudgetEnforcementService:
|
||||
return []
|
||||
|
||||
try:
|
||||
actual_cost = CostCalculator.calculate_cost_cents(model_name, input_tokens, output_tokens)
|
||||
actual_cost = CostCalculator.calculate_cost_cents(
|
||||
model_name, input_tokens, output_tokens
|
||||
)
|
||||
updated_budgets = []
|
||||
|
||||
# Begin transaction for finalization
|
||||
@@ -258,9 +306,12 @@ class BudgetEnforcementService:
|
||||
|
||||
for budget_id in reserved_budget_ids:
|
||||
# Lock budget for update
|
||||
budget = self.db.query(Budget).filter(
|
||||
Budget.id == budget_id
|
||||
).with_for_update().first()
|
||||
budget = (
|
||||
self.db.query(Budget)
|
||||
.filter(Budget.id == budget_id)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
|
||||
if not budget:
|
||||
logger.warning(f"Budget {budget_id} not found during finalization")
|
||||
@@ -270,7 +321,9 @@ class BudgetEnforcementService:
|
||||
# Calculate adjustment (actual cost - estimated cost already reserved)
|
||||
# Note: We don't know the exact estimated cost that was reserved
|
||||
# So we'll just set to actual cost (this is safe as we already reserved)
|
||||
self._atomic_set_actual_usage(budget, actual_cost, input_tokens, output_tokens)
|
||||
self._atomic_set_actual_usage(
|
||||
budget, actual_cost, input_tokens, output_tokens
|
||||
)
|
||||
updated_budgets.append(budget)
|
||||
|
||||
logger.debug(
|
||||
@@ -287,7 +340,9 @@ class BudgetEnforcementService:
|
||||
self.db.rollback()
|
||||
return []
|
||||
|
||||
def _atomic_set_actual_usage(self, budget: Budget, actual_cost: int, input_tokens: int, output_tokens: int):
|
||||
def _atomic_set_actual_usage(
|
||||
self, budget: Budget, actual_cost: int, input_tokens: int, output_tokens: int
|
||||
):
|
||||
"""Set the actual usage cost (replacing any reservation)"""
|
||||
# For simplicity, we'll just ensure the current usage reflects actual cost
|
||||
# In a more sophisticated system, you might track reservations separately
|
||||
@@ -300,7 +355,7 @@ class BudgetEnforcementService:
|
||||
api_key: APIKey,
|
||||
model_name: str,
|
||||
estimated_tokens: int,
|
||||
endpoint: str = None
|
||||
endpoint: str = None,
|
||||
) -> Tuple[bool, Optional[str], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Check if a request complies with budget limits
|
||||
@@ -346,27 +401,41 @@ class BudgetEnforcementService:
|
||||
f"Requested: ${estimated_cost/100:.4f}, "
|
||||
f"Remaining: ${(budget.limit_cents - budget.current_usage_cents)/100:.2f}"
|
||||
)
|
||||
logger.warning(f"Budget exceeded for API key {api_key.id}: {error_msg}")
|
||||
logger.warning(
|
||||
f"Budget exceeded for API key {api_key.id}: {error_msg}"
|
||||
)
|
||||
return False, error_msg, warnings
|
||||
|
||||
# Check if request would trigger warning
|
||||
if budget.would_exceed_warning(estimated_cost) and not budget.is_warning_sent:
|
||||
if (
|
||||
budget.would_exceed_warning(estimated_cost)
|
||||
and not budget.is_warning_sent
|
||||
):
|
||||
warning_msg = (
|
||||
f"Budget '{budget.name}' approaching limit. "
|
||||
f"Usage will be ${(budget.current_usage_cents + estimated_cost)/100:.2f} "
|
||||
f"of ${budget.limit_cents/100:.2f} "
|
||||
f"({((budget.current_usage_cents + estimated_cost) / budget.limit_cents * 100):.1f}%)"
|
||||
)
|
||||
warnings.append({
|
||||
warnings.append(
|
||||
{
|
||||
"type": "budget_warning",
|
||||
"budget_id": budget.id,
|
||||
"budget_name": budget.name,
|
||||
"message": warning_msg,
|
||||
"current_usage_cents": budget.current_usage_cents + estimated_cost,
|
||||
"current_usage_cents": budget.current_usage_cents
|
||||
+ estimated_cost,
|
||||
"limit_cents": budget.limit_cents,
|
||||
"usage_percentage": (budget.current_usage_cents + estimated_cost) / budget.limit_cents * 100
|
||||
})
|
||||
logger.info(f"Budget warning for API key {api_key.id}: {warning_msg}")
|
||||
"usage_percentage": (
|
||||
budget.current_usage_cents + estimated_cost
|
||||
)
|
||||
/ budget.limit_cents
|
||||
* 100,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Budget warning for API key {api_key.id}: {warning_msg}"
|
||||
)
|
||||
|
||||
return True, None, warnings
|
||||
|
||||
@@ -381,7 +450,7 @@ class BudgetEnforcementService:
|
||||
model_name: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
endpoint: str = None
|
||||
endpoint: str = None,
|
||||
) -> List[Budget]:
|
||||
"""
|
||||
Record actual usage against applicable budgets
|
||||
@@ -398,7 +467,9 @@ class BudgetEnforcementService:
|
||||
"""
|
||||
try:
|
||||
# Calculate actual cost
|
||||
actual_cost = CostCalculator.calculate_cost_cents(model_name, input_tokens, output_tokens)
|
||||
actual_cost = CostCalculator.calculate_cost_cents(
|
||||
model_name, input_tokens, output_tokens
|
||||
)
|
||||
|
||||
# Get applicable budgets
|
||||
budgets = self._get_applicable_budgets(api_key, model_name, endpoint)
|
||||
@@ -427,10 +498,7 @@ class BudgetEnforcementService:
|
||||
return []
|
||||
|
||||
def _get_applicable_budgets(
|
||||
self,
|
||||
api_key: APIKey,
|
||||
model_name: str = None,
|
||||
endpoint: str = None
|
||||
self, api_key: APIKey, model_name: str = None, endpoint: str = None
|
||||
) -> List[Budget]:
|
||||
"""Get budgets that apply to the given request"""
|
||||
|
||||
@@ -438,9 +506,11 @@ class BudgetEnforcementService:
|
||||
conditions = [
|
||||
Budget.is_active == True,
|
||||
or_(
|
||||
and_(Budget.user_id == api_key.user_id, Budget.api_key_id.is_(None)), # User budget
|
||||
Budget.api_key_id == api_key.id # API key specific budget
|
||||
)
|
||||
and_(
|
||||
Budget.user_id == api_key.user_id, Budget.api_key_id.is_(None)
|
||||
), # User budget
|
||||
Budget.api_key_id == api_key.id, # API key specific budget
|
||||
),
|
||||
]
|
||||
|
||||
# Query budgets
|
||||
@@ -492,7 +562,7 @@ class BudgetEnforcementService:
|
||||
"warning_budgets": 0,
|
||||
"total_limit_cents": 0,
|
||||
"total_usage_cents": 0,
|
||||
"budgets": []
|
||||
"budgets": [],
|
||||
}
|
||||
|
||||
for budget in budgets:
|
||||
@@ -500,12 +570,14 @@ class BudgetEnforcementService:
|
||||
continue
|
||||
|
||||
budget_info = budget.to_dict()
|
||||
budget_info.update({
|
||||
budget_info.update(
|
||||
{
|
||||
"is_expired": budget.is_expired(),
|
||||
"days_remaining": budget.get_period_days_remaining(),
|
||||
"daily_burn_rate": budget.get_daily_burn_rate(),
|
||||
"projected_spend": budget.get_projected_spend()
|
||||
})
|
||||
"projected_spend": budget.get_projected_spend(),
|
||||
}
|
||||
)
|
||||
|
||||
status["budgets"].append(budget_info)
|
||||
status["active_budgets"] += 1
|
||||
@@ -514,18 +586,25 @@ class BudgetEnforcementService:
|
||||
|
||||
if budget.is_exceeded:
|
||||
status["exceeded_budgets"] += 1
|
||||
elif budget.warning_threshold_cents and budget.current_usage_cents >= budget.warning_threshold_cents:
|
||||
elif (
|
||||
budget.warning_threshold_cents
|
||||
and budget.current_usage_cents >= budget.warning_threshold_cents
|
||||
):
|
||||
status["warning_budgets"] += 1
|
||||
|
||||
# Calculate overall percentages
|
||||
if status["total_limit_cents"] > 0:
|
||||
status["overall_usage_percentage"] = (status["total_usage_cents"] / status["total_limit_cents"]) * 100
|
||||
status["overall_usage_percentage"] = (
|
||||
status["total_usage_cents"] / status["total_limit_cents"]
|
||||
) * 100
|
||||
else:
|
||||
status["overall_usage_percentage"] = 0
|
||||
|
||||
status["total_limit_dollars"] = status["total_limit_cents"] / 100
|
||||
status["total_usage_dollars"] = status["total_usage_cents"] / 100
|
||||
status["total_remaining_cents"] = max(0, status["total_limit_cents"] - status["total_usage_cents"])
|
||||
status["total_remaining_cents"] = max(
|
||||
0, status["total_limit_cents"] - status["total_usage_cents"]
|
||||
)
|
||||
status["total_remaining_dollars"] = status["total_remaining_cents"] / 100
|
||||
|
||||
return status
|
||||
@@ -538,14 +617,11 @@ class BudgetEnforcementService:
|
||||
"active_budgets": 0,
|
||||
"exceeded_budgets": 0,
|
||||
"warning_budgets": 0,
|
||||
"budgets": []
|
||||
"budgets": [],
|
||||
}
|
||||
|
||||
def create_default_user_budget(
|
||||
self,
|
||||
user_id: int,
|
||||
limit_dollars: float = 10.0,
|
||||
period_type: str = "monthly"
|
||||
self, user_id: int, limit_dollars: float = 10.0, period_type: str = "monthly"
|
||||
) -> Budget:
|
||||
"""Create a default budget for a new user"""
|
||||
try:
|
||||
@@ -553,19 +629,21 @@ class BudgetEnforcementService:
|
||||
budget = Budget.create_monthly_budget(
|
||||
user_id=user_id,
|
||||
name="Default Monthly Budget",
|
||||
limit_dollars=limit_dollars
|
||||
limit_dollars=limit_dollars,
|
||||
)
|
||||
else:
|
||||
budget = Budget.create_daily_budget(
|
||||
user_id=user_id,
|
||||
name="Default Daily Budget",
|
||||
limit_dollars=limit_dollars
|
||||
limit_dollars=limit_dollars,
|
||||
)
|
||||
|
||||
self.db.add(budget)
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"Created default budget for user {user_id}: ${limit_dollars} {period_type}")
|
||||
logger.info(
|
||||
f"Created default budget for user {user_id}: ${limit_dollars} {period_type}"
|
||||
)
|
||||
|
||||
return budget
|
||||
|
||||
@@ -577,13 +655,17 @@ class BudgetEnforcementService:
|
||||
def check_and_reset_expired_budgets(self):
|
||||
"""Background task to check and reset expired budgets"""
|
||||
try:
|
||||
expired_budgets = self.db.query(Budget).filter(
|
||||
expired_budgets = (
|
||||
self.db.query(Budget)
|
||||
.filter(
|
||||
and_(
|
||||
Budget.is_active == True,
|
||||
Budget.auto_renew == True,
|
||||
Budget.period_end < datetime.utcnow()
|
||||
Budget.period_end < datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
for budget in expired_budgets:
|
||||
self._reset_expired_budget(budget)
|
||||
@@ -596,17 +678,20 @@ class BudgetEnforcementService:
|
||||
|
||||
# Convenience functions
|
||||
|
||||
|
||||
# DEPRECATED: Use atomic versions for race-condition-free budget enforcement
|
||||
def check_budget_for_request(
|
||||
db: Session,
|
||||
api_key: APIKey,
|
||||
model_name: str,
|
||||
estimated_tokens: int,
|
||||
endpoint: str = None
|
||||
endpoint: str = None,
|
||||
) -> Tuple[bool, Optional[str], List[Dict[str, Any]]]:
|
||||
"""DEPRECATED: Convenience function to check budget compliance (race conditions possible)"""
|
||||
service = BudgetEnforcementService(db)
|
||||
return service.check_budget_compliance(api_key, model_name, estimated_tokens, endpoint)
|
||||
return service.check_budget_compliance(
|
||||
api_key, model_name, estimated_tokens, endpoint
|
||||
)
|
||||
|
||||
|
||||
def record_request_usage(
|
||||
@@ -615,11 +700,13 @@ def record_request_usage(
|
||||
model_name: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
endpoint: str = None
|
||||
endpoint: str = None,
|
||||
) -> List[Budget]:
|
||||
"""DEPRECATED: Convenience function to record actual usage (race conditions possible)"""
|
||||
service = BudgetEnforcementService(db)
|
||||
return service.record_usage(api_key, model_name, input_tokens, output_tokens, endpoint)
|
||||
return service.record_usage(
|
||||
api_key, model_name, input_tokens, output_tokens, endpoint
|
||||
)
|
||||
|
||||
|
||||
# ATOMIC VERSIONS: Race-condition-free budget enforcement
|
||||
@@ -628,11 +715,13 @@ def atomic_check_and_reserve_budget(
|
||||
api_key: APIKey,
|
||||
model_name: str,
|
||||
estimated_tokens: int,
|
||||
endpoint: str = None
|
||||
endpoint: str = None,
|
||||
) -> Tuple[bool, Optional[str], List[Dict[str, Any]], List[int]]:
|
||||
"""Atomic convenience function to check budget compliance and reserve spending"""
|
||||
service = BudgetEnforcementService(db)
|
||||
return service.atomic_check_and_reserve_budget(api_key, model_name, estimated_tokens, endpoint)
|
||||
return service.atomic_check_and_reserve_budget(
|
||||
api_key, model_name, estimated_tokens, endpoint
|
||||
)
|
||||
|
||||
|
||||
def atomic_finalize_usage(
|
||||
@@ -642,8 +731,10 @@ def atomic_finalize_usage(
|
||||
model_name: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
endpoint: str = None
|
||||
endpoint: str = None,
|
||||
) -> List[Budget]:
|
||||
"""Atomic convenience function to finalize actual usage after request completion"""
|
||||
service = BudgetEnforcementService(db)
|
||||
return service.atomic_finalize_usage(reserved_budget_ids, api_key, model_name, input_tokens, output_tokens, endpoint)
|
||||
return service.atomic_finalize_usage(
|
||||
reserved_budget_ids, api_key, model_name, input_tokens, output_tokens, endpoint
|
||||
)
|
||||
|
||||
@@ -31,9 +31,13 @@ class CachedAPIKeyService:
|
||||
|
||||
async def close(self):
|
||||
"""Close method for compatibility - core cache handles its own lifecycle"""
|
||||
logger.info("Cached API key service close called - core cache handles lifecycle")
|
||||
logger.info(
|
||||
"Cached API key service close called - core cache handles lifecycle"
|
||||
)
|
||||
|
||||
async def get_cached_api_key(self, key_prefix: str, db: AsyncSession) -> Optional[Dict[str, Any]]:
|
||||
async def get_cached_api_key(
|
||||
self, key_prefix: str, db: AsyncSession
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get API key data from cache or database
|
||||
Returns: Dictionary with api_key, user, and api_key_id
|
||||
@@ -57,19 +61,18 @@ class CachedAPIKeyService:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
"user": user,
|
||||
"api_key_id": api_key_data.get("id")
|
||||
"api_key_id": api_key_data.get("id"),
|
||||
}
|
||||
|
||||
logger.debug(f"API key cache miss for prefix: {key_prefix}, fetching from database")
|
||||
logger.debug(
|
||||
f"API key cache miss for prefix: {key_prefix}, fetching from database"
|
||||
)
|
||||
|
||||
# Cache miss - fetch from database with optimized query
|
||||
stmt = (
|
||||
select(APIKey, User)
|
||||
.join(User, APIKey.user_id == User.id)
|
||||
.options(
|
||||
joinedload(APIKey.user),
|
||||
joinedload(User.api_keys)
|
||||
)
|
||||
.options(joinedload(APIKey.user), joinedload(User.api_keys))
|
||||
.where(APIKey.key_prefix == key_prefix)
|
||||
.where(APIKey.is_active == True)
|
||||
)
|
||||
@@ -86,11 +89,7 @@ class CachedAPIKeyService:
|
||||
# Cache for future requests
|
||||
await self._cache_api_key_data(key_prefix, api_key, user)
|
||||
|
||||
return {
|
||||
"api_key": api_key,
|
||||
"user": user,
|
||||
"api_key_id": api_key.id
|
||||
}
|
||||
return {"api_key": api_key, "user": user, "api_key_id": api_key.id}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving API key for prefix {key_prefix}: {e}")
|
||||
@@ -118,17 +117,25 @@ class CachedAPIKeyService:
|
||||
"allowed_ips": api_key.allowed_ips,
|
||||
"description": api_key.description,
|
||||
"tags": api_key.tags,
|
||||
"created_at": api_key.created_at.isoformat() if api_key.created_at else None,
|
||||
"updated_at": api_key.updated_at.isoformat() if api_key.updated_at else None,
|
||||
"last_used_at": api_key.last_used_at.isoformat() if api_key.last_used_at else None,
|
||||
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
|
||||
"created_at": api_key.created_at.isoformat()
|
||||
if api_key.created_at
|
||||
else None,
|
||||
"updated_at": api_key.updated_at.isoformat()
|
||||
if api_key.updated_at
|
||||
else None,
|
||||
"last_used_at": api_key.last_used_at.isoformat()
|
||||
if api_key.last_used_at
|
||||
else None,
|
||||
"expires_at": api_key.expires_at.isoformat()
|
||||
if api_key.expires_at
|
||||
else None,
|
||||
"total_requests": api_key.total_requests,
|
||||
"total_tokens": api_key.total_tokens,
|
||||
"total_cost": api_key.total_cost,
|
||||
"is_unlimited": api_key.is_unlimited,
|
||||
"budget_limit_cents": api_key.budget_limit_cents,
|
||||
"budget_type": api_key.budget_type,
|
||||
"allowed_chatbots": api_key.allowed_chatbots
|
||||
"allowed_chatbots": api_key.allowed_chatbots,
|
||||
},
|
||||
"user_data": {
|
||||
"id": user.id,
|
||||
@@ -137,11 +144,17 @@ class CachedAPIKeyService:
|
||||
"is_active": user.is_active,
|
||||
"is_superuser": user.is_superuser,
|
||||
"role": user.role,
|
||||
"created_at": user.created_at.isoformat() if user.created_at else None,
|
||||
"updated_at": user.updated_at.isoformat() if user.updated_at else None,
|
||||
"last_login": user.last_login.isoformat() if user.last_login else None
|
||||
"created_at": user.created_at.isoformat()
|
||||
if user.created_at
|
||||
else None,
|
||||
"updated_at": user.updated_at.isoformat()
|
||||
if user.updated_at
|
||||
else None,
|
||||
"last_login": user.last_login.isoformat()
|
||||
if user.last_login
|
||||
else None,
|
||||
},
|
||||
"cached_at": datetime.utcnow().isoformat()
|
||||
"cached_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
await core_cache.cache_api_key(key_prefix, cache_data, self.cache_ttl)
|
||||
@@ -150,7 +163,9 @@ class CachedAPIKeyService:
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching API key data for prefix {key_prefix}: {e}")
|
||||
|
||||
async def verify_api_key_cached(self, api_key: str, key_prefix: str) -> Optional[bool]:
|
||||
async def verify_api_key_cached(
|
||||
self, api_key: str, key_prefix: str
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
Verify API key using cached hash to avoid expensive bcrypt operations
|
||||
Returns: True if verified, False if invalid, None if not cached
|
||||
@@ -161,25 +176,39 @@ class CachedAPIKeyService:
|
||||
|
||||
if cached_verification:
|
||||
# Check if cache is still valid (within TTL)
|
||||
cached_timestamp = datetime.fromisoformat(cached_verification["timestamp"])
|
||||
if datetime.utcnow() - cached_timestamp < timedelta(seconds=self.verification_cache_ttl):
|
||||
logger.debug(f"API key verification cache hit for prefix: {key_prefix}")
|
||||
cached_timestamp = datetime.fromisoformat(
|
||||
cached_verification["timestamp"]
|
||||
)
|
||||
if datetime.utcnow() - cached_timestamp < timedelta(
|
||||
seconds=self.verification_cache_ttl
|
||||
):
|
||||
logger.debug(
|
||||
f"API key verification cache hit for prefix: {key_prefix}"
|
||||
)
|
||||
return cached_verification.get("is_valid", False)
|
||||
|
||||
return None # Not cached or expired
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking verification cache for prefix {key_prefix}: {e}")
|
||||
logger.error(
|
||||
f"Error checking verification cache for prefix {key_prefix}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def cache_verification_result(self, api_key: str, key_prefix: str, key_hash: str, is_valid: bool):
|
||||
async def cache_verification_result(
|
||||
self, api_key: str, key_prefix: str, key_hash: str, is_valid: bool
|
||||
):
|
||||
"""Cache API key verification result to avoid expensive bcrypt operations"""
|
||||
try:
|
||||
await core_cache.cache_verification_result(api_key, key_prefix, key_hash, is_valid, self.verification_cache_ttl)
|
||||
await core_cache.cache_verification_result(
|
||||
api_key, key_prefix, key_hash, is_valid, self.verification_cache_ttl
|
||||
)
|
||||
logger.debug(f"Cached verification result for prefix: {key_prefix}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching verification result for prefix {key_prefix}: {e}")
|
||||
logger.error(
|
||||
f"Error caching verification result for prefix {key_prefix}: {e}"
|
||||
)
|
||||
|
||||
async def invalidate_api_key_cache(self, key_prefix: str):
|
||||
"""Invalidate cached API key data"""
|
||||
@@ -187,7 +216,9 @@ class CachedAPIKeyService:
|
||||
await core_cache.invalidate_api_key(key_prefix)
|
||||
|
||||
# Also invalidate verification cache
|
||||
verification_keys = await core_cache.clear_pattern(f"verify:{key_prefix}*", prefix="auth")
|
||||
verification_keys = await core_cache.clear_pattern(
|
||||
f"verify:{key_prefix}*", prefix="auth"
|
||||
)
|
||||
|
||||
logger.debug(f"Invalidated cache for API key prefix: {key_prefix}")
|
||||
|
||||
@@ -206,10 +237,7 @@ class CachedAPIKeyService:
|
||||
return # Skip update if recent
|
||||
|
||||
# Update database
|
||||
stmt = (
|
||||
select(APIKey)
|
||||
.where(APIKey.id == api_key_id)
|
||||
)
|
||||
stmt = select(APIKey).where(APIKey.id == api_key_id)
|
||||
result = await db.execute(stmt)
|
||||
api_key = result.scalar_one_or_none()
|
||||
|
||||
@@ -218,7 +246,9 @@ class CachedAPIKeyService:
|
||||
await db.commit()
|
||||
|
||||
# Cache that we updated to prevent spam
|
||||
await core_cache.set(cache_key, datetime.utcnow().isoformat(), ttl=300, prefix="perf")
|
||||
await core_cache.set(
|
||||
cache_key, datetime.utcnow().isoformat(), ttl=300, prefix="perf"
|
||||
)
|
||||
|
||||
logger.debug(f"Updated last_used_at for API key {api_key_id}")
|
||||
|
||||
@@ -232,14 +262,14 @@ class CachedAPIKeyService:
|
||||
return {
|
||||
"cache_backend": "core_cache",
|
||||
"cache_enabled": core_stats.get("enabled", False),
|
||||
"cache_stats": core_stats
|
||||
"cache_stats": core_stats,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting cache stats: {e}")
|
||||
return {
|
||||
"cache_backend": "core_cache",
|
||||
"cache_enabled": False,
|
||||
"error": str(e)
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ logger = get_logger(__name__)
|
||||
@dataclass
|
||||
class ConfigVersion:
|
||||
"""Configuration version metadata"""
|
||||
|
||||
version: str
|
||||
timestamp: datetime
|
||||
checksum: str
|
||||
@@ -36,6 +37,7 @@ class ConfigVersion:
|
||||
@dataclass
|
||||
class ConfigSchema:
|
||||
"""Configuration schema definition"""
|
||||
|
||||
name: str
|
||||
required_fields: List[str]
|
||||
optional_fields: List[str]
|
||||
@@ -46,6 +48,7 @@ class ConfigSchema:
|
||||
@dataclass
|
||||
class ConfigStats:
|
||||
"""Configuration manager statistics"""
|
||||
|
||||
total_configs: int
|
||||
active_watchers: int
|
||||
config_versions: int
|
||||
@@ -78,19 +81,19 @@ class ConfigWatcher(FileSystemEventHandler):
|
||||
self.last_modified[path] = current_time
|
||||
|
||||
# Trigger hot reload for config files
|
||||
if path.endswith(('.json', '.yaml', '.yml', '.toml')):
|
||||
if path.endswith((".json", ".yaml", ".yml", ".toml")):
|
||||
# Schedule coroutine in a thread-safe way
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.call_soon_threadsafe(
|
||||
lambda: asyncio.create_task(self.config_manager.reload_config_file(path))
|
||||
lambda: asyncio.create_task(
|
||||
self.config_manager.reload_config_file(path)
|
||||
)
|
||||
)
|
||||
except RuntimeError:
|
||||
# No running loop, schedule for later
|
||||
threading.Thread(
|
||||
target=self._schedule_reload,
|
||||
args=(path,),
|
||||
daemon=True
|
||||
target=self._schedule_reload, args=(path,), daemon=True
|
||||
).start()
|
||||
|
||||
def _schedule_reload(self, path: str):
|
||||
@@ -110,7 +113,7 @@ class ConfigManager:
|
||||
self.versions: Dict[str, List[ConfigVersion]] = {}
|
||||
self.watchers: Dict[str, Observer] = {}
|
||||
self.config_paths: Dict[str, Path] = {}
|
||||
self.environment = os.getenv('ENVIRONMENT', 'development')
|
||||
self.environment = os.getenv("ENVIRONMENT", "development")
|
||||
self.start_time = time.time()
|
||||
self.stats = ConfigStats(
|
||||
total_configs=0,
|
||||
@@ -119,7 +122,7 @@ class ConfigManager:
|
||||
hot_reloads_performed=0,
|
||||
validation_errors=0,
|
||||
last_reload_time=datetime.now(),
|
||||
uptime=0
|
||||
uptime=0,
|
||||
)
|
||||
|
||||
# Base configuration directories
|
||||
@@ -157,7 +160,9 @@ class ConfigManager:
|
||||
for field, expected_type in schema.field_types.items():
|
||||
if field in config_data:
|
||||
if not isinstance(config_data[field], expected_type):
|
||||
logger.error(f"Invalid type for field '{field}' in config '{name}'. Expected {expected_type.__name__}")
|
||||
logger.error(
|
||||
f"Invalid type for field '{field}' in config '{name}'. Expected {expected_type.__name__}"
|
||||
)
|
||||
self.stats.validation_errors += 1
|
||||
return False
|
||||
|
||||
@@ -165,7 +170,9 @@ class ConfigManager:
|
||||
for field, validator in schema.validators.items():
|
||||
if field in config_data:
|
||||
if not validator(config_data[field]):
|
||||
logger.error(f"Validation failed for field '{field}' in config '{name}'")
|
||||
logger.error(
|
||||
f"Validation failed for field '{field}' in config '{name}'"
|
||||
)
|
||||
self.stats.validation_errors += 1
|
||||
return False
|
||||
|
||||
@@ -182,7 +189,9 @@ class ConfigManager:
|
||||
json_str = json.dumps(data, sort_keys=True)
|
||||
return hashlib.sha256(json_str.encode()).hexdigest()
|
||||
|
||||
def _create_version(self, name: str, config_data: Dict[str, Any], description: str = "Auto-save") -> ConfigVersion:
|
||||
def _create_version(
|
||||
self, name: str, config_data: Dict[str, Any], description: str = "Auto-save"
|
||||
) -> ConfigVersion:
|
||||
"""Create a new configuration version"""
|
||||
version_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
checksum = self._calculate_checksum(config_data)
|
||||
@@ -191,9 +200,9 @@ class ConfigManager:
|
||||
version=version_id,
|
||||
timestamp=datetime.now(),
|
||||
checksum=checksum,
|
||||
author=os.getenv('USER', 'system'),
|
||||
author=os.getenv("USER", "system"),
|
||||
description=description,
|
||||
config_data=config_data.copy()
|
||||
config_data=config_data.copy(),
|
||||
)
|
||||
|
||||
if name not in self.versions:
|
||||
@@ -209,8 +218,9 @@ class ConfigManager:
|
||||
logger.debug(f"Created version {version_id} for config '{name}'")
|
||||
return version
|
||||
|
||||
async def set_config(self, name: str, config_data: Dict[str, Any],
|
||||
description: str = "Manual update") -> bool:
|
||||
async def set_config(
|
||||
self, name: str, config_data: Dict[str, Any], description: str = "Manual update"
|
||||
) -> bool:
|
||||
"""Set configuration with validation and versioning"""
|
||||
try:
|
||||
# Validate configuration
|
||||
@@ -247,13 +257,15 @@ class ConfigManager:
|
||||
|
||||
return default
|
||||
|
||||
async def get_config_value(self, config_name: str, key: str, default: Any = None) -> Any:
|
||||
async def get_config_value(
|
||||
self, config_name: str, key: str, default: Any = None
|
||||
) -> Any:
|
||||
"""Get specific value from configuration"""
|
||||
config = await self.get_config(config_name)
|
||||
if config is None:
|
||||
return default
|
||||
|
||||
keys = key.split('.')
|
||||
keys = key.split(".")
|
||||
value = config
|
||||
|
||||
try:
|
||||
@@ -269,7 +281,7 @@ class ConfigManager:
|
||||
|
||||
try:
|
||||
# Save as regular JSON
|
||||
with open(file_path, 'w') as f:
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(config_data, f, indent=2)
|
||||
logger.debug(f"Saved config '{name}' to {file_path}")
|
||||
|
||||
@@ -288,7 +300,7 @@ class ConfigManager:
|
||||
|
||||
try:
|
||||
# Load regular JSON
|
||||
with open(file_path, 'r') as f:
|
||||
with open(file_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
except Exception as e:
|
||||
@@ -302,11 +314,11 @@ class ConfigManager:
|
||||
config_name = path.stem
|
||||
|
||||
# Load updated configuration
|
||||
if path.suffix == '.json':
|
||||
with open(path, 'r') as f:
|
||||
if path.suffix == ".json":
|
||||
with open(path, "r") as f:
|
||||
new_config = json.load(f)
|
||||
elif path.suffix in ['.yaml', '.yml']:
|
||||
with open(path, 'r') as f:
|
||||
elif path.suffix in [".yaml", ".yml"]:
|
||||
with open(path, "r") as f:
|
||||
new_config = yaml.safe_load(f)
|
||||
else:
|
||||
logger.warning(f"Unsupported config file format: {path.suffix}")
|
||||
@@ -317,9 +329,13 @@ class ConfigManager:
|
||||
self.configs[config_name] = new_config
|
||||
self.stats.hot_reloads_performed += 1
|
||||
self.stats.last_reload_time = datetime.now()
|
||||
logger.info(f"Hot reloaded configuration '{config_name}' from {file_path}")
|
||||
logger.info(
|
||||
f"Hot reloaded configuration '{config_name}' from {file_path}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"Failed to hot reload '{config_name}' - validation failed")
|
||||
logger.error(
|
||||
f"Failed to hot reload '{config_name}' - validation failed"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error hot reloading config from {file_path}: {str(e)}")
|
||||
@@ -375,7 +391,7 @@ async def _register_default_schemas():
|
||||
required_fields=["host", "port", "name"],
|
||||
optional_fields=["username", "password", "ssl"],
|
||||
field_types={"host": str, "port": int, "name": str, "ssl": bool},
|
||||
validators={"port": lambda x: 1 <= x <= 65535}
|
||||
validators={"port": lambda x: 1 <= x <= 65535},
|
||||
)
|
||||
manager.register_schema("database", db_schema)
|
||||
|
||||
@@ -385,7 +401,7 @@ async def _register_default_schemas():
|
||||
required_fields=["redis_url"],
|
||||
optional_fields=["timeout", "max_connections"],
|
||||
field_types={"redis_url": str, "timeout": int, "max_connections": int},
|
||||
validators={"timeout": lambda x: x > 0}
|
||||
validators={"timeout": lambda x: x > 0},
|
||||
)
|
||||
manager.register_schema("cache", cache_schema)
|
||||
|
||||
@@ -400,13 +416,13 @@ async def _load_default_configs():
|
||||
"version": "1.0.0",
|
||||
"debug": manager.environment == "development",
|
||||
"log_level": "INFO",
|
||||
"timezone": "UTC"
|
||||
"timezone": "UTC",
|
||||
},
|
||||
"cache": {
|
||||
"redis_url": "redis://empire-redis:6379/0",
|
||||
"timeout": 30,
|
||||
"max_connections": 10
|
||||
}
|
||||
"max_connections": 10,
|
||||
},
|
||||
}
|
||||
|
||||
for name, config in default_configs.items():
|
||||
|
||||
@@ -27,7 +27,7 @@ class ConversationService:
|
||||
chatbot_id: str,
|
||||
user_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
title: Optional[str] = None
|
||||
title: Optional[str] = None,
|
||||
) -> ChatbotConversation:
|
||||
"""Get existing conversation or create a new one"""
|
||||
|
||||
@@ -38,7 +38,7 @@ class ConversationService:
|
||||
ChatbotConversation.id == conversation_id,
|
||||
ChatbotConversation.chatbot_id == chatbot_id,
|
||||
ChatbotConversation.user_id == user_id,
|
||||
ChatbotConversation.is_active == True
|
||||
ChatbotConversation.is_active == True,
|
||||
)
|
||||
)
|
||||
result = await self.db.execute(stmt)
|
||||
@@ -48,7 +48,9 @@ class ConversationService:
|
||||
logger.info(f"Found existing conversation {conversation_id}")
|
||||
return conversation
|
||||
else:
|
||||
logger.warning(f"Conversation {conversation_id} not found or not accessible")
|
||||
logger.warning(
|
||||
f"Conversation {conversation_id} not found or not accessible"
|
||||
)
|
||||
|
||||
# Create new conversation
|
||||
if not title:
|
||||
@@ -61,21 +63,20 @@ class ConversationService:
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow(),
|
||||
is_active=True,
|
||||
context_data={}
|
||||
context_data={},
|
||||
)
|
||||
|
||||
self.db.add(conversation)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(conversation)
|
||||
|
||||
logger.info(f"Created new conversation {conversation.id} for chatbot {chatbot_id}")
|
||||
logger.info(
|
||||
f"Created new conversation {conversation.id} for chatbot {chatbot_id}"
|
||||
)
|
||||
return conversation
|
||||
|
||||
async def get_conversation_history(
|
||||
self,
|
||||
conversation_id: str,
|
||||
limit: int = 20,
|
||||
include_system: bool = False
|
||||
self, conversation_id: str, limit: int = 20, include_system: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load conversation history for a conversation
|
||||
@@ -96,7 +97,7 @@ class ConversationService:
|
||||
|
||||
# Optionally exclude system messages
|
||||
if not include_system:
|
||||
stmt = stmt.where(ChatbotMessage.role != 'system')
|
||||
stmt = stmt.where(ChatbotMessage.role != "system")
|
||||
|
||||
# Order by timestamp descending and limit
|
||||
stmt = stmt.order_by(desc(ChatbotMessage.timestamp)).limit(limit)
|
||||
@@ -107,19 +108,27 @@ class ConversationService:
|
||||
# Convert to list and reverse to get chronological order (oldest first)
|
||||
history = []
|
||||
for msg in reversed(messages):
|
||||
history.append({
|
||||
history.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"timestamp": msg.timestamp.isoformat() if msg.timestamp else None,
|
||||
"timestamp": msg.timestamp.isoformat()
|
||||
if msg.timestamp
|
||||
else None,
|
||||
"metadata": msg.message_metadata or {},
|
||||
"sources": msg.sources
|
||||
})
|
||||
"sources": msg.sources,
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Loaded {len(history)} messages for conversation {conversation_id}")
|
||||
logger.info(
|
||||
f"Loaded {len(history)} messages for conversation {conversation_id}"
|
||||
)
|
||||
return history
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load conversation history for {conversation_id}: {e}")
|
||||
logger.error(
|
||||
f"Failed to load conversation history for {conversation_id}: {e}"
|
||||
)
|
||||
return [] # Return empty list on error to avoid breaking chat
|
||||
|
||||
async def add_message(
|
||||
@@ -128,11 +137,11 @@ class ConversationService:
|
||||
role: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
sources: Optional[List[Dict[str, Any]]] = None
|
||||
sources: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> ChatbotMessage:
|
||||
"""Add a message to a conversation"""
|
||||
|
||||
if role not in ['user', 'assistant', 'system']:
|
||||
if role not in ["user", "assistant", "system"]:
|
||||
raise ValueError(f"Invalid message role: {role}")
|
||||
|
||||
message = ChatbotMessage(
|
||||
@@ -141,13 +150,15 @@ class ConversationService:
|
||||
content=content,
|
||||
timestamp=datetime.utcnow(),
|
||||
message_metadata=metadata or {},
|
||||
sources=sources
|
||||
sources=sources,
|
||||
)
|
||||
|
||||
self.db.add(message)
|
||||
|
||||
# Update conversation timestamp
|
||||
stmt = select(ChatbotConversation).where(ChatbotConversation.id == conversation_id)
|
||||
stmt = select(ChatbotConversation).where(
|
||||
ChatbotConversation.id == conversation_id
|
||||
)
|
||||
result = await self.db.execute(stmt)
|
||||
conversation = result.scalar_one_or_none()
|
||||
|
||||
@@ -164,18 +175,19 @@ class ConversationService:
|
||||
"""Get statistics for a conversation"""
|
||||
|
||||
# Count messages by role
|
||||
stmt = select(
|
||||
ChatbotMessage.role,
|
||||
func.count(ChatbotMessage.id).label('count')
|
||||
).where(
|
||||
ChatbotMessage.conversation_id == conversation_id
|
||||
).group_by(ChatbotMessage.role)
|
||||
stmt = (
|
||||
select(ChatbotMessage.role, func.count(ChatbotMessage.id).label("count"))
|
||||
.where(ChatbotMessage.conversation_id == conversation_id)
|
||||
.group_by(ChatbotMessage.role)
|
||||
)
|
||||
|
||||
result = await self.db.execute(stmt)
|
||||
role_counts = {row.role: row.count for row in result}
|
||||
|
||||
# Get conversation info
|
||||
stmt = select(ChatbotConversation).where(ChatbotConversation.id == conversation_id)
|
||||
stmt = select(ChatbotConversation).where(
|
||||
ChatbotConversation.id == conversation_id
|
||||
)
|
||||
result = await self.db.execute(stmt)
|
||||
conversation = result.scalar_one_or_none()
|
||||
|
||||
@@ -185,12 +197,16 @@ class ConversationService:
|
||||
return {
|
||||
"conversation_id": conversation_id,
|
||||
"title": conversation.title,
|
||||
"created_at": conversation.created_at.isoformat() if conversation.created_at else None,
|
||||
"updated_at": conversation.updated_at.isoformat() if conversation.updated_at else None,
|
||||
"created_at": conversation.created_at.isoformat()
|
||||
if conversation.created_at
|
||||
else None,
|
||||
"updated_at": conversation.updated_at.isoformat()
|
||||
if conversation.updated_at
|
||||
else None,
|
||||
"total_messages": sum(role_counts.values()),
|
||||
"user_messages": role_counts.get('user', 0),
|
||||
"assistant_messages": role_counts.get('assistant', 0),
|
||||
"system_messages": role_counts.get('system', 0)
|
||||
"user_messages": role_counts.get("user", 0),
|
||||
"assistant_messages": role_counts.get("assistant", 0),
|
||||
"system_messages": role_counts.get("system", 0),
|
||||
}
|
||||
|
||||
async def archive_old_conversations(self, days_inactive: int = 30) -> int:
|
||||
@@ -202,7 +218,7 @@ class ConversationService:
|
||||
stmt = select(ChatbotConversation).where(
|
||||
and_(
|
||||
ChatbotConversation.updated_at < cutoff_date,
|
||||
ChatbotConversation.is_active == True
|
||||
ChatbotConversation.is_active == True,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -224,12 +240,16 @@ class ConversationService:
|
||||
"""Delete a conversation and all its messages"""
|
||||
|
||||
# Verify ownership
|
||||
stmt = select(ChatbotConversation).where(
|
||||
stmt = (
|
||||
select(ChatbotConversation)
|
||||
.where(
|
||||
and_(
|
||||
ChatbotConversation.id == conversation_id,
|
||||
ChatbotConversation.user_id == user_id
|
||||
ChatbotConversation.user_id == user_id,
|
||||
)
|
||||
)
|
||||
.options(selectinload(ChatbotConversation.messages))
|
||||
)
|
||||
).options(selectinload(ChatbotConversation.messages))
|
||||
|
||||
result = await self.db.execute(stmt)
|
||||
conversation = result.scalar_one_or_none()
|
||||
@@ -245,7 +265,9 @@ class ConversationService:
|
||||
await self.db.delete(conversation)
|
||||
await self.db.commit()
|
||||
|
||||
logger.info(f"Deleted conversation {conversation_id} with {len(conversation.messages)} messages")
|
||||
logger.info(
|
||||
f"Deleted conversation {conversation_id} with {len(conversation.messages)} messages"
|
||||
)
|
||||
return True
|
||||
|
||||
async def get_user_conversations(
|
||||
@@ -253,21 +275,25 @@ class ConversationService:
|
||||
user_id: str,
|
||||
chatbot_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
skip: int = 0
|
||||
skip: int = 0,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get list of conversations for a user"""
|
||||
|
||||
stmt = select(ChatbotConversation).where(
|
||||
and_(
|
||||
ChatbotConversation.user_id == user_id,
|
||||
ChatbotConversation.is_active == True
|
||||
ChatbotConversation.is_active == True,
|
||||
)
|
||||
)
|
||||
|
||||
if chatbot_id:
|
||||
stmt = stmt.where(ChatbotConversation.chatbot_id == chatbot_id)
|
||||
|
||||
stmt = stmt.order_by(desc(ChatbotConversation.updated_at)).offset(skip).limit(limit)
|
||||
stmt = (
|
||||
stmt.order_by(desc(ChatbotConversation.updated_at))
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await self.db.execute(stmt)
|
||||
conversations = result.scalars().all()
|
||||
@@ -281,14 +307,20 @@ class ConversationService:
|
||||
msg_count_result = await self.db.execute(msg_count_stmt)
|
||||
message_count = msg_count_result.scalar() or 0
|
||||
|
||||
conversation_list.append({
|
||||
conversation_list.append(
|
||||
{
|
||||
"id": conv.id,
|
||||
"chatbot_id": conv.chatbot_id,
|
||||
"title": conv.title,
|
||||
"message_count": message_count,
|
||||
"created_at": conv.created_at.isoformat() if conv.created_at else None,
|
||||
"updated_at": conv.updated_at.isoformat() if conv.updated_at else None,
|
||||
"context_data": conv.context_data or {}
|
||||
})
|
||||
"created_at": conv.created_at.isoformat()
|
||||
if conv.created_at
|
||||
else None,
|
||||
"updated_at": conv.updated_at.isoformat()
|
||||
if conv.updated_at
|
||||
else None,
|
||||
"context_data": conv.context_data or {},
|
||||
}
|
||||
)
|
||||
|
||||
return conversation_list
|
||||
@@ -17,20 +17,22 @@ class CostCalculator:
|
||||
"gpt-4": {"input": 300, "output": 600}, # $0.03/$0.06 per 1K tokens
|
||||
"gpt-4-turbo": {"input": 100, "output": 300}, # $0.01/$0.03 per 1K tokens
|
||||
"gpt-3.5-turbo": {"input": 5, "output": 15}, # $0.0005/$0.0015 per 1K tokens
|
||||
|
||||
# Anthropic Models
|
||||
"claude-3-opus": {"input": 150, "output": 750}, # $0.015/$0.075 per 1K tokens
|
||||
"claude-3-sonnet": {"input": 30, "output": 150}, # $0.003/$0.015 per 1K tokens
|
||||
"claude-3-haiku": {"input": 25, "output": 125}, # $0.00025/$0.00125 per 1K tokens
|
||||
|
||||
"claude-3-haiku": {
|
||||
"input": 25,
|
||||
"output": 125,
|
||||
}, # $0.00025/$0.00125 per 1K tokens
|
||||
# Google Models
|
||||
"gemini-pro": {"input": 5, "output": 15}, # $0.0005/$0.0015 per 1K tokens
|
||||
"gemini-pro-vision": {"input": 5, "output": 15}, # $0.0005/$0.0015 per 1K tokens
|
||||
|
||||
"gemini-pro-vision": {
|
||||
"input": 5,
|
||||
"output": 15,
|
||||
}, # $0.0005/$0.0015 per 1K tokens
|
||||
# Privatemode.ai Models (estimated pricing)
|
||||
"privatemode-llama-70b": {"input": 40, "output": 80}, # Estimated pricing
|
||||
"privatemode-mixtral": {"input": 20, "output": 40}, # Estimated pricing
|
||||
|
||||
# Embedding Models
|
||||
"text-embedding-ada-002": {"input": 1, "output": 0}, # $0.0001 per 1K tokens
|
||||
"text-embedding-3-small": {"input": 2, "output": 0}, # $0.00002 per 1K tokens
|
||||
@@ -49,7 +51,9 @@ class CostCalculator:
|
||||
# Look up pricing
|
||||
pricing = cls.MODEL_PRICING.get(normalized_name, cls.DEFAULT_PRICING)
|
||||
|
||||
logger.debug(f"Pricing for model '{model_name}' (normalized: '{normalized_name}'): {pricing}")
|
||||
logger.debug(
|
||||
f"Pricing for model '{model_name}' (normalized: '{normalized_name}'): {pricing}"
|
||||
)
|
||||
return pricing
|
||||
|
||||
@classmethod
|
||||
@@ -61,7 +65,7 @@ class CostCalculator:
|
||||
normalized = model_name.lower()
|
||||
for prefix in prefixes:
|
||||
if normalized.startswith(prefix):
|
||||
normalized = normalized[len(prefix):]
|
||||
normalized = normalized[len(prefix) :]
|
||||
break
|
||||
|
||||
# Handle special cases
|
||||
@@ -80,10 +84,7 @@ class CostCalculator:
|
||||
|
||||
@classmethod
|
||||
def calculate_cost_cents(
|
||||
cls,
|
||||
model_name: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0
|
||||
cls, model_name: str, input_tokens: int = 0, output_tokens: int = 0
|
||||
) -> int:
|
||||
"""
|
||||
Calculate cost in cents for given token usage
|
||||
@@ -147,7 +148,7 @@ class CostCalculator:
|
||||
return {
|
||||
"input": pricing_cents["input"] / 10000, # Convert 1/10000ths to dollars
|
||||
"output": pricing_cents["output"] / 10000,
|
||||
"currency": "USD"
|
||||
"currency": "USD",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -172,7 +173,9 @@ class CostCalculator:
|
||||
|
||||
|
||||
# Convenience functions for common operations
|
||||
def calculate_request_cost(model_name: str, input_tokens: int, output_tokens: int) -> int:
|
||||
def calculate_request_cost(
|
||||
model_name: str, input_tokens: int, output_tokens: int
|
||||
) -> int:
|
||||
"""Calculate cost for a single request"""
|
||||
return CostCalculator.calculate_cost_cents(model_name, input_tokens, output_tokens)
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ class ProcessingStatus(str, Enum):
|
||||
@dataclass
|
||||
class ProcessingTask:
|
||||
"""Document processing task"""
|
||||
|
||||
document_id: int
|
||||
priority: int = 1
|
||||
retry_count: int = 0
|
||||
@@ -57,7 +58,7 @@ class DocumentProcessor:
|
||||
"processed_count": 0,
|
||||
"error_count": 0,
|
||||
"queue_size": 0,
|
||||
"active_workers": 0
|
||||
"active_workers": 0,
|
||||
}
|
||||
self._rag_module = None
|
||||
self._rag_module_lock = asyncio.Lock()
|
||||
@@ -111,11 +112,15 @@ class DocumentProcessor:
|
||||
|
||||
self.stats["queue_size"] = self.processing_queue.qsize()
|
||||
|
||||
logger.info(f"Added processing task for document {document_id} (priority: {priority})")
|
||||
logger.info(
|
||||
f"Added processing task for document {document_id} (priority: {priority})"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add processing task for document {document_id}: {e}")
|
||||
logger.error(
|
||||
f"Failed to add processing task for document {document_id}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
async def _worker(self, worker_name: str):
|
||||
@@ -126,10 +131,7 @@ class DocumentProcessor:
|
||||
task: Optional[ProcessingTask] = None
|
||||
try:
|
||||
# Get task from queue (wait up to 1 second)
|
||||
task = await asyncio.wait_for(
|
||||
self.processing_queue.get(),
|
||||
timeout=1.0
|
||||
)
|
||||
task = await asyncio.wait_for(self.processing_queue.get(), timeout=1.0)
|
||||
|
||||
self.stats["active_workers"] += 1
|
||||
self.stats["queue_size"] = self.processing_queue.qsize()
|
||||
@@ -141,14 +143,20 @@ class DocumentProcessor:
|
||||
|
||||
if success:
|
||||
self.stats["processed_count"] += 1
|
||||
logger.info(f"{worker_name}: Successfully processed document {task.document_id}")
|
||||
logger.info(
|
||||
f"{worker_name}: Successfully processed document {task.document_id}"
|
||||
)
|
||||
else:
|
||||
# Retry logic
|
||||
if task.retry_count < task.max_retries:
|
||||
task.retry_count += 1
|
||||
await asyncio.sleep(2 ** task.retry_count) # Exponential backoff
|
||||
await asyncio.sleep(
|
||||
2**task.retry_count
|
||||
) # Exponential backoff
|
||||
try:
|
||||
await asyncio.wait_for(self.processing_queue.put(task), timeout=5.0)
|
||||
await asyncio.wait_for(
|
||||
self.processing_queue.put(task), timeout=5.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"%s: Failed to requeue document %s due to saturated queue",
|
||||
@@ -157,10 +165,14 @@ class DocumentProcessor:
|
||||
)
|
||||
self.stats["error_count"] += 1
|
||||
continue
|
||||
logger.warning(f"{worker_name}: Retrying document {task.document_id} (attempt {task.retry_count})")
|
||||
logger.warning(
|
||||
f"{worker_name}: Retrying document {task.document_id} (attempt {task.retry_count})"
|
||||
)
|
||||
else:
|
||||
self.stats["error_count"] += 1
|
||||
logger.error(f"{worker_name}: Failed to process document {task.document_id} after {task.max_retries} retries")
|
||||
logger.error(
|
||||
f"{worker_name}: Failed to process document {task.document_id} after {task.max_retries} retries"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# No tasks in queue, continue
|
||||
@@ -183,30 +195,34 @@ class DocumentProcessor:
|
||||
async def _get_rag_module(self):
|
||||
"""Resolve and cache the RAG module instance"""
|
||||
async with self._rag_module_lock:
|
||||
if self._rag_module and getattr(self._rag_module, 'enabled', False):
|
||||
if self._rag_module and getattr(self._rag_module, "enabled", False):
|
||||
return self._rag_module
|
||||
|
||||
if not module_manager.initialized:
|
||||
await module_manager.initialize()
|
||||
|
||||
rag_module = module_manager.get_module('rag')
|
||||
rag_module = module_manager.get_module("rag")
|
||||
|
||||
if not rag_module:
|
||||
enabled = await module_manager.enable_module('rag')
|
||||
enabled = await module_manager.enable_module("rag")
|
||||
if not enabled:
|
||||
raise RuntimeError("Failed to enable RAG module")
|
||||
rag_module = module_manager.get_module('rag')
|
||||
rag_module = module_manager.get_module("rag")
|
||||
|
||||
if not rag_module:
|
||||
raise RuntimeError("RAG module not available after enable attempt")
|
||||
|
||||
if not getattr(rag_module, 'enabled', True):
|
||||
enabled = await module_manager.enable_module('rag')
|
||||
if not getattr(rag_module, "enabled", True):
|
||||
enabled = await module_manager.enable_module("rag")
|
||||
if not enabled:
|
||||
raise RuntimeError("RAG module is disabled and could not be re-enabled")
|
||||
rag_module = module_manager.get_module('rag')
|
||||
if not rag_module or not getattr(rag_module, 'enabled', True):
|
||||
raise RuntimeError("RAG module is disabled and could not be re-enabled")
|
||||
raise RuntimeError(
|
||||
"RAG module is disabled and could not be re-enabled"
|
||||
)
|
||||
rag_module = module_manager.get_module("rag")
|
||||
if not rag_module or not getattr(rag_module, "enabled", True):
|
||||
raise RuntimeError(
|
||||
"RAG module is disabled and could not be re-enabled"
|
||||
)
|
||||
|
||||
self._rag_module = rag_module
|
||||
logger.info("DocumentProcessor cached RAG module instance for reuse")
|
||||
@@ -216,6 +232,7 @@ class DocumentProcessor:
|
||||
"""Process a single document"""
|
||||
from datetime import datetime
|
||||
from app.db.database import async_session_factory
|
||||
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
# Get document from database
|
||||
@@ -245,42 +262,61 @@ class DocumentProcessor:
|
||||
if not rag_module or not rag_module.enabled:
|
||||
raise Exception("RAG module not available or not enabled")
|
||||
|
||||
logger.info(f"RAG module loaded successfully for document {task.document_id}")
|
||||
logger.info(
|
||||
f"RAG module loaded successfully for document {task.document_id}"
|
||||
)
|
||||
|
||||
# Read file content
|
||||
logger.info(f"Reading file content for document {task.document_id}: {document.file_path}")
|
||||
logger.info(
|
||||
f"Reading file content for document {task.document_id}: {document.file_path}"
|
||||
)
|
||||
file_path = Path(document.file_path)
|
||||
try:
|
||||
file_content = await asyncio.to_thread(file_path.read_bytes)
|
||||
except FileNotFoundError:
|
||||
logger.error(f"File not found for document {task.document_id}: {document.file_path}")
|
||||
logger.error(
|
||||
f"File not found for document {task.document_id}: {document.file_path}"
|
||||
)
|
||||
document.status = ProcessingStatus.ERROR
|
||||
document.processing_error = "Document file not found on disk"
|
||||
await session.commit()
|
||||
return False
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed reading file for document {task.document_id}: {exc}")
|
||||
logger.error(
|
||||
f"Failed reading file for document {task.document_id}: {exc}"
|
||||
)
|
||||
document.status = ProcessingStatus.ERROR
|
||||
document.processing_error = f"Failed to read file: {exc}"
|
||||
await session.commit()
|
||||
return False
|
||||
|
||||
logger.info(f"File content read successfully for document {task.document_id}, size: {len(file_content)} bytes")
|
||||
logger.info(
|
||||
f"File content read successfully for document {task.document_id}, size: {len(file_content)} bytes"
|
||||
)
|
||||
|
||||
# Process with RAG module
|
||||
logger.info(f"Starting document processing for document {task.document_id} with RAG module")
|
||||
logger.info(
|
||||
f"Starting document processing for document {task.document_id} with RAG module"
|
||||
)
|
||||
|
||||
# Special handling for JSONL files - skip processing phase
|
||||
if document.file_type == 'jsonl':
|
||||
if document.file_type == "jsonl":
|
||||
# For JSONL files, we don't need to process content here
|
||||
# The optimized JSONL processor will handle everything during indexing
|
||||
document.converted_content = f"JSONL file with {len(file_content)} bytes"
|
||||
document.converted_content = (
|
||||
f"JSONL file with {len(file_content)} bytes"
|
||||
)
|
||||
document.word_count = 0 # Will be updated during indexing
|
||||
document.character_count = len(file_content)
|
||||
document.document_metadata = {"file_path": document.file_path, "processed": "jsonl"}
|
||||
document.document_metadata = {
|
||||
"file_path": document.file_path,
|
||||
"processed": "jsonl",
|
||||
}
|
||||
document.status = ProcessingStatus.PROCESSED
|
||||
document.processed_at = datetime.utcnow()
|
||||
logger.info(f"JSONL document {task.document_id} marked for optimized processing")
|
||||
logger.info(
|
||||
f"JSONL document {task.document_id} marked for optimized processing"
|
||||
)
|
||||
else:
|
||||
# Standard processing for other file types
|
||||
try:
|
||||
@@ -289,11 +325,13 @@ class DocumentProcessor:
|
||||
rag_module.process_document(
|
||||
file_content,
|
||||
document.original_filename,
|
||||
{"file_path": document.file_path}
|
||||
{"file_path": document.file_path},
|
||||
),
|
||||
timeout=300.0 # 5 minute timeout
|
||||
timeout=300.0, # 5 minute timeout
|
||||
)
|
||||
logger.info(
|
||||
f"Document processing completed for document {task.document_id}"
|
||||
)
|
||||
logger.info(f"Document processing completed for document {task.document_id}")
|
||||
|
||||
# Update document with processed content
|
||||
document.converted_content = processed_doc.content
|
||||
@@ -303,16 +341,22 @@ class DocumentProcessor:
|
||||
document.status = ProcessingStatus.PROCESSED
|
||||
document.processed_at = datetime.utcnow()
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Document processing timed out for document {task.document_id}")
|
||||
logger.error(
|
||||
f"Document processing timed out for document {task.document_id}"
|
||||
)
|
||||
raise Exception("Document processing timed out after 5 minutes")
|
||||
except Exception as e:
|
||||
logger.error(f"Document processing failed for document {task.document_id}: {e}")
|
||||
logger.error(
|
||||
f"Document processing failed for document {task.document_id}: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
# Index in RAG system using same RAG module
|
||||
if rag_module and document.converted_content:
|
||||
try:
|
||||
logger.info(f"Starting indexing for document {task.document_id} in collection {document.collection.qdrant_collection_name}")
|
||||
logger.info(
|
||||
f"Starting indexing for document {task.document_id} in collection {document.collection.qdrant_collection_name}"
|
||||
)
|
||||
|
||||
# Index the document content in the correct Qdrant collection
|
||||
doc_metadata = {
|
||||
@@ -320,12 +364,12 @@ class DocumentProcessor:
|
||||
"document_id": document.id,
|
||||
"filename": document.original_filename,
|
||||
"file_type": document.file_type,
|
||||
**document.document_metadata
|
||||
**document.document_metadata,
|
||||
}
|
||||
|
||||
# Use the correct Qdrant collection name for this document
|
||||
# For JSONL files, we need to use the processed document flow
|
||||
if document.file_type == 'jsonl':
|
||||
if document.file_type == "jsonl":
|
||||
# Create a ProcessedDocument for the JSONL processor
|
||||
from app.modules.rag.main import ProcessedDocument
|
||||
from datetime import datetime
|
||||
@@ -333,7 +377,9 @@ class DocumentProcessor:
|
||||
|
||||
# Calculate file hash
|
||||
processed_at = datetime.utcnow()
|
||||
file_hash = hashlib.md5(str(document.id).encode()).hexdigest()
|
||||
file_hash = hashlib.md5(
|
||||
str(document.id).encode()
|
||||
).hexdigest()
|
||||
|
||||
processed_doc = ProcessedDocument(
|
||||
id=str(document.id),
|
||||
@@ -341,12 +387,14 @@ class DocumentProcessor:
|
||||
extracted_text="", # Will be filled by JSONL processor
|
||||
metadata={
|
||||
**doc_metadata,
|
||||
"file_path": document.file_path
|
||||
"file_path": document.file_path,
|
||||
},
|
||||
original_filename=document.original_filename,
|
||||
file_type=document.file_type,
|
||||
mime_type=document.mime_type,
|
||||
language=document.document_metadata.get('language', 'EN'),
|
||||
language=document.document_metadata.get(
|
||||
"language", "EN"
|
||||
),
|
||||
word_count=0, # Will be updated during processing
|
||||
sentence_count=0, # Will be updated during processing
|
||||
entities=[],
|
||||
@@ -354,16 +402,16 @@ class DocumentProcessor:
|
||||
processing_time=0.0,
|
||||
processed_at=processed_at,
|
||||
file_hash=file_hash,
|
||||
file_size=document.file_size
|
||||
file_size=document.file_size,
|
||||
)
|
||||
|
||||
# The JSONL processor will read the original file
|
||||
await asyncio.wait_for(
|
||||
rag_module.index_processed_document(
|
||||
processed_doc=processed_doc,
|
||||
collection_name=document.collection.qdrant_collection_name
|
||||
collection_name=document.collection.qdrant_collection_name,
|
||||
),
|
||||
timeout=300.0 # 5 minute timeout for JSONL processing
|
||||
timeout=300.0, # 5 minute timeout for JSONL processing
|
||||
)
|
||||
else:
|
||||
# Use standard indexing for other file types
|
||||
@@ -371,15 +419,19 @@ class DocumentProcessor:
|
||||
rag_module.index_document(
|
||||
content=document.converted_content,
|
||||
metadata=doc_metadata,
|
||||
collection_name=document.collection.qdrant_collection_name
|
||||
collection_name=document.collection.qdrant_collection_name,
|
||||
),
|
||||
timeout=120.0 # 2 minute timeout for indexing
|
||||
timeout=120.0, # 2 minute timeout for indexing
|
||||
)
|
||||
|
||||
logger.info(f"Document {task.document_id} indexed successfully in collection {document.collection.qdrant_collection_name}")
|
||||
logger.info(
|
||||
f"Document {task.document_id} indexed successfully in collection {document.collection.qdrant_collection_name}"
|
||||
)
|
||||
|
||||
# Update vector count (approximate)
|
||||
document.vector_count = max(1, len(document.converted_content) // 1000)
|
||||
document.vector_count = max(
|
||||
1, len(document.converted_content) // 1000
|
||||
)
|
||||
document.status = ProcessingStatus.INDEXED
|
||||
document.indexed_at = datetime.utcnow()
|
||||
|
||||
@@ -392,7 +444,9 @@ class DocumentProcessor:
|
||||
collection.updated_at = datetime.utcnow()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to index document {task.document_id} in RAG: {e}")
|
||||
logger.error(
|
||||
f"Failed to index document {task.document_id} in RAG: {e}"
|
||||
)
|
||||
# Mark as error since indexing failed
|
||||
document.status = ProcessingStatus.ERROR
|
||||
document.processing_error = f"Indexing failed: {str(e)}"
|
||||
@@ -403,7 +457,7 @@ class DocumentProcessor:
|
||||
|
||||
except Exception as e:
|
||||
# Mark document as error
|
||||
if 'document' in locals() and document:
|
||||
if "document" in locals() and document:
|
||||
document.status = ProcessingStatus.ERROR
|
||||
document.processing_error = str(e)
|
||||
await session.commit()
|
||||
@@ -417,7 +471,7 @@ class DocumentProcessor:
|
||||
**self.stats,
|
||||
"running": self.running,
|
||||
"worker_count": len(self.workers),
|
||||
"queue_size": self.processing_queue.qsize()
|
||||
"queue_size": self.processing_queue.qsize(),
|
||||
}
|
||||
|
||||
async def get_queue_status(self) -> Dict[str, Any]:
|
||||
@@ -427,7 +481,7 @@ class DocumentProcessor:
|
||||
"max_queue_size": self.max_queue_size,
|
||||
"queue_full": self.processing_queue.full(),
|
||||
"active_workers": self.stats["active_workers"],
|
||||
"max_workers": self.max_workers
|
||||
"max_workers": self.max_workers,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,9 @@ class EmbeddingService:
|
||||
"""Service for generating text embeddings using a local transformer model"""
|
||||
|
||||
def __init__(self, model_name: Optional[str] = None):
|
||||
self.model_name = model_name or getattr(settings, "RAG_EMBEDDING_MODEL", "BAAI/bge-m3")
|
||||
self.model_name = model_name or getattr(
|
||||
settings, "RAG_EMBEDDING_MODEL", "BAAI/bge-m3"
|
||||
)
|
||||
self.dimension = 1024 # bge-m3 produces 1024-d vectors
|
||||
self.initialized = False
|
||||
self.local_model = None
|
||||
@@ -56,7 +58,9 @@ class EmbeddingService:
|
||||
return False
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to load local embedding model {self.model_name}: {exc}")
|
||||
logger.error(
|
||||
f"Failed to load local embedding model {self.model_name}: {exc}"
|
||||
)
|
||||
logger.warning("Falling back to random embeddings")
|
||||
self.local_model = None
|
||||
self.initialized = False
|
||||
@@ -102,13 +106,21 @@ class EmbeddingService:
|
||||
except Exception as exc:
|
||||
logger.error(f"Local embedding generation failed: {exc}")
|
||||
self.backend = "fallback_random"
|
||||
return self._generate_fallback_embeddings(texts, duration=time.time() - start_time)
|
||||
return self._generate_fallback_embeddings(
|
||||
texts, duration=time.time() - start_time
|
||||
)
|
||||
|
||||
logger.warning("Local embedding model unavailable; using fallback random embeddings")
|
||||
logger.warning(
|
||||
"Local embedding model unavailable; using fallback random embeddings"
|
||||
)
|
||||
self.backend = "fallback_random"
|
||||
return self._generate_fallback_embeddings(texts, duration=time.time() - start_time)
|
||||
return self._generate_fallback_embeddings(
|
||||
texts, duration=time.time() - start_time
|
||||
)
|
||||
|
||||
def _generate_fallback_embeddings(self, texts: List[str], duration: float = None) -> List[List[float]]:
|
||||
def _generate_fallback_embeddings(
|
||||
self, texts: List[str], duration: float = None
|
||||
) -> List[List[float]]:
|
||||
"""Generate fallback random embeddings when model unavailable"""
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
@@ -155,7 +167,7 @@ class EmbeddingService:
|
||||
"model_loaded": self.initialized,
|
||||
"dimension": self.dimension,
|
||||
"backend": self.backend,
|
||||
"initialized": self.initialized
|
||||
"initialized": self.initialized,
|
||||
}
|
||||
|
||||
async def cleanup(self):
|
||||
|
||||
@@ -21,17 +21,30 @@ class EnhancedEmbeddingService(EmbeddingService):
|
||||
def __init__(self, model_name: Optional[str] = None):
|
||||
super().__init__(model_name)
|
||||
self.rate_limit_tracker = {
|
||||
'requests_count': 0,
|
||||
'window_start': time.time(),
|
||||
'window_size': 60, # 1 minute window
|
||||
'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 12)), # Configurable
|
||||
'retry_delays': [int(x) for x in getattr(settings, 'RAG_EMBEDDING_RETRY_DELAYS', '1,2,4,8,16').split(',')], # Exponential backoff
|
||||
'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 1.0)),
|
||||
'delay_per_request': float(getattr(settings, 'RAG_EMBEDDING_DELAY_PER_REQUEST', 0.5)),
|
||||
'last_rate_limit_error': None
|
||||
"requests_count": 0,
|
||||
"window_start": time.time(),
|
||||
"window_size": 60, # 1 minute window
|
||||
"max_requests_per_minute": int(
|
||||
getattr(settings, "RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE", 12)
|
||||
), # Configurable
|
||||
"retry_delays": [
|
||||
int(x)
|
||||
for x in getattr(
|
||||
settings, "RAG_EMBEDDING_RETRY_DELAYS", "1,2,4,8,16"
|
||||
).split(",")
|
||||
], # Exponential backoff
|
||||
"delay_between_batches": float(
|
||||
getattr(settings, "RAG_EMBEDDING_DELAY_BETWEEN_BATCHES", 1.0)
|
||||
),
|
||||
"delay_per_request": float(
|
||||
getattr(settings, "RAG_EMBEDDING_DELAY_PER_REQUEST", 0.5)
|
||||
),
|
||||
"last_rate_limit_error": None,
|
||||
}
|
||||
|
||||
async def get_embeddings_with_retry(self, texts: List[str], max_retries: int = None) -> tuple[List[List[float]], bool]:
|
||||
async def get_embeddings_with_retry(
|
||||
self, texts: List[str], max_retries: int = None
|
||||
) -> tuple[List[List[float]], bool]:
|
||||
"""
|
||||
Get embeddings with retry bookkeeping.
|
||||
"""
|
||||
@@ -51,13 +64,20 @@ class EnhancedEmbeddingService(EmbeddingService):
|
||||
return {
|
||||
**base_stats,
|
||||
"rate_limit_info": {
|
||||
"requests_in_current_window": self.rate_limit_tracker['requests_count'],
|
||||
"max_requests_per_minute": self.rate_limit_tracker['max_requests_per_minute'],
|
||||
"window_reset_in_seconds": max(0,
|
||||
self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size'] - time.time()
|
||||
"requests_in_current_window": self.rate_limit_tracker["requests_count"],
|
||||
"max_requests_per_minute": self.rate_limit_tracker[
|
||||
"max_requests_per_minute"
|
||||
],
|
||||
"window_reset_in_seconds": max(
|
||||
0,
|
||||
self.rate_limit_tracker["window_start"]
|
||||
+ self.rate_limit_tracker["window_size"]
|
||||
- time.time(),
|
||||
),
|
||||
"last_rate_limit_error": self.rate_limit_tracker['last_rate_limit_error']
|
||||
}
|
||||
"last_rate_limit_error": self.rate_limit_tracker[
|
||||
"last_rate_limit_error"
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from qdrant_client.models import PointStruct, Filter, FieldCondition, MatchValue
|
||||
from qdrant_client.http.models import Batch
|
||||
|
||||
from app.modules.rag.main import ProcessedDocument
|
||||
|
||||
# from app.core.analytics import log_module_event # Analytics module not available
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -26,8 +27,13 @@ class JSONLProcessor:
|
||||
self.rag_module = rag_module
|
||||
self.config = rag_module.config
|
||||
|
||||
async def process_and_index_jsonl(self, collection_name: str, content: bytes,
|
||||
filename: str, metadata: Dict[str, Any]) -> str:
|
||||
async def process_and_index_jsonl(
|
||||
self,
|
||||
collection_name: str,
|
||||
content: bytes,
|
||||
filename: str,
|
||||
metadata: Dict[str, Any],
|
||||
) -> str:
|
||||
"""Process and index a JSONL file efficiently
|
||||
|
||||
Processes each JSON line as a separate document to avoid
|
||||
@@ -35,8 +41,8 @@ class JSONLProcessor:
|
||||
"""
|
||||
try:
|
||||
# Decode content
|
||||
jsonl_content = content.decode('utf-8', errors='replace')
|
||||
lines = jsonl_content.strip().split('\n')
|
||||
jsonl_content = content.decode("utf-8", errors="replace")
|
||||
lines = jsonl_content.strip().split("\n")
|
||||
|
||||
logger.info(f"Processing JSONL file {filename} with {len(lines)} lines")
|
||||
|
||||
@@ -58,28 +64,38 @@ class JSONLProcessor:
|
||||
batch_start,
|
||||
base_doc_id,
|
||||
filename,
|
||||
metadata
|
||||
metadata,
|
||||
)
|
||||
|
||||
processed_count += len(batch_lines)
|
||||
|
||||
# Log progress
|
||||
if processed_count % 50 == 0:
|
||||
logger.info(f"Processed {processed_count}/{len(lines)} lines from {filename}")
|
||||
logger.info(
|
||||
f"Processed {processed_count}/{len(lines)} lines from {filename}"
|
||||
)
|
||||
|
||||
# Small delay to prevent resource exhaustion
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
logger.info(f"Successfully processed JSONL file {filename} with {len(lines)} lines")
|
||||
logger.info(
|
||||
f"Successfully processed JSONL file {filename} with {len(lines)} lines"
|
||||
)
|
||||
return base_doc_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing JSONL file {filename}: {e}")
|
||||
raise
|
||||
|
||||
async def _process_jsonl_batch(self, collection_name: str, lines: List[str],
|
||||
start_idx: int, base_doc_id: str,
|
||||
filename: str, metadata: Dict[str, Any]) -> None:
|
||||
async def _process_jsonl_batch(
|
||||
self,
|
||||
collection_name: str,
|
||||
lines: List[str],
|
||||
start_idx: int,
|
||||
base_doc_id: str,
|
||||
filename: str,
|
||||
metadata: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Process a batch of JSONL lines"""
|
||||
try:
|
||||
points = []
|
||||
@@ -98,14 +114,14 @@ class JSONLProcessor:
|
||||
continue
|
||||
|
||||
# Handle helpjuice export format
|
||||
if 'payload' in data and data['payload'] is not None:
|
||||
payload = data['payload']
|
||||
article_id = data.get('id', f'article_{line_idx}')
|
||||
if "payload" in data and data["payload"] is not None:
|
||||
payload = data["payload"]
|
||||
article_id = data.get("id", f"article_{line_idx}")
|
||||
|
||||
# Extract Q&A
|
||||
question = payload.get('question', '')
|
||||
answer = payload.get('answer', '')
|
||||
language = payload.get('language', 'EN')
|
||||
question = payload.get("question", "")
|
||||
answer = payload.get("answer", "")
|
||||
language = payload.get("language", "EN")
|
||||
|
||||
if question or answer:
|
||||
# Create Q&A content
|
||||
@@ -120,15 +136,18 @@ class JSONLProcessor:
|
||||
"line_number": line_idx,
|
||||
"content_type": "qa_pair",
|
||||
"question": question[:100], # Truncate for metadata
|
||||
"processed_at": datetime.utcnow().isoformat()
|
||||
"processed_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
# Generate single embedding for the Q&A pair
|
||||
embeddings = await self.rag_module._generate_embeddings([content])
|
||||
embeddings = await self.rag_module._generate_embeddings(
|
||||
[content]
|
||||
)
|
||||
|
||||
# Create point
|
||||
point_id = str(uuid.uuid4())
|
||||
points.append(PointStruct(
|
||||
points.append(
|
||||
PointStruct(
|
||||
id=point_id,
|
||||
vector=embeddings[0],
|
||||
payload={
|
||||
@@ -136,9 +155,10 @@ class JSONLProcessor:
|
||||
"document_id": f"{base_doc_id}_{article_id}",
|
||||
"content": content,
|
||||
"chunk_index": 0,
|
||||
"chunk_count": 1
|
||||
}
|
||||
))
|
||||
"chunk_count": 1,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Handle generic JSON format
|
||||
else:
|
||||
@@ -146,12 +166,19 @@ class JSONLProcessor:
|
||||
|
||||
# For larger JSON objects, we might need to chunk
|
||||
if len(content) > 1000:
|
||||
chunks = self.rag_module._chunk_text(content, chunk_size=500)
|
||||
embeddings = await self.rag_module._generate_embeddings(chunks)
|
||||
chunks = self.rag_module._chunk_text(
|
||||
content, chunk_size=500
|
||||
)
|
||||
embeddings = await self.rag_module._generate_embeddings(
|
||||
chunks
|
||||
)
|
||||
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
for i, (chunk, embedding) in enumerate(
|
||||
zip(chunks, embeddings)
|
||||
):
|
||||
point_id = str(uuid.uuid4())
|
||||
points.append(PointStruct(
|
||||
points.append(
|
||||
PointStruct(
|
||||
id=point_id,
|
||||
vector=embedding,
|
||||
payload={
|
||||
@@ -162,14 +189,18 @@ class JSONLProcessor:
|
||||
"document_id": f"{base_doc_id}_line_{line_idx}",
|
||||
"content": chunk,
|
||||
"chunk_index": i,
|
||||
"chunk_count": len(chunks)
|
||||
}
|
||||
))
|
||||
"chunk_count": len(chunks),
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Small JSON - no chunking needed
|
||||
embeddings = await self.rag_module._generate_embeddings([content])
|
||||
embeddings = await self.rag_module._generate_embeddings(
|
||||
[content]
|
||||
)
|
||||
point_id = str(uuid.uuid4())
|
||||
points.append(PointStruct(
|
||||
points.append(
|
||||
PointStruct(
|
||||
id=point_id,
|
||||
vector=embeddings[0],
|
||||
payload={
|
||||
@@ -180,9 +211,10 @@ class JSONLProcessor:
|
||||
"document_id": f"{base_doc_id}_line_{line_idx}",
|
||||
"content": content,
|
||||
"chunk_index": 0,
|
||||
"chunk_count": 1
|
||||
}
|
||||
))
|
||||
"chunk_count": 1,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Error parsing JSONL line {line_idx}: {e}")
|
||||
@@ -194,8 +226,7 @@ class JSONLProcessor:
|
||||
# Insert all points in this batch
|
||||
if points:
|
||||
self.rag_module.qdrant_client.upsert(
|
||||
collection_name=collection_name,
|
||||
points=points
|
||||
collection_name=collection_name, points=points
|
||||
)
|
||||
|
||||
# Update stats
|
||||
|
||||
@@ -17,5 +17,5 @@ __all__ = [
|
||||
"EmbeddingResponse",
|
||||
"LLMError",
|
||||
"ProviderError",
|
||||
"SecurityError"
|
||||
"SecurityError",
|
||||
]
|
||||
@@ -15,30 +15,51 @@ from .models import ResilienceConfig
|
||||
|
||||
class ProviderConfig(BaseModel):
|
||||
"""Configuration for an LLM provider"""
|
||||
|
||||
name: str = Field(..., description="Provider name")
|
||||
provider_type: str = Field(..., description="Provider type (e.g., 'openai', 'privatemode')")
|
||||
provider_type: str = Field(
|
||||
..., description="Provider type (e.g., 'openai', 'privatemode')"
|
||||
)
|
||||
enabled: bool = Field(True, description="Whether provider is enabled")
|
||||
base_url: str = Field(..., description="Provider base URL")
|
||||
api_key_env_var: str = Field(..., description="Environment variable for API key")
|
||||
default_model: Optional[str] = Field(None, description="Default model for this provider")
|
||||
supported_models: List[str] = Field(default_factory=list, description="List of supported models")
|
||||
capabilities: List[str] = Field(default_factory=list, description="Provider capabilities")
|
||||
default_model: Optional[str] = Field(
|
||||
None, description="Default model for this provider"
|
||||
)
|
||||
supported_models: List[str] = Field(
|
||||
default_factory=list, description="List of supported models"
|
||||
)
|
||||
capabilities: List[str] = Field(
|
||||
default_factory=list, description="Provider capabilities"
|
||||
)
|
||||
priority: int = Field(1, description="Provider priority (lower = higher priority)")
|
||||
|
||||
# Rate limiting
|
||||
max_requests_per_minute: Optional[int] = Field(None, description="Max requests per minute")
|
||||
max_requests_per_hour: Optional[int] = Field(None, description="Max requests per hour")
|
||||
max_requests_per_minute: Optional[int] = Field(
|
||||
None, description="Max requests per minute"
|
||||
)
|
||||
max_requests_per_hour: Optional[int] = Field(
|
||||
None, description="Max requests per hour"
|
||||
)
|
||||
|
||||
# Model-specific settings
|
||||
supports_streaming: bool = Field(False, description="Whether provider supports streaming")
|
||||
supports_function_calling: bool = Field(False, description="Whether provider supports function calling")
|
||||
max_context_window: Optional[int] = Field(None, description="Maximum context window size")
|
||||
supports_streaming: bool = Field(
|
||||
False, description="Whether provider supports streaming"
|
||||
)
|
||||
supports_function_calling: bool = Field(
|
||||
False, description="Whether provider supports function calling"
|
||||
)
|
||||
max_context_window: Optional[int] = Field(
|
||||
None, description="Maximum context window size"
|
||||
)
|
||||
max_output_tokens: Optional[int] = Field(None, description="Maximum output tokens")
|
||||
|
||||
# Resilience configuration
|
||||
resilience: ResilienceConfig = Field(default_factory=ResilienceConfig, description="Resilience settings")
|
||||
resilience: ResilienceConfig = Field(
|
||||
default_factory=ResilienceConfig, description="Resilience settings"
|
||||
)
|
||||
|
||||
@validator('priority')
|
||||
@validator("priority")
|
||||
def validate_priority(cls, v):
|
||||
if v < 1:
|
||||
raise ValueError("Priority must be >= 1")
|
||||
@@ -50,32 +71,45 @@ class LLMServiceConfig(BaseModel):
|
||||
|
||||
# Global settings
|
||||
default_provider: str = Field("privatemode", description="Default provider to use")
|
||||
enable_detailed_logging: bool = Field(False, description="Enable detailed request/response logging")
|
||||
enable_detailed_logging: bool = Field(
|
||||
False, description="Enable detailed request/response logging"
|
||||
)
|
||||
enable_security_checks: bool = Field(True, description="Enable security validation")
|
||||
enable_metrics_collection: bool = Field(True, description="Enable metrics collection")
|
||||
enable_metrics_collection: bool = Field(
|
||||
True, description="Enable metrics collection"
|
||||
)
|
||||
|
||||
max_prompt_length: int = Field(50000, ge=1000, description="Maximum prompt length")
|
||||
max_response_length: int = Field(32000, ge=1000, description="Maximum response length")
|
||||
max_response_length: int = Field(
|
||||
32000, ge=1000, description="Maximum response length"
|
||||
)
|
||||
|
||||
# Performance settings
|
||||
default_timeout_ms: int = Field(30000, ge=1000, le=300000, description="Default request timeout")
|
||||
max_concurrent_requests: int = Field(100, ge=1, le=1000, description="Maximum concurrent requests")
|
||||
default_timeout_ms: int = Field(
|
||||
30000, ge=1000, le=300000, description="Default request timeout"
|
||||
)
|
||||
max_concurrent_requests: int = Field(
|
||||
100, ge=1, le=1000, description="Maximum concurrent requests"
|
||||
)
|
||||
|
||||
# Provider configurations
|
||||
providers: Dict[str, ProviderConfig] = Field(default_factory=dict, description="Provider configurations")
|
||||
providers: Dict[str, ProviderConfig] = Field(
|
||||
default_factory=dict, description="Provider configurations"
|
||||
)
|
||||
|
||||
# Token rate limiting (organization-wide)
|
||||
token_limits_per_minute: Dict[str, int] = Field(
|
||||
default_factory=lambda: {
|
||||
"prompt_tokens": 20000, # PrivateMode Standard tier
|
||||
"completion_tokens": 10000 # PrivateMode Standard tier
|
||||
"completion_tokens": 10000, # PrivateMode Standard tier
|
||||
},
|
||||
description="Token rate limits per minute (organization-wide)"
|
||||
description="Token rate limits per minute (organization-wide)",
|
||||
)
|
||||
|
||||
# Model routing (model_name -> provider_name)
|
||||
model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing")
|
||||
|
||||
model_routing: Dict[str, str] = Field(
|
||||
default_factory=dict, description="Model to provider routing"
|
||||
)
|
||||
|
||||
|
||||
def create_default_config(env_vars=None) -> LLMServiceConfig:
|
||||
@@ -105,13 +139,11 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
|
||||
retry_delay_ms=1000,
|
||||
timeout_ms=60000, # PrivateMode may be slower due to TEE
|
||||
circuit_breaker_threshold=5,
|
||||
circuit_breaker_reset_timeout_ms=120000
|
||||
)
|
||||
circuit_breaker_reset_timeout_ms=120000,
|
||||
),
|
||||
)
|
||||
|
||||
providers: Dict[str, ProviderConfig] = {
|
||||
"privatemode": privatemode_config
|
||||
}
|
||||
providers: Dict[str, ProviderConfig] = {"privatemode": privatemode_config}
|
||||
|
||||
if env.OPENAI_API_KEY:
|
||||
providers["openai"] = ProviderConfig(
|
||||
@@ -126,7 +158,7 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
|
||||
"gpt-4o",
|
||||
"gpt-3.5-turbo",
|
||||
"text-embedding-3-large",
|
||||
"text-embedding-3-small"
|
||||
"text-embedding-3-small",
|
||||
],
|
||||
capabilities=["chat", "embeddings"],
|
||||
priority=2,
|
||||
@@ -139,8 +171,8 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
|
||||
retry_delay_ms=750,
|
||||
timeout_ms=45000,
|
||||
circuit_breaker_threshold=6,
|
||||
circuit_breaker_reset_timeout_ms=60000
|
||||
)
|
||||
circuit_breaker_reset_timeout_ms=60000,
|
||||
),
|
||||
)
|
||||
|
||||
if env.ANTHROPIC_API_KEY:
|
||||
@@ -154,7 +186,7 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
|
||||
supported_models=[
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-haiku-20240307"
|
||||
"claude-3-haiku-20240307",
|
||||
],
|
||||
capabilities=["chat"],
|
||||
priority=3,
|
||||
@@ -167,8 +199,8 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
|
||||
retry_delay_ms=1000,
|
||||
timeout_ms=60000,
|
||||
circuit_breaker_threshold=5,
|
||||
circuit_breaker_reset_timeout_ms=90000
|
||||
)
|
||||
circuit_breaker_reset_timeout_ms=90000,
|
||||
),
|
||||
)
|
||||
|
||||
if env.GOOGLE_API_KEY:
|
||||
@@ -181,7 +213,7 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
|
||||
default_model="models/gemini-1.5-pro-latest",
|
||||
supported_models=[
|
||||
"models/gemini-1.5-pro-latest",
|
||||
"models/gemini-1.5-flash-latest"
|
||||
"models/gemini-1.5-flash-latest",
|
||||
],
|
||||
capabilities=["chat", "multimodal"],
|
||||
priority=4,
|
||||
@@ -194,13 +226,13 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
|
||||
retry_delay_ms=1000,
|
||||
timeout_ms=45000,
|
||||
circuit_breaker_threshold=4,
|
||||
circuit_breaker_reset_timeout_ms=60000
|
||||
)
|
||||
circuit_breaker_reset_timeout_ms=60000,
|
||||
),
|
||||
)
|
||||
|
||||
default_provider = next(
|
||||
(name for name, provider in providers.items() if provider.enabled),
|
||||
"privatemode"
|
||||
"privatemode",
|
||||
)
|
||||
|
||||
# Create main configuration
|
||||
@@ -208,7 +240,7 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
|
||||
default_provider=default_provider,
|
||||
enable_detailed_logging=settings.LOG_LLM_PROMPTS,
|
||||
providers=providers,
|
||||
model_routing={} # Will be populated dynamically from provider models
|
||||
model_routing={}, # Will be populated dynamically from provider models
|
||||
)
|
||||
|
||||
return config
|
||||
@@ -241,7 +273,7 @@ class EnvironmentVariables:
|
||||
"privatemode": self.PRIVATEMODE_API_KEY,
|
||||
"openai": self.OPENAI_API_KEY,
|
||||
"anthropic": self.ANTHROPIC_API_KEY,
|
||||
"google": self.GOOGLE_API_KEY
|
||||
"google": self.GOOGLE_API_KEY,
|
||||
}
|
||||
|
||||
return key_mapping.get(provider_name.lower())
|
||||
@@ -307,8 +339,11 @@ class ConfigurationManager:
|
||||
|
||||
if missing_keys:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"Missing API keys for enabled providers: {', '.join(missing_keys)}")
|
||||
logger.warning(
|
||||
f"Missing API keys for enabled providers: {', '.join(missing_keys)}"
|
||||
)
|
||||
|
||||
# Validate default provider is enabled
|
||||
default_provider = self._config.default_provider
|
||||
@@ -323,8 +358,11 @@ class ConfigurationManager:
|
||||
|
||||
if invalid_routes:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"Model routes point to disabled providers: {', '.join(invalid_routes)}")
|
||||
logger.warning(
|
||||
f"Model routes point to disabled providers: {', '.join(invalid_routes)}"
|
||||
)
|
||||
|
||||
async def refresh_provider_models(self, provider_name: str, models: List[str]):
|
||||
"""Update supported models for a provider dynamically"""
|
||||
@@ -343,6 +381,7 @@ class ConfigurationManager:
|
||||
self._config.model_routing[model] = provider_name
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Updated {provider_name} with {len(models)} models: {models}")
|
||||
|
||||
|
||||
@@ -8,7 +8,9 @@ Custom exceptions for LLM service operations.
|
||||
class LLMError(Exception):
|
||||
"""Base exception for LLM service errors"""
|
||||
|
||||
def __init__(self, message: str, error_code: str = "LLM_ERROR", details: dict = None):
|
||||
def __init__(
|
||||
self, message: str, error_code: str = "LLM_ERROR", details: dict = None
|
||||
):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.error_code = error_code
|
||||
@@ -18,7 +20,13 @@ class LLMError(Exception):
|
||||
class ProviderError(LLMError):
|
||||
"""Exception for LLM provider-specific errors"""
|
||||
|
||||
def __init__(self, message: str, provider: str, error_code: str = "PROVIDER_ERROR", details: dict = None):
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
provider: str,
|
||||
error_code: str = "PROVIDER_ERROR",
|
||||
details: dict = None,
|
||||
):
|
||||
super().__init__(message, error_code, details)
|
||||
self.provider = provider
|
||||
|
||||
@@ -26,7 +34,13 @@ class ProviderError(LLMError):
|
||||
class SecurityError(LLMError):
|
||||
"""Exception for security-related errors"""
|
||||
|
||||
def __init__(self, message: str, risk_score: float = 0.0, error_code: str = "SECURITY_ERROR", details: dict = None):
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
risk_score: float = 0.0,
|
||||
error_code: str = "SECURITY_ERROR",
|
||||
details: dict = None,
|
||||
):
|
||||
super().__init__(message, error_code, details)
|
||||
self.risk_score = risk_score
|
||||
|
||||
@@ -34,14 +48,22 @@ class SecurityError(LLMError):
|
||||
class ConfigurationError(LLMError):
|
||||
"""Exception for configuration-related errors"""
|
||||
|
||||
def __init__(self, message: str, error_code: str = "CONFIG_ERROR", details: dict = None):
|
||||
def __init__(
|
||||
self, message: str, error_code: str = "CONFIG_ERROR", details: dict = None
|
||||
):
|
||||
super().__init__(message, error_code, details)
|
||||
|
||||
|
||||
class RateLimitError(LLMError):
|
||||
"""Exception for rate limiting errors"""
|
||||
|
||||
def __init__(self, message: str, retry_after: int = None, error_code: str = "RATE_LIMIT_ERROR", details: dict = None):
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
retry_after: int = None,
|
||||
error_code: str = "RATE_LIMIT_ERROR",
|
||||
details: dict = None,
|
||||
):
|
||||
super().__init__(message, error_code, details)
|
||||
self.retry_after = retry_after
|
||||
|
||||
@@ -49,7 +71,13 @@ class RateLimitError(LLMError):
|
||||
class TimeoutError(LLMError):
|
||||
"""Exception for timeout errors"""
|
||||
|
||||
def __init__(self, message: str, timeout_duration: float = None, error_code: str = "TIMEOUT_ERROR", details: dict = None):
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
timeout_duration: float = None,
|
||||
error_code: str = "TIMEOUT_ERROR",
|
||||
details: dict = None,
|
||||
):
|
||||
super().__init__(message, error_code, details)
|
||||
self.timeout_duration = timeout_duration
|
||||
|
||||
@@ -57,6 +85,12 @@ class TimeoutError(LLMError):
|
||||
class ValidationError(LLMError):
|
||||
"""Exception for request validation errors"""
|
||||
|
||||
def __init__(self, message: str, field: str = None, error_code: str = "VALIDATION_ERROR", details: dict = None):
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
field: str = None,
|
||||
error_code: str = "VALIDATION_ERROR",
|
||||
details: dict = None,
|
||||
):
|
||||
super().__init__(message, error_code, details)
|
||||
self.field = field
|
||||
@@ -20,6 +20,7 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class RequestMetric:
|
||||
"""Individual request metric"""
|
||||
|
||||
timestamp: datetime
|
||||
provider: str
|
||||
model: str
|
||||
@@ -45,7 +46,9 @@ class MetricsCollector:
|
||||
"""
|
||||
self.max_history_size = max_history_size
|
||||
self._metrics: deque = deque(maxlen=max_history_size)
|
||||
self._provider_metrics: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000))
|
||||
self._provider_metrics: Dict[str, deque] = defaultdict(
|
||||
lambda: deque(maxlen=1000)
|
||||
)
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# Aggregated metrics cache
|
||||
@@ -53,7 +56,9 @@ class MetricsCollector:
|
||||
self._cached_metrics: Optional[LLMMetrics] = None
|
||||
self._cache_ttl_seconds = 60 # Cache for 1 minute
|
||||
|
||||
logger.info(f"Metrics collector initialized with max history: {max_history_size}")
|
||||
logger.info(
|
||||
f"Metrics collector initialized with max history: {max_history_size}"
|
||||
)
|
||||
|
||||
def record_request(
|
||||
self,
|
||||
@@ -66,7 +71,7 @@ class MetricsCollector:
|
||||
security_risk_score: float = 0.0,
|
||||
error_code: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
api_key_id: Optional[int] = None
|
||||
api_key_id: Optional[int] = None,
|
||||
):
|
||||
"""Record a request metric"""
|
||||
metric = RequestMetric(
|
||||
@@ -80,7 +85,7 @@ class MetricsCollector:
|
||||
security_risk_score=security_risk_score,
|
||||
error_code=error_code,
|
||||
user_id=user_id,
|
||||
api_key_id=api_key_id
|
||||
api_key_id=api_key_id,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
@@ -93,18 +98,25 @@ class MetricsCollector:
|
||||
|
||||
# Log significant events
|
||||
if not success:
|
||||
logger.warning(f"Request failed: {provider}/{model} - {error_code or 'Unknown error'}")
|
||||
logger.warning(
|
||||
f"Request failed: {provider}/{model} - {error_code or 'Unknown error'}"
|
||||
)
|
||||
elif security_risk_score > 0.6:
|
||||
logger.info(f"High risk request: {provider}/{model} - risk score: {security_risk_score:.3f}")
|
||||
logger.info(
|
||||
f"High risk request: {provider}/{model} - risk score: {security_risk_score:.3f}"
|
||||
)
|
||||
|
||||
def get_metrics(self, force_refresh: bool = False) -> LLMMetrics:
|
||||
"""Get aggregated metrics"""
|
||||
with self._lock:
|
||||
# Check cache validity
|
||||
if (not force_refresh and
|
||||
self._cached_metrics and
|
||||
self._cache_timestamp and
|
||||
(datetime.utcnow() - self._cache_timestamp).total_seconds() < self._cache_ttl_seconds):
|
||||
if (
|
||||
not force_refresh
|
||||
and self._cached_metrics
|
||||
and self._cache_timestamp
|
||||
and (datetime.utcnow() - self._cache_timestamp).total_seconds()
|
||||
< self._cache_ttl_seconds
|
||||
):
|
||||
return self._cached_metrics
|
||||
|
||||
# Calculate fresh metrics
|
||||
@@ -136,7 +148,9 @@ class MetricsCollector:
|
||||
provider_metrics = {}
|
||||
for provider, provider_data in self._provider_metrics.items():
|
||||
if provider_data:
|
||||
provider_metrics[provider] = self._calculate_provider_metrics(provider_data)
|
||||
provider_metrics[provider] = self._calculate_provider_metrics(
|
||||
provider_data
|
||||
)
|
||||
|
||||
return LLMMetrics(
|
||||
total_requests=total_requests,
|
||||
@@ -145,7 +159,7 @@ class MetricsCollector:
|
||||
average_latency_ms=avg_latency,
|
||||
average_risk_score=avg_risk_score,
|
||||
provider_metrics=provider_metrics,
|
||||
last_updated=datetime.utcnow()
|
||||
last_updated=datetime.utcnow(),
|
||||
)
|
||||
|
||||
def _calculate_provider_metrics(self, provider_data: deque) -> Dict[str, Any]:
|
||||
@@ -168,7 +182,9 @@ class MetricsCollector:
|
||||
for metric in provider_data:
|
||||
if metric.token_usage:
|
||||
total_prompt_tokens += metric.token_usage.get("prompt_tokens", 0)
|
||||
total_completion_tokens += metric.token_usage.get("completion_tokens", 0)
|
||||
total_completion_tokens += metric.token_usage.get(
|
||||
"completion_tokens", 0
|
||||
)
|
||||
total_tokens += metric.token_usage.get("total_tokens", 0)
|
||||
|
||||
# Model distribution
|
||||
@@ -198,12 +214,14 @@ class MetricsCollector:
|
||||
"total_completion_tokens": total_completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"avg_prompt_tokens": total_prompt_tokens / total if total > 0 else 0,
|
||||
"avg_completion_tokens": total_completion_tokens / successful if successful > 0 else 0
|
||||
"avg_completion_tokens": total_completion_tokens / successful
|
||||
if successful > 0
|
||||
else 0,
|
||||
},
|
||||
"model_distribution": dict(model_counts),
|
||||
"request_type_distribution": dict(request_type_counts),
|
||||
"error_distribution": dict(error_counts),
|
||||
"recent_requests": total
|
||||
"recent_requests": total,
|
||||
}
|
||||
|
||||
def get_provider_metrics(self, provider: str) -> Optional[Dict[str, Any]]:
|
||||
@@ -228,7 +246,11 @@ class MetricsCollector:
|
||||
|
||||
with self._lock:
|
||||
for metric in self._metrics:
|
||||
if metric.timestamp >= cutoff_time and not metric.success and metric.error_code:
|
||||
if (
|
||||
metric.timestamp >= cutoff_time
|
||||
and not metric.success
|
||||
and metric.error_code
|
||||
):
|
||||
error_counts[metric.error_code] += 1
|
||||
|
||||
return dict(error_counts)
|
||||
@@ -252,7 +274,7 @@ class MetricsCollector:
|
||||
"max_latency_ms": max(latencies),
|
||||
"p95_latency_ms": self._percentile(latencies, 95),
|
||||
"p99_latency_ms": self._percentile(latencies, 99),
|
||||
"request_count": len(latencies)
|
||||
"request_count": len(latencies),
|
||||
}
|
||||
|
||||
return performance
|
||||
@@ -307,9 +329,11 @@ class MetricsCollector:
|
||||
"total_requests_5min": total_recent,
|
||||
"average_latency_ms": metrics.average_latency_ms,
|
||||
"error_count_1hour": sum(error_metrics.values()),
|
||||
"top_errors": dict(sorted(error_metrics.items(), key=lambda x: x[1], reverse=True)[:5]),
|
||||
"top_errors": dict(
|
||||
sorted(error_metrics.items(), key=lambda x: x[1], reverse=True)[:5]
|
||||
),
|
||||
"provider_count": len(metrics.provider_metrics),
|
||||
"last_updated": datetime.utcnow().isoformat()
|
||||
"last_updated": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -9,15 +9,30 @@ from pydantic import BaseModel, Field, validator
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""Tool call in a message"""
|
||||
|
||||
id: str = Field(..., description="Tool call identifier")
|
||||
type: str = Field("function", description="Tool call type")
|
||||
function: Dict[str, Any] = Field(..., description="Function call details")
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""Individual chat message"""
|
||||
role: str = Field(..., description="Message role (system, user, assistant)")
|
||||
content: str = Field(..., description="Message content")
|
||||
name: Optional[str] = Field(None, description="Optional message name")
|
||||
|
||||
@validator('role')
|
||||
role: str = Field(..., description="Message role (system, user, assistant)")
|
||||
content: Optional[str] = Field(None, description="Message content")
|
||||
name: Optional[str] = Field(None, description="Optional message name")
|
||||
tool_calls: Optional[List[ToolCall]] = Field(
|
||||
None, description="Tool calls in this message"
|
||||
)
|
||||
tool_call_id: Optional[str] = Field(
|
||||
None, description="Tool call ID for tool responses"
|
||||
)
|
||||
|
||||
@validator("role")
|
||||
def validate_role(cls, v):
|
||||
allowed_roles = {'system', 'user', 'assistant', 'function'}
|
||||
allowed_roles = {"system", "user", "assistant", "function", "tool"}
|
||||
if v not in allowed_roles:
|
||||
raise ValueError(f"Role must be one of {allowed_roles}")
|
||||
return v
|
||||
@@ -25,21 +40,38 @@ class ChatMessage(BaseModel):
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""Chat completion request"""
|
||||
|
||||
model: str = Field(..., description="Model identifier")
|
||||
messages: List[ChatMessage] = Field(..., description="Chat messages")
|
||||
temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature")
|
||||
max_tokens: Optional[int] = Field(None, ge=1, le=32000, description="Maximum tokens to generate")
|
||||
top_p: Optional[float] = Field(1.0, ge=0.0, le=1.0, description="Nucleus sampling parameter")
|
||||
temperature: Optional[float] = Field(
|
||||
0.7, ge=0.0, le=2.0, description="Sampling temperature"
|
||||
)
|
||||
max_tokens: Optional[int] = Field(
|
||||
None, ge=1, le=32000, description="Maximum tokens to generate"
|
||||
)
|
||||
top_p: Optional[float] = Field(
|
||||
1.0, ge=0.0, le=1.0, description="Nucleus sampling parameter"
|
||||
)
|
||||
top_k: Optional[int] = Field(None, ge=1, description="Top-k sampling parameter")
|
||||
frequency_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="Frequency penalty")
|
||||
presence_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="Presence penalty")
|
||||
frequency_penalty: Optional[float] = Field(
|
||||
0.0, ge=-2.0, le=2.0, description="Frequency penalty"
|
||||
)
|
||||
presence_penalty: Optional[float] = Field(
|
||||
0.0, ge=-2.0, le=2.0, description="Presence penalty"
|
||||
)
|
||||
stop: Optional[Union[str, List[str]]] = Field(None, description="Stop sequences")
|
||||
stream: Optional[bool] = Field(False, description="Stream response")
|
||||
tools: Optional[List[Dict[str, Any]]] = Field(
|
||||
None, description="Available tools for function calling"
|
||||
)
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
|
||||
None, description="Tool choice preference"
|
||||
)
|
||||
user_id: str = Field(..., description="User identifier")
|
||||
api_key_id: int = Field(..., description="API key identifier")
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
|
||||
|
||||
@validator('messages')
|
||||
@validator("messages")
|
||||
def validate_messages(cls, v):
|
||||
if not v:
|
||||
raise ValueError("Messages cannot be empty")
|
||||
@@ -48,6 +80,7 @@ class ChatRequest(BaseModel):
|
||||
|
||||
class TokenUsage(BaseModel):
|
||||
"""Token usage information"""
|
||||
|
||||
prompt_tokens: int = Field(..., description="Tokens in the prompt")
|
||||
completion_tokens: int = Field(..., description="Tokens in the completion")
|
||||
total_tokens: int = Field(..., description="Total tokens used")
|
||||
@@ -55,13 +88,17 @@ class TokenUsage(BaseModel):
|
||||
|
||||
class ChatChoice(BaseModel):
|
||||
"""Chat completion choice"""
|
||||
|
||||
index: int = Field(..., description="Choice index")
|
||||
message: ChatMessage = Field(..., description="Generated message")
|
||||
finish_reason: Optional[str] = Field(None, description="Reason for completion finish")
|
||||
finish_reason: Optional[str] = Field(
|
||||
None, description="Reason for completion finish"
|
||||
)
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""Chat completion response"""
|
||||
|
||||
id: str = Field(..., description="Response identifier")
|
||||
object: str = Field("chat.completion", description="Object type")
|
||||
created: int = Field(..., description="Creation timestamp")
|
||||
@@ -72,17 +109,26 @@ class ChatResponse(BaseModel):
|
||||
system_fingerprint: Optional[str] = Field(None, description="System fingerprint")
|
||||
|
||||
# Security fields maintained for backward compatibility
|
||||
security_check: Optional[bool] = Field(None, description="Whether security check passed")
|
||||
security_check: Optional[bool] = Field(
|
||||
None, description="Whether security check passed"
|
||||
)
|
||||
risk_score: Optional[float] = Field(None, description="Security risk score")
|
||||
detected_patterns: Optional[List[str]] = Field(None, description="Detected security patterns")
|
||||
detected_patterns: Optional[List[str]] = Field(
|
||||
None, description="Detected security patterns"
|
||||
)
|
||||
|
||||
# Performance metrics
|
||||
latency_ms: Optional[float] = Field(None, description="Response latency in milliseconds")
|
||||
provider_latency_ms: Optional[float] = Field(None, description="Provider-specific latency")
|
||||
latency_ms: Optional[float] = Field(
|
||||
None, description="Response latency in milliseconds"
|
||||
)
|
||||
provider_latency_ms: Optional[float] = Field(
|
||||
None, description="Provider-specific latency"
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
"""Embedding generation request"""
|
||||
|
||||
model: str = Field(..., description="Embedding model identifier")
|
||||
input: Union[str, List[str]] = Field(..., description="Text to embed")
|
||||
encoding_format: Optional[str] = Field("float", description="Encoding format")
|
||||
@@ -91,19 +137,22 @@ class EmbeddingRequest(BaseModel):
|
||||
api_key_id: int = Field(..., description="API key identifier")
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
|
||||
|
||||
@validator('input')
|
||||
@validator("input")
|
||||
def validate_input(cls, v):
|
||||
if isinstance(v, str):
|
||||
if not v.strip():
|
||||
raise ValueError("Input text cannot be empty")
|
||||
elif isinstance(v, list):
|
||||
if not v or not all(isinstance(item, str) and item.strip() for item in v):
|
||||
raise ValueError("Input list cannot be empty and must contain non-empty strings")
|
||||
raise ValueError(
|
||||
"Input list cannot be empty and must contain non-empty strings"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class EmbeddingData(BaseModel):
|
||||
"""Single embedding data"""
|
||||
|
||||
object: str = Field("embedding", description="Object type")
|
||||
index: int = Field(..., description="Embedding index")
|
||||
embedding: List[float] = Field(..., description="Embedding vector")
|
||||
@@ -111,6 +160,7 @@ class EmbeddingData(BaseModel):
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
"""Embedding generation response"""
|
||||
|
||||
object: str = Field("list", description="Object type")
|
||||
data: List[EmbeddingData] = Field(..., description="Embedding data")
|
||||
model: str = Field(..., description="Model used")
|
||||
@@ -118,56 +168,90 @@ class EmbeddingResponse(BaseModel):
|
||||
usage: Optional[TokenUsage] = Field(None, description="Token usage")
|
||||
|
||||
# Security fields maintained for backward compatibility
|
||||
security_check: Optional[bool] = Field(None, description="Whether security check passed")
|
||||
security_check: Optional[bool] = Field(
|
||||
None, description="Whether security check passed"
|
||||
)
|
||||
risk_score: Optional[float] = Field(None, description="Security risk score")
|
||||
detected_patterns: Optional[List[str]] = Field(None, description="Detected security patterns")
|
||||
detected_patterns: Optional[List[str]] = Field(
|
||||
None, description="Detected security patterns"
|
||||
)
|
||||
|
||||
# Performance metrics
|
||||
latency_ms: Optional[float] = Field(None, description="Response latency in milliseconds")
|
||||
provider_latency_ms: Optional[float] = Field(None, description="Provider-specific latency")
|
||||
latency_ms: Optional[float] = Field(
|
||||
None, description="Response latency in milliseconds"
|
||||
)
|
||||
provider_latency_ms: Optional[float] = Field(
|
||||
None, description="Provider-specific latency"
|
||||
)
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Model information"""
|
||||
|
||||
id: str = Field(..., description="Model identifier")
|
||||
object: str = Field("model", description="Object type")
|
||||
created: Optional[int] = Field(None, description="Creation timestamp")
|
||||
owned_by: str = Field(..., description="Model owner")
|
||||
provider: str = Field(..., description="Provider name")
|
||||
capabilities: List[str] = Field(default_factory=list, description="Model capabilities")
|
||||
capabilities: List[str] = Field(
|
||||
default_factory=list, description="Model capabilities"
|
||||
)
|
||||
context_window: Optional[int] = Field(None, description="Context window size")
|
||||
max_output_tokens: Optional[int] = Field(None, description="Maximum output tokens")
|
||||
supports_streaming: bool = Field(False, description="Whether model supports streaming")
|
||||
supports_function_calling: bool = Field(False, description="Whether model supports function calling")
|
||||
tasks: Optional[List[str]] = Field(None, description="Model tasks (e.g., generate, embed, vision)")
|
||||
supports_streaming: bool = Field(
|
||||
False, description="Whether model supports streaming"
|
||||
)
|
||||
supports_function_calling: bool = Field(
|
||||
False, description="Whether model supports function calling"
|
||||
)
|
||||
tasks: Optional[List[str]] = Field(
|
||||
None, description="Model tasks (e.g., generate, embed, vision)"
|
||||
)
|
||||
|
||||
|
||||
class ProviderStatus(BaseModel):
|
||||
"""Provider health status"""
|
||||
|
||||
provider: str = Field(..., description="Provider name")
|
||||
status: str = Field(..., description="Status (healthy, degraded, unavailable)")
|
||||
latency_ms: Optional[float] = Field(None, description="Average latency")
|
||||
success_rate: Optional[float] = Field(None, description="Success rate (0.0 to 1.0)")
|
||||
last_check: datetime = Field(..., description="Last health check timestamp")
|
||||
error_message: Optional[str] = Field(None, description="Error message if unhealthy")
|
||||
models_available: List[str] = Field(default_factory=list, description="Available models")
|
||||
models_available: List[str] = Field(
|
||||
default_factory=list, description="Available models"
|
||||
)
|
||||
|
||||
|
||||
class LLMMetrics(BaseModel):
|
||||
"""LLM service metrics"""
|
||||
|
||||
total_requests: int = Field(0, description="Total requests processed")
|
||||
successful_requests: int = Field(0, description="Successful requests")
|
||||
failed_requests: int = Field(0, description="Failed requests")
|
||||
average_latency_ms: float = Field(0.0, description="Average response latency")
|
||||
provider_metrics: Dict[str, Dict[str, Any]] = Field(default_factory=dict, description="Per-provider metrics")
|
||||
last_updated: datetime = Field(default_factory=datetime.utcnow, description="Last metrics update")
|
||||
provider_metrics: Dict[str, Dict[str, Any]] = Field(
|
||||
default_factory=dict, description="Per-provider metrics"
|
||||
)
|
||||
last_updated: datetime = Field(
|
||||
default_factory=datetime.utcnow, description="Last metrics update"
|
||||
)
|
||||
|
||||
|
||||
class ResilienceConfig(BaseModel):
|
||||
"""Configuration for resilience patterns"""
|
||||
|
||||
max_retries: int = Field(3, ge=0, le=10, description="Maximum retry attempts")
|
||||
retry_delay_ms: int = Field(1000, ge=100, le=30000, description="Initial retry delay")
|
||||
retry_exponential_base: float = Field(2.0, ge=1.1, le=5.0, description="Exponential backoff base")
|
||||
retry_delay_ms: int = Field(
|
||||
1000, ge=100, le=30000, description="Initial retry delay"
|
||||
)
|
||||
retry_exponential_base: float = Field(
|
||||
2.0, ge=1.1, le=5.0, description="Exponential backoff base"
|
||||
)
|
||||
timeout_ms: int = Field(30000, ge=1000, le=300000, description="Request timeout")
|
||||
circuit_breaker_threshold: int = Field(5, ge=1, le=50, description="Circuit breaker failure threshold")
|
||||
circuit_breaker_reset_timeout_ms: int = Field(60000, ge=10000, le=600000, description="Circuit breaker reset timeout")
|
||||
circuit_breaker_threshold: int = Field(
|
||||
5, ge=1, le=50, description="Circuit breaker failure threshold"
|
||||
)
|
||||
circuit_breaker_reset_timeout_ms: int = Field(
|
||||
60000, ge=10000, le=600000, description="Circuit breaker reset timeout"
|
||||
)
|
||||
|
||||
@@ -9,8 +9,12 @@ from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
import logging
|
||||
|
||||
from ..models import (
|
||||
ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
|
||||
ModelInfo, ProviderStatus
|
||||
ChatRequest,
|
||||
ChatResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
ModelInfo,
|
||||
ProviderStatus,
|
||||
)
|
||||
from ..config import ProviderConfig
|
||||
|
||||
@@ -80,7 +84,9 @@ class BaseLLMProvider(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_chat_completion_stream(self, request: ChatRequest) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
async def create_chat_completion_stream(
|
||||
self, request: ChatRequest
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
Create streaming chat completion
|
||||
|
||||
@@ -121,7 +127,7 @@ class BaseLLMProvider(ABC):
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup provider resources"""
|
||||
if self._session and hasattr(self._session, 'close'):
|
||||
if self._session and hasattr(self._session, "close"):
|
||||
await self._session.close()
|
||||
logger.debug(f"Cleaned up session for {self.name} provider")
|
||||
|
||||
@@ -147,24 +153,27 @@ class BaseLLMProvider(ABC):
|
||||
context_window=self.config.max_context_window,
|
||||
max_output_tokens=self.config.max_output_tokens,
|
||||
supports_streaming=self.config.supports_streaming,
|
||||
supports_function_calling=self.config.supports_function_calling
|
||||
supports_function_calling=self.config.supports_function_calling,
|
||||
)
|
||||
|
||||
def _validate_request(self, request: Any):
|
||||
"""Base request validation (override for provider-specific validation)"""
|
||||
if hasattr(request, 'model') and not self.supports_model(request.model):
|
||||
if hasattr(request, "model") and not self.supports_model(request.model):
|
||||
from ..exceptions import ValidationError
|
||||
|
||||
raise ValidationError(
|
||||
f"Model '{request.model}' not supported by provider '{self.name}'",
|
||||
field="model"
|
||||
field="model",
|
||||
)
|
||||
|
||||
def _create_headers(self, additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
|
||||
def _create_headers(
|
||||
self, additional_headers: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, str]:
|
||||
"""Create HTTP headers for requests"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"User-Agent": f"Enclava-LLM-Service/{self.name}"
|
||||
"User-Agent": f"Enclava-LLM-Service/{self.name}",
|
||||
}
|
||||
|
||||
if additional_headers:
|
||||
@@ -172,7 +181,9 @@ class BaseLLMProvider(ABC):
|
||||
|
||||
return headers
|
||||
|
||||
def _handle_http_error(self, status_code: int, response_text: str, provider_context: str = ""):
|
||||
def _handle_http_error(
|
||||
self, status_code: int, response_text: str, provider_context: str = ""
|
||||
):
|
||||
"""Handle HTTP errors consistently across providers"""
|
||||
from ..exceptions import ProviderError, RateLimitError, ValidationError
|
||||
|
||||
@@ -183,40 +194,44 @@ class BaseLLMProvider(ABC):
|
||||
f"Authentication failed for {context}",
|
||||
provider=self.name,
|
||||
error_code="AUTHENTICATION_ERROR",
|
||||
details={"status_code": status_code, "response": response_text}
|
||||
details={"status_code": status_code, "response": response_text},
|
||||
)
|
||||
elif status_code == 403:
|
||||
raise ProviderError(
|
||||
f"Access forbidden for {context}",
|
||||
provider=self.name,
|
||||
error_code="AUTHORIZATION_ERROR",
|
||||
details={"status_code": status_code, "response": response_text}
|
||||
details={"status_code": status_code, "response": response_text},
|
||||
)
|
||||
elif status_code == 429:
|
||||
raise RateLimitError(
|
||||
f"Rate limit exceeded for {context}",
|
||||
error_code="RATE_LIMIT_ERROR",
|
||||
details={"status_code": status_code, "response": response_text, "provider": self.name}
|
||||
details={
|
||||
"status_code": status_code,
|
||||
"response": response_text,
|
||||
"provider": self.name,
|
||||
},
|
||||
)
|
||||
elif status_code == 400:
|
||||
raise ValidationError(
|
||||
f"Bad request for {context}: {response_text}",
|
||||
error_code="BAD_REQUEST",
|
||||
details={"status_code": status_code, "response": response_text}
|
||||
details={"status_code": status_code, "response": response_text},
|
||||
)
|
||||
elif 500 <= status_code < 600:
|
||||
raise ProviderError(
|
||||
f"Server error for {context}: {response_text}",
|
||||
provider=self.name,
|
||||
error_code="SERVER_ERROR",
|
||||
details={"status_code": status_code, "response": response_text}
|
||||
details={"status_code": status_code, "response": response_text},
|
||||
)
|
||||
else:
|
||||
raise ProviderError(
|
||||
f"HTTP error {status_code} for {context}: {response_text}",
|
||||
provider=self.name,
|
||||
error_code="HTTP_ERROR",
|
||||
details={"status_code": status_code, "response": response_text}
|
||||
details={"status_code": status_code, "response": response_text},
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
||||
@@ -15,9 +15,16 @@ import aiohttp
|
||||
|
||||
from .base import BaseLLMProvider
|
||||
from ..models import (
|
||||
ChatRequest, ChatResponse, ChatMessage, ChatChoice, TokenUsage,
|
||||
EmbeddingRequest, EmbeddingResponse, EmbeddingData,
|
||||
ModelInfo, ProviderStatus
|
||||
ChatRequest,
|
||||
ChatResponse,
|
||||
ChatMessage,
|
||||
ChatChoice,
|
||||
TokenUsage,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingData,
|
||||
ModelInfo,
|
||||
ProviderStatus,
|
||||
)
|
||||
from ..config import ProviderConfig
|
||||
from ..exceptions import ProviderError, ValidationError, TimeoutError
|
||||
@@ -30,7 +37,7 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
|
||||
def __init__(self, config: ProviderConfig, api_key: str):
|
||||
super().__init__(config, api_key)
|
||||
self.base_url = config.base_url.rstrip('/')
|
||||
self.base_url = config.base_url.rstrip("/")
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
# TEE-specific settings
|
||||
@@ -52,17 +59,19 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
limit=100, # Connection pool limit
|
||||
limit_per_host=50,
|
||||
ttl_dns_cache=300, # DNS cache TTL
|
||||
use_dns_cache=True
|
||||
use_dns_cache=True,
|
||||
)
|
||||
|
||||
# Create session with security headers
|
||||
timeout = aiohttp.ClientTimeout(total=self.config.resilience.timeout_ms / 1000.0)
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=self.config.resilience.timeout_ms / 1000.0
|
||||
)
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
timeout=timeout,
|
||||
headers=self._create_headers(),
|
||||
trust_env=False # Don't trust environment variables
|
||||
trust_env=False, # Don't trust environment variables
|
||||
)
|
||||
|
||||
logger.debug("Created new secure HTTP session for PrivateMode")
|
||||
@@ -82,7 +91,9 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
|
||||
if response.status == 200:
|
||||
models_data = await response.json()
|
||||
models = [model.get("id", "") for model in models_data.get("data", [])]
|
||||
models = [
|
||||
model.get("id", "") for model in models_data.get("data", [])
|
||||
]
|
||||
|
||||
return ProviderStatus(
|
||||
provider=self.provider_name,
|
||||
@@ -90,7 +101,7 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
latency_ms=latency,
|
||||
success_rate=1.0,
|
||||
last_check=datetime.utcnow(),
|
||||
models_available=models
|
||||
models_available=models,
|
||||
)
|
||||
else:
|
||||
error_text = await response.text()
|
||||
@@ -101,7 +112,7 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
success_rate=0.0,
|
||||
last_check=datetime.utcnow(),
|
||||
error_message=f"HTTP {response.status}: {error_text}",
|
||||
models_available=[]
|
||||
models_available=[],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -115,7 +126,7 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
success_rate=0.0,
|
||||
last_check=datetime.utcnow(),
|
||||
error_message=str(e),
|
||||
models_available=[]
|
||||
models_available=[],
|
||||
)
|
||||
|
||||
async def get_models(self) -> List[ModelInfo]:
|
||||
@@ -151,7 +162,9 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
capabilities.append("vision")
|
||||
|
||||
# Check for function calling support in the API response
|
||||
supports_function_calling = model_data.get("supports_function_calling", False)
|
||||
supports_function_calling = model_data.get(
|
||||
"supports_function_calling", False
|
||||
)
|
||||
if supports_function_calling:
|
||||
capabilities.append("function_calling")
|
||||
|
||||
@@ -164,9 +177,11 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
capabilities=capabilities,
|
||||
context_window=model_data.get("context_window"),
|
||||
max_output_tokens=model_data.get("max_output_tokens"),
|
||||
supports_streaming=model_data.get("supports_streaming", True),
|
||||
supports_streaming=model_data.get(
|
||||
"supports_streaming", True
|
||||
),
|
||||
supports_function_calling=supports_function_calling,
|
||||
tasks=tasks # Pass through tasks field from PrivateMode API
|
||||
tasks=tasks, # Pass through tasks field from PrivateMode API
|
||||
)
|
||||
models.append(model_info)
|
||||
|
||||
@@ -174,7 +189,9 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
return models
|
||||
else:
|
||||
error_text = await response.text()
|
||||
self._handle_http_error(response.status, error_text, "models endpoint")
|
||||
self._handle_http_error(
|
||||
response.status, error_text, "models endpoint"
|
||||
)
|
||||
return [] # Never reached due to exception
|
||||
|
||||
except Exception as e:
|
||||
@@ -186,7 +203,7 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
"Failed to retrieve models from PrivateMode",
|
||||
provider=self.provider_name,
|
||||
error_code="MODEL_RETRIEVAL_ERROR",
|
||||
details={"error": str(e)}
|
||||
details={"error": str(e)},
|
||||
)
|
||||
|
||||
async def create_chat_completion(self, request: ChatRequest) -> ChatResponse:
|
||||
@@ -205,12 +222,12 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
**({"name": msg.name} if msg.name else {})
|
||||
**({"name": msg.name} if msg.name else {}),
|
||||
}
|
||||
for msg in request.messages
|
||||
],
|
||||
"temperature": request.temperature,
|
||||
"stream": False # Non-streaming version
|
||||
"stream": False, # Non-streaming version
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
@@ -234,12 +251,11 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
"api_key_id": request.api_key_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"enclava_request_id": str(uuid.uuid4()),
|
||||
**(request.metadata or {})
|
||||
**(request.metadata or {}),
|
||||
}
|
||||
|
||||
async with session.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
json=payload
|
||||
f"{self.base_url}/chat/completions", json=payload
|
||||
) as response:
|
||||
provider_latency = (time.time() - start_time) * 1000
|
||||
|
||||
@@ -254,9 +270,9 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
index=choice_data.get("index", 0),
|
||||
message=ChatMessage(
|
||||
role=message_data.get("role", "assistant"),
|
||||
content=message_data.get("content", "")
|
||||
content=message_data.get("content", ""),
|
||||
),
|
||||
finish_reason=choice_data.get("finish_reason")
|
||||
finish_reason=choice_data.get("finish_reason"),
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
@@ -265,7 +281,7 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
usage = TokenUsage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0)
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
# Create response
|
||||
@@ -281,15 +297,19 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
security_check=True, # Will be set by security manager
|
||||
risk_score=0.0, # Will be set by security manager
|
||||
latency_ms=provider_latency,
|
||||
provider_latency_ms=provider_latency
|
||||
provider_latency_ms=provider_latency,
|
||||
)
|
||||
|
||||
logger.debug(f"PrivateMode chat completion successful in {provider_latency:.2f}ms")
|
||||
logger.debug(
|
||||
f"PrivateMode chat completion successful in {provider_latency:.2f}ms"
|
||||
)
|
||||
return chat_response
|
||||
|
||||
else:
|
||||
error_text = await response.text()
|
||||
self._handle_http_error(response.status, error_text, "chat completion")
|
||||
self._handle_http_error(
|
||||
response.status, error_text, "chat completion"
|
||||
)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"PrivateMode request error: {e}")
|
||||
@@ -297,7 +317,7 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
"Network error communicating with PrivateMode",
|
||||
provider=self.provider_name,
|
||||
error_code="NETWORK_ERROR",
|
||||
details={"error": str(e)}
|
||||
details={"error": str(e)},
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, (ProviderError, ValidationError)):
|
||||
@@ -308,10 +328,12 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
"Unexpected error during chat completion",
|
||||
provider=self.provider_name,
|
||||
error_code="UNEXPECTED_ERROR",
|
||||
details={"error": str(e)}
|
||||
details={"error": str(e)},
|
||||
)
|
||||
|
||||
async def create_chat_completion_stream(self, request: ChatRequest) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
async def create_chat_completion_stream(
|
||||
self, request: ChatRequest
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Create streaming chat completion"""
|
||||
self._validate_request(request)
|
||||
|
||||
@@ -325,12 +347,12 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
**({"name": msg.name} if msg.name else {})
|
||||
**({"name": msg.name} if msg.name else {}),
|
||||
}
|
||||
for msg in request.messages
|
||||
],
|
||||
"temperature": request.temperature,
|
||||
"stream": True
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
@@ -349,12 +371,11 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
payload["user"] = f"user_{request.user_id}"
|
||||
|
||||
async with session.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
json=payload
|
||||
f"{self.base_url}/chat/completions", json=payload
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
async for line in response.content:
|
||||
line = line.decode('utf-8').strip()
|
||||
line = line.decode("utf-8").strip()
|
||||
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:] # Remove "data: " prefix
|
||||
@@ -366,11 +387,15 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
chunk_data = json.loads(data_str)
|
||||
yield chunk_data
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse streaming chunk: {data_str}")
|
||||
logger.warning(
|
||||
f"Failed to parse streaming chunk: {data_str}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
error_text = await response.text()
|
||||
self._handle_http_error(response.status, error_text, "streaming chat completion")
|
||||
self._handle_http_error(
|
||||
response.status, error_text, "streaming chat completion"
|
||||
)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"PrivateMode streaming error: {e}")
|
||||
@@ -378,7 +403,7 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
"Network error during streaming",
|
||||
provider=self.provider_name,
|
||||
error_code="STREAMING_ERROR",
|
||||
details={"error": str(e)}
|
||||
details={"error": str(e)},
|
||||
)
|
||||
|
||||
async def create_embedding(self, request: EmbeddingRequest) -> EmbeddingResponse:
|
||||
@@ -394,7 +419,7 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
payload = {
|
||||
"model": request.model,
|
||||
"input": request.input,
|
||||
"user": f"user_{request.user_id}"
|
||||
"user": f"user_{request.user_id}",
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
@@ -408,12 +433,11 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
"user_id": request.user_id,
|
||||
"api_key_id": request.api_key_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
**(request.metadata or {})
|
||||
**(request.metadata or {}),
|
||||
}
|
||||
|
||||
async with session.post(
|
||||
f"{self.base_url}/embeddings",
|
||||
json=payload
|
||||
f"{self.base_url}/embeddings", json=payload
|
||||
) as response:
|
||||
provider_latency = (time.time() - start_time) * 1000
|
||||
|
||||
@@ -426,7 +450,7 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
embedding = EmbeddingData(
|
||||
object="embedding",
|
||||
index=emb_data.get("index", 0),
|
||||
embedding=emb_data.get("embedding", [])
|
||||
embedding=emb_data.get("embedding", []),
|
||||
)
|
||||
embeddings.append(embedding)
|
||||
|
||||
@@ -435,7 +459,9 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
usage = TokenUsage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=0, # No completion tokens for embeddings
|
||||
total_tokens=usage_data.get("total_tokens", usage_data.get("prompt_tokens", 0))
|
||||
total_tokens=usage_data.get(
|
||||
"total_tokens", usage_data.get("prompt_tokens", 0)
|
||||
),
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
@@ -447,13 +473,15 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
security_check=True, # Will be set by security manager
|
||||
risk_score=0.0, # Will be set by security manager
|
||||
latency_ms=provider_latency,
|
||||
provider_latency_ms=provider_latency
|
||||
provider_latency_ms=provider_latency,
|
||||
)
|
||||
|
||||
else:
|
||||
error_text = await response.text()
|
||||
# Log the detailed error response from the provider
|
||||
logger.error(f"PrivateMode embedding error - Status {response.status}: {error_text}")
|
||||
logger.error(
|
||||
f"PrivateMode embedding error - Status {response.status}: {error_text}"
|
||||
)
|
||||
self._handle_http_error(response.status, error_text, "embeddings")
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
@@ -462,7 +490,7 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
"Network error during embedding generation",
|
||||
provider=self.provider_name,
|
||||
error_code="EMBEDDING_ERROR",
|
||||
details={"error": str(e)}
|
||||
details={"error": str(e)},
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, (ProviderError, ValidationError)):
|
||||
@@ -473,7 +501,7 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
"Unexpected error during embedding generation",
|
||||
provider=self.provider_name,
|
||||
error_code="UNEXPECTED_ERROR",
|
||||
details={"error": str(e)}
|
||||
details={"error": str(e)},
|
||||
)
|
||||
|
||||
async def cleanup(self):
|
||||
|
||||
@@ -20,6 +20,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class CircuitBreakerState(Enum):
|
||||
"""Circuit breaker states"""
|
||||
|
||||
CLOSED = "closed" # Normal operation
|
||||
OPEN = "open" # Failing, blocking requests
|
||||
HALF_OPEN = "half_open" # Testing if service recovered
|
||||
@@ -28,6 +29,7 @@ class CircuitBreakerState(Enum):
|
||||
@dataclass
|
||||
class CircuitBreakerStats:
|
||||
"""Circuit breaker statistics"""
|
||||
|
||||
failure_count: int = 0
|
||||
success_count: int = 0
|
||||
last_failure_time: Optional[datetime] = None
|
||||
@@ -51,7 +53,9 @@ class CircuitBreaker:
|
||||
|
||||
if self.state == CircuitBreakerState.OPEN:
|
||||
# Check if reset timeout has passed
|
||||
if (datetime.utcnow() - self.stats.state_change_time).total_seconds() * 1000 > self.config.circuit_breaker_reset_timeout_ms:
|
||||
if (
|
||||
datetime.utcnow() - self.stats.state_change_time
|
||||
).total_seconds() * 1000 > self.config.circuit_breaker_reset_timeout_ms:
|
||||
self._transition_to_half_open()
|
||||
return True
|
||||
return False
|
||||
@@ -72,7 +76,9 @@ class CircuitBreaker:
|
||||
# Reset failure count on success
|
||||
self.stats.failure_count = 0
|
||||
|
||||
logger.debug(f"Circuit breaker [{self.provider_name}]: Success recorded, state={self.state.value}")
|
||||
logger.debug(
|
||||
f"Circuit breaker [{self.provider_name}]: Success recorded, state={self.state.value}"
|
||||
)
|
||||
|
||||
def record_failure(self):
|
||||
"""Record failed request"""
|
||||
@@ -85,27 +91,35 @@ class CircuitBreaker:
|
||||
elif self.state == CircuitBreakerState.HALF_OPEN:
|
||||
self._transition_to_open()
|
||||
|
||||
logger.warning(f"Circuit breaker [{self.provider_name}]: Failure recorded, "
|
||||
f"count={self.stats.failure_count}, state={self.state.value}")
|
||||
logger.warning(
|
||||
f"Circuit breaker [{self.provider_name}]: Failure recorded, "
|
||||
f"count={self.stats.failure_count}, state={self.state.value}"
|
||||
)
|
||||
|
||||
def _transition_to_open(self):
|
||||
"""Transition to OPEN state"""
|
||||
self.state = CircuitBreakerState.OPEN
|
||||
self.stats.state_change_time = datetime.utcnow()
|
||||
logger.error(f"Circuit breaker [{self.provider_name}]: OPENED after {self.stats.failure_count} failures")
|
||||
logger.error(
|
||||
f"Circuit breaker [{self.provider_name}]: OPENED after {self.stats.failure_count} failures"
|
||||
)
|
||||
|
||||
def _transition_to_half_open(self):
|
||||
"""Transition to HALF_OPEN state"""
|
||||
self.state = CircuitBreakerState.HALF_OPEN
|
||||
self.stats.state_change_time = datetime.utcnow()
|
||||
logger.info(f"Circuit breaker [{self.provider_name}]: Transitioning to HALF_OPEN for testing")
|
||||
logger.info(
|
||||
f"Circuit breaker [{self.provider_name}]: Transitioning to HALF_OPEN for testing"
|
||||
)
|
||||
|
||||
def _transition_to_closed(self):
|
||||
"""Transition to CLOSED state"""
|
||||
self.state = CircuitBreakerState.CLOSED
|
||||
self.stats.state_change_time = datetime.utcnow()
|
||||
self.stats.failure_count = 0 # Reset failure count
|
||||
logger.info(f"Circuit breaker [{self.provider_name}]: CLOSED - service recovered")
|
||||
logger.info(
|
||||
f"Circuit breaker [{self.provider_name}]: CLOSED - service recovered"
|
||||
)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get circuit breaker statistics"""
|
||||
@@ -113,10 +127,17 @@ class CircuitBreaker:
|
||||
"state": self.state.value,
|
||||
"failure_count": self.stats.failure_count,
|
||||
"success_count": self.stats.success_count,
|
||||
"last_failure_time": self.stats.last_failure_time.isoformat() if self.stats.last_failure_time else None,
|
||||
"last_success_time": self.stats.last_success_time.isoformat() if self.stats.last_success_time else None,
|
||||
"last_failure_time": self.stats.last_failure_time.isoformat()
|
||||
if self.stats.last_failure_time
|
||||
else None,
|
||||
"last_success_time": self.stats.last_success_time.isoformat()
|
||||
if self.stats.last_success_time
|
||||
else None,
|
||||
"state_change_time": self.stats.state_change_time.isoformat(),
|
||||
"time_in_current_state_ms": (datetime.utcnow() - self.stats.state_change_time).total_seconds() * 1000
|
||||
"time_in_current_state_ms": (
|
||||
datetime.utcnow() - self.stats.state_change_time
|
||||
).total_seconds()
|
||||
* 1000,
|
||||
}
|
||||
|
||||
|
||||
@@ -132,7 +153,7 @@ class RetryManager:
|
||||
*args,
|
||||
retryable_exceptions: tuple = (Exception,),
|
||||
non_retryable_exceptions: tuple = (RateLimitError,),
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Execute function with retry logic"""
|
||||
last_exception = None
|
||||
@@ -149,11 +170,15 @@ class RetryManager:
|
||||
last_exception = e
|
||||
|
||||
if attempt == self.config.max_retries:
|
||||
logger.error(f"All {self.config.max_retries + 1} attempts failed. Last error: {e}")
|
||||
logger.error(
|
||||
f"All {self.config.max_retries + 1} attempts failed. Last error: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
delay = self._calculate_delay(attempt)
|
||||
logger.warning(f"Attempt {attempt + 1} failed: {e}. Retrying in {delay}ms...")
|
||||
logger.warning(
|
||||
f"Attempt {attempt + 1} failed: {e}. Retrying in {delay}ms..."
|
||||
)
|
||||
|
||||
await asyncio.sleep(delay / 1000.0)
|
||||
|
||||
@@ -165,10 +190,13 @@ class RetryManager:
|
||||
|
||||
def _calculate_delay(self, attempt: int) -> int:
|
||||
"""Calculate delay for exponential backoff"""
|
||||
delay = self.config.retry_delay_ms * (self.config.retry_exponential_base ** attempt)
|
||||
delay = self.config.retry_delay_ms * (
|
||||
self.config.retry_exponential_base**attempt
|
||||
)
|
||||
|
||||
# Add some jitter to prevent thundering herd
|
||||
import random
|
||||
|
||||
jitter = random.uniform(0.8, 1.2)
|
||||
|
||||
return int(delay * jitter)
|
||||
@@ -181,18 +209,16 @@ class TimeoutManager:
|
||||
self.config = config
|
||||
|
||||
async def execute_with_timeout(
|
||||
self,
|
||||
func: Callable,
|
||||
*args,
|
||||
timeout_override: Optional[int] = None,
|
||||
**kwargs
|
||||
self, func: Callable, *args, timeout_override: Optional[int] = None, **kwargs
|
||||
) -> Any:
|
||||
"""Execute function with timeout"""
|
||||
timeout_ms = timeout_override or self.config.timeout_ms
|
||||
timeout_seconds = timeout_ms / 1000.0
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(func(*args, **kwargs), timeout=timeout_seconds)
|
||||
return await asyncio.wait_for(
|
||||
func(*args, **kwargs), timeout=timeout_seconds
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
error_msg = f"Request timed out after {timeout_ms}ms"
|
||||
logger.error(error_msg)
|
||||
@@ -216,7 +242,7 @@ class ResilienceManager:
|
||||
retryable_exceptions: tuple = (Exception,),
|
||||
non_retryable_exceptions: tuple = (RateLimitError,),
|
||||
timeout_override: Optional[int] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Execute function with full resilience patterns"""
|
||||
|
||||
@@ -237,14 +263,16 @@ class ResilienceManager:
|
||||
retryable_exceptions=retryable_exceptions,
|
||||
non_retryable_exceptions=non_retryable_exceptions,
|
||||
timeout_override=timeout_override,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Record success
|
||||
self.circuit_breaker.record_success()
|
||||
|
||||
execution_time = (time.time() - start_time) * 1000
|
||||
logger.debug(f"Resilient execution succeeded for {self.provider_name} in {execution_time:.2f}ms")
|
||||
logger.debug(
|
||||
f"Resilient execution succeeded for {self.provider_name} in {execution_time:.2f}ms"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -253,7 +281,9 @@ class ResilienceManager:
|
||||
self.circuit_breaker.record_failure()
|
||||
|
||||
execution_time = (time.time() - start_time) * 1000
|
||||
logger.error(f"Resilient execution failed for {self.provider_name} after {execution_time:.2f}ms: {e}")
|
||||
logger.error(
|
||||
f"Resilient execution failed for {self.provider_name} after {execution_time:.2f}ms: {e}"
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
@@ -281,8 +311,8 @@ class ResilienceManager:
|
||||
"config": {
|
||||
"max_retries": self.config.max_retries,
|
||||
"timeout_ms": self.config.timeout_ms,
|
||||
"circuit_breaker_threshold": self.config.circuit_breaker_threshold
|
||||
}
|
||||
"circuit_breaker_threshold": self.config.circuit_breaker_threshold,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -293,11 +323,15 @@ class ResilienceManagerFactory:
|
||||
_default_config = ResilienceConfig()
|
||||
|
||||
@classmethod
|
||||
def get_manager(cls, provider_name: str, config: Optional[ResilienceConfig] = None) -> ResilienceManager:
|
||||
def get_manager(
|
||||
cls, provider_name: str, config: Optional[ResilienceConfig] = None
|
||||
) -> ResilienceManager:
|
||||
"""Get or create resilience manager for provider"""
|
||||
if provider_name not in cls._managers:
|
||||
manager_config = config or cls._default_config
|
||||
cls._managers[provider_name] = ResilienceManager(manager_config, provider_name)
|
||||
cls._managers[provider_name] = ResilienceManager(
|
||||
manager_config, provider_name
|
||||
)
|
||||
|
||||
return cls._managers[provider_name]
|
||||
|
||||
@@ -305,8 +339,7 @@ class ResilienceManagerFactory:
|
||||
def get_all_health_status(cls) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get health status for all managed providers"""
|
||||
return {
|
||||
name: manager.get_health_status()
|
||||
for name, manager in cls._managers.items()
|
||||
name: manager.get_health_status() for name, manager in cls._managers.items()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -12,18 +12,28 @@ from typing import Dict, Any, Optional, List, AsyncGenerator
|
||||
from datetime import datetime
|
||||
|
||||
from .models import (
|
||||
ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
|
||||
ModelInfo, ProviderStatus, LLMMetrics
|
||||
ChatRequest,
|
||||
ChatResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
ModelInfo,
|
||||
ProviderStatus,
|
||||
LLMMetrics,
|
||||
)
|
||||
from .config import config_manager, ProviderConfig
|
||||
from ...core.config import settings
|
||||
|
||||
from .resilience import ResilienceManagerFactory
|
||||
|
||||
# from .metrics import metrics_collector
|
||||
from .providers import BaseLLMProvider, PrivateModeProvider
|
||||
from .exceptions import (
|
||||
LLMError, ProviderError, SecurityError, ConfigurationError,
|
||||
ValidationError, TimeoutError
|
||||
LLMError,
|
||||
ProviderError,
|
||||
SecurityError,
|
||||
ConfigurationError,
|
||||
ValidationError,
|
||||
TimeoutError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -52,7 +62,9 @@ class LLMService:
|
||||
try:
|
||||
# Get configuration
|
||||
config = config_manager.get_config()
|
||||
logger.info(f"Initializing LLM service with {len(config.providers)} configured providers")
|
||||
logger.info(
|
||||
f"Initializing LLM service with {len(config.providers)} configured providers"
|
||||
)
|
||||
|
||||
# Initialize enabled providers
|
||||
enabled_providers = config_manager.get_enabled_providers()
|
||||
@@ -70,13 +82,17 @@ class LLMService:
|
||||
default_provider = config.default_provider
|
||||
if default_provider not in self._providers:
|
||||
available_providers = list(self._providers.keys())
|
||||
logger.warning(f"Default provider '{default_provider}' not available, using '{available_providers[0]}'")
|
||||
logger.warning(
|
||||
f"Default provider '{default_provider}' not available, using '{available_providers[0]}'"
|
||||
)
|
||||
config.default_provider = available_providers[0]
|
||||
|
||||
self._initialized = True
|
||||
initialization_time = (time.time() - start_time) * 1000
|
||||
|
||||
logger.info(f"LLM Service initialized successfully in {initialization_time:.2f}ms")
|
||||
logger.info(
|
||||
f"LLM Service initialized successfully in {initialization_time:.2f}ms"
|
||||
)
|
||||
logger.info(f"Available providers: {list(self._providers.keys())}")
|
||||
|
||||
except Exception as e:
|
||||
@@ -106,12 +122,16 @@ class LLMService:
|
||||
# Test provider health
|
||||
health_status = await provider.health_check()
|
||||
if health_status.status == "unavailable":
|
||||
logger.error(f"Provider '{provider_name}' failed health check: {health_status.error_message}")
|
||||
logger.error(
|
||||
f"Provider '{provider_name}' failed health check: {health_status.error_message}"
|
||||
)
|
||||
return
|
||||
|
||||
# Register provider
|
||||
self._providers[provider_name] = provider
|
||||
logger.info(f"Provider '{provider_name}' initialized successfully (status: {health_status.status})")
|
||||
logger.info(
|
||||
f"Provider '{provider_name}' initialized successfully (status: {health_status.status})"
|
||||
)
|
||||
|
||||
# Fetch and update models dynamically
|
||||
await self._refresh_provider_models(provider_name, provider)
|
||||
@@ -126,7 +146,9 @@ class LLMService:
|
||||
else:
|
||||
raise ConfigurationError(f"Unknown provider type: {config.name}")
|
||||
|
||||
async def _refresh_provider_models(self, provider_name: str, provider: BaseLLMProvider):
|
||||
async def _refresh_provider_models(
|
||||
self, provider_name: str, provider: BaseLLMProvider
|
||||
):
|
||||
"""Fetch and update models dynamically from provider"""
|
||||
try:
|
||||
# Get models from provider
|
||||
@@ -136,10 +158,14 @@ class LLMService:
|
||||
# Update configuration
|
||||
await config_manager.refresh_provider_models(provider_name, model_ids)
|
||||
|
||||
logger.info(f"Refreshed {len(model_ids)} models for provider '{provider_name}': {model_ids}")
|
||||
logger.info(
|
||||
f"Refreshed {len(model_ids)} models for provider '{provider_name}': {model_ids}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh models for provider '{provider_name}': {e}")
|
||||
logger.error(
|
||||
f"Failed to refresh models for provider '{provider_name}': {e}"
|
||||
)
|
||||
|
||||
async def create_chat_completion(self, request: ChatRequest) -> ChatResponse:
|
||||
"""Create chat completion with security and resilience"""
|
||||
@@ -157,8 +183,10 @@ class LLMService:
|
||||
provider = self._providers.get(provider_name)
|
||||
|
||||
if not provider:
|
||||
raise ProviderError(f"No available provider for model '{request.model}'", provider=provider_name)
|
||||
|
||||
raise ProviderError(
|
||||
f"No available provider for model '{request.model}'",
|
||||
provider=provider_name,
|
||||
)
|
||||
|
||||
# Execute with resilience
|
||||
resilience_manager = ResilienceManagerFactory.get_manager(provider_name)
|
||||
@@ -169,22 +197,18 @@ class LLMService:
|
||||
provider.create_chat_completion,
|
||||
request,
|
||||
retryable_exceptions=(ProviderError, TimeoutError),
|
||||
non_retryable_exceptions=(ValidationError,)
|
||||
non_retryable_exceptions=(ValidationError,),
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
# Record successful request - metrics disabled
|
||||
total_latency = (time.time() - start_time) * 1000
|
||||
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
# Record failed request - metrics disabled
|
||||
total_latency = (time.time() - start_time) * 1000
|
||||
error_code = getattr(e, 'error_code', e.__class__.__name__)
|
||||
error_code = getattr(e, "error_code", e.__class__.__name__)
|
||||
|
||||
logger.exception(
|
||||
"Chat completion failed for provider %s (model=%s, latency=%.2fms, error=%s)",
|
||||
@@ -195,7 +219,9 @@ class LLMService:
|
||||
)
|
||||
raise
|
||||
|
||||
async def create_chat_completion_stream(self, request: ChatRequest) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
async def create_chat_completion_stream(
|
||||
self, request: ChatRequest
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Create streaming chat completion"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
@@ -203,13 +229,15 @@ class LLMService:
|
||||
# Security validation disabled - always allow streaming requests
|
||||
risk_score = 0.0
|
||||
|
||||
|
||||
# Get provider
|
||||
provider_name = self._get_provider_for_model(request.model)
|
||||
provider = self._providers.get(provider_name)
|
||||
|
||||
if not provider:
|
||||
raise ProviderError(f"No available provider for model '{request.model}'", provider=provider_name)
|
||||
raise ProviderError(
|
||||
f"No available provider for model '{request.model}'",
|
||||
provider=provider_name,
|
||||
)
|
||||
|
||||
# Execute streaming with resilience
|
||||
resilience_manager = ResilienceManagerFactory.get_manager(provider_name)
|
||||
@@ -219,13 +247,13 @@ class LLMService:
|
||||
provider.create_chat_completion_stream,
|
||||
request,
|
||||
retryable_exceptions=(ProviderError, TimeoutError),
|
||||
non_retryable_exceptions=(ValidationError,)
|
||||
non_retryable_exceptions=(ValidationError,),
|
||||
):
|
||||
yield chunk
|
||||
|
||||
except Exception as e:
|
||||
# Record streaming failure - metrics disabled
|
||||
error_code = getattr(e, 'error_code', e.__class__.__name__)
|
||||
error_code = getattr(e, "error_code", e.__class__.__name__)
|
||||
logger.exception(
|
||||
"Streaming chat completion failed for provider %s (model=%s, error=%s)",
|
||||
provider_name,
|
||||
@@ -242,13 +270,15 @@ class LLMService:
|
||||
# Security validation disabled - always allow embedding requests
|
||||
risk_score = 0.0
|
||||
|
||||
|
||||
# Get provider
|
||||
provider_name = self._get_provider_for_model(request.model)
|
||||
provider = self._providers.get(provider_name)
|
||||
|
||||
if not provider:
|
||||
raise ProviderError(f"No available provider for model '{request.model}'", provider=provider_name)
|
||||
raise ProviderError(
|
||||
f"No available provider for model '{request.model}'",
|
||||
provider=provider_name,
|
||||
)
|
||||
|
||||
# Execute with resilience
|
||||
resilience_manager = ResilienceManagerFactory.get_manager(provider_name)
|
||||
@@ -259,20 +289,18 @@ class LLMService:
|
||||
provider.create_embedding,
|
||||
request,
|
||||
retryable_exceptions=(ProviderError, TimeoutError),
|
||||
non_retryable_exceptions=(ValidationError,)
|
||||
non_retryable_exceptions=(ValidationError,),
|
||||
)
|
||||
|
||||
|
||||
# Record successful request - metrics disabled
|
||||
total_latency = (time.time() - start_time) * 1000
|
||||
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
# Record failed request - metrics disabled
|
||||
total_latency = (time.time() - start_time) * 1000
|
||||
error_code = getattr(e, 'error_code', e.__class__.__name__)
|
||||
error_code = getattr(e, "error_code", e.__class__.__name__)
|
||||
logger.exception(
|
||||
"Embedding request failed for provider %s (model=%s, latency=%.2fms, error=%s)",
|
||||
provider_name,
|
||||
@@ -327,7 +355,7 @@ class LLMService:
|
||||
status="unavailable",
|
||||
last_check=datetime.utcnow(),
|
||||
error_message=str(e),
|
||||
models_available=[]
|
||||
models_available=[],
|
||||
)
|
||||
|
||||
return status_dict
|
||||
@@ -336,10 +364,7 @@ class LLMService:
|
||||
"""Get service metrics - metrics disabled"""
|
||||
# return metrics_collector.get_metrics()
|
||||
return LLMMetrics(
|
||||
total_requests=0,
|
||||
success_rate=0.0,
|
||||
avg_latency_ms=0,
|
||||
error_rates={}
|
||||
total_requests=0, success_rate=0.0, avg_latency_ms=0, error_rates={}
|
||||
)
|
||||
|
||||
def get_health_summary(self) -> Dict[str, Any]:
|
||||
@@ -349,11 +374,13 @@ class LLMService:
|
||||
|
||||
return {
|
||||
"service_status": "healthy" if self._initialized else "initializing",
|
||||
"startup_time": self._startup_time.isoformat() if self._startup_time else None,
|
||||
"startup_time": self._startup_time.isoformat()
|
||||
if self._startup_time
|
||||
else None,
|
||||
"provider_count": len(self._providers),
|
||||
"active_providers": list(self._providers.keys()),
|
||||
"metrics": {"status": "disabled"},
|
||||
"resilience": resilience_health
|
||||
"resilience": resilience_health,
|
||||
}
|
||||
|
||||
def _get_provider_for_model(self, model: str) -> str:
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.core.logging import log_module_event, log_security_event
|
||||
@dataclass
|
||||
class MetricData:
|
||||
"""Individual metric data point"""
|
||||
|
||||
timestamp: datetime
|
||||
value: float
|
||||
labels: Dict[str, str] = field(default_factory=dict)
|
||||
@@ -23,6 +24,7 @@ class MetricData:
|
||||
@dataclass
|
||||
class RequestMetrics:
|
||||
"""Request-related metrics"""
|
||||
|
||||
total_requests: int = 0
|
||||
successful_requests: int = 0
|
||||
failed_requests: int = 0
|
||||
@@ -37,6 +39,7 @@ class RequestMetrics:
|
||||
@dataclass
|
||||
class SystemMetrics:
|
||||
"""System-related metrics"""
|
||||
|
||||
uptime: float = 0.0
|
||||
memory_usage: float = 0.0
|
||||
cpu_usage: float = 0.0
|
||||
@@ -78,12 +81,16 @@ class MetricsService:
|
||||
|
||||
# Store historical data
|
||||
self._store_metric("uptime", self.system_metrics.uptime)
|
||||
self._store_metric("active_connections", self.system_metrics.active_connections)
|
||||
self._store_metric(
|
||||
"active_connections", self.system_metrics.active_connections
|
||||
)
|
||||
|
||||
await asyncio.sleep(60) # Collect every minute
|
||||
|
||||
except Exception as e:
|
||||
log_module_event("metrics_service", "system_metrics_error", {"error": str(e)})
|
||||
log_module_event(
|
||||
"metrics_service", "system_metrics_error", {"error": str(e)}
|
||||
)
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def _cleanup_old_metrics(self):
|
||||
@@ -103,20 +110,20 @@ class MetricsService:
|
||||
log_module_event("metrics_service", "cleanup_error", {"error": str(e)})
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
def _store_metric(self, name: str, value: float, labels: Optional[Dict[str, str]] = None):
|
||||
def _store_metric(
|
||||
self, name: str, value: float, labels: Optional[Dict[str, str]] = None
|
||||
):
|
||||
"""Store a metric data point"""
|
||||
if labels is None:
|
||||
labels = {}
|
||||
|
||||
metric_data = MetricData(
|
||||
timestamp=datetime.now(),
|
||||
value=value,
|
||||
labels=labels
|
||||
)
|
||||
metric_data = MetricData(timestamp=datetime.now(), value=value, labels=labels)
|
||||
|
||||
self.metric_history[name].append(metric_data)
|
||||
|
||||
def start_request(self, request_id: str, endpoint: str, user_id: Optional[str] = None):
|
||||
def start_request(
|
||||
self, request_id: str, endpoint: str, user_id: Optional[str] = None
|
||||
):
|
||||
"""Start tracking a request"""
|
||||
self.active_requests[request_id] = time.time()
|
||||
|
||||
@@ -124,13 +131,15 @@ class MetricsService:
|
||||
self.request_metrics.total_requests += 1
|
||||
|
||||
# Track by endpoint
|
||||
self.request_metrics.requests_by_endpoint[endpoint] = \
|
||||
self.request_metrics.requests_by_endpoint[endpoint] = (
|
||||
self.request_metrics.requests_by_endpoint.get(endpoint, 0) + 1
|
||||
)
|
||||
|
||||
# Track by user
|
||||
if user_id:
|
||||
self.request_metrics.requests_by_user[user_id] = \
|
||||
self.request_metrics.requests_by_user[user_id] = (
|
||||
self.request_metrics.requests_by_user.get(user_id, 0) + 1
|
||||
)
|
||||
|
||||
# Store metric
|
||||
self._store_metric("requests_total", self.request_metrics.total_requests)
|
||||
@@ -139,9 +148,14 @@ class MetricsService:
|
||||
if user_id:
|
||||
self._store_metric("requests_by_user", 1, {"user_id": user_id})
|
||||
|
||||
def end_request(self, request_id: str, success: bool = True,
|
||||
model: Optional[str] = None, tokens_used: int = 0,
|
||||
cost: float = 0.0):
|
||||
def end_request(
|
||||
self,
|
||||
request_id: str,
|
||||
success: bool = True,
|
||||
model: Optional[str] = None,
|
||||
tokens_used: int = 0,
|
||||
cost: float = 0.0,
|
||||
):
|
||||
"""End tracking a request"""
|
||||
if request_id not in self.active_requests:
|
||||
return
|
||||
@@ -158,7 +172,9 @@ class MetricsService:
|
||||
|
||||
# Update average response time
|
||||
if self.response_times:
|
||||
self.request_metrics.average_response_time = sum(self.response_times) / len(self.response_times)
|
||||
self.request_metrics.average_response_time = sum(self.response_times) / len(
|
||||
self.response_times
|
||||
)
|
||||
|
||||
# Update token and cost metrics
|
||||
self.request_metrics.total_tokens_used += tokens_used
|
||||
@@ -166,8 +182,9 @@ class MetricsService:
|
||||
|
||||
# Track by model
|
||||
if model:
|
||||
self.request_metrics.requests_by_model[model] = \
|
||||
self.request_metrics.requests_by_model[model] = (
|
||||
self.request_metrics.requests_by_model.get(model, 0) + 1
|
||||
)
|
||||
|
||||
# Store metrics
|
||||
self._store_metric("response_time", response_time)
|
||||
@@ -180,8 +197,13 @@ class MetricsService:
|
||||
# Clean up
|
||||
del self.active_requests[request_id]
|
||||
|
||||
def record_error(self, error_type: str, error_message: str,
|
||||
endpoint: Optional[str] = None, user_id: Optional[str] = None):
|
||||
def record_error(
|
||||
self,
|
||||
error_type: str,
|
||||
error_message: str,
|
||||
endpoint: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
):
|
||||
"""Record an error occurrence"""
|
||||
labels = {"error_type": error_type}
|
||||
|
||||
@@ -193,16 +215,23 @@ class MetricsService:
|
||||
self._store_metric("errors_total", 1, labels)
|
||||
|
||||
# Log security events for authentication/authorization errors
|
||||
if error_type in ["authentication_failed", "authorization_failed", "invalid_api_key"]:
|
||||
log_security_event(error_type, user_id or "anonymous", {
|
||||
"error": error_message,
|
||||
"endpoint": endpoint
|
||||
})
|
||||
if error_type in [
|
||||
"authentication_failed",
|
||||
"authorization_failed",
|
||||
"invalid_api_key",
|
||||
]:
|
||||
log_security_event(
|
||||
error_type,
|
||||
user_id or "anonymous",
|
||||
{"error": error_message, "endpoint": endpoint},
|
||||
)
|
||||
|
||||
def record_module_status(self, module_name: str, is_healthy: bool):
|
||||
"""Record module health status"""
|
||||
self.system_metrics.module_status[module_name] = is_healthy
|
||||
self._store_metric("module_health", 1 if is_healthy else 0, {"module": module_name})
|
||||
self._store_metric(
|
||||
"module_health", 1 if is_healthy else 0, {"module": module_name}
|
||||
)
|
||||
|
||||
def get_current_metrics(self) -> Dict[str, Any]:
|
||||
"""Get current metrics snapshot"""
|
||||
@@ -212,25 +241,28 @@ class MetricsService:
|
||||
"successful_requests": self.request_metrics.successful_requests,
|
||||
"failed_requests": self.request_metrics.failed_requests,
|
||||
"success_rate": (
|
||||
self.request_metrics.successful_requests / self.request_metrics.total_requests
|
||||
if self.request_metrics.total_requests > 0 else 0
|
||||
self.request_metrics.successful_requests
|
||||
/ self.request_metrics.total_requests
|
||||
if self.request_metrics.total_requests > 0
|
||||
else 0
|
||||
),
|
||||
"average_response_time": self.request_metrics.average_response_time,
|
||||
"total_tokens_used": self.request_metrics.total_tokens_used,
|
||||
"total_cost": self.request_metrics.total_cost,
|
||||
"requests_by_model": dict(self.request_metrics.requests_by_model),
|
||||
"requests_by_user": dict(self.request_metrics.requests_by_user),
|
||||
"requests_by_endpoint": dict(self.request_metrics.requests_by_endpoint)
|
||||
"requests_by_endpoint": dict(self.request_metrics.requests_by_endpoint),
|
||||
},
|
||||
"system_metrics": {
|
||||
"uptime": self.system_metrics.uptime,
|
||||
"active_connections": self.system_metrics.active_connections,
|
||||
"module_status": dict(self.system_metrics.module_status)
|
||||
}
|
||||
"module_status": dict(self.system_metrics.module_status),
|
||||
},
|
||||
}
|
||||
|
||||
def get_metrics_history(self, metric_name: str,
|
||||
hours: int = 1) -> List[Dict[str, Any]]:
|
||||
def get_metrics_history(
|
||||
self, metric_name: str, hours: int = 1
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get historical metrics data"""
|
||||
if metric_name not in self.metric_history:
|
||||
return []
|
||||
@@ -241,7 +273,7 @@ class MetricsService:
|
||||
{
|
||||
"timestamp": data.timestamp.isoformat(),
|
||||
"value": data.value,
|
||||
"labels": data.labels
|
||||
"labels": data.labels,
|
||||
}
|
||||
for data in self.metric_history[metric_name]
|
||||
if data.timestamp > cutoff_time
|
||||
@@ -251,18 +283,27 @@ class MetricsService:
|
||||
"""Get top metrics by type"""
|
||||
if metric_type == "models":
|
||||
return dict(
|
||||
sorted(self.request_metrics.requests_by_model.items(),
|
||||
key=lambda x: x[1], reverse=True)[:limit]
|
||||
sorted(
|
||||
self.request_metrics.requests_by_model.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)[:limit]
|
||||
)
|
||||
elif metric_type == "users":
|
||||
return dict(
|
||||
sorted(self.request_metrics.requests_by_user.items(),
|
||||
key=lambda x: x[1], reverse=True)[:limit]
|
||||
sorted(
|
||||
self.request_metrics.requests_by_user.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)[:limit]
|
||||
)
|
||||
elif metric_type == "endpoints":
|
||||
return dict(
|
||||
sorted(self.request_metrics.requests_by_endpoint.items(),
|
||||
key=lambda x: x[1], reverse=True)[:limit]
|
||||
sorted(
|
||||
self.request_metrics.requests_by_endpoint.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)[:limit]
|
||||
)
|
||||
else:
|
||||
return {}
|
||||
@@ -275,11 +316,13 @@ class MetricsService:
|
||||
"active_connections": self.system_metrics.active_connections,
|
||||
"total_requests": self.request_metrics.total_requests,
|
||||
"success_rate": (
|
||||
self.request_metrics.successful_requests / self.request_metrics.total_requests
|
||||
if self.request_metrics.total_requests > 0 else 1.0
|
||||
self.request_metrics.successful_requests
|
||||
/ self.request_metrics.total_requests
|
||||
if self.request_metrics.total_requests > 0
|
||||
else 1.0
|
||||
),
|
||||
"modules": self.system_metrics.module_status,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
async def reset_metrics(self):
|
||||
@@ -305,4 +348,5 @@ def setup_metrics(app):
|
||||
|
||||
# Initialize metrics service
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(metrics_service.initialize())
|
||||
@@ -18,6 +18,7 @@ logger = get_logger(__name__)
|
||||
@dataclass
|
||||
class ModuleManifest:
|
||||
"""Module manifest loaded from module.yaml"""
|
||||
|
||||
name: str
|
||||
version: str
|
||||
description: str
|
||||
@@ -69,7 +70,9 @@ class ModuleConfigManager:
|
||||
self.schemas: Dict[str, Dict] = {}
|
||||
self.configs: Dict[str, Dict] = {}
|
||||
|
||||
async def discover_modules(self, modules_path: str = "modules") -> Dict[str, ModuleManifest]:
|
||||
async def discover_modules(
|
||||
self, modules_path: str = "modules"
|
||||
) -> Dict[str, ModuleManifest]:
|
||||
"""Discover modules from filesystem using module.yaml manifests"""
|
||||
discovered_modules = {}
|
||||
|
||||
@@ -91,14 +94,16 @@ class ModuleConfigManager:
|
||||
if not manifest_path.exists():
|
||||
# Check if it's a legacy module (has main.py but no manifest)
|
||||
if (module_dir / "main.py").exists():
|
||||
logger.info(f"Legacy module found (no manifest): {module_dir.name}")
|
||||
logger.info(
|
||||
f"Legacy module found (no manifest): {module_dir.name}"
|
||||
)
|
||||
# Create a basic manifest for legacy modules
|
||||
manifest = ModuleManifest(
|
||||
name=module_dir.name,
|
||||
version="1.0.0",
|
||||
description=f"Legacy {module_dir.name} module",
|
||||
author="System",
|
||||
category="legacy"
|
||||
category="legacy",
|
||||
)
|
||||
discovered_modules[manifest.name] = manifest
|
||||
continue
|
||||
@@ -118,14 +123,16 @@ class ModuleConfigManager:
|
||||
async def _load_module_manifest(self, manifest_path: Path) -> ModuleManifest:
|
||||
"""Load and validate a module manifest file"""
|
||||
try:
|
||||
with open(manifest_path, 'r', encoding='utf-8') as f:
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest_data = yaml.safe_load(f)
|
||||
|
||||
# Validate required fields
|
||||
required_fields = ['name', 'version', 'description', 'author']
|
||||
required_fields = ["name", "version", "description", "author"]
|
||||
for field in required_fields:
|
||||
if field not in manifest_data:
|
||||
raise ConfigurationError(f"Missing required field '{field}' in {manifest_path}")
|
||||
raise ConfigurationError(
|
||||
f"Missing required field '{field}' in {manifest_path}"
|
||||
)
|
||||
|
||||
manifest = ModuleManifest(**manifest_data)
|
||||
|
||||
@@ -147,7 +154,7 @@ class ModuleConfigManager:
|
||||
async def _load_module_schema(self, module_name: str, schema_path: Path):
|
||||
"""Load JSON schema for module configuration"""
|
||||
try:
|
||||
with open(schema_path, 'r', encoding='utf-8') as f:
|
||||
with open(schema_path, "r", encoding="utf-8") as f:
|
||||
schema = json.load(f)
|
||||
|
||||
self.schemas[module_name] = schema
|
||||
@@ -174,26 +181,32 @@ class ModuleConfigManager:
|
||||
"""Validate module configuration against its schema"""
|
||||
schema = self.schemas.get(module_name)
|
||||
if not schema:
|
||||
logger.info(f"No schema found for module {module_name}, skipping validation")
|
||||
logger.info(
|
||||
f"No schema found for module {module_name}, skipping validation"
|
||||
)
|
||||
return {"valid": True, "errors": []}
|
||||
|
||||
try:
|
||||
validate(instance=config, schema=schema, format_checker=draft7_format_checker)
|
||||
validate(
|
||||
instance=config, schema=schema, format_checker=draft7_format_checker
|
||||
)
|
||||
return {"valid": True, "errors": []}
|
||||
|
||||
except ValidationError as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"errors": [{
|
||||
"errors": [
|
||||
{
|
||||
"path": list(e.path),
|
||||
"message": e.message,
|
||||
"invalid_value": e.instance
|
||||
}]
|
||||
"invalid_value": e.instance,
|
||||
}
|
||||
],
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"errors": [{"message": f"Schema validation failed: {str(e)}"}]
|
||||
"errors": [{"message": f"Schema validation failed: {str(e)}"}],
|
||||
}
|
||||
|
||||
async def save_module_config(self, module_name: str, config: Dict) -> bool:
|
||||
@@ -202,7 +215,9 @@ class ModuleConfigManager:
|
||||
validation_result = await self.validate_config(module_name, config)
|
||||
if not validation_result["valid"]:
|
||||
error_messages = [error["message"] for error in validation_result["errors"]]
|
||||
raise ConfigurationError(f"Invalid configuration for {module_name}: {', '.join(error_messages)}")
|
||||
raise ConfigurationError(
|
||||
f"Invalid configuration for {module_name}: {', '.join(error_messages)}"
|
||||
)
|
||||
|
||||
# Save configuration
|
||||
self.configs[module_name] = config
|
||||
@@ -214,7 +229,7 @@ class ModuleConfigManager:
|
||||
|
||||
config_file = config_dir / f"{module_name}.json"
|
||||
try:
|
||||
with open(config_file, 'w', encoding='utf-8') as f:
|
||||
with open(config_file, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
logger.info(f"Saved configuration for module: {module_name}")
|
||||
@@ -233,7 +248,7 @@ class ModuleConfigManager:
|
||||
for config_file in config_dir.glob("*.json"):
|
||||
module_name = config_file.stem
|
||||
try:
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.configs[module_name] = config
|
||||
@@ -246,7 +261,8 @@ class ModuleConfigManager:
|
||||
"""List all discovered modules with their metadata"""
|
||||
modules = []
|
||||
for name, manifest in self.manifests.items():
|
||||
modules.append({
|
||||
modules.append(
|
||||
{
|
||||
"name": manifest.name,
|
||||
"version": manifest.version,
|
||||
"description": manifest.description,
|
||||
@@ -258,12 +274,12 @@ class ModuleConfigManager:
|
||||
"consumes": manifest.consumes,
|
||||
"has_schema": name in self.schemas,
|
||||
"has_config": name in self.configs,
|
||||
"ui_config": manifest.ui_config
|
||||
})
|
||||
"ui_config": manifest.ui_config,
|
||||
}
|
||||
)
|
||||
|
||||
return modules
|
||||
|
||||
|
||||
async def update_module_status(self, module_name: str, enabled: bool) -> bool:
|
||||
"""Update module enabled status"""
|
||||
manifest = self.manifests.get(module_name)
|
||||
@@ -279,7 +295,7 @@ class ModuleConfigManager:
|
||||
if manifest_path.exists():
|
||||
try:
|
||||
manifest_dict = asdict(manifest)
|
||||
with open(manifest_path, 'w', encoding='utf-8') as f:
|
||||
with open(manifest_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(manifest_dict, f, default_flow_style=False)
|
||||
|
||||
logger.info(f"Updated module status: {module_name} enabled={enabled}")
|
||||
|
||||
@@ -23,6 +23,7 @@ logger = get_logger(__name__)
|
||||
@dataclass
|
||||
class ModuleConfig:
|
||||
"""Configuration for a module"""
|
||||
|
||||
name: str
|
||||
enabled: bool = True
|
||||
config: Dict[str, Any] = None
|
||||
@@ -52,21 +53,24 @@ class ModuleFileWatcher(FileSystemEventHandler):
|
||||
return parts[0] if parts else None
|
||||
|
||||
def on_modified(self, event):
|
||||
if event.is_directory or not event.src_path.endswith('.py'):
|
||||
if event.is_directory or not event.src_path.endswith(".py"):
|
||||
return
|
||||
|
||||
module_name = self._resolve_module_name(event.src_path)
|
||||
if not module_name or module_name not in self.module_manager.modules:
|
||||
return
|
||||
|
||||
log_module_event("hot_reload", "file_changed", {
|
||||
"module": module_name,
|
||||
"file": event.src_path
|
||||
})
|
||||
log_module_event(
|
||||
"hot_reload",
|
||||
"file_changed",
|
||||
{"module": module_name, "file": event.src_path},
|
||||
)
|
||||
|
||||
loop = self.module_manager.loop
|
||||
if not loop or loop.is_closed():
|
||||
logger.debug("Hot reload skipped for %s; event loop unavailable", module_name)
|
||||
logger.debug(
|
||||
"Hot reload skipped for %s; event loop unavailable", module_name
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -75,7 +79,8 @@ class ModuleFileWatcher(FileSystemEventHandler):
|
||||
loop,
|
||||
)
|
||||
future.add_done_callback(
|
||||
lambda f: f.exception() and logger.warning(
|
||||
lambda f: f.exception()
|
||||
and logger.warning(
|
||||
"Module reload error for %s: %s", module_name, f.exception()
|
||||
)
|
||||
)
|
||||
@@ -95,7 +100,9 @@ class ModuleManager:
|
||||
self.file_observer = None
|
||||
self.fastapi_app = None
|
||||
self.loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self.modules_root = (Path(__file__).resolve().parent.parent / "modules").resolve()
|
||||
self.modules_root = (
|
||||
Path(__file__).resolve().parent.parent / "modules"
|
||||
).resolve()
|
||||
|
||||
async def initialize(self, fastapi_app=None):
|
||||
"""Initialize the module manager and load all modules"""
|
||||
@@ -117,13 +124,23 @@ class ModuleManager:
|
||||
await self._load_modules()
|
||||
|
||||
self.initialized = True
|
||||
log_module_event("module_manager", "initialized", {
|
||||
log_module_event(
|
||||
"module_manager",
|
||||
"initialized",
|
||||
{
|
||||
"modules_count": len(self.modules),
|
||||
"enabled_modules": [name for name, config in self.module_configs.items() if config.enabled]
|
||||
})
|
||||
"enabled_modules": [
|
||||
name
|
||||
for name, config in self.module_configs.items()
|
||||
if config.enabled
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log_module_event("module_manager", "initialization_failed", {"error": str(e)})
|
||||
log_module_event(
|
||||
"module_manager", "initialization_failed", {"error": str(e)}
|
||||
)
|
||||
raise ModuleLoadError(f"Failed to initialize module manager: {str(e)}")
|
||||
|
||||
async def _load_module_configs(self):
|
||||
@@ -138,7 +155,9 @@ class ModuleManager:
|
||||
logger.warning("Modules directory not found at %s", self.modules_root)
|
||||
return
|
||||
|
||||
discovered_manifests = await module_config_manager.discover_modules(str(self.modules_root))
|
||||
discovered_manifests = await module_config_manager.discover_modules(
|
||||
str(self.modules_root)
|
||||
)
|
||||
|
||||
# Load saved configurations
|
||||
await module_config_manager.load_saved_configs()
|
||||
@@ -150,7 +169,9 @@ class ModuleManager:
|
||||
for name, manifest in discovered_manifests.items():
|
||||
# Skip modules that are now core infrastructure
|
||||
if name in EXCLUDED_MODULES:
|
||||
logger.info(f"Skipping module '{name}' - now integrated as core infrastructure")
|
||||
logger.info(
|
||||
f"Skipping module '{name}' - now integrated as core infrastructure"
|
||||
)
|
||||
continue
|
||||
|
||||
saved_config = module_config_manager.get_module_config(name)
|
||||
@@ -159,19 +180,25 @@ class ModuleManager:
|
||||
name=manifest.name,
|
||||
enabled=manifest.enabled,
|
||||
config=saved_config,
|
||||
dependencies=manifest.dependencies
|
||||
dependencies=manifest.dependencies,
|
||||
)
|
||||
|
||||
self.module_configs[name] = module_config
|
||||
|
||||
log_module_event(name, "discovered", {
|
||||
log_module_event(
|
||||
name,
|
||||
"discovered",
|
||||
{
|
||||
"version": manifest.version,
|
||||
"description": manifest.description,
|
||||
"enabled": manifest.enabled,
|
||||
"dependencies": manifest.dependencies
|
||||
})
|
||||
"dependencies": manifest.dependencies,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Discovered {len(discovered_manifests)} modules: {list(discovered_manifests.keys())}")
|
||||
logger.info(
|
||||
f"Discovered {len(discovered_manifests)} modules: {list(discovered_manifests.keys())}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to discover modules: {e}")
|
||||
@@ -186,9 +213,7 @@ class ModuleManager:
|
||||
"""Fallback to legacy hard-coded module loading"""
|
||||
logger.warning("Falling back to legacy module configuration")
|
||||
|
||||
default_modules = [
|
||||
ModuleConfig(name="rag", enabled=True, config={})
|
||||
]
|
||||
default_modules = [ModuleConfig(name="rag", enabled=True, config={})]
|
||||
|
||||
for config in default_modules:
|
||||
self.module_configs[config.name] = config
|
||||
@@ -212,7 +237,9 @@ class ModuleManager:
|
||||
|
||||
def visit(module_name: str):
|
||||
if module_name in temp_visited:
|
||||
raise ModuleLoadError(f"Circular dependency detected involving module: {module_name}")
|
||||
raise ModuleLoadError(
|
||||
f"Circular dependency detected involving module: {module_name}"
|
||||
)
|
||||
if module_name in visited:
|
||||
return
|
||||
|
||||
@@ -263,14 +290,14 @@ class ModuleManager:
|
||||
module_instance = None
|
||||
|
||||
# Pattern 1: {module_name}_module (e.g., cache_module)
|
||||
if hasattr(module, f'{module_name}_module'):
|
||||
module_instance = getattr(module, f'{module_name}_module')
|
||||
if hasattr(module, f"{module_name}_module"):
|
||||
module_instance = getattr(module, f"{module_name}_module")
|
||||
# Pattern 2: Just 'module' attribute
|
||||
elif hasattr(module, 'module'):
|
||||
module_instance = getattr(module, 'module')
|
||||
elif hasattr(module, "module"):
|
||||
module_instance = getattr(module, "module")
|
||||
# Pattern 3: Module class with same name as module (e.g., CacheModule)
|
||||
elif hasattr(module, f'{module_name.title()}Module'):
|
||||
module_class = getattr(module, f'{module_name.title()}Module')
|
||||
elif hasattr(module, f"{module_name.title()}Module"):
|
||||
module_class = getattr(module, f"{module_name.title()}Module")
|
||||
if callable(module_class):
|
||||
module_instance = module_class()
|
||||
else:
|
||||
@@ -283,14 +310,17 @@ class ModuleManager:
|
||||
|
||||
# Initialize the module if it has an init function
|
||||
module_initialized = False
|
||||
if hasattr(self.modules[module_name], 'initialize'):
|
||||
if hasattr(self.modules[module_name], "initialize"):
|
||||
try:
|
||||
import inspect
|
||||
|
||||
init_method = self.modules[module_name].initialize
|
||||
sig = inspect.signature(init_method)
|
||||
param_count = len([p for p in sig.parameters.values() if p.name != 'self'])
|
||||
param_count = len(
|
||||
[p for p in sig.parameters.values() if p.name != "self"]
|
||||
)
|
||||
|
||||
if hasattr(self.modules[module_name], 'config'):
|
||||
if hasattr(self.modules[module_name], "config"):
|
||||
# Pass config if it's a BaseModule
|
||||
self.modules[module_name].config.update(config.config)
|
||||
await self.modules[module_name].initialize()
|
||||
@@ -303,7 +333,9 @@ class ModuleManager:
|
||||
module_initialized = True
|
||||
log_module_event(module_name, "initialized", {"success": True})
|
||||
except Exception as e:
|
||||
log_module_event(module_name, "initialization_failed", {"error": str(e)})
|
||||
log_module_event(
|
||||
module_name, "initialization_failed", {"error": str(e)}
|
||||
)
|
||||
module_initialized = False
|
||||
else:
|
||||
# Module doesn't have initialize method, mark as initialized anyway
|
||||
@@ -320,26 +352,32 @@ class ModuleManager:
|
||||
permissions = []
|
||||
|
||||
# New BaseModule method
|
||||
if hasattr(self.modules[module_name], 'get_required_permissions'):
|
||||
if hasattr(self.modules[module_name], "get_required_permissions"):
|
||||
try:
|
||||
permissions = self.modules[module_name].get_required_permissions()
|
||||
log_module_event(module_name, "permissions_registered", {
|
||||
"permissions_count": len(permissions),
|
||||
"type": "BaseModule"
|
||||
})
|
||||
log_module_event(
|
||||
module_name,
|
||||
"permissions_registered",
|
||||
{"permissions_count": len(permissions), "type": "BaseModule"},
|
||||
)
|
||||
except Exception as e:
|
||||
log_module_event(module_name, "permissions_failed", {"error": str(e)})
|
||||
log_module_event(
|
||||
module_name, "permissions_failed", {"error": str(e)}
|
||||
)
|
||||
|
||||
# Legacy method
|
||||
elif hasattr(self.modules[module_name], 'get_permissions'):
|
||||
elif hasattr(self.modules[module_name], "get_permissions"):
|
||||
try:
|
||||
permissions = self.modules[module_name].get_permissions()
|
||||
log_module_event(module_name, "permissions_registered", {
|
||||
"permissions_count": len(permissions),
|
||||
"type": "legacy"
|
||||
})
|
||||
log_module_event(
|
||||
module_name,
|
||||
"permissions_registered",
|
||||
{"permissions_count": len(permissions), "type": "legacy"},
|
||||
)
|
||||
except Exception as e:
|
||||
log_module_event(module_name, "permissions_failed", {"error": str(e)})
|
||||
log_module_event(
|
||||
module_name, "permissions_failed", {"error": str(e)}
|
||||
)
|
||||
|
||||
# Register permissions with the permission system
|
||||
if permissions:
|
||||
@@ -352,21 +390,29 @@ class ModuleManager:
|
||||
|
||||
except ImportError as e:
|
||||
error_msg = f"Module {module_name} import failed: {str(e)}"
|
||||
log_module_event(module_name, "load_failed", {"error": error_msg, "type": "ImportError"})
|
||||
log_module_event(
|
||||
module_name, "load_failed", {"error": error_msg, "type": "ImportError"}
|
||||
)
|
||||
# For critical modules, we might want to fail completely
|
||||
if module_name in ['security', 'cache']:
|
||||
if module_name in ["security", "cache"]:
|
||||
raise ModuleLoadError(error_msg)
|
||||
# For optional modules, log warning but continue
|
||||
import warnings
|
||||
|
||||
warnings.warn(f"Optional module {module_name} failed to load: {str(e)}")
|
||||
except Exception as e:
|
||||
error_msg = f"Module {module_name} loading failed: {str(e)}"
|
||||
log_module_event(module_name, "load_failed", {"error": error_msg, "type": type(e).__name__})
|
||||
log_module_event(
|
||||
module_name,
|
||||
"load_failed",
|
||||
{"error": error_msg, "type": type(e).__name__},
|
||||
)
|
||||
# For critical modules, we might want to fail completely
|
||||
if module_name in ['security', 'cache']:
|
||||
if module_name in ["security", "cache"]:
|
||||
raise ModuleLoadError(error_msg)
|
||||
# For optional modules, log warning but continue
|
||||
import warnings
|
||||
|
||||
warnings.warn(f"Optional module {module_name} failed to load: {str(e)}")
|
||||
|
||||
async def _register_module_router(self, module_name: str, module_instance):
|
||||
@@ -376,30 +422,37 @@ class ModuleManager:
|
||||
|
||||
try:
|
||||
# Check if module has a router attribute
|
||||
if hasattr(module_instance, 'router'):
|
||||
router = getattr(module_instance, 'router')
|
||||
if hasattr(module_instance, "router"):
|
||||
router = getattr(module_instance, "router")
|
||||
|
||||
# Verify it's actually a FastAPI router
|
||||
from fastapi import APIRouter
|
||||
|
||||
if isinstance(router, APIRouter):
|
||||
# Register the router with the app
|
||||
self.fastapi_app.include_router(router)
|
||||
|
||||
log_module_event(module_name, "router_registered", {
|
||||
"router_prefix": getattr(router, 'prefix', 'unknown'),
|
||||
"router_tags": getattr(router, 'tags', [])
|
||||
})
|
||||
log_module_event(
|
||||
module_name,
|
||||
"router_registered",
|
||||
{
|
||||
"router_prefix": getattr(router, "prefix", "unknown"),
|
||||
"router_tags": getattr(router, "tags", []),
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Registered router for module {module_name}")
|
||||
else:
|
||||
logger.debug(f"Module {module_name} has 'router' attribute but it's not a FastAPI router")
|
||||
logger.debug(
|
||||
f"Module {module_name} has 'router' attribute but it's not a FastAPI router"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Module {module_name} does not have a router")
|
||||
|
||||
except Exception as e:
|
||||
log_module_event(module_name, "router_registration_failed", {
|
||||
"error": str(e)
|
||||
})
|
||||
log_module_event(
|
||||
module_name, "router_registration_failed", {"error": str(e)}
|
||||
)
|
||||
logger.warning(f"Failed to register router for module {module_name}: {e}")
|
||||
|
||||
async def unload_module(self, module_name: str):
|
||||
@@ -411,7 +464,7 @@ class ModuleManager:
|
||||
module = self.modules[module_name]
|
||||
|
||||
# Call cleanup if available
|
||||
if hasattr(module, 'cleanup'):
|
||||
if hasattr(module, "cleanup"):
|
||||
await module.cleanup()
|
||||
|
||||
del self.modules[module_name]
|
||||
@@ -435,7 +488,11 @@ class ModuleManager:
|
||||
log_module_event(module_name, "reloaded", {"success": True})
|
||||
return True
|
||||
else:
|
||||
log_module_event(module_name, "reload_skipped", {"reason": "Module disabled or no config"})
|
||||
log_module_event(
|
||||
module_name,
|
||||
"reload_skipped",
|
||||
{"reason": "Module disabled or no config"},
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
log_module_event(module_name, "reload_failed", {"error": str(e)})
|
||||
@@ -453,7 +510,9 @@ class ModuleManager:
|
||||
"""Check if a module is loaded"""
|
||||
return module_name in self.modules
|
||||
|
||||
async def execute_interceptor_chain(self, chain_type: str, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def execute_interceptor_chain(
|
||||
self, chain_type: str, context: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute interceptor chain for all loaded modules"""
|
||||
result_context = context.copy()
|
||||
|
||||
@@ -468,15 +527,17 @@ class ModuleManager:
|
||||
interceptor = getattr(module, interceptor_method)
|
||||
result_context = await interceptor(result_context)
|
||||
|
||||
log_module_event(module_name, "interceptor_executed", {
|
||||
"chain_type": chain_type,
|
||||
"success": True
|
||||
})
|
||||
log_module_event(
|
||||
module_name,
|
||||
"interceptor_executed",
|
||||
{"chain_type": chain_type, "success": True},
|
||||
)
|
||||
except Exception as e:
|
||||
log_module_event(module_name, "interceptor_failed", {
|
||||
"chain_type": chain_type,
|
||||
"error": str(e)
|
||||
})
|
||||
log_module_event(
|
||||
module_name,
|
||||
"interceptor_failed",
|
||||
{"chain_type": chain_type, "error": str(e)},
|
||||
)
|
||||
# Continue with other modules even if one fails
|
||||
continue
|
||||
|
||||
@@ -487,7 +548,9 @@ class ModuleManager:
|
||||
if not self.initialized:
|
||||
return
|
||||
|
||||
log_module_event("module_manager", "shutting_down", {"modules_count": len(self.modules)})
|
||||
log_module_event(
|
||||
"module_manager", "shutting_down", {"modules_count": len(self.modules)}
|
||||
)
|
||||
|
||||
# Unload modules in reverse order
|
||||
for module_name in reversed(self.module_order):
|
||||
@@ -519,14 +582,22 @@ class ModuleManager:
|
||||
return
|
||||
|
||||
if not self.modules_root.exists():
|
||||
log_module_event("hot_reload", "watcher_skipped", {"reason": f"No modules directory at {self.modules_root}"})
|
||||
log_module_event(
|
||||
"hot_reload",
|
||||
"watcher_skipped",
|
||||
{"reason": f"No modules directory at {self.modules_root}"},
|
||||
)
|
||||
return
|
||||
|
||||
self.file_observer = Observer()
|
||||
event_handler = ModuleFileWatcher(self, self.modules_root)
|
||||
self.file_observer.schedule(event_handler, str(self.modules_root), recursive=True)
|
||||
self.file_observer.schedule(
|
||||
event_handler, str(self.modules_root), recursive=True
|
||||
)
|
||||
self.file_observer.start()
|
||||
log_module_event("hot_reload", "watcher_started", {"path": str(self.modules_root)})
|
||||
log_module_event(
|
||||
"hot_reload", "watcher_started", {"path": str(self.modules_root)}
|
||||
)
|
||||
except Exception as e:
|
||||
log_module_event("hot_reload", "watcher_failed", {"error": str(e)})
|
||||
|
||||
@@ -536,7 +607,9 @@ class ModuleManager:
|
||||
"""Enable a module"""
|
||||
try:
|
||||
# Update the manifest status
|
||||
success = await module_config_manager.update_module_status(module_name, True)
|
||||
success = await module_config_manager.update_module_status(
|
||||
module_name, True
|
||||
)
|
||||
if not success:
|
||||
return False
|
||||
|
||||
@@ -561,7 +634,9 @@ class ModuleManager:
|
||||
"""Disable a module"""
|
||||
try:
|
||||
# Update the manifest status
|
||||
success = await module_config_manager.update_module_status(module_name, False)
|
||||
success = await module_config_manager.update_module_status(
|
||||
module_name, False
|
||||
)
|
||||
if not success:
|
||||
return False
|
||||
|
||||
@@ -604,8 +679,9 @@ class ModuleManager:
|
||||
"endpoints": manifest.endpoints,
|
||||
"permissions": manifest.permissions,
|
||||
"ui_config": manifest.ui_config,
|
||||
"has_schema": module_config_manager.get_module_schema(module_name) is not None,
|
||||
"current_config": module_config_manager.get_module_config(module_name)
|
||||
"has_schema": module_config_manager.get_module_schema(module_name)
|
||||
is not None,
|
||||
"current_config": module_config_manager.get_module_config(module_name),
|
||||
}
|
||||
|
||||
def list_all_modules(self) -> List[Dict]:
|
||||
@@ -621,7 +697,9 @@ class ModuleManager:
|
||||
"""Update module configuration"""
|
||||
try:
|
||||
# Validate and save the configuration
|
||||
success = await module_config_manager.save_module_config(module_name, config)
|
||||
success = await module_config_manager.save_module_config(
|
||||
module_name, config
|
||||
)
|
||||
if not success:
|
||||
return False
|
||||
|
||||
@@ -640,7 +718,6 @@ class ModuleManager:
|
||||
log_module_event(module_name, "config_update_failed", {"error": str(e)})
|
||||
return False
|
||||
|
||||
|
||||
async def get_module_health(self, module_name: str) -> Dict:
|
||||
"""Get module health status"""
|
||||
manifest = module_config_manager.get_module_manifest(module_name)
|
||||
@@ -656,11 +733,11 @@ class ModuleManager:
|
||||
"enabled": manifest.enabled,
|
||||
"dependencies_met": self._check_dependencies(module_name),
|
||||
"last_loaded": None,
|
||||
"error": None
|
||||
"error": None,
|
||||
}
|
||||
|
||||
# Check if module has custom health check
|
||||
if module and hasattr(module, 'get_health'):
|
||||
if module and hasattr(module, "get_health"):
|
||||
try:
|
||||
custom_health = await module.get_health()
|
||||
health.update(custom_health)
|
||||
|
||||
556
backend/app/services/notification_service.py
Normal file
556
backend/app/services/notification_service.py
Normal file
@@ -0,0 +1,556 @@
|
||||
"""
|
||||
Notification Service
|
||||
Multi-channel notification system with email, webhooks, and other providers
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from jinja2 import Template, Environment, DictLoader
|
||||
import aiohttp
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import and_, or_, desc, func
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.models.notification import (
|
||||
Notification,
|
||||
NotificationTemplate,
|
||||
NotificationChannel,
|
||||
NotificationType,
|
||||
NotificationStatus,
|
||||
NotificationPriority,
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NotificationService:
|
||||
"""Service for managing and sending notifications"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.jinja_env = Environment(loader=DictLoader({}))
|
||||
|
||||
async def send_notification(
|
||||
self,
|
||||
recipients: List[str],
|
||||
subject: Optional[str] = None,
|
||||
body: str = "",
|
||||
html_body: Optional[str] = None,
|
||||
notification_type: NotificationType = NotificationType.EMAIL,
|
||||
priority: NotificationPriority = NotificationPriority.NORMAL,
|
||||
template_name: Optional[str] = None,
|
||||
template_variables: Optional[Dict[str, Any]] = None,
|
||||
channel_name: Optional[str] = None,
|
||||
user_id: Optional[int] = None,
|
||||
scheduled_at: Optional[datetime] = None,
|
||||
expires_at: Optional[datetime] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> Notification:
|
||||
"""Send a notification through specified channel"""
|
||||
|
||||
# Get or create channel
|
||||
channel = await self._get_or_default_channel(notification_type, channel_name)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No active channel found for type: {notification_type}",
|
||||
)
|
||||
|
||||
# Process template if specified
|
||||
if template_name:
|
||||
template = await self._get_template(template_name)
|
||||
if template:
|
||||
rendered = await self._render_template(
|
||||
template, template_variables or {}
|
||||
)
|
||||
subject = subject or rendered.get("subject")
|
||||
body = body or rendered.get("body")
|
||||
html_body = html_body or rendered.get("html_body")
|
||||
|
||||
# Create notification record
|
||||
notification = Notification(
|
||||
subject=subject,
|
||||
body=body,
|
||||
html_body=html_body,
|
||||
recipients=recipients,
|
||||
priority=priority,
|
||||
scheduled_at=scheduled_at,
|
||||
expires_at=expires_at,
|
||||
channel_id=channel.id,
|
||||
user_id=user_id,
|
||||
metadata=metadata or {},
|
||||
tags=tags or [],
|
||||
)
|
||||
|
||||
self.db.add(notification)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(notification)
|
||||
|
||||
# Send immediately if not scheduled
|
||||
if scheduled_at is None or scheduled_at <= datetime.utcnow():
|
||||
await self._deliver_notification(notification)
|
||||
|
||||
return notification
|
||||
|
||||
async def send_email(
|
||||
self,
|
||||
recipients: List[str],
|
||||
subject: str,
|
||||
body: str,
|
||||
html_body: Optional[str] = None,
|
||||
cc_recipients: Optional[List[str]] = None,
|
||||
bcc_recipients: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> Notification:
|
||||
"""Send email notification"""
|
||||
|
||||
notification = await self.send_notification(
|
||||
recipients=recipients,
|
||||
subject=subject,
|
||||
body=body,
|
||||
html_body=html_body,
|
||||
notification_type=NotificationType.EMAIL,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if cc_recipients:
|
||||
notification.cc_recipients = cc_recipients
|
||||
if bcc_recipients:
|
||||
notification.bcc_recipients = bcc_recipients
|
||||
|
||||
await self.db.commit()
|
||||
return notification
|
||||
|
||||
async def send_webhook(
|
||||
self,
|
||||
webhook_url: str,
|
||||
payload: Dict[str, Any],
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Notification:
|
||||
"""Send webhook notification"""
|
||||
|
||||
# Use webhook URL as recipient
|
||||
notification = await self.send_notification(
|
||||
recipients=[webhook_url],
|
||||
body=json.dumps(payload),
|
||||
notification_type=NotificationType.WEBHOOK,
|
||||
metadata={"headers": headers or {}},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return notification
|
||||
|
||||
async def send_slack_message(
|
||||
self, channel: str, message: str, **kwargs
|
||||
) -> Notification:
|
||||
"""Send Slack message"""
|
||||
|
||||
notification = await self.send_notification(
|
||||
recipients=[channel],
|
||||
body=message,
|
||||
notification_type=NotificationType.SLACK,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return notification
|
||||
|
||||
async def process_scheduled_notifications(self):
|
||||
"""Process notifications that are scheduled for delivery"""
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Get pending scheduled notifications that are due
|
||||
stmt = select(Notification).where(
|
||||
and_(
|
||||
Notification.status == NotificationStatus.PENDING,
|
||||
Notification.scheduled_at <= now,
|
||||
or_(Notification.expires_at.is_(None), Notification.expires_at > now),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.db.execute(stmt)
|
||||
notifications = result.scalars().all()
|
||||
|
||||
processed_count = 0
|
||||
for notification in notifications:
|
||||
try:
|
||||
await self._deliver_notification(notification)
|
||||
processed_count += 1
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to process scheduled notification {notification.id}: {e}"
|
||||
)
|
||||
|
||||
logger.info(f"Processed {processed_count} scheduled notifications")
|
||||
return processed_count
|
||||
|
||||
async def retry_failed_notifications(self):
|
||||
"""Retry failed notifications that can be retried"""
|
||||
|
||||
# Get failed notifications that can be retried
|
||||
stmt = select(Notification).where(
|
||||
and_(
|
||||
Notification.status.in_(
|
||||
[NotificationStatus.FAILED, NotificationStatus.RETRY]
|
||||
),
|
||||
Notification.attempts < Notification.max_attempts,
|
||||
or_(
|
||||
Notification.expires_at.is_(None),
|
||||
Notification.expires_at > datetime.utcnow(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.db.execute(stmt)
|
||||
notifications = result.scalars().all()
|
||||
|
||||
retried_count = 0
|
||||
for notification in notifications:
|
||||
# Check retry delay
|
||||
if notification.failed_at:
|
||||
retry_delay = timedelta(
|
||||
minutes=notification.channel.retry_delay_minutes
|
||||
)
|
||||
if datetime.utcnow() - notification.failed_at < retry_delay:
|
||||
continue
|
||||
|
||||
try:
|
||||
await self._deliver_notification(notification)
|
||||
retried_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retry notification {notification.id}: {e}")
|
||||
|
||||
logger.info(f"Retried {retried_count} failed notifications")
|
||||
return retried_count
|
||||
|
||||
async def _deliver_notification(self, notification: Notification):
|
||||
"""Deliver a notification through its channel"""
|
||||
|
||||
channel = await self._get_channel_by_id(notification.channel_id)
|
||||
if not channel or not channel.is_active:
|
||||
notification.mark_failed("Channel not available")
|
||||
await self.db.commit()
|
||||
return
|
||||
|
||||
try:
|
||||
if channel.notification_type == NotificationType.EMAIL:
|
||||
await self._send_email(notification, channel)
|
||||
elif channel.notification_type == NotificationType.WEBHOOK:
|
||||
await self._send_webhook(notification, channel)
|
||||
elif channel.notification_type == NotificationType.SLACK:
|
||||
await self._send_slack(notification, channel)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported notification type: {channel.notification_type}"
|
||||
)
|
||||
|
||||
# Update channel stats
|
||||
channel.update_stats(success=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deliver notification {notification.id}: {e}")
|
||||
notification.mark_failed(str(e))
|
||||
channel.update_stats(success=False, error_message=str(e))
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
async def _send_email(
|
||||
self, notification: Notification, channel: NotificationChannel
|
||||
):
|
||||
"""Send email through SMTP"""
|
||||
|
||||
config = channel.config
|
||||
credentials = channel.credentials or {}
|
||||
|
||||
# Create message
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = notification.subject or "No Subject"
|
||||
msg["From"] = config.get("from_email", "noreply@example.com")
|
||||
msg["To"] = ", ".join(notification.recipients)
|
||||
|
||||
if notification.cc_recipients:
|
||||
msg["Cc"] = ", ".join(notification.cc_recipients)
|
||||
|
||||
# Add text part
|
||||
text_part = MIMEText(notification.body, "plain", "utf-8")
|
||||
msg.attach(text_part)
|
||||
|
||||
# Add HTML part if available
|
||||
if notification.html_body:
|
||||
html_part = MIMEText(notification.html_body, "html", "utf-8")
|
||||
msg.attach(html_part)
|
||||
|
||||
# Send email
|
||||
smtp_host = config.get("smtp_host", "localhost")
|
||||
smtp_port = config.get("smtp_port", 587)
|
||||
username = credentials.get("username")
|
||||
password = credentials.get("password")
|
||||
use_tls = config.get("use_tls", True)
|
||||
|
||||
with smtplib.SMTP(smtp_host, smtp_port) as server:
|
||||
if use_tls:
|
||||
server.starttls()
|
||||
if username and password:
|
||||
server.login(username, password)
|
||||
|
||||
all_recipients = notification.recipients[:]
|
||||
if notification.cc_recipients:
|
||||
all_recipients.extend(notification.cc_recipients)
|
||||
if notification.bcc_recipients:
|
||||
all_recipients.extend(notification.bcc_recipients)
|
||||
|
||||
server.sendmail(msg["From"], all_recipients, msg.as_string())
|
||||
|
||||
notification.mark_sent()
|
||||
|
||||
async def _send_webhook(
|
||||
self, notification: Notification, channel: NotificationChannel
|
||||
):
|
||||
"""Send webhook HTTP request"""
|
||||
|
||||
webhook_url = notification.recipients[0] # URL is stored as recipient
|
||||
headers = notification.metadata.get("headers", {})
|
||||
headers.setdefault("Content-Type", "application/json")
|
||||
|
||||
# Parse body as JSON payload
|
||||
try:
|
||||
payload = json.loads(notification.body)
|
||||
except json.JSONDecodeError:
|
||||
payload = {"message": notification.body}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
webhook_url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
) as response:
|
||||
if response.status >= 400:
|
||||
raise Exception(
|
||||
f"Webhook failed with status {response.status}: {await response.text()}"
|
||||
)
|
||||
|
||||
external_id = response.headers.get("X-Message-ID")
|
||||
notification.mark_sent(external_id)
|
||||
|
||||
async def _send_slack(
|
||||
self, notification: Notification, channel: NotificationChannel
|
||||
):
|
||||
"""Send Slack message"""
|
||||
|
||||
credentials = channel.credentials or {}
|
||||
webhook_url = credentials.get("webhook_url")
|
||||
|
||||
if not webhook_url:
|
||||
raise ValueError("Slack webhook URL not configured")
|
||||
|
||||
payload = {
|
||||
"channel": notification.recipients[0],
|
||||
"text": notification.body,
|
||||
"username": channel.config.get("username", "Enclava Bot"),
|
||||
}
|
||||
|
||||
if notification.subject:
|
||||
payload["attachments"] = [
|
||||
{"title": notification.subject, "text": notification.body}
|
||||
]
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
webhook_url, json=payload, timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as response:
|
||||
if response.status >= 400:
|
||||
raise Exception(
|
||||
f"Slack webhook failed with status {response.status}: {await response.text()}"
|
||||
)
|
||||
|
||||
notification.mark_sent()
|
||||
|
||||
async def _get_or_default_channel(
|
||||
self, notification_type: NotificationType, channel_name: Optional[str] = None
|
||||
) -> Optional[NotificationChannel]:
|
||||
"""Get specific channel or default for notification type"""
|
||||
|
||||
if channel_name:
|
||||
stmt = select(NotificationChannel).where(
|
||||
and_(
|
||||
NotificationChannel.name == channel_name,
|
||||
NotificationChannel.is_active == True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
stmt = select(NotificationChannel).where(
|
||||
and_(
|
||||
NotificationChannel.notification_type == notification_type,
|
||||
NotificationChannel.is_active == True,
|
||||
NotificationChannel.is_default == True,
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def _get_channel_by_id(
|
||||
self, channel_id: int
|
||||
) -> Optional[NotificationChannel]:
|
||||
"""Get channel by ID"""
|
||||
|
||||
stmt = select(NotificationChannel).where(NotificationChannel.id == channel_id)
|
||||
result = await self.db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def _get_template(self, template_name: str) -> Optional[NotificationTemplate]:
|
||||
"""Get notification template by name"""
|
||||
|
||||
stmt = select(NotificationTemplate).where(
|
||||
and_(
|
||||
NotificationTemplate.name == template_name,
|
||||
NotificationTemplate.is_active == True,
|
||||
)
|
||||
)
|
||||
result = await self.db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def _render_template(
|
||||
self, template: NotificationTemplate, variables: Dict[str, Any]
|
||||
) -> Dict[str, str]:
|
||||
"""Render template with variables"""
|
||||
|
||||
rendered = {}
|
||||
|
||||
# Render subject
|
||||
if template.subject_template:
|
||||
subject_tmpl = Template(template.subject_template)
|
||||
rendered["subject"] = subject_tmpl.render(**variables)
|
||||
|
||||
# Render body
|
||||
body_tmpl = Template(template.body_template)
|
||||
rendered["body"] = body_tmpl.render(**variables)
|
||||
|
||||
# Render HTML body
|
||||
if template.html_template:
|
||||
html_tmpl = Template(template.html_template)
|
||||
rendered["html_body"] = html_tmpl.render(**variables)
|
||||
|
||||
return rendered
|
||||
|
||||
# Management methods
|
||||
|
||||
async def create_template(
|
||||
self,
|
||||
name: str,
|
||||
display_name: str,
|
||||
notification_type: NotificationType,
|
||||
body_template: str,
|
||||
subject_template: Optional[str] = None,
|
||||
html_template: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
default_priority: NotificationPriority = NotificationPriority.NORMAL,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
) -> NotificationTemplate:
|
||||
"""Create notification template"""
|
||||
|
||||
template = NotificationTemplate(
|
||||
name=name,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
notification_type=notification_type,
|
||||
subject_template=subject_template,
|
||||
body_template=body_template,
|
||||
html_template=html_template,
|
||||
default_priority=default_priority,
|
||||
variables=variables or {},
|
||||
)
|
||||
|
||||
self.db.add(template)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(template)
|
||||
|
||||
return template
|
||||
|
||||
async def create_channel(
|
||||
self,
|
||||
name: str,
|
||||
display_name: str,
|
||||
notification_type: NotificationType,
|
||||
config: Dict[str, Any],
|
||||
credentials: Optional[Dict[str, Any]] = None,
|
||||
is_default: bool = False,
|
||||
) -> NotificationChannel:
|
||||
"""Create notification channel"""
|
||||
|
||||
channel = NotificationChannel(
|
||||
name=name,
|
||||
display_name=display_name,
|
||||
notification_type=notification_type,
|
||||
config=config,
|
||||
credentials=credentials,
|
||||
is_default=is_default,
|
||||
)
|
||||
|
||||
self.db.add(channel)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(channel)
|
||||
|
||||
return channel
|
||||
|
||||
async def get_notification_stats(self) -> Dict[str, Any]:
|
||||
"""Get notification statistics"""
|
||||
|
||||
# Total notifications
|
||||
total_notifications = await self.db.execute(select(func.count(Notification.id)))
|
||||
total_count = total_notifications.scalar()
|
||||
|
||||
# Notifications by status
|
||||
status_counts = await self.db.execute(
|
||||
select(Notification.status, func.count(Notification.id)).group_by(
|
||||
Notification.status
|
||||
)
|
||||
)
|
||||
status_stats = dict(status_counts.all())
|
||||
|
||||
# Recent notifications (last 24h)
|
||||
twenty_four_hours_ago = datetime.utcnow() - timedelta(hours=24)
|
||||
recent_notifications = await self.db.execute(
|
||||
select(func.count(Notification.id)).where(
|
||||
Notification.created_at >= twenty_four_hours_ago
|
||||
)
|
||||
)
|
||||
recent_count = recent_notifications.scalar()
|
||||
|
||||
# Channel performance
|
||||
channel_stats = await self.db.execute(
|
||||
select(
|
||||
NotificationChannel.name,
|
||||
NotificationChannel.success_count,
|
||||
NotificationChannel.failure_count,
|
||||
)
|
||||
)
|
||||
channel_performance = [
|
||||
{
|
||||
"name": name,
|
||||
"success_count": success,
|
||||
"failure_count": failure,
|
||||
"success_rate": success / (success + failure)
|
||||
if (success + failure) > 0
|
||||
else 0,
|
||||
}
|
||||
for name, success, failure in channel_stats.all()
|
||||
]
|
||||
|
||||
return {
|
||||
"total_notifications": total_count,
|
||||
"status_breakdown": status_stats,
|
||||
"recent_notifications": recent_count,
|
||||
"channel_performance": channel_performance,
|
||||
}
|
||||
@@ -15,7 +15,9 @@ logger = logging.getLogger(__name__)
|
||||
class OllamaEmbeddingService:
|
||||
"""Service for generating text embeddings using Ollama"""
|
||||
|
||||
def __init__(self, model_name: str = "bge-m3", base_url: str = "http://172.17.0.1:11434"):
|
||||
def __init__(
|
||||
self, model_name: str = "bge-m3", base_url: str = "http://172.17.0.1:11434"
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.dimension = 1024 # bge-m3 dimension
|
||||
@@ -37,20 +39,28 @@ class OllamaEmbeddingService:
|
||||
return False
|
||||
|
||||
data = await resp.json()
|
||||
models = [model['name'].split(':')[0] for model in data.get('models', [])]
|
||||
models = [
|
||||
model["name"].split(":")[0] for model in data.get("models", [])
|
||||
]
|
||||
|
||||
if self.model_name not in models:
|
||||
logger.error(f"Model {self.model_name} not found in Ollama. Available: {models}")
|
||||
logger.error(
|
||||
f"Model {self.model_name} not found in Ollama. Available: {models}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Test embedding generation
|
||||
test_embedding = await self.get_embedding("test")
|
||||
if not test_embedding or len(test_embedding) != self.dimension:
|
||||
logger.error(f"Failed to generate test embedding with {self.model_name}")
|
||||
logger.error(
|
||||
f"Failed to generate test embedding with {self.model_name}"
|
||||
)
|
||||
return False
|
||||
|
||||
self.initialized = True
|
||||
logger.info(f"Ollama embedding service initialized with model: {self.model_name} (dimension: {self.dimension})")
|
||||
logger.info(
|
||||
f"Ollama embedding service initialized with model: {self.model_name} (dimension: {self.dimension})"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -85,27 +95,32 @@ class OllamaEmbeddingService:
|
||||
# Call Ollama embedding API
|
||||
async with self._session.post(
|
||||
f"{self.base_url}/api/embeddings",
|
||||
json={
|
||||
"model": self.model_name,
|
||||
"prompt": text
|
||||
}
|
||||
json={"model": self.model_name, "prompt": text},
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
logger.error(f"Ollama embedding request failed: {resp.status}")
|
||||
logger.error(
|
||||
f"Ollama embedding request failed: {resp.status}"
|
||||
)
|
||||
embeddings.append(self._generate_fallback_embedding(text))
|
||||
continue
|
||||
|
||||
result = await resp.json()
|
||||
|
||||
if 'embedding' in result:
|
||||
embedding = result['embedding']
|
||||
if "embedding" in result:
|
||||
embedding = result["embedding"]
|
||||
if len(embedding) == self.dimension:
|
||||
embeddings.append(embedding)
|
||||
else:
|
||||
logger.warning(f"Embedding dimension mismatch: expected {self.dimension}, got {len(embedding)}")
|
||||
embeddings.append(self._generate_fallback_embedding(text))
|
||||
logger.warning(
|
||||
f"Embedding dimension mismatch: expected {self.dimension}, got {len(embedding)}"
|
||||
)
|
||||
embeddings.append(
|
||||
self._generate_fallback_embedding(text)
|
||||
)
|
||||
else:
|
||||
logger.error(f"No embedding in Ollama response for text: {text[:50]}...")
|
||||
logger.error(
|
||||
f"No embedding in Ollama response for text: {text[:50]}..."
|
||||
)
|
||||
embeddings.append(self._generate_fallback_embedding(text))
|
||||
|
||||
except Exception as e:
|
||||
@@ -156,7 +171,7 @@ class OllamaEmbeddingService:
|
||||
"dimension": self.dimension,
|
||||
"backend": "Ollama",
|
||||
"base_url": self.base_url,
|
||||
"initialized": self.initialized
|
||||
"initialized": self.initialized,
|
||||
}
|
||||
|
||||
async def cleanup(self):
|
||||
|
||||
@@ -10,13 +10,16 @@ from typing import Dict, List, Set, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from app.core.logging import get_logger
|
||||
from app.utils.exceptions import CustomHTTPException
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PermissionAction(str, Enum):
|
||||
"""Standard permission actions"""
|
||||
|
||||
CREATE = "create"
|
||||
READ = "read"
|
||||
UPDATE = "update"
|
||||
@@ -30,6 +33,7 @@ class PermissionAction(str, Enum):
|
||||
@dataclass
|
||||
class Permission:
|
||||
"""Permission definition"""
|
||||
|
||||
resource: str
|
||||
action: str
|
||||
description: str = ""
|
||||
@@ -39,6 +43,7 @@ class Permission:
|
||||
@dataclass
|
||||
class PermissionScope:
|
||||
"""Permission scope for context-aware permissions"""
|
||||
|
||||
namespace: str
|
||||
resource: str
|
||||
action: str
|
||||
@@ -137,33 +142,22 @@ class ModulePermissionRegistry:
|
||||
def _initialize_default_roles(self) -> Dict[str, List[str]]:
|
||||
"""Initialize default permission roles"""
|
||||
return {
|
||||
"super_admin": [
|
||||
"platform:*",
|
||||
"modules:*",
|
||||
"llm:*"
|
||||
],
|
||||
"admin": [
|
||||
"platform:*",
|
||||
"modules:*",
|
||||
"llm:*"
|
||||
],
|
||||
"super_admin": ["platform:*", "modules:*", "llm:*"],
|
||||
"admin": ["platform:*", "modules:*", "llm:*"],
|
||||
"developer": [
|
||||
"platform:api-keys:*",
|
||||
"platform:budgets:read",
|
||||
"llm:completions:execute",
|
||||
"llm:embeddings:execute",
|
||||
"modules:*:read",
|
||||
"modules:*:execute"
|
||||
"modules:*:execute",
|
||||
],
|
||||
"user": [
|
||||
"llm:completions:execute",
|
||||
"llm:embeddings:execute",
|
||||
"modules:*:read"
|
||||
"modules:*:read",
|
||||
],
|
||||
"readonly": [
|
||||
"platform:*:read",
|
||||
"modules:*:read"
|
||||
]
|
||||
"readonly": ["platform:*:read", "modules:*:read"],
|
||||
}
|
||||
|
||||
def register_module(self, module_id: str, permissions: List[Permission]):
|
||||
@@ -187,32 +181,25 @@ class ModulePermissionRegistry:
|
||||
Permission("users", "update", "Update users"),
|
||||
Permission("users", "delete", "Delete users"),
|
||||
Permission("users", "manage", "Full user management"),
|
||||
|
||||
Permission("api-keys", "create", "Create API keys"),
|
||||
Permission("api-keys", "read", "View API keys"),
|
||||
Permission("api-keys", "update", "Update API keys"),
|
||||
Permission("api-keys", "delete", "Delete API keys"),
|
||||
Permission("api-keys", "manage", "Full API key management"),
|
||||
|
||||
Permission("budgets", "create", "Create budgets"),
|
||||
Permission("budgets", "read", "View budgets"),
|
||||
Permission("budgets", "update", "Update budgets"),
|
||||
Permission("budgets", "delete", "Delete budgets"),
|
||||
Permission("budgets", "manage", "Full budget management"),
|
||||
|
||||
Permission("audit", "read", "View audit logs"),
|
||||
Permission("audit", "export", "Export audit logs"),
|
||||
|
||||
Permission("settings", "read", "View settings"),
|
||||
Permission("settings", "update", "Update settings"),
|
||||
Permission("settings", "manage", "Full settings management"),
|
||||
|
||||
Permission("health", "read", "View health status"),
|
||||
Permission("metrics", "read", "View metrics"),
|
||||
|
||||
Permission("permissions", "read", "View permissions"),
|
||||
Permission("permissions", "manage", "Manage permissions"),
|
||||
|
||||
Permission("roles", "create", "Create roles"),
|
||||
Permission("roles", "read", "View roles"),
|
||||
Permission("roles", "update", "Update roles"),
|
||||
@@ -238,8 +225,9 @@ class ModulePermissionRegistry:
|
||||
logger.info("Registered platform and LLM permissions")
|
||||
self._platform_permissions_registered = True
|
||||
|
||||
def check_permission(self, user_permissions: List[str], required: str,
|
||||
context: Dict[str, Any] = None) -> bool:
|
||||
def check_permission(
|
||||
self, user_permissions: List[str], required: str, context: Dict[str, Any] = None
|
||||
) -> bool:
|
||||
"""Check if user has required permission"""
|
||||
# Basic permission check
|
||||
has_perm = self.tree.has_permission(user_permissions, required)
|
||||
@@ -253,8 +241,9 @@ class ModulePermissionRegistry:
|
||||
|
||||
return True
|
||||
|
||||
def _check_context_permissions(self, user_permissions: List[str],
|
||||
required: str, context: Dict[str, Any]) -> bool:
|
||||
def _check_context_permissions(
|
||||
self, user_permissions: List[str], required: str, context: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Check context-aware permissions"""
|
||||
# Extract resource owner information
|
||||
resource_owner = context.get("owner_id")
|
||||
@@ -266,24 +255,32 @@ class ModulePermissionRegistry:
|
||||
|
||||
# Check for elevated permissions for cross-user access
|
||||
if resource_owner and resource_owner != current_user:
|
||||
elevated_required = required.replace(":read", ":manage").replace(":update", ":manage")
|
||||
elevated_required = required.replace(":read", ":manage").replace(
|
||||
":update", ":manage"
|
||||
)
|
||||
return self.tree.has_permission(user_permissions, elevated_required)
|
||||
|
||||
return True
|
||||
|
||||
def get_user_permissions(self, roles: List[str],
|
||||
custom_permissions: List[str] = None) -> List[str]:
|
||||
def get_user_permissions(
|
||||
self, roles: List[str], custom_permissions: List[str] = None
|
||||
) -> List[str]:
|
||||
"""Get effective permissions for a user based on roles and custom permissions"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
logger.info(f"=== GET USER PERMISSIONS START === Roles: {roles}, Custom perms: {custom_permissions}")
|
||||
logger.info(
|
||||
f"=== GET USER PERMISSIONS START === Roles: {roles}, Custom perms: {custom_permissions}"
|
||||
)
|
||||
|
||||
try:
|
||||
permissions = set()
|
||||
|
||||
# Add role-based permissions
|
||||
for role in roles:
|
||||
role_perms = self.role_permissions.get(role, self.default_roles.get(role, []))
|
||||
role_perms = self.role_permissions.get(
|
||||
role, self.default_roles.get(role, [])
|
||||
)
|
||||
logger.info(f"Role '{role}' has {len(role_perms)} permissions")
|
||||
permissions.update(role_perms)
|
||||
|
||||
@@ -295,20 +292,26 @@ class ModulePermissionRegistry:
|
||||
result = list(permissions)
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
logger.info(f"=== GET USER PERMISSIONS END === Total permissions: {len(result)}, Duration: {duration:.3f}s")
|
||||
logger.info(
|
||||
f"=== GET USER PERMISSIONS END === Total permissions: {len(result)}, Duration: {duration:.3f}s"
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
logger.error(f"=== GET USER PERMISSIONS FAILED === Duration: {duration:.3f}s, Error: {e}")
|
||||
logger.error(
|
||||
f"=== GET USER PERMISSIONS FAILED === Duration: {duration:.3f}s, Error: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def get_module_permissions(self, module_id: str) -> List[Permission]:
|
||||
"""Get all permissions for a specific module"""
|
||||
return self.module_permissions.get(module_id, [])
|
||||
|
||||
def get_available_permissions(self, namespace: str = None) -> Dict[str, List[Permission]]:
|
||||
def get_available_permissions(
|
||||
self, namespace: str = None
|
||||
) -> Dict[str, List[Permission]]:
|
||||
"""Get all available permissions, optionally filtered by namespace"""
|
||||
if namespace:
|
||||
filtered = {}
|
||||
@@ -345,11 +348,7 @@ class ModulePermissionRegistry:
|
||||
else:
|
||||
invalid.append(perm)
|
||||
|
||||
return {
|
||||
"valid": valid,
|
||||
"invalid": invalid,
|
||||
"is_valid": len(invalid) == 0
|
||||
}
|
||||
return {"valid": valid, "invalid": invalid, "is_valid": len(invalid) == 0}
|
||||
|
||||
def _is_valid_wildcard(self, permission: str) -> bool:
|
||||
"""Check if a wildcard permission is valid"""
|
||||
@@ -371,6 +370,7 @@ class ModulePermissionRegistry:
|
||||
|
||||
def get_permission_hierarchy(self) -> Dict[str, Any]:
|
||||
"""Get the permission hierarchy tree structure"""
|
||||
|
||||
def build_tree(node, path=""):
|
||||
tree = {}
|
||||
for key, value in node.items():
|
||||
@@ -378,7 +378,7 @@ class ModulePermissionRegistry:
|
||||
tree["_permission"] = {
|
||||
"resource": value.resource,
|
||||
"action": value.action,
|
||||
"description": value.description
|
||||
"description": value.description,
|
||||
}
|
||||
else:
|
||||
current_path = f"{path}:{key}" if path else key
|
||||
@@ -388,7 +388,11 @@ class ModulePermissionRegistry:
|
||||
return build_tree(self.tree.root)
|
||||
|
||||
|
||||
def require_permission(user_permissions: List[str], required_permission: str, context: Optional[Dict[str, Any]] = None):
|
||||
def require_permission(
|
||||
user_permissions: List[str],
|
||||
required_permission: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Decorator function to require a specific permission
|
||||
Raises HTTPException if user doesn't have the required permission
|
||||
@@ -403,11 +407,46 @@ def require_permission(user_permissions: List[str], required_permission: str, co
|
||||
"""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
if not permission_registry.check_permission(user_permissions, required_permission, context):
|
||||
logger.warning(f"Permission denied: required '{required_permission}', user has {user_permissions}")
|
||||
raise HTTPException(
|
||||
if not permission_registry.check_permission(
|
||||
user_permissions, required_permission, context
|
||||
):
|
||||
logger.warning(
|
||||
f"Permission denied: required '{required_permission}', user has {user_permissions}"
|
||||
)
|
||||
|
||||
# Create user-friendly error message
|
||||
permission_parts = required_permission.split(":")
|
||||
if len(permission_parts) >= 2:
|
||||
resource = permission_parts[0]
|
||||
action = permission_parts[1]
|
||||
|
||||
# Create more specific error messages based on resource type
|
||||
if resource == "platform":
|
||||
if action == "settings":
|
||||
error_message = "You don't have permission to access platform settings. Contact your administrator to request access."
|
||||
elif action == "analytics":
|
||||
error_message = "You don't have permission to access analytics. Contact your administrator to request access."
|
||||
else:
|
||||
error_message = f"You don't have permission to {action} platform resources. Contact your administrator to request access."
|
||||
elif resource == "tools":
|
||||
error_message = f"You don't have permission to {action} tools. Contact your administrator to request tool management access."
|
||||
elif resource == "users":
|
||||
error_message = f"You don't have permission to {action} user accounts. Contact your administrator to request user management access."
|
||||
elif resource == "api_keys":
|
||||
error_message = f"You don't have permission to {action} API keys. Contact your administrator to request API key management access."
|
||||
else:
|
||||
error_message = f"You don't have permission to {action} {resource}. Contact your administrator to request access."
|
||||
else:
|
||||
error_message = f"Insufficient permissions. Contact your administrator to request access to: {required_permission}"
|
||||
|
||||
raise CustomHTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Insufficient permissions. Required: {required_permission}"
|
||||
error_code="INSUFFICIENT_PERMISSIONS",
|
||||
detail=error_message,
|
||||
details={
|
||||
"required_permission": required_permission,
|
||||
"suggestion": "Contact your administrator to request the necessary permissions for this operation."
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class PluginAutoDiscovery:
|
||||
plugin_info = await self._discover_plugin(item)
|
||||
if plugin_info:
|
||||
discovered.append(plugin_info)
|
||||
self.discovered_plugins[plugin_info['slug']] = plugin_info
|
||||
self.discovered_plugins[plugin_info["slug"]] = plugin_info
|
||||
|
||||
logger.info(f"Discovered {len(discovered)} plugins in directory")
|
||||
return discovered
|
||||
@@ -69,7 +69,9 @@ class PluginAutoDiscovery:
|
||||
validation_result = validate_manifest_file(manifest_path)
|
||||
|
||||
if not validation_result["valid"]:
|
||||
logger.warning(f"Invalid manifest for plugin {plugin_path.name}: {validation_result['errors']}")
|
||||
logger.warning(
|
||||
f"Invalid manifest for plugin {plugin_path.name}: {validation_result['errors']}"
|
||||
)
|
||||
return None
|
||||
|
||||
manifest = validation_result["manifest"]
|
||||
@@ -86,6 +88,7 @@ class PluginAutoDiscovery:
|
||||
|
||||
# Convert manifest to JSON-serializable format
|
||||
import json
|
||||
|
||||
manifest_dict = json.loads(manifest.json())
|
||||
|
||||
plugin_info = {
|
||||
@@ -101,10 +104,12 @@ class PluginAutoDiscovery:
|
||||
"main_py_path": str(main_py_path),
|
||||
"manifest_hash": manifest_hash,
|
||||
"package_hash": package_hash,
|
||||
"discovered_at": datetime.now(timezone.utc)
|
||||
"discovered_at": datetime.now(timezone.utc),
|
||||
}
|
||||
|
||||
logger.info(f"Discovered plugin: {manifest.metadata.name} v{manifest.metadata.version}")
|
||||
logger.info(
|
||||
f"Discovered plugin: {manifest.metadata.name} v{manifest.metadata.version}"
|
||||
)
|
||||
return plugin_info
|
||||
|
||||
except Exception as e:
|
||||
@@ -136,23 +141,33 @@ class PluginAutoDiscovery:
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
successful_registrations = sum(1 for success in registration_results.values() if success)
|
||||
logger.info(f"Plugin registration complete: {successful_registrations}/{len(registration_results)} successful")
|
||||
successful_registrations = sum(
|
||||
1 for success in registration_results.values() if success
|
||||
)
|
||||
logger.info(
|
||||
f"Plugin registration complete: {successful_registrations}/{len(registration_results)} successful"
|
||||
)
|
||||
|
||||
return registration_results
|
||||
|
||||
async def _register_single_plugin(self, db: Session, plugin_info: Dict[str, Any]) -> bool:
|
||||
async def _register_single_plugin(
|
||||
self, db: Session, plugin_info: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Register a single plugin in the database"""
|
||||
try:
|
||||
plugin_slug = plugin_info["slug"]
|
||||
|
||||
# Check if plugin already exists by slug
|
||||
existing_plugin = db.query(Plugin).filter(Plugin.slug == plugin_slug).first()
|
||||
existing_plugin = (
|
||||
db.query(Plugin).filter(Plugin.slug == plugin_slug).first()
|
||||
)
|
||||
|
||||
if existing_plugin:
|
||||
# Update existing plugin if version is different
|
||||
if existing_plugin.version != plugin_info["version"]:
|
||||
logger.info(f"Updating plugin {plugin_slug}: {existing_plugin.version} -> {plugin_info['version']}")
|
||||
logger.info(
|
||||
f"Updating plugin {plugin_slug}: {existing_plugin.version} -> {plugin_info['version']}"
|
||||
)
|
||||
|
||||
existing_plugin.version = plugin_info["version"]
|
||||
existing_plugin.description = plugin_info["description"]
|
||||
@@ -190,7 +205,7 @@ class PluginAutoDiscovery:
|
||||
package_hash=plugin_info["package_hash"],
|
||||
status="installed",
|
||||
installed_by_user_id=1, # System installation
|
||||
auto_enable=True # Auto-enable discovered plugins
|
||||
auto_enable=True, # Auto-enable discovered plugins
|
||||
)
|
||||
|
||||
db.add(plugin)
|
||||
@@ -204,10 +219,14 @@ class PluginAutoDiscovery:
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Database error registering plugin {plugin_info['slug']}: {e}")
|
||||
logger.error(
|
||||
f"Database error registering plugin {plugin_info['slug']}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
async def _setup_plugin_database(self, plugin_id: str, plugin_info: Dict[str, Any]) -> bool:
|
||||
async def _setup_plugin_database(
|
||||
self, plugin_id: str, plugin_info: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Setup database schema for plugin"""
|
||||
try:
|
||||
manifest_data = plugin_info["manifest_data"]
|
||||
@@ -234,11 +253,12 @@ class PluginAutoDiscovery:
|
||||
try:
|
||||
# Load plugin into sandbox using the correct method
|
||||
plugin_dir = Path(plugin_info["plugin_path"])
|
||||
plugin_token = f"plugin_{plugin_slug}_token" # Generate a token for the plugin
|
||||
plugin_token = (
|
||||
f"plugin_{plugin_slug}_token" # Generate a token for the plugin
|
||||
)
|
||||
|
||||
plugin_instance = await plugin_loader.load_plugin_with_sandbox(
|
||||
plugin_dir,
|
||||
plugin_token
|
||||
plugin_dir, plugin_token
|
||||
)
|
||||
|
||||
if plugin_instance:
|
||||
@@ -253,7 +273,9 @@ class PluginAutoDiscovery:
|
||||
loading_results[plugin_slug] = False
|
||||
|
||||
successful_loads = sum(1 for success in loading_results.values() if success)
|
||||
logger.info(f"Plugin loading complete: {successful_loads}/{len(loading_results)} successful")
|
||||
logger.info(
|
||||
f"Plugin loading complete: {successful_loads}/{len(loading_results)} successful"
|
||||
)
|
||||
|
||||
return loading_results
|
||||
|
||||
@@ -268,8 +290,8 @@ class PluginAutoDiscovery:
|
||||
"summary": {
|
||||
"total_discovered": 0,
|
||||
"successful_registrations": 0,
|
||||
"successful_loads": 0
|
||||
}
|
||||
"successful_loads": 0,
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -285,16 +307,22 @@ class PluginAutoDiscovery:
|
||||
# Step 2: Register plugins in database
|
||||
registration_results = await self.register_discovered_plugins()
|
||||
results["registered"] = registration_results
|
||||
results["summary"]["successful_registrations"] = sum(1 for success in registration_results.values() if success)
|
||||
results["summary"]["successful_registrations"] = sum(
|
||||
1 for success in registration_results.values() if success
|
||||
)
|
||||
|
||||
# Step 3: Load plugins into sandbox
|
||||
loading_results = await self.load_discovered_plugins()
|
||||
results["loaded"] = loading_results
|
||||
results["summary"]["successful_loads"] = sum(1 for success in loading_results.values() if success)
|
||||
results["summary"]["successful_loads"] = sum(
|
||||
1 for success in loading_results.values() if success
|
||||
)
|
||||
|
||||
logger.info(f"Auto-discovery complete! Discovered: {results['summary']['total_discovered']}, "
|
||||
logger.info(
|
||||
f"Auto-discovery complete! Discovered: {results['summary']['total_discovered']}, "
|
||||
f"Registered: {results['summary']['successful_registrations']}, "
|
||||
f"Loaded: {results['summary']['successful_loads']}")
|
||||
f"Loaded: {results['summary']['successful_loads']}"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -310,7 +338,7 @@ class PluginAutoDiscovery:
|
||||
"plugins_dir_exists": self.plugins_dir.exists(),
|
||||
"discovered_plugins": list(self.discovered_plugins.keys()),
|
||||
"discovery_count": len(self.discovered_plugins),
|
||||
"last_scan": datetime.now(timezone.utc).isoformat()
|
||||
"last_scan": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@@ -329,7 +357,14 @@ async def initialize_plugin_autodiscovery() -> Dict[str, Any]:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Plugin auto-discovery initialization failed: {e}")
|
||||
return {"error": str(e), "summary": {"total_discovered": 0, "successful_registrations": 0, "successful_loads": 0}}
|
||||
return {
|
||||
"error": str(e),
|
||||
"summary": {
|
||||
"total_discovered": 0,
|
||||
"successful_registrations": 0,
|
||||
"successful_loads": 0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Convenience function for manual plugin discovery
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user