Files
goose/packages/exchange/tests/test_truncate.py
2024-10-02 11:05:43 -07:00

133 lines
5.1 KiB
Python

import pytest
from exchange import Exchange
from exchange.content import ToolResult, ToolUse
from exchange.message import Message
from exchange.moderators.truncate import ContextTruncate
from exchange.providers import Provider, Usage
MAX_TOKENS = 300
SYSTEM_PROMPT_TOKENS = 100
MESSAGE_SEQUENCE = [
Message.user("Hi, can you help me with my homework?"),
Message.assistant("Of course! What do you need help with?"),
Message.user("I need help with math problems."),
Message.assistant("Sure, I can help with that. Let's get started."),
Message.user("What is 2 + 2, 3*3, 9/5, 2*20, 14/2?"),
Message(
role="assistant",
content=[ToolUse(id="1", name="add", parameters={"a": 2, "b": 2})],
),
Message(role="user", content=[ToolResult(tool_use_id="1", output="4")]),
Message(
role="assistant",
content=[ToolUse(id="2", name="multiply", parameters={"a": 3, "b": 3})],
),
Message(role="user", content=[ToolResult(tool_use_id="2", output="9")]),
Message(
role="assistant",
content=[ToolUse(id="3", name="divide", parameters={"a": 9, "b": 5})],
),
Message(role="user", content=[ToolResult(tool_use_id="3", output="1.8")]),
Message(
role="assistant",
content=[ToolUse(id="4", name="multiply", parameters={"a": 2, "b": 20})],
),
Message(role="user", content=[ToolResult(tool_use_id="4", output="40")]),
Message(
role="assistant",
content=[ToolUse(id="5", name="divide", parameters={"a": 14, "b": 2})],
),
Message(role="user", content=[ToolResult(tool_use_id="5", output="7")]),
Message.assistant("I'm done calculating the answers to your math questions."),
Message.user("Can you also help with my science homework?"),
Message.assistant("Yes, I can help with science too."),
Message.user("What is the speed of light? The frequency of a photon? The mass of an electron?"),
Message(
role="assistant",
content=[ToolUse(id="6", name="speed_of_light", parameters={})],
),
Message(role="user", content=[ToolResult(tool_use_id="6", output="299,792,458 m/s")]),
Message(
role="assistant",
content=[ToolUse(id="7", name="photon_frequency", parameters={})],
),
Message(role="user", content=[ToolResult(tool_use_id="7", output="2.418 x 10^14 Hz")]),
Message(role="assistant", content=[ToolUse(id="8", name="electron_mass", parameters={})]),
Message(
role="user",
content=[ToolResult(tool_use_id="8", output="9.10938356 x 10^-31 kg")],
),
Message.assistant("I'm done calculating the answers to your science questions."),
Message.user("That's great! How about history?"),
Message.assistant("Of course, I can help with history as well."),
Message.user("Thanks! You're very helpful."),
Message.assistant("You're welcome! I'm here to help."),
]
class TruncateLinearProvider(Provider):
def __init__(self):
self.sequence = MESSAGE_SEQUENCE
self.current_index = 1
self.summarize_next = False
self.summarized_count = 0
def complete(self, model, system, messages, tools):
input_token_count = SYSTEM_PROMPT_TOKENS
message = self.sequence[self.current_index]
if len(messages) > 0 and type(messages[0].content[0]) is ToolResult:
raise ValueError("ToolResult should not be the first message")
if len(messages) == 1 and messages[0].text == "a":
# adding a +1 for the "a"
return Message.assistant("Getting system prompt size"), Usage(
input_tokens=80 + 1, output_tokens=20, total_tokens=SYSTEM_PROMPT_TOKENS + 1
)
for i in range(len(messages)):
if type(messages[i].content[0]) in (ToolResult, ToolUse):
input_token_count += 10
else:
input_token_count += len(messages[i].text) * 2
if type(message.content[0]) in (ToolResult, ToolUse):
output_tokens = 10
else:
output_tokens = len(message.text) * 2
total_tokens = input_token_count + output_tokens
usage = Usage(
input_tokens=input_token_count,
output_tokens=output_tokens,
total_tokens=total_tokens,
)
self.current_index += 2
return message, usage
@pytest.fixture
def conversation_exchange_instance():
ex = Exchange(
provider=TruncateLinearProvider(),
model="test-model",
system="test-system",
moderator=ContextTruncate(max_tokens=500),
)
return ex
def test_truncate_on_generic_conversation(conversation_exchange_instance: Exchange):
i = 0
while i < len(MESSAGE_SEQUENCE):
next_message = MESSAGE_SEQUENCE[i]
conversation_exchange_instance.add(next_message)
message = conversation_exchange_instance.generate()
if message.text != "Summary message here":
i += 2
# ensure the total token count is not anything exhorbitant
assert conversation_exchange_instance.checkpoint_data.total_token_count < 700
assert conversation_exchange_instance.moderator.system_prompt_token_count == 100