mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 15:34:36 +01:00
247 lines
7.1 KiB
Python
247 lines
7.1 KiB
Python
"""
|
|
Pytest configuration and shared fixtures for all tests.
|
|
"""
|
|
import os
|
|
import sys
|
|
import asyncio
|
|
import pytest
|
|
import pytest_asyncio
|
|
from pathlib import Path
|
|
from typing import AsyncGenerator, Generator
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
|
from sqlalchemy.pool import NullPool
|
|
import aiohttp
|
|
from qdrant_client import QdrantClient
|
|
from httpx import AsyncClient
|
|
import uuid
|
|
|
|
# Add backend directory to path
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
from app.db.database import Base, get_db
|
|
from app.core.config import settings
|
|
from app.main import app
|
|
|
|
|
|
# Test database URL (use different database name for tests)
|
|
TEST_DATABASE_URL = os.getenv(
|
|
"TEST_DATABASE_URL",
|
|
"postgresql+asyncpg://enclava_user:enclava_pass@localhost:5432/enclava_test_db"
|
|
)
|
|
|
|
|
|
# Create test engine
|
|
test_engine = create_async_engine(
|
|
TEST_DATABASE_URL,
|
|
echo=False,
|
|
pool_pre_ping=True,
|
|
poolclass=NullPool
|
|
)
|
|
|
|
# Create test session factory
|
|
TestSessionLocal = async_sessionmaker(
|
|
test_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def event_loop():
|
|
"""Create an instance of the default event loop for the test session."""
|
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
|
yield loop
|
|
loop.close()
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="function")
|
|
async def test_db() -> AsyncGenerator[AsyncSession, None]:
|
|
"""Create a test database session with automatic rollback."""
|
|
async with test_engine.begin() as conn:
|
|
# Create all tables for this test
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
async with TestSessionLocal() as session:
|
|
yield session
|
|
# Rollback any changes made during the test
|
|
await session.rollback()
|
|
|
|
# Clean up tables after test
|
|
async with test_engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.drop_all)
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="function")
|
|
async def async_client() -> AsyncGenerator[AsyncClient, None]:
|
|
"""Create an async HTTP client for testing FastAPI endpoints."""
|
|
async def override_get_db():
|
|
async with TestSessionLocal() as session:
|
|
yield session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
async with AsyncClient(app=app, base_url="http://test") as client:
|
|
yield client
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="function")
|
|
async def authenticated_client(async_client: AsyncClient, test_user_token: str) -> AsyncClient:
|
|
"""Create an authenticated async client with JWT token."""
|
|
async_client.headers.update({"Authorization": f"Bearer {test_user_token}"})
|
|
return async_client
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="function")
|
|
async def api_key_client(async_client: AsyncClient, test_api_key: str) -> AsyncClient:
|
|
"""Create an async client authenticated with API key."""
|
|
async_client.headers.update({"Authorization": f"Bearer {test_api_key}"})
|
|
return async_client
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="function")
|
|
async def nginx_client() -> AsyncGenerator[aiohttp.ClientSession, None]:
|
|
"""Create an aiohttp client for testing through nginx proxy."""
|
|
async with aiohttp.ClientSession() as session:
|
|
yield session
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def qdrant_client() -> QdrantClient:
|
|
"""Create a Qdrant client for testing."""
|
|
return QdrantClient(
|
|
host=os.getenv("QDRANT_HOST", "localhost"),
|
|
port=int(os.getenv("QDRANT_PORT", "6333"))
|
|
)
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="function")
|
|
async def test_user(test_db: AsyncSession) -> dict:
|
|
"""Create a test user."""
|
|
from app.models.user import User
|
|
from app.core.security import get_password_hash
|
|
|
|
user = User(
|
|
email="testuser@example.com",
|
|
username="testuser",
|
|
hashed_password=get_password_hash("testpass123"),
|
|
is_active=True,
|
|
is_verified=True
|
|
)
|
|
|
|
test_db.add(user)
|
|
await test_db.commit()
|
|
await test_db.refresh(user)
|
|
|
|
return {
|
|
"id": str(user.id),
|
|
"email": user.email,
|
|
"username": user.username,
|
|
"password": "testpass123"
|
|
}
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="function")
|
|
async def test_user_token(test_user: dict) -> str:
|
|
"""Create a JWT token for test user."""
|
|
from app.core.security import create_access_token
|
|
|
|
token_data = {"sub": test_user["email"], "user_id": test_user["id"]}
|
|
return create_access_token(data=token_data)
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="function")
|
|
async def test_api_key(test_db: AsyncSession, test_user: dict) -> str:
|
|
"""Create a test API key."""
|
|
from app.models.api_key import APIKey
|
|
from app.models.budget import Budget
|
|
import secrets
|
|
|
|
# Create budget
|
|
budget = Budget(
|
|
id=str(uuid.uuid4()),
|
|
user_id=test_user["id"],
|
|
limit_amount=100.0,
|
|
period="monthly",
|
|
current_usage=0.0,
|
|
is_active=True
|
|
)
|
|
test_db.add(budget)
|
|
|
|
# Create API key
|
|
key = f"sk-test-{secrets.token_urlsafe(32)}"
|
|
api_key = APIKey(
|
|
id=str(uuid.uuid4()),
|
|
key_hash=key, # In real code, this would be hashed
|
|
name="Test API Key",
|
|
user_id=test_user["id"],
|
|
scopes=["llm.chat", "llm.embeddings"],
|
|
budget_id=budget.id,
|
|
is_active=True
|
|
)
|
|
test_db.add(api_key)
|
|
await test_db.commit()
|
|
|
|
return key
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="function")
|
|
async def test_qdrant_collection(qdrant_client: QdrantClient) -> str:
|
|
"""Create a test Qdrant collection."""
|
|
from qdrant_client.models import Distance, VectorParams
|
|
|
|
collection_name = f"test_collection_{uuid.uuid4().hex[:8]}"
|
|
|
|
qdrant_client.create_collection(
|
|
collection_name=collection_name,
|
|
vectors_config=VectorParams(size=1536, distance=Distance.COSINE)
|
|
)
|
|
|
|
yield collection_name
|
|
|
|
# Cleanup
|
|
try:
|
|
qdrant_client.delete_collection(collection_name)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def test_documents_dir() -> Path:
|
|
"""Get the test documents directory."""
|
|
return Path(__file__).parent / "data" / "documents"
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def sample_text_path(test_documents_dir: Path) -> Path:
|
|
"""Get path to sample text file for testing."""
|
|
text_path = test_documents_dir / "sample.txt"
|
|
if not text_path.exists():
|
|
text_path.parent.mkdir(parents=True, exist_ok=True)
|
|
text_path.write_text("""
|
|
Enclava Platform Documentation
|
|
|
|
This is a sample document for testing the RAG system.
|
|
It contains information about the Enclava platform's features and capabilities.
|
|
|
|
Features:
|
|
- Secure LLM access through PrivateMode.ai
|
|
- Chatbot creation and management
|
|
- RAG (Retrieval Augmented Generation) support
|
|
- OpenAI-compatible API endpoints
|
|
- Budget management and API key controls
|
|
""")
|
|
return text_path
|
|
|
|
|
|
# Test environment variables
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def setup_test_env():
|
|
"""Setup test environment variables."""
|
|
os.environ["TESTING"] = "true"
|
|
os.environ["LOG_LLM_PROMPTS"] = "true"
|
|
os.environ["APP_DEBUG"] = "true"
|
|
yield
|
|
# Cleanup
|
|
os.environ.pop("TESTING", None) |