mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-20 23:44:19 +01:00
Fix the maximum context length issue by chunking (#3222)
Co-authored-by: Reinier van der Leer <github@pwuts.nl>
This commit is contained in:
@@ -1,9 +1,14 @@
|
||||
import string
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from numpy.random import RandomState
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from autogpt.llm.llm_utils import get_ada_embedding
|
||||
from autogpt.config import Config
|
||||
from autogpt.llm import llm_utils
|
||||
from autogpt.llm.api_manager import ApiManager
|
||||
from autogpt.llm.modelsinfo import COSTS
|
||||
from tests.utils import requires_api_key
|
||||
|
||||
|
||||
@@ -16,10 +21,42 @@ def random_large_string():
|
||||
return "".join(random.choice(list(string.ascii_lowercase), size=n_characters))
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="We have no mechanism for embedding large strings.")
|
||||
@pytest.fixture()
|
||||
def api_manager(mocker: MockerFixture):
|
||||
api_manager = ApiManager()
|
||||
mocker.patch.multiple(
|
||||
api_manager,
|
||||
total_prompt_tokens=0,
|
||||
total_completion_tokens=0,
|
||||
total_cost=0,
|
||||
)
|
||||
yield api_manager
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def spy_create_embedding(mocker: MockerFixture):
|
||||
return mocker.spy(llm_utils, "create_embedding")
|
||||
|
||||
|
||||
@pytest.mark.vcr
|
||||
@requires_api_key("OPENAI_API_KEY")
|
||||
def test_get_ada_embedding(
|
||||
config: Config, api_manager: ApiManager, spy_create_embedding: MagicMock
|
||||
):
|
||||
token_cost = COSTS[config.embedding_model]["prompt"]
|
||||
llm_utils.get_ada_embedding("test")
|
||||
|
||||
spy_create_embedding.assert_called_once_with("test", model=config.embedding_model)
|
||||
|
||||
assert (prompt_tokens := api_manager.get_total_prompt_tokens()) == 1
|
||||
assert api_manager.get_total_completion_tokens() == 0
|
||||
assert api_manager.get_total_cost() == (prompt_tokens * token_cost) / 1000
|
||||
|
||||
|
||||
@pytest.mark.vcr
|
||||
@requires_api_key("OPENAI_API_KEY")
|
||||
def test_get_ada_embedding_large_context(random_large_string):
|
||||
# This test should be able to mock the openai call after we have a fix. We don't need
|
||||
# to hit the API to test the logic of the function (so not using vcr). This is a quick
|
||||
# regression test to document the issue.
|
||||
get_ada_embedding(random_large_string)
|
||||
llm_utils.get_ada_embedding(random_large_string)
|
||||
|
||||
Reference in New Issue
Block a user