Rebase MessageHistory on ChatSequence (#4922)

* Rebase `MessageHistory` on `ChatSequence`

* Process feedback & make mypy happy

---------

Co-authored-by: James Collins <collijk@uw.edu>
This commit is contained in:
Reinier van der Leer
2023-07-09 19:52:59 +02:00
committed by GitHub
parent 7dc6d736c7
commit 1e1eff70bc
6 changed files with 145 additions and 108 deletions

View File

@@ -70,7 +70,7 @@ class Agent:
):
self.ai_name = ai_name
self.memory = memory
self.history = MessageHistory(self)
self.history = MessageHistory.for_model(config.smart_llm, agent=self)
self.next_action_count = next_action_count
self.command_registry = command_registry
self.config = config

View File

@@ -1,13 +1,14 @@
from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass, field
from math import ceil, floor
from typing import TYPE_CHECKING, List, Literal, Optional, TypedDict
from typing import TYPE_CHECKING, Literal, Optional, Type, TypedDict, TypeVar, overload
if TYPE_CHECKING:
from autogpt.llm.providers.openai import OpenAIFunctionCall
MessageRole = Literal["system", "user", "assistant"]
MessageRole = Literal["system", "user", "assistant", "function"]
MessageType = Literal["ai_response", "action_result"]
TText = list[int]
@@ -68,15 +69,31 @@ class EmbeddingModelInfo(ModelInfo):
embedding_dimensions: int
# Can be replaced by Self in Python 3.11
TChatSequence = TypeVar("TChatSequence", bound="ChatSequence")
@dataclass
class ChatSequence:
"""Utility container for a chat sequence"""
model: ChatModelInfo
messages: list[Message] = field(default_factory=list)
messages: list[Message] = field(default_factory=list[Message])
def __getitem__(self, i: int):
return self.messages[i]
@overload
def __getitem__(self, key: int) -> Message:
...
@overload
def __getitem__(self: TChatSequence, key: slice) -> TChatSequence:
...
def __getitem__(self: TChatSequence, key: int | slice) -> Message | TChatSequence:
if isinstance(key, slice):
copy = deepcopy(self)
copy.messages = self.messages[key]
return copy
return self.messages[key]
def __iter__(self):
return iter(self.messages)
@@ -84,6 +101,14 @@ class ChatSequence:
def __len__(self):
return len(self.messages)
def add(
self,
message_role: MessageRole,
content: str,
type: MessageType | None = None,
) -> None:
self.append(Message(message_role, content, type))
def append(self, message: Message):
return self.messages.append(message)
@@ -95,21 +120,23 @@ class ChatSequence:
self.messages.insert(index, message)
@classmethod
def for_model(cls, model_name: str, messages: list[Message] | ChatSequence = []):
def for_model(
cls: Type[TChatSequence],
model_name: str,
messages: list[Message] | ChatSequence = [],
**kwargs,
) -> TChatSequence:
from autogpt.llm.providers.openai import OPEN_AI_CHAT_MODELS
if not model_name in OPEN_AI_CHAT_MODELS:
raise ValueError(f"Unknown chat model '{model_name}'")
return ChatSequence(
model=OPEN_AI_CHAT_MODELS[model_name], messages=list(messages)
return cls(
model=OPEN_AI_CHAT_MODELS[model_name], messages=list(messages), **kwargs
)
def add(self, message_role: MessageRole, content: str):
self.messages.append(Message(message_role, content))
@property
def token_length(self):
def token_length(self) -> int:
from autogpt.llm.utils import count_message_tokens
return count_message_tokens(self.messages, self.model.name)
@@ -128,7 +155,7 @@ class ChatSequence:
[f"{separator(m.role)}\n{m.content}" for m in self.messages]
)
return f"""
============== ChatSequence ==============
============== {__class__.__name__} ==============
Length: {self.token_length} tokens; {len(self.messages)} messages
{formatted_messages}
==========================================
@@ -148,7 +175,7 @@ class LLMResponse:
class EmbeddingModelResponse(LLMResponse):
"""Standard response struct for a response from an embedding model."""
embedding: List[float] = field(default_factory=list)
embedding: list[float] = field(default_factory=list)
def __post_init__(self):
if self.completion_tokens_used:

View File

@@ -93,13 +93,13 @@ def chat_with_ai(
# Account for user input (appended later)
user_input_msg = Message("user", triggering_prompt)
current_tokens_used += count_message_tokens([user_input_msg], model)
current_tokens_used += count_message_tokens(user_input_msg, model)
current_tokens_used += 500 # Reserve space for new_summary_message
current_tokens_used += agent.history.max_summary_tlength # Reserve space
current_tokens_used += 500 # Reserve space for the openai functions TODO improve
# Add Messages until the token limit is reached or there are no more messages to add.
for cycle in reversed(list(agent.history.per_cycle(agent.config))):
for cycle in reversed(list(agent.history.per_cycle())):
messages_to_add = [msg for msg in cycle if msg is not None]
tokens_to_add = count_message_tokens(messages_to_add, model)
if current_tokens_used + tokens_to_add > send_token_limit:
@@ -115,9 +115,9 @@ def chat_with_ai(
new_summary_message, trimmed_messages = agent.history.trim_messages(
current_message_chain=list(message_sequence), config=agent.config
)
tokens_to_add = count_message_tokens([new_summary_message], model)
tokens_to_add = count_message_tokens(new_summary_message, model)
message_sequence.insert(insertion_index, new_summary_message)
current_tokens_used += tokens_to_add - 500
current_tokens_used += tokens_to_add - agent.history.max_summary_tlength
# FIXME: uncomment when memory is back in use
# memory_store = get_memory(config)
@@ -143,7 +143,7 @@ def chat_with_ai(
)
logger.debug(budget_message)
message_sequence.add("system", budget_message)
current_tokens_used += count_message_tokens([message_sequence[-1]], model)
current_tokens_used += count_message_tokens(message_sequence[-1], model)
# Append user input, the length of this is accounted for above
message_sequence.append(user_input_msg)
@@ -157,9 +157,7 @@ def chat_with_ai(
)
if not plugin_response or plugin_response == "":
continue
tokens_to_add = count_message_tokens(
[Message("system", plugin_response)], model
)
tokens_to_add = count_message_tokens(Message("system", plugin_response), model)
if current_tokens_used + tokens_to_add > send_token_limit:
logger.debug(f"Plugin response too long, skipping: {plugin_response}")
logger.debug(f"Plugins remaining at stop: {plugin_count - i}")

View File

@@ -1,7 +1,7 @@
"""Functions for counting the number of tokens in a message or string."""
from __future__ import annotations
from typing import List
from typing import List, overload
import tiktoken
@@ -9,8 +9,18 @@ from autogpt.llm.base import Message
from autogpt.logs import logger
@overload
def count_message_tokens(messages: Message, model: str = "gpt-3.5-turbo") -> int:
...
@overload
def count_message_tokens(messages: List[Message], model: str = "gpt-3.5-turbo") -> int:
...
def count_message_tokens(
messages: List[Message], model: str = "gpt-3.5-turbo-0301"
messages: Message | List[Message], model: str = "gpt-3.5-turbo"
) -> int:
"""
Returns the number of tokens used by a list of messages.
@@ -24,6 +34,9 @@ def count_message_tokens(
Returns:
int: The number of tokens used by the list of messages.
"""
if isinstance(messages, Message):
messages = [messages]
if model.startswith("gpt-3.5-turbo"):
tokens_per_message = (
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n

View File

@@ -2,49 +2,46 @@ from __future__ import annotations
import copy
import json
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from autogpt.agent import Agent
from autogpt.config import Config
from autogpt.json_utils.utilities import extract_json_from_response
from autogpt.llm.base import ChatSequence, Message, MessageRole, MessageType
from autogpt.llm.base import ChatSequence, Message
from autogpt.llm.providers.openai import OPEN_AI_CHAT_MODELS
from autogpt.llm.utils import count_string_tokens, create_chat_completion
from autogpt.llm.utils import (
count_message_tokens,
count_string_tokens,
create_chat_completion,
)
from autogpt.log_cycle.log_cycle import PROMPT_SUMMARY_FILE_NAME, SUMMARY_FILE_NAME
from autogpt.logs import logger
@dataclass
class MessageHistory:
agent: Agent
messages: list[Message] = field(default_factory=list)
class MessageHistory(ChatSequence):
max_summary_tlength: int = 500
agent: Optional[Agent] = None
summary: str = "I was created"
last_trimmed_index: int = 0
def __getitem__(self, i: int):
return self.messages[i]
SUMMARIZATION_PROMPT = '''Your task is to create a concise running summary of actions and information results in the provided text, focusing on key and potentially important information to remember.
def __iter__(self):
return iter(self.messages)
You will receive the current summary and your latest actions. Combine them, adding relevant key information from the latest development in 1st person past tense and keeping the summary concise.
def __len__(self):
return len(self.messages)
Summary So Far:
"""
{summary}
"""
def add(
self,
role: MessageRole,
content: str,
type: MessageType | None = None,
):
return self.append(Message(role, content, type))
def append(self, message: Message):
return self.messages.append(message)
Latest Development:
"""
{new_events}
"""
'''
def trim_messages(
self, current_message_chain: list[Message], config: Config
@@ -84,7 +81,7 @@ class MessageHistory:
return new_summary_message, new_messages_not_in_chain
def per_cycle(self, config: Config, messages: list[Message] | None = None):
def per_cycle(self, messages: list[Message] | None = None):
"""
Yields:
Message: a message containing user input
@@ -119,26 +116,33 @@ class MessageHistory:
)
def update_running_summary(
self, new_events: list[Message], config: Config
self,
new_events: list[Message],
config: Config,
max_summary_length: Optional[int] = None,
) -> Message:
"""
This function takes a list of dictionaries representing new events and combines them with the current summary,
focusing on key and potentially important information to remember. The updated summary is returned in a message
formatted in the 1st person past tense.
This function takes a list of Message objects and updates the running summary
to include the events they describe. The updated summary is returned
in a Message formatted in the 1st person past tense.
Args:
new_events (List[Dict]): A list of dictionaries containing the latest events to be added to the summary.
new_events: A list of Messages containing the latest events to be added to the summary.
Returns:
str: A message containing the updated summary of actions, formatted in the 1st person past tense.
Message: a Message containing the updated running summary.
Example:
```py
new_events = [{"event": "entered the kitchen."}, {"event": "found a scrawled note with the number 7"}]
update_running_summary(new_events)
# Returns: "This reminds you of these events from your past: \nI entered the kitchen and found a scrawled note saying 7."
```
"""
if not new_events:
return self.summary_message()
if not max_summary_length:
max_summary_length = self.max_summary_tlength
# Create a copy of the new_events list to prevent modifying the original list
new_events = copy.deepcopy(new_events)
@@ -166,29 +170,29 @@ class MessageHistory:
elif event.role == "user":
new_events.remove(event)
# Summarize events and current summary in batch to a new running summary
summ_model = OPEN_AI_CHAT_MODELS[config.fast_llm]
# Assume an upper bound length for the summary prompt template, i.e. Your task is to create a concise running summary...., in summarize_batch func
# TODO make this default dynamic
prompt_template_length = 100
max_tokens = OPEN_AI_CHAT_MODELS.get(config.fast_llm).max_tokens
summary_tlength = count_string_tokens(str(self.summary), config.fast_llm)
# Determine token lengths for use in batching
prompt_template_length = len(
MessageHistory.SUMMARIZATION_PROMPT.format(summary="", new_events="")
)
max_input_tokens = summ_model.max_tokens - max_summary_length
summary_tlength = count_string_tokens(self.summary, summ_model.name)
batch = []
batch_tlength = 0
# TODO Can put a cap on length of total new events and drop some previous events to save API cost, but need to think thru more how to do it without losing the context
# TODO: Put a cap on length of total new events and drop some previous events to
# save API cost. Need to think thru more how to do it without losing the context.
for event in new_events:
event_tlength = count_string_tokens(str(event), config.fast_llm)
event_tlength = count_message_tokens(event, summ_model.name)
if (
batch_tlength + event_tlength
> max_tokens - prompt_template_length - summary_tlength
> max_input_tokens - prompt_template_length - summary_tlength
):
# The batch is full. Summarize it and start a new one.
self.summarize_batch(batch, config)
summary_tlength = count_string_tokens(
str(self.summary), config.fast_llm
)
self.summarize_batch(batch, config, max_summary_length)
summary_tlength = count_string_tokens(self.summary, summ_model.name)
batch = [event]
batch_tlength = event_tlength
else:
@@ -197,41 +201,36 @@ class MessageHistory:
if batch:
# There's an unprocessed batch. Summarize it.
self.summarize_batch(batch, config)
self.summarize_batch(batch, config, max_summary_length)
return self.summary_message()
def summarize_batch(self, new_events_batch, config):
prompt = f'''Your task is to create a concise running summary of actions and information results in the provided text, focusing on key and potentially important information to remember.
You will receive the current summary and your latest actions. Combine them, adding relevant key information from the latest development in 1st person past tense and keeping the summary concise.
Summary So Far:
"""
{self.summary}
"""
Latest Development:
"""
{new_events_batch or "Nothing new happened."}
"""
'''
def summarize_batch(
self, new_events_batch: list[Message], config: Config, max_output_length: int
):
prompt = MessageHistory.SUMMARIZATION_PROMPT.format(
summary=self.summary, new_events=new_events_batch
)
prompt = ChatSequence.for_model(config.fast_llm, [Message("user", prompt)])
self.agent.log_cycle_handler.log_cycle(
self.agent.ai_name,
self.agent.created_at,
self.agent.cycle_count,
prompt.raw(),
PROMPT_SUMMARY_FILE_NAME,
)
if self.agent:
self.agent.log_cycle_handler.log_cycle(
self.agent.ai_config.ai_name,
self.agent.created_at,
self.agent.cycle_count,
prompt.raw(),
PROMPT_SUMMARY_FILE_NAME,
)
self.summary = create_chat_completion(prompt, config).content
self.summary = create_chat_completion(
prompt, config, max_tokens=max_output_length
).content
self.agent.log_cycle_handler.log_cycle(
self.agent.ai_name,
self.agent.created_at,
self.agent.cycle_count,
self.summary,
SUMMARY_FILE_NAME,
)
if self.agent:
self.agent.log_cycle_handler.log_cycle(
self.agent.ai_config.ai_name,
self.agent.created_at,
self.agent.cycle_count,
self.summary,
SUMMARY_FILE_NAME,
)

