mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
429 lines
16 KiB
Python
429 lines
16 KiB
Python
"""
|
|
API integration tests for chatbot sources with URL metadata.
|
|
|
|
Tests cover:
|
|
- Chatbot API returns sources with URLs
|
|
- Sources have all required fields
|
|
- Sources are sorted by relevance
|
|
- URL deduplication in chat response
|
|
"""
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
import json
|
|
from httpx import AsyncClient
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.modules.rag.main import RAGModule
|
|
from app.models.chatbot import ChatbotInstance
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_faq_jsonl_with_urls():
|
|
"""Sample FAQ JSONL with URLs for testing"""
|
|
return """{"id": "faq_pass", "payload": {"question": "How to reset my password?", "answer": "To reset your password, go to the login page and click 'Forgot Password'. You will receive an email with reset instructions.", "language": "EN", "url": "https://support.example.com/faq/password-reset"}}
|
|
{"id": "faq_2fa", "payload": {"question": "How to enable two-factor authentication?", "answer": "Two-factor authentication can be enabled in your account security settings. Go to Settings > Security > Two-Factor Authentication and follow the setup wizard.", "language": "EN", "url": "https://support.example.com/faq/2fa-setup"}}
|
|
{"id": "faq_hours", "payload": {"question": "What are your business hours?", "answer": "We are open Monday through Friday, 9:00 AM to 5:00 PM EST. We are closed on weekends and major holidays.", "language": "EN", "url": "https://support.example.com/faq/business-hours"}}
|
|
{"id": "faq_cancel", "payload": {"question": "How to cancel my subscription?", "answer": "You can cancel your subscription at any time from your account settings. Go to Settings > Billing > Cancel Subscription. Your access will continue until the end of your billing period.", "language": "EN", "url": "https://support.example.com/faq/cancel-subscription"}}"""
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def chatbot_with_rag(test_db: AsyncSession, test_user: dict, test_qdrant_collection: str, sample_faq_jsonl_with_urls: str):
|
|
"""Create a chatbot instance with RAG enabled and indexed documents"""
|
|
# Initialize RAG module
|
|
rag_module = RAGModule()
|
|
await rag_module.initialize()
|
|
rag_module.default_collection_name = test_qdrant_collection
|
|
|
|
# Process and index FAQ documents
|
|
file_content = sample_faq_jsonl_with_urls.encode("utf-8")
|
|
processed_doc = await rag_module.process_document(
|
|
file_data=file_content,
|
|
filename="support_faq.jsonl"
|
|
)
|
|
await rag_module.index_processed_document(processed_doc, collection_name=test_qdrant_collection)
|
|
|
|
# Create chatbot instance
|
|
chatbot = ChatbotInstance(
|
|
name="Support Bot",
|
|
chatbot_type="customer_support",
|
|
user_id=test_user["id"],
|
|
model="gpt-3.5-turbo",
|
|
system_prompt="You are a helpful support assistant.",
|
|
temperature=0.7,
|
|
max_tokens=500,
|
|
use_rag=True,
|
|
rag_collection=test_qdrant_collection,
|
|
rag_top_k=5,
|
|
rag_score_threshold=0.1,
|
|
is_active=True
|
|
)
|
|
|
|
test_db.add(chatbot)
|
|
await test_db.commit()
|
|
await test_db.refresh(chatbot)
|
|
|
|
yield chatbot
|
|
|
|
# Cleanup
|
|
await rag_module.cleanup()
|
|
|
|
|
|
class TestChatbotSourcesResponse:
|
|
"""Test chatbot API returns sources with URL metadata"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_returns_sources(self, authenticated_client: AsyncClient, chatbot_with_rag: ChatbotInstance):
|
|
"""Test that chat API returns sources array"""
|
|
response = await authenticated_client.post(
|
|
f"/api-internal/v1/chatbots/{chatbot_with_rag.id}/chat",
|
|
json={
|
|
"message": "How do I reset my password?",
|
|
"conversation_id": None
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
# Verify response structure
|
|
assert "response" in data
|
|
assert "sources" in data
|
|
assert isinstance(data["sources"], list)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sources_contain_required_fields(self, authenticated_client: AsyncClient, chatbot_with_rag: ChatbotInstance):
|
|
"""Test that sources contain all required fields"""
|
|
response = await authenticated_client.post(
|
|
f"/api-internal/v1/chatbots/{chatbot_with_rag.id}/chat",
|
|
json={
|
|
"message": "Tell me about password reset and two-factor authentication",
|
|
"conversation_id": None
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
if len(data["sources"]) > 0:
|
|
source = data["sources"][0]
|
|
|
|
# Required fields
|
|
assert "title" in source or "question" in source
|
|
assert "relevance_score" in source or "score" in source
|
|
|
|
# URL field (may be None for legacy documents)
|
|
if "url" in source:
|
|
assert source["url"] is None or isinstance(source["url"], str)
|
|
|
|
# Optional fields
|
|
if "language" in source:
|
|
assert isinstance(source["language"], str)
|
|
|
|
if "article_id" in source:
|
|
assert isinstance(source["article_id"], str)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sources_have_urls(self, authenticated_client: AsyncClient, chatbot_with_rag: ChatbotInstance):
|
|
"""Test that sources contain URL metadata when available"""
|
|
response = await authenticated_client.post(
|
|
f"/api-internal/v1/chatbots/{chatbot_with_rag.id}/chat",
|
|
json={
|
|
"message": "How to enable two-factor authentication?",
|
|
"conversation_id": None
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
# Should have at least one source with URL
|
|
sources_with_urls = [
|
|
s for s in data["sources"]
|
|
if s.get("url") and s["url"].startswith("http")
|
|
]
|
|
|
|
# At least some sources should have URLs (depending on RAG results)
|
|
assert len(sources_with_urls) >= 0 # Flexible assertion
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_url_format_validation(self, authenticated_client: AsyncClient, chatbot_with_rag: ChatbotInstance):
|
|
"""Test that returned URLs are properly formatted"""
|
|
response = await authenticated_client.post(
|
|
f"/api-internal/v1/chatbots/{chatbot_with_rag.id}/chat",
|
|
json={
|
|
"message": "What are your business hours?",
|
|
"conversation_id": None
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
for source in data["sources"]:
|
|
if source.get("url"):
|
|
url = source["url"]
|
|
# URL should be valid format
|
|
assert url.startswith("http://") or url.startswith("https://")
|
|
assert " " not in url # No spaces in URL
|
|
assert len(url) <= 2048 # Reasonable URL length
|
|
|
|
|
|
class TestSourcesSortedByRelevance:
|
|
"""Test that sources are sorted by relevance score"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sources_sorted_descending(self, authenticated_client: AsyncClient, chatbot_with_rag: ChatbotInstance):
|
|
"""Test that sources are sorted by relevance score (highest first)"""
|
|
response = await authenticated_client.post(
|
|
f"/api-internal/v1/chatbots/{chatbot_with_rag.id}/chat",
|
|
json={
|
|
"message": "Tell me about account security and subscription management",
|
|
"conversation_id": None
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
if len(data["sources"]) > 1:
|
|
# Extract relevance scores
|
|
scores = []
|
|
for source in data["sources"]:
|
|
score = source.get("relevance_score") or source.get("score", 0)
|
|
scores.append(score)
|
|
|
|
# Verify sorted in descending order
|
|
assert scores == sorted(scores, reverse=True), "Sources should be sorted by relevance (highest first)"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_highest_relevance_first(self, authenticated_client: AsyncClient, chatbot_with_rag: ChatbotInstance):
|
|
"""Test that most relevant source is first"""
|
|
response = await authenticated_client.post(
|
|
f"/api-internal/v1/chatbots/{chatbot_with_rag.id}/chat",
|
|
json={
|
|
"message": "How to reset password?",
|
|
"conversation_id": None
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
if len(data["sources"]) > 0:
|
|
# First source should have highest score
|
|
first_score = data["sources"][0].get("relevance_score") or data["sources"][0].get("score", 0)
|
|
|
|
for source in data["sources"][1:]:
|
|
source_score = source.get("relevance_score") or source.get("score", 0)
|
|
assert first_score >= source_score, "First source should have highest relevance"
|
|
|
|
|
|
class TestURLDeduplicationInChatResponse:
|
|
"""Test URL deduplication in chat API responses"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_duplicate_urls_removed(self, authenticated_client: AsyncClient, chatbot_with_rag: ChatbotInstance):
|
|
"""Test that duplicate URLs are deduplicated in response"""
|
|
response = await authenticated_client.post(
|
|
f"/api-internal/v1/chatbots/{chatbot_with_rag.id}/chat",
|
|
json={
|
|
"message": "Tell me everything about password security, 2FA, and account protection",
|
|
"conversation_id": None
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
# Extract URLs from sources
|
|
urls = [s.get("url") for s in data["sources"] if s.get("url")]
|
|
|
|
if len(urls) > 0:
|
|
# Check for duplicates
|
|
unique_urls = set(urls)
|
|
assert len(urls) == len(unique_urls), "Response should not contain duplicate URLs"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_highest_score_kept_for_duplicate_url(self, authenticated_client: AsyncClient, test_qdrant_collection: str):
|
|
"""Test that highest scoring document is kept when URLs are duplicated"""
|
|
# This would require setting up documents with duplicate URLs
|
|
# For now, we test the general behavior
|
|
pass # Implementation would depend on specific test data setup
|
|
|
|
|
|
class TestMixedSourcesWithAndWithoutURLs:
|
|
"""Test handling of mixed sources (some with URLs, some without)"""
|
|
|
|
@pytest_asyncio.fixture
|
|
async def chatbot_with_mixed_docs(self, test_db: AsyncSession, test_user: dict, test_qdrant_collection: str):
|
|
"""Create chatbot with mixed documents (with and without URLs)"""
|
|
mixed_jsonl = """{"id": "with_url", "payload": {"question": "How to login?", "answer": "Use your email and password to log in.", "language": "EN", "url": "https://support.example.com/faq/login"}}
|
|
{"id": "without_url", "payload": {"question": "Security best practices", "answer": "Always use strong passwords and enable 2FA.", "language": "EN"}}
|
|
{"id": "with_url2", "payload": {"question": "Account recovery", "answer": "Contact support for account recovery.", "language": "EN", "url": "https://support.example.com/faq/recovery"}}"""
|
|
|
|
# Initialize RAG and index documents
|
|
rag_module = RAGModule()
|
|
await rag_module.initialize()
|
|
rag_module.default_collection_name = test_qdrant_collection
|
|
|
|
file_content = mixed_jsonl.encode("utf-8")
|
|
processed_doc = await rag_module.process_document(
|
|
file_data=file_content,
|
|
filename="mixed_faq.jsonl"
|
|
)
|
|
await rag_module.index_processed_document(processed_doc, collection_name=test_qdrant_collection)
|
|
|
|
# Create chatbot
|
|
chatbot = ChatbotInstance(
|
|
name="Mixed Sources Bot",
|
|
chatbot_type="assistant",
|
|
user_id=test_user["id"],
|
|
model="gpt-3.5-turbo",
|
|
use_rag=True,
|
|
rag_collection=test_qdrant_collection,
|
|
rag_top_k=10,
|
|
rag_score_threshold=0.01,
|
|
is_active=True
|
|
)
|
|
|
|
test_db.add(chatbot)
|
|
await test_db.commit()
|
|
await test_db.refresh(chatbot)
|
|
|
|
yield chatbot
|
|
|
|
await rag_module.cleanup()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mixed_sources_response(self, authenticated_client: AsyncClient, chatbot_with_mixed_docs: ChatbotInstance):
|
|
"""Test that response handles mix of sources with and without URLs"""
|
|
response = await authenticated_client.post(
|
|
f"/api-internal/v1/chatbots/{chatbot_with_mixed_docs.id}/chat",
|
|
json={
|
|
"message": "Tell me about login and security",
|
|
"conversation_id": None
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
# Should have sources
|
|
assert len(data["sources"]) >= 0
|
|
|
|
# Check that sources can have both URL and non-URL documents
|
|
with_urls = [s for s in data["sources"] if s.get("url")]
|
|
without_urls = [s for s in data["sources"] if not s.get("url")]
|
|
|
|
# Both types should be handled gracefully
|
|
for source in data["sources"]:
|
|
# All sources should have title/question
|
|
assert "title" in source or "question" in source
|
|
|
|
# URL is optional
|
|
if "url" in source and source["url"]:
|
|
assert isinstance(source["url"], str)
|
|
assert source["url"].startswith("http")
|
|
|
|
|
|
class TestSourcesEmptyState:
|
|
"""Test behavior when no sources are available"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_rag_sources(self, authenticated_client: AsyncClient, test_db: AsyncSession, test_user: dict):
|
|
"""Test chat response when RAG is disabled"""
|
|
# Create chatbot without RAG
|
|
chatbot = ChatbotInstance(
|
|
name="No RAG Bot",
|
|
chatbot_type="assistant",
|
|
user_id=test_user["id"],
|
|
model="gpt-3.5-turbo",
|
|
use_rag=False,
|
|
is_active=True
|
|
)
|
|
|
|
test_db.add(chatbot)
|
|
await test_db.commit()
|
|
await test_db.refresh(chatbot)
|
|
|
|
response = await authenticated_client.post(
|
|
f"/api-internal/v1/chatbots/{chatbot.id}/chat",
|
|
json={
|
|
"message": "Hello, how can you help?",
|
|
"conversation_id": None
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
# Sources should be empty or not present
|
|
if "sources" in data:
|
|
assert isinstance(data["sources"], list)
|
|
assert len(data["sources"]) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_matching_documents(self, authenticated_client: AsyncClient, chatbot_with_rag: ChatbotInstance):
|
|
"""Test response when query matches no documents"""
|
|
response = await authenticated_client.post(
|
|
f"/api-internal/v1/chatbots/{chatbot_with_rag.id}/chat",
|
|
json={
|
|
"message": "xyzabc123 nonexistent query zzzqqq",
|
|
"conversation_id": None
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
# Should have response even with no sources
|
|
assert "response" in data
|
|
|
|
# Sources may be empty
|
|
if "sources" in data:
|
|
assert isinstance(data["sources"], list)
|
|
|
|
|
|
class TestConversationContext:
|
|
"""Test that sources are maintained across conversation turns"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sources_in_conversation(self, authenticated_client: AsyncClient, chatbot_with_rag: ChatbotInstance):
|
|
"""Test that sources are provided in multi-turn conversation"""
|
|
# First message
|
|
response1 = await authenticated_client.post(
|
|
f"/api-internal/v1/chatbots/{chatbot_with_rag.id}/chat",
|
|
json={
|
|
"message": "How do I reset my password?",
|
|
"conversation_id": None
|
|
}
|
|
)
|
|
|
|
assert response1.status_code == 200
|
|
data1 = response1.json()
|
|
conversation_id = data1.get("conversation_id")
|
|
|
|
assert conversation_id is not None
|
|
assert "sources" in data1
|
|
|
|
# Follow-up message in same conversation
|
|
response2 = await authenticated_client.post(
|
|
f"/api-internal/v1/chatbots/{chatbot_with_rag.id}/chat",
|
|
json={
|
|
"message": "What if I don't receive the reset email?",
|
|
"conversation_id": conversation_id
|
|
}
|
|
)
|
|
|
|
assert response2.status_code == 200
|
|
data2 = response2.json()
|
|
|
|
# Should still have sources in follow-up
|
|
assert "sources" in data2
|
|
assert isinstance(data2["sources"], list)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|