adding ollama embeddings and expanding to metadata search

This commit is contained in:
2025-10-23 06:58:34 +02:00
parent 5d964dfd54
commit 8b6d241921
11 changed files with 289 additions and 35 deletions

View File

@@ -55,8 +55,11 @@ async def debug_search(
# Get configuration # Get configuration
app_config = settings app_config = settings
# Initialize RAG module # Initialize RAG module with BGE-M3 configuration
rag_module = RAGModule(app_config) rag_config = {
"embedding_model": "BAAI/bge-m3"
}
rag_module = RAGModule(app_config, config=rag_config)
# Get available collections if none specified # Get available collections if none specified
if not collection_name: if not collection_name:

View File

@@ -129,7 +129,7 @@ class Settings(BaseSettings):
RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5")) RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5"))
RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true" RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true"
RAG_WARN_ON_FALLBACK: bool = os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true" RAG_WARN_ON_FALLBACK: bool = os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true"
RAG_EMBEDDING_MODEL: str = os.getenv("RAG_EMBEDDING_MODEL", "BAAI/bge-small-en") RAG_EMBEDDING_MODEL: str = os.getenv("RAG_EMBEDDING_MODEL", "bge-m3")
RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300")) RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300"))
RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120")) RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120"))
RAG_INDEXING_TIMEOUT: int = int(os.getenv("RAG_INDEXING_TIMEOUT", "120")) RAG_INDEXING_TIMEOUT: int = int(os.getenv("RAG_INDEXING_TIMEOUT", "120"))

View File

@@ -846,9 +846,9 @@ class ChatbotModule(BaseModule):
logger.info(f"Looking up RAG collection with identifier: '{collection_identifier}'") logger.info(f"Looking up RAG collection with identifier: '{collection_identifier}'")
# First check if this might be a direct Qdrant collection name # First check if this collection exists in Qdrant directly
# (e.g., starts with "ext_", "rag_", or contains specific patterns) # Qdrant is the source of truth for collections
if collection_identifier.startswith(("ext_", "rag_", "test_")) or "_" in collection_identifier: if True: # Always check Qdrant first
# Check if this collection exists in Qdrant directly # Check if this collection exists in Qdrant directly
actual_collection_name = collection_identifier actual_collection_name = collection_identifier
# Remove "ext_" prefix if present # Remove "ext_" prefix if present
@@ -866,6 +866,10 @@ class ChatbotModule(BaseModule):
if actual_collection_name in collection_names: if actual_collection_name in collection_names:
logger.info(f"Found Qdrant collection directly: {actual_collection_name}") logger.info(f"Found Qdrant collection directly: {actual_collection_name}")
# Auto-register the collection in the database if not found
await self._auto_register_collection(actual_collection_name, db)
return actual_collection_name return actual_collection_name
except Exception as e: except Exception as e:
logger.warning(f"Error checking Qdrant collections: {e}") logger.warning(f"Error checking Qdrant collections: {e}")
@@ -905,6 +909,43 @@ class ChatbotModule(BaseModule):
logger.error(f"Traceback: {traceback.format_exc()}") logger.error(f"Traceback: {traceback.format_exc()}")
return None return None
async def _auto_register_collection(self, collection_name: str, db: Session) -> None:
"""Automatically register a Qdrant collection in the database"""
try:
from app.models.rag_collection import RagCollection
from sqlalchemy import select
# Check if already registered
stmt = select(RagCollection).where(
RagCollection.qdrant_collection_name == collection_name
)
result = db.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
logger.info(f"Collection '{collection_name}' already registered in database")
return
# Create a readable name from collection name
display_name = collection_name.replace("-", " ").replace("_", " ").title()
# Auto-register the collection
new_collection = RagCollection(
name=display_name,
qdrant_collection_name=collection_name,
description=f"Auto-discovered collection from Qdrant: {collection_name}",
is_active=True
)
db.add(new_collection)
db.commit()
logger.info(f"Auto-registered Qdrant collection '{collection_name}' in database")
except Exception as e:
logger.error(f"Failed to auto-register collection '{collection_name}': {e}")
# Don't re-raise - this should not block collection usage
# Required abstract methods from BaseModule # Required abstract methods from BaseModule
async def cleanup(self): async def cleanup(self):

View File

@@ -148,8 +148,8 @@ class RAGModule(BaseModule):
if config: if config:
self.config.update(config) self.config.update(config)
# Ensure embedding model configured (defaults to local BGE small) # Ensure embedding model configured (defaults to local BGE-M3)
default_embedding_model = getattr(settings, 'RAG_EMBEDDING_MODEL', 'BAAI/bge-small-en') default_embedding_model = getattr(settings, 'RAG_EMBEDDING_MODEL', 'BAAI/bge-m3')
self.config.setdefault("embedding_model", default_embedding_model) self.config.setdefault("embedding_model", default_embedding_model)
self.default_embedding_model = default_embedding_model self.default_embedding_model = default_embedding_model

