MockManager, mock_func in data.json (#39)

This commit is contained in:
Silen Naihin
2023-06-23 07:53:57 -04:00
committed by GitHub
parent 15c5469bb1
commit ffd1d15a0e
8 changed files with 68 additions and 14 deletions

View File

@@ -17,8 +17,9 @@ Input:
- **answer** (str): The raw text of ground truth answer
- **should_contain** (list): the exact strings that is required in the final answer
- **should_not_contain** (list): the exact strings that should not be in the final answer
- **files**: files that the are used for retrieval
- **files**: files that the are used for retrieval. Can specify file here or an extension **TODO:** like .txt
- **difficulty**(str): the difficulty of this query. choices from
- **mock_func**: function to mock the agent's response. This is used for testing purposes
Example:

View File

@@ -16,6 +16,7 @@ class Challenge(BaseModel):
task: str
ground: Ground
difficulty: str
mock_func: Optional[str] = None
def serialize(self, path: str) -> None:
with open(path, "w") as file:

View File

@@ -7,5 +7,6 @@
"should_not_contain": ["New York", "Los Angeles", "San Francisco"],
"files": ["file_to_check.txt"]
},
"difficulty": "easy"
"difficulty": "easy",
"mock_func": "retrieval_1_mock"
}

View File

@@ -11,7 +11,7 @@ class TestRetrieval1(RetrievalChallenge):
@pytest.mark.parametrize(
"server_response",
[data.task],
[(data.task, data.mock_func)],
indirect=True,
)
@pytest.mark.retrieval

View File

@@ -2,9 +2,10 @@ import json
import os
import pytest
import shutil
from agbenchmark.mocks.tests.retrieval_manual import mock_retrieval
from agbenchmark.tests.regression.RegressionManager import RegressionManager
import requests
from requests.exceptions import RequestException
from agbenchmark.mocks.MockManager import MockManager
@pytest.fixture(scope="module")
@@ -33,15 +34,34 @@ def workspace(config):
@pytest.fixture(autouse=True)
def server_response(request, config):
task = request.param # The task is passed in indirectly
print(f"Server starting at {request.module}")
"""Calling to get a response"""
if isinstance(request.param, tuple):
task = request.param[0] # The task is passed in indirectly
mock_function_name = request.param[1]
else:
task = request.param
mock_function_name = None
# print(f"Server starting at {request.module}")
# try:
# response = requests.post(
# f"{config['hostname']}:{config['port']}", data={"task": task}
# )
# assert (
# response.status_code == 200
# ), f"Request failed with status code {response.status_code}"
mock_retrieval(task, config["workspace"])
# response.raise_for_status() # This will raise an HTTPError if the status is 4xx or 5xx
# except RequestException:
# # If an exception occurs (could be connection, timeout, or HTTP errors), we use the mock
if mock_function_name:
mock_manager = MockManager(
task
) # workspace doesn't need to be passed in, stays the same
print("Server unavailable, using mock", mock_function_name)
mock_manager.delegate(mock_function_name)
else:
print("No mock provided")
# else:
# # This code is run if no exception occurred
# print(f"Request succeeded with status code {response.status_code}")
regression_txt = "agbenchmark/tests/regression/regression_tests.txt"

View File

@@ -0,0 +1,28 @@
import sys
import agbenchmark.mocks.tests.basic_mocks as basic_mocks
import agbenchmark.mocks.tests.retrieval_mocks as retrieval_mocks
class MockManager:
def __init__(self, task: str):
self.task = task
self.workspace = "agbenchmark/mocks/workspace"
self.modules = [basic_mocks, retrieval_mocks]
def delegate(self, mock_function_name, *args, **kwargs):
if hasattr(self, mock_function_name):
# Check if the mock function is an attribute of this class
getattr(self, mock_function_name)(*args, **kwargs)
elif mock_function_name in globals():
# Check if the function is imported in the file
func = globals()[mock_function_name]
func(self.task, self.workspace, *args, **kwargs)
elif len(self.modules) > 0:
# checks if function is in imported modules
for module in self.modules:
if hasattr(module, mock_function_name):
func = getattr(module, mock_function_name)
func(self.task, self.workspace, *args, **kwargs)
return
else:
raise ValueError(f"No such mock: {mock_function_name}")

View File

View File

@@ -2,7 +2,10 @@ from ..basic_gpt_agent import basic_gpt_agent
from agbenchmark.Challenge import Challenge
def mock_retrieval(task: str, workspace: str):
# TODO: Make it so that you can specify for tests to only run if their prerequisites are met.
# Prerequisites here would be writing to a file (basic_abilities test).
# Should also check if prerequisites exists in regression file
def retrieval_1_mock(task: str, workspace: str):
# Call the basic_gpt_agent to get a response.
response = basic_gpt_agent(task)