mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 22:54:24 +01:00
133 lines
5.1 KiB
Python
133 lines
5.1 KiB
Python
import pytest
|
|
from exchange import Exchange
|
|
from exchange.content import ToolResult, ToolUse
|
|
from exchange.message import Message
|
|
from exchange.moderators.truncate import ContextTruncate
|
|
from exchange.providers import Provider, Usage
|
|
|
|
MAX_TOKENS = 300
|
|
SYSTEM_PROMPT_TOKENS = 100
|
|
|
|
MESSAGE_SEQUENCE = [
|
|
Message.user("Hi, can you help me with my homework?"),
|
|
Message.assistant("Of course! What do you need help with?"),
|
|
Message.user("I need help with math problems."),
|
|
Message.assistant("Sure, I can help with that. Let's get started."),
|
|
Message.user("What is 2 + 2, 3*3, 9/5, 2*20, 14/2?"),
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="1", name="add", parameters={"a": 2, "b": 2})],
|
|
),
|
|
Message(role="user", content=[ToolResult(tool_use_id="1", output="4")]),
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="2", name="multiply", parameters={"a": 3, "b": 3})],
|
|
),
|
|
Message(role="user", content=[ToolResult(tool_use_id="2", output="9")]),
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="3", name="divide", parameters={"a": 9, "b": 5})],
|
|
),
|
|
Message(role="user", content=[ToolResult(tool_use_id="3", output="1.8")]),
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="4", name="multiply", parameters={"a": 2, "b": 20})],
|
|
),
|
|
Message(role="user", content=[ToolResult(tool_use_id="4", output="40")]),
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="5", name="divide", parameters={"a": 14, "b": 2})],
|
|
),
|
|
Message(role="user", content=[ToolResult(tool_use_id="5", output="7")]),
|
|
Message.assistant("I'm done calculating the answers to your math questions."),
|
|
Message.user("Can you also help with my science homework?"),
|
|
Message.assistant("Yes, I can help with science too."),
|
|
Message.user("What is the speed of light? The frequency of a photon? The mass of an electron?"),
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="6", name="speed_of_light", parameters={})],
|
|
),
|
|
Message(role="user", content=[ToolResult(tool_use_id="6", output="299,792,458 m/s")]),
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="7", name="photon_frequency", parameters={})],
|
|
),
|
|
Message(role="user", content=[ToolResult(tool_use_id="7", output="2.418 x 10^14 Hz")]),
|
|
Message(role="assistant", content=[ToolUse(id="8", name="electron_mass", parameters={})]),
|
|
Message(
|
|
role="user",
|
|
content=[ToolResult(tool_use_id="8", output="9.10938356 x 10^-31 kg")],
|
|
),
|
|
Message.assistant("I'm done calculating the answers to your science questions."),
|
|
Message.user("That's great! How about history?"),
|
|
Message.assistant("Of course, I can help with history as well."),
|
|
Message.user("Thanks! You're very helpful."),
|
|
Message.assistant("You're welcome! I'm here to help."),
|
|
]
|
|
|
|
|
|
class TruncateLinearProvider(Provider):
|
|
def __init__(self):
|
|
self.sequence = MESSAGE_SEQUENCE
|
|
self.current_index = 1
|
|
self.summarize_next = False
|
|
self.summarized_count = 0
|
|
|
|
def complete(self, model, system, messages, tools):
|
|
input_token_count = SYSTEM_PROMPT_TOKENS
|
|
|
|
message = self.sequence[self.current_index]
|
|
|
|
if len(messages) > 0 and type(messages[0].content[0]) is ToolResult:
|
|
raise ValueError("ToolResult should not be the first message")
|
|
|
|
if len(messages) == 1 and messages[0].text == "a":
|
|
# adding a +1 for the "a"
|
|
return Message.assistant("Getting system prompt size"), Usage(
|
|
input_tokens=80 + 1, output_tokens=20, total_tokens=SYSTEM_PROMPT_TOKENS + 1
|
|
)
|
|
|
|
for i in range(len(messages)):
|
|
if type(messages[i].content[0]) in (ToolResult, ToolUse):
|
|
input_token_count += 10
|
|
else:
|
|
input_token_count += len(messages[i].text) * 2
|
|
|
|
if type(message.content[0]) in (ToolResult, ToolUse):
|
|
output_tokens = 10
|
|
else:
|
|
output_tokens = len(message.text) * 2
|
|
|
|
total_tokens = input_token_count + output_tokens
|
|
usage = Usage(
|
|
input_tokens=input_token_count,
|
|
output_tokens=output_tokens,
|
|
total_tokens=total_tokens,
|
|
)
|
|
self.current_index += 2
|
|
return message, usage
|
|
|
|
|
|
@pytest.fixture
|
|
def conversation_exchange_instance():
|
|
ex = Exchange(
|
|
provider=TruncateLinearProvider(),
|
|
model="test-model",
|
|
system="test-system",
|
|
moderator=ContextTruncate(max_tokens=500),
|
|
)
|
|
return ex
|
|
|
|
|
|
def test_truncate_on_generic_conversation(conversation_exchange_instance: Exchange):
|
|
i = 0
|
|
while i < len(MESSAGE_SEQUENCE):
|
|
next_message = MESSAGE_SEQUENCE[i]
|
|
conversation_exchange_instance.add(next_message)
|
|
message = conversation_exchange_instance.generate()
|
|
if message.text != "Summary message here":
|
|
i += 2
|
|
# ensure the total token count is not anything exhorbitant
|
|
assert conversation_exchange_instance.checkpoint_data.total_token_count < 700
|
|
assert conversation_exchange_instance.moderator.system_prompt_token_count == 100
|