View File

@@ -19,8 +19,8 @@ class EmbeddingService:
"""Service for generating text embeddings using a local transformer model""" """Service for generating text embeddings using a local transformer model"""
def __init__(self, model_name: Optional[str] = None): def __init__(self, model_name: Optional[str] = None):
self.model_name = model_name or getattr(settings, "RAG_EMBEDDING_MODEL", "BAAI/bge-small-en") self.model_name = model_name or getattr(settings, "RAG_EMBEDDING_MODEL", "BAAI/bge-m3")
self.dimension = 384 # bge-small produces 384-d vectors self.dimension = 1024 # bge-m3 produces 1024-d vectors
self.initialized = False self.initialized = False
self.local_model = None self.local_model = None
self.backend = "uninitialized" self.backend = "uninitialized"
@@ -127,7 +127,7 @@ class EmbeddingService:
def _generate_fallback_embedding(self, text: str) -> List[float]: def _generate_fallback_embedding(self, text: str) -> List[float]:
"""Generate a single fallback embedding""" """Generate a single fallback embedding"""
dimension = self.dimension or 384 dimension = self.dimension or 1024
# Use hash for reproducible random embeddings # Use hash for reproducible random embeddings
np.random.seed(hash(text) % 2**32) np.random.seed(hash(text) % 2**32)
return np.random.random(dimension).tolist() return np.random.random(dimension).tolist()

View File

@@ -0,0 +1,170 @@
"""
Ollama Embedding Service
Provides text embedding functionality using Ollama locally
"""
import logging
from typing import List, Dict, Any, Optional
import numpy as np
import aiohttp
import asyncio
logger = logging.getLogger(__name__)
class OllamaEmbeddingService:
"""Service for generating text embeddings using Ollama"""
def __init__(self, model_name: str = "bge-m3", base_url: str = "http://172.17.0.1:11434"):
self.model_name = model_name
self.base_url = base_url
self.dimension = 1024 # bge-m3 dimension
self.initialized = False
self._session = None
async def initialize(self):
"""Initialize embedding service with Ollama"""
try:
# Create HTTP session
self._session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=60)
)
# Test Ollama is running and model is available
async with self._session.get(f"{self.base_url}/api/tags") as resp:
if resp.status != 200:
logger.error(f"Ollama not responding at {self.base_url}")
return False
data = await resp.json()
models = [model['name'].split(':')[0] for model in data.get('models', [])]
if self.model_name not in models:
logger.error(f"Model {self.model_name} not found in Ollama. Available: {models}")
return False
# Test embedding generation
test_embedding = await self.get_embedding("test")
if not test_embedding or len(test_embedding) != self.dimension:
logger.error(f"Failed to generate test embedding with {self.model_name}")
return False
self.initialized = True
logger.info(f"Ollama embedding service initialized with model: {self.model_name} (dimension: {self.dimension})")
return True
except Exception as e:
logger.error(f"Failed to initialize Ollama embedding service: {e}")
return False
async def get_embedding(self, text: str) -> List[float]:
"""Get embedding for a single text"""
embeddings = await self.get_embeddings([text])
return embeddings[0]
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get embeddings for multiple texts using Ollama"""
if not self.initialized:
# Try to initialize if not done
if not await self.initialize():
logger.error("Ollama embedding service not available")
return self._generate_fallback_embeddings(texts)
try:
embeddings = []
# Process each text individually (Ollama API typically processes one at a time)
for text in texts:
try:
# Skip empty inputs
if not text.strip():
logger.debug("Empty input for embedding; using fallback vector")
embeddings.append(self._generate_fallback_embedding(text))
continue
# Call Ollama embedding API
async with self._session.post(
f"{self.base_url}/api/embeddings",
json={
"model": self.model_name,
"prompt": text
}
) as resp:
if resp.status != 200:
logger.error(f"Ollama embedding request failed: {resp.status}")
embeddings.append(self._generate_fallback_embedding(text))
continue
result = await resp.json()
if 'embedding' in result:
embedding = result['embedding']
if len(embedding) == self.dimension:
embeddings.append(embedding)
else:
logger.warning(f"Embedding dimension mismatch: expected {self.dimension}, got {len(embedding)}")
embeddings.append(self._generate_fallback_embedding(text))
else:
logger.error(f"No embedding in Ollama response for text: {text[:50]}...")
embeddings.append(self._generate_fallback_embedding(text))
except Exception as e:
logger.error(f"Error getting embedding from Ollama for text: {e}")
embeddings.append(self._generate_fallback_embedding(text))
return embeddings
except Exception as e:
logger.error(f"Error generating embeddings with Ollama: {e}")
return self._generate_fallback_embeddings(texts)
def _generate_fallback_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Generate fallback random embeddings when Ollama unavailable"""
embeddings = []
for text in texts:
embeddings.append(self._generate_fallback_embedding(text))
return embeddings
def _generate_fallback_embedding(self, text: str) -> List[float]:
"""Generate a single fallback embedding"""
dimension = self.dimension # 1024 for bge-m3
# Use hash for reproducible random embeddings
np.random.seed(hash(text) % 2**32)
return np.random.random(dimension).tolist()
async def similarity(self, text1: str, text2: str) -> float:
"""Calculate cosine similarity between two texts"""
embeddings = await self.get_embeddings([text1, text2])
# Calculate cosine similarity
vec1 = np.array(embeddings[0])
vec2 = np.array(embeddings[1])
# Normalize vectors
vec1_norm = vec1 / np.linalg.norm(vec1)
vec2_norm = vec2 / np.linalg.norm(vec2)
# Calculate cosine similarity
similarity = np.dot(vec1_norm, vec2_norm)
return float(similarity)
async def get_stats(self) -> Dict[str, Any]:
"""Get embedding service statistics"""
return {
"model_name": self.model_name,
"model_loaded": self.initialized,
"dimension": self.dimension,
"backend": "Ollama",
"base_url": self.base_url,
"initialized": self.initialized
}
async def cleanup(self):
"""Cleanup resources"""
if self._session:
await self._session.close()
self.initialized = False
# Global Ollama embedding service instance
ollama_embedding_service = OllamaEmbeddingService()

