mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 23:44:24 +01:00
tshoot rag memory leak
This commit is contained in:
@@ -345,7 +345,12 @@ async def refresh_token(
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPException without modification
|
||||
raise
|
||||
except Exception as e:
|
||||
# Log the actual error for debugging
|
||||
logger.error(f"Refresh token error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token"
|
||||
|
||||
@@ -288,14 +288,18 @@ async def get_documents(
|
||||
|
||||
@router.post("/documents", response_model=dict)
|
||||
async def upload_document(
|
||||
collection_id: int = Form(...),
|
||||
collection_id: str = Form(...),
|
||||
file: UploadFile = File(...),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Upload and process a document"""
|
||||
try:
|
||||
# Read file content
|
||||
# Validate file can be read before processing
|
||||
filename = file.filename or "unknown"
|
||||
file_extension = filename.split('.')[-1].lower() if '.' in filename else ''
|
||||
|
||||
# Read file content once and use it for all validations
|
||||
file_content = await file.read()
|
||||
|
||||
if len(file_content) == 0:
|
||||
@@ -304,10 +308,6 @@ async def upload_document(
|
||||
if len(file_content) > 50 * 1024 * 1024: # 50MB limit
|
||||
raise HTTPException(status_code=400, detail="File too large (max 50MB)")
|
||||
|
||||
# Validate file can be read before processing
|
||||
filename = file.filename or "unknown"
|
||||
file_extension = filename.split('.')[-1].lower() if '.' in filename else ''
|
||||
|
||||
try:
|
||||
# Test file readability based on type
|
||||
if file_extension == 'jsonl':
|
||||
@@ -349,8 +349,33 @@ async def upload_document(
|
||||
raise HTTPException(status_code=400, detail=f"File validation failed: {str(e)}")
|
||||
|
||||
rag_service = RAGService(db)
|
||||
|
||||
# Resolve collection identifier (supports both numeric IDs and Qdrant collection names)
|
||||
collection_identifier = (collection_id or "").strip()
|
||||
if not collection_identifier:
|
||||
raise HTTPException(status_code=400, detail="Collection identifier is required")
|
||||
|
||||
resolved_collection_id: Optional[int] = None
|
||||
|
||||
if collection_identifier.isdigit():
|
||||
resolved_collection_id = int(collection_identifier)
|
||||
else:
|
||||
qdrant_name = collection_identifier
|
||||
if qdrant_name.startswith("ext_"):
|
||||
qdrant_name = qdrant_name[4:]
|
||||
|
||||
try:
|
||||
collection_record = await rag_service.ensure_collection_record(qdrant_name)
|
||||
except Exception as ensure_error:
|
||||
raise HTTPException(status_code=500, detail=str(ensure_error))
|
||||
|
||||
resolved_collection_id = collection_record.id
|
||||
|
||||
if resolved_collection_id is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid collection identifier")
|
||||
|
||||
document = await rag_service.upload_document(
|
||||
collection_id=collection_id,
|
||||
collection_id=resolved_collection_id,
|
||||
file_content=file_content,
|
||||
filename=filename,
|
||||
content_type=file.content_type
|
||||
|
||||
@@ -682,10 +682,13 @@ class RAGModule(BaseModule):
|
||||
chunks.append(chunk_text)
|
||||
|
||||
# Move to next chunk with overlap
|
||||
start_idx = end_idx - chunk_overlap
|
||||
# Ensure we make progress and don't loop infinitely
|
||||
start_idx += chunk_size - chunk_overlap
|
||||
if start_idx >= len(tokens):
|
||||
break
|
||||
|
||||
# Ensure progress (in case overlap >= chunk_size)
|
||||
if start_idx >= end_idx:
|
||||
# Safety check to prevent infinite loop
|
||||
if start_idx <= end_idx - chunk_size:
|
||||
start_idx = end_idx
|
||||
|
||||
return chunks
|
||||
|
||||
@@ -16,6 +16,7 @@ from sqlalchemy.orm import selectinload
|
||||
from app.db.database import get_db
|
||||
from app.models.rag_document import RagDocument
|
||||
from app.models.rag_collection import RagCollection
|
||||
from app.services.module_manager import module_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -57,6 +58,8 @@ class DocumentProcessor:
|
||||
"queue_size": 0,
|
||||
"active_workers": 0
|
||||
}
|
||||
self._rag_module = None
|
||||
self._rag_module_lock = asyncio.Lock()
|
||||
|
||||
async def start(self):
|
||||
"""Start the document processor"""
|
||||
@@ -160,6 +163,30 @@ class DocumentProcessor:
|
||||
|
||||
logger.info(f"Worker stopped: {worker_name}")
|
||||
|
||||
async def _get_rag_module(self):
|
||||
"""Resolve and cache the RAG module instance"""
|
||||
async with self._rag_module_lock:
|
||||
if self._rag_module and getattr(self._rag_module, 'enabled', False):
|
||||
return self._rag_module
|
||||
|
||||
if not module_manager.initialized:
|
||||
await module_manager.initialize()
|
||||
|
||||
rag_module = module_manager.modules.get('rag')
|
||||
|
||||
if not rag_module or not getattr(rag_module, 'enabled', False):
|
||||
enabled = await module_manager.enable_module('rag')
|
||||
if not enabled:
|
||||
raise Exception("Failed to enable RAG module")
|
||||
rag_module = module_manager.modules.get('rag')
|
||||
|
||||
if not rag_module or not getattr(rag_module, 'enabled', False):
|
||||
raise Exception("RAG module not available or not enabled")
|
||||
|
||||
self._rag_module = rag_module
|
||||
logger.info("DocumentProcessor cached RAG module instance for reuse")
|
||||
return self._rag_module
|
||||
|
||||
async def _process_document(self, task: ProcessingTask) -> bool:
|
||||
"""Process a single document"""
|
||||
from datetime import datetime
|
||||
@@ -185,16 +212,7 @@ class DocumentProcessor:
|
||||
|
||||
# Get RAG module for processing
|
||||
try:
|
||||
# Import RAG module and initialize it properly
|
||||
from modules.rag.main import RAGModule
|
||||
from app.core.config import settings
|
||||
|
||||
# Create and initialize RAG module instance
|
||||
rag_module = RAGModule(settings)
|
||||
init_result = await rag_module.initialize()
|
||||
if not rag_module.enabled:
|
||||
raise Exception("Failed to enable RAG module")
|
||||
|
||||
rag_module = await self._get_rag_module()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get RAG module: {e}")
|
||||
raise Exception(f"RAG module not available: {e}")
|
||||
|
||||
@@ -175,6 +175,15 @@ class EmbeddingService:
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
# Cleanup LLM service to prevent memory leaks
|
||||
try:
|
||||
from .llm.service import llm_service
|
||||
if llm_service._initialized:
|
||||
await llm_service.cleanup()
|
||||
logger.info("Cleaned up LLM service from embedding service")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up LLM service: {e}")
|
||||
|
||||
self.initialized = False
|
||||
|
||||
|
||||
|
||||
@@ -478,5 +478,11 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup PrivateMode provider resources"""
|
||||
# Close HTTP session to prevent memory leaks
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
logger.debug("Closed PrivateMode HTTP session")
|
||||
|
||||
await super().cleanup()
|
||||
logger.debug("PrivateMode provider cleanup completed")
|
||||
@@ -82,6 +82,54 @@ class RAGService:
|
||||
)
|
||||
return await self.db.scalar(stmt)
|
||||
|
||||
async def get_collection_by_qdrant_name(self, qdrant_collection_name: str) -> Optional[RagCollection]:
|
||||
"""Get a collection using its Qdrant collection name"""
|
||||
stmt = select(RagCollection).where(
|
||||
RagCollection.qdrant_collection_name == qdrant_collection_name
|
||||
)
|
||||
return await self.db.scalar(stmt)
|
||||
|
||||
async def ensure_collection_record(self, qdrant_collection_name: str) -> RagCollection:
|
||||
"""Ensure we have a managed record for a given Qdrant collection"""
|
||||
existing = await self.get_collection_by_qdrant_name(qdrant_collection_name)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
# Create a friendly name from the Qdrant collection identifier
|
||||
friendly_name = qdrant_collection_name
|
||||
try:
|
||||
if qdrant_collection_name.startswith("rag_"):
|
||||
trimmed = qdrant_collection_name[4:]
|
||||
parts = [part for part in trimmed.split("_") if part]
|
||||
if parts:
|
||||
friendly_name = " ".join(parts).title()
|
||||
except Exception:
|
||||
# Fall back to original identifier on any parsing issues
|
||||
friendly_name = qdrant_collection_name
|
||||
|
||||
collection = RagCollection(
|
||||
name=friendly_name,
|
||||
description=f"Synced from Qdrant collection '{qdrant_collection_name}'",
|
||||
qdrant_collection_name=qdrant_collection_name,
|
||||
status='active',
|
||||
is_active=True
|
||||
)
|
||||
|
||||
self.db.add(collection)
|
||||
|
||||
try:
|
||||
await self.db.commit()
|
||||
except Exception:
|
||||
await self.db.rollback()
|
||||
# Another request might have created the collection concurrently; fetch again
|
||||
existing = await self.get_collection_by_qdrant_name(qdrant_collection_name)
|
||||
if existing:
|
||||
return existing
|
||||
raise
|
||||
|
||||
await self.db.refresh(collection)
|
||||
return collection
|
||||
|
||||
async def get_all_collections(self, skip: int = 0, limit: int = 100) -> List[dict]:
|
||||
"""Get all collections from Qdrant (source of truth) with additional metadata from PostgreSQL."""
|
||||
logger.info("Getting all RAG collections from Qdrant (source of truth)")
|
||||
|
||||
@@ -664,28 +664,36 @@ class RAGModule(BaseModule):
|
||||
chunk_size = chunk_size or self.config.get("chunk_size", 300)
|
||||
chunk_overlap = self.config.get("chunk_overlap", 50)
|
||||
|
||||
# Tokenize text
|
||||
# Ensure sane values to avoid infinite loops on very short docs
|
||||
chunk_size = max(1, chunk_size)
|
||||
if chunk_overlap >= chunk_size:
|
||||
chunk_overlap = max(0, chunk_size - 1)
|
||||
|
||||
tokens = self.tokenizer.encode(text)
|
||||
if not tokens:
|
||||
return []
|
||||
|
||||
# Split into chunks with overlap
|
||||
chunks = []
|
||||
chunks: List[str] = []
|
||||
len_tokens = len(tokens)
|
||||
start_idx = 0
|
||||
step = max(1, chunk_size - chunk_overlap)
|
||||
|
||||
while start_idx < len(tokens):
|
||||
end_idx = min(start_idx + chunk_size, len(tokens))
|
||||
while start_idx < len_tokens:
|
||||
end_idx = min(start_idx + chunk_size, len_tokens)
|
||||
chunk_tokens = tokens[start_idx:end_idx]
|
||||
|
||||
if not chunk_tokens:
|
||||
break
|
||||
|
||||
chunk_text = self.tokenizer.decode(chunk_tokens)
|
||||
|
||||
# Only add non-empty chunks
|
||||
if chunk_text.strip():
|
||||
chunks.append(chunk_text)
|
||||
|
||||
# Move to next chunk with overlap
|
||||
start_idx = end_idx - chunk_overlap
|
||||
if end_idx >= len_tokens:
|
||||
break
|
||||
|
||||
# Ensure progress (in case overlap >= chunk_size)
|
||||
if start_idx >= end_idx:
|
||||
start_idx = end_idx
|
||||
start_idx += step
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
@@ -42,23 +42,37 @@ export async function GET(request: NextRequest) {
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const formData = await request.formData()
|
||||
console.log('=== Document Upload API Route Called ===')
|
||||
|
||||
// Get auth token from request headers
|
||||
const authHeader = request.headers.get('authorization')
|
||||
console.log('Auth header present:', !!authHeader)
|
||||
|
||||
// Forward the FormData directly to backend
|
||||
// Get the original content type (includes multipart boundary)
|
||||
const contentType = request.headers.get('content-type')
|
||||
console.log('Original content type:', contentType)
|
||||
|
||||
// Try to forward the original request body directly
|
||||
try {
|
||||
console.log('Request body type:', typeof request.body)
|
||||
console.log('Request body:', request.body)
|
||||
|
||||
// Forward directly to backend with original body stream
|
||||
console.log('Sending request to backend:', `${BACKEND_URL}/api/rag/documents`)
|
||||
const backendResponse = await fetch(`${BACKEND_URL}/api/rag/documents`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
...(authHeader && { 'Authorization': authHeader }),
|
||||
// Don't set Content-Type for FormData - let the browser set it with boundary
|
||||
...(contentType && { 'Content-Type': contentType }),
|
||||
},
|
||||
body: formData,
|
||||
body: request.body,
|
||||
})
|
||||
|
||||
console.log('Backend response status:', backendResponse.status, backendResponse.statusText)
|
||||
|
||||
if (!backendResponse.ok) {
|
||||
const errorData = await backendResponse.json().catch(() => ({ error: 'Unknown error' }))
|
||||
console.log('Backend error response:', errorData)
|
||||
return NextResponse.json(
|
||||
{ success: false, error: errorData.detail || errorData.error || 'Failed to upload document' },
|
||||
{ status: backendResponse.status }
|
||||
@@ -66,10 +80,19 @@ export async function POST(request: NextRequest) {
|
||||
}
|
||||
|
||||
const data = await backendResponse.json()
|
||||
console.log('Backend success response:', data)
|
||||
return NextResponse.json(data)
|
||||
|
||||
} catch (bodyError) {
|
||||
console.error('Error reading request body:', bodyError)
|
||||
throw bodyError
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
console.error('Document upload error:', error)
|
||||
console.error('Error stack:', error.stack)
|
||||
return NextResponse.json(
|
||||
{ success: false, error: 'Failed to upload document' },
|
||||
{ success: false, error: 'Failed to upload document: ' + error.message },
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"use client"
|
||||
|
||||
import { useState, useRef } from "react"
|
||||
import { useState, useRef, useEffect } from "react"
|
||||
import { Button } from "@/components/ui/button"
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"
|
||||
import { Input } from "@/components/ui/input"
|
||||
@@ -16,6 +16,8 @@ import { uploadFile } from "@/lib/file-download"
|
||||
interface Collection {
|
||||
id: string
|
||||
name: string
|
||||
is_managed?: boolean
|
||||
source?: string
|
||||
}
|
||||
|
||||
interface DocumentUploadProps {
|
||||
@@ -39,11 +41,29 @@ export function DocumentUpload({ collections, selectedCollection, onDocumentUplo
|
||||
const fileInputRef = useRef<HTMLInputElement>(null)
|
||||
const { toast } = useToast()
|
||||
|
||||
const isUploadableCollection = (collectionId: string | null) => {
|
||||
if (!collectionId) return false
|
||||
return collections.some(collection => collection.id === collectionId)
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedCollection && isUploadableCollection(selectedCollection)) {
|
||||
setTargetCollection(selectedCollection)
|
||||
return
|
||||
}
|
||||
|
||||
if (!targetCollection || !isUploadableCollection(targetCollection)) {
|
||||
setTargetCollection(collections[0]?.id || "")
|
||||
}
|
||||
}, [selectedCollection, collections, targetCollection])
|
||||
|
||||
const supportedTypes = [
|
||||
".pdf", ".docx", ".doc", ".xlsx", ".xls", ".txt", ".md", ".html", ".json", ".csv"
|
||||
]
|
||||
|
||||
const handleFileSelect = (files: FileList | null) => {
|
||||
console.log('handleFileSelect called with:', files)
|
||||
|
||||
if (!files || files.length === 0) return
|
||||
|
||||
if (!targetCollection) {
|
||||
@@ -62,20 +82,22 @@ export function DocumentUpload({ collections, selectedCollection, onDocumentUplo
|
||||
id: Math.random().toString(36).substr(2, 9)
|
||||
}))
|
||||
|
||||
console.log('Processing files:', newFiles.map(f => ({ name: f.file.name, size: f.file.size })))
|
||||
|
||||
setUploadingFiles(prev => [...prev, ...newFiles])
|
||||
|
||||
// Process each file
|
||||
newFiles.forEach(uploadFile => {
|
||||
uploadDocument(uploadFile)
|
||||
})
|
||||
// Process files sequentially to avoid body consumption conflicts
|
||||
processFilesSequentially(newFiles)
|
||||
}
|
||||
|
||||
const processFilesSequentially = async (files: UploadingFile[]) => {
|
||||
for (const uploadingFile of files) {
|
||||
await uploadDocument(uploadingFile)
|
||||
}
|
||||
}
|
||||
|
||||
const uploadDocument = async (uploadingFile: UploadingFile) => {
|
||||
try {
|
||||
const formData = new FormData()
|
||||
formData.append('file', uploadingFile.file)
|
||||
formData.append('collection_id', targetCollection)
|
||||
|
||||
// Simulate upload progress
|
||||
const updateProgress = (progress: number) => {
|
||||
setUploadingFiles(prev =>
|
||||
@@ -90,12 +112,9 @@ export function DocumentUpload({ collections, selectedCollection, onDocumentUplo
|
||||
await new Promise(resolve => setTimeout(resolve, 200))
|
||||
updateProgress(60)
|
||||
|
||||
await uploadFile(
|
||||
uploadingFile.file,
|
||||
'/api-internal/v1/rag/documents',
|
||||
(progress) => updateProgress(progress),
|
||||
{ collection_id: targetCollection }
|
||||
)
|
||||
await uploadFile('/api-internal/v1/rag/documents', uploadingFile.file, {
|
||||
collection_id: targetCollection,
|
||||
})
|
||||
|
||||
updateProgress(80)
|
||||
updateProgress(90)
|
||||
@@ -211,6 +230,7 @@ export function DocumentUpload({ collections, selectedCollection, onDocumentUplo
|
||||
{collections.map((collection) => (
|
||||
<SelectItem key={collection.id} value={collection.id}>
|
||||
{collection.name}
|
||||
{collection.is_managed === false ? ' (external)' : ''}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
|
||||
@@ -29,23 +29,61 @@ export async function downloadFile(path: string, filename: string, params?: URLS
|
||||
}
|
||||
|
||||
export async function uploadFile(path: string, file: File, extraFields?: Record<string, string>) {
|
||||
if (typeof path !== 'string' || path.length === 0) {
|
||||
throw new TypeError('uploadFile path must be a non-empty string')
|
||||
}
|
||||
|
||||
// Ensure path starts with / and construct full URL
|
||||
const cleanPath = path.startsWith('/') ? path : `/${path}`
|
||||
const url = `${typeof window !== 'undefined' ? window.location.origin : 'http://localhost:3000'}${cleanPath}`
|
||||
|
||||
// Debug logging
|
||||
console.log('uploadFile called with:', { path: cleanPath, url, fileName: file.name, fileSize: file.size })
|
||||
|
||||
const form = new FormData()
|
||||
form.append('file', file)
|
||||
if (extraFields) Object.entries(extraFields).forEach(([k, v]) => form.append(k, v))
|
||||
if (extraFields) {
|
||||
console.log('Adding extra fields:', extraFields)
|
||||
Object.entries(extraFields).forEach(([k, v]) => form.append(k, v))
|
||||
}
|
||||
|
||||
const token = await tokenManager.getAccessToken()
|
||||
const res = await fetch(path, {
|
||||
console.log('Making request to:', url)
|
||||
|
||||
try {
|
||||
const res = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
...(token ? { Authorization: `Bearer ${token}` } : {}),
|
||||
},
|
||||
body: form,
|
||||
})
|
||||
if (!res.ok) {
|
||||
let details: any
|
||||
try { details = await res.json() } catch { details = await res.text() }
|
||||
throw new Error(typeof details === 'string' ? details : (details?.error || 'Upload failed'))
|
||||
}
|
||||
return await res.json()
|
||||
}
|
||||
|
||||
console.log('Response status:', res.status, res.statusText)
|
||||
|
||||
if (!res.ok) {
|
||||
const rawError = await res.text()
|
||||
console.log('Error payload:', rawError)
|
||||
|
||||
let details: any
|
||||
try {
|
||||
details = rawError ? JSON.parse(rawError) : undefined
|
||||
} catch {
|
||||
details = rawError
|
||||
}
|
||||
|
||||
const message = typeof details === 'string'
|
||||
? details
|
||||
: details?.detail || details?.error || `Upload failed (${res.status})`
|
||||
|
||||
throw new Error(message)
|
||||
}
|
||||
|
||||
const result = await res.json()
|
||||
console.log('Upload successful:', result)
|
||||
return result
|
||||
} catch (error) {
|
||||
console.error('Upload failed:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user