mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 23:44:24 +01:00
fixing rag
This commit is contained in:
@@ -1,21 +1,49 @@
|
||||
"""
|
||||
Pytest configuration and fixtures for testing.
|
||||
Pytest configuration and shared fixtures for all tests.
|
||||
"""
|
||||
import pytest
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator, Generator
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
import aiohttp
|
||||
from qdrant_client import QdrantClient
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
import uuid
|
||||
|
||||
from app.main import app
|
||||
from app.db.database import get_db, Base
|
||||
# Add backend directory to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from app.db.database import Base, get_db
|
||||
from app.core.config import settings
|
||||
from app.main import app
|
||||
|
||||
|
||||
# Test database URL
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite:///./test.db"
|
||||
# Test database URL (use different database name for tests)
|
||||
TEST_DATABASE_URL = os.getenv(
|
||||
"TEST_DATABASE_URL",
|
||||
"postgresql+asyncpg://enclava_user:enclava_pass@localhost:5432/enclava_test_db"
|
||||
)
|
||||
|
||||
|
||||
# Create test engine
|
||||
test_engine = create_async_engine(
|
||||
TEST_DATABASE_URL,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
poolclass=NullPool
|
||||
)
|
||||
|
||||
# Create test session factory
|
||||
TestSessionLocal = async_sessionmaker(
|
||||
test_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@@ -26,44 +54,29 @@ def event_loop():
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def test_engine():
|
||||
"""Create test database engine."""
|
||||
engine = create_async_engine(
|
||||
TEST_DATABASE_URL,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
|
||||
# Create tables
|
||||
async with engine.begin() as conn:
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def test_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create a test database session with automatic rollback."""
|
||||
async with test_engine.begin() as conn:
|
||||
# Create all tables for this test
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
# Cleanup
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_db(test_engine):
|
||||
"""Create test database session."""
|
||||
async_session = sessionmaker(
|
||||
test_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with async_session() as session:
|
||||
async with TestSessionLocal() as session:
|
||||
yield session
|
||||
# Rollback any changes made during the test
|
||||
await session.rollback()
|
||||
|
||||
# Clean up tables after test
|
||||
async with test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(test_db):
|
||||
"""Create test client."""
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def async_client() -> AsyncGenerator[AsyncClient, None]:
|
||||
"""Create an async HTTP client for testing FastAPI endpoints."""
|
||||
async def override_get_db():
|
||||
yield test_db
|
||||
async with TestSessionLocal() as session:
|
||||
yield session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
@@ -73,23 +86,162 @@ async def client(test_db):
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_data():
|
||||
"""Test user data."""
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def authenticated_client(async_client: AsyncClient, test_user_token: str) -> AsyncClient:
|
||||
"""Create an authenticated async client with JWT token."""
|
||||
async_client.headers.update({"Authorization": f"Bearer {test_user_token}"})
|
||||
return async_client
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def api_key_client(async_client: AsyncClient, test_api_key: str) -> AsyncClient:
|
||||
"""Create an async client authenticated with API key."""
|
||||
async_client.headers.update({"Authorization": f"Bearer {test_api_key}"})
|
||||
return async_client
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def nginx_client() -> AsyncGenerator[aiohttp.ClientSession, None]:
|
||||
"""Create an aiohttp client for testing through nginx proxy."""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def qdrant_client() -> QdrantClient:
|
||||
"""Create a Qdrant client for testing."""
|
||||
return QdrantClient(
|
||||
host=os.getenv("QDRANT_HOST", "localhost"),
|
||||
port=int(os.getenv("QDRANT_PORT", "6333"))
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def test_user(test_db: AsyncSession) -> dict:
|
||||
"""Create a test user."""
|
||||
from app.models.user import User
|
||||
from app.core.security import get_password_hash
|
||||
|
||||
user = User(
|
||||
email="testuser@example.com",
|
||||
username="testuser",
|
||||
hashed_password=get_password_hash("testpass123"),
|
||||
is_active=True,
|
||||
is_verified=True
|
||||
)
|
||||
|
||||
test_db.add(user)
|
||||
await test_db.commit()
|
||||
await test_db.refresh(user)
|
||||
|
||||
return {
|
||||
"email": "test@example.com",
|
||||
"username": "testuser",
|
||||
"full_name": "Test User",
|
||||
"password": "testpassword123"
|
||||
"id": str(user.id),
|
||||
"email": user.email,
|
||||
"username": user.username,
|
||||
"password": "testpass123"
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_api_key_data():
|
||||
"""Test API key data."""
|
||||
return {
|
||||
"name": "Test API Key",
|
||||
"scopes": ["llm.chat", "llm.embeddings"],
|
||||
"budget_limit": 100.0,
|
||||
"budget_period": "monthly"
|
||||
}
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def test_user_token(test_user: dict) -> str:
|
||||
"""Create a JWT token for test user."""
|
||||
from app.core.security import create_access_token
|
||||
|
||||
token_data = {"sub": test_user["email"], "user_id": test_user["id"]}
|
||||
return create_access_token(data=token_data)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def test_api_key(test_db: AsyncSession, test_user: dict) -> str:
|
||||
"""Create a test API key."""
|
||||
from app.models.api_key import APIKey
|
||||
from app.models.budget import Budget
|
||||
import secrets
|
||||
|
||||
# Create budget
|
||||
budget = Budget(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=test_user["id"],
|
||||
limit_amount=100.0,
|
||||
period="monthly",
|
||||
current_usage=0.0,
|
||||
is_active=True
|
||||
)
|
||||
test_db.add(budget)
|
||||
|
||||
# Create API key
|
||||
key = f"sk-test-{secrets.token_urlsafe(32)}"
|
||||
api_key = APIKey(
|
||||
id=str(uuid.uuid4()),
|
||||
key_hash=key, # In real code, this would be hashed
|
||||
name="Test API Key",
|
||||
user_id=test_user["id"],
|
||||
scopes=["llm.chat", "llm.embeddings"],
|
||||
budget_id=budget.id,
|
||||
is_active=True
|
||||
)
|
||||
test_db.add(api_key)
|
||||
await test_db.commit()
|
||||
|
||||
return key
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def test_qdrant_collection(qdrant_client: QdrantClient) -> str:
|
||||
"""Create a test Qdrant collection."""
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
|
||||
collection_name = f"test_collection_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
qdrant_client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(size=1536, distance=Distance.COSINE)
|
||||
)
|
||||
|
||||
yield collection_name
|
||||
|
||||
# Cleanup
|
||||
try:
|
||||
qdrant_client.delete_collection(collection_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_documents_dir() -> Path:
|
||||
"""Get the test documents directory."""
|
||||
return Path(__file__).parent / "data" / "documents"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sample_text_path(test_documents_dir: Path) -> Path:
|
||||
"""Get path to sample text file for testing."""
|
||||
text_path = test_documents_dir / "sample.txt"
|
||||
if not text_path.exists():
|
||||
text_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
text_path.write_text("""
|
||||
Enclava Platform Documentation
|
||||
|
||||
This is a sample document for testing the RAG system.
|
||||
It contains information about the Enclava platform's features and capabilities.
|
||||
|
||||
Features:
|
||||
- Secure LLM access through PrivateMode.ai
|
||||
- Chatbot creation and management
|
||||
- RAG (Retrieval Augmented Generation) support
|
||||
- OpenAI-compatible API endpoints
|
||||
- Budget management and API key controls
|
||||
""")
|
||||
return text_path
|
||||
|
||||
|
||||
# Test environment variables
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def setup_test_env():
|
||||
"""Setup test environment variables."""
|
||||
os.environ["TESTING"] = "true"
|
||||
os.environ["LOG_LLM_PROMPTS"] = "true"
|
||||
os.environ["APP_DEBUG"] = "true"
|
||||
yield
|
||||
# Cleanup
|
||||
os.environ.pop("TESTING", None)
|
||||
Reference in New Issue
Block a user