mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-24 01:14:22 +01:00
Fix the maximum context length issue by chunking (#3222)
Co-authored-by: Reinier van der Leer <github@pwuts.nl>
This commit is contained in:
@@ -49,6 +49,14 @@ OPENAI_API_KEY=your-openai-api-key
|
|||||||
# FAST_TOKEN_LIMIT=4000
|
# FAST_TOKEN_LIMIT=4000
|
||||||
# SMART_TOKEN_LIMIT=8000
|
# SMART_TOKEN_LIMIT=8000
|
||||||
|
|
||||||
|
### EMBEDDINGS
|
||||||
|
## EMBEDDING_MODEL - Model to use for creating embeddings
|
||||||
|
## EMBEDDING_TOKENIZER - Tokenizer to use for chunking large inputs
|
||||||
|
## EMBEDDING_TOKEN_LIMIT - Chunk size limit for large inputs
|
||||||
|
# EMBEDDING_MODEL=text-embedding-ada-002
|
||||||
|
# EMBEDDING_TOKENIZER=cl100k_base
|
||||||
|
# EMBEDDING_TOKEN_LIMIT=8191
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
### MEMORY
|
### MEMORY
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|||||||
@@ -35,6 +35,9 @@ class Config(metaclass=Singleton):
|
|||||||
self.smart_llm_model = os.getenv("SMART_LLM_MODEL", "gpt-4")
|
self.smart_llm_model = os.getenv("SMART_LLM_MODEL", "gpt-4")
|
||||||
self.fast_token_limit = int(os.getenv("FAST_TOKEN_LIMIT", 4000))
|
self.fast_token_limit = int(os.getenv("FAST_TOKEN_LIMIT", 4000))
|
||||||
self.smart_token_limit = int(os.getenv("SMART_TOKEN_LIMIT", 8000))
|
self.smart_token_limit = int(os.getenv("SMART_TOKEN_LIMIT", 8000))
|
||||||
|
self.embedding_model = os.getenv("EMBEDDING_MODEL", "text-embedding-ada-002")
|
||||||
|
self.embedding_tokenizer = os.getenv("EMBEDDING_TOKENIZER", "cl100k_base")
|
||||||
|
self.embedding_token_limit = int(os.getenv("EMBEDDING_TOKEN_LIMIT", 8191))
|
||||||
self.browse_chunk_max_length = int(os.getenv("BROWSE_CHUNK_MAX_LENGTH", 3000))
|
self.browse_chunk_max_length = int(os.getenv("BROWSE_CHUNK_MAX_LENGTH", 3000))
|
||||||
self.browse_spacy_language_model = os.getenv(
|
self.browse_spacy_language_model = os.getenv(
|
||||||
"BROWSE_SPACY_LANGUAGE_MODEL", "en_core_web_sm"
|
"BROWSE_SPACY_LANGUAGE_MODEL", "en_core_web_sm"
|
||||||
@@ -216,6 +219,18 @@ class Config(metaclass=Singleton):
|
|||||||
"""Set the smart token limit value."""
|
"""Set the smart token limit value."""
|
||||||
self.smart_token_limit = value
|
self.smart_token_limit = value
|
||||||
|
|
||||||
|
def set_embedding_model(self, value: str) -> None:
|
||||||
|
"""Set the model to use for creating embeddings."""
|
||||||
|
self.embedding_model = value
|
||||||
|
|
||||||
|
def set_embedding_tokenizer(self, value: str) -> None:
|
||||||
|
"""Set the tokenizer to use when creating embeddings."""
|
||||||
|
self.embedding_tokenizer = value
|
||||||
|
|
||||||
|
def set_embedding_token_limit(self, value: int) -> None:
|
||||||
|
"""Set the token limit for creating embeddings."""
|
||||||
|
self.embedding_token_limit = value
|
||||||
|
|
||||||
def set_browse_chunk_max_length(self, value: int) -> None:
|
def set_browse_chunk_max_length(self, value: int) -> None:
|
||||||
"""Set the browse_website command chunk max length value."""
|
"""Set the browse_website command chunk max length value."""
|
||||||
self.browse_chunk_max_length = value
|
self.browse_chunk_max_length = value
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from autogpt.llm.base import (
|
|||||||
from autogpt.llm.chat import chat_with_ai, create_chat_message, generate_context
|
from autogpt.llm.chat import chat_with_ai, create_chat_message, generate_context
|
||||||
from autogpt.llm.llm_utils import (
|
from autogpt.llm.llm_utils import (
|
||||||
call_ai_function,
|
call_ai_function,
|
||||||
|
chunked_tokens,
|
||||||
create_chat_completion,
|
create_chat_completion,
|
||||||
get_ada_embedding,
|
get_ada_embedding,
|
||||||
)
|
)
|
||||||
@@ -32,6 +33,7 @@ __all__ = [
|
|||||||
"call_ai_function",
|
"call_ai_function",
|
||||||
"create_chat_completion",
|
"create_chat_completion",
|
||||||
"get_ada_embedding",
|
"get_ada_embedding",
|
||||||
|
"chunked_tokens",
|
||||||
"COSTS",
|
"COSTS",
|
||||||
"count_message_tokens",
|
"count_message_tokens",
|
||||||
"count_string_tokens",
|
"count_string_tokens",
|
||||||
|
|||||||
@@ -2,9 +2,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
import time
|
import time
|
||||||
|
from itertools import islice
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import openai
|
import openai
|
||||||
|
import tiktoken
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
from openai.error import APIError, RateLimitError, Timeout
|
from openai.error import APIError, RateLimitError, Timeout
|
||||||
|
|
||||||
@@ -207,6 +210,23 @@ def create_chat_completion(
|
|||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
def batched(iterable, n):
|
||||||
|
"""Batch data into tuples of length n. The last batch may be shorter."""
|
||||||
|
# batched('ABCDEFG', 3) --> ABC DEF G
|
||||||
|
if n < 1:
|
||||||
|
raise ValueError("n must be at least one")
|
||||||
|
it = iter(iterable)
|
||||||
|
while batch := tuple(islice(it, n)):
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
def chunked_tokens(text, tokenizer_name, chunk_length):
|
||||||
|
tokenizer = tiktoken.get_encoding(tokenizer_name)
|
||||||
|
tokens = tokenizer.encode(text)
|
||||||
|
chunks_iterator = batched(tokens, chunk_length)
|
||||||
|
yield from chunks_iterator
|
||||||
|
|
||||||
|
|
||||||
def get_ada_embedding(text: str) -> List[float]:
|
def get_ada_embedding(text: str) -> List[float]:
|
||||||
"""Get an embedding from the ada model.
|
"""Get an embedding from the ada model.
|
||||||
|
|
||||||
@@ -217,7 +237,7 @@ def get_ada_embedding(text: str) -> List[float]:
|
|||||||
List[float]: The embedding.
|
List[float]: The embedding.
|
||||||
"""
|
"""
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
model = "text-embedding-ada-002"
|
model = cfg.embedding_model
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
|
|
||||||
if cfg.use_azure:
|
if cfg.use_azure:
|
||||||
@@ -226,13 +246,7 @@ def get_ada_embedding(text: str) -> List[float]:
|
|||||||
kwargs = {"model": model}
|
kwargs = {"model": model}
|
||||||
|
|
||||||
embedding = create_embedding(text, **kwargs)
|
embedding = create_embedding(text, **kwargs)
|
||||||
api_manager = ApiManager()
|
return embedding
|
||||||
api_manager.update_cost(
|
|
||||||
prompt_tokens=embedding.usage.prompt_tokens,
|
|
||||||
completion_tokens=0,
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
return embedding["data"][0]["embedding"]
|
|
||||||
|
|
||||||
|
|
||||||
@retry_openai_api()
|
@retry_openai_api()
|
||||||
@@ -251,8 +265,31 @@ def create_embedding(
|
|||||||
openai.Embedding: The embedding object.
|
openai.Embedding: The embedding object.
|
||||||
"""
|
"""
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
return openai.Embedding.create(
|
chunk_embeddings = []
|
||||||
input=[text],
|
chunk_lengths = []
|
||||||
api_key=cfg.openai_api_key,
|
for chunk in chunked_tokens(
|
||||||
**kwargs,
|
text,
|
||||||
)
|
tokenizer_name=cfg.embedding_tokenizer,
|
||||||
|
chunk_length=cfg.embedding_token_limit,
|
||||||
|
):
|
||||||
|
embedding = openai.Embedding.create(
|
||||||
|
input=[chunk],
|
||||||
|
api_key=cfg.openai_api_key,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
api_manager = ApiManager()
|
||||||
|
api_manager.update_cost(
|
||||||
|
prompt_tokens=embedding.usage.prompt_tokens,
|
||||||
|
completion_tokens=0,
|
||||||
|
model=cfg.embedding_model,
|
||||||
|
)
|
||||||
|
chunk_embeddings.append(embedding["data"][0]["embedding"])
|
||||||
|
chunk_lengths.append(len(chunk))
|
||||||
|
|
||||||
|
# do weighted avg
|
||||||
|
chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lengths)
|
||||||
|
chunk_embeddings = chunk_embeddings / np.linalg.norm(
|
||||||
|
chunk_embeddings
|
||||||
|
) # normalize the length to one
|
||||||
|
chunk_embeddings = chunk_embeddings.tolist()
|
||||||
|
return chunk_embeddings
|
||||||
|
|||||||
@@ -3,5 +3,8 @@ COSTS = {
|
|||||||
"gpt-3.5-turbo-0301": {"prompt": 0.002, "completion": 0.002},
|
"gpt-3.5-turbo-0301": {"prompt": 0.002, "completion": 0.002},
|
||||||
"gpt-4-0314": {"prompt": 0.03, "completion": 0.06},
|
"gpt-4-0314": {"prompt": 0.03, "completion": 0.06},
|
||||||
"gpt-4": {"prompt": 0.03, "completion": 0.06},
|
"gpt-4": {"prompt": 0.03, "completion": 0.06},
|
||||||
|
"gpt-4-0314": {"prompt": 0.03, "completion": 0.06},
|
||||||
|
"gpt-4-32k": {"prompt": 0.06, "completion": 0.12},
|
||||||
|
"gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12},
|
||||||
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
|
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,168 @@
|
|||||||
|
interactions:
|
||||||
|
- request:
|
||||||
|
body: '{"input": [[1985]], "model": "text-embedding-ada-002", "encoding_format":
|
||||||
|
"base64"}'
|
||||||
|
headers:
|
||||||
|
Accept:
|
||||||
|
- '*/*'
|
||||||
|
Accept-Encoding:
|
||||||
|
- gzip, deflate
|
||||||
|
Connection:
|
||||||
|
- keep-alive
|
||||||
|
Content-Length:
|
||||||
|
- '83'
|
||||||
|
Content-Type:
|
||||||
|
- application/json
|
||||||
|
method: POST
|
||||||
|
uri: https://api.openai.com/v1/embeddings
|
||||||
|
response:
|
||||||
|
body:
|
||||||
|
string: !!binary |
|
||||||
|
H4sIAAAAAAAAA1SaSxO6Orvl5++n2LWn9FsiIgl7xl3kkiAgYldXlyAiKHJNgJw6371L/6dOd08c
|
||||||
|
QAqV5HnWWr/kP/71119/t1ld5NPf//z197sap7//x/fa/Tbd/v7nr//5r7/++uuv//h9/n8jiyYr
|
||||||
|
7vfqU/6G/25Wn3ux/P3PX/x/X/m/g/756++DcanIKGZXd/X5SZTuBpdQW6+lYQqkfQ5NFN8pSgIw
|
||||||
|
zBGFBLyF+0x9ZATRzFVcC+xx3aJPfw3cJaBFA2RnibBlW4EuBNKhlLJ5dbH/QRJYDniFELe0Ryw1
|
||||||
|
zy4l2M8lVX4IVLvccrDkD26WikG50PuxCQFdnkwEiSTcsFJvgcv2m6mB50DfUpu5FLAPV1ZQMZ8b
|
||||||
|
IvW7azZ//KsHo/R9nYA/OPXa3M0WXkbEI3Dus2z9lNsEjEngUv+4PdWrN6EKdpdVoMba1vqqHZkH
|
||||||
|
b+fNG4mbTQRmWGsFxIdzQM3Kfkbt1AUcBI3doc1xk9ZLJVoC3ClDSm3yBtl4AC8Bvk47CzunCbFl
|
||||||
|
aqkGQtGR0A74sT4HY8DDq8puaE3xHC1C9H7BD749CHUSFM03CxuwDd2YQD5eskks+B4idSLUZd1V
|
||||||
|
n+fOt2FzPT2pkptdNtu6EQD6Cg2sRFSr1/NdKvcmSu74EM8ioDerUyT9vHfIWElpvUwkz0F2c2e0
|
||||||
|
vXpNNvuXsQe81EdofRp6LfDVTYMNUUx6PtQEjL//y79kHiuu1mRzc3xWsDg6TwQepu6OOhkEae88
|
||||||
|
TKzIbykanVhZ5SOLI3zWZCmjogx6OCXDhAT9yuoulKYevMWdRl2+GOtu9W6a9CHZBxvBramHc8Up
|
||||||
|
MNsImCJ0jBgThDiA13A6U8UU5IxJ8keQGG8I2At5RecPahjD4yJBrGZgAl2aDgjGwQth47BVXR5l
|
||||||
|
eQyPQVnRMISPjMUvNgNJeRm4GC/BsMh0SsDUKRt6PB+LelHlIJCl9SYjYfewsjbe5Rb8MNvDONKd
|
||||||
|
gclhbEAMywgba6u762TdAkjVIkJrd3Pc9antBBgEgUIEXZPZdDzHL/hCTxEtOGkHtg5XBQiuvsUe
|
||||||
|
DR76a+oCCDe720rWtGncxT69JHi49BTbp/jo8vxe5mEjlx02733NVuNIb4BDao0PgNnDnOWLBoNH
|
||||||
|
YlB8ebb6ulyVUS5at8MW8c5DeVb2IdQv8RN7R9F02VEWRxBM2Yka77HVZzqLMbwdpBSJVyNzRymr
|
||||||
|
rY2zKwusPUQpIrNlFYC+AgOHmmCxbfbkS3mMeR47t0UB/ImJHvBP64C4GV/rebiFCKKw66h78mRA
|
||||||
|
twSKwOJ5DtvXng4ru62tlPsgx+63Hpep/Sgwm2cXH55jxNZzfL4BkkeYHPqNMVBkCSFkaL8isbq4
|
||||||
|
gPmml0AnmldsPYNbvYiZBwEcZhcHqV9my/bSWVDZRwOaUSyxWs4WT17k15keClcZmCRTHua8lmGn
|
||||||
|
2QTR9CDbEtabDUaSqLHoMsfUkgoo+WRNL5+oc6ruJqnc3qeeHIwRldNjBXb81aCHeE7ZOnKrAMs6
|
||||||
|
JeRZtn00qXIQwho5Cek+J21Y5/NJg547drRQwStaLreWg+M2irB7ZM96em4kDSiX65bsDjAeRv5o
|
||||||
|
FxDt+JEqUVmzWUkwhD23rthXls0wLRvTAWgnjGhV9/nQ77ZKLpUhxNTYmEs0LU5pyNKTK4gEBjmb
|
||||||
|
9N1DA/zhpmPNutz0JdZ8A3Kq4OOD8HnUtBma8M99fXOpaxLfHF56VUxFwtozfS10vYc5jhl671zD
|
||||||
|
7ayTOErIrC4EGB9xmPbVK4aBW2dorrdAXwN9HGExwJCikPu4zMx2BF6OmU6k0DCGuXWjXLKi/EVj
|
||||||
|
6cPYlD+EGSr704CPt/NnWBRijjBqDzLaXMIuYrvNkgIm5DYRWWXWs7XbOlC0e4/aBlcNzNn0PcyX
|
||||||
|
4YiA8UmHlRVVDBsOKvQU8Za+RmkqgcwZPIy2xgiWNNzOsBncC9Wt/D6Q7a0RwFsKrlhP44ixuK1z
|
||||||
|
mJ9QT93w2LqzFJocrO+tRA9G1NdLW18R7N2qIuLuRoalfTwrYMvjTOZvf1qorhBgHHuB6tK+cef9
|
||||||
|
dXFgBIUnqTt7x2b1/kkgi/oLavuDDKZdJdrSSqWeatz5zOa38EHAy7BPzYDYYAG9ToDs9yrhv/O7
|
||||||
|
3kQ7Br4WHLD71bdOeCgF5OhVRnJYTWDmzWCG7gXK2MxhmlH95s9gPrsaPvrPul6GduWhpZUdPp5a
|
||||||
|
N2PcoxThd/6pJrx1sN5EJYGnk8OhrdsLw7AUhgfdNssRF36WqOtJGUDH9I5Ym54LYx6gDjjipiH8
|
||||||
|
GzRgILEkwqnjJHqA2qGeuYfFwfO6a5Akth82c+RUwlZrZ3o+eHt3GRvZAiS7xPiwOw86I0WqASsq
|
||||||
|
XiiYHiqYL1AeAa34CTufxHKX61vUwIDHiR6bOM3owxV4MCXdhM3g4ej8RRs8OBlsR80VvrMlLdIW
|
||||||
|
3nfBgJXJ9tz3tXReEIKAYteXjtHylLwGevfXA59j+VMvt9vZhtKaythJKk0XnCFuwFefscdLJhtP
|
||||||
|
bPZkJd16JMolO2PGpChw674Q4QXqsDnzm1y69MuduvxWi3jyiRA8bz2Rekfxra/2W4JwSk0DR8G5
|
||||||
|
1ufr6Qkh+qQ+eX/1tputOYUtvzHJOnU2W80EImmTcBHa8Ks0LPYlTOC2O26o/yjs73qiyR9/5CVW
|
||||||
|
XM/LPJbgFFUB2QUiZatq6RUsJPVB5u/7FOIks+HhaliIPWSZjUXjJHB95C69U3DRp+zJV0BYbJ6c
|
||||||
|
0rAG7KffB6kosW/K7bCcmBVAk5Idtk/1h43h6CuSu79cUNTZF0DuYPSAWRkOmWt5dafuPqxwDdID
|
||||||
|
RuVRdZmcuCvciMlEXSW9slV1Zw74/rPCSF6O7uortgAJ6Z9oEndrNKLXksPMs2qqHhRNHxkALyhe
|
||||||
|
vC1GZr4b5lrqCYyO1pZaqcTrHXgVAbwcNiN267qsV2z2OZzsV0ldheJh1nrowBRr7/96fxetRvBU
|
||||||
|
azySmk2Qje1JEeHWbRD1YL9j1FPzFiaf2id9EunZ+pATGy59ZKNdIGLwez5wwnOGjbJ86Gy6lTG0
|
||||||
|
3vHtT723vJnOcCPGE3a//WnZVbMDY+7IYeWtGtkCWMtDn/VP9OaAoDMuSBLAKamAtdi26+UdVCv0
|
||||||
|
nxHGB0/P63arZ8p+P8aAjKfdFgxTIvYSSooQqy0NdRbWqgDezcajWno5RLO5X0W407WeOob7yGbX
|
||||||
|
OzcSnYH01e8SzHSeEyjPUkidgdcH5sqjBEs+7mgx4339HtpVgKfH1FBNSEyXP1qAg+vmiqj9XHYZ
|
||||||
|
M7o0gWP4qAhI1MWdw9HUYErUiXqqvrB5TIpv/QoT1qT7yBbnagRAfUgmthscg4UF6igp/aTR42ab
|
||||||
|
DvOYezYYNchh/bW/6qPJm7lUJCeGkjXjs7EWmAIN1/Cxsk+ygcwukCCIA4aNoHtHdMCnAOyPTwfj
|
||||||
|
G8uHpRh8Afh+XSGiOfXPn/HAj5CGBKa2jKJE4aExFh+qtJH4/X1eCEP9HpE9pYrO25u4grz+vmHn
|
||||||
|
5o4ZScclgGVVMNTZgLqtsWYljPJUo97xibOlPro3aXEDhUbcbR+tGZgVyNXeh6r5rEbjKh17OJ+P
|
||||||
|
GrX7gg0lW4sc7PjMoLb9GdksWoIEHTUryV5528OuhwUH6xzX1ASdWn/rPQbVkLwRj/kFLP6wjlB7
|
||||||
|
4hrJqDfA1m1OhmQJwZlG5XrK/ughamObmvrwZrPjDC0wN35GcQsUIDiLpsHwfD5SDfcO4B2nbqF0
|
||||||
|
0mqql+sSMRSebHghLwsHXVC79OfHp1ueYPu+J1EHl9oCxFsHqj9Tk9EPM3t4etAGH+73rT7LzeqA
|
||||||
|
0UjP3/mANT2MPJKm4XbCasRb7uI3fgjuZ/FEHfG0uuNmjGYQcy6HXpL6yKbaPdrwlzd1IWuHpZxp
|
||||||
|
DrFcINJ882XHrLEBh0tLv/mnGlbFPoYgDcQc65L7za9bQ4GOESJqQU8DW+tc3sCoPj/Y3Yk70FWH
|
||||||
|
aw4DeTeTfWVe69mSqwamD67E5tfv7OxeKCC+kopsK/uZLbOcIVDySUd21/eoz3TQJCg81itWlxxF
|
||||||
|
8x4VIQzoWiNxF16GOS9bAukzS4i8v1tgvVw3CkRvTsAq92n1dc3nVN5fNx15g04dhL0g3eDY2io1
|
||||||
|
Y/kzrKNRxrK6hCV2HU3Jfs8D4oGcsKY9Fn1BtL8BdFYAdeAxc8k3/8GnlgmoV/rRnRtzEX71RK83
|
||||||
|
q/35j1jKK/tNfdS/wHxvFQu8zWYlUXDW9bksHg7sbo6N3SRt2PrJsARffL9DrDcrdyBp38DPfCoQ
|
||||||
|
f38r2Swmai5/x1Nlu22jKaDFC7KovRDh5FTunBp3WwqCUKEqunLR3KVlKC+39YX4aUjBKD+NEX7X
|
||||||
|
P1r22UOfYzbeoDffTtS5uV7GErZW0h+9KQJDX0JBhkDfLQo1N09nmL/5Bhiu5VPjmqjD7tUYLSik
|
||||||
|
EpA6ZsqwO2VVDlKiT+h1KIaoS4ugBd/5or/1O3qBGcD0yV/oY5tifW0rGoKNadyRYFuBy0javwD4
|
||||||
|
NMo3f94zFtZH/ud/aELW/pvXRgluwfP0x68vh2tg/PG/1tpHOvXUuP/lETKpQslI5AkatGqgI5k/
|
||||||
|
3cGiawuE9VPSyfztZ+xZ1go8v8cYeyipavbzb/GjD6minXuXFLrbg7g01m8+7Fxat9oN4g+j1AIX
|
||||||
|
nH3rrwVj66gYJY9b1L8uLw1WtpqSLd6fM4YsLtj7fJNjJMIzWPL3ywMavozUupb20AfSoYILtDn6
|
||||||
|
84vELv0S5kpIkbjrpuwPfwmGjYQku8uBcFtvEiSkfRIQHludJa8PhEuyAei1trU708ER4SSHB3zU
|
||||||
|
W06fV9Ks8BxeD1SP0rSe57UP4VooPFbf1gHM6LUUcvp4GGT51gfbA1mE9f1YYLuRlWi9kgGB4C5O
|
||||||
|
9GAvtvu5yjoPz5pYY4zugst2+tOBumSH1LxHyJ0jyo9wauua7O3XNVtMO2tg7+V3IontAcxD/Jqh
|
||||||
|
ubwIvbc6ijrTqCXotpJIFc5FbHYvcw4vnNFj5z5u6mWWIw8qkt/Qg3cLszXnPR7IhVVS1XluMxI/
|
||||||
|
xBBAKeapW5ApY5301mBNDlck2h8PrMn7JIJFbs7UiJ8kG03B9cDxWdzx8aPVbOVUmUBouguxfbPP
|
||||||
|
mKjWNxhGYEWQj0/ZrJOa/62fnz7U6zPUUjgEbfdnfuf8VDgwy94HrMe+566L/+bhKPgIl7zI6U8h
|
||||||
|
mhrw3oHDL78z9s3nIqGvHDvGEjG25ftY+vYDrGsvvp6PKioBMFaMxML29UXfXRTw43nuLaBgvb6R
|
||||||
|
Bb/8B9FI74f5+RbKXx6jaJtVQ0f3Q7k/12cNB3nc6lRO1QoWAxeixi1BtPiNGUL+teGpbQvD0Pey
|
||||||
|
IoLp824JTO59NpNJ6SFVFEjawWh1VlbtDF/K3qUG94CMvqdRA6p7fP78l7twyyPYJ3ZlYr+ePjW5
|
||||||
|
vi0L1iVnozJc02hOHa2Rv36ILOUprpcj/tjw+rY+2Dm/d4xOJ92Tf/xJfTAB0N96JpxrUttQ02ht
|
||||||
|
xgDJ19gZ8c//rBN7QMDrJkedOVSzMWzPMdyOyoT2xn2pxy9fg21cNtP+vHPcHnGuBiEkFlW2mDHm
|
||||||
|
OfsZdpYAsMKLhT5eT08OjALpCfzsNLaquO1hcZgcwhwJuIvbXA1gb5wDko4XO1vEArbAvXAyteFh
|
||||||
|
cNnpLGvQng2dPkprqic+WlJZGe0r2fNmw358FBRJxMii6Zcf38x/+QXndfvSf/4QvJ6XHmvRvhtY
|
||||||
|
4D0k6W7A5Ntv+2H83oe2m2Y0g9D55rVAg/0oQiKxswB6Im0SiWxMheL9Z2LdMPYV+OWH42mjusTM
|
||||||
|
ux7MZltgMy0k1ke58gLf+aFq7nhgTjiQ/3nf2mVZh+HKGgIDeTtj7CQkI6qSWFAUP5i8u+d9oD1M
|
||||||
|
uP03X9JDFdpRb5XXGG4dZ0fN4NG7bMcXJehasqBebD9gXiW1lW9pvKO6ewrAt7+uwB1Ch2rlNAHW
|
||||||
|
8bsGfip0Ii/8Zjq7bo4N+PJq+s1/YHs+vnK4+YjoD+9dvFeWAF0cAT7l12Egx3qL4ErFnhpxiIf9
|
||||||
|
GV9FuM5Cg60mr4cF7oIS6s3WJFs327G1GVMEI3nOqbbfpdGojXsR7urHEYlfXracuM4B37z8zW9P
|
||||||
|
0FOtWIF9D05fvnZnzO65HIIXDfHXr331UirBTn1H1OCEsP7pgXQqh4rI85Vn6+/9LE5ikr3/rIcv
|
||||||
|
b7Z+/AEf4aq7c+ebKQDGjIkgVLk+7697B2qCpJGBuqCeOHbWwEswEI6mZpuRn14GhG/JJtKd+udX
|
||||||
|
4I9H+8dmZfP97At/+MzmxuAw4c/Og/sxAdTyzoQxJeBz8PE39pdv7/W3U3UpLDUlxt4NfMD4q9d3
|
||||||
|
lxT4y1N0iryBQLe95tRNPu9h0TYOAvBS6/T40XS2rHYfSycGKbVhcmRbAroQLgN5YG+gA+sfFkhg
|
||||||
|
2PYJEoT9AkZwfcdgE79V6pq+r/M6fwqg3qEL9peDEs1uqCtw95JK8tRufs1AUNnAOJxe1Do5mr47
|
||||||
|
inog/fR9zmPb/fILBRKXn7G3T/WMbu61DS89uyMAETeQjU57sG4yRIZZoe46n68KtNj1g621Z+54
|
||||||
|
VK0S2l6c4JsalD+9QeCb12mmcChbrdsg/vIZxpxfRayMCg++L0WLoDFCl11lXQDvs5lQnz/J7Mej
|
||||||
|
gZvUzc+/RuyRqSncZ1eFfv25y3/rWT6cT3eMkiCrd+mIE7BjoMX6+/OqWXWTX/CrF4QT4ZbNh+3m
|
||||||
|
9ocXWf1QZXMfbC3ocXBPvTPl2Ki3XAo43q7oob/OP55syGopqPSX10lTAQPoUS1SWynf7penoD9+
|
||||||
|
6dSdrXoXk1SERg71n37UfSY4DfzuT2C9I162Oz1CBxo599O/dlhF5hB4E5cEF8MCAG2WbIR5sb38
|
||||||
|
/HJEj2pbwN4r7ljV9B2YuQeCkmJtEsRzjxysxZ4UkJ6klSoXdacT64RmcNw8O3pQjIFR7TlDuM2s
|
||||||
|
FStqumEL3KUV7EygIMsbXwPtD3tJEh6BjP0xnPSfvsBbAkpSfCIuY0HoJ0AXLh2av3x6Fpw1he5B
|
||||||
|
1bC2zPMwvgXqwaOPY4pbfxvN8f4kyE8zeFFzIyqR8Ov/H1+2qdL2ZOgVWw3+jDfubyXaYWfiwIdc
|
||||||
|
P9SrpHR4PbeqAdqoGagenHWXLP7E//whPpyTKWNfvyh/88mf/YWV2h4C1s4LsDMgqtP4YSXQo/mZ
|
||||||
|
BhMph13Udx7cbZMNgS/ryabuo4jwiFyMcXRTsuWbp6BUPg/kQ1EdLQFNGrgxrTs2pE8E+tsaSvAy
|
||||||
|
te6Pp4AJLoMB9XiDsdtMhb5uu30KN9GxR2y6JMP84rcB1IVzh7oqx9my9ZAnhfvXif78ycKUjQPS
|
||||||
|
zb7+7S8w8kQ3AkVJawjvbIp6XuUlgbzURmQP19plV9aMsBssHZtA93Xy88O/vK+ViaiP0r0U5K8f
|
||||||
|
psahcKNZcKQbbFwUY6x2oT6/KBbAb/9HxyWuZ+mZVjByKo/65oV3R/zwNPjjZUyITsMo908e3LVD
|
||||||
|
S/g4pDWTkOj94Y+XdHOpZ16wbLgRhSOST/CarSMnCWDevp/426/An/xvc+1Av+uVzXIj2dI+Fy5Y
|
||||||
|
iY7IpZddj8D+qhK0+er7aqR7CAso+kRsmr6epQu7yTW2Ttj/wCaaP5z1gostEqpc0hHQIuo18OXZ
|
||||||
|
ZJGO/cDsXsjB8T4w0hqqGHV0k3I//4rV6D5n9P1CEjjs9SO27mrGGA6iFGZw80Q722U1WY6cLe2U
|
||||||
|
LsV2RdRhXvkcwfAyZjTcbnBGfzx4t26m335a9NvvBV8egY+k58A01GYKv/yY+rMj6N3LA9ZPn7AD
|
||||||
|
d80wbMWbAr/9AvuqoLBp3r1n+M1zROhj7+tPrQpu1BNP3nZaRUtyaEZ4xK+G/Pj1ck2mAv78uTek
|
||||||
|
gt5utrMo/fTLFHiSrVn/GeUfn4y/ej0XVdjCe5xLGD0kTWeKb8ywAsz4+jPObVM8a/CsSTVG/srp
|
||||||
|
0zBWJfyUq0Qkv/m4q+ynrz/1AGrPYaszxC8Q5I8EzdrZ0dmLditYrhcfW5X/YqOPDhW8XVf05XOb
|
||||||
|
aC3sxYITjD2qLM1xYCu1R2guDUFSv9tnP14n9/dbivFxkw5ffbaBwCqLKi/j5tIHbsr9ZVOpWHck
|
||||||
|
oLPrPnZk8zsT6BNx0dgdzyOA1tum+uodwWrdakn8+3cq4D//9ddf/+t3wqBp78X7ezBgKpbp3/99
|
||||||
|
VODft/vt3zwv/JsKf04ikPFWFn//81+HEP7uhrbppv89ta/iM/79z1/bP6cN/p7a6fb+fy7/6/td
|
||||||
|
//mv/wMAAP//AwDOXgQl4SAAAA==
|
||||||
|
headers:
|
||||||
|
CF-Cache-Status:
|
||||||
|
- DYNAMIC
|
||||||
|
CF-RAY:
|
||||||
|
- 7c09bf823fb50b70-AMS
|
||||||
|
Connection:
|
||||||
|
- keep-alive
|
||||||
|
Content-Encoding:
|
||||||
|
- gzip
|
||||||
|
Content-Type:
|
||||||
|
- application/json
|
||||||
|
Date:
|
||||||
|
- Mon, 01 May 2023 17:29:41 GMT
|
||||||
|
Server:
|
||||||
|
- cloudflare
|
||||||
|
access-control-allow-origin:
|
||||||
|
- '*'
|
||||||
|
alt-svc:
|
||||||
|
- h3=":443"; ma=86400, h3-29=":443"; ma=86400
|
||||||
|
openai-organization:
|
||||||
|
- user-kd1j0bcill5flig1m29wdaof
|
||||||
|
openai-processing-ms:
|
||||||
|
- '69'
|
||||||
|
openai-version:
|
||||||
|
- '2020-10-01'
|
||||||
|
strict-transport-security:
|
||||||
|
- max-age=15724800; includeSubDomains
|
||||||
|
x-ratelimit-limit-requests:
|
||||||
|
- '3000'
|
||||||
|
x-ratelimit-remaining-requests:
|
||||||
|
- '2999'
|
||||||
|
x-ratelimit-reset-requests:
|
||||||
|
- 20ms
|
||||||
|
x-request-id:
|
||||||
|
- 555d4ffdb6ceac9f62f60bb64d87170d
|
||||||
|
status:
|
||||||
|
code: 200
|
||||||
|
message: OK
|
||||||
|
version: 1
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,14 @@
|
|||||||
import string
|
import string
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from numpy.random import RandomState
|
from numpy.random import RandomState
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from autogpt.llm.llm_utils import get_ada_embedding
|
from autogpt.config import Config
|
||||||
|
from autogpt.llm import llm_utils
|
||||||
|
from autogpt.llm.api_manager import ApiManager
|
||||||
|
from autogpt.llm.modelsinfo import COSTS
|
||||||
from tests.utils import requires_api_key
|
from tests.utils import requires_api_key
|
||||||
|
|
||||||
|
|
||||||
@@ -16,10 +21,42 @@ def random_large_string():
|
|||||||
return "".join(random.choice(list(string.ascii_lowercase), size=n_characters))
|
return "".join(random.choice(list(string.ascii_lowercase), size=n_characters))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="We have no mechanism for embedding large strings.")
|
@pytest.fixture()
|
||||||
|
def api_manager(mocker: MockerFixture):
|
||||||
|
api_manager = ApiManager()
|
||||||
|
mocker.patch.multiple(
|
||||||
|
api_manager,
|
||||||
|
total_prompt_tokens=0,
|
||||||
|
total_completion_tokens=0,
|
||||||
|
total_cost=0,
|
||||||
|
)
|
||||||
|
yield api_manager
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def spy_create_embedding(mocker: MockerFixture):
|
||||||
|
return mocker.spy(llm_utils, "create_embedding")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr
|
||||||
|
@requires_api_key("OPENAI_API_KEY")
|
||||||
|
def test_get_ada_embedding(
|
||||||
|
config: Config, api_manager: ApiManager, spy_create_embedding: MagicMock
|
||||||
|
):
|
||||||
|
token_cost = COSTS[config.embedding_model]["prompt"]
|
||||||
|
llm_utils.get_ada_embedding("test")
|
||||||
|
|
||||||
|
spy_create_embedding.assert_called_once_with("test", model=config.embedding_model)
|
||||||
|
|
||||||
|
assert (prompt_tokens := api_manager.get_total_prompt_tokens()) == 1
|
||||||
|
assert api_manager.get_total_completion_tokens() == 0
|
||||||
|
assert api_manager.get_total_cost() == (prompt_tokens * token_cost) / 1000
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr
|
||||||
@requires_api_key("OPENAI_API_KEY")
|
@requires_api_key("OPENAI_API_KEY")
|
||||||
def test_get_ada_embedding_large_context(random_large_string):
|
def test_get_ada_embedding_large_context(random_large_string):
|
||||||
# This test should be able to mock the openai call after we have a fix. We don't need
|
# This test should be able to mock the openai call after we have a fix. We don't need
|
||||||
# to hit the API to test the logic of the function (so not using vcr). This is a quick
|
# to hit the API to test the logic of the function (so not using vcr). This is a quick
|
||||||
# regression test to document the issue.
|
# regression test to document the issue.
|
||||||
get_ada_embedding(random_large_string)
|
llm_utils.get_ada_embedding(random_large_string)
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from openai.error import APIError, RateLimitError
|
from openai.error import APIError, RateLimitError
|
||||||
|
|
||||||
from autogpt.llm import COSTS, get_ada_embedding
|
from autogpt.llm import llm_utils
|
||||||
from autogpt.llm.llm_utils import retry_openai_api
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=[RateLimitError, APIError])
|
@pytest.fixture(params=[RateLimitError, APIError])
|
||||||
@@ -13,22 +12,12 @@ def error(request):
|
|||||||
return request.param("Error")
|
return request.param("Error")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_create_embedding(mocker):
|
|
||||||
mock_response = mocker.MagicMock()
|
|
||||||
mock_response.usage.prompt_tokens = 5
|
|
||||||
mock_response.__getitem__.side_effect = lambda key: [{"embedding": [0.1, 0.2, 0.3]}]
|
|
||||||
return mocker.patch(
|
|
||||||
"autogpt.llm.llm_utils.create_embedding", return_value=mock_response
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def error_factory(error_instance, error_count, retry_count, warn_user=True):
|
def error_factory(error_instance, error_count, retry_count, warn_user=True):
|
||||||
class RaisesError:
|
class RaisesError:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.count = 0
|
self.count = 0
|
||||||
|
|
||||||
@retry_openai_api(
|
@llm_utils.retry_openai_api(
|
||||||
num_retries=retry_count, backoff_base=0.001, warn_user=warn_user
|
num_retries=retry_count, backoff_base=0.001, warn_user=warn_user
|
||||||
)
|
)
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
@@ -41,7 +30,7 @@ def error_factory(error_instance, error_count, retry_count, warn_user=True):
|
|||||||
|
|
||||||
|
|
||||||
def test_retry_open_api_no_error(capsys):
|
def test_retry_open_api_no_error(capsys):
|
||||||
@retry_openai_api()
|
@llm_utils.retry_openai_api()
|
||||||
def f():
|
def f():
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
@@ -114,16 +103,31 @@ def test_retry_openapi_other_api_error(capsys):
|
|||||||
assert output.out == ""
|
assert output.out == ""
|
||||||
|
|
||||||
|
|
||||||
def test_get_ada_embedding(mock_create_embedding, api_manager):
|
def test_chunked_tokens():
|
||||||
model = "text-embedding-ada-002"
|
text = "Auto-GPT is an experimental open-source application showcasing the capabilities of the GPT-4 language model"
|
||||||
embedding = get_ada_embedding("test")
|
expected_output = [
|
||||||
mock_create_embedding.assert_called_once_with(
|
(
|
||||||
"test", model="text-embedding-ada-002"
|
13556,
|
||||||
)
|
12279,
|
||||||
|
2898,
|
||||||
assert embedding == [0.1, 0.2, 0.3]
|
374,
|
||||||
|
459,
|
||||||
cost = COSTS[model]["prompt"]
|
22772,
|
||||||
assert api_manager.get_total_prompt_tokens() == 5
|
1825,
|
||||||
assert api_manager.get_total_completion_tokens() == 0
|
31874,
|
||||||
assert api_manager.get_total_cost() == (5 * cost) / 1000
|
3851,
|
||||||
|
67908,
|
||||||
|
279,
|
||||||
|
17357,
|
||||||
|
315,
|
||||||
|
279,
|
||||||
|
480,
|
||||||
|
2898,
|
||||||
|
12,
|
||||||
|
19,
|
||||||
|
4221,
|
||||||
|
1646,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
output = list(llm_utils.chunked_tokens(text, "cl100k_base", 8191))
|
||||||
|
assert output == expected_output
|
||||||
Reference in New Issue
Block a user