Fix the maximum context length issue by chunking (#3222)

Co-authored-by: Reinier van der Leer <github@pwuts.nl>
This commit is contained in:
kinance
2023-05-02 03:13:24 +09:00
committed by GitHub
parent 0ef6f06462
commit 4767fe63d3
9 changed files with 1801 additions and 43 deletions

View File

@@ -1,9 +1,14 @@
import string
from unittest.mock import MagicMock
import pytest
from numpy.random import RandomState
from pytest_mock import MockerFixture
from autogpt.llm.llm_utils import get_ada_embedding
from autogpt.config import Config
from autogpt.llm import llm_utils
from autogpt.llm.api_manager import ApiManager
from autogpt.llm.modelsinfo import COSTS
from tests.utils import requires_api_key
@@ -16,10 +21,42 @@ def random_large_string():
return "".join(random.choice(list(string.ascii_lowercase), size=n_characters))
@pytest.mark.xfail(reason="We have no mechanism for embedding large strings.")
@pytest.fixture()
def api_manager(mocker: MockerFixture):
api_manager = ApiManager()
mocker.patch.multiple(
api_manager,
total_prompt_tokens=0,
total_completion_tokens=0,
total_cost=0,
)
yield api_manager
@pytest.fixture()
def spy_create_embedding(mocker: MockerFixture):
return mocker.spy(llm_utils, "create_embedding")
@pytest.mark.vcr
@requires_api_key("OPENAI_API_KEY")
def test_get_ada_embedding(
config: Config, api_manager: ApiManager, spy_create_embedding: MagicMock
):
token_cost = COSTS[config.embedding_model]["prompt"]
llm_utils.get_ada_embedding("test")
spy_create_embedding.assert_called_once_with("test", model=config.embedding_model)
assert (prompt_tokens := api_manager.get_total_prompt_tokens()) == 1
assert api_manager.get_total_completion_tokens() == 0
assert api_manager.get_total_cost() == (prompt_tokens * token_cost) / 1000
@pytest.mark.vcr
@requires_api_key("OPENAI_API_KEY")
def test_get_ada_embedding_large_context(random_large_string):
# This test should be able to mock the openai call after we have a fix. We don't need
# to hit the API to test the logic of the function (so not using vcr). This is a quick
# regression test to document the issue.
get_ada_embedding(random_large_string)
llm_utils.get_ada_embedding(random_large_string)