mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-23 09:04:26 +01:00
chore: setup workspace for exchange (#105)
This commit is contained in:
227
packages/exchange/tests/test_summarizer.py
Normal file
227
packages/exchange/tests/test_summarizer.py
Normal file
@@ -0,0 +1,227 @@
|
||||
import pytest
|
||||
from exchange import Exchange, Message
|
||||
from exchange.content import ToolResult, ToolUse
|
||||
from exchange.moderators.passive import PassiveModerator
|
||||
from exchange.moderators.summarizer import ContextSummarizer
|
||||
from exchange.providers import Usage
|
||||
|
||||
|
||||
class MockProvider:
|
||||
def complete(self, model, system, messages, tools):
|
||||
assistant_message_text = "Summarized content here."
|
||||
output_tokens = len(assistant_message_text)
|
||||
total_input_tokens = sum(len(msg.text) for msg in messages)
|
||||
if not messages or messages[-1].role == "assistant":
|
||||
message = Message.user(assistant_message_text)
|
||||
else:
|
||||
message = Message.assistant(assistant_message_text)
|
||||
total_tokens = total_input_tokens + output_tokens
|
||||
usage = Usage(
|
||||
input_tokens=total_input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
return message, usage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def exchange_instance():
|
||||
ex = Exchange(
|
||||
provider=MockProvider(),
|
||||
model="test-model",
|
||||
system="test-system",
|
||||
messages=[
|
||||
Message.user("Hi, can you help me with my homework?"),
|
||||
Message.assistant("Of course! What do you need help with?"),
|
||||
Message.user("I need help with math problems."),
|
||||
Message.assistant("Sure, I can help with that. Let's get started."),
|
||||
Message.user("Can you also help with my science homework?"),
|
||||
Message.assistant("Yes, I can help with science too."),
|
||||
Message.user("That's great! How about history?"),
|
||||
Message.assistant("Of course, I can help with history as well."),
|
||||
Message.user("Thanks! You're very helpful."),
|
||||
Message.assistant("You're welcome! I'm here to help."),
|
||||
],
|
||||
moderator=PassiveModerator(),
|
||||
)
|
||||
return ex
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def summarizer_instance():
|
||||
return ContextSummarizer(max_tokens=300)
|
||||
|
||||
|
||||
def test_context_summarizer_rewrite(exchange_instance: Exchange, summarizer_instance: ContextSummarizer):
|
||||
# Pre-checks
|
||||
assert len(exchange_instance.messages) == 10
|
||||
|
||||
exchange_instance.generate()
|
||||
|
||||
# the exchange instance has a PassiveModerator so the messages are not truncated nor summarized
|
||||
assert len(exchange_instance.messages) == 11
|
||||
assert len(exchange_instance.checkpoint_data.checkpoints) == 2
|
||||
|
||||
# we now tell the summarizer to summarize the exchange
|
||||
summarizer_instance.rewrite(exchange_instance)
|
||||
|
||||
assert exchange_instance.checkpoint_data.total_token_count <= 200
|
||||
assert len(exchange_instance.messages) == 2
|
||||
|
||||
# Assert that summarized content is the first message
|
||||
first_message = exchange_instance.messages[0]
|
||||
assert first_message.role == "user" or first_message.role == "assistant"
|
||||
assert any("summarized" in content.text.lower() for content in first_message.content)
|
||||
|
||||
# Ensure roles alternate in the output
|
||||
for i in range(1, len(exchange_instance.messages)):
|
||||
assert (
|
||||
exchange_instance.messages[i - 1].role != exchange_instance.messages[i].role
|
||||
), "Messages must alternate between user and assistant"
|
||||
|
||||
|
||||
MESSAGE_SEQUENCE = [
|
||||
Message.user("Hi, can you help me with my homework?"),
|
||||
Message.assistant("Of course! What do you need help with?"),
|
||||
Message.user("I need help with math problems."),
|
||||
Message.assistant("Sure, I can help with that. Let's get started."),
|
||||
Message.user("What is 2 + 2, 3*3, 9/5, 2*20, 14/2?"),
|
||||
Message(
|
||||
role="assistant",
|
||||
content=[ToolUse(id="1", name="add", parameters={"a": 2, "b": 2})],
|
||||
),
|
||||
Message(role="user", content=[ToolResult(tool_use_id="1", output="4")]),
|
||||
Message(
|
||||
role="assistant",
|
||||
content=[ToolUse(id="2", name="multiply", parameters={"a": 3, "b": 3})],
|
||||
),
|
||||
Message(role="user", content=[ToolResult(tool_use_id="2", output="9")]),
|
||||
Message(
|
||||
role="assistant",
|
||||
content=[ToolUse(id="3", name="divide", parameters={"a": 9, "b": 5})],
|
||||
),
|
||||
Message(role="user", content=[ToolResult(tool_use_id="3", output="1.8")]),
|
||||
Message(
|
||||
role="assistant",
|
||||
content=[ToolUse(id="4", name="multiply", parameters={"a": 2, "b": 20})],
|
||||
),
|
||||
Message(role="user", content=[ToolResult(tool_use_id="4", output="40")]),
|
||||
Message(
|
||||
role="assistant",
|
||||
content=[ToolUse(id="5", name="divide", parameters={"a": 14, "b": 2})],
|
||||
),
|
||||
Message(role="user", content=[ToolResult(tool_use_id="5", output="7")]),
|
||||
Message.assistant("I'm done calculating the answers to your math questions."),
|
||||
Message.user("Can you also help with my science homework?"),
|
||||
Message.assistant("Yes, I can help with science too."),
|
||||
Message.user("What is the speed of light? The frequency of a photon? The mass of an electron?"),
|
||||
Message(
|
||||
role="assistant",
|
||||
content=[ToolUse(id="6", name="speed_of_light", parameters={})],
|
||||
),
|
||||
Message(role="user", content=[ToolResult(tool_use_id="6", output="299,792,458 m/s")]),
|
||||
Message(
|
||||
role="assistant",
|
||||
content=[ToolUse(id="7", name="photon_frequency", parameters={})],
|
||||
),
|
||||
Message(role="user", content=[ToolResult(tool_use_id="7", output="2.418 x 10^14 Hz")]),
|
||||
Message(role="assistant", content=[ToolUse(id="8", name="electron_mass", parameters={})]),
|
||||
Message(
|
||||
role="user",
|
||||
content=[ToolResult(tool_use_id="8", output="9.10938356 x 10^-31 kg")],
|
||||
),
|
||||
Message.assistant("I'm done calculating the answers to your science questions."),
|
||||
Message.user("That's great! How about history?"),
|
||||
Message.assistant("Of course, I can help with history as well."),
|
||||
Message.user("Thanks! You're very helpful."),
|
||||
Message.assistant("You're welcome! I'm here to help."),
|
||||
]
|
||||
|
||||
|
||||
class AnotherMockProvider:
|
||||
def __init__(self):
|
||||
self.sequence = MESSAGE_SEQUENCE
|
||||
self.current_index = 1
|
||||
self.summarize_next = False
|
||||
self.summarized_count = 0
|
||||
|
||||
def complete(self, model, system, messages, tools):
|
||||
system_prompt_tokens = 100
|
||||
input_token_count = system_prompt_tokens
|
||||
|
||||
message = self.sequence[self.current_index]
|
||||
if self.summarize_next:
|
||||
text = "Summary message here"
|
||||
self.summarize_next = False
|
||||
self.summarized_count += 1
|
||||
return Message.assistant(text=text), Usage(
|
||||
# in this case, input tokens can really be whatever
|
||||
input_tokens=40,
|
||||
output_tokens=len(text) * 2,
|
||||
total_tokens=40 + len(text) * 2,
|
||||
)
|
||||
|
||||
if len(messages) > 0 and type(messages[0].content[0]) is ToolResult:
|
||||
raise ValueError("ToolResult should not be the first message")
|
||||
|
||||
if len(messages) == 1 and messages[0].text == "a":
|
||||
# adding a +1 for the "a"
|
||||
return Message.assistant("Getting system prompt size"), Usage(
|
||||
input_tokens=80 + 1, output_tokens=20, total_tokens=system_prompt_tokens + 1
|
||||
)
|
||||
|
||||
for i in range(len(messages)):
|
||||
if type(messages[i].content[0]) in (ToolResult, ToolUse):
|
||||
input_token_count += 10
|
||||
else:
|
||||
input_token_count += len(messages[i].text) * 2
|
||||
|
||||
if type(message.content[0]) in (ToolResult, ToolUse):
|
||||
output_tokens = 10
|
||||
else:
|
||||
output_tokens = len(message.text) * 2
|
||||
|
||||
total_tokens = input_token_count + output_tokens
|
||||
if total_tokens > 300:
|
||||
self.summarize_next = True
|
||||
usage = Usage(
|
||||
input_tokens=input_token_count,
|
||||
output_tokens=output_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
self.current_index += 2
|
||||
return message, usage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conversation_exchange_instance():
|
||||
ex = Exchange(
|
||||
provider=AnotherMockProvider(),
|
||||
model="test-model",
|
||||
system="test-system",
|
||||
moderator=ContextSummarizer(max_tokens=300),
|
||||
# TODO: make it work with an offset so we don't have to send off requests basically
|
||||
# at every generate step
|
||||
)
|
||||
return ex
|
||||
|
||||
|
||||
def test_summarizer_generic_conversation(conversation_exchange_instance: Exchange):
|
||||
i = 0
|
||||
while i < len(MESSAGE_SEQUENCE):
|
||||
next_message = MESSAGE_SEQUENCE[i]
|
||||
conversation_exchange_instance.add(next_message)
|
||||
message = conversation_exchange_instance.generate()
|
||||
if message.text != "Summary message here":
|
||||
i += 2
|
||||
checkpoints = conversation_exchange_instance.checkpoint_data.checkpoints
|
||||
assert conversation_exchange_instance.checkpoint_data.total_token_count == 570
|
||||
assert len(checkpoints) == 10
|
||||
assert len(conversation_exchange_instance.messages) == 10
|
||||
assert checkpoints[0].start_index == 20
|
||||
assert checkpoints[0].end_index == 20
|
||||
assert checkpoints[-1].start_index == 29
|
||||
assert checkpoints[-1].end_index == 29
|
||||
assert conversation_exchange_instance.checkpoint_data.message_index_offset == 20
|
||||
assert conversation_exchange_instance.provider.summarized_count == 12
|
||||
assert conversation_exchange_instance.moderator.system_prompt_token_count == 100
|
||||
Reference in New Issue
Block a user