Fix model context overflow issue (#2542)

Co-authored-by: batyu <batyu@localhost>
This commit is contained in:
bszollosinagy
2023-04-19 23:28:57 +02:00
committed by GitHub
parent a5a9b5dbd8
commit fa91bc154c
4 changed files with 68 additions and 20 deletions

View File

@@ -7,9 +7,6 @@
# EXECUTE_LOCAL_COMMANDS=False
# RESTRICT_TO_WORKSPACE=True
## BROWSE_CHUNK_MAX_LENGTH - When browsing website, define the length of chunk stored in memory
# BROWSE_CHUNK_MAX_LENGTH=8192
## USER_AGENT - Define the user-agent used by the requests library to browse website (string)
# USER_AGENT="Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36"
@@ -152,6 +149,10 @@ OPENAI_API_KEY=your-openai-api-key
## Note: set this to either 'chrome', 'firefox', or 'safari' depending on your current browser
# HEADLESS_BROWSER=True
# USE_WEB_BROWSER=chrome
## BROWSE_CHUNK_MAX_LENGTH - When browsing website, define the length of chunks to summarize (in number of tokens, excluding the response. 75 % of FAST_TOKEN_LIMIT is usually wise )
# BROWSE_CHUNK_MAX_LENGTH=3000
## BROWSE_SPACY_LANGUAGE_MODEL is used to split sentences. Install additional languages via pip, and set the model name here. Example Chinese: python -m spacy download zh_core_web_sm
# BROWSE_SPACY_LANGUAGE_MODEL=en_core_web_sm
### GOOGLE
## GOOGLE_API_KEY - Google API key (Example: my-google-api-key)

View File

@@ -31,7 +31,10 @@ class Config(metaclass=Singleton):
self.smart_llm_model = os.getenv("SMART_LLM_MODEL", "gpt-4")
self.fast_token_limit = int(os.getenv("FAST_TOKEN_LIMIT", 4000))
self.smart_token_limit = int(os.getenv("SMART_TOKEN_LIMIT", 8000))
self.browse_chunk_max_length = int(os.getenv("BROWSE_CHUNK_MAX_LENGTH", 8192))
self.browse_chunk_max_length = int(os.getenv("BROWSE_CHUNK_MAX_LENGTH", 3000))
self.browse_spacy_language_model = os.getenv(
"BROWSE_SPACY_LANGUAGE_MODEL", "en_core_web_sm"
)
self.openai_api_key = os.getenv("OPENAI_API_KEY")
self.temperature = float(os.getenv("TEMPERATURE", "0"))

View File

@@ -1,8 +1,10 @@
"""Text processing functions"""
from typing import Dict, Generator, Optional
import spacy
from selenium.webdriver.remote.webdriver import WebDriver
from autogpt import token_counter
from autogpt.config import Config
from autogpt.llm_utils import create_chat_completion
from autogpt.memory import get_memory
@@ -11,7 +13,12 @@ CFG = Config()
MEMORY = get_memory(CFG)
def split_text(text: str, max_length: int = 8192) -> Generator[str, None, None]:
def split_text(
text: str,
max_length: int = CFG.browse_chunk_max_length,
model: str = CFG.fast_llm_model,
question: str = "",
) -> Generator[str, None, None]:
"""Split text into chunks of a maximum length
Args:
@@ -24,21 +31,46 @@ def split_text(text: str, max_length: int = 8192) -> Generator[str, None, None]:
Raises:
ValueError: If the text is longer than the maximum length
"""
paragraphs = text.split("\n")
current_length = 0
flatened_paragraphs = " ".join(text.split("\n"))
nlp = spacy.load(CFG.browse_spacy_language_model)
nlp.add_pipe("sentencizer")
doc = nlp(flatened_paragraphs)
sentences = [sent.text.strip() for sent in doc.sents]
current_chunk = []
for paragraph in paragraphs:
if current_length + len(paragraph) + 1 <= max_length:
current_chunk.append(paragraph)
current_length += len(paragraph) + 1
for sentence in sentences:
message_with_additional_sentence = [
create_message(" ".join(current_chunk) + " " + sentence, question)
]
expected_token_usage = (
token_usage_of_chunk(messages=message_with_additional_sentence, model=model)
+ 1
)
if expected_token_usage <= max_length:
current_chunk.append(sentence)
else:
yield "\n".join(current_chunk)
current_chunk = [paragraph]
current_length = len(paragraph) + 1
yield " ".join(current_chunk)
current_chunk = [sentence]
message_this_sentence_only = [
create_message(" ".join(current_chunk), question)
]
expected_token_usage = (
token_usage_of_chunk(messages=message_this_sentence_only, model=model)
+ 1
)
if expected_token_usage > max_length:
raise ValueError(
f"Sentence is too long in webpage: {expected_token_usage} tokens."
)
if current_chunk:
yield "\n".join(current_chunk)
yield " ".join(current_chunk)
def token_usage_of_chunk(messages, model):
return token_counter.count_message_tokens(messages, model)
def summarize_text(
@@ -58,11 +90,16 @@ def summarize_text(
if not text:
return "Error: No text to summarize"
model = CFG.fast_llm_model
text_length = len(text)
print(f"Text length: {text_length} characters")
summaries = []
chunks = list(split_text(text, CFG.browse_chunk_max_length))
chunks = list(
split_text(
text, max_length=CFG.browse_chunk_max_length, model=model, question=question
),
)
scroll_ratio = 1 / len(chunks)
for i, chunk in enumerate(chunks):
@@ -74,15 +111,20 @@ def summarize_text(
MEMORY.add(memory_to_add)
print(f"Summarizing chunk {i + 1} / {len(chunks)}")
messages = [create_message(chunk, question)]
tokens_for_chunk = token_counter.count_message_tokens(messages, model)
print(
f"Summarizing chunk {i + 1} / {len(chunks)} of length {len(chunk)} characters, or {tokens_for_chunk} tokens"
)
summary = create_chat_completion(
model=CFG.fast_llm_model,
model=model,
messages=messages,
)
summaries.append(summary)
print(f"Added chunk {i + 1} summary to memory")
print(
f"Added chunk {i + 1} summary to memory, of length {len(summary)} characters"
)
memory_to_add = f"Source: {url}\n" f"Content summary part#{i + 1}: {summary}"
@@ -94,7 +136,7 @@ def summarize_text(
messages = [create_message(combined_summary, question)]
return create_chat_completion(
model=CFG.fast_llm_model,
model=model,
messages=messages,
)

View File

@@ -20,6 +20,8 @@ webdriver-manager
jsonschema
tweepy
click
spacy>=3.0.0,<4.0.0
en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.4.0/en_core_web_sm-3.4.0-py3-none-any.whl
##Dev
coverage