tshoot rag memory leak

This commit is contained in:
2025-10-06 14:42:07 +02:00
parent 3e841d0e42
commit bae86fb5a2
12 changed files with 295 additions and 92 deletions

View File

@@ -345,7 +345,12 @@ async def refresh_token(
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
) )
except HTTPException:
# Re-raise HTTPException without modification
raise
except Exception as e: except Exception as e:
# Log the actual error for debugging
logger.error(f"Refresh token error: {str(e)}", exc_info=True)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token" detail="Invalid refresh token"

View File

@@ -288,14 +288,18 @@ async def get_documents(
@router.post("/documents", response_model=dict) @router.post("/documents", response_model=dict)
async def upload_document( async def upload_document(
collection_id: int = Form(...), collection_id: str = Form(...),
file: UploadFile = File(...), file: UploadFile = File(...),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
"""Upload and process a document""" """Upload and process a document"""
try: try:
# Read file content # Validate file can be read before processing
filename = file.filename or "unknown"
file_extension = filename.split('.')[-1].lower() if '.' in filename else ''
# Read file content once and use it for all validations
file_content = await file.read() file_content = await file.read()
if len(file_content) == 0: if len(file_content) == 0:
@@ -304,10 +308,6 @@ async def upload_document(
if len(file_content) > 50 * 1024 * 1024: # 50MB limit if len(file_content) > 50 * 1024 * 1024: # 50MB limit
raise HTTPException(status_code=400, detail="File too large (max 50MB)") raise HTTPException(status_code=400, detail="File too large (max 50MB)")
# Validate file can be read before processing
filename = file.filename or "unknown"
file_extension = filename.split('.')[-1].lower() if '.' in filename else ''
try: try:
# Test file readability based on type # Test file readability based on type
if file_extension == 'jsonl': if file_extension == 'jsonl':
@@ -349,8 +349,33 @@ async def upload_document(
raise HTTPException(status_code=400, detail=f"File validation failed: {str(e)}") raise HTTPException(status_code=400, detail=f"File validation failed: {str(e)}")
rag_service = RAGService(db) rag_service = RAGService(db)
# Resolve collection identifier (supports both numeric IDs and Qdrant collection names)
collection_identifier = (collection_id or "").strip()
if not collection_identifier:
raise HTTPException(status_code=400, detail="Collection identifier is required")
resolved_collection_id: Optional[int] = None
if collection_identifier.isdigit():
resolved_collection_id = int(collection_identifier)
else:
qdrant_name = collection_identifier
if qdrant_name.startswith("ext_"):
qdrant_name = qdrant_name[4:]
try:
collection_record = await rag_service.ensure_collection_record(qdrant_name)
except Exception as ensure_error:
raise HTTPException(status_code=500, detail=str(ensure_error))
resolved_collection_id = collection_record.id
if resolved_collection_id is None:
raise HTTPException(status_code=400, detail="Invalid collection identifier")
document = await rag_service.upload_document( document = await rag_service.upload_document(
collection_id=collection_id, collection_id=resolved_collection_id,
file_content=file_content, file_content=file_content,
filename=filename, filename=filename,
content_type=file.content_type content_type=file.content_type

View File

@@ -682,10 +682,13 @@ class RAGModule(BaseModule):
chunks.append(chunk_text) chunks.append(chunk_text)
# Move to next chunk with overlap # Move to next chunk with overlap
start_idx = end_idx - chunk_overlap # Ensure we make progress and don't loop infinitely
start_idx += chunk_size - chunk_overlap
if start_idx >= len(tokens):
break
# Ensure progress (in case overlap >= chunk_size) # Safety check to prevent infinite loop
if start_idx >= end_idx: if start_idx <= end_idx - chunk_size:
start_idx = end_idx start_idx = end_idx
return chunks return chunks

View File

@@ -16,6 +16,7 @@ from sqlalchemy.orm import selectinload
from app.db.database import get_db from app.db.database import get_db
from app.models.rag_document import RagDocument from app.models.rag_document import RagDocument
from app.models.rag_collection import RagCollection from app.models.rag_collection import RagCollection
from app.services.module_manager import module_manager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -57,6 +58,8 @@ class DocumentProcessor:
"queue_size": 0, "queue_size": 0,
"active_workers": 0 "active_workers": 0
} }
self._rag_module = None
self._rag_module_lock = asyncio.Lock()
async def start(self): async def start(self):
"""Start the document processor""" """Start the document processor"""
@@ -160,6 +163,30 @@ class DocumentProcessor:
logger.info(f"Worker stopped: {worker_name}") logger.info(f"Worker stopped: {worker_name}")
async def _get_rag_module(self):
"""Resolve and cache the RAG module instance"""
async with self._rag_module_lock:
if self._rag_module and getattr(self._rag_module, 'enabled', False):
return self._rag_module
if not module_manager.initialized:
await module_manager.initialize()
rag_module = module_manager.modules.get('rag')
if not rag_module or not getattr(rag_module, 'enabled', False):
enabled = await module_manager.enable_module('rag')
if not enabled:
raise Exception("Failed to enable RAG module")
rag_module = module_manager.modules.get('rag')
if not rag_module or not getattr(rag_module, 'enabled', False):
raise Exception("RAG module not available or not enabled")
self._rag_module = rag_module
logger.info("DocumentProcessor cached RAG module instance for reuse")
return self._rag_module
async def _process_document(self, task: ProcessingTask) -> bool: async def _process_document(self, task: ProcessingTask) -> bool:
"""Process a single document""" """Process a single document"""
from datetime import datetime from datetime import datetime
@@ -185,16 +212,7 @@ class DocumentProcessor:
# Get RAG module for processing # Get RAG module for processing
try: try:
# Import RAG module and initialize it properly rag_module = await self._get_rag_module()
from modules.rag.main import RAGModule
from app.core.config import settings
# Create and initialize RAG module instance
rag_module = RAGModule(settings)
init_result = await rag_module.initialize()
if not rag_module.enabled:
raise Exception("Failed to enable RAG module")
except Exception as e: except Exception as e:
logger.error(f"Failed to get RAG module: {e}") logger.error(f"Failed to get RAG module: {e}")
raise Exception(f"RAG module not available: {e}") raise Exception(f"RAG module not available: {e}")

View File

@@ -175,6 +175,15 @@ class EmbeddingService:
async def cleanup(self): async def cleanup(self):
"""Cleanup resources""" """Cleanup resources"""
# Cleanup LLM service to prevent memory leaks
try:
from .llm.service import llm_service
if llm_service._initialized:
await llm_service.cleanup()
logger.info("Cleaned up LLM service from embedding service")
except Exception as e:
logger.error(f"Error cleaning up LLM service: {e}")
self.initialized = False self.initialized = False

View File

@@ -478,5 +478,11 @@ class PrivateModeProvider(BaseLLMProvider):
async def cleanup(self): async def cleanup(self):
"""Cleanup PrivateMode provider resources""" """Cleanup PrivateMode provider resources"""
# Close HTTP session to prevent memory leaks
if self._session and not self._session.closed:
await self._session.close()
self._session = None
logger.debug("Closed PrivateMode HTTP session")
await super().cleanup() await super().cleanup()
logger.debug("PrivateMode provider cleanup completed") logger.debug("PrivateMode provider cleanup completed")

View File

@@ -82,6 +82,54 @@ class RAGService:
) )
return await self.db.scalar(stmt) return await self.db.scalar(stmt)
async def get_collection_by_qdrant_name(self, qdrant_collection_name: str) -> Optional[RagCollection]:
"""Get a collection using its Qdrant collection name"""
stmt = select(RagCollection).where(
RagCollection.qdrant_collection_name == qdrant_collection_name
)
return await self.db.scalar(stmt)
async def ensure_collection_record(self, qdrant_collection_name: str) -> RagCollection:
"""Ensure we have a managed record for a given Qdrant collection"""
existing = await self.get_collection_by_qdrant_name(qdrant_collection_name)
if existing:
return existing
# Create a friendly name from the Qdrant collection identifier
friendly_name = qdrant_collection_name
try:
if qdrant_collection_name.startswith("rag_"):
trimmed = qdrant_collection_name[4:]
parts = [part for part in trimmed.split("_") if part]
if parts:
friendly_name = " ".join(parts).title()
except Exception:
# Fall back to original identifier on any parsing issues
friendly_name = qdrant_collection_name
collection = RagCollection(
name=friendly_name,
description=f"Synced from Qdrant collection '{qdrant_collection_name}'",
qdrant_collection_name=qdrant_collection_name,
status='active',
is_active=True
)
self.db.add(collection)
try:
await self.db.commit()
except Exception:
await self.db.rollback()
# Another request might have created the collection concurrently; fetch again
existing = await self.get_collection_by_qdrant_name(qdrant_collection_name)
if existing:
return existing
raise
await self.db.refresh(collection)
return collection
async def get_all_collections(self, skip: int = 0, limit: int = 100) -> List[dict]: async def get_all_collections(self, skip: int = 0, limit: int = 100) -> List[dict]:
"""Get all collections from Qdrant (source of truth) with additional metadata from PostgreSQL.""" """Get all collections from Qdrant (source of truth) with additional metadata from PostgreSQL."""
logger.info("Getting all RAG collections from Qdrant (source of truth)") logger.info("Getting all RAG collections from Qdrant (source of truth)")

View File

@@ -664,28 +664,36 @@ class RAGModule(BaseModule):
chunk_size = chunk_size or self.config.get("chunk_size", 300) chunk_size = chunk_size or self.config.get("chunk_size", 300)
chunk_overlap = self.config.get("chunk_overlap", 50) chunk_overlap = self.config.get("chunk_overlap", 50)
# Tokenize text # Ensure sane values to avoid infinite loops on very short docs
chunk_size = max(1, chunk_size)
if chunk_overlap >= chunk_size:
chunk_overlap = max(0, chunk_size - 1)
tokens = self.tokenizer.encode(text) tokens = self.tokenizer.encode(text)
if not tokens:
return []
# Split into chunks with overlap chunks: List[str] = []
chunks = [] len_tokens = len(tokens)
start_idx = 0 start_idx = 0
step = max(1, chunk_size - chunk_overlap)
while start_idx < len(tokens): while start_idx < len_tokens:
end_idx = min(start_idx + chunk_size, len(tokens)) end_idx = min(start_idx + chunk_size, len_tokens)
chunk_tokens = tokens[start_idx:end_idx] chunk_tokens = tokens[start_idx:end_idx]
if not chunk_tokens:
break
chunk_text = self.tokenizer.decode(chunk_tokens) chunk_text = self.tokenizer.decode(chunk_tokens)
# Only add non-empty chunks
if chunk_text.strip(): if chunk_text.strip():
chunks.append(chunk_text) chunks.append(chunk_text)
# Move to next chunk with overlap if end_idx >= len_tokens:
start_idx = end_idx - chunk_overlap break
# Ensure progress (in case overlap >= chunk_size) start_idx += step
if start_idx >= end_idx:
start_idx = end_idx
return chunks return chunks

View File

@@ -42,23 +42,37 @@ export async function GET(request: NextRequest) {
export async function POST(request: NextRequest) { export async function POST(request: NextRequest) {
try { try {
const formData = await request.formData() console.log('=== Document Upload API Route Called ===')
// Get auth token from request headers // Get auth token from request headers
const authHeader = request.headers.get('authorization') const authHeader = request.headers.get('authorization')
console.log('Auth header present:', !!authHeader)
// Forward the FormData directly to backend // Get the original content type (includes multipart boundary)
const contentType = request.headers.get('content-type')
console.log('Original content type:', contentType)
// Try to forward the original request body directly
try {
console.log('Request body type:', typeof request.body)
console.log('Request body:', request.body)
// Forward directly to backend with original body stream
console.log('Sending request to backend:', `${BACKEND_URL}/api/rag/documents`)
const backendResponse = await fetch(`${BACKEND_URL}/api/rag/documents`, { const backendResponse = await fetch(`${BACKEND_URL}/api/rag/documents`, {
method: 'POST', method: 'POST',
headers: { headers: {
...(authHeader && { 'Authorization': authHeader }), ...(authHeader && { 'Authorization': authHeader }),
// Don't set Content-Type for FormData - let the browser set it with boundary ...(contentType && { 'Content-Type': contentType }),
}, },
body: formData, body: request.body,
}) })
console.log('Backend response status:', backendResponse.status, backendResponse.statusText)
if (!backendResponse.ok) { if (!backendResponse.ok) {
const errorData = await backendResponse.json().catch(() => ({ error: 'Unknown error' })) const errorData = await backendResponse.json().catch(() => ({ error: 'Unknown error' }))
console.log('Backend error response:', errorData)
return NextResponse.json( return NextResponse.json(
{ success: false, error: errorData.detail || errorData.error || 'Failed to upload document' }, { success: false, error: errorData.detail || errorData.error || 'Failed to upload document' },
{ status: backendResponse.status } { status: backendResponse.status }
@@ -66,10 +80,19 @@ export async function POST(request: NextRequest) {
} }
const data = await backendResponse.json() const data = await backendResponse.json()
console.log('Backend success response:', data)
return NextResponse.json(data) return NextResponse.json(data)
} catch (bodyError) {
console.error('Error reading request body:', bodyError)
throw bodyError
}
} catch (error) { } catch (error) {
console.error('Document upload error:', error)
console.error('Error stack:', error.stack)
return NextResponse.json( return NextResponse.json(
{ success: false, error: 'Failed to upload document' }, { success: false, error: 'Failed to upload document: ' + error.message },
{ status: 500 } { status: 500 }
) )
} }

View File

@@ -1,6 +1,6 @@
"use client" "use client"
import { useState, useRef } from "react" import { useState, useRef, useEffect } from "react"
import { Button } from "@/components/ui/button" import { Button } from "@/components/ui/button"
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card" import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"
import { Input } from "@/components/ui/input" import { Input } from "@/components/ui/input"
@@ -16,6 +16,8 @@ import { uploadFile } from "@/lib/file-download"
interface Collection { interface Collection {
id: string id: string
name: string name: string
is_managed?: boolean
source?: string
} }
interface DocumentUploadProps { interface DocumentUploadProps {
@@ -39,11 +41,29 @@ export function DocumentUpload({ collections, selectedCollection, onDocumentUplo
const fileInputRef = useRef<HTMLInputElement>(null) const fileInputRef = useRef<HTMLInputElement>(null)
const { toast } = useToast() const { toast } = useToast()
const isUploadableCollection = (collectionId: string | null) => {
if (!collectionId) return false
return collections.some(collection => collection.id === collectionId)
}
useEffect(() => {
if (selectedCollection && isUploadableCollection(selectedCollection)) {
setTargetCollection(selectedCollection)
return
}
if (!targetCollection || !isUploadableCollection(targetCollection)) {
setTargetCollection(collections[0]?.id || "")
}
}, [selectedCollection, collections, targetCollection])
const supportedTypes = [ const supportedTypes = [
".pdf", ".docx", ".doc", ".xlsx", ".xls", ".txt", ".md", ".html", ".json", ".csv" ".pdf", ".docx", ".doc", ".xlsx", ".xls", ".txt", ".md", ".html", ".json", ".csv"
] ]
const handleFileSelect = (files: FileList | null) => { const handleFileSelect = (files: FileList | null) => {
console.log('handleFileSelect called with:', files)
if (!files || files.length === 0) return if (!files || files.length === 0) return
if (!targetCollection) { if (!targetCollection) {
@@ -62,20 +82,22 @@ export function DocumentUpload({ collections, selectedCollection, onDocumentUplo
id: Math.random().toString(36).substr(2, 9) id: Math.random().toString(36).substr(2, 9)
})) }))
console.log('Processing files:', newFiles.map(f => ({ name: f.file.name, size: f.file.size })))
setUploadingFiles(prev => [...prev, ...newFiles]) setUploadingFiles(prev => [...prev, ...newFiles])
// Process each file // Process files sequentially to avoid body consumption conflicts
newFiles.forEach(uploadFile => { processFilesSequentially(newFiles)
uploadDocument(uploadFile) }
})
const processFilesSequentially = async (files: UploadingFile[]) => {
for (const uploadingFile of files) {
await uploadDocument(uploadingFile)
}
} }
const uploadDocument = async (uploadingFile: UploadingFile) => { const uploadDocument = async (uploadingFile: UploadingFile) => {
try { try {
const formData = new FormData()
formData.append('file', uploadingFile.file)
formData.append('collection_id', targetCollection)
// Simulate upload progress // Simulate upload progress
const updateProgress = (progress: number) => { const updateProgress = (progress: number) => {
setUploadingFiles(prev => setUploadingFiles(prev =>
@@ -90,12 +112,9 @@ export function DocumentUpload({ collections, selectedCollection, onDocumentUplo
await new Promise(resolve => setTimeout(resolve, 200)) await new Promise(resolve => setTimeout(resolve, 200))
updateProgress(60) updateProgress(60)
await uploadFile( await uploadFile('/api-internal/v1/rag/documents', uploadingFile.file, {
uploadingFile.file, collection_id: targetCollection,
'/api-internal/v1/rag/documents', })
(progress) => updateProgress(progress),
{ collection_id: targetCollection }
)
updateProgress(80) updateProgress(80)
updateProgress(90) updateProgress(90)
@@ -211,6 +230,7 @@ export function DocumentUpload({ collections, selectedCollection, onDocumentUplo
{collections.map((collection) => ( {collections.map((collection) => (
<SelectItem key={collection.id} value={collection.id}> <SelectItem key={collection.id} value={collection.id}>
{collection.name} {collection.name}
{collection.is_managed === false ? ' (external)' : ''}
</SelectItem> </SelectItem>
))} ))}
</SelectContent> </SelectContent>

View File

@@ -29,23 +29,61 @@ export async function downloadFile(path: string, filename: string, params?: URLS
} }
export async function uploadFile(path: string, file: File, extraFields?: Record<string, string>) { export async function uploadFile(path: string, file: File, extraFields?: Record<string, string>) {
if (typeof path !== 'string' || path.length === 0) {
throw new TypeError('uploadFile path must be a non-empty string')
}
// Ensure path starts with / and construct full URL
const cleanPath = path.startsWith('/') ? path : `/${path}`
const url = `${typeof window !== 'undefined' ? window.location.origin : 'http://localhost:3000'}${cleanPath}`
// Debug logging
console.log('uploadFile called with:', { path: cleanPath, url, fileName: file.name, fileSize: file.size })
const form = new FormData() const form = new FormData()
form.append('file', file) form.append('file', file)
if (extraFields) Object.entries(extraFields).forEach(([k, v]) => form.append(k, v)) if (extraFields) {
console.log('Adding extra fields:', extraFields)
Object.entries(extraFields).forEach(([k, v]) => form.append(k, v))
}
const token = await tokenManager.getAccessToken() const token = await tokenManager.getAccessToken()
const res = await fetch(path, { console.log('Making request to:', url)
try {
const res = await fetch(url, {
method: 'POST', method: 'POST',
headers: { headers: {
...(token ? { Authorization: `Bearer ${token}` } : {}), ...(token ? { Authorization: `Bearer ${token}` } : {}),
}, },
body: form, body: form,
}) })
if (!res.ok) {
let details: any
try { details = await res.json() } catch { details = await res.text() }
throw new Error(typeof details === 'string' ? details : (details?.error || 'Upload failed'))
}
return await res.json()
}
console.log('Response status:', res.status, res.statusText)
if (!res.ok) {
const rawError = await res.text()
console.log('Error payload:', rawError)
let details: any
try {
details = rawError ? JSON.parse(rawError) : undefined
} catch {
details = rawError
}
const message = typeof details === 'string'
? details
: details?.detail || details?.error || `Upload failed (${res.status})`
throw new Error(message)
}
const result = await res.json()
console.log('Upload successful:', result)
return result
} catch (error) {
console.error('Upload failed:', error)
throw error
}
}