mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-02-04 05:44:25 +01:00
Rebase MessageHistory on ChatSequence (#4922)
* Rebase `MessageHistory` on `ChatSequence` * Process feedback & make mypy happy --------- Co-authored-by: James Collins <collijk@uw.edu>
This commit is contained in:
committed by
GitHub
parent
7dc6d736c7
commit
1e1eff70bc
@@ -70,7 +70,7 @@ class Agent:
|
||||
):
|
||||
self.ai_name = ai_name
|
||||
self.memory = memory
|
||||
self.history = MessageHistory(self)
|
||||
self.history = MessageHistory.for_model(config.smart_llm, agent=self)
|
||||
self.next_action_count = next_action_count
|
||||
self.command_registry = command_registry
|
||||
self.config = config
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from math import ceil, floor
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional, TypedDict
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Type, TypedDict, TypeVar, overload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.llm.providers.openai import OpenAIFunctionCall
|
||||
|
||||
MessageRole = Literal["system", "user", "assistant"]
|
||||
MessageRole = Literal["system", "user", "assistant", "function"]
|
||||
MessageType = Literal["ai_response", "action_result"]
|
||||
|
||||
TText = list[int]
|
||||
@@ -68,15 +69,31 @@ class EmbeddingModelInfo(ModelInfo):
|
||||
embedding_dimensions: int
|
||||
|
||||
|
||||
# Can be replaced by Self in Python 3.11
|
||||
TChatSequence = TypeVar("TChatSequence", bound="ChatSequence")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatSequence:
|
||||
"""Utility container for a chat sequence"""
|
||||
|
||||
model: ChatModelInfo
|
||||
messages: list[Message] = field(default_factory=list)
|
||||
messages: list[Message] = field(default_factory=list[Message])
|
||||
|
||||
def __getitem__(self, i: int):
|
||||
return self.messages[i]
|
||||
@overload
|
||||
def __getitem__(self, key: int) -> Message:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self: TChatSequence, key: slice) -> TChatSequence:
|
||||
...
|
||||
|
||||
def __getitem__(self: TChatSequence, key: int | slice) -> Message | TChatSequence:
|
||||
if isinstance(key, slice):
|
||||
copy = deepcopy(self)
|
||||
copy.messages = self.messages[key]
|
||||
return copy
|
||||
return self.messages[key]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.messages)
|
||||
@@ -84,6 +101,14 @@ class ChatSequence:
|
||||
def __len__(self):
|
||||
return len(self.messages)
|
||||
|
||||
def add(
|
||||
self,
|
||||
message_role: MessageRole,
|
||||
content: str,
|
||||
type: MessageType | None = None,
|
||||
) -> None:
|
||||
self.append(Message(message_role, content, type))
|
||||
|
||||
def append(self, message: Message):
|
||||
return self.messages.append(message)
|
||||
|
||||
@@ -95,21 +120,23 @@ class ChatSequence:
|
||||
self.messages.insert(index, message)
|
||||
|
||||
@classmethod
|
||||
def for_model(cls, model_name: str, messages: list[Message] | ChatSequence = []):
|
||||
def for_model(
|
||||
cls: Type[TChatSequence],
|
||||
model_name: str,
|
||||
messages: list[Message] | ChatSequence = [],
|
||||
**kwargs,
|
||||
) -> TChatSequence:
|
||||
from autogpt.llm.providers.openai import OPEN_AI_CHAT_MODELS
|
||||
|
||||
if not model_name in OPEN_AI_CHAT_MODELS:
|
||||
raise ValueError(f"Unknown chat model '{model_name}'")
|
||||
|
||||
return ChatSequence(
|
||||
model=OPEN_AI_CHAT_MODELS[model_name], messages=list(messages)
|
||||
return cls(
|
||||
model=OPEN_AI_CHAT_MODELS[model_name], messages=list(messages), **kwargs
|
||||
)
|
||||
|
||||
def add(self, message_role: MessageRole, content: str):
|
||||
self.messages.append(Message(message_role, content))
|
||||
|
||||
@property
|
||||
def token_length(self):
|
||||
def token_length(self) -> int:
|
||||
from autogpt.llm.utils import count_message_tokens
|
||||
|
||||
return count_message_tokens(self.messages, self.model.name)
|
||||
@@ -128,7 +155,7 @@ class ChatSequence:
|
||||
[f"{separator(m.role)}\n{m.content}" for m in self.messages]
|
||||
)
|
||||
return f"""
|
||||
============== ChatSequence ==============
|
||||
============== {__class__.__name__} ==============
|
||||
Length: {self.token_length} tokens; {len(self.messages)} messages
|
||||
{formatted_messages}
|
||||
==========================================
|
||||
@@ -148,7 +175,7 @@ class LLMResponse:
|
||||
class EmbeddingModelResponse(LLMResponse):
|
||||
"""Standard response struct for a response from an embedding model."""
|
||||
|
||||
embedding: List[float] = field(default_factory=list)
|
||||
embedding: list[float] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.completion_tokens_used:
|
||||
|
||||
@@ -93,13 +93,13 @@ def chat_with_ai(
|
||||
|
||||
# Account for user input (appended later)
|
||||
user_input_msg = Message("user", triggering_prompt)
|
||||
current_tokens_used += count_message_tokens([user_input_msg], model)
|
||||
current_tokens_used += count_message_tokens(user_input_msg, model)
|
||||
|
||||
current_tokens_used += 500 # Reserve space for new_summary_message
|
||||
current_tokens_used += agent.history.max_summary_tlength # Reserve space
|
||||
current_tokens_used += 500 # Reserve space for the openai functions TODO improve
|
||||
|
||||
# Add Messages until the token limit is reached or there are no more messages to add.
|
||||
for cycle in reversed(list(agent.history.per_cycle(agent.config))):
|
||||
for cycle in reversed(list(agent.history.per_cycle())):
|
||||
messages_to_add = [msg for msg in cycle if msg is not None]
|
||||
tokens_to_add = count_message_tokens(messages_to_add, model)
|
||||
if current_tokens_used + tokens_to_add > send_token_limit:
|
||||
@@ -115,9 +115,9 @@ def chat_with_ai(
|
||||
new_summary_message, trimmed_messages = agent.history.trim_messages(
|
||||
current_message_chain=list(message_sequence), config=agent.config
|
||||
)
|
||||
tokens_to_add = count_message_tokens([new_summary_message], model)
|
||||
tokens_to_add = count_message_tokens(new_summary_message, model)
|
||||
message_sequence.insert(insertion_index, new_summary_message)
|
||||
current_tokens_used += tokens_to_add - 500
|
||||
current_tokens_used += tokens_to_add - agent.history.max_summary_tlength
|
||||
|
||||
# FIXME: uncomment when memory is back in use
|
||||
# memory_store = get_memory(config)
|
||||
@@ -143,7 +143,7 @@ def chat_with_ai(
|
||||
)
|
||||
logger.debug(budget_message)
|
||||
message_sequence.add("system", budget_message)
|
||||
current_tokens_used += count_message_tokens([message_sequence[-1]], model)
|
||||
current_tokens_used += count_message_tokens(message_sequence[-1], model)
|
||||
|
||||
# Append user input, the length of this is accounted for above
|
||||
message_sequence.append(user_input_msg)
|
||||
@@ -157,9 +157,7 @@ def chat_with_ai(
|
||||
)
|
||||
if not plugin_response or plugin_response == "":
|
||||
continue
|
||||
tokens_to_add = count_message_tokens(
|
||||
[Message("system", plugin_response)], model
|
||||
)
|
||||
tokens_to_add = count_message_tokens(Message("system", plugin_response), model)
|
||||
if current_tokens_used + tokens_to_add > send_token_limit:
|
||||
logger.debug(f"Plugin response too long, skipping: {plugin_response}")
|
||||
logger.debug(f"Plugins remaining at stop: {plugin_count - i}")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Functions for counting the number of tokens in a message or string."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from typing import List, overload
|
||||
|
||||
import tiktoken
|
||||
|
||||
@@ -9,8 +9,18 @@ from autogpt.llm.base import Message
|
||||
from autogpt.logs import logger
|
||||
|
||||
|
||||
@overload
|
||||
def count_message_tokens(messages: Message, model: str = "gpt-3.5-turbo") -> int:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def count_message_tokens(messages: List[Message], model: str = "gpt-3.5-turbo") -> int:
|
||||
...
|
||||
|
||||
|
||||
def count_message_tokens(
|
||||
messages: List[Message], model: str = "gpt-3.5-turbo-0301"
|
||||
messages: Message | List[Message], model: str = "gpt-3.5-turbo"
|
||||
) -> int:
|
||||
"""
|
||||
Returns the number of tokens used by a list of messages.
|
||||
@@ -24,6 +34,9 @@ def count_message_tokens(
|
||||
Returns:
|
||||
int: The number of tokens used by the list of messages.
|
||||
"""
|
||||
if isinstance(messages, Message):
|
||||
messages = [messages]
|
||||
|
||||
if model.startswith("gpt-3.5-turbo"):
|
||||
tokens_per_message = (
|
||||
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
|
||||
@@ -2,49 +2,46 @@ from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agent import Agent
|
||||
|
||||
from autogpt.config import Config
|
||||
from autogpt.json_utils.utilities import extract_json_from_response
|
||||
from autogpt.llm.base import ChatSequence, Message, MessageRole, MessageType
|
||||
from autogpt.llm.base import ChatSequence, Message
|
||||
from autogpt.llm.providers.openai import OPEN_AI_CHAT_MODELS
|
||||
from autogpt.llm.utils import count_string_tokens, create_chat_completion
|
||||
from autogpt.llm.utils import (
|
||||
count_message_tokens,
|
||||
count_string_tokens,
|
||||
create_chat_completion,
|
||||
)
|
||||
from autogpt.log_cycle.log_cycle import PROMPT_SUMMARY_FILE_NAME, SUMMARY_FILE_NAME
|
||||
from autogpt.logs import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageHistory:
|
||||
agent: Agent
|
||||
|
||||
messages: list[Message] = field(default_factory=list)
|
||||
class MessageHistory(ChatSequence):
|
||||
max_summary_tlength: int = 500
|
||||
agent: Optional[Agent] = None
|
||||
summary: str = "I was created"
|
||||
|
||||
last_trimmed_index: int = 0
|
||||
|
||||
def __getitem__(self, i: int):
|
||||
return self.messages[i]
|
||||
SUMMARIZATION_PROMPT = '''Your task is to create a concise running summary of actions and information results in the provided text, focusing on key and potentially important information to remember.
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.messages)
|
||||
You will receive the current summary and your latest actions. Combine them, adding relevant key information from the latest development in 1st person past tense and keeping the summary concise.
|
||||
|
||||
def __len__(self):
|
||||
return len(self.messages)
|
||||
Summary So Far:
|
||||
"""
|
||||
{summary}
|
||||
"""
|
||||
|
||||
def add(
|
||||
self,
|
||||
role: MessageRole,
|
||||
content: str,
|
||||
type: MessageType | None = None,
|
||||
):
|
||||
return self.append(Message(role, content, type))
|
||||
|
||||
def append(self, message: Message):
|
||||
return self.messages.append(message)
|
||||
Latest Development:
|
||||
"""
|
||||
{new_events}
|
||||
"""
|
||||
'''
|
||||
|
||||
def trim_messages(
|
||||
self, current_message_chain: list[Message], config: Config
|
||||
@@ -84,7 +81,7 @@ class MessageHistory:
|
||||
|
||||
return new_summary_message, new_messages_not_in_chain
|
||||
|
||||
def per_cycle(self, config: Config, messages: list[Message] | None = None):
|
||||
def per_cycle(self, messages: list[Message] | None = None):
|
||||
"""
|
||||
Yields:
|
||||
Message: a message containing user input
|
||||
@@ -119,26 +116,33 @@ class MessageHistory:
|
||||
)
|
||||
|
||||
def update_running_summary(
|
||||
self, new_events: list[Message], config: Config
|
||||
self,
|
||||
new_events: list[Message],
|
||||
config: Config,
|
||||
max_summary_length: Optional[int] = None,
|
||||
) -> Message:
|
||||
"""
|
||||
This function takes a list of dictionaries representing new events and combines them with the current summary,
|
||||
focusing on key and potentially important information to remember. The updated summary is returned in a message
|
||||
formatted in the 1st person past tense.
|
||||
This function takes a list of Message objects and updates the running summary
|
||||
to include the events they describe. The updated summary is returned
|
||||
in a Message formatted in the 1st person past tense.
|
||||
|
||||
Args:
|
||||
new_events (List[Dict]): A list of dictionaries containing the latest events to be added to the summary.
|
||||
new_events: A list of Messages containing the latest events to be added to the summary.
|
||||
|
||||
Returns:
|
||||
str: A message containing the updated summary of actions, formatted in the 1st person past tense.
|
||||
Message: a Message containing the updated running summary.
|
||||
|
||||
Example:
|
||||
```py
|
||||
new_events = [{"event": "entered the kitchen."}, {"event": "found a scrawled note with the number 7"}]
|
||||
update_running_summary(new_events)
|
||||
# Returns: "This reminds you of these events from your past: \nI entered the kitchen and found a scrawled note saying 7."
|
||||
```
|
||||
"""
|
||||
if not new_events:
|
||||
return self.summary_message()
|
||||
if not max_summary_length:
|
||||
max_summary_length = self.max_summary_tlength
|
||||
|
||||
# Create a copy of the new_events list to prevent modifying the original list
|
||||
new_events = copy.deepcopy(new_events)
|
||||
@@ -166,29 +170,29 @@ class MessageHistory:
|
||||
elif event.role == "user":
|
||||
new_events.remove(event)
|
||||
|
||||
# Summarize events and current summary in batch to a new running summary
|
||||
summ_model = OPEN_AI_CHAT_MODELS[config.fast_llm]
|
||||
|
||||
# Assume an upper bound length for the summary prompt template, i.e. Your task is to create a concise running summary...., in summarize_batch func
|
||||
# TODO make this default dynamic
|
||||
prompt_template_length = 100
|
||||
max_tokens = OPEN_AI_CHAT_MODELS.get(config.fast_llm).max_tokens
|
||||
summary_tlength = count_string_tokens(str(self.summary), config.fast_llm)
|
||||
# Determine token lengths for use in batching
|
||||
prompt_template_length = len(
|
||||
MessageHistory.SUMMARIZATION_PROMPT.format(summary="", new_events="")
|
||||
)
|
||||
max_input_tokens = summ_model.max_tokens - max_summary_length
|
||||
summary_tlength = count_string_tokens(self.summary, summ_model.name)
|
||||
batch = []
|
||||
batch_tlength = 0
|
||||
|
||||
# TODO Can put a cap on length of total new events and drop some previous events to save API cost, but need to think thru more how to do it without losing the context
|
||||
# TODO: Put a cap on length of total new events and drop some previous events to
|
||||
# save API cost. Need to think thru more how to do it without losing the context.
|
||||
for event in new_events:
|
||||
event_tlength = count_string_tokens(str(event), config.fast_llm)
|
||||
event_tlength = count_message_tokens(event, summ_model.name)
|
||||
|
||||
if (
|
||||
batch_tlength + event_tlength
|
||||
> max_tokens - prompt_template_length - summary_tlength
|
||||
> max_input_tokens - prompt_template_length - summary_tlength
|
||||
):
|
||||
# The batch is full. Summarize it and start a new one.
|
||||
self.summarize_batch(batch, config)
|
||||
summary_tlength = count_string_tokens(
|
||||
str(self.summary), config.fast_llm
|
||||
)
|
||||
self.summarize_batch(batch, config, max_summary_length)
|
||||
summary_tlength = count_string_tokens(self.summary, summ_model.name)
|
||||
batch = [event]
|
||||
batch_tlength = event_tlength
|
||||
else:
|
||||
@@ -197,41 +201,36 @@ class MessageHistory:
|
||||
|
||||
if batch:
|
||||
# There's an unprocessed batch. Summarize it.
|
||||
self.summarize_batch(batch, config)
|
||||
self.summarize_batch(batch, config, max_summary_length)
|
||||
|
||||
return self.summary_message()
|
||||
|
||||
def summarize_batch(self, new_events_batch, config):
|
||||
prompt = f'''Your task is to create a concise running summary of actions and information results in the provided text, focusing on key and potentially important information to remember.
|
||||
|
||||
You will receive the current summary and your latest actions. Combine them, adding relevant key information from the latest development in 1st person past tense and keeping the summary concise.
|
||||
|
||||
Summary So Far:
|
||||
"""
|
||||
{self.summary}
|
||||
"""
|
||||
|
||||
Latest Development:
|
||||
"""
|
||||
{new_events_batch or "Nothing new happened."}
|
||||
"""
|
||||
'''
|
||||
def summarize_batch(
|
||||
self, new_events_batch: list[Message], config: Config, max_output_length: int
|
||||
):
|
||||
prompt = MessageHistory.SUMMARIZATION_PROMPT.format(
|
||||
summary=self.summary, new_events=new_events_batch
|
||||
)
|
||||
|
||||
prompt = ChatSequence.for_model(config.fast_llm, [Message("user", prompt)])
|
||||
self.agent.log_cycle_handler.log_cycle(
|
||||
self.agent.ai_name,
|
||||
self.agent.created_at,
|
||||
self.agent.cycle_count,
|
||||
prompt.raw(),
|
||||
PROMPT_SUMMARY_FILE_NAME,
|
||||
)
|
||||
if self.agent:
|
||||
self.agent.log_cycle_handler.log_cycle(
|
||||
self.agent.ai_config.ai_name,
|
||||
self.agent.created_at,
|
||||
self.agent.cycle_count,
|
||||
prompt.raw(),
|
||||
PROMPT_SUMMARY_FILE_NAME,
|
||||
)
|
||||
|
||||
self.summary = create_chat_completion(prompt, config).content
|
||||
self.summary = create_chat_completion(
|
||||
prompt, config, max_tokens=max_output_length
|
||||
).content
|
||||
|
||||
self.agent.log_cycle_handler.log_cycle(
|
||||
self.agent.ai_name,
|
||||
self.agent.created_at,
|
||||
self.agent.cycle_count,
|
||||
self.summary,
|
||||
SUMMARY_FILE_NAME,
|
||||
)
|
||||
if self.agent:
|
||||
self.agent.log_cycle_handler.log_cycle(
|
||||
self.agent.ai_config.ai_name,
|
||||
self.agent.created_at,
|
||||
self.agent.cycle_count,
|
||||
self.summary,
|
||||
SUMMARY_FILE_NAME,
|
||||
)
|
||||
|
||||
@@ -38,8 +38,8 @@ def agent(config: Config):
|
||||
return agent
|
||||
|
||||
|
||||
def test_message_history_batch_summary(mocker, agent, config):
|
||||
history = MessageHistory(agent)
|
||||
def test_message_history_batch_summary(mocker, agent: Agent, config: Config):
|
||||
history = MessageHistory.for_model(agent.config.smart_llm, agent=agent)
|
||||
model = config.fast_llm
|
||||
message_tlength = 0
|
||||
message_count = 0
|
||||
@@ -48,7 +48,7 @@ def test_message_history_batch_summary(mocker, agent, config):
|
||||
mock_summary_response = ChatModelResponse(
|
||||
model_info=OPEN_AI_CHAT_MODELS[model],
|
||||
content="I executed browse_website command for each of the websites returned from Google search, but none of them have any job openings.",
|
||||
function_call={},
|
||||
function_call=None,
|
||||
)
|
||||
mock_summary = mocker.patch(
|
||||
"autogpt.memory.message_history.create_chat_completion",
|
||||
@@ -105,7 +105,7 @@ def test_message_history_batch_summary(mocker, agent, config):
|
||||
result = (
|
||||
"Command browse_website returned: Answer gathered from website: The text in job"
|
||||
+ str(i)
|
||||
+ " does not provide information on specific job requirements or a job URL.]",
|
||||
+ " does not provide information on specific job requirements or a job URL.]"
|
||||
)
|
||||
msg = Message("system", result, "action_result")
|
||||
history.append(msg)
|
||||
@@ -117,7 +117,7 @@ def test_message_history_batch_summary(mocker, agent, config):
|
||||
history.append(user_input_msg)
|
||||
|
||||
# only take the last cycle of the message history, trim the rest of previous messages, and generate a summary for them
|
||||
for cycle in reversed(list(history.per_cycle(config))):
|
||||
for cycle in reversed(list(history.per_cycle())):
|
||||
messages_to_add = [msg for msg in cycle if msg is not None]
|
||||
message_sequence.insert(insertion_index, *messages_to_add)
|
||||
break
|
||||
@@ -134,7 +134,7 @@ def test_message_history_batch_summary(mocker, agent, config):
|
||||
)
|
||||
|
||||
expected_call_count = math.ceil(
|
||||
message_tlength / (OPEN_AI_CHAT_MODELS.get(config.fast_llm).max_tokens)
|
||||
message_tlength / (OPEN_AI_CHAT_MODELS[config.fast_llm].max_tokens)
|
||||
)
|
||||
# Expecting 2 batches because of over max token
|
||||
assert mock_summary.call_count == expected_call_count # 2 at the time of writing
|
||||
|
||||
Reference in New Issue
Block a user