tshoot rag memory leak

This commit is contained in:
2025-10-06 14:42:07 +02:00
parent 3e841d0e42
commit bae86fb5a2
12 changed files with 295 additions and 92 deletions

View File

@@ -345,7 +345,12 @@ async def refresh_token(
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
)
except HTTPException:
# Re-raise HTTPException without modification
raise
except Exception as e:
# Log the actual error for debugging
logger.error(f"Refresh token error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token"

View File

@@ -288,14 +288,18 @@ async def get_documents(
@router.post("/documents", response_model=dict)
async def upload_document(
collection_id: int = Form(...),
collection_id: str = Form(...),
file: UploadFile = File(...),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Upload and process a document"""
try:
# Read file content
# Validate file can be read before processing
filename = file.filename or "unknown"
file_extension = filename.split('.')[-1].lower() if '.' in filename else ''
# Read file content once and use it for all validations
file_content = await file.read()
if len(file_content) == 0:
@@ -304,10 +308,6 @@ async def upload_document(
if len(file_content) > 50 * 1024 * 1024: # 50MB limit
raise HTTPException(status_code=400, detail="File too large (max 50MB)")
# Validate file can be read before processing
filename = file.filename or "unknown"
file_extension = filename.split('.')[-1].lower() if '.' in filename else ''
try:
# Test file readability based on type
if file_extension == 'jsonl':
@@ -349,8 +349,33 @@ async def upload_document(
raise HTTPException(status_code=400, detail=f"File validation failed: {str(e)}")
rag_service = RAGService(db)
# Resolve collection identifier (supports both numeric IDs and Qdrant collection names)
collection_identifier = (collection_id or "").strip()
if not collection_identifier:
raise HTTPException(status_code=400, detail="Collection identifier is required")
resolved_collection_id: Optional[int] = None
if collection_identifier.isdigit():
resolved_collection_id = int(collection_identifier)
else:
qdrant_name = collection_identifier
if qdrant_name.startswith("ext_"):
qdrant_name = qdrant_name[4:]
try:
collection_record = await rag_service.ensure_collection_record(qdrant_name)
except Exception as ensure_error:
raise HTTPException(status_code=500, detail=str(ensure_error))
resolved_collection_id = collection_record.id
if resolved_collection_id is None:
raise HTTPException(status_code=400, detail="Invalid collection identifier")
document = await rag_service.upload_document(
collection_id=collection_id,
collection_id=resolved_collection_id,
file_content=file_content,
filename=filename,
content_type=file.content_type

View File

@@ -682,10 +682,13 @@ class RAGModule(BaseModule):
chunks.append(chunk_text)
# Move to next chunk with overlap
start_idx = end_idx - chunk_overlap
# Ensure we make progress and don't loop infinitely
start_idx += chunk_size - chunk_overlap
if start_idx >= len(tokens):
break
# Ensure progress (in case overlap >= chunk_size)
if start_idx >= end_idx:
# Safety check to prevent infinite loop
if start_idx <= end_idx - chunk_size:
start_idx = end_idx
return chunks

View File

@@ -16,6 +16,7 @@ from sqlalchemy.orm import selectinload
from app.db.database import get_db
from app.models.rag_document import RagDocument
from app.models.rag_collection import RagCollection
from app.services.module_manager import module_manager
logger = logging.getLogger(__name__)
@@ -57,6 +58,8 @@ class DocumentProcessor:
"queue_size": 0,
"active_workers": 0
}
self._rag_module = None
self._rag_module_lock = asyncio.Lock()
async def start(self):
"""Start the document processor"""
@@ -157,9 +160,33 @@ class DocumentProcessor:
self.stats["active_workers"] -= 1
logger.error(f"{worker_name}: Unexpected error: {e}")
await asyncio.sleep(1) # Brief pause before continuing
logger.info(f"Worker stopped: {worker_name}")
async def _get_rag_module(self):
"""Resolve and cache the RAG module instance"""
async with self._rag_module_lock:
if self._rag_module and getattr(self._rag_module, 'enabled', False):
return self._rag_module
if not module_manager.initialized:
await module_manager.initialize()
rag_module = module_manager.modules.get('rag')
if not rag_module or not getattr(rag_module, 'enabled', False):
enabled = await module_manager.enable_module('rag')
if not enabled:
raise Exception("Failed to enable RAG module")
rag_module = module_manager.modules.get('rag')
if not rag_module or not getattr(rag_module, 'enabled', False):
raise Exception("RAG module not available or not enabled")
self._rag_module = rag_module
logger.info("DocumentProcessor cached RAG module instance for reuse")
return self._rag_module
async def _process_document(self, task: ProcessingTask) -> bool:
"""Process a single document"""
from datetime import datetime
@@ -182,19 +209,10 @@ class DocumentProcessor:
# Update status to processing
document.status = ProcessingStatus.PROCESSING
await session.commit()
# Get RAG module for processing
try:
# Import RAG module and initialize it properly
from modules.rag.main import RAGModule
from app.core.config import settings
# Create and initialize RAG module instance
rag_module = RAGModule(settings)
init_result = await rag_module.initialize()
if not rag_module.enabled:
raise Exception("Failed to enable RAG module")
rag_module = await self._get_rag_module()
except Exception as e:
logger.error(f"Failed to get RAG module: {e}")
raise Exception(f"RAG module not available: {e}")
@@ -376,4 +394,4 @@ class DocumentProcessor:
# Global document processor instance
document_processor = DocumentProcessor()
document_processor = DocumentProcessor()

View File

@@ -175,6 +175,15 @@ class EmbeddingService:
async def cleanup(self):
"""Cleanup resources"""
# Cleanup LLM service to prevent memory leaks
try:
from .llm.service import llm_service
if llm_service._initialized:
await llm_service.cleanup()
logger.info("Cleaned up LLM service from embedding service")
except Exception as e:
logger.error(f"Error cleaning up LLM service: {e}")
self.initialized = False

View File

@@ -146,8 +146,8 @@ class EnhancedEmbeddingService(EmbeddingService):
self._dimension_confirmed = True
else:
raise ValueError("Empty embedding in response")
else:
raise ValueError("Invalid response structure")
else:
raise ValueError("Invalid response structure")
# Count this successful request and optionally delay between requests
self._update_rate_limit_tracker(success=True)

View File

@@ -478,5 +478,11 @@ class PrivateModeProvider(BaseLLMProvider):
async def cleanup(self):
"""Cleanup PrivateMode provider resources"""
# Close HTTP session to prevent memory leaks
if self._session and not self._session.closed:
await self._session.close()
self._session = None
logger.debug("Closed PrivateMode HTTP session")
await super().cleanup()
logger.debug("PrivateMode provider cleanup completed")

View File

@@ -81,7 +81,55 @@ class RAGService:
RagCollection.is_active == True
)
return await self.db.scalar(stmt)
async def get_collection_by_qdrant_name(self, qdrant_collection_name: str) -> Optional[RagCollection]:
"""Get a collection using its Qdrant collection name"""
stmt = select(RagCollection).where(
RagCollection.qdrant_collection_name == qdrant_collection_name
)
return await self.db.scalar(stmt)
async def ensure_collection_record(self, qdrant_collection_name: str) -> RagCollection:
"""Ensure we have a managed record for a given Qdrant collection"""
existing = await self.get_collection_by_qdrant_name(qdrant_collection_name)
if existing:
return existing
# Create a friendly name from the Qdrant collection identifier
friendly_name = qdrant_collection_name
try:
if qdrant_collection_name.startswith("rag_"):
trimmed = qdrant_collection_name[4:]
parts = [part for part in trimmed.split("_") if part]
if parts:
friendly_name = " ".join(parts).title()
except Exception:
# Fall back to original identifier on any parsing issues
friendly_name = qdrant_collection_name
collection = RagCollection(
name=friendly_name,
description=f"Synced from Qdrant collection '{qdrant_collection_name}'",
qdrant_collection_name=qdrant_collection_name,
status='active',
is_active=True
)
self.db.add(collection)
try:
await self.db.commit()
except Exception:
await self.db.rollback()
# Another request might have created the collection concurrently; fetch again
existing = await self.get_collection_by_qdrant_name(qdrant_collection_name)
if existing:
return existing
raise
await self.db.refresh(collection)
return collection
async def get_all_collections(self, skip: int = 0, limit: int = 100) -> List[dict]:
"""Get all collections from Qdrant (source of truth) with additional metadata from PostgreSQL."""
logger.info("Getting all RAG collections from Qdrant (source of truth)")

View File

@@ -664,28 +664,36 @@ class RAGModule(BaseModule):
chunk_size = chunk_size or self.config.get("chunk_size", 300)
chunk_overlap = self.config.get("chunk_overlap", 50)
# Tokenize text
# Ensure sane values to avoid infinite loops on very short docs
chunk_size = max(1, chunk_size)
if chunk_overlap >= chunk_size:
chunk_overlap = max(0, chunk_size - 1)
tokens = self.tokenizer.encode(text)
if not tokens:
return []
# Split into chunks with overlap
chunks = []
chunks: List[str] = []
len_tokens = len(tokens)
start_idx = 0
step = max(1, chunk_size - chunk_overlap)
while start_idx < len(tokens):
end_idx = min(start_idx + chunk_size, len(tokens))
while start_idx < len_tokens:
end_idx = min(start_idx + chunk_size, len_tokens)
chunk_tokens = tokens[start_idx:end_idx]
if not chunk_tokens:
break
chunk_text = self.tokenizer.decode(chunk_tokens)
# Only add non-empty chunks
if chunk_text.strip():
chunks.append(chunk_text)
# Move to next chunk with overlap
start_idx = end_idx - chunk_overlap
if end_idx >= len_tokens:
break
# Ensure progress (in case overlap >= chunk_size)
if start_idx >= end_idx:
start_idx = end_idx
start_idx += step
return chunks
@@ -1962,4 +1970,4 @@ async def delete_collection(collection_name: str) -> bool:
async def get_supported_types() -> List[str]:
"""Get list of supported file types"""
return list(rag_module.supported_types.keys())
return list(rag_module.supported_types.keys())

View File

@@ -42,34 +42,57 @@ export async function GET(request: NextRequest) {
export async function POST(request: NextRequest) {
try {
const formData = await request.formData()
console.log('=== Document Upload API Route Called ===')
// Get auth token from request headers
const authHeader = request.headers.get('authorization')
console.log('Auth header present:', !!authHeader)
// Forward the FormData directly to backend
const backendResponse = await fetch(`${BACKEND_URL}/api/rag/documents`, {
method: 'POST',
headers: {
...(authHeader && { 'Authorization': authHeader }),
// Don't set Content-Type for FormData - let the browser set it with boundary
},
body: formData,
})
// Get the original content type (includes multipart boundary)
const contentType = request.headers.get('content-type')
console.log('Original content type:', contentType)
if (!backendResponse.ok) {
const errorData = await backendResponse.json().catch(() => ({ error: 'Unknown error' }))
return NextResponse.json(
{ success: false, error: errorData.detail || errorData.error || 'Failed to upload document' },
{ status: backendResponse.status }
)
// Try to forward the original request body directly
try {
console.log('Request body type:', typeof request.body)
console.log('Request body:', request.body)
// Forward directly to backend with original body stream
console.log('Sending request to backend:', `${BACKEND_URL}/api/rag/documents`)
const backendResponse = await fetch(`${BACKEND_URL}/api/rag/documents`, {
method: 'POST',
headers: {
...(authHeader && { 'Authorization': authHeader }),
...(contentType && { 'Content-Type': contentType }),
},
body: request.body,
})
console.log('Backend response status:', backendResponse.status, backendResponse.statusText)
if (!backendResponse.ok) {
const errorData = await backendResponse.json().catch(() => ({ error: 'Unknown error' }))
console.log('Backend error response:', errorData)
return NextResponse.json(
{ success: false, error: errorData.detail || errorData.error || 'Failed to upload document' },
{ status: backendResponse.status }
)
}
const data = await backendResponse.json()
console.log('Backend success response:', data)
return NextResponse.json(data)
} catch (bodyError) {
console.error('Error reading request body:', bodyError)
throw bodyError
}
const data = await backendResponse.json()
return NextResponse.json(data)
} catch (error) {
console.error('Document upload error:', error)
console.error('Error stack:', error.stack)
return NextResponse.json(
{ success: false, error: 'Failed to upload document' },
{ success: false, error: 'Failed to upload document: ' + error.message },
{ status: 500 }
)
}

View File

@@ -1,6 +1,6 @@
"use client"
import { useState, useRef } from "react"
import { useState, useRef, useEffect } from "react"
import { Button } from "@/components/ui/button"
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"
import { Input } from "@/components/ui/input"
@@ -16,6 +16,8 @@ import { uploadFile } from "@/lib/file-download"
interface Collection {
id: string
name: string
is_managed?: boolean
source?: string
}
interface DocumentUploadProps {
@@ -39,13 +41,31 @@ export function DocumentUpload({ collections, selectedCollection, onDocumentUplo
const fileInputRef = useRef<HTMLInputElement>(null)
const { toast } = useToast()
const isUploadableCollection = (collectionId: string | null) => {
if (!collectionId) return false
return collections.some(collection => collection.id === collectionId)
}
useEffect(() => {
if (selectedCollection && isUploadableCollection(selectedCollection)) {
setTargetCollection(selectedCollection)
return
}
if (!targetCollection || !isUploadableCollection(targetCollection)) {
setTargetCollection(collections[0]?.id || "")
}
}, [selectedCollection, collections, targetCollection])
const supportedTypes = [
".pdf", ".docx", ".doc", ".xlsx", ".xls", ".txt", ".md", ".html", ".json", ".csv"
]
const handleFileSelect = (files: FileList | null) => {
console.log('handleFileSelect called with:', files)
if (!files || files.length === 0) return
if (!targetCollection) {
toast({
title: "Error",
@@ -62,23 +82,25 @@ export function DocumentUpload({ collections, selectedCollection, onDocumentUplo
id: Math.random().toString(36).substr(2, 9)
}))
console.log('Processing files:', newFiles.map(f => ({ name: f.file.name, size: f.file.size })))
setUploadingFiles(prev => [...prev, ...newFiles])
// Process each file
newFiles.forEach(uploadFile => {
uploadDocument(uploadFile)
})
// Process files sequentially to avoid body consumption conflicts
processFilesSequentially(newFiles)
}
const processFilesSequentially = async (files: UploadingFile[]) => {
for (const uploadingFile of files) {
await uploadDocument(uploadingFile)
}
}
const uploadDocument = async (uploadingFile: UploadingFile) => {
try {
const formData = new FormData()
formData.append('file', uploadingFile.file)
formData.append('collection_id', targetCollection)
// Simulate upload progress
const updateProgress = (progress: number) => {
setUploadingFiles(prev =>
setUploadingFiles(prev =>
prev.map(f => f.id === uploadingFile.id ? { ...f, progress } : f)
)
}
@@ -90,12 +112,9 @@ export function DocumentUpload({ collections, selectedCollection, onDocumentUplo
await new Promise(resolve => setTimeout(resolve, 200))
updateProgress(60)
await uploadFile(
uploadingFile.file,
'/api-internal/v1/rag/documents',
(progress) => updateProgress(progress),
{ collection_id: targetCollection }
)
await uploadFile('/api-internal/v1/rag/documents', uploadingFile.file, {
collection_id: targetCollection,
})
updateProgress(80)
updateProgress(90)
@@ -211,6 +230,7 @@ export function DocumentUpload({ collections, selectedCollection, onDocumentUplo
{collections.map((collection) => (
<SelectItem key={collection.id} value={collection.id}>
{collection.name}
{collection.is_managed === false ? ' (external)' : ''}
</SelectItem>
))}
</SelectContent>
@@ -316,4 +336,4 @@ export function DocumentUpload({ collections, selectedCollection, onDocumentUplo
)}
</div>
)
}
}

View File

@@ -29,23 +29,61 @@ export async function downloadFile(path: string, filename: string, params?: URLS
}
export async function uploadFile(path: string, file: File, extraFields?: Record<string, string>) {
if (typeof path !== 'string' || path.length === 0) {
throw new TypeError('uploadFile path must be a non-empty string')
}
// Ensure path starts with / and construct full URL
const cleanPath = path.startsWith('/') ? path : `/${path}`
const url = `${typeof window !== 'undefined' ? window.location.origin : 'http://localhost:3000'}${cleanPath}`
// Debug logging
console.log('uploadFile called with:', { path: cleanPath, url, fileName: file.name, fileSize: file.size })
const form = new FormData()
form.append('file', file)
if (extraFields) Object.entries(extraFields).forEach(([k, v]) => form.append(k, v))
if (extraFields) {
console.log('Adding extra fields:', extraFields)
Object.entries(extraFields).forEach(([k, v]) => form.append(k, v))
}
const token = await tokenManager.getAccessToken()
const res = await fetch(path, {
method: 'POST',
headers: {
...(token ? { Authorization: `Bearer ${token}` } : {}),
},
body: form,
})
if (!res.ok) {
let details: any
try { details = await res.json() } catch { details = await res.text() }
throw new Error(typeof details === 'string' ? details : (details?.error || 'Upload failed'))
}
return await res.json()
}
console.log('Making request to:', url)
try {
const res = await fetch(url, {
method: 'POST',
headers: {
...(token ? { Authorization: `Bearer ${token}` } : {}),
},
body: form,
})
console.log('Response status:', res.status, res.statusText)
if (!res.ok) {
const rawError = await res.text()
console.log('Error payload:', rawError)
let details: any
try {
details = rawError ? JSON.parse(rawError) : undefined
} catch {
details = rawError
}
const message = typeof details === 'string'
? details
: details?.detail || details?.error || `Upload failed (${res.status})`
throw new Error(message)
}
const result = await res.json()
console.log('Upload successful:', result)
return result
} catch (error) {
console.error('Upload failed:', error)
throw error
}
}