mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 22:54:24 +01:00
764 lines
26 KiB
Python
764 lines
26 KiB
Python
from typing import List, Tuple
|
|
|
|
import pytest
|
|
|
|
from exchange.checkpoint import Checkpoint, CheckpointData
|
|
from exchange.content import Text, ToolResult, ToolUse
|
|
from exchange.exchange import Exchange
|
|
from exchange.message import Message
|
|
from exchange.moderators import PassiveModerator
|
|
from exchange.providers import Provider, Usage
|
|
from exchange.tool import Tool
|
|
|
|
|
|
def dummy_tool() -> str:
|
|
"""An example tool"""
|
|
return "dummy response"
|
|
|
|
|
|
too_long_output = "x" * (2**20 + 1)
|
|
too_long_token_output = "word " * 128000
|
|
|
|
|
|
def no_overlapping_checkpoints(exchange: Exchange) -> bool:
|
|
"""Assert that there are no overlapping checkpoints in the exchange."""
|
|
for i, checkpoint in enumerate(exchange.checkpoint_data.checkpoints):
|
|
for other_checkpoint in exchange.checkpoint_data.checkpoints[i + 1 :]:
|
|
if not checkpoint.end_index < other_checkpoint.start_index:
|
|
return False
|
|
return True
|
|
|
|
|
|
def checkpoint_to_index_pairs(checkpoints: List[Checkpoint]) -> List[Tuple[int, int]]:
|
|
return [(checkpoint.start_index, checkpoint.end_index) for checkpoint in checkpoints]
|
|
|
|
|
|
class MockProvider(Provider):
|
|
def __init__(self, sequence: List[Message], usage_dicts: List[dict]):
|
|
# We'll use init to provide a preplanned reply sequence
|
|
self.sequence = sequence
|
|
self.call_count = 0
|
|
self.usage_dicts = usage_dicts
|
|
|
|
@staticmethod
|
|
def get_usage(data: dict) -> Usage:
|
|
usage = data.pop("usage")
|
|
input_tokens = usage.get("input_tokens")
|
|
output_tokens = usage.get("output_tokens")
|
|
total_tokens = usage.get("total_tokens")
|
|
|
|
if total_tokens is None and input_tokens is not None and output_tokens is not None:
|
|
total_tokens = input_tokens + output_tokens
|
|
|
|
return Usage(
|
|
input_tokens=input_tokens,
|
|
output_tokens=output_tokens,
|
|
total_tokens=total_tokens,
|
|
)
|
|
|
|
def complete(self, model: str, system: str, messages: List[Message], tools: List[Tool]) -> Message:
|
|
output = self.sequence[self.call_count]
|
|
usage = self.get_usage(self.usage_dicts[self.call_count])
|
|
self.call_count += 1
|
|
return (output, usage)
|
|
|
|
|
|
def test_reply_with_unsupported_tool():
|
|
ex = Exchange(
|
|
provider=MockProvider(
|
|
sequence=[
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="1", name="unsupported_tool", parameters={})],
|
|
),
|
|
Message(
|
|
role="assistant",
|
|
content=[Text(text="Here is the completion after tool call")],
|
|
),
|
|
],
|
|
usage_dicts=[
|
|
{"usage": {"input_tokens": 12, "output_tokens": 23}},
|
|
{"usage": {"input_tokens": 12, "output_tokens": 23}},
|
|
],
|
|
),
|
|
model="gpt-4o-2024-05-13",
|
|
system="You are a helpful assistant.",
|
|
tools=(Tool.from_function(dummy_tool),),
|
|
moderator=PassiveModerator(),
|
|
)
|
|
|
|
ex.add(Message(role="user", content=[Text(text="test")]))
|
|
|
|
ex.reply()
|
|
|
|
content = ex.messages[-2].content[0]
|
|
assert isinstance(content, ToolResult) and content.is_error and "no tool exists" in content.output.lower()
|
|
|
|
|
|
def test_invalid_tool_parameters():
|
|
"""Test handling of invalid tool parameters response"""
|
|
ex = Exchange(
|
|
provider=MockProvider(
|
|
sequence=[
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="1", name="dummy_tool", parameters="invalid json")],
|
|
),
|
|
Message(
|
|
role="assistant",
|
|
content=[Text(text="Here is the completion after tool call")],
|
|
),
|
|
],
|
|
usage_dicts=[
|
|
{"usage": {"input_tokens": 12, "output_tokens": 23}},
|
|
{"usage": {"input_tokens": 12, "output_tokens": 23}},
|
|
],
|
|
),
|
|
model="gpt-4o-2024-05-13",
|
|
system="You are a helpful assistant.",
|
|
tools=[Tool.from_function(dummy_tool)],
|
|
moderator=PassiveModerator(),
|
|
)
|
|
|
|
ex.add(Message(role="user", content=[Text(text="test invalid parameters")]))
|
|
|
|
ex.reply()
|
|
|
|
content = ex.messages[-2].content[0]
|
|
assert isinstance(content, ToolResult) and content.is_error and "invalid json" in content.output.lower()
|
|
|
|
|
|
def test_max_tool_use_when_limit_reached():
|
|
"""Test the max_tool_use parameter in the reply method."""
|
|
ex = Exchange(
|
|
provider=MockProvider(
|
|
sequence=[
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="1", name="dummy_tool", parameters={})],
|
|
),
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="2", name="dummy_tool", parameters={})],
|
|
),
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="3", name="dummy_tool", parameters={})],
|
|
),
|
|
],
|
|
usage_dicts=[
|
|
{"usage": {"input_tokens": 12, "output_tokens": 23}},
|
|
{"usage": {"input_tokens": 12, "output_tokens": 23}},
|
|
{"usage": {"input_tokens": 12, "output_tokens": 23}},
|
|
],
|
|
),
|
|
model="gpt-4o-2024-05-13",
|
|
system="You are a helpful assistant.",
|
|
tools=[Tool.from_function(dummy_tool)],
|
|
moderator=PassiveModerator(),
|
|
)
|
|
|
|
ex.add(Message(role="user", content=[Text(text="test max tool use")]))
|
|
|
|
response = ex.reply(max_tool_use=3)
|
|
|
|
assert ex.provider.call_count == 3
|
|
assert "reached the limit" in response.content[0].text.lower()
|
|
|
|
assert isinstance(ex.messages[-2].content[0], ToolResult) and ex.messages[-2].content[0].tool_use_id == "3"
|
|
|
|
assert ex.messages[-1].role == "assistant"
|
|
|
|
|
|
def test_tool_output_too_long_character_error():
|
|
"""Test tool handling when output exceeds character limit."""
|
|
|
|
def long_output_tool_char() -> str:
|
|
return too_long_output
|
|
|
|
ex = Exchange(
|
|
provider=MockProvider(
|
|
sequence=[
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="1", name="long_output_tool_char", parameters={})],
|
|
),
|
|
Message(
|
|
role="assistant",
|
|
content=[Text(text="Here is the completion after tool call")],
|
|
),
|
|
],
|
|
usage_dicts=[
|
|
{"usage": {"input_tokens": 12, "output_tokens": 23}},
|
|
{"usage": {"input_tokens": 12, "output_tokens": 23}},
|
|
],
|
|
),
|
|
model="gpt-4o-2024-05-13",
|
|
system="You are a helpful assistant.",
|
|
tools=[Tool.from_function(long_output_tool_char)],
|
|
moderator=PassiveModerator(),
|
|
)
|
|
|
|
ex.add(Message(role="user", content=[Text(text="test long output char")]))
|
|
|
|
ex.reply()
|
|
|
|
content = ex.messages[-2].content[0]
|
|
assert (
|
|
isinstance(content, ToolResult)
|
|
and content.is_error
|
|
and "output that was too long to handle" in content.output.lower()
|
|
)
|
|
|
|
|
|
def test_tool_output_too_long_token_error():
|
|
"""Test tool handling when output exceeds token limit."""
|
|
|
|
def long_output_tool_token() -> str:
|
|
return too_long_token_output
|
|
|
|
ex = Exchange(
|
|
provider=MockProvider(
|
|
sequence=[
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="1", name="long_output_tool_token", parameters={})],
|
|
),
|
|
Message(
|
|
role="assistant",
|
|
content=[Text(text="Here is the completion after tool call")],
|
|
),
|
|
],
|
|
usage_dicts=[
|
|
{"usage": {"input_tokens": 12, "output_tokens": 23}},
|
|
{"usage": {"input_tokens": 12, "output_tokens": 23}},
|
|
],
|
|
),
|
|
model="gpt-4o-2024-05-13",
|
|
system="You are a helpful assistant.",
|
|
tools=[Tool.from_function(long_output_tool_token)],
|
|
moderator=PassiveModerator(),
|
|
)
|
|
|
|
ex.add(Message(role="user", content=[Text(text="test long output token")]))
|
|
|
|
ex.reply()
|
|
|
|
content = ex.messages[-2].content[0]
|
|
assert (
|
|
isinstance(content, ToolResult)
|
|
and content.is_error
|
|
and "output that was too long to handle" in content.output.lower()
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def normal_exchange() -> Exchange:
|
|
ex = Exchange(
|
|
provider=MockProvider(
|
|
sequence=[
|
|
Message(role="assistant", content=[Text(text="Message 1")]),
|
|
Message(role="assistant", content=[Text(text="Message 2")]),
|
|
Message(role="assistant", content=[Text(text="Message 3")]),
|
|
Message(role="assistant", content=[Text(text="Message 4")]),
|
|
Message(role="assistant", content=[Text(text="Message 5")]),
|
|
],
|
|
usage_dicts=[
|
|
{"usage": {"total_tokens": 10, "input_tokens": 5, "output_tokens": 5}},
|
|
{"usage": {"total_tokens": 28, "input_tokens": 10, "output_tokens": 18}},
|
|
{"usage": {"total_tokens": 33, "input_tokens": 28, "output_tokens": 5}},
|
|
{"usage": {"total_tokens": 40, "input_tokens": 32, "output_tokens": 8}},
|
|
{"usage": {"total_tokens": 50, "input_tokens": 40, "output_tokens": 10}},
|
|
],
|
|
),
|
|
model="gpt-4o-2024-05-13",
|
|
system="You are a helpful assistant.",
|
|
tools=(Tool.from_function(dummy_tool),),
|
|
moderator=PassiveModerator(),
|
|
checkpoint_data=CheckpointData(),
|
|
)
|
|
return ex
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def resumed_exchange() -> Exchange:
|
|
messages = [
|
|
Message(role="user", content=[Text(text="User message 1")]),
|
|
Message(role="assistant", content=[Text(text="Assistant Message 1")]),
|
|
Message(role="user", content=[Text(text="User message 2")]),
|
|
Message(role="assistant", content=[Text(text="Assistant Message 2")]),
|
|
Message(role="user", content=[Text(text="User message 3")]),
|
|
Message(role="assistant", content=[Text(text="Assistant Message 3")]),
|
|
]
|
|
provider = MockProvider(
|
|
sequence=[
|
|
Message(role="assistant", content=[Text(text="Assistant Message 4")]),
|
|
],
|
|
usage_dicts=[
|
|
{"usage": {"total_tokens": 40, "input_tokens": 32, "output_tokens": 8}},
|
|
],
|
|
)
|
|
ex = Exchange(
|
|
provider=provider,
|
|
messages=messages,
|
|
tools=[],
|
|
model="gpt-4o-2024-05-13",
|
|
system="You are a helpful assistant.",
|
|
checkpoint_data=CheckpointData(),
|
|
moderator=PassiveModerator(),
|
|
)
|
|
return ex
|
|
|
|
|
|
def test_checkpoints_on_exchange(normal_exchange):
|
|
"""Test checkpoints on an exchange."""
|
|
ex = normal_exchange
|
|
ex.add(Message(role="user", content=[Text(text="User message")]))
|
|
ex.reply()
|
|
ex.add(Message(role="user", content=[Text(text="User message")]))
|
|
ex.reply()
|
|
ex.add(Message(role="user", content=[Text(text="User message")]))
|
|
ex.reply()
|
|
|
|
# Check if checkpoints are created correctly
|
|
checkpoints = ex.checkpoint_data.checkpoints
|
|
assert len(checkpoints) == 6
|
|
for i in range(len(ex.messages)):
|
|
# asserting that each message has a corresponding checkpoint
|
|
assert checkpoints[i].start_index == i
|
|
assert checkpoints[i].end_index == i
|
|
|
|
# Check if the messages are ordered correctly
|
|
assert [msg.content[0].text for msg in ex.messages] == [
|
|
"User message",
|
|
"Message 1",
|
|
"User message",
|
|
"Message 2",
|
|
"User message",
|
|
"Message 3",
|
|
]
|
|
assert no_overlapping_checkpoints(ex)
|
|
|
|
|
|
def test_checkpoints_on_resumed_exchange(resumed_exchange) -> None:
|
|
ex = resumed_exchange
|
|
ex.pop_last_message()
|
|
ex.reply()
|
|
|
|
checkpoints = ex.checkpoint_data.checkpoints
|
|
assert len(checkpoints) == 2
|
|
assert len(ex.messages) == 6
|
|
assert checkpoints[0].token_count == 32
|
|
assert checkpoints[0].start_index == 0
|
|
assert checkpoints[0].end_index == 4
|
|
assert checkpoints[1].token_count == 8
|
|
assert checkpoints[1].start_index == 5
|
|
assert checkpoints[1].end_index == 5
|
|
assert no_overlapping_checkpoints(ex)
|
|
|
|
|
|
def test_pop_last_checkpoint_on_resumed_exchange(resumed_exchange) -> None:
|
|
ex = resumed_exchange
|
|
ex.add(Message(role="user", content=[Text(text="Assistant Message 4")]))
|
|
ex.reply()
|
|
ex.pop_last_checkpoint()
|
|
|
|
assert len(ex.messages) == 7
|
|
assert len(ex.checkpoint_data.checkpoints) == 1
|
|
|
|
ex.pop_last_checkpoint()
|
|
assert len(ex.messages) == 0
|
|
assert len(ex.checkpoint_data.checkpoints) == 0
|
|
assert no_overlapping_checkpoints(ex)
|
|
|
|
|
|
def test_pop_last_checkpoint_on_normal_exchange(normal_exchange) -> None:
|
|
ex = normal_exchange
|
|
ex.add(Message(role="user", content=[Text(text="User message")]))
|
|
ex.reply()
|
|
ex.add(Message(role="user", content=[Text(text="User message")]))
|
|
ex.reply()
|
|
ex.pop_last_checkpoint()
|
|
ex.pop_last_checkpoint()
|
|
|
|
assert len(ex.messages) == 2
|
|
assert len(ex.checkpoint_data.checkpoints) == 2
|
|
assert no_overlapping_checkpoints(ex)
|
|
ex.add(Message(role="user", content=[Text(text="User message")]))
|
|
ex.pop_last_checkpoint()
|
|
assert len(ex.messages) == 1
|
|
assert len(ex.checkpoint_data.checkpoints) == 1
|
|
ex.reply()
|
|
assert len(ex.messages) == 2
|
|
assert len(ex.checkpoint_data.checkpoints) == 2
|
|
assert no_overlapping_checkpoints(ex)
|
|
|
|
|
|
def test_pop_first_message_no_messages():
|
|
ex = Exchange(
|
|
provider=MockProvider(sequence=[], usage_dicts=[]),
|
|
model="gpt-4o-2024-05-13",
|
|
system="You are a helpful assistant.",
|
|
tools=[Tool.from_function(dummy_tool)],
|
|
moderator=PassiveModerator(),
|
|
)
|
|
|
|
with pytest.raises(ValueError) as e:
|
|
ex.pop_first_message()
|
|
assert str(e.value) == "There are no messages to pop"
|
|
|
|
|
|
def test_pop_first_message_checkpoint_with_many_messages(resumed_exchange):
|
|
ex = resumed_exchange
|
|
ex.pop_last_message()
|
|
ex.reply()
|
|
|
|
assert len(ex.messages) == 6
|
|
assert len(ex.checkpoint_data.checkpoints) == 2
|
|
assert ex.checkpoint_data.checkpoints[0].start_index == 0
|
|
assert ex.checkpoint_data.checkpoints[0].end_index == 4
|
|
assert ex.checkpoint_data.checkpoints[1].start_index == 5
|
|
assert ex.checkpoint_data.checkpoints[1].end_index == 5
|
|
assert ex.checkpoint_data.message_index_offset == 0
|
|
assert no_overlapping_checkpoints(ex)
|
|
|
|
ex.pop_first_message()
|
|
|
|
assert len(ex.messages) == 5
|
|
assert len(ex.checkpoint_data.checkpoints) == 1
|
|
assert ex.checkpoint_data.checkpoints[0].start_index == 5
|
|
assert ex.checkpoint_data.checkpoints[0].end_index == 5
|
|
assert ex.checkpoint_data.message_index_offset == 1
|
|
assert no_overlapping_checkpoints(ex)
|
|
|
|
ex.pop_first_message()
|
|
|
|
assert len(ex.messages) == 4
|
|
assert len(ex.checkpoint_data.checkpoints) == 1
|
|
assert ex.checkpoint_data.checkpoints[0].start_index == 5
|
|
assert ex.checkpoint_data.checkpoints[0].end_index == 5
|
|
assert ex.checkpoint_data.message_index_offset == 2
|
|
assert no_overlapping_checkpoints(ex)
|
|
|
|
ex.pop_first_message()
|
|
|
|
assert len(ex.messages) == 3
|
|
assert len(ex.checkpoint_data.checkpoints) == 1
|
|
assert ex.checkpoint_data.checkpoints[0].start_index == 5
|
|
assert ex.checkpoint_data.checkpoints[0].end_index == 5
|
|
assert ex.checkpoint_data.message_index_offset == 3
|
|
assert no_overlapping_checkpoints(ex)
|
|
|
|
ex.pop_first_message()
|
|
|
|
assert len(ex.messages) == 2
|
|
assert len(ex.checkpoint_data.checkpoints) == 1
|
|
assert ex.checkpoint_data.checkpoints[0].start_index == 5
|
|
assert ex.checkpoint_data.checkpoints[0].end_index == 5
|
|
assert ex.checkpoint_data.message_index_offset == 4
|
|
assert no_overlapping_checkpoints(ex)
|
|
|
|
ex.pop_first_message()
|
|
|
|
assert len(ex.messages) == 1
|
|
assert len(ex.checkpoint_data.checkpoints) == 1
|
|
assert ex.checkpoint_data.checkpoints[0].start_index == 5
|
|
assert ex.checkpoint_data.checkpoints[0].end_index == 5
|
|
assert ex.checkpoint_data.message_index_offset == 5
|
|
assert no_overlapping_checkpoints(ex)
|
|
|
|
ex.pop_first_message()
|
|
|
|
assert len(ex.messages) == 0
|
|
assert len(ex.checkpoint_data.checkpoints) == 0
|
|
assert ex.checkpoint_data.message_index_offset == 0
|
|
assert no_overlapping_checkpoints(ex)
|
|
|
|
with pytest.raises(ValueError) as e:
|
|
ex.pop_first_message()
|
|
|
|
assert str(e.value) == "There are no messages to pop"
|
|
|
|
|
|
def test_varied_message_manipulation(normal_exchange):
|
|
ex = normal_exchange
|
|
ex.add(Message(role="user", content=[Text(text="User message 1")]))
|
|
ex.reply()
|
|
|
|
ex.pop_first_message()
|
|
|
|
ex.add(Message(role="user", content=[Text(text="User message 2")]))
|
|
ex.reply()
|
|
|
|
assert len(ex.messages) == 3
|
|
assert len(ex.checkpoint_data.checkpoints) == 3
|
|
assert ex.checkpoint_data.message_index_offset == 1
|
|
# (start, end)
|
|
# (1, 1), (2, 2), (3, 3)
|
|
# actual_index_in_messages_arr = any checkpoint index - offset
|
|
assert no_overlapping_checkpoints(ex)
|
|
for i in range(3):
|
|
assert ex.checkpoint_data.checkpoints[i].start_index == i + 1
|
|
assert ex.checkpoint_data.checkpoints[i].end_index == i + 1
|
|
|
|
ex.pop_last_message()
|
|
|
|
assert len(ex.messages) == 2
|
|
assert len(ex.checkpoint_data.checkpoints) == 2
|
|
assert ex.checkpoint_data.message_index_offset == 1
|
|
assert no_overlapping_checkpoints(ex)
|
|
for i in range(2):
|
|
assert ex.checkpoint_data.checkpoints[i].start_index == i + 1
|
|
assert ex.checkpoint_data.checkpoints[i].end_index == i + 1
|
|
|
|
ex.add(Message(role="assistant", content=[Text(text="Assistant message")]))
|
|
ex.add(Message(role="user", content=[Text(text="User message 3")]))
|
|
ex.reply()
|
|
|
|
assert len(ex.messages) == 5
|
|
assert len(ex.checkpoint_data.checkpoints) == 4
|
|
assert ex.checkpoint_data.message_index_offset == 1
|
|
assert no_overlapping_checkpoints(ex)
|
|
assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(1, 1), (2, 2), (3, 4), (5, 5)]
|
|
|
|
ex.pop_last_checkpoint()
|
|
|
|
assert len(ex.messages) == 4
|
|
assert len(ex.checkpoint_data.checkpoints) == 3
|
|
assert ex.checkpoint_data.message_index_offset == 1
|
|
assert no_overlapping_checkpoints(ex)
|
|
assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(1, 1), (2, 2), (3, 4)]
|
|
|
|
ex.pop_first_message()
|
|
|
|
assert len(ex.messages) == 3
|
|
assert len(ex.checkpoint_data.checkpoints) == 2
|
|
assert ex.checkpoint_data.message_index_offset == 2
|
|
assert no_overlapping_checkpoints(ex)
|
|
assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(2, 2), (3, 4)]
|
|
|
|
ex.pop_last_message()
|
|
|
|
assert len(ex.messages) == 2
|
|
assert len(ex.checkpoint_data.checkpoints) == 1
|
|
assert ex.checkpoint_data.message_index_offset == 2
|
|
assert no_overlapping_checkpoints(ex)
|
|
assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(2, 2)]
|
|
|
|
ex.pop_last_message()
|
|
assert len(ex.messages) == 1
|
|
assert len(ex.checkpoint_data.checkpoints) == 1
|
|
assert ex.checkpoint_data.message_index_offset == 2
|
|
assert no_overlapping_checkpoints(ex)
|
|
assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(2, 2)]
|
|
|
|
ex.add(Message(role="assistant", content=[Text(text="Assistant message")]))
|
|
ex.add(Message(role="user", content=[Text(text="User message 5")]))
|
|
ex.pop_last_checkpoint()
|
|
|
|
assert len(ex.messages) == 0
|
|
assert len(ex.checkpoint_data.checkpoints) == 0
|
|
|
|
ex.add(Message(role="user", content=[Text(text="User message 6")]))
|
|
ex.reply()
|
|
|
|
assert len(ex.messages) == 2
|
|
assert len(ex.checkpoint_data.checkpoints) == 2
|
|
assert ex.checkpoint_data.message_index_offset == 2
|
|
assert no_overlapping_checkpoints(ex)
|
|
assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(2, 2), (3, 3)]
|
|
|
|
ex.pop_last_message()
|
|
|
|
assert len(ex.messages) == 1
|
|
assert len(ex.checkpoint_data.checkpoints) == 1
|
|
assert ex.checkpoint_data.message_index_offset == 2
|
|
assert no_overlapping_checkpoints(ex)
|
|
assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(2, 2)]
|
|
|
|
ex.pop_first_message()
|
|
|
|
assert len(ex.messages) == 0
|
|
assert len(ex.checkpoint_data.checkpoints) == 0
|
|
assert ex.checkpoint_data.message_index_offset == 0
|
|
|
|
ex.add(Message(role="user", content=[Text(text="User message 7")]))
|
|
ex.pop_last_message()
|
|
|
|
assert len(ex.messages) == 0
|
|
assert len(ex.checkpoint_data.checkpoints) == 0
|
|
assert ex.checkpoint_data.message_index_offset == 0
|
|
|
|
|
|
def test_pop_last_message_when_no_checkpoints_but_messages_present(normal_exchange):
|
|
ex = normal_exchange
|
|
ex.add(Message(role="user", content=[Text(text="User message")]))
|
|
|
|
ex.pop_last_message()
|
|
|
|
assert len(ex.messages) == 0
|
|
assert len(ex.checkpoint_data.checkpoints) == 0
|
|
assert ex.checkpoint_data.message_index_offset == 0
|
|
|
|
|
|
def test_pop_first_message_when_no_checkpoints_but_message_present(normal_exchange):
|
|
ex = normal_exchange
|
|
ex.add(Message(role="user", content=[Text(text="User message")]))
|
|
|
|
with pytest.raises(ValueError) as e:
|
|
ex.pop_first_message()
|
|
|
|
assert str(e.value) == "There must be at least one checkpoint to pop the first message"
|
|
|
|
|
|
def test_pop_first_checkpoint_size_n(resumed_exchange):
|
|
ex = resumed_exchange
|
|
ex.pop_last_message() # needed because the last message is an assistant message
|
|
ex.reply()
|
|
|
|
ex.pop_first_checkpoint()
|
|
assert ex.checkpoint_data.message_index_offset == 5
|
|
assert len(ex.checkpoint_data.checkpoints) == 1
|
|
assert len(ex.messages) == 1
|
|
|
|
ex.pop_first_checkpoint()
|
|
assert ex.checkpoint_data.message_index_offset == 0
|
|
assert len(ex.checkpoint_data.checkpoints) == 0
|
|
assert len(ex.messages) == 0
|
|
|
|
|
|
def test_pop_first_checkpoint_size_1(normal_exchange):
|
|
ex = normal_exchange
|
|
ex.add(Message(role="user", content=[Text(text="User message")]))
|
|
ex.reply()
|
|
|
|
ex.pop_first_checkpoint()
|
|
assert ex.checkpoint_data.message_index_offset == 1
|
|
assert len(ex.checkpoint_data.checkpoints) == 1
|
|
assert len(ex.messages) == 1
|
|
|
|
ex.pop_first_checkpoint()
|
|
assert ex.checkpoint_data.message_index_offset == 0
|
|
assert len(ex.checkpoint_data.checkpoints) == 0
|
|
assert len(ex.messages) == 0
|
|
|
|
|
|
def test_pop_first_checkpoint_no_checkpoints(normal_exchange):
|
|
ex = normal_exchange
|
|
|
|
with pytest.raises(ValueError) as e:
|
|
ex.pop_first_checkpoint()
|
|
|
|
assert str(e.value) == "There are no checkpoints to pop"
|
|
|
|
|
|
def test_prepend_checkpointed_message_empty_exchange(normal_exchange):
|
|
ex = normal_exchange
|
|
ex.prepend_checkpointed_message(Message(role="assistant", content=[Text(text="Assistant message")]), 10)
|
|
|
|
assert ex.checkpoint_data.message_index_offset == 0
|
|
assert len(ex.checkpoint_data.checkpoints) == 1
|
|
assert ex.checkpoint_data.checkpoints[0].start_index == 0
|
|
assert ex.checkpoint_data.checkpoints[0].end_index == 0
|
|
|
|
ex.add(Message(role="user", content=[Text(text="User message")]))
|
|
ex.reply()
|
|
|
|
assert ex.checkpoint_data.message_index_offset == 0
|
|
assert len(ex.checkpoint_data.checkpoints) == 3
|
|
assert len(ex.messages) == 3
|
|
assert no_overlapping_checkpoints(ex)
|
|
|
|
ex.pop_first_checkpoint()
|
|
|
|
assert ex.checkpoint_data.message_index_offset == 1
|
|
assert len(ex.checkpoint_data.checkpoints) == 2
|
|
assert len(ex.messages) == 2
|
|
assert no_overlapping_checkpoints(ex)
|
|
|
|
ex.prepend_checkpointed_message(Message(role="assistant", content=[Text(text="Assistant message")]), 10)
|
|
assert ex.checkpoint_data.message_index_offset == 0
|
|
assert len(ex.checkpoint_data.checkpoints) == 3
|
|
assert len(ex.messages) == 3
|
|
assert no_overlapping_checkpoints(ex)
|
|
|
|
|
|
def test_generate_successful_response_on_first_try(normal_exchange):
|
|
ex = normal_exchange
|
|
ex.add(Message(role="user", content=[Text("Hello")]))
|
|
ex.generate()
|
|
|
|
|
|
def test_rewind_in_normal_exchange(normal_exchange):
|
|
ex = normal_exchange
|
|
ex.rewind()
|
|
|
|
assert len(ex.messages) == 0
|
|
assert len(ex.checkpoint_data.checkpoints) == 0
|
|
|
|
ex.add(Message(role="user", content=[Text("Hello")]))
|
|
ex.generate()
|
|
ex.add(Message(role="user", content=[Text("Hello")]))
|
|
|
|
# testing if it works with a user text message at the end
|
|
ex.rewind()
|
|
|
|
assert len(ex.messages) == 2
|
|
assert len(ex.checkpoint_data.checkpoints) == 2
|
|
|
|
ex.add(Message(role="user", content=[Text("Hello")]))
|
|
ex.generate()
|
|
|
|
# testing if it works with a non-user text message at the end
|
|
ex.rewind()
|
|
|
|
assert len(ex.messages) == 2
|
|
assert len(ex.checkpoint_data.checkpoints) == 2
|
|
|
|
|
|
def test_rewind_with_tool_usage():
|
|
# simulating a real exchange with tool usage
|
|
ex = Exchange(
|
|
provider=MockProvider(
|
|
sequence=[
|
|
Message.assistant("Hello!"),
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="1", name="dummy_tool", parameters={})],
|
|
),
|
|
Message(
|
|
role="assistant",
|
|
content=[ToolUse(id="2", name="dummy_tool", parameters={})],
|
|
),
|
|
Message.assistant("Done!"),
|
|
],
|
|
usage_dicts=[
|
|
{"usage": {"input_tokens": 12, "output_tokens": 23}},
|
|
{"usage": {"input_tokens": 27, "output_tokens": 44}},
|
|
{"usage": {"input_tokens": 50, "output_tokens": 56}},
|
|
{"usage": {"input_tokens": 60, "output_tokens": 76}},
|
|
],
|
|
),
|
|
model="gpt-4o-2024-05-13",
|
|
system="You are a helpful assistant.",
|
|
tools=[Tool.from_function(dummy_tool)],
|
|
moderator=PassiveModerator(),
|
|
)
|
|
ex.add(Message(role="user", content=[Text(text="test")]))
|
|
ex.reply()
|
|
ex.add(Message(role="user", content=[Text(text="kick it off!")]))
|
|
ex.reply()
|
|
|
|
# removing the last message to simulate not getting a response
|
|
ex.pop_last_message()
|
|
|
|
# calling rewind to last user message
|
|
ex.rewind()
|
|
|
|
assert len(ex.messages) == 2
|
|
assert len(ex.checkpoint_data.checkpoints) == 2
|
|
assert no_overlapping_checkpoints(ex)
|
|
assert ex.messages[0].content[0].text == "test"
|
|
assert type(ex.messages[1].content[0]) is Text
|
|
assert ex.messages[1].role == "assistant"
|