From 25778ab94e836446ff48d80957827273857711d4 Mon Sep 17 00:00:00 2001 From: Aljaz Ceru Date: Thu, 18 Sep 2025 15:37:00 +0200 Subject: [PATCH 01/13] fix --- .gitignore | 108 ------------------- backend/app/api/v1/prompt_templates.py | 49 ++++++--- frontend/src/lib/api-client.ts | 96 +++++++++++++++++ frontend/src/lib/config.ts | 15 +++ frontend/src/lib/file-download.ts | 51 +++++++++ frontend/src/lib/id-utils.ts | 16 +++ frontend/src/lib/proxy-auth.ts | 31 ++++++ frontend/src/lib/token-manager.ts | 141 +++++++++++++++++++++++++ frontend/src/lib/utils.ts | 8 ++ 9 files changed, 392 insertions(+), 123 deletions(-) create mode 100644 frontend/src/lib/api-client.ts create mode 100644 frontend/src/lib/config.ts create mode 100644 frontend/src/lib/file-download.ts create mode 100644 frontend/src/lib/id-utils.ts create mode 100644 frontend/src/lib/proxy-auth.ts create mode 100644 frontend/src/lib/token-manager.ts create mode 100644 frontend/src/lib/utils.ts diff --git a/.gitignore b/.gitignore index 1719a7d..e69de29 100644 --- a/.gitignore +++ b/.gitignore @@ -1,108 +0,0 @@ -*.backup -backend/storage/rag_documents/* -# Python -__pycache__/ -*.py[cod] -*$py.class -*.so -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# Virtual environments -venv/ -env/ -ENV/ -env.bak/ -venv.bak/ -.venv/ - -# IDE -.vscode/ -.idea/ -*.swp -*.swo -*~ - -# OS -.DS_Store -.DS_Store? -._* -.Spotlight-V100 -.Trashes -ehthumbs.db -Thumbs.db - -# Node.js -node_modules/ -npm-debug.log* -yarn-debug.log* -yarn-error.log* -.npm -.eslintcache -.next/ -.nuxt/ -out/ -dist/ - -# Environment variables -.env -.env.local -.env.development.local -.env.test.local -.env.production.local - -# Docker -*.log -docker-compose.override.yml - -# Database -*.db -*.sqlite - -# Redis -dump.rdb - -# Logs -logs/ -*.log - -# Coverage -coverage/ -.coverage -.nyc_output - -# Cache -.cache/ -.pytest_cache/ -.mypy_cache/ -.ruff_cache/ - -# Temporary files -*.tmp -*.temp -.tmp/ - -# Security - sensitive files -backend/.config_encryption_key -*.key -*.pem -*.crt - -# Generated files -backend/performance_report.json -performance_report*.json diff --git a/backend/app/api/v1/prompt_templates.py b/backend/app/api/v1/prompt_templates.py index 2f58a65..6612149 100644 --- a/backend/app/api/v1/prompt_templates.py +++ b/backend/app/api/v1/prompt_templates.py @@ -484,7 +484,7 @@ async def seed_default_templates( select(PromptTemplate).where(PromptTemplate.type_key == type_key) ) existing_template = existing.scalar_one_or_none() - + if existing_template: # Only update if it's still the default (version 1) if existing_template.version == 1 and existing_template.is_default: @@ -494,21 +494,40 @@ async def seed_default_templates( existing_template.updated_at = datetime.utcnow() updated_templates.append(type_key) else: - # Create new template - new_template = PromptTemplate( - id=str(uuid.uuid4()), - name=template_data["name"], - type_key=type_key, - description=template_data["description"], - system_prompt=template_data["prompt"], - is_default=True, - is_active=True, - version=1, - created_at=datetime.utcnow(), - updated_at=datetime.utcnow() + # Check if any inactive template exists with this type_key + inactive_result = await db.execute( + select(PromptTemplate) + .where(PromptTemplate.type_key == type_key) + .where(PromptTemplate.is_active == False) ) - db.add(new_template) - created_templates.append(type_key) + inactive_template = inactive_result.scalar_one_or_none() + + if inactive_template: + # Reactivate the inactive template + inactive_template.is_active = True + inactive_template.name = template_data["name"] + inactive_template.description = template_data["description"] + inactive_template.system_prompt = template_data["prompt"] + inactive_template.is_default = True + inactive_template.version = 1 + inactive_template.updated_at = datetime.utcnow() + updated_templates.append(type_key) + else: + # Create new template + new_template = PromptTemplate( + id=str(uuid.uuid4()), + name=template_data["name"], + type_key=type_key, + description=template_data["description"], + system_prompt=template_data["prompt"], + is_default=True, + is_active=True, + version=1, + created_at=datetime.utcnow(), + updated_at=datetime.utcnow() + ) + db.add(new_template) + created_templates.append(type_key) await db.commit() diff --git a/frontend/src/lib/api-client.ts b/frontend/src/lib/api-client.ts new file mode 100644 index 0000000..126f733 --- /dev/null +++ b/frontend/src/lib/api-client.ts @@ -0,0 +1,96 @@ +export interface AppError extends Error { + code: 'UNAUTHORIZED' | 'NETWORK_ERROR' | 'VALIDATION_ERROR' | 'NOT_FOUND' | 'FORBIDDEN' | 'TIMEOUT' | 'UNKNOWN' + status?: number + details?: any +} + +function makeError(message: string, code: AppError['code'], status?: number, details?: any): AppError { + const err = new Error(message) as AppError + err.code = code + err.status = status + err.details = details + return err +} + +async function getAuthHeader(): Promise> { + try { + const { tokenManager } = await import('./token-manager') + const token = await tokenManager.getAccessToken() + return token ? { Authorization: `Bearer ${token}` } : {} + } catch { + return {} + } +} + +async function request(method: string, url: string, body?: any, extraInit?: RequestInit): Promise { + try { + const headers: Record = { + 'Accept': 'application/json', + ...(method !== 'GET' && method !== 'HEAD' ? { 'Content-Type': 'application/json' } : {}), + ...(await getAuthHeader()), + ...(extraInit?.headers as Record | undefined), + } + + const res = await fetch(url, { + method, + headers, + body: body != null && method !== 'GET' && method !== 'HEAD' ? JSON.stringify(body) : undefined, + ...extraInit, + }) + + if (!res.ok) { + let details: any = undefined + try { details = await res.json() } catch { details = await res.text() } + const status = res.status + if (status === 401) throw makeError('Unauthorized', 'UNAUTHORIZED', status, details) + if (status === 403) throw makeError('Forbidden', 'FORBIDDEN', status, details) + if (status === 404) throw makeError('Not found', 'NOT_FOUND', status, details) + if (status === 400) throw makeError('Validation error', 'VALIDATION_ERROR', status, details) + throw makeError('Request failed', 'UNKNOWN', status, details) + } + + const contentType = res.headers.get('content-type') || '' + if (contentType.includes('application/json')) { + return (await res.json()) as T + } + // @ts-expect-error allow non-json generic + return (await res.text()) as T + } catch (e: any) { + if (e?.code) throw e + if (e?.name === 'AbortError') throw makeError('Request timed out', 'TIMEOUT') + throw makeError(e?.message || 'Network error', 'NETWORK_ERROR') + } +} + +export const apiClient = { + get: (url: string, init?: RequestInit) => request('GET', url, undefined, init), + post: (url: string, body?: any, init?: RequestInit) => request('POST', url, body, init), + put: (url: string, body?: any, init?: RequestInit) => request('PUT', url, body, init), + delete: (url: string, init?: RequestInit) => request('DELETE', url, undefined, init), +} + +export const chatbotApi = { + async listChatbots() { + try { + return await apiClient.get('/api-internal/v1/chatbot/list') + } catch { + return await apiClient.get('/api-internal/v1/chatbot/instances') + } + }, + createChatbot(config: any) { + return apiClient.post('/api-internal/v1/chatbot/create', config) + }, + updateChatbot(id: string, config: any) { + return apiClient.put(`/api-internal/v1/chatbot/update/${encodeURIComponent(id)}`, config) + }, + deleteChatbot(id: string) { + return apiClient.delete(`/api-internal/v1/chatbot/delete/${encodeURIComponent(id)}`) + }, + sendMessage(chatbotId: string, message: string, conversationId?: string, history?: Array<{role: string; content: string}>) { + const body: any = { chatbot_id: chatbotId, message } + if (conversationId) body.conversation_id = conversationId + if (history) body.history = history + return apiClient.post('/api-internal/v1/chatbot/chat', body) + }, +} + diff --git a/frontend/src/lib/config.ts b/frontend/src/lib/config.ts new file mode 100644 index 0000000..e194a3f --- /dev/null +++ b/frontend/src/lib/config.ts @@ -0,0 +1,15 @@ +export const config = { + getPublicApiUrl(): string { + if (typeof process !== 'undefined' && process.env.NEXT_PUBLIC_BASE_URL) { + return process.env.NEXT_PUBLIC_BASE_URL + } + if (typeof window !== 'undefined') { + return window.location.origin + } + return 'http://localhost:3000' + }, + getAppName(): string { + return process.env.NEXT_PUBLIC_APP_NAME || 'Enclava' + }, +} + diff --git a/frontend/src/lib/file-download.ts b/frontend/src/lib/file-download.ts new file mode 100644 index 0000000..137b274 --- /dev/null +++ b/frontend/src/lib/file-download.ts @@ -0,0 +1,51 @@ +import { tokenManager } from './token-manager' + +export async function downloadFile(path: string, filename: string, params?: URLSearchParams | Record) { + const url = new URL(path, typeof window !== 'undefined' ? window.location.origin : 'http://localhost:3000') + if (params) { + const p = params instanceof URLSearchParams ? params : new URLSearchParams(params) + p.forEach((v, k) => url.searchParams.set(k, v)) + } + + const token = await tokenManager.getAccessToken() + const res = await fetch(url.toString(), { + headers: { + ...(token ? { Authorization: `Bearer ${token}` } : {}), + }, + }) + if (!res.ok) throw new Error(`Failed to download file (${res.status})`) + const blob = await res.blob() + + if (typeof window !== 'undefined') { + const link = document.createElement('a') + const href = URL.createObjectURL(blob) + link.href = href + link.download = filename + document.body.appendChild(link) + link.click() + link.remove() + URL.revokeObjectURL(href) + } +} + +export async function uploadFile(path: string, file: File, extraFields?: Record) { + const form = new FormData() + form.append('file', file) + if (extraFields) Object.entries(extraFields).forEach(([k, v]) => form.append(k, v)) + + const token = await tokenManager.getAccessToken() + const res = await fetch(path, { + method: 'POST', + headers: { + ...(token ? { Authorization: `Bearer ${token}` } : {}), + }, + body: form, + }) + if (!res.ok) { + let details: any + try { details = await res.json() } catch { details = await res.text() } + throw new Error(typeof details === 'string' ? details : (details?.error || 'Upload failed')) + } + return await res.json() +} + diff --git a/frontend/src/lib/id-utils.ts b/frontend/src/lib/id-utils.ts new file mode 100644 index 0000000..d511e03 --- /dev/null +++ b/frontend/src/lib/id-utils.ts @@ -0,0 +1,16 @@ +export function generateId(prefix = "id"): string { + const rand = Math.random().toString(36).slice(2, 10) + return `${prefix}_${rand}` +} + +export function generateShortId(prefix = "id"): string { + const rand = Math.random().toString(36).slice(2, 7) + return `${prefix}_${rand}` +} + +export function generateTimestampId(prefix = "id"): string { + const ts = Date.now() + const rand = Math.floor(Math.random() * 1000).toString().padStart(3, '0') + return `${prefix}_${ts}_${rand}` +} + diff --git a/frontend/src/lib/proxy-auth.ts b/frontend/src/lib/proxy-auth.ts new file mode 100644 index 0000000..bbf7109 --- /dev/null +++ b/frontend/src/lib/proxy-auth.ts @@ -0,0 +1,31 @@ +const BACKEND_URL = process.env.INTERNAL_API_URL || `http://enclava-backend:${process.env.BACKEND_INTERNAL_PORT || '8000'}` + +function mapPath(path: string): string { + // Convert '/api-internal/..' to backend '/api/..' + if (path.startsWith('/api-internal/')) { + return path.replace('/api-internal/', '/api/') + } + return path +} + +export async function proxyRequest(path: string, init?: RequestInit): Promise { + const url = `${BACKEND_URL}${mapPath(path)}` + const headers: Record = { + 'Content-Type': 'application/json', + ...(init?.headers as Record | undefined), + } + return fetch(url, { ...init, headers }) +} + +export async function handleProxyResponse(response: Response, defaultMessage = 'Request failed'): Promise { + if (!response.ok) { + let details: any + try { details = await response.json() } catch { details = await response.text() } + throw new Error(typeof details === 'string' ? `${defaultMessage}: ${details}` : (details?.error || defaultMessage)) + } + const contentType = response.headers.get('content-type') || '' + if (contentType.includes('application/json')) return (await response.json()) as T + // @ts-ignore allow non-json + return (await response.text()) as T +} + diff --git a/frontend/src/lib/token-manager.ts b/frontend/src/lib/token-manager.ts new file mode 100644 index 0000000..b4699df --- /dev/null +++ b/frontend/src/lib/token-manager.ts @@ -0,0 +1,141 @@ +type Listener = (...args: any[]) => void + +class SimpleEmitter { + private listeners = new Map>() + + on(event: string, listener: Listener) { + if (!this.listeners.has(event)) this.listeners.set(event, new Set()) + this.listeners.get(event)!.add(listener) + } + + off(event: string, listener: Listener) { + this.listeners.get(event)?.delete(listener) + } + + emit(event: string, ...args: any[]) { + this.listeners.get(event)?.forEach(l => l(...args)) + } +} + +interface StoredTokens { + access_token: string + refresh_token: string + access_expires_at: number // epoch ms + refresh_expires_at?: number // epoch ms +} + +const ACCESS_LIFETIME_FALLBACK_MS = 30 * 60 * 1000 // 30 minutes +const REFRESH_LIFETIME_FALLBACK_MS = 7 * 24 * 60 * 60 * 1000 // 7 days + +function now() { return Date.now() } + +function readTokens(): StoredTokens | null { + if (typeof window === 'undefined') return null + try { + const raw = window.localStorage.getItem('auth_tokens') + return raw ? JSON.parse(raw) as StoredTokens : null + } catch { + return null + } +} + +function writeTokens(tokens: StoredTokens | null) { + if (typeof window === 'undefined') return + if (tokens) { + window.localStorage.setItem('auth_tokens', JSON.stringify(tokens)) + } else { + window.localStorage.removeItem('auth_tokens') + } +} + +class TokenManager extends SimpleEmitter { + private refreshTimer: ReturnType | null = null + + isAuthenticated(): boolean { + const t = readTokens() + return !!t && t.access_expires_at > now() + } + + getTokenExpiry(): Date | null { + const t = readTokens() + return t ? new Date(t.access_expires_at) : null + } + + getRefreshTokenExpiry(): Date | null { + const t = readTokens() + return t?.refresh_expires_at ? new Date(t.refresh_expires_at) : null + } + + setTokens(accessToken: string, refreshToken: string, expiresInSeconds?: number) { + const access_expires_at = now() + (expiresInSeconds ? expiresInSeconds * 1000 : ACCESS_LIFETIME_FALLBACK_MS) + const refresh_expires_at = now() + REFRESH_LIFETIME_FALLBACK_MS + const tokens: StoredTokens = { + access_token: accessToken, + refresh_token: refreshToken, + access_expires_at, + refresh_expires_at, + } + writeTokens(tokens) + this.scheduleRefresh() + this.emit('tokensUpdated') + } + + clearTokens() { + if (this.refreshTimer) { + clearTimeout(this.refreshTimer) + this.refreshTimer = null + } + writeTokens(null) + this.emit('tokensCleared') + } + + logout() { + this.clearTokens() + this.emit('logout') + } + + private scheduleRefresh() { + if (typeof window === 'undefined') return + const t = readTokens() + if (!t) return + if (this.refreshTimer) clearTimeout(this.refreshTimer) + const msUntilRefresh = Math.max(5_000, t.access_expires_at - now() - 60_000) // 1 minute before expiry + this.refreshTimer = setTimeout(() => { + this.refreshAccessToken().catch(() => { + this.emit('sessionExpired', 'refresh_failed') + this.clearTokens() + }) + }, msUntilRefresh) + } + + async getAccessToken(): Promise { + const t = readTokens() + if (!t) return null + if (t.access_expires_at - now() > 10_000) return t.access_token + try { + await this.refreshAccessToken() + return readTokens()?.access_token || null + } catch { + this.emit('sessionExpired', 'expired') + this.clearTokens() + return null + } + } + + private async refreshAccessToken(): Promise { + const t = readTokens() + if (!t?.refresh_token) throw new Error('No refresh token') + const res = await fetch('/api-internal/v1/auth/refresh', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ refresh_token: t.refresh_token }), + }) + if (!res.ok) throw new Error('Refresh failed') + const data = await res.json() + const expiresIn = data.expires_in as number | undefined + this.setTokens(data.access_token, data.refresh_token || t.refresh_token, expiresIn) + } +} + +export const tokenManager = new TokenManager() + diff --git a/frontend/src/lib/utils.ts b/frontend/src/lib/utils.ts new file mode 100644 index 0000000..02aca5c --- /dev/null +++ b/frontend/src/lib/utils.ts @@ -0,0 +1,8 @@ +import { type ClassValue } from 'clsx' +import { clsx } from 'clsx' +import { twMerge } from 'tailwind-merge' + +export function cn(...inputs: ClassValue[]) { + return twMerge(clsx(inputs)) +} + From 0c20de4ca1e9948f7e91be1b4765bcdedde9451f Mon Sep 17 00:00:00 2001 From: Aljaz Ceru Date: Fri, 19 Sep 2025 20:34:51 +0200 Subject: [PATCH 02/13] working chatbot, rag weird --- backend/app/api/v1/chatbot.py | 60 ++------ backend/app/services/llm/security.py | 142 ++++++++++++------ backend/app/services/llm/service.py | 89 ++++------- backend/modules/chatbot/main.py | 57 +++++-- frontend/package-lock.json | 6 +- frontend/src/app/api-keys/page.tsx | 11 +- .../src/components/chatbot/ChatInterface.tsx | 32 ++-- .../src/components/chatbot/ChatbotManager.tsx | 1 + frontend/src/lib/api-client.ts | 24 ++- 9 files changed, 230 insertions(+), 192 deletions(-) diff --git a/backend/app/api/v1/chatbot.py b/backend/app/api/v1/chatbot.py index 1f2f579..20f03dc 100644 --- a/backend/app/api/v1/chatbot.py +++ b/backend/app/api/v1/chatbot.py @@ -275,14 +275,14 @@ async def chat_with_chatbot( current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db) ): - """Send a message to a chatbot and get a response""" + """Send a message to a chatbot and get a response (without persisting conversation)""" user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id log_api_request("chat_with_chatbot", { "user_id": user_id, "chatbot_id": chatbot_id, "message_length": len(request.message) }) - + try: # Get the chatbot instance result = await db.execute( @@ -291,74 +291,40 @@ async def chat_with_chatbot( .where(ChatbotInstance.created_by == str(user_id)) ) chatbot = result.scalar_one_or_none() - + if not chatbot: raise HTTPException(status_code=404, detail="Chatbot not found") - + if not chatbot.is_active: raise HTTPException(status_code=400, detail="Chatbot is not active") - - # Initialize conversation service - conversation_service = ConversationService(db) - - # Get or create conversation - conversation = await conversation_service.get_or_create_conversation( - chatbot_id=chatbot_id, - user_id=str(user_id), - conversation_id=request.conversation_id - ) - - # Add user message to conversation - await conversation_service.add_message( - conversation_id=conversation.id, - role="user", - content=request.message, - metadata={} - ) - + # Get chatbot module and generate response try: chatbot_module = module_manager.modules.get("chatbot") if not chatbot_module: raise HTTPException(status_code=500, detail="Chatbot module not available") - - # Load conversation history for context - conversation_history = await conversation_service.get_conversation_history( - conversation_id=conversation.id, - limit=chatbot.config.get('memory_length', 10), - include_system=False - ) - - # Use the chatbot module to generate a response + + # Use the chatbot module to generate a response (without persisting) response_data = await chatbot_module.chat( chatbot_config=chatbot.config, message=request.message, - conversation_history=conversation_history, + conversation_history=[], # Empty history for test chat user_id=str(user_id) ) - + response_content = response_data.get("response", "I'm sorry, I couldn't generate a response.") - + except Exception as e: # Use fallback response fallback_responses = chatbot.config.get("fallback_responses", [ "I'm sorry, I'm having trouble processing your request right now." ]) response_content = fallback_responses[0] if fallback_responses else "I'm sorry, I couldn't process your request." - - # Save assistant message using conversation service - assistant_message = await conversation_service.add_message( - conversation_id=conversation.id, - role="assistant", - content=response_content, - metadata={}, - sources=response_data.get("sources") - ) - + + # Return response without conversation ID (since we're not persisting) return { - "conversation_id": conversation.id, "response": response_content, - "timestamp": assistant_message.timestamp.isoformat() + "sources": response_data.get("sources") } except HTTPException: diff --git a/backend/app/services/llm/security.py b/backend/app/services/llm/security.py index 0d24f62..8aa37be 100644 --- a/backend/app/services/llm/security.py +++ b/backend/app/services/llm/security.py @@ -29,7 +29,7 @@ class SecurityManager: """Setup patterns for prompt injection detection""" self.injection_patterns = [ # Direct instruction injection - r"(?i)(ignore|forget|disregard|override)\s+(previous|all|above|prior)\s+(instructions|rules|prompts)", + r"(?i)(ignore|forget|disregard|override).{0,20}(instructions|rules|prompts)", r"(?i)(new|updated|different)\s+(instructions|rules|system)", r"(?i)act\s+as\s+(if|though)\s+you\s+(are|were)", r"(?i)pretend\s+(to\s+be|you\s+are)", @@ -61,12 +61,12 @@ class SecurityManager: r"(?i)base64\s*:", r"(?i)hex\s*:", r"(?i)unicode\s*:", - r"[A-Za-z0-9+/]{20,}={0,2}", # Potential base64 + r"(?i)\b[A-Za-z0-9+/]{40,}={0,2}\b", # More specific base64 pattern (longer sequences) - # SQL injection patterns (for system prompts) - r"(?i)(union|select|insert|update|delete|drop|create)\s+", - r"(?i)(or|and)\s+1\s*=\s*1", - r"(?i)';?\s*(drop|delete|insert)", + # SQL injection patterns (more specific to reduce false positives) + r"(?i)(union\s+select|select\s+\*|insert\s+into|update\s+\w+\s+set|delete\s+from|drop\s+table|create\s+table)\s", + r"(?i)(or|and)\s+\d+\s*=\s*\d+", + r"(?i)';?\s*(drop\s+table|delete\s+from|insert\s+into)", # Command injection patterns r"(?i)(exec|eval|system|shell|cmd)\s*\(", @@ -88,23 +88,27 @@ class SecurityManager: def validate_prompt_security(self, messages: List[Dict[str, str]]) -> Tuple[bool, float, List[str]]: """ Validate messages for prompt injection attempts - + Returns: Tuple[bool, float, List[str]]: (is_safe, risk_score, detected_patterns) """ detected_patterns = [] total_risk = 0.0 - + + # Check if this is a system/RAG request + is_system_request = self._is_system_request(messages) + for message in messages: content = message.get("content", "") if not content: continue - - # Check against injection patterns + + # Check against injection patterns with context awareness for i, pattern in enumerate(self.compiled_patterns): matches = pattern.findall(content) if matches: - pattern_risk = self._calculate_pattern_risk(i, matches) + # Apply context-aware risk calculation + pattern_risk = self._calculate_pattern_risk(i, matches, message.get("role", "user"), is_system_request) total_risk += pattern_risk detected_patterns.append({ "pattern_index": i, @@ -112,57 +116,97 @@ class SecurityManager: "matches": matches, "risk": pattern_risk }) - - # Additional security checks - total_risk += self._check_message_characteristics(content) - + + # Additional security checks with context awareness + total_risk += self._check_message_characteristics(content, message.get("role", "user"), is_system_request) + # Normalize risk score (0.0 to 1.0) risk_score = min(total_risk / len(messages) if messages else 0.0, 1.0) - is_safe = risk_score < settings.API_SECURITY_RISK_THRESHOLD - + # Never block - always return True for is_safe + is_safe = True + if detected_patterns: - logger.warning(f"Detected {len(detected_patterns)} potential injection patterns, risk score: {risk_score}") - + logger.info(f"Detected {len(detected_patterns)} potential injection patterns, risk score: {risk_score} (system_request: {is_system_request})") + return is_safe, risk_score, detected_patterns - def _calculate_pattern_risk(self, pattern_index: int, matches: List) -> float: - """Calculate risk score for a detected pattern""" + def _calculate_pattern_risk(self, pattern_index: int, matches: List, role: str, is_system_request: bool) -> float: + """Calculate risk score for a detected pattern with context awareness""" # Different patterns have different risk levels - high_risk_patterns = [0, 1, 2, 3, 4, 5, 6, 7, 14, 15, 16, 22, 23, 24] # System manipulation, jailbreak + high_risk_patterns = [0, 1, 2, 3, 4, 5, 6, 7, 22, 23, 24] # System manipulation, jailbreak medium_risk_patterns = [8, 9, 10, 11, 12, 13, 17, 18, 19, 20, 21] # Escape attempts, info extraction - + + # Base risk score base_risk = 0.8 if pattern_index in high_risk_patterns else 0.5 if pattern_index in medium_risk_patterns else 0.3 - - # Increase risk based on number of matches - match_multiplier = min(1.0 + (len(matches) - 1) * 0.2, 2.0) - + + # Apply context-specific risk reduction + if is_system_request or role == "system": + # Reduce risk for system messages and RAG content + if pattern_index in [14, 15, 16]: # Encoding patterns (base64, hex, unicode) + base_risk *= 0.2 # Reduce encoding risk by 80% for system content + elif pattern_index in [17, 18, 19]: # SQL patterns + base_risk *= 0.3 # Reduce SQL risk by 70% for system content + else: + base_risk *= 0.6 # Reduce other risks by 40% for system content + + # Increase risk based on number of matches, but cap it + match_multiplier = min(1.0 + (len(matches) - 1) * 0.1, 1.5) # Reduced multiplier + return base_risk * match_multiplier - def _check_message_characteristics(self, content: str) -> float: - """Check message characteristics for additional risk factors""" + def _check_message_characteristics(self, content: str, role: str, is_system_request: bool) -> float: + """Check message characteristics for additional risk factors with context awareness""" risk = 0.0 - - # Excessive length (potential stuffing attack) - if len(content) > 10000: - risk += 0.3 - - # High ratio of special characters + + # Excessive length (potential stuffing attack) - less restrictive for system content + length_threshold = 50000 if is_system_request else 10000 # Much higher threshold for system content + if len(content) > length_threshold: + risk += 0.1 if is_system_request else 0.3 + + # High ratio of special characters - more lenient for system content special_chars = sum(1 for c in content if not c.isalnum() and not c.isspace()) - if len(content) > 0 and special_chars / len(content) > 0.5: - risk += 0.4 - - # Multiple encoding indicators + if len(content) > 0: + char_ratio = special_chars / len(content) + threshold = 0.8 if is_system_request else 0.5 + if char_ratio > threshold: + risk += 0.2 if is_system_request else 0.4 + + # Multiple encoding indicators - reduced risk for system content encoding_indicators = ["base64", "hex", "unicode", "url", "ascii"] found_encodings = sum(1 for indicator in encoding_indicators if indicator.lower() in content.lower()) if found_encodings > 1: - risk += 0.3 - - # Excessive newlines or formatting (potential formatting attacks) - if content.count('\n') > 50 or content.count('\\n') > 50: - risk += 0.2 - + risk += 0.1 if is_system_request else 0.3 + + # Excessive newlines or formatting - more lenient for system content + newline_threshold = 200 if is_system_request else 50 + if content.count('\n') > newline_threshold or content.count('\\n') > newline_threshold: + risk += 0.1 if is_system_request else 0.2 + return risk - + + def _is_system_request(self, messages: List[Dict[str, str]]) -> bool: + """Determine if this is a system/RAG request""" + if not messages: + return False + + # Check for system messages + for message in messages: + if message.get("role") == "system": + return True + + # Check message content for RAG indicators + for message in messages: + content = message.get("content", "") + if ("document:" in content.lower() or + "context:" in content.lower() or + "source:" in content.lower() or + "retrieved:" in content.lower() or + "citation:" in content.lower() or + "reference:" in content.lower()): + return True + + return False + def create_audit_log( self, user_id: str, @@ -195,11 +239,11 @@ class SecurityManager: audit_hash = self._create_audit_hash(audit_entry) audit_entry["audit_hash"] = audit_hash - # Log based on risk level + # Log based on risk level (never block, only log) if risk_score >= settings.API_SECURITY_RISK_THRESHOLD: - logger.error(f"HIGH RISK LLM REQUEST BLOCKED: {json.dumps(audit_entry)}") + logger.warning(f"HIGH RISK LLM REQUEST DETECTED (NOT BLOCKED): {json.dumps(audit_entry)}") elif risk_score >= settings.API_SECURITY_WARNING_THRESHOLD: - logger.warning(f"MEDIUM RISK LLM REQUEST: {json.dumps(audit_entry)}") + logger.info(f"MEDIUM RISK LLM REQUEST: {json.dumps(audit_entry)}") else: logger.info(f"LLM REQUEST AUDIT: user={user_id}, model={model}, risk={risk_score:.3f}") diff --git a/backend/app/services/llm/service.py b/backend/app/services/llm/service.py index fae28fd..bb8e683 100644 --- a/backend/app/services/llm/service.py +++ b/backend/app/services/llm/service.py @@ -16,6 +16,7 @@ from .models import ( ModelInfo, ProviderStatus, LLMMetrics ) from .config import config_manager, ProviderConfig +from ...core.config import settings from .security import security_manager from .resilience import ResilienceManagerFactory from .metrics import metrics_collector @@ -149,19 +150,17 @@ class LLMService: if not request.messages: raise ValidationError("Messages cannot be empty", field="messages") - # Security validation - # Chatbot and RAG system requests should have relaxed security validation - is_system_request = ( - request.user_id == "rag_system" or - request.user_id == "chatbot_user" or - str(request.user_id).startswith("chatbot_") - ) - + # Security validation (only if enabled) messages_dict = [{"role": msg.role, "content": msg.content} for msg in request.messages] - is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security(messages_dict) - - if not is_safe and not is_system_request: - # Log security violation for regular user requests + + if settings.API_SECURITY_ENABLED: + is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security(messages_dict) + else: + # Security disabled - always safe + is_safe, risk_score, detected_patterns = True, 0.0, [] + + if not is_safe: + # Log security violation security_manager.create_audit_log( user_id=request.user_id, api_key_id=request.api_key_id, @@ -171,7 +170,7 @@ class LLMService: risk_score=risk_score, detected_patterns=[p.get("pattern", "") for p in detected_patterns] ) - + # Record blocked request metrics_collector.record_request( provider="security", @@ -184,18 +183,12 @@ class LLMService: user_id=request.user_id, api_key_id=request.api_key_id ) - + raise SecurityError( "Request blocked due to security concerns", risk_score=risk_score, details={"detected_patterns": detected_patterns} ) - elif not is_safe and is_system_request: - # For system requests (chatbot/RAG), log but don't block - logger.info(f"System request contains security patterns (risk_score={risk_score:.2f}) but allowing due to system context") - if detected_patterns: - logger.info(f"Detected patterns: {[p.get('pattern', 'unknown') for p in detected_patterns]}") - # Allow system requests regardless of security patterns # Get provider for model provider_name = self._get_provider_for_model(request.model) @@ -317,25 +310,20 @@ class LLMService: await self.initialize() # Security validation (same as non-streaming) - # Chatbot and RAG system requests should have relaxed security validation - is_system_request = ( - request.user_id == "rag_system" or - request.user_id == "chatbot_user" or - str(request.user_id).startswith("chatbot_") - ) - messages_dict = [{"role": msg.role, "content": msg.content} for msg in request.messages] - is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security(messages_dict) - - if not is_safe and not is_system_request: + + if settings.API_SECURITY_ENABLED: + is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security(messages_dict) + else: + # Security disabled - always safe + is_safe, risk_score, detected_patterns = True, 0.0, [] + + if not is_safe: raise SecurityError( "Streaming request blocked due to security concerns", risk_score=risk_score, details={"detected_patterns": detected_patterns} ) - elif not is_safe and is_system_request: - # For system requests (chatbot/RAG), log but don't block - logger.info(f"System streaming request contains security patterns (risk_score={risk_score:.2f}) but allowing due to system context") # Get provider provider_name = self._get_provider_for_model(request.model) @@ -378,33 +366,22 @@ class LLMService: await self.initialize() # Security validation for embedding input - # RAG system requests (document embedding) should use relaxed security validation - is_rag_system = request.user_id == "rag_system" - - if not is_rag_system: - # Apply normal security validation for user-generated embedding requests - input_text = request.input if isinstance(request.input, str) else " ".join(request.input) + input_text = request.input if isinstance(request.input, str) else " ".join(request.input) + + if settings.API_SECURITY_ENABLED: is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security([ {"role": "user", "content": input_text} ]) - - if not is_safe: - raise SecurityError( - "Embedding request blocked due to security concerns", - risk_score=risk_score, - details={"detected_patterns": detected_patterns} - ) else: - # For RAG system requests, log but don't block (document content can contain legitimate text that triggers patterns) - input_text = request.input if isinstance(request.input, str) else " ".join(request.input) - is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security([ - {"role": "user", "content": input_text} - ]) - - if detected_patterns: - logger.info(f"RAG document embedding contains security patterns (risk_score={risk_score:.2f}) but allowing due to document context") - - # Allow RAG system requests regardless of security patterns + # Security disabled - always safe + is_safe, risk_score, detected_patterns = True, 0.0, [] + + if not is_safe: + raise SecurityError( + "Embedding request blocked due to security concerns", + risk_score=risk_score, + details={"detected_patterns": detected_patterns} + ) # Get provider provider_name = self._get_provider_for_model(request.model) diff --git a/backend/modules/chatbot/main.py b/backend/modules/chatbot/main.py index 38ee9ec..6f42f09 100644 --- a/backend/modules/chatbot/main.py +++ b/backend/modules/chatbot/main.py @@ -265,6 +265,7 @@ class ChatbotModule(BaseModule): async def chat_completion(self, request: ChatRequest, user_id: str, db: Session) -> ChatResponse: """Generate chat completion response""" + logger.info("=== CHAT COMPLETION METHOD CALLED ===") # Get chatbot configuration from database db_chatbot = db.query(DBChatbotInstance).filter(DBChatbotInstance.id == request.chatbot_id).first() @@ -363,10 +364,11 @@ class ChatbotModule(BaseModule): metadata={"error": str(e), "fallback": True} ) - async def _generate_response(self, message: str, db_messages: List[DBMessage], + async def _generate_response(self, message: str, db_messages: List[DBMessage], config: ChatbotConfig, context: Optional[Dict] = None, db: Session = None) -> tuple[str, Optional[List]]: """Generate response using LLM with optional RAG""" - + logger.info("=== _generate_response METHOD CALLED ===") + # Lazy load dependencies if not available await self._ensure_dependencies() @@ -426,6 +428,11 @@ class ChatbotModule(BaseModule): logger.warning(f"RAG search traceback: {traceback.format_exc()}") # Build conversation context (includes the current message from db_messages) + logger.info(f"=== CRITICAL DEBUG ===") + logger.info(f"rag_context length: {len(rag_context)}") + logger.info(f"rag_context empty: {not rag_context}") + logger.info(f"rag_context preview: {rag_context[:200] if rag_context else 'EMPTY'}") + logger.info(f"=== END CRITICAL DEBUG ===") messages = self._build_conversation_messages(db_messages, config, rag_context, context) # Note: Current user message is already included in db_messages from the query @@ -511,32 +518,38 @@ class ChatbotModule(BaseModule): # Return fallback if available return "I'm currently unable to process your request. Please try again later.", None - def _build_conversation_messages(self, db_messages: List[DBMessage], config: ChatbotConfig, + def _build_conversation_messages(self, db_messages: List[DBMessage], config: ChatbotConfig, rag_context: str = "", context: Optional[Dict] = None) -> List[Dict]: """Build messages array for LLM completion""" - + messages = [] - - # System prompt + logger.info(f"DEBUG: _build_conversation_messages called. rag_context length: {len(rag_context)}") + + # System prompt - keep it clean without RAG context system_prompt = config.system_prompt - if rag_context: - system_prompt += rag_context if context and context.get('additional_instructions'): - system_prompt += f"\\n\\nAdditional instructions: {context['additional_instructions']}" - + system_prompt += f"\n\nAdditional instructions: {context['additional_instructions']}" + messages.append({"role": "system", "content": system_prompt}) - + logger.info(f"Building messages from {len(db_messages)} database messages") - + # Conversation history (messages are already limited by memory_length in the query) # Reverse to get chronological order # Include ALL messages - the current user message is needed for the LLM to respond! for idx, msg in enumerate(reversed(db_messages)): logger.info(f"Processing message {idx}: role={msg.role}, content_preview={msg.content[:50] if msg.content else 'None'}...") if msg.role in ["user", "assistant"]: + # For user messages, prepend RAG context if available + content = msg.content + if msg.role == "user" and rag_context and idx == 0: + # Add RAG context to the current user message (first in reversed order) + content = f"Relevant information from knowledge base:\n{rag_context}\n\nQuestion: {msg.content}" + logger.info("Added RAG context to user message") + messages.append({ "role": msg.role, - "content": msg.content + "content": content }) logger.info(f"Added message with role {msg.role} to LLM messages") else: @@ -677,9 +690,10 @@ class ChatbotModule(BaseModule): return router # API Compatibility Methods - async def chat(self, chatbot_config: Dict[str, Any], message: str, + async def chat(self, chatbot_config: Dict[str, Any], message: str, conversation_history: List = None, user_id: str = "anonymous") -> Dict[str, Any]: """Chat method for API compatibility""" + logger.info("=== CHAT METHOD (API COMPATIBILITY) CALLED ===") logger.info(f"Chat method called with message: {message[:50]}... by user: {user_id}") # Lazy load dependencies @@ -709,9 +723,20 @@ class ChatbotModule(BaseModule): fallback_responses=chatbot_config.get("fallback_responses", []) ) - # Generate response using internal method with empty message history + # For API compatibility, create a temporary DBMessage for the current message + # so RAG context can be properly added + from app.models.chatbot import ChatbotMessage as DBMessage + + # Create a temporary user message with the current message + temp_user_message = DBMessage( + conversation_id="temp_conversation", + role=MessageRole.USER.value, + content=message + ) + + # Generate response using internal method with the current message included response_content, sources = await self._generate_response( - message, [], config, None, db + message, [temp_user_message], config, None, db ) return { diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 5f8f91b..b5af4b6 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -2613,9 +2613,9 @@ } }, "node_modules/axios": { - "version": "1.11.0", - "resolved": "https://registry.npmjs.org/axios/-/axios-1.11.0.tgz", - "integrity": "sha512-1Lx3WLFQWm3ooKDYZD1eXmoGO9fxYQjrycfHFC8P0sCfQVXyROp0p9PFWBehewBOdCwHc+f/b8I0fMto5eSfwA==", + "version": "1.12.2", + "resolved": "https://registry.npmjs.org/axios/-/axios-1.12.2.tgz", + "integrity": "sha512-vMJzPewAlRyOgxV2dU0Cuz2O8zzzx9VYtbJOaBgXFeLc4IV/Eg50n4LowmehOOR61S8ZMpc2K5Sa7g6A4jfkUw==", "license": "MIT", "dependencies": { "follow-redirects": "^1.15.6", diff --git a/frontend/src/app/api-keys/page.tsx b/frontend/src/app/api-keys/page.tsx index e62e2c7..da5bc62 100644 --- a/frontend/src/app/api-keys/page.tsx +++ b/frontend/src/app/api-keys/page.tsx @@ -2,6 +2,7 @@ import { useState, useEffect } from "react"; import { useSearchParams } from "next/navigation"; +import { Suspense } from "react"; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; @@ -93,7 +94,7 @@ const PERMISSION_OPTIONS = [ { value: "llm:embeddings", label: "LLM Embeddings" }, ]; -export default function ApiKeysPage() { +function ApiKeysContent() { const { toast } = useToast(); const searchParams = useSearchParams(); const [apiKeys, setApiKeys] = useState([]); @@ -905,4 +906,12 @@ export default function ApiKeysPage() { ); +} + +export default function ApiKeysPage() { + return ( + Loading API keys...}> + + + ); } \ No newline at end of file diff --git a/frontend/src/components/chatbot/ChatInterface.tsx b/frontend/src/components/chatbot/ChatInterface.tsx index fb65658..4e1fe8d 100644 --- a/frontend/src/components/chatbot/ChatInterface.tsx +++ b/frontend/src/components/chatbot/ChatInterface.tsx @@ -87,9 +87,8 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface const [messages, setMessages] = useState([]) const [input, setInput] = useState("") const [isLoading, setIsLoading] = useState(false) - const [conversationId, setConversationId] = useState(null) const scrollAreaRef = useRef(null) - const { toast } = useToast() + const { success: toastSuccess, error: toastError } = useToast() const scrollToBottom = useCallback(() => { if (scrollAreaRef.current) { @@ -120,23 +119,20 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface setIsLoading(true) try { - // Build conversation history in OpenAI format + let data: any + + // Use internal API const conversationHistory = messages.map(msg => ({ role: msg.role, content: msg.content })) - - const data = await chatbotApi.sendMessage( + + data = await chatbotApi.sendMessage( chatbotId, messageToSend, - conversationId || undefined, + undefined, // No conversation ID conversationHistory ) - - // Update conversation ID if it's a new conversation - if (!conversationId && data.conversation_id) { - setConversationId(data.conversation_id) - } const assistantMessage: ChatMessage = { id: data.message_id || generateTimestampId('msg'), @@ -153,16 +149,16 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface // More specific error handling if (appError.code === 'UNAUTHORIZED') { - toast.error("Authentication Required", "Please log in to continue chatting.") + toastError("Authentication Required", "Please log in to continue chatting.") } else if (appError.code === 'NETWORK_ERROR') { - toast.error("Connection Error", "Please check your internet connection and try again.") + toastError("Connection Error", "Please check your internet connection and try again.") } else { - toast.error("Message Failed", appError.message || "Failed to send message. Please try again.") + toastError("Message Failed", appError.message || "Failed to send message. Please try again.") } } finally { setIsLoading(false) } - }, [input, isLoading, chatbotId, conversationId, messages, toast]) + }, [input, isLoading, chatbotId, messages, toastError]) const handleKeyPress = useCallback((e: React.KeyboardEvent) => { if (e.key === 'Enter' && !e.shiftKey) { @@ -174,11 +170,11 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface const copyMessage = useCallback(async (content: string) => { try { await navigator.clipboard.writeText(content) - toast.success("Copied", "Message copied to clipboard") + toastSuccess("Copied", "Message copied to clipboard") } catch (error) { - toast.error("Copy Failed", "Unable to copy message to clipboard") + toastError("Copy Failed", "Unable to copy message to clipboard") } - }, [toast]) + }, [toastSuccess, toastError]) const formatTime = useCallback((date: Date) => { return date.toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' }) diff --git a/frontend/src/components/chatbot/ChatbotManager.tsx b/frontend/src/components/chatbot/ChatbotManager.tsx index 9178c66..6221d4e 100644 --- a/frontend/src/components/chatbot/ChatbotManager.tsx +++ b/frontend/src/components/chatbot/ChatbotManager.tsx @@ -138,6 +138,7 @@ export function ChatbotManager() { const [editingChatbot, setEditingChatbot] = useState(null) const [showChatInterface, setShowChatInterface] = useState(false) const [testingChatbot, setTestingChatbot] = useState(null) + const [chatbotApiKeys, setChatbotApiKeys] = useState>({}) const { toast } = useToast() // New chatbot form state diff --git a/frontend/src/lib/api-client.ts b/frontend/src/lib/api-client.ts index 126f733..4df0089 100644 --- a/frontend/src/lib/api-client.ts +++ b/frontend/src/lib/api-client.ts @@ -86,11 +86,31 @@ export const chatbotApi = { deleteChatbot(id: string) { return apiClient.delete(`/api-internal/v1/chatbot/delete/${encodeURIComponent(id)}`) }, + // Legacy method with JWT auth (to be deprecated) sendMessage(chatbotId: string, message: string, conversationId?: string, history?: Array<{role: string; content: string}>) { - const body: any = { chatbot_id: chatbotId, message } + const body: any = { message } if (conversationId) body.conversation_id = conversationId if (history) body.history = history - return apiClient.post('/api-internal/v1/chatbot/chat', body) + return apiClient.post(`/api-internal/v1/chatbot/chat/${encodeURIComponent(chatbotId)}`, body) }, + // OpenAI-compatible chatbot API with API key auth + sendOpenAIChatMessage(chatbotId: string, messages: Array<{role: string; content: string}>, apiKey: string, options?: { + temperature?: number + max_tokens?: number + stream?: boolean + }) { + const body: any = { + messages, + ...options + } + return fetch(`/api/v1/chatbot/external/${encodeURIComponent(chatbotId)}/chat/completions`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${apiKey}` + }, + body: JSON.stringify(body) + }).then(res => res.json()) + } } From f58a76ac596439bf6f69963fa278fd7d47f2d288 Mon Sep 17 00:00:00 2001 From: Aljaz Ceru Date: Sun, 21 Sep 2025 06:49:55 +0200 Subject: [PATCH 03/13] ratelimiting and rag --- backend/app/core/config.py | 27 +- backend/app/main.py | 4 + backend/app/middleware/rate_limiting.py | 234 +++++++++++------- backend/app/middleware/security.py | 30 +-- .../services/enhanced_embedding_service.py | 201 +++++++++++++++ backend/app/services/llm/config.py | 15 +- backend/modules/rag/main.py | 29 ++- 7 files changed, 410 insertions(+), 130 deletions(-) create mode 100644 backend/app/services/enhanced_embedding_service.py diff --git a/backend/app/core/config.py b/backend/app/core/config.py index c5cb8c3..f3ac614 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -82,18 +82,25 @@ class Settings(BaseSettings): # Rate Limiting Configuration API_RATE_LIMITING_ENABLED: bool = os.getenv("API_RATE_LIMITING_ENABLED", "True").lower() == "true" - - # Authenticated users (JWT token) - API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "300")) - API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "5000")) - + + # PrivateMode Standard tier limits (organization-level, not per user) + # These are shared across all API keys and users in the organization + PRIVATEMODE_REQUESTS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_MINUTE", "20")) + PRIVATEMODE_REQUESTS_PER_HOUR: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_HOUR", "1200")) + PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE", "20000")) + PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE", "10000")) + + # Per-user limits (additional protection on top of organization limits) + API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "20")) # Match PrivateMode + API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "1200")) + # API key users (programmatic access) - API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "1000")) - API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "20000")) - + API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "20")) # Match PrivateMode + API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "1200")) + # Premium/Enterprise API keys - API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "5000")) - API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "100000")) + API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")) # Match PrivateMode + API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200")) # Security Thresholds API_SECURITY_RISK_THRESHOLD: float = float(os.getenv("API_SECURITY_RISK_THRESHOLD", "0.8")) # Block requests above this risk score diff --git a/backend/app/main.py b/backend/app/main.py index e0466e6..40d51a3 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -139,6 +139,10 @@ setup_analytics_middleware(app) from app.middleware.security import setup_security_middleware setup_security_middleware(app, enabled=settings.API_SECURITY_ENABLED) +# Add rate limiting middleware only for specific endpoints +from app.middleware.rate_limiting import RateLimitMiddleware +app.add_middleware(RateLimitMiddleware) + # Exception handlers @app.exception_handler(CustomHTTPException) diff --git a/backend/app/middleware/rate_limiting.py b/backend/app/middleware/rate_limiting.py index 611a67a..f6e1901 100644 --- a/backend/app/middleware/rate_limiting.py +++ b/backend/app/middleware/rate_limiting.py @@ -7,6 +7,7 @@ import redis from typing import Dict, Optional from fastapi import Request, HTTPException, status from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware import asyncio from datetime import datetime, timedelta @@ -155,96 +156,153 @@ class RateLimiter: rate_limiter = RateLimiter() -async def rate_limit_middleware(request: Request, call_next): - """ - Rate limiting middleware for FastAPI - """ - - # Skip rate limiting for health checks and static files - if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]: +class RateLimitMiddleware(BaseHTTPMiddleware): + """Rate limiting middleware for FastAPI""" + + def __init__(self, app): + super().__init__(app) + self.rate_limiter = RateLimiter() + logger.info("RateLimitMiddleware initialized") + + async def dispatch(self, request: Request, call_next): + """Process request through rate limiting""" + + # Skip rate limiting if disabled in settings + if not settings.API_RATE_LIMITING_ENABLED: + response = await call_next(request) + return response + + # Skip rate limiting for all internal API endpoints (platform operations) + if request.url.path.startswith("/api-internal/v1/"): + response = await call_next(request) + return response + + # Only apply rate limiting to privatemode.ai proxy endpoints (OpenAI-compatible API and LLM service) + # Skip for all other endpoints + if not (request.url.path.startswith("/api/v1/chat/completions") or + request.url.path.startswith("/api/v1/embeddings") or + request.url.path.startswith("/api/v1/models") or + request.url.path.startswith("/api/v1/llm/")): + response = await call_next(request) + return response + + # Skip rate limiting for health checks and static files + if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]: + response = await call_next(request) + return response + + # Get client IP + client_ip = request.client.host + forwarded_for = request.headers.get("X-Forwarded-For") + if forwarded_for: + client_ip = forwarded_for.split(",")[0].strip() + + # Check for API key in headers + api_key = None + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + api_key = auth_header[7:] + elif request.headers.get("X-API-Key"): + api_key = request.headers.get("X-API-Key") + + # Determine rate limiting strategy + headers = {} + is_allowed = True + + if api_key: + # API key-based rate limiting + api_key_key = f"api_key:{api_key}" + + # First check organization-wide limits (PrivateMode limits are org-wide) + org_key = "organization:privatemode" + + # Check organization per-minute limit + org_allowed_minute, org_headers_minute = await self.rate_limiter.check_rate_limit( + org_key, settings.PRIVATEMODE_REQUESTS_PER_MINUTE, 60, "minute" + ) + + # Check organization per-hour limit + org_allowed_hour, org_headers_hour = await self.rate_limiter.check_rate_limit( + org_key, settings.PRIVATEMODE_REQUESTS_PER_HOUR, 3600, "hour" + ) + + # If organization limits are exceeded, return 429 + if not (org_allowed_minute and org_allowed_hour): + logger.warning(f"Organization rate limit exceeded for {org_key}") + return JSONResponse( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + content={"detail": "Organization rate limit exceeded"}, + headers=org_headers_minute + ) + + # Then check per-API key limits + limit_per_minute = settings.API_RATE_LIMIT_API_KEY_PER_MINUTE + limit_per_hour = settings.API_RATE_LIMIT_API_KEY_PER_HOUR + + # Check per-minute limit + is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit( + api_key_key, limit_per_minute, 60, "minute" + ) + + # Check per-hour limit + is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit( + api_key_key, limit_per_hour, 3600, "hour" + ) + + is_allowed = is_allowed_minute and is_allowed_hour + headers = headers_minute # Use minute headers for response + + else: + # IP-based rate limiting for unauthenticated requests + rate_limit_key = f"ip:{client_ip}" + + # More restrictive limits for unauthenticated requests + limit_per_minute = 20 # Hardcoded for unauthenticated users + limit_per_hour = 100 + + # Check per-minute limit + is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit( + rate_limit_key, limit_per_minute, 60, "minute" + ) + + # Check per-hour limit + is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit( + rate_limit_key, limit_per_hour, 3600, "hour" + ) + + is_allowed = is_allowed_minute and is_allowed_hour + headers = headers_minute # Use minute headers for response + + # If rate limit exceeded, return 429 + if not is_allowed: + return JSONResponse( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + content={ + "error": "RATE_LIMIT_EXCEEDED", + "message": "Rate limit exceeded. Please try again later.", + "details": { + "limit": headers["X-RateLimit-Limit"], + "reset_time": headers["X-RateLimit-Reset"] + } + }, + headers={k: str(v) for k, v in headers.items()} + ) + + # Continue with request response = await call_next(request) + + # Add rate limit headers to response + for key, value in headers.items(): + response.headers[key] = str(value) + return response - - # Get client IP - client_ip = request.client.host - forwarded_for = request.headers.get("X-Forwarded-For") - if forwarded_for: - client_ip = forwarded_for.split(",")[0].strip() - - # Check for API key in headers - api_key = None - auth_header = request.headers.get("Authorization") - if auth_header and auth_header.startswith("Bearer "): - api_key = auth_header[7:] - elif request.headers.get("X-API-Key"): - api_key = request.headers.get("X-API-Key") - - # Determine rate limiting strategy - if api_key: - # API key-based rate limiting - rate_limit_key = f"api_key:{api_key}" - - # Get API key limits from database (simplified - would implement proper lookup) - limit_per_minute = 100 # Default limit - limit_per_hour = 1000 # Default limit - - # Check per-minute limit - is_allowed_minute, headers_minute = await rate_limiter.check_rate_limit( - rate_limit_key, limit_per_minute, 60, "minute" - ) - - # Check per-hour limit - is_allowed_hour, headers_hour = await rate_limiter.check_rate_limit( - rate_limit_key, limit_per_hour, 3600, "hour" - ) - - is_allowed = is_allowed_minute and is_allowed_hour - headers = headers_minute # Use minute headers for response - - else: - # IP-based rate limiting for unauthenticated requests - rate_limit_key = f"ip:{client_ip}" - - # More restrictive limits for unauthenticated requests - limit_per_minute = 20 - limit_per_hour = 100 - - # Check per-minute limit - is_allowed_minute, headers_minute = await rate_limiter.check_rate_limit( - rate_limit_key, limit_per_minute, 60, "minute" - ) - - # Check per-hour limit - is_allowed_hour, headers_hour = await rate_limiter.check_rate_limit( - rate_limit_key, limit_per_hour, 3600, "hour" - ) - - is_allowed = is_allowed_minute and is_allowed_hour - headers = headers_minute # Use minute headers for response - - # If rate limit exceeded, return 429 - if not is_allowed: - return JSONResponse( - status_code=status.HTTP_429_TOO_MANY_REQUESTS, - content={ - "error": "RATE_LIMIT_EXCEEDED", - "message": "Rate limit exceeded. Please try again later.", - "details": { - "limit": headers["X-RateLimit-Limit"], - "reset_time": headers["X-RateLimit-Reset"] - } - }, - headers={k: str(v) for k, v in headers.items()} - ) - - # Continue with request - response = await call_next(request) - - # Add rate limit headers to response - for key, value in headers.items(): - response.headers[key] = str(value) - - return response + + +# Keep the old function for backward compatibility +async def rate_limit_middleware(request: Request, call_next): + """Legacy function - use RateLimitMiddleware class instead""" + middleware = RateLimitMiddleware(None) + return await middleware.dispatch(request, call_next) class RateLimitExceeded(HTTPException): diff --git a/backend/app/middleware/security.py b/backend/app/middleware/security.py index 6efc1f4..57d2ebe 100644 --- a/backend/app/middleware/security.py +++ b/backend/app/middleware/security.py @@ -61,12 +61,12 @@ class SecurityMiddleware(BaseHTTPMiddleware): if analysis.is_threat and (analysis.should_block or analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD): await self._log_security_event(request, analysis) - # Check if request should be blocked - if analysis.should_block: + # Check if request should be blocked (excluding rate limiting) + if analysis.should_block and not analysis.rate_limit_exceeded: threat_detection_service.stats['threats_blocked'] += 1 logger.warning(f"Blocked request from {request.client.host if request.client else 'unknown'}: " f"risk_score={analysis.risk_score:.3f}, threats={len(analysis.threats)}") - + # Return security block response return self._create_block_response(analysis) @@ -136,17 +136,13 @@ class SecurityMiddleware(BaseHTTPMiddleware): """Create response for blocked requests""" # Determine status code based on threat type status_code = 403 # Forbidden by default - - # Rate limiting gets 429 - if analysis.rate_limit_exceeded: - status_code = 429 - + # Critical threats get 403 for threat in analysis.threats: if threat.threat_type in ["command_injection", "sql_injection"]: status_code = 403 break - + response_data = { "error": "Security Policy Violation", "message": "Request blocked due to security policy violation", @@ -155,24 +151,12 @@ class SecurityMiddleware(BaseHTTPMiddleware): "threat_count": len(analysis.threats), "recommendations": analysis.recommendations[:3] # Limit to first 3 recommendations } - - # Add rate limiting info if applicable - if analysis.rate_limit_exceeded: - response_data["error"] = "Rate Limit Exceeded" - response_data["message"] = f"Rate limit exceeded for {analysis.auth_level.value} user" - response_data["retry_after"] = "60" # Suggest retry after 60 seconds - + response = JSONResponse( content=response_data, status_code=status_code ) - - # Add rate limiting headers - if analysis.rate_limit_exceeded: - response.headers["Retry-After"] = "60" - response.headers["X-RateLimit-Limit"] = "See API documentation" - response.headers["X-RateLimit-Reset"] = str(int(time.time() + 60)) - + return response def _add_security_headers(self, response: Response) -> Response: diff --git a/backend/app/services/enhanced_embedding_service.py b/backend/app/services/enhanced_embedding_service.py new file mode 100644 index 0000000..284773f --- /dev/null +++ b/backend/app/services/enhanced_embedding_service.py @@ -0,0 +1,201 @@ +# Enhanced Embedding Service with Rate Limiting Handling +""" +Enhanced embedding service with robust rate limiting and retry logic +""" + +import asyncio +import logging +import time +from typing import List, Dict, Any, Optional +import numpy as np +from datetime import datetime, timedelta + +from .embedding_service import EmbeddingService +from app.core.config import settings + +logger = logging.getLogger(__name__) + + +class EnhancedEmbeddingService(EmbeddingService): + """Enhanced embedding service with rate limiting handling""" + + def __init__(self, model_name: str = "intfloat/multilingual-e5-large-instruct"): + super().__init__(model_name) + self.rate_limit_tracker = { + 'requests_count': 0, + 'window_start': time.time(), + 'window_size': 60, # 1 minute window + 'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 60)), # Configurable + 'retry_delays': [int(x) for x in getattr(settings, 'RAG_EMBEDDING_RETRY_DELAYS', '1,2,4,8,16').split(',')], # Exponential backoff + 'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 0.5)), + 'last_rate_limit_error': None + } + + async def get_embeddings_with_retry(self, texts: List[str], max_retries: int = None) -> tuple[List[List[float]], bool]: + """ + Get embeddings with rate limiting and retry logic + """ + if max_retries is None: + max_retries = int(getattr(settings, 'RAG_EMBEDDING_RETRY_COUNT', 3)) + + batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 5)) + + if not self.initialized: + logger.warning("Embedding service not initialized, using fallback") + return self._generate_fallback_embeddings(texts), False + + embeddings = [] + success = True + + for i in range(0, len(texts), batch_size): + batch = texts[i:i+batch_size] + batch_embeddings, batch_success = await self._get_batch_embeddings_with_retry(batch, max_retries) + embeddings.extend(batch_embeddings) + success = success and batch_success + + # Add delay between batches to avoid rate limiting + if i + batch_size < len(texts): + delay = self.rate_limit_tracker['delay_between_batches'] + await asyncio.sleep(delay) # Configurable delay between batches + + return embeddings, success + + async def _get_batch_embeddings_with_retry(self, texts: List[str], max_retries: int) -> tuple[List[List[float]], bool]: + """Get embeddings for a batch with retry logic""" + last_error = None + + for attempt in range(max_retries + 1): + try: + # Check rate limit before making request + if self._is_rate_limited(): + delay = self._get_rate_limit_delay() + logger.warning(f"Rate limit detected, waiting {delay} seconds") + await asyncio.sleep(delay) + continue + + # Make the request + embeddings = await self._get_embeddings_batch_impl(texts) + + # Update rate limit tracker on success + self._update_rate_limit_tracker(success=True) + + return embeddings, True + + except Exception as e: + last_error = e + error_msg = str(e).lower() + + # Check if it's a rate limit error + if any(indicator in error_msg for indicator in ['429', 'rate limit', 'too many requests', 'quota exceeded']): + logger.warning(f"Rate limit error (attempt {attempt + 1}/{max_retries + 1}): {e}") + self._update_rate_limit_tracker(success=False) + + if attempt < max_retries: + delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)] + logger.info(f"Retrying in {delay} seconds...") + await asyncio.sleep(delay) + continue + else: + logger.error(f"Max retries exceeded for rate limit, using fallback embeddings") + return self._generate_fallback_embeddings(texts), False + else: + # Non-rate-limit error + logger.error(f"Error generating embeddings: {e}") + if attempt < max_retries: + delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)] + await asyncio.sleep(delay) + else: + logger.error("Max retries exceeded, using fallback embeddings") + return self._generate_fallback_embeddings(texts), False + + # If we get here, all retries failed + logger.error(f"All retries failed, last error: {last_error}") + return self._generate_fallback_embeddings(texts), False + + async def _get_embeddings_batch_impl(self, texts: List[str]) -> List[List[float]]: + """Implementation of getting embeddings for a batch""" + from app.services.llm.service import llm_service + from app.services.llm.models import EmbeddingRequest + + embeddings = [] + + for text in texts: + # Truncate text if needed + max_chars = 1600 + truncated_text = text[:max_chars] if len(text) > max_chars else text + + llm_request = EmbeddingRequest( + model=self.model_name, + input=truncated_text, + user_id="rag_system", + api_key_id=0 + ) + + response = await llm_service.create_embedding(llm_request) + + if response.data and len(response.data) > 0: + embedding = response.data[0].embedding + if embedding: + embeddings.append(embedding) + if not hasattr(self, '_dimension_confirmed'): + self.dimension = len(embedding) + self._dimension_confirmed = True + else: + raise ValueError("Empty embedding in response") + else: + raise ValueError("Invalid response structure") + + return embeddings + + def _is_rate_limited(self) -> bool: + """Check if we're currently rate limited""" + now = time.time() + window_start = self.rate_limit_tracker['window_start'] + + # Reset window if it's expired + if now - window_start > self.rate_limit_tracker['window_size']: + self.rate_limit_tracker['requests_count'] = 0 + self.rate_limit_tracker['window_start'] = now + return False + + # Check if we've exceeded the limit + return self.rate_limit_tracker['requests_count'] >= self.rate_limit_tracker['max_requests_per_minute'] + + def _get_rate_limit_delay(self) -> float: + """Get delay to wait for rate limit reset""" + now = time.time() + window_end = self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size'] + return max(0, window_end - now) + + def _update_rate_limit_tracker(self, success: bool): + """Update the rate limit tracker""" + now = time.time() + + # Reset window if it's expired + if now - self.rate_limit_tracker['window_start'] > self.rate_limit_tracker['window_size']: + self.rate_limit_tracker['requests_count'] = 0 + self.rate_limit_tracker['window_start'] = now + + # Increment counter on successful requests + if success: + self.rate_limit_tracker['requests_count'] += 1 + + async def get_embedding_stats(self) -> Dict[str, Any]: + """Get embedding service statistics including rate limiting info""" + base_stats = await self.get_stats() + + return { + **base_stats, + "rate_limit_info": { + "requests_in_current_window": self.rate_limit_tracker['requests_count'], + "max_requests_per_minute": self.rate_limit_tracker['max_requests_per_minute'], + "window_reset_in_seconds": max(0, + self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size'] - time.time() + ), + "last_rate_limit_error": self.rate_limit_tracker['last_rate_limit_error'] + } + } + + +# Global enhanced embedding service instance +enhanced_embedding_service = EnhancedEmbeddingService() \ No newline at end of file diff --git a/backend/app/services/llm/config.py b/backend/app/services/llm/config.py index 8ac8fb8..61a8576 100644 --- a/backend/app/services/llm/config.py +++ b/backend/app/services/llm/config.py @@ -65,7 +65,16 @@ class LLMServiceConfig(BaseModel): # Provider configurations providers: Dict[str, ProviderConfig] = Field(default_factory=dict, description="Provider configurations") - + + # Token rate limiting (organization-wide) + token_limits_per_minute: Dict[str, int] = Field( + default_factory=lambda: { + "prompt_tokens": 20000, # PrivateMode Standard tier + "completion_tokens": 10000 # PrivateMode Standard tier + }, + description="Token rate limits per minute (organization-wide)" + ) + # Model routing (model_name -> provider_name) model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing") @@ -91,8 +100,8 @@ def create_default_config() -> LLMServiceConfig: supported_models=[], # Will be populated dynamically from proxy capabilities=["chat", "embeddings", "tee"], priority=1, - max_requests_per_minute=100, - max_requests_per_hour=2000, + max_requests_per_minute=20, # PrivateMode Standard tier limit: 20 req/min + max_requests_per_hour=1200, # 20 req/min * 60 min supports_streaming=True, supports_function_calling=True, max_context_window=128000, diff --git a/backend/modules/rag/main.py b/backend/modules/rag/main.py index 1871b0d..b6c90b7 100644 --- a/backend/modules/rag/main.py +++ b/backend/modules/rag/main.py @@ -60,6 +60,7 @@ import tiktoken from app.core.config import settings from app.core.logging import log_module_event from app.services.base_module import BaseModule, Permission +from app.services.enhanced_embedding_service import enhanced_embedding_service @dataclass @@ -1125,9 +1126,17 @@ class RAGModule(BaseModule): # Chunk the document chunks = self._chunk_text(content) - # Generate embeddings for all chunks in batch (more efficient) - embeddings = await self._generate_embeddings(chunks) - + # Generate embeddings with enhanced rate limiting handling + embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks) + + # Log if fallback embeddings were used + if not success: + logger.warning(f"Used fallback embeddings for document {doc_id} - search quality may be degraded") + log_module_event("rag", "fallback_embeddings_used", { + "document_id": doc_id, + "content_preview": content[:100] + "..." if len(content) > 100 else content + }) + # Create document points points = [] for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): @@ -1188,9 +1197,17 @@ class RAGModule(BaseModule): # Chunk the document chunks = self._chunk_text(processed_doc.content) - # Generate embeddings for all chunks in batch (more efficient) - embeddings = await self._generate_embeddings(chunks) - + # Generate embeddings with enhanced rate limiting handling + embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks) + + # Log if fallback embeddings were used + if not success: + logger.warning(f"Used fallback embeddings for document {processed_doc.id} - search quality may be degraded") + log_module_event("rag", "fallback_embeddings_used", { + "document_id": processed_doc.id, + "filename": processed_doc.original_filename + }) + # Create document points with enhanced metadata points = [] for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): From a2ee959ec951b0ecedd5112b7e467dc237b7a504 Mon Sep 17 00:00:00 2001 From: Aljaz Ceru Date: Sun, 21 Sep 2025 18:44:02 +0200 Subject: [PATCH 04/13] rag improvements --- .../app/services/llm/token_rate_limiter.py | 153 +++++++++++ backend/modules/chatbot/main.py | 79 +++--- backend/modules/rag/main.py | 255 +++++++++++++++--- 3 files changed, 401 insertions(+), 86 deletions(-) create mode 100644 backend/app/services/llm/token_rate_limiter.py diff --git a/backend/app/services/llm/token_rate_limiter.py b/backend/app/services/llm/token_rate_limiter.py new file mode 100644 index 0000000..2338a03 --- /dev/null +++ b/backend/app/services/llm/token_rate_limiter.py @@ -0,0 +1,153 @@ +""" +Token-based rate limiting for LLM service +""" + +import time +import redis +from typing import Dict, Optional, Tuple +from datetime import datetime, timedelta +from ..core.config import settings +from ..core.logging import get_logger + +logger = get_logger(__name__) + + +class TokenRateLimiter: + """Token-based rate limiting implementation""" + + def __init__(self): + try: + self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True) + self.redis_client.ping() + logger.info("Token rate limiter initialized with Redis backend") + except Exception as e: + logger.warning(f"Redis not available for token rate limiting: {e}") + self.redis_client = None + # Fall back to in-memory rate limiting + self.in_memory_store = {} + logger.info("Token rate limiter using in-memory fallback") + + async def check_token_limits( + self, + provider: str, + prompt_tokens: int, + completion_tokens: int = 0 + ) -> Tuple[bool, Dict[str, str]]: + """ + Check if token usage is within limits + + Args: + provider: Provider name (e.g., "privatemode") + prompt_tokens: Number of prompt tokens to use + completion_tokens: Number of completion tokens to use + + Returns: + Tuple of (is_allowed, headers) + """ + # Get token limits from configuration + from .config import get_config + config = get_config() + token_limits = config.token_limits_per_minute + + # Check organization-wide limits + org_key = f"tokens:org:{provider}" + + # Get current usage + current_usage = await self._get_token_usage(org_key) + + # Calculate new usage + new_prompt_tokens = current_usage.get("prompt_tokens", 0) + prompt_tokens + new_completion_tokens = current_usage.get("completion_tokens", 0) + completion_tokens + + # Check limits + prompt_limit = token_limits.get("prompt_tokens", 20000) + completion_limit = token_limits.get("completion_tokens", 10000) + + is_allowed = ( + new_prompt_tokens <= prompt_limit and + new_completion_tokens <= completion_limit + ) + + if is_allowed: + # Update usage + await self._update_token_usage(org_key, prompt_tokens, completion_tokens) + logger.debug(f"Token usage updated: {new_prompt_tokens}/{prompt_limit} prompt, " + f"{new_completion_tokens}/{completion_limit} completion") + + # Calculate remaining tokens + remaining_prompt = max(0, prompt_limit - new_prompt_tokens) + remaining_completion = max(0, completion_limit - new_completion_tokens) + + # Create headers + headers = { + "X-TokenLimit-Prompt-Remaining": str(remaining_prompt), + "X-TokenLimit-Completion-Remaining": str(remaining_completion), + "X-TokenLimit-Prompt-Limit": str(prompt_limit), + "X-TokenLimit-Completion-Limit": str(completion_limit), + "X-TokenLimit-Reset": str(int(time.time() + 60)) # Reset in 1 minute + } + + if not is_allowed: + logger.warning(f"Token rate limit exceeded for {provider}. " + f"Requested: {prompt_tokens} prompt, {completion_tokens} completion. " + f"Current: {current_usage}") + + return is_allowed, headers + + async def _get_token_usage(self, key: str) -> Dict[str, int]: + """Get current token usage""" + if self.redis_client: + try: + data = self.redis_client.hgetall(key) + if data: + return { + "prompt_tokens": int(data.get("prompt_tokens", 0)), + "completion_tokens": int(data.get("completion_tokens", 0)), + "updated_at": float(data.get("updated_at", time.time())) + } + except Exception as e: + logger.error(f"Error getting token usage from Redis: {e}") + + # Fallback to in-memory + return self.in_memory_store.get(key, {"prompt_tokens": 0, "completion_tokens": 0}) + + async def _update_token_usage(self, key: str, prompt_tokens: int, completion_tokens: int): + """Update token usage""" + if self.redis_client: + try: + pipe = self.redis_client.pipeline() + pipe.hincrby(key, "prompt_tokens", prompt_tokens) + pipe.hincrby(key, "completion_tokens", completion_tokens) + pipe.hset(key, "updated_at", time.time()) + pipe.expire(key, 60) # Expire after 1 minute + pipe.execute() + except Exception as e: + logger.error(f"Error updating token usage in Redis: {e}") + # Fallback to in-memory + self._update_in_memory(key, prompt_tokens, completion_tokens) + else: + self._update_in_memory(key, prompt_tokens, completion_tokens) + + def _update_in_memory(self, key: str, prompt_tokens: int, completion_tokens: int): + """Update in-memory token usage""" + if key not in self.in_memory_store: + self.in_memory_store[key] = {"prompt_tokens": 0, "completion_tokens": 0} + + self.in_memory_store[key]["prompt_tokens"] += prompt_tokens + self.in_memory_store[key]["completion_tokens"] += completion_tokens + self.in_memory_store[key]["updated_at"] = time.time() + + def cleanup_expired(self): + """Clean up expired entries (for in-memory store)""" + if not self.redis_client: + current_time = time.time() + expired_keys = [ + key for key, data in self.in_memory_store.items() + if current_time - data.get("updated_at", 0) > 60 + ] + for key in expired_keys: + del self.in_memory_store[key] + + +# Global token rate limiter instance +token_rate_limiter = TokenRateLimiter() \ No newline at end of file diff --git a/backend/modules/chatbot/main.py b/backend/modules/chatbot/main.py index 6f42f09..5ab62c7 100644 --- a/backend/modules/chatbot/main.py +++ b/backend/modules/chatbot/main.py @@ -265,7 +265,6 @@ class ChatbotModule(BaseModule): async def chat_completion(self, request: ChatRequest, user_id: str, db: Session) -> ChatResponse: """Generate chat completion response""" - logger.info("=== CHAT COMPLETION METHOD CALLED ===") # Get chatbot configuration from database db_chatbot = db.query(DBChatbotInstance).filter(DBChatbotInstance.id == request.chatbot_id).first() @@ -364,11 +363,10 @@ class ChatbotModule(BaseModule): metadata={"error": str(e), "fallback": True} ) - async def _generate_response(self, message: str, db_messages: List[DBMessage], + async def _generate_response(self, message: str, db_messages: List[DBMessage], config: ChatbotConfig, context: Optional[Dict] = None, db: Session = None) -> tuple[str, Optional[List]]: """Generate response using LLM with optional RAG""" - logger.info("=== _generate_response METHOD CALLED ===") - + # Lazy load dependencies if not available await self._ensure_dependencies() @@ -397,8 +395,8 @@ class ChatbotModule(BaseModule): for i, result in enumerate(rag_results)] # Build full RAG context from all results - rag_context = "\\n\\nRelevant information from knowledge base:\\n" + "\\n\\n".join([ - f"[Document {i+1}]:\\n{result.document.content}" for i, result in enumerate(rag_results) + rag_context = "\n\nRelevant information from knowledge base:\n" + "\n\n".join([ + f"[Document {i+1}]:\n{result.document.content}" for i, result in enumerate(rag_results) ]) # Detailed RAG logging - ALWAYS log for debugging @@ -407,14 +405,14 @@ class ChatbotModule(BaseModule): logger.info(f"Collection: {qdrant_collection_name}") logger.info(f"Number of results: {len(rag_results)}") for i, result in enumerate(rag_results): - logger.info(f"\\n--- RAG Result {i+1} ---") + logger.info(f"\n--- RAG Result {i+1} ---") logger.info(f"Score: {getattr(result, 'score', 'N/A')}") logger.info(f"Document ID: {getattr(result.document, 'id', 'N/A')}") logger.info(f"Full Content ({len(result.document.content)} chars):") logger.info(f"{result.document.content}") if hasattr(result.document, 'metadata'): logger.info(f"Metadata: {result.document.metadata}") - logger.info(f"\\n=== RAG CONTEXT BEING ADDED TO PROMPT ({len(rag_context)} chars) ===") + logger.info(f"\n=== RAG CONTEXT BEING ADDED TO PROMPT ({len(rag_context)} chars) ===") logger.info(rag_context) logger.info("=== END RAG SEARCH RESULTS ===") else: @@ -428,11 +426,6 @@ class ChatbotModule(BaseModule): logger.warning(f"RAG search traceback: {traceback.format_exc()}") # Build conversation context (includes the current message from db_messages) - logger.info(f"=== CRITICAL DEBUG ===") - logger.info(f"rag_context length: {len(rag_context)}") - logger.info(f"rag_context empty: {not rag_context}") - logger.info(f"rag_context preview: {rag_context[:200] if rag_context else 'EMPTY'}") - logger.info(f"=== END CRITICAL DEBUG ===") messages = self._build_conversation_messages(db_messages, config, rag_context, context) # Note: Current user message is already included in db_messages from the query @@ -452,9 +445,9 @@ class ChatbotModule(BaseModule): if config.use_rag and rag_context: logger.info(f"RAG context added: {len(rag_context)} characters") logger.info(f"RAG sources: {len(sources) if sources else 0} documents") - logger.info("\\n=== COMPLETE MESSAGES SENT TO LLM ===") + logger.info("\n=== COMPLETE MESSAGES SENT TO LLM ===") for i, msg in enumerate(messages): - logger.info(f"\\n--- Message {i+1} ---") + logger.info(f"\n--- Message {i+1} ---") logger.info(f"Role: {msg['role']}") logger.info(f"Content ({len(msg['content'])} chars):") # Truncate long content for logging (full RAG context can be very long) @@ -518,38 +511,34 @@ class ChatbotModule(BaseModule): # Return fallback if available return "I'm currently unable to process your request. Please try again later.", None - def _build_conversation_messages(self, db_messages: List[DBMessage], config: ChatbotConfig, + def _build_conversation_messages(self, db_messages: List[DBMessage], config: ChatbotConfig, rag_context: str = "", context: Optional[Dict] = None) -> List[Dict]: """Build messages array for LLM completion""" - + messages = [] - logger.info(f"DEBUG: _build_conversation_messages called. rag_context length: {len(rag_context)}") - - # System prompt - keep it clean without RAG context + + # System prompt system_prompt = config.system_prompt + if rag_context: + # Add explicit instruction to use RAG context + system_prompt += "\n\nIMPORTANT: Use the following information from the knowledge base to answer the user's question. " \ + "This information is directly relevant to their query and should be your primary source:\n" + rag_context if context and context.get('additional_instructions'): system_prompt += f"\n\nAdditional instructions: {context['additional_instructions']}" - + messages.append({"role": "system", "content": system_prompt}) - + logger.info(f"Building messages from {len(db_messages)} database messages") - + # Conversation history (messages are already limited by memory_length in the query) # Reverse to get chronological order # Include ALL messages - the current user message is needed for the LLM to respond! for idx, msg in enumerate(reversed(db_messages)): logger.info(f"Processing message {idx}: role={msg.role}, content_preview={msg.content[:50] if msg.content else 'None'}...") if msg.role in ["user", "assistant"]: - # For user messages, prepend RAG context if available - content = msg.content - if msg.role == "user" and rag_context and idx == 0: - # Add RAG context to the current user message (first in reversed order) - content = f"Relevant information from knowledge base:\n{rag_context}\n\nQuestion: {msg.content}" - logger.info("Added RAG context to user message") - messages.append({ "role": msg.role, - "content": content + "content": msg.content }) logger.info(f"Added message with role {msg.role} to LLM messages") else: @@ -690,10 +679,9 @@ class ChatbotModule(BaseModule): return router # API Compatibility Methods - async def chat(self, chatbot_config: Dict[str, Any], message: str, + async def chat(self, chatbot_config: Dict[str, Any], message: str, conversation_history: List = None, user_id: str = "anonymous") -> Dict[str, Any]: """Chat method for API compatibility""" - logger.info("=== CHAT METHOD (API COMPATIBILITY) CALLED ===") logger.info(f"Chat method called with message: {message[:50]}... by user: {user_id}") # Lazy load dependencies @@ -723,20 +711,21 @@ class ChatbotModule(BaseModule): fallback_responses=chatbot_config.get("fallback_responses", []) ) - # For API compatibility, create a temporary DBMessage for the current message - # so RAG context can be properly added - from app.models.chatbot import ChatbotMessage as DBMessage + # Generate response using internal method + # Create a temporary message object for the current user message + temp_messages = [ + DBMessage( + id=0, + conversation_id=0, + role="user", + content=message, + timestamp=datetime.utcnow(), + metadata={} + ) + ] - # Create a temporary user message with the current message - temp_user_message = DBMessage( - conversation_id="temp_conversation", - role=MessageRole.USER.value, - content=message - ) - - # Generate response using internal method with the current message included response_content, sources = await self._generate_response( - message, [temp_user_message], config, None, db + message, temp_messages, config, None, db ) return { diff --git a/backend/modules/rag/main.py b/backend/modules/rag/main.py index b6c90b7..7d75fbd 100644 --- a/backend/modules/rag/main.py +++ b/backend/modules/rag/main.py @@ -53,14 +53,13 @@ except ImportError: PYTHON_DOCX_AVAILABLE = False from qdrant_client import QdrantClient -from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue +from qdrant_client.models import Distance, VectorParams, PointStruct, ScoredPoint, Filter, FieldCondition, MatchValue from qdrant_client.http import models import tiktoken from app.core.config import settings from app.core.logging import log_module_event from app.services.base_module import BaseModule, Permission -from app.services.enhanced_embedding_service import enhanced_embedding_service @dataclass @@ -134,6 +133,19 @@ class RAGModule(BaseModule): self.embedding_model = None self.embedding_service = None self.tokenizer = None + + # Set improved default configuration + self.config = { + "chunk_size": 300, # Reduced from 400 for better precision + "chunk_overlap": 50, # Added overlap for context preservation + "max_results": 10, + "score_threshold": 0.3, # Increased from 0.0 to filter low-quality results + "enable_hybrid": True, # Enable hybrid search (vector + BM25) + "hybrid_weights": {"vector": 0.7, "bm25": 0.3} # Weight for hybrid scoring + } + # Update with any provided config + if config: + self.config.update(config) # Content processing components self.nlp_model = None @@ -640,19 +652,33 @@ class RAGModule(BaseModule): return embeddings def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]: - """Split text into chunks""" - chunk_size = chunk_size or self.config.get("chunk_size", 400) - + """Split text into overlapping chunks for better context preservation""" + chunk_size = chunk_size or self.config.get("chunk_size", 300) + chunk_overlap = self.config.get("chunk_overlap", 50) + # Tokenize text tokens = self.tokenizer.encode(text) - - # Split into chunks + + # Split into chunks with overlap chunks = [] - for i in range(0, len(tokens), chunk_size): - chunk_tokens = tokens[i:i + chunk_size] + start_idx = 0 + + while start_idx < len(tokens): + end_idx = min(start_idx + chunk_size, len(tokens)) + chunk_tokens = tokens[start_idx:end_idx] chunk_text = self.tokenizer.decode(chunk_tokens) - chunks.append(chunk_text) - + + # Only add non-empty chunks + if chunk_text.strip(): + chunks.append(chunk_text) + + # Move to next chunk with overlap + start_idx = end_idx - chunk_overlap + + # Ensure progress (in case overlap >= chunk_size) + if start_idx >= end_idx: + start_idx = end_idx + return chunks async def _process_text(self, content: bytes, filename: str) -> str: @@ -1126,17 +1152,9 @@ class RAGModule(BaseModule): # Chunk the document chunks = self._chunk_text(content) - # Generate embeddings with enhanced rate limiting handling - embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks) - - # Log if fallback embeddings were used - if not success: - logger.warning(f"Used fallback embeddings for document {doc_id} - search quality may be degraded") - log_module_event("rag", "fallback_embeddings_used", { - "document_id": doc_id, - "content_preview": content[:100] + "..." if len(content) > 100 else content - }) - + # Generate embeddings for all chunks in batch (more efficient) + embeddings = await self._generate_embeddings(chunks) + # Create document points points = [] for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): @@ -1197,17 +1215,9 @@ class RAGModule(BaseModule): # Chunk the document chunks = self._chunk_text(processed_doc.content) - # Generate embeddings with enhanced rate limiting handling - embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks) - - # Log if fallback embeddings were used - if not success: - logger.warning(f"Used fallback embeddings for document {processed_doc.id} - search quality may be degraded") - log_module_event("rag", "fallback_embeddings_used", { - "document_id": processed_doc.id, - "filename": processed_doc.original_filename - }) - + # Generate embeddings for all chunks in batch (more efficient) + embeddings = await self._generate_embeddings(chunks) + # Create document points with enhanced metadata points = [] for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): @@ -1277,6 +1287,154 @@ class RAGModule(BaseModule): except Exception: return False + async def _hybrid_search(self, collection_name: str, query: str, query_vector: List[float], + query_filter: Optional[Filter], limit: int, score_threshold: float) -> List[Any]: + """Perform hybrid search combining vector similarity and BM25 scoring""" + + # Preprocess query for BM25 + query_terms = self._preprocess_text_for_bm25(query) + + # Get all documents from the collection (for BM25 scoring) + # Note: In production, you'd want to optimize this with a proper BM25 index + scroll_filter = query_filter or Filter() + all_points = [] + + # Use scroll to get all points + offset = None + batch_size = 100 + while True: + search_result = self.qdrant_client.scroll( + collection_name=collection_name, + scroll_filter=scroll_filter, + limit=batch_size, + offset=offset, + with_payload=True, + with_vectors=False + ) + + points = search_result[0] + all_points.extend(points) + + if len(points) < batch_size: + break + + offset = points[-1].id + + # Calculate BM25 scores for each document + bm25_scores = {} + for point in all_points: + doc_id = point.payload.get("document_id", "") + content = point.payload.get("content", "") + + # Calculate BM25 score + bm25_score = self._calculate_bm25_score(query_terms, content) + bm25_scores[doc_id] = bm25_score + + # Perform vector search + vector_results = self.qdrant_client.search( + collection_name=collection_name, + query_vector=query_vector, + query_filter=query_filter, + limit=limit * 2, # Get more results for re-ranking + score_threshold=score_threshold / 2 # Lower threshold for initial search + ) + + # Combine scores + hybrid_weights = self.config.get("hybrid_weights", {"vector": 0.7, "bm25": 0.3}) + vector_weight = hybrid_weights.get("vector", 0.7) + bm25_weight = hybrid_weights.get("bm25", 0.3) + + # Create hybrid results + hybrid_results = [] + for result in vector_results: + doc_id = result.payload.get("document_id", "") + vector_score = result.score + bm25_score = bm25_scores.get(doc_id, 0.0) + + # Normalize scores (simple min-max normalization) + vector_norm = (vector_score - score_threshold) / (1.0 - score_threshold) if vector_score > score_threshold else 0 + bm25_norm = min(bm25_score, 1.0) # BM25 scores are typically 0-1 + + # Calculate hybrid score + hybrid_score = (vector_weight * vector_norm) + (bm25_weight * bm25_norm) + + # Create new point with hybrid score + hybrid_point = ScoredPoint( + id=result.id, + payload=result.payload, + score=hybrid_score, + vector=result.vector, + shard_key=None, + order_value=None + ) + hybrid_results.append(hybrid_point) + + # Sort by hybrid score and apply final threshold + hybrid_results.sort(key=lambda x: x.score, reverse=True) + final_results = [r for r in hybrid_results if r.score >= score_threshold][:limit] + + logger.info(f"Hybrid search: {len(vector_results)} vector results, {len(final_results)} final results") + return final_results + + def _preprocess_text_for_bm25(self, text: str) -> List[str]: + """Preprocess text for BM25 scoring""" + if not NLTK_AVAILABLE: + return text.lower().split() + + try: + # Tokenize + tokens = word_tokenize(text.lower()) + + # Remove stopwords and non-alphabetic tokens + stop_words = set(stopwords.words('english')) + filtered_tokens = [ + token for token in tokens + if token.isalpha() and token not in stop_words and len(token) > 2 + ] + + return filtered_tokens + except: + # Fallback to simple splitting + return text.lower().split() + + def _calculate_bm25_score(self, query_terms: List[str], document: str) -> float: + """Calculate BM25 score for a document against query terms""" + if not query_terms: + return 0.0 + + # Preprocess document + doc_terms = self._preprocess_text_for_bm25(document) + if not doc_terms: + return 0.0 + + # Calculate term frequencies + doc_len = len(doc_terms) + avg_doc_len = 300 # Average document length (configurable) + + # BM25 parameters + k1 = 1.2 # Controls term frequency saturation + b = 0.75 # Controls document length normalization + + score = 0.0 + + # Calculate IDF for each query term + for term in set(query_terms): + # Term frequency in document + tf = doc_terms.count(term) + + # Simple IDF (log(N/n) + 1) + # In production, you'd use the actual document frequency + idf = 2.0 # Simplified IDF + + # BM25 formula + numerator = tf * (k1 + 1) + denominator = tf + k1 * (1 - b + b * (doc_len / avg_doc_len)) + + score += idf * (numerator / denominator) + + # Normalize score to 0-1 range + return min(score / 10.0, 1.0) # Simple normalization + async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]: """Search for relevant documents""" if not self.enabled: @@ -1314,14 +1472,29 @@ class RAGModule(BaseModule): logger.info(f"Query embedding (first 10 values): {query_embedding[:10] if query_embedding else 'None'}") logger.info(f"Embedding service available: {self.embedding_service is not None}") - # Search in Qdrant - search_results = self.qdrant_client.search( - collection_name=collection_name, - query_vector=query_embedding, - query_filter=search_filter, - limit=max_results, - score_threshold=0.0 # Lowered from 0.5 to see all results including low scores - ) + # Check if hybrid search is enabled + enable_hybrid = self.config.get("enable_hybrid", False) + score_threshold = self.config.get("score_threshold", 0.3) + + if enable_hybrid and NLTK_AVAILABLE: + # Perform hybrid search (vector + BM25) + search_results = await self._hybrid_search( + collection_name=collection_name, + query=query, + query_vector=query_embedding, + query_filter=search_filter, + limit=max_results, + score_threshold=score_threshold + ) + else: + # Pure vector search with improved threshold + search_results = self.qdrant_client.search( + collection_name=collection_name, + query_vector=query_embedding, + query_filter=search_filter, + limit=max_results, + score_threshold=score_threshold + ) logger.info(f"Raw search results count: {len(search_results)}") From 361c016da4467b2b1bda424e1fe1523c499f072a Mon Sep 17 00:00:00 2001 From: Aljaz Ceru Date: Mon, 22 Sep 2025 11:42:40 +0200 Subject: [PATCH 05/13] chatbot rag testing --- backend/app/api/v1/chatbot.py | 47 ++++++++++++++++++++------------- backend/modules/chatbot/main.py | 4 ++- docker-compose.yml | 11 +++++--- 3 files changed, 39 insertions(+), 23 deletions(-) diff --git a/backend/app/api/v1/chatbot.py b/backend/app/api/v1/chatbot.py index 20f03dc..f738b05 100644 --- a/backend/app/api/v1/chatbot.py +++ b/backend/app/api/v1/chatbot.py @@ -32,12 +32,28 @@ class ChatbotCreateRequest(BaseModel): use_rag: bool = False rag_collection: Optional[str] = None rag_top_k: int = 5 + rag_score_threshold: float = 0.02 # Lowered from default 0.3 to allow more results temperature: float = 0.7 max_tokens: int = 1000 memory_length: int = 10 fallback_responses: List[str] = [] +class ChatbotUpdateRequest(BaseModel): + name: Optional[str] = None + chatbot_type: Optional[str] = None + model: Optional[str] = None + system_prompt: Optional[str] = None + use_rag: Optional[bool] = None + rag_collection: Optional[str] = None + rag_top_k: Optional[int] = None + rag_score_threshold: Optional[float] = None + temperature: Optional[float] = None + max_tokens: Optional[int] = None + memory_length: Optional[int] = None + fallback_responses: Optional[List[str]] = None + + class ChatRequest(BaseModel): message: str conversation_id: Optional[str] = None @@ -190,7 +206,7 @@ async def create_chatbot( @router.put("/update/{chatbot_id}") async def update_chatbot( chatbot_id: str, - request: ChatbotCreateRequest, + request: ChatbotUpdateRequest, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db) ): @@ -214,28 +230,23 @@ async def update_chatbot( if not chatbot: raise HTTPException(status_code=404, detail="Chatbot not found or access denied") - # Update chatbot configuration - config = { - "name": request.name, - "chatbot_type": request.chatbot_type, - "model": request.model, - "system_prompt": request.system_prompt, - "use_rag": request.use_rag, - "rag_collection": request.rag_collection, - "rag_top_k": request.rag_top_k, - "temperature": request.temperature, - "max_tokens": request.max_tokens, - "memory_length": request.memory_length, - "fallback_responses": request.fallback_responses - } - + # Get existing config + existing_config = chatbot.config.copy() if chatbot.config else {} + + # Update only the fields that are provided in the request + update_data = request.dict(exclude_unset=True) + + # Merge with existing config, preserving unset values + for key, value in update_data.items(): + existing_config[key] = value + # Update the chatbot await db.execute( update(ChatbotInstance) .where(ChatbotInstance.id == chatbot_id) .values( - name=request.name, - config=config, + name=existing_config.get("name", chatbot.name), + config=existing_config, updated_at=datetime.utcnow() ) ) diff --git a/backend/modules/chatbot/main.py b/backend/modules/chatbot/main.py index 5ab62c7..e378414 100644 --- a/backend/modules/chatbot/main.py +++ b/backend/modules/chatbot/main.py @@ -69,6 +69,7 @@ class ChatbotConfig: memory_length: int = 10 # Number of previous messages to remember use_rag: bool = False rag_top_k: int = 5 + rag_score_threshold: float = 0.02 # Lowered from default 0.3 to allow more results fallback_responses: List[str] = None def __post_init__(self): @@ -386,7 +387,8 @@ class ChatbotModule(BaseModule): rag_results = await self.rag_module.search_documents( query=message, max_results=config.rag_top_k, - collection_name=qdrant_collection_name + collection_name=qdrant_collection_name, + score_threshold=config.rag_score_threshold ) if rag_results: diff --git a/docker-compose.yml b/docker-compose.yml index 8210ba9..badc278 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -15,16 +15,18 @@ services: # Database migration service - runs once to apply migrations enclava-migrate: - build: + build: context: ./backend dockerfile: Dockerfile environment: - DATABASE_URL=postgresql://enclava_user:enclava_pass@enclava-postgres:5432/enclava_db + - JWT_SECRET=${JWT_SECRET:-your-jwt-secret-here} depends_on: - enclava-postgres command: ["/usr/local/bin/migrate.sh"] volumes: - ./backend:/app + - ./.env:/app/.env networks: - enclava-net restart: "no" # Run once and exit @@ -63,9 +65,9 @@ services: enclava-frontend: image: node:18-alpine working_dir: /app - command: sh -c "npm install && npm run dev" + command: sh -c "npm ci --ignore-scripts && npm run dev" environment: - # Required base URL (derives APP/API/WS URLs) + # Required base URL (derives APP/API/WS URLs) - BASE_URL=${BASE_URL} - NEXT_PUBLIC_BASE_URL=${BASE_URL} # Docker internal ports @@ -79,7 +81,7 @@ services: - "3002:3000" # Direct frontend access for development volumes: - ./frontend:/app - - /app/node_modules + - enclava-frontend-node-modules:/app/node_modules networks: - enclava-net restart: unless-stopped @@ -148,6 +150,7 @@ volumes: enclava-postgres-data: enclava-redis-data: enclava-qdrant-data: + enclava-frontend-node-modules: # enclava-ollama-data: networks: From a8fe7d6d29bef10f74e9a201c64224afc9e39d0b Mon Sep 17 00:00:00 2001 From: Aljaz Ceru Date: Mon, 22 Sep 2025 11:47:09 +0200 Subject: [PATCH 06/13] Backup before security middleware removal MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- .env | 152 +++++++++++++++++++++++++++++++++++++++++++++++++++ backend/.env | 0 2 files changed, 152 insertions(+) create mode 100644 .env create mode 100644 backend/.env diff --git a/.env b/.env new file mode 100644 index 0000000..9e074ae --- /dev/null +++ b/.env @@ -0,0 +1,152 @@ +# =================================== +# ENCLAVA MINIMAL CONFIGURATION +# =================================== +# Only essential environment variables that CANNOT have defaults +# Other settings should be configurable through the app UI + +# =================================== +# INFRASTRUCTURE (Required) +# =================================== +DATABASE_URL=postgresql://enclava_user:enclava_pass@enclava-postgres:5432/enclava_db +REDIS_URL=redis://enclava-redis:6379 + +# =================================== +# SECURITY CRITICAL (Required) +# =================================== +JWT_SECRET=your-super-secret-jwt-key-here-change-in-production +PRIVATEMODE_API_KEY=dfaea90e-df15-48d4-94ff-5ee243b846bb + +# Admin user (created on first startup only) +ADMIN_EMAIL=admin@example.com +ADMIN_PASSWORD=admin123 +API_RATE_LIMITING_ENABLED=false +# =================================== +# ADDITIONAL SECURITY SETTINGS (Optional but recommended) +# =================================== +# JWT Algorithm (default: HS256) +# JWT_ALGORITHM=HS256 + +# Token expiration times (in minutes) +# ACCESS_TOKEN_EXPIRE_MINUTES=30 +# REFRESH_TOKEN_EXPIRE_MINUTES=10080 +# SESSION_EXPIRE_MINUTES=1440 + +# API Key prefix (default: en_) +# API_KEY_PREFIX=en_ + +# Security thresholds (0.0-1.0) +# API_SECURITY_RISK_THRESHOLD=0.8 +# API_SECURITY_WARNING_THRESHOLD=0.6 +# API_SECURITY_ANOMALY_THRESHOLD=0.7 + +# IP security (comma-separated for multiple IPs) +# API_BLOCKED_IPS= +# API_ALLOWED_IPS= + +# =================================== +# APPLICATION BASE URL (Required - derives all URLs and CORS) +# =================================== +BASE_URL=localhost +# Frontend derives: APP_URL=http://localhost, API_URL=http://localhost, WS_URL=ws://localhost +# Backend derives: CORS_ORIGINS=["http://localhost"] + +# =================================== +# DOCKER NETWORKING (Required for containers) +# =================================== +BACKEND_INTERNAL_PORT=8000 +FRONTEND_INTERNAL_PORT=3000 +# Hosts are fixed: enclava-backend, enclava-frontend +# Upstreams derive: enclava-backend:8000, enclava-frontend:3000 + +# =================================== +# QDRANT (Required for RAG) +# =================================== +QDRANT_HOST=enclava-qdrant +QDRANT_PORT=6333 +QDRANT_URL=http://enclava-qdrant:6333 + +# =================================== +# OPTIONAL PRIVATEMODE SETTINGS (Have defaults) +# =================================== +# PRIVATEMODE_CACHE_MODE=none # Optional: defaults to 'none' +# PRIVATEMODE_CACHE_SALT= # Optional: defaults to empty + +# =================================== +# OPTIONAL CONFIGURATION (All have sensible defaults) +# =================================== + +# Application Settings +# APP_NAME=Enclava +# APP_DEBUG=false +# APP_LOG_LEVEL=INFO +# APP_HOST=0.0.0.0 +# APP_PORT=8000 + +# Security Features +API_SECURITY_ENABLED=false +# API_THREAT_DETECTION_ENABLED=true +# API_IP_REPUTATION_ENABLED=true +# API_ANOMALY_DETECTION_ENABLED=true +API_RATE_LIMITING_ENABLED=false +# API_SECURITY_HEADERS_ENABLED=true + +# Content Security Policy +# API_CSP_HEADER=default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline' + +# Rate Limiting (requests per minute/hour) +# API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE=300 +# API_RATE_LIMIT_AUTHENTICATED_PER_HOUR=5000 +# API_RATE_LIMIT_API_KEY_PER_MINUTE=1000 +# API_RATE_LIMIT_API_KEY_PER_HOUR=20000 +# API_RATE_LIMIT_PREMIUM_PER_MINUTE=5000 +# API_RATE_LIMIT_PREMIUM_PER_HOUR=100000 + +# Request Size Limits (in bytes) +# API_MAX_REQUEST_BODY_SIZE=10485760 # 10MB +# API_MAX_REQUEST_BODY_SIZE_PREMIUM=52428800 # 50MB +# MAX_UPLOAD_SIZE=10485760 # 10MB + +# Monitoring +# PROMETHEUS_ENABLED=true +# PROMETHEUS_PORT=9090 + +# Logging +# LOG_FORMAT=json +# LOG_LEVEL=INFO +# LOG_LLM_PROMPTS=false + +# Module Configuration +# MODULES_CONFIG_PATH=config/modules.yaml + +# Plugin Configuration +# PLUGINS_DIR=/plugins +# PLUGINS_CONFIG_PATH=config/plugins.yaml +# PLUGIN_REPOSITORY_URL=https://plugins.enclava.com +# PLUGIN_ENCRYPTION_KEY= + +# =================================== +# RAG EMBEDDING ENHANCED SETTINGS +# =================================== +# Enhanced embedding service configuration +RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE=60 +RAG_EMBEDDING_BATCH_SIZE=5 +RAG_EMBEDDING_RETRY_COUNT=3 +RAG_EMBEDDING_RETRY_DELAYS=1,2,4,8,16 +RAG_EMBEDDING_DELAY_BETWEEN_BATCHES=0.5 + +# Fallback embedding behavior +RAG_ALLOW_FALLBACK_EMBEDDINGS=true +RAG_WARN_ON_FALLBACK=true + +# Processing timeouts (in seconds) +RAG_DOCUMENT_PROCESSING_TIMEOUT=300 +RAG_EMBEDDING_GENERATION_TIMEOUT=120 +RAG_INDEXING_TIMEOUT=120 + +# =================================== +# SUMMARY +# =================================== +# Required: DATABASE_URL, REDIS_URL, JWT_SECRET, ADMIN_EMAIL, ADMIN_PASSWORD, BASE_URL +# Recommended: PRIVATEMODE_API_KEY, QDRANT_HOST, QDRANT_PORT +# Optional: All other settings have secure defaults +# =================================== diff --git a/backend/.env b/backend/.env new file mode 100644 index 0000000..e69de29 From 95d5b3a443cad87b23c3cd1da93387f0c7e20d3a Mon Sep 17 00:00:00 2001 From: Aljaz Ceru Date: Mon, 22 Sep 2025 11:48:11 +0200 Subject: [PATCH 07/13] Remove security and rate limiting middleware from backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Removed security middleware setup from main.py - Disabled security middleware functionality - Removed rate limiting middleware setup 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- backend/app/main.py | 8 +--- backend/app/middleware/security.py | 74 +++--------------------------- 2 files changed, 9 insertions(+), 73 deletions(-) diff --git a/backend/app/main.py b/backend/app/main.py index 40d51a3..8bea827 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -135,13 +135,9 @@ app.add_middleware( # Add analytics middleware setup_analytics_middleware(app) -# Add security middleware -from app.middleware.security import setup_security_middleware -setup_security_middleware(app, enabled=settings.API_SECURITY_ENABLED) +# Security middleware disabled - handled externally -# Add rate limiting middleware only for specific endpoints -from app.middleware.rate_limiting import RateLimitMiddleware -app.add_middleware(RateLimitMiddleware) +# Rate limiting middleware disabled - handled externally # Exception handlers diff --git a/backend/app/middleware/security.py b/backend/app/middleware/security.py index 57d2ebe..c7b7952 100644 --- a/backend/app/middleware/security.py +++ b/backend/app/middleware/security.py @@ -18,77 +18,17 @@ logger = get_logger(__name__) class SecurityMiddleware(BaseHTTPMiddleware): - """Security middleware for threat detection and request filtering""" - + """Security middleware for threat detection and request filtering - DISABLED""" + def __init__(self, app, enabled: bool = True): super().__init__(app) - self.enabled = enabled and settings.API_SECURITY_ENABLED - logger.info(f"SecurityMiddleware initialized, enabled: {self.enabled}") + self.enabled = False # Force disable regardless of settings + logger.info("SecurityMiddleware initialized, enabled: False (DISABLED)") async def dispatch(self, request: Request, call_next: Callable) -> Response: - """Process request through security analysis""" - if not self.enabled: - # Security disabled, pass through - return await call_next(request) - - # Skip security analysis for certain endpoints - if self._should_skip_security(request): - response = await call_next(request) - return self._add_security_headers(response) - - # Simple authentication check - drop requests without valid auth - if not self._has_valid_auth(request): - return JSONResponse( - content={"error": "Authentication required", "message": "Valid API key or authentication token required"}, - status_code=401, - headers={"WWW-Authenticate": "Bearer"} - ) - - try: - # Get user context if available - user_context = getattr(request.state, 'user', None) - - # Perform security analysis - start_time = time.time() - analysis = await threat_detection_service.analyze_request(request, user_context) - analysis_time = time.time() - start_time - - # Store analysis in request state for later use - request.state.security_analysis = analysis - - # Log security events (only for significant threats to reduce false positive noise) - # Only log if: being blocked OR risk score above warning threshold (0.6) - if analysis.is_threat and (analysis.should_block or analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD): - await self._log_security_event(request, analysis) - - # Check if request should be blocked (excluding rate limiting) - if analysis.should_block and not analysis.rate_limit_exceeded: - threat_detection_service.stats['threats_blocked'] += 1 - logger.warning(f"Blocked request from {request.client.host if request.client else 'unknown'}: " - f"risk_score={analysis.risk_score:.3f}, threats={len(analysis.threats)}") - - # Return security block response - return self._create_block_response(analysis) - - # Log warnings for medium-risk requests - if analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD: - logger.warning(f"High-risk request detected from {request.client.host if request.client else 'unknown'}: " - f"risk_score={analysis.risk_score:.3f}, auth_level={analysis.auth_level.value}") - - # Continue with request processing - response = await call_next(request) - - # Add security headers and metrics - response = self._add_security_headers(response) - response = self._add_security_metrics(response, analysis, analysis_time) - - return response - - except Exception as e: - logger.error(f"Security middleware error: {e}") - # Continue with request on security middleware errors to avoid breaking the app - response = await call_next(request) - return self._add_security_headers(response) + """Process request through security analysis - DISABLED""" + # Security disabled, always pass through + return await call_next(request) def _should_skip_security(self, request: Request) -> bool: """Determine if security analysis should be skipped for this request""" From 354b43494dd1e3dcdfea32defec17418ca53bfa9 Mon Sep 17 00:00:00 2001 From: Aljaz Ceru Date: Mon, 22 Sep 2025 11:49:13 +0200 Subject: [PATCH 08/13] Add verification script for security middleware removal MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Script verifies: - Environment settings are correct - Python syntax is valid - Docker configuration exists - No security import errors 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- verify_security_removal.py | 166 +++++++++++++++++++++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 verify_security_removal.py diff --git a/verify_security_removal.py b/verify_security_removal.py new file mode 100644 index 0000000..81c3e9b --- /dev/null +++ b/verify_security_removal.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +""" +Verification script for security middleware removal +""" +import subprocess +import sys +import time + +def run_command(cmd, cwd=None): + """Run a command and return the result""" + try: + result = subprocess.run( + cmd, + shell=True, + capture_output=True, + text=True, + cwd=cwd, + timeout=30 + ) + return result.returncode, result.stdout, result.stderr + except subprocess.TimeoutExpired: + return -1, "", "Command timed out" + +def test_backend_syntax(): + """Test if backend Python files have valid syntax""" + print("🔍 Testing backend Python syntax...") + + # Check main.py + code, stdout, stderr = run_command("python3 -m py_compile app/main.py", cwd="backend") + if code == 0: + print("✅ main.py syntax OK") + else: + print(f"❌ main.py syntax error: {stderr}") + return False + + # Check security middleware + code, stdout, stderr = run_command("python3 -m py_compile app/middleware/security.py", cwd="backend") + if code == 0: + print("✅ security.py syntax OK") + else: + print(f"❌ security.py syntax error: {stderr}") + return False + + return True + +def test_docker_build(): + """Test if Docker can build the backend service""" + print("\n🐳 Testing Docker backend build...") + + # Just check if the Dockerfile exists and is readable + try: + with open("backend/Dockerfile", "r") as f: + content = f.read() + if "FROM" in content and "python" in content: + print("✅ Dockerfile exists and looks valid") + return True + else: + print("❌ Dockerfile appears invalid") + return False + except FileNotFoundError: + print("❌ Dockerfile not found") + return False + +def test_env_settings(): + """Test if environment settings are correct""" + print("\n⚙️ Testing environment settings...") + + try: + with open(".env", "r") as f: + env_content = f.read() + + if "API_SECURITY_ENABLED=false" in env_content: + print("✅ Security is disabled in .env") + else: + print("❌ Security is not disabled in .env") + return False + + if "API_RATE_LIMITING_ENABLED=false" in env_content: + print("✅ Rate limiting is disabled in .env") + else: + print("❌ Rate limiting is not disabled in .env") + return False + + return True + except FileNotFoundError: + print("❌ .env file not found") + return False + +def test_imports(): + """Test if the main application can be imported without security dependencies""" + print("\n📦 Testing import dependencies...") + + # Create a minimal test script + test_script = """ +import sys +sys.path.insert(0, 'backend') + +try: + # Test if we can create the app without security middleware + from app.main import app + print("✅ App can be imported successfully") +except ImportError as e: + print(f"❌ Import error: {e}") + sys.exit(1) +except Exception as e: + print(f"❌ Other error: {e}") + sys.exit(1) +""" + + # Save test script + with open("test_import.py", "w") as f: + f.write(test_script) + + # Run test (will likely fail due to missing dependencies, but should not fail due to security imports) + code, stdout, stderr = run_command("python3 test_import.py") + + # Clean up + import os + os.remove("test_import.py") + + # We expect this to fail due to missing FastAPI, but not due to security imports + if "security" in stderr.lower() and "No module named" not in stderr: + print("❌ Security import errors detected") + return False + else: + print("✅ No security import errors detected") + return True + +def main(): + """Run all verification tests""" + print("🚀 Starting verification of security middleware removal...\n") + + tests = [ + ("Environment Settings", test_env_settings), + ("Python Syntax", test_backend_syntax), + ("Docker Configuration", test_docker_build), + ("Import Dependencies", test_imports), + ] + + results = [] + for test_name, test_func in tests: + print(f"\n--- {test_name} ---") + result = test_func() + results.append((test_name, result)) + + # Print summary + print("\n" + "="*50) + print("📊 VERIFICATION SUMMARY") + print("="*50) + + for test_name, result in results: + status = "✅ PASS" if result else "❌ FAIL" + print(f"{test_name}: {status}") + + all_passed = all(result for _, result in results) + + if all_passed: + print("\n🎉 All tests passed! Security middleware has been successfully removed.") + else: + print("\n⚠️ Some tests failed. Please review the issues above.") + + return all_passed + +if __name__ == "__main__": + success = main() + sys.exit(0 if success else 1) \ No newline at end of file From f8d127ff4211f79f11ceeb77d63b436e65dbe205 Mon Sep 17 00:00:00 2001 From: Aljaz Ceru Date: Tue, 23 Sep 2025 15:26:54 +0200 Subject: [PATCH 09/13] rag improvements --- .env | 2 +- .env.example | 12 +- backend/app/api/internal_v1/__init__.py | 7 +- backend/app/api/v1/__init__.py | 3 - backend/app/api/v1/llm.py | 3 +- backend/app/api/v1/rag.py | 308 +++++++- backend/app/api/v1/security.py | 251 ------ backend/app/api/v1/settings.py | 4 - backend/app/core/config.py | 44 +- backend/app/core/threat_detection.py | 744 ------------------ backend/app/main.py | 12 +- backend/app/middleware/rate_limiting.py | 371 --------- backend/app/middleware/security.py | 210 ----- backend/app/services/document_processor.py | 148 +++- backend/app/services/embedding_service.py | 24 +- .../services/enhanced_embedding_service.py | 28 +- backend/app/services/llm/config.py | 14 +- backend/app/services/llm/metrics.py | 2 - backend/app/services/llm/models.py | 1 - .../app/services/llm/providers/privatemode.py | 2 + backend/app/services/llm/security.py | 325 -------- backend/app/services/llm/service.py | 286 ++----- .../app/services/llm/token_rate_limiter.py | 153 ---- backend/app/services/rag_service.py | 9 +- backend/modules/rag/main.py | 131 ++- frontend/src/app/api/auth/login/route.ts | 2 +- frontend/src/app/rag/page.tsx | 25 +- .../src/components/rag/document-browser.tsx | 55 +- frontend/src/components/ui/navigation.tsx | 5 +- nginx/nginx.conf | 64 +- 30 files changed, 817 insertions(+), 2428 deletions(-) delete mode 100644 backend/app/api/v1/security.py delete mode 100644 backend/app/core/threat_detection.py delete mode 100644 backend/app/middleware/rate_limiting.py delete mode 100644 backend/app/middleware/security.py delete mode 100644 backend/app/services/llm/security.py delete mode 100644 backend/app/services/llm/token_rate_limiter.py diff --git a/.env b/.env index 9e074ae..b8d34af 100644 --- a/.env +++ b/.env @@ -46,7 +46,7 @@ API_RATE_LIMITING_ENABLED=false # =================================== # APPLICATION BASE URL (Required - derives all URLs and CORS) # =================================== -BASE_URL=localhost +BASE_URL=localhost:80 # Frontend derives: APP_URL=http://localhost, API_URL=http://localhost, WS_URL=ws://localhost # Backend derives: CORS_ORIGINS=["http://localhost"] diff --git a/.env.example b/.env.example index cf6d8f1..b9dd120 100644 --- a/.env.example +++ b/.env.example @@ -65,6 +65,16 @@ QDRANT_HOST=enclava-qdrant QDRANT_PORT=6333 QDRANT_URL=http://enclava-qdrant:6333 +# =================================== +# RAG EMBEDDING CONFIGURATION (Optional overrides) +# =================================== +# These control embedding throughput to avoid provider 429s. +# Defaults are conservative; uncomment to override. +# RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE=12 +# RAG_EMBEDDING_BATCH_SIZE=3 +# RAG_EMBEDDING_DELAY_BETWEEN_BATCHES=1.0 # seconds +# RAG_EMBEDDING_DELAY_PER_REQUEST=0.5 # seconds + # =================================== # OPTIONAL PRIVATEMODE SETTINGS (Have defaults) # =================================== @@ -130,4 +140,4 @@ QDRANT_URL=http://enclava-qdrant:6333 # Required: DATABASE_URL, REDIS_URL, JWT_SECRET, ADMIN_EMAIL, ADMIN_PASSWORD, BASE_URL # Recommended: PRIVATEMODE_API_KEY, QDRANT_HOST, QDRANT_PORT # Optional: All other settings have secure defaults -# =================================== \ No newline at end of file +# =================================== diff --git a/backend/app/api/internal_v1/__init__.py b/backend/app/api/internal_v1/__init__.py index 97e8510..29af4ab 100644 --- a/backend/app/api/internal_v1/__init__.py +++ b/backend/app/api/internal_v1/__init__.py @@ -12,8 +12,8 @@ from ..v1.audit import router as audit_router from ..v1.settings import router as settings_router from ..v1.analytics import router as analytics_router from ..v1.rag import router as rag_router +from ..rag_debug import router as rag_debug_router from ..v1.prompt_templates import router as prompt_templates_router -from ..v1.security import router as security_router from ..v1.plugin_registry import router as plugin_registry_router from ..v1.platform import router as platform_router from ..v1.llm_internal import router as llm_internal_router @@ -52,11 +52,12 @@ internal_api_router.include_router(analytics_router, prefix="/analytics", tags=[ # Include RAG routes (frontend RAG document management) internal_api_router.include_router(rag_router, prefix="/rag", tags=["internal-rag"]) +# Include RAG debug routes (for demo and debugging) +internal_api_router.include_router(rag_debug_router, prefix="/rag/debug", tags=["internal-rag-debug"]) + # Include prompt template routes (frontend prompt template management) internal_api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["internal-prompt-templates"]) -# Include security routes (frontend security settings) -internal_api_router.include_router(security_router, prefix="/security", tags=["internal-security"]) # Include plugin registry routes (frontend plugin management) internal_api_router.include_router(plugin_registry_router, prefix="/plugins", tags=["internal-plugins"]) diff --git a/backend/app/api/v1/__init__.py b/backend/app/api/v1/__init__.py index 6f66641..f9412e4 100644 --- a/backend/app/api/v1/__init__.py +++ b/backend/app/api/v1/__init__.py @@ -16,7 +16,6 @@ from .analytics import router as analytics_router from .rag import router as rag_router from .chatbot import router as chatbot_router from .prompt_templates import router as prompt_templates_router -from .security import router as security_router from .plugin_registry import router as plugin_registry_router # Create main API router @@ -61,8 +60,6 @@ api_router.include_router(chatbot_router, prefix="/chatbot", tags=["chatbot"]) # Include prompt template routes api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["prompt-templates"]) -# Include security routes -api_router.include_router(security_router, prefix="/security", tags=["security"]) # Include plugin registry routes diff --git a/backend/app/api/v1/llm.py b/backend/app/api/v1/llm.py index c30d797..5fdc20c 100644 --- a/backend/app/api/v1/llm.py +++ b/backend/app/api/v1/llm.py @@ -745,8 +745,7 @@ async def get_llm_metrics( "total_requests": metrics.total_requests, "successful_requests": metrics.successful_requests, "failed_requests": metrics.failed_requests, - "security_blocked_requests": metrics.security_blocked_requests, - "average_latency_ms": metrics.average_latency_ms, + "average_latency_ms": metrics.average_latency_ms, "average_risk_score": metrics.average_risk_score, "provider_metrics": metrics.provider_metrics, "last_updated": metrics.last_updated.isoformat() diff --git a/backend/app/api/v1/rag.py b/backend/app/api/v1/rag.py index b5d00cf..0e65c2f 100644 --- a/backend/app/api/v1/rag.py +++ b/backend/app/api/v1/rag.py @@ -3,12 +3,14 @@ RAG API Endpoints Provides REST API for RAG (Retrieval Augmented Generation) operations """ -from typing import List, Optional +from typing import List, Optional, Dict, Any from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status from fastapi.responses import StreamingResponse from sqlalchemy.ext.asyncio import AsyncSession from pydantic import BaseModel import io +import asyncio +from datetime import datetime from app.db.database import get_db from app.core.security import get_current_user @@ -16,6 +18,9 @@ from app.models.user import User from app.services.rag_service import RAGService from app.utils.exceptions import APIException +# Import RAG module from module manager +from app.services.module_manager import module_manager + router = APIRouter(tags=["RAG"]) @@ -78,14 +83,25 @@ async def get_collections( db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user) ): - """Get all RAG collections from Qdrant (source of truth) with PostgreSQL metadata""" + """Get all RAG collections - live data directly from Qdrant (source of truth)""" try: - rag_service = RAGService(db) - collections_data = await rag_service.get_all_collections(skip=skip, limit=limit) + from app.services.qdrant_stats_service import qdrant_stats_service + + # Get live stats from Qdrant + stats_data = await qdrant_stats_service.get_collections_stats() + collections = stats_data.get("collections", []) + + # Apply pagination + start_idx = skip + end_idx = skip + limit + paginated_collections = collections[start_idx:end_idx] + return { "success": True, - "collections": collections_data, - "total": len(collections_data) + "collections": paginated_collections, + "total": len(collections), + "total_documents": stats_data.get("total_documents", 0), + "total_size_bytes": stats_data.get("total_size_bytes", 0) } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -116,6 +132,62 @@ async def create_collection( raise HTTPException(status_code=500, detail=str(e)) +@router.get("/stats", response_model=dict) +async def get_rag_stats( + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Get overall RAG statistics - live data directly from Qdrant""" + try: + from app.services.qdrant_stats_service import qdrant_stats_service + + # Get live stats from Qdrant + stats_data = await qdrant_stats_service.get_collections_stats() + + # Calculate active collections (collections with documents) + active_collections = sum(1 for col in stats_data.get("collections", []) if col.get("document_count", 0) > 0) + + # Calculate processing documents from database + processing_docs = 0 + try: + from sqlalchemy import select + from app.models.rag_document import RagDocument, ProcessingStatus + + result = await db.execute( + select(RagDocument).where(RagDocument.status == ProcessingStatus.PROCESSING) + ) + processing_docs = len(result.scalars().all()) + except Exception: + pass # If database query fails, default to 0 + + response_data = { + "success": True, + "stats": { + "collections": { + "total": stats_data.get("total_collections", 0), + "active": active_collections + }, + "documents": { + "total": stats_data.get("total_documents", 0), + "processing": processing_docs, + "processed": stats_data.get("total_documents", 0) # Indexed documents + }, + "storage": { + "total_size_bytes": stats_data.get("total_size_bytes", 0), + "total_size_mb": round(stats_data.get("total_size_bytes", 0) / (1024 * 1024), 2) + }, + "vectors": { + "total": stats_data.get("total_documents", 0) # Same as documents for RAG + }, + "last_updated": datetime.utcnow().isoformat() + } + } + + return response_data + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @router.get("/collections/{collection_id}", response_model=dict) async def get_collection( collection_id: int, @@ -225,21 +297,65 @@ async def upload_document( try: # Read file content file_content = await file.read() - + if len(file_content) == 0: raise HTTPException(status_code=400, detail="Empty file uploaded") - + if len(file_content) > 50 * 1024 * 1024: # 50MB limit raise HTTPException(status_code=400, detail="File too large (max 50MB)") - + + # Validate file can be read before processing + filename = file.filename or "unknown" + file_extension = filename.split('.')[-1].lower() if '.' in filename else '' + + try: + # Test file readability based on type + if file_extension == 'jsonl': + # Validate JSONL format - try to parse first few lines + try: + content_str = file_content.decode('utf-8') + lines = content_str.strip().split('\n')[:5] # Check first 5 lines + import json + for i, line in enumerate(lines): + if line.strip(): # Skip empty lines + json.loads(line) # Will raise JSONDecodeError if invalid + except UnicodeDecodeError: + raise HTTPException(status_code=400, detail="File is not valid UTF-8 text") + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSONL format: {str(e)}") + + elif file_extension in ['txt', 'md', 'py', 'js', 'html', 'css', 'json']: + # Validate text files can be decoded + try: + file_content.decode('utf-8') + except UnicodeDecodeError: + raise HTTPException(status_code=400, detail="File is not valid UTF-8 text") + + elif file_extension in ['pdf']: + # For PDF files, just check if it starts with PDF signature + if not file_content.startswith(b'%PDF'): + raise HTTPException(status_code=400, detail="Invalid PDF file format") + + elif file_extension in ['docx', 'xlsx', 'pptx']: + # For Office documents, check ZIP signature + if not file_content.startswith(b'PK'): + raise HTTPException(status_code=400, detail=f"Invalid {file_extension.upper()} file format") + + # For other file types, we'll rely on the document processor + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=400, detail=f"File validation failed: {str(e)}") + rag_service = RAGService(db) document = await rag_service.upload_document( collection_id=collection_id, file_content=file_content, - filename=file.filename or "unknown", + filename=filename, content_type=file.content_type ) - + return { "success": True, "document": document.to_dict(), @@ -362,21 +478,167 @@ async def download_document( raise HTTPException(status_code=500, detail=str(e)) -# Stats Endpoint -@router.get("/stats", response_model=dict) -async def get_rag_stats( - db: AsyncSession = Depends(get_db), +# Debug Endpoints + +@router.post("/debug/search") +async def search_with_debug( + query: str, + max_results: int = 10, + score_threshold: float = 0.3, + collection_name: str = None, + config: Dict[str, Any] = None, current_user: User = Depends(get_current_user) -): - """Get RAG system statistics""" +) -> Dict[str, Any]: + """ + Enhanced search with comprehensive debug information + """ + # Get RAG module from module manager + rag_module = module_manager.modules.get('rag') + if not rag_module or not rag_module.enabled: + raise HTTPException(status_code=503, detail="RAG module not initialized") + + debug_info = {} + start_time = datetime.utcnow() + try: - rag_service = RAGService(db) - stats = await rag_service.get_stats() - + # Apply configuration if provided + if config: + # Update RAG config temporarily + original_config = rag_module.config.copy() + rag_module.config.update(config) + + # Generate query embedding (with or without prefix) + if config and config.get("use_query_prefix"): + optimized_query = f"query: {query}" + else: + optimized_query = query + + query_embedding = await rag_module._generate_embedding(optimized_query) + + # Store embedding info for debug + if config and config.get("debug", {}).get("show_embeddings"): + debug_info["query_embedding"] = query_embedding[:10] # First 10 dimensions + debug_info["embedding_dimension"] = len(query_embedding) + debug_info["optimized_query"] = optimized_query + + # Perform search + search_start = asyncio.get_event_loop().time() + results = await rag_module.search_documents( + query, + max_results=max_results, + score_threshold=score_threshold, + collection_name=collection_name + ) + search_time = (asyncio.get_event_loop().time() - search_start) * 1000 + + # Calculate score statistics + scores = [r.score for r in results if r.score is not None] + if scores: + import statistics + debug_info["score_stats"] = { + "min": min(scores), + "max": max(scores), + "avg": statistics.mean(scores), + "stddev": statistics.stdev(scores) if len(scores) > 1 else 0 + } + + # Get collection statistics + try: + from qdrant_client.http.models import Filter + collection_name = collection_name or rag_module.default_collection_name + + # Count total documents + count_result = rag_module.qdrant_client.count( + collection_name=collection_name, + count_filter=Filter(must=[]) + ) + total_points = count_result.count + + # Get unique documents and languages + scroll_result = rag_module.qdrant_client.scroll( + collection_name=collection_name, + limit=1000, # Sample for stats + with_payload=True, + with_vectors=False + ) + + unique_docs = set() + languages = set() + + for point in scroll_result[0]: + payload = point.payload or {} + doc_id = payload.get("document_id") + if doc_id: + unique_docs.add(doc_id) + + language = payload.get("language") + if language: + languages.add(language) + + debug_info["collection_stats"] = { + "total_documents": len(unique_docs), + "total_chunks": total_points, + "languages": sorted(list(languages)) + } + + except Exception as e: + debug_info["collection_stats_error"] = str(e) + + # Enhance results with debug info + enhanced_results = [] + for result in results: + enhanced_result = { + "document": { + "id": result.document.id, + "content": result.document.content, + "metadata": result.document.metadata + }, + "score": result.score, + "debug_info": {} + } + + # Add hybrid search debug info if available + metadata = result.document.metadata or {} + if "_vector_score" in metadata: + enhanced_result["debug_info"]["vector_score"] = metadata["_vector_score"] + if "_bm25_score" in metadata: + enhanced_result["debug_info"]["bm25_score"] = metadata["_bm25_score"] + + enhanced_results.append(enhanced_result) + + # Note: Analytics logging disabled (module not available) + return { - "success": True, - "stats": stats + "results": enhanced_results, + "debug_info": debug_info, + "search_time_ms": search_time, + "timestamp": start_time.isoformat() } + except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file + # Note: Analytics logging disabled (module not available) + raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}") + + finally: + # Restore original config if modified + if config and 'original_config' in locals(): + rag_module.config = original_config + + +@router.get("/debug/config") +async def get_current_config( + current_user: User = Depends(get_current_user) +) -> Dict[str, Any]: + """Get current RAG configuration""" + # Get RAG module from module manager + rag_module = module_manager.modules.get('rag') + if not rag_module or not rag_module.enabled: + raise HTTPException(status_code=503, detail="RAG module not initialized") + + return { + "config": rag_module.config, + "embedding_model": rag_module.embedding_model, + "enabled": rag_module.enabled, + "collections": await rag_module._get_collections_safely() + } diff --git a/backend/app/api/v1/security.py b/backend/app/api/v1/security.py deleted file mode 100644 index 838dd6f..0000000 --- a/backend/app/api/v1/security.py +++ /dev/null @@ -1,251 +0,0 @@ -""" -Security API endpoints for monitoring and configuration -""" - -from typing import Dict, Any, List, Optional -from fastapi import APIRouter, Depends, HTTPException, Request, status -from pydantic import BaseModel, Field - -from app.core.security import get_current_active_user, RequiresRole -from app.middleware.security import get_security_stats, get_request_auth_level, get_request_risk_score -from app.core.config import settings -from app.core.logging import get_logger - -logger = get_logger(__name__) - -router = APIRouter(tags=["security"]) - - -# Pydantic models for API responses -class SecurityStatsResponse(BaseModel): - """Security statistics response model""" - total_requests_analyzed: int - threats_detected: int - threats_blocked: int - anomalies_detected: int - rate_limits_exceeded: int - avg_analysis_time: float - threat_types: Dict[str, int] - threat_levels: Dict[str, int] - top_attacking_ips: List[tuple] - security_enabled: bool - threat_detection_enabled: bool - rate_limiting_enabled: bool - - -class SecurityConfigResponse(BaseModel): - """Security configuration response model""" - security_enabled: bool = Field(description="Overall security system enabled") - threat_detection_enabled: bool = Field(description="Threat detection analysis enabled") - rate_limiting_enabled: bool = Field(description="Rate limiting enabled") - ip_reputation_enabled: bool = Field(description="IP reputation checking enabled") - anomaly_detection_enabled: bool = Field(description="Anomaly detection enabled") - security_headers_enabled: bool = Field(description="Security headers enabled") - - # Rate limiting settings - unauthenticated_per_minute: int = Field(description="Rate limit for unauthenticated requests per minute") - authenticated_per_minute: int = Field(description="Rate limit for authenticated users per minute") - api_key_per_minute: int = Field(description="Rate limit for API key users per minute") - premium_per_minute: int = Field(description="Rate limit for premium users per minute") - - # Security thresholds - risk_threshold: float = Field(description="Risk score threshold for blocking requests") - warning_threshold: float = Field(description="Risk score threshold for warnings") - anomaly_threshold: float = Field(description="Anomaly severity threshold") - - # IP settings - blocked_ips: List[str] = Field(description="List of blocked IP addresses") - allowed_ips: List[str] = Field(description="List of allowed IP addresses (empty = allow all)") - - -class RateLimitInfoResponse(BaseModel): - """Rate limit information for current request""" - auth_level: str = Field(description="Authentication level (unauthenticated, authenticated, api_key, premium)") - current_limits: Dict[str, int] = Field(description="Current rate limits for this auth level") - remaining_requests: Optional[Dict[str, int]] = Field(description="Estimated remaining requests (if available)") - - -@router.get("/stats", response_model=SecurityStatsResponse) -async def get_security_statistics( - current_user: Dict[str, Any] = Depends(RequiresRole("admin")) -): - """ - Get security system statistics - - Requires admin role. Returns comprehensive statistics about: - - Request analysis counts - - Threat detection results - - Rate limiting enforcement - - Top attacking IPs - - Performance metrics - """ - try: - stats = get_security_stats() - return SecurityStatsResponse(**stats) - except Exception as e: - logger.error(f"Error getting security stats: {e}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve security statistics" - ) - - -@router.get("/config", response_model=SecurityConfigResponse) -async def get_security_config( - current_user: Dict[str, Any] = Depends(RequiresRole("admin")) -): - """ - Get current security configuration - - Requires admin role. Returns current security settings including: - - Feature enablement flags - - Rate limiting thresholds - - Security thresholds - - IP allowlists/blocklists - """ - return SecurityConfigResponse( - security_enabled=settings.API_SECURITY_ENABLED, - threat_detection_enabled=settings.API_THREAT_DETECTION_ENABLED, - rate_limiting_enabled=settings.API_RATE_LIMITING_ENABLED, - ip_reputation_enabled=settings.API_IP_REPUTATION_ENABLED, - anomaly_detection_enabled=settings.API_ANOMALY_DETECTION_ENABLED, - security_headers_enabled=settings.API_SECURITY_HEADERS_ENABLED, - - unauthenticated_per_minute=settings.API_RATE_LIMIT_UNAUTHENTICATED_PER_MINUTE, - authenticated_per_minute=settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE, - api_key_per_minute=settings.API_RATE_LIMIT_API_KEY_PER_MINUTE, - premium_per_minute=settings.API_RATE_LIMIT_PREMIUM_PER_MINUTE, - - risk_threshold=settings.API_SECURITY_RISK_THRESHOLD, - warning_threshold=settings.API_SECURITY_WARNING_THRESHOLD, - anomaly_threshold=settings.API_SECURITY_ANOMALY_THRESHOLD, - - blocked_ips=settings.API_BLOCKED_IPS, - allowed_ips=settings.API_ALLOWED_IPS - ) - - -@router.get("/status") -async def get_security_status( - request: Request, - current_user: Dict[str, Any] = Depends(get_current_active_user) -): - """ - Get security status for current request - - Returns information about the security analysis of the current request: - - Authentication level - - Risk score (if available) - - Rate limiting status - """ - auth_level = get_request_auth_level(request) - risk_score = get_request_risk_score(request) - - # Get rate limits for current auth level - from app.core.threat_detection import AuthLevel - try: - auth_enum = AuthLevel(auth_level) - from app.core.threat_detection import threat_detection_service - minute_limit, hour_limit = threat_detection_service.get_rate_limits(auth_enum) - - rate_limit_info = RateLimitInfoResponse( - auth_level=auth_level, - current_limits={ - "per_minute": minute_limit, - "per_hour": hour_limit - }, - remaining_requests=None # We don't track remaining requests in current implementation - ) - except ValueError: - rate_limit_info = RateLimitInfoResponse( - auth_level=auth_level, - current_limits={}, - remaining_requests=None - ) - - return { - "security_enabled": settings.API_SECURITY_ENABLED, - "auth_level": auth_level, - "risk_score": round(risk_score, 3) if risk_score > 0 else None, - "rate_limit_info": rate_limit_info.dict(), - "security_headers_enabled": settings.API_SECURITY_HEADERS_ENABLED - } - - -@router.post("/test") -async def test_security_analysis( - request: Request, - current_user: Dict[str, Any] = Depends(RequiresRole("admin")) -): - """ - Test security analysis on current request - - Requires admin role. Manually triggers security analysis on the current request - and returns detailed results. Useful for testing security rules and thresholds. - """ - try: - from app.middleware.security import analyze_request_security - - analysis = await analyze_request_security(request, current_user) - - return { - "analysis_complete": True, - "is_threat": analysis.is_threat, - "risk_score": round(analysis.risk_score, 3), - "auth_level": analysis.auth_level.value, - "should_block": analysis.should_block, - "rate_limit_exceeded": analysis.rate_limit_exceeded, - "threat_count": len(analysis.threats), - "threats": [ - { - "type": threat.threat_type, - "level": threat.level.value, - "confidence": round(threat.confidence, 3), - "description": threat.description, - "mitigation": threat.mitigation - } - for threat in analysis.threats - ], - "recommendations": analysis.recommendations - } - except Exception as e: - logger.error(f"Error in security analysis test: {e}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to perform security analysis test" - ) - - -@router.get("/health") -async def security_health_check(): - """ - Security system health check - - Public endpoint that returns the health status of the security system. - Does not require authentication. - """ - try: - stats = get_security_stats() - - # Basic health checks - is_healthy = ( - settings.API_SECURITY_ENABLED and - stats.get("total_requests_analyzed", 0) >= 0 and - stats.get("avg_analysis_time", 0) < 1.0 # Analysis should be under 1 second - ) - - return { - "status": "healthy" if is_healthy else "degraded", - "security_enabled": settings.API_SECURITY_ENABLED, - "threat_detection_enabled": settings.API_THREAT_DETECTION_ENABLED, - "rate_limiting_enabled": settings.API_RATE_LIMITING_ENABLED, - "avg_analysis_time_ms": round(stats.get("avg_analysis_time", 0) * 1000, 2), - "total_requests_analyzed": stats.get("total_requests_analyzed", 0) - } - except Exception as e: - logger.error(f"Security health check failed: {e}") - return { - "status": "unhealthy", - "error": "Security system error", - "security_enabled": settings.API_SECURITY_ENABLED - } \ No newline at end of file diff --git a/backend/app/api/v1/settings.py b/backend/app/api/v1/settings.py index 8595ad6..4b97e25 100644 --- a/backend/app/api/v1/settings.py +++ b/backend/app/api/v1/settings.py @@ -97,7 +97,6 @@ SETTINGS_STORE: Dict[str, Dict[str, Any]] = { "api": { # Security Settings "security_enabled": {"value": True, "type": "boolean", "description": "Enable API security system"}, - "threat_detection_enabled": {"value": True, "type": "boolean", "description": "Enable threat detection analysis"}, "rate_limiting_enabled": {"value": True, "type": "boolean", "description": "Enable rate limiting"}, "ip_reputation_enabled": {"value": True, "type": "boolean", "description": "Enable IP reputation checking"}, "anomaly_detection_enabled": {"value": True, "type": "boolean", "description": "Enable anomaly detection"}, @@ -112,7 +111,6 @@ SETTINGS_STORE: Dict[str, Dict[str, Any]] = { "rate_limit_premium_per_hour": {"value": 100000, "type": "integer", "description": "Rate limit for premium users per hour"}, # Security Thresholds - "security_risk_threshold": {"value": 0.8, "type": "float", "description": "Risk score threshold for blocking requests (0.0-1.0)"}, "security_warning_threshold": {"value": 0.6, "type": "float", "description": "Risk score threshold for warnings (0.0-1.0)"}, "anomaly_threshold": {"value": 0.7, "type": "float", "description": "Anomaly severity threshold (0.0-1.0)"}, @@ -601,7 +599,6 @@ async def reset_to_defaults( "api": { # Security Settings "security_enabled": {"value": True, "type": "boolean"}, - "threat_detection_enabled": {"value": True, "type": "boolean"}, "rate_limiting_enabled": {"value": True, "type": "boolean"}, "ip_reputation_enabled": {"value": True, "type": "boolean"}, "anomaly_detection_enabled": {"value": True, "type": "boolean"}, @@ -616,7 +613,6 @@ async def reset_to_defaults( "rate_limit_premium_per_hour": {"value": 100000, "type": "integer"}, # Security Thresholds - "security_risk_threshold": {"value": 0.8, "type": "float"}, "security_warning_threshold": {"value": 0.6, "type": "float"}, "anomaly_threshold": {"value": 0.7, "type": "float"}, diff --git a/backend/app/core/config.py b/backend/app/core/config.py index f3ac614..7d53387 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -17,6 +17,8 @@ class Settings(BaseSettings): APP_LOG_LEVEL: str = os.getenv("APP_LOG_LEVEL", "INFO") APP_HOST: str = os.getenv("APP_HOST", "0.0.0.0") APP_PORT: int = int(os.getenv("APP_PORT", "8000")) + BACKEND_INTERNAL_PORT: int = int(os.getenv("BACKEND_INTERNAL_PORT", "8000")) + FRONTEND_INTERNAL_PORT: int = int(os.getenv("FRONTEND_INTERNAL_PORT", "3000")) # Detailed logging for LLM interactions LOG_LLM_PROMPTS: bool = os.getenv("LOG_LLM_PROMPTS", "False").lower() == "true" # Set to True to log prompts and context sent to LLM @@ -73,16 +75,11 @@ class Settings(BaseSettings): QDRANT_HOST: str = os.getenv("QDRANT_HOST", "localhost") QDRANT_PORT: int = int(os.getenv("QDRANT_PORT", "6333")) QDRANT_API_KEY: Optional[str] = os.getenv("QDRANT_API_KEY") + QDRANT_URL: str = os.getenv("QDRANT_URL", "http://localhost:6333") - # API & Security Settings - API_SECURITY_ENABLED: bool = os.getenv("API_SECURITY_ENABLED", "True").lower() == "true" - API_THREAT_DETECTION_ENABLED: bool = os.getenv("API_THREAT_DETECTION_ENABLED", "True").lower() == "true" - API_IP_REPUTATION_ENABLED: bool = os.getenv("API_IP_REPUTATION_ENABLED", "True").lower() == "true" - API_ANOMALY_DETECTION_ENABLED: bool = os.getenv("API_ANOMALY_DETECTION_ENABLED", "True").lower() == "true" - + # Rate Limiting Configuration - API_RATE_LIMITING_ENABLED: bool = os.getenv("API_RATE_LIMITING_ENABLED", "True").lower() == "true" - + # PrivateMode Standard tier limits (organization-level, not per user) # These are shared across all API keys and users in the organization PRIVATEMODE_REQUESTS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_MINUTE", "20")) @@ -101,23 +98,14 @@ class Settings(BaseSettings): # Premium/Enterprise API keys API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")) # Match PrivateMode API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200")) - - # Security Thresholds - API_SECURITY_RISK_THRESHOLD: float = float(os.getenv("API_SECURITY_RISK_THRESHOLD", "0.8")) # Block requests above this risk score - API_SECURITY_WARNING_THRESHOLD: float = float(os.getenv("API_SECURITY_WARNING_THRESHOLD", "0.6")) # Log warnings above this threshold - API_SECURITY_ANOMALY_THRESHOLD: float = float(os.getenv("API_SECURITY_ANOMALY_THRESHOLD", "0.7")) # Flag anomalies above this threshold - + # Request Size Limits API_MAX_REQUEST_BODY_SIZE: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE", "10485760")) # 10MB API_MAX_REQUEST_BODY_SIZE_PREMIUM: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE_PREMIUM", "52428800")) # 50MB for premium # IP Security - API_BLOCKED_IPS: List[str] = os.getenv("API_BLOCKED_IPS", "").split(",") if os.getenv("API_BLOCKED_IPS") else [] - API_ALLOWED_IPS: List[str] = os.getenv("API_ALLOWED_IPS", "").split(",") if os.getenv("API_ALLOWED_IPS") else [] - API_IP_REPUTATION_CACHE_TTL: int = int(os.getenv("API_IP_REPUTATION_CACHE_TTL", "3600")) # 1 hour # Security Headers - API_SECURITY_HEADERS_ENABLED: bool = os.getenv("API_SECURITY_HEADERS_ENABLED", "True").lower() == "true" API_CSP_HEADER: str = os.getenv("API_CSP_HEADER", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'") # Monitoring @@ -129,6 +117,19 @@ class Settings(BaseSettings): # Module configuration MODULES_CONFIG_PATH: str = os.getenv("MODULES_CONFIG_PATH", "config/modules.yaml") + + # RAG Embedding Configuration + RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE: int = int(os.getenv("RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE", "12")) + RAG_EMBEDDING_BATCH_SIZE: int = int(os.getenv("RAG_EMBEDDING_BATCH_SIZE", "3")) + RAG_EMBEDDING_RETRY_COUNT: int = int(os.getenv("RAG_EMBEDDING_RETRY_COUNT", "3")) + RAG_EMBEDDING_RETRY_DELAYS: str = os.getenv("RAG_EMBEDDING_RETRY_DELAYS", "1,2,4,8,16") + RAG_EMBEDDING_DELAY_BETWEEN_BATCHES: float = float(os.getenv("RAG_EMBEDDING_DELAY_BETWEEN_BATCHES", "1.0")) + RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5")) + RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true" + RAG_WARN_ON_FALLBACK: bool = os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true" + RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300")) + RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120")) + RAG_INDEXING_TIMEOUT: int = int(os.getenv("RAG_INDEXING_TIMEOUT", "120")) # Plugin configuration PLUGINS_DIR: str = os.getenv("PLUGINS_DIR", "/plugins") @@ -142,9 +143,12 @@ class Settings(BaseSettings): model_config = { "env_file": ".env", - "case_sensitive": True + "case_sensitive": True, + # Ignore unknown environment variables to avoid validation errors + # when optional/deprecated flags are present in .env + "extra": "ignore", } # Global settings instance -settings = Settings() \ No newline at end of file +settings = Settings() diff --git a/backend/app/core/threat_detection.py b/backend/app/core/threat_detection.py deleted file mode 100644 index cac2c7b..0000000 --- a/backend/app/core/threat_detection.py +++ /dev/null @@ -1,744 +0,0 @@ -""" -Core threat detection and security analysis for the platform -""" - -import re -import time -from collections import defaultdict, deque -from dataclasses import dataclass, field -from datetime import datetime, timedelta -from enum import Enum -from typing import Dict, List, Optional, Set, Tuple, Any, Union -from urllib.parse import unquote - -from fastapi import Request -from app.core.config import settings -from app.core.logging import get_logger - -logger = get_logger(__name__) - - -class ThreatLevel(Enum): - """Threat severity levels""" - LOW = "low" - MEDIUM = "medium" - HIGH = "high" - CRITICAL = "critical" - - -class AuthLevel(Enum): - """Authentication levels for rate limiting""" - AUTHENTICATED = "authenticated" - API_KEY = "api_key" - PREMIUM = "premium" - - -@dataclass -class SecurityThreat: - """Security threat detection result""" - threat_type: str - level: ThreatLevel - confidence: float - description: str - source_ip: str - user_agent: Optional[str] = None - request_path: Optional[str] = None - payload: Optional[str] = None - timestamp: datetime = field(default_factory=datetime.utcnow) - mitigation: Optional[str] = None - - -@dataclass -class SecurityAnalysis: - """Comprehensive security analysis result""" - is_threat: bool - threats: List[SecurityThreat] - risk_score: float - recommendations: List[str] - auth_level: AuthLevel - rate_limit_exceeded: bool - should_block: bool - timestamp: datetime = field(default_factory=datetime.utcnow) - - -@dataclass -class RateLimitInfo: - """Rate limiting information""" - auth_level: AuthLevel - requests_per_minute: int - requests_per_hour: int - minute_limit: int - hour_limit: int - exceeded: bool - - -@dataclass -class AnomalyDetection: - """Anomaly detection result""" - is_anomaly: bool - anomaly_type: str - severity: float - details: Dict[str, Any] - baseline_value: Optional[float] = None - current_value: Optional[float] = None - - -class ThreatDetectionService: - """Core threat detection and security analysis service""" - - def __init__(self): - self.name = "threat_detection" - - # Statistics - self.stats = { - 'total_requests_analyzed': 0, - 'threats_detected': 0, - 'threats_blocked': 0, - 'anomalies_detected': 0, - 'rate_limits_exceeded': 0, - 'total_analysis_time': 0, - 'threat_types': defaultdict(int), - 'threat_levels': defaultdict(int), - 'attacking_ips': defaultdict(int) - } - - # Threat detection patterns - self.sql_injection_patterns = [ - r"(\bunion\b.*\bselect\b)", - r"(\bselect\b.*\bfrom\b)", - r"(\binsert\b.*\binto\b)", - r"(\bupdate\b.*\bset\b)", - r"(\bdelete\b.*\bfrom\b)", - r"(\bdrop\b.*\btable\b)", - r"(\bor\b.*\b1\s*=\s*1\b)", - r"(\band\b.*\b1\s*=\s*1\b)", - r"(\bexec\b.*\bxp_\w+)", - r"(\bsp_\w+)", - r"(\bsleep\b\s*\(\s*\d+\s*\))", - r"(\bwaitfor\b.*\bdelay\b)", - r"(\bbenchmark\b\s*\(\s*\d+)", - r"(\bload_file\b\s*\()", - r"(\binto\b.*\boutfile\b)" - ] - - self.xss_patterns = [ - r"]*>.*?", - r"]*>.*?", - r"]*>.*?", - r"]*>.*?", - r"]*>", - r"]*>", - r"javascript:", - r"vbscript:", - r"on\w+\s*=", - r"style\s*=.*expression", - r"style\s*=.*javascript" - ] - - self.path_traversal_patterns = [ - r"\.\.\/", - r"\.\.\\", - r"%2e%2e%2f", - r"%2e%2e%5c", - r"..%2f", - r"..%5c", - r"%252e%252e%252f", - r"%252e%252e%255c" - ] - - self.command_injection_patterns = [ - r";\s*cat\s+", - r";\s*ls\s+", - r";\s*pwd\s*", - r";\s*whoami\s*", - r";\s*id\s*", - r";\s*uname\s*", - r";\s*ps\s+", - r";\s*netstat\s+", - r";\s*wget\s+", - r";\s*curl\s+", - r"\|\s*cat\s+", - r"\|\s*ls\s+", - r"&&\s*cat\s+", - r"&&\s*ls\s+" - ] - - self.suspicious_ua_patterns = [ - r"sqlmap", - r"nikto", - r"nmap", - r"masscan", - r"zap", - r"burp", - r"w3af", - r"acunetix", - r"nessus", - r"openvas", - r"metasploit" - ] - - # Rate limiting tracking - separate by auth level (excluding unauthenticated since they're blocked) - self.rate_limits = { - AuthLevel.AUTHENTICATED: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)}), - AuthLevel.API_KEY: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)}), - AuthLevel.PREMIUM: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)}) - } - - # Anomaly detection - self.request_history = deque(maxlen=1000) - self.ip_history = defaultdict(lambda: deque(maxlen=100)) - self.endpoint_history = defaultdict(lambda: deque(maxlen=100)) - - # Blocked and allowed IPs - self.blocked_ips = set(settings.API_BLOCKED_IPS) - self.allowed_ips = set(settings.API_ALLOWED_IPS) if settings.API_ALLOWED_IPS else None - - # IP reputation cache - self.ip_reputation_cache = {} - self.cache_expiry = {} - - # Compile patterns for performance - self._compile_patterns() - - logger.info(f"ThreatDetectionService initialized with {len(self.sql_injection_patterns)} SQL patterns, " - f"{len(self.xss_patterns)} XSS patterns, rate limiting enabled: {settings.API_RATE_LIMITING_ENABLED}") - - def _compile_patterns(self): - """Compile regex patterns for better performance""" - try: - self.compiled_sql_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.sql_injection_patterns] - self.compiled_xss_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.xss_patterns] - self.compiled_path_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.path_traversal_patterns] - self.compiled_cmd_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.command_injection_patterns] - self.compiled_ua_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.suspicious_ua_patterns] - except re.error as e: - logger.error(f"Failed to compile security patterns: {e}") - # Fallback to empty lists to prevent crashes - self.compiled_sql_patterns = [] - self.compiled_xss_patterns = [] - self.compiled_path_patterns = [] - self.compiled_cmd_patterns = [] - self.compiled_ua_patterns = [] - - def determine_auth_level(self, request: Request, user_context: Optional[Dict] = None) -> AuthLevel: - """Determine authentication level for rate limiting""" - # Check if request has API key authentication - if hasattr(request.state, 'api_key_context') and request.state.api_key_context: - api_key = request.state.api_key_context.get('api_key') - if api_key and hasattr(api_key, 'tier'): - # Check for premium tier - if api_key.tier in ['premium', 'enterprise']: - return AuthLevel.PREMIUM - return AuthLevel.API_KEY - - # Check for JWT authentication - if user_context or hasattr(request.state, 'user'): - return AuthLevel.AUTHENTICATED - - # Check Authorization header for API key - auth_header = request.headers.get("Authorization", "") - api_key_header = request.headers.get("X-API-Key", "") - if auth_header.startswith("Bearer ") or api_key_header: - return AuthLevel.API_KEY - - # Default to authenticated since unauthenticated requests are blocked at middleware - return AuthLevel.AUTHENTICATED - - def get_rate_limits(self, auth_level: AuthLevel) -> Tuple[int, int]: - """Get rate limits for authentication level""" - if not settings.API_RATE_LIMITING_ENABLED: - return float('inf'), float('inf') - - if auth_level == AuthLevel.AUTHENTICATED: - return (settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE, settings.API_RATE_LIMIT_AUTHENTICATED_PER_HOUR) - elif auth_level == AuthLevel.API_KEY: - return (settings.API_RATE_LIMIT_API_KEY_PER_MINUTE, settings.API_RATE_LIMIT_API_KEY_PER_HOUR) - elif auth_level == AuthLevel.PREMIUM: - return (settings.API_RATE_LIMIT_PREMIUM_PER_MINUTE, settings.API_RATE_LIMIT_PREMIUM_PER_HOUR) - else: - # Fallback to authenticated limits - return (settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE, settings.API_RATE_LIMIT_AUTHENTICATED_PER_HOUR) - - def check_rate_limit(self, client_ip: str, auth_level: AuthLevel) -> RateLimitInfo: - """Check if request exceeds rate limits""" - minute_limit, hour_limit = self.get_rate_limits(auth_level) - current_time = time.time() - - # Get or create tracking for this auth level - if auth_level not in self.rate_limits: - # This shouldn't happen, but handle gracefully - return RateLimitInfo( - auth_level=auth_level, - requests_per_minute=0, - requests_per_hour=0, - minute_limit=minute_limit, - hour_limit=hour_limit, - exceeded=False - ) - - ip_limits = self.rate_limits[auth_level][client_ip] - - # Clean old entries - minute_ago = current_time - 60 - hour_ago = current_time - 3600 - - while ip_limits['minute'] and ip_limits['minute'][0] < minute_ago: - ip_limits['minute'].popleft() - - while ip_limits['hour'] and ip_limits['hour'][0] < hour_ago: - ip_limits['hour'].popleft() - - # Check current counts - requests_per_minute = len(ip_limits['minute']) - requests_per_hour = len(ip_limits['hour']) - - # Check if limits exceeded - exceeded = (requests_per_minute >= minute_limit) or (requests_per_hour >= hour_limit) - - # Add current request to tracking - if not exceeded: - ip_limits['minute'].append(current_time) - ip_limits['hour'].append(current_time) - - return RateLimitInfo( - auth_level=auth_level, - requests_per_minute=requests_per_minute, - requests_per_hour=requests_per_hour, - minute_limit=minute_limit, - hour_limit=hour_limit, - exceeded=exceeded - ) - - async def analyze_request(self, request: Request, user_context: Optional[Dict] = None) -> SecurityAnalysis: - """Perform comprehensive security analysis on a request""" - start_time = time.time() - - try: - client_ip = request.client.host if request.client else "unknown" - user_agent = request.headers.get("user-agent", "") - path = str(request.url.path) - method = request.method - - # Determine authentication level - auth_level = self.determine_auth_level(request, user_context) - - # Check IP allowlist/blocklist first - if self.allowed_ips and client_ip not in self.allowed_ips: - threat = SecurityThreat( - threat_type="ip_not_allowed", - level=ThreatLevel.HIGH, - confidence=1.0, - description=f"IP {client_ip} not in allowlist", - source_ip=client_ip, - mitigation="Add IP to allowlist or remove IP restrictions" - ) - return SecurityAnalysis( - is_threat=True, - threats=[threat], - risk_score=1.0, - recommendations=["Block request immediately"], - auth_level=auth_level, - rate_limit_exceeded=False, - should_block=True - ) - - if client_ip in self.blocked_ips: - threat = SecurityThreat( - threat_type="ip_blocked", - level=ThreatLevel.CRITICAL, - confidence=1.0, - description=f"IP {client_ip} is blocked", - source_ip=client_ip, - mitigation="Remove IP from blocklist if legitimate" - ) - return SecurityAnalysis( - is_threat=True, - threats=[threat], - risk_score=1.0, - recommendations=["Block request immediately"], - auth_level=auth_level, - rate_limit_exceeded=False, - should_block=True - ) - - # Check rate limiting - rate_limit_info = self.check_rate_limit(client_ip, auth_level) - if rate_limit_info.exceeded: - self.stats['rate_limits_exceeded'] += 1 - threat = SecurityThreat( - threat_type="rate_limit_exceeded", - level=ThreatLevel.MEDIUM, - confidence=0.9, - description=f"Rate limit exceeded for {auth_level.value}: {rate_limit_info.requests_per_minute}/min, {rate_limit_info.requests_per_hour}/hr", - source_ip=client_ip, - mitigation=f"Implement rate limiting, current limits: {rate_limit_info.minute_limit}/min, {rate_limit_info.hour_limit}/hr" - ) - return SecurityAnalysis( - is_threat=True, - threats=[threat], - risk_score=0.7, - recommendations=[f"Rate limit exceeded for {auth_level.value} user"], - auth_level=auth_level, - rate_limit_exceeded=True, - should_block=True - ) - - # Skip threat detection if disabled - if not settings.API_THREAT_DETECTION_ENABLED: - return SecurityAnalysis( - is_threat=False, - threats=[], - risk_score=0.0, - recommendations=[], - auth_level=auth_level, - rate_limit_exceeded=False, - should_block=False - ) - - # Collect request data for threat analysis - query_params = str(request.query_params) - headers = dict(request.headers) - - # Try to get body content safely - body_content = "" - try: - if hasattr(request, '_body') and request._body: - body_content = request._body.decode() if isinstance(request._body, bytes) else str(request._body) - except: - pass - - threats = [] - - # Analyze for various threats - threats.extend(await self._detect_sql_injection(query_params, body_content, path, client_ip)) - threats.extend(await self._detect_xss(query_params, body_content, headers, client_ip)) - threats.extend(await self._detect_path_traversal(path, query_params, client_ip)) - threats.extend(await self._detect_command_injection(query_params, body_content, client_ip)) - threats.extend(await self._detect_suspicious_patterns(headers, user_agent, path, client_ip)) - - # Anomaly detection if enabled - if settings.API_ANOMALY_DETECTION_ENABLED: - anomaly = await self._detect_anomalies(client_ip, path, method, len(body_content)) - if anomaly.is_anomaly and anomaly.severity > settings.API_SECURITY_ANOMALY_THRESHOLD: - threat = SecurityThreat( - threat_type=f"anomaly_{anomaly.anomaly_type}", - level=ThreatLevel.MEDIUM if anomaly.severity > 0.7 else ThreatLevel.LOW, - confidence=anomaly.severity, - description=f"Anomalous behavior detected: {anomaly.details}", - source_ip=client_ip, - user_agent=user_agent, - request_path=path - ) - threats.append(threat) - - # Calculate risk score - risk_score = self._calculate_risk_score(threats) - - # Determine if request should be blocked - should_block = risk_score >= settings.API_SECURITY_RISK_THRESHOLD - - # Generate recommendations - recommendations = self._generate_recommendations(threats, risk_score, auth_level) - - # Update statistics - self._update_stats(threats, time.time() - start_time) - - return SecurityAnalysis( - is_threat=len(threats) > 0, - threats=threats, - risk_score=risk_score, - recommendations=recommendations, - auth_level=auth_level, - rate_limit_exceeded=False, - should_block=should_block - ) - - except Exception as e: - logger.error(f"Error in threat analysis: {e}") - return SecurityAnalysis( - is_threat=False, - threats=[], - risk_score=0.0, - recommendations=["Error occurred during security analysis"], - auth_level=AuthLevel.AUTHENTICATED, - rate_limit_exceeded=False, - should_block=False - ) - - async def _detect_sql_injection(self, query_params: str, body_content: str, path: str, client_ip: str) -> List[SecurityThreat]: - """Detect SQL injection attempts""" - threats = [] - content_to_check = f"{query_params} {body_content} {path}".lower() - - for pattern in self.compiled_sql_patterns: - if pattern.search(content_to_check): - threat = SecurityThreat( - threat_type="sql_injection", - level=ThreatLevel.HIGH, - confidence=0.85, - description="Potential SQL injection attempt detected", - source_ip=client_ip, - payload=pattern.pattern, - mitigation="Block request, sanitize input, use parameterized queries" - ) - threats.append(threat) - break # Don't duplicate for multiple patterns - - return threats - - async def _detect_xss(self, query_params: str, body_content: str, headers: dict, client_ip: str) -> List[SecurityThreat]: - """Detect XSS attempts""" - threats = [] - content_to_check = f"{query_params} {body_content}".lower() - - # Check headers for XSS - for header_name, header_value in headers.items(): - content_to_check += f" {header_value}".lower() - - for pattern in self.compiled_xss_patterns: - if pattern.search(content_to_check): - threat = SecurityThreat( - threat_type="xss", - level=ThreatLevel.HIGH, - confidence=0.80, - description="Potential XSS attack detected", - source_ip=client_ip, - payload=pattern.pattern, - mitigation="Block request, sanitize input, implement CSP headers" - ) - threats.append(threat) - break - - return threats - - async def _detect_path_traversal(self, path: str, query_params: str, client_ip: str) -> List[SecurityThreat]: - """Detect path traversal attempts""" - threats = [] - content_to_check = f"{path} {query_params}".lower() - decoded_content = unquote(content_to_check) - - for pattern in self.compiled_path_patterns: - if pattern.search(content_to_check) or pattern.search(decoded_content): - threat = SecurityThreat( - threat_type="path_traversal", - level=ThreatLevel.HIGH, - confidence=0.90, - description="Path traversal attempt detected", - source_ip=client_ip, - request_path=path, - mitigation="Block request, validate file paths, implement access controls" - ) - threats.append(threat) - break - - return threats - - async def _detect_command_injection(self, query_params: str, body_content: str, client_ip: str) -> List[SecurityThreat]: - """Detect command injection attempts""" - threats = [] - content_to_check = f"{query_params} {body_content}".lower() - - for pattern in self.compiled_cmd_patterns: - if pattern.search(content_to_check): - threat = SecurityThreat( - threat_type="command_injection", - level=ThreatLevel.CRITICAL, - confidence=0.95, - description="Command injection attempt detected", - source_ip=client_ip, - payload=pattern.pattern, - mitigation="Block request immediately, sanitize input, disable shell execution" - ) - threats.append(threat) - break - - return threats - - async def _detect_suspicious_patterns(self, headers: dict, user_agent: str, path: str, client_ip: str) -> List[SecurityThreat]: - """Detect suspicious patterns in headers and user agent""" - threats = [] - - # Check for suspicious user agents - ua_lower = user_agent.lower() - for pattern in self.compiled_ua_patterns: - if pattern.search(ua_lower): - threat = SecurityThreat( - threat_type="suspicious_user_agent", - level=ThreatLevel.HIGH, - confidence=0.85, - description=f"Suspicious user agent detected: {pattern.pattern}", - source_ip=client_ip, - user_agent=user_agent, - mitigation="Block request, monitor IP for further activity" - ) - threats.append(threat) - break - - # Check for suspicious headers - if "x-forwarded-for" in headers and "x-real-ip" in headers: - # Potential header manipulation - threat = SecurityThreat( - threat_type="header_manipulation", - level=ThreatLevel.LOW, - confidence=0.30, - description="Potential IP header manipulation detected", - source_ip=client_ip, - mitigation="Validate proxy headers, implement IP whitelisting" - ) - threats.append(threat) - - return threats - - async def _detect_anomalies(self, client_ip: str, path: str, method: str, body_size: int) -> AnomalyDetection: - """Detect anomalous behavior patterns""" - try: - # Request size anomaly - max_size = settings.API_MAX_REQUEST_BODY_SIZE - if body_size > max_size: - return AnomalyDetection( - is_anomaly=True, - anomaly_type="request_size", - severity=0.8, - details={"body_size": body_size, "threshold": max_size}, - current_value=body_size, - baseline_value=max_size // 10 - ) - - # Unusual endpoint access - if path.startswith("/admin") or path.startswith("/api/admin"): - return AnomalyDetection( - is_anomaly=True, - anomaly_type="sensitive_endpoint", - severity=0.6, - details={"path": path, "reason": "admin endpoint access"}, - current_value=1.0, - baseline_value=0.0 - ) - - # IP request frequency anomaly - current_time = time.time() - ip_requests = self.ip_history[client_ip] - - # Clean old entries (last 5 minutes) - five_minutes_ago = current_time - 300 - while ip_requests and ip_requests[0] < five_minutes_ago: - ip_requests.popleft() - - ip_requests.append(current_time) - - if len(ip_requests) > 100: # More than 100 requests in 5 minutes - return AnomalyDetection( - is_anomaly=True, - anomaly_type="request_frequency", - severity=0.7, - details={"requests_5min": len(ip_requests), "threshold": 100}, - current_value=len(ip_requests), - baseline_value=10 # 10 requests baseline - ) - - return AnomalyDetection( - is_anomaly=False, - anomaly_type="none", - severity=0.0, - details={} - ) - - except Exception as e: - logger.error(f"Error in anomaly detection: {e}") - return AnomalyDetection( - is_anomaly=False, - anomaly_type="error", - severity=0.0, - details={"error": str(e)} - ) - - def _calculate_risk_score(self, threats: List[SecurityThreat]) -> float: - """Calculate overall risk score based on threats""" - if not threats: - return 0.0 - - score = 0.0 - for threat in threats: - level_multiplier = { - ThreatLevel.LOW: 0.25, - ThreatLevel.MEDIUM: 0.5, - ThreatLevel.HIGH: 0.75, - ThreatLevel.CRITICAL: 1.0 - } - score += threat.confidence * level_multiplier.get(threat.level, 0.5) - - # Normalize to 0-1 range - return min(score / len(threats), 1.0) - - def _generate_recommendations(self, threats: List[SecurityThreat], risk_score: float, auth_level: AuthLevel) -> List[str]: - """Generate security recommendations based on analysis""" - recommendations = [] - - if risk_score >= settings.API_SECURITY_RISK_THRESHOLD: - recommendations.append("CRITICAL: Block this request immediately") - elif risk_score >= settings.API_SECURITY_WARNING_THRESHOLD: - recommendations.append("HIGH: Consider blocking or rate limiting this IP") - elif risk_score > 0.4: - recommendations.append("MEDIUM: Monitor this IP closely") - - threat_types = {threat.threat_type for threat in threats} - - if "sql_injection" in threat_types: - recommendations.append("Implement parameterized queries and input validation") - - if "xss" in threat_types: - recommendations.append("Implement Content Security Policy (CSP) headers") - - if "command_injection" in threat_types: - recommendations.append("Disable shell execution and validate all inputs") - - if "path_traversal" in threat_types: - recommendations.append("Implement proper file path validation and access controls") - - if "rate_limit_exceeded" in threat_types: - recommendations.append(f"Rate limiting active for {auth_level.value} user") - - if not recommendations: - recommendations.append("No immediate action required, continue monitoring") - - return recommendations - - def _update_stats(self, threats: List[SecurityThreat], analysis_time: float): - """Update service statistics""" - self.stats['total_requests_analyzed'] += 1 - self.stats['total_analysis_time'] += analysis_time - - if threats: - self.stats['threats_detected'] += len(threats) - for threat in threats: - self.stats['threat_types'][threat.threat_type] += 1 - self.stats['threat_levels'][threat.level.value] += 1 - if threat.source_ip: - self.stats['attacking_ips'][threat.source_ip] += 1 - - def get_stats(self) -> Dict[str, Any]: - """Get service statistics""" - avg_time = (self.stats['total_analysis_time'] / self.stats['total_requests_analyzed'] - if self.stats['total_requests_analyzed'] > 0 else 0) - - # Get top attacking IPs - top_ips = sorted(self.stats['attacking_ips'].items(), key=lambda x: x[1], reverse=True)[:10] - - return { - "total_requests_analyzed": self.stats['total_requests_analyzed'], - "threats_detected": self.stats['threats_detected'], - "threats_blocked": self.stats['threats_blocked'], - "anomalies_detected": self.stats['anomalies_detected'], - "rate_limits_exceeded": self.stats['rate_limits_exceeded'], - "avg_analysis_time": avg_time, - "threat_types": dict(self.stats['threat_types']), - "threat_levels": dict(self.stats['threat_levels']), - "top_attacking_ips": top_ips, - "security_enabled": settings.API_SECURITY_ENABLED, - "threat_detection_enabled": settings.API_THREAT_DETECTION_ENABLED, - "rate_limiting_enabled": settings.API_RATE_LIMITING_ENABLED - } - - -# Global threat detection service instance -threat_detection_service = ThreatDetectionService() \ No newline at end of file diff --git a/backend/app/main.py b/backend/app/main.py index 8bea827..8c8b26d 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -52,10 +52,18 @@ async def lifespan(app: FastAPI): # Initialize config manager await init_config_manager() - + + # Initialize LLM service (needed by RAG module) + from app.services.llm.service import llm_service + try: + await llm_service.initialize() + logger.info("LLM service initialized successfully") + except Exception as e: + logger.warning(f"LLM service initialization failed: {e}") + # Initialize analytics service init_analytics_service() - + # Initialize module manager with FastAPI app for router registration await module_manager.initialize(app) app.state.module_manager = module_manager diff --git a/backend/app/middleware/rate_limiting.py b/backend/app/middleware/rate_limiting.py deleted file mode 100644 index f6e1901..0000000 --- a/backend/app/middleware/rate_limiting.py +++ /dev/null @@ -1,371 +0,0 @@ -""" -Rate limiting middleware -""" - -import time -import redis -from typing import Dict, Optional -from fastapi import Request, HTTPException, status -from fastapi.responses import JSONResponse -from starlette.middleware.base import BaseHTTPMiddleware -import asyncio -from datetime import datetime, timedelta - -from app.core.config import settings -from app.core.logging import get_logger - -logger = get_logger(__name__) - - -class RateLimiter: - """Rate limiting implementation using Redis""" - - def __init__(self): - try: - self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True) - self.redis_client.ping() # Test connection - logger.info("Rate limiter initialized with Redis backend") - except Exception as e: - logger.warning(f"Redis not available for rate limiting: {e}") - self.redis_client = None - # Fall back to in-memory rate limiting - self.memory_store: Dict[str, Dict[str, float]] = {} - - async def check_rate_limit( - self, - key: str, - limit: int, - window_seconds: int, - identifier: str = "default" - ) -> tuple[bool, Dict[str, int]]: - """ - Check if request is within rate limit - - Args: - key: Rate limiting key (e.g., IP address, API key) - limit: Maximum number of requests allowed - window_seconds: Time window in seconds - identifier: Additional identifier for the rate limit - - Returns: - Tuple of (is_allowed, headers_dict) - """ - - full_key = f"rate_limit:{identifier}:{key}" - current_time = int(time.time()) - window_start = current_time - window_seconds - - if self.redis_client: - return await self._check_redis_rate_limit( - full_key, limit, window_seconds, current_time, window_start - ) - else: - return self._check_memory_rate_limit( - full_key, limit, window_seconds, current_time, window_start - ) - - async def _check_redis_rate_limit( - self, - key: str, - limit: int, - window_seconds: int, - current_time: int, - window_start: int - ) -> tuple[bool, Dict[str, int]]: - """Check rate limit using Redis""" - - pipe = self.redis_client.pipeline() - - # Remove old entries - pipe.zremrangebyscore(key, 0, window_start) - - # Count current requests in window - pipe.zcard(key) - - # Add current request - pipe.zadd(key, {str(current_time): current_time}) - - # Set expiration - pipe.expire(key, window_seconds + 1) - - results = pipe.execute() - current_requests = results[1] - - # Calculate remaining requests and reset time - remaining = max(0, limit - current_requests - 1) - reset_time = current_time + window_seconds - - headers = { - "X-RateLimit-Limit": limit, - "X-RateLimit-Remaining": remaining, - "X-RateLimit-Reset": reset_time, - "X-RateLimit-Window": window_seconds - } - - is_allowed = current_requests < limit - - if not is_allowed: - logger.warning(f"Rate limit exceeded for key: {key}") - - return is_allowed, headers - - def _check_memory_rate_limit( - self, - key: str, - limit: int, - window_seconds: int, - current_time: int, - window_start: int - ) -> tuple[bool, Dict[str, int]]: - """Check rate limit using in-memory storage""" - - if key not in self.memory_store: - self.memory_store[key] = {} - - # Clean old entries - store = self.memory_store[key] - keys_to_remove = [k for k, v in store.items() if v < window_start] - for k in keys_to_remove: - del store[k] - - current_requests = len(store) - - # Calculate remaining requests and reset time - remaining = max(0, limit - current_requests - 1) - reset_time = current_time + window_seconds - - headers = { - "X-RateLimit-Limit": limit, - "X-RateLimit-Remaining": remaining, - "X-RateLimit-Reset": reset_time, - "X-RateLimit-Window": window_seconds - } - - is_allowed = current_requests < limit - - if is_allowed: - # Add current request - store[str(current_time)] = current_time - else: - logger.warning(f"Rate limit exceeded for key: {key}") - - return is_allowed, headers - - -# Global rate limiter instance -rate_limiter = RateLimiter() - - -class RateLimitMiddleware(BaseHTTPMiddleware): - """Rate limiting middleware for FastAPI""" - - def __init__(self, app): - super().__init__(app) - self.rate_limiter = RateLimiter() - logger.info("RateLimitMiddleware initialized") - - async def dispatch(self, request: Request, call_next): - """Process request through rate limiting""" - - # Skip rate limiting if disabled in settings - if not settings.API_RATE_LIMITING_ENABLED: - response = await call_next(request) - return response - - # Skip rate limiting for all internal API endpoints (platform operations) - if request.url.path.startswith("/api-internal/v1/"): - response = await call_next(request) - return response - - # Only apply rate limiting to privatemode.ai proxy endpoints (OpenAI-compatible API and LLM service) - # Skip for all other endpoints - if not (request.url.path.startswith("/api/v1/chat/completions") or - request.url.path.startswith("/api/v1/embeddings") or - request.url.path.startswith("/api/v1/models") or - request.url.path.startswith("/api/v1/llm/")): - response = await call_next(request) - return response - - # Skip rate limiting for health checks and static files - if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]: - response = await call_next(request) - return response - - # Get client IP - client_ip = request.client.host - forwarded_for = request.headers.get("X-Forwarded-For") - if forwarded_for: - client_ip = forwarded_for.split(",")[0].strip() - - # Check for API key in headers - api_key = None - auth_header = request.headers.get("Authorization") - if auth_header and auth_header.startswith("Bearer "): - api_key = auth_header[7:] - elif request.headers.get("X-API-Key"): - api_key = request.headers.get("X-API-Key") - - # Determine rate limiting strategy - headers = {} - is_allowed = True - - if api_key: - # API key-based rate limiting - api_key_key = f"api_key:{api_key}" - - # First check organization-wide limits (PrivateMode limits are org-wide) - org_key = "organization:privatemode" - - # Check organization per-minute limit - org_allowed_minute, org_headers_minute = await self.rate_limiter.check_rate_limit( - org_key, settings.PRIVATEMODE_REQUESTS_PER_MINUTE, 60, "minute" - ) - - # Check organization per-hour limit - org_allowed_hour, org_headers_hour = await self.rate_limiter.check_rate_limit( - org_key, settings.PRIVATEMODE_REQUESTS_PER_HOUR, 3600, "hour" - ) - - # If organization limits are exceeded, return 429 - if not (org_allowed_minute and org_allowed_hour): - logger.warning(f"Organization rate limit exceeded for {org_key}") - return JSONResponse( - status_code=status.HTTP_429_TOO_MANY_REQUESTS, - content={"detail": "Organization rate limit exceeded"}, - headers=org_headers_minute - ) - - # Then check per-API key limits - limit_per_minute = settings.API_RATE_LIMIT_API_KEY_PER_MINUTE - limit_per_hour = settings.API_RATE_LIMIT_API_KEY_PER_HOUR - - # Check per-minute limit - is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit( - api_key_key, limit_per_minute, 60, "minute" - ) - - # Check per-hour limit - is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit( - api_key_key, limit_per_hour, 3600, "hour" - ) - - is_allowed = is_allowed_minute and is_allowed_hour - headers = headers_minute # Use minute headers for response - - else: - # IP-based rate limiting for unauthenticated requests - rate_limit_key = f"ip:{client_ip}" - - # More restrictive limits for unauthenticated requests - limit_per_minute = 20 # Hardcoded for unauthenticated users - limit_per_hour = 100 - - # Check per-minute limit - is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit( - rate_limit_key, limit_per_minute, 60, "minute" - ) - - # Check per-hour limit - is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit( - rate_limit_key, limit_per_hour, 3600, "hour" - ) - - is_allowed = is_allowed_minute and is_allowed_hour - headers = headers_minute # Use minute headers for response - - # If rate limit exceeded, return 429 - if not is_allowed: - return JSONResponse( - status_code=status.HTTP_429_TOO_MANY_REQUESTS, - content={ - "error": "RATE_LIMIT_EXCEEDED", - "message": "Rate limit exceeded. Please try again later.", - "details": { - "limit": headers["X-RateLimit-Limit"], - "reset_time": headers["X-RateLimit-Reset"] - } - }, - headers={k: str(v) for k, v in headers.items()} - ) - - # Continue with request - response = await call_next(request) - - # Add rate limit headers to response - for key, value in headers.items(): - response.headers[key] = str(value) - - return response - - -# Keep the old function for backward compatibility -async def rate_limit_middleware(request: Request, call_next): - """Legacy function - use RateLimitMiddleware class instead""" - middleware = RateLimitMiddleware(None) - return await middleware.dispatch(request, call_next) - - -class RateLimitExceeded(HTTPException): - """Exception raised when rate limit is exceeded""" - - def __init__(self, limit: int, reset_time: int): - super().__init__( - status_code=status.HTTP_429_TOO_MANY_REQUESTS, - detail=f"Rate limit exceeded. Limit: {limit}, Reset: {reset_time}" - ) - - -# Decorator for applying rate limits to specific endpoints -def rate_limit(requests_per_minute: int = 60, requests_per_hour: int = 1000): - """ - Decorator to apply rate limiting to specific endpoints - - Args: - requests_per_minute: Maximum requests per minute - requests_per_hour: Maximum requests per hour - """ - def decorator(func): - async def wrapper(*args, **kwargs): - # This would be implemented to work with FastAPI dependencies - # For now, this is a placeholder for endpoint-specific rate limiting - return await func(*args, **kwargs) - return wrapper - return decorator - - -# Helper functions for different rate limiting strategies -async def check_api_key_rate_limit(api_key: str, endpoint: str) -> bool: - """Check rate limit for specific API key and endpoint""" - - # This would lookup API key specific limits from database - # For now, using default limits - key = f"api_key:{api_key}:endpoint:{endpoint}" - - is_allowed, _ = await rate_limiter.check_rate_limit( - key, limit=100, window_seconds=60, identifier="endpoint" - ) - - return is_allowed - - -async def check_user_rate_limit(user_id: str, action: str) -> bool: - """Check rate limit for specific user and action""" - - key = f"user:{user_id}:action:{action}" - - is_allowed, _ = await rate_limiter.check_rate_limit( - key, limit=50, window_seconds=60, identifier="user_action" - ) - - return is_allowed - - -async def apply_burst_protection(key: str) -> bool: - """Apply burst protection for high-frequency actions""" - - # Allow burst of 10 requests in 10 seconds - is_allowed, _ = await rate_limiter.check_rate_limit( - key, limit=10, window_seconds=10, identifier="burst" - ) - - return is_allowed \ No newline at end of file diff --git a/backend/app/middleware/security.py b/backend/app/middleware/security.py deleted file mode 100644 index c7b7952..0000000 --- a/backend/app/middleware/security.py +++ /dev/null @@ -1,210 +0,0 @@ -""" -Security middleware for request/response processing -""" - -import json -import time -from typing import Callable, Optional, Dict, Any - -from fastapi import Request, Response -from fastapi.responses import JSONResponse -from starlette.middleware.base import BaseHTTPMiddleware - -from app.core.config import settings -from app.core.logging import get_logger -from app.core.threat_detection import threat_detection_service, SecurityAnalysis - -logger = get_logger(__name__) - - -class SecurityMiddleware(BaseHTTPMiddleware): - """Security middleware for threat detection and request filtering - DISABLED""" - - def __init__(self, app, enabled: bool = True): - super().__init__(app) - self.enabled = False # Force disable regardless of settings - logger.info("SecurityMiddleware initialized, enabled: False (DISABLED)") - - async def dispatch(self, request: Request, call_next: Callable) -> Response: - """Process request through security analysis - DISABLED""" - # Security disabled, always pass through - return await call_next(request) - - def _should_skip_security(self, request: Request) -> bool: - """Determine if security analysis should be skipped for this request""" - path = request.url.path - - # Skip for health checks, authentication endpoints, and static assets - skip_paths = [ - "/health", - "/metrics", - "/api/v1/docs", - "/api/v1/openapi.json", - "/api/v1/redoc", - "/favicon.ico", - "/api/v1/auth/register", - "/api/v1/auth/login", - "/api/v1/auth/refresh", # Allow refresh endpoint - "/api-internal/v1/auth/register", - "/api-internal/v1/auth/login", - "/api-internal/v1/auth/refresh", # Allow refresh endpoint for internal API - "/", # Root endpoint - ] - - # Skip for static file extensions - static_extensions = [".css", ".js", ".png", ".jpg", ".jpeg", ".gif", ".ico", ".svg", ".woff", ".woff2"] - - return ( - path in skip_paths or - any(path.endswith(ext) for ext in static_extensions) or - path.startswith("/static/") - ) - - def _has_valid_auth(self, request: Request) -> bool: - """Check if request has valid authentication""" - # Check Authorization header - auth_header = request.headers.get("Authorization", "") - api_key_header = request.headers.get("X-API-Key", "") - - # Has some form of auth token/key - return ( - auth_header.startswith("Bearer ") and len(auth_header) > 7 or - len(api_key_header.strip()) > 0 - ) - - def _create_block_response(self, analysis: SecurityAnalysis) -> JSONResponse: - """Create response for blocked requests""" - # Determine status code based on threat type - status_code = 403 # Forbidden by default - - # Critical threats get 403 - for threat in analysis.threats: - if threat.threat_type in ["command_injection", "sql_injection"]: - status_code = 403 - break - - response_data = { - "error": "Security Policy Violation", - "message": "Request blocked due to security policy violation", - "risk_score": round(analysis.risk_score, 3), - "auth_level": analysis.auth_level.value, - "threat_count": len(analysis.threats), - "recommendations": analysis.recommendations[:3] # Limit to first 3 recommendations - } - - response = JSONResponse( - content=response_data, - status_code=status_code - ) - - return response - - def _add_security_headers(self, response: Response) -> Response: - """Add security headers to response""" - if not settings.API_SECURITY_HEADERS_ENABLED: - return response - - # Standard security headers - response.headers["X-Content-Type-Options"] = "nosniff" - response.headers["X-Frame-Options"] = "DENY" - response.headers["X-XSS-Protection"] = "1; mode=block" - response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" - - # Only add HSTS for HTTPS - if hasattr(response, 'headers') and response.headers.get("X-Forwarded-Proto") == "https": - response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" - - # Content Security Policy - if settings.API_CSP_HEADER: - response.headers["Content-Security-Policy"] = settings.API_CSP_HEADER - - return response - - def _add_security_metrics(self, response: Response, analysis: SecurityAnalysis, analysis_time: float) -> Response: - """Add security metrics to response headers (for debugging/monitoring)""" - # Only add in debug mode or for admin users - if settings.APP_DEBUG: - response.headers["X-Security-Risk-Score"] = str(round(analysis.risk_score, 3)) - response.headers["X-Security-Threats"] = str(len(analysis.threats)) - response.headers["X-Security-Auth-Level"] = analysis.auth_level.value - response.headers["X-Security-Analysis-Time"] = f"{analysis_time*1000:.1f}ms" - - return response - - async def _log_security_event(self, request: Request, analysis: SecurityAnalysis): - """Log security events for audit and monitoring""" - client_ip = request.client.host if request.client else "unknown" - user_agent = request.headers.get("user-agent", "") - - # Create security event log - event_data = { - "timestamp": analysis.timestamp.isoformat(), - "client_ip": client_ip, - "user_agent": user_agent, - "path": str(request.url.path), - "method": request.method, - "risk_score": round(analysis.risk_score, 3), - "auth_level": analysis.auth_level.value, - "threat_count": len(analysis.threats), - "rate_limit_exceeded": analysis.rate_limit_exceeded, - "should_block": analysis.should_block, - "threats": [ - { - "type": threat.threat_type, - "level": threat.level.value, - "confidence": round(threat.confidence, 3), - "description": threat.description - } - for threat in analysis.threats[:5] # Limit to first 5 threats - ], - "recommendations": analysis.recommendations - } - - # Log at appropriate level based on risk - if analysis.should_block: - logger.warning(f"SECURITY_BLOCK: {json.dumps(event_data)}") - elif analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD: - logger.warning(f"SECURITY_WARNING: {json.dumps(event_data)}") - else: - logger.info(f"SECURITY_THREAT: {json.dumps(event_data)}") - - -def setup_security_middleware(app, enabled: bool = True) -> None: - """Setup security middleware on FastAPI app""" - if enabled and settings.API_SECURITY_ENABLED: - app.add_middleware(SecurityMiddleware, enabled=enabled) - logger.info("Security middleware enabled") - else: - logger.info("Security middleware disabled") - - -# Helper functions for manual security checks -async def analyze_request_security(request: Request, user_context: Optional[Dict] = None) -> SecurityAnalysis: - """Manually analyze request security (for use in route handlers)""" - return await threat_detection_service.analyze_request(request, user_context) - - -def get_security_stats() -> Dict[str, Any]: - """Get security statistics""" - return threat_detection_service.get_stats() - - -def is_request_blocked(request: Request) -> bool: - """Check if request was blocked by security analysis""" - if hasattr(request.state, 'security_analysis'): - return request.state.security_analysis.should_block - return False - - -def get_request_risk_score(request: Request) -> float: - """Get risk score for request""" - if hasattr(request.state, 'security_analysis'): - return request.state.security_analysis.risk_score - return 0.0 - - -def get_request_auth_level(request: Request) -> str: - """Get authentication level for request""" - if hasattr(request.state, 'security_analysis'): - return request.state.security_analysis.auth_level.value - return "unknown" \ No newline at end of file diff --git a/backend/app/services/document_processor.py b/backend/app/services/document_processor.py index 8447333..8875ae8 100644 --- a/backend/app/services/document_processor.py +++ b/backend/app/services/document_processor.py @@ -162,6 +162,7 @@ class DocumentProcessor: async def _process_document(self, task: ProcessingTask) -> bool: """Process a single document""" + from datetime import datetime from app.db.database import async_session_factory async with async_session_factory() as session: try: @@ -182,16 +183,24 @@ class DocumentProcessor: document.status = ProcessingStatus.PROCESSING await session.commit() - # Get RAG module for processing (now includes content processing) + # Get RAG module for processing try: - from app.services.module_manager import module_manager - rag_module = module_manager.get_module('rag') + # Import RAG module and initialize it properly + from modules.rag.main import RAGModule + from app.core.config import settings + + # Create and initialize RAG module instance + rag_module = RAGModule(settings) + init_result = await rag_module.initialize() + if not rag_module.enabled: + raise Exception("Failed to enable RAG module") + except Exception as e: logger.error(f"Failed to get RAG module: {e}") raise Exception(f"RAG module not available: {e}") - - if not rag_module: - raise Exception("RAG module not available") + + if not rag_module or not rag_module.enabled: + raise Exception("RAG module not available or not enabled") logger.info(f"RAG module loaded successfully for document {task.document_id}") @@ -204,31 +213,45 @@ class DocumentProcessor: # Process with RAG module logger.info(f"Starting document processing for document {task.document_id} with RAG module") - try: - # Add timeout to prevent hanging - processed_doc = await asyncio.wait_for( - rag_module.process_document( - file_content, - document.original_filename, - {} - ), - timeout=300.0 # 5 minute timeout - ) - logger.info(f"Document processing completed for document {task.document_id}") - except asyncio.TimeoutError: - logger.error(f"Document processing timed out for document {task.document_id}") - raise Exception("Document processing timed out after 5 minutes") - except Exception as e: - logger.error(f"Document processing failed for document {task.document_id}: {e}") - raise - - # Update document with processed content - document.converted_content = processed_doc.content - document.word_count = processed_doc.word_count - document.character_count = len(processed_doc.content) - document.document_metadata = processed_doc.metadata - document.status = ProcessingStatus.PROCESSED - document.processed_at = datetime.utcnow() + + # Special handling for JSONL files - skip processing phase + if document.file_type == 'jsonl': + # For JSONL files, we don't need to process content here + # The optimized JSONL processor will handle everything during indexing + document.converted_content = f"JSONL file with {len(file_content)} bytes" + document.word_count = 0 # Will be updated during indexing + document.character_count = len(file_content) + document.document_metadata = {"file_path": document.file_path, "processed": "jsonl"} + document.status = ProcessingStatus.PROCESSED + document.processed_at = datetime.utcnow() + logger.info(f"JSONL document {task.document_id} marked for optimized processing") + else: + # Standard processing for other file types + try: + # Add timeout to prevent hanging + processed_doc = await asyncio.wait_for( + rag_module.process_document( + file_content, + document.original_filename, + {"file_path": document.file_path} + ), + timeout=300.0 # 5 minute timeout + ) + logger.info(f"Document processing completed for document {task.document_id}") + + # Update document with processed content + document.converted_content = processed_doc.content + document.word_count = processed_doc.word_count + document.character_count = len(processed_doc.content) + document.document_metadata = processed_doc.metadata + document.status = ProcessingStatus.PROCESSED + document.processed_at = datetime.utcnow() + except asyncio.TimeoutError: + logger.error(f"Document processing timed out for document {task.document_id}") + raise Exception("Document processing timed out after 5 minutes") + except Exception as e: + logger.error(f"Document processing failed for document {task.document_id}: {e}") + raise # Index in RAG system using same RAG module if rag_module and document.converted_content: @@ -245,14 +268,57 @@ class DocumentProcessor: } # Use the correct Qdrant collection name for this document - await asyncio.wait_for( - rag_module.index_document( - content=document.converted_content, - metadata=doc_metadata, - collection_name=document.collection.qdrant_collection_name - ), - timeout=120.0 # 2 minute timeout for indexing - ) + # For JSONL files, we need to use the processed document flow + if document.file_type == 'jsonl': + # Create a ProcessedDocument for the JSONL processor + from app.modules.rag.main import ProcessedDocument + from datetime import datetime + import hashlib + + # Calculate file hash + processed_at = datetime.utcnow() + file_hash = hashlib.md5(str(document.id).encode()).hexdigest() + + processed_doc = ProcessedDocument( + id=str(document.id), + content="", # Will be filled by JSONL processor + extracted_text="", # Will be filled by JSONL processor + metadata={ + **doc_metadata, + "file_path": document.file_path + }, + original_filename=document.original_filename, + file_type=document.file_type, + mime_type=document.mime_type, + language=document.document_metadata.get('language', 'EN'), + word_count=0, # Will be updated during processing + sentence_count=0, # Will be updated during processing + entities=[], + keywords=[], + processing_time=0.0, + processed_at=processed_at, + file_hash=file_hash, + file_size=document.file_size + ) + + # The JSONL processor will read the original file + await asyncio.wait_for( + rag_module.index_processed_document( + processed_doc=processed_doc, + collection_name=document.collection.qdrant_collection_name + ), + timeout=300.0 # 5 minute timeout for JSONL processing + ) + else: + # Use standard indexing for other file types + await asyncio.wait_for( + rag_module.index_document( + content=document.converted_content, + metadata=doc_metadata, + collection_name=document.collection.qdrant_collection_name + ), + timeout=120.0 # 2 minute timeout for indexing + ) logger.info(f"Document {task.document_id} indexed successfully in collection {document.collection.qdrant_collection_name}") @@ -271,7 +337,9 @@ class DocumentProcessor: except Exception as e: logger.error(f"Failed to index document {task.document_id} in RAG: {e}") - # Keep as processed even if indexing fails + # Mark as error since indexing failed + document.status = ProcessingStatus.ERROR + document.processing_error = f"Indexing failed: {str(e)}" # Don't raise the exception to avoid retries on indexing failures await session.commit() diff --git a/backend/app/services/embedding_service.py b/backend/app/services/embedding_service.py index 4032086..ab7e04f 100644 --- a/backend/app/services/embedding_service.py +++ b/backend/app/services/embedding_service.py @@ -28,9 +28,19 @@ class EmbeddingService: await llm_service.initialize() # Test LLM service health - health_summary = llm_service.get_health_summary() - if health_summary.get("service_status") != "healthy": - logger.error(f"LLM service unhealthy: {health_summary}") + if not llm_service._initialized: + logger.error("LLM service not initialized") + return False + + # Check if PrivateMode provider is available + try: + provider_status = await llm_service.get_provider_status() + privatemode_status = provider_status.get("privatemode") + if not privatemode_status or privatemode_status.status != "healthy": + logger.error(f"PrivateMode provider not available: {privatemode_status}") + return False + except Exception as e: + logger.error(f"Failed to check provider status: {e}") return False self.initialized = True @@ -75,6 +85,12 @@ class EmbeddingService: else: truncated_text = text + # Guard: skip empty inputs (validator rejects empty strings) + if not truncated_text.strip(): + logger.debug("Empty input for embedding; using fallback vector") + batch_embeddings.append(self._generate_fallback_embedding(text)) + continue + # Call LLM service embedding endpoint from app.services.llm.service import llm_service from app.services.llm.models import EmbeddingRequest @@ -163,4 +179,4 @@ class EmbeddingService: # Global embedding service instance -embedding_service = EmbeddingService() \ No newline at end of file +embedding_service = EmbeddingService() diff --git a/backend/app/services/enhanced_embedding_service.py b/backend/app/services/enhanced_embedding_service.py index 284773f..cc66e42 100644 --- a/backend/app/services/enhanced_embedding_service.py +++ b/backend/app/services/enhanced_embedding_service.py @@ -25,9 +25,10 @@ class EnhancedEmbeddingService(EmbeddingService): 'requests_count': 0, 'window_start': time.time(), 'window_size': 60, # 1 minute window - 'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 60)), # Configurable + 'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 12)), # Configurable 'retry_delays': [int(x) for x in getattr(settings, 'RAG_EMBEDDING_RETRY_DELAYS', '1,2,4,8,16').split(',')], # Exponential backoff - 'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 0.5)), + 'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 1.0)), + 'delay_per_request': float(getattr(settings, 'RAG_EMBEDDING_DELAY_PER_REQUEST', 0.5)), 'last_rate_limit_error': None } @@ -38,7 +39,7 @@ class EnhancedEmbeddingService(EmbeddingService): if max_retries is None: max_retries = int(getattr(settings, 'RAG_EMBEDDING_RETRY_COUNT', 3)) - batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 5)) + batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 3)) if not self.initialized: logger.warning("Embedding service not initialized, using fallback") @@ -76,9 +77,6 @@ class EnhancedEmbeddingService(EmbeddingService): # Make the request embeddings = await self._get_embeddings_batch_impl(texts) - # Update rate limit tracker on success - self._update_rate_limit_tracker(success=True) - return embeddings, True except Exception as e: @@ -120,6 +118,12 @@ class EnhancedEmbeddingService(EmbeddingService): embeddings = [] for text in texts: + # Respect rate limit before each request + while self._is_rate_limited(): + delay = self._get_rate_limit_delay() + logger.warning(f"Rate limit window exceeded, waiting {delay:.2f}s before next request") + await asyncio.sleep(delay) + # Truncate text if needed max_chars = 1600 truncated_text = text[:max_chars] if len(text) > max_chars else text @@ -142,8 +146,14 @@ class EnhancedEmbeddingService(EmbeddingService): self._dimension_confirmed = True else: raise ValueError("Empty embedding in response") - else: - raise ValueError("Invalid response structure") + else: + raise ValueError("Invalid response structure") + + # Count this successful request and optionally delay between requests + self._update_rate_limit_tracker(success=True) + per_req_delay = self.rate_limit_tracker.get('delay_per_request', 0) + if per_req_delay and per_req_delay > 0: + await asyncio.sleep(per_req_delay) return embeddings @@ -198,4 +208,4 @@ class EnhancedEmbeddingService(EmbeddingService): # Global enhanced embedding service instance -enhanced_embedding_service = EnhancedEmbeddingService() \ No newline at end of file +enhanced_embedding_service = EnhancedEmbeddingService() diff --git a/backend/app/services/llm/config.py b/backend/app/services/llm/config.py index 61a8576..b7aeb13 100644 --- a/backend/app/services/llm/config.py +++ b/backend/app/services/llm/config.py @@ -16,6 +16,7 @@ from .models import ResilienceConfig class ProviderConfig(BaseModel): """Configuration for an LLM provider""" name: str = Field(..., description="Provider name") + provider_type: str = Field(..., description="Provider type (e.g., 'openai', 'privatemode')") enabled: bool = Field(True, description="Whether provider is enabled") base_url: str = Field(..., description="Provider base URL") api_key_env_var: str = Field(..., description="Environment variable for API key") @@ -53,9 +54,6 @@ class LLMServiceConfig(BaseModel): enable_security_checks: bool = Field(True, description="Enable security validation") enable_metrics_collection: bool = Field(True, description="Enable metrics collection") - # Security settings - security_risk_threshold: float = Field(0.8, ge=0.0, le=1.0, description="Risk threshold for blocking") - security_warning_threshold: float = Field(0.6, ge=0.0, le=1.0, description="Risk threshold for warnings") max_prompt_length: int = Field(50000, ge=1000, description="Maximum prompt length") max_response_length: int = Field(32000, ge=1000, description="Maximum response length") @@ -78,12 +76,6 @@ class LLMServiceConfig(BaseModel): # Model routing (model_name -> provider_name) model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing") - @validator('security_risk_threshold') - def validate_risk_threshold(cls, v, values): - warning_threshold = values.get('security_warning_threshold', 0.6) - if v <= warning_threshold: - raise ValueError("Risk threshold must be greater than warning threshold") - return v def create_default_config() -> LLMServiceConfig: @@ -93,6 +85,7 @@ def create_default_config() -> LLMServiceConfig: # Models will be fetched dynamically from proxy /models endpoint privatemode_config = ProviderConfig( name="privatemode", + provider_type="privatemode", enabled=True, base_url=settings.PRIVATEMODE_PROXY_URL, api_key_env_var="PRIVATEMODE_API_KEY", @@ -119,9 +112,6 @@ def create_default_config() -> LLMServiceConfig: config = LLMServiceConfig( default_provider="privatemode", enable_detailed_logging=settings.LOG_LLM_PROMPTS, - enable_security_checks=settings.API_SECURITY_ENABLED, - security_risk_threshold=settings.API_SECURITY_RISK_THRESHOLD, - security_warning_threshold=settings.API_SECURITY_WARNING_THRESHOLD, providers={ "privatemode": privatemode_config }, diff --git a/backend/app/services/llm/metrics.py b/backend/app/services/llm/metrics.py index 542dd7d..9a35fc4 100644 --- a/backend/app/services/llm/metrics.py +++ b/backend/app/services/llm/metrics.py @@ -124,7 +124,6 @@ class MetricsCollector: total_requests = len(self._metrics) successful_requests = sum(1 for m in self._metrics if m.success) failed_requests = total_requests - successful_requests - security_blocked = sum(1 for m in self._metrics if not m.success and m.security_risk_score > 0.8) # Calculate averages latencies = [m.latency_ms for m in self._metrics if m.latency_ms > 0] @@ -143,7 +142,6 @@ class MetricsCollector: total_requests=total_requests, successful_requests=successful_requests, failed_requests=failed_requests, - security_blocked_requests=security_blocked, average_latency_ms=avg_latency, average_risk_score=avg_risk_score, provider_metrics=provider_metrics, diff --git a/backend/app/services/llm/models.py b/backend/app/services/llm/models.py index 903451d..b699b2c 100644 --- a/backend/app/services/llm/models.py +++ b/backend/app/services/llm/models.py @@ -157,7 +157,6 @@ class LLMMetrics(BaseModel): total_requests: int = Field(0, description="Total requests processed") successful_requests: int = Field(0, description="Successful requests") failed_requests: int = Field(0, description="Failed requests") - security_blocked_requests: int = Field(0, description="Security blocked requests") average_latency_ms: float = Field(0.0, description="Average response latency") average_risk_score: float = Field(0.0, description="Average security risk score") provider_metrics: Dict[str, Dict[str, Any]] = Field(default_factory=dict, description="Per-provider metrics") diff --git a/backend/app/services/llm/providers/privatemode.py b/backend/app/services/llm/providers/privatemode.py index 63f18ad..b136ccb 100644 --- a/backend/app/services/llm/providers/privatemode.py +++ b/backend/app/services/llm/providers/privatemode.py @@ -452,6 +452,8 @@ class PrivateModeProvider(BaseLLMProvider): else: error_text = await response.text() + # Log the detailed error response from the provider + logger.error(f"PrivateMode embedding error - Status {response.status}: {error_text}") self._handle_http_error(response.status, error_text, "embeddings") except aiohttp.ClientError as e: diff --git a/backend/app/services/llm/security.py b/backend/app/services/llm/security.py deleted file mode 100644 index 8aa37be..0000000 --- a/backend/app/services/llm/security.py +++ /dev/null @@ -1,325 +0,0 @@ -""" -LLM Security Manager - -Handles prompt injection detection and audit logging. -Provides comprehensive security for LLM interactions. -""" - -import os -import re -import json -import logging -import hashlib -from typing import Dict, Any, List, Optional, Tuple -from datetime import datetime - -from app.core.config import settings - -logger = logging.getLogger(__name__) - - -class SecurityManager: - """Manages security for LLM operations""" - - def __init__(self): - self._setup_prompt_injection_patterns() - - - def _setup_prompt_injection_patterns(self): - """Setup patterns for prompt injection detection""" - self.injection_patterns = [ - # Direct instruction injection - r"(?i)(ignore|forget|disregard|override).{0,20}(instructions|rules|prompts)", - r"(?i)(new|updated|different)\s+(instructions|rules|system)", - r"(?i)act\s+as\s+(if|though)\s+you\s+(are|were)", - r"(?i)pretend\s+(to\s+be|you\s+are)", - r"(?i)you\s+are\s+now\s+(a|an)\s+", - - # System role manipulation - r"(?i)system\s*:\s*", - r"(?i)\[system\]", - r"(?i)", - r"(?i)assistant\s*:\s*", - r"(?i)\[assistant\]", - - # Escape attempts - r"(?i)\\n\\n#+", - r"(?i)```\s*(system|assistant|user)", - r"(?i)---\s*(new|system|override)", - - # Role manipulation - r"(?i)(you|your)\s+(role|purpose|function)\s+(is|has\s+changed)", - r"(?i)switch\s+to\s+(admin|developer|debug)\s+mode", - r"(?i)(admin|root|sudo|developer)\s+(access|mode|privileges)", - - # Information extraction attempts - r"(?i)(show|display|reveal|expose)\s+(your|the)\s+(prompt|instructions|system)", - r"(?i)what\s+(are|were)\s+your\s+(original|initial)\s+(instructions|prompts)", - r"(?i)(debug|verbose|diagnostic)\s+mode", - - # Encoding/obfuscation attempts - r"(?i)base64\s*:", - r"(?i)hex\s*:", - r"(?i)unicode\s*:", - r"(?i)\b[A-Za-z0-9+/]{40,}={0,2}\b", # More specific base64 pattern (longer sequences) - - # SQL injection patterns (more specific to reduce false positives) - r"(?i)(union\s+select|select\s+\*|insert\s+into|update\s+\w+\s+set|delete\s+from|drop\s+table|create\s+table)\s", - r"(?i)(or|and)\s+\d+\s*=\s*\d+", - r"(?i)';?\s*(drop\s+table|delete\s+from|insert\s+into)", - - # Command injection patterns - r"(?i)(exec|eval|system|shell|cmd)\s*\(", - r"(?i)(\$\(|\`)[^)]+(\)|\`)", - r"(?i)&&\s*(rm|del|format)", - - # Jailbreak attempts - r"(?i)jailbreak", - r"(?i)break\s+out\s+of", - r"(?i)escape\s+(the|your)\s+(rules|constraints)", - r"(?i)(DAN|Do\s+Anything\s+Now)", - r"(?i)unrestricted\s+mode", - ] - - self.compiled_patterns = [re.compile(pattern) for pattern in self.injection_patterns] - logger.info(f"Initialized {len(self.injection_patterns)} prompt injection patterns") - - - def validate_prompt_security(self, messages: List[Dict[str, str]]) -> Tuple[bool, float, List[str]]: - """ - Validate messages for prompt injection attempts - - Returns: - Tuple[bool, float, List[str]]: (is_safe, risk_score, detected_patterns) - """ - detected_patterns = [] - total_risk = 0.0 - - # Check if this is a system/RAG request - is_system_request = self._is_system_request(messages) - - for message in messages: - content = message.get("content", "") - if not content: - continue - - # Check against injection patterns with context awareness - for i, pattern in enumerate(self.compiled_patterns): - matches = pattern.findall(content) - if matches: - # Apply context-aware risk calculation - pattern_risk = self._calculate_pattern_risk(i, matches, message.get("role", "user"), is_system_request) - total_risk += pattern_risk - detected_patterns.append({ - "pattern_index": i, - "pattern": self.injection_patterns[i], - "matches": matches, - "risk": pattern_risk - }) - - # Additional security checks with context awareness - total_risk += self._check_message_characteristics(content, message.get("role", "user"), is_system_request) - - # Normalize risk score (0.0 to 1.0) - risk_score = min(total_risk / len(messages) if messages else 0.0, 1.0) - # Never block - always return True for is_safe - is_safe = True - - if detected_patterns: - logger.info(f"Detected {len(detected_patterns)} potential injection patterns, risk score: {risk_score} (system_request: {is_system_request})") - - return is_safe, risk_score, detected_patterns - - def _calculate_pattern_risk(self, pattern_index: int, matches: List, role: str, is_system_request: bool) -> float: - """Calculate risk score for a detected pattern with context awareness""" - # Different patterns have different risk levels - high_risk_patterns = [0, 1, 2, 3, 4, 5, 6, 7, 22, 23, 24] # System manipulation, jailbreak - medium_risk_patterns = [8, 9, 10, 11, 12, 13, 17, 18, 19, 20, 21] # Escape attempts, info extraction - - # Base risk score - base_risk = 0.8 if pattern_index in high_risk_patterns else 0.5 if pattern_index in medium_risk_patterns else 0.3 - - # Apply context-specific risk reduction - if is_system_request or role == "system": - # Reduce risk for system messages and RAG content - if pattern_index in [14, 15, 16]: # Encoding patterns (base64, hex, unicode) - base_risk *= 0.2 # Reduce encoding risk by 80% for system content - elif pattern_index in [17, 18, 19]: # SQL patterns - base_risk *= 0.3 # Reduce SQL risk by 70% for system content - else: - base_risk *= 0.6 # Reduce other risks by 40% for system content - - # Increase risk based on number of matches, but cap it - match_multiplier = min(1.0 + (len(matches) - 1) * 0.1, 1.5) # Reduced multiplier - - return base_risk * match_multiplier - - def _check_message_characteristics(self, content: str, role: str, is_system_request: bool) -> float: - """Check message characteristics for additional risk factors with context awareness""" - risk = 0.0 - - # Excessive length (potential stuffing attack) - less restrictive for system content - length_threshold = 50000 if is_system_request else 10000 # Much higher threshold for system content - if len(content) > length_threshold: - risk += 0.1 if is_system_request else 0.3 - - # High ratio of special characters - more lenient for system content - special_chars = sum(1 for c in content if not c.isalnum() and not c.isspace()) - if len(content) > 0: - char_ratio = special_chars / len(content) - threshold = 0.8 if is_system_request else 0.5 - if char_ratio > threshold: - risk += 0.2 if is_system_request else 0.4 - - # Multiple encoding indicators - reduced risk for system content - encoding_indicators = ["base64", "hex", "unicode", "url", "ascii"] - found_encodings = sum(1 for indicator in encoding_indicators if indicator.lower() in content.lower()) - if found_encodings > 1: - risk += 0.1 if is_system_request else 0.3 - - # Excessive newlines or formatting - more lenient for system content - newline_threshold = 200 if is_system_request else 50 - if content.count('\n') > newline_threshold or content.count('\\n') > newline_threshold: - risk += 0.1 if is_system_request else 0.2 - - return risk - - def _is_system_request(self, messages: List[Dict[str, str]]) -> bool: - """Determine if this is a system/RAG request""" - if not messages: - return False - - # Check for system messages - for message in messages: - if message.get("role") == "system": - return True - - # Check message content for RAG indicators - for message in messages: - content = message.get("content", "") - if ("document:" in content.lower() or - "context:" in content.lower() or - "source:" in content.lower() or - "retrieved:" in content.lower() or - "citation:" in content.lower() or - "reference:" in content.lower()): - return True - - return False - - def create_audit_log( - self, - user_id: str, - api_key_id: int, - provider: str, - model: str, - request_type: str, - risk_score: float, - detected_patterns: List[str], - metadata: Optional[Dict[str, Any]] = None - ) -> Dict[str, Any]: - """Create comprehensive audit log for LLM request""" - audit_entry = { - "timestamp": datetime.utcnow().isoformat(), - "user_id": user_id, - "api_key_id": api_key_id, - "provider": provider, - "model": model, - "request_type": request_type, - "security": { - "risk_score": risk_score, - "detected_patterns": detected_patterns, - "security_check_passed": risk_score < settings.API_SECURITY_RISK_THRESHOLD - }, - "metadata": metadata or {}, - "audit_hash": None # Will be set below - } - - # Create hash for audit integrity - audit_hash = self._create_audit_hash(audit_entry) - audit_entry["audit_hash"] = audit_hash - - # Log based on risk level (never block, only log) - if risk_score >= settings.API_SECURITY_RISK_THRESHOLD: - logger.warning(f"HIGH RISK LLM REQUEST DETECTED (NOT BLOCKED): {json.dumps(audit_entry)}") - elif risk_score >= settings.API_SECURITY_WARNING_THRESHOLD: - logger.info(f"MEDIUM RISK LLM REQUEST: {json.dumps(audit_entry)}") - else: - logger.info(f"LLM REQUEST AUDIT: user={user_id}, model={model}, risk={risk_score:.3f}") - - return audit_entry - - def _create_audit_hash(self, audit_entry: Dict[str, Any]) -> str: - """Create hash for audit trail integrity""" - # Create hash from key fields (excluding the hash itself) - hash_data = { - "timestamp": audit_entry["timestamp"], - "user_id": audit_entry["user_id"], - "api_key_id": audit_entry["api_key_id"], - "provider": audit_entry["provider"], - "model": audit_entry["model"], - "request_type": audit_entry["request_type"], - "risk_score": audit_entry["security"]["risk_score"] - } - - hash_string = json.dumps(hash_data, sort_keys=True) - return hashlib.sha256(hash_string.encode()).hexdigest() - - def log_detailed_request( - self, - messages: List[Dict[str, str]], - model: str, - user_id: str, - provider: str, - context_info: Optional[Dict[str, Any]] = None - ): - """Log detailed LLM request if LOG_LLM_PROMPTS is enabled""" - if not settings.LOG_LLM_PROMPTS: - return - - logger.info("=== DETAILED LLM REQUEST ===") - logger.info(f"Model: {model}") - logger.info(f"Provider: {provider}") - logger.info(f"User ID: {user_id}") - - if context_info: - for key, value in context_info.items(): - logger.info(f"{key}: {value}") - - logger.info("Messages to LLM:") - for i, message in enumerate(messages): - role = message.get("role", "unknown") - content = message.get("content", "")[:500] # Truncate for logging - logger.info(f" Message {i+1} [{role}]: {content}{'...' if len(message.get('content', '')) > 500 else ''}") - - logger.info("=== END DETAILED LLM REQUEST ===") - - def log_detailed_response( - self, - response_content: str, - token_usage: Optional[Dict[str, int]] = None, - provider: str = "unknown" - ): - """Log detailed LLM response if LOG_LLM_PROMPTS is enabled""" - if not settings.LOG_LLM_PROMPTS: - return - - logger.info("=== DETAILED LLM RESPONSE ===") - logger.info(f"Provider: {provider}") - logger.info(f"Response content: {response_content[:500]}{'...' if len(response_content) > 500 else ''}") - - if token_usage: - logger.info(f"Token usage - Prompt: {token_usage.get('prompt_tokens', 0)}, " - f"Completion: {token_usage.get('completion_tokens', 0)}, " - f"Total: {token_usage.get('total_tokens', 0)}") - - logger.info("=== END DETAILED LLM RESPONSE ===") - - -class SecurityError(Exception): - """Security-related errors in LLM operations""" - pass - - -# Global security manager instance -security_manager = SecurityManager() \ No newline at end of file diff --git a/backend/app/services/llm/service.py b/backend/app/services/llm/service.py index bb8e683..d3f2503 100644 --- a/backend/app/services/llm/service.py +++ b/backend/app/services/llm/service.py @@ -17,9 +17,8 @@ from .models import ( ) from .config import config_manager, ProviderConfig from ...core.config import settings -from .security import security_manager from .resilience import ResilienceManagerFactory -from .metrics import metrics_collector +# from .metrics import metrics_collector from .providers import BaseLLMProvider, PrivateModeProvider from .exceptions import ( LLMError, ProviderError, SecurityError, ConfigurationError, @@ -150,45 +149,8 @@ class LLMService: if not request.messages: raise ValidationError("Messages cannot be empty", field="messages") - # Security validation (only if enabled) - messages_dict = [{"role": msg.role, "content": msg.content} for msg in request.messages] - - if settings.API_SECURITY_ENABLED: - is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security(messages_dict) - else: - # Security disabled - always safe - is_safe, risk_score, detected_patterns = True, 0.0, [] - - if not is_safe: - # Log security violation - security_manager.create_audit_log( - user_id=request.user_id, - api_key_id=request.api_key_id, - provider="blocked", - model=request.model, - request_type="chat_completion", - risk_score=risk_score, - detected_patterns=[p.get("pattern", "") for p in detected_patterns] - ) - - # Record blocked request - metrics_collector.record_request( - provider="security", - model=request.model, - request_type="chat_completion", - success=False, - latency_ms=0, - security_risk_score=risk_score, - error_code="SECURITY_BLOCKED", - user_id=request.user_id, - api_key_id=request.api_key_id - ) - - raise SecurityError( - "Request blocked due to security concerns", - risk_score=risk_score, - details={"detected_patterns": detected_patterns} - ) + # Security validation disabled - always allow requests + risk_score = 0.0 # Get provider for model provider_name = self._get_provider_for_model(request.model) @@ -197,18 +159,7 @@ class LLMService: if not provider: raise ProviderError(f"No available provider for model '{request.model}'", provider=provider_name) - # Log detailed request if enabled - security_manager.log_detailed_request( - messages=messages_dict, - model=request.model, - user_id=request.user_id, - provider=provider_name, - context_info={ - "temperature": request.temperature, - "max_tokens": request.max_tokens, - "risk_score": f"{risk_score:.3f}" - } - ) + # Security logging disabled # Execute with resilience resilience_manager = ResilienceManagerFactory.get_manager(provider_name) @@ -222,85 +173,46 @@ class LLMService: non_retryable_exceptions=(SecurityError, ValidationError) ) - # Update response with security information - response.security_check = is_safe - response.risk_score = risk_score - response.detected_patterns = [p.get("pattern", "") for p in detected_patterns] + # Security features disabled - # Log detailed response if enabled - if response.choices: - content = response.choices[0].message.content - security_manager.log_detailed_response( - response_content=content, - token_usage=response.usage.model_dump() if response.usage else None, - provider=provider_name - ) + # Security logging disabled - # Record successful request + # Record successful request - metrics disabled total_latency = (time.time() - start_time) * 1000 - metrics_collector.record_request( - provider=provider_name, - model=request.model, - request_type="chat_completion", - success=True, - latency_ms=total_latency, - token_usage=response.usage.model_dump() if response.usage else None, - security_risk_score=risk_score, - user_id=request.user_id, - api_key_id=request.api_key_id - ) + # metrics_collector.record_request( + # provider=provider_name, + # model=request.model, + # request_type="chat_completion", + # success=True, + # latency_ms=total_latency, + # token_usage=response.usage.model_dump() if response.usage else None, + # security_risk_score=risk_score, + # user_id=request.user_id, + # api_key_id=request.api_key_id + # ) - # Create audit log - security_manager.create_audit_log( - user_id=request.user_id, - api_key_id=request.api_key_id, - provider=provider_name, - model=request.model, - request_type="chat_completion", - risk_score=risk_score, - detected_patterns=[p.get("pattern", "") for p in detected_patterns], - metadata={ - "success": True, - "latency_ms": total_latency, - "token_usage": response.usage.model_dump() if response.usage else None - } - ) + # Security audit logging disabled return response except Exception as e: - # Record failed request + # Record failed request - metrics disabled total_latency = (time.time() - start_time) * 1000 error_code = getattr(e, 'error_code', e.__class__.__name__) + + # metrics_collector.record_request( + # provider=provider_name, + # model=request.model, + # request_type="chat_completion", + # success=False, + # latency_ms=total_latency, + # security_risk_score=risk_score, + # error_code=error_code, + # user_id=request.user_id, + # api_key_id=request.api_key_id + # ) - metrics_collector.record_request( - provider=provider_name, - model=request.model, - request_type="chat_completion", - success=False, - latency_ms=total_latency, - security_risk_score=risk_score, - error_code=error_code, - user_id=request.user_id, - api_key_id=request.api_key_id - ) - - # Create audit log for failure - security_manager.create_audit_log( - user_id=request.user_id, - api_key_id=request.api_key_id, - provider=provider_name, - model=request.model, - request_type="chat_completion", - risk_score=risk_score, - detected_patterns=[p.get("pattern", "") for p in detected_patterns], - metadata={ - "success": False, - "error": str(e), - "error_code": error_code, - "latency_ms": total_latency - } - ) + # Security audit logging disabled raise @@ -309,21 +221,8 @@ class LLMService: if not self._initialized: await self.initialize() - # Security validation (same as non-streaming) - messages_dict = [{"role": msg.role, "content": msg.content} for msg in request.messages] - - if settings.API_SECURITY_ENABLED: - is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security(messages_dict) - else: - # Security disabled - always safe - is_safe, risk_score, detected_patterns = True, 0.0, [] - - if not is_safe: - raise SecurityError( - "Streaming request blocked due to security concerns", - risk_score=risk_score, - details={"detected_patterns": detected_patterns} - ) + # Security validation disabled - always allow streaming requests + risk_score = 0.0 # Get provider provider_name = self._get_provider_for_model(request.model) @@ -345,19 +244,19 @@ class LLMService: yield chunk except Exception as e: - # Record streaming failure + # Record streaming failure - metrics disabled error_code = getattr(e, 'error_code', e.__class__.__name__) - metrics_collector.record_request( - provider=provider_name, - model=request.model, - request_type="chat_completion_stream", - success=False, - latency_ms=0, - security_risk_score=risk_score, - error_code=error_code, - user_id=request.user_id, - api_key_id=request.api_key_id - ) + # metrics_collector.record_request( + # provider=provider_name, + # model=request.model, + # request_type="chat_completion_stream", + # success=False, + # latency_ms=0, + # security_risk_score=risk_score, + # error_code=error_code, + # user_id=request.user_id, + # api_key_id=request.api_key_id + # ) raise async def create_embedding(self, request: EmbeddingRequest) -> EmbeddingResponse: @@ -365,23 +264,8 @@ class LLMService: if not self._initialized: await self.initialize() - # Security validation for embedding input - input_text = request.input if isinstance(request.input, str) else " ".join(request.input) - - if settings.API_SECURITY_ENABLED: - is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security([ - {"role": "user", "content": input_text} - ]) - else: - # Security disabled - always safe - is_safe, risk_score, detected_patterns = True, 0.0, [] - - if not is_safe: - raise SecurityError( - "Embedding request blocked due to security concerns", - risk_score=risk_score, - details={"detected_patterns": detected_patterns} - ) + # Security validation disabled - always allow embedding requests + risk_score = 0.0 # Get provider provider_name = self._get_provider_for_model(request.model) @@ -402,42 +286,40 @@ class LLMService: non_retryable_exceptions=(SecurityError, ValidationError) ) - # Update response with security information - response.security_check = is_safe - response.risk_score = risk_score + # Security features disabled - # Record successful request + # Record successful request - metrics disabled total_latency = (time.time() - start_time) * 1000 - metrics_collector.record_request( - provider=provider_name, - model=request.model, - request_type="embedding", - success=True, - latency_ms=total_latency, - token_usage=response.usage.model_dump() if response.usage else None, - security_risk_score=risk_score, - user_id=request.user_id, - api_key_id=request.api_key_id - ) + # metrics_collector.record_request( + # provider=provider_name, + # model=request.model, + # request_type="embedding", + # success=True, + # latency_ms=total_latency, + # token_usage=response.usage.model_dump() if response.usage else None, + # security_risk_score=risk_score, + # user_id=request.user_id, + # api_key_id=request.api_key_id + # ) return response except Exception as e: - # Record failed request + # Record failed request - metrics disabled total_latency = (time.time() - start_time) * 1000 error_code = getattr(e, 'error_code', e.__class__.__name__) - - metrics_collector.record_request( - provider=provider_name, - model=request.model, - request_type="embedding", - success=False, - latency_ms=total_latency, - security_risk_score=risk_score, - error_code=error_code, - user_id=request.user_id, - api_key_id=request.api_key_id - ) + + # metrics_collector.record_request( + # provider=provider_name, + # model=request.model, + # request_type="embedding", + # success=False, + # latency_ms=total_latency, + # security_risk_score=risk_score, + # error_code=error_code, + # user_id=request.user_id, + # api_key_id=request.api_key_id + # ) raise @@ -492,20 +374,26 @@ class LLMService: return status_dict def get_metrics(self) -> LLMMetrics: - """Get service metrics""" - return metrics_collector.get_metrics() + """Get service metrics - metrics disabled""" + # return metrics_collector.get_metrics() + return LLMMetrics( + total_requests=0, + success_rate=0.0, + avg_latency_ms=0, + error_rates={} + ) def get_health_summary(self) -> Dict[str, Any]: - """Get comprehensive health summary""" - metrics_health = metrics_collector.get_health_summary() + """Get comprehensive health summary - metrics disabled""" + # metrics_health = metrics_collector.get_health_summary() resilience_health = ResilienceManagerFactory.get_all_health_status() - + return { "service_status": "healthy" if self._initialized else "initializing", "startup_time": self._startup_time.isoformat() if self._startup_time else None, "provider_count": len(self._providers), "active_providers": list(self._providers.keys()), - "metrics": metrics_health, + "metrics": {"status": "disabled"}, "resilience": resilience_health } diff --git a/backend/app/services/llm/token_rate_limiter.py b/backend/app/services/llm/token_rate_limiter.py deleted file mode 100644 index 2338a03..0000000 --- a/backend/app/services/llm/token_rate_limiter.py +++ /dev/null @@ -1,153 +0,0 @@ -""" -Token-based rate limiting for LLM service -""" - -import time -import redis -from typing import Dict, Optional, Tuple -from datetime import datetime, timedelta -from ..core.config import settings -from ..core.logging import get_logger - -logger = get_logger(__name__) - - -class TokenRateLimiter: - """Token-based rate limiting implementation""" - - def __init__(self): - try: - self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True) - self.redis_client.ping() - logger.info("Token rate limiter initialized with Redis backend") - except Exception as e: - logger.warning(f"Redis not available for token rate limiting: {e}") - self.redis_client = None - # Fall back to in-memory rate limiting - self.in_memory_store = {} - logger.info("Token rate limiter using in-memory fallback") - - async def check_token_limits( - self, - provider: str, - prompt_tokens: int, - completion_tokens: int = 0 - ) -> Tuple[bool, Dict[str, str]]: - """ - Check if token usage is within limits - - Args: - provider: Provider name (e.g., "privatemode") - prompt_tokens: Number of prompt tokens to use - completion_tokens: Number of completion tokens to use - - Returns: - Tuple of (is_allowed, headers) - """ - # Get token limits from configuration - from .config import get_config - config = get_config() - token_limits = config.token_limits_per_minute - - # Check organization-wide limits - org_key = f"tokens:org:{provider}" - - # Get current usage - current_usage = await self._get_token_usage(org_key) - - # Calculate new usage - new_prompt_tokens = current_usage.get("prompt_tokens", 0) + prompt_tokens - new_completion_tokens = current_usage.get("completion_tokens", 0) + completion_tokens - - # Check limits - prompt_limit = token_limits.get("prompt_tokens", 20000) - completion_limit = token_limits.get("completion_tokens", 10000) - - is_allowed = ( - new_prompt_tokens <= prompt_limit and - new_completion_tokens <= completion_limit - ) - - if is_allowed: - # Update usage - await self._update_token_usage(org_key, prompt_tokens, completion_tokens) - logger.debug(f"Token usage updated: {new_prompt_tokens}/{prompt_limit} prompt, " - f"{new_completion_tokens}/{completion_limit} completion") - - # Calculate remaining tokens - remaining_prompt = max(0, prompt_limit - new_prompt_tokens) - remaining_completion = max(0, completion_limit - new_completion_tokens) - - # Create headers - headers = { - "X-TokenLimit-Prompt-Remaining": str(remaining_prompt), - "X-TokenLimit-Completion-Remaining": str(remaining_completion), - "X-TokenLimit-Prompt-Limit": str(prompt_limit), - "X-TokenLimit-Completion-Limit": str(completion_limit), - "X-TokenLimit-Reset": str(int(time.time() + 60)) # Reset in 1 minute - } - - if not is_allowed: - logger.warning(f"Token rate limit exceeded for {provider}. " - f"Requested: {prompt_tokens} prompt, {completion_tokens} completion. " - f"Current: {current_usage}") - - return is_allowed, headers - - async def _get_token_usage(self, key: str) -> Dict[str, int]: - """Get current token usage""" - if self.redis_client: - try: - data = self.redis_client.hgetall(key) - if data: - return { - "prompt_tokens": int(data.get("prompt_tokens", 0)), - "completion_tokens": int(data.get("completion_tokens", 0)), - "updated_at": float(data.get("updated_at", time.time())) - } - except Exception as e: - logger.error(f"Error getting token usage from Redis: {e}") - - # Fallback to in-memory - return self.in_memory_store.get(key, {"prompt_tokens": 0, "completion_tokens": 0}) - - async def _update_token_usage(self, key: str, prompt_tokens: int, completion_tokens: int): - """Update token usage""" - if self.redis_client: - try: - pipe = self.redis_client.pipeline() - pipe.hincrby(key, "prompt_tokens", prompt_tokens) - pipe.hincrby(key, "completion_tokens", completion_tokens) - pipe.hset(key, "updated_at", time.time()) - pipe.expire(key, 60) # Expire after 1 minute - pipe.execute() - except Exception as e: - logger.error(f"Error updating token usage in Redis: {e}") - # Fallback to in-memory - self._update_in_memory(key, prompt_tokens, completion_tokens) - else: - self._update_in_memory(key, prompt_tokens, completion_tokens) - - def _update_in_memory(self, key: str, prompt_tokens: int, completion_tokens: int): - """Update in-memory token usage""" - if key not in self.in_memory_store: - self.in_memory_store[key] = {"prompt_tokens": 0, "completion_tokens": 0} - - self.in_memory_store[key]["prompt_tokens"] += prompt_tokens - self.in_memory_store[key]["completion_tokens"] += completion_tokens - self.in_memory_store[key]["updated_at"] = time.time() - - def cleanup_expired(self): - """Clean up expired entries (for in-memory store)""" - if not self.redis_client: - current_time = time.time() - expired_keys = [ - key for key, data in self.in_memory_store.items() - if current_time - data.get("updated_at", 0) > 60 - ] - for key in expired_keys: - del self.in_memory_store[key] - - -# Global token rate limiter instance -token_rate_limiter = TokenRateLimiter() \ No newline at end of file diff --git a/backend/app/services/rag_service.py b/backend/app/services/rag_service.py index 119cb26..1741362 100644 --- a/backend/app/services/rag_service.py +++ b/backend/app/services/rag_service.py @@ -755,10 +755,11 @@ class RAGService: # Process with RAG module try: + # Pass file_path in metadata so JSONL indexing can reopen the source file processed_doc = await rag_module.process_document( - file_content, - document.original_filename, - {} + file_content, + document.original_filename, + {"file_path": document.file_path} ) # Success case - update document with processed content @@ -873,4 +874,4 @@ class RAGService: except Exception as e: logger.error(f"Error reprocessing document {document_id}: {e}") - return False \ No newline at end of file + return False diff --git a/backend/modules/rag/main.py b/backend/modules/rag/main.py index 7d75fbd..d56503c 100644 --- a/backend/modules/rag/main.py +++ b/backend/modules/rag/main.py @@ -638,11 +638,19 @@ class RAGModule(BaseModule): np.random.seed(hash(text) % 2**32) return np.random.random(self.embedding_model.get("dimension", 768)).tolist() - async def _generate_embeddings(self, texts: List[str]) -> List[List[float]]: + async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]: """Generate embeddings for multiple texts (batch processing)""" if self.embedding_service: + # Add task-specific prefixes for better E5 model performance + if is_document: + # For document passages, use "passage:" prefix + prefixed_texts = [f"passage: {text}" for text in texts] + else: + # For queries, use "query:" prefix (handled in search method) + prefixed_texts = texts + # Use real embedding service for batch processing - return await self.embedding_service.get_embeddings(texts) + return await self.embedding_service.get_embeddings(prefixed_texts) else: # Fallback to individual processing embeddings = [] @@ -917,69 +925,75 @@ class RAGModule(BaseModule): async def _process_jsonl(self, content: bytes, filename: str) -> str: """Process JSONL files (newline-delimited JSON) - + Specifically optimized for helpjuice-export.jsonl format: - Each line contains a JSON object with 'id' and 'payload' - Payload contains 'question', 'language', and 'answer' fields - Combines question and answer into searchable content + + Performance optimizations: + - Processes articles in smaller batches to reduce memory usage + - Uses streaming approach for large files """ try: + # Use streaming approach for large files jsonl_content = content.decode('utf-8', errors='replace') lines = jsonl_content.strip().split('\n') - + processed_articles = [] - + batch_size = 50 # Process in batches of 50 articles + for line_num, line in enumerate(lines, 1): if not line.strip(): continue - + try: # Parse each JSON line data = json.loads(line) - + # Handle helpjuice export format if 'payload' in data: payload = data['payload'] article_id = data.get('id', f'article_{line_num}') - + # Extract fields question = payload.get('question', '') answer = payload.get('answer', '') language = payload.get('language', 'EN') - + # Combine question and answer for better search if question or answer: # Format as Q&A for better context article_text = f"## {question}\n\n{answer}\n\n" - + # Add language tag if not English if language != 'EN': article_text = f"[{language}] {article_text}" - + # Add metadata separator article_text += f"---\nArticle ID: {article_id}\nLanguage: {language}\n\n" - + processed_articles.append(article_text) - + # Handle generic JSONL format else: # Convert the entire JSON object to readable text json_text = json.dumps(data, indent=2, ensure_ascii=False) processed_articles.append(json_text + "\n\n") - + except json.JSONDecodeError as e: logger.warning(f"Error parsing JSONL line {line_num}: {e}") continue except Exception as e: logger.warning(f"Error processing JSONL line {line_num}: {e}") continue - + # Combine all articles combined_text = '\n'.join(processed_articles) - + logger.info(f"Successfully processed {len(processed_articles)} articles from JSONL file {filename}") return combined_text - + except Exception as e: logger.error(f"Error processing JSONL file {filename}: {e}") return "" @@ -1153,7 +1167,7 @@ class RAGModule(BaseModule): chunks = self._chunk_text(content) # Generate embeddings for all chunks in batch (more efficient) - embeddings = await self._generate_embeddings(chunks) + embeddings = await self._generate_embeddings(chunks, is_document=True) # Create document points points = [] @@ -1200,10 +1214,28 @@ class RAGModule(BaseModule): """Index a processed document in the vector database""" if not self.enabled: raise RuntimeError("RAG module not initialized") - + collection_name = collection_name or self.default_collection_name - + try: + # Special handling for JSONL files + if processed_doc.file_type == 'jsonl': + # Import the optimized JSONL processor + from app.services.jsonl_processor import JSONLProcessor + jsonl_processor = JSONLProcessor(self) + + # Read the original file content + with open(processed_doc.metadata.get('file_path', ''), 'rb') as f: + file_content = f.read() + + # Process using the optimized JSONL processor + return await jsonl_processor.process_and_index_jsonl( + collection_name=collection_name, + content=file_content, + filename=processed_doc.original_filename, + metadata=processed_doc.metadata + ) + # Ensure collection exists await self._ensure_collection_exists(collection_name) @@ -1216,7 +1248,7 @@ class RAGModule(BaseModule): chunks = self._chunk_text(processed_doc.content) # Generate embeddings for all chunks in batch (more efficient) - embeddings = await self._generate_embeddings(chunks) + embeddings = await self._generate_embeddings(chunks, is_document=True) # Create document points with enhanced metadata points = [] @@ -1339,24 +1371,48 @@ class RAGModule(BaseModule): score_threshold=score_threshold / 2 # Lower threshold for initial search ) - # Combine scores + # Combine scores with improved normalization hybrid_weights = self.config.get("hybrid_weights", {"vector": 0.7, "bm25": 0.3}) vector_weight = hybrid_weights.get("vector", 0.7) bm25_weight = hybrid_weights.get("bm25", 0.3) - # Create hybrid results + # Get score distributions for better normalization + vector_scores = [r.score for r in vector_results] + bm25_scores_list = list(bm25_scores.values()) + + # Calculate statistics for normalization + if vector_scores: + v_max = max(vector_scores) + v_min = min(vector_scores) + v_range = v_max - v_min if v_max != v_min else 1 + else: + v_max, v_min, v_range = 1, 0, 1 + + if bm25_scores_list: + bm25_max = max(bm25_scores_list) + bm25_min = min(bm25_scores_list) + bm25_range = bm25_max - bm25_min if bm25_max != bm25_min else 1 + else: + bm25_max, bm25_min, bm25_range = 1, 0, 1 + + # Create hybrid results with improved scoring hybrid_results = [] for result in vector_results: doc_id = result.payload.get("document_id", "") vector_score = result.score bm25_score = bm25_scores.get(doc_id, 0.0) - # Normalize scores (simple min-max normalization) - vector_norm = (vector_score - score_threshold) / (1.0 - score_threshold) if vector_score > score_threshold else 0 - bm25_norm = min(bm25_score, 1.0) # BM25 scores are typically 0-1 + # Improved normalization using actual score distributions + vector_norm = (vector_score - v_min) / v_range if v_range > 0 else 0.5 + bm25_norm = (bm25_score - bm25_min) / bm25_range if bm25_range > 0 else 0.5 - # Calculate hybrid score - hybrid_score = (vector_weight * vector_norm) + (bm25_weight * bm25_norm) + # Apply reciprocal rank fusion for better combination + # This gives more weight to documents that rank highly in both methods + rrf_vector = 1.0 / (1.0 + vector_results.index(result) + 1) # +1 to avoid division by zero + rrf_bm25 = 1.0 / (1.0 + sorted(bm25_scores_list, reverse=True).index(bm25_score) + 1) if bm25_score in bm25_scores_list else 0 + + # Calculate hybrid score using both normalized scores and RRF + hybrid_score = (vector_weight * vector_norm + bm25_weight * bm25_norm) * 0.7 + (rrf_vector + rrf_bm25) * 0.3 # Create new point with hybrid score hybrid_point = ScoredPoint( @@ -1435,7 +1491,7 @@ class RAGModule(BaseModule): # Normalize score to 0-1 range return min(score / 10.0, 1.0) # Simple normalization - async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]: + async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None, score_threshold: float = None) -> List[SearchResult]: """Search for relevant documents""" if not self.enabled: raise RuntimeError("RAG module not initialized") @@ -1453,8 +1509,10 @@ class RAGModule(BaseModule): import time start_time = time.time() - # Generate query embedding - query_embedding = await self._generate_embedding(query) + # Generate query embedding with task-specific prefix for better retrieval + # The E5 model works better with "query:" prefix for search queries + optimized_query = f"query: {query}" + query_embedding = await self._generate_embedding(optimized_query) # Build filter search_filter = None @@ -1474,7 +1532,8 @@ class RAGModule(BaseModule): # Check if hybrid search is enabled enable_hybrid = self.config.get("enable_hybrid", False) - score_threshold = self.config.get("score_threshold", 0.3) + # Use provided score_threshold or fall back to config + search_score_threshold = score_threshold if score_threshold is not None else self.config.get("score_threshold", 0.3) if enable_hybrid and NLTK_AVAILABLE: # Perform hybrid search (vector + BM25) @@ -1484,7 +1543,7 @@ class RAGModule(BaseModule): query_vector=query_embedding, query_filter=search_filter, limit=max_results, - score_threshold=score_threshold + score_threshold=search_score_threshold ) else: # Pure vector search with improved threshold @@ -1493,7 +1552,7 @@ class RAGModule(BaseModule): query_vector=query_embedding, query_filter=search_filter, limit=max_results, - score_threshold=score_threshold + score_threshold=search_score_threshold ) logger.info(f"Raw search results count: {len(search_results)}") @@ -1841,9 +1900,9 @@ async def index_processed_document(processed_doc: ProcessedDocument, collection_ """Index a processed document""" return await rag_module.index_processed_document(processed_doc, collection_name) -async def search_documents(query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]: +async def search_documents(query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None, score_threshold: float = None) -> List[SearchResult]: """Search documents""" - return await rag_module.search_documents(query, max_results, filters, collection_name) + return await rag_module.search_documents(query, max_results, filters, collection_name, score_threshold) async def delete_document(document_id: str, collection_name: str = None) -> bool: """Delete a document""" diff --git a/frontend/src/app/api/auth/login/route.ts b/frontend/src/app/api/auth/login/route.ts index c32f93e..fefb7fe 100644 --- a/frontend/src/app/api/auth/login/route.ts +++ b/frontend/src/app/api/auth/login/route.ts @@ -7,7 +7,7 @@ export async function POST(request: NextRequest) { // Make request to backend auth endpoint without requiring existing auth const baseUrl = process.env.INTERNAL_API_URL || `http://enclava-backend:${process.env.BACKEND_INTERNAL_PORT || '8000'}` - const url = `${baseUrl}/api/auth/login` + const url = `${baseUrl}/api-internal/v1/auth/login` const response = await fetch(url, { method: 'POST', diff --git a/frontend/src/app/rag/page.tsx b/frontend/src/app/rag/page.tsx index 87616c1..48ae013 100644 --- a/frontend/src/app/rag/page.tsx +++ b/frontend/src/app/rag/page.tsx @@ -85,8 +85,31 @@ function RAGPageContent() { const loadStats = async () => { try { const data = await apiClient.get('/api-internal/v1/rag/stats') - setStats(data.stats) + console.log('Stats API response:', data) + + // Check if the response has the expected structure + if (data && data.stats && data.stats.collections) { + console.log('✓ Stats has collections property') + setStats(data.stats) + } else { + console.error('✗ Invalid stats structure:', data) + // Set default empty stats to prevent error + setStats({ + collections: { total: 0, active: 0 }, + documents: { total: 0, processing: 0, processed: 0 }, + storage: { total_size_bytes: 0, total_size_mb: 0 }, + vectors: { total: 0 } + }) + } } catch (error) { + console.error('Error loading stats:', error) + // Set default empty stats on error + setStats({ + collections: { total: 0, active: 0 }, + documents: { total: 0, processing: 0, processed: 0 }, + storage: { total_size_bytes: 0, total_size_mb: 0 }, + vectors: { total: 0 } + }) } } diff --git a/frontend/src/components/rag/document-browser.tsx b/frontend/src/components/rag/document-browser.tsx index c3e643f..2643e9c 100644 --- a/frontend/src/components/rag/document-browser.tsx +++ b/frontend/src/components/rag/document-browser.tsx @@ -9,7 +9,7 @@ import { Badge } from "@/components/ui/badge" import { Separator } from "@/components/ui/separator" import { AlertDialog, AlertDialogAction, AlertDialogCancel, AlertDialogContent, AlertDialogDescription, AlertDialogFooter, AlertDialogHeader, AlertDialogTitle, AlertDialogTrigger } from "@/components/ui/alert-dialog" import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle, DialogTrigger } from "@/components/ui/dialog" -import { Search, FileText, Trash2, Eye, Download, Calendar, Hash, FileIcon, Filter } from "lucide-react" +import { Search, FileText, Trash2, Eye, Download, Calendar, Hash, FileIcon, Filter, RefreshCw } from "lucide-react" import { useToast } from "@/hooks/use-toast" import { apiClient } from "@/lib/api-client" import { config } from "@/lib/config" @@ -56,6 +56,7 @@ export function DocumentBrowser({ collections, selectedCollection, onCollectionS const [filterStatus, setFilterStatus] = useState("all") const [selectedDocument, setSelectedDocument] = useState(null) const [deleting, setDeleting] = useState(null) + const [reprocessing, setReprocessing] = useState(null) const { toast } = useToast() useEffect(() => { @@ -157,6 +158,43 @@ export function DocumentBrowser({ collections, selectedCollection, onCollectionS } } + const handleReprocessDocument = async (documentId: string) => { + setReprocessing(documentId) + + try { + await apiClient.post(`/api-internal/v1/rag/documents/${documentId}/reprocess`) + + // Update the document status to processing in the UI + setDocuments(prev => prev.map(doc => + doc.id === documentId + ? { ...doc, status: 'processing' as const, processed_at: new Date().toISOString() } + : doc + )) + + toast({ + title: "Success", + description: "Document reprocessing started", + }) + + // Reload documents after a short delay to see status updates + setTimeout(() => { + loadDocuments() + }, 2000) + + } catch (error) { + const errorMessage = error instanceof Error ? error.message : "Failed to reprocess document" + toast({ + title: "Error", + description: errorMessage.includes("Cannot reprocess document with status 'processed'") + ? "Cannot reprocess documents that are already processed" + : errorMessage, + variant: "destructive", + }) + } finally { + setReprocessing(null) + } + } + const formatFileSize = (bytes: number) => { if (bytes === 0) return '0 Bytes' const k = 1024 @@ -432,6 +470,21 @@ export function DocumentBrowser({ collections, selectedCollection, onCollectionS + + + ))} + + + {/* Search Box */} +
+
+ setQuery(e.target.value)} + onKeyPress={(e) => e.key === 'Enter' && performSearch()} + placeholder="Enter your search query..." + className="flex-1 px-4 py-2 border border-gray-300 rounded-md focus:outline-none focus:ring-2 focus:ring-blue-500" + /> + +
+ + {error && ( +
+ Error: {error} +
+ )} + + {/* Results Summary */} + {results.length > 0 && ( +
+

+ Found {results.length} results in {searchTime.toFixed(0)}ms + {config.enable_hybrid && ( + • Hybrid Search Enabled + )} +

+
+ )} +
+ + {/* Search Results */} +
+ {results.map((result, index) => ( +
+
+

Result {index + 1}

+ = 0.5 ? 'bg-green-100 text-green-800' : + result.score >= 0.3 ? 'bg-yellow-100 text-yellow-800' : + 'bg-red-100 text-red-800' + }`}> + Score: {result.score.toFixed(4)} + +
+ +
+ {result.document.content} +
+ + {/* Metadata */} +
+ {result.document.metadata.content_type && ( + Type: {result.document.metadata.content_type} + )} + {result.document.metadata.language && ( + Language: {result.document.metadata.language} + )} + {result.document.metadata.filename && ( + File: {result.document.metadata.filename} + )} + {result.document.metadata.chunk_index !== undefined && ( + + Chunk: {result.document.metadata.chunk_index + 1}/{result.document.metadata.chunk_count || '?'} + + )} +
+ + {/* Debug Details */} + {config.show_timing && result.debug_info && ( +
+

Debug Information:

+ {result.debug_info.vector_score !== undefined && ( +

Vector Score: {result.debug_info.vector_score.toFixed(4)}

+ )} + {result.debug_info.bm25_score !== undefined && ( +

BM25 Score: {result.debug_info.bm25_score.toFixed(4)}

+ )} + {result.document.metadata.question && ( +
+

Question: {result.document.metadata.question}

+
+ )} +
+ )} +
+ ))} +
+ + {/* Debug Section */} + {debugInfo && Object.keys(debugInfo).length > 0 && ( +
+

Debug Information

+ + {debugInfo.score_stats && ( +
+

Score Statistics:

+
+
Min: {debugInfo.score_stats.min?.toFixed(4)}
+
Max: {debugInfo.score_stats.max?.toFixed(4)}
+
Avg: {debugInfo.score_stats.avg?.toFixed(4)}
+
StdDev: {debugInfo.score_stats.stddev?.toFixed(4)}
+
+
+ )} + + {debugInfo.collection_stats && ( +
+

Collection Stats:

+
+

Total Documents: {debugInfo.collection_stats.total_documents}

+

Total Chunks: {debugInfo.collection_stats.total_chunks}

+

Languages: {debugInfo.collection_stats.languages?.join(', ')}

+
+
+ )} + + {debugInfo.query_embedding && config.show_embeddings && ( +
+

Query Embedding (first 10 dims):

+

+ [{debugInfo.query_embedding.slice(0, 10).map(x => x.toFixed(6)).join(', ')}...] +

+
+ )} +
+ )} + + + {/* Configuration Panel */} +
+
+

⚙️ Configuration

+ +
+ {/* Search Settings */} +
+

Search Settings

+
+
+ + updateConfig('max_results', parseInt(e.target.value))} + className="w-full" + /> +
+
+ + updateConfig('score_threshold', parseFloat(e.target.value))} + className="w-full" + /> +
+
+ + {collectionsLoading ? ( + + ) : ( + + )} +
+
+
+ + {/* Chunking Settings */} +
+

Chunking Settings

+
+
+ + updateConfig('chunk_size', parseInt(e.target.value))} + className="w-full" + /> +
+
+ + updateConfig('chunk_overlap', parseInt(e.target.value))} + className="w-full" + /> +
+
+
+ + {/* Hybrid Search */} +
+

Hybrid Search

+
+ + {config.enable_hybrid && ( + <> +
+ + updateConfig('vector_weight', parseFloat(e.target.value))} + className="w-full" + /> +
+
+ + updateConfig('bm25_weight', parseFloat(e.target.value))} + className="w-full" + /> +
+ + )} +
+
+ + {/* Debug Options */} +
+

Debug Options

+
+ + +
+
+
+
+
+ + + ); +} \ No newline at end of file From f3f5cca50b05dc5f6ea79a90023602de63084f58 Mon Sep 17 00:00:00 2001 From: Aljaz Ceru Date: Wed, 1 Oct 2025 15:50:34 +0200 Subject: [PATCH 12/13] fixing rag --- .gitignore | 1 + backend/Dockerfile | 3 + backend/app/modules/chatbot/main.py | 13 ++++ backend/modules/rag/main.py | 52 ++++++++++++++-- backend/requirements.txt | 6 +- backend/scripts/import_jsonl.py | 92 +++++++++++++++++++++++++++++ 6 files changed, 159 insertions(+), 8 deletions(-) create mode 100644 backend/scripts/import_jsonl.py diff --git a/.gitignore b/.gitignore index 6642c56..4abb4a1 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ frontend/.env.development backend/storage/ + # TypeScript *.tsbuildinfo diff --git a/backend/Dockerfile b/backend/Dockerfile index aaa4fe6..0cb709e 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -17,6 +17,9 @@ RUN apt-get update && apt-get install -y \ ffmpeg \ && rm -rf /var/lib/apt/lists/* +# Install CPU-only PyTorch and compatible numpy first (faster download) +RUN pip install --no-cache-dir torch==2.5.1+cpu torchaudio==2.5.1+cpu --index-url https://download.pytorch.org/whl/cpu -f https://download.pytorch.org/whl/torch_stable.html + # Copy requirements and install Python dependencies COPY requirements.txt . COPY tests/requirements-test.txt ./tests/ diff --git a/backend/app/modules/chatbot/main.py b/backend/app/modules/chatbot/main.py index 3f9b8dc..96ae1b2 100644 --- a/backend/app/modules/chatbot/main.py +++ b/backend/app/modules/chatbot/main.py @@ -453,9 +453,22 @@ class ChatbotModule(BaseModule): guardrails += ( "When asked about encryption or SD-card backups, do not claim that backups are encrypted unless the provided context explicitly uses wording like 'encrypt', 'encrypted', or 'encryption'. " "If such wording is absent, state clearly that the SD-card backup is not encrypted. " + "Product policy: For BitBox devices, microSD (SD card) backups are not encrypted; verification steps may require a recovery password, but that is not encryption. Do not conflate password entry with encryption. " ) extra_instructions["additional_instructions"] = guardrails + # Deterministic enforcement: if encryption question and RAG context does not explicitly + # contain encryption wording, return policy answer without calling the LLM. + ctx_lower = (rag_context or "").lower() + has_encryption_terms = any(k in ctx_lower for k in ["encrypt", "encrypted", "encryption", "decrypt", "decryption"]) + if is_encryption and not has_encryption_terms: + policy_answer = ( + "No. BitBox microSD (SD card) backups are not encrypted. " + "Verification may require entering a recovery password, but that does not encrypt the backup — " + "it only proves you have the correct credentials to restore. Keep the card and password secure." + ) + return policy_answer, sources + messages = self._build_conversation_messages(db_messages, config, rag_context, extra_instructions) # Note: Current user message is already included in db_messages from the query diff --git a/backend/modules/rag/main.py b/backend/modules/rag/main.py index d56503c..92f43f6 100644 --- a/backend/modules/rag/main.py +++ b/backend/modules/rag/main.py @@ -1495,8 +1495,16 @@ class RAGModule(BaseModule): """Search for relevant documents""" if not self.enabled: raise RuntimeError("RAG module not initialized") - + collection_name = collection_name or self.default_collection_name + + # Special handling for collections with different vector dimensions + SPECIAL_COLLECTIONS = { + "bitbox02_faq_local": { + "dimension": 384, + "model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" + } + } max_results = max_results or self.config.get("max_results", 10) # Check cache (include collection name in cache key) @@ -1510,9 +1518,24 @@ class RAGModule(BaseModule): start_time = time.time() # Generate query embedding with task-specific prefix for better retrieval - # The E5 model works better with "query:" prefix for search queries - optimized_query = f"query: {query}" - query_embedding = await self._generate_embedding(optimized_query) + try: + # Check if this is a special collection + if collection_name in SPECIAL_COLLECTIONS: + # Try to import sentence-transformers + import sentence_transformers + from sentence_transformers import SentenceTransformer + model = SentenceTransformer(SPECIAL_COLLECTIONS[collection_name]["model"]) + query_embedding = model.encode([query], normalize_embeddings=True)[0].tolist() + logger.info(f"Using {SPECIAL_COLLECTIONS[collection_name]['dimension']}-dim local model for {collection_name}") + else: + # The E5 model works better with "query:" prefix for search queries + optimized_query = f"query: {query}" + query_embedding = await self._generate_embedding(optimized_query) + except ImportError: + # Fallback to default embedding if sentence-transformers is not available + logger.warning(f"sentence-transformers not available, falling back to default embedding for {collection_name}") + optimized_query = f"query: {query}" + query_embedding = await self._generate_embedding(optimized_query) # Build filter search_filter = None @@ -1565,14 +1588,31 @@ class RAGModule(BaseModule): doc_id = result.payload.get("document_id") content = result.payload.get("content", "") score = result.score - + + # Generic content extraction for documents without a 'content' field + if not content: + # Build content from all text-based fields in the payload + # This makes the RAG module completely agnostic to document structure + text_fields = [] + for field, value in result.payload.items(): + # Skip system/metadata fields + if field not in ["document_id", "chunk_index", "chunk_count", "indexed_at", "processed_at", + "file_hash", "mime_type", "file_type", "created_at", "__collection_metadata__"]: + # Include any field that has a non-empty string value + if value and isinstance(value, str) and len(value.strip()) > 0: + text_fields.append(f"{field}: {value}") + + # Join all text fields to create content + if text_fields: + content = "\n\n".join(text_fields) + # Log each raw result for debugging logger.info(f"\n--- Raw Result {i+1} ---") logger.info(f"Score: {score}") logger.info(f"Document ID: {doc_id}") logger.info(f"Content preview (first 200 chars): {content[:200]}") logger.info(f"Metadata keys: {list(result.payload.keys())}") - + # Aggregate scores by document if doc_id in document_scores: document_scores[doc_id]["score"] = max(document_scores[doc_id]["score"], score) diff --git a/backend/requirements.txt b/backend/requirements.txt index c4ec167..b8fd274 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -46,6 +46,7 @@ qdrant-client==1.7.0 # Text Processing tiktoken==0.5.1 +numpy>=1.26.0 # Basic document processing (lightweight) markitdown==0.0.1a2 @@ -56,8 +57,9 @@ python-docx==1.1.0 # nltk==3.8.1 # spacy==3.7.2 -# Heavy ML dependencies (REMOVED - unused in codebase) -# sentence-transformers==2.6.1 # REMOVED - not used anywhere in codebase +# Heavy ML dependencies (sentence-transformers will be installed separately) +# Note: PyTorch is already installed in the base Docker image +sentence-transformers==2.6.1 # Added back - needed for bitbox02_faq_local collection # transformers==4.35.2 # REMOVED - already commented out # Configuration diff --git a/backend/scripts/import_jsonl.py b/backend/scripts/import_jsonl.py new file mode 100644 index 0000000..a932883 --- /dev/null +++ b/backend/scripts/import_jsonl.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +""" +Import a JSONL file into a Qdrant collection from inside the backend container. + +Usage (from host): + docker compose exec enclava-backend bash -lc \ + 'python /app/scripts/import_jsonl.py \ + --collection rag_test_import_859b1f01 \ + --file /app/_to_delete/helpjuice-export.jsonl' + +Notes: + - Runs fully inside the backend, so Docker service hostnames (e.g. enclava-qdrant) + and privatemode-proxy are reachable. + - Uses RAGModule + JSONLProcessor to embed/index each JSONL line. + - Creates the collection if missing (size=1024, cosine). +""" + +import argparse +import asyncio +import os +from datetime import datetime + + +async def import_jsonl(collection_name: str, file_path: str): + from qdrant_client import QdrantClient + from qdrant_client.models import Distance, VectorParams + from app.modules.rag.main import RAGModule + from app.services.jsonl_processor import JSONLProcessor + from app.core.config import settings + + if not os.path.exists(file_path): + raise SystemExit(f"File not found: {file_path}") + + # Ensure collection exists (inside container uses Docker DNS hostnames) + client = QdrantClient(host=settings.QDRANT_HOST, port=settings.QDRANT_PORT) + collections = client.get_collections().collections + if not any(c.name == collection_name for c in collections): + client.create_collection( + collection_name=collection_name, + vectors_config=VectorParams(size=1024, distance=Distance.COSINE), + ) + print(f"Created Qdrant collection '{collection_name}' (size=1024, cosine)") + else: + print(f"Using existing Qdrant collection '{collection_name}'") + + # Initialize RAG + rag = RAGModule({ + "chunk_size": 300, + "chunk_overlap": 50, + "max_results": 10, + "score_threshold": 0.3, + "embedding_model": "intfloat/multilingual-e5-large-instruct", + }) + await rag.initialize() + + # Process JSONL + processor = JSONLProcessor(rag) + with open(file_path, "rb") as f: + content = f.read() + + doc_id = await processor.process_and_index_jsonl( + collection_name=collection_name, + content=content, + filename=os.path.basename(file_path), + metadata={ + "source": "jsonl_upload", + "upload_date": datetime.utcnow().isoformat(), + "file_path": os.path.abspath(file_path), + }, + ) + + # Report stats using safe HTTP method to avoid client parsing issues + try: + info = await rag._get_collection_info_safely(collection_name) + print(f"Import complete. Points: {info.get('points_count', 0)}, vector_size: {info.get('vector_size', 'n/a')}") + except Exception as e: + print(f"Import complete. (Could not fetch collection info safely: {e})") + await rag.cleanup() + return doc_id + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--collection", required=True, help="Qdrant collection name") + ap.add_argument("--file", required=True, help="Path inside container (e.g. /app/_to_delete/...).") + args = ap.parse_args() + + asyncio.run(import_jsonl(args.collection, args.file)) + + +if __name__ == "__main__": + main() From 8391dd5170e29f446b252b546c1726c979388cef Mon Sep 17 00:00:00 2001 From: Aljaz Ceru Date: Wed, 1 Oct 2025 17:05:04 +0200 Subject: [PATCH 13/13] vector size test --- backend/modules/rag/main.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/backend/modules/rag/main.py b/backend/modules/rag/main.py index 92f43f6..2672c44 100644 --- a/backend/modules/rag/main.py +++ b/backend/modules/rag/main.py @@ -1503,6 +1503,10 @@ class RAGModule(BaseModule): "bitbox02_faq_local": { "dimension": 384, "model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" + }, + "bitbox_local_rag": { + "dimension": 384, + "model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" } } max_results = max_results or self.config.get("max_results", 10)