mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
working chatbot, rag weird
This commit is contained in:
@@ -16,6 +16,7 @@ from .models import (
|
||||
ModelInfo, ProviderStatus, LLMMetrics
|
||||
)
|
||||
from .config import config_manager, ProviderConfig
|
||||
from ...core.config import settings
|
||||
from .security import security_manager
|
||||
from .resilience import ResilienceManagerFactory
|
||||
from .metrics import metrics_collector
|
||||
@@ -149,19 +150,17 @@ class LLMService:
|
||||
if not request.messages:
|
||||
raise ValidationError("Messages cannot be empty", field="messages")
|
||||
|
||||
# Security validation
|
||||
# Chatbot and RAG system requests should have relaxed security validation
|
||||
is_system_request = (
|
||||
request.user_id == "rag_system" or
|
||||
request.user_id == "chatbot_user" or
|
||||
str(request.user_id).startswith("chatbot_")
|
||||
)
|
||||
|
||||
# Security validation (only if enabled)
|
||||
messages_dict = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||
is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security(messages_dict)
|
||||
|
||||
if not is_safe and not is_system_request:
|
||||
# Log security violation for regular user requests
|
||||
|
||||
if settings.API_SECURITY_ENABLED:
|
||||
is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security(messages_dict)
|
||||
else:
|
||||
# Security disabled - always safe
|
||||
is_safe, risk_score, detected_patterns = True, 0.0, []
|
||||
|
||||
if not is_safe:
|
||||
# Log security violation
|
||||
security_manager.create_audit_log(
|
||||
user_id=request.user_id,
|
||||
api_key_id=request.api_key_id,
|
||||
@@ -171,7 +170,7 @@ class LLMService:
|
||||
risk_score=risk_score,
|
||||
detected_patterns=[p.get("pattern", "") for p in detected_patterns]
|
||||
)
|
||||
|
||||
|
||||
# Record blocked request
|
||||
metrics_collector.record_request(
|
||||
provider="security",
|
||||
@@ -184,18 +183,12 @@ class LLMService:
|
||||
user_id=request.user_id,
|
||||
api_key_id=request.api_key_id
|
||||
)
|
||||
|
||||
|
||||
raise SecurityError(
|
||||
"Request blocked due to security concerns",
|
||||
risk_score=risk_score,
|
||||
details={"detected_patterns": detected_patterns}
|
||||
)
|
||||
elif not is_safe and is_system_request:
|
||||
# For system requests (chatbot/RAG), log but don't block
|
||||
logger.info(f"System request contains security patterns (risk_score={risk_score:.2f}) but allowing due to system context")
|
||||
if detected_patterns:
|
||||
logger.info(f"Detected patterns: {[p.get('pattern', 'unknown') for p in detected_patterns]}")
|
||||
# Allow system requests regardless of security patterns
|
||||
|
||||
# Get provider for model
|
||||
provider_name = self._get_provider_for_model(request.model)
|
||||
@@ -317,25 +310,20 @@ class LLMService:
|
||||
await self.initialize()
|
||||
|
||||
# Security validation (same as non-streaming)
|
||||
# Chatbot and RAG system requests should have relaxed security validation
|
||||
is_system_request = (
|
||||
request.user_id == "rag_system" or
|
||||
request.user_id == "chatbot_user" or
|
||||
str(request.user_id).startswith("chatbot_")
|
||||
)
|
||||
|
||||
messages_dict = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||
is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security(messages_dict)
|
||||
|
||||
if not is_safe and not is_system_request:
|
||||
|
||||
if settings.API_SECURITY_ENABLED:
|
||||
is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security(messages_dict)
|
||||
else:
|
||||
# Security disabled - always safe
|
||||
is_safe, risk_score, detected_patterns = True, 0.0, []
|
||||
|
||||
if not is_safe:
|
||||
raise SecurityError(
|
||||
"Streaming request blocked due to security concerns",
|
||||
risk_score=risk_score,
|
||||
details={"detected_patterns": detected_patterns}
|
||||
)
|
||||
elif not is_safe and is_system_request:
|
||||
# For system requests (chatbot/RAG), log but don't block
|
||||
logger.info(f"System streaming request contains security patterns (risk_score={risk_score:.2f}) but allowing due to system context")
|
||||
|
||||
# Get provider
|
||||
provider_name = self._get_provider_for_model(request.model)
|
||||
@@ -378,33 +366,22 @@ class LLMService:
|
||||
await self.initialize()
|
||||
|
||||
# Security validation for embedding input
|
||||
# RAG system requests (document embedding) should use relaxed security validation
|
||||
is_rag_system = request.user_id == "rag_system"
|
||||
|
||||
if not is_rag_system:
|
||||
# Apply normal security validation for user-generated embedding requests
|
||||
input_text = request.input if isinstance(request.input, str) else " ".join(request.input)
|
||||
input_text = request.input if isinstance(request.input, str) else " ".join(request.input)
|
||||
|
||||
if settings.API_SECURITY_ENABLED:
|
||||
is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security([
|
||||
{"role": "user", "content": input_text}
|
||||
])
|
||||
|
||||
if not is_safe:
|
||||
raise SecurityError(
|
||||
"Embedding request blocked due to security concerns",
|
||||
risk_score=risk_score,
|
||||
details={"detected_patterns": detected_patterns}
|
||||
)
|
||||
else:
|
||||
# For RAG system requests, log but don't block (document content can contain legitimate text that triggers patterns)
|
||||
input_text = request.input if isinstance(request.input, str) else " ".join(request.input)
|
||||
is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security([
|
||||
{"role": "user", "content": input_text}
|
||||
])
|
||||
|
||||
if detected_patterns:
|
||||
logger.info(f"RAG document embedding contains security patterns (risk_score={risk_score:.2f}) but allowing due to document context")
|
||||
|
||||
# Allow RAG system requests regardless of security patterns
|
||||
# Security disabled - always safe
|
||||
is_safe, risk_score, detected_patterns = True, 0.0, []
|
||||
|
||||
if not is_safe:
|
||||
raise SecurityError(
|
||||
"Embedding request blocked due to security concerns",
|
||||
risk_score=risk_score,
|
||||
details={"detected_patterns": detected_patterns}
|
||||
)
|
||||
|
||||
# Get provider
|
||||
provider_name = self._get_provider_for_model(request.model)
|
||||
|
||||
Reference in New Issue
Block a user