View File

@@ -38,8 +38,8 @@ def agent(config: Config):
return agent
def test_message_history_batch_summary(mocker, agent, config):
history = MessageHistory(agent)
def test_message_history_batch_summary(mocker, agent: Agent, config: Config):
history = MessageHistory.for_model(agent.config.smart_llm, agent=agent)
model = config.fast_llm
message_tlength = 0
message_count = 0
@@ -48,7 +48,7 @@ def test_message_history_batch_summary(mocker, agent, config):
mock_summary_response = ChatModelResponse(
model_info=OPEN_AI_CHAT_MODELS[model],
content="I executed browse_website command for each of the websites returned from Google search, but none of them have any job openings.",
function_call={},
function_call=None,
)
mock_summary = mocker.patch(
"autogpt.memory.message_history.create_chat_completion",
@@ -105,7 +105,7 @@ def test_message_history_batch_summary(mocker, agent, config):
result = (
"Command browse_website returned: Answer gathered from website: The text in job"
+ str(i)
+ " does not provide information on specific job requirements or a job URL.]",
+ " does not provide information on specific job requirements or a job URL.]"
)
msg = Message("system", result, "action_result")
history.append(msg)
@@ -117,7 +117,7 @@ def test_message_history_batch_summary(mocker, agent, config):
history.append(user_input_msg)
# only take the last cycle of the message history, trim the rest of previous messages, and generate a summary for them
for cycle in reversed(list(history.per_cycle(config))):
for cycle in reversed(list(history.per_cycle())):
messages_to_add = [msg for msg in cycle if msg is not None]
message_sequence.insert(insertion_index, *messages_to_add)
break
@@ -134,7 +134,7 @@ def test_message_history_batch_summary(mocker, agent, config):
)
expected_call_count = math.ceil(
message_tlength / (OPEN_AI_CHAT_MODELS.get(config.fast_llm).max_tokens)
message_tlength / (OPEN_AI_CHAT_MODELS[config.fast_llm].max_tokens)
)
# Expecting 2 batches because of over max token
assert mock_summary.call_count == expected_call_count # 2 at the time of writing