mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-30 20:44:28 +01:00
228 lines
9.0 KiB
Python
228 lines
9.0 KiB
Python
import pytest
|
|
from exchange import Exchange, Message
|
|
from exchange.content import ToolResult, ToolUse
|
|
from exchange.moderators.passive import PassiveModerator
|
|
from exchange.moderators.summarizer import ContextSummarizer
|
|
from exchange.providers import Usage, Provider
|
|
|
|
|
|
class MockProvider(Provider):
|
|
def complete(self, model, system, messages, tools, **kwargs):
|
|
assistant_message_text = "Summarized content here."
|
|
output_tokens = len(assistant_message_text)
|
|
total_input_tokens = sum(len(msg.text) for msg in messages)
|
|
if not messages or messages[-1].role == "assistant":
|
|
message = Message.user(assistant_message_text)
|
|
else:
|
|
message = Message.assistant(assistant_message_text)
|
|
total_tokens = total_input_tokens + output_tokens
|
|
usage = Usage(
|
|
input_tokens=total_input_tokens,
|
|
output_tokens=output_tokens,
|
|
total_tokens=total_tokens,
|
|
)
|
|
return message, usage
|
|
|
|
|
|
@pytest.fixture
|
|
def exchange_instance():
|
|
ex = Exchange(
|
|
provider=MockProvider(),
|
|
model="test-model",
|
|
system="test-system",
|
|
messages=[
|
|
Message.user("Hi, can you help me with my homework?"),
|
|
Message.assistant("Of course! What do you need help with?"),
|
|
Message.user("I need help with math problems."),
|
|
Message.assistant("Sure, I can help with that. Let's get started."),
|
|
Message.user("Can you also help with my science homework?"),
|
|
Message.assistant("Yes, I can help with science too."),
|
|
Message.user("That's great! How about history?"),
|
|
Message.assistant("Of course, I can help with history as well."),
|
|
Message.user("Thanks! You're very helpful."),
|
|
Message.assistant("You're welcome! I'm here to help."),
|
|
],
|
|
moderator=PassiveModerator(),
|
|
)
|
|
return ex
|
|
|
|
|
|
@pytest.fixture
|
|
def summarizer_instance():
|
|
return ContextSummarizer(max_tokens=300)
|
|
|
|
|
|
def test_context_summarizer_rewrite(exchange_instance: Exchange, summarizer_instance: ContextSummarizer):
|
|
# Pre-checks
|
|
assert len(exchange_instance.messages) == 10
|
|
|
|
exchange_instance.generate()
|
|
|
|
# the exchange instance has a PassiveModerator so the messages are not truncated nor summarized
|
|
assert len(exchange_instance.messages) == 11
|
|
assert len(exchange_instance.checkpoint_data.checkpoints) == 2
|
|
|
|
# we now tell the summarizer to summarize the exchange
|
|
summarizer_instance.rewrite(exchange_instance)
|
|
|
|
assert exchange_instance.checkpoint_data.total_token_count <= 200
|
|
assert len(exchange_instance.messages) == 2
|
|
|
|
# Assert that summarized content is the first message
|
|
first_message = exchange_instance.messages[0]
|
|
assert first_message.role == "user" or first_message.role == "assistant"
|
|
assert any("summarized" in content.text.lower() for content in first_message.content)
|
|
|
|
# Ensure roles alternate in the output
|
|
for i in range(1, len(exchange_instance.messages)):
|
|
assert (
|
|
exchange_instance.messages[i - 1].role != exchange_instance.messages[i].role
|
|
), "Messages must alternate between user and assistant"
|
|
|
|
|
|
MESSAGE_SEQUENCE = [
|
|
Message.user("Hi, can you help me with my homework?"),
|
|
Message.assistant("Of course! What do you need help with?"),
|
|
Message.user("I need help with math problems."),
|
|
Message.assistant("Sure, I can help with that. Let's get started."),
|
|
Message.user("What is 2 + 2, 3*3, 9/5, 2*20, 14/2?"),
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="1", name="add", parameters={"a": 2, "b": 2})],
|
|
),
|
|
Message(role="user", content=[ToolResult(tool_use_id="1", output="4")]),
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="2", name="multiply", parameters={"a": 3, "b": 3})],
|
|
),
|
|
Message(role="user", content=[ToolResult(tool_use_id="2", output="9")]),
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="3", name="divide", parameters={"a": 9, "b": 5})],
|
|
),
|
|
Message(role="user", content=[ToolResult(tool_use_id="3", output="1.8")]),
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="4", name="multiply", parameters={"a": 2, "b": 20})],
|
|
),
|
|
Message(role="user", content=[ToolResult(tool_use_id="4", output="40")]),
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="5", name="divide", parameters={"a": 14, "b": 2})],
|
|
),
|
|
Message(role="user", content=[ToolResult(tool_use_id="5", output="7")]),
|
|
Message.assistant("I'm done calculating the answers to your math questions."),
|
|
Message.user("Can you also help with my science homework?"),
|
|
Message.assistant("Yes, I can help with science too."),
|
|
Message.user("What is the speed of light? The frequency of a photon? The mass of an electron?"),
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="6", name="speed_of_light", parameters={})],
|
|
),
|
|
Message(role="user", content=[ToolResult(tool_use_id="6", output="299,792,458 m/s")]),
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="7", name="photon_frequency", parameters={})],
|
|
),
|
|
Message(role="user", content=[ToolResult(tool_use_id="7", output="2.418 x 10^14 Hz")]),
|
|
Message(role="assistant", content=[ToolUse(id="8", name="electron_mass", parameters={})]),
|
|
Message(
|
|
role="user",
|
|
content=[ToolResult(tool_use_id="8", output="9.10938356 x 10^-31 kg")],
|
|
),
|
|
Message.assistant("I'm done calculating the answers to your science questions."),
|
|
Message.user("That's great! How about history?"),
|
|
Message.assistant("Of course, I can help with history as well."),
|
|
Message.user("Thanks! You're very helpful."),
|
|
Message.assistant("You're welcome! I'm here to help."),
|
|
]
|
|
|
|
|
|
class AnotherMockProvider(Provider):
|
|
def __init__(self):
|
|
self.sequence = MESSAGE_SEQUENCE
|
|
self.current_index = 1
|
|
self.summarize_next = False
|
|
self.summarized_count = 0
|
|
|
|
def complete(self, model, system, messages, tools, **kwargs):
|
|
system_prompt_tokens = 100
|
|
input_token_count = system_prompt_tokens
|
|
|
|
message = self.sequence[self.current_index]
|
|
if self.summarize_next:
|
|
text = "Summary message here"
|
|
self.summarize_next = False
|
|
self.summarized_count += 1
|
|
return Message.assistant(text=text), Usage(
|
|
# in this case, input tokens can really be whatever
|
|
input_tokens=40,
|
|
output_tokens=len(text) * 2,
|
|
total_tokens=40 + len(text) * 2,
|
|
)
|
|
|
|
if len(messages) > 0 and type(messages[0].content[0]) is ToolResult:
|
|
raise ValueError("ToolResult should not be the first message")
|
|
|
|
if len(messages) == 1 and messages[0].text == "a":
|
|
# adding a +1 for the "a"
|
|
return Message.assistant("Getting system prompt size"), Usage(
|
|
input_tokens=80 + 1, output_tokens=20, total_tokens=system_prompt_tokens + 1
|
|
)
|
|
|
|
for i in range(len(messages)):
|
|
if type(messages[i].content[0]) in (ToolResult, ToolUse):
|
|
input_token_count += 10
|
|
else:
|
|
input_token_count += len(messages[i].text) * 2
|
|
|
|
if type(message.content[0]) in (ToolResult, ToolUse):
|
|
output_tokens = 10
|
|
else:
|
|
output_tokens = len(message.text) * 2
|
|
|
|
total_tokens = input_token_count + output_tokens
|
|
if total_tokens > 300:
|
|
self.summarize_next = True
|
|
usage = Usage(
|
|
input_tokens=input_token_count,
|
|
output_tokens=output_tokens,
|
|
total_tokens=total_tokens,
|
|
)
|
|
self.current_index += 2
|
|
return message, usage
|
|
|
|
|
|
@pytest.fixture
|
|
def conversation_exchange_instance():
|
|
ex = Exchange(
|
|
provider=AnotherMockProvider(),
|
|
model="test-model",
|
|
system="test-system",
|
|
moderator=ContextSummarizer(max_tokens=300),
|
|
# TODO: make it work with an offset so we don't have to send off requests basically
|
|
# at every generate step
|
|
)
|
|
return ex
|
|
|
|
|
|
def test_summarizer_generic_conversation(conversation_exchange_instance: Exchange):
|
|
i = 0
|
|
while i < len(MESSAGE_SEQUENCE):
|
|
next_message = MESSAGE_SEQUENCE[i]
|
|
conversation_exchange_instance.add(next_message)
|
|
message = conversation_exchange_instance.generate()
|
|
if message.text != "Summary message here":
|
|
i += 2
|
|
checkpoints = conversation_exchange_instance.checkpoint_data.checkpoints
|
|
assert conversation_exchange_instance.checkpoint_data.total_token_count == 570
|
|
assert len(checkpoints) == 10
|
|
assert len(conversation_exchange_instance.messages) == 10
|
|
assert checkpoints[0].start_index == 20
|
|
assert checkpoints[0].end_index == 20
|
|
assert checkpoints[-1].start_index == 29
|
|
assert checkpoints[-1].end_index == 29
|
|
assert conversation_exchange_instance.checkpoint_data.message_index_offset == 20
|
|
assert conversation_exchange_instance.provider.summarized_count == 12
|
|
assert conversation_exchange_instance.moderator.system_prompt_token_count == 100
|