View File

@@ -568,7 +568,7 @@ class RAGService:
# Create collection with proper vector configuration # Create collection with proper vector configuration
from app.services.embedding_service import embedding_service from app.services.embedding_service import embedding_service
vector_dimension = getattr(embedding_service, 'dimension', 384) or 384 vector_dimension = getattr(embedding_service, 'dimension', 1024) or 1024
client.create_collection( client.create_collection(
collection_name=collection_name, collection_name=collection_name,

View File

@@ -805,9 +805,9 @@ class ChatbotModule(BaseModule):
logger.info(f"Looking up RAG collection with identifier: '{collection_identifier}'") logger.info(f"Looking up RAG collection with identifier: '{collection_identifier}'")
# First check if this might be a direct Qdrant collection name # First check if this collection exists in Qdrant directly
# (e.g., starts with "ext_", "rag_", or contains specific patterns) # Qdrant is the source of truth for collections
if collection_identifier.startswith(("ext_", "rag_", "test_")) or "_" in collection_identifier: if True: # Always check Qdrant first
# Check if this collection exists in Qdrant directly # Check if this collection exists in Qdrant directly
actual_collection_name = collection_identifier actual_collection_name = collection_identifier
# Remove "ext_" prefix if present # Remove "ext_" prefix if present
@@ -825,6 +825,10 @@ class ChatbotModule(BaseModule):
if actual_collection_name in collection_names: if actual_collection_name in collection_names:
logger.info(f"Found Qdrant collection directly: {actual_collection_name}") logger.info(f"Found Qdrant collection directly: {actual_collection_name}")
# Auto-register the collection in the database if not found
await self._auto_register_collection(actual_collection_name, db)
return actual_collection_name return actual_collection_name
except Exception as e: except Exception as e:
logger.warning(f"Error checking Qdrant collections: {e}") logger.warning(f"Error checking Qdrant collections: {e}")
@@ -864,6 +868,43 @@ class ChatbotModule(BaseModule):
logger.error(f"Traceback: {traceback.format_exc()}") logger.error(f"Traceback: {traceback.format_exc()}")
return None return None
async def _auto_register_collection(self, collection_name: str, db: Session) -> None:
"""Automatically register a Qdrant collection in the database"""
try:
from app.models.rag_collection import RagCollection
from sqlalchemy import select
# Check if already registered
stmt = select(RagCollection).where(
RagCollection.qdrant_collection_name == collection_name
)
result = db.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
logger.info(f"Collection '{collection_name}' already registered in database")
return
# Create a readable name from collection name
display_name = collection_name.replace("-", " ").replace("_", " ").title()
# Auto-register the collection
new_collection = RagCollection(
name=display_name,
qdrant_collection_name=collection_name,
description=f"Auto-discovered collection from Qdrant: {collection_name}",
is_active=True
)
db.add(new_collection)
db.commit()
logger.info(f"Auto-registered Qdrant collection '{collection_name}' in database")
except Exception as e:
logger.error(f"Failed to auto-register collection '{collection_name}': {e}")
# Don't re-raise - this should not block collection usage
# Required abstract methods from BaseModule # Required abstract methods from BaseModule
async def cleanup(self): async def cleanup(self):

View File

@@ -148,8 +148,8 @@ class RAGModule(BaseModule):
if config: if config:
self.config.update(config) self.config.update(config)
# Ensure embedding model configured (defaults to local BGE small) # Ensure embedding model configured (defaults to local BGE-M3)
default_embedding_model = getattr(settings, 'RAG_EMBEDDING_MODEL', 'BAAI/bge-small-en') default_embedding_model = getattr(settings, 'RAG_EMBEDDING_MODEL', 'bge-m3')
self.config.setdefault("embedding_model", default_embedding_model) self.config.setdefault("embedding_model", default_embedding_model)
self.default_embedding_model = default_embedding_model self.default_embedding_model = default_embedding_model
@@ -431,20 +431,20 @@ class RAGModule(BaseModule):
async def _initialize_embedding_model(self): async def _initialize_embedding_model(self):
"""Initialize embedding model""" """Initialize embedding model"""
from app.services.embedding_service import embedding_service from app.services.ollama_embedding_service import ollama_embedding_service
model_name = self.config.get("embedding_model", self.default_embedding_model) model_name = self.config.get("embedding_model", self.default_embedding_model)
embedding_service.model_name = model_name ollama_embedding_service.model_name = model_name
# Initialize the embedding service # Initialize the embedding service
success = await embedding_service.initialize() success = await ollama_embedding_service.initialize()
if success: if success:
self.embedding_service = embedding_service self.embedding_service = ollama_embedding_service
logger.info(f"Successfully initialized embedding service with {model_name}") logger.info(f"Successfully initialized embedding service with {model_name}")
return { return {
"model_name": model_name, "model_name": model_name,
"dimension": embedding_service.dimension or 384 "dimension": ollama_embedding_service.dimension or 1024
} }
else: else:
# Fallback to mock implementation # Fallback to mock implementation
@@ -452,7 +452,7 @@ class RAGModule(BaseModule):
self.embedding_service = None self.embedding_service = None
return { return {
"model_name": model_name, "model_name": model_name,
"dimension": 384 # Default dimension matching local bge-small embeddings "dimension": 1024 # Default dimension matching BGE-M3 embeddings
} }
async def _initialize_content_processing(self): async def _initialize_content_processing(self):
@@ -596,7 +596,7 @@ class RAGModule(BaseModule):
# Create collection with the current embedding dimension # Create collection with the current embedding dimension
vector_dimension = self.embedding_model.get( vector_dimension = self.embedding_model.get(
"dimension", "dimension",
getattr(self.embedding_service, "dimension", 384) or 384 getattr(self.embedding_service, "dimension", 1024) or 1024
) )
self.qdrant_client.create_collection( self.qdrant_client.create_collection(
@@ -664,7 +664,7 @@ class RAGModule(BaseModule):
else: else:
# Fallback to deterministic random embedding for consistency # Fallback to deterministic random embedding for consistency
np.random.seed(hash(text) % 2**32) np.random.seed(hash(text) % 2**32)
fallback_dim = self.embedding_model.get("dimension", getattr(self.embedding_service, "dimension", 384) or 384) fallback_dim = self.embedding_model.get("dimension", getattr(self.embedding_service, "dimension", 1024) or 1024)
return np.random.random(fallback_dim).tolist() return np.random.random(fallback_dim).tolist()
async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]: async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]:
@@ -1617,11 +1617,11 @@ class RAGModule(BaseModule):
# Special handling for collections with different vector dimensions # Special handling for collections with different vector dimensions
SPECIAL_COLLECTIONS = { SPECIAL_COLLECTIONS = {
"bitbox02_faq_local": { "bitbox02_faq_local": {
"dimension": 384, "dimension": 1024,
"model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" "model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
}, },
"bitbox_local_rag": { "bitbox_local_rag": {
"dimension": 384, "dimension": 1024,
"model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" "model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
} }
} }

View File

@@ -3,7 +3,6 @@
import { useState, useEffect, Suspense } from "react"; import { useState, useEffect, Suspense } from "react";
export const dynamic = 'force-dynamic' export const dynamic = 'force-dynamic'
import { useSearchParams } from "next/navigation"; import { useSearchParams } from "next/navigation";
import { Suspense } from "react";
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input"; import { Input } from "@/components/ui/input";

View File

@@ -1,7 +1,7 @@
"use client"; "use client";
import { useState, useEffect } from 'react'; import { useState, useEffect } from 'react';
import { useAuth } from '@/contexts/AuthContext'; import { useAuth } from '@/components/providers/auth-provider';
import { tokenManager } from '@/lib/token-manager'; import { tokenManager } from '@/lib/token-manager';
interface SearchResult { interface SearchResult {