From 8b6d24192144d123e44c405760c242aa1fe92104 Mon Sep 17 00:00:00 2001 From: Aljaz Ceru Date: Thu, 23 Oct 2025 06:58:34 +0200 Subject: [PATCH] adding ollama embeddings and expanding to metadata search --- backend/app/api/rag_debug.py | 7 +- backend/app/core/config.py | 2 +- backend/app/modules/chatbot/main.py | 49 ++++- backend/app/modules/rag/main.py | 4 +- backend/app/services/embedding_service.py | 6 +- .../app/services/ollama_embedding_service.py | 170 ++++++++++++++++++ backend/app/services/rag_service.py | 2 +- backend/modules/chatbot/main.py | 51 +++++- backend/modules/rag/main.py | 30 ++-- frontend/src/app/api-keys/page.tsx | 1 - frontend/src/app/rag-demo/page.tsx | 2 +- 11 files changed, 289 insertions(+), 35 deletions(-) create mode 100644 backend/app/services/ollama_embedding_service.py diff --git a/backend/app/api/rag_debug.py b/backend/app/api/rag_debug.py index 75a81ce..977bff0 100644 --- a/backend/app/api/rag_debug.py +++ b/backend/app/api/rag_debug.py @@ -55,8 +55,11 @@ async def debug_search( # Get configuration app_config = settings - # Initialize RAG module - rag_module = RAGModule(app_config) + # Initialize RAG module with BGE-M3 configuration + rag_config = { + "embedding_model": "BAAI/bge-m3" + } + rag_module = RAGModule(app_config, config=rag_config) # Get available collections if none specified if not collection_name: diff --git a/backend/app/core/config.py b/backend/app/core/config.py index a091936..bf03f79 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -129,7 +129,7 @@ class Settings(BaseSettings): RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5")) RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true" RAG_WARN_ON_FALLBACK: bool = os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true" - RAG_EMBEDDING_MODEL: str = os.getenv("RAG_EMBEDDING_MODEL", "BAAI/bge-small-en") + RAG_EMBEDDING_MODEL: str = os.getenv("RAG_EMBEDDING_MODEL", "bge-m3") RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300")) RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120")) RAG_INDEXING_TIMEOUT: int = int(os.getenv("RAG_INDEXING_TIMEOUT", "120")) diff --git a/backend/app/modules/chatbot/main.py b/backend/app/modules/chatbot/main.py index 96ae1b2..bfb18f9 100644 --- a/backend/app/modules/chatbot/main.py +++ b/backend/app/modules/chatbot/main.py @@ -846,9 +846,9 @@ class ChatbotModule(BaseModule): logger.info(f"Looking up RAG collection with identifier: '{collection_identifier}'") - # First check if this might be a direct Qdrant collection name - # (e.g., starts with "ext_", "rag_", or contains specific patterns) - if collection_identifier.startswith(("ext_", "rag_", "test_")) or "_" in collection_identifier: + # First check if this collection exists in Qdrant directly + # Qdrant is the source of truth for collections + if True: # Always check Qdrant first # Check if this collection exists in Qdrant directly actual_collection_name = collection_identifier # Remove "ext_" prefix if present @@ -866,6 +866,10 @@ class ChatbotModule(BaseModule): if actual_collection_name in collection_names: logger.info(f"Found Qdrant collection directly: {actual_collection_name}") + + # Auto-register the collection in the database if not found + await self._auto_register_collection(actual_collection_name, db) + return actual_collection_name except Exception as e: logger.warning(f"Error checking Qdrant collections: {e}") @@ -898,13 +902,50 @@ class ChatbotModule(BaseModule): else: logger.warning(f"RAG collection '{collection_identifier}' not found in database (tried both ID and name)") return None - + except Exception as e: logger.error(f"Error looking up RAG collection '{collection_identifier}': {e}") import traceback logger.error(f"Traceback: {traceback.format_exc()}") return None + async def _auto_register_collection(self, collection_name: str, db: Session) -> None: + """Automatically register a Qdrant collection in the database""" + try: + from app.models.rag_collection import RagCollection + from sqlalchemy import select + + # Check if already registered + stmt = select(RagCollection).where( + RagCollection.qdrant_collection_name == collection_name + ) + result = db.execute(stmt) + existing = result.scalar_one_or_none() + + if existing: + logger.info(f"Collection '{collection_name}' already registered in database") + return + + # Create a readable name from collection name + display_name = collection_name.replace("-", " ").replace("_", " ").title() + + # Auto-register the collection + new_collection = RagCollection( + name=display_name, + qdrant_collection_name=collection_name, + description=f"Auto-discovered collection from Qdrant: {collection_name}", + is_active=True + ) + + db.add(new_collection) + db.commit() + + logger.info(f"Auto-registered Qdrant collection '{collection_name}' in database") + + except Exception as e: + logger.error(f"Failed to auto-register collection '{collection_name}': {e}") + # Don't re-raise - this should not block collection usage + # Required abstract methods from BaseModule async def cleanup(self): diff --git a/backend/app/modules/rag/main.py b/backend/app/modules/rag/main.py index b3c5ac3..4401280 100644 --- a/backend/app/modules/rag/main.py +++ b/backend/app/modules/rag/main.py @@ -148,8 +148,8 @@ class RAGModule(BaseModule): if config: self.config.update(config) - # Ensure embedding model configured (defaults to local BGE small) - default_embedding_model = getattr(settings, 'RAG_EMBEDDING_MODEL', 'BAAI/bge-small-en') + # Ensure embedding model configured (defaults to local BGE-M3) + default_embedding_model = getattr(settings, 'RAG_EMBEDDING_MODEL', 'BAAI/bge-m3') self.config.setdefault("embedding_model", default_embedding_model) self.default_embedding_model = default_embedding_model diff --git a/backend/app/services/embedding_service.py b/backend/app/services/embedding_service.py index 78198a9..bb15ccb 100644 --- a/backend/app/services/embedding_service.py +++ b/backend/app/services/embedding_service.py @@ -19,8 +19,8 @@ class EmbeddingService: """Service for generating text embeddings using a local transformer model""" def __init__(self, model_name: Optional[str] = None): - self.model_name = model_name or getattr(settings, "RAG_EMBEDDING_MODEL", "BAAI/bge-small-en") - self.dimension = 384 # bge-small produces 384-d vectors + self.model_name = model_name or getattr(settings, "RAG_EMBEDDING_MODEL", "BAAI/bge-m3") + self.dimension = 1024 # bge-m3 produces 1024-d vectors self.initialized = False self.local_model = None self.backend = "uninitialized" @@ -127,7 +127,7 @@ class EmbeddingService: def _generate_fallback_embedding(self, text: str) -> List[float]: """Generate a single fallback embedding""" - dimension = self.dimension or 384 + dimension = self.dimension or 1024 # Use hash for reproducible random embeddings np.random.seed(hash(text) % 2**32) return np.random.random(dimension).tolist() diff --git a/backend/app/services/ollama_embedding_service.py b/backend/app/services/ollama_embedding_service.py new file mode 100644 index 0000000..5ea460f --- /dev/null +++ b/backend/app/services/ollama_embedding_service.py @@ -0,0 +1,170 @@ +""" +Ollama Embedding Service +Provides text embedding functionality using Ollama locally +""" + +import logging +from typing import List, Dict, Any, Optional +import numpy as np +import aiohttp +import asyncio + +logger = logging.getLogger(__name__) + + +class OllamaEmbeddingService: + """Service for generating text embeddings using Ollama""" + + def __init__(self, model_name: str = "bge-m3", base_url: str = "http://172.17.0.1:11434"): + self.model_name = model_name + self.base_url = base_url + self.dimension = 1024 # bge-m3 dimension + self.initialized = False + self._session = None + + async def initialize(self): + """Initialize embedding service with Ollama""" + try: + # Create HTTP session + self._session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=60) + ) + + # Test Ollama is running and model is available + async with self._session.get(f"{self.base_url}/api/tags") as resp: + if resp.status != 200: + logger.error(f"Ollama not responding at {self.base_url}") + return False + + data = await resp.json() + models = [model['name'].split(':')[0] for model in data.get('models', [])] + + if self.model_name not in models: + logger.error(f"Model {self.model_name} not found in Ollama. Available: {models}") + return False + + # Test embedding generation + test_embedding = await self.get_embedding("test") + if not test_embedding or len(test_embedding) != self.dimension: + logger.error(f"Failed to generate test embedding with {self.model_name}") + return False + + self.initialized = True + logger.info(f"Ollama embedding service initialized with model: {self.model_name} (dimension: {self.dimension})") + return True + + except Exception as e: + logger.error(f"Failed to initialize Ollama embedding service: {e}") + return False + + async def get_embedding(self, text: str) -> List[float]: + """Get embedding for a single text""" + embeddings = await self.get_embeddings([text]) + return embeddings[0] + + async def get_embeddings(self, texts: List[str]) -> List[List[float]]: + """Get embeddings for multiple texts using Ollama""" + if not self.initialized: + # Try to initialize if not done + if not await self.initialize(): + logger.error("Ollama embedding service not available") + return self._generate_fallback_embeddings(texts) + + try: + embeddings = [] + + # Process each text individually (Ollama API typically processes one at a time) + for text in texts: + try: + # Skip empty inputs + if not text.strip(): + logger.debug("Empty input for embedding; using fallback vector") + embeddings.append(self._generate_fallback_embedding(text)) + continue + + # Call Ollama embedding API + async with self._session.post( + f"{self.base_url}/api/embeddings", + json={ + "model": self.model_name, + "prompt": text + } + ) as resp: + if resp.status != 200: + logger.error(f"Ollama embedding request failed: {resp.status}") + embeddings.append(self._generate_fallback_embedding(text)) + continue + + result = await resp.json() + + if 'embedding' in result: + embedding = result['embedding'] + if len(embedding) == self.dimension: + embeddings.append(embedding) + else: + logger.warning(f"Embedding dimension mismatch: expected {self.dimension}, got {len(embedding)}") + embeddings.append(self._generate_fallback_embedding(text)) + else: + logger.error(f"No embedding in Ollama response for text: {text[:50]}...") + embeddings.append(self._generate_fallback_embedding(text)) + + except Exception as e: + logger.error(f"Error getting embedding from Ollama for text: {e}") + embeddings.append(self._generate_fallback_embedding(text)) + + return embeddings + + except Exception as e: + logger.error(f"Error generating embeddings with Ollama: {e}") + return self._generate_fallback_embeddings(texts) + + def _generate_fallback_embeddings(self, texts: List[str]) -> List[List[float]]: + """Generate fallback random embeddings when Ollama unavailable""" + embeddings = [] + for text in texts: + embeddings.append(self._generate_fallback_embedding(text)) + return embeddings + + def _generate_fallback_embedding(self, text: str) -> List[float]: + """Generate a single fallback embedding""" + dimension = self.dimension # 1024 for bge-m3 + # Use hash for reproducible random embeddings + np.random.seed(hash(text) % 2**32) + return np.random.random(dimension).tolist() + + async def similarity(self, text1: str, text2: str) -> float: + """Calculate cosine similarity between two texts""" + embeddings = await self.get_embeddings([text1, text2]) + + # Calculate cosine similarity + vec1 = np.array(embeddings[0]) + vec2 = np.array(embeddings[1]) + + # Normalize vectors + vec1_norm = vec1 / np.linalg.norm(vec1) + vec2_norm = vec2 / np.linalg.norm(vec2) + + # Calculate cosine similarity + similarity = np.dot(vec1_norm, vec2_norm) + return float(similarity) + + async def get_stats(self) -> Dict[str, Any]: + """Get embedding service statistics""" + return { + "model_name": self.model_name, + "model_loaded": self.initialized, + "dimension": self.dimension, + "backend": "Ollama", + "base_url": self.base_url, + "initialized": self.initialized + } + + async def cleanup(self): + """Cleanup resources""" + if self._session: + await self._session.close() + self.initialized = False + + +# Global Ollama embedding service instance +ollama_embedding_service = OllamaEmbeddingService() \ No newline at end of file diff --git a/backend/app/services/rag_service.py b/backend/app/services/rag_service.py index 74443d4..d73e84a 100644 --- a/backend/app/services/rag_service.py +++ b/backend/app/services/rag_service.py @@ -568,7 +568,7 @@ class RAGService: # Create collection with proper vector configuration from app.services.embedding_service import embedding_service - vector_dimension = getattr(embedding_service, 'dimension', 384) or 384 + vector_dimension = getattr(embedding_service, 'dimension', 1024) or 1024 client.create_collection( collection_name=collection_name, diff --git a/backend/modules/chatbot/main.py b/backend/modules/chatbot/main.py index e378414..0dd46f9 100644 --- a/backend/modules/chatbot/main.py +++ b/backend/modules/chatbot/main.py @@ -805,9 +805,9 @@ class ChatbotModule(BaseModule): logger.info(f"Looking up RAG collection with identifier: '{collection_identifier}'") - # First check if this might be a direct Qdrant collection name - # (e.g., starts with "ext_", "rag_", or contains specific patterns) - if collection_identifier.startswith(("ext_", "rag_", "test_")) or "_" in collection_identifier: + # First check if this collection exists in Qdrant directly + # Qdrant is the source of truth for collections + if True: # Always check Qdrant first # Check if this collection exists in Qdrant directly actual_collection_name = collection_identifier # Remove "ext_" prefix if present @@ -825,6 +825,10 @@ class ChatbotModule(BaseModule): if actual_collection_name in collection_names: logger.info(f"Found Qdrant collection directly: {actual_collection_name}") + + # Auto-register the collection in the database if not found + await self._auto_register_collection(actual_collection_name, db) + return actual_collection_name except Exception as e: logger.warning(f"Error checking Qdrant collections: {e}") @@ -857,13 +861,50 @@ class ChatbotModule(BaseModule): else: logger.warning(f"RAG collection '{collection_identifier}' not found in database (tried both ID and name)") return None - + except Exception as e: logger.error(f"Error looking up RAG collection '{collection_identifier}': {e}") import traceback logger.error(f"Traceback: {traceback.format_exc()}") return None + async def _auto_register_collection(self, collection_name: str, db: Session) -> None: + """Automatically register a Qdrant collection in the database""" + try: + from app.models.rag_collection import RagCollection + from sqlalchemy import select + + # Check if already registered + stmt = select(RagCollection).where( + RagCollection.qdrant_collection_name == collection_name + ) + result = db.execute(stmt) + existing = result.scalar_one_or_none() + + if existing: + logger.info(f"Collection '{collection_name}' already registered in database") + return + + # Create a readable name from collection name + display_name = collection_name.replace("-", " ").replace("_", " ").title() + + # Auto-register the collection + new_collection = RagCollection( + name=display_name, + qdrant_collection_name=collection_name, + description=f"Auto-discovered collection from Qdrant: {collection_name}", + is_active=True + ) + + db.add(new_collection) + db.commit() + + logger.info(f"Auto-registered Qdrant collection '{collection_name}' in database") + + except Exception as e: + logger.error(f"Failed to auto-register collection '{collection_name}': {e}") + # Don't re-raise - this should not block collection usage + # Required abstract methods from BaseModule async def cleanup(self): @@ -905,4 +946,4 @@ def create_module(rag_service: Optional[RAGServiceProtocol] = None) -> ChatbotMo return ChatbotModule(rag_service=rag_service) # Create module instance (dependencies will be injected via factory) -chatbot_module = ChatbotModule() \ No newline at end of file +chatbot_module = ChatbotModule() diff --git a/backend/modules/rag/main.py b/backend/modules/rag/main.py index 8bebd9e..c3035f0 100644 --- a/backend/modules/rag/main.py +++ b/backend/modules/rag/main.py @@ -148,8 +148,8 @@ class RAGModule(BaseModule): if config: self.config.update(config) - # Ensure embedding model configured (defaults to local BGE small) - default_embedding_model = getattr(settings, 'RAG_EMBEDDING_MODEL', 'BAAI/bge-small-en') + # Ensure embedding model configured (defaults to local BGE-M3) + default_embedding_model = getattr(settings, 'RAG_EMBEDDING_MODEL', 'bge-m3') self.config.setdefault("embedding_model", default_embedding_model) self.default_embedding_model = default_embedding_model @@ -431,20 +431,20 @@ class RAGModule(BaseModule): async def _initialize_embedding_model(self): """Initialize embedding model""" - from app.services.embedding_service import embedding_service - + from app.services.ollama_embedding_service import ollama_embedding_service + model_name = self.config.get("embedding_model", self.default_embedding_model) - embedding_service.model_name = model_name - + ollama_embedding_service.model_name = model_name + # Initialize the embedding service - success = await embedding_service.initialize() - + success = await ollama_embedding_service.initialize() + if success: - self.embedding_service = embedding_service + self.embedding_service = ollama_embedding_service logger.info(f"Successfully initialized embedding service with {model_name}") return { "model_name": model_name, - "dimension": embedding_service.dimension or 384 + "dimension": ollama_embedding_service.dimension or 1024 } else: # Fallback to mock implementation @@ -452,7 +452,7 @@ class RAGModule(BaseModule): self.embedding_service = None return { "model_name": model_name, - "dimension": 384 # Default dimension matching local bge-small embeddings + "dimension": 1024 # Default dimension matching BGE-M3 embeddings } async def _initialize_content_processing(self): @@ -596,7 +596,7 @@ class RAGModule(BaseModule): # Create collection with the current embedding dimension vector_dimension = self.embedding_model.get( "dimension", - getattr(self.embedding_service, "dimension", 384) or 384 + getattr(self.embedding_service, "dimension", 1024) or 1024 ) self.qdrant_client.create_collection( @@ -664,7 +664,7 @@ class RAGModule(BaseModule): else: # Fallback to deterministic random embedding for consistency np.random.seed(hash(text) % 2**32) - fallback_dim = self.embedding_model.get("dimension", getattr(self.embedding_service, "dimension", 384) or 384) + fallback_dim = self.embedding_model.get("dimension", getattr(self.embedding_service, "dimension", 1024) or 1024) return np.random.random(fallback_dim).tolist() async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]: @@ -1617,11 +1617,11 @@ class RAGModule(BaseModule): # Special handling for collections with different vector dimensions SPECIAL_COLLECTIONS = { "bitbox02_faq_local": { - "dimension": 384, + "dimension": 1024, "model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" }, "bitbox_local_rag": { - "dimension": 384, + "dimension": 1024, "model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" } } diff --git a/frontend/src/app/api-keys/page.tsx b/frontend/src/app/api-keys/page.tsx index f60ac97..d40b437 100644 --- a/frontend/src/app/api-keys/page.tsx +++ b/frontend/src/app/api-keys/page.tsx @@ -3,7 +3,6 @@ import { useState, useEffect, Suspense } from "react"; export const dynamic = 'force-dynamic' import { useSearchParams } from "next/navigation"; -import { Suspense } from "react"; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; diff --git a/frontend/src/app/rag-demo/page.tsx b/frontend/src/app/rag-demo/page.tsx index b07c7ba..77e3db8 100644 --- a/frontend/src/app/rag-demo/page.tsx +++ b/frontend/src/app/rag-demo/page.tsx @@ -1,7 +1,7 @@ "use client"; import { useState, useEffect } from 'react'; -import { useAuth } from '@/contexts/AuthContext'; +import { useAuth } from '@/components/providers/auth-provider'; import { tokenManager } from '@/lib/token-manager'; interface SearchResult {