mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 23:44:24 +01:00
tshoot rag memory leak
This commit is contained in:
@@ -288,14 +288,18 @@ async def get_documents(
|
||||
|
||||
@router.post("/documents", response_model=dict)
|
||||
async def upload_document(
|
||||
collection_id: int = Form(...),
|
||||
collection_id: str = Form(...),
|
||||
file: UploadFile = File(...),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Upload and process a document"""
|
||||
try:
|
||||
# Read file content
|
||||
# Validate file can be read before processing
|
||||
filename = file.filename or "unknown"
|
||||
file_extension = filename.split('.')[-1].lower() if '.' in filename else ''
|
||||
|
||||
# Read file content once and use it for all validations
|
||||
file_content = await file.read()
|
||||
|
||||
if len(file_content) == 0:
|
||||
@@ -304,10 +308,6 @@ async def upload_document(
|
||||
if len(file_content) > 50 * 1024 * 1024: # 50MB limit
|
||||
raise HTTPException(status_code=400, detail="File too large (max 50MB)")
|
||||
|
||||
# Validate file can be read before processing
|
||||
filename = file.filename or "unknown"
|
||||
file_extension = filename.split('.')[-1].lower() if '.' in filename else ''
|
||||
|
||||
try:
|
||||
# Test file readability based on type
|
||||
if file_extension == 'jsonl':
|
||||
@@ -349,8 +349,33 @@ async def upload_document(
|
||||
raise HTTPException(status_code=400, detail=f"File validation failed: {str(e)}")
|
||||
|
||||
rag_service = RAGService(db)
|
||||
|
||||
# Resolve collection identifier (supports both numeric IDs and Qdrant collection names)
|
||||
collection_identifier = (collection_id or "").strip()
|
||||
if not collection_identifier:
|
||||
raise HTTPException(status_code=400, detail="Collection identifier is required")
|
||||
|
||||
resolved_collection_id: Optional[int] = None
|
||||
|
||||
if collection_identifier.isdigit():
|
||||
resolved_collection_id = int(collection_identifier)
|
||||
else:
|
||||
qdrant_name = collection_identifier
|
||||
if qdrant_name.startswith("ext_"):
|
||||
qdrant_name = qdrant_name[4:]
|
||||
|
||||
try:
|
||||
collection_record = await rag_service.ensure_collection_record(qdrant_name)
|
||||
except Exception as ensure_error:
|
||||
raise HTTPException(status_code=500, detail=str(ensure_error))
|
||||
|
||||
resolved_collection_id = collection_record.id
|
||||
|
||||
if resolved_collection_id is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid collection identifier")
|
||||
|
||||
document = await rag_service.upload_document(
|
||||
collection_id=collection_id,
|
||||
collection_id=resolved_collection_id,
|
||||
file_content=file_content,
|
||||
filename=filename,
|
||||
content_type=file.content_type
|
||||
|
||||
Reference in New Issue
Block a user