mirror of
https://github.com/aljazceru/enclava.git
synced 2026-01-31 05:24:38 +01:00
mega changes
This commit is contained in:
13
.gitignore
vendored
13
.gitignore
vendored
@@ -66,3 +66,16 @@ frontend/node_modules/
|
||||
node_modules/
|
||||
venv/
|
||||
venv_memory_monitor/
|
||||
to_delete/
|
||||
security-review.md
|
||||
features-plan.md
|
||||
AGENTS.md
|
||||
CLAUDE.md
|
||||
backend/CURL_VERIFICATION_EXAMPLES.md
|
||||
backend/OPENAI_COMPATIBILITY_GUIDE.md
|
||||
backend/production-deployment-guide.md
|
||||
features-plan.md
|
||||
security-review.md
|
||||
backend/.env.local
|
||||
backend/.env.test
|
||||
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
import asyncio
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
load_dotenv('.env.local')
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
@@ -4,7 +4,7 @@ This migration represents the complete, accurate database schema based on the ac
|
||||
model files in the codebase. All legacy migrations have been consolidated into this
|
||||
single migration to ensure the database matches what the models expect.
|
||||
|
||||
Revision ID: 000_consolidated_ground_truth_schema
|
||||
Revision ID: 000_ground_truth
|
||||
Revises:
|
||||
Create Date: 2025-08-22 10:30:00.000000
|
||||
|
||||
|
||||
62
backend/alembic/versions/001_add_roles_table.py
Normal file
62
backend/alembic/versions/001_add_roles_table.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""add roles table
|
||||
|
||||
Revision ID: 001
|
||||
Revises: 000_ground_truth
|
||||
Create Date: 2025-01-30 00:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers
|
||||
revision = '001_add_roles_table'
|
||||
down_revision = '000_ground_truth'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# Create roles table
|
||||
op.create_table(
|
||||
'roles',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(length=50), nullable=False),
|
||||
sa.Column('display_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('level', sa.String(length=20), nullable=False),
|
||||
sa.Column('permissions', sa.JSON(), nullable=True),
|
||||
sa.Column('can_manage_users', sa.Boolean(), nullable=True, default=False),
|
||||
sa.Column('can_manage_budgets', sa.Boolean(), nullable=True, default=False),
|
||||
sa.Column('can_view_reports', sa.Boolean(), nullable=True, default=False),
|
||||
sa.Column('can_manage_tools', sa.Boolean(), nullable=True, default=False),
|
||||
sa.Column('inherits_from', sa.JSON(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
|
||||
sa.Column('is_system_role', sa.Boolean(), nullable=True, default=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
|
||||
# Create indexes for roles
|
||||
op.create_index(op.f('ix_roles_id'), 'roles', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_roles_name'), 'roles', ['name'], unique=True)
|
||||
op.create_index(op.f('ix_roles_level'), 'roles', ['level'], unique=False)
|
||||
|
||||
# Add role_id to users table
|
||||
op.add_column('users', sa.Column('role_id', sa.Integer(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
'fk_users_role_id', 'users', 'roles',
|
||||
['role_id'], ['id'], ondelete='SET NULL'
|
||||
)
|
||||
op.create_index('ix_users_role_id', 'users', ['role_id'])
|
||||
|
||||
def downgrade():
|
||||
# Remove role_id from users
|
||||
op.drop_index('ix_users_role_id', table_name='users')
|
||||
op.drop_constraint('fk_users_role_id', table_name='users', type_='foreignkey')
|
||||
op.drop_column('users', 'role_id')
|
||||
|
||||
# Drop roles table
|
||||
op.drop_index(op.f('ix_roles_level'), table_name='roles')
|
||||
op.drop_index(op.f('ix_roles_name'), table_name='roles')
|
||||
op.drop_index(op.f('ix_roles_id'), table_name='roles')
|
||||
op.drop_table('roles')
|
||||
118
backend/alembic/versions/002_add_tools_tables.py
Normal file
118
backend/alembic/versions/002_add_tools_tables.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""add tools tables
|
||||
|
||||
Revision ID: 002
|
||||
Revises: 001_add_roles_table
|
||||
Create Date: 2025-01-30 00:00:01.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers
|
||||
revision = '002_add_tools_tables'
|
||||
down_revision = '001_add_roles_table'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# Create tool_categories table
|
||||
op.create_table(
|
||||
'tool_categories',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(length=50), nullable=False),
|
||||
sa.Column('display_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('icon', sa.String(length=50), nullable=True),
|
||||
sa.Column('color', sa.String(length=20), nullable=True),
|
||||
sa.Column('sort_order', sa.Integer(), nullable=True, default=0),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_tool_categories_id'), 'tool_categories', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_tool_categories_name'), 'tool_categories', ['name'], unique=True)
|
||||
|
||||
# Create tools table
|
||||
op.create_table(
|
||||
'tools',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('display_name', sa.String(length=200), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('tool_type', sa.String(length=20), nullable=False),
|
||||
sa.Column('code', sa.Text(), nullable=False),
|
||||
sa.Column('parameters_schema', sa.JSON(), nullable=True),
|
||||
sa.Column('return_schema', sa.JSON(), nullable=True),
|
||||
sa.Column('timeout_seconds', sa.Integer(), nullable=True, default=30),
|
||||
sa.Column('max_memory_mb', sa.Integer(), nullable=True, default=256),
|
||||
sa.Column('max_cpu_seconds', sa.Float(), nullable=True, default=10.0),
|
||||
sa.Column('docker_image', sa.String(length=200), nullable=True),
|
||||
sa.Column('docker_command', sa.Text(), nullable=True),
|
||||
sa.Column('is_public', sa.Boolean(), nullable=True, default=False),
|
||||
sa.Column('is_approved', sa.Boolean(), nullable=True, default=False),
|
||||
sa.Column('created_by_user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('category', sa.String(length=50), nullable=True),
|
||||
sa.Column('tags', sa.JSON(), nullable=True),
|
||||
sa.Column('usage_count', sa.Integer(), nullable=True, default=0),
|
||||
sa.Column('last_used_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_tools_id'), 'tools', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_tools_name'), 'tools', ['name'], unique=False)
|
||||
op.create_foreign_key(
|
||||
'fk_tools_created_by_user_id', 'tools', 'users',
|
||||
['created_by_user_id'], ['id'], ondelete='CASCADE'
|
||||
)
|
||||
|
||||
# Create tool_executions table
|
||||
op.create_table(
|
||||
'tool_executions',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('tool_id', sa.Integer(), nullable=False),
|
||||
sa.Column('executed_by_user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('parameters', sa.JSON(), nullable=True),
|
||||
sa.Column('status', sa.String(length=20), nullable=False, default='pending'),
|
||||
sa.Column('output', sa.Text(), nullable=True),
|
||||
sa.Column('error_message', sa.Text(), nullable=True),
|
||||
sa.Column('return_code', sa.Integer(), nullable=True),
|
||||
sa.Column('execution_time_ms', sa.Integer(), nullable=True),
|
||||
sa.Column('memory_used_mb', sa.Float(), nullable=True),
|
||||
sa.Column('cpu_time_ms', sa.Integer(), nullable=True),
|
||||
sa.Column('container_id', sa.String(length=100), nullable=True),
|
||||
sa.Column('docker_logs', sa.Text(), nullable=True),
|
||||
sa.Column('started_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('completed_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_tool_executions_id'), 'tool_executions', ['id'], unique=False)
|
||||
op.create_foreign_key(
|
||||
'fk_tool_executions_tool_id', 'tool_executions', 'tools',
|
||||
['tool_id'], ['id'], ondelete='CASCADE'
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'fk_tool_executions_executed_by_user_id', 'tool_executions', 'users',
|
||||
['executed_by_user_id'], ['id'], ondelete='CASCADE'
|
||||
)
|
||||
|
||||
def downgrade():
|
||||
# Drop tool_executions table
|
||||
op.drop_constraint('fk_tool_executions_executed_by_user_id', table_name='tool_executions', type_='foreignkey')
|
||||
op.drop_constraint('fk_tool_executions_tool_id', table_name='tool_executions', type_='foreignkey')
|
||||
op.drop_index(op.f('ix_tool_executions_id'), table_name='tool_executions')
|
||||
op.drop_table('tool_executions')
|
||||
|
||||
# Drop tools table
|
||||
op.drop_constraint('fk_tools_created_by_user_id', table_name='tools', type_='foreignkey')
|
||||
op.drop_index(op.f('ix_tools_name'), table_name='tools')
|
||||
op.drop_index(op.f('ix_tools_id'), table_name='tools')
|
||||
op.drop_table('tools')
|
||||
|
||||
# Drop tool_categories table
|
||||
op.drop_index(op.f('ix_tool_categories_name'), table_name='tool_categories')
|
||||
op.drop_index(op.f('ix_tool_categories_id'), table_name='tool_categories')
|
||||
op.drop_table('tool_categories')
|
||||
132
backend/alembic/versions/003_add_notifications_tables.py
Normal file
132
backend/alembic/versions/003_add_notifications_tables.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""add notifications tables
|
||||
|
||||
Revision ID: 003_add_notifications_tables
|
||||
Revises: 002_add_tools_tables
|
||||
Create Date: 2025-01-30 00:00:02.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers
|
||||
revision = '003_add_notifications_tables'
|
||||
down_revision = '002_add_tools_tables'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# Create notification_templates table
|
||||
op.create_table(
|
||||
'notification_templates',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('display_name', sa.String(length=200), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('notification_type', sa.String(length=20), nullable=False),
|
||||
sa.Column('subject_template', sa.Text(), nullable=True),
|
||||
sa.Column('body_template', sa.Text(), nullable=False),
|
||||
sa.Column('html_template', sa.Text(), nullable=True),
|
||||
sa.Column('default_priority', sa.String(length=20), nullable=True, default='normal'),
|
||||
sa.Column('variables', sa.JSON(), nullable=True),
|
||||
sa.Column('metadata', sa.JSON(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_notification_templates_id'), 'notification_templates', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_notification_templates_name'), 'notification_templates', ['name'], unique=True)
|
||||
|
||||
# Create notification_channels table
|
||||
op.create_table(
|
||||
'notification_channels',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('display_name', sa.String(length=200), nullable=False),
|
||||
sa.Column('notification_type', sa.String(length=20), nullable=False),
|
||||
sa.Column('config', sa.JSON(), nullable=False),
|
||||
sa.Column('credentials', sa.JSON(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
|
||||
sa.Column('is_default', sa.Boolean(), nullable=True, default=False),
|
||||
sa.Column('rate_limit', sa.Integer(), nullable=True, default=100),
|
||||
sa.Column('retry_count', sa.Integer(), nullable=True, default=3),
|
||||
sa.Column('retry_delay_minutes', sa.Integer(), nullable=True, default=5),
|
||||
sa.Column('last_used_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('success_count', sa.Integer(), nullable=True, default=0),
|
||||
sa.Column('failure_count', sa.Integer(), nullable=True, default=0),
|
||||
sa.Column('last_error', sa.Text(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_notification_channels_id'), 'notification_channels', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_notification_channels_name'), 'notification_channels', ['name'], unique=False)
|
||||
|
||||
# Create notifications table
|
||||
op.create_table(
|
||||
'notifications',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('subject', sa.String(length=500), nullable=True),
|
||||
sa.Column('body', sa.Text(), nullable=False),
|
||||
sa.Column('html_body', sa.Text(), nullable=True),
|
||||
sa.Column('recipients', sa.JSON(), nullable=False),
|
||||
sa.Column('cc_recipients', sa.JSON(), nullable=True),
|
||||
sa.Column('bcc_recipients', sa.JSON(), nullable=True),
|
||||
sa.Column('priority', sa.String(length=20), nullable=True, default='normal'),
|
||||
sa.Column('scheduled_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('expires_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('template_id', sa.Integer(), nullable=True),
|
||||
sa.Column('channel_id', sa.Integer(), nullable=False),
|
||||
sa.Column('user_id', sa.Integer(), nullable=True),
|
||||
sa.Column('status', sa.String(length=20), nullable=True, default='pending'),
|
||||
sa.Column('attempts', sa.Integer(), nullable=True, default=0),
|
||||
sa.Column('max_attempts', sa.Integer(), nullable=True, default=3),
|
||||
sa.Column('sent_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('delivered_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('failed_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('error_message', sa.Text(), nullable=True),
|
||||
sa.Column('external_id', sa.String(length=200), nullable=True),
|
||||
sa.Column('callback_url', sa.String(length=500), nullable=True),
|
||||
sa.Column('metadata', sa.JSON(), nullable=True),
|
||||
sa.Column('tags', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_notifications_id'), 'notifications', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_notifications_status'), 'notifications', ['status'], unique=False)
|
||||
op.create_index(op.f('ix_notifications_scheduled_at'), 'notifications', ['scheduled_at'], unique=False)
|
||||
|
||||
# Add foreign key constraints
|
||||
op.create_foreign_key(
|
||||
'fk_notifications_template_id', 'notifications', 'notification_templates',
|
||||
['template_id'], ['id'], ondelete='SET NULL'
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'fk_notifications_channel_id', 'notifications', 'notification_channels',
|
||||
['channel_id'], ['id'], ondelete='CASCADE'
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'fk_notifications_user_id', 'notifications', 'users',
|
||||
['user_id'], ['id'], ondelete='SET NULL'
|
||||
)
|
||||
|
||||
def downgrade():
|
||||
# Drop notifications table
|
||||
op.drop_constraint('fk_notifications_user_id', table_name='notifications', type_='foreignkey')
|
||||
op.drop_constraint('fk_notifications_channel_id', table_name='notifications', type_='foreignkey')
|
||||
op.drop_constraint('fk_notifications_template_id', table_name='notifications', type_='foreignkey')
|
||||
op.drop_index(op.f('ix_notifications_scheduled_at'), table_name='notifications')
|
||||
op.drop_index(op.f('ix_notifications_status'), table_name='notifications')
|
||||
op.drop_index(op.f('ix_notifications_id'), table_name='notifications')
|
||||
op.drop_table('notifications')
|
||||
|
||||
# Drop notification_channels table
|
||||
op.drop_index(op.f('ix_notification_channels_name'), table_name='notification_channels')
|
||||
op.drop_index(op.f('ix_notification_channels_id'), table_name='notification_channels')
|
||||
op.drop_table('notification_channels')
|
||||
|
||||
# Drop notification_templates table
|
||||
op.drop_index(op.f('ix_notification_templates_name'), table_name='notification_templates')
|
||||
op.drop_index(op.f('ix_notification_templates_id'), table_name='notification_templates')
|
||||
op.drop_table('notification_templates')
|
||||
26
backend/alembic/versions/004_add_force_password_change.py
Normal file
26
backend/alembic/versions/004_add_force_password_change.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Add force_password_change to users
|
||||
|
||||
Revision ID: 004_add_force_password_change
|
||||
Revises: fd999a559a35
|
||||
Create Date: 2025-01-31 10:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '004_add_force_password_change'
|
||||
down_revision = 'fd999a559a35'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Add force_password_change column to users table
|
||||
op.add_column('users', sa.Column('force_password_change', sa.Boolean(), default=False, nullable=False, server_default='false'))
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Remove force_password_change column from users table
|
||||
op.drop_column('users', 'force_password_change')
|
||||
79
backend/alembic/versions/005_fix_user_nullable_columns.py
Normal file
79
backend/alembic/versions/005_fix_user_nullable_columns.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""fix user nullable columns
|
||||
|
||||
Revision ID: 005_fix_user_nullable_columns
|
||||
Revises: 004_add_force_password_change
|
||||
Create Date: 2025-11-20 08:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "005_fix_user_nullable_columns"
|
||||
down_revision = "004_add_force_password_change"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""
|
||||
Fix nullable columns in users table:
|
||||
- Backfill NULL values for account_locked and failed_login_attempts
|
||||
- Set proper server defaults
|
||||
- Alter columns to NOT NULL
|
||||
"""
|
||||
# Use connection to execute raw SQL for backfilling
|
||||
conn = op.get_bind()
|
||||
|
||||
# Backfill NULL values for account_locked
|
||||
conn.execute(
|
||||
sa.text("UPDATE users SET account_locked = FALSE WHERE account_locked IS NULL")
|
||||
)
|
||||
|
||||
# Backfill NULL values for failed_login_attempts
|
||||
conn.execute(
|
||||
sa.text("UPDATE users SET failed_login_attempts = 0 WHERE failed_login_attempts IS NULL")
|
||||
)
|
||||
|
||||
# Backfill NULL values for custom_permissions (use empty JSON object)
|
||||
conn.execute(
|
||||
sa.text("UPDATE users SET custom_permissions = '{}' WHERE custom_permissions IS NULL")
|
||||
)
|
||||
|
||||
# Now alter columns to NOT NULL with server defaults
|
||||
# Note: PostgreSQL syntax
|
||||
op.alter_column('users', 'account_locked',
|
||||
existing_type=sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false())
|
||||
|
||||
op.alter_column('users', 'failed_login_attempts',
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False,
|
||||
server_default='0')
|
||||
|
||||
op.alter_column('users', 'custom_permissions',
|
||||
existing_type=sa.JSON(),
|
||||
nullable=False,
|
||||
server_default='{}')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""
|
||||
Revert columns to nullable (original state from fd999a559a35)
|
||||
"""
|
||||
op.alter_column('users', 'account_locked',
|
||||
existing_type=sa.Boolean(),
|
||||
nullable=True,
|
||||
server_default=None)
|
||||
|
||||
op.alter_column('users', 'failed_login_attempts',
|
||||
existing_type=sa.Integer(),
|
||||
nullable=True,
|
||||
server_default=None)
|
||||
|
||||
op.alter_column('users', 'custom_permissions',
|
||||
existing_type=sa.JSON(),
|
||||
nullable=True,
|
||||
server_default=None)
|
||||
@@ -0,0 +1,71 @@
|
||||
"""fix missing user columns
|
||||
|
||||
Revision ID: fd999a559a35
|
||||
Revises: 003_add_notifications_tables
|
||||
Create Date: 2025-10-30 11:33:42.236622
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "fd999a559a35"
|
||||
down_revision = "003_add_notifications_tables"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add missing columns to users table
|
||||
# These columns should have been added in 001_add_roles_table.py but were not
|
||||
|
||||
# Use try/except to handle cases where columns might already exist
|
||||
try:
|
||||
op.add_column("users", sa.Column("custom_permissions", sa.JSON(), nullable=True, default=dict))
|
||||
except Exception:
|
||||
pass # Column might already exist
|
||||
|
||||
try:
|
||||
op.add_column("users", sa.Column("account_locked", sa.Boolean(), nullable=True, default=False))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
op.add_column("users", sa.Column("account_locked_until", sa.DateTime(), nullable=True))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
op.add_column("users", sa.Column("failed_login_attempts", sa.Integer(), nullable=True, default=0))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
op.add_column("users", sa.Column("last_failed_login", sa.DateTime(), nullable=True))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the columns
|
||||
try:
|
||||
op.drop_column("users", "last_failed_login")
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
op.drop_column("users", "failed_login_attempts")
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
op.drop_column("users", "account_locked_until")
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
op.drop_column("users", "account_locked")
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
op.drop_column("users", "custom_permissions")
|
||||
except Exception:
|
||||
pass
|
||||
@@ -4,4 +4,4 @@ Enclava - Modular AI Platform
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "Enclava Team"
|
||||
__description__ = "Enclava - Modular AI Platform with confidential processing"
|
||||
__description__ = "Enclava - Modular AI Platform with confidential processing"
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""
|
||||
API package
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -44,7 +44,9 @@ class HealthChecker:
|
||||
await session.execute(select(1))
|
||||
|
||||
# Check table availability
|
||||
await session.execute(text("SELECT COUNT(*) FROM information_schema.tables"))
|
||||
await session.execute(
|
||||
text("SELECT COUNT(*) FROM information_schema.tables")
|
||||
)
|
||||
|
||||
duration = time.time() - start_time
|
||||
|
||||
@@ -54,8 +56,8 @@ class HealthChecker:
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"details": {
|
||||
"connection": "successful",
|
||||
"query_execution": "successful"
|
||||
}
|
||||
"query_execution": "successful",
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -64,10 +66,7 @@ class HealthChecker:
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"details": {
|
||||
"connection": "failed",
|
||||
"error_type": type(e).__name__
|
||||
}
|
||||
"details": {"connection": "failed", "error_type": type(e).__name__},
|
||||
}
|
||||
|
||||
async def check_memory_health(self) -> Dict[str, Any]:
|
||||
@@ -107,7 +106,7 @@ class HealthChecker:
|
||||
"process_memory_mb": round(process_memory_mb, 2),
|
||||
"system_memory_percent": memory.percent,
|
||||
"system_available_gb": round(memory.available / (1024**3), 2),
|
||||
"issues": issues
|
||||
"issues": issues,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -115,7 +114,7 @@ class HealthChecker:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
async def check_connection_health(self) -> Dict[str, Any]:
|
||||
@@ -128,8 +127,16 @@ class HealthChecker:
|
||||
|
||||
# Analyze connections
|
||||
total_connections = len(connections)
|
||||
established_connections = len([c for c in connections if c.status == 'ESTABLISHED'])
|
||||
http_connections = len([c for c in connections if any(port in str(c.laddr) for port in [80, 8000, 3000])])
|
||||
established_connections = len(
|
||||
[c for c in connections if c.status == "ESTABLISHED"]
|
||||
)
|
||||
http_connections = len(
|
||||
[
|
||||
c
|
||||
for c in connections
|
||||
if any(port in str(c.laddr) for port in [80, 8000, 3000])
|
||||
]
|
||||
)
|
||||
|
||||
# Check for connection issues
|
||||
connection_status = "healthy"
|
||||
@@ -154,7 +161,7 @@ class HealthChecker:
|
||||
"total_connections": total_connections,
|
||||
"established_connections": established_connections,
|
||||
"http_connections": http_connections,
|
||||
"issues": issues
|
||||
"issues": issues,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -162,7 +169,7 @@ class HealthChecker:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
async def check_embedding_service_health(self) -> Dict[str, Any]:
|
||||
@@ -193,7 +200,7 @@ class HealthChecker:
|
||||
"response_time_ms": round(duration * 1000, 2),
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"stats": stats,
|
||||
"issues": issues
|
||||
"issues": issues,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -201,7 +208,7 @@ class HealthChecker:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
async def check_redis_health(self) -> Dict[str, Any]:
|
||||
@@ -209,7 +216,7 @@ class HealthChecker:
|
||||
if not settings.REDIS_URL:
|
||||
return {
|
||||
"status": "not_configured",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -233,7 +240,7 @@ class HealthChecker:
|
||||
return {
|
||||
"status": "healthy",
|
||||
"response_time_ms": round(duration * 1000, 2),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -241,7 +248,7 @@ class HealthChecker:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
async def get_comprehensive_health(self) -> Dict[str, Any]:
|
||||
@@ -251,7 +258,7 @@ class HealthChecker:
|
||||
"memory": await self.check_memory_health(),
|
||||
"connections": await self.check_connection_health(),
|
||||
"embedding_service": await self.check_embedding_service_health(),
|
||||
"redis": await self.check_redis_health()
|
||||
"redis": await self.check_redis_health(),
|
||||
}
|
||||
|
||||
# Determine overall status
|
||||
@@ -274,12 +281,16 @@ class HealthChecker:
|
||||
"summary": {
|
||||
"total_checks": len(checks),
|
||||
"healthy_checks": len([s for s in statuses if s == "healthy"]),
|
||||
"degraded_checks": len([s for s in statuses if s in ["warning", "degraded", "unhealthy"]]),
|
||||
"failed_checks": len([s for s in statuses if s in ["critical", "error"]]),
|
||||
"total_issues": total_issues
|
||||
"degraded_checks": len(
|
||||
[s for s in statuses if s in ["warning", "degraded", "unhealthy"]]
|
||||
),
|
||||
"failed_checks": len(
|
||||
[s for s in statuses if s in ["critical", "error"]]
|
||||
),
|
||||
"total_issues": total_issues,
|
||||
},
|
||||
"version": "1.0.0",
|
||||
"uptime_seconds": int(time.time() - psutil.boot_time())
|
||||
"uptime_seconds": int(time.time() - psutil.boot_time()),
|
||||
}
|
||||
|
||||
|
||||
@@ -294,7 +305,7 @@ async def basic_health_check():
|
||||
"status": "healthy",
|
||||
"app": settings.APP_NAME,
|
||||
"version": "1.0.0",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@@ -307,7 +318,7 @@ async def detailed_health_check():
|
||||
logger.error(f"Detailed health check failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Health check failed: {str(e)}"
|
||||
detail=f"Health check failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -320,7 +331,7 @@ async def memory_health_check():
|
||||
logger.error(f"Memory health check failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Memory health check failed: {str(e)}"
|
||||
detail=f"Memory health check failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -333,7 +344,7 @@ async def connection_health_check():
|
||||
logger.error(f"Connection health check failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Connection health check failed: {str(e)}"
|
||||
detail=f"Connection health check failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -346,5 +357,5 @@ async def embedding_service_health_check():
|
||||
logger.error(f"Embedding service health check failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Embedding service health check failed: {str(e)}"
|
||||
)
|
||||
detail=f"Embedding service health check failed: {str(e)}",
|
||||
)
|
||||
|
||||
@@ -19,6 +19,7 @@ from ..v1.platform import router as platform_router
|
||||
from ..v1.llm_internal import router as llm_internal_router
|
||||
from ..v1.chatbot import router as chatbot_router
|
||||
from .debugging import router as debugging_router
|
||||
from ..v1.endpoints.user_management import router as user_management_router
|
||||
|
||||
# Create internal API router
|
||||
internal_api_router = APIRouter()
|
||||
@@ -27,47 +28,82 @@ internal_api_router = APIRouter()
|
||||
internal_api_router.include_router(auth_router, prefix="/auth", tags=["internal-auth"])
|
||||
|
||||
# Include modules routes (frontend management)
|
||||
internal_api_router.include_router(modules_router, prefix="/modules", tags=["internal-modules"])
|
||||
internal_api_router.include_router(
|
||||
modules_router, prefix="/modules", tags=["internal-modules"]
|
||||
)
|
||||
|
||||
# Include platform routes (frontend platform management)
|
||||
internal_api_router.include_router(platform_router, prefix="/platform", tags=["internal-platform"])
|
||||
# Include platform routes (frontend platform management)
|
||||
internal_api_router.include_router(
|
||||
platform_router, prefix="/platform", tags=["internal-platform"]
|
||||
)
|
||||
|
||||
# Include user management routes (frontend user admin)
|
||||
internal_api_router.include_router(users_router, prefix="/users", tags=["internal-users"])
|
||||
internal_api_router.include_router(
|
||||
users_router, prefix="/users", tags=["internal-users"]
|
||||
)
|
||||
|
||||
# Include API key management routes (frontend API key management)
|
||||
internal_api_router.include_router(api_keys_router, prefix="/api-keys", tags=["internal-api-keys"])
|
||||
internal_api_router.include_router(
|
||||
api_keys_router, prefix="/api-keys", tags=["internal-api-keys"]
|
||||
)
|
||||
|
||||
# Include budget management routes (frontend budget management)
|
||||
internal_api_router.include_router(budgets_router, prefix="/budgets", tags=["internal-budgets"])
|
||||
internal_api_router.include_router(
|
||||
budgets_router, prefix="/budgets", tags=["internal-budgets"]
|
||||
)
|
||||
|
||||
# Include audit log routes (frontend audit viewing)
|
||||
internal_api_router.include_router(audit_router, prefix="/audit", tags=["internal-audit"])
|
||||
internal_api_router.include_router(
|
||||
audit_router, prefix="/audit", tags=["internal-audit"]
|
||||
)
|
||||
|
||||
# Include settings management routes (frontend settings)
|
||||
internal_api_router.include_router(settings_router, prefix="/settings", tags=["internal-settings"])
|
||||
internal_api_router.include_router(
|
||||
settings_router, prefix="/settings", tags=["internal-settings"]
|
||||
)
|
||||
|
||||
# Include analytics routes (frontend analytics viewing)
|
||||
internal_api_router.include_router(analytics_router, prefix="/analytics", tags=["internal-analytics"])
|
||||
internal_api_router.include_router(
|
||||
analytics_router, prefix="/analytics", tags=["internal-analytics"]
|
||||
)
|
||||
|
||||
# Include RAG routes (frontend RAG document management)
|
||||
internal_api_router.include_router(rag_router, prefix="/rag", tags=["internal-rag"])
|
||||
|
||||
# Include RAG debug routes (for demo and debugging)
|
||||
internal_api_router.include_router(rag_debug_router, prefix="/rag/debug", tags=["internal-rag-debug"])
|
||||
internal_api_router.include_router(
|
||||
rag_debug_router, prefix="/rag/debug", tags=["internal-rag-debug"]
|
||||
)
|
||||
|
||||
# Include prompt template routes (frontend prompt template management)
|
||||
internal_api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["internal-prompt-templates"])
|
||||
internal_api_router.include_router(
|
||||
prompt_templates_router,
|
||||
prefix="/prompt-templates",
|
||||
tags=["internal-prompt-templates"],
|
||||
)
|
||||
|
||||
|
||||
# Include plugin registry routes (frontend plugin management)
|
||||
internal_api_router.include_router(plugin_registry_router, prefix="/plugins", tags=["internal-plugins"])
|
||||
internal_api_router.include_router(
|
||||
plugin_registry_router, prefix="/plugins", tags=["internal-plugins"]
|
||||
)
|
||||
|
||||
# Include internal LLM routes (frontend LLM service access with JWT auth)
|
||||
internal_api_router.include_router(llm_internal_router, prefix="/llm", tags=["internal-llm"])
|
||||
internal_api_router.include_router(
|
||||
llm_internal_router, prefix="/llm", tags=["internal-llm"]
|
||||
)
|
||||
|
||||
# Include chatbot routes (frontend chatbot management)
|
||||
internal_api_router.include_router(chatbot_router, prefix="/chatbot", tags=["internal-chatbot"])
|
||||
internal_api_router.include_router(
|
||||
chatbot_router, prefix="/chatbot", tags=["internal-chatbot"]
|
||||
)
|
||||
|
||||
# Include debugging routes (troubleshooting and diagnostics)
|
||||
internal_api_router.include_router(debugging_router, prefix="/debugging", tags=["internal-debugging"])
|
||||
internal_api_router.include_router(
|
||||
debugging_router, prefix="/debugging", tags=["internal-debugging"]
|
||||
)
|
||||
|
||||
# Include user management routes (advanced user and role management)
|
||||
internal_api_router.include_router(
|
||||
user_management_router, prefix="/user-management", tags=["internal-user-management"]
|
||||
)
|
||||
|
||||
@@ -20,23 +20,28 @@ router = APIRouter()
|
||||
async def get_chatbot_config_debug(
|
||||
chatbot_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get detailed configuration for debugging a specific chatbot"""
|
||||
|
||||
# Get chatbot instance
|
||||
chatbot = db.query(ChatbotInstance).filter(
|
||||
ChatbotInstance.id == chatbot_id,
|
||||
ChatbotInstance.user_id == current_user.id
|
||||
).first()
|
||||
chatbot = (
|
||||
db.query(ChatbotInstance)
|
||||
.filter(
|
||||
ChatbotInstance.id == chatbot_id, ChatbotInstance.user_id == current_user.id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not chatbot:
|
||||
raise HTTPException(status_code=404, detail="Chatbot not found")
|
||||
|
||||
# Get prompt template
|
||||
prompt_template = db.query(PromptTemplate).filter(
|
||||
PromptTemplate.type == chatbot.chatbot_type
|
||||
).first()
|
||||
prompt_template = (
|
||||
db.query(PromptTemplate)
|
||||
.filter(PromptTemplate.type == chatbot.chatbot_type)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Get RAG collections if configured
|
||||
rag_collections = []
|
||||
@@ -44,31 +49,37 @@ async def get_chatbot_config_debug(
|
||||
collection_ids = chatbot.rag_collection_ids
|
||||
if isinstance(collection_ids, str):
|
||||
import json
|
||||
|
||||
try:
|
||||
collection_ids = json.loads(collection_ids)
|
||||
except:
|
||||
collection_ids = []
|
||||
|
||||
if collection_ids:
|
||||
collections = db.query(RagCollection).filter(
|
||||
RagCollection.id.in_(collection_ids)
|
||||
).all()
|
||||
collections = (
|
||||
db.query(RagCollection)
|
||||
.filter(RagCollection.id.in_(collection_ids))
|
||||
.all()
|
||||
)
|
||||
rag_collections = [
|
||||
{
|
||||
"id": col.id,
|
||||
"name": col.name,
|
||||
"document_count": col.document_count,
|
||||
"qdrant_collection_name": col.qdrant_collection_name,
|
||||
"is_active": col.is_active
|
||||
"is_active": col.is_active,
|
||||
}
|
||||
for col in collections
|
||||
]
|
||||
|
||||
# Get recent conversations count
|
||||
from app.models.chatbot import ChatbotConversation
|
||||
conversation_count = db.query(ChatbotConversation).filter(
|
||||
ChatbotConversation.chatbot_instance_id == chatbot_id
|
||||
).count()
|
||||
|
||||
conversation_count = (
|
||||
db.query(ChatbotConversation)
|
||||
.filter(ChatbotConversation.chatbot_instance_id == chatbot_id)
|
||||
.count()
|
||||
)
|
||||
|
||||
return {
|
||||
"chatbot": {
|
||||
@@ -78,20 +89,20 @@ async def get_chatbot_config_debug(
|
||||
"description": chatbot.description,
|
||||
"created_at": chatbot.created_at,
|
||||
"is_active": chatbot.is_active,
|
||||
"conversation_count": conversation_count
|
||||
"conversation_count": conversation_count,
|
||||
},
|
||||
"prompt_template": {
|
||||
"type": prompt_template.type if prompt_template else None,
|
||||
"system_prompt": prompt_template.system_prompt if prompt_template else None,
|
||||
"variables": prompt_template.variables if prompt_template else []
|
||||
"variables": prompt_template.variables if prompt_template else [],
|
||||
},
|
||||
"rag_collections": rag_collections,
|
||||
"configuration": {
|
||||
"max_tokens": chatbot.max_tokens,
|
||||
"temperature": chatbot.temperature,
|
||||
"streaming": chatbot.streaming,
|
||||
"memory_config": chatbot.memory_config
|
||||
}
|
||||
"memory_config": chatbot.memory_config,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -101,15 +112,18 @@ async def test_rag_search(
|
||||
query: str = "test query",
|
||||
top_k: int = 5,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Test RAG search for a specific chatbot"""
|
||||
|
||||
# Get chatbot instance
|
||||
chatbot = db.query(ChatbotInstance).filter(
|
||||
ChatbotInstance.id == chatbot_id,
|
||||
ChatbotInstance.user_id == current_user.id
|
||||
).first()
|
||||
chatbot = (
|
||||
db.query(ChatbotInstance)
|
||||
.filter(
|
||||
ChatbotInstance.id == chatbot_id, ChatbotInstance.user_id == current_user.id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not chatbot:
|
||||
raise HTTPException(status_code=404, detail="Chatbot not found")
|
||||
@@ -123,6 +137,7 @@ async def test_rag_search(
|
||||
if chatbot.rag_collection_ids:
|
||||
if isinstance(chatbot.rag_collection_ids, str):
|
||||
import json
|
||||
|
||||
try:
|
||||
collection_ids = json.loads(chatbot.rag_collection_ids)
|
||||
except:
|
||||
@@ -134,22 +149,19 @@ async def test_rag_search(
|
||||
return {
|
||||
"query": query,
|
||||
"results": [],
|
||||
"message": "No RAG collections configured for this chatbot"
|
||||
"message": "No RAG collections configured for this chatbot",
|
||||
}
|
||||
|
||||
# Perform search
|
||||
search_results = await rag_module.search(
|
||||
query=query,
|
||||
collection_ids=collection_ids,
|
||||
top_k=top_k,
|
||||
score_threshold=0.5
|
||||
query=query, collection_ids=collection_ids, top_k=top_k, score_threshold=0.5
|
||||
)
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"results": search_results,
|
||||
"collections_searched": collection_ids,
|
||||
"result_count": len(search_results)
|
||||
"result_count": len(search_results),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -157,14 +169,13 @@ async def test_rag_search(
|
||||
"query": query,
|
||||
"results": [],
|
||||
"error": str(e),
|
||||
"message": "RAG search failed"
|
||||
"message": "RAG search failed",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/system/status")
|
||||
async def get_system_status(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
db: Session = Depends(get_db), current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get system status for debugging"""
|
||||
|
||||
@@ -179,11 +190,12 @@ async def get_system_status(
|
||||
module_status = {}
|
||||
try:
|
||||
from app.services.module_manager import module_manager
|
||||
|
||||
modules = module_manager.list_modules()
|
||||
for module_name, module_info in modules.items():
|
||||
module_status[module_name] = {
|
||||
"status": module_info.get("status", "unknown"),
|
||||
"enabled": module_info.get("enabled", False)
|
||||
"enabled": module_info.get("enabled", False),
|
||||
}
|
||||
except Exception as e:
|
||||
module_status = {"error": str(e)}
|
||||
@@ -192,6 +204,7 @@ async def get_system_status(
|
||||
redis_status = "not configured"
|
||||
try:
|
||||
from app.core.cache import core_cache
|
||||
|
||||
await core_cache.ping()
|
||||
redis_status = "healthy"
|
||||
except Exception as e:
|
||||
@@ -201,6 +214,7 @@ async def get_system_status(
|
||||
qdrant_status = "not configured"
|
||||
try:
|
||||
from app.services.qdrant_service import qdrant_service
|
||||
|
||||
collections = await qdrant_service.list_collections()
|
||||
qdrant_status = f"healthy ({len(collections)} collections)"
|
||||
except Exception as e:
|
||||
@@ -211,5 +225,5 @@ async def get_system_status(
|
||||
"modules": module_status,
|
||||
"redis": redis_status,
|
||||
"qdrant": qdrant_status,
|
||||
"timestamp": "UTC"
|
||||
}
|
||||
"timestamp": "UTC",
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ Public API v1 package - for external clients
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from ..v1.auth import router as auth_router
|
||||
from ..v1.llm import router as llm_router
|
||||
from ..v1.chatbot import router as chatbot_router
|
||||
from ..v1.openai_compat import router as openai_router
|
||||
@@ -10,6 +11,9 @@ from ..v1.openai_compat import router as openai_router
|
||||
# Create public API router
|
||||
public_api_router = APIRouter()
|
||||
|
||||
# Include authentication routes (needed for login/logout)
|
||||
public_api_router.include_router(auth_router, prefix="/auth", tags=["authentication"])
|
||||
|
||||
# Include OpenAI-compatible routes (chat/completions, models, embeddings)
|
||||
public_api_router.include_router(openai_router, tags=["openai-compat"])
|
||||
|
||||
@@ -17,4 +21,6 @@ public_api_router.include_router(openai_router, tags=["openai-compat"])
|
||||
public_api_router.include_router(llm_router, prefix="/llm", tags=["public-llm"])
|
||||
|
||||
# Include public chatbot API (external chatbot integrations)
|
||||
public_api_router.include_router(chatbot_router, prefix="/chatbot", tags=["public-chatbot"])
|
||||
public_api_router.include_router(
|
||||
chatbot_router, prefix="/chatbot", tags=["public-chatbot"]
|
||||
)
|
||||
|
||||
@@ -16,10 +16,9 @@ logger = logging.getLogger(__name__)
|
||||
# Create router
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/collections")
|
||||
async def list_collections(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
async def list_collections(current_user: User = Depends(get_current_user)):
|
||||
"""List all available RAG collections"""
|
||||
try:
|
||||
from app.services.qdrant_stats_service import qdrant_stats_service
|
||||
@@ -31,10 +30,7 @@ async def list_collections(
|
||||
# Extract collection names
|
||||
collection_names = [col["name"] for col in collections]
|
||||
|
||||
return {
|
||||
"collections": collection_names,
|
||||
"count": len(collection_names)
|
||||
}
|
||||
return {"collections": collection_names, "count": len(collection_names)}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"List collections error: {e}")
|
||||
@@ -45,10 +41,14 @@ async def list_collections(
|
||||
async def debug_search(
|
||||
query: str = Query(..., description="Search query"),
|
||||
max_results: int = Query(10, ge=1, le=50, description="Maximum number of results"),
|
||||
score_threshold: float = Query(0.3, ge=0.0, le=1.0, description="Minimum score threshold"),
|
||||
collection_name: Optional[str] = Query(None, description="Collection name to search"),
|
||||
score_threshold: float = Query(
|
||||
0.3, ge=0.0, le=1.0, description="Minimum score threshold"
|
||||
),
|
||||
collection_name: Optional[str] = Query(
|
||||
None, description="Collection name to search"
|
||||
),
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Debug search endpoint with detailed information"""
|
||||
try:
|
||||
@@ -56,9 +56,7 @@ async def debug_search(
|
||||
app_config = settings
|
||||
|
||||
# Initialize RAG module with BGE-M3 configuration
|
||||
rag_config = {
|
||||
"embedding_model": "BAAI/bge-m3"
|
||||
}
|
||||
rag_config = {"embedding_model": "BAAI/bge-m3"}
|
||||
rag_module = RAGModule(app_config, config=rag_config)
|
||||
|
||||
# Get available collections if none specified
|
||||
@@ -71,9 +69,9 @@ async def debug_search(
|
||||
"results": [],
|
||||
"debug_info": {
|
||||
"error": "No collections available",
|
||||
"collections_found": 0
|
||||
"collections_found": 0,
|
||||
},
|
||||
"search_time_ms": 0
|
||||
"search_time_ms": 0,
|
||||
}
|
||||
|
||||
# Perform search
|
||||
@@ -82,7 +80,7 @@ async def debug_search(
|
||||
max_results=max_results,
|
||||
score_threshold=score_threshold,
|
||||
collection_name=collection_name,
|
||||
config=config or {}
|
||||
config=config or {},
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -94,7 +92,7 @@ async def debug_search(
|
||||
"debug_info": {
|
||||
"error": str(e),
|
||||
"query": query,
|
||||
"collection_name": collection_name
|
||||
"collection_name": collection_name,
|
||||
},
|
||||
"search_time_ms": 0
|
||||
}
|
||||
"search_time_ms": 0,
|
||||
}
|
||||
|
||||
@@ -17,6 +17,9 @@ from .rag import router as rag_router
|
||||
from .chatbot import router as chatbot_router
|
||||
from .prompt_templates import router as prompt_templates_router
|
||||
from .plugin_registry import router as plugin_registry_router
|
||||
from .endpoints.tools import router as tools_router
|
||||
from .endpoints.tool_calling import router as tool_calling_router
|
||||
from .endpoints.user_management import router as user_management_router
|
||||
|
||||
# Create main API router
|
||||
api_router = APIRouter()
|
||||
@@ -58,9 +61,23 @@ api_router.include_router(rag_router, prefix="/rag", tags=["rag"])
|
||||
api_router.include_router(chatbot_router, prefix="/chatbot", tags=["chatbot"])
|
||||
|
||||
# Include prompt template routes
|
||||
api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["prompt-templates"])
|
||||
|
||||
api_router.include_router(
|
||||
prompt_templates_router, prefix="/prompt-templates", tags=["prompt-templates"]
|
||||
)
|
||||
|
||||
|
||||
# Include plugin registry routes
|
||||
api_router.include_router(plugin_registry_router, prefix="/plugins", tags=["plugins"])
|
||||
api_router.include_router(plugin_registry_router, prefix="/plugins", tags=["plugins"])
|
||||
|
||||
# Include tool management routes
|
||||
api_router.include_router(tools_router, prefix="/tools", tags=["tools"])
|
||||
|
||||
# Include tool calling routes
|
||||
api_router.include_router(
|
||||
tool_calling_router, prefix="/tool-calling", tags=["tool-calling"]
|
||||
)
|
||||
|
||||
# Include admin user management routes
|
||||
api_router.include_router(
|
||||
user_management_router, prefix="/admin/user-management", tags=["admin", "user-management"]
|
||||
)
|
||||
|
||||
@@ -22,110 +22,103 @@ router = APIRouter()
|
||||
async def get_usage_metrics(
|
||||
hours: int = Query(24, ge=1, le=168, description="Hours to analyze (1-168)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get comprehensive usage metrics including costs and budgets"""
|
||||
try:
|
||||
analytics = get_analytics_service()
|
||||
metrics = await analytics.get_usage_metrics(hours=hours, user_id=current_user['id'])
|
||||
return {
|
||||
"success": True,
|
||||
"data": metrics,
|
||||
"period_hours": hours
|
||||
}
|
||||
metrics = await analytics.get_usage_metrics(
|
||||
hours=hours, user_id=current_user["id"]
|
||||
)
|
||||
return {"success": True, "data": metrics, "period_hours": hours}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting usage metrics: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error getting usage metrics: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/metrics/system")
|
||||
async def get_system_metrics(
|
||||
hours: int = Query(24, ge=1, le=168),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get system-wide metrics (admin only)"""
|
||||
if not current_user['is_superuser']:
|
||||
if not current_user["is_superuser"]:
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
|
||||
try:
|
||||
analytics = get_analytics_service()
|
||||
metrics = await analytics.get_usage_metrics(hours=hours)
|
||||
return {
|
||||
"success": True,
|
||||
"data": metrics,
|
||||
"period_hours": hours
|
||||
}
|
||||
return {"success": True, "data": metrics, "period_hours": hours}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting system metrics: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error getting system metrics: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def get_system_health(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get system health status including budget and performance analysis"""
|
||||
try:
|
||||
analytics = get_analytics_service()
|
||||
health = await analytics.get_system_health()
|
||||
return {
|
||||
"success": True,
|
||||
"data": health
|
||||
}
|
||||
return {"success": True, "data": health}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting system health: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error getting system health: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/costs")
|
||||
async def get_cost_analysis(
|
||||
days: int = Query(30, ge=1, le=365, description="Days to analyze (1-365)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get detailed cost analysis and trends"""
|
||||
try:
|
||||
analytics = get_analytics_service()
|
||||
analysis = await analytics.get_cost_analysis(days=days, user_id=current_user['id'])
|
||||
return {
|
||||
"success": True,
|
||||
"data": analysis,
|
||||
"period_days": days
|
||||
}
|
||||
analysis = await analytics.get_cost_analysis(
|
||||
days=days, user_id=current_user["id"]
|
||||
)
|
||||
return {"success": True, "data": analysis, "period_days": days}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting cost analysis: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error getting cost analysis: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/costs/system")
|
||||
async def get_system_cost_analysis(
|
||||
days: int = Query(30, ge=1, le=365),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get system-wide cost analysis (admin only)"""
|
||||
if not current_user['is_superuser']:
|
||||
if not current_user["is_superuser"]:
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
|
||||
try:
|
||||
analytics = get_analytics_service()
|
||||
analysis = await analytics.get_cost_analysis(days=days)
|
||||
return {
|
||||
"success": True,
|
||||
"data": analysis,
|
||||
"period_days": days
|
||||
}
|
||||
return {"success": True, "data": analysis, "period_days": days}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting system cost analysis: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error getting system cost analysis: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/endpoints")
|
||||
async def get_endpoint_stats(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get endpoint usage statistics"""
|
||||
try:
|
||||
analytics = get_analytics_service()
|
||||
|
||||
|
||||
# For now, return the in-memory stats
|
||||
# In future, this could be enhanced with database queries
|
||||
return {
|
||||
@@ -133,74 +126,79 @@ async def get_endpoint_stats(
|
||||
"data": {
|
||||
"endpoint_stats": dict(analytics.endpoint_stats),
|
||||
"status_codes": dict(analytics.status_codes),
|
||||
"model_stats": dict(analytics.model_stats)
|
||||
}
|
||||
"model_stats": dict(analytics.model_stats),
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting endpoint stats: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error getting endpoint stats: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/usage-trends")
|
||||
async def get_usage_trends(
|
||||
days: int = Query(7, ge=1, le=30, description="Days for trend analysis"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get usage trends over time"""
|
||||
try:
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import func
|
||||
from app.models.usage_tracking import UsageTracking
|
||||
|
||||
|
||||
cutoff_time = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
|
||||
# Daily usage trends
|
||||
daily_usage = db.query(
|
||||
func.date(UsageTracking.created_at).label('date'),
|
||||
func.count(UsageTracking.id).label('requests'),
|
||||
func.sum(UsageTracking.total_tokens).label('tokens'),
|
||||
func.sum(UsageTracking.cost_cents).label('cost_cents')
|
||||
).filter(
|
||||
UsageTracking.created_at >= cutoff_time,
|
||||
UsageTracking.user_id == current_user['id']
|
||||
).group_by(
|
||||
func.date(UsageTracking.created_at)
|
||||
).order_by('date').all()
|
||||
|
||||
daily_usage = (
|
||||
db.query(
|
||||
func.date(UsageTracking.created_at).label("date"),
|
||||
func.count(UsageTracking.id).label("requests"),
|
||||
func.sum(UsageTracking.total_tokens).label("tokens"),
|
||||
func.sum(UsageTracking.cost_cents).label("cost_cents"),
|
||||
)
|
||||
.filter(
|
||||
UsageTracking.created_at >= cutoff_time,
|
||||
UsageTracking.user_id == current_user["id"],
|
||||
)
|
||||
.group_by(func.date(UsageTracking.created_at))
|
||||
.order_by("date")
|
||||
.all()
|
||||
)
|
||||
|
||||
trends = []
|
||||
for date, requests, tokens, cost_cents in daily_usage:
|
||||
trends.append({
|
||||
"date": date.isoformat(),
|
||||
"requests": requests,
|
||||
"tokens": tokens or 0,
|
||||
"cost_cents": cost_cents or 0,
|
||||
"cost_dollars": (cost_cents or 0) / 100
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"trends": trends,
|
||||
"period_days": days
|
||||
}
|
||||
}
|
||||
trends.append(
|
||||
{
|
||||
"date": date.isoformat(),
|
||||
"requests": requests,
|
||||
"tokens": tokens or 0,
|
||||
"cost_cents": cost_cents or 0,
|
||||
"cost_dollars": (cost_cents or 0) / 100,
|
||||
}
|
||||
)
|
||||
|
||||
return {"success": True, "data": {"trends": trends, "period_days": days}}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting usage trends: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error getting usage trends: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/overview")
|
||||
async def get_analytics_overview(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get analytics overview data"""
|
||||
try:
|
||||
analytics = get_analytics_service()
|
||||
|
||||
|
||||
# Get basic metrics
|
||||
metrics = await analytics.get_usage_metrics(hours=24, user_id=current_user['id'])
|
||||
metrics = await analytics.get_usage_metrics(
|
||||
hours=24, user_id=current_user["id"]
|
||||
)
|
||||
health = await analytics.get_system_health()
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
@@ -210,8 +208,8 @@ async def get_analytics_overview(
|
||||
"error_rate": metrics.error_rate,
|
||||
"budget_usage_percentage": metrics.budget_usage_percentage,
|
||||
"system_health": health.status,
|
||||
"health_score": health.score
|
||||
}
|
||||
"health_score": health.score,
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting overview: {str(e)}")
|
||||
@@ -219,16 +217,15 @@ async def get_analytics_overview(
|
||||
|
||||
@router.get("/modules")
|
||||
async def get_module_analytics(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get analytics data for all modules"""
|
||||
try:
|
||||
module_stats = []
|
||||
|
||||
|
||||
for name, module in module_manager.modules.items():
|
||||
stats = {"name": name, "initialized": getattr(module, "initialized", False)}
|
||||
|
||||
|
||||
# Get module statistics if available
|
||||
if hasattr(module, "get_stats"):
|
||||
try:
|
||||
@@ -240,18 +237,22 @@ async def get_module_analytics(
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get stats for module {name}: {e}")
|
||||
stats["error"] = str(e)
|
||||
|
||||
|
||||
module_stats.append(stats)
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"modules": module_stats,
|
||||
"total_modules": len(module_stats),
|
||||
"system_health": "healthy" if all(m.get("initialized", False) for m in module_stats) else "warning"
|
||||
}
|
||||
"system_health": "healthy"
|
||||
if all(m.get("initialized", False) for m in module_stats)
|
||||
else "warning",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get module analytics: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve module analytics")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to retrieve module analytics"
|
||||
)
|
||||
|
||||
@@ -74,47 +74,49 @@ class APIKeyResponse(BaseModel):
|
||||
last_used_at: Optional[datetime] = None
|
||||
total_requests: int
|
||||
total_tokens: int
|
||||
total_cost_cents: int = Field(alias='total_cost')
|
||||
total_cost_cents: int = Field(alias="total_cost")
|
||||
rate_limit_per_minute: Optional[int] = None
|
||||
rate_limit_per_hour: Optional[int] = None
|
||||
rate_limit_per_day: Optional[int] = None
|
||||
allowed_ips: List[str]
|
||||
allowed_models: List[str] # Model restrictions
|
||||
allowed_chatbots: List[str] # Chatbot restrictions
|
||||
budget_limit: Optional[int] = Field(None, alias='budget_limit_cents') # Budget limit in cents
|
||||
budget_limit: Optional[int] = Field(
|
||||
None, alias="budget_limit_cents"
|
||||
) # Budget limit in cents
|
||||
budget_type: Optional[str] = None # Budget type
|
||||
is_unlimited: bool = True # Unlimited budget flag
|
||||
tags: List[str]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_api_key(cls, api_key):
|
||||
"""Create response from APIKey model with formatted key prefix"""
|
||||
data = {
|
||||
'id': api_key.id,
|
||||
'name': api_key.name,
|
||||
'description': api_key.description,
|
||||
'key_prefix': api_key.key_prefix + "..." if api_key.key_prefix else "",
|
||||
'scopes': api_key.scopes,
|
||||
'is_active': api_key.is_active,
|
||||
'expires_at': api_key.expires_at,
|
||||
'created_at': api_key.created_at,
|
||||
'last_used_at': api_key.last_used_at,
|
||||
'total_requests': api_key.total_requests,
|
||||
'total_tokens': api_key.total_tokens,
|
||||
'total_cost': api_key.total_cost,
|
||||
'rate_limit_per_minute': api_key.rate_limit_per_minute,
|
||||
'rate_limit_per_hour': api_key.rate_limit_per_hour,
|
||||
'rate_limit_per_day': api_key.rate_limit_per_day,
|
||||
'allowed_ips': api_key.allowed_ips,
|
||||
'allowed_models': api_key.allowed_models,
|
||||
'allowed_chatbots': api_key.allowed_chatbots,
|
||||
'budget_limit_cents': api_key.budget_limit_cents,
|
||||
'budget_type': api_key.budget_type,
|
||||
'is_unlimited': api_key.is_unlimited,
|
||||
'tags': api_key.tags
|
||||
"id": api_key.id,
|
||||
"name": api_key.name,
|
||||
"description": api_key.description,
|
||||
"key_prefix": api_key.key_prefix + "..." if api_key.key_prefix else "",
|
||||
"scopes": api_key.scopes,
|
||||
"is_active": api_key.is_active,
|
||||
"expires_at": api_key.expires_at,
|
||||
"created_at": api_key.created_at,
|
||||
"last_used_at": api_key.last_used_at,
|
||||
"total_requests": api_key.total_requests,
|
||||
"total_tokens": api_key.total_tokens,
|
||||
"total_cost": api_key.total_cost,
|
||||
"rate_limit_per_minute": api_key.rate_limit_per_minute,
|
||||
"rate_limit_per_hour": api_key.rate_limit_per_hour,
|
||||
"rate_limit_per_day": api_key.rate_limit_per_day,
|
||||
"allowed_ips": api_key.allowed_ips,
|
||||
"allowed_models": api_key.allowed_models,
|
||||
"allowed_chatbots": api_key.allowed_chatbots,
|
||||
"budget_limit_cents": api_key.budget_limit_cents,
|
||||
"budget_type": api_key.budget_type,
|
||||
"is_unlimited": api_key.is_unlimited,
|
||||
"tags": api_key.tags,
|
||||
}
|
||||
return cls(**data)
|
||||
|
||||
@@ -148,15 +150,18 @@ class APIKeyUsageResponse(BaseModel):
|
||||
def generate_api_key() -> tuple[str, str]:
|
||||
"""Generate a new API key and return (full_key, key_hash)"""
|
||||
# Generate random key part (32 characters)
|
||||
key_part = ''.join(secrets.choice(string.ascii_letters + string.digits) for _ in range(32))
|
||||
|
||||
key_part = "".join(
|
||||
secrets.choice(string.ascii_letters + string.digits) for _ in range(32)
|
||||
)
|
||||
|
||||
# Create full key with prefix
|
||||
full_key = f"{settings.API_KEY_PREFIX}{key_part}"
|
||||
|
||||
|
||||
# Create hash for storage
|
||||
from app.core.security import get_api_key_hash
|
||||
|
||||
key_hash = get_api_key_hash(full_key)
|
||||
|
||||
|
||||
return full_key, key_hash
|
||||
|
||||
|
||||
@@ -169,73 +174,87 @@ async def list_api_keys(
|
||||
is_active: Optional[bool] = Query(None),
|
||||
search: Optional[str] = Query(None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List API keys with pagination and filtering"""
|
||||
|
||||
|
||||
# Check permissions - users can view their own API keys
|
||||
if user_id and int(user_id) != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
|
||||
if user_id and int(user_id) != current_user["id"]:
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:api-keys:read"
|
||||
)
|
||||
elif not user_id:
|
||||
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
|
||||
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:api-keys:read"
|
||||
)
|
||||
|
||||
# If no user_id specified and user doesn't have admin permissions, show only their keys
|
||||
if not user_id and "platform:api-keys:read" not in current_user.get("permissions", []):
|
||||
user_id = current_user['id']
|
||||
|
||||
if not user_id and "platform:api-keys:read" not in current_user.get(
|
||||
"permissions", []
|
||||
):
|
||||
user_id = current_user["id"]
|
||||
|
||||
# Build query
|
||||
query = select(APIKey)
|
||||
|
||||
|
||||
# Apply filters
|
||||
if user_id:
|
||||
query = query.where(APIKey.user_id == (int(user_id) if isinstance(user_id, str) else user_id))
|
||||
query = query.where(
|
||||
APIKey.user_id == (int(user_id) if isinstance(user_id, str) else user_id)
|
||||
)
|
||||
if is_active is not None:
|
||||
query = query.where(APIKey.is_active == is_active)
|
||||
if search:
|
||||
query = query.where(
|
||||
(APIKey.name.ilike(f"%{search}%")) |
|
||||
(APIKey.description.ilike(f"%{search}%"))
|
||||
(APIKey.name.ilike(f"%{search}%"))
|
||||
| (APIKey.description.ilike(f"%{search}%"))
|
||||
)
|
||||
|
||||
|
||||
# Get total count using func.count()
|
||||
total_query = select(func.count(APIKey.id))
|
||||
|
||||
|
||||
# Apply same filters for count
|
||||
if user_id:
|
||||
total_query = total_query.where(APIKey.user_id == (int(user_id) if isinstance(user_id, str) else user_id))
|
||||
total_query = total_query.where(
|
||||
APIKey.user_id == (int(user_id) if isinstance(user_id, str) else user_id)
|
||||
)
|
||||
if is_active is not None:
|
||||
total_query = total_query.where(APIKey.is_active == is_active)
|
||||
if search:
|
||||
total_query = total_query.where(
|
||||
(APIKey.name.ilike(f"%{search}%")) |
|
||||
(APIKey.description.ilike(f"%{search}%"))
|
||||
(APIKey.name.ilike(f"%{search}%"))
|
||||
| (APIKey.description.ilike(f"%{search}%"))
|
||||
)
|
||||
|
||||
|
||||
total_result = await db.execute(total_query)
|
||||
total = total_result.scalar()
|
||||
|
||||
|
||||
# Apply pagination
|
||||
offset = (page - 1) * size
|
||||
query = query.offset(offset).limit(size).order_by(APIKey.created_at.desc())
|
||||
|
||||
|
||||
# Execute query
|
||||
result = await db.execute(query)
|
||||
api_keys = result.scalars().all()
|
||||
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="list_api_keys",
|
||||
resource_type="api_key",
|
||||
details={"page": page, "size": size, "filters": {"user_id": user_id, "is_active": is_active, "search": search}}
|
||||
details={
|
||||
"page": page,
|
||||
"size": size,
|
||||
"filters": {"user_id": user_id, "is_active": is_active, "search": search},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
return APIKeyListResponse(
|
||||
api_keys=[APIKeyResponse.model_validate(key) for key in api_keys],
|
||||
total=total,
|
||||
page=page,
|
||||
size=size
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
@@ -243,34 +262,35 @@ async def list_api_keys(
|
||||
async def get_api_key(
|
||||
api_key_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get API key by ID"""
|
||||
|
||||
|
||||
# Get API key
|
||||
query = select(APIKey).where(APIKey.id == int(api_key_id))
|
||||
result = await db.execute(query)
|
||||
api_key = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
|
||||
)
|
||||
|
||||
|
||||
# Check permissions - users can view their own API keys
|
||||
if api_key.user_id != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
|
||||
|
||||
if api_key.user_id != current_user["id"]:
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:api-keys:read"
|
||||
)
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="get_api_key",
|
||||
resource_type="api_key",
|
||||
resource_id=api_key_id
|
||||
resource_id=api_key_id,
|
||||
)
|
||||
|
||||
|
||||
return APIKeyResponse.model_validate(api_key)
|
||||
|
||||
|
||||
@@ -278,24 +298,24 @@ async def get_api_key(
|
||||
async def create_api_key(
|
||||
api_key_data: APIKeyCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a new API key"""
|
||||
|
||||
|
||||
# Check permissions
|
||||
require_permission(current_user.get("permissions", []), "platform:api-keys:create")
|
||||
|
||||
|
||||
# Generate API key
|
||||
full_key, key_hash = generate_api_key()
|
||||
key_prefix = full_key[:8] # Store only first 8 characters for lookup
|
||||
|
||||
|
||||
# Create API key
|
||||
new_api_key = APIKey(
|
||||
name=api_key_data.name,
|
||||
description=api_key_data.description,
|
||||
key_hash=key_hash,
|
||||
key_prefix=key_prefix,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
scopes=api_key_data.scopes,
|
||||
expires_at=api_key_data.expires_at,
|
||||
rate_limit_per_minute=api_key_data.rate_limit_per_minute,
|
||||
@@ -305,29 +325,32 @@ async def create_api_key(
|
||||
allowed_models=api_key_data.allowed_models,
|
||||
allowed_chatbots=api_key_data.allowed_chatbots,
|
||||
is_unlimited=api_key_data.is_unlimited,
|
||||
budget_limit_cents=api_key_data.budget_limit_cents if not api_key_data.is_unlimited else None,
|
||||
budget_limit_cents=api_key_data.budget_limit_cents
|
||||
if not api_key_data.is_unlimited
|
||||
else None,
|
||||
budget_type=api_key_data.budget_type if not api_key_data.is_unlimited else None,
|
||||
tags=api_key_data.tags
|
||||
tags=api_key_data.tags,
|
||||
)
|
||||
|
||||
|
||||
db.add(new_api_key)
|
||||
await db.commit()
|
||||
await db.refresh(new_api_key)
|
||||
|
||||
|
||||
# Log audit event asynchronously (non-blocking)
|
||||
asyncio.create_task(log_audit_event_async(
|
||||
user_id=str(current_user['id']),
|
||||
action="create_api_key",
|
||||
resource_type="api_key",
|
||||
resource_id=str(new_api_key.id),
|
||||
details={"name": api_key_data.name, "scopes": api_key_data.scopes}
|
||||
))
|
||||
|
||||
asyncio.create_task(
|
||||
log_audit_event_async(
|
||||
user_id=str(current_user["id"]),
|
||||
action="create_api_key",
|
||||
resource_type="api_key",
|
||||
resource_id=str(new_api_key.id),
|
||||
details={"name": api_key_data.name, "scopes": api_key_data.scopes},
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"API key created: {new_api_key.name} by {current_user['username']}")
|
||||
|
||||
|
||||
return APIKeyCreateResponse(
|
||||
api_key=APIKeyResponse.model_validate(new_api_key),
|
||||
secret_key=full_key
|
||||
api_key=APIKeyResponse.model_validate(new_api_key), secret_key=full_key
|
||||
)
|
||||
|
||||
|
||||
@@ -336,56 +359,57 @@ async def update_api_key(
|
||||
api_key_id: str,
|
||||
api_key_data: APIKeyUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update API key"""
|
||||
|
||||
|
||||
# Get API key
|
||||
query = select(APIKey).where(APIKey.id == int(api_key_id))
|
||||
result = await db.execute(query)
|
||||
api_key = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
|
||||
)
|
||||
|
||||
|
||||
# Check permissions - users can update their own API keys
|
||||
if api_key.user_id != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:api-keys:update")
|
||||
|
||||
if api_key.user_id != current_user["id"]:
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:api-keys:update"
|
||||
)
|
||||
|
||||
# Store original values for audit
|
||||
original_values = {
|
||||
"name": api_key.name,
|
||||
"scopes": api_key.scopes,
|
||||
"is_active": api_key.is_active
|
||||
"is_active": api_key.is_active,
|
||||
}
|
||||
|
||||
|
||||
# Update API key fields
|
||||
update_data = api_key_data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(api_key, field, value)
|
||||
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(api_key)
|
||||
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="update_api_key",
|
||||
resource_type="api_key",
|
||||
resource_id=api_key_id,
|
||||
details={
|
||||
"updated_fields": list(update_data.keys()),
|
||||
"before_values": original_values,
|
||||
"after_values": {k: getattr(api_key, k) for k in update_data.keys()}
|
||||
}
|
||||
"after_values": {k: getattr(api_key, k) for k in update_data.keys()},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"API key updated: {api_key.name} by {current_user['username']}")
|
||||
|
||||
|
||||
return APIKeyResponse.model_validate(api_key)
|
||||
|
||||
|
||||
@@ -393,41 +417,42 @@ async def update_api_key(
|
||||
async def delete_api_key(
|
||||
api_key_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Delete API key"""
|
||||
|
||||
|
||||
# Get API key
|
||||
query = select(APIKey).where(APIKey.id == int(api_key_id))
|
||||
result = await db.execute(query)
|
||||
api_key = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
|
||||
)
|
||||
|
||||
|
||||
# Check permissions - users can delete their own API keys
|
||||
if api_key.user_id != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:api-keys:delete")
|
||||
|
||||
if api_key.user_id != current_user["id"]:
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:api-keys:delete"
|
||||
)
|
||||
|
||||
# Delete API key
|
||||
await db.delete(api_key)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="delete_api_key",
|
||||
resource_type="api_key",
|
||||
resource_id=api_key_id,
|
||||
details={"name": api_key.name}
|
||||
details={"name": api_key.name},
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"API key deleted: {api_key.name} by {current_user['username']}")
|
||||
|
||||
|
||||
return {"message": "API key deleted successfully"}
|
||||
|
||||
|
||||
@@ -435,51 +460,51 @@ async def delete_api_key(
|
||||
async def regenerate_api_key(
|
||||
api_key_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Regenerate API key secret"""
|
||||
|
||||
|
||||
# Get API key
|
||||
query = select(APIKey).where(APIKey.id == int(api_key_id))
|
||||
result = await db.execute(query)
|
||||
api_key = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
|
||||
)
|
||||
|
||||
|
||||
# Check permissions - users can regenerate their own API keys
|
||||
if api_key.user_id != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:api-keys:update")
|
||||
|
||||
if api_key.user_id != current_user["id"]:
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:api-keys:update"
|
||||
)
|
||||
|
||||
# Generate new API key
|
||||
full_key, key_hash = generate_api_key()
|
||||
key_prefix = full_key[:8] # Store only first 8 characters for lookup
|
||||
|
||||
|
||||
# Update API key
|
||||
api_key.key_hash = key_hash
|
||||
api_key.key_prefix = key_prefix
|
||||
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(api_key)
|
||||
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="regenerate_api_key",
|
||||
resource_type="api_key",
|
||||
resource_id=api_key_id,
|
||||
details={"name": api_key.name}
|
||||
details={"name": api_key.name},
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"API key regenerated: {api_key.name} by {current_user['username']}")
|
||||
|
||||
|
||||
return APIKeyCreateResponse(
|
||||
api_key=APIKeyResponse.model_validate(api_key),
|
||||
secret_key=full_key
|
||||
api_key=APIKeyResponse.model_validate(api_key), secret_key=full_key
|
||||
)
|
||||
|
||||
|
||||
@@ -487,65 +512,64 @@ async def regenerate_api_key(
|
||||
async def get_api_key_usage(
|
||||
api_key_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get API key usage statistics"""
|
||||
|
||||
|
||||
# Get API key
|
||||
query = select(APIKey).where(APIKey.id == int(api_key_id))
|
||||
result = await db.execute(query)
|
||||
api_key = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
|
||||
)
|
||||
|
||||
|
||||
# Check permissions - users can view their own API key usage
|
||||
if api_key.user_id != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
|
||||
|
||||
if api_key.user_id != current_user["id"]:
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:api-keys:read"
|
||||
)
|
||||
|
||||
# Calculate usage statistics
|
||||
from app.models.usage_tracking import UsageTracking
|
||||
|
||||
|
||||
now = datetime.utcnow()
|
||||
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
hour_start = now.replace(minute=0, second=0, microsecond=0)
|
||||
|
||||
|
||||
# Today's usage
|
||||
today_query = select(
|
||||
func.count(UsageTracking.id),
|
||||
func.sum(UsageTracking.total_tokens),
|
||||
func.sum(UsageTracking.cost_cents)
|
||||
func.sum(UsageTracking.cost_cents),
|
||||
).where(
|
||||
UsageTracking.api_key_id == api_key_id,
|
||||
UsageTracking.created_at >= today_start
|
||||
UsageTracking.api_key_id == api_key_id, UsageTracking.created_at >= today_start
|
||||
)
|
||||
today_result = await db.execute(today_query)
|
||||
today_stats = today_result.first()
|
||||
|
||||
|
||||
# This hour's usage
|
||||
hour_query = select(
|
||||
func.count(UsageTracking.id),
|
||||
func.sum(UsageTracking.total_tokens),
|
||||
func.sum(UsageTracking.cost_cents)
|
||||
func.sum(UsageTracking.cost_cents),
|
||||
).where(
|
||||
UsageTracking.api_key_id == api_key_id,
|
||||
UsageTracking.created_at >= hour_start
|
||||
UsageTracking.api_key_id == api_key_id, UsageTracking.created_at >= hour_start
|
||||
)
|
||||
hour_result = await db.execute(hour_query)
|
||||
hour_stats = hour_result.first()
|
||||
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="get_api_key_usage",
|
||||
resource_type="api_key",
|
||||
resource_id=api_key_id
|
||||
resource_id=api_key_id,
|
||||
)
|
||||
|
||||
|
||||
return APIKeyUsageResponse(
|
||||
api_key_id=api_key_id,
|
||||
total_requests=api_key.total_requests,
|
||||
@@ -557,7 +581,7 @@ async def get_api_key_usage(
|
||||
requests_this_hour=hour_stats[0] or 0,
|
||||
tokens_this_hour=hour_stats[1] or 0,
|
||||
cost_this_hour_cents=hour_stats[2] or 0,
|
||||
last_used_at=api_key.last_used_at
|
||||
last_used_at=api_key.last_used_at,
|
||||
)
|
||||
|
||||
|
||||
@@ -565,41 +589,42 @@ async def get_api_key_usage(
|
||||
async def activate_api_key(
|
||||
api_key_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Activate API key"""
|
||||
|
||||
|
||||
# Get API key
|
||||
query = select(APIKey).where(APIKey.id == int(api_key_id))
|
||||
result = await db.execute(query)
|
||||
api_key = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
|
||||
)
|
||||
|
||||
|
||||
# Check permissions - users can activate their own API keys
|
||||
if api_key.user_id != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:api-keys:update")
|
||||
|
||||
if api_key.user_id != current_user["id"]:
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:api-keys:update"
|
||||
)
|
||||
|
||||
# Activate API key
|
||||
api_key.is_active = True
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="activate_api_key",
|
||||
resource_type="api_key",
|
||||
resource_id=api_key_id,
|
||||
details={"name": api_key.name}
|
||||
details={"name": api_key.name},
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"API key activated: {api_key.name} by {current_user['username']}")
|
||||
|
||||
|
||||
return {"message": "API key activated successfully"}
|
||||
|
||||
|
||||
@@ -607,39 +632,40 @@ async def activate_api_key(
|
||||
async def deactivate_api_key(
|
||||
api_key_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Deactivate API key"""
|
||||
|
||||
|
||||
# Get API key
|
||||
query = select(APIKey).where(APIKey.id == int(api_key_id))
|
||||
result = await db.execute(query)
|
||||
api_key = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
|
||||
)
|
||||
|
||||
|
||||
# Check permissions - users can deactivate their own API keys
|
||||
if api_key.user_id != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:api-keys:update")
|
||||
|
||||
if api_key.user_id != current_user["id"]:
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:api-keys:update"
|
||||
)
|
||||
|
||||
# Deactivate API key
|
||||
api_key.is_active = False
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="deactivate_api_key",
|
||||
resource_type="api_key",
|
||||
resource_id=api_key_id,
|
||||
details={"name": api_key.name}
|
||||
details={"name": api_key.name},
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"API key deactivated: {api_key.name} by {current_user['username']}")
|
||||
|
||||
return {"message": "API key deactivated successfully"}
|
||||
|
||||
return {"message": "API key deactivated successfully"}
|
||||
|
||||
@@ -36,7 +36,7 @@ class AuditLogResponse(BaseModel):
|
||||
success: bool
|
||||
severity: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
@@ -96,17 +96,17 @@ async def list_audit_logs(
|
||||
severity: Optional[str] = Query(None),
|
||||
search: Optional[str] = Query(None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List audit logs with filtering and pagination"""
|
||||
|
||||
|
||||
# Check permissions
|
||||
require_permission(current_user.get("permissions", []), "platform:audit:read")
|
||||
|
||||
|
||||
# Build query
|
||||
query = select(AuditLog)
|
||||
conditions = []
|
||||
|
||||
|
||||
# Apply filters
|
||||
if user_id:
|
||||
conditions.append(AuditLog.user_id == user_id)
|
||||
@@ -128,29 +128,29 @@ async def list_audit_logs(
|
||||
search_conditions = [
|
||||
AuditLog.action.ilike(f"%{search}%"),
|
||||
AuditLog.resource_type.ilike(f"%{search}%"),
|
||||
AuditLog.details.astext.ilike(f"%{search}%")
|
||||
AuditLog.details.astext.ilike(f"%{search}%"),
|
||||
]
|
||||
conditions.append(or_(*search_conditions))
|
||||
|
||||
|
||||
if conditions:
|
||||
query = query.where(and_(*conditions))
|
||||
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count(AuditLog.id))
|
||||
if conditions:
|
||||
count_query = count_query.where(and_(*conditions))
|
||||
|
||||
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar()
|
||||
|
||||
|
||||
# Apply pagination and ordering
|
||||
offset = (page - 1) * size
|
||||
query = query.offset(offset).limit(size).order_by(AuditLog.created_at.desc())
|
||||
|
||||
|
||||
# Execute query
|
||||
result = await db.execute(query)
|
||||
logs = result.scalars().all()
|
||||
|
||||
|
||||
# Log audit event for this query
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
@@ -166,19 +166,19 @@ async def list_audit_logs(
|
||||
"end_date": end_date.isoformat() if end_date else None,
|
||||
"success": success,
|
||||
"severity": severity,
|
||||
"search": search
|
||||
"search": search,
|
||||
},
|
||||
"page": page,
|
||||
"size": size,
|
||||
"total_results": total
|
||||
}
|
||||
"total_results": total,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
return AuditLogListResponse(
|
||||
logs=[AuditLogResponse.model_validate(log) for log in logs],
|
||||
total=total,
|
||||
page=page,
|
||||
size=size
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
@@ -188,13 +188,13 @@ async def search_audit_logs(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(50, ge=1, le=1000),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Advanced search for audit logs"""
|
||||
|
||||
|
||||
# Check permissions
|
||||
require_permission(current_user.get("permissions", []), "platform:audit:read")
|
||||
|
||||
|
||||
# Use the audit service function
|
||||
logs = await get_audit_logs(
|
||||
db=db,
|
||||
@@ -205,13 +205,13 @@ async def search_audit_logs(
|
||||
start_date=search_request.start_date,
|
||||
end_date=search_request.end_date,
|
||||
limit=size,
|
||||
offset=(page - 1) * size
|
||||
offset=(page - 1) * size,
|
||||
)
|
||||
|
||||
|
||||
# Get total count for the search
|
||||
total_query = select(func.count(AuditLog.id))
|
||||
conditions = []
|
||||
|
||||
|
||||
if search_request.user_id:
|
||||
conditions.append(AuditLog.user_id == search_request.user_id)
|
||||
if search_request.action:
|
||||
@@ -234,16 +234,16 @@ async def search_audit_logs(
|
||||
search_conditions = [
|
||||
AuditLog.action.ilike(f"%{search_request.search_text}%"),
|
||||
AuditLog.resource_type.ilike(f"%{search_request.search_text}%"),
|
||||
AuditLog.details.astext.ilike(f"%{search_request.search_text}%")
|
||||
AuditLog.details.astext.ilike(f"%{search_request.search_text}%"),
|
||||
]
|
||||
conditions.append(or_(*search_conditions))
|
||||
|
||||
|
||||
if conditions:
|
||||
total_query = total_query.where(and_(*conditions))
|
||||
|
||||
|
||||
total_result = await db.execute(total_query)
|
||||
total = total_result.scalar()
|
||||
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
@@ -253,15 +253,15 @@ async def search_audit_logs(
|
||||
details={
|
||||
"search_criteria": search_request.model_dump(exclude_unset=True),
|
||||
"results_count": len(logs),
|
||||
"total_matches": total
|
||||
}
|
||||
"total_matches": total,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
return AuditLogListResponse(
|
||||
logs=[AuditLogResponse.model_validate(log) for log in logs],
|
||||
total=total,
|
||||
page=page,
|
||||
size=size
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
@@ -270,64 +270,80 @@ async def get_audit_statistics(
|
||||
start_date: Optional[datetime] = Query(None),
|
||||
end_date: Optional[datetime] = Query(None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get audit log statistics"""
|
||||
|
||||
|
||||
# Check permissions
|
||||
require_permission(current_user.get("permissions", []), "platform:audit:read")
|
||||
|
||||
|
||||
# Default to last 30 days if no dates provided
|
||||
if not end_date:
|
||||
end_date = datetime.utcnow()
|
||||
if not start_date:
|
||||
start_date = end_date - timedelta(days=30)
|
||||
|
||||
|
||||
# Get basic stats using audit service
|
||||
basic_stats = await get_audit_stats(db, start_date, end_date)
|
||||
|
||||
|
||||
# Get additional statistics
|
||||
conditions = [
|
||||
AuditLog.created_at >= start_date,
|
||||
AuditLog.created_at <= end_date
|
||||
]
|
||||
|
||||
conditions = [AuditLog.created_at >= start_date, AuditLog.created_at <= end_date]
|
||||
|
||||
# Events by user
|
||||
user_query = select(
|
||||
AuditLog.user_id,
|
||||
func.count(AuditLog.id).label('count')
|
||||
).where(and_(*conditions)).group_by(AuditLog.user_id).order_by(func.count(AuditLog.id).desc()).limit(10)
|
||||
|
||||
user_query = (
|
||||
select(AuditLog.user_id, func.count(AuditLog.id).label("count"))
|
||||
.where(and_(*conditions))
|
||||
.group_by(AuditLog.user_id)
|
||||
.order_by(func.count(AuditLog.id).desc())
|
||||
.limit(10)
|
||||
)
|
||||
|
||||
user_result = await db.execute(user_query)
|
||||
events_by_user = dict(user_result.fetchall())
|
||||
|
||||
|
||||
# Events by hour of day
|
||||
hour_query = select(
|
||||
func.extract('hour', AuditLog.created_at).label('hour'),
|
||||
func.count(AuditLog.id).label('count')
|
||||
).where(and_(*conditions)).group_by(func.extract('hour', AuditLog.created_at)).order_by('hour')
|
||||
|
||||
hour_query = (
|
||||
select(
|
||||
func.extract("hour", AuditLog.created_at).label("hour"),
|
||||
func.count(AuditLog.id).label("count"),
|
||||
)
|
||||
.where(and_(*conditions))
|
||||
.group_by(func.extract("hour", AuditLog.created_at))
|
||||
.order_by("hour")
|
||||
)
|
||||
|
||||
hour_result = await db.execute(hour_query)
|
||||
events_by_hour = dict(hour_result.fetchall())
|
||||
|
||||
|
||||
# Top actions
|
||||
top_actions_query = select(
|
||||
AuditLog.action,
|
||||
func.count(AuditLog.id).label('count')
|
||||
).where(and_(*conditions)).group_by(AuditLog.action).order_by(func.count(AuditLog.id).desc()).limit(10)
|
||||
|
||||
top_actions_query = (
|
||||
select(AuditLog.action, func.count(AuditLog.id).label("count"))
|
||||
.where(and_(*conditions))
|
||||
.group_by(AuditLog.action)
|
||||
.order_by(func.count(AuditLog.id).desc())
|
||||
.limit(10)
|
||||
)
|
||||
|
||||
top_actions_result = await db.execute(top_actions_query)
|
||||
top_actions = [{"action": row[0], "count": row[1]} for row in top_actions_result.fetchall()]
|
||||
|
||||
top_actions = [
|
||||
{"action": row[0], "count": row[1]} for row in top_actions_result.fetchall()
|
||||
]
|
||||
|
||||
# Top resources
|
||||
top_resources_query = select(
|
||||
AuditLog.resource_type,
|
||||
func.count(AuditLog.id).label('count')
|
||||
).where(and_(*conditions)).group_by(AuditLog.resource_type).order_by(func.count(AuditLog.id).desc()).limit(10)
|
||||
|
||||
top_resources_query = (
|
||||
select(AuditLog.resource_type, func.count(AuditLog.id).label("count"))
|
||||
.where(and_(*conditions))
|
||||
.group_by(AuditLog.resource_type)
|
||||
.order_by(func.count(AuditLog.id).desc())
|
||||
.limit(10)
|
||||
)
|
||||
|
||||
top_resources_result = await db.execute(top_resources_query)
|
||||
top_resources = [{"resource_type": row[0], "count": row[1]} for row in top_resources_result.fetchall()]
|
||||
|
||||
top_resources = [
|
||||
{"resource_type": row[0], "count": row[1]}
|
||||
for row in top_resources_result.fetchall()
|
||||
]
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
@@ -337,16 +353,16 @@ async def get_audit_statistics(
|
||||
details={
|
||||
"start_date": start_date.isoformat(),
|
||||
"end_date": end_date.isoformat(),
|
||||
"total_events": basic_stats["total_events"]
|
||||
}
|
||||
"total_events": basic_stats["total_events"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
return AuditStatsResponse(
|
||||
**basic_stats,
|
||||
events_by_user=events_by_user,
|
||||
events_by_hour=events_by_hour,
|
||||
top_actions=top_actions,
|
||||
top_resources=top_resources
|
||||
top_resources=top_resources,
|
||||
)
|
||||
|
||||
|
||||
@@ -354,25 +370,30 @@ async def get_audit_statistics(
|
||||
async def get_security_events(
|
||||
hours: int = Query(24, ge=1, le=168), # Last 24 hours by default, max 1 week
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get security-related events and anomalies"""
|
||||
|
||||
|
||||
# Check permissions
|
||||
require_permission(current_user.get("permissions", []), "platform:audit:read")
|
||||
|
||||
|
||||
end_time = datetime.utcnow()
|
||||
start_time = end_time - timedelta(hours=hours)
|
||||
|
||||
|
||||
# Failed logins
|
||||
failed_logins_query = select(AuditLog).where(
|
||||
and_(
|
||||
AuditLog.created_at >= start_time,
|
||||
AuditLog.action == "login",
|
||||
AuditLog.success == False
|
||||
failed_logins_query = (
|
||||
select(AuditLog)
|
||||
.where(
|
||||
and_(
|
||||
AuditLog.created_at >= start_time,
|
||||
AuditLog.action == "login",
|
||||
AuditLog.success == False,
|
||||
)
|
||||
)
|
||||
).order_by(AuditLog.created_at.desc()).limit(50)
|
||||
|
||||
.order_by(AuditLog.created_at.desc())
|
||||
.limit(50)
|
||||
)
|
||||
|
||||
failed_logins_result = await db.execute(failed_logins_query)
|
||||
failed_logins = [
|
||||
{
|
||||
@@ -380,19 +401,24 @@ async def get_security_events(
|
||||
"user_id": log.user_id,
|
||||
"ip_address": log.ip_address,
|
||||
"user_agent": log.user_agent,
|
||||
"details": log.details
|
||||
"details": log.details,
|
||||
}
|
||||
for log in failed_logins_result.scalars().all()
|
||||
]
|
||||
|
||||
|
||||
# High severity events
|
||||
high_severity_query = select(AuditLog).where(
|
||||
and_(
|
||||
AuditLog.created_at >= start_time,
|
||||
AuditLog.severity.in_(["error", "critical"])
|
||||
high_severity_query = (
|
||||
select(AuditLog)
|
||||
.where(
|
||||
and_(
|
||||
AuditLog.created_at >= start_time,
|
||||
AuditLog.severity.in_(["error", "critical"]),
|
||||
)
|
||||
)
|
||||
).order_by(AuditLog.created_at.desc()).limit(50)
|
||||
|
||||
.order_by(AuditLog.created_at.desc())
|
||||
.limit(50)
|
||||
)
|
||||
|
||||
high_severity_result = await db.execute(high_severity_query)
|
||||
high_severity_events = [
|
||||
{
|
||||
@@ -403,56 +429,65 @@ async def get_security_events(
|
||||
"user_id": log.user_id,
|
||||
"ip_address": log.ip_address,
|
||||
"success": log.success,
|
||||
"details": log.details
|
||||
"details": log.details,
|
||||
}
|
||||
for log in high_severity_result.scalars().all()
|
||||
]
|
||||
|
||||
|
||||
# Suspicious activities (multiple failed attempts from same IP)
|
||||
suspicious_ips_query = select(
|
||||
AuditLog.ip_address,
|
||||
func.count(AuditLog.id).label('failed_count')
|
||||
).where(
|
||||
and_(
|
||||
AuditLog.created_at >= start_time,
|
||||
AuditLog.success == False,
|
||||
AuditLog.ip_address.isnot(None)
|
||||
suspicious_ips_query = (
|
||||
select(AuditLog.ip_address, func.count(AuditLog.id).label("failed_count"))
|
||||
.where(
|
||||
and_(
|
||||
AuditLog.created_at >= start_time,
|
||||
AuditLog.success == False,
|
||||
AuditLog.ip_address.isnot(None),
|
||||
)
|
||||
)
|
||||
).group_by(AuditLog.ip_address).having(func.count(AuditLog.id) >= 5).order_by(func.count(AuditLog.id).desc())
|
||||
|
||||
.group_by(AuditLog.ip_address)
|
||||
.having(func.count(AuditLog.id) >= 5)
|
||||
.order_by(func.count(AuditLog.id).desc())
|
||||
)
|
||||
|
||||
suspicious_ips_result = await db.execute(suspicious_ips_query)
|
||||
suspicious_activities = [
|
||||
{
|
||||
"ip_address": row[0],
|
||||
"failed_attempts": row[1],
|
||||
"risk_level": "high" if row[1] >= 10 else "medium"
|
||||
"risk_level": "high" if row[1] >= 10 else "medium",
|
||||
}
|
||||
for row in suspicious_ips_result.fetchall()
|
||||
]
|
||||
|
||||
|
||||
# Unusual access patterns (users accessing from multiple IPs)
|
||||
unusual_access_query = select(
|
||||
AuditLog.user_id,
|
||||
func.count(func.distinct(AuditLog.ip_address)).label('ip_count'),
|
||||
func.array_agg(func.distinct(AuditLog.ip_address)).label('ip_addresses')
|
||||
).where(
|
||||
and_(
|
||||
AuditLog.created_at >= start_time,
|
||||
AuditLog.user_id.isnot(None),
|
||||
AuditLog.ip_address.isnot(None)
|
||||
unusual_access_query = (
|
||||
select(
|
||||
AuditLog.user_id,
|
||||
func.count(func.distinct(AuditLog.ip_address)).label("ip_count"),
|
||||
func.array_agg(func.distinct(AuditLog.ip_address)).label("ip_addresses"),
|
||||
)
|
||||
).group_by(AuditLog.user_id).having(func.count(func.distinct(AuditLog.ip_address)) >= 3).order_by(func.count(func.distinct(AuditLog.ip_address)).desc())
|
||||
|
||||
.where(
|
||||
and_(
|
||||
AuditLog.created_at >= start_time,
|
||||
AuditLog.user_id.isnot(None),
|
||||
AuditLog.ip_address.isnot(None),
|
||||
)
|
||||
)
|
||||
.group_by(AuditLog.user_id)
|
||||
.having(func.count(func.distinct(AuditLog.ip_address)) >= 3)
|
||||
.order_by(func.count(func.distinct(AuditLog.ip_address)).desc())
|
||||
)
|
||||
|
||||
unusual_access_result = await db.execute(unusual_access_query)
|
||||
unusual_access_patterns = [
|
||||
{
|
||||
"user_id": row[0],
|
||||
"unique_ips": row[1],
|
||||
"ip_addresses": row[2] if row[2] else []
|
||||
"ip_addresses": row[2] if row[2] else [],
|
||||
}
|
||||
for row in unusual_access_result.fetchall()
|
||||
]
|
||||
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
@@ -464,15 +499,15 @@ async def get_security_events(
|
||||
"failed_logins_count": len(failed_logins),
|
||||
"high_severity_count": len(high_severity_events),
|
||||
"suspicious_ips_count": len(suspicious_activities),
|
||||
"unusual_access_patterns_count": len(unusual_access_patterns)
|
||||
}
|
||||
"unusual_access_patterns_count": len(unusual_access_patterns),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
return SecurityEventsResponse(
|
||||
suspicious_activities=suspicious_activities,
|
||||
failed_logins=failed_logins,
|
||||
unusual_access_patterns=unusual_access_patterns,
|
||||
high_severity_events=high_severity_events
|
||||
high_severity_events=high_severity_events,
|
||||
)
|
||||
|
||||
|
||||
@@ -485,42 +520,43 @@ async def export_audit_logs(
|
||||
action: Optional[str] = Query(None),
|
||||
resource_type: Optional[str] = Query(None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Export audit logs in CSV or JSON format"""
|
||||
|
||||
|
||||
# Check permissions
|
||||
require_permission(current_user.get("permissions", []), "platform:audit:export")
|
||||
|
||||
|
||||
# Default to last 30 days if no dates provided
|
||||
if not end_date:
|
||||
end_date = datetime.utcnow()
|
||||
if not start_date:
|
||||
start_date = end_date - timedelta(days=30)
|
||||
|
||||
|
||||
# Limit export size
|
||||
max_records = 10000
|
||||
|
||||
|
||||
# Build query
|
||||
query = select(AuditLog)
|
||||
conditions = [
|
||||
AuditLog.created_at >= start_date,
|
||||
AuditLog.created_at <= end_date
|
||||
]
|
||||
|
||||
conditions = [AuditLog.created_at >= start_date, AuditLog.created_at <= end_date]
|
||||
|
||||
if user_id:
|
||||
conditions.append(AuditLog.user_id == user_id)
|
||||
if action:
|
||||
conditions.append(AuditLog.action == action)
|
||||
if resource_type:
|
||||
conditions.append(AuditLog.resource_type == resource_type)
|
||||
|
||||
query = query.where(and_(*conditions)).order_by(AuditLog.created_at.desc()).limit(max_records)
|
||||
|
||||
|
||||
query = (
|
||||
query.where(and_(*conditions))
|
||||
.order_by(AuditLog.created_at.desc())
|
||||
.limit(max_records)
|
||||
)
|
||||
|
||||
# Execute query
|
||||
result = await db.execute(query)
|
||||
logs = result.scalars().all()
|
||||
|
||||
|
||||
# Log export event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
@@ -535,13 +571,14 @@ async def export_audit_logs(
|
||||
"filters": {
|
||||
"user_id": user_id,
|
||||
"action": action,
|
||||
"resource_type": resource_type
|
||||
}
|
||||
}
|
||||
"resource_type": resource_type,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if format == "json":
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
export_data = [
|
||||
{
|
||||
"id": str(log.id),
|
||||
@@ -554,45 +591,59 @@ async def export_audit_logs(
|
||||
"user_agent": log.user_agent,
|
||||
"success": log.success,
|
||||
"severity": log.severity,
|
||||
"created_at": log.created_at.isoformat()
|
||||
"created_at": log.created_at.isoformat(),
|
||||
}
|
||||
for log in logs
|
||||
]
|
||||
return JSONResponse(content=export_data)
|
||||
|
||||
|
||||
else: # CSV format
|
||||
import csv
|
||||
import io
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
|
||||
|
||||
# Write header
|
||||
writer.writerow([
|
||||
"ID", "User ID", "Action", "Resource Type", "Resource ID",
|
||||
"IP Address", "Success", "Severity", "Created At", "Details"
|
||||
])
|
||||
|
||||
writer.writerow(
|
||||
[
|
||||
"ID",
|
||||
"User ID",
|
||||
"Action",
|
||||
"Resource Type",
|
||||
"Resource ID",
|
||||
"IP Address",
|
||||
"Success",
|
||||
"Severity",
|
||||
"Created At",
|
||||
"Details",
|
||||
]
|
||||
)
|
||||
|
||||
# Write data
|
||||
for log in logs:
|
||||
writer.writerow([
|
||||
str(log.id),
|
||||
log.user_id or "",
|
||||
log.action,
|
||||
log.resource_type,
|
||||
log.resource_id or "",
|
||||
log.ip_address or "",
|
||||
log.success,
|
||||
log.severity,
|
||||
log.created_at.isoformat(),
|
||||
str(log.details)
|
||||
])
|
||||
|
||||
writer.writerow(
|
||||
[
|
||||
str(log.id),
|
||||
log.user_id or "",
|
||||
log.action,
|
||||
log.resource_type,
|
||||
log.resource_id or "",
|
||||
log.ip_address or "",
|
||||
log.success,
|
||||
log.severity,
|
||||
log.created_at.isoformat(),
|
||||
str(log.details),
|
||||
]
|
||||
)
|
||||
|
||||
output.seek(0)
|
||||
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(output.getvalue().encode()),
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": f"attachment; filename=audit_logs_{start_date.strftime('%Y%m%d')}_{end_date.strftime('%Y%m%d')}.csv"}
|
||||
)
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=audit_logs_{start_date.strftime('%Y%m%d')}_{end_date.strftime('%Y%m%d')}.csv"
|
||||
},
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ from fastapi.security import HTTPBearer
|
||||
from pydantic import BaseModel, EmailStr, validator
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
@@ -38,25 +39,25 @@ class UserRegisterRequest(BaseModel):
|
||||
password: str
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
|
||||
@validator('password')
|
||||
|
||||
@validator("password")
|
||||
def validate_password(cls, v):
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters long')
|
||||
raise ValueError("Password must be at least 8 characters long")
|
||||
if not any(c.isupper() for c in v):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
raise ValueError("Password must contain at least one uppercase letter")
|
||||
if not any(c.islower() for c in v):
|
||||
raise ValueError('Password must contain at least one lowercase letter')
|
||||
raise ValueError("Password must contain at least one lowercase letter")
|
||||
if not any(c.isdigit() for c in v):
|
||||
raise ValueError('Password must contain at least one digit')
|
||||
raise ValueError("Password must contain at least one digit")
|
||||
return v
|
||||
|
||||
@validator('username')
|
||||
|
||||
@validator("username")
|
||||
def validate_username(cls, v):
|
||||
if len(v) < 3:
|
||||
raise ValueError('Username must be at least 3 characters long')
|
||||
raise ValueError("Username must be at least 3 characters long")
|
||||
if not v.isalnum():
|
||||
raise ValueError('Username must contain only alphanumeric characters')
|
||||
raise ValueError("Username must contain only alphanumeric characters")
|
||||
return v
|
||||
|
||||
|
||||
@@ -65,10 +66,10 @@ class UserLoginRequest(BaseModel):
|
||||
username: Optional[str] = None
|
||||
password: str
|
||||
|
||||
@validator('email')
|
||||
@validator("email")
|
||||
def validate_email_or_username(cls, v, values):
|
||||
if v is None and not values.get('username'):
|
||||
raise ValueError('Either email or username must be provided')
|
||||
if v is None and not values.get("username"):
|
||||
raise ValueError("Either email or username must be provided")
|
||||
return v
|
||||
|
||||
|
||||
@@ -77,6 +78,8 @@ class TokenResponse(BaseModel):
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
force_password_change: Optional[bool] = None
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
@@ -86,9 +89,9 @@ class UserResponse(BaseModel):
|
||||
full_name: Optional[str]
|
||||
is_active: bool
|
||||
is_verified: bool
|
||||
role: str
|
||||
role: Optional[str]
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
@@ -100,50 +103,47 @@ class RefreshTokenRequest(BaseModel):
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
||||
@validator('new_password')
|
||||
|
||||
@validator("new_password")
|
||||
def validate_new_password(cls, v):
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters long')
|
||||
raise ValueError("Password must be at least 8 characters long")
|
||||
if not any(c.isupper() for c in v):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
raise ValueError("Password must contain at least one uppercase letter")
|
||||
if not any(c.islower() for c in v):
|
||||
raise ValueError('Password must contain at least one lowercase letter')
|
||||
raise ValueError("Password must contain at least one lowercase letter")
|
||||
if not any(c.isdigit() for c in v):
|
||||
raise ValueError('Password must contain at least one digit')
|
||||
raise ValueError("Password must contain at least one digit")
|
||||
return v
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(
|
||||
user_data: UserRegisterRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
@router.post(
|
||||
"/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
async def register(user_data: UserRegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""Register a new user"""
|
||||
|
||||
|
||||
# Check if user already exists
|
||||
stmt = select(User).where(User.email == user_data.email)
|
||||
result = await db.execute(stmt)
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already registered"
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered"
|
||||
)
|
||||
|
||||
|
||||
# Check if username already exists
|
||||
stmt = select(User).where(User.username == user_data.username)
|
||||
result = await db.execute(stmt)
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Username already taken"
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken"
|
||||
)
|
||||
|
||||
|
||||
# Create new user
|
||||
full_name = None
|
||||
if user_data.first_name or user_data.last_name:
|
||||
full_name = f"{user_data.first_name or ''} {user_data.last_name or ''}".strip()
|
||||
|
||||
|
||||
user = User(
|
||||
email=user_data.email,
|
||||
username=user_data.username,
|
||||
@@ -151,23 +151,29 @@ async def register(
|
||||
full_name=full_name,
|
||||
is_active=True,
|
||||
is_verified=False,
|
||||
role="user"
|
||||
role_id=2, # Default to 'user' role (id=2)
|
||||
)
|
||||
|
||||
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
return UserResponse.from_orm(user)
|
||||
|
||||
return UserResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
username=user.username,
|
||||
full_name=user.full_name,
|
||||
is_active=user.is_active,
|
||||
is_verified=user.is_verified,
|
||||
role=user.role.name if user.role else None,
|
||||
created_at=user.created_at,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(
|
||||
user_data: UserLoginRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
async def login(user_data: UserLoginRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""Login user and return access tokens"""
|
||||
|
||||
|
||||
# Determine identifier for logging and user lookup
|
||||
identifier = user_data.email if user_data.email else user_data.username
|
||||
logger.info(
|
||||
@@ -187,9 +193,9 @@ async def login(
|
||||
query_start = datetime.utcnow()
|
||||
|
||||
if user_data.email:
|
||||
stmt = select(User).where(User.email == user_data.email)
|
||||
stmt = select(User).options(selectinload(User.role)).where(User.email == user_data.email)
|
||||
else:
|
||||
stmt = select(User).where(User.username == user_data.username)
|
||||
stmt = select(User).options(selectinload(User.role)).where(User.username == user_data.username)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
query_end = datetime.utcnow()
|
||||
@@ -205,13 +211,18 @@ async def login(
|
||||
identifier_lower = identifier.lower() if identifier else ""
|
||||
admin_email = settings.ADMIN_EMAIL.lower() if settings.ADMIN_EMAIL else None
|
||||
|
||||
if user_data.email and admin_email and identifier_lower == admin_email and settings.ADMIN_PASSWORD:
|
||||
if (
|
||||
user_data.email
|
||||
and admin_email
|
||||
and identifier_lower == admin_email
|
||||
and settings.ADMIN_PASSWORD
|
||||
):
|
||||
bootstrap_attempted = True
|
||||
logger.info("LOGIN_ADMIN_BOOTSTRAP_START", email=user_data.email)
|
||||
try:
|
||||
await create_default_admin()
|
||||
# Re-run lookup after bootstrap attempt
|
||||
stmt = select(User).where(User.email == user_data.email)
|
||||
stmt = select(User).options(selectinload(User.role)).where(User.email == user_data.email)
|
||||
result = await db.execute(stmt)
|
||||
user = result.scalar_one_or_none()
|
||||
if user:
|
||||
@@ -232,19 +243,21 @@ async def login(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("LOGIN_USER_LIST_FAILURE", error=str(e))
|
||||
|
||||
|
||||
if bootstrap_attempted:
|
||||
logger.warning("LOGIN_ADMIN_BOOTSTRAP_UNSUCCESSFUL", email=user_data.email)
|
||||
|
||||
logger.warning(
|
||||
"LOGIN_ADMIN_BOOTSTRAP_UNSUCCESSFUL", email=user_data.email
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect email or password"
|
||||
detail="Incorrect email or password",
|
||||
)
|
||||
|
||||
|
||||
logger.info("LOGIN_USER_FOUND", email=user.email, is_active=user.is_active)
|
||||
logger.info("LOGIN_PASSWORD_VERIFY_START")
|
||||
verify_start = datetime.utcnow()
|
||||
|
||||
|
||||
if not verify_password(user_data.password, user.hashed_password):
|
||||
verify_end = datetime.utcnow()
|
||||
logger.warning(
|
||||
@@ -253,21 +266,20 @@ async def login(
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect email or password"
|
||||
detail="Incorrect email or password",
|
||||
)
|
||||
|
||||
|
||||
verify_end = datetime.utcnow()
|
||||
logger.info(
|
||||
"LOGIN_PASSWORD_VERIFY_SUCCESS",
|
||||
duration_seconds=(verify_end - verify_start).total_seconds(),
|
||||
)
|
||||
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User account is disabled"
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="User account is disabled"
|
||||
)
|
||||
|
||||
|
||||
# Update last login
|
||||
logger.info("LOGIN_LAST_LOGIN_UPDATE_START")
|
||||
update_start = datetime.utcnow()
|
||||
@@ -289,14 +301,12 @@ async def login(
|
||||
"sub": str(user.id),
|
||||
"email": user.email,
|
||||
"is_superuser": user.is_superuser,
|
||||
"role": user.role
|
||||
"role": user.role.name if user.role else None,
|
||||
},
|
||||
expires_delta=access_token_expires
|
||||
expires_delta=access_token_expires,
|
||||
)
|
||||
|
||||
refresh_token = create_refresh_token(
|
||||
data={"sub": str(user.id), "type": "refresh"}
|
||||
)
|
||||
refresh_token = create_refresh_token(data={"sub": str(user.id), "type": "refresh"})
|
||||
token_end = datetime.utcnow()
|
||||
logger.info(
|
||||
"LOGIN_TOKEN_CREATE_SUCCESS",
|
||||
@@ -308,65 +318,76 @@ async def login(
|
||||
"LOGIN_DEBUG_COMPLETE",
|
||||
total_duration_seconds=total_time.total_seconds(),
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
)
|
||||
|
||||
# Check if user needs to change password
|
||||
response_data = {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
}
|
||||
|
||||
# Add force password change flag if needed
|
||||
if user.force_password_change:
|
||||
response_data["force_password_change"] = True
|
||||
response_data["message"] = "Password change required on first login"
|
||||
|
||||
return response_data
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh_token(
|
||||
token_data: RefreshTokenRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
token_data: RefreshTokenRequest, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Refresh access token using refresh token"""
|
||||
|
||||
|
||||
try:
|
||||
payload = verify_token(token_data.refresh_token)
|
||||
user_id = payload.get("sub")
|
||||
token_type = payload.get("type")
|
||||
|
||||
|
||||
if not user_id or token_type != "refresh":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token"
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
|
||||
)
|
||||
|
||||
|
||||
# Get user from database
|
||||
stmt = select(User).where(User.id == int(user_id))
|
||||
stmt = select(User).options(selectinload(User.role)).where(User.id == int(user_id))
|
||||
result = await db.execute(stmt)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found or inactive"
|
||||
detail="User not found or inactive",
|
||||
)
|
||||
|
||||
|
||||
# Create new access token
|
||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
logger.info(f"REFRESH: Creating new access token with expiration: {access_token_expires}")
|
||||
logger.info(f"REFRESH: ACCESS_TOKEN_EXPIRE_MINUTES from settings: {settings.ACCESS_TOKEN_EXPIRE_MINUTES}")
|
||||
logger.info(
|
||||
f"REFRESH: Creating new access token with expiration: {access_token_expires}"
|
||||
)
|
||||
logger.info(
|
||||
f"REFRESH: ACCESS_TOKEN_EXPIRE_MINUTES from settings: {settings.ACCESS_TOKEN_EXPIRE_MINUTES}"
|
||||
)
|
||||
logger.info(f"REFRESH: Current UTC time: {datetime.utcnow().isoformat()}")
|
||||
|
||||
|
||||
access_token = create_access_token(
|
||||
data={
|
||||
"sub": str(user.id),
|
||||
"email": user.email,
|
||||
"is_superuser": user.is_superuser,
|
||||
"role": user.role
|
||||
"role": user.role.name if user.role else None,
|
||||
},
|
||||
expires_delta=access_token_expires
|
||||
expires_delta=access_token_expires,
|
||||
)
|
||||
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=token_data.refresh_token, # Keep same refresh token
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPException without modification
|
||||
raise
|
||||
@@ -374,30 +395,37 @@ async def refresh_token(
|
||||
# Log the actual error for debugging
|
||||
logger.error(f"Refresh token error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token"
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_current_user_info(
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get current user information"""
|
||||
|
||||
|
||||
# Get full user details from database
|
||||
stmt = select(User).where(User.id == int(current_user["id"]))
|
||||
stmt = select(User).options(selectinload(User.role)).where(User.id == int(current_user["id"]))
|
||||
result = await db.execute(stmt)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
return UserResponse.from_orm(user)
|
||||
|
||||
return UserResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
username=user.username,
|
||||
full_name=user.full_name,
|
||||
is_active=user.is_active,
|
||||
is_verified=user.is_verified,
|
||||
role=user.role.name if user.role else None,
|
||||
created_at=user.created_at,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
@@ -407,14 +435,12 @@ async def logout():
|
||||
|
||||
|
||||
@router.post("/verify-token")
|
||||
async def verify_user_token(
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
async def verify_user_token(current_user: dict = Depends(get_current_user)):
|
||||
"""Verify if the current token is valid"""
|
||||
return {
|
||||
"valid": True,
|
||||
"user_id": current_user["id"],
|
||||
"email": current_user["email"]
|
||||
"email": current_user["email"],
|
||||
}
|
||||
|
||||
|
||||
@@ -422,32 +448,31 @@ async def verify_user_token(
|
||||
async def change_password(
|
||||
password_data: ChangePasswordRequest,
|
||||
current_user: dict = Depends(get_current_active_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Change user password"""
|
||||
|
||||
|
||||
# Get user from database
|
||||
stmt = select(User).where(User.id == int(current_user["id"]))
|
||||
result = await db.execute(stmt)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
|
||||
# Verify current password
|
||||
if not verify_password(password_data.current_password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Current password is incorrect"
|
||||
detail="Current password is incorrect",
|
||||
)
|
||||
|
||||
|
||||
# Update password
|
||||
user.hashed_password = get_password_hash(password_data.new_password)
|
||||
user.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
return {"message": "Password changed successfully"}
|
||||
|
||||
@@ -129,76 +129,89 @@ async def list_budgets(
|
||||
budget_type: Optional[BudgetType] = Query(None),
|
||||
is_enabled: Optional[bool] = Query(None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List budgets with pagination and filtering"""
|
||||
|
||||
|
||||
# Check permissions - users can view their own budgets
|
||||
if user_id and int(user_id) != current_user['id']:
|
||||
if user_id and int(user_id) != current_user["id"]:
|
||||
require_permission(current_user.get("permissions", []), "platform:budgets:read")
|
||||
elif not user_id:
|
||||
require_permission(current_user.get("permissions", []), "platform:budgets:read")
|
||||
|
||||
|
||||
# If no user_id specified and user doesn't have admin permissions, show only their budgets
|
||||
if not user_id and "platform:budgets:read" not in current_user.get("permissions", []):
|
||||
user_id = current_user['id']
|
||||
|
||||
if not user_id and "platform:budgets:read" not in current_user.get(
|
||||
"permissions", []
|
||||
):
|
||||
user_id = current_user["id"]
|
||||
|
||||
# Build query
|
||||
query = select(Budget)
|
||||
|
||||
|
||||
# Apply filters
|
||||
if user_id:
|
||||
query = query.where(Budget.user_id == (int(user_id) if isinstance(user_id, str) else user_id))
|
||||
query = query.where(
|
||||
Budget.user_id == (int(user_id) if isinstance(user_id, str) else user_id)
|
||||
)
|
||||
if budget_type:
|
||||
query = query.where(Budget.budget_type == budget_type.value)
|
||||
if is_enabled is not None:
|
||||
query = query.where(Budget.is_enabled == is_enabled)
|
||||
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count(Budget.id))
|
||||
|
||||
|
||||
# Apply same filters to count query
|
||||
if user_id:
|
||||
count_query = count_query.where(Budget.user_id == (int(user_id) if isinstance(user_id, str) else user_id))
|
||||
count_query = count_query.where(
|
||||
Budget.user_id == (int(user_id) if isinstance(user_id, str) else user_id)
|
||||
)
|
||||
if budget_type:
|
||||
count_query = count_query.where(Budget.budget_type == budget_type.value)
|
||||
if is_enabled is not None:
|
||||
count_query = count_query.where(Budget.is_enabled == is_enabled)
|
||||
|
||||
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
|
||||
# Apply pagination
|
||||
offset = (page - 1) * size
|
||||
query = query.offset(offset).limit(size).order_by(Budget.created_at.desc())
|
||||
|
||||
|
||||
# Execute query
|
||||
result = await db.execute(query)
|
||||
budgets = result.scalars().all()
|
||||
|
||||
|
||||
# Calculate current usage for each budget
|
||||
budget_responses = []
|
||||
for budget in budgets:
|
||||
usage = await _calculate_budget_usage(db, budget)
|
||||
budget_data = BudgetResponse.model_validate(budget)
|
||||
budget_data.current_usage = usage
|
||||
budget_data.usage_percentage = (usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
|
||||
budget_data.usage_percentage = (
|
||||
(usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
|
||||
)
|
||||
budget_responses.append(budget_data)
|
||||
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="list_budgets",
|
||||
resource_type="budget",
|
||||
details={"page": page, "size": size, "filters": {"user_id": user_id, "budget_type": budget_type, "is_enabled": is_enabled}}
|
||||
details={
|
||||
"page": page,
|
||||
"size": size,
|
||||
"filters": {
|
||||
"user_id": user_id,
|
||||
"budget_type": budget_type,
|
||||
"is_enabled": is_enabled,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
return BudgetListResponse(
|
||||
budgets=budget_responses,
|
||||
total=total,
|
||||
page=page,
|
||||
size=size
|
||||
budgets=budget_responses, total=total, page=page, size=size
|
||||
)
|
||||
|
||||
|
||||
@@ -206,42 +219,43 @@ async def list_budgets(
|
||||
async def get_budget(
|
||||
budget_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get budget by ID"""
|
||||
|
||||
|
||||
# Get budget
|
||||
query = select(Budget).where(Budget.id == budget_id)
|
||||
result = await db.execute(query)
|
||||
budget = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not budget:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Budget not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Budget not found"
|
||||
)
|
||||
|
||||
|
||||
# Check permissions - users can view their own budgets
|
||||
if budget.user_id != current_user['id']:
|
||||
if budget.user_id != current_user["id"]:
|
||||
require_permission(current_user.get("permissions", []), "platform:budgets:read")
|
||||
|
||||
|
||||
# Calculate current usage
|
||||
usage = await _calculate_budget_usage(db, budget)
|
||||
|
||||
|
||||
# Build response
|
||||
budget_data = BudgetResponse.model_validate(budget)
|
||||
budget_data.current_usage = usage
|
||||
budget_data.usage_percentage = (usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
|
||||
|
||||
budget_data.usage_percentage = (
|
||||
(usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
|
||||
)
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="get_budget",
|
||||
resource_type="budget",
|
||||
resource_id=budget_id
|
||||
resource_id=budget_id,
|
||||
)
|
||||
|
||||
|
||||
return budget_data
|
||||
|
||||
|
||||
@@ -249,24 +263,30 @@ async def get_budget(
|
||||
async def create_budget(
|
||||
budget_data: BudgetCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a new budget"""
|
||||
|
||||
|
||||
# Check permissions
|
||||
require_permission(current_user.get("permissions", []), "platform:budgets:create")
|
||||
|
||||
|
||||
# If user_id not specified, use current user
|
||||
target_user_id = budget_data.user_id or current_user['id']
|
||||
|
||||
target_user_id = budget_data.user_id or current_user["id"]
|
||||
|
||||
# If setting budget for another user, need admin permissions
|
||||
if int(target_user_id) != current_user['id'] if isinstance(target_user_id, str) else target_user_id != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:budgets:admin")
|
||||
|
||||
if (
|
||||
int(target_user_id) != current_user["id"]
|
||||
if isinstance(target_user_id, str)
|
||||
else target_user_id != current_user["id"]
|
||||
):
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:budgets:admin"
|
||||
)
|
||||
|
||||
# Calculate period start and end
|
||||
now = datetime.utcnow()
|
||||
period_start, period_end = _calculate_period_bounds(now, budget_data.period_type)
|
||||
|
||||
|
||||
# Create budget
|
||||
new_budget = Budget(
|
||||
name=budget_data.name,
|
||||
@@ -281,30 +301,34 @@ async def create_budget(
|
||||
is_enabled=budget_data.is_enabled,
|
||||
alert_threshold_percent=budget_data.alert_threshold_percent,
|
||||
allowed_resources=budget_data.allowed_resources,
|
||||
metadata=budget_data.metadata
|
||||
metadata=budget_data.metadata,
|
||||
)
|
||||
|
||||
|
||||
db.add(new_budget)
|
||||
await db.commit()
|
||||
await db.refresh(new_budget)
|
||||
|
||||
|
||||
# Build response
|
||||
budget_response = BudgetResponse.model_validate(new_budget)
|
||||
budget_response.current_usage = 0.0
|
||||
budget_response.usage_percentage = 0.0
|
||||
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="create_budget",
|
||||
resource_type="budget",
|
||||
resource_id=str(new_budget.id),
|
||||
details={"name": budget_data.name, "budget_type": budget_data.budget_type, "limit_amount": budget_data.limit_amount}
|
||||
details={
|
||||
"name": budget_data.name,
|
||||
"budget_type": budget_data.budget_type,
|
||||
"limit_amount": budget_data.limit_amount,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Budget created: {new_budget.name} by {current_user['username']}")
|
||||
|
||||
|
||||
return budget_response
|
||||
|
||||
|
||||
@@ -313,70 +337,75 @@ async def update_budget(
|
||||
budget_id: str,
|
||||
budget_data: BudgetUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update budget"""
|
||||
|
||||
|
||||
# Get budget
|
||||
query = select(Budget).where(Budget.id == budget_id)
|
||||
result = await db.execute(query)
|
||||
budget = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not budget:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Budget not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Budget not found"
|
||||
)
|
||||
|
||||
|
||||
# Check permissions - users can update their own budgets
|
||||
if budget.user_id != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:budgets:update")
|
||||
|
||||
if budget.user_id != current_user["id"]:
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:budgets:update"
|
||||
)
|
||||
|
||||
# Store original values for audit
|
||||
original_values = {
|
||||
"name": budget.name,
|
||||
"limit_amount": budget.limit_amount,
|
||||
"is_enabled": budget.is_enabled
|
||||
"is_enabled": budget.is_enabled,
|
||||
}
|
||||
|
||||
|
||||
# Update budget fields
|
||||
update_data = budget_data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(budget, field, value)
|
||||
|
||||
|
||||
# Recalculate period if period_type changed
|
||||
if "period_type" in update_data:
|
||||
period_start, period_end = _calculate_period_bounds(datetime.utcnow(), budget.period_type)
|
||||
period_start, period_end = _calculate_period_bounds(
|
||||
datetime.utcnow(), budget.period_type
|
||||
)
|
||||
budget.period_start = period_start
|
||||
budget.period_end = period_end
|
||||
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(budget)
|
||||
|
||||
|
||||
# Calculate current usage
|
||||
usage = await _calculate_budget_usage(db, budget)
|
||||
|
||||
|
||||
# Build response
|
||||
budget_response = BudgetResponse.model_validate(budget)
|
||||
budget_response.current_usage = usage
|
||||
budget_response.usage_percentage = (usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
|
||||
|
||||
budget_response.usage_percentage = (
|
||||
(usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
|
||||
)
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="update_budget",
|
||||
resource_type="budget",
|
||||
resource_id=budget_id,
|
||||
details={
|
||||
"updated_fields": list(update_data.keys()),
|
||||
"before_values": original_values,
|
||||
"after_values": {k: getattr(budget, k) for k in update_data.keys()}
|
||||
}
|
||||
"after_values": {k: getattr(budget, k) for k in update_data.keys()},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Budget updated: {budget.name} by {current_user['username']}")
|
||||
|
||||
|
||||
return budget_response
|
||||
|
||||
|
||||
@@ -384,41 +413,42 @@ async def update_budget(
|
||||
async def delete_budget(
|
||||
budget_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Delete budget"""
|
||||
|
||||
|
||||
# Get budget
|
||||
query = select(Budget).where(Budget.id == budget_id)
|
||||
result = await db.execute(query)
|
||||
budget = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not budget:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Budget not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Budget not found"
|
||||
)
|
||||
|
||||
|
||||
# Check permissions - users can delete their own budgets
|
||||
if budget.user_id != current_user['id']:
|
||||
require_permission(current_user.get("permissions", []), "platform:budgets:delete")
|
||||
|
||||
if budget.user_id != current_user["id"]:
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:budgets:delete"
|
||||
)
|
||||
|
||||
# Delete budget
|
||||
await db.delete(budget)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="delete_budget",
|
||||
resource_type="budget",
|
||||
resource_id=budget_id,
|
||||
details={"name": budget.name}
|
||||
details={"name": budget.name},
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Budget deleted: {budget.name} by {current_user['username']}")
|
||||
|
||||
|
||||
return {"message": "Budget deleted successfully"}
|
||||
|
||||
|
||||
@@ -426,35 +456,36 @@ async def delete_budget(
|
||||
async def get_budget_usage(
|
||||
budget_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get detailed budget usage information"""
|
||||
|
||||
|
||||
# Get budget
|
||||
query = select(Budget).where(Budget.id == budget_id)
|
||||
result = await db.execute(query)
|
||||
budget = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not budget:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Budget not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Budget not found"
|
||||
)
|
||||
|
||||
|
||||
# Check permissions - users can view their own budget usage
|
||||
if budget.user_id != current_user['id']:
|
||||
if budget.user_id != current_user["id"]:
|
||||
require_permission(current_user.get("permissions", []), "platform:budgets:read")
|
||||
|
||||
|
||||
# Calculate usage
|
||||
current_usage = await _calculate_budget_usage(db, budget)
|
||||
usage_percentage = (current_usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
|
||||
usage_percentage = (
|
||||
(current_usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
|
||||
)
|
||||
remaining_amount = max(0, budget.limit_amount - current_usage)
|
||||
is_exceeded = current_usage > budget.limit_amount
|
||||
|
||||
|
||||
# Calculate days remaining in period
|
||||
now = datetime.utcnow()
|
||||
days_remaining = max(0, (budget.period_end - now).days)
|
||||
|
||||
|
||||
# Calculate projected usage
|
||||
projected_usage = None
|
||||
if days_remaining > 0 and current_usage > 0:
|
||||
@@ -463,19 +494,19 @@ async def get_budget_usage(
|
||||
daily_rate = current_usage / days_elapsed
|
||||
total_days = (budget.period_end - budget.period_start).days + 1
|
||||
projected_usage = daily_rate * total_days
|
||||
|
||||
|
||||
# Get usage history (last 30 days)
|
||||
usage_history = await _get_usage_history(db, budget, days=30)
|
||||
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="get_budget_usage",
|
||||
resource_type="budget",
|
||||
resource_id=budget_id
|
||||
resource_id=budget_id,
|
||||
)
|
||||
|
||||
|
||||
return BudgetUsageResponse(
|
||||
budget_id=budget_id,
|
||||
current_usage=current_usage,
|
||||
@@ -487,7 +518,7 @@ async def get_budget_usage(
|
||||
is_exceeded=is_exceeded,
|
||||
days_remaining=days_remaining,
|
||||
projected_usage=projected_usage,
|
||||
usage_history=usage_history
|
||||
usage_history=usage_history,
|
||||
)
|
||||
|
||||
|
||||
@@ -495,85 +526,92 @@ async def get_budget_usage(
|
||||
async def get_budget_alerts(
|
||||
budget_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get budget alerts"""
|
||||
|
||||
|
||||
# Get budget
|
||||
query = select(Budget).where(Budget.id == budget_id)
|
||||
result = await db.execute(query)
|
||||
budget = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not budget:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Budget not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Budget not found"
|
||||
)
|
||||
|
||||
|
||||
# Check permissions - users can view their own budget alerts
|
||||
if budget.user_id != current_user['id']:
|
||||
if budget.user_id != current_user["id"]:
|
||||
require_permission(current_user.get("permissions", []), "platform:budgets:read")
|
||||
|
||||
|
||||
# Calculate usage
|
||||
current_usage = await _calculate_budget_usage(db, budget)
|
||||
usage_percentage = (current_usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
|
||||
|
||||
usage_percentage = (
|
||||
(current_usage / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
|
||||
)
|
||||
|
||||
alerts = []
|
||||
|
||||
|
||||
# Check for alerts
|
||||
if usage_percentage >= 100:
|
||||
alerts.append(BudgetAlertResponse(
|
||||
budget_id=budget_id,
|
||||
budget_name=budget.name,
|
||||
alert_type="exceeded",
|
||||
current_usage=current_usage,
|
||||
limit_amount=budget.limit_amount,
|
||||
usage_percentage=usage_percentage,
|
||||
message=f"Budget '{budget.name}' has been exceeded ({usage_percentage:.1f}% used)"
|
||||
))
|
||||
alerts.append(
|
||||
BudgetAlertResponse(
|
||||
budget_id=budget_id,
|
||||
budget_name=budget.name,
|
||||
alert_type="exceeded",
|
||||
current_usage=current_usage,
|
||||
limit_amount=budget.limit_amount,
|
||||
usage_percentage=usage_percentage,
|
||||
message=f"Budget '{budget.name}' has been exceeded ({usage_percentage:.1f}% used)",
|
||||
)
|
||||
)
|
||||
elif usage_percentage >= 90:
|
||||
alerts.append(BudgetAlertResponse(
|
||||
budget_id=budget_id,
|
||||
budget_name=budget.name,
|
||||
alert_type="critical",
|
||||
current_usage=current_usage,
|
||||
limit_amount=budget.limit_amount,
|
||||
usage_percentage=usage_percentage,
|
||||
message=f"Budget '{budget.name}' is critically high ({usage_percentage:.1f}% used)"
|
||||
))
|
||||
alerts.append(
|
||||
BudgetAlertResponse(
|
||||
budget_id=budget_id,
|
||||
budget_name=budget.name,
|
||||
alert_type="critical",
|
||||
current_usage=current_usage,
|
||||
limit_amount=budget.limit_amount,
|
||||
usage_percentage=usage_percentage,
|
||||
message=f"Budget '{budget.name}' is critically high ({usage_percentage:.1f}% used)",
|
||||
)
|
||||
)
|
||||
elif usage_percentage >= budget.alert_threshold_percent:
|
||||
alerts.append(BudgetAlertResponse(
|
||||
budget_id=budget_id,
|
||||
budget_name=budget.name,
|
||||
alert_type="warning",
|
||||
current_usage=current_usage,
|
||||
limit_amount=budget.limit_amount,
|
||||
usage_percentage=usage_percentage,
|
||||
message=f"Budget '{budget.name}' has reached alert threshold ({usage_percentage:.1f}% used)"
|
||||
))
|
||||
|
||||
alerts.append(
|
||||
BudgetAlertResponse(
|
||||
budget_id=budget_id,
|
||||
budget_name=budget.name,
|
||||
alert_type="warning",
|
||||
current_usage=current_usage,
|
||||
limit_amount=budget.limit_amount,
|
||||
usage_percentage=usage_percentage,
|
||||
message=f"Budget '{budget.name}' has reached alert threshold ({usage_percentage:.1f}% used)",
|
||||
)
|
||||
)
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
# Helper functions
|
||||
async def _calculate_budget_usage(db: AsyncSession, budget: Budget) -> float:
|
||||
"""Calculate current usage for a budget"""
|
||||
|
||||
|
||||
# Build base query
|
||||
query = select(UsageTracking)
|
||||
|
||||
|
||||
# Filter by time period
|
||||
query = query.where(
|
||||
UsageTracking.created_at >= budget.period_start,
|
||||
UsageTracking.created_at <= budget.period_end
|
||||
UsageTracking.created_at <= budget.period_end,
|
||||
)
|
||||
|
||||
|
||||
# Filter by user or API key
|
||||
if budget.api_key_id:
|
||||
query = query.where(UsageTracking.api_key_id == budget.api_key_id)
|
||||
elif budget.user_id:
|
||||
query = query.where(UsageTracking.user_id == budget.user_id)
|
||||
|
||||
|
||||
# Calculate usage based on budget type
|
||||
if budget.budget_type == "tokens":
|
||||
usage_query = query.with_only_columns(func.sum(UsageTracking.total_tokens))
|
||||
@@ -583,20 +621,22 @@ async def _calculate_budget_usage(db: AsyncSession, budget: Budget) -> float:
|
||||
usage_query = query.with_only_columns(func.count(UsageTracking.id))
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
|
||||
result = await db.execute(usage_query)
|
||||
usage = result.scalar() or 0
|
||||
|
||||
|
||||
# Convert cents to dollars for dollar budgets
|
||||
if budget.budget_type == "dollars":
|
||||
usage = usage / 100.0
|
||||
|
||||
|
||||
return float(usage)
|
||||
|
||||
|
||||
def _calculate_period_bounds(current_time: datetime, period_type: str) -> tuple[datetime, datetime]:
|
||||
def _calculate_period_bounds(
|
||||
current_time: datetime, period_type: str
|
||||
) -> tuple[datetime, datetime]:
|
||||
"""Calculate period start and end dates"""
|
||||
|
||||
|
||||
if period_type == "hourly":
|
||||
start = current_time.replace(minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(hours=1) - timedelta(microseconds=1)
|
||||
@@ -606,7 +646,9 @@ def _calculate_period_bounds(current_time: datetime, period_type: str) -> tuple[
|
||||
elif period_type == "weekly":
|
||||
# Start of week (Monday)
|
||||
days_since_monday = current_time.weekday()
|
||||
start = current_time.replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=days_since_monday)
|
||||
start = current_time.replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
) - timedelta(days=days_since_monday)
|
||||
end = start + timedelta(weeks=1) - timedelta(microseconds=1)
|
||||
elif period_type == "monthly":
|
||||
start = current_time.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
@@ -616,44 +658,49 @@ def _calculate_period_bounds(current_time: datetime, period_type: str) -> tuple[
|
||||
next_month = start.replace(month=start.month + 1)
|
||||
end = next_month - timedelta(microseconds=1)
|
||||
elif period_type == "yearly":
|
||||
start = current_time.replace(month=1, day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
start = current_time.replace(
|
||||
month=1, day=1, hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
end = start.replace(year=start.year + 1) - timedelta(microseconds=1)
|
||||
else:
|
||||
# Default to daily
|
||||
start = current_time.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(days=1) - timedelta(microseconds=1)
|
||||
|
||||
|
||||
return start, end
|
||||
|
||||
|
||||
async def _get_usage_history(db: AsyncSession, budget: Budget, days: int = 30) -> List[dict]:
|
||||
async def _get_usage_history(
|
||||
db: AsyncSession, budget: Budget, days: int = 30
|
||||
) -> List[dict]:
|
||||
"""Get usage history for the budget"""
|
||||
|
||||
|
||||
end_date = datetime.utcnow()
|
||||
start_date = end_date - timedelta(days=days)
|
||||
|
||||
|
||||
# Build query
|
||||
query = select(
|
||||
func.date(UsageTracking.created_at).label('date'),
|
||||
func.sum(UsageTracking.total_tokens).label('tokens'),
|
||||
func.sum(UsageTracking.cost_cents).label('cost_cents'),
|
||||
func.count(UsageTracking.id).label('requests')
|
||||
func.date(UsageTracking.created_at).label("date"),
|
||||
func.sum(UsageTracking.total_tokens).label("tokens"),
|
||||
func.sum(UsageTracking.cost_cents).label("cost_cents"),
|
||||
func.count(UsageTracking.id).label("requests"),
|
||||
).where(
|
||||
UsageTracking.created_at >= start_date,
|
||||
UsageTracking.created_at <= end_date
|
||||
UsageTracking.created_at >= start_date, UsageTracking.created_at <= end_date
|
||||
)
|
||||
|
||||
|
||||
# Filter by user or API key
|
||||
if budget.api_key_id:
|
||||
query = query.where(UsageTracking.api_key_id == budget.api_key_id)
|
||||
elif budget.user_id:
|
||||
query = query.where(UsageTracking.user_id == budget.user_id)
|
||||
|
||||
query = query.group_by(func.date(UsageTracking.created_at)).order_by(func.date(UsageTracking.created_at))
|
||||
|
||||
|
||||
query = query.group_by(func.date(UsageTracking.created_at)).order_by(
|
||||
func.date(UsageTracking.created_at)
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
rows = result.fetchall()
|
||||
|
||||
|
||||
history = []
|
||||
for row in rows:
|
||||
usage_value = 0
|
||||
@@ -663,13 +710,15 @@ async def _get_usage_history(db: AsyncSession, budget: Budget, days: int = 30) -
|
||||
usage_value = (row.cost_cents or 0) / 100.0
|
||||
elif budget.budget_type == "requests":
|
||||
usage_value = row.requests or 0
|
||||
|
||||
history.append({
|
||||
"date": row.date.isoformat(),
|
||||
"usage": usage_value,
|
||||
"tokens": row.tokens or 0,
|
||||
"cost_dollars": (row.cost_cents or 0) / 100.0,
|
||||
"requests": row.requests or 0
|
||||
})
|
||||
|
||||
return history
|
||||
|
||||
history.append(
|
||||
{
|
||||
"date": row.date.isoformat(),
|
||||
"usage": usage_value,
|
||||
"tokens": row.tokens or 0,
|
||||
"cost_dollars": (row.cost_cents or 0) / 100.0,
|
||||
"requests": row.requests or 0,
|
||||
}
|
||||
)
|
||||
|
||||
return history
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
3
backend/app/api/v1/endpoints/__init__.py
Normal file
3
backend/app/api/v1/endpoints/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
API v1 endpoints package
|
||||
"""
|
||||
166
backend/app/api/v1/endpoints/tool_calling.py
Normal file
166
backend/app/api/v1/endpoints/tool_calling.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Tool calling API endpoints
|
||||
Integration between LLM and tool execution
|
||||
"""
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.core.security import get_current_user
|
||||
from app.services.tool_calling_service import ToolCallingService
|
||||
from app.services.llm.models import ChatRequest, ChatResponse
|
||||
from app.schemas.tool_calling import (
|
||||
ToolCallRequest,
|
||||
ToolCallResponse,
|
||||
ToolExecutionRequest,
|
||||
ToolValidationRequest,
|
||||
ToolValidationResponse,
|
||||
ToolHistoryResponse,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/chat/completions", response_model=ChatResponse)
|
||||
async def create_chat_completion_with_tools(
|
||||
request: ChatRequest,
|
||||
auto_execute_tools: bool = Query(
|
||||
True, description="Whether to automatically execute tool calls"
|
||||
),
|
||||
max_tool_calls: int = Query(
|
||||
5, ge=1, le=10, description="Maximum number of tool calls"
|
||||
),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Create chat completion with tool calling support"""
|
||||
|
||||
service = ToolCallingService(db)
|
||||
|
||||
# Resolve user ID for context
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
# Set user context in request
|
||||
request.user_id = str(user_id)
|
||||
request.api_key_id = 1 # Default for internal usage
|
||||
|
||||
response = await service.create_chat_completion_with_tools(
|
||||
request=request,
|
||||
user=current_user,
|
||||
auto_execute_tools=auto_execute_tools,
|
||||
max_tool_calls=max_tool_calls,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/chat/completions/stream")
|
||||
async def create_chat_completion_stream_with_tools(
|
||||
request: ChatRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Create streaming chat completion with tool calling support"""
|
||||
|
||||
service = ToolCallingService(db)
|
||||
|
||||
# Resolve user ID for context
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
# Set user context in request
|
||||
request.user_id = str(user_id)
|
||||
request.api_key_id = 1 # Default for internal usage
|
||||
|
||||
async def stream_generator():
|
||||
async for chunk in service.create_chat_completion_stream_with_tools(
|
||||
request=request, user=current_user
|
||||
):
|
||||
yield f"data: {chunk}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
stream_generator(),
|
||||
media_type="text/plain",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/execute", response_model=ToolCallResponse)
|
||||
async def execute_tool_by_name(
|
||||
request: ToolExecutionRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Execute a tool by name directly"""
|
||||
|
||||
service = ToolCallingService(db)
|
||||
|
||||
try:
|
||||
result = await service.execute_tool_by_name(
|
||||
tool_name=request.tool_name,
|
||||
parameters=request.parameters,
|
||||
user=current_user,
|
||||
)
|
||||
|
||||
return ToolCallResponse(success=True, result=result, error=None)
|
||||
|
||||
except Exception as e:
|
||||
return ToolCallResponse(success=False, result=None, error=str(e))
|
||||
|
||||
|
||||
@router.post("/validate", response_model=ToolValidationResponse)
|
||||
async def validate_tool_availability(
|
||||
request: ToolValidationRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Validate which tools are available to the user"""
|
||||
|
||||
service = ToolCallingService(db)
|
||||
|
||||
availability = await service.validate_tool_availability(
|
||||
tool_names=request.tool_names, user=current_user
|
||||
)
|
||||
|
||||
return ToolValidationResponse(tool_availability=availability)
|
||||
|
||||
|
||||
@router.get("/history", response_model=ToolHistoryResponse)
|
||||
async def get_tool_call_history(
|
||||
limit: int = Query(
|
||||
50, ge=1, le=100, description="Number of history items to return"
|
||||
),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Get recent tool execution history for the user"""
|
||||
|
||||
service = ToolCallingService(db)
|
||||
|
||||
history = await service.get_tool_call_history(user=current_user, limit=limit)
|
||||
|
||||
return ToolHistoryResponse(history=history, total=len(history))
|
||||
|
||||
|
||||
@router.get("/available")
|
||||
async def get_available_tools(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Get tools available for function calling"""
|
||||
|
||||
service = ToolCallingService(db)
|
||||
|
||||
# Get available tools
|
||||
tools = await service._get_available_tools_for_user(current_user)
|
||||
|
||||
# Convert to OpenAI format
|
||||
openai_tools = await service._convert_tools_to_openai_format(tools)
|
||||
|
||||
return {"tools": openai_tools, "count": len(openai_tools)}
|
||||
386
backend/app/api/v1/endpoints/tools.py
Normal file
386
backend/app/api/v1/endpoints/tools.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
Tool management and execution API endpoints
|
||||
"""
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.core.security import get_current_user
|
||||
from app.services.tool_management_service import ToolManagementService
|
||||
from app.services.tool_execution_service import ToolExecutionService
|
||||
from app.schemas.tool import (
|
||||
ToolCreate,
|
||||
ToolUpdate,
|
||||
ToolResponse,
|
||||
ToolListResponse,
|
||||
ToolExecutionCreate,
|
||||
ToolExecutionResponse,
|
||||
ToolExecutionListResponse,
|
||||
ToolCategoryCreate,
|
||||
ToolCategoryResponse,
|
||||
ToolStatisticsResponse,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/", response_model=ToolResponse)
|
||||
async def create_tool(
|
||||
tool_data: ToolCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Create a new tool"""
|
||||
service = ToolManagementService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
tool = await service.create_tool(
|
||||
name=tool_data.name,
|
||||
display_name=tool_data.display_name,
|
||||
code=tool_data.code,
|
||||
tool_type=tool_data.tool_type,
|
||||
created_by_user_id=user_id,
|
||||
description=tool_data.description,
|
||||
parameters_schema=tool_data.parameters_schema,
|
||||
return_schema=tool_data.return_schema,
|
||||
timeout_seconds=tool_data.timeout_seconds,
|
||||
max_memory_mb=tool_data.max_memory_mb,
|
||||
max_cpu_seconds=tool_data.max_cpu_seconds,
|
||||
docker_image=tool_data.docker_image,
|
||||
docker_command=tool_data.docker_command,
|
||||
category=tool_data.category,
|
||||
tags=tool_data.tags,
|
||||
is_public=tool_data.is_public,
|
||||
)
|
||||
|
||||
return ToolResponse.from_orm(tool)
|
||||
|
||||
|
||||
@router.get("/", response_model=ToolListResponse)
|
||||
async def list_tools(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
category: Optional[str] = Query(None),
|
||||
tool_type: Optional[str] = Query(None),
|
||||
is_public: Optional[bool] = Query(None),
|
||||
is_approved: Optional[bool] = Query(None),
|
||||
search: Optional[str] = Query(None),
|
||||
tags: Optional[List[str]] = Query(None),
|
||||
created_by_user_id: Optional[int] = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""List tools with filtering and pagination"""
|
||||
service = ToolManagementService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
tools = await service.get_tools(
|
||||
user_id=user_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
category=category,
|
||||
tool_type=tool_type,
|
||||
is_public=is_public,
|
||||
is_approved=is_approved,
|
||||
search=search,
|
||||
tags=tags,
|
||||
created_by_user_id=created_by_user_id,
|
||||
)
|
||||
|
||||
return ToolListResponse(
|
||||
tools=[ToolResponse.from_orm(tool) for tool in tools],
|
||||
total=len(tools),
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{tool_id}", response_model=ToolResponse)
|
||||
async def get_tool(
|
||||
tool_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Get tool by ID"""
|
||||
service = ToolManagementService(db)
|
||||
|
||||
tool = await service.get_tool_by_id(tool_id)
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="Tool not found")
|
||||
|
||||
# Resolve underlying User object if available
|
||||
user_obj = (
|
||||
current_user.get("user_obj")
|
||||
if isinstance(current_user, dict)
|
||||
else current_user
|
||||
)
|
||||
|
||||
# Check if user can access this tool
|
||||
if not user_obj or not tool.can_be_used_by(user_obj):
|
||||
raise HTTPException(status_code=403, detail="Access denied to this tool")
|
||||
|
||||
return ToolResponse.from_orm(tool)
|
||||
|
||||
|
||||
@router.put("/{tool_id}", response_model=ToolResponse)
|
||||
async def update_tool(
|
||||
tool_id: int,
|
||||
tool_data: ToolUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Update tool (only by creator or admin)"""
|
||||
service = ToolManagementService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
tool = await service.update_tool(
|
||||
tool_id=tool_id,
|
||||
user_id=user_id,
|
||||
display_name=tool_data.display_name,
|
||||
description=tool_data.description,
|
||||
code=tool_data.code,
|
||||
parameters_schema=tool_data.parameters_schema,
|
||||
return_schema=tool_data.return_schema,
|
||||
timeout_seconds=tool_data.timeout_seconds,
|
||||
max_memory_mb=tool_data.max_memory_mb,
|
||||
max_cpu_seconds=tool_data.max_cpu_seconds,
|
||||
docker_image=tool_data.docker_image,
|
||||
docker_command=tool_data.docker_command,
|
||||
category=tool_data.category,
|
||||
tags=tool_data.tags,
|
||||
is_public=tool_data.is_public,
|
||||
is_active=tool_data.is_active,
|
||||
)
|
||||
|
||||
return ToolResponse.from_orm(tool)
|
||||
|
||||
|
||||
@router.delete("/{tool_id}")
|
||||
async def delete_tool(
|
||||
tool_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Delete tool (only by creator or admin)"""
|
||||
service = ToolManagementService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
await service.delete_tool(tool_id, user_id)
|
||||
return {"message": "Tool deleted successfully"}
|
||||
|
||||
|
||||
@router.post("/{tool_id}/approve", response_model=ToolResponse)
|
||||
async def approve_tool(
|
||||
tool_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Approve tool for public use (admin only)"""
|
||||
service = ToolManagementService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
tool = await service.approve_tool(tool_id, user_id)
|
||||
return ToolResponse.from_orm(tool)
|
||||
|
||||
|
||||
# Tool Execution Endpoints
|
||||
|
||||
|
||||
@router.post("/{tool_id}/execute", response_model=ToolExecutionResponse)
|
||||
async def execute_tool(
|
||||
tool_id: int,
|
||||
execution_data: ToolExecutionCreate,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Execute a tool with given parameters"""
|
||||
service = ToolExecutionService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
execution = await service.execute_tool(
|
||||
tool_id=tool_id,
|
||||
user_id=user_id,
|
||||
parameters=execution_data.parameters,
|
||||
timeout_override=execution_data.timeout_override,
|
||||
)
|
||||
|
||||
return ToolExecutionResponse.from_orm(execution)
|
||||
|
||||
|
||||
@router.get("/executions", response_model=ToolExecutionListResponse)
|
||||
async def list_executions(
|
||||
tool_id: Optional[int] = Query(None),
|
||||
executed_by_user_id: Optional[int] = Query(None),
|
||||
status: Optional[str] = Query(None),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""List tool executions with filtering"""
|
||||
service = ToolExecutionService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
executions = await service.get_tool_executions(
|
||||
tool_id=tool_id,
|
||||
user_id=user_id,
|
||||
executed_by_user_id=executed_by_user_id,
|
||||
status=status,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return ToolExecutionListResponse(
|
||||
executions=[
|
||||
ToolExecutionResponse.from_orm(execution) for execution in executions
|
||||
],
|
||||
total=len(executions),
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/executions/{execution_id}", response_model=ToolExecutionResponse)
|
||||
async def get_execution(
|
||||
execution_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Get execution details"""
|
||||
service = ToolExecutionService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
# Get execution through list with filter to ensure permission check
|
||||
executions = await service.get_tool_executions(
|
||||
user_id=user_id, skip=0, limit=1
|
||||
)
|
||||
|
||||
execution = next((e for e in executions if e.id == execution_id), None)
|
||||
if not execution:
|
||||
raise HTTPException(status_code=404, detail="Execution not found")
|
||||
|
||||
return ToolExecutionResponse.from_orm(execution)
|
||||
|
||||
|
||||
@router.post("/executions/{execution_id}/cancel", response_model=ToolExecutionResponse)
|
||||
async def cancel_execution(
|
||||
execution_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Cancel a running execution"""
|
||||
service = ToolExecutionService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
execution = await service.cancel_execution(execution_id, user_id)
|
||||
return ToolExecutionResponse.from_orm(execution)
|
||||
|
||||
|
||||
@router.get("/executions/{execution_id}/logs")
|
||||
async def get_execution_logs(
|
||||
execution_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Get real-time logs for execution"""
|
||||
service = ToolExecutionService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
logs = await service.get_execution_logs(execution_id, user_id)
|
||||
return logs
|
||||
|
||||
|
||||
# Tool Categories
|
||||
|
||||
|
||||
@router.post("/categories", response_model=ToolCategoryResponse)
|
||||
async def create_category(
|
||||
category_data: ToolCategoryCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Create a new tool category (admin only)"""
|
||||
user_obj = (
|
||||
current_user.get("user_obj")
|
||||
if isinstance(current_user, dict)
|
||||
else current_user
|
||||
)
|
||||
|
||||
if not user_obj or not user_obj.has_permission("manage_tools"):
|
||||
raise HTTPException(status_code=403, detail="Admin privileges required")
|
||||
|
||||
service = ToolManagementService(db)
|
||||
|
||||
category = await service.create_category(
|
||||
name=category_data.name,
|
||||
display_name=category_data.display_name,
|
||||
description=category_data.description,
|
||||
icon=category_data.icon,
|
||||
color=category_data.color,
|
||||
sort_order=category_data.sort_order,
|
||||
)
|
||||
|
||||
return ToolCategoryResponse.from_orm(category)
|
||||
|
||||
|
||||
@router.get("/categories", response_model=List[ToolCategoryResponse])
|
||||
async def list_categories(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""List all active tool categories"""
|
||||
service = ToolManagementService(db)
|
||||
|
||||
categories = await service.get_categories()
|
||||
return [ToolCategoryResponse.from_orm(category) for category in categories]
|
||||
|
||||
|
||||
# Statistics
|
||||
|
||||
|
||||
@router.get("/statistics", response_model=ToolStatisticsResponse)
|
||||
async def get_statistics(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Get tool usage statistics"""
|
||||
service = ToolManagementService(db)
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
|
||||
stats = await service.get_tool_statistics(user_id=user_id)
|
||||
return ToolStatisticsResponse(**stats)
|
||||
703
backend/app/api/v1/endpoints/user_management.py
Normal file
703
backend/app/api/v1/endpoints/user_management.py
Normal file
@@ -0,0 +1,703 @@
|
||||
"""
|
||||
User Management API endpoints
|
||||
Admin endpoints for managing users, roles, and audit logs
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi.security import HTTPBearer
|
||||
from pydantic import BaseModel, EmailStr, validator, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.db.database import get_db
|
||||
from app.models.user import User
|
||||
from app.models.role import Role
|
||||
from app.models.audit_log import AuditLog
|
||||
from app.services.user_management_service import UserManagementService
|
||||
from app.services.permission_manager import require_permission
|
||||
from app.schemas.role import RoleCreate, RoleUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
class CreateUserRequest(BaseModel):
|
||||
email: EmailStr
|
||||
username: str
|
||||
password: str
|
||||
full_name: Optional[str] = None
|
||||
role_id: Optional[int] = None
|
||||
is_active: bool = True
|
||||
force_password_change: bool = False
|
||||
|
||||
@validator("password")
|
||||
def validate_password(cls, v):
|
||||
if len(v) < 8:
|
||||
raise ValueError("Password must be at least 8 characters long")
|
||||
return v
|
||||
|
||||
@validator("username")
|
||||
def validate_username(cls, v):
|
||||
if len(v) < 3:
|
||||
raise ValueError("Username must be at least 3 characters long")
|
||||
if not v.replace("_", "").replace("-", "").isalnum():
|
||||
raise ValueError("Username must contain only alphanumeric characters, underscores, and hyphens")
|
||||
return v
|
||||
|
||||
|
||||
class UpdateUserRequest(BaseModel):
|
||||
email: Optional[EmailStr] = None
|
||||
username: Optional[str] = None
|
||||
full_name: Optional[str] = None
|
||||
role_id: Optional[int] = None
|
||||
is_active: Optional[bool] = None
|
||||
is_verified: Optional[bool] = None
|
||||
custom_permissions: Optional[Dict[str, Any]] = None
|
||||
|
||||
@validator("username")
|
||||
def validate_username(cls, v):
|
||||
if v is not None and len(v) < 3:
|
||||
raise ValueError("Username must be at least 3 characters long")
|
||||
return v
|
||||
|
||||
|
||||
class AdminPasswordResetRequest(BaseModel):
|
||||
new_password: str
|
||||
force_change_on_login: bool = True
|
||||
|
||||
@validator("new_password")
|
||||
def validate_password(cls, v):
|
||||
if len(v) < 8:
|
||||
raise ValueError("Password must be at least 8 characters long")
|
||||
return v
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: int
|
||||
email: str
|
||||
username: str
|
||||
full_name: Optional[str]
|
||||
role_id: Optional[int]
|
||||
role: Optional[Dict[str, Any]]
|
||||
is_active: bool
|
||||
is_verified: bool
|
||||
account_locked: bool
|
||||
force_password_change: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_login: Optional[datetime]
|
||||
audit_summary: Optional[Dict[str, Any]] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class UserListResponse(BaseModel):
|
||||
users: List[UserResponse]
|
||||
total: int
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
|
||||
class RoleResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str]
|
||||
level: str
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class AuditLogResponse(BaseModel):
|
||||
id: int
|
||||
user_id: Optional[int]
|
||||
action: str
|
||||
resource_type: str
|
||||
resource_id: Optional[str]
|
||||
description: str
|
||||
details: Dict[str, Any]
|
||||
severity: str
|
||||
category: Optional[str]
|
||||
success: bool
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# User Management Endpoints
|
||||
@router.get("/users", response_model=UserListResponse)
|
||||
async def get_users(
|
||||
request: Request,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
search: Optional[str] = None,
|
||||
role_id: Optional[int] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
include_audit_summary: bool = False,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get all users with filtering options"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:users:read",
|
||||
context={"user_id": current_user["id"]}
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
users_data = await service.get_users(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
search=search,
|
||||
role_id=role_id,
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
# Convert to response format
|
||||
users = []
|
||||
for user in users_data:
|
||||
user_dict = user.to_dict()
|
||||
if include_audit_summary:
|
||||
# Get audit summary for user
|
||||
audit_logs = await service.get_user_audit_logs(user.id, limit=10)
|
||||
user_dict["audit_summary"] = {
|
||||
"recent_actions": len(audit_logs),
|
||||
"last_login": user.last_login.isoformat() if user.last_login else None,
|
||||
}
|
||||
user_response = UserResponse(**user_dict)
|
||||
users.append(user_response)
|
||||
|
||||
return UserListResponse(
|
||||
users=users,
|
||||
total=len(users), # Would need actual count query for large datasets
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/users/{user_id}", response_model=UserResponse)
|
||||
async def get_user(
|
||||
user_id: int,
|
||||
include_audit_summary: bool = False,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get specific user by ID"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:users:read",
|
||||
context={"user_id": current_user["id"], "owner_id": user_id}
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
user = await service.get_user_by_id(user_id)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
user_dict = user.to_dict()
|
||||
|
||||
if include_audit_summary:
|
||||
# Get audit summary for user
|
||||
audit_logs = await service.get_user_audit_logs(user_id, limit=10)
|
||||
user_dict["audit_summary"] = {
|
||||
"recent_actions": len(audit_logs),
|
||||
"last_login": user.last_login.isoformat() if user.last_login else None,
|
||||
}
|
||||
|
||||
return UserResponse(**user_dict)
|
||||
|
||||
|
||||
@router.post("/users", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_user(
|
||||
user_data: CreateUserRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a new user"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:users:create",
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
|
||||
user = await service.create_user(
|
||||
email=user_data.email,
|
||||
username=user_data.username,
|
||||
password=user_data.password,
|
||||
full_name=user_data.full_name,
|
||||
role_id=user_data.role_id,
|
||||
is_active=user_data.is_active,
|
||||
is_verified=True, # Admin-created users are verified by default
|
||||
custom_permissions={}, # Empty by default
|
||||
)
|
||||
|
||||
# Log user creation in audit
|
||||
await service._log_audit_event(
|
||||
user_id=current_user["id"],
|
||||
action="create",
|
||||
resource_type="user",
|
||||
resource_id=str(user.id),
|
||||
description=f"User created by admin: {user.email}",
|
||||
details={
|
||||
"created_by": current_user["email"],
|
||||
"target_user": user.email,
|
||||
"role_id": user_data.role_id,
|
||||
},
|
||||
severity="medium",
|
||||
)
|
||||
|
||||
user_dict = user.to_dict()
|
||||
return UserResponse(**user_dict)
|
||||
|
||||
|
||||
@router.put("/users/{user_id}", response_model=UserResponse)
|
||||
async def update_user(
|
||||
user_id: int,
|
||||
user_data: UpdateUserRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update user information"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:users:update",
|
||||
context={"user_id": current_user["id"], "owner_id": user_id}
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
|
||||
# Get current user for audit comparison
|
||||
current_user_data = await service.get_user_by_id(user_id)
|
||||
if not current_user_data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
old_values = current_user_data.to_dict()
|
||||
|
||||
# Update user
|
||||
user = await service.update_user(
|
||||
user_id=user_id,
|
||||
email=user_data.email,
|
||||
username=user_data.username,
|
||||
full_name=user_data.full_name,
|
||||
role_id=user_data.role_id,
|
||||
is_active=user_data.is_active,
|
||||
is_verified=user_data.is_verified,
|
||||
custom_permissions=user_data.custom_permissions,
|
||||
)
|
||||
|
||||
# Log user update in audit
|
||||
await service._log_audit_event(
|
||||
user_id=current_user["id"],
|
||||
action="update",
|
||||
resource_type="user",
|
||||
resource_id=str(user_id),
|
||||
description=f"User updated by admin: {user.email}",
|
||||
details={
|
||||
"updated_by": current_user["email"],
|
||||
"target_user": user.email,
|
||||
},
|
||||
old_values=old_values,
|
||||
new_values=user.to_dict(),
|
||||
severity="medium",
|
||||
)
|
||||
|
||||
user_dict = user.to_dict()
|
||||
return UserResponse(**user_dict)
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/password-reset")
|
||||
async def admin_reset_password(
|
||||
user_id: int,
|
||||
password_data: AdminPasswordResetRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Admin reset user password with forced change option"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:users:manage",
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
|
||||
user = await service.admin_reset_user_password(
|
||||
user_id=user_id,
|
||||
new_password=password_data.new_password,
|
||||
force_change_on_login=password_data.force_change_on_login,
|
||||
admin_user_id=current_user["id"],
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Password reset for user {user.email}",
|
||||
"force_change_on_login": password_data.force_change_on_login,
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}")
|
||||
async def delete_user(
|
||||
user_id: int,
|
||||
hard_delete: bool = False,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Delete or deactivate user"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:users:delete",
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
|
||||
# Get user info for audit before deletion
|
||||
user = await service.get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
user_email = user.email
|
||||
|
||||
# Delete user
|
||||
success = await service.delete_user(user_id, hard_delete=hard_delete)
|
||||
|
||||
# Log user deletion in audit
|
||||
await service._log_audit_event(
|
||||
user_id=current_user["id"],
|
||||
action="delete",
|
||||
resource_type="user",
|
||||
resource_id=str(user_id),
|
||||
description=f"User {'hard deleted' if hard_delete else 'deactivated'} by admin: {user_email}",
|
||||
details={
|
||||
"deleted_by": current_user["email"],
|
||||
"target_user": user_email,
|
||||
"hard_delete": hard_delete,
|
||||
},
|
||||
severity="high",
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"User {'deleted' if hard_delete else 'deactivated'} successfully",
|
||||
"user_email": user_email,
|
||||
}
|
||||
|
||||
|
||||
# Role Management Endpoints
|
||||
@router.get("/roles", response_model=List[RoleResponse])
|
||||
async def get_roles(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get all available roles"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:roles:read",
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
roles = await service.get_roles(is_active=True)
|
||||
|
||||
return [RoleResponse(**role.to_dict()) for role in roles]
|
||||
|
||||
|
||||
@router.post("/roles", response_model=RoleResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_role(
|
||||
role_data: RoleCreate,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a new role"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:roles:create",
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
|
||||
# Create role
|
||||
role = await service.create_role(
|
||||
name=role_data.name,
|
||||
display_name=role_data.display_name,
|
||||
description=role_data.description,
|
||||
level=role_data.level,
|
||||
permissions=role_data.permissions,
|
||||
can_manage_users=role_data.can_manage_users,
|
||||
can_manage_budgets=role_data.can_manage_budgets,
|
||||
can_view_reports=role_data.can_view_reports,
|
||||
can_manage_tools=role_data.can_manage_tools,
|
||||
inherits_from=role_data.inherits_from,
|
||||
is_active=role_data.is_active,
|
||||
is_system_role=role_data.is_system_role,
|
||||
)
|
||||
|
||||
# Log role creation
|
||||
await service._log_audit_event(
|
||||
user_id=current_user["id"],
|
||||
action="create",
|
||||
resource_type="role",
|
||||
resource_id=str(role.id),
|
||||
description=f"Role created: {role.name}",
|
||||
details={
|
||||
"created_by": current_user["email"],
|
||||
"role_name": role.name,
|
||||
"level": role.level,
|
||||
},
|
||||
severity="medium",
|
||||
)
|
||||
|
||||
return RoleResponse(**role.to_dict())
|
||||
|
||||
|
||||
@router.put("/roles/{role_id}", response_model=RoleResponse)
|
||||
async def update_role(
|
||||
role_id: int,
|
||||
role_data: RoleUpdate,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update a role"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:roles:update",
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
|
||||
# Get current role for audit
|
||||
current_role = await service.get_role_by_id(role_id)
|
||||
if not current_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Role not found"
|
||||
)
|
||||
|
||||
# Prevent updating system roles
|
||||
if current_role.is_system_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Cannot modify system roles"
|
||||
)
|
||||
|
||||
old_values = current_role.to_dict()
|
||||
|
||||
# Update role
|
||||
role = await service.update_role(
|
||||
role_id=role_id,
|
||||
display_name=role_data.display_name,
|
||||
description=role_data.description,
|
||||
permissions=role_data.permissions,
|
||||
can_manage_users=role_data.can_manage_users,
|
||||
can_manage_budgets=role_data.can_manage_budgets,
|
||||
can_view_reports=role_data.can_view_reports,
|
||||
can_manage_tools=role_data.can_manage_tools,
|
||||
is_active=role_data.is_active,
|
||||
)
|
||||
|
||||
# Log role update
|
||||
await service._log_audit_event(
|
||||
user_id=current_user["id"],
|
||||
action="update",
|
||||
resource_type="role",
|
||||
resource_id=str(role_id),
|
||||
description=f"Role updated: {role.name}",
|
||||
details={
|
||||
"updated_by": current_user["email"],
|
||||
"role_name": role.name,
|
||||
},
|
||||
old_values=old_values,
|
||||
new_values=role.to_dict(),
|
||||
severity="medium",
|
||||
)
|
||||
|
||||
return RoleResponse(**role.to_dict())
|
||||
|
||||
|
||||
@router.delete("/roles/{role_id}")
|
||||
async def delete_role(
|
||||
role_id: int,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Delete a role"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:roles:delete",
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
|
||||
# Get role info for audit before deletion
|
||||
role = await service.get_role_by_id(role_id)
|
||||
if not role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Role not found"
|
||||
)
|
||||
|
||||
# Prevent deleting system roles
|
||||
if role.is_system_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Cannot delete system roles"
|
||||
)
|
||||
|
||||
role_name = role.name
|
||||
|
||||
# Delete role
|
||||
success = await service.delete_role(role_id)
|
||||
|
||||
# Log role deletion
|
||||
await service._log_audit_event(
|
||||
user_id=current_user["id"],
|
||||
action="delete",
|
||||
resource_type="role",
|
||||
resource_id=str(role_id),
|
||||
description=f"Role deleted: {role_name}",
|
||||
details={
|
||||
"deleted_by": current_user["email"],
|
||||
"role_name": role_name,
|
||||
},
|
||||
severity="high",
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Role {role_name} deleted successfully",
|
||||
"role_name": role_name,
|
||||
}
|
||||
|
||||
|
||||
# Audit Log Endpoints
|
||||
@router.get("/users/{user_id}/audit-logs", response_model=List[AuditLogResponse])
|
||||
async def get_user_audit_logs(
|
||||
user_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
action_filter: Optional[str] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get audit logs for a specific user"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:audit:read",
|
||||
context={"user_id": current_user["id"], "owner_id": user_id}
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
audit_logs = await service.get_user_audit_logs(
|
||||
user_id=user_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
action_filter=action_filter,
|
||||
)
|
||||
|
||||
return [AuditLogResponse(**log.to_dict()) for log in audit_logs]
|
||||
|
||||
|
||||
@router.get("/audit-logs", response_model=List[AuditLogResponse])
|
||||
async def get_all_audit_logs(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
user_id: Optional[int] = None,
|
||||
action_filter: Optional[str] = None,
|
||||
category_filter: Optional[str] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get all audit logs with filtering"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:audit:read",
|
||||
)
|
||||
|
||||
# Direct database query for audit logs with filters
|
||||
from sqlalchemy import select, and_, desc
|
||||
|
||||
query = select(AuditLog)
|
||||
conditions = []
|
||||
|
||||
if user_id:
|
||||
conditions.append(AuditLog.user_id == user_id)
|
||||
if action_filter:
|
||||
conditions.append(AuditLog.action == action_filter)
|
||||
if category_filter:
|
||||
conditions.append(AuditLog.category == category_filter)
|
||||
|
||||
if conditions:
|
||||
query = query.where(and_(*conditions))
|
||||
|
||||
query = query.order_by(desc(AuditLog.created_at))
|
||||
query = query.offset(skip).limit(limit)
|
||||
|
||||
result = await db.execute(query)
|
||||
audit_logs = result.scalars().all()
|
||||
|
||||
return [AuditLogResponse(**log.to_dict()) for log in audit_logs]
|
||||
|
||||
|
||||
# Statistics Endpoints
|
||||
@router.get("/statistics")
|
||||
async def get_user_management_statistics(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get user management statistics"""
|
||||
|
||||
# Check permission
|
||||
require_permission(
|
||||
current_user.get("permissions", []),
|
||||
"platform:users:read",
|
||||
)
|
||||
|
||||
service = UserManagementService(db)
|
||||
user_stats = await service.get_user_statistics()
|
||||
role_stats = await service.get_role_statistics()
|
||||
|
||||
return {
|
||||
"users": user_stats,
|
||||
"roles": role_stats,
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
@@ -12,16 +12,33 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.services.api_key_auth import require_api_key, RequireScope, APIKeyAuthService, get_api_key_context
|
||||
from app.services.api_key_auth import (
|
||||
require_api_key,
|
||||
RequireScope,
|
||||
APIKeyAuthService,
|
||||
get_api_key_context,
|
||||
)
|
||||
from app.core.security import get_current_user
|
||||
from app.models.user import User
|
||||
from app.core.config import settings
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import ChatRequest, ChatMessage as LLMChatMessage, EmbeddingRequest as LLMEmbeddingRequest
|
||||
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError, ValidationError
|
||||
from app.services.llm.models import (
|
||||
ChatRequest,
|
||||
ChatMessage as LLMChatMessage,
|
||||
EmbeddingRequest as LLMEmbeddingRequest,
|
||||
)
|
||||
from app.services.llm.exceptions import (
|
||||
LLMError,
|
||||
ProviderError,
|
||||
SecurityError,
|
||||
ValidationError,
|
||||
)
|
||||
from app.services.budget_enforcement import (
|
||||
check_budget_for_request, record_request_usage, BudgetEnforcementService,
|
||||
atomic_check_and_reserve_budget, atomic_finalize_usage
|
||||
check_budget_for_request,
|
||||
record_request_usage,
|
||||
BudgetEnforcementService,
|
||||
atomic_check_and_reserve_budget,
|
||||
atomic_finalize_usage,
|
||||
)
|
||||
from app.services.cost_calculator import CostCalculator, estimate_request_cost
|
||||
from app.utils.exceptions import AuthenticationError, AuthorizationError
|
||||
@@ -30,11 +47,7 @@ from app.middleware.analytics import set_analytics_data
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Models response cache - simple in-memory cache for performance
|
||||
_models_cache = {
|
||||
"data": None,
|
||||
"cached_at": 0,
|
||||
"cache_ttl": 900 # 15 minutes cache TTL
|
||||
}
|
||||
_models_cache = {"data": None, "cached_at": 0, "cache_ttl": 900} # 15 minutes cache TTL
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -42,18 +55,20 @@ router = APIRouter()
|
||||
async def get_cached_models() -> List[Dict[str, Any]]:
|
||||
"""Get models from cache or fetch from LLM service if cache is stale"""
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
# Check if cache is still valid
|
||||
if (_models_cache["data"] is not None and
|
||||
current_time - _models_cache["cached_at"] < _models_cache["cache_ttl"]):
|
||||
if (
|
||||
_models_cache["data"] is not None
|
||||
and current_time - _models_cache["cached_at"] < _models_cache["cache_ttl"]
|
||||
):
|
||||
logger.debug("Returning cached models list")
|
||||
return _models_cache["data"]
|
||||
|
||||
|
||||
# Cache miss or stale - fetch from LLM service
|
||||
try:
|
||||
logger.debug("Fetching fresh models list from LLM service")
|
||||
model_infos = await llm_service.get_models()
|
||||
|
||||
|
||||
# Convert ModelInfo objects to dict format for compatibility
|
||||
models = []
|
||||
for model_info in model_infos:
|
||||
@@ -63,32 +78,36 @@ async def get_cached_models() -> List[Dict[str, Any]]:
|
||||
"created": model_info.created or int(time.time()),
|
||||
"owned_by": model_info.owned_by,
|
||||
# Add frontend-expected fields
|
||||
"name": getattr(model_info, 'name', model_info.id), # Use name if available, fallback to id
|
||||
"provider": getattr(model_info, 'provider', model_info.owned_by), # Use provider if available, fallback to owned_by
|
||||
"name": getattr(
|
||||
model_info, "name", model_info.id
|
||||
), # Use name if available, fallback to id
|
||||
"provider": getattr(
|
||||
model_info, "provider", model_info.owned_by
|
||||
), # Use provider if available, fallback to owned_by
|
||||
"capabilities": model_info.capabilities,
|
||||
"context_window": model_info.context_window,
|
||||
"max_output_tokens": model_info.max_output_tokens,
|
||||
"supports_streaming": model_info.supports_streaming,
|
||||
"supports_function_calling": model_info.supports_function_calling
|
||||
"supports_function_calling": model_info.supports_function_calling,
|
||||
}
|
||||
# Include tasks field if present
|
||||
if model_info.tasks:
|
||||
model_dict["tasks"] = model_info.tasks
|
||||
models.append(model_dict)
|
||||
|
||||
|
||||
# Update cache
|
||||
_models_cache["data"] = models
|
||||
_models_cache["cached_at"] = current_time
|
||||
|
||||
|
||||
return models
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch models from LLM service: {e}")
|
||||
|
||||
|
||||
# Return stale cache if available, otherwise empty list
|
||||
if _models_cache["data"] is not None:
|
||||
logger.warning("Returning stale cached models due to fetch error")
|
||||
return _models_cache["data"]
|
||||
|
||||
|
||||
return []
|
||||
|
||||
|
||||
@@ -138,11 +157,12 @@ class ModelsResponse(BaseModel):
|
||||
# Authentication: Public API endpoints should use require_api_key
|
||||
# Internal API endpoints should use get_current_user from core.security
|
||||
|
||||
|
||||
# Endpoints
|
||||
@router.get("/models", response_model=ModelsResponse)
|
||||
async def list_models(
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List available models"""
|
||||
try:
|
||||
@@ -155,33 +175,35 @@ async def list_models(
|
||||
if not await auth_service.check_scope_permission(context, "models.list"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Insufficient permissions to list models"
|
||||
detail="Insufficient permissions to list models",
|
||||
)
|
||||
|
||||
|
||||
# Get models from cache or LLM service
|
||||
models = await get_cached_models()
|
||||
|
||||
|
||||
# Filter models based on API key permissions
|
||||
api_key = context.get("api_key")
|
||||
if api_key and api_key.allowed_models:
|
||||
models = [model for model in models if model.get("id") in api_key.allowed_models]
|
||||
|
||||
models = [
|
||||
model for model in models if model.get("id") in api_key.allowed_models
|
||||
]
|
||||
|
||||
return ModelsResponse(data=models)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing models: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to list models"
|
||||
detail="Failed to list models",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/models/invalidate-cache")
|
||||
async def invalidate_models_cache_endpoint(
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Invalidate models cache (admin only)"""
|
||||
# Check for admin permissions
|
||||
@@ -190,7 +212,7 @@ async def invalidate_models_cache_endpoint(
|
||||
if not user or not user.get("is_superuser"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin privileges required"
|
||||
detail="Admin privileges required",
|
||||
)
|
||||
else:
|
||||
# For API key users, check admin permissions
|
||||
@@ -198,9 +220,9 @@ async def invalidate_models_cache_endpoint(
|
||||
if not await auth_service.check_scope_permission(context, "admin.cache"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin permissions required to invalidate cache"
|
||||
detail="Admin permissions required to invalidate cache",
|
||||
)
|
||||
|
||||
|
||||
invalidate_models_cache()
|
||||
return {"message": "Models cache invalidated successfully"}
|
||||
|
||||
@@ -210,34 +232,38 @@ async def create_chat_completion(
|
||||
request_body: Request,
|
||||
chat_request: ChatCompletionRequest,
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create chat completion with budget enforcement"""
|
||||
try:
|
||||
auth_type = context.get("auth_type", "api_key")
|
||||
|
||||
|
||||
# Handle different authentication types
|
||||
if auth_type == "api_key":
|
||||
auth_service = APIKeyAuthService(db)
|
||||
|
||||
|
||||
# Check permissions
|
||||
if not await auth_service.check_scope_permission(context, "chat.completions"):
|
||||
if not await auth_service.check_scope_permission(
|
||||
context, "chat.completions"
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Insufficient permissions for chat completions"
|
||||
detail="Insufficient permissions for chat completions",
|
||||
)
|
||||
|
||||
if not await auth_service.check_model_permission(context, chat_request.model):
|
||||
|
||||
if not await auth_service.check_model_permission(
|
||||
context, chat_request.model
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Model '{chat_request.model}' not allowed"
|
||||
detail=f"Model '{chat_request.model}' not allowed",
|
||||
)
|
||||
|
||||
|
||||
api_key = context.get("api_key")
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="API key information not available"
|
||||
detail="API key information not available",
|
||||
)
|
||||
elif auth_type == "jwt":
|
||||
# For JWT authentication, we'll skip the detailed permission checks for now
|
||||
@@ -246,15 +272,15 @@ async def create_chat_completion(
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User information not available"
|
||||
detail="User information not available",
|
||||
)
|
||||
api_key = None # JWT users don't have API keys
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication type"
|
||||
detail="Invalid authentication type",
|
||||
)
|
||||
|
||||
|
||||
# Estimate token usage for budget checking
|
||||
messages_text = " ".join([msg.content for msg in chat_request.messages])
|
||||
estimated_tokens = len(messages_text.split()) * 1.3 # Rough token estimation
|
||||
@@ -262,31 +288,44 @@ async def create_chat_completion(
|
||||
estimated_tokens += chat_request.max_tokens
|
||||
else:
|
||||
estimated_tokens += 150 # Default response length estimate
|
||||
|
||||
|
||||
# Get a synchronous session for budget enforcement
|
||||
from app.db.database import SessionLocal
|
||||
|
||||
sync_db = SessionLocal()
|
||||
|
||||
|
||||
try:
|
||||
# Atomic budget check and reservation (only for API key users)
|
||||
warnings = []
|
||||
reserved_budget_ids = []
|
||||
if auth_type == "api_key" and api_key:
|
||||
is_allowed, error_message, budget_warnings, budget_ids = atomic_check_and_reserve_budget(
|
||||
sync_db, api_key, chat_request.model, int(estimated_tokens), "chat/completions"
|
||||
(
|
||||
is_allowed,
|
||||
error_message,
|
||||
budget_warnings,
|
||||
budget_ids,
|
||||
) = atomic_check_and_reserve_budget(
|
||||
sync_db,
|
||||
api_key,
|
||||
chat_request.model,
|
||||
int(estimated_tokens),
|
||||
"chat/completions",
|
||||
)
|
||||
|
||||
|
||||
if not is_allowed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Budget exceeded: {error_message}"
|
||||
detail=f"Budget exceeded: {error_message}",
|
||||
)
|
||||
warnings = budget_warnings
|
||||
reserved_budget_ids = budget_ids
|
||||
|
||||
|
||||
# Convert messages to LLM service format
|
||||
llm_messages = [LLMChatMessage(role=msg.role, content=msg.content) for msg in chat_request.messages]
|
||||
|
||||
llm_messages = [
|
||||
LLMChatMessage(role=msg.role, content=msg.content)
|
||||
for msg in chat_request.messages
|
||||
]
|
||||
|
||||
# Create LLM service request
|
||||
llm_request = ChatRequest(
|
||||
model=chat_request.model,
|
||||
@@ -299,12 +338,14 @@ async def create_chat_completion(
|
||||
stop=chat_request.stop,
|
||||
stream=chat_request.stream or False,
|
||||
user_id=str(context.get("user_id", "anonymous")),
|
||||
api_key_id=context.get("api_key_id", 0) if auth_type == "api_key" else 0
|
||||
api_key_id=context.get("api_key_id", 0)
|
||||
if auth_type == "api_key"
|
||||
else 0,
|
||||
)
|
||||
|
||||
|
||||
# Make request to LLM service
|
||||
llm_response = await llm_service.create_chat_completion(llm_request)
|
||||
|
||||
|
||||
# Convert LLM service response to API format
|
||||
response = {
|
||||
"id": llm_response.id,
|
||||
@@ -316,45 +357,56 @@ async def create_chat_completion(
|
||||
"index": choice.index,
|
||||
"message": {
|
||||
"role": choice.message.role,
|
||||
"content": choice.message.content
|
||||
"content": choice.message.content,
|
||||
},
|
||||
"finish_reason": choice.finish_reason
|
||||
"finish_reason": choice.finish_reason,
|
||||
}
|
||||
for choice in llm_response.choices
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0,
|
||||
"completion_tokens": llm_response.usage.completion_tokens if llm_response.usage else 0,
|
||||
"total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0
|
||||
} if llm_response.usage else {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
"prompt_tokens": llm_response.usage.prompt_tokens
|
||||
if llm_response.usage
|
||||
else 0,
|
||||
"completion_tokens": llm_response.usage.completion_tokens
|
||||
if llm_response.usage
|
||||
else 0,
|
||||
"total_tokens": llm_response.usage.total_tokens
|
||||
if llm_response.usage
|
||||
else 0,
|
||||
}
|
||||
if llm_response.usage
|
||||
else {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
|
||||
# Calculate actual cost and update usage
|
||||
usage = response.get("usage", {})
|
||||
input_tokens = usage.get("prompt_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0)
|
||||
total_tokens = usage.get("total_tokens", input_tokens + output_tokens)
|
||||
|
||||
|
||||
# Calculate accurate cost
|
||||
actual_cost_cents = CostCalculator.calculate_cost_cents(
|
||||
chat_request.model, input_tokens, output_tokens
|
||||
)
|
||||
|
||||
|
||||
# Finalize actual usage in budgets (only for API key users)
|
||||
if auth_type == "api_key" and api_key:
|
||||
atomic_finalize_usage(
|
||||
sync_db, reserved_budget_ids, api_key, chat_request.model,
|
||||
input_tokens, output_tokens, "chat/completions"
|
||||
sync_db,
|
||||
reserved_budget_ids,
|
||||
api_key,
|
||||
chat_request.model,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
"chat/completions",
|
||||
)
|
||||
|
||||
|
||||
# Update API key usage statistics
|
||||
auth_service = APIKeyAuthService(db)
|
||||
await auth_service.update_usage_stats(context, total_tokens, actual_cost_cents)
|
||||
|
||||
await auth_service.update_usage_stats(
|
||||
context, total_tokens, actual_cost_cents
|
||||
)
|
||||
|
||||
# Set analytics data for middleware
|
||||
set_analytics_data(
|
||||
model=chat_request.model,
|
||||
@@ -363,55 +415,55 @@ async def create_chat_completion(
|
||||
total_tokens=total_tokens,
|
||||
cost_cents=actual_cost_cents,
|
||||
budget_ids=reserved_budget_ids,
|
||||
budget_warnings=warnings
|
||||
budget_warnings=warnings,
|
||||
)
|
||||
|
||||
|
||||
# Add budget warnings to response if any
|
||||
if warnings:
|
||||
response["budget_warnings"] = warnings
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
finally:
|
||||
sync_db.close()
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except SecurityError as e:
|
||||
logger.warning(f"Security error in chat completion: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Security validation failed: {e.message}"
|
||||
detail=f"Security validation failed: {e.message}",
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Validation error in chat completion: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Request validation failed: {e.message}"
|
||||
detail=f"Request validation failed: {e.message}",
|
||||
)
|
||||
except ProviderError as e:
|
||||
logger.error(f"Provider error in chat completion: {e}")
|
||||
if "rate limit" in str(e).lower():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Rate limit exceeded"
|
||||
detail="Rate limit exceeded",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LLM service temporarily unavailable"
|
||||
detail="LLM service temporarily unavailable",
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.error(f"LLM service error in chat completion: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="LLM service error"
|
||||
detail="LLM service error",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error creating chat completion: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create chat completion"
|
||||
detail="Failed to create chat completion",
|
||||
)
|
||||
|
||||
|
||||
@@ -419,62 +471,62 @@ async def create_chat_completion(
|
||||
async def create_embedding(
|
||||
request: EmbeddingRequest,
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create embedding with budget enforcement"""
|
||||
try:
|
||||
auth_service = APIKeyAuthService(db)
|
||||
|
||||
|
||||
# Check permissions
|
||||
if not await auth_service.check_scope_permission(context, "embeddings.create"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Insufficient permissions for embeddings"
|
||||
detail="Insufficient permissions for embeddings",
|
||||
)
|
||||
|
||||
|
||||
if not await auth_service.check_model_permission(context, request.model):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Model '{request.model}' not allowed"
|
||||
detail=f"Model '{request.model}' not allowed",
|
||||
)
|
||||
|
||||
|
||||
api_key = context.get("api_key")
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="API key information not available"
|
||||
detail="API key information not available",
|
||||
)
|
||||
|
||||
|
||||
# Estimate token usage for budget checking
|
||||
estimated_tokens = len(request.input.split()) * 1.3 # Rough token estimation
|
||||
|
||||
|
||||
# Convert AsyncSession to Session for budget enforcement
|
||||
sync_db = Session(bind=db.bind.sync_engine)
|
||||
|
||||
|
||||
try:
|
||||
# Check budget compliance before making request
|
||||
is_allowed, error_message, warnings = check_budget_for_request(
|
||||
sync_db, api_key, request.model, int(estimated_tokens), "embeddings"
|
||||
)
|
||||
|
||||
|
||||
if not is_allowed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Budget exceeded: {error_message}"
|
||||
detail=f"Budget exceeded: {error_message}",
|
||||
)
|
||||
|
||||
|
||||
# Create LLM service request
|
||||
llm_request = LLMEmbeddingRequest(
|
||||
model=request.model,
|
||||
input=request.input,
|
||||
encoding_format=request.encoding_format,
|
||||
user_id=str(context["user_id"]),
|
||||
api_key_id=context["api_key_id"]
|
||||
api_key_id=context["api_key_id"],
|
||||
)
|
||||
|
||||
|
||||
# Make request to LLM service
|
||||
llm_response = await llm_service.create_embedding(llm_request)
|
||||
|
||||
|
||||
# Convert LLM service response to API format
|
||||
response = {
|
||||
"object": llm_response.object,
|
||||
@@ -482,139 +534,142 @@ async def create_embedding(
|
||||
{
|
||||
"object": emb.object,
|
||||
"index": emb.index,
|
||||
"embedding": emb.embedding
|
||||
"embedding": emb.embedding,
|
||||
}
|
||||
for emb in llm_response.data
|
||||
],
|
||||
"model": llm_response.model,
|
||||
"usage": {
|
||||
"prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0,
|
||||
"total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0
|
||||
} if llm_response.usage else {
|
||||
"prompt_tokens": int(estimated_tokens),
|
||||
"total_tokens": int(estimated_tokens)
|
||||
"prompt_tokens": llm_response.usage.prompt_tokens
|
||||
if llm_response.usage
|
||||
else 0,
|
||||
"total_tokens": llm_response.usage.total_tokens
|
||||
if llm_response.usage
|
||||
else 0,
|
||||
}
|
||||
if llm_response.usage
|
||||
else {
|
||||
"prompt_tokens": int(estimated_tokens),
|
||||
"total_tokens": int(estimated_tokens),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Calculate actual cost and update usage
|
||||
usage = response.get("usage", {})
|
||||
total_tokens = usage.get("total_tokens", int(estimated_tokens))
|
||||
|
||||
|
||||
# Calculate accurate cost (embeddings typically use input tokens only)
|
||||
actual_cost_cents = CostCalculator.calculate_cost_cents(
|
||||
request.model, total_tokens, 0
|
||||
)
|
||||
|
||||
|
||||
# Record actual usage in budgets
|
||||
record_request_usage(
|
||||
sync_db, api_key, request.model, total_tokens, 0, "embeddings"
|
||||
)
|
||||
|
||||
|
||||
# Update API key usage statistics
|
||||
await auth_service.update_usage_stats(context, total_tokens, actual_cost_cents)
|
||||
|
||||
await auth_service.update_usage_stats(
|
||||
context, total_tokens, actual_cost_cents
|
||||
)
|
||||
|
||||
# Add budget warnings to response if any
|
||||
if warnings:
|
||||
response["budget_warnings"] = warnings
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
finally:
|
||||
sync_db.close()
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except SecurityError as e:
|
||||
logger.warning(f"Security error in embedding: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Security validation failed: {e.message}"
|
||||
detail=f"Security validation failed: {e.message}",
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Validation error in embedding: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Request validation failed: {e.message}"
|
||||
detail=f"Request validation failed: {e.message}",
|
||||
)
|
||||
except ProviderError as e:
|
||||
logger.error(f"Provider error in embedding: {e}")
|
||||
if "rate limit" in str(e).lower():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Rate limit exceeded"
|
||||
detail="Rate limit exceeded",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LLM service temporarily unavailable"
|
||||
detail="LLM service temporarily unavailable",
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.error(f"LLM service error in embedding: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="LLM service error"
|
||||
detail="LLM service error",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error creating embedding: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create embedding"
|
||||
detail="Failed to create embedding",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def llm_health_check(
|
||||
context: Dict[str, Any] = Depends(require_api_key)
|
||||
):
|
||||
async def llm_health_check(context: Dict[str, Any] = Depends(require_api_key)):
|
||||
"""Health check for LLM service"""
|
||||
try:
|
||||
health_summary = llm_service.get_health_summary()
|
||||
provider_status = await llm_service.get_provider_status()
|
||||
|
||||
|
||||
# Determine overall health
|
||||
overall_status = "healthy"
|
||||
if health_summary["service_status"] != "healthy":
|
||||
overall_status = "degraded"
|
||||
|
||||
|
||||
for provider, status in provider_status.items():
|
||||
if status.status == "unavailable":
|
||||
overall_status = "degraded"
|
||||
break
|
||||
|
||||
|
||||
return {
|
||||
"status": overall_status,
|
||||
"service": "LLM Service",
|
||||
"service_status": health_summary,
|
||||
"provider_status": {name: {
|
||||
"status": status.status,
|
||||
"latency_ms": status.latency_ms,
|
||||
"error_message": status.error_message
|
||||
} for name, status in provider_status.items()},
|
||||
"provider_status": {
|
||||
name: {
|
||||
"status": status.status,
|
||||
"latency_ms": status.latency_ms,
|
||||
"error_message": status.error_message,
|
||||
}
|
||||
for name, status in provider_status.items()
|
||||
},
|
||||
"user_id": context["user_id"],
|
||||
"api_key_name": context["api_key_name"]
|
||||
"api_key_name": context["api_key_name"],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"LLM health check error: {e}")
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"service": "LLM Service",
|
||||
"error": str(e)
|
||||
}
|
||||
return {"status": "unhealthy", "service": "LLM Service", "error": str(e)}
|
||||
|
||||
|
||||
@router.get("/usage")
|
||||
async def get_usage_stats(
|
||||
context: Dict[str, Any] = Depends(require_api_key)
|
||||
):
|
||||
async def get_usage_stats(context: Dict[str, Any] = Depends(require_api_key)):
|
||||
"""Get usage statistics for the API key"""
|
||||
try:
|
||||
api_key = context.get("api_key")
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="API key information not available"
|
||||
detail="API key information not available",
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"api_key_id": api_key.id,
|
||||
"api_key_name": api_key.name,
|
||||
@@ -622,24 +677,26 @@ async def get_usage_stats(
|
||||
"total_tokens": api_key.total_tokens,
|
||||
"total_cost_cents": api_key.total_cost,
|
||||
"created_at": api_key.created_at.isoformat(),
|
||||
"last_used_at": api_key.last_used_at.isoformat() if api_key.last_used_at else None,
|
||||
"last_used_at": api_key.last_used_at.isoformat()
|
||||
if api_key.last_used_at
|
||||
else None,
|
||||
"rate_limits": {
|
||||
"per_minute": api_key.rate_limit_per_minute,
|
||||
"per_hour": api_key.rate_limit_per_hour,
|
||||
"per_day": api_key.rate_limit_per_day
|
||||
"per_day": api_key.rate_limit_per_day,
|
||||
},
|
||||
"permissions": api_key.permissions,
|
||||
"scopes": api_key.scopes,
|
||||
"allowed_models": api_key.allowed_models
|
||||
"allowed_models": api_key.allowed_models,
|
||||
}
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting usage stats: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get usage statistics"
|
||||
detail="Failed to get usage statistics",
|
||||
)
|
||||
|
||||
|
||||
@@ -647,51 +704,48 @@ async def get_usage_stats(
|
||||
async def get_budget_status(
|
||||
request: Request,
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get current budget status and usage analytics"""
|
||||
try:
|
||||
auth_type = context.get("auth_type", "api_key")
|
||||
|
||||
|
||||
# Check permissions based on auth type
|
||||
if auth_type == "api_key":
|
||||
auth_service = APIKeyAuthService(db)
|
||||
if not await auth_service.check_scope_permission(context, "budget.read"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Insufficient permissions to read budget information"
|
||||
detail="Insufficient permissions to read budget information",
|
||||
)
|
||||
|
||||
|
||||
api_key = context.get("api_key")
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="API key information not available"
|
||||
detail="API key information not available",
|
||||
)
|
||||
|
||||
|
||||
# Convert AsyncSession to Session for budget enforcement
|
||||
sync_db = Session(bind=db.bind.sync_engine)
|
||||
|
||||
|
||||
try:
|
||||
budget_service = BudgetEnforcementService(sync_db)
|
||||
budget_status = budget_service.get_budget_status(api_key)
|
||||
|
||||
return {
|
||||
"object": "budget_status",
|
||||
"data": budget_status
|
||||
}
|
||||
|
||||
return {"object": "budget_status", "data": budget_status}
|
||||
finally:
|
||||
sync_db.close()
|
||||
|
||||
|
||||
elif auth_type == "jwt":
|
||||
# For JWT authentication, return user-level budget information
|
||||
user = context.get("user")
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User information not available"
|
||||
detail="User information not available",
|
||||
)
|
||||
|
||||
|
||||
# Return basic budget info for JWT users
|
||||
return {
|
||||
"object": "budget_status",
|
||||
@@ -702,23 +756,23 @@ async def get_budget_status(
|
||||
"projections": {
|
||||
"daily_burn_rate": 0.0,
|
||||
"projected_monthly": 0.0,
|
||||
"days_remaining": 30
|
||||
}
|
||||
}
|
||||
"days_remaining": 30,
|
||||
},
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication type"
|
||||
detail="Invalid authentication type",
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting budget status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get budget status"
|
||||
detail="Failed to get budget status",
|
||||
)
|
||||
|
||||
|
||||
@@ -726,7 +780,7 @@ async def get_budget_status(
|
||||
@router.get("/metrics")
|
||||
async def get_llm_metrics(
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get LLM service metrics (admin only)"""
|
||||
try:
|
||||
@@ -735,9 +789,9 @@ async def get_llm_metrics(
|
||||
if not await auth_service.check_scope_permission(context, "admin.metrics"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin permissions required to view metrics"
|
||||
detail="Admin permissions required to view metrics",
|
||||
)
|
||||
|
||||
|
||||
metrics = llm_service.get_metrics()
|
||||
return {
|
||||
"object": "llm_metrics",
|
||||
@@ -745,27 +799,27 @@ async def get_llm_metrics(
|
||||
"total_requests": metrics.total_requests,
|
||||
"successful_requests": metrics.successful_requests,
|
||||
"failed_requests": metrics.failed_requests,
|
||||
"average_latency_ms": metrics.average_latency_ms,
|
||||
"average_latency_ms": metrics.average_latency_ms,
|
||||
"average_risk_score": metrics.average_risk_score,
|
||||
"provider_metrics": metrics.provider_metrics,
|
||||
"last_updated": metrics.last_updated.isoformat()
|
||||
}
|
||||
"last_updated": metrics.last_updated.isoformat(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting LLM metrics: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get LLM metrics"
|
||||
detail="Failed to get LLM metrics",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/providers/status")
|
||||
async def get_provider_status(
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get status of all LLM providers"""
|
||||
try:
|
||||
@@ -773,9 +827,9 @@ async def get_provider_status(
|
||||
if not await auth_service.check_scope_permission(context, "admin.status"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin permissions required to view provider status"
|
||||
detail="Admin permissions required to view provider status",
|
||||
)
|
||||
|
||||
|
||||
provider_status = await llm_service.get_provider_status()
|
||||
return {
|
||||
"object": "provider_status",
|
||||
@@ -787,17 +841,17 @@ async def get_provider_status(
|
||||
"success_rate": status.success_rate,
|
||||
"last_check": status.last_check.isoformat(),
|
||||
"error_message": status.error_message,
|
||||
"models_available": status.models_available
|
||||
"models_available": status.models_available,
|
||||
}
|
||||
for name, status in provider_status.items()
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting provider status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get provider status"
|
||||
)
|
||||
detail="Failed to get provider status",
|
||||
)
|
||||
|
||||
@@ -13,7 +13,12 @@ from app.db.database import get_db
|
||||
from app.core.security import get_current_user
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import ChatRequest, ChatMessage as LLMChatMessage
|
||||
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError, ValidationError
|
||||
from app.services.llm.exceptions import (
|
||||
LLMError,
|
||||
ProviderError,
|
||||
SecurityError,
|
||||
ValidationError,
|
||||
)
|
||||
from app.api.v1.llm import get_cached_models # Reuse the caching logic
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -35,14 +40,12 @@ async def list_models(
|
||||
logger.error(f"Failed to list models: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve models"
|
||||
detail="Failed to retrieve models",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/providers/status")
|
||||
async def get_provider_status(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
async def get_provider_status(current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
"""
|
||||
Get status of all LLM providers for authenticated users
|
||||
"""
|
||||
@@ -58,23 +61,21 @@ async def get_provider_status(
|
||||
"success_rate": status.success_rate,
|
||||
"last_check": status.last_check.isoformat(),
|
||||
"error_message": status.error_message,
|
||||
"models_available": status.models_available
|
||||
"models_available": status.models_available,
|
||||
}
|
||||
for name, status in provider_status.items()
|
||||
}
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get provider status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve provider status"
|
||||
detail="Failed to retrieve provider status",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
async def health_check(current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
"""
|
||||
Get LLM service health status for authenticated users
|
||||
"""
|
||||
@@ -83,39 +84,35 @@ async def health_check(
|
||||
return {
|
||||
"status": health["status"],
|
||||
"providers": health.get("providers", {}),
|
||||
"timestamp": health.get("timestamp")
|
||||
"timestamp": health.get("timestamp"),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Health check failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Health check failed"
|
||||
detail="Health check failed",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
async def get_metrics(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
async def get_metrics(current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
"""
|
||||
Get LLM service metrics for authenticated users
|
||||
"""
|
||||
try:
|
||||
metrics = await llm_service.get_metrics()
|
||||
return {
|
||||
"object": "metrics",
|
||||
"data": metrics
|
||||
}
|
||||
return {"object": "metrics", "data": metrics}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get metrics: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve metrics"
|
||||
detail="Failed to retrieve metrics",
|
||||
)
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
"""Request model for chat completions"""
|
||||
|
||||
model: str
|
||||
messages: List[Dict[str, str]]
|
||||
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0)
|
||||
@@ -128,7 +125,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
async def create_chat_completion(
|
||||
request: ChatCompletionRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Create chat completion for authenticated frontend users.
|
||||
@@ -137,7 +134,7 @@ async def create_chat_completion(
|
||||
try:
|
||||
# Get user ID from JWT token context
|
||||
user_id = str(current_user.get("id", current_user.get("sub", "0")))
|
||||
|
||||
|
||||
# Convert request to LLM service format
|
||||
# For internal use, we use a special api_key_id of 0 to indicate JWT auth
|
||||
chat_request = ChatRequest(
|
||||
@@ -151,15 +148,17 @@ async def create_chat_completion(
|
||||
top_p=request.top_p,
|
||||
stream=request.stream,
|
||||
user_id=user_id,
|
||||
api_key_id=0 # Special value for JWT-authenticated requests
|
||||
api_key_id=0, # Special value for JWT-authenticated requests
|
||||
)
|
||||
|
||||
|
||||
# Log the request for debugging
|
||||
logger.info(f"Internal chat completion request from user {current_user.get('id')}: model={request.model}")
|
||||
|
||||
logger.info(
|
||||
f"Internal chat completion request from user {current_user.get('id')}: model={request.model}"
|
||||
)
|
||||
|
||||
# Process the request through the LLM service
|
||||
response = await llm_service.create_chat_completion(chat_request)
|
||||
|
||||
|
||||
# Format the response to match OpenAI's structure
|
||||
formatted_response = {
|
||||
"id": response.id,
|
||||
@@ -171,36 +170,39 @@ async def create_chat_completion(
|
||||
"index": choice.index,
|
||||
"message": {
|
||||
"role": choice.message.role,
|
||||
"content": choice.message.content
|
||||
"content": choice.message.content,
|
||||
},
|
||||
"finish_reason": choice.finish_reason
|
||||
"finish_reason": choice.finish_reason,
|
||||
}
|
||||
for choice in response.choices
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
|
||||
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
|
||||
"total_tokens": response.usage.total_tokens if response.usage else 0
|
||||
} if response.usage else None
|
||||
"completion_tokens": response.usage.completion_tokens
|
||||
if response.usage
|
||||
else 0,
|
||||
"total_tokens": response.usage.total_tokens if response.usage else 0,
|
||||
}
|
||||
if response.usage
|
||||
else None,
|
||||
}
|
||||
|
||||
|
||||
return formatted_response
|
||||
|
||||
|
||||
except ValidationError as e:
|
||||
logger.error(f"Validation error in chat completion: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid request: {str(e)}"
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid request: {str(e)}"
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.error(f"LLM service error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"LLM service error: {str(e)}"
|
||||
detail=f"LLM service error: {str(e)}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in chat completion: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to process chat completion"
|
||||
)
|
||||
detail="Failed to process chat completion",
|
||||
)
|
||||
|
||||
@@ -15,17 +15,17 @@ router = APIRouter()
|
||||
async def list_modules(current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
"""Get list of all discovered modules with their status (enabled and disabled)"""
|
||||
log_api_request("list_modules", {})
|
||||
|
||||
|
||||
# Get all discovered modules including disabled ones
|
||||
all_modules = module_manager.list_all_modules()
|
||||
|
||||
|
||||
modules = []
|
||||
for module_info in all_modules:
|
||||
# Convert module_info to API format with status field
|
||||
name = module_info["name"]
|
||||
is_loaded = module_info["loaded"] # Module is actually loaded in memory
|
||||
is_enabled = module_info["enabled"] # Module is enabled in config
|
||||
|
||||
|
||||
# Determine status based on enabled + loaded state
|
||||
if is_enabled and is_loaded:
|
||||
status = "running"
|
||||
@@ -33,40 +33,43 @@ async def list_modules(current_user: Dict[str, Any] = Depends(get_current_user))
|
||||
status = "error" # Enabled but failed to load
|
||||
else: # not is_enabled (regardless of loaded state)
|
||||
status = "standby" # Disabled
|
||||
|
||||
|
||||
api_module = {
|
||||
"name": name,
|
||||
"version": module_info["version"],
|
||||
"description": module_info["description"],
|
||||
"initialized": is_loaded,
|
||||
"initialized": is_loaded,
|
||||
"enabled": is_enabled,
|
||||
"status": status # Add status field for frontend compatibility
|
||||
"status": status, # Add status field for frontend compatibility
|
||||
}
|
||||
|
||||
|
||||
# Get module statistics if available and module is loaded
|
||||
if module_info["loaded"] and module_info["name"] in module_manager.modules:
|
||||
module_instance = module_manager.modules[module_info["name"]]
|
||||
if hasattr(module_instance, "get_stats"):
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
if asyncio.iscoroutinefunction(module_instance.get_stats):
|
||||
stats = await module_instance.get_stats()
|
||||
else:
|
||||
stats = module_instance.get_stats()
|
||||
api_module["stats"] = stats.__dict__ if hasattr(stats, "__dict__") else stats
|
||||
api_module["stats"] = (
|
||||
stats.__dict__ if hasattr(stats, "__dict__") else stats
|
||||
)
|
||||
except:
|
||||
api_module["stats"] = {}
|
||||
|
||||
|
||||
modules.append(api_module)
|
||||
|
||||
|
||||
# Calculate stats
|
||||
loaded_count = sum(1 for m in modules if m["initialized"] and m["enabled"])
|
||||
|
||||
|
||||
return {
|
||||
"total": len(modules),
|
||||
"modules": modules,
|
||||
"module_count": loaded_count,
|
||||
"initialized": module_manager.initialized
|
||||
"initialized": module_manager.initialized,
|
||||
}
|
||||
|
||||
|
||||
@@ -74,20 +77,20 @@ async def list_modules(current_user: Dict[str, Any] = Depends(get_current_user))
|
||||
async def get_modules_status(current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
"""Get comprehensive module status - CONSOLIDATED endpoint"""
|
||||
log_api_request("get_modules_status", {})
|
||||
|
||||
|
||||
# Get all discovered modules including disabled ones
|
||||
all_modules = module_manager.list_all_modules()
|
||||
|
||||
|
||||
modules_with_status = []
|
||||
running_count = 0
|
||||
standby_count = 0
|
||||
failed_count = 0
|
||||
|
||||
|
||||
for module_info in all_modules:
|
||||
name = module_info["name"]
|
||||
is_loaded = module_info["loaded"] # Module is actually loaded in memory
|
||||
is_enabled = module_info["enabled"] # Module is enabled in config
|
||||
|
||||
|
||||
# Determine status based on enabled + loaded state
|
||||
if is_enabled and is_loaded:
|
||||
status = "running"
|
||||
@@ -98,7 +101,7 @@ async def get_modules_status(current_user: Dict[str, Any] = Depends(get_current_
|
||||
else: # not is_enabled (regardless of loaded state)
|
||||
status = "standby" # Disabled
|
||||
standby_count += 1
|
||||
|
||||
|
||||
# Get module statistics if available and loaded
|
||||
stats = {}
|
||||
if is_loaded and name in module_manager.modules:
|
||||
@@ -106,56 +109,68 @@ async def get_modules_status(current_user: Dict[str, Any] = Depends(get_current_
|
||||
if hasattr(module_instance, "get_stats"):
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
if asyncio.iscoroutinefunction(module_instance.get_stats):
|
||||
stats_result = await module_instance.get_stats()
|
||||
else:
|
||||
stats_result = module_instance.get_stats()
|
||||
stats = stats_result.__dict__ if hasattr(stats_result, "__dict__") else stats_result
|
||||
stats = (
|
||||
stats_result.__dict__
|
||||
if hasattr(stats_result, "__dict__")
|
||||
else stats_result
|
||||
)
|
||||
except:
|
||||
stats = {}
|
||||
|
||||
modules_with_status.append({
|
||||
"name": name,
|
||||
"version": module_info["version"],
|
||||
"description": module_info["description"],
|
||||
"status": status,
|
||||
"enabled": is_enabled,
|
||||
"loaded": is_loaded,
|
||||
"stats": stats
|
||||
})
|
||||
|
||||
|
||||
modules_with_status.append(
|
||||
{
|
||||
"name": name,
|
||||
"version": module_info["version"],
|
||||
"description": module_info["description"],
|
||||
"status": status,
|
||||
"enabled": is_enabled,
|
||||
"loaded": is_loaded,
|
||||
"stats": stats,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"modules": modules_with_status,
|
||||
"total": len(modules_with_status),
|
||||
"running": running_count,
|
||||
"standby": standby_count,
|
||||
"standby": standby_count,
|
||||
"failed": failed_count,
|
||||
"system_initialized": module_manager.initialized
|
||||
"system_initialized": module_manager.initialized,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{module_name}")
|
||||
async def get_module_info(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def get_module_info(
|
||||
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Get detailed information about a specific module"""
|
||||
log_api_request("get_module_info", {"module_name": module_name})
|
||||
|
||||
|
||||
if module_name not in module_manager.modules:
|
||||
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
|
||||
|
||||
|
||||
module = module_manager.modules[module_name]
|
||||
module_info = {
|
||||
"name": module_name,
|
||||
"version": getattr(module, "version", "1.0.0"),
|
||||
"description": getattr(module, "description", ""),
|
||||
"initialized": getattr(module, "initialized", False),
|
||||
"enabled": module_manager.module_configs.get(module_name, ModuleConfig(module_name)).enabled,
|
||||
"capabilities": []
|
||||
"enabled": module_manager.module_configs.get(
|
||||
module_name, ModuleConfig(module_name)
|
||||
).enabled,
|
||||
"capabilities": [],
|
||||
}
|
||||
|
||||
|
||||
# Get module capabilities
|
||||
if hasattr(module, "get_module_info"):
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
if asyncio.iscoroutinefunction(module.get_module_info):
|
||||
info = await module.get_module_info()
|
||||
else:
|
||||
@@ -163,19 +178,22 @@ async def get_module_info(module_name: str, current_user: Dict[str, Any] = Depen
|
||||
module_info.update(info)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# Get module statistics
|
||||
if hasattr(module, "get_stats"):
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
if asyncio.iscoroutinefunction(module.get_stats):
|
||||
stats = await module.get_stats()
|
||||
else:
|
||||
stats = module.get_stats()
|
||||
module_info["stats"] = stats.__dict__ if hasattr(stats, "__dict__") else stats
|
||||
module_info["stats"] = (
|
||||
stats.__dict__ if hasattr(stats, "__dict__") else stats
|
||||
)
|
||||
except:
|
||||
module_info["stats"] = {}
|
||||
|
||||
|
||||
# List available methods
|
||||
methods = []
|
||||
for attr_name in dir(module):
|
||||
@@ -183,57 +201,64 @@ async def get_module_info(module_name: str, current_user: Dict[str, Any] = Depen
|
||||
if callable(attr) and not attr_name.startswith("_"):
|
||||
methods.append(attr_name)
|
||||
module_info["methods"] = methods
|
||||
|
||||
|
||||
return module_info
|
||||
|
||||
|
||||
@router.post("/{module_name}/enable")
|
||||
async def enable_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def enable_module(
|
||||
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Enable a module"""
|
||||
log_api_request("enable_module", {"module_name": module_name})
|
||||
|
||||
|
||||
if module_name not in module_manager.module_configs:
|
||||
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
|
||||
|
||||
|
||||
# Enable the module in config
|
||||
config = module_manager.module_configs[module_name]
|
||||
config.enabled = True
|
||||
|
||||
|
||||
# Load the module if not already loaded
|
||||
if module_name not in module_manager.modules:
|
||||
try:
|
||||
await module_manager._load_module(module_name, config)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to enable module '{module_name}': {str(e)}")
|
||||
|
||||
return {
|
||||
"message": f"Module '{module_name}' enabled successfully",
|
||||
"enabled": True
|
||||
}
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to enable module '{module_name}': {str(e)}",
|
||||
)
|
||||
|
||||
return {"message": f"Module '{module_name}' enabled successfully", "enabled": True}
|
||||
|
||||
|
||||
@router.post("/{module_name}/disable")
|
||||
async def disable_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def disable_module(
|
||||
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Disable a module"""
|
||||
log_api_request("disable_module", {"module_name": module_name})
|
||||
|
||||
|
||||
if module_name not in module_manager.module_configs:
|
||||
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
|
||||
|
||||
|
||||
# Disable the module in config
|
||||
config = module_manager.module_configs[module_name]
|
||||
config.enabled = False
|
||||
|
||||
|
||||
# Unload the module if loaded
|
||||
if module_name in module_manager.modules:
|
||||
try:
|
||||
await module_manager.unload_module(module_name)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to disable module '{module_name}': {str(e)}")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to disable module '{module_name}': {str(e)}",
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Module '{module_name}' disabled successfully",
|
||||
"enabled": False
|
||||
"enabled": False,
|
||||
}
|
||||
|
||||
|
||||
@@ -241,10 +266,10 @@ async def disable_module(module_name: str, current_user: Dict[str, Any] = Depend
|
||||
async def reload_all_modules(current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
"""Reload all modules"""
|
||||
log_api_request("reload_all_modules", {})
|
||||
|
||||
|
||||
results = {}
|
||||
failed_modules = []
|
||||
|
||||
|
||||
for module_name in list(module_manager.modules.keys()):
|
||||
try:
|
||||
success = await module_manager.reload_module(module_name)
|
||||
@@ -254,272 +279,316 @@ async def reload_all_modules(current_user: Dict[str, Any] = Depends(get_current_
|
||||
except Exception as e:
|
||||
results[module_name] = {"success": False, "error": str(e)}
|
||||
failed_modules.append(module_name)
|
||||
|
||||
|
||||
if failed_modules:
|
||||
return {
|
||||
"message": f"Reloaded {len(results) - len(failed_modules)}/{len(results)} modules successfully",
|
||||
"success": False,
|
||||
"results": results,
|
||||
"failed_modules": failed_modules
|
||||
"failed_modules": failed_modules,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"message": f"All {len(results)} modules reloaded successfully",
|
||||
"success": True,
|
||||
"results": results,
|
||||
"failed_modules": []
|
||||
"failed_modules": [],
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{module_name}/reload")
|
||||
async def reload_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def reload_module(
|
||||
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Reload a specific module"""
|
||||
log_api_request("reload_module", {"module_name": module_name})
|
||||
|
||||
|
||||
if module_name not in module_manager.modules:
|
||||
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
|
||||
|
||||
|
||||
success = await module_manager.reload_module(module_name)
|
||||
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to reload module '{module_name}'")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to reload module '{module_name}'"
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Module '{module_name}' reloaded successfully",
|
||||
"reloaded": True
|
||||
"reloaded": True,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{module_name}/restart")
|
||||
async def restart_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def restart_module(
|
||||
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Restart a specific module (alias for reload)"""
|
||||
log_api_request("restart_module", {"module_name": module_name})
|
||||
|
||||
|
||||
if module_name not in module_manager.modules:
|
||||
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
|
||||
|
||||
|
||||
success = await module_manager.reload_module(module_name)
|
||||
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to restart module '{module_name}'")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to restart module '{module_name}'"
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Module '{module_name}' restarted successfully",
|
||||
"restarted": True
|
||||
"restarted": True,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{module_name}/start")
|
||||
async def start_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def start_module(
|
||||
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Start a specific module (enable and load)"""
|
||||
log_api_request("start_module", {"module_name": module_name})
|
||||
|
||||
|
||||
if module_name not in module_manager.module_configs:
|
||||
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
|
||||
|
||||
|
||||
# Enable the module
|
||||
config = module_manager.module_configs[module_name]
|
||||
config.enabled = True
|
||||
|
||||
|
||||
# Load the module if not already loaded
|
||||
if module_name not in module_manager.modules:
|
||||
await module_manager._load_module(module_name, config)
|
||||
|
||||
return {
|
||||
"message": f"Module '{module_name}' started successfully",
|
||||
"started": True
|
||||
}
|
||||
|
||||
return {"message": f"Module '{module_name}' started successfully", "started": True}
|
||||
|
||||
|
||||
@router.post("/{module_name}/stop")
|
||||
async def stop_module(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def stop_module(
|
||||
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Stop a specific module (disable and unload)"""
|
||||
log_api_request("stop_module", {"module_name": module_name})
|
||||
|
||||
|
||||
if module_name not in module_manager.module_configs:
|
||||
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
|
||||
|
||||
|
||||
# Disable the module
|
||||
config = module_manager.module_configs[module_name]
|
||||
config.enabled = False
|
||||
|
||||
|
||||
# Unload the module if loaded
|
||||
if module_name in module_manager.modules:
|
||||
await module_manager.unload_module(module_name)
|
||||
|
||||
return {
|
||||
"message": f"Module '{module_name}' stopped successfully",
|
||||
"stopped": True
|
||||
}
|
||||
|
||||
return {"message": f"Module '{module_name}' stopped successfully", "stopped": True}
|
||||
|
||||
|
||||
@router.get("/{module_name}/stats")
|
||||
async def get_module_stats(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def get_module_stats(
|
||||
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Get module statistics"""
|
||||
log_api_request("get_module_stats", {"module_name": module_name})
|
||||
|
||||
|
||||
if module_name not in module_manager.modules:
|
||||
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
|
||||
|
||||
|
||||
module = module_manager.modules[module_name]
|
||||
|
||||
|
||||
if not hasattr(module, "get_stats"):
|
||||
raise HTTPException(status_code=404, detail=f"Module '{module_name}' does not provide statistics")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Module '{module_name}' does not provide statistics",
|
||||
)
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
if asyncio.iscoroutinefunction(module.get_stats):
|
||||
stats = await module.get_stats()
|
||||
else:
|
||||
stats = module.get_stats()
|
||||
return {
|
||||
"module": module_name,
|
||||
"stats": stats.__dict__ if hasattr(stats, "__dict__") else stats
|
||||
"stats": stats.__dict__ if hasattr(stats, "__dict__") else stats,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get statistics: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get statistics: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{module_name}/execute")
|
||||
async def execute_module_action(module_name: str, request_data: Dict[str, Any], current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def execute_module_action(
|
||||
module_name: str,
|
||||
request_data: Dict[str, Any],
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Execute a module action through the interceptor pattern"""
|
||||
log_api_request("execute_module_action", {"module_name": module_name, "action": request_data.get("action")})
|
||||
|
||||
log_api_request(
|
||||
"execute_module_action",
|
||||
{"module_name": module_name, "action": request_data.get("action")},
|
||||
)
|
||||
|
||||
if module_name not in module_manager.modules:
|
||||
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
|
||||
|
||||
|
||||
module = module_manager.modules[module_name]
|
||||
|
||||
|
||||
# Check if module supports the new interceptor pattern
|
||||
if hasattr(module, 'execute_with_interceptors'):
|
||||
if hasattr(module, "execute_with_interceptors"):
|
||||
try:
|
||||
# Prepare context (would normally come from authentication middleware)
|
||||
context = {
|
||||
"user_id": "test_user", # Would come from authentication
|
||||
"api_key_id": "test_api_key", # Would come from API key auth
|
||||
"ip_address": "127.0.0.1", # Would come from request
|
||||
"user_permissions": [f"modules:{module_name}:*"] # Would come from user/API key permissions
|
||||
"user_permissions": [
|
||||
f"modules:{module_name}:*"
|
||||
], # Would come from user/API key permissions
|
||||
}
|
||||
|
||||
|
||||
# Execute through interceptor chain
|
||||
response = await module.execute_with_interceptors(request_data, context)
|
||||
|
||||
|
||||
return {
|
||||
"module": module_name,
|
||||
"success": True,
|
||||
"response": response,
|
||||
"interceptor_pattern": True
|
||||
"interceptor_pattern": True,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Module execution failed: {str(e)}")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Module execution failed: {str(e)}"
|
||||
)
|
||||
|
||||
# Fallback for legacy modules
|
||||
else:
|
||||
action = request_data.get("action", "execute")
|
||||
|
||||
|
||||
if hasattr(module, action):
|
||||
try:
|
||||
method = getattr(module, action)
|
||||
if callable(method):
|
||||
import asyncio
|
||||
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
response = await method(request_data)
|
||||
else:
|
||||
response = method(request_data)
|
||||
|
||||
|
||||
return {
|
||||
"module": module_name,
|
||||
"success": True,
|
||||
"response": response,
|
||||
"interceptor_pattern": False
|
||||
"interceptor_pattern": False,
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"'{action}' is not callable")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"'{action}' is not callable"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Module execution failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Module execution failed: {str(e)}"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Action '{action}' not supported by module '{module_name}'")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Action '{action}' not supported by module '{module_name}'",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{module_name}/config")
|
||||
async def get_module_config(module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def get_module_config(
|
||||
module_name: str, current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Get module configuration schema and current values"""
|
||||
log_api_request("get_module_config", {"module_name": module_name})
|
||||
|
||||
|
||||
from app.services.module_config_manager import module_config_manager
|
||||
from app.services.llm.service import llm_service
|
||||
import copy
|
||||
|
||||
|
||||
# Get module manifest and schema
|
||||
manifest = module_config_manager.get_module_manifest(module_name)
|
||||
if not manifest:
|
||||
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
|
||||
|
||||
|
||||
schema = module_config_manager.get_module_schema(module_name)
|
||||
current_config = module_config_manager.get_module_config(module_name)
|
||||
|
||||
|
||||
# For Signal module, populate model options dynamically
|
||||
if module_name == "signal" and schema:
|
||||
try:
|
||||
# Get available models from LLM service
|
||||
models_data = await llm_service.get_models()
|
||||
model_ids = [model.id for model in models_data]
|
||||
|
||||
|
||||
if model_ids:
|
||||
# Create a copy of the schema to avoid modifying the original
|
||||
dynamic_schema = copy.deepcopy(schema)
|
||||
|
||||
|
||||
# Add enum options for the model field
|
||||
if "properties" in dynamic_schema and "model" in dynamic_schema["properties"]:
|
||||
if (
|
||||
"properties" in dynamic_schema
|
||||
and "model" in dynamic_schema["properties"]
|
||||
):
|
||||
dynamic_schema["properties"]["model"]["enum"] = model_ids
|
||||
# Set a sensible default if the current default isn't in the list
|
||||
current_default = dynamic_schema["properties"]["model"].get("default", "gpt-3.5-turbo")
|
||||
current_default = dynamic_schema["properties"]["model"].get(
|
||||
"default", "gpt-3.5-turbo"
|
||||
)
|
||||
if current_default not in model_ids and model_ids:
|
||||
dynamic_schema["properties"]["model"]["default"] = model_ids[0]
|
||||
|
||||
|
||||
schema = dynamic_schema
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# If we can't get models, log warning but continue with original schema
|
||||
logger.warning(f"Failed to get dynamic models for Signal config: {e}")
|
||||
|
||||
|
||||
return {
|
||||
"module": module_name,
|
||||
"description": manifest.description,
|
||||
"schema": schema,
|
||||
"current_config": current_config,
|
||||
"has_schema": schema is not None
|
||||
"has_schema": schema is not None,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{module_name}/config")
|
||||
async def update_module_config(module_name: str, config: dict, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def update_module_config(
|
||||
module_name: str,
|
||||
config: dict,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Update module configuration"""
|
||||
log_api_request("update_module_config", {"module_name": module_name})
|
||||
|
||||
|
||||
from app.services.module_config_manager import module_config_manager
|
||||
|
||||
|
||||
# Validate module exists
|
||||
manifest = module_config_manager.get_module_manifest(module_name)
|
||||
if not manifest:
|
||||
raise HTTPException(status_code=404, detail=f"Module '{module_name}' not found")
|
||||
|
||||
|
||||
try:
|
||||
# Save configuration
|
||||
success = await module_config_manager.save_module_config(module_name, config)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to save configuration")
|
||||
|
||||
|
||||
# Update module manager with new config
|
||||
success = await module_manager.update_module_config(module_name, config)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to apply configuration")
|
||||
|
||||
|
||||
return {
|
||||
"message": f"Configuration updated for module '{module_name}'",
|
||||
"config": config
|
||||
"config": config,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@@ -12,9 +12,13 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.db.database import get_db
|
||||
from app.services.api_key_auth import require_api_key
|
||||
from app.api.v1.llm import (
|
||||
get_cached_models, ModelsResponse, ModelInfo,
|
||||
ChatCompletionRequest, EmbeddingRequest, create_chat_completion as llm_chat_completion,
|
||||
create_embedding as llm_create_embedding
|
||||
get_cached_models,
|
||||
ModelsResponse,
|
||||
ModelInfo,
|
||||
ChatCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
create_chat_completion as llm_chat_completion,
|
||||
create_embedding as llm_create_embedding,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -22,8 +26,12 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def openai_error_response(message: str, error_type: str = "invalid_request_error",
|
||||
status_code: int = 400, code: str = None):
|
||||
def openai_error_response(
|
||||
message: str,
|
||||
error_type: str = "invalid_request_error",
|
||||
status_code: int = 400,
|
||||
code: str = None,
|
||||
):
|
||||
"""Create OpenAI-compatible error response"""
|
||||
error_data = {
|
||||
"error": {
|
||||
@@ -33,52 +41,42 @@ def openai_error_response(message: str, error_type: str = "invalid_request_error
|
||||
}
|
||||
if code:
|
||||
error_data["error"]["code"] = code
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=error_data
|
||||
)
|
||||
|
||||
return JSONResponse(status_code=status_code, content=error_data)
|
||||
|
||||
|
||||
@router.get("/models", response_model=ModelsResponse)
|
||||
async def list_models(
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Lists the currently available models, and provides basic information about each one
|
||||
Lists the currently available models, and provides basic information about each one
|
||||
such as the owner and availability.
|
||||
|
||||
|
||||
This endpoint follows the exact OpenAI API specification:
|
||||
GET /v1/models
|
||||
"""
|
||||
try:
|
||||
# Delegate to the existing LLM models endpoint
|
||||
from app.api.v1.llm import list_models as llm_list_models
|
||||
|
||||
return await llm_list_models(context, db)
|
||||
except HTTPException as e:
|
||||
# Convert FastAPI HTTPException to OpenAI format
|
||||
if e.status_code == 401:
|
||||
return openai_error_response(
|
||||
"Invalid authentication credentials",
|
||||
"authentication_error",
|
||||
401
|
||||
"Invalid authentication credentials", "authentication_error", 401
|
||||
)
|
||||
elif e.status_code == 403:
|
||||
return openai_error_response(
|
||||
"Insufficient permissions",
|
||||
"permission_error",
|
||||
403
|
||||
"Insufficient permissions", "permission_error", 403
|
||||
)
|
||||
else:
|
||||
return openai_error_response(str(e.detail), "api_error", e.status_code)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in OpenAI models endpoint: {e}")
|
||||
return openai_error_response(
|
||||
"Internal server error",
|
||||
"api_error",
|
||||
500
|
||||
)
|
||||
return openai_error_response("Internal server error", "api_error", 500)
|
||||
|
||||
|
||||
@router.post("/chat/completions")
|
||||
@@ -86,11 +84,11 @@ async def create_chat_completion(
|
||||
request_body: Request,
|
||||
chat_request: ChatCompletionRequest,
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Create chat completion - OpenAI compatible endpoint
|
||||
|
||||
|
||||
This endpoint follows the exact OpenAI API specification:
|
||||
POST /v1/chat/completions
|
||||
"""
|
||||
@@ -102,11 +100,11 @@ async def create_chat_completion(
|
||||
async def create_embedding(
|
||||
request: EmbeddingRequest,
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Create embedding - OpenAI compatible endpoint
|
||||
|
||||
|
||||
This endpoint follows the exact OpenAI API specification:
|
||||
POST /v1/embeddings
|
||||
"""
|
||||
@@ -118,44 +116,46 @@ async def create_embedding(
|
||||
async def retrieve_model(
|
||||
model_id: str,
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Retrieve model information - OpenAI compatible endpoint
|
||||
|
||||
|
||||
This endpoint follows the exact OpenAI API specification:
|
||||
GET /v1/models/{model}
|
||||
"""
|
||||
try:
|
||||
# Get all models and find the specific one
|
||||
models = await get_cached_models()
|
||||
|
||||
|
||||
# Filter models based on API key permissions
|
||||
api_key = context.get("api_key")
|
||||
if api_key and api_key.allowed_models:
|
||||
models = [model for model in models if model.get("id") in api_key.allowed_models]
|
||||
|
||||
models = [
|
||||
model for model in models if model.get("id") in api_key.allowed_models
|
||||
]
|
||||
|
||||
# Find the specific model
|
||||
model = next((m for m in models if m.get("id") == model_id), None)
|
||||
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Model '{model_id}' not found"
|
||||
detail=f"Model '{model_id}' not found",
|
||||
)
|
||||
|
||||
|
||||
return ModelInfo(
|
||||
id=model.get("id", model_id),
|
||||
object="model",
|
||||
created=model.get("created", 0),
|
||||
owned_by=model.get("owned_by", "system")
|
||||
owned_by=model.get("owned_by", "system"),
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving model {model_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve model information"
|
||||
)
|
||||
detail="Failed to retrieve model information",
|
||||
)
|
||||
|
||||
@@ -7,7 +7,11 @@ from typing import List, Dict, Any, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.services.permission_manager import permission_registry, Permission, PermissionScope
|
||||
from app.services.permission_manager import (
|
||||
permission_registry,
|
||||
Permission,
|
||||
PermissionScope,
|
||||
)
|
||||
from app.core.logging import get_logger
|
||||
from app.core.security import get_current_user
|
||||
|
||||
@@ -77,7 +81,7 @@ async def get_available_permissions(namespace: Optional[str] = None):
|
||||
"""Get all available permissions, optionally filtered by namespace"""
|
||||
try:
|
||||
permissions = permission_registry.get_available_permissions(namespace)
|
||||
|
||||
|
||||
# Convert to response format
|
||||
result = {}
|
||||
for ns, perms in permissions.items():
|
||||
@@ -86,18 +90,18 @@ async def get_available_permissions(namespace: Optional[str] = None):
|
||||
resource=perm.resource,
|
||||
action=perm.action,
|
||||
description=perm.description,
|
||||
conditions=getattr(perm, 'conditions', None)
|
||||
conditions=getattr(perm, "conditions", None),
|
||||
)
|
||||
for perm in perms
|
||||
]
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting permissions: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get permissions: {str(e)}"
|
||||
detail=f"Failed to get permissions: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -107,12 +111,12 @@ async def get_permission_hierarchy():
|
||||
try:
|
||||
hierarchy = permission_registry.get_permission_hierarchy()
|
||||
return PermissionHierarchyResponse(hierarchy=hierarchy)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting permission hierarchy: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get permission hierarchy: {str(e)}"
|
||||
detail=f"Failed to get permission hierarchy: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -120,44 +124,43 @@ async def get_permission_hierarchy():
|
||||
async def validate_permissions(request: PermissionValidationRequest):
|
||||
"""Validate a list of permissions"""
|
||||
try:
|
||||
validation_result = permission_registry.validate_permissions(request.permissions)
|
||||
validation_result = permission_registry.validate_permissions(
|
||||
request.permissions
|
||||
)
|
||||
return PermissionValidationResponse(**validation_result)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating permissions: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to validate permissions: {str(e)}"
|
||||
detail=f"Failed to validate permissions: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/permissions/check", response_model=PermissionCheckResponse)
|
||||
async def check_permission(
|
||||
request: PermissionCheckRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Check if user has a specific permission"""
|
||||
try:
|
||||
has_permission = permission_registry.check_permission(
|
||||
request.user_permissions,
|
||||
request.required_permission,
|
||||
request.context
|
||||
request.user_permissions, request.required_permission, request.context
|
||||
)
|
||||
|
||||
matching_permissions = list(permission_registry.tree.get_matching_permissions(
|
||||
request.user_permissions
|
||||
))
|
||||
|
||||
|
||||
matching_permissions = list(
|
||||
permission_registry.tree.get_matching_permissions(request.user_permissions)
|
||||
)
|
||||
|
||||
return PermissionCheckResponse(
|
||||
has_permission=has_permission,
|
||||
matching_permissions=matching_permissions
|
||||
has_permission=has_permission, matching_permissions=matching_permissions
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking permission: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to check permission: {str(e)}"
|
||||
detail=f"Failed to check permission: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -166,22 +169,22 @@ async def get_module_permissions(module_id: str):
|
||||
"""Get permissions for a specific module"""
|
||||
try:
|
||||
permissions = permission_registry.get_module_permissions(module_id)
|
||||
|
||||
|
||||
return [
|
||||
PermissionResponse(
|
||||
resource=perm.resource,
|
||||
action=perm.action,
|
||||
description=perm.description,
|
||||
conditions=getattr(perm, 'conditions', None)
|
||||
conditions=getattr(perm, "conditions", None),
|
||||
)
|
||||
for perm in permissions
|
||||
]
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting module permissions: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get module permissions: {str(e)}"
|
||||
detail=f"Failed to get module permissions: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -191,27 +194,28 @@ async def create_role(request: RoleRequest):
|
||||
"""Create a custom role with specific permissions"""
|
||||
try:
|
||||
# Validate permissions first
|
||||
validation_result = permission_registry.validate_permissions(request.permissions)
|
||||
validation_result = permission_registry.validate_permissions(
|
||||
request.permissions
|
||||
)
|
||||
if not validation_result["is_valid"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid permissions: {validation_result['invalid']}"
|
||||
detail=f"Invalid permissions: {validation_result['invalid']}",
|
||||
)
|
||||
|
||||
|
||||
permission_registry.create_role(request.role_name, request.permissions)
|
||||
|
||||
|
||||
return RoleResponse(
|
||||
role_name=request.role_name,
|
||||
permissions=request.permissions
|
||||
role_name=request.role_name, permissions=request.permissions
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating role: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to create role: {str(e)}"
|
||||
detail=f"Failed to create role: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -220,14 +224,17 @@ async def get_roles():
|
||||
"""Get all available roles and their permissions"""
|
||||
try:
|
||||
# Combine default roles and custom roles
|
||||
all_roles = {**permission_registry.default_roles, **permission_registry.role_permissions}
|
||||
all_roles = {
|
||||
**permission_registry.default_roles,
|
||||
**permission_registry.role_permissions,
|
||||
}
|
||||
return all_roles
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting roles: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get roles: {str(e)}"
|
||||
detail=f"Failed to get roles: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -236,28 +243,25 @@ async def get_role(role_name: str):
|
||||
"""Get a specific role and its permissions"""
|
||||
try:
|
||||
# Check default roles first, then custom roles
|
||||
permissions = (permission_registry.role_permissions.get(role_name) or
|
||||
permission_registry.default_roles.get(role_name))
|
||||
|
||||
permissions = permission_registry.role_permissions.get(
|
||||
role_name
|
||||
) or permission_registry.default_roles.get(role_name)
|
||||
|
||||
if permissions is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Role '{role_name}' not found"
|
||||
detail=f"Role '{role_name}' not found",
|
||||
)
|
||||
|
||||
return RoleResponse(
|
||||
role_name=role_name,
|
||||
permissions=permissions,
|
||||
created=True
|
||||
)
|
||||
|
||||
|
||||
return RoleResponse(role_name=role_name, permissions=permissions, created=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting role: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get role: {str(e)}"
|
||||
detail=f"Failed to get role: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -267,21 +271,20 @@ async def calculate_user_permissions(request: UserPermissionsRequest):
|
||||
"""Calculate effective permissions for a user based on roles and custom permissions"""
|
||||
try:
|
||||
effective_permissions = permission_registry.get_user_permissions(
|
||||
request.roles,
|
||||
request.custom_permissions
|
||||
request.roles, request.custom_permissions
|
||||
)
|
||||
|
||||
|
||||
return UserPermissionsResponse(
|
||||
effective_permissions=effective_permissions,
|
||||
roles=request.roles,
|
||||
custom_permissions=request.custom_permissions or []
|
||||
custom_permissions=request.custom_permissions or [],
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating user permissions: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to calculate user permissions: {str(e)}"
|
||||
detail=f"Failed to calculate user permissions: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -293,8 +296,10 @@ async def platform_health():
|
||||
# Get permission system status
|
||||
total_permissions = len(permission_registry.tree.permissions)
|
||||
total_modules = len(permission_registry.module_permissions)
|
||||
total_roles = len(permission_registry.default_roles) + len(permission_registry.role_permissions)
|
||||
|
||||
total_roles = len(permission_registry.default_roles) + len(
|
||||
permission_registry.role_permissions
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "Confidential Empire Platform API",
|
||||
@@ -302,16 +307,13 @@ async def platform_health():
|
||||
"permission_system": {
|
||||
"total_permissions": total_permissions,
|
||||
"registered_modules": total_modules,
|
||||
"available_roles": total_roles
|
||||
}
|
||||
"available_roles": total_roles,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking platform health: {str(e)}")
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e)
|
||||
}
|
||||
return {"status": "unhealthy", "error": str(e)}
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
@@ -320,28 +322,29 @@ async def platform_metrics():
|
||||
try:
|
||||
# Get permission system metrics
|
||||
namespaces = permission_registry.get_available_permissions()
|
||||
|
||||
|
||||
metrics = {
|
||||
"permissions": {
|
||||
"total": len(permission_registry.tree.permissions),
|
||||
"by_namespace": {ns: len(perms) for ns, perms in namespaces.items()}
|
||||
"by_namespace": {ns: len(perms) for ns, perms in namespaces.items()},
|
||||
},
|
||||
"modules": {
|
||||
"registered": len(permission_registry.module_permissions),
|
||||
"names": list(permission_registry.module_permissions.keys())
|
||||
"names": list(permission_registry.module_permissions.keys()),
|
||||
},
|
||||
"roles": {
|
||||
"default": len(permission_registry.default_roles),
|
||||
"custom": len(permission_registry.role_permissions),
|
||||
"total": len(permission_registry.default_roles) + len(permission_registry.role_permissions)
|
||||
}
|
||||
"total": len(permission_registry.default_roles)
|
||||
+ len(permission_registry.role_permissions),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting platform metrics: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get platform metrics: {str(e)}"
|
||||
)
|
||||
detail=f"Failed to get platform metrics: {str(e)}",
|
||||
)
|
||||
|
||||
@@ -46,79 +46,75 @@ async def discover_plugins(
|
||||
category: str = "",
|
||||
limit: int = 20,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Discover available plugins from repository"""
|
||||
try:
|
||||
tag_list = [tag.strip() for tag in tags.split(",") if tag.strip()] if tags else None
|
||||
|
||||
tag_list = (
|
||||
[tag.strip() for tag in tags.split(",") if tag.strip()] if tags else None
|
||||
)
|
||||
|
||||
plugins = await plugin_discovery.search_available_plugins(
|
||||
query=query,
|
||||
tags=tag_list,
|
||||
tags=tag_list,
|
||||
category=category if category else None,
|
||||
limit=limit,
|
||||
db=db
|
||||
db=db,
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"plugins": plugins,
|
||||
"count": len(plugins),
|
||||
"query": query,
|
||||
"filters": {
|
||||
"tags": tag_list,
|
||||
"category": category
|
||||
}
|
||||
"filters": {"tags": tag_list, "category": category},
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Plugin discovery failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Discovery failed: {e}")
|
||||
|
||||
|
||||
@router.get("/categories")
|
||||
async def get_plugin_categories(current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def get_plugin_categories(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Get available plugin categories"""
|
||||
try:
|
||||
categories = await plugin_discovery.get_plugin_categories()
|
||||
return {"categories": categories}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get categories: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get categories: {e}")
|
||||
|
||||
|
||||
|
||||
@router.get("/installed")
|
||||
async def get_installed_plugins(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get user's installed plugins"""
|
||||
try:
|
||||
plugins = await plugin_discovery.get_installed_plugins(current_user["id"], db)
|
||||
return {
|
||||
"plugins": plugins,
|
||||
"count": len(plugins)
|
||||
}
|
||||
|
||||
return {"plugins": plugins, "count": len(plugins)}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get installed plugins: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get installed plugins: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get installed plugins: {e}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/updates")
|
||||
async def check_plugin_updates(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Check for available plugin updates"""
|
||||
try:
|
||||
updates = await plugin_discovery.get_plugin_updates(db)
|
||||
return {
|
||||
"updates": updates,
|
||||
"count": len(updates)
|
||||
}
|
||||
|
||||
return {"updates": updates, "count": len(updates)}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check updates: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to check updates: {e}")
|
||||
@@ -130,29 +126,32 @@ async def install_plugin(
|
||||
request: PluginInstallRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Install plugin from repository"""
|
||||
try:
|
||||
if request.source != "repository":
|
||||
raise HTTPException(status_code=400, detail="Only repository installation supported via this endpoint")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Only repository installation supported via this endpoint",
|
||||
)
|
||||
|
||||
# Start installation in background
|
||||
background_tasks.add_task(
|
||||
install_plugin_background,
|
||||
request.plugin_id,
|
||||
request.version,
|
||||
current_user["id"],
|
||||
db
|
||||
db,
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"status": "installation_started",
|
||||
"plugin_id": request.plugin_id,
|
||||
"version": request.version,
|
||||
"message": "Plugin installation started in background"
|
||||
"message": "Plugin installation started in background",
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Plugin installation failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Installation failed: {e}")
|
||||
@@ -163,38 +162,40 @@ async def install_plugin_from_file(
|
||||
file: UploadFile = File(...),
|
||||
background_tasks: BackgroundTasks = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Install plugin from uploaded file"""
|
||||
try:
|
||||
# Validate file type
|
||||
if not file.filename.endswith('.zip'):
|
||||
if not file.filename.endswith(".zip"):
|
||||
raise HTTPException(status_code=400, detail="Only ZIP files are supported")
|
||||
|
||||
|
||||
# Save uploaded file
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as temp_file:
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as temp_file:
|
||||
content = await file.read()
|
||||
temp_file.write(content)
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
|
||||
try:
|
||||
# Install plugin
|
||||
result = await plugin_installer.install_plugin_from_file(
|
||||
temp_file_path, current_user["id"], db
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"status": "installed",
|
||||
"result": result,
|
||||
"message": "Plugin installed successfully"
|
||||
"message": "Plugin installed successfully",
|
||||
}
|
||||
|
||||
|
||||
finally:
|
||||
# Cleanup temp file
|
||||
import os
|
||||
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"File upload installation failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Installation failed: {e}")
|
||||
@@ -205,20 +206,20 @@ async def uninstall_plugin(
|
||||
plugin_id: str,
|
||||
request: PluginUninstallRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Uninstall plugin"""
|
||||
try:
|
||||
result = await plugin_installer.uninstall_plugin(
|
||||
plugin_id, current_user["id"], db, request.keep_data
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"status": "uninstalled",
|
||||
"result": result,
|
||||
"message": "Plugin uninstalled successfully"
|
||||
"message": "Plugin uninstalled successfully",
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Plugin uninstall failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Uninstall failed: {e}")
|
||||
@@ -229,28 +230,28 @@ async def uninstall_plugin(
|
||||
async def enable_plugin(
|
||||
plugin_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Enable plugin"""
|
||||
try:
|
||||
from app.models.plugin import Plugin
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
stmt = select(Plugin).where(Plugin.id == plugin_id)
|
||||
result = await db.execute(stmt)
|
||||
plugin = result.scalar_one_or_none()
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
|
||||
|
||||
plugin.status = "enabled"
|
||||
await db.commit()
|
||||
|
||||
|
||||
return {
|
||||
"status": "enabled",
|
||||
"plugin_id": plugin_id,
|
||||
"message": "Plugin enabled successfully"
|
||||
"message": "Plugin enabled successfully",
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Plugin enable failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Enable failed: {e}")
|
||||
@@ -260,32 +261,32 @@ async def enable_plugin(
|
||||
async def disable_plugin(
|
||||
plugin_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Disable plugin"""
|
||||
try:
|
||||
from app.models.plugin import Plugin
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
stmt = select(Plugin).where(Plugin.id == plugin_id)
|
||||
result = await db.execute(stmt)
|
||||
plugin = result.scalar_one_or_none()
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
|
||||
|
||||
# Unload if currently loaded
|
||||
if plugin_id in plugin_loader.loaded_plugins:
|
||||
await plugin_loader.unload_plugin(plugin_id)
|
||||
|
||||
|
||||
plugin.status = "disabled"
|
||||
await db.commit()
|
||||
|
||||
|
||||
return {
|
||||
"status": "disabled",
|
||||
"status": "disabled",
|
||||
"plugin_id": plugin_id,
|
||||
"message": "Plugin disabled successfully"
|
||||
"message": "Plugin disabled successfully",
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Plugin disable failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Disable failed: {e}")
|
||||
@@ -295,58 +296,62 @@ async def disable_plugin(
|
||||
async def load_plugin(
|
||||
plugin_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Load plugin into runtime"""
|
||||
try:
|
||||
from app.models.plugin import Plugin
|
||||
from pathlib import Path
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
stmt = select(Plugin).where(Plugin.id == plugin_id)
|
||||
result = await db.execute(stmt)
|
||||
plugin = result.scalar_one_or_none()
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
|
||||
|
||||
if plugin.status != "enabled":
|
||||
raise HTTPException(status_code=400, detail="Plugin must be enabled to load")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Plugin must be enabled to load"
|
||||
)
|
||||
|
||||
if plugin_id in plugin_loader.loaded_plugins:
|
||||
raise HTTPException(status_code=400, detail="Plugin already loaded")
|
||||
|
||||
|
||||
# Load plugin with proper context management
|
||||
plugin_dir = Path(plugin.plugin_dir)
|
||||
|
||||
|
||||
# Create plugin context for standardized interface
|
||||
plugin_context = plugin_context_manager.create_plugin_context(
|
||||
plugin_id=plugin_id,
|
||||
user_id=str(current_user.get("id", "unknown")), # Use actual user ID
|
||||
session_type="api_load"
|
||||
session_type="api_load",
|
||||
)
|
||||
|
||||
|
||||
# Generate plugin token based on context
|
||||
plugin_token = plugin_context_manager.generate_plugin_token(plugin_context["context_id"])
|
||||
|
||||
plugin_token = plugin_context_manager.generate_plugin_token(
|
||||
plugin_context["context_id"]
|
||||
)
|
||||
|
||||
# Log plugin loading action
|
||||
plugin_context_manager.add_audit_trail_entry(
|
||||
plugin_context["context_id"],
|
||||
"plugin_load_via_api",
|
||||
{
|
||||
"plugin_dir": str(plugin_dir),
|
||||
"plugin_dir": str(plugin_dir),
|
||||
"user_id": current_user.get("id", "unknown"),
|
||||
"action": "load_plugin_with_sandbox"
|
||||
}
|
||||
"action": "load_plugin_with_sandbox",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
await plugin_loader.load_plugin_with_sandbox(plugin_dir, plugin_token)
|
||||
|
||||
|
||||
return {
|
||||
"status": "loaded",
|
||||
"plugin_id": plugin_id,
|
||||
"message": "Plugin loaded successfully"
|
||||
"message": "Plugin loaded successfully",
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Plugin load failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Load failed: {e}")
|
||||
@@ -354,24 +359,23 @@ async def load_plugin(
|
||||
|
||||
@router.post("/{plugin_id}/unload")
|
||||
async def unload_plugin(
|
||||
plugin_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
plugin_id: str, current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Unload plugin from runtime"""
|
||||
try:
|
||||
if plugin_id not in plugin_loader.loaded_plugins:
|
||||
raise HTTPException(status_code=404, detail="Plugin not loaded")
|
||||
|
||||
|
||||
success = await plugin_loader.unload_plugin(plugin_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to unload plugin")
|
||||
|
||||
|
||||
return {
|
||||
"status": "unloaded",
|
||||
"plugin_id": plugin_id,
|
||||
"message": "Plugin unloaded successfully"
|
||||
"message": "Plugin unloaded successfully",
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Plugin unload failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Unload failed: {e}")
|
||||
@@ -382,40 +386,38 @@ async def unload_plugin(
|
||||
async def get_plugin_configuration(
|
||||
plugin_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get plugin configuration for user with automatic decryption"""
|
||||
try:
|
||||
from app.services.plugin_configuration_manager import plugin_config_manager
|
||||
|
||||
|
||||
# Use the new configuration manager to get decrypted configuration
|
||||
config_data = await plugin_config_manager.get_plugin_configuration(
|
||||
plugin_id=plugin_id,
|
||||
user_id=current_user["id"],
|
||||
db=db,
|
||||
decrypt_sensitive=False # Don't decrypt sensitive data for API response
|
||||
decrypt_sensitive=False, # Don't decrypt sensitive data for API response
|
||||
)
|
||||
|
||||
|
||||
if config_data is not None:
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"configuration": config_data,
|
||||
"has_configuration": True
|
||||
"has_configuration": True,
|
||||
}
|
||||
else:
|
||||
# Get default configuration from manifest
|
||||
resolved_config = await plugin_config_manager.get_resolved_configuration(
|
||||
plugin_id=plugin_id,
|
||||
user_id=current_user["id"],
|
||||
db=db
|
||||
plugin_id=plugin_id, user_id=current_user["id"], db=db
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"configuration": resolved_config,
|
||||
"has_configuration": False
|
||||
"has_configuration": False,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get plugin configuration: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get configuration: {e}")
|
||||
@@ -426,17 +428,17 @@ async def save_plugin_configuration(
|
||||
plugin_id: str,
|
||||
config_request: Dict[str, Any],
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Save plugin configuration for user with automatic encryption of sensitive fields"""
|
||||
try:
|
||||
from app.services.plugin_configuration_manager import plugin_config_manager
|
||||
|
||||
|
||||
# Extract configuration data and metadata
|
||||
config_data = config_request.get("configuration", {})
|
||||
config_name = config_request.get("name", "Default Configuration")
|
||||
config_description = config_request.get("description")
|
||||
|
||||
|
||||
# Use the new configuration manager to save with automatic encryption
|
||||
saved_config = await plugin_config_manager.save_plugin_configuration(
|
||||
plugin_id=plugin_id,
|
||||
@@ -444,43 +446,47 @@ async def save_plugin_configuration(
|
||||
config_data=config_data,
|
||||
config_name=config_name,
|
||||
config_description=config_description,
|
||||
db=db
|
||||
db=db,
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"status": "saved",
|
||||
"plugin_id": plugin_id,
|
||||
"configuration_id": str(saved_config.id),
|
||||
"message": "Configuration saved successfully with automatic encryption of sensitive fields"
|
||||
"message": "Configuration saved successfully with automatic encryption of sensitive fields",
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save plugin configuration: {e}")
|
||||
await db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"Failed to save configuration: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to save configuration: {e}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{plugin_id}/schema")
|
||||
async def get_plugin_configuration_schema(
|
||||
plugin_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get plugin configuration schema from manifest"""
|
||||
try:
|
||||
from app.services.plugin_configuration_manager import plugin_config_manager
|
||||
|
||||
|
||||
# Use the new configuration manager to get schema
|
||||
schema = await plugin_config_manager.get_plugin_configuration_schema(plugin_id, db)
|
||||
|
||||
schema = await plugin_config_manager.get_plugin_configuration_schema(
|
||||
plugin_id, db
|
||||
)
|
||||
|
||||
if not schema:
|
||||
raise HTTPException(status_code=404, detail=f"No configuration schema available for plugin '{plugin_id}'")
|
||||
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"schema": schema
|
||||
}
|
||||
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No configuration schema available for plugin '{plugin_id}'",
|
||||
)
|
||||
|
||||
return {"plugin_id": plugin_id, "schema": schema}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -493,120 +499,129 @@ async def test_plugin_credentials(
|
||||
plugin_id: str,
|
||||
test_request: Dict[str, Any],
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Test plugin credentials (currently supports Zammad)"""
|
||||
import httpx
|
||||
|
||||
|
||||
try:
|
||||
logger.info(f"Testing credentials for plugin {plugin_id}")
|
||||
|
||||
|
||||
# Get plugin from database to check its name
|
||||
from app.models.plugin import Plugin
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
stmt = select(Plugin).where(Plugin.id == plugin_id)
|
||||
result = await db.execute(stmt)
|
||||
plugin = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail=f"Plugin '{plugin_id}' not found")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Plugin '{plugin_id}' not found"
|
||||
)
|
||||
|
||||
# Check if this is a Zammad plugin
|
||||
if plugin.name.lower() != 'zammad':
|
||||
raise HTTPException(status_code=400, detail=f"Credential testing not supported for plugin '{plugin.name}'")
|
||||
|
||||
if plugin.name.lower() != "zammad":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Credential testing not supported for plugin '{plugin.name}'",
|
||||
)
|
||||
|
||||
# Extract credentials from request
|
||||
zammad_url = test_request.get('zammad_url')
|
||||
api_token = test_request.get('api_token')
|
||||
|
||||
zammad_url = test_request.get("zammad_url")
|
||||
api_token = test_request.get("api_token")
|
||||
|
||||
if not zammad_url or not api_token:
|
||||
raise HTTPException(status_code=400, detail="Both zammad_url and api_token are required")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Both zammad_url and api_token are required"
|
||||
)
|
||||
|
||||
# Clean up the URL (remove trailing slash)
|
||||
zammad_url = zammad_url.rstrip('/')
|
||||
|
||||
zammad_url = zammad_url.rstrip("/")
|
||||
|
||||
# Test credentials by making a read-only API call to Zammad
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
# Try to get user info - this is a safe read-only operation
|
||||
test_url = f"{zammad_url}/api/v1/users/me"
|
||||
headers = {
|
||||
'Authorization': f'Token token={api_token}',
|
||||
'Content-Type': 'application/json'
|
||||
"Authorization": f"Token token={api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
response = await client.get(test_url, headers=headers)
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
# Success - credentials are valid
|
||||
user_data = response.json()
|
||||
user_email = user_data.get('email', 'unknown')
|
||||
user_email = user_data.get("email", "unknown")
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Credentials verified! Connected as: {user_email}",
|
||||
"zammad_url": zammad_url,
|
||||
"user_info": {
|
||||
"email": user_email,
|
||||
"firstname": user_data.get('firstname', ''),
|
||||
"lastname": user_data.get('lastname', '')
|
||||
}
|
||||
"firstname": user_data.get("firstname", ""),
|
||||
"lastname": user_data.get("lastname", ""),
|
||||
},
|
||||
}
|
||||
elif response.status_code == 401:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Invalid API token. Please check your token and try again.",
|
||||
"error_code": "invalid_token"
|
||||
"error_code": "invalid_token",
|
||||
}
|
||||
elif response.status_code == 404:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Zammad URL not found. Please verify the URL is correct.",
|
||||
"error_code": "invalid_url"
|
||||
"error_code": "invalid_url",
|
||||
}
|
||||
else:
|
||||
error_text = ""
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_text = error_data.get('error', error_data.get('message', ''))
|
||||
error_text = error_data.get("error", error_data.get("message", ""))
|
||||
except:
|
||||
error_text = response.text[:200]
|
||||
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Connection failed (HTTP {response.status_code}): {error_text}",
|
||||
"error_code": "connection_failed"
|
||||
"error_code": "connection_failed",
|
||||
}
|
||||
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Connection timeout. Please check the Zammad URL and your network connection.",
|
||||
"error_code": "timeout"
|
||||
"error_code": "timeout",
|
||||
}
|
||||
except httpx.ConnectError:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Could not connect to Zammad. Please verify the URL is correct and accessible.",
|
||||
"error_code": "connection_error"
|
||||
"error_code": "connection_error",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test plugin credentials: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Test failed: {str(e)}",
|
||||
"error_code": "unknown_error"
|
||||
"error_code": "unknown_error",
|
||||
}
|
||||
|
||||
|
||||
# Background task for plugin installation
|
||||
async def install_plugin_background(plugin_id: str, version: str, user_id: str, db: AsyncSession):
|
||||
async def install_plugin_background(
|
||||
plugin_id: str, version: str, user_id: str, db: AsyncSession
|
||||
):
|
||||
"""Background task for plugin installation"""
|
||||
try:
|
||||
result = await plugin_installer.install_plugin_from_repository(
|
||||
plugin_id, version, user_id, db
|
||||
)
|
||||
logger.info(f"Background installation completed: {result}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Background installation failed: {e}")
|
||||
# TODO: Notify user of installation failure
|
||||
# TODO: Notify user of installation failure
|
||||
|
||||
@@ -17,7 +17,10 @@ from app.core.security import get_current_user
|
||||
from app.models.user import User
|
||||
from app.core.logging import log_api_request
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage
|
||||
from app.services.llm.models import (
|
||||
ChatRequest as LLMChatRequest,
|
||||
ChatMessage as LLMChatMessage,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -59,13 +62,14 @@ class ImprovePromptRequest(BaseModel):
|
||||
|
||||
@router.get("/templates", response_model=List[PromptTemplateResponse])
|
||||
async def list_prompt_templates(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get all prompt templates"""
|
||||
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
log_api_request("list_prompt_templates", {"user_id": user_id})
|
||||
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(PromptTemplate)
|
||||
@@ -73,7 +77,7 @@ async def list_prompt_templates(
|
||||
.order_by(PromptTemplate.name)
|
||||
)
|
||||
templates = result.scalars().all()
|
||||
|
||||
|
||||
template_list = []
|
||||
for template in templates:
|
||||
template_dict = {
|
||||
@@ -85,28 +89,38 @@ async def list_prompt_templates(
|
||||
"is_default": template.is_default,
|
||||
"is_active": template.is_active,
|
||||
"version": template.version,
|
||||
"created_at": template.created_at.isoformat() if template.created_at else None,
|
||||
"updated_at": template.updated_at.isoformat() if template.updated_at else None
|
||||
"created_at": template.created_at.isoformat()
|
||||
if template.created_at
|
||||
else None,
|
||||
"updated_at": template.updated_at.isoformat()
|
||||
if template.updated_at
|
||||
else None,
|
||||
}
|
||||
template_list.append(template_dict)
|
||||
|
||||
|
||||
return template_list
|
||||
|
||||
|
||||
except Exception as e:
|
||||
log_api_request("list_prompt_templates_error", {"error": str(e), "user_id": user_id})
|
||||
raise HTTPException(status_code=500, detail=f"Failed to fetch prompt templates: {str(e)}")
|
||||
log_api_request(
|
||||
"list_prompt_templates_error", {"error": str(e), "user_id": user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch prompt templates: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/templates/{type_key}", response_model=PromptTemplateResponse)
|
||||
async def get_prompt_template(
|
||||
type_key: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get a specific prompt template by type key"""
|
||||
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
log_api_request("get_prompt_template", {"user_id": user_id, "type_key": type_key})
|
||||
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(PromptTemplate)
|
||||
@@ -114,10 +128,10 @@ async def get_prompt_template(
|
||||
.where(PromptTemplate.is_active == True)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Prompt template not found")
|
||||
|
||||
|
||||
return {
|
||||
"id": template.id,
|
||||
"name": template.name,
|
||||
@@ -127,15 +141,23 @@ async def get_prompt_template(
|
||||
"is_default": template.is_default,
|
||||
"is_active": template.is_active,
|
||||
"version": template.version,
|
||||
"created_at": template.created_at.isoformat() if template.created_at else None,
|
||||
"updated_at": template.updated_at.isoformat() if template.updated_at else None
|
||||
"created_at": template.created_at.isoformat()
|
||||
if template.created_at
|
||||
else None,
|
||||
"updated_at": template.updated_at.isoformat()
|
||||
if template.updated_at
|
||||
else None,
|
||||
}
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
log_api_request("get_prompt_template_error", {"error": str(e), "user_id": user_id})
|
||||
raise HTTPException(status_code=500, detail=f"Failed to fetch prompt template: {str(e)}")
|
||||
log_api_request(
|
||||
"get_prompt_template_error", {"error": str(e), "user_id": user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch prompt template: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/templates/{type_key}")
|
||||
@@ -143,16 +165,17 @@ async def update_prompt_template(
|
||||
type_key: str,
|
||||
request: PromptTemplateRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update a prompt template"""
|
||||
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
log_api_request("update_prompt_template", {
|
||||
"user_id": user_id,
|
||||
"type_key": type_key,
|
||||
"name": request.name
|
||||
})
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
log_api_request(
|
||||
"update_prompt_template",
|
||||
{"user_id": user_id, "type_key": type_key, "name": request.name},
|
||||
)
|
||||
|
||||
try:
|
||||
# Get existing template
|
||||
result = await db.execute(
|
||||
@@ -161,10 +184,10 @@ async def update_prompt_template(
|
||||
.where(PromptTemplate.is_active == True)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Prompt template not found")
|
||||
|
||||
|
||||
# Update the template
|
||||
await db.execute(
|
||||
update(PromptTemplate)
|
||||
@@ -175,19 +198,18 @@ async def update_prompt_template(
|
||||
system_prompt=request.system_prompt,
|
||||
is_active=request.is_active,
|
||||
version=template.version + 1,
|
||||
updated_at=datetime.utcnow()
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Return updated template
|
||||
updated_result = await db.execute(
|
||||
select(PromptTemplate)
|
||||
.where(PromptTemplate.type_key == type_key)
|
||||
select(PromptTemplate).where(PromptTemplate.type_key == type_key)
|
||||
)
|
||||
updated_template = updated_result.scalar_one()
|
||||
|
||||
|
||||
return {
|
||||
"id": updated_template.id,
|
||||
"name": updated_template.name,
|
||||
@@ -197,41 +219,52 @@ async def update_prompt_template(
|
||||
"is_default": updated_template.is_default,
|
||||
"is_active": updated_template.is_active,
|
||||
"version": updated_template.version,
|
||||
"created_at": updated_template.created_at.isoformat() if updated_template.created_at else None,
|
||||
"updated_at": updated_template.updated_at.isoformat() if updated_template.updated_at else None
|
||||
"created_at": updated_template.created_at.isoformat()
|
||||
if updated_template.created_at
|
||||
else None,
|
||||
"updated_at": updated_template.updated_at.isoformat()
|
||||
if updated_template.updated_at
|
||||
else None,
|
||||
}
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
log_api_request("update_prompt_template_error", {"error": str(e), "user_id": user_id})
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update prompt template: {str(e)}")
|
||||
log_api_request(
|
||||
"update_prompt_template_error", {"error": str(e), "user_id": user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update prompt template: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/templates/create")
|
||||
async def create_prompt_template(
|
||||
request: PromptTemplateRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a new prompt template"""
|
||||
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
log_api_request("create_prompt_template", {
|
||||
"user_id": user_id,
|
||||
"type_key": request.type_key,
|
||||
"name": request.name
|
||||
})
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
log_api_request(
|
||||
"create_prompt_template",
|
||||
{"user_id": user_id, "type_key": request.type_key, "name": request.name},
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if template already exists
|
||||
existing_result = await db.execute(
|
||||
select(PromptTemplate)
|
||||
.where(PromptTemplate.type_key == request.type_key)
|
||||
select(PromptTemplate).where(PromptTemplate.type_key == request.type_key)
|
||||
)
|
||||
if existing_result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail="Prompt template with this type key already exists")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Prompt template with this type key already exists",
|
||||
)
|
||||
|
||||
# Create new template
|
||||
template = PromptTemplate(
|
||||
id=str(uuid.uuid4()),
|
||||
@@ -243,13 +276,13 @@ async def create_prompt_template(
|
||||
is_active=request.is_active,
|
||||
version=1,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
db.add(template)
|
||||
await db.commit()
|
||||
await db.refresh(template)
|
||||
|
||||
|
||||
return {
|
||||
"id": template.id,
|
||||
"name": template.name,
|
||||
@@ -259,27 +292,36 @@ async def create_prompt_template(
|
||||
"is_default": template.is_default,
|
||||
"is_active": template.is_active,
|
||||
"version": template.version,
|
||||
"created_at": template.created_at.isoformat() if template.created_at else None,
|
||||
"updated_at": template.updated_at.isoformat() if template.updated_at else None
|
||||
"created_at": template.created_at.isoformat()
|
||||
if template.created_at
|
||||
else None,
|
||||
"updated_at": template.updated_at.isoformat()
|
||||
if template.updated_at
|
||||
else None,
|
||||
}
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
log_api_request("create_prompt_template_error", {"error": str(e), "user_id": user_id})
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create prompt template: {str(e)}")
|
||||
log_api_request(
|
||||
"create_prompt_template_error", {"error": str(e), "user_id": user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create prompt template: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/variables", response_model=List[PromptVariableResponse])
|
||||
async def list_prompt_variables(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get all available prompt variables"""
|
||||
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
log_api_request("list_prompt_variables", {"user_id": user_id})
|
||||
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(ChatbotPromptVariable)
|
||||
@@ -287,7 +329,7 @@ async def list_prompt_variables(
|
||||
.order_by(ChatbotPromptVariable.variable_name)
|
||||
)
|
||||
variables = result.scalars().all()
|
||||
|
||||
|
||||
variable_list = []
|
||||
for variable in variables:
|
||||
variable_dict = {
|
||||
@@ -295,27 +337,33 @@ async def list_prompt_variables(
|
||||
"variable_name": variable.variable_name,
|
||||
"description": variable.description,
|
||||
"example_value": variable.example_value,
|
||||
"is_active": variable.is_active
|
||||
"is_active": variable.is_active,
|
||||
}
|
||||
variable_list.append(variable_dict)
|
||||
|
||||
|
||||
return variable_list
|
||||
|
||||
|
||||
except Exception as e:
|
||||
log_api_request("list_prompt_variables_error", {"error": str(e), "user_id": user_id})
|
||||
raise HTTPException(status_code=500, detail=f"Failed to fetch prompt variables: {str(e)}")
|
||||
log_api_request(
|
||||
"list_prompt_variables_error", {"error": str(e), "user_id": user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch prompt variables: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/templates/{type_key}/reset")
|
||||
async def reset_prompt_template(
|
||||
type_key: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Reset a prompt template to its default"""
|
||||
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
log_api_request("reset_prompt_template", {"user_id": user_id, "type_key": type_key})
|
||||
|
||||
|
||||
# Define default prompts (same as in migration)
|
||||
default_prompts = {
|
||||
"assistant": "You are a helpful AI assistant. Provide accurate, concise, and friendly responses. Always aim to be helpful while being honest about your limitations. When you don't know something, say so clearly. Be professional but approachable in your communication style.",
|
||||
@@ -323,12 +371,12 @@ async def reset_prompt_template(
|
||||
"teacher": "You are an experienced educational tutor and learning facilitator. Break down complex concepts into understandable, digestible parts. Use analogies, examples, and step-by-step explanations to help students learn. Encourage critical thinking through thoughtful questions. Be patient, supportive, and encouraging. Adapt your teaching style to different learning preferences. When a student makes mistakes, guide them to the correct answer rather than just providing it.",
|
||||
"researcher": "You are a thorough research assistant with a focus on accuracy and evidence-based information. Provide well-researched, factual information with sources when possible. Be thorough in your analysis and present multiple perspectives when relevant topics have different viewpoints. Always distinguish between established facts, current research, and opinions. When information is uncertain or contested, clearly communicate the level of confidence and supporting evidence.",
|
||||
"creative_writer": "You are an experienced creative writing mentor and storytelling expert. Help with brainstorming ideas, character development, plot structure, dialogue, and creative expression. Be imaginative and inspiring while providing constructive, actionable feedback. Encourage experimentation with different writing styles and techniques. When reviewing work, balance praise for strengths with specific suggestions for improvement. Help writers find their unique voice while mastering fundamental storytelling principles.",
|
||||
"custom": "You are a helpful AI assistant. Your personality, expertise, and behavior will be defined by the user through custom instructions. Follow the user's guidance on how to respond, what tone to use, and what role to play. Be adaptable and responsive to the specific needs and preferences outlined in your configuration."
|
||||
"custom": "You are a helpful AI assistant. Your personality, expertise, and behavior will be defined by the user through custom instructions. Follow the user's guidance on how to respond, what tone to use, and what role to play. Be adaptable and responsive to the specific needs and preferences outlined in your configuration.",
|
||||
}
|
||||
|
||||
|
||||
if type_key not in default_prompts:
|
||||
raise HTTPException(status_code=404, detail="Unknown prompt template type")
|
||||
|
||||
|
||||
try:
|
||||
# Update the template to default
|
||||
await db.execute(
|
||||
@@ -337,33 +385,39 @@ async def reset_prompt_template(
|
||||
.values(
|
||||
system_prompt=default_prompts[type_key],
|
||||
version=PromptTemplate.version + 1,
|
||||
updated_at=datetime.utcnow()
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
return {"message": "Prompt template reset to default successfully"}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
log_api_request("reset_prompt_template_error", {"error": str(e), "user_id": user_id})
|
||||
raise HTTPException(status_code=500, detail=f"Failed to reset prompt template: {str(e)}")
|
||||
log_api_request(
|
||||
"reset_prompt_template_error", {"error": str(e), "user_id": user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to reset prompt template: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/improve")
|
||||
async def improve_prompt_with_ai(
|
||||
request: ImprovePromptRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Improve a prompt using AI"""
|
||||
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
log_api_request("improve_prompt_with_ai", {
|
||||
"user_id": user_id,
|
||||
"chatbot_type": request.chatbot_type
|
||||
})
|
||||
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
log_api_request(
|
||||
"improve_prompt_with_ai",
|
||||
{"user_id": user_id, "chatbot_type": request.chatbot_type},
|
||||
)
|
||||
|
||||
try:
|
||||
# Create system message for improvement
|
||||
system_message = """You are an expert prompt engineer. Your task is to improve the given prompt to make it more effective, clear, and specific for the intended chatbot type.
|
||||
@@ -392,92 +446,100 @@ Please improve this prompt to make it more effective for a {request.chatbot_type
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_message},
|
||||
{"role": "user", "content": user_message}
|
||||
{"role": "user", "content": user_message},
|
||||
]
|
||||
|
||||
# Get available models to use a default model
|
||||
models = await llm_service.get_models()
|
||||
if not models:
|
||||
raise HTTPException(status_code=503, detail="No LLM models available")
|
||||
|
||||
|
||||
# Use the first available model (you might want to make this configurable)
|
||||
default_model = models[0].id
|
||||
|
||||
|
||||
# Prepare the chat request for the new LLM service
|
||||
chat_request = LLMChatRequest(
|
||||
model=default_model,
|
||||
messages=[LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages],
|
||||
messages=[
|
||||
LLMChatMessage(role=msg["role"], content=msg["content"])
|
||||
for msg in messages
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=1000,
|
||||
user_id=str(user_id),
|
||||
api_key_id=1 # Using default API key, you might want to make this dynamic
|
||||
api_key_id=1, # Using default API key, you might want to make this dynamic
|
||||
)
|
||||
|
||||
|
||||
# Make the AI call
|
||||
response = await llm_service.create_chat_completion(chat_request)
|
||||
|
||||
|
||||
# Extract the improved prompt from the response
|
||||
improved_prompt = response.choices[0].message.content.strip()
|
||||
|
||||
|
||||
return {
|
||||
"improved_prompt": improved_prompt,
|
||||
"original_prompt": request.current_prompt,
|
||||
"model_used": default_model
|
||||
"model_used": default_model,
|
||||
}
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
log_api_request("improve_prompt_with_ai_error", {"error": str(e), "user_id": user_id})
|
||||
raise HTTPException(status_code=500, detail=f"Failed to improve prompt: {str(e)}")
|
||||
log_api_request(
|
||||
"improve_prompt_with_ai_error", {"error": str(e), "user_id": user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to improve prompt: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/seed-defaults")
|
||||
async def seed_default_templates(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Seed default prompt templates for all chatbot types"""
|
||||
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
user_id = (
|
||||
current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
)
|
||||
log_api_request("seed_default_templates", {"user_id": user_id})
|
||||
|
||||
|
||||
# Define default prompts (same as in reset)
|
||||
default_prompts = {
|
||||
"assistant": {
|
||||
"name": "General Assistant",
|
||||
"description": "A helpful, accurate, and friendly AI assistant",
|
||||
"prompt": "You are a helpful AI assistant. Provide accurate, concise, and friendly responses. Always aim to be helpful while being honest about your limitations. When you don't know something, say so clearly. Be professional but approachable in your communication style."
|
||||
"prompt": "You are a helpful AI assistant. Provide accurate, concise, and friendly responses. Always aim to be helpful while being honest about your limitations. When you don't know something, say so clearly. Be professional but approachable in your communication style.",
|
||||
},
|
||||
"customer_support": {
|
||||
"name": "Customer Support Agent",
|
||||
"description": "Professional customer service representative focused on solving problems",
|
||||
"prompt": "You are a professional customer support representative. Be empathetic, professional, and solution-focused in all interactions. Always try to understand the customer's issue fully before providing solutions. Use the knowledge base to provide accurate information. When you cannot resolve an issue, explain clearly how the customer can escalate or get further help. Maintain a helpful and patient tone even in difficult situations."
|
||||
"prompt": "You are a professional customer support representative. Be empathetic, professional, and solution-focused in all interactions. Always try to understand the customer's issue fully before providing solutions. Use the knowledge base to provide accurate information. When you cannot resolve an issue, explain clearly how the customer can escalate or get further help. Maintain a helpful and patient tone even in difficult situations.",
|
||||
},
|
||||
"teacher": {
|
||||
"name": "Educational Tutor",
|
||||
"description": "Patient and encouraging educational facilitator",
|
||||
"prompt": "You are an experienced educational tutor and learning facilitator. Break down complex concepts into understandable, digestible parts. Use analogies, examples, and step-by-step explanations to help students learn. Encourage critical thinking through thoughtful questions. Be patient, supportive, and encouraging. Adapt your teaching style to different learning preferences. When a student makes mistakes, guide them to the correct answer rather than just providing it."
|
||||
"prompt": "You are an experienced educational tutor and learning facilitator. Break down complex concepts into understandable, digestible parts. Use analogies, examples, and step-by-step explanations to help students learn. Encourage critical thinking through thoughtful questions. Be patient, supportive, and encouraging. Adapt your teaching style to different learning preferences. When a student makes mistakes, guide them to the correct answer rather than just providing it.",
|
||||
},
|
||||
"researcher": {
|
||||
"name": "Research Assistant",
|
||||
"description": "Thorough researcher focused on evidence-based information",
|
||||
"prompt": "You are a thorough research assistant with a focus on accuracy and evidence-based information. Provide well-researched, factual information with sources when possible. Be thorough in your analysis and present multiple perspectives when relevant topics have different viewpoints. Always distinguish between established facts, current research, and opinions. When information is uncertain or contested, clearly communicate the level of confidence and supporting evidence."
|
||||
"prompt": "You are a thorough research assistant with a focus on accuracy and evidence-based information. Provide well-researched, factual information with sources when possible. Be thorough in your analysis and present multiple perspectives when relevant topics have different viewpoints. Always distinguish between established facts, current research, and opinions. When information is uncertain or contested, clearly communicate the level of confidence and supporting evidence.",
|
||||
},
|
||||
"creative_writer": {
|
||||
"name": "Creative Writing Mentor",
|
||||
"description": "Imaginative storytelling expert and writing coach",
|
||||
"prompt": "You are an experienced creative writing mentor and storytelling expert. Help with brainstorming ideas, character development, plot structure, dialogue, and creative expression. Be imaginative and inspiring while providing constructive, actionable feedback. Encourage experimentation with different writing styles and techniques. When reviewing work, balance praise for strengths with specific suggestions for improvement. Help writers find their unique voice while mastering fundamental storytelling principles."
|
||||
"prompt": "You are an experienced creative writing mentor and storytelling expert. Help with brainstorming ideas, character development, plot structure, dialogue, and creative expression. Be imaginative and inspiring while providing constructive, actionable feedback. Encourage experimentation with different writing styles and techniques. When reviewing work, balance praise for strengths with specific suggestions for improvement. Help writers find their unique voice while mastering fundamental storytelling principles.",
|
||||
},
|
||||
"custom": {
|
||||
"name": "Custom Chatbot",
|
||||
"description": "Customizable AI assistant with user-defined behavior",
|
||||
"prompt": "You are a helpful AI assistant. Your personality, expertise, and behavior will be defined by the user through custom instructions. Follow the user's guidance on how to respond, what tone to use, and what role to play. Be adaptable and responsive to the specific needs and preferences outlined in your configuration."
|
||||
}
|
||||
"prompt": "You are a helpful AI assistant. Your personality, expertise, and behavior will be defined by the user through custom instructions. Follow the user's guidance on how to respond, what tone to use, and what role to play. Be adaptable and responsive to the specific needs and preferences outlined in your configuration.",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
created_templates = []
|
||||
updated_templates = []
|
||||
|
||||
|
||||
try:
|
||||
for type_key, template_data in default_prompts.items():
|
||||
# Check if template already exists
|
||||
@@ -530,7 +592,9 @@ async def seed_default_templates(
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=[PromptTemplate.type_key])
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[PromptTemplate.type_key]
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
@@ -541,17 +605,21 @@ async def seed_default_templates(
|
||||
"prompt_template_seed_skipped",
|
||||
{"type_key": type_key, "reason": "already_exists"},
|
||||
)
|
||||
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
return {
|
||||
"message": "Default templates seeded successfully",
|
||||
"created": created_templates,
|
||||
"updated": updated_templates,
|
||||
"total": len(created_templates) + len(updated_templates)
|
||||
"total": len(created_templates) + len(updated_templates),
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
log_api_request("seed_default_templates_error", {"error": str(e), "user_id": user_id})
|
||||
raise HTTPException(status_code=500, detail=f"Failed to seed default templates: {str(e)}")
|
||||
log_api_request(
|
||||
"seed_default_templates_error", {"error": str(e), "user_id": user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to seed default templates: {str(e)}"
|
||||
)
|
||||
|
||||
@@ -29,6 +29,7 @@ router = APIRouter(tags=["RAG"])
|
||||
|
||||
# Request/Response Models
|
||||
|
||||
|
||||
class CollectionCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
@@ -78,12 +79,13 @@ class StatsResponse(BaseModel):
|
||||
|
||||
# Collection Endpoints
|
||||
|
||||
|
||||
@router.get("/collections", response_model=dict)
|
||||
async def get_collections(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get all RAG collections - live data directly from Qdrant (source of truth)"""
|
||||
try:
|
||||
@@ -103,7 +105,7 @@ async def get_collections(
|
||||
"collections": paginated_collections,
|
||||
"total": len(collections),
|
||||
"total_documents": stats_data.get("total_documents", 0),
|
||||
"total_size_bytes": stats_data.get("total_size_bytes", 0)
|
||||
"total_size_bytes": stats_data.get("total_size_bytes", 0),
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -113,20 +115,19 @@ async def get_collections(
|
||||
async def create_collection(
|
||||
collection_data: CollectionCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Create a new RAG collection"""
|
||||
try:
|
||||
rag_service = RAGService(db)
|
||||
collection = await rag_service.create_collection(
|
||||
name=collection_data.name,
|
||||
description=collection_data.description
|
||||
name=collection_data.name, description=collection_data.description
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"collection": collection.to_dict(),
|
||||
"message": "Collection created successfully"
|
||||
"message": "Collection created successfully",
|
||||
}
|
||||
except APIException as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
@@ -136,8 +137,7 @@ async def create_collection(
|
||||
|
||||
@router.get("/stats", response_model=dict)
|
||||
async def get_rag_stats(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get overall RAG statistics - live data directly from Qdrant"""
|
||||
try:
|
||||
@@ -147,7 +147,11 @@ async def get_rag_stats(
|
||||
stats_data = await qdrant_stats_service.get_collections_stats()
|
||||
|
||||
# Calculate active collections (collections with documents)
|
||||
active_collections = sum(1 for col in stats_data.get("collections", []) if col.get("document_count", 0) > 0)
|
||||
active_collections = sum(
|
||||
1
|
||||
for col in stats_data.get("collections", [])
|
||||
if col.get("document_count", 0) > 0
|
||||
)
|
||||
|
||||
# Calculate processing documents from database
|
||||
processing_docs = 0
|
||||
@@ -156,7 +160,9 @@ async def get_rag_stats(
|
||||
from app.models.rag_document import RagDocument, ProcessingStatus
|
||||
|
||||
result = await db.execute(
|
||||
select(RagDocument).where(RagDocument.status == ProcessingStatus.PROCESSING)
|
||||
select(RagDocument).where(
|
||||
RagDocument.status == ProcessingStatus.PROCESSING
|
||||
)
|
||||
)
|
||||
processing_docs = len(result.scalars().all())
|
||||
except Exception:
|
||||
@@ -167,22 +173,28 @@ async def get_rag_stats(
|
||||
"stats": {
|
||||
"collections": {
|
||||
"total": stats_data.get("total_collections", 0),
|
||||
"active": active_collections
|
||||
"active": active_collections,
|
||||
},
|
||||
"documents": {
|
||||
"total": stats_data.get("total_documents", 0),
|
||||
"processing": processing_docs,
|
||||
"processed": stats_data.get("total_documents", 0) # Indexed documents
|
||||
"processed": stats_data.get(
|
||||
"total_documents", 0
|
||||
), # Indexed documents
|
||||
},
|
||||
"storage": {
|
||||
"total_size_bytes": stats_data.get("total_size_bytes", 0),
|
||||
"total_size_mb": round(stats_data.get("total_size_bytes", 0) / (1024 * 1024), 2)
|
||||
"total_size_mb": round(
|
||||
stats_data.get("total_size_bytes", 0) / (1024 * 1024), 2
|
||||
),
|
||||
},
|
||||
"vectors": {
|
||||
"total": stats_data.get("total_documents", 0) # Same as documents for RAG
|
||||
"total": stats_data.get(
|
||||
"total_documents", 0
|
||||
) # Same as documents for RAG
|
||||
},
|
||||
"last_updated": datetime.utcnow().isoformat()
|
||||
}
|
||||
"last_updated": datetime.utcnow().isoformat(),
|
||||
},
|
||||
}
|
||||
|
||||
return response_data
|
||||
@@ -194,20 +206,17 @@ async def get_rag_stats(
|
||||
async def get_collection(
|
||||
collection_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get a specific collection"""
|
||||
try:
|
||||
rag_service = RAGService(db)
|
||||
collection = await rag_service.get_collection(collection_id)
|
||||
|
||||
|
||||
if not collection:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"collection": collection.to_dict()
|
||||
}
|
||||
|
||||
return {"success": True, "collection": collection.to_dict()}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -219,19 +228,20 @@ async def delete_collection(
|
||||
collection_id: int,
|
||||
cascade: bool = True, # Default to cascade deletion for better UX
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a collection and optionally all its documents"""
|
||||
try:
|
||||
rag_service = RAGService(db)
|
||||
success = await rag_service.delete_collection(collection_id, cascade=cascade)
|
||||
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Collection deleted successfully" + (" (with documents)" if cascade else "")
|
||||
"message": "Collection deleted successfully"
|
||||
+ (" (with documents)" if cascade else ""),
|
||||
}
|
||||
except APIException as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
@@ -243,13 +253,14 @@ async def delete_collection(
|
||||
|
||||
# Document Endpoints
|
||||
|
||||
|
||||
@router.get("/documents", response_model=dict)
|
||||
async def get_documents(
|
||||
collection_id: Optional[str] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get documents, optionally filtered by collection"""
|
||||
try:
|
||||
@@ -260,11 +271,7 @@ async def get_documents(
|
||||
if collection_id.startswith("ext_"):
|
||||
# External collections exist only in Qdrant and have no documents in PostgreSQL
|
||||
# Return empty list since they don't have managed documents
|
||||
return {
|
||||
"success": True,
|
||||
"documents": [],
|
||||
"total": 0
|
||||
}
|
||||
return {"success": True, "documents": [], "total": 0}
|
||||
else:
|
||||
# Try to convert to integer for managed collections
|
||||
try:
|
||||
@@ -272,29 +279,25 @@ async def get_documents(
|
||||
except (ValueError, TypeError):
|
||||
# Attempt to resolve by Qdrant collection name
|
||||
collection_row = await db.scalar(
|
||||
select(RagCollection).where(RagCollection.qdrant_collection_name == collection_id)
|
||||
select(RagCollection).where(
|
||||
RagCollection.qdrant_collection_name == collection_id
|
||||
)
|
||||
)
|
||||
if collection_row:
|
||||
collection_id_int = collection_row.id
|
||||
else:
|
||||
# Unknown collection identifier; return empty result instead of erroring out
|
||||
return {
|
||||
"success": True,
|
||||
"documents": [],
|
||||
"total": 0
|
||||
}
|
||||
|
||||
return {"success": True, "documents": [], "total": 0}
|
||||
|
||||
rag_service = RAGService(db)
|
||||
documents = await rag_service.get_documents(
|
||||
collection_id=collection_id_int,
|
||||
skip=skip,
|
||||
limit=limit
|
||||
collection_id=collection_id_int, skip=skip, limit=limit
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"documents": [doc.to_dict() for doc in documents],
|
||||
"total": len(documents)
|
||||
"total": len(documents),
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -305,13 +308,13 @@ async def upload_document(
|
||||
collection_id: str = Form(...),
|
||||
file: UploadFile = File(...),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Upload and process a document"""
|
||||
try:
|
||||
# Validate file can be read before processing
|
||||
filename = file.filename or "unknown"
|
||||
file_extension = filename.split('.')[-1].lower() if '.' in filename else ''
|
||||
file_extension = filename.split(".")[-1].lower() if "." in filename else ""
|
||||
|
||||
# Read file content once and use it for all validations
|
||||
file_content = await file.read()
|
||||
@@ -324,50 +327,66 @@ async def upload_document(
|
||||
|
||||
try:
|
||||
# Test file readability based on type
|
||||
if file_extension == 'jsonl':
|
||||
if file_extension == "jsonl":
|
||||
# Validate JSONL format - try to parse first few lines
|
||||
try:
|
||||
content_str = file_content.decode('utf-8')
|
||||
lines = content_str.strip().split('\n')[:5] # Check first 5 lines
|
||||
content_str = file_content.decode("utf-8")
|
||||
lines = content_str.strip().split("\n")[:5] # Check first 5 lines
|
||||
import json
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip(): # Skip empty lines
|
||||
json.loads(line) # Will raise JSONDecodeError if invalid
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="File is not valid UTF-8 text"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSONL format: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid JSONL format: {str(e)}"
|
||||
)
|
||||
|
||||
elif file_extension in ['txt', 'md', 'py', 'js', 'html', 'css', 'json']:
|
||||
elif file_extension in ["txt", "md", "py", "js", "html", "css", "json"]:
|
||||
# Validate text files can be decoded
|
||||
try:
|
||||
file_content.decode('utf-8')
|
||||
file_content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="File is not valid UTF-8 text"
|
||||
)
|
||||
|
||||
elif file_extension in ['pdf']:
|
||||
elif file_extension in ["pdf"]:
|
||||
# For PDF files, just check if it starts with PDF signature
|
||||
if not file_content.startswith(b'%PDF'):
|
||||
raise HTTPException(status_code=400, detail="Invalid PDF file format")
|
||||
if not file_content.startswith(b"%PDF"):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid PDF file format"
|
||||
)
|
||||
|
||||
elif file_extension in ['docx', 'xlsx', 'pptx']:
|
||||
elif file_extension in ["docx", "xlsx", "pptx"]:
|
||||
# For Office documents, check ZIP signature
|
||||
if not file_content.startswith(b'PK'):
|
||||
raise HTTPException(status_code=400, detail=f"Invalid {file_extension.upper()} file format")
|
||||
if not file_content.startswith(b"PK"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid {file_extension.upper()} file format",
|
||||
)
|
||||
|
||||
# For other file types, we'll rely on the document processor
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"File validation failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"File validation failed: {str(e)}"
|
||||
)
|
||||
|
||||
rag_service = RAGService(db)
|
||||
|
||||
# Resolve collection identifier (supports both numeric IDs and Qdrant collection names)
|
||||
collection_identifier = (collection_id or "").strip()
|
||||
if not collection_identifier:
|
||||
raise HTTPException(status_code=400, detail="Collection identifier is required")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Collection identifier is required"
|
||||
)
|
||||
|
||||
resolved_collection_id: Optional[int] = None
|
||||
|
||||
@@ -379,7 +398,9 @@ async def upload_document(
|
||||
qdrant_name = qdrant_name[4:]
|
||||
|
||||
try:
|
||||
collection_record = await rag_service.ensure_collection_record(qdrant_name)
|
||||
collection_record = await rag_service.ensure_collection_record(
|
||||
qdrant_name
|
||||
)
|
||||
except Exception as ensure_error:
|
||||
raise HTTPException(status_code=500, detail=str(ensure_error))
|
||||
|
||||
@@ -392,13 +413,13 @@ async def upload_document(
|
||||
collection_id=resolved_collection_id,
|
||||
file_content=file_content,
|
||||
filename=filename,
|
||||
content_type=file.content_type
|
||||
content_type=file.content_type,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"document": document.to_dict(),
|
||||
"message": "Document uploaded and processing started"
|
||||
"message": "Document uploaded and processing started",
|
||||
}
|
||||
except APIException as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
@@ -412,20 +433,17 @@ async def upload_document(
|
||||
async def get_document(
|
||||
document_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get a specific document"""
|
||||
try:
|
||||
rag_service = RAGService(db)
|
||||
document = await rag_service.get_document(document_id)
|
||||
|
||||
|
||||
if not document:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"document": document.to_dict()
|
||||
}
|
||||
|
||||
return {"success": True, "document": document.to_dict()}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -436,20 +454,17 @@ async def get_document(
|
||||
async def delete_document(
|
||||
document_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a document"""
|
||||
try:
|
||||
rag_service = RAGService(db)
|
||||
success = await rag_service.delete_document(document_id)
|
||||
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Document deleted successfully"
|
||||
}
|
||||
|
||||
return {"success": True, "message": "Document deleted successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -460,13 +475,13 @@ async def delete_document(
|
||||
async def reprocess_document(
|
||||
document_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Restart processing for a stuck or failed document"""
|
||||
try:
|
||||
rag_service = RAGService(db)
|
||||
success = await rag_service.reprocess_document(document_id)
|
||||
|
||||
|
||||
if not success:
|
||||
# Get document to check if it exists and its current status
|
||||
document = await rag_service.get_document(document_id)
|
||||
@@ -474,13 +489,13 @@ async def reprocess_document(
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot reprocess document with status '{document.status}'. Only 'processing' or 'error' documents can be reprocessed."
|
||||
status_code=400,
|
||||
detail=f"Cannot reprocess document with status '{document.status}'. Only 'processing' or 'error' documents can be reprocessed.",
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Document reprocessing started successfully"
|
||||
"message": "Document reprocessing started successfully",
|
||||
}
|
||||
except APIException as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
@@ -494,22 +509,24 @@ async def reprocess_document(
|
||||
async def download_document(
|
||||
document_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Download the original document file"""
|
||||
try:
|
||||
rag_service = RAGService(db)
|
||||
result = await rag_service.download_document(document_id)
|
||||
|
||||
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Document not found or file not available")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Document not found or file not available"
|
||||
)
|
||||
|
||||
content, filename, mime_type = result
|
||||
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(content),
|
||||
media_type=mime_type,
|
||||
headers={"Content-Disposition": f"attachment; filename={filename}"}
|
||||
headers={"Content-Disposition": f"attachment; filename={filename}"},
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -517,9 +534,9 @@ async def download_document(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
||||
# Debug Endpoints
|
||||
|
||||
|
||||
@router.post("/debug/search")
|
||||
async def search_with_debug(
|
||||
query: str,
|
||||
@@ -527,13 +544,13 @@ async def search_with_debug(
|
||||
score_threshold: float = 0.3,
|
||||
collection_name: str = None,
|
||||
config: Dict[str, Any] = None,
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Enhanced search with comprehensive debug information
|
||||
"""
|
||||
# Get RAG module from module manager
|
||||
rag_module = module_manager.modules.get('rag')
|
||||
rag_module = module_manager.modules.get("rag")
|
||||
if not rag_module or not rag_module.enabled:
|
||||
raise HTTPException(status_code=503, detail="RAG module not initialized")
|
||||
|
||||
@@ -567,7 +584,7 @@ async def search_with_debug(
|
||||
query,
|
||||
max_results=max_results,
|
||||
score_threshold=score_threshold,
|
||||
collection_name=collection_name
|
||||
collection_name=collection_name,
|
||||
)
|
||||
search_time = (asyncio.get_event_loop().time() - search_start) * 1000
|
||||
|
||||
@@ -575,22 +592,23 @@ async def search_with_debug(
|
||||
scores = [r.score for r in results if r.score is not None]
|
||||
if scores:
|
||||
import statistics
|
||||
|
||||
debug_info["score_stats"] = {
|
||||
"min": min(scores),
|
||||
"max": max(scores),
|
||||
"avg": statistics.mean(scores),
|
||||
"stddev": statistics.stdev(scores) if len(scores) > 1 else 0
|
||||
"stddev": statistics.stdev(scores) if len(scores) > 1 else 0,
|
||||
}
|
||||
|
||||
# Get collection statistics
|
||||
try:
|
||||
from qdrant_client.http.models import Filter
|
||||
|
||||
collection_name = collection_name or rag_module.default_collection_name
|
||||
|
||||
# Count total documents
|
||||
count_result = rag_module.qdrant_client.count(
|
||||
collection_name=collection_name,
|
||||
count_filter=Filter(must=[])
|
||||
collection_name=collection_name, count_filter=Filter(must=[])
|
||||
)
|
||||
total_points = count_result.count
|
||||
|
||||
@@ -599,7 +617,7 @@ async def search_with_debug(
|
||||
collection_name=collection_name,
|
||||
limit=1000, # Sample for stats
|
||||
with_payload=True,
|
||||
with_vectors=False
|
||||
with_vectors=False,
|
||||
)
|
||||
|
||||
unique_docs = set()
|
||||
@@ -618,7 +636,7 @@ async def search_with_debug(
|
||||
debug_info["collection_stats"] = {
|
||||
"total_documents": len(unique_docs),
|
||||
"total_chunks": total_points,
|
||||
"languages": sorted(list(languages))
|
||||
"languages": sorted(list(languages)),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -631,16 +649,18 @@ async def search_with_debug(
|
||||
"document": {
|
||||
"id": result.document.id,
|
||||
"content": result.document.content,
|
||||
"metadata": result.document.metadata
|
||||
"metadata": result.document.metadata,
|
||||
},
|
||||
"score": result.score,
|
||||
"debug_info": {}
|
||||
"debug_info": {},
|
||||
}
|
||||
|
||||
# Add hybrid search debug info if available
|
||||
metadata = result.document.metadata or {}
|
||||
if "_vector_score" in metadata:
|
||||
enhanced_result["debug_info"]["vector_score"] = metadata["_vector_score"]
|
||||
enhanced_result["debug_info"]["vector_score"] = metadata[
|
||||
"_vector_score"
|
||||
]
|
||||
if "_bm25_score" in metadata:
|
||||
enhanced_result["debug_info"]["bm25_score"] = metadata["_bm25_score"]
|
||||
|
||||
@@ -652,7 +672,7 @@ async def search_with_debug(
|
||||
"results": enhanced_results,
|
||||
"debug_info": debug_info,
|
||||
"search_time_ms": search_time,
|
||||
"timestamp": start_time.isoformat()
|
||||
"timestamp": start_time.isoformat(),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -661,17 +681,17 @@ async def search_with_debug(
|
||||
|
||||
finally:
|
||||
# Restore original config if modified
|
||||
if config and 'original_config' in locals():
|
||||
if config and "original_config" in locals():
|
||||
rag_module.config = original_config
|
||||
|
||||
|
||||
@router.get("/debug/config")
|
||||
async def get_current_config(
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current RAG configuration"""
|
||||
# Get RAG module from module manager
|
||||
rag_module = module_manager.modules.get('rag')
|
||||
rag_module = module_manager.modules.get("rag")
|
||||
if not rag_module or not rag_module.enabled:
|
||||
raise HTTPException(status_code=503, detail="RAG module not initialized")
|
||||
|
||||
@@ -679,5 +699,5 @@ async def get_current_config(
|
||||
"config": rag_module.config,
|
||||
"embedding_model": rag_module.embedding_model,
|
||||
"enabled": rag_module.enabled,
|
||||
"collections": await rag_module._get_collections_safely()
|
||||
"collections": await rag_module._get_collections_safely(),
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -85,16 +85,16 @@ async def list_users(
|
||||
is_active: Optional[bool] = Query(None),
|
||||
search: Optional[str] = Query(None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List all users with pagination and filtering"""
|
||||
|
||||
|
||||
# Check permissions
|
||||
require_permission(current_user.get("permissions", []), "platform:users:read")
|
||||
|
||||
|
||||
# Build query
|
||||
query = select(User)
|
||||
|
||||
|
||||
# Apply filters
|
||||
if role:
|
||||
query = query.where(User.role == role)
|
||||
@@ -102,38 +102,42 @@ async def list_users(
|
||||
query = query.where(User.is_active == is_active)
|
||||
if search:
|
||||
query = query.where(
|
||||
(User.username.ilike(f"%{search}%")) |
|
||||
(User.email.ilike(f"%{search}%")) |
|
||||
(User.full_name.ilike(f"%{search}%"))
|
||||
(User.username.ilike(f"%{search}%"))
|
||||
| (User.email.ilike(f"%{search}%"))
|
||||
| (User.full_name.ilike(f"%{search}%"))
|
||||
)
|
||||
|
||||
|
||||
# Get total count
|
||||
total_query = select(User.id).select_from(query.subquery())
|
||||
total_result = await db.execute(total_query)
|
||||
total = len(total_result.fetchall())
|
||||
|
||||
|
||||
# Apply pagination
|
||||
offset = (page - 1) * size
|
||||
query = query.offset(offset).limit(size)
|
||||
|
||||
|
||||
# Execute query
|
||||
result = await db.execute(query)
|
||||
users = result.scalars().all()
|
||||
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="list_users",
|
||||
resource_type="user",
|
||||
details={"page": page, "size": size, "filters": {"role": role, "is_active": is_active, "search": search}}
|
||||
details={
|
||||
"page": page,
|
||||
"size": size,
|
||||
"filters": {"role": role, "is_active": is_active, "search": search},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
return UserListResponse(
|
||||
users=[UserResponse.model_validate(user) for user in users],
|
||||
total=total,
|
||||
page=page,
|
||||
size=size
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
@@ -141,34 +145,33 @@ async def list_users(
|
||||
async def get_user(
|
||||
user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get user by ID"""
|
||||
|
||||
|
||||
# Check permissions (users can view their own profile)
|
||||
if int(user_id) != current_user['id']:
|
||||
if int(user_id) != current_user["id"]:
|
||||
require_permission(current_user.get("permissions", []), "platform:users:read")
|
||||
|
||||
|
||||
# Get user
|
||||
query = select(User).where(User.id == int(user_id))
|
||||
result = await db.execute(query)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="get_user",
|
||||
resource_type="user",
|
||||
resource_id=user_id
|
||||
resource_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
return UserResponse.model_validate(user)
|
||||
|
||||
|
||||
@@ -176,26 +179,26 @@ async def get_user(
|
||||
async def create_user(
|
||||
user_data: UserCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a new user"""
|
||||
|
||||
|
||||
# Check permissions
|
||||
require_permission(current_user.get("permissions", []), "platform:users:create")
|
||||
|
||||
|
||||
# Check if user already exists
|
||||
query = select(User).where(
|
||||
(User.username == user_data.username) | (User.email == user_data.email)
|
||||
)
|
||||
result = await db.execute(query)
|
||||
existing_user = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if existing_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User with this username or email already exists"
|
||||
detail="User with this username or email already exists",
|
||||
)
|
||||
|
||||
|
||||
# Create user
|
||||
hashed_password = get_password_hash(user_data.password)
|
||||
new_user = User(
|
||||
@@ -204,25 +207,29 @@ async def create_user(
|
||||
full_name=user_data.full_name,
|
||||
hashed_password=hashed_password,
|
||||
role=user_data.role,
|
||||
is_active=user_data.is_active
|
||||
is_active=user_data.is_active,
|
||||
)
|
||||
|
||||
|
||||
db.add(new_user)
|
||||
await db.commit()
|
||||
await db.refresh(new_user)
|
||||
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="create_user",
|
||||
resource_type="user",
|
||||
resource_id=str(new_user.id),
|
||||
details={"username": user_data.username, "email": user_data.email, "role": user_data.role}
|
||||
details={
|
||||
"username": user_data.username,
|
||||
"email": user_data.email,
|
||||
"role": user_data.role,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"User created: {new_user.username} by {current_user['username']}")
|
||||
|
||||
|
||||
return UserResponse.model_validate(new_user)
|
||||
|
||||
|
||||
@@ -231,26 +238,25 @@ async def update_user(
|
||||
user_id: str,
|
||||
user_data: UserUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update user"""
|
||||
|
||||
|
||||
# Check permissions (users can update their own profile with limited fields)
|
||||
is_self_update = int(user_id) == current_user['id']
|
||||
is_self_update = int(user_id) == current_user["id"]
|
||||
if not is_self_update:
|
||||
require_permission(current_user.get("permissions", []), "platform:users:update")
|
||||
|
||||
|
||||
# Get user
|
||||
query = select(User).where(User.id == int(user_id))
|
||||
result = await db.execute(query)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
|
||||
# For self-updates, restrict what can be changed
|
||||
if is_self_update:
|
||||
allowed_fields = {"username", "email", "full_name"}
|
||||
@@ -259,41 +265,41 @@ async def update_user(
|
||||
if restricted_fields:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Cannot update fields: {restricted_fields}"
|
||||
detail=f"Cannot update fields: {restricted_fields}",
|
||||
)
|
||||
|
||||
|
||||
# Store original values for audit
|
||||
original_values = {
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"role": user.role,
|
||||
"is_active": user.is_active
|
||||
"is_active": user.is_active,
|
||||
}
|
||||
|
||||
|
||||
# Update user fields
|
||||
update_data = user_data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(user, field, value)
|
||||
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="update_user",
|
||||
resource_type="user",
|
||||
resource_id=user_id,
|
||||
details={
|
||||
"updated_fields": list(update_data.keys()),
|
||||
"before_values": original_values,
|
||||
"after_values": {k: getattr(user, k) for k in update_data.keys()}
|
||||
}
|
||||
"after_values": {k: getattr(user, k) for k in update_data.keys()},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"User updated: {user.username} by {current_user['username']}")
|
||||
|
||||
|
||||
return UserResponse.model_validate(user)
|
||||
|
||||
|
||||
@@ -301,47 +307,46 @@ async def update_user(
|
||||
async def delete_user(
|
||||
user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Delete user (soft delete by deactivating)"""
|
||||
|
||||
|
||||
# Check permissions
|
||||
require_permission(current_user.get("permissions", []), "platform:users:delete")
|
||||
|
||||
|
||||
# Prevent self-deletion
|
||||
if int(user_id) == current_user['id']:
|
||||
if int(user_id) == current_user["id"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot delete your own account"
|
||||
detail="Cannot delete your own account",
|
||||
)
|
||||
|
||||
|
||||
# Get user
|
||||
query = select(User).where(User.id == int(user_id))
|
||||
result = await db.execute(query)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
|
||||
# Soft delete by deactivating
|
||||
user.is_active = False
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="delete_user",
|
||||
resource_type="user",
|
||||
resource_id=user_id,
|
||||
details={"username": user.username, "email": user.email}
|
||||
details={"username": user.username, "email": user.email},
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"User deleted: {user.username} by {current_user['username']}")
|
||||
|
||||
|
||||
return {"message": "User deleted successfully"}
|
||||
|
||||
|
||||
@@ -350,50 +355,51 @@ async def change_password(
|
||||
user_id: str,
|
||||
password_data: PasswordChangeRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Change user password"""
|
||||
|
||||
|
||||
# Users can only change their own password, or admins can change any password
|
||||
is_self_update = int(user_id) == current_user['id']
|
||||
is_self_update = int(user_id) == current_user["id"]
|
||||
if not is_self_update:
|
||||
require_permission(current_user.get("permissions", []), "platform:users:update")
|
||||
|
||||
|
||||
# Get user
|
||||
query = select(User).where(User.id == int(user_id))
|
||||
result = await db.execute(query)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
|
||||
# For self-updates, verify current password
|
||||
if is_self_update:
|
||||
if not verify_password(password_data.current_password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Current password is incorrect"
|
||||
detail="Current password is incorrect",
|
||||
)
|
||||
|
||||
|
||||
# Update password
|
||||
user.hashed_password = get_password_hash(password_data.new_password)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="change_password",
|
||||
resource_type="user",
|
||||
resource_id=user_id,
|
||||
details={"target_user": user.username}
|
||||
details={"target_user": user.username},
|
||||
)
|
||||
|
||||
logger.info(f"Password changed for user: {user.username} by {current_user['username']}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Password changed for user: {user.username} by {current_user['username']}"
|
||||
)
|
||||
|
||||
return {"message": "Password changed successfully"}
|
||||
|
||||
|
||||
@@ -402,40 +408,41 @@ async def reset_password(
|
||||
user_id: str,
|
||||
password_data: PasswordResetRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Reset user password (admin only)"""
|
||||
|
||||
|
||||
# Check permissions
|
||||
require_permission(current_user.get("permissions", []), "platform:users:update")
|
||||
|
||||
|
||||
# Get user
|
||||
query = select(User).where(User.id == int(user_id))
|
||||
result = await db.execute(query)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
|
||||
# Reset password
|
||||
user.hashed_password = get_password_hash(password_data.new_password)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Log audit event
|
||||
await log_audit_event(
|
||||
db=db,
|
||||
user_id=current_user['id'],
|
||||
user_id=current_user["id"],
|
||||
action="reset_password",
|
||||
resource_type="user",
|
||||
resource_id=user_id,
|
||||
details={"target_user": user.username}
|
||||
details={"target_user": user.username},
|
||||
)
|
||||
|
||||
logger.info(f"Password reset for user: {user.username} by {current_user['username']}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Password reset for user: {user.username} by {current_user['username']}"
|
||||
)
|
||||
|
||||
return {"message": "Password reset successfully"}
|
||||
|
||||
|
||||
@@ -443,20 +450,22 @@ async def reset_password(
|
||||
async def get_user_api_keys(
|
||||
user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get API keys for a user"""
|
||||
|
||||
|
||||
# Check permissions (users can view their own API keys)
|
||||
is_self_request = int(user_id) == current_user['id']
|
||||
is_self_request = int(user_id) == current_user["id"]
|
||||
if not is_self_request:
|
||||
require_permission(current_user.get("permissions", []), "platform:api-keys:read")
|
||||
|
||||
require_permission(
|
||||
current_user.get("permissions", []), "platform:api-keys:read"
|
||||
)
|
||||
|
||||
# Get API keys
|
||||
query = select(APIKey).where(APIKey.user_id == int(user_id))
|
||||
result = await db.execute(query)
|
||||
api_keys = result.scalars().all()
|
||||
|
||||
|
||||
# Return safe representation (no key values)
|
||||
return [
|
||||
{
|
||||
@@ -466,8 +475,12 @@ async def get_user_api_keys(
|
||||
"scopes": api_key.scopes,
|
||||
"is_active": api_key.is_active,
|
||||
"created_at": api_key.created_at.isoformat(),
|
||||
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
|
||||
"last_used_at": api_key.last_used_at.isoformat() if api_key.last_used_at else None
|
||||
"expires_at": api_key.expires_at.isoformat()
|
||||
if api_key.expires_at
|
||||
else None,
|
||||
"last_used_at": api_key.last_used_at.isoformat()
|
||||
if api_key.last_used_at
|
||||
else None,
|
||||
}
|
||||
for api_key in api_keys
|
||||
]
|
||||
]
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""
|
||||
Core package
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -19,24 +19,19 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class CoreCacheService:
|
||||
"""Core Redis-based cache service for system-wide caching"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.redis_pool: Optional[ConnectionPool] = None
|
||||
self.redis_client: Optional[Redis] = None
|
||||
self.enabled = False
|
||||
self.stats = {
|
||||
"hits": 0,
|
||||
"misses": 0,
|
||||
"errors": 0,
|
||||
"total_requests": 0
|
||||
}
|
||||
|
||||
self.stats = {"hits": 0, "misses": 0, "errors": 0, "total_requests": 0}
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the core cache service with connection pool"""
|
||||
try:
|
||||
# Create Redis connection pool for better resource management
|
||||
redis_url = getattr(settings, 'REDIS_URL', 'redis://localhost:6379/0')
|
||||
|
||||
redis_url = getattr(settings, "REDIS_URL", "redis://localhost:6379/0")
|
||||
|
||||
self.redis_pool = ConnectionPool.from_url(
|
||||
redis_url,
|
||||
encoding="utf-8",
|
||||
@@ -45,141 +40,145 @@ class CoreCacheService:
|
||||
socket_timeout=5,
|
||||
retry_on_timeout=True,
|
||||
max_connections=20, # Shared pool for all cache operations
|
||||
health_check_interval=30
|
||||
health_check_interval=30,
|
||||
)
|
||||
|
||||
|
||||
self.redis_client = Redis(connection_pool=self.redis_pool)
|
||||
|
||||
|
||||
# Test connection
|
||||
await self.redis_client.ping()
|
||||
|
||||
|
||||
self.enabled = True
|
||||
logger.info("Core cache service initialized with Redis connection pool")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize core cache service: {e}")
|
||||
self.enabled = False
|
||||
raise
|
||||
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup cache resources"""
|
||||
if self.redis_client:
|
||||
await self.redis_client.close()
|
||||
self.redis_client = None
|
||||
|
||||
|
||||
if self.redis_pool:
|
||||
await self.redis_pool.disconnect()
|
||||
self.redis_pool = None
|
||||
|
||||
|
||||
self.enabled = False
|
||||
logger.info("Core cache service cleaned up")
|
||||
|
||||
|
||||
def _get_cache_key(self, key: str, prefix: str = "core") -> str:
|
||||
"""Generate cache key with prefix"""
|
||||
return f"{prefix}:{key}"
|
||||
|
||||
|
||||
async def get(self, key: str, default: Any = None, prefix: str = "core") -> Any:
|
||||
"""Get value from cache"""
|
||||
if not self.enabled:
|
||||
return default
|
||||
|
||||
|
||||
try:
|
||||
cache_key = self._get_cache_key(key, prefix)
|
||||
value = await self.redis_client.get(cache_key)
|
||||
|
||||
|
||||
if value is None:
|
||||
self.stats["misses"] += 1
|
||||
return default
|
||||
|
||||
|
||||
self.stats["hits"] += 1
|
||||
self.stats["total_requests"] += 1
|
||||
|
||||
|
||||
# Try to deserialize JSON
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache get error for key {key}: {e}")
|
||||
self.stats["errors"] += 1
|
||||
return default
|
||||
|
||||
async def set(self, key: str, value: Any, ttl: Optional[int] = None, prefix: str = "core") -> bool:
|
||||
|
||||
async def set(
|
||||
self, key: str, value: Any, ttl: Optional[int] = None, prefix: str = "core"
|
||||
) -> bool:
|
||||
"""Set value in cache"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
cache_key = self._get_cache_key(key, prefix)
|
||||
ttl = ttl or 3600 # Default 1 hour TTL
|
||||
|
||||
|
||||
# Serialize complex objects as JSON
|
||||
if isinstance(value, (dict, list, tuple)):
|
||||
value = json.dumps(value)
|
||||
|
||||
|
||||
await self.redis_client.setex(cache_key, ttl, value)
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache set error for key {key}: {e}")
|
||||
self.stats["errors"] += 1
|
||||
return False
|
||||
|
||||
|
||||
async def delete(self, key: str, prefix: str = "core") -> bool:
|
||||
"""Delete key from cache"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
cache_key = self._get_cache_key(key, prefix)
|
||||
result = await self.redis_client.delete(cache_key)
|
||||
return result > 0
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache delete error for key {key}: {e}")
|
||||
self.stats["errors"] += 1
|
||||
return False
|
||||
|
||||
|
||||
async def exists(self, key: str, prefix: str = "core") -> bool:
|
||||
"""Check if key exists in cache"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
cache_key = self._get_cache_key(key, prefix)
|
||||
return await self.redis_client.exists(cache_key) > 0
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache exists error for key {key}: {e}")
|
||||
self.stats["errors"] += 1
|
||||
return False
|
||||
|
||||
|
||||
async def clear_pattern(self, pattern: str, prefix: str = "core") -> int:
|
||||
"""Clear keys matching pattern"""
|
||||
if not self.enabled:
|
||||
return 0
|
||||
|
||||
|
||||
try:
|
||||
cache_pattern = self._get_cache_key(pattern, prefix)
|
||||
keys = await self.redis_client.keys(cache_pattern)
|
||||
if keys:
|
||||
return await self.redis_client.delete(*keys)
|
||||
return 0
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache clear pattern error for pattern {pattern}: {e}")
|
||||
self.stats["errors"] += 1
|
||||
return 0
|
||||
|
||||
async def increment(self, key: str, amount: int = 1, ttl: Optional[int] = None, prefix: str = "core") -> int:
|
||||
|
||||
async def increment(
|
||||
self, key: str, amount: int = 1, ttl: Optional[int] = None, prefix: str = "core"
|
||||
) -> int:
|
||||
"""Increment counter with optional TTL"""
|
||||
if not self.enabled:
|
||||
return 0
|
||||
|
||||
|
||||
try:
|
||||
cache_key = self._get_cache_key(key, prefix)
|
||||
|
||||
|
||||
# Use pipeline for atomic increment + expire
|
||||
async with self.redis_client.pipeline() as pipe:
|
||||
await pipe.incr(cache_key, amount)
|
||||
@@ -187,93 +186,118 @@ class CoreCacheService:
|
||||
await pipe.expire(cache_key, ttl)
|
||||
results = await pipe.execute()
|
||||
return results[0]
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache increment error for key {key}: {e}")
|
||||
self.stats["errors"] += 1
|
||||
return 0
|
||||
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive cache statistics"""
|
||||
stats = self.stats.copy()
|
||||
|
||||
|
||||
if self.enabled:
|
||||
try:
|
||||
info = await self.redis_client.info()
|
||||
stats.update({
|
||||
"redis_memory_used": info.get("used_memory_human", "N/A"),
|
||||
"redis_connected_clients": info.get("connected_clients", 0),
|
||||
"redis_total_commands": info.get("total_commands_processed", 0),
|
||||
"redis_keyspace_hits": info.get("keyspace_hits", 0),
|
||||
"redis_keyspace_misses": info.get("keyspace_misses", 0),
|
||||
"connection_pool_size": self.redis_pool.connection_pool_size if self.redis_pool else 0,
|
||||
"hit_rate": round(
|
||||
(stats["hits"] / stats["total_requests"]) * 100, 2
|
||||
) if stats["total_requests"] > 0 else 0,
|
||||
"enabled": True
|
||||
})
|
||||
stats.update(
|
||||
{
|
||||
"redis_memory_used": info.get("used_memory_human", "N/A"),
|
||||
"redis_connected_clients": info.get("connected_clients", 0),
|
||||
"redis_total_commands": info.get("total_commands_processed", 0),
|
||||
"redis_keyspace_hits": info.get("keyspace_hits", 0),
|
||||
"redis_keyspace_misses": info.get("keyspace_misses", 0),
|
||||
"connection_pool_size": self.redis_pool.connection_pool_size
|
||||
if self.redis_pool
|
||||
else 0,
|
||||
"hit_rate": round(
|
||||
(stats["hits"] / stats["total_requests"]) * 100, 2
|
||||
)
|
||||
if stats["total_requests"] > 0
|
||||
else 0,
|
||||
"enabled": True,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Redis stats: {e}")
|
||||
stats["enabled"] = False
|
||||
else:
|
||||
stats["enabled"] = False
|
||||
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def pipeline(self):
|
||||
"""Context manager for Redis pipeline operations"""
|
||||
if not self.enabled:
|
||||
yield None
|
||||
return
|
||||
|
||||
|
||||
async with self.redis_client.pipeline() as pipe:
|
||||
yield pipe
|
||||
|
||||
|
||||
# Specialized caching methods for common use cases
|
||||
|
||||
async def cache_api_key(self, key_prefix: str, api_key_data: Dict[str, Any], ttl: int = 300) -> bool:
|
||||
|
||||
async def cache_api_key(
|
||||
self, key_prefix: str, api_key_data: Dict[str, Any], ttl: int = 300
|
||||
) -> bool:
|
||||
"""Cache API key data for authentication"""
|
||||
return await self.set(key_prefix, api_key_data, ttl, prefix="auth")
|
||||
|
||||
|
||||
async def get_cached_api_key(self, key_prefix: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get cached API key data"""
|
||||
return await self.get(key_prefix, prefix="auth")
|
||||
|
||||
|
||||
async def invalidate_api_key(self, key_prefix: str) -> bool:
|
||||
"""Invalidate cached API key"""
|
||||
return await self.delete(key_prefix, prefix="auth")
|
||||
|
||||
async def cache_verification_result(self, api_key: str, key_prefix: str, key_hash: str, is_valid: bool, ttl: int = 300) -> bool:
|
||||
|
||||
async def cache_verification_result(
|
||||
self,
|
||||
api_key: str,
|
||||
key_prefix: str,
|
||||
key_hash: str,
|
||||
is_valid: bool,
|
||||
ttl: int = 300,
|
||||
) -> bool:
|
||||
"""Cache API key verification result to avoid expensive bcrypt operations"""
|
||||
verification_data = {
|
||||
"key_hash": key_hash,
|
||||
"is_valid": is_valid,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
return await self.set(f"verify:{key_prefix}", verification_data, ttl, prefix="auth")
|
||||
|
||||
async def get_cached_verification(self, key_prefix: str) -> Optional[Dict[str, Any]]:
|
||||
return await self.set(
|
||||
f"verify:{key_prefix}", verification_data, ttl, prefix="auth"
|
||||
)
|
||||
|
||||
async def get_cached_verification(
|
||||
self, key_prefix: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get cached verification result"""
|
||||
return await self.get(f"verify:{key_prefix}", prefix="auth")
|
||||
|
||||
async def cache_rate_limit(self, identifier: str, window_seconds: int, limit: int, current_count: int = 1) -> Dict[str, Any]:
|
||||
|
||||
async def cache_rate_limit(
|
||||
self, identifier: str, window_seconds: int, limit: int, current_count: int = 1
|
||||
) -> Dict[str, Any]:
|
||||
"""Cache and track rate limit state"""
|
||||
key = f"rate_limit:{identifier}:{window_seconds}"
|
||||
|
||||
|
||||
try:
|
||||
# Use atomic increment with expiry
|
||||
count = await self.increment(key, current_count, window_seconds, prefix="rate")
|
||||
|
||||
count = await self.increment(
|
||||
key, current_count, window_seconds, prefix="rate"
|
||||
)
|
||||
|
||||
remaining = max(0, limit - count)
|
||||
reset_time = int((datetime.utcnow() + timedelta(seconds=window_seconds)).timestamp())
|
||||
|
||||
reset_time = int(
|
||||
(datetime.utcnow() + timedelta(seconds=window_seconds)).timestamp()
|
||||
)
|
||||
|
||||
return {
|
||||
"count": count,
|
||||
"limit": limit,
|
||||
"remaining": remaining,
|
||||
"reset_time": reset_time,
|
||||
"exceeded": count > limit
|
||||
"exceeded": count > limit,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Rate limit cache error: {e}")
|
||||
@@ -282,8 +306,10 @@ class CoreCacheService:
|
||||
"count": 0,
|
||||
"limit": limit,
|
||||
"remaining": limit,
|
||||
"reset_time": int((datetime.utcnow() + timedelta(seconds=window_seconds)).timestamp()),
|
||||
"exceeded": False
|
||||
"reset_time": int(
|
||||
(datetime.utcnow() + timedelta(seconds=window_seconds)).timestamp()
|
||||
),
|
||||
"exceeded": False,
|
||||
}
|
||||
|
||||
|
||||
@@ -297,7 +323,9 @@ async def get(key: str, default: Any = None, prefix: str = "core") -> Any:
|
||||
return await core_cache.get(key, default, prefix)
|
||||
|
||||
|
||||
async def set(key: str, value: Any, ttl: Optional[int] = None, prefix: str = "core") -> bool:
|
||||
async def set(
|
||||
key: str, value: Any, ttl: Optional[int] = None, prefix: str = "core"
|
||||
) -> bool:
|
||||
"""Set value in core cache"""
|
||||
return await core_cache.set(key, value, ttl, prefix)
|
||||
|
||||
@@ -319,4 +347,4 @@ async def clear_pattern(pattern: str, prefix: str = "core") -> int:
|
||||
|
||||
async def get_stats() -> Dict[str, Any]:
|
||||
"""Get core cache statistics"""
|
||||
return await core_cache.get_stats()
|
||||
return await core_cache.get_stats()
|
||||
|
||||
@@ -10,7 +10,7 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings"""
|
||||
|
||||
|
||||
# Application
|
||||
APP_NAME: str = os.getenv("APP_NAME", "Enclava")
|
||||
APP_DEBUG: bool = os.getenv("APP_DEBUG", "False").lower() == "true"
|
||||
@@ -19,131 +19,188 @@ class Settings(BaseSettings):
|
||||
APP_PORT: int = int(os.getenv("APP_PORT", "8000"))
|
||||
BACKEND_INTERNAL_PORT: int = int(os.getenv("BACKEND_INTERNAL_PORT", "8000"))
|
||||
FRONTEND_INTERNAL_PORT: int = int(os.getenv("FRONTEND_INTERNAL_PORT", "3000"))
|
||||
|
||||
|
||||
# Detailed logging for LLM interactions
|
||||
LOG_LLM_PROMPTS: bool = os.getenv("LOG_LLM_PROMPTS", "False").lower() == "true" # Set to True to log prompts and context sent to LLM
|
||||
|
||||
LOG_LLM_PROMPTS: bool = (
|
||||
os.getenv("LOG_LLM_PROMPTS", "False").lower() == "true"
|
||||
) # Set to True to log prompts and context sent to LLM
|
||||
|
||||
# Database
|
||||
DATABASE_URL: str = os.getenv("DATABASE_URL")
|
||||
|
||||
|
||||
# Redis
|
||||
REDIS_URL: str = os.getenv("REDIS_URL", "redis://localhost:6379")
|
||||
|
||||
|
||||
# Security
|
||||
JWT_SECRET: str = os.getenv("JWT_SECRET")
|
||||
JWT_ALGORITHM: str = os.getenv("JWT_ALGORITHM", "HS256")
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440")) # 24 hours
|
||||
REFRESH_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_MINUTES", "10080")) # 7 days
|
||||
SESSION_EXPIRE_MINUTES: int = int(os.getenv("SESSION_EXPIRE_MINUTES", "1440")) # 24 hours
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(
|
||||
os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440")
|
||||
) # 24 hours
|
||||
REFRESH_TOKEN_EXPIRE_MINUTES: int = int(
|
||||
os.getenv("REFRESH_TOKEN_EXPIRE_MINUTES", "10080")
|
||||
) # 7 days
|
||||
SESSION_EXPIRE_MINUTES: int = int(
|
||||
os.getenv("SESSION_EXPIRE_MINUTES", "1440")
|
||||
) # 24 hours
|
||||
API_KEY_PREFIX: str = os.getenv("API_KEY_PREFIX", "en_")
|
||||
BCRYPT_ROUNDS: int = int(os.getenv("BCRYPT_ROUNDS", "6")) # Bcrypt work factor - lower for production performance
|
||||
|
||||
BCRYPT_ROUNDS: int = int(
|
||||
os.getenv("BCRYPT_ROUNDS", "6")
|
||||
) # Bcrypt work factor - lower for production performance
|
||||
|
||||
# Admin user provisioning (used only on first startup)
|
||||
ADMIN_EMAIL: str = os.getenv("ADMIN_EMAIL")
|
||||
ADMIN_PASSWORD: str = os.getenv("ADMIN_PASSWORD")
|
||||
|
||||
|
||||
# Base URL for deriving CORS origins
|
||||
BASE_URL: str = os.getenv("BASE_URL", "localhost")
|
||||
|
||||
@field_validator('CORS_ORIGINS', mode='before')
|
||||
|
||||
@field_validator("CORS_ORIGINS", mode="before")
|
||||
@classmethod
|
||||
def derive_cors_origins(cls, v, info):
|
||||
"""Derive CORS origins from BASE_URL if not explicitly set"""
|
||||
if v is None:
|
||||
base_url = info.data.get('BASE_URL', 'localhost')
|
||||
base_url = info.data.get("BASE_URL", "localhost")
|
||||
# Support both HTTP and HTTPS for production environments
|
||||
return [f"http://{base_url}", f"https://{base_url}"]
|
||||
return v if isinstance(v, list) else [v]
|
||||
|
||||
|
||||
# CORS origins (derived from BASE_URL)
|
||||
CORS_ORIGINS: Optional[List[str]] = None
|
||||
|
||||
|
||||
# LLM Service Configuration (replaced LiteLLM)
|
||||
# LLM service configuration is now handled in app/services/llm/config.py
|
||||
|
||||
|
||||
# LLM Service Security (removed encryption - credentials handled by proxy)
|
||||
|
||||
|
||||
# Plugin System Security
|
||||
PLUGIN_ENCRYPTION_KEY: Optional[str] = os.getenv("PLUGIN_ENCRYPTION_KEY") # Key for encrypting plugin secrets and configurations
|
||||
|
||||
PLUGIN_ENCRYPTION_KEY: Optional[str] = os.getenv(
|
||||
"PLUGIN_ENCRYPTION_KEY"
|
||||
) # Key for encrypting plugin secrets and configurations
|
||||
|
||||
# API Keys for LLM providers
|
||||
OPENAI_API_KEY: Optional[str] = os.getenv("OPENAI_API_KEY")
|
||||
ANTHROPIC_API_KEY: Optional[str] = os.getenv("ANTHROPIC_API_KEY")
|
||||
GOOGLE_API_KEY: Optional[str] = os.getenv("GOOGLE_API_KEY")
|
||||
PRIVATEMODE_API_KEY: Optional[str] = os.getenv("PRIVATEMODE_API_KEY")
|
||||
PRIVATEMODE_PROXY_URL: str = os.getenv("PRIVATEMODE_PROXY_URL", "http://privatemode-proxy:8080/v1")
|
||||
|
||||
PRIVATEMODE_PROXY_URL: str = os.getenv(
|
||||
"PRIVATEMODE_PROXY_URL", "http://privatemode-proxy:8080/v1"
|
||||
)
|
||||
|
||||
# Qdrant
|
||||
QDRANT_HOST: str = os.getenv("QDRANT_HOST", "localhost")
|
||||
QDRANT_PORT: int = int(os.getenv("QDRANT_PORT", "6333"))
|
||||
QDRANT_API_KEY: Optional[str] = os.getenv("QDRANT_API_KEY")
|
||||
QDRANT_URL: str = os.getenv("QDRANT_URL", "http://localhost:6333")
|
||||
|
||||
|
||||
|
||||
# Rate Limiting Configuration
|
||||
|
||||
|
||||
# PrivateMode Standard tier limits (organization-level, not per user)
|
||||
# These are shared across all API keys and users in the organization
|
||||
PRIVATEMODE_REQUESTS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_MINUTE", "20"))
|
||||
PRIVATEMODE_REQUESTS_PER_HOUR: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_HOUR", "1200"))
|
||||
PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE", "20000"))
|
||||
PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE", "10000"))
|
||||
PRIVATEMODE_REQUESTS_PER_MINUTE: int = int(
|
||||
os.getenv("PRIVATEMODE_REQUESTS_PER_MINUTE", "20")
|
||||
)
|
||||
PRIVATEMODE_REQUESTS_PER_HOUR: int = int(
|
||||
os.getenv("PRIVATEMODE_REQUESTS_PER_HOUR", "1200")
|
||||
)
|
||||
PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE: int = int(
|
||||
os.getenv("PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE", "20000")
|
||||
)
|
||||
PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE: int = int(
|
||||
os.getenv("PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE", "10000")
|
||||
)
|
||||
|
||||
# Per-user limits (additional protection on top of organization limits)
|
||||
API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "20")) # Match PrivateMode
|
||||
API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "1200"))
|
||||
API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(
|
||||
os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "20")
|
||||
) # Match PrivateMode
|
||||
API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(
|
||||
os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "1200")
|
||||
)
|
||||
|
||||
# API key users (programmatic access)
|
||||
API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "20")) # Match PrivateMode
|
||||
API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "1200"))
|
||||
API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(
|
||||
os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "20")
|
||||
) # Match PrivateMode
|
||||
API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(
|
||||
os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "1200")
|
||||
)
|
||||
|
||||
# Premium/Enterprise API keys
|
||||
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")) # Match PrivateMode
|
||||
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200"))
|
||||
|
||||
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(
|
||||
os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")
|
||||
) # Match PrivateMode
|
||||
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(
|
||||
os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200")
|
||||
)
|
||||
|
||||
# Request Size Limits
|
||||
API_MAX_REQUEST_BODY_SIZE: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE", "10485760")) # 10MB
|
||||
API_MAX_REQUEST_BODY_SIZE_PREMIUM: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE_PREMIUM", "52428800")) # 50MB for premium
|
||||
|
||||
API_MAX_REQUEST_BODY_SIZE: int = int(
|
||||
os.getenv("API_MAX_REQUEST_BODY_SIZE", "10485760")
|
||||
) # 10MB
|
||||
API_MAX_REQUEST_BODY_SIZE_PREMIUM: int = int(
|
||||
os.getenv("API_MAX_REQUEST_BODY_SIZE_PREMIUM", "52428800")
|
||||
) # 50MB for premium
|
||||
|
||||
# IP Security
|
||||
|
||||
|
||||
# Security Headers
|
||||
API_CSP_HEADER: str = os.getenv("API_CSP_HEADER", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'")
|
||||
|
||||
API_CSP_HEADER: str = os.getenv(
|
||||
"API_CSP_HEADER",
|
||||
"default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'",
|
||||
)
|
||||
|
||||
# Monitoring
|
||||
PROMETHEUS_ENABLED: bool = os.getenv("PROMETHEUS_ENABLED", "True").lower() == "true"
|
||||
PROMETHEUS_PORT: int = int(os.getenv("PROMETHEUS_PORT", "9090"))
|
||||
|
||||
|
||||
# File uploads
|
||||
MAX_UPLOAD_SIZE: int = int(os.getenv("MAX_UPLOAD_SIZE", "10485760")) # 10MB
|
||||
|
||||
|
||||
# Module configuration
|
||||
MODULES_CONFIG_PATH: str = os.getenv("MODULES_CONFIG_PATH", "config/modules.yaml")
|
||||
|
||||
# RAG Embedding Configuration
|
||||
RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE: int = int(os.getenv("RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE", "12"))
|
||||
RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE: int = int(
|
||||
os.getenv("RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE", "12")
|
||||
)
|
||||
RAG_EMBEDDING_BATCH_SIZE: int = int(os.getenv("RAG_EMBEDDING_BATCH_SIZE", "3"))
|
||||
RAG_EMBEDDING_RETRY_COUNT: int = int(os.getenv("RAG_EMBEDDING_RETRY_COUNT", "3"))
|
||||
RAG_EMBEDDING_RETRY_DELAYS: str = os.getenv("RAG_EMBEDDING_RETRY_DELAYS", "1,2,4,8,16")
|
||||
RAG_EMBEDDING_DELAY_BETWEEN_BATCHES: float = float(os.getenv("RAG_EMBEDDING_DELAY_BETWEEN_BATCHES", "1.0"))
|
||||
RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5"))
|
||||
RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true"
|
||||
RAG_WARN_ON_FALLBACK: bool = os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true"
|
||||
RAG_EMBEDDING_RETRY_DELAYS: str = os.getenv(
|
||||
"RAG_EMBEDDING_RETRY_DELAYS", "1,2,4,8,16"
|
||||
)
|
||||
RAG_EMBEDDING_DELAY_BETWEEN_BATCHES: float = float(
|
||||
os.getenv("RAG_EMBEDDING_DELAY_BETWEEN_BATCHES", "1.0")
|
||||
)
|
||||
RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(
|
||||
os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5")
|
||||
)
|
||||
RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = (
|
||||
os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true"
|
||||
)
|
||||
RAG_WARN_ON_FALLBACK: bool = (
|
||||
os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true"
|
||||
)
|
||||
RAG_EMBEDDING_MODEL: str = os.getenv("RAG_EMBEDDING_MODEL", "bge-m3")
|
||||
RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300"))
|
||||
RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120"))
|
||||
RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(
|
||||
os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300")
|
||||
)
|
||||
RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(
|
||||
os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120")
|
||||
)
|
||||
RAG_INDEXING_TIMEOUT: int = int(os.getenv("RAG_INDEXING_TIMEOUT", "120"))
|
||||
|
||||
|
||||
# Plugin configuration
|
||||
PLUGINS_DIR: str = os.getenv("PLUGINS_DIR", "/plugins")
|
||||
PLUGINS_CONFIG_PATH: str = os.getenv("PLUGINS_CONFIG_PATH", "config/plugins.yaml")
|
||||
PLUGIN_REPOSITORY_URL: str = os.getenv("PLUGIN_REPOSITORY_URL", "https://plugins.enclava.com")
|
||||
|
||||
PLUGIN_REPOSITORY_URL: str = os.getenv(
|
||||
"PLUGIN_REPOSITORY_URL", "https://plugins.enclava.com"
|
||||
)
|
||||
|
||||
# Logging
|
||||
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "json")
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||
|
||||
|
||||
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"case_sensitive": True,
|
||||
|
||||
@@ -13,7 +13,7 @@ from app.core.config import settings
|
||||
|
||||
def setup_logging() -> None:
|
||||
"""Setup structured logging"""
|
||||
|
||||
|
||||
# Configure structlog
|
||||
structlog.configure(
|
||||
processors=[
|
||||
@@ -24,21 +24,23 @@ def setup_logging() -> None:
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.processors.UnicodeDecoder(),
|
||||
structlog.processors.JSONRenderer() if settings.LOG_FORMAT == "json" else structlog.dev.ConsoleRenderer(),
|
||||
structlog.processors.JSONRenderer()
|
||||
if settings.LOG_FORMAT == "json"
|
||||
else structlog.dev.ConsoleRenderer(),
|
||||
],
|
||||
context_class=dict,
|
||||
logger_factory=LoggerFactory(),
|
||||
wrapper_class=structlog.stdlib.BoundLogger,
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
|
||||
|
||||
# Configure standard logging
|
||||
logging.basicConfig(
|
||||
format="%(message)s",
|
||||
stream=sys.stdout,
|
||||
level=getattr(logging, settings.LOG_LEVEL.upper()),
|
||||
)
|
||||
|
||||
|
||||
# Set specific loggers
|
||||
logging.getLogger("uvicorn").setLevel(logging.WARNING)
|
||||
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
||||
@@ -52,17 +54,17 @@ def get_logger(name: str) -> structlog.stdlib.BoundLogger:
|
||||
|
||||
class RequestContextFilter(logging.Filter):
|
||||
"""Add request context to log records"""
|
||||
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
# Add request context if available
|
||||
from contextvars import ContextVar
|
||||
|
||||
|
||||
request_id: ContextVar[str] = ContextVar("request_id", default="")
|
||||
user_id: ContextVar[str] = ContextVar("user_id", default="")
|
||||
|
||||
|
||||
record.request_id = request_id.get()
|
||||
record.user_id = user_id.get()
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -77,7 +79,7 @@ def log_request(
|
||||
) -> None:
|
||||
"""Log HTTP request"""
|
||||
logger = get_logger("api.request")
|
||||
|
||||
|
||||
log_data = {
|
||||
"method": method,
|
||||
"path": path,
|
||||
@@ -87,7 +89,7 @@ def log_request(
|
||||
"request_id": request_id,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
if status_code >= 500:
|
||||
logger.error("Request failed", **log_data)
|
||||
elif status_code >= 400:
|
||||
@@ -105,7 +107,7 @@ def log_security_event(
|
||||
) -> None:
|
||||
"""Log security event"""
|
||||
logger = get_logger("security")
|
||||
|
||||
|
||||
log_data = {
|
||||
"event_type": event_type,
|
||||
"user_id": user_id,
|
||||
@@ -113,7 +115,7 @@ def log_security_event(
|
||||
"details": details or {},
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
logger.warning("Security event", **log_data)
|
||||
|
||||
|
||||
@@ -125,14 +127,14 @@ def log_module_event(
|
||||
) -> None:
|
||||
"""Log module event"""
|
||||
logger = get_logger("module")
|
||||
|
||||
|
||||
log_data = {
|
||||
"module_id": module_id,
|
||||
"event_type": event_type,
|
||||
"details": details or {},
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
logger.info("Module event", **log_data)
|
||||
|
||||
|
||||
@@ -143,11 +145,11 @@ def log_api_request(
|
||||
) -> None:
|
||||
"""Log API request for modules endpoints"""
|
||||
logger = get_logger("api.module")
|
||||
|
||||
|
||||
log_data = {
|
||||
"endpoint": endpoint,
|
||||
"params": params or {},
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
logger.info("API request", **log_data)
|
||||
|
||||
logger.info("API request", **log_data)
|
||||
|
||||
367
backend/app/core/permissions.py
Normal file
367
backend/app/core/permissions.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
Permissions Module
|
||||
Role-based access control decorators and utilities
|
||||
"""
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Union, Callable
|
||||
from fastapi import HTTPException, status, Depends
|
||||
from fastapi.security import HTTPBearer
|
||||
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
def require_permission(
|
||||
user: User, permission: str, resource_id: Optional[Union[str, int]] = None
|
||||
):
|
||||
"""
|
||||
Check if user has the required permission
|
||||
|
||||
Args:
|
||||
user: User object from dependency injection
|
||||
permission: Required permission string
|
||||
resource_id: Optional resource ID for resource-specific permissions
|
||||
|
||||
Raises:
|
||||
HTTPException: If user doesn't have the required permission
|
||||
"""
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required"
|
||||
)
|
||||
|
||||
# Check if user is active
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Account is not active"
|
||||
)
|
||||
|
||||
# Check if account is locked
|
||||
if user.account_locked:
|
||||
if user.account_locked_until and user.account_locked_until > datetime.utcnow():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Account is temporarily locked",
|
||||
)
|
||||
else:
|
||||
# Unlock account if lock period has expired
|
||||
user.unlock_account()
|
||||
|
||||
# Check permission
|
||||
if not user.has_permission(permission):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Permission '{permission}' required",
|
||||
)
|
||||
|
||||
|
||||
def require_permissions(permissions: List[str], require_all: bool = True):
|
||||
"""
|
||||
Decorator to require multiple permissions
|
||||
|
||||
Args:
|
||||
permissions: List of required permissions
|
||||
require_all: If True, user must have all permissions. If False, any one permission is sufficient
|
||||
"""
|
||||
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Extract user from kwargs (assuming it's passed as a dependency)
|
||||
user = None
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, User):
|
||||
user = value
|
||||
break
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
)
|
||||
|
||||
# Check permissions
|
||||
if require_all:
|
||||
# User must have all permissions
|
||||
for permission in permissions:
|
||||
if not user.has_permission(permission):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"All of the following permissions required: {', '.join(permissions)}",
|
||||
)
|
||||
else:
|
||||
# User needs at least one permission
|
||||
if not any(
|
||||
user.has_permission(permission) for permission in permissions
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"At least one of the following permissions required: {', '.join(permissions)}",
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_role(role_names: Union[str, List[str]], require_all: bool = True):
|
||||
"""
|
||||
Decorator to require specific roles
|
||||
|
||||
Args:
|
||||
role_names: Required role name(s)
|
||||
require_all: If True, user must have all roles. If False, any one role is sufficient
|
||||
"""
|
||||
if isinstance(role_names, str):
|
||||
role_names = [role_names]
|
||||
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Extract user from kwargs
|
||||
user = None
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, User):
|
||||
user = value
|
||||
break
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
)
|
||||
|
||||
# Check roles
|
||||
user_role_names = []
|
||||
if user.role:
|
||||
user_role_names.append(user.role.name)
|
||||
|
||||
if require_all:
|
||||
# User must have all roles
|
||||
for role_name in role_names:
|
||||
if role_name not in user_role_names:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"All of the following roles required: {', '.join(role_names)}",
|
||||
)
|
||||
else:
|
||||
# User needs at least one role
|
||||
if not any(role_name in user_role_names for role_name in role_names):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"At least one of the following roles required: {', '.join(role_names)}",
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_minimum_role(minimum_role_level: str):
|
||||
"""
|
||||
Decorator to require minimum role level based on hierarchy
|
||||
|
||||
Args:
|
||||
minimum_role_level: Minimum required role level
|
||||
"""
|
||||
role_hierarchy = {"read_only": 1, "user": 2, "admin": 3, "super_admin": 4}
|
||||
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Extract user from kwargs
|
||||
user = None
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, User):
|
||||
user = value
|
||||
break
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
)
|
||||
|
||||
# Superusers bypass role checks
|
||||
if user.is_superuser:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
# Check role level
|
||||
if not user.role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Minimum role level '{minimum_role_level}' required",
|
||||
)
|
||||
|
||||
user_level = role_hierarchy.get(user.role.level, 0)
|
||||
required_level = role_hierarchy.get(minimum_role_level, 0)
|
||||
|
||||
if user_level < required_level:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Minimum role level '{minimum_role_level}' required",
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def check_resource_permission(
|
||||
user: User, resource_type: str, resource_id: Union[str, int], action: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user has permission to perform action on specific resource
|
||||
|
||||
Args:
|
||||
user: User object
|
||||
resource_type: Type of resource (e.g., 'user', 'budget', 'api_key')
|
||||
resource_id: ID of the resource
|
||||
action: Action to perform (e.g., 'read', 'update', 'delete')
|
||||
|
||||
Returns:
|
||||
bool: True if user has permission, False otherwise
|
||||
"""
|
||||
# Superusers can do anything
|
||||
if user.is_superuser:
|
||||
return True
|
||||
|
||||
# Check basic permissions
|
||||
permission = f"{action}_{resource_type}"
|
||||
if user.has_permission(permission):
|
||||
return True
|
||||
|
||||
# Check own resource permissions
|
||||
if resource_type == "user" and str(resource_id) == str(user.id):
|
||||
if user.has_permission(f"{action}_own"):
|
||||
return True
|
||||
|
||||
# Check role-based resource access
|
||||
if user.role:
|
||||
# Admins can manage all users
|
||||
if resource_type == "user" and user.role.level in ["admin", "super_admin"]:
|
||||
return True
|
||||
|
||||
# Users with budget permissions can manage budgets
|
||||
if resource_type == "budget" and user.role.can_manage_budgets:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def require_resource_permission(
|
||||
resource_type: str, resource_id_param: str = "resource_id", action: str = "read"
|
||||
):
|
||||
"""
|
||||
Decorator to require permission for specific resource
|
||||
|
||||
Args:
|
||||
resource_type: Type of resource
|
||||
resource_id_param: Name of parameter containing resource ID
|
||||
action: Action to perform
|
||||
"""
|
||||
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Extract user and resource ID from kwargs
|
||||
user = None
|
||||
resource_id = None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, User):
|
||||
user = value
|
||||
elif key == resource_id_param:
|
||||
resource_id = value
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
)
|
||||
|
||||
if resource_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Resource ID not provided",
|
||||
)
|
||||
|
||||
# Check resource permission
|
||||
if not check_resource_permission(user, resource_type, resource_id, action):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Permission '{action}_{resource_type}' required",
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def check_budget_permission(user: User, budget_id: int, action: str) -> bool:
|
||||
"""
|
||||
Check if user has permission to perform action on specific budget
|
||||
|
||||
Args:
|
||||
user: User object
|
||||
budget_id: ID of the budget
|
||||
action: Action to perform
|
||||
|
||||
Returns:
|
||||
bool: True if user has permission, False otherwise
|
||||
"""
|
||||
# Superusers can do anything
|
||||
if user.is_superuser:
|
||||
return True
|
||||
|
||||
# Check if user owns the budget
|
||||
for budget in user.budgets:
|
||||
if budget.id == budget_id and budget.is_active:
|
||||
return user.has_permission(f"{action}_own")
|
||||
|
||||
# Check if user can manage all budgets
|
||||
if user.has_permission(f"{action}_all") or (
|
||||
user.role and user.role.can_manage_budgets
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_api_key_permission(user: User, api_key_id: int, action: str) -> bool:
|
||||
"""
|
||||
Check if user has permission to perform action on specific API key
|
||||
|
||||
Args:
|
||||
user: User object
|
||||
api_key_id: ID of the API key
|
||||
action: Action to perform
|
||||
|
||||
Returns:
|
||||
bool: True if user has permission, False otherwise
|
||||
"""
|
||||
# Superusers can do anything
|
||||
if user.is_superuser:
|
||||
return True
|
||||
|
||||
# Check if user owns the API key
|
||||
for api_key in user.api_keys:
|
||||
if api_key.id == api_key_id:
|
||||
return user.has_permission(f"{action}_own")
|
||||
|
||||
# Check if user can manage all API keys
|
||||
if user.has_permission(f"{action}_all"):
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -24,34 +24,39 @@ logger = logging.getLogger(__name__)
|
||||
# Password hashing
|
||||
# Use a lower work factor for better performance in production
|
||||
pwd_context = CryptContext(
|
||||
schemes=["bcrypt"],
|
||||
deprecated="auto",
|
||||
bcrypt__rounds=settings.BCRYPT_ROUNDS
|
||||
schemes=["bcrypt"], deprecated="auto", bcrypt__rounds=settings.BCRYPT_ROUNDS
|
||||
)
|
||||
|
||||
# JWT token handling
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash"""
|
||||
import time
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
logger.info(f"=== PASSWORD VERIFICATION START === BCRYPT_ROUNDS: {settings.BCRYPT_ROUNDS}")
|
||||
|
||||
logger.info(
|
||||
f"=== PASSWORD VERIFICATION START === BCRYPT_ROUNDS: {settings.BCRYPT_ROUNDS}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Run password verification in a thread with timeout
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(pwd_context.verify, plain_password, hashed_password)
|
||||
future = executor.submit(
|
||||
pwd_context.verify, plain_password, hashed_password
|
||||
)
|
||||
result = future.result(timeout=5.0) # 5 second timeout
|
||||
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
logger.info(f"=== PASSWORD VERIFICATION END === Duration: {duration:.3f}s, Result: {result}")
|
||||
|
||||
logger.info(
|
||||
f"=== PASSWORD VERIFICATION END === Duration: {duration:.3f}s, Result: {result}"
|
||||
)
|
||||
|
||||
if duration > 1:
|
||||
logger.warning(f"PASSWORD VERIFICATION TOOK TOO LONG: {duration:.3f}s")
|
||||
|
||||
|
||||
return result
|
||||
except concurrent.futures.TimeoutError:
|
||||
end_time = time.time()
|
||||
@@ -61,87 +66,116 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
logger.error(f"=== PASSWORD VERIFICATION FAILED === Duration: {duration:.3f}s, Error: {e}")
|
||||
logger.error(
|
||||
f"=== PASSWORD VERIFICATION FAILED === Duration: {duration:.3f}s, Error: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Generate password hash"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_api_key(plain_api_key: str, hashed_api_key: str) -> bool:
|
||||
"""Verify an API key against its hash"""
|
||||
return pwd_context.verify(plain_api_key, hashed_api_key)
|
||||
|
||||
|
||||
def get_api_key_hash(api_key: str) -> str:
|
||||
"""Generate API key hash"""
|
||||
return pwd_context.hash(api_key)
|
||||
|
||||
def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||
|
||||
def create_access_token(
|
||||
data: Dict[str, Any], expires_delta: Optional[timedelta] = None
|
||||
) -> str:
|
||||
"""Create JWT access token"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
logger.info(f"=== CREATE ACCESS TOKEN START ===")
|
||||
|
||||
|
||||
try:
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
expire = datetime.utcnow() + timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
|
||||
to_encode.update({"exp": expire})
|
||||
logger.info(f"JWT encode start...")
|
||||
encode_start = time.time()
|
||||
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM
|
||||
)
|
||||
encode_end = time.time()
|
||||
encode_duration = encode_end - encode_start
|
||||
|
||||
|
||||
end_time = time.time()
|
||||
total_duration = end_time - start_time
|
||||
|
||||
|
||||
# Log token creation details
|
||||
logger.info(f"Created access token for user {data.get('sub')}")
|
||||
logger.info(f"Token expires at: {expire.isoformat()} (UTC)")
|
||||
logger.info(f"Current UTC time: {datetime.utcnow().isoformat()}")
|
||||
logger.info(f"ACCESS_TOKEN_EXPIRE_MINUTES setting: {settings.ACCESS_TOKEN_EXPIRE_MINUTES}")
|
||||
logger.info(
|
||||
f"ACCESS_TOKEN_EXPIRE_MINUTES setting: {settings.ACCESS_TOKEN_EXPIRE_MINUTES}"
|
||||
)
|
||||
logger.info(f"JWT encode duration: {encode_duration:.3f}s")
|
||||
logger.info(f"Total token creation duration: {total_duration:.3f}s")
|
||||
logger.info(f"=== CREATE ACCESS TOKEN END ===")
|
||||
|
||||
|
||||
return encoded_jwt
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
total_duration = end_time - start_time
|
||||
logger.error(f"=== CREATE ACCESS TOKEN FAILED === Duration: {total_duration:.3f}s, Error: {e}")
|
||||
logger.error(
|
||||
f"=== CREATE ACCESS TOKEN FAILED === Duration: {total_duration:.3f}s, Error: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def create_refresh_token(data: Dict[str, Any]) -> str:
|
||||
"""Create JWT refresh token"""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES)
|
||||
expire = datetime.utcnow() + timedelta(
|
||||
minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM
|
||||
)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def verify_token(token: str) -> Dict[str, Any]:
|
||||
"""Verify JWT token and return payload"""
|
||||
try:
|
||||
# Log current time before verification
|
||||
current_time = datetime.utcnow()
|
||||
logger.info(f"Verifying token at: {current_time.isoformat()} (UTC)")
|
||||
|
||||
|
||||
# Decode without verification first to check expiration
|
||||
try:
|
||||
unverified_payload = jwt.get_unverified_claims(token)
|
||||
exp_timestamp = unverified_payload.get('exp')
|
||||
exp_timestamp = unverified_payload.get("exp")
|
||||
if exp_timestamp:
|
||||
exp_datetime = datetime.fromtimestamp(exp_timestamp, tz=None)
|
||||
logger.info(f"Token expiration time: {exp_datetime.isoformat()} (UTC)")
|
||||
logger.info(f"Time until expiration: {(exp_datetime - current_time).total_seconds()} seconds")
|
||||
logger.info(
|
||||
f"Time until expiration: {(exp_datetime - current_time).total_seconds()} seconds"
|
||||
)
|
||||
except Exception as decode_error:
|
||||
logger.warning(f"Could not decode token for expiration check: {decode_error}")
|
||||
|
||||
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
|
||||
logger.warning(
|
||||
f"Could not decode token for expiration check: {decode_error}"
|
||||
)
|
||||
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
logger.info(f"Token verified successfully for user {payload.get('sub')}")
|
||||
return payload
|
||||
except JWTError as e:
|
||||
@@ -149,30 +183,32 @@ def verify_token(token: str) -> Dict[str, Any]:
|
||||
logger.warning(f"Current UTC time: {datetime.utcnow().isoformat()}")
|
||||
raise AuthenticationError("Invalid token")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current user from JWT token"""
|
||||
try:
|
||||
# Log server time for debugging clock sync issues
|
||||
server_time = datetime.utcnow()
|
||||
logger.info(f"get_current_user called at: {server_time.isoformat()} (UTC)")
|
||||
|
||||
|
||||
payload = verify_token(credentials.credentials)
|
||||
user_id: str = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise AuthenticationError("Invalid token payload")
|
||||
|
||||
|
||||
# Load user from database
|
||||
from app.models.user import User
|
||||
from sqlalchemy import select
|
||||
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
# Query user from database
|
||||
stmt = select(User).where(User.id == int(user_id))
|
||||
stmt = select(User).options(selectinload(User.role)).where(User.id == int(user_id))
|
||||
result = await db.execute(stmt)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not user:
|
||||
# If user doesn't exist in DB but token is valid, create basic user info from token
|
||||
return {
|
||||
@@ -181,49 +217,53 @@ async def get_current_user(
|
||||
"is_superuser": payload.get("is_superuser", False),
|
||||
"role": payload.get("role", "user"),
|
||||
"is_active": True,
|
||||
"permissions": [] # Default to empty list for permissions
|
||||
"permissions": [], # Default to empty list for permissions
|
||||
}
|
||||
|
||||
|
||||
# Update last login
|
||||
user.update_last_login()
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Calculate effective permissions using permission manager
|
||||
from app.services.permission_manager import permission_registry
|
||||
|
||||
# Convert role string to list for permission calculation
|
||||
user_roles = [user.role] if user.role else []
|
||||
|
||||
|
||||
# Convert role to name for permission calculation
|
||||
user_roles = [user.role.name] if user.role else []
|
||||
|
||||
# For super admin users, use only role-based permissions, ignore custom permissions
|
||||
# Custom permissions might contain legacy formats like ['*'] that don't work with new system
|
||||
# Custom permissions might contain legacy formats like ['*'] or dict formats
|
||||
custom_permissions = []
|
||||
if not user.is_superuser:
|
||||
# Only use custom permissions for non-superuser accounts
|
||||
if user.permissions:
|
||||
if isinstance(user.permissions, list):
|
||||
custom_permissions = user.permissions
|
||||
|
||||
# Support both list-based and dict-based custom permission formats
|
||||
raw_custom_perms = getattr(user, "custom_permissions", None)
|
||||
if raw_custom_perms:
|
||||
if isinstance(raw_custom_perms, list):
|
||||
custom_permissions = raw_custom_perms
|
||||
elif isinstance(raw_custom_perms, dict):
|
||||
granted = raw_custom_perms.get("granted")
|
||||
if isinstance(granted, list):
|
||||
custom_permissions = granted
|
||||
|
||||
# Calculate effective permissions based on role and custom permissions
|
||||
effective_permissions = permission_registry.get_user_permissions(
|
||||
roles=user_roles,
|
||||
custom_permissions=custom_permissions
|
||||
roles=user_roles, custom_permissions=custom_permissions
|
||||
)
|
||||
|
||||
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"username": user.username,
|
||||
"is_superuser": user.is_superuser,
|
||||
"is_active": user.is_active,
|
||||
"role": user.role,
|
||||
"role": user.role.name if user.role else None,
|
||||
"permissions": effective_permissions, # Use calculated permissions
|
||||
"user_obj": user # Include full user object for other operations
|
||||
"user_obj": user, # Include full user object for other operations
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication error: {e}")
|
||||
raise AuthenticationError("Could not validate credentials")
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
@@ -233,6 +273,7 @@ async def get_current_active_user(
|
||||
raise AuthenticationError("User account is inactive")
|
||||
return current_user
|
||||
|
||||
|
||||
async def get_current_superuser(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
@@ -241,99 +282,120 @@ async def get_current_superuser(
|
||||
raise AuthorizationError("Insufficient privileges")
|
||||
return current_user
|
||||
|
||||
|
||||
def generate_api_key() -> str:
|
||||
"""Generate a new API key"""
|
||||
import secrets
|
||||
import string
|
||||
|
||||
|
||||
# Generate random string
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
api_key = ''.join(secrets.choice(alphabet) for _ in range(32))
|
||||
|
||||
api_key = "".join(secrets.choice(alphabet) for _ in range(32))
|
||||
|
||||
return f"{settings.API_KEY_PREFIX}{api_key}"
|
||||
|
||||
|
||||
def hash_api_key(api_key: str) -> str:
|
||||
"""Hash API key for storage"""
|
||||
return get_password_hash(api_key)
|
||||
|
||||
|
||||
def verify_api_key(api_key: str, hashed_key: str) -> bool:
|
||||
"""Verify API key against hash"""
|
||||
return verify_password(api_key, hashed_key)
|
||||
|
||||
|
||||
async def get_api_key_user(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
request: Request, db: AsyncSession = Depends(get_db)
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get user from API key"""
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
|
||||
# Implement API key lookup in database
|
||||
from app.models.api_key import APIKey
|
||||
from app.models.user import User
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
try:
|
||||
# Extract key prefix for lookup
|
||||
if len(api_key) < 8:
|
||||
return None
|
||||
|
||||
|
||||
key_prefix = api_key[:8]
|
||||
|
||||
|
||||
# Query API key from database
|
||||
stmt = select(APIKey).join(User).where(
|
||||
APIKey.key_prefix == key_prefix,
|
||||
APIKey.is_active == True,
|
||||
User.is_active == True
|
||||
stmt = (
|
||||
select(APIKey)
|
||||
.join(User)
|
||||
.where(
|
||||
APIKey.key_prefix == key_prefix,
|
||||
APIKey.is_active == True,
|
||||
User.is_active == True,
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
db_api_key = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not db_api_key:
|
||||
return None
|
||||
|
||||
|
||||
# Verify the API key hash
|
||||
if not verify_api_key(api_key, db_api_key.key_hash):
|
||||
return None
|
||||
|
||||
|
||||
# Check if key is valid (not expired)
|
||||
if not db_api_key.is_valid():
|
||||
return None
|
||||
|
||||
|
||||
# Update last used timestamp
|
||||
db_api_key.last_used_at = datetime.utcnow()
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Load associated user
|
||||
user_stmt = select(User).where(User.id == db_api_key.user_id)
|
||||
user_stmt = select(User).options(selectinload(User.role)).where(User.id == db_api_key.user_id)
|
||||
user_result = await db.execute(user_stmt)
|
||||
user = user_result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not user or not user.is_active:
|
||||
return None
|
||||
|
||||
|
||||
# Calculate effective permissions using permission manager
|
||||
from app.services.permission_manager import permission_registry
|
||||
|
||||
# Convert role string to list for permission calculation
|
||||
user_roles = [user.role] if user.role else []
|
||||
|
||||
# Use API key specific permissions if available, otherwise use user permissions
|
||||
|
||||
# Convert role to name for permission calculation
|
||||
user_roles = [user.role.name] if user.role else []
|
||||
|
||||
# Use API key specific permissions if available
|
||||
api_key_permissions = db_api_key.permissions if db_api_key.permissions else []
|
||||
|
||||
# Get custom permissions from database (convert dict to list if needed)
|
||||
custom_permissions = api_key_permissions
|
||||
if user.permissions:
|
||||
if isinstance(user.permissions, list):
|
||||
custom_permissions.extend(user.permissions)
|
||||
|
||||
|
||||
# Normalize permissions into a flat list of granted permission strings
|
||||
custom_permissions: list[str] = []
|
||||
|
||||
# Handle API key permissions that may be stored as list or dict
|
||||
if isinstance(api_key_permissions, list):
|
||||
custom_permissions.extend(api_key_permissions)
|
||||
elif isinstance(api_key_permissions, dict):
|
||||
api_granted = api_key_permissions.get("granted")
|
||||
if isinstance(api_granted, list):
|
||||
custom_permissions.extend(api_granted)
|
||||
|
||||
# Merge in user-level custom permissions for non-superusers
|
||||
raw_user_custom = getattr(user, "custom_permissions", None)
|
||||
if raw_user_custom and not user.is_superuser:
|
||||
if isinstance(raw_user_custom, list):
|
||||
custom_permissions.extend(raw_user_custom)
|
||||
elif isinstance(raw_user_custom, dict):
|
||||
user_granted = raw_user_custom.get("granted")
|
||||
if isinstance(user_granted, list):
|
||||
custom_permissions.extend(user_granted)
|
||||
|
||||
# Calculate effective permissions based on role and custom permissions
|
||||
effective_permissions = permission_registry.get_user_permissions(
|
||||
roles=user_roles,
|
||||
custom_permissions=custom_permissions
|
||||
roles=user_roles, custom_permissions=custom_permissions
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
@@ -344,73 +406,80 @@ async def get_api_key_user(
|
||||
"permissions": effective_permissions,
|
||||
"api_key": db_api_key,
|
||||
"user_obj": user,
|
||||
"auth_type": "api_key"
|
||||
"auth_type": "api_key",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"API key lookup error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class RequiresPermission:
|
||||
"""Dependency class for permission checking"""
|
||||
|
||||
|
||||
def __init__(self, permission: str):
|
||||
self.permission = permission
|
||||
|
||||
|
||||
def __call__(self, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
# Implement permission checking
|
||||
# Check if user is superuser (has all permissions)
|
||||
if current_user.get("is_superuser", False):
|
||||
return current_user
|
||||
|
||||
|
||||
# Check role-based permissions
|
||||
role = current_user.get("role", "user")
|
||||
role_permissions = {
|
||||
"user": ["read_own", "create_own", "update_own"],
|
||||
"admin": ["read_all", "create_all", "update_all", "delete_own"],
|
||||
"super_admin": ["read_all", "create_all", "update_all", "delete_all", "manage_users", "manage_modules"]
|
||||
"super_admin": [
|
||||
"read_all",
|
||||
"create_all",
|
||||
"update_all",
|
||||
"delete_all",
|
||||
"manage_users",
|
||||
"manage_modules",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
if role in role_permissions and self.permission in role_permissions[role]:
|
||||
return current_user
|
||||
|
||||
|
||||
# Check custom permissions
|
||||
user_permissions = current_user.get("permissions", {})
|
||||
if self.permission in user_permissions:
|
||||
return current_user
|
||||
|
||||
|
||||
# If user has access to full user object, use the model's has_permission method
|
||||
user_obj = current_user.get("user_obj")
|
||||
if user_obj and hasattr(user_obj, "has_permission"):
|
||||
if user_obj.has_permission(self.permission):
|
||||
return current_user
|
||||
|
||||
|
||||
raise AuthorizationError(f"Permission '{self.permission}' required")
|
||||
|
||||
|
||||
class RequiresRole:
|
||||
"""Dependency class for role checking"""
|
||||
|
||||
|
||||
def __init__(self, role: str):
|
||||
self.role = role
|
||||
|
||||
|
||||
def __call__(self, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
# Implement role checking
|
||||
# Superusers have access to everything
|
||||
if current_user.get("is_superuser", False):
|
||||
return current_user
|
||||
|
||||
|
||||
user_role = current_user.get("role", "user")
|
||||
|
||||
|
||||
# Define role hierarchy
|
||||
role_hierarchy = {
|
||||
"user": 1,
|
||||
"admin": 2,
|
||||
"super_admin": 3
|
||||
}
|
||||
|
||||
role_hierarchy = {"user": 1, "admin": 2, "super_admin": 3}
|
||||
|
||||
required_level = role_hierarchy.get(self.role, 0)
|
||||
user_level = role_hierarchy.get(user_role, 0)
|
||||
|
||||
|
||||
if user_level >= required_level:
|
||||
return current_user
|
||||
|
||||
raise AuthorizationError(f"Role '{self.role}' required, but user has role '{user_role}'")
|
||||
|
||||
raise AuthorizationError(
|
||||
f"Role '{self.role}' required, but user has role '{user_role}'"
|
||||
)
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""
|
||||
Database package
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -20,10 +20,10 @@ engine = create_async_engine(
|
||||
echo=settings.APP_DEBUG,
|
||||
future=True,
|
||||
pool_pre_ping=True,
|
||||
pool_size=50, # Increased from 20 for better concurrency
|
||||
max_overflow=100, # Increased from 30 for burst capacity
|
||||
pool_recycle=3600, # Recycle connections every hour
|
||||
pool_timeout=30, # Max time to get connection from pool
|
||||
pool_size=50, # Increased from 20 for better concurrency
|
||||
max_overflow=100, # Increased from 30 for burst capacity
|
||||
pool_recycle=3600, # Recycle connections every hour
|
||||
pool_timeout=30, # Max time to get connection from pool
|
||||
connect_args={
|
||||
"timeout": 5,
|
||||
"command_timeout": 5,
|
||||
@@ -46,10 +46,10 @@ sync_engine = create_engine(
|
||||
echo=settings.APP_DEBUG,
|
||||
future=True,
|
||||
pool_pre_ping=True,
|
||||
pool_size=25, # Increased from 10 for better performance
|
||||
max_overflow=50, # Increased from 20 for burst capacity
|
||||
pool_recycle=3600, # Recycle connections every hour
|
||||
pool_timeout=30, # Max time to get connection from pool
|
||||
pool_size=25, # Increased from 10 for better performance
|
||||
max_overflow=50, # Increased from 20 for burst capacity
|
||||
pool_recycle=3600, # Recycle connections every hour
|
||||
pool_timeout=30, # Max time to get connection from pool
|
||||
connect_args={
|
||||
"connect_timeout": 5,
|
||||
"application_name": "enclava_backend_sync",
|
||||
@@ -72,11 +72,12 @@ metadata = MetaData()
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Get database session"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
request_id = f"db_{int(time.time() * 1000)}"
|
||||
|
||||
|
||||
logger.info(f"[{request_id}] === DATABASE SESSION START ===")
|
||||
|
||||
|
||||
try:
|
||||
logger.info(f"[{request_id}] Creating database session...")
|
||||
async with async_session_factory() as session:
|
||||
@@ -86,7 +87,10 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
except Exception as e:
|
||||
# Only log if there's an actual error, not normal operation
|
||||
if str(e).strip(): # Only log if error message exists
|
||||
logger.error(f"[{request_id}] Database session error: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"[{request_id}] Database session error: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
@@ -94,9 +98,13 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
await session.close()
|
||||
close_time = time.time() - close_start
|
||||
total_time = time.time() - start_time
|
||||
logger.info(f"[{request_id}] Database session closed. Close time: {close_time:.3f}s, Total time: {total_time:.3f}s")
|
||||
logger.info(
|
||||
f"[{request_id}] Database session closed. Close time: {close_time:.3f}s, Total time: {total_time:.3f}s"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[{request_id}] Failed to create database session: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"[{request_id}] Failed to create database session: {e}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@@ -106,44 +114,82 @@ async def init_db():
|
||||
async with engine.begin() as conn:
|
||||
# Import all models to ensure they're registered
|
||||
from app.models.user import User
|
||||
from app.models.role import Role
|
||||
from app.models.api_key import APIKey
|
||||
from app.models.usage_tracking import UsageTracking
|
||||
|
||||
# Import additional models - these are available
|
||||
try:
|
||||
from app.models.budget import Budget
|
||||
except ImportError:
|
||||
logger.warning("Budget model not available yet")
|
||||
|
||||
|
||||
try:
|
||||
from app.models.audit_log import AuditLog
|
||||
except ImportError:
|
||||
logger.warning("AuditLog model not available yet")
|
||||
|
||||
|
||||
try:
|
||||
from app.models.module import Module
|
||||
except ImportError:
|
||||
logger.warning("Module model not available yet")
|
||||
|
||||
|
||||
# Tables are now created via migration container - no need to create here
|
||||
# await conn.run_sync(Base.metadata.create_all) # DISABLED - migrations handle this
|
||||
|
||||
|
||||
# Create default roles if they don't exist
|
||||
await create_default_roles()
|
||||
|
||||
# Create default admin user if no admin exists
|
||||
await create_default_admin()
|
||||
|
||||
|
||||
logger.info("Database initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize database: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def create_default_roles():
|
||||
"""Create default roles if they don't exist"""
|
||||
from app.models.role import Role, RoleLevel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
try:
|
||||
async with async_session_factory() as session:
|
||||
# Check if any roles exist
|
||||
stmt = select(Role).limit(1)
|
||||
result = await session.execute(stmt)
|
||||
existing_role = result.scalar_one_or_none()
|
||||
|
||||
if existing_role:
|
||||
logger.info("Roles already exist - skipping default role creation")
|
||||
return
|
||||
|
||||
# Create default roles using the Role.create_default_roles class method
|
||||
default_roles = Role.create_default_roles()
|
||||
|
||||
for role in default_roles:
|
||||
session.add(role)
|
||||
|
||||
await session.commit()
|
||||
|
||||
logger.info("Created default roles: read_only, user, admin, super_admin")
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Failed to create default roles due to database error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def create_default_admin():
|
||||
"""Create default admin user if user with ADMIN_EMAIL doesn't exist"""
|
||||
from app.models.user import User
|
||||
from app.models.role import Role
|
||||
from app.core.security import get_password_hash
|
||||
from app.core.config import settings
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
|
||||
try:
|
||||
admin_email = settings.ADMIN_EMAIL
|
||||
admin_password = settings.ADMIN_PASSWORD
|
||||
@@ -151,42 +197,61 @@ async def create_default_admin():
|
||||
if not admin_email or not admin_password:
|
||||
logger.info("Admin bootstrap skipped: ADMIN_EMAIL or ADMIN_PASSWORD unset")
|
||||
return
|
||||
|
||||
|
||||
async with async_session_factory() as session:
|
||||
# Check if user with ADMIN_EMAIL exists
|
||||
stmt = select(User).where(User.email == admin_email)
|
||||
result = await session.execute(stmt)
|
||||
existing_user = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if existing_user:
|
||||
logger.info(f"User with email {admin_email} already exists - skipping admin creation")
|
||||
logger.info(
|
||||
f"User with email {admin_email} already exists - skipping admin creation"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
# Get the super_admin role
|
||||
stmt = select(Role).where(Role.name == "super_admin")
|
||||
result = await session.execute(stmt)
|
||||
super_admin_role = result.scalar_one_or_none()
|
||||
|
||||
if not super_admin_role:
|
||||
logger.error("Super admin role not found - cannot create admin user")
|
||||
return
|
||||
|
||||
# Create admin user from environment variables
|
||||
# Generate username from email (part before @)
|
||||
admin_username = admin_email.split('@')[0]
|
||||
|
||||
admin_username = admin_email.split("@")[0]
|
||||
|
||||
admin_user = User.create_default_admin(
|
||||
email=admin_email,
|
||||
username=admin_username,
|
||||
password_hash=get_password_hash(admin_password)
|
||||
password_hash=get_password_hash(admin_password),
|
||||
)
|
||||
|
||||
|
||||
# Assign the super_admin role
|
||||
admin_user.role_id = super_admin_role.id
|
||||
|
||||
session.add(admin_user)
|
||||
await session.commit()
|
||||
|
||||
|
||||
logger.warning("=" * 60)
|
||||
logger.warning("ADMIN USER CREATED FROM ENVIRONMENT")
|
||||
logger.warning(f"Email: {admin_email}")
|
||||
logger.warning(f"Username: {admin_username}")
|
||||
logger.warning("Password: [Set via ADMIN_PASSWORD - only used on first creation]")
|
||||
logger.warning("Role: Super Administrator")
|
||||
logger.warning(
|
||||
"Password: [Set via ADMIN_PASSWORD - only used on first creation]"
|
||||
)
|
||||
logger.warning("PLEASE CHANGE THE PASSWORD AFTER FIRST LOGIN")
|
||||
logger.warning("=" * 60)
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Failed to create default admin user due to database error: {e}")
|
||||
except AttributeError as e:
|
||||
logger.error(f"Failed to create default admin user: invalid ADMIN_EMAIL '{settings.ADMIN_EMAIL}'")
|
||||
logger.error(
|
||||
f"Failed to create default admin user: invalid ADMIN_EMAIL '{settings.ADMIN_EMAIL}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create default admin user: {e}")
|
||||
# Don't raise here as this shouldn't block the application startup
|
||||
|
||||
@@ -59,7 +59,10 @@ async def _check_redis_startup():
|
||||
duration = time.perf_counter() - start
|
||||
logger.info(
|
||||
"Startup Redis check succeeded",
|
||||
extra={"redis_url": settings.REDIS_URL, "duration_seconds": round(duration, 3)},
|
||||
extra={
|
||||
"redis_url": settings.REDIS_URL,
|
||||
"duration_seconds": round(duration, 3),
|
||||
},
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
@@ -104,9 +107,10 @@ async def lifespan(app: FastAPI):
|
||||
"""
|
||||
logger.info("Starting Enclava platform...")
|
||||
background_tasks = []
|
||||
|
||||
|
||||
# Initialize core cache service (before database to provide caching for auth)
|
||||
from app.core.cache import core_cache
|
||||
|
||||
try:
|
||||
await core_cache.initialize()
|
||||
logger.info("Core cache service initialized successfully")
|
||||
@@ -122,12 +126,13 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
# Initialize database
|
||||
await init_db()
|
||||
|
||||
|
||||
# Initialize config manager
|
||||
await init_config_manager()
|
||||
|
||||
# Ensure platform permissions are registered before module discovery
|
||||
from app.services.permission_manager import permission_registry
|
||||
|
||||
permission_registry.register_platform_permissions()
|
||||
|
||||
# Initialize LLM service (needed by RAG module) concurrently
|
||||
@@ -153,40 +158,45 @@ async def lifespan(app: FastAPI):
|
||||
await module_manager.initialize(app)
|
||||
app.state.module_manager = module_manager
|
||||
logger.info("Module manager initialized successfully")
|
||||
|
||||
|
||||
# Initialize document processor
|
||||
from app.services.document_processor import document_processor
|
||||
|
||||
try:
|
||||
await document_processor.start()
|
||||
app.state.document_processor = document_processor
|
||||
except Exception as exc:
|
||||
logger.error(f"Document processor failed to start: {exc}")
|
||||
app.state.document_processor = None
|
||||
|
||||
|
||||
# Setup metrics
|
||||
try:
|
||||
setup_metrics(app)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Metrics setup failed: {exc}")
|
||||
|
||||
|
||||
# Start background audit worker
|
||||
from app.services.audit_service import start_audit_worker
|
||||
|
||||
try:
|
||||
start_audit_worker()
|
||||
except Exception as exc:
|
||||
logger.warning(f"Audit worker failed to start: {exc}")
|
||||
|
||||
|
||||
# Initialize plugin auto-discovery service concurrently
|
||||
async def initialize_plugins():
|
||||
from app.services.plugin_autodiscovery import initialize_plugin_autodiscovery
|
||||
|
||||
try:
|
||||
discovery_results = await initialize_plugin_autodiscovery()
|
||||
app.state.plugin_discovery_results = discovery_results
|
||||
logger.info(f"Plugin auto-discovery completed: {discovery_results.get('summary')}")
|
||||
logger.info(
|
||||
f"Plugin auto-discovery completed: {discovery_results.get('summary')}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Plugin auto-discovery failed: {exc}")
|
||||
app.state.plugin_discovery_results = {"error": str(exc)}
|
||||
|
||||
|
||||
background_tasks.append(asyncio.create_task(initialize_plugins()))
|
||||
|
||||
if background_tasks:
|
||||
@@ -194,9 +204,9 @@ async def lifespan(app: FastAPI):
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"Background startup task failed: {result}")
|
||||
|
||||
|
||||
logger.info("Platform started successfully")
|
||||
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
@@ -205,6 +215,7 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
# Cleanup embedding service HTTP sessions
|
||||
from app.services.embedding_service import embedding_service
|
||||
|
||||
try:
|
||||
await embedding_service.cleanup()
|
||||
logger.info("Embedding service cleaned up successfully")
|
||||
@@ -213,14 +224,16 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
# Close core cache service
|
||||
from app.core.cache import core_cache
|
||||
|
||||
await core_cache.cleanup()
|
||||
|
||||
# Close Redis connection for cached API key service
|
||||
from app.services.cached_api_key import cached_api_key_service
|
||||
|
||||
await cached_api_key_service.close()
|
||||
|
||||
# Stop document processor
|
||||
processor = getattr(app.state, 'document_processor', None)
|
||||
processor = getattr(app.state, "document_processor", None)
|
||||
if processor:
|
||||
await processor.stop()
|
||||
|
||||
@@ -297,10 +310,12 @@ async def validation_exception_handler(request, exc: RequestValidationError):
|
||||
"type": error.get("type", ""),
|
||||
"location": error.get("loc", []),
|
||||
"message": error.get("msg", ""),
|
||||
"input": str(error.get("input", "")) if error.get("input") is not None else None
|
||||
"input": str(error.get("input", ""))
|
||||
if error.get("input") is not None
|
||||
else None,
|
||||
}
|
||||
errors.append(error_dict)
|
||||
|
||||
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content={
|
||||
@@ -326,7 +341,7 @@ async def general_exception_handler(request, exc: Exception):
|
||||
# Include Internal API routes (for frontend)
|
||||
app.include_router(internal_api_router, prefix="/api-internal/v1")
|
||||
|
||||
# Include Public API routes (for external clients)
|
||||
# Include Public API routes (for external clients)
|
||||
app.include_router(public_api_router, prefix="/api/v1")
|
||||
|
||||
# OpenAI-compatible routes are now included in public API router at /api/v1/
|
||||
@@ -357,7 +372,7 @@ async def root():
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.APP_HOST,
|
||||
|
||||
@@ -17,24 +17,29 @@ from app.db.database import get_db
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Context variable to pass analytics data from endpoints to middleware
|
||||
analytics_context: ContextVar[dict] = ContextVar('analytics_context', default={})
|
||||
analytics_context: ContextVar[dict] = ContextVar("analytics_context", default={})
|
||||
|
||||
|
||||
class AnalyticsMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to automatically track all requests for analytics"""
|
||||
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Start timing
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# Skip analytics for health checks and static files
|
||||
if request.url.path in ["/health", "/docs", "/redoc", "/openapi.json"] or request.url.path.startswith("/static"):
|
||||
if request.url.path in [
|
||||
"/health",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
] or request.url.path.startswith("/static"):
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
# Get user info if available from token
|
||||
user_id = None
|
||||
api_key_id = None
|
||||
|
||||
|
||||
try:
|
||||
authorization = request.headers.get("Authorization")
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
@@ -42,6 +47,7 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
|
||||
# Try to extract user info from token without full validation
|
||||
# This is a lightweight check for analytics purposes
|
||||
from app.core.security import verify_token
|
||||
|
||||
try:
|
||||
payload = verify_token(token)
|
||||
user_id = int(payload.get("sub"))
|
||||
@@ -51,7 +57,7 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
|
||||
except Exception:
|
||||
# Don't let analytics break the request
|
||||
pass
|
||||
|
||||
|
||||
# Get client IP
|
||||
client_ip = request.client.host if request.client else None
|
||||
if not client_ip:
|
||||
@@ -59,17 +65,17 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
|
||||
client_ip = request.headers.get("X-Forwarded-For", "").split(",")[0].strip()
|
||||
if not client_ip:
|
||||
client_ip = request.headers.get("X-Real-IP", "unknown")
|
||||
|
||||
|
||||
# Get user agent
|
||||
user_agent = request.headers.get("User-Agent", "")
|
||||
|
||||
|
||||
# Get request size
|
||||
request_size = int(request.headers.get("Content-Length", 0))
|
||||
|
||||
|
||||
# Process the request
|
||||
response = None
|
||||
error_message = None
|
||||
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
except Exception as e:
|
||||
@@ -77,21 +83,21 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
|
||||
error_message = str(e)
|
||||
response = JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": "INTERNAL_ERROR", "message": "Internal server error"}
|
||||
content={"error": "INTERNAL_ERROR", "message": "Internal server error"},
|
||||
)
|
||||
|
||||
|
||||
# Calculate timing
|
||||
end_time = time.time()
|
||||
response_time = (end_time - start_time) * 1000 # Convert to milliseconds
|
||||
|
||||
|
||||
# Get response size
|
||||
response_size = 0
|
||||
if hasattr(response, 'body'):
|
||||
if hasattr(response, "body"):
|
||||
response_size = len(response.body) if response.body else 0
|
||||
|
||||
|
||||
# Get analytics data from context (set by endpoints)
|
||||
context_data = analytics_context.get({})
|
||||
|
||||
|
||||
# Create analytics event
|
||||
event = RequestEvent(
|
||||
timestamp=datetime.utcnow(),
|
||||
@@ -107,26 +113,29 @@ class AnalyticsMiddleware(BaseHTTPMiddleware):
|
||||
response_size=response_size,
|
||||
error_message=error_message,
|
||||
# Token/cost info populated by LLM endpoints via context
|
||||
model=context_data.get('model'),
|
||||
request_tokens=context_data.get('request_tokens', 0),
|
||||
response_tokens=context_data.get('response_tokens', 0),
|
||||
total_tokens=context_data.get('total_tokens', 0),
|
||||
cost_cents=context_data.get('cost_cents', 0),
|
||||
budget_ids=context_data.get('budget_ids', []),
|
||||
budget_warnings=context_data.get('budget_warnings', [])
|
||||
model=context_data.get("model"),
|
||||
request_tokens=context_data.get("request_tokens", 0),
|
||||
response_tokens=context_data.get("response_tokens", 0),
|
||||
total_tokens=context_data.get("total_tokens", 0),
|
||||
cost_cents=context_data.get("cost_cents", 0),
|
||||
budget_ids=context_data.get("budget_ids", []),
|
||||
budget_warnings=context_data.get("budget_warnings", []),
|
||||
)
|
||||
|
||||
|
||||
# Track the event
|
||||
try:
|
||||
from app.services.analytics import analytics_service
|
||||
|
||||
if analytics_service is not None:
|
||||
await analytics_service.track_request(event)
|
||||
else:
|
||||
logger.warning("Analytics service not initialized, skipping event tracking")
|
||||
logger.warning(
|
||||
"Analytics service not initialized, skipping event tracking"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to track analytics event: {e}")
|
||||
# Don't let analytics failures break the request
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@@ -140,4 +149,4 @@ def set_analytics_data(**kwargs):
|
||||
def setup_analytics_middleware(app):
|
||||
"""Add analytics middleware to the FastAPI app"""
|
||||
app.add_middleware(AnalyticsMiddleware)
|
||||
logger.info("Analytics middleware configured")
|
||||
logger.info("Analytics middleware configured")
|
||||
|
||||
393
backend/app/middleware/audit_middleware.py
Normal file
393
backend/app/middleware/audit_middleware.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""
|
||||
Audit Logging Middleware
|
||||
Automatically logs user actions and system events
|
||||
"""
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
from typing import Callable, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.audit_log import AuditLog, AuditAction, AuditSeverity
|
||||
from app.db.database import get_db_session
|
||||
from app.core.security import verify_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuditLoggingMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to automatically log user actions and API calls"""
|
||||
|
||||
def __init__(self, app, exclude_paths: Optional[list] = None):
|
||||
super().__init__(app)
|
||||
# Paths to exclude from audit logging
|
||||
self.exclude_paths = exclude_paths or [
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/static",
|
||||
"/favicon.ico",
|
||||
]
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# Skip audit logging for excluded paths
|
||||
if any(request.url.path.startswith(path) for path in self.exclude_paths):
|
||||
return await call_next(request)
|
||||
|
||||
# Skip audit logging for health checks and static assets
|
||||
if request.url.path in ["/", "/health"] or "/static/" in request.url.path:
|
||||
return await call_next(request)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Extract user information from request
|
||||
user_info = await self._extract_user_info(request)
|
||||
|
||||
# Prepare audit data
|
||||
audit_data = {
|
||||
"method": request.method,
|
||||
"path": request.url.path,
|
||||
"query_params": dict(request.query_params),
|
||||
"ip_address": self._get_client_ip(request),
|
||||
"user_agent": request.headers.get("user-agent"),
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Calculate response time
|
||||
process_time = time.time() - start_time
|
||||
audit_data["response_time"] = round(process_time * 1000, 2) # milliseconds
|
||||
audit_data["status_code"] = response.status_code
|
||||
audit_data["success"] = 200 <= response.status_code < 400
|
||||
|
||||
# Log the audit event asynchronously
|
||||
try:
|
||||
await self._log_audit_event(user_info, audit_data, request)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log audit event: {e}")
|
||||
# Don't fail the request if audit logging fails
|
||||
|
||||
return response
|
||||
|
||||
async def _extract_user_info(self, request: Request) -> Optional[Dict[str, Any]]:
|
||||
"""Extract user information from request headers"""
|
||||
try:
|
||||
# Try to get user info from Authorization header
|
||||
auth_header = request.headers.get("authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
token = auth_header.split(" ")[1]
|
||||
payload = verify_token(token)
|
||||
return {
|
||||
"user_id": int(payload.get("sub")) if payload.get("sub") else None,
|
||||
"email": payload.get("email"),
|
||||
"is_superuser": payload.get("is_superuser", False),
|
||||
"role": payload.get("role"),
|
||||
}
|
||||
except Exception:
|
||||
# If token verification fails, continue without user info
|
||||
pass
|
||||
|
||||
# Try to get user info from API key header
|
||||
api_key = request.headers.get("x-api-key")
|
||||
if api_key:
|
||||
# Would need to implement API key lookup here
|
||||
# For now, just indicate it's an API key request
|
||||
return {
|
||||
"user_id": None,
|
||||
"email": "api_key_user",
|
||||
"is_superuser": False,
|
||||
"role": "api_user",
|
||||
"auth_type": "api_key",
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Get client IP address with proxy support"""
|
||||
# Check for forwarded headers first (for reverse proxy setups)
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
# Take the first IP in the chain
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
forwarded = request.headers.get("x-forwarded")
|
||||
if forwarded:
|
||||
return forwarded
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# Fall back to direct client IP
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
async def _log_audit_event(
|
||||
self,
|
||||
user_info: Optional[Dict[str, Any]],
|
||||
audit_data: Dict[str, Any],
|
||||
request: Request
|
||||
):
|
||||
"""Log the audit event to database"""
|
||||
|
||||
# Determine action based on HTTP method and path
|
||||
action = self._determine_action(request.method, request.url.path)
|
||||
|
||||
# Determine resource type and ID from path
|
||||
resource_type, resource_id = self._parse_resource_from_path(request.url.path)
|
||||
|
||||
# Create description
|
||||
description = self._create_description(request.method, request.url.path, audit_data["success"])
|
||||
|
||||
# Determine severity
|
||||
severity = self._determine_severity(request.method, audit_data["status_code"], request.url.path)
|
||||
|
||||
# Create audit log entry
|
||||
try:
|
||||
async with get_db_session() as db:
|
||||
audit_log = AuditLog(
|
||||
user_id=user_info.get("user_id") if user_info else None,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
description=description,
|
||||
details={
|
||||
"request": {
|
||||
"method": audit_data["method"],
|
||||
"path": audit_data["path"],
|
||||
"query_params": audit_data["query_params"],
|
||||
"response_time_ms": audit_data["response_time"],
|
||||
},
|
||||
"user_info": user_info,
|
||||
},
|
||||
ip_address=audit_data["ip_address"],
|
||||
user_agent=audit_data["user_agent"],
|
||||
severity=severity,
|
||||
category=self._determine_category(request.url.path),
|
||||
success=audit_data["success"],
|
||||
tags=self._generate_tags(request.method, request.url.path),
|
||||
)
|
||||
|
||||
db.add(audit_log)
|
||||
await db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save audit log to database: {e}")
|
||||
# Could implement fallback logging to file here
|
||||
|
||||
def _determine_action(self, method: str, path: str) -> str:
|
||||
"""Determine action type from HTTP method and path"""
|
||||
method = method.upper()
|
||||
|
||||
if method == "GET":
|
||||
return AuditAction.READ
|
||||
elif method == "POST":
|
||||
if "login" in path.lower():
|
||||
return AuditAction.LOGIN
|
||||
elif "logout" in path.lower():
|
||||
return AuditAction.LOGOUT
|
||||
else:
|
||||
return AuditAction.CREATE
|
||||
elif method == "PUT" or method == "PATCH":
|
||||
return AuditAction.UPDATE
|
||||
elif method == "DELETE":
|
||||
return AuditAction.DELETE
|
||||
else:
|
||||
return method.lower()
|
||||
|
||||
def _parse_resource_from_path(self, path: str) -> tuple[str, Optional[str]]:
|
||||
"""Parse resource type and ID from URL path"""
|
||||
path_parts = path.strip("/").split("/")
|
||||
|
||||
# Skip API version prefix
|
||||
if path_parts and path_parts[0] in ["api", "api-internal"]:
|
||||
path_parts = path_parts[2:] # Skip 'api' and 'v1'
|
||||
|
||||
if not path_parts:
|
||||
return "system", None
|
||||
|
||||
resource_type = path_parts[0]
|
||||
resource_id = None
|
||||
|
||||
# Try to find numeric ID in path
|
||||
for part in path_parts[1:]:
|
||||
if part.isdigit():
|
||||
resource_id = part
|
||||
break
|
||||
|
||||
return resource_type, resource_id
|
||||
|
||||
def _create_description(self, method: str, path: str, success: bool) -> str:
|
||||
"""Create human-readable description of the action"""
|
||||
action_verbs = {
|
||||
"GET": "accessed" if success else "attempted to access",
|
||||
"POST": "created" if success else "attempted to create",
|
||||
"PUT": "updated" if success else "attempted to update",
|
||||
"PATCH": "modified" if success else "attempted to modify",
|
||||
"DELETE": "deleted" if success else "attempted to delete",
|
||||
}
|
||||
|
||||
verb = action_verbs.get(method, method.lower())
|
||||
resource = path.strip("/").split("/")[-1] if "/" in path else path
|
||||
|
||||
return f"User {verb} {resource}"
|
||||
|
||||
def _determine_severity(self, method: str, status_code: int, path: str) -> str:
|
||||
"""Determine severity level based on action and outcome"""
|
||||
|
||||
# Critical operations
|
||||
if any(keyword in path.lower() for keyword in ["delete", "password", "admin", "key"]):
|
||||
return AuditSeverity.HIGH
|
||||
|
||||
# Failed operations
|
||||
if status_code >= 400:
|
||||
if status_code >= 500:
|
||||
return AuditSeverity.CRITICAL
|
||||
elif status_code in [401, 403]:
|
||||
return AuditSeverity.HIGH
|
||||
else:
|
||||
return AuditSeverity.MEDIUM
|
||||
|
||||
# Write operations
|
||||
if method in ["POST", "PUT", "PATCH", "DELETE"]:
|
||||
return AuditSeverity.MEDIUM
|
||||
|
||||
# Read operations
|
||||
return AuditSeverity.LOW
|
||||
|
||||
def _determine_category(self, path: str) -> str:
|
||||
"""Determine category based on path"""
|
||||
path = path.lower()
|
||||
|
||||
if any(keyword in path for keyword in ["auth", "login", "logout", "token"]):
|
||||
return "authentication"
|
||||
elif any(keyword in path for keyword in ["user", "admin", "role", "permission"]):
|
||||
return "user_management"
|
||||
elif any(keyword in path for keyword in ["api-key", "key"]):
|
||||
return "security"
|
||||
elif any(keyword in path for keyword in ["budget", "billing", "usage"]):
|
||||
return "financial"
|
||||
elif any(keyword in path for keyword in ["audit", "log"]):
|
||||
return "audit"
|
||||
elif any(keyword in path for keyword in ["setting", "config"]):
|
||||
return "configuration"
|
||||
else:
|
||||
return "general"
|
||||
|
||||
def _generate_tags(self, method: str, path: str) -> list[str]:
|
||||
"""Generate tags for the audit log"""
|
||||
tags = [method.lower()]
|
||||
|
||||
path_parts = path.strip("/").split("/")
|
||||
if path_parts:
|
||||
tags.append(path_parts[0])
|
||||
|
||||
# Add special tags
|
||||
if "admin" in path.lower():
|
||||
tags.append("admin_action")
|
||||
if any(keyword in path.lower() for keyword in ["password", "auth", "login"]):
|
||||
tags.append("security_action")
|
||||
|
||||
return tags
|
||||
|
||||
|
||||
class LoginAuditMiddleware(BaseHTTPMiddleware):
|
||||
"""Specialized middleware for login/logout events"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# Only process auth-related endpoints
|
||||
if not any(path in request.url.path for path in ["/auth/login", "/auth/logout", "/auth/refresh"]):
|
||||
return await call_next(request)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Store request body for login attempts
|
||||
request_body = None
|
||||
if request.method == "POST" and "/login" in request.url.path:
|
||||
try:
|
||||
body = await request.body()
|
||||
if body:
|
||||
request_body = json.loads(body.decode())
|
||||
# Re-create request with body for downstream processing
|
||||
from starlette.requests import Request as StarletteRequest
|
||||
from io import BytesIO
|
||||
request._body = body
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse login request body: {e}")
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
# Log login/logout events
|
||||
try:
|
||||
await self._log_auth_event(request, response, request_body, time.time() - start_time)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log auth event: {e}")
|
||||
|
||||
return response
|
||||
|
||||
async def _log_auth_event(self, request: Request, response: Response, request_body: dict, process_time: float):
|
||||
"""Log authentication events"""
|
||||
|
||||
success = 200 <= response.status_code < 300
|
||||
|
||||
if "/login" in request.url.path:
|
||||
# Extract email/username from request
|
||||
identifier = None
|
||||
if request_body:
|
||||
identifier = request_body.get("email") or request_body.get("username")
|
||||
|
||||
# For successful logins, we could extract user_id from response
|
||||
# For now, we'll use the identifier
|
||||
|
||||
async with get_db_session() as db:
|
||||
audit_log = AuditLog.create_login_event(
|
||||
user_id=None, # Would need to extract from response for successful logins
|
||||
success=success,
|
||||
ip_address=self._get_client_ip(request),
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
error_message=f"HTTP {response.status_code}" if not success else None,
|
||||
)
|
||||
|
||||
# Add additional details
|
||||
audit_log.details.update({
|
||||
"identifier": identifier,
|
||||
"response_time_ms": round(process_time * 1000, 2),
|
||||
})
|
||||
|
||||
db.add(audit_log)
|
||||
await db.commit()
|
||||
|
||||
elif "/logout" in request.url.path:
|
||||
# Extract user info from token if available
|
||||
user_id = None
|
||||
try:
|
||||
auth_header = request.headers.get("authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
token = auth_header.split(" ")[1]
|
||||
payload = verify_token(token)
|
||||
user_id = int(payload.get("sub")) if payload.get("sub") else None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async with get_db_session() as db:
|
||||
audit_log = AuditLog.create_logout_event(
|
||||
user_id=user_id,
|
||||
session_id=None, # Could extract from token if stored
|
||||
)
|
||||
|
||||
db.add(audit_log)
|
||||
await db.commit()
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Get client IP address with proxy support"""
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
return request.client.host if request.client else "unknown"
|
||||
@@ -23,8 +23,12 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
|
||||
request_id = str(uuid4())
|
||||
|
||||
# Skip debugging for health checks and static files
|
||||
if request.url.path in ["/health", "/docs", "/redoc", "/openapi.json"] or \
|
||||
request.url.path.startswith("/static"):
|
||||
if request.url.path in [
|
||||
"/health",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
] or request.url.path.startswith("/static"):
|
||||
return await call_next(request)
|
||||
|
||||
# Log request details
|
||||
@@ -37,7 +41,7 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
|
||||
try:
|
||||
request_body = json.loads(body_bytes)
|
||||
except json.JSONDecodeError:
|
||||
request_body = body_bytes.decode('utf-8', errors='replace')
|
||||
request_body = body_bytes.decode("utf-8", errors="replace")
|
||||
# Restore body for downstream processing
|
||||
request._body = body_bytes
|
||||
except Exception:
|
||||
@@ -45,8 +49,9 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
# Extract headers we care about
|
||||
headers_to_log = {
|
||||
"authorization": request.headers.get("Authorization", "")[:50] + "..." if
|
||||
request.headers.get("Authorization") else None,
|
||||
"authorization": request.headers.get("Authorization", "")[:50] + "..."
|
||||
if request.headers.get("Authorization")
|
||||
else None,
|
||||
"content-type": request.headers.get("Content-Type"),
|
||||
"user-agent": request.headers.get("User-Agent"),
|
||||
"x-forwarded-for": request.headers.get("X-Forwarded-For"),
|
||||
@@ -54,17 +59,20 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
|
||||
}
|
||||
|
||||
# Log request
|
||||
logger.info("=== API REQUEST DEBUG ===", extra={
|
||||
"request_id": request_id,
|
||||
"method": request.method,
|
||||
"url": str(request.url),
|
||||
"path": request.url.path,
|
||||
"query_params": dict(request.query_params),
|
||||
"headers": {k: v for k, v in headers_to_log.items() if v is not None},
|
||||
"body": request_body,
|
||||
"client_ip": request.client.host if request.client else None,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
logger.info(
|
||||
"=== API REQUEST DEBUG ===",
|
||||
extra={
|
||||
"request_id": request_id,
|
||||
"method": request.method,
|
||||
"url": str(request.url),
|
||||
"path": request.url.path,
|
||||
"query_params": dict(request.query_params),
|
||||
"headers": {k: v for k, v in headers_to_log.items() if v is not None},
|
||||
"body": request_body,
|
||||
"client_ip": request.client.host if request.client else None,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
# Process the request
|
||||
start_time = time.time()
|
||||
@@ -73,33 +81,43 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
# Add timeout detection
|
||||
try:
|
||||
logger.info(f"=== START PROCESSING REQUEST === {request_id} at {datetime.utcnow().isoformat()}")
|
||||
logger.info(
|
||||
f"=== START PROCESSING REQUEST === {request_id} at {datetime.utcnow().isoformat()}"
|
||||
)
|
||||
logger.info(f"Request path: {request.url.path}")
|
||||
logger.info(f"Request method: {request.method}")
|
||||
|
||||
|
||||
# Check if this is the login endpoint
|
||||
if request.url.path == "/api-internal/v1/auth/login" and request.method == "POST":
|
||||
if (
|
||||
request.url.path == "/api-internal/v1/auth/login"
|
||||
and request.method == "POST"
|
||||
):
|
||||
logger.info(f"=== LOGIN REQUEST DETECTED === {request_id}")
|
||||
|
||||
|
||||
response = await call_next(request)
|
||||
logger.info(f"=== REQUEST COMPLETED === {request_id} at {datetime.utcnow().isoformat()}")
|
||||
logger.info(
|
||||
f"=== REQUEST COMPLETED === {request_id} at {datetime.utcnow().isoformat()}"
|
||||
)
|
||||
|
||||
# Capture response body for successful JSON responses
|
||||
if response.status_code < 400 and isinstance(response, JSONResponse):
|
||||
try:
|
||||
response_body = json.loads(response.body.decode('utf-8'))
|
||||
response_body = json.loads(response.body.decode("utf-8"))
|
||||
except:
|
||||
response_body = "[Failed to decode response body]"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Request processing failed: {str(e)}", extra={
|
||||
"request_id": request_id,
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__
|
||||
})
|
||||
logger.error(
|
||||
f"Request processing failed: {str(e)}",
|
||||
extra={
|
||||
"request_id": request_id,
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
},
|
||||
)
|
||||
response = JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": "INTERNAL_ERROR", "message": "Internal server error"}
|
||||
content={"error": "INTERNAL_ERROR", "message": "Internal server error"},
|
||||
)
|
||||
|
||||
# Calculate timing
|
||||
@@ -107,14 +125,17 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
|
||||
duration = (end_time - start_time) * 1000 # milliseconds
|
||||
|
||||
# Log response
|
||||
logger.info("=== API RESPONSE DEBUG ===", extra={
|
||||
"request_id": request_id,
|
||||
"status_code": response.status_code,
|
||||
"duration_ms": round(duration, 2),
|
||||
"response_body": response_body,
|
||||
"response_headers": dict(response.headers),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
logger.info(
|
||||
"=== API RESPONSE DEBUG ===",
|
||||
extra={
|
||||
"request_id": request_id,
|
||||
"status_code": response.status_code,
|
||||
"duration_ms": round(duration, 2),
|
||||
"response_body": response_body,
|
||||
"response_headers": dict(response.headers),
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -122,4 +143,4 @@ class DebuggingMiddleware(BaseHTTPMiddleware):
|
||||
def setup_debugging_middleware(app):
|
||||
"""Add debugging middleware to the FastAPI app"""
|
||||
app.add_middleware(DebuggingMiddleware)
|
||||
logger.info("Debugging middleware configured")
|
||||
logger.info("Debugging middleware configured")
|
||||
|
||||
@@ -9,28 +9,63 @@ from .budget import Budget
|
||||
from .audit_log import AuditLog
|
||||
from .rag_collection import RagCollection
|
||||
from .rag_document import RagDocument
|
||||
from .chatbot import ChatbotInstance, ChatbotConversation, ChatbotMessage, ChatbotAnalytics
|
||||
from .chatbot import (
|
||||
ChatbotInstance,
|
||||
ChatbotConversation,
|
||||
ChatbotMessage,
|
||||
ChatbotAnalytics,
|
||||
)
|
||||
from .prompt_template import PromptTemplate, ChatbotPromptVariable
|
||||
from .plugin import Plugin, PluginConfiguration, PluginInstance, PluginAuditLog, PluginCronJob, PluginAPIGateway
|
||||
from .plugin import (
|
||||
Plugin,
|
||||
PluginConfiguration,
|
||||
PluginInstance,
|
||||
PluginAuditLog,
|
||||
PluginCronJob,
|
||||
PluginAPIGateway,
|
||||
)
|
||||
from .role import Role, RoleLevel
|
||||
from .tool import Tool, ToolExecution, ToolCategory, ToolType, ToolStatus
|
||||
from .notification import (
|
||||
Notification,
|
||||
NotificationTemplate,
|
||||
NotificationChannel,
|
||||
NotificationType,
|
||||
NotificationPriority,
|
||||
NotificationStatus,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
"APIKey",
|
||||
"UsageTracking",
|
||||
"Budget",
|
||||
"User",
|
||||
"APIKey",
|
||||
"UsageTracking",
|
||||
"Budget",
|
||||
"AuditLog",
|
||||
"RagCollection",
|
||||
"RagCollection",
|
||||
"RagDocument",
|
||||
"ChatbotInstance",
|
||||
"ChatbotConversation",
|
||||
"ChatbotConversation",
|
||||
"ChatbotMessage",
|
||||
"ChatbotAnalytics",
|
||||
"PromptTemplate",
|
||||
"ChatbotPromptVariable",
|
||||
"Plugin",
|
||||
"PluginConfiguration",
|
||||
"PluginInstance",
|
||||
"PluginInstance",
|
||||
"PluginAuditLog",
|
||||
"PluginCronJob",
|
||||
"PluginAPIGateway"
|
||||
]
|
||||
"PluginAPIGateway",
|
||||
"Role",
|
||||
"RoleLevel",
|
||||
"Tool",
|
||||
"ToolExecution",
|
||||
"ToolCategory",
|
||||
"ToolType",
|
||||
"ToolStatus",
|
||||
"Notification",
|
||||
"NotificationTemplate",
|
||||
"NotificationChannel",
|
||||
"NotificationType",
|
||||
"NotificationPriority",
|
||||
"NotificationStatus",
|
||||
]
|
||||
|
||||
@@ -3,73 +3,94 @@ API Key model
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any, List
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON, ForeignKey
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
Boolean,
|
||||
Text,
|
||||
JSON,
|
||||
ForeignKey,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class APIKey(Base):
|
||||
"""API Key model for authentication and access control"""
|
||||
|
||||
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False) # Human-readable name for the API key
|
||||
key_hash = Column(String, unique=True, index=True, nullable=False) # Hashed API key
|
||||
key_prefix = Column(String, index=True, nullable=False) # First 8 characters for identification
|
||||
|
||||
key_prefix = Column(
|
||||
String, index=True, nullable=False
|
||||
) # First 8 characters for identification
|
||||
|
||||
# User relationship
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
user = relationship("User", back_populates="api_keys")
|
||||
|
||||
|
||||
# Related data relationships
|
||||
budgets = relationship("Budget", back_populates="api_key", cascade="all, delete-orphan")
|
||||
usage_tracking = relationship("UsageTracking", back_populates="api_key", cascade="all, delete-orphan")
|
||||
|
||||
budgets = relationship(
|
||||
"Budget", back_populates="api_key", cascade="all, delete-orphan"
|
||||
)
|
||||
usage_tracking = relationship(
|
||||
"UsageTracking", back_populates="api_key", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Key status and permissions
|
||||
is_active = Column(Boolean, default=True)
|
||||
permissions = Column(JSON, default=dict) # Specific permissions for this key
|
||||
scopes = Column(JSON, default=list) # OAuth-like scopes
|
||||
|
||||
|
||||
# Usage limits
|
||||
rate_limit_per_minute = Column(Integer, default=60) # Requests per minute
|
||||
rate_limit_per_hour = Column(Integer, default=3600) # Requests per hour
|
||||
rate_limit_per_day = Column(Integer, default=86400) # Requests per day
|
||||
|
||||
|
||||
# Allowed resources
|
||||
allowed_models = Column(JSON, default=list) # List of allowed LLM models
|
||||
allowed_endpoints = Column(JSON, default=list) # List of allowed API endpoints
|
||||
allowed_ips = Column(JSON, default=list) # IP whitelist
|
||||
allowed_chatbots = Column(JSON, default=list) # List of allowed chatbot IDs for chatbot-specific keys
|
||||
|
||||
allowed_chatbots = Column(
|
||||
JSON, default=list
|
||||
) # List of allowed chatbot IDs for chatbot-specific keys
|
||||
|
||||
# Budget configuration
|
||||
is_unlimited = Column(Boolean, default=True) # Unlimited budget flag
|
||||
budget_limit_cents = Column(Integer, nullable=True) # Budget limit in cents
|
||||
budget_type = Column(String, nullable=True) # "total" or "monthly"
|
||||
|
||||
|
||||
# Metadata
|
||||
description = Column(Text, nullable=True)
|
||||
tags = Column(JSON, default=list) # For organizing keys
|
||||
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
last_used_at = Column(DateTime, nullable=True)
|
||||
expires_at = Column(DateTime, nullable=True) # Optional expiration
|
||||
|
||||
|
||||
# Usage tracking
|
||||
total_requests = Column(Integer, default=0)
|
||||
total_tokens = Column(Integer, default=0)
|
||||
total_cost = Column(Integer, default=0) # In cents
|
||||
|
||||
|
||||
# Relationships
|
||||
usage_tracking = relationship("UsageTracking", back_populates="api_key", cascade="all, delete-orphan")
|
||||
budgets = relationship("Budget", back_populates="api_key", cascade="all, delete-orphan")
|
||||
usage_tracking = relationship(
|
||||
"UsageTracking", back_populates="api_key", cascade="all, delete-orphan"
|
||||
)
|
||||
budgets = relationship(
|
||||
"Budget", back_populates="api_key", cascade="all, delete-orphan"
|
||||
)
|
||||
plugin_audit_logs = relationship("PluginAuditLog", back_populates="api_key")
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f"<APIKey(id={self.id}, name='{self.name}', user_id={self.user_id})>"
|
||||
|
||||
|
||||
def to_dict(self, include_sensitive: bool = False):
|
||||
"""Convert API key to dictionary for API responses"""
|
||||
data = {
|
||||
@@ -91,138 +112,142 @@ class APIKey(Base):
|
||||
"tags": self.tags,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
|
||||
"last_used_at": self.last_used_at.isoformat()
|
||||
if self.last_used_at
|
||||
else None,
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"total_requests": self.total_requests,
|
||||
"total_tokens": self.total_tokens,
|
||||
"total_cost_cents": self.total_cost,
|
||||
"is_unlimited": self.is_unlimited,
|
||||
"budget_limit": self.budget_limit_cents, # Map to budget_limit for API response
|
||||
"budget_type": self.budget_type
|
||||
"budget_type": self.budget_type,
|
||||
}
|
||||
|
||||
|
||||
if include_sensitive:
|
||||
data["key_hash"] = self.key_hash
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the API key has expired"""
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return datetime.utcnow() > self.expires_at
|
||||
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the API key is valid and active"""
|
||||
return self.is_active and not self.is_expired()
|
||||
|
||||
|
||||
def has_permission(self, permission: str) -> bool:
|
||||
"""Check if the API key has a specific permission"""
|
||||
return permission in self.permissions
|
||||
|
||||
|
||||
def has_scope(self, scope: str) -> bool:
|
||||
"""Check if the API key has a specific scope"""
|
||||
return scope in self.scopes
|
||||
|
||||
|
||||
def can_access_model(self, model_name: str) -> bool:
|
||||
"""Check if the API key can access a specific model"""
|
||||
if not self.allowed_models: # Empty list means all models allowed
|
||||
return True
|
||||
return model_name in self.allowed_models
|
||||
|
||||
|
||||
def can_access_endpoint(self, endpoint: str) -> bool:
|
||||
"""Check if the API key can access a specific endpoint"""
|
||||
if not self.allowed_endpoints: # Empty list means all endpoints allowed
|
||||
return True
|
||||
return endpoint in self.allowed_endpoints
|
||||
|
||||
|
||||
def can_access_from_ip(self, ip_address: str) -> bool:
|
||||
"""Check if the API key can be used from a specific IP"""
|
||||
if not self.allowed_ips: # Empty list means all IPs allowed
|
||||
return True
|
||||
return ip_address in self.allowed_ips
|
||||
|
||||
|
||||
def can_access_chatbot(self, chatbot_id: str) -> bool:
|
||||
"""Check if the API key can access a specific chatbot"""
|
||||
if not self.allowed_chatbots: # Empty list means all chatbots allowed
|
||||
return True
|
||||
return chatbot_id in self.allowed_chatbots
|
||||
|
||||
|
||||
def update_usage(self, tokens_used: int = 0, cost_cents: int = 0):
|
||||
"""Update usage statistics"""
|
||||
self.total_requests += 1
|
||||
self.total_tokens += tokens_used
|
||||
self.total_cost += cost_cents
|
||||
self.last_used_at = datetime.utcnow()
|
||||
|
||||
|
||||
def set_expiration(self, days: int):
|
||||
"""Set expiration date in days from now"""
|
||||
self.expires_at = datetime.utcnow() + timedelta(days=days)
|
||||
|
||||
|
||||
def extend_expiration(self, days: int):
|
||||
"""Extend expiration date by specified days"""
|
||||
if self.expires_at is None:
|
||||
self.expires_at = datetime.utcnow() + timedelta(days=days)
|
||||
else:
|
||||
self.expires_at = self.expires_at + timedelta(days=days)
|
||||
|
||||
|
||||
def revoke(self):
|
||||
"""Revoke the API key"""
|
||||
self.is_active = False
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
def add_scope(self, scope: str):
|
||||
"""Add a scope to the API key"""
|
||||
if scope not in self.scopes:
|
||||
self.scopes.append(scope)
|
||||
|
||||
|
||||
def remove_scope(self, scope: str):
|
||||
"""Remove a scope from the API key"""
|
||||
if scope in self.scopes:
|
||||
self.scopes.remove(scope)
|
||||
|
||||
|
||||
def add_allowed_model(self, model_name: str):
|
||||
"""Add an allowed model"""
|
||||
if model_name not in self.allowed_models:
|
||||
self.allowed_models.append(model_name)
|
||||
|
||||
|
||||
def remove_allowed_model(self, model_name: str):
|
||||
"""Remove an allowed model"""
|
||||
if model_name in self.allowed_models:
|
||||
self.allowed_models.remove(model_name)
|
||||
|
||||
|
||||
def add_allowed_endpoint(self, endpoint: str):
|
||||
"""Add an allowed endpoint"""
|
||||
if endpoint not in self.allowed_endpoints:
|
||||
self.allowed_endpoints.append(endpoint)
|
||||
|
||||
|
||||
def remove_allowed_endpoint(self, endpoint: str):
|
||||
"""Remove an allowed endpoint"""
|
||||
if endpoint in self.allowed_endpoints:
|
||||
self.allowed_endpoints.remove(endpoint)
|
||||
|
||||
|
||||
def add_allowed_ip(self, ip_address: str):
|
||||
"""Add an allowed IP address"""
|
||||
if ip_address not in self.allowed_ips:
|
||||
self.allowed_ips.append(ip_address)
|
||||
|
||||
|
||||
def remove_allowed_ip(self, ip_address: str):
|
||||
"""Remove an allowed IP address"""
|
||||
if ip_address in self.allowed_ips:
|
||||
self.allowed_ips.remove(ip_address)
|
||||
|
||||
|
||||
def add_allowed_chatbot(self, chatbot_id: str):
|
||||
"""Add an allowed chatbot"""
|
||||
if chatbot_id not in self.allowed_chatbots:
|
||||
self.allowed_chatbots.append(chatbot_id)
|
||||
|
||||
|
||||
def remove_allowed_chatbot(self, chatbot_id: str):
|
||||
"""Remove an allowed chatbot"""
|
||||
if chatbot_id in self.allowed_chatbots:
|
||||
self.allowed_chatbots.remove(chatbot_id)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_default_key(cls, user_id: int, name: str, key_hash: str, key_prefix: str) -> "APIKey":
|
||||
def create_default_key(
|
||||
cls, user_id: int, name: str, key_hash: str, key_prefix: str
|
||||
) -> "APIKey":
|
||||
"""Create a default API key with standard permissions"""
|
||||
return cls(
|
||||
name=name,
|
||||
@@ -230,17 +255,8 @@ class APIKey(Base):
|
||||
key_prefix=key_prefix,
|
||||
user_id=user_id,
|
||||
is_active=True,
|
||||
permissions={
|
||||
"read": True,
|
||||
"write": True,
|
||||
"chat": True,
|
||||
"embeddings": True
|
||||
},
|
||||
scopes=[
|
||||
"chat.completions",
|
||||
"embeddings.create",
|
||||
"models.list"
|
||||
],
|
||||
permissions={"read": True, "write": True, "chat": True, "embeddings": True},
|
||||
scopes=["chat.completions", "embeddings.create", "models.list"],
|
||||
rate_limit_per_minute=60,
|
||||
rate_limit_per_hour=3600,
|
||||
rate_limit_per_day=86400,
|
||||
@@ -248,12 +264,19 @@ class APIKey(Base):
|
||||
allowed_endpoints=[], # All endpoints allowed by default
|
||||
allowed_ips=[], # All IPs allowed by default
|
||||
description="Default API key with standard permissions",
|
||||
tags=["default"]
|
||||
tags=["default"],
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_restricted_key(cls, user_id: int, name: str, key_hash: str, key_prefix: str,
|
||||
models: List[str], endpoints: List[str]) -> "APIKey":
|
||||
def create_restricted_key(
|
||||
cls,
|
||||
user_id: int,
|
||||
name: str,
|
||||
key_hash: str,
|
||||
key_prefix: str,
|
||||
models: List[str],
|
||||
endpoints: List[str],
|
||||
) -> "APIKey":
|
||||
"""Create a restricted API key with limited permissions"""
|
||||
return cls(
|
||||
name=name,
|
||||
@@ -261,13 +284,8 @@ class APIKey(Base):
|
||||
key_prefix=key_prefix,
|
||||
user_id=user_id,
|
||||
is_active=True,
|
||||
permissions={
|
||||
"read": True,
|
||||
"chat": True
|
||||
},
|
||||
scopes=[
|
||||
"chat.completions"
|
||||
],
|
||||
permissions={"read": True, "chat": True},
|
||||
scopes=["chat.completions"],
|
||||
rate_limit_per_minute=30,
|
||||
rate_limit_per_hour=1800,
|
||||
rate_limit_per_day=43200,
|
||||
@@ -275,12 +293,19 @@ class APIKey(Base):
|
||||
allowed_endpoints=endpoints,
|
||||
allowed_ips=[],
|
||||
description="Restricted API key with limited permissions",
|
||||
tags=["restricted"]
|
||||
tags=["restricted"],
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_chatbot_key(cls, user_id: int, name: str, key_hash: str, key_prefix: str,
|
||||
chatbot_id: str, chatbot_name: str) -> "APIKey":
|
||||
def create_chatbot_key(
|
||||
cls,
|
||||
user_id: int,
|
||||
name: str,
|
||||
key_hash: str,
|
||||
key_prefix: str,
|
||||
chatbot_id: str,
|
||||
chatbot_name: str,
|
||||
) -> "APIKey":
|
||||
"""Create a chatbot-specific API key"""
|
||||
return cls(
|
||||
name=name,
|
||||
@@ -288,22 +313,18 @@ class APIKey(Base):
|
||||
key_prefix=key_prefix,
|
||||
user_id=user_id,
|
||||
is_active=True,
|
||||
permissions={
|
||||
"chatbot": True
|
||||
},
|
||||
scopes=[
|
||||
"chatbot.chat"
|
||||
],
|
||||
permissions={"chatbot": True},
|
||||
scopes=["chatbot.chat"],
|
||||
rate_limit_per_minute=100,
|
||||
rate_limit_per_hour=6000,
|
||||
rate_limit_per_day=144000,
|
||||
allowed_models=[], # Will use chatbot's configured model
|
||||
allowed_endpoints=[
|
||||
f"/api/v1/chatbot/external/{chatbot_id}/chat",
|
||||
f"/api/v1/chatbot/external/{chatbot_id}/chat/completions"
|
||||
f"/api/v1/chatbot/external/{chatbot_id}/chat/completions",
|
||||
],
|
||||
allowed_ips=[],
|
||||
allowed_chatbots=[chatbot_id],
|
||||
description=f"API key for chatbot: {chatbot_name}",
|
||||
tags=["chatbot", f"chatbot-{chatbot_id}"]
|
||||
)
|
||||
tags=["chatbot", f"chatbot-{chatbot_id}"],
|
||||
)
|
||||
|
||||
@@ -3,7 +3,16 @@ Audit log model for tracking system events and user actions
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from sqlalchemy import Column, Integer, String, DateTime, JSON, ForeignKey, Text, Boolean
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
JSON,
|
||||
ForeignKey,
|
||||
Text,
|
||||
Boolean,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db.database import Base
|
||||
from enum import Enum
|
||||
@@ -11,6 +20,7 @@ from enum import Enum
|
||||
|
||||
class AuditAction(str, Enum):
|
||||
"""Audit action types"""
|
||||
|
||||
CREATE = "create"
|
||||
READ = "read"
|
||||
UPDATE = "update"
|
||||
@@ -32,6 +42,7 @@ class AuditAction(str, Enum):
|
||||
|
||||
class AuditSeverity(str, Enum):
|
||||
"""Audit severity levels"""
|
||||
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
@@ -40,52 +51,58 @@ class AuditSeverity(str, Enum):
|
||||
|
||||
class AuditLog(Base):
|
||||
"""Audit log model for tracking system events and user actions"""
|
||||
|
||||
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
|
||||
# User relationship (nullable for system events)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
|
||||
user = relationship("User", back_populates="audit_logs")
|
||||
|
||||
|
||||
# Event details
|
||||
action = Column(String, nullable=False)
|
||||
resource_type = Column(String, nullable=False) # user, api_key, budget, module, etc.
|
||||
resource_type = Column(
|
||||
String, nullable=False
|
||||
) # user, api_key, budget, module, etc.
|
||||
resource_id = Column(String, nullable=True) # ID of the affected resource
|
||||
|
||||
|
||||
# Event description and details
|
||||
description = Column(Text, nullable=False)
|
||||
details = Column(JSON, default=dict) # Additional event details
|
||||
|
||||
|
||||
# Request context
|
||||
ip_address = Column(String, nullable=True)
|
||||
user_agent = Column(String, nullable=True)
|
||||
session_id = Column(String, nullable=True)
|
||||
request_id = Column(String, nullable=True)
|
||||
|
||||
|
||||
# Event classification
|
||||
severity = Column(String, default=AuditSeverity.LOW)
|
||||
category = Column(String, nullable=True) # security, access, data, system
|
||||
|
||||
|
||||
# Success/failure tracking
|
||||
success = Column(Boolean, default=True)
|
||||
error_message = Column(Text, nullable=True)
|
||||
|
||||
|
||||
# Additional metadata
|
||||
tags = Column(JSON, default=list)
|
||||
audit_metadata = Column("metadata", JSON, default=dict) # Map to 'metadata' column in DB
|
||||
|
||||
audit_metadata = Column(
|
||||
"metadata", JSON, default=dict
|
||||
) # Map to 'metadata' column in DB
|
||||
|
||||
# Before/after values for data changes
|
||||
old_values = Column(JSON, nullable=True)
|
||||
new_values = Column(JSON, nullable=True)
|
||||
|
||||
|
||||
# Timestamp
|
||||
created_at = Column(DateTime, default=datetime.utcnow, index=True)
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f"<AuditLog(id={self.id}, action='{self.action}', user_id={self.user_id})>"
|
||||
|
||||
return (
|
||||
f"<AuditLog(id={self.id}, action='{self.action}', user_id={self.user_id})>"
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert audit log to dictionary for API responses"""
|
||||
return {
|
||||
@@ -108,9 +125,9 @@ class AuditLog(Base):
|
||||
"metadata": self.audit_metadata,
|
||||
"old_values": self.old_values,
|
||||
"new_values": self.new_values,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def is_security_event(self) -> bool:
|
||||
"""Check if this is a security-related event"""
|
||||
security_actions = [
|
||||
@@ -120,39 +137,45 @@ class AuditLog(Base):
|
||||
AuditAction.API_KEY_DELETE,
|
||||
AuditAction.PERMISSION_GRANT,
|
||||
AuditAction.PERMISSION_REVOKE,
|
||||
AuditAction.SECURITY_EVENT
|
||||
AuditAction.SECURITY_EVENT,
|
||||
]
|
||||
return self.action in security_actions or self.category == "security"
|
||||
|
||||
|
||||
def is_high_severity(self) -> bool:
|
||||
"""Check if this is a high severity event"""
|
||||
return self.severity in [AuditSeverity.HIGH, AuditSeverity.CRITICAL]
|
||||
|
||||
|
||||
def add_tag(self, tag: str):
|
||||
"""Add a tag to the audit log"""
|
||||
if tag not in self.tags:
|
||||
self.tags.append(tag)
|
||||
|
||||
|
||||
def remove_tag(self, tag: str):
|
||||
"""Remove a tag from the audit log"""
|
||||
if tag in self.tags:
|
||||
self.tags.remove(tag)
|
||||
|
||||
|
||||
def update_metadata(self, key: str, value: Any):
|
||||
"""Update metadata"""
|
||||
if self.audit_metadata is None:
|
||||
self.audit_metadata = {}
|
||||
self.audit_metadata[key] = value
|
||||
|
||||
|
||||
def set_before_after(self, old_values: Dict[str, Any], new_values: Dict[str, Any]):
|
||||
"""Set before and after values for data changes"""
|
||||
self.old_values = old_values
|
||||
self.new_values = new_values
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_login_event(cls, user_id: int, success: bool = True,
|
||||
ip_address: str = None, user_agent: str = None,
|
||||
session_id: str = None, error_message: str = None) -> "AuditLog":
|
||||
def create_login_event(
|
||||
cls,
|
||||
user_id: int,
|
||||
success: bool = True,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
session_id: str = None,
|
||||
error_message: str = None,
|
||||
) -> "AuditLog":
|
||||
"""Create a login audit event"""
|
||||
return cls(
|
||||
user_id=user_id,
|
||||
@@ -160,10 +183,7 @@ class AuditLog(Base):
|
||||
resource_type="user",
|
||||
resource_id=str(user_id),
|
||||
description=f"User login {'successful' if success else 'failed'}",
|
||||
details={
|
||||
"login_method": "password",
|
||||
"success": success
|
||||
},
|
||||
details={"login_method": "password", "success": success},
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
session_id=session_id,
|
||||
@@ -171,9 +191,9 @@ class AuditLog(Base):
|
||||
category="security",
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
tags=["authentication", "login"]
|
||||
tags=["authentication", "login"],
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_logout_event(cls, user_id: int, session_id: str = None) -> "AuditLog":
|
||||
"""Create a logout audit event"""
|
||||
@@ -183,20 +203,24 @@ class AuditLog(Base):
|
||||
resource_type="user",
|
||||
resource_id=str(user_id),
|
||||
description="User logout",
|
||||
details={
|
||||
"logout_method": "manual"
|
||||
},
|
||||
details={"logout_method": "manual"},
|
||||
session_id=session_id,
|
||||
severity=AuditSeverity.LOW,
|
||||
category="security",
|
||||
success=True,
|
||||
tags=["authentication", "logout"]
|
||||
tags=["authentication", "logout"],
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_api_key_event(cls, user_id: int, action: str, api_key_id: int,
|
||||
api_key_name: str, success: bool = True,
|
||||
error_message: str = None) -> "AuditLog":
|
||||
def create_api_key_event(
|
||||
cls,
|
||||
user_id: int,
|
||||
action: str,
|
||||
api_key_id: int,
|
||||
api_key_name: str,
|
||||
success: bool = True,
|
||||
error_message: str = None,
|
||||
) -> "AuditLog":
|
||||
"""Create an API key audit event"""
|
||||
return cls(
|
||||
user_id=user_id,
|
||||
@@ -204,21 +228,24 @@ class AuditLog(Base):
|
||||
resource_type="api_key",
|
||||
resource_id=str(api_key_id),
|
||||
description=f"API key {action}: {api_key_name}",
|
||||
details={
|
||||
"api_key_name": api_key_name,
|
||||
"action": action
|
||||
},
|
||||
details={"api_key_name": api_key_name, "action": action},
|
||||
severity=AuditSeverity.MEDIUM,
|
||||
category="security",
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
tags=["api_key", action]
|
||||
tags=["api_key", action],
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_budget_event(cls, user_id: int, action: str, budget_id: int,
|
||||
budget_name: str, details: Dict[str, Any] = None,
|
||||
success: bool = True) -> "AuditLog":
|
||||
def create_budget_event(
|
||||
cls,
|
||||
user_id: int,
|
||||
action: str,
|
||||
budget_id: int,
|
||||
budget_name: str,
|
||||
details: Dict[str, Any] = None,
|
||||
success: bool = True,
|
||||
) -> "AuditLog":
|
||||
"""Create a budget audit event"""
|
||||
return cls(
|
||||
user_id=user_id,
|
||||
@@ -227,16 +254,24 @@ class AuditLog(Base):
|
||||
resource_id=str(budget_id),
|
||||
description=f"Budget {action}: {budget_name}",
|
||||
details=details or {},
|
||||
severity=AuditSeverity.MEDIUM if action == AuditAction.BUDGET_EXCEED else AuditSeverity.LOW,
|
||||
severity=AuditSeverity.MEDIUM
|
||||
if action == AuditAction.BUDGET_EXCEED
|
||||
else AuditSeverity.LOW,
|
||||
category="financial",
|
||||
success=success,
|
||||
tags=["budget", action]
|
||||
tags=["budget", action],
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_module_event(cls, user_id: int, action: str, module_name: str,
|
||||
success: bool = True, error_message: str = None,
|
||||
details: Dict[str, Any] = None) -> "AuditLog":
|
||||
def create_module_event(
|
||||
cls,
|
||||
user_id: int,
|
||||
action: str,
|
||||
module_name: str,
|
||||
success: bool = True,
|
||||
error_message: str = None,
|
||||
details: Dict[str, Any] = None,
|
||||
) -> "AuditLog":
|
||||
"""Create a module audit event"""
|
||||
return cls(
|
||||
user_id=user_id,
|
||||
@@ -249,12 +284,18 @@ class AuditLog(Base):
|
||||
category="system",
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
tags=["module", action]
|
||||
tags=["module", action],
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_permission_event(cls, user_id: int, action: str, target_user_id: int,
|
||||
permission: str, success: bool = True) -> "AuditLog":
|
||||
def create_permission_event(
|
||||
cls,
|
||||
user_id: int,
|
||||
action: str,
|
||||
target_user_id: int,
|
||||
permission: str,
|
||||
success: bool = True,
|
||||
) -> "AuditLog":
|
||||
"""Create a permission audit event"""
|
||||
return cls(
|
||||
user_id=user_id,
|
||||
@@ -262,21 +303,23 @@ class AuditLog(Base):
|
||||
resource_type="permission",
|
||||
resource_id=str(target_user_id),
|
||||
description=f"Permission {action}: {permission} for user {target_user_id}",
|
||||
details={
|
||||
"permission": permission,
|
||||
"target_user_id": target_user_id
|
||||
},
|
||||
details={"permission": permission, "target_user_id": target_user_id},
|
||||
severity=AuditSeverity.HIGH,
|
||||
category="security",
|
||||
success=success,
|
||||
tags=["permission", action]
|
||||
tags=["permission", action],
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_security_event(cls, user_id: int, event_type: str, description: str,
|
||||
severity: str = AuditSeverity.HIGH,
|
||||
details: Dict[str, Any] = None,
|
||||
ip_address: str = None) -> "AuditLog":
|
||||
def create_security_event(
|
||||
cls,
|
||||
user_id: int,
|
||||
event_type: str,
|
||||
description: str,
|
||||
severity: str = AuditSeverity.HIGH,
|
||||
details: Dict[str, Any] = None,
|
||||
ip_address: str = None,
|
||||
) -> "AuditLog":
|
||||
"""Create a security audit event"""
|
||||
return cls(
|
||||
user_id=user_id,
|
||||
@@ -289,15 +332,19 @@ class AuditLog(Base):
|
||||
severity=severity,
|
||||
category="security",
|
||||
success=False, # Security events are typically failures
|
||||
tags=["security", event_type]
|
||||
tags=["security", event_type],
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_system_event(cls, action: str, description: str,
|
||||
resource_type: str = "system",
|
||||
resource_id: str = None,
|
||||
severity: str = AuditSeverity.LOW,
|
||||
details: Dict[str, Any] = None) -> "AuditLog":
|
||||
def create_system_event(
|
||||
cls,
|
||||
action: str,
|
||||
description: str,
|
||||
resource_type: str = "system",
|
||||
resource_id: str = None,
|
||||
severity: str = AuditSeverity.LOW,
|
||||
details: Dict[str, Any] = None,
|
||||
) -> "AuditLog":
|
||||
"""Create a system audit event"""
|
||||
return cls(
|
||||
user_id=None, # System events don't have a user
|
||||
@@ -309,14 +356,20 @@ class AuditLog(Base):
|
||||
severity=severity,
|
||||
category="system",
|
||||
success=True,
|
||||
tags=["system", action]
|
||||
tags=["system", action],
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_data_change_event(cls, user_id: int, action: str, resource_type: str,
|
||||
resource_id: str, description: str,
|
||||
old_values: Dict[str, Any],
|
||||
new_values: Dict[str, Any]) -> "AuditLog":
|
||||
def create_data_change_event(
|
||||
cls,
|
||||
user_id: int,
|
||||
action: str,
|
||||
resource_type: str,
|
||||
resource_id: str,
|
||||
description: str,
|
||||
old_values: Dict[str, Any],
|
||||
new_values: Dict[str, Any],
|
||||
) -> "AuditLog":
|
||||
"""Create a data change audit event"""
|
||||
return cls(
|
||||
user_id=user_id,
|
||||
@@ -329,9 +382,9 @@ class AuditLog(Base):
|
||||
severity=AuditSeverity.LOW,
|
||||
category="data",
|
||||
success=True,
|
||||
tags=["data_change", action]
|
||||
tags=["data_change", action],
|
||||
)
|
||||
|
||||
|
||||
def get_summary(self) -> Dict[str, Any]:
|
||||
"""Get a summary of the audit log"""
|
||||
return {
|
||||
@@ -342,5 +395,5 @@ class AuditLog(Base):
|
||||
"severity": self.severity,
|
||||
"success": self.success,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"user_id": self.user_id
|
||||
}
|
||||
"user_id": self.user_id,
|
||||
}
|
||||
|
||||
@@ -5,13 +5,24 @@ Budget model for managing spending limits and cost control
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
from enum import Enum
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON, ForeignKey, Float
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
Boolean,
|
||||
Text,
|
||||
JSON,
|
||||
ForeignKey,
|
||||
Float,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class BudgetType(str, Enum):
|
||||
"""Budget type enumeration"""
|
||||
|
||||
USER = "user"
|
||||
API_KEY = "api_key"
|
||||
GLOBAL = "global"
|
||||
@@ -19,6 +30,7 @@ class BudgetType(str, Enum):
|
||||
|
||||
class BudgetPeriod(str, Enum):
|
||||
"""Budget period types"""
|
||||
|
||||
DAILY = "daily"
|
||||
WEEKLY = "weekly"
|
||||
MONTHLY = "monthly"
|
||||
@@ -28,67 +40,85 @@ class BudgetPeriod(str, Enum):
|
||||
|
||||
class Budget(Base):
|
||||
"""Budget model for setting and managing spending limits"""
|
||||
|
||||
|
||||
__tablename__ = "budgets"
|
||||
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False) # Human-readable name for the budget
|
||||
|
||||
|
||||
# User and API Key relationships
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
user = relationship("User", back_populates="budgets")
|
||||
|
||||
api_key_id = Column(Integer, ForeignKey("api_keys.id"), nullable=True) # Optional: specific to an API key
|
||||
|
||||
api_key_id = Column(
|
||||
Integer, ForeignKey("api_keys.id"), nullable=True
|
||||
) # Optional: specific to an API key
|
||||
api_key = relationship("APIKey", back_populates="budgets")
|
||||
|
||||
|
||||
# Usage tracking relationship
|
||||
usage_tracking = relationship("UsageTracking", back_populates="budget", cascade="all, delete-orphan")
|
||||
|
||||
usage_tracking = relationship(
|
||||
"UsageTracking", back_populates="budget", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Budget limits (in cents)
|
||||
limit_cents = Column(Integer, nullable=False) # Maximum spend limit
|
||||
warning_threshold_cents = Column(Integer, nullable=True) # Warning threshold (e.g., 80% of limit)
|
||||
|
||||
warning_threshold_cents = Column(
|
||||
Integer, nullable=True
|
||||
) # Warning threshold (e.g., 80% of limit)
|
||||
|
||||
# Time period settings
|
||||
period_type = Column(String, nullable=False, default="monthly") # daily, weekly, monthly, yearly, custom
|
||||
period_type = Column(
|
||||
String, nullable=False, default="monthly"
|
||||
) # daily, weekly, monthly, yearly, custom
|
||||
period_start = Column(DateTime, nullable=False) # Start of current period
|
||||
period_end = Column(DateTime, nullable=False) # End of current period
|
||||
|
||||
|
||||
# Current usage (in cents)
|
||||
current_usage_cents = Column(Integer, default=0) # Spent in current period
|
||||
|
||||
|
||||
# Budget status
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_exceeded = Column(Boolean, default=False)
|
||||
is_warning_sent = Column(Boolean, default=False)
|
||||
|
||||
|
||||
# Enforcement settings
|
||||
enforce_hard_limit = Column(Boolean, default=True) # Block requests when limit exceeded
|
||||
enforce_hard_limit = Column(
|
||||
Boolean, default=True
|
||||
) # Block requests when limit exceeded
|
||||
enforce_warning = Column(Boolean, default=True) # Send warnings at threshold
|
||||
|
||||
|
||||
# Allowed resources (optional filters)
|
||||
allowed_models = Column(JSON, default=list) # Specific models this budget applies to
|
||||
allowed_endpoints = Column(JSON, default=list) # Specific endpoints this budget applies to
|
||||
|
||||
allowed_models = Column(
|
||||
JSON, default=list
|
||||
) # Specific models this budget applies to
|
||||
allowed_endpoints = Column(
|
||||
JSON, default=list
|
||||
) # Specific endpoints this budget applies to
|
||||
|
||||
# Metadata
|
||||
description = Column(Text, nullable=True)
|
||||
tags = Column(JSON, default=list)
|
||||
currency = Column(String, default="USD")
|
||||
|
||||
|
||||
# Auto-renewal settings
|
||||
auto_renew = Column(Boolean, default=True) # Automatically renew budget for next period
|
||||
rollover_unused = Column(Boolean, default=False) # Rollover unused budget to next period
|
||||
|
||||
auto_renew = Column(
|
||||
Boolean, default=True
|
||||
) # Automatically renew budget for next period
|
||||
rollover_unused = Column(
|
||||
Boolean, default=False
|
||||
) # Rollover unused budget to next period
|
||||
|
||||
# Notification settings
|
||||
notification_settings = Column(JSON, default=dict) # Email, webhook, etc.
|
||||
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
last_reset_at = Column(DateTime, nullable=True) # Last time budget was reset
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Budget(id={self.id}, name='{self.name}', user_id={self.user_id}, limit=${self.limit_cents/100:.2f})>"
|
||||
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert budget to dictionary for API responses"""
|
||||
return {
|
||||
@@ -99,15 +129,23 @@ class Budget(Base):
|
||||
"limit_cents": self.limit_cents,
|
||||
"limit_dollars": self.limit_cents / 100,
|
||||
"warning_threshold_cents": self.warning_threshold_cents,
|
||||
"warning_threshold_dollars": self.warning_threshold_cents / 100 if self.warning_threshold_cents else None,
|
||||
"warning_threshold_dollars": self.warning_threshold_cents / 100
|
||||
if self.warning_threshold_cents
|
||||
else None,
|
||||
"period_type": self.period_type,
|
||||
"period_start": self.period_start.isoformat() if self.period_start else None,
|
||||
"period_start": self.period_start.isoformat()
|
||||
if self.period_start
|
||||
else None,
|
||||
"period_end": self.period_end.isoformat() if self.period_end else None,
|
||||
"current_usage_cents": self.current_usage_cents,
|
||||
"current_usage_dollars": self.current_usage_cents / 100,
|
||||
"remaining_cents": max(0, self.limit_cents - self.current_usage_cents),
|
||||
"remaining_dollars": max(0, (self.limit_cents - self.current_usage_cents) / 100),
|
||||
"usage_percentage": (self.current_usage_cents / self.limit_cents * 100) if self.limit_cents > 0 else 0,
|
||||
"remaining_dollars": max(
|
||||
0, (self.limit_cents - self.current_usage_cents) / 100
|
||||
),
|
||||
"usage_percentage": (self.current_usage_cents / self.limit_cents * 100)
|
||||
if self.limit_cents > 0
|
||||
else 0,
|
||||
"is_active": self.is_active,
|
||||
"is_exceeded": self.is_exceeded,
|
||||
"is_warning_sent": self.is_warning_sent,
|
||||
@@ -123,62 +161,67 @@ class Budget(Base):
|
||||
"notification_settings": self.notification_settings,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"last_reset_at": self.last_reset_at.isoformat() if self.last_reset_at else None
|
||||
"last_reset_at": self.last_reset_at.isoformat()
|
||||
if self.last_reset_at
|
||||
else None,
|
||||
}
|
||||
|
||||
|
||||
def is_in_period(self) -> bool:
|
||||
"""Check if current time is within budget period"""
|
||||
now = datetime.utcnow()
|
||||
return self.period_start <= now <= self.period_end
|
||||
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if budget period has expired"""
|
||||
return datetime.utcnow() > self.period_end
|
||||
|
||||
|
||||
def can_spend(self, amount_cents: int) -> bool:
|
||||
"""Check if spending amount is within budget"""
|
||||
if not self.is_active or not self.is_in_period():
|
||||
return False
|
||||
|
||||
|
||||
if not self.enforce_hard_limit:
|
||||
return True
|
||||
|
||||
|
||||
return (self.current_usage_cents + amount_cents) <= self.limit_cents
|
||||
|
||||
|
||||
def would_exceed_warning(self, amount_cents: int) -> bool:
|
||||
"""Check if spending amount would exceed warning threshold"""
|
||||
if not self.warning_threshold_cents:
|
||||
return False
|
||||
|
||||
|
||||
return (self.current_usage_cents + amount_cents) >= self.warning_threshold_cents
|
||||
|
||||
|
||||
def add_usage(self, amount_cents: int):
|
||||
"""Add usage to current budget"""
|
||||
self.current_usage_cents += amount_cents
|
||||
|
||||
|
||||
# Check if budget is exceeded
|
||||
if self.current_usage_cents >= self.limit_cents:
|
||||
self.is_exceeded = True
|
||||
|
||||
|
||||
# Check if warning threshold is reached
|
||||
if self.warning_threshold_cents and self.current_usage_cents >= self.warning_threshold_cents:
|
||||
if (
|
||||
self.warning_threshold_cents
|
||||
and self.current_usage_cents >= self.warning_threshold_cents
|
||||
):
|
||||
if not self.is_warning_sent:
|
||||
self.is_warning_sent = True
|
||||
|
||||
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
def reset_period(self):
|
||||
"""Reset budget for new period"""
|
||||
if self.rollover_unused and self.current_usage_cents < self.limit_cents:
|
||||
# Rollover unused budget
|
||||
unused_amount = self.limit_cents - self.current_usage_cents
|
||||
self.limit_cents += unused_amount
|
||||
|
||||
|
||||
self.current_usage_cents = 0
|
||||
self.is_exceeded = False
|
||||
self.is_warning_sent = False
|
||||
self.last_reset_at = datetime.utcnow()
|
||||
|
||||
|
||||
# Calculate next period
|
||||
if self.period_type == "daily":
|
||||
self.period_start = self.period_end
|
||||
@@ -190,39 +233,43 @@ class Budget(Base):
|
||||
self.period_start = self.period_end
|
||||
# Handle month boundaries properly
|
||||
if self.period_start.month == 12:
|
||||
next_month = self.period_start.replace(year=self.period_start.year + 1, month=1)
|
||||
next_month = self.period_start.replace(
|
||||
year=self.period_start.year + 1, month=1
|
||||
)
|
||||
else:
|
||||
next_month = self.period_start.replace(month=self.period_start.month + 1)
|
||||
next_month = self.period_start.replace(
|
||||
month=self.period_start.month + 1
|
||||
)
|
||||
self.period_end = next_month
|
||||
elif self.period_type == "yearly":
|
||||
self.period_start = self.period_end
|
||||
self.period_end = self.period_start.replace(year=self.period_start.year + 1)
|
||||
|
||||
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
def get_period_days_remaining(self) -> int:
|
||||
"""Get number of days remaining in current period"""
|
||||
if self.is_expired():
|
||||
return 0
|
||||
return (self.period_end - datetime.utcnow()).days
|
||||
|
||||
|
||||
def get_daily_burn_rate(self) -> float:
|
||||
"""Get average daily spend rate in current period"""
|
||||
if not self.is_in_period():
|
||||
return 0.0
|
||||
|
||||
|
||||
days_elapsed = (datetime.utcnow() - self.period_start).days
|
||||
if days_elapsed == 0:
|
||||
days_elapsed = 1 # Avoid division by zero
|
||||
|
||||
|
||||
return self.current_usage_cents / days_elapsed / 100 # Return in dollars
|
||||
|
||||
|
||||
def get_projected_spend(self) -> float:
|
||||
"""Get projected spend for entire period based on current burn rate"""
|
||||
daily_burn = self.get_daily_burn_rate()
|
||||
total_period_days = (self.period_end - self.period_start).days
|
||||
return daily_burn * total_period_days
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_monthly_budget(
|
||||
cls,
|
||||
@@ -230,7 +277,7 @@ class Budget(Base):
|
||||
name: str,
|
||||
limit_dollars: float,
|
||||
api_key_id: Optional[int] = None,
|
||||
warning_threshold_percentage: float = 0.8
|
||||
warning_threshold_percentage: float = 0.8,
|
||||
) -> "Budget":
|
||||
"""Create a monthly budget"""
|
||||
now = datetime.utcnow()
|
||||
@@ -241,10 +288,10 @@ class Budget(Base):
|
||||
period_end = period_start.replace(year=now.year + 1, month=1)
|
||||
else:
|
||||
period_end = period_start.replace(month=now.month + 1)
|
||||
|
||||
|
||||
limit_cents = int(limit_dollars * 100)
|
||||
warning_threshold_cents = int(limit_cents * warning_threshold_percentage)
|
||||
|
||||
|
||||
return cls(
|
||||
name=name,
|
||||
user_id=user_id,
|
||||
@@ -258,28 +305,25 @@ class Budget(Base):
|
||||
enforce_hard_limit=True,
|
||||
enforce_warning=True,
|
||||
auto_renew=True,
|
||||
notification_settings={
|
||||
"email_on_warning": True,
|
||||
"email_on_exceeded": True
|
||||
}
|
||||
notification_settings={"email_on_warning": True, "email_on_exceeded": True},
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_daily_budget(
|
||||
cls,
|
||||
user_id: int,
|
||||
name: str,
|
||||
limit_dollars: float,
|
||||
api_key_id: Optional[int] = None
|
||||
api_key_id: Optional[int] = None,
|
||||
) -> "Budget":
|
||||
"""Create a daily budget"""
|
||||
now = datetime.utcnow()
|
||||
period_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
period_end = period_start + timedelta(days=1)
|
||||
|
||||
|
||||
limit_cents = int(limit_dollars * 100)
|
||||
warning_threshold_cents = int(limit_cents * 0.8) # 80% warning threshold
|
||||
|
||||
|
||||
return cls(
|
||||
name=name,
|
||||
user_id=user_id,
|
||||
@@ -292,5 +336,5 @@ class Budget(Base):
|
||||
is_active=True,
|
||||
enforce_hard_limit=True,
|
||||
enforce_warning=True,
|
||||
auto_renew=True
|
||||
)
|
||||
auto_renew=True,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,16 @@
|
||||
"""
|
||||
Database models for chatbot module
|
||||
"""
|
||||
from sqlalchemy import Column, Integer, String, Text, Boolean, DateTime, JSON, ForeignKey
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
Boolean,
|
||||
DateTime,
|
||||
JSON,
|
||||
ForeignKey,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from datetime import datetime
|
||||
@@ -9,102 +18,115 @@ import uuid
|
||||
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class ChatbotInstance(Base):
|
||||
"""Configured chatbot instance"""
|
||||
|
||||
__tablename__ = "chatbot_instances"
|
||||
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
name = Column(String(255), nullable=False)
|
||||
description = Column(Text)
|
||||
|
||||
|
||||
# Configuration stored as JSON
|
||||
config = Column(JSON, nullable=False)
|
||||
|
||||
|
||||
# Metadata
|
||||
created_by = Column(String, nullable=False) # User ID
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
|
||||
# Relationships
|
||||
conversations = relationship("ChatbotConversation", back_populates="chatbot", cascade="all, delete-orphan")
|
||||
|
||||
conversations = relationship(
|
||||
"ChatbotConversation", back_populates="chatbot", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ChatbotInstance(id='{self.id}', name='{self.name}')>"
|
||||
|
||||
|
||||
class ChatbotConversation(Base):
|
||||
"""Conversation state and history"""
|
||||
|
||||
__tablename__ = "chatbot_conversations"
|
||||
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
chatbot_id = Column(String, ForeignKey("chatbot_instances.id"), nullable=False)
|
||||
user_id = Column(String, nullable=False) # User ID
|
||||
|
||||
|
||||
# Conversation metadata
|
||||
title = Column(String(255)) # Auto-generated or user-defined title
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
|
||||
# Conversation context and settings
|
||||
context_data = Column(JSON, default=dict) # Additional context
|
||||
|
||||
|
||||
# Relationships
|
||||
chatbot = relationship("ChatbotInstance", back_populates="conversations")
|
||||
messages = relationship("ChatbotMessage", back_populates="conversation", cascade="all, delete-orphan")
|
||||
|
||||
messages = relationship(
|
||||
"ChatbotMessage", back_populates="conversation", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ChatbotConversation(id='{self.id}', chatbot_id='{self.chatbot_id}')>"
|
||||
|
||||
|
||||
class ChatbotMessage(Base):
|
||||
"""Individual chat messages in conversations"""
|
||||
|
||||
__tablename__ = "chatbot_messages"
|
||||
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
conversation_id = Column(String, ForeignKey("chatbot_conversations.id"), nullable=False)
|
||||
|
||||
conversation_id = Column(
|
||||
String, ForeignKey("chatbot_conversations.id"), nullable=False
|
||||
)
|
||||
|
||||
# Message content
|
||||
role = Column(String(20), nullable=False) # 'user', 'assistant', 'system'
|
||||
content = Column(Text, nullable=False)
|
||||
|
||||
|
||||
# Metadata
|
||||
timestamp = Column(DateTime, default=datetime.utcnow)
|
||||
message_metadata = Column(JSON, default=dict) # Token counts, model used, etc.
|
||||
|
||||
|
||||
# RAG sources if applicable
|
||||
sources = Column(JSON) # RAG sources used for this message
|
||||
|
||||
|
||||
# Relationships
|
||||
conversation = relationship("ChatbotConversation", back_populates="messages")
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ChatbotMessage(id='{self.id}', role='{self.role}')>"
|
||||
|
||||
|
||||
class ChatbotAnalytics(Base):
|
||||
"""Analytics and metrics for chatbot usage"""
|
||||
|
||||
__tablename__ = "chatbot_analytics"
|
||||
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
chatbot_id = Column(String, ForeignKey("chatbot_instances.id"), nullable=False)
|
||||
user_id = Column(String, nullable=False)
|
||||
|
||||
|
||||
# Event tracking
|
||||
event_type = Column(String(50), nullable=False) # 'message_sent', 'response_generated', etc.
|
||||
event_type = Column(
|
||||
String(50), nullable=False
|
||||
) # 'message_sent', 'response_generated', etc.
|
||||
event_data = Column(JSON, default=dict)
|
||||
|
||||
|
||||
# Performance metrics
|
||||
response_time_ms = Column(Integer)
|
||||
token_count = Column(Integer)
|
||||
cost_cents = Column(Integer)
|
||||
|
||||
|
||||
# Context
|
||||
model_used = Column(String(100))
|
||||
rag_used = Column(Boolean, default=False)
|
||||
|
||||
|
||||
timestamp = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ChatbotAnalytics(id={self.id}, event_type='{self.event_type}')>"
|
||||
return f"<ChatbotAnalytics(id={self.id}, event_type='{self.event_type}')>"
|
||||
|
||||
@@ -10,6 +10,7 @@ from enum import Enum
|
||||
|
||||
class ModuleStatus(str, Enum):
|
||||
"""Module status types"""
|
||||
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
ERROR = "error"
|
||||
@@ -19,6 +20,7 @@ class ModuleStatus(str, Enum):
|
||||
|
||||
class ModuleType(str, Enum):
|
||||
"""Module type categories"""
|
||||
|
||||
CORE = "core"
|
||||
INTERCEPTOR = "interceptor"
|
||||
ANALYTICS = "analytics"
|
||||
@@ -30,75 +32,81 @@ class ModuleType(str, Enum):
|
||||
|
||||
class Module(Base):
|
||||
"""Module model for tracking installed modules and their configurations"""
|
||||
|
||||
|
||||
__tablename__ = "modules"
|
||||
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, unique=True, index=True, nullable=False)
|
||||
display_name = Column(String, nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
|
||||
# Module classification
|
||||
module_type = Column(String, default=ModuleType.CUSTOM)
|
||||
category = Column(String, nullable=True) # cache, rag, analytics, etc.
|
||||
|
||||
|
||||
# Module details
|
||||
version = Column(String, nullable=False)
|
||||
author = Column(String, nullable=True)
|
||||
license = Column(String, nullable=True)
|
||||
|
||||
|
||||
# Module status
|
||||
status = Column(String, default=ModuleStatus.INACTIVE)
|
||||
is_enabled = Column(Boolean, default=False)
|
||||
is_core = Column(Boolean, default=False) # Core modules cannot be disabled
|
||||
|
||||
|
||||
# Configuration
|
||||
config_schema = Column(JSON, default=dict) # JSON schema for configuration
|
||||
config_values = Column(JSON, default=dict) # Current configuration values
|
||||
default_config = Column(JSON, default=dict) # Default configuration
|
||||
|
||||
|
||||
# Dependencies
|
||||
dependencies = Column(JSON, default=list) # List of module dependencies
|
||||
conflicts = Column(JSON, default=list) # List of conflicting modules
|
||||
|
||||
|
||||
# Installation details
|
||||
install_path = Column(String, nullable=True)
|
||||
entry_point = Column(String, nullable=True) # Main module entry point
|
||||
|
||||
|
||||
# Interceptor configuration
|
||||
interceptor_chains = Column(JSON, default=list) # Which chains this module hooks into
|
||||
interceptor_chains = Column(
|
||||
JSON, default=list
|
||||
) # Which chains this module hooks into
|
||||
execution_order = Column(Integer, default=100) # Order in interceptor chain
|
||||
|
||||
|
||||
# API endpoints
|
||||
api_endpoints = Column(JSON, default=list) # List of API endpoints this module provides
|
||||
|
||||
api_endpoints = Column(
|
||||
JSON, default=list
|
||||
) # List of API endpoints this module provides
|
||||
|
||||
# Permissions and security
|
||||
required_permissions = Column(JSON, default=list) # Permissions required to use this module
|
||||
required_permissions = Column(
|
||||
JSON, default=list
|
||||
) # Permissions required to use this module
|
||||
security_level = Column(String, default="low") # low, medium, high, critical
|
||||
|
||||
|
||||
# Metadata
|
||||
tags = Column(JSON, default=list)
|
||||
module_metadata = Column(JSON, default=dict)
|
||||
|
||||
|
||||
# Runtime information
|
||||
last_error = Column(Text, nullable=True)
|
||||
error_count = Column(Integer, default=0)
|
||||
last_started = Column(DateTime, nullable=True)
|
||||
last_stopped = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
# Statistics
|
||||
request_count = Column(Integer, default=0)
|
||||
success_count = Column(Integer, default=0)
|
||||
error_count_runtime = Column(Integer, default=0)
|
||||
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
installed_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Module(id={self.id}, name='{self.name}', status='{self.status}')>"
|
||||
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert module to dictionary for API responses"""
|
||||
return {
|
||||
@@ -130,81 +138,87 @@ class Module(Base):
|
||||
"metadata": self.module_metadata,
|
||||
"last_error": self.last_error,
|
||||
"error_count": self.error_count,
|
||||
"last_started": self.last_started.isoformat() if self.last_started else None,
|
||||
"last_stopped": self.last_stopped.isoformat() if self.last_stopped else None,
|
||||
"last_started": self.last_started.isoformat()
|
||||
if self.last_started
|
||||
else None,
|
||||
"last_stopped": self.last_stopped.isoformat()
|
||||
if self.last_stopped
|
||||
else None,
|
||||
"request_count": self.request_count,
|
||||
"success_count": self.success_count,
|
||||
"error_count_runtime": self.error_count_runtime,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"installed_at": self.installed_at.isoformat() if self.installed_at else None,
|
||||
"installed_at": self.installed_at.isoformat()
|
||||
if self.installed_at
|
||||
else None,
|
||||
"success_rate": self.get_success_rate(),
|
||||
"uptime": self.get_uptime_seconds() if self.is_running() else 0
|
||||
"uptime": self.get_uptime_seconds() if self.is_running() else 0,
|
||||
}
|
||||
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Check if module is currently running"""
|
||||
return self.status == ModuleStatus.ACTIVE
|
||||
|
||||
|
||||
def is_healthy(self) -> bool:
|
||||
"""Check if module is healthy (running without recent errors)"""
|
||||
return self.is_running() and self.error_count_runtime == 0
|
||||
|
||||
|
||||
def get_success_rate(self) -> float:
|
||||
"""Get success rate as percentage"""
|
||||
if self.request_count == 0:
|
||||
return 100.0
|
||||
return (self.success_count / self.request_count) * 100
|
||||
|
||||
|
||||
def get_uptime_seconds(self) -> int:
|
||||
"""Get uptime in seconds"""
|
||||
if not self.last_started:
|
||||
return 0
|
||||
return int((datetime.utcnow() - self.last_started).total_seconds())
|
||||
|
||||
|
||||
def can_be_disabled(self) -> bool:
|
||||
"""Check if module can be disabled"""
|
||||
return not self.is_core
|
||||
|
||||
|
||||
def has_dependency(self, module_name: str) -> bool:
|
||||
"""Check if module has a specific dependency"""
|
||||
return module_name in self.dependencies
|
||||
|
||||
|
||||
def conflicts_with(self, module_name: str) -> bool:
|
||||
"""Check if module conflicts with another module"""
|
||||
return module_name in self.conflicts
|
||||
|
||||
|
||||
def requires_permission(self, permission: str) -> bool:
|
||||
"""Check if module requires a specific permission"""
|
||||
return permission in self.required_permissions
|
||||
|
||||
|
||||
def hooks_into_chain(self, chain_name: str) -> bool:
|
||||
"""Check if module hooks into a specific interceptor chain"""
|
||||
return chain_name in self.interceptor_chains
|
||||
|
||||
|
||||
def provides_endpoint(self, endpoint: str) -> bool:
|
||||
"""Check if module provides a specific API endpoint"""
|
||||
return endpoint in self.api_endpoints
|
||||
|
||||
|
||||
def update_config(self, config_updates: Dict[str, Any]):
|
||||
"""Update module configuration"""
|
||||
if self.config_values is None:
|
||||
self.config_values = {}
|
||||
self.config_values.update(config_updates)
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
def reset_config(self):
|
||||
"""Reset configuration to default values"""
|
||||
self.config_values = self.default_config.copy() if self.default_config else {}
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
def enable(self):
|
||||
"""Enable the module"""
|
||||
if self.status != ModuleStatus.ERROR:
|
||||
self.is_enabled = True
|
||||
self.status = ModuleStatus.LOADING
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
def disable(self):
|
||||
"""Disable the module"""
|
||||
if self.can_be_disabled():
|
||||
@@ -212,20 +226,20 @@ class Module(Base):
|
||||
self.status = ModuleStatus.DISABLED
|
||||
self.last_stopped = datetime.utcnow()
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
def start(self):
|
||||
"""Start the module"""
|
||||
self.status = ModuleStatus.ACTIVE
|
||||
self.last_started = datetime.utcnow()
|
||||
self.last_error = None
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
def stop(self):
|
||||
"""Stop the module"""
|
||||
self.status = ModuleStatus.INACTIVE
|
||||
self.last_stopped = datetime.utcnow()
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
def set_error(self, error_message: str):
|
||||
"""Set module error status"""
|
||||
self.status = ModuleStatus.ERROR
|
||||
@@ -233,13 +247,13 @@ class Module(Base):
|
||||
self.error_count += 1
|
||||
self.error_count_runtime += 1
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
def clear_error(self):
|
||||
"""Clear error status"""
|
||||
self.last_error = None
|
||||
self.error_count_runtime = 0
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
def record_request(self, success: bool = True):
|
||||
"""Record a request to this module"""
|
||||
self.request_count += 1
|
||||
@@ -247,76 +261,82 @@ class Module(Base):
|
||||
self.success_count += 1
|
||||
else:
|
||||
self.error_count_runtime += 1
|
||||
|
||||
|
||||
def add_tag(self, tag: str):
|
||||
"""Add a tag to the module"""
|
||||
if tag not in self.tags:
|
||||
self.tags.append(tag)
|
||||
|
||||
|
||||
def remove_tag(self, tag: str):
|
||||
"""Remove a tag from the module"""
|
||||
if tag in self.tags:
|
||||
self.tags.remove(tag)
|
||||
|
||||
|
||||
def update_metadata(self, key: str, value: Any):
|
||||
"""Update metadata"""
|
||||
if self.module_metadata is None:
|
||||
self.module_metadata = {}
|
||||
self.module_metadata[key] = value
|
||||
|
||||
|
||||
def add_dependency(self, module_name: str):
|
||||
"""Add a dependency"""
|
||||
if module_name not in self.dependencies:
|
||||
self.dependencies.append(module_name)
|
||||
|
||||
|
||||
def remove_dependency(self, module_name: str):
|
||||
"""Remove a dependency"""
|
||||
if module_name in self.dependencies:
|
||||
self.dependencies.remove(module_name)
|
||||
|
||||
|
||||
def add_conflict(self, module_name: str):
|
||||
"""Add a conflict"""
|
||||
if module_name not in self.conflicts:
|
||||
self.conflicts.append(module_name)
|
||||
|
||||
|
||||
def remove_conflict(self, module_name: str):
|
||||
"""Remove a conflict"""
|
||||
if module_name in self.conflicts:
|
||||
self.conflicts.remove(module_name)
|
||||
|
||||
|
||||
def add_interceptor_chain(self, chain_name: str):
|
||||
"""Add an interceptor chain"""
|
||||
if chain_name not in self.interceptor_chains:
|
||||
self.interceptor_chains.append(chain_name)
|
||||
|
||||
|
||||
def remove_interceptor_chain(self, chain_name: str):
|
||||
"""Remove an interceptor chain"""
|
||||
if chain_name in self.interceptor_chains:
|
||||
self.interceptor_chains.remove(chain_name)
|
||||
|
||||
|
||||
def add_api_endpoint(self, endpoint: str):
|
||||
"""Add an API endpoint"""
|
||||
if endpoint not in self.api_endpoints:
|
||||
self.api_endpoints.append(endpoint)
|
||||
|
||||
|
||||
def remove_api_endpoint(self, endpoint: str):
|
||||
"""Remove an API endpoint"""
|
||||
if endpoint in self.api_endpoints:
|
||||
self.api_endpoints.remove(endpoint)
|
||||
|
||||
|
||||
def add_required_permission(self, permission: str):
|
||||
"""Add a required permission"""
|
||||
if permission not in self.required_permissions:
|
||||
self.required_permissions.append(permission)
|
||||
|
||||
|
||||
def remove_required_permission(self, permission: str):
|
||||
"""Remove a required permission"""
|
||||
if permission in self.required_permissions:
|
||||
self.required_permissions.remove(permission)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_core_module(cls, name: str, display_name: str, description: str,
|
||||
version: str, entry_point: str) -> "Module":
|
||||
def create_core_module(
|
||||
cls,
|
||||
name: str,
|
||||
display_name: str,
|
||||
description: str,
|
||||
version: str,
|
||||
entry_point: str,
|
||||
) -> "Module":
|
||||
"""Create a core module"""
|
||||
return cls(
|
||||
name=name,
|
||||
@@ -341,9 +361,9 @@ class Module(Base):
|
||||
required_permissions=[],
|
||||
security_level="high",
|
||||
tags=["core"],
|
||||
module_metadata={}
|
||||
module_metadata={},
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_cache_module(cls) -> "Module":
|
||||
"""Create the cache module"""
|
||||
@@ -365,20 +385,12 @@ class Module(Base):
|
||||
"properties": {
|
||||
"provider": {"type": "string", "enum": ["redis"]},
|
||||
"ttl": {"type": "integer", "minimum": 60},
|
||||
"max_size": {"type": "integer", "minimum": 1000}
|
||||
"max_size": {"type": "integer", "minimum": 1000},
|
||||
},
|
||||
"required": ["provider", "ttl"]
|
||||
},
|
||||
config_values={
|
||||
"provider": "redis",
|
||||
"ttl": 3600,
|
||||
"max_size": 10000
|
||||
},
|
||||
default_config={
|
||||
"provider": "redis",
|
||||
"ttl": 3600,
|
||||
"max_size": 10000
|
||||
"required": ["provider", "ttl"],
|
||||
},
|
||||
config_values={"provider": "redis", "ttl": 3600, "max_size": 10000},
|
||||
default_config={"provider": "redis", "ttl": 3600, "max_size": 10000},
|
||||
dependencies=[],
|
||||
conflicts=[],
|
||||
interceptor_chains=["pre_request", "post_response"],
|
||||
@@ -387,9 +399,9 @@ class Module(Base):
|
||||
required_permissions=["cache.read", "cache.write"],
|
||||
security_level="low",
|
||||
tags=["cache", "performance"],
|
||||
module_metadata={}
|
||||
module_metadata={},
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_rag_module(cls) -> "Module":
|
||||
"""Create the RAG module"""
|
||||
@@ -412,21 +424,21 @@ class Module(Base):
|
||||
"vector_db": {"type": "string", "enum": ["qdrant"]},
|
||||
"embedding_model": {"type": "string"},
|
||||
"chunk_size": {"type": "integer", "minimum": 100},
|
||||
"max_results": {"type": "integer", "minimum": 1}
|
||||
"max_results": {"type": "integer", "minimum": 1},
|
||||
},
|
||||
"required": ["vector_db", "embedding_model"]
|
||||
"required": ["vector_db", "embedding_model"],
|
||||
},
|
||||
config_values={
|
||||
"vector_db": "qdrant",
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"chunk_size": 512,
|
||||
"max_results": 10
|
||||
"max_results": 10,
|
||||
},
|
||||
default_config={
|
||||
"vector_db": "qdrant",
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"chunk_size": 512,
|
||||
"max_results": 10
|
||||
"max_results": 10,
|
||||
},
|
||||
dependencies=[],
|
||||
conflicts=[],
|
||||
@@ -436,9 +448,9 @@ class Module(Base):
|
||||
required_permissions=["rag.read", "rag.write"],
|
||||
security_level="medium",
|
||||
tags=["rag", "ai", "search"],
|
||||
module_metadata={}
|
||||
module_metadata={},
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_analytics_module(cls) -> "Module":
|
||||
"""Create the analytics module"""
|
||||
@@ -460,19 +472,19 @@ class Module(Base):
|
||||
"properties": {
|
||||
"track_requests": {"type": "boolean"},
|
||||
"track_responses": {"type": "boolean"},
|
||||
"retention_days": {"type": "integer", "minimum": 1}
|
||||
"retention_days": {"type": "integer", "minimum": 1},
|
||||
},
|
||||
"required": ["track_requests", "track_responses"]
|
||||
"required": ["track_requests", "track_responses"],
|
||||
},
|
||||
config_values={
|
||||
"track_requests": True,
|
||||
"track_responses": True,
|
||||
"retention_days": 30
|
||||
"retention_days": 30,
|
||||
},
|
||||
default_config={
|
||||
"track_requests": True,
|
||||
"track_responses": True,
|
||||
"retention_days": 30
|
||||
"retention_days": 30,
|
||||
},
|
||||
dependencies=[],
|
||||
conflicts=[],
|
||||
@@ -482,9 +494,9 @@ class Module(Base):
|
||||
required_permissions=["analytics.read"],
|
||||
security_level="low",
|
||||
tags=["analytics", "monitoring"],
|
||||
module_metadata={}
|
||||
module_metadata={},
|
||||
)
|
||||
|
||||
|
||||
def get_health_status(self) -> Dict[str, Any]:
|
||||
"""Get health status of the module"""
|
||||
return {
|
||||
@@ -495,5 +507,7 @@ class Module(Base):
|
||||
"uptime_seconds": self.get_uptime_seconds() if self.is_running() else 0,
|
||||
"last_error": self.last_error,
|
||||
"error_count": self.error_count_runtime,
|
||||
"last_started": self.last_started.isoformat() if self.last_started else None
|
||||
}
|
||||
"last_started": self.last_started.isoformat()
|
||||
if self.last_started
|
||||
else None,
|
||||
}
|
||||
|
||||
295
backend/app/models/notification.py
Normal file
295
backend/app/models/notification.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""
|
||||
Notification models for multi-channel communication
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from enum import Enum
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
Boolean,
|
||||
Text,
|
||||
JSON,
|
||||
ForeignKey,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class NotificationType(str, Enum):
|
||||
"""Notification types"""
|
||||
|
||||
EMAIL = "email"
|
||||
WEBHOOK = "webhook"
|
||||
SLACK = "slack"
|
||||
DISCORD = "discord"
|
||||
SMS = "sms"
|
||||
PUSH = "push"
|
||||
|
||||
|
||||
class NotificationPriority(str, Enum):
|
||||
"""Notification priority levels"""
|
||||
|
||||
LOW = "low"
|
||||
NORMAL = "normal"
|
||||
HIGH = "high"
|
||||
URGENT = "urgent"
|
||||
|
||||
|
||||
class NotificationStatus(str, Enum):
|
||||
"""Notification delivery status"""
|
||||
|
||||
PENDING = "pending"
|
||||
SENT = "sent"
|
||||
DELIVERED = "delivered"
|
||||
FAILED = "failed"
|
||||
RETRY = "retry"
|
||||
|
||||
|
||||
class NotificationTemplate(Base):
|
||||
"""Notification template model"""
|
||||
|
||||
__tablename__ = "notification_templates"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(100), nullable=False, unique=True)
|
||||
display_name = Column(String(200), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# Template content
|
||||
notification_type = Column(String(20), nullable=False) # NotificationType enum
|
||||
subject_template = Column(Text, nullable=True) # For email/messages
|
||||
body_template = Column(Text, nullable=False) # Main content
|
||||
html_template = Column(Text, nullable=True) # HTML version for email
|
||||
|
||||
# Configuration
|
||||
default_priority = Column(String(20), default=NotificationPriority.NORMAL)
|
||||
variables = Column(JSON, default=dict) # Expected template variables
|
||||
template_metadata = Column(JSON, default=dict) # Additional configuration
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
notifications = relationship("Notification", back_populates="template")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<NotificationTemplate(id={self.id}, name='{self.name}', type='{self.notification_type}')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert template to dictionary"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"display_name": self.display_name,
|
||||
"description": self.description,
|
||||
"notification_type": self.notification_type,
|
||||
"subject_template": self.subject_template,
|
||||
"body_template": self.body_template,
|
||||
"html_template": self.html_template,
|
||||
"default_priority": self.default_priority,
|
||||
"variables": self.variables,
|
||||
"template_metadata": self.template_metadata,
|
||||
"is_active": self.is_active,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
class NotificationChannel(Base):
|
||||
"""Notification channel configuration"""
|
||||
|
||||
__tablename__ = "notification_channels"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(100), nullable=False)
|
||||
display_name = Column(String(200), nullable=False)
|
||||
notification_type = Column(String(20), nullable=False) # NotificationType enum
|
||||
|
||||
# Channel configuration
|
||||
config = Column(JSON, nullable=False) # Channel-specific settings
|
||||
credentials = Column(JSON, nullable=True) # Encrypted credentials
|
||||
|
||||
# Settings
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_default = Column(Boolean, default=False)
|
||||
rate_limit = Column(Integer, default=100) # Messages per minute
|
||||
retry_count = Column(Integer, default=3)
|
||||
retry_delay_minutes = Column(Integer, default=5)
|
||||
|
||||
# Health monitoring
|
||||
last_used_at = Column(DateTime, nullable=True)
|
||||
success_count = Column(Integer, default=0)
|
||||
failure_count = Column(Integer, default=0)
|
||||
last_error = Column(Text, nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
notifications = relationship("Notification", back_populates="channel")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<NotificationChannel(id={self.id}, name='{self.name}', type='{self.notification_type}')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert channel to dictionary (excluding sensitive credentials)"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"display_name": self.display_name,
|
||||
"notification_type": self.notification_type,
|
||||
"config": self.config,
|
||||
"is_active": self.is_active,
|
||||
"is_default": self.is_default,
|
||||
"rate_limit": self.rate_limit,
|
||||
"retry_count": self.retry_count,
|
||||
"retry_delay_minutes": self.retry_delay_minutes,
|
||||
"last_used_at": self.last_used_at.isoformat()
|
||||
if self.last_used_at
|
||||
else None,
|
||||
"success_count": self.success_count,
|
||||
"failure_count": self.failure_count,
|
||||
"last_error": self.last_error,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
def update_stats(self, success: bool, error_message: Optional[str] = None):
|
||||
"""Update channel statistics"""
|
||||
self.last_used_at = datetime.utcnow()
|
||||
if success:
|
||||
self.success_count += 1
|
||||
self.last_error = None
|
||||
else:
|
||||
self.failure_count += 1
|
||||
self.last_error = error_message
|
||||
|
||||
|
||||
class Notification(Base):
|
||||
"""Individual notification instance"""
|
||||
|
||||
__tablename__ = "notifications"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
# Content
|
||||
subject = Column(String(500), nullable=True)
|
||||
body = Column(Text, nullable=False)
|
||||
html_body = Column(Text, nullable=True)
|
||||
|
||||
# Recipients
|
||||
recipients = Column(JSON, nullable=False) # List of recipient addresses/IDs
|
||||
cc_recipients = Column(JSON, nullable=True) # CC recipients (for email)
|
||||
bcc_recipients = Column(JSON, nullable=True) # BCC recipients (for email)
|
||||
|
||||
# Configuration
|
||||
priority = Column(String(20), default=NotificationPriority.NORMAL)
|
||||
scheduled_at = Column(DateTime, nullable=True) # For scheduled delivery
|
||||
expires_at = Column(DateTime, nullable=True) # Expiration time
|
||||
|
||||
# References
|
||||
template_id = Column(
|
||||
Integer, ForeignKey("notification_templates.id"), nullable=True
|
||||
)
|
||||
channel_id = Column(Integer, ForeignKey("notification_channels.id"), nullable=False)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=True) # Triggering user
|
||||
|
||||
# Status tracking
|
||||
status = Column(String(20), default=NotificationStatus.PENDING)
|
||||
attempts = Column(Integer, default=0)
|
||||
max_attempts = Column(Integer, default=3)
|
||||
|
||||
# Delivery tracking
|
||||
sent_at = Column(DateTime, nullable=True)
|
||||
delivered_at = Column(DateTime, nullable=True)
|
||||
failed_at = Column(DateTime, nullable=True)
|
||||
error_message = Column(Text, nullable=True)
|
||||
|
||||
# External references
|
||||
external_id = Column(String(200), nullable=True) # Provider message ID
|
||||
callback_url = Column(String(500), nullable=True) # Delivery callback
|
||||
|
||||
# Metadata
|
||||
notification_metadata = Column(JSON, default=dict)
|
||||
tags = Column(JSON, default=list)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
template = relationship("NotificationTemplate", back_populates="notifications")
|
||||
channel = relationship("NotificationChannel", back_populates="notifications")
|
||||
user = relationship("User", back_populates="notifications")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Notification(id={self.id}, status='{self.status}', channel='{self.channel.name if self.channel else 'unknown'}')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert notification to dictionary"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"subject": self.subject,
|
||||
"body": self.body,
|
||||
"html_body": self.html_body,
|
||||
"recipients": self.recipients,
|
||||
"cc_recipients": self.cc_recipients,
|
||||
"bcc_recipients": self.bcc_recipients,
|
||||
"priority": self.priority,
|
||||
"scheduled_at": self.scheduled_at.isoformat()
|
||||
if self.scheduled_at
|
||||
else None,
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"template_id": self.template_id,
|
||||
"channel_id": self.channel_id,
|
||||
"user_id": self.user_id,
|
||||
"status": self.status,
|
||||
"attempts": self.attempts,
|
||||
"max_attempts": self.max_attempts,
|
||||
"sent_at": self.sent_at.isoformat() if self.sent_at else None,
|
||||
"delivered_at": self.delivered_at.isoformat()
|
||||
if self.delivered_at
|
||||
else None,
|
||||
"failed_at": self.failed_at.isoformat() if self.failed_at else None,
|
||||
"error_message": self.error_message,
|
||||
"external_id": self.external_id,
|
||||
"callback_url": self.callback_url,
|
||||
"notification_metadata": self.notification_metadata,
|
||||
"tags": self.tags,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
def mark_sent(self, external_id: Optional[str] = None):
|
||||
"""Mark notification as sent"""
|
||||
self.status = NotificationStatus.SENT
|
||||
self.sent_at = datetime.utcnow()
|
||||
self.external_id = external_id
|
||||
|
||||
def mark_delivered(self):
|
||||
"""Mark notification as delivered"""
|
||||
self.status = NotificationStatus.DELIVERED
|
||||
self.delivered_at = datetime.utcnow()
|
||||
|
||||
def mark_failed(self, error_message: str):
|
||||
"""Mark notification as failed"""
|
||||
self.status = NotificationStatus.FAILED
|
||||
self.failed_at = datetime.utcnow()
|
||||
self.error_message = error_message
|
||||
self.attempts += 1
|
||||
|
||||
def can_retry(self) -> bool:
|
||||
"""Check if notification can be retried"""
|
||||
return (
|
||||
self.status in [NotificationStatus.FAILED, NotificationStatus.RETRY]
|
||||
and self.attempts < self.max_attempts
|
||||
and (self.expires_at is None or self.expires_at > datetime.utcnow())
|
||||
)
|
||||
@@ -2,7 +2,17 @@
|
||||
Plugin System Database Models
|
||||
Defines the database schema for the isolated plugin architecture
|
||||
"""
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, JSON, ForeignKey, Index
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
DateTime,
|
||||
Boolean,
|
||||
JSON,
|
||||
ForeignKey,
|
||||
Index,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
@@ -13,13 +23,16 @@ from app.db.database import Base
|
||||
|
||||
class Plugin(Base):
|
||||
"""Plugin registry - tracks all installed plugins"""
|
||||
|
||||
__tablename__ = "plugins"
|
||||
|
||||
|
||||
# Primary identification
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(100), unique=True, nullable=False, index=True)
|
||||
slug = Column(String(100), unique=True, nullable=False, index=True) # URL-safe identifier
|
||||
|
||||
slug = Column(
|
||||
String(100), unique=True, nullable=False, index=True
|
||||
) # URL-safe identifier
|
||||
|
||||
# Metadata
|
||||
display_name = Column(String(200), nullable=False)
|
||||
description = Column(Text)
|
||||
@@ -27,65 +40,74 @@ class Plugin(Base):
|
||||
author = Column(String(200))
|
||||
homepage = Column(String(500))
|
||||
repository = Column(String(500))
|
||||
|
||||
|
||||
# Plugin file information
|
||||
package_path = Column(String(500), nullable=False) # Path to plugin package
|
||||
manifest_hash = Column(String(64), nullable=False) # SHA256 of manifest file
|
||||
package_hash = Column(String(64), nullable=False) # SHA256 of plugin package
|
||||
|
||||
package_hash = Column(String(64), nullable=False) # SHA256 of plugin package
|
||||
|
||||
# Status and lifecycle
|
||||
status = Column(String(20), nullable=False, default="installed", index=True)
|
||||
# Statuses: installing, installed, enabled, disabled, error, uninstalling
|
||||
enabled = Column(Boolean, default=False, nullable=False, index=True)
|
||||
auto_enable = Column(Boolean, default=False, nullable=False)
|
||||
|
||||
|
||||
# Installation tracking
|
||||
installed_at = Column(DateTime, nullable=False, default=func.now())
|
||||
enabled_at = Column(DateTime)
|
||||
last_updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
|
||||
installed_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
|
||||
|
||||
# Configuration and requirements
|
||||
manifest_data = Column(JSON) # Complete plugin.yaml content
|
||||
config_schema = Column(JSON) # JSON schema for plugin configuration
|
||||
default_config = Column(JSON) # Default configuration values
|
||||
|
||||
|
||||
# Security and permissions
|
||||
required_permissions = Column(JSON) # List of required permission scopes
|
||||
api_scopes = Column(JSON) # Required API access scopes
|
||||
resource_limits = Column(JSON) # Memory, CPU, storage limits
|
||||
|
||||
|
||||
# Database isolation
|
||||
database_name = Column(String(100), unique=True) # Isolated database name
|
||||
database_url = Column(String(1000)) # Connection string for plugin database
|
||||
|
||||
|
||||
# Error tracking
|
||||
last_error = Column(Text)
|
||||
error_count = Column(Integer, default=0)
|
||||
last_error_at = Column(DateTime)
|
||||
|
||||
|
||||
# Relationships
|
||||
installed_by_user = relationship("User", back_populates="installed_plugins")
|
||||
configurations = relationship("PluginConfiguration", back_populates="plugin", cascade="all, delete-orphan")
|
||||
instances = relationship("PluginInstance", back_populates="plugin", cascade="all, delete-orphan")
|
||||
audit_logs = relationship("PluginAuditLog", back_populates="plugin", cascade="all, delete-orphan")
|
||||
cron_jobs = relationship("PluginCronJob", back_populates="plugin", cascade="all, delete-orphan")
|
||||
|
||||
configurations = relationship(
|
||||
"PluginConfiguration", back_populates="plugin", cascade="all, delete-orphan"
|
||||
)
|
||||
instances = relationship(
|
||||
"PluginInstance", back_populates="plugin", cascade="all, delete-orphan"
|
||||
)
|
||||
audit_logs = relationship(
|
||||
"PluginAuditLog", back_populates="plugin", cascade="all, delete-orphan"
|
||||
)
|
||||
cron_jobs = relationship(
|
||||
"PluginCronJob", back_populates="plugin", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Indexes for performance
|
||||
__table_args__ = (
|
||||
Index('idx_plugin_status_enabled', 'status', 'enabled'),
|
||||
Index('idx_plugin_user_status', 'installed_by_user_id', 'status'),
|
||||
Index("idx_plugin_status_enabled", "status", "enabled"),
|
||||
Index("idx_plugin_user_status", "installed_by_user_id", "status"),
|
||||
)
|
||||
|
||||
|
||||
class PluginConfiguration(Base):
|
||||
"""Plugin configuration instances - per user/environment configs"""
|
||||
|
||||
__tablename__ = "plugin_configurations"
|
||||
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
plugin_id = Column(UUID(as_uuid=True), ForeignKey("plugins.id"), nullable=False)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
|
||||
|
||||
# Configuration data
|
||||
name = Column(String(200), nullable=False) # Human-readable config name
|
||||
description = Column(Text)
|
||||
@@ -94,133 +116,140 @@ class PluginConfiguration(Base):
|
||||
schema_version = Column(String(50)) # Schema version for migration support
|
||||
is_active = Column(Boolean, default=False, nullable=False)
|
||||
is_default = Column(Boolean, default=False, nullable=False)
|
||||
|
||||
|
||||
# Metadata
|
||||
created_at = Column(DateTime, nullable=False, default=func.now())
|
||||
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
|
||||
created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
|
||||
|
||||
# Relationships
|
||||
plugin = relationship("Plugin", back_populates="configurations")
|
||||
user = relationship("User", foreign_keys=[user_id])
|
||||
created_by_user = relationship("User", foreign_keys=[created_by_user_id])
|
||||
|
||||
|
||||
# Constraints
|
||||
__table_args__ = (
|
||||
Index('idx_plugin_config_user_active', 'plugin_id', 'user_id', 'is_active'),
|
||||
Index("idx_plugin_config_user_active", "plugin_id", "user_id", "is_active"),
|
||||
)
|
||||
|
||||
|
||||
class PluginInstance(Base):
|
||||
"""Plugin runtime instances - tracks running plugin processes"""
|
||||
|
||||
__tablename__ = "plugin_instances"
|
||||
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
plugin_id = Column(UUID(as_uuid=True), ForeignKey("plugins.id"), nullable=False)
|
||||
configuration_id = Column(UUID(as_uuid=True), ForeignKey("plugin_configurations.id"))
|
||||
|
||||
configuration_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("plugin_configurations.id")
|
||||
)
|
||||
|
||||
# Runtime information
|
||||
instance_name = Column(String(200), nullable=False)
|
||||
process_id = Column(String(100)) # Docker container ID or process ID
|
||||
status = Column(String(20), nullable=False, default="starting", index=True)
|
||||
# Statuses: starting, running, stopping, stopped, error, crashed
|
||||
|
||||
|
||||
# Performance tracking
|
||||
start_time = Column(DateTime, nullable=False, default=func.now())
|
||||
last_heartbeat = Column(DateTime, default=func.now())
|
||||
stop_time = Column(DateTime)
|
||||
restart_count = Column(Integer, default=0)
|
||||
|
||||
|
||||
# Resource usage
|
||||
memory_usage_mb = Column(Integer)
|
||||
cpu_usage_percent = Column(Integer)
|
||||
|
||||
|
||||
# Health monitoring
|
||||
health_status = Column(String(20), default="unknown") # healthy, warning, critical
|
||||
health_message = Column(Text)
|
||||
last_health_check = Column(DateTime)
|
||||
|
||||
|
||||
# Error tracking
|
||||
last_error = Column(Text)
|
||||
error_count = Column(Integer, default=0)
|
||||
|
||||
|
||||
# Relationships
|
||||
plugin = relationship("Plugin", back_populates="instances")
|
||||
configuration = relationship("PluginConfiguration")
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_plugin_instance_status', 'plugin_id', 'status'),
|
||||
)
|
||||
|
||||
__table_args__ = (Index("idx_plugin_instance_status", "plugin_id", "status"),)
|
||||
|
||||
|
||||
class PluginAuditLog(Base):
|
||||
"""Audit logging for all plugin activities"""
|
||||
|
||||
__tablename__ = "plugin_audit_logs"
|
||||
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
plugin_id = Column(UUID(as_uuid=True), ForeignKey("plugins.id"), nullable=False)
|
||||
instance_id = Column(UUID(as_uuid=True), ForeignKey("plugin_instances.id"))
|
||||
|
||||
|
||||
# Event details
|
||||
event_type = Column(String(50), nullable=False, index=True) # api_call, config_change, error, etc.
|
||||
event_type = Column(
|
||||
String(50), nullable=False, index=True
|
||||
) # api_call, config_change, error, etc.
|
||||
action = Column(String(100), nullable=False)
|
||||
resource = Column(String(200)) # Resource being accessed
|
||||
|
||||
|
||||
# Context information
|
||||
user_id = Column(Integer, ForeignKey("users.id"))
|
||||
api_key_id = Column(Integer, ForeignKey("api_keys.id"))
|
||||
ip_address = Column(String(45)) # IPv4 or IPv6
|
||||
user_agent = Column(String(500))
|
||||
|
||||
|
||||
# Request/response data
|
||||
request_data = Column(JSON) # Sanitized request data
|
||||
response_status = Column(Integer)
|
||||
response_data = Column(JSON) # Sanitized response data
|
||||
|
||||
|
||||
# Performance metrics
|
||||
duration_ms = Column(Integer)
|
||||
|
||||
|
||||
# Status and errors
|
||||
success = Column(Boolean, nullable=False, index=True)
|
||||
error_message = Column(Text)
|
||||
|
||||
|
||||
# Timestamps
|
||||
timestamp = Column(DateTime, nullable=False, default=func.now(), index=True)
|
||||
|
||||
|
||||
# Relationships
|
||||
plugin = relationship("Plugin", back_populates="audit_logs")
|
||||
instance = relationship("PluginInstance")
|
||||
user = relationship("User")
|
||||
api_key = relationship("APIKey")
|
||||
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_plugin_audit_plugin_time', 'plugin_id', 'timestamp'),
|
||||
Index('idx_plugin_audit_user_time', 'user_id', 'timestamp'),
|
||||
Index('idx_plugin_audit_event_type', 'event_type', 'timestamp'),
|
||||
Index("idx_plugin_audit_plugin_time", "plugin_id", "timestamp"),
|
||||
Index("idx_plugin_audit_user_time", "user_id", "timestamp"),
|
||||
Index("idx_plugin_audit_event_type", "event_type", "timestamp"),
|
||||
)
|
||||
|
||||
|
||||
class PluginCronJob(Base):
|
||||
"""Plugin scheduled jobs and cron tasks"""
|
||||
|
||||
__tablename__ = "plugin_cron_jobs"
|
||||
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
plugin_id = Column(UUID(as_uuid=True), ForeignKey("plugins.id"), nullable=False)
|
||||
|
||||
|
||||
# Job identification
|
||||
job_name = Column(String(200), nullable=False)
|
||||
job_id = Column(String(100), nullable=False, unique=True, index=True) # Unique scheduler ID
|
||||
|
||||
job_id = Column(
|
||||
String(100), nullable=False, unique=True, index=True
|
||||
) # Unique scheduler ID
|
||||
|
||||
# Schedule configuration
|
||||
schedule = Column(String(100), nullable=False) # Cron expression
|
||||
timezone = Column(String(50), default="UTC")
|
||||
enabled = Column(Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
|
||||
# Job details
|
||||
description = Column(Text)
|
||||
function_name = Column(String(200), nullable=False) # Plugin function to call
|
||||
job_data = Column(JSON) # Parameters for the job function
|
||||
|
||||
|
||||
# Execution tracking
|
||||
last_run_at = Column(DateTime)
|
||||
next_run_at = Column(DateTime, index=True)
|
||||
@@ -228,65 +257,72 @@ class PluginCronJob(Base):
|
||||
run_count = Column(Integer, default=0)
|
||||
success_count = Column(Integer, default=0)
|
||||
error_count = Column(Integer, default=0)
|
||||
|
||||
|
||||
# Error handling
|
||||
last_error = Column(Text)
|
||||
last_error_at = Column(DateTime)
|
||||
max_retries = Column(Integer, default=3)
|
||||
retry_delay_seconds = Column(Integer, default=60)
|
||||
|
||||
|
||||
# Lifecycle
|
||||
created_at = Column(DateTime, nullable=False, default=func.now())
|
||||
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
|
||||
created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
|
||||
|
||||
# Relationships
|
||||
plugin = relationship("Plugin", back_populates="cron_jobs")
|
||||
created_by_user = relationship("User")
|
||||
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_plugin_cron_next_run', 'enabled', 'next_run_at'),
|
||||
Index('idx_plugin_cron_plugin', 'plugin_id', 'enabled'),
|
||||
Index("idx_plugin_cron_next_run", "enabled", "next_run_at"),
|
||||
Index("idx_plugin_cron_plugin", "plugin_id", "enabled"),
|
||||
)
|
||||
|
||||
|
||||
class PluginAPIGateway(Base):
|
||||
"""API gateway configuration for plugin routing"""
|
||||
|
||||
__tablename__ = "plugin_api_gateways"
|
||||
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
plugin_id = Column(UUID(as_uuid=True), ForeignKey("plugins.id"), nullable=False, unique=True)
|
||||
|
||||
plugin_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("plugins.id"), nullable=False, unique=True
|
||||
)
|
||||
|
||||
# API routing configuration
|
||||
base_path = Column(String(200), nullable=False, unique=True) # /api/v1/plugins/zammad
|
||||
base_path = Column(
|
||||
String(200), nullable=False, unique=True
|
||||
) # /api/v1/plugins/zammad
|
||||
internal_url = Column(String(500), nullable=False) # http://plugin-zammad:8000
|
||||
|
||||
|
||||
# Security settings
|
||||
require_authentication = Column(Boolean, default=True, nullable=False)
|
||||
allowed_methods = Column(JSON, default=["GET", "POST", "PUT", "DELETE"]) # HTTP methods
|
||||
allowed_methods = Column(
|
||||
JSON, default=["GET", "POST", "PUT", "DELETE"]
|
||||
) # HTTP methods
|
||||
rate_limit_per_minute = Column(Integer, default=60)
|
||||
rate_limit_per_hour = Column(Integer, default=1000)
|
||||
|
||||
|
||||
# CORS settings
|
||||
cors_enabled = Column(Boolean, default=True, nullable=False)
|
||||
cors_origins = Column(JSON, default=["*"])
|
||||
cors_methods = Column(JSON, default=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
cors_headers = Column(JSON, default=["*"])
|
||||
|
||||
|
||||
# Circuit breaker settings
|
||||
circuit_breaker_enabled = Column(Boolean, default=True, nullable=False)
|
||||
failure_threshold = Column(Integer, default=5)
|
||||
recovery_timeout_seconds = Column(Integer, default=60)
|
||||
|
||||
|
||||
# Monitoring
|
||||
enabled = Column(Boolean, default=True, nullable=False, index=True)
|
||||
last_health_check = Column(DateTime)
|
||||
health_status = Column(String(20), default="unknown") # healthy, unhealthy, timeout
|
||||
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, nullable=False, default=func.now())
|
||||
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
|
||||
|
||||
|
||||
# Relationships
|
||||
plugin = relationship("Plugin")
|
||||
|
||||
@@ -303,36 +339,42 @@ Add to APIKey model:
|
||||
plugin_audit_logs = relationship("PluginAuditLog", back_populates="api_key")
|
||||
"""
|
||||
|
||||
|
||||
class PluginPermission(Base):
|
||||
"""Plugin permission grants - tracks user permissions for plugins"""
|
||||
|
||||
__tablename__ = "plugin_permissions"
|
||||
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
plugin_id = Column(UUID(as_uuid=True), ForeignKey("plugins.id"), nullable=False)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
|
||||
|
||||
# Permission details
|
||||
permission_name = Column(String(200), nullable=False) # e.g., 'chatbot:invoke', 'rag:query'
|
||||
granted = Column(Boolean, default=True, nullable=False) # True=granted, False=revoked
|
||||
|
||||
permission_name = Column(
|
||||
String(200), nullable=False
|
||||
) # e.g., 'chatbot:invoke', 'rag:query'
|
||||
granted = Column(
|
||||
Boolean, default=True, nullable=False
|
||||
) # True=granted, False=revoked
|
||||
|
||||
# Grant/revoke tracking
|
||||
granted_at = Column(DateTime, nullable=False, default=func.now())
|
||||
granted_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
revoked_at = Column(DateTime)
|
||||
revoked_by_user_id = Column(Integer, ForeignKey("users.id"))
|
||||
|
||||
|
||||
# Metadata
|
||||
reason = Column(Text) # Reason for grant/revocation
|
||||
expires_at = Column(DateTime) # Optional expiration time
|
||||
|
||||
|
||||
# Relationships
|
||||
plugin = relationship("Plugin")
|
||||
user = relationship("User", foreign_keys=[user_id])
|
||||
granted_by_user = relationship("User", foreign_keys=[granted_by_user_id])
|
||||
revoked_by_user = relationship("User", foreign_keys=[revoked_by_user_id])
|
||||
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_plugin_permission_user_plugin', 'user_id', 'plugin_id'),
|
||||
Index('idx_plugin_permission_plugin_name', 'plugin_id', 'permission_name'),
|
||||
Index('idx_plugin_permission_active', 'plugin_id', 'user_id', 'granted'),
|
||||
Index("idx_plugin_permission_user_plugin", "user_id", "plugin_id"),
|
||||
Index("idx_plugin_permission_plugin_name", "plugin_id", "permission_name"),
|
||||
Index("idx_plugin_permission_active", "plugin_id", "user_id", "granted"),
|
||||
)
|
||||
|
||||
@@ -10,33 +10,41 @@ from datetime import datetime
|
||||
|
||||
class PromptTemplate(Base):
|
||||
"""Editable prompt templates for different chatbot types"""
|
||||
|
||||
__tablename__ = "prompt_templates"
|
||||
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
name = Column(String(255), nullable=False, index=True) # Human readable name
|
||||
type_key = Column(String(100), nullable=False, unique=True, index=True) # assistant, customer_support, etc.
|
||||
type_key = Column(
|
||||
String(100), nullable=False, unique=True, index=True
|
||||
) # assistant, customer_support, etc.
|
||||
description = Column(Text, nullable=True)
|
||||
system_prompt = Column(Text, nullable=False)
|
||||
is_default = Column(Boolean, default=True, nullable=False)
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
version = Column(Integer, default=1, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<PromptTemplate(type_key='{self.type_key}', name='{self.name}')>"
|
||||
|
||||
|
||||
class ChatbotPromptVariable(Base):
|
||||
"""Available variables that can be used in prompts"""
|
||||
|
||||
__tablename__ = "prompt_variables"
|
||||
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
variable_name = Column(String(100), nullable=False, unique=True, index=True) # {user_name}, {context}, etc.
|
||||
variable_name = Column(
|
||||
String(100), nullable=False, unique=True, index=True
|
||||
) # {user_name}, {context}, etc.
|
||||
description = Column(Text, nullable=True)
|
||||
example_value = Column(String(500), nullable=True)
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f"<PromptVariable(name='{self.variable_name}')>"
|
||||
return f"<PromptVariable(name='{self.variable_name}')>"
|
||||
|
||||
@@ -15,23 +15,36 @@ class RagCollection(Base):
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
description = Column(Text, nullable=True)
|
||||
qdrant_collection_name = Column(String(255), nullable=False, unique=True, index=True)
|
||||
|
||||
qdrant_collection_name = Column(
|
||||
String(255), nullable=False, unique=True, index=True
|
||||
)
|
||||
|
||||
# Metadata
|
||||
document_count = Column(Integer, default=0, nullable=False)
|
||||
size_bytes = Column(BigInteger, default=0, nullable=False)
|
||||
vector_count = Column(Integer, default=0, nullable=False)
|
||||
|
||||
|
||||
# Status tracking
|
||||
status = Column(String(50), default='active', nullable=False) # 'active', 'indexing', 'error'
|
||||
status = Column(
|
||||
String(50), default="active", nullable=False
|
||||
) # 'active', 'indexing', 'error'
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
documents = relationship("RagDocument", back_populates="collection", cascade="all, delete-orphan")
|
||||
documents = relationship(
|
||||
"RagDocument", back_populates="collection", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert model to dictionary for API responses"""
|
||||
@@ -45,8 +58,8 @@ class RagCollection(Base):
|
||||
"status": self.status,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"is_active": self.is_active
|
||||
"is_active": self.is_active,
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
return f"<RagCollection(id={self.id}, name='{self.name}', documents={self.document_count})>"
|
||||
return f"<RagCollection(id={self.id}, name='{self.name}', documents={self.document_count})>"
|
||||
|
||||
@@ -3,7 +3,17 @@ RAG Document Model
|
||||
Represents documents within RAG collections
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, BigInteger, ForeignKey, JSON
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
DateTime,
|
||||
Boolean,
|
||||
BigInteger,
|
||||
ForeignKey,
|
||||
JSON,
|
||||
)
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db.database import Base
|
||||
@@ -13,11 +23,16 @@ class RagDocument(Base):
|
||||
__tablename__ = "rag_documents"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
|
||||
# Collection relationship
|
||||
collection_id = Column(Integer, ForeignKey("rag_collections.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
collection_id = Column(
|
||||
Integer,
|
||||
ForeignKey("rag_collections.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
collection = relationship("RagCollection", back_populates="documents")
|
||||
|
||||
|
||||
# File information
|
||||
filename = Column(String(255), nullable=False) # sanitized filename for storage
|
||||
original_filename = Column(String(255), nullable=False) # user's original filename
|
||||
@@ -25,29 +40,44 @@ class RagDocument(Base):
|
||||
file_type = Column(String(50), nullable=False) # pdf, docx, txt, etc.
|
||||
file_size = Column(BigInteger, nullable=False) # file size in bytes
|
||||
mime_type = Column(String(100), nullable=True)
|
||||
|
||||
|
||||
# Processing status
|
||||
status = Column(String(50), default='processing', nullable=False) # 'processing', 'processed', 'error', 'indexed'
|
||||
status = Column(
|
||||
String(50), default="processing", nullable=False
|
||||
) # 'processing', 'processed', 'error', 'indexed'
|
||||
processing_error = Column(Text, nullable=True)
|
||||
|
||||
|
||||
# Content information
|
||||
converted_content = Column(Text, nullable=True) # markdown converted content
|
||||
word_count = Column(Integer, default=0, nullable=False)
|
||||
character_count = Column(Integer, default=0, nullable=False)
|
||||
|
||||
|
||||
# Vector information
|
||||
vector_count = Column(Integer, default=0, nullable=False) # number of chunks/vectors created
|
||||
chunk_size = Column(Integer, default=1000, nullable=False) # chunk size used for vectorization
|
||||
|
||||
vector_count = Column(
|
||||
Integer, default=0, nullable=False
|
||||
) # number of chunks/vectors created
|
||||
chunk_size = Column(
|
||||
Integer, default=1000, nullable=False
|
||||
) # chunk size used for vectorization
|
||||
|
||||
# Metadata extracted from document
|
||||
document_metadata = Column(JSON, nullable=True) # language, entities, keywords, etc.
|
||||
|
||||
document_metadata = Column(
|
||||
JSON, nullable=True
|
||||
) # language, entities, keywords, etc.
|
||||
|
||||
# Processing timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
processed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
indexed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Soft delete
|
||||
is_deleted = Column(Boolean, default=False, nullable=False)
|
||||
deleted_at = Column(DateTime(timezone=True), nullable=True)
|
||||
@@ -72,11 +102,13 @@ class RagDocument(Base):
|
||||
"chunk_size": self.chunk_size,
|
||||
"metadata": self.document_metadata or {},
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"processed_at": self.processed_at.isoformat() if self.processed_at else None,
|
||||
"processed_at": self.processed_at.isoformat()
|
||||
if self.processed_at
|
||||
else None,
|
||||
"indexed_at": self.indexed_at.isoformat() if self.indexed_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"is_deleted": self.is_deleted
|
||||
"is_deleted": self.is_deleted,
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
return f"<RagDocument(id={self.id}, filename='{self.original_filename}', status='{self.status}')>"
|
||||
return f"<RagDocument(id={self.id}, filename='{self.original_filename}', status='{self.status}')>"
|
||||
|
||||
158
backend/app/models/role.py
Normal file
158
backend/app/models/role.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
Role model for hierarchical permission management
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from enum import Enum
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class RoleLevel(str, Enum):
|
||||
"""Role hierarchy levels"""
|
||||
|
||||
READ_ONLY = "read_only" # Level 1: Can only view
|
||||
USER = "user" # Level 2: Can create and manage own resources
|
||||
ADMIN = "admin" # Level 3: Can manage users and settings
|
||||
SUPER_ADMIN = "super_admin" # Level 4: Full system access
|
||||
|
||||
|
||||
class Role(Base):
|
||||
"""Role model with hierarchical permissions"""
|
||||
|
||||
__tablename__ = "roles"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(50), unique=True, nullable=False)
|
||||
display_name = Column(String(100), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
level = Column(String(20), nullable=False) # RoleLevel enum
|
||||
|
||||
# Permissions configuration
|
||||
permissions = Column(JSON, default=dict) # Granular permissions
|
||||
can_manage_users = Column(Boolean, default=False)
|
||||
can_manage_budgets = Column(Boolean, default=False)
|
||||
can_view_reports = Column(Boolean, default=False)
|
||||
can_manage_tools = Column(Boolean, default=False)
|
||||
|
||||
# Role hierarchy
|
||||
inherits_from = Column(JSON, default=list) # List of parent role names
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_system_role = Column(Boolean, default=False) # System roles cannot be deleted
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
users = relationship("User", back_populates="role")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Role(id={self.id}, name='{self.name}', level='{self.level}')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert role to dictionary"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"display_name": self.display_name,
|
||||
"description": self.description,
|
||||
"level": self.level,
|
||||
"permissions": self.permissions,
|
||||
"can_manage_users": self.can_manage_users,
|
||||
"can_manage_budgets": self.can_manage_budgets,
|
||||
"can_view_reports": self.can_view_reports,
|
||||
"can_manage_tools": self.can_manage_tools,
|
||||
"inherits_from": self.inherits_from,
|
||||
"is_active": self.is_active,
|
||||
"is_system_role": self.is_system_role,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
def has_permission(self, permission: str) -> bool:
|
||||
"""Check if role has a specific permission"""
|
||||
if self.level == RoleLevel.SUPER_ADMIN:
|
||||
return True
|
||||
|
||||
# Check direct permissions
|
||||
if permission in self.permissions.get("granted", []):
|
||||
return True
|
||||
|
||||
# Check denied permissions
|
||||
if permission in self.permissions.get("denied", []):
|
||||
return False
|
||||
|
||||
# Check inherited permissions (simplified)
|
||||
for parent_role in self.inherits_from:
|
||||
# This would require recursive checking in a real implementation
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def create_default_roles(cls):
|
||||
"""Create default system roles"""
|
||||
roles = [
|
||||
cls(
|
||||
name="read_only",
|
||||
display_name="Read Only",
|
||||
description="Can view own data only",
|
||||
level=RoleLevel.READ_ONLY,
|
||||
permissions={
|
||||
"granted": ["read_own"],
|
||||
"denied": ["create", "update", "delete"],
|
||||
},
|
||||
is_system_role=True,
|
||||
),
|
||||
cls(
|
||||
name="user",
|
||||
display_name="User",
|
||||
description="Standard user with full access to own resources",
|
||||
level=RoleLevel.USER,
|
||||
permissions={
|
||||
"granted": ["read_own", "create_own", "update_own", "delete_own"],
|
||||
"denied": ["manage_users", "manage_all"],
|
||||
},
|
||||
inherits_from=["read_only"],
|
||||
is_system_role=True,
|
||||
),
|
||||
cls(
|
||||
name="admin",
|
||||
display_name="Administrator",
|
||||
description="Can manage users and view reports",
|
||||
level=RoleLevel.ADMIN,
|
||||
permissions={
|
||||
"granted": [
|
||||
"read_all",
|
||||
"create_all",
|
||||
"update_all",
|
||||
"manage_users",
|
||||
"view_reports",
|
||||
],
|
||||
"denied": ["system_settings"],
|
||||
},
|
||||
inherits_from=["user"],
|
||||
can_manage_users=True,
|
||||
can_manage_budgets=True,
|
||||
can_view_reports=True,
|
||||
is_system_role=True,
|
||||
),
|
||||
cls(
|
||||
name="super_admin",
|
||||
display_name="Super Administrator",
|
||||
description="Full system access",
|
||||
level=RoleLevel.SUPER_ADMIN,
|
||||
permissions={"granted": ["*"]}, # All permissions
|
||||
inherits_from=["admin"],
|
||||
can_manage_users=True,
|
||||
can_manage_budgets=True,
|
||||
can_view_reports=True,
|
||||
can_manage_tools=True,
|
||||
is_system_role=True,
|
||||
),
|
||||
]
|
||||
return roles
|
||||
272
backend/app/models/tool.py
Normal file
272
backend/app/models/tool.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
Tool model for custom tool execution
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from enum import Enum
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
Boolean,
|
||||
Text,
|
||||
JSON,
|
||||
ForeignKey,
|
||||
Float,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class ToolType(str, Enum):
|
||||
"""Tool execution types"""
|
||||
|
||||
PYTHON = "python"
|
||||
BASH = "bash"
|
||||
DOCKER = "docker"
|
||||
API = "api"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class ToolStatus(str, Enum):
|
||||
"""Tool execution status"""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
TIMEOUT = "timeout"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class Tool(Base):
|
||||
"""Tool definition model"""
|
||||
|
||||
__tablename__ = "tools"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
display_name = Column(String(200), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# Tool configuration
|
||||
tool_type = Column(String(20), nullable=False) # ToolType enum
|
||||
code = Column(Text, nullable=False) # Tool implementation code
|
||||
parameters_schema = Column(JSON, default=dict) # JSON schema for parameters
|
||||
return_schema = Column(JSON, default=dict) # Expected return format
|
||||
|
||||
# Execution settings
|
||||
timeout_seconds = Column(Integer, default=30)
|
||||
max_memory_mb = Column(Integer, default=256)
|
||||
max_cpu_seconds = Column(Float, default=10.0)
|
||||
|
||||
# Docker settings (for docker type tools)
|
||||
docker_image = Column(String(200), nullable=True)
|
||||
docker_command = Column(Text, nullable=True)
|
||||
|
||||
# Access control
|
||||
is_public = Column(Boolean, default=False) # Public tools available to all users
|
||||
is_approved = Column(Boolean, default=False) # Admin approved for security
|
||||
created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
|
||||
# Categories and tags
|
||||
category = Column(String(50), nullable=True)
|
||||
tags = Column(JSON, default=list)
|
||||
|
||||
# Usage tracking
|
||||
usage_count = Column(Integer, default=0)
|
||||
last_used_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
created_by = relationship("User", back_populates="created_tools")
|
||||
executions = relationship(
|
||||
"ToolExecution", back_populates="tool", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Tool(id={self.id}, name='{self.name}', type='{self.tool_type}')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert tool to dictionary"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"display_name": self.display_name,
|
||||
"description": self.description,
|
||||
"tool_type": self.tool_type,
|
||||
"parameters_schema": self.parameters_schema,
|
||||
"return_schema": self.return_schema,
|
||||
"timeout_seconds": self.timeout_seconds,
|
||||
"max_memory_mb": self.max_memory_mb,
|
||||
"max_cpu_seconds": self.max_cpu_seconds,
|
||||
"docker_image": self.docker_image,
|
||||
"is_public": self.is_public,
|
||||
"is_approved": self.is_approved,
|
||||
"created_by_user_id": self.created_by_user_id,
|
||||
"category": self.category,
|
||||
"tags": self.tags,
|
||||
"usage_count": self.usage_count,
|
||||
"last_used_at": self.last_used_at.isoformat()
|
||||
if self.last_used_at
|
||||
else None,
|
||||
"is_active": self.is_active,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
def increment_usage(self):
|
||||
"""Increment usage count and update last used timestamp"""
|
||||
self.usage_count += 1
|
||||
self.last_used_at = datetime.utcnow()
|
||||
|
||||
def can_be_used_by(self, user) -> bool:
|
||||
"""Check if user can use this tool"""
|
||||
# Tool creator can always use their tools
|
||||
if self.created_by_user_id == user.id:
|
||||
return True
|
||||
|
||||
# Public and approved tools can be used by anyone
|
||||
if self.is_public and self.is_approved:
|
||||
return True
|
||||
|
||||
# Admin users can use any tool
|
||||
if user.has_permission("manage_tools"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class ToolExecution(Base):
|
||||
"""Tool execution instance model"""
|
||||
|
||||
__tablename__ = "tool_executions"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
# Tool and user references
|
||||
tool_id = Column(Integer, ForeignKey("tools.id"), nullable=False)
|
||||
executed_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
|
||||
# Execution details
|
||||
parameters = Column(JSON, default=dict) # Input parameters
|
||||
status = Column(String(20), nullable=False, default=ToolStatus.PENDING)
|
||||
|
||||
# Results
|
||||
output = Column(Text, nullable=True) # Tool output
|
||||
error_message = Column(Text, nullable=True) # Error details if failed
|
||||
return_code = Column(Integer, nullable=True) # Exit code
|
||||
|
||||
# Resource usage
|
||||
execution_time_ms = Column(Integer, nullable=True) # Actual execution time
|
||||
memory_used_mb = Column(Float, nullable=True) # Peak memory usage
|
||||
cpu_time_ms = Column(Integer, nullable=True) # CPU time used
|
||||
|
||||
# Docker execution details
|
||||
container_id = Column(String(100), nullable=True) # Docker container ID
|
||||
docker_logs = Column(Text, nullable=True) # Docker container logs
|
||||
|
||||
# Timestamps
|
||||
started_at = Column(DateTime, nullable=True)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
tool = relationship("Tool", back_populates="executions")
|
||||
executed_by = relationship("User", back_populates="tool_executions")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ToolExecution(id={self.id}, tool_id={self.tool_id}, status='{self.status}')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert execution to dictionary"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"tool_id": self.tool_id,
|
||||
"tool_name": self.tool.name if self.tool else None,
|
||||
"executed_by_user_id": self.executed_by_user_id,
|
||||
"parameters": self.parameters,
|
||||
"status": self.status,
|
||||
"output": self.output,
|
||||
"error_message": self.error_message,
|
||||
"return_code": self.return_code,
|
||||
"execution_time_ms": self.execution_time_ms,
|
||||
"memory_used_mb": self.memory_used_mb,
|
||||
"cpu_time_ms": self.cpu_time_ms,
|
||||
"container_id": self.container_id,
|
||||
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||
"completed_at": self.completed_at.isoformat()
|
||||
if self.completed_at
|
||||
else None,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
@property
|
||||
def duration_seconds(self) -> float:
|
||||
"""Calculate execution duration in seconds"""
|
||||
if self.started_at and self.completed_at:
|
||||
return (self.completed_at - self.started_at).total_seconds()
|
||||
return 0.0
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Check if execution is currently running"""
|
||||
return self.status in [ToolStatus.PENDING, ToolStatus.RUNNING]
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
"""Check if execution is finished (success or failure)"""
|
||||
return self.status in [
|
||||
ToolStatus.COMPLETED,
|
||||
ToolStatus.FAILED,
|
||||
ToolStatus.TIMEOUT,
|
||||
ToolStatus.CANCELLED,
|
||||
]
|
||||
|
||||
|
||||
class ToolCategory(Base):
|
||||
"""Tool category for organization"""
|
||||
|
||||
__tablename__ = "tool_categories"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(50), unique=True, nullable=False)
|
||||
display_name = Column(String(100), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# Visual
|
||||
icon = Column(String(50), nullable=True) # Icon name
|
||||
color = Column(String(20), nullable=True) # Color code
|
||||
|
||||
# Ordering
|
||||
sort_order = Column(Integer, default=0)
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ToolCategory(id={self.id}, name='{self.name}')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert category to dictionary"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"display_name": self.display_name,
|
||||
"description": self.description,
|
||||
"icon": self.icon,
|
||||
"color": self.color,
|
||||
"sort_order": self.sort_order,
|
||||
"is_active": self.is_active,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
@@ -4,63 +4,73 @@ Usage Tracking model for API key usage statistics
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON, ForeignKey, Float
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
Boolean,
|
||||
Text,
|
||||
JSON,
|
||||
ForeignKey,
|
||||
Float,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class UsageTracking(Base):
|
||||
"""Usage tracking model for detailed API key usage statistics"""
|
||||
|
||||
|
||||
__tablename__ = "usage_tracking"
|
||||
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
|
||||
# API Key relationship
|
||||
api_key_id = Column(Integer, ForeignKey("api_keys.id"), nullable=False)
|
||||
api_key = relationship("APIKey", back_populates="usage_tracking")
|
||||
|
||||
|
||||
# User relationship
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
user = relationship("User", back_populates="usage_tracking")
|
||||
|
||||
|
||||
# Budget relationship (optional)
|
||||
budget_id = Column(Integer, ForeignKey("budgets.id"), nullable=True)
|
||||
budget = relationship("Budget", back_populates="usage_tracking")
|
||||
|
||||
|
||||
# Request information
|
||||
endpoint = Column(String, nullable=False) # API endpoint used
|
||||
method = Column(String, nullable=False) # HTTP method
|
||||
model = Column(String, nullable=True) # Model used (if applicable)
|
||||
|
||||
|
||||
# Usage metrics
|
||||
request_tokens = Column(Integer, default=0) # Input tokens
|
||||
response_tokens = Column(Integer, default=0) # Output tokens
|
||||
total_tokens = Column(Integer, default=0) # Total tokens used
|
||||
|
||||
|
||||
# Cost tracking
|
||||
cost_cents = Column(Integer, default=0) # Cost in cents
|
||||
cost_currency = Column(String, default="USD") # Currency
|
||||
|
||||
|
||||
# Response information
|
||||
response_time_ms = Column(Integer, nullable=True) # Response time in milliseconds
|
||||
status_code = Column(Integer, nullable=True) # HTTP status code
|
||||
|
||||
|
||||
# Request metadata
|
||||
request_id = Column(String, nullable=True) # Unique request identifier
|
||||
session_id = Column(String, nullable=True) # Session identifier
|
||||
ip_address = Column(String, nullable=True) # Client IP address
|
||||
user_agent = Column(String, nullable=True) # User agent
|
||||
|
||||
|
||||
# Additional metadata
|
||||
request_metadata = Column(JSON, default=dict) # Additional request metadata
|
||||
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f"<UsageTracking(id={self.id}, api_key_id={self.api_key_id}, endpoint='{self.endpoint}')>"
|
||||
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert usage tracking to dictionary for API responses"""
|
||||
return {
|
||||
@@ -82,9 +92,9 @@ class UsageTracking(Base):
|
||||
"ip_address": self.ip_address,
|
||||
"user_agent": self.user_agent,
|
||||
"request_metadata": self.request_metadata,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_tracking_record(
|
||||
cls,
|
||||
@@ -102,7 +112,7 @@ class UsageTracking(Base):
|
||||
session_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
request_metadata: Optional[dict] = None
|
||||
request_metadata: Optional[dict] = None,
|
||||
) -> "UsageTracking":
|
||||
"""Create a new usage tracking record"""
|
||||
return cls(
|
||||
@@ -121,5 +131,5 @@ class UsageTracking(Base):
|
||||
session_id=session_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_metadata=request_metadata or {}
|
||||
)
|
||||
request_metadata=request_metadata or {},
|
||||
)
|
||||
|
||||
@@ -4,75 +4,119 @@ User model
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from enum import Enum
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, JSON
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
Boolean,
|
||||
Text,
|
||||
JSON,
|
||||
ForeignKey,
|
||||
Numeric,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class UserRole(str, Enum):
|
||||
"""User role enumeration"""
|
||||
USER = "user"
|
||||
ADMIN = "admin"
|
||||
SUPER_ADMIN = "super_admin"
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""User model for authentication and user management"""
|
||||
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
email = Column(String, unique=True, index=True, nullable=False)
|
||||
username = Column(String, unique=True, index=True, nullable=False)
|
||||
hashed_password = Column(String, nullable=False)
|
||||
full_name = Column(String, nullable=True)
|
||||
|
||||
# User status and permissions
|
||||
|
||||
# Account status
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_superuser = Column(Boolean, default=False)
|
||||
is_verified = Column(Boolean, default=False)
|
||||
|
||||
# Role-based access control
|
||||
role = Column(String, default=UserRole.USER.value) # user, admin, super_admin
|
||||
permissions = Column(JSON, default=dict) # Custom permissions
|
||||
|
||||
is_superuser = Column(Boolean, default=False) # Legacy field for compatibility
|
||||
|
||||
# Role-based access control (using new Role model)
|
||||
role_id = Column(Integer, ForeignKey("roles.id"), nullable=True)
|
||||
custom_permissions = Column(JSON, default=dict) # Custom permissions override
|
||||
|
||||
# Account management
|
||||
account_locked = Column(Boolean, default=False)
|
||||
account_locked_until = Column(DateTime, nullable=True)
|
||||
failed_login_attempts = Column(Integer, default=0)
|
||||
last_failed_login = Column(DateTime, nullable=True)
|
||||
force_password_change = Column(Boolean, default=False)
|
||||
|
||||
# Profile information
|
||||
avatar_url = Column(String, nullable=True)
|
||||
bio = Column(Text, nullable=True)
|
||||
company = Column(String, nullable=True)
|
||||
website = Column(String, nullable=True)
|
||||
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
last_login = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
# Settings
|
||||
preferences = Column(JSON, default=dict)
|
||||
notification_settings = Column(JSON, default=dict)
|
||||
|
||||
|
||||
# Relationships
|
||||
api_keys = relationship("APIKey", back_populates="user", cascade="all, delete-orphan")
|
||||
usage_tracking = relationship("UsageTracking", back_populates="user", cascade="all, delete-orphan")
|
||||
budgets = relationship("Budget", back_populates="user", cascade="all, delete-orphan")
|
||||
audit_logs = relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
|
||||
role = relationship("Role", back_populates="users")
|
||||
api_keys = relationship(
|
||||
"APIKey", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
usage_tracking = relationship(
|
||||
"UsageTracking", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
budgets = relationship(
|
||||
"Budget", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
audit_logs = relationship(
|
||||
"AuditLog", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
installed_plugins = relationship("Plugin", back_populates="installed_by_user")
|
||||
|
||||
created_tools = relationship(
|
||||
"Tool", back_populates="created_by", cascade="all, delete-orphan"
|
||||
)
|
||||
tool_executions = relationship(
|
||||
"ToolExecution", back_populates="executed_by", cascade="all, delete-orphan"
|
||||
)
|
||||
notifications = relationship(
|
||||
"Notification", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User(id={self.id}, email='{self.email}', username='{self.username}')>"
|
||||
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert user to dictionary for API responses"""
|
||||
# Check if role relationship is loaded to avoid lazy loading in async context
|
||||
inspector = sa_inspect(self)
|
||||
role_loaded = "role" not in inspector.unloaded
|
||||
|
||||
return {
|
||||
"id": self.id,
|
||||
"email": self.email,
|
||||
"username": self.username,
|
||||
"full_name": self.full_name,
|
||||
"is_active": self.is_active,
|
||||
"is_superuser": self.is_superuser,
|
||||
"is_verified": self.is_verified,
|
||||
"role": self.role,
|
||||
"permissions": self.permissions,
|
||||
"is_superuser": self.is_superuser,
|
||||
"role_id": self.role_id,
|
||||
"role": self.role.to_dict() if role_loaded and self.role else None,
|
||||
"custom_permissions": self.custom_permissions,
|
||||
"account_locked": self.account_locked,
|
||||
"account_locked_until": self.account_locked_until.isoformat()
|
||||
if self.account_locked_until
|
||||
else None,
|
||||
"failed_login_attempts": self.failed_login_attempts,
|
||||
"last_failed_login": self.last_failed_login.isoformat()
|
||||
if self.last_failed_login
|
||||
else None,
|
||||
"force_password_change": self.force_password_change,
|
||||
"avatar_url": self.avatar_url,
|
||||
"bio": self.bio,
|
||||
"company": self.company,
|
||||
@@ -81,54 +125,157 @@ class User(Base):
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"last_login": self.last_login.isoformat() if self.last_login else None,
|
||||
"preferences": self.preferences,
|
||||
"notification_settings": self.notification_settings
|
||||
"notification_settings": self.notification_settings,
|
||||
}
|
||||
|
||||
|
||||
def has_permission(self, permission: str) -> bool:
|
||||
"""Check if user has a specific permission"""
|
||||
"""Check if user has a specific permission using role hierarchy"""
|
||||
if self.is_superuser:
|
||||
return True
|
||||
|
||||
# Check role-based permissions
|
||||
role_permissions = {
|
||||
"user": ["read_own", "create_own", "update_own"],
|
||||
"admin": ["read_all", "create_all", "update_all", "delete_own"],
|
||||
"super_admin": ["read_all", "create_all", "update_all", "delete_all", "manage_users", "manage_modules"]
|
||||
}
|
||||
|
||||
if self.role in role_permissions and permission in role_permissions[self.role]:
|
||||
|
||||
# Check custom permissions first (override)
|
||||
if permission in self.custom_permissions.get("denied", []):
|
||||
return False
|
||||
if permission in self.custom_permissions.get("granted", []):
|
||||
return True
|
||||
|
||||
# Check custom permissions
|
||||
return permission in self.permissions
|
||||
|
||||
|
||||
# Check role permissions if user has a role assigned
|
||||
if self.role:
|
||||
return self.role.has_permission(permission)
|
||||
|
||||
return False
|
||||
|
||||
def can_access_module(self, module_name: str) -> bool:
|
||||
"""Check if user can access a specific module"""
|
||||
if self.is_superuser:
|
||||
return True
|
||||
|
||||
# Check module-specific permissions
|
||||
module_permissions = self.permissions.get("modules", {})
|
||||
return module_permissions.get(module_name, False)
|
||||
|
||||
|
||||
# Check custom permissions first
|
||||
module_permissions = self.custom_permissions.get("modules", {})
|
||||
if module_name in module_permissions:
|
||||
return module_permissions[module_name]
|
||||
|
||||
# Check role permissions
|
||||
if self.role:
|
||||
# For admin roles, allow all modules
|
||||
if self.role.level in ["admin", "super_admin"]:
|
||||
return True
|
||||
# For regular users, check module access
|
||||
elif self.role.level == "user":
|
||||
return True # Basic users can access all modules
|
||||
# For read-only users, limit access
|
||||
elif self.role.level == "read_only":
|
||||
return module_name in ["chatbot", "analytics"] # Only certain modules
|
||||
|
||||
return False
|
||||
|
||||
def update_last_login(self):
|
||||
"""Update the last login timestamp"""
|
||||
self.last_login = datetime.utcnow()
|
||||
|
||||
|
||||
def update_preferences(self, preferences: dict):
|
||||
"""Update user preferences"""
|
||||
if self.preferences is None:
|
||||
self.preferences = {}
|
||||
self.preferences.update(preferences)
|
||||
|
||||
|
||||
def update_notification_settings(self, settings: dict):
|
||||
"""Update notification settings"""
|
||||
if self.notification_settings is None:
|
||||
self.notification_settings = {}
|
||||
self.notification_settings.update(settings)
|
||||
|
||||
|
||||
def get_effective_permissions(self) -> dict:
|
||||
"""Get all effective permissions combining role and custom permissions"""
|
||||
permissions = {"granted": set(), "denied": set()}
|
||||
|
||||
# Start with role permissions
|
||||
if self.role:
|
||||
role_perms = self.role.permissions
|
||||
permissions["granted"].update(role_perms.get("granted", []))
|
||||
permissions["denied"].update(role_perms.get("denied", []))
|
||||
|
||||
# Apply custom permissions (override role permissions)
|
||||
permissions["granted"].update(self.custom_permissions.get("granted", []))
|
||||
permissions["denied"].update(self.custom_permissions.get("denied", []))
|
||||
|
||||
# Remove any denied permissions from granted
|
||||
permissions["granted"] -= permissions["denied"]
|
||||
|
||||
return {
|
||||
"granted": list(permissions["granted"]),
|
||||
"denied": list(permissions["denied"]),
|
||||
}
|
||||
|
||||
def can_create_api_key(self) -> bool:
|
||||
"""Check if user can create API keys based on role and limits"""
|
||||
if not self.is_active or self.account_locked:
|
||||
return False
|
||||
|
||||
# Check permission
|
||||
if not self.has_permission("create_api_key"):
|
||||
return False
|
||||
|
||||
# Check if user has reached their API key limit
|
||||
current_keys = [key for key in self.api_keys if key.is_active]
|
||||
max_keys = (
|
||||
self.role.permissions.get("limits", {}).get("max_api_keys", 5)
|
||||
if self.role
|
||||
else 5
|
||||
)
|
||||
|
||||
return len(current_keys) < max_keys
|
||||
|
||||
def can_create_tool(self) -> bool:
|
||||
"""Check if user can create custom tools"""
|
||||
return (
|
||||
self.is_active
|
||||
and not self.account_locked
|
||||
and self.has_permission("create_tool")
|
||||
)
|
||||
|
||||
def is_budget_exceeded(self) -> bool:
|
||||
"""Check if user has exceeded their budget limits"""
|
||||
if not self.budgets:
|
||||
return False
|
||||
|
||||
active_budget = next((b for b in self.budgets if b.is_active), None)
|
||||
if not active_budget:
|
||||
return False
|
||||
|
||||
return active_budget.current_usage > active_budget.limit
|
||||
|
||||
def lock_account(self, duration_hours: int = 24):
|
||||
"""Lock user account for specified duration"""
|
||||
from datetime import timedelta
|
||||
|
||||
self.account_locked = True
|
||||
self.account_locked_until = datetime.utcnow() + timedelta(hours=duration_hours)
|
||||
|
||||
def unlock_account(self):
|
||||
"""Unlock user account"""
|
||||
self.account_locked = False
|
||||
self.account_locked_until = None
|
||||
self.failed_login_attempts = 0
|
||||
|
||||
def record_failed_login(self):
|
||||
"""Record a failed login attempt"""
|
||||
self.failed_login_attempts += 1
|
||||
self.last_failed_login = datetime.utcnow()
|
||||
|
||||
# Lock account after 5 failed attempts
|
||||
if self.failed_login_attempts >= 5:
|
||||
self.lock_account(24) # Lock for 24 hours
|
||||
|
||||
def reset_failed_logins(self):
|
||||
"""Reset failed login counter"""
|
||||
self.failed_login_attempts = 0
|
||||
self.last_failed_login = None
|
||||
|
||||
@classmethod
|
||||
def create_default_admin(cls, email: str, username: str, password_hash: str) -> "User":
|
||||
def create_default_admin(
|
||||
cls, email: str, username: str, password_hash: str
|
||||
) -> "User":
|
||||
"""Create a default admin user"""
|
||||
return cls(
|
||||
email=email,
|
||||
@@ -136,24 +283,16 @@ class User(Base):
|
||||
hashed_password=password_hash,
|
||||
full_name="System Administrator",
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
is_superuser=True, # Legacy compatibility
|
||||
is_verified=True,
|
||||
role="super_admin",
|
||||
permissions={
|
||||
"modules": {
|
||||
"cache": True,
|
||||
"analytics": True,
|
||||
"rag": True
|
||||
}
|
||||
},
|
||||
preferences={
|
||||
"theme": "dark",
|
||||
"language": "en",
|
||||
"timezone": "UTC"
|
||||
# Note: role_id will be set after role is created in init_db
|
||||
custom_permissions={
|
||||
"modules": {"cache": True, "analytics": True, "rag": True}
|
||||
},
|
||||
preferences={"theme": "dark", "language": "en", "timezone": "UTC"},
|
||||
notification_settings={
|
||||
"email_notifications": True,
|
||||
"security_alerts": True,
|
||||
"system_updates": True
|
||||
}
|
||||
)
|
||||
"system_updates": True,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -15,7 +15,4 @@ __version__ = "1.0.0"
|
||||
__author__ = "Enclava Team"
|
||||
|
||||
# Export main classes for easy importing
|
||||
__all__ = [
|
||||
"ChatbotModule",
|
||||
"create_module"
|
||||
]
|
||||
__all__ = ["ChatbotModule", "create_module"]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,7 +10,7 @@ from typing import Dict, Optional, Any
|
||||
import logging
|
||||
|
||||
# Import all modules
|
||||
from .rag.main import RAGModule
|
||||
from .rag.main import RAGModule
|
||||
from .chatbot.main import ChatbotModule, create_module as create_chatbot_module
|
||||
from .workflow.main import WorkflowModule
|
||||
|
||||
@@ -19,11 +19,11 @@ from app.services.litellm_client import LiteLLMClient
|
||||
|
||||
# Import protocols for type safety
|
||||
from .protocols import (
|
||||
RAGServiceProtocol,
|
||||
ChatbotServiceProtocol,
|
||||
RAGServiceProtocol,
|
||||
ChatbotServiceProtocol,
|
||||
LiteLLMClientProtocol,
|
||||
WorkflowServiceProtocol,
|
||||
ServiceRegistry
|
||||
ServiceRegistry,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -31,113 +31,119 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ModuleFactory:
|
||||
"""Factory for creating and wiring module dependencies"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.modules: Dict[str, Any] = {}
|
||||
self.initialized = False
|
||||
|
||||
async def create_all_modules(self, config: Optional[Dict[str, Any]] = None) -> ServiceRegistry:
|
||||
|
||||
async def create_all_modules(
|
||||
self, config: Optional[Dict[str, Any]] = None
|
||||
) -> ServiceRegistry:
|
||||
"""
|
||||
Create all modules with proper dependency injection
|
||||
|
||||
|
||||
Args:
|
||||
config: Optional configuration for modules
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary of created modules with their dependencies wired
|
||||
"""
|
||||
config = config or {}
|
||||
|
||||
|
||||
logger.info("Creating modules with dependency injection...")
|
||||
|
||||
|
||||
# Step 1: Create LiteLLM client (shared dependency)
|
||||
litellm_client = LiteLLMClient()
|
||||
|
||||
|
||||
# Step 2: Create RAG module (no dependencies on other modules)
|
||||
rag_module = RAGModule(config=config.get("rag", {}))
|
||||
|
||||
|
||||
# Step 3: Create chatbot module with RAG dependency
|
||||
chatbot_module = create_chatbot_module(
|
||||
litellm_client=litellm_client,
|
||||
rag_service=rag_module # RAG module implements RAGServiceProtocol
|
||||
rag_service=rag_module, # RAG module implements RAGServiceProtocol
|
||||
)
|
||||
|
||||
# Step 4: Create workflow module with chatbot dependency
|
||||
|
||||
# Step 4: Create workflow module with chatbot dependency
|
||||
workflow_module = WorkflowModule(
|
||||
chatbot_service=chatbot_module # Chatbot module implements ChatbotServiceProtocol
|
||||
)
|
||||
|
||||
|
||||
# Store all modules
|
||||
modules = {
|
||||
"rag": rag_module,
|
||||
"chatbot": chatbot_module,
|
||||
"workflow": workflow_module
|
||||
"workflow": workflow_module,
|
||||
}
|
||||
|
||||
|
||||
logger.info(f"Created {len(modules)} modules with dependencies wired")
|
||||
|
||||
|
||||
# Initialize all modules
|
||||
await self._initialize_modules(modules, config)
|
||||
|
||||
|
||||
self.modules = modules
|
||||
self.initialized = True
|
||||
|
||||
|
||||
return modules
|
||||
|
||||
async def _initialize_modules(self, modules: Dict[str, Any], config: Dict[str, Any]):
|
||||
|
||||
async def _initialize_modules(
|
||||
self, modules: Dict[str, Any], config: Dict[str, Any]
|
||||
):
|
||||
"""Initialize all modules in dependency order"""
|
||||
|
||||
|
||||
# Initialize in dependency order (modules with no deps first)
|
||||
initialization_order = [
|
||||
("rag", modules["rag"]),
|
||||
("chatbot", modules["chatbot"]), # Depends on RAG
|
||||
("workflow", modules["workflow"]) # Depends on Chatbot
|
||||
("workflow", modules["workflow"]), # Depends on Chatbot
|
||||
]
|
||||
|
||||
|
||||
for module_name, module in initialization_order:
|
||||
try:
|
||||
logger.info(f"Initializing {module_name} module...")
|
||||
module_config = config.get(module_name, {})
|
||||
|
||||
|
||||
# Different modules have different initialization patterns
|
||||
if hasattr(module, 'initialize'):
|
||||
if hasattr(module, "initialize"):
|
||||
if module_name == "rag":
|
||||
await module.initialize()
|
||||
else:
|
||||
await module.initialize(**module_config)
|
||||
|
||||
|
||||
logger.info(f"✅ {module_name} module initialized successfully")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to initialize {module_name} module: {e}")
|
||||
raise RuntimeError(f"Module initialization failed: {module_name}") from e
|
||||
|
||||
raise RuntimeError(
|
||||
f"Module initialization failed: {module_name}"
|
||||
) from e
|
||||
|
||||
async def cleanup_all_modules(self):
|
||||
"""Cleanup all modules in reverse dependency order"""
|
||||
if not self.initialized:
|
||||
return
|
||||
|
||||
|
||||
# Cleanup in reverse order
|
||||
cleanup_order = ["workflow", "chatbot", "rag"]
|
||||
|
||||
|
||||
for module_name in cleanup_order:
|
||||
if module_name in self.modules:
|
||||
try:
|
||||
logger.info(f"Cleaning up {module_name} module...")
|
||||
module = self.modules[module_name]
|
||||
if hasattr(module, 'cleanup'):
|
||||
if hasattr(module, "cleanup"):
|
||||
await module.cleanup()
|
||||
logger.info(f"✅ {module_name} module cleaned up")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error cleaning up {module_name}: {e}")
|
||||
|
||||
|
||||
self.modules.clear()
|
||||
self.initialized = False
|
||||
|
||||
|
||||
def get_module(self, name: str) -> Optional[Any]:
|
||||
"""Get a module by name"""
|
||||
return self.modules.get(name)
|
||||
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if factory is initialized"""
|
||||
return self.initialized
|
||||
@@ -174,13 +180,16 @@ def create_rag_module(config: Optional[Dict[str, Any]] = None) -> RAGModule:
|
||||
return RAGModule(config=config or {})
|
||||
|
||||
|
||||
def create_chatbot_with_rag(rag_service: RAGServiceProtocol,
|
||||
litellm_client: LiteLLMClientProtocol) -> ChatbotModule:
|
||||
def create_chatbot_with_rag(
|
||||
rag_service: RAGServiceProtocol, litellm_client: LiteLLMClientProtocol
|
||||
) -> ChatbotModule:
|
||||
"""Create chatbot module with RAG dependency"""
|
||||
return create_chatbot_module(litellm_client=litellm_client, rag_service=rag_service)
|
||||
|
||||
|
||||
def create_workflow_with_chatbot(chatbot_service: ChatbotServiceProtocol) -> WorkflowModule:
|
||||
def create_workflow_with_chatbot(
|
||||
chatbot_service: ChatbotServiceProtocol,
|
||||
) -> WorkflowModule:
|
||||
"""Create workflow module with chatbot dependency"""
|
||||
return WorkflowModule(chatbot_service=chatbot_service)
|
||||
|
||||
@@ -188,38 +197,38 @@ def create_workflow_with_chatbot(chatbot_service: ChatbotServiceProtocol) -> Wor
|
||||
# Module registry for backward compatibility
|
||||
class ModuleRegistry:
|
||||
"""Registry that provides access to modules (for backward compatibility)"""
|
||||
|
||||
|
||||
def __init__(self, factory: ModuleFactory):
|
||||
self._factory = factory
|
||||
|
||||
|
||||
@property
|
||||
def modules(self) -> Dict[str, Any]:
|
||||
"""Get all modules (compatible with existing module_manager interface)"""
|
||||
return self._factory.modules
|
||||
|
||||
|
||||
def get(self, name: str) -> Optional[Any]:
|
||||
"""Get module by name"""
|
||||
return self._factory.get_module(name)
|
||||
|
||||
|
||||
def __getitem__(self, name: str) -> Any:
|
||||
"""Support dictionary-style access"""
|
||||
module = self.get(name)
|
||||
if module is None:
|
||||
raise KeyError(f"Module '{name}' not found")
|
||||
return module
|
||||
|
||||
|
||||
def keys(self):
|
||||
"""Get module names"""
|
||||
return self._factory.modules.keys()
|
||||
|
||||
|
||||
def values(self):
|
||||
"""Get module instances"""
|
||||
"""Get module instances"""
|
||||
return self._factory.modules.values()
|
||||
|
||||
|
||||
def items(self):
|
||||
"""Get module name-instance pairs"""
|
||||
return self._factory.modules.items()
|
||||
|
||||
|
||||
# Create registry instance for backward compatibility
|
||||
module_registry = ModuleRegistry(module_factory)
|
||||
module_registry = ModuleRegistry(module_factory)
|
||||
|
||||
@@ -12,44 +12,48 @@ from abc import abstractmethod
|
||||
|
||||
class RAGServiceProtocol(Protocol):
|
||||
"""Protocol for RAG (Retrieval-Augmented Generation) service interface"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def search(self, query: str, collection_name: str, top_k: int) -> Dict[str, Any]:
|
||||
async def search(
|
||||
self, query: str, collection_name: str, top_k: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Search for relevant documents
|
||||
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
collection_name: Name of the collection to search in
|
||||
top_k: Number of top results to return
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing search results with 'results' key
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def index_document(self, content: str, metadata: Dict[str, Any] = None) -> str:
|
||||
async def index_document(
|
||||
self, content: str, metadata: Dict[str, Any] = None
|
||||
) -> str:
|
||||
"""
|
||||
Index a document in the vector database
|
||||
|
||||
|
||||
Args:
|
||||
content: Document content to index
|
||||
metadata: Optional metadata for the document
|
||||
|
||||
|
||||
Returns:
|
||||
Document ID
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def delete_document(self, document_id: str) -> bool:
|
||||
"""
|
||||
Delete a document from the vector database
|
||||
|
||||
|
||||
Args:
|
||||
document_id: ID of document to delete
|
||||
|
||||
|
||||
Returns:
|
||||
True if successfully deleted
|
||||
"""
|
||||
@@ -58,32 +62,32 @@ class RAGServiceProtocol(Protocol):
|
||||
|
||||
class ChatbotServiceProtocol(Protocol):
|
||||
"""Protocol for Chatbot service interface"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def chat_completion(self, request: Any, user_id: str, db: Any) -> Any:
|
||||
"""
|
||||
Generate chat completion response
|
||||
|
||||
|
||||
Args:
|
||||
request: Chat request object
|
||||
user_id: ID of the user making the request
|
||||
db: Database session
|
||||
|
||||
|
||||
Returns:
|
||||
Chat response object
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def create_chatbot(self, config: Any, user_id: str, db: Any) -> Any:
|
||||
"""
|
||||
Create a new chatbot instance
|
||||
|
||||
|
||||
Args:
|
||||
config: Chatbot configuration
|
||||
user_id: ID of the user creating the chatbot
|
||||
db: Database session
|
||||
|
||||
|
||||
Returns:
|
||||
Created chatbot instance
|
||||
"""
|
||||
@@ -92,35 +96,43 @@ class ChatbotServiceProtocol(Protocol):
|
||||
|
||||
class LiteLLMClientProtocol(Protocol):
|
||||
"""Protocol for LiteLLM client interface"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def completion(self, model: str, messages: List[Dict[str, str]], **kwargs) -> Any:
|
||||
async def completion(
|
||||
self, model: str, messages: List[Dict[str, str]], **kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Create a completion using the specified model
|
||||
|
||||
|
||||
Args:
|
||||
model: Model name to use
|
||||
messages: List of messages for the conversation
|
||||
**kwargs: Additional parameters for the completion
|
||||
|
||||
|
||||
Returns:
|
||||
Completion response object
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def create_chat_completion(self, model: str, messages: List[Dict[str, str]],
|
||||
user_id: str, api_key_id: str, **kwargs) -> Any:
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
user_id: str,
|
||||
api_key_id: str,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
Create a chat completion with user tracking
|
||||
|
||||
|
||||
Args:
|
||||
model: Model name to use
|
||||
messages: List of messages for the conversation
|
||||
user_id: ID of the user making the request
|
||||
api_key_id: API key identifier
|
||||
**kwargs: Additional parameters
|
||||
|
||||
|
||||
Returns:
|
||||
Chat completion response
|
||||
"""
|
||||
@@ -129,44 +141,44 @@ class LiteLLMClientProtocol(Protocol):
|
||||
|
||||
class CacheServiceProtocol(Protocol):
|
||||
"""Protocol for Cache service interface"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Get value from cache
|
||||
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
default: Default value if key not found
|
||||
|
||||
|
||||
Returns:
|
||||
Cached value or default
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
|
||||
"""
|
||||
Set value in cache
|
||||
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
ttl: Time to live in seconds
|
||||
|
||||
|
||||
Returns:
|
||||
True if successfully cached
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete key from cache
|
||||
|
||||
|
||||
Args:
|
||||
key: Cache key to delete
|
||||
|
||||
|
||||
Returns:
|
||||
True if successfully deleted
|
||||
"""
|
||||
@@ -175,28 +187,28 @@ class CacheServiceProtocol(Protocol):
|
||||
|
||||
class SecurityServiceProtocol(Protocol):
|
||||
"""Protocol for Security service interface"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def analyze_request(self, request: Any) -> Any:
|
||||
"""
|
||||
Perform security analysis on a request
|
||||
|
||||
|
||||
Args:
|
||||
request: Request object to analyze
|
||||
|
||||
|
||||
Returns:
|
||||
Security analysis result
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def validate_request(self, request: Any) -> bool:
|
||||
"""
|
||||
Validate request for security compliance
|
||||
|
||||
|
||||
Args:
|
||||
request: Request object to validate
|
||||
|
||||
|
||||
Returns:
|
||||
True if request is valid/safe
|
||||
"""
|
||||
@@ -205,29 +217,31 @@ class SecurityServiceProtocol(Protocol):
|
||||
|
||||
class WorkflowServiceProtocol(Protocol):
|
||||
"""Protocol for Workflow service interface"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def execute_workflow(self, workflow: Any, input_data: Dict[str, Any] = None) -> Any:
|
||||
async def execute_workflow(
|
||||
self, workflow: Any, input_data: Dict[str, Any] = None
|
||||
) -> Any:
|
||||
"""
|
||||
Execute a workflow definition
|
||||
|
||||
|
||||
Args:
|
||||
workflow: Workflow definition to execute
|
||||
input_data: Optional input data for the workflow
|
||||
|
||||
|
||||
Returns:
|
||||
Workflow execution result
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def get_execution(self, execution_id: str) -> Any:
|
||||
"""
|
||||
Get workflow execution status
|
||||
|
||||
|
||||
Args:
|
||||
execution_id: ID of the execution to retrieve
|
||||
|
||||
|
||||
Returns:
|
||||
Execution status object
|
||||
"""
|
||||
@@ -236,17 +250,17 @@ class WorkflowServiceProtocol(Protocol):
|
||||
|
||||
class ModuleServiceProtocol(Protocol):
|
||||
"""Base protocol for all module services"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def initialize(self, **kwargs) -> None:
|
||||
"""Initialize the module"""
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup module resources"""
|
||||
...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_required_permissions(self) -> List[Any]:
|
||||
"""Get required permissions for this module"""
|
||||
@@ -255,4 +269,4 @@ class ModuleServiceProtocol(Protocol):
|
||||
|
||||
# Type aliases for common service combinations
|
||||
ServiceRegistry = Dict[str, ModuleServiceProtocol]
|
||||
ServiceDependencies = Dict[str, Optional[ModuleServiceProtocol]]
|
||||
ServiceDependencies = Dict[str, Optional[ModuleServiceProtocol]]
|
||||
|
||||
@@ -3,4 +3,4 @@ RAG (Retrieval-Augmented Generation) module for Confidential Empire platform
|
||||
"""
|
||||
from .main import RAGModule
|
||||
|
||||
__all__ = ["RAGModule"]
|
||||
__all__ = ["RAGModule"]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,77 +13,102 @@ from pathlib import Path
|
||||
|
||||
class PluginRuntimeSpec(BaseModel):
|
||||
"""Plugin runtime requirements and dependencies"""
|
||||
|
||||
python_version: str = Field("3.11", description="Required Python version")
|
||||
dependencies: List[str] = Field(default_factory=list, description="Required Python packages")
|
||||
environment_variables: Dict[str, str] = Field(default_factory=dict, description="Required environment variables")
|
||||
|
||||
@validator('python_version')
|
||||
dependencies: List[str] = Field(
|
||||
default_factory=list, description="Required Python packages"
|
||||
)
|
||||
environment_variables: Dict[str, str] = Field(
|
||||
default_factory=dict, description="Required environment variables"
|
||||
)
|
||||
|
||||
@validator("python_version")
|
||||
def validate_python_version(cls, v):
|
||||
if not v.startswith(('3.9', '3.10', '3.11', '3.12')):
|
||||
raise ValueError('Python version must be 3.9, 3.10, 3.11, or 3.12')
|
||||
if not v.startswith(("3.9", "3.10", "3.11", "3.12")):
|
||||
raise ValueError("Python version must be 3.9, 3.10, 3.11, or 3.12")
|
||||
return v
|
||||
|
||||
|
||||
class PluginPermissions(BaseModel):
|
||||
"""Plugin permission specifications"""
|
||||
platform_apis: List[str] = Field(default_factory=list, description="Platform API access scopes")
|
||||
plugin_scopes: List[str] = Field(default_factory=list, description="Plugin-specific permission scopes")
|
||||
external_domains: List[str] = Field(default_factory=list, description="Allowed external domains")
|
||||
|
||||
@validator('platform_apis')
|
||||
|
||||
platform_apis: List[str] = Field(
|
||||
default_factory=list, description="Platform API access scopes"
|
||||
)
|
||||
plugin_scopes: List[str] = Field(
|
||||
default_factory=list, description="Plugin-specific permission scopes"
|
||||
)
|
||||
external_domains: List[str] = Field(
|
||||
default_factory=list, description="Allowed external domains"
|
||||
)
|
||||
|
||||
@validator("platform_apis")
|
||||
def validate_platform_apis(cls, v):
|
||||
allowed_apis = [
|
||||
'chatbot:invoke', 'chatbot:manage', 'chatbot:read',
|
||||
'rag:query', 'rag:manage', 'rag:read',
|
||||
'llm:completion', 'llm:embeddings', 'llm:models',
|
||||
'workflow:execute', 'workflow:read',
|
||||
'cache:read', 'cache:write'
|
||||
"chatbot:invoke",
|
||||
"chatbot:manage",
|
||||
"chatbot:read",
|
||||
"rag:query",
|
||||
"rag:manage",
|
||||
"rag:read",
|
||||
"llm:completion",
|
||||
"llm:embeddings",
|
||||
"llm:models",
|
||||
"workflow:execute",
|
||||
"workflow:read",
|
||||
"cache:read",
|
||||
"cache:write",
|
||||
]
|
||||
for api in v:
|
||||
if api not in allowed_apis and not api.endswith(':*'):
|
||||
raise ValueError(f'Invalid platform API scope: {api}')
|
||||
if api not in allowed_apis and not api.endswith(":*"):
|
||||
raise ValueError(f"Invalid platform API scope: {api}")
|
||||
return v
|
||||
|
||||
|
||||
class PluginDatabaseSpec(BaseModel):
|
||||
"""Plugin database configuration"""
|
||||
|
||||
schema: str = Field(..., description="Database schema name")
|
||||
migrations_path: str = Field("./migrations", description="Path to migration files")
|
||||
auto_migrate: bool = Field(True, description="Auto-run migrations on startup")
|
||||
|
||||
@validator('schema')
|
||||
|
||||
@validator("schema")
|
||||
def validate_schema_name(cls, v):
|
||||
if not v.startswith('plugin_'):
|
||||
if not v.startswith("plugin_"):
|
||||
raise ValueError('Database schema must start with "plugin_"')
|
||||
if not v.replace('plugin_', '').replace('_', '').isalnum():
|
||||
raise ValueError('Schema name must contain only alphanumeric characters and underscores')
|
||||
if not v.replace("plugin_", "").replace("_", "").isalnum():
|
||||
raise ValueError(
|
||||
"Schema name must contain only alphanumeric characters and underscores"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class PluginAPIEndpoint(BaseModel):
|
||||
"""Plugin API endpoint specification"""
|
||||
|
||||
path: str = Field(..., description="API endpoint path")
|
||||
methods: List[str] = Field(default=['GET'], description="Allowed HTTP methods")
|
||||
methods: List[str] = Field(default=["GET"], description="Allowed HTTP methods")
|
||||
description: str = Field("", description="Endpoint description")
|
||||
auth_required: bool = Field(True, description="Whether authentication is required")
|
||||
|
||||
@validator('methods')
|
||||
|
||||
@validator("methods")
|
||||
def validate_methods(cls, v):
|
||||
allowed_methods = ['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS']
|
||||
allowed_methods = ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]
|
||||
for method in v:
|
||||
if method not in allowed_methods:
|
||||
raise ValueError(f'Invalid HTTP method: {method}')
|
||||
raise ValueError(f"Invalid HTTP method: {method}")
|
||||
return v
|
||||
|
||||
@validator('path')
|
||||
|
||||
@validator("path")
|
||||
def validate_path(cls, v):
|
||||
if not v.startswith('/'):
|
||||
if not v.startswith("/"):
|
||||
raise ValueError('API path must start with "/"')
|
||||
return v
|
||||
|
||||
|
||||
class PluginCronJob(BaseModel):
|
||||
"""Plugin scheduled job specification"""
|
||||
|
||||
name: str = Field(..., description="Job name")
|
||||
schedule: str = Field(..., description="Cron expression")
|
||||
function: str = Field(..., description="Function to execute")
|
||||
@@ -91,41 +116,56 @@ class PluginCronJob(BaseModel):
|
||||
enabled: bool = Field(True, description="Whether job is enabled by default")
|
||||
timeout_seconds: int = Field(300, description="Job timeout in seconds")
|
||||
max_retries: int = Field(3, description="Maximum retry attempts")
|
||||
|
||||
@validator('schedule')
|
||||
|
||||
@validator("schedule")
|
||||
def validate_cron_expression(cls, v):
|
||||
# Basic cron validation - should have 5 parts
|
||||
parts = v.split()
|
||||
if len(parts) != 5:
|
||||
raise ValueError('Cron expression must have 5 parts (minute hour day month weekday)')
|
||||
raise ValueError(
|
||||
"Cron expression must have 5 parts (minute hour day month weekday)"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class PluginUIConfig(BaseModel):
|
||||
"""Plugin UI configuration"""
|
||||
configuration_schema: str = Field("./config_schema.json", description="JSON schema for configuration")
|
||||
|
||||
configuration_schema: str = Field(
|
||||
"./config_schema.json", description="JSON schema for configuration"
|
||||
)
|
||||
ui_components: str = Field("./ui/components", description="Path to UI components")
|
||||
pages: List[Dict[str, str]] = Field(default_factory=list, description="Plugin pages")
|
||||
|
||||
@validator('pages')
|
||||
pages: List[Dict[str, str]] = Field(
|
||||
default_factory=list, description="Plugin pages"
|
||||
)
|
||||
|
||||
@validator("pages")
|
||||
def validate_pages(cls, v):
|
||||
required_fields = ['name', 'path', 'component']
|
||||
required_fields = ["name", "path", "component"]
|
||||
for page in v:
|
||||
for field in required_fields:
|
||||
if field not in page:
|
||||
raise ValueError(f'Page must have {field} field')
|
||||
raise ValueError(f"Page must have {field} field")
|
||||
return v
|
||||
|
||||
|
||||
class PluginExternalServices(BaseModel):
|
||||
"""Plugin external service configuration"""
|
||||
allowed_domains: List[str] = Field(default_factory=list, description="Allowed external domains")
|
||||
webhooks: List[Dict[str, str]] = Field(default_factory=list, description="Webhook configurations")
|
||||
rate_limits: Dict[str, int] = Field(default_factory=dict, description="Rate limits per domain")
|
||||
|
||||
allowed_domains: List[str] = Field(
|
||||
default_factory=list, description="Allowed external domains"
|
||||
)
|
||||
webhooks: List[Dict[str, str]] = Field(
|
||||
default_factory=list, description="Webhook configurations"
|
||||
)
|
||||
rate_limits: Dict[str, int] = Field(
|
||||
default_factory=dict, description="Rate limits per domain"
|
||||
)
|
||||
|
||||
|
||||
class PluginMetadata(BaseModel):
|
||||
"""Plugin metadata information"""
|
||||
|
||||
name: str = Field(..., description="Plugin name (must be unique)")
|
||||
version: str = Field(..., description="Plugin version (semantic versioning)")
|
||||
description: str = Field(..., description="Plugin description")
|
||||
@@ -133,58 +173,78 @@ class PluginMetadata(BaseModel):
|
||||
license: str = Field("MIT", description="Plugin license")
|
||||
homepage: Optional[HttpUrl] = Field(None, description="Plugin homepage URL")
|
||||
repository: Optional[HttpUrl] = Field(None, description="Plugin repository URL")
|
||||
tags: List[str] = Field(default_factory=list, description="Plugin tags for discovery")
|
||||
|
||||
@validator('name')
|
||||
tags: List[str] = Field(
|
||||
default_factory=list, description="Plugin tags for discovery"
|
||||
)
|
||||
|
||||
@validator("name")
|
||||
def validate_name(cls, v):
|
||||
if not v.replace('-', '').replace('_', '').isalnum():
|
||||
raise ValueError('Plugin name must contain only alphanumeric characters, hyphens, and underscores')
|
||||
if not v.replace("-", "").replace("_", "").isalnum():
|
||||
raise ValueError(
|
||||
"Plugin name must contain only alphanumeric characters, hyphens, and underscores"
|
||||
)
|
||||
if len(v) < 3 or len(v) > 50:
|
||||
raise ValueError('Plugin name must be between 3 and 50 characters')
|
||||
raise ValueError("Plugin name must be between 3 and 50 characters")
|
||||
return v.lower()
|
||||
|
||||
@validator('version')
|
||||
|
||||
@validator("version")
|
||||
def validate_version(cls, v):
|
||||
# Basic semantic versioning validation
|
||||
parts = v.split('.')
|
||||
parts = v.split(".")
|
||||
if len(parts) != 3:
|
||||
raise ValueError('Version must follow semantic versioning (x.y.z)')
|
||||
raise ValueError("Version must follow semantic versioning (x.y.z)")
|
||||
for part in parts:
|
||||
if not part.isdigit():
|
||||
raise ValueError('Version parts must be numeric')
|
||||
raise ValueError("Version parts must be numeric")
|
||||
return v
|
||||
|
||||
|
||||
class PluginManifest(BaseModel):
|
||||
"""Complete plugin manifest specification"""
|
||||
|
||||
apiVersion: str = Field("v1", description="Manifest API version")
|
||||
kind: str = Field("Plugin", description="Resource kind")
|
||||
metadata: PluginMetadata = Field(..., description="Plugin metadata")
|
||||
spec: "PluginSpec" = Field(..., description="Plugin specification")
|
||||
|
||||
@validator('apiVersion')
|
||||
|
||||
@validator("apiVersion")
|
||||
def validate_api_version(cls, v):
|
||||
if v not in ['v1']:
|
||||
raise ValueError('Unsupported API version')
|
||||
if v not in ["v1"]:
|
||||
raise ValueError("Unsupported API version")
|
||||
return v
|
||||
|
||||
@validator('kind')
|
||||
|
||||
@validator("kind")
|
||||
def validate_kind(cls, v):
|
||||
if v != 'Plugin':
|
||||
if v != "Plugin":
|
||||
raise ValueError('Kind must be "Plugin"')
|
||||
return v
|
||||
|
||||
|
||||
class PluginSpec(BaseModel):
|
||||
"""Plugin specification details"""
|
||||
runtime: PluginRuntimeSpec = Field(default_factory=PluginRuntimeSpec, description="Runtime requirements")
|
||||
permissions: PluginPermissions = Field(default_factory=PluginPermissions, description="Permission requirements")
|
||||
database: Optional[PluginDatabaseSpec] = Field(None, description="Database configuration")
|
||||
api_endpoints: List[PluginAPIEndpoint] = Field(default_factory=list, description="API endpoints")
|
||||
cron_jobs: List[PluginCronJob] = Field(default_factory=list, description="Scheduled jobs")
|
||||
|
||||
runtime: PluginRuntimeSpec = Field(
|
||||
default_factory=PluginRuntimeSpec, description="Runtime requirements"
|
||||
)
|
||||
permissions: PluginPermissions = Field(
|
||||
default_factory=PluginPermissions, description="Permission requirements"
|
||||
)
|
||||
database: Optional[PluginDatabaseSpec] = Field(
|
||||
None, description="Database configuration"
|
||||
)
|
||||
api_endpoints: List[PluginAPIEndpoint] = Field(
|
||||
default_factory=list, description="API endpoints"
|
||||
)
|
||||
cron_jobs: List[PluginCronJob] = Field(
|
||||
default_factory=list, description="Scheduled jobs"
|
||||
)
|
||||
ui_config: Optional[PluginUIConfig] = Field(None, description="UI configuration")
|
||||
external_services: Optional[PluginExternalServices] = Field(None, description="External service configuration")
|
||||
config_schema: Dict[str, Any] = Field(default_factory=dict, description="Plugin configuration JSON schema")
|
||||
external_services: Optional[PluginExternalServices] = Field(
|
||||
None, description="External service configuration"
|
||||
)
|
||||
config_schema: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Plugin configuration JSON schema"
|
||||
)
|
||||
|
||||
|
||||
# Update forward reference
|
||||
@@ -193,111 +253,108 @@ PluginManifest.model_rebuild()
|
||||
|
||||
class PluginManifestValidator:
|
||||
"""Plugin manifest validation and parsing utilities"""
|
||||
|
||||
REQUIRED_FILES = [
|
||||
'manifest.yaml',
|
||||
'main.py',
|
||||
'requirements.txt'
|
||||
]
|
||||
|
||||
|
||||
REQUIRED_FILES = ["manifest.yaml", "main.py", "requirements.txt"]
|
||||
|
||||
OPTIONAL_FILES = [
|
||||
'config_schema.json',
|
||||
'README.md',
|
||||
'ui/components',
|
||||
'migrations',
|
||||
'tests'
|
||||
"config_schema.json",
|
||||
"README.md",
|
||||
"ui/components",
|
||||
"migrations",
|
||||
"tests",
|
||||
]
|
||||
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, manifest_path: Union[str, Path]) -> PluginManifest:
|
||||
"""Load and validate plugin manifest from YAML file"""
|
||||
manifest_path = Path(manifest_path)
|
||||
|
||||
|
||||
if not manifest_path.exists():
|
||||
raise FileNotFoundError(f"Manifest file not found: {manifest_path}")
|
||||
|
||||
|
||||
try:
|
||||
with open(manifest_path, 'r', encoding='utf-8') as f:
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest_data = yaml.safe_load(f)
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"Invalid YAML in manifest file: {e}")
|
||||
|
||||
|
||||
try:
|
||||
manifest = PluginManifest(**manifest_data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid manifest structure: {e}")
|
||||
|
||||
|
||||
# Additional validation
|
||||
cls._validate_plugin_structure(manifest_path.parent, manifest)
|
||||
|
||||
|
||||
return manifest
|
||||
|
||||
|
||||
@classmethod
|
||||
def _validate_plugin_structure(cls, plugin_dir: Path, manifest: PluginManifest):
|
||||
"""Validate plugin directory structure and required files"""
|
||||
|
||||
|
||||
# Check required files
|
||||
for required_file in cls.REQUIRED_FILES:
|
||||
file_path = plugin_dir / required_file
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"Required file missing: {required_file}")
|
||||
|
||||
|
||||
# Validate main.py contains plugin class
|
||||
main_py_path = plugin_dir / 'main.py'
|
||||
with open(main_py_path, 'r', encoding='utf-8') as f:
|
||||
main_py_path = plugin_dir / "main.py"
|
||||
with open(main_py_path, "r", encoding="utf-8") as f:
|
||||
main_content = f.read()
|
||||
|
||||
if 'BasePlugin' not in main_content:
|
||||
|
||||
if "BasePlugin" not in main_content:
|
||||
raise ValueError("main.py must contain a class inheriting from BasePlugin")
|
||||
|
||||
|
||||
# Validate requirements.txt format
|
||||
requirements_path = plugin_dir / 'requirements.txt'
|
||||
with open(requirements_path, 'r', encoding='utf-8') as f:
|
||||
requirements_path = plugin_dir / "requirements.txt"
|
||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||
requirements = f.read().strip()
|
||||
|
||||
if requirements and not all(line.strip() for line in requirements.split('\n')):
|
||||
|
||||
if requirements and not all(line.strip() for line in requirements.split("\n")):
|
||||
raise ValueError("Invalid requirements.txt format")
|
||||
|
||||
|
||||
# Validate config schema if specified
|
||||
if manifest.spec.ui_config and manifest.spec.ui_config.configuration_schema:
|
||||
schema_path = plugin_dir / manifest.spec.ui_config.configuration_schema
|
||||
if schema_path.exists():
|
||||
try:
|
||||
import json
|
||||
with open(schema_path, 'r', encoding='utf-8') as f:
|
||||
|
||||
with open(schema_path, "r", encoding="utf-8") as f:
|
||||
json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON schema: {e}")
|
||||
|
||||
|
||||
# Validate migrations if database is specified
|
||||
if manifest.spec.database:
|
||||
migrations_path = plugin_dir / manifest.spec.database.migrations_path
|
||||
if migrations_path.exists() and not migrations_path.is_dir():
|
||||
raise ValueError("Migrations path must be a directory")
|
||||
|
||||
|
||||
@classmethod
|
||||
def validate_plugin_compatibility(cls, manifest: PluginManifest) -> Dict[str, Any]:
|
||||
"""Validate plugin compatibility with platform"""
|
||||
|
||||
|
||||
compatibility_report = {
|
||||
"compatible": True,
|
||||
"warnings": [],
|
||||
"errors": [],
|
||||
"platform_version": "1.0.0"
|
||||
"platform_version": "1.0.0",
|
||||
}
|
||||
|
||||
|
||||
# Check platform API compatibility
|
||||
unsupported_apis = []
|
||||
for api in manifest.spec.permissions.platform_apis:
|
||||
if not cls._is_platform_api_supported(api):
|
||||
unsupported_apis.append(api)
|
||||
|
||||
|
||||
if unsupported_apis:
|
||||
compatibility_report["errors"].append(
|
||||
f"Unsupported platform APIs: {', '.join(unsupported_apis)}"
|
||||
)
|
||||
compatibility_report["compatible"] = False
|
||||
|
||||
|
||||
# Check Python version compatibility
|
||||
required_version = manifest.spec.runtime.python_version
|
||||
if not cls._is_python_version_supported(required_version):
|
||||
@@ -305,63 +362,82 @@ class PluginManifestValidator:
|
||||
f"Unsupported Python version: {required_version}"
|
||||
)
|
||||
compatibility_report["compatible"] = False
|
||||
|
||||
|
||||
# Check dependency compatibility
|
||||
for dependency in manifest.spec.runtime.dependencies:
|
||||
if cls._is_dependency_conflicting(dependency):
|
||||
compatibility_report["warnings"].append(
|
||||
f"Potential dependency conflict: {dependency}"
|
||||
)
|
||||
|
||||
|
||||
return compatibility_report
|
||||
|
||||
|
||||
@classmethod
|
||||
def _is_platform_api_supported(cls, api: str) -> bool:
|
||||
"""Check if platform API is supported"""
|
||||
supported_apis = [
|
||||
'chatbot:invoke', 'chatbot:manage', 'chatbot:read',
|
||||
'rag:query', 'rag:manage', 'rag:read',
|
||||
'llm:completion', 'llm:embeddings', 'llm:models',
|
||||
'workflow:execute', 'workflow:read',
|
||||
'cache:read', 'cache:write'
|
||||
"chatbot:invoke",
|
||||
"chatbot:manage",
|
||||
"chatbot:read",
|
||||
"rag:query",
|
||||
"rag:manage",
|
||||
"rag:read",
|
||||
"llm:completion",
|
||||
"llm:embeddings",
|
||||
"llm:models",
|
||||
"workflow:execute",
|
||||
"workflow:read",
|
||||
"cache:read",
|
||||
"cache:write",
|
||||
]
|
||||
|
||||
|
||||
# Support wildcard permissions
|
||||
if api.endswith(':*'):
|
||||
if api.endswith(":*"):
|
||||
base_api = api[:-2]
|
||||
return any(supported.startswith(base_api + ':') for supported in supported_apis)
|
||||
|
||||
return any(
|
||||
supported.startswith(base_api + ":") for supported in supported_apis
|
||||
)
|
||||
|
||||
return api in supported_apis
|
||||
|
||||
|
||||
@classmethod
|
||||
def _is_python_version_supported(cls, version: str) -> bool:
|
||||
"""Check if Python version is supported"""
|
||||
supported_versions = ['3.9', '3.10', '3.11', '3.12']
|
||||
supported_versions = ["3.9", "3.10", "3.11", "3.12"]
|
||||
return any(version.startswith(v) for v in supported_versions)
|
||||
|
||||
|
||||
@classmethod
|
||||
def _is_dependency_conflicting(cls, dependency: str) -> bool:
|
||||
"""Check if dependency might conflict with platform"""
|
||||
# Extract package name (before ==, >=, etc.)
|
||||
package_name = dependency.split('==')[0].split('>=')[0].split('<=')[0].split('>')[0].split('<')[0].strip()
|
||||
|
||||
package_name = (
|
||||
dependency.split("==")[0]
|
||||
.split(">=")[0]
|
||||
.split("<=")[0]
|
||||
.split(">")[0]
|
||||
.split("<")[0]
|
||||
.strip()
|
||||
)
|
||||
|
||||
# Known conflicting packages
|
||||
conflicting_packages = [
|
||||
'sqlalchemy', # Platform uses specific version
|
||||
'fastapi', # Platform uses specific version
|
||||
'pydantic', # Platform uses specific version
|
||||
'alembic' # Platform migration system
|
||||
"sqlalchemy", # Platform uses specific version
|
||||
"fastapi", # Platform uses specific version
|
||||
"pydantic", # Platform uses specific version
|
||||
"alembic", # Platform migration system
|
||||
]
|
||||
|
||||
|
||||
return package_name.lower() in conflicting_packages
|
||||
|
||||
|
||||
@classmethod
|
||||
def generate_manifest_hash(cls, manifest: PluginManifest) -> str:
|
||||
"""Generate hash for manifest content verification"""
|
||||
manifest_dict = manifest.dict()
|
||||
manifest_str = yaml.dump(manifest_dict, sort_keys=True, default_flow_style=False)
|
||||
return hashlib.sha256(manifest_str.encode('utf-8')).hexdigest()
|
||||
|
||||
manifest_str = yaml.dump(
|
||||
manifest_dict, sort_keys=True, default_flow_style=False
|
||||
)
|
||||
return hashlib.sha256(manifest_str.encode("utf-8")).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def create_example_manifest(cls, plugin_name: str) -> PluginManifest:
|
||||
"""Create an example plugin manifest for development"""
|
||||
@@ -372,29 +448,25 @@ class PluginManifestValidator:
|
||||
description=f"Example {plugin_name} plugin for Enclava platform",
|
||||
author="Enclava Team",
|
||||
license="MIT",
|
||||
tags=["integration", "example"]
|
||||
tags=["integration", "example"],
|
||||
),
|
||||
spec=PluginSpec(
|
||||
runtime=PluginRuntimeSpec(
|
||||
python_version="3.11",
|
||||
dependencies=[
|
||||
"aiohttp>=3.8.0",
|
||||
"pydantic>=2.0.0"
|
||||
]
|
||||
dependencies=["aiohttp>=3.8.0", "pydantic>=2.0.0"],
|
||||
),
|
||||
permissions=PluginPermissions(
|
||||
platform_apis=["chatbot:invoke", "rag:query"],
|
||||
plugin_scopes=["read", "write"]
|
||||
plugin_scopes=["read", "write"],
|
||||
),
|
||||
database=PluginDatabaseSpec(
|
||||
schema=f"plugin_{plugin_name}",
|
||||
migrations_path="./migrations"
|
||||
schema=f"plugin_{plugin_name}", migrations_path="./migrations"
|
||||
),
|
||||
api_endpoints=[
|
||||
PluginAPIEndpoint(
|
||||
path="/status",
|
||||
methods=["GET"],
|
||||
description="Plugin health status"
|
||||
description="Plugin health status",
|
||||
)
|
||||
],
|
||||
ui_config=PluginUIConfig(
|
||||
@@ -403,11 +475,11 @@ class PluginManifestValidator:
|
||||
{
|
||||
"name": "dashboard",
|
||||
"path": f"/plugins/{plugin_name}",
|
||||
"component": f"{plugin_name.title()}Dashboard"
|
||||
"component": f"{plugin_name.title()}Dashboard",
|
||||
}
|
||||
]
|
||||
)
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -417,20 +489,20 @@ def validate_manifest_file(manifest_path: Union[str, Path]) -> Dict[str, Any]:
|
||||
manifest = PluginManifestValidator.load_from_file(manifest_path)
|
||||
compatibility = PluginManifestValidator.validate_plugin_compatibility(manifest)
|
||||
manifest_hash = PluginManifestValidator.generate_manifest_hash(manifest)
|
||||
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"manifest": manifest,
|
||||
"compatibility": compatibility,
|
||||
"hash": manifest_hash,
|
||||
"errors": []
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"manifest": None,
|
||||
"compatibility": None,
|
||||
"hash": None,
|
||||
"errors": [str(e)]
|
||||
}
|
||||
"errors": [str(e)],
|
||||
}
|
||||
|
||||
367
backend/app/schemas/role.py
Normal file
367
backend/app/schemas/role.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
Role Management Schemas
|
||||
Pydantic models for role management API
|
||||
"""
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
|
||||
class RoleBase(BaseModel):
|
||||
"""Base role schema"""
|
||||
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
level: str = "user"
|
||||
permissions: Dict[str, Any] = {}
|
||||
can_manage_users: bool = False
|
||||
can_manage_budgets: bool = False
|
||||
can_view_reports: bool = False
|
||||
can_manage_tools: bool = False
|
||||
inherits_from: List[str] = []
|
||||
is_active: bool = True
|
||||
|
||||
@validator("name")
|
||||
def validate_name(cls, v):
|
||||
if len(v) < 2:
|
||||
raise ValueError("Role name must be at least 2 characters long")
|
||||
if len(v) > 50:
|
||||
raise ValueError("Role name must be less than 50 characters long")
|
||||
if not v.isalnum() and "_" not in v:
|
||||
raise ValueError(
|
||||
"Role name must contain only alphanumeric characters and underscores"
|
||||
)
|
||||
return v.lower()
|
||||
|
||||
@validator("display_name")
|
||||
def validate_display_name(cls, v):
|
||||
if len(v) < 2:
|
||||
raise ValueError("Display name must be at least 2 characters long")
|
||||
if len(v) > 100:
|
||||
raise ValueError("Display name must be less than 100 characters long")
|
||||
return v
|
||||
|
||||
@validator("level")
|
||||
def validate_level(cls, v):
|
||||
valid_levels = ["read_only", "user", "admin", "super_admin"]
|
||||
if v not in valid_levels:
|
||||
raise ValueError(f'Level must be one of: {", ".join(valid_levels)}')
|
||||
return v
|
||||
|
||||
|
||||
class RoleCreate(RoleBase):
|
||||
"""Schema for creating a role"""
|
||||
|
||||
is_system_role: bool = False
|
||||
|
||||
|
||||
class RoleUpdate(BaseModel):
|
||||
"""Schema for updating a role"""
|
||||
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
permissions: Optional[Dict[str, Any]] = None
|
||||
can_manage_users: Optional[bool] = None
|
||||
can_manage_budgets: Optional[bool] = None
|
||||
can_view_reports: Optional[bool] = None
|
||||
can_manage_tools: Optional[bool] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
@validator("display_name")
|
||||
def validate_display_name(cls, v):
|
||||
if v is not None:
|
||||
if len(v) < 2:
|
||||
raise ValueError("Display name must be at least 2 characters long")
|
||||
if len(v) > 100:
|
||||
raise ValueError("Display name must be less than 100 characters long")
|
||||
return v
|
||||
|
||||
|
||||
class RoleResponse(BaseModel):
|
||||
"""Role response schema"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str]
|
||||
level: str
|
||||
permissions: Dict[str, Any]
|
||||
can_manage_users: bool
|
||||
can_manage_budgets: bool
|
||||
can_view_reports: bool
|
||||
can_manage_tools: bool
|
||||
inherits_from: List[str]
|
||||
is_active: bool
|
||||
is_system_role: bool
|
||||
created_at: Optional[datetime]
|
||||
updated_at: Optional[datetime]
|
||||
user_count: Optional[int] = 0 # Number of users with this role
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls, obj):
|
||||
"""Create response from ORM object"""
|
||||
data = obj.to_dict()
|
||||
# Add user count if available
|
||||
if hasattr(obj, "users"):
|
||||
data["user_count"] = len([u for u in obj.users if u.is_active])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class RoleListResponse(BaseModel):
|
||||
"""Role list response schema"""
|
||||
|
||||
roles: List[RoleResponse]
|
||||
total: int
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
|
||||
class RoleAssignmentRequest(BaseModel):
|
||||
"""Schema for role assignment"""
|
||||
|
||||
role_id: int
|
||||
|
||||
@validator("role_id")
|
||||
def validate_role_id(cls, v):
|
||||
if v <= 0:
|
||||
raise ValueError("Role ID must be a positive integer")
|
||||
return v
|
||||
|
||||
|
||||
class RoleBulkAction(BaseModel):
|
||||
"""Schema for bulk role actions"""
|
||||
|
||||
role_ids: List[int]
|
||||
action: str # activate, deactivate, delete
|
||||
action_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
@validator("action")
|
||||
def validate_action(cls, v):
|
||||
valid_actions = ["activate", "deactivate", "delete"]
|
||||
if v not in valid_actions:
|
||||
raise ValueError(f'Action must be one of: {", ".join(valid_actions)}')
|
||||
return v
|
||||
|
||||
@validator("role_ids")
|
||||
def validate_role_ids(cls, v):
|
||||
if not v:
|
||||
raise ValueError("At least one role ID must be provided")
|
||||
if len(v) > 50:
|
||||
raise ValueError("Cannot perform bulk action on more than 50 roles at once")
|
||||
return v
|
||||
|
||||
|
||||
class RoleStatistics(BaseModel):
|
||||
"""Role statistics schema"""
|
||||
|
||||
total_roles: int
|
||||
active_roles: int
|
||||
system_roles: int
|
||||
roles_by_level: Dict[str, int]
|
||||
roles_with_users: int
|
||||
unused_roles: int
|
||||
|
||||
|
||||
class RolePermission(BaseModel):
|
||||
"""Individual role permission schema"""
|
||||
|
||||
permission: str
|
||||
granted: bool = True
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class RolePermissionTemplate(BaseModel):
|
||||
"""Role permission template schema"""
|
||||
|
||||
name: str
|
||||
display_name: str
|
||||
description: str
|
||||
level: str
|
||||
permissions: List[RolePermission]
|
||||
can_manage_users: bool = False
|
||||
can_manage_budgets: bool = False
|
||||
can_view_reports: bool = False
|
||||
can_manage_tools: bool = False
|
||||
|
||||
|
||||
class RoleHierarchy(BaseModel):
|
||||
"""Role hierarchy schema"""
|
||||
|
||||
role: RoleResponse
|
||||
parent_roles: List[RoleResponse] = []
|
||||
child_roles: List[RoleResponse] = []
|
||||
effective_permissions: Dict[str, Any]
|
||||
|
||||
|
||||
class RoleComparison(BaseModel):
|
||||
"""Role comparison schema"""
|
||||
|
||||
role1: RoleResponse
|
||||
role2: RoleResponse
|
||||
common_permissions: List[str]
|
||||
unique_to_role1: List[str]
|
||||
unique_to_role2: List[str]
|
||||
|
||||
|
||||
class RoleUsage(BaseModel):
|
||||
"""Role usage statistics schema"""
|
||||
|
||||
role_id: int
|
||||
role_name: str
|
||||
user_count: int
|
||||
active_user_count: int
|
||||
last_assigned: Optional[datetime]
|
||||
usage_trend: Dict[str, int] # Monthly usage data
|
||||
|
||||
|
||||
class RoleSearchFilter(BaseModel):
|
||||
"""Role search filter schema"""
|
||||
|
||||
search: Optional[str] = None
|
||||
level: Optional[str] = None
|
||||
is_active: Optional[bool] = None
|
||||
is_system_role: Optional[bool] = None
|
||||
has_users: Optional[bool] = None
|
||||
created_after: Optional[datetime] = None
|
||||
created_before: Optional[datetime] = None
|
||||
|
||||
|
||||
# Predefined permission templates
|
||||
ROLE_TEMPLATES = {
|
||||
"read_only": RolePermissionTemplate(
|
||||
name="read_only",
|
||||
display_name="Read Only",
|
||||
description="Can view own data only",
|
||||
level="read_only",
|
||||
permissions=[
|
||||
RolePermission(
|
||||
permission="read_own", granted=True, description="Read own data"
|
||||
),
|
||||
RolePermission(
|
||||
permission="create", granted=False, description="Create new resources"
|
||||
),
|
||||
RolePermission(
|
||||
permission="update",
|
||||
granted=False,
|
||||
description="Update existing resources",
|
||||
),
|
||||
RolePermission(
|
||||
permission="delete", granted=False, description="Delete resources"
|
||||
),
|
||||
],
|
||||
),
|
||||
"user": RolePermissionTemplate(
|
||||
name="user",
|
||||
display_name="User",
|
||||
description="Standard user with full access to own resources",
|
||||
level="user",
|
||||
permissions=[
|
||||
RolePermission(
|
||||
permission="read_own", granted=True, description="Read own data"
|
||||
),
|
||||
RolePermission(
|
||||
permission="create_own",
|
||||
granted=True,
|
||||
description="Create own resources",
|
||||
),
|
||||
RolePermission(
|
||||
permission="update_own",
|
||||
granted=True,
|
||||
description="Update own resources",
|
||||
),
|
||||
RolePermission(
|
||||
permission="delete_own",
|
||||
granted=True,
|
||||
description="Delete own resources",
|
||||
),
|
||||
RolePermission(
|
||||
permission="manage_users",
|
||||
granted=False,
|
||||
description="Manage other users",
|
||||
),
|
||||
RolePermission(
|
||||
permission="manage_all",
|
||||
granted=False,
|
||||
description="Manage all resources",
|
||||
),
|
||||
],
|
||||
inherits_from=["read_only"],
|
||||
),
|
||||
"admin": RolePermissionTemplate(
|
||||
name="admin",
|
||||
display_name="Administrator",
|
||||
description="Can manage users and view reports",
|
||||
level="admin",
|
||||
permissions=[
|
||||
RolePermission(
|
||||
permission="read_all", granted=True, description="Read all data"
|
||||
),
|
||||
RolePermission(
|
||||
permission="create_all",
|
||||
granted=True,
|
||||
description="Create any resources",
|
||||
),
|
||||
RolePermission(
|
||||
permission="update_all",
|
||||
granted=True,
|
||||
description="Update any resources",
|
||||
),
|
||||
RolePermission(
|
||||
permission="manage_users", granted=True, description="Manage users"
|
||||
),
|
||||
RolePermission(
|
||||
permission="view_reports",
|
||||
granted=True,
|
||||
description="View system reports",
|
||||
),
|
||||
RolePermission(
|
||||
permission="system_settings",
|
||||
granted=False,
|
||||
description="Modify system settings",
|
||||
),
|
||||
],
|
||||
can_manage_users=True,
|
||||
can_view_reports=True,
|
||||
inherits_from=["user"],
|
||||
),
|
||||
"super_admin": RolePermissionTemplate(
|
||||
name="super_admin",
|
||||
display_name="Super Administrator",
|
||||
description="Full system access",
|
||||
level="super_admin",
|
||||
permissions=[
|
||||
RolePermission(permission="*", granted=True, description="All permissions")
|
||||
],
|
||||
can_manage_users=True,
|
||||
can_manage_budgets=True,
|
||||
can_view_reports=True,
|
||||
can_manage_tools=True,
|
||||
inherits_from=["admin"],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# Common permission definitions
|
||||
COMMON_PERMISSIONS = {
|
||||
"read_own": "Read own data and resources",
|
||||
"read_all": "Read all data and resources",
|
||||
"create_own": "Create own resources",
|
||||
"create_all": "Create any resources",
|
||||
"update_own": "Update own resources",
|
||||
"update_all": "Update any resources",
|
||||
"delete_own": "Delete own resources",
|
||||
"delete_all": "Delete any resources",
|
||||
"manage_users": "Manage user accounts",
|
||||
"manage_roles": "Manage role assignments",
|
||||
"manage_budgets": "Manage budget settings",
|
||||
"view_reports": "View system reports",
|
||||
"manage_tools": "Manage custom tools",
|
||||
"system_settings": "Modify system settings",
|
||||
"create_api_key": "Create API keys",
|
||||
"create_tool": "Create custom tools",
|
||||
"export_data": "Export system data",
|
||||
}
|
||||
219
backend/app/schemas/tool.py
Normal file
219
backend/app/schemas/tool.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
Tool schemas for API requests and responses
|
||||
"""
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
# Tool Creation and Update Schemas
|
||||
|
||||
|
||||
class ToolCreate(BaseModel):
|
||||
"""Schema for creating a new tool"""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=100, description="Unique tool name")
|
||||
display_name: str = Field(
|
||||
..., min_length=1, max_length=200, description="Display name for the tool"
|
||||
)
|
||||
description: Optional[str] = Field(None, description="Tool description")
|
||||
tool_type: str = Field(..., description="Tool type (python, bash, docker, api)")
|
||||
code: str = Field(..., min_length=1, description="Tool implementation code")
|
||||
parameters_schema: Optional[Dict[str, Any]] = Field(
|
||||
None, description="JSON schema for parameters"
|
||||
)
|
||||
return_schema: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Expected return format schema"
|
||||
)
|
||||
timeout_seconds: Optional[int] = Field(
|
||||
30, ge=1, le=300, description="Execution timeout in seconds"
|
||||
)
|
||||
max_memory_mb: Optional[int] = Field(
|
||||
256, ge=1, le=1024, description="Maximum memory in MB"
|
||||
)
|
||||
max_cpu_seconds: Optional[float] = Field(
|
||||
10.0, ge=0.1, le=60.0, description="Maximum CPU time in seconds"
|
||||
)
|
||||
docker_image: Optional[str] = Field(
|
||||
None, max_length=200, description="Docker image for execution"
|
||||
)
|
||||
docker_command: Optional[str] = Field(None, description="Docker command to run")
|
||||
category: Optional[str] = Field(None, max_length=50, description="Tool category")
|
||||
tags: Optional[List[str]] = Field(None, description="Tool tags")
|
||||
is_public: Optional[bool] = Field(False, description="Whether tool is public")
|
||||
|
||||
@validator("tool_type")
|
||||
def validate_tool_type(cls, v):
|
||||
valid_types = ["python", "bash", "docker", "api", "custom"]
|
||||
if v not in valid_types:
|
||||
raise ValueError(f"Tool type must be one of: {valid_types}")
|
||||
return v
|
||||
|
||||
@validator("tags")
|
||||
def validate_tags(cls, v):
|
||||
if v is not None and len(v) > 10:
|
||||
raise ValueError("Maximum 10 tags allowed")
|
||||
return v
|
||||
|
||||
|
||||
class ToolUpdate(BaseModel):
|
||||
"""Schema for updating a tool"""
|
||||
|
||||
display_name: Optional[str] = Field(None, min_length=1, max_length=200)
|
||||
description: Optional[str] = None
|
||||
code: Optional[str] = Field(None, min_length=1)
|
||||
parameters_schema: Optional[Dict[str, Any]] = None
|
||||
return_schema: Optional[Dict[str, Any]] = None
|
||||
timeout_seconds: Optional[int] = Field(None, ge=1, le=300)
|
||||
max_memory_mb: Optional[int] = Field(None, ge=1, le=1024)
|
||||
max_cpu_seconds: Optional[float] = Field(None, ge=0.1, le=60.0)
|
||||
docker_image: Optional[str] = Field(None, max_length=200)
|
||||
docker_command: Optional[str] = None
|
||||
category: Optional[str] = Field(None, max_length=50)
|
||||
tags: Optional[List[str]] = None
|
||||
is_public: Optional[bool] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
@validator("tags")
|
||||
def validate_tags(cls, v):
|
||||
if v is not None and len(v) > 10:
|
||||
raise ValueError("Maximum 10 tags allowed")
|
||||
return v
|
||||
|
||||
|
||||
# Tool Response Schemas
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
"""Schema for tool response"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str]
|
||||
tool_type: str
|
||||
parameters_schema: Dict[str, Any]
|
||||
return_schema: Dict[str, Any]
|
||||
timeout_seconds: int
|
||||
max_memory_mb: int
|
||||
max_cpu_seconds: float
|
||||
docker_image: Optional[str]
|
||||
is_public: bool
|
||||
is_approved: bool
|
||||
created_by_user_id: int
|
||||
category: Optional[str]
|
||||
tags: List[str]
|
||||
usage_count: int
|
||||
last_used_at: Optional[datetime]
|
||||
is_active: bool
|
||||
created_at: Optional[datetime]
|
||||
updated_at: Optional[datetime]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ToolListResponse(BaseModel):
|
||||
"""Schema for tool list response"""
|
||||
|
||||
tools: List[ToolResponse]
|
||||
total: int
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
|
||||
# Tool Execution Schemas
|
||||
|
||||
|
||||
class ToolExecutionCreate(BaseModel):
|
||||
"""Schema for creating a tool execution"""
|
||||
|
||||
parameters: Dict[str, Any] = Field(..., description="Parameters for tool execution")
|
||||
timeout_override: Optional[int] = Field(
|
||||
None, ge=1, le=300, description="Override default timeout"
|
||||
)
|
||||
|
||||
|
||||
class ToolExecutionResponse(BaseModel):
|
||||
"""Schema for tool execution response"""
|
||||
|
||||
id: int
|
||||
tool_id: int
|
||||
tool_name: Optional[str]
|
||||
executed_by_user_id: int
|
||||
parameters: Dict[str, Any]
|
||||
status: str
|
||||
output: Optional[str]
|
||||
error_message: Optional[str]
|
||||
return_code: Optional[int]
|
||||
execution_time_ms: Optional[int]
|
||||
memory_used_mb: Optional[float]
|
||||
cpu_time_ms: Optional[int]
|
||||
container_id: Optional[str]
|
||||
started_at: Optional[datetime]
|
||||
completed_at: Optional[datetime]
|
||||
created_at: Optional[datetime]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ToolExecutionListResponse(BaseModel):
|
||||
"""Schema for tool execution list response"""
|
||||
|
||||
executions: List[ToolExecutionResponse]
|
||||
total: int
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
|
||||
# Tool Category Schemas
|
||||
|
||||
|
||||
class ToolCategoryCreate(BaseModel):
|
||||
"""Schema for creating a tool category"""
|
||||
|
||||
name: str = Field(
|
||||
..., min_length=1, max_length=50, description="Unique category name"
|
||||
)
|
||||
display_name: str = Field(
|
||||
..., min_length=1, max_length=100, description="Display name"
|
||||
)
|
||||
description: Optional[str] = Field(None, description="Category description")
|
||||
icon: Optional[str] = Field(None, max_length=50, description="Icon name")
|
||||
color: Optional[str] = Field(None, max_length=20, description="Color code")
|
||||
sort_order: Optional[int] = Field(0, description="Sort order")
|
||||
|
||||
|
||||
class ToolCategoryResponse(BaseModel):
|
||||
"""Schema for tool category response"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str]
|
||||
icon: Optional[str]
|
||||
color: Optional[str]
|
||||
sort_order: int
|
||||
is_active: bool
|
||||
created_at: Optional[datetime]
|
||||
updated_at: Optional[datetime]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# Statistics Schema
|
||||
|
||||
|
||||
class ToolStatisticsResponse(BaseModel):
|
||||
"""Schema for tool statistics response"""
|
||||
|
||||
total_tools: int
|
||||
public_tools: int
|
||||
tools_by_type: Dict[str, int]
|
||||
total_executions: int
|
||||
executions_by_status: Dict[str, int]
|
||||
recent_executions: int
|
||||
top_tools: List[Dict[str, Any]]
|
||||
user_tools: Optional[int] = None
|
||||
user_executions: Optional[int] = None
|
||||
68
backend/app/schemas/tool_calling.py
Normal file
68
backend/app/schemas/tool_calling.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Tool calling schemas for API requests and responses
|
||||
"""
|
||||
from typing import Dict, List, Optional, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ToolExecutionRequest(BaseModel):
|
||||
"""Schema for executing a tool by name"""
|
||||
|
||||
tool_name: str = Field(..., description="Name of the tool to execute")
|
||||
parameters: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Parameters for tool execution"
|
||||
)
|
||||
|
||||
|
||||
class ToolCallResponse(BaseModel):
|
||||
"""Schema for tool call response"""
|
||||
|
||||
success: bool = Field(..., description="Whether the tool call was successful")
|
||||
result: Optional[Dict[str, Any]] = Field(None, description="Tool execution result")
|
||||
error: Optional[str] = Field(None, description="Error message if failed")
|
||||
|
||||
|
||||
class ToolValidationRequest(BaseModel):
|
||||
"""Schema for validating tool availability"""
|
||||
|
||||
tool_names: List[str] = Field(..., description="List of tool names to validate")
|
||||
|
||||
|
||||
class ToolValidationResponse(BaseModel):
|
||||
"""Schema for tool validation response"""
|
||||
|
||||
tool_availability: Dict[str, bool] = Field(
|
||||
..., description="Tool name to availability mapping"
|
||||
)
|
||||
|
||||
|
||||
class ToolHistoryItem(BaseModel):
|
||||
"""Schema for tool execution history item"""
|
||||
|
||||
id: int = Field(..., description="Execution ID")
|
||||
tool_name: str = Field(..., description="Tool name")
|
||||
parameters: Dict[str, Any] = Field(..., description="Execution parameters")
|
||||
status: str = Field(..., description="Execution status")
|
||||
output: Optional[str] = Field(None, description="Tool output")
|
||||
error_message: Optional[str] = Field(None, description="Error message if failed")
|
||||
execution_time_ms: Optional[int] = Field(
|
||||
None, description="Execution time in milliseconds"
|
||||
)
|
||||
created_at: Optional[str] = Field(None, description="Creation timestamp")
|
||||
completed_at: Optional[str] = Field(None, description="Completion timestamp")
|
||||
|
||||
|
||||
class ToolHistoryResponse(BaseModel):
|
||||
"""Schema for tool execution history response"""
|
||||
|
||||
history: List[ToolHistoryItem] = Field(..., description="Tool execution history")
|
||||
total: int = Field(..., description="Total number of history items")
|
||||
|
||||
|
||||
class ToolCallRequest(BaseModel):
|
||||
"""Schema for tool call request (placeholder for future use)"""
|
||||
|
||||
message: str = Field(..., description="Chat message")
|
||||
tools: Optional[List[str]] = Field(
|
||||
None, description="Specific tools to make available"
|
||||
)
|
||||
260
backend/app/schemas/user.py
Normal file
260
backend/app/schemas/user.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""
|
||||
User Management Schemas
|
||||
Pydantic models for user management API
|
||||
"""
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, EmailStr, validator
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
"""Base user schema"""
|
||||
|
||||
email: EmailStr
|
||||
username: str
|
||||
full_name: Optional[str] = None
|
||||
is_active: bool = True
|
||||
is_verified: bool = False
|
||||
|
||||
@validator("username")
|
||||
def validate_username(cls, v):
|
||||
if len(v) < 3:
|
||||
raise ValueError("Username must be at least 3 characters long")
|
||||
if len(v) > 50:
|
||||
raise ValueError("Username must be less than 50 characters long")
|
||||
return v
|
||||
|
||||
@validator("email")
|
||||
def validate_email(cls, v):
|
||||
if len(v) > 255:
|
||||
raise ValueError("Email must be less than 255 characters long")
|
||||
return v
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
"""Schema for creating a user"""
|
||||
|
||||
password: str
|
||||
role_id: Optional[int] = None
|
||||
custom_permissions: Dict[str, Any] = {}
|
||||
|
||||
@validator("password")
|
||||
def validate_password(cls, v):
|
||||
if len(v) < 8:
|
||||
raise ValueError("Password must be at least 8 characters long")
|
||||
return v
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
"""Schema for updating a user"""
|
||||
|
||||
email: Optional[EmailStr] = None
|
||||
username: Optional[str] = None
|
||||
full_name: Optional[str] = None
|
||||
role_id: Optional[int] = None
|
||||
custom_permissions: Optional[Dict[str, Any]] = None
|
||||
is_active: Optional[bool] = None
|
||||
is_verified: Optional[bool] = None
|
||||
|
||||
@validator("username")
|
||||
def validate_username(cls, v):
|
||||
if v is not None:
|
||||
if len(v) < 3:
|
||||
raise ValueError("Username must be at least 3 characters long")
|
||||
if len(v) > 50:
|
||||
raise ValueError("Username must be less than 50 characters long")
|
||||
return v
|
||||
|
||||
|
||||
class PasswordChange(BaseModel):
|
||||
"""Schema for changing password"""
|
||||
|
||||
current_password: Optional[str] = None
|
||||
new_password: str
|
||||
confirm_password: str
|
||||
|
||||
@validator("new_password")
|
||||
def validate_new_password(cls, v):
|
||||
if len(v) < 8:
|
||||
raise ValueError("Password must be at least 8 characters long")
|
||||
return v
|
||||
|
||||
@validator("confirm_password")
|
||||
def passwords_match(cls, v, values):
|
||||
if "new_password" in values and v != values["new_password"]:
|
||||
raise ValueError("Passwords do not match")
|
||||
return v
|
||||
|
||||
|
||||
class RoleInfo(BaseModel):
|
||||
"""Role information schema"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
display_name: str
|
||||
level: str
|
||||
permissions: Dict[str, Any]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
"""User response schema"""
|
||||
|
||||
id: int
|
||||
email: str
|
||||
username: str
|
||||
full_name: Optional[str]
|
||||
is_active: bool
|
||||
is_verified: bool
|
||||
is_superuser: bool
|
||||
role_id: Optional[int]
|
||||
role: Optional[RoleInfo]
|
||||
custom_permissions: Dict[str, Any]
|
||||
account_locked: Optional[bool] = False
|
||||
account_locked_until: Optional[datetime]
|
||||
failed_login_attempts: Optional[int] = 0
|
||||
last_failed_login: Optional[datetime]
|
||||
avatar_url: Optional[str]
|
||||
bio: Optional[str]
|
||||
company: Optional[str]
|
||||
website: Optional[str]
|
||||
created_at: Optional[datetime]
|
||||
updated_at: Optional[datetime]
|
||||
last_login: Optional[datetime]
|
||||
preferences: Dict[str, Any]
|
||||
notification_settings: Dict[str, Any]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls, obj):
|
||||
"""Create response from ORM object with proper role handling"""
|
||||
data = obj.to_dict()
|
||||
if obj.role:
|
||||
data["role"] = RoleInfo.from_orm(obj.role)
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class UserListResponse(BaseModel):
|
||||
"""User list response schema"""
|
||||
|
||||
users: List[UserResponse]
|
||||
total: int
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
|
||||
class AccountLockResponse(BaseModel):
|
||||
"""Account lock response schema"""
|
||||
|
||||
user_id: int
|
||||
is_locked: bool
|
||||
locked_until: Optional[datetime]
|
||||
message: str
|
||||
|
||||
|
||||
class UserProfileUpdate(BaseModel):
|
||||
"""Schema for user profile updates (limited fields)"""
|
||||
|
||||
full_name: Optional[str] = None
|
||||
avatar_url: Optional[str] = None
|
||||
bio: Optional[str] = None
|
||||
company: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
preferences: Optional[Dict[str, Any]] = None
|
||||
notification_settings: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class UserPreferences(BaseModel):
|
||||
"""User preferences schema"""
|
||||
|
||||
theme: Optional[str] = "light"
|
||||
language: Optional[str] = "en"
|
||||
timezone: Optional[str] = "UTC"
|
||||
email_notifications: Optional[bool] = True
|
||||
security_alerts: Optional[bool] = True
|
||||
system_updates: Optional[bool] = True
|
||||
|
||||
|
||||
class UserSearchFilter(BaseModel):
|
||||
"""User search filter schema"""
|
||||
|
||||
search: Optional[str] = None
|
||||
role_id: Optional[int] = None
|
||||
is_active: Optional[bool] = None
|
||||
is_verified: Optional[bool] = None
|
||||
created_after: Optional[datetime] = None
|
||||
created_before: Optional[datetime] = None
|
||||
last_login_after: Optional[datetime] = None
|
||||
last_login_before: Optional[datetime] = None
|
||||
|
||||
|
||||
class UserBulkAction(BaseModel):
|
||||
"""Schema for bulk user actions"""
|
||||
|
||||
user_ids: List[int]
|
||||
action: str # activate, deactivate, lock, unlock, assign_role, remove_role
|
||||
action_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
@validator("action")
|
||||
def validate_action(cls, v):
|
||||
valid_actions = [
|
||||
"activate",
|
||||
"deactivate",
|
||||
"lock",
|
||||
"unlock",
|
||||
"assign_role",
|
||||
"remove_role",
|
||||
]
|
||||
if v not in valid_actions:
|
||||
raise ValueError(f'Action must be one of: {", ".join(valid_actions)}')
|
||||
return v
|
||||
|
||||
@validator("user_ids")
|
||||
def validate_user_ids(cls, v):
|
||||
if not v:
|
||||
raise ValueError("At least one user ID must be provided")
|
||||
if len(v) > 100:
|
||||
raise ValueError(
|
||||
"Cannot perform bulk action on more than 100 users at once"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class UserStatistics(BaseModel):
|
||||
"""User statistics schema"""
|
||||
|
||||
total_users: int
|
||||
active_users: int
|
||||
verified_users: int
|
||||
locked_users: int
|
||||
users_by_role: Dict[str, int]
|
||||
recent_registrations: int
|
||||
registrations_by_month: Dict[str, int]
|
||||
|
||||
|
||||
class UserActivity(BaseModel):
|
||||
"""User activity schema"""
|
||||
|
||||
user_id: int
|
||||
action: str
|
||||
resource_type: Optional[str] = None
|
||||
resource_id: Optional[int] = None
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
timestamp: datetime
|
||||
ip_address: Optional[str] = None
|
||||
user_agent: Optional[str] = None
|
||||
|
||||
|
||||
class UserActivityFilter(BaseModel):
|
||||
"""User activity filter schema"""
|
||||
|
||||
user_id: Optional[int] = None
|
||||
action: Optional[str] = None
|
||||
resource_type: Optional[str] = None
|
||||
date_from: Optional[datetime] = None
|
||||
date_to: Optional[datetime] = None
|
||||
ip_address: Optional[str] = None
|
||||
@@ -1,3 +1,3 @@
|
||||
"""
|
||||
Services package
|
||||
"""
|
||||
"""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,40 +23,46 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class APIKeyAuthService:
|
||||
"""Service for API key authentication and validation"""
|
||||
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def validate_api_key(self, api_key: str, request: Request) -> Optional[Dict[str, Any]]:
|
||||
|
||||
async def validate_api_key(
|
||||
self, api_key: str, request: Request
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Validate API key and return user context using Redis cache for performance"""
|
||||
try:
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
|
||||
# Extract key prefix for lookup
|
||||
if len(api_key) < 8:
|
||||
logger.warning(f"Invalid API key format: too short")
|
||||
return None
|
||||
|
||||
|
||||
key_prefix = api_key[:8]
|
||||
|
||||
|
||||
# Try cached verification first
|
||||
cached_verification = await cached_api_key_service.verify_api_key_cached(api_key, key_prefix)
|
||||
|
||||
cached_verification = await cached_api_key_service.verify_api_key_cached(
|
||||
api_key, key_prefix
|
||||
)
|
||||
|
||||
# Get API key data from cache or database
|
||||
context = await cached_api_key_service.get_cached_api_key(key_prefix, self.db)
|
||||
|
||||
context = await cached_api_key_service.get_cached_api_key(
|
||||
key_prefix, self.db
|
||||
)
|
||||
|
||||
if not context:
|
||||
logger.warning(f"API key not found: {key_prefix}")
|
||||
return None
|
||||
|
||||
|
||||
api_key_obj = context["api_key"]
|
||||
|
||||
|
||||
# If not in verification cache, verify and cache the result
|
||||
if not cached_verification:
|
||||
# Get the actual key hash for verification (this should be in the cached context)
|
||||
db_api_key = None
|
||||
if not hasattr(api_key_obj, 'key_hash'):
|
||||
if not hasattr(api_key_obj, "key_hash"):
|
||||
# Fallback: fetch full API key from database for hash
|
||||
stmt = select(APIKey).where(APIKey.key_prefix == key_prefix)
|
||||
result = await self.db.execute(stmt)
|
||||
@@ -66,76 +72,85 @@ class APIKeyAuthService:
|
||||
key_hash = db_api_key.key_hash
|
||||
else:
|
||||
key_hash = api_key_obj.key_hash
|
||||
|
||||
|
||||
# Verify the API key hash
|
||||
if not verify_api_key(api_key, key_hash):
|
||||
logger.warning(f"Invalid API key hash: {key_prefix}")
|
||||
return None
|
||||
|
||||
|
||||
# Cache successful verification
|
||||
await cached_api_key_service.cache_verification_result(api_key, key_prefix, key_hash, True)
|
||||
|
||||
await cached_api_key_service.cache_verification_result(
|
||||
api_key, key_prefix, key_hash, True
|
||||
)
|
||||
|
||||
# Check if key is valid (expiry, active status)
|
||||
if not api_key_obj.is_valid():
|
||||
logger.warning(f"API key expired or inactive: {key_prefix}")
|
||||
# Invalidate cache for expired keys
|
||||
await cached_api_key_service.invalidate_api_key_cache(key_prefix)
|
||||
return None
|
||||
|
||||
|
||||
# Check IP restrictions
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
if not api_key_obj.can_access_from_ip(client_ip):
|
||||
logger.warning(f"IP not allowed for API key {key_prefix}: {client_ip}")
|
||||
return None
|
||||
|
||||
|
||||
# Update last used timestamp asynchronously (performance optimization)
|
||||
await cached_api_key_service.update_last_used(context["api_key_id"], self.db)
|
||||
|
||||
await cached_api_key_service.update_last_used(
|
||||
context["api_key_id"], self.db
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"API key validation error: {e}")
|
||||
return None
|
||||
|
||||
async def check_endpoint_permission(self, context: Dict[str, Any], endpoint: str) -> bool:
|
||||
|
||||
async def check_endpoint_permission(
|
||||
self, context: Dict[str, Any], endpoint: str
|
||||
) -> bool:
|
||||
"""Check if API key has permission to access endpoint"""
|
||||
api_key: APIKey = context.get("api_key")
|
||||
if not api_key:
|
||||
return False
|
||||
|
||||
|
||||
return api_key.can_access_endpoint(endpoint)
|
||||
|
||||
|
||||
async def check_model_permission(self, context: Dict[str, Any], model: str) -> bool:
|
||||
"""Check if API key has permission to access model"""
|
||||
api_key: APIKey = context.get("api_key")
|
||||
if not api_key:
|
||||
return False
|
||||
|
||||
|
||||
return api_key.can_access_model(model)
|
||||
|
||||
|
||||
async def check_scope_permission(self, context: Dict[str, Any], scope: str) -> bool:
|
||||
"""Check if API key has required scope"""
|
||||
api_key: APIKey = context.get("api_key")
|
||||
if not api_key:
|
||||
return False
|
||||
|
||||
|
||||
return api_key.has_scope(scope)
|
||||
|
||||
async def update_usage_stats(self, context: Dict[str, Any], tokens_used: int = 0, cost_cents: int = 0):
|
||||
|
||||
async def update_usage_stats(
|
||||
self, context: Dict[str, Any], tokens_used: int = 0, cost_cents: int = 0
|
||||
):
|
||||
"""Update API key usage statistics"""
|
||||
try:
|
||||
api_key: APIKey = context.get("api_key")
|
||||
if api_key:
|
||||
api_key.update_usage(tokens_used, cost_cents)
|
||||
await self.db.commit()
|
||||
logger.info(f"Updated usage for API key {api_key.key_prefix}: +{tokens_used} tokens, +{cost_cents} cents")
|
||||
logger.info(
|
||||
f"Updated usage for API key {api_key.key_prefix}: +{tokens_used} tokens, +{cost_cents} cents"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update usage stats: {e}")
|
||||
|
||||
|
||||
async def get_api_key_context(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
request: Request, db: AsyncSession = Depends(get_db)
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Dependency to get API key context from request"""
|
||||
auth_service = APIKeyAuthService(db)
|
||||
@@ -170,7 +185,7 @@ async def require_api_key(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Valid API key required",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return context
|
||||
|
||||
@@ -180,19 +195,19 @@ async def get_current_api_key_user(
|
||||
) -> tuple:
|
||||
"""
|
||||
Dependency that returns current user and API key as a tuple
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (user, api_key)
|
||||
"""
|
||||
user = context.get("user")
|
||||
api_key = context.get("api_key")
|
||||
|
||||
|
||||
if not user or not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User or API key not found in context"
|
||||
detail="User or API key not found in context",
|
||||
)
|
||||
|
||||
|
||||
return user, api_key
|
||||
|
||||
|
||||
@@ -201,48 +216,48 @@ async def get_api_key_auth(
|
||||
) -> APIKey:
|
||||
"""
|
||||
Dependency that returns the authenticated API key object
|
||||
|
||||
|
||||
Returns:
|
||||
APIKey: The authenticated API key object
|
||||
"""
|
||||
api_key = context.get("api_key")
|
||||
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="API key not found in context"
|
||||
detail="API key not found in context",
|
||||
)
|
||||
|
||||
|
||||
return api_key
|
||||
|
||||
|
||||
class RequireScope:
|
||||
"""Dependency class for scope checking"""
|
||||
|
||||
|
||||
def __init__(self, scope: str):
|
||||
self.scope = scope
|
||||
|
||||
|
||||
async def __call__(self, context: Dict[str, Any] = Depends(require_api_key)):
|
||||
auth_service = APIKeyAuthService(context.get("db"))
|
||||
if not await auth_service.check_scope_permission(context, self.scope):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Scope '{self.scope}' required"
|
||||
detail=f"Scope '{self.scope}' required",
|
||||
)
|
||||
return context
|
||||
|
||||
|
||||
class RequireModel:
|
||||
"""Dependency class for model access checking"""
|
||||
|
||||
|
||||
def __init__(self, model: str):
|
||||
self.model = model
|
||||
|
||||
|
||||
async def __call__(self, context: Dict[str, Any] = Depends(require_api_key)):
|
||||
auth_service = APIKeyAuthService(context.get("db"))
|
||||
if not await auth_service.check_model_permission(context, self.model):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Model '{self.model}' not allowed"
|
||||
detail=f"Model '{self.model}' not allowed",
|
||||
)
|
||||
return context
|
||||
return context
|
||||
|
||||
@@ -20,17 +20,17 @@ _audit_worker_started = False
|
||||
async def _audit_worker():
|
||||
"""Background worker to process audit events"""
|
||||
from app.db.database import async_session_factory
|
||||
|
||||
|
||||
logger.info("Audit worker started")
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Get audit event from queue
|
||||
audit_data = await _audit_queue.get()
|
||||
|
||||
|
||||
if audit_data is None: # Shutdown signal
|
||||
break
|
||||
|
||||
|
||||
# Process the audit event in a separate database session
|
||||
async with async_session_factory() as db:
|
||||
try:
|
||||
@@ -41,9 +41,9 @@ async def _audit_worker():
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write audit log in background: {e}")
|
||||
await db.rollback()
|
||||
|
||||
|
||||
_audit_queue.task_done()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Audit worker error: {e}")
|
||||
await asyncio.sleep(1) # Brief pause before retrying
|
||||
@@ -68,11 +68,11 @@ async def log_audit_event_async(
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
success: bool = True,
|
||||
severity: str = "info"
|
||||
severity: str = "info",
|
||||
):
|
||||
"""
|
||||
Log an audit event asynchronously (non-blocking)
|
||||
|
||||
|
||||
This function queues the audit event for background processing,
|
||||
so it doesn't block the main request flow.
|
||||
"""
|
||||
@@ -80,11 +80,11 @@ async def log_audit_event_async(
|
||||
# Ensure audit worker is started
|
||||
if not _audit_worker_started:
|
||||
start_audit_worker()
|
||||
|
||||
|
||||
audit_details = details or {}
|
||||
if api_key_id:
|
||||
audit_details["api_key_id"] = api_key_id
|
||||
|
||||
|
||||
audit_data = {
|
||||
"user_id": user_id,
|
||||
"action": action,
|
||||
@@ -96,16 +96,16 @@ async def log_audit_event_async(
|
||||
"user_agent": user_agent,
|
||||
"success": success,
|
||||
"severity": severity,
|
||||
"created_at": datetime.utcnow()
|
||||
"created_at": datetime.utcnow(),
|
||||
}
|
||||
|
||||
|
||||
# Queue the audit event (non-blocking)
|
||||
try:
|
||||
_audit_queue.put_nowait(audit_data)
|
||||
logger.debug(f"Audit event queued: {action} on {resource_type}")
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("Audit queue full, dropping audit event")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to queue audit event: {e}")
|
||||
# Don't raise - audit failures shouldn't break main operations
|
||||
@@ -122,11 +122,11 @@ async def log_audit_event(
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
success: bool = True,
|
||||
severity: str = "info"
|
||||
severity: str = "info",
|
||||
):
|
||||
"""
|
||||
Log an audit event to the database
|
||||
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: ID of the user performing the action
|
||||
@@ -140,12 +140,12 @@ async def log_audit_event(
|
||||
success: Whether the action was successful
|
||||
severity: Severity level (info, warning, error, critical)
|
||||
"""
|
||||
|
||||
|
||||
try:
|
||||
audit_details = details or {}
|
||||
if api_key_id:
|
||||
audit_details["api_key_id"] = api_key_id
|
||||
|
||||
|
||||
audit_log = AuditLog(
|
||||
user_id=user_id,
|
||||
action=action,
|
||||
@@ -157,14 +157,16 @@ async def log_audit_event(
|
||||
user_agent=user_agent,
|
||||
success=success,
|
||||
severity=severity,
|
||||
created_at=datetime.utcnow()
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
db.add(audit_log)
|
||||
await db.flush() # Don't commit here, let the caller control the transaction
|
||||
|
||||
logger.debug(f"Audit event logged: {action} on {resource_type} by user {user_id}")
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"Audit event logged: {action} on {resource_type} by user {user_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log audit event: {e}")
|
||||
# Don't raise here as audit logging shouldn't break the main operation
|
||||
@@ -179,11 +181,11 @@ async def get_audit_logs(
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Query audit logs with filtering
|
||||
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: Filter by user ID
|
||||
@@ -194,16 +196,16 @@ async def get_audit_logs(
|
||||
end_date: Filter by end date
|
||||
limit: Maximum number of results
|
||||
offset: Number of results to skip
|
||||
|
||||
|
||||
Returns:
|
||||
List of audit log entries
|
||||
"""
|
||||
|
||||
|
||||
from sqlalchemy import select, and_
|
||||
|
||||
|
||||
query = select(AuditLog)
|
||||
conditions = []
|
||||
|
||||
|
||||
if user_id:
|
||||
conditions.append(AuditLog.user_id == user_id)
|
||||
if action:
|
||||
@@ -216,13 +218,13 @@ async def get_audit_logs(
|
||||
conditions.append(AuditLog.created_at >= start_date)
|
||||
if end_date:
|
||||
conditions.append(AuditLog.created_at <= end_date)
|
||||
|
||||
|
||||
if conditions:
|
||||
query = query.where(and_(*conditions))
|
||||
|
||||
|
||||
query = query.order_by(AuditLog.created_at.desc())
|
||||
query = query.offset(offset).limit(limit)
|
||||
|
||||
|
||||
result = await db.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
@@ -230,68 +232,80 @@ async def get_audit_logs(
|
||||
async def get_audit_stats(
|
||||
db: AsyncSession,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
end_date: Optional[datetime] = None,
|
||||
):
|
||||
"""
|
||||
Get audit statistics
|
||||
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
start_date: Start date for statistics
|
||||
end_date: End date for statistics
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with audit statistics
|
||||
"""
|
||||
|
||||
|
||||
from sqlalchemy import select, func, and_
|
||||
|
||||
|
||||
conditions = []
|
||||
if start_date:
|
||||
conditions.append(AuditLog.created_at >= start_date)
|
||||
if end_date:
|
||||
conditions.append(AuditLog.created_at <= end_date)
|
||||
|
||||
|
||||
# Total events
|
||||
total_query = select(func.count(AuditLog.id))
|
||||
if conditions:
|
||||
total_query = total_query.where(and_(*conditions))
|
||||
total_result = await db.execute(total_query)
|
||||
total_events = total_result.scalar()
|
||||
|
||||
|
||||
# Events by action
|
||||
action_query = select(AuditLog.action, func.count(AuditLog.id)).group_by(AuditLog.action)
|
||||
action_query = select(AuditLog.action, func.count(AuditLog.id)).group_by(
|
||||
AuditLog.action
|
||||
)
|
||||
if conditions:
|
||||
action_query = action_query.where(and_(*conditions))
|
||||
action_result = await db.execute(action_query)
|
||||
events_by_action = dict(action_result.fetchall())
|
||||
|
||||
|
||||
# Events by resource type
|
||||
resource_query = select(AuditLog.resource_type, func.count(AuditLog.id)).group_by(AuditLog.resource_type)
|
||||
resource_query = select(AuditLog.resource_type, func.count(AuditLog.id)).group_by(
|
||||
AuditLog.resource_type
|
||||
)
|
||||
if conditions:
|
||||
resource_query = resource_query.where(and_(*conditions))
|
||||
resource_result = await db.execute(resource_query)
|
||||
events_by_resource = dict(resource_result.fetchall())
|
||||
|
||||
|
||||
# Events by severity
|
||||
severity_query = select(AuditLog.severity, func.count(AuditLog.id)).group_by(AuditLog.severity)
|
||||
severity_query = select(AuditLog.severity, func.count(AuditLog.id)).group_by(
|
||||
AuditLog.severity
|
||||
)
|
||||
if conditions:
|
||||
severity_query = severity_query.where(and_(*conditions))
|
||||
severity_result = await db.execute(severity_query)
|
||||
events_by_severity = dict(severity_result.fetchall())
|
||||
|
||||
|
||||
# Success rate
|
||||
success_query = select(AuditLog.success, func.count(AuditLog.id)).group_by(AuditLog.success)
|
||||
success_query = select(AuditLog.success, func.count(AuditLog.id)).group_by(
|
||||
AuditLog.success
|
||||
)
|
||||
if conditions:
|
||||
success_query = success_query.where(and_(*conditions))
|
||||
success_result = await db.execute(success_query)
|
||||
success_stats = dict(success_result.fetchall())
|
||||
|
||||
|
||||
return {
|
||||
"total_events": total_events,
|
||||
"events_by_action": events_by_action,
|
||||
"events_by_resource_type": events_by_resource,
|
||||
"events_by_severity": events_by_severity,
|
||||
"success_rate": success_stats.get(True, 0) / total_events if total_events > 0 else 0,
|
||||
"failure_rate": success_stats.get(False, 0) / total_events if total_events > 0 else 0
|
||||
}
|
||||
"success_rate": success_stats.get(True, 0) / total_events
|
||||
if total_events > 0
|
||||
else 0,
|
||||
"failure_rate": success_stats.get(False, 0) / total_events
|
||||
if total_events > 0
|
||||
else 0,
|
||||
}
|
||||
|
||||
@@ -22,10 +22,11 @@ logger = get_logger(__name__)
|
||||
@dataclass
|
||||
class Permission:
|
||||
"""Represents a module permission"""
|
||||
|
||||
resource: str
|
||||
action: str
|
||||
description: str
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.resource}:{self.action}"
|
||||
|
||||
@@ -33,26 +34,28 @@ class Permission:
|
||||
@dataclass
|
||||
class ModuleMetrics:
|
||||
"""Module performance metrics"""
|
||||
|
||||
requests_processed: int = 0
|
||||
average_response_time: float = 0.0
|
||||
error_rate: float = 0.0
|
||||
last_activity: Optional[str] = None
|
||||
total_errors: int = 0
|
||||
uptime_start: float = 0.0
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if self.uptime_start == 0.0:
|
||||
self.uptime_start = time.time()
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass
|
||||
class ModuleHealth:
|
||||
"""Module health status"""
|
||||
|
||||
status: str = "healthy" # healthy, warning, error
|
||||
message: str = "Module is functioning normally"
|
||||
uptime: float = 0.0
|
||||
last_check: float = 0.0
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if self.last_check == 0.0:
|
||||
self.last_check = time.time()
|
||||
@@ -60,53 +63,55 @@ class ModuleHealth:
|
||||
|
||||
class BaseModule(ABC):
|
||||
"""Base class for all modules with interceptor pattern support"""
|
||||
|
||||
|
||||
def __init__(self, module_id: str, config: Dict[str, Any] = None):
|
||||
self.module_id = module_id
|
||||
self.config = config or {}
|
||||
self.metrics = ModuleMetrics()
|
||||
self.health = ModuleHealth()
|
||||
self.initialized = False
|
||||
self.interceptors: List['ModuleInterceptor'] = []
|
||||
|
||||
self.interceptors: List["ModuleInterceptor"] = []
|
||||
|
||||
# Register default interceptors
|
||||
self._register_default_interceptors()
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the module"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup module resources"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_required_permissions(self) -> List[Permission]:
|
||||
"""Return list of permissions this module requires"""
|
||||
return []
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def process_request(self, request: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def process_request(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Process a module request"""
|
||||
pass
|
||||
|
||||
|
||||
def get_health(self) -> ModuleHealth:
|
||||
"""Get current module health status"""
|
||||
self.health.uptime = time.time() - self.metrics.uptime_start
|
||||
self.health.last_check = time.time()
|
||||
return self.health
|
||||
|
||||
|
||||
def get_metrics(self) -> ModuleMetrics:
|
||||
"""Get current module metrics"""
|
||||
return self.metrics
|
||||
|
||||
|
||||
def check_access(self, user_permissions: List[str], action: str) -> bool:
|
||||
"""Check if user can perform action on this module"""
|
||||
required = f"modules:{self.module_id}:{action}"
|
||||
return permission_registry.check_permission(user_permissions, required)
|
||||
|
||||
|
||||
def _register_default_interceptors(self):
|
||||
"""Register default interceptors for all modules"""
|
||||
self.interceptors = [
|
||||
@@ -115,47 +120,49 @@ class BaseModule(ABC):
|
||||
ValidationInterceptor(),
|
||||
MetricsInterceptor(self),
|
||||
SecurityInterceptor(),
|
||||
AuditInterceptor(self)
|
||||
AuditInterceptor(self),
|
||||
]
|
||||
|
||||
async def execute_with_interceptors(self, request: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
async def execute_with_interceptors(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute request through interceptor chain"""
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
try:
|
||||
# Pre-processing interceptors
|
||||
for interceptor in self.interceptors:
|
||||
request, context = await interceptor.pre_process(request, context)
|
||||
|
||||
|
||||
# Execute main module logic
|
||||
response = await self.process_request(request, context)
|
||||
|
||||
|
||||
# Post-processing interceptors (in reverse order)
|
||||
for interceptor in reversed(self.interceptors):
|
||||
response = await interceptor.post_process(request, context, response)
|
||||
|
||||
|
||||
# Update metrics
|
||||
self._update_metrics(start_time, success=True)
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Update error metrics
|
||||
self._update_metrics(start_time, success=False, error=str(e))
|
||||
|
||||
|
||||
# Error handling interceptors
|
||||
for interceptor in reversed(self.interceptors):
|
||||
if hasattr(interceptor, 'handle_error'):
|
||||
if hasattr(interceptor, "handle_error"):
|
||||
await interceptor.handle_error(request, context, e)
|
||||
|
||||
|
||||
raise
|
||||
|
||||
|
||||
def _update_metrics(self, start_time: float, success: bool, error: str = None):
|
||||
"""Update module metrics"""
|
||||
duration = time.time() - start_time
|
||||
|
||||
|
||||
self.metrics.requests_processed += 1
|
||||
|
||||
|
||||
# Update average response time
|
||||
if self.metrics.requests_processed == 1:
|
||||
self.metrics.average_response_time = duration
|
||||
@@ -165,94 +172,118 @@ class BaseModule(ABC):
|
||||
self.metrics.average_response_time = (
|
||||
alpha * duration + (1 - alpha) * self.metrics.average_response_time
|
||||
)
|
||||
|
||||
|
||||
if not success:
|
||||
self.metrics.total_errors += 1
|
||||
self.metrics.error_rate = self.metrics.total_errors / self.metrics.requests_processed
|
||||
|
||||
self.metrics.error_rate = (
|
||||
self.metrics.total_errors / self.metrics.requests_processed
|
||||
)
|
||||
|
||||
# Update health status based on error rate
|
||||
if self.metrics.error_rate > 0.1: # More than 10% error rate
|
||||
self.health.status = "error"
|
||||
self.health.message = f"High error rate: {self.metrics.error_rate:.2%}"
|
||||
elif self.metrics.error_rate > 0.05: # More than 5% error rate
|
||||
self.health.status = "warning"
|
||||
self.health.message = f"Elevated error rate: {self.metrics.error_rate:.2%}"
|
||||
self.health.message = (
|
||||
f"Elevated error rate: {self.metrics.error_rate:.2%}"
|
||||
)
|
||||
else:
|
||||
self.metrics.error_rate = self.metrics.total_errors / self.metrics.requests_processed
|
||||
self.metrics.error_rate = (
|
||||
self.metrics.total_errors / self.metrics.requests_processed
|
||||
)
|
||||
if self.metrics.error_rate <= 0.05:
|
||||
self.health.status = "healthy"
|
||||
self.health.message = "Module is functioning normally"
|
||||
|
||||
|
||||
self.metrics.last_activity = time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
class ModuleInterceptor(ABC):
|
||||
"""Base class for module interceptors"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
async def pre_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""Pre-process the request"""
|
||||
return request, context
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def post_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Post-process the response"""
|
||||
return response
|
||||
|
||||
|
||||
class AuthenticationInterceptor(ModuleInterceptor):
|
||||
"""Handles authentication for module requests"""
|
||||
|
||||
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
|
||||
async def pre_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
# Check if user is authenticated (context should contain user info from API auth)
|
||||
if not context.get("user_id") and not context.get("api_key_id"):
|
||||
raise AuthenticationError("Authentication required for module access")
|
||||
|
||||
|
||||
return request, context
|
||||
|
||||
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
async def post_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
return response
|
||||
|
||||
|
||||
class PermissionInterceptor(ModuleInterceptor):
|
||||
"""Handles permission checking for module requests"""
|
||||
|
||||
|
||||
def __init__(self, module: BaseModule):
|
||||
self.module = module
|
||||
|
||||
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
|
||||
async def pre_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
action = request.get("action", "execute")
|
||||
user_permissions = context.get("user_permissions", [])
|
||||
|
||||
|
||||
if not self.module.check_access(user_permissions, action):
|
||||
raise AuthenticationError(f"Insufficient permissions for module action: {action}")
|
||||
|
||||
raise AuthenticationError(
|
||||
f"Insufficient permissions for module action: {action}"
|
||||
)
|
||||
|
||||
return request, context
|
||||
|
||||
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
async def post_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
return response
|
||||
|
||||
|
||||
class ValidationInterceptor(ModuleInterceptor):
|
||||
"""Handles request validation and sanitization"""
|
||||
|
||||
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
|
||||
async def pre_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
# Sanitize request data
|
||||
sanitized_request = self._sanitize_request(request)
|
||||
|
||||
|
||||
# Validate request structure
|
||||
self._validate_request(sanitized_request)
|
||||
|
||||
|
||||
return sanitized_request, context
|
||||
|
||||
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
async def post_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
# Sanitize response data
|
||||
return self._sanitize_response(response)
|
||||
|
||||
|
||||
def _sanitize_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Remove potentially dangerous content from request"""
|
||||
sanitized = copy.deepcopy(request)
|
||||
|
||||
|
||||
# Define dangerous patterns
|
||||
dangerous_patterns = [
|
||||
r"<script[^>]*>.*?</script>",
|
||||
@@ -264,19 +295,21 @@ class ValidationInterceptor(ModuleInterceptor):
|
||||
r"eval\s*\(",
|
||||
r"Function\s*\(",
|
||||
]
|
||||
|
||||
|
||||
def sanitize_value(value):
|
||||
if isinstance(value, str):
|
||||
# Remove dangerous patterns
|
||||
for pattern in dangerous_patterns:
|
||||
value = re.sub(pattern, "", value, flags=re.IGNORECASE)
|
||||
|
||||
|
||||
# Limit string length
|
||||
max_length = 10000
|
||||
if len(value) > max_length:
|
||||
value = value[:max_length]
|
||||
logger.warning(f"Truncated long string in request: {len(value)} chars")
|
||||
|
||||
logger.warning(
|
||||
f"Truncated long string in request: {len(value)} chars"
|
||||
)
|
||||
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return {k: sanitize_value(v) for k, v in value.items()}
|
||||
@@ -284,124 +317,159 @@ class ValidationInterceptor(ModuleInterceptor):
|
||||
return [sanitize_value(item) for item in value]
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
return sanitize_value(sanitized)
|
||||
|
||||
|
||||
def _sanitize_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Sanitize response data"""
|
||||
# Similar sanitization for responses
|
||||
return self._sanitize_request(response)
|
||||
|
||||
|
||||
def _validate_request(self, request: Dict[str, Any]):
|
||||
"""Validate request structure"""
|
||||
# Check for required fields
|
||||
if not isinstance(request, dict):
|
||||
raise ValidationError("Request must be a dictionary")
|
||||
|
||||
|
||||
# Check request size
|
||||
request_str = json.dumps(request)
|
||||
max_size = 10 * 1024 * 1024 # 10MB
|
||||
if len(request_str.encode()) > max_size:
|
||||
raise ValidationError(f"Request size exceeds maximum allowed ({max_size} bytes)")
|
||||
raise ValidationError(
|
||||
f"Request size exceeds maximum allowed ({max_size} bytes)"
|
||||
)
|
||||
|
||||
|
||||
class MetricsInterceptor(ModuleInterceptor):
|
||||
"""Handles metrics collection for module requests"""
|
||||
|
||||
|
||||
def __init__(self, module: BaseModule):
|
||||
self.module = module
|
||||
|
||||
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
|
||||
async def pre_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
context["_metrics_start_time"] = time.time()
|
||||
return request, context
|
||||
|
||||
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
async def post_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
# Metrics are updated in the base module execute_with_interceptors method
|
||||
return response
|
||||
|
||||
|
||||
class SecurityInterceptor(ModuleInterceptor):
|
||||
"""Handles security-related processing"""
|
||||
|
||||
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
|
||||
async def pre_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
# Add security headers to context
|
||||
context["security_headers"] = {
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-Frame-Options": "DENY",
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains"
|
||||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
|
||||
}
|
||||
|
||||
|
||||
# Check for suspicious patterns
|
||||
self._check_security_patterns(request)
|
||||
|
||||
|
||||
return request, context
|
||||
|
||||
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
async def post_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
# Remove any sensitive information from response
|
||||
return self._remove_sensitive_data(response)
|
||||
|
||||
|
||||
def _check_security_patterns(self, request: Dict[str, Any]):
|
||||
"""Check for suspicious security patterns"""
|
||||
request_str = json.dumps(request).lower()
|
||||
|
||||
|
||||
suspicious_patterns = [
|
||||
"union select", "drop table", "insert into", "delete from",
|
||||
"script>", "javascript:", "eval(", "expression(",
|
||||
"../", "..\\", "file://", "ftp://",
|
||||
"union select",
|
||||
"drop table",
|
||||
"insert into",
|
||||
"delete from",
|
||||
"script>",
|
||||
"javascript:",
|
||||
"eval(",
|
||||
"expression(",
|
||||
"../",
|
||||
"..\\",
|
||||
"file://",
|
||||
"ftp://",
|
||||
]
|
||||
|
||||
|
||||
for pattern in suspicious_patterns:
|
||||
if pattern in request_str:
|
||||
logger.warning(f"Suspicious pattern detected in request: {pattern}")
|
||||
# Could implement additional security measures here
|
||||
|
||||
|
||||
def _remove_sensitive_data(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Remove sensitive data from response"""
|
||||
sensitive_keys = ["password", "secret", "token", "key", "private"]
|
||||
|
||||
|
||||
def clean_dict(obj):
|
||||
if isinstance(obj, dict):
|
||||
return {
|
||||
k: "***REDACTED***" if any(sk in k.lower() for sk in sensitive_keys) else clean_dict(v)
|
||||
k: "***REDACTED***"
|
||||
if any(sk in k.lower() for sk in sensitive_keys)
|
||||
else clean_dict(v)
|
||||
for k, v in obj.items()
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [clean_dict(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
return clean_dict(response)
|
||||
|
||||
|
||||
class AuditInterceptor(ModuleInterceptor):
|
||||
"""Handles audit logging for module requests"""
|
||||
|
||||
|
||||
def __init__(self, module: BaseModule):
|
||||
self.module = module
|
||||
|
||||
async def pre_process(self, request: Dict[str, Any], context: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
|
||||
async def pre_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
context["_audit_start_time"] = time.time()
|
||||
context["_audit_request_hash"] = self._hash_request(request)
|
||||
return request, context
|
||||
|
||||
async def post_process(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
async def post_process(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
await self._log_audit_event(request, context, response, success=True)
|
||||
return response
|
||||
|
||||
async def handle_error(self, request: Dict[str, Any], context: Dict[str, Any], error: Exception):
|
||||
|
||||
async def handle_error(
|
||||
self, request: Dict[str, Any], context: Dict[str, Any], error: Exception
|
||||
):
|
||||
"""Handle error logging"""
|
||||
await self._log_audit_event(request, context, {"error": str(error)}, success=False)
|
||||
|
||||
await self._log_audit_event(
|
||||
request, context, {"error": str(error)}, success=False
|
||||
)
|
||||
|
||||
def _hash_request(self, request: Dict[str, Any]) -> str:
|
||||
"""Create a hash of the request for audit purposes"""
|
||||
request_str = json.dumps(request, sort_keys=True)
|
||||
return hashlib.sha256(request_str.encode()).hexdigest()[:16]
|
||||
|
||||
async def _log_audit_event(self, request: Dict[str, Any], context: Dict[str, Any], response: Dict[str, Any], success: bool):
|
||||
|
||||
async def _log_audit_event(
|
||||
self,
|
||||
request: Dict[str, Any],
|
||||
context: Dict[str, Any],
|
||||
response: Dict[str, Any],
|
||||
success: bool,
|
||||
):
|
||||
"""Log audit event"""
|
||||
duration = time.time() - context.get("_audit_start_time", time.time())
|
||||
|
||||
|
||||
audit_data = {
|
||||
"module_id": self.module.module_id,
|
||||
"action": request.get("action", "execute"),
|
||||
@@ -413,11 +481,11 @@ class AuditInterceptor(ModuleInterceptor):
|
||||
"duration_ms": int(duration * 1000),
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}
|
||||
|
||||
|
||||
if not success:
|
||||
audit_data["error"] = response.get("error", "Unknown error")
|
||||
|
||||
|
||||
# Log the audit event
|
||||
logger.info(f"Module audit: {json.dumps(audit_data)}")
|
||||
|
||||
# Could also store in database for persistent audit trail
|
||||
|
||||
# Could also store in database for persistent audit trail
|
||||
|
||||
@@ -28,6 +28,7 @@ from sqlalchemy.orm import Session
|
||||
@dataclass
|
||||
class PluginContext:
|
||||
"""Plugin execution context with user and authentication info"""
|
||||
|
||||
user_id: Optional[str] = None
|
||||
api_key_id: Optional[str] = None
|
||||
user_permissions: List[str] = None
|
||||
@@ -38,25 +39,29 @@ class PluginContext:
|
||||
|
||||
class PlatformAPIClient:
|
||||
"""Secure client for plugins to access platform APIs"""
|
||||
|
||||
|
||||
def __init__(self, plugin_id: str, plugin_token: str):
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_token = plugin_token
|
||||
self.base_url = settings.INTERNAL_API_URL or "http://localhost:58000"
|
||||
self.logger = get_logger(f"plugin.{plugin_id}.api_client")
|
||||
|
||||
async def _make_request(self, method: str, endpoint: str, **kwargs) -> Dict[str, Any]:
|
||||
|
||||
async def _make_request(
|
||||
self, method: str, endpoint: str, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Make authenticated request to platform API"""
|
||||
headers = kwargs.setdefault('headers', {})
|
||||
headers.update({
|
||||
'Authorization': f'Bearer {self.plugin_token}',
|
||||
'X-Plugin-ID': self.plugin_id,
|
||||
'X-Platform-Client': 'plugin',
|
||||
'Content-Type': 'application/json'
|
||||
})
|
||||
|
||||
headers = kwargs.setdefault("headers", {})
|
||||
headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {self.plugin_token}",
|
||||
"X-Plugin-ID": self.plugin_id,
|
||||
"X-Platform-Client": "plugin",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
)
|
||||
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.request(method, url, **kwargs) as response:
|
||||
@@ -64,154 +69,162 @@ class PlatformAPIClient:
|
||||
error_text = await response.text()
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"Platform API error: {error_text}"
|
||||
detail=f"Platform API error: {error_text}",
|
||||
)
|
||||
|
||||
if response.content_type == 'application/json':
|
||||
|
||||
if response.content_type == "application/json":
|
||||
return await response.json()
|
||||
else:
|
||||
return {"data": await response.text()}
|
||||
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
self.logger.error(f"Platform API client error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"Platform API unavailable: {str(e)}"
|
||||
status_code=503, detail=f"Platform API unavailable: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
async def get(self, endpoint: str, **kwargs) -> Dict[str, Any]:
|
||||
"""GET request to platform API"""
|
||||
return await self._make_request('GET', endpoint, **kwargs)
|
||||
|
||||
async def post(self, endpoint: str, data: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
|
||||
return await self._make_request("GET", endpoint, **kwargs)
|
||||
|
||||
async def post(
|
||||
self, endpoint: str, data: Dict[str, Any] = None, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""POST request to platform API"""
|
||||
if data:
|
||||
kwargs['json'] = data
|
||||
return await self._make_request('POST', endpoint, **kwargs)
|
||||
|
||||
async def put(self, endpoint: str, data: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
|
||||
kwargs["json"] = data
|
||||
return await self._make_request("POST", endpoint, **kwargs)
|
||||
|
||||
async def put(
|
||||
self, endpoint: str, data: Dict[str, Any] = None, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""PUT request to platform API"""
|
||||
if data:
|
||||
kwargs['json'] = data
|
||||
return await self._make_request('PUT', endpoint, **kwargs)
|
||||
|
||||
kwargs["json"] = data
|
||||
return await self._make_request("PUT", endpoint, **kwargs)
|
||||
|
||||
async def delete(self, endpoint: str, **kwargs) -> Dict[str, Any]:
|
||||
"""DELETE request to platform API"""
|
||||
return await self._make_request('DELETE', endpoint, **kwargs)
|
||||
|
||||
return await self._make_request("DELETE", endpoint, **kwargs)
|
||||
|
||||
# Platform-specific API methods
|
||||
async def call_chatbot_api(self, chatbot_id: str, message: str,
|
||||
context: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
async def call_chatbot_api(
|
||||
self, chatbot_id: str, message: str, context: Dict[str, Any] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Consume platform chatbot API"""
|
||||
return await self.post(
|
||||
f"/api/v1/chatbot/external/{chatbot_id}/chat",
|
||||
{
|
||||
"message": message,
|
||||
"context": context or {}
|
||||
}
|
||||
{"message": message, "context": context or {}},
|
||||
)
|
||||
|
||||
async def call_llm_api(self, model: str, messages: List[Dict[str, Any]],
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
|
||||
async def call_llm_api(
|
||||
self, model: str, messages: List[Dict[str, Any]], **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Consume platform LLM API"""
|
||||
return await self.post(
|
||||
"/api/v1/llm/chat/completions",
|
||||
{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**kwargs
|
||||
}
|
||||
{"model": model, "messages": messages, **kwargs},
|
||||
)
|
||||
|
||||
async def search_rag(self, collection: str, query: str,
|
||||
top_k: int = 5) -> Dict[str, Any]:
|
||||
|
||||
async def search_rag(
|
||||
self, collection: str, query: str, top_k: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
"""Consume platform RAG API"""
|
||||
return await self.post(
|
||||
f"/api/v1/rag/collections/{collection}/search",
|
||||
{
|
||||
"query": query,
|
||||
"top_k": top_k
|
||||
}
|
||||
{"query": query, "top_k": top_k},
|
||||
)
|
||||
|
||||
|
||||
async def get_embeddings(self, model: str, input_text: str) -> Dict[str, Any]:
|
||||
"""Generate embeddings via platform API"""
|
||||
return await self.post(
|
||||
"/api/v1/llm/embeddings",
|
||||
{
|
||||
"model": model,
|
||||
"input": input_text
|
||||
}
|
||||
"/api/v1/llm/embeddings", {"model": model, "input": input_text}
|
||||
)
|
||||
|
||||
|
||||
class PluginConfigManager:
|
||||
"""Manages plugin configuration with validation and encryption"""
|
||||
|
||||
|
||||
def __init__(self, plugin_id: str):
|
||||
self.plugin_id = plugin_id
|
||||
self.logger = get_logger(f"plugin.{plugin_id}.config")
|
||||
|
||||
|
||||
async def get_config(self, user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get plugin configuration for user (or default)"""
|
||||
try:
|
||||
# Use dependency injection to get database session
|
||||
from app.db.database import SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
|
||||
|
||||
try:
|
||||
# Query for active configuration
|
||||
query = db.query(PluginConfiguration).filter(
|
||||
PluginConfiguration.plugin_id == self.plugin_id,
|
||||
PluginConfiguration.is_active == True
|
||||
PluginConfiguration.is_active == True,
|
||||
)
|
||||
|
||||
|
||||
if user_id:
|
||||
# Get user-specific configuration
|
||||
query = query.filter(PluginConfiguration.user_id == user_id)
|
||||
else:
|
||||
# Get default configuration (is_default=True)
|
||||
query = query.filter(PluginConfiguration.is_default == True)
|
||||
|
||||
|
||||
config = query.first()
|
||||
|
||||
|
||||
if config:
|
||||
self.logger.debug(f"Retrieved configuration for plugin {self.plugin_id}, user {user_id}")
|
||||
self.logger.debug(
|
||||
f"Retrieved configuration for plugin {self.plugin_id}, user {user_id}"
|
||||
)
|
||||
return config.config_data or {}
|
||||
else:
|
||||
self.logger.debug(f"No configuration found for plugin {self.plugin_id}, user {user_id}")
|
||||
self.logger.debug(
|
||||
f"No configuration found for plugin {self.plugin_id}, user {user_id}"
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get configuration: {e}")
|
||||
return {}
|
||||
|
||||
async def save_config(self, config: Dict[str, Any], user_id: str,
|
||||
name: str = "Default Configuration",
|
||||
description: str = None) -> bool:
|
||||
|
||||
async def save_config(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
user_id: str,
|
||||
name: str = "Default Configuration",
|
||||
description: str = None,
|
||||
) -> bool:
|
||||
"""Save plugin configuration for user"""
|
||||
try:
|
||||
from app.db.database import SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
|
||||
|
||||
try:
|
||||
# Check if configuration already exists
|
||||
existing_config = db.query(PluginConfiguration).filter(
|
||||
PluginConfiguration.plugin_id == self.plugin_id,
|
||||
PluginConfiguration.user_id == user_id,
|
||||
PluginConfiguration.name == name
|
||||
).first()
|
||||
|
||||
existing_config = (
|
||||
db.query(PluginConfiguration)
|
||||
.filter(
|
||||
PluginConfiguration.plugin_id == self.plugin_id,
|
||||
PluginConfiguration.user_id == user_id,
|
||||
PluginConfiguration.name == name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_config:
|
||||
# Update existing configuration
|
||||
existing_config.config_data = config
|
||||
existing_config.description = description
|
||||
existing_config.is_active = True
|
||||
|
||||
self.logger.info(f"Updated configuration for plugin {self.plugin_id}, user {user_id}")
|
||||
|
||||
self.logger.info(
|
||||
f"Updated configuration for plugin {self.plugin_id}, user {user_id}"
|
||||
)
|
||||
else:
|
||||
# Create new configuration
|
||||
new_config = PluginConfiguration(
|
||||
@@ -222,40 +235,48 @@ class PluginConfigManager:
|
||||
config_data=config,
|
||||
is_active=True,
|
||||
is_default=(name == "Default Configuration"),
|
||||
created_by_user_id=user_id
|
||||
created_by_user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
# If this is the first configuration for this user/plugin, make it default
|
||||
existing_count = db.query(PluginConfiguration).filter(
|
||||
PluginConfiguration.plugin_id == self.plugin_id,
|
||||
PluginConfiguration.user_id == user_id
|
||||
).count()
|
||||
|
||||
existing_count = (
|
||||
db.query(PluginConfiguration)
|
||||
.filter(
|
||||
PluginConfiguration.plugin_id == self.plugin_id,
|
||||
PluginConfiguration.user_id == user_id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
if existing_count == 0:
|
||||
new_config.is_default = True
|
||||
|
||||
|
||||
db.add(new_config)
|
||||
self.logger.info(f"Created new configuration for plugin {self.plugin_id}, user {user_id}")
|
||||
|
||||
self.logger.info(
|
||||
f"Created new configuration for plugin {self.plugin_id}, user {user_id}"
|
||||
)
|
||||
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
self.logger.error(f"Database error saving configuration: {e}")
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to save configuration: {e}")
|
||||
return False
|
||||
|
||||
async def validate_config(self, config: Dict[str, Any],
|
||||
schema: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
||||
|
||||
async def validate_config(
|
||||
self, config: Dict[str, Any], schema: Dict[str, Any]
|
||||
) -> Tuple[bool, List[str]]:
|
||||
"""Validate configuration against JSON schema"""
|
||||
try:
|
||||
import jsonschema
|
||||
|
||||
jsonschema.validate(config, schema)
|
||||
return True, []
|
||||
except jsonschema.ValidationError as e:
|
||||
@@ -266,45 +287,52 @@ class PluginConfigManager:
|
||||
|
||||
class PluginLogger:
|
||||
"""Plugin-specific logger with security filtering"""
|
||||
|
||||
|
||||
def __init__(self, plugin_id: str):
|
||||
self.plugin_id = plugin_id
|
||||
self.logger = get_logger(f"plugin.{plugin_id}")
|
||||
|
||||
|
||||
# Sensitive data patterns to filter
|
||||
self.sensitive_patterns = [
|
||||
r'password', r'token', r'key', r'secret', r'api_key',
|
||||
r'bearer', r'authorization', r'credential'
|
||||
r"password",
|
||||
r"token",
|
||||
r"key",
|
||||
r"secret",
|
||||
r"api_key",
|
||||
r"bearer",
|
||||
r"authorization",
|
||||
r"credential",
|
||||
]
|
||||
|
||||
|
||||
def _filter_sensitive_data(self, message: str) -> str:
|
||||
"""Filter sensitive data from log messages"""
|
||||
import re
|
||||
|
||||
filtered_message = message
|
||||
for pattern in self.sensitive_patterns:
|
||||
filtered_message = re.sub(
|
||||
f'{pattern}[=:]\s*["\']?([^"\'\\s]+)["\']?',
|
||||
f'{pattern}=***REDACTED***',
|
||||
f"{pattern}[=:]\s*[\"']?([^\"'\\s]+)[\"']?",
|
||||
f"{pattern}=***REDACTED***",
|
||||
filtered_message,
|
||||
flags=re.IGNORECASE
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
return filtered_message
|
||||
|
||||
|
||||
def info(self, message: str, **kwargs):
|
||||
"""Log info message with sensitive data filtering"""
|
||||
filtered_message = self._filter_sensitive_data(message)
|
||||
self.logger.info(f"[PLUGIN:{self.plugin_id}] {filtered_message}", **kwargs)
|
||||
|
||||
|
||||
def warning(self, message: str, **kwargs):
|
||||
"""Log warning message with sensitive data filtering"""
|
||||
filtered_message = self._filter_sensitive_data(message)
|
||||
self.logger.warning(f"[PLUGIN:{self.plugin_id}] {filtered_message}", **kwargs)
|
||||
|
||||
|
||||
def error(self, message: str, **kwargs):
|
||||
"""Log error message with sensitive data filtering"""
|
||||
filtered_message = self._filter_sensitive_data(message)
|
||||
self.logger.error(f"[PLUGIN:{self.plugin_id}] {filtered_message}", **kwargs)
|
||||
|
||||
|
||||
def debug(self, message: str, **kwargs):
|
||||
"""Log debug message with sensitive data filtering"""
|
||||
filtered_message = self._filter_sensitive_data(message)
|
||||
@@ -313,45 +341,45 @@ class PluginLogger:
|
||||
|
||||
class BasePlugin(ABC):
|
||||
"""Base class for all Enclava plugins with security and isolation"""
|
||||
|
||||
|
||||
def __init__(self, manifest: PluginManifest, plugin_token: str):
|
||||
self.manifest = manifest
|
||||
self.plugin_id = manifest.metadata.name
|
||||
self.version = manifest.metadata.version
|
||||
|
||||
|
||||
# Initialize plugin services
|
||||
self.api_client = PlatformAPIClient(self.plugin_id, plugin_token)
|
||||
self.config = PluginConfigManager(self.plugin_id)
|
||||
self.logger = PluginLogger(self.plugin_id)
|
||||
|
||||
|
||||
# Plugin state
|
||||
self.initialized = False
|
||||
self._startup_time = time.time()
|
||||
self._request_count = 0
|
||||
self._error_count = 0
|
||||
|
||||
|
||||
self.logger.info(f"Plugin {self.plugin_id} v{self.version} instantiated")
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_api_router(self) -> APIRouter:
|
||||
"""Return FastAPI router for plugin endpoints"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize plugin resources and connections"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup(self) -> bool:
|
||||
"""Cleanup plugin resources on shutdown"""
|
||||
pass
|
||||
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Plugin health status"""
|
||||
uptime = time.time() - self._startup_time
|
||||
error_rate = self._error_count / max(self._request_count, 1)
|
||||
|
||||
|
||||
return {
|
||||
"status": "healthy" if error_rate < 0.1 else "warning",
|
||||
"plugin": self.plugin_id,
|
||||
@@ -360,28 +388,28 @@ class BasePlugin(ABC):
|
||||
"request_count": self._request_count,
|
||||
"error_count": self._error_count,
|
||||
"error_rate": round(error_rate, 3),
|
||||
"initialized": self.initialized
|
||||
"initialized": self.initialized,
|
||||
}
|
||||
|
||||
|
||||
async def get_configuration_schema(self) -> Dict[str, Any]:
|
||||
"""Return JSON schema for plugin configuration"""
|
||||
return self.manifest.spec.config_schema
|
||||
|
||||
|
||||
async def execute_cron_job(self, job_name: str) -> bool:
|
||||
"""Execute scheduled cron job"""
|
||||
self.logger.info(f"Executing cron job: {job_name}")
|
||||
|
||||
|
||||
# Find job in manifest
|
||||
job_spec = None
|
||||
for job in self.manifest.spec.cron_jobs:
|
||||
if job.name == job_name:
|
||||
job_spec = job
|
||||
break
|
||||
|
||||
|
||||
if not job_spec:
|
||||
self.logger.error(f"Cron job not found: {job_name}")
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
# Get the function to execute
|
||||
if hasattr(self, job_spec.function):
|
||||
@@ -390,34 +418,37 @@ class BasePlugin(ABC):
|
||||
result = await func()
|
||||
else:
|
||||
result = func()
|
||||
|
||||
|
||||
self.logger.info(f"Cron job {job_name} completed successfully")
|
||||
return bool(result)
|
||||
else:
|
||||
self.logger.error(f"Cron job function not found: {job_spec.function}")
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Cron job {job_name} failed: {e}")
|
||||
self._error_count += 1
|
||||
return False
|
||||
|
||||
|
||||
def get_auth_context(self) -> PluginContext:
|
||||
"""Dependency to get authentication context in API endpoints"""
|
||||
|
||||
async def _get_context(request: Request) -> PluginContext:
|
||||
# Extract authentication info from request
|
||||
# This would be populated by the plugin API gateway
|
||||
return PluginContext(
|
||||
user_id=request.headers.get('X-User-ID'),
|
||||
api_key_id=request.headers.get('X-API-Key-ID'),
|
||||
user_permissions=request.headers.get('X-User-Permissions', '').split(','),
|
||||
ip_address=request.headers.get('X-Real-IP'),
|
||||
user_agent=request.headers.get('User-Agent'),
|
||||
request_id=request.headers.get('X-Request-ID')
|
||||
user_id=request.headers.get("X-User-ID"),
|
||||
api_key_id=request.headers.get("X-API-Key-ID"),
|
||||
user_permissions=request.headers.get("X-User-Permissions", "").split(
|
||||
","
|
||||
),
|
||||
ip_address=request.headers.get("X-Real-IP"),
|
||||
user_agent=request.headers.get("User-Agent"),
|
||||
request_id=request.headers.get("X-Request-ID"),
|
||||
)
|
||||
|
||||
|
||||
return Depends(_get_context)
|
||||
|
||||
|
||||
def _track_request(self, success: bool = True):
|
||||
"""Track request metrics"""
|
||||
self._request_count += 1
|
||||
@@ -427,164 +458,206 @@ class BasePlugin(ABC):
|
||||
|
||||
class PluginSecurityManager:
|
||||
"""Manages plugin security and isolation"""
|
||||
|
||||
|
||||
BLOCKED_IMPORTS = {
|
||||
# Core platform modules
|
||||
'app.db', 'app.models', 'app.core', 'app.services',
|
||||
'sqlalchemy', 'alembic',
|
||||
|
||||
"app.db",
|
||||
"app.models",
|
||||
"app.core",
|
||||
"app.services",
|
||||
"sqlalchemy",
|
||||
"alembic",
|
||||
# Security sensitive
|
||||
'subprocess', 'eval', 'exec', 'compile', '__import__',
|
||||
'os.system', 'os.popen', 'os.spawn',
|
||||
|
||||
"subprocess",
|
||||
"eval",
|
||||
"exec",
|
||||
"compile",
|
||||
"__import__",
|
||||
"os.system",
|
||||
"os.popen",
|
||||
"os.spawn",
|
||||
# System access
|
||||
'socket', 'multiprocessing', 'threading'
|
||||
"socket",
|
||||
"multiprocessing",
|
||||
"threading",
|
||||
}
|
||||
|
||||
|
||||
ALLOWED_IMPORTS = {
|
||||
# Standard library
|
||||
'asyncio', 'aiohttp', 'json', 'datetime', 'typing', 'pydantic',
|
||||
'logging', 'time', 'uuid', 'hashlib', 'base64', 'pathlib',
|
||||
're', 'urllib.parse', 'dataclasses', 'enum',
|
||||
|
||||
"asyncio",
|
||||
"aiohttp",
|
||||
"json",
|
||||
"datetime",
|
||||
"typing",
|
||||
"pydantic",
|
||||
"logging",
|
||||
"time",
|
||||
"uuid",
|
||||
"hashlib",
|
||||
"base64",
|
||||
"pathlib",
|
||||
"re",
|
||||
"urllib.parse",
|
||||
"dataclasses",
|
||||
"enum",
|
||||
# Approved third-party
|
||||
'httpx', 'requests', 'pandas', 'numpy', 'yaml',
|
||||
|
||||
"httpx",
|
||||
"requests",
|
||||
"pandas",
|
||||
"numpy",
|
||||
"yaml",
|
||||
# Plugin framework
|
||||
'app.services.base_plugin', 'app.schemas.plugin_manifest'
|
||||
"app.services.base_plugin",
|
||||
"app.schemas.plugin_manifest",
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def validate_plugin_import(cls, import_name: str) -> bool:
|
||||
"""Validate if plugin can import a module"""
|
||||
# Block dangerous imports
|
||||
if any(import_name.startswith(blocked) for blocked in cls.BLOCKED_IMPORTS):
|
||||
raise SecurityError(f"Import '{import_name}' not allowed in plugin environment")
|
||||
|
||||
raise SecurityError(
|
||||
f"Import '{import_name}' not allowed in plugin environment"
|
||||
)
|
||||
|
||||
# Allow explicit safe imports
|
||||
if any(import_name.startswith(allowed) for allowed in cls.ALLOWED_IMPORTS):
|
||||
return True
|
||||
|
||||
|
||||
# Log potentially unsafe imports
|
||||
logger = get_logger("plugin.security")
|
||||
logger.warning(f"Potentially unsafe import in plugin: {import_name}")
|
||||
return True
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_plugin_sandbox(cls, plugin_id: str) -> Dict[str, Any]:
|
||||
"""Create isolated environment for plugin execution"""
|
||||
return {
|
||||
'max_memory_mb': 128,
|
||||
'max_cpu_percent': 25,
|
||||
'max_disk_mb': 100,
|
||||
'max_api_calls_per_minute': 100,
|
||||
'allowed_domains': [], # Will be populated from manifest
|
||||
'network_timeout_seconds': 30
|
||||
"max_memory_mb": 128,
|
||||
"max_cpu_percent": 25,
|
||||
"max_disk_mb": 100,
|
||||
"max_api_calls_per_minute": 100,
|
||||
"allowed_domains": [], # Will be populated from manifest
|
||||
"network_timeout_seconds": 30,
|
||||
}
|
||||
|
||||
|
||||
class PluginLoader:
|
||||
"""Loads and validates plugins from directories"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.logger = get_logger("plugin.loader")
|
||||
self.loaded_plugins: Dict[str, BasePlugin] = {}
|
||||
|
||||
|
||||
async def load_plugin(self, plugin_dir: Path, plugin_token: str) -> BasePlugin:
|
||||
"""Load a plugin from a directory"""
|
||||
self.logger.info(f"Loading plugin from: {plugin_dir}")
|
||||
|
||||
|
||||
# Load and validate manifest
|
||||
manifest_path = plugin_dir / "manifest.yaml"
|
||||
validation_result = validate_manifest_file(manifest_path)
|
||||
|
||||
|
||||
if not validation_result["valid"]:
|
||||
raise ValidationError(f"Invalid plugin manifest: {validation_result['errors']}")
|
||||
|
||||
raise ValidationError(
|
||||
f"Invalid plugin manifest: {validation_result['errors']}"
|
||||
)
|
||||
|
||||
manifest = validation_result["manifest"]
|
||||
|
||||
|
||||
# Check compatibility
|
||||
compatibility = validation_result["compatibility"]
|
||||
if not compatibility["compatible"]:
|
||||
raise ValidationError(f"Plugin incompatible: {compatibility['errors']}")
|
||||
|
||||
|
||||
# Load plugin module
|
||||
main_py_path = plugin_dir / "main.py"
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
f"plugin_{manifest.metadata.name}",
|
||||
main_py_path
|
||||
f"plugin_{manifest.metadata.name}", main_py_path
|
||||
)
|
||||
|
||||
|
||||
if not spec or not spec.loader:
|
||||
raise ValidationError(f"Cannot load plugin module: {main_py_path}")
|
||||
|
||||
|
||||
# Security check before loading
|
||||
self._validate_plugin_security(main_py_path)
|
||||
|
||||
|
||||
# Load module
|
||||
plugin_module = importlib.util.module_from_spec(spec)
|
||||
|
||||
|
||||
# Add to sys.modules to allow imports
|
||||
sys.modules[spec.name] = plugin_module
|
||||
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(plugin_module)
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Failed to execute plugin module: {e}")
|
||||
|
||||
|
||||
# Find plugin class
|
||||
plugin_class = None
|
||||
for attr_name in dir(plugin_module):
|
||||
attr = getattr(plugin_module, attr_name)
|
||||
if (isinstance(attr, type) and
|
||||
issubclass(attr, BasePlugin) and
|
||||
attr is not BasePlugin):
|
||||
if (
|
||||
isinstance(attr, type)
|
||||
and issubclass(attr, BasePlugin)
|
||||
and attr is not BasePlugin
|
||||
):
|
||||
plugin_class = attr
|
||||
break
|
||||
|
||||
|
||||
if not plugin_class:
|
||||
raise ValidationError("Plugin must contain a class inheriting from BasePlugin")
|
||||
|
||||
raise ValidationError(
|
||||
"Plugin must contain a class inheriting from BasePlugin"
|
||||
)
|
||||
|
||||
# Instantiate plugin
|
||||
plugin_instance = plugin_class(manifest, plugin_token)
|
||||
|
||||
|
||||
# Initialize plugin
|
||||
try:
|
||||
await plugin_instance.initialize()
|
||||
plugin_instance.initialized = True
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Plugin initialization failed: {e}")
|
||||
|
||||
|
||||
self.loaded_plugins[manifest.metadata.name] = plugin_instance
|
||||
self.logger.info(f"Plugin {manifest.metadata.name} loaded successfully")
|
||||
|
||||
|
||||
return plugin_instance
|
||||
|
||||
|
||||
def _validate_plugin_security(self, main_py_path: Path):
|
||||
"""Validate plugin code for security issues"""
|
||||
with open(main_py_path, 'r', encoding='utf-8') as f:
|
||||
with open(main_py_path, "r", encoding="utf-8") as f:
|
||||
code_content = f.read()
|
||||
|
||||
|
||||
# Check for dangerous patterns
|
||||
dangerous_patterns = [
|
||||
'eval(', 'exec(', 'compile(',
|
||||
'subprocess.', 'os.system', 'os.popen',
|
||||
'__import__', 'importlib.import_module',
|
||||
'from app.db', 'from app.models',
|
||||
'sqlalchemy', 'SessionLocal'
|
||||
"eval(",
|
||||
"exec(",
|
||||
"compile(",
|
||||
"subprocess.",
|
||||
"os.system",
|
||||
"os.popen",
|
||||
"__import__",
|
||||
"importlib.import_module",
|
||||
"from app.db",
|
||||
"from app.models",
|
||||
"sqlalchemy",
|
||||
"SessionLocal",
|
||||
]
|
||||
|
||||
|
||||
for pattern in dangerous_patterns:
|
||||
if pattern in code_content:
|
||||
raise SecurityError(f"Dangerous pattern detected in plugin code: {pattern}")
|
||||
|
||||
raise SecurityError(
|
||||
f"Dangerous pattern detected in plugin code: {pattern}"
|
||||
)
|
||||
|
||||
async def unload_plugin(self, plugin_id: str) -> bool:
|
||||
"""Unload a plugin and cleanup resources"""
|
||||
if plugin_id not in self.loaded_plugins:
|
||||
return False
|
||||
|
||||
|
||||
plugin = self.loaded_plugins[plugin_id]
|
||||
|
||||
|
||||
try:
|
||||
await plugin.cleanup()
|
||||
del self.loaded_plugins[plugin_id]
|
||||
@@ -593,11 +666,11 @@ class PluginLoader:
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error unloading plugin {plugin_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_plugin(self, plugin_id: str) -> Optional[BasePlugin]:
|
||||
"""Get loaded plugin by ID"""
|
||||
return self.loaded_plugins.get(plugin_id)
|
||||
|
||||
|
||||
def list_loaded_plugins(self) -> List[str]:
|
||||
"""List all loaded plugin IDs"""
|
||||
return list(self.loaded_plugins.keys())
|
||||
return list(self.loaded_plugins.keys())
|
||||
|
||||
@@ -21,11 +21,13 @@ logger = get_logger(__name__)
|
||||
|
||||
class BudgetEnforcementError(Exception):
|
||||
"""Custom exception for budget enforcement failures"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BudgetExceededError(BudgetEnforcementError):
|
||||
"""Exception raised when budget would be exceeded"""
|
||||
|
||||
def __init__(self, message: str, budget: Budget, requested_cost: int):
|
||||
super().__init__(message)
|
||||
self.budget = budget
|
||||
@@ -34,6 +36,7 @@ class BudgetExceededError(BudgetEnforcementError):
|
||||
|
||||
class BudgetWarningError(BudgetEnforcementError):
|
||||
"""Exception raised when budget warning threshold is reached"""
|
||||
|
||||
def __init__(self, message: str, budget: Budget, requested_cost: int):
|
||||
super().__init__(message)
|
||||
self.budget = budget
|
||||
@@ -42,6 +45,7 @@ class BudgetWarningError(BudgetEnforcementError):
|
||||
|
||||
class BudgetConcurrencyError(BudgetEnforcementError):
|
||||
"""Exception raised when budget update fails due to concurrency"""
|
||||
|
||||
def __init__(self, message: str, retry_count: int = 0):
|
||||
super().__init__(message)
|
||||
self.retry_count = retry_count
|
||||
@@ -49,6 +53,7 @@ class BudgetConcurrencyError(BudgetEnforcementError):
|
||||
|
||||
class BudgetAtomicError(BudgetEnforcementError):
|
||||
"""Exception raised when atomic budget operation fails"""
|
||||
|
||||
def __init__(self, message: str, budget_id: int, requested_amount: int):
|
||||
super().__init__(message)
|
||||
self.budget_id = budget_id
|
||||
@@ -57,84 +62,96 @@ class BudgetAtomicError(BudgetEnforcementError):
|
||||
|
||||
class BudgetEnforcementService:
|
||||
"""Service for enforcing budget limits and tracking usage"""
|
||||
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.max_retries = 3
|
||||
self.retry_delay_base = 0.1 # Base delay in seconds
|
||||
|
||||
|
||||
def atomic_check_and_reserve_budget(
|
||||
self,
|
||||
api_key: APIKey,
|
||||
model_name: str,
|
||||
estimated_tokens: int,
|
||||
endpoint: str = None
|
||||
endpoint: str = None,
|
||||
) -> Tuple[bool, Optional[str], List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
Atomically check budget compliance and reserve spending
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (is_allowed, error_message, warnings, reserved_budget_ids)
|
||||
"""
|
||||
estimated_cost = estimate_request_cost(model_name, estimated_tokens)
|
||||
budgets = self._get_applicable_budgets(api_key, model_name, endpoint)
|
||||
|
||||
|
||||
if not budgets:
|
||||
logger.debug(f"No applicable budgets found for API key {api_key.id}")
|
||||
return True, None, [], []
|
||||
|
||||
|
||||
# Try atomic reservation with retries
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
return self._attempt_atomic_reservation(budgets, estimated_cost, api_key.id, attempt)
|
||||
return self._attempt_atomic_reservation(
|
||||
budgets, estimated_cost, api_key.id, attempt
|
||||
)
|
||||
except BudgetConcurrencyError as e:
|
||||
if attempt == self.max_retries - 1:
|
||||
logger.error(f"Atomic budget reservation failed after {self.max_retries} attempts: {e}")
|
||||
return False, f"Budget check temporarily unavailable (concurrency limit)", [], []
|
||||
|
||||
logger.error(
|
||||
f"Atomic budget reservation failed after {self.max_retries} attempts: {e}"
|
||||
)
|
||||
return (
|
||||
False,
|
||||
f"Budget check temporarily unavailable (concurrency limit)",
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
# Exponential backoff with jitter
|
||||
delay = self.retry_delay_base * (2 ** attempt) + random.uniform(0, 0.1)
|
||||
delay = self.retry_delay_base * (2**attempt) + random.uniform(0, 0.1)
|
||||
time.sleep(delay)
|
||||
logger.info(f"Retrying atomic budget reservation (attempt {attempt + 2})")
|
||||
logger.info(
|
||||
f"Retrying atomic budget reservation (attempt {attempt + 2})"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in atomic budget reservation: {e}")
|
||||
return False, f"Budget check failed: {str(e)}", [], []
|
||||
|
||||
|
||||
return False, "Budget check failed after maximum retries", [], []
|
||||
|
||||
|
||||
def _attempt_atomic_reservation(
|
||||
self,
|
||||
budgets: List[Budget],
|
||||
estimated_cost: int,
|
||||
api_key_id: int,
|
||||
attempt: int
|
||||
self, budgets: List[Budget], estimated_cost: int, api_key_id: int, attempt: int
|
||||
) -> Tuple[bool, Optional[str], List[Dict[str, Any]], List[int]]:
|
||||
"""Attempt to atomically reserve budget across all applicable budgets"""
|
||||
warnings = []
|
||||
reserved_budget_ids = []
|
||||
|
||||
|
||||
try:
|
||||
# Begin transaction
|
||||
self.db.begin()
|
||||
|
||||
|
||||
for budget in budgets:
|
||||
# Lock budget row for update to prevent concurrent modifications
|
||||
locked_budget = self.db.query(Budget).filter(
|
||||
Budget.id == budget.id
|
||||
).with_for_update().first()
|
||||
|
||||
locked_budget = (
|
||||
self.db.query(Budget)
|
||||
.filter(Budget.id == budget.id)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
|
||||
if not locked_budget:
|
||||
raise BudgetAtomicError(f"Budget {budget.id} not found", budget.id, estimated_cost)
|
||||
|
||||
raise BudgetAtomicError(
|
||||
f"Budget {budget.id} not found", budget.id, estimated_cost
|
||||
)
|
||||
|
||||
# Reset budget if expired and auto-renew enabled
|
||||
if locked_budget.is_expired() and locked_budget.auto_renew:
|
||||
self._reset_expired_budget(locked_budget)
|
||||
self.db.flush() # Ensure reset is applied before checking
|
||||
|
||||
|
||||
# Skip inactive or expired budgets
|
||||
if not locked_budget.is_active or locked_budget.is_expired():
|
||||
continue
|
||||
|
||||
|
||||
# Check if request would exceed budget using atomic operation
|
||||
if not self._atomic_can_spend(locked_budget, estimated_cost):
|
||||
error_msg = (
|
||||
@@ -144,56 +161,74 @@ class BudgetEnforcementService:
|
||||
f"Requested: ${estimated_cost/100:.4f}, "
|
||||
f"Remaining: ${(locked_budget.limit_cents - locked_budget.current_usage_cents)/100:.2f}"
|
||||
)
|
||||
logger.warning(f"Budget exceeded for API key {api_key_id}: {error_msg}")
|
||||
logger.warning(
|
||||
f"Budget exceeded for API key {api_key_id}: {error_msg}"
|
||||
)
|
||||
self.db.rollback()
|
||||
return False, error_msg, warnings, []
|
||||
|
||||
|
||||
# Check warning threshold
|
||||
if locked_budget.would_exceed_warning(estimated_cost) and not locked_budget.is_warning_sent:
|
||||
if (
|
||||
locked_budget.would_exceed_warning(estimated_cost)
|
||||
and not locked_budget.is_warning_sent
|
||||
):
|
||||
warning_msg = (
|
||||
f"Budget '{locked_budget.name}' approaching limit. "
|
||||
f"Usage will be ${(locked_budget.current_usage_cents + estimated_cost)/100:.2f} "
|
||||
f"of ${locked_budget.limit_cents/100:.2f} "
|
||||
f"({((locked_budget.current_usage_cents + estimated_cost) / locked_budget.limit_cents * 100):.1f}%)"
|
||||
)
|
||||
warnings.append({
|
||||
"type": "budget_warning",
|
||||
"budget_id": locked_budget.id,
|
||||
"budget_name": locked_budget.name,
|
||||
"message": warning_msg,
|
||||
"current_usage_cents": locked_budget.current_usage_cents + estimated_cost,
|
||||
"limit_cents": locked_budget.limit_cents,
|
||||
"usage_percentage": (locked_budget.current_usage_cents + estimated_cost) / locked_budget.limit_cents * 100
|
||||
})
|
||||
logger.info(f"Budget warning for API key {api_key_id}: {warning_msg}")
|
||||
|
||||
warnings.append(
|
||||
{
|
||||
"type": "budget_warning",
|
||||
"budget_id": locked_budget.id,
|
||||
"budget_name": locked_budget.name,
|
||||
"message": warning_msg,
|
||||
"current_usage_cents": locked_budget.current_usage_cents
|
||||
+ estimated_cost,
|
||||
"limit_cents": locked_budget.limit_cents,
|
||||
"usage_percentage": (
|
||||
locked_budget.current_usage_cents + estimated_cost
|
||||
)
|
||||
/ locked_budget.limit_cents
|
||||
* 100,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Budget warning for API key {api_key_id}: {warning_msg}"
|
||||
)
|
||||
|
||||
# Reserve the budget (temporarily add estimated cost)
|
||||
self._atomic_reserve_usage(locked_budget, estimated_cost)
|
||||
reserved_budget_ids.append(locked_budget.id)
|
||||
|
||||
|
||||
# Commit the reservation
|
||||
self.db.commit()
|
||||
logger.debug(f"Successfully reserved budget for API key {api_key_id}, estimated cost: ${estimated_cost/100:.4f}")
|
||||
logger.debug(
|
||||
f"Successfully reserved budget for API key {api_key_id}, estimated cost: ${estimated_cost/100:.4f}"
|
||||
)
|
||||
return True, None, warnings, reserved_budget_ids
|
||||
|
||||
|
||||
except IntegrityError as e:
|
||||
self.db.rollback()
|
||||
raise BudgetConcurrencyError(f"Database integrity error during budget reservation: {e}", attempt)
|
||||
raise BudgetConcurrencyError(
|
||||
f"Database integrity error during budget reservation: {e}", attempt
|
||||
)
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error in atomic budget reservation: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def _atomic_can_spend(self, budget: Budget, amount_cents: int) -> bool:
|
||||
"""Atomically check if budget can accommodate spending"""
|
||||
if not budget.is_active or not budget.is_in_period():
|
||||
return False
|
||||
|
||||
|
||||
if not budget.enforce_hard_limit:
|
||||
return True
|
||||
|
||||
|
||||
return (budget.current_usage_cents + amount_cents) <= budget.limit_cents
|
||||
|
||||
|
||||
def _atomic_reserve_usage(self, budget: Budget, amount_cents: int):
|
||||
"""Atomically reserve usage in budget (add to current usage)"""
|
||||
# Use database-level atomic update
|
||||
@@ -203,26 +238,37 @@ class BudgetEnforcementService:
|
||||
.values(
|
||||
current_usage_cents=Budget.current_usage_cents + amount_cents,
|
||||
updated_at=datetime.utcnow(),
|
||||
is_exceeded=Budget.current_usage_cents + amount_cents >= Budget.limit_cents,
|
||||
is_exceeded=Budget.current_usage_cents + amount_cents
|
||||
>= Budget.limit_cents,
|
||||
is_warning_sent=(
|
||||
Budget.is_warning_sent |
|
||||
((Budget.warning_threshold_cents.isnot(None)) &
|
||||
(Budget.current_usage_cents + amount_cents >= Budget.warning_threshold_cents))
|
||||
)
|
||||
Budget.is_warning_sent
|
||||
| (
|
||||
(Budget.warning_threshold_cents.isnot(None))
|
||||
& (
|
||||
Budget.current_usage_cents + amount_cents
|
||||
>= Budget.warning_threshold_cents
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if result.rowcount != 1:
|
||||
raise BudgetAtomicError(f"Failed to update budget {budget.id}", budget.id, amount_cents)
|
||||
|
||||
raise BudgetAtomicError(
|
||||
f"Failed to update budget {budget.id}", budget.id, amount_cents
|
||||
)
|
||||
|
||||
# Update the in-memory object to reflect changes
|
||||
budget.current_usage_cents += amount_cents
|
||||
budget.updated_at = datetime.utcnow()
|
||||
if budget.current_usage_cents >= budget.limit_cents:
|
||||
budget.is_exceeded = True
|
||||
if budget.warning_threshold_cents and budget.current_usage_cents >= budget.warning_threshold_cents:
|
||||
if (
|
||||
budget.warning_threshold_cents
|
||||
and budget.current_usage_cents >= budget.warning_threshold_cents
|
||||
):
|
||||
budget.is_warning_sent = True
|
||||
|
||||
|
||||
def atomic_finalize_usage(
|
||||
self,
|
||||
reserved_budget_ids: List[int],
|
||||
@@ -230,11 +276,11 @@ class BudgetEnforcementService:
|
||||
model_name: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
endpoint: str = None
|
||||
endpoint: str = None,
|
||||
) -> List[Budget]:
|
||||
"""
|
||||
Finalize actual usage and adjust reservations
|
||||
|
||||
|
||||
Args:
|
||||
reserved_budget_ids: Budget IDs that had usage reserved
|
||||
api_key: API key that made the request
|
||||
@@ -242,101 +288,110 @@ class BudgetEnforcementService:
|
||||
input_tokens: Actual input tokens used
|
||||
output_tokens: Actual output tokens used
|
||||
endpoint: API endpoint that was accessed
|
||||
|
||||
|
||||
Returns:
|
||||
List of budgets that were updated
|
||||
"""
|
||||
if not reserved_budget_ids:
|
||||
return []
|
||||
|
||||
|
||||
try:
|
||||
actual_cost = CostCalculator.calculate_cost_cents(model_name, input_tokens, output_tokens)
|
||||
actual_cost = CostCalculator.calculate_cost_cents(
|
||||
model_name, input_tokens, output_tokens
|
||||
)
|
||||
updated_budgets = []
|
||||
|
||||
|
||||
# Begin transaction for finalization
|
||||
self.db.begin()
|
||||
|
||||
|
||||
for budget_id in reserved_budget_ids:
|
||||
# Lock budget for update
|
||||
budget = self.db.query(Budget).filter(
|
||||
Budget.id == budget_id
|
||||
).with_for_update().first()
|
||||
|
||||
budget = (
|
||||
self.db.query(Budget)
|
||||
.filter(Budget.id == budget_id)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
|
||||
if not budget:
|
||||
logger.warning(f"Budget {budget_id} not found during finalization")
|
||||
continue
|
||||
|
||||
|
||||
if budget.is_active and budget.is_in_period():
|
||||
# Calculate adjustment (actual cost - estimated cost already reserved)
|
||||
# Note: We don't know the exact estimated cost that was reserved
|
||||
# So we'll just set to actual cost (this is safe as we already reserved)
|
||||
self._atomic_set_actual_usage(budget, actual_cost, input_tokens, output_tokens)
|
||||
self._atomic_set_actual_usage(
|
||||
budget, actual_cost, input_tokens, output_tokens
|
||||
)
|
||||
updated_budgets.append(budget)
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"Finalized usage for budget {budget.id}: "
|
||||
f"${actual_cost/100:.4f} (total: ${budget.current_usage_cents/100:.2f})"
|
||||
)
|
||||
|
||||
|
||||
# Commit finalization
|
||||
self.db.commit()
|
||||
return updated_budgets
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error finalizing budget usage: {e}")
|
||||
self.db.rollback()
|
||||
return []
|
||||
|
||||
def _atomic_set_actual_usage(self, budget: Budget, actual_cost: int, input_tokens: int, output_tokens: int):
|
||||
|
||||
def _atomic_set_actual_usage(
|
||||
self, budget: Budget, actual_cost: int, input_tokens: int, output_tokens: int
|
||||
):
|
||||
"""Set the actual usage cost (replacing any reservation)"""
|
||||
# For simplicity, we'll just ensure the current usage reflects actual cost
|
||||
# In a more sophisticated system, you might track reservations separately
|
||||
# For now, the reservation system ensures we don't exceed limits
|
||||
# and the actual cost will be very close to estimated cost
|
||||
pass # The reservation already added the estimated cost, actual cost adjustment is minimal
|
||||
|
||||
|
||||
def check_budget_compliance(
|
||||
self,
|
||||
api_key: APIKey,
|
||||
model_name: str,
|
||||
estimated_tokens: int,
|
||||
endpoint: str = None
|
||||
endpoint: str = None,
|
||||
) -> Tuple[bool, Optional[str], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Check if a request complies with budget limits
|
||||
|
||||
|
||||
Args:
|
||||
api_key: API key making the request
|
||||
model_name: Model being used
|
||||
estimated_tokens: Estimated token usage
|
||||
endpoint: API endpoint being accessed
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (is_allowed, error_message, warnings)
|
||||
"""
|
||||
try:
|
||||
# Calculate estimated cost
|
||||
estimated_cost = estimate_request_cost(model_name, estimated_tokens)
|
||||
|
||||
|
||||
# Get applicable budgets
|
||||
budgets = self._get_applicable_budgets(api_key, model_name, endpoint)
|
||||
|
||||
|
||||
if not budgets:
|
||||
logger.debug(f"No applicable budgets found for API key {api_key.id}")
|
||||
return True, None, []
|
||||
|
||||
|
||||
warnings = []
|
||||
|
||||
|
||||
# Check each budget
|
||||
for budget in budgets:
|
||||
# Reset budget if period expired and auto-renew is enabled
|
||||
if budget.is_expired() and budget.auto_renew:
|
||||
self._reset_expired_budget(budget)
|
||||
|
||||
|
||||
# Skip inactive or expired budgets
|
||||
if not budget.is_active or budget.is_expired():
|
||||
continue
|
||||
|
||||
|
||||
# Check if request would exceed budget
|
||||
if not budget.can_spend(estimated_cost):
|
||||
error_msg = (
|
||||
@@ -346,145 +401,160 @@ class BudgetEnforcementService:
|
||||
f"Requested: ${estimated_cost/100:.4f}, "
|
||||
f"Remaining: ${(budget.limit_cents - budget.current_usage_cents)/100:.2f}"
|
||||
)
|
||||
logger.warning(f"Budget exceeded for API key {api_key.id}: {error_msg}")
|
||||
logger.warning(
|
||||
f"Budget exceeded for API key {api_key.id}: {error_msg}"
|
||||
)
|
||||
return False, error_msg, warnings
|
||||
|
||||
|
||||
# Check if request would trigger warning
|
||||
if budget.would_exceed_warning(estimated_cost) and not budget.is_warning_sent:
|
||||
if (
|
||||
budget.would_exceed_warning(estimated_cost)
|
||||
and not budget.is_warning_sent
|
||||
):
|
||||
warning_msg = (
|
||||
f"Budget '{budget.name}' approaching limit. "
|
||||
f"Usage will be ${(budget.current_usage_cents + estimated_cost)/100:.2f} "
|
||||
f"of ${budget.limit_cents/100:.2f} "
|
||||
f"({((budget.current_usage_cents + estimated_cost) / budget.limit_cents * 100):.1f}%)"
|
||||
)
|
||||
warnings.append({
|
||||
"type": "budget_warning",
|
||||
"budget_id": budget.id,
|
||||
"budget_name": budget.name,
|
||||
"message": warning_msg,
|
||||
"current_usage_cents": budget.current_usage_cents + estimated_cost,
|
||||
"limit_cents": budget.limit_cents,
|
||||
"usage_percentage": (budget.current_usage_cents + estimated_cost) / budget.limit_cents * 100
|
||||
})
|
||||
logger.info(f"Budget warning for API key {api_key.id}: {warning_msg}")
|
||||
|
||||
warnings.append(
|
||||
{
|
||||
"type": "budget_warning",
|
||||
"budget_id": budget.id,
|
||||
"budget_name": budget.name,
|
||||
"message": warning_msg,
|
||||
"current_usage_cents": budget.current_usage_cents
|
||||
+ estimated_cost,
|
||||
"limit_cents": budget.limit_cents,
|
||||
"usage_percentage": (
|
||||
budget.current_usage_cents + estimated_cost
|
||||
)
|
||||
/ budget.limit_cents
|
||||
* 100,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Budget warning for API key {api_key.id}: {warning_msg}"
|
||||
)
|
||||
|
||||
return True, None, warnings
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking budget compliance: {e}")
|
||||
# Allow request on error to avoid blocking legitimate usage
|
||||
return True, None, []
|
||||
|
||||
|
||||
def record_usage(
|
||||
self,
|
||||
api_key: APIKey,
|
||||
model_name: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
endpoint: str = None
|
||||
endpoint: str = None,
|
||||
) -> List[Budget]:
|
||||
"""
|
||||
Record actual usage against applicable budgets
|
||||
|
||||
|
||||
Args:
|
||||
api_key: API key that made the request
|
||||
model_name: Model that was used
|
||||
input_tokens: Actual input tokens used
|
||||
output_tokens: Actual output tokens used
|
||||
endpoint: API endpoint that was accessed
|
||||
|
||||
|
||||
Returns:
|
||||
List of budgets that were updated
|
||||
"""
|
||||
try:
|
||||
# Calculate actual cost
|
||||
actual_cost = CostCalculator.calculate_cost_cents(model_name, input_tokens, output_tokens)
|
||||
|
||||
actual_cost = CostCalculator.calculate_cost_cents(
|
||||
model_name, input_tokens, output_tokens
|
||||
)
|
||||
|
||||
# Get applicable budgets
|
||||
budgets = self._get_applicable_budgets(api_key, model_name, endpoint)
|
||||
|
||||
|
||||
updated_budgets = []
|
||||
|
||||
|
||||
for budget in budgets:
|
||||
if budget.is_active and budget.is_in_period():
|
||||
# Add usage to budget
|
||||
budget.add_usage(actual_cost)
|
||||
updated_budgets.append(budget)
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"Recorded usage for budget {budget.id}: "
|
||||
f"${actual_cost/100:.4f} (total: ${budget.current_usage_cents/100:.2f})"
|
||||
)
|
||||
|
||||
|
||||
# Commit changes
|
||||
self.db.commit()
|
||||
|
||||
|
||||
return updated_budgets
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error recording budget usage: {e}")
|
||||
self.db.rollback()
|
||||
return []
|
||||
|
||||
|
||||
def _get_applicable_budgets(
|
||||
self,
|
||||
api_key: APIKey,
|
||||
model_name: str = None,
|
||||
endpoint: str = None
|
||||
self, api_key: APIKey, model_name: str = None, endpoint: str = None
|
||||
) -> List[Budget]:
|
||||
"""Get budgets that apply to the given request"""
|
||||
|
||||
|
||||
# Build query conditions
|
||||
conditions = [
|
||||
Budget.is_active == True,
|
||||
or_(
|
||||
and_(Budget.user_id == api_key.user_id, Budget.api_key_id.is_(None)), # User budget
|
||||
Budget.api_key_id == api_key.id # API key specific budget
|
||||
)
|
||||
and_(
|
||||
Budget.user_id == api_key.user_id, Budget.api_key_id.is_(None)
|
||||
), # User budget
|
||||
Budget.api_key_id == api_key.id, # API key specific budget
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# Query budgets
|
||||
query = self.db.query(Budget).filter(and_(*conditions))
|
||||
budgets = query.all()
|
||||
|
||||
|
||||
# Filter budgets based on allowed models/endpoints
|
||||
applicable_budgets = []
|
||||
|
||||
|
||||
for budget in budgets:
|
||||
# Check model restrictions
|
||||
if model_name and budget.allowed_models:
|
||||
if model_name not in budget.allowed_models:
|
||||
continue
|
||||
|
||||
|
||||
# Check endpoint restrictions
|
||||
if endpoint and budget.allowed_endpoints:
|
||||
if endpoint not in budget.allowed_endpoints:
|
||||
continue
|
||||
|
||||
|
||||
applicable_budgets.append(budget)
|
||||
|
||||
|
||||
return applicable_budgets
|
||||
|
||||
|
||||
def _reset_expired_budget(self, budget: Budget):
|
||||
"""Reset an expired budget for the next period"""
|
||||
try:
|
||||
budget.reset_period()
|
||||
self.db.commit()
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Reset expired budget {budget.id} for new period: "
|
||||
f"{budget.period_start} to {budget.period_end}"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error resetting expired budget {budget.id}: {e}")
|
||||
self.db.rollback()
|
||||
|
||||
|
||||
def get_budget_status(self, api_key: APIKey) -> Dict[str, Any]:
|
||||
"""Get comprehensive budget status for an API key"""
|
||||
try:
|
||||
budgets = self._get_applicable_budgets(api_key)
|
||||
|
||||
|
||||
status = {
|
||||
"total_budgets": len(budgets),
|
||||
"active_budgets": 0,
|
||||
@@ -492,44 +562,53 @@ class BudgetEnforcementService:
|
||||
"warning_budgets": 0,
|
||||
"total_limit_cents": 0,
|
||||
"total_usage_cents": 0,
|
||||
"budgets": []
|
||||
"budgets": [],
|
||||
}
|
||||
|
||||
|
||||
for budget in budgets:
|
||||
if not budget.is_active:
|
||||
continue
|
||||
|
||||
|
||||
budget_info = budget.to_dict()
|
||||
budget_info.update({
|
||||
"is_expired": budget.is_expired(),
|
||||
"days_remaining": budget.get_period_days_remaining(),
|
||||
"daily_burn_rate": budget.get_daily_burn_rate(),
|
||||
"projected_spend": budget.get_projected_spend()
|
||||
})
|
||||
|
||||
budget_info.update(
|
||||
{
|
||||
"is_expired": budget.is_expired(),
|
||||
"days_remaining": budget.get_period_days_remaining(),
|
||||
"daily_burn_rate": budget.get_daily_burn_rate(),
|
||||
"projected_spend": budget.get_projected_spend(),
|
||||
}
|
||||
)
|
||||
|
||||
status["budgets"].append(budget_info)
|
||||
status["active_budgets"] += 1
|
||||
status["total_limit_cents"] += budget.limit_cents
|
||||
status["total_usage_cents"] += budget.current_usage_cents
|
||||
|
||||
|
||||
if budget.is_exceeded:
|
||||
status["exceeded_budgets"] += 1
|
||||
elif budget.warning_threshold_cents and budget.current_usage_cents >= budget.warning_threshold_cents:
|
||||
elif (
|
||||
budget.warning_threshold_cents
|
||||
and budget.current_usage_cents >= budget.warning_threshold_cents
|
||||
):
|
||||
status["warning_budgets"] += 1
|
||||
|
||||
|
||||
# Calculate overall percentages
|
||||
if status["total_limit_cents"] > 0:
|
||||
status["overall_usage_percentage"] = (status["total_usage_cents"] / status["total_limit_cents"]) * 100
|
||||
status["overall_usage_percentage"] = (
|
||||
status["total_usage_cents"] / status["total_limit_cents"]
|
||||
) * 100
|
||||
else:
|
||||
status["overall_usage_percentage"] = 0
|
||||
|
||||
|
||||
status["total_limit_dollars"] = status["total_limit_cents"] / 100
|
||||
status["total_usage_dollars"] = status["total_usage_cents"] / 100
|
||||
status["total_remaining_cents"] = max(0, status["total_limit_cents"] - status["total_usage_cents"])
|
||||
status["total_remaining_cents"] = max(
|
||||
0, status["total_limit_cents"] - status["total_usage_cents"]
|
||||
)
|
||||
status["total_remaining_dollars"] = status["total_remaining_cents"] / 100
|
||||
|
||||
|
||||
return status
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting budget status: {e}")
|
||||
return {
|
||||
@@ -538,14 +617,11 @@ class BudgetEnforcementService:
|
||||
"active_budgets": 0,
|
||||
"exceeded_budgets": 0,
|
||||
"warning_budgets": 0,
|
||||
"budgets": []
|
||||
"budgets": [],
|
||||
}
|
||||
|
||||
|
||||
def create_default_user_budget(
|
||||
self,
|
||||
user_id: int,
|
||||
limit_dollars: float = 10.0,
|
||||
period_type: str = "monthly"
|
||||
self, user_id: int, limit_dollars: float = 10.0, period_type: str = "monthly"
|
||||
) -> Budget:
|
||||
"""Create a default budget for a new user"""
|
||||
try:
|
||||
@@ -553,60 +629,69 @@ class BudgetEnforcementService:
|
||||
budget = Budget.create_monthly_budget(
|
||||
user_id=user_id,
|
||||
name="Default Monthly Budget",
|
||||
limit_dollars=limit_dollars
|
||||
limit_dollars=limit_dollars,
|
||||
)
|
||||
else:
|
||||
budget = Budget.create_daily_budget(
|
||||
user_id=user_id,
|
||||
name="Default Daily Budget",
|
||||
limit_dollars=limit_dollars
|
||||
limit_dollars=limit_dollars,
|
||||
)
|
||||
|
||||
|
||||
self.db.add(budget)
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"Created default budget for user {user_id}: ${limit_dollars} {period_type}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Created default budget for user {user_id}: ${limit_dollars} {period_type}"
|
||||
)
|
||||
|
||||
return budget
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating default budget: {e}")
|
||||
self.db.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def check_and_reset_expired_budgets(self):
|
||||
"""Background task to check and reset expired budgets"""
|
||||
try:
|
||||
expired_budgets = self.db.query(Budget).filter(
|
||||
and_(
|
||||
Budget.is_active == True,
|
||||
Budget.auto_renew == True,
|
||||
Budget.period_end < datetime.utcnow()
|
||||
expired_budgets = (
|
||||
self.db.query(Budget)
|
||||
.filter(
|
||||
and_(
|
||||
Budget.is_active == True,
|
||||
Budget.auto_renew == True,
|
||||
Budget.period_end < datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
).all()
|
||||
|
||||
.all()
|
||||
)
|
||||
|
||||
for budget in expired_budgets:
|
||||
self._reset_expired_budget(budget)
|
||||
|
||||
|
||||
logger.info(f"Reset {len(expired_budgets)} expired budgets")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in budget reset task: {e}")
|
||||
|
||||
|
||||
# Convenience functions
|
||||
|
||||
|
||||
# DEPRECATED: Use atomic versions for race-condition-free budget enforcement
|
||||
def check_budget_for_request(
|
||||
db: Session,
|
||||
api_key: APIKey,
|
||||
model_name: str,
|
||||
estimated_tokens: int,
|
||||
endpoint: str = None
|
||||
endpoint: str = None,
|
||||
) -> Tuple[bool, Optional[str], List[Dict[str, Any]]]:
|
||||
"""DEPRECATED: Convenience function to check budget compliance (race conditions possible)"""
|
||||
service = BudgetEnforcementService(db)
|
||||
return service.check_budget_compliance(api_key, model_name, estimated_tokens, endpoint)
|
||||
return service.check_budget_compliance(
|
||||
api_key, model_name, estimated_tokens, endpoint
|
||||
)
|
||||
|
||||
|
||||
def record_request_usage(
|
||||
@@ -615,11 +700,13 @@ def record_request_usage(
|
||||
model_name: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
endpoint: str = None
|
||||
endpoint: str = None,
|
||||
) -> List[Budget]:
|
||||
"""DEPRECATED: Convenience function to record actual usage (race conditions possible)"""
|
||||
service = BudgetEnforcementService(db)
|
||||
return service.record_usage(api_key, model_name, input_tokens, output_tokens, endpoint)
|
||||
return service.record_usage(
|
||||
api_key, model_name, input_tokens, output_tokens, endpoint
|
||||
)
|
||||
|
||||
|
||||
# ATOMIC VERSIONS: Race-condition-free budget enforcement
|
||||
@@ -628,11 +715,13 @@ def atomic_check_and_reserve_budget(
|
||||
api_key: APIKey,
|
||||
model_name: str,
|
||||
estimated_tokens: int,
|
||||
endpoint: str = None
|
||||
endpoint: str = None,
|
||||
) -> Tuple[bool, Optional[str], List[Dict[str, Any]], List[int]]:
|
||||
"""Atomic convenience function to check budget compliance and reserve spending"""
|
||||
service = BudgetEnforcementService(db)
|
||||
return service.atomic_check_and_reserve_budget(api_key, model_name, estimated_tokens, endpoint)
|
||||
return service.atomic_check_and_reserve_budget(
|
||||
api_key, model_name, estimated_tokens, endpoint
|
||||
)
|
||||
|
||||
|
||||
def atomic_finalize_usage(
|
||||
@@ -642,8 +731,10 @@ def atomic_finalize_usage(
|
||||
model_name: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
endpoint: str = None
|
||||
endpoint: str = None,
|
||||
) -> List[Budget]:
|
||||
"""Atomic convenience function to finalize actual usage after request completion"""
|
||||
service = BudgetEnforcementService(db)
|
||||
return service.atomic_finalize_usage(reserved_budget_ids, api_key, model_name, input_tokens, output_tokens, endpoint)
|
||||
return service.atomic_finalize_usage(
|
||||
reserved_budget_ids, api_key, model_name, input_tokens, output_tokens, endpoint
|
||||
)
|
||||
|
||||
@@ -23,17 +23,21 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class CachedAPIKeyService:
|
||||
"""Core cache-backed API key caching service for performance optimization"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.cache_ttl = 300 # 5 minutes cache TTL
|
||||
self.verification_cache_ttl = 3600 # 1 hour for verification results
|
||||
logger.info("Cached API key service initialized with core cache backend")
|
||||
|
||||
|
||||
async def close(self):
|
||||
"""Close method for compatibility - core cache handles its own lifecycle"""
|
||||
logger.info("Cached API key service close called - core cache handles lifecycle")
|
||||
|
||||
async def get_cached_api_key(self, key_prefix: str, db: AsyncSession) -> Optional[Dict[str, Any]]:
|
||||
logger.info(
|
||||
"Cached API key service close called - core cache handles lifecycle"
|
||||
)
|
||||
|
||||
async def get_cached_api_key(
|
||||
self, key_prefix: str, db: AsyncSession
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get API key data from cache or database
|
||||
Returns: Dictionary with api_key, user, and api_key_id
|
||||
@@ -43,59 +47,54 @@ class CachedAPIKeyService:
|
||||
cached_data = await core_cache.get_cached_api_key(key_prefix)
|
||||
if cached_data:
|
||||
logger.debug(f"API key cache hit for prefix: {key_prefix}")
|
||||
|
||||
|
||||
# Recreate APIKey object from cached data
|
||||
api_key_data = cached_data.get("api_key_data", {})
|
||||
user_data = cached_data.get("user_data", {})
|
||||
|
||||
|
||||
# Create APIKey instance
|
||||
api_key = APIKey(**api_key_data)
|
||||
|
||||
# Create User instance
|
||||
|
||||
# Create User instance
|
||||
user = User(**user_data)
|
||||
|
||||
|
||||
return {
|
||||
"api_key": api_key,
|
||||
"user": user,
|
||||
"api_key_id": api_key_data.get("id")
|
||||
"api_key_id": api_key_data.get("id"),
|
||||
}
|
||||
|
||||
logger.debug(f"API key cache miss for prefix: {key_prefix}, fetching from database")
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"API key cache miss for prefix: {key_prefix}, fetching from database"
|
||||
)
|
||||
|
||||
# Cache miss - fetch from database with optimized query
|
||||
stmt = (
|
||||
select(APIKey, User)
|
||||
.join(User, APIKey.user_id == User.id)
|
||||
.options(
|
||||
joinedload(APIKey.user),
|
||||
joinedload(User.api_keys)
|
||||
)
|
||||
.options(joinedload(APIKey.user), joinedload(User.api_keys))
|
||||
.where(APIKey.key_prefix == key_prefix)
|
||||
.where(APIKey.is_active == True)
|
||||
)
|
||||
|
||||
|
||||
result = await db.execute(stmt)
|
||||
api_key_user = result.first()
|
||||
|
||||
|
||||
if not api_key_user:
|
||||
logger.debug(f"API key not found in database for prefix: {key_prefix}")
|
||||
return None
|
||||
|
||||
|
||||
api_key, user = api_key_user
|
||||
|
||||
|
||||
# Cache for future requests
|
||||
await self._cache_api_key_data(key_prefix, api_key, user)
|
||||
|
||||
return {
|
||||
"api_key": api_key,
|
||||
"user": user,
|
||||
"api_key_id": api_key.id
|
||||
}
|
||||
|
||||
|
||||
return {"api_key": api_key, "user": user, "api_key_id": api_key.id}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving API key for prefix {key_prefix}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _cache_api_key_data(self, key_prefix: str, api_key: APIKey, user: User):
|
||||
"""Cache API key and user data"""
|
||||
try:
|
||||
@@ -118,17 +117,25 @@ class CachedAPIKeyService:
|
||||
"allowed_ips": api_key.allowed_ips,
|
||||
"description": api_key.description,
|
||||
"tags": api_key.tags,
|
||||
"created_at": api_key.created_at.isoformat() if api_key.created_at else None,
|
||||
"updated_at": api_key.updated_at.isoformat() if api_key.updated_at else None,
|
||||
"last_used_at": api_key.last_used_at.isoformat() if api_key.last_used_at else None,
|
||||
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
|
||||
"created_at": api_key.created_at.isoformat()
|
||||
if api_key.created_at
|
||||
else None,
|
||||
"updated_at": api_key.updated_at.isoformat()
|
||||
if api_key.updated_at
|
||||
else None,
|
||||
"last_used_at": api_key.last_used_at.isoformat()
|
||||
if api_key.last_used_at
|
||||
else None,
|
||||
"expires_at": api_key.expires_at.isoformat()
|
||||
if api_key.expires_at
|
||||
else None,
|
||||
"total_requests": api_key.total_requests,
|
||||
"total_tokens": api_key.total_tokens,
|
||||
"total_cost": api_key.total_cost,
|
||||
"is_unlimited": api_key.is_unlimited,
|
||||
"budget_limit_cents": api_key.budget_limit_cents,
|
||||
"budget_type": api_key.budget_type,
|
||||
"allowed_chatbots": api_key.allowed_chatbots
|
||||
"allowed_chatbots": api_key.allowed_chatbots,
|
||||
},
|
||||
"user_data": {
|
||||
"id": user.id,
|
||||
@@ -137,20 +144,28 @@ class CachedAPIKeyService:
|
||||
"is_active": user.is_active,
|
||||
"is_superuser": user.is_superuser,
|
||||
"role": user.role,
|
||||
"created_at": user.created_at.isoformat() if user.created_at else None,
|
||||
"updated_at": user.updated_at.isoformat() if user.updated_at else None,
|
||||
"last_login": user.last_login.isoformat() if user.last_login else None
|
||||
"created_at": user.created_at.isoformat()
|
||||
if user.created_at
|
||||
else None,
|
||||
"updated_at": user.updated_at.isoformat()
|
||||
if user.updated_at
|
||||
else None,
|
||||
"last_login": user.last_login.isoformat()
|
||||
if user.last_login
|
||||
else None,
|
||||
},
|
||||
"cached_at": datetime.utcnow().isoformat()
|
||||
"cached_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
await core_cache.cache_api_key(key_prefix, cache_data, self.cache_ttl)
|
||||
logger.debug(f"Cached API key data for prefix: {key_prefix}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching API key data for prefix {key_prefix}: {e}")
|
||||
|
||||
async def verify_api_key_cached(self, api_key: str, key_prefix: str) -> Optional[bool]:
|
||||
|
||||
async def verify_api_key_cached(
|
||||
self, api_key: str, key_prefix: str
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
Verify API key using cached hash to avoid expensive bcrypt operations
|
||||
Returns: True if verified, False if invalid, None if not cached
|
||||
@@ -158,73 +173,88 @@ class CachedAPIKeyService:
|
||||
try:
|
||||
# Check verification cache
|
||||
cached_verification = await core_cache.get_cached_verification(key_prefix)
|
||||
|
||||
|
||||
if cached_verification:
|
||||
# Check if cache is still valid (within TTL)
|
||||
cached_timestamp = datetime.fromisoformat(cached_verification["timestamp"])
|
||||
if datetime.utcnow() - cached_timestamp < timedelta(seconds=self.verification_cache_ttl):
|
||||
logger.debug(f"API key verification cache hit for prefix: {key_prefix}")
|
||||
cached_timestamp = datetime.fromisoformat(
|
||||
cached_verification["timestamp"]
|
||||
)
|
||||
if datetime.utcnow() - cached_timestamp < timedelta(
|
||||
seconds=self.verification_cache_ttl
|
||||
):
|
||||
logger.debug(
|
||||
f"API key verification cache hit for prefix: {key_prefix}"
|
||||
)
|
||||
return cached_verification.get("is_valid", False)
|
||||
|
||||
|
||||
return None # Not cached or expired
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking verification cache for prefix {key_prefix}: {e}")
|
||||
logger.error(
|
||||
f"Error checking verification cache for prefix {key_prefix}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def cache_verification_result(self, api_key: str, key_prefix: str, key_hash: str, is_valid: bool):
|
||||
|
||||
async def cache_verification_result(
|
||||
self, api_key: str, key_prefix: str, key_hash: str, is_valid: bool
|
||||
):
|
||||
"""Cache API key verification result to avoid expensive bcrypt operations"""
|
||||
try:
|
||||
await core_cache.cache_verification_result(api_key, key_prefix, key_hash, is_valid, self.verification_cache_ttl)
|
||||
await core_cache.cache_verification_result(
|
||||
api_key, key_prefix, key_hash, is_valid, self.verification_cache_ttl
|
||||
)
|
||||
logger.debug(f"Cached verification result for prefix: {key_prefix}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching verification result for prefix {key_prefix}: {e}")
|
||||
|
||||
logger.error(
|
||||
f"Error caching verification result for prefix {key_prefix}: {e}"
|
||||
)
|
||||
|
||||
async def invalidate_api_key_cache(self, key_prefix: str):
|
||||
"""Invalidate cached API key data"""
|
||||
try:
|
||||
await core_cache.invalidate_api_key(key_prefix)
|
||||
|
||||
|
||||
# Also invalidate verification cache
|
||||
verification_keys = await core_cache.clear_pattern(f"verify:{key_prefix}*", prefix="auth")
|
||||
|
||||
verification_keys = await core_cache.clear_pattern(
|
||||
f"verify:{key_prefix}*", prefix="auth"
|
||||
)
|
||||
|
||||
logger.debug(f"Invalidated cache for API key prefix: {key_prefix}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error invalidating cache for prefix {key_prefix}: {e}")
|
||||
|
||||
|
||||
async def update_last_used(self, api_key_id: int, db: AsyncSession):
|
||||
"""Update last used timestamp asynchronously for performance"""
|
||||
try:
|
||||
# Use core cache to track update requests to avoid database spam
|
||||
cache_key = f"last_used_update:{api_key_id}"
|
||||
|
||||
|
||||
# Check if we recently updated (within 5 minutes)
|
||||
last_update = await core_cache.get(cache_key, prefix="perf")
|
||||
if last_update:
|
||||
return # Skip update if recent
|
||||
|
||||
|
||||
# Update database
|
||||
stmt = (
|
||||
select(APIKey)
|
||||
.where(APIKey.id == api_key_id)
|
||||
)
|
||||
stmt = select(APIKey).where(APIKey.id == api_key_id)
|
||||
result = await db.execute(stmt)
|
||||
api_key = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if api_key:
|
||||
api_key.last_used_at = datetime.utcnow()
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Cache that we updated to prevent spam
|
||||
await core_cache.set(cache_key, datetime.utcnow().isoformat(), ttl=300, prefix="perf")
|
||||
|
||||
await core_cache.set(
|
||||
cache_key, datetime.utcnow().isoformat(), ttl=300, prefix="perf"
|
||||
)
|
||||
|
||||
logger.debug(f"Updated last_used_at for API key {api_key_id}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating last_used for API key {api_key_id}: {e}")
|
||||
|
||||
|
||||
async def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""Get cache performance statistics"""
|
||||
try:
|
||||
@@ -232,16 +262,16 @@ class CachedAPIKeyService:
|
||||
return {
|
||||
"cache_backend": "core_cache",
|
||||
"cache_enabled": core_stats.get("enabled", False),
|
||||
"cache_stats": core_stats
|
||||
"cache_stats": core_stats,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting cache stats: {e}")
|
||||
return {
|
||||
"cache_backend": "core_cache",
|
||||
"cache_backend": "core_cache",
|
||||
"cache_enabled": False,
|
||||
"error": str(e)
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
# Global instance
|
||||
cached_api_key_service = CachedAPIKeyService()
|
||||
cached_api_key_service = CachedAPIKeyService()
|
||||
|
||||
@@ -25,6 +25,7 @@ logger = get_logger(__name__)
|
||||
@dataclass
|
||||
class ConfigVersion:
|
||||
"""Configuration version metadata"""
|
||||
|
||||
version: str
|
||||
timestamp: datetime
|
||||
checksum: str
|
||||
@@ -36,6 +37,7 @@ class ConfigVersion:
|
||||
@dataclass
|
||||
class ConfigSchema:
|
||||
"""Configuration schema definition"""
|
||||
|
||||
name: str
|
||||
required_fields: List[str]
|
||||
optional_fields: List[str]
|
||||
@@ -46,6 +48,7 @@ class ConfigSchema:
|
||||
@dataclass
|
||||
class ConfigStats:
|
||||
"""Configuration manager statistics"""
|
||||
|
||||
total_configs: int
|
||||
active_watchers: int
|
||||
config_versions: int
|
||||
@@ -57,42 +60,42 @@ class ConfigStats:
|
||||
|
||||
class ConfigWatcher(FileSystemEventHandler):
|
||||
"""File system watcher for configuration changes"""
|
||||
|
||||
|
||||
def __init__(self, config_manager):
|
||||
self.config_manager = config_manager
|
||||
self.debounce_time = 1.0 # 1 second debounce
|
||||
self.last_modified = {}
|
||||
|
||||
|
||||
def on_modified(self, event):
|
||||
if event.is_directory:
|
||||
return
|
||||
|
||||
|
||||
path = event.src_path
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
# Debounce rapid file changes
|
||||
if path in self.last_modified:
|
||||
if current_time - self.last_modified[path] < self.debounce_time:
|
||||
return
|
||||
|
||||
|
||||
self.last_modified[path] = current_time
|
||||
|
||||
|
||||
# Trigger hot reload for config files
|
||||
if path.endswith(('.json', '.yaml', '.yml', '.toml')):
|
||||
if path.endswith((".json", ".yaml", ".yml", ".toml")):
|
||||
# Schedule coroutine in a thread-safe way
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.call_soon_threadsafe(
|
||||
lambda: asyncio.create_task(self.config_manager.reload_config_file(path))
|
||||
lambda: asyncio.create_task(
|
||||
self.config_manager.reload_config_file(path)
|
||||
)
|
||||
)
|
||||
except RuntimeError:
|
||||
# No running loop, schedule for later
|
||||
threading.Thread(
|
||||
target=self._schedule_reload,
|
||||
args=(path,),
|
||||
daemon=True
|
||||
target=self._schedule_reload, args=(path,), daemon=True
|
||||
).start()
|
||||
|
||||
|
||||
def _schedule_reload(self, path: str):
|
||||
"""Schedule reload in a new thread if no event loop is available"""
|
||||
try:
|
||||
@@ -103,14 +106,14 @@ class ConfigWatcher(FileSystemEventHandler):
|
||||
|
||||
class ConfigManager:
|
||||
"""Core configuration management system"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.configs: Dict[str, Dict[str, Any]] = {}
|
||||
self.schemas: Dict[str, ConfigSchema] = {}
|
||||
self.versions: Dict[str, List[ConfigVersion]] = {}
|
||||
self.watchers: Dict[str, Observer] = {}
|
||||
self.config_paths: Dict[str, Path] = {}
|
||||
self.environment = os.getenv('ENVIRONMENT', 'development')
|
||||
self.environment = os.getenv("ENVIRONMENT", "development")
|
||||
self.start_time = time.time()
|
||||
self.stats = ConfigStats(
|
||||
total_configs=0,
|
||||
@@ -119,32 +122,32 @@ class ConfigManager:
|
||||
hot_reloads_performed=0,
|
||||
validation_errors=0,
|
||||
last_reload_time=datetime.now(),
|
||||
uptime=0
|
||||
uptime=0,
|
||||
)
|
||||
|
||||
|
||||
# Base configuration directories
|
||||
self.config_base_dir = Path("configs")
|
||||
self.config_base_dir.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
# Environment-specific directory
|
||||
self.env_config_dir = self.config_base_dir / self.environment
|
||||
self.env_config_dir.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
logger.info(f"ConfigManager initialized for environment: {self.environment}")
|
||||
|
||||
|
||||
def register_schema(self, name: str, schema: ConfigSchema):
|
||||
"""Register a configuration schema for validation"""
|
||||
self.schemas[name] = schema
|
||||
logger.info(f"Registered configuration schema: {name}")
|
||||
|
||||
|
||||
def validate_config(self, name: str, config_data: Dict[str, Any]) -> bool:
|
||||
"""Validate configuration against registered schema"""
|
||||
if name not in self.schemas:
|
||||
logger.debug(f"No schema registered for config: {name}")
|
||||
return True
|
||||
|
||||
|
||||
schema = self.schemas[name]
|
||||
|
||||
|
||||
try:
|
||||
# Check required fields
|
||||
for field in schema.required_fields:
|
||||
@@ -152,189 +155,202 @@ class ConfigManager:
|
||||
logger.error(f"Missing required field '{field}' in config '{name}'")
|
||||
self.stats.validation_errors += 1
|
||||
return False
|
||||
|
||||
|
||||
# Validate field types
|
||||
for field, expected_type in schema.field_types.items():
|
||||
if field in config_data:
|
||||
if not isinstance(config_data[field], expected_type):
|
||||
logger.error(f"Invalid type for field '{field}' in config '{name}'. Expected {expected_type.__name__}")
|
||||
logger.error(
|
||||
f"Invalid type for field '{field}' in config '{name}'. Expected {expected_type.__name__}"
|
||||
)
|
||||
self.stats.validation_errors += 1
|
||||
return False
|
||||
|
||||
|
||||
# Run custom validators
|
||||
for field, validator in schema.validators.items():
|
||||
if field in config_data:
|
||||
if not validator(config_data[field]):
|
||||
logger.error(f"Validation failed for field '{field}' in config '{name}'")
|
||||
logger.error(
|
||||
f"Validation failed for field '{field}' in config '{name}'"
|
||||
)
|
||||
self.stats.validation_errors += 1
|
||||
return False
|
||||
|
||||
|
||||
logger.debug(f"Configuration '{name}' passed validation")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating config '{name}': {str(e)}")
|
||||
self.stats.validation_errors += 1
|
||||
return False
|
||||
|
||||
|
||||
def _calculate_checksum(self, data: Dict[str, Any]) -> str:
|
||||
"""Calculate checksum for configuration data"""
|
||||
json_str = json.dumps(data, sort_keys=True)
|
||||
return hashlib.sha256(json_str.encode()).hexdigest()
|
||||
|
||||
def _create_version(self, name: str, config_data: Dict[str, Any], description: str = "Auto-save") -> ConfigVersion:
|
||||
|
||||
def _create_version(
|
||||
self, name: str, config_data: Dict[str, Any], description: str = "Auto-save"
|
||||
) -> ConfigVersion:
|
||||
"""Create a new configuration version"""
|
||||
version_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
checksum = self._calculate_checksum(config_data)
|
||||
|
||||
|
||||
version = ConfigVersion(
|
||||
version=version_id,
|
||||
timestamp=datetime.now(),
|
||||
checksum=checksum,
|
||||
author=os.getenv('USER', 'system'),
|
||||
author=os.getenv("USER", "system"),
|
||||
description=description,
|
||||
config_data=config_data.copy()
|
||||
config_data=config_data.copy(),
|
||||
)
|
||||
|
||||
|
||||
if name not in self.versions:
|
||||
self.versions[name] = []
|
||||
|
||||
|
||||
self.versions[name].append(version)
|
||||
|
||||
|
||||
# Keep only last 10 versions
|
||||
if len(self.versions[name]) > 10:
|
||||
self.versions[name] = self.versions[name][-10:]
|
||||
|
||||
|
||||
self.stats.config_versions += 1
|
||||
logger.debug(f"Created version {version_id} for config '{name}'")
|
||||
return version
|
||||
|
||||
async def set_config(self, name: str, config_data: Dict[str, Any],
|
||||
description: str = "Manual update") -> bool:
|
||||
|
||||
async def set_config(
|
||||
self, name: str, config_data: Dict[str, Any], description: str = "Manual update"
|
||||
) -> bool:
|
||||
"""Set configuration with validation and versioning"""
|
||||
try:
|
||||
# Validate configuration
|
||||
if not self.validate_config(name, config_data):
|
||||
return False
|
||||
|
||||
|
||||
# Create version before updating
|
||||
self._create_version(name, config_data, description)
|
||||
|
||||
|
||||
# Store configuration
|
||||
self.configs[name] = config_data.copy()
|
||||
self.stats.total_configs = len(self.configs)
|
||||
|
||||
|
||||
# Save to file
|
||||
await self._save_config_to_file(name, config_data)
|
||||
|
||||
|
||||
logger.info(f"Configuration '{name}' updated successfully")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting config '{name}': {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_config(self, name: str, default: Any = None) -> Any:
|
||||
"""Get configuration value"""
|
||||
if name in self.configs:
|
||||
return self.configs[name]
|
||||
|
||||
|
||||
# Try to load from file if not in memory
|
||||
config_data = await self._load_config_from_file(name)
|
||||
if config_data is not None:
|
||||
self.configs[name] = config_data
|
||||
return config_data
|
||||
|
||||
|
||||
return default
|
||||
|
||||
async def get_config_value(self, config_name: str, key: str, default: Any = None) -> Any:
|
||||
|
||||
async def get_config_value(
|
||||
self, config_name: str, key: str, default: Any = None
|
||||
) -> Any:
|
||||
"""Get specific value from configuration"""
|
||||
config = await self.get_config(config_name)
|
||||
if config is None:
|
||||
return default
|
||||
|
||||
keys = key.split('.')
|
||||
|
||||
keys = key.split(".")
|
||||
value = config
|
||||
|
||||
|
||||
try:
|
||||
for k in keys:
|
||||
value = value[k]
|
||||
return value
|
||||
except (KeyError, TypeError):
|
||||
return default
|
||||
|
||||
|
||||
async def _save_config_to_file(self, name: str, config_data: Dict[str, Any]):
|
||||
"""Save configuration to file"""
|
||||
file_path = self.env_config_dir / f"{name}.json"
|
||||
|
||||
|
||||
try:
|
||||
# Save as regular JSON
|
||||
with open(file_path, 'w') as f:
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(config_data, f, indent=2)
|
||||
logger.debug(f"Saved config '{name}' to {file_path}")
|
||||
|
||||
|
||||
self.config_paths[name] = file_path
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving config '{name}' to file: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def _load_config_from_file(self, name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Load configuration from file"""
|
||||
file_path = self.env_config_dir / f"{name}.json"
|
||||
|
||||
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
# Load regular JSON
|
||||
with open(file_path, 'r') as f:
|
||||
with open(file_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading config '{name}' from file: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def reload_config_file(self, file_path: str):
|
||||
"""Hot reload configuration from file change"""
|
||||
try:
|
||||
path = Path(file_path)
|
||||
config_name = path.stem
|
||||
|
||||
|
||||
# Load updated configuration
|
||||
if path.suffix == '.json':
|
||||
with open(path, 'r') as f:
|
||||
if path.suffix == ".json":
|
||||
with open(path, "r") as f:
|
||||
new_config = json.load(f)
|
||||
elif path.suffix in ['.yaml', '.yml']:
|
||||
with open(path, 'r') as f:
|
||||
elif path.suffix in [".yaml", ".yml"]:
|
||||
with open(path, "r") as f:
|
||||
new_config = yaml.safe_load(f)
|
||||
else:
|
||||
logger.warning(f"Unsupported config file format: {path.suffix}")
|
||||
return
|
||||
|
||||
|
||||
# Validate and update
|
||||
if self.validate_config(config_name, new_config):
|
||||
self.configs[config_name] = new_config
|
||||
self.stats.hot_reloads_performed += 1
|
||||
self.stats.last_reload_time = datetime.now()
|
||||
logger.info(f"Hot reloaded configuration '{config_name}' from {file_path}")
|
||||
logger.info(
|
||||
f"Hot reloaded configuration '{config_name}' from {file_path}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"Failed to hot reload '{config_name}' - validation failed")
|
||||
|
||||
logger.error(
|
||||
f"Failed to hot reload '{config_name}' - validation failed"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error hot reloading config from {file_path}: {str(e)}")
|
||||
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get configuration management statistics"""
|
||||
self.stats.uptime = time.time() - self.start_time
|
||||
return asdict(self.stats)
|
||||
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
for watcher in self.watchers.values():
|
||||
watcher.stop()
|
||||
watcher.join()
|
||||
|
||||
|
||||
self.watchers.clear()
|
||||
logger.info("Configuration management cleanup completed")
|
||||
|
||||
@@ -355,37 +371,37 @@ async def init_config_manager():
|
||||
"""Initialize the global config manager"""
|
||||
global config_manager
|
||||
config_manager = ConfigManager()
|
||||
|
||||
|
||||
# Register default schemas
|
||||
await _register_default_schemas()
|
||||
|
||||
|
||||
# Load default configurations
|
||||
await _load_default_configs()
|
||||
|
||||
|
||||
logger.info("Configuration manager initialized")
|
||||
|
||||
|
||||
async def _register_default_schemas():
|
||||
"""Register default configuration schemas"""
|
||||
manager = get_config_manager()
|
||||
|
||||
|
||||
# Database schema
|
||||
db_schema = ConfigSchema(
|
||||
name="database",
|
||||
required_fields=["host", "port", "name"],
|
||||
optional_fields=["username", "password", "ssl"],
|
||||
field_types={"host": str, "port": int, "name": str, "ssl": bool},
|
||||
validators={"port": lambda x: 1 <= x <= 65535}
|
||||
validators={"port": lambda x: 1 <= x <= 65535},
|
||||
)
|
||||
manager.register_schema("database", db_schema)
|
||||
|
||||
|
||||
# Cache schema
|
||||
cache_schema = ConfigSchema(
|
||||
name="cache",
|
||||
required_fields=["redis_url"],
|
||||
optional_fields=["timeout", "max_connections"],
|
||||
field_types={"redis_url": str, "timeout": int, "max_connections": int},
|
||||
validators={"timeout": lambda x: x > 0}
|
||||
validators={"timeout": lambda x: x > 0},
|
||||
)
|
||||
manager.register_schema("cache", cache_schema)
|
||||
|
||||
@@ -393,21 +409,21 @@ async def _register_default_schemas():
|
||||
async def _load_default_configs():
|
||||
"""Load default configurations"""
|
||||
manager = get_config_manager()
|
||||
|
||||
|
||||
default_configs = {
|
||||
"app": {
|
||||
"name": "Confidential Empire",
|
||||
"version": "1.0.0",
|
||||
"debug": manager.environment == "development",
|
||||
"log_level": "INFO",
|
||||
"timezone": "UTC"
|
||||
"timezone": "UTC",
|
||||
},
|
||||
"cache": {
|
||||
"redis_url": "redis://empire-redis:6379/0",
|
||||
"timeout": 30,
|
||||
"max_connections": 10
|
||||
}
|
||||
"max_connections": 10,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
for name, config in default_configs.items():
|
||||
await manager.set_config(name, config, description="Default configuration")
|
||||
await manager.set_config(name, config, description="Default configuration")
|
||||
|
||||
@@ -18,19 +18,19 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ConversationService:
|
||||
"""Service for managing chatbot conversations and message history"""
|
||||
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
|
||||
async def get_or_create_conversation(
|
||||
self,
|
||||
chatbot_id: str,
|
||||
self,
|
||||
chatbot_id: str,
|
||||
user_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
title: Optional[str] = None
|
||||
title: Optional[str] = None,
|
||||
) -> ChatbotConversation:
|
||||
"""Get existing conversation or create a new one"""
|
||||
|
||||
|
||||
# If conversation_id provided, try to get existing conversation
|
||||
if conversation_id:
|
||||
stmt = select(ChatbotConversation).where(
|
||||
@@ -38,22 +38,24 @@ class ConversationService:
|
||||
ChatbotConversation.id == conversation_id,
|
||||
ChatbotConversation.chatbot_id == chatbot_id,
|
||||
ChatbotConversation.user_id == user_id,
|
||||
ChatbotConversation.is_active == True
|
||||
ChatbotConversation.is_active == True,
|
||||
)
|
||||
)
|
||||
result = await self.db.execute(stmt)
|
||||
conversation = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if conversation:
|
||||
logger.info(f"Found existing conversation {conversation_id}")
|
||||
return conversation
|
||||
else:
|
||||
logger.warning(f"Conversation {conversation_id} not found or not accessible")
|
||||
|
||||
logger.warning(
|
||||
f"Conversation {conversation_id} not found or not accessible"
|
||||
)
|
||||
|
||||
# Create new conversation
|
||||
if not title:
|
||||
title = f"Chat {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
|
||||
|
||||
|
||||
conversation = ChatbotConversation(
|
||||
chatbot_id=chatbot_id,
|
||||
user_id=user_id,
|
||||
@@ -61,30 +63,29 @@ class ConversationService:
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow(),
|
||||
is_active=True,
|
||||
context_data={}
|
||||
context_data={},
|
||||
)
|
||||
|
||||
|
||||
self.db.add(conversation)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(conversation)
|
||||
|
||||
logger.info(f"Created new conversation {conversation.id} for chatbot {chatbot_id}")
|
||||
|
||||
logger.info(
|
||||
f"Created new conversation {conversation.id} for chatbot {chatbot_id}"
|
||||
)
|
||||
return conversation
|
||||
|
||||
|
||||
async def get_conversation_history(
|
||||
self,
|
||||
conversation_id: str,
|
||||
limit: int = 20,
|
||||
include_system: bool = False
|
||||
self, conversation_id: str, limit: int = 20, include_system: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load conversation history for a conversation
|
||||
|
||||
|
||||
Args:
|
||||
conversation_id: ID of the conversation
|
||||
limit: Maximum number of messages to return (default 20)
|
||||
include_system: Whether to include system messages (default False)
|
||||
|
||||
|
||||
Returns:
|
||||
List of messages in chronological order (oldest first)
|
||||
"""
|
||||
@@ -93,185 +94,210 @@ class ConversationService:
|
||||
stmt = select(ChatbotMessage).where(
|
||||
ChatbotMessage.conversation_id == conversation_id
|
||||
)
|
||||
|
||||
|
||||
# Optionally exclude system messages
|
||||
if not include_system:
|
||||
stmt = stmt.where(ChatbotMessage.role != 'system')
|
||||
|
||||
stmt = stmt.where(ChatbotMessage.role != "system")
|
||||
|
||||
# Order by timestamp descending and limit
|
||||
stmt = stmt.order_by(desc(ChatbotMessage.timestamp)).limit(limit)
|
||||
|
||||
|
||||
result = await self.db.execute(stmt)
|
||||
messages = result.scalars().all()
|
||||
|
||||
|
||||
# Convert to list and reverse to get chronological order (oldest first)
|
||||
history = []
|
||||
for msg in reversed(messages):
|
||||
history.append({
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"timestamp": msg.timestamp.isoformat() if msg.timestamp else None,
|
||||
"metadata": msg.message_metadata or {},
|
||||
"sources": msg.sources
|
||||
})
|
||||
|
||||
logger.info(f"Loaded {len(history)} messages for conversation {conversation_id}")
|
||||
history.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"timestamp": msg.timestamp.isoformat()
|
||||
if msg.timestamp
|
||||
else None,
|
||||
"metadata": msg.message_metadata or {},
|
||||
"sources": msg.sources,
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Loaded {len(history)} messages for conversation {conversation_id}"
|
||||
)
|
||||
return history
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load conversation history for {conversation_id}: {e}")
|
||||
logger.error(
|
||||
f"Failed to load conversation history for {conversation_id}: {e}"
|
||||
)
|
||||
return [] # Return empty list on error to avoid breaking chat
|
||||
|
||||
|
||||
async def add_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
sources: Optional[List[Dict[str, Any]]] = None
|
||||
sources: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> ChatbotMessage:
|
||||
"""Add a message to a conversation"""
|
||||
|
||||
if role not in ['user', 'assistant', 'system']:
|
||||
|
||||
if role not in ["user", "assistant", "system"]:
|
||||
raise ValueError(f"Invalid message role: {role}")
|
||||
|
||||
|
||||
message = ChatbotMessage(
|
||||
conversation_id=conversation_id,
|
||||
role=role,
|
||||
content=content,
|
||||
timestamp=datetime.utcnow(),
|
||||
message_metadata=metadata or {},
|
||||
sources=sources
|
||||
sources=sources,
|
||||
)
|
||||
|
||||
|
||||
self.db.add(message)
|
||||
|
||||
|
||||
# Update conversation timestamp
|
||||
stmt = select(ChatbotConversation).where(ChatbotConversation.id == conversation_id)
|
||||
stmt = select(ChatbotConversation).where(
|
||||
ChatbotConversation.id == conversation_id
|
||||
)
|
||||
result = await self.db.execute(stmt)
|
||||
conversation = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if conversation:
|
||||
conversation.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
await self.db.commit()
|
||||
await self.db.refresh(message)
|
||||
|
||||
|
||||
logger.info(f"Added {role} message to conversation {conversation_id}")
|
||||
return message
|
||||
|
||||
|
||||
async def get_conversation_stats(self, conversation_id: str) -> Dict[str, Any]:
|
||||
"""Get statistics for a conversation"""
|
||||
|
||||
|
||||
# Count messages by role
|
||||
stmt = select(
|
||||
ChatbotMessage.role,
|
||||
func.count(ChatbotMessage.id).label('count')
|
||||
).where(
|
||||
ChatbotMessage.conversation_id == conversation_id
|
||||
).group_by(ChatbotMessage.role)
|
||||
|
||||
stmt = (
|
||||
select(ChatbotMessage.role, func.count(ChatbotMessage.id).label("count"))
|
||||
.where(ChatbotMessage.conversation_id == conversation_id)
|
||||
.group_by(ChatbotMessage.role)
|
||||
)
|
||||
|
||||
result = await self.db.execute(stmt)
|
||||
role_counts = {row.role: row.count for row in result}
|
||||
|
||||
|
||||
# Get conversation info
|
||||
stmt = select(ChatbotConversation).where(ChatbotConversation.id == conversation_id)
|
||||
stmt = select(ChatbotConversation).where(
|
||||
ChatbotConversation.id == conversation_id
|
||||
)
|
||||
result = await self.db.execute(stmt)
|
||||
conversation = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not conversation:
|
||||
raise APIException(status_code=404, error_code="CONVERSATION_NOT_FOUND")
|
||||
|
||||
|
||||
return {
|
||||
"conversation_id": conversation_id,
|
||||
"title": conversation.title,
|
||||
"created_at": conversation.created_at.isoformat() if conversation.created_at else None,
|
||||
"updated_at": conversation.updated_at.isoformat() if conversation.updated_at else None,
|
||||
"created_at": conversation.created_at.isoformat()
|
||||
if conversation.created_at
|
||||
else None,
|
||||
"updated_at": conversation.updated_at.isoformat()
|
||||
if conversation.updated_at
|
||||
else None,
|
||||
"total_messages": sum(role_counts.values()),
|
||||
"user_messages": role_counts.get('user', 0),
|
||||
"assistant_messages": role_counts.get('assistant', 0),
|
||||
"system_messages": role_counts.get('system', 0)
|
||||
"user_messages": role_counts.get("user", 0),
|
||||
"assistant_messages": role_counts.get("assistant", 0),
|
||||
"system_messages": role_counts.get("system", 0),
|
||||
}
|
||||
|
||||
|
||||
async def archive_old_conversations(self, days_inactive: int = 30) -> int:
|
||||
"""Archive conversations that haven't been used in specified days"""
|
||||
|
||||
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days_inactive)
|
||||
|
||||
|
||||
# Find conversations to archive
|
||||
stmt = select(ChatbotConversation).where(
|
||||
and_(
|
||||
ChatbotConversation.updated_at < cutoff_date,
|
||||
ChatbotConversation.is_active == True
|
||||
ChatbotConversation.is_active == True,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
result = await self.db.execute(stmt)
|
||||
conversations = result.scalars().all()
|
||||
|
||||
|
||||
archived_count = 0
|
||||
for conversation in conversations:
|
||||
conversation.is_active = False
|
||||
archived_count += 1
|
||||
|
||||
|
||||
if archived_count > 0:
|
||||
await self.db.commit()
|
||||
logger.info(f"Archived {archived_count} inactive conversations")
|
||||
|
||||
|
||||
return archived_count
|
||||
|
||||
|
||||
async def delete_conversation(self, conversation_id: str, user_id: str) -> bool:
|
||||
"""Delete a conversation and all its messages"""
|
||||
|
||||
|
||||
# Verify ownership
|
||||
stmt = select(ChatbotConversation).where(
|
||||
and_(
|
||||
ChatbotConversation.id == conversation_id,
|
||||
ChatbotConversation.user_id == user_id
|
||||
stmt = (
|
||||
select(ChatbotConversation)
|
||||
.where(
|
||||
and_(
|
||||
ChatbotConversation.id == conversation_id,
|
||||
ChatbotConversation.user_id == user_id,
|
||||
)
|
||||
)
|
||||
).options(selectinload(ChatbotConversation.messages))
|
||||
|
||||
.options(selectinload(ChatbotConversation.messages))
|
||||
)
|
||||
|
||||
result = await self.db.execute(stmt)
|
||||
conversation = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not conversation:
|
||||
return False
|
||||
|
||||
|
||||
# Delete all messages first
|
||||
for message in conversation.messages:
|
||||
await self.db.delete(message)
|
||||
|
||||
|
||||
# Delete conversation
|
||||
await self.db.delete(conversation)
|
||||
await self.db.commit()
|
||||
|
||||
logger.info(f"Deleted conversation {conversation_id} with {len(conversation.messages)} messages")
|
||||
|
||||
logger.info(
|
||||
f"Deleted conversation {conversation_id} with {len(conversation.messages)} messages"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
async def get_user_conversations(
|
||||
self,
|
||||
user_id: str,
|
||||
self,
|
||||
user_id: str,
|
||||
chatbot_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
skip: int = 0
|
||||
skip: int = 0,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get list of conversations for a user"""
|
||||
|
||||
|
||||
stmt = select(ChatbotConversation).where(
|
||||
and_(
|
||||
ChatbotConversation.user_id == user_id,
|
||||
ChatbotConversation.is_active == True
|
||||
ChatbotConversation.is_active == True,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if chatbot_id:
|
||||
stmt = stmt.where(ChatbotConversation.chatbot_id == chatbot_id)
|
||||
|
||||
stmt = stmt.order_by(desc(ChatbotConversation.updated_at)).offset(skip).limit(limit)
|
||||
|
||||
|
||||
stmt = (
|
||||
stmt.order_by(desc(ChatbotConversation.updated_at))
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await self.db.execute(stmt)
|
||||
conversations = result.scalars().all()
|
||||
|
||||
|
||||
conversation_list = []
|
||||
for conv in conversations:
|
||||
# Get message count
|
||||
@@ -280,15 +306,21 @@ class ConversationService:
|
||||
)
|
||||
msg_count_result = await self.db.execute(msg_count_stmt)
|
||||
message_count = msg_count_result.scalar() or 0
|
||||
|
||||
conversation_list.append({
|
||||
"id": conv.id,
|
||||
"chatbot_id": conv.chatbot_id,
|
||||
"title": conv.title,
|
||||
"message_count": message_count,
|
||||
"created_at": conv.created_at.isoformat() if conv.created_at else None,
|
||||
"updated_at": conv.updated_at.isoformat() if conv.updated_at else None,
|
||||
"context_data": conv.context_data or {}
|
||||
})
|
||||
|
||||
return conversation_list
|
||||
|
||||
conversation_list.append(
|
||||
{
|
||||
"id": conv.id,
|
||||
"chatbot_id": conv.chatbot_id,
|
||||
"title": conv.title,
|
||||
"message_count": message_count,
|
||||
"created_at": conv.created_at.isoformat()
|
||||
if conv.created_at
|
||||
else None,
|
||||
"updated_at": conv.updated_at.isoformat()
|
||||
if conv.updated_at
|
||||
else None,
|
||||
"context_data": conv.context_data or {},
|
||||
}
|
||||
)
|
||||
|
||||
return conversation_list
|
||||
|
||||
@@ -10,60 +10,64 @@ logger = get_logger(__name__)
|
||||
|
||||
class CostCalculator:
|
||||
"""Service for calculating costs based on model usage and token consumption"""
|
||||
|
||||
|
||||
# Model pricing in 1/10000ths of a dollar per 1000 tokens (input/output)
|
||||
MODEL_PRICING = {
|
||||
# OpenAI Models
|
||||
"gpt-4": {"input": 300, "output": 600}, # $0.03/$0.06 per 1K tokens
|
||||
"gpt-4-turbo": {"input": 100, "output": 300}, # $0.01/$0.03 per 1K tokens
|
||||
"gpt-3.5-turbo": {"input": 5, "output": 15}, # $0.0005/$0.0015 per 1K tokens
|
||||
|
||||
# Anthropic Models
|
||||
"claude-3-opus": {"input": 150, "output": 750}, # $0.015/$0.075 per 1K tokens
|
||||
"claude-3-sonnet": {"input": 30, "output": 150}, # $0.003/$0.015 per 1K tokens
|
||||
"claude-3-haiku": {"input": 25, "output": 125}, # $0.00025/$0.00125 per 1K tokens
|
||||
|
||||
"claude-3-haiku": {
|
||||
"input": 25,
|
||||
"output": 125,
|
||||
}, # $0.00025/$0.00125 per 1K tokens
|
||||
# Google Models
|
||||
"gemini-pro": {"input": 5, "output": 15}, # $0.0005/$0.0015 per 1K tokens
|
||||
"gemini-pro-vision": {"input": 5, "output": 15}, # $0.0005/$0.0015 per 1K tokens
|
||||
|
||||
"gemini-pro-vision": {
|
||||
"input": 5,
|
||||
"output": 15,
|
||||
}, # $0.0005/$0.0015 per 1K tokens
|
||||
# Privatemode.ai Models (estimated pricing)
|
||||
"privatemode-llama-70b": {"input": 40, "output": 80}, # Estimated pricing
|
||||
"privatemode-mixtral": {"input": 20, "output": 40}, # Estimated pricing
|
||||
|
||||
# Embedding Models
|
||||
"text-embedding-ada-002": {"input": 1, "output": 0}, # $0.0001 per 1K tokens
|
||||
"text-embedding-3-small": {"input": 2, "output": 0}, # $0.00002 per 1K tokens
|
||||
"text-embedding-3-large": {"input": 13, "output": 0}, # $0.00013 per 1K tokens
|
||||
}
|
||||
|
||||
|
||||
# Default pricing for unknown models
|
||||
DEFAULT_PRICING = {"input": 10, "output": 20} # $0.001/$0.002 per 1K tokens
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_model_pricing(cls, model_name: str) -> Dict[str, int]:
|
||||
"""Get pricing for a specific model"""
|
||||
# Normalize model name (remove provider prefixes)
|
||||
normalized_name = cls._normalize_model_name(model_name)
|
||||
|
||||
|
||||
# Look up pricing
|
||||
pricing = cls.MODEL_PRICING.get(normalized_name, cls.DEFAULT_PRICING)
|
||||
|
||||
logger.debug(f"Pricing for model '{model_name}' (normalized: '{normalized_name}'): {pricing}")
|
||||
|
||||
logger.debug(
|
||||
f"Pricing for model '{model_name}' (normalized: '{normalized_name}'): {pricing}"
|
||||
)
|
||||
return pricing
|
||||
|
||||
|
||||
@classmethod
|
||||
def _normalize_model_name(cls, model_name: str) -> str:
|
||||
"""Normalize model name by removing provider prefixes"""
|
||||
# Remove common provider prefixes
|
||||
prefixes = ["openai/", "anthropic/", "google/", "gemini/", "privatemode/"]
|
||||
|
||||
|
||||
normalized = model_name.lower()
|
||||
for prefix in prefixes:
|
||||
if normalized.startswith(prefix):
|
||||
normalized = normalized[len(prefix):]
|
||||
normalized = normalized[len(prefix) :]
|
||||
break
|
||||
|
||||
|
||||
# Handle special cases
|
||||
if "claude-3-opus-20240229" in normalized:
|
||||
return "claude-3-opus"
|
||||
@@ -75,91 +79,88 @@ class CostCalculator:
|
||||
return "privatemode-llama-70b"
|
||||
elif "mistralai/mixtral-8x7b-instruct" in normalized:
|
||||
return "privatemode-mixtral"
|
||||
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
@classmethod
|
||||
def calculate_cost_cents(
|
||||
cls,
|
||||
model_name: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0
|
||||
cls, model_name: str, input_tokens: int = 0, output_tokens: int = 0
|
||||
) -> int:
|
||||
"""
|
||||
Calculate cost in cents for given token usage
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the LLM model
|
||||
input_tokens: Number of input tokens used
|
||||
output_tokens: Number of output tokens generated
|
||||
|
||||
|
||||
Returns:
|
||||
Total cost in cents
|
||||
"""
|
||||
pricing = cls.get_model_pricing(model_name)
|
||||
|
||||
|
||||
# Calculate cost per token type
|
||||
input_cost_cents = (input_tokens * pricing["input"]) // 1000
|
||||
output_cost_cents = (output_tokens * pricing["output"]) // 1000
|
||||
|
||||
|
||||
total_cost_cents = input_cost_cents + output_cost_cents
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"Cost calculation for {model_name}: "
|
||||
f"input_tokens={input_tokens} (${input_cost_cents/100:.4f}), "
|
||||
f"output_tokens={output_tokens} (${output_cost_cents/100:.4f}), "
|
||||
f"total=${total_cost_cents/100:.4f}"
|
||||
)
|
||||
|
||||
|
||||
return total_cost_cents
|
||||
|
||||
|
||||
@classmethod
|
||||
def estimate_cost_cents(cls, model_name: str, estimated_tokens: int) -> int:
|
||||
"""
|
||||
Estimate cost for a request based on estimated total tokens
|
||||
Assumes 70% input, 30% output token distribution
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the LLM model
|
||||
estimated_tokens: Estimated total tokens for the request
|
||||
|
||||
|
||||
Returns:
|
||||
Estimated cost in cents
|
||||
"""
|
||||
input_tokens = int(estimated_tokens * 0.7) # 70% input
|
||||
output_tokens = int(estimated_tokens * 0.3) # 30% output
|
||||
|
||||
|
||||
return cls.calculate_cost_cents(model_name, input_tokens, output_tokens)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_cost_per_1k_tokens(cls, model_name: str) -> Dict[str, float]:
|
||||
"""
|
||||
Get cost per 1000 tokens in dollars for display purposes
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the LLM model
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with input and output costs in dollars per 1K tokens
|
||||
"""
|
||||
pricing_cents = cls.get_model_pricing(model_name)
|
||||
|
||||
|
||||
return {
|
||||
"input": pricing_cents["input"] / 10000, # Convert 1/10000ths to dollars
|
||||
"output": pricing_cents["output"] / 10000,
|
||||
"currency": "USD"
|
||||
"currency": "USD",
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_all_model_pricing(cls) -> Dict[str, Dict[str, float]]:
|
||||
"""Get pricing for all supported models in dollars"""
|
||||
pricing_data = {}
|
||||
|
||||
|
||||
for model_name in cls.MODEL_PRICING.keys():
|
||||
pricing_data[model_name] = cls.get_cost_per_1k_tokens(model_name)
|
||||
|
||||
|
||||
return pricing_data
|
||||
|
||||
|
||||
@classmethod
|
||||
def format_cost_display(cls, cost_cents: int) -> str:
|
||||
"""Format cost in 1/1000ths of a dollar for display"""
|
||||
@@ -172,7 +173,9 @@ class CostCalculator:
|
||||
|
||||
|
||||
# Convenience functions for common operations
|
||||
def calculate_request_cost(model_name: str, input_tokens: int, output_tokens: int) -> int:
|
||||
def calculate_request_cost(
|
||||
model_name: str, input_tokens: int, output_tokens: int
|
||||
) -> int:
|
||||
"""Calculate cost for a single request"""
|
||||
return CostCalculator.calculate_cost_cents(model_name, input_tokens, output_tokens)
|
||||
|
||||
@@ -184,4 +187,4 @@ def estimate_request_cost(model_name: str, estimated_tokens: int) -> int:
|
||||
|
||||
def get_model_pricing_display(model_name: str) -> Dict[str, float]:
|
||||
"""Get model pricing for display"""
|
||||
return CostCalculator.get_cost_per_1k_tokens(model_name)
|
||||
return CostCalculator.get_cost_per_1k_tokens(model_name)
|
||||
|
||||
@@ -33,12 +33,13 @@ class ProcessingStatus(str, Enum):
|
||||
@dataclass
|
||||
class ProcessingTask:
|
||||
"""Document processing task"""
|
||||
|
||||
document_id: int
|
||||
priority: int = 1
|
||||
retry_count: int = 0
|
||||
max_retries: int = 3
|
||||
created_at: datetime = None
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.utcnow()
|
||||
@@ -46,7 +47,7 @@ class ProcessingTask:
|
||||
|
||||
class DocumentProcessor:
|
||||
"""Async document processor with queue management"""
|
||||
|
||||
|
||||
def __init__(self, max_workers: int = 3, max_queue_size: int = 100):
|
||||
self.max_workers = max_workers
|
||||
self.max_queue_size = max_queue_size
|
||||
@@ -57,49 +58,49 @@ class DocumentProcessor:
|
||||
"processed_count": 0,
|
||||
"error_count": 0,
|
||||
"queue_size": 0,
|
||||
"active_workers": 0
|
||||
"active_workers": 0,
|
||||
}
|
||||
self._rag_module = None
|
||||
self._rag_module_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def start(self):
|
||||
"""Start the document processor"""
|
||||
if self.running:
|
||||
return
|
||||
|
||||
|
||||
self.running = True
|
||||
logger.info(f"Starting document processor with {self.max_workers} workers")
|
||||
|
||||
|
||||
# Start worker tasks
|
||||
for i in range(self.max_workers):
|
||||
worker = asyncio.create_task(self._worker(f"worker-{i}"))
|
||||
self.workers.append(worker)
|
||||
|
||||
|
||||
logger.info("Document processor started")
|
||||
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the document processor"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
|
||||
self.running = False
|
||||
logger.info("Stopping document processor...")
|
||||
|
||||
|
||||
# Cancel all workers
|
||||
for worker in self.workers:
|
||||
worker.cancel()
|
||||
|
||||
|
||||
# Wait for workers to finish
|
||||
await asyncio.gather(*self.workers, return_exceptions=True)
|
||||
self.workers.clear()
|
||||
|
||||
|
||||
logger.info("Document processor stopped")
|
||||
|
||||
|
||||
async def add_task(self, document_id: int, priority: int = 1) -> bool:
|
||||
"""Add a document processing task to the queue"""
|
||||
try:
|
||||
task = ProcessingTask(document_id=document_id, priority=priority)
|
||||
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self.processing_queue.put(task), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
@@ -108,47 +109,54 @@ class DocumentProcessor:
|
||||
document_id,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
self.stats["queue_size"] = self.processing_queue.qsize()
|
||||
|
||||
logger.info(f"Added processing task for document {document_id} (priority: {priority})")
|
||||
|
||||
logger.info(
|
||||
f"Added processing task for document {document_id} (priority: {priority})"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add processing task for document {document_id}: {e}")
|
||||
logger.error(
|
||||
f"Failed to add processing task for document {document_id}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def _worker(self, worker_name: str):
|
||||
"""Worker coroutine that processes documents"""
|
||||
logger.info(f"Started worker: {worker_name}")
|
||||
|
||||
|
||||
while self.running:
|
||||
task: Optional[ProcessingTask] = None
|
||||
try:
|
||||
# Get task from queue (wait up to 1 second)
|
||||
task = await asyncio.wait_for(
|
||||
self.processing_queue.get(),
|
||||
timeout=1.0
|
||||
)
|
||||
|
||||
task = await asyncio.wait_for(self.processing_queue.get(), timeout=1.0)
|
||||
|
||||
self.stats["active_workers"] += 1
|
||||
self.stats["queue_size"] = self.processing_queue.qsize()
|
||||
|
||||
|
||||
logger.info(f"{worker_name}: Processing document {task.document_id}")
|
||||
|
||||
|
||||
# Process the document
|
||||
success = await self._process_document(task)
|
||||
|
||||
|
||||
if success:
|
||||
self.stats["processed_count"] += 1
|
||||
logger.info(f"{worker_name}: Successfully processed document {task.document_id}")
|
||||
logger.info(
|
||||
f"{worker_name}: Successfully processed document {task.document_id}"
|
||||
)
|
||||
else:
|
||||
# Retry logic
|
||||
if task.retry_count < task.max_retries:
|
||||
task.retry_count += 1
|
||||
await asyncio.sleep(2 ** task.retry_count) # Exponential backoff
|
||||
await asyncio.sleep(
|
||||
2**task.retry_count
|
||||
) # Exponential backoff
|
||||
try:
|
||||
await asyncio.wait_for(self.processing_queue.put(task), timeout=5.0)
|
||||
await asyncio.wait_for(
|
||||
self.processing_queue.put(task), timeout=5.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"%s: Failed to requeue document %s due to saturated queue",
|
||||
@@ -157,11 +165,15 @@ class DocumentProcessor:
|
||||
)
|
||||
self.stats["error_count"] += 1
|
||||
continue
|
||||
logger.warning(f"{worker_name}: Retrying document {task.document_id} (attempt {task.retry_count})")
|
||||
logger.warning(
|
||||
f"{worker_name}: Retrying document {task.document_id} (attempt {task.retry_count})"
|
||||
)
|
||||
else:
|
||||
self.stats["error_count"] += 1
|
||||
logger.error(f"{worker_name}: Failed to process document {task.document_id} after {task.max_retries} retries")
|
||||
|
||||
logger.error(
|
||||
f"{worker_name}: Failed to process document {task.document_id} after {task.max_retries} retries"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# No tasks in queue, continue
|
||||
continue
|
||||
@@ -183,30 +195,34 @@ class DocumentProcessor:
|
||||
async def _get_rag_module(self):
|
||||
"""Resolve and cache the RAG module instance"""
|
||||
async with self._rag_module_lock:
|
||||
if self._rag_module and getattr(self._rag_module, 'enabled', False):
|
||||
if self._rag_module and getattr(self._rag_module, "enabled", False):
|
||||
return self._rag_module
|
||||
|
||||
if not module_manager.initialized:
|
||||
await module_manager.initialize()
|
||||
|
||||
rag_module = module_manager.get_module('rag')
|
||||
rag_module = module_manager.get_module("rag")
|
||||
|
||||
if not rag_module:
|
||||
enabled = await module_manager.enable_module('rag')
|
||||
enabled = await module_manager.enable_module("rag")
|
||||
if not enabled:
|
||||
raise RuntimeError("Failed to enable RAG module")
|
||||
rag_module = module_manager.get_module('rag')
|
||||
rag_module = module_manager.get_module("rag")
|
||||
|
||||
if not rag_module:
|
||||
raise RuntimeError("RAG module not available after enable attempt")
|
||||
|
||||
if not getattr(rag_module, 'enabled', True):
|
||||
enabled = await module_manager.enable_module('rag')
|
||||
if not getattr(rag_module, "enabled", True):
|
||||
enabled = await module_manager.enable_module("rag")
|
||||
if not enabled:
|
||||
raise RuntimeError("RAG module is disabled and could not be re-enabled")
|
||||
rag_module = module_manager.get_module('rag')
|
||||
if not rag_module or not getattr(rag_module, 'enabled', True):
|
||||
raise RuntimeError("RAG module is disabled and could not be re-enabled")
|
||||
raise RuntimeError(
|
||||
"RAG module is disabled and could not be re-enabled"
|
||||
)
|
||||
rag_module = module_manager.get_module("rag")
|
||||
if not rag_module or not getattr(rag_module, "enabled", True):
|
||||
raise RuntimeError(
|
||||
"RAG module is disabled and could not be re-enabled"
|
||||
)
|
||||
|
||||
self._rag_module = rag_module
|
||||
logger.info("DocumentProcessor cached RAG module instance for reuse")
|
||||
@@ -216,6 +232,7 @@ class DocumentProcessor:
|
||||
"""Process a single document"""
|
||||
from datetime import datetime
|
||||
from app.db.database import async_session_factory
|
||||
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
# Get document from database
|
||||
@@ -226,11 +243,11 @@ class DocumentProcessor:
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not document:
|
||||
logger.error(f"Document {task.document_id} not found")
|
||||
return False
|
||||
|
||||
|
||||
# Update status to processing
|
||||
document.status = ProcessingStatus.PROCESSING
|
||||
await session.commit()
|
||||
@@ -244,43 +261,62 @@ class DocumentProcessor:
|
||||
|
||||
if not rag_module or not rag_module.enabled:
|
||||
raise Exception("RAG module not available or not enabled")
|
||||
|
||||
logger.info(f"RAG module loaded successfully for document {task.document_id}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"RAG module loaded successfully for document {task.document_id}"
|
||||
)
|
||||
|
||||
# Read file content
|
||||
logger.info(f"Reading file content for document {task.document_id}: {document.file_path}")
|
||||
logger.info(
|
||||
f"Reading file content for document {task.document_id}: {document.file_path}"
|
||||
)
|
||||
file_path = Path(document.file_path)
|
||||
try:
|
||||
file_content = await asyncio.to_thread(file_path.read_bytes)
|
||||
except FileNotFoundError:
|
||||
logger.error(f"File not found for document {task.document_id}: {document.file_path}")
|
||||
logger.error(
|
||||
f"File not found for document {task.document_id}: {document.file_path}"
|
||||
)
|
||||
document.status = ProcessingStatus.ERROR
|
||||
document.processing_error = "Document file not found on disk"
|
||||
await session.commit()
|
||||
return False
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed reading file for document {task.document_id}: {exc}")
|
||||
logger.error(
|
||||
f"Failed reading file for document {task.document_id}: {exc}"
|
||||
)
|
||||
document.status = ProcessingStatus.ERROR
|
||||
document.processing_error = f"Failed to read file: {exc}"
|
||||
await session.commit()
|
||||
return False
|
||||
|
||||
logger.info(f"File content read successfully for document {task.document_id}, size: {len(file_content)} bytes")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"File content read successfully for document {task.document_id}, size: {len(file_content)} bytes"
|
||||
)
|
||||
|
||||
# Process with RAG module
|
||||
logger.info(f"Starting document processing for document {task.document_id} with RAG module")
|
||||
logger.info(
|
||||
f"Starting document processing for document {task.document_id} with RAG module"
|
||||
)
|
||||
|
||||
# Special handling for JSONL files - skip processing phase
|
||||
if document.file_type == 'jsonl':
|
||||
if document.file_type == "jsonl":
|
||||
# For JSONL files, we don't need to process content here
|
||||
# The optimized JSONL processor will handle everything during indexing
|
||||
document.converted_content = f"JSONL file with {len(file_content)} bytes"
|
||||
document.converted_content = (
|
||||
f"JSONL file with {len(file_content)} bytes"
|
||||
)
|
||||
document.word_count = 0 # Will be updated during indexing
|
||||
document.character_count = len(file_content)
|
||||
document.document_metadata = {"file_path": document.file_path, "processed": "jsonl"}
|
||||
document.document_metadata = {
|
||||
"file_path": document.file_path,
|
||||
"processed": "jsonl",
|
||||
}
|
||||
document.status = ProcessingStatus.PROCESSED
|
||||
document.processed_at = datetime.utcnow()
|
||||
logger.info(f"JSONL document {task.document_id} marked for optimized processing")
|
||||
logger.info(
|
||||
f"JSONL document {task.document_id} marked for optimized processing"
|
||||
)
|
||||
else:
|
||||
# Standard processing for other file types
|
||||
try:
|
||||
@@ -289,11 +325,13 @@ class DocumentProcessor:
|
||||
rag_module.process_document(
|
||||
file_content,
|
||||
document.original_filename,
|
||||
{"file_path": document.file_path}
|
||||
{"file_path": document.file_path},
|
||||
),
|
||||
timeout=300.0 # 5 minute timeout
|
||||
timeout=300.0, # 5 minute timeout
|
||||
)
|
||||
logger.info(
|
||||
f"Document processing completed for document {task.document_id}"
|
||||
)
|
||||
logger.info(f"Document processing completed for document {task.document_id}")
|
||||
|
||||
# Update document with processed content
|
||||
document.converted_content = processed_doc.content
|
||||
@@ -303,29 +341,35 @@ class DocumentProcessor:
|
||||
document.status = ProcessingStatus.PROCESSED
|
||||
document.processed_at = datetime.utcnow()
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Document processing timed out for document {task.document_id}")
|
||||
logger.error(
|
||||
f"Document processing timed out for document {task.document_id}"
|
||||
)
|
||||
raise Exception("Document processing timed out after 5 minutes")
|
||||
except Exception as e:
|
||||
logger.error(f"Document processing failed for document {task.document_id}: {e}")
|
||||
logger.error(
|
||||
f"Document processing failed for document {task.document_id}: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# Index in RAG system using same RAG module
|
||||
if rag_module and document.converted_content:
|
||||
try:
|
||||
logger.info(f"Starting indexing for document {task.document_id} in collection {document.collection.qdrant_collection_name}")
|
||||
|
||||
logger.info(
|
||||
f"Starting indexing for document {task.document_id} in collection {document.collection.qdrant_collection_name}"
|
||||
)
|
||||
|
||||
# Index the document content in the correct Qdrant collection
|
||||
doc_metadata = {
|
||||
"collection_id": document.collection_id,
|
||||
"document_id": document.id,
|
||||
"filename": document.original_filename,
|
||||
"file_type": document.file_type,
|
||||
**document.document_metadata
|
||||
**document.document_metadata,
|
||||
}
|
||||
|
||||
|
||||
# Use the correct Qdrant collection name for this document
|
||||
# For JSONL files, we need to use the processed document flow
|
||||
if document.file_type == 'jsonl':
|
||||
if document.file_type == "jsonl":
|
||||
# Create a ProcessedDocument for the JSONL processor
|
||||
from app.modules.rag.main import ProcessedDocument
|
||||
from datetime import datetime
|
||||
@@ -333,7 +377,9 @@ class DocumentProcessor:
|
||||
|
||||
# Calculate file hash
|
||||
processed_at = datetime.utcnow()
|
||||
file_hash = hashlib.md5(str(document.id).encode()).hexdigest()
|
||||
file_hash = hashlib.md5(
|
||||
str(document.id).encode()
|
||||
).hexdigest()
|
||||
|
||||
processed_doc = ProcessedDocument(
|
||||
id=str(document.id),
|
||||
@@ -341,12 +387,14 @@ class DocumentProcessor:
|
||||
extracted_text="", # Will be filled by JSONL processor
|
||||
metadata={
|
||||
**doc_metadata,
|
||||
"file_path": document.file_path
|
||||
"file_path": document.file_path,
|
||||
},
|
||||
original_filename=document.original_filename,
|
||||
file_type=document.file_type,
|
||||
mime_type=document.mime_type,
|
||||
language=document.document_metadata.get('language', 'EN'),
|
||||
language=document.document_metadata.get(
|
||||
"language", "EN"
|
||||
),
|
||||
word_count=0, # Will be updated during processing
|
||||
sentence_count=0, # Will be updated during processing
|
||||
entities=[],
|
||||
@@ -354,16 +402,16 @@ class DocumentProcessor:
|
||||
processing_time=0.0,
|
||||
processed_at=processed_at,
|
||||
file_hash=file_hash,
|
||||
file_size=document.file_size
|
||||
file_size=document.file_size,
|
||||
)
|
||||
|
||||
# The JSONL processor will read the original file
|
||||
await asyncio.wait_for(
|
||||
rag_module.index_processed_document(
|
||||
processed_doc=processed_doc,
|
||||
collection_name=document.collection.qdrant_collection_name
|
||||
collection_name=document.collection.qdrant_collection_name,
|
||||
),
|
||||
timeout=300.0 # 5 minute timeout for JSONL processing
|
||||
timeout=300.0, # 5 minute timeout for JSONL processing
|
||||
)
|
||||
else:
|
||||
# Use standard indexing for other file types
|
||||
@@ -371,18 +419,22 @@ class DocumentProcessor:
|
||||
rag_module.index_document(
|
||||
content=document.converted_content,
|
||||
metadata=doc_metadata,
|
||||
collection_name=document.collection.qdrant_collection_name
|
||||
collection_name=document.collection.qdrant_collection_name,
|
||||
),
|
||||
timeout=120.0 # 2 minute timeout for indexing
|
||||
timeout=120.0, # 2 minute timeout for indexing
|
||||
)
|
||||
|
||||
logger.info(f"Document {task.document_id} indexed successfully in collection {document.collection.qdrant_collection_name}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Document {task.document_id} indexed successfully in collection {document.collection.qdrant_collection_name}"
|
||||
)
|
||||
|
||||
# Update vector count (approximate)
|
||||
document.vector_count = max(1, len(document.converted_content) // 1000)
|
||||
document.vector_count = max(
|
||||
1, len(document.converted_content) // 1000
|
||||
)
|
||||
document.status = ProcessingStatus.INDEXED
|
||||
document.indexed_at = datetime.utcnow()
|
||||
|
||||
|
||||
# Update collection stats
|
||||
collection = document.collection
|
||||
if collection and document.status == ProcessingStatus.INDEXED:
|
||||
@@ -390,36 +442,38 @@ class DocumentProcessor:
|
||||
collection.size_bytes += document.file_size
|
||||
collection.vector_count += document.vector_count
|
||||
collection.updated_at = datetime.utcnow()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to index document {task.document_id} in RAG: {e}")
|
||||
logger.error(
|
||||
f"Failed to index document {task.document_id} in RAG: {e}"
|
||||
)
|
||||
# Mark as error since indexing failed
|
||||
document.status = ProcessingStatus.ERROR
|
||||
document.processing_error = f"Indexing failed: {str(e)}"
|
||||
# Don't raise the exception to avoid retries on indexing failures
|
||||
|
||||
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Mark document as error
|
||||
if 'document' in locals() and document:
|
||||
if "document" in locals() and document:
|
||||
document.status = ProcessingStatus.ERROR
|
||||
document.processing_error = str(e)
|
||||
await session.commit()
|
||||
|
||||
|
||||
logger.error(f"Error processing document {task.document_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get processor statistics"""
|
||||
return {
|
||||
**self.stats,
|
||||
"running": self.running,
|
||||
"worker_count": len(self.workers),
|
||||
"queue_size": self.processing_queue.qsize()
|
||||
"queue_size": self.processing_queue.qsize(),
|
||||
}
|
||||
|
||||
|
||||
async def get_queue_status(self) -> Dict[str, Any]:
|
||||
"""Get detailed queue status"""
|
||||
return {
|
||||
@@ -427,7 +481,7 @@ class DocumentProcessor:
|
||||
"max_queue_size": self.max_queue_size,
|
||||
"queue_full": self.processing_queue.full(),
|
||||
"active_workers": self.stats["active_workers"],
|
||||
"max_workers": self.max_workers
|
||||
"max_workers": self.max_workers,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,9 @@ class EmbeddingService:
|
||||
"""Service for generating text embeddings using a local transformer model"""
|
||||
|
||||
def __init__(self, model_name: Optional[str] = None):
|
||||
self.model_name = model_name or getattr(settings, "RAG_EMBEDDING_MODEL", "BAAI/bge-m3")
|
||||
self.model_name = model_name or getattr(
|
||||
settings, "RAG_EMBEDDING_MODEL", "BAAI/bge-m3"
|
||||
)
|
||||
self.dimension = 1024 # bge-m3 produces 1024-d vectors
|
||||
self.initialized = False
|
||||
self.local_model = None
|
||||
@@ -56,13 +58,15 @@ class EmbeddingService:
|
||||
return False
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to load local embedding model {self.model_name}: {exc}")
|
||||
logger.error(
|
||||
f"Failed to load local embedding model {self.model_name}: {exc}"
|
||||
)
|
||||
logger.warning("Falling back to random embeddings")
|
||||
self.local_model = None
|
||||
self.initialized = False
|
||||
self.backend = "fallback_random"
|
||||
return False
|
||||
|
||||
|
||||
async def get_embedding(self, text: str) -> List[float]:
|
||||
"""Get embedding for a single text"""
|
||||
embeddings = await self.get_embeddings([text])
|
||||
@@ -102,13 +106,21 @@ class EmbeddingService:
|
||||
except Exception as exc:
|
||||
logger.error(f"Local embedding generation failed: {exc}")
|
||||
self.backend = "fallback_random"
|
||||
return self._generate_fallback_embeddings(texts, duration=time.time() - start_time)
|
||||
return self._generate_fallback_embeddings(
|
||||
texts, duration=time.time() - start_time
|
||||
)
|
||||
|
||||
logger.warning("Local embedding model unavailable; using fallback random embeddings")
|
||||
logger.warning(
|
||||
"Local embedding model unavailable; using fallback random embeddings"
|
||||
)
|
||||
self.backend = "fallback_random"
|
||||
return self._generate_fallback_embeddings(texts, duration=time.time() - start_time)
|
||||
|
||||
def _generate_fallback_embeddings(self, texts: List[str], duration: float = None) -> List[List[float]]:
|
||||
return self._generate_fallback_embeddings(
|
||||
texts, duration=time.time() - start_time
|
||||
)
|
||||
|
||||
def _generate_fallback_embeddings(
|
||||
self, texts: List[str], duration: float = None
|
||||
) -> List[List[float]]:
|
||||
"""Generate fallback random embeddings when model unavailable"""
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
@@ -124,30 +136,30 @@ class EmbeddingService:
|
||||
},
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
def _generate_fallback_embedding(self, text: str) -> List[float]:
|
||||
"""Generate a single fallback embedding"""
|
||||
dimension = self.dimension or 1024
|
||||
# Use hash for reproducible random embeddings
|
||||
np.random.seed(hash(text) % 2**32)
|
||||
return np.random.random(dimension).tolist()
|
||||
|
||||
|
||||
async def similarity(self, text1: str, text2: str) -> float:
|
||||
"""Calculate cosine similarity between two texts"""
|
||||
embeddings = await self.get_embeddings([text1, text2])
|
||||
|
||||
|
||||
# Calculate cosine similarity
|
||||
vec1 = np.array(embeddings[0])
|
||||
vec2 = np.array(embeddings[1])
|
||||
|
||||
|
||||
# Normalize vectors
|
||||
vec1_norm = vec1 / np.linalg.norm(vec1)
|
||||
vec2_norm = vec2 / np.linalg.norm(vec2)
|
||||
|
||||
|
||||
# Calculate cosine similarity
|
||||
similarity = np.dot(vec1_norm, vec2_norm)
|
||||
return float(similarity)
|
||||
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get embedding service statistics"""
|
||||
return {
|
||||
@@ -155,7 +167,7 @@ class EmbeddingService:
|
||||
"model_loaded": self.initialized,
|
||||
"dimension": self.dimension,
|
||||
"backend": self.backend,
|
||||
"initialized": self.initialized
|
||||
"initialized": self.initialized,
|
||||
}
|
||||
|
||||
async def cleanup(self):
|
||||
|
||||
@@ -21,17 +21,30 @@ class EnhancedEmbeddingService(EmbeddingService):
|
||||
def __init__(self, model_name: Optional[str] = None):
|
||||
super().__init__(model_name)
|
||||
self.rate_limit_tracker = {
|
||||
'requests_count': 0,
|
||||
'window_start': time.time(),
|
||||
'window_size': 60, # 1 minute window
|
||||
'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 12)), # Configurable
|
||||
'retry_delays': [int(x) for x in getattr(settings, 'RAG_EMBEDDING_RETRY_DELAYS', '1,2,4,8,16').split(',')], # Exponential backoff
|
||||
'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 1.0)),
|
||||
'delay_per_request': float(getattr(settings, 'RAG_EMBEDDING_DELAY_PER_REQUEST', 0.5)),
|
||||
'last_rate_limit_error': None
|
||||
"requests_count": 0,
|
||||
"window_start": time.time(),
|
||||
"window_size": 60, # 1 minute window
|
||||
"max_requests_per_minute": int(
|
||||
getattr(settings, "RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE", 12)
|
||||
), # Configurable
|
||||
"retry_delays": [
|
||||
int(x)
|
||||
for x in getattr(
|
||||
settings, "RAG_EMBEDDING_RETRY_DELAYS", "1,2,4,8,16"
|
||||
).split(",")
|
||||
], # Exponential backoff
|
||||
"delay_between_batches": float(
|
||||
getattr(settings, "RAG_EMBEDDING_DELAY_BETWEEN_BATCHES", 1.0)
|
||||
),
|
||||
"delay_per_request": float(
|
||||
getattr(settings, "RAG_EMBEDDING_DELAY_PER_REQUEST", 0.5)
|
||||
),
|
||||
"last_rate_limit_error": None,
|
||||
}
|
||||
|
||||
async def get_embeddings_with_retry(self, texts: List[str], max_retries: int = None) -> tuple[List[List[float]], bool]:
|
||||
async def get_embeddings_with_retry(
|
||||
self, texts: List[str], max_retries: int = None
|
||||
) -> tuple[List[List[float]], bool]:
|
||||
"""
|
||||
Get embeddings with retry bookkeeping.
|
||||
"""
|
||||
@@ -51,13 +64,20 @@ class EnhancedEmbeddingService(EmbeddingService):
|
||||
return {
|
||||
**base_stats,
|
||||
"rate_limit_info": {
|
||||
"requests_in_current_window": self.rate_limit_tracker['requests_count'],
|
||||
"max_requests_per_minute": self.rate_limit_tracker['max_requests_per_minute'],
|
||||
"window_reset_in_seconds": max(0,
|
||||
self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size'] - time.time()
|
||||
"requests_in_current_window": self.rate_limit_tracker["requests_count"],
|
||||
"max_requests_per_minute": self.rate_limit_tracker[
|
||||
"max_requests_per_minute"
|
||||
],
|
||||
"window_reset_in_seconds": max(
|
||||
0,
|
||||
self.rate_limit_tracker["window_start"]
|
||||
+ self.rate_limit_tracker["window_size"]
|
||||
- time.time(),
|
||||
),
|
||||
"last_rate_limit_error": self.rate_limit_tracker['last_rate_limit_error']
|
||||
}
|
||||
"last_rate_limit_error": self.rate_limit_tracker[
|
||||
"last_rate_limit_error"
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from qdrant_client.models import PointStruct, Filter, FieldCondition, MatchValue
|
||||
from qdrant_client.http.models import Batch
|
||||
|
||||
from app.modules.rag.main import ProcessedDocument
|
||||
|
||||
# from app.core.analytics import log_module_event # Analytics module not available
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -26,8 +27,13 @@ class JSONLProcessor:
|
||||
self.rag_module = rag_module
|
||||
self.config = rag_module.config
|
||||
|
||||
async def process_and_index_jsonl(self, collection_name: str, content: bytes,
|
||||
filename: str, metadata: Dict[str, Any]) -> str:
|
||||
async def process_and_index_jsonl(
|
||||
self,
|
||||
collection_name: str,
|
||||
content: bytes,
|
||||
filename: str,
|
||||
metadata: Dict[str, Any],
|
||||
) -> str:
|
||||
"""Process and index a JSONL file efficiently
|
||||
|
||||
Processes each JSON line as a separate document to avoid
|
||||
@@ -35,8 +41,8 @@ class JSONLProcessor:
|
||||
"""
|
||||
try:
|
||||
# Decode content
|
||||
jsonl_content = content.decode('utf-8', errors='replace')
|
||||
lines = jsonl_content.strip().split('\n')
|
||||
jsonl_content = content.decode("utf-8", errors="replace")
|
||||
lines = jsonl_content.strip().split("\n")
|
||||
|
||||
logger.info(f"Processing JSONL file {filename} with {len(lines)} lines")
|
||||
|
||||
@@ -58,28 +64,38 @@ class JSONLProcessor:
|
||||
batch_start,
|
||||
base_doc_id,
|
||||
filename,
|
||||
metadata
|
||||
metadata,
|
||||
)
|
||||
|
||||
processed_count += len(batch_lines)
|
||||
|
||||
# Log progress
|
||||
if processed_count % 50 == 0:
|
||||
logger.info(f"Processed {processed_count}/{len(lines)} lines from {filename}")
|
||||
logger.info(
|
||||
f"Processed {processed_count}/{len(lines)} lines from {filename}"
|
||||
)
|
||||
|
||||
# Small delay to prevent resource exhaustion
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
logger.info(f"Successfully processed JSONL file {filename} with {len(lines)} lines")
|
||||
logger.info(
|
||||
f"Successfully processed JSONL file {filename} with {len(lines)} lines"
|
||||
)
|
||||
return base_doc_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing JSONL file {filename}: {e}")
|
||||
raise
|
||||
|
||||
async def _process_jsonl_batch(self, collection_name: str, lines: List[str],
|
||||
start_idx: int, base_doc_id: str,
|
||||
filename: str, metadata: Dict[str, Any]) -> None:
|
||||
async def _process_jsonl_batch(
|
||||
self,
|
||||
collection_name: str,
|
||||
lines: List[str],
|
||||
start_idx: int,
|
||||
base_doc_id: str,
|
||||
filename: str,
|
||||
metadata: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Process a batch of JSONL lines"""
|
||||
try:
|
||||
points = []
|
||||
@@ -98,14 +114,14 @@ class JSONLProcessor:
|
||||
continue
|
||||
|
||||
# Handle helpjuice export format
|
||||
if 'payload' in data and data['payload'] is not None:
|
||||
payload = data['payload']
|
||||
article_id = data.get('id', f'article_{line_idx}')
|
||||
if "payload" in data and data["payload"] is not None:
|
||||
payload = data["payload"]
|
||||
article_id = data.get("id", f"article_{line_idx}")
|
||||
|
||||
# Extract Q&A
|
||||
question = payload.get('question', '')
|
||||
answer = payload.get('answer', '')
|
||||
language = payload.get('language', 'EN')
|
||||
question = payload.get("question", "")
|
||||
answer = payload.get("answer", "")
|
||||
language = payload.get("language", "EN")
|
||||
|
||||
if question or answer:
|
||||
# Create Q&A content
|
||||
@@ -120,25 +136,29 @@ class JSONLProcessor:
|
||||
"line_number": line_idx,
|
||||
"content_type": "qa_pair",
|
||||
"question": question[:100], # Truncate for metadata
|
||||
"processed_at": datetime.utcnow().isoformat()
|
||||
"processed_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
# Generate single embedding for the Q&A pair
|
||||
embeddings = await self.rag_module._generate_embeddings([content])
|
||||
embeddings = await self.rag_module._generate_embeddings(
|
||||
[content]
|
||||
)
|
||||
|
||||
# Create point
|
||||
point_id = str(uuid.uuid4())
|
||||
points.append(PointStruct(
|
||||
id=point_id,
|
||||
vector=embeddings[0],
|
||||
payload={
|
||||
**doc_metadata,
|
||||
"document_id": f"{base_doc_id}_{article_id}",
|
||||
"content": content,
|
||||
"chunk_index": 0,
|
||||
"chunk_count": 1
|
||||
}
|
||||
))
|
||||
points.append(
|
||||
PointStruct(
|
||||
id=point_id,
|
||||
vector=embeddings[0],
|
||||
payload={
|
||||
**doc_metadata,
|
||||
"document_id": f"{base_doc_id}_{article_id}",
|
||||
"content": content,
|
||||
"chunk_index": 0,
|
||||
"chunk_count": 1,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Handle generic JSON format
|
||||
else:
|
||||
@@ -146,43 +166,55 @@ class JSONLProcessor:
|
||||
|
||||
# For larger JSON objects, we might need to chunk
|
||||
if len(content) > 1000:
|
||||
chunks = self.rag_module._chunk_text(content, chunk_size=500)
|
||||
embeddings = await self.rag_module._generate_embeddings(chunks)
|
||||
chunks = self.rag_module._chunk_text(
|
||||
content, chunk_size=500
|
||||
)
|
||||
embeddings = await self.rag_module._generate_embeddings(
|
||||
chunks
|
||||
)
|
||||
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
for i, (chunk, embedding) in enumerate(
|
||||
zip(chunks, embeddings)
|
||||
):
|
||||
point_id = str(uuid.uuid4())
|
||||
points.append(PointStruct(
|
||||
points.append(
|
||||
PointStruct(
|
||||
id=point_id,
|
||||
vector=embedding,
|
||||
payload={
|
||||
**metadata,
|
||||
"filename": filename,
|
||||
"line_number": line_idx,
|
||||
"content_type": "json_object",
|
||||
"document_id": f"{base_doc_id}_line_{line_idx}",
|
||||
"content": chunk,
|
||||
"chunk_index": i,
|
||||
"chunk_count": len(chunks),
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Small JSON - no chunking needed
|
||||
embeddings = await self.rag_module._generate_embeddings(
|
||||
[content]
|
||||
)
|
||||
point_id = str(uuid.uuid4())
|
||||
points.append(
|
||||
PointStruct(
|
||||
id=point_id,
|
||||
vector=embedding,
|
||||
vector=embeddings[0],
|
||||
payload={
|
||||
**metadata,
|
||||
"filename": filename,
|
||||
"line_number": line_idx,
|
||||
"content_type": "json_object",
|
||||
"document_id": f"{base_doc_id}_line_{line_idx}",
|
||||
"content": chunk,
|
||||
"chunk_index": i,
|
||||
"chunk_count": len(chunks)
|
||||
}
|
||||
))
|
||||
else:
|
||||
# Small JSON - no chunking needed
|
||||
embeddings = await self.rag_module._generate_embeddings([content])
|
||||
point_id = str(uuid.uuid4())
|
||||
points.append(PointStruct(
|
||||
id=point_id,
|
||||
vector=embeddings[0],
|
||||
payload={
|
||||
**metadata,
|
||||
"filename": filename,
|
||||
"line_number": line_idx,
|
||||
"content_type": "json_object",
|
||||
"document_id": f"{base_doc_id}_line_{line_idx}",
|
||||
"content": content,
|
||||
"chunk_index": 0,
|
||||
"chunk_count": 1
|
||||
}
|
||||
))
|
||||
"content": content,
|
||||
"chunk_index": 0,
|
||||
"chunk_count": 1,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Error parsing JSONL line {line_idx}: {e}")
|
||||
@@ -194,8 +226,7 @@ class JSONLProcessor:
|
||||
# Insert all points in this batch
|
||||
if points:
|
||||
self.rag_module.qdrant_client.upsert(
|
||||
collection_name=collection_name,
|
||||
points=points
|
||||
collection_name=collection_name, points=points
|
||||
)
|
||||
|
||||
# Update stats
|
||||
@@ -208,4 +239,4 @@ class JSONLProcessor:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing JSONL batch: {e}")
|
||||
raise
|
||||
raise
|
||||
|
||||
@@ -11,11 +11,11 @@ from .exceptions import LLMError, ProviderError, SecurityError
|
||||
|
||||
__all__ = [
|
||||
"LLMService",
|
||||
"ChatRequest",
|
||||
"ChatRequest",
|
||||
"ChatResponse",
|
||||
"EmbeddingRequest",
|
||||
"EmbeddingResponse",
|
||||
"EmbeddingResponse",
|
||||
"LLMError",
|
||||
"ProviderError",
|
||||
"SecurityError"
|
||||
]
|
||||
"SecurityError",
|
||||
]
|
||||
|
||||
@@ -15,30 +15,51 @@ from .models import ResilienceConfig
|
||||
|
||||
class ProviderConfig(BaseModel):
|
||||
"""Configuration for an LLM provider"""
|
||||
|
||||
name: str = Field(..., description="Provider name")
|
||||
provider_type: str = Field(..., description="Provider type (e.g., 'openai', 'privatemode')")
|
||||
provider_type: str = Field(
|
||||
..., description="Provider type (e.g., 'openai', 'privatemode')"
|
||||
)
|
||||
enabled: bool = Field(True, description="Whether provider is enabled")
|
||||
base_url: str = Field(..., description="Provider base URL")
|
||||
api_key_env_var: str = Field(..., description="Environment variable for API key")
|
||||
default_model: Optional[str] = Field(None, description="Default model for this provider")
|
||||
supported_models: List[str] = Field(default_factory=list, description="List of supported models")
|
||||
capabilities: List[str] = Field(default_factory=list, description="Provider capabilities")
|
||||
default_model: Optional[str] = Field(
|
||||
None, description="Default model for this provider"
|
||||
)
|
||||
supported_models: List[str] = Field(
|
||||
default_factory=list, description="List of supported models"
|
||||
)
|
||||
capabilities: List[str] = Field(
|
||||
default_factory=list, description="Provider capabilities"
|
||||
)
|
||||
priority: int = Field(1, description="Provider priority (lower = higher priority)")
|
||||
|
||||
|
||||
# Rate limiting
|
||||
max_requests_per_minute: Optional[int] = Field(None, description="Max requests per minute")
|
||||
max_requests_per_hour: Optional[int] = Field(None, description="Max requests per hour")
|
||||
|
||||
max_requests_per_minute: Optional[int] = Field(
|
||||
None, description="Max requests per minute"
|
||||
)
|
||||
max_requests_per_hour: Optional[int] = Field(
|
||||
None, description="Max requests per hour"
|
||||
)
|
||||
|
||||
# Model-specific settings
|
||||
supports_streaming: bool = Field(False, description="Whether provider supports streaming")
|
||||
supports_function_calling: bool = Field(False, description="Whether provider supports function calling")
|
||||
max_context_window: Optional[int] = Field(None, description="Maximum context window size")
|
||||
supports_streaming: bool = Field(
|
||||
False, description="Whether provider supports streaming"
|
||||
)
|
||||
supports_function_calling: bool = Field(
|
||||
False, description="Whether provider supports function calling"
|
||||
)
|
||||
max_context_window: Optional[int] = Field(
|
||||
None, description="Maximum context window size"
|
||||
)
|
||||
max_output_tokens: Optional[int] = Field(None, description="Maximum output tokens")
|
||||
|
||||
|
||||
# Resilience configuration
|
||||
resilience: ResilienceConfig = Field(default_factory=ResilienceConfig, description="Resilience settings")
|
||||
|
||||
@validator('priority')
|
||||
resilience: ResilienceConfig = Field(
|
||||
default_factory=ResilienceConfig, description="Resilience settings"
|
||||
)
|
||||
|
||||
@validator("priority")
|
||||
def validate_priority(cls, v):
|
||||
if v < 1:
|
||||
raise ValueError("Priority must be >= 1")
|
||||
@@ -47,35 +68,48 @@ class ProviderConfig(BaseModel):
|
||||
|
||||
class LLMServiceConfig(BaseModel):
|
||||
"""Main LLM service configuration"""
|
||||
|
||||
|
||||
# Global settings
|
||||
default_provider: str = Field("privatemode", description="Default provider to use")
|
||||
enable_detailed_logging: bool = Field(False, description="Enable detailed request/response logging")
|
||||
enable_detailed_logging: bool = Field(
|
||||
False, description="Enable detailed request/response logging"
|
||||
)
|
||||
enable_security_checks: bool = Field(True, description="Enable security validation")
|
||||
enable_metrics_collection: bool = Field(True, description="Enable metrics collection")
|
||||
|
||||
enable_metrics_collection: bool = Field(
|
||||
True, description="Enable metrics collection"
|
||||
)
|
||||
|
||||
max_prompt_length: int = Field(50000, ge=1000, description="Maximum prompt length")
|
||||
max_response_length: int = Field(32000, ge=1000, description="Maximum response length")
|
||||
|
||||
max_response_length: int = Field(
|
||||
32000, ge=1000, description="Maximum response length"
|
||||
)
|
||||
|
||||
# Performance settings
|
||||
default_timeout_ms: int = Field(30000, ge=1000, le=300000, description="Default request timeout")
|
||||
max_concurrent_requests: int = Field(100, ge=1, le=1000, description="Maximum concurrent requests")
|
||||
|
||||
default_timeout_ms: int = Field(
|
||||
30000, ge=1000, le=300000, description="Default request timeout"
|
||||
)
|
||||
max_concurrent_requests: int = Field(
|
||||
100, ge=1, le=1000, description="Maximum concurrent requests"
|
||||
)
|
||||
|
||||
# Provider configurations
|
||||
providers: Dict[str, ProviderConfig] = Field(default_factory=dict, description="Provider configurations")
|
||||
providers: Dict[str, ProviderConfig] = Field(
|
||||
default_factory=dict, description="Provider configurations"
|
||||
)
|
||||
|
||||
# Token rate limiting (organization-wide)
|
||||
token_limits_per_minute: Dict[str, int] = Field(
|
||||
default_factory=lambda: {
|
||||
"prompt_tokens": 20000, # PrivateMode Standard tier
|
||||
"completion_tokens": 10000 # PrivateMode Standard tier
|
||||
"prompt_tokens": 20000, # PrivateMode Standard tier
|
||||
"completion_tokens": 10000, # PrivateMode Standard tier
|
||||
},
|
||||
description="Token rate limits per minute (organization-wide)"
|
||||
description="Token rate limits per minute (organization-wide)",
|
||||
)
|
||||
|
||||
# Model routing (model_name -> provider_name)
|
||||
model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing")
|
||||
|
||||
model_routing: Dict[str, str] = Field(
|
||||
default_factory=dict, description="Model to provider routing"
|
||||
)
|
||||
|
||||
|
||||
def create_default_config(env_vars=None) -> LLMServiceConfig:
|
||||
@@ -94,8 +128,8 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
|
||||
supported_models=[], # Will be populated dynamically from proxy
|
||||
capabilities=["chat", "embeddings", "tee"],
|
||||
priority=1,
|
||||
max_requests_per_minute=20, # PrivateMode Standard tier limit: 20 req/min
|
||||
max_requests_per_hour=1200, # 20 req/min * 60 min
|
||||
max_requests_per_minute=20, # PrivateMode Standard tier limit: 20 req/min
|
||||
max_requests_per_hour=1200, # 20 req/min * 60 min
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
max_context_window=128000,
|
||||
@@ -105,13 +139,11 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
|
||||
retry_delay_ms=1000,
|
||||
timeout_ms=60000, # PrivateMode may be slower due to TEE
|
||||
circuit_breaker_threshold=5,
|
||||
circuit_breaker_reset_timeout_ms=120000
|
||||
)
|
||||
circuit_breaker_reset_timeout_ms=120000,
|
||||
),
|
||||
)
|
||||
|
||||
providers: Dict[str, ProviderConfig] = {
|
||||
"privatemode": privatemode_config
|
||||
}
|
||||
|
||||
providers: Dict[str, ProviderConfig] = {"privatemode": privatemode_config}
|
||||
|
||||
if env.OPENAI_API_KEY:
|
||||
providers["openai"] = ProviderConfig(
|
||||
@@ -126,7 +158,7 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
|
||||
"gpt-4o",
|
||||
"gpt-3.5-turbo",
|
||||
"text-embedding-3-large",
|
||||
"text-embedding-3-small"
|
||||
"text-embedding-3-small",
|
||||
],
|
||||
capabilities=["chat", "embeddings"],
|
||||
priority=2,
|
||||
@@ -139,10 +171,10 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
|
||||
retry_delay_ms=750,
|
||||
timeout_ms=45000,
|
||||
circuit_breaker_threshold=6,
|
||||
circuit_breaker_reset_timeout_ms=60000
|
||||
)
|
||||
circuit_breaker_reset_timeout_ms=60000,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if env.ANTHROPIC_API_KEY:
|
||||
providers["anthropic"] = ProviderConfig(
|
||||
name="anthropic",
|
||||
@@ -154,7 +186,7 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
|
||||
supported_models=[
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-haiku-20240307"
|
||||
"claude-3-haiku-20240307",
|
||||
],
|
||||
capabilities=["chat"],
|
||||
priority=3,
|
||||
@@ -167,10 +199,10 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
|
||||
retry_delay_ms=1000,
|
||||
timeout_ms=60000,
|
||||
circuit_breaker_threshold=5,
|
||||
circuit_breaker_reset_timeout_ms=90000
|
||||
)
|
||||
circuit_breaker_reset_timeout_ms=90000,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if env.GOOGLE_API_KEY:
|
||||
providers["google"] = ProviderConfig(
|
||||
name="google",
|
||||
@@ -181,7 +213,7 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
|
||||
default_model="models/gemini-1.5-pro-latest",
|
||||
supported_models=[
|
||||
"models/gemini-1.5-pro-latest",
|
||||
"models/gemini-1.5-flash-latest"
|
||||
"models/gemini-1.5-flash-latest",
|
||||
],
|
||||
capabilities=["chat", "multimodal"],
|
||||
priority=4,
|
||||
@@ -194,169 +226,176 @@ def create_default_config(env_vars=None) -> LLMServiceConfig:
|
||||
retry_delay_ms=1000,
|
||||
timeout_ms=45000,
|
||||
circuit_breaker_threshold=4,
|
||||
circuit_breaker_reset_timeout_ms=60000
|
||||
)
|
||||
circuit_breaker_reset_timeout_ms=60000,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
default_provider = next(
|
||||
(name for name, provider in providers.items() if provider.enabled),
|
||||
"privatemode"
|
||||
"privatemode",
|
||||
)
|
||||
|
||||
|
||||
# Create main configuration
|
||||
config = LLMServiceConfig(
|
||||
default_provider=default_provider,
|
||||
enable_detailed_logging=settings.LOG_LLM_PROMPTS,
|
||||
providers=providers,
|
||||
model_routing={} # Will be populated dynamically from provider models
|
||||
model_routing={}, # Will be populated dynamically from provider models
|
||||
)
|
||||
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvironmentVariables:
|
||||
"""Environment variables used by LLM service"""
|
||||
|
||||
|
||||
# Provider API keys
|
||||
PRIVATEMODE_API_KEY: Optional[str] = None
|
||||
OPENAI_API_KEY: Optional[str] = None
|
||||
ANTHROPIC_API_KEY: Optional[str] = None
|
||||
GOOGLE_API_KEY: Optional[str] = None
|
||||
|
||||
|
||||
# Service settings
|
||||
LOG_LLM_PROMPTS: bool = False
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
"""Load values from environment"""
|
||||
self.PRIVATEMODE_API_KEY = os.getenv("PRIVATEMODE_API_KEY")
|
||||
self.OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
||||
self.OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
||||
self.ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
|
||||
self.GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
||||
self.LOG_LLM_PROMPTS = os.getenv("LOG_LLM_PROMPTS", "false").lower() == "true"
|
||||
|
||||
|
||||
def get_api_key(self, provider_name: str) -> Optional[str]:
|
||||
"""Get API key for a specific provider"""
|
||||
key_mapping = {
|
||||
"privatemode": self.PRIVATEMODE_API_KEY,
|
||||
"openai": self.OPENAI_API_KEY,
|
||||
"anthropic": self.ANTHROPIC_API_KEY,
|
||||
"google": self.GOOGLE_API_KEY
|
||||
"google": self.GOOGLE_API_KEY,
|
||||
}
|
||||
|
||||
|
||||
return key_mapping.get(provider_name.lower())
|
||||
|
||||
|
||||
def validate_required_keys(self, enabled_providers: List[str]) -> List[str]:
|
||||
"""Validate that required API keys are present"""
|
||||
missing_keys = []
|
||||
|
||||
|
||||
for provider in enabled_providers:
|
||||
if not self.get_api_key(provider):
|
||||
missing_keys.append(f"{provider.upper()}_API_KEY")
|
||||
|
||||
|
||||
return missing_keys
|
||||
|
||||
|
||||
class ConfigurationManager:
|
||||
"""Manages LLM service configuration"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self._config: Optional[LLMServiceConfig] = None
|
||||
self._env_vars = EnvironmentVariables()
|
||||
|
||||
|
||||
def get_config(self) -> LLMServiceConfig:
|
||||
"""Get current configuration"""
|
||||
if self._config is None:
|
||||
self._config = create_default_config(self._env_vars)
|
||||
self._validate_configuration()
|
||||
|
||||
|
||||
return self._config
|
||||
|
||||
|
||||
def update_config(self, config: LLMServiceConfig):
|
||||
"""Update configuration"""
|
||||
self._config = config
|
||||
self._validate_configuration()
|
||||
|
||||
|
||||
def get_provider_config(self, provider_name: str) -> Optional[ProviderConfig]:
|
||||
"""Get configuration for a specific provider"""
|
||||
config = self.get_config()
|
||||
return config.providers.get(provider_name)
|
||||
|
||||
|
||||
def get_provider_for_model(self, model_name: str) -> Optional[str]:
|
||||
"""Get provider name for a specific model"""
|
||||
config = self.get_config()
|
||||
return config.model_routing.get(model_name)
|
||||
|
||||
|
||||
def get_enabled_providers(self) -> List[str]:
|
||||
"""Get list of enabled providers"""
|
||||
config = self.get_config()
|
||||
return [name for name, provider in config.providers.items() if provider.enabled]
|
||||
|
||||
|
||||
def get_api_key(self, provider_name: str) -> Optional[str]:
|
||||
"""Get API key for provider"""
|
||||
return self._env_vars.get_api_key(provider_name)
|
||||
|
||||
|
||||
def _validate_configuration(self):
|
||||
"""Validate current configuration"""
|
||||
if not self._config:
|
||||
return
|
||||
|
||||
|
||||
# Check for enabled providers without API keys
|
||||
enabled_providers = self.get_enabled_providers()
|
||||
missing_keys = self._env_vars.validate_required_keys(enabled_providers)
|
||||
|
||||
|
||||
if missing_keys:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"Missing API keys for enabled providers: {', '.join(missing_keys)}")
|
||||
|
||||
logger.warning(
|
||||
f"Missing API keys for enabled providers: {', '.join(missing_keys)}"
|
||||
)
|
||||
|
||||
# Validate default provider is enabled
|
||||
default_provider = self._config.default_provider
|
||||
if default_provider not in enabled_providers:
|
||||
raise ValueError(f"Default provider '{default_provider}' is not enabled")
|
||||
|
||||
|
||||
# Validate model routing points to enabled providers
|
||||
invalid_routes = []
|
||||
for model, provider in self._config.model_routing.items():
|
||||
if provider not in enabled_providers:
|
||||
invalid_routes.append(f"{model} -> {provider}")
|
||||
|
||||
|
||||
if invalid_routes:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"Model routes point to disabled providers: {', '.join(invalid_routes)}")
|
||||
|
||||
logger.warning(
|
||||
f"Model routes point to disabled providers: {', '.join(invalid_routes)}"
|
||||
)
|
||||
|
||||
async def refresh_provider_models(self, provider_name: str, models: List[str]):
|
||||
"""Update supported models for a provider dynamically"""
|
||||
if not self._config:
|
||||
return
|
||||
|
||||
|
||||
provider_config = self._config.providers.get(provider_name)
|
||||
if not provider_config:
|
||||
return
|
||||
|
||||
|
||||
# Update supported models
|
||||
provider_config.supported_models = models
|
||||
|
||||
|
||||
# Update model routing - map all models to this provider
|
||||
for model in models:
|
||||
self._config.model_routing[model] = provider_name
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Updated {provider_name} with {len(models)} models: {models}")
|
||||
|
||||
|
||||
async def get_all_available_models(self) -> Dict[str, List[str]]:
|
||||
"""Get all available models grouped by provider"""
|
||||
config = self.get_config()
|
||||
models_by_provider = {}
|
||||
|
||||
|
||||
for provider_name, provider_config in config.providers.items():
|
||||
if provider_config.enabled:
|
||||
models_by_provider[provider_name] = provider_config.supported_models
|
||||
|
||||
|
||||
return models_by_provider
|
||||
|
||||
|
||||
def get_model_provider_mapping(self) -> Dict[str, str]:
|
||||
"""Get current model to provider mapping"""
|
||||
config = self.get_config()
|
||||
|
||||
@@ -7,8 +7,10 @@ Custom exceptions for LLM service operations.
|
||||
|
||||
class LLMError(Exception):
|
||||
"""Base exception for LLM service errors"""
|
||||
|
||||
def __init__(self, message: str, error_code: str = "LLM_ERROR", details: dict = None):
|
||||
|
||||
def __init__(
|
||||
self, message: str, error_code: str = "LLM_ERROR", details: dict = None
|
||||
):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.error_code = error_code
|
||||
@@ -17,46 +19,78 @@ class LLMError(Exception):
|
||||
|
||||
class ProviderError(LLMError):
|
||||
"""Exception for LLM provider-specific errors"""
|
||||
|
||||
def __init__(self, message: str, provider: str, error_code: str = "PROVIDER_ERROR", details: dict = None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
provider: str,
|
||||
error_code: str = "PROVIDER_ERROR",
|
||||
details: dict = None,
|
||||
):
|
||||
super().__init__(message, error_code, details)
|
||||
self.provider = provider
|
||||
|
||||
|
||||
class SecurityError(LLMError):
|
||||
"""Exception for security-related errors"""
|
||||
|
||||
def __init__(self, message: str, risk_score: float = 0.0, error_code: str = "SECURITY_ERROR", details: dict = None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
risk_score: float = 0.0,
|
||||
error_code: str = "SECURITY_ERROR",
|
||||
details: dict = None,
|
||||
):
|
||||
super().__init__(message, error_code, details)
|
||||
self.risk_score = risk_score
|
||||
|
||||
|
||||
class ConfigurationError(LLMError):
|
||||
"""Exception for configuration-related errors"""
|
||||
|
||||
def __init__(self, message: str, error_code: str = "CONFIG_ERROR", details: dict = None):
|
||||
|
||||
def __init__(
|
||||
self, message: str, error_code: str = "CONFIG_ERROR", details: dict = None
|
||||
):
|
||||
super().__init__(message, error_code, details)
|
||||
|
||||
|
||||
class RateLimitError(LLMError):
|
||||
"""Exception for rate limiting errors"""
|
||||
|
||||
def __init__(self, message: str, retry_after: int = None, error_code: str = "RATE_LIMIT_ERROR", details: dict = None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
retry_after: int = None,
|
||||
error_code: str = "RATE_LIMIT_ERROR",
|
||||
details: dict = None,
|
||||
):
|
||||
super().__init__(message, error_code, details)
|
||||
self.retry_after = retry_after
|
||||
|
||||
|
||||
class TimeoutError(LLMError):
|
||||
"""Exception for timeout errors"""
|
||||
|
||||
def __init__(self, message: str, timeout_duration: float = None, error_code: str = "TIMEOUT_ERROR", details: dict = None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
timeout_duration: float = None,
|
||||
error_code: str = "TIMEOUT_ERROR",
|
||||
details: dict = None,
|
||||
):
|
||||
super().__init__(message, error_code, details)
|
||||
self.timeout_duration = timeout_duration
|
||||
|
||||
|
||||
class ValidationError(LLMError):
|
||||
"""Exception for request validation errors"""
|
||||
|
||||
def __init__(self, message: str, field: str = None, error_code: str = "VALIDATION_ERROR", details: dict = None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
field: str = None,
|
||||
error_code: str = "VALIDATION_ERROR",
|
||||
details: dict = None,
|
||||
):
|
||||
super().__init__(message, error_code, details)
|
||||
self.field = field
|
||||
self.field = field
|
||||
|
||||
@@ -20,6 +20,7 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class RequestMetric:
|
||||
"""Individual request metric"""
|
||||
|
||||
timestamp: datetime
|
||||
provider: str
|
||||
model: str
|
||||
@@ -35,26 +36,30 @@ class RequestMetric:
|
||||
|
||||
class MetricsCollector:
|
||||
"""Collects and aggregates LLM service metrics"""
|
||||
|
||||
|
||||
def __init__(self, max_history_size: int = 10000):
|
||||
"""
|
||||
Initialize metrics collector
|
||||
|
||||
|
||||
Args:
|
||||
max_history_size: Maximum number of metrics to keep in memory
|
||||
"""
|
||||
self.max_history_size = max_history_size
|
||||
self._metrics: deque = deque(maxlen=max_history_size)
|
||||
self._provider_metrics: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000))
|
||||
self._provider_metrics: Dict[str, deque] = defaultdict(
|
||||
lambda: deque(maxlen=1000)
|
||||
)
|
||||
self._lock = threading.RLock()
|
||||
|
||||
|
||||
# Aggregated metrics cache
|
||||
self._cache_timestamp: Optional[datetime] = None
|
||||
self._cached_metrics: Optional[LLMMetrics] = None
|
||||
self._cache_ttl_seconds = 60 # Cache for 1 minute
|
||||
|
||||
logger.info(f"Metrics collector initialized with max history: {max_history_size}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Metrics collector initialized with max history: {max_history_size}"
|
||||
)
|
||||
|
||||
def record_request(
|
||||
self,
|
||||
provider: str,
|
||||
@@ -66,7 +71,7 @@ class MetricsCollector:
|
||||
security_risk_score: float = 0.0,
|
||||
error_code: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
api_key_id: Optional[int] = None
|
||||
api_key_id: Optional[int] = None,
|
||||
):
|
||||
"""Record a request metric"""
|
||||
metric = RequestMetric(
|
||||
@@ -80,64 +85,73 @@ class MetricsCollector:
|
||||
security_risk_score=security_risk_score,
|
||||
error_code=error_code,
|
||||
user_id=user_id,
|
||||
api_key_id=api_key_id
|
||||
api_key_id=api_key_id,
|
||||
)
|
||||
|
||||
|
||||
with self._lock:
|
||||
self._metrics.append(metric)
|
||||
self._provider_metrics[provider].append(metric)
|
||||
|
||||
|
||||
# Invalidate cache
|
||||
self._cached_metrics = None
|
||||
self._cache_timestamp = None
|
||||
|
||||
|
||||
# Log significant events
|
||||
if not success:
|
||||
logger.warning(f"Request failed: {provider}/{model} - {error_code or 'Unknown error'}")
|
||||
logger.warning(
|
||||
f"Request failed: {provider}/{model} - {error_code or 'Unknown error'}"
|
||||
)
|
||||
elif security_risk_score > 0.6:
|
||||
logger.info(f"High risk request: {provider}/{model} - risk score: {security_risk_score:.3f}")
|
||||
|
||||
logger.info(
|
||||
f"High risk request: {provider}/{model} - risk score: {security_risk_score:.3f}"
|
||||
)
|
||||
|
||||
def get_metrics(self, force_refresh: bool = False) -> LLMMetrics:
|
||||
"""Get aggregated metrics"""
|
||||
with self._lock:
|
||||
# Check cache validity
|
||||
if (not force_refresh and
|
||||
self._cached_metrics and
|
||||
self._cache_timestamp and
|
||||
(datetime.utcnow() - self._cache_timestamp).total_seconds() < self._cache_ttl_seconds):
|
||||
if (
|
||||
not force_refresh
|
||||
and self._cached_metrics
|
||||
and self._cache_timestamp
|
||||
and (datetime.utcnow() - self._cache_timestamp).total_seconds()
|
||||
< self._cache_ttl_seconds
|
||||
):
|
||||
return self._cached_metrics
|
||||
|
||||
|
||||
# Calculate fresh metrics
|
||||
metrics = self._calculate_metrics()
|
||||
|
||||
|
||||
# Cache results
|
||||
self._cached_metrics = metrics
|
||||
self._cache_timestamp = datetime.utcnow()
|
||||
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def _calculate_metrics(self) -> LLMMetrics:
|
||||
"""Calculate aggregated metrics from recorded data"""
|
||||
if not self._metrics:
|
||||
return LLMMetrics()
|
||||
|
||||
|
||||
total_requests = len(self._metrics)
|
||||
successful_requests = sum(1 for m in self._metrics if m.success)
|
||||
failed_requests = total_requests - successful_requests
|
||||
|
||||
|
||||
# Calculate averages
|
||||
latencies = [m.latency_ms for m in self._metrics if m.latency_ms > 0]
|
||||
risk_scores = [m.security_risk_score for m in self._metrics]
|
||||
|
||||
|
||||
avg_latency = sum(latencies) / len(latencies) if latencies else 0.0
|
||||
avg_risk_score = sum(risk_scores) / len(risk_scores) if risk_scores else 0.0
|
||||
|
||||
|
||||
# Provider-specific metrics
|
||||
provider_metrics = {}
|
||||
for provider, provider_data in self._provider_metrics.items():
|
||||
if provider_data:
|
||||
provider_metrics[provider] = self._calculate_provider_metrics(provider_data)
|
||||
|
||||
provider_metrics[provider] = self._calculate_provider_metrics(
|
||||
provider_data
|
||||
)
|
||||
|
||||
return LLMMetrics(
|
||||
total_requests=total_requests,
|
||||
successful_requests=successful_requests,
|
||||
@@ -145,48 +159,50 @@ class MetricsCollector:
|
||||
average_latency_ms=avg_latency,
|
||||
average_risk_score=avg_risk_score,
|
||||
provider_metrics=provider_metrics,
|
||||
last_updated=datetime.utcnow()
|
||||
last_updated=datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
def _calculate_provider_metrics(self, provider_data: deque) -> Dict[str, Any]:
|
||||
"""Calculate metrics for a specific provider"""
|
||||
if not provider_data:
|
||||
return {}
|
||||
|
||||
|
||||
total = len(provider_data)
|
||||
successful = sum(1 for m in provider_data if m.success)
|
||||
failed = total - successful
|
||||
|
||||
|
||||
latencies = [m.latency_ms for m in provider_data if m.latency_ms > 0]
|
||||
avg_latency = sum(latencies) / len(latencies) if latencies else 0.0
|
||||
|
||||
|
||||
# Token usage aggregation
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
total_tokens = 0
|
||||
|
||||
|
||||
for metric in provider_data:
|
||||
if metric.token_usage:
|
||||
total_prompt_tokens += metric.token_usage.get("prompt_tokens", 0)
|
||||
total_completion_tokens += metric.token_usage.get("completion_tokens", 0)
|
||||
total_completion_tokens += metric.token_usage.get(
|
||||
"completion_tokens", 0
|
||||
)
|
||||
total_tokens += metric.token_usage.get("total_tokens", 0)
|
||||
|
||||
|
||||
# Model distribution
|
||||
model_counts = defaultdict(int)
|
||||
for metric in provider_data:
|
||||
model_counts[metric.model] += 1
|
||||
|
||||
|
||||
# Request type distribution
|
||||
request_type_counts = defaultdict(int)
|
||||
for metric in provider_data:
|
||||
request_type_counts[metric.request_type] += 1
|
||||
|
||||
|
||||
# Error analysis
|
||||
error_counts = defaultdict(int)
|
||||
for metric in provider_data:
|
||||
if not metric.success and metric.error_code:
|
||||
error_counts[metric.error_code] += 1
|
||||
|
||||
|
||||
return {
|
||||
"total_requests": total,
|
||||
"successful_requests": successful,
|
||||
@@ -198,51 +214,57 @@ class MetricsCollector:
|
||||
"total_completion_tokens": total_completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"avg_prompt_tokens": total_prompt_tokens / total if total > 0 else 0,
|
||||
"avg_completion_tokens": total_completion_tokens / successful if successful > 0 else 0
|
||||
"avg_completion_tokens": total_completion_tokens / successful
|
||||
if successful > 0
|
||||
else 0,
|
||||
},
|
||||
"model_distribution": dict(model_counts),
|
||||
"request_type_distribution": dict(request_type_counts),
|
||||
"error_distribution": dict(error_counts),
|
||||
"recent_requests": total
|
||||
"recent_requests": total,
|
||||
}
|
||||
|
||||
|
||||
def get_provider_metrics(self, provider: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get metrics for a specific provider"""
|
||||
with self._lock:
|
||||
if provider not in self._provider_metrics:
|
||||
return None
|
||||
|
||||
|
||||
return self._calculate_provider_metrics(self._provider_metrics[provider])
|
||||
|
||||
|
||||
def get_recent_metrics(self, minutes: int = 5) -> List[RequestMetric]:
|
||||
"""Get metrics from the last N minutes"""
|
||||
cutoff_time = datetime.utcnow() - timedelta(minutes=minutes)
|
||||
|
||||
|
||||
with self._lock:
|
||||
return [m for m in self._metrics if m.timestamp >= cutoff_time]
|
||||
|
||||
|
||||
def get_error_metrics(self, hours: int = 1) -> Dict[str, int]:
|
||||
"""Get error distribution from the last N hours"""
|
||||
cutoff_time = datetime.utcnow() - timedelta(hours=hours)
|
||||
error_counts = defaultdict(int)
|
||||
|
||||
|
||||
with self._lock:
|
||||
for metric in self._metrics:
|
||||
if metric.timestamp >= cutoff_time and not metric.success and metric.error_code:
|
||||
if (
|
||||
metric.timestamp >= cutoff_time
|
||||
and not metric.success
|
||||
and metric.error_code
|
||||
):
|
||||
error_counts[metric.error_code] += 1
|
||||
|
||||
|
||||
return dict(error_counts)
|
||||
|
||||
|
||||
def get_performance_metrics(self, minutes: int = 15) -> Dict[str, Dict[str, float]]:
|
||||
"""Get performance metrics by provider from the last N minutes"""
|
||||
cutoff_time = datetime.utcnow() - timedelta(minutes=minutes)
|
||||
provider_perf = defaultdict(list)
|
||||
|
||||
|
||||
with self._lock:
|
||||
for metric in self._metrics:
|
||||
if metric.timestamp >= cutoff_time and metric.success:
|
||||
provider_perf[metric.provider].append(metric.latency_ms)
|
||||
|
||||
|
||||
performance = {}
|
||||
for provider, latencies in provider_perf.items():
|
||||
if latencies:
|
||||
@@ -252,26 +274,26 @@ class MetricsCollector:
|
||||
"max_latency_ms": max(latencies),
|
||||
"p95_latency_ms": self._percentile(latencies, 95),
|
||||
"p99_latency_ms": self._percentile(latencies, 99),
|
||||
"request_count": len(latencies)
|
||||
"request_count": len(latencies),
|
||||
}
|
||||
|
||||
|
||||
return performance
|
||||
|
||||
|
||||
def _percentile(self, data: List[float], percentile: int) -> float:
|
||||
"""Calculate percentile of a list of numbers"""
|
||||
if not data:
|
||||
return 0.0
|
||||
|
||||
|
||||
sorted_data = sorted(data)
|
||||
index = (percentile / 100.0) * (len(sorted_data) - 1)
|
||||
|
||||
|
||||
if index.is_integer():
|
||||
return sorted_data[int(index)]
|
||||
else:
|
||||
lower = sorted_data[int(index)]
|
||||
upper = sorted_data[int(index) + 1]
|
||||
return lower + (upper - lower) * (index - int(index))
|
||||
|
||||
|
||||
def clear_metrics(self):
|
||||
"""Clear all metrics (use with caution)"""
|
||||
with self._lock:
|
||||
@@ -279,20 +301,20 @@ class MetricsCollector:
|
||||
self._provider_metrics.clear()
|
||||
self._cached_metrics = None
|
||||
self._cache_timestamp = None
|
||||
|
||||
|
||||
logger.info("All metrics cleared")
|
||||
|
||||
|
||||
def get_health_summary(self) -> Dict[str, Any]:
|
||||
"""Get a health summary for monitoring"""
|
||||
metrics = self.get_metrics()
|
||||
recent_metrics = self.get_recent_metrics(minutes=5)
|
||||
error_metrics = self.get_error_metrics(hours=1)
|
||||
|
||||
|
||||
# Calculate health scores
|
||||
total_recent = len(recent_metrics)
|
||||
successful_recent = sum(1 for m in recent_metrics if m.success)
|
||||
success_rate = successful_recent / total_recent if total_recent > 0 else 1.0
|
||||
|
||||
|
||||
# Determine health status
|
||||
if success_rate >= 0.95:
|
||||
health_status = "healthy"
|
||||
@@ -300,18 +322,20 @@ class MetricsCollector:
|
||||
health_status = "degraded"
|
||||
else:
|
||||
health_status = "unhealthy"
|
||||
|
||||
|
||||
return {
|
||||
"health_status": health_status,
|
||||
"success_rate_5min": success_rate,
|
||||
"total_requests_5min": total_recent,
|
||||
"average_latency_ms": metrics.average_latency_ms,
|
||||
"error_count_1hour": sum(error_metrics.values()),
|
||||
"top_errors": dict(sorted(error_metrics.items(), key=lambda x: x[1], reverse=True)[:5]),
|
||||
"top_errors": dict(
|
||||
sorted(error_metrics.items(), key=lambda x: x[1], reverse=True)[:5]
|
||||
),
|
||||
"provider_count": len(metrics.provider_metrics),
|
||||
"last_updated": datetime.utcnow().isoformat()
|
||||
"last_updated": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# Global metrics collector instance
|
||||
metrics_collector = MetricsCollector()
|
||||
metrics_collector = MetricsCollector()
|
||||
|
||||
@@ -9,15 +9,30 @@ from pydantic import BaseModel, Field, validator
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""Tool call in a message"""
|
||||
|
||||
id: str = Field(..., description="Tool call identifier")
|
||||
type: str = Field("function", description="Tool call type")
|
||||
function: Dict[str, Any] = Field(..., description="Function call details")
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""Individual chat message"""
|
||||
|
||||
role: str = Field(..., description="Message role (system, user, assistant)")
|
||||
content: str = Field(..., description="Message content")
|
||||
content: Optional[str] = Field(None, description="Message content")
|
||||
name: Optional[str] = Field(None, description="Optional message name")
|
||||
|
||||
@validator('role')
|
||||
tool_calls: Optional[List[ToolCall]] = Field(
|
||||
None, description="Tool calls in this message"
|
||||
)
|
||||
tool_call_id: Optional[str] = Field(
|
||||
None, description="Tool call ID for tool responses"
|
||||
)
|
||||
|
||||
@validator("role")
|
||||
def validate_role(cls, v):
|
||||
allowed_roles = {'system', 'user', 'assistant', 'function'}
|
||||
allowed_roles = {"system", "user", "assistant", "function", "tool"}
|
||||
if v not in allowed_roles:
|
||||
raise ValueError(f"Role must be one of {allowed_roles}")
|
||||
return v
|
||||
@@ -25,21 +40,38 @@ class ChatMessage(BaseModel):
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""Chat completion request"""
|
||||
|
||||
model: str = Field(..., description="Model identifier")
|
||||
messages: List[ChatMessage] = Field(..., description="Chat messages")
|
||||
temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature")
|
||||
max_tokens: Optional[int] = Field(None, ge=1, le=32000, description="Maximum tokens to generate")
|
||||
top_p: Optional[float] = Field(1.0, ge=0.0, le=1.0, description="Nucleus sampling parameter")
|
||||
temperature: Optional[float] = Field(
|
||||
0.7, ge=0.0, le=2.0, description="Sampling temperature"
|
||||
)
|
||||
max_tokens: Optional[int] = Field(
|
||||
None, ge=1, le=32000, description="Maximum tokens to generate"
|
||||
)
|
||||
top_p: Optional[float] = Field(
|
||||
1.0, ge=0.0, le=1.0, description="Nucleus sampling parameter"
|
||||
)
|
||||
top_k: Optional[int] = Field(None, ge=1, description="Top-k sampling parameter")
|
||||
frequency_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="Frequency penalty")
|
||||
presence_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="Presence penalty")
|
||||
frequency_penalty: Optional[float] = Field(
|
||||
0.0, ge=-2.0, le=2.0, description="Frequency penalty"
|
||||
)
|
||||
presence_penalty: Optional[float] = Field(
|
||||
0.0, ge=-2.0, le=2.0, description="Presence penalty"
|
||||
)
|
||||
stop: Optional[Union[str, List[str]]] = Field(None, description="Stop sequences")
|
||||
stream: Optional[bool] = Field(False, description="Stream response")
|
||||
tools: Optional[List[Dict[str, Any]]] = Field(
|
||||
None, description="Available tools for function calling"
|
||||
)
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
|
||||
None, description="Tool choice preference"
|
||||
)
|
||||
user_id: str = Field(..., description="User identifier")
|
||||
api_key_id: int = Field(..., description="API key identifier")
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
|
||||
|
||||
@validator('messages')
|
||||
|
||||
@validator("messages")
|
||||
def validate_messages(cls, v):
|
||||
if not v:
|
||||
raise ValueError("Messages cannot be empty")
|
||||
@@ -48,6 +80,7 @@ class ChatRequest(BaseModel):
|
||||
|
||||
class TokenUsage(BaseModel):
|
||||
"""Token usage information"""
|
||||
|
||||
prompt_tokens: int = Field(..., description="Tokens in the prompt")
|
||||
completion_tokens: int = Field(..., description="Tokens in the completion")
|
||||
total_tokens: int = Field(..., description="Total tokens used")
|
||||
@@ -55,13 +88,17 @@ class TokenUsage(BaseModel):
|
||||
|
||||
class ChatChoice(BaseModel):
|
||||
"""Chat completion choice"""
|
||||
|
||||
index: int = Field(..., description="Choice index")
|
||||
message: ChatMessage = Field(..., description="Generated message")
|
||||
finish_reason: Optional[str] = Field(None, description="Reason for completion finish")
|
||||
finish_reason: Optional[str] = Field(
|
||||
None, description="Reason for completion finish"
|
||||
)
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""Chat completion response"""
|
||||
|
||||
id: str = Field(..., description="Response identifier")
|
||||
object: str = Field("chat.completion", description="Object type")
|
||||
created: int = Field(..., description="Creation timestamp")
|
||||
@@ -70,19 +107,28 @@ class ChatResponse(BaseModel):
|
||||
choices: List[ChatChoice] = Field(..., description="Generated choices")
|
||||
usage: Optional[TokenUsage] = Field(None, description="Token usage")
|
||||
system_fingerprint: Optional[str] = Field(None, description="System fingerprint")
|
||||
|
||||
|
||||
# Security fields maintained for backward compatibility
|
||||
security_check: Optional[bool] = Field(None, description="Whether security check passed")
|
||||
security_check: Optional[bool] = Field(
|
||||
None, description="Whether security check passed"
|
||||
)
|
||||
risk_score: Optional[float] = Field(None, description="Security risk score")
|
||||
detected_patterns: Optional[List[str]] = Field(None, description="Detected security patterns")
|
||||
detected_patterns: Optional[List[str]] = Field(
|
||||
None, description="Detected security patterns"
|
||||
)
|
||||
|
||||
# Performance metrics
|
||||
latency_ms: Optional[float] = Field(None, description="Response latency in milliseconds")
|
||||
provider_latency_ms: Optional[float] = Field(None, description="Provider-specific latency")
|
||||
latency_ms: Optional[float] = Field(
|
||||
None, description="Response latency in milliseconds"
|
||||
)
|
||||
provider_latency_ms: Optional[float] = Field(
|
||||
None, description="Provider-specific latency"
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
"""Embedding generation request"""
|
||||
|
||||
model: str = Field(..., description="Embedding model identifier")
|
||||
input: Union[str, List[str]] = Field(..., description="Text to embed")
|
||||
encoding_format: Optional[str] = Field("float", description="Encoding format")
|
||||
@@ -90,20 +136,23 @@ class EmbeddingRequest(BaseModel):
|
||||
user_id: str = Field(..., description="User identifier")
|
||||
api_key_id: int = Field(..., description="API key identifier")
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
|
||||
|
||||
@validator('input')
|
||||
|
||||
@validator("input")
|
||||
def validate_input(cls, v):
|
||||
if isinstance(v, str):
|
||||
if not v.strip():
|
||||
raise ValueError("Input text cannot be empty")
|
||||
elif isinstance(v, list):
|
||||
if not v or not all(isinstance(item, str) and item.strip() for item in v):
|
||||
raise ValueError("Input list cannot be empty and must contain non-empty strings")
|
||||
raise ValueError(
|
||||
"Input list cannot be empty and must contain non-empty strings"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class EmbeddingData(BaseModel):
|
||||
"""Single embedding data"""
|
||||
|
||||
object: str = Field("embedding", description="Object type")
|
||||
index: int = Field(..., description="Embedding index")
|
||||
embedding: List[float] = Field(..., description="Embedding vector")
|
||||
@@ -111,63 +160,98 @@ class EmbeddingData(BaseModel):
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
"""Embedding generation response"""
|
||||
|
||||
object: str = Field("list", description="Object type")
|
||||
data: List[EmbeddingData] = Field(..., description="Embedding data")
|
||||
model: str = Field(..., description="Model used")
|
||||
provider: str = Field(..., description="Provider used")
|
||||
usage: Optional[TokenUsage] = Field(None, description="Token usage")
|
||||
|
||||
|
||||
# Security fields maintained for backward compatibility
|
||||
security_check: Optional[bool] = Field(None, description="Whether security check passed")
|
||||
security_check: Optional[bool] = Field(
|
||||
None, description="Whether security check passed"
|
||||
)
|
||||
risk_score: Optional[float] = Field(None, description="Security risk score")
|
||||
detected_patterns: Optional[List[str]] = Field(None, description="Detected security patterns")
|
||||
detected_patterns: Optional[List[str]] = Field(
|
||||
None, description="Detected security patterns"
|
||||
)
|
||||
|
||||
# Performance metrics
|
||||
latency_ms: Optional[float] = Field(None, description="Response latency in milliseconds")
|
||||
provider_latency_ms: Optional[float] = Field(None, description="Provider-specific latency")
|
||||
latency_ms: Optional[float] = Field(
|
||||
None, description="Response latency in milliseconds"
|
||||
)
|
||||
provider_latency_ms: Optional[float] = Field(
|
||||
None, description="Provider-specific latency"
|
||||
)
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Model information"""
|
||||
|
||||
id: str = Field(..., description="Model identifier")
|
||||
object: str = Field("model", description="Object type")
|
||||
created: Optional[int] = Field(None, description="Creation timestamp")
|
||||
owned_by: str = Field(..., description="Model owner")
|
||||
provider: str = Field(..., description="Provider name")
|
||||
capabilities: List[str] = Field(default_factory=list, description="Model capabilities")
|
||||
capabilities: List[str] = Field(
|
||||
default_factory=list, description="Model capabilities"
|
||||
)
|
||||
context_window: Optional[int] = Field(None, description="Context window size")
|
||||
max_output_tokens: Optional[int] = Field(None, description="Maximum output tokens")
|
||||
supports_streaming: bool = Field(False, description="Whether model supports streaming")
|
||||
supports_function_calling: bool = Field(False, description="Whether model supports function calling")
|
||||
tasks: Optional[List[str]] = Field(None, description="Model tasks (e.g., generate, embed, vision)")
|
||||
supports_streaming: bool = Field(
|
||||
False, description="Whether model supports streaming"
|
||||
)
|
||||
supports_function_calling: bool = Field(
|
||||
False, description="Whether model supports function calling"
|
||||
)
|
||||
tasks: Optional[List[str]] = Field(
|
||||
None, description="Model tasks (e.g., generate, embed, vision)"
|
||||
)
|
||||
|
||||
|
||||
class ProviderStatus(BaseModel):
|
||||
"""Provider health status"""
|
||||
|
||||
provider: str = Field(..., description="Provider name")
|
||||
status: str = Field(..., description="Status (healthy, degraded, unavailable)")
|
||||
latency_ms: Optional[float] = Field(None, description="Average latency")
|
||||
success_rate: Optional[float] = Field(None, description="Success rate (0.0 to 1.0)")
|
||||
last_check: datetime = Field(..., description="Last health check timestamp")
|
||||
error_message: Optional[str] = Field(None, description="Error message if unhealthy")
|
||||
models_available: List[str] = Field(default_factory=list, description="Available models")
|
||||
models_available: List[str] = Field(
|
||||
default_factory=list, description="Available models"
|
||||
)
|
||||
|
||||
|
||||
class LLMMetrics(BaseModel):
|
||||
"""LLM service metrics"""
|
||||
|
||||
total_requests: int = Field(0, description="Total requests processed")
|
||||
successful_requests: int = Field(0, description="Successful requests")
|
||||
failed_requests: int = Field(0, description="Failed requests")
|
||||
average_latency_ms: float = Field(0.0, description="Average response latency")
|
||||
provider_metrics: Dict[str, Dict[str, Any]] = Field(default_factory=dict, description="Per-provider metrics")
|
||||
last_updated: datetime = Field(default_factory=datetime.utcnow, description="Last metrics update")
|
||||
provider_metrics: Dict[str, Dict[str, Any]] = Field(
|
||||
default_factory=dict, description="Per-provider metrics"
|
||||
)
|
||||
last_updated: datetime = Field(
|
||||
default_factory=datetime.utcnow, description="Last metrics update"
|
||||
)
|
||||
|
||||
|
||||
class ResilienceConfig(BaseModel):
|
||||
"""Configuration for resilience patterns"""
|
||||
|
||||
max_retries: int = Field(3, ge=0, le=10, description="Maximum retry attempts")
|
||||
retry_delay_ms: int = Field(1000, ge=100, le=30000, description="Initial retry delay")
|
||||
retry_exponential_base: float = Field(2.0, ge=1.1, le=5.0, description="Exponential backoff base")
|
||||
retry_delay_ms: int = Field(
|
||||
1000, ge=100, le=30000, description="Initial retry delay"
|
||||
)
|
||||
retry_exponential_base: float = Field(
|
||||
2.0, ge=1.1, le=5.0, description="Exponential backoff base"
|
||||
)
|
||||
timeout_ms: int = Field(30000, ge=1000, le=300000, description="Request timeout")
|
||||
circuit_breaker_threshold: int = Field(5, ge=1, le=50, description="Circuit breaker failure threshold")
|
||||
circuit_breaker_reset_timeout_ms: int = Field(60000, ge=10000, le=600000, description="Circuit breaker reset timeout")
|
||||
circuit_breaker_threshold: int = Field(
|
||||
5, ge=1, le=50, description="Circuit breaker failure threshold"
|
||||
)
|
||||
circuit_breaker_reset_timeout_ms: int = Field(
|
||||
60000, ge=10000, le=600000, description="Circuit breaker reset timeout"
|
||||
)
|
||||
|
||||
@@ -7,4 +7,4 @@ Base provider interface and provider implementations.
|
||||
from .base import BaseLLMProvider
|
||||
from .privatemode import PrivateModeProvider
|
||||
|
||||
__all__ = ["BaseLLMProvider", "PrivateModeProvider"]
|
||||
__all__ = ["BaseLLMProvider", "PrivateModeProvider"]
|
||||
|
||||
@@ -9,8 +9,12 @@ from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
import logging
|
||||
|
||||
from ..models import (
|
||||
ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
|
||||
ModelInfo, ProviderStatus
|
||||
ChatRequest,
|
||||
ChatResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
ModelInfo,
|
||||
ProviderStatus,
|
||||
)
|
||||
from ..config import ProviderConfig
|
||||
|
||||
@@ -19,11 +23,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseLLMProvider(ABC):
|
||||
"""Abstract base class for LLM providers"""
|
||||
|
||||
|
||||
def __init__(self, config: ProviderConfig, api_key: str):
|
||||
"""
|
||||
Initialize provider
|
||||
|
||||
|
||||
Args:
|
||||
config: Provider configuration
|
||||
api_key: Decrypted API key for the provider
|
||||
@@ -32,112 +36,114 @@ class BaseLLMProvider(ABC):
|
||||
self.api_key = api_key
|
||||
self.name = config.name
|
||||
self._session = None
|
||||
|
||||
|
||||
logger.info(f"Initializing {self.name} provider")
|
||||
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def provider_name(self) -> str:
|
||||
"""Get provider name"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def health_check(self) -> ProviderStatus:
|
||||
"""
|
||||
Check provider health status
|
||||
|
||||
|
||||
Returns:
|
||||
ProviderStatus with current health information
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def get_models(self) -> List[ModelInfo]:
|
||||
"""
|
||||
Get list of available models
|
||||
|
||||
|
||||
Returns:
|
||||
List of available models with their capabilities
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def create_chat_completion(self, request: ChatRequest) -> ChatResponse:
|
||||
"""
|
||||
Create chat completion
|
||||
|
||||
|
||||
Args:
|
||||
request: Chat completion request
|
||||
|
||||
|
||||
Returns:
|
||||
Chat completion response
|
||||
|
||||
|
||||
Raises:
|
||||
ProviderError: If provider-specific error occurs
|
||||
SecurityError: If security validation fails
|
||||
ValidationError: If request validation fails
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def create_chat_completion_stream(self, request: ChatRequest) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
async def create_chat_completion_stream(
|
||||
self, request: ChatRequest
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
Create streaming chat completion
|
||||
|
||||
|
||||
Args:
|
||||
request: Chat completion request with stream=True
|
||||
|
||||
|
||||
Yields:
|
||||
Streaming response chunks
|
||||
|
||||
|
||||
Raises:
|
||||
ProviderError: If provider-specific error occurs
|
||||
SecurityError: If security validation fails
|
||||
ValidationError: If request validation fails
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def create_embedding(self, request: EmbeddingRequest) -> EmbeddingResponse:
|
||||
"""
|
||||
Create embeddings
|
||||
|
||||
|
||||
Args:
|
||||
request: Embedding generation request
|
||||
|
||||
|
||||
Returns:
|
||||
Embedding response
|
||||
|
||||
|
||||
Raises:
|
||||
ProviderError: If provider-specific error occurs
|
||||
SecurityError: If security validation fails
|
||||
ValidationError: If request validation fails
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize provider resources (override if needed)"""
|
||||
pass
|
||||
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup provider resources"""
|
||||
if self._session and hasattr(self._session, 'close'):
|
||||
if self._session and hasattr(self._session, "close"):
|
||||
await self._session.close()
|
||||
logger.debug(f"Cleaned up session for {self.name} provider")
|
||||
|
||||
|
||||
def supports_model(self, model_name: str) -> bool:
|
||||
"""Check if provider supports a specific model"""
|
||||
return model_name in self.config.supported_models
|
||||
|
||||
|
||||
def supports_capability(self, capability: str) -> bool:
|
||||
"""Check if provider supports a specific capability"""
|
||||
return capability in self.config.capabilities
|
||||
|
||||
|
||||
def get_model_info(self, model_name: str) -> Optional[ModelInfo]:
|
||||
"""Get information about a specific model (override for provider-specific info)"""
|
||||
if not self.supports_model(model_name):
|
||||
return None
|
||||
|
||||
|
||||
return ModelInfo(
|
||||
id=model_name,
|
||||
object="model",
|
||||
@@ -147,80 +153,89 @@ class BaseLLMProvider(ABC):
|
||||
context_window=self.config.max_context_window,
|
||||
max_output_tokens=self.config.max_output_tokens,
|
||||
supports_streaming=self.config.supports_streaming,
|
||||
supports_function_calling=self.config.supports_function_calling
|
||||
supports_function_calling=self.config.supports_function_calling,
|
||||
)
|
||||
|
||||
|
||||
def _validate_request(self, request: Any):
|
||||
"""Base request validation (override for provider-specific validation)"""
|
||||
if hasattr(request, 'model') and not self.supports_model(request.model):
|
||||
if hasattr(request, "model") and not self.supports_model(request.model):
|
||||
from ..exceptions import ValidationError
|
||||
|
||||
raise ValidationError(
|
||||
f"Model '{request.model}' not supported by provider '{self.name}'",
|
||||
field="model"
|
||||
field="model",
|
||||
)
|
||||
|
||||
def _create_headers(self, additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
|
||||
|
||||
def _create_headers(
|
||||
self, additional_headers: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, str]:
|
||||
"""Create HTTP headers for requests"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"User-Agent": f"Enclava-LLM-Service/{self.name}"
|
||||
"User-Agent": f"Enclava-LLM-Service/{self.name}",
|
||||
}
|
||||
|
||||
|
||||
if additional_headers:
|
||||
headers.update(additional_headers)
|
||||
|
||||
|
||||
return headers
|
||||
|
||||
def _handle_http_error(self, status_code: int, response_text: str, provider_context: str = ""):
|
||||
|
||||
def _handle_http_error(
|
||||
self, status_code: int, response_text: str, provider_context: str = ""
|
||||
):
|
||||
"""Handle HTTP errors consistently across providers"""
|
||||
from ..exceptions import ProviderError, RateLimitError, ValidationError
|
||||
|
||||
|
||||
context = f"{self.name} {provider_context}".strip()
|
||||
|
||||
|
||||
if status_code == 401:
|
||||
raise ProviderError(
|
||||
f"Authentication failed for {context}",
|
||||
provider=self.name,
|
||||
error_code="AUTHENTICATION_ERROR",
|
||||
details={"status_code": status_code, "response": response_text}
|
||||
details={"status_code": status_code, "response": response_text},
|
||||
)
|
||||
elif status_code == 403:
|
||||
raise ProviderError(
|
||||
f"Access forbidden for {context}",
|
||||
provider=self.name,
|
||||
error_code="AUTHORIZATION_ERROR",
|
||||
details={"status_code": status_code, "response": response_text}
|
||||
details={"status_code": status_code, "response": response_text},
|
||||
)
|
||||
elif status_code == 429:
|
||||
raise RateLimitError(
|
||||
f"Rate limit exceeded for {context}",
|
||||
error_code="RATE_LIMIT_ERROR",
|
||||
details={"status_code": status_code, "response": response_text, "provider": self.name}
|
||||
details={
|
||||
"status_code": status_code,
|
||||
"response": response_text,
|
||||
"provider": self.name,
|
||||
},
|
||||
)
|
||||
elif status_code == 400:
|
||||
raise ValidationError(
|
||||
f"Bad request for {context}: {response_text}",
|
||||
error_code="BAD_REQUEST",
|
||||
details={"status_code": status_code, "response": response_text}
|
||||
details={"status_code": status_code, "response": response_text},
|
||||
)
|
||||
elif 500 <= status_code < 600:
|
||||
raise ProviderError(
|
||||
f"Server error for {context}: {response_text}",
|
||||
provider=self.name,
|
||||
error_code="SERVER_ERROR",
|
||||
details={"status_code": status_code, "response": response_text}
|
||||
details={"status_code": status_code, "response": response_text},
|
||||
)
|
||||
else:
|
||||
raise ProviderError(
|
||||
f"HTTP error {status_code} for {context}: {response_text}",
|
||||
provider=self.name,
|
||||
error_code="HTTP_ERROR",
|
||||
details={"status_code": status_code, "response": response_text}
|
||||
details={"status_code": status_code, "response": response_text},
|
||||
)
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}(name={self.name})"
|
||||
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(name={self.name}, enabled={self.config.enabled})"
|
||||
return f"{self.__class__.__name__}(name={self.name}, enabled={self.config.enabled})"
|
||||
|
||||
@@ -15,9 +15,16 @@ import aiohttp
|
||||
|
||||
from .base import BaseLLMProvider
|
||||
from ..models import (
|
||||
ChatRequest, ChatResponse, ChatMessage, ChatChoice, TokenUsage,
|
||||
EmbeddingRequest, EmbeddingResponse, EmbeddingData,
|
||||
ModelInfo, ProviderStatus
|
||||
ChatRequest,
|
||||
ChatResponse,
|
||||
ChatMessage,
|
||||
ChatChoice,
|
||||
TokenUsage,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingData,
|
||||
ModelInfo,
|
||||
ProviderStatus,
|
||||
)
|
||||
from ..config import ProviderConfig
|
||||
from ..exceptions import ProviderError, ValidationError, TimeoutError
|
||||
@@ -27,22 +34,22 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class PrivateModeProvider(BaseLLMProvider):
|
||||
"""PrivateMode.ai provider with TEE security"""
|
||||
|
||||
|
||||
def __init__(self, config: ProviderConfig, api_key: str):
|
||||
super().__init__(config, api_key)
|
||||
self.base_url = config.base_url.rstrip('/')
|
||||
self.base_url = config.base_url.rstrip("/")
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
|
||||
# TEE-specific settings
|
||||
self.verify_ssl = True # Always verify SSL for security
|
||||
self.trust_env = False # Don't trust environment proxy settings
|
||||
|
||||
|
||||
logger.info(f"PrivateMode provider initialized with base URL: {self.base_url}")
|
||||
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "privatemode"
|
||||
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create HTTP session with security settings"""
|
||||
if self._session is None or self._session.closed:
|
||||
@@ -52,45 +59,49 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
limit=100, # Connection pool limit
|
||||
limit_per_host=50,
|
||||
ttl_dns_cache=300, # DNS cache TTL
|
||||
use_dns_cache=True
|
||||
use_dns_cache=True,
|
||||
)
|
||||
|
||||
|
||||
# Create session with security headers
|
||||
timeout = aiohttp.ClientTimeout(total=self.config.resilience.timeout_ms / 1000.0)
|
||||
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=self.config.resilience.timeout_ms / 1000.0
|
||||
)
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
timeout=timeout,
|
||||
headers=self._create_headers(),
|
||||
trust_env=False # Don't trust environment variables
|
||||
trust_env=False, # Don't trust environment variables
|
||||
)
|
||||
|
||||
|
||||
logger.debug("Created new secure HTTP session for PrivateMode")
|
||||
|
||||
|
||||
return self._session
|
||||
|
||||
|
||||
async def health_check(self) -> ProviderStatus:
|
||||
"""Check PrivateMode.ai service health"""
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
|
||||
|
||||
# Use a lightweight endpoint for health check
|
||||
async with session.get(f"{self.base_url}/models") as response:
|
||||
latency = (time.time() - start_time) * 1000
|
||||
|
||||
|
||||
if response.status == 200:
|
||||
models_data = await response.json()
|
||||
models = [model.get("id", "") for model in models_data.get("data", [])]
|
||||
|
||||
models = [
|
||||
model.get("id", "") for model in models_data.get("data", [])
|
||||
]
|
||||
|
||||
return ProviderStatus(
|
||||
provider=self.provider_name,
|
||||
status="healthy",
|
||||
latency_ms=latency,
|
||||
success_rate=1.0,
|
||||
last_check=datetime.utcnow(),
|
||||
models_available=models
|
||||
models_available=models,
|
||||
)
|
||||
else:
|
||||
error_text = await response.text()
|
||||
@@ -101,13 +112,13 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
success_rate=0.0,
|
||||
last_check=datetime.utcnow(),
|
||||
error_message=f"HTTP {response.status}: {error_text}",
|
||||
models_available=[]
|
||||
models_available=[],
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
latency = (time.time() - start_time) * 1000
|
||||
logger.error(f"PrivateMode health check failed: {e}")
|
||||
|
||||
|
||||
return ProviderStatus(
|
||||
provider=self.provider_name,
|
||||
status="unavailable",
|
||||
@@ -115,33 +126,33 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
success_rate=0.0,
|
||||
last_check=datetime.utcnow(),
|
||||
error_message=str(e),
|
||||
models_available=[]
|
||||
models_available=[],
|
||||
)
|
||||
|
||||
|
||||
async def get_models(self) -> List[ModelInfo]:
|
||||
"""Get available models from PrivateMode.ai"""
|
||||
try:
|
||||
session = await self._get_session()
|
||||
|
||||
|
||||
async with session.get(f"{self.base_url}/models") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
models_data = data.get("data", [])
|
||||
|
||||
|
||||
models = []
|
||||
for model_data in models_data:
|
||||
model_id = model_data.get("id", "")
|
||||
if not model_id:
|
||||
continue
|
||||
|
||||
|
||||
# Extract all information directly from API response
|
||||
# Determine capabilities based on tasks field
|
||||
tasks = model_data.get("tasks", [])
|
||||
capabilities = []
|
||||
|
||||
|
||||
# All PrivateMode models have TEE capability
|
||||
capabilities.append("tee")
|
||||
|
||||
|
||||
# Add capabilities based on tasks
|
||||
if "generate" in tasks:
|
||||
capabilities.append("chat")
|
||||
@@ -149,12 +160,14 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
capabilities.append("embeddings")
|
||||
if "vision" in tasks:
|
||||
capabilities.append("vision")
|
||||
|
||||
|
||||
# Check for function calling support in the API response
|
||||
supports_function_calling = model_data.get("supports_function_calling", False)
|
||||
supports_function_calling = model_data.get(
|
||||
"supports_function_calling", False
|
||||
)
|
||||
if supports_function_calling:
|
||||
capabilities.append("function_calling")
|
||||
|
||||
|
||||
model_info = ModelInfo(
|
||||
id=model_id,
|
||||
object="model",
|
||||
@@ -164,40 +177,44 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
capabilities=capabilities,
|
||||
context_window=model_data.get("context_window"),
|
||||
max_output_tokens=model_data.get("max_output_tokens"),
|
||||
supports_streaming=model_data.get("supports_streaming", True),
|
||||
supports_streaming=model_data.get(
|
||||
"supports_streaming", True
|
||||
),
|
||||
supports_function_calling=supports_function_calling,
|
||||
tasks=tasks # Pass through tasks field from PrivateMode API
|
||||
tasks=tasks, # Pass through tasks field from PrivateMode API
|
||||
)
|
||||
models.append(model_info)
|
||||
|
||||
|
||||
logger.info(f"Retrieved {len(models)} models from PrivateMode")
|
||||
return models
|
||||
else:
|
||||
error_text = await response.text()
|
||||
self._handle_http_error(response.status, error_text, "models endpoint")
|
||||
self._handle_http_error(
|
||||
response.status, error_text, "models endpoint"
|
||||
)
|
||||
return [] # Never reached due to exception
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ProviderError):
|
||||
raise
|
||||
|
||||
|
||||
logger.error(f"Failed to get models from PrivateMode: {e}")
|
||||
raise ProviderError(
|
||||
"Failed to retrieve models from PrivateMode",
|
||||
provider=self.provider_name,
|
||||
error_code="MODEL_RETRIEVAL_ERROR",
|
||||
details={"error": str(e)}
|
||||
details={"error": str(e)},
|
||||
)
|
||||
|
||||
|
||||
async def create_chat_completion(self, request: ChatRequest) -> ChatResponse:
|
||||
"""Create chat completion via PrivateMode.ai"""
|
||||
self._validate_request(request)
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
|
||||
|
||||
# Prepare request payload
|
||||
payload = {
|
||||
"model": request.model,
|
||||
@@ -205,14 +222,14 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
**({"name": msg.name} if msg.name else {})
|
||||
**({"name": msg.name} if msg.name else {}),
|
||||
}
|
||||
for msg in request.messages
|
||||
],
|
||||
"temperature": request.temperature,
|
||||
"stream": False # Non-streaming version
|
||||
"stream": False, # Non-streaming version
|
||||
}
|
||||
|
||||
|
||||
# Add optional parameters
|
||||
if request.max_tokens is not None:
|
||||
payload["max_tokens"] = request.max_tokens
|
||||
@@ -224,28 +241,27 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
payload["presence_penalty"] = request.presence_penalty
|
||||
if request.stop is not None:
|
||||
payload["stop"] = request.stop
|
||||
|
||||
|
||||
# Add user tracking
|
||||
payload["user"] = f"user_{request.user_id}"
|
||||
|
||||
|
||||
# Add metadata for TEE audit trail
|
||||
payload["metadata"] = {
|
||||
"user_id": request.user_id,
|
||||
"api_key_id": request.api_key_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"enclava_request_id": str(uuid.uuid4()),
|
||||
**(request.metadata or {})
|
||||
**(request.metadata or {}),
|
||||
}
|
||||
|
||||
|
||||
async with session.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
json=payload
|
||||
f"{self.base_url}/chat/completions", json=payload
|
||||
) as response:
|
||||
provider_latency = (time.time() - start_time) * 1000
|
||||
|
||||
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
|
||||
|
||||
# Parse response
|
||||
choices = []
|
||||
for choice_data in data.get("choices", []):
|
||||
@@ -254,20 +270,20 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
index=choice_data.get("index", 0),
|
||||
message=ChatMessage(
|
||||
role=message_data.get("role", "assistant"),
|
||||
content=message_data.get("content", "")
|
||||
content=message_data.get("content", ""),
|
||||
),
|
||||
finish_reason=choice_data.get("finish_reason")
|
||||
finish_reason=choice_data.get("finish_reason"),
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
|
||||
# Parse token usage
|
||||
usage_data = data.get("usage", {})
|
||||
usage = TokenUsage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0)
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
|
||||
# Create response
|
||||
chat_response = ChatResponse(
|
||||
id=data.get("id", str(uuid.uuid4())),
|
||||
@@ -279,45 +295,51 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
usage=usage,
|
||||
system_fingerprint=data.get("system_fingerprint"),
|
||||
security_check=True, # Will be set by security manager
|
||||
risk_score=0.0, # Will be set by security manager
|
||||
risk_score=0.0, # Will be set by security manager
|
||||
latency_ms=provider_latency,
|
||||
provider_latency_ms=provider_latency
|
||||
provider_latency_ms=provider_latency,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"PrivateMode chat completion successful in {provider_latency:.2f}ms"
|
||||
)
|
||||
|
||||
logger.debug(f"PrivateMode chat completion successful in {provider_latency:.2f}ms")
|
||||
return chat_response
|
||||
|
||||
|
||||
else:
|
||||
error_text = await response.text()
|
||||
self._handle_http_error(response.status, error_text, "chat completion")
|
||||
|
||||
self._handle_http_error(
|
||||
response.status, error_text, "chat completion"
|
||||
)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"PrivateMode request error: {e}")
|
||||
raise ProviderError(
|
||||
"Network error communicating with PrivateMode",
|
||||
provider=self.provider_name,
|
||||
error_code="NETWORK_ERROR",
|
||||
details={"error": str(e)}
|
||||
details={"error": str(e)},
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, (ProviderError, ValidationError)):
|
||||
raise
|
||||
|
||||
|
||||
logger.error(f"Unexpected error in PrivateMode chat completion: {e}")
|
||||
raise ProviderError(
|
||||
"Unexpected error during chat completion",
|
||||
provider=self.provider_name,
|
||||
error_code="UNEXPECTED_ERROR",
|
||||
details={"error": str(e)}
|
||||
details={"error": str(e)},
|
||||
)
|
||||
|
||||
async def create_chat_completion_stream(self, request: ChatRequest) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
|
||||
async def create_chat_completion_stream(
|
||||
self, request: ChatRequest
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Create streaming chat completion"""
|
||||
self._validate_request(request)
|
||||
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
|
||||
|
||||
# Prepare streaming payload
|
||||
payload = {
|
||||
"model": request.model,
|
||||
@@ -325,14 +347,14 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
**({"name": msg.name} if msg.name else {})
|
||||
**({"name": msg.name} if msg.name else {}),
|
||||
}
|
||||
for msg in request.messages
|
||||
],
|
||||
"temperature": request.temperature,
|
||||
"stream": True
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
|
||||
# Add optional parameters
|
||||
if request.max_tokens is not None:
|
||||
payload["max_tokens"] = request.max_tokens
|
||||
@@ -344,100 +366,104 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
payload["presence_penalty"] = request.presence_penalty
|
||||
if request.stop is not None:
|
||||
payload["stop"] = request.stop
|
||||
|
||||
|
||||
# Add user tracking
|
||||
payload["user"] = f"user_{request.user_id}"
|
||||
|
||||
|
||||
async with session.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
json=payload
|
||||
f"{self.base_url}/chat/completions", json=payload
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
async for line in response.content:
|
||||
line = line.decode('utf-8').strip()
|
||||
|
||||
line = line.decode("utf-8").strip()
|
||||
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:] # Remove "data: " prefix
|
||||
|
||||
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
|
||||
|
||||
try:
|
||||
chunk_data = json.loads(data_str)
|
||||
yield chunk_data
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse streaming chunk: {data_str}")
|
||||
logger.warning(
|
||||
f"Failed to parse streaming chunk: {data_str}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
error_text = await response.text()
|
||||
self._handle_http_error(response.status, error_text, "streaming chat completion")
|
||||
|
||||
self._handle_http_error(
|
||||
response.status, error_text, "streaming chat completion"
|
||||
)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"PrivateMode streaming error: {e}")
|
||||
raise ProviderError(
|
||||
"Network error during streaming",
|
||||
provider=self.provider_name,
|
||||
error_code="STREAMING_ERROR",
|
||||
details={"error": str(e)}
|
||||
details={"error": str(e)},
|
||||
)
|
||||
|
||||
|
||||
async def create_embedding(self, request: EmbeddingRequest) -> EmbeddingResponse:
|
||||
"""Create embeddings via PrivateMode.ai"""
|
||||
self._validate_request(request)
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
|
||||
|
||||
# Prepare embedding payload
|
||||
payload = {
|
||||
"model": request.model,
|
||||
"input": request.input,
|
||||
"user": f"user_{request.user_id}"
|
||||
"user": f"user_{request.user_id}",
|
||||
}
|
||||
|
||||
|
||||
# Add optional parameters
|
||||
if request.encoding_format:
|
||||
payload["encoding_format"] = request.encoding_format
|
||||
if request.dimensions:
|
||||
payload["dimensions"] = request.dimensions
|
||||
|
||||
|
||||
# Add metadata
|
||||
payload["metadata"] = {
|
||||
"user_id": request.user_id,
|
||||
"api_key_id": request.api_key_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
**(request.metadata or {})
|
||||
**(request.metadata or {}),
|
||||
}
|
||||
|
||||
|
||||
async with session.post(
|
||||
f"{self.base_url}/embeddings",
|
||||
json=payload
|
||||
f"{self.base_url}/embeddings", json=payload
|
||||
) as response:
|
||||
provider_latency = (time.time() - start_time) * 1000
|
||||
|
||||
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
|
||||
|
||||
# Parse embedding data
|
||||
embeddings = []
|
||||
for emb_data in data.get("data", []):
|
||||
embedding = EmbeddingData(
|
||||
object="embedding",
|
||||
index=emb_data.get("index", 0),
|
||||
embedding=emb_data.get("embedding", [])
|
||||
embedding=emb_data.get("embedding", []),
|
||||
)
|
||||
embeddings.append(embedding)
|
||||
|
||||
|
||||
# Parse usage
|
||||
usage_data = data.get("usage", {})
|
||||
usage = TokenUsage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=0, # No completion tokens for embeddings
|
||||
total_tokens=usage_data.get("total_tokens", usage_data.get("prompt_tokens", 0))
|
||||
total_tokens=usage_data.get(
|
||||
"total_tokens", usage_data.get("prompt_tokens", 0)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
return EmbeddingResponse(
|
||||
object="list",
|
||||
data=embeddings,
|
||||
@@ -445,37 +471,39 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
provider=self.provider_name,
|
||||
usage=usage,
|
||||
security_check=True, # Will be set by security manager
|
||||
risk_score=0.0, # Will be set by security manager
|
||||
risk_score=0.0, # Will be set by security manager
|
||||
latency_ms=provider_latency,
|
||||
provider_latency_ms=provider_latency
|
||||
provider_latency_ms=provider_latency,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
error_text = await response.text()
|
||||
# Log the detailed error response from the provider
|
||||
logger.error(f"PrivateMode embedding error - Status {response.status}: {error_text}")
|
||||
logger.error(
|
||||
f"PrivateMode embedding error - Status {response.status}: {error_text}"
|
||||
)
|
||||
self._handle_http_error(response.status, error_text, "embeddings")
|
||||
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"PrivateMode embedding error: {e}")
|
||||
raise ProviderError(
|
||||
"Network error during embedding generation",
|
||||
provider=self.provider_name,
|
||||
error_code="EMBEDDING_ERROR",
|
||||
details={"error": str(e)}
|
||||
details={"error": str(e)},
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, (ProviderError, ValidationError)):
|
||||
raise
|
||||
|
||||
|
||||
logger.error(f"Unexpected error in PrivateMode embedding: {e}")
|
||||
raise ProviderError(
|
||||
"Unexpected error during embedding generation",
|
||||
provider=self.provider_name,
|
||||
error_code="UNEXPECTED_ERROR",
|
||||
details={"error": str(e)}
|
||||
details={"error": str(e)},
|
||||
)
|
||||
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup PrivateMode provider resources"""
|
||||
# Close HTTP session to prevent memory leaks
|
||||
@@ -485,4 +513,4 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
logger.debug("Closed PrivateMode HTTP session")
|
||||
|
||||
await super().cleanup()
|
||||
logger.debug("PrivateMode provider cleanup completed")
|
||||
logger.debug("PrivateMode provider cleanup completed")
|
||||
|
||||
@@ -20,14 +20,16 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class CircuitBreakerState(Enum):
|
||||
"""Circuit breaker states"""
|
||||
CLOSED = "closed" # Normal operation
|
||||
OPEN = "open" # Failing, blocking requests
|
||||
|
||||
CLOSED = "closed" # Normal operation
|
||||
OPEN = "open" # Failing, blocking requests
|
||||
HALF_OPEN = "half_open" # Testing if service recovered
|
||||
|
||||
|
||||
@dataclass
|
||||
class CircuitBreakerStats:
|
||||
"""Circuit breaker statistics"""
|
||||
|
||||
failure_count: int = 0
|
||||
success_count: int = 0
|
||||
last_failure_time: Optional[datetime] = None
|
||||
@@ -37,162 +39,186 @@ class CircuitBreakerStats:
|
||||
|
||||
class CircuitBreaker:
|
||||
"""Circuit breaker implementation for provider resilience"""
|
||||
|
||||
|
||||
def __init__(self, config: ResilienceConfig, provider_name: str):
|
||||
self.config = config
|
||||
self.provider_name = provider_name
|
||||
self.state = CircuitBreakerState.CLOSED
|
||||
self.stats = CircuitBreakerStats()
|
||||
|
||||
|
||||
def can_execute(self) -> bool:
|
||||
"""Check if request can be executed"""
|
||||
if self.state == CircuitBreakerState.CLOSED:
|
||||
return True
|
||||
|
||||
|
||||
if self.state == CircuitBreakerState.OPEN:
|
||||
# Check if reset timeout has passed
|
||||
if (datetime.utcnow() - self.stats.state_change_time).total_seconds() * 1000 > self.config.circuit_breaker_reset_timeout_ms:
|
||||
if (
|
||||
datetime.utcnow() - self.stats.state_change_time
|
||||
).total_seconds() * 1000 > self.config.circuit_breaker_reset_timeout_ms:
|
||||
self._transition_to_half_open()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
if self.state == CircuitBreakerState.HALF_OPEN:
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def record_success(self):
|
||||
"""Record successful request"""
|
||||
self.stats.success_count += 1
|
||||
self.stats.last_success_time = datetime.utcnow()
|
||||
|
||||
|
||||
if self.state == CircuitBreakerState.HALF_OPEN:
|
||||
self._transition_to_closed()
|
||||
elif self.state == CircuitBreakerState.CLOSED:
|
||||
# Reset failure count on success
|
||||
self.stats.failure_count = 0
|
||||
|
||||
logger.debug(f"Circuit breaker [{self.provider_name}]: Success recorded, state={self.state.value}")
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"Circuit breaker [{self.provider_name}]: Success recorded, state={self.state.value}"
|
||||
)
|
||||
|
||||
def record_failure(self):
|
||||
"""Record failed request"""
|
||||
self.stats.failure_count += 1
|
||||
self.stats.last_failure_time = datetime.utcnow()
|
||||
|
||||
|
||||
if self.state == CircuitBreakerState.CLOSED:
|
||||
if self.stats.failure_count >= self.config.circuit_breaker_threshold:
|
||||
self._transition_to_open()
|
||||
elif self.state == CircuitBreakerState.HALF_OPEN:
|
||||
self._transition_to_open()
|
||||
|
||||
logger.warning(f"Circuit breaker [{self.provider_name}]: Failure recorded, "
|
||||
f"count={self.stats.failure_count}, state={self.state.value}")
|
||||
|
||||
|
||||
logger.warning(
|
||||
f"Circuit breaker [{self.provider_name}]: Failure recorded, "
|
||||
f"count={self.stats.failure_count}, state={self.state.value}"
|
||||
)
|
||||
|
||||
def _transition_to_open(self):
|
||||
"""Transition to OPEN state"""
|
||||
self.state = CircuitBreakerState.OPEN
|
||||
self.stats.state_change_time = datetime.utcnow()
|
||||
logger.error(f"Circuit breaker [{self.provider_name}]: OPENED after {self.stats.failure_count} failures")
|
||||
|
||||
logger.error(
|
||||
f"Circuit breaker [{self.provider_name}]: OPENED after {self.stats.failure_count} failures"
|
||||
)
|
||||
|
||||
def _transition_to_half_open(self):
|
||||
"""Transition to HALF_OPEN state"""
|
||||
self.state = CircuitBreakerState.HALF_OPEN
|
||||
self.stats.state_change_time = datetime.utcnow()
|
||||
logger.info(f"Circuit breaker [{self.provider_name}]: Transitioning to HALF_OPEN for testing")
|
||||
|
||||
logger.info(
|
||||
f"Circuit breaker [{self.provider_name}]: Transitioning to HALF_OPEN for testing"
|
||||
)
|
||||
|
||||
def _transition_to_closed(self):
|
||||
"""Transition to CLOSED state"""
|
||||
self.state = CircuitBreakerState.CLOSED
|
||||
self.stats.state_change_time = datetime.utcnow()
|
||||
self.stats.failure_count = 0 # Reset failure count
|
||||
logger.info(f"Circuit breaker [{self.provider_name}]: CLOSED - service recovered")
|
||||
|
||||
logger.info(
|
||||
f"Circuit breaker [{self.provider_name}]: CLOSED - service recovered"
|
||||
)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get circuit breaker statistics"""
|
||||
return {
|
||||
"state": self.state.value,
|
||||
"failure_count": self.stats.failure_count,
|
||||
"success_count": self.stats.success_count,
|
||||
"last_failure_time": self.stats.last_failure_time.isoformat() if self.stats.last_failure_time else None,
|
||||
"last_success_time": self.stats.last_success_time.isoformat() if self.stats.last_success_time else None,
|
||||
"last_failure_time": self.stats.last_failure_time.isoformat()
|
||||
if self.stats.last_failure_time
|
||||
else None,
|
||||
"last_success_time": self.stats.last_success_time.isoformat()
|
||||
if self.stats.last_success_time
|
||||
else None,
|
||||
"state_change_time": self.stats.state_change_time.isoformat(),
|
||||
"time_in_current_state_ms": (datetime.utcnow() - self.stats.state_change_time).total_seconds() * 1000
|
||||
"time_in_current_state_ms": (
|
||||
datetime.utcnow() - self.stats.state_change_time
|
||||
).total_seconds()
|
||||
* 1000,
|
||||
}
|
||||
|
||||
|
||||
class RetryManager:
|
||||
"""Manages retry logic with exponential backoff"""
|
||||
|
||||
|
||||
def __init__(self, config: ResilienceConfig):
|
||||
self.config = config
|
||||
|
||||
|
||||
async def execute_with_retry(
|
||||
self,
|
||||
func: Callable,
|
||||
*args,
|
||||
retryable_exceptions: tuple = (Exception,),
|
||||
non_retryable_exceptions: tuple = (RateLimitError,),
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Execute function with retry logic"""
|
||||
last_exception = None
|
||||
|
||||
|
||||
for attempt in range(self.config.max_retries + 1):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
|
||||
except non_retryable_exceptions as e:
|
||||
logger.warning(f"Non-retryable exception on attempt {attempt + 1}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
except retryable_exceptions as e:
|
||||
last_exception = e
|
||||
|
||||
|
||||
if attempt == self.config.max_retries:
|
||||
logger.error(f"All {self.config.max_retries + 1} attempts failed. Last error: {e}")
|
||||
logger.error(
|
||||
f"All {self.config.max_retries + 1} attempts failed. Last error: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
delay = self._calculate_delay(attempt)
|
||||
logger.warning(f"Attempt {attempt + 1} failed: {e}. Retrying in {delay}ms...")
|
||||
|
||||
logger.warning(
|
||||
f"Attempt {attempt + 1} failed: {e}. Retrying in {delay}ms..."
|
||||
)
|
||||
|
||||
await asyncio.sleep(delay / 1000.0)
|
||||
|
||||
|
||||
# This should never be reached, but just in case
|
||||
if last_exception:
|
||||
raise last_exception
|
||||
else:
|
||||
raise LLMError("Unexpected error in retry logic")
|
||||
|
||||
|
||||
def _calculate_delay(self, attempt: int) -> int:
|
||||
"""Calculate delay for exponential backoff"""
|
||||
delay = self.config.retry_delay_ms * (self.config.retry_exponential_base ** attempt)
|
||||
|
||||
delay = self.config.retry_delay_ms * (
|
||||
self.config.retry_exponential_base**attempt
|
||||
)
|
||||
|
||||
# Add some jitter to prevent thundering herd
|
||||
import random
|
||||
|
||||
jitter = random.uniform(0.8, 1.2)
|
||||
|
||||
|
||||
return int(delay * jitter)
|
||||
|
||||
|
||||
class TimeoutManager:
|
||||
"""Manages request timeouts"""
|
||||
|
||||
|
||||
def __init__(self, config: ResilienceConfig):
|
||||
self.config = config
|
||||
|
||||
|
||||
async def execute_with_timeout(
|
||||
self,
|
||||
func: Callable,
|
||||
*args,
|
||||
timeout_override: Optional[int] = None,
|
||||
**kwargs
|
||||
self, func: Callable, *args, timeout_override: Optional[int] = None, **kwargs
|
||||
) -> Any:
|
||||
"""Execute function with timeout"""
|
||||
timeout_ms = timeout_override or self.config.timeout_ms
|
||||
timeout_seconds = timeout_ms / 1000.0
|
||||
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(func(*args, **kwargs), timeout=timeout_seconds)
|
||||
return await asyncio.wait_for(
|
||||
func(*args, **kwargs), timeout=timeout_seconds
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
error_msg = f"Request timed out after {timeout_ms}ms"
|
||||
logger.error(error_msg)
|
||||
@@ -201,14 +227,14 @@ class TimeoutManager:
|
||||
|
||||
class ResilienceManager:
|
||||
"""Comprehensive resilience manager combining all patterns"""
|
||||
|
||||
|
||||
def __init__(self, config: ResilienceConfig, provider_name: str):
|
||||
self.config = config
|
||||
self.provider_name = provider_name
|
||||
self.circuit_breaker = CircuitBreaker(config, provider_name)
|
||||
self.retry_manager = RetryManager(config)
|
||||
self.timeout_manager = TimeoutManager(config)
|
||||
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
func: Callable,
|
||||
@@ -216,18 +242,18 @@ class ResilienceManager:
|
||||
retryable_exceptions: tuple = (Exception,),
|
||||
non_retryable_exceptions: tuple = (RateLimitError,),
|
||||
timeout_override: Optional[int] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Execute function with full resilience patterns"""
|
||||
|
||||
|
||||
# Check circuit breaker
|
||||
if not self.circuit_breaker.can_execute():
|
||||
error_msg = f"Circuit breaker is OPEN for provider {self.provider_name}"
|
||||
logger.error(error_msg)
|
||||
raise LLMError(error_msg, error_code="CIRCUIT_BREAKER_OPEN")
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
try:
|
||||
# Execute with timeout and retry
|
||||
result = await self.retry_manager.execute_with_retry(
|
||||
@@ -237,30 +263,34 @@ class ResilienceManager:
|
||||
retryable_exceptions=retryable_exceptions,
|
||||
non_retryable_exceptions=non_retryable_exceptions,
|
||||
timeout_override=timeout_override,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# Record success
|
||||
self.circuit_breaker.record_success()
|
||||
|
||||
|
||||
execution_time = (time.time() - start_time) * 1000
|
||||
logger.debug(f"Resilient execution succeeded for {self.provider_name} in {execution_time:.2f}ms")
|
||||
|
||||
logger.debug(
|
||||
f"Resilient execution succeeded for {self.provider_name} in {execution_time:.2f}ms"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Record failure
|
||||
self.circuit_breaker.record_failure()
|
||||
|
||||
|
||||
execution_time = (time.time() - start_time) * 1000
|
||||
logger.error(f"Resilient execution failed for {self.provider_name} after {execution_time:.2f}ms: {e}")
|
||||
|
||||
logger.error(
|
||||
f"Resilient execution failed for {self.provider_name} after {execution_time:.2f}ms: {e}"
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
|
||||
def get_health_status(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive health status"""
|
||||
cb_stats = self.circuit_breaker.get_stats()
|
||||
|
||||
|
||||
# Determine overall health
|
||||
if cb_stats["state"] == "open":
|
||||
health = "unhealthy"
|
||||
@@ -273,7 +303,7 @@ class ResilienceManager:
|
||||
health = "degraded"
|
||||
else:
|
||||
health = "healthy"
|
||||
|
||||
|
||||
return {
|
||||
"provider": self.provider_name,
|
||||
"health": health,
|
||||
@@ -281,34 +311,37 @@ class ResilienceManager:
|
||||
"config": {
|
||||
"max_retries": self.config.max_retries,
|
||||
"timeout_ms": self.config.timeout_ms,
|
||||
"circuit_breaker_threshold": self.config.circuit_breaker_threshold
|
||||
}
|
||||
"circuit_breaker_threshold": self.config.circuit_breaker_threshold,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ResilienceManagerFactory:
|
||||
"""Factory for creating resilience managers"""
|
||||
|
||||
|
||||
_managers: Dict[str, ResilienceManager] = {}
|
||||
_default_config = ResilienceConfig()
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_manager(cls, provider_name: str, config: Optional[ResilienceConfig] = None) -> ResilienceManager:
|
||||
def get_manager(
|
||||
cls, provider_name: str, config: Optional[ResilienceConfig] = None
|
||||
) -> ResilienceManager:
|
||||
"""Get or create resilience manager for provider"""
|
||||
if provider_name not in cls._managers:
|
||||
manager_config = config or cls._default_config
|
||||
cls._managers[provider_name] = ResilienceManager(manager_config, provider_name)
|
||||
|
||||
cls._managers[provider_name] = ResilienceManager(
|
||||
manager_config, provider_name
|
||||
)
|
||||
|
||||
return cls._managers[provider_name]
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_all_health_status(cls) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get health status for all managed providers"""
|
||||
return {
|
||||
name: manager.get_health_status()
|
||||
for name, manager in cls._managers.items()
|
||||
name: manager.get_health_status() for name, manager in cls._managers.items()
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def update_config(cls, provider_name: str, config: ResilienceConfig):
|
||||
"""Update configuration for a specific provider"""
|
||||
@@ -317,7 +350,7 @@ class ResilienceManagerFactory:
|
||||
cls._managers[provider_name].circuit_breaker.config = config
|
||||
cls._managers[provider_name].retry_manager.config = config
|
||||
cls._managers[provider_name].timeout_manager.config = config
|
||||
|
||||
|
||||
@classmethod
|
||||
def reset_circuit_breaker(cls, provider_name: str):
|
||||
"""Manually reset circuit breaker for a provider"""
|
||||
@@ -325,8 +358,8 @@ class ResilienceManagerFactory:
|
||||
manager = cls._managers[provider_name]
|
||||
manager.circuit_breaker._transition_to_closed()
|
||||
logger.info(f"Manually reset circuit breaker for {provider_name}")
|
||||
|
||||
|
||||
@classmethod
|
||||
def set_default_config(cls, config: ResilienceConfig):
|
||||
"""Set default configuration for new managers"""
|
||||
cls._default_config = config
|
||||
cls._default_config = config
|
||||
|
||||
@@ -12,18 +12,28 @@ from typing import Dict, Any, Optional, List, AsyncGenerator
|
||||
from datetime import datetime
|
||||
|
||||
from .models import (
|
||||
ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
|
||||
ModelInfo, ProviderStatus, LLMMetrics
|
||||
ChatRequest,
|
||||
ChatResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
ModelInfo,
|
||||
ProviderStatus,
|
||||
LLMMetrics,
|
||||
)
|
||||
from .config import config_manager, ProviderConfig
|
||||
from ...core.config import settings
|
||||
|
||||
from .resilience import ResilienceManagerFactory
|
||||
|
||||
# from .metrics import metrics_collector
|
||||
from .providers import BaseLLMProvider, PrivateModeProvider
|
||||
from .exceptions import (
|
||||
LLMError, ProviderError, SecurityError, ConfigurationError,
|
||||
ValidationError, TimeoutError
|
||||
LLMError,
|
||||
ProviderError,
|
||||
SecurityError,
|
||||
ConfigurationError,
|
||||
ValidationError,
|
||||
TimeoutError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -31,58 +41,64 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class LLMService:
|
||||
"""Main LLM service coordinating all components"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize LLM service"""
|
||||
self._providers: Dict[str, BaseLLMProvider] = {}
|
||||
self._initialized = False
|
||||
self._startup_time: Optional[datetime] = None
|
||||
|
||||
|
||||
logger.info("LLM Service initialized")
|
||||
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize service and providers"""
|
||||
if self._initialized:
|
||||
logger.warning("LLM Service already initialized")
|
||||
return
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
self._startup_time = datetime.utcnow()
|
||||
|
||||
|
||||
try:
|
||||
# Get configuration
|
||||
config = config_manager.get_config()
|
||||
logger.info(f"Initializing LLM service with {len(config.providers)} configured providers")
|
||||
|
||||
logger.info(
|
||||
f"Initializing LLM service with {len(config.providers)} configured providers"
|
||||
)
|
||||
|
||||
# Initialize enabled providers
|
||||
enabled_providers = config_manager.get_enabled_providers()
|
||||
if not enabled_providers:
|
||||
raise ConfigurationError("No enabled providers found")
|
||||
|
||||
|
||||
for provider_name in enabled_providers:
|
||||
await self._initialize_provider(provider_name)
|
||||
|
||||
|
||||
# Verify we have at least one working provider
|
||||
if not self._providers:
|
||||
raise ConfigurationError("No providers successfully initialized")
|
||||
|
||||
|
||||
# Verify default provider is available
|
||||
default_provider = config.default_provider
|
||||
if default_provider not in self._providers:
|
||||
available_providers = list(self._providers.keys())
|
||||
logger.warning(f"Default provider '{default_provider}' not available, using '{available_providers[0]}'")
|
||||
logger.warning(
|
||||
f"Default provider '{default_provider}' not available, using '{available_providers[0]}'"
|
||||
)
|
||||
config.default_provider = available_providers[0]
|
||||
|
||||
|
||||
self._initialized = True
|
||||
initialization_time = (time.time() - start_time) * 1000
|
||||
|
||||
logger.info(f"LLM Service initialized successfully in {initialization_time:.2f}ms")
|
||||
|
||||
logger.info(
|
||||
f"LLM Service initialized successfully in {initialization_time:.2f}ms"
|
||||
)
|
||||
logger.info(f"Available providers: {list(self._providers.keys())}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize LLM service: {e}")
|
||||
raise ConfigurationError(f"LLM service initialization failed: {e}")
|
||||
|
||||
|
||||
async def _initialize_provider(self, provider_name: str):
|
||||
"""Initialize a specific provider"""
|
||||
try:
|
||||
@@ -90,101 +106,109 @@ class LLMService:
|
||||
if not provider_config or not provider_config.enabled:
|
||||
logger.warning(f"Provider '{provider_name}' not enabled, skipping")
|
||||
return
|
||||
|
||||
|
||||
# Get API key
|
||||
api_key = config_manager.get_api_key(provider_name)
|
||||
if not api_key:
|
||||
logger.error(f"No API key found for provider '{provider_name}'")
|
||||
return
|
||||
|
||||
|
||||
# Create provider instance
|
||||
provider = self._create_provider(provider_config, api_key)
|
||||
|
||||
|
||||
# Initialize provider
|
||||
await provider.initialize()
|
||||
|
||||
|
||||
# Test provider health
|
||||
health_status = await provider.health_check()
|
||||
if health_status.status == "unavailable":
|
||||
logger.error(f"Provider '{provider_name}' failed health check: {health_status.error_message}")
|
||||
logger.error(
|
||||
f"Provider '{provider_name}' failed health check: {health_status.error_message}"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
# Register provider
|
||||
self._providers[provider_name] = provider
|
||||
logger.info(f"Provider '{provider_name}' initialized successfully (status: {health_status.status})")
|
||||
|
||||
logger.info(
|
||||
f"Provider '{provider_name}' initialized successfully (status: {health_status.status})"
|
||||
)
|
||||
|
||||
# Fetch and update models dynamically
|
||||
await self._refresh_provider_models(provider_name, provider)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize provider '{provider_name}': {e}")
|
||||
|
||||
|
||||
def _create_provider(self, config: ProviderConfig, api_key: str) -> BaseLLMProvider:
|
||||
"""Create provider instance based on configuration"""
|
||||
if config.name == "privatemode":
|
||||
return PrivateModeProvider(config, api_key)
|
||||
else:
|
||||
raise ConfigurationError(f"Unknown provider type: {config.name}")
|
||||
|
||||
async def _refresh_provider_models(self, provider_name: str, provider: BaseLLMProvider):
|
||||
|
||||
async def _refresh_provider_models(
|
||||
self, provider_name: str, provider: BaseLLMProvider
|
||||
):
|
||||
"""Fetch and update models dynamically from provider"""
|
||||
try:
|
||||
# Get models from provider
|
||||
models = await provider.get_models()
|
||||
model_ids = [model.id for model in models]
|
||||
|
||||
|
||||
# Update configuration
|
||||
await config_manager.refresh_provider_models(provider_name, model_ids)
|
||||
|
||||
logger.info(f"Refreshed {len(model_ids)} models for provider '{provider_name}': {model_ids}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Refreshed {len(model_ids)} models for provider '{provider_name}': {model_ids}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh models for provider '{provider_name}': {e}")
|
||||
|
||||
logger.error(
|
||||
f"Failed to refresh models for provider '{provider_name}': {e}"
|
||||
)
|
||||
|
||||
async def create_chat_completion(self, request: ChatRequest) -> ChatResponse:
|
||||
"""Create chat completion with security and resilience"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
|
||||
# Validate request
|
||||
if not request.messages:
|
||||
raise ValidationError("Messages cannot be empty", field="messages")
|
||||
|
||||
|
||||
risk_score = 0.0
|
||||
|
||||
|
||||
# Get provider for model
|
||||
provider_name = self._get_provider_for_model(request.model)
|
||||
provider = self._providers.get(provider_name)
|
||||
|
||||
|
||||
if not provider:
|
||||
raise ProviderError(f"No available provider for model '{request.model}'", provider=provider_name)
|
||||
|
||||
|
||||
raise ProviderError(
|
||||
f"No available provider for model '{request.model}'",
|
||||
provider=provider_name,
|
||||
)
|
||||
|
||||
# Execute with resilience
|
||||
resilience_manager = ResilienceManagerFactory.get_manager(provider_name)
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
try:
|
||||
response = await resilience_manager.execute(
|
||||
provider.create_chat_completion,
|
||||
request,
|
||||
retryable_exceptions=(ProviderError, TimeoutError),
|
||||
non_retryable_exceptions=(ValidationError,)
|
||||
non_retryable_exceptions=(ValidationError,),
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
# Record successful request - metrics disabled
|
||||
total_latency = (time.time() - start_time) * 1000
|
||||
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Record failed request - metrics disabled
|
||||
total_latency = (time.time() - start_time) * 1000
|
||||
error_code = getattr(e, 'error_code', e.__class__.__name__)
|
||||
error_code = getattr(e, "error_code", e.__class__.__name__)
|
||||
|
||||
logger.exception(
|
||||
"Chat completion failed for provider %s (model=%s, latency=%.2fms, error=%s)",
|
||||
@@ -194,38 +218,42 @@ class LLMService:
|
||||
error_code,
|
||||
)
|
||||
raise
|
||||
|
||||
async def create_chat_completion_stream(self, request: ChatRequest) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
|
||||
async def create_chat_completion_stream(
|
||||
self, request: ChatRequest
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Create streaming chat completion"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
|
||||
# Security validation disabled - always allow streaming requests
|
||||
risk_score = 0.0
|
||||
|
||||
|
||||
# Get provider
|
||||
provider_name = self._get_provider_for_model(request.model)
|
||||
provider = self._providers.get(provider_name)
|
||||
|
||||
|
||||
if not provider:
|
||||
raise ProviderError(f"No available provider for model '{request.model}'", provider=provider_name)
|
||||
|
||||
raise ProviderError(
|
||||
f"No available provider for model '{request.model}'",
|
||||
provider=provider_name,
|
||||
)
|
||||
|
||||
# Execute streaming with resilience
|
||||
resilience_manager = ResilienceManagerFactory.get_manager(provider_name)
|
||||
|
||||
|
||||
try:
|
||||
async for chunk in await resilience_manager.execute(
|
||||
provider.create_chat_completion_stream,
|
||||
request,
|
||||
retryable_exceptions=(ProviderError, TimeoutError),
|
||||
non_retryable_exceptions=(ValidationError,)
|
||||
non_retryable_exceptions=(ValidationError,),
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Record streaming failure - metrics disabled
|
||||
error_code = getattr(e, 'error_code', e.__class__.__name__)
|
||||
error_code = getattr(e, "error_code", e.__class__.__name__)
|
||||
logger.exception(
|
||||
"Streaming chat completion failed for provider %s (model=%s, error=%s)",
|
||||
provider_name,
|
||||
@@ -233,46 +261,46 @@ class LLMService:
|
||||
error_code,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def create_embedding(self, request: EmbeddingRequest) -> EmbeddingResponse:
|
||||
"""Create embeddings with security and resilience"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
|
||||
# Security validation disabled - always allow embedding requests
|
||||
risk_score = 0.0
|
||||
|
||||
|
||||
# Get provider
|
||||
provider_name = self._get_provider_for_model(request.model)
|
||||
provider = self._providers.get(provider_name)
|
||||
|
||||
|
||||
if not provider:
|
||||
raise ProviderError(f"No available provider for model '{request.model}'", provider=provider_name)
|
||||
|
||||
raise ProviderError(
|
||||
f"No available provider for model '{request.model}'",
|
||||
provider=provider_name,
|
||||
)
|
||||
|
||||
# Execute with resilience
|
||||
resilience_manager = ResilienceManagerFactory.get_manager(provider_name)
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
try:
|
||||
response = await resilience_manager.execute(
|
||||
provider.create_embedding,
|
||||
request,
|
||||
retryable_exceptions=(ProviderError, TimeoutError),
|
||||
non_retryable_exceptions=(ValidationError,)
|
||||
non_retryable_exceptions=(ValidationError,),
|
||||
)
|
||||
|
||||
|
||||
|
||||
# Record successful request - metrics disabled
|
||||
total_latency = (time.time() - start_time) * 1000
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Record failed request - metrics disabled
|
||||
total_latency = (time.time() - start_time) * 1000
|
||||
error_code = getattr(e, 'error_code', e.__class__.__name__)
|
||||
error_code = getattr(e, "error_code", e.__class__.__name__)
|
||||
logger.exception(
|
||||
"Embedding request failed for provider %s (model=%s, latency=%.2fms, error=%s)",
|
||||
provider_name,
|
||||
@@ -281,14 +309,14 @@ class LLMService:
|
||||
error_code,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def get_models(self, provider_name: Optional[str] = None) -> List[ModelInfo]:
|
||||
"""Get available models from all or specific provider"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
|
||||
models = []
|
||||
|
||||
|
||||
if provider_name:
|
||||
# Get models from specific provider
|
||||
provider = self._providers.get(provider_name)
|
||||
@@ -306,16 +334,16 @@ class LLMService:
|
||||
models.extend(provider_models)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get models from {name}: {e}")
|
||||
|
||||
|
||||
return models
|
||||
|
||||
|
||||
async def get_provider_status(self) -> Dict[str, ProviderStatus]:
|
||||
"""Get health status of all providers"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
|
||||
status_dict = {}
|
||||
|
||||
|
||||
for name, provider in self._providers.items():
|
||||
try:
|
||||
status = await provider.health_check()
|
||||
@@ -327,21 +355,18 @@ class LLMService:
|
||||
status="unavailable",
|
||||
last_check=datetime.utcnow(),
|
||||
error_message=str(e),
|
||||
models_available=[]
|
||||
models_available=[],
|
||||
)
|
||||
|
||||
|
||||
return status_dict
|
||||
|
||||
|
||||
def get_metrics(self) -> LLMMetrics:
|
||||
"""Get service metrics - metrics disabled"""
|
||||
# return metrics_collector.get_metrics()
|
||||
return LLMMetrics(
|
||||
total_requests=0,
|
||||
success_rate=0.0,
|
||||
avg_latency_ms=0,
|
||||
error_rates={}
|
||||
total_requests=0, success_rate=0.0, avg_latency_ms=0, error_rates={}
|
||||
)
|
||||
|
||||
|
||||
def get_health_summary(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive health summary - metrics disabled"""
|
||||
# metrics_health = metrics_collector.get_health_summary()
|
||||
@@ -349,40 +374,42 @@ class LLMService:
|
||||
|
||||
return {
|
||||
"service_status": "healthy" if self._initialized else "initializing",
|
||||
"startup_time": self._startup_time.isoformat() if self._startup_time else None,
|
||||
"startup_time": self._startup_time.isoformat()
|
||||
if self._startup_time
|
||||
else None,
|
||||
"provider_count": len(self._providers),
|
||||
"active_providers": list(self._providers.keys()),
|
||||
"metrics": {"status": "disabled"},
|
||||
"resilience": resilience_health
|
||||
"resilience": resilience_health,
|
||||
}
|
||||
|
||||
|
||||
def _get_provider_for_model(self, model: str) -> str:
|
||||
"""Get provider name for a model"""
|
||||
# Check model routing first
|
||||
provider_name = config_manager.get_provider_for_model(model)
|
||||
if provider_name and provider_name in self._providers:
|
||||
return provider_name
|
||||
|
||||
|
||||
# Fall back to providers that support the model
|
||||
for name, provider in self._providers.items():
|
||||
if provider.supports_model(model):
|
||||
return name
|
||||
|
||||
|
||||
# Use default provider as last resort
|
||||
config = config_manager.get_config()
|
||||
if config.default_provider in self._providers:
|
||||
return config.default_provider
|
||||
|
||||
|
||||
# If nothing else works, use first available provider
|
||||
if self._providers:
|
||||
return list(self._providers.keys())[0]
|
||||
|
||||
|
||||
raise ProviderError(f"No provider found for model '{model}'", provider="none")
|
||||
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup service resources"""
|
||||
logger.info("Cleaning up LLM service")
|
||||
|
||||
|
||||
# Cleanup providers
|
||||
for name, provider in self._providers.items():
|
||||
try:
|
||||
@@ -390,7 +417,7 @@ class LLMService:
|
||||
logger.debug(f"Cleaned up provider: {name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up provider {name}: {e}")
|
||||
|
||||
|
||||
self._providers.clear()
|
||||
self._initialized = False
|
||||
logger.info("LLM service cleanup completed")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user