Fix the maximum context length issue by chunking (#3222)

Co-authored-by: Reinier van der Leer <github@pwuts.nl>
This commit is contained in:
kinance
2023-05-02 03:13:24 +09:00
committed by GitHub
parent 0ef6f06462
commit 4767fe63d3
9 changed files with 1801 additions and 43 deletions

View File

@@ -49,6 +49,14 @@ OPENAI_API_KEY=your-openai-api-key
# FAST_TOKEN_LIMIT=4000 # FAST_TOKEN_LIMIT=4000
# SMART_TOKEN_LIMIT=8000 # SMART_TOKEN_LIMIT=8000
### EMBEDDINGS
## EMBEDDING_MODEL - Model to use for creating embeddings
## EMBEDDING_TOKENIZER - Tokenizer to use for chunking large inputs
## EMBEDDING_TOKEN_LIMIT - Chunk size limit for large inputs
# EMBEDDING_MODEL=text-embedding-ada-002
# EMBEDDING_TOKENIZER=cl100k_base
# EMBEDDING_TOKEN_LIMIT=8191
################################################################################ ################################################################################
### MEMORY ### MEMORY
################################################################################ ################################################################################

View File

@@ -35,6 +35,9 @@ class Config(metaclass=Singleton):
self.smart_llm_model = os.getenv("SMART_LLM_MODEL", "gpt-4") self.smart_llm_model = os.getenv("SMART_LLM_MODEL", "gpt-4")
self.fast_token_limit = int(os.getenv("FAST_TOKEN_LIMIT", 4000)) self.fast_token_limit = int(os.getenv("FAST_TOKEN_LIMIT", 4000))
self.smart_token_limit = int(os.getenv("SMART_TOKEN_LIMIT", 8000)) self.smart_token_limit = int(os.getenv("SMART_TOKEN_LIMIT", 8000))
self.embedding_model = os.getenv("EMBEDDING_MODEL", "text-embedding-ada-002")
self.embedding_tokenizer = os.getenv("EMBEDDING_TOKENIZER", "cl100k_base")
self.embedding_token_limit = int(os.getenv("EMBEDDING_TOKEN_LIMIT", 8191))
self.browse_chunk_max_length = int(os.getenv("BROWSE_CHUNK_MAX_LENGTH", 3000)) self.browse_chunk_max_length = int(os.getenv("BROWSE_CHUNK_MAX_LENGTH", 3000))
self.browse_spacy_language_model = os.getenv( self.browse_spacy_language_model = os.getenv(
"BROWSE_SPACY_LANGUAGE_MODEL", "en_core_web_sm" "BROWSE_SPACY_LANGUAGE_MODEL", "en_core_web_sm"
@@ -216,6 +219,18 @@ class Config(metaclass=Singleton):
"""Set the smart token limit value.""" """Set the smart token limit value."""
self.smart_token_limit = value self.smart_token_limit = value
def set_embedding_model(self, value: str) -> None:
"""Set the model to use for creating embeddings."""
self.embedding_model = value
def set_embedding_tokenizer(self, value: str) -> None:
"""Set the tokenizer to use when creating embeddings."""
self.embedding_tokenizer = value
def set_embedding_token_limit(self, value: int) -> None:
"""Set the token limit for creating embeddings."""
self.embedding_token_limit = value
def set_browse_chunk_max_length(self, value: int) -> None: def set_browse_chunk_max_length(self, value: int) -> None:
"""Set the browse_website command chunk max length value.""" """Set the browse_website command chunk max length value."""
self.browse_chunk_max_length = value self.browse_chunk_max_length = value

View File

@@ -11,6 +11,7 @@ from autogpt.llm.base import (
from autogpt.llm.chat import chat_with_ai, create_chat_message, generate_context from autogpt.llm.chat import chat_with_ai, create_chat_message, generate_context
from autogpt.llm.llm_utils import ( from autogpt.llm.llm_utils import (
call_ai_function, call_ai_function,
chunked_tokens,
create_chat_completion, create_chat_completion,
get_ada_embedding, get_ada_embedding,
) )
@@ -32,6 +33,7 @@ __all__ = [
"call_ai_function", "call_ai_function",
"create_chat_completion", "create_chat_completion",
"get_ada_embedding", "get_ada_embedding",
"chunked_tokens",
"COSTS", "COSTS",
"count_message_tokens", "count_message_tokens",
"count_string_tokens", "count_string_tokens",

View File

@@ -2,9 +2,12 @@ from __future__ import annotations
import functools import functools
import time import time
from itertools import islice
from typing import List, Optional from typing import List, Optional
import numpy as np
import openai import openai
import tiktoken
from colorama import Fore, Style from colorama import Fore, Style
from openai.error import APIError, RateLimitError, Timeout from openai.error import APIError, RateLimitError, Timeout
@@ -207,6 +210,23 @@ def create_chat_completion(
return resp return resp
def batched(iterable, n):
"""Batch data into tuples of length n. The last batch may be shorter."""
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError("n must be at least one")
it = iter(iterable)
while batch := tuple(islice(it, n)):
yield batch
def chunked_tokens(text, tokenizer_name, chunk_length):
tokenizer = tiktoken.get_encoding(tokenizer_name)
tokens = tokenizer.encode(text)
chunks_iterator = batched(tokens, chunk_length)
yield from chunks_iterator
def get_ada_embedding(text: str) -> List[float]: def get_ada_embedding(text: str) -> List[float]:
"""Get an embedding from the ada model. """Get an embedding from the ada model.
@@ -217,7 +237,7 @@ def get_ada_embedding(text: str) -> List[float]:
List[float]: The embedding. List[float]: The embedding.
""" """
cfg = Config() cfg = Config()
model = "text-embedding-ada-002" model = cfg.embedding_model
text = text.replace("\n", " ") text = text.replace("\n", " ")
if cfg.use_azure: if cfg.use_azure:
@@ -226,13 +246,7 @@ def get_ada_embedding(text: str) -> List[float]:
kwargs = {"model": model} kwargs = {"model": model}
embedding = create_embedding(text, **kwargs) embedding = create_embedding(text, **kwargs)
api_manager = ApiManager() return embedding
api_manager.update_cost(
prompt_tokens=embedding.usage.prompt_tokens,
completion_tokens=0,
model=model,
)
return embedding["data"][0]["embedding"]
@retry_openai_api() @retry_openai_api()
@@ -251,8 +265,31 @@ def create_embedding(
openai.Embedding: The embedding object. openai.Embedding: The embedding object.
""" """
cfg = Config() cfg = Config()
return openai.Embedding.create( chunk_embeddings = []
input=[text], chunk_lengths = []
api_key=cfg.openai_api_key, for chunk in chunked_tokens(
**kwargs, text,
) tokenizer_name=cfg.embedding_tokenizer,
chunk_length=cfg.embedding_token_limit,
):
embedding = openai.Embedding.create(
input=[chunk],
api_key=cfg.openai_api_key,
**kwargs,
)
api_manager = ApiManager()
api_manager.update_cost(
prompt_tokens=embedding.usage.prompt_tokens,
completion_tokens=0,
model=cfg.embedding_model,
)
chunk_embeddings.append(embedding["data"][0]["embedding"])
chunk_lengths.append(len(chunk))
# do weighted avg
chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lengths)
chunk_embeddings = chunk_embeddings / np.linalg.norm(
chunk_embeddings
) # normalize the length to one
chunk_embeddings = chunk_embeddings.tolist()
return chunk_embeddings

View File

@@ -3,5 +3,8 @@ COSTS = {
"gpt-3.5-turbo-0301": {"prompt": 0.002, "completion": 0.002}, "gpt-3.5-turbo-0301": {"prompt": 0.002, "completion": 0.002},
"gpt-4-0314": {"prompt": 0.03, "completion": 0.06}, "gpt-4-0314": {"prompt": 0.03, "completion": 0.06},
"gpt-4": {"prompt": 0.03, "completion": 0.06}, "gpt-4": {"prompt": 0.03, "completion": 0.06},
"gpt-4-0314": {"prompt": 0.03, "completion": 0.06},
"gpt-4-32k": {"prompt": 0.06, "completion": 0.12},
"gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12},
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
} }

View File

@@ -0,0 +1,168 @@
interactions:
- request:
body: '{"input": [[1985]], "model": "text-embedding-ada-002", "encoding_format":
"base64"}'
headers:
Accept:
- '*/*'
Accept-Encoding:
- gzip, deflate
Connection:
- keep-alive
Content-Length:
- '83'
Content-Type:
- application/json
method: POST
uri: https://api.openai.com/v1/embeddings
response:
body:
string: !!binary |
H4sIAAAAAAAAA1SaSxO6Orvl5++n2LWn9FsiIgl7xl3kkiAgYldXlyAiKHJNgJw6371L/6dOd08c
QAqV5HnWWr/kP/71119/t1ld5NPf//z197sap7//x/fa/Tbd/v7nr//5r7/++uuv//h9/n8jiyYr
7vfqU/6G/25Wn3ux/P3PX/x/X/m/g/756++DcanIKGZXd/X5SZTuBpdQW6+lYQqkfQ5NFN8pSgIw
zBGFBLyF+0x9ZATRzFVcC+xx3aJPfw3cJaBFA2RnibBlW4EuBNKhlLJ5dbH/QRJYDniFELe0Ryw1
zy4l2M8lVX4IVLvccrDkD26WikG50PuxCQFdnkwEiSTcsFJvgcv2m6mB50DfUpu5FLAPV1ZQMZ8b
IvW7azZ//KsHo/R9nYA/OPXa3M0WXkbEI3Dus2z9lNsEjEngUv+4PdWrN6EKdpdVoMba1vqqHZkH
b+fNG4mbTQRmWGsFxIdzQM3Kfkbt1AUcBI3doc1xk9ZLJVoC3ClDSm3yBtl4AC8Bvk47CzunCbFl
aqkGQtGR0A74sT4HY8DDq8puaE3xHC1C9H7BD749CHUSFM03CxuwDd2YQD5eskks+B4idSLUZd1V
n+fOt2FzPT2pkptdNtu6EQD6Cg2sRFSr1/NdKvcmSu74EM8ioDerUyT9vHfIWElpvUwkz0F2c2e0
vXpNNvuXsQe81EdofRp6LfDVTYMNUUx6PtQEjL//y79kHiuu1mRzc3xWsDg6TwQepu6OOhkEae88
TKzIbykanVhZ5SOLI3zWZCmjogx6OCXDhAT9yuoulKYevMWdRl2+GOtu9W6a9CHZBxvBramHc8Up
MNsImCJ0jBgThDiA13A6U8UU5IxJ8keQGG8I2At5RecPahjD4yJBrGZgAl2aDgjGwQth47BVXR5l
eQyPQVnRMISPjMUvNgNJeRm4GC/BsMh0SsDUKRt6PB+LelHlIJCl9SYjYfewsjbe5Rb8MNvDONKd
gclhbEAMywgba6u762TdAkjVIkJrd3Pc9antBBgEgUIEXZPZdDzHL/hCTxEtOGkHtg5XBQiuvsUe
DR76a+oCCDe720rWtGncxT69JHi49BTbp/jo8vxe5mEjlx02733NVuNIb4BDao0PgNnDnOWLBoNH
YlB8ebb6ulyVUS5at8MW8c5DeVb2IdQv8RN7R9F02VEWRxBM2Yka77HVZzqLMbwdpBSJVyNzRymr
rY2zKwusPUQpIrNlFYC+AgOHmmCxbfbkS3mMeR47t0UB/ImJHvBP64C4GV/rebiFCKKw66h78mRA
twSKwOJ5DtvXng4ru62tlPsgx+63Hpep/Sgwm2cXH55jxNZzfL4BkkeYHPqNMVBkCSFkaL8isbq4
gPmml0AnmldsPYNbvYiZBwEcZhcHqV9my/bSWVDZRwOaUSyxWs4WT17k15keClcZmCRTHua8lmGn
2QTR9CDbEtabDUaSqLHoMsfUkgoo+WRNL5+oc6ruJqnc3qeeHIwRldNjBXb81aCHeE7ZOnKrAMs6
JeRZtn00qXIQwho5Cek+J21Y5/NJg547drRQwStaLreWg+M2irB7ZM96em4kDSiX65bsDjAeRv5o
FxDt+JEqUVmzWUkwhD23rthXls0wLRvTAWgnjGhV9/nQ77ZKLpUhxNTYmEs0LU5pyNKTK4gEBjmb
9N1DA/zhpmPNutz0JdZ8A3Kq4OOD8HnUtBma8M99fXOpaxLfHF56VUxFwtozfS10vYc5jhl671zD
7ayTOErIrC4EGB9xmPbVK4aBW2dorrdAXwN9HGExwJCikPu4zMx2BF6OmU6k0DCGuXWjXLKi/EVj
6cPYlD+EGSr704CPt/NnWBRijjBqDzLaXMIuYrvNkgIm5DYRWWXWs7XbOlC0e4/aBlcNzNn0PcyX
4YiA8UmHlRVVDBsOKvQU8Za+RmkqgcwZPIy2xgiWNNzOsBncC9Wt/D6Q7a0RwFsKrlhP44ixuK1z
mJ9QT93w2LqzFJocrO+tRA9G1NdLW18R7N2qIuLuRoalfTwrYMvjTOZvf1qorhBgHHuB6tK+cef9
dXFgBIUnqTt7x2b1/kkgi/oLavuDDKZdJdrSSqWeatz5zOa38EHAy7BPzYDYYAG9ToDs9yrhv/O7
3kQ7Br4WHLD71bdOeCgF5OhVRnJYTWDmzWCG7gXK2MxhmlH95s9gPrsaPvrPul6GduWhpZUdPp5a
N2PcoxThd/6pJrx1sN5EJYGnk8OhrdsLw7AUhgfdNssRF36WqOtJGUDH9I5Ym54LYx6gDjjipiH8
GzRgILEkwqnjJHqA2qGeuYfFwfO6a5Akth82c+RUwlZrZ3o+eHt3GRvZAiS7xPiwOw86I0WqASsq
XiiYHiqYL1AeAa34CTufxHKX61vUwIDHiR6bOM3owxV4MCXdhM3g4ej8RRs8OBlsR80VvrMlLdIW
3nfBgJXJ9tz3tXReEIKAYteXjtHylLwGevfXA59j+VMvt9vZhtKaythJKk0XnCFuwFefscdLJhtP
bPZkJd16JMolO2PGpChw674Q4QXqsDnzm1y69MuduvxWi3jyiRA8bz2Rekfxra/2W4JwSk0DR8G5
1ufr6Qkh+qQ+eX/1tputOYUtvzHJOnU2W80EImmTcBHa8Ks0LPYlTOC2O26o/yjs73qiyR9/5CVW
XM/LPJbgFFUB2QUiZatq6RUsJPVB5u/7FOIks+HhaliIPWSZjUXjJHB95C69U3DRp+zJV0BYbJ6c
0rAG7KffB6kosW/K7bCcmBVAk5Idtk/1h43h6CuSu79cUNTZF0DuYPSAWRkOmWt5dafuPqxwDdID
RuVRdZmcuCvciMlEXSW9slV1Zw74/rPCSF6O7uortgAJ6Z9oEndrNKLXksPMs2qqHhRNHxkALyhe
vC1GZr4b5lrqCYyO1pZaqcTrHXgVAbwcNiN267qsV2z2OZzsV0ldheJh1nrowBRr7/96fxetRvBU
azySmk2Qje1JEeHWbRD1YL9j1FPzFiaf2id9EunZ+pATGy59ZKNdIGLwez5wwnOGjbJ86Gy6lTG0
3vHtT723vJnOcCPGE3a//WnZVbMDY+7IYeWtGtkCWMtDn/VP9OaAoDMuSBLAKamAtdi26+UdVCv0
nxHGB0/P63arZ8p+P8aAjKfdFgxTIvYSSooQqy0NdRbWqgDezcajWno5RLO5X0W407WeOob7yGbX
OzcSnYH01e8SzHSeEyjPUkidgdcH5sqjBEs+7mgx4339HtpVgKfH1FBNSEyXP1qAg+vmiqj9XHYZ
M7o0gWP4qAhI1MWdw9HUYErUiXqqvrB5TIpv/QoT1qT7yBbnagRAfUgmthscg4UF6igp/aTR42ab
DvOYezYYNchh/bW/6qPJm7lUJCeGkjXjs7EWmAIN1/Cxsk+ygcwukCCIA4aNoHtHdMCnAOyPTwfj
G8uHpRh8Afh+XSGiOfXPn/HAj5CGBKa2jKJE4aExFh+qtJH4/X1eCEP9HpE9pYrO25u4grz+vmHn
5o4ZScclgGVVMNTZgLqtsWYljPJUo97xibOlPro3aXEDhUbcbR+tGZgVyNXeh6r5rEbjKh17OJ+P
GrX7gg0lW4sc7PjMoLb9GdksWoIEHTUryV5528OuhwUH6xzX1ASdWn/rPQbVkLwRj/kFLP6wjlB7
4hrJqDfA1m1OhmQJwZlG5XrK/ughamObmvrwZrPjDC0wN35GcQsUIDiLpsHwfD5SDfcO4B2nbqF0
0mqql+sSMRSebHghLwsHXVC79OfHp1ueYPu+J1EHl9oCxFsHqj9Tk9EPM3t4etAGH+73rT7LzeqA
0UjP3/mANT2MPJKm4XbCasRb7uI3fgjuZ/FEHfG0uuNmjGYQcy6HXpL6yKbaPdrwlzd1IWuHpZxp
DrFcINJ882XHrLEBh0tLv/mnGlbFPoYgDcQc65L7za9bQ4GOESJqQU8DW+tc3sCoPj/Y3Yk70FWH
aw4DeTeTfWVe69mSqwamD67E5tfv7OxeKCC+kopsK/uZLbOcIVDySUd21/eoz3TQJCg81itWlxxF
8x4VIQzoWiNxF16GOS9bAukzS4i8v1tgvVw3CkRvTsAq92n1dc3nVN5fNx15g04dhL0g3eDY2io1
Y/kzrKNRxrK6hCV2HU3Jfs8D4oGcsKY9Fn1BtL8BdFYAdeAxc8k3/8GnlgmoV/rRnRtzEX71RK83
q/35j1jKK/tNfdS/wHxvFQu8zWYlUXDW9bksHg7sbo6N3SRt2PrJsARffL9DrDcrdyBp38DPfCoQ
f38r2Swmai5/x1Nlu22jKaDFC7KovRDh5FTunBp3WwqCUKEqunLR3KVlKC+39YX4aUjBKD+NEX7X
P1r22UOfYzbeoDffTtS5uV7GErZW0h+9KQJDX0JBhkDfLQo1N09nmL/5Bhiu5VPjmqjD7tUYLSik
EpA6ZsqwO2VVDlKiT+h1KIaoS4ugBd/5or/1O3qBGcD0yV/oY5tifW0rGoKNadyRYFuBy0javwD4
NMo3f94zFtZH/ud/aELW/pvXRgluwfP0x68vh2tg/PG/1tpHOvXUuP/lETKpQslI5AkatGqgI5k/
3cGiawuE9VPSyfztZ+xZ1go8v8cYeyipavbzb/GjD6minXuXFLrbg7g01m8+7Fxat9oN4g+j1AIX
nH3rrwVj66gYJY9b1L8uLw1WtpqSLd6fM4YsLtj7fJNjJMIzWPL3ywMavozUupb20AfSoYILtDn6
84vELv0S5kpIkbjrpuwPfwmGjYQku8uBcFtvEiSkfRIQHludJa8PhEuyAei1trU708ER4SSHB3zU
W06fV9Ks8BxeD1SP0rSe57UP4VooPFbf1gHM6LUUcvp4GGT51gfbA1mE9f1YYLuRlWi9kgGB4C5O
9GAvtvu5yjoPz5pYY4zugst2+tOBumSH1LxHyJ0jyo9wauua7O3XNVtMO2tg7+V3IontAcxD/Jqh
ubwIvbc6ijrTqCXotpJIFc5FbHYvcw4vnNFj5z5u6mWWIw8qkt/Qg3cLszXnPR7IhVVS1XluMxI/
xBBAKeapW5ApY5301mBNDlck2h8PrMn7JIJFbs7UiJ8kG03B9cDxWdzx8aPVbOVUmUBouguxfbPP
mKjWNxhGYEWQj0/ZrJOa/62fnz7U6zPUUjgEbfdnfuf8VDgwy94HrMe+566L/+bhKPgIl7zI6U8h
mhrw3oHDL78z9s3nIqGvHDvGEjG25ftY+vYDrGsvvp6PKioBMFaMxML29UXfXRTw43nuLaBgvb6R
Bb/8B9FI74f5+RbKXx6jaJtVQ0f3Q7k/12cNB3nc6lRO1QoWAxeixi1BtPiNGUL+teGpbQvD0Pey
IoLp824JTO59NpNJ6SFVFEjawWh1VlbtDF/K3qUG94CMvqdRA6p7fP78l7twyyPYJ3ZlYr+ePjW5
vi0L1iVnozJc02hOHa2Rv36ILOUprpcj/tjw+rY+2Dm/d4xOJ92Tf/xJfTAB0N96JpxrUttQ02ht
xgDJ19gZ8c//rBN7QMDrJkedOVSzMWzPMdyOyoT2xn2pxy9fg21cNtP+vHPcHnGuBiEkFlW2mDHm
OfsZdpYAsMKLhT5eT08OjALpCfzsNLaquO1hcZgcwhwJuIvbXA1gb5wDko4XO1vEArbAvXAyteFh
cNnpLGvQng2dPkprqic+WlJZGe0r2fNmw358FBRJxMii6Zcf38x/+QXndfvSf/4QvJ6XHmvRvhtY
4D0k6W7A5Ntv+2H83oe2m2Y0g9D55rVAg/0oQiKxswB6Im0SiWxMheL9Z2LdMPYV+OWH42mjusTM
ux7MZltgMy0k1ke58gLf+aFq7nhgTjiQ/3nf2mVZh+HKGgIDeTtj7CQkI6qSWFAUP5i8u+d9oD1M
uP03X9JDFdpRb5XXGG4dZ0fN4NG7bMcXJehasqBebD9gXiW1lW9pvKO6ewrAt7+uwB1Ch2rlNAHW
8bsGfip0Ii/8Zjq7bo4N+PJq+s1/YHs+vnK4+YjoD+9dvFeWAF0cAT7l12Egx3qL4ErFnhpxiIf9
GV9FuM5Cg60mr4cF7oIS6s3WJFs327G1GVMEI3nOqbbfpdGojXsR7urHEYlfXracuM4B37z8zW9P
0FOtWIF9D05fvnZnzO65HIIXDfHXr331UirBTn1H1OCEsP7pgXQqh4rI85Vn6+/9LE5ikr3/rIcv
b7Z+/AEf4aq7c+ebKQDGjIkgVLk+7697B2qCpJGBuqCeOHbWwEswEI6mZpuRn14GhG/JJtKd+udX
4I9H+8dmZfP97At/+MzmxuAw4c/Og/sxAdTyzoQxJeBz8PE39pdv7/W3U3UpLDUlxt4NfMD4q9d3
lxT4y1N0iryBQLe95tRNPu9h0TYOAvBS6/T40XS2rHYfSycGKbVhcmRbAroQLgN5YG+gA+sfFkhg
2PYJEoT9AkZwfcdgE79V6pq+r/M6fwqg3qEL9peDEs1uqCtw95JK8tRufs1AUNnAOJxe1Do5mr47
inog/fR9zmPb/fILBRKXn7G3T/WMbu61DS89uyMAETeQjU57sG4yRIZZoe46n68KtNj1g621Z+54
VK0S2l6c4JsalD+9QeCb12mmcChbrdsg/vIZxpxfRayMCg++L0WLoDFCl11lXQDvs5lQnz/J7Mej
gZvUzc+/RuyRqSncZ1eFfv25y3/rWT6cT3eMkiCrd+mIE7BjoMX6+/OqWXWTX/CrF4QT4ZbNh+3m
9ocXWf1QZXMfbC3ocXBPvTPl2Ki3XAo43q7oob/OP55syGopqPSX10lTAQPoUS1SWynf7penoD9+
6dSdrXoXk1SERg71n37UfSY4DfzuT2C9I162Oz1CBxo599O/dlhF5hB4E5cEF8MCAG2WbIR5sb38
/HJEj2pbwN4r7ljV9B2YuQeCkmJtEsRzjxysxZ4UkJ6klSoXdacT64RmcNw8O3pQjIFR7TlDuM2s
FStqumEL3KUV7EygIMsbXwPtD3tJEh6BjP0xnPSfvsBbAkpSfCIuY0HoJ0AXLh2av3x6Fpw1he5B
1bC2zPMwvgXqwaOPY4pbfxvN8f4kyE8zeFFzIyqR8Ov/H1+2qdL2ZOgVWw3+jDfubyXaYWfiwIdc
P9SrpHR4PbeqAdqoGagenHWXLP7E//whPpyTKWNfvyh/88mf/YWV2h4C1s4LsDMgqtP4YSXQo/mZ
BhMph13Udx7cbZMNgS/ryabuo4jwiFyMcXRTsuWbp6BUPg/kQ1EdLQFNGrgxrTs2pE8E+tsaSvAy
te6Pp4AJLoMB9XiDsdtMhb5uu30KN9GxR2y6JMP84rcB1IVzh7oqx9my9ZAnhfvXif78ycKUjQPS
zb7+7S8w8kQ3AkVJawjvbIp6XuUlgbzURmQP19plV9aMsBssHZtA93Xy88O/vK+ViaiP0r0U5K8f
psahcKNZcKQbbFwUY6x2oT6/KBbAb/9HxyWuZ+mZVjByKo/65oV3R/zwNPjjZUyITsMo908e3LVD
S/g4pDWTkOj94Y+XdHOpZ16wbLgRhSOST/CarSMnCWDevp/426/An/xvc+1Av+uVzXIj2dI+Fy5Y
iY7IpZddj8D+qhK0+er7aqR7CAso+kRsmr6epQu7yTW2Ttj/wCaaP5z1gostEqpc0hHQIuo18OXZ
ZJGO/cDsXsjB8T4w0hqqGHV0k3I//4rV6D5n9P1CEjjs9SO27mrGGA6iFGZw80Q722U1WY6cLe2U
LsV2RdRhXvkcwfAyZjTcbnBGfzx4t26m335a9NvvBV8egY+k58A01GYKv/yY+rMj6N3LA9ZPn7AD
d80wbMWbAr/9AvuqoLBp3r1n+M1zROhj7+tPrQpu1BNP3nZaRUtyaEZ4xK+G/Pj1ck2mAv78uTek
gt5utrMo/fTLFHiSrVn/GeUfn4y/ej0XVdjCe5xLGD0kTWeKb8ywAsz4+jPObVM8a/CsSTVG/srp
0zBWJfyUq0Qkv/m4q+ynrz/1AGrPYaszxC8Q5I8EzdrZ0dmLditYrhcfW5X/YqOPDhW8XVf05XOb
aC3sxYITjD2qLM1xYCu1R2guDUFSv9tnP14n9/dbivFxkw5ffbaBwCqLKi/j5tIHbsr9ZVOpWHck
oLPrPnZk8zsT6BNx0dgdzyOA1tum+uodwWrdakn8+3cq4D//9ddf/+t3wqBp78X7ezBgKpbp3/99
VODft/vt3zwv/JsKf04ikPFWFn//81+HEP7uhrbppv89ta/iM/79z1/bP6cN/p7a6fb+fy7/6/td
//mv/wMAAP//AwDOXgQl4SAAAA==
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 7c09bf823fb50b70-AMS
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Mon, 01 May 2023 17:29:41 GMT
Server:
- cloudflare
access-control-allow-origin:
- '*'
alt-svc:
- h3=":443"; ma=86400, h3-29=":443"; ma=86400
openai-organization:
- user-kd1j0bcill5flig1m29wdaof
openai-processing-ms:
- '69'
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=15724800; includeSubDomains
x-ratelimit-limit-requests:
- '3000'
x-ratelimit-remaining-requests:
- '2999'
x-ratelimit-reset-requests:
- 20ms
x-request-id:
- 555d4ffdb6ceac9f62f60bb64d87170d
status:
code: 200
message: OK
version: 1

View File

@@ -1,9 +1,14 @@
import string import string
from unittest.mock import MagicMock
import pytest import pytest
from numpy.random import RandomState from numpy.random import RandomState
from pytest_mock import MockerFixture
from autogpt.llm.llm_utils import get_ada_embedding from autogpt.config import Config
from autogpt.llm import llm_utils
from autogpt.llm.api_manager import ApiManager
from autogpt.llm.modelsinfo import COSTS
from tests.utils import requires_api_key from tests.utils import requires_api_key
@@ -16,10 +21,42 @@ def random_large_string():
return "".join(random.choice(list(string.ascii_lowercase), size=n_characters)) return "".join(random.choice(list(string.ascii_lowercase), size=n_characters))
@pytest.mark.xfail(reason="We have no mechanism for embedding large strings.") @pytest.fixture()
def api_manager(mocker: MockerFixture):
api_manager = ApiManager()
mocker.patch.multiple(
api_manager,
total_prompt_tokens=0,
total_completion_tokens=0,
total_cost=0,
)
yield api_manager
@pytest.fixture()
def spy_create_embedding(mocker: MockerFixture):
return mocker.spy(llm_utils, "create_embedding")
@pytest.mark.vcr
@requires_api_key("OPENAI_API_KEY")
def test_get_ada_embedding(
config: Config, api_manager: ApiManager, spy_create_embedding: MagicMock
):
token_cost = COSTS[config.embedding_model]["prompt"]
llm_utils.get_ada_embedding("test")
spy_create_embedding.assert_called_once_with("test", model=config.embedding_model)
assert (prompt_tokens := api_manager.get_total_prompt_tokens()) == 1
assert api_manager.get_total_completion_tokens() == 0
assert api_manager.get_total_cost() == (prompt_tokens * token_cost) / 1000
@pytest.mark.vcr
@requires_api_key("OPENAI_API_KEY") @requires_api_key("OPENAI_API_KEY")
def test_get_ada_embedding_large_context(random_large_string): def test_get_ada_embedding_large_context(random_large_string):
# This test should be able to mock the openai call after we have a fix. We don't need # This test should be able to mock the openai call after we have a fix. We don't need
# to hit the API to test the logic of the function (so not using vcr). This is a quick # to hit the API to test the logic of the function (so not using vcr). This is a quick
# regression test to document the issue. # regression test to document the issue.
get_ada_embedding(random_large_string) llm_utils.get_ada_embedding(random_large_string)

View File

@@ -1,8 +1,7 @@
import pytest import pytest
from openai.error import APIError, RateLimitError from openai.error import APIError, RateLimitError
from autogpt.llm import COSTS, get_ada_embedding from autogpt.llm import llm_utils
from autogpt.llm.llm_utils import retry_openai_api
@pytest.fixture(params=[RateLimitError, APIError]) @pytest.fixture(params=[RateLimitError, APIError])
@@ -13,22 +12,12 @@ def error(request):
return request.param("Error") return request.param("Error")
@pytest.fixture
def mock_create_embedding(mocker):
mock_response = mocker.MagicMock()
mock_response.usage.prompt_tokens = 5
mock_response.__getitem__.side_effect = lambda key: [{"embedding": [0.1, 0.2, 0.3]}]
return mocker.patch(
"autogpt.llm.llm_utils.create_embedding", return_value=mock_response
)
def error_factory(error_instance, error_count, retry_count, warn_user=True): def error_factory(error_instance, error_count, retry_count, warn_user=True):
class RaisesError: class RaisesError:
def __init__(self): def __init__(self):
self.count = 0 self.count = 0
@retry_openai_api( @llm_utils.retry_openai_api(
num_retries=retry_count, backoff_base=0.001, warn_user=warn_user num_retries=retry_count, backoff_base=0.001, warn_user=warn_user
) )
def __call__(self): def __call__(self):
@@ -41,7 +30,7 @@ def error_factory(error_instance, error_count, retry_count, warn_user=True):
def test_retry_open_api_no_error(capsys): def test_retry_open_api_no_error(capsys):
@retry_openai_api() @llm_utils.retry_openai_api()
def f(): def f():
return 1 return 1
@@ -114,16 +103,31 @@ def test_retry_openapi_other_api_error(capsys):
assert output.out == "" assert output.out == ""
def test_get_ada_embedding(mock_create_embedding, api_manager): def test_chunked_tokens():
model = "text-embedding-ada-002" text = "Auto-GPT is an experimental open-source application showcasing the capabilities of the GPT-4 language model"
embedding = get_ada_embedding("test") expected_output = [
mock_create_embedding.assert_called_once_with( (
"test", model="text-embedding-ada-002" 13556,
) 12279,
2898,
assert embedding == [0.1, 0.2, 0.3] 374,
459,
cost = COSTS[model]["prompt"] 22772,
assert api_manager.get_total_prompt_tokens() == 5 1825,
assert api_manager.get_total_completion_tokens() == 0 31874,
assert api_manager.get_total_cost() == (5 * cost) / 1000 3851,
67908,
279,
17357,
315,
279,
480,
2898,
12,
19,
4221,
1646,
)
]
output = list(llm_utils.chunked_tokens(text, "cl100k_base", 8191))
assert output == expected_output