mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
swapping to local embeddings
This commit is contained in:
@@ -7,6 +7,7 @@ from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from pydantic import BaseModel
|
||||
import io
|
||||
import asyncio
|
||||
@@ -15,6 +16,7 @@ from datetime import datetime
|
||||
from app.db.database import get_db
|
||||
from app.core.security import get_current_user
|
||||
from app.models.user import User
|
||||
from app.models.rag_collection import RagCollection
|
||||
from app.services.rag_service import RAGService
|
||||
from app.utils.exceptions import APIException
|
||||
|
||||
@@ -268,7 +270,19 @@ async def get_documents(
|
||||
try:
|
||||
collection_id_int = int(collection_id)
|
||||
except (ValueError, TypeError):
|
||||
raise HTTPException(status_code=400, detail="Invalid collection_id format")
|
||||
# Attempt to resolve by Qdrant collection name
|
||||
collection_row = await db.scalar(
|
||||
select(RagCollection).where(RagCollection.qdrant_collection_name == collection_id)
|
||||
)
|
||||
if collection_row:
|
||||
collection_id_int = collection_row.id
|
||||
else:
|
||||
# Unknown collection identifier; return empty result instead of erroring out
|
||||
return {
|
||||
"success": True,
|
||||
"documents": [],
|
||||
"total": 0
|
||||
}
|
||||
|
||||
rag_service = RAGService(db)
|
||||
documents = await rag_service.get_documents(
|
||||
|
||||
@@ -129,6 +129,7 @@ class Settings(BaseSettings):
|
||||
RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5"))
|
||||
RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true"
|
||||
RAG_WARN_ON_FALLBACK: bool = os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true"
|
||||
RAG_EMBEDDING_MODEL: str = os.getenv("RAG_EMBEDDING_MODEL", "BAAI/bge-small-en")
|
||||
RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300"))
|
||||
RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120"))
|
||||
RAG_INDEXING_TIMEOUT: int = int(os.getenv("RAG_INDEXING_TIMEOUT", "120"))
|
||||
@@ -154,4 +155,3 @@ class Settings(BaseSettings):
|
||||
|
||||
# Global settings instance
|
||||
settings = Settings()
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, asdict
|
||||
@@ -147,6 +148,11 @@ class RAGModule(BaseModule):
|
||||
if config:
|
||||
self.config.update(config)
|
||||
|
||||
# Ensure embedding model configured (defaults to local BGE small)
|
||||
default_embedding_model = getattr(settings, 'RAG_EMBEDDING_MODEL', 'BAAI/bge-small-en')
|
||||
self.config.setdefault("embedding_model", default_embedding_model)
|
||||
self.default_embedding_model = default_embedding_model
|
||||
|
||||
# Content processing components
|
||||
self.nlp_model = None
|
||||
self.lemmatizer = None
|
||||
@@ -178,6 +184,7 @@ class RAGModule(BaseModule):
|
||||
"supported_types": len(self.supported_types)
|
||||
}
|
||||
self.search_cache = {}
|
||||
self.collection_vector_sizes: Dict[str, int] = {}
|
||||
|
||||
def get_required_permissions(self) -> List[Permission]:
|
||||
"""Return list of permissions this module requires"""
|
||||
@@ -215,7 +222,7 @@ class RAGModule(BaseModule):
|
||||
self.initialized = True
|
||||
log_module_event("rag", "initialized", {
|
||||
"vector_db": self.config.get("vector_db", "qdrant"),
|
||||
"embedding_model": self.embedding_model.get("model_name", "intfloat/multilingual-e5-large-instruct"),
|
||||
"embedding_model": self.embedding_model.get("model_name", self.default_embedding_model),
|
||||
"chunk_size": self.config.get("chunk_size", 400),
|
||||
"max_results": self.config.get("max_results", 10),
|
||||
"supported_file_types": list(self.supported_types.keys()),
|
||||
@@ -427,8 +434,7 @@ class RAGModule(BaseModule):
|
||||
# Prefer enhanced embedding service (rate limiting + retry)
|
||||
from app.services.enhanced_embedding_service import enhanced_embedding_service as embedding_service
|
||||
|
||||
# Use intfloat/multilingual-e5-large-instruct for LLM service integration
|
||||
model_name = self.config.get("embedding_model", "intfloat/multilingual-e5-large-instruct")
|
||||
model_name = self.config.get("embedding_model", self.default_embedding_model)
|
||||
embedding_service.model_name = model_name
|
||||
|
||||
# Initialize the embedding service
|
||||
@@ -439,7 +445,7 @@ class RAGModule(BaseModule):
|
||||
logger.info(f"Successfully initialized embedding service with {model_name}")
|
||||
return {
|
||||
"model_name": model_name,
|
||||
"dimension": embedding_service.dimension or 768
|
||||
"dimension": embedding_service.dimension or 384
|
||||
}
|
||||
else:
|
||||
# Fallback to mock implementation
|
||||
@@ -447,7 +453,7 @@ class RAGModule(BaseModule):
|
||||
self.embedding_service = None
|
||||
return {
|
||||
"model_name": model_name,
|
||||
"dimension": 1024 # Default dimension for intfloat/multilingual-e5-large-instruct
|
||||
"dimension": 384 # Default dimension matching local bge-small embeddings
|
||||
}
|
||||
|
||||
async def _initialize_content_processing(self):
|
||||
@@ -588,15 +594,36 @@ class RAGModule(BaseModule):
|
||||
collection_names = await self._get_collections_safely()
|
||||
|
||||
if collection_name not in collection_names:
|
||||
# Create collection
|
||||
# Create collection with the current embedding dimension
|
||||
vector_dimension = self.embedding_model.get(
|
||||
"dimension",
|
||||
getattr(self.embedding_service, "dimension", 384) or 384
|
||||
)
|
||||
|
||||
self.qdrant_client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=self.embedding_model.get("dimension", 768),
|
||||
size=vector_dimension,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
self.collection_vector_sizes[collection_name] = vector_dimension
|
||||
log_module_event("rag", "collection_created", {"collection": collection_name})
|
||||
else:
|
||||
# Cache existing collection vector size for later alignment
|
||||
try:
|
||||
info = self.qdrant_client.get_collection(collection_name)
|
||||
vectors_param = getattr(info.config.params, "vectors", None) if hasattr(info, "config") else None
|
||||
existing_size = None
|
||||
if vectors_param is not None and hasattr(vectors_param, "size"):
|
||||
existing_size = vectors_param.size
|
||||
elif isinstance(vectors_param, dict):
|
||||
existing_size = vectors_param.get("size")
|
||||
|
||||
if existing_size:
|
||||
self.collection_vector_sizes[collection_name] = existing_size
|
||||
except Exception as inner_error:
|
||||
logger.debug(f"Unable to cache vector size for collection {collection_name}: {inner_error}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring collection exists: {e}")
|
||||
@@ -632,12 +659,13 @@ class RAGModule(BaseModule):
|
||||
async def _generate_embedding(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
if self.embedding_service:
|
||||
# Use real embedding service
|
||||
return await self.embedding_service.get_embedding(text)
|
||||
vector = await self.embedding_service.get_embedding(text)
|
||||
return vector
|
||||
else:
|
||||
# Fallback to deterministic random embedding for consistency
|
||||
np.random.seed(hash(text) % 2**32)
|
||||
return np.random.random(self.embedding_model.get("dimension", 768)).tolist()
|
||||
fallback_dim = self.embedding_model.get("dimension", getattr(self.embedding_service, "dimension", 384) or 384)
|
||||
return np.random.random(fallback_dim).tolist()
|
||||
|
||||
async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts (batch processing)"""
|
||||
@@ -651,15 +679,91 @@ class RAGModule(BaseModule):
|
||||
prefixed_texts = texts
|
||||
|
||||
# Use real embedding service for batch processing
|
||||
return await self.embedding_service.get_embeddings(prefixed_texts)
|
||||
backend = getattr(self.embedding_service, "backend", "unknown")
|
||||
start_time = time.time()
|
||||
logger.info(
|
||||
"Embedding batch requested",
|
||||
extra={
|
||||
"backend": backend,
|
||||
"model": getattr(self.embedding_service, "model_name", "unknown"),
|
||||
"count": len(prefixed_texts),
|
||||
"scope": "documents" if is_document else "queries"
|
||||
},
|
||||
)
|
||||
embeddings = await self.embedding_service.get_embeddings(prefixed_texts)
|
||||
duration = time.time() - start_time
|
||||
logger.info(
|
||||
"Embedding batch finished",
|
||||
extra={
|
||||
"backend": backend,
|
||||
"model": getattr(self.embedding_service, "model_name", "unknown"),
|
||||
"count": len(embeddings),
|
||||
"scope": "documents" if is_document else "queries",
|
||||
"duration_sec": round(duration, 4)
|
||||
},
|
||||
)
|
||||
return embeddings
|
||||
else:
|
||||
# Fallback to individual processing
|
||||
logger.warning(
|
||||
"Embedding service unavailable, falling back to per-item generation",
|
||||
extra={
|
||||
"count": len(texts),
|
||||
"scope": "documents" if is_document else "queries"
|
||||
},
|
||||
)
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
embedding = await self._generate_embedding(text)
|
||||
embeddings.append(embedding)
|
||||
return embeddings
|
||||
|
||||
def _get_collection_vector_size(self, collection_name: Optional[str]) -> int:
|
||||
"""Return the expected vector size for a collection, caching results."""
|
||||
default_dim = self.embedding_model.get(
|
||||
"dimension",
|
||||
getattr(self.embedding_service, "dimension", 384) or 384
|
||||
)
|
||||
|
||||
if not collection_name:
|
||||
return default_dim
|
||||
|
||||
if collection_name in self.collection_vector_sizes:
|
||||
return self.collection_vector_sizes[collection_name]
|
||||
|
||||
try:
|
||||
info = self.qdrant_client.get_collection(collection_name)
|
||||
vectors_param = getattr(info.config.params, "vectors", None) if hasattr(info, "config") else None
|
||||
existing_size = None
|
||||
if vectors_param is not None and hasattr(vectors_param, "size"):
|
||||
existing_size = vectors_param.size
|
||||
elif isinstance(vectors_param, dict):
|
||||
existing_size = vectors_param.get("size")
|
||||
|
||||
if existing_size:
|
||||
self.collection_vector_sizes[collection_name] = existing_size
|
||||
return existing_size
|
||||
except Exception as e:
|
||||
logger.debug(f"Unable to determine vector size for {collection_name}: {e}")
|
||||
|
||||
self.collection_vector_sizes[collection_name] = default_dim
|
||||
return default_dim
|
||||
|
||||
def _align_embedding_dimension(self, vector: List[float], collection_name: Optional[str]) -> List[float]:
|
||||
"""Pad or truncate embeddings to match the target collection dimension."""
|
||||
if vector is None:
|
||||
return vector
|
||||
|
||||
target_dim = self._get_collection_vector_size(collection_name)
|
||||
current_dim = len(vector)
|
||||
|
||||
if current_dim == target_dim:
|
||||
return vector
|
||||
if current_dim > target_dim:
|
||||
return vector[:target_dim]
|
||||
padding = [0.0] * (target_dim - current_dim)
|
||||
return vector + padding
|
||||
|
||||
def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]:
|
||||
"""Split text into overlapping chunks for better context preservation"""
|
||||
chunk_size = chunk_size or self.config.get("chunk_size", 300)
|
||||
@@ -1176,6 +1280,7 @@ class RAGModule(BaseModule):
|
||||
# Create document points
|
||||
points = []
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
aligned_embedding = self._align_embedding_dimension(embedding, collection_name)
|
||||
chunk_id = str(uuid.uuid4())
|
||||
|
||||
chunk_metadata = {
|
||||
@@ -1189,7 +1294,7 @@ class RAGModule(BaseModule):
|
||||
|
||||
points.append(PointStruct(
|
||||
id=chunk_id,
|
||||
vector=embedding,
|
||||
vector=aligned_embedding,
|
||||
payload=chunk_metadata
|
||||
))
|
||||
|
||||
@@ -1257,6 +1362,7 @@ class RAGModule(BaseModule):
|
||||
# Create document points with enhanced metadata
|
||||
points = []
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
aligned_embedding = self._align_embedding_dimension(embedding, collection_name)
|
||||
chunk_id = str(uuid.uuid4())
|
||||
|
||||
chunk_metadata = {
|
||||
@@ -1280,7 +1386,7 @@ class RAGModule(BaseModule):
|
||||
|
||||
points.append(PointStruct(
|
||||
id=chunk_id,
|
||||
vector=embedding,
|
||||
vector=aligned_embedding,
|
||||
payload=chunk_metadata
|
||||
))
|
||||
|
||||
@@ -1514,9 +1620,9 @@ class RAGModule(BaseModule):
|
||||
start_time = time.time()
|
||||
|
||||
# Generate query embedding with task-specific prefix for better retrieval
|
||||
# The E5 model works better with "query:" prefix for search queries
|
||||
optimized_query = f"query: {query}"
|
||||
query_embedding = await self._generate_embedding(optimized_query)
|
||||
query_embedding = self._align_embedding_dimension(query_embedding, collection_name)
|
||||
|
||||
# Build filter
|
||||
search_filter = None
|
||||
|
||||
@@ -1,55 +1,66 @@
|
||||
"""
|
||||
Embedding Service
|
||||
Provides text embedding functionality using LLM service
|
||||
Provides local sentence-transformer embeddings (default: BAAI/bge-small-en).
|
||||
Falls back to deterministic random vectors when the local model is unavailable.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
import numpy as np
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingService:
|
||||
"""Service for generating text embeddings using LLM service"""
|
||||
"""Service for generating text embeddings using a local transformer model"""
|
||||
|
||||
def __init__(self, model_name: str = "intfloat/multilingual-e5-large-instruct"):
|
||||
self.model_name = model_name
|
||||
self.dimension = 1024 # Actual dimension for intfloat/multilingual-e5-large-instruct
|
||||
def __init__(self, model_name: Optional[str] = None):
|
||||
self.model_name = model_name or getattr(settings, "RAG_EMBEDDING_MODEL", "BAAI/bge-small-en")
|
||||
self.dimension = 384 # bge-small produces 384-d vectors
|
||||
self.initialized = False
|
||||
self.local_model = None
|
||||
self.backend = "uninitialized"
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the embedding service with LLM service"""
|
||||
try:
|
||||
from app.services.llm.service import llm_service
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# Initialize LLM service if not already done
|
||||
if not llm_service._initialized:
|
||||
await llm_service.initialize()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Test LLM service health
|
||||
if not llm_service._initialized:
|
||||
logger.error("LLM service not initialized")
|
||||
return False
|
||||
|
||||
# Check if PrivateMode provider is available
|
||||
try:
|
||||
provider_status = await llm_service.get_provider_status()
|
||||
privatemode_status = provider_status.get("privatemode")
|
||||
if not privatemode_status or privatemode_status.status != "healthy":
|
||||
logger.error(f"PrivateMode provider not available: {privatemode_status}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check provider status: {e}")
|
||||
return False
|
||||
def load_model():
|
||||
# Load model synchronously in a worker thread to avoid blocking event loop
|
||||
return SentenceTransformer(self.model_name)
|
||||
|
||||
self.local_model = await loop.run_in_executor(None, load_model)
|
||||
self.dimension = self.local_model.get_sentence_embedding_dimension()
|
||||
self.initialized = True
|
||||
logger.info(f"Embedding service initialized with LLM service: {self.model_name} (dimension: {self.dimension})")
|
||||
self.backend = "sentence_transformer"
|
||||
logger.info(
|
||||
"Embedding service initialized with local model %s (dimension: %s)",
|
||||
self.model_name,
|
||||
self.dimension,
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize LLM embedding service: {e}")
|
||||
logger.warning("Using fallback random embeddings")
|
||||
except ImportError as exc:
|
||||
logger.error("sentence-transformers not installed: %s", exc)
|
||||
logger.warning("Falling back to random embeddings")
|
||||
self.local_model = None
|
||||
self.initialized = False
|
||||
self.backend = "fallback_random"
|
||||
return False
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to load local embedding model {self.model_name}: {exc}")
|
||||
logger.warning("Falling back to random embeddings")
|
||||
self.local_model = None
|
||||
self.initialized = False
|
||||
self.backend = "fallback_random"
|
||||
return False
|
||||
|
||||
async def get_embedding(self, text: str) -> List[float]:
|
||||
@@ -59,90 +70,64 @@ class EmbeddingService:
|
||||
|
||||
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Get embeddings for multiple texts using LLM service"""
|
||||
if not self.initialized:
|
||||
# Fallback to random embeddings if not initialized
|
||||
logger.warning("LLM service not available, using random embeddings")
|
||||
return self._generate_fallback_embeddings(texts)
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
embeddings = []
|
||||
# Process texts in batches for efficiency
|
||||
batch_size = 10
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i+batch_size]
|
||||
if self.local_model:
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Process each text in the batch
|
||||
batch_embeddings = []
|
||||
for text in batch:
|
||||
try:
|
||||
# Truncate text if it's too long for the model's context window
|
||||
# intfloat/multilingual-e5-large-instruct has a 512 token limit, truncate to ~400 tokens worth of chars
|
||||
# Rough estimate: 1 token ≈ 4 characters, so 400 tokens ≈ 1600 chars
|
||||
max_chars = 1600
|
||||
if len(text) > max_chars:
|
||||
truncated_text = text[:max_chars]
|
||||
logger.debug(f"Truncated text from {len(text)} to {max_chars} chars for embedding")
|
||||
else:
|
||||
truncated_text = text
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Guard: skip empty inputs (validator rejects empty strings)
|
||||
if not truncated_text.strip():
|
||||
logger.debug("Empty input for embedding; using fallback vector")
|
||||
batch_embeddings.append(self._generate_fallback_embedding(text))
|
||||
continue
|
||||
try:
|
||||
embeddings = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.local_model.encode(
|
||||
texts,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True,
|
||||
),
|
||||
)
|
||||
duration = time.time() - start_time
|
||||
logger.info(
|
||||
"Embedding batch completed",
|
||||
extra={
|
||||
"backend": self.backend,
|
||||
"model": self.model_name,
|
||||
"count": len(texts),
|
||||
"dimension": self.dimension,
|
||||
"duration_sec": round(duration, 4),
|
||||
},
|
||||
)
|
||||
return embeddings.tolist()
|
||||
except Exception as exc:
|
||||
logger.error(f"Local embedding generation failed: {exc}")
|
||||
self.backend = "fallback_random"
|
||||
return self._generate_fallback_embeddings(texts, duration=time.time() - start_time)
|
||||
|
||||
# Call LLM service embedding endpoint
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import EmbeddingRequest
|
||||
logger.warning("Local embedding model unavailable; using fallback random embeddings")
|
||||
self.backend = "fallback_random"
|
||||
return self._generate_fallback_embeddings(texts, duration=time.time() - start_time)
|
||||
|
||||
llm_request = EmbeddingRequest(
|
||||
model=self.model_name,
|
||||
input=truncated_text,
|
||||
user_id="rag_system",
|
||||
api_key_id=0 # System API key
|
||||
)
|
||||
|
||||
response = await llm_service.create_embedding(llm_request)
|
||||
|
||||
# Extract embedding from response
|
||||
if response.data and len(response.data) > 0:
|
||||
embedding = response.data[0].embedding
|
||||
if embedding:
|
||||
batch_embeddings.append(embedding)
|
||||
# Update dimension based on actual embedding size
|
||||
if not hasattr(self, '_dimension_confirmed'):
|
||||
self.dimension = len(embedding)
|
||||
self._dimension_confirmed = True
|
||||
logger.info(f"Confirmed embedding dimension: {self.dimension}")
|
||||
else:
|
||||
logger.warning(f"No embedding in response for text: {text[:50]}...")
|
||||
batch_embeddings.append(self._generate_fallback_embedding(text))
|
||||
else:
|
||||
logger.warning(f"Invalid response structure for text: {text[:50]}...")
|
||||
batch_embeddings.append(self._generate_fallback_embedding(text))
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting embedding for text: {e}")
|
||||
batch_embeddings.append(self._generate_fallback_embedding(text))
|
||||
|
||||
embeddings.extend(batch_embeddings)
|
||||
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embeddings with LLM service: {e}")
|
||||
# Fallback to random embeddings
|
||||
return self._generate_fallback_embeddings(texts)
|
||||
|
||||
def _generate_fallback_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
def _generate_fallback_embeddings(self, texts: List[str], duration: float = None) -> List[List[float]]:
|
||||
"""Generate fallback random embeddings when model unavailable"""
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
embeddings.append(self._generate_fallback_embedding(text))
|
||||
logger.info(
|
||||
"Embedding batch completed",
|
||||
extra={
|
||||
"backend": "fallback_random",
|
||||
"model": self.model_name,
|
||||
"count": len(texts),
|
||||
"dimension": self.dimension,
|
||||
"duration_sec": round(duration or 0.0, 4),
|
||||
},
|
||||
)
|
||||
return embeddings
|
||||
|
||||
def _generate_fallback_embedding(self, text: str) -> List[float]:
|
||||
"""Generate a single fallback embedding"""
|
||||
dimension = self.dimension or 1024 # Default dimension for intfloat/multilingual-e5-large-instruct
|
||||
dimension = self.dimension or 384
|
||||
# Use hash for reproducible random embeddings
|
||||
np.random.seed(hash(text) % 2**32)
|
||||
return np.random.random(dimension).tolist()
|
||||
@@ -169,22 +154,15 @@ class EmbeddingService:
|
||||
"model_name": self.model_name,
|
||||
"model_loaded": self.initialized,
|
||||
"dimension": self.dimension,
|
||||
"backend": "LLM Service",
|
||||
"backend": self.backend,
|
||||
"initialized": self.initialized
|
||||
}
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
# Cleanup LLM service to prevent memory leaks
|
||||
try:
|
||||
from .llm.service import llm_service
|
||||
if llm_service._initialized:
|
||||
await llm_service.cleanup()
|
||||
logger.info("Cleaned up LLM service from embedding service")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up LLM service: {e}")
|
||||
|
||||
self.local_model = None
|
||||
self.initialized = False
|
||||
self.backend = "uninitialized"
|
||||
|
||||
|
||||
# Global embedding service instance
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
# Enhanced Embedding Service with Rate Limiting Handling
|
||||
"""
|
||||
Enhanced embedding service with robust rate limiting and retry logic
|
||||
Enhanced embedding service that adds basic retry semantics around the local
|
||||
embedding generator. Since embeddings are fully local, rate-limiting metadata is
|
||||
largely informational but retained for compatibility.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .embedding_service import EmbeddingService
|
||||
from app.core.config import settings
|
||||
@@ -17,9 +16,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnhancedEmbeddingService(EmbeddingService):
|
||||
"""Enhanced embedding service with rate limiting handling"""
|
||||
"""Enhanced embedding service with lightweight retry bookkeeping"""
|
||||
|
||||
def __init__(self, model_name: str = "intfloat/multilingual-e5-large-instruct"):
|
||||
def __init__(self, model_name: Optional[str] = None):
|
||||
super().__init__(model_name)
|
||||
self.rate_limit_tracker = {
|
||||
'requests_count': 0,
|
||||
@@ -34,161 +33,16 @@ class EnhancedEmbeddingService(EmbeddingService):
|
||||
|
||||
async def get_embeddings_with_retry(self, texts: List[str], max_retries: int = None) -> tuple[List[List[float]], bool]:
|
||||
"""
|
||||
Get embeddings with rate limiting and retry logic
|
||||
Get embeddings with retry bookkeeping.
|
||||
"""
|
||||
if max_retries is None:
|
||||
max_retries = int(getattr(settings, 'RAG_EMBEDDING_RETRY_COUNT', 3))
|
||||
|
||||
batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 3))
|
||||
|
||||
if not self.initialized:
|
||||
logger.warning("Embedding service not initialized, using fallback")
|
||||
return self._generate_fallback_embeddings(texts), False
|
||||
|
||||
embeddings = []
|
||||
success = True
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i+batch_size]
|
||||
batch_embeddings, batch_success = await self._get_batch_embeddings_with_retry(batch, max_retries)
|
||||
embeddings.extend(batch_embeddings)
|
||||
success = success and batch_success
|
||||
|
||||
# Add delay between batches to avoid rate limiting
|
||||
if i + batch_size < len(texts):
|
||||
delay = self.rate_limit_tracker['delay_between_batches']
|
||||
await asyncio.sleep(delay) # Configurable delay between batches
|
||||
|
||||
return embeddings, success
|
||||
|
||||
async def _get_batch_embeddings_with_retry(self, texts: List[str], max_retries: int) -> tuple[List[List[float]], bool]:
|
||||
"""Get embeddings for a batch with retry logic"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
# Check rate limit before making request
|
||||
if self._is_rate_limited():
|
||||
delay = self._get_rate_limit_delay()
|
||||
logger.warning(f"Rate limit detected, waiting {delay} seconds")
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
|
||||
# Make the request
|
||||
embeddings = await self._get_embeddings_batch_impl(texts)
|
||||
|
||||
return embeddings, True
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
error_msg = str(e).lower()
|
||||
|
||||
# Check if it's a rate limit error
|
||||
if any(indicator in error_msg for indicator in ['429', 'rate limit', 'too many requests', 'quota exceeded']):
|
||||
logger.warning(f"Rate limit error (attempt {attempt + 1}/{max_retries + 1}): {e}")
|
||||
self._update_rate_limit_tracker(success=False)
|
||||
|
||||
if attempt < max_retries:
|
||||
delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)]
|
||||
logger.info(f"Retrying in {delay} seconds...")
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
logger.error(f"Max retries exceeded for rate limit, using fallback embeddings")
|
||||
return self._generate_fallback_embeddings(texts), False
|
||||
else:
|
||||
# Non-rate-limit error
|
||||
logger.error(f"Error generating embeddings: {e}")
|
||||
if attempt < max_retries:
|
||||
delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)]
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logger.error("Max retries exceeded, using fallback embeddings")
|
||||
return self._generate_fallback_embeddings(texts), False
|
||||
|
||||
# If we get here, all retries failed
|
||||
logger.error(f"All retries failed, last error: {last_error}")
|
||||
return self._generate_fallback_embeddings(texts), False
|
||||
|
||||
async def _get_embeddings_batch_impl(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Implementation of getting embeddings for a batch"""
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import EmbeddingRequest
|
||||
|
||||
embeddings = []
|
||||
|
||||
for text in texts:
|
||||
# Respect rate limit before each request
|
||||
while self._is_rate_limited():
|
||||
delay = self._get_rate_limit_delay()
|
||||
logger.warning(f"Rate limit window exceeded, waiting {delay:.2f}s before next request")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Truncate text if needed
|
||||
max_chars = 1600
|
||||
truncated_text = text[:max_chars] if len(text) > max_chars else text
|
||||
|
||||
llm_request = EmbeddingRequest(
|
||||
model=self.model_name,
|
||||
input=truncated_text,
|
||||
user_id="rag_system",
|
||||
api_key_id=0
|
||||
embeddings = await super().get_embeddings(texts)
|
||||
success = self.local_model is not None
|
||||
if not success:
|
||||
logger.warning(
|
||||
"Embedding service operating in fallback mode; consider installing the local model %s",
|
||||
self.model_name,
|
||||
)
|
||||
|
||||
response = await llm_service.create_embedding(llm_request)
|
||||
|
||||
if response.data and len(response.data) > 0:
|
||||
embedding = response.data[0].embedding
|
||||
if embedding:
|
||||
embeddings.append(embedding)
|
||||
if not hasattr(self, '_dimension_confirmed'):
|
||||
self.dimension = len(embedding)
|
||||
self._dimension_confirmed = True
|
||||
else:
|
||||
raise ValueError("Empty embedding in response")
|
||||
else:
|
||||
raise ValueError("Invalid response structure")
|
||||
|
||||
# Count this successful request and optionally delay between requests
|
||||
self._update_rate_limit_tracker(success=True)
|
||||
per_req_delay = self.rate_limit_tracker.get('delay_per_request', 0)
|
||||
if per_req_delay and per_req_delay > 0:
|
||||
await asyncio.sleep(per_req_delay)
|
||||
|
||||
return embeddings
|
||||
|
||||
def _is_rate_limited(self) -> bool:
|
||||
"""Check if we're currently rate limited"""
|
||||
now = time.time()
|
||||
window_start = self.rate_limit_tracker['window_start']
|
||||
|
||||
# Reset window if it's expired
|
||||
if now - window_start > self.rate_limit_tracker['window_size']:
|
||||
self.rate_limit_tracker['requests_count'] = 0
|
||||
self.rate_limit_tracker['window_start'] = now
|
||||
return False
|
||||
|
||||
# Check if we've exceeded the limit
|
||||
return self.rate_limit_tracker['requests_count'] >= self.rate_limit_tracker['max_requests_per_minute']
|
||||
|
||||
def _get_rate_limit_delay(self) -> float:
|
||||
"""Get delay to wait for rate limit reset"""
|
||||
now = time.time()
|
||||
window_end = self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size']
|
||||
return max(0, window_end - now)
|
||||
|
||||
def _update_rate_limit_tracker(self, success: bool):
|
||||
"""Update the rate limit tracker"""
|
||||
now = time.time()
|
||||
|
||||
# Reset window if it's expired
|
||||
if now - self.rate_limit_tracker['window_start'] > self.rate_limit_tracker['window_size']:
|
||||
self.rate_limit_tracker['requests_count'] = 0
|
||||
self.rate_limit_tracker['window_start'] = now
|
||||
|
||||
# Increment counter on successful requests
|
||||
if success:
|
||||
self.rate_limit_tracker['requests_count'] += 1
|
||||
return embeddings, success
|
||||
|
||||
async def get_embedding_stats(self) -> Dict[str, Any]:
|
||||
"""Get embedding service statistics including rate limiting info"""
|
||||
|
||||
@@ -566,10 +566,14 @@ class RAGService:
|
||||
logger.warning(f"Could not check existing collections: {e}")
|
||||
|
||||
# Create collection with proper vector configuration
|
||||
from app.services.embedding_service import embedding_service
|
||||
|
||||
vector_dimension = getattr(embedding_service, 'dimension', 384) or 384
|
||||
|
||||
client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=1024, # Updated for multilingual-e5-large-instruct model
|
||||
size=vector_dimension,
|
||||
distance=Distance.COSINE
|
||||
),
|
||||
optimizers_config=models.OptimizersConfigDiff(
|
||||
|
||||
@@ -8,6 +8,7 @@ import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, asdict
|
||||
@@ -147,6 +148,11 @@ class RAGModule(BaseModule):
|
||||
if config:
|
||||
self.config.update(config)
|
||||
|
||||
# Ensure embedding model configured (defaults to local BGE small)
|
||||
default_embedding_model = getattr(settings, 'RAG_EMBEDDING_MODEL', 'BAAI/bge-small-en')
|
||||
self.config.setdefault("embedding_model", default_embedding_model)
|
||||
self.default_embedding_model = default_embedding_model
|
||||
|
||||
# Content processing components
|
||||
self.nlp_model = None
|
||||
self.lemmatizer = None
|
||||
@@ -178,6 +184,7 @@ class RAGModule(BaseModule):
|
||||
"supported_types": len(self.supported_types)
|
||||
}
|
||||
self.search_cache = {}
|
||||
self.collection_vector_sizes: Dict[str, int] = {}
|
||||
|
||||
def get_required_permissions(self) -> List[Permission]:
|
||||
"""Return list of permissions this module requires"""
|
||||
@@ -215,7 +222,7 @@ class RAGModule(BaseModule):
|
||||
self.initialized = True
|
||||
log_module_event("rag", "initialized", {
|
||||
"vector_db": self.config.get("vector_db", "qdrant"),
|
||||
"embedding_model": self.embedding_model.get("model_name", "intfloat/multilingual-e5-large-instruct"),
|
||||
"embedding_model": self.embedding_model.get("model_name", self.default_embedding_model),
|
||||
"chunk_size": self.config.get("chunk_size", 400),
|
||||
"max_results": self.config.get("max_results", 10),
|
||||
"supported_file_types": list(self.supported_types.keys()),
|
||||
@@ -426,8 +433,7 @@ class RAGModule(BaseModule):
|
||||
"""Initialize embedding model"""
|
||||
from app.services.embedding_service import embedding_service
|
||||
|
||||
# Use intfloat/multilingual-e5-large-instruct for LLM service integration
|
||||
model_name = self.config.get("embedding_model", "intfloat/multilingual-e5-large-instruct")
|
||||
model_name = self.config.get("embedding_model", self.default_embedding_model)
|
||||
embedding_service.model_name = model_name
|
||||
|
||||
# Initialize the embedding service
|
||||
@@ -438,7 +444,7 @@ class RAGModule(BaseModule):
|
||||
logger.info(f"Successfully initialized embedding service with {model_name}")
|
||||
return {
|
||||
"model_name": model_name,
|
||||
"dimension": embedding_service.dimension or 768
|
||||
"dimension": embedding_service.dimension or 384
|
||||
}
|
||||
else:
|
||||
# Fallback to mock implementation
|
||||
@@ -446,7 +452,7 @@ class RAGModule(BaseModule):
|
||||
self.embedding_service = None
|
||||
return {
|
||||
"model_name": model_name,
|
||||
"dimension": 1024 # Default dimension for intfloat/multilingual-e5-large-instruct
|
||||
"dimension": 384 # Default dimension matching local bge-small embeddings
|
||||
}
|
||||
|
||||
async def _initialize_content_processing(self):
|
||||
@@ -587,15 +593,36 @@ class RAGModule(BaseModule):
|
||||
collection_names = await self._get_collections_safely()
|
||||
|
||||
if collection_name not in collection_names:
|
||||
# Create collection
|
||||
# Create collection with the current embedding dimension
|
||||
vector_dimension = self.embedding_model.get(
|
||||
"dimension",
|
||||
getattr(self.embedding_service, "dimension", 384) or 384
|
||||
)
|
||||
|
||||
self.qdrant_client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=self.embedding_model.get("dimension", 768),
|
||||
size=vector_dimension,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
self.collection_vector_sizes[collection_name] = vector_dimension
|
||||
log_module_event("rag", "collection_created", {"collection": collection_name})
|
||||
else:
|
||||
# Cache existing collection vector size for later alignment
|
||||
try:
|
||||
info = self.qdrant_client.get_collection(collection_name)
|
||||
vectors_param = getattr(info.config.params, "vectors", None) if hasattr(info, "config") else None
|
||||
existing_size = None
|
||||
if vectors_param is not None and hasattr(vectors_param, "size"):
|
||||
existing_size = vectors_param.size
|
||||
elif isinstance(vectors_param, dict):
|
||||
existing_size = vectors_param.get("size")
|
||||
|
||||
if existing_size:
|
||||
self.collection_vector_sizes[collection_name] = existing_size
|
||||
except Exception as inner_error:
|
||||
logger.debug(f"Unable to cache vector size for collection {collection_name}: {inner_error}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring collection exists: {e}")
|
||||
@@ -632,11 +659,13 @@ class RAGModule(BaseModule):
|
||||
"""Generate embedding for text"""
|
||||
if self.embedding_service:
|
||||
# Use real embedding service
|
||||
return await self.embedding_service.get_embedding(text)
|
||||
vector = await self.embedding_service.get_embedding(text)
|
||||
return vector
|
||||
else:
|
||||
# Fallback to deterministic random embedding for consistency
|
||||
np.random.seed(hash(text) % 2**32)
|
||||
return np.random.random(self.embedding_model.get("dimension", 768)).tolist()
|
||||
fallback_dim = self.embedding_model.get("dimension", getattr(self.embedding_service, "dimension", 384) or 384)
|
||||
return np.random.random(fallback_dim).tolist()
|
||||
|
||||
async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts (batch processing)"""
|
||||
@@ -650,15 +679,92 @@ class RAGModule(BaseModule):
|
||||
prefixed_texts = texts
|
||||
|
||||
# Use real embedding service for batch processing
|
||||
return await self.embedding_service.get_embeddings(prefixed_texts)
|
||||
backend = getattr(self.embedding_service, "backend", "unknown")
|
||||
start_time = time.time()
|
||||
logger.info(
|
||||
"Embedding batch requested",
|
||||
extra={
|
||||
"backend": backend,
|
||||
"model": getattr(self.embedding_service, "model_name", "unknown"),
|
||||
"count": len(prefixed_texts),
|
||||
"scope": "documents" if is_document else "queries"
|
||||
},
|
||||
)
|
||||
embeddings = await self.embedding_service.get_embeddings(prefixed_texts)
|
||||
duration = time.time() - start_time
|
||||
logger.info(
|
||||
"Embedding batch finished",
|
||||
extra={
|
||||
"backend": backend,
|
||||
"model": getattr(self.embedding_service, "model_name", "unknown"),
|
||||
"count": len(embeddings),
|
||||
"scope": "documents" if is_document else "queries",
|
||||
"duration_sec": round(duration, 4)
|
||||
},
|
||||
)
|
||||
return embeddings
|
||||
else:
|
||||
# Fallback to individual processing
|
||||
logger.warning(
|
||||
"Embedding service unavailable, falling back to per-item generation",
|
||||
extra={
|
||||
"count": len(texts),
|
||||
"scope": "documents" if is_document else "queries"
|
||||
},
|
||||
)
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
embedding = await self._generate_embedding(text)
|
||||
embeddings.append(embedding)
|
||||
return embeddings
|
||||
|
||||
def _get_collection_vector_size(self, collection_name: Optional[str]) -> int:
|
||||
"""Return the expected vector size for a collection, caching results."""
|
||||
default_dim = self.embedding_model.get(
|
||||
"dimension",
|
||||
getattr(self.embedding_service, "dimension", 384) or 384
|
||||
)
|
||||
|
||||
if not collection_name:
|
||||
return default_dim
|
||||
|
||||
if collection_name in self.collection_vector_sizes:
|
||||
return self.collection_vector_sizes[collection_name]
|
||||
|
||||
try:
|
||||
info = self.qdrant_client.get_collection(collection_name)
|
||||
vectors_param = getattr(info.config.params, "vectors", None) if hasattr(info, "config") else None
|
||||
existing_size = None
|
||||
if vectors_param is not None and hasattr(vectors_param, "size"):
|
||||
existing_size = vectors_param.size
|
||||
elif isinstance(vectors_param, dict):
|
||||
existing_size = vectors_param.get("size")
|
||||
|
||||
if existing_size:
|
||||
self.collection_vector_sizes[collection_name] = existing_size
|
||||
return existing_size
|
||||
except Exception as e:
|
||||
logger.debug(f"Unable to determine vector size for {collection_name}: {e}")
|
||||
|
||||
self.collection_vector_sizes[collection_name] = default_dim
|
||||
return default_dim
|
||||
|
||||
def _align_embedding_dimension(self, vector: List[float], collection_name: Optional[str]) -> List[float]:
|
||||
"""Pad or truncate embeddings to match the target collection dimension."""
|
||||
if vector is None:
|
||||
return vector
|
||||
|
||||
target_dim = self._get_collection_vector_size(collection_name)
|
||||
current_dim = len(vector)
|
||||
|
||||
if current_dim == target_dim:
|
||||
return vector
|
||||
if current_dim > target_dim:
|
||||
return vector[:target_dim]
|
||||
# Pad with zeros to reach the target dimension
|
||||
padding = [0.0] * (target_dim - current_dim)
|
||||
return vector + padding
|
||||
|
||||
def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]:
|
||||
"""Split text into overlapping chunks for better context preservation"""
|
||||
chunk_size = chunk_size or self.config.get("chunk_size", 300)
|
||||
@@ -1180,6 +1286,7 @@ class RAGModule(BaseModule):
|
||||
# Create document points
|
||||
points = []
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
aligned_embedding = self._align_embedding_dimension(embedding, collection_name)
|
||||
chunk_id = str(uuid.uuid4())
|
||||
|
||||
chunk_metadata = {
|
||||
@@ -1193,7 +1300,7 @@ class RAGModule(BaseModule):
|
||||
|
||||
points.append(PointStruct(
|
||||
id=chunk_id,
|
||||
vector=embedding,
|
||||
vector=aligned_embedding,
|
||||
payload=chunk_metadata
|
||||
))
|
||||
|
||||
@@ -1261,6 +1368,7 @@ class RAGModule(BaseModule):
|
||||
# Create document points with enhanced metadata
|
||||
points = []
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
aligned_embedding = self._align_embedding_dimension(embedding, collection_name)
|
||||
chunk_id = str(uuid.uuid4())
|
||||
|
||||
chunk_metadata = {
|
||||
@@ -1284,7 +1392,7 @@ class RAGModule(BaseModule):
|
||||
|
||||
points.append(PointStruct(
|
||||
id=chunk_id,
|
||||
vector=embedding,
|
||||
vector=aligned_embedding,
|
||||
payload=chunk_metadata
|
||||
))
|
||||
|
||||
@@ -1549,6 +1657,8 @@ class RAGModule(BaseModule):
|
||||
optimized_query = f"query: {query}"
|
||||
query_embedding = await self._generate_embedding(optimized_query)
|
||||
|
||||
query_embedding = self._align_embedding_dimension(query_embedding, collection_name)
|
||||
|
||||
# Build filter
|
||||
search_filter = None
|
||||
if filters:
|
||||
|
||||
@@ -87,6 +87,7 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
|
||||
const [messages, setMessages] = useState<ChatMessage[]>([])
|
||||
const [input, setInput] = useState("")
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
const [conversationId, setConversationId] = useState<string | undefined>(undefined)
|
||||
const scrollAreaRef = useRef<HTMLDivElement>(null)
|
||||
const { success: toastSuccess, error: toastError } = useToast()
|
||||
|
||||
@@ -103,6 +104,12 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
|
||||
scrollToBottom()
|
||||
}, [messages])
|
||||
|
||||
useEffect(() => {
|
||||
// Reset conversation when switching chatbots
|
||||
setMessages([])
|
||||
setConversationId(undefined)
|
||||
}, [chatbotId])
|
||||
|
||||
const sendMessage = useCallback(async () => {
|
||||
if (!input.trim() || isLoading) return
|
||||
|
||||
@@ -119,12 +126,13 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
|
||||
setIsLoading(true)
|
||||
|
||||
// Enhanced logging for debugging
|
||||
const currentConversationId = conversationId
|
||||
const debugInfo = {
|
||||
chatbotId,
|
||||
messageLength: messageToSend.length,
|
||||
conversationId,
|
||||
conversationId: currentConversationId ?? null,
|
||||
timestamp: new Date().toISOString(),
|
||||
messagesCount: messages.length
|
||||
messagesCount: messages.length + 1
|
||||
}
|
||||
console.log('=== CHAT REQUEST DEBUG ===', debugInfo)
|
||||
|
||||
@@ -132,7 +140,7 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
|
||||
let data: any
|
||||
|
||||
// Use internal API
|
||||
const conversationHistory = messages.map(msg => ({
|
||||
const conversationHistory = [...messages, userMessage].map(msg => ({
|
||||
role: msg.role,
|
||||
content: msg.content
|
||||
}))
|
||||
@@ -140,7 +148,7 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
|
||||
data = await chatbotApi.sendMessage(
|
||||
chatbotId,
|
||||
messageToSend,
|
||||
undefined, // No conversation ID
|
||||
currentConversationId,
|
||||
conversationHistory
|
||||
)
|
||||
|
||||
@@ -154,6 +162,11 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
|
||||
|
||||
setMessages(prev => [...prev, assistantMessage])
|
||||
|
||||
const newConversationId = data?.conversation_id || currentConversationId
|
||||
if (newConversationId !== conversationId) {
|
||||
setConversationId(newConversationId)
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
const appError = error as AppError
|
||||
|
||||
@@ -168,7 +181,7 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
|
||||
} finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}, [input, isLoading, chatbotId, messages, toastError])
|
||||
}, [input, isLoading, chatbotId, messages, toastError, conversationId])
|
||||
|
||||
const handleKeyPress = useCallback((e: React.KeyboardEvent) => {
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
|
||||
Reference in New Issue
Block a user