Files
enclava/backend/tests/clients/openai_test_client.py
2025-08-25 17:13:15 +02:00

312 lines
12 KiB
Python

"""
OpenAI-compatible test client for verifying API compatibility.
"""
import openai
from openai import OpenAI
import asyncio
from typing import Optional, Dict, Any, List, AsyncGenerator
import aiohttp
import json
class OpenAITestClient:
"""OpenAI client wrapper for testing Enclava compatibility"""
def __init__(self, base_url: str = "http://localhost:3001/api/v1", api_key: Optional[str] = None):
self.base_url = base_url
self.api_key = api_key or "test-api-key"
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
def list_models(self) -> List[Dict[str, Any]]:
"""Test /v1/models endpoint compatibility"""
try:
response = self.client.models.list()
return [model.model_dump() for model in response.data]
except Exception as e:
raise OpenAICompatibilityError(f"Models list failed: {e}")
def create_chat_completion(self,
model: str,
messages: List[Dict[str, str]],
stream: bool = False,
**kwargs) -> Dict[str, Any]:
"""Test chat completion endpoint compatibility"""
try:
response = self.client.chat.completions.create(
model=model,
messages=messages,
stream=stream,
**kwargs
)
if stream:
return {"stream": response}
else:
return response.model_dump()
except Exception as e:
raise OpenAICompatibilityError(f"Chat completion failed: {e}")
def create_completion(self,
model: str,
prompt: str,
**kwargs) -> Dict[str, Any]:
"""Test legacy completion endpoint compatibility"""
try:
response = self.client.completions.create(
model=model,
prompt=prompt,
**kwargs
)
return response.model_dump()
except Exception as e:
raise OpenAICompatibilityError(f"Completion failed: {e}")
def create_embedding(self,
model: str,
input_text: str,
**kwargs) -> Dict[str, Any]:
"""Test embeddings endpoint compatibility"""
try:
response = self.client.embeddings.create(
model=model,
input=input_text,
**kwargs
)
return response.model_dump()
except Exception as e:
raise OpenAICompatibilityError(f"Embeddings failed: {e}")
def test_streaming_completion(self,
model: str,
messages: List[Dict[str, str]],
**kwargs) -> List[Dict[str, Any]]:
"""Test streaming chat completion"""
try:
stream = self.client.chat.completions.create(
model=model,
messages=messages,
stream=True,
**kwargs
)
chunks = []
for chunk in stream:
chunks.append(chunk.model_dump())
return chunks
except Exception as e:
raise OpenAICompatibilityError(f"Streaming completion failed: {e}")
def test_error_handling(self) -> Dict[str, Any]:
"""Test error response compatibility"""
test_cases = []
# Test invalid model
try:
self.client.chat.completions.create(
model="nonexistent-model",
messages=[{"role": "user", "content": "test"}]
)
except openai.BadRequestError as e:
test_cases.append({
"test": "invalid_model",
"error_type": type(e).__name__,
"status_code": e.response.status_code,
"error_body": e.response.text if hasattr(e.response, 'text') else str(e)
})
# Test missing API key
try:
no_key_client = OpenAI(base_url=self.base_url, api_key="")
no_key_client.models.list()
except openai.AuthenticationError as e:
test_cases.append({
"test": "missing_api_key",
"error_type": type(e).__name__,
"status_code": e.response.status_code,
"error_body": e.response.text if hasattr(e.response, 'text') else str(e)
})
# Test rate limiting (if implemented)
try:
for _ in range(100): # Attempt to trigger rate limiting
self.client.chat.completions.create(
model="test-model",
messages=[{"role": "user", "content": "test"}],
max_tokens=1
)
except openai.RateLimitError as e:
test_cases.append({
"test": "rate_limiting",
"error_type": type(e).__name__,
"status_code": e.response.status_code,
"error_body": e.response.text if hasattr(e.response, 'text') else str(e)
})
except Exception:
# Rate limiting might not be triggered, that's okay
test_cases.append({
"test": "rate_limiting",
"result": "no_rate_limit_triggered"
})
return {"error_tests": test_cases}
class AsyncOpenAITestClient:
"""Async version of OpenAI test client for concurrent testing"""
def __init__(self, base_url: str = "http://localhost:3001/api/v1", api_key: Optional[str] = None):
self.base_url = base_url
self.api_key = api_key or "test-api-key"
async def test_concurrent_requests(self, num_requests: int = 10) -> List[Dict[str, Any]]:
"""Test concurrent API requests"""
async def make_request(session: aiohttp.ClientSession, request_id: int) -> Dict[str, Any]:
headers = {"Authorization": f"Bearer {self.api_key}"}
payload = {
"model": "test-model",
"messages": [{"role": "user", "content": f"Request {request_id}"}],
"max_tokens": 50
}
start_time = asyncio.get_event_loop().time()
try:
async with session.post(
f"{self.base_url}/chat/completions",
json=payload,
headers=headers
) as response:
end_time = asyncio.get_event_loop().time()
response_data = await response.json()
return {
"request_id": request_id,
"status_code": response.status,
"response_time": end_time - start_time,
"success": response.status == 200,
"response": response_data if response.status == 200 else None,
"error": response_data if response.status != 200 else None
}
except Exception as e:
end_time = asyncio.get_event_loop().time()
return {
"request_id": request_id,
"status_code": None,
"response_time": end_time - start_time,
"success": False,
"error": str(e)
}
async with aiohttp.ClientSession() as session:
tasks = [make_request(session, i) for i in range(num_requests)]
results = await asyncio.gather(*tasks)
return results
async def test_streaming_performance(self) -> Dict[str, Any]:
"""Test streaming response performance"""
headers = {"Authorization": f"Bearer {self.api_key}"}
payload = {
"model": "test-model",
"messages": [{"role": "user", "content": "Generate a long response about AI"}],
"stream": True,
"max_tokens": 500
}
chunk_times = []
chunks = []
start_time = asyncio.get_event_loop().time()
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/chat/completions",
json=payload,
headers=headers
) as response:
if response.status != 200:
return {"error": f"Request failed with status {response.status}"}
async for line in response.content:
if line:
current_time = asyncio.get_event_loop().time()
try:
# Parse SSE format
line_str = line.decode('utf-8').strip()
if line_str.startswith('data: '):
data_str = line_str[6:] # Remove 'data: ' prefix
if data_str != '[DONE]':
chunk_data = json.loads(data_str)
chunks.append(chunk_data)
chunk_times.append(current_time - start_time)
except json.JSONDecodeError:
continue
end_time = asyncio.get_event_loop().time()
return {
"total_time": end_time - start_time,
"chunk_count": len(chunks),
"chunk_times": chunk_times,
"avg_chunk_interval": sum(chunk_times) / len(chunk_times) if chunk_times else 0,
"first_chunk_time": chunk_times[0] if chunk_times else None
}
class OpenAICompatibilityError(Exception):
"""Custom exception for OpenAI compatibility test failures"""
pass
def validate_openai_response_format(response: Dict[str, Any], endpoint_type: str) -> List[str]:
"""Validate response format matches OpenAI specification"""
errors = []
if endpoint_type == "chat_completion":
required_fields = ["id", "object", "created", "model", "choices"]
for field in required_fields:
if field not in response:
errors.append(f"Missing required field: {field}")
if "choices" in response and len(response["choices"]) > 0:
choice = response["choices"][0]
if "message" not in choice:
errors.append("Missing 'message' in choice")
elif "content" not in choice["message"]:
errors.append("Missing 'content' in message")
if "usage" in response:
usage_fields = ["prompt_tokens", "completion_tokens", "total_tokens"]
for field in usage_fields:
if field not in response["usage"]:
errors.append(f"Missing usage field: {field}")
elif endpoint_type == "models":
if not isinstance(response, list):
errors.append("Models response should be a list")
else:
for model in response:
model_fields = ["id", "object", "created", "owned_by"]
for field in model_fields:
if field not in model:
errors.append(f"Missing model field: {field}")
elif endpoint_type == "embeddings":
required_fields = ["object", "data", "model", "usage"]
for field in required_fields:
if field not in response:
errors.append(f"Missing required field: {field}")
if "data" in response and len(response["data"]) > 0:
embedding = response["data"][0]
if "embedding" not in embedding:
errors.append("Missing 'embedding' in data")
elif not isinstance(embedding["embedding"], list):
errors.append("Embedding should be a list of floats")
return errors