swapping to local embeddings

This commit is contained in:
2025-10-06 17:28:55 +02:00
parent bae86fb5a2
commit 487d7d0cc9
8 changed files with 398 additions and 319 deletions

View File

@@ -7,6 +7,7 @@ from typing import List, Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from pydantic import BaseModel
import io
import asyncio
@@ -15,6 +16,7 @@ from datetime import datetime
from app.db.database import get_db
from app.core.security import get_current_user
from app.models.user import User
from app.models.rag_collection import RagCollection
from app.services.rag_service import RAGService
from app.utils.exceptions import APIException
@@ -268,7 +270,19 @@ async def get_documents(
try:
collection_id_int = int(collection_id)
except (ValueError, TypeError):
raise HTTPException(status_code=400, detail="Invalid collection_id format")
# Attempt to resolve by Qdrant collection name
collection_row = await db.scalar(
select(RagCollection).where(RagCollection.qdrant_collection_name == collection_id)
)
if collection_row:
collection_id_int = collection_row.id
else:
# Unknown collection identifier; return empty result instead of erroring out
return {
"success": True,
"documents": [],
"total": 0
}
rag_service = RAGService(db)
documents = await rag_service.get_documents(

View File

@@ -129,6 +129,7 @@ class Settings(BaseSettings):
RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5"))
RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true"
RAG_WARN_ON_FALLBACK: bool = os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true"
RAG_EMBEDDING_MODEL: str = os.getenv("RAG_EMBEDDING_MODEL", "BAAI/bge-small-en")
RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300"))
RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120"))
RAG_INDEXING_TIMEOUT: int = int(os.getenv("RAG_INDEXING_TIMEOUT", "120"))
@@ -154,4 +155,3 @@ class Settings(BaseSettings):
# Global settings instance
settings = Settings()

View File

@@ -8,6 +8,7 @@ import json
import logging
import mimetypes
import re
import time
from typing import Any, Dict, List, Optional, Tuple, Union
from datetime import datetime
from dataclasses import dataclass, asdict
@@ -146,6 +147,11 @@ class RAGModule(BaseModule):
# Update with any provided config
if config:
self.config.update(config)
# Ensure embedding model configured (defaults to local BGE small)
default_embedding_model = getattr(settings, 'RAG_EMBEDDING_MODEL', 'BAAI/bge-small-en')
self.config.setdefault("embedding_model", default_embedding_model)
self.default_embedding_model = default_embedding_model
# Content processing components
self.nlp_model = None
@@ -178,6 +184,7 @@ class RAGModule(BaseModule):
"supported_types": len(self.supported_types)
}
self.search_cache = {}
self.collection_vector_sizes: Dict[str, int] = {}
def get_required_permissions(self) -> List[Permission]:
"""Return list of permissions this module requires"""
@@ -215,7 +222,7 @@ class RAGModule(BaseModule):
self.initialized = True
log_module_event("rag", "initialized", {
"vector_db": self.config.get("vector_db", "qdrant"),
"embedding_model": self.embedding_model.get("model_name", "intfloat/multilingual-e5-large-instruct"),
"embedding_model": self.embedding_model.get("model_name", self.default_embedding_model),
"chunk_size": self.config.get("chunk_size", 400),
"max_results": self.config.get("max_results", 10),
"supported_file_types": list(self.supported_types.keys()),
@@ -427,8 +434,7 @@ class RAGModule(BaseModule):
# Prefer enhanced embedding service (rate limiting + retry)
from app.services.enhanced_embedding_service import enhanced_embedding_service as embedding_service
# Use intfloat/multilingual-e5-large-instruct for LLM service integration
model_name = self.config.get("embedding_model", "intfloat/multilingual-e5-large-instruct")
model_name = self.config.get("embedding_model", self.default_embedding_model)
embedding_service.model_name = model_name
# Initialize the embedding service
@@ -439,7 +445,7 @@ class RAGModule(BaseModule):
logger.info(f"Successfully initialized embedding service with {model_name}")
return {
"model_name": model_name,
"dimension": embedding_service.dimension or 768
"dimension": embedding_service.dimension or 384
}
else:
# Fallback to mock implementation
@@ -447,7 +453,7 @@ class RAGModule(BaseModule):
self.embedding_service = None
return {
"model_name": model_name,
"dimension": 1024 # Default dimension for intfloat/multilingual-e5-large-instruct
"dimension": 384 # Default dimension matching local bge-small embeddings
}
async def _initialize_content_processing(self):
@@ -588,16 +594,37 @@ class RAGModule(BaseModule):
collection_names = await self._get_collections_safely()
if collection_name not in collection_names:
# Create collection
# Create collection with the current embedding dimension
vector_dimension = self.embedding_model.get(
"dimension",
getattr(self.embedding_service, "dimension", 384) or 384
)
self.qdrant_client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=self.embedding_model.get("dimension", 768),
size=vector_dimension,
distance=Distance.COSINE
)
)
self.collection_vector_sizes[collection_name] = vector_dimension
log_module_event("rag", "collection_created", {"collection": collection_name})
else:
# Cache existing collection vector size for later alignment
try:
info = self.qdrant_client.get_collection(collection_name)
vectors_param = getattr(info.config.params, "vectors", None) if hasattr(info, "config") else None
existing_size = None
if vectors_param is not None and hasattr(vectors_param, "size"):
existing_size = vectors_param.size
elif isinstance(vectors_param, dict):
existing_size = vectors_param.get("size")
if existing_size:
self.collection_vector_sizes[collection_name] = existing_size
except Exception as inner_error:
logger.debug(f"Unable to cache vector size for collection {collection_name}: {inner_error}")
except Exception as e:
logger.error(f"Error ensuring collection exists: {e}")
raise
@@ -632,12 +659,13 @@ class RAGModule(BaseModule):
async def _generate_embedding(self, text: str) -> List[float]:
"""Generate embedding for text"""
if self.embedding_service:
# Use real embedding service
return await self.embedding_service.get_embedding(text)
vector = await self.embedding_service.get_embedding(text)
return vector
else:
# Fallback to deterministic random embedding for consistency
np.random.seed(hash(text) % 2**32)
return np.random.random(self.embedding_model.get("dimension", 768)).tolist()
fallback_dim = self.embedding_model.get("dimension", getattr(self.embedding_service, "dimension", 384) or 384)
return np.random.random(fallback_dim).tolist()
async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]:
"""Generate embeddings for multiple texts (batch processing)"""
@@ -651,14 +679,90 @@ class RAGModule(BaseModule):
prefixed_texts = texts
# Use real embedding service for batch processing
return await self.embedding_service.get_embeddings(prefixed_texts)
backend = getattr(self.embedding_service, "backend", "unknown")
start_time = time.time()
logger.info(
"Embedding batch requested",
extra={
"backend": backend,
"model": getattr(self.embedding_service, "model_name", "unknown"),
"count": len(prefixed_texts),
"scope": "documents" if is_document else "queries"
},
)
embeddings = await self.embedding_service.get_embeddings(prefixed_texts)
duration = time.time() - start_time
logger.info(
"Embedding batch finished",
extra={
"backend": backend,
"model": getattr(self.embedding_service, "model_name", "unknown"),
"count": len(embeddings),
"scope": "documents" if is_document else "queries",
"duration_sec": round(duration, 4)
},
)
return embeddings
else:
# Fallback to individual processing
logger.warning(
"Embedding service unavailable, falling back to per-item generation",
extra={
"count": len(texts),
"scope": "documents" if is_document else "queries"
},
)
embeddings = []
for text in texts:
embedding = await self._generate_embedding(text)
embeddings.append(embedding)
return embeddings
def _get_collection_vector_size(self, collection_name: Optional[str]) -> int:
"""Return the expected vector size for a collection, caching results."""
default_dim = self.embedding_model.get(
"dimension",
getattr(self.embedding_service, "dimension", 384) or 384
)
if not collection_name:
return default_dim
if collection_name in self.collection_vector_sizes:
return self.collection_vector_sizes[collection_name]
try:
info = self.qdrant_client.get_collection(collection_name)
vectors_param = getattr(info.config.params, "vectors", None) if hasattr(info, "config") else None
existing_size = None
if vectors_param is not None and hasattr(vectors_param, "size"):
existing_size = vectors_param.size
elif isinstance(vectors_param, dict):
existing_size = vectors_param.get("size")
if existing_size:
self.collection_vector_sizes[collection_name] = existing_size
return existing_size
except Exception as e:
logger.debug(f"Unable to determine vector size for {collection_name}: {e}")
self.collection_vector_sizes[collection_name] = default_dim
return default_dim
def _align_embedding_dimension(self, vector: List[float], collection_name: Optional[str]) -> List[float]:
"""Pad or truncate embeddings to match the target collection dimension."""
if vector is None:
return vector
target_dim = self._get_collection_vector_size(collection_name)
current_dim = len(vector)
if current_dim == target_dim:
return vector
if current_dim > target_dim:
return vector[:target_dim]
padding = [0.0] * (target_dim - current_dim)
return vector + padding
def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]:
"""Split text into overlapping chunks for better context preservation"""
@@ -1176,6 +1280,7 @@ class RAGModule(BaseModule):
# Create document points
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
aligned_embedding = self._align_embedding_dimension(embedding, collection_name)
chunk_id = str(uuid.uuid4())
chunk_metadata = {
@@ -1189,7 +1294,7 @@ class RAGModule(BaseModule):
points.append(PointStruct(
id=chunk_id,
vector=embedding,
vector=aligned_embedding,
payload=chunk_metadata
))
@@ -1257,6 +1362,7 @@ class RAGModule(BaseModule):
# Create document points with enhanced metadata
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
aligned_embedding = self._align_embedding_dimension(embedding, collection_name)
chunk_id = str(uuid.uuid4())
chunk_metadata = {
@@ -1280,7 +1386,7 @@ class RAGModule(BaseModule):
points.append(PointStruct(
id=chunk_id,
vector=embedding,
vector=aligned_embedding,
payload=chunk_metadata
))
@@ -1514,9 +1620,9 @@ class RAGModule(BaseModule):
start_time = time.time()
# Generate query embedding with task-specific prefix for better retrieval
# The E5 model works better with "query:" prefix for search queries
optimized_query = f"query: {query}"
query_embedding = await self._generate_embedding(optimized_query)
query_embedding = self._align_embedding_dimension(query_embedding, collection_name)
# Build filter
search_filter = None

View File

@@ -1,148 +1,133 @@
"""
Embedding Service
Provides text embedding functionality using LLM service
Provides local sentence-transformer embeddings (default: BAAI/bge-small-en).
Falls back to deterministic random vectors when the local model is unavailable.
"""
import asyncio
import logging
import time
from typing import List, Dict, Any, Optional
import numpy as np
from app.core.config import settings
logger = logging.getLogger(__name__)
class EmbeddingService:
"""Service for generating text embeddings using LLM service"""
def __init__(self, model_name: str = "intfloat/multilingual-e5-large-instruct"):
self.model_name = model_name
self.dimension = 1024 # Actual dimension for intfloat/multilingual-e5-large-instruct
"""Service for generating text embeddings using a local transformer model"""
def __init__(self, model_name: Optional[str] = None):
self.model_name = model_name or getattr(settings, "RAG_EMBEDDING_MODEL", "BAAI/bge-small-en")
self.dimension = 384 # bge-small produces 384-d vectors
self.initialized = False
self.local_model = None
self.backend = "uninitialized"
async def initialize(self):
"""Initialize the embedding service with LLM service"""
try:
from app.services.llm.service import llm_service
# Initialize LLM service if not already done
if not llm_service._initialized:
await llm_service.initialize()
# Test LLM service health
if not llm_service._initialized:
logger.error("LLM service not initialized")
return False
from sentence_transformers import SentenceTransformer
# Check if PrivateMode provider is available
try:
provider_status = await llm_service.get_provider_status()
privatemode_status = provider_status.get("privatemode")
if not privatemode_status or privatemode_status.status != "healthy":
logger.error(f"PrivateMode provider not available: {privatemode_status}")
return False
except Exception as e:
logger.error(f"Failed to check provider status: {e}")
return False
loop = asyncio.get_running_loop()
def load_model():
# Load model synchronously in a worker thread to avoid blocking event loop
return SentenceTransformer(self.model_name)
self.local_model = await loop.run_in_executor(None, load_model)
self.dimension = self.local_model.get_sentence_embedding_dimension()
self.initialized = True
logger.info(f"Embedding service initialized with LLM service: {self.model_name} (dimension: {self.dimension})")
self.backend = "sentence_transformer"
logger.info(
"Embedding service initialized with local model %s (dimension: %s)",
self.model_name,
self.dimension,
)
return True
except Exception as e:
logger.error(f"Failed to initialize LLM embedding service: {e}")
logger.warning("Using fallback random embeddings")
except ImportError as exc:
logger.error("sentence-transformers not installed: %s", exc)
logger.warning("Falling back to random embeddings")
self.local_model = None
self.initialized = False
self.backend = "fallback_random"
return False
except Exception as exc:
logger.error(f"Failed to load local embedding model {self.model_name}: {exc}")
logger.warning("Falling back to random embeddings")
self.local_model = None
self.initialized = False
self.backend = "fallback_random"
return False
async def get_embedding(self, text: str) -> List[float]:
"""Get embedding for a single text"""
embeddings = await self.get_embeddings([text])
return embeddings[0]
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get embeddings for multiple texts using LLM service"""
if not self.initialized:
# Fallback to random embeddings if not initialized
logger.warning("LLM service not available, using random embeddings")
return self._generate_fallback_embeddings(texts)
try:
embeddings = []
# Process texts in batches for efficiency
batch_size = 10
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
# Process each text in the batch
batch_embeddings = []
for text in batch:
try:
# Truncate text if it's too long for the model's context window
# intfloat/multilingual-e5-large-instruct has a 512 token limit, truncate to ~400 tokens worth of chars
# Rough estimate: 1 token ≈ 4 characters, so 400 tokens ≈ 1600 chars
max_chars = 1600
if len(text) > max_chars:
truncated_text = text[:max_chars]
logger.debug(f"Truncated text from {len(text)} to {max_chars} chars for embedding")
else:
truncated_text = text
# Guard: skip empty inputs (validator rejects empty strings)
if not truncated_text.strip():
logger.debug("Empty input for embedding; using fallback vector")
batch_embeddings.append(self._generate_fallback_embedding(text))
continue
start_time = time.time()
# Call LLM service embedding endpoint
from app.services.llm.service import llm_service
from app.services.llm.models import EmbeddingRequest
llm_request = EmbeddingRequest(
model=self.model_name,
input=truncated_text,
user_id="rag_system",
api_key_id=0 # System API key
)
response = await llm_service.create_embedding(llm_request)
# Extract embedding from response
if response.data and len(response.data) > 0:
embedding = response.data[0].embedding
if embedding:
batch_embeddings.append(embedding)
# Update dimension based on actual embedding size
if not hasattr(self, '_dimension_confirmed'):
self.dimension = len(embedding)
self._dimension_confirmed = True
logger.info(f"Confirmed embedding dimension: {self.dimension}")
else:
logger.warning(f"No embedding in response for text: {text[:50]}...")
batch_embeddings.append(self._generate_fallback_embedding(text))
else:
logger.warning(f"Invalid response structure for text: {text[:50]}...")
batch_embeddings.append(self._generate_fallback_embedding(text))
except Exception as e:
logger.error(f"Error getting embedding for text: {e}")
batch_embeddings.append(self._generate_fallback_embedding(text))
embeddings.extend(batch_embeddings)
return embeddings
except Exception as e:
logger.error(f"Error generating embeddings with LLM service: {e}")
# Fallback to random embeddings
return self._generate_fallback_embeddings(texts)
if self.local_model:
if not texts:
return []
loop = asyncio.get_running_loop()
try:
embeddings = await loop.run_in_executor(
None,
lambda: self.local_model.encode(
texts,
convert_to_numpy=True,
normalize_embeddings=True,
),
)
duration = time.time() - start_time
logger.info(
"Embedding batch completed",
extra={
"backend": self.backend,
"model": self.model_name,
"count": len(texts),
"dimension": self.dimension,
"duration_sec": round(duration, 4),
},
)
return embeddings.tolist()
except Exception as exc:
logger.error(f"Local embedding generation failed: {exc}")
self.backend = "fallback_random"
return self._generate_fallback_embeddings(texts, duration=time.time() - start_time)
logger.warning("Local embedding model unavailable; using fallback random embeddings")
self.backend = "fallback_random"
return self._generate_fallback_embeddings(texts, duration=time.time() - start_time)
def _generate_fallback_embeddings(self, texts: List[str]) -> List[List[float]]:
def _generate_fallback_embeddings(self, texts: List[str], duration: float = None) -> List[List[float]]:
"""Generate fallback random embeddings when model unavailable"""
embeddings = []
for text in texts:
embeddings.append(self._generate_fallback_embedding(text))
logger.info(
"Embedding batch completed",
extra={
"backend": "fallback_random",
"model": self.model_name,
"count": len(texts),
"dimension": self.dimension,
"duration_sec": round(duration or 0.0, 4),
},
)
return embeddings
def _generate_fallback_embedding(self, text: str) -> List[float]:
"""Generate a single fallback embedding"""
dimension = self.dimension or 1024 # Default dimension for intfloat/multilingual-e5-large-instruct
dimension = self.dimension or 384
# Use hash for reproducible random embeddings
np.random.seed(hash(text) % 2**32)
return np.random.random(dimension).tolist()
@@ -169,22 +154,15 @@ class EmbeddingService:
"model_name": self.model_name,
"model_loaded": self.initialized,
"dimension": self.dimension,
"backend": "LLM Service",
"backend": self.backend,
"initialized": self.initialized
}
async def cleanup(self):
"""Cleanup resources"""
# Cleanup LLM service to prevent memory leaks
try:
from .llm.service import llm_service
if llm_service._initialized:
await llm_service.cleanup()
logger.info("Cleaned up LLM service from embedding service")
except Exception as e:
logger.error(f"Error cleaning up LLM service: {e}")
self.local_model = None
self.initialized = False
self.backend = "uninitialized"
# Global embedding service instance

View File

@@ -1,14 +1,13 @@
# Enhanced Embedding Service with Rate Limiting Handling
"""
Enhanced embedding service with robust rate limiting and retry logic
Enhanced embedding service that adds basic retry semantics around the local
embedding generator. Since embeddings are fully local, rate-limiting metadata is
largely informational but retained for compatibility.
"""
import asyncio
import logging
import time
from typing import List, Dict, Any, Optional
import numpy as np
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from .embedding_service import EmbeddingService
from app.core.config import settings
@@ -17,9 +16,9 @@ logger = logging.getLogger(__name__)
class EnhancedEmbeddingService(EmbeddingService):
"""Enhanced embedding service with rate limiting handling"""
"""Enhanced embedding service with lightweight retry bookkeeping"""
def __init__(self, model_name: str = "intfloat/multilingual-e5-large-instruct"):
def __init__(self, model_name: Optional[str] = None):
super().__init__(model_name)
self.rate_limit_tracker = {
'requests_count': 0,
@@ -34,161 +33,16 @@ class EnhancedEmbeddingService(EmbeddingService):
async def get_embeddings_with_retry(self, texts: List[str], max_retries: int = None) -> tuple[List[List[float]], bool]:
"""
Get embeddings with rate limiting and retry logic
Get embeddings with retry bookkeeping.
"""
if max_retries is None:
max_retries = int(getattr(settings, 'RAG_EMBEDDING_RETRY_COUNT', 3))
batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 3))
if not self.initialized:
logger.warning("Embedding service not initialized, using fallback")
return self._generate_fallback_embeddings(texts), False
embeddings = []
success = True
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
batch_embeddings, batch_success = await self._get_batch_embeddings_with_retry(batch, max_retries)
embeddings.extend(batch_embeddings)
success = success and batch_success
# Add delay between batches to avoid rate limiting
if i + batch_size < len(texts):
delay = self.rate_limit_tracker['delay_between_batches']
await asyncio.sleep(delay) # Configurable delay between batches
return embeddings, success
async def _get_batch_embeddings_with_retry(self, texts: List[str], max_retries: int) -> tuple[List[List[float]], bool]:
"""Get embeddings for a batch with retry logic"""
last_error = None
for attempt in range(max_retries + 1):
try:
# Check rate limit before making request
if self._is_rate_limited():
delay = self._get_rate_limit_delay()
logger.warning(f"Rate limit detected, waiting {delay} seconds")
await asyncio.sleep(delay)
continue
# Make the request
embeddings = await self._get_embeddings_batch_impl(texts)
return embeddings, True
except Exception as e:
last_error = e
error_msg = str(e).lower()
# Check if it's a rate limit error
if any(indicator in error_msg for indicator in ['429', 'rate limit', 'too many requests', 'quota exceeded']):
logger.warning(f"Rate limit error (attempt {attempt + 1}/{max_retries + 1}): {e}")
self._update_rate_limit_tracker(success=False)
if attempt < max_retries:
delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)]
logger.info(f"Retrying in {delay} seconds...")
await asyncio.sleep(delay)
continue
else:
logger.error(f"Max retries exceeded for rate limit, using fallback embeddings")
return self._generate_fallback_embeddings(texts), False
else:
# Non-rate-limit error
logger.error(f"Error generating embeddings: {e}")
if attempt < max_retries:
delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)]
await asyncio.sleep(delay)
else:
logger.error("Max retries exceeded, using fallback embeddings")
return self._generate_fallback_embeddings(texts), False
# If we get here, all retries failed
logger.error(f"All retries failed, last error: {last_error}")
return self._generate_fallback_embeddings(texts), False
async def _get_embeddings_batch_impl(self, texts: List[str]) -> List[List[float]]:
"""Implementation of getting embeddings for a batch"""
from app.services.llm.service import llm_service
from app.services.llm.models import EmbeddingRequest
embeddings = []
for text in texts:
# Respect rate limit before each request
while self._is_rate_limited():
delay = self._get_rate_limit_delay()
logger.warning(f"Rate limit window exceeded, waiting {delay:.2f}s before next request")
await asyncio.sleep(delay)
# Truncate text if needed
max_chars = 1600
truncated_text = text[:max_chars] if len(text) > max_chars else text
llm_request = EmbeddingRequest(
model=self.model_name,
input=truncated_text,
user_id="rag_system",
api_key_id=0
embeddings = await super().get_embeddings(texts)
success = self.local_model is not None
if not success:
logger.warning(
"Embedding service operating in fallback mode; consider installing the local model %s",
self.model_name,
)
response = await llm_service.create_embedding(llm_request)
if response.data and len(response.data) > 0:
embedding = response.data[0].embedding
if embedding:
embeddings.append(embedding)
if not hasattr(self, '_dimension_confirmed'):
self.dimension = len(embedding)
self._dimension_confirmed = True
else:
raise ValueError("Empty embedding in response")
else:
raise ValueError("Invalid response structure")
# Count this successful request and optionally delay between requests
self._update_rate_limit_tracker(success=True)
per_req_delay = self.rate_limit_tracker.get('delay_per_request', 0)
if per_req_delay and per_req_delay > 0:
await asyncio.sleep(per_req_delay)
return embeddings
def _is_rate_limited(self) -> bool:
"""Check if we're currently rate limited"""
now = time.time()
window_start = self.rate_limit_tracker['window_start']
# Reset window if it's expired
if now - window_start > self.rate_limit_tracker['window_size']:
self.rate_limit_tracker['requests_count'] = 0
self.rate_limit_tracker['window_start'] = now
return False
# Check if we've exceeded the limit
return self.rate_limit_tracker['requests_count'] >= self.rate_limit_tracker['max_requests_per_minute']
def _get_rate_limit_delay(self) -> float:
"""Get delay to wait for rate limit reset"""
now = time.time()
window_end = self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size']
return max(0, window_end - now)
def _update_rate_limit_tracker(self, success: bool):
"""Update the rate limit tracker"""
now = time.time()
# Reset window if it's expired
if now - self.rate_limit_tracker['window_start'] > self.rate_limit_tracker['window_size']:
self.rate_limit_tracker['requests_count'] = 0
self.rate_limit_tracker['window_start'] = now
# Increment counter on successful requests
if success:
self.rate_limit_tracker['requests_count'] += 1
return embeddings, success
async def get_embedding_stats(self) -> Dict[str, Any]:
"""Get embedding service statistics including rate limiting info"""

View File

@@ -566,10 +566,14 @@ class RAGService:
logger.warning(f"Could not check existing collections: {e}")
# Create collection with proper vector configuration
from app.services.embedding_service import embedding_service
vector_dimension = getattr(embedding_service, 'dimension', 384) or 384
client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=1024, # Updated for multilingual-e5-large-instruct model
size=vector_dimension,
distance=Distance.COSINE
),
optimizers_config=models.OptimizersConfigDiff(

View File

@@ -8,6 +8,7 @@ import json
import logging
import mimetypes
import re
import time
from typing import Any, Dict, List, Optional, Tuple, Union
from datetime import datetime
from dataclasses import dataclass, asdict
@@ -146,6 +147,11 @@ class RAGModule(BaseModule):
# Update with any provided config
if config:
self.config.update(config)
# Ensure embedding model configured (defaults to local BGE small)
default_embedding_model = getattr(settings, 'RAG_EMBEDDING_MODEL', 'BAAI/bge-small-en')
self.config.setdefault("embedding_model", default_embedding_model)
self.default_embedding_model = default_embedding_model
# Content processing components
self.nlp_model = None
@@ -178,6 +184,7 @@ class RAGModule(BaseModule):
"supported_types": len(self.supported_types)
}
self.search_cache = {}
self.collection_vector_sizes: Dict[str, int] = {}
def get_required_permissions(self) -> List[Permission]:
"""Return list of permissions this module requires"""
@@ -215,7 +222,7 @@ class RAGModule(BaseModule):
self.initialized = True
log_module_event("rag", "initialized", {
"vector_db": self.config.get("vector_db", "qdrant"),
"embedding_model": self.embedding_model.get("model_name", "intfloat/multilingual-e5-large-instruct"),
"embedding_model": self.embedding_model.get("model_name", self.default_embedding_model),
"chunk_size": self.config.get("chunk_size", 400),
"max_results": self.config.get("max_results", 10),
"supported_file_types": list(self.supported_types.keys()),
@@ -426,8 +433,7 @@ class RAGModule(BaseModule):
"""Initialize embedding model"""
from app.services.embedding_service import embedding_service
# Use intfloat/multilingual-e5-large-instruct for LLM service integration
model_name = self.config.get("embedding_model", "intfloat/multilingual-e5-large-instruct")
model_name = self.config.get("embedding_model", self.default_embedding_model)
embedding_service.model_name = model_name
# Initialize the embedding service
@@ -438,7 +444,7 @@ class RAGModule(BaseModule):
logger.info(f"Successfully initialized embedding service with {model_name}")
return {
"model_name": model_name,
"dimension": embedding_service.dimension or 768
"dimension": embedding_service.dimension or 384
}
else:
# Fallback to mock implementation
@@ -446,7 +452,7 @@ class RAGModule(BaseModule):
self.embedding_service = None
return {
"model_name": model_name,
"dimension": 1024 # Default dimension for intfloat/multilingual-e5-large-instruct
"dimension": 384 # Default dimension matching local bge-small embeddings
}
async def _initialize_content_processing(self):
@@ -585,18 +591,39 @@ class RAGModule(BaseModule):
try:
# Use safe collection fetching to avoid Pydantic validation errors
collection_names = await self._get_collections_safely()
if collection_name not in collection_names:
# Create collection
# Create collection with the current embedding dimension
vector_dimension = self.embedding_model.get(
"dimension",
getattr(self.embedding_service, "dimension", 384) or 384
)
self.qdrant_client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=self.embedding_model.get("dimension", 768),
size=vector_dimension,
distance=Distance.COSINE
)
)
self.collection_vector_sizes[collection_name] = vector_dimension
log_module_event("rag", "collection_created", {"collection": collection_name})
else:
# Cache existing collection vector size for later alignment
try:
info = self.qdrant_client.get_collection(collection_name)
vectors_param = getattr(info.config.params, "vectors", None) if hasattr(info, "config") else None
existing_size = None
if vectors_param is not None and hasattr(vectors_param, "size"):
existing_size = vectors_param.size
elif isinstance(vectors_param, dict):
existing_size = vectors_param.get("size")
if existing_size:
self.collection_vector_sizes[collection_name] = existing_size
except Exception as inner_error:
logger.debug(f"Unable to cache vector size for collection {collection_name}: {inner_error}")
except Exception as e:
logger.error(f"Error ensuring collection exists: {e}")
raise
@@ -632,11 +659,13 @@ class RAGModule(BaseModule):
"""Generate embedding for text"""
if self.embedding_service:
# Use real embedding service
return await self.embedding_service.get_embedding(text)
vector = await self.embedding_service.get_embedding(text)
return vector
else:
# Fallback to deterministic random embedding for consistency
np.random.seed(hash(text) % 2**32)
return np.random.random(self.embedding_model.get("dimension", 768)).tolist()
fallback_dim = self.embedding_model.get("dimension", getattr(self.embedding_service, "dimension", 384) or 384)
return np.random.random(fallback_dim).tolist()
async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]:
"""Generate embeddings for multiple texts (batch processing)"""
@@ -650,14 +679,91 @@ class RAGModule(BaseModule):
prefixed_texts = texts
# Use real embedding service for batch processing
return await self.embedding_service.get_embeddings(prefixed_texts)
backend = getattr(self.embedding_service, "backend", "unknown")
start_time = time.time()
logger.info(
"Embedding batch requested",
extra={
"backend": backend,
"model": getattr(self.embedding_service, "model_name", "unknown"),
"count": len(prefixed_texts),
"scope": "documents" if is_document else "queries"
},
)
embeddings = await self.embedding_service.get_embeddings(prefixed_texts)
duration = time.time() - start_time
logger.info(
"Embedding batch finished",
extra={
"backend": backend,
"model": getattr(self.embedding_service, "model_name", "unknown"),
"count": len(embeddings),
"scope": "documents" if is_document else "queries",
"duration_sec": round(duration, 4)
},
)
return embeddings
else:
# Fallback to individual processing
logger.warning(
"Embedding service unavailable, falling back to per-item generation",
extra={
"count": len(texts),
"scope": "documents" if is_document else "queries"
},
)
embeddings = []
for text in texts:
embedding = await self._generate_embedding(text)
embeddings.append(embedding)
return embeddings
def _get_collection_vector_size(self, collection_name: Optional[str]) -> int:
"""Return the expected vector size for a collection, caching results."""
default_dim = self.embedding_model.get(
"dimension",
getattr(self.embedding_service, "dimension", 384) or 384
)
if not collection_name:
return default_dim
if collection_name in self.collection_vector_sizes:
return self.collection_vector_sizes[collection_name]
try:
info = self.qdrant_client.get_collection(collection_name)
vectors_param = getattr(info.config.params, "vectors", None) if hasattr(info, "config") else None
existing_size = None
if vectors_param is not None and hasattr(vectors_param, "size"):
existing_size = vectors_param.size
elif isinstance(vectors_param, dict):
existing_size = vectors_param.get("size")
if existing_size:
self.collection_vector_sizes[collection_name] = existing_size
return existing_size
except Exception as e:
logger.debug(f"Unable to determine vector size for {collection_name}: {e}")
self.collection_vector_sizes[collection_name] = default_dim
return default_dim
def _align_embedding_dimension(self, vector: List[float], collection_name: Optional[str]) -> List[float]:
"""Pad or truncate embeddings to match the target collection dimension."""
if vector is None:
return vector
target_dim = self._get_collection_vector_size(collection_name)
current_dim = len(vector)
if current_dim == target_dim:
return vector
if current_dim > target_dim:
return vector[:target_dim]
# Pad with zeros to reach the target dimension
padding = [0.0] * (target_dim - current_dim)
return vector + padding
def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]:
"""Split text into overlapping chunks for better context preservation"""
@@ -1180,8 +1286,9 @@ class RAGModule(BaseModule):
# Create document points
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
aligned_embedding = self._align_embedding_dimension(embedding, collection_name)
chunk_id = str(uuid.uuid4())
chunk_metadata = {
**metadata,
"document_id": doc_id,
@@ -1193,7 +1300,7 @@ class RAGModule(BaseModule):
points.append(PointStruct(
id=chunk_id,
vector=embedding,
vector=aligned_embedding,
payload=chunk_metadata
))
@@ -1261,8 +1368,9 @@ class RAGModule(BaseModule):
# Create document points with enhanced metadata
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
aligned_embedding = self._align_embedding_dimension(embedding, collection_name)
chunk_id = str(uuid.uuid4())
chunk_metadata = {
**processed_doc.metadata,
"document_id": processed_doc.id,
@@ -1284,7 +1392,7 @@ class RAGModule(BaseModule):
points.append(PointStruct(
id=chunk_id,
vector=embedding,
vector=aligned_embedding,
payload=chunk_metadata
))
@@ -1548,6 +1656,8 @@ class RAGModule(BaseModule):
logger.warning(f"sentence-transformers not available, falling back to default embedding for {collection_name}")
optimized_query = f"query: {query}"
query_embedding = await self._generate_embedding(optimized_query)
query_embedding = self._align_embedding_dimension(query_embedding, collection_name)
# Build filter
search_filter = None

View File

@@ -87,6 +87,7 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
const [messages, setMessages] = useState<ChatMessage[]>([])
const [input, setInput] = useState("")
const [isLoading, setIsLoading] = useState(false)
const [conversationId, setConversationId] = useState<string | undefined>(undefined)
const scrollAreaRef = useRef<HTMLDivElement>(null)
const { success: toastSuccess, error: toastError } = useToast()
@@ -103,6 +104,12 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
scrollToBottom()
}, [messages])
useEffect(() => {
// Reset conversation when switching chatbots
setMessages([])
setConversationId(undefined)
}, [chatbotId])
const sendMessage = useCallback(async () => {
if (!input.trim() || isLoading) return
@@ -119,12 +126,13 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
setIsLoading(true)
// Enhanced logging for debugging
const currentConversationId = conversationId
const debugInfo = {
chatbotId,
messageLength: messageToSend.length,
conversationId,
conversationId: currentConversationId ?? null,
timestamp: new Date().toISOString(),
messagesCount: messages.length
messagesCount: messages.length + 1
}
console.log('=== CHAT REQUEST DEBUG ===', debugInfo)
@@ -132,7 +140,7 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
let data: any
// Use internal API
const conversationHistory = messages.map(msg => ({
const conversationHistory = [...messages, userMessage].map(msg => ({
role: msg.role,
content: msg.content
}))
@@ -140,7 +148,7 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
data = await chatbotApi.sendMessage(
chatbotId,
messageToSend,
undefined, // No conversation ID
currentConversationId,
conversationHistory
)
@@ -154,6 +162,11 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
setMessages(prev => [...prev, assistantMessage])
const newConversationId = data?.conversation_id || currentConversationId
if (newConversationId !== conversationId) {
setConversationId(newConversationId)
}
} catch (error) {
const appError = error as AppError
@@ -168,7 +181,7 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
} finally {
setIsLoading(false)
}
}, [input, isLoading, chatbotId, messages, toastError])
}, [input, isLoading, chatbotId, messages, toastError, conversationId])
const handleKeyPress = useCallback((e: React.KeyboardEvent) => {
if (e.key === 'Enter' && !e.shiftKey) {
@@ -358,4 +371,4 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
</CardContent>
</Card>
)
}
}