From fa91bc154c239f569346d72e81e5b11b24267413 Mon Sep 17 00:00:00 2001 From: bszollosinagy <4211175+bszollosinagy@users.noreply.github.com> Date: Wed, 19 Apr 2023 23:28:57 +0200 Subject: [PATCH] Fix model context overflow issue (#2542) Co-authored-by: batyu --- .env.template | 7 ++-- autogpt/config/config.py | 5 ++- autogpt/processing/text.py | 74 +++++++++++++++++++++++++++++--------- requirements.txt | 2 ++ 4 files changed, 68 insertions(+), 20 deletions(-) diff --git a/.env.template b/.env.template index 58486904..f1b511c2 100644 --- a/.env.template +++ b/.env.template @@ -7,9 +7,6 @@ # EXECUTE_LOCAL_COMMANDS=False # RESTRICT_TO_WORKSPACE=True -## BROWSE_CHUNK_MAX_LENGTH - When browsing website, define the length of chunk stored in memory -# BROWSE_CHUNK_MAX_LENGTH=8192 - ## USER_AGENT - Define the user-agent used by the requests library to browse website (string) # USER_AGENT="Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36" @@ -152,6 +149,10 @@ OPENAI_API_KEY=your-openai-api-key ## Note: set this to either 'chrome', 'firefox', or 'safari' depending on your current browser # HEADLESS_BROWSER=True # USE_WEB_BROWSER=chrome +## BROWSE_CHUNK_MAX_LENGTH - When browsing website, define the length of chunks to summarize (in number of tokens, excluding the response. 75 % of FAST_TOKEN_LIMIT is usually wise ) +# BROWSE_CHUNK_MAX_LENGTH=3000 +## BROWSE_SPACY_LANGUAGE_MODEL is used to split sentences. Install additional languages via pip, and set the model name here. Example Chinese: python -m spacy download zh_core_web_sm +# BROWSE_SPACY_LANGUAGE_MODEL=en_core_web_sm ### GOOGLE ## GOOGLE_API_KEY - Google API key (Example: my-google-api-key) diff --git a/autogpt/config/config.py b/autogpt/config/config.py index 0c4576da..c284a4ac 100644 --- a/autogpt/config/config.py +++ b/autogpt/config/config.py @@ -31,7 +31,10 @@ class Config(metaclass=Singleton): self.smart_llm_model = os.getenv("SMART_LLM_MODEL", "gpt-4") self.fast_token_limit = int(os.getenv("FAST_TOKEN_LIMIT", 4000)) self.smart_token_limit = int(os.getenv("SMART_TOKEN_LIMIT", 8000)) - self.browse_chunk_max_length = int(os.getenv("BROWSE_CHUNK_MAX_LENGTH", 8192)) + self.browse_chunk_max_length = int(os.getenv("BROWSE_CHUNK_MAX_LENGTH", 3000)) + self.browse_spacy_language_model = os.getenv( + "BROWSE_SPACY_LANGUAGE_MODEL", "en_core_web_sm" + ) self.openai_api_key = os.getenv("OPENAI_API_KEY") self.temperature = float(os.getenv("TEMPERATURE", "0")) diff --git a/autogpt/processing/text.py b/autogpt/processing/text.py index 130de473..2122f0f0 100644 --- a/autogpt/processing/text.py +++ b/autogpt/processing/text.py @@ -1,8 +1,10 @@ """Text processing functions""" from typing import Dict, Generator, Optional +import spacy from selenium.webdriver.remote.webdriver import WebDriver +from autogpt import token_counter from autogpt.config import Config from autogpt.llm_utils import create_chat_completion from autogpt.memory import get_memory @@ -11,7 +13,12 @@ CFG = Config() MEMORY = get_memory(CFG) -def split_text(text: str, max_length: int = 8192) -> Generator[str, None, None]: +def split_text( + text: str, + max_length: int = CFG.browse_chunk_max_length, + model: str = CFG.fast_llm_model, + question: str = "", +) -> Generator[str, None, None]: """Split text into chunks of a maximum length Args: @@ -24,21 +31,46 @@ def split_text(text: str, max_length: int = 8192) -> Generator[str, None, None]: Raises: ValueError: If the text is longer than the maximum length """ - paragraphs = text.split("\n") - current_length = 0 + flatened_paragraphs = " ".join(text.split("\n")) + nlp = spacy.load(CFG.browse_spacy_language_model) + nlp.add_pipe("sentencizer") + doc = nlp(flatened_paragraphs) + sentences = [sent.text.strip() for sent in doc.sents] + current_chunk = [] - for paragraph in paragraphs: - if current_length + len(paragraph) + 1 <= max_length: - current_chunk.append(paragraph) - current_length += len(paragraph) + 1 + for sentence in sentences: + message_with_additional_sentence = [ + create_message(" ".join(current_chunk) + " " + sentence, question) + ] + + expected_token_usage = ( + token_usage_of_chunk(messages=message_with_additional_sentence, model=model) + + 1 + ) + if expected_token_usage <= max_length: + current_chunk.append(sentence) else: - yield "\n".join(current_chunk) - current_chunk = [paragraph] - current_length = len(paragraph) + 1 + yield " ".join(current_chunk) + current_chunk = [sentence] + message_this_sentence_only = [ + create_message(" ".join(current_chunk), question) + ] + expected_token_usage = ( + token_usage_of_chunk(messages=message_this_sentence_only, model=model) + + 1 + ) + if expected_token_usage > max_length: + raise ValueError( + f"Sentence is too long in webpage: {expected_token_usage} tokens." + ) if current_chunk: - yield "\n".join(current_chunk) + yield " ".join(current_chunk) + + +def token_usage_of_chunk(messages, model): + return token_counter.count_message_tokens(messages, model) def summarize_text( @@ -58,11 +90,16 @@ def summarize_text( if not text: return "Error: No text to summarize" + model = CFG.fast_llm_model text_length = len(text) print(f"Text length: {text_length} characters") summaries = [] - chunks = list(split_text(text, CFG.browse_chunk_max_length)) + chunks = list( + split_text( + text, max_length=CFG.browse_chunk_max_length, model=model, question=question + ), + ) scroll_ratio = 1 / len(chunks) for i, chunk in enumerate(chunks): @@ -74,15 +111,20 @@ def summarize_text( MEMORY.add(memory_to_add) - print(f"Summarizing chunk {i + 1} / {len(chunks)}") messages = [create_message(chunk, question)] + tokens_for_chunk = token_counter.count_message_tokens(messages, model) + print( + f"Summarizing chunk {i + 1} / {len(chunks)} of length {len(chunk)} characters, or {tokens_for_chunk} tokens" + ) summary = create_chat_completion( - model=CFG.fast_llm_model, + model=model, messages=messages, ) summaries.append(summary) - print(f"Added chunk {i + 1} summary to memory") + print( + f"Added chunk {i + 1} summary to memory, of length {len(summary)} characters" + ) memory_to_add = f"Source: {url}\n" f"Content summary part#{i + 1}: {summary}" @@ -94,7 +136,7 @@ def summarize_text( messages = [create_message(combined_summary, question)] return create_chat_completion( - model=CFG.fast_llm_model, + model=model, messages=messages, ) diff --git a/requirements.txt b/requirements.txt index e2d76e04..66c90c79 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,8 @@ webdriver-manager jsonschema tweepy click +spacy>=3.0.0,<4.0.0 +en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.4.0/en_core_web_sm-3.4.0-py3-none-any.whl ##Dev coverage