diff --git a/.github/workflows/add-cassettes.yml b/.github/workflows/add-cassettes.yml new file mode 100644 index 00000000..90e4402d --- /dev/null +++ b/.github/workflows/add-cassettes.yml @@ -0,0 +1,49 @@ +name: Merge and Commit Cassettes + +on: + pull_request_target: + types: + - closed + +jobs: + update-cassettes: + if: github.event.pull_request.merged == true + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v3 + with: + fetch-depth: 0 # This is necessary to fetch all branches and tags + + - name: Fetch all branches + run: git fetch --all + + - name: Reset branch + run: | + git checkout ${{ github.event.pull_request.base.ref }} + git reset --hard origin/cassette-diff-${{ github.event.pull_request.number }} + + - name: Create PR + id: create_pr + uses: peter-evans/create-pull-request@v5 + with: + commit-message: Update cassettes + signoff: false + branch: cassette-diff-${{ github.event.pull_request.number }} + delete-branch: false + title: "Update cassettes" + body: "This PR updates the cassettes." + draft: false + + - name: Check PR + run: | + echo "Pull Request Number - ${{ steps.create_pr.outputs.pull-request-number }}" + echo "Pull Request URL - ${{ steps.create_pr.outputs.pull-request-url }}" + + - name: Comment PR URL in the current PR + uses: thollander/actions-comment-pull-request@v2 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + message: | + New pull request created for cassettes: [HERE](${{ steps.create_pr.outputs.pull-request-url }}). Please merge it asap. diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f21a263e..e8830775 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,12 +3,12 @@ name: Python CI on: push: branches: [ master ] - pull_request: + pull_request_target: branches: [ master, stable ] concurrency: group: ${{ format('ci-{0}', github.head_ref && format('pr-{0}', github.event.pull_request.number) || github.sha) }} - cancel-in-progress: ${{ github.event_name == 'pull_request' }} + cancel-in-progress: ${{ github.event_name == 'pull_request_target' }} jobs: lint: @@ -19,6 +19,9 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v3 + with: + fetch-depth: 0 + ref: ${{ github.event.pull_request.head.ref }} - name: Set up Python ${{ env.min-python-version }} uses: actions/setup-python@v2 @@ -58,6 +61,9 @@ jobs: steps: - name: Check out repository uses: actions/checkout@v3 + with: + fetch-depth: 0 + ref: ${{ github.event.pull_request.head.ref }} - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 @@ -74,6 +80,20 @@ jobs: pytest -n auto --cov=autogpt --cov-report term-missing --cov-branch --cov-report xml --cov-report term env: CI: true + PROXY: ${{ vars.PROXY }} + AGENT_MODE: ${{ vars.AGENT_MODE }} + AGENT_TYPE: ${{ vars.AGENT_TYPE }} - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v3 + + - name: Stage new files and commit + run: | + git add tests + git diff --cached --quiet && echo "No changes to commit" && exit 0 + git config user.email "github-actions@github.com" + git config user.name "GitHub Actions" + git commit -m "Add new cassettes" + git checkout -b cassette-diff-${{ github.event.pull_request.number }} + git push -f origin cassette-diff-${{ github.event.pull_request.number }} + echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV diff --git a/autogpt/llm/api_manager.py b/autogpt/llm/api_manager.py index 9143389e..a7777a2b 100644 --- a/autogpt/llm/api_manager.py +++ b/autogpt/llm/api_manager.py @@ -59,10 +59,11 @@ class ApiManager(metaclass=Singleton): max_tokens=max_tokens, api_key=cfg.openai_api_key, ) - logger.debug(f"Response: {response}") - prompt_tokens = response.usage.prompt_tokens - completion_tokens = response.usage.completion_tokens - self.update_cost(prompt_tokens, completion_tokens, model) + if not hasattr(response, "error"): + logger.debug(f"Response: {response}") + prompt_tokens = response.usage.prompt_tokens + completion_tokens = response.usage.completion_tokens + self.update_cost(prompt_tokens, completion_tokens, model) return response def update_cost(self, prompt_tokens, completion_tokens, model): diff --git a/autogpt/llm/llm_utils.py b/autogpt/llm/llm_utils.py index a77bccbc..58b19735 100644 --- a/autogpt/llm/llm_utils.py +++ b/autogpt/llm/llm_utils.py @@ -181,7 +181,7 @@ def create_chat_completion( ) warned_user = True except (APIError, Timeout) as e: - if e.http_status != 502: + if e.http_status != 502 : raise if attempt == num_retries - 1: raise diff --git a/tests/conftest.py b/tests/conftest.py index da00058b..6e6f0ad3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import os from pathlib import Path import pytest @@ -9,6 +10,8 @@ from autogpt.workspace import Workspace pytest_plugins = ["tests.integration.agent_factory"] +PROXY = os.environ.get("PROXY") + @pytest.fixture() def workspace_root(tmp_path: Path) -> Path: diff --git a/tests/integration/challenges/information_retrieval/test_information_retrieval_challenge_a.py b/tests/integration/challenges/information_retrieval/test_information_retrieval_challenge_a.py index a5f8fb4c..b96e811a 100644 --- a/tests/integration/challenges/information_retrieval/test_information_retrieval_challenge_a.py +++ b/tests/integration/challenges/information_retrieval/test_information_retrieval_challenge_a.py @@ -26,7 +26,7 @@ def input_generator(input_sequence: list) -> Generator[str, None, None]: @requires_api_key("OPENAI_API_KEY") @run_multiple_times(3) def test_information_retrieval_challenge_a( - get_company_revenue_agent, monkeypatch + get_company_revenue_agent, monkeypatch, patched_api_requestor ) -> None: """ Test the challenge_a function in a given agent by mocking user inputs and checking the output file content. diff --git a/tests/integration/challenges/memory/test_memory_challenge_a.py b/tests/integration/challenges/memory/test_memory_challenge_a.py index 895fc8fe..fb5876cd 100644 --- a/tests/integration/challenges/memory/test_memory_challenge_a.py +++ b/tests/integration/challenges/memory/test_memory_challenge_a.py @@ -13,7 +13,7 @@ MAX_LEVEL = 3 @pytest.mark.vcr @requires_api_key("OPENAI_API_KEY") def test_memory_challenge_a( - memory_management_agent: Agent, user_selected_level: int + memory_management_agent: Agent, user_selected_level: int, patched_api_requestor ) -> None: """ The agent reads a file containing a task_id. Then, it reads a series of other files. @@ -30,7 +30,7 @@ def test_memory_challenge_a( create_instructions_files(memory_management_agent, num_files, task_id) try: - run_interaction_loop(memory_management_agent, 180) + run_interaction_loop(memory_management_agent, 400) # catch system exit exceptions except SystemExit: file_path = str(memory_management_agent.workspace.get_path("output.txt")) diff --git a/tests/integration/challenges/memory/test_memory_challenge_b.py b/tests/integration/challenges/memory/test_memory_challenge_b.py index 628b4989..21d46b38 100644 --- a/tests/integration/challenges/memory/test_memory_challenge_b.py +++ b/tests/integration/challenges/memory/test_memory_challenge_b.py @@ -14,7 +14,7 @@ NOISE = 1000 @pytest.mark.vcr @requires_api_key("OPENAI_API_KEY") def test_memory_challenge_b( - memory_management_agent: Agent, user_selected_level: int + memory_management_agent: Agent, user_selected_level: int, patched_api_requestor ) -> None: """ The agent reads a series of files, each containing a task_id and noise. After reading 'n' files, diff --git a/tests/integration/challenges/memory/test_memory_challenge_c.py b/tests/integration/challenges/memory/test_memory_challenge_c.py index edd3efe0..634a24a3 100644 --- a/tests/integration/challenges/memory/test_memory_challenge_c.py +++ b/tests/integration/challenges/memory/test_memory_challenge_c.py @@ -14,7 +14,7 @@ NOISE = 1000 @pytest.mark.vcr @requires_api_key("OPENAI_API_KEY") def test_memory_challenge_c( - memory_management_agent: Agent, user_selected_level: int + memory_management_agent: Agent, user_selected_level: int, patched_api_requestor ) -> None: """ Instead of reading task Ids from files as with the previous challenges, the agent now must remember diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 00928702..dfb94d0e 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,7 +1,9 @@ import os +import openai import pytest +from tests.conftest import PROXY from tests.vcr.vcr_filter import before_record_request, before_record_response @@ -17,5 +19,37 @@ def vcr_config(): "X-OpenAI-Client-User-Agent", "User-Agent", ], - "match_on": ["method", "uri", "body"], + "match_on": ["method", "body"], } + + +def patch_api_base(requestor): + new_api_base = f"{PROXY}/v1" + requestor.api_base = new_api_base + return requestor + + +@pytest.fixture +def patched_api_requestor(mocker): + original_init = openai.api_requestor.APIRequestor.__init__ + original_validate_headers = openai.api_requestor.APIRequestor._validate_headers + + def patched_init(requestor, *args, **kwargs): + original_init(requestor, *args, **kwargs) + patch_api_base(requestor) + + def patched_validate_headers(self, supplied_headers): + headers = original_validate_headers(self, supplied_headers) + headers["AGENT-MODE"] = os.environ.get("AGENT_MODE") + headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE") + return headers + + if PROXY: + mocker.patch("openai.api_requestor.APIRequestor.__init__", new=patched_init) + mocker.patch.object( + openai.api_requestor.APIRequestor, + "_validate_headers", + new=patched_validate_headers, + ) + + return mocker diff --git a/tests/integration/goal_oriented/cassettes/test_write_file/test_write_file.yaml b/tests/integration/goal_oriented/cassettes/test_write_file/test_write_file.yaml index 42160ea2..4aefdb23 100644 --- a/tests/integration/goal_oriented/cassettes/test_write_file/test_write_file.yaml +++ b/tests/integration/goal_oriented/cassettes/test_write_file/test_write_file.yaml @@ -573,4 +573,120 @@ interactions: status: code: 200 message: OK +- request: + body: '{"model": "gpt-3.5-turbo", "messages": [{"role": "system", "content": "You + are write_to_file-GPT, an AI designed to use the write_to_file command to write + ''Hello World'' into a file named \"hello_world.txt\" and then use the task_complete + command to complete the task.\nYour decisions must always be made independently + without seeking user assistance. Play to your strengths as an LLM and pursue + simple strategies with no legal complications.\n\nGOALS:\n\n1. Use the write_to_file + command to write ''Hello World'' into a file named \"hello_world.txt\".\n2. + Use the task_complete command to complete the task.\n3. Do not use any other + commands.\n\n\nConstraints:\n1. ~4000 word limit for short term memory. Your + short term memory is short, so immediately save important information to files.\n2. + If you are unsure how you previously did something or want to recall past events, + thinking about similar events will help you remember.\n3. No user assistance\n4. + Exclusively use the commands listed in double quote e.g. \"command name\"\n\nCommands:\n1. + append_to_file: Append to file, args: \"filename\": \"\", \"text\": + \"\"\n2. delete_file: Delete file, args: \"filename\": \"\"\n3. + list_files: List Files in Directory, args: \"directory\": \"\"\n4. + read_file: Read file, args: \"filename\": \"\"\n5. write_to_file: + Write to file, args: \"filename\": \"\", \"text\": \"\"\n6. + delete_agent: Delete GPT Agent, args: \"key\": \"\"\n7. get_hyperlinks: + Get text summary, args: \"url\": \"\"\n8. get_text_summary: Get text summary, + args: \"url\": \"\", \"question\": \"\"\n9. list_agents: List + GPT Agents, args: () -> str\n10. message_agent: Message GPT Agent, args: \"key\": + \"\", \"message\": \"\"\n11. start_agent: Start GPT Agent, args: + \"name\": \"\", \"task\": \"\", \"prompt\": \"\"\n12. + task_complete: Task Complete (Shutdown), args: \"reason\": \"\"\n\nResources:\n1. + Internet access for searches and information gathering.\n2. Long Term memory + management.\n3. GPT-3.5 powered Agents for delegation of simple tasks.\n4. File + output.\n\nPerformance Evaluation:\n1. Continuously review and analyze your + actions to ensure you are performing to the best of your abilities.\n2. Constructively + self-criticize your big-picture behavior constantly.\n3. Reflect on past decisions + and strategies to refine your approach.\n4. Every command has a cost, so be + smart and efficient. Aim to complete tasks in the least number of steps.\n5. + Write all code to a file.\n\nYou should only respond in JSON format as described + below \nResponse Format: \n{\n \"thoughts\": {\n \"text\": \"thought\",\n \"reasoning\": + \"reasoning\",\n \"plan\": \"- short bulleted\\n- list that conveys\\n- + long-term plan\",\n \"criticism\": \"constructive self-criticism\",\n \"speak\": + \"thoughts summary to say to user\"\n },\n \"command\": {\n \"name\": + \"command name\",\n \"args\": {\n \"arg name\": \"value\"\n }\n }\n} + \nEnsure the response can be parsed by Python json.loads"}, {"role": "system", + "content": "The current time and date is Tue Jan 1 00:00:00 2000"}, {"role": + "user", "content": "Determine which next command to use, and respond using the + format specified above:"}], "temperature": 0, "max_tokens": 0}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '3405' + Content-Type: + - application/json + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAA7yTT4/TMBDF73yK0Vx6cauUqt1trmhhK8QBBEKIoMrrTBvT2BPsCVu2yndfJemf + 3SBOCK7zxu/9xh4f0OaYoim0GFeV46vX7u3i183N/Geup/7Lq9vVh/fXEt49PLwpNCrku+9k5Hhi + YthVJYlljwpNIC2UYzpdXM+Wy3kySxQ6zqnEFLeVjGeT+VjqcMfjZJZMUWEd9ZYwPWAV2FWyFt6R + j5heLRKFF+9z/WUyVSgsujyXlotpo9AUbA1FTL8e0FE82QYuCVPUMdoo2ksLyV7ItwMcMg8AkKEU + XG8LiRmmcCweBdpLW8xwBZ4oB2GoI4EUBPfBCq2F1xtbEhh2TvuuoRNgdEtlyfCZQ5mPwHph0NC1 + eu0oh1HR6uv7Vp/IXkaTDNXT7EA6srd+2wN8LAhExx0E+lHbQBEc/UWago72j3PY2ImOo4CuqsBV + sFoINhxAilbVcTdErkrte9oxfPpP12SCFWtsdMN3Ih/r0DJogdX51QyHQEYuEf090F4uTMJda9sy + TIsV6d0p6d6W5b9chz64Uac1PZr+tqWtQ8/0DGKArsN2uOC90PZeLAYcz0yGn+LJTCfajvgInvkG + G4Ub620s1v0+Y4pRuEKF1ue0xzRpvjUvHgEAAP//AwDSj7qBhAQAAA== + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 7c6c3f8bcdd1cf87-SJC + Cache-Control: + - no-cache, must-revalidate + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Sat, 13 May 2023 16:24:06 GMT + Server: + - cloudflare + access-control-allow-origin: + - '*' + alt-svc: + - h3=":443"; ma=86400, h3-29=":443"; ma=86400 + openai-model: + - gpt-3.5-turbo-0301 + openai-organization: + - user-adtx4fhfg1qsiyzdoaxciooj + openai-processing-ms: + - '16269' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=15724800; includeSubDomains + x-ratelimit-limit-requests: + - '3500' + x-ratelimit-limit-tokens: + - '90000' + x-ratelimit-remaining-requests: + - '3499' + x-ratelimit-remaining-tokens: + - '86496' + x-ratelimit-reset-requests: + - 17ms + x-ratelimit-reset-tokens: + - 2.336s + x-request-id: + - 8d3e6826e88e77fb2cbce01166ddc550 + status: + code: 200 + message: OK version: 1 diff --git a/tests/integration/goal_oriented/test_browse_website.py b/tests/integration/goal_oriented/test_browse_website.py index ca433d80..3ce85689 100644 --- a/tests/integration/goal_oriented/test_browse_website.py +++ b/tests/integration/goal_oriented/test_browse_website.py @@ -8,7 +8,7 @@ from tests.utils import requires_api_key @requires_api_key("OPENAI_API_KEY") @pytest.mark.vcr -def test_browse_website(browser_agent: Agent) -> None: +def test_browse_website(browser_agent: Agent, patched_api_requestor) -> None: file_path = browser_agent.workspace.get_path("browse_website.txt") try: run_interaction_loop(browser_agent, 120) diff --git a/tests/integration/goal_oriented/test_write_file.py b/tests/integration/goal_oriented/test_write_file.py index da67235a..55db3f4a 100644 --- a/tests/integration/goal_oriented/test_write_file.py +++ b/tests/integration/goal_oriented/test_write_file.py @@ -1,3 +1,6 @@ +import os + +import openai import pytest from autogpt.agent import Agent @@ -8,10 +11,10 @@ from tests.utils import requires_api_key @requires_api_key("OPENAI_API_KEY") @pytest.mark.vcr -def test_write_file(writer_agent: Agent) -> None: +def test_write_file(writer_agent: Agent, patched_api_requestor) -> None: file_path = str(writer_agent.workspace.get_path("hello_world.txt")) try: - run_interaction_loop(writer_agent, 40) + run_interaction_loop(writer_agent, 200) # catch system exit exceptions except SystemExit: # the agent returns an exception when it shuts down content = read_file(file_path) diff --git a/tests/unit/test_commands.py b/tests/integration/test_commands.py similarity index 92% rename from tests/unit/test_commands.py rename to tests/integration/test_commands.py index e3b874fb..59f63857 100644 --- a/tests/unit/test_commands.py +++ b/tests/integration/test_commands.py @@ -10,11 +10,12 @@ from tests.utils import requires_api_key @pytest.mark.vcr @pytest.mark.integration_test @requires_api_key("OPENAI_API_KEY") -def test_make_agent() -> None: +def test_make_agent(patched_api_requestor) -> None: """Test that an agent can be created""" # Use the mock agent manager to avoid creating a real agent with patch("openai.ChatCompletion.create") as mock: response = MagicMock() + # del response.error response.choices[0].messages[0].content = "Test message" response.usage.prompt_tokens = 1 response.usage.completion_tokens = 1 diff --git a/tests/integration/test_llm_utils.py b/tests/integration/test_llm_utils.py index 553d3699..fefc239c 100644 --- a/tests/integration/test_llm_utils.py +++ b/tests/integration/test_llm_utils.py @@ -41,7 +41,10 @@ def spy_create_embedding(mocker: MockerFixture): @pytest.mark.vcr @requires_api_key("OPENAI_API_KEY") def test_get_ada_embedding( - config: Config, api_manager: ApiManager, spy_create_embedding: MagicMock + config: Config, + api_manager: ApiManager, + spy_create_embedding: MagicMock, + patched_api_requestor, ): token_cost = COSTS[config.embedding_model]["prompt"] llm_utils.get_ada_embedding("test") diff --git a/tests/integration/test_local_cache.py b/tests/integration/test_local_cache.py index 5200e026..808f119a 100644 --- a/tests/integration/test_local_cache.py +++ b/tests/integration/test_local_cache.py @@ -91,7 +91,7 @@ def test_get(LocalCache, config, mock_embed_with_ada): @pytest.mark.vcr @requires_api_key("OPENAI_API_KEY") -def test_get_relevant(LocalCache, config) -> None: +def test_get_relevant(LocalCache, config, patched_api_requestor) -> None: cache = LocalCache(config) text1 = "Sample text 1" text2 = "Sample text 2" diff --git a/tests/integration/test_memory_management.py b/tests/integration/test_memory_management.py index c9ab9fc9..22ade7b0 100644 --- a/tests/integration/test_memory_management.py +++ b/tests/integration/test_memory_management.py @@ -52,7 +52,10 @@ Human Feedback:Command Result: Important Information.""" @requires_api_key("OPENAI_API_KEY") @pytest.mark.vcr def test_save_memory_trimmed_from_context_window( - message_history_fixture, expected_permanent_memory, config: Config + message_history_fixture, + expected_permanent_memory, + config: Config, + patched_api_requestor, ): next_message_to_add_index = len(message_history_fixture) - 1 memory = get_memory(config, init=True) diff --git a/tests/integration/test_setup.py b/tests/integration/test_setup.py index b649bb14..444d9474 100644 --- a/tests/integration/test_setup.py +++ b/tests/integration/test_setup.py @@ -13,7 +13,7 @@ from tests.utils import requires_api_key @pytest.mark.vcr @requires_api_key("OPENAI_API_KEY") -def test_generate_aiconfig_automatic_default(): +def test_generate_aiconfig_automatic_default(patched_api_requestor): user_inputs = [""] with patch("builtins.input", side_effect=user_inputs): ai_config = prompt_user() @@ -26,7 +26,7 @@ def test_generate_aiconfig_automatic_default(): @pytest.mark.vcr @requires_api_key("OPENAI_API_KEY") -def test_generate_aiconfig_automatic_typical(): +def test_generate_aiconfig_automatic_typical(patched_api_requestor): user_prompt = "Help me create a rock opera about cybernetic giraffes" ai_config = generate_aiconfig_automatic(user_prompt) @@ -38,7 +38,7 @@ def test_generate_aiconfig_automatic_typical(): @pytest.mark.vcr @requires_api_key("OPENAI_API_KEY") -def test_generate_aiconfig_automatic_fallback(): +def test_generate_aiconfig_automatic_fallback(patched_api_requestor): user_inputs = [ "T&GF£OIBECC()!*", "Chef-GPT", @@ -59,7 +59,7 @@ def test_generate_aiconfig_automatic_fallback(): @pytest.mark.vcr @requires_api_key("OPENAI_API_KEY") -def test_prompt_user_manual_mode(): +def test_prompt_user_manual_mode(patched_api_requestor): user_inputs = [ "--manual", "Chef-GPT", diff --git a/tests/test_api_manager.py b/tests/test_api_manager.py index ba64a72f..3d0672c1 100644 --- a/tests/test_api_manager.py +++ b/tests/test_api_manager.py @@ -39,6 +39,7 @@ class TestApiManager: with patch("openai.ChatCompletion.create") as mock_create: mock_response = MagicMock() + del mock_response.error mock_response.usage.prompt_tokens = 10 mock_response.usage.completion_tokens = 20 mock_create.return_value = mock_response @@ -55,6 +56,7 @@ class TestApiManager: with patch("openai.ChatCompletion.create") as mock_create: mock_response = MagicMock() + del mock_response.error mock_response.usage.prompt_tokens = 0 mock_response.usage.completion_tokens = 0 mock_create.return_value = mock_response @@ -76,6 +78,7 @@ class TestApiManager: with patch("openai.ChatCompletion.create") as mock_create: mock_response = MagicMock() + del mock_response.error mock_response.usage.prompt_tokens = 10 mock_response.usage.completion_tokens = 20 mock_create.return_value = mock_response diff --git a/tests/test_image_gen.py b/tests/test_image_gen.py index b4eb99e0..136fb510 100644 --- a/tests/test_image_gen.py +++ b/tests/test_image_gen.py @@ -19,7 +19,7 @@ def image_size(request): reason="The image is too big to be put in a cassette for a CI pipeline. We're looking into a solution." ) @requires_api_key("OPENAI_API_KEY") -def test_dalle(config, workspace, image_size): +def test_dalle(config, workspace, image_size, patched_api_requestor): """Test DALL-E image generation.""" generate_and_validate( config, diff --git a/tests/utils.py b/tests/utils.py index 2a0d25d8..2603dfe4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -24,11 +24,11 @@ def requires_api_key(env_var): def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): - if not os.environ.get(env_var) and env_var == "OPENAI_API_KEY": - with dummy_openai_api_key(): - return func(*args, **kwargs) - else: - return func(*args, **kwargs) + if env_var == "OPENAI_API_KEY": + if not os.environ.get(env_var) and env_var == "OPENAI_API_KEY": + with dummy_openai_api_key(): + return func(*args, **kwargs) + return func(*args, **kwargs) return wrapper diff --git a/tests/vcr/vcr_filter.py b/tests/vcr/vcr_filter.py index 892b8021..3a58bee9 100644 --- a/tests/vcr/vcr_filter.py +++ b/tests/vcr/vcr_filter.py @@ -1,7 +1,10 @@ import json +import os import re from typing import Any, Dict, List +from tests.conftest import PROXY + REPLACEMENTS: List[Dict[str, str]] = [ { "regex": r"\w{3} \w{3} {1,2}\d{1,2} \d{2}:\d{2}:\d{2} \d{4}", @@ -13,6 +16,19 @@ REPLACEMENTS: List[Dict[str, str]] = [ }, ] +ALLOWED_HOSTNAMES: List[str] = [ + "api.openai.com", + "localhost:50337", +] + +if PROXY: + ALLOWED_HOSTNAMES.append(PROXY) + ORIGINAL_URL = PROXY +else: + ORIGINAL_URL = "no_ci" + +NEW_URL = "api.openai.com" + def replace_message_content(content: str, replacements: List[Dict[str, str]]) -> str: for replacement in replacements: @@ -53,6 +69,8 @@ def before_record_response(response: Dict[str, Any]) -> Dict[str, Any]: def before_record_request(request: Any) -> Any: + request = replace_request_hostname(request, ORIGINAL_URL, NEW_URL) + filtered_request = filter_hostnames(request) filtered_request_without_dynamic_data = replace_timestamp_in_request( filtered_request @@ -60,14 +78,24 @@ def before_record_request(request: Any) -> Any: return filtered_request_without_dynamic_data -def filter_hostnames(request: Any) -> Any: - allowed_hostnames: List[str] = [ - "api.openai.com", - "localhost:50337", - ] +from urllib.parse import urlparse, urlunparse + +def replace_request_hostname(request: Any, original_url: str, new_hostname: str) -> Any: + parsed_url = urlparse(request.uri) + + if parsed_url.hostname in original_url: + new_path = parsed_url.path.replace("/proxy_function", "") + request.uri = urlunparse( + parsed_url._replace(netloc=new_hostname, path=new_path, scheme="https") + ) + + return request + + +def filter_hostnames(request: Any) -> Any: # Add your implementation here for filtering hostnames - if any(hostname in request.url for hostname in allowed_hostnames): + if any(hostname in request.url for hostname in ALLOWED_HOSTNAMES): return request else: return None