Files
enclava/backend/tests/conftest.py
2025-08-25 17:13:15 +02:00

247 lines
7.1 KiB
Python

"""
Pytest configuration and shared fixtures for all tests.
"""
import os
import sys
import asyncio
import pytest
import pytest_asyncio
from pathlib import Path
from typing import AsyncGenerator, Generator
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.pool import NullPool
import aiohttp
from qdrant_client import QdrantClient
from httpx import AsyncClient
import uuid
# Add backend directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from app.db.database import Base, get_db
from app.core.config import settings
from app.main import app
# Test database URL (use different database name for tests)
TEST_DATABASE_URL = os.getenv(
"TEST_DATABASE_URL",
"postgresql+asyncpg://enclava_user:enclava_pass@localhost:5432/enclava_test_db"
)
# Create test engine
test_engine = create_async_engine(
TEST_DATABASE_URL,
echo=False,
pool_pre_ping=True,
poolclass=NullPool
)
# Create test session factory
TestSessionLocal = async_sessionmaker(
test_engine,
class_=AsyncSession,
expire_on_commit=False
)
@pytest.fixture(scope="session")
def event_loop():
"""Create an instance of the default event loop for the test session."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest_asyncio.fixture(scope="function")
async def test_db() -> AsyncGenerator[AsyncSession, None]:
"""Create a test database session with automatic rollback."""
async with test_engine.begin() as conn:
# Create all tables for this test
await conn.run_sync(Base.metadata.create_all)
async with TestSessionLocal() as session:
yield session
# Rollback any changes made during the test
await session.rollback()
# Clean up tables after test
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest_asyncio.fixture(scope="function")
async def async_client() -> AsyncGenerator[AsyncClient, None]:
"""Create an async HTTP client for testing FastAPI endpoints."""
async def override_get_db():
async with TestSessionLocal() as session:
yield session
app.dependency_overrides[get_db] = override_get_db
async with AsyncClient(app=app, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
@pytest_asyncio.fixture(scope="function")
async def authenticated_client(async_client: AsyncClient, test_user_token: str) -> AsyncClient:
"""Create an authenticated async client with JWT token."""
async_client.headers.update({"Authorization": f"Bearer {test_user_token}"})
return async_client
@pytest_asyncio.fixture(scope="function")
async def api_key_client(async_client: AsyncClient, test_api_key: str) -> AsyncClient:
"""Create an async client authenticated with API key."""
async_client.headers.update({"Authorization": f"Bearer {test_api_key}"})
return async_client
@pytest_asyncio.fixture(scope="function")
async def nginx_client() -> AsyncGenerator[aiohttp.ClientSession, None]:
"""Create an aiohttp client for testing through nginx proxy."""
async with aiohttp.ClientSession() as session:
yield session
@pytest.fixture(scope="function")
def qdrant_client() -> QdrantClient:
"""Create a Qdrant client for testing."""
return QdrantClient(
host=os.getenv("QDRANT_HOST", "localhost"),
port=int(os.getenv("QDRANT_PORT", "6333"))
)
@pytest_asyncio.fixture(scope="function")
async def test_user(test_db: AsyncSession) -> dict:
"""Create a test user."""
from app.models.user import User
from app.core.security import get_password_hash
user = User(
email="testuser@example.com",
username="testuser",
hashed_password=get_password_hash("testpass123"),
is_active=True,
is_verified=True
)
test_db.add(user)
await test_db.commit()
await test_db.refresh(user)
return {
"id": str(user.id),
"email": user.email,
"username": user.username,
"password": "testpass123"
}
@pytest_asyncio.fixture(scope="function")
async def test_user_token(test_user: dict) -> str:
"""Create a JWT token for test user."""
from app.core.security import create_access_token
token_data = {"sub": test_user["email"], "user_id": test_user["id"]}
return create_access_token(data=token_data)
@pytest_asyncio.fixture(scope="function")
async def test_api_key(test_db: AsyncSession, test_user: dict) -> str:
"""Create a test API key."""
from app.models.api_key import APIKey
from app.models.budget import Budget
import secrets
# Create budget
budget = Budget(
id=str(uuid.uuid4()),
user_id=test_user["id"],
limit_amount=100.0,
period="monthly",
current_usage=0.0,
is_active=True
)
test_db.add(budget)
# Create API key
key = f"sk-test-{secrets.token_urlsafe(32)}"
api_key = APIKey(
id=str(uuid.uuid4()),
key_hash=key, # In real code, this would be hashed
name="Test API Key",
user_id=test_user["id"],
scopes=["llm.chat", "llm.embeddings"],
budget_id=budget.id,
is_active=True
)
test_db.add(api_key)
await test_db.commit()
return key
@pytest_asyncio.fixture(scope="function")
async def test_qdrant_collection(qdrant_client: QdrantClient) -> str:
"""Create a test Qdrant collection."""
from qdrant_client.models import Distance, VectorParams
collection_name = f"test_collection_{uuid.uuid4().hex[:8]}"
qdrant_client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=1536, distance=Distance.COSINE)
)
yield collection_name
# Cleanup
try:
qdrant_client.delete_collection(collection_name)
except Exception:
pass
@pytest.fixture(scope="session")
def test_documents_dir() -> Path:
"""Get the test documents directory."""
return Path(__file__).parent / "data" / "documents"
@pytest.fixture(scope="session")
def sample_text_path(test_documents_dir: Path) -> Path:
"""Get path to sample text file for testing."""
text_path = test_documents_dir / "sample.txt"
if not text_path.exists():
text_path.parent.mkdir(parents=True, exist_ok=True)
text_path.write_text("""
Enclava Platform Documentation
This is a sample document for testing the RAG system.
It contains information about the Enclava platform's features and capabilities.
Features:
- Secure LLM access through PrivateMode.ai
- Chatbot creation and management
- RAG (Retrieval Augmented Generation) support
- OpenAI-compatible API endpoints
- Budget management and API key controls
""")
return text_path
# Test environment variables
@pytest.fixture(scope="session", autouse=True)
def setup_test_env():
"""Setup test environment variables."""
os.environ["TESTING"] = "true"
os.environ["LOG_LLM_PROMPTS"] = "true"
os.environ["APP_DEBUG"] = "true"
yield
# Cleanup
os.environ.pop("TESTING